Skip to content

Utils


Package


utils

utils

Shared utility functions and helpers for psyphy.

This subpackage provides: - bootstrap : frequentist confidence intervals via resampling. - candidates : functions for generating candidate stimulus pools. - diagnostics : parameter summaries and threshold uncertainty estimation. - math : mathematical utilities (basis functions, distances, kernels). - rng : random number handling for reproducibility.

MVP implementation
  • bootstrap: prediction CIs, model comparison, arbitrary statistics.
  • candidates: grid, Sobol, custom pools.
  • diagnostics: parameter summaries, threshold uncertainty.
  • math: Chebyshev basis, Mahalanobis distance, RBF kernel.
  • rng: seed() and split() for JAX PRNG keys.
Full WPPM mode
  • candidates: adaptive refinement around posterior uncertainty.
  • diagnostics: parameter sensitivity analysis, model comparison.
  • math: richer kernels and basis expansions for Wishart processes.
  • rng: experiment-wide RNG registry.

Functions:

Name Description
bootstrap_compare_models

Bootstrap comparison of two models' predictive performance.

bootstrap_predictions

Bootstrap confidence intervals for model predictions.

bootstrap_statistic

Bootstrap confidence interval for any model-derived statistic.

chebyshev_basis

Construct the Chebyshev polynomial basis matrix T_0..T_degree evaluated at x.

custom_candidates

Wrap a user-defined list of probes into candidate pairs.

estimate_threshold_contour_uncertainty

Estimate threshold contour and its uncertainty around a reference point.

estimate_threshold_uncertainty

Estimate threshold location and uncertainty via parameter sampling.

grid_candidates

Generate grid-based candidate probes around a reference.

mahalanobis_distance

Compute squared Mahalanobis distance between x and mean.

parameter_summary

Compute summary statistics for all model parameters.

print_parameter_summary

Print a human-readable parameter summary.

rbf_kernel

Radial Basis Function (RBF) kernel between two sets of points.

seed

Create a new PRNG key from an integer seed.

sobol_candidates

Generate Sobol quasi-random candidates within bounds.

split

Split a PRNG key into multiple independent keys.

bootstrap_compare_models

bootstrap_compare_models(
    model1: Model,
    model2: Model,
    X_train: ndarray,
    y_train: ndarray,
    X_test: ndarray,
    y_test: ndarray,
    *,
    metric_fn: Callable[[ndarray, ndarray], float]
    | None = None,
    n_bootstrap: int = 100,
    confidence_level: float = 0.95,
    probes: ndarray | None = None,
    inference: str = "map",
    inference_config: dict[str, Any] | None = None,
    key: Any,
) -> tuple[float, float, float, bool]

Bootstrap comparison of two models' predictive performance.

Tests whether model1 performs significantly better/worse than model2 by computing confidence intervals on the performance difference.

Parameters:

Name Type Description Default
model1 Model

Unfitted model instances to compare

required
model2 Model

Unfitted model instances to compare

required
X_train Training data
required
y_train Training data
required
X_test Test data for evaluation
required
y_test Test data for evaluation
required
metric_fn callable

Function that takes (y_true, y_pred) and returns a scalar. Default: accuracy for binary classification

None
n_bootstrap int

Number of bootstrap samples

100
confidence_level float

Confidence level

0.95
probes optional

Test probes for discrimination tasks

None
inference str

Inference method

'map'
inference_config dict

Inference configuration

None
key KeyArray

Random key

required

Returns:

Name Type Description
diff_estimate float

Estimated difference in performance (model1 - model2) Positive = model1 is better

ci_lower float

Lower bound on difference

ci_upper float

Upper bound on difference

is_significant bool

True if the difference is statistically significant (i.e., confidence interval excludes zero)

Examples:

>>> # Compare two models on held-out data
>>> from psyphy.utils.bootstrap import bootstrap_compare_models
>>>
>>> model1 = WPPM(input_dim=2, prior=Prior(input_dim=2, scale=0.5), ...)
>>> model2 = WPPM(input_dim=2, prior=Prior(input_dim=2, scale=1.0), ...)
>>>
>>> diff, lower, upper, significant = bootstrap_compare_models(
...     model1,
...     model2,
...     X_train,
...     y_train,
...     X_test,
...     y_test,
...     n_bootstrap=200,
...     key=jr.PRNGKey(0),
... )
>>>
>>> if significant:
...     winner = "Model 1" if diff > 0 else "Model 2"
...     print(f"{winner} is significantly better")
...     print(f"Difference: {diff:.3f} [{lower:.3f}, {upper:.3f}]")
>>> else:
...     print("No significant difference")
>>> # Custom metric: mean squared error
>>> def mse(y_true, y_pred):
...     return jnp.mean((y_true - y_pred) ** 2)
>>>
>>> diff, lower, upper, sig = bootstrap_compare_models(
...     model1,
...     model2,
...     X_train,
...     y_train,
...     X_test,
...     y_test,
...     metric=mse,  # Lower is better
...     n_bootstrap=100,
...     key=jr.PRNGKey(1),
... )
Notes

This function performs paired bootstrap comparison: for each bootstrap sample, both models are fit on the same resampled training data and evaluated on the same test data. This controls for data sampling variability.

The null hypothesis is: "models have equal performance" We reject this if the CI on the difference excludes zero.

Source code in src/psyphy/utils/bootstrap.py
def bootstrap_compare_models(
    model1: Model,
    model2: Model,
    X_train: jnp.ndarray,
    y_train: jnp.ndarray,
    X_test: jnp.ndarray,
    y_test: jnp.ndarray,
    *,
    metric_fn: Callable[[jnp.ndarray, jnp.ndarray], float] | None = None,
    n_bootstrap: int = 100,
    confidence_level: float = 0.95,
    probes: jnp.ndarray | None = None,
    inference: str = "map",
    inference_config: dict[str, Any] | None = None,
    key: Any,
) -> tuple[float, float, float, bool]:
    """
    Bootstrap comparison of two models' predictive performance.

    Tests whether model1 performs significantly better/worse than model2
    by computing confidence intervals on the performance difference.

    Parameters
    ----------
    model1, model2 : Model
        Unfitted model instances to compare
    X_train, y_train : Training data
    X_test, y_test : Test data for evaluation
    metric_fn : callable, optional
        Function that takes (y_true, y_pred) and returns a scalar.
        Default: accuracy for binary classification
    n_bootstrap : int, default=100
        Number of bootstrap samples
    confidence_level : float, default=0.95
        Confidence level
    probes : optional
        Test probes for discrimination tasks
    inference : str
        Inference method
    inference_config : dict, optional
        Inference configuration
    key : jr.KeyArray
        Random key

    Returns
    -------
    diff_estimate : float
        Estimated difference in performance (model1 - model2)
        Positive = model1 is better
    ci_lower : float
        Lower bound on difference
    ci_upper : float
        Upper bound on difference
    is_significant : bool
        True if the difference is statistically significant
        (i.e., confidence interval excludes zero)

    Examples
    --------
    >>> # Compare two models on held-out data
    >>> from psyphy.utils.bootstrap import bootstrap_compare_models
    >>>
    >>> model1 = WPPM(input_dim=2, prior=Prior(input_dim=2, scale=0.5), ...)
    >>> model2 = WPPM(input_dim=2, prior=Prior(input_dim=2, scale=1.0), ...)
    >>>
    >>> diff, lower, upper, significant = bootstrap_compare_models(
    ...     model1,
    ...     model2,
    ...     X_train,
    ...     y_train,
    ...     X_test,
    ...     y_test,
    ...     n_bootstrap=200,
    ...     key=jr.PRNGKey(0),
    ... )
    >>>
    >>> if significant:
    ...     winner = "Model 1" if diff > 0 else "Model 2"
    ...     print(f"{winner} is significantly better")
    ...     print(f"Difference: {diff:.3f} [{lower:.3f}, {upper:.3f}]")
    >>> else:
    ...     print("No significant difference")

    >>> # Custom metric: mean squared error
    >>> def mse(y_true, y_pred):
    ...     return jnp.mean((y_true - y_pred) ** 2)
    >>>
    >>> diff, lower, upper, sig = bootstrap_compare_models(
    ...     model1,
    ...     model2,
    ...     X_train,
    ...     y_train,
    ...     X_test,
    ...     y_test,
    ...     metric=mse,  # Lower is better
    ...     n_bootstrap=100,
    ...     key=jr.PRNGKey(1),
    ... )

    Notes
    -----
    This function performs paired bootstrap comparison: for each
    bootstrap sample, both models are fit on the same resampled
    training data and evaluated on the same test data. This controls
    for data sampling variability.

    The null hypothesis is: "models have equal performance"
    We reject this if the CI on the difference excludes zero.
    """
    # Set default metric if not provided
    _metric: Callable[[Any, Any], float]
    if metric_fn is None:
        # Default: accuracy
        def default_metric(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> float:
            return float(jnp.mean(y_pred == y_true))

        _metric = default_metric
    else:
        _metric = metric_fn  # type: ignore[assignment]

    n_train = len(X_train)
    alpha = 1 - confidence_level

    differences = []

    for _ in range(n_bootstrap):
        # Resample training data (same sample for both models)
        key, subkey = jr.split(key)
        indices = jr.randint(subkey, (n_train,), 0, n_train)
        X_boot = X_train[indices]
        y_boot = y_train[indices]

        # Fit model 1
        m1_boot = _clone_model(model1)
        m1_boot.fit(
            X_boot,
            y_boot,
            inference=inference,
            inference_config=inference_config,
        )

        # Fit model 2
        m2_boot = _clone_model(model2)
        m2_boot.fit(
            X_boot,
            y_boot,
            inference=inference,
            inference_config=inference_config,
        )

        # Evaluate on test data
        post1 = m1_boot.posterior(X_test, probes=probes)
        post2 = m2_boot.posterior(X_test, probes=probes)

        # Convert to predictions (threshold at 0.5)
        y_pred1 = (post1.mean > 0.5).astype(int)  # type: ignore[attr-defined, union-attr]
        y_pred2 = (post2.mean > 0.5).astype(int)  # type: ignore[attr-defined, union-attr]

        # Compute metrics
        score1 = _metric(y_test, y_pred1)
        score2 = _metric(y_test, y_pred2)

        differences.append(score1 - score2)

    # Stack and compute statistics
    differences = jnp.array(differences)

    diff_estimate = float(jnp.mean(differences))
    ci_lower = float(jnp.percentile(differences, 100 * alpha / 2))
    ci_upper = float(jnp.percentile(differences, 100 * (1 - alpha / 2)))

    # Test significance: CI excludes zero?
    is_significant = bool((ci_lower > 0) or (ci_upper < 0))

    return diff_estimate, ci_lower, ci_upper, is_significant

bootstrap_predictions

bootstrap_predictions(
    model: Model,
    X_train: ndarray,
    y_train: ndarray,
    X_test: ndarray,
    *,
    n_bootstrap: int = 100,
    probes: ndarray | None = None,
    confidence_level: float = 0.95,
    inference: str = "map",
    inference_config: dict[str, Any] | None = None,
    key: Any,
) -> tuple[ndarray, ndarray, ndarray]

Bootstrap confidence intervals for model predictions.

Resamples training data with replacement, refits model N times, and computes prediction quantiles at test points.

Parameters:

Name Type Description Default
model Model

Unfitted model instance (will be cloned for each bootstrap sample)

required
X_train (ndarray, shape(n_train, ...))

Training stimuli

required
y_train (ndarray, shape(n_train))

Training responses

required
X_test (ndarray, shape(n_test, ...))

Test points for predictions

required
n_bootstrap int

Number of bootstrap samples. Typical values: 100 (quick), 1000 (publication quality)

100
probes ndarray

Test probes for discrimination tasks

None
confidence_level float

Confidence level (e.g., 0.95 for 95% CI, 0.99 for 99% CI)

0.95
inference str

Inference method for each bootstrap fit

"map"
inference_config dict

Configuration for inference engine

None
key Any

JAX random key for reproducibility

required

Returns:

Name Type Description
mean_estimate (ndarray, shape(n_test))

Average prediction across bootstrap samples

ci_lower (ndarray, shape(n_test))

Lower confidence bound at each test point

ci_upper (ndarray, shape(n_test))

Upper confidence bound at each test point

Examples:

>>> # Fit model and get bootstrap CIs
>>> from psyphy.model import WPPM, Prior
>>> from psyphy.model.task import OddityTask
>>> from psyphy.model.noise import GaussianNoise
>>>
>>> model = WPPM(
...     input_dim=2,
...     prior=Prior(input_dim=2),
...     task=OddityTask(),
...     noise=GaussianNoise(),
... )
>>>
>>> # Bootstrap CIs for psychometric function
>>> X_test = jnp.linspace(-1, 1, 50)[:, None]
>>> probes_test = X_test + 0.1
>>>
>>> mean, lower, upper = bootstrap_predictions(
...     model,
...     X_train,
...     y_train,
...     X_test,
...     probes=probes_test,
...     n_bootstrap=100,
...     key=jr.PRNGKey(0),
... )
>>>
>>> # Plot with confidence bands
>>> import matplotlib.pyplot as plt
>>> plt.plot(X_test, mean, label="Mean prediction")
>>> plt.fill_between(X_test[:, 0], lower, upper, alpha=0.3, label="95% CI")
>>> plt.legend()
>>> # Quick diagnostic: is model stable?
>>> mean, lower, upper = bootstrap_predictions(
...     model,
...     X_train,
...     y_train,
...     X_test,
...     n_bootstrap=50,  # Faster for diagnostics
...     key=jr.PRNGKey(42),
... )
>>> ci_width = upper - lower
>>> print(f"Average CI width: {jnp.mean(ci_width):.3f}")
Notes

Computational cost: - Each bootstrap sample requires a full model refit - Total time ≈ n_bootstrap × (time per fit) - For MAP with 100 samples: typically 10-100 seconds

Assumptions: - Training data are IID (independent and identically distributed) - For sequential data, consider block bootstrap instead

The bootstrap estimates sampling uncertainty (how stable are predictions if we collected different data?), not model uncertainty (what is the range of plausible predictions given the data?). For model uncertainty, use the Bayesian posterior.variance instead.

Source code in src/psyphy/utils/bootstrap.py
def bootstrap_predictions(
    model: Model,
    X_train: jnp.ndarray,
    y_train: jnp.ndarray,
    X_test: jnp.ndarray,
    *,
    n_bootstrap: int = 100,
    probes: jnp.ndarray | None = None,
    confidence_level: float = 0.95,
    inference: str = "map",
    inference_config: dict[str, Any] | None = None,
    key: Any,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """
    Bootstrap confidence intervals for model predictions.

    Resamples training data with replacement, refits model N times,
    and computes prediction quantiles at test points.

    Parameters
    ----------
    model : Model
        Unfitted model instance (will be cloned for each bootstrap sample)
    X_train : jnp.ndarray, shape (n_train, ...)
        Training stimuli
    y_train : jnp.ndarray, shape (n_train,)
        Training responses
    X_test : jnp.ndarray, shape (n_test, ...)
        Test points for predictions
    n_bootstrap : int, default=100
        Number of bootstrap samples.
        Typical values: 100 (quick), 1000 (publication quality)
    probes : jnp.ndarray, optional
        Test probes for discrimination tasks
    confidence_level : float, default=0.95
        Confidence level (e.g., 0.95 for 95% CI, 0.99 for 99% CI)
    inference : str, default="map"
        Inference method for each bootstrap fit
    inference_config : dict, optional
        Configuration for inference engine
    key : Any
        JAX random key for reproducibility

    Returns
    -------
    mean_estimate : jnp.ndarray, shape (n_test,)
        Average prediction across bootstrap samples
    ci_lower : jnp.ndarray, shape (n_test,)
        Lower confidence bound at each test point
    ci_upper : jnp.ndarray, shape (n_test,)
        Upper confidence bound at each test point

    Examples
    --------
    >>> # Fit model and get bootstrap CIs
    >>> from psyphy.model import WPPM, Prior
    >>> from psyphy.model.task import OddityTask
    >>> from psyphy.model.noise import GaussianNoise
    >>>
    >>> model = WPPM(
    ...     input_dim=2,
    ...     prior=Prior(input_dim=2),
    ...     task=OddityTask(),
    ...     noise=GaussianNoise(),
    ... )
    >>>
    >>> # Bootstrap CIs for psychometric function
    >>> X_test = jnp.linspace(-1, 1, 50)[:, None]
    >>> probes_test = X_test + 0.1
    >>>
    >>> mean, lower, upper = bootstrap_predictions(
    ...     model,
    ...     X_train,
    ...     y_train,
    ...     X_test,
    ...     probes=probes_test,
    ...     n_bootstrap=100,
    ...     key=jr.PRNGKey(0),
    ... )
    >>>
    >>> # Plot with confidence bands
    >>> import matplotlib.pyplot as plt
    >>> plt.plot(X_test, mean, label="Mean prediction")
    >>> plt.fill_between(X_test[:, 0], lower, upper, alpha=0.3, label="95% CI")
    >>> plt.legend()

    >>> # Quick diagnostic: is model stable?
    >>> mean, lower, upper = bootstrap_predictions(
    ...     model,
    ...     X_train,
    ...     y_train,
    ...     X_test,
    ...     n_bootstrap=50,  # Faster for diagnostics
    ...     key=jr.PRNGKey(42),
    ... )
    >>> ci_width = upper - lower
    >>> print(f"Average CI width: {jnp.mean(ci_width):.3f}")

    Notes
    -----
    Computational cost:
    - Each bootstrap sample requires a full model refit
    - Total time ≈ n_bootstrap × (time per fit)
    - For MAP with 100 samples: typically 10-100 seconds

    Assumptions:
    - Training data are IID (independent and identically distributed)
    - For sequential data, consider block bootstrap instead

    The bootstrap estimates sampling uncertainty (how stable are
    predictions if we collected different data?), not model uncertainty
    (what is the range of plausible predictions given the data?).
    For model uncertainty, use the Bayesian posterior.variance instead.
    """
    n_train = len(X_train)
    alpha = 1 - confidence_level

    # Store predictions from each bootstrap sample
    predictions = []

    for _ in range(n_bootstrap):
        # Resample with replacement
        key, subkey = jr.split(key)
        indices = jr.randint(subkey, (n_train,), 0, n_train)
        X_boot = X_train[indices]
        y_boot = y_train[indices]

        # Create fresh model instance and fit
        model_boot = _clone_model(model)
        model_boot.fit(
            X_boot,
            y_boot,
            inference=inference,
            inference_config=inference_config,
        )

        # Get predictions
        posterior_boot = model_boot.posterior(X_test, probes=probes)
        predictions.append(posterior_boot.mean)  # type: ignore[attr-defined, union-attr]

    # Stack and compute quantiles
    predictions = jnp.stack(predictions, axis=0)  # (n_bootstrap, n_test)

    mean_estimate = jnp.mean(predictions, axis=0)
    ci_lower = jnp.percentile(predictions, 100 * alpha / 2, axis=0)
    ci_upper = jnp.percentile(predictions, 100 * (1 - alpha / 2), axis=0)

    return mean_estimate, ci_lower, ci_upper

bootstrap_statistic

bootstrap_statistic(
    model: Model,
    X: ndarray,
    y: ndarray,
    statistic_fn: Callable[[Model], float | ndarray],
    *,
    n_bootstrap: int = 100,
    confidence_level: float = 0.95,
    inference: str = "map",
    inference_config: dict[str, Any] | None = None,
    key: Any,
) -> tuple[
    float | ndarray, float | ndarray, float | ndarray
]

Bootstrap confidence interval for any model-derived statistic.

Resamples data, refits model, and computes statistic for each bootstrap sample. Returns point estimate and confidence interval.

Parameters:

Name Type Description Default
model Model

Unfitted model instance

required
X (ndarray, shape(n_trials, ...))

Training stimuli

required
y (ndarray, shape(n_trials))

Training responses

required
statistic_fn callable

Function that takes a fitted Model and returns a scalar or array. Examples: - lambda m: m.estimate_threshold(criterion=0.75) - lambda m: m.posterior(X_test).mean - lambda m: jnp.linalg.norm(m._posterior.params["lengthscales"])

required
n_bootstrap int

Number of bootstrap samples

100
confidence_level float

Confidence level for interval

0.95
inference str

Inference method

"map"
inference_config dict

Inference configuration

None
key KeyArray

Random key

required

Returns:

Name Type Description
estimate float or ndarray

Point estimate (mean across bootstrap samples)

ci_lower float or ndarray

Lower confidence bound

ci_upper float or ndarray

Upper confidence bound

Examples:

>>> # Bootstrap CI for threshold estimate
>>> def get_threshold(fitted_model):
...     # Example threshold estimation
...     X_grid = jnp.linspace(-2, 2, 100)[:, None]
...     probes = X_grid + 0.1
...     posterior = fitted_model.posterior(X_grid, probes=probes)
...     probs = posterior.mean
...     idx = jnp.argmin(jnp.abs(probs - 0.75))
...     return X_grid[idx, 0]
>>>
>>> threshold, lower, upper = bootstrap_statistic(
...     model,
...     X,
...     y,
...     statistic_fn=get_threshold,
...     n_bootstrap=200,
...     key=jr.PRNGKey(0),
... )
>>> print(f"Threshold: {threshold:.3f} [{lower:.3f}, {upper:.3f}]")
>>> # Bootstrap CI for model parameter
>>> def get_lengthscale(fitted_model):
...     return fitted_model._posterior.params["lengthscales"][0]
>>>
>>> ls, ls_lower, ls_upper = bootstrap_statistic(
...     model,
...     X,
...     y,
...     statistic_fn=get_lengthscale,
...     n_bootstrap=100,
...     key=jr.PRNGKey(42),
... )
>>> # Compare two models
>>> def test_accuracy(fitted_model):
...     posterior = fitted_model.posterior(X_test, probes=probes_test)
...     preds = (posterior.mean > 0.5).astype(int)
...     return jnp.mean(preds == y_test)
>>>
>>> acc1, l1, u1 = bootstrap_statistic(
...     model1, X_train, y_train, test_accuracy, n_bootstrap=100, key=jr.PRNGKey(0)
... )
>>> acc2, l2, u2 = bootstrap_statistic(
...     model2, X_train, y_train, test_accuracy, n_bootstrap=100, key=jr.PRNGKey(1)
... )
>>> # If CIs don't overlap, difference is statistically significant
>>> print(f"Model 1: {acc1:.3f} [{l1:.3f}, {u1:.3f}]")
>>> print(f"Model 2: {acc2:.3f} [{l2:.3f}, {u2:.3f}]")
Notes

This is a general-purpose function for any statistic you can compute from a fitted model. The statistic_fn should: - Take a fitted Model as input - Return a scalar or array (but shape must be consistent across samples) - Not modify the model

For vector-valued statistics, confidence intervals are computed element-wise.

Source code in src/psyphy/utils/bootstrap.py
def bootstrap_statistic(
    model: Model,
    X: jnp.ndarray,
    y: jnp.ndarray,
    statistic_fn: Callable[[Model], float | jnp.ndarray],
    *,
    n_bootstrap: int = 100,
    confidence_level: float = 0.95,
    inference: str = "map",
    inference_config: dict[str, Any] | None = None,
    key: Any,
) -> tuple[float | jnp.ndarray, float | jnp.ndarray, float | jnp.ndarray]:
    """
    Bootstrap confidence interval for any model-derived statistic.

    Resamples data, refits model, and computes statistic for each
    bootstrap sample. Returns point estimate and confidence interval.

    Parameters
    ----------
    model : Model
        Unfitted model instance
    X : jnp.ndarray, shape (n_trials, ...)
        Training stimuli
    y : jnp.ndarray, shape (n_trials,)
        Training responses
    statistic_fn : callable
        Function that takes a fitted Model and returns a scalar or array.
        Examples:
        - lambda m: m.estimate_threshold(criterion=0.75)
        - lambda m: m.posterior(X_test).mean
        - lambda m: jnp.linalg.norm(m._posterior.params["lengthscales"])
    n_bootstrap : int, default=100
        Number of bootstrap samples
    confidence_level : float, default=0.95
        Confidence level for interval
    inference : str, default="map"
        Inference method
    inference_config : dict, optional
        Inference configuration
    key : jr.KeyArray
        Random key

    Returns
    -------
    estimate : float or jnp.ndarray
        Point estimate (mean across bootstrap samples)
    ci_lower : float or jnp.ndarray
        Lower confidence bound
    ci_upper : float or jnp.ndarray
        Upper confidence bound

    Examples
    --------
    >>> # Bootstrap CI for threshold estimate
    >>> def get_threshold(fitted_model):
    ...     # Example threshold estimation
    ...     X_grid = jnp.linspace(-2, 2, 100)[:, None]
    ...     probes = X_grid + 0.1
    ...     posterior = fitted_model.posterior(X_grid, probes=probes)
    ...     probs = posterior.mean
    ...     idx = jnp.argmin(jnp.abs(probs - 0.75))
    ...     return X_grid[idx, 0]
    >>>
    >>> threshold, lower, upper = bootstrap_statistic(
    ...     model,
    ...     X,
    ...     y,
    ...     statistic_fn=get_threshold,
    ...     n_bootstrap=200,
    ...     key=jr.PRNGKey(0),
    ... )
    >>> print(f"Threshold: {threshold:.3f} [{lower:.3f}, {upper:.3f}]")

    >>> # Bootstrap CI for model parameter
    >>> def get_lengthscale(fitted_model):
    ...     return fitted_model._posterior.params["lengthscales"][0]
    >>>
    >>> ls, ls_lower, ls_upper = bootstrap_statistic(
    ...     model,
    ...     X,
    ...     y,
    ...     statistic_fn=get_lengthscale,
    ...     n_bootstrap=100,
    ...     key=jr.PRNGKey(42),
    ... )

    >>> # Compare two models
    >>> def test_accuracy(fitted_model):
    ...     posterior = fitted_model.posterior(X_test, probes=probes_test)
    ...     preds = (posterior.mean > 0.5).astype(int)
    ...     return jnp.mean(preds == y_test)
    >>>
    >>> acc1, l1, u1 = bootstrap_statistic(
    ...     model1, X_train, y_train, test_accuracy, n_bootstrap=100, key=jr.PRNGKey(0)
    ... )
    >>> acc2, l2, u2 = bootstrap_statistic(
    ...     model2, X_train, y_train, test_accuracy, n_bootstrap=100, key=jr.PRNGKey(1)
    ... )
    >>> # If CIs don't overlap, difference is statistically significant
    >>> print(f"Model 1: {acc1:.3f} [{l1:.3f}, {u1:.3f}]")
    >>> print(f"Model 2: {acc2:.3f} [{l2:.3f}, {u2:.3f}]")

    Notes
    -----
    This is a general-purpose function for any statistic you can compute
    from a fitted model. The statistic_fn should:
    - Take a fitted Model as input
    - Return a scalar or array (but shape must be consistent across samples)
    - Not modify the model

    For vector-valued statistics, confidence intervals are computed
    element-wise.
    """
    n_train = len(X)
    alpha = 1 - confidence_level

    statistics = []

    for _ in range(n_bootstrap):
        # Resample with replacement
        key, subkey = jr.split(key)
        indices = jr.randint(subkey, (n_train,), 0, n_train)
        X_boot = X[indices]
        y_boot = y[indices]

        # Fit and compute statistic
        model_boot = _clone_model(model)
        model_boot.fit(
            X_boot,
            y_boot,
            inference=inference,
            inference_config=inference_config,
        )

        stat = statistic_fn(model_boot)
        statistics.append(stat)

    # Stack and compute quantiles
    statistics = jnp.stack(statistics, axis=0)

    estimate = jnp.mean(statistics, axis=0)
    ci_lower = jnp.percentile(statistics, 100 * alpha / 2, axis=0)
    ci_upper = jnp.percentile(statistics, 100 * (1 - alpha / 2), axis=0)

    return estimate, ci_lower, ci_upper

chebyshev_basis

chebyshev_basis(x: ndarray, degree: int) -> ndarray

Construct the Chebyshev polynomial basis matrix T_0..T_degree evaluated at x.

Parameters:

Name Type Description Default
x ndarray

Input points of shape (N,). For best numerical properties, values should lie in [-1, 1].

required
degree int

Maximum polynomial degree (>= 0). The output includes columns for T_0 through T_degree.

required

Returns:

Type Description
ndarray

Array of shape (N, degree + 1) where column j contains T_j(x).

Raises:

Type Description
ValueError

If degree is negative or x is not 1-D.

Notes

Uses the three-term recurrence: T_0(x) = 1 T_1(x) = x T_{n+1}(x) = 2 x T_n(x) - T_{n-1}(x) The Chebyshev polynomials are orthogonal on [-1, 1] with weight (1 / sqrt(1 - x^2)).

Examples:

1
2
3
>>> import jax.numpy as jnp
>>> x = jnp.linspace(-1, 1, 5)
>>> B = chebyshev_basis(x, degree=3)  # columns: T0, T1, T2, T3
Source code in src/psyphy/utils/math.py
def chebyshev_basis(x: jnp.ndarray, degree: int) -> jnp.ndarray:
    """
    Construct the Chebyshev polynomial basis matrix T_0..T_degree evaluated at x.

    Parameters
    ----------
    x : jnp.ndarray
        Input points of shape (N,). For best numerical properties, values should lie in [-1, 1].
    degree : int
        Maximum polynomial degree (>= 0). The output includes columns for T_0 through T_degree.

    Returns
    -------
    jnp.ndarray
        Array of shape (N, degree + 1) where column j contains T_j(x).

    Raises
    ------
    ValueError
        If `degree` is negative or `x` is not 1-D.

    Notes
    -----
    Uses the three-term recurrence:
        T_0(x) = 1
        T_1(x) = x
        T_{n+1}(x) = 2 x T_n(x) - T_{n-1}(x)
    The Chebyshev polynomials are orthogonal on [-1, 1] with weight (1 / sqrt(1 - x^2)).

    Examples
    --------
    >>> import jax.numpy as jnp
    >>> x = jnp.linspace(-1, 1, 5)
    >>> B = chebyshev_basis(x, degree=3)  # columns: T0, T1, T2, T3
    """
    if degree < 0:
        raise ValueError("degree must be >= 0")
    if x.ndim != 1:
        raise ValueError("x must be 1-D (shape (N,))")

    # Ensure a floating dtype (Chebyshev recurrences are polynomial in x)
    x = x.astype(jnp.result_type(x, 0.0))

    N = x.shape[0]

    # Handle small degrees explicitly.
    if degree == 0:
        return jnp.ones((N, 1), dtype=x.dtype)
    if degree == 1:
        return jnp.stack([jnp.ones_like(x), x], axis=1)

    # Initialize T0 and T1 columns.
    T0 = jnp.ones_like(x)
    T1 = x

    # Scan to generate T2..T_degree in a JIT-friendly way (avoids Python-side loops).
    def step(carry, _):
        # compute next Chebyshev polynomial
        Tm1, Tm = carry
        Tnext = 2.0 * x * Tm - Tm1
        return (Tm, Tnext), Tnext  # new carry, plus an output to collect

    # Jax friendly loop
    (final_Tm1_ignored, final_Tm_ignored), Ts = lax.scan(
        step, (T0, T1), xs=None, length=degree - 1
    )
    # Ts has shape (degree-1, N) and holds [T2, T3, ..., T_degree]
    B = jnp.concatenate([T0[:, None], T1[:, None], jnp.swapaxes(Ts, 0, 1)], axis=1)
    return B

custom_candidates

custom_candidates(
    reference: ndarray, probe_list: list[ndarray]
) -> list[Stimulus]

Wrap a user-defined list of probes into candidate pairs.

Parameters:

Name Type Description Default
reference (ndarray, shape(D))

Reference stimulus.

required
probe_list list of jnp.ndarray

Explicitly chosen probe vectors.

required

Returns:

Type Description
list of Stimulus

Candidate (reference, probe) pairs.

Notes
  • Useful when hardware constraints (monitor gamut, auditory frequencies) restrict the set of valid stimuli.
  • Full WPPM mode: this pool could be pruned or expanded dynamically depending on posterior fit quality.
Source code in src/psyphy/utils/candidates.py
def custom_candidates(
    reference: jnp.ndarray, probe_list: list[jnp.ndarray]
) -> list[Stimulus]:
    """
    Wrap a user-defined list of probes into candidate pairs.

    Parameters
    ----------
    reference : jnp.ndarray, shape (D,)
        Reference stimulus.
    probe_list : list of jnp.ndarray
        Explicitly chosen probe vectors.

    Returns
    -------
    list of Stimulus
        Candidate (reference, probe) pairs.

    Notes
    -----
    - Useful when hardware constraints (monitor gamut, auditory frequencies)
      restrict the set of valid stimuli.
    - Full WPPM mode: this pool could be pruned or expanded dynamically
      depending on posterior fit quality.
    """
    return [(reference, probe) for probe in probe_list]

estimate_threshold_contour_uncertainty

estimate_threshold_contour_uncertainty(
    model: Model,
    reference: ndarray,
    n_angles: int = 16,
    max_distance: float = 0.5,
    n_grid_points: int = 100,
    probe_offset: float = 0.05,
    threshold_criterion: float = 0.75,
    n_samples: int = 100,
    *,
    key: Any,
) -> dict[str, Any]

Estimate threshold contour and its uncertainty around a reference point.

Searches radially in multiple directions to find threshold locations and their uncertainty.

Parameters:

Name Type Description Default
model Model

Fitted model

required
reference (ndarray, shape(input_dim))

Reference stimulus (center of contour)

required
n_angles int

Number of directions to search

16
max_distance float

Maximum search distance from reference

0.5
n_grid_points int

Grid resolution per direction

100
probe_offset float

Probe offset for discrimination

0.05
threshold_criterion float

Target accuracy level

0.75
n_samples int

Parameter samples for uncertainty estimation

100
key JAX random key
required

Returns:

Name Type Description
results dict

Dictionary with keys: - "angles": (n_angles,) - angles in radians - "threshold_mean": (n_angles, input_dim) - mean threshold coords - "threshold_std": (n_angles,) - std of threshold distance - "threshold_samples": (n_angles, n_samples) - all sample indices

Examples:

>>> # Estimate full contour
>>> reference = jnp.array([0.5, 0.3])
>>> results = estimate_threshold_contour_uncertainty(
...     model, reference, n_angles=16, n_samples=200, key=jr.PRNGKey(0)
... )
>>>
>>> # Plot with uncertainty
>>> import matplotlib.pyplot as plt
>>> fig, ax = plt.subplots(figsize=(8, 8))
>>> for i, angle in enumerate(results["angles"]):
...     mean_coord = results["threshold_mean"][i]
...     std_dist = results["threshold_std"][i]
...     ax.plot(*mean_coord, "ro", markersize=8)
...     # Plot uncertainty as error bar
...     direction = jnp.array([jnp.cos(angle), jnp.sin(angle)])
...     lower = mean_coord - 2 * std_dist * direction
...     upper = mean_coord + 2 * std_dist * direction
...     ax.plot([lower[0], upper[0]], [lower[1], upper[1]], "r-", alpha=0.3)
>>> ax.plot(*reference, "k*", markersize=20)
>>> ax.set_aspect("equal")
Source code in src/psyphy/utils/diagnostics.py
def estimate_threshold_contour_uncertainty(
    model: Model,
    reference: jnp.ndarray,
    n_angles: int = 16,
    max_distance: float = 0.5,
    n_grid_points: int = 100,
    probe_offset: float = 0.05,
    threshold_criterion: float = 0.75,
    n_samples: int = 100,
    *,
    key: Any,
) -> dict[str, Any]:
    """
    Estimate threshold contour and its uncertainty around a reference point.

    Searches radially in multiple directions to find threshold locations
    and their uncertainty.

    Parameters
    ----------
    model : Model
        Fitted model
    reference : jnp.ndarray, shape (input_dim,)
        Reference stimulus (center of contour)
    n_angles : int, default=16
        Number of directions to search
    max_distance : float, default=0.5
        Maximum search distance from reference
    n_grid_points : int, default=100
        Grid resolution per direction
    probe_offset : float, default=0.05
        Probe offset for discrimination
    threshold_criterion : float, default=0.75
        Target accuracy level
    n_samples : int, default=100
        Parameter samples for uncertainty estimation
    key : JAX random key

    Returns
    -------
    results : dict
        Dictionary with keys:
        - "angles": (n_angles,) - angles in radians
        - "threshold_mean": (n_angles, input_dim) - mean threshold coords
        - "threshold_std": (n_angles,) - std of threshold distance
        - "threshold_samples": (n_angles, n_samples) - all sample indices

    Examples
    --------
    >>> # Estimate full contour
    >>> reference = jnp.array([0.5, 0.3])
    >>> results = estimate_threshold_contour_uncertainty(
    ...     model, reference, n_angles=16, n_samples=200, key=jr.PRNGKey(0)
    ... )
    >>>
    >>> # Plot with uncertainty
    >>> import matplotlib.pyplot as plt
    >>> fig, ax = plt.subplots(figsize=(8, 8))
    >>> for i, angle in enumerate(results["angles"]):
    ...     mean_coord = results["threshold_mean"][i]
    ...     std_dist = results["threshold_std"][i]
    ...     ax.plot(*mean_coord, "ro", markersize=8)
    ...     # Plot uncertainty as error bar
    ...     direction = jnp.array([jnp.cos(angle), jnp.sin(angle)])
    ...     lower = mean_coord - 2 * std_dist * direction
    ...     upper = mean_coord + 2 * std_dist * direction
    ...     ax.plot([lower[0], upper[0]], [lower[1], upper[1]], "r-", alpha=0.3)
    >>> ax.plot(*reference, "k*", markersize=20)
    >>> ax.set_aspect("equal")
    """
    angles = jnp.linspace(0, 2 * jnp.pi, n_angles, endpoint=False)

    threshold_means = []
    threshold_stds = []
    all_samples = []

    for angle in angles:
        # Direction vector
        direction = jnp.array([jnp.cos(angle), jnp.sin(angle)])

        # Grid along this direction
        t = jnp.linspace(0, max_distance, n_grid_points)
        X_grid = reference + t[:, None] * direction
        probes = X_grid + probe_offset * direction

        # Estimate threshold
        key, subkey = jr.split(key)
        indices, mean_idx, std_idx = estimate_threshold_uncertainty(
            model,
            X_grid,
            probes,
            threshold_criterion=threshold_criterion,
            n_samples=n_samples,
            key=subkey,
        )

        # Store results
        threshold_coord = X_grid[int(mean_idx)]
        threshold_means.append(threshold_coord)
        threshold_stds.append(std_idx * (t[1] - t[0]))  # Convert to distance
        all_samples.append(indices)

    return {
        "angles": angles,
        "threshold_mean": jnp.array(threshold_means),
        "threshold_std": jnp.array(threshold_stds),
        "threshold_samples": jnp.array(all_samples),
    }

estimate_threshold_uncertainty

estimate_threshold_uncertainty(
    model: Model,
    X_grid: ndarray,
    probes: ndarray,
    threshold_criterion: float = 0.75,
    n_samples: int = 100,
    *,
    key: Any,
) -> tuple[ndarray, float, float]

Estimate threshold location and uncertainty via parameter sampling.

For each parameter sample θᵢ ~ p(θ | data), finds where the model predicts threshold_criterion accuracy. The distribution of these threshold locations gives us uncertainty about the threshold.

Parameters:

Name Type Description Default
model Model

Fitted model (must support predict_with_params)

required
X_grid (ndarray, shape(n_grid, input_dim))

Grid of test points to search over (e.g., line through stimulus space)

required
probes (ndarray, shape(n_grid, input_dim))

Probe at each grid point

required
threshold_criterion float

Target accuracy level (e.g., 0.75 for 75% correct threshold)

0.75
n_samples int

Number of parameter samples for Monte Carlo estimation

100
key JAX random key

Random key for parameter sampling

required

Returns:

Name Type Description
threshold_locations (ndarray, shape(n_samples))

Grid index of threshold for each parameter sample

threshold_mean float

Mean threshold location (as grid index)

threshold_std float

Standard deviation of threshold location (quantifies uncertainty)

Examples:

>>> # Create a line through stimulus space
>>> reference = jnp.array([0.5, 0.3])
>>> direction = jnp.array([0.1, 0.05])
>>> t = jnp.linspace(-1, 1, 200)
>>> X_grid = reference + t[:, None] * direction
>>> probes = X_grid + 0.05  # Small probe offset
>>>
>>> # Estimate threshold uncertainty
>>> indices, mean_idx, std_idx = estimate_threshold_uncertainty(
...     model,
...     X_grid,
...     probes,
...     threshold_criterion=0.75,
...     n_samples=200,
...     key=jr.PRNGKey(0),
... )
>>>
>>> # Convert to coordinates
>>> threshold_coords = X_grid[int(mean_idx)]
>>> print(f"75% threshold at: {threshold_coords}")
>>> print(f"Uncertainty: ±{std_idx * (t[1] - t[0]):.3f} stimulus units")
>>>
>>> # Plot distribution
>>> import matplotlib.pyplot as plt
>>> plt.hist(X_grid[indices, 0], bins=30, alpha=0.7)
>>> plt.axvline(
...     threshold_coords[0], color="r", linestyle="--", label="Mean threshold"
... )
>>> plt.xlabel("Threshold location (dimension 1)")
>>> plt.ylabel("Frequency")
>>> plt.title(f"{int(threshold_criterion * 100)}% Threshold Distribution")
>>> plt.legend()
Notes

This function quantifies threshold uncertainty - how uncertain we are about the threshold location given the observed data.

This is different from prediction uncertainty at a fixed location: - pred_post.variance tells you: "uncertainty about p(correct) at X" - estimate_threshold_uncertainty tells you: "uncertainty about where the threshold is"

Use this for: - Reporting threshold estimates with confidence intervals - Visualizing threshold contour uncertainty - Experimental design (test near uncertain thresholds)

Source code in src/psyphy/utils/diagnostics.py
def estimate_threshold_uncertainty(
    model: Model,
    X_grid: jnp.ndarray,
    probes: jnp.ndarray,
    threshold_criterion: float = 0.75,
    n_samples: int = 100,
    *,
    key: Any,
) -> tuple[jnp.ndarray, float, float]:
    """
    Estimate threshold location and uncertainty via parameter sampling.

    For each parameter sample θᵢ ~ p(θ | data), finds where the model
    predicts threshold_criterion accuracy. The distribution of these
    threshold locations gives us uncertainty about the threshold.

    Parameters
    ----------
    model : Model
        Fitted model (must support predict_with_params)
    X_grid : jnp.ndarray, shape (n_grid, input_dim)
        Grid of test points to search over (e.g., line through stimulus space)
    probes : jnp.ndarray, shape (n_grid, input_dim)
        Probe at each grid point
    threshold_criterion : float, default=0.75
        Target accuracy level (e.g., 0.75 for 75% correct threshold)
    n_samples : int, default=100
        Number of parameter samples for Monte Carlo estimation
    key : JAX random key
        Random key for parameter sampling

    Returns
    -------
    threshold_locations : jnp.ndarray, shape (n_samples,)
        Grid index of threshold for each parameter sample
    threshold_mean : float
        Mean threshold location (as grid index)
    threshold_std : float
        Standard deviation of threshold location (quantifies uncertainty)

    Examples
    --------
    >>> # Create a line through stimulus space
    >>> reference = jnp.array([0.5, 0.3])
    >>> direction = jnp.array([0.1, 0.05])
    >>> t = jnp.linspace(-1, 1, 200)
    >>> X_grid = reference + t[:, None] * direction
    >>> probes = X_grid + 0.05  # Small probe offset
    >>>
    >>> # Estimate threshold uncertainty
    >>> indices, mean_idx, std_idx = estimate_threshold_uncertainty(
    ...     model,
    ...     X_grid,
    ...     probes,
    ...     threshold_criterion=0.75,
    ...     n_samples=200,
    ...     key=jr.PRNGKey(0),
    ... )
    >>>
    >>> # Convert to coordinates
    >>> threshold_coords = X_grid[int(mean_idx)]
    >>> print(f"75% threshold at: {threshold_coords}")
    >>> print(f"Uncertainty: ±{std_idx * (t[1] - t[0]):.3f} stimulus units")
    >>>
    >>> # Plot distribution
    >>> import matplotlib.pyplot as plt
    >>> plt.hist(X_grid[indices, 0], bins=30, alpha=0.7)
    >>> plt.axvline(
    ...     threshold_coords[0], color="r", linestyle="--", label="Mean threshold"
    ... )
    >>> plt.xlabel("Threshold location (dimension 1)")
    >>> plt.ylabel("Frequency")
    >>> plt.title(f"{int(threshold_criterion * 100)}% Threshold Distribution")
    >>> plt.legend()

    Notes
    -----
    This function quantifies **threshold uncertainty** - how uncertain we are
    about the threshold location given the observed data.

    This is different from **prediction uncertainty** at a fixed location:
    - pred_post.variance tells you: "uncertainty about p(correct) at X"
    - estimate_threshold_uncertainty tells you: "uncertainty about where the threshold is"

    Use this for:
    - Reporting threshold estimates with confidence intervals
    - Visualizing threshold contour uncertainty
    - Experimental design (test near uncertain thresholds)
    """
    # Get parameter posterior and sample
    param_post = model.posterior(kind="parameter")
    param_samples = param_post.sample(n_samples, key=key)  # type: ignore[union-attr]

    threshold_indices = []

    for i in range(n_samples):
        # Extract i-th parameter sample
        params_i = {k: v[i] for k, v in param_samples.items()}

        # Evaluate model at all grid points with these specific parameters
        predictions_i = model.predict_with_params(X_grid, probes, params_i)

        # Find where this crosses threshold
        diffs = jnp.abs(predictions_i - threshold_criterion)
        threshold_idx = jnp.argmin(diffs)

        threshold_indices.append(int(threshold_idx))

    threshold_indices = jnp.array(threshold_indices)

    return (
        threshold_indices,
        float(jnp.mean(threshold_indices)),
        float(jnp.std(threshold_indices)),
    )

grid_candidates

grid_candidates(
    reference: ndarray,
    radii: list[float],
    directions: int = 16,
) -> list[Stimulus]

Generate grid-based candidate probes around a reference.

Parameters:

Name Type Description Default
reference (ndarray, shape(D))

Reference stimulus in model space.

required
radii list of float

Distances from reference to probe.

required
directions int

Number of angular directions.

16

Returns:

Type Description
list of Stimulus

Candidate (reference, probe) pairs.

Notes
  • MVP: probes lie on concentric circles around reference.
  • Full WPPM mode: could adaptively refine grid around regions of high posterior uncertainty.
Source code in src/psyphy/utils/candidates.py
def grid_candidates(
    reference: jnp.ndarray, radii: list[float], directions: int = 16
) -> list[Stimulus]:
    """
    Generate grid-based candidate probes around a reference.

    Parameters
    ----------
    reference : jnp.ndarray, shape (D,)
        Reference stimulus in model space.
    radii : list of float
        Distances from reference to probe.
    directions : int, default=16
        Number of angular directions.

    Returns
    -------
    list of Stimulus
        Candidate (reference, probe) pairs.

    Notes
    -----
    - MVP: probes lie on concentric circles around reference.
    - Full WPPM mode: could adaptively refine grid around regions of
      high posterior uncertainty.
    """
    candidates = []
    angles = jnp.linspace(0, 2 * jnp.pi, directions, endpoint=False)
    for r in radii:
        probes = [reference + r * jnp.array([jnp.cos(a), jnp.sin(a)]) for a in angles]
        candidates.extend([(reference, p) for p in probes])
    return candidates

mahalanobis_distance

mahalanobis_distance(
    x: ndarray, mean: ndarray, cov_inv: ndarray
) -> ndarray

Compute squared Mahalanobis distance between x and mean.

Parameters:

Name Type Description Default
x ndarray

Data vector, shape (D,).

required
mean ndarray

Mean vector, shape (D,).

required
cov_inv ndarray

Inverse covariance matrix, shape (D, D).

required

Returns:

Type Description
ndarray

Scalar squared Mahalanobis distance.

Notes
  • Formula: d^2 = (x - mean)^T Σ^{-1} (x - mean)
  • Used in WPPM discriminability calculations.
Source code in src/psyphy/utils/math.py
def mahalanobis_distance(
    x: jnp.ndarray, mean: jnp.ndarray, cov_inv: jnp.ndarray
) -> jnp.ndarray:
    """
    Compute squared Mahalanobis distance between x and mean.

    Parameters
    ----------
    x : jnp.ndarray
        Data vector, shape (D,).
    mean : jnp.ndarray
        Mean vector, shape (D,).
    cov_inv : jnp.ndarray
        Inverse covariance matrix, shape (D, D).

    Returns
    -------
    jnp.ndarray
        Scalar squared Mahalanobis distance.

    Notes
    -----
    - Formula: d^2 = (x - mean)^T Σ^{-1} (x - mean)
    - Used in WPPM discriminability calculations.
    """
    delta = x - mean
    return jnp.dot(delta, cov_inv @ delta)

parameter_summary

parameter_summary(
    param_posterior: ParameterPosterior,
    n_samples: int = 1000,
    *,
    key: Any | None = None,
    quantiles: tuple[float, ...] = (
        0.025,
        0.25,
        0.5,
        0.75,
        0.975,
    ),
) -> dict[str, dict[str, ndarray]]

Compute summary statistics for all model parameters.

Parameters:

Name Type Description Default
param_posterior ParameterPosterior

Parameter posterior to summarize

required
n_samples int

Number of Monte Carlo samples

1000
key JAX PRNGKey

Random key for sampling (auto-generated if None)

None
quantiles tuple of floats

Quantiles to compute

(0.025, 0.25, 0.5, 0.75, 0.975)

Returns:

Name Type Description
summary dict[str, dict[str, ndarray]]

Dictionary with keys for each parameter, values are dicts with: - "mean": Mean of posterior samples - "std": Standard deviation - "quantiles": Dict mapping quantile to value

Examples:

>>> param_post = model.posterior(kind="parameter")
>>> summary = parameter_summary(param_post, n_samples=500)
>>> print(
...     f"Noise: {summary['noise_scale']['mean']:.3f} "
...     f{summary['noise_scale']['std']:.3f}"
... )
>>> print(
...     f"95% CI: [{summary['noise_scale']['quantiles'][0.025]:.3f}, "
...     f"{summary['noise_scale']['quantiles'][0.975]:.3f}]"
... )
Source code in src/psyphy/utils/diagnostics.py
def parameter_summary(
    param_posterior: ParameterPosterior,
    n_samples: int = 1000,
    *,
    key: Any | None = None,
    quantiles: tuple[float, ...] = (0.025, 0.25, 0.5, 0.75, 0.975),
) -> dict[str, dict[str, jnp.ndarray]]:
    """
    Compute summary statistics for all model parameters.

    Parameters
    ----------
    param_posterior : ParameterPosterior
        Parameter posterior to summarize
    n_samples : int, default=1000
        Number of Monte Carlo samples
    key : JAX PRNGKey, optional
        Random key for sampling (auto-generated if None)
    quantiles : tuple of floats, default=(0.025, 0.25, 0.5, 0.75, 0.975)
        Quantiles to compute

    Returns
    -------
    summary : dict[str, dict[str, jnp.ndarray]]
        Dictionary with keys for each parameter, values are dicts with:
        - "mean": Mean of posterior samples
        - "std": Standard deviation
        - "quantiles": Dict mapping quantile to value

    Examples
    --------
    >>> param_post = model.posterior(kind="parameter")
    >>> summary = parameter_summary(param_post, n_samples=500)
    >>> print(
    ...     f"Noise: {summary['noise_scale']['mean']:.3f} "
    ...     f"± {summary['noise_scale']['std']:.3f}"
    ... )
    >>> print(
    ...     f"95% CI: [{summary['noise_scale']['quantiles'][0.025]:.3f}, "
    ...     f"{summary['noise_scale']['quantiles'][0.975]:.3f}]"
    ... )
    """
    if key is None:
        import time

        key = jr.PRNGKey(int(time.time() * 1e6) % 2**32)

    # Sample parameters
    samples = param_posterior.sample(n_samples, key=key)

    # Compute statistics for each parameter
    summary = {}
    for param_name, param_samples in samples.items():
        summary[param_name] = {
            "mean": jnp.mean(param_samples, axis=0),
            "std": jnp.std(param_samples, axis=0),
            "quantiles": {
                q: jnp.percentile(param_samples, 100 * q, axis=0) for q in quantiles
            },
        }

    return summary

print_parameter_summary

print_parameter_summary(
    param_posterior: ParameterPosterior,
    n_samples: int = 1000,
    *,
    key: Any | None = None,
) -> None

Print a human-readable parameter summary.

Examples:

1
2
3
>>> param_post = model.posterior(kind="parameter")
>>> print_parameter_summary(param_post)
Parameter Summary (1000 samples):

log_diag: Mean: [0.12, -0.03] Std: [0.05, 0.02]

Source code in src/psyphy/utils/diagnostics.py
def print_parameter_summary(
    param_posterior: ParameterPosterior,
    n_samples: int = 1000,
    *,
    key: Any | None = None,
) -> None:
    """
    Print a human-readable parameter summary.

    Examples
    --------
    >>> param_post = model.posterior(kind="parameter")
    >>> print_parameter_summary(param_post)
    Parameter Summary (1000 samples):

    log_diag:
      Mean: [0.12, -0.03]
      Std:  [0.05,  0.02]
    """
    summary = parameter_summary(param_posterior, n_samples, key=key)

    print(f"Parameter Summary ({n_samples} samples):\n")

    for param_name, stats in summary.items():
        print(f"{param_name}:")

        mean = stats["mean"]
        std = stats["std"]
        q025 = stats["quantiles"][0.025]
        q975 = stats["quantiles"][0.975]

        # Handle different shapes
        if mean.ndim == 0:  # Scalar
            print(f"  Mean: {float(mean):.3f} ± {float(std):.3f}")
            print(f"  95% CI: [{float(q025):.3f}, {float(q975):.3f}]")

        elif mean.ndim == 1:  # Vector
            print(f"  Mean: {mean}")
            print(f"  Std:  {std}")

        elif mean.ndim == 2:  # Matrix (e.g., W)
            # For matrices, report Frobenius norm
            mean_norm = jnp.linalg.norm(mean, "fro")
            std_norm = jnp.linalg.norm(std, "fro")
            q025_norm = jnp.linalg.norm(q025, "fro")
            q975_norm = jnp.linalg.norm(q975, "fro")
            print(f"  Mean (Frobenius norm): {mean_norm:.3f} ± {std_norm:.3f}")
            print(f"  95% CI (norm): [{q025_norm:.3f}, {q975_norm:.3f}]")
            print(f"  Shape: {mean.shape}")

        print()

rbf_kernel

rbf_kernel(
    x1: ndarray, x2: ndarray, lengthscale: float = 1.0
) -> ndarray

Radial Basis Function (RBF) kernel between two sets of points.

Parameters:

Name Type Description Default
x1 ndarray

First set of points, shape (N, D).

required
x2 ndarray

Second set of points, shape (M, D).

required
lengthscale float

Length-scale parameter controlling smoothness.

1.0

Returns:

Type Description
ndarray

Kernel matrix of shape (N, M).

Notes
  • RBF kernel: k(x, x') = exp(-||x - x'||^2 / (2 * lengthscale^2))
  • Default used for Gaussian processes for smooth covariance priors in Full WPPM mode.
Source code in src/psyphy/utils/math.py
def rbf_kernel(
    x1: jnp.ndarray, x2: jnp.ndarray, lengthscale: float = 1.0
) -> jnp.ndarray:
    """
    Radial Basis Function (RBF) kernel between two sets of points.


    Parameters
    ----------
    x1 : jnp.ndarray
        First set of points, shape (N, D).
    x2 : jnp.ndarray
        Second set of points, shape (M, D).
    lengthscale : float, default=1.0
        Length-scale parameter controlling smoothness.

    Returns
    -------
    jnp.ndarray
        Kernel matrix of shape (N, M).

    Notes
    -----
    - RBF kernel: k(x, x') = exp(-||x - x'||^2 / (2 * lengthscale^2))
    - Default used for Gaussian processes for smooth covariance priors in Full WPPM mode.
    """
    sqdist = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1)
    return jnp.exp(-0.5 * sqdist / (lengthscale**2))

seed

seed(seed_value: int) -> Any

Create a new PRNG key from an integer seed.

Parameters:

Name Type Description Default
seed_value int

Seed for random number generation.

required

Returns:

Type Description
KeyArray

New PRNG key.

Source code in src/psyphy/utils/rng.py
def seed(seed_value: int) -> Any:
    """
    Create a new PRNG key from an integer seed.

    Parameters
    ----------
    seed_value : int
        Seed for random number generation.

    Returns
    -------
    jax.random.KeyArray
        New PRNG key.
    """
    return jr.PRNGKey(seed_value)

sobol_candidates

sobol_candidates(
    reference: ndarray,
    n: int,
    bounds: list[tuple[float, float]],
    seed: int = 0,
) -> list[Stimulus]

Generate Sobol quasi-random candidates within bounds.

Parameters:

Name Type Description Default
reference (ndarray, shape(D))

Reference stimulus.

required
n int

Number of candidates to generate.

required
bounds list of (low, high)

Bounds per dimension.

required
seed int

Random seed.

0

Returns:

Type Description
list of Stimulus

Candidate (reference, probe) pairs.

Notes
  • MVP: uniform coverage of space using low-discrepancy Sobol sequence.
  • Full WPPM mode: Sobol could be used for initialization, then hand off to posterior-aware strategies.
Source code in src/psyphy/utils/candidates.py
def sobol_candidates(
    reference: jnp.ndarray, n: int, bounds: list[tuple[float, float]], seed: int = 0
) -> list[Stimulus]:
    """
    Generate Sobol quasi-random candidates within bounds.

    Parameters
    ----------
    reference : jnp.ndarray, shape (D,)
        Reference stimulus.
    n : int
        Number of candidates to generate.
    bounds : list of (low, high)
        Bounds per dimension.
    seed : int, default=0
        Random seed.

    Returns
    -------
    list of Stimulus
        Candidate (reference, probe) pairs.

    Notes
    -----
    - MVP: uniform coverage of space using low-discrepancy Sobol sequence.
    - Full WPPM mode: Sobol could be used for initialization,
      then hand off to posterior-aware strategies.
    """
    from scipy.stats.qmc import Sobol

    dim = len(bounds)
    engine = Sobol(d=dim, scramble=True, seed=seed)
    raw = engine.random(n)
    scaled = [low + (high - low) * raw[:, i] for i, (low, high) in enumerate(bounds)]
    probes = np.stack(scaled, axis=-1)
    return [(reference, jnp.array(p)) for p in probes]

split

split(key: Any, num: int = 2) -> Any

Split a PRNG key into multiple independent keys.

Parameters:

Name Type Description Default
key KeyArray

RNG key to split.

required
num int

Number of new keys to return.

2

Returns:

Type Description
tuple of jax.random.KeyArray

Independent new PRNG keys.

Source code in src/psyphy/utils/rng.py
def split(key: Any, num: int = 2) -> Any:
    """
    Split a PRNG key into multiple independent keys.

    Parameters
    ----------
    key : jax.random.KeyArray
        RNG key to split.
    num : int, default=2
        Number of new keys to return.

    Returns
    -------
    tuple of jax.random.KeyArray
        Independent new PRNG keys.
    """
    return jr.split(key, num=num)

RNG


rng

rng.py

Random number utilities for psyphy.

This module standardizes RNG handling across the package, especially important when mixing NumPy and JAX.

MVP implementation: - Wrappers around JAX PRNG keys. - Helpers for reproducibility.

Future extensions: - Experiment-wide RNG registry. - Splitting strategies for parallel adaptive placement.

Examples:

1
2
3
4
>>> import jax
>>> from psyphy.utils.rng import seed, split
>>> key = seed(0)
>>> k1, k2 = split(key)

Functions:

Name Description
seed

Create a new PRNG key from an integer seed.

split

Split a PRNG key into multiple independent keys.

seed

seed(seed_value: int) -> Any

Create a new PRNG key from an integer seed.

Parameters:

Name Type Description Default
seed_value int

Seed for random number generation.

required

Returns:

Type Description
KeyArray

New PRNG key.

Source code in src/psyphy/utils/rng.py
def seed(seed_value: int) -> Any:
    """
    Create a new PRNG key from an integer seed.

    Parameters
    ----------
    seed_value : int
        Seed for random number generation.

    Returns
    -------
    jax.random.KeyArray
        New PRNG key.
    """
    return jr.PRNGKey(seed_value)

split

split(key: Any, num: int = 2) -> Any

Split a PRNG key into multiple independent keys.

Parameters:

Name Type Description Default
key KeyArray

RNG key to split.

required
num int

Number of new keys to return.

2

Returns:

Type Description
tuple of jax.random.KeyArray

Independent new PRNG keys.

Source code in src/psyphy/utils/rng.py
def split(key: Any, num: int = 2) -> Any:
    """
    Split a PRNG key into multiple independent keys.

    Parameters
    ----------
    key : jax.random.KeyArray
        RNG key to split.
    num : int, default=2
        Number of new keys to return.

    Returns
    -------
    tuple of jax.random.KeyArray
        Independent new PRNG keys.
    """
    return jr.split(key, num=num)

Math


math

math.py

Math utilities for psyphy.

Includes: - chebyshev_basis : compute Chebyshev polynomial basis. - mahalanobis_distance : discriminability metric used in WPPM MVP. - rbf_kernel : kernel function, useful in Full WPPM mode covariance priors.

All functions use JAX (jax.numpy) for compatibility with autodiff.

Notes
  • math.chebyshev_basis is relevant when implementing Full WPPM mode, where covariance fields are expressed in a basis expansion.
  • math.mahalanobis_distance is directly used in WPPM MVP discriminability.
  • math.rbf_kernel is a placeholder for Gaussian-process-style covariance priors.

Examples:

1
2
3
4
5
>>> import jax.numpy as jnp
>>> from psyphy.utils import math
>>> x = jnp.linspace(-1, 1, 5)
>>> math.chebyshev_basis(x, degree=3).shape
(5, 4)

Functions:

Name Description
chebyshev_basis

Construct the Chebyshev polynomial basis matrix T_0..T_degree evaluated at x.

mahalanobis_distance

Compute squared Mahalanobis distance between x and mean.

rbf_kernel

Radial Basis Function (RBF) kernel between two sets of points.

chebyshev_basis

chebyshev_basis(x: ndarray, degree: int) -> ndarray

Construct the Chebyshev polynomial basis matrix T_0..T_degree evaluated at x.

Parameters:

Name Type Description Default
x ndarray

Input points of shape (N,). For best numerical properties, values should lie in [-1, 1].

required
degree int

Maximum polynomial degree (>= 0). The output includes columns for T_0 through T_degree.

required

Returns:

Type Description
ndarray

Array of shape (N, degree + 1) where column j contains T_j(x).

Raises:

Type Description
ValueError

If degree is negative or x is not 1-D.

Notes

Uses the three-term recurrence: T_0(x) = 1 T_1(x) = x T_{n+1}(x) = 2 x T_n(x) - T_{n-1}(x) The Chebyshev polynomials are orthogonal on [-1, 1] with weight (1 / sqrt(1 - x^2)).

Examples:

1
2
3
>>> import jax.numpy as jnp
>>> x = jnp.linspace(-1, 1, 5)
>>> B = chebyshev_basis(x, degree=3)  # columns: T0, T1, T2, T3
Source code in src/psyphy/utils/math.py
def chebyshev_basis(x: jnp.ndarray, degree: int) -> jnp.ndarray:
    """
    Construct the Chebyshev polynomial basis matrix T_0..T_degree evaluated at x.

    Parameters
    ----------
    x : jnp.ndarray
        Input points of shape (N,). For best numerical properties, values should lie in [-1, 1].
    degree : int
        Maximum polynomial degree (>= 0). The output includes columns for T_0 through T_degree.

    Returns
    -------
    jnp.ndarray
        Array of shape (N, degree + 1) where column j contains T_j(x).

    Raises
    ------
    ValueError
        If `degree` is negative or `x` is not 1-D.

    Notes
    -----
    Uses the three-term recurrence:
        T_0(x) = 1
        T_1(x) = x
        T_{n+1}(x) = 2 x T_n(x) - T_{n-1}(x)
    The Chebyshev polynomials are orthogonal on [-1, 1] with weight (1 / sqrt(1 - x^2)).

    Examples
    --------
    >>> import jax.numpy as jnp
    >>> x = jnp.linspace(-1, 1, 5)
    >>> B = chebyshev_basis(x, degree=3)  # columns: T0, T1, T2, T3
    """
    if degree < 0:
        raise ValueError("degree must be >= 0")
    if x.ndim != 1:
        raise ValueError("x must be 1-D (shape (N,))")

    # Ensure a floating dtype (Chebyshev recurrences are polynomial in x)
    x = x.astype(jnp.result_type(x, 0.0))

    N = x.shape[0]

    # Handle small degrees explicitly.
    if degree == 0:
        return jnp.ones((N, 1), dtype=x.dtype)
    if degree == 1:
        return jnp.stack([jnp.ones_like(x), x], axis=1)

    # Initialize T0 and T1 columns.
    T0 = jnp.ones_like(x)
    T1 = x

    # Scan to generate T2..T_degree in a JIT-friendly way (avoids Python-side loops).
    def step(carry, _):
        # compute next Chebyshev polynomial
        Tm1, Tm = carry
        Tnext = 2.0 * x * Tm - Tm1
        return (Tm, Tnext), Tnext  # new carry, plus an output to collect

    # Jax friendly loop
    (final_Tm1_ignored, final_Tm_ignored), Ts = lax.scan(
        step, (T0, T1), xs=None, length=degree - 1
    )
    # Ts has shape (degree-1, N) and holds [T2, T3, ..., T_degree]
    B = jnp.concatenate([T0[:, None], T1[:, None], jnp.swapaxes(Ts, 0, 1)], axis=1)
    return B

mahalanobis_distance

mahalanobis_distance(
    x: ndarray, mean: ndarray, cov_inv: ndarray
) -> ndarray

Compute squared Mahalanobis distance between x and mean.

Parameters:

Name Type Description Default
x ndarray

Data vector, shape (D,).

required
mean ndarray

Mean vector, shape (D,).

required
cov_inv ndarray

Inverse covariance matrix, shape (D, D).

required

Returns:

Type Description
ndarray

Scalar squared Mahalanobis distance.

Notes
  • Formula: d^2 = (x - mean)^T Σ^{-1} (x - mean)
  • Used in WPPM discriminability calculations.
Source code in src/psyphy/utils/math.py
def mahalanobis_distance(
    x: jnp.ndarray, mean: jnp.ndarray, cov_inv: jnp.ndarray
) -> jnp.ndarray:
    """
    Compute squared Mahalanobis distance between x and mean.

    Parameters
    ----------
    x : jnp.ndarray
        Data vector, shape (D,).
    mean : jnp.ndarray
        Mean vector, shape (D,).
    cov_inv : jnp.ndarray
        Inverse covariance matrix, shape (D, D).

    Returns
    -------
    jnp.ndarray
        Scalar squared Mahalanobis distance.

    Notes
    -----
    - Formula: d^2 = (x - mean)^T Σ^{-1} (x - mean)
    - Used in WPPM discriminability calculations.
    """
    delta = x - mean
    return jnp.dot(delta, cov_inv @ delta)

rbf_kernel

rbf_kernel(
    x1: ndarray, x2: ndarray, lengthscale: float = 1.0
) -> ndarray

Radial Basis Function (RBF) kernel between two sets of points.

Parameters:

Name Type Description Default
x1 ndarray

First set of points, shape (N, D).

required
x2 ndarray

Second set of points, shape (M, D).

required
lengthscale float

Length-scale parameter controlling smoothness.

1.0

Returns:

Type Description
ndarray

Kernel matrix of shape (N, M).

Notes
  • RBF kernel: k(x, x') = exp(-||x - x'||^2 / (2 * lengthscale^2))
  • Default used for Gaussian processes for smooth covariance priors in Full WPPM mode.
Source code in src/psyphy/utils/math.py
def rbf_kernel(
    x1: jnp.ndarray, x2: jnp.ndarray, lengthscale: float = 1.0
) -> jnp.ndarray:
    """
    Radial Basis Function (RBF) kernel between two sets of points.


    Parameters
    ----------
    x1 : jnp.ndarray
        First set of points, shape (N, D).
    x2 : jnp.ndarray
        Second set of points, shape (M, D).
    lengthscale : float, default=1.0
        Length-scale parameter controlling smoothness.

    Returns
    -------
    jnp.ndarray
        Kernel matrix of shape (N, M).

    Notes
    -----
    - RBF kernel: k(x, x') = exp(-||x - x'||^2 / (2 * lengthscale^2))
    - Default used for Gaussian processes for smooth covariance priors in Full WPPM mode.
    """
    sqdist = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1)
    return jnp.exp(-0.5 * sqdist / (lengthscale**2))


Stimulus candidates


candidates

candidates.py

Utilities for generating candidate stimulus pools.

Definition

A candidate pool is the set of all possible (reference, probe) pairs that an adaptive placement strategy may select from.

Separation of concerns
  • Candidate generation (this module) defines what stimuli are possible.
  • Trial placement strategies (e.g., GreedyMAPPlacement, InfoGainPlacement) define which of those candidates to present next.
Why this matters
  • Researchers: think of the candidate pool as the "menu" of allowable trials.
  • Developers: placement strategies should not generate candidates but only select from a given pool.
MVP implementation
  • Grid-based candidates (probes on circles around a reference).
  • Sobol sequence candidates (low-discrepancy exploration).
  • Custom user-defined candidate pools.
Full WPPM mode
  • Candidate generation could adaptively refine itself based on posterior uncertainty (e.g., dynamic grids).
  • Candidate pools could be constrained by device gamut or subject-specific calibration.

Functions:

Name Description
custom_candidates

Wrap a user-defined list of probes into candidate pairs.

grid_candidates

Generate grid-based candidate probes around a reference.

sobol_candidates

Generate Sobol quasi-random candidates within bounds.

Attributes:

Name Type Description
Stimulus

Stimulus

Stimulus = tuple[ndarray, ndarray]

custom_candidates

custom_candidates(
    reference: ndarray, probe_list: list[ndarray]
) -> list[Stimulus]

Wrap a user-defined list of probes into candidate pairs.

Parameters:

Name Type Description Default
reference (ndarray, shape(D))

Reference stimulus.

required
probe_list list of jnp.ndarray

Explicitly chosen probe vectors.

required

Returns:

Type Description
list of Stimulus

Candidate (reference, probe) pairs.

Notes
  • Useful when hardware constraints (monitor gamut, auditory frequencies) restrict the set of valid stimuli.
  • Full WPPM mode: this pool could be pruned or expanded dynamically depending on posterior fit quality.
Source code in src/psyphy/utils/candidates.py
def custom_candidates(
    reference: jnp.ndarray, probe_list: list[jnp.ndarray]
) -> list[Stimulus]:
    """
    Wrap a user-defined list of probes into candidate pairs.

    Parameters
    ----------
    reference : jnp.ndarray, shape (D,)
        Reference stimulus.
    probe_list : list of jnp.ndarray
        Explicitly chosen probe vectors.

    Returns
    -------
    list of Stimulus
        Candidate (reference, probe) pairs.

    Notes
    -----
    - Useful when hardware constraints (monitor gamut, auditory frequencies)
      restrict the set of valid stimuli.
    - Full WPPM mode: this pool could be pruned or expanded dynamically
      depending on posterior fit quality.
    """
    return [(reference, probe) for probe in probe_list]

grid_candidates

grid_candidates(
    reference: ndarray,
    radii: list[float],
    directions: int = 16,
) -> list[Stimulus]

Generate grid-based candidate probes around a reference.

Parameters:

Name Type Description Default
reference (ndarray, shape(D))

Reference stimulus in model space.

required
radii list of float

Distances from reference to probe.

required
directions int

Number of angular directions.

16

Returns:

Type Description
list of Stimulus

Candidate (reference, probe) pairs.

Notes
  • MVP: probes lie on concentric circles around reference.
  • Full WPPM mode: could adaptively refine grid around regions of high posterior uncertainty.
Source code in src/psyphy/utils/candidates.py
def grid_candidates(
    reference: jnp.ndarray, radii: list[float], directions: int = 16
) -> list[Stimulus]:
    """
    Generate grid-based candidate probes around a reference.

    Parameters
    ----------
    reference : jnp.ndarray, shape (D,)
        Reference stimulus in model space.
    radii : list of float
        Distances from reference to probe.
    directions : int, default=16
        Number of angular directions.

    Returns
    -------
    list of Stimulus
        Candidate (reference, probe) pairs.

    Notes
    -----
    - MVP: probes lie on concentric circles around reference.
    - Full WPPM mode: could adaptively refine grid around regions of
      high posterior uncertainty.
    """
    candidates = []
    angles = jnp.linspace(0, 2 * jnp.pi, directions, endpoint=False)
    for r in radii:
        probes = [reference + r * jnp.array([jnp.cos(a), jnp.sin(a)]) for a in angles]
        candidates.extend([(reference, p) for p in probes])
    return candidates

sobol_candidates

sobol_candidates(
    reference: ndarray,
    n: int,
    bounds: list[tuple[float, float]],
    seed: int = 0,
) -> list[Stimulus]

Generate Sobol quasi-random candidates within bounds.

Parameters:

Name Type Description Default
reference (ndarray, shape(D))

Reference stimulus.

required
n int

Number of candidates to generate.

required
bounds list of (low, high)

Bounds per dimension.

required
seed int

Random seed.

0

Returns:

Type Description
list of Stimulus

Candidate (reference, probe) pairs.

Notes
  • MVP: uniform coverage of space using low-discrepancy Sobol sequence.
  • Full WPPM mode: Sobol could be used for initialization, then hand off to posterior-aware strategies.
Source code in src/psyphy/utils/candidates.py
def sobol_candidates(
    reference: jnp.ndarray, n: int, bounds: list[tuple[float, float]], seed: int = 0
) -> list[Stimulus]:
    """
    Generate Sobol quasi-random candidates within bounds.

    Parameters
    ----------
    reference : jnp.ndarray, shape (D,)
        Reference stimulus.
    n : int
        Number of candidates to generate.
    bounds : list of (low, high)
        Bounds per dimension.
    seed : int, default=0
        Random seed.

    Returns
    -------
    list of Stimulus
        Candidate (reference, probe) pairs.

    Notes
    -----
    - MVP: uniform coverage of space using low-discrepancy Sobol sequence.
    - Full WPPM mode: Sobol could be used for initialization,
      then hand off to posterior-aware strategies.
    """
    from scipy.stats.qmc import Sobol

    dim = len(bounds)
    engine = Sobol(d=dim, scramble=True, seed=seed)
    raw = engine.random(n)
    scaled = [low + (high - low) * raw[:, i] for i, (low, high) in enumerate(bounds)]
    probes = np.stack(scaled, axis=-1)
    return [(reference, jnp.array(p)) for p in probes]