Skip to content

Inference

Package


inference

inference

Inference engines for WPPM.

This subpackage provides different strategies for fitting model parameters to data and returning posterior objects.

MVP implementations
  • MAPOptimizer : maximum a posteriori fit with Optax optimizers.
  • LaplaceApproximation : approximate posterior covariance around MAP.
  • LangevinSampler : skeleton for sampling-based inference.
Future extensions
  • adjusted MC samplers, e.g., MALA (for Bayesian posterior inference).

Classes:

Name Description
InferenceEngine

Abstract interface for inference engines.

LangevinSampler

Langevin sampler (stub).

LaplaceApproximation

Laplace approximation around MAP estimate.

MAPOptimizer

MAP (Maximum A Posteriori) optimizer.

InferenceEngine

Bases: ABC

Abstract interface for inference engines.

Methods:

Name Description
fit

Fit model parameters to data and return a Posterior object.

fit

fit(model: Any, data: Any) -> Any

Fit model parameters to data.

Parameters:

Name Type Description Default
model WPPM

Psychophysical model to fit.

required
data ResponseData

Observed trials.

required

Returns:

Type Description
Posterior

Posterior object wrapping fitted params and model reference.

Source code in src/psyphy/inference/base.py
@abstractmethod
def fit(self, model: Any, data: Any) -> Any:
    """
    Fit model parameters to data.

    Parameters
    ----------
    model : WPPM
        Psychophysical model to fit.
    data : ResponseData
        Observed trials.

    Returns
    -------
    Posterior
        Posterior object wrapping fitted params and model reference.
    """
    ...

LangevinSampler

LangevinSampler(
    steps: int = 1000,
    step_size: float = 0.001,
    temperature: float = 1.0,
)

Langevin sampler (stub).

Parameters:

Name Type Description Default
steps int

Number of Langevin steps.

1000
step_size float

Integration step size.

1e-3
temperature float

Noise scale (temperature).

1.0

Methods:

Name Description
fit

Fit model parameters with Langevin dynamics (stub).

Attributes:

Name Type Description
step_size
steps
temperature
Source code in src/psyphy/inference/langevin.py
def __init__(self, steps: int = 1000, step_size: float = 1e-3, temperature: float = 1.0):
    self.steps = steps
    self.step_size = step_size
    self.temperature = temperature

step_size

step_size = step_size

steps

steps = steps

temperature

temperature = temperature

fit

fit(model, data) -> Posterior

Fit model parameters with Langevin dynamics (stub).

Parameters:

Name Type Description Default
model WPPM

Model instance.

required
data ResponseData

Observed trials.

required

Returns:

Type Description
Posterior

Posterior wrapper (MVP: params from init).

Source code in src/psyphy/inference/langevin.py
def fit(self, model, data) -> Posterior:
    """
    Fit model parameters with Langevin dynamics (stub).

    Parameters
    ----------
    model : WPPM
        Model instance.
    data : ResponseData
        Observed trials.

    Returns
    -------
    Posterior
        Posterior wrapper (MVP: params from init).
    """
    return Posterior(params=model.init_params(None), model=model)

LaplaceApproximation

Laplace approximation around MAP estimate.

Methods:

Name Description
from_map

Construct a Gaussian approximation centered at MAP.

from_map

from_map(map_posterior: Posterior) -> Posterior

Return posterior approximation from MAP.

Parameters:

Name Type Description Default
map_posterior Posterior

Posterior object from MAP optimization.

required

Returns:

Type Description
Posterior

Same posterior object (MVP).

Source code in src/psyphy/inference/laplace.py
def from_map(self, map_posterior: Posterior) -> Posterior:
    """
    Return posterior approximation from MAP.

    Parameters
    ----------
    map_posterior : Posterior
        Posterior object from MAP optimization.

    Returns
    -------
    Posterior
        Same posterior object (MVP).
    """
    return map_posterior

MAPOptimizer

MAPOptimizer(
    steps: int = 500,
    learning_rate: float = 5e-05,
    momentum: float = 0.9,
    optimizer: GradientTransformation | None = None,
    *,
    track_history: bool = False,
    log_every: int = 10
)

Bases: InferenceEngine

MAP (Maximum A Posteriori) optimizer.

Parameters:

Name Type Description Default
steps int

Number of optimization steps.

500
optimizer GradientTransformation

Optax optimizer to use. Default: SGD with momentum.

None
Notes
  • Loss function = negative log posterior.
  • Gradients computed with jax.grad.

Create a MAP optimizer.

Parameters:

Name Type Description Default
steps int

Number of optimization steps.

500
optimizer GradientTransformation | None

Optax optimizer to use.

None
learning_rate float

Learning rate for the default optimizer (SGD with momentum).

5e-05
momentum float

Momentum for the default optimizer (SGD with momentum).

0.9
track_history bool

When True, record loss history during fitting for plotting.

False
log_every int

Record every N steps (also records the last step).

10

Methods:

Name Description
fit

Fit model parameters with MAP optimization.

get_history

Return (steps, losses) recorded during the last fit when tracking was enabled.

Attributes:

Name Type Description
log_every
loss_history list[float]
loss_steps list[int]
optimizer
steps
track_history
Source code in src/psyphy/inference/map_optimizer.py
def __init__(
    self,
    steps: int = 500,
    learning_rate: float = 5e-5,
    momentum: float = 0.9,
    optimizer: optax.GradientTransformation | None = None,
    *,
    track_history: bool = False,
    log_every: int = 10,
):
    """Create a MAP optimizer.

    Parameters
    ----------
    steps : int
        Number of optimization steps.
    optimizer : optax.GradientTransformation | None
        Optax optimizer to use.
    learning_rate : float, optional
        Learning rate for the default optimizer (SGD with momentum).
    momentum : float, optional
        Momentum for the default optimizer (SGD with momentum).
    track_history : bool, optional
        When True, record loss history during fitting for plotting.
    log_every : int, optional
        Record every N steps (also records the last step).
    """
    self.steps = steps
    self.optimizer = optimizer or optax.sgd(learning_rate=learning_rate, momentum=momentum)
    self.track_history = track_history
    self.log_every = max(1, int(log_every))
    # Exposed after fit() when tracking is enabled
    self.loss_steps: list[int] = []
    self.loss_history: list[float] = []

log_every

log_every = max(1, int(log_every))

loss_history

loss_history: list[float] = []

loss_steps

loss_steps: list[int] = []

optimizer

optimizer = optimizer or sgd(
    learning_rate=learning_rate, momentum=momentum
)

steps

steps = steps

track_history

track_history = track_history

fit

fit(
    model,
    data,
    init_params: dict | None = None,
    seed: int | None = None,
) -> Posterior

Fit model parameters with MAP optimization.

Parameters:

Name Type Description Default
model WPPM

Model instance.

required
data ResponseData

Observed trials.

required
init_params dict | None

Initial parameter PyTree to start optimization from. If provided, this takes precedence over the seed.

None
seed int | None

PRNG seed used to draw initial parameters from the model's prior when init_params is not provided. If None, defaults to 0.

None

Returns:

Type Description
Posterior

Posterior wrapper around MAP params and model.

Source code in src/psyphy/inference/map_optimizer.py
def fit(self, model, data, init_params: dict | None = None, seed: int | None = None) -> Posterior:
    """
    Fit model parameters with MAP optimization.

    Parameters
    ----------
    model : WPPM
        Model instance.
    data : ResponseData
        Observed trials.
    init_params : dict | None, optional
        Initial parameter PyTree to start optimization from. If provided,
        this takes precedence over the seed.
    seed : int | None, optional
        PRNG seed used to draw initial parameters from the model's prior
        when init_params is not provided. If None, defaults to 0.

    Returns
    -------
    Posterior
        Posterior wrapper around MAP params and model.
    """

    def loss_fn(params):
        return -model.log_posterior_from_data(params, data)

    # Initialize parameters
    if init_params is not None:
        params = init_params
    else:
        rng_seed = 0 if seed is None else int(seed)
        params = model.init_params(jax.random.PRNGKey(rng_seed))
    opt_state = self.optimizer.init(params)

    @jax.jit
    def step(params, opt_state):
        # Ensure params and opt_state are JAX PyTrees for JIT compatibility
        loss, grads = jax.value_and_grad(loss_fn)(params)  # auto-diff
        updates, opt_state = self.optimizer.update(grads, opt_state, params)  # optimizer update
        params = optax.apply_updates(params, updates)  # apply updates
        # Only return JAX-compatible types (PyTrees of arrays, scalars)
        return params, opt_state, loss

    # clear any previous history
    if self.track_history:
        self.loss_steps.clear()
        self.loss_history.clear()

    for i in range(self.steps):
        params, opt_state, loss = step(params, opt_state)
        if self.track_history and ((i % self.log_every == 0) or (i == self.steps - 1)):
            # Pull scalar to host and record
            try:
                self.loss_steps.append(i)
                self.loss_history.append(float(loss))
            except Exception:
                # Best-effort; do not break fitting if logging fails
                pass

    return Posterior(params=params, model=model)

get_history

get_history() -> tuple[list[int], list[float]]

Return (steps, losses) recorded during the last fit when tracking was enabled.

Source code in src/psyphy/inference/map_optimizer.py
def get_history(self) -> tuple[list[int], list[float]]:
    """Return (steps, losses) recorded during the last fit when tracking was enabled."""
    return self.loss_steps, self.loss_history

Base


base

base.py

Abstract base class for inference engines.

All inference engines must implement a fit(model, data) method that returns a Posterior object.

All inference engines (MAPOptimizer, LangevinSampler, LaplaceApproximation) subclass from this base.

Classes:

Name Description
InferenceEngine

Abstract interface for inference engines.

InferenceEngine

Bases: ABC

Abstract interface for inference engines.

Methods:

Name Description
fit

Fit model parameters to data and return a Posterior object.

fit

fit(model: Any, data: Any) -> Any

Fit model parameters to data.

Parameters:

Name Type Description Default
model WPPM

Psychophysical model to fit.

required
data ResponseData

Observed trials.

required

Returns:

Type Description
Posterior

Posterior object wrapping fitted params and model reference.

Source code in src/psyphy/inference/base.py
@abstractmethod
def fit(self, model: Any, data: Any) -> Any:
    """
    Fit model parameters to data.

    Parameters
    ----------
    model : WPPM
        Psychophysical model to fit.
    data : ResponseData
        Observed trials.

    Returns
    -------
    Posterior
        Posterior object wrapping fitted params and model reference.
    """
    ...

MAP Optimizer


map_optimizer

map_optimizer.py

MAP (Maximum A Posteriori) optimizer using Optax.

MVP implementation: - Uses gradient ascent on log posterior. - Defaults to SGD with momentum, but any Optax optimizer can be passed in.

Connections
  • Calls WPPM.log_posterior_from_data(params, data) as the objective.
  • Returns a Posterior object wrapping the MAP estimate.

Classes:

Name Description
MAPOptimizer

MAP (Maximum A Posteriori) optimizer.

MAPOptimizer

MAPOptimizer(
    steps: int = 500,
    learning_rate: float = 5e-05,
    momentum: float = 0.9,
    optimizer: GradientTransformation | None = None,
    *,
    track_history: bool = False,
    log_every: int = 10
)

Bases: InferenceEngine

MAP (Maximum A Posteriori) optimizer.

Parameters:

Name Type Description Default
steps int

Number of optimization steps.

500
optimizer GradientTransformation

Optax optimizer to use. Default: SGD with momentum.

None
Notes
  • Loss function = negative log posterior.
  • Gradients computed with jax.grad.

Create a MAP optimizer.

Parameters:

Name Type Description Default
steps int

Number of optimization steps.

500
optimizer GradientTransformation | None

Optax optimizer to use.

None
learning_rate float

Learning rate for the default optimizer (SGD with momentum).

5e-05
momentum float

Momentum for the default optimizer (SGD with momentum).

0.9
track_history bool

When True, record loss history during fitting for plotting.

False
log_every int

Record every N steps (also records the last step).

10

Methods:

Name Description
fit

Fit model parameters with MAP optimization.

get_history

Return (steps, losses) recorded during the last fit when tracking was enabled.

Attributes:

Name Type Description
log_every
loss_history list[float]
loss_steps list[int]
optimizer
steps
track_history
Source code in src/psyphy/inference/map_optimizer.py
def __init__(
    self,
    steps: int = 500,
    learning_rate: float = 5e-5,
    momentum: float = 0.9,
    optimizer: optax.GradientTransformation | None = None,
    *,
    track_history: bool = False,
    log_every: int = 10,
):
    """Create a MAP optimizer.

    Parameters
    ----------
    steps : int
        Number of optimization steps.
    optimizer : optax.GradientTransformation | None
        Optax optimizer to use.
    learning_rate : float, optional
        Learning rate for the default optimizer (SGD with momentum).
    momentum : float, optional
        Momentum for the default optimizer (SGD with momentum).
    track_history : bool, optional
        When True, record loss history during fitting for plotting.
    log_every : int, optional
        Record every N steps (also records the last step).
    """
    self.steps = steps
    self.optimizer = optimizer or optax.sgd(learning_rate=learning_rate, momentum=momentum)
    self.track_history = track_history
    self.log_every = max(1, int(log_every))
    # Exposed after fit() when tracking is enabled
    self.loss_steps: list[int] = []
    self.loss_history: list[float] = []

log_every

log_every = max(1, int(log_every))

loss_history

loss_history: list[float] = []

loss_steps

loss_steps: list[int] = []

optimizer

optimizer = optimizer or sgd(
    learning_rate=learning_rate, momentum=momentum
)

steps

steps = steps

track_history

track_history = track_history

fit

fit(
    model,
    data,
    init_params: dict | None = None,
    seed: int | None = None,
) -> Posterior

Fit model parameters with MAP optimization.

Parameters:

Name Type Description Default
model WPPM

Model instance.

required
data ResponseData

Observed trials.

required
init_params dict | None

Initial parameter PyTree to start optimization from. If provided, this takes precedence over the seed.

None
seed int | None

PRNG seed used to draw initial parameters from the model's prior when init_params is not provided. If None, defaults to 0.

None

Returns:

Type Description
Posterior

Posterior wrapper around MAP params and model.

Source code in src/psyphy/inference/map_optimizer.py
def fit(self, model, data, init_params: dict | None = None, seed: int | None = None) -> Posterior:
    """
    Fit model parameters with MAP optimization.

    Parameters
    ----------
    model : WPPM
        Model instance.
    data : ResponseData
        Observed trials.
    init_params : dict | None, optional
        Initial parameter PyTree to start optimization from. If provided,
        this takes precedence over the seed.
    seed : int | None, optional
        PRNG seed used to draw initial parameters from the model's prior
        when init_params is not provided. If None, defaults to 0.

    Returns
    -------
    Posterior
        Posterior wrapper around MAP params and model.
    """

    def loss_fn(params):
        return -model.log_posterior_from_data(params, data)

    # Initialize parameters
    if init_params is not None:
        params = init_params
    else:
        rng_seed = 0 if seed is None else int(seed)
        params = model.init_params(jax.random.PRNGKey(rng_seed))
    opt_state = self.optimizer.init(params)

    @jax.jit
    def step(params, opt_state):
        # Ensure params and opt_state are JAX PyTrees for JIT compatibility
        loss, grads = jax.value_and_grad(loss_fn)(params)  # auto-diff
        updates, opt_state = self.optimizer.update(grads, opt_state, params)  # optimizer update
        params = optax.apply_updates(params, updates)  # apply updates
        # Only return JAX-compatible types (PyTrees of arrays, scalars)
        return params, opt_state, loss

    # clear any previous history
    if self.track_history:
        self.loss_steps.clear()
        self.loss_history.clear()

    for i in range(self.steps):
        params, opt_state, loss = step(params, opt_state)
        if self.track_history and ((i % self.log_every == 0) or (i == self.steps - 1)):
            # Pull scalar to host and record
            try:
                self.loss_steps.append(i)
                self.loss_history.append(float(loss))
            except Exception:
                # Best-effort; do not break fitting if logging fails
                pass

    return Posterior(params=params, model=model)

get_history

get_history() -> tuple[list[int], list[float]]

Return (steps, losses) recorded during the last fit when tracking was enabled.

Source code in src/psyphy/inference/map_optimizer.py
def get_history(self) -> tuple[list[int], list[float]]:
    """Return (steps, losses) recorded during the last fit when tracking was enabled."""
    return self.loss_steps, self.loss_history

Langevin Samplers


langevin

langevin.py

Langevin samplers for posterior inference.

Implements: - Overdamped (unadjusted) Langevin Algorithm (ULA) - Underdamped Langevin (with BAOAB splitting scheme?)

Used for posterior-aware trial placement (InfoGain).

MVP implementation: - Stub that returns an initial Posterior. - Future: implement underdamped Langevin dynamics (e.g. BAOAB integrator).

Classes:

Name Description
LangevinSampler

Langevin sampler (stub).

LangevinSampler

LangevinSampler(
    steps: int = 1000,
    step_size: float = 0.001,
    temperature: float = 1.0,
)

Langevin sampler (stub).

Parameters:

Name Type Description Default
steps int

Number of Langevin steps.

1000
step_size float

Integration step size.

1e-3
temperature float

Noise scale (temperature).

1.0

Methods:

Name Description
fit

Fit model parameters with Langevin dynamics (stub).

Attributes:

Name Type Description
step_size
steps
temperature
Source code in src/psyphy/inference/langevin.py
def __init__(self, steps: int = 1000, step_size: float = 1e-3, temperature: float = 1.0):
    self.steps = steps
    self.step_size = step_size
    self.temperature = temperature

step_size

step_size = step_size

steps

steps = steps

temperature

temperature = temperature

fit

fit(model, data) -> Posterior

Fit model parameters with Langevin dynamics (stub).

Parameters:

Name Type Description Default
model WPPM

Model instance.

required
data ResponseData

Observed trials.

required

Returns:

Type Description
Posterior

Posterior wrapper (MVP: params from init).

Source code in src/psyphy/inference/langevin.py
def fit(self, model, data) -> Posterior:
    """
    Fit model parameters with Langevin dynamics (stub).

    Parameters
    ----------
    model : WPPM
        Model instance.
    data : ResponseData
        Observed trials.

    Returns
    -------
    Posterior
        Posterior wrapper (MVP: params from init).
    """
    return Posterior(params=model.init_params(None), model=model)

Laplace Approximation


laplace

laplace.py

Laplace approximation to posterior.

Approximates posterior with a Gaussian: N(mean = MAP, covariance = H^-1 at MAP)

Provides posterior.sample() cheaply. Useful for InfoGainPlacement when only MAP fit is available.

MVP implementation: - Stub that just returns the MAP posterior. - Future: compute covariance from Hessian at MAP params.

Classes:

Name Description
LaplaceApproximation

Laplace approximation around MAP estimate.

LaplaceApproximation

Laplace approximation around MAP estimate.

Methods:

Name Description
from_map

Construct a Gaussian approximation centered at MAP.

from_map

from_map(map_posterior: Posterior) -> Posterior

Return posterior approximation from MAP.

Parameters:

Name Type Description Default
map_posterior Posterior

Posterior object from MAP optimization.

required

Returns:

Type Description
Posterior

Same posterior object (MVP).

Source code in src/psyphy/inference/laplace.py
def from_map(self, map_posterior: Posterior) -> Posterior:
    """
    Return posterior approximation from MAP.

    Parameters
    ----------
    map_posterior : Posterior
        Posterior object from MAP optimization.

    Returns
    -------
    Posterior
        Same posterior object (MVP).
    """
    return map_posterior