Source code for jax_moseq.models.arhmm.gibbs

import jax
import jax.numpy as jnp
import jax.random as jr

from dynamax.hidden_markov_model.inference import hmm_smoother

from jax_moseq.utils import (
    pad_affine,
    psd_solve,
    psd_inv,
    nan_check,
    mixed_map,
)

from jax_moseq.utils.distributions import sample_mniw, sample_hmm_stateseq
from jax_moseq.utils.autoregression import (
    get_lags,
    get_nlags,
    ar_log_likelihood,
)
from jax_moseq.utils.transitions import resample_hdp_transitions
from functools import partial

na = jnp.newaxis


[docs] @jax.jit def resample_discrete_stateseqs(seed, x, mask, Ab, Q, pi, **kwargs): """ Resamples the discrete state sequence ``z``. Parameters ---------- seed : jr.PRNGKey JAX random seed. x : jax array of shape (N, T, latent_dim) Latent trajectories. mask : jax array of shape (N, T) Binary indicator for valid frames. Ab : jax array of shape (num_states, latent_dim, ar_dim) Autoregressive transforms. Q : jax array of shape (num_states, latent_dim, latent_dim) Autoregressive noise covariances. pi : jax_array of shape (num_states, num_states) Transition probabilities. **kwargs : dict Overflow, for convenience. Returns ------ z : jax_array of shape (N, T - n_lags) Discrete state sequences. """ nlags = get_nlags(Ab) log_likelihoods = jax.lax.map(partial(ar_log_likelihood, x), (Ab, Q)) _, z = jax.vmap(sample_hmm_stateseq, in_axes=(0, na, 0, 0))( jr.split(seed, mask.shape[0]), pi, jnp.moveaxis(log_likelihoods, 0, -1), mask.astype(float)[:, nlags:], ) return z
[docs] @jax.jit def stateseq_marginals(x, mask, Ab, Q, pi, **kwargs): """ Computes the marginal probability of each discrete state at each time step. Parameters ---------- x : jax array of shape (N, T, latent_dim) Latent trajectories. mask : jax array of shape (N, T) Binary indicator for valid frames. Ab : jax array of shape (num_states, latent_dim, ar_dim) Autoregressive transforms. Q : jax array of shape (num_states, latent_dim, latent_dim) Autoregressive noise covariances. pi : jax_array of shape (num_states, num_states) Transition probabilities. **kwargs : dict Overflow, for convenience. Returns ------ z_marginals : jax array of shape (N, T, num_states) Marginal probability of each discrete state at each time step. """ nlags = get_nlags(Ab) num_states = pi.shape[0] initial_distribution = jnp.ones(num_states) / num_states log_likelihoods = jax.lax.map(partial(ar_log_likelihood, x), (Ab, Q)) log_likelihoods = jnp.moveaxis(log_likelihoods, 0, -1) masked_log_likelihoods = log_likelihoods * mask[:, nlags:, na] smoother = lambda lls: hmm_smoother(initial_distribution, pi, lls).smoothed_probs z_marginals = mixed_map(smoother)(masked_log_likelihoods) return z_marginals
[docs] @nan_check @partial(jax.jit, static_argnames=("num_states", "nlags")) def resample_ar_params( seed, *, nlags, num_states, mask, x, z, nu_0, S_0, M_0, K_0, **kwargs ): """ Resamples the AR parameters ``Ab`` and ``Q``. Parameters ---------- seed : jr.PRNGKey JAX random seed. nlags : int Number of autoregressive lags. num_states : int Max number of HMM states. mask : jax array of shape (N, T) Binary indicator for valid frames. x : jax array of shape (N, T, latent_dim) Latent trajectories. z : jax_array of shape (N, T - n_lags) Discrete state sequences. nu_0 : int Inverse-Wishart degrees of freedom parameter for Q. S_0 : jax array of shape (latent_dim, latent_dim) Inverse-Wishart scale parameter for Q. M_0 : jax array of shape (latent_dim, ar_dim) Matrix normal expectation for Ab. K_0 : jax array of shape (ar_dim, ar_dim) Matrix normal column scale parameter for Ab. **kwargs : dict Overflow, for convenience. Returns ------ Ab : jax array of shape (num_states, latent_dim, ar_dim) Autoregressive transforms. Q : jax array of shape (num_states, latent_dim, latent_dim) Autoregressive noise covariances. """ seeds = jr.split(seed, num_states) masks = mask[..., nlags:].reshape(1, -1) * jnp.eye(num_states)[:, z.reshape(-1)] x_in = pad_affine(get_lags(x, nlags)).reshape(-1, nlags * x.shape[-1] + 1) x_out = x[..., nlags:, :].reshape(-1, x.shape[-1]) map_fun = partial(_resample_regression_params, x_in, x_out, nu_0, S_0, M_0, K_0) Ab, Q = jax.lax.map(map_fun, (seeds, masks)) return Ab, Q
@nan_check @jax.jit def _resample_regression_params(x_in, x_out, nu_0, S_0, M_0, K_0, args): """ Resamples regression parameters from a Matrix normal inverse-Wishart distribution. Parameters ---------- x_in : jax array of shape (..., in_dim) Regression input. x_out : jax array of shape (..., out_dim) Regression output. nu_0 : int Inverse-Wishart degrees of freedom parameter for Q. S_0 : jax array of shape (out_dim, out_dim) Inverse-Wishart scale parameter for Q. M_0 : jax array of shape (out_dim, in_dim) Matrix normal expectation for Ab. K_0 : jax array of shape (in_dim, in_dim) Matrix normal column scale parameter for Ab. args: tuple (seed, mask) JAX random seed and binary indicator for frames to use for calculating the sufficient statistics. Returns ------ Ab : jax array of shape (num_states, out_dim, in_dim) Regression transforms. Q : jax array of shape (num_states, out_dim, out_dim) Regression noise covariances. """ seed, mask = args S_out_out = jnp.einsum("ti,tj,t->ij", x_out, x_out, mask) S_out_in = jnp.einsum("ti,tj,t->ij", x_out, x_in, mask) S_in_in = jnp.einsum("ti,tj,t->ij", x_in, x_in, mask) K_0_inv = psd_inv(K_0) K_n_inv = K_0_inv + S_in_in K_n = psd_inv(K_n_inv) M_n = psd_solve(K_n_inv.T, K_0_inv @ M_0.T + S_out_in.T).T S_n = S_0 + S_out_out + (M_0 @ K_0_inv @ M_0.T - M_n @ K_n_inv @ M_n.T) return sample_mniw(seed, nu_0 + mask.sum(), S_n, M_n, K_n)
[docs] def resample_model( data, seed, states, params, hypparams, states_only=False, verbose=False, **kwargs ): """ Resamples the ARHMM model given the hyperparameters, data, current states, and current parameters. Parameters ---------- data : dict Data dictionary containing the observations and mask. seed : jr.PRNGKey JAX random seed. states : dict State values for each latent variable. params : dict Values for each model parameter. hypparams : dict Values for each group of hyperparameters. states_only : bool, default=False Whether to restrict sampling to states. **kwargs : dict Overflow, for convenience. Returns ------ model : dict Dictionary containing the hyperparameters and updated seed, states, and parameters of the model. """ seed = jr.split(seed)[1] params = params.copy() states = states.copy() if not states_only: if verbose: print("Resampling pi (transition matrix)") params["betas"], params["pi"] = resample_hdp_transitions( seed, **data, **states, **params, **hypparams["trans_hypparams"] ) if verbose: print("Resampling Ab,Q (AR parameters)") params["Ab"], params["Q"] = resample_ar_params( seed, **data, **states, **params, **hypparams["ar_hypparams"] ) if verbose: print("Resampling z (discrete latent states)") states["z"] = resample_discrete_stateseqs(seed, **data, **states, **params) return { "seed": seed, "states": states, "params": params, "hypparams": hypparams, }