Source code for jax_moseq.utils.utils

import numpy as np
import jax
import jax.numpy as jnp
from sklearn.decomposition import PCA
from jax.scipy.linalg import cho_factor, cho_solve
from textwrap import fill
import functools
from math import ceil

_MIXED_MAP_ITERS = 1
_MIXED_MAP_GPUS = 1


[docs] def concatenate_stateseqs(stateseqs, mask=None): """ Concatenate state sequences, optionally applying a mask. Parameters ---------- stateseqs: ndarray of shape (..., t), or dict or list of such arrays Batch of state sequences where the last dim indexes time, or a dict/list containing state sequences as 1d arrays. mask: ndarray of shape (..., >=t), default=None Binary indicator for which elements of `stateseqs` are valid, used in the case where `stateseqs` is an ndarray. If `mask` contains more time-points than `stateseqs`, the initial extra time-points will be ignored. Returns ------- stateseqs_flat: ndarray 1d array containing all state sequences """ if isinstance(stateseqs, dict): stateseq_flat = np.hstack(list(stateseqs.values())) elif isinstance(stateseqs, list): stateseq_flat = np.hstack(stateseqs) elif mask is not None: stateseq_flat = stateseqs[mask[:, -stateseqs.shape[1] :] > 0] else: stateseq_flat = stateseqs.flatten() return stateseq_flat
[docs] def get_durations(stateseqs, mask=None): """ Get durations for a batch of state sequences. Parameters ---------- stateseqs: ndarray of shape (..., t), or dict or list of such arrays Batch of state sequences where the last dim indexes time, or a dict/list containing state sequences as 1d arrays. mask: ndarray of shape (..., >=t), default=None Binary indicator for which elements of `stateseqs` are valid, used in the case where `stateseqs` is an ndarray. If `mask` contains more time-points than `stateseqs`, the initial extra time-points will be ignored. Returns ------- durations: 1d array The duration of each each state (across all state sequences) Examples -------- >>> stateseqs = { 'name1': np.array([1, 1, 2, 2, 2, 3]), 'name2': np.array([0, 0, 0, 1]) } >>> get_durations(stateseqs) array([2, 3, 1, 3, 1]) """ stateseq_flat = concatenate_stateseqs(stateseqs, mask=mask).astype(int) stateseq_padded = np.hstack([[-1], stateseq_flat, [-1]]) changepoints = np.diff(stateseq_padded).nonzero()[0] return changepoints[1:] - changepoints[:-1]
[docs] def get_frequencies(stateseqs, mask=None, num_states=None, runlength=True): """ Get state frequencies for a batch of state sequences. Parameters ---------- stateseqs: ndarray of shape (..., t), or dict or list of such arrays Batch of state sequences where the last dim indexes time, or a dict/list containing state sequences as 1d arrays. mask: ndarray of shape (..., >=t), default=None Binary indicator for which elements of `stateseqs` are valid, used in the case where `stateseqs` is an ndarray. If `mask` contains more time-points than `stateseqs`, the initial extra time-points will be ignored. num_states: int, default=None Number of different states. If None, the number of states will be set to `max(stateseqs)+1`. runlength: bool, default=True Whether to count frequency by the number of instances of each state (True), or by the number of frames in each state (False). Returns ------- frequencies: 1d array Frequency of each state across all state sequences Examples -------- >>> stateseqs = { 'name1': np.array([1, 1, 2, 2, 2, 3]), 'name2': np.array([0, 0, 0, 1])} >>> get_frequencies(stateseqs, runlength=True) array([0.2, 0.4, 0.2, 0.2]) >>> get_frequencies(stateseqs, runlength=False) array([0.3, 0.3, 0.3, 0.1]) """ stateseq_flat = concatenate_stateseqs(stateseqs, mask=mask).astype(int) if runlength: state_onsets = np.pad(np.diff(stateseq_flat).nonzero()[0] + 1, (1, 0)) stateseq_flat = stateseq_flat[state_onsets] counts = np.bincount(stateseq_flat, minlength=num_states) frequencies = counts / counts.sum() return frequencies
[docs] def symmetrize(A): """Symmetrize a matrix.""" return (A + A.swapaxes(-1, -2)) / 2
[docs] def psd_solve(A, B, diagonal_boost=1e-6): """ Solves the linear system Ax=B, assuming A is positive semi-definite. Uses Cholesky decomposition for improved numerical stability and efficiency. A is symmetrized and diagonal elements are boosted by ``diagonal_boost`` to ensure positive definiteness. Parameters ---------- A: jax array, shape (n,n) A positive semi-definite matrix b: jax array, shape (...,n) Returns ------- x: jax array, shape (...,n) Solution of the linear system Ax=b """ A = symmetrize(A) + diagonal_boost * jnp.eye(A.shape[-1]) L, lower = cho_factor(A, lower=True) x = cho_solve((L, lower), B) return x
[docs] def psd_inv(A, diagonal_boost=1e-6): """ Invert a positive semi-definite matrix. Uses :py:func:`jax_moseq.utils.psd_solve` for numerical stability and ensures that the inverse matrix is symmetric. Parameters ---------- A: jax array, shape (n,n) A positive semi-definite matrix Returns ------- Ainv: jax array, shape (n,n) The inverse of A """ Ainv = psd_solve(A, jnp.eye(A.shape[-1]), diagonal_boost=diagonal_boost) return symmetrize(Ainv)
[docs] def jax_io(fn): """ Converts a function involving numpy arrays to one that inputs and outputs jax arrays. """ return lambda *args, **kwargs: jax.device_put( fn(*jax.device_get(args), **jax.device_get(kwargs)) )
def device_put_as_scalar(x): as_scalar = lambda arr: arr.item() if arr.shape == () else arr return jax.tree_util.tree_map(as_scalar, jax.device_put(x)) def apply_affine(x, Ab): return jnp.einsum("...ij, ...j->...i", Ab, pad_affine(x))
[docs] def pad_affine(x): """ Pad ``x`` with 1's so that it can be affine transformed with matrix multiplication. """ padding = jnp.ones((*x.shape[:-1], 1)) xpadded = jnp.concatenate((x, padding), axis=-1) return xpadded
[docs] def fit_pca(Y, mask, PCA_fitting_num_frames=1000000, verbose=False, **kwargs): """ Fit a PCA model to transformed keypoint coordinates. Parameters ---------- Y: jax array, shape (..., d) Keypoint coordinates mask: jax array Binary indicator for which elements of ``Y`` are valid PCA_fitting_num_frames: int, default=1000000 Maximum number of frames to use for PCA. Frames will be sampled randomly if the input data exceed this size. verbose: bool, default=False Whether to print the number of sampled frames. Returns ------- pca, sklearn.decomposition._pca.PCA An sklearn PCA model fit to Y """ Y_flat = Y[mask > 0] N = Y_flat.shape[0] N_sample = min(PCA_fitting_num_frames, N) rng = np.random.default_rng(42) sample = rng.choice(N, N_sample, replace=False) Y_sample = np.array(Y_flat)[sample] if verbose: print(f"PCA: Fitting PCA model to {N_sample} data points") pca = PCA(random_state=42).fit(Y_sample) return pca
[docs] def wrap_angle(x): """ Wrap an angle to the range [-pi, pi]. """ return (x + jnp.pi) % (2 * jnp.pi) - jnp.pi
[docs] def pad_along_axis(arr, pad_widths, axis=0, value=0): """ Pad an array along a single axis Parameters ------- arr: ndarray, Array to be padded pad_widths: tuple (int,int), Amount of padding on either end axis: int, Axis along which to add padding value: float, Value of padded array elements Returns _______ padded_arr: ndarray """ pad_left_shape = list(arr.shape) pad_right_shape = list(arr.shape) pad_left_shape[axis] = pad_widths[0] pad_right_shape[axis] = pad_widths[1] padding_left = jnp.ones(pad_left_shape, dtype=arr.dtype) * value padding_right = jnp.ones(pad_right_shape, dtype=arr.dtype) * value padded_arr = jnp.concatenate([padding_left, arr, padding_right], axis=axis) return padded_arr
[docs] def unbatch(data, keys, bounds): """ Invert :py:func:`jax_moseq.utils.batch` Parameters ---------- data: ndarray, shape (num_segs, seg_length, ...) Stack of segmented time-series keys: list or array of str, length num_segs Name of the time-series that each segment came from bounds: ndarray, shape (num_segs, 2) Start and end times for each segment, reflecting how the segments were extracted from the original time-series. Returns ------- data_dict: dict Dictionary mapping names to reconstructed time-series """ data_dict = {} for key in set(list(keys)): length = bounds[keys == key, 1].max() seq = np.zeros((int(length), *data.shape[2:]), dtype=data.dtype) for (s, e), d in zip(bounds[keys == key], data[keys == key]): seq[s:e] = d[: e - s] data_dict[key] = seq return data_dict
[docs] def batch(data_dict, keys=None, seg_length=None, seg_overlap=30): """ Stack time-series data of different lengths into a single array for batch processing, optionally breaking up the data into fixed length segments. The data is padded so that the stacked array isn't ragged. The padding repeats the last frame of each time-series until the end of the segment. Parameters ---------- data_dict: dict {str : ndarray} Dictionary mapping names to ndarrays, where the first dim represents time. All data arrays must have the same shape except for the first dim. keys: list of str, default=None Optional list of names specifying which datasets to include in the output and what order to put them in. Each name must be a key in ``data_dict``. If ``keys=None``, names will be sorted alphabetically. seg_length: int, default=None Break each time-series into segments of this length. If ``seg_length=None``, the final stacked array will be as long as the longest time-series. seg_overlap: int, default=30 Amount of overlap between segments. For example, setting ``seg_length=N`` and ``seg_overlap=M`` will result in segments with start/end times (0, N+M), (N, 2*N+M), (2*N, 3*N+M),... Returns ------- data: ndarray, shape (N, seg_length, ...) Stacked data array mask: ndarray, shape (N, seg_length) Binary indicator specifying which elements of ``data`` are not padding (``mask==0`` in padded locations) metadata: tuple (keys, bounds) Metadata for the rows of `data`, as a tuple with an array of keys and an array of (start,end) times. """ if keys is None: keys = sorted(data_dict.keys()) Ns = [len(data_dict[key]) for key in keys] if seg_length is None: seg_length = np.max(Ns) stack, mask, keys_out, bounds = [], [], [], [] for key, N in zip(keys, Ns): for start in range(0, N, seg_length): arr = data_dict[key] end = min(start + seg_length + seg_overlap, N) pad_length = seg_length + seg_overlap - (end - start) padding = np.repeat(arr[end - 1 : end], pad_length, axis=0) mask.append(np.hstack([np.ones(end - start), np.zeros(pad_length)])) stack.append(np.concatenate([arr[start:end], padding], axis=0)) keys_out.append(key) bounds.append((start, end)) stack = np.stack(stack) mask = np.stack(mask) metadata = (np.array(keys_out), np.array(bounds)) return stack, mask, metadata
[docs] def get_mixed_map_iters(): """Get the number of iterations to use for jax.lax.map in :py:func:`jax_moseq.utils.mixed_map`.""" return _MIXED_MAP_ITERS
[docs] def set_mixed_map_iters(iters): """Set the number of iterations to use for jax.lax.map in :py:func:`jax_moseq.utils.mixed_map`.""" global _MIXED_MAP_ITERS _MIXED_MAP_ITERS = iters
[docs] def get_mixed_map_gpus(): """Get the number of GPUs to use for jax.pmap in :py:func:`jax_moseq.utils.mixed_map`.""" return _MIXED_MAP_GPUS
[docs] def set_mixed_map_gpus(gpus): """Set the number of GPUs to use for jax.pmap in :py:func:`jax_moseq.utils.mixed_map`.""" global _MIXED_MAP_GPUS _MIXED_MAP_GPUS = gpus
def _reshape_args(args, axes): """Reshape args to (pmap dim, lax.map dim, vmap dim, [other dims])""" n_iters = get_mixed_map_iters() n_gpus = get_mixed_map_gpus() axis_size = args[0].shape[axes[0]] vmap_size = ceil(axis_size / n_iters / n_gpus) pmap_size = ceil(axis_size / vmap_size / n_iters) lmap_size = ceil(axis_size / vmap_size / pmap_size) padding = vmap_size * pmap_size * lmap_size - axis_size def _reshape(a, axis): if axis > 0: a = jnp.moveaxis(a, axis, 0) if padding > 0: padding_array = jnp.zeros((padding, *a.shape[1:]), dtype=a.dtype) a = jnp.concatenate((a, padding_array)) return a.reshape(lmap_size, pmap_size, vmap_size, *a.shape[1:]) args = [_reshape(arg, axis) for arg, axis in zip(args, axes)] return args, axis_size def _reshape_outputs(outputs, axes, axis_size): """Reshape outputs from (lax.map dim, vmap dim, [other dims])""" def _reshape(a, axis): a = a.reshape(-1, *a.shape[3:])[:axis_size] if axis > 0: a = jnp.moveaxis(a, 0, axis) return a outputs = [_reshape(out, axis) for out, axis in zip(outputs, axes)] if len(outputs) == 1: outputs = outputs[0] return outputs def _partial(fun, other_args, mapped_argnums, other_argnums): def partial_fun(mapped_args): args = {} for i, arg in zip(mapped_argnums, mapped_args): args[i] = arg for i, arg in zip(other_argnums, other_args): args[i] = arg args = [args[i] for i in range(len(args))] return fun(*args) return partial_fun def _sort_args(args, in_axes): """Sort arguments into mapped and unmapped arguments.""" mapped_args, mapped_argnums = [], [] other_args, other_argnums = [], [] for i, (arg, axis) in enumerate(zip(args, in_axes)): if axis is not None: mapped_args.append(arg) mapped_argnums.append(i) else: other_args.append(arg) other_argnums.append(i) return mapped_args, mapped_argnums, other_args, other_argnums
[docs] def mixed_map(fun, in_axes=None, out_axes=None): """ Combine jax.pmap, jax.vmap and jax.lax.map for parallelization. This function is similar to `jax.vmap`, except that it mixes together `jax.pmap`, `jax.vmap` and `jax.lax.map` to prevent OOM errors and allow for parallelization across multiple GPUs. The behavior is determined by the global variables `_MIXED_MAP_ITERS` and `_MIXED_MAP_GPUS`, which can be set using :py:func:`jax_moseq.utils.set_mixed_map_iters` and py:func:`jax_moseq.utils.set_mixed_map_gpus` respectively. Given an axis size of N to map, the data is padded such that the axis size is a multiple of the number of `_MIXED_MAP_ITERS * _MIXED_MAP_GPUS`. The data is then processed serially chunks, where the number of chunks is determined by `_MIXED_MAP_ITERS`. Each chunk is processed in parallel using jax.pmap to distribute across `_MIXED_MAP_GPUS` devices and jax.vmap within each device. """ @functools.wraps(fun) def mixed_map_f(*args): nonlocal in_axes nonlocal out_axes if in_axes is None: in_axes = tuple([0] * len(args)) else: assert len(in_axes) == len( args ), "`in_axes` should be a tuple with the same length as the number of arguments" mapped_args, mapped_argnums, other_args, other_argnums = _sort_args( args, in_axes ) mapped_args, axis_size = _reshape_args( mapped_args, [in_axes[i] for i in mapped_argnums] ) f = _partial(fun, other_args, mapped_argnums, other_argnums) outputs = jax.lax.map(jax.pmap(jax.vmap(f)), mapped_args) if not isinstance(outputs, tuple) or isinstance(outputs, list): outputs = (outputs,) if out_axes is None: out_axes = tuple([0] * len(outputs)) else: assert len(out_axes) == len( outputs ), "`out_axes` should be a tuple with the same length as the number of function outputs" outputs = _reshape_outputs(outputs, out_axes, axis_size) return outputs return mixed_map_f