nemos.glm.GLM#

class nemos.glm.GLM(observation_model='Poisson', inverse_link_function=None, regularizer=None, regularizer_strength=None, solver_name=None, solver_kwargs=None)[source]#

Bases: BaseRegressor

Generalized Linear Model (GLM) for neural activity data.

This GLM implementation allows users to model neural activity based on a combination of exogenous inputs (like convolved currents or light intensities) and a choice of observation model. It is suitable for scenarios where the relationship between predictors and the response variable might be non-linear, and the residuals don’t follow a normal distribution.

Below is a table of the default inverse link function for the availabe observation model.

Observation Model

Default Inverse Link Function

Poisson

\(e^x\)

Gamma

\(1/x\)

Bernoulli

\(1 / (1 + e^{-x})\)

NegativeBinomial

\(e^x\)

Gaussian

\(x\)

Below is a table listing the default and available solvers for each regularizer.

Regularizer

Default Solver

Available Solvers

UnRegularized

LBFGS

GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient

Ridge

LBFGS

GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient

Lasso

ProximalGradient

ProximalGradient

GroupLasso

ProximalGradient

ProximalGradient

Fitting Large Models

For very large models, you may consider using the Stochastic Variance Reduced Gradient nemos.solvers._svrg.SVRG or its proximal variant nemos.solvers._svrg.ProxSVRG solver, which take advantage of batched computation. You can change the solver by passing "SVRG" as solver_name at model initialization.

The performance of the SVRG solver depends critically on the choice of batch_size and stepsize hyperparameters. These parameters control the size of the mini-batches used for gradient computations and the step size for each iteration, respectively. Improper selection of these parameters can lead to slow convergence or even divergence of the optimization process.

To assist with this, for certain GLM configurations, we provide batch_size and stepsize default values that are theoretically guaranteed to ensure fast convergence.

Below is a list of the configurations for which we can provide guaranteed default hyperparameters:

GLM / PopulationGLM Configuration

Stepsize

Batch Size

Poisson + soft-plus + UnRegularized

Poisson + soft-plus + Ridge

Poisson + soft-plus + Lasso

Poisson + soft-plus + GroupLasso

Parameters:
  • observation_model (Union[BernoulliObservations, GammaObservations, GaussianObservations, NegativeBinomialObservations, PoissonObservations, Literal['Poisson', 'Gamma', 'Bernoulli', 'NegativeBinomial', 'Gaussian']]) – Observation model to use. The model describes the distribution of the neural activity. Default is the Poisson model. Alternatives are “Gamma”, “Bernoulli”, “NegativeBinomial” and “Gaussian”.

  • inverse_link_function (Optional[Callable]) – A function that maps the linear combination of predictors into a firing rate. The default depends on the observation model, see the table above.

  • regularizer (Union[str, Regularizer, None]) – Regularization to use for model optimization. Defines the regularization scheme and related parameters. Default is UnRegularized regression.

  • regularizer_strength (Any) – Typically a float. Default is None. Sets the regularizer strength. If a user does not pass a value, and it is needed for regularization, a warning will be raised and the strength will default to 1.0. For finer control, the user can pass a pytree that matches the parameter structure to regularize parameters differentially.

  • solver_name (str) – Solver to use for model optimization. Defines the optimization scheme and related parameters. The solver must be an appropriate match for the chosen regularizer. Default is None. If no solver specified, one will be chosen based on the regularizer. Please see table above for regularizer/optimizer pairings.

  • solver_kwargs (dict) –

    Optional dictionary for keyword arguments that are passed to the solver when instantiated. E.g. stepsize, tol, acceleration, etc.

    For details on each solver’s kwargs, see get_accepted_arguments and get_solver_documentation.

intercept_#

Model baseline linked firing rate parameters, e.g. if the link is the logarithm, the baseline firing rate will be jnp.exp(model.intercept_).

coef_#

Basis coefficients for the model.

solver_state_#

State of the solver after fitting. May include details like optimization error.

scale_#

Scale parameter for the model. The scale parameter is the constant \(\Phi\), for which \(\text{Var} \left( y \right) = \Phi V(\mu)\). This parameter, together with the estimate of the mean \(\mu\) fully specifies the distribution of the activity \(y\).

dof_resid_#

Degrees of freedom for the residuals.

Raises:

TypeError – If provided regularizer or observation_model are not valid.

Parameters:
  • observation_model (REGRESSION_GLM_TYPES | Literal['Poisson', 'Gamma', 'Gaussian', 'Bernoulli', 'NegativeBinomial'])

  • inverse_link_function (Optional[Callable])

  • regularizer (Optional[Union[str, Regularizer]])

  • regularizer_strength (Any)

  • solver_name (str)

  • solver_kwargs (dict)

Examples

Fit a GLM

Basic model fitting with default Poisson observation model:

>>> import numpy as np
>>> import nemos as nmo
>>> np.random.seed(123)
>>> X = np.random.normal(size=(100, 5))
>>> y = np.random.poisson(size=100)
>>> model = nmo.glm.GLM().fit(X, y)
>>> model.coef_.shape
(5,)

Customize the Observation Model

Specify the observation model as a string:

>>> model = nmo.glm.GLM(observation_model="Gamma")
>>> model.observation_model
GammaObservations()

Or pass the observation model object directly:

>>> model = nmo.glm.GLM(observation_model=nmo.observation_models.GammaObservations())
>>> model.observation_model
GammaObservations()

Customize the Inverse Link Function

Use a soft-plus inverse link function instead of the default exponential:

>>> model = nmo.glm.GLM(inverse_link_function=jax.nn.softplus)
>>> model.inverse_link_function.__name__
'softplus'

Use Regularization

Fit with Ridge regularization:

>>> model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1)
>>> model = model.fit(X, y)
>>> model.regularizer
Ridge()

Fit with Lasso regularization for sparse coefficients:

>>> model = nmo.glm.GLM(regularizer="Lasso", regularizer_strength=0.01)
>>> model = model.fit(X, y)
>>> model.regularizer
Lasso()

Select a Solver

Use LBFGS solver for potentially faster convergence:

>>> model = nmo.glm.GLM(solver_name="BFGS").fit(X, y)
>>> model.solver_name
'BFGS'

Use a Pytree of arrays as Input

Features can be passed as any JAX pytree of 2-D arrays; the fitted coef_ will share the same pytree structure:

>>> X_dict = {"input_1": X[:, :2], "input_2": X[:, 2:]}
>>> model = nmo.glm.GLM().fit(X_dict, y)
>>> # The coefficient structure will match the input.
>>> type(model.coef_)
<class 'dict'>

Attributes

inverse_link_function

Inverse link function mapping the linear predictor to the response space.

observation_model

The observation model governing the conditional distribution of y.

optimizer_init_state

Provides the initialization function for the optimizer state.

optimizer_run

Provides the function to execute the optimization process.

optimizer_update

Provides the function for updating the state during the optimization process.

regularizer

Getter for the regularizer attribute.

regularizer_strength

Regularizer strength getter.

solver

Getter for the solver class.

solver_kwargs

Getter for the solver_kwargs attribute.

solver_name

Getter for the solver_name attribute.

solver_spec

Getter for the solver specification.

__init__(observation_model='Poisson', inverse_link_function=None, regularizer=None, regularizer_strength=None, solver_name=None, solver_kwargs=None)[source]#
Parameters:

Methods

__init__([observation_model, ...])

compute_loss(params, X, y, *args, **kwargs)

Compute the loss function for the model.

fit(X, y[, init_params])

Fit GLM to neural activity.

get_params([deep])

From scikit-learn, get parameters by inspecting init.

initialize_optimizer_and_state(init_params, X, y)

Initialize the optimization routine and its state for running fit and update.

initialize_params(X, y)

Initialize model parameters.

predict(X)

Predict rates based on fit parameters.

save_params(filename)

Save GLM model parameters to a .npz file.

score(X, y[, score_type, ...])

Evaluate the goodness-of-fit of the model to the observed neural data.

set_params(**params)

Manage warnings in case of multiple parameter settings.

simulate(random_key, feedforward_input)

Simulate neural activity in response to a feed-forward input.

update(params, opt_state, X, y, *args[, ...])

Update the model parameters and solver state.

classmethod __init_subclass__(**kwargs)#

Set the set_{method}_request methods.

This uses PEP-487 [1] to set the set_{method}_request methods. It looks for the information available in the set default values which are set using __metadata_request__* class attributes, or inferred from method signatures.

The __metadata_request__* class attributes are used when a method does not explicitly accept a metadata through its arguments or if the developer would like to specify a request value for those metadata which are different from the default None.

References

__repr__()[source]#

Representation of the GLM class.

__sklearn_clone__()[source]#

Clone the GLM.

Return type:

GLM

__sklearn_tags__()[source]#

Return GLM specific estimator tags.

compute_loss(params, X, y, *args, **kwargs)#

Compute the loss function for the model.

This method validates inputs and converts user-provided parameters to the internal representation before computing the loss.

Parameters:
  • params (UserProvidedParamsT) – Parameter tuple of (coefficients, intercept).

  • X (DESIGN_INPUT_TYPE) – Input data, array of shape (n_time_bins, n_features) or pytree of same.

  • y (jnp.ndarray) – Target data, array of shape (n_time_bins,) for single neuron models or (n_time_bins, n_neurons) for population models.

  • *args – Additional positional arguments passed to the model-specific loss function.

  • **kwargs – Additional keyword arguments passed to the model-specific loss function.

Returns:

The loss value (negative log-likelihood).

Return type:

jnp.ndarray

Raises:

ValueError – If inputs or parameters have incompatible shapes or invalid values.

fit(X, y, init_params=None)[source]#

Fit GLM to neural activity.

Fit and store the model parameters as attributes coef_ and coef_.

Parameters:
  • X (Union[DESIGN_INPUT_TYPE, ArrayLike]) – Predictors, array of shape (n_time_bins, n_features) or pytree of the same shape.

  • y (ArrayLike) – Target neural activity arranged in a matrix, shape (n_time_bins, ).

  • init_params (Optional[GLMUserParams]) – 2-tuple of initial parameter values: (coefficients, intercepts). If None, we initialize coefficients with zeros, intercepts with the log of the mean neural activity. coefficients is an array of shape (n_features,) or pytree of same, intercepts is an array of shape (1, )

Raises:
  • ValueError – If init_params is not of length two.

  • ValueError – If dimensionality of init_params are not correct.

  • ValueError – If X is not two-dimensional.

  • ValueError – If y is not one-dimensional.

  • ValueError – If solver returns at least one NaN parameter, which means it found an invalid solution. Try tuning optimization hyperparameters.

  • TypeError – If init_params are not array-like

  • TypeError – If init_params[i] cannot be converted to jnp.ndarray for all i

Examples

>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
>>> # fit a ridge regression Poisson GLM
>>> import nemos as nmo
>>> model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1)
>>> model = model.fit(X, y)
>>> # get model weights and intercept
>>> model_weights = model.coef_
>>> model_intercept = model.intercept_
get_metadata_routing()#

Get metadata routing of this object.

Please check User Guide on how the routing mechanism works.

Returns:

routing – A MetadataRequest encapsulating routing information.

Return type:

MetadataRequest

get_params(deep=True)#

From scikit-learn, get parameters by inspecting init.

Parameters:

deep – If True, will return the parameters for this estimator and contained subobjects that are estimators.

Return type:

dict

Returns:

A dictionary containing the parameters. Key is the parameter name, value is the parameter value.

initialize_optimizer_and_state(init_params, X, y)#

Initialize the optimization routine and its state for running fit and update.

This method must be called before using update() for iterative optimization. It sets up the solver with the provided initial parameters and data.

Parameters:
  • X (DESIGN_INPUT_TYPE) – Input data, array of shape (n_time_bins, n_features) or pytree of same.

  • y (jnp.ndarray) – Target data, array of shape (n_time_bins,) for single neuron models or (n_time_bins, n_neurons) for population models.

  • init_params (UserProvidedParamsT) – Initial parameter tuple of (coefficients, intercept).

Returns:

Initial solver state.

Return type:

SolverState

Raises:

ValueError – If inputs or parameters have incompatible shapes or invalid values.

initialize_params(X, y)#

Initialize model parameters.

Initialize coefficients with zeros and intercept by matching the mean firing rate.

Parameters:
  • X (DESIGN_INPUT_TYPE) – Input data, array of shape (n_time_bins, n_features) or pytree of same.

  • y (jnp.ndarray) – Target data, array of shape (n_time_bins,) for single neuron models or (n_time_bins, n_neurons) for population models.

Returns:

Initial parameter tuple of (coefficients, intercept).

Return type:

UserProvidedParamsT

Inverse link function mapping the linear predictor to the response space.

Always a callable. If None was passed at construction time, this is resolved to the observation model’s default (e.g. jnp.exp for Poisson, 1 / x for Gamma, jax.nn.sigmoid for Bernoulli).

property observation_model: None | Observations#

The observation model governing the conditional distribution of y.

Always an instance of an Observations subclass. If a string alias was passed at construction time it is resolved to the corresponding instance here.

property optimizer_init_state: None | Callable[[Any, Array, Array], SolverState]#

Provides the initialization function for the optimizer state.

This function is responsible for initializing the optimizer state, necessary for the start of the optimizer process. It sets up initial values for parameters like gradients and step sizes based on the model configuration and input data.

Returns:

The function to initialize the optimizer state, if available; otherwise, None if the optimizer has not yet been instantiated.

property optimizer_run: None | Callable[[Any, Array, Array], Tuple[Any, SolverState, Aux]]#

Provides the function to execute the optimization process.

This function runs the optimizer using the initialized parameters and state, performing the optimization to fit the model to the data. It iteratively updates the model parameters until a stopping criterion is met, such as convergence or exceeding a maximum number of iterations.

Returns:

The function to run the optimization process, if available; otherwise, None if the optimizer has not yet been instantiated.

property optimizer_update: None | Callable[[Any, NamedTuple, Array, Array], Tuple[Any, SolverState, Aux]]#

Provides the function for updating the state during the optimization process.

This function is used to perform a single update step in the optimization process. It updates the model’s parameters based on the current state, data, and gradients. It is typically used in scenarios where fine-grained control over each optimizer step is necessary, such as in online learning or complex optimization scenarios.

Returns:

The function to perform a single optimization update step, if available; otherwise, None if the optimizer has not yet been instantiated.

predict(X)[source]#

Predict rates based on fit parameters.

Parameters:

X (DESIGN_INPUT_TYPE) – Predictors, array of shape (n_time_bins, n_features) or a pytree of arrays of the same shape.

Return type:

jnp.ndarray

Returns:

The predicted rates with shape (n_time_bins, ).

Raises:
  • NotFittedError – If fit has not been called first with this instance.

  • ValueError – If params is not a JAX pytree of size two.

  • ValueError – If weights and bias terms in params don’t have the expected dimensions.

  • ValueError – If X is not three-dimensional.

  • ValueError – If there’s an inconsistent number of features between spike basis coefficients and X.

Examples

>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
>>> # define and fit a GLM
>>> import nemos as nmo
>>> model = nmo.glm.GLM()
>>> model = model.fit(X, y)
>>> # predict new spike data
>>> Xnew = np.random.normal(size=(20, X.shape[1]))
>>> predicted_spikes = model.predict(Xnew)

See also

nemos.glm.GLM.score()

Score predicted rates against target spike counts.

nemos.glm.GLM.simulate()

Simulate neural activity in response to a feed-forward input (feed-forward only).

nemos.simulation.simulate_recurrent()

Simulate neural activity in response to a feed-forward input using the GLM as a recurrent network (feed-forward + coupling).

property regularizer: None | Regularizer#

Getter for the regularizer attribute.

property regularizer_strength: Any#

Regularizer strength getter.

save_params(filename)[source]#

Save GLM model parameters to a .npz file.

This method allows to reuse the model parameters. The saved parameters can be loaded back into a GLM instance using the load_params function.

Parameters:

filename (Union[str, Path]) – The name of the file where the model parameters will be saved. The file will be saved in .npz format.

Examples

>>> import nemos as nmo
>>> # Create a GLM model with specified parameters
>>> solver_args = {"stepsize": 0.1, "maxiter": 1000, "tol": 1e-6}
>>> model = nmo.glm.GLM(
...     regularizer="Ridge",
...     regularizer_strength=0.1,
...     observation_model="Gamma",
...     solver_name="BFGS",
...     solver_kwargs=solver_args,
... )
>>> for key, value in model.get_params().items():
...     print(f"{key}: {value}")
inverse_link_function: <function one_over_x at ...>
observation_model: GammaObservations()
regularizer: Ridge()
regularizer_strength: 0.1...
solver_kwargs: {'stepsize': 0.1, 'maxiter': 1000, 'tol': 1e-06}
solver_name: BFGS
>>> # Save the model parameters to a file
>>> model.save_params("model_params.npz")
>>> # Load the model from the saved file
>>> model = nmo.load_model("model_params.npz")
>>> # Model has the same parameters before and after load
>>> for key, value in model.get_params().items():
...     print(f"{key}: {value}")
inverse_link_function: <function one_over_x at ...>
observation_model: GammaObservations()
regularizer: Ridge()
regularizer_strength: 0.1
solver_kwargs: {'maxiter': 1000, 'stepsize': 0.1, 'tol': 1e-06}
solver_name: BFGS
>>> # Saving and loading a custom inverse link function
>>> model = nmo.glm.GLM(
...     observation_model="Poisson",
...     inverse_link_function=lambda x: x**2
... )
>>> model.save_params("model_params.npz")
>>> # Provide a mapping for the custom link function when loading.
>>> mapping_dict = {
...     "inverse_link_function": lambda x: x**2,
... }
>>> loaded_model = nmo.load_model("model_params.npz", mapping_dict=mapping_dict)
>>> # Now the loaded model will have the updated solver_name and solver_kwargs
>>> for key, value in loaded_model.get_params().items():
...     print(f"{key}: {value}")
inverse_link_function: <function <lambda> at ...>
observation_model: PoissonObservations()
regularizer: UnRegularized()
regularizer_strength: None
solver_kwargs: {}
solver_name: LBFGS
score(X, y, score_type='log-likelihood', aggregate_sample_scores=<function mean>)[source]#

Evaluate the goodness-of-fit of the model to the observed neural data.

This method computes the goodness-of-fit score, which can either be the mean log-likelihood or of two versions of the pseudo-\(R^2\). The scoring process includes validation of input compatibility with the model’s parameters, ensuring that the model has been previously fitted and the input data are appropriate for scoring. A higher score indicates a better fit of the model to the observed data.

Parameters:
  • X (Union[DESIGN_INPUT_TYPE, ArrayLike]) – Predictors, array of shape (n_time_bins, n_features) or a pytree of arrays of the same shape.

  • y (ArrayLike) – Neural activity. Shape (n_time_bins, ).

  • score_type (Literal[‘log-likelihood’, ‘pseudo-r2-McFadden’, ‘pseudo-r2-Cohen’]) – Type of scoring: either log-likelihood or pseudo-\(R^2\).

  • aggregate_sample_scores (Callable) – Function that aggregates the score of all samples.

Returns:

The log-likelihood or the pseudo-\(R^2\) of the current model.

Return type:

jnp.ndarray

Raises:
  • NotFittedError – If fit has not been called first with this instance.

  • ValueError – If X structure doesn’t match the params, and if X and y have different number of samples.

Examples

>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
>>> import nemos as nmo
>>> model = nmo.glm.GLM()
>>> model = model.fit(X, y)
>>> # get model score
>>> log_likelihood_score = model.score(X, y)
>>> # get a pseudo-R2 score
>>> pseudo_r2_score = model.score(X, y, score_type='pseudo-r2-McFadden')

Notes

The log-likelihood is not on a standard scale, its value is influenced by many factors, among which the number of model parameters. The log-likelihood can assume both positive and negative values.

The Pseudo-\(R^2\) is not equivalent to the \(R^2\) value in linear regression. While both provide a measure of model fit, and assume values in the [0,1] range, the methods and interpretations can differ. The Pseudo-\(R^2\) is particularly useful for generalized linear models when the interpretation of the \(R^2\) as explained variance does not apply (i.e., when the observations are not Gaussian distributed).

Why does the traditional \(R^2\) is usually a poor measure of performance in GLMs?

  1. In the context of GLMs the variance and the mean of the observations are related. Ignoring the relation between them can result in underestimating the model performance; for instance, when we model a Poisson variable with large mean we expect an equally large variance. In this scenario, even if our model perfectly captures the mean, the high-variance will result in large residuals and low \(R^2\). Additionally, when the mean of the observations varies, the variance will vary too. This violates the “homoschedasticity” assumption, necessary for interpreting the \(R^2\) as variance explained.

  2. The \(R^2\) capture the variance explained when the relationship between the observations and the predictors is linear. In GLMs, the link function sets a non-linear mapping between the predictors and the mean of the observations, compromising the interpretation of the \(R^2\).

Note that it is possible to re-normalized the residuals by a mean-dependent quantity proportional to the model standard deviation (i.e. Pearson residuals). This “rescaled” residual distribution however deviates substantially from normality for counting data with low mean (common for spike counts). Therefore, even the Pearson residuals performs poorly as a measure of fit quality, especially for GLM modeling counting data.

Refer to the nmo.observation_models.Observations concrete subclasses for the likelihood and pseudo-\(R^2\) equations.

set_fit_request(*, init_params='$UNCHANGED$')#

Configure whether metadata should be requested to be passed to the fit method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Parameters:
  • init_params (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for init_params parameter in fit.

  • self (GLM)

Returns:

self – The updated object.

Return type:

object

set_params(**params)#

Manage warnings in case of multiple parameter settings.

Parameters:

params (Any)

set_score_request(*, aggregate_sample_scores='$UNCHANGED$', score_type='$UNCHANGED$')#

Configure whether metadata should be requested to be passed to the score method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to score if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to score.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Parameters:
  • aggregate_sample_scores (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for aggregate_sample_scores parameter in score.

  • score_type (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for score_type parameter in score.

  • self (GLM)

Returns:

self – The updated object.

Return type:

object

simulate(random_key, feedforward_input)[source]#

Simulate neural activity in response to a feed-forward input.

Parameters:
  • random_key (jax.Array) – jax.random.key for seeding the simulation.

  • feedforward_input (DESIGN_INPUT_TYPE) – External input predictors to the model, representing factors like convolved currents, light intensities, etc. When not provided, the simulation is done with coupling-only. Array of shape (n_time_bins, n_basis_input) or pytree with leaves of the same shape.

Return type:

Tuple[jnp.ndarray, jnp.ndarray]

Returns:

  • simulated_activity – Simulated activity (spike counts for Poisson GLMs) for the neuron over time. Shape: (n_time_bins, ).

  • firing_rates – Simulated rates for the neuron over time. Shape, (n_time_bins, ).

Raises:
  • NotFittedError

    • If the model hasn’t been fitted prior to calling this method.

  • ValueError

    • If the instance has not been previously fitted.

Examples

>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
>>> # define and fit model
>>> import nemos as nmo
>>> model = nmo.glm.GLM()
>>> model = model.fit(X, y)
>>> # generate spikes and rates
>>> random_key = jax.random.key(123)
>>> Xnew = np.random.normal(size=(20, X.shape[1]))
>>> spikes, rates = model.simulate(random_key, Xnew)

See also

nemos.glm.GLM.predict()

Method to predict rates based on the model’s parameters.

property solver#

Getter for the solver class.

property solver_kwargs#

Getter for the solver_kwargs attribute.

property solver_name: str#

Getter for the solver_name attribute.

property solver_spec: SolverSpec#

Getter for the solver specification.

update(params, opt_state, X, y, *args, n_samples=None, **kwargs)[source]#

Update the model parameters and solver state.

This method performs a single optimization step using the model’s current solver. It updates the model’s coefficients and intercept based on the provided parameters, predictors (X), responses (y), and the current optimization state. This method is particularly useful for iterative model fitting, especially in scenarios where model parameters need to be updated incrementally, such as online learning or when dealing with very large datasets that do not fit into memory at once.

Parameters:
  • params (GLMUserParams) – The current model parameters, typically a tuple of coefficients and intercepts.

  • opt_state (SolverState) – The current state of the optimizer, encapsulating information necessary for the optimization algorithm to continue from the current state. This includes gradients, step sizes, and other optimizer-specific metrics.

  • X (DESIGN_INPUT_TYPE) – The predictors used in the model fitting process, which may include feature matrices or a pytree of arrays. Shape (n_time_bins, n_features).

  • y (jnp.ndarray) – The response variable or output data corresponding to the predictors. Shape (n_time_bins,).

  • *args – Additional positional arguments to be passed to the solver’s update method.

  • n_samples (Optional[int]) – The total number of samples. Usually larger than the samples of an individual batch, the n_samples are used to estimate the scale parameter of the GLM.

  • **kwargs – Additional keyword arguments to be passed to the solver’s update method.

Return type:

StepResult

Returns:

  • params – Updated model parameters (coefficients, intercepts).

  • state – Updated optimizer state.

Raises:

ValueError – If the solver has not been instantiated or if the solver returns NaN values indicating an invalid update step, typically due to numerical instabilities or inappropriate solver configurations.

Examples

>>> import nemos as nmo
>>> import numpy as np
>>> import jax
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
>>> glm_instance = nmo.glm.GLM()
>>> params = glm_instance.initialize_params(X, y)
>>> opt_state = glm_instance.initialize_optimizer_and_state(params, X, y)
>>> new_params, new_opt_state = glm_instance.update(params, opt_state, X, y)