nemos.observation_models.PoissonObservations#
- class nemos.observation_models.PoissonObservations(inverse_link_function=<PjitFunction of <function exp>>)[source]#
Bases:
Observations
Model observations as Poisson random variables.
The PoissonObservations is designed to model the observed spike counts based on a Poisson distribution with a given rate. It provides methods for computing the negative log-likelihood, generating samples, and computing the residual deviance for the given spike count data.
- inverse_link_function#
A function that maps the predicted rate to the domain of the Poisson parameter. Defaults to
jax.numpy.exp
.
Attributes
Getter for the inverse link function for the model.
Getter for the scale parameter of the model.
Methods
__init__
([inverse_link_function])Check if the provided inverse_link_function is usable.
deviance
(spike_counts, predicted_rate[, scale])Compute the residual deviance for a Poisson model.
estimate_scale
(y, predicted_rate, dof_resid)Assign 1 to the scale parameter of the Poisson model.
get_params
([deep])From scikit-learn, get parameters by inspecting init.
log_likelihood
(y, predicted_rate[, scale, ...])Compute the Poisson negative log-likelihood.
pseudo_r2
(y, predicted_rate[, score_type, ...])Pseudo-\(R^2\) calculation for a GLM.
sample_generator
(key, predicted_rate[, scale])Sample from the Poisson distribution.
set_params
(**params)Set the parameters of this estimator.
- static check_inverse_link_function(inverse_link_function)#
Check if the provided inverse_link_function is usable.
This function verifies if the inverse link function:
Is callable
Returns a jax.numpy.ndarray
Is differentiable (via jax)
- Parameters:
inverse_link_function (
Callable
) – The function to be checked.- Raises:
TypeError – If the function is not callable, does not return a jax.numpy.ndarray, or is not differentiable.
- deviance(spike_counts, predicted_rate, scale=1.0)[source]#
Compute the residual deviance for a Poisson model.
- Parameters:
spike_counts (
Array
) – The spike counts. Shape(n_time_bins, )
or(n_time_bins, n_neurons)
for population models.predicted_rate (
Array
) – The predicted firing rates. Shape(n_time_bins, )
or(n_time_bins, n_neurons)
for population models.scale (
Union
[float
,Array
]) – Scale parameter of the model.
- Return type:
Array
- Returns:
The residual deviance of the model.
Notes
The deviance is a measure of the goodness of fit of a statistical model. For a Poisson model, the residual deviance is computed as:
\[\begin{split}\begin{aligned} D(y_{tn}, \hat{y}_{tn}) &= 2 \left[ y_{tn} \log\left(\frac{y_{tn}}{\hat{y}_{tn}}\right) - (y_{tn} - \hat{y}_{tn}) \right]\\\ &= 2 \left( \text{LL}\left(y_{tn} | y_{tn}\right) - \text{LL}\left(y_{tn} | \hat{y}_{tn}\right)\right) \end{aligned}\end{split}\]where \(y\) is the observed data, \(\hat{y}\) is the predicted data, and \(\text{LL}\) is the model log-likelihood. Lower values of deviance indicate a better fit.
- estimate_scale(y, predicted_rate, dof_resid)[source]#
Assign 1 to the scale parameter of the Poisson model.
For the Poisson exponential family distribution, the scale parameter \(\phi\) is always 1. This property is consistent with the fact that the variance equals the mean in a Poisson distribution. As given in the general exponential family expression:
\[\text{var}(Y) = \frac{V(\mu)}{a(\phi)},\]for the Poisson family, it simplifies to \(\text{var}(Y) = \mu\) since \(a(\phi) = 1\) and \(V(\mu) = \mu\).
- Parameters:
y (
Array
) – Observed spike counts.predicted_rate (
Array
) – The predicted rate values. This is not used in the Poisson model for estimating scale, but is retained for compatibility with the abstract method signature.dof_resid (
Union
[float
,Array
]) – The DOF of the residuals.
- Return type:
Union
[float
,Array
]
- get_params(deep=True)#
From scikit-learn, get parameters by inspecting init.
- Parameters:
deep
- Return type:
dict
- Returns:
- out:
A dictionary containing the parameters. Key is the parameter name, value is the parameter value.
- property inverse_link_function#
Getter for the inverse link function for the model.
- log_likelihood(y, predicted_rate, scale=1.0, aggregate_sample_scores=<function mean>)[source]#
Compute the Poisson negative log-likelihood.
This computes the Poisson negative log-likelihood of the predicted rates for the observed spike counts up to a constant.
- Parameters:
y (
Array
) – The target spikes to compare against. Shape(n_time_bins, )
, or(n_time_bins, n_neurons)
.predicted_rate (
Array
) – The predicted rate of the current model. Shape(n_time_bins, )
, or(n_time_bins, n_neurons)
.scale (
Union
[float
,Array
]) – The scale parameter of the model.aggregate_sample_scores (
Callable
) – Function that aggregates the log-likelihood of each sample.
- Returns:
The Poisson negative log-likehood. Shape (1,).
Notes
The formula for the Poisson mean log-likelihood is the following,
\[\begin{split}\begin{aligned} \text{LL}(\hat{\lambda} | y) &= \frac{1}{T \cdot N} \sum_{n=1}^{N} \sum_{t=1}^{T} [y_{tn} \log(\hat{\lambda}_{tn}) - \hat{\lambda}_{tn} - \log({y_{tn}!})] \\\ &= \frac{1}{T \cdot N} \sum_{n=1}^{N} \sum_{t=1}^{T} [y_{tn} \log(\hat{\lambda}_{tn}) - \hat{\lambda}_{tn} - \Gamma({y_{tn}+1})] \\\ &= \frac{1}{T \cdot N} \sum_{n=1}^{N} \sum_{t=1}^{T} [y_{tn} \log(\hat{\lambda}_{tn}) - \hat{\lambda}_{tn}] + \text{const} \end{aligned}\end{split}\]Because \(\Gamma(k+1)=k!\), see wikipedia for explanation.
The \(\log({y_{tn}!})\) term is not a function of the parameters and can be disregarded when computing the loss-function. This is why we incorporated it into the const term.
- pseudo_r2(y, predicted_rate, score_type='pseudo-r2-McFadden', scale=1.0, aggregate_sample_scores=<function mean>)#
Pseudo-\(R^2\) calculation for a GLM.
Compute the pseudo-\(R^2\) metric for the GLM, as defined by McFadden et al. [1] or by Cohen et al. [2].
This metric evaluates the goodness-of-fit of the model relative to a null (baseline) model that assumes a constant mean for the observations. While the pseudo-\(R^2\) is bounded between 0 and 1 for the training set, it can yield negative values on out-of-sample data, indicating potential over-fitting.
- Parameters:
y (
Array
) – The neural activity. Expected shape:(n_time_bins, )
predicted_rate (
Array
) – The mean neural activity. Expected shape:(n_time_bins, )
score_type (
Literal
['pseudo-r2-McFadden'
,'pseudo-r2-Cohen'
]) – The pseudo-\(R^2\) type.scale (
Union
[float
,Array
,ndarray
[Any
,dtype
[TypeVar
(_ScalarType_co
, bound=generic
, covariant=True)]]]) – The scale parameter of the model.aggregate_sample_scores (Callable)
- Return type:
Array
- Returns:
The pseudo-\(R^2\) of the model. A value closer to 1 indicates a better model fit, whereas a value closer to 0 suggests that the model doesn’t improve much over the null model.
Notes
The McFadden pseudo-\(R^2\) is given by:
\[R^2_{\text{mcf}} = 1 - \frac{\log(L_{M})}{\log(L_0)}.\]Equivalent to statsmodels GLMResults.pseudo_rsquared(kind=’mcf’) .
The Cohen pseudo-\(R^2\) is given by:
\[\begin{split}\begin{aligned} R^2_{\text{Cohen}} &= \frac{D_0 - D_M}{D_0} \\\ &= 1 - \frac{\log(L_s) - \log(L_M)}{\log(L_s)-\log(L_0)}, \end{aligned}\end{split}\]where \(L_M\), \(L_0\) and \(L_s\) are the likelihood of the fitted model, the null model (a model with only the intercept term), and the saturated model (a model with one parameter per sample, i.e. the maximum value that the likelihood could possibly achieve). \(D_M\) and \(D_0\) are the model and the null deviance, \(D_i = -2 \left[ \log(L_s) - \log(L_i) \right]\) for \(i=M,0\).
References
- sample_generator(key, predicted_rate, scale=1.0)[source]#
Sample from the Poisson distribution.
This method generates random numbers from a Poisson distribution based on the given predicted_rate.
- Parameters:
key (
Array
) – Random key used for the generation of random numbers in JAX.predicted_rate (
Array
) – Expected rate (lambda) of the Poisson distribution. Shape(n_time_bins, )
, or(n_time_bins, n_neurons)
.scale (
Union
[float
,Array
]) – Scale parameter. For Poisson should be equal to 1.
- Returns:
Random numbers generated from the Poisson distribution based on the predicted_rate.
- Return type:
jnp.ndarray
- property scale#
Getter for the scale parameter of the model.
- set_params(**params)#
Set the parameters of this estimator.
The method works on simple estimators as well as on nested objects (such as
Pipeline
). The latter have parameters of the form<component>__<parameter>
so that it’s possible to update each component of a nested object.- Parameters:
**params (dict) – Estimator parameters.
- Returns:
self – Estimator instance.
- Return type:
estimator instance