Source code for jax_moseq.models.slds.log_prob

import jax
import jax.numpy as jnp
import tensorflow_probability.substrates.jax.distributions as tfd

from jax_moseq.utils import apply_affine

from jax_moseq.models import arhmm

na = jnp.newaxis


[docs] def scale_log_prob(s, s_0, nu_s, **kwargs): """ Calculate the log probability of the noise scale `s` for each datapoint given the noise prior, which is a scaled inverse chi-squared distribution. Parameters ---------- s : jax array Noise scales. s_0 : scalar or jax array, broadcastable to `s` Prior on noise scale. nu_s : int Chi-squared degrees of freedom in noise prior. **kwargs : dict Overflow, for convenience. Returns ------- log_ps: jax array Log probability of `s`. """ return -nu_s * s_0 / s / 2 - (1 + nu_s / 2) * jnp.log(s)
[docs] def obs_log_prob(Y, x, s, Cd, sigmasq, **kwargs): """ Calculate the log probability of the observations at each time-step given the latent trajectories, noise parameters, and observation matrix. Parameters ---------- Y : jax array of shape (..., obs_dim) Observations. x : jax array of shape (..., latent_dim) Latent trajectories. s : jax array of shape (..., obs_dim) Noise scales. Cd : jax array of shape (obs_dim, latent_dim + 1) Observation transform. sigmasq : jax_array of shape obs_dim Unscaled noise. **kwargs : dict Overflow, for convenience. Returns ------- log_pY: jax array of shape (..., obs_dim) Log probability of `Y`. """ Y_bar = apply_affine(x, Cd) cov = jnp.sqrt(s * sigmasq) return tfd.MultivariateNormalDiag(Y_bar, cov).log_prob(Y)
[docs] @jax.jit def log_joint_likelihood( Y, mask, x, s, z, pi, Ab, Q, Cd, sigmasq, s_0, nu_s, **kwargs ): """ Calculate the total log probability for each latent state. Parameters ---------- Y : jax array of shape (..., T, obs_dim) Observations. mask : jax array of shape (..., T) Binary indicator for valid frames. x : jax array of shape (..., T, latent_dim) Latent trajectories. s : jax array of shape (..., T, obs_dim) Noise scales. z : jax_array of shape (..., T - n_lags) Discrete state sequences. pi : jax_array of shape (num_states, num_states) Transition probabilities. Ab : jax array of shape (num_states, latent_dim, ar_dim) Autoregressive transforms. Q : jax array of shape (num_states, latent_dim, latent_dim) Autoregressive noise covariances. Cd : jax array of shape (obs_dim, latent_dim + 1) Observation transform. sigmasq : jax_array of shape obs_dim Unscaled noise. s_0 : scalar or jax array broadcastable to `Y` Prior on noise scale. nu_s : int Chi-squared degrees of freedom in noise prior. **kwargs : dict Overflow, for convenience. Returns ------- ll: dict Dictionary mapping the name of each state variable to its total log probability. """ ll = arhmm.log_joint_likelihood(x, mask, z, pi, Ab, Q) log_pY = obs_log_prob(Y, x, s, Cd, sigmasq) log_ps = scale_log_prob(s, s_0, nu_s) ll["Y"] = (log_pY * mask).sum() ll["s"] = (log_ps * mask[..., na]).sum() return ll
[docs] def model_likelihood(data, states, params, hypparams, **kwargs): """ Convenience class that invokes `log_joint_likelihood`. Parameters ---------- data : dict Data dictionary containing the observations and mask. states : dict State values for each latent variable. params : dict Values for each model parameter. hypparams : dict Values for each group of hyperparameters. **kwargs : dict Overflow, for convenience. Returns ------ ll : dict Dictionary mapping state variable name to its total log probability. """ return log_joint_likelihood( **data, **states, **params, **hypparams["obs_hypparams"] )