Skip to content

Model

Package


model

psyphy.model

Model-layer API: everything model-related in one place.

Includes
  • WPPM (core model)
  • Priors (Prior)
  • Tasks (TaskLikelihood base, OddityTask, TwoAFC)
  • Noise models (GaussianNoise, StudentTNoise)

All functions/classes use JAX arrays (jax.numpy as jnp) for autodiff and optimization with Optax.

Typical usage
1
from psyphy.model import WPPM, Prior, OddityTask, GaussianNoise

Classes:

Name Description
GaussianNoise
Model

Abstract base class for psychophysical models.

OddityTask

Three-alternative forced-choice oddity task.

OnlineConfig

Configuration for online learning and memory management.

Prior

Prior distribution over WPPM parameters

StudentTNoise
TaskLikelihood

Abstract base class for task likelihoods

TwoAFC

2-alternative forced-choice task (MVP placeholder).

WPPM

Wishart Process Psychophysical Model (WPPM).

GaussianNoise

GaussianNoise(sigma: float = 1.0)

Methods:

Name Description
log_prob
sample_standard

Sample from standard Gaussian (mean=0, var=1).

Attributes:

Name Type Description
sigma float

sigma

sigma: float = 1.0

log_prob

log_prob(residual: float) -> float
Source code in src/psyphy/model/noise.py
def log_prob(self, residual: float) -> float:
    _ = residual
    return -0.5

sample_standard

sample_standard(
    key: Array, shape: tuple[int, ...]
) -> Array

Sample from standard Gaussian (mean=0, var=1).

Source code in src/psyphy/model/noise.py
def sample_standard(self, key: jax.Array, shape: tuple[int, ...]) -> jax.Array:
    """Sample from standard Gaussian (mean=0, var=1)."""
    return jr.normal(key, shape)

Model

Model(*, online_config: OnlineConfig | None = None)

Bases: ABC

Abstract base class for psychophysical models.

Provides API that mimics BoTorch style: - fit(X, y) --> train model - posterior(X) --> get predictions - condition_on_observations(X, y) --> online updates

Subclasses must implement: - init_params(key) --> sample initial parameters - log_likelihood_from_data(params, data) --> compute likelihood

Parameters:

Name Type Description Default
online_config OnlineConfig | None

Configuration for online learning. If None, uses default (unbounded memory).

None

Attributes:

Name Type Description
_posterior ParameterPosterior | None

Cached parameter posterior from last fit

_inference_engine InferenceEngine | None

Cached inference engine for warm-start refitting

_data_buffer ResponseData | None

Data buffer managed according to online_config

_n_updates int

Number of condition_on_observations calls

online_config OnlineConfig

Online learning configuration

Initialize model.

Parameters:

Name Type Description Default
online_config OnlineConfig | None

Online learning configuration. If None, uses default settings.

None

Methods:

Name Description
condition_on_observations

Update model with new observations (online learning).

fit

Fit model to data.

init_params

Sample initial parameters from prior.

log_likelihood_from_data

Compute log p(data | params).

posterior

Return posterior distribution.

predict_with_params

Evaluate model at specific parameter values (no marginalization).

Source code in src/psyphy/model/base.py
def __init__(self, *, online_config: OnlineConfig | None = None):
    """
    Initialize model.

    Parameters
    ----------
    online_config : OnlineConfig | None
        Online learning configuration. If None, uses default settings.
    """
    self._posterior: ParameterPosterior | None = None
    self._inference_engine: InferenceEngine | None = None
    self._data_buffer: ResponseData | None = None
    self._n_updates: int = 0
    self.online_config = online_config or OnlineConfig()

online_config

online_config = online_config or OnlineConfig()

condition_on_observations

condition_on_observations(X: ndarray, y: ndarray) -> Model

Update model with new observations (online learning).

Behavior depends on self.online_config.strategy: - "full": Accumulate all data, refit periodically - "sliding_window": Keep only recent window_size trials - "reservoir": Random sampling of window_size trials - "none": Refit from scratch (no caching)

Returns a NEW model instance (immutable update).

Parameters:

Name Type Description Default
X ndarray

New stimuli

required
y ndarray

New responses

required

Returns:

Type Description
Model

Updated model (new instance)

Examples:

1
2
3
4
5
>>> # Online learning loop
>>> model = WPPM(...).fit(X_init, y_init)
>>> for X_new, y_new in stream:
...     model = model.condition_on_observations(X_new, y_new)
...     # Model automatically manages memory and refitting
Source code in src/psyphy/model/base.py
def condition_on_observations(self, X: jnp.ndarray, y: jnp.ndarray) -> Model:
    """
    Update model with new observations (online learning).

    Behavior depends on self.online_config.strategy:
    - "full": Accumulate all data, refit periodically
    - "sliding_window": Keep only recent window_size trials
    - "reservoir": Random sampling of window_size trials
    - "none": Refit from scratch (no caching)

    Returns a NEW model instance (immutable update).

    Parameters
    ----------
    X : jnp.ndarray
        New stimuli
    y : jnp.ndarray
        New responses

    Returns
    -------
    Model
        Updated model (new instance)

    Examples
    --------
    >>> # Online learning loop
    >>> model = WPPM(...).fit(X_init, y_init)
    >>> for X_new, y_new in stream:
    ...     model = model.condition_on_observations(X_new, y_new)
    ...     # Model automatically manages memory and refitting
    """
    from psyphy.data import ResponseData

    # Convert new data
    new_data = ResponseData.from_arrays(X, y)

    # Update data buffer according to strategy
    if self.online_config.strategy == "none":
        data_to_fit = new_data

    elif self.online_config.strategy == "full":
        if self._data_buffer is None:
            self._data_buffer = ResponseData()
        self._data_buffer.merge(new_data)
        data_to_fit = self._data_buffer

    elif self.online_config.strategy == "sliding_window":
        if self._data_buffer is None:
            self._data_buffer = ResponseData()
        self._data_buffer.merge(new_data)

        # Keep only last N trials
        window_size = self.online_config.window_size
        assert window_size is not None, "window_size must be set for sliding_window"
        if len(self._data_buffer) > window_size:
            self._data_buffer = self._data_buffer.tail(window_size)
        data_to_fit = self._data_buffer

    elif self.online_config.strategy == "reservoir":
        if self._data_buffer is None:
            self._data_buffer = ResponseData()

        # Reservoir sampling
        window_size = self.online_config.window_size
        assert window_size is not None, "window_size must be set for reservoir"
        self._data_buffer = self._reservoir_update(
            self._data_buffer,
            new_data,
            window_size,
        )
        data_to_fit = self._data_buffer
    else:
        raise ValueError(f"Unknown strategy: {self.online_config.strategy}")

    # Decide whether to refit
    self._n_updates += 1
    should_refit = self._n_updates % self.online_config.refit_interval == 0

    if not should_refit:
        # Return clone with updated buffer but old posterior
        new_model = self._clone()
        new_model._data_buffer = data_to_fit
        return new_model

    # Refit with optional warm start
    inference = self._inference_engine
    assert inference is not None, (
        "Model must be fit before condition_on_observations"
    )

    if self.online_config.warm_start and self._posterior is not None:
        # TODO: Add warm-start support to inference engines
        # For now, just refit from cached params
        pass

    new_model = self._clone()
    new_model._data_buffer = data_to_fit
    new_model._posterior = inference.fit(new_model, data_to_fit)
    new_model._inference_engine = inference

    return new_model

fit

fit(
    X: ndarray,
    y: ndarray,
    *,
    inference: InferenceEngine | str = "laplace",
    inference_config: dict | None = None,
) -> Model

Fit model to data.

Parameters:

Name Type Description Default
X ndarray

Stimuli, shape (n_trials, 2, input_dim) for (ref, probe) pairs or (n_trials, input_dim) for references only

required
y ndarray

Responses, shape (n_trials,)

required
inference InferenceEngine | str

Inference engine or string key ("map", "laplace", "langevin")

"laplace"
inference_config dict | None

Hyperparameters for string-based inference. Examples: {"steps": 500, "lr": 1e-3} for MAP

None

Returns:

Type Description
Model

Self for method chaining

Examples:

>>> # Simple: use defaults
>>> model.fit(X, y)
1
2
3
>>> # Explicit optimizer
>>> from psyphy.inference import MAPOptimizer
>>> model.fit(X, y, inference=MAPOptimizer(steps=500))
>>> # String + config (for experiment tracking)
>>> model.fit(X, y, inference="map", inference_config={"steps": 500})
Source code in src/psyphy/model/base.py
def fit(
    self,
    X: jnp.ndarray,
    y: jnp.ndarray,
    *,
    inference: InferenceEngine | str = "laplace",
    inference_config: dict | None = None,
) -> Model:
    """
    Fit model to data.

    Parameters
    ----------
    X : jnp.ndarray
        Stimuli, shape (n_trials, 2, input_dim) for (ref, probe) pairs
        or (n_trials, input_dim) for references only
    y : jnp.ndarray
        Responses, shape (n_trials,)
    inference : InferenceEngine | str, default="laplace"
        Inference engine or string key ("map", "laplace", "langevin")
    inference_config : dict | None
        Hyperparameters for string-based inference.
        Examples: {"steps": 500, "lr": 1e-3} for MAP

    Returns
    -------
    Model
        Self for method chaining

    Examples
    --------
    >>> # Simple: use defaults
    >>> model.fit(X, y)

    >>> # Explicit optimizer
    >>> from psyphy.inference import MAPOptimizer
    >>> model.fit(X, y, inference=MAPOptimizer(steps=500))

    >>> # String + config (for experiment tracking)
    >>> model.fit(X, y, inference="map", inference_config={"steps": 500})
    """
    from psyphy.data import ResponseData
    from psyphy.inference import INFERENCE_ENGINES, InferenceEngine

    # Resolve inference engine
    is_string_inference = isinstance(inference, str)

    if is_string_inference:
        config = inference_config or {}
        inference_key: str = inference  # type: ignore[assignment]
        if inference_key not in INFERENCE_ENGINES:
            available = ", ".join(INFERENCE_ENGINES.keys())
            raise ValueError(
                f"Unknown inference: '{inference}'. Available: {available}"
            )
        inference_engine: InferenceEngine = INFERENCE_ENGINES[inference_key](
            **config
        )
    elif isinstance(inference, InferenceEngine):
        inference_engine = inference
    else:
        raise TypeError(
            f"inference must be InferenceEngine or str, got {type(inference)}"
        )

    if inference_config is not None and not is_string_inference:
        raise ValueError(
            "Cannot pass inference_config with InferenceEngine instance"
        )

    # Convert data
    data = ResponseData.from_arrays(X, y)

    # Fit
    self._posterior = inference_engine.fit(self, data)
    self._inference_engine = inference_engine
    self._data_buffer = data  # Initialize buffer
    return self

init_params

init_params(key: Any) -> dict

Sample initial parameters from prior.

Parameters:

Name Type Description Default
key KeyArray

PRNG key

required

Returns:

Type Description
dict

Parameter PyTree

Source code in src/psyphy/model/base.py
@abstractmethod
def init_params(self, key: Any) -> dict:  # jax.random.KeyArray
    """
    Sample initial parameters from prior.

    Parameters
    ----------
    key : jax.random.KeyArray
        PRNG key

    Returns
    -------
    dict
        Parameter PyTree
    """
    ...

log_likelihood_from_data

log_likelihood_from_data(
    params: dict, data: ResponseData
) -> ndarray

Compute log p(data | params).

Parameters:

Name Type Description Default
params dict

Model parameters

required
data ResponseData

Observed trials

required

Returns:

Type Description
ndarray

Log-likelihood (scalar)

Source code in src/psyphy/model/base.py
@abstractmethod
def log_likelihood_from_data(self, params: dict, data: ResponseData) -> jnp.ndarray:
    """
    Compute log p(data | params).

    Parameters
    ----------
    params : dict
        Model parameters
    data : ResponseData
        Observed trials

    Returns
    -------
    jnp.ndarray
        Log-likelihood (scalar)
    """
    ...

posterior

posterior(
    X: ndarray | None = None,
    *,
    probes: ndarray | None = None,
    kind: str = "predictive",
) -> PredictivePosterior | ParameterPosterior

Return posterior distribution.

Parameters:

Name Type Description Default
X ndarray | None

Test stimuli (references), shape (n_test, input_dim). Required for predictive posteriors, optional for parameter posteriors.

None
probes ndarray | None

Test probes, shape (n_test, input_dim). Required for predictive posteriors.

None
kind ('predictive', 'parameter')

Type of posterior to return: - "predictive": PredictivePosterior over f(X*) [for acquisitions] - "parameter": ParameterPosterior over θ [for diagnostics]

"predictive"

Returns:

Type Description
PredictivePosterior | ParameterPosterior

Posterior distribution

Raises:

Type Description
RuntimeError

If model has not been fit yet

Examples:

1
2
3
4
>>> # For acquisition functions
>>> pred_post = model.posterior(X_candidates, probes=X_probes)
>>> mean = pred_post.mean
>>> var = pred_post.variance
1
2
3
>>> # For diagnostics
>>> param_post = model.posterior(kind="parameter")
>>> samples = param_post.sample(100, key=jr.PRNGKey(42))
Source code in src/psyphy/model/base.py
def posterior(
    self,
    X: jnp.ndarray | None = None,
    *,
    probes: jnp.ndarray | None = None,
    kind: str = "predictive",
) -> PredictivePosterior | ParameterPosterior:
    """
    Return posterior distribution.

    Parameters
    ----------
    X : jnp.ndarray | None
        Test stimuli (references), shape (n_test, input_dim).
        Required for predictive posteriors, optional for parameter posteriors.
    probes : jnp.ndarray | None
        Test probes, shape (n_test, input_dim).
        Required for predictive posteriors.
    kind : {"predictive", "parameter"}
        Type of posterior to return:
        - "predictive": PredictivePosterior over f(X*) [for acquisitions]
        - "parameter": ParameterPosterior over θ [for diagnostics]

    Returns
    -------
    PredictivePosterior | ParameterPosterior
        Posterior distribution

    Raises
    ------
    RuntimeError
        If model has not been fit yet

    Examples
    --------
    >>> # For acquisition functions
    >>> pred_post = model.posterior(X_candidates, probes=X_probes)
    >>> mean = pred_post.mean
    >>> var = pred_post.variance

    >>> # For diagnostics
    >>> param_post = model.posterior(kind="parameter")
    >>> samples = param_post.sample(100, key=jr.PRNGKey(42))
    """
    if self._posterior is None:
        raise RuntimeError("Must call fit() before posterior()")

    if kind == "parameter":
        return self._posterior
    elif kind == "predictive":
        if X is None:
            raise ValueError("X is required for predictive posteriors")
        from psyphy.posterior import WPPMPredictivePosterior

        return WPPMPredictivePosterior(self._posterior, X, probes=probes)
    else:
        raise ValueError(
            f"Unknown kind: '{kind}'. Use 'predictive' or 'parameter'."
        )

predict_with_params

predict_with_params(
    X: ndarray,
    probes: ndarray | None,
    params: dict[str, ndarray],
) -> ndarray

Evaluate model at specific parameter values (no marginalization).

This is useful for: - Threshold uncertainty estimation (evaluate at sampled parameters) - Parameter sensitivity analysis - Debugging and diagnostics

NOT for making predictions (use .posterior() instead, which marginalizes over parameter uncertainty).

Parameters:

Name Type Description Default
X (ndarray, shape(n_test, input_dim))

Test stimuli (references)

required
probes (ndarray, shape(n_test, input_dim))

Probe stimuli (for discrimination tasks)

required
params dict[str, ndarray]

Specific parameter values to evaluate at. Keys and shapes depend on the model (e.g., WPPM has "W", "noise_scale", etc.)

required

Returns:

Name Type Description
predictions (ndarray, shape(n_test))

Predicted probabilities at each test point, given these parameters

Examples:

>>> # Sample parameters and evaluate
>>> param_post = model.posterior(kind="parameter")
>>> samples = param_post.sample(100, key=jr.PRNGKey(0))
>>>
>>> # Evaluate at first parameter sample
>>> params_0 = {k: v[0] for k, v in samples.items()}
>>> predictions = model.predict_with_params(X_test, probes, params_0)
>>>
>>> # Use for threshold uncertainty estimation
>>> threshold_locs = []
>>> for i in range(100):
...     params_i = {k: v[i] for k, v in samples.items()}
...     preds_i = model.predict_with_params(X_grid, probes_grid, params_i)
...     threshold_idx = jnp.argmin(jnp.abs(preds_i - 0.75))
...     threshold_locs.append(threshold_idx)
Notes

This bypasses the posterior marginalization. For acquisition functions, always use .posterior() which properly accounts for parameter uncertainty.

Source code in src/psyphy/model/base.py
def predict_with_params(
    self,
    X: jnp.ndarray,
    probes: jnp.ndarray | None,
    params: dict[str, jnp.ndarray],
) -> jnp.ndarray:
    """
    Evaluate model at specific parameter values (no marginalization).

    This is useful for:
    - Threshold uncertainty estimation (evaluate at sampled parameters)
    - Parameter sensitivity analysis
    - Debugging and diagnostics

    NOT for making predictions (use .posterior() instead, which
    marginalizes over parameter uncertainty).

    Parameters
    ----------
    X : jnp.ndarray, shape (n_test, input_dim)
        Test stimuli (references)
    probes : jnp.ndarray, shape (n_test, input_dim), optional
        Probe stimuli (for discrimination tasks)
    params : dict[str, jnp.ndarray]
        Specific parameter values to evaluate at.
        Keys and shapes depend on the model (e.g., WPPM has "W", "noise_scale", etc.)

    Returns
    -------
    predictions : jnp.ndarray, shape (n_test,)
        Predicted probabilities at each test point, given these parameters

    Examples
    --------
    >>> # Sample parameters and evaluate
    >>> param_post = model.posterior(kind="parameter")
    >>> samples = param_post.sample(100, key=jr.PRNGKey(0))
    >>>
    >>> # Evaluate at first parameter sample
    >>> params_0 = {k: v[0] for k, v in samples.items()}
    >>> predictions = model.predict_with_params(X_test, probes, params_0)
    >>>
    >>> # Use for threshold uncertainty estimation
    >>> threshold_locs = []
    >>> for i in range(100):
    ...     params_i = {k: v[i] for k, v in samples.items()}
    ...     preds_i = model.predict_with_params(X_grid, probes_grid, params_i)
    ...     threshold_idx = jnp.argmin(jnp.abs(preds_i - 0.75))
    ...     threshold_locs.append(threshold_idx)

    Notes
    -----
    This bypasses the posterior marginalization. For acquisition functions,
    always use .posterior() which properly accounts for parameter uncertainty.
    """
    return self._forward(X, probes, params)

OddityTask

OddityTask(slope: float = 1.5)

Bases: TaskLikelihood

Three-alternative forced-choice oddity task.

In an oddity task, the observer is presented with three stimuli: two identical references and one comparison. The task is to identify which stimulus is the "odd one out" (the comparison). Performance depends on the discriminability between reference and comparison.

This class provides two likelihood computation methods:

  1. Analytical approximation (MVP mode):
  2. predict(): maps discriminability to P(correct) via tanh
  3. loglik(): Bernoulli likelihood using analytical predictions
  4. Fast, differentiable, suitable for gradient-based optimization

  5. Monte Carlo simulation (Full WPPM mode):

  6. loglik_mc(): simulates the full 3-stimulus oddity task
  7. Samples three internal representations per trial (z0, z1, z2)
  8. Uses proper oddity decision rule with three pairwise distances
  9. More accurate for complex covariance structures
  10. Suitable for validation and benchmarking

Parameters:

Name Type Description Default
slope float

Slope parameter for analytical tanh mapping in predict(). Controls steepness of discriminability -> performance relationship.

1.5

Attributes:

Name Type Description
chance_level float

Chance performance for oddity task (1/3)

performance_range float

Range from chance to perfect performance (2/3)

Notes

The analytical approximation in predict() uses: P(correct) = 1/3 + 2/3 * (1 + tanh(slope * d)) / 2

MC simulation in loglik_mc() (Full 3-stimulus oddity): 1. Sample three internal representations: z_ref, z_refprime ~ N(ref, Σ_ref), z_comparison ~ N(comparison, Σ_comparison) 2. Compute average covariance: Σ_avg = (2/3) Σ_ref + (1/3) Σ_comparison 3. Compute three pairwise Mahalanobis distances: - d^2(z_ref, z_refprime) = distance between two reference samples - d^2(z_ref, z_comparison) = distance from ref to comparison - d^2(z_refprime, z_comparison) = distance from reference_prime to comparison 4. Apply oddity decision rule: delta = min(d^2(z_ref,z_comparison), d^2(z_refprime,z_comparison)) - d^2(z_ref,z_refprime) 5. Logistic smoothing: P(correct) pprox logistic.cdf(delta / bandwidth) 6. Average over samples

Examples:

>>> from psyphy.model.task import OddityTask
>>> from psyphy.model import WPPM, Prior
>>> from psyphy.model.noise import GaussianNoise
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>>
>>> # Create task and model
>>> task = OddityTask(slope=1.5)
>>> model = WPPM(
...     input_dim=2, prior=Prior(input_dim=2), task=task, noise=GaussianNoise()
... )
>>> params = model.init_params(jr.PRNGKey(0))
>>>
>>> # Analytical prediction
>>> ref = jnp.array([0.0, 0.0])
>>> comparison = jnp.array([0.5, 0.5])
>>> p_correct = task.predict(params, (ref, comparison), model, model.noise)
>>> print(f"P(correct) pprox {p_correct:.3f}")
>>>
>>> # MC simulation (more accurate)
>>> from psyphy.data.dataset import ResponseData
>>> data = ResponseData()
>>> data.add_trial(ref, comparison, resp=1)
>>> ll_mc = task.loglik_mc(
...     params, data, model, model.noise, num_samples=1000, key=jr.PRNGKey(42)
... )
>>> print(f"Log-likelihood (MC): {ll_mc:.4f}")

Methods:

Name Description
loglik

Compute log-likelihood using analytical predictions.

loglik_mc

Compute log-likelihood via Monte Carlo observer simulation.

predict

Predict probability of correct response using analytical approximation.

Source code in src/psyphy/model/task.py
def __init__(self, slope: float = 1.5) -> None:
    self.slope = float(slope)
    self.chance_level: float = 1.0 / 3.0
    self.performance_range: float = 1.0 - self.chance_level

chance_level

chance_level: float = 1.0 / 3.0

performance_range

performance_range: float = 1.0 - chance_level

slope

slope = float(slope)

loglik

loglik(
    params: Any, data: Any, model: Any, noise: Any
) -> ndarray

Compute log-likelihood using analytical predictions.

Parameters:

Name Type Description Default
params dict

Model parameters

required
data ResponseData

Trial data containing refs, comparisons, responses

required
model WPPM

Model instance

required
noise NoiseModel

Observer noise model

required

Returns:

Type Description
ndarray

Scalar sum of log-likelihoods over all trials

Notes

Uses Bernoulli log-likelihood: LL = Σ [y * log(p) + (1-y) * log(1-p)] where p comes from predict() (analytical approximation)

Source code in src/psyphy/model/task.py
def loglik(self, params: Any, data: Any, model: Any, noise: Any) -> jnp.ndarray:
    """
    Compute log-likelihood using analytical predictions.

    Parameters
    ----------
    params : dict
        Model parameters
    data : ResponseData
        Trial data containing refs, comparisons, responses
    model : WPPM
        Model instance
    noise : NoiseModel
        Observer noise model

    Returns
    -------
    jnp.ndarray
        Scalar sum of log-likelihoods over all trials

    Notes
    -----
    Uses Bernoulli log-likelihood:
        LL = Σ [y * log(p) + (1-y) * log(1-p)]
    where p comes from predict() (analytical approximation)
    """
    refs, comparisons, responses = data.to_numpy()
    ps = jnp.array(
        [
            self.predict(params, (r, p), model, noise)
            for r, p in zip(refs, comparisons)
        ]
    )
    eps = 1e-9
    return jnp.sum(
        jnp.where(responses == 1, jnp.log(ps + eps), jnp.log(1.0 - ps + eps))
    )

loglik_mc

loglik_mc(
    params: Any,
    data: Any,
    model: Any,
    noise: Any,
    num_samples: int = 1000,
    bandwidth: float = 0.01,
    key: Any = None,
) -> ndarray

Compute log-likelihood via Monte Carlo observer simulation.

This method implements the FULL 3-stimulus oddity task. Instead of using an analytical approximation, we: 1. Sample three internal noisy representations per trial: - z_ref, z_refprime ~ N(ref, Σ_ref) [two samples from reference] - z_comparison ~ N(comparison, Σ_comparison) [one sample from comparison] 2. Compute three pairwise Mahalanobis distances 3. Apply oddity decision rule: comparison is odd if it's farther from BOTH ref and reference_prime 4. Apply logistic smoothing to approximate P(correct) 5. Average over MC samples

Parameters:

Name Type Description Default
params dict

Model parameters (must contain 'W' for WPPM basis coefficients)

required
data ResponseData

Trial data with refs, comparisons, and responses

required
model WPPM

Model instance providing compute_U() for covariance computation

required
noise NoiseModel

Observer noise model (provides sigma for diagonal noise term)

required
num_samples int

Number of Monte Carlo samples per trial. - Use 1000-5000 for accurate likelihood estimation - Larger values reduce MC variance but increase compute time

1000
bandwidth float

Smoothing parameter for logistic CDF approximation. - Smaller values -> sharper transition (closer to step function) - Larger values -> smoother approximation - Typical range: [1e-3, 5e-2]

1e-2
key PRNGKey

Random key for reproducible sampling. If None, uses PRNGKey(0) (deterministic but not recommended for production)

None

Returns:

Type Description
ndarray

Scalar sum of log-likelihoods over all trials. Same shape and interpretation as loglik().

Raises:

Type Description
ValueError

If num_samples <= 0

Notes

Full 3-stimulus oddity task algorithm:

For each trial (ref, comparison, response): 1. Compute covariances: - Σ_ref = U_ref @ U_ref.T + σ^2 I - Σ_comparison = U_comparison @ U_comparison.T + σ^2 I - Σ_avg = (2/3) Σ_ref + (1/3) Σ_comparison [weighted by stimulus frequency]

  1. Sample three internal representations:
  2. z_ref, z_refprime ~ N(ref, Σ_ref) [2 samples from reference, num_samples times each]
  3. z_comparison ~ N(comparison, Σ_comparison) [1 sample from comparison, num_samples times]

  4. Compute three pairwise Mahalanobis distances:

  5. d^2(z_ref, z_refprime) = (z_ref - z_refprime).T @ Σ_avg^{-1} @ (z_ref - z_refprime) [ref vs reference_prime]
  6. d^2(z_ref, z_comparison) = (z_ref - z_comparison).T @ Σ_avg^{-1} @ (z_ref - z_comparison) [ref vs comparison]
  7. d^2(z_refprime, z_comparison) = (z_refprime - z_comparison).T @ Σ_avg^{-1} @ (z_refprime - z_comparison) [reference_prime vs comparison]

  8. Apply oddity decision rule:

  9. delta = min(d^2(z_ref,z_comparison), d^2(z_refprime,z_comparison)) - d^2(z_ref,z_refprime)
  10. delta > 0 means comparison is farther from BOTH ref and reference_prime -> correct identification

  11. Apply logistic smoothing:

  12. P(correct) pprox mean(logistic.cdf(delta / bandwidth))

  13. Bernoulli log-likelihood:

  14. LL = Σ [y * log(p) + (1-y) * log(1-p)]

Performance: - Time complexity: O(n_trials * num_samples * input_dim³) - Memory: O(num_samples * input_dim) per trial - Vectorized across trials using jax.vmap for GPU acceleration - Can be JIT-compiled for additional speed (future optimization)

Comparison to analytical: - MC implements full 3-stimulus oddity (more realistic) - MC is more accurate for complex Σ(x) structures - Analytical is faster and differentiable - Use MC for validation and benchmarking, analytical for optimization

Examples:

>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> from psyphy.model import WPPM, Prior
>>> from psyphy.model.task import OddityTask
>>> from psyphy.model.noise import GaussianNoise
>>> from psyphy.data.dataset import ResponseData
>>>
>>> # Setup
>>> model = WPPM(
...     input_dim=2,
...     prior=Prior(input_dim=2, basis_degree=3),
...     task=OddityTask(),
...     noise=GaussianNoise(sigma=0.03),
... )
>>> params = model.init_params(jr.PRNGKey(0))
>>>
>>> # Create trial data
>>> data = ResponseData()
>>> data.add_trial(
...     ref=jnp.array([0.0, 0.0]), comparison=jnp.array([0.3, 0.2]), resp=1
... )
>>>
>>> # Compare analytical vs MC
>>> ll_analytical = model.task.loglik(params, data, model, model.noise)
>>> ll_mc = model.task.loglik_mc(
...     params,
...     data,
...     model,
...     model.noise,
...     num_samples=5000,
...     bandwidth=1e-3,
...     key=jr.PRNGKey(42),
... )
>>> print(f"Analytical: {ll_analytical:.4f}")
>>> print(f"MC (N=5000): {ll_mc:.4f}")
>>> print(f"Difference: {abs(ll_mc - ll_analytical):.4f}")
See Also

loglik : Analytical log-likelihood (faster, differentiable) predict : Analytical prediction for single trial

Source code in src/psyphy/model/task.py
def loglik_mc(
    self,
    params: Any,
    data: Any,
    model: Any,
    noise: Any,
    num_samples: int = 1000,
    bandwidth: float = 1e-2,
    key: Any = None,
) -> jnp.ndarray:
    """
    Compute log-likelihood via Monte Carlo observer simulation.

    This method implements the FULL 3-stimulus oddity task. Instead of using
    an analytical approximation, we:
    1. Sample three internal noisy representations per trial:
       - z_ref, z_refprime ~ N(ref, Σ_ref)  [two samples from reference]
       - z_comparison ~ N(comparison, Σ_comparison)           [one sample from comparison]
    2. Compute three pairwise Mahalanobis distances
    3. Apply oddity decision rule: comparison is odd if it's farther from BOTH ref and reference_prime
    4. Apply logistic smoothing to approximate P(correct)
    5. Average over MC samples

    Parameters
    ----------
    params : dict
        Model parameters (must contain 'W' for WPPM basis coefficients)
    data : ResponseData
        Trial data with refs, comparisons, and responses
    model : WPPM
        Model instance providing compute_U() for covariance computation
    noise : NoiseModel
        Observer noise model (provides sigma for diagonal noise term)
    num_samples : int, default=1000
        Number of Monte Carlo samples per trial.
        - Use 1000-5000 for accurate likelihood estimation
        - Larger values reduce MC variance but increase compute time
    bandwidth : float, default=1e-2
        Smoothing parameter for logistic CDF approximation.
        - Smaller values -> sharper transition (closer to step function)
        - Larger values -> smoother approximation
        - Typical range: [1e-3, 5e-2]
    key : jax.random.PRNGKey, optional
        Random key for reproducible sampling.
        If None, uses PRNGKey(0) (deterministic but not recommended for production)

    Returns
    -------
    jnp.ndarray
        Scalar sum of log-likelihoods over all trials.
        Same shape and interpretation as loglik().

    Raises
    ------
    ValueError
        If num_samples <= 0

    Notes
    -----
    **Full 3-stimulus oddity task algorithm:**

    For each trial (ref, comparison, response):
    1. Compute covariances:
       - Σ_ref = U_ref @ U_ref.T + σ^2 I
       - Σ_comparison = U_comparison @ U_comparison.T + σ^2 I
       - Σ_avg = (2/3) Σ_ref + (1/3) Σ_comparison  [weighted by stimulus frequency]

    2. Sample three internal representations:
       - z_ref, z_refprime ~ N(ref, Σ_ref)  [2 samples from reference, num_samples times each]
       - z_comparison ~ N(comparison, Σ_comparison)           [1 sample from comparison, num_samples times]

    3. Compute three pairwise Mahalanobis distances:
       - d^2(z_ref, z_refprime) = (z_ref - z_refprime).T @ Σ_avg^{-1} @ (z_ref - z_refprime)  [ref vs reference_prime]
       - d^2(z_ref, z_comparison) = (z_ref - z_comparison).T @ Σ_avg^{-1} @ (z_ref - z_comparison)  [ref vs comparison]
       - d^2(z_refprime, z_comparison) = (z_refprime - z_comparison).T @ Σ_avg^{-1} @ (z_refprime - z_comparison)  [reference_prime vs comparison]

    4. Apply oddity decision rule:
       - delta = min(d^2(z_ref,z_comparison), d^2(z_refprime,z_comparison)) - d^2(z_ref,z_refprime)
       - delta > 0 means comparison is farther from BOTH ref and reference_prime -> correct identification

    5. Apply logistic smoothing:
       - P(correct) \approx mean(logistic.cdf(delta / bandwidth))

    6. Bernoulli log-likelihood:
       - LL = Σ [y * log(p) + (1-y) * log(1-p)]

    Performance:
    - Time complexity: O(n_trials * num_samples * input_dim³)
    - Memory: O(num_samples * input_dim) per trial
    - Vectorized across trials using jax.vmap for GPU acceleration
    - Can be JIT-compiled for additional speed (future optimization)

    Comparison to analytical:
    - MC implements full 3-stimulus oddity (more realistic)
    - MC is more accurate for complex Σ(x) structures
    - Analytical is faster and differentiable
    - Use MC for validation and benchmarking, analytical for optimization

    Examples
    --------
    >>> import jax.numpy as jnp
    >>> import jax.random as jr
    >>> from psyphy.model import WPPM, Prior
    >>> from psyphy.model.task import OddityTask
    >>> from psyphy.model.noise import GaussianNoise
    >>> from psyphy.data.dataset import ResponseData
    >>>
    >>> # Setup
    >>> model = WPPM(
    ...     input_dim=2,
    ...     prior=Prior(input_dim=2, basis_degree=3),
    ...     task=OddityTask(),
    ...     noise=GaussianNoise(sigma=0.03),
    ... )
    >>> params = model.init_params(jr.PRNGKey(0))
    >>>
    >>> # Create trial data
    >>> data = ResponseData()
    >>> data.add_trial(
    ...     ref=jnp.array([0.0, 0.0]), comparison=jnp.array([0.3, 0.2]), resp=1
    ... )
    >>>
    >>> # Compare analytical vs MC
    >>> ll_analytical = model.task.loglik(params, data, model, model.noise)
    >>> ll_mc = model.task.loglik_mc(
    ...     params,
    ...     data,
    ...     model,
    ...     model.noise,
    ...     num_samples=5000,
    ...     bandwidth=1e-3,
    ...     key=jr.PRNGKey(42),
    ... )
    >>> print(f"Analytical: {ll_analytical:.4f}")
    >>> print(f"MC (N=5000): {ll_mc:.4f}")
    >>> print(f"Difference: {abs(ll_mc - ll_analytical):.4f}")

    See Also
    --------
    loglik : Analytical log-likelihood (faster, differentiable)
    predict : Analytical prediction for single trial



    """
    # Validate inputs
    if num_samples <= 0:
        raise ValueError(f"num_samples must be > 0, got {num_samples}")

    # Default key for reproducibility (warn: not secure)
    if key is None:
        key = jr.PRNGKey(0)

    # Unpack trial data
    refs, comparisons, responses = data.to_numpy()
    n_trials = len(refs)

    # Split keys for each trial (ensures independent sampling)
    trial_keys = jr.split(key, n_trials)

    # Vectorized computation of P(correct) for all trials
    # This processes all trials in parallel using jax.vmap
    # Note: probabilities are already clipped in _simulate_trial_mc()
    probs = self._simulate_trials_mc_vectorized(
        params=params,
        refs=refs,
        comparisons=comparisons,
        model=model,
        noise=noise,
        num_samples=num_samples,
        bandwidth=bandwidth,
        trial_keys=trial_keys,
    )

    # Bernoulli log-likelihood: LL = Σ [y log(p) + (1-y) log(1-p)]
    # Probabilities are already clipped to [eps, 1-eps] so log is safe
    log_likelihoods = jnp.where(
        responses == 1,
        jnp.log(probs),  # Correct response
        jnp.log(1.0 - probs),  # Incorrect response
    )

    return jnp.sum(log_likelihoods)

predict

predict(
    params: Any, stimuli: Stimulus, model: Any, noise: Any
) -> ndarray

Predict probability of correct response using analytical approximation.

Parameters:

Name Type Description Default
params dict

Model parameters (e.g., W for WPPM)

required
stimuli tuple[ndarray, ndarray]

(reference, comparison) stimulus pair

required
model WPPM

Model instance providing discriminability()

required
noise NoiseModel

Observer noise model (currently unused in analytical version)

required

Returns:

Type Description
ndarray

Scalar probability of correct response, in range [1/3, 1]

Notes

Uses tanh mapping: P(correct) = 1/3 + 2/3 * sigmoid(slope * d) where d is discriminability from model.discriminability()

Source code in src/psyphy/model/task.py
def predict(
    self, params: Any, stimuli: Stimulus, model: Any, noise: Any
) -> jnp.ndarray:
    """
    Predict probability of correct response using analytical approximation.

    Parameters
    ----------
    params : dict
        Model parameters (e.g., W for WPPM)
    stimuli : tuple[jnp.ndarray, jnp.ndarray]
        (reference, comparison) stimulus pair
    model : WPPM
        Model instance providing discriminability()
    noise : NoiseModel
        Observer noise model (currently unused in analytical version)

    Returns
    -------
    jnp.ndarray
        Scalar probability of correct response, in range [1/3, 1]

    Notes
    -----
    Uses tanh mapping: P(correct) = 1/3 + 2/3 * sigmoid(slope * d)
    where d is discriminability from model.discriminability()
    """
    d = model.discriminability(params, stimuli)
    g = 0.5 * (jnp.tanh(self.slope * d) + 1.0)
    return self.chance_level + self.performance_range * g

OnlineConfig

OnlineConfig(
    strategy: Literal[
        "full", "sliding_window", "reservoir", "none"
    ] = "full",
    window_size: int | None = None,
    refit_interval: int = 1,
    warm_start: bool = True,
)

Configuration for online learning and memory management.

Attributes:

Name Type Description
strategy {'full', 'sliding_window', 'reservoir', 'none'}

Data retention strategy: - "full": Keep all data (unbounded memory) - "sliding_window": Keep only last N trials (FIFO) - "reservoir": Reservoir sampling for uniform coverage - "none": No caching, refit from scratch each time

window_size int | None

Maximum number of trials to retain (for sliding_window/reservoir). Required for sliding_window and reservoir strategies.

refit_interval int

Refit model every N updates (1=always, 10=batch every 10 trials). Trades off accuracy vs. computational cost.

warm_start bool

If True, initialize refitting from cached parameters. Speeds up convergence for small updates.

Examples:

>>> # Unbounded memory (default)
>>> config = OnlineConfig(strategy="full")
1
2
3
4
5
6
>>> # Sliding window: keep last 10K trials
>>> config = OnlineConfig(
...     strategy="sliding_window",
...     window_size=10_000,
...     refit_interval=10,  # Batch every 10 trials
... )
1
2
3
4
5
>>> # Reservoir sampling: uniform coverage with 5K trials
>>> config = OnlineConfig(
...     strategy="reservoir",
...     window_size=5_000,
... )

refit_interval

refit_interval: int = 1

strategy

strategy: Literal[
    "full", "sliding_window", "reservoir", "none"
] = "full"

warm_start

warm_start: bool = True

window_size

window_size: int | None = None

Prior

Prior(
    input_dim: int,
    scale: float = 0.5,
    basis_degree: int | None = None,
    variance_scale: float = 1.0,
    decay_rate: float = 0.5,
    lengthscale: float = 1.0,
    extra_embedding_dims: int = 0,
)

Prior distribution over WPPM parameters

Parameters:

Name Type Description Default
input_dim int

Dimensionality of the model space (same as WPPM.input_dim)

required
scale float

Stddev of Gaussian prior for log_diag entries (MVP only).

0.5
basis_degree int | None

Degree of Chebyshev basis for Wishart process. If None, uses MVP mode with log_diag parameters. If set, uses Wishart mode with W coefficients.

None
variance_scale float

Prior variance for degree-0 (constant) coefficient in Wishart mode. Controls overall scale of covariances.

1.0
decay_rate float

Geometric decay rate for prior variance over higher-degree coefficients. Prior variance for degree-d coefficient = variance_scale * (decay_rate^d). Smaller decay_rate → stronger smoothness prior.

0.5
lengthscale float

Alias for decay_rate (kept for backward compatibility). If both specified, decay_rate takes precedence.

1.0
extra_embedding_dims int

Additional latent dimensions in U matrices beyond input dimensions. Allows richer ellipsoid shapes in Wishart mode.

0

Methods:

Name Description
default

Convenience constructor with MVP defaults.

log_prob

Compute log prior density (up to a constant)

sample_params

Sample initial parameters from the prior.

Attributes:

Name Type Description
basis_degree int | None
decay_rate float
extra_embedding_dims int
input_dim int
lengthscale float
scale float
variance_scale float

basis_degree

basis_degree: int | None = None

decay_rate

decay_rate: float = 0.5

extra_embedding_dims

extra_embedding_dims: int = 0

input_dim

input_dim: int

lengthscale

lengthscale: float = 1.0

scale

scale: float = 0.5

variance_scale

variance_scale: float = 1.0

default

default(input_dim: int, scale: float = 0.5) -> Prior

Convenience constructor with MVP defaults.

Source code in src/psyphy/model/prior.py
@classmethod
def default(cls, input_dim: int, scale: float = 0.5) -> Prior:
    """Convenience constructor with MVP defaults."""
    return cls(input_dim=input_dim, scale=scale)

log_prob

log_prob(params: Params) -> ndarray

Compute log prior density (up to a constant)

MVP mode: Isotropic Gaussian on log_diag

Wishart mode: Gaussian prior on W with smoothness via decay_rate log p(W) = Σ_ij log N(W_ij | 0, σ_ij^2) where σ_ij^2 = prior variance

Parameters:

Name Type Description Default
params dict

Parameter dictionary

required

Returns:

Name Type Description
log_prob float

Log prior probability (up to normalizing constant)

Source code in src/psyphy/model/prior.py
def log_prob(self, params: Params) -> jnp.ndarray:
    """
    Compute log prior density (up to a constant)

    MVP mode:
        Isotropic Gaussian on log_diag

    Wishart mode:
        Gaussian prior on W with smoothness via decay_rate
        log p(W) = Σ_ij log N(W_ij | 0, σ_ij^2) where σ_ij^2 = prior variance

    Parameters
    ----------
    params : dict
        Parameter dictionary

    Returns
    -------
    log_prob : float
        Log prior probability (up to normalizing constant)
    """
    if "log_diag" in params:
        # MVP mode
        log_diag = params["log_diag"]
        var = self.scale**2
        return -0.5 * jnp.sum((log_diag**2) / var)

    if "W" in params:
        # Wishart mode
        W = params["W"]
        variances = self._compute_W_prior_variances()

        # Gaussian log probability for each entry
        # log N(x | 0, σ^2) = -0.5 * (x^2/σ^2 + log(2πσ^2))
        # Up to constant: -0.5 * x^2/σ^2

        if self.input_dim == 2:
            # Each W[i,j,:,:] ~ Normal(0, variance[i,j] * I)
            return -0.5 * jnp.sum((W**2) / (variances[:, :, None, None] + 1e-10))
        elif self.input_dim == 3:
            return -0.5 * jnp.sum((W**2) / (variances[:, :, :, None, None] + 1e-10))

    raise ValueError("params must contain either 'log_diag' (MVP) or 'W' (Wishart)")

sample_params

sample_params(key: Any) -> Params

Sample initial parameters from the prior.

MVP mode (basis_degree=None): Returns {"log_diag": shape (input_dim,)}

Wishart mode (basis_degree set): Returns {"W": shape (degree+1, degree+1, input_dim, embedding_dim)} for 2D, where embedding_dim = input_dim + extra_embedding_dims

1
2
3
Note: The 3rd dimension is input_dim (output space dimension).
This matches the einsum in _compute_U:
U = einsum("ijde,ij->de", W, phi) where d indexes input_dim.

Parameters:

Name Type Description Default
key JAX random key
required

Returns:

Name Type Description
params dict

Parameter dictionary

Source code in src/psyphy/model/prior.py
def sample_params(self, key: Any) -> Params:
    """
    Sample initial parameters from the prior.

    MVP mode (basis_degree=None):
        Returns {"log_diag": shape (input_dim,)}

    Wishart mode (basis_degree set):
        Returns {"W": shape (degree+1, degree+1, input_dim, embedding_dim)}
        for 2D, where embedding_dim = input_dim + extra_embedding_dims

        Note: The 3rd dimension is input_dim (output space dimension).
        This matches the einsum in _compute_U:
        U = einsum("ijde,ij->de", W, phi) where d indexes input_dim.

    Parameters
    ----------
    key : JAX random key

    Returns
    -------
    params : dict
        Parameter dictionary
    """
    if self.basis_degree is None:
        # MVP mode: simple diagonal covariance
        log_diag = jr.normal(key, shape=(self.input_dim,)) * self.scale
        return {"log_diag": log_diag}

    # Wishart mode: basis function coefficients W
    variances = self._compute_W_prior_variances()
    embedding_dim = self.input_dim + self.extra_embedding_dims

    if self.input_dim == 2:
        # Sample W ~ Normal(0, variances) for each matrix entry
        # Shape: (degree+1, degree+1, input_dim, embedding_dim)
        # Note: degree+1 to match number of basis functions [T_0, ..., T_degree]
        W = jnp.sqrt(variances)[:, :, None, None] * jr.normal(
            key,
            shape=(
                self.basis_degree + 1,
                self.basis_degree + 1,
                self.input_dim,
                embedding_dim,
            ),
        )
    elif self.input_dim == 3:
        # Shape: (degree+1, degree+1, degree+1, input_dim, embedding_dim)
        W = jnp.sqrt(variances)[:, :, :, None, None] * jr.normal(
            key,
            shape=(
                self.basis_degree + 1,
                self.basis_degree + 1,
                self.basis_degree + 1,
                self.input_dim,
                embedding_dim,
            ),
        )
    else:
        raise NotImplementedError(
            f"Wishart process only supports 2D and 3D. Got input_dim={self.input_dim}"
        )

    return {"W": W}

StudentTNoise

StudentTNoise(df: float = 3.0, scale: float = 1.0)

Methods:

Name Description
log_prob
sample_standard

Sample from standard Student-t (df=self.df).

Attributes:

Name Type Description
df float
scale float

df

df: float = 3.0

scale

scale: float = 1.0

log_prob

log_prob(residual: float) -> float
Source code in src/psyphy/model/noise.py
def log_prob(self, residual: float) -> float:
    _ = residual
    return -0.5

sample_standard

sample_standard(
    key: Array, shape: tuple[int, ...]
) -> Array

Sample from standard Student-t (df=self.df).

Source code in src/psyphy/model/noise.py
def sample_standard(self, key: jax.Array, shape: tuple[int, ...]) -> jax.Array:
    """Sample from standard Student-t (df=self.df)."""
    return jr.t(key, self.df, shape)

TaskLikelihood

Bases: ABC

Abstract base class for task likelihoods

Methods:

Name Description
loglik

Compute log-likelihood of observed responses under this task

predict

Predict probability of correct response for a stimulus.

loglik

loglik(
    params: Any, data: Any, model: Any, noise: Any
) -> ndarray

Compute log-likelihood of observed responses under this task

Source code in src/psyphy/model/task.py
@abstractmethod
def loglik(self, params: Any, data: Any, model: Any, noise: Any) -> jnp.ndarray:
    """Compute log-likelihood of observed responses under this task"""
    ...

predict

predict(
    params: Any, stimuli: Stimulus, model: Any, noise: Any
) -> ndarray

Predict probability of correct response for a stimulus.

Source code in src/psyphy/model/task.py
@abstractmethod
def predict(
    self, params: Any, stimuli: Stimulus, model: Any, noise: Any
) -> jnp.ndarray:
    """Predict probability of correct response for a stimulus."""
    ...

TwoAFC

TwoAFC(slope: float = 2.0)

Bases: TaskLikelihood

2-alternative forced-choice task (MVP placeholder).

Methods:

Name Description
loglik
predict

Attributes:

Name Type Description
chance_level float
performance_range float
slope
Source code in src/psyphy/model/task.py
def __init__(self, slope: float = 2.0) -> None:
    self.slope = float(slope)
    self.chance_level: float = 0.5
    self.performance_range: float = 1.0 - self.chance_level

chance_level

chance_level: float = 0.5

performance_range

performance_range: float = 1.0 - chance_level

slope

slope = float(slope)

loglik

loglik(
    params: Any, data: Any, model: Any, noise: Any
) -> ndarray
Source code in src/psyphy/model/task.py
def loglik(self, params: Any, data: Any, model: Any, noise: Any) -> jnp.ndarray:
    refs, comparisons, responses = data.to_numpy()
    ps = jnp.array(
        [
            self.predict(params, (r, p), model, noise)
            for r, p in zip(refs, comparisons)
        ]
    )
    eps = 1e-9
    return jnp.sum(
        jnp.where(responses == 1, jnp.log(ps + eps), jnp.log(1.0 - ps + eps))
    )

predict

predict(
    params: Any, stimuli: Stimulus, model: Any, noise: Any
) -> ndarray
Source code in src/psyphy/model/task.py
def predict(
    self, params: Any, stimuli: Stimulus, model: Any, noise: Any
) -> jnp.ndarray:
    d = model.discriminability(params, stimuli)
    return self.chance_level + self.performance_range * jnp.tanh(self.slope * d)

WPPM

WPPM(
    input_dim: int,
    prior: Prior,
    task: TaskLikelihood,
    noise: Any | None = None,
    *,
    extra_dims: int = 0,
    variance_scale: float = 1.0,
    lengthscale: float = 1.0,
    diag_term: float = 1e-06,
    **kwargs,
)

Bases: Model

Wishart Process Psychophysical Model (WPPM).

Parameters:

Name Type Description Default
input_dim int

Dimensionality of the input stimulus space (e.g., 2 for isoluminant plane, 3 for RGB). Both reference and probe live in R^{input_dim}.

required
prior Prior

Prior distribution over model parameters. Controls basis_degree for Wishart mode (basis expansion) vs MVP mode (diagonal covariance). The WPPM delegates to prior.basis_degree to ensure consistency between parameter sampling and basis evaluation.

required
task TaskLikelihood

Psychophysical task mapping that defines how discriminability translates to p(correct) and how log-likelihood of responses is computed. (e.g., OddityTask, TwoAFC)

required
noise Any

Noise model describing internal representation noise (e.g., GaussianNoise). Not used in MVP mapping but passed to the task interface for future MC sims.

None
Forward-compatible hyperparameters

extra_dims : int, default=0 Additional embedding dimensions for basis expansions (beyond input_dim). In Wishart mode, embedding_dim = input_dim + extra_dims. variance_scale : float, default=1.0 Global scaling factor for covariance magnitude (unused in MVP). lengthscale : float, default=1.0 Smoothness/length-scale for spatial covariance variation (unused in MVP). (formerly "decay_rate") diag_term : float, default=1e-6 Small positive value added to the covariance diagonal for numerical stability. MVP uses this in matrix solves; the research model will also use it.

Methods:

Name Description
condition_on_observations

Update model with new observations (online learning).

discriminability

Compute scalar discriminability d >= 0 for a (reference, probe) pair

fit

Fit model to data.

init_params

Sample initial parameters from the prior.

local_covariance

Return local covariance Σ(x) at stimulus location x.

log_likelihood

Compute the log-likelihood for arrays of trials.

log_likelihood_from_data

Compute log-likelihood directly from a ResponseData object.

log_posterior_from_data

Convenience helper if you want log posterior in one call (MVP).

posterior

Return posterior distribution.

predict_prob

Predict probability of a correct response for a single stimulus.

predict_with_params

Evaluate model at specific parameter values (no marginalization).

Attributes:

Name Type Description
basis_degree int | None

Chebyshev polynomial degree for Wishart process basis expansion.

diag_term
embedding_dim int

Dimension of the embedding space (perceptual space).

extra_dims
input_dim
lengthscale
noise
online_config
prior
task
variance_scale
Source code in src/psyphy/model/wppm.py
def __init__(
    self,
    input_dim: int,
    prior: Prior,
    task: TaskLikelihood,
    noise: Any | None = None,
    *,  # everything after here is keyword-only
    extra_dims: int = 0,
    variance_scale: float = 1.0,
    lengthscale: float = 1.0,
    diag_term: float = 1e-6,
    **kwargs,  # Accept online_config from model base
) -> None:
    # Initialize Model base class
    super().__init__(**kwargs)

    # --- core components ---
    self.input_dim = int(input_dim)  # stimulus-space dimensionality
    self.prior = prior  # prior over parameter PyTree
    self.task = task  # task mapping and likelihood
    self.noise = noise  # noise model

    # --- forward-compatible hyperparameters (stubs in MVP) ---
    self.extra_dims = int(extra_dims)
    self.variance_scale = float(variance_scale)
    self.lengthscale = float(lengthscale)
    self.diag_term = float(diag_term)

basis_degree

basis_degree: int | None

Chebyshev polynomial degree for Wishart process basis expansion.

This property delegates to self.prior.basis_degree to ensure consistency between parameter sampling and basis evaluation.

Returns:

Type Description
int | None

Degree of Chebyshev polynomial basis (0 = constant, 1 = linear, etc.) None indicates MVP mode (no basis expansion)

Notes

WPPM gets its basis_degree parameter from Prior.basis_degree.

diag_term

diag_term = float(diag_term)

embedding_dim

embedding_dim: int

Dimension of the embedding space (perceptual space).

embedding_dim = input_dim + extra_dims. this represents the full perceptual space where: - First input_dim dimensions correspond to observable stimulus features - Remaining extra_dims are latent dimensions

Returns:

Type Description
int

input_dim + extra_dims (in Wishart mode) input_dim (in MVP mode, extra_dims ignored)

Notes

This is a computed property, not a constructor parameter.

extra_dims

extra_dims = int(extra_dims)

input_dim

input_dim = int(input_dim)

lengthscale

lengthscale = float(lengthscale)

noise

noise = noise

online_config

online_config = online_config or OnlineConfig()

prior

prior = prior

task

task = task

variance_scale

variance_scale = float(variance_scale)

condition_on_observations

condition_on_observations(X: ndarray, y: ndarray) -> Model

Update model with new observations (online learning).

Behavior depends on self.online_config.strategy: - "full": Accumulate all data, refit periodically - "sliding_window": Keep only recent window_size trials - "reservoir": Random sampling of window_size trials - "none": Refit from scratch (no caching)

Returns a NEW model instance (immutable update).

Parameters:

Name Type Description Default
X ndarray

New stimuli

required
y ndarray

New responses

required

Returns:

Type Description
Model

Updated model (new instance)

Examples:

1
2
3
4
5
>>> # Online learning loop
>>> model = WPPM(...).fit(X_init, y_init)
>>> for X_new, y_new in stream:
...     model = model.condition_on_observations(X_new, y_new)
...     # Model automatically manages memory and refitting
Source code in src/psyphy/model/base.py
def condition_on_observations(self, X: jnp.ndarray, y: jnp.ndarray) -> Model:
    """
    Update model with new observations (online learning).

    Behavior depends on self.online_config.strategy:
    - "full": Accumulate all data, refit periodically
    - "sliding_window": Keep only recent window_size trials
    - "reservoir": Random sampling of window_size trials
    - "none": Refit from scratch (no caching)

    Returns a NEW model instance (immutable update).

    Parameters
    ----------
    X : jnp.ndarray
        New stimuli
    y : jnp.ndarray
        New responses

    Returns
    -------
    Model
        Updated model (new instance)

    Examples
    --------
    >>> # Online learning loop
    >>> model = WPPM(...).fit(X_init, y_init)
    >>> for X_new, y_new in stream:
    ...     model = model.condition_on_observations(X_new, y_new)
    ...     # Model automatically manages memory and refitting
    """
    from psyphy.data import ResponseData

    # Convert new data
    new_data = ResponseData.from_arrays(X, y)

    # Update data buffer according to strategy
    if self.online_config.strategy == "none":
        data_to_fit = new_data

    elif self.online_config.strategy == "full":
        if self._data_buffer is None:
            self._data_buffer = ResponseData()
        self._data_buffer.merge(new_data)
        data_to_fit = self._data_buffer

    elif self.online_config.strategy == "sliding_window":
        if self._data_buffer is None:
            self._data_buffer = ResponseData()
        self._data_buffer.merge(new_data)

        # Keep only last N trials
        window_size = self.online_config.window_size
        assert window_size is not None, "window_size must be set for sliding_window"
        if len(self._data_buffer) > window_size:
            self._data_buffer = self._data_buffer.tail(window_size)
        data_to_fit = self._data_buffer

    elif self.online_config.strategy == "reservoir":
        if self._data_buffer is None:
            self._data_buffer = ResponseData()

        # Reservoir sampling
        window_size = self.online_config.window_size
        assert window_size is not None, "window_size must be set for reservoir"
        self._data_buffer = self._reservoir_update(
            self._data_buffer,
            new_data,
            window_size,
        )
        data_to_fit = self._data_buffer
    else:
        raise ValueError(f"Unknown strategy: {self.online_config.strategy}")

    # Decide whether to refit
    self._n_updates += 1
    should_refit = self._n_updates % self.online_config.refit_interval == 0

    if not should_refit:
        # Return clone with updated buffer but old posterior
        new_model = self._clone()
        new_model._data_buffer = data_to_fit
        return new_model

    # Refit with optional warm start
    inference = self._inference_engine
    assert inference is not None, (
        "Model must be fit before condition_on_observations"
    )

    if self.online_config.warm_start and self._posterior is not None:
        # TODO: Add warm-start support to inference engines
        # For now, just refit from cached params
        pass

    new_model = self._clone()
    new_model._data_buffer = data_to_fit
    new_model._posterior = inference.fit(new_model, data_to_fit)
    new_model._inference_engine = inference

    return new_model

discriminability

discriminability(
    params: Params, stimulus: Stimulus
) -> ndarray

Compute scalar discriminability d >= 0 for a (reference, probe) pair

MVP mode: d = sqrt( (probe - ref)^T Σ(ref)^{-1} (probe - ref) ) with Σ(ref) the local covariance at the reference in stimulus space.

Wishart mode (rectangular U design) if extra_dims > 0: d = sqrt( (probe - ref)^T Σ(ref)^{-1} (probe - ref) ) where Σ(ref) is directly computed in stimulus space (input_dim, input_dim) via U(x) @ U(x)^T with U rectangular.

The discrimination task only depends on observable stimulus dimensions. The rectangular U design means local_covariance() already returns the stimulus covariance - no block extraction needed.

Future (full WPPM mode): d is implicit via Monte Carlo simulation of internal noisy responses under the task's decision rule (no closed form). In that case, tasks will directly implement predict/loglik with MC, and this method may be used only for diagnostics.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
stimulus tuple

(reference, probe) arrays of shape (input_dim,).

required

Returns:

Name Type Description
d ndarray

Nonnegative scalar discriminability.

Source code in src/psyphy/model/wppm.py
def discriminability(self, params: Params, stimulus: Stimulus) -> jnp.ndarray:
    """
    Compute scalar discriminability d >= 0 for a (reference, probe) pair

    MVP mode:
        d = sqrt( (probe - ref)^T Σ(ref)^{-1} (probe - ref) )
        with Σ(ref) the local covariance at the reference in stimulus space.

    Wishart mode (rectangular U design) if extra_dims > 0:
        d = sqrt( (probe - ref)^T Σ(ref)^{-1} (probe - ref) )
        where Σ(ref) is directly computed in stimulus space (input_dim, input_dim)
        via U(x) @ U(x)^T with U rectangular.

    The discrimination task only depends on observable stimulus dimensions.
    The rectangular U design means local_covariance() already returns
    the stimulus covariance - no block extraction needed.

    Future (full WPPM mode):
        d is implicit via Monte Carlo simulation of internal noisy responses
        under the task's decision rule (no closed form). In that case, tasks
        will directly implement predict/loglik with MC, and this method may be
        used only for diagnostics.

    Parameters
    ----------
    params : dict
        Model parameters.
    stimulus : tuple
        (reference, probe) arrays of shape (input_dim,).

    Returns
    -------
    d : jnp.ndarray
        Nonnegative scalar discriminability.
    """
    ref, probe = stimulus

    # Delta is in stimulus space (input_dim)
    delta = probe - ref

    # Get stimulus covariance at reference
    # (rectangular U design: already returns (input_dim, input_dim))
    Sigma = self.local_covariance(params, ref)

    # Add jitter for stable solve; diag_term is configurable
    jitter = self.diag_term * jnp.eye(self.input_dim)

    # Solve (Σ + jitter)^{-1} delta using a PD-aware solver
    x = jax.scipy.linalg.solve(Sigma + jitter, delta, assume_a="pos")
    d2 = jnp.dot(delta, x)  # quadratic form

    # Guard against tiny negative values from numerical error
    return jnp.sqrt(jnp.maximum(d2, 0.0))

fit

fit(
    X: ndarray,
    y: ndarray,
    *,
    inference: InferenceEngine | str = "laplace",
    inference_config: dict | None = None,
) -> Model

Fit model to data.

Parameters:

Name Type Description Default
X ndarray

Stimuli, shape (n_trials, 2, input_dim) for (ref, probe) pairs or (n_trials, input_dim) for references only

required
y ndarray

Responses, shape (n_trials,)

required
inference InferenceEngine | str

Inference engine or string key ("map", "laplace", "langevin")

"laplace"
inference_config dict | None

Hyperparameters for string-based inference. Examples: {"steps": 500, "lr": 1e-3} for MAP

None

Returns:

Type Description
Model

Self for method chaining

Examples:

>>> # Simple: use defaults
>>> model.fit(X, y)
1
2
3
>>> # Explicit optimizer
>>> from psyphy.inference import MAPOptimizer
>>> model.fit(X, y, inference=MAPOptimizer(steps=500))
>>> # String + config (for experiment tracking)
>>> model.fit(X, y, inference="map", inference_config={"steps": 500})
Source code in src/psyphy/model/base.py
def fit(
    self,
    X: jnp.ndarray,
    y: jnp.ndarray,
    *,
    inference: InferenceEngine | str = "laplace",
    inference_config: dict | None = None,
) -> Model:
    """
    Fit model to data.

    Parameters
    ----------
    X : jnp.ndarray
        Stimuli, shape (n_trials, 2, input_dim) for (ref, probe) pairs
        or (n_trials, input_dim) for references only
    y : jnp.ndarray
        Responses, shape (n_trials,)
    inference : InferenceEngine | str, default="laplace"
        Inference engine or string key ("map", "laplace", "langevin")
    inference_config : dict | None
        Hyperparameters for string-based inference.
        Examples: {"steps": 500, "lr": 1e-3} for MAP

    Returns
    -------
    Model
        Self for method chaining

    Examples
    --------
    >>> # Simple: use defaults
    >>> model.fit(X, y)

    >>> # Explicit optimizer
    >>> from psyphy.inference import MAPOptimizer
    >>> model.fit(X, y, inference=MAPOptimizer(steps=500))

    >>> # String + config (for experiment tracking)
    >>> model.fit(X, y, inference="map", inference_config={"steps": 500})
    """
    from psyphy.data import ResponseData
    from psyphy.inference import INFERENCE_ENGINES, InferenceEngine

    # Resolve inference engine
    is_string_inference = isinstance(inference, str)

    if is_string_inference:
        config = inference_config or {}
        inference_key: str = inference  # type: ignore[assignment]
        if inference_key not in INFERENCE_ENGINES:
            available = ", ".join(INFERENCE_ENGINES.keys())
            raise ValueError(
                f"Unknown inference: '{inference}'. Available: {available}"
            )
        inference_engine: InferenceEngine = INFERENCE_ENGINES[inference_key](
            **config
        )
    elif isinstance(inference, InferenceEngine):
        inference_engine = inference
    else:
        raise TypeError(
            f"inference must be InferenceEngine or str, got {type(inference)}"
        )

    if inference_config is not None and not is_string_inference:
        raise ValueError(
            "Cannot pass inference_config with InferenceEngine instance"
        )

    # Convert data
    data = ResponseData.from_arrays(X, y)

    # Fit
    self._posterior = inference_engine.fit(self, data)
    self._inference_engine = inference_engine
    self._data_buffer = data  # Initialize buffer
    return self

init_params

init_params(key: KeyArray) -> Params

Sample initial parameters from the prior.

MVP parameters: {"log_diag": shape (input_dim,)} which defines a constant diagonal covariance across the space.

Returns:

Name Type Description
params dict[str, ndarray]
Source code in src/psyphy/model/wppm.py
def init_params(self, key: jr.KeyArray) -> Params:
    """
    Sample initial parameters from the prior.

    MVP parameters:
        {"log_diag": shape (input_dim,)}
    which defines a constant diagonal covariance across the space.

    Returns
    -------
    params : dict[str, jnp.ndarray]
    """
    return self.prior.sample_params(key)

local_covariance

local_covariance(params: Params, x: ndarray) -> ndarray

Return local covariance Σ(x) at stimulus location x.

MVP mode (basis_degree=None): Σ(x) = diag(exp(log_diag)), constant across x. - Positive-definite because exp(log_diag) > 0.

Wishart mode (basis_degree set): Σ(x) = U(x) @ U(x)^T + diag_term * I where U(x) is rectangular (input_dim, embedding_dim) if extra_dims > 0. - Varies smoothly with x - Guaranteed positive-definite - Returns stimulus covariance directly (input_dim, input_dim)

Parameters:

Name Type Description Default
params dict

Model parameters: - MVP: {"log_diag": (input_dim,)} - Wishart: {"W": (degree+1, ..., input_dim, embedding_dim)}

required
x (ndarray, shape(input_dim))

Stimulus location

required

Returns:

Type Description
Σ : jnp.ndarray, shape (input_dim, input_dim)

Covariance matrix in stimulus space.

Source code in src/psyphy/model/wppm.py
def local_covariance(self, params: Params, x: jnp.ndarray) -> jnp.ndarray:
    """
    Return local covariance Σ(x) at stimulus location x.

    MVP mode (basis_degree=None):
        Σ(x) = diag(exp(log_diag)), constant across x.
        - Positive-definite because exp(log_diag) > 0.

    Wishart mode (basis_degree set):
        Σ(x) = U(x) @ U(x)^T + diag_term * I
        where U(x) is rectangular (input_dim, embedding_dim) if extra_dims > 0.
        - Varies smoothly with x
        - Guaranteed positive-definite
        - Returns stimulus covariance directly (input_dim, input_dim)

    Parameters
    ----------
    params : dict
        Model parameters:
        - MVP: {"log_diag": (input_dim,)}
        - Wishart: {"W": (degree+1, ..., input_dim, embedding_dim)}
    x : jnp.ndarray, shape (input_dim,)
        Stimulus location

    Returns
    -------
    Σ : jnp.ndarray, shape (input_dim, input_dim)
        Covariance matrix in stimulus space.
    """
    # MVP mode: constant diagonal covariance
    if "log_diag" in params:
        log_diag = params["log_diag"]
        diag = jnp.exp(log_diag)
        return jnp.diag(diag)

    # Wishart mode: spatially-varying covariance
    if "W" in params:
        U = self._compute_U(params, x)  # (input_dim, embedding_dim)
        # Σ(x) = U(x) @ U(x)^T + diag_term * I
        # Result is (input_dim, input_dim)
        Sigma = U @ U.T + self.diag_term * jnp.eye(self.input_dim)
        return Sigma

    raise ValueError("params must contain either 'log_diag' (MVP) or 'W' (Wishart)")

log_likelihood

log_likelihood(
    params: Params,
    refs: ndarray,
    probes: ndarray,
    responses: ndarray,
) -> ndarray

Compute the log-likelihood for arrays of trials.

IMPORTANT: We delegate to the TaskLikelihood to avoid duplicating Bernoulli (MPV) or MC likelihood logic in multiple places.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
refs (ndarray, shape(N, input_dim))
required
probes (ndarray, shape(N, input_dim))
required
responses (ndarray, shape(N))

Typically 0/1; task may support richer encodings.

required

Returns:

Name Type Description
loglik ndarray

Scalar log-likelihood (task-only; add prior outside if needed)

Source code in src/psyphy/model/wppm.py
def log_likelihood(
    self,
    params: Params,
    refs: jnp.ndarray,
    probes: jnp.ndarray,
    responses: jnp.ndarray,
) -> jnp.ndarray:
    """
    Compute the log-likelihood for arrays of trials.

    IMPORTANT:
        We delegate to the TaskLikelihood to avoid duplicating Bernoulli (MPV)
        or MC likelihood logic in multiple places.

    Parameters
    ----------
    params : dict
        Model parameters.
    refs : jnp.ndarray, shape (N, input_dim)
    probes : jnp.ndarray, shape (N, input_dim)
    responses : jnp.ndarray, shape (N,)
        Typically 0/1; task may support richer encodings.

    Returns
    -------
    loglik : jnp.ndarray
        Scalar log-likelihood (task-only; add prior outside if needed)
    """
    # We need a ResponseData-like object. To keep this method usable from
    # array inputs, we construct one on the fly. If you already have a
    # ResponseData instance, prefer `log_likelihood_from_data`.
    from psyphy.data.dataset import ResponseData  # local import to avoid cycles

    data = ResponseData()
    # ResponseData.add_trial(ref, probe, resp)
    for r, p, y in zip(refs, probes, responses):
        data.add_trial(r, p, int(y))
    return self.task.loglik(params, data, self, self.noise)

log_likelihood_from_data

log_likelihood_from_data(
    params: Params, data: Any
) -> ndarray

Compute log-likelihood directly from a ResponseData object.

Why delegate to the task? - The task knows the decision rule (oddity, 2AFC, ...). - The task can use the model (this WPPM) to fetch discriminabilities - and the task can use the noise model if it needs MC simulation

Parameters:

Name Type Description Default
params dict

Model parameters.

required
data ResponseData

Collected trial data.

required

Returns:

Name Type Description
loglik ndarray

scalar log-likelihood (task-only; add prior outside if needed)

Source code in src/psyphy/model/wppm.py
def log_likelihood_from_data(self, params: Params, data: Any) -> jnp.ndarray:
    """
    Compute log-likelihood directly from a ResponseData object.

    Why delegate to the task?
        - The task knows the decision rule (oddity, 2AFC, ...).
        - The task can use the model (this WPPM) to fetch discriminabilities
        - and the task can use the noise model if it needs MC simulation

    Parameters
    ----------
    params : dict
        Model parameters.
    data : ResponseData
        Collected trial data.

    Returns
    -------
    loglik : jnp.ndarray
        scalar log-likelihood (task-only; add prior outside if needed)
    """
    return self.task.loglik(params, data, self, self.noise)

log_posterior_from_data

log_posterior_from_data(
    params: Params, data: Any
) -> ndarray

Convenience helper if you want log posterior in one call (MVP).

This simply adds the prior log-probability to the task log-likelihood. Inference engines (e.g., MAP optimizer) typically optimize this quantity.

Returns:

Type Description
jnp.ndarray : scalar log posterior = loglik(params | data) + log_prior(params)
Source code in src/psyphy/model/wppm.py
def log_posterior_from_data(self, params: Params, data: Any) -> jnp.ndarray:
    """
    Convenience helper if you want log posterior in one call (MVP).

    This simply adds the prior log-probability to the task log-likelihood.
    Inference engines (e.g., MAP optimizer) typically optimize this quantity.

    Returns
    -------
    jnp.ndarray : scalar log posterior = loglik(params | data) + log_prior(params)
    """
    return self.log_likelihood_from_data(params, data) + self.prior.log_prob(params)

posterior

posterior(
    X: ndarray | None = None,
    *,
    probes: ndarray | None = None,
    kind: str = "predictive",
) -> PredictivePosterior | ParameterPosterior

Return posterior distribution.

Parameters:

Name Type Description Default
X ndarray | None

Test stimuli (references), shape (n_test, input_dim). Required for predictive posteriors, optional for parameter posteriors.

None
probes ndarray | None

Test probes, shape (n_test, input_dim). Required for predictive posteriors.

None
kind ('predictive', 'parameter')

Type of posterior to return: - "predictive": PredictivePosterior over f(X*) [for acquisitions] - "parameter": ParameterPosterior over θ [for diagnostics]

"predictive"

Returns:

Type Description
PredictivePosterior | ParameterPosterior

Posterior distribution

Raises:

Type Description
RuntimeError

If model has not been fit yet

Examples:

1
2
3
4
>>> # For acquisition functions
>>> pred_post = model.posterior(X_candidates, probes=X_probes)
>>> mean = pred_post.mean
>>> var = pred_post.variance
1
2
3
>>> # For diagnostics
>>> param_post = model.posterior(kind="parameter")
>>> samples = param_post.sample(100, key=jr.PRNGKey(42))
Source code in src/psyphy/model/base.py
def posterior(
    self,
    X: jnp.ndarray | None = None,
    *,
    probes: jnp.ndarray | None = None,
    kind: str = "predictive",
) -> PredictivePosterior | ParameterPosterior:
    """
    Return posterior distribution.

    Parameters
    ----------
    X : jnp.ndarray | None
        Test stimuli (references), shape (n_test, input_dim).
        Required for predictive posteriors, optional for parameter posteriors.
    probes : jnp.ndarray | None
        Test probes, shape (n_test, input_dim).
        Required for predictive posteriors.
    kind : {"predictive", "parameter"}
        Type of posterior to return:
        - "predictive": PredictivePosterior over f(X*) [for acquisitions]
        - "parameter": ParameterPosterior over θ [for diagnostics]

    Returns
    -------
    PredictivePosterior | ParameterPosterior
        Posterior distribution

    Raises
    ------
    RuntimeError
        If model has not been fit yet

    Examples
    --------
    >>> # For acquisition functions
    >>> pred_post = model.posterior(X_candidates, probes=X_probes)
    >>> mean = pred_post.mean
    >>> var = pred_post.variance

    >>> # For diagnostics
    >>> param_post = model.posterior(kind="parameter")
    >>> samples = param_post.sample(100, key=jr.PRNGKey(42))
    """
    if self._posterior is None:
        raise RuntimeError("Must call fit() before posterior()")

    if kind == "parameter":
        return self._posterior
    elif kind == "predictive":
        if X is None:
            raise ValueError("X is required for predictive posteriors")
        from psyphy.posterior import WPPMPredictivePosterior

        return WPPMPredictivePosterior(self._posterior, X, probes=probes)
    else:
        raise ValueError(
            f"Unknown kind: '{kind}'. Use 'predictive' or 'parameter'."
        )

predict_prob

predict_prob(params: Params, stimulus: Stimulus) -> ndarray

Predict probability of a correct response for a single stimulus.

Design choice: WPPM computes discriminability & covariance; the TASK defines how that translates to performance. We therefore delegate to: task.predict(params, stimulus, model=self, noise=self.noise)

Parameters:

Name Type Description Default
params dict
required
stimulus (reference, probe)
required

Returns:

Name Type Description
p_correct ndarray
Source code in src/psyphy/model/wppm.py
def predict_prob(self, params: Params, stimulus: Stimulus) -> jnp.ndarray:
    """
    Predict probability of a correct response for a single stimulus.

    Design choice:
        WPPM computes discriminability & covariance; the TASK defines how
        that translates to performance. We therefore delegate to:
            task.predict(params, stimulus, model=self, noise=self.noise)

    Parameters
    ----------
    params : dict
    stimulus : (reference, probe)

    Returns
    -------
    p_correct : jnp.ndarray
    """
    return self.task.predict(params, stimulus, self, self.noise)

predict_with_params

predict_with_params(
    X: ndarray,
    probes: ndarray | None,
    params: dict[str, ndarray],
) -> ndarray

Evaluate model at specific parameter values (no marginalization).

This is useful for: - Threshold uncertainty estimation (evaluate at sampled parameters) - Parameter sensitivity analysis - Debugging and diagnostics

NOT for making predictions (use .posterior() instead, which marginalizes over parameter uncertainty).

Parameters:

Name Type Description Default
X (ndarray, shape(n_test, input_dim))

Test stimuli (references)

required
probes (ndarray, shape(n_test, input_dim))

Probe stimuli (for discrimination tasks)

required
params dict[str, ndarray]

Specific parameter values to evaluate at. Keys and shapes depend on the model (e.g., WPPM has "W", "noise_scale", etc.)

required

Returns:

Name Type Description
predictions (ndarray, shape(n_test))

Predicted probabilities at each test point, given these parameters

Examples:

>>> # Sample parameters and evaluate
>>> param_post = model.posterior(kind="parameter")
>>> samples = param_post.sample(100, key=jr.PRNGKey(0))
>>>
>>> # Evaluate at first parameter sample
>>> params_0 = {k: v[0] for k, v in samples.items()}
>>> predictions = model.predict_with_params(X_test, probes, params_0)
>>>
>>> # Use for threshold uncertainty estimation
>>> threshold_locs = []
>>> for i in range(100):
...     params_i = {k: v[i] for k, v in samples.items()}
...     preds_i = model.predict_with_params(X_grid, probes_grid, params_i)
...     threshold_idx = jnp.argmin(jnp.abs(preds_i - 0.75))
...     threshold_locs.append(threshold_idx)
Notes

This bypasses the posterior marginalization. For acquisition functions, always use .posterior() which properly accounts for parameter uncertainty.

Source code in src/psyphy/model/base.py
def predict_with_params(
    self,
    X: jnp.ndarray,
    probes: jnp.ndarray | None,
    params: dict[str, jnp.ndarray],
) -> jnp.ndarray:
    """
    Evaluate model at specific parameter values (no marginalization).

    This is useful for:
    - Threshold uncertainty estimation (evaluate at sampled parameters)
    - Parameter sensitivity analysis
    - Debugging and diagnostics

    NOT for making predictions (use .posterior() instead, which
    marginalizes over parameter uncertainty).

    Parameters
    ----------
    X : jnp.ndarray, shape (n_test, input_dim)
        Test stimuli (references)
    probes : jnp.ndarray, shape (n_test, input_dim), optional
        Probe stimuli (for discrimination tasks)
    params : dict[str, jnp.ndarray]
        Specific parameter values to evaluate at.
        Keys and shapes depend on the model (e.g., WPPM has "W", "noise_scale", etc.)

    Returns
    -------
    predictions : jnp.ndarray, shape (n_test,)
        Predicted probabilities at each test point, given these parameters

    Examples
    --------
    >>> # Sample parameters and evaluate
    >>> param_post = model.posterior(kind="parameter")
    >>> samples = param_post.sample(100, key=jr.PRNGKey(0))
    >>>
    >>> # Evaluate at first parameter sample
    >>> params_0 = {k: v[0] for k, v in samples.items()}
    >>> predictions = model.predict_with_params(X_test, probes, params_0)
    >>>
    >>> # Use for threshold uncertainty estimation
    >>> threshold_locs = []
    >>> for i in range(100):
    ...     params_i = {k: v[i] for k, v in samples.items()}
    ...     preds_i = model.predict_with_params(X_grid, probes_grid, params_i)
    ...     threshold_idx = jnp.argmin(jnp.abs(preds_i - 0.75))
    ...     threshold_locs.append(threshold_idx)

    Notes
    -----
    This bypasses the posterior marginalization. For acquisition functions,
    always use .posterior() which properly accounts for parameter uncertainty.
    """
    return self._forward(X, probes, params)

Wishart Psyochophysical Process Model (WPPM)


wppm

wppm.py

Wishart Process Psychophysical Model (WPPM) — MVP-style implementation with forward-compatible hooks for the full WPPM model.

Goals

1) MVP that runs today: - Local covariance Σ(x) is diagonal and constant across the space. - Discriminability is Mahalanobis distance under Σ(reference). - Task mapping (e.g., Oddity, 2AFC) converts discriminability -> p(correct). - Likelihood is delegated to the TaskLikelihood (no Bernoulli code here).

2) Forward compatibility with full WPPM model: - Expose hyperparameters needed to for example use Model config used in Hong et al.: * extra_dims: embedding size for basis expansions (unused in MVP) * variance_scale: global covariance scale (unused in MVP) * lengthscale: smoothness/length-scale for covariance field (unused in MVP) * diag_term: numerical stabilizer added to covariance diagonals (used in MVP) - Later, replace local_covariance with a basis-expansion Wishart process and swap discriminability/likelihood with MC observer simulation.

All numerics use JAX (jax.numpy as jnp) to support autodiff and optax optimizers

Classes:

Name Description
WPPM

Wishart Process Psychophysical Model (WPPM).

Attributes:

Name Type Description
Params
Stimulus

Params

Params = dict[str, ndarray]

Stimulus

Stimulus = tuple[ndarray, ndarray]

WPPM

WPPM(
    input_dim: int,
    prior: Prior,
    task: TaskLikelihood,
    noise: Any | None = None,
    *,
    extra_dims: int = 0,
    variance_scale: float = 1.0,
    lengthscale: float = 1.0,
    diag_term: float = 1e-06,
    **kwargs,
)

Bases: Model

Wishart Process Psychophysical Model (WPPM).

Parameters:

Name Type Description Default
input_dim int

Dimensionality of the input stimulus space (e.g., 2 for isoluminant plane, 3 for RGB). Both reference and probe live in R^{input_dim}.

required
prior Prior

Prior distribution over model parameters. Controls basis_degree for Wishart mode (basis expansion) vs MVP mode (diagonal covariance). The WPPM delegates to prior.basis_degree to ensure consistency between parameter sampling and basis evaluation.

required
task TaskLikelihood

Psychophysical task mapping that defines how discriminability translates to p(correct) and how log-likelihood of responses is computed. (e.g., OddityTask, TwoAFC)

required
noise Any

Noise model describing internal representation noise (e.g., GaussianNoise). Not used in MVP mapping but passed to the task interface for future MC sims.

None
Forward-compatible hyperparameters

extra_dims : int, default=0 Additional embedding dimensions for basis expansions (beyond input_dim). In Wishart mode, embedding_dim = input_dim + extra_dims. variance_scale : float, default=1.0 Global scaling factor for covariance magnitude (unused in MVP). lengthscale : float, default=1.0 Smoothness/length-scale for spatial covariance variation (unused in MVP). (formerly "decay_rate") diag_term : float, default=1e-6 Small positive value added to the covariance diagonal for numerical stability. MVP uses this in matrix solves; the research model will also use it.

Methods:

Name Description
condition_on_observations

Update model with new observations (online learning).

discriminability

Compute scalar discriminability d >= 0 for a (reference, probe) pair

fit

Fit model to data.

init_params

Sample initial parameters from the prior.

local_covariance

Return local covariance Σ(x) at stimulus location x.

log_likelihood

Compute the log-likelihood for arrays of trials.

log_likelihood_from_data

Compute log-likelihood directly from a ResponseData object.

log_posterior_from_data

Convenience helper if you want log posterior in one call (MVP).

posterior

Return posterior distribution.

predict_prob

Predict probability of a correct response for a single stimulus.

predict_with_params

Evaluate model at specific parameter values (no marginalization).

Attributes:

Name Type Description
basis_degree int | None

Chebyshev polynomial degree for Wishart process basis expansion.

diag_term
embedding_dim int

Dimension of the embedding space (perceptual space).

extra_dims
input_dim
lengthscale
noise
online_config
prior
task
variance_scale
Source code in src/psyphy/model/wppm.py
def __init__(
    self,
    input_dim: int,
    prior: Prior,
    task: TaskLikelihood,
    noise: Any | None = None,
    *,  # everything after here is keyword-only
    extra_dims: int = 0,
    variance_scale: float = 1.0,
    lengthscale: float = 1.0,
    diag_term: float = 1e-6,
    **kwargs,  # Accept online_config from model base
) -> None:
    # Initialize Model base class
    super().__init__(**kwargs)

    # --- core components ---
    self.input_dim = int(input_dim)  # stimulus-space dimensionality
    self.prior = prior  # prior over parameter PyTree
    self.task = task  # task mapping and likelihood
    self.noise = noise  # noise model

    # --- forward-compatible hyperparameters (stubs in MVP) ---
    self.extra_dims = int(extra_dims)
    self.variance_scale = float(variance_scale)
    self.lengthscale = float(lengthscale)
    self.diag_term = float(diag_term)

basis_degree

basis_degree: int | None

Chebyshev polynomial degree for Wishart process basis expansion.

This property delegates to self.prior.basis_degree to ensure consistency between parameter sampling and basis evaluation.

Returns:

Type Description
int | None

Degree of Chebyshev polynomial basis (0 = constant, 1 = linear, etc.) None indicates MVP mode (no basis expansion)

Notes

WPPM gets its basis_degree parameter from Prior.basis_degree.

diag_term

diag_term = float(diag_term)

embedding_dim

embedding_dim: int

Dimension of the embedding space (perceptual space).

embedding_dim = input_dim + extra_dims. this represents the full perceptual space where: - First input_dim dimensions correspond to observable stimulus features - Remaining extra_dims are latent dimensions

Returns:

Type Description
int

input_dim + extra_dims (in Wishart mode) input_dim (in MVP mode, extra_dims ignored)

Notes

This is a computed property, not a constructor parameter.

extra_dims

extra_dims = int(extra_dims)

input_dim

input_dim = int(input_dim)

lengthscale

lengthscale = float(lengthscale)

noise

noise = noise

online_config

online_config = online_config or OnlineConfig()

prior

prior = prior

task

task = task

variance_scale

variance_scale = float(variance_scale)

condition_on_observations

condition_on_observations(X: ndarray, y: ndarray) -> Model

Update model with new observations (online learning).

Behavior depends on self.online_config.strategy: - "full": Accumulate all data, refit periodically - "sliding_window": Keep only recent window_size trials - "reservoir": Random sampling of window_size trials - "none": Refit from scratch (no caching)

Returns a NEW model instance (immutable update).

Parameters:

Name Type Description Default
X ndarray

New stimuli

required
y ndarray

New responses

required

Returns:

Type Description
Model

Updated model (new instance)

Examples:

1
2
3
4
5
>>> # Online learning loop
>>> model = WPPM(...).fit(X_init, y_init)
>>> for X_new, y_new in stream:
...     model = model.condition_on_observations(X_new, y_new)
...     # Model automatically manages memory and refitting
Source code in src/psyphy/model/base.py
def condition_on_observations(self, X: jnp.ndarray, y: jnp.ndarray) -> Model:
    """
    Update model with new observations (online learning).

    Behavior depends on self.online_config.strategy:
    - "full": Accumulate all data, refit periodically
    - "sliding_window": Keep only recent window_size trials
    - "reservoir": Random sampling of window_size trials
    - "none": Refit from scratch (no caching)

    Returns a NEW model instance (immutable update).

    Parameters
    ----------
    X : jnp.ndarray
        New stimuli
    y : jnp.ndarray
        New responses

    Returns
    -------
    Model
        Updated model (new instance)

    Examples
    --------
    >>> # Online learning loop
    >>> model = WPPM(...).fit(X_init, y_init)
    >>> for X_new, y_new in stream:
    ...     model = model.condition_on_observations(X_new, y_new)
    ...     # Model automatically manages memory and refitting
    """
    from psyphy.data import ResponseData

    # Convert new data
    new_data = ResponseData.from_arrays(X, y)

    # Update data buffer according to strategy
    if self.online_config.strategy == "none":
        data_to_fit = new_data

    elif self.online_config.strategy == "full":
        if self._data_buffer is None:
            self._data_buffer = ResponseData()
        self._data_buffer.merge(new_data)
        data_to_fit = self._data_buffer

    elif self.online_config.strategy == "sliding_window":
        if self._data_buffer is None:
            self._data_buffer = ResponseData()
        self._data_buffer.merge(new_data)

        # Keep only last N trials
        window_size = self.online_config.window_size
        assert window_size is not None, "window_size must be set for sliding_window"
        if len(self._data_buffer) > window_size:
            self._data_buffer = self._data_buffer.tail(window_size)
        data_to_fit = self._data_buffer

    elif self.online_config.strategy == "reservoir":
        if self._data_buffer is None:
            self._data_buffer = ResponseData()

        # Reservoir sampling
        window_size = self.online_config.window_size
        assert window_size is not None, "window_size must be set for reservoir"
        self._data_buffer = self._reservoir_update(
            self._data_buffer,
            new_data,
            window_size,
        )
        data_to_fit = self._data_buffer
    else:
        raise ValueError(f"Unknown strategy: {self.online_config.strategy}")

    # Decide whether to refit
    self._n_updates += 1
    should_refit = self._n_updates % self.online_config.refit_interval == 0

    if not should_refit:
        # Return clone with updated buffer but old posterior
        new_model = self._clone()
        new_model._data_buffer = data_to_fit
        return new_model

    # Refit with optional warm start
    inference = self._inference_engine
    assert inference is not None, (
        "Model must be fit before condition_on_observations"
    )

    if self.online_config.warm_start and self._posterior is not None:
        # TODO: Add warm-start support to inference engines
        # For now, just refit from cached params
        pass

    new_model = self._clone()
    new_model._data_buffer = data_to_fit
    new_model._posterior = inference.fit(new_model, data_to_fit)
    new_model._inference_engine = inference

    return new_model

discriminability

discriminability(
    params: Params, stimulus: Stimulus
) -> ndarray

Compute scalar discriminability d >= 0 for a (reference, probe) pair

MVP mode: d = sqrt( (probe - ref)^T Σ(ref)^{-1} (probe - ref) ) with Σ(ref) the local covariance at the reference in stimulus space.

Wishart mode (rectangular U design) if extra_dims > 0: d = sqrt( (probe - ref)^T Σ(ref)^{-1} (probe - ref) ) where Σ(ref) is directly computed in stimulus space (input_dim, input_dim) via U(x) @ U(x)^T with U rectangular.

The discrimination task only depends on observable stimulus dimensions. The rectangular U design means local_covariance() already returns the stimulus covariance - no block extraction needed.

Future (full WPPM mode): d is implicit via Monte Carlo simulation of internal noisy responses under the task's decision rule (no closed form). In that case, tasks will directly implement predict/loglik with MC, and this method may be used only for diagnostics.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
stimulus tuple

(reference, probe) arrays of shape (input_dim,).

required

Returns:

Name Type Description
d ndarray

Nonnegative scalar discriminability.

Source code in src/psyphy/model/wppm.py
def discriminability(self, params: Params, stimulus: Stimulus) -> jnp.ndarray:
    """
    Compute scalar discriminability d >= 0 for a (reference, probe) pair

    MVP mode:
        d = sqrt( (probe - ref)^T Σ(ref)^{-1} (probe - ref) )
        with Σ(ref) the local covariance at the reference in stimulus space.

    Wishart mode (rectangular U design) if extra_dims > 0:
        d = sqrt( (probe - ref)^T Σ(ref)^{-1} (probe - ref) )
        where Σ(ref) is directly computed in stimulus space (input_dim, input_dim)
        via U(x) @ U(x)^T with U rectangular.

    The discrimination task only depends on observable stimulus dimensions.
    The rectangular U design means local_covariance() already returns
    the stimulus covariance - no block extraction needed.

    Future (full WPPM mode):
        d is implicit via Monte Carlo simulation of internal noisy responses
        under the task's decision rule (no closed form). In that case, tasks
        will directly implement predict/loglik with MC, and this method may be
        used only for diagnostics.

    Parameters
    ----------
    params : dict
        Model parameters.
    stimulus : tuple
        (reference, probe) arrays of shape (input_dim,).

    Returns
    -------
    d : jnp.ndarray
        Nonnegative scalar discriminability.
    """
    ref, probe = stimulus

    # Delta is in stimulus space (input_dim)
    delta = probe - ref

    # Get stimulus covariance at reference
    # (rectangular U design: already returns (input_dim, input_dim))
    Sigma = self.local_covariance(params, ref)

    # Add jitter for stable solve; diag_term is configurable
    jitter = self.diag_term * jnp.eye(self.input_dim)

    # Solve (Σ + jitter)^{-1} delta using a PD-aware solver
    x = jax.scipy.linalg.solve(Sigma + jitter, delta, assume_a="pos")
    d2 = jnp.dot(delta, x)  # quadratic form

    # Guard against tiny negative values from numerical error
    return jnp.sqrt(jnp.maximum(d2, 0.0))

fit

fit(
    X: ndarray,
    y: ndarray,
    *,
    inference: InferenceEngine | str = "laplace",
    inference_config: dict | None = None,
) -> Model

Fit model to data.

Parameters:

Name Type Description Default
X ndarray

Stimuli, shape (n_trials, 2, input_dim) for (ref, probe) pairs or (n_trials, input_dim) for references only

required
y ndarray

Responses, shape (n_trials,)

required
inference InferenceEngine | str

Inference engine or string key ("map", "laplace", "langevin")

"laplace"
inference_config dict | None

Hyperparameters for string-based inference. Examples: {"steps": 500, "lr": 1e-3} for MAP

None

Returns:

Type Description
Model

Self for method chaining

Examples:

>>> # Simple: use defaults
>>> model.fit(X, y)
1
2
3
>>> # Explicit optimizer
>>> from psyphy.inference import MAPOptimizer
>>> model.fit(X, y, inference=MAPOptimizer(steps=500))
>>> # String + config (for experiment tracking)
>>> model.fit(X, y, inference="map", inference_config={"steps": 500})
Source code in src/psyphy/model/base.py
def fit(
    self,
    X: jnp.ndarray,
    y: jnp.ndarray,
    *,
    inference: InferenceEngine | str = "laplace",
    inference_config: dict | None = None,
) -> Model:
    """
    Fit model to data.

    Parameters
    ----------
    X : jnp.ndarray
        Stimuli, shape (n_trials, 2, input_dim) for (ref, probe) pairs
        or (n_trials, input_dim) for references only
    y : jnp.ndarray
        Responses, shape (n_trials,)
    inference : InferenceEngine | str, default="laplace"
        Inference engine or string key ("map", "laplace", "langevin")
    inference_config : dict | None
        Hyperparameters for string-based inference.
        Examples: {"steps": 500, "lr": 1e-3} for MAP

    Returns
    -------
    Model
        Self for method chaining

    Examples
    --------
    >>> # Simple: use defaults
    >>> model.fit(X, y)

    >>> # Explicit optimizer
    >>> from psyphy.inference import MAPOptimizer
    >>> model.fit(X, y, inference=MAPOptimizer(steps=500))

    >>> # String + config (for experiment tracking)
    >>> model.fit(X, y, inference="map", inference_config={"steps": 500})
    """
    from psyphy.data import ResponseData
    from psyphy.inference import INFERENCE_ENGINES, InferenceEngine

    # Resolve inference engine
    is_string_inference = isinstance(inference, str)

    if is_string_inference:
        config = inference_config or {}
        inference_key: str = inference  # type: ignore[assignment]
        if inference_key not in INFERENCE_ENGINES:
            available = ", ".join(INFERENCE_ENGINES.keys())
            raise ValueError(
                f"Unknown inference: '{inference}'. Available: {available}"
            )
        inference_engine: InferenceEngine = INFERENCE_ENGINES[inference_key](
            **config
        )
    elif isinstance(inference, InferenceEngine):
        inference_engine = inference
    else:
        raise TypeError(
            f"inference must be InferenceEngine or str, got {type(inference)}"
        )

    if inference_config is not None and not is_string_inference:
        raise ValueError(
            "Cannot pass inference_config with InferenceEngine instance"
        )

    # Convert data
    data = ResponseData.from_arrays(X, y)

    # Fit
    self._posterior = inference_engine.fit(self, data)
    self._inference_engine = inference_engine
    self._data_buffer = data  # Initialize buffer
    return self

init_params

init_params(key: KeyArray) -> Params

Sample initial parameters from the prior.

MVP parameters: {"log_diag": shape (input_dim,)} which defines a constant diagonal covariance across the space.

Returns:

Name Type Description
params dict[str, ndarray]
Source code in src/psyphy/model/wppm.py
def init_params(self, key: jr.KeyArray) -> Params:
    """
    Sample initial parameters from the prior.

    MVP parameters:
        {"log_diag": shape (input_dim,)}
    which defines a constant diagonal covariance across the space.

    Returns
    -------
    params : dict[str, jnp.ndarray]
    """
    return self.prior.sample_params(key)

local_covariance

local_covariance(params: Params, x: ndarray) -> ndarray

Return local covariance Σ(x) at stimulus location x.

MVP mode (basis_degree=None): Σ(x) = diag(exp(log_diag)), constant across x. - Positive-definite because exp(log_diag) > 0.

Wishart mode (basis_degree set): Σ(x) = U(x) @ U(x)^T + diag_term * I where U(x) is rectangular (input_dim, embedding_dim) if extra_dims > 0. - Varies smoothly with x - Guaranteed positive-definite - Returns stimulus covariance directly (input_dim, input_dim)

Parameters:

Name Type Description Default
params dict

Model parameters: - MVP: {"log_diag": (input_dim,)} - Wishart: {"W": (degree+1, ..., input_dim, embedding_dim)}

required
x (ndarray, shape(input_dim))

Stimulus location

required

Returns:

Type Description
Σ : jnp.ndarray, shape (input_dim, input_dim)

Covariance matrix in stimulus space.

Source code in src/psyphy/model/wppm.py
def local_covariance(self, params: Params, x: jnp.ndarray) -> jnp.ndarray:
    """
    Return local covariance Σ(x) at stimulus location x.

    MVP mode (basis_degree=None):
        Σ(x) = diag(exp(log_diag)), constant across x.
        - Positive-definite because exp(log_diag) > 0.

    Wishart mode (basis_degree set):
        Σ(x) = U(x) @ U(x)^T + diag_term * I
        where U(x) is rectangular (input_dim, embedding_dim) if extra_dims > 0.
        - Varies smoothly with x
        - Guaranteed positive-definite
        - Returns stimulus covariance directly (input_dim, input_dim)

    Parameters
    ----------
    params : dict
        Model parameters:
        - MVP: {"log_diag": (input_dim,)}
        - Wishart: {"W": (degree+1, ..., input_dim, embedding_dim)}
    x : jnp.ndarray, shape (input_dim,)
        Stimulus location

    Returns
    -------
    Σ : jnp.ndarray, shape (input_dim, input_dim)
        Covariance matrix in stimulus space.
    """
    # MVP mode: constant diagonal covariance
    if "log_diag" in params:
        log_diag = params["log_diag"]
        diag = jnp.exp(log_diag)
        return jnp.diag(diag)

    # Wishart mode: spatially-varying covariance
    if "W" in params:
        U = self._compute_U(params, x)  # (input_dim, embedding_dim)
        # Σ(x) = U(x) @ U(x)^T + diag_term * I
        # Result is (input_dim, input_dim)
        Sigma = U @ U.T + self.diag_term * jnp.eye(self.input_dim)
        return Sigma

    raise ValueError("params must contain either 'log_diag' (MVP) or 'W' (Wishart)")

log_likelihood

log_likelihood(
    params: Params,
    refs: ndarray,
    probes: ndarray,
    responses: ndarray,
) -> ndarray

Compute the log-likelihood for arrays of trials.

IMPORTANT: We delegate to the TaskLikelihood to avoid duplicating Bernoulli (MPV) or MC likelihood logic in multiple places.

Parameters:

Name Type Description Default
params dict

Model parameters.

required
refs (ndarray, shape(N, input_dim))
required
probes (ndarray, shape(N, input_dim))
required
responses (ndarray, shape(N))

Typically 0/1; task may support richer encodings.

required

Returns:

Name Type Description
loglik ndarray

Scalar log-likelihood (task-only; add prior outside if needed)

Source code in src/psyphy/model/wppm.py
def log_likelihood(
    self,
    params: Params,
    refs: jnp.ndarray,
    probes: jnp.ndarray,
    responses: jnp.ndarray,
) -> jnp.ndarray:
    """
    Compute the log-likelihood for arrays of trials.

    IMPORTANT:
        We delegate to the TaskLikelihood to avoid duplicating Bernoulli (MPV)
        or MC likelihood logic in multiple places.

    Parameters
    ----------
    params : dict
        Model parameters.
    refs : jnp.ndarray, shape (N, input_dim)
    probes : jnp.ndarray, shape (N, input_dim)
    responses : jnp.ndarray, shape (N,)
        Typically 0/1; task may support richer encodings.

    Returns
    -------
    loglik : jnp.ndarray
        Scalar log-likelihood (task-only; add prior outside if needed)
    """
    # We need a ResponseData-like object. To keep this method usable from
    # array inputs, we construct one on the fly. If you already have a
    # ResponseData instance, prefer `log_likelihood_from_data`.
    from psyphy.data.dataset import ResponseData  # local import to avoid cycles

    data = ResponseData()
    # ResponseData.add_trial(ref, probe, resp)
    for r, p, y in zip(refs, probes, responses):
        data.add_trial(r, p, int(y))
    return self.task.loglik(params, data, self, self.noise)

log_likelihood_from_data

log_likelihood_from_data(
    params: Params, data: Any
) -> ndarray

Compute log-likelihood directly from a ResponseData object.

Why delegate to the task? - The task knows the decision rule (oddity, 2AFC, ...). - The task can use the model (this WPPM) to fetch discriminabilities - and the task can use the noise model if it needs MC simulation

Parameters:

Name Type Description Default
params dict

Model parameters.

required
data ResponseData

Collected trial data.

required

Returns:

Name Type Description
loglik ndarray

scalar log-likelihood (task-only; add prior outside if needed)

Source code in src/psyphy/model/wppm.py
def log_likelihood_from_data(self, params: Params, data: Any) -> jnp.ndarray:
    """
    Compute log-likelihood directly from a ResponseData object.

    Why delegate to the task?
        - The task knows the decision rule (oddity, 2AFC, ...).
        - The task can use the model (this WPPM) to fetch discriminabilities
        - and the task can use the noise model if it needs MC simulation

    Parameters
    ----------
    params : dict
        Model parameters.
    data : ResponseData
        Collected trial data.

    Returns
    -------
    loglik : jnp.ndarray
        scalar log-likelihood (task-only; add prior outside if needed)
    """
    return self.task.loglik(params, data, self, self.noise)

log_posterior_from_data

log_posterior_from_data(
    params: Params, data: Any
) -> ndarray

Convenience helper if you want log posterior in one call (MVP).

This simply adds the prior log-probability to the task log-likelihood. Inference engines (e.g., MAP optimizer) typically optimize this quantity.

Returns:

Type Description
jnp.ndarray : scalar log posterior = loglik(params | data) + log_prior(params)
Source code in src/psyphy/model/wppm.py
def log_posterior_from_data(self, params: Params, data: Any) -> jnp.ndarray:
    """
    Convenience helper if you want log posterior in one call (MVP).

    This simply adds the prior log-probability to the task log-likelihood.
    Inference engines (e.g., MAP optimizer) typically optimize this quantity.

    Returns
    -------
    jnp.ndarray : scalar log posterior = loglik(params | data) + log_prior(params)
    """
    return self.log_likelihood_from_data(params, data) + self.prior.log_prob(params)

posterior

posterior(
    X: ndarray | None = None,
    *,
    probes: ndarray | None = None,
    kind: str = "predictive",
) -> PredictivePosterior | ParameterPosterior

Return posterior distribution.

Parameters:

Name Type Description Default
X ndarray | None

Test stimuli (references), shape (n_test, input_dim). Required for predictive posteriors, optional for parameter posteriors.

None
probes ndarray | None

Test probes, shape (n_test, input_dim). Required for predictive posteriors.

None
kind ('predictive', 'parameter')

Type of posterior to return: - "predictive": PredictivePosterior over f(X*) [for acquisitions] - "parameter": ParameterPosterior over θ [for diagnostics]

"predictive"

Returns:

Type Description
PredictivePosterior | ParameterPosterior

Posterior distribution

Raises:

Type Description
RuntimeError

If model has not been fit yet

Examples:

1
2
3
4
>>> # For acquisition functions
>>> pred_post = model.posterior(X_candidates, probes=X_probes)
>>> mean = pred_post.mean
>>> var = pred_post.variance
1
2
3
>>> # For diagnostics
>>> param_post = model.posterior(kind="parameter")
>>> samples = param_post.sample(100, key=jr.PRNGKey(42))
Source code in src/psyphy/model/base.py
def posterior(
    self,
    X: jnp.ndarray | None = None,
    *,
    probes: jnp.ndarray | None = None,
    kind: str = "predictive",
) -> PredictivePosterior | ParameterPosterior:
    """
    Return posterior distribution.

    Parameters
    ----------
    X : jnp.ndarray | None
        Test stimuli (references), shape (n_test, input_dim).
        Required for predictive posteriors, optional for parameter posteriors.
    probes : jnp.ndarray | None
        Test probes, shape (n_test, input_dim).
        Required for predictive posteriors.
    kind : {"predictive", "parameter"}
        Type of posterior to return:
        - "predictive": PredictivePosterior over f(X*) [for acquisitions]
        - "parameter": ParameterPosterior over θ [for diagnostics]

    Returns
    -------
    PredictivePosterior | ParameterPosterior
        Posterior distribution

    Raises
    ------
    RuntimeError
        If model has not been fit yet

    Examples
    --------
    >>> # For acquisition functions
    >>> pred_post = model.posterior(X_candidates, probes=X_probes)
    >>> mean = pred_post.mean
    >>> var = pred_post.variance

    >>> # For diagnostics
    >>> param_post = model.posterior(kind="parameter")
    >>> samples = param_post.sample(100, key=jr.PRNGKey(42))
    """
    if self._posterior is None:
        raise RuntimeError("Must call fit() before posterior()")

    if kind == "parameter":
        return self._posterior
    elif kind == "predictive":
        if X is None:
            raise ValueError("X is required for predictive posteriors")
        from psyphy.posterior import WPPMPredictivePosterior

        return WPPMPredictivePosterior(self._posterior, X, probes=probes)
    else:
        raise ValueError(
            f"Unknown kind: '{kind}'. Use 'predictive' or 'parameter'."
        )

predict_prob

predict_prob(params: Params, stimulus: Stimulus) -> ndarray

Predict probability of a correct response for a single stimulus.

Design choice: WPPM computes discriminability & covariance; the TASK defines how that translates to performance. We therefore delegate to: task.predict(params, stimulus, model=self, noise=self.noise)

Parameters:

Name Type Description Default
params dict
required
stimulus (reference, probe)
required

Returns:

Name Type Description
p_correct ndarray
Source code in src/psyphy/model/wppm.py
def predict_prob(self, params: Params, stimulus: Stimulus) -> jnp.ndarray:
    """
    Predict probability of a correct response for a single stimulus.

    Design choice:
        WPPM computes discriminability & covariance; the TASK defines how
        that translates to performance. We therefore delegate to:
            task.predict(params, stimulus, model=self, noise=self.noise)

    Parameters
    ----------
    params : dict
    stimulus : (reference, probe)

    Returns
    -------
    p_correct : jnp.ndarray
    """
    return self.task.predict(params, stimulus, self, self.noise)

predict_with_params

predict_with_params(
    X: ndarray,
    probes: ndarray | None,
    params: dict[str, ndarray],
) -> ndarray

Evaluate model at specific parameter values (no marginalization).

This is useful for: - Threshold uncertainty estimation (evaluate at sampled parameters) - Parameter sensitivity analysis - Debugging and diagnostics

NOT for making predictions (use .posterior() instead, which marginalizes over parameter uncertainty).

Parameters:

Name Type Description Default
X (ndarray, shape(n_test, input_dim))

Test stimuli (references)

required
probes (ndarray, shape(n_test, input_dim))

Probe stimuli (for discrimination tasks)

required
params dict[str, ndarray]

Specific parameter values to evaluate at. Keys and shapes depend on the model (e.g., WPPM has "W", "noise_scale", etc.)

required

Returns:

Name Type Description
predictions (ndarray, shape(n_test))

Predicted probabilities at each test point, given these parameters

Examples:

>>> # Sample parameters and evaluate
>>> param_post = model.posterior(kind="parameter")
>>> samples = param_post.sample(100, key=jr.PRNGKey(0))
>>>
>>> # Evaluate at first parameter sample
>>> params_0 = {k: v[0] for k, v in samples.items()}
>>> predictions = model.predict_with_params(X_test, probes, params_0)
>>>
>>> # Use for threshold uncertainty estimation
>>> threshold_locs = []
>>> for i in range(100):
...     params_i = {k: v[i] for k, v in samples.items()}
...     preds_i = model.predict_with_params(X_grid, probes_grid, params_i)
...     threshold_idx = jnp.argmin(jnp.abs(preds_i - 0.75))
...     threshold_locs.append(threshold_idx)
Notes

This bypasses the posterior marginalization. For acquisition functions, always use .posterior() which properly accounts for parameter uncertainty.

Source code in src/psyphy/model/base.py
def predict_with_params(
    self,
    X: jnp.ndarray,
    probes: jnp.ndarray | None,
    params: dict[str, jnp.ndarray],
) -> jnp.ndarray:
    """
    Evaluate model at specific parameter values (no marginalization).

    This is useful for:
    - Threshold uncertainty estimation (evaluate at sampled parameters)
    - Parameter sensitivity analysis
    - Debugging and diagnostics

    NOT for making predictions (use .posterior() instead, which
    marginalizes over parameter uncertainty).

    Parameters
    ----------
    X : jnp.ndarray, shape (n_test, input_dim)
        Test stimuli (references)
    probes : jnp.ndarray, shape (n_test, input_dim), optional
        Probe stimuli (for discrimination tasks)
    params : dict[str, jnp.ndarray]
        Specific parameter values to evaluate at.
        Keys and shapes depend on the model (e.g., WPPM has "W", "noise_scale", etc.)

    Returns
    -------
    predictions : jnp.ndarray, shape (n_test,)
        Predicted probabilities at each test point, given these parameters

    Examples
    --------
    >>> # Sample parameters and evaluate
    >>> param_post = model.posterior(kind="parameter")
    >>> samples = param_post.sample(100, key=jr.PRNGKey(0))
    >>>
    >>> # Evaluate at first parameter sample
    >>> params_0 = {k: v[0] for k, v in samples.items()}
    >>> predictions = model.predict_with_params(X_test, probes, params_0)
    >>>
    >>> # Use for threshold uncertainty estimation
    >>> threshold_locs = []
    >>> for i in range(100):
    ...     params_i = {k: v[i] for k, v in samples.items()}
    ...     preds_i = model.predict_with_params(X_grid, probes_grid, params_i)
    ...     threshold_idx = jnp.argmin(jnp.abs(preds_i - 0.75))
    ...     threshold_locs.append(threshold_idx)

    Notes
    -----
    This bypasses the posterior marginalization. For acquisition functions,
    always use .posterior() which properly accounts for parameter uncertainty.
    """
    return self._forward(X, probes, params)

Priors


prior

prior.py

Prior distributions for WPPM parameters

MVP implementation: - Gaussian prior over diagonal log-variances

Forward compatibility (Full WPPM mode): - Exposes hyperparameters that will be used when the full Wishart Process covariance field is implemented: * variance_scale : global scaling factor for covariance magnitude * lengthscale : smoothness/length-scale controlling spatial variation * extra_embedding_dims : embedding dimension for basis expansions

Connections
  • WPPM calls Prior.sample_params() to initialize model parameters
  • WPPM adds Prior.log_prob(params) to task log-likelihoods to form the log posterior
  • In Full WPPM mode, Prior will generate structured parameters for basis expansions and lengthscale-controlled smooth covariance fields

Classes:

Name Description
Prior

Prior distribution over WPPM parameters

Attributes:

Name Type Description
Params

Params

Params = dict[str, ndarray]

Prior

Prior(
    input_dim: int,
    scale: float = 0.5,
    basis_degree: int | None = None,
    variance_scale: float = 1.0,
    decay_rate: float = 0.5,
    lengthscale: float = 1.0,
    extra_embedding_dims: int = 0,
)

Prior distribution over WPPM parameters

Parameters:

Name Type Description Default
input_dim int

Dimensionality of the model space (same as WPPM.input_dim)

required
scale float

Stddev of Gaussian prior for log_diag entries (MVP only).

0.5
basis_degree int | None

Degree of Chebyshev basis for Wishart process. If None, uses MVP mode with log_diag parameters. If set, uses Wishart mode with W coefficients.

None
variance_scale float

Prior variance for degree-0 (constant) coefficient in Wishart mode. Controls overall scale of covariances.

1.0
decay_rate float

Geometric decay rate for prior variance over higher-degree coefficients. Prior variance for degree-d coefficient = variance_scale * (decay_rate^d). Smaller decay_rate → stronger smoothness prior.

0.5
lengthscale float

Alias for decay_rate (kept for backward compatibility). If both specified, decay_rate takes precedence.

1.0
extra_embedding_dims int

Additional latent dimensions in U matrices beyond input dimensions. Allows richer ellipsoid shapes in Wishart mode.

0

Methods:

Name Description
default

Convenience constructor with MVP defaults.

log_prob

Compute log prior density (up to a constant)

sample_params

Sample initial parameters from the prior.

Attributes:

Name Type Description
basis_degree int | None
decay_rate float
extra_embedding_dims int
input_dim int
lengthscale float
scale float
variance_scale float

basis_degree

basis_degree: int | None = None

decay_rate

decay_rate: float = 0.5

extra_embedding_dims

extra_embedding_dims: int = 0

input_dim

input_dim: int

lengthscale

lengthscale: float = 1.0

scale

scale: float = 0.5

variance_scale

variance_scale: float = 1.0

default

default(input_dim: int, scale: float = 0.5) -> Prior

Convenience constructor with MVP defaults.

Source code in src/psyphy/model/prior.py
@classmethod
def default(cls, input_dim: int, scale: float = 0.5) -> Prior:
    """Convenience constructor with MVP defaults."""
    return cls(input_dim=input_dim, scale=scale)

log_prob

log_prob(params: Params) -> ndarray

Compute log prior density (up to a constant)

MVP mode: Isotropic Gaussian on log_diag

Wishart mode: Gaussian prior on W with smoothness via decay_rate log p(W) = Σ_ij log N(W_ij | 0, σ_ij^2) where σ_ij^2 = prior variance

Parameters:

Name Type Description Default
params dict

Parameter dictionary

required

Returns:

Name Type Description
log_prob float

Log prior probability (up to normalizing constant)

Source code in src/psyphy/model/prior.py
def log_prob(self, params: Params) -> jnp.ndarray:
    """
    Compute log prior density (up to a constant)

    MVP mode:
        Isotropic Gaussian on log_diag

    Wishart mode:
        Gaussian prior on W with smoothness via decay_rate
        log p(W) = Σ_ij log N(W_ij | 0, σ_ij^2) where σ_ij^2 = prior variance

    Parameters
    ----------
    params : dict
        Parameter dictionary

    Returns
    -------
    log_prob : float
        Log prior probability (up to normalizing constant)
    """
    if "log_diag" in params:
        # MVP mode
        log_diag = params["log_diag"]
        var = self.scale**2
        return -0.5 * jnp.sum((log_diag**2) / var)

    if "W" in params:
        # Wishart mode
        W = params["W"]
        variances = self._compute_W_prior_variances()

        # Gaussian log probability for each entry
        # log N(x | 0, σ^2) = -0.5 * (x^2/σ^2 + log(2πσ^2))
        # Up to constant: -0.5 * x^2/σ^2

        if self.input_dim == 2:
            # Each W[i,j,:,:] ~ Normal(0, variance[i,j] * I)
            return -0.5 * jnp.sum((W**2) / (variances[:, :, None, None] + 1e-10))
        elif self.input_dim == 3:
            return -0.5 * jnp.sum((W**2) / (variances[:, :, :, None, None] + 1e-10))

    raise ValueError("params must contain either 'log_diag' (MVP) or 'W' (Wishart)")

sample_params

sample_params(key: Any) -> Params

Sample initial parameters from the prior.

MVP mode (basis_degree=None): Returns {"log_diag": shape (input_dim,)}

Wishart mode (basis_degree set): Returns {"W": shape (degree+1, degree+1, input_dim, embedding_dim)} for 2D, where embedding_dim = input_dim + extra_embedding_dims

1
2
3
Note: The 3rd dimension is input_dim (output space dimension).
This matches the einsum in _compute_U:
U = einsum("ijde,ij->de", W, phi) where d indexes input_dim.

Parameters:

Name Type Description Default
key JAX random key
required

Returns:

Name Type Description
params dict

Parameter dictionary

Source code in src/psyphy/model/prior.py
def sample_params(self, key: Any) -> Params:
    """
    Sample initial parameters from the prior.

    MVP mode (basis_degree=None):
        Returns {"log_diag": shape (input_dim,)}

    Wishart mode (basis_degree set):
        Returns {"W": shape (degree+1, degree+1, input_dim, embedding_dim)}
        for 2D, where embedding_dim = input_dim + extra_embedding_dims

        Note: The 3rd dimension is input_dim (output space dimension).
        This matches the einsum in _compute_U:
        U = einsum("ijde,ij->de", W, phi) where d indexes input_dim.

    Parameters
    ----------
    key : JAX random key

    Returns
    -------
    params : dict
        Parameter dictionary
    """
    if self.basis_degree is None:
        # MVP mode: simple diagonal covariance
        log_diag = jr.normal(key, shape=(self.input_dim,)) * self.scale
        return {"log_diag": log_diag}

    # Wishart mode: basis function coefficients W
    variances = self._compute_W_prior_variances()
    embedding_dim = self.input_dim + self.extra_embedding_dims

    if self.input_dim == 2:
        # Sample W ~ Normal(0, variances) for each matrix entry
        # Shape: (degree+1, degree+1, input_dim, embedding_dim)
        # Note: degree+1 to match number of basis functions [T_0, ..., T_degree]
        W = jnp.sqrt(variances)[:, :, None, None] * jr.normal(
            key,
            shape=(
                self.basis_degree + 1,
                self.basis_degree + 1,
                self.input_dim,
                embedding_dim,
            ),
        )
    elif self.input_dim == 3:
        # Shape: (degree+1, degree+1, degree+1, input_dim, embedding_dim)
        W = jnp.sqrt(variances)[:, :, :, None, None] * jr.normal(
            key,
            shape=(
                self.basis_degree + 1,
                self.basis_degree + 1,
                self.basis_degree + 1,
                self.input_dim,
                embedding_dim,
            ),
        )
    else:
        raise NotImplementedError(
            f"Wishart process only supports 2D and 3D. Got input_dim={self.input_dim}"
        )

    return {"W": W}

Noise


noise

Classes:

Name Description
GaussianNoise
StudentTNoise

GaussianNoise

GaussianNoise(sigma: float = 1.0)

Methods:

Name Description
log_prob
sample_standard

Sample from standard Gaussian (mean=0, var=1).

Attributes:

Name Type Description
sigma float

sigma

sigma: float = 1.0

log_prob

log_prob(residual: float) -> float
Source code in src/psyphy/model/noise.py
def log_prob(self, residual: float) -> float:
    _ = residual
    return -0.5

sample_standard

sample_standard(
    key: Array, shape: tuple[int, ...]
) -> Array

Sample from standard Gaussian (mean=0, var=1).

Source code in src/psyphy/model/noise.py
def sample_standard(self, key: jax.Array, shape: tuple[int, ...]) -> jax.Array:
    """Sample from standard Gaussian (mean=0, var=1)."""
    return jr.normal(key, shape)

StudentTNoise

StudentTNoise(df: float = 3.0, scale: float = 1.0)

Methods:

Name Description
log_prob
sample_standard

Sample from standard Student-t (df=self.df).

Attributes:

Name Type Description
df float
scale float

df

df: float = 3.0

scale

scale: float = 1.0

log_prob

log_prob(residual: float) -> float
Source code in src/psyphy/model/noise.py
def log_prob(self, residual: float) -> float:
    _ = residual
    return -0.5

sample_standard

sample_standard(
    key: Array, shape: tuple[int, ...]
) -> Array

Sample from standard Student-t (df=self.df).

Source code in src/psyphy/model/noise.py
def sample_standard(self, key: jax.Array, shape: tuple[int, ...]) -> jax.Array:
    """Sample from standard Student-t (df=self.df)."""
    return jr.t(key, self.df, shape)

Tasks


task

task.py

Task likelihoods for different psychophysical experimetns.

Each TaskLikelihood defines: - predict(params, stimuli, model, noise) Map discriminability (computed by model) to probability of correct response.

  • loglik(params, data, model, noise) Compute log-likelihood of observed responses under this task.

  • loglik_mc(params, data, model, noise, num_samples, bandwidth, key) [Optional] Compute log-likelihood via Monte Carlo observer simulation.

MVP implementation: - OddityTask (3AFC) and TwoAFC. - Both use simple sigmoid-like mappings of discriminability -> performance - loglik implemented as Bernoulli log-prob with these predictions

Full WPPM mode: - OddityTask.loglik_mc() provides Monte Carlo likelihood computation: * Implements the full 3-stimulus oddity task (two refs, 1 comparison) * Sample three internal noisy representations: z0, z1 ~ N(ref, Σ_ref), z2 ~ N(comparison, Σ_comparison) * Compute three pairwise Mahalanobis distances: d^2(z0,z1), d^2(z0,z2), d^2(z1,z2) * Decision rule: comparison z2 is odd one out if min[d^2(z0,z2), d^2(z1,z2)] > d^2(z0,z1) * Apply logistic CDF smoothing with bandwidth parameter * Average over MC samples to estimate P(correct)

Connections
  • WPPM delegates to task.predict and task.loglik (never re-implements likelihood)
  • Noise model is passed through from WPPM so tasks can simulate responses.
  • We can define new tasks by subclassing TaskLikelihood and implementing predict() and loglik().

Classes:

Name Description
OddityTask

Three-alternative forced-choice oddity task.

TaskLikelihood

Abstract base class for task likelihoods

TwoAFC

2-alternative forced-choice task (MVP placeholder).

Attributes:

Name Type Description
Stimulus

Stimulus

Stimulus = tuple[ndarray, ndarray]

OddityTask

OddityTask(slope: float = 1.5)

Bases: TaskLikelihood

Three-alternative forced-choice oddity task.

In an oddity task, the observer is presented with three stimuli: two identical references and one comparison. The task is to identify which stimulus is the "odd one out" (the comparison). Performance depends on the discriminability between reference and comparison.

This class provides two likelihood computation methods:

  1. Analytical approximation (MVP mode):
  2. predict(): maps discriminability to P(correct) via tanh
  3. loglik(): Bernoulli likelihood using analytical predictions
  4. Fast, differentiable, suitable for gradient-based optimization

  5. Monte Carlo simulation (Full WPPM mode):

  6. loglik_mc(): simulates the full 3-stimulus oddity task
  7. Samples three internal representations per trial (z0, z1, z2)
  8. Uses proper oddity decision rule with three pairwise distances
  9. More accurate for complex covariance structures
  10. Suitable for validation and benchmarking

Parameters:

Name Type Description Default
slope float

Slope parameter for analytical tanh mapping in predict(). Controls steepness of discriminability -> performance relationship.

1.5

Attributes:

Name Type Description
chance_level float

Chance performance for oddity task (1/3)

performance_range float

Range from chance to perfect performance (2/3)

Notes

The analytical approximation in predict() uses: P(correct) = 1/3 + 2/3 * (1 + tanh(slope * d)) / 2

MC simulation in loglik_mc() (Full 3-stimulus oddity): 1. Sample three internal representations: z_ref, z_refprime ~ N(ref, Σ_ref), z_comparison ~ N(comparison, Σ_comparison) 2. Compute average covariance: Σ_avg = (2/3) Σ_ref + (1/3) Σ_comparison 3. Compute three pairwise Mahalanobis distances: - d^2(z_ref, z_refprime) = distance between two reference samples - d^2(z_ref, z_comparison) = distance from ref to comparison - d^2(z_refprime, z_comparison) = distance from reference_prime to comparison 4. Apply oddity decision rule: delta = min(d^2(z_ref,z_comparison), d^2(z_refprime,z_comparison)) - d^2(z_ref,z_refprime) 5. Logistic smoothing: P(correct) pprox logistic.cdf(delta / bandwidth) 6. Average over samples

Examples:

>>> from psyphy.model.task import OddityTask
>>> from psyphy.model import WPPM, Prior
>>> from psyphy.model.noise import GaussianNoise
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>>
>>> # Create task and model
>>> task = OddityTask(slope=1.5)
>>> model = WPPM(
...     input_dim=2, prior=Prior(input_dim=2), task=task, noise=GaussianNoise()
... )
>>> params = model.init_params(jr.PRNGKey(0))
>>>
>>> # Analytical prediction
>>> ref = jnp.array([0.0, 0.0])
>>> comparison = jnp.array([0.5, 0.5])
>>> p_correct = task.predict(params, (ref, comparison), model, model.noise)
>>> print(f"P(correct) pprox {p_correct:.3f}")
>>>
>>> # MC simulation (more accurate)
>>> from psyphy.data.dataset import ResponseData
>>> data = ResponseData()
>>> data.add_trial(ref, comparison, resp=1)
>>> ll_mc = task.loglik_mc(
...     params, data, model, model.noise, num_samples=1000, key=jr.PRNGKey(42)
... )
>>> print(f"Log-likelihood (MC): {ll_mc:.4f}")

Methods:

Name Description
loglik

Compute log-likelihood using analytical predictions.

loglik_mc

Compute log-likelihood via Monte Carlo observer simulation.

predict

Predict probability of correct response using analytical approximation.

Source code in src/psyphy/model/task.py
def __init__(self, slope: float = 1.5) -> None:
    self.slope = float(slope)
    self.chance_level: float = 1.0 / 3.0
    self.performance_range: float = 1.0 - self.chance_level

chance_level

chance_level: float = 1.0 / 3.0

performance_range

performance_range: float = 1.0 - chance_level

slope

slope = float(slope)

loglik

loglik(
    params: Any, data: Any, model: Any, noise: Any
) -> ndarray

Compute log-likelihood using analytical predictions.

Parameters:

Name Type Description Default
params dict

Model parameters

required
data ResponseData

Trial data containing refs, comparisons, responses

required
model WPPM

Model instance

required
noise NoiseModel

Observer noise model

required

Returns:

Type Description
ndarray

Scalar sum of log-likelihoods over all trials

Notes

Uses Bernoulli log-likelihood: LL = Σ [y * log(p) + (1-y) * log(1-p)] where p comes from predict() (analytical approximation)

Source code in src/psyphy/model/task.py
def loglik(self, params: Any, data: Any, model: Any, noise: Any) -> jnp.ndarray:
    """
    Compute log-likelihood using analytical predictions.

    Parameters
    ----------
    params : dict
        Model parameters
    data : ResponseData
        Trial data containing refs, comparisons, responses
    model : WPPM
        Model instance
    noise : NoiseModel
        Observer noise model

    Returns
    -------
    jnp.ndarray
        Scalar sum of log-likelihoods over all trials

    Notes
    -----
    Uses Bernoulli log-likelihood:
        LL = Σ [y * log(p) + (1-y) * log(1-p)]
    where p comes from predict() (analytical approximation)
    """
    refs, comparisons, responses = data.to_numpy()
    ps = jnp.array(
        [
            self.predict(params, (r, p), model, noise)
            for r, p in zip(refs, comparisons)
        ]
    )
    eps = 1e-9
    return jnp.sum(
        jnp.where(responses == 1, jnp.log(ps + eps), jnp.log(1.0 - ps + eps))
    )

loglik_mc

loglik_mc(
    params: Any,
    data: Any,
    model: Any,
    noise: Any,
    num_samples: int = 1000,
    bandwidth: float = 0.01,
    key: Any = None,
) -> ndarray

Compute log-likelihood via Monte Carlo observer simulation.

This method implements the FULL 3-stimulus oddity task. Instead of using an analytical approximation, we: 1. Sample three internal noisy representations per trial: - z_ref, z_refprime ~ N(ref, Σ_ref) [two samples from reference] - z_comparison ~ N(comparison, Σ_comparison) [one sample from comparison] 2. Compute three pairwise Mahalanobis distances 3. Apply oddity decision rule: comparison is odd if it's farther from BOTH ref and reference_prime 4. Apply logistic smoothing to approximate P(correct) 5. Average over MC samples

Parameters:

Name Type Description Default
params dict

Model parameters (must contain 'W' for WPPM basis coefficients)

required
data ResponseData

Trial data with refs, comparisons, and responses

required
model WPPM

Model instance providing compute_U() for covariance computation

required
noise NoiseModel

Observer noise model (provides sigma for diagonal noise term)

required
num_samples int

Number of Monte Carlo samples per trial. - Use 1000-5000 for accurate likelihood estimation - Larger values reduce MC variance but increase compute time

1000
bandwidth float

Smoothing parameter for logistic CDF approximation. - Smaller values -> sharper transition (closer to step function) - Larger values -> smoother approximation - Typical range: [1e-3, 5e-2]

1e-2
key PRNGKey

Random key for reproducible sampling. If None, uses PRNGKey(0) (deterministic but not recommended for production)

None

Returns:

Type Description
ndarray

Scalar sum of log-likelihoods over all trials. Same shape and interpretation as loglik().

Raises:

Type Description
ValueError

If num_samples <= 0

Notes

Full 3-stimulus oddity task algorithm:

For each trial (ref, comparison, response): 1. Compute covariances: - Σ_ref = U_ref @ U_ref.T + σ^2 I - Σ_comparison = U_comparison @ U_comparison.T + σ^2 I - Σ_avg = (2/3) Σ_ref + (1/3) Σ_comparison [weighted by stimulus frequency]

  1. Sample three internal representations:
  2. z_ref, z_refprime ~ N(ref, Σ_ref) [2 samples from reference, num_samples times each]
  3. z_comparison ~ N(comparison, Σ_comparison) [1 sample from comparison, num_samples times]

  4. Compute three pairwise Mahalanobis distances:

  5. d^2(z_ref, z_refprime) = (z_ref - z_refprime).T @ Σ_avg^{-1} @ (z_ref - z_refprime) [ref vs reference_prime]
  6. d^2(z_ref, z_comparison) = (z_ref - z_comparison).T @ Σ_avg^{-1} @ (z_ref - z_comparison) [ref vs comparison]
  7. d^2(z_refprime, z_comparison) = (z_refprime - z_comparison).T @ Σ_avg^{-1} @ (z_refprime - z_comparison) [reference_prime vs comparison]

  8. Apply oddity decision rule:

  9. delta = min(d^2(z_ref,z_comparison), d^2(z_refprime,z_comparison)) - d^2(z_ref,z_refprime)
  10. delta > 0 means comparison is farther from BOTH ref and reference_prime -> correct identification

  11. Apply logistic smoothing:

  12. P(correct) pprox mean(logistic.cdf(delta / bandwidth))

  13. Bernoulli log-likelihood:

  14. LL = Σ [y * log(p) + (1-y) * log(1-p)]

Performance: - Time complexity: O(n_trials * num_samples * input_dim³) - Memory: O(num_samples * input_dim) per trial - Vectorized across trials using jax.vmap for GPU acceleration - Can be JIT-compiled for additional speed (future optimization)

Comparison to analytical: - MC implements full 3-stimulus oddity (more realistic) - MC is more accurate for complex Σ(x) structures - Analytical is faster and differentiable - Use MC for validation and benchmarking, analytical for optimization

Examples:

>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> from psyphy.model import WPPM, Prior
>>> from psyphy.model.task import OddityTask
>>> from psyphy.model.noise import GaussianNoise
>>> from psyphy.data.dataset import ResponseData
>>>
>>> # Setup
>>> model = WPPM(
...     input_dim=2,
...     prior=Prior(input_dim=2, basis_degree=3),
...     task=OddityTask(),
...     noise=GaussianNoise(sigma=0.03),
... )
>>> params = model.init_params(jr.PRNGKey(0))
>>>
>>> # Create trial data
>>> data = ResponseData()
>>> data.add_trial(
...     ref=jnp.array([0.0, 0.0]), comparison=jnp.array([0.3, 0.2]), resp=1
... )
>>>
>>> # Compare analytical vs MC
>>> ll_analytical = model.task.loglik(params, data, model, model.noise)
>>> ll_mc = model.task.loglik_mc(
...     params,
...     data,
...     model,
...     model.noise,
...     num_samples=5000,
...     bandwidth=1e-3,
...     key=jr.PRNGKey(42),
... )
>>> print(f"Analytical: {ll_analytical:.4f}")
>>> print(f"MC (N=5000): {ll_mc:.4f}")
>>> print(f"Difference: {abs(ll_mc - ll_analytical):.4f}")
See Also

loglik : Analytical log-likelihood (faster, differentiable) predict : Analytical prediction for single trial

Source code in src/psyphy/model/task.py
def loglik_mc(
    self,
    params: Any,
    data: Any,
    model: Any,
    noise: Any,
    num_samples: int = 1000,
    bandwidth: float = 1e-2,
    key: Any = None,
) -> jnp.ndarray:
    """
    Compute log-likelihood via Monte Carlo observer simulation.

    This method implements the FULL 3-stimulus oddity task. Instead of using
    an analytical approximation, we:
    1. Sample three internal noisy representations per trial:
       - z_ref, z_refprime ~ N(ref, Σ_ref)  [two samples from reference]
       - z_comparison ~ N(comparison, Σ_comparison)           [one sample from comparison]
    2. Compute three pairwise Mahalanobis distances
    3. Apply oddity decision rule: comparison is odd if it's farther from BOTH ref and reference_prime
    4. Apply logistic smoothing to approximate P(correct)
    5. Average over MC samples

    Parameters
    ----------
    params : dict
        Model parameters (must contain 'W' for WPPM basis coefficients)
    data : ResponseData
        Trial data with refs, comparisons, and responses
    model : WPPM
        Model instance providing compute_U() for covariance computation
    noise : NoiseModel
        Observer noise model (provides sigma for diagonal noise term)
    num_samples : int, default=1000
        Number of Monte Carlo samples per trial.
        - Use 1000-5000 for accurate likelihood estimation
        - Larger values reduce MC variance but increase compute time
    bandwidth : float, default=1e-2
        Smoothing parameter for logistic CDF approximation.
        - Smaller values -> sharper transition (closer to step function)
        - Larger values -> smoother approximation
        - Typical range: [1e-3, 5e-2]
    key : jax.random.PRNGKey, optional
        Random key for reproducible sampling.
        If None, uses PRNGKey(0) (deterministic but not recommended for production)

    Returns
    -------
    jnp.ndarray
        Scalar sum of log-likelihoods over all trials.
        Same shape and interpretation as loglik().

    Raises
    ------
    ValueError
        If num_samples <= 0

    Notes
    -----
    **Full 3-stimulus oddity task algorithm:**

    For each trial (ref, comparison, response):
    1. Compute covariances:
       - Σ_ref = U_ref @ U_ref.T + σ^2 I
       - Σ_comparison = U_comparison @ U_comparison.T + σ^2 I
       - Σ_avg = (2/3) Σ_ref + (1/3) Σ_comparison  [weighted by stimulus frequency]

    2. Sample three internal representations:
       - z_ref, z_refprime ~ N(ref, Σ_ref)  [2 samples from reference, num_samples times each]
       - z_comparison ~ N(comparison, Σ_comparison)           [1 sample from comparison, num_samples times]

    3. Compute three pairwise Mahalanobis distances:
       - d^2(z_ref, z_refprime) = (z_ref - z_refprime).T @ Σ_avg^{-1} @ (z_ref - z_refprime)  [ref vs reference_prime]
       - d^2(z_ref, z_comparison) = (z_ref - z_comparison).T @ Σ_avg^{-1} @ (z_ref - z_comparison)  [ref vs comparison]
       - d^2(z_refprime, z_comparison) = (z_refprime - z_comparison).T @ Σ_avg^{-1} @ (z_refprime - z_comparison)  [reference_prime vs comparison]

    4. Apply oddity decision rule:
       - delta = min(d^2(z_ref,z_comparison), d^2(z_refprime,z_comparison)) - d^2(z_ref,z_refprime)
       - delta > 0 means comparison is farther from BOTH ref and reference_prime -> correct identification

    5. Apply logistic smoothing:
       - P(correct) \approx mean(logistic.cdf(delta / bandwidth))

    6. Bernoulli log-likelihood:
       - LL = Σ [y * log(p) + (1-y) * log(1-p)]

    Performance:
    - Time complexity: O(n_trials * num_samples * input_dim³)
    - Memory: O(num_samples * input_dim) per trial
    - Vectorized across trials using jax.vmap for GPU acceleration
    - Can be JIT-compiled for additional speed (future optimization)

    Comparison to analytical:
    - MC implements full 3-stimulus oddity (more realistic)
    - MC is more accurate for complex Σ(x) structures
    - Analytical is faster and differentiable
    - Use MC for validation and benchmarking, analytical for optimization

    Examples
    --------
    >>> import jax.numpy as jnp
    >>> import jax.random as jr
    >>> from psyphy.model import WPPM, Prior
    >>> from psyphy.model.task import OddityTask
    >>> from psyphy.model.noise import GaussianNoise
    >>> from psyphy.data.dataset import ResponseData
    >>>
    >>> # Setup
    >>> model = WPPM(
    ...     input_dim=2,
    ...     prior=Prior(input_dim=2, basis_degree=3),
    ...     task=OddityTask(),
    ...     noise=GaussianNoise(sigma=0.03),
    ... )
    >>> params = model.init_params(jr.PRNGKey(0))
    >>>
    >>> # Create trial data
    >>> data = ResponseData()
    >>> data.add_trial(
    ...     ref=jnp.array([0.0, 0.0]), comparison=jnp.array([0.3, 0.2]), resp=1
    ... )
    >>>
    >>> # Compare analytical vs MC
    >>> ll_analytical = model.task.loglik(params, data, model, model.noise)
    >>> ll_mc = model.task.loglik_mc(
    ...     params,
    ...     data,
    ...     model,
    ...     model.noise,
    ...     num_samples=5000,
    ...     bandwidth=1e-3,
    ...     key=jr.PRNGKey(42),
    ... )
    >>> print(f"Analytical: {ll_analytical:.4f}")
    >>> print(f"MC (N=5000): {ll_mc:.4f}")
    >>> print(f"Difference: {abs(ll_mc - ll_analytical):.4f}")

    See Also
    --------
    loglik : Analytical log-likelihood (faster, differentiable)
    predict : Analytical prediction for single trial



    """
    # Validate inputs
    if num_samples <= 0:
        raise ValueError(f"num_samples must be > 0, got {num_samples}")

    # Default key for reproducibility (warn: not secure)
    if key is None:
        key = jr.PRNGKey(0)

    # Unpack trial data
    refs, comparisons, responses = data.to_numpy()
    n_trials = len(refs)

    # Split keys for each trial (ensures independent sampling)
    trial_keys = jr.split(key, n_trials)

    # Vectorized computation of P(correct) for all trials
    # This processes all trials in parallel using jax.vmap
    # Note: probabilities are already clipped in _simulate_trial_mc()
    probs = self._simulate_trials_mc_vectorized(
        params=params,
        refs=refs,
        comparisons=comparisons,
        model=model,
        noise=noise,
        num_samples=num_samples,
        bandwidth=bandwidth,
        trial_keys=trial_keys,
    )

    # Bernoulli log-likelihood: LL = Σ [y log(p) + (1-y) log(1-p)]
    # Probabilities are already clipped to [eps, 1-eps] so log is safe
    log_likelihoods = jnp.where(
        responses == 1,
        jnp.log(probs),  # Correct response
        jnp.log(1.0 - probs),  # Incorrect response
    )

    return jnp.sum(log_likelihoods)

predict

predict(
    params: Any, stimuli: Stimulus, model: Any, noise: Any
) -> ndarray

Predict probability of correct response using analytical approximation.

Parameters:

Name Type Description Default
params dict

Model parameters (e.g., W for WPPM)

required
stimuli tuple[ndarray, ndarray]

(reference, comparison) stimulus pair

required
model WPPM

Model instance providing discriminability()

required
noise NoiseModel

Observer noise model (currently unused in analytical version)

required

Returns:

Type Description
ndarray

Scalar probability of correct response, in range [1/3, 1]

Notes

Uses tanh mapping: P(correct) = 1/3 + 2/3 * sigmoid(slope * d) where d is discriminability from model.discriminability()

Source code in src/psyphy/model/task.py
def predict(
    self, params: Any, stimuli: Stimulus, model: Any, noise: Any
) -> jnp.ndarray:
    """
    Predict probability of correct response using analytical approximation.

    Parameters
    ----------
    params : dict
        Model parameters (e.g., W for WPPM)
    stimuli : tuple[jnp.ndarray, jnp.ndarray]
        (reference, comparison) stimulus pair
    model : WPPM
        Model instance providing discriminability()
    noise : NoiseModel
        Observer noise model (currently unused in analytical version)

    Returns
    -------
    jnp.ndarray
        Scalar probability of correct response, in range [1/3, 1]

    Notes
    -----
    Uses tanh mapping: P(correct) = 1/3 + 2/3 * sigmoid(slope * d)
    where d is discriminability from model.discriminability()
    """
    d = model.discriminability(params, stimuli)
    g = 0.5 * (jnp.tanh(self.slope * d) + 1.0)
    return self.chance_level + self.performance_range * g

TaskLikelihood

Bases: ABC

Abstract base class for task likelihoods

Methods:

Name Description
loglik

Compute log-likelihood of observed responses under this task

predict

Predict probability of correct response for a stimulus.

loglik

loglik(
    params: Any, data: Any, model: Any, noise: Any
) -> ndarray

Compute log-likelihood of observed responses under this task

Source code in src/psyphy/model/task.py
@abstractmethod
def loglik(self, params: Any, data: Any, model: Any, noise: Any) -> jnp.ndarray:
    """Compute log-likelihood of observed responses under this task"""
    ...

predict

predict(
    params: Any, stimuli: Stimulus, model: Any, noise: Any
) -> ndarray

Predict probability of correct response for a stimulus.

Source code in src/psyphy/model/task.py
@abstractmethod
def predict(
    self, params: Any, stimuli: Stimulus, model: Any, noise: Any
) -> jnp.ndarray:
    """Predict probability of correct response for a stimulus."""
    ...

TwoAFC

TwoAFC(slope: float = 2.0)

Bases: TaskLikelihood

2-alternative forced-choice task (MVP placeholder).

Methods:

Name Description
loglik
predict

Attributes:

Name Type Description
chance_level float
performance_range float
slope
Source code in src/psyphy/model/task.py
def __init__(self, slope: float = 2.0) -> None:
    self.slope = float(slope)
    self.chance_level: float = 0.5
    self.performance_range: float = 1.0 - self.chance_level

chance_level

chance_level: float = 0.5

performance_range

performance_range: float = 1.0 - chance_level

slope

slope = float(slope)

loglik

loglik(
    params: Any, data: Any, model: Any, noise: Any
) -> ndarray
Source code in src/psyphy/model/task.py
def loglik(self, params: Any, data: Any, model: Any, noise: Any) -> jnp.ndarray:
    refs, comparisons, responses = data.to_numpy()
    ps = jnp.array(
        [
            self.predict(params, (r, p), model, noise)
            for r, p in zip(refs, comparisons)
        ]
    )
    eps = 1e-9
    return jnp.sum(
        jnp.where(responses == 1, jnp.log(ps + eps), jnp.log(1.0 - ps + eps))
    )

predict

predict(
    params: Any, stimuli: Stimulus, model: Any, noise: Any
) -> ndarray
Source code in src/psyphy/model/task.py
def predict(
    self, params: Any, stimuli: Stimulus, model: Any, noise: Any
) -> jnp.ndarray:
    d = model.discriminability(params, stimuli)
    return self.chance_level + self.performance_range * jnp.tanh(self.slope * d)