AR-HMM
Initialization
Functions:
|
Initialize the autoregression parameters by sampling from an MNIW distribution. |
|
Initialize the latent states of the ARHMM from the data and parameters. |
|
Initialize the parameters of the ARHMM from the data and hyperparameters. |
|
Formats the hyperparameter dictionary of the ARHMM. |
|
Initialize an ARHMM model dict containing the hyperparameters and initial seed, states, and parameters. |
- jax_moseq.models.arhmm.initialize.init_ar_params(seed, *, num_states, nu_0, S_0, M_0, K_0, **kwargs)[source]
Initialize the autoregression parameters by sampling from an MNIW distribution. Note below that ar_dim = latent_dim * num_lags + 1.
- Parameters:
seed (jr.PRNGKey) – JAX random seed.
num_states (int) – Max number of HMM states.
nu_0 (int) – Inverse-Wishart degrees of freedom parameter for Q.
S_0 (jax array of shape (latent_dim, latent_dim)) – Inverse-Wishart scale parameter for Q.
M_0 (jax array of shape (latent_dim, ar_dim)) – Matrix normal expectation for Ab.
K_0 (jax array of shape (ar_dim, ar_dim)) – Matrix normal column scale parameter for Ab.
**kwargs (dict) – Overflow, for convenience.
- Returns:
Ab (jax array of shape (num_states, latent_dim, ar_dim)) – Autoregressive transforms.
Q (jax array of shape (num_states, latent_dim, latent_dim)) – Autoregressive noise covariances.
- jax_moseq.models.arhmm.initialize.init_states(seed, x, mask, params, **kwargs)[source]
Initialize the latent states of the ARHMM from the data and parameters.
- Parameters:
seed (jr.PRNGKey) – JAX random seed.
x (jax array of shape (N, T, latent_dim)) – Latent trajectories.
mask (jax array of shape (N, T)) – Binary indicator for valid frames.
params (dict) – Values for each model parameter.
**kwargs (dict) – Overflow, for convenience.
- Returns:
states – State values for each latent variable.
- Return type:
dict
- jax_moseq.models.arhmm.initialize.init_params(seed, trans_hypparams, ar_hypparams, **kwargs)[source]
Initialize the parameters of the ARHMM from the data and hyperparameters.
- Parameters:
seed (jr.PRNGKey) – JAX random seed.
trans_hypparams (dict) – HDP transition hyperparameters.
ar_hypparams (dict) – Autoregression hyperparameters.
**kwargs (dict) – Overflow, for convenience.
- Returns:
params – Values for each model parameter.
- Return type:
dict
- jax_moseq.models.arhmm.initialize.init_hyperparams(trans_hypparams, ar_hypparams, **kwargs)[source]
Formats the hyperparameter dictionary of the ARHMM.
- Parameters:
trans_hypparams (dict) – HDP transition hyperparameters.
ar_hypparams (dict) – Autoregression hyperparameters.
**kwargs (dict, optional) – Overflow, for convenience.
- Returns:
hypparams – Values for each group of hyperparameters.
- Return type:
dict
- jax_moseq.models.arhmm.initialize.init_model(data=None, states=None, params=None, hypparams=None, seed=Array([0, 0], dtype=uint32), trans_hypparams=None, ar_hypparams=None, verbose=False, **kwargs)[source]
Initialize an ARHMM model dict containing the hyperparameters and initial seed, states, and parameters.
- Parameters:
data (dict, optional) – Data dictionary containing the observations and mask. Must be provided if
statesnot precomputed.states (dict, optional) – State values for each latent variable, if precomputed.
params (dict, optional) – Values for each model parameter, if precomputed.
hypparams (dict, optional) – Values for each group of hyperparameters. If not provided, caller must provide each arg of
init_hypparams.seed (int or jr.PRNGKey, default=jr.PRNGKey(0)) – Initial JAX random seed.
trans_hypparams (dict, optional) – HDP transition hyperparameters. Must be provided if
hypparamsnot provided.ar_hypparams (dict, optional) – Autoregression hyperparameters. Must be provided if
hypparamsnot provided.verbose (bool, default=False) – Whether to print progress info during initialization.
**kwargs (dict, optional) – Overflow, for convenience.
- Returns:
model – Dictionary containing the hyperparameters and initial seed, states, and parameters of the model.
- Return type:
dict
- Raises:
ValueError – If the subset of the parameters provided by the caller is insufficient for model initialization.
Gibbs sampling
Functions:
|
Resamples the discrete state sequence |
|
Computes the marginal probability of each discrete state at each time step. |
|
Resamples the AR parameters |
|
Resamples the ARHMM model given the hyperparameters, data, current states, and current parameters. |
- jax_moseq.models.arhmm.gibbs.resample_discrete_stateseqs(seed, x, mask, Ab, Q, pi, **kwargs)[source]
Resamples the discrete state sequence
z.- Parameters:
seed (jr.PRNGKey) – JAX random seed.
x (jax array of shape (N, T, latent_dim)) – Latent trajectories.
mask (jax array of shape (N, T)) – Binary indicator for valid frames.
Ab (jax array of shape (num_states, latent_dim, ar_dim)) – Autoregressive transforms.
Q (jax array of shape (num_states, latent_dim, latent_dim)) – Autoregressive noise covariances.
pi (jax_array of shape (num_states, num_states)) – Transition probabilities.
**kwargs (dict) – Overflow, for convenience.
- Returns:
z – Discrete state sequences.
- Return type:
jax_array of shape (N, T - n_lags)
- jax_moseq.models.arhmm.gibbs.stateseq_marginals(x, mask, Ab, Q, pi, **kwargs)[source]
Computes the marginal probability of each discrete state at each time step.
- Parameters:
x (jax array of shape (N, T, latent_dim)) – Latent trajectories.
mask (jax array of shape (N, T)) – Binary indicator for valid frames.
Ab (jax array of shape (num_states, latent_dim, ar_dim)) – Autoregressive transforms.
Q (jax array of shape (num_states, latent_dim, latent_dim)) – Autoregressive noise covariances.
pi (jax_array of shape (num_states, num_states)) – Transition probabilities.
**kwargs (dict) – Overflow, for convenience.
- Returns:
z_marginals – Marginal probability of each discrete state at each time step.
- Return type:
jax array of shape (N, T, num_states)
- jax_moseq.models.arhmm.gibbs.resample_ar_params(seed, *, nlags, num_states, mask, x, z, nu_0, S_0, M_0, K_0, **kwargs)[source]
Resamples the AR parameters
AbandQ.- Parameters:
seed (jr.PRNGKey) – JAX random seed.
nlags (int) – Number of autoregressive lags.
num_states (int) – Max number of HMM states.
mask (jax array of shape (N, T)) – Binary indicator for valid frames.
x (jax array of shape (N, T, latent_dim)) – Latent trajectories.
z (jax_array of shape (N, T - n_lags)) – Discrete state sequences.
nu_0 (int) – Inverse-Wishart degrees of freedom parameter for Q.
S_0 (jax array of shape (latent_dim, latent_dim)) – Inverse-Wishart scale parameter for Q.
M_0 (jax array of shape (latent_dim, ar_dim)) – Matrix normal expectation for Ab.
K_0 (jax array of shape (ar_dim, ar_dim)) – Matrix normal column scale parameter for Ab.
**kwargs (dict) – Overflow, for convenience.
- Returns:
Ab (jax array of shape (num_states, latent_dim, ar_dim)) – Autoregressive transforms.
Q (jax array of shape (num_states, latent_dim, latent_dim)) – Autoregressive noise covariances.
- jax_moseq.models.arhmm.gibbs.resample_model(data, seed, states, params, hypparams, states_only=False, verbose=False, **kwargs)[source]
Resamples the ARHMM model given the hyperparameters, data, current states, and current parameters.
- Parameters:
data (dict) – Data dictionary containing the observations and mask.
seed (jr.PRNGKey) – JAX random seed.
states (dict) – State values for each latent variable.
params (dict) – Values for each model parameter.
hypparams (dict) – Values for each group of hyperparameters.
states_only (bool, default=False) – Whether to restrict sampling to states.
**kwargs (dict) – Overflow, for convenience.
- Returns:
model – Dictionary containing the hyperparameters and updated seed, states, and parameters of the model.
- Return type:
dict
Log probability
Functions:
|
Calculate the log probability of a discrete state sequence at each timestep given a matrix of transition probabilities. |
|
Calculate the log probability of the trajectory |
|
Calculate the total log probability for each latent state |
|
Convenience function that invokes |
|
Calculate log likelihoods of frames assigned to each state, given the dynamics of each other state. |
|
Marginal log likelihood of continuous latents given model parameters. |
- jax_moseq.models.arhmm.log_prob.discrete_stateseq_log_prob(z, pi, **kwargs)[source]
Calculate the log probability of a discrete state sequence at each timestep given a matrix of transition probabilities.
- Parameters:
z (jax_array of shape (..., T - n_lags)) – Discrete state sequences.
pi (jax_array of shape (num_states, num_states)) – Transition probabilities.
**kwargs (dict) – Overflow, for convenience.
- Returns:
log_pz – Log probability of
z.- Return type:
jax array of shape (…, T - 1)
- jax_moseq.models.arhmm.log_prob.continuous_stateseq_log_prob(x, z, Ab, Q, **kwargs)[source]
Calculate the log probability of the trajectory
xat each time step, given switching autoregressive (AR) parameters- Parameters:
x (jax array of shape (..., T, latent_dim)) – Latent trajectories.
z (jax_array of shape (..., T - n_lags)) – Discrete state sequences.
Ab (jax array of shape (num_states, latent_dim, ar_dim)) – Autoregressive transforms.
Q (jax array of shape (num_states, latent_dim, latent_dim)) – Autoregressive noise covariances.
**kwargs (dict) – Overflow, for convenience.
- Returns:
log_px – Log probability of
x.- Return type:
jax array of shape (…, T - n_lags)
- jax_moseq.models.arhmm.log_prob.log_joint_likelihood(x, mask, z, pi, Ab, Q, **kwargs)[source]
Calculate the total log probability for each latent state
- Parameters:
x (jax array of shape (..., T, latent_dim)) – Latent trajectories.
mask (jax array) – Binary indicator for which data points are valid.
z (jax_array of shape (..., T - n_lags)) – Discrete state sequences.
pi (jax_array of shape (num_states, num_states)) – Transition probabilities.
Ab (jax array of shape (num_states, latent_dim, ar_dim)) – Autoregressive transforms.
Q (jax array of shape (num_states, latent_dim, latent_dim)) – Autoregressive noise covariances.
**kwargs (dict) – Overflow, for convenience.
- Returns:
ll – Dictionary mapping state variable name to its total log probability.
- Return type:
dict
- jax_moseq.models.arhmm.log_prob.model_likelihood(data, states, params, hypparams=None, **kwargs)[source]
Convenience function that invokes
jax_moseq.models.arhmm.log_prob.log_joint_likelihood().- Parameters:
data (dict) – Data dictionary containing the observations and mask.
states (dict) – State values for each latent variable.
params (dict) – Values for each model parameter.
hypparams (dict, optional) – Values for each group of hyperparameters.
**kwargs (dict) – Overflow, for convenience.
- Returns:
ll – Dictionary mapping state variable name to its total log probability.
- Return type:
dict
- jax_moseq.models.arhmm.log_prob.state_cross_likelihoods(params, states, mask, **kwargs)[source]
Calculate log likelihoods of frames assigned to each state, given the dynamics of each other state. See page 33 of the supplement (Wiltchsko, 2015) for a formal definition.
- jax_moseq.models.arhmm.log_prob.marginal_log_likelihood(mask, x, Ab, Q, pi, **kwargs)[source]
Marginal log likelihood of continuous latents given model parameters.
- Parameters:
mask (jax array) – Binary indicator for which data points are valid.
x (jax array of shape (..., T, latent_dim)) – Latent trajectories.
Ab (jax array of shape (num_states, latent_dim, ar_dim)) – Autoregressive transforms.
Q (jax array of shape (num_states, latent_dim, latent_dim)) – Autoregressive noise covariances.
pi (jax_array of shape (num_states, num_states)) – Transition probabilities.
- Returns:
ml – Marginal log likelihood.
- Return type:
float