nemos.glm.GLM#

class nemos.glm.GLM(observation_model=PoissonObservations(inverse_link_function=exp), regularizer='UnRegularized', regularizer_strength=None, solver_name=None, solver_kwargs=None)[source]#

Bases: BaseRegressor

Generalized Linear Model (GLM) for neural activity data.

This GLM implementation allows users to model neural activity based on a combination of exogenous inputs (like convolved currents or light intensities) and a choice of observation model. It is suitable for scenarios where the relationship between predictors and the response variable might be non-linear, and the residuals don’t follow a normal distribution. Below is a table listing the default and available solvers for each regularizer.

Regularizer

Default Solver

Available Solvers

UnRegularized

GradientDescent

GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient

Ridge

GradientDescent

GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient

Lasso

ProximalGradient

ProximalGradient

GroupLasso

ProximalGradient

ProximalGradient

Fitting Large Models

For very large models, you may consider using the Stochastic Variance Reduced Gradient nemos.solvers._svrg.SVRG or its proximal variant nemos.solvers._svrg.ProxSVRG solver, which take advantage of batched computation. You can change the solver by passing "SVRG" as solver_name at model initialization.

The performance of the SVRG solver depends critically on the choice of batch_size and stepsize hyperparameters. These parameters control the size of the mini-batches used for gradient computations and the step size for each iteration, respectively. Improper selection of these parameters can lead to slow convergence or even divergence of the optimization process.

To assist with this, for certain GLM configurations, we provide batch_size and stepsize default values that are theoretically guaranteed to ensure fast convergence.

Below is a list of the configurations for which we can provide guaranteed default hyperparameters:

GLM / PopulationGLM Configuration

Stepsize

Batch Size

Poisson + soft-plus + UnRegularized

Poisson + soft-plus + Ridge

Poisson + soft-plus + Lasso

Poisson + soft-plus + GroupLasso

Parameters:
  • observation_model (Observations) – Observation model to use. The model describes the distribution of the neural activity. Default is the Poisson model.

  • regularizer (Union[str, Regularizer]) – Regularization to use for model optimization. Defines the regularization scheme and related parameters. Default is UnRegularized regression.

  • regularizer_strength (Optional[float]) – Float that is default None. Sets the regularizer strength. If a user does not pass a value, and it is needed for regularization, a warning will be raised and the strength will default to 1.0.

  • solver_name (str) – Solver to use for model optimization. Defines the optimization scheme and related parameters. The solver must be an appropriate match for the chosen regularizer. Default is None. If no solver specified, one will be chosen based on the regularizer. Please see table above for regularizer/optimizer pairings.

  • solver_kwargs (dict) –

    Optional dictionary for keyword arguments that are passed to the solver when instantiated. E.g. stepsize, acceleration, value_and_grad, etc.

    See the jaxopt documentation for details on each solver’s kwargs: https://jaxopt.github.io/stable/

intercept_#

Model baseline linked firing rate parameters, e.g. if the link is the logarithm, the baseline firing rate will be jnp.exp(model.intercept_).

coef_#

Basis coefficients for the model.

solver_state_#

State of the solver after fitting. May include details like optimization error.

scale_#

Scale parameter for the model. The scale parameter is the constant \(\Phi\), for which \(\text{Var} \left( y \right) = \Phi V(\mu)\). This parameter, together with the estimate of the mean \(\mu\) fully specifies the distribution of the activity \(y\).

dof_resid_#

Degrees of freedom for the residuals.

Raises:

TypeError – If provided regularizer or observation_model are not valid.

Parameters:
  • observation_model (obs.Observations)

  • regularizer (Union[str, Regularizer])

  • regularizer_strength (Optional[float])

  • solver_name (str)

  • solver_kwargs (dict)

Examples

>>> import nemos as nmo
>>> # define single neuron GLM model
>>> model = nmo.glm.GLM()
>>> model
GLM(
    observation_model=PoissonObservations(inverse_link_function=exp),
    regularizer=UnRegularized(),
    solver_name='GradientDescent'
)
>>> print("Regularizer type: ", type(model.regularizer))
Regularizer type:  <class 'nemos.regularizer.UnRegularized'>
>>> print("Observation model: ", type(model.observation_model))
Observation model:  <class 'nemos.observation_models.PoissonObservations'>
>>> # define GLM model of PoissonObservations model with soft-plus NL
>>> observation_models = nmo.observation_models.PoissonObservations(jax.nn.softplus)
>>> model = nmo.glm.GLM(observation_model=observation_models, solver_name="LBFGS")
>>> print("Regularizer type: ", type(model.regularizer))
Regularizer type:  <class 'nemos.regularizer.UnRegularized'>
>>> print("Observation model: ", type(model.observation_model))
Observation model:  <class 'nemos.observation_models.PoissonObservations'>

Attributes

observation_model

Getter for the observation_model attribute.

regularizer

Getter for the regularizer attribute.

regularizer_strength

Regularizer strength getter.

solver_init_state

Provides the initialization function for the solver's state.

solver_kwargs

Getter for the solver_kwargs attribute.

solver_name

Getter for the solver_name attribute.

solver_run

Provides the function to execute the solver's optimization process.

solver_update

Provides the function for updating the solver's state during the optimization process.

__init__(observation_model=PoissonObservations(inverse_link_function=exp), regularizer='UnRegularized', regularizer_strength=None, solver_name=None, solver_kwargs=None)[source]#
Parameters:
  • observation_model (Observations)

  • regularizer (str | Regularizer)

  • regularizer_strength (float | None)

  • solver_name (str)

  • solver_kwargs (dict)

Methods

__init__([observation_model, regularizer, ...])

fit(X, y[, init_params])

Fit GLM to neural activity.

get_params([deep])

From scikit-learn, get parameters by inspecting init.

initialize_params(X, y[, init_params])

Initialize the model parameters for the optimization process.

initialize_state(X, y, init_params)

Initialize the solver by instantiating its init_state, update and, run methods.

instantiate_solver(*args[, solver_kwargs])

Instantiate the solver with the provided loss function.

predict(X)

Predict rates based on fit parameters.

score(X, y[, score_type, ...])

Evaluate the goodness-of-fit of the model to the observed neural data.

set_params(**params)

Manage warnings in case of multiple parameter settings.

simulate(random_key, feedforward_input)

Simulate neural activity in response to a feed-forward input.

update(params, opt_state, X, y, *args[, ...])

Update the model parameters and solver state.

__sklearn_clone__()[source]#

Clone the PopulationGLM, dropping feature_mask

Return type:

GLM

fit(X, y, init_params=None)[source]#

Fit GLM to neural activity.

Fit and store the model parameters as attributes coef_ and coef_.

Parameters:
  • X (Union[Array, FeaturePytree, ArrayLike]) – Predictors, array of shape (n_time_bins, n_features) or pytree of the same shape.

  • y (ArrayLike) – Target neural activity arranged in a matrix, shape (n_time_bins, ).

  • init_params (Optional[Tuple[Union[dict, ArrayLike], ArrayLike]]) – 2-tuple of initial parameter values: (coefficients, intercepts). If None, we initialize coefficients with zeros, intercepts with the log of the mean neural activity. coefficients is an array of shape (n_features,) or pytree of same, intercepts is an array of shape (1, )

Raises:
  • ValueError – If init_params is not of length two.

  • ValueError – If dimensionality of init_params are not correct.

  • ValueError – If X is not two-dimensional.

  • ValueError – If y is not one-dimensional.

  • ValueError – If solver returns at least one NaN parameter, which means it found an invalid solution. Try tuning optimization hyperparameters.

  • TypeError – If init_params are not array-like

  • TypeError – If init_params[i] cannot be converted to jnp.ndarray for all i

Examples

>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
>>> # fit a ridge regression Poisson GLM
>>> import nemos as nmo
>>> model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1)
>>> model = model.fit(X, y)
>>> # get model weights and intercept
>>> model_weights = model.coef_
>>> model_intercept = model.intercept_
get_params(deep=True)#

From scikit-learn, get parameters by inspecting init.

Parameters:

deep

Return type:

dict

Returns:

out:

A dictionary containing the parameters. Key is the parameter name, value is the parameter value.

initialize_params(X, y, init_params=None)[source]#

Initialize the model parameters for the optimization process.

This method prepares the initializes model parameters if they are not provided. It is typically called before starting the optimization process to ensure that all necessary components and states are correctly configured.

Parameters:
  • X (Union[Array, FeaturePytree]) – The predictors used in the model fitting process. This can include feature matrices or other structures compatible with the model’s design.

  • y (Array) – The response variables or outputs corresponding to the predictors. Used to initialize parameters when they are not provided.

  • init_params (Optional[Tuple[Array, Array]]) – Initial parameters for the model. If not provided, they will be initialized based on the input data X and y.

Returns:

The initialized model parameters

Return type:

ModelParams

Raises:
  • ValueError – If params is not of length two.

  • ValueError – If dimensionality of init_params are not correct.

  • ValueError – If X is not two-dimensional.

  • ValueError – If y is not correct (1D for GLM, 2D for populationGLM).

  • TypeError – If params are not array-like when provided.

  • TypeError – If init_params[i] cannot be converted to jnp.ndarray for all i

Examples

>>> import numpy as np
>>> import nemos as nmo
>>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10)
>>> model = nmo.glm.GLM()
>>> params = model.initialize_params(X, y)
>>> opt_state = model.initialize_state(X, y, params)
>>> # Now ready to run optimization or update steps
initialize_state(X, y, init_params)[source]#

Initialize the solver by instantiating its init_state, update and, run methods.

This method also prepares the solver’s state by using the initialized model parameters and data. This setup is ready to be used for running the solver’s optimization routines.

Parameters:
  • X (Union[Array, FeaturePytree]) – The predictors used in the model fitting process. This can include feature matrices or other structures compatible with the model’s design.

  • y (Array) – The response variables or outputs corresponding to the predictors. Used to initialize parameters when they are not provided.

  • init_params – Initial parameters for the model.

Returns:

The initialized solver state

Return type:

NamedTuple

Examples

>>> import numpy as np
>>> import nemos as nmo
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
>>> model = nmo.glm.GLM()
>>> params = model.initialize_params(X, y)
>>> opt_state = model.initialize_state(X, y, params)
>>> # Now ready to run optimization or update steps
instantiate_solver(*args, solver_kwargs=None)#

Instantiate the solver with the provided loss function.

Instantiate the solver with the provided loss function, and store callable functions that initialize the solver state, update the model parameters, and run the optimization as attributes.

This method creates a solver instance from nemos.solvers or the jaxopt library, tailored to the specific loss function and regularization approach defined by the Regularizer instance. It also handles the proximal operator if required for the optimization method. The returned functions are directly usable in optimization loops, simplifying the syntax by pre-setting common arguments like regularization strength and other hyperparameters.

Parameters:
  • *args – Positional arguments for the jaxopt solver.run method, e.g. the regularizing strength for proximal gradient methods.

  • solver_kwargs (Optional[dict]) – Optional dictionary with the solver kwargs. If nothing is provided, it defaults to self.solver_kwargs.

Return type:

BaseRegressor

Returns:

The instance itself for method chaining.

property observation_model: None | Observations#

Getter for the observation_model attribute.

predict(X)[source]#

Predict rates based on fit parameters.

Parameters:

X (Union[Array, FeaturePytree]) – Predictors, array of shape (n_time_bins, n_features) or pytree of same.

Return type:

Array

Returns:

The predicted rates with shape (n_time_bins, ).

Raises:
  • NotFittedError – If fit has not been called first with this instance.

  • ValueError – If params is not a JAX pytree of size two.

  • ValueError – If weights and bias terms in params don’t have the expected dimensions.

  • ValueError – If X is not three-dimensional.

  • ValueError – If there’s an inconsistent number of features between spike basis coefficients and X.

Examples

>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
>>> # define and fit a GLM
>>> import nemos as nmo
>>> model = nmo.glm.GLM()
>>> model = model.fit(X, y)
>>> # predict new spike data
>>> Xnew = np.random.normal(size=(20, X.shape[1]))
>>> predicted_spikes = model.predict(Xnew)

See also

nemos.glm.GLM.score()

Score predicted rates against target spike counts.

nemos.glm.GLM.simulate()

Simulate neural activity in response to a feed-forward input (feed-forward only).

nemos.simulation.simulate_recurrent()

Simulate neural activity in response to a feed-forward input using the GLM as a recurrent network (feed-forward + coupling).

property regularizer: None | Regularizer#

Getter for the regularizer attribute.

property regularizer_strength: float#

Regularizer strength getter.

score(X, y, score_type='log-likelihood', aggregate_sample_scores=<function mean>)[source]#

Evaluate the goodness-of-fit of the model to the observed neural data.

This method computes the goodness-of-fit score, which can either be the mean log-likelihood or of two versions of the pseudo-\(R^2\). The scoring process includes validation of input compatibility with the model’s parameters, ensuring that the model has been previously fitted and the input data are appropriate for scoring. A higher score indicates a better fit of the model to the observed data.

Parameters:
  • X (Union[Array, FeaturePytree, ArrayLike]) – The exogenous variables. Shape (n_time_bins, n_features).

  • y (ArrayLike) – Neural activity. Shape (n_time_bins, ).

  • score_type (Literal['log-likelihood', 'pseudo-r2-McFadden', 'pseudo-r2-Cohen']) – Type of scoring: either log-likelihood or pseudo-\(R^2\).

  • aggregate_sample_scores (Callable) – Function that aggregates the score of all samples.

Returns:

The log-likelihood or the pseudo-\(R^2\) of the current model.

Return type:

score

Raises:
  • NotFittedError – If fit has not been called first with this instance.

  • ValueError – If X structure doesn’t match the params, and if X and y have different number of samples.

Examples

>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
>>> import nemos as nmo
>>> model = nmo.glm.GLM()
>>> model = model.fit(X, y)
>>> # get model score
>>> log_likelihood_score = model.score(X, y)
>>> # get a pseudo-R2 score
>>> pseudo_r2_score = model.score(X, y, score_type='pseudo-r2-McFadden')

Notes

The log-likelihood is not on a standard scale, its value is influenced by many factors, among which the number of model parameters. The log-likelihood can assume both positive and negative values.

The Pseudo-\(R^2\) is not equivalent to the \(R^2\) value in linear regression. While both provide a measure of model fit, and assume values in the [0,1] range, the methods and interpretations can differ. The Pseudo-\(R^2\) is particularly useful for generalized linear models when the interpretation of the \(R^2\) as explained variance does not apply (i.e., when the observations are not Gaussian distributed).

Why does the traditional \(R^2\) is usually a poor measure of performance in GLMs?

  1. In the context of GLMs the variance and the mean of the observations are related. Ignoring the relation between them can result in underestimating the model performance; for instance, when we model a Poisson variable with large mean we expect an equally large variance. In this scenario, even if our model perfectly captures the mean, the high-variance will result in large residuals and low \(R^2\). Additionally, when the mean of the observations varies, the variance will vary too. This violates the “homoschedasticity” assumption, necessary for interpreting the \(R^2\) as variance explained.

  2. The \(R^2\) capture the variance explained when the relationship between the observations and the predictors is linear. In GLMs, the link function sets a non-linear mapping between the predictors and the mean of the observations, compromising the interpretation of the \(R^2\).

Note that it is possible to re-normalized the residuals by a mean-dependent quantity proportional to the model standard deviation (i.e. Pearson residuals). This “rescaled” residual distribution however deviates substantially from normality for counting data with low mean (common for spike counts). Therefore, even the Pearson residuals performs poorly as a measure of fit quality, especially for GLM modeling counting data.

Refer to the nmo.observation_models.Observations concrete subclasses for the likelihood and pseudo-\(R^2\) equations.

set_params(**params)#

Manage warnings in case of multiple parameter settings.

Parameters:

params (Any)

simulate(random_key, feedforward_input)[source]#

Simulate neural activity in response to a feed-forward input.

Parameters:
  • random_key (Array) – jax.random.key for seeding the simulation.

  • feedforward_input (Union[Array, FeaturePytree]) – External input matrix to the model, representing factors like convolved currents, light intensities, etc. When not provided, the simulation is done with coupling-only. Array of shape (n_time_bins, n_basis_input) or pytree of same.

Return type:

Tuple[Array, Array]

Returns:

  • simulated_activity – Simulated activity (spike counts for Poisson GLMs) for the neuron over time. Shape: (n_time_bins, ).

  • firing_rates – Simulated rates for the neuron over time. Shape, (n_time_bins, ).

Raises:
  • NotFittedError

    • If the model hasn’t been fitted prior to calling this method.

  • ValueError

    • If the instance has not been previously fitted.

Examples

>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
>>> # define and fit model
>>> import nemos as nmo
>>> model = nmo.glm.GLM()
>>> model = model.fit(X, y)
>>> # generate spikes and rates
>>> random_key = jax.random.key(123)
>>> Xnew = np.random.normal(size=(20, X.shape[1]))
>>> spikes, rates = model.simulate(random_key, Xnew)

See also

nemos.glm.GLM.predict()

Method to predict rates based on the model’s parameters.

property solver_init_state: None | Callable[[Any, Array, Array], NamedTuple]#

Provides the initialization function for the solver’s state.

This function is responsible for initializing the solver’s state, necessary for the start of the optimization process. It sets up initial values for parameters like gradients and step sizes based on the model configuration and input data.

Returns:

The function to initialize the state of the solver, if available; otherwise, None if the solver has not yet been instantiated.

property solver_kwargs#

Getter for the solver_kwargs attribute.

property solver_name: str#

Getter for the solver_name attribute.

property solver_run: None | Callable[[Any, Array, Array], OptStep]#

Provides the function to execute the solver’s optimization process.

This function runs the solver using the initialized parameters and state, performing the optimization to fit the model to the data. It iteratively updates the model parameters until a stopping criterion is met, such as convergence or exceeding a maximum number of iterations.

Returns:

The function to run the solver’s optimization process, if available; otherwise, None if the solver has not yet been instantiated.

property solver_update: None | Callable[[Any, NamedTuple, Array, Array], OptStep]#

Provides the function for updating the solver’s state during the optimization process.

This function is used to perform a single update step in the optimization process. It updates the model’s parameters based on the current state, data, and gradients. It is typically used in scenarios where fine-grained control over each optimization step is necessary, such as in online learning or complex optimization scenarios.

Returns:

The function to update the solver’s state, if available; otherwise, None if the solver has not yet been instantiated.

update(params, opt_state, X, y, *args, n_samples=None, **kwargs)[source]#

Update the model parameters and solver state.

This method performs a single optimization step using the model’s current solver. It updates the model’s coefficients and intercept based on the provided parameters, predictors (X), responses (y), and the current optimization state. This method is particularly useful for iterative model fitting, especially in scenarios where model parameters need to be updated incrementally, such as online learning or when dealing with very large datasets that do not fit into memory at once.

Parameters:
  • params (Tuple[Array, Array]) – The current model parameters, typically a tuple of coefficients and intercepts.

  • opt_state (NamedTuple) – The current state of the optimizer, encapsulating information necessary for the optimization algorithm to continue from the current state. This includes gradients, step sizes, and other optimizer-specific metrics.

  • X (Union[Array, FeaturePytree]) – The predictors used in the model fitting process, which may include feature matrices or nemos.pytrees.FeaturePytree objects.

  • y (Array) – The response variable or output data corresponding to the predictors, used in the model fitting process.

  • *args – Additional positional arguments to be passed to the solver’s update method.

  • n_samples (Optional[int]) – The tot number of samples. Usually larger than the samples of an indivisual batch, the n_samples are used to estimate the scale parameter of the GLM.

  • **kwargs – Additional keyword arguments to be passed to the solver’s update method.

Returns:

A tuple containing the updated parameters and optimization state. This tuple is typically used to continue the optimization process in subsequent steps.

Return type:

jaxopt.OptStep

Raises:

ValueError – If the solver has not been instantiated or if the solver returns NaN values indicating an invalid update step, typically due to numerical instabilities or inappropriate solver configurations.

Examples

>>> import nemos as nmo
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10)
>>> glm_instance = nmo.glm.GLM().fit(X, y)
>>> params = glm_instance.coef_, glm_instance.intercept_
>>> opt_state = glm_instance.solver_state_
>>> new_params, new_opt_state = glm_instance.update(params, opt_state, X, y)