Source code for jax_moseq.utils.distributions

import jax
import jax.numpy as jnp
import jax.random as jr
import tensorflow_probability.substrates.jax.distributions as tfd
from dynamax.hidden_markov_model.inference import hmm_posterior_sample
from jax_moseq.utils import convert_data_precision

na = jnp.newaxis


def sample_vonmises(seed, theta, kappa):
    return tfd.VonMises(theta, kappa).sample(seed=seed)


def sample_vonmises_fisher(seed, direction):
    kappa = jnp.sqrt((direction**2).sum(-1))
    direction = direction / kappa[..., na]
    return tfd.VonMisesFisher(direction, kappa).sample(seed=seed)


def sample_gamma(seed, a, b):
    return jr.gamma(seed, a) / b


def sample_inv_gamma(seed, a, b):
    return 1 / sample_gamma(seed, a, b)


def sample_scaled_inv_chi2(seed, degs, variance):
    return sample_inv_gamma(seed, degs / 2, degs * variance / 2)


def sample_chi2(seed, degs):
    return jr.gamma(seed, degs / 2) * 2


def sample_mn(seed, M, U, V):
    G = jr.normal(seed, M.shape)
    G = jnp.dot(jnp.linalg.cholesky(U), G)
    G = jnp.dot(G, jnp.linalg.cholesky(V).T)
    return M + G


def sample_invwishart(seed, S, nu):
    n = S.shape[0]

    chi2_seed, norm_seed = jr.split(seed)
    x = jnp.diag(jnp.sqrt(sample_chi2(chi2_seed, nu - jnp.arange(n))))
    x = x.at[jnp.triu_indices_from(x, 1)].set(jr.normal(norm_seed, (n * (n - 1) // 2,)))
    R = jnp.linalg.qr(x, "r")

    chol = jnp.linalg.cholesky(S)

    T = jax.scipy.linalg.solve_triangular(R.T, chol.T, lower=True).T
    return jnp.dot(T, T.T)


def sample_niw(seed, mu, lam, nu, S):
    sigma = sample_invwishart(seed, S, nu)
    mu = jr.multivariate_normal(seed, mu, sigma / lam)
    return mu, sigma


def sample_mniw(seed, nu, S, M, K):
    sigma = sample_invwishart(seed, S, nu)
    A = sample_mn(seed, M, sigma, K)
    return A, sigma


[docs] def sample_hmm_stateseq(seed, transition_matrix, log_likelihoods, mask): """Sample state sequences in a Markov chain. Parameters ---------- seed: jax.random.PRNGKey Random seed transition_matrix: jax array, shape (num_states, num_states) Transition matrix log_likelihoods: jax array, shape (num_timesteps, num_states) Sequence of log likelihoods of emissions given hidden state and parameters mask: jax array, shape (num_timesteps,) Sequence indicating whether to use an emission (1) or not (0) Returns ------- log_norm: float: Posterior marginal log likelihood states: jax array, shape (num_timesteps,) Sequence of sampled states """ num_states = transition_matrix.shape[0] initial_distribution = jnp.ones(num_states) / num_states masked_log_likelihoods = log_likelihoods * mask[:, None] L, z = hmm_posterior_sample( seed, initial_distribution, transition_matrix, masked_log_likelihoods ) z = convert_data_precision(z) return L, z