Utilities
Distributions
Functions:
|
Sample state sequences in a Markov chain. |
- jax_moseq.utils.distributions.sample_hmm_stateseq(seed, transition_matrix, log_likelihoods, mask)[source]
Sample state sequences in a Markov chain.
- Parameters:
seed (jax.random.PRNGKey) – Random seed
transition_matrix (jax array, shape (num_states, num_states)) – Transition matrix
log_likelihoods (jax array, shape (num_timesteps, num_states)) – Sequence of log likelihoods of emissions given hidden state and parameters
mask (jax array, shape (num_timesteps,)) – Sequence indicating whether to use an emission (1) or not (0)
- Returns:
log_norm (float:) – Posterior marginal log likelihood
states (jax array, shape (num_timesteps,)) – Sequence of sampled states
Transitions
Functions:
|
Count the number of transitions between each pair of states |
|
Sample the state usages |
|
Sample the transition matrix |
|
Sample the transition parameters of the HDP-HMM given the observed transition counts, the current usage estimates, and the model hyperparameters. |
|
Resample the transition parameters of the HDP-HMM. |
|
Initialize the transition parameters of the HDP-HMM. |
|
Sample a transition matrix using a sticky Dirichlet prior. |
|
Resample Markov transition probabilities using a sticky Dirichlet prior. |
- jax_moseq.utils.transitions.count_transitions(num_states, stateseqs, mask)[source]
Count the number of transitions between each pair of states
iandjin the unmasked entries ofstateseqs, including self transitions (i.e. i == j).- Parameters:
num_states (int) – Total number of states (must exceed
max(stateseqs)).stateseqs (jax int array of shape (..., T)) – Batch of state sequences where the last dim indexes time. All entries
mask (jax array of shape (..., T + num_lags)) – Binary indicator for which elements of
stateseqsare valid. Ifnum_lags > 0, the firstnum_lagstime points of the mask are ignored (ensures time alignment with the AR process).
- Returns:
transition_counts – The number of transitions between every pair of states.
- Return type:
jax array of shape (num_states, num_states)
- jax_moseq.utils.transitions.sample_betas(seed, transition_counts, betas, alpha, kappa, gamma)[source]
Sample the state usages
betasgiven the observed transition counts and the model hyperparameters.- Parameters:
seed (jr.PRNGKey) – JAX random seed.
transition_counts (jax array of shape (num_states, num_states)) – The number of transitions between every pair of states.
betas (jax array of shape num_states) – State usages.
alpha (scalar) – State usage influence hyperparameter.
kappa (scalar) – State persistence (i.e. “stickiness”) hyperparameter.
gamma (scalar) – Usage uniformity hyperparameter.
- Returns:
betas – Resampled state usages.
- Return type:
jax array of shape num_states
- jax_moseq.utils.transitions.sample_pi(seed, transition_counts, betas, alpha, kappa)[source]
Sample the transition matrix
pigiven the observed transition counts, state usages, and model hyperparameters.- Parameters:
seed (jr.PRNGKey) – JAX random seed.
transition_counts (jax array of shape (num_states, num_states)) – The number of transitions between every pair of states.
betas (jax array of shape num_states) – State usages.
alpha (scalar) – State usage influence hyperparameter.
kappa (scalar) – State persistence (i.e. “stickiness”) hyperparameter.
- Returns:
pi – Resampled transition probabilities.
- Return type:
jax_array of shape (num_states, num_states)
- jax_moseq.utils.transitions.sample_hdp_transitions(seed, transition_counts, betas, alpha, kappa, gamma)[source]
Sample the transition parameters of the HDP-HMM given the observed transition counts, the current usage estimates, and the model hyperparameters.
- Parameters:
seed (jr.PRNGKey) – JAX random seed.
transition_counts (jax array of shape (num_states, num_states)) – The number of transitions between every pair of states.
betas (jax array of shape num_states) – State usages.
alpha (scalar) – State usage influence hyperparameter.
kappa (scalar) – State persistence (i.e. “stickiness”) hyperparameter.
gamma (scalar) – Usage uniformity hyperparameter.
- Returns:
betas (jax array of shape num_states) – Resampled state usages.
pi (jax_array of shape (num_states, num_states)) – Resampled transition probabilities.
- jax_moseq.utils.transitions.resample_hdp_transitions(seed, z, mask, betas, alpha, kappa, gamma, **kwargs)[source]
Resample the transition parameters of the HDP-HMM.
- Parameters:
seed (jr.PRNGKey) – JAX random seed.
z (jax_array of shape (..., T - n_lags)) – Discrete state sequences.
mask (jax array of shape (..., T)) – Binary indicator for which data points are valid.
betas (jax array of shape (num_states,)) – State usages.
alpha (scalar) – State usage influence hyperparameter.
kappa (scalar) – State persistence (i.e. “stickiness”) hyperparameter.
gamma (scalar) – Usage uniformity hyperparameter.
kwargs (dict) – Overflow, for convenience.
- Returns:
betas (jax array of shape (num_states,)) – Resampled state usages.
pi (jax_array of shape (num_states, num_states)) – Resampled transition probabilities.
- jax_moseq.utils.transitions.init_hdp_transitions(seed, num_states, alpha, kappa, gamma, **kwargs)[source]
Initialize the transition parameters of the HDP-HMM.
- Parameters:
seed (jr.PRNGKey) – JAX random seed.
num_states (int) – Max number of HMM states.
betas (jax array of shape (num_states,)) – State usages.
alpha (scalar) – State usage influence hyperparameter.
kappa (scalar) – State persistence (i.e. “stickiness”) hyperparameter.
gamma (scalar) – Usage uniformity hyperparameter.
kwargs (dict) – Overflow, for convenience.
- Returns:
betas (jax array of shape (num_states,)) – Initial state usages.
pi (jax_array of shape (num_states, num_states)) – Initial transition probabilities.
- jax_moseq.utils.transitions.sample_dir_transitions(seed, transition_counts, beta, kappa)[source]
Sample a transition matrix using a sticky Dirichlet prior.
- Parameters:
seed (jr.PRNGKey) – JAX random seed.
transition_counts (jax array of shape (num_states, num_states)) – The number of transitions between every pair of states.
beta (scalar) – Dirichlet prior concentration parameter.
kappa (scalar) – State persistence (i.e. “stickiness”) hyperparameter.
- Returns:
pi – Transition probabilities.
- Return type:
jax_array of shape (num_states, num_states)
- jax_moseq.utils.transitions.resample_dir_transitions(seed, num_states, z, mask, beta, kappa, **kwargs)[source]
Resample Markov transition probabilities using a sticky Dirichlet prior.
- Parameters:
seed (jr.PRNGKey) – JAX random seed.
num_states (int) – Max number of HMM states.
z (jax_array of shape (..., T - n_lags)) – Discrete state sequences.
mask (jax array of shape (..., T)) – Binary indicator for which data points are valid.
beta (scalar) – Dirichlet prior concentration parameter.
kappa (scalar) – State persistence (i.e. “stickiness”) hyperparameter.
- Returns:
pi – Resampled transition probabilities.
- Return type:
jax_array of shape (num_states, num_states)
Kalman smoothing
Functions:
|
Run forward-filtering and backward-sampling to draw samples from posterior of a 1st-order dynamic system with autoregressive dynamics of order n_lags. |
|
Given a linear dynamical system with L'th-order autoregressive dynamics in R^D, returns the emission terms and initia state distribution for a system with 1st-order dynamics in R^(D*L) |
|
Given a linear dynamical system with L'th-order autoregressive dynamics in R^D, returns the dynamics terms for a system with 1st-order dynamics in R^(D*L) |
- jax_moseq.utils.kalman.kalman_sample(seed, ys, mask, zs, m0, S0, A, B, Q, C, D, Rs, masked_dynamics_params, masked_obs_noise, jitter=0, parallel=True)[source]
Run forward-filtering and backward-sampling to draw samples from posterior of a 1st-order dynamic system with autoregressive dynamics of order n_lags.
- Parameters:
seed (jr.PRNGKey.)
ys (jax.Array with shape (T, obs_dim)) – Continuous observations, minus first L+1 frames.
mask (jax.Array with shape (T,)) – Indicator of observation validity, for timesteps [L-1, T)
zs (jax.Array with shape (T-n_lags,)) – Discrete state sequence, taking integer values [1, n_states).
mu0 (jax.Array with shape (ar_dim,)) – Initial continuous state mean
S0 (jax.Array with shape (ar_dim, ar_dim)) – Initial continuous state covariance
A (jax.Array with shape (n_states, ar_dim, ar_dim)) – State dynamics matrix
B (jax.Array with shape (n_states, ar_dim)) – State input matrix
Q (jax.Array with shape (n_states, ar_dim, ar_dim)) – State noise matrix
C (jax.Array with shape (obs_dim, ar_dim)) – Observation transform matrix
D (jax.Array with shape (obs_dim,)) – Observation input matrix
Rs (jax.Array with shape (T, obs_dim)) – Observation noise scales (diagonal entries of covariance)
masked_dynamics_params (dict with key-value pairs) –
weights: jax.Array with shape (ar_dim, ar_dim)
bias: jax.Array with shape (ar_dim,)
cov: jax.Array with shape (ar_dim, ar_dim)
Dynamics parameters, for masked timesteps
masked_obs_noise (jax.Array with shape (obs_dim,)) – Diagonal observation noise scale, for masked timesteps.
jitter (float, default=0) – Amount to boost the diagonal of the covariance matrix during backward-sampling of the continuous states.
parallel (bool, default=True,) – Use associative scan for Kalman sampling, which is faster on a GPU but has a significantly longer jit time.
- Returns:
xs – Sampled continuous state sequence.
- Return type:
jax.Array with shape (T, ar_dim)
- jax_moseq.utils.kalman.ar_to_lds_emissions(Cd, R, y, m0, S0, nlags)[source]
Given a linear dynamical system with L’th-order autoregressive dynamics in R^D, returns the emission terms and initia state distribution for a system with 1st-order dynamics in R^(D*L)
- Parameters:
Cd (jax array, shape (D_obs, D+1)) – Observation affine transformation
R (jax array, shape (T, D_obs)) – Dimension-wise observation covariances
y (jax array, shape (T, D_obs)) – Observations
m0 (jax array, shape (D)) – Initial state distribution mean
S0 (jax array, shape (D, D)) – Initial state distribution cov
nlags (Number of autoregressive lags)
- Returns:
C_ (jax array, shape (D_obs, D*L))
d_ (jax array, shape (D_obs))
R_ (jax array, shape (T, D_obs))
y_ (jax array, shape (T, D_obs))
m0_ (jax array, shape (D*L))
S0_ (jax array, shape (D*L, D*L))
- jax_moseq.utils.kalman.ar_to_lds_dynamics(Ab, Q)[source]
Given a linear dynamical system with L’th-order autoregressive dynamics in R^D, returns the dynamics terms for a system with 1st-order dynamics in R^(D*L)
- Parameters:
Ab (jax array, shape (..., D, D*L + 1)) – AR affine transform
Q (jax array, shape (..., D, D)) – AR noise covariance
- Returns:
A_ (jax array, shape (…, D*L, D*L))
b_ (jax array, shape (…, D*L))
Q_ (jax array, shape (…, D*L, D*L))
Autoregression
Functions:
|
Get lags of a multivariate time series. |
- jax_moseq.utils.autoregression.get_lags(x, nlags)[source]
Get lags of a multivariate time series. Lags are concatenated along the last dim in time-order.
- Parameters:
nlags (int) – Number of lags
x (jax array, shape (..., t, d)) – Batch of d-dimensional time series
- Returns:
x_lagged
- Return type:
jax array, shape (…, t-nlags, d*nlags)
Odds and ends
Functions:
|
Concatenate state sequences, optionally applying a mask. |
|
Get durations for a batch of state sequences. |
|
Get state frequencies for a batch of state sequences. |
|
Symmetrize a matrix. |
|
Solves the linear system Ax=B, assuming A is positive semi-definite. |
|
Invert a positive semi-definite matrix. |
|
Converts a function involving numpy arrays to one that inputs and outputs jax arrays. |
|
Pad |
|
Fit a PCA model to transformed keypoint coordinates. |
|
Wrap an angle to the range [-pi, pi]. |
|
Pad an array along a single axis |
|
Invert |
|
Stack time-series data of different lengths into a single array for batch processing, optionally breaking up the data into fixed length segments. |
Get the number of iterations to use for jax.lax.map in |
|
|
Set the number of iterations to use for jax.lax.map in |
Get the number of GPUs to use for jax.pmap in |
|
|
Set the number of GPUs to use for jax.pmap in |
|
Combine jax.pmap, jax.vmap and jax.lax.map for parallelization. |
- jax_moseq.utils.utils.concatenate_stateseqs(stateseqs, mask=None)[source]
Concatenate state sequences, optionally applying a mask.
- Parameters:
stateseqs (ndarray of shape (..., t), or dict or list of such arrays) – Batch of state sequences where the last dim indexes time, or a dict/list containing state sequences as 1d arrays.
mask (ndarray of shape (..., >=t), default=None) – Binary indicator for which elements of stateseqs are valid, used in the case where stateseqs is an ndarray. If mask contains more time-points than stateseqs, the initial extra time-points will be ignored.
- Returns:
stateseqs_flat – 1d array containing all state sequences
- Return type:
ndarray
- jax_moseq.utils.utils.get_durations(stateseqs, mask=None)[source]
Get durations for a batch of state sequences.
- Parameters:
stateseqs (ndarray of shape (..., t), or dict or list of such arrays) – Batch of state sequences where the last dim indexes time, or a dict/list containing state sequences as 1d arrays.
mask (ndarray of shape (..., >=t), default=None) – Binary indicator for which elements of stateseqs are valid, used in the case where stateseqs is an ndarray. If mask contains more time-points than stateseqs, the initial extra time-points will be ignored.
- Returns:
durations – The duration of each each state (across all state sequences)
- Return type:
1d array
Examples
>>> stateseqs = { 'name1': np.array([1, 1, 2, 2, 2, 3]), 'name2': np.array([0, 0, 0, 1]) } >>> get_durations(stateseqs) array([2, 3, 1, 3, 1])
- jax_moseq.utils.utils.get_frequencies(stateseqs, mask=None, num_states=None, runlength=True)[source]
Get state frequencies for a batch of state sequences.
- Parameters:
stateseqs (ndarray of shape (..., t), or dict or list of such arrays) – Batch of state sequences where the last dim indexes time, or a dict/list containing state sequences as 1d arrays.
mask (ndarray of shape (..., >=t), default=None) – Binary indicator for which elements of stateseqs are valid, used in the case where stateseqs is an ndarray. If mask contains more time-points than stateseqs, the initial extra time-points will be ignored.
num_states (int, default=None) – Number of different states. If None, the number of states will be set to max(stateseqs)+1.
runlength (bool, default=True) – Whether to count frequency by the number of instances of each state (True), or by the number of frames in each state (False).
- Returns:
frequencies – Frequency of each state across all state sequences
- Return type:
1d array
Examples
>>> stateseqs = { 'name1': np.array([1, 1, 2, 2, 2, 3]), 'name2': np.array([0, 0, 0, 1])} >>> get_frequencies(stateseqs, runlength=True) array([0.2, 0.4, 0.2, 0.2]) >>> get_frequencies(stateseqs, runlength=False) array([0.3, 0.3, 0.3, 0.1])
- jax_moseq.utils.utils.psd_solve(A, B, diagonal_boost=1e-06)[source]
Solves the linear system Ax=B, assuming A is positive semi-definite.
Uses Cholesky decomposition for improved numerical stability and efficiency. A is symmetrized and diagonal elements are boosted by
diagonal_boostto ensure positive definiteness.- Parameters:
A (jax array, shape (n,n)) – A positive semi-definite matrix
b (jax array, shape (...,n))
- Returns:
x – Solution of the linear system Ax=b
- Return type:
jax array, shape (…,n)
- jax_moseq.utils.utils.psd_inv(A, diagonal_boost=1e-06)[source]
Invert a positive semi-definite matrix.
Uses
jax_moseq.utils.psd_solve()for numerical stability and ensures that the inverse matrix is symmetric.- Parameters:
A (jax array, shape (n,n)) – A positive semi-definite matrix
- Returns:
Ainv – The inverse of A
- Return type:
jax array, shape (n,n)
- jax_moseq.utils.utils.jax_io(fn)[source]
Converts a function involving numpy arrays to one that inputs and outputs jax arrays.
- jax_moseq.utils.utils.pad_affine(x)[source]
Pad
xwith 1’s so that it can be affine transformed with matrix multiplication.
- jax_moseq.utils.utils.fit_pca(Y, mask, PCA_fitting_num_frames=1000000, verbose=False, **kwargs)[source]
Fit a PCA model to transformed keypoint coordinates.
- Parameters:
Y (jax array, shape (..., d)) – Keypoint coordinates
mask (jax array) – Binary indicator for which elements of
Yare validPCA_fitting_num_frames (int, default=1000000) – Maximum number of frames to use for PCA. Frames will be sampled randomly if the input data exceed this size.
verbose (bool, default=False) – Whether to print the number of sampled frames.
- Returns:
An sklearn PCA model fit to Y
- Return type:
pca, sklearn.decomposition._pca.PCA
- jax_moseq.utils.utils.pad_along_axis(arr, pad_widths, axis=0, value=0)[source]
Pad an array along a single axis
- Parameters:
arr (ndarray, Array to be padded)
pad_widths (tuple (int,int), Amount of padding on either end)
axis (int, Axis along which to add padding)
value (float, Value of padded array elements)
- Returns:
padded_arr
- Return type:
ndarray
- jax_moseq.utils.utils.unbatch(data, keys, bounds)[source]
Invert
jax_moseq.utils.batch()- Parameters:
data (ndarray, shape (num_segs, seg_length, ...)) – Stack of segmented time-series
keys (list or array of str, length num_segs) – Name of the time-series that each segment came from
bounds (ndarray, shape (num_segs, 2)) – Start and end times for each segment, reflecting how the segments were extracted from the original time-series.
- Returns:
data_dict – Dictionary mapping names to reconstructed time-series
- Return type:
dict
- jax_moseq.utils.utils.batch(data_dict, keys=None, seg_length=None, seg_overlap=30)[source]
Stack time-series data of different lengths into a single array for batch processing, optionally breaking up the data into fixed length segments. The data is padded so that the stacked array isn’t ragged. The padding repeats the last frame of each time-series until the end of the segment.
- Parameters:
data_dict (dict {str : ndarray}) – Dictionary mapping names to ndarrays, where the first dim represents time. All data arrays must have the same shape except for the first dim.
keys (list of str, default=None) – Optional list of names specifying which datasets to include in the output and what order to put them in. Each name must be a key in
data_dict. Ifkeys=None, names will be sorted alphabetically.seg_length (int, default=None) – Break each time-series into segments of this length. If
seg_length=None, the final stacked array will be as long as the longest time-series.seg_overlap (int, default=30) – Amount of overlap between segments. For example, setting
seg_length=Nandseg_overlap=Mwill result in segments with start/end times (0, N+M), (N, 2*N+M), (2*N, 3*N+M),…
- Returns:
data (ndarray, shape (N, seg_length, …)) – Stacked data array
mask (ndarray, shape (N, seg_length)) – Binary indicator specifying which elements of
dataare not padding (mask==0in padded locations)metadata (tuple (keys, bounds)) – Metadata for the rows of data, as a tuple with an array of keys and an array of (start,end) times.
- jax_moseq.utils.utils.get_mixed_map_iters()[source]
Get the number of iterations to use for jax.lax.map in
jax_moseq.utils.mixed_map().
- jax_moseq.utils.utils.set_mixed_map_iters(iters)[source]
Set the number of iterations to use for jax.lax.map in
jax_moseq.utils.mixed_map().
- jax_moseq.utils.utils.get_mixed_map_gpus()[source]
Get the number of GPUs to use for jax.pmap in
jax_moseq.utils.mixed_map().
- jax_moseq.utils.utils.set_mixed_map_gpus(gpus)[source]
Set the number of GPUs to use for jax.pmap in
jax_moseq.utils.mixed_map().
- jax_moseq.utils.utils.mixed_map(fun, in_axes=None, out_axes=None)[source]
Combine jax.pmap, jax.vmap and jax.lax.map for parallelization.
This function is similar to jax.vmap, except that it mixes together jax.pmap, jax.vmap and jax.lax.map to prevent OOM errors and allow for parallelization across multiple GPUs. The behavior is determined by the global variables _MIXED_MAP_ITERS and _MIXED_MAP_GPUS, which can be set using
jax_moseq.utils.set_mixed_map_iters()and py:func:jax_moseq.utils.set_mixed_map_gpus respectively.Given an axis size of N to map, the data is padded such that the axis size is a multiple of the number of _MIXED_MAP_ITERS * _MIXED_MAP_GPUS. The data is then processed serially chunks, where the number of chunks is determined by _MIXED_MAP_ITERS. Each chunk is processed in parallel using jax.pmap to distribute across _MIXED_MAP_GPUS devices and jax.vmap within each device.