Utilities

Distributions

Functions:

sample_hmm_stateseq(seed, transition_matrix, ...)

Sample state sequences in a Markov chain.

jax_moseq.utils.distributions.sample_hmm_stateseq(seed, transition_matrix, log_likelihoods, mask)[source]

Sample state sequences in a Markov chain.

Parameters:
  • seed (jax.random.PRNGKey) – Random seed

  • transition_matrix (jax array, shape (num_states, num_states)) – Transition matrix

  • log_likelihoods (jax array, shape (num_timesteps, num_states)) – Sequence of log likelihoods of emissions given hidden state and parameters

  • mask (jax array, shape (num_timesteps,)) – Sequence indicating whether to use an emission (1) or not (0)

Returns:

  • log_norm (float:) – Posterior marginal log likelihood

  • states (jax array, shape (num_timesteps,)) – Sequence of sampled states

Transitions

Functions:

count_transitions(num_states, stateseqs, mask)

Count the number of transitions between each pair of states i and j in the unmasked entries of stateseqs, including self transitions (i.e. i == j).

sample_betas(seed, transition_counts, betas, ...)

Sample the state usages betas given the observed transition counts and the model hyperparameters.

sample_pi(seed, transition_counts, betas, ...)

Sample the transition matrix pi given the observed transition counts, state usages, and model hyperparameters.

sample_hdp_transitions(seed, ...)

Sample the transition parameters of the HDP-HMM given the observed transition counts, the current usage estimates, and the model hyperparameters.

resample_hdp_transitions(seed, z, mask, ...)

Resample the transition parameters of the HDP-HMM.

init_hdp_transitions(seed, num_states, ...)

Initialize the transition parameters of the HDP-HMM.

sample_dir_transitions(seed, ...)

Sample a transition matrix using a sticky Dirichlet prior.

resample_dir_transitions(seed, num_states, ...)

Resample Markov transition probabilities using a sticky Dirichlet prior.

jax_moseq.utils.transitions.count_transitions(num_states, stateseqs, mask)[source]

Count the number of transitions between each pair of states i and j in the unmasked entries of stateseqs, including self transitions (i.e. i == j).

Parameters:
  • num_states (int) – Total number of states (must exceed max(stateseqs)).

  • stateseqs (jax int array of shape (..., T)) – Batch of state sequences where the last dim indexes time. All entries

  • mask (jax array of shape (..., T + num_lags)) – Binary indicator for which elements of stateseqs are valid. If num_lags > 0, the first num_lags time points of the mask are ignored (ensures time alignment with the AR process).

Returns:

transition_counts – The number of transitions between every pair of states.

Return type:

jax array of shape (num_states, num_states)

jax_moseq.utils.transitions.sample_betas(seed, transition_counts, betas, alpha, kappa, gamma)[source]

Sample the state usages betas given the observed transition counts and the model hyperparameters.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • transition_counts (jax array of shape (num_states, num_states)) – The number of transitions between every pair of states.

  • betas (jax array of shape num_states) – State usages.

  • alpha (scalar) – State usage influence hyperparameter.

  • kappa (scalar) – State persistence (i.e. “stickiness”) hyperparameter.

  • gamma (scalar) – Usage uniformity hyperparameter.

Returns:

betas – Resampled state usages.

Return type:

jax array of shape num_states

jax_moseq.utils.transitions.sample_pi(seed, transition_counts, betas, alpha, kappa)[source]

Sample the transition matrix pi given the observed transition counts, state usages, and model hyperparameters.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • transition_counts (jax array of shape (num_states, num_states)) – The number of transitions between every pair of states.

  • betas (jax array of shape num_states) – State usages.

  • alpha (scalar) – State usage influence hyperparameter.

  • kappa (scalar) – State persistence (i.e. “stickiness”) hyperparameter.

Returns:

pi – Resampled transition probabilities.

Return type:

jax_array of shape (num_states, num_states)

jax_moseq.utils.transitions.sample_hdp_transitions(seed, transition_counts, betas, alpha, kappa, gamma)[source]

Sample the transition parameters of the HDP-HMM given the observed transition counts, the current usage estimates, and the model hyperparameters.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • transition_counts (jax array of shape (num_states, num_states)) – The number of transitions between every pair of states.

  • betas (jax array of shape num_states) – State usages.

  • alpha (scalar) – State usage influence hyperparameter.

  • kappa (scalar) – State persistence (i.e. “stickiness”) hyperparameter.

  • gamma (scalar) – Usage uniformity hyperparameter.

Returns:

  • betas (jax array of shape num_states) – Resampled state usages.

  • pi (jax_array of shape (num_states, num_states)) – Resampled transition probabilities.

jax_moseq.utils.transitions.resample_hdp_transitions(seed, z, mask, betas, alpha, kappa, gamma, **kwargs)[source]

Resample the transition parameters of the HDP-HMM.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • z (jax_array of shape (..., T - n_lags)) – Discrete state sequences.

  • mask (jax array of shape (..., T)) – Binary indicator for which data points are valid.

  • betas (jax array of shape (num_states,)) – State usages.

  • alpha (scalar) – State usage influence hyperparameter.

  • kappa (scalar) – State persistence (i.e. “stickiness”) hyperparameter.

  • gamma (scalar) – Usage uniformity hyperparameter.

  • kwargs (dict) – Overflow, for convenience.

Returns:

  • betas (jax array of shape (num_states,)) – Resampled state usages.

  • pi (jax_array of shape (num_states, num_states)) – Resampled transition probabilities.

jax_moseq.utils.transitions.init_hdp_transitions(seed, num_states, alpha, kappa, gamma, **kwargs)[source]

Initialize the transition parameters of the HDP-HMM.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • num_states (int) – Max number of HMM states.

  • betas (jax array of shape (num_states,)) – State usages.

  • alpha (scalar) – State usage influence hyperparameter.

  • kappa (scalar) – State persistence (i.e. “stickiness”) hyperparameter.

  • gamma (scalar) – Usage uniformity hyperparameter.

  • kwargs (dict) – Overflow, for convenience.

Returns:

  • betas (jax array of shape (num_states,)) – Initial state usages.

  • pi (jax_array of shape (num_states, num_states)) – Initial transition probabilities.

jax_moseq.utils.transitions.sample_dir_transitions(seed, transition_counts, beta, kappa)[source]

Sample a transition matrix using a sticky Dirichlet prior.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • transition_counts (jax array of shape (num_states, num_states)) – The number of transitions between every pair of states.

  • beta (scalar) – Dirichlet prior concentration parameter.

  • kappa (scalar) – State persistence (i.e. “stickiness”) hyperparameter.

Returns:

pi – Transition probabilities.

Return type:

jax_array of shape (num_states, num_states)

jax_moseq.utils.transitions.resample_dir_transitions(seed, num_states, z, mask, beta, kappa, **kwargs)[source]

Resample Markov transition probabilities using a sticky Dirichlet prior.

Parameters:
  • seed (jr.PRNGKey) – JAX random seed.

  • num_states (int) – Max number of HMM states.

  • z (jax_array of shape (..., T - n_lags)) – Discrete state sequences.

  • mask (jax array of shape (..., T)) – Binary indicator for which data points are valid.

  • beta (scalar) – Dirichlet prior concentration parameter.

  • kappa (scalar) – State persistence (i.e. “stickiness”) hyperparameter.

Returns:

pi – Resampled transition probabilities.

Return type:

jax_array of shape (num_states, num_states)

Kalman smoothing

Functions:

kalman_sample(seed, ys, mask, zs, m0, S0, A, ...)

Run forward-filtering and backward-sampling to draw samples from posterior of a 1st-order dynamic system with autoregressive dynamics of order n_lags.

ar_to_lds_emissions(Cd, R, y, m0, S0, nlags)

Given a linear dynamical system with L'th-order autoregressive dynamics in R^D, returns the emission terms and initia state distribution for a system with 1st-order dynamics in R^(D*L)

ar_to_lds_dynamics(Ab, Q)

Given a linear dynamical system with L'th-order autoregressive dynamics in R^D, returns the dynamics terms for a system with 1st-order dynamics in R^(D*L)

jax_moseq.utils.kalman.kalman_sample(seed, ys, mask, zs, m0, S0, A, B, Q, C, D, Rs, masked_dynamics_params, masked_obs_noise, jitter=0, parallel=True)[source]

Run forward-filtering and backward-sampling to draw samples from posterior of a 1st-order dynamic system with autoregressive dynamics of order n_lags.

Parameters:
  • seed (jr.PRNGKey.)

  • ys (jax.Array with shape (T, obs_dim)) – Continuous observations, minus first L+1 frames.

  • mask (jax.Array with shape (T,)) – Indicator of observation validity, for timesteps [L-1, T)

  • zs (jax.Array with shape (T-n_lags,)) – Discrete state sequence, taking integer values [1, n_states).

  • mu0 (jax.Array with shape (ar_dim,)) – Initial continuous state mean

  • S0 (jax.Array with shape (ar_dim, ar_dim)) – Initial continuous state covariance

  • A (jax.Array with shape (n_states, ar_dim, ar_dim)) – State dynamics matrix

  • B (jax.Array with shape (n_states, ar_dim)) – State input matrix

  • Q (jax.Array with shape (n_states, ar_dim, ar_dim)) – State noise matrix

  • C (jax.Array with shape (obs_dim, ar_dim)) – Observation transform matrix

  • D (jax.Array with shape (obs_dim,)) – Observation input matrix

  • Rs (jax.Array with shape (T, obs_dim)) – Observation noise scales (diagonal entries of covariance)

  • masked_dynamics_params (dict with key-value pairs) –

    • weights: jax.Array with shape (ar_dim, ar_dim)

    • bias: jax.Array with shape (ar_dim,)

    • cov: jax.Array with shape (ar_dim, ar_dim)

    Dynamics parameters, for masked timesteps

  • masked_obs_noise (jax.Array with shape (obs_dim,)) – Diagonal observation noise scale, for masked timesteps.

  • jitter (float, default=0) – Amount to boost the diagonal of the covariance matrix during backward-sampling of the continuous states.

  • parallel (bool, default=True,) – Use associative scan for Kalman sampling, which is faster on a GPU but has a significantly longer jit time.

Returns:

xs – Sampled continuous state sequence.

Return type:

jax.Array with shape (T, ar_dim)

jax_moseq.utils.kalman.ar_to_lds_emissions(Cd, R, y, m0, S0, nlags)[source]

Given a linear dynamical system with L’th-order autoregressive dynamics in R^D, returns the emission terms and initia state distribution for a system with 1st-order dynamics in R^(D*L)

Parameters:
  • Cd (jax array, shape (D_obs, D+1)) – Observation affine transformation

  • R (jax array, shape (T, D_obs)) – Dimension-wise observation covariances

  • y (jax array, shape (T, D_obs)) – Observations

  • m0 (jax array, shape (D)) – Initial state distribution mean

  • S0 (jax array, shape (D, D)) – Initial state distribution cov

  • nlags (Number of autoregressive lags)

Returns:

  • C_ (jax array, shape (D_obs, D*L))

  • d_ (jax array, shape (D_obs))

  • R_ (jax array, shape (T, D_obs))

  • y_ (jax array, shape (T, D_obs))

  • m0_ (jax array, shape (D*L))

  • S0_ (jax array, shape (D*L, D*L))

jax_moseq.utils.kalman.ar_to_lds_dynamics(Ab, Q)[source]

Given a linear dynamical system with L’th-order autoregressive dynamics in R^D, returns the dynamics terms for a system with 1st-order dynamics in R^(D*L)

Parameters:
  • Ab (jax array, shape (..., D, D*L + 1)) – AR affine transform

  • Q (jax array, shape (..., D, D)) – AR noise covariance

Returns:

  • A_ (jax array, shape (…, D*L, D*L))

  • b_ (jax array, shape (…, D*L))

  • Q_ (jax array, shape (…, D*L, D*L))

Autoregression

Functions:

get_lags(x, nlags)

Get lags of a multivariate time series.

jax_moseq.utils.autoregression.get_lags(x, nlags)[source]

Get lags of a multivariate time series. Lags are concatenated along the last dim in time-order.

Parameters:
  • nlags (int) – Number of lags

  • x (jax array, shape (..., t, d)) – Batch of d-dimensional time series

Returns:

x_lagged

Return type:

jax array, shape (…, t-nlags, d*nlags)

Odds and ends

Functions:

concatenate_stateseqs(stateseqs[, mask])

Concatenate state sequences, optionally applying a mask.

get_durations(stateseqs[, mask])

Get durations for a batch of state sequences.

get_frequencies(stateseqs[, mask, ...])

Get state frequencies for a batch of state sequences.

symmetrize(A)

Symmetrize a matrix.

psd_solve(A, B[, diagonal_boost])

Solves the linear system Ax=B, assuming A is positive semi-definite.

psd_inv(A[, diagonal_boost])

Invert a positive semi-definite matrix.

jax_io(fn)

Converts a function involving numpy arrays to one that inputs and outputs jax arrays.

pad_affine(x)

Pad x with 1's so that it can be affine transformed with matrix multiplication.

fit_pca(Y, mask[, PCA_fitting_num_frames, ...])

Fit a PCA model to transformed keypoint coordinates.

wrap_angle(x)

Wrap an angle to the range [-pi, pi].

pad_along_axis(arr, pad_widths[, axis, value])

Pad an array along a single axis

unbatch(data, keys, bounds)

Invert jax_moseq.utils.batch()

batch(data_dict[, keys, seg_length, seg_overlap])

Stack time-series data of different lengths into a single array for batch processing, optionally breaking up the data into fixed length segments.

get_mixed_map_iters()

Get the number of iterations to use for jax.lax.map in jax_moseq.utils.mixed_map().

set_mixed_map_iters(iters)

Set the number of iterations to use for jax.lax.map in jax_moseq.utils.mixed_map().

get_mixed_map_gpus()

Get the number of GPUs to use for jax.pmap in jax_moseq.utils.mixed_map().

set_mixed_map_gpus(gpus)

Set the number of GPUs to use for jax.pmap in jax_moseq.utils.mixed_map().

mixed_map(fun[, in_axes, out_axes])

Combine jax.pmap, jax.vmap and jax.lax.map for parallelization.

jax_moseq.utils.utils.concatenate_stateseqs(stateseqs, mask=None)[source]

Concatenate state sequences, optionally applying a mask.

Parameters:
  • stateseqs (ndarray of shape (..., t), or dict or list of such arrays) – Batch of state sequences where the last dim indexes time, or a dict/list containing state sequences as 1d arrays.

  • mask (ndarray of shape (..., >=t), default=None) – Binary indicator for which elements of stateseqs are valid, used in the case where stateseqs is an ndarray. If mask contains more time-points than stateseqs, the initial extra time-points will be ignored.

Returns:

stateseqs_flat – 1d array containing all state sequences

Return type:

ndarray

jax_moseq.utils.utils.get_durations(stateseqs, mask=None)[source]

Get durations for a batch of state sequences.

Parameters:
  • stateseqs (ndarray of shape (..., t), or dict or list of such arrays) – Batch of state sequences where the last dim indexes time, or a dict/list containing state sequences as 1d arrays.

  • mask (ndarray of shape (..., >=t), default=None) – Binary indicator for which elements of stateseqs are valid, used in the case where stateseqs is an ndarray. If mask contains more time-points than stateseqs, the initial extra time-points will be ignored.

Returns:

durations – The duration of each each state (across all state sequences)

Return type:

1d array

Examples

>>> stateseqs = {
    'name1': np.array([1, 1, 2, 2, 2, 3]),
    'name2': np.array([0, 0, 0, 1])
}
>>> get_durations(stateseqs)
array([2, 3, 1, 3, 1])
jax_moseq.utils.utils.get_frequencies(stateseqs, mask=None, num_states=None, runlength=True)[source]

Get state frequencies for a batch of state sequences.

Parameters:
  • stateseqs (ndarray of shape (..., t), or dict or list of such arrays) – Batch of state sequences where the last dim indexes time, or a dict/list containing state sequences as 1d arrays.

  • mask (ndarray of shape (..., >=t), default=None) – Binary indicator for which elements of stateseqs are valid, used in the case where stateseqs is an ndarray. If mask contains more time-points than stateseqs, the initial extra time-points will be ignored.

  • num_states (int, default=None) – Number of different states. If None, the number of states will be set to max(stateseqs)+1.

  • runlength (bool, default=True) – Whether to count frequency by the number of instances of each state (True), or by the number of frames in each state (False).

Returns:

frequencies – Frequency of each state across all state sequences

Return type:

1d array

Examples

>>> stateseqs = {
    'name1': np.array([1, 1, 2, 2, 2, 3]),
    'name2': np.array([0, 0, 0, 1])}
>>> get_frequencies(stateseqs, runlength=True)
array([0.2, 0.4, 0.2, 0.2])
>>> get_frequencies(stateseqs, runlength=False)
array([0.3, 0.3, 0.3, 0.1])
jax_moseq.utils.utils.symmetrize(A)[source]

Symmetrize a matrix.

jax_moseq.utils.utils.psd_solve(A, B, diagonal_boost=1e-06)[source]

Solves the linear system Ax=B, assuming A is positive semi-definite.

Uses Cholesky decomposition for improved numerical stability and efficiency. A is symmetrized and diagonal elements are boosted by diagonal_boost to ensure positive definiteness.

Parameters:
  • A (jax array, shape (n,n)) – A positive semi-definite matrix

  • b (jax array, shape (...,n))

Returns:

x – Solution of the linear system Ax=b

Return type:

jax array, shape (…,n)

jax_moseq.utils.utils.psd_inv(A, diagonal_boost=1e-06)[source]

Invert a positive semi-definite matrix.

Uses jax_moseq.utils.psd_solve() for numerical stability and ensures that the inverse matrix is symmetric.

Parameters:

A (jax array, shape (n,n)) – A positive semi-definite matrix

Returns:

Ainv – The inverse of A

Return type:

jax array, shape (n,n)

jax_moseq.utils.utils.jax_io(fn)[source]

Converts a function involving numpy arrays to one that inputs and outputs jax arrays.

jax_moseq.utils.utils.pad_affine(x)[source]

Pad x with 1’s so that it can be affine transformed with matrix multiplication.

jax_moseq.utils.utils.fit_pca(Y, mask, PCA_fitting_num_frames=1000000, verbose=False, **kwargs)[source]

Fit a PCA model to transformed keypoint coordinates.

Parameters:
  • Y (jax array, shape (..., d)) – Keypoint coordinates

  • mask (jax array) – Binary indicator for which elements of Y are valid

  • PCA_fitting_num_frames (int, default=1000000) – Maximum number of frames to use for PCA. Frames will be sampled randomly if the input data exceed this size.

  • verbose (bool, default=False) – Whether to print the number of sampled frames.

Returns:

An sklearn PCA model fit to Y

Return type:

pca, sklearn.decomposition._pca.PCA

jax_moseq.utils.utils.wrap_angle(x)[source]

Wrap an angle to the range [-pi, pi].

jax_moseq.utils.utils.pad_along_axis(arr, pad_widths, axis=0, value=0)[source]

Pad an array along a single axis

Parameters:
  • arr (ndarray, Array to be padded)

  • pad_widths (tuple (int,int), Amount of padding on either end)

  • axis (int, Axis along which to add padding)

  • value (float, Value of padded array elements)

Returns:

padded_arr

Return type:

ndarray

jax_moseq.utils.utils.unbatch(data, keys, bounds)[source]

Invert jax_moseq.utils.batch()

Parameters:
  • data (ndarray, shape (num_segs, seg_length, ...)) – Stack of segmented time-series

  • keys (list or array of str, length num_segs) – Name of the time-series that each segment came from

  • bounds (ndarray, shape (num_segs, 2)) – Start and end times for each segment, reflecting how the segments were extracted from the original time-series.

Returns:

data_dict – Dictionary mapping names to reconstructed time-series

Return type:

dict

jax_moseq.utils.utils.batch(data_dict, keys=None, seg_length=None, seg_overlap=30)[source]

Stack time-series data of different lengths into a single array for batch processing, optionally breaking up the data into fixed length segments. The data is padded so that the stacked array isn’t ragged. The padding repeats the last frame of each time-series until the end of the segment.

Parameters:
  • data_dict (dict {str : ndarray}) – Dictionary mapping names to ndarrays, where the first dim represents time. All data arrays must have the same shape except for the first dim.

  • keys (list of str, default=None) – Optional list of names specifying which datasets to include in the output and what order to put them in. Each name must be a key in data_dict. If keys=None, names will be sorted alphabetically.

  • seg_length (int, default=None) – Break each time-series into segments of this length. If seg_length=None, the final stacked array will be as long as the longest time-series.

  • seg_overlap (int, default=30) – Amount of overlap between segments. For example, setting seg_length=N and seg_overlap=M will result in segments with start/end times (0, N+M), (N, 2*N+M), (2*N, 3*N+M),…

Returns:

  • data (ndarray, shape (N, seg_length, …)) – Stacked data array

  • mask (ndarray, shape (N, seg_length)) – Binary indicator specifying which elements of data are not padding (mask==0 in padded locations)

  • metadata (tuple (keys, bounds)) – Metadata for the rows of data, as a tuple with an array of keys and an array of (start,end) times.

jax_moseq.utils.utils.get_mixed_map_iters()[source]

Get the number of iterations to use for jax.lax.map in jax_moseq.utils.mixed_map().

jax_moseq.utils.utils.set_mixed_map_iters(iters)[source]

Set the number of iterations to use for jax.lax.map in jax_moseq.utils.mixed_map().

jax_moseq.utils.utils.get_mixed_map_gpus()[source]

Get the number of GPUs to use for jax.pmap in jax_moseq.utils.mixed_map().

jax_moseq.utils.utils.set_mixed_map_gpus(gpus)[source]

Set the number of GPUs to use for jax.pmap in jax_moseq.utils.mixed_map().

jax_moseq.utils.utils.mixed_map(fun, in_axes=None, out_axes=None)[source]

Combine jax.pmap, jax.vmap and jax.lax.map for parallelization.

This function is similar to jax.vmap, except that it mixes together jax.pmap, jax.vmap and jax.lax.map to prevent OOM errors and allow for parallelization across multiple GPUs. The behavior is determined by the global variables _MIXED_MAP_ITERS and _MIXED_MAP_GPUS, which can be set using jax_moseq.utils.set_mixed_map_iters() and py:func:jax_moseq.utils.set_mixed_map_gpus respectively.

Given an axis size of N to map, the data is padded such that the axis size is a multiple of the number of _MIXED_MAP_ITERS * _MIXED_MAP_GPUS. The data is then processed serially chunks, where the number of chunks is determined by _MIXED_MAP_ITERS. Each chunk is processed in parallel using jax.pmap to distribute across _MIXED_MAP_GPUS devices and jax.vmap within each device.