import numpy as np
from numba import njit, prange
import jax
import jax.numpy as jnp
import jax.random as jr
eps = jnp.finfo(jnp.float32).tiny
from functools import partial
[docs]
@partial(jax.jit, static_argnames=("num_states"))
def count_transitions(num_states, stateseqs, mask):
"""
Count the number of transitions between each pair of
states ``i`` and ``j`` in the unmasked entries of ``stateseqs``,
including self transitions (i.e. i == j).
Parameters
----------
num_states: int
Total number of states (must exceed ``max(stateseqs)``).
stateseqs: jax int array of shape (..., T)
Batch of state sequences where the last dim indexes time.
All entries
mask: jax array of shape (..., T + num_lags)
Binary indicator for which elements of ``stateseqs`` are valid.
If ``num_lags > 0``, the first ``num_lags`` time points of the mask
are ignored (ensures time alignment with the AR process).
Returns
-------
transition_counts: jax array of shape (num_states, num_states)
The number of transitions between every pair of states.
"""
T = stateseqs.shape[-1]
mask = mask[..., -T + 1 :]
start_states = stateseqs[..., :-1]
end_states = stateseqs[..., 1:]
transition_counts = jnp.zeros((num_states, num_states))
transition_counts = transition_counts.at[start_states, end_states].add(
mask
)
return transition_counts
@njit
def _sample_loyal_crf_table_counts(
rng, customer_counts, dish_ratings, loyalty
):
"""
In a Chinese restaurant franchise (CRF) process with loyal customers,
``table_counts[i, j]`` represents the number of tables in restaurant ``i``
that were served dish ``j``, which is a random variable that depends on:
(1) the observed number of patrons in the restaurant eating
the dish (``customer_counts[i, j]``),
(2) the franchise-wide popularity of the dish
(``dish_ratings[j]``), and
(3) the bias towards each restaurant's specialty dish
(i.e. the dish that shares its index ``i``), which is
encoded by ``loyalty``.
This function samples that value for each restaurant/dish pair. In
brief, each restaurant is a row of the transition matrix, each instance
of a customer in restaurant ``i`` eating dish ``j`` represents a transition
from ``i`` to ``j``, and the number of tables that served dish ``j``
throughout the franchise is used (after a correction step) for the resampling
of the franchise-wide ``dish_ratings`` (analogous to ``betas`` scaled by ``alpha``).
For a more thorough overview of the analogy and its relevance to the HDP-HMM
Gibbs sampling algorithm, see the reference (where `table_counts` corresponds
to the auxillary parameter m).
Parameters
----------
rng : instance of numpy.random.Generator
Seeded random number generator
customer_counts : numpy array of shape (N, N)
Number of customers for each restaurant/dish pair.
dish_ratings : numpy array of shape N
Parameter representing franchise-wide popularity of each dish.
loyalty : scalar
Non-negative scalar representing customers' bias for their
restaurant's specialty dish.
Returns
-------
table_counts : numpy array of shape (N, N)
Number of tables in each restaurant served each dish.
References
----------
See the supplement to Fox et al. 2011 at
<http://dx.doi.org/10.1214/10-AOAS395SUPP>.
"""
N = len(dish_ratings) # num restaurants/dishes
# Sample counts without considering loyalty factor
table_counts = np.zeros_like(customer_counts)
for i in prange(N):
for j in range(N):
# Sample counts by simulating table dish
# assignment process.
for k in range(customer_counts[i, j]):
dish_rating = dish_ratings[j]
if i == j:
# Account for loyalty factor
dish_rating += loyalty
p = dish_rating / (k + dish_rating)
bernoulli_sample = rng.random() < p
table_counts[i, j] += bernoulli_sample
return table_counts
def _sample_beta_suffient_stats(
seed, transition_counts, betas, alpha, kappa, gamma
):
"""
Compute the sufficient statistics for the Gibbs resampling
of ``betas`` using the auxillary parameter scheme devised
by Fox et al. for the Sticky HDP-HMM.
Parameters
----------
seed : jr.PRNGKey
JAX random seed.
transition_counts : jax array of shape (num_states, num_states)
The number of transitions between every pair of states.
betas : jax array of shape num_states
State usages.
alpha : scalar
State usage influence hyperparameter.
kappa : scalar
State persistence (i.e. "stickiness") hyperparameter.
gamma : scalar
Usage uniformity hyperparameter.
Returns
-------
sufficient_stats : jax array of shape num_states
Sufficient statistics for resampling `betas`.
"""
num_states = len(betas)
# Sample table counts (uses numpy/numba)
rng = np.random.default_rng(seed[0].item())
concentrations = np.array(alpha * betas)
transition_counts = np.array(transition_counts, dtype=np.int32)
# m in Fox et al.
table_counts = _sample_loyal_crf_table_counts(
rng, transition_counts, concentrations, kappa
)
# Downweight the influence of self transitions,
# which are less informative about state usages
auxillary_param = table_counts # corresponds to mbar in Fox et al.
diagonal_counts = np.diag(auxillary_param)
p = concentrations / (concentrations + kappa)
binomial_samples = rng.binomial(diagonal_counts, p)
np.fill_diagonal(auxillary_param, binomial_samples)
# Compute sufficient statistics
sufficient_stats = auxillary_param.sum(0) + (gamma / num_states)
sufficient_stats = jax.device_put(sufficient_stats)
return sufficient_stats
[docs]
def sample_betas(seed, transition_counts, betas, alpha, kappa, gamma):
"""
Sample the state usages ``betas`` given the observed transition
counts and the model hyperparameters.
Parameters
----------
seed : jr.PRNGKey
JAX random seed.
transition_counts : jax array of shape (num_states, num_states)
The number of transitions between every pair of states.
betas : jax array of shape num_states
State usages.
alpha : scalar
State usage influence hyperparameter.
kappa : scalar
State persistence (i.e. "stickiness") hyperparameter.
gamma : scalar
Usage uniformity hyperparameter.
Returns
-------
betas : jax array of shape num_states
Resampled state usages.
"""
sufficient_stats = _sample_beta_suffient_stats(
seed, transition_counts, betas, alpha, kappa, gamma
)
betas = jr.dirichlet(seed, sufficient_stats)
return betas
[docs]
def sample_pi(seed, transition_counts, betas, alpha, kappa):
"""
Sample the transition matrix ``pi`` given the observed transition
counts, state usages, and model hyperparameters.
Parameters
----------
seed : jr.PRNGKey
JAX random seed.
transition_counts : jax array of shape (num_states, num_states)
The number of transitions between every pair of states.
betas : jax array of shape num_states
State usages.
alpha : scalar
State usage influence hyperparameter.
kappa : scalar
State persistence (i.e. "stickiness") hyperparameter.
Returns
-------
pi : jax_array of shape (num_states, num_states)
Resampled transition probabilities.
"""
num_states = len(betas)
sufficient_stats = (
transition_counts + alpha * betas + kappa * jnp.eye(num_states)
)
pi = jr.dirichlet(seed, sufficient_stats)
return pi
[docs]
def sample_hdp_transitions(
seed, transition_counts, betas, alpha, kappa, gamma
):
"""
Sample the transition parameters of the HDP-HMM given
the observed transition counts, the current usage estimates,
and the model hyperparameters.
Parameters
----------
seed : jr.PRNGKey
JAX random seed.
transition_counts : jax array of shape (num_states, num_states)
The number of transitions between every pair of states.
betas : jax array of shape num_states
State usages.
alpha : scalar
State usage influence hyperparameter.
kappa : scalar
State persistence (i.e. "stickiness") hyperparameter.
gamma : scalar
Usage uniformity hyperparameter.
Returns
-------
betas : jax array of shape num_states
Resampled state usages.
pi : jax_array of shape (num_states, num_states)
Resampled transition probabilities.
"""
seeds = jr.split(seed)
betas = sample_betas(
seeds[0], transition_counts, betas, alpha, kappa, gamma
)
pi = sample_pi(seeds[1], transition_counts, betas, alpha, kappa)
return betas, pi
[docs]
def resample_hdp_transitions(
seed, z, mask, betas, alpha, kappa, gamma, **kwargs
):
"""
Resample the transition parameters of the HDP-HMM.
Parameters
----------
seed : jr.PRNGKey
JAX random seed.
z : jax_array of shape (..., T - n_lags)
Discrete state sequences.
mask : jax array of shape (..., T)
Binary indicator for which data points are valid.
betas : jax array of shape (num_states,)
State usages.
alpha : scalar
State usage influence hyperparameter.
kappa : scalar
State persistence (i.e. "stickiness") hyperparameter.
gamma : scalar
Usage uniformity hyperparameter.
kwargs : dict
Overflow, for convenience.
Returns
-------
betas : jax array of shape (num_states,)
Resampled state usages.
pi : jax_array of shape (num_states, num_states)
Resampled transition probabilities.
"""
num_states = len(betas)
transition_counts = count_transitions(num_states, z, mask)
betas, pi = sample_hdp_transitions(
seed, transition_counts, betas, alpha, kappa, gamma
)
return betas, pi
[docs]
def init_hdp_transitions(seed, num_states, alpha, kappa, gamma, **kwargs):
"""
Initialize the transition parameters of the HDP-HMM.
Parameters
----------
seed : jr.PRNGKey
JAX random seed.
num_states : int
Max number of HMM states.
betas : jax array of shape (num_states,)
State usages.
alpha : scalar
State usage influence hyperparameter.
kappa : scalar
State persistence (i.e. "stickiness") hyperparameter.
gamma : scalar
Usage uniformity hyperparameter.
kwargs : dict
Overflow, for convenience.
Returns
-------
betas : jax array of shape (num_states,)
Initial state usages.
pi : jax_array of shape (num_states, num_states)
Initial transition probabilities.
"""
seeds = jr.split(seed)
betas_init = jr.dirichlet(
seeds[0], jnp.full(num_states, gamma / num_states)
)
pseudo_counts = jnp.zeros((num_states, num_states))
betas, pi = sample_hdp_transitions(
seeds[1], pseudo_counts, betas_init, alpha, kappa, gamma
)
# pseudocount for numerical stability
pi = (pi + eps) / (pi + eps).sum(1)[:, None]
return betas, pi
[docs]
def sample_dir_transitions(seed, transition_counts, beta, kappa):
"""
Sample a transition matrix using a sticky Dirichlet prior.
Parameters
----------
seed : jr.PRNGKey
JAX random seed.
transition_counts : jax array of shape (num_states, num_states)
The number of transitions between every pair of states.
beta : scalar
Dirichlet prior concentration parameter.
kappa : scalar
State persistence (i.e. "stickiness") hyperparameter.
Returns
-------
pi : jax_array of shape (num_states, num_states)
Transition probabilities.
"""
num_states = transition_counts.shape[0]
conc = (beta + transition_counts) + kappa * jnp.eye(num_states)
pi = jax.vmap(jr.dirichlet)(jr.split(seed, num_states), conc)
return pi
[docs]
def resample_dir_transitions(seed, num_states, z, mask, beta, kappa, **kwargs):
"""
Resample Markov transition probabilities using a sticky Dirichlet prior.
Parameters
----------
seed : jr.PRNGKey
JAX random seed.
num_states : int
Max number of HMM states.
z : jax_array of shape (..., T - n_lags)
Discrete state sequences.
mask : jax array of shape (..., T)
Binary indicator for which data points are valid.
beta : scalar
Dirichlet prior concentration parameter.
kappa : scalar
State persistence (i.e. "stickiness") hyperparameter.
Returns
-------
pi : jax_array of shape (num_states, num_states)
Resampled transition probabilities.
"""
transition_counts = count_transitions(num_states, z, mask)
pi = sample_dir_transitions(seed, transition_counts, beta, kappa)
return pi