from psyphy.data.dataset import ResponseData
from psyphy.model import WPPM, Prior, TwoAFC
from psyphy.inference.map_optimizer import MAPOptimizer
import optax
import jax.numpy as jnp
# Prepare data
# Create an empty container for trials (reference, probe, response)
data = ResponseData()
# Add one trial:
# - ref: reference stimulus (shape: (input_dim,))
# - probe: probe stimulus (same shape as ref)
# - resp: binary response in {0, 1}; TwoAFC log-likelihood treats 1 as "correct"
data.add_trial(ref=jnp.array([0.0, 0.0]), probe=jnp.array([0.1, 0.0]), resp=1)
# Add another trial (subject responded 0 = "incorrect")
data.add_trial(ref=jnp.array([0.0, 0.0]), probe=jnp.array([0.0, 0.1]), resp=0)
# Model
model = WPPM(input_dim=2, prior=Prior.default(2), task=TwoAFC())
# Optimizer config (SGD + momentum)
opt = optax.sgd(learning_rate=5e-4, momentum=0.9)
posterior = MAPOptimizer(steps=500, optimizer=opt).fit(model, data)
# Predictions
p = posterior.predict_prob((jnp.array([0.0, 0.0]), jnp.array([0.05, 0.05])))
contour = posterior.predict_thresholds(reference=jnp.array([0.0, 0.0]))