Build your own opimizer with jax
This example caters to the users who are interested in building their own optimizers with jax
.
For this purpose, we expose how the MAP optimizer is implemented in psyphy
.
You can run the toy example with this from scratch implementation yourself with the following script:
| python docs/examples/mvp/offline_fit_mvp_with_map_optimizer.py
|
Model
Model |
---|
| # Here we fit a new WPPM to the simulated data:
# For each trial i with response y_i, the model evaluates:
# log p(y_i | ref_i, probe_i, θ, task)
# where p(y=1|·) is obtained using the same closed-form mapping as in simulation,
# but with parameters θ (to be estimated).
# from psyphy.model.prior import Prior
# from psyphy.model.noise import GaussianNoise
# from psyphy.model.wppm import WPPM
prior = Prior.default(input_dim=2, scale=0.5) # Gaussian prior on log_diag
noise = GaussianNoise(sigma=1.0) # additive isotropic Gaussian noise
model = WPPM(input_dim=2, prior=prior, task=task, noise=noise)
init_params = model.init_params(jax.random.PRNGKey(42))
|
Training / Fitting
Here, we define the training loop that minimizes the model’s negative log posterior using stochastic gradient descent and momentum both with pshyphy and from scratch.
Fitting with psyphy |
---|
| # optimizer hyperparameters:
steps = 1000
lr = 2e-2
momentum = 0.9
# from psyphy.inference.map_optimizer import MAPOptimizer
optimizer = MAPOptimizer(steps=steps, learning_rate=lr, momentum=momentum, track_history=True, log_every=10)
# [Optional] Initialize parameters explicitly (otherwise falls back to prior sample with seed=0)
init_params = model.init_params(jax.random.PRNGKey(42))
# Fit model to data, returns a Posterior wrapper around the fitted params and model
# To see the training loop that is used, check the source of MAPOptimizer.fit() or the
# code snippet below.
posterior = optimizer.fit(model, data, init_params=init_params)
|
Implementing optimizers with jax
-- exposing psyphy's MAP implementation in Jax
Below we illustrate the implementation of MAP, one of our optimizers implemented
in the inference
module.
In particular, we will point out why and how we use Jax and optax.
A note on JAX:
The key feature here is JAX’s Just-In-Time (JIT) compilation, which transforms our Python function into a single, optimized computation graph that runs efficiently on CPU, GPU, or TPU.
To make this work, we represent parameters and optimizer states as PyTrees (nested dictionaries or tuples of arrays) — a core JAX data structure that supports efficient vectorization and differentiation.
This approach lets us scale optimization and inference routines from small CPU experiments to large GPU-accelerated Bayesian models with minimal code changes.
From scratch: training loop exposing psyphy's MAP implementation in Jax |
---|
| # optimizer hyperparameters:
steps = 100
lr = 2e-2
momentum = 0.9
# Use SGD+momentum from Optax
opt = optax.sgd(learning_rate=lr, momentum=momentum)
# Define loss = negative log posterior (minimize it)
def _loss_fn(params):
return -model.log_posterior_from_data(params, data)
# Start from prior init
params = init_params # PyTree of parameters (dict of arrays)
opt_state = opt.init(params) # PyTree of optimizer state
# Perform a single optimization step:
# Details: each step computes gradients via automatic differentiation (jax.value_and_grad),
# updates parameters, and returns new ones — all as jax PyTrees,
# which are lightweight nested structures of arrays that
# Jax can efficiently traverse and transform.
@jax.jit
def _step(params, opt_state):
# Ensure params and opt_state are Jax PyTrees for JIT compatibility
# (e.g., dicts of arrays, not custom Python objects)
loss, grads = jax.value_and_grad(_loss_fn)(params) # auto-diff
updates, opt_state = opt.update(grads, opt_state, params) # optimizer update
params = optax.apply_updates(params, updates) # apply updates
# Only return jax-compatible types (PyTrees of arrays, scalars)
return params, opt_state, loss
# Training loop: run steps of SGD+momentum
# and track loss every 10 steps
loss_iters: list[int] = []
loss_values: list[float] = []
for i in range(steps):
params, opt_state, loss = _step(params, opt_state) # single JIT-compiled step
if (i % 10 == 0) or (i == steps - 1):
loss_iters.append(i)
loss_values.append(float(loss))
fitted_params = params # maximum a posteriori (MAP) estimate after training of θ
|