Keypoint-SLDS
Initialization
Functions:
|
Initialize the latent states of the keypoint SLDS from the data, parameters, and hyperparameters. |
|
Initialize the parameters of the keypoint SLDS from the data and hyperparameters. |
|
Formats the hyperparameter dictionary of the keypoint SLDS. |
|
Initialize a keypoint SLDS model dict containing the hyperparameters, noise prior, and initial seed, states, and parameters. |
|
Using the provided keypoint confidences and parameters learned from the noise calibration, returns prior on the noise for each datapoint. |
- jax_moseq.models.keypoint_slds.initialize.init_states(seed, Y, mask, params, noise_prior, obs_hypparams, Y_flat=None, v=None, h=None, fix_heading=False, **kwargs)[source]
Initialize the latent states of the keypoint SLDS from the data, parameters, and hyperparameters.
- Parameters:
seed (jr.PRNGKey) – JAX random seed.
Y (jax array of shape (N, T, k, d)) – Keypoint observations.
mask (jax array of shape (N, T)) – Binary indicator for valid frames.
params (dict) – Values for each model parameter.
noise_prior (scalar or jax array broadcastable to
s) – Prior on noise scale.obs_hypparams (dict) – Observation hyperparameters.
Y_flat (jax array of shape (N, T, (k - 1) * d), optional) – Aligned and embedded keypoint observations.
v (jax array of shape (N, T, d), optional) – Initial centroid positions.
h (jax array of shape (N, T), optional) – Initial heading angles.
fix_heading (bool, default=False) – Whether keep the heading angle of the pose fixed. If true, the heading variable
his initialized as 0.**kwargs (dict, optional) – Arguments to
jax_moseq.models.keypoint_slds.alignment.preprocess_for_pca(), as a substitute forY_flat,v, orh.
- Returns:
states – State values for each latent variable.
- Return type:
dict
- jax_moseq.models.keypoint_slds.initialize.init_params(seed, pca, Y_flat, mask, trans_hypparams, ar_hypparams, whiten, k, **kwargs)[source]
Initialize the parameters of the keypoint SLDS from the data and hyperparameters.
- Parameters:
seed (jr.PRNGKey) – JAX random seed.
pca (sklearn.decomposition._pca.PCA) – PCA object fit to observations.
Y_flat (jax array of shape (N, T, (k - 1) * d)) – Aligned and embedded keypoint observations.
mask (jax array of shape (N, T)) – Binary indicator for valid frames.
trans_hypparams (dict) – HDP transition hyperparameters.
ar_hypparams (dict) – Autoregression hyperparameters.
whiten (bool) – Whether to whiten PC’s to initialize continuous latents.
k (int) – Number of keypoints.
**kwargs (dict) – Overflow, for convenience.
- Returns:
params – Values for each model parameter.
- Return type:
dict
- jax_moseq.models.keypoint_slds.initialize.init_hyperparams(trans_hypparams, ar_hypparams, obs_hypparams, cen_hypparams, **kwargs)[source]
Formats the hyperparameter dictionary of the keypoint SLDS.
- Parameters:
trans_hypparams (dict) – HDP transition hyperparameters.
ar_hypparams (dict) – Autoregression hyperparameters.
obs_hypparams (dict) – Observation hyperparameters.
cen_hypparams (dict) – Centroid movement hyperparameters.
**kwargs (dict) – Overflow, for convenience.
- Returns:
hypparams – Values for each group of hyperparameters.
- Return type:
dict
- jax_moseq.models.keypoint_slds.initialize.init_model(data=None, states=None, params=None, hypparams=None, noise_prior=None, seed=Array([0, 0], dtype=uint32), pca=None, whiten=True, PCA_fitting_num_frames=1000000, anterior_idxs=None, posterior_idxs=None, conf_threshold=0.5, error_estimator=None, trans_hypparams=None, ar_hypparams=None, obs_hypparams=None, cen_hypparams=None, verbose=False, exclude_outliers_for_pca=True, fix_heading=False, **kwargs)[source]
Initialize a keypoint SLDS model dict containing the hyperparameters, noise prior, and initial seed, states, and parameters.
- Parameters:
data (dict, optional) – Data dictionary containing the observations, mask, and (optionally) confidences. Must be provided if
statesorparamsnot precomputed.states (dict, optional) – State values for each latent variable, if precomputed.
params (dict, optional) – Values for each model parameter, if precomputed.
hypparams (dict, optional) – Values for each group of hyperparameters. If not provided, caller must provide each arg of
init_hypparams.noise_prior (array or scalar, optional) – Prior on the noise for each keypoint observation, if precomputed.
seed (int or jr.PRNGKey, default=jr.PRNGKey(0)) – Initial random seed value.
pca (sklearn.decomposition.PCA, optional) – PCA object, if precomputed. If unspecified, will be computed from the data.
whiten (bool, default=True) – Whether to whiten PC’s to initialize continuous latents.
PCA_fitting_num_frames (int, default=1000000) – Maximum number of datapoints to sample for PCA fitting, if
pcais not provided.anterior_idxs (iterable of ints, optional) – Anterior keypoint indices for heading initialization. Must be provided if states or params not precomputed.
posterior_idxs (iterable of ints, optional) – Posterior keypoint indices for heading initialization. Must be provided if
statesorparamsnot precomputed.conf_threshold (float, default=0.5) – Confidence threshold below which points are interpolated in PCA fitting and heading/position initialization. See
jax_moseq.models.keypoint_slds.alignment.preprocess_for_pca()for details.error_estimator (dict, optional) – Parameters used to initialize
noise_prior. Must be provided ifdatacontains confidences. Seejax_moseq.models.keypoint_slds.initialize.estimate_error()for details.trans_hypparams (dict, optional) – HDP transition hyperparameters. Must be provided if
hypparamsnot provided.ar_hypparams (dict, optional) – Autoregression hyperparameters. Must be provided if
hypparamsnot provided.obs_hypparams (dict, optional) – Observation hyperparameters. Must be provided if
hypparamsnot provided.cen_hypparams (dict, optional) – Centroid movement hyperparameters. Must be provided if
hypparamsnot provided.verbose (bool, default=False) – Whether to print progress info during initialization.
exclude_outliers_for_pca (bool, default=True) – Whether to exclude frames with low-confidence keypoints. If False, then the low-confidence keypoint coordinates are l inearly interpolated.
fix_heading (bool, default=False) – Whether keep the heading angle of the pose fixed. If true, the heading variable
his initialized as 0.**kwargs (dict, optional) – Unused. For convenience, enables user to invoke function by unpacking dict that contains keys not used by the method.
- Returns:
model – Dictionary containing the hyperparameters, noise prior, and initial seed, states, and parameters of the model.
- Return type:
dict
- Raises:
ValueError – If the subset of the parameters provided by the caller is insufficient for model initialization.
- jax_moseq.models.keypoint_slds.initialize.estimate_error(conf, slope, intercept)[source]
Using the provided keypoint confidences and parameters learned from the noise calibration, returns prior on the noise for each datapoint.
- Parameters:
conf (jax array of shape (..., k)) – Confidence for each keypoint observation. Must be >= 0.
slope (float) – Slope learned by noise calibration.
intercept (float) – Intercept learned by noise calibration.
- Returns:
noise_prior – Prior on the noise for each observation.
- Return type:
jax array of shape (…, k)
Gibbs sampling
Functions:
|
Resamples the latent trajectories |
|
Resample the observation variance |
|
Resample the scale values |
|
Computes the squared error between model predicted and true observations. |
|
Resample the heading angles |
|
Resample the centroid positions |
|
Resamples the Keypoint SLDS model given the hyperparameters, data, noise prior, current states, and current parameters. |
- jax_moseq.models.keypoint_slds.gibbs.resample_continuous_stateseqs(seed, Y, mask, v, h, s, z, Cd, sigmasq, Ab, Q, jitter=0.001, parallel_message_passing=True, **kwargs)[source]
Resamples the latent trajectories
x.- Parameters:
seed (jr.PRNGKey) – JAX random seed.
Y (jax array of shape (N, T, k, d)) – Keypoint observations.
mask (jax array of shape (N, T)) – Binary indicator for valid frames.
v (jax array of shape (N, T, d)) – Centroid positions.
h (jax array of shape (N, T)) – Heading angles.
s (jax array of shape (N, T, k)) – Noise scales.
z (jax_array of shape (N, T - n_lags)) – Discrete state sequences.
Cd (jax array of shape ((k - 1) * d, latent_dim + 1)) – Observation transform.
sigmasq (jax_array of shape k) – Unscaled noise.
Ab (jax array of shape (num_states, latent_dim, ar_dim)) – Autoregressive transforms.
Q (jax array of shape (num_states, latent_dim, latent_dim)) – Autoregressive noise covariances.
jitter (float, default=1e-3) – Amount to boost the diagonal of the covariance matrix during backward-sampling of the continuous states.
parallel_message_passing (bool, default=True,) – Use associative scan for Kalman sampling, which is faster on a GPU but has a significantly longer jit time.
**kwargs (dict) – Overflow, for convenience.
- Returns:
x – Latent trajectories.
- Return type:
jax array of shape (N, T, latent_dim)
- jax_moseq.models.keypoint_slds.gibbs.resample_obs_variance(seed, Y, mask, Cd, x, v, h, s, nu_sigma, sigmasq_0, **kwargs)[source]
Resample the observation variance
sigmasq.- Parameters:
seed (jr.PRNGKey) – JAX random seed.
Y (jax array of shape (N, T, k, d)) – Keypoint observations.
mask (jax array of shape (N, T)) – Binary indicator for valid frames.
Cd (jax array of shape ((k - 1) * d, latent_dim + 1)) – Observation transform.
x (jax array of shape (N, T, latent_dim)) – Latent trajectories.
v (jax array of shape (N, T, d)) – Centroid positions.
h (jax array of shape (N, T)) – Heading angles.
s (jax array of shape (N, T, k)) – Noise scales.
nu_sigma (float) – Chi-squared degrees of freedom in sigmasq.
sigmasq_0 (float) – Scaled inverse chi-squared scaling parameter for sigmasq.
**kwargs (dict) – Overflow, for convenience.
- Returns:
sigmasq – Unscaled noise.
- Return type:
jax_array of shape k
- jax_moseq.models.keypoint_slds.gibbs.resample_scales(seed, Y, x, v, h, Cd, sigmasq, nu_s, s_0, **kwargs)[source]
Resample the scale values
s.- Parameters:
seed (jr.PRNGKey) – JAX random seed.
Y (jax array of shape (N, T, k, d)) – Keypoint observations.
x (jax array of shape (N, T, latent_dim)) – Latent trajectories.
v (jax array of shape (N, T, d)) – Centroid positions.
h (jax array of shape (N, T)) – Heading angles.
Cd (jax array of shape ((k - 1) * d, latent_dim + 1)) – Observation transform.
sigmasq (jax_array of shape k) – Unscaled noise.
nu_s (int) – Chi-squared degrees of freedom in noise prior.
s_0 (scalar or jax array broadcastable to
Y) – Prior on noise scale.**kwargs (dict) – Overflow, for convenience.
- Returns:
s – Noise scales.
- Return type:
jax array of shape (N, T, k)
- jax_moseq.models.keypoint_slds.gibbs.compute_squared_error(Y, x, v, h, Cd, mask=None)[source]
Computes the squared error between model predicted and true observations.
- Parameters:
Y (jax array of shape (..., k, d)) – Keypoint observations.
x (jax array of shape (..., latent_dim)) – Latent trajectories.
v (jax array of shape (..., d)) – Centroid positions.
h (jax array) – Heading angles.
Cd (jax array of shape ((k - 1) * d, latent_dim + 1)) – Observation transform.
mask (jax array, optional) – Binary indicator for valid frames.
- Returns:
sqerr – Squared error between model predicted and true observations.
- Return type:
jax array of shape (…, k)
- jax_moseq.models.keypoint_slds.gibbs.resample_heading(seed, Y, x, v, s, Cd, sigmasq, **kwargs)[source]
Resample the heading angles
h.- Parameters:
seed (jr.PRNGKey) – JAX random seed.
Y (jax array of shape (N, T, k, d)) – Keypoint observations.
x (jax array of shape (N, T, latent_dim)) – Latent trajectories.
v (jax array of shape (N, T, d)) – Centroid positions.
s (jax array of shape (N, T, k)) – Noise scales.
Cd (jax array of shape ((k - 1) * d, latent_dim + 1)) – Observation transform.
sigmasq (jax_array of shape k) – Unscaled noise.
**kwargs (dict) – Overflow, for convenience.
- Returns:
h – Heading angles.
- Return type:
jax array of shape (N, T)
- jax_moseq.models.keypoint_slds.gibbs.resample_location(seed, Y, mask, x, h, s, Cd, sigmasq, sigmasq_loc, parallel_message_passing=True, **kwargs)[source]
Resample the centroid positions
v.- Parameters:
seed (jr.PRNGKey) – JAX random seed.
Y (jax array of shape (N, T, k, d)) – Keypoint observations.
mask (jax array of shape (N, T)) – Binary indicator for valid frames.
x (jax array of shape (N, T, latent_dim)) – Latent trajectories.
h (jax array of shape (N, T)) – Heading angles.
s (jax array of shape (N, T, k)) – Noise scales.
Cd (jax array of shape ((k - 1) * d, latent_dim + 1)) – Observation transform.
sigmasq (jax_array of shape k) – Unscaled noise.
sigmasq_loc (float) – Assumed variance in centroid displacements.
parallel_message_passing (bool, default=True,) – Use associative scan for Kalman sampling, which is faster on a GPU but has a significantly longer jit time.
**kwargs (dict) – Overflow, for convenience.
- Returns:
v – Centroid positions.
- Return type:
jax array of shape (N, T, d)
- jax_moseq.models.keypoint_slds.gibbs.resample_model(data, seed, states, params, hypparams, noise_prior, ar_only=False, states_only=False, resample_global_noise_scale=False, resample_local_noise_scale=True, fix_heading=False, verbose=False, jitter=0.001, parallel_message_passing=False, **kwargs)[source]
Resamples the Keypoint SLDS model given the hyperparameters, data, noise prior, current states, and current parameters.
- Parameters:
data (dict) – Data dictionary containing the observations and mask.
seed (jr.PRNGKey) – JAX random seed.
states (dict) – State values for each latent variable.
params (dict) – Values for each model parameter.
hypparams (dict) – Values for each group of hyperparameters.
noise_prior (scalar or jax array broadcastable to
s) – Prior on noise scale.ar_only (bool, default=False) – Whether to restrict sampling to ARHMM components.
states_only (bool, default=False) – Whether to restrict sampling to states.
resample_global_noise_scale (bool, default=False) – Whether to resample the global noise scales (
sigmasq)resample_local_noise_scale (bool, default=True) – Whether to resample the local noise scales (
s)fix_heading (bool, default=False) – Whether to exclude
hfrom resampling.jitter (float, default=1e-3) – Amount to boost the diagonal of the covariance matrix during backward-sampling of the continuous states.
verbose (bool, default=False) – Whether to print progress info during resampling.
parallel_message_passing (bool, default=True,) – Use associative scan for Kalman sampling, which is faster on a GPU but has a significantly longer jit time.
- Returns:
model – Dictionary containing the hyperparameters and updated seed, states, and parameters of the model.
- Return type:
dict
Log probability
Functions:
|
Calculate the log probability of the centroid location at each time-step, given the prior on centroid movement. |
|
Calculate the log probability of keypoint coordinates at each time-step, given continuous latent trajectories, centroids, heading angles, noise scales, and observation parameters. |
|
Calculate the total log probability for each latent state. |
|
Convenience class that invokes log_joint_likelihood. |
- jax_moseq.models.keypoint_slds.log_prob.location_log_prob(v, sigmasq_loc)[source]
Calculate the log probability of the centroid location at each time-step, given the prior on centroid movement.
- Parameters:
v (jax array of shape (..., T, d)) – Centroid positions.
sigmasq_loc (float) – Assumed variance in centroid displacements.
- Returns:
log_pv – Log probability of v.
- Return type:
jax array of shape (…, T - 1)
- jax_moseq.models.keypoint_slds.log_prob.obs_log_prob(Y, x, v, h, s, Cd, sigmasq, **kwargs)[source]
Calculate the log probability of keypoint coordinates at each time-step, given continuous latent trajectories, centroids, heading angles, noise scales, and observation parameters.
- Parameters:
Y (jax array of shape (..., k, d)) – Keypoint observations.
x (jax array of shape (..., latent_dim)) – Latent trajectories.
v (jax array of shape (..., d)) – Centroid positions.
h (jax array) – Heading angles.
s (jax array of shape (..., k)) – Noise scales.
Cd (jax array of shape ((k - 1) * d, latent_dim + 1)) – Observation transform.
sigmasq (jax_array of shape k) – Unscaled noise.
**kwargs (dict) – Overflow, for convenience.
- Returns:
log_pY – Log probability of Y.
- Return type:
jax array of shape (…, k)
- jax_moseq.models.keypoint_slds.log_prob.log_joint_likelihood(Y, mask, x, v, h, s, z, pi, Ab, Q, Cd, sigmasq, sigmasq_loc, s_0, nu_s, **kwargs)[source]
Calculate the total log probability for each latent state.
- Parameters:
Y (jax array of shape (..., T, k, d)) – Keypoint observations.
mask (jax array of shape (..., T)) – Binary indicator for valid frames.
x (jax array of shape (..., T, latent_dim)) – Latent trajectories.
v (jax array of shape (..., T, d)) – Centroid positions.
h (jax array of shape (..., T)) – Heading angles.
s (jax array of shape (..., T, k)) – Noise scales.
z (jax_array of shape (..., T - n_lags)) – Discrete state sequences.
pi (jax_array of shape (num_states, num_states)) – Transition probabilities.
Ab (jax array of shape (num_states, latent_dim, ar_dim)) – Autoregressive transforms.
Q (jax array of shape (num_states, latent_dim, latent_dim)) – Autoregressive noise covariances.
Cd (jax array of shape ((k - 1) * d, latent_dim + 1)) – Observation transform.
sigmasq (jax_array of shape k) – Unscaled noise.
sigmasq_loc (float) – Assumed variance in centroid displacements.
s_0 (scalar or jax array broadcastable to Y) – Prior on noise scale.
nu_s (int) – Chi-squared degrees of freedom in noise prior.
**kwargs (dict) – Overflow, for convenience.
- Returns:
ll – Dictionary mapping the name of each state variable to its total log probability.
- Return type:
dict
- jax_moseq.models.keypoint_slds.log_prob.model_likelihood(data, states, params, hypparams, noise_prior, **kwargs)[source]
Convenience class that invokes log_joint_likelihood.
- Parameters:
data (dict) – Data dictionary containing the observations and mask.
states (dict) – State values for each latent variable.
params (dict) – Values for each model parameter.
hypparams (dict) – Values for each group of hyperparameters.
noise_prior (scalar or jax array broadcastable to s) – Prior on noise scale.
**kwargs (dict) – Overflow, for convenience.
- Returns:
ll – Dictionary mapping state variable name to its total log probability.
- Return type:
dict
Geometric transforms
Functions:
|
Given the empirical keypoint positions, position/heading estimates, isotropic noise estimates, and emission parameters, this function returns the (relevant subset of the) observations, states, and params for an equivalent SLDS that directly maps the latent trajectories to flattened and aligned keypoint observations. |
|
Estimate keypoint coordinates obtained from projecting the latent state |
|
Compute estimated positions of aligned keypoints (i.e. prior to applying the rigid transform). |
|
Apply the rigid transform consisting of rotation by h and translation by v to a set of keypoint observations. |
|
Apply the inverse of the rigid transform consisting of rotation by h and translation by v to a set of keypoint observations. |
Generates a matrix |
|
|
Rotate |
|
Create rotation matrices from an array of angles. |
Convert 2D vectors to angles in [-pi, pi]. |
|
|
Fit a PCA model to transformed keypoint coordinates. |
|
Prepare keypoint coordinates for PCA by performing egocentric alignment (optional), changing basis using |
|
Perform egocentric alignment of keypoints by translating the centroid to the origin and rotatating so that the vector pointing from the posterior bodyparts toward the anterior bodyparts is proportional to (1,0). |
- jax_moseq.models.keypoint_slds.alignment.to_vanilla_slds(Y, v, h, s, Cd, sigmasq, **kwargs)[source]
Given the empirical keypoint positions, position/heading estimates, isotropic noise estimates, and emission parameters, this function returns the (relevant subset of the) observations, states, and params for an equivalent SLDS that directly maps the latent trajectories to flattened and aligned keypoint observations.
- Parameters:
Y (jax array of shape (..., k, d)) – Keypoint observations.
v (jax array of shape (..., d)) – Centroid positions.
h (jax array of shape (..., T)) – Heading angles.
s (jax array of shape (..., k)) – Noise scales.
Cd (jax array of shape ((k - 1) * d, latent_dim + 1)) – Observation transform.
sigmasq (jax_array of shape k) – Unscaled noise.
**kwargs (dict) – Overflow, for convenience.
- Returns:
Y (jax array of shape (…, k * d)) – Flattened and aligned keypoint observations.
s (jax array of shape (…, k * d)) – Noise scales repeated along spatial dimension.
Cd (jax array of shape (k * d, latent_dim + 1)) – Emission parameters, accounting for effect of embedding.
sigmasq (jax_array of shape k * d) – Unscaled noise repeated along spatial dimension.
- jax_moseq.models.keypoint_slds.alignment.estimate_coordinates(x, v, h, Cd, **kwargs)[source]
Estimate keypoint coordinates obtained from projecting the latent state
xinto keypoint-space (viaCd) and then rotating and translating byhandvrespectively- Parameters:
x (jax array of shape (..., latent_dim)) – Latent trajectories.
v (jax array of shape (..., d)) – Centroid positions.
h (jax array) – Heading angles.
Cd (jax array of shape ((k - 1) * d, latent_dim + 1)) – Observation transform.
**kwargs (dict) – Overflow, for convenience.
- Returns:
Y_bar
- Return type:
jax array of shape (…, k, d), Estimated coordinates
- jax_moseq.models.keypoint_slds.alignment.estimate_aligned(x, Cd, k)[source]
Compute estimated positions of aligned keypoints (i.e. prior to applying the rigid transform).
- Parameters:
x (jax array of shape (..., latent_dim)) – Latent trajectories.
Cd (jax array of shape ((k - 1) * d, latent_dim + 1)) – Observation transform.
k (int) – Number of keypoints.
- Returns:
Y_aligned – Aligned keypoint positions estimated from latents.
- Return type:
jax array of shape (…, k * d)
- jax_moseq.models.keypoint_slds.alignment.rigid_transform(Y, v, h)[source]
Apply the rigid transform consisting of rotation by h and translation by v to a set of keypoint observations.
- Parameters:
Y (jax array of shape (..., k, d)) – Keypoint observations.
v (jax array of shape (..., d)) – Centroid positions.
h (jax array) – Heading angles.
- Returns:
Y_transformed – Rigidly transformed positions.
- Return type:
jax array of shape (…, k, d)
- jax_moseq.models.keypoint_slds.alignment.inverse_rigid_transform(Y, v, h)[source]
Apply the inverse of the rigid transform consisting of rotation by h and translation by v to a set of keypoint observations.
- Parameters:
Y (jax array of shape (..., k, d)) – Keypoint observations.
v (jax array of shape (..., d)) – Centroid positions.
h (jax array) – Heading angles.
- Returns:
Y_transformed – Rigidly transformed positions.
- Return type:
jax array of shape (…, k, d)
- jax_moseq.models.keypoint_slds.alignment.center_embedding(n)[source]
Generates a matrix
Gammathat maps from a (n-1)-dimensional vector space to the space of k-tuples with zero mean- Parameters:
n (int) – Number of keypoints.
- Returns:
Gamma – Matrix to map to centered embedded space.
- Return type:
jax array of shape (n, n - 1)
- jax_moseq.models.keypoint_slds.alignment.apply_rotation(Y, h)[source]
Rotate
Ybyhradians.- Parameters:
Y (jax array of shape (..., k, d)) – Keypoint observations.
h (jax array) – Heading angles.
- Returns:
Y_rot – Rotated keypoint observations.
- Return type:
jax array of shape (…, k, d)
- jax_moseq.models.keypoint_slds.alignment.angle_to_rotation_matrix(h, d=3)[source]
Create rotation matrices from an array of angles. If
d > 2then rotation is performed in the first two dims.- Parameters:
h (jax array of shape (N, T)) – Heading angles.
d (int, default=3) – Keypoint dimensionality (either 2 or 3).
- Returns:
m – Rotation matrices.
- Return type:
jax array of shape (…, d, d)
- jax_moseq.models.keypoint_slds.alignment.vector_to_angle(V)[source]
Convert 2D vectors to angles in [-pi, pi]. The vector (1,0) corresponds to angle of 0. If V is multidimensional, the first n-1 dimensions are treated as batch dims.
- Parameters:
V (jax array of shape (..., 2)) – Batch of 2D vectors.
- Returns:
h – Rotation angles in radians.
- Return type:
jax array
- jax_moseq.models.keypoint_slds.alignment.fit_pca(Y, mask, anterior_idxs=None, posterior_idxs=None, conf=None, conf_threshold=0.5, verbose=False, PCA_fitting_num_frames=1000000, exclude_outliers_for_pca=False, fix_heading=False, **kwargs)[source]
Fit a PCA model to transformed keypoint coordinates. If
confis not None, perform linear interpolation over outliers defined byconf < conf_threshold.- Parameters:
Y (jax array of shape (..., k, d)) – Keypoint observations.
mask (jax array) – Binary indicator for valid frames.
anterior_idxs (iterable of ints) – Anterior keypoint indices for heading initialization.
posterior_idxs (iterable of ints) – Posterior keypoint indices for heading initialization.
conf (jax array of shape (..., k), optional) – Confidence for each keypoint observation. Must be >= 0.
conf_threshold (float, default=0.5) – Confidence threshold for interpolation.
verbose (bool, default=False) – Whether to print progress updates.
PCA_fitting_num_frames (int, default=1000000) – Maximum number of frames for PCA fitting.
exclude_outliers_for_pca (bool, default=False) – Whether to exclude frames with low-confidence keypoints. If False, then the low-confidence keypoint coordinates are linearly interpolated.
fix_heading (bool, default=False) – Whether keep the heading angle fixed. If true, the heading
his set to 0 and keypoints are not rotated.**kwargs (dict) – Overflow, for convenience.
- Returns:
PCA object fit to observations.
- Return type:
pca, sklearn.decomposition._pca.PCA
- jax_moseq.models.keypoint_slds.alignment.preprocess_for_pca(Y, anterior_idxs, posterior_idxs, conf=None, conf_threshold=0.5, fix_heading=False, verbose=False, **kwargs)[source]
Prepare keypoint coordinates for PCA by performing egocentric alignment (optional), changing basis using
center_embedding(k), and reshaping to a single flat vector per frame.- Parameters:
Y (jax array of shape (..., k, d)) – Keypoint observations.
anterior_idxs (iterable of ints) – Anterior keypoint indices for heading initialization.
posterior_idxs (iterable of ints) – Posterior keypoint indices for heading initialization.
conf (jax array of shape (..., k), optional) – Confidence for each keypoint observation. Must be >= 0.
conf_threshold (float, default=.5) – Confidence threshold for interpolation.
fix_heading (bool, default=False) – Whether keep the heading angle fixed. If true, the heading
his set to 0 and keypoints are not rotated.verbose (bool, default=False) – Whether to print progress updates.
**kwargs (dict) – Overflow, for convenience.
- Returns:
Y_flat – Aligned and embedded keypoint observations.
- Return type:
jax array of shape (…, (k - 1) * d), optional
- jax_moseq.models.keypoint_slds.alignment.align_egocentric(Y, anterior_idxs, posterior_idxs, **kwargs)[source]
Perform egocentric alignment of keypoints by translating the centroid to the origin and rotatating so that the vector pointing from the posterior bodyparts toward the anterior bodyparts is proportional to (1,0).
- Parameters:
Y (jax array of shape (..., k, d)) – Keypoint observations.
anterior_idxs (iterable of ints) – Anterior keypoint indices for heading initialization.
posterior_idxs (iterable of ints) – Posterior keypoint indices for heading initialization.
**kwargs (dict) – Overflow, for convenience.
- Returns:
Y_aligned (jax array of shape (…, k, d)) – Aligned keypoint coordinates.
v (jax array of shape (…, d)) – Centroid positions that were used for alignment.
h (jax array) – Heading angles that were used for alignment.