nemos.observation_models.Observations#

class nemos.observation_models.Observations(inverse_link_function, **kwargs)[source]#

Bases: Base, ABC

Abstract observation model class for neural data processing.

This is an abstract base class used to implement observation models for neural data. Specific observation models that inherit from this class should define their versions of the abstract methods such as log_likelihood(), sample_generator(), and deviance().

Parameters:

inverse_link_function (Callable)

A function that transforms a set of predictors to the domain of the model parameter.

See also

PoissonObservations

A specific implementation of a observation model using the Poisson distribution.

GammaObservations

A specific implementation of a observation model using the Gamma distribution.

Attributes

inverse_link_function

Getter for the inverse link function for the model.

scale

Getter for the scale parameter of the model.

__init__(inverse_link_function, **kwargs)[source]#
Parameters:

inverse_link_function (Callable)

Methods

__init__(inverse_link_function, **kwargs)

check_inverse_link_function(...)

Check if the provided inverse_link_function is usable.

deviance(spike_counts, predicted_rate[, scale])

Compute the residual deviance for the observation model.

estimate_scale(y, predicted_rate, dof_resid)

Estimate the scale parameter for the model.

get_params([deep])

From scikit-learn, get parameters by inspecting init.

log_likelihood(y, predicted_rate[, scale, ...])

Compute the observation model log-likelihood.

pseudo_r2(y, predicted_rate[, score_type, ...])

Pseudo-\(R^2\) calculation for a GLM.

sample_generator(key, predicted_rate[, scale])

Sample from the estimated distribution.

set_params(**params)

Set the parameters of this estimator.

Check if the provided inverse_link_function is usable.

This function verifies if the inverse link function:

  1. Is callable

  2. Returns a jax.numpy.ndarray

  3. Is differentiable (via jax)

Parameters:

inverse_link_function (Callable) – The function to be checked.

Raises:

TypeError – If the function is not callable, does not return a jax.numpy.ndarray, or is not differentiable.

abstract deviance(spike_counts, predicted_rate, scale=1.0)[source]#

Compute the residual deviance for the observation model.

Parameters:
  • spike_counts (Array) – The spike counts. Shape (n_time_bins, ) or (n_time_bins, n_neurons) for population models.

  • predicted_rate (Array) – The predicted firing rates. Shape (n_time_bins, ) or (n_time_bins, n_neurons) for population models.

  • scale (Union[float, Array]) – Scale parameter of the model.

Returns:

The residual deviance of the model.

abstract estimate_scale(y, predicted_rate, dof_resid)[source]#

Estimate the scale parameter for the model.

This method estimates the scale parameter, often denoted as \(\phi\), which determines the dispersion of an exponential family distribution. The probability density function (pdf) for such a distribution is generally expressed as \(f(x; \theta, \phi) \propto \exp \left(a(\phi)\left( y\theta - \mathcal{k}(\theta) \right)\right)\).

The relationship between variance and the scale parameter is given by:

\[\text{var}(Y) = \frac{V(\mu)}{a(\phi)}.\]

The scale parameter, \(\phi\), is necessary for capturing the variance of the data accurately.

Parameters:
  • y (Array) – Observed activity.

  • predicted_rate (Array) – The predicted rate values.

  • dof_resid (Union[float, Array]) – The DOF of the residual.

Return type:

Union[float, Array]

get_params(deep=True)#

From scikit-learn, get parameters by inspecting init.

Parameters:

deep

Return type:

dict

Returns:

out:

A dictionary containing the parameters. Key is the parameter name, value is the parameter value.

property inverse_link_function#

Getter for the inverse link function for the model.

abstract log_likelihood(y, predicted_rate, scale=1.0, aggregate_sample_scores=<function mean>)[source]#

Compute the observation model log-likelihood.

This computes the log-likelihood of the predicted rates for the observed neural activity including the normalization constant

Parameters:
  • y (Array) – The target activity to compare against. Shape (n_time_bins, ), or (n_time_bins, n_neurons).

  • predicted_rate (Array) – The predicted rate of the current model. Shape (n_time_bins, ), or (n_time_bins, n_neurons).

  • scale (Union[float, Array]) – The scale parameter of the model

  • aggregate_sample_scores (Callable) – Function that aggregates the log-likelihood of each sample.

Returns:

The log-likehood. Shape (1,).

pseudo_r2(y, predicted_rate, score_type='pseudo-r2-McFadden', scale=1.0, aggregate_sample_scores=<function mean>)[source]#

Pseudo-\(R^2\) calculation for a GLM.

Compute the pseudo-\(R^2\) metric for the GLM, as defined by McFadden et al. [1] or by Cohen et al. [2].

This metric evaluates the goodness-of-fit of the model relative to a null (baseline) model that assumes a constant mean for the observations. While the pseudo-\(R^2\) is bounded between 0 and 1 for the training set, it can yield negative values on out-of-sample data, indicating potential over-fitting.

Parameters:
  • y (Array) – The neural activity. Expected shape: (n_time_bins, )

  • predicted_rate (Array) – The mean neural activity. Expected shape: (n_time_bins, )

  • score_type (Literal['pseudo-r2-McFadden', 'pseudo-r2-Cohen']) – The pseudo-\(R^2\) type.

  • scale (Union[float, Array, ndarray[Any, dtype[TypeVar(_ScalarType_co, bound= generic, covariant=True)]]]) – The scale parameter of the model.

  • aggregate_sample_scores (Callable)

Return type:

Array

Returns:

The pseudo-\(R^2\) of the model. A value closer to 1 indicates a better model fit, whereas a value closer to 0 suggests that the model doesn’t improve much over the null model.

Notes

  • The McFadden pseudo-\(R^2\) is given by:

    \[R^2_{\text{mcf}} = 1 - \frac{\log(L_{M})}{\log(L_0)}.\]

    Equivalent to statsmodels GLMResults.pseudo_rsquared(kind=’mcf’) .

  • The Cohen pseudo-\(R^2\) is given by:

    \[\begin{split}\begin{aligned} R^2_{\text{Cohen}} &= \frac{D_0 - D_M}{D_0} \\\ &= 1 - \frac{\log(L_s) - \log(L_M)}{\log(L_s)-\log(L_0)}, \end{aligned}\end{split}\]

    where \(L_M\), \(L_0\) and \(L_s\) are the likelihood of the fitted model, the null model (a model with only the intercept term), and the saturated model (a model with one parameter per sample, i.e. the maximum value that the likelihood could possibly achieve). \(D_M\) and \(D_0\) are the model and the null deviance, \(D_i = -2 \left[ \log(L_s) - \log(L_i) \right]\) for \(i=M,0\).

References

abstract sample_generator(key, predicted_rate, scale=1.0)[source]#

Sample from the estimated distribution.

This method generates random numbers from the desired distribution based on the given predicted_rate.

Parameters:
  • key (Array) – Random key used for the generation of random numbers in JAX.

  • predicted_rate (Array) – Expected rate of the distribution. Shape (n_time_bins, ), or (n_time_bins, n_neurons)..

  • scale (Union[float, Array]) – Scale parameter for the distribution.

Return type:

Array

Returns:

Random numbers generated from the observation model with predicted_rate.

property scale#

Getter for the scale parameter of the model.

set_params(**params)#

Set the parameters of this estimator.

The method works on simple estimators as well as on nested objects (such as Pipeline). The latter have parameters of the form <component>__<parameter> so that it’s possible to update each component of a nested object.

Parameters:

**params (dict) – Estimator parameters.

Returns:

self – Estimator instance.

Return type:

estimator instance