"""GLM core module."""
# required to get ArrayLike to render correctly
from __future__ import annotations
import warnings
from pathlib import Path
from typing import Any, Callable, Literal, Optional, Tuple, Union
import equinox as eqx
import jax
import jax.numpy as jnp
from numpy.typing import ArrayLike
from sklearn.utils import InputTags, TargetTags
from .. import observation_models as obs
from .. import tree_utils, validation
from .._observation_model_builder import instantiate_observation_model
from ..base_regressor import BaseRegressor, strip_metadata
from ..exceptions import NotFittedError
from ..inverse_link_function_utils import resolve_inverse_link_function
from ..pytrees import FeaturePytree
from ..regularizer import ElasticNet, GroupLasso, Lasso, Regularizer, Ridge
from ..solvers._compute_defaults import glm_compute_optimal_stepsize_configs
from ..type_casting import cast_to_jax, support_pynapple
from ..typing import DESIGN_INPUT_TYPE, SolverState, StepResult
from ..utils import format_repr
from .initialize_parameters import initialize_intercept_matching_mean_rate
from .params import GLMParams, GLMUserParams
from .validation import (
GLMValidator,
PopulationGLMValidator,
)
__all__ = ["GLM", "PopulationGLM"]
REGRESSION_GLM_TYPES = Union[
obs.BernoulliObservations,
obs.GammaObservations,
obs.GaussianObservations,
obs.NegativeBinomialObservations,
obs.PoissonObservations,
Literal[
"Poisson",
"Gamma",
"Bernoulli",
"NegativeBinomial",
"Gaussian",
],
]
[docs]
class GLM(BaseRegressor[GLMUserParams, GLMParams, GLMValidator]):
r"""Generalized Linear Model (GLM) for neural activity data.
This GLM implementation allows users to model neural activity based on a combination of exogenous inputs
(like convolved currents or light intensities) and a choice of observation model. It is suitable for scenarios where
the relationship between predictors and the response variable might be non-linear, and the residuals
don't follow a normal distribution.
Below is a table of the default inverse link function for the availabe observation model.
+---------------------+---------------------------------+
| Observation Model | Default Inverse Link Function |
+=====================+=================================+
| Poisson | :math:`e^x` |
+---------------------+---------------------------------+
| Gamma | :math:`1/x` |
+---------------------+---------------------------------+
| Bernoulli | :math:`1 / (1 + e^{-x})` |
+---------------------+---------------------------------+
| NegativeBinomial | :math:`e^x` |
+---------------------+---------------------------------+
| Gaussian | :math:`x` |
+---------------------+---------------------------------+
Below is a table listing the default and available solvers for each regularizer.
+---------------+------------------+-------------------------------------------------------------+
| Regularizer | Default Solver | Available Solvers |
+===============+==================+=============================================================+
| UnRegularized | LBFGS | GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient |
+---------------+------------------+-------------------------------------------------------------+
| Ridge | LBFGS | GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient |
+---------------+------------------+-------------------------------------------------------------+
| Lasso | ProximalGradient | ProximalGradient |
+---------------+------------------+-------------------------------------------------------------+
| GroupLasso | ProximalGradient | ProximalGradient |
+---------------+------------------+-------------------------------------------------------------+
**Fitting Large Models**
For very large models, you may consider using the Stochastic Variance Reduced Gradient
:class:`nemos.solvers._svrg.SVRG` or its proximal variant
:class:`nemos.solvers._svrg.ProxSVRG` solver,
which take advantage of batched computation. You can change the solver by passing
``"SVRG"`` as ``solver_name`` at model initialization.
The performance of the SVRG solver depends critically on the choice of ``batch_size`` and ``stepsize``
hyperparameters. These parameters control the size of the mini-batches used for gradient computations
and the step size for each iteration, respectively. Improper selection of these parameters can lead to slow
convergence or even divergence of the optimization process.
To assist with this, for certain GLM configurations, we provide ``batch_size`` and ``stepsize`` default
values that are theoretically guaranteed to ensure fast convergence.
Below is a list of the configurations for which we can provide guaranteed default hyperparameters:
+---------------------------------------+-----------+-------------+
| GLM / PopulationGLM Configuration | Stepsize | Batch Size |
+=======================================+===========+=============+
| Poisson + soft-plus + UnRegularized | ✅ | ❌ |
+---------------------------------------+-----------+-------------+
| Poisson + soft-plus + Ridge | ✅ | ✅ |
+---------------------------------------+-----------+-------------+
| Poisson + soft-plus + Lasso | ✅ | ❌ |
+---------------------------------------+-----------+-------------+
| Poisson + soft-plus + GroupLasso | ✅ | ❌ |
+---------------------------------------+-----------+-------------+
Parameters
----------
observation_model :
Observation model to use. The model describes the distribution of the neural activity.
Default is the Poisson model. Alternatives are "Gamma", "Bernoulli", "NegativeBinomial" and "Gaussian".
inverse_link_function :
A function that maps the linear combination of predictors into a firing rate. The default depends
on the observation model, see the table above.
regularizer :
Regularization to use for model optimization. Defines the regularization scheme
and related parameters.
Default is UnRegularized regression.
regularizer_strength :
Typically a float. Default is None. Sets the regularizer strength.
If a user does not pass a value, and it is needed for regularization,
a warning will be raised and the strength will default to 1.0.
For finer control, the user can pass a pytree that matches the
parameter structure to regularize parameters differentially.
solver_name :
Solver to use for model optimization. Defines the optimization scheme and related parameters.
The solver must be an appropriate match for the chosen regularizer.
Default is ``None``. If no solver specified, one will be chosen based on the regularizer.
Please see table above for regularizer/optimizer pairings.
solver_kwargs :
Optional dictionary for keyword arguments that are passed to the solver when instantiated.
E.g. stepsize, tol, acceleration, etc.
For details on each solver's kwargs, see `get_accepted_arguments` and `get_solver_documentation`.
Attributes
----------
intercept_ :
Model baseline linked firing rate parameters, e.g. if the link is the logarithm, the baseline
firing rate will be ``jnp.exp(model.intercept_)``.
coef_ :
Basis coefficients for the model.
solver_state_ :
State of the solver after fitting. May include details like optimization error.
scale_:
Scale parameter for the model. The scale parameter is the constant :math:`\Phi`, for which
:math:`\text{Var} \left( y \right) = \Phi V(\mu)`. This parameter, together with the estimate
of the mean :math:`\mu` fully specifies the distribution of the activity :math:`y`.
dof_resid_:
Degrees of freedom for the residuals.
Raises
------
TypeError
If provided ``regularizer`` or ``observation_model`` are not valid.
Examples
--------
**Fit a GLM**
Basic model fitting with default Poisson observation model:
>>> import numpy as np
>>> import nemos as nmo
>>> np.random.seed(123)
>>> X = np.random.normal(size=(100, 5))
>>> y = np.random.poisson(size=100)
>>> model = nmo.glm.GLM().fit(X, y)
>>> model.coef_.shape
(5,)
**Customize the Observation Model**
Specify the observation model as a string:
>>> model = nmo.glm.GLM(observation_model="Gamma")
>>> model.observation_model
GammaObservations()
Or pass the observation model object directly:
>>> model = nmo.glm.GLM(observation_model=nmo.observation_models.GammaObservations())
>>> model.observation_model
GammaObservations()
**Customize the Inverse Link Function**
Use a soft-plus inverse link function instead of the default exponential:
>>> model = nmo.glm.GLM(inverse_link_function=jax.nn.softplus)
>>> model.inverse_link_function.__name__
'softplus'
**Use Regularization**
Fit with Ridge regularization:
>>> model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1)
>>> model = model.fit(X, y)
>>> model.regularizer
Ridge()
Fit with Lasso regularization for sparse coefficients:
>>> model = nmo.glm.GLM(regularizer="Lasso", regularizer_strength=0.01)
>>> model = model.fit(X, y)
>>> model.regularizer
Lasso()
**Select a Solver**
Use LBFGS solver for potentially faster convergence:
>>> model = nmo.glm.GLM(solver_name="BFGS").fit(X, y)
>>> model.solver_name
'BFGS'
**Use a Pytree of arrays as Input**
Features can be passed as any JAX pytree of 2-D arrays; the fitted
``coef_`` will share the same pytree structure:
>>> X_dict = {"input_1": X[:, :2], "input_2": X[:, 2:]}
>>> model = nmo.glm.GLM().fit(X_dict, y)
>>> # The coefficient structure will match the input.
>>> type(model.coef_)
<class 'dict'>
"""
_invalid_observation_types = (obs.CategoricalObservations,)
_validator_class = GLMValidator
[docs]
def __init__(
self,
observation_model: (
REGRESSION_GLM_TYPES
| Literal["Poisson", "Gamma", "Gaussian", "Bernoulli", "NegativeBinomial"]
) = "Poisson",
inverse_link_function: Optional[Callable] = None,
regularizer: Optional[Union[str, Regularizer]] = None,
regularizer_strength: Any = None,
solver_name: str = None,
solver_kwargs: dict = None,
):
super().__init__(
regularizer=regularizer,
regularizer_strength=regularizer_strength,
solver_name=solver_name,
solver_kwargs=solver_kwargs,
)
self.observation_model = observation_model
self.inverse_link_function = inverse_link_function
self._validator = self._validator_class(
extra_params=self._get_validator_extra_params()
)
# initialize to None fit output
self.intercept_ = None
self.coef_ = None
self.solver_state_ = None
self.scale_ = None
self.dof_resid_ = None
self.aux_ = None
self._solver = None
@property
def solver(self):
"""Getter for the solver class."""
return self._solver
@classmethod
def _validate_observation_class(cls, observation: obs.Observations):
if observation.__class__ in cls._invalid_observation_types:
model_name = cls.__name__
obs_name = observation.__class__.__name__
error_msg = f"The ``{obs_name}`` observation type is not supported for ``{model_name}`` models."
is_categorical = isinstance(observation, obs.CategoricalObservations)
if is_categorical:
correct_model = (
"ClassifierPopulationGLM"
if issubclass(cls, PopulationGLM)
else "ClassifierGLM"
)
error_msg += (
f" To use a GLM for classification instantiate a ``{correct_model}`` "
f"object."
)
else:
correct_model = (
"PopulationGLM" if issubclass(cls, PopulationGLM) else "GLM"
)
error_msg += (
f" To use a GLM for regression with ``{obs_name}`` instantiate a ``{correct_model}`` "
f"object."
)
raise TypeError(error_msg)
@property
def inverse_link_function(self):
"""Inverse link function mapping the linear predictor to the response space.
Always a callable. If ``None`` was passed at construction time, this is
resolved to the observation model's default (e.g. ``jnp.exp`` for Poisson,
``1 / x`` for Gamma, ``jax.nn.sigmoid`` for Bernoulli).
"""
return self._inverse_link_function
@inverse_link_function.setter
def inverse_link_function(self, inverse_link_function: Callable):
"""Validate and set the inverse link function.
Parameters
----------
inverse_link_function :
One of:
- ``None`` — use the observation model's default inverse link.
- ``str`` — name of a built-in (e.g. ``"identity"``, ``"log"``,
``"logit"``); resolved by
:func:`nemos.inverse_link_function_utils.resolve_inverse_link_function`.
- ``Callable`` — a custom function. Must be JAX-traceable
(differentiable) and return a ``jax.numpy.ndarray`` or scalar
when called on a JAX array.
Raises
------
TypeError
If the value is neither callable nor a string.
ValueError
If a callable is non-differentiable or returns an unsupported type.
"""
self._inverse_link_function = resolve_inverse_link_function(
inverse_link_function, self._observation_model
)
@property
def observation_model(self) -> Union[None, obs.Observations]:
"""The observation model governing the conditional distribution of ``y``.
Always an instance of an :class:`~nemos.observation_models.Observations`
subclass. If a string alias was passed at construction time it is
resolved to the corresponding instance here.
"""
return self._observation_model
@observation_model.setter
def observation_model(self, observation: obs.Observations):
"""Validate and set the observation model.
Parameters
----------
observation :
Either an :class:`~nemos.observation_models.Observations` instance,
or a string alias from
``{"Poisson", "Gamma", "Gaussian", "Bernoulli", "NegativeBinomial"}``.
String aliases are instantiated via
:func:`nemos.observation_models.instantiate_observation_model`.
Raises
------
AttributeError, TypeError
If the instance does not implement the
:class:`~nemos.observation_models.Observations` interface (checked
via :func:`nemos.observation_models.check_observation_model`).
ValueError
If the resolved observation class is not allowed for this model
(e.g. ``CategoricalObservations`` is rejected by ``GLM``).
"""
if isinstance(observation, str):
self._observation_model = instantiate_observation_model(observation)
self._validate_observation_class(self.observation_model)
return
# check that the model has the required attributes
# and that the attribute can be called
obs.check_observation_model(observation)
self._observation_model = observation
self._validate_observation_class(self.observation_model)
def _check_is_fit(self):
"""Ensure the instance has been fitted."""
if (self.coef_ is None) or (self.intercept_ is None):
raise NotFittedError(
"This GLM instance is not fitted yet. Call 'fit' with appropriate arguments."
)
def _predict(
self, params: GLMParams, X: Union[dict[str, jnp.ndarray], jnp.ndarray]
) -> jnp.ndarray:
"""
Predicts firing rates based on given parameters and design matrix.
This function computes the predicted firing rates using the provided parameters
and model design matrix ``X``. It is a streamlined version used internally within
optimization routines, where it serves as the loss function. Unlike the ``GLM.predict``
method, it does not perform any input validation, assuming that the inputs are pre-validated.
Parameters
----------
params :
GLMParams containing the spike basis coefficients and bias terms.
X :
Predictors.
Returns
-------
:
The predicted rates. Shape (n_time_bins, ).
"""
return self._inverse_link_function(
# First, multiply each feature by its corresponding coefficient,
# then sum across all features and add the intercept, before
# passing to the inverse link function
tree_utils.pytree_map_and_reduce(
lambda x, w: jnp.einsum("tj, j...->t...", x, w), sum, X, params.coef
)
+ params.intercept
)
[docs]
@support_pynapple(conv_type="jax")
def predict(self, X: DESIGN_INPUT_TYPE) -> jnp.ndarray:
"""Predict rates based on fit parameters.
Parameters
----------
X :
Predictors, array of shape ``(n_time_bins, n_features)`` or a pytree
of arrays of the same shape.
Returns
-------
:
The predicted rates with shape ``(n_time_bins, )``.
Raises
------
NotFittedError
If ``fit`` has not been called first with this instance.
ValueError
If ``params`` is not a JAX pytree of size two.
ValueError
If weights and bias terms in ``params`` don't have the expected dimensions.
ValueError
If ``X`` is not three-dimensional.
ValueError
If there's an inconsistent number of features between spike basis coefficients and ``X``.
Examples
--------
>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
>>> # define and fit a GLM
>>> import nemos as nmo
>>> model = nmo.glm.GLM()
>>> model = model.fit(X, y)
>>> # predict new spike data
>>> Xnew = np.random.normal(size=(20, X.shape[1]))
>>> predicted_spikes = model.predict(Xnew)
See Also
--------
:meth:`nemos.glm.GLM.score`
Score predicted rates against target spike counts.
:meth:`nemos.glm.GLM.simulate`
Simulate neural activity in response to a feed-forward input (feed-forward only).
:func:`nemos.simulation.simulate_recurrent`
Simulate neural activity in response to a feed-forward input
using the GLM as a recurrent network (feed-forward + coupling).
"""
# check that the model is fitted
self._check_is_fit()
# extract model params
params = self._get_model_params()
# filter for non-nans, grab data if needed
data, _ = self._preprocess_inputs(X, drop_nans=False)
self._validator.validate_inputs(data)
# check consistency between X and params
self._validator.validate_consistency(params, X=data)
self._validator.feature_mask_consistency(
getattr(self, "_feature_mask", None), params
)
return self._predict(params, data)
def _compute_loss(
self,
params: GLMParams,
X: DESIGN_INPUT_TYPE,
y: jnp.ndarray,
*args,
**kwargs,
) -> jnp.ndarray:
r"""Predict the rate and compute the negative log-likelihood against neural activity.
This method computes the negative log-likelihood up to a constant term. Unlike ``score``,
it does not conduct parameter checks prior to evaluation. Passed directly to the solver,
it serves to establish the optimization objective for learning the model parameters.
Parameters
----------
params :
2-tuple containing the spike basis coefficients and bias terms.
X :
Predictors.
y :
Target neural activity.
Returns
-------
:
The model negative log-likehood. Shape (1,).
"""
predicted_rate = self._predict(params, X)
return self._observation_model._negative_log_likelihood(y, predicted_rate)
[docs]
@cast_to_jax
def score(
self,
X: Union[DESIGN_INPUT_TYPE, ArrayLike],
y: ArrayLike,
score_type: Literal[
"log-likelihood", "pseudo-r2-McFadden", "pseudo-r2-Cohen"
] = "log-likelihood",
aggregate_sample_scores: Callable = jnp.mean,
) -> jnp.ndarray:
r"""Evaluate the goodness-of-fit of the model to the observed neural data.
This method computes the goodness-of-fit score, which can either be the mean
log-likelihood or of two versions of the pseudo-:math:`R^2`.
The scoring process includes validation of input compatibility with the model's
parameters, ensuring that the model has been previously fitted and the input data
are appropriate for scoring. A higher score indicates a better fit of the model
to the observed data.
Parameters
----------
X :
Predictors, array of shape ``(n_time_bins, n_features)`` or a pytree
of arrays of the same shape.
y :
Neural activity. Shape ``(n_time_bins, )``.
score_type :
Type of scoring: either log-likelihood or pseudo-:math:`R^2`.
aggregate_sample_scores :
Function that aggregates the score of all samples.
Returns
-------
score :
The log-likelihood or the pseudo-:math:`R^2` of the current model.
Raises
------
NotFittedError
If ``fit`` has not been called first with this instance.
ValueError
If X structure doesn't match the params, and if X and y have different
number of samples.
Examples
--------
>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
>>> import nemos as nmo
>>> model = nmo.glm.GLM()
>>> model = model.fit(X, y)
>>> # get model score
>>> log_likelihood_score = model.score(X, y)
>>> # get a pseudo-R2 score
>>> pseudo_r2_score = model.score(X, y, score_type='pseudo-r2-McFadden')
Notes
-----
The log-likelihood is not on a standard scale, its value is influenced by many factors,
among which the number of model parameters. The log-likelihood can assume both positive
and negative values.
The Pseudo-:math:`R^2` is not equivalent to the :math:`R^2` value in linear regression. While both
provide a measure of model fit, and assume values in the [0,1] range, the methods and
interpretations can differ. The Pseudo-:math:`R^2` is particularly useful for generalized linear
models when the interpretation of the :math:`R^2` as explained variance does not apply
(i.e., when the observations are not Gaussian distributed).
Why does the traditional :math:`R^2` is usually a poor measure of performance in GLMs?
1. In the context of GLMs the variance and the mean of the observations are related.
Ignoring the relation between them can result in underestimating the model
performance; for instance, when we model a Poisson variable with large mean we expect an
equally large variance. In this scenario, even if our model perfectly captures the mean,
the high-variance will result in large residuals and low :math:`R^2`.
Additionally, when the mean of the observations varies, the variance will vary too. This
violates the "homoschedasticity" assumption, necessary for interpreting the :math:`R^2` as
variance explained.
2. The :math:`R^2` capture the variance explained when the relationship between the observations and
the predictors is linear. In GLMs, the link function sets a non-linear mapping between the predictors
and the mean of the observations, compromising the interpretation of the :math:`R^2`.
Note that it is possible to re-normalized the residuals by a mean-dependent quantity proportional
to the model standard deviation (i.e. Pearson residuals). This "rescaled" residual distribution however
deviates substantially from normality for counting data with low mean (common for spike counts).
Therefore, even the Pearson residuals performs poorly as a measure of fit quality, especially
for GLM modeling counting data.
Refer to the ``nmo.observation_models.Observations`` concrete subclasses for the likelihood and
pseudo-:math:`R^2` equations.
"""
self._check_is_fit()
params = self._get_model_params()
self._validator.validate_inputs(X, y)
X, y = self._preprocess_inputs(X, y, drop_nans=True)
self._validator.validate_consistency(params, X, y)
self._validator.feature_mask_consistency(
getattr(self, "_feature_mask", None), params
)
if score_type == "log-likelihood":
score = self._observation_model.log_likelihood(
y,
self._predict(params, X),
self.scale_,
aggregate_sample_scores=aggregate_sample_scores,
)
elif score_type.startswith("pseudo-r2"):
score = self._observation_model.pseudo_r2(
y,
self._predict(params, X),
score_type=score_type,
scale=self.scale_,
aggregate_sample_scores=aggregate_sample_scores,
)
else:
raise NotImplementedError(
f"Scoring method {score_type} not implemented! "
"`score_type` must be either 'log-likelihood', 'pseudo-r2-McFadden', "
"or 'pseudo-r2-Cohen'."
)
return score
def _model_specific_initialization(
self,
X: DESIGN_INPUT_TYPE,
y: jnp.ndarray,
**kwargs,
) -> GLMParams:
"""Initialize the parameters based on the structure and dimensions X and y.
This method initializes the coefficients (spike basis coefficients) and intercepts (bias terms)
required for the GLM. The coefficients are initialized to zeros with dimensions based on the input X.
If X is a pytree of arrays, the coefficients retain the pytree structure with
arrays of zeros shaped according to the features in X.
If X is a simple ndarray, the coefficients are initialized as a 2D array. The intercepts are initialized
based on the log mean of the target data y across the first axis, corresponding to the average log activity
of the neuron.
Parameters
----------
X :
The input data, either a pytree of arrays with leaves of shape
``(n_timebins, n_features)``, or a simple ndarray of shape ``(n_timebins, n_features)``.
y :
The target data array of shape ``(n_timebins, )``, representing
the neuron firing rates or similar metrics.
Returns
-------
Tuple[Union[pytree of arrays, jnp.ndarray], jnp.ndarray]
A tuple containing the initialized parameters:
- The first element is the initialized coefficients
(either as a pytree of arrays or ndarray, matching the structure of X) with shapes (n_features,).
- The second element is the initialized intercept (bias terms) as an ndarray of shape (1,).
"""
if isinstance(X, FeaturePytree):
data = X.data
else:
data = X
empty_params = self._validator.get_empty_params(data, y)
initial_intercept = initialize_intercept_matching_mean_rate(
self._inverse_link_function, y
)
initial_coef = jax.tree_util.tree_map(
lambda x: jnp.zeros(x.shape), empty_params.coef
)
init_params = eqx.tree_at(
lambda p: (p.coef, p.intercept),
empty_params,
(initial_coef, initial_intercept),
)
self._validator.feature_mask_consistency(
getattr(self, "_feature_mask", None), init_params
)
return init_params
[docs]
@cast_to_jax
def fit(
self,
X: Union[DESIGN_INPUT_TYPE, ArrayLike],
y: ArrayLike,
init_params: Optional[GLMUserParams] = None,
):
"""Fit GLM to neural activity.
Fit and store the model parameters as attributes
``coef_`` and ``coef_``.
Parameters
----------
X :
Predictors, array of shape (n_time_bins, n_features) or pytree of the same
shape.
y :
Target neural activity arranged in a matrix, shape (n_time_bins, ).
init_params :
2-tuple of initial parameter values: (coefficients, intercepts). If
None, we initialize coefficients with zeros, intercepts with the
log of the mean neural activity. coefficients is an array of shape
(n_features,) or pytree of same, intercepts is an array
of shape (1, )
Raises
------
ValueError
If ``init_params`` is not of length two.
ValueError
If dimensionality of ``init_params`` are not correct.
ValueError
If ``X`` is not two-dimensional.
ValueError
If ``y`` is not one-dimensional.
ValueError
If solver returns at least one NaN parameter, which means it found
an invalid solution. Try tuning optimization hyperparameters.
TypeError
If ``init_params`` are not array-like
TypeError
If ``init_params[i]`` cannot be converted to ``jnp.ndarray`` for all ``i``
Examples
--------
>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
>>> # fit a ridge regression Poisson GLM
>>> import nemos as nmo
>>> model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1)
>>> model = model.fit(X, y)
>>> # get model weights and intercept
>>> model_weights = model.coef_
>>> model_intercept = model.intercept_
"""
self._validator.validate_inputs(X, y)
# filter for non-nans, grab data if needed
data, y = self._preprocess_inputs(X, y)
# initialize params if no params are provided
if init_params is None:
init_params = self._model_specific_initialization(X, y)
else:
init_params = self._validator.validate_and_cast_params(init_params)
self._validator.validate_consistency(init_params, X=X, y=y)
self._validator.feature_mask_consistency(
getattr(self, "_feature_mask", None), init_params
)
self._initialize_optimizer_and_state(init_params, data, y)
params, state, aux = self._optimizer_run(init_params, data, y)
if tree_utils.pytree_map_and_reduce(
lambda x: jnp.any(jnp.isnan(x)), any, params
):
raise ValueError(
"Solver returned at least one NaN parameter, so solution is invalid!"
" Try tuning optimization hyperparameters, specifically try decreasing the `stepsize` "
"and/or setting `acceleration=False`."
)
if hasattr(state, "stats") and hasattr(state.stats, "converged"):
converged = state.stats.converged
elif hasattr(state, "converged"):
# try if the custom defined solver has a convergence flag directly
converged = state.converged
else:
# custom solver with potentially undefined convergence state
converged = True
warnings.warn(
f"Solver state {state} does not have a ``.converged`` nor a ``.stats.converged`` "
f"attribute. Convergence state is unknown; assuming converged. "
f"To assess the optimization manually, "
f"inspect the ``solver_state_`` attribute of the model.",
UserWarning,
)
if not converged:
warnings.warn(
"The fit did not converge. "
"Consider the following:"
"\n1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` "
"\n2) Increase the max number of iterations or increase tolerance (if reasonable). "
"These parameters can be specified by providing a ``solver_kwargs`` dictionary. "
"For the available options see the ``self.solver.__init__`` docstrings.",
RuntimeWarning,
)
self._set_model_params(params)
self.dof_resid_ = self._estimate_resid_degrees_of_freedom(X)
self.scale_ = self.observation_model.estimate_scale(
y, self._predict(params, data), dof_resid=self.dof_resid_
)
# note that this will include an error value, which is not the same as
# the output of loss. I believe it's the output of
# solver.l2_optimality_error
self.solver_state_ = state
self.aux_ = aux
return self
def _get_model_params(self):
"""Pack coef_ and intercept_ into a params pytree.
This method should be overwritten in case the parameter structure changes,
or if new regression models will have a different parameter structure.
"""
# Retrieve parameter tree
return GLMParams(self.coef_, self.intercept_)
def _set_model_params(self, params: GLMParams):
"""Unpack and store params pytree to coef_ and intercept_.
This method should be overwritten in case the parameter structure changes,
or if new regression models will have a different parameter structure.
"""
# Store parameters
self.coef_: DESIGN_INPUT_TYPE = params.coef
self.intercept_: jnp.ndarray = params.intercept
[docs]
@support_pynapple(conv_type="jax")
def simulate(
self,
random_key: jax.Array,
feedforward_input: DESIGN_INPUT_TYPE,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Simulate neural activity in response to a feed-forward input.
Parameters
----------
random_key :
jax.random.key for seeding the simulation.
feedforward_input :
External input predictors to the model, representing factors like convolved currents,
light intensities, etc. When not provided, the simulation is done with coupling-only.
Array of shape (n_time_bins, n_basis_input) or pytree with leaves of the same shape.
Returns
-------
simulated_activity :
Simulated activity (spike counts for Poisson GLMs) for the neuron over time.
Shape: ``(n_time_bins, )``.
firing_rates :
Simulated rates for the neuron over time. Shape, ``(n_time_bins, )``.
Raises
------
NotFittedError
- If the model hasn't been fitted prior to calling this method.
ValueError
- If the instance has not been previously fitted.
Examples
--------
>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
>>> # define and fit model
>>> import nemos as nmo
>>> model = nmo.glm.GLM()
>>> model = model.fit(X, y)
>>> # generate spikes and rates
>>> random_key = jax.random.key(123)
>>> Xnew = np.random.normal(size=(20, X.shape[1]))
>>> spikes, rates = model.simulate(random_key, Xnew)
See Also
--------
:meth:`nemos.glm.GLM.predict`
Method to predict rates based on the model's parameters.
"""
# check if the model is fit
self._check_is_fit()
params = self._get_model_params()
# if all invalid, raise error
validation.error_all_invalid(feedforward_input)
# check input dimensionality
self._validator.validate_inputs(X=feedforward_input)
# validate input and params consistency
self._validator.validate_consistency(params, X=feedforward_input)
self._validator.feature_mask_consistency(
getattr(self, "_feature_mask", None), params
)
# pre-process
feedforward_input, _ = self._preprocess_inputs(
X=feedforward_input, drop_nans=False
)
predicted_rate = self._predict(params, feedforward_input)
return (
self._observation_model.sample_generator(
key=random_key, predicted_rate=predicted_rate, scale=self.scale_
),
predicted_rate,
)
def _estimate_resid_degrees_of_freedom(
self, X: DESIGN_INPUT_TYPE, n_samples: Optional[int] = None
) -> jnp.ndarray:
"""
Estimate the degrees of freedom of the residuals.
Parameters
----------
self :
A fitted GLM model.
X :
The design matrix.
n_samples :
The number of samples observed. If not provided, n_samples is set to ``X.shape[0]``. If the fit is
batched, the n_samples could be larger than ``X.shape[0]``.
Returns
-------
:
An estimate of the degrees of freedom of the residuals.
"""
# Convert a pytree to a design-matrix with pytrees
X = jnp.hstack(jax.tree_util.tree_leaves(X))
if n_samples is None:
n_samples = X.shape[0]
else:
if not isinstance(n_samples, int):
raise TypeError(
"`n_samples` must be `None` or of type `int`. Type {type(n_sample)} provided "
"instead!"
)
params = self._get_model_params()
# if the regularizer is lasso use the non-zero
# coeff as an estimate of the dof
# see https://arxiv.org/abs/0712.0881
if isinstance(self.regularizer, (GroupLasso, Lasso, ElasticNet)):
resid_dof = tree_utils.pytree_map_and_reduce(
lambda x: ~jnp.isclose(x, jnp.zeros_like(x)),
lambda x: sum([jnp.sum(i, axis=0) for i in x]),
params.coef,
)
return n_samples - resid_dof - 1
elif isinstance(self.regularizer, Ridge):
# for Ridge, use the tot parameters (X.shape[1] + intercept)
return (n_samples - X.shape[1] - 1) * jnp.ones_like(params.intercept)
else:
# for UnRegularized, use the rank
rank = jnp.linalg.matrix_rank(X)
return (n_samples - rank - 1) * jnp.ones_like(params.intercept)
def _initialize_optimizer_and_state(
self,
init_params: GLMParams,
X: dict[str, jnp.ndarray] | jnp.ndarray,
y: jnp.ndarray,
) -> SolverState:
"""Initialize the solver by instantiating its init_state, update and, run methods.
This method also prepares the solver's state by using the initialized model parameters and data.
This setup is ready to be used for running the solver's optimization routines.
Parameters
----------
init_params :
Initial parameters for the model.
X :
The predictors used in the model fitting process. This can include feature matrices or other structures
compatible with the model's design.
y :
The response variables or outputs corresponding to the predictors. Used to initialize parameters when
they are not provided.
Returns
-------
SolverState
The initialized solver state
Examples
--------
>>> import numpy as np
>>> import nemos as nmo
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
>>> model = nmo.glm.GLM()
>>> params = model.initialize_params(X, y)
>>> opt_state = model.initialize_optimizer_and_state(params, X, y)
>>> # Now ready to run optimization or update steps
"""
opt_solver_kwargs = self._optimize_solver_params(X, y)
# set up the solver init/run/update attrs
self._solver = self._instantiate_solver(
self._compute_loss, init_params=init_params, solver_kwargs=opt_solver_kwargs
)
self._optimizer_init_state = self._solver.init_state
self._optimizer_update = self._solver.update
self._optimizer_run = self._solver.run
opt_state = self._optimizer_init_state(init_params, X, y)
return opt_state
[docs]
@cast_to_jax
def update(
self,
params: GLMUserParams,
opt_state: SolverState,
X: DESIGN_INPUT_TYPE,
y: jnp.ndarray,
*args,
n_samples: Optional[int] = None,
**kwargs,
) -> StepResult:
"""Update the model parameters and solver state.
This method performs a single optimization step using the model's current solver.
It updates the model's coefficients and intercept based on the provided parameters, predictors (X),
responses (y), and the current optimization state. This method is particularly useful for iterative
model fitting, especially in scenarios where model parameters need to be updated incrementally,
such as online learning or when dealing with very large datasets that do not fit into memory at once.
Parameters
----------
params
The current model parameters, typically a tuple of coefficients and intercepts.
opt_state
The current state of the optimizer, encapsulating information necessary for the
optimization algorithm to continue from the current state. This includes gradients,
step sizes, and other optimizer-specific metrics.
X
The predictors used in the model fitting process, which may include feature matrices
or a pytree of arrays. Shape ``(n_time_bins, n_features)``.
y
The response variable or output data corresponding to the predictors. Shape ``(n_time_bins,)``.
*args
Additional positional arguments to be passed to the solver's update method.
n_samples
The total number of samples. Usually larger than the samples of an individual batch,
the ``n_samples`` are used to estimate the scale parameter of the GLM.
**kwargs
Additional keyword arguments to be passed to the solver's update method.
Returns
-------
params
Updated model parameters (coefficients, intercepts).
state
Updated optimizer state.
Raises
------
ValueError
If the solver has not been instantiated or if the solver returns NaN values
indicating an invalid update step, typically due to numerical instabilities
or inappropriate solver configurations.
Examples
--------
>>> import nemos as nmo
>>> import numpy as np
>>> import jax
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
>>> glm_instance = nmo.glm.GLM()
>>> params = glm_instance.initialize_params(X, y)
>>> opt_state = glm_instance.initialize_optimizer_and_state(params, X, y)
>>> new_params, new_opt_state = glm_instance.update(params, opt_state, X, y)
"""
# find non-nans
X, y = tree_utils.drop_nans(X, y)
# grab the data
data = X.data if isinstance(X, FeaturePytree) else X
# wrap into GLM params, this assumes params are well structured,
# if initializaiton is done via `initialize_optimizer_and_state` it
# should be fine
params = self._validator.to_model_params(params)
# perform a one-step update
updated_params, updated_state, aux = self._optimizer_update(
params, opt_state, data, y, *args, **kwargs
)
# store params and state
self._set_model_params(updated_params)
self.solver_state_ = updated_state
self.aux_ = aux
# estimate the scale
self.dof_resid_ = self._estimate_resid_degrees_of_freedom(
X, n_samples=n_samples
)
self.scale_ = self.observation_model.estimate_scale(
y, self._predict(updated_params, data), dof_resid=self.dof_resid_
)
return self._validator.from_model_params(updated_params), updated_state
def _get_optimal_solver_params_config(self):
"""Return the functions for computing default step and batch size for the solver."""
return glm_compute_optimal_stepsize_configs(self)
[docs]
def __repr__(self):
"""Representation of the GLM class."""
return format_repr(
self, multiline=True, use_name_keys=["inverse_link_function"]
)
[docs]
def __sklearn_clone__(self) -> GLM:
"""Clone the GLM."""
params = self.get_params(deep=False)
klass = self.__class__(**params)
return klass
[docs]
def save_params(self, filename: Union[str, Path]):
"""
Save GLM model parameters to a .npz file.
This method allows to reuse the model parameters. The saved parameters can be loaded back
into a GLM instance using the `load_params` function.
Parameters
----------
filename :
The name of the file where the model parameters will be saved. The file will be saved in `.npz` format.
Examples
--------
>>> import nemos as nmo
>>> # Create a GLM model with specified parameters
>>> solver_args = {"stepsize": 0.1, "maxiter": 1000, "tol": 1e-6}
>>> model = nmo.glm.GLM(
... regularizer="Ridge",
... regularizer_strength=0.1,
... observation_model="Gamma",
... solver_name="BFGS",
... solver_kwargs=solver_args,
... )
>>> for key, value in model.get_params().items():
... print(f"{key}: {value}")
inverse_link_function: <function one_over_x at ...>
observation_model: GammaObservations()
regularizer: Ridge()
regularizer_strength: 0.1...
solver_kwargs: {'stepsize': 0.1, 'maxiter': 1000, 'tol': 1e-06}
solver_name: BFGS
>>> # Save the model parameters to a file
>>> model.save_params("model_params.npz")
>>> # Load the model from the saved file
>>> model = nmo.load_model("model_params.npz")
>>> # Model has the same parameters before and after load
>>> for key, value in model.get_params().items(): # doctest: +ELLIPSIS
... print(f"{key}: {value}")
inverse_link_function: <function one_over_x at ...>
observation_model: GammaObservations()
regularizer: Ridge()
regularizer_strength: 0.1
solver_kwargs: {'maxiter': 1000, 'stepsize': 0.1, 'tol': 1e-06}
solver_name: BFGS
>>> # Saving and loading a custom inverse link function
>>> model = nmo.glm.GLM(
... observation_model="Poisson",
... inverse_link_function=lambda x: x**2
... )
>>> model.save_params("model_params.npz")
>>> # Provide a mapping for the custom link function when loading.
>>> mapping_dict = {
... "inverse_link_function": lambda x: x**2,
... }
>>> loaded_model = nmo.load_model("model_params.npz", mapping_dict=mapping_dict)
>>> # Now the loaded model will have the updated solver_name and solver_kwargs
>>> for key, value in loaded_model.get_params().items():
... print(f"{key}: {value}")
inverse_link_function: <function <lambda> at ...>
observation_model: PoissonObservations()
regularizer: UnRegularized()
regularizer_strength: None
solver_kwargs: {}
solver_name: LBFGS
"""
# initialize saving dictionary
fit_attrs = self._get_fit_state()
fit_attrs.pop("solver_state_")
string_attrs = ["inverse_link_function"]
self._save_params(filename, fit_attrs, string_attrs)
[docs]
class PopulationGLM(GLM):
"""
Population Generalized Linear Model.
This class implements a Generalized Linear Model for a neural population.
This GLM implementation allows users to model the activity of a population of neurons based on a
combination of exogenous inputs (like convolved currents or light intensities) and a choice of observation model.
It is suitable for scenarios where the relationship between predictors and the response
variable might be non-linear, and the residuals don't follow a normal distribution. The predictors must be
stored in tabular format, shape (n_timebins, num_features) or as a pytree of arrays of the same shape.
Below is a table listing the default and available solvers for each regularizer.
+---------------+------------------+-------------------------------------------------------------+
| Regularizer | Default Solver | Available Solvers |
+===============+==================+=============================================================+
| UnRegularized | LBFGS | GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient |
+---------------+------------------+-------------------------------------------------------------+
| Ridge | LBFGS | GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient |
+---------------+------------------+-------------------------------------------------------------+
| Lasso | ProximalGradient | ProximalGradient |
+---------------+------------------+-------------------------------------------------------------+
| GroupLasso | ProximalGradient | ProximalGradient |
+---------------+------------------+-------------------------------------------------------------+
**Fitting Large Models**
For very large models, you may consider using the Stochastic Variance Reduced Gradient
:class:`nemos.solvers._svrg.SVRG` or its proximal variant
(:class:`nemos.solvers._svrg.ProxSVRG`) solver,
which take advantage of batched computation. You can change the solver by passing
``"SVRG"`` or ``"ProxSVRG"`` as ``solver_name`` at model initialization.
The performance of the SVRG solver depends critically on the choice of ``batch_size`` and ``stepsize``
hyperparameters. These parameters control the size of the mini-batches used for gradient computations
and the step size for each iteration, respectively. Improper selection of these parameters can lead to slow
convergence or even divergence of the optimization process.
To assist with this, for certain GLM configurations, we provide ``batch_size`` and ``stepsize`` default
values that are theoretically guaranteed to ensure fast convergence.
Below is a list of the configurations for which we can provide guaranteed hyperparameters:
+---------------------------------------+-----------+-------------+
| GLM / PopulationGLM Configuration | Stepsize | Batch Size |
+=======================================+===========+=============+
| Poisson + soft-plus + UnRegularized | ✅ | ❌ |
+---------------------------------------+-----------+-------------+
| Poisson + soft-plus + Ridge | ✅ | ✅ |
+---------------------------------------+-----------+-------------+
| Poisson + soft-plus + Lasso | ✅ | ❌ |
+---------------------------------------+-----------+-------------+
| Poisson + soft-plus + GroupLasso | ✅ | ❌ |
+---------------------------------------+-----------+-------------+
Parameters
----------
observation_model :
Observation model to use. The model describes the distribution of the neural activity.
Default is the Poisson model.
inverse_link_function :
A function that maps the linear combination of predictors into a firing rate. The default depends
on the observation model, see the table above.
regularizer :
Regularization to use for model optimization. Defines the regularization scheme
and related parameters.
Default is UnRegularized regression.
regularizer_strength :
Typically a float. Default is None. Sets the regularizer strength.
If a user does not pass a value, and it is needed for regularization,
a warning will be raised and the strength will default to 1.0.
For finer control, the user can pass a pytree that matches the
parameter structure to regularize parameters differentially.
solver_name :
Solver to use for model optimization. Defines the optimization scheme and related parameters.
The solver must be an appropriate match for the chosen regularizer.
Default is ``None``. If no solver specified, one will be chosen based on the regularizer.
Please see table above for regularizer/optimizer pairings.
solver_kwargs :
Optional dictionary for keyword arguments that are passed to the solver when instantiated.
E.g. stepsize, tol, acceleration, etc.
For details on each solver's kwargs, see `get_accepted_arguments` and `get_solver_documentation`.
feature_mask :
Either a matrix of shape (num_features, num_neurons) or a PyTree of 0s and 1s, with
leaves of shape (num_neurons, ).
The mask will be used to select which features are used as predictors for which neuron.
Attributes
----------
intercept_ :
Model baseline linked firing rate parameters, e.g. if the link is the logarithm, the baseline
firing rate will be ``jnp.exp(model.intercept_)``.
coef_ :
Basis coefficients for the model.
solver_state_ :
State of the solver after fitting. May include details like optimization error.
Raises
------
TypeError
If provided ``regularizer`` or ``observation_model`` are not valid.
TypeError
If provided ``feature_mask`` is not an array-like of dimension two.
Examples
--------
**Fit a PopulationGLM**
Basic model fitting for a population of neurons:
>>> import jax.numpy as jnp
>>> import numpy as np
>>> import nemos as nmo
>>> np.random.seed(123)
>>> num_samples, num_features, num_neurons = 100, 3, 2
>>> X = np.random.normal(size=(num_samples, num_features))
>>> weights = np.array([[0.5, 0.0], [-0.5, -0.5], [0.0, 1.0]])
>>> y = np.random.poisson(np.exp(X.dot(weights)))
>>> model = nmo.glm.PopulationGLM().fit(X, y)
>>> model.coef_.shape
(3, 2)
**Mask Coefficients with an Array**
Use a feature mask to specify which features predict each neuron.
The mask has shape ``(num_features, num_neurons)``:
>>> feature_mask = np.array([[1, 0], [1, 1], [0, 1]])
>>> model = nmo.glm.PopulationGLM(feature_mask=feature_mask).fit(X, y)
>>> model.coef_
Array(...)
**Use a Dict of Arrays as Input**
Features can be passed as a dict (or any JAX pytree). The feature mask
should mirror the same structure, with one 1-D entry per leaf:
>>> feature_1 = np.random.normal(size=(num_samples, 2))
>>> feature_2 = np.random.normal(size=(num_samples, 1))
>>> X_dict = {"feature_1": feature_1, "feature_2": feature_2}
>>> weights = dict(
... feature_1=jnp.array([[0.0, 0.5], [0.0, -0.5]]),
... feature_2=jnp.array([[1.0, 0.0]])
... )
>>> rate = np.exp(
... X_dict["feature_1"].dot(weights["feature_1"]) +
... X_dict["feature_2"].dot(weights["feature_2"])
... )
>>> y = np.random.poisson(rate)
>>> feature_mask = {
... "feature_1": jnp.array([0, 1], dtype=jnp.int32),
... "feature_2": jnp.array([1, 0], dtype=jnp.int32)
... }
>>> model = nmo.glm.PopulationGLM(feature_mask=feature_mask).fit(X_dict, y)
>>> model.coef_
{...}
**Customize the Observation Model**
Use a Gamma observation model for continuous positive data:
>>> model = nmo.glm.PopulationGLM(observation_model="Gamma")
>>> model.observation_model
GammaObservations()
**Use Regularization**
Fit with Ridge regularization:
>>> X = np.random.normal(size=(num_samples, num_features))
>>> weights = np.array([[0.5, 0.0], [-0.5, -0.5], [0.0, 1.0]])
>>> y = np.random.poisson(np.exp(X.dot(weights)))
>>> model = nmo.glm.PopulationGLM(
... regularizer="Ridge",
... regularizer_strength=0.1
... ).fit(X, y)
>>> model.regularizer
Ridge()
"""
_validator_class = PopulationGLMValidator
[docs]
def __init__(
self,
observation_model: (
REGRESSION_GLM_TYPES
| Literal["Poisson", "Gamma", "Gaussian", "Bernoulli", "NegativeBinomial"]
) = "Poisson",
inverse_link_function: Optional[Callable] = None,
regularizer: Union[str, Regularizer] = "UnRegularized",
regularizer_strength: Any = None,
solver_name: str = None,
solver_kwargs: dict = None,
feature_mask: Optional[jnp.ndarray] = None,
**kwargs,
):
super().__init__(
observation_model=observation_model,
inverse_link_function=inverse_link_function,
regularizer_strength=regularizer_strength,
regularizer=regularizer,
solver_name=solver_name,
solver_kwargs=solver_kwargs,
**kwargs,
)
self._metadata = None
self.feature_mask = feature_mask
@property
def feature_mask(self) -> Union[jnp.ndarray, dict[str, jnp.ndarray]]:
"""
Mask indicating which features are used for each neuron.
The feature mask has a tree structure matching the coefficients (``coef_``):
- **Array input**: Shape ``(n_features, n_neurons)``. Each entry ``[i, j]``
indicates whether feature ``i`` is used for neuron ``j`` (1 = used, 0 = masked).
- **Pytree**: A pytree with structure matching that of ``coef_``.
Each leaf array has shape ``(n_neurons,)``, indicating whether that feature
group is used for each neuron.
Returns
-------
jnp.ndarray or dict[str, jnp.ndarray]
The feature mask, or None if not set.
"""
return self._feature_mask
@feature_mask.setter
def feature_mask(self, feature_mask: Union[DESIGN_INPUT_TYPE, dict]):
# do not allow reassignment after fit
if (self.coef_ is not None) and (self.intercept_ is not None):
raise AttributeError(
"property 'feature_mask' of 'populationGLM' cannot be set after fitting."
)
self._feature_mask = self._validator.validate_and_cast_feature_mask(
feature_mask
)
[docs]
@strip_metadata(arg_num=1, arg_name="y")
def fit(
self,
X: Union[DESIGN_INPUT_TYPE, ArrayLike],
y: ArrayLike,
init_params: Optional[GLMUserParams] = None,
):
"""Fit GLM to the activity of a population of neurons.
Fit and store the model parameters as attributes ``coef_`` and ``intercept_``.
Each neuron can have different predictors. The ``feature_mask`` will determine which
feature will be used for which neurons. See the note below for more information on
the ``feature_mask``.
Parameters
----------
X :
Predictors, array of shape (n_timebins, n_features) or pytree of the same
shape.
y :
Target neural activity arranged in a matrix, shape (n_timebins, n_neurons).
init_params :
2-tuple of initial parameter values: (coefficients, intercepts). If
None, we initialize coefficients with zeros, intercepts with the
log of the mean neural activity. coefficients is an array of shape
(n_features, n_neurons) or pytree of the same shape, intercepts is an array
of shape (n_neurons, )
Raises
------
ValueError
If ``init_params`` is not of length two.
ValueError
If dimensionality of ``init_params`` are not correct.
ValueError
If ``X`` is not two-dimensional.
ValueError
If ``y`` is not two-dimensional.
ValueError
If the ``feature_mask`` is not of the right shape.
ValueError
If solver returns at least one NaN parameter, which means it found
an invalid solution. Try tuning optimization hyperparameters.
TypeError
If ``init_params`` are not array-like
TypeError
If ``init_params[i]`` cannot be converted to jnp.ndarray for all i
Notes
-----
The ``feature_mask`` is used to select features for each neuron, and it is
an NDArray or a PyTree of 0s and 1s. In particular,
- If the mask is in array format, feature ``i`` is a predictor for neuron ``j`` if
``feature_mask[i, j] == 1``.
- If the mask is a PyTree, then
a leaf is a predictor of neuron ``j`` if the matching leaf in ``feature_mask``
is equal to 1.
Examples
--------
>>> # Generate sample data
>>> import jax.numpy as jnp
>>> import numpy as np
>>> from nemos.glm import PopulationGLM
>>> np.random.seed(0)
>>> # Define predictors (X), weights, and neural activity (y)
>>> num_samples, num_features, num_neurons = 100, 3, 2
>>> X = np.random.normal(size=(num_samples, num_features))
>>> # Weights is defined by how each feature influences the output, shape (num_features, num_neurons)
>>> weights = np.array([[ 0.5, 0. ], [-0.5, -0.5], [ 0. , 1. ]])
>>> # Output y simulates a Poisson distribution based on a linear model between features X and wegihts
>>> y = np.random.poisson(np.exp(X.dot(weights)))
>>> # Define a feature mask, shape (num_features, num_neurons)
>>> feature_mask = jnp.array([[1, 0], [1, 1], [0, 1]])
>>> # Create and fit the model
>>> model = PopulationGLM(feature_mask=feature_mask).fit(X, y)
>>> print(model.coef_.shape)
(3, 2)
"""
return super().fit(X, y, init_params)
def _predict(self, params: GLMParams, X: jnp.ndarray) -> jnp.ndarray:
"""
Predicts firing rates based on given parameters and design matrix.
This function computes the predicted firing rates using the provided parameters, the feature
mask and model design matrix ``X``. It is a streamlined version used internally within
optimization routines, where it serves as the loss function. Unlike the ``GLM.predict``
method, it does not perform any input validation, assuming that the inputs are pre-validated.
The parameters are first element-wise multiplied with the mask, then the canonical
linear-non-linear GLM map is applied.
Parameters
----------
params :
GLMParams containing the spike basis coefficients and bias terms.
X :
Predictors.
Returns
-------
:
The predicted rates. Shape (n_timebins, n_neurons).
"""
if self._feature_mask is None:
return super()._predict(params, X)
return self.inverse_link_function(
# First, multiply each feature by its corresponding coefficient,
# then sum across all features and add the intercept, before
# passing to the inverse link function
tree_utils.pytree_map_and_reduce(
lambda x, w, m: jnp.einsum("ti, i...->t...", x, w * m),
sum,
X,
params.coef,
self._feature_mask,
)
+ params.intercept
)
[docs]
def __sklearn_clone__(self) -> PopulationGLM:
"""Clone the PopulationGLM, dropping feature_mask."""
params = self.get_params(deep=False)
klass = self.__class__(**params)
# reattach metadata
klass._metadata = self._metadata
return klass