Skip to content

Acquisition Functions

Acquisition functions for Bayesian optimization and active learning in psychophysical experiments.

Package Overview


acquisition

acquisition

Acquisition functions for adaptive experimental design.

This module provides: - AcquisitionFunction: Protocol for acquisition functions - optimize_acqf(): Functional interface for optimization - Common acquisition functions (Expected Improvement (EI), Upper Confidence Bound (UCB), Information Gain (IG))

Design

Unlike BoTorch's class-based approach, we use a functional style: acq_fn = lambda X: expected_improvement(model.posterior(X), best_f) X_next = optimize_acqf(acq_fn, bounds, q=1)

This is simpler and more composable than inheritance hierarchies.

Available Functions
  • expected_improvement: Maximize expected improvement over best observation
  • upper_confidence_bound: Balance exploration vs exploitation
  • probability_of_improvement: Maximize probability of improvement
  • mutual_information: Maximize information gain (Bayesian Active Learning by Disagreement (BALD))
Optimization Methods
  • optimize_acqf_discrete: Exhaustive search over candidate set
  • optimize_acqf: Gradient-based optimization (Optax)
  • optimize_acqf_random: Random search baseline

Examples:

1
2
3
4
>>> # Discrete optimization
>>> candidates = jnp.array([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]])
>>> acq_values = expected_improvement(model.posterior(candidates), best_f=0.5)
>>> X_next = candidates[jnp.argmax(acq_values)]
1
2
3
4
5
6
7
8
9
>>> # Continuous optimization with gradient descent
>>> def acq_fn(X):
...     return expected_improvement(model.posterior(X), best_f=0.5)
>>> X_next = optimize_acqf(
...     acq_fn,
...     bounds=jnp.array([[0.0, 1.0], [0.0, 1.0]]),
...     q=1,
...     method="gradient",
... )

Classes:

Name Description
AcquisitionFunction

Protocol for acquisition functions.

Functions:

Name Description
log_expected_improvement

Log expected improvement for numerical stability.

optimize_acqf

Optimize acquisition function over continuous domain.

optimize_acqf_discrete

Optimize acquisition function over discrete candidate set.

optimize_acqf_random

Optimize acquisition function via random search.

AcquisitionFunction

Bases: Protocol

Protocol for acquisition functions.

An acquisition function scores candidate points for selection in adaptive experimental design. Higher scores indicate more valuable points.

Examples:

1
2
3
4
5
6
>>> # Function-based acquisition
>>> def my_acquisition(X):
...     posterior = model.posterior(X)
...     return posterior.mean + 2.0 * jnp.sqrt(posterior.variance)
>>>
>>> X_next = optimize_acqf(my_acquisition, bounds, q=1)
1
2
3
>>> # Lambda-based acquisition
>>> acq_fn = lambda X: -posterior.variance  # Minimize uncertainty
>>> X_next = optimize_acqf(acq_fn, bounds, q=1)
Notes

We deliberately do NOT use a class hierarchy (unlike BoTorch's AcquisitionFunction base class). Functional composition is simpler and more flexible for research code.

If you need stateful acquisition (e.g., caching), use a callable class: class CachedAcquisition: def init(self, model): self.model = model self._cache = {}

1
2
3
    def __call__(self, X):
        # Use cache...
        return scores

log_expected_improvement

log_expected_improvement(
    posterior: PredictivePosterior,
    best_f: float,
    maximize: bool = True,
) -> ndarray

Log expected improvement for numerical stability.

Useful when EI values span many orders of magnitude.

Parameters:

Name Type Description Default
posterior PredictivePosterior

Predictive posterior

required
best_f float

Best observed value

required
maximize bool

If True, maximize. If False, minimize.

True

Returns:

Type Description
(ndarray, shape(n_candidates))

log(EI) values

Examples:

1
2
3
>>> # When EI values are very small, log(EI) is more stable
>>> log_ei = log_expected_improvement(posterior, best_f)
>>> X_next = X_candidates[jnp.argmax(log_ei)]
Notes

Since we only care about ranking, log(EI) preserves the order: argmax EI(x) = argmax log(EI(x))

This is numerically more stable when EI values are near zero.

Source code in src/psyphy/acquisition/expected_improvement.py
def log_expected_improvement(
    posterior: PredictivePosterior,
    best_f: float,
    maximize: bool = True,
) -> jnp.ndarray:
    """
    Log expected improvement for numerical stability.

    Useful when EI values span many orders of magnitude.

    Parameters
    ----------
    posterior : PredictivePosterior
        Predictive posterior
    best_f : float
        Best observed value
    maximize : bool, default=True
        If True, maximize. If False, minimize.

    Returns
    -------
    jnp.ndarray, shape (n_candidates,)
        log(EI) values

    Examples
    --------
    >>> # When EI values are very small, log(EI) is more stable
    >>> log_ei = log_expected_improvement(posterior, best_f)
    >>> X_next = X_candidates[jnp.argmax(log_ei)]

    Notes
    -----
    Since we only care about ranking, log(EI) preserves the order:
        argmax EI(x) = argmax log(EI(x))

    This is numerically more stable when EI values are near zero.
    """
    ei = expected_improvement(posterior, best_f, maximize=maximize)

    # Add small constant for log stability
    log_ei = jnp.log(ei + 1e-25)

    return log_ei

optimize_acqf

optimize_acqf(
    acq_fn: Callable[[ndarray], ndarray],
    bounds: ndarray,
    q: int = 1,
    *,
    method: Literal["gradient", "random"] = "gradient",
    num_restarts: int = 10,
    raw_samples: int = 100,
    optim_steps: int = 100,
    lr: float = 0.01,
    key: Any = None,
) -> tuple[ndarray, ndarray]

Optimize acquisition function over continuous domain.

Uses multi-start gradient descent to find global optimum.

Parameters:

Name Type Description Default
acq_fn callable

Acquisition function. Takes (n_points, input_dim) array, returns (n_points,) scores.

required
bounds (ndarray, shape(input_dim, 2))

Box constraints [[x1_min, x1_max], [x2_min, x2_max], ...]

required
q int

Batch size (number of points to select)

1
method ('gradient', 'random')

Optimization method

"gradient"
num_restarts int

Number of random restarts for gradient descent

10
raw_samples int

Number of random samples to initialize restarts

100
optim_steps int

Number of optimization steps per restart

100
lr float

Learning rate for gradient descent

0.01
key KeyArray | None

PRNG key for random initialization

None

Returns:

Name Type Description
X_next (ndarray, shape(q, input_dim))

Optimized points

acq_values (ndarray, shape(q))

Acquisition values at X_next

Examples:

1
2
3
4
5
6
>>> # Simple continuous optimization
>>> def acq_fn(X):
...     return expected_improvement(model.posterior(X), best_f=0.5)
>>>
>>> bounds = jnp.array([[0.0, 1.0], [0.0, 1.0]])  # 2D unit square
>>> X_next, acq_val = optimize_acqf(acq_fn, bounds, q=1, method="gradient")
1
2
3
4
5
6
7
8
>>> # Batch acquisition with multiple restarts
>>> X_batch, acq_vals = optimize_acqf(
...     acq_fn,
...     bounds,
...     q=3,
...     num_restarts=20,
...     optim_steps=200,
... )
Notes

For gradient-based optimization, ensure your acquisition function is differentiable through JAX. Use jax.grad() or jax.value_and_grad().

For non-differentiable acquisition functions, use method="random" or optimize_acqf_discrete() with a candidate grid.

Source code in src/psyphy/acquisition/optimize.py
def optimize_acqf(
    acq_fn: Callable[[jnp.ndarray], jnp.ndarray],
    bounds: jnp.ndarray,
    q: int = 1,
    *,
    method: Literal["gradient", "random"] = "gradient",
    num_restarts: int = 10,
    raw_samples: int = 100,
    optim_steps: int = 100,
    lr: float = 0.01,
    key: Any = None,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """
    Optimize acquisition function over continuous domain.

    Uses multi-start gradient descent to find global optimum.

    Parameters
    ----------
    acq_fn : callable
        Acquisition function. Takes (n_points, input_dim) array,
        returns (n_points,) scores.
    bounds : jnp.ndarray, shape (input_dim, 2)
        Box constraints [[x1_min, x1_max], [x2_min, x2_max], ...]
    q : int, default=1
        Batch size (number of points to select)
    method : {"gradient", "random"}, default="gradient"
        Optimization method
    num_restarts : int, default=10
        Number of random restarts for gradient descent
    raw_samples : int, default=100
        Number of random samples to initialize restarts
    optim_steps : int, default=100
        Number of optimization steps per restart
    lr : float, default=0.01
        Learning rate for gradient descent
    key : jax.random.KeyArray | None
        PRNG key for random initialization

    Returns
    -------
    X_next : jnp.ndarray, shape (q, input_dim)
        Optimized points
    acq_values : jnp.ndarray, shape (q,)
        Acquisition values at X_next

    Examples
    --------
    >>> # Simple continuous optimization
    >>> def acq_fn(X):
    ...     return expected_improvement(model.posterior(X), best_f=0.5)
    >>>
    >>> bounds = jnp.array([[0.0, 1.0], [0.0, 1.0]])  # 2D unit square
    >>> X_next, acq_val = optimize_acqf(acq_fn, bounds, q=1, method="gradient")

    >>> # Batch acquisition with multiple restarts
    >>> X_batch, acq_vals = optimize_acqf(
    ...     acq_fn,
    ...     bounds,
    ...     q=3,
    ...     num_restarts=20,
    ...     optim_steps=200,
    ... )

    Notes
    -----
    For gradient-based optimization, ensure your acquisition function
    is differentiable through JAX. Use jax.grad() or jax.value_and_grad().

    For non-differentiable acquisition functions, use method="random"
    or optimize_acqf_discrete() with a candidate grid.
    """
    if key is None:
        key = jr.PRNGKey(0)

    if method == "random":
        return optimize_acqf_random(
            acq_fn, bounds, q=q, num_samples=raw_samples, key=key
        )
    elif method == "gradient":
        return _optimize_acqf_gradient(
            acq_fn,
            bounds,
            q=q,
            num_restarts=num_restarts,
            raw_samples=raw_samples,
            optim_steps=optim_steps,
            lr=lr,
            key=key,
        )
    else:
        raise ValueError(f"Unknown method: {method}. Use 'gradient' or 'random'.")

optimize_acqf_discrete

optimize_acqf_discrete(
    acq_fn: Callable[[ndarray], ndarray],
    candidates: ndarray,
    q: int = 1,
) -> tuple[ndarray, ndarray]

Optimize acquisition function over discrete candidate set.

This is the simplest and most common approach for psychophysics, where stimulus space is often discretized.

Parameters:

Name Type Description Default
acq_fn callable

Acquisition function. Takes (n_candidates, input_dim) array, returns (n_candidates,) scores.

required
candidates (ndarray, shape(n_candidates, input_dim))

Discrete candidate points to evaluate

required
q int

Batch size (number of points to select)

1

Returns:

Name Type Description
X_next (ndarray, shape(q, input_dim))

Selected candidate points

acq_values (ndarray, shape(q))

Acquisition values of selected points

Examples:

1
2
3
4
5
6
>>> # Simple discrete optimization
>>> candidates = jnp.array([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]])
>>> def acq_fn(X):
...     return expected_improvement(model.posterior(X), best_f=0.5)
>>>
>>> X_next, acq_val = optimize_acqf_discrete(acq_fn, candidates, q=1)
>>> # Batch acquisition
>>> X_batch, acq_vals = optimize_acqf_discrete(acq_fn, candidates, q=3)
Source code in src/psyphy/acquisition/optimize.py
def optimize_acqf_discrete(
    acq_fn: Callable[[jnp.ndarray], jnp.ndarray],
    candidates: jnp.ndarray,
    q: int = 1,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """
    Optimize acquisition function over discrete candidate set.

    This is the simplest and most common approach for psychophysics,
    where stimulus space is often discretized.

    Parameters
    ----------
    acq_fn : callable
        Acquisition function. Takes (n_candidates, input_dim) array,
        returns (n_candidates,) scores.
    candidates : jnp.ndarray, shape (n_candidates, input_dim)
        Discrete candidate points to evaluate
    q : int, default=1
        Batch size (number of points to select)

    Returns
    -------
    X_next : jnp.ndarray, shape (q, input_dim)
        Selected candidate points
    acq_values : jnp.ndarray, shape (q,)
        Acquisition values of selected points

    Examples
    --------
    >>> # Simple discrete optimization
    >>> candidates = jnp.array([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]])
    >>> def acq_fn(X):
    ...     return expected_improvement(model.posterior(X), best_f=0.5)
    >>>
    >>> X_next, acq_val = optimize_acqf_discrete(acq_fn, candidates, q=1)

    >>> # Batch acquisition
    >>> X_batch, acq_vals = optimize_acqf_discrete(acq_fn, candidates, q=3)
    """
    # Evaluate all candidates
    acq_values = acq_fn(candidates)

    # Select top-q by acquisition value
    top_indices = jnp.argsort(acq_values)[-q:][::-1]  # Descending order
    X_next = candidates[top_indices]
    selected_values = acq_values[top_indices]

    return X_next, selected_values

optimize_acqf_random

optimize_acqf_random(
    acq_fn: Callable[[ndarray], ndarray],
    bounds: ndarray,
    q: int = 1,
    *,
    num_samples: int = 1000,
    key: Any = None,
) -> tuple[ndarray, ndarray]

Optimize acquisition function via random search.

Simple baseline: sample random points, evaluate, select best.

Parameters:

Name Type Description Default
acq_fn callable

Acquisition function

required
bounds (ndarray, shape(input_dim, 2))

Box constraints

required
q int

Batch size

1
num_samples int

Number of random samples to evaluate

1000
key KeyArray | None

PRNG key

None

Returns:

Name Type Description
X_next (ndarray, shape(q, input_dim))

Best random samples

acq_values (ndarray, shape(q))

Acquisition values

Examples:

>>> X_next, acq_val = optimize_acqf_random(acq_fn, bounds, q=1, num_samples=5000)
Source code in src/psyphy/acquisition/optimize.py
def optimize_acqf_random(
    acq_fn: Callable[[jnp.ndarray], jnp.ndarray],
    bounds: jnp.ndarray,
    q: int = 1,
    *,
    num_samples: int = 1000,
    key: Any = None,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """
    Optimize acquisition function via random search.

    Simple baseline: sample random points, evaluate, select best.

    Parameters
    ----------
    acq_fn : callable
        Acquisition function
    bounds : jnp.ndarray, shape (input_dim, 2)
        Box constraints
    q : int, default=1
        Batch size
    num_samples : int, default=1000
        Number of random samples to evaluate
    key : jax.random.KeyArray | None
        PRNG key

    Returns
    -------
    X_next : jnp.ndarray, shape (q, input_dim)
        Best random samples
    acq_values : jnp.ndarray, shape (q,)
        Acquisition values

    Examples
    --------
    >>> X_next, acq_val = optimize_acqf_random(acq_fn, bounds, q=1, num_samples=5000)
    """
    if key is None:
        key = jr.PRNGKey(0)

    # Generate random samples in bounds
    input_dim = bounds.shape[0]
    key, subkey = jr.split(key)
    random_samples = jr.uniform(subkey, (num_samples, input_dim))

    # Scale to bounds
    lower = bounds[:, 0]
    upper = bounds[:, 1]
    random_samples = lower + random_samples * (upper - lower)

    # Evaluate and select best
    return optimize_acqf_discrete(acq_fn, random_samples, q=q)

Base


base

base.py

Base protocol for acquisition functions.

Design

We use Protocol (not ABC) for maximum flexibility. An acquisition function is just any callable that: 1. Takes test points X 2. Returns scalar scores (higher = better)

This enables functional composition without inheritance.

Classes:

Name Description
AcquisitionFunction

Protocol for acquisition functions.

Attributes:

Name Type Description
AcqFn

AcqFn

AcqFn = Callable[[ndarray], ndarray]

AcquisitionFunction

Bases: Protocol

Protocol for acquisition functions.

An acquisition function scores candidate points for selection in adaptive experimental design. Higher scores indicate more valuable points.

Examples:

1
2
3
4
5
6
>>> # Function-based acquisition
>>> def my_acquisition(X):
...     posterior = model.posterior(X)
...     return posterior.mean + 2.0 * jnp.sqrt(posterior.variance)
>>>
>>> X_next = optimize_acqf(my_acquisition, bounds, q=1)
1
2
3
>>> # Lambda-based acquisition
>>> acq_fn = lambda X: -posterior.variance  # Minimize uncertainty
>>> X_next = optimize_acqf(acq_fn, bounds, q=1)
Notes

We deliberately do NOT use a class hierarchy (unlike BoTorch's AcquisitionFunction base class). Functional composition is simpler and more flexible for research code.

If you need stateful acquisition (e.g., caching), use a callable class: class CachedAcquisition: def init(self, model): self.model = model self._cache = {}

1
2
3
    def __call__(self, X):
        # Use cache...
        return scores

Expected Improvement


expected_improvement

expected_improvement.py

Expected Improvement (EI) acquisition function.

The most popular acquisition function for Bayesian optimization. Balances exploration (high uncertainty) and exploitation (high mean).

References

Mockus, J., Tiesis, V., & Zilinskas, A. (1978). The application of Bayesian methods for seeking the extremum. Towards Global Optimization, 2, 117-129.

Functions:

Name Description
expected_improvement

Expected improvement acquisition function.

log_expected_improvement

Log expected improvement for numerical stability.

expected_improvement

expected_improvement(
    posterior: PredictivePosterior,
    best_f: float,
    maximize: bool = True,
) -> ndarray

Expected improvement acquisition function.

Computes E[max(0, f(x) - best_f)] for each candidate point.

Parameters:

Name Type Description Default
posterior PredictivePosterior

Predictive posterior p(f(X*) | data)

required
best_f float

Best observed value so far

required
maximize bool

If True, maximize (higher is better). If False, minimize (lower is better).

True

Returns:

Type Description
(ndarray, shape(n_candidates))

EI values for each candidate

Examples:

1
2
3
4
5
>>> # Basic usage
>>> posterior = model.posterior(X_candidates, probes=probes)
>>> best_f = jnp.max(y_observed)
>>> ei = expected_improvement(posterior, best_f)
>>> X_next = X_candidates[jnp.argmax(ei)]
1
2
3
>>> # Minimization
>>> best_f = jnp.min(y_observed)
>>> ei = expected_improvement(posterior, best_f, maximize=False)
1
2
3
4
>>> # With optimization
>>> def acq_fn(X):
...     return expected_improvement(model.posterior(X), best_f=0.8, maximize=True)
>>> X_next, ei_val = optimize_acqf(acq_fn, bounds, q=1)
Notes

For psychophysics: - best_f is typically the highest accuracy observed - Use maximize=True to find points that maximize accuracy - EI naturally balances exploration (high variance) and exploitation (high mean)

Mathematical Details

Let μ(x), σ(x) be the posterior mean and std at x. Let u = (μ(x) - best_f) / σ(x) (standardized improvement).

Then: EI(x) = σ(x) * [u * \Phi(u) + arphi(u)]

where \Phi is the standard normal CDF and arphi is the PDF.

When σ(x) = 0 (no uncertainty), EI(x) = max(0, μ(x) - best_f).

Source code in src/psyphy/acquisition/expected_improvement.py
def expected_improvement(
    posterior: PredictivePosterior,
    best_f: float,
    maximize: bool = True,
) -> jnp.ndarray:
    """
    Expected improvement acquisition function.

    Computes E[max(0, f(x) - best_f)] for each candidate point.

    Parameters
    ----------
    posterior : PredictivePosterior
        Predictive posterior p(f(X*) | data)
    best_f : float
        Best observed value so far
    maximize : bool, default=True
        If True, maximize (higher is better).
        If False, minimize (lower is better).

    Returns
    -------
    jnp.ndarray, shape (n_candidates,)
        EI values for each candidate

    Examples
    --------
    >>> # Basic usage
    >>> posterior = model.posterior(X_candidates, probes=probes)
    >>> best_f = jnp.max(y_observed)
    >>> ei = expected_improvement(posterior, best_f)
    >>> X_next = X_candidates[jnp.argmax(ei)]

    >>> # Minimization
    >>> best_f = jnp.min(y_observed)
    >>> ei = expected_improvement(posterior, best_f, maximize=False)

    >>> # With optimization
    >>> def acq_fn(X):
    ...     return expected_improvement(model.posterior(X), best_f=0.8, maximize=True)
    >>> X_next, ei_val = optimize_acqf(acq_fn, bounds, q=1)

    Notes
    -----
    For psychophysics:
    - `best_f` is typically the highest accuracy observed
    - Use `maximize=True` to find points that maximize accuracy
    - EI naturally balances exploration (high variance) and
      exploitation (high mean)

    Mathematical Details
    --------------------
    Let μ(x), σ(x) be the posterior mean and std at x.
    Let u = (μ(x) - best_f) / σ(x) (standardized improvement).

    Then:
        EI(x) = σ(x) * [u * \\Phi(u) + \varphi(u)]

    where \\Phi is the standard normal CDF and \varphi is the PDF.

    When σ(x) = 0 (no uncertainty), EI(x) = max(0, μ(x) - best_f).
    """
    mean = posterior.mean
    std = jnp.sqrt(posterior.variance)

    if not maximize:
        # For minimization, flip the improvement
        u = (best_f - mean) / (std + 1e-9)  # Numerical stability
    else:
        u = (mean - best_f) / (std + 1e-9)

    # EI formula: σ * [u * \Phi(u) + \varphi(u)]
    normal_cdf = stats.norm.cdf(u)
    normal_pdf = stats.norm.pdf(u)

    ei = std * (u * normal_cdf + normal_pdf)

    # Handle numerical issues: EI should be non-negative
    ei = jnp.maximum(ei, 0.0)

    return ei

log_expected_improvement

log_expected_improvement(
    posterior: PredictivePosterior,
    best_f: float,
    maximize: bool = True,
) -> ndarray

Log expected improvement for numerical stability.

Useful when EI values span many orders of magnitude.

Parameters:

Name Type Description Default
posterior PredictivePosterior

Predictive posterior

required
best_f float

Best observed value

required
maximize bool

If True, maximize. If False, minimize.

True

Returns:

Type Description
(ndarray, shape(n_candidates))

log(EI) values

Examples:

1
2
3
>>> # When EI values are very small, log(EI) is more stable
>>> log_ei = log_expected_improvement(posterior, best_f)
>>> X_next = X_candidates[jnp.argmax(log_ei)]
Notes

Since we only care about ranking, log(EI) preserves the order: argmax EI(x) = argmax log(EI(x))

This is numerically more stable when EI values are near zero.

Source code in src/psyphy/acquisition/expected_improvement.py
def log_expected_improvement(
    posterior: PredictivePosterior,
    best_f: float,
    maximize: bool = True,
) -> jnp.ndarray:
    """
    Log expected improvement for numerical stability.

    Useful when EI values span many orders of magnitude.

    Parameters
    ----------
    posterior : PredictivePosterior
        Predictive posterior
    best_f : float
        Best observed value
    maximize : bool, default=True
        If True, maximize. If False, minimize.

    Returns
    -------
    jnp.ndarray, shape (n_candidates,)
        log(EI) values

    Examples
    --------
    >>> # When EI values are very small, log(EI) is more stable
    >>> log_ei = log_expected_improvement(posterior, best_f)
    >>> X_next = X_candidates[jnp.argmax(log_ei)]

    Notes
    -----
    Since we only care about ranking, log(EI) preserves the order:
        argmax EI(x) = argmax log(EI(x))

    This is numerically more stable when EI values are near zero.
    """
    ei = expected_improvement(posterior, best_f, maximize=maximize)

    # Add small constant for log stability
    log_ei = jnp.log(ei + 1e-25)

    return log_ei

Mutual Information


mutual_information

mutual_information.py

Mutual information (information gain) acquisition function.

Also known as BALD (Bayesian Active Learning by Disagreement). Selects points that maximize information gain about model parameters.

References

Houlsby, N., Huszár, F., Ghahramani, Z., & Lengyel, M. (2011). Bayesian active learning for classification and preference learning. arXiv preprint arXiv:1112.5745.

Functions:

Name Description
binary_entropy

Binary entropy H(p) = -p log(p) - (1-p) log(1-p).

mutual_information

Mutual information between parameters and observations.

binary_entropy

binary_entropy(p: ndarray) -> ndarray

Binary entropy H(p) = -p log(p) - (1-p) log(1-p).

Parameters:

Name Type Description Default
p ndarray

Probabilities in [0, 1]

required

Returns:

Type Description
ndarray

Entropy values

Source code in src/psyphy/acquisition/mutual_information.py
def binary_entropy(p: jnp.ndarray) -> jnp.ndarray:
    """
    Binary entropy H(p) = -p log(p) - (1-p) log(1-p).

    Parameters
    ----------
    p : jnp.ndarray
        Probabilities in [0, 1]

    Returns
    -------
    jnp.ndarray
        Entropy values
    """
    # Clip for numerical stability
    p = jnp.clip(p, 1e-10, 1 - 1e-10)

    entropy = -p * jnp.log(p) - (1 - p) * jnp.log(1 - p)

    return entropy

mutual_information

mutual_information(
    param_posterior: ParameterPosterior,
    X: ndarray,
    probes: ndarray | None = None,
    n_samples: int = 100,
    key: Any = None,
) -> ndarray

Mutual information between parameters and observations.

Computes I(θ; y | X, data) = H[p(y | X, data)] - E_θ[H[p(y | θ, X)]]

This measures how much we expect to learn about parameters θ from observing response y at location X.

Parameters:

Name Type Description Default
param_posterior ParameterPosterior

Posterior over model parameters p(θ | data)

required
X (ndarray, shape(n_candidates, input_dim))

Candidate reference stimuli

required
probes (ndarray, shape(n_candidates, input_dim) | None)

Candidate probe stimuli. Required for discrimination tasks.

None
n_samples int

Number of posterior samples for MC approximation

100
key KeyArray | None

PRNG key for sampling

None

Returns:

Type Description
(ndarray, shape(n_candidates))

Mutual information scores (higher = more informative)

Examples:

1
2
3
4
5
6
7
>>> # Basic usage
>>> param_post = model.posterior(kind="parameter")
>>> X_candidates = jnp.array([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]])
>>> probes = X_candidates + 0.1
>>>
>>> mi = mutual_information(param_post, X_candidates, probes, n_samples=200)
>>> X_next = X_candidates[jnp.argmax(mi)]
1
2
3
4
>>> # With optimization (discrete)
>>> def acq_fn(X):
...     return mutual_information(param_post, X, probes=None, n_samples=100)
>>> X_next, mi_val = optimize_acqf_discrete(acq_fn, candidates, q=1)
Notes

For psychophysics, mutual information is ideal for: - Threshold estimation: Find stimuli that maximally reduce uncertainty about perceptual thresholds - Model selection: Distinguish between competing perceptual models - Efficient design: Minimize trials needed for desired precision

Computational Cost

Requires n_samples posterior samples and forward passes through the model. For large candidate sets, use optimize_acqf_discrete() for efficiency.

Mathematical Details

Let θ ~ p(θ | data_observed) be the current parameter posterior. Let y be the hypothetical response at candidate X.

Mutual information: I(θ; y | X) = H[p(y | X)] - E_θ[H[p(y | θ, X)]]

where: - H[p(y | X)] is the predictive entropy (uncertainty before observing y) - E_θ[H[p(y | θ, X)]] is the expected conditional entropy (average uncertainty given a parameter sample)

This is approximated via MC: p(y | X) ≈ (1/N) Σ_i p(y | θ_i, X) where θ_i ~ p(θ | data)

BALD Interpretation

BALD (Bayesian Active Learning by Disagreement) selects points where different parameter samples θ_i disagree most about the prediction. High disagreement → high information gain.

Source code in src/psyphy/acquisition/mutual_information.py
def mutual_information(
    param_posterior: ParameterPosterior,
    X: jnp.ndarray,
    probes: jnp.ndarray | None = None,
    n_samples: int = 100,
    key: Any = None,
) -> jnp.ndarray:
    """
    Mutual information between parameters and observations.

    Computes I(θ; y | X, data) = H[p(y | X, data)] - E_θ[H[p(y | θ, X)]]

    This measures how much we expect to learn about parameters θ
    from observing response y at location X.

    Parameters
    ----------
    param_posterior : ParameterPosterior
        Posterior over model parameters p(θ | data)
    X : jnp.ndarray, shape (n_candidates, input_dim)
        Candidate reference stimuli
    probes : jnp.ndarray, shape (n_candidates, input_dim) | None
        Candidate probe stimuli. Required for discrimination tasks.
    n_samples : int, default=100
        Number of posterior samples for MC approximation
    key : jax.random.KeyArray | None
        PRNG key for sampling

    Returns
    -------
    jnp.ndarray, shape (n_candidates,)
        Mutual information scores (higher = more informative)

    Examples
    --------
    >>> # Basic usage
    >>> param_post = model.posterior(kind="parameter")
    >>> X_candidates = jnp.array([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]])
    >>> probes = X_candidates + 0.1
    >>>
    >>> mi = mutual_information(param_post, X_candidates, probes, n_samples=200)
    >>> X_next = X_candidates[jnp.argmax(mi)]

    >>> # With optimization (discrete)
    >>> def acq_fn(X):
    ...     return mutual_information(param_post, X, probes=None, n_samples=100)
    >>> X_next, mi_val = optimize_acqf_discrete(acq_fn, candidates, q=1)

    Notes
    -----
    For psychophysics, mutual information is ideal for:
    - **Threshold estimation**: Find stimuli that maximally reduce
      uncertainty about perceptual thresholds
    - **Model selection**: Distinguish between competing perceptual models
    - **Efficient design**: Minimize trials needed for desired precision

    Computational Cost
    ------------------
    Requires n_samples posterior samples and forward passes through the model.
    For large candidate sets, use optimize_acqf_discrete() for efficiency.

    Mathematical Details
    --------------------
    Let θ ~ p(θ | data_observed) be the current parameter posterior.
    Let y be the hypothetical response at candidate X.

    Mutual information:
        I(θ; y | X) = H[p(y | X)] - E_θ[H[p(y | θ, X)]]

    where:
    - H[p(y | X)] is the predictive entropy (uncertainty before observing y)
    - E_θ[H[p(y | θ, X)]] is the expected conditional entropy (average
      uncertainty given a parameter sample)

    This is approximated via MC:
        p(y | X) ≈ (1/N) Σ_i p(y | θ_i, X)  where θ_i ~ p(θ | data)

    BALD Interpretation
    -------------------
    BALD (Bayesian Active Learning by Disagreement) selects points where
    different parameter samples θ_i **disagree** most about the prediction.
    High disagreement → high information gain.
    """
    if key is None:
        key = jr.PRNGKey(0)

    # Sample from parameter posterior
    # TODO: Currently not using samples - need model.predict_prob_from_params()
    _param_samples = param_posterior.sample(n_samples, key=key)

    # Get model from posterior (assuming WPPM structure)
    # TODO: Need model.predict_prob_from_params() method
    # model = param_posterior.model  # Not yet implemented

    # Compute predictive probabilities for each parameter sample
    # Shape: (n_samples, n_candidates)
    prob_correct = []

    for _ in range(n_samples):
        # Extract parameters for this sample
        # TODO: Use params_i to compute p(correct | θ_i, X)
        # params_i = {k: v[i] for k, v in param_samples.items()}

        # Compute p(correct | θ_i, X)
        if probes is not None:
            # Discrimination task: compute prob of correct response
            # This requires model-specific logic
            # For WPPM: prob_correct = model.predict_prob(params_i, X, probes)

            # Placeholder: use predictive posterior mean as proxy
            # TODO: Implement model.predict_prob_from_params(params, X, probes)
            prob_i = jnp.ones(X.shape[0]) * 0.5  # Stub
        else:
            # Threshold task
            prob_i = jnp.ones(X.shape[0]) * 0.5  # Stub

        prob_correct.append(prob_i)

    # Convert list to array
    prob_correct_array = jnp.array(prob_correct)  # Shape: (n_samples, n_candidates)

    # Compute predictive entropy: H[p(y | X)]
    # p(y=1 | X) ≈ mean over samples
    pred_prob_mean = jnp.mean(prob_correct_array, axis=0)  # Shape: (n_candidates,)
    pred_entropy = binary_entropy(pred_prob_mean)

    # Compute expected conditional entropy: E_θ[H[p(y | θ, X)]]
    conditional_entropies = binary_entropy(
        prob_correct_array
    )  # Shape: (n_samples, n_candidates)
    expected_cond_entropy = jnp.mean(conditional_entropies, axis=0)

    # Mutual information = H[p(y|X)] - E_θ[H[p(y|θ,X)]]
    mi = pred_entropy - expected_cond_entropy

    # Should be non-negative (numerical stability)
    mi = jnp.maximum(mi, 0.0)

    return mi

Upper Confidence Bound


upper_confidence_bound

upper_confidence_bound.py

Upper Confidence Bound (UCB) acquisition function.

Balances exploration and exploitation via a tunable parameter β.

References

Srinivas, N., Krause, A., Kakade, S. M., & Seeger, M. (2009). Gaussian process optimization in the bandit setting: No regret and experimental design. arXiv preprint arXiv:0912.3995.

Functions:

Name Description
lower_confidence_bound

Lower confidence bound (LCB) for minimization.

upper_confidence_bound

Upper confidence bound acquisition function.

lower_confidence_bound

lower_confidence_bound(
    posterior: PredictivePosterior, beta: float = 2.0
) -> ndarray

Lower confidence bound (LCB) for minimization.

Alias for upper_confidence_bound(..., maximize=False).

Parameters:

Name Type Description Default
posterior PredictivePosterior

Predictive posterior

required
beta float

Exploration parameter

2.0

Returns:

Type Description
ndarray

LCB values

Source code in src/psyphy/acquisition/upper_confidence_bound.py
def lower_confidence_bound(
    posterior: PredictivePosterior,
    beta: float = 2.0,
) -> jnp.ndarray:
    """
    Lower confidence bound (LCB) for minimization.

    Alias for upper_confidence_bound(..., maximize=False).

    Parameters
    ----------
    posterior : PredictivePosterior
        Predictive posterior
    beta : float, default=2.0
        Exploration parameter

    Returns
    -------
    jnp.ndarray
        LCB values
    """
    return upper_confidence_bound(posterior, beta=beta, maximize=False)

upper_confidence_bound

upper_confidence_bound(
    posterior: PredictivePosterior,
    beta: float = 2.0,
    maximize: bool = True,
) -> ndarray

Upper confidence bound acquisition function.

Computes μ(x) + β * \sigma(x) for maximization. Computes μ(x) - β * \sigma(x) for minimization.

Parameters:

Name Type Description Default
posterior PredictivePosterior

Predictive posterior p(f(X*) | data)

required
beta float

Exploration-exploitation trade-off parameter. - β = 0: Pure exploitation (greedy selection) - β = 1: Balanced - β > 2: Aggressive exploration

2.0
maximize bool

If True, maximize (higher is better). If False, minimize (lower is better).

True

Returns:

Type Description
(ndarray, shape(n_candidates))

UCB values for each candidate

Examples:

1
2
3
4
>>> # Balanced exploration-exploitation
>>> posterior = model.posterior(X_candidates, probes=probes)
>>> ucb = upper_confidence_bound(posterior, beta=2.0)
>>> X_next = X_candidates[jnp.argmax(ucb)]
>>> # Aggressive exploration (high β)
>>> ucb = upper_confidence_bound(posterior, beta=5.0)
>>> # Pure exploitation (β = 0)
>>> ucb = upper_confidence_bound(posterior, beta=0.0)  # Just selects max mean
1
2
3
4
>>> # With optimization
>>> def acq_fn(X):
...     return upper_confidence_bound(model.posterior(X), beta=2.0, maximize=True)
>>> X_next, ucb_val = optimize_acqf(acq_fn, bounds, q=1)
Notes

For psychophysics: - Use β ∈ [1, 3] for typical experiments - Larger β explores uncertain regions more - Smaller β focuses on high-accuracy regions

UCB is often faster to compute than EI (no need for CDF/PDF), but theoretically less well-motivated for finite-sample regret.

Adaptive β

Theoretically, β should grow with number of trials: β_t = sqrt(2 * log(input_dim * t^2 * \pi^2 / 6\delta))

where t is the trial number and delta is a confidence parameter. In practice, a fixed β ∈ [1, 3] works well.

Source code in src/psyphy/acquisition/upper_confidence_bound.py
def upper_confidence_bound(
    posterior: PredictivePosterior,
    beta: float = 2.0,
    maximize: bool = True,
) -> jnp.ndarray:
    r"""
    Upper confidence bound acquisition function.

    Computes μ(x) + β * \sigma(x) for maximization.
    Computes μ(x) - β * \sigma(x) for minimization.

    Parameters
    ----------
    posterior : PredictivePosterior
        Predictive posterior p(f(X*) | data)
    beta : float, default=2.0
        Exploration-exploitation trade-off parameter.
        - β = 0: Pure exploitation (greedy selection)
        - β = 1: Balanced
        - β > 2: Aggressive exploration
    maximize : bool, default=True
        If True, maximize (higher is better).
        If False, minimize (lower is better).

    Returns
    -------
    jnp.ndarray, shape (n_candidates,)
        UCB values for each candidate

    Examples
    --------
    >>> # Balanced exploration-exploitation
    >>> posterior = model.posterior(X_candidates, probes=probes)
    >>> ucb = upper_confidence_bound(posterior, beta=2.0)
    >>> X_next = X_candidates[jnp.argmax(ucb)]

    >>> # Aggressive exploration (high β)
    >>> ucb = upper_confidence_bound(posterior, beta=5.0)

    >>> # Pure exploitation (β = 0)
    >>> ucb = upper_confidence_bound(posterior, beta=0.0)  # Just selects max mean

    >>> # With optimization
    >>> def acq_fn(X):
    ...     return upper_confidence_bound(model.posterior(X), beta=2.0, maximize=True)
    >>> X_next, ucb_val = optimize_acqf(acq_fn, bounds, q=1)

    Notes
    -----
    For psychophysics:
    - Use β ∈ [1, 3] for typical experiments
    - Larger β explores uncertain regions more
    - Smaller β focuses on high-accuracy regions

    UCB is often faster to compute than EI (no need for CDF/PDF),
    but theoretically less well-motivated for finite-sample regret.

    Adaptive β
    ----------
    Theoretically, β should grow with number of trials:
        β_t = sqrt(2 * log(input_dim * t^2 * \pi^2 / 6\delta))

    where t is the trial number and
     delta is a confidence parameter.
    In practice, a fixed β ∈ [1, 3] works well.
    """
    mean = posterior.mean
    std = jnp.sqrt(posterior.variance)

    ucb = mean + beta * std if maximize else mean - beta * std

    return ucb

Optimization


optimize

optimize.py

Optimization utilities for acquisition functions.

Provides functional interface for maximizing acquisition functions: - optimize_acqf_discrete: Exhaustive search over candidate set - optimize_acqf: Gradient-based optimization (continuous) - optimize_acqf_random: Random search baseline

Design

Following BoTorch's API: X_next, acq_value = optimize_acqf(acq_fn, bounds, q=1)

But adapted for psyphy: - Support for both continuous and discrete optimization - JAX-based gradient descent (Optax) - Batch acquisition (q > 1) support

Functions:

Name Description
optimize_acqf

Optimize acquisition function over continuous domain.

optimize_acqf_discrete

Optimize acquisition function over discrete candidate set.

optimize_acqf_random

Optimize acquisition function via random search.

optimize_acqf

optimize_acqf(
    acq_fn: Callable[[ndarray], ndarray],
    bounds: ndarray,
    q: int = 1,
    *,
    method: Literal["gradient", "random"] = "gradient",
    num_restarts: int = 10,
    raw_samples: int = 100,
    optim_steps: int = 100,
    lr: float = 0.01,
    key: Any = None,
) -> tuple[ndarray, ndarray]

Optimize acquisition function over continuous domain.

Uses multi-start gradient descent to find global optimum.

Parameters:

Name Type Description Default
acq_fn callable

Acquisition function. Takes (n_points, input_dim) array, returns (n_points,) scores.

required
bounds (ndarray, shape(input_dim, 2))

Box constraints [[x1_min, x1_max], [x2_min, x2_max], ...]

required
q int

Batch size (number of points to select)

1
method ('gradient', 'random')

Optimization method

"gradient"
num_restarts int

Number of random restarts for gradient descent

10
raw_samples int

Number of random samples to initialize restarts

100
optim_steps int

Number of optimization steps per restart

100
lr float

Learning rate for gradient descent

0.01
key KeyArray | None

PRNG key for random initialization

None

Returns:

Name Type Description
X_next (ndarray, shape(q, input_dim))

Optimized points

acq_values (ndarray, shape(q))

Acquisition values at X_next

Examples:

1
2
3
4
5
6
>>> # Simple continuous optimization
>>> def acq_fn(X):
...     return expected_improvement(model.posterior(X), best_f=0.5)
>>>
>>> bounds = jnp.array([[0.0, 1.0], [0.0, 1.0]])  # 2D unit square
>>> X_next, acq_val = optimize_acqf(acq_fn, bounds, q=1, method="gradient")
1
2
3
4
5
6
7
8
>>> # Batch acquisition with multiple restarts
>>> X_batch, acq_vals = optimize_acqf(
...     acq_fn,
...     bounds,
...     q=3,
...     num_restarts=20,
...     optim_steps=200,
... )
Notes

For gradient-based optimization, ensure your acquisition function is differentiable through JAX. Use jax.grad() or jax.value_and_grad().

For non-differentiable acquisition functions, use method="random" or optimize_acqf_discrete() with a candidate grid.

Source code in src/psyphy/acquisition/optimize.py
def optimize_acqf(
    acq_fn: Callable[[jnp.ndarray], jnp.ndarray],
    bounds: jnp.ndarray,
    q: int = 1,
    *,
    method: Literal["gradient", "random"] = "gradient",
    num_restarts: int = 10,
    raw_samples: int = 100,
    optim_steps: int = 100,
    lr: float = 0.01,
    key: Any = None,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """
    Optimize acquisition function over continuous domain.

    Uses multi-start gradient descent to find global optimum.

    Parameters
    ----------
    acq_fn : callable
        Acquisition function. Takes (n_points, input_dim) array,
        returns (n_points,) scores.
    bounds : jnp.ndarray, shape (input_dim, 2)
        Box constraints [[x1_min, x1_max], [x2_min, x2_max], ...]
    q : int, default=1
        Batch size (number of points to select)
    method : {"gradient", "random"}, default="gradient"
        Optimization method
    num_restarts : int, default=10
        Number of random restarts for gradient descent
    raw_samples : int, default=100
        Number of random samples to initialize restarts
    optim_steps : int, default=100
        Number of optimization steps per restart
    lr : float, default=0.01
        Learning rate for gradient descent
    key : jax.random.KeyArray | None
        PRNG key for random initialization

    Returns
    -------
    X_next : jnp.ndarray, shape (q, input_dim)
        Optimized points
    acq_values : jnp.ndarray, shape (q,)
        Acquisition values at X_next

    Examples
    --------
    >>> # Simple continuous optimization
    >>> def acq_fn(X):
    ...     return expected_improvement(model.posterior(X), best_f=0.5)
    >>>
    >>> bounds = jnp.array([[0.0, 1.0], [0.0, 1.0]])  # 2D unit square
    >>> X_next, acq_val = optimize_acqf(acq_fn, bounds, q=1, method="gradient")

    >>> # Batch acquisition with multiple restarts
    >>> X_batch, acq_vals = optimize_acqf(
    ...     acq_fn,
    ...     bounds,
    ...     q=3,
    ...     num_restarts=20,
    ...     optim_steps=200,
    ... )

    Notes
    -----
    For gradient-based optimization, ensure your acquisition function
    is differentiable through JAX. Use jax.grad() or jax.value_and_grad().

    For non-differentiable acquisition functions, use method="random"
    or optimize_acqf_discrete() with a candidate grid.
    """
    if key is None:
        key = jr.PRNGKey(0)

    if method == "random":
        return optimize_acqf_random(
            acq_fn, bounds, q=q, num_samples=raw_samples, key=key
        )
    elif method == "gradient":
        return _optimize_acqf_gradient(
            acq_fn,
            bounds,
            q=q,
            num_restarts=num_restarts,
            raw_samples=raw_samples,
            optim_steps=optim_steps,
            lr=lr,
            key=key,
        )
    else:
        raise ValueError(f"Unknown method: {method}. Use 'gradient' or 'random'.")

optimize_acqf_discrete

optimize_acqf_discrete(
    acq_fn: Callable[[ndarray], ndarray],
    candidates: ndarray,
    q: int = 1,
) -> tuple[ndarray, ndarray]

Optimize acquisition function over discrete candidate set.

This is the simplest and most common approach for psychophysics, where stimulus space is often discretized.

Parameters:

Name Type Description Default
acq_fn callable

Acquisition function. Takes (n_candidates, input_dim) array, returns (n_candidates,) scores.

required
candidates (ndarray, shape(n_candidates, input_dim))

Discrete candidate points to evaluate

required
q int

Batch size (number of points to select)

1

Returns:

Name Type Description
X_next (ndarray, shape(q, input_dim))

Selected candidate points

acq_values (ndarray, shape(q))

Acquisition values of selected points

Examples:

1
2
3
4
5
6
>>> # Simple discrete optimization
>>> candidates = jnp.array([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]])
>>> def acq_fn(X):
...     return expected_improvement(model.posterior(X), best_f=0.5)
>>>
>>> X_next, acq_val = optimize_acqf_discrete(acq_fn, candidates, q=1)
>>> # Batch acquisition
>>> X_batch, acq_vals = optimize_acqf_discrete(acq_fn, candidates, q=3)
Source code in src/psyphy/acquisition/optimize.py
def optimize_acqf_discrete(
    acq_fn: Callable[[jnp.ndarray], jnp.ndarray],
    candidates: jnp.ndarray,
    q: int = 1,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """
    Optimize acquisition function over discrete candidate set.

    This is the simplest and most common approach for psychophysics,
    where stimulus space is often discretized.

    Parameters
    ----------
    acq_fn : callable
        Acquisition function. Takes (n_candidates, input_dim) array,
        returns (n_candidates,) scores.
    candidates : jnp.ndarray, shape (n_candidates, input_dim)
        Discrete candidate points to evaluate
    q : int, default=1
        Batch size (number of points to select)

    Returns
    -------
    X_next : jnp.ndarray, shape (q, input_dim)
        Selected candidate points
    acq_values : jnp.ndarray, shape (q,)
        Acquisition values of selected points

    Examples
    --------
    >>> # Simple discrete optimization
    >>> candidates = jnp.array([[0.0, 0.0], [0.5, 0.5], [1.0, 1.0]])
    >>> def acq_fn(X):
    ...     return expected_improvement(model.posterior(X), best_f=0.5)
    >>>
    >>> X_next, acq_val = optimize_acqf_discrete(acq_fn, candidates, q=1)

    >>> # Batch acquisition
    >>> X_batch, acq_vals = optimize_acqf_discrete(acq_fn, candidates, q=3)
    """
    # Evaluate all candidates
    acq_values = acq_fn(candidates)

    # Select top-q by acquisition value
    top_indices = jnp.argsort(acq_values)[-q:][::-1]  # Descending order
    X_next = candidates[top_indices]
    selected_values = acq_values[top_indices]

    return X_next, selected_values

optimize_acqf_random

optimize_acqf_random(
    acq_fn: Callable[[ndarray], ndarray],
    bounds: ndarray,
    q: int = 1,
    *,
    num_samples: int = 1000,
    key: Any = None,
) -> tuple[ndarray, ndarray]

Optimize acquisition function via random search.

Simple baseline: sample random points, evaluate, select best.

Parameters:

Name Type Description Default
acq_fn callable

Acquisition function

required
bounds (ndarray, shape(input_dim, 2))

Box constraints

required
q int

Batch size

1
num_samples int

Number of random samples to evaluate

1000
key KeyArray | None

PRNG key

None

Returns:

Name Type Description
X_next (ndarray, shape(q, input_dim))

Best random samples

acq_values (ndarray, shape(q))

Acquisition values

Examples:

>>> X_next, acq_val = optimize_acqf_random(acq_fn, bounds, q=1, num_samples=5000)
Source code in src/psyphy/acquisition/optimize.py
def optimize_acqf_random(
    acq_fn: Callable[[jnp.ndarray], jnp.ndarray],
    bounds: jnp.ndarray,
    q: int = 1,
    *,
    num_samples: int = 1000,
    key: Any = None,
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """
    Optimize acquisition function via random search.

    Simple baseline: sample random points, evaluate, select best.

    Parameters
    ----------
    acq_fn : callable
        Acquisition function
    bounds : jnp.ndarray, shape (input_dim, 2)
        Box constraints
    q : int, default=1
        Batch size
    num_samples : int, default=1000
        Number of random samples to evaluate
    key : jax.random.KeyArray | None
        PRNG key

    Returns
    -------
    X_next : jnp.ndarray, shape (q, input_dim)
        Best random samples
    acq_values : jnp.ndarray, shape (q,)
        Acquisition values

    Examples
    --------
    >>> X_next, acq_val = optimize_acqf_random(acq_fn, bounds, q=1, num_samples=5000)
    """
    if key is None:
        key = jr.PRNGKey(0)

    # Generate random samples in bounds
    input_dim = bounds.shape[0]
    key, subkey = jr.split(key)
    random_samples = jr.uniform(subkey, (num_samples, input_dim))

    # Scale to bounds
    lower = bounds[:, 0]
    upper = bounds[:, 1]
    random_samples = lower + random_samples * (upper - lower)

    # Evaluate and select best
    return optimize_acqf_discrete(acq_fn, random_samples, q=q)