SLDS
Initialization
Functions:
|
Initialize the observation (also known as "emission") parameters of the model using PCA. |
|
Initialize the continuous latents by applying the inverse of the emission transform to the data. |
|
Initialize the latent states of the SLDS from the data, parameters, and hyperparameters. |
|
Initialize the parameters of the SLDS from the data and hyperparameters. |
|
Formats the hyperparameter dictionary of the SLDS. |
|
Initialize a SLDS model dict containing the hyperparameters, noise prior, and initial seed, states, and parameters. |
- jax_moseq.models.slds.initialize.init_obs_params(pca, Y, mask, whiten, latent_dim, **kwargs)[source]
Initialize the observation (also known as “emission”) parameters of the model using PCA.
- Parameters:
pca (sklearn.decomposition._pca.PCA) – PCA object fit to observations.
Y (jax array of shape (N, T, obs_dim)) – Observations.
mask (jax array of shape (N, T)) – Binary indicator for valid frames.
whiten (bool) – Whether to whiten PC’s to initialize continuous latents.
latent_dim (int) – Dimensionality of continuous latents.
**kwargs (dict) – Overflow, for convenience.
- Returns:
Cd – Observation transform.
- Return type:
jax array of shape (obs_dim, latent_dim + 1)
- jax_moseq.models.slds.initialize.init_continuous_stateseqs(Y, Cd, **kwargs)[source]
Initialize the continuous latents by applying the inverse of the emission transform to the data.
- Parameters:
Y (jax array of shape (N, T, obs_dim)) – Observations.
Cd (jax array of shape (obs_dim, latent_dim + 1)) – Observation transform.
**kwargs (dict) – Overflow, for convenience.
- Returns:
x – Latent trajectories.
- Return type:
jax array of shape (N, T, latent_dim)
- jax_moseq.models.slds.initialize.init_states(seed, Y, mask, params, obs_hypparams, **kwargs)[source]
Initialize the latent states of the SLDS from the data, parameters, and hyperparameters.
- Parameters:
seed (jr.PRNGKey) – JAX random seed.
Y (jax array of shape (N, T, obs_dim)) – Observations.
mask (jax array of shape (N, T)) – Binary indicator for valid frames.
params (dict) – Values for each model parameter.
obs_hypparams (dict) – Observation hyperparameters.
**kwargs (dict) – Overflow, for convenience.
- Returns:
states – State values for each latent variable.
- Return type:
dict
- jax_moseq.models.slds.initialize.init_params(seed, pca, Y, mask, trans_hypparams, ar_hypparams, whiten, **kwargs)[source]
Initialize the parameters of the SLDS from the data and hyperparameters.
- Parameters:
seed (jr.PRNGKey) – JAX random seed.
pca (sklearn.decomposition._pca.PCA) – PCA object fit to observations.
Y (jax array of shape (N, T, obs_dim)) – Observations.
mask (jax array of shape (N, T)) – Binary indicator for valid frames.
trans_hypparams (dict) – HDP transition hyperparameters.
ar_hypparams (dict) – Autoregression hyperparameters.
whiten (bool) – Whether to whiten PC’s to initialize continuous latents.
**kwargs (dict) – Overflow, for convenience.
- Returns:
params – Values for each model parameter.
- Return type:
dict
- jax_moseq.models.slds.initialize.init_hyperparams(trans_hypparams, ar_hypparams, obs_hypparams, **kwargs)[source]
Formats the hyperparameter dictionary of the SLDS.
- Parameters:
trans_hypparams (dict) – HDP transition hyperparameters.
ar_hypparams (dict) – Autoregression hyperparameters.
obs_hypparams (dict) – Observation hyperparameters.
**kwargs (dict) – Overflow, for convenience.
- Returns:
hypparams – Values for each group of hyperparameters.
- Return type:
dict
- jax_moseq.models.slds.initialize.init_model(data=None, states=None, params=None, hypparams=None, seed=Array([0, 0], dtype=uint32), pca=None, whiten=True, PCA_fitting_num_frames=1000000, trans_hypparams=None, ar_hypparams=None, obs_hypparams=None, verbose=False, **kwargs)[source]
Initialize a SLDS model dict containing the hyperparameters, noise prior, and initial seed, states, and parameters.
- Parameters:
data (dict, optional) – Data dictionary containing the observations, mask, and (optionally) confidences. Must be provided if states or params not precomputed.
states (dict, optional) – State values for each latent variable, if precomputed.
params (dict, optional) – Values for each model parameter, if precomputed.
hypparams (dict, optional) – Values for each group of hyperparameters. If not provided, caller must provide each arg of init_hypparams.
seed (int or jr.PRNGKey, default=jr.PRNGKey(0)) – Initial random seed value.
pca (sklearn.decomposition._pca.PCA, optional) – PCA object fit to observations.
whiten (bool, default=True) – Whether to whiten PC’s to initialize continuous latents.
PCA_fitting_num_frames (int, default=1000000) – Maximum number of datapoints to sample for PCA fitting, if pca is not provided.
trans_hypparams (dict, optional) – HDP transition hyperparameters. Must be provided if hypparams not provided.
ar_hypparams (dict, optional) – Autoregression hyperparameters. Must be provided if hypparams not provided.
obs_hypparams (dict, optional) – Observation hyperparameters. Must be provided if hypparams not provided.
verbose (bool, default=False) – Whether to print progress info during initialization.
**kwargs (dict, optional) – Unused. For convenience, enables user to invoke function by unpacking dict that contains keys not used by the method.
- Returns:
model – Dictionary containing the hyperparameters, noise prior, and initial seed, states, and parameters of the model.
- Return type:
dict
- Raises:
ValueError – If the subset of the parameters provided by the caller is insufficient for model initialization.
Gibbs sampling
Functions:
|
Resample the latent trajectories x. |
|
Resample the observation variance sigmasq. |
|
Resample the observation variance sigmasq using the squared error between predicted and true observations. |
|
Resample the scale values s. |
|
Resample the scale values s using the squared error between predicted and true observations. |
|
Computes the squared error between model predicted and true observations. |
|
Resamples the SLDS model given the hyperparameters, data, current states, and current parameters. |
- jax_moseq.models.slds.gibbs.resample_continuous_stateseqs(seed, y, mask, z, s, Ab, Q, Cd, sigmasq, jitter=0.001, parallel_message_passing=True, **kwargs)[source]
Resample the latent trajectories x.
- Parameters:
seed (jr.PRNGKey) – JAX random seed.
y (jax.Array of shape (n_recordings, n_timesteps, obs_dim)) – Observations.
mask (jax.Array of shape (n_recordings, n_timesteps)) – Binary indicator, 1=valid frames, 0=invalid frames.
z (jax.Array of shape (n_recordings, n_timesteps-n_lags)) – Discrete state sequences, taking integer values between [0, n_states), for timesteps [n_lags, n_timesteps),
s (jax.Array of shape (n_recordings, n_timesteps, obs_dim)) – Observation noise scales.
Ab (jax.Array of shape (n_states, latent_dim, ar_dim + 1)) – Autoregressive dynamics and bias, where ar_dim = latent_dim * n_lags
Q (jax.Array of shape (n_states, latent_dim, latent_dim)) – Autoregressive noise covariances.
Cd (jax.Array of shape (obs_dim, latent_dim + 1)) – Affine transform from latent_dim to state_dim
sigmasq (jax.Array of shape (obs_dim,)) – Unscaled noise.
jitter (float, default=1e-3) – Amount to boost the diagonal of the covariance matrix during backward-sampling of the continuous states.
parallel_message_passing (bool, default=True,) – Use associative scan for Kalman sampling, which is faster on a GPU but has a significantly longer jit time.
**kwargs (dict) – Overflow, for convenience.
- Returns:
x – Posterior sample of latent trajectories.
- Return type:
jax.Array of shape (n_recordings, n_timesteps, latent_dim)
- jax_moseq.models.slds.gibbs.resample_obs_variance(seed, Y, mask, x, s, Cd, nu_sigma, sigmasq_0, **kwargs)[source]
Resample the observation variance sigmasq.
- Parameters:
seed (jr.PRNGKey) – JAX random seed.
Y (jax array of shape (N, T, obs_dim)) – Observations.
mask (jax array of shape (N, T)) – Binary indicator for valid frames.
x (jax array of shape (N, T, latent_dim)) – Latent trajectories.
s (jax array of shape (N, T, obs_dim)) – Noise scales.
Cd (jax array of shape (obs_dim, latent_dim + 1)) – Observation transform.
nu_sigma (float) – Chi-squared degrees of freedom in sigmasq.
sigmasq_0 (float) – Scaled inverse chi-squared scaling parameter for sigmasq.
**kwargs (dict) – Overflow, for convenience.
- Returns:
sigmasq – Unscaled noise.
- Return type:
jax_array of shape obs_dim
- jax_moseq.models.slds.gibbs.resample_obs_variance_from_sqerr(seed, sqerr, mask, s, nu_sigma, sigmasq_0, **kwargs)[source]
Resample the observation variance sigmasq using the squared error between predicted and true observations.
- Parameters:
seed (jr.PRNGKey) – JAX random seed.
sqerr (jax array of shape (N, T, obs_dim)) – Squared error between predicted and true observations.
mask (jax array of shape (N, T)) – Binary indicator for valid frames.
s (jax array of shape (N, T, obs_dim)) – Noise scales.
nu_sigma (float) – Chi-squared degrees of freedom in sigmasq.
sigmasq_0 (float) – Scaled inverse chi-squared scaling parameter for sigmasq.
**kwargs (dict) – Overflow, for convenience.
- Returns:
sigmasq – Unscaled noise.
- Return type:
jax_array of shape obs_dim
- jax_moseq.models.slds.gibbs.resample_scales(seed, Y, x, Cd, sigmasq, nu_s, s_0, **kwargs)[source]
Resample the scale values s.
- Parameters:
seed (jr.PRNGKey) – JAX random seed.
Y (jax array of shape (N, T, obs_dim)) – Observations.
x (jax array of shape (N, T, latent_dim)) – Latent trajectories.
Cd (jax array of shape (obs_dim, latent_dim + 1)) – Observation transform.
sigmasq (jax_array of shape obs_dim) – Unscaled noise.
nu_s (int) – Chi-squared degrees of freedom in noise prior.
s_0 (scalar or jax array broadcastable to Y) – Prior on noise scale.
**kwargs (dict) – Overflow, for convenience.
- Returns:
s – Noise scales.
- Return type:
jax array of shape (N, T, obs_dim)
- jax_moseq.models.slds.gibbs.resample_scales_from_sqerr(seed, sqerr, sigmasq, nu_s, s_0, **kwargs)[source]
Resample the scale values s using the squared error between predicted and true observations.
- Parameters:
seed (jr.PRNGKey) – JAX random seed.
sqerr (jax array of shape (N, T, obs_dim)) – Squared error between predicted and true observations.
sigmasq (jax_array of shape obs_dim) – Unscaled noise.
nu_s (int) – Chi-squared degrees of freedom in noise prior.
s_0 (scalar or jax array broadcastable to Y) – Prior on noise scale.
**kwargs (dict) – Overflow, for convenience.
- Returns:
s – Per observation noise scales.
- Return type:
jax array of shape (N, T, obs_dim)
- jax_moseq.models.slds.gibbs.compute_squared_error(Y, x, Cd, mask=None)[source]
Computes the squared error between model predicted and true observations.
- Parameters:
Y (jax array of shape (..., obs_dim)) – Observations.
x (jax array of shape (..., latent_dim)) – Latent trajectories.
Cd (jax array of shape (obs_dim, latent_dim + 1)) – Observation transform.
mask (jax array of shape (...), optional) – Binary indicator for valid frames.
- Returns:
sqerr – Squared error between model predicted and true observations.
- Return type:
jax array of shape (…, obs_dim)
- jax_moseq.models.slds.gibbs.resample_model(data, seed, states, params, hypparams, ar_only=False, states_only=False, skip_noise=True, parallel_message_passing=False, **kwargs)[source]
Resamples the SLDS model given the hyperparameters, data, current states, and current parameters.
- Parameters:
data (dict) – Data dictionary containing the observations and mask.
seed (jr.PRNGKey) – JAX random seed.
states (dict) – State values for each latent variable.
params (dict) – Values for each model parameter.
hypparams (dict) – Values for each group of hyperparameters.
ar_only (bool, default=False) – Whether to restrict sampling to ARHMM components.
states_only (bool, default=False) – Whether to restrict sampling to states.
skip_noise (bool, default=True) – Whether to exclude sigmasq and s from resampling.
parallel_message_passing (bool, default=True,) – Use associative scan for Kalman sampling, which is faster on a GPU but has a significantly longer jit time.
**kwargs (dict) – Overflow, for convenience.
- Returns:
model – Dictionary containing the hyperparameters and updated seed, states, and parameters of the model.
- Return type:
dict
Log probability
Functions:
|
Calculate the log probability of the noise scale s for each datapoint given the noise prior, which is a scaled inverse chi-squared distribution. |
|
Calculate the log probability of the observations at each time-step given the latent trajectories, noise parameters, and observation matrix. |
|
Calculate the total log probability for each latent state. |
|
Convenience class that invokes log_joint_likelihood. |
- jax_moseq.models.slds.log_prob.scale_log_prob(s, s_0, nu_s, **kwargs)[source]
Calculate the log probability of the noise scale s for each datapoint given the noise prior, which is a scaled inverse chi-squared distribution.
- Parameters:
s (jax array) – Noise scales.
s_0 (scalar or jax array, broadcastable to s) – Prior on noise scale.
nu_s (int) – Chi-squared degrees of freedom in noise prior.
**kwargs (dict) – Overflow, for convenience.
- Returns:
log_ps – Log probability of s.
- Return type:
jax array
- jax_moseq.models.slds.log_prob.obs_log_prob(Y, x, s, Cd, sigmasq, **kwargs)[source]
Calculate the log probability of the observations at each time-step given the latent trajectories, noise parameters, and observation matrix.
- Parameters:
Y (jax array of shape (..., obs_dim)) – Observations.
x (jax array of shape (..., latent_dim)) – Latent trajectories.
s (jax array of shape (..., obs_dim)) – Noise scales.
Cd (jax array of shape (obs_dim, latent_dim + 1)) – Observation transform.
sigmasq (jax_array of shape obs_dim) – Unscaled noise.
**kwargs (dict) – Overflow, for convenience.
- Returns:
log_pY – Log probability of Y.
- Return type:
jax array of shape (…, obs_dim)
- jax_moseq.models.slds.log_prob.log_joint_likelihood(Y, mask, x, s, z, pi, Ab, Q, Cd, sigmasq, s_0, nu_s, **kwargs)[source]
Calculate the total log probability for each latent state.
- Parameters:
Y (jax array of shape (..., T, obs_dim)) – Observations.
mask (jax array of shape (..., T)) – Binary indicator for valid frames.
x (jax array of shape (..., T, latent_dim)) – Latent trajectories.
s (jax array of shape (..., T, obs_dim)) – Noise scales.
z (jax_array of shape (..., T - n_lags)) – Discrete state sequences.
pi (jax_array of shape (num_states, num_states)) – Transition probabilities.
Ab (jax array of shape (num_states, latent_dim, ar_dim)) – Autoregressive transforms.
Q (jax array of shape (num_states, latent_dim, latent_dim)) – Autoregressive noise covariances.
Cd (jax array of shape (obs_dim, latent_dim + 1)) – Observation transform.
sigmasq (jax_array of shape obs_dim) – Unscaled noise.
s_0 (scalar or jax array broadcastable to Y) – Prior on noise scale.
nu_s (int) – Chi-squared degrees of freedom in noise prior.
**kwargs (dict) – Overflow, for convenience.
- Returns:
ll – Dictionary mapping the name of each state variable to its total log probability.
- Return type:
dict
- jax_moseq.models.slds.log_prob.model_likelihood(data, states, params, hypparams, **kwargs)[source]
Convenience class that invokes log_joint_likelihood.
- Parameters:
data (dict) – Data dictionary containing the observations and mask.
states (dict) – State values for each latent variable.
params (dict) – Values for each model parameter.
hypparams (dict) – Values for each group of hyperparameters.
**kwargs (dict) – Overflow, for convenience.
- Returns:
ll – Dictionary mapping state variable name to its total log probability.
- Return type:
dict