Skip to content

Covariance Field: Visualizing Spatially-Varying Thresholds

This example demonstrates how to work with the WPPMCovarianceField abstraction for visualizing spatially-varying perceptual thresholds in the Wishart Process Psychophysical Model (WPPM).

The covariance field provides a clean, functional interface for: - Evaluating covariance matrices at different stimulus locations

  • Visualizing threshold countours (as ellipses)

  • Understanding how perceptual noise varies across stimulus space

You can run this complete example yourself with:

python docs/examples/covariance_field/covariance_field_demo.py


Mathematical Background

A covariance field \(\Sigma(x)\) maps stimulus locations to covariance matrices, representing how perceptual thresholds vary across space.

Wishart Process Covariance

The full WPPM uses a Wishart Process to model spatially-varying, full covariance:

\[ \Sigma(x) = U(x) \, U(x)^\top + \lambda I \]

where:

\[ U(x) = \sum_{ij} W_{ij} \, \phi_{ij}(x) \]
  • \(\phi_{ij}(x)\): Chebyshev basis functions
  • \(W_{ij}\): Learned coefficients
  • \(\lambda\): Numerical stabilizer (diag_term)

Imports

Required Imports
1
2
3
4
5
from psyphy.model.covariance_field import WPPMCovarianceField
from psyphy.model.noise import GaussianNoise
from psyphy.model.prior import Prior
from psyphy.model.task import OddityTask
from psyphy.model.wppm import WPPM

Example 1: Single Point Threshold Ellipse

Creating a Wishart Covariance Field

First, we create a WPPM model in Wishart mode with 5×5 basis functions (basis_degree=4) and sample a covariance field from the prior:

Create Wishart Model and Sample Field
# Create Wishart model
model = WPPM(
    input_dim=2,
    prior=Prior(input_dim=2, basis_degree=4, variance_scale=0.03, decay_rate=0.3),
    task=OddityTask(),
    noise=GaussianNoise(sigma=0.1),
    basis_degree=4,  # Wishart mode with 5x5 basis functions
    extra_dims=1,
    diag_term=1e-3,
)

# Sample covariance field from prior
key = jr.PRNGKey(42)
field = WPPMCovarianceField.from_prior(model, key)

# Evaluate at a single reference point
x_ref = jnp.array([0.5, 0.5])
Sigma_full = field(x_ref)  # callable interface!
# Extract the 2x2 block (input dimensions only)
Sigma_ref = Sigma_full[:2, :2]

Key parameters: - basis_degree=4: Activates Wishart mode with 5×5 Chebyshev basis functions (degrees 0-4) - extra_dims=1: Adds embedding dimensions for richer covariance structure - variance_scale=0.03, decay_rate=0.3: Control spatial smoothness - field(x): Callable interface makes the field feel like a mathematical function

Visualizing a Single Threshold Ellipse

We can visualize the threshold at a single point as an ellipse:

Threshold ellipse at point (0.5, 0.5). The ellipse shape encodes the covariance structure: elongation shows correlation direction, size shows threshold magnitude.


Example 2: Covariance Field Grid

Batch Evaluation Across Stimulus Space

To see how thresholds vary across space, we evaluate the covariance field at multiple points:

Create Grid and Evaluate
# Create grid of reference points
n_grid = 5
x_vals = jnp.linspace(0.15, 0.85, n_grid)
y_vals = jnp.linspace(0.15, 0.85, n_grid)
X_grid, Y_grid = jnp.meshgrid(x_vals, y_vals)
grid_points = jnp.stack([X_grid.ravel(), Y_grid.ravel()], axis=1)

# Evaluate covariance at all points using batch method
Sigmas_full = field.cov_batch(grid_points)
# Extract 2x2 blocks
Sigmas_grid = Sigmas_full[:, :2, :2]

# Also works with JAX vmap on the callable!
Sigmas_vmap_full = jax.vmap(field)(grid_points)
Sigmas_vmap = Sigmas_vmap_full[:, :2, :2]

Batch evaluation methods: 1. field.cov_batch(X) - Explicit batch method 2. jax.vmap(field)(X) - JAX vmap on callable interface 3. both are equivalent and efficient

Visualizing the Ellipse Field

5×5 grid of uncertainty ellipses showing spatial variation. note how ellipse size, shape, and orientation vary smoothly across the stimulus space.


Example 3: Custom Covariance Fields

Sampling Different Fields

You can sample multiple fields from the prior or create fields from custom parameters:

Custom Field from Different Prior Sample
1
2
3
4
5
6
key_custom = jr.PRNGKey(999)
custom_field = WPPMCovarianceField.from_prior(model, key_custom)
# Evaluate at center point
x_center = jnp.array([0.5, 0.5])
Sigma_custom_full = custom_field(x_center)
Sigma_custom = Sigma_custom_full[:2, :2]

Visualizing Custom Fields

Different prior sample produces different spatial pattern.


API Reference

Construction Methods

Method Description Use Case
from_prior(model, key) Sample from prior Initialization, visualization
from_posterior(posterior) Extract fitted field Analysis after fitting
from_params(model, params) Custom parameters testing / load fitted params from e.g. different subject

Evaluation Methods

Method Input Shape Output Shape Description
field(x) (d,) (d, d) Callable interface
field.cov(x) (d,) (d, d) Explicit covariance
field.cov_batch(X) (n, d) (n, d, d) Batch evaluation
jax.vmap(field)(X) (n, d) (n, d, d) JAX vmap

where: - d = embedding_dim (Wishart mode) - n = number of evaluation points

Properties

Property Returns Description
field.is_wishart_mode bool True if spatially-varying
field.model WPPM Associated model
field.params dict Parameter dictionary

Integration with WPPM Workflow

Here's how the covariance field would integrate into the full WPPM workflow:

# 1. Create model
model = WPPM(
    input_dim=2,
    prior=Prior(input_dim=2, basis_degree=5),
    task=OddityTask(),
    basis_degree=5,
)

# 2. visualize prior uncertainty (optional)
field_prior = WPPMCovarianceField.from_prior(model, key)
plot_ellipse_field(field_prior, grid_points)

# 3. fit model to data
model.fit(data, inference=MAPOptimizer(steps=500))

# 4. Extract learned covariance field
posterior = model.posterior(kind="parameter")
field_fitted = WPPMCovarianceField.from_posterior(posterior)

# 5. Visualize learned uncertainty
plot_ellipse_field(field_fitted, grid_points)

Best Practices

1
2
3
4
5
6
7
8
# use callable interface for single points
Sigma = field(x)

# Use batch methods for multiple points
Sigmas = field.cov_batch(X_grid)

# alternatively, use vmap for functional composition
Sigmas = jax.vmap(field)(X_grid)

Avoid

1
2
3
4
5
6
# don't loop over individual points (inefficient)
for x in X_grid:
    Sigma = field(x)  # bad! Use field.cov_batch(X_grid)

# Don't use wrong shapes
field.cov(X_batch)  # error! Use cov_batch()