nemos.glm.GLM#
- class nemos.glm.GLM(observation_model=PoissonObservations(inverse_link_function=exp), regularizer='UnRegularized', regularizer_strength=None, solver_name=None, solver_kwargs=None)[source]#
Bases:
BaseRegressor
Generalized Linear Model (GLM) for neural activity data.
This GLM implementation allows users to model neural activity based on a combination of exogenous inputs (like convolved currents or light intensities) and a choice of observation model. It is suitable for scenarios where the relationship between predictors and the response variable might be non-linear, and the residuals don’t follow a normal distribution. Below is a table listing the default and available solvers for each regularizer.
Regularizer
Default Solver
Available Solvers
UnRegularized
GradientDescent
GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient
Ridge
GradientDescent
GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient
Lasso
ProximalGradient
ProximalGradient
GroupLasso
ProximalGradient
ProximalGradient
Fitting Large Models
For very large models, you may consider using the Stochastic Variance Reduced Gradient
nemos.solvers._svrg.SVRG
or its proximal variantnemos.solvers._svrg.ProxSVRG
solver, which take advantage of batched computation. You can change the solver by passing"SVRG"
assolver_name
at model initialization.The performance of the SVRG solver depends critically on the choice of
batch_size
andstepsize
hyperparameters. These parameters control the size of the mini-batches used for gradient computations and the step size for each iteration, respectively. Improper selection of these parameters can lead to slow convergence or even divergence of the optimization process.To assist with this, for certain GLM configurations, we provide
batch_size
andstepsize
default values that are theoretically guaranteed to ensure fast convergence.Below is a list of the configurations for which we can provide guaranteed default hyperparameters:
GLM / PopulationGLM Configuration
Stepsize
Batch Size
Poisson + soft-plus + UnRegularized
✅
❌
Poisson + soft-plus + Ridge
✅
✅
Poisson + soft-plus + Lasso
✅
❌
Poisson + soft-plus + GroupLasso
✅
❌
- Parameters:
observation_model (
Observations
) – Observation model to use. The model describes the distribution of the neural activity. Default is the Poisson model.regularizer (
Union
[str
,Regularizer
]) – Regularization to use for model optimization. Defines the regularization scheme and related parameters. Default is UnRegularized regression.regularizer_strength (
Optional
[float
]) – Float that is default None. Sets the regularizer strength. If a user does not pass a value, and it is needed for regularization, a warning will be raised and the strength will default to 1.0.solver_name (
str
) – Solver to use for model optimization. Defines the optimization scheme and related parameters. The solver must be an appropriate match for the chosen regularizer. Default isNone
. If no solver specified, one will be chosen based on the regularizer. Please see table above for regularizer/optimizer pairings.solver_kwargs (
dict
) –Optional dictionary for keyword arguments that are passed to the solver when instantiated. E.g. stepsize, acceleration, value_and_grad, etc.
See the jaxopt documentation for details on each solver’s kwargs: https://jaxopt.github.io/stable/
- intercept_#
Model baseline linked firing rate parameters, e.g. if the link is the logarithm, the baseline firing rate will be
jnp.exp(model.intercept_)
.
- coef_#
Basis coefficients for the model.
- solver_state_#
State of the solver after fitting. May include details like optimization error.
- scale_#
Scale parameter for the model. The scale parameter is the constant \(\Phi\), for which \(\text{Var} \left( y \right) = \Phi V(\mu)\). This parameter, together with the estimate of the mean \(\mu\) fully specifies the distribution of the activity \(y\).
- dof_resid_#
Degrees of freedom for the residuals.
- Raises:
TypeError – If provided
regularizer
orobservation_model
are not valid.- Parameters:
observation_model (obs.Observations)
regularizer (Union[str, Regularizer])
regularizer_strength (Optional[float])
solver_name (str)
solver_kwargs (dict)
Examples
>>> import nemos as nmo >>> # define single neuron GLM model >>> model = nmo.glm.GLM() >>> model GLM( observation_model=PoissonObservations(inverse_link_function=exp), regularizer=UnRegularized(), solver_name='GradientDescent' ) >>> print("Regularizer type: ", type(model.regularizer)) Regularizer type: <class 'nemos.regularizer.UnRegularized'> >>> print("Observation model: ", type(model.observation_model)) Observation model: <class 'nemos.observation_models.PoissonObservations'> >>> # define GLM model of PoissonObservations model with soft-plus NL >>> observation_models = nmo.observation_models.PoissonObservations(jax.nn.softplus) >>> model = nmo.glm.GLM(observation_model=observation_models, solver_name="LBFGS") >>> print("Regularizer type: ", type(model.regularizer)) Regularizer type: <class 'nemos.regularizer.UnRegularized'> >>> print("Observation model: ", type(model.observation_model)) Observation model: <class 'nemos.observation_models.PoissonObservations'>
Attributes
Getter for the
observation_model
attribute.Getter for the regularizer attribute.
Regularizer strength getter.
Provides the initialization function for the solver's state.
Getter for the solver_kwargs attribute.
Getter for the solver_name attribute.
Provides the function to execute the solver's optimization process.
Provides the function for updating the solver's state during the optimization process.
- __init__(observation_model=PoissonObservations(inverse_link_function=exp), regularizer='UnRegularized', regularizer_strength=None, solver_name=None, solver_kwargs=None)[source]#
- Parameters:
observation_model (Observations)
regularizer (str | Regularizer)
regularizer_strength (float | None)
solver_name (str)
solver_kwargs (dict)
Methods
__init__
([observation_model, regularizer, ...])fit
(X, y[, init_params])Fit GLM to neural activity.
get_params
([deep])From scikit-learn, get parameters by inspecting init.
initialize_params
(X, y[, init_params])Initialize the model parameters for the optimization process.
initialize_state
(X, y, init_params)Initialize the solver by instantiating its init_state, update and, run methods.
instantiate_solver
(*args[, solver_kwargs])Instantiate the solver with the provided loss function.
predict
(X)Predict rates based on fit parameters.
score
(X, y[, score_type, ...])Evaluate the goodness-of-fit of the model to the observed neural data.
set_params
(**params)Manage warnings in case of multiple parameter settings.
simulate
(random_key, feedforward_input)Simulate neural activity in response to a feed-forward input.
update
(params, opt_state, X, y, *args[, ...])Update the model parameters and solver state.
- fit(X, y, init_params=None)[source]#
Fit GLM to neural activity.
Fit and store the model parameters as attributes
coef_
andcoef_
.- Parameters:
X (
Union
[Array
,FeaturePytree
, ArrayLike]) – Predictors, array of shape (n_time_bins, n_features) or pytree of the same shape.y (ArrayLike) – Target neural activity arranged in a matrix, shape (n_time_bins, ).
init_params (
Optional
[Tuple
[Union
[dict
, ArrayLike], ArrayLike]]) – 2-tuple of initial parameter values: (coefficients, intercepts). If None, we initialize coefficients with zeros, intercepts with the log of the mean neural activity. coefficients is an array of shape (n_features,) or pytree of same, intercepts is an array of shape (1, )
- Raises:
ValueError – If
init_params
is not of length two.ValueError – If dimensionality of
init_params
are not correct.ValueError – If
X
is not two-dimensional.ValueError – If
y
is not one-dimensional.ValueError – If solver returns at least one NaN parameter, which means it found an invalid solution. Try tuning optimization hyperparameters.
TypeError – If
init_params
are not array-likeTypeError – If
init_params[i]
cannot be converted tojnp.ndarray
for alli
Examples
>>> # example input >>> import numpy as np >>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10) >>> # fit a ridge regression Poisson GLM >>> import nemos as nmo >>> model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1) >>> model = model.fit(X, y) >>> # get model weights and intercept >>> model_weights = model.coef_ >>> model_intercept = model.intercept_
- get_params(deep=True)#
From scikit-learn, get parameters by inspecting init.
- Parameters:
deep
- Return type:
dict
- Returns:
- out:
A dictionary containing the parameters. Key is the parameter name, value is the parameter value.
- initialize_params(X, y, init_params=None)[source]#
Initialize the model parameters for the optimization process.
This method prepares the initializes model parameters if they are not provided. It is typically called before starting the optimization process to ensure that all necessary components and states are correctly configured.
- Parameters:
X (
Union
[Array
,FeaturePytree
]) – The predictors used in the model fitting process. This can include feature matrices or other structures compatible with the model’s design.y (
Array
) – The response variables or outputs corresponding to the predictors. Used to initialize parameters when they are not provided.init_params (
Optional
[Tuple
[Array
,Array
]]) – Initial parameters for the model. If not provided, they will be initialized based on the input data X and y.
- Returns:
The initialized model parameters
- Return type:
ModelParams
- Raises:
ValueError – If
params
is not of length two.ValueError – If dimensionality of
init_params
are not correct.ValueError – If
X
is not two-dimensional.ValueError – If
y
is not correct (1D for GLM, 2D for populationGLM).TypeError – If
params
are not array-like when provided.TypeError – If
init_params[i]
cannot be converted to jnp.ndarray for all i
Examples
>>> import numpy as np >>> import nemos as nmo >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) >>> model = nmo.glm.GLM() >>> params = model.initialize_params(X, y) >>> opt_state = model.initialize_state(X, y, params) >>> # Now ready to run optimization or update steps
- initialize_state(X, y, init_params)[source]#
Initialize the solver by instantiating its init_state, update and, run methods.
This method also prepares the solver’s state by using the initialized model parameters and data. This setup is ready to be used for running the solver’s optimization routines.
- Parameters:
X (
Union
[Array
,FeaturePytree
]) – The predictors used in the model fitting process. This can include feature matrices or other structures compatible with the model’s design.y (
Array
) – The response variables or outputs corresponding to the predictors. Used to initialize parameters when they are not provided.init_params – Initial parameters for the model.
- Returns:
The initialized solver state
- Return type:
NamedTuple
Examples
>>> import numpy as np >>> import nemos as nmo >>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10) >>> model = nmo.glm.GLM() >>> params = model.initialize_params(X, y) >>> opt_state = model.initialize_state(X, y, params) >>> # Now ready to run optimization or update steps
- instantiate_solver(*args, solver_kwargs=None)#
Instantiate the solver with the provided loss function.
Instantiate the solver with the provided loss function, and store callable functions that initialize the solver state, update the model parameters, and run the optimization as attributes.
This method creates a solver instance from nemos.solvers or the jaxopt library, tailored to the specific loss function and regularization approach defined by the Regularizer instance. It also handles the proximal operator if required for the optimization method. The returned functions are directly usable in optimization loops, simplifying the syntax by pre-setting common arguments like regularization strength and other hyperparameters.
- Parameters:
*args – Positional arguments for the jaxopt solver.run method, e.g. the regularizing strength for proximal gradient methods.
solver_kwargs (
Optional
[dict
]) – Optional dictionary with the solver kwargs. If nothing is provided, it defaults to self.solver_kwargs.
- Return type:
BaseRegressor
- Returns:
The instance itself for method chaining.
- property observation_model: None | Observations#
Getter for the
observation_model
attribute.
- predict(X)[source]#
Predict rates based on fit parameters.
- Parameters:
X (
Union
[Array
,FeaturePytree
]) – Predictors, array of shape(n_time_bins, n_features)
or pytree of same.- Return type:
Array
- Returns:
The predicted rates with shape
(n_time_bins, )
.- Raises:
NotFittedError – If
fit
has not been called first with this instance.ValueError – If
params
is not a JAX pytree of size two.ValueError – If weights and bias terms in
params
don’t have the expected dimensions.ValueError – If
X
is not three-dimensional.ValueError – If there’s an inconsistent number of features between spike basis coefficients and
X
.
Examples
>>> # example input >>> import numpy as np >>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10) >>> # define and fit a GLM >>> import nemos as nmo >>> model = nmo.glm.GLM() >>> model = model.fit(X, y) >>> # predict new spike data >>> Xnew = np.random.normal(size=(20, X.shape[1])) >>> predicted_spikes = model.predict(Xnew)
See also
nemos.glm.GLM.score()
Score predicted rates against target spike counts.
nemos.glm.GLM.simulate()
Simulate neural activity in response to a feed-forward input (feed-forward only).
nemos.simulation.simulate_recurrent()
Simulate neural activity in response to a feed-forward input using the GLM as a recurrent network (feed-forward + coupling).
- property regularizer: None | Regularizer#
Getter for the regularizer attribute.
- property regularizer_strength: float#
Regularizer strength getter.
- score(X, y, score_type='log-likelihood', aggregate_sample_scores=<function mean>)[source]#
Evaluate the goodness-of-fit of the model to the observed neural data.
This method computes the goodness-of-fit score, which can either be the mean log-likelihood or of two versions of the pseudo-\(R^2\). The scoring process includes validation of input compatibility with the model’s parameters, ensuring that the model has been previously fitted and the input data are appropriate for scoring. A higher score indicates a better fit of the model to the observed data.
- Parameters:
X (
Union
[Array
,FeaturePytree
, ArrayLike]) – The exogenous variables. Shape(n_time_bins, n_features)
.y (ArrayLike) – Neural activity. Shape
(n_time_bins, )
.score_type (
Literal
['log-likelihood'
,'pseudo-r2-McFadden'
,'pseudo-r2-Cohen'
]) – Type of scoring: either log-likelihood or pseudo-\(R^2\).aggregate_sample_scores (
Callable
) – Function that aggregates the score of all samples.
- Returns:
The log-likelihood or the pseudo-\(R^2\) of the current model.
- Return type:
score
- Raises:
NotFittedError – If
fit
has not been called first with this instance.ValueError – If X structure doesn’t match the params, and if X and y have different number of samples.
Examples
>>> # example input >>> import numpy as np >>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10) >>> import nemos as nmo >>> model = nmo.glm.GLM() >>> model = model.fit(X, y) >>> # get model score >>> log_likelihood_score = model.score(X, y) >>> # get a pseudo-R2 score >>> pseudo_r2_score = model.score(X, y, score_type='pseudo-r2-McFadden')
Notes
The log-likelihood is not on a standard scale, its value is influenced by many factors, among which the number of model parameters. The log-likelihood can assume both positive and negative values.
The Pseudo-\(R^2\) is not equivalent to the \(R^2\) value in linear regression. While both provide a measure of model fit, and assume values in the [0,1] range, the methods and interpretations can differ. The Pseudo-\(R^2\) is particularly useful for generalized linear models when the interpretation of the \(R^2\) as explained variance does not apply (i.e., when the observations are not Gaussian distributed).
Why does the traditional \(R^2\) is usually a poor measure of performance in GLMs?
In the context of GLMs the variance and the mean of the observations are related. Ignoring the relation between them can result in underestimating the model performance; for instance, when we model a Poisson variable with large mean we expect an equally large variance. In this scenario, even if our model perfectly captures the mean, the high-variance will result in large residuals and low \(R^2\). Additionally, when the mean of the observations varies, the variance will vary too. This violates the “homoschedasticity” assumption, necessary for interpreting the \(R^2\) as variance explained.
The \(R^2\) capture the variance explained when the relationship between the observations and the predictors is linear. In GLMs, the link function sets a non-linear mapping between the predictors and the mean of the observations, compromising the interpretation of the \(R^2\).
Note that it is possible to re-normalized the residuals by a mean-dependent quantity proportional to the model standard deviation (i.e. Pearson residuals). This “rescaled” residual distribution however deviates substantially from normality for counting data with low mean (common for spike counts). Therefore, even the Pearson residuals performs poorly as a measure of fit quality, especially for GLM modeling counting data.
Refer to the
nmo.observation_models.Observations
concrete subclasses for the likelihood and pseudo-\(R^2\) equations.
- set_params(**params)#
Manage warnings in case of multiple parameter settings.
- Parameters:
params (Any)
- simulate(random_key, feedforward_input)[source]#
Simulate neural activity in response to a feed-forward input.
- Parameters:
random_key (
Array
) – jax.random.key for seeding the simulation.feedforward_input (
Union
[Array
,FeaturePytree
]) – External input matrix to the model, representing factors like convolved currents, light intensities, etc. When not provided, the simulation is done with coupling-only. Array of shape (n_time_bins, n_basis_input) or pytree of same.
- Return type:
Tuple
[Array
,Array
]- Returns:
simulated_activity – Simulated activity (spike counts for Poisson GLMs) for the neuron over time. Shape:
(n_time_bins, )
.firing_rates – Simulated rates for the neuron over time. Shape,
(n_time_bins, )
.
- Raises:
NotFittedError –
If the model hasn’t been fitted prior to calling this method.
ValueError –
If the instance has not been previously fitted.
Examples
>>> # example input >>> import numpy as np >>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10) >>> # define and fit model >>> import nemos as nmo >>> model = nmo.glm.GLM() >>> model = model.fit(X, y) >>> # generate spikes and rates >>> random_key = jax.random.key(123) >>> Xnew = np.random.normal(size=(20, X.shape[1])) >>> spikes, rates = model.simulate(random_key, Xnew)
See also
nemos.glm.GLM.predict()
Method to predict rates based on the model’s parameters.
- property solver_init_state: None | Callable[[Any, Array, Array], NamedTuple]#
Provides the initialization function for the solver’s state.
This function is responsible for initializing the solver’s state, necessary for the start of the optimization process. It sets up initial values for parameters like gradients and step sizes based on the model configuration and input data.
- Returns:
The function to initialize the state of the solver, if available; otherwise, None if the solver has not yet been instantiated.
- property solver_kwargs#
Getter for the solver_kwargs attribute.
- property solver_name: str#
Getter for the solver_name attribute.
- property solver_run: None | Callable[[Any, Array, Array], OptStep]#
Provides the function to execute the solver’s optimization process.
This function runs the solver using the initialized parameters and state, performing the optimization to fit the model to the data. It iteratively updates the model parameters until a stopping criterion is met, such as convergence or exceeding a maximum number of iterations.
- Returns:
The function to run the solver’s optimization process, if available; otherwise, None if the solver has not yet been instantiated.
- property solver_update: None | Callable[[Any, NamedTuple, Array, Array], OptStep]#
Provides the function for updating the solver’s state during the optimization process.
This function is used to perform a single update step in the optimization process. It updates the model’s parameters based on the current state, data, and gradients. It is typically used in scenarios where fine-grained control over each optimization step is necessary, such as in online learning or complex optimization scenarios.
- Returns:
The function to update the solver’s state, if available; otherwise, None if the solver has not yet been instantiated.
- update(params, opt_state, X, y, *args, n_samples=None, **kwargs)[source]#
Update the model parameters and solver state.
This method performs a single optimization step using the model’s current solver. It updates the model’s coefficients and intercept based on the provided parameters, predictors (X), responses (y), and the current optimization state. This method is particularly useful for iterative model fitting, especially in scenarios where model parameters need to be updated incrementally, such as online learning or when dealing with very large datasets that do not fit into memory at once.
- Parameters:
params (
Tuple
[Array
,Array
]) – The current model parameters, typically a tuple of coefficients and intercepts.opt_state (
NamedTuple
) – The current state of the optimizer, encapsulating information necessary for the optimization algorithm to continue from the current state. This includes gradients, step sizes, and other optimizer-specific metrics.X (
Union
[Array
,FeaturePytree
]) – The predictors used in the model fitting process, which may include feature matrices ornemos.pytrees.FeaturePytree
objects.y (
Array
) – The response variable or output data corresponding to the predictors, used in the model fitting process.*args – Additional positional arguments to be passed to the solver’s update method.
n_samples (
Optional
[int
]) – The tot number of samples. Usually larger than the samples of an indivisual batch, then_samples
are used to estimate the scale parameter of the GLM.**kwargs – Additional keyword arguments to be passed to the solver’s update method.
- Returns:
A tuple containing the updated parameters and optimization state. This tuple is typically used to continue the optimization process in subsequent steps.
- Return type:
jaxopt.OptStep
- Raises:
ValueError – If the solver has not been instantiated or if the solver returns NaN values indicating an invalid update step, typically due to numerical instabilities or inappropriate solver configurations.
Examples
>>> import nemos as nmo >>> import numpy as np >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) >>> glm_instance = nmo.glm.GLM().fit(X, y) >>> params = glm_instance.coef_, glm_instance.intercept_ >>> opt_state = glm_instance.solver_state_ >>> new_params, new_opt_state = glm_instance.update(params, opt_state, X, y)