Source code for jax_moseq.models.keypoint_slds.log_prob

import jax
import jax.numpy as jnp
import tensorflow_probability.substrates.jax.distributions as tfd

from jax_moseq.models import arhmm, slds
from jax_moseq.models.keypoint_slds.alignment import estimate_coordinates

na = jnp.newaxis


[docs] def location_log_prob(v, sigmasq_loc): """ Calculate the log probability of the centroid location at each time-step, given the prior on centroid movement. Parameters ---------- v : jax array of shape (..., T, d) Centroid positions. sigmasq_loc : float Assumed variance in centroid displacements. Returns ------- log_pv: jax array of shape (..., T - 1) Log probability of `v`. """ v0 = v[..., :-1, :] v1 = v[..., 1:, :] sigma = jnp.sqrt(sigmasq_loc) * jnp.ones_like(v0) return tfd.MultivariateNormalDiag(v0, sigma).log_prob(v1)
[docs] def obs_log_prob(Y, x, v, h, s, Cd, sigmasq, **kwargs): """ Calculate the log probability of keypoint coordinates at each time-step, given continuous latent trajectories, centroids, heading angles, noise scales, and observation parameters. Parameters ---------- Y : jax array of shape (..., k, d) Keypoint observations. x : jax array of shape (..., latent_dim) Latent trajectories. v : jax array of shape (..., d) Centroid positions. h : jax array Heading angles. s : jax array of shape (..., k) Noise scales. Cd : jax array of shape ((k - 1) * d, latent_dim + 1) Observation transform. sigmasq : jax_array of shape k Unscaled noise. **kwargs : dict Overflow, for convenience. Returns ------- log_pY: jax array of shape (..., k) Log probability of `Y`. """ Y_bar = estimate_coordinates(x, v, h, Cd) sigma = jnp.broadcast_to(jnp.sqrt(s * sigmasq)[..., na], Y.shape) return tfd.MultivariateNormalDiag(Y_bar, sigma).log_prob(Y)
[docs] @jax.jit def log_joint_likelihood( Y, mask, x, v, h, s, z, pi, Ab, Q, Cd, sigmasq, sigmasq_loc, s_0, nu_s, **kwargs ): """ Calculate the total log probability for each latent state. Parameters ---------- Y : jax array of shape (..., T, k, d) Keypoint observations. mask : jax array of shape (..., T) Binary indicator for valid frames. x : jax array of shape (..., T, latent_dim) Latent trajectories. v : jax array of shape (..., T, d) Centroid positions. h : jax array of shape (..., T) Heading angles. s : jax array of shape (..., T, k) Noise scales. z : jax_array of shape (..., T - n_lags) Discrete state sequences. pi : jax_array of shape (num_states, num_states) Transition probabilities. Ab : jax array of shape (num_states, latent_dim, ar_dim) Autoregressive transforms. Q : jax array of shape (num_states, latent_dim, latent_dim) Autoregressive noise covariances. Cd : jax array of shape ((k - 1) * d, latent_dim + 1) Observation transform. sigmasq : jax_array of shape k Unscaled noise. sigmasq_loc : float Assumed variance in centroid displacements. s_0 : scalar or jax array broadcastable to `Y` Prior on noise scale. nu_s : int Chi-squared degrees of freedom in noise prior. **kwargs : dict Overflow, for convenience. Returns ------- ll: dict Dictionary mapping the name of each state variable to its total log probability. """ ll = arhmm.log_joint_likelihood(x, mask, z, pi, Ab, Q) log_pY = obs_log_prob(Y, x, v, h, s, Cd, sigmasq) log_ps = slds.scale_log_prob(s, s_0, nu_s) log_pv = location_log_prob(v, sigmasq_loc) ll["Y"] = (log_pY * mask[..., na]).sum() ll["s"] = (log_ps * mask[..., na]).sum() ll["v"] = (log_pv * mask[..., 1:]).sum() return ll
[docs] def model_likelihood(data, states, params, hypparams, noise_prior, **kwargs): """ Convenience class that invokes `log_joint_likelihood`. Parameters ---------- data : dict Data dictionary containing the observations and mask. states : dict State values for each latent variable. params : dict Values for each model parameter. hypparams : dict Values for each group of hyperparameters. noise_prior : scalar or jax array broadcastable to `s` Prior on noise scale. **kwargs : dict Overflow, for convenience. Returns ------ ll : dict Dictionary mapping state variable name to its total log probability. """ return log_joint_likelihood( **data, **states, **params, **hypparams["obs_hypparams"], **hypparams["cen_hypparams"], s_0=noise_prior )