import jax
import jax.numpy as jnp
import jax.random as jr
from jax_moseq.utils import device_put_as_scalar, check_precision
from jax_moseq.utils.transitions import init_hdp_transitions
from jax_moseq.utils.distributions import sample_mniw
from jax_moseq.models.arhmm.gibbs import resample_discrete_stateseqs
na = jnp.newaxis
[docs]
def init_ar_params(seed, *, num_states, nu_0, S_0, M_0, K_0, **kwargs):
"""
Initialize the autoregression parameters by sampling from an
MNIW distribution. Note below that ar_dim = latent_dim * num_lags + 1.
Parameters
----------
seed : jr.PRNGKey
JAX random seed.
num_states : int
Max number of HMM states.
nu_0 : int
Inverse-Wishart degrees of freedom parameter for Q.
S_0 : jax array of shape (latent_dim, latent_dim)
Inverse-Wishart scale parameter for Q.
M_0 : jax array of shape (latent_dim, ar_dim)
Matrix normal expectation for Ab.
K_0 : jax array of shape (ar_dim, ar_dim)
Matrix normal column scale parameter for Ab.
**kwargs : dict
Overflow, for convenience.
Returns
------
Ab : jax array of shape (num_states, latent_dim, ar_dim)
Autoregressive transforms.
Q : jax array of shape (num_states, latent_dim, latent_dim)
Autoregressive noise covariances.
"""
seeds = jr.split(seed, num_states)
in_axes = (0, na, na, na, na)
Ab, Q = jax.vmap(sample_mniw, in_axes)(seeds, nu_0, S_0, M_0, K_0)
return Ab, Q
[docs]
def init_states(seed, x, mask, params, **kwargs):
"""
Initialize the latent states of the ARHMM from the
data and parameters.
Parameters
----------
seed : jr.PRNGKey
JAX random seed.
x : jax array of shape (N, T, latent_dim)
Latent trajectories.
mask : jax array of shape (N, T)
Binary indicator for valid frames.
params : dict
Values for each model parameter.
**kwargs : dict
Overflow, for convenience.
Returns
-------
states : dict
State values for each latent variable.
"""
z = resample_discrete_stateseqs(seed, x, mask, **params)
return {"z": z}
[docs]
def init_params(seed, trans_hypparams, ar_hypparams, **kwargs):
"""
Initialize the parameters of the ARHMM from the
data and hyperparameters.
Parameters
----------
seed : jr.PRNGKey
JAX random seed.
trans_hypparams : dict
HDP transition hyperparameters.
ar_hypparams : dict
Autoregression hyperparameters.
**kwargs : dict
Overflow, for convenience.
Returns
-------
params : dict
Values for each model parameter.
"""
params = {}
params["betas"], params["pi"] = init_hdp_transitions(
seed, **trans_hypparams
)
params["Ab"], params["Q"] = init_ar_params(seed, **ar_hypparams)
return params
[docs]
def init_hyperparams(trans_hypparams, ar_hypparams, **kwargs):
"""
Formats the hyperparameter dictionary of the ARHMM.
Parameters
----------
trans_hypparams : dict
HDP transition hyperparameters.
ar_hypparams : dict
Autoregression hyperparameters.
**kwargs : dict, optional
Overflow, for convenience.
Returns
-------
hypparams : dict
Values for each group of hyperparameters.
"""
trans_hypparams = trans_hypparams.copy()
ar_hypparams = ar_hypparams.copy()
# unpack for brevity
d = ar_hypparams["latent_dim"]
nlags = ar_hypparams["nlags"]
S_0_scale = ar_hypparams["S_0_scale"]
K_0_scale = ar_hypparams["K_0_scale"]
ar_hypparams["S_0"] = S_0_scale * jnp.eye(d)
ar_hypparams["K_0"] = K_0_scale * jnp.eye(d * nlags + 1)
ar_hypparams["M_0"] = jnp.pad(jnp.eye(d), ((0, 0), ((nlags - 1) * d, 1)))
ar_hypparams["num_states"] = trans_hypparams["num_states"]
ar_hypparams["nu_0"] = d + 2
return {"trans_hypparams": trans_hypparams, "ar_hypparams": ar_hypparams}
[docs]
def init_model(
data=None,
states=None,
params=None,
hypparams=None,
seed=jr.PRNGKey(0),
trans_hypparams=None,
ar_hypparams=None,
verbose=False,
**kwargs
):
"""
Initialize an ARHMM model dict containing the hyperparameters
and initial seed, states, and parameters.
Parameters
----------
data : dict, optional
Data dictionary containing the observations and mask.
Must be provided if ``states`` not precomputed.
states : dict, optional
State values for each latent variable, if precomputed.
params : dict, optional
Values for each model parameter, if precomputed.
hypparams : dict, optional
Values for each group of hyperparameters. If not provided,
caller must provide each arg of ``init_hypparams``.
seed : int or jr.PRNGKey, default=jr.PRNGKey(0)
Initial JAX random seed.
trans_hypparams : dict, optional
HDP transition hyperparameters. Must be provided if
``hypparams`` not provided.
ar_hypparams : dict, optional
Autoregression hyperparameters. Must be provided if
``hypparams`` not provided.
verbose : bool, default=False
Whether to print progress info during initialization.
**kwargs : dict, optional
Overflow, for convenience.
Returns
-------
model : dict
Dictionary containing the hyperparameters and
initial seed, states, and parameters of the model.
Raises
------
ValueError
If the subset of the parameters provided by the caller
is insufficient for model initialization.
"""
_check_init_args(
data, states, params, hypparams, trans_hypparams, ar_hypparams
)
model = {}
if states is None:
x, mask = data["x"], data["mask"]
if isinstance(seed, int):
seed = jr.PRNGKey(seed)
model["seed"] = seed
if hypparams is None:
if verbose:
print("ARHMM: Initializing hyperparameters")
hypparams = init_hyperparams(trans_hypparams, ar_hypparams)
else:
hypparams = device_put_as_scalar(hypparams)
model["hypparams"] = hypparams
if params is None:
if verbose:
print("ARHMM: Initializing parameters")
params = init_params(seed, **hypparams)
else:
params = jax.device_put(params)
model["params"] = params
if states is None:
if verbose:
print("ARHMM: Initializing states")
states = init_states(seed, x, mask, params)
else:
states = jax.device_put(states)
model["states"] = states
return model
@check_precision
def _check_init_args(
data, states, params, hypparams, trans_hypparams, ar_hypparams
):
"""
Helper method for ``init_model`` that ensures a sufficient subset
of the initialization arguments have been provided by the caller.
Parameters
----------
data : dict or None
Data dictionary containing the observations, mask,
and (optionally) confidences.
states : dict or None
State values for each latent variable.
params : dict or None
Values for each model parameter.
hypparams : dict or None
Values for each group of hyperparameters.
trans_hypparams : dict or None
HDP transition hyperparameters.
ar_hypparams : dict or None
Autoregression hyperparameters.
Raises
------
ValueError
If the subset of the parameters provided by the caller
is insufficient for model initialization.
"""
if not (data or states):
raise ValueError("Must provide either `data` or `states`.")
if not (hypparams or (trans_hypparams and ar_hypparams)):
raise ValueError(
"Must provide either `hypparams` or "
"both `trans_hypparams` and `ar_hypparams`."
)