import jax
import jax.numpy as jnp
import jax.random as jr
from jax_moseq import utils
from jax_moseq.utils import jax_io, device_put_as_scalar, check_precision
from jax_moseq.models import arhmm, slds
from jax_moseq.models.keypoint_slds.gibbs import resample_scales
from jax_moseq.models.keypoint_slds.alignment import preprocess_for_pca
[docs]
def init_states(
seed,
Y,
mask,
params,
noise_prior,
obs_hypparams,
Y_flat=None,
v=None,
h=None,
fix_heading=False,
**kwargs,
):
"""
Initialize the latent states of the keypoint SLDS from the data,
parameters, and hyperparameters.
Parameters
----------
seed : jr.PRNGKey
JAX random seed.
Y : jax array of shape (N, T, k, d)
Keypoint observations.
mask : jax array of shape (N, T)
Binary indicator for valid frames.
params : dict
Values for each model parameter.
noise_prior : scalar or jax array broadcastable to ``s``
Prior on noise scale.
obs_hypparams : dict
Observation hyperparameters.
Y_flat : jax array of shape (N, T, (k - 1) * d), optional
Aligned and embedded keypoint observations.
v : jax array of shape (N, T, d), optional
Initial centroid positions.
h : jax array of shape (N, T), optional
Initial heading angles.
fix_heading : bool, default=False
Whether keep the heading angle of the pose fixed. If true,
the heading variable ``h`` is initialized as 0.
**kwargs : dict, optional
Arguments to :py:func:`jax_moseq.models.keypoint_slds.alignment.preprocess_for_pca`, as a substitute for
``Y_flat``, ``v``, or ``h``.
Returns
-------
states : dict
State values for each latent variable.
"""
if Y_flat is None:
Y_flat, v, h = preprocess_for_pca(Y, fix_heading, **kwargs)
x = slds.init_continuous_stateseqs(Y_flat, params["Cd"])
states = arhmm.init_states(seed, x, mask, params)
states["x"] = x
states["v"] = v
states["h"] = h
states["s"] = resample_scales(
seed, Y, **states, **params, s_0=noise_prior, **obs_hypparams
)
return states
[docs]
def init_params(
seed, pca, Y_flat, mask, trans_hypparams, ar_hypparams, whiten, k, **kwargs
):
"""
Initialize the parameters of the keypoint SLDS from the
data and hyperparameters.
Parameters
----------
seed : jr.PRNGKey
JAX random seed.
pca : sklearn.decomposition._pca.PCA
PCA object fit to observations.
Y_flat : jax array of shape (N, T, (k - 1) * d)
Aligned and embedded keypoint observations.
mask : jax array of shape (N, T)
Binary indicator for valid frames.
trans_hypparams : dict
HDP transition hyperparameters.
ar_hypparams : dict
Autoregression hyperparameters.
whiten : bool
Whether to whiten PC's to initialize continuous latents.
k : int
Number of keypoints.
**kwargs : dict
Overflow, for convenience.
Returns
-------
params : dict
Values for each model parameter.
"""
params = arhmm.init_params(seed, trans_hypparams, ar_hypparams)
params["Cd"] = slds.init_obs_params(pca, Y_flat, mask, whiten, **ar_hypparams)
params["sigmasq"] = jnp.ones(k)
return params
[docs]
def init_hyperparams(
trans_hypparams, ar_hypparams, obs_hypparams, cen_hypparams, **kwargs
):
"""
Formats the hyperparameter dictionary of the keypoint SLDS.
Parameters
----------
trans_hypparams : dict
HDP transition hyperparameters.
ar_hypparams : dict
Autoregression hyperparameters.
obs_hypparams : dict
Observation hyperparameters.
cen_hypparams : dict
Centroid movement hyperparameters.
**kwargs : dict
Overflow, for convenience.
Returns
-------
hypparams : dict
Values for each group of hyperparameters.
"""
hyperparams = slds.init_hyperparams(trans_hypparams, ar_hypparams, obs_hypparams)
hyperparams["cen_hypparams"] = cen_hypparams.copy()
return hyperparams
[docs]
def init_model(
data=None,
states=None,
params=None,
hypparams=None,
noise_prior=None,
seed=jr.PRNGKey(0),
pca=None,
whiten=True,
PCA_fitting_num_frames=1000000,
anterior_idxs=None,
posterior_idxs=None,
conf_threshold=0.5,
error_estimator=None,
trans_hypparams=None,
ar_hypparams=None,
obs_hypparams=None,
cen_hypparams=None,
verbose=False,
exclude_outliers_for_pca=True,
fix_heading=False,
**kwargs,
):
"""
Initialize a keypoint SLDS model dict containing the
hyperparameters, noise prior, and initial seed, states,
and parameters.
Parameters
----------
data : dict, optional
Data dictionary containing the observations, mask,
and (optionally) confidences. Must be provided if ``states``
or ``params`` not precomputed.
states : dict, optional
State values for each latent variable, if precomputed.
params : dict, optional
Values for each model parameter, if precomputed.
hypparams : dict, optional
Values for each group of hyperparameters. If not provided,
caller must provide each arg of ``init_hypparams``.
noise_prior : array or scalar, optional
Prior on the noise for each keypoint observation, if precomputed.
seed : int or jr.PRNGKey, default=jr.PRNGKey(0)
Initial random seed value.
pca : sklearn.decomposition.PCA, optional
PCA object, if precomputed. If unspecified, will be
computed from the data.
whiten : bool, default=True
Whether to whiten PC's to initialize continuous latents.
PCA_fitting_num_frames : int, default=1000000
Maximum number of datapoints to sample for PCA fitting,
if ``pca`` is not provided.
anterior_idxs : iterable of ints, optional
Anterior keypoint indices for heading initialization.
Must be provided if `states` or `params` not precomputed.
posterior_idxs : iterable of ints, optional
Posterior keypoint indices for heading initialization.
Must be provided if ``states`` or ``params`` not precomputed.
conf_threshold : float, default=0.5
Confidence threshold below which points are interpolated
in PCA fitting and heading/position initialization. See
:py:func:`jax_moseq.models.keypoint_slds.alignment.preprocess_for_pca` for details.
error_estimator : dict, optional
Parameters used to initialize ``noise_prior``. Must be provided
if ``data`` contains confidences. See :py:func:`jax_moseq.models.keypoint_slds.initialize.estimate_error` for details.
trans_hypparams : dict, optional
HDP transition hyperparameters. Must be provided if
``hypparams`` not provided.
ar_hypparams : dict, optional
Autoregression hyperparameters. Must be provided if
``hypparams`` not provided.
obs_hypparams : dict, optional
Observation hyperparameters. Must be provided if
``hypparams`` not provided.
cen_hypparams : dict, optional
Centroid movement hyperparameters. Must be provided if
``hypparams`` not provided.
verbose : bool, default=False
Whether to print progress info during initialization.
exclude_outliers_for_pca : bool, default=True
Whether to exclude frames with low-confidence keypoints.
If False, then the low-confidence keypoint coordinates are l
inearly interpolated.
fix_heading : bool, default=False
Whether keep the heading angle of the pose fixed. If true,
the heading variable ``h`` is initialized as 0.
**kwargs : dict, optional
Unused. For convenience, enables user to invoke function
by unpacking dict that contains keys not used by the method.
Returns
-------
model : dict
Dictionary containing the hyperparameters, noise prior,
and initial seed, states, and parameters of the model.
Raises
------
ValueError
If the subset of the parameters provided by the caller
is insufficient for model initialization.
"""
has_conf = data and ("conf" in data)
_check_init_args(
data,
states,
params,
hypparams,
trans_hypparams,
ar_hypparams,
obs_hypparams,
cen_hypparams,
has_conf,
noise_prior,
error_estimator,
anterior_idxs,
posterior_idxs,
)
model = {}
if has_conf:
conf = data["conf"]
else:
conf = None
if not (states and params):
Y, mask = data["Y"], data["mask"]
Y_flat, v, h = preprocess_for_pca(
Y,
anterior_idxs,
posterior_idxs,
conf,
conf_threshold,
fix_heading,
verbose,
)
if isinstance(seed, int):
seed = jr.PRNGKey(seed)
model["seed"] = seed
if hypparams is None:
if verbose:
print("Keypoint SLDS: Initializing hyperparameters")
hypparams = init_hyperparams(
trans_hypparams, ar_hypparams, obs_hypparams, cen_hypparams
)
else:
hypparams = device_put_as_scalar(hypparams)
model["hypparams"] = hypparams
if noise_prior is None:
if verbose:
print("Keypoint SLDS: Initializing noise prior")
if has_conf:
noise_prior = estimate_error(conf, **error_estimator)
else:
noise_prior = 1.0 # TODO: magic number
else:
noise_prior = jax.device_put(noise_prior)
model["noise_prior"] = noise_prior
if params is None:
if verbose:
print("Keypoint SLDS: Initializing parameters")
if pca is None:
if not exclude_outliers_for_pca or conf is None:
pca_mask = mask
else:
pca_mask = jnp.logical_and(mask, (conf > conf_threshold).all(-1))
pca = utils.fit_pca(Y_flat, pca_mask, PCA_fitting_num_frames, verbose)
params = init_params(
seed, pca, Y_flat, mask, **hypparams, whiten=whiten, k=Y.shape[-2]
)
else:
params = jax.device_put(params)
model["params"] = params
if states is None:
if verbose:
print("Keypoint SLDS: Initializing states")
obs_hypparams = hypparams["obs_hypparams"]
states = init_states(
seed,
Y,
mask,
params,
noise_prior,
obs_hypparams,
Y_flat,
v,
h,
fix_heading,
)
else:
states = jax.device_put(states)
model["states"] = states
return model
[docs]
def estimate_error(conf, slope, intercept):
"""
Using the provided keypoint confidences and parameters
learned from the noise calibration, returns prior on
the noise for each datapoint.
Parameters
----------
conf : jax array of shape (..., k)
Confidence for each keypoint observation. Must be >= 0.
slope : float
Slope learned by noise calibration.
intercept : float
Intercept learned by noise calibration.
Returns
-------
noise_prior : jax array of shape (..., k)
Prior on the noise for each observation.
"""
return 10 ** (2 * (jnp.log10(conf + 1e-6) * slope + intercept))
@check_precision
def _check_init_args(
data,
states,
params,
hypparams,
trans_hypparams,
ar_hypparams,
obs_hypparams,
cen_hypparams,
has_conf,
noise_prior,
error_estimator,
anterior_idxs,
posterior_idxs,
):
"""
Helper method for :py:func:`jax_moseq.models.initialize.init_model`
that ensures a sufficient subset of the initialization arguments have
been provided by the caller.
Parameters
----------
data : dict or None
Data dictionary containing the observations, mask,
and (optionally) confidences.
states : dict or None
State values for each latent variable.
params : dict or None
Values for each model parameter.
hypparams : dict or None
Values for each group of hyperparameters.
trans_hypparams : dict or None
HDP transition hyperparameters.
ar_hypparams : dict or None
Autoregression hyperparameters.
obs_hypparams : dict or None
Observation hyperparameters.
cen_hypparams : dict or None
Centroid movement hyperparameters.
has_conf : bool
Whether data is provided and includes confidences.
noise_prior : array, scalar, or None
Prior on the noise for each keypoint observation.
error_estimator : dict or None
Parameters used to initialize ``noise_prior``.
anterior_idxs : iterable of ints or None
Anterior keypoint indices for heading initialization.
posterior_idxs : iterable of ints or Nonne
Posterior keypoint indices for heading initialization.
Raises
------
ValueError
If the subset of the parameters provided by the caller
is insufficient for model initialization.
"""
if not (data or (states and params)):
raise ValueError("Must provide either `data` or " "both `states` and `params`.")
if not (
hypparams
or (trans_hypparams and ar_hypparams and obs_hypparams and cen_hypparams)
):
raise ValueError(
"Must provide either `hypparams` or "
"all of `trans_hypparams`, `ar_hypparams`, "
"`obs_hypparams`, and `cen_hypparams`."
)
if has_conf and ((noise_prior is None) and (error_estimator is None)):
raise ValueError(
"If confidences are provided, must also provide "
"either `error_estimator` or `noise_prior`."
)
if not (states and params) and (anterior_idxs is None or posterior_idxs is None):
raise ValueError(
"If `states` and `params` not provided, must "
"provide `anterior_idxs` and `posterior_idxs`."
)
if data:
if ar_hypparams:
latent_dim = ar_hypparams["latent_dim"]
else:
latent_dim = hypparams["ar_hypparams"]["latent_dim"]
max_dim = (data["Y"].shape[-2] - 1) * data["Y"].shape[-1]
if latent_dim > max_dim:
raise ValueError(
"`latent_dim` must be less than or equal to `(num_keypoints - 1) * keypoint_dim`. "
f"The current value of `latent_dim` ({latent_dim}) is above the maximum ({max_dim})"
)