SLDS

Initialization

Functions:

init_obs_params(pca, Y, mask, whiten, ...)

Initialize the observation (also known as "emission") parameters of the model using PCA.

init_continuous_stateseqs(Y, Cd, **kwargs)

Initialize the continuous latents by applying the inverse of the emission transform to the data.

init_states(seed, Y, mask, params, ...)

Initialize the latent states of the SLDS from the data, parameters, and hyperparameters.

init_params(seed, pca, Y, mask, ...)

Initialize the parameters of the SLDS from the data and hyperparameters.

init_hyperparams(trans_hypparams, ...)

Formats the hyperparameter dictionary of the SLDS.

init_model([data, states, params, ...])

Initialize a SLDS model dict containing the hyperparameters, noise prior, and initial seed, states, and parameters.

jax_moseq.models.slds.initialize.init_obs_params(pca, Y, mask, whiten, latent_dim, **kwargs)[source]

Initialize the observation (also known as “emission”) parameters of the model using PCA.

Parameters:
  • pca (sklearn.decomposition._pca.PCA) – PCA object fit to observations.

  • Y (jax array of shape (N, T, obs_dim)) – Observations.

  • mask (jax array of shape (N, T)) – Binary indicator for valid frames.

  • whiten (bool) – Whether to whiten PC’s to initialize continuous latents.

  • latent_dim (int) – Dimensionality of continuous latents.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

Cd – Observation transform.

Return type:

jax array of shape (obs_dim, latent_dim + 1)

jax_moseq.models.slds.initialize.init_continuous_stateseqs(Y, Cd, **kwargs)[source]

Initialize the continuous latents by applying the inverse of the emission transform to the data.

Parameters:
  • Y (jax array of shape (N, T, obs_dim)) – Observations.

  • Cd (jax array of shape (obs_dim, latent_dim + 1)) – Observation transform.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

x – Latent trajectories.

Return type:

jax array of shape (N, T, latent_dim)

jax_moseq.models.slds.initialize.init_states(seed, Y, mask, params, obs_hypparams, **kwargs)[source]

Initialize the latent states of the SLDS from the data, parameters, and hyperparameters.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • Y (jax array of shape (N, T, obs_dim)) – Observations.

  • mask (jax array of shape (N, T)) – Binary indicator for valid frames.

  • params (dict) – Values for each model parameter.

  • obs_hypparams (dict) – Observation hyperparameters.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

states – State values for each latent variable.

Return type:

dict

jax_moseq.models.slds.initialize.init_params(seed, pca, Y, mask, trans_hypparams, ar_hypparams, whiten, **kwargs)[source]

Initialize the parameters of the SLDS from the data and hyperparameters.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • pca (sklearn.decomposition._pca.PCA) – PCA object fit to observations.

  • Y (jax array of shape (N, T, obs_dim)) – Observations.

  • mask (jax array of shape (N, T)) – Binary indicator for valid frames.

  • trans_hypparams (dict) – HDP transition hyperparameters.

  • ar_hypparams (dict) – Autoregression hyperparameters.

  • whiten (bool) – Whether to whiten PC’s to initialize continuous latents.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

params – Values for each model parameter.

Return type:

dict

jax_moseq.models.slds.initialize.init_hyperparams(trans_hypparams, ar_hypparams, obs_hypparams, **kwargs)[source]

Formats the hyperparameter dictionary of the SLDS.

Parameters:
  • trans_hypparams (dict) – HDP transition hyperparameters.

  • ar_hypparams (dict) – Autoregression hyperparameters.

  • obs_hypparams (dict) – Observation hyperparameters.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

hypparams – Values for each group of hyperparameters.

Return type:

dict

jax_moseq.models.slds.initialize.init_model(data=None, states=None, params=None, hypparams=None, seed=Array([0, 0], dtype=uint32), pca=None, whiten=True, PCA_fitting_num_frames=1000000, trans_hypparams=None, ar_hypparams=None, obs_hypparams=None, verbose=False, **kwargs)[source]

Initialize a SLDS model dict containing the hyperparameters, noise prior, and initial seed, states, and parameters.

Parameters:
  • data (dict, optional) – Data dictionary containing the observations, mask, and (optionally) confidences. Must be provided if states or params not precomputed.

  • states (dict, optional) – State values for each latent variable, if precomputed.

  • params (dict, optional) – Values for each model parameter, if precomputed.

  • hypparams (dict, optional) – Values for each group of hyperparameters. If not provided, caller must provide each arg of init_hypparams.

  • seed (int or jr.PRNGKey, default=jr.PRNGKey(0)) – Initial random seed value.

  • pca (sklearn.decomposition._pca.PCA, optional) – PCA object fit to observations.

  • whiten (bool, default=True) – Whether to whiten PC’s to initialize continuous latents.

  • PCA_fitting_num_frames (int, default=1000000) – Maximum number of datapoints to sample for PCA fitting, if pca is not provided.

  • trans_hypparams (dict, optional) – HDP transition hyperparameters. Must be provided if hypparams not provided.

  • ar_hypparams (dict, optional) – Autoregression hyperparameters. Must be provided if hypparams not provided.

  • obs_hypparams (dict, optional) – Observation hyperparameters. Must be provided if hypparams not provided.

  • verbose (bool, default=False) – Whether to print progress info during initialization.

  • **kwargs (dict, optional) – Unused. For convenience, enables user to invoke function by unpacking dict that contains keys not used by the method.

Returns:

model – Dictionary containing the hyperparameters, noise prior, and initial seed, states, and parameters of the model.

Return type:

dict

Raises:

ValueError – If the subset of the parameters provided by the caller is insufficient for model initialization.

Gibbs sampling

Functions:

resample_continuous_stateseqs(seed, y, mask, ...)

Resample the latent trajectories x.

resample_obs_variance(seed, Y, mask, x, s, ...)

Resample the observation variance sigmasq.

resample_obs_variance_from_sqerr(seed, ...)

Resample the observation variance sigmasq using the squared error between predicted and true observations.

resample_scales(seed, Y, x, Cd, sigmasq, ...)

Resample the scale values s.

resample_scales_from_sqerr(seed, sqerr, ...)

Resample the scale values s using the squared error between predicted and true observations.

compute_squared_error(Y, x, Cd[, mask])

Computes the squared error between model predicted and true observations.

resample_model(data, seed, states, params, ...)

Resamples the SLDS model given the hyperparameters, data, current states, and current parameters.

jax_moseq.models.slds.gibbs.resample_continuous_stateseqs(seed, y, mask, z, s, Ab, Q, Cd, sigmasq, jitter=0.001, parallel_message_passing=True, **kwargs)[source]

Resample the latent trajectories x.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • y (jax.Array of shape (n_recordings, n_timesteps, obs_dim)) – Observations.

  • mask (jax.Array of shape (n_recordings, n_timesteps)) – Binary indicator, 1=valid frames, 0=invalid frames.

  • z (jax.Array of shape (n_recordings, n_timesteps-n_lags)) – Discrete state sequences, taking integer values between [0, n_states), for timesteps [n_lags, n_timesteps),

  • s (jax.Array of shape (n_recordings, n_timesteps, obs_dim)) – Observation noise scales.

  • Ab (jax.Array of shape (n_states, latent_dim, ar_dim + 1)) – Autoregressive dynamics and bias, where ar_dim = latent_dim * n_lags

  • Q (jax.Array of shape (n_states, latent_dim, latent_dim)) – Autoregressive noise covariances.

  • Cd (jax.Array of shape (obs_dim, latent_dim + 1)) – Affine transform from latent_dim to state_dim

  • sigmasq (jax.Array of shape (obs_dim,)) – Unscaled noise.

  • jitter (float, default=1e-3) – Amount to boost the diagonal of the covariance matrix during backward-sampling of the continuous states.

  • parallel_message_passing (bool, default=True,) – Use associative scan for Kalman sampling, which is faster on a GPU but has a significantly longer jit time.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

x – Posterior sample of latent trajectories.

Return type:

jax.Array of shape (n_recordings, n_timesteps, latent_dim)

jax_moseq.models.slds.gibbs.resample_obs_variance(seed, Y, mask, x, s, Cd, nu_sigma, sigmasq_0, **kwargs)[source]

Resample the observation variance sigmasq.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • Y (jax array of shape (N, T, obs_dim)) – Observations.

  • mask (jax array of shape (N, T)) – Binary indicator for valid frames.

  • x (jax array of shape (N, T, latent_dim)) – Latent trajectories.

  • s (jax array of shape (N, T, obs_dim)) – Noise scales.

  • Cd (jax array of shape (obs_dim, latent_dim + 1)) – Observation transform.

  • nu_sigma (float) – Chi-squared degrees of freedom in sigmasq.

  • sigmasq_0 (float) – Scaled inverse chi-squared scaling parameter for sigmasq.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

sigmasq – Unscaled noise.

Return type:

jax_array of shape obs_dim

jax_moseq.models.slds.gibbs.resample_obs_variance_from_sqerr(seed, sqerr, mask, s, nu_sigma, sigmasq_0, **kwargs)[source]

Resample the observation variance sigmasq using the squared error between predicted and true observations.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • sqerr (jax array of shape (N, T, obs_dim)) – Squared error between predicted and true observations.

  • mask (jax array of shape (N, T)) – Binary indicator for valid frames.

  • s (jax array of shape (N, T, obs_dim)) – Noise scales.

  • nu_sigma (float) – Chi-squared degrees of freedom in sigmasq.

  • sigmasq_0 (float) – Scaled inverse chi-squared scaling parameter for sigmasq.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

sigmasq – Unscaled noise.

Return type:

jax_array of shape obs_dim

jax_moseq.models.slds.gibbs.resample_scales(seed, Y, x, Cd, sigmasq, nu_s, s_0, **kwargs)[source]

Resample the scale values s.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • Y (jax array of shape (N, T, obs_dim)) – Observations.

  • x (jax array of shape (N, T, latent_dim)) – Latent trajectories.

  • Cd (jax array of shape (obs_dim, latent_dim + 1)) – Observation transform.

  • sigmasq (jax_array of shape obs_dim) – Unscaled noise.

  • nu_s (int) – Chi-squared degrees of freedom in noise prior.

  • s_0 (scalar or jax array broadcastable to Y) – Prior on noise scale.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

s – Noise scales.

Return type:

jax array of shape (N, T, obs_dim)

jax_moseq.models.slds.gibbs.resample_scales_from_sqerr(seed, sqerr, sigmasq, nu_s, s_0, **kwargs)[source]

Resample the scale values s using the squared error between predicted and true observations.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • sqerr (jax array of shape (N, T, obs_dim)) – Squared error between predicted and true observations.

  • sigmasq (jax_array of shape obs_dim) – Unscaled noise.

  • nu_s (int) – Chi-squared degrees of freedom in noise prior.

  • s_0 (scalar or jax array broadcastable to Y) – Prior on noise scale.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

s – Per observation noise scales.

Return type:

jax array of shape (N, T, obs_dim)

jax_moseq.models.slds.gibbs.compute_squared_error(Y, x, Cd, mask=None)[source]

Computes the squared error between model predicted and true observations.

Parameters:
  • Y (jax array of shape (..., obs_dim)) – Observations.

  • x (jax array of shape (..., latent_dim)) – Latent trajectories.

  • Cd (jax array of shape (obs_dim, latent_dim + 1)) – Observation transform.

  • mask (jax array of shape (...), optional) – Binary indicator for valid frames.

Returns:

sqerr – Squared error between model predicted and true observations.

Return type:

jax array of shape (…, obs_dim)

jax_moseq.models.slds.gibbs.resample_model(data, seed, states, params, hypparams, ar_only=False, states_only=False, skip_noise=True, parallel_message_passing=False, **kwargs)[source]

Resamples the SLDS model given the hyperparameters, data, current states, and current parameters.

Parameters:
  • data (dict) – Data dictionary containing the observations and mask.

  • seed (jr.PRNGKey) – JAX random seed.

  • states (dict) – State values for each latent variable.

  • params (dict) – Values for each model parameter.

  • hypparams (dict) – Values for each group of hyperparameters.

  • ar_only (bool, default=False) – Whether to restrict sampling to ARHMM components.

  • states_only (bool, default=False) – Whether to restrict sampling to states.

  • skip_noise (bool, default=True) – Whether to exclude sigmasq and s from resampling.

  • parallel_message_passing (bool, default=True,) – Use associative scan for Kalman sampling, which is faster on a GPU but has a significantly longer jit time.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

model – Dictionary containing the hyperparameters and updated seed, states, and parameters of the model.

Return type:

dict

Log probability

Functions:

scale_log_prob(s, s_0, nu_s, **kwargs)

Calculate the log probability of the noise scale s for each datapoint given the noise prior, which is a scaled inverse chi-squared distribution.

obs_log_prob(Y, x, s, Cd, sigmasq, **kwargs)

Calculate the log probability of the observations at each time-step given the latent trajectories, noise parameters, and observation matrix.

log_joint_likelihood(Y, mask, x, s, z, pi, ...)

Calculate the total log probability for each latent state.

model_likelihood(data, states, params, ...)

Convenience class that invokes log_joint_likelihood.

jax_moseq.models.slds.log_prob.scale_log_prob(s, s_0, nu_s, **kwargs)[source]

Calculate the log probability of the noise scale s for each datapoint given the noise prior, which is a scaled inverse chi-squared distribution.

Parameters:
  • s (jax array) – Noise scales.

  • s_0 (scalar or jax array, broadcastable to s) – Prior on noise scale.

  • nu_s (int) – Chi-squared degrees of freedom in noise prior.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

log_ps – Log probability of s.

Return type:

jax array

jax_moseq.models.slds.log_prob.obs_log_prob(Y, x, s, Cd, sigmasq, **kwargs)[source]

Calculate the log probability of the observations at each time-step given the latent trajectories, noise parameters, and observation matrix.

Parameters:
  • Y (jax array of shape (..., obs_dim)) – Observations.

  • x (jax array of shape (..., latent_dim)) – Latent trajectories.

  • s (jax array of shape (..., obs_dim)) – Noise scales.

  • Cd (jax array of shape (obs_dim, latent_dim + 1)) – Observation transform.

  • sigmasq (jax_array of shape obs_dim) – Unscaled noise.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

log_pY – Log probability of Y.

Return type:

jax array of shape (…, obs_dim)

jax_moseq.models.slds.log_prob.log_joint_likelihood(Y, mask, x, s, z, pi, Ab, Q, Cd, sigmasq, s_0, nu_s, **kwargs)[source]

Calculate the total log probability for each latent state.

Parameters:
  • Y (jax array of shape (..., T, obs_dim)) – Observations.

  • mask (jax array of shape (..., T)) – Binary indicator for valid frames.

  • x (jax array of shape (..., T, latent_dim)) – Latent trajectories.

  • s (jax array of shape (..., T, obs_dim)) – Noise scales.

  • z (jax_array of shape (..., T - n_lags)) – Discrete state sequences.

  • pi (jax_array of shape (num_states, num_states)) – Transition probabilities.

  • Ab (jax array of shape (num_states, latent_dim, ar_dim)) – Autoregressive transforms.

  • Q (jax array of shape (num_states, latent_dim, latent_dim)) – Autoregressive noise covariances.

  • Cd (jax array of shape (obs_dim, latent_dim + 1)) – Observation transform.

  • sigmasq (jax_array of shape obs_dim) – Unscaled noise.

  • s_0 (scalar or jax array broadcastable to Y) – Prior on noise scale.

  • nu_s (int) – Chi-squared degrees of freedom in noise prior.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

ll – Dictionary mapping the name of each state variable to its total log probability.

Return type:

dict

jax_moseq.models.slds.log_prob.model_likelihood(data, states, params, hypparams, **kwargs)[source]

Convenience class that invokes log_joint_likelihood.

Parameters:
  • data (dict) – Data dictionary containing the observations and mask.

  • states (dict) – State values for each latent variable.

  • params (dict) – Values for each model parameter.

  • hypparams (dict) – Values for each group of hyperparameters.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

ll – Dictionary mapping state variable name to its total log probability.

Return type:

dict