import jax
import jax.numpy as jnp
import jax.random as jr
import tensorflow_probability.substrates.jax.distributions as tfd
from jax_moseq.utils import apply_affine
from jax_moseq.utils.distributions import sample_mniw
na = jnp.newaxis
def apply_ar_params(x, Ab):
nlags = get_nlags(Ab)
x_in = get_lags(x, nlags)
return apply_affine(x_in, Ab)
def ar_log_likelihood(x, params):
Ab, Q = params
nlags = get_nlags(Ab)
mu = apply_ar_params(x, Ab)
x = x[..., nlags:, :]
return tfd.MultivariateNormalFullCovariance(mu, Q).log_prob(x)
[docs]
def get_lags(x, nlags):
"""
Get lags of a multivariate time series. Lags are concatenated along
the last dim in time-order.
Parameters
----------
nlags: int
Number of lags
x: jax array, shape (..., t, d)
Batch of d-dimensional time series
Returns
-------
x_lagged: jax array, shape (..., t-nlags, d*nlags)
"""
lags = [jnp.roll(x, t, axis=-2) for t in range(1, nlags + 1)]
return jnp.concatenate(lags[::-1], axis=-1)[..., nlags:, :]
def get_nlags(Ab):
return Ab.shape[-1] // Ab.shape[-2]