import jax
import jax.numpy as jnp
import numpy as np
from jax_moseq import utils
from jax_moseq.utils import jax_io, apply_affine
na = jnp.newaxis
[docs]
def to_vanilla_slds(Y, v, h, s, Cd, sigmasq, **kwargs):
"""
Given the empirical keypoint positions, position/heading
estimates, isotropic noise estimates, and emission parameters,
this function returns the (relevant subset of the) observations,
states, and params for an equivalent SLDS that directly maps the
latent trajectories to flattened and aligned keypoint observations.
Parameters
----------
Y : jax array of shape (..., k, d)
Keypoint observations.
v : jax array of shape (..., d)
Centroid positions.
h : jax array of shape (..., T)
Heading angles.
s : jax array of shape (..., k)
Noise scales.
Cd : jax array of shape ((k - 1) * d, latent_dim + 1)
Observation transform.
sigmasq : jax_array of shape k
Unscaled noise.
**kwargs : dict
Overflow, for convenience.
Returns
-------
Y : jax array of shape (..., k * d)
Flattened and aligned keypoint observations.
s : jax array of shape (..., k * d)
Noise scales repeated along spatial dimension.
Cd : jax array of shape (k * d, latent_dim + 1)
Emission parameters, accounting for effect of embedding.
sigmasq : jax_array of shape k * d
Unscaled noise repeated along spatial dimension.
"""
# d denotes keypoint dim, not emission bias
batch_shape = Y.shape[:-2]
k, d = Y.shape[-2:]
# Obtain aligned and flattened estimates of keypoint positions
Y = inverse_rigid_transform(Y, v, h).reshape(*batch_shape, -1)
# Reformat Cd to map x -> (k * d)-dimensional flattened, aligned coordinate space
Gamma = center_embedding(k)
Cd = jnp.kron(Gamma, jnp.eye(d)) @ Cd
# Repeat the isotropic noise estimates along the collapsed
# keypoint/coordinate axis.
s = jnp.repeat(s, d, axis=-1)
sigmasq = jnp.repeat(sigmasq, d, axis=-1)
return Y, s, Cd, sigmasq
[docs]
def estimate_coordinates(x, v, h, Cd, **kwargs):
"""
Estimate keypoint coordinates obtained from projecting the
latent state ``x`` into keypoint-space (via ``Cd``) and then
rotating and translating by ``h`` and ``v`` respectively
Parameters
----------
x : jax array of shape (..., latent_dim)
Latent trajectories.
v : jax array of shape (..., d)
Centroid positions.
h : jax array
Heading angles.
Cd : jax array of shape ((k - 1) * d, latent_dim + 1)
Observation transform.
**kwargs : dict
Overflow, for convenience.
Returns
-------
Y_bar : jax array of shape (..., k, d), Estimated coordinates
"""
batch_shape = x.shape[:-2]
d = v.shape[-1]
k = Cd.shape[0] // d + 1
Y_bar = estimate_aligned(x, Cd, k)
Y_est = rigid_transform(Y_bar, v, h)
return Y_est
[docs]
def estimate_aligned(x, Cd, k):
"""
Compute estimated positions of aligned keypoints
(i.e. prior to applying the rigid transform).
Parameters
----------
x : jax array of shape (..., latent_dim)
Latent trajectories.
Cd : jax array of shape ((k - 1) * d, latent_dim + 1)
Observation transform.
k : int
Number of keypoints.
Returns
------
Y_aligned : jax array of shape (..., k * d)
Aligned keypoint positions estimated from latents.
"""
# Apply emissions to obtain flattened,
# centered, keypoint observation
y = apply_affine(x, Cd)
# Reshape keypoints
batch_shape = x.shape[:-1]
y = y.reshape(*batch_shape, k - 1, -1)
Gamma = center_embedding(k)
return Gamma @ y
[docs]
def center_embedding(n):
"""
Generates a matrix ``Gamma`` that maps from a (n-1)-dimensional
vector space to the space of k-tuples with zero mean
Parameters
----------
n : int
Number of keypoints.
Returns
-------
Gamma: jax array of shape (n, n - 1)
Matrix to map to centered embedded space.
"""
X = jnp.tril(jnp.ones((n, n)), k=-1)[1:]
X = jnp.eye(n)[1:] - X / X.sum(1)[:, na]
X = X / jnp.sqrt((X**2).sum(1))[:, na]
return X.T
[docs]
def apply_rotation(Y, h):
"""
Rotate ``Y`` by ``h`` radians.
Parameters
----------
Y : jax array of shape (..., k, d)
Keypoint observations.
h : jax array
Heading angles.
Returns
------
Y_rot : jax array of shape (..., k, d)
Rotated keypoint observations.
"""
d = Y.shape[-1]
rot_matrix = angle_to_rotation_matrix(h, d)
return jnp.einsum("...kj,...ij->...ki", Y, rot_matrix)
[docs]
def angle_to_rotation_matrix(h, d=3):
"""
Create rotation matrices from an array of angles. If
``d > 2`` then rotation is performed in the first two dims.
Parameters
----------
h : jax array of shape (N, T)
Heading angles.
d : int, default=3
Keypoint dimensionality (either 2 or 3).
Returns
------
m: jax array of shape (..., d, d)
Rotation matrices.
"""
m = jnp.tile(jnp.eye(d), (*h.shape, 1, 1))
m = m.at[..., 0, 0].set(jnp.cos(h))
m = m.at[..., 1, 1].set(jnp.cos(h))
m = m.at[..., 0, 1].set(-jnp.sin(h))
m = m.at[..., 1, 0].set(jnp.sin(h))
return m
[docs]
def vector_to_angle(V):
"""
Convert 2D vectors to angles in [-pi, pi]. The vector (1,0)
corresponds to angle of 0. If V is multidimensional, the first
n-1 dimensions are treated as batch dims.
Parameters
----------
V : jax array of shape (..., 2)
Batch of 2D vectors.
Returns
------
h : jax array
Rotation angles in radians.
"""
return jnp.arctan2(V[..., 1], V[..., 0])
[docs]
def fit_pca(
Y,
mask,
anterior_idxs=None,
posterior_idxs=None,
conf=None,
conf_threshold=0.5,
verbose=False,
PCA_fitting_num_frames=1000000,
exclude_outliers_for_pca=False,
fix_heading=False,
**kwargs,
):
"""
Fit a PCA model to transformed keypoint coordinates. If ``conf`` is
not None, perform linear interpolation over outliers defined by
``conf < conf_threshold``.
Parameters
----------
Y : jax array of shape (..., k, d)
Keypoint observations.
mask : jax array
Binary indicator for valid frames.
anterior_idxs : iterable of ints
Anterior keypoint indices for heading initialization.
posterior_idxs : iterable of ints
Posterior keypoint indices for heading initialization.
conf : jax array of shape (..., k), optional
Confidence for each keypoint observation. Must be >= 0.
conf_threshold : float, default=0.5
Confidence threshold for interpolation.
verbose : bool, default=False
Whether to print progress updates.
PCA_fitting_num_frames : int, default=1000000
Maximum number of frames for PCA fitting.
exclude_outliers_for_pca : bool, default=False
Whether to exclude frames with low-confidence keypoints.
If False, then the low-confidence keypoint coordinates are
linearly interpolated.
fix_heading : bool, default=False
Whether keep the heading angle fixed. If true, the
heading ``h`` is set to 0 and keypoints are not rotated.
**kwargs : dict
Overflow, for convenience.
Returns
-------
pca, sklearn.decomposition._pca.PCA
PCA object fit to observations.
"""
Y_flat = preprocess_for_pca(
Y,
anterior_idxs,
posterior_idxs,
conf,
conf_threshold,
fix_heading,
verbose,
)[0]
if (not exclude_outliers_for_pca) or (conf is None):
pca_mask = mask
else:
pca_mask = jnp.logical_and(mask, (conf > conf_threshold).all(-1))
assert pca_mask.sum() >= (Y.shape[-1] * Y.shape[-2]), (
"Not enough frames for PCA fitting. Make sure "
"`exclude_outliers_for_pca=False` 'or decrease `conf_threshold`."
)
return utils.fit_pca(Y_flat, pca_mask, PCA_fitting_num_frames, verbose)
[docs]
def preprocess_for_pca(
Y,
anterior_idxs,
posterior_idxs,
conf=None,
conf_threshold=0.5,
fix_heading=False,
verbose=False,
**kwargs,
):
"""
Prepare keypoint coordinates for PCA by performing egocentric
alignment (optional), changing basis using ``center_embedding(k)``,
and reshaping to a single flat vector per frame.
Parameters
----------
Y : jax array of shape (..., k, d)
Keypoint observations.
anterior_idxs : iterable of ints
Anterior keypoint indices for heading initialization.
posterior_idxs : iterable of ints
Posterior keypoint indices for heading initialization.
conf : jax array of shape (..., k), optional
Confidence for each keypoint observation. Must be >= 0.
conf_threshold : float, default=.5
Confidence threshold for interpolation.
fix_heading : bool, default=False
Whether keep the heading angle fixed. If true, the
heading ``h`` is set to 0 and keypoints are not rotated.
verbose : bool, default=False
Whether to print progress updates.
**kwargs : dict
Overflow, for convenience.
Returns
-------
Y_flat : jax array of shape (..., (k - 1) * d), optional
Aligned and embedded keypoint observations.
"""
if conf is not None:
outliers = conf < conf_threshold
if verbose:
n = outliers.sum()
pct = outliers.mean() * 100
print(f"Interpolating {n} ({pct:.1f}%) low-confidence keypoints")
Y = interpolate(Y, outliers).block_until_ready()
if fix_heading:
v = Y.mean(-2)
h = jnp.zeros(Y.shape[:-2])
Y_aligned = Y - v[..., na, :]
else:
Y_aligned, v, h = align_egocentric(Y, anterior_idxs, posterior_idxs)
dims = Y.shape[:-2]
k, d = Y.shape[-2:]
Gamma_inv = center_embedding(k).T
Y_embedded = Gamma_inv @ Y_aligned
Y_flat = Y_embedded.reshape(*dims, (k - 1) * d)
return Y_flat, v, h
[docs]
def align_egocentric(Y, anterior_idxs, posterior_idxs, **kwargs):
"""
Perform egocentric alignment of keypoints by translating the
centroid to the origin and rotatating so that the vector pointing
from the posterior bodyparts toward the anterior bodyparts is
proportional to (1,0).
Parameters
----------
Y : jax array of shape (..., k, d)
Keypoint observations.
anterior_idxs : iterable of ints
Anterior keypoint indices for heading initialization.
posterior_idxs : iterable of ints
Posterior keypoint indices for heading initialization.
**kwargs : dict
Overflow, for convenience.
Returns
-------
Y_aligned : jax array of shape (..., k, d)
Aligned keypoint coordinates.
v : jax array of shape (..., d)
Centroid positions that were used for alignment.
h : jax array
Heading angles that were used for alignment.
"""
posterior_loc = Y[..., posterior_idxs, :2].mean(-2)
anterior_loc = Y[..., anterior_idxs, :2].mean(-2)
h = vector_to_angle(anterior_loc - posterior_loc)
v = Y.mean(-2)
Y_aligned = inverse_rigid_transform(Y, v, h)
return Y_aligned, v, h
@jax_io
def interpolate(Y, outliers, axis=1):
"""
Use linear interpolation to impute the coordinates of outliers.
Parameters
----------
Y : jax array of shape (N, T, k, d)
Keypoint observations.
outliers : jax array of shape (..., T, k)
Binary indicator whose true entries are outlier points.
axis : int, default=1
Axis to interpolate along.
Returns
-------
Y_interp : jax array, shape (..., T, k, d)
Copy of ``Y`` where outliers have been replaced by
linearly interpolated values.
"""
Y = np.moveaxis(Y, axis, 0)
init_shape = Y.shape
Y = Y.reshape(init_shape[0], -1)
outliers = np.moveaxis(outliers, axis, 0)
outliers = np.repeat(outliers[..., None], init_shape[-1], axis=-1)
outliers = outliers.reshape(init_shape[0], -1)
interp = lambda x, xp, fp: (
np.ones_like(x) * x.mean() if len(xp) == 0 else np.interp(x, xp, fp)
)
Y = np.stack(
[
interp(
np.arange(init_shape[0]),
np.nonzero(~outliers[:, i])[0],
Y[:, i][~outliers[:, i]],
)
for i in range(Y.shape[1])
],
axis=1,
)
return np.moveaxis(Y.reshape(init_shape), 0, axis)