Keypoint-SLDS

Initialization

Functions:

init_states(seed, Y, mask, params, ...[, ...])

Initialize the latent states of the keypoint SLDS from the data, parameters, and hyperparameters.

init_params(seed, pca, Y_flat, mask, ...)

Initialize the parameters of the keypoint SLDS from the data and hyperparameters.

init_hyperparams(trans_hypparams, ...)

Formats the hyperparameter dictionary of the keypoint SLDS.

init_model([data, states, params, ...])

Initialize a keypoint SLDS model dict containing the hyperparameters, noise prior, and initial seed, states, and parameters.

estimate_error(conf, slope, intercept)

Using the provided keypoint confidences and parameters learned from the noise calibration, returns prior on the noise for each datapoint.

jax_moseq.models.keypoint_slds.initialize.init_states(seed, Y, mask, params, noise_prior, obs_hypparams, Y_flat=None, v=None, h=None, fix_heading=False, **kwargs)[source]

Initialize the latent states of the keypoint SLDS from the data, parameters, and hyperparameters.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • Y (jax array of shape (N, T, k, d)) – Keypoint observations.

  • mask (jax array of shape (N, T)) – Binary indicator for valid frames.

  • params (dict) – Values for each model parameter.

  • noise_prior (scalar or jax array broadcastable to s) – Prior on noise scale.

  • obs_hypparams (dict) – Observation hyperparameters.

  • Y_flat (jax array of shape (N, T, (k - 1) * d), optional) – Aligned and embedded keypoint observations.

  • v (jax array of shape (N, T, d), optional) – Initial centroid positions.

  • h (jax array of shape (N, T), optional) – Initial heading angles.

  • fix_heading (bool, default=False) – Whether keep the heading angle of the pose fixed. If true, the heading variable h is initialized as 0.

  • **kwargs (dict, optional) – Arguments to jax_moseq.models.keypoint_slds.alignment.preprocess_for_pca(), as a substitute for Y_flat, v, or h.

Returns:

states – State values for each latent variable.

Return type:

dict

jax_moseq.models.keypoint_slds.initialize.init_params(seed, pca, Y_flat, mask, trans_hypparams, ar_hypparams, whiten, k, **kwargs)[source]

Initialize the parameters of the keypoint SLDS from the data and hyperparameters.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • pca (sklearn.decomposition._pca.PCA) – PCA object fit to observations.

  • Y_flat (jax array of shape (N, T, (k - 1) * d)) – Aligned and embedded keypoint observations.

  • mask (jax array of shape (N, T)) – Binary indicator for valid frames.

  • trans_hypparams (dict) – HDP transition hyperparameters.

  • ar_hypparams (dict) – Autoregression hyperparameters.

  • whiten (bool) – Whether to whiten PC’s to initialize continuous latents.

  • k (int) – Number of keypoints.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

params – Values for each model parameter.

Return type:

dict

jax_moseq.models.keypoint_slds.initialize.init_hyperparams(trans_hypparams, ar_hypparams, obs_hypparams, cen_hypparams, **kwargs)[source]

Formats the hyperparameter dictionary of the keypoint SLDS.

Parameters:
  • trans_hypparams (dict) – HDP transition hyperparameters.

  • ar_hypparams (dict) – Autoregression hyperparameters.

  • obs_hypparams (dict) – Observation hyperparameters.

  • cen_hypparams (dict) – Centroid movement hyperparameters.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

hypparams – Values for each group of hyperparameters.

Return type:

dict

jax_moseq.models.keypoint_slds.initialize.init_model(data=None, states=None, params=None, hypparams=None, noise_prior=None, seed=Array([0, 0], dtype=uint32), pca=None, whiten=True, PCA_fitting_num_frames=1000000, anterior_idxs=None, posterior_idxs=None, conf_threshold=0.5, error_estimator=None, trans_hypparams=None, ar_hypparams=None, obs_hypparams=None, cen_hypparams=None, verbose=False, exclude_outliers_for_pca=True, fix_heading=False, **kwargs)[source]

Initialize a keypoint SLDS model dict containing the hyperparameters, noise prior, and initial seed, states, and parameters.

Parameters:
  • data (dict, optional) – Data dictionary containing the observations, mask, and (optionally) confidences. Must be provided if states or params not precomputed.

  • states (dict, optional) – State values for each latent variable, if precomputed.

  • params (dict, optional) – Values for each model parameter, if precomputed.

  • hypparams (dict, optional) – Values for each group of hyperparameters. If not provided, caller must provide each arg of init_hypparams.

  • noise_prior (array or scalar, optional) – Prior on the noise for each keypoint observation, if precomputed.

  • seed (int or jr.PRNGKey, default=jr.PRNGKey(0)) – Initial random seed value.

  • pca (sklearn.decomposition.PCA, optional) – PCA object, if precomputed. If unspecified, will be computed from the data.

  • whiten (bool, default=True) – Whether to whiten PC’s to initialize continuous latents.

  • PCA_fitting_num_frames (int, default=1000000) – Maximum number of datapoints to sample for PCA fitting, if pca is not provided.

  • anterior_idxs (iterable of ints, optional) – Anterior keypoint indices for heading initialization. Must be provided if states or params not precomputed.

  • posterior_idxs (iterable of ints, optional) – Posterior keypoint indices for heading initialization. Must be provided if states or params not precomputed.

  • conf_threshold (float, default=0.5) – Confidence threshold below which points are interpolated in PCA fitting and heading/position initialization. See jax_moseq.models.keypoint_slds.alignment.preprocess_for_pca() for details.

  • error_estimator (dict, optional) – Parameters used to initialize noise_prior. Must be provided if data contains confidences. See jax_moseq.models.keypoint_slds.initialize.estimate_error() for details.

  • trans_hypparams (dict, optional) – HDP transition hyperparameters. Must be provided if hypparams not provided.

  • ar_hypparams (dict, optional) – Autoregression hyperparameters. Must be provided if hypparams not provided.

  • obs_hypparams (dict, optional) – Observation hyperparameters. Must be provided if hypparams not provided.

  • cen_hypparams (dict, optional) – Centroid movement hyperparameters. Must be provided if hypparams not provided.

  • verbose (bool, default=False) – Whether to print progress info during initialization.

  • exclude_outliers_for_pca (bool, default=True) – Whether to exclude frames with low-confidence keypoints. If False, then the low-confidence keypoint coordinates are l inearly interpolated.

  • fix_heading (bool, default=False) – Whether keep the heading angle of the pose fixed. If true, the heading variable h is initialized as 0.

  • **kwargs (dict, optional) – Unused. For convenience, enables user to invoke function by unpacking dict that contains keys not used by the method.

Returns:

model – Dictionary containing the hyperparameters, noise prior, and initial seed, states, and parameters of the model.

Return type:

dict

Raises:

ValueError – If the subset of the parameters provided by the caller is insufficient for model initialization.

jax_moseq.models.keypoint_slds.initialize.estimate_error(conf, slope, intercept)[source]

Using the provided keypoint confidences and parameters learned from the noise calibration, returns prior on the noise for each datapoint.

Parameters:
  • conf (jax array of shape (..., k)) – Confidence for each keypoint observation. Must be >= 0.

  • slope (float) – Slope learned by noise calibration.

  • intercept (float) – Intercept learned by noise calibration.

Returns:

noise_prior – Prior on the noise for each observation.

Return type:

jax array of shape (…, k)

Gibbs sampling

Functions:

resample_continuous_stateseqs(seed, Y, mask, ...)

Resamples the latent trajectories x.

resample_obs_variance(seed, Y, mask, Cd, x, ...)

Resample the observation variance sigmasq.

resample_scales(seed, Y, x, v, h, Cd, ...)

Resample the scale values s.

compute_squared_error(Y, x, v, h, Cd[, mask])

Computes the squared error between model predicted and true observations.

resample_heading(seed, Y, x, v, s, Cd, ...)

Resample the heading angles h.

resample_location(seed, Y, mask, x, h, s, ...)

Resample the centroid positions v.

resample_model(data, seed, states, params, ...)

Resamples the Keypoint SLDS model given the hyperparameters, data, noise prior, current states, and current parameters.

jax_moseq.models.keypoint_slds.gibbs.resample_continuous_stateseqs(seed, Y, mask, v, h, s, z, Cd, sigmasq, Ab, Q, jitter=0.001, parallel_message_passing=True, **kwargs)[source]

Resamples the latent trajectories x.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • Y (jax array of shape (N, T, k, d)) – Keypoint observations.

  • mask (jax array of shape (N, T)) – Binary indicator for valid frames.

  • v (jax array of shape (N, T, d)) – Centroid positions.

  • h (jax array of shape (N, T)) – Heading angles.

  • s (jax array of shape (N, T, k)) – Noise scales.

  • z (jax_array of shape (N, T - n_lags)) – Discrete state sequences.

  • Cd (jax array of shape ((k - 1) * d, latent_dim + 1)) – Observation transform.

  • sigmasq (jax_array of shape k) – Unscaled noise.

  • Ab (jax array of shape (num_states, latent_dim, ar_dim)) – Autoregressive transforms.

  • Q (jax array of shape (num_states, latent_dim, latent_dim)) – Autoregressive noise covariances.

  • jitter (float, default=1e-3) – Amount to boost the diagonal of the covariance matrix during backward-sampling of the continuous states.

  • parallel_message_passing (bool, default=True,) – Use associative scan for Kalman sampling, which is faster on a GPU but has a significantly longer jit time.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

x – Latent trajectories.

Return type:

jax array of shape (N, T, latent_dim)

jax_moseq.models.keypoint_slds.gibbs.resample_obs_variance(seed, Y, mask, Cd, x, v, h, s, nu_sigma, sigmasq_0, **kwargs)[source]

Resample the observation variance sigmasq.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • Y (jax array of shape (N, T, k, d)) – Keypoint observations.

  • mask (jax array of shape (N, T)) – Binary indicator for valid frames.

  • Cd (jax array of shape ((k - 1) * d, latent_dim + 1)) – Observation transform.

  • x (jax array of shape (N, T, latent_dim)) – Latent trajectories.

  • v (jax array of shape (N, T, d)) – Centroid positions.

  • h (jax array of shape (N, T)) – Heading angles.

  • s (jax array of shape (N, T, k)) – Noise scales.

  • nu_sigma (float) – Chi-squared degrees of freedom in sigmasq.

  • sigmasq_0 (float) – Scaled inverse chi-squared scaling parameter for sigmasq.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

sigmasq – Unscaled noise.

Return type:

jax_array of shape k

jax_moseq.models.keypoint_slds.gibbs.resample_scales(seed, Y, x, v, h, Cd, sigmasq, nu_s, s_0, **kwargs)[source]

Resample the scale values s.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • Y (jax array of shape (N, T, k, d)) – Keypoint observations.

  • x (jax array of shape (N, T, latent_dim)) – Latent trajectories.

  • v (jax array of shape (N, T, d)) – Centroid positions.

  • h (jax array of shape (N, T)) – Heading angles.

  • Cd (jax array of shape ((k - 1) * d, latent_dim + 1)) – Observation transform.

  • sigmasq (jax_array of shape k) – Unscaled noise.

  • nu_s (int) – Chi-squared degrees of freedom in noise prior.

  • s_0 (scalar or jax array broadcastable to Y) – Prior on noise scale.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

s – Noise scales.

Return type:

jax array of shape (N, T, k)

jax_moseq.models.keypoint_slds.gibbs.compute_squared_error(Y, x, v, h, Cd, mask=None)[source]

Computes the squared error between model predicted and true observations.

Parameters:
  • Y (jax array of shape (..., k, d)) – Keypoint observations.

  • x (jax array of shape (..., latent_dim)) – Latent trajectories.

  • v (jax array of shape (..., d)) – Centroid positions.

  • h (jax array) – Heading angles.

  • Cd (jax array of shape ((k - 1) * d, latent_dim + 1)) – Observation transform.

  • mask (jax array, optional) – Binary indicator for valid frames.

Returns:

sqerr – Squared error between model predicted and true observations.

Return type:

jax array of shape (…, k)

jax_moseq.models.keypoint_slds.gibbs.resample_heading(seed, Y, x, v, s, Cd, sigmasq, **kwargs)[source]

Resample the heading angles h.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • Y (jax array of shape (N, T, k, d)) – Keypoint observations.

  • x (jax array of shape (N, T, latent_dim)) – Latent trajectories.

  • v (jax array of shape (N, T, d)) – Centroid positions.

  • s (jax array of shape (N, T, k)) – Noise scales.

  • Cd (jax array of shape ((k - 1) * d, latent_dim + 1)) – Observation transform.

  • sigmasq (jax_array of shape k) – Unscaled noise.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

h – Heading angles.

Return type:

jax array of shape (N, T)

jax_moseq.models.keypoint_slds.gibbs.resample_location(seed, Y, mask, x, h, s, Cd, sigmasq, sigmasq_loc, parallel_message_passing=True, **kwargs)[source]

Resample the centroid positions v.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • Y (jax array of shape (N, T, k, d)) – Keypoint observations.

  • mask (jax array of shape (N, T)) – Binary indicator for valid frames.

  • x (jax array of shape (N, T, latent_dim)) – Latent trajectories.

  • h (jax array of shape (N, T)) – Heading angles.

  • s (jax array of shape (N, T, k)) – Noise scales.

  • Cd (jax array of shape ((k - 1) * d, latent_dim + 1)) – Observation transform.

  • sigmasq (jax_array of shape k) – Unscaled noise.

  • sigmasq_loc (float) – Assumed variance in centroid displacements.

  • parallel_message_passing (bool, default=True,) – Use associative scan for Kalman sampling, which is faster on a GPU but has a significantly longer jit time.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

v – Centroid positions.

Return type:

jax array of shape (N, T, d)

jax_moseq.models.keypoint_slds.gibbs.resample_model(data, seed, states, params, hypparams, noise_prior, ar_only=False, states_only=False, resample_global_noise_scale=False, resample_local_noise_scale=True, fix_heading=False, verbose=False, jitter=0.001, parallel_message_passing=False, **kwargs)[source]

Resamples the Keypoint SLDS model given the hyperparameters, data, noise prior, current states, and current parameters.

Parameters:
  • data (dict) – Data dictionary containing the observations and mask.

  • seed (jr.PRNGKey) – JAX random seed.

  • states (dict) – State values for each latent variable.

  • params (dict) – Values for each model parameter.

  • hypparams (dict) – Values for each group of hyperparameters.

  • noise_prior (scalar or jax array broadcastable to s) – Prior on noise scale.

  • ar_only (bool, default=False) – Whether to restrict sampling to ARHMM components.

  • states_only (bool, default=False) – Whether to restrict sampling to states.

  • resample_global_noise_scale (bool, default=False) – Whether to resample the global noise scales (sigmasq)

  • resample_local_noise_scale (bool, default=True) – Whether to resample the local noise scales (s)

  • fix_heading (bool, default=False) – Whether to exclude h from resampling.

  • jitter (float, default=1e-3) – Amount to boost the diagonal of the covariance matrix during backward-sampling of the continuous states.

  • verbose (bool, default=False) – Whether to print progress info during resampling.

  • parallel_message_passing (bool, default=True,) – Use associative scan for Kalman sampling, which is faster on a GPU but has a significantly longer jit time.

Returns:

model – Dictionary containing the hyperparameters and updated seed, states, and parameters of the model.

Return type:

dict

Log probability

Functions:

location_log_prob(v, sigmasq_loc)

Calculate the log probability of the centroid location at each time-step, given the prior on centroid movement.

obs_log_prob(Y, x, v, h, s, Cd, sigmasq, ...)

Calculate the log probability of keypoint coordinates at each time-step, given continuous latent trajectories, centroids, heading angles, noise scales, and observation parameters.

log_joint_likelihood(Y, mask, x, v, h, s, z, ...)

Calculate the total log probability for each latent state.

model_likelihood(data, states, params, ...)

Convenience class that invokes log_joint_likelihood.

jax_moseq.models.keypoint_slds.log_prob.location_log_prob(v, sigmasq_loc)[source]

Calculate the log probability of the centroid location at each time-step, given the prior on centroid movement.

Parameters:
  • v (jax array of shape (..., T, d)) – Centroid positions.

  • sigmasq_loc (float) – Assumed variance in centroid displacements.

Returns:

log_pv – Log probability of v.

Return type:

jax array of shape (…, T - 1)

jax_moseq.models.keypoint_slds.log_prob.obs_log_prob(Y, x, v, h, s, Cd, sigmasq, **kwargs)[source]

Calculate the log probability of keypoint coordinates at each time-step, given continuous latent trajectories, centroids, heading angles, noise scales, and observation parameters.

Parameters:
  • Y (jax array of shape (..., k, d)) – Keypoint observations.

  • x (jax array of shape (..., latent_dim)) – Latent trajectories.

  • v (jax array of shape (..., d)) – Centroid positions.

  • h (jax array) – Heading angles.

  • s (jax array of shape (..., k)) – Noise scales.

  • Cd (jax array of shape ((k - 1) * d, latent_dim + 1)) – Observation transform.

  • sigmasq (jax_array of shape k) – Unscaled noise.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

log_pY – Log probability of Y.

Return type:

jax array of shape (…, k)

jax_moseq.models.keypoint_slds.log_prob.log_joint_likelihood(Y, mask, x, v, h, s, z, pi, Ab, Q, Cd, sigmasq, sigmasq_loc, s_0, nu_s, **kwargs)[source]

Calculate the total log probability for each latent state.

Parameters:
  • Y (jax array of shape (..., T, k, d)) – Keypoint observations.

  • mask (jax array of shape (..., T)) – Binary indicator for valid frames.

  • x (jax array of shape (..., T, latent_dim)) – Latent trajectories.

  • v (jax array of shape (..., T, d)) – Centroid positions.

  • h (jax array of shape (..., T)) – Heading angles.

  • s (jax array of shape (..., T, k)) – Noise scales.

  • z (jax_array of shape (..., T - n_lags)) – Discrete state sequences.

  • pi (jax_array of shape (num_states, num_states)) – Transition probabilities.

  • Ab (jax array of shape (num_states, latent_dim, ar_dim)) – Autoregressive transforms.

  • Q (jax array of shape (num_states, latent_dim, latent_dim)) – Autoregressive noise covariances.

  • Cd (jax array of shape ((k - 1) * d, latent_dim + 1)) – Observation transform.

  • sigmasq (jax_array of shape k) – Unscaled noise.

  • sigmasq_loc (float) – Assumed variance in centroid displacements.

  • s_0 (scalar or jax array broadcastable to Y) – Prior on noise scale.

  • nu_s (int) – Chi-squared degrees of freedom in noise prior.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

ll – Dictionary mapping the name of each state variable to its total log probability.

Return type:

dict

jax_moseq.models.keypoint_slds.log_prob.model_likelihood(data, states, params, hypparams, noise_prior, **kwargs)[source]

Convenience class that invokes log_joint_likelihood.

Parameters:
  • data (dict) – Data dictionary containing the observations and mask.

  • states (dict) – State values for each latent variable.

  • params (dict) – Values for each model parameter.

  • hypparams (dict) – Values for each group of hyperparameters.

  • noise_prior (scalar or jax array broadcastable to s) – Prior on noise scale.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

ll – Dictionary mapping state variable name to its total log probability.

Return type:

dict

Geometric transforms

Functions:

to_vanilla_slds(Y, v, h, s, Cd, sigmasq, ...)

Given the empirical keypoint positions, position/heading estimates, isotropic noise estimates, and emission parameters, this function returns the (relevant subset of the) observations, states, and params for an equivalent SLDS that directly maps the latent trajectories to flattened and aligned keypoint observations.

estimate_coordinates(x, v, h, Cd, **kwargs)

Estimate keypoint coordinates obtained from projecting the latent state x into keypoint-space (via Cd) and then rotating and translating by h and v respectively

estimate_aligned(x, Cd, k)

Compute estimated positions of aligned keypoints (i.e. prior to applying the rigid transform).

rigid_transform(Y, v, h)

Apply the rigid transform consisting of rotation by h and translation by v to a set of keypoint observations.

inverse_rigid_transform(Y, v, h)

Apply the inverse of the rigid transform consisting of rotation by h and translation by v to a set of keypoint observations.

center_embedding(n)

Generates a matrix Gamma that maps from a (n-1)-dimensional vector space to the space of k-tuples with zero mean

apply_rotation(Y, h)

Rotate Y by h radians.

angle_to_rotation_matrix(h[, d])

Create rotation matrices from an array of angles.

vector_to_angle(V)

Convert 2D vectors to angles in [-pi, pi].

fit_pca(Y, mask[, anterior_idxs, ...])

Fit a PCA model to transformed keypoint coordinates.

preprocess_for_pca(Y, anterior_idxs, ...[, ...])

Prepare keypoint coordinates for PCA by performing egocentric alignment (optional), changing basis using center_embedding(k), and reshaping to a single flat vector per frame.

align_egocentric(Y, anterior_idxs, ...)

Perform egocentric alignment of keypoints by translating the centroid to the origin and rotatating so that the vector pointing from the posterior bodyparts toward the anterior bodyparts is proportional to (1,0).

jax_moseq.models.keypoint_slds.alignment.to_vanilla_slds(Y, v, h, s, Cd, sigmasq, **kwargs)[source]

Given the empirical keypoint positions, position/heading estimates, isotropic noise estimates, and emission parameters, this function returns the (relevant subset of the) observations, states, and params for an equivalent SLDS that directly maps the latent trajectories to flattened and aligned keypoint observations.

Parameters:
  • Y (jax array of shape (..., k, d)) – Keypoint observations.

  • v (jax array of shape (..., d)) – Centroid positions.

  • h (jax array of shape (..., T)) – Heading angles.

  • s (jax array of shape (..., k)) – Noise scales.

  • Cd (jax array of shape ((k - 1) * d, latent_dim + 1)) – Observation transform.

  • sigmasq (jax_array of shape k) – Unscaled noise.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

  • Y (jax array of shape (…, k * d)) – Flattened and aligned keypoint observations.

  • s (jax array of shape (…, k * d)) – Noise scales repeated along spatial dimension.

  • Cd (jax array of shape (k * d, latent_dim + 1)) – Emission parameters, accounting for effect of embedding.

  • sigmasq (jax_array of shape k * d) – Unscaled noise repeated along spatial dimension.

jax_moseq.models.keypoint_slds.alignment.estimate_coordinates(x, v, h, Cd, **kwargs)[source]

Estimate keypoint coordinates obtained from projecting the latent state x into keypoint-space (via Cd) and then rotating and translating by h and v respectively

Parameters:
  • x (jax array of shape (..., latent_dim)) – Latent trajectories.

  • v (jax array of shape (..., d)) – Centroid positions.

  • h (jax array) – Heading angles.

  • Cd (jax array of shape ((k - 1) * d, latent_dim + 1)) – Observation transform.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

Y_bar

Return type:

jax array of shape (…, k, d), Estimated coordinates

jax_moseq.models.keypoint_slds.alignment.estimate_aligned(x, Cd, k)[source]

Compute estimated positions of aligned keypoints (i.e. prior to applying the rigid transform).

Parameters:
  • x (jax array of shape (..., latent_dim)) – Latent trajectories.

  • Cd (jax array of shape ((k - 1) * d, latent_dim + 1)) – Observation transform.

  • k (int) – Number of keypoints.

Returns:

Y_aligned – Aligned keypoint positions estimated from latents.

Return type:

jax array of shape (…, k * d)

jax_moseq.models.keypoint_slds.alignment.rigid_transform(Y, v, h)[source]

Apply the rigid transform consisting of rotation by h and translation by v to a set of keypoint observations.

Parameters:
  • Y (jax array of shape (..., k, d)) – Keypoint observations.

  • v (jax array of shape (..., d)) – Centroid positions.

  • h (jax array) – Heading angles.

Returns:

Y_transformed – Rigidly transformed positions.

Return type:

jax array of shape (…, k, d)

jax_moseq.models.keypoint_slds.alignment.inverse_rigid_transform(Y, v, h)[source]

Apply the inverse of the rigid transform consisting of rotation by h and translation by v to a set of keypoint observations.

Parameters:
  • Y (jax array of shape (..., k, d)) – Keypoint observations.

  • v (jax array of shape (..., d)) – Centroid positions.

  • h (jax array) – Heading angles.

Returns:

Y_transformed – Rigidly transformed positions.

Return type:

jax array of shape (…, k, d)

jax_moseq.models.keypoint_slds.alignment.center_embedding(n)[source]

Generates a matrix Gamma that maps from a (n-1)-dimensional vector space to the space of k-tuples with zero mean

Parameters:

n (int) – Number of keypoints.

Returns:

Gamma – Matrix to map to centered embedded space.

Return type:

jax array of shape (n, n - 1)

jax_moseq.models.keypoint_slds.alignment.apply_rotation(Y, h)[source]

Rotate Y by h radians.

Parameters:
  • Y (jax array of shape (..., k, d)) – Keypoint observations.

  • h (jax array) – Heading angles.

Returns:

Y_rot – Rotated keypoint observations.

Return type:

jax array of shape (…, k, d)

jax_moseq.models.keypoint_slds.alignment.angle_to_rotation_matrix(h, d=3)[source]

Create rotation matrices from an array of angles. If d > 2 then rotation is performed in the first two dims.

Parameters:
  • h (jax array of shape (N, T)) – Heading angles.

  • d (int, default=3) – Keypoint dimensionality (either 2 or 3).

Returns:

m – Rotation matrices.

Return type:

jax array of shape (…, d, d)

jax_moseq.models.keypoint_slds.alignment.vector_to_angle(V)[source]

Convert 2D vectors to angles in [-pi, pi]. The vector (1,0) corresponds to angle of 0. If V is multidimensional, the first n-1 dimensions are treated as batch dims.

Parameters:

V (jax array of shape (..., 2)) – Batch of 2D vectors.

Returns:

h – Rotation angles in radians.

Return type:

jax array

jax_moseq.models.keypoint_slds.alignment.fit_pca(Y, mask, anterior_idxs=None, posterior_idxs=None, conf=None, conf_threshold=0.5, verbose=False, PCA_fitting_num_frames=1000000, exclude_outliers_for_pca=False, fix_heading=False, **kwargs)[source]

Fit a PCA model to transformed keypoint coordinates. If conf is not None, perform linear interpolation over outliers defined by conf < conf_threshold.

Parameters:
  • Y (jax array of shape (..., k, d)) – Keypoint observations.

  • mask (jax array) – Binary indicator for valid frames.

  • anterior_idxs (iterable of ints) – Anterior keypoint indices for heading initialization.

  • posterior_idxs (iterable of ints) – Posterior keypoint indices for heading initialization.

  • conf (jax array of shape (..., k), optional) – Confidence for each keypoint observation. Must be >= 0.

  • conf_threshold (float, default=0.5) – Confidence threshold for interpolation.

  • verbose (bool, default=False) – Whether to print progress updates.

  • PCA_fitting_num_frames (int, default=1000000) – Maximum number of frames for PCA fitting.

  • exclude_outliers_for_pca (bool, default=False) – Whether to exclude frames with low-confidence keypoints. If False, then the low-confidence keypoint coordinates are linearly interpolated.

  • fix_heading (bool, default=False) – Whether keep the heading angle fixed. If true, the heading h is set to 0 and keypoints are not rotated.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

PCA object fit to observations.

Return type:

pca, sklearn.decomposition._pca.PCA

jax_moseq.models.keypoint_slds.alignment.preprocess_for_pca(Y, anterior_idxs, posterior_idxs, conf=None, conf_threshold=0.5, fix_heading=False, verbose=False, **kwargs)[source]

Prepare keypoint coordinates for PCA by performing egocentric alignment (optional), changing basis using center_embedding(k), and reshaping to a single flat vector per frame.

Parameters:
  • Y (jax array of shape (..., k, d)) – Keypoint observations.

  • anterior_idxs (iterable of ints) – Anterior keypoint indices for heading initialization.

  • posterior_idxs (iterable of ints) – Posterior keypoint indices for heading initialization.

  • conf (jax array of shape (..., k), optional) – Confidence for each keypoint observation. Must be >= 0.

  • conf_threshold (float, default=.5) – Confidence threshold for interpolation.

  • fix_heading (bool, default=False) – Whether keep the heading angle fixed. If true, the heading h is set to 0 and keypoints are not rotated.

  • verbose (bool, default=False) – Whether to print progress updates.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

Y_flat – Aligned and embedded keypoint observations.

Return type:

jax array of shape (…, (k - 1) * d), optional

jax_moseq.models.keypoint_slds.alignment.align_egocentric(Y, anterior_idxs, posterior_idxs, **kwargs)[source]

Perform egocentric alignment of keypoints by translating the centroid to the origin and rotatating so that the vector pointing from the posterior bodyparts toward the anterior bodyparts is proportional to (1,0).

Parameters:
  • Y (jax array of shape (..., k, d)) – Keypoint observations.

  • anterior_idxs (iterable of ints) – Anterior keypoint indices for heading initialization.

  • posterior_idxs (iterable of ints) – Posterior keypoint indices for heading initialization.

  • **kwargs (dict) – Overflow, for convenience.

Returns:

  • Y_aligned (jax array of shape (…, k, d)) – Aligned keypoint coordinates.

  • v (jax array of shape (…, d)) – Centroid positions that were used for alignment.

  • h (jax array) – Heading angles that were used for alignment.