Source code for jax_moseq.models.keypoint_slds.gibbs

import jax
import jax.numpy as jnp
import jax.random as jr
from functools import partial

from jax_moseq.utils.kalman import kalman_sample
from jax_moseq.utils.distributions import sample_vonmises_fisher

from jax_moseq.models import arhmm, slds
from jax_moseq.models.keypoint_slds.alignment import (
    to_vanilla_slds,
    estimate_coordinates,
    estimate_aligned,
    apply_rotation,
    vector_to_angle,
)

na = jnp.newaxis


[docs] def resample_continuous_stateseqs( seed, Y, mask, v, h, s, z, Cd, sigmasq, Ab, Q, jitter=1e-3, parallel_message_passing=True, **kwargs ): """ Resamples the latent trajectories ``x``. Parameters ---------- seed : jr.PRNGKey JAX random seed. Y : jax array of shape (N, T, k, d) Keypoint observations. mask : jax array of shape (N, T) Binary indicator for valid frames. v : jax array of shape (N, T, d) Centroid positions. h : jax array of shape (N, T) Heading angles. s : jax array of shape (N, T, k) Noise scales. z : jax_array of shape (N, T - n_lags) Discrete state sequences. Cd : jax array of shape ((k - 1) * d, latent_dim + 1) Observation transform. sigmasq : jax_array of shape k Unscaled noise. Ab : jax array of shape (num_states, latent_dim, ar_dim) Autoregressive transforms. Q : jax array of shape (num_states, latent_dim, latent_dim) Autoregressive noise covariances. jitter : float, default=1e-3 Amount to boost the diagonal of the covariance matrix during backward-sampling of the continuous states. parallel_message_passing : bool, default=True, Use associative scan for Kalman sampling, which is faster on a GPU but has a significantly longer jit time. **kwargs : dict Overflow, for convenience. Returns ------ x : jax array of shape (N, T, latent_dim) Latent trajectories. """ Y, s, Cd, sigmasq = to_vanilla_slds(Y, v, h, s, Cd, sigmasq) x = slds.resample_continuous_stateseqs( seed, Y, mask, z, s, Ab, Q, Cd, sigmasq, jitter=jitter, parallel_message_passing=parallel_message_passing, ) return x
[docs] @jax.jit def resample_obs_variance(seed, Y, mask, Cd, x, v, h, s, nu_sigma, sigmasq_0, **kwargs): """ Resample the observation variance ``sigmasq``. Parameters ---------- seed : jr.PRNGKey JAX random seed. Y : jax array of shape (N, T, k, d) Keypoint observations. mask : jax array of shape (N, T) Binary indicator for valid frames. Cd : jax array of shape ((k - 1) * d, latent_dim + 1) Observation transform. x : jax array of shape (N, T, latent_dim) Latent trajectories. v : jax array of shape (N, T, d) Centroid positions. h : jax array of shape (N, T) Heading angles. s : jax array of shape (N, T, k) Noise scales. nu_sigma : float Chi-squared degrees of freedom in sigmasq. sigmasq_0 : float Scaled inverse chi-squared scaling parameter for sigmasq. **kwargs : dict Overflow, for convenience. Returns ------ sigmasq : jax_array of shape k Unscaled noise. """ sqerr = compute_squared_error(Y, x, v, h, Cd, mask) return slds.resample_obs_variance_from_sqerr( seed, sqerr, mask, s, nu_sigma, sigmasq_0 )
[docs] @jax.jit def resample_scales(seed, Y, x, v, h, Cd, sigmasq, nu_s, s_0, **kwargs): """ Resample the scale values ``s``. Parameters ---------- seed : jr.PRNGKey JAX random seed. Y : jax array of shape (N, T, k, d) Keypoint observations. x : jax array of shape (N, T, latent_dim) Latent trajectories. v : jax array of shape (N, T, d) Centroid positions. h : jax array of shape (N, T) Heading angles. Cd : jax array of shape ((k - 1) * d, latent_dim + 1) Observation transform. sigmasq : jax_array of shape k Unscaled noise. nu_s : int Chi-squared degrees of freedom in noise prior. s_0 : scalar or jax array broadcastable to ``Y`` Prior on noise scale. **kwargs : dict Overflow, for convenience. Returns ------ s : jax array of shape (N, T, k) Noise scales. """ sqerr = compute_squared_error(Y, x, v, h, Cd) return slds.resample_scales_from_sqerr(seed, sqerr, sigmasq, nu_s, s_0)
[docs] @jax.jit def compute_squared_error(Y, x, v, h, Cd, mask=None): """ Computes the squared error between model predicted and true observations. Parameters ---------- Y : jax array of shape (..., k, d) Keypoint observations. x : jax array of shape (..., latent_dim) Latent trajectories. v : jax array of shape (..., d) Centroid positions. h : jax array Heading angles. Cd : jax array of shape ((k - 1) * d, latent_dim + 1) Observation transform. mask : jax array, optional Binary indicator for valid frames. Returns ------ sqerr : jax array of shape (..., k) Squared error between model predicted and true observations. """ Y_est = estimate_coordinates(x, v, h, Cd) sqerr = ((Y - Y_est) ** 2).sum(-1) if mask is not None: sqerr = mask[..., na] * sqerr return sqerr
[docs] @jax.jit def resample_heading(seed, Y, x, v, s, Cd, sigmasq, **kwargs): """ Resample the heading angles ``h``. Parameters ---------- seed : jr.PRNGKey JAX random seed. Y : jax array of shape (N, T, k, d) Keypoint observations. x : jax array of shape (N, T, latent_dim) Latent trajectories. v : jax array of shape (N, T, d) Centroid positions. s : jax array of shape (N, T, k) Noise scales. Cd : jax array of shape ((k - 1) * d, latent_dim + 1) Observation transform. sigmasq : jax_array of shape k Unscaled noise. **kwargs : dict Overflow, for convenience. Returns ------ h : jax array of shape (N, T) Heading angles. """ k = Y.shape[-2] Y_bar = estimate_aligned(x, Cd, k) Y_cent = Y - v[..., na, :] variance = s * sigmasq # [(..., t, k, d, na) * (..., t, k, na, d) / (..., t, k, na, na)] -> (..., t, d, d) S = (Y_bar[..., :2, na] * Y_cent[..., na, :2] / variance[..., na, na]).sum(-3) del Y_bar, Y_cent, variance # free up memory kappa_cos = S[..., 0, 0] + S[..., 1, 1] kappa_sin = S[..., 0, 1] - S[..., 1, 0] del S mean_direction = jnp.stack([kappa_cos, kappa_sin], axis=-1) sampled_direction = sample_vonmises_fisher(seed, mean_direction) h = vector_to_angle(sampled_direction) return h
[docs] @jax.jit def resample_location( seed, Y, mask, x, h, s, Cd, sigmasq, sigmasq_loc, parallel_message_passing=True, **kwargs ): """ Resample the centroid positions ``v``. Parameters ---------- seed : jr.PRNGKey JAX random seed. Y : jax array of shape (N, T, k, d) Keypoint observations. mask : jax array of shape (N, T) Binary indicator for valid frames. x : jax array of shape (N, T, latent_dim) Latent trajectories. h : jax array of shape (N, T) Heading angles. s : jax array of shape (N, T, k) Noise scales. Cd : jax array of shape ((k - 1) * d, latent_dim + 1) Observation transform. sigmasq : jax_array of shape k Unscaled noise. sigmasq_loc : float Assumed variance in centroid displacements. parallel_message_passing : bool, default=True, Use associative scan for Kalman sampling, which is faster on a GPU but has a significantly longer jit time. **kwargs : dict Overflow, for convenience. Returns ------ v : jax array of shape (N, T, d) Centroid positions. """ k, d = Y.shape[-2:] Y_rot = apply_rotation(estimate_aligned(x, Cd, k), h) variance = s * sigmasq gammasq = 1 / (1 / variance).sum(-1, keepdims=True) mu = jnp.einsum("...tkd, ...tk->...td", Y - Y_rot, gammasq / variance) # Apply Kalman filter to get smooth headings # TODO Parameterize these distributional hyperparameter seed = jr.split(seed, mask.shape[0]) m0 = jnp.zeros(d) S0 = jnp.eye(d) * 1e4 A = jnp.eye(d)[na] B = jnp.zeros(d)[na] Q = jnp.eye(d)[na] * sigmasq_loc C = jnp.eye(d) D = jnp.zeros(d) R = jnp.repeat(gammasq, d, axis=-1) zz = jnp.zeros_like(mask[:, 1:], dtype=int) masked_dynamics_noise = sigmasq_loc * 10 masked_obs_noise = sigmasq.max() * 10 masked_dynamics_params = { "weights": jnp.eye(d), "bias": jnp.zeros(d), "cov": jnp.eye(d) * masked_dynamics_noise, } masked_obs_noise_diag = jnp.ones(d) * masked_obs_noise in_axes = (0, 0, 0, 0, na, na, na, na, na, na, na, 0, na, na) v = jax.vmap(partial(kalman_sample, parallel=parallel_message_passing), in_axes)( seed, mu, mask, zz, m0, S0, A, B, Q, C, D, R, masked_dynamics_params, masked_obs_noise_diag, ) return v
[docs] def resample_model( data, seed, states, params, hypparams, noise_prior, ar_only=False, states_only=False, resample_global_noise_scale=False, resample_local_noise_scale=True, fix_heading=False, verbose=False, jitter=1e-3, parallel_message_passing=False, **kwargs ): """ Resamples the Keypoint SLDS model given the hyperparameters, data, noise prior, current states, and current parameters. Parameters ---------- data : dict Data dictionary containing the observations and mask. seed : jr.PRNGKey JAX random seed. states : dict State values for each latent variable. params : dict Values for each model parameter. hypparams : dict Values for each group of hyperparameters. noise_prior : scalar or jax array broadcastable to ``s`` Prior on noise scale. ar_only : bool, default=False Whether to restrict sampling to ARHMM components. states_only : bool, default=False Whether to restrict sampling to states. resample_global_noise_scale : bool, default=False Whether to resample the global noise scales (``sigmasq``) resample_local_noise_scale : bool, default=True Whether to resample the local noise scales (``s``) fix_heading : bool, default=False Whether to exclude ``h`` from resampling. jitter : float, default=1e-3 Amount to boost the diagonal of the covariance matrix during backward-sampling of the continuous states. verbose : bool, default=False Whether to print progress info during resampling. parallel_message_passing : bool, default=True, Use associative scan for Kalman sampling, which is faster on a GPU but has a significantly longer jit time. Returns ------ model : dict Dictionary containing the hyperparameters and updated seed, states, and parameters of the model. """ model = arhmm.resample_model( data, seed, states, params, hypparams, states_only, verbose=verbose ) if ar_only: model["noise_prior"] = noise_prior return model seed = model["seed"] params = model["params"].copy() states = model["states"].copy() if (not states_only) and resample_global_noise_scale: if verbose: print("Resampling sigmasq (global noise scales)") params["sigmasq"] = resample_obs_variance( seed, **data, **states, **params, s_0=noise_prior, **hypparams["obs_hypparams"] ) if verbose: print("Resampling x (continuous latent states)") states["x"] = resample_continuous_stateseqs( seed, **data, **states, **params, jitter=jitter, parallel_message_passing=parallel_message_passing ) if not fix_heading: if verbose: print("Resampling h (heading)") states["h"] = resample_heading(seed, **data, **states, **params) if verbose: print("Resampling v (location)") states["v"] = resample_location( seed, **data, **states, **params, **hypparams["cen_hypparams"] ) if resample_local_noise_scale: if verbose: print("Resampling s (local noise scales)") states["s"] = resample_scales( seed, **data, **states, **params, s_0=noise_prior, **hypparams["obs_hypparams"] ) return { "seed": seed, "states": states, "params": params, "hypparams": hypparams, "noise_prior": noise_prior, }