Source code for jax_moseq.models.slds.gibbs

import jax
import jax.numpy as jnp
import jax.random as jr
from functools import partial
from jax_moseq.utils import mixed_map, apply_affine
from jax_moseq.models import arhmm
from jax_moseq.utils.kalman import (
    kalman_sample,
    ar_to_lds_dynamics,
    ar_to_lds_emissions,
)


na = jnp.newaxis


[docs] def resample_continuous_stateseqs( seed, y, mask, z, s, Ab, Q, Cd, sigmasq, jitter=1e-3, parallel_message_passing=True, **kwargs ): """Resample the latent trajectories `x`. Parameters ---------- seed : jr.PRNGKey JAX random seed. y : jax.Array of shape (n_recordings, n_timesteps, obs_dim) Observations. mask : jax.Array of shape (n_recordings, n_timesteps) Binary indicator, 1=valid frames, 0=invalid frames. z : jax.Array of shape (n_recordings, n_timesteps-n_lags) Discrete state sequences, taking integer values between [0, n_states), for timesteps [n_lags, n_timesteps), s : jax.Array of shape (n_recordings, n_timesteps, obs_dim) Observation noise scales. Ab : jax.Array of shape (n_states, latent_dim, ar_dim + 1) Autoregressive dynamics and bias, where `ar_dim = latent_dim * n_lags` Q : jax.Array of shape (n_states, latent_dim, latent_dim) Autoregressive noise covariances. Cd : jax.Array of shape (obs_dim, latent_dim + 1) Affine transform from `latent_dim` to `state_dim` sigmasq : jax.Array of shape (obs_dim,) Unscaled noise. jitter : float, default=1e-3 Amount to boost the diagonal of the covariance matrix during backward-sampling of the continuous states. parallel_message_passing : bool, default=True, Use associative scan for Kalman sampling, which is faster on a GPU but has a significantly longer jit time. **kwargs : dict Overflow, for convenience. Returns ------ x : jax.Array of shape (n_recordings, n_timesteps, latent_dim) Posterior sample of latent trajectories. """ n_recordings, latent_dim, obs_dim = y.shape[0], Ab.shape[1], y.shape[-1] n_lags = Ab.shape[2] // latent_dim # TODO Parameterize these distributional hyperparameter m0 = jnp.zeros(latent_dim * n_lags) S0 = 10 * jnp.eye(latent_dim * n_lags) # TODO: hard coded constant 10 masked_dynamics_noise = 10 masked_obs_noise = 10 # ===================================================================== # 1. Omit the first n_lags frames of observations and associated sequences # ===================================================================== y_ = y[:, n_lags - 1 :] mask_ = mask[:, n_lags - 1 :] # Scale unscaled observations by fitted diagonal scales R_ = sigmasq * s[:, n_lags - 1 :] # ========================================================================== # 2. Reformat n_lags'th-order AR dynamics in R^D to 1st-order dynamics in R^{DL} # ========================================================================== C_, d_, R_, y_, m0_, S0_ = jax.vmap( ar_to_lds_emissions, in_axes=(na, 0, 0, na, na, na) )(Cd, sigmasq * s, y, m0, S0, n_lags) A_, b_, Q_ = ar_to_lds_dynamics(Ab, Q) # ============================================= # 3. Formulate parameters for masked timesteps # ============================================= ar_dim = latent_dim * n_lags # If masked, hold the last state, i.e. set dynamics for "unlagged" states to # identity matrix and all other state dynamics to 0 eye_zero_order = jnp.zeros((ar_dim, ar_dim)) eye_zero_order = eye_zero_order.at[-latent_dim:, -latent_dim:].set( jnp.eye(latent_dim) ) masked_dynamics_params = { "weights": eye_zero_order, "bias": jnp.zeros(ar_dim), "cov": jnp.eye(ar_dim) * masked_dynamics_noise, } masked_obs_noise_diag = jnp.ones(obs_dim) * masked_obs_noise # ================================================== # 4. Apply vectorized Kalman sample to each recording # Shapes of time-varying parameters going into the Kalman sampler are # ys: (n_timesteps-n_lags+1, obs_dim), corresponding to timesteps [nlags-1, T) # mask: (n_timesteps-n_lags+1,) # zs: (n_timesteps-n_lags,), corresponding to timesteps [n_lags, T) # Rs: (n_timesteps-n_lags+1, obs_dim) # ================================================== x = batched_kalman_sample( jr.split(seed, n_recordings), y_, mask_, z, m0, S0, A_, b_, Q_, C_, d_, R_, masked_dynamics_params, masked_obs_noise_diag, jitter, parallel_message_passing, ) # ========================================================================= # 5. Reformat sampled trajectories back into L'th order AR dynamics in R^D # ========================================================================= if n_lags > 1: x = jnp.concatenate( [ x[:, 0, : (n_lags - 1) * latent_dim].reshape( -1, n_lags - 1, latent_dim ), x[:, :, -latent_dim:], ], axis=1, ) return x
[docs] @partial(jax.jit, static_argnames=("parallel",)) def batched_kalman_sample( seeds, y_, mask_, z, m0, S0, A_, b_, Q_, C_, d_, R_, masked_dynamics_params, masked_obs_noise_diag, jitter, parallel, ): """Wrapper that applied `mixed_map` to `kalman_sample` for batched inputs.""" in_axes = (0, 0, 0, 0, na, na, na, na, na, na, na, 0, na, na) x = mixed_map( partial(kalman_sample, jitter=jitter, parallel=parallel), in_axes, )( seeds, y_, mask_, z, m0, S0, A_, b_, Q_, C_, d_, R_, masked_dynamics_params, masked_obs_noise_diag, ) return x
[docs] @jax.jit def resample_obs_variance(seed, Y, mask, x, s, Cd, nu_sigma, sigmasq_0, **kwargs): """ Resample the observation variance `sigmasq`. Parameters ---------- seed : jr.PRNGKey JAX random seed. Y : jax array of shape (N, T, obs_dim) Observations. mask : jax array of shape (N, T) Binary indicator for valid frames. x : jax array of shape (N, T, latent_dim) Latent trajectories. s : jax array of shape (N, T, obs_dim) Noise scales. Cd : jax array of shape (obs_dim, latent_dim + 1) Observation transform. nu_sigma : float Chi-squared degrees of freedom in sigmasq. sigmasq_0 : float Scaled inverse chi-squared scaling parameter for sigmasq. **kwargs : dict Overflow, for convenience. Returns ------ sigmasq : jax_array of shape obs_dim Unscaled noise. """ sqerr = compute_squared_error(Y, x, Cd, mask) return resample_obs_variance_from_sqerr(seed, sqerr, mask, s, nu_sigma, sigmasq_0)
[docs] @jax.jit def resample_obs_variance_from_sqerr( seed, sqerr, mask, s, nu_sigma, sigmasq_0, **kwargs ): """ Resample the observation variance `sigmasq` using the squared error between predicted and true observations. Parameters ---------- seed : jr.PRNGKey JAX random seed. sqerr : jax array of shape (N, T, obs_dim) Squared error between predicted and true observations. mask : jax array of shape (N, T) Binary indicator for valid frames. s : jax array of shape (N, T, obs_dim) Noise scales. nu_sigma : float Chi-squared degrees of freedom in sigmasq. sigmasq_0 : float Scaled inverse chi-squared scaling parameter for sigmasq. **kwargs : dict Overflow, for convenience. Returns ------ sigmasq : jax_array of shape obs_dim Unscaled noise. """ degs = nu_sigma + 3 * mask.sum() k = sqerr.shape[-1] S_y = (sqerr / s).reshape(-1, k).sum(0) # (..., k) -> k variance = nu_sigma * sigmasq_0 + S_y return _resample_spread(seed, degs, variance)
[docs] @jax.jit def resample_scales(seed, Y, x, Cd, sigmasq, nu_s, s_0, **kwargs): """ Resample the scale values `s`. Parameters ---------- seed : jr.PRNGKey JAX random seed. Y : jax array of shape (N, T, obs_dim) Observations. x : jax array of shape (N, T, latent_dim) Latent trajectories. Cd : jax array of shape (obs_dim, latent_dim + 1) Observation transform. sigmasq : jax_array of shape obs_dim Unscaled noise. nu_s : int Chi-squared degrees of freedom in noise prior. s_0 : scalar or jax array broadcastable to `Y` Prior on noise scale. **kwargs : dict Overflow, for convenience. Returns ------ s : jax array of shape (N, T, obs_dim) Noise scales. """ sqerr = compute_squared_error(Y, x, Cd) return resample_scales_from_sqerr(seed, sqerr, sigmasq, nu_s, s_0)
[docs] @jax.jit def resample_scales_from_sqerr(seed, sqerr, sigmasq, nu_s, s_0, **kwargs): """ Resample the scale values `s` using the squared error between predicted and true observations. Parameters ---------- seed : jr.PRNGKey JAX random seed. sqerr : jax array of shape (N, T, obs_dim) Squared error between predicted and true observations. sigmasq : jax_array of shape obs_dim Unscaled noise. nu_s : int Chi-squared degrees of freedom in noise prior. s_0 : scalar or jax array broadcastable to `Y` Prior on noise scale. **kwargs : dict Overflow, for convenience. Returns ------ s : jax array of shape (N, T, obs_dim) Per observation noise scales. """ degs = nu_s + 3 variance = sqerr / sigmasq + s_0 * nu_s return _resample_spread(seed, degs, variance)
@jax.jit def _resample_spread(seed, degs, variance): """ Resample the noise values from the computed degrees of freedom and variance. Parameters ---------- seed : jr.PRNGKey JAX random seed. degs : scalar Chi-squared degrees of freedom. variance : jax array Variance computed from the data. Returns ------ spread : jax array, same shape as `variance` Resampled noise values. """ # same as sample_scaled_inv_chi2(seed, degs, variance / degs) return variance / jr.gamma(seed, degs / 2, shape=variance.shape) / 2
[docs] @jax.jit def compute_squared_error(Y, x, Cd, mask=None): """ Computes the squared error between model predicted and true observations. Parameters ---------- Y : jax array of shape (..., obs_dim) Observations. x : jax array of shape (..., latent_dim) Latent trajectories. Cd : jax array of shape (obs_dim, latent_dim + 1) Observation transform. mask : jax array of shape (...), optional Binary indicator for valid frames. Returns ------ sqerr : jax array of shape (..., obs_dim) Squared error between model predicted and true observations. """ Y_bar = apply_affine(x, Cd) sqerr = (Y - Y_bar) ** 2 if mask is not None: sqerr = mask[..., na] * sqerr return sqerr
[docs] def resample_model( data, seed, states, params, hypparams, ar_only=False, states_only=False, skip_noise=True, parallel_message_passing=False, **kwargs ): """ Resamples the SLDS model given the hyperparameters, data, current states, and current parameters. Parameters ---------- data : dict Data dictionary containing the observations and mask. seed : jr.PRNGKey JAX random seed. states : dict State values for each latent variable. params : dict Values for each model parameter. hypparams : dict Values for each group of hyperparameters. ar_only : bool, default=False Whether to restrict sampling to ARHMM components. states_only : bool, default=False Whether to restrict sampling to states. skip_noise : bool, default=True Whether to exclude `sigmasq` and `s` from resampling. parallel_message_passing : bool, default=True, Use associative scan for Kalman sampling, which is faster on a GPU but has a significantly longer jit time. **kwargs : dict Overflow, for convenience. Returns ------ model : dict Dictionary containing the hyperparameters and updated seed, states, and parameters of the model. """ model = arhmm.resample_model(data, seed, states, params, hypparams, states_only) if ar_only: return model seed = model["seed"] params = model["params"].copy() states = model["states"].copy() if not (states_only or skip_noise): params["sigmasq"] = resample_obs_variance( seed, **data, **states, **params, **hypparams["obs_hypparams"] ) states["x"] = resample_continuous_stateseqs( seed, **data, **states, **params, parallel_message_passing=parallel_message_passing ) if not skip_noise: states["s"] = resample_scales( seed, **data, **states, **params, **hypparams["obs_hypparams"] ) return { "seed": seed, "states": states, "params": params, "hypparams": hypparams, }