Model observations as Categorical random variables.
The CategoricalObservations is designed to model an observed categorical variable based on a categorical
distribution with given success probability.
It provides methods for computing the negative log-likelihood,
generating samples, and computing the residual deviance for the given categorical observations.
This distribution is equivalent to a multinomial with n=1.
This uses PEP-487 [1] to set the set_{method}_request methods. It
looks for the information available in the set default values which are
set using __metadata_request__* class attributes, or inferred
from method signatures.
The __metadata_request__* class attributes are used when a method
does not explicitly accept a metadata through its arguments or if the
developer would like to specify a request value for those metadata
which are different from the default None.
Compute the residual deviance for a Categorical model.
Parameters:
observations (Array) – One-hot encoded categories. Shape (n_time_bins,n_categories) or
(n_time_bins,n_neurons,n_categories).
predicted_rate (Array) – The log-probabilities of each category (output of log_softmax).
Shape (n_time_bins,n_categories) or (n_time_bins,n_neurons,n_categories).
scale (Union[float, Array]) – Scale parameter of the model. For Categorical should be equal to 1.
where \(y_k\) is the one-hot encoded observed category (1 if category \(k\) was observed, 0 otherwise),
\(\hat{p}_{k}\) is the predicted probability for category \(k\),
and \(\text{LL}\) is the model log-likelihood.
The saturated model has log-likelihood 0 for categorical (since \(\log(1) = 0\) for the true category).
Lower values of deviance indicate a better fit.
Assign 1 to the scale parameter of the Categorical model.
For the Categorical (Multinomial with n=1) exponential family distribution, the scale parameter
\(\phi\) is always 1.
Parameters:
y (Array) – One-hot encoded categories. Shape (n_time_bins,n_categories) or
(n_time_bins,n_neurons,n_categories).
predicted_rate (Array) – The predicted log-probabilities. This is not used in the Categorical model for estimating
scale, but is retained for compatibility with the abstract method signature.
dof_resid (Union[float, Array]) – The DOF of the residuals.
This computes the likelihood of the predicted category probabilities
for the observed one-hot encoded categories.
Parameters:
y (Array) – One-hot encoded categories. Shape (n_time_bins,n_categories) or
(n_time_bins,n_neurons,n_categories).
predicted_rate (Array) – The log-probabilities for each category (output of log_softmax).
Shape (n_time_bins,n_categories) or (n_time_bins,n_neurons,n_categories).
scale (Union[float, Array]) – The scale parameter of the model. For Categorical should be equal to 1.
aggregate_sample_scores (Callable) – Function that aggregates the likelihood of each sample.
This computes the Categorical log-likelihood of the predicted category probabilities
for the observed one-hot encoded categories.
Parameters:
y (Array) – One-hot encoded categories. Shape (n_time_bins,n_categories) or
(n_time_bins,n_neurons,n_categories).
predicted_rate (Array) – The log-probabilities for each category (output of log_softmax).
Shape (n_time_bins,n_categories) or (n_time_bins,n_neurons,n_categories).
scale (Union[float, Array]) – The scale parameter of the model. For Categorical should be equal to 1.
aggregate_sample_scores (Callable) – Function that aggregates the log-likelihood of each sample.
Returns:
The Categorical log-likelihood. Shape (1,).
Notes
The formula for the Categorical mean log-likelihood is the following,
where \(p_{tnk}\) is the predicted probability of category \(k\) for neuron
\(n\) at time \(t\), \(y_{tnk}\) is the one-hot encoding
(1 if category \(k\) was observed, 0 otherwise), and the predicted_rate input contains
\(\log(p_{tnk})\).
Compute the pseudo-\(R^2\) metric for the GLM, as defined by McFadden et al. [2]
or by Cohen et al. [3].
This metric evaluates the goodness-of-fit of the model relative to a null (baseline) model that assumes a
constant mean for the observations. While the pseudo-\(R^2\) is bounded between 0 and 1 for the
training set, it can yield negative values on out-of-sample data, indicating potential over-fitting.
Parameters:
y (Array) – The neural activity. Expected shape: (n_time_bins,)
predicted_rate (Array) – The mean neural activity. Expected shape: (n_time_bins,)
score_type (Literal['pseudo-r2-McFadden', 'pseudo-r2-Cohen']) – The pseudo-\(R^2\) type.
The pseudo-\(R^2\) of the model. A value closer to 1 indicates a better model fit,
whereas a value closer to 0 suggests that the model doesn’t improve much over the null model.
where \(L_M\), \(L_0\) and \(L_s\) are the likelihood of the fitted model, the null model (a
model with only the intercept term), and the saturated model (a model with one parameter per
sample, i.e. the maximum value that the likelihood could possibly achieve). \(D_M\) and \(D_0\) are
the model and the null deviance, \(D_i = -2 \left[ \log(L_s) - \log(L_i) \right]\) for \(i=M,0\).
This method generates random category indices from a Categorical distribution based on the given
log-probabilities. Note that this returns category indices, not one-hot encodings.
Parameters:
key (Array) – Random key used for the generation of random numbers in JAX.
predicted_rate (Array) – Log-probabilities for each category (output of log_softmax).
Shape (n_time_bins,n_categories) or (n_time_bins,n_neurons,n_categories).
scale (Union[float, Array]) – Scale parameter. For Categorical should be equal to 1.
Returns:
Random category indices sampled from the Categorical distribution.
Shape (n_time_bins,) or (n_time_bins,n_neurons).
The method works on simple estimators as well as on nested objects
(such as Pipeline). The latter have
parameters of the form <component>__<parameter> so that it’s
possible to update each component of a nested object.