Source code for jax_moseq.utils.kalman

import jax
import jax.numpy as jnp
import jax.random as jr
from jax import lax
from functools import partial

from dynamax.linear_gaussian_ssm.parallel_inference import (
    lgssm_posterior_sample as parallel_lgssm_sample,
)
from dynamax.linear_gaussian_ssm.inference import (
    lgssm_posterior_sample as serial_lgssm_sample,
    _condition_on,
    ParamsLGSSM,
    ParamsLGSSMInitial,
    ParamsLGSSMDynamics,
    ParamsLGSSMEmissions,
)

from jax_moseq.utils.autoregression import get_nlags

na = jnp.newaxis


[docs] def kalman_sample( seed, ys, mask, zs, m0, S0, A, B, Q, C, D, Rs, masked_dynamics_params, masked_obs_noise, jitter=0, parallel=True, ): """Run forward-filtering and backward-sampling to draw samples from posterior of a 1st-order dynamic system. Parameters ---------- seed: jr.PRNGKey. ys: jax.Array with shape (T, obs_dim) Continuous observation sequence mask: jax.Array with shape (T,) Indicator of observation validity zs: jax.Array with shape (T-1,) Discrete state sequence, taking integer values [0, n_states). mu0: jax.Array with shape (ar_dim,) Initial continuous state mean S0: jax.Array with shape (ar_dim, ar_dim) Initial continuous state covariance A: jax.Array with shape (n_states, ar_dim, ar_dim) State dynamics matrix B: jax.Array with shape (n_states, ar_dim) State input matrix Q: jax.Array with shape (n_states, ar_dim, ar_dim) State noise matrix C: jax.Array with shape (obs_dim, ar_dim) Observation transform matrix D: jax.Array with shape (obs_dim,) Observation input matrix Rs: jax.Array with shape (T, obs_dim) Observation noise scales (diagonal entries of covariance) masked_dynamics_params: dict with key-value pairs - weights: jax.Array with shape (ar_dim, ar_dim) - bias: jax.Array with shape (ar_dim,) - cov: jax.Array with shape (ar_dim, ar_dim) Dynamics parameters, for masked timesteps masked_obs_noise: jax.Array with shape (obs_dim,) Diagonal observation noise scale, for masked timesteps. jitter : float, default=0 Amount to boost the diagonal of the covariance matrix during backward-sampling of the continuous states. parallel : bool, default=True, Use associative scan for Kalman sampling, which is faster on a GPU but has a significantly longer jit time. Returns ------- xs: jax.Array with shape (T, ar_dim) Sampled continuous state sequence. """ n_states, _, ar_dim, obs_dim = *A.shape, Rs.shape[-1] initial_params = ParamsLGSSMInitial(mean=m0, cov=S0) A_and_mask = jnp.concatenate((A, masked_dynamics_params["weights"][na])) B_and_mask = jnp.concatenate((B, masked_dynamics_params["bias"][na])) Q_and_mask = jnp.concatenate((Q, masked_dynamics_params["cov"][na])) zs_masked = jnp.where(mask[:-1], zs, n_states).astype(int) dynamics_params = ParamsLGSSMDynamics( weights=lambda t: A_and_mask[zs_masked[t]], bias=lambda t: B_and_mask[zs_masked[t]], cov=lambda t: Q_and_mask[zs_masked[t]], input_weights=jnp.zeros((ar_dim, 0)), ) emissions_params = ParamsLGSSMEmissions( weights=C, bias=D, input_weights=jnp.zeros((obs_dim, 0)), cov=jnp.where(mask[:, None], Rs, masked_obs_noise), ) params = ParamsLGSSM( initial=initial_params, dynamics=dynamics_params, emissions=emissions_params, ) if parallel: return parallel_lgssm_sample(seed, params, ys) else: return serial_lgssm_sample(seed, params, ys, jitter=jitter)
[docs] def ar_to_lds_emissions(Cd, R, y, m0, S0, nlags): """ Given a linear dynamical system with L'th-order autoregressive dynamics in R^D, returns the emission terms and initia state distribution for a system with 1st-order dynamics in R^(D*L) Parameters ---------- Cd: jax array, shape (D_obs, D+1) Observation affine transformation R: jax array, shape (T, D_obs) Dimension-wise observation covariances y: jax array, shape (T, D_obs) Observations m0: jax array, shape (D) Initial state distribution mean S0: jax array, shape (D, D) Initial state distribution cov nlags: Number of autoregressive lags Returns ------- C_: jax array, shape (D_obs, D*L) d_: jax array, shape (D_obs) R_: jax array, shape (T, D_obs) y_: jax array, shape (T, D_obs) m0_: jax array, shape (D*L) S0_: jax array, shape (D*L, D*L) """ obs_dim = y.shape[-1] latent_dim = Cd.shape[-1] - 1 lds_dim = latent_dim * nlags C = Cd[:, :-1] C_ = jnp.zeros((obs_dim, lds_dim)) C_ = C_.at[:, -latent_dim:].set(C) d_ = Cd[:, -1] R_ = R[nlags - 1 :] y_ = y[nlags - 1 :] C0 = jnp.zeros((obs_dim * (nlags - 1), lds_dim)) for l in range(nlags - 1): C0 = C0.at[ l * obs_dim : (l + 1) * obs_dim, l * latent_dim : (l + 1) * latent_dim ].set(C) d0 = jnp.tile(d_, (nlags - 1,)) D0 = jnp.zeros((obs_dim * (nlags - 1), 2)) u0 = jnp.zeros( 2, ) R0 = jnp.diag(R[: nlags - 1, :].reshape(-1)) y0 = y[: nlags - 1, :].reshape(-1) m0_, S0_ = _condition_on(m0, S0, C0, D0, d0, R0, u0, y0) return C_, d_, R_, y_, m0_, S0_
[docs] def ar_to_lds_dynamics(Ab, Q): """ Given a linear dynamical system with L'th-order autoregressive dynamics in R^D, returns the dynamics terms for a system with 1st-order dynamics in R^(D*L) Parameters ---------- Ab: jax array, shape (..., D, D*L + 1) AR affine transform Q: jax array, shape (..., D, D) AR noise covariance Returns ------- A_: jax array, shape (..., D*L, D*L) b_: jax array, shape (..., D*L) Q_: jax array, shape (..., D*L, D*L) """ nlags = get_nlags(Ab) latent_dim = Ab.shape[-2] lds_dim = latent_dim * nlags eye = jnp.eye(latent_dim * (nlags - 1)) A = Ab[..., :-1] dims = A.shape[:-2] A_ = jnp.zeros((*dims, lds_dim, lds_dim)) A_ = A_.at[..., :-latent_dim, latent_dim:].set(eye) A_ = A_.at[..., -latent_dim:, :].set(A) b = Ab[..., -1] b_ = jnp.zeros((*dims, lds_dim)) b_ = b_.at[..., -latent_dim:].set(b) dims = Q.shape[:-2] Q_ = jnp.zeros((*dims, lds_dim, lds_dim)) Q_ = Q_.at[..., :-latent_dim, :-latent_dim].set(eye * 1e-2) Q_ = Q_.at[..., -latent_dim:, -latent_dim:].set(Q) return A_, b_, Q_