Source code for jax_moseq.models.slds.initialize

import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np

from jax_moseq.utils import (
    jax_io,
    device_put_as_scalar,
    fit_pca,
    check_precision,
)

from jax_moseq.models import arhmm
from jax_moseq.models.slds.gibbs import resample_scales

na = jnp.newaxis


[docs] def init_obs_params(pca, Y, mask, whiten, latent_dim, **kwargs): """ Initialize the observation (also known as "emission") parameters of the model using PCA. Parameters ---------- pca : sklearn.decomposition._pca.PCA PCA object fit to observations. Y : jax array of shape (N, T, obs_dim) Observations. mask : jax array of shape (N, T) Binary indicator for valid frames. whiten : bool Whether to whiten PC's to initialize continuous latents. latent_dim : int Dimensionality of continuous latents. **kwargs : dict Overflow, for convenience. Returns ------- Cd : jax array of shape (obs_dim, latent_dim + 1) Observation transform. """ C = jnp.array(pca.components_[:latent_dim]) d = jnp.array(pca.mean_) if whiten: Y_flat = Y[mask > 0] latents_flat = jax_io(pca.transform)(Y_flat)[:, :latent_dim] cov = jnp.cov(latents_flat.T) W = jnp.linalg.cholesky(cov) C = W.T @ C Cd = jnp.hstack([C.T, d[:, na]]) return Cd
[docs] def init_continuous_stateseqs(Y, Cd, **kwargs): """ Initialize the continuous latents by applying the inverse of the emission transform to the data. Parameters ---------- Y : jax array of shape (N, T, obs_dim) Observations. Cd : jax array of shape (obs_dim, latent_dim + 1) Observation transform. **kwargs : dict Overflow, for convenience. Returns ------- x : jax array of shape (N, T, latent_dim) Latent trajectories. """ C, d = Cd[:, :-1], Cd[:, -1] C_inv = jnp.array(np.linalg.pinv(C)) return (Y - d) @ C_inv.T
[docs] def init_states(seed, Y, mask, params, obs_hypparams, **kwargs): """ Initialize the latent states of the SLDS from the data, parameters, and hyperparameters. Parameters ---------- seed : jr.PRNGKey JAX random seed. Y : jax array of shape (N, T, obs_dim) Observations. mask : jax array of shape (N, T) Binary indicator for valid frames. params : dict Values for each model parameter. obs_hypparams : dict Observation hyperparameters. **kwargs : dict Overflow, for convenience. Returns ------- states : dict State values for each latent variable. """ x = init_continuous_stateseqs(Y, params["Cd"]) s = resample_scales(seed, Y, x, **params, **obs_hypparams) # initialize arhmm to get discrete latents states = arhmm.init_states(seed, x, mask, params) states["x"] = x states["s"] = s return states
[docs] def init_params( seed, pca, Y, mask, trans_hypparams, ar_hypparams, whiten, **kwargs ): """ Initialize the parameters of the SLDS from the data and hyperparameters. Parameters ---------- seed : jr.PRNGKey JAX random seed. pca : sklearn.decomposition._pca.PCA PCA object fit to observations. Y : jax array of shape (N, T, obs_dim) Observations. mask : jax array of shape (N, T) Binary indicator for valid frames. trans_hypparams : dict HDP transition hyperparameters. ar_hypparams : dict Autoregression hyperparameters. whiten : bool Whether to whiten PC's to initialize continuous latents. **kwargs : dict Overflow, for convenience. Returns ------- params : dict Values for each model parameter. """ params = arhmm.init_params(seed, trans_hypparams, ar_hypparams) latent_dim = ar_hypparams["latent_dim"] params["Cd"] = init_obs_params(pca, Y, mask, whiten, latent_dim) params["sigmasq"] = jnp.ones(Y.shape[-1]) # TODO return params
[docs] def init_hyperparams(trans_hypparams, ar_hypparams, obs_hypparams, **kwargs): """ Formats the hyperparameter dictionary of the SLDS. Parameters ---------- trans_hypparams : dict HDP transition hyperparameters. ar_hypparams : dict Autoregression hyperparameters. obs_hypparams : dict Observation hyperparameters. **kwargs : dict Overflow, for convenience. Returns ------- hypparams : dict Values for each group of hyperparameters. """ hyperparams = arhmm.init_hyperparams(trans_hypparams, ar_hypparams) hyperparams["obs_hypparams"] = obs_hypparams.copy() return hyperparams
[docs] def init_model( data=None, states=None, params=None, hypparams=None, seed=jr.PRNGKey(0), pca=None, whiten=True, PCA_fitting_num_frames=1000000, trans_hypparams=None, ar_hypparams=None, obs_hypparams=None, verbose=False, **kwargs ): """ Initialize a SLDS model dict containing the hyperparameters, noise prior, and initial seed, states, and parameters. Parameters ---------- data : dict, optional Data dictionary containing the observations, mask, and (optionally) confidences. Must be provided if `states` or `params` not precomputed. states : dict, optional State values for each latent variable, if precomputed. params : dict, optional Values for each model parameter, if precomputed. hypparams : dict, optional Values for each group of hyperparameters. If not provided, caller must provide each arg of `init_hypparams`. seed : int or jr.PRNGKey, default=jr.PRNGKey(0) Initial random seed value. pca : sklearn.decomposition._pca.PCA, optional PCA object fit to observations. whiten : bool, default=True Whether to whiten PC's to initialize continuous latents. PCA_fitting_num_frames : int, default=1000000 Maximum number of datapoints to sample for PCA fitting, if `pca` is not provided. trans_hypparams : dict, optional HDP transition hyperparameters. Must be provided if `hypparams` not provided. ar_hypparams : dict, optional Autoregression hyperparameters. Must be provided if `hypparams` not provided. obs_hypparams : dict, optional Observation hyperparameters. Must be provided if `hypparams` not provided. verbose : bool, default=False Whether to print progress info during initialization. **kwargs : dict, optional Unused. For convenience, enables user to invoke function by unpacking dict that contains keys not used by the method. Returns ------- model : dict Dictionary containing the hyperparameters, noise prior, and initial seed, states, and parameters of the model. Raises ------ ValueError If the subset of the parameters provided by the caller is insufficient for model initialization. """ _check_init_args( data, states, params, hypparams, trans_hypparams, ar_hypparams, obs_hypparams, ) model = {} if not (states and params): Y, mask = data["Y"], data["mask"] if isinstance(seed, int): seed = jr.PRNGKey(seed) model["seed"] = seed if hypparams is None: if verbose: print("SLDS: Initializing hyperparameters") hypparams = init_hyperparams( trans_hypparams, ar_hypparams, obs_hypparams ) else: hypparams = device_put_as_scalar(hypparams) model["hypparams"] = hypparams if params is None: if verbose: print("SLDS: Initializing parameters") if pca is None: pca = fit_pca(Y, mask, PCA_fitting_num_frames, verbose) params = init_params(seed, pca, Y, mask, **hypparams, whiten=whiten) else: params = jax.device_put(params) model["params"] = params if states is None: if verbose: print("SLDS: Initializing states") states = init_states(seed, Y, mask, params, **hypparams) else: states = jax.device_put(states) model["states"] = states return model
@check_precision def _check_init_args( data, states, params, hypparams, trans_hypparams, ar_hypparams, obs_hypparams, ): """ Helper method for `init_model` that ensures a sufficient subset of the initialization arguments have been provided by the caller. Parameters ---------- data : dict or None Data dictionary containing the observations, mask, and (optionally) confidences. states : dict or None State values for each latent variable. params : dict or None Values for each model parameter. hypparams : dict or None Values for each group of hyperparameters. trans_hypparams : dict or None HDP transition hyperparameters. ar_hypparams : dict or None Autoregression hyperparameters. obs_hypparams : dict or None Observation hyperparameters. Raises ------ ValueError If the subset of the parameters provided by the caller is insufficient for model initialization. """ if not (data or (states and params)): raise ValueError( "Must provide either `data` or " "both `states` and `params`." ) if not (hypparams or (trans_hypparams and ar_hypparams and obs_hypparams)): raise ValueError( "Must provide either `hypparams` or " "all of `trans_hypparams`, `ar_hypparams`, " "and `obs_hypparams`." )