AR-HMM

Initialization

Functions:

init_ar_params(seed, *, num_states, nu_0, ...)

Initialize the autoregression parameters by sampling from an MNIW distribution.

init_states(seed, x, mask, params, **kwargs)

Initialize the latent states of the ARHMM from the data and parameters.

init_params(seed, trans_hypparams, ...)

Initialize the parameters of the ARHMM from the data and hyperparameters.

init_hyperparams(trans_hypparams, ...)

Formats the hyperparameter dictionary of the ARHMM.

init_model([data, states, params, ...])

Initialize an ARHMM model dict containing the hyperparameters and initial seed, states, and parameters.

jax_moseq.models.arhmm.initialize.init_ar_params(seed, *, num_states, nu_0, S_0, M_0, K_0, **kwargs)[source]

Initialize the autoregression parameters by sampling from an MNIW distribution. Note below that ar_dim = latent_dim * num_lags + 1.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • num_states (int) – Max number of HMM states.

  • nu_0 (int) – Inverse-Wishart degrees of freedom parameter for Q.

  • S_0 (jax array of shape (latent_dim, latent_dim)) – Inverse-Wishart scale parameter for Q.

  • M_0 (jax array of shape (latent_dim, ar_dim)) – Matrix normal expectation for Ab.

  • K_0 (jax array of shape (ar_dim, ar_dim)) – Matrix normal column scale parameter for Ab.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

  • Ab (jax array of shape (num_states, latent_dim, ar_dim)) – Autoregressive transforms.

  • Q (jax array of shape (num_states, latent_dim, latent_dim)) – Autoregressive noise covariances.

jax_moseq.models.arhmm.initialize.init_states(seed, x, mask, params, **kwargs)[source]

Initialize the latent states of the ARHMM from the data and parameters.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • x (jax array of shape (N, T, latent_dim)) – Latent trajectories.

  • mask (jax array of shape (N, T)) – Binary indicator for valid frames.

  • params (dict) – Values for each model parameter.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

states – State values for each latent variable.

Return type:

dict

jax_moseq.models.arhmm.initialize.init_params(seed, trans_hypparams, ar_hypparams, **kwargs)[source]

Initialize the parameters of the ARHMM from the data and hyperparameters.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • trans_hypparams (dict) – HDP transition hyperparameters.

  • ar_hypparams (dict) – Autoregression hyperparameters.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

params – Values for each model parameter.

Return type:

dict

jax_moseq.models.arhmm.initialize.init_hyperparams(trans_hypparams, ar_hypparams, **kwargs)[source]

Formats the hyperparameter dictionary of the ARHMM.

Parameters:
  • trans_hypparams (dict) – HDP transition hyperparameters.

  • ar_hypparams (dict) – Autoregression hyperparameters.

  • **kwargs (dict, optional) – Overflow, for convenience.

Returns:

hypparams – Values for each group of hyperparameters.

Return type:

dict

jax_moseq.models.arhmm.initialize.init_model(data=None, states=None, params=None, hypparams=None, seed=Array([0, 0], dtype=uint32), trans_hypparams=None, ar_hypparams=None, verbose=False, **kwargs)[source]

Initialize an ARHMM model dict containing the hyperparameters and initial seed, states, and parameters.

Parameters:
  • data (dict, optional) – Data dictionary containing the observations and mask. Must be provided if states not precomputed.

  • states (dict, optional) – State values for each latent variable, if precomputed.

  • params (dict, optional) – Values for each model parameter, if precomputed.

  • hypparams (dict, optional) – Values for each group of hyperparameters. If not provided, caller must provide each arg of init_hypparams.

  • seed (int or jr.PRNGKey, default=jr.PRNGKey(0)) – Initial JAX random seed.

  • trans_hypparams (dict, optional) – HDP transition hyperparameters. Must be provided if hypparams not provided.

  • ar_hypparams (dict, optional) – Autoregression hyperparameters. Must be provided if hypparams not provided.

  • verbose (bool, default=False) – Whether to print progress info during initialization.

  • **kwargs (dict, optional) – Overflow, for convenience.

Returns:

model – Dictionary containing the hyperparameters and initial seed, states, and parameters of the model.

Return type:

dict

Raises:

ValueError – If the subset of the parameters provided by the caller is insufficient for model initialization.

Gibbs sampling

Functions:

resample_discrete_stateseqs(seed, x, mask, ...)

Resamples the discrete state sequence z.

stateseq_marginals(x, mask, Ab, Q, pi, **kwargs)

Computes the marginal probability of each discrete state at each time step.

resample_ar_params(seed, *, nlags, ...)

Resamples the AR parameters Ab and Q.

resample_model(data, seed, states, params, ...)

Resamples the ARHMM model given the hyperparameters, data, current states, and current parameters.

jax_moseq.models.arhmm.gibbs.resample_discrete_stateseqs(seed, x, mask, Ab, Q, pi, **kwargs)[source]

Resamples the discrete state sequence z.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • x (jax array of shape (N, T, latent_dim)) – Latent trajectories.

  • mask (jax array of shape (N, T)) – Binary indicator for valid frames.

  • Ab (jax array of shape (num_states, latent_dim, ar_dim)) – Autoregressive transforms.

  • Q (jax array of shape (num_states, latent_dim, latent_dim)) – Autoregressive noise covariances.

  • pi (jax_array of shape (num_states, num_states)) – Transition probabilities.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

z – Discrete state sequences.

Return type:

jax_array of shape (N, T - n_lags)

jax_moseq.models.arhmm.gibbs.stateseq_marginals(x, mask, Ab, Q, pi, **kwargs)[source]

Computes the marginal probability of each discrete state at each time step.

Parameters:
  • x (jax array of shape (N, T, latent_dim)) – Latent trajectories.

  • mask (jax array of shape (N, T)) – Binary indicator for valid frames.

  • Ab (jax array of shape (num_states, latent_dim, ar_dim)) – Autoregressive transforms.

  • Q (jax array of shape (num_states, latent_dim, latent_dim)) – Autoregressive noise covariances.

  • pi (jax_array of shape (num_states, num_states)) – Transition probabilities.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

z_marginals – Marginal probability of each discrete state at each time step.

Return type:

jax array of shape (N, T, num_states)

jax_moseq.models.arhmm.gibbs.resample_ar_params(seed, *, nlags, num_states, mask, x, z, nu_0, S_0, M_0, K_0, **kwargs)[source]

Resamples the AR parameters Ab and Q.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • nlags (int) – Number of autoregressive lags.

  • num_states (int) – Max number of HMM states.

  • mask (jax array of shape (N, T)) – Binary indicator for valid frames.

  • x (jax array of shape (N, T, latent_dim)) – Latent trajectories.

  • z (jax_array of shape (N, T - n_lags)) – Discrete state sequences.

  • nu_0 (int) – Inverse-Wishart degrees of freedom parameter for Q.

  • S_0 (jax array of shape (latent_dim, latent_dim)) – Inverse-Wishart scale parameter for Q.

  • M_0 (jax array of shape (latent_dim, ar_dim)) – Matrix normal expectation for Ab.

  • K_0 (jax array of shape (ar_dim, ar_dim)) – Matrix normal column scale parameter for Ab.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

  • Ab (jax array of shape (num_states, latent_dim, ar_dim)) – Autoregressive transforms.

  • Q (jax array of shape (num_states, latent_dim, latent_dim)) – Autoregressive noise covariances.

jax_moseq.models.arhmm.gibbs.resample_model(data, seed, states, params, hypparams, states_only=False, verbose=False, **kwargs)[source]

Resamples the ARHMM model given the hyperparameters, data, current states, and current parameters.

Parameters:
  • data (dict) – Data dictionary containing the observations and mask.

  • seed (jr.PRNGKey) – JAX random seed.

  • states (dict) – State values for each latent variable.

  • params (dict) – Values for each model parameter.

  • hypparams (dict) – Values for each group of hyperparameters.

  • states_only (bool, default=False) – Whether to restrict sampling to states.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

model – Dictionary containing the hyperparameters and updated seed, states, and parameters of the model.

Return type:

dict

Log probability

Functions:

discrete_stateseq_log_prob(z, pi, **kwargs)

Calculate the log probability of a discrete state sequence at each timestep given a matrix of transition probabilities.

continuous_stateseq_log_prob(x, z, Ab, Q, ...)

Calculate the log probability of the trajectory x at each time step, given switching autoregressive (AR) parameters

log_joint_likelihood(x, mask, z, pi, Ab, Q, ...)

Calculate the total log probability for each latent state

model_likelihood(data, states, params[, ...])

Convenience function that invokes jax_moseq.models.arhmm.log_prob.log_joint_likelihood().

state_cross_likelihoods(params, states, ...)

Calculate log likelihoods of frames assigned to each state, given the dynamics of each other state.

marginal_log_likelihood(mask, x, Ab, Q, pi, ...)

Marginal log likelihood of continuous latents given model parameters.

jax_moseq.models.arhmm.log_prob.discrete_stateseq_log_prob(z, pi, **kwargs)[source]

Calculate the log probability of a discrete state sequence at each timestep given a matrix of transition probabilities.

Parameters:
  • z (jax_array of shape (..., T - n_lags)) – Discrete state sequences.

  • pi (jax_array of shape (num_states, num_states)) – Transition probabilities.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

log_pz – Log probability of z.

Return type:

jax array of shape (…, T - 1)

jax_moseq.models.arhmm.log_prob.continuous_stateseq_log_prob(x, z, Ab, Q, **kwargs)[source]

Calculate the log probability of the trajectory x at each time step, given switching autoregressive (AR) parameters

Parameters:
  • x (jax array of shape (..., T, latent_dim)) – Latent trajectories.

  • z (jax_array of shape (..., T - n_lags)) – Discrete state sequences.

  • Ab (jax array of shape (num_states, latent_dim, ar_dim)) – Autoregressive transforms.

  • Q (jax array of shape (num_states, latent_dim, latent_dim)) – Autoregressive noise covariances.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

log_px – Log probability of x.

Return type:

jax array of shape (…, T - n_lags)

jax_moseq.models.arhmm.log_prob.log_joint_likelihood(x, mask, z, pi, Ab, Q, **kwargs)[source]

Calculate the total log probability for each latent state

Parameters:
  • x (jax array of shape (..., T, latent_dim)) – Latent trajectories.

  • mask (jax array) – Binary indicator for which data points are valid.

  • z (jax_array of shape (..., T - n_lags)) – Discrete state sequences.

  • pi (jax_array of shape (num_states, num_states)) – Transition probabilities.

  • Ab (jax array of shape (num_states, latent_dim, ar_dim)) – Autoregressive transforms.

  • Q (jax array of shape (num_states, latent_dim, latent_dim)) – Autoregressive noise covariances.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

ll – Dictionary mapping state variable name to its total log probability.

Return type:

dict

jax_moseq.models.arhmm.log_prob.model_likelihood(data, states, params, hypparams=None, **kwargs)[source]

Convenience function that invokes jax_moseq.models.arhmm.log_prob.log_joint_likelihood().

Parameters:
  • data (dict) – Data dictionary containing the observations and mask.

  • states (dict) – State values for each latent variable.

  • params (dict) – Values for each model parameter.

  • hypparams (dict, optional) – Values for each group of hyperparameters.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

ll – Dictionary mapping state variable name to its total log probability.

Return type:

dict

jax_moseq.models.arhmm.log_prob.state_cross_likelihoods(params, states, mask, **kwargs)[source]

Calculate log likelihoods of frames assigned to each state, given the dynamics of each other state. See page 33 of the supplement (Wiltchsko, 2015) for a formal definition.

jax_moseq.models.arhmm.log_prob.marginal_log_likelihood(mask, x, Ab, Q, pi, **kwargs)[source]

Marginal log likelihood of continuous latents given model parameters.

Parameters:
  • mask (jax array) – Binary indicator for which data points are valid.

  • x (jax array of shape (..., T, latent_dim)) – Latent trajectories.

  • Ab (jax array of shape (num_states, latent_dim, ar_dim)) – Autoregressive transforms.

  • Q (jax array of shape (num_states, latent_dim, latent_dim)) – Autoregressive noise covariances.

  • pi (jax_array of shape (num_states, num_states)) – Transition probabilities.

Returns:

ml – Marginal log likelihood.

Return type:

float