nemos.glm.GLM#
- class nemos.glm.GLM(observation_model='Poisson', inverse_link_function=None, regularizer=None, regularizer_strength=None, solver_name=None, solver_kwargs=None)[source]#
Bases:
BaseRegressorGeneralized Linear Model (GLM) for neural activity data.
This GLM implementation allows users to model neural activity based on a combination of exogenous inputs (like convolved currents or light intensities) and a choice of observation model. It is suitable for scenarios where the relationship between predictors and the response variable might be non-linear, and the residuals don’t follow a normal distribution.
Below is a table of the default inverse link function for the availabe observation model.
Observation Model
Default Inverse Link Function
Poisson
\(e^x\)
Gamma
\(1/x\)
Bernoulli
\(1 / (1 + e^{-x})\)
NegativeBinomial
\(e^x\)
Gaussian
\(x\)
Below is a table listing the default and available solvers for each regularizer.
Regularizer
Default Solver
Available Solvers
UnRegularized
LBFGS
GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient
Ridge
LBFGS
GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient
Lasso
ProximalGradient
ProximalGradient
GroupLasso
ProximalGradient
ProximalGradient
Fitting Large Models
For very large models, you may consider using the Stochastic Variance Reduced Gradient
nemos.solvers._svrg.SVRGor its proximal variantnemos.solvers._svrg.ProxSVRGsolver, which take advantage of batched computation. You can change the solver by passing"SVRG"assolver_nameat model initialization.The performance of the SVRG solver depends critically on the choice of
batch_sizeandstepsizehyperparameters. These parameters control the size of the mini-batches used for gradient computations and the step size for each iteration, respectively. Improper selection of these parameters can lead to slow convergence or even divergence of the optimization process.To assist with this, for certain GLM configurations, we provide
batch_sizeandstepsizedefault values that are theoretically guaranteed to ensure fast convergence.Below is a list of the configurations for which we can provide guaranteed default hyperparameters:
GLM / PopulationGLM Configuration
Stepsize
Batch Size
Poisson + soft-plus + UnRegularized
✅
❌
Poisson + soft-plus + Ridge
✅
✅
Poisson + soft-plus + Lasso
✅
❌
Poisson + soft-plus + GroupLasso
✅
❌
- Parameters:
observation_model (
Union[BernoulliObservations,GammaObservations,GaussianObservations,NegativeBinomialObservations,PoissonObservations,Literal['Poisson','Gamma','Bernoulli','NegativeBinomial','Gaussian']]) – Observation model to use. The model describes the distribution of the neural activity. Default is the Poisson model. Alternatives are “Gamma”, “Bernoulli”, “NegativeBinomial” and “Gaussian”.inverse_link_function (
Optional[Callable]) – A function that maps the linear combination of predictors into a firing rate. The default depends on the observation model, see the table above.regularizer (
Union[str,Regularizer,None]) – Regularization to use for model optimization. Defines the regularization scheme and related parameters. Default is UnRegularized regression.regularizer_strength (
Any) – Typically a float. Default is None. Sets the regularizer strength. If a user does not pass a value, and it is needed for regularization, a warning will be raised and the strength will default to 1.0. For finer control, the user can pass a pytree that matches the parameter structure to regularize parameters differentially.solver_name (
str) – Solver to use for model optimization. Defines the optimization scheme and related parameters. The solver must be an appropriate match for the chosen regularizer. Default isNone. If no solver specified, one will be chosen based on the regularizer. Please see table above for regularizer/optimizer pairings.solver_kwargs (
dict) –Optional dictionary for keyword arguments that are passed to the solver when instantiated. E.g. stepsize, tol, acceleration, etc.
For details on each solver’s kwargs, see get_accepted_arguments and get_solver_documentation.
- intercept_#
Model baseline linked firing rate parameters, e.g. if the link is the logarithm, the baseline firing rate will be
jnp.exp(model.intercept_).
- coef_#
Basis coefficients for the model.
- solver_state_#
State of the solver after fitting. May include details like optimization error.
- scale_#
Scale parameter for the model. The scale parameter is the constant \(\Phi\), for which \(\text{Var} \left( y \right) = \Phi V(\mu)\). This parameter, together with the estimate of the mean \(\mu\) fully specifies the distribution of the activity \(y\).
- dof_resid_#
Degrees of freedom for the residuals.
- Raises:
TypeError – If provided
regularizerorobservation_modelare not valid.- Parameters:
observation_model (REGRESSION_GLM_TYPES | Literal['Poisson', 'Gamma', 'Gaussian', 'Bernoulli', 'NegativeBinomial'])
inverse_link_function (Optional[Callable])
regularizer (Optional[Union[str, Regularizer]])
regularizer_strength (Any)
solver_name (str)
solver_kwargs (dict)
Examples
Fit a GLM
Basic model fitting with default Poisson observation model:
>>> import numpy as np >>> import nemos as nmo >>> np.random.seed(123) >>> X = np.random.normal(size=(100, 5)) >>> y = np.random.poisson(size=100) >>> model = nmo.glm.GLM().fit(X, y) >>> model.coef_.shape (5,)
Customize the Observation Model
Specify the observation model as a string:
>>> model = nmo.glm.GLM(observation_model="Gamma") >>> model.observation_model GammaObservations()
Or pass the observation model object directly:
>>> model = nmo.glm.GLM(observation_model=nmo.observation_models.GammaObservations()) >>> model.observation_model GammaObservations()
Customize the Inverse Link Function
Use a soft-plus inverse link function instead of the default exponential:
>>> model = nmo.glm.GLM(inverse_link_function=jax.nn.softplus) >>> model.inverse_link_function.__name__ 'softplus'
Use Regularization
Fit with Ridge regularization:
>>> model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1) >>> model = model.fit(X, y) >>> model.regularizer Ridge()
Fit with Lasso regularization for sparse coefficients:
>>> model = nmo.glm.GLM(regularizer="Lasso", regularizer_strength=0.01) >>> model = model.fit(X, y) >>> model.regularizer Lasso()
Select a Solver
Use LBFGS solver for potentially faster convergence:
>>> model = nmo.glm.GLM(solver_name="BFGS").fit(X, y) >>> model.solver_name 'BFGS'
Use a Pytree of arrays as Input
Features can be passed as any JAX pytree of 2-D arrays; the fitted
coef_will share the same pytree structure:>>> X_dict = {"input_1": X[:, :2], "input_2": X[:, 2:]} >>> model = nmo.glm.GLM().fit(X_dict, y) >>> # The coefficient structure will match the input. >>> type(model.coef_) <class 'dict'>
Attributes
Inverse link function mapping the linear predictor to the response space.
The observation model governing the conditional distribution of
y.Provides the initialization function for the optimizer state.
Provides the function to execute the optimization process.
Provides the function for updating the state during the optimization process.
Getter for the regularizer attribute.
Regularizer strength getter.
Getter for the solver class.
Getter for the solver_kwargs attribute.
Getter for the solver_name attribute.
Getter for the solver specification.
- __init__(observation_model='Poisson', inverse_link_function=None, regularizer=None, regularizer_strength=None, solver_name=None, solver_kwargs=None)[source]#
- Parameters:
observation_model (BernoulliObservations | GammaObservations | GaussianObservations | NegativeBinomialObservations | PoissonObservations | Literal['Poisson', 'Gamma', 'Bernoulli', 'NegativeBinomial', 'Gaussian'])
inverse_link_function (Callable | None)
regularizer (str | Regularizer | None)
regularizer_strength (Any)
solver_name (str)
solver_kwargs (dict)
Methods
__init__([observation_model, ...])compute_loss(params, X, y, *args, **kwargs)Compute the loss function for the model.
fit(X, y[, init_params])Fit GLM to neural activity.
get_params([deep])From scikit-learn, get parameters by inspecting init.
initialize_optimizer_and_state(init_params, X, y)Initialize the optimization routine and its state for running fit and update.
initialize_params(X, y)Initialize model parameters.
predict(X)Predict rates based on fit parameters.
save_params(filename)Save GLM model parameters to a .npz file.
score(X, y[, score_type, ...])Evaluate the goodness-of-fit of the model to the observed neural data.
set_params(**params)Manage warnings in case of multiple parameter settings.
simulate(random_key, feedforward_input)Simulate neural activity in response to a feed-forward input.
update(params, opt_state, X, y, *args[, ...])Update the model parameters and solver state.
- classmethod __init_subclass__(**kwargs)#
Set the
set_{method}_requestmethods.This uses PEP-487 [1] to set the
set_{method}_requestmethods. It looks for the information available in the set default values which are set using__metadata_request__*class attributes, or inferred from method signatures.The
__metadata_request__*class attributes are used when a method does not explicitly accept a metadata through its arguments or if the developer would like to specify a request value for those metadata which are different from the defaultNone.References
- compute_loss(params, X, y, *args, **kwargs)#
Compute the loss function for the model.
This method validates inputs and converts user-provided parameters to the internal representation before computing the loss.
- Parameters:
params (UserProvidedParamsT) – Parameter tuple of (coefficients, intercept).
X (DESIGN_INPUT_TYPE) – Input data, array of shape
(n_time_bins, n_features)or pytree of same.y (jnp.ndarray) – Target data, array of shape
(n_time_bins,)for single neuron models or(n_time_bins, n_neurons)for population models.*args – Additional positional arguments passed to the model-specific loss function.
**kwargs – Additional keyword arguments passed to the model-specific loss function.
- Returns:
The loss value (negative log-likelihood).
- Return type:
jnp.ndarray
- Raises:
ValueError – If inputs or parameters have incompatible shapes or invalid values.
- fit(X, y, init_params=None)[source]#
Fit GLM to neural activity.
Fit and store the model parameters as attributes
coef_andcoef_.- Parameters:
X (Union[DESIGN_INPUT_TYPE, ArrayLike]) – Predictors, array of shape (n_time_bins, n_features) or pytree of the same shape.
y (ArrayLike) – Target neural activity arranged in a matrix, shape (n_time_bins, ).
init_params (Optional[GLMUserParams]) – 2-tuple of initial parameter values: (coefficients, intercepts). If None, we initialize coefficients with zeros, intercepts with the log of the mean neural activity. coefficients is an array of shape (n_features,) or pytree of same, intercepts is an array of shape (1, )
- Raises:
ValueError – If
init_paramsis not of length two.ValueError – If dimensionality of
init_paramsare not correct.ValueError – If
Xis not two-dimensional.ValueError – If
yis not one-dimensional.ValueError – If solver returns at least one NaN parameter, which means it found an invalid solution. Try tuning optimization hyperparameters.
TypeError – If
init_paramsare not array-likeTypeError – If
init_params[i]cannot be converted tojnp.ndarrayfor alli
Examples
>>> # example input >>> import numpy as np >>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10) >>> # fit a ridge regression Poisson GLM >>> import nemos as nmo >>> model = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1) >>> model = model.fit(X, y) >>> # get model weights and intercept >>> model_weights = model.coef_ >>> model_intercept = model.intercept_
- get_metadata_routing()#
Get metadata routing of this object.
Please check User Guide on how the routing mechanism works.
- Returns:
routing – A
MetadataRequestencapsulating routing information.- Return type:
MetadataRequest
- get_params(deep=True)#
From scikit-learn, get parameters by inspecting init.
- Parameters:
deep – If True, will return the parameters for this estimator and contained subobjects that are estimators.
- Return type:
- Returns:
A dictionary containing the parameters. Key is the parameter name, value is the parameter value.
- initialize_optimizer_and_state(init_params, X, y)#
Initialize the optimization routine and its state for running fit and update.
This method must be called before using
update()for iterative optimization. It sets up the solver with the provided initial parameters and data.- Parameters:
X (DESIGN_INPUT_TYPE) – Input data, array of shape
(n_time_bins, n_features)or pytree of same.y (jnp.ndarray) – Target data, array of shape
(n_time_bins,)for single neuron models or(n_time_bins, n_neurons)for population models.init_params (UserProvidedParamsT) – Initial parameter tuple of (coefficients, intercept).
- Returns:
Initial solver state.
- Return type:
SolverState
- Raises:
ValueError – If inputs or parameters have incompatible shapes or invalid values.
- initialize_params(X, y)#
Initialize model parameters.
Initialize coefficients with zeros and intercept by matching the mean firing rate.
- Parameters:
X (DESIGN_INPUT_TYPE) – Input data, array of shape
(n_time_bins, n_features)or pytree of same.y (jnp.ndarray) – Target data, array of shape
(n_time_bins,)for single neuron models or(n_time_bins, n_neurons)for population models.
- Returns:
Initial parameter tuple of (coefficients, intercept).
- Return type:
UserProvidedParamsT
- property inverse_link_function#
Inverse link function mapping the linear predictor to the response space.
Always a callable. If
Nonewas passed at construction time, this is resolved to the observation model’s default (e.g.jnp.expfor Poisson,1 / xfor Gamma,jax.nn.sigmoidfor Bernoulli).
- property observation_model: None | Observations#
The observation model governing the conditional distribution of
y.Always an instance of an
Observationssubclass. If a string alias was passed at construction time it is resolved to the corresponding instance here.
- property optimizer_init_state: None | Callable[[Any, Array, Array], SolverState]#
Provides the initialization function for the optimizer state.
This function is responsible for initializing the optimizer state, necessary for the start of the optimizer process. It sets up initial values for parameters like gradients and step sizes based on the model configuration and input data.
- Returns:
The function to initialize the optimizer state, if available; otherwise, None if the optimizer has not yet been instantiated.
- property optimizer_run: None | Callable[[Any, Array, Array], Tuple[Any, SolverState, Aux]]#
Provides the function to execute the optimization process.
This function runs the optimizer using the initialized parameters and state, performing the optimization to fit the model to the data. It iteratively updates the model parameters until a stopping criterion is met, such as convergence or exceeding a maximum number of iterations.
- Returns:
The function to run the optimization process, if available; otherwise, None if the optimizer has not yet been instantiated.
- property optimizer_update: None | Callable[[Any, NamedTuple, Array, Array], Tuple[Any, SolverState, Aux]]#
Provides the function for updating the state during the optimization process.
This function is used to perform a single update step in the optimization process. It updates the model’s parameters based on the current state, data, and gradients. It is typically used in scenarios where fine-grained control over each optimizer step is necessary, such as in online learning or complex optimization scenarios.
- Returns:
The function to perform a single optimization update step, if available; otherwise, None if the optimizer has not yet been instantiated.
- predict(X)[source]#
Predict rates based on fit parameters.
- Parameters:
X (DESIGN_INPUT_TYPE) – Predictors, array of shape
(n_time_bins, n_features)or a pytree of arrays of the same shape.- Return type:
jnp.ndarray
- Returns:
The predicted rates with shape
(n_time_bins, ).- Raises:
NotFittedError – If
fithas not been called first with this instance.ValueError – If
paramsis not a JAX pytree of size two.ValueError – If weights and bias terms in
paramsdon’t have the expected dimensions.ValueError – If
Xis not three-dimensional.ValueError – If there’s an inconsistent number of features between spike basis coefficients and
X.
Examples
>>> # example input >>> import numpy as np >>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10) >>> # define and fit a GLM >>> import nemos as nmo >>> model = nmo.glm.GLM() >>> model = model.fit(X, y) >>> # predict new spike data >>> Xnew = np.random.normal(size=(20, X.shape[1])) >>> predicted_spikes = model.predict(Xnew)
See also
nemos.glm.GLM.score()Score predicted rates against target spike counts.
nemos.glm.GLM.simulate()Simulate neural activity in response to a feed-forward input (feed-forward only).
nemos.simulation.simulate_recurrent()Simulate neural activity in response to a feed-forward input using the GLM as a recurrent network (feed-forward + coupling).
- property regularizer: None | Regularizer#
Getter for the regularizer attribute.
- save_params(filename)[source]#
Save GLM model parameters to a .npz file.
This method allows to reuse the model parameters. The saved parameters can be loaded back into a GLM instance using the load_params function.
- Parameters:
filename (
Union[str,Path]) – The name of the file where the model parameters will be saved. The file will be saved in .npz format.
Examples
>>> import nemos as nmo >>> # Create a GLM model with specified parameters >>> solver_args = {"stepsize": 0.1, "maxiter": 1000, "tol": 1e-6} >>> model = nmo.glm.GLM( ... regularizer="Ridge", ... regularizer_strength=0.1, ... observation_model="Gamma", ... solver_name="BFGS", ... solver_kwargs=solver_args, ... ) >>> for key, value in model.get_params().items(): ... print(f"{key}: {value}") inverse_link_function: <function one_over_x at ...> observation_model: GammaObservations() regularizer: Ridge() regularizer_strength: 0.1... solver_kwargs: {'stepsize': 0.1, 'maxiter': 1000, 'tol': 1e-06} solver_name: BFGS >>> # Save the model parameters to a file >>> model.save_params("model_params.npz") >>> # Load the model from the saved file >>> model = nmo.load_model("model_params.npz") >>> # Model has the same parameters before and after load >>> for key, value in model.get_params().items(): ... print(f"{key}: {value}") inverse_link_function: <function one_over_x at ...> observation_model: GammaObservations() regularizer: Ridge() regularizer_strength: 0.1 solver_kwargs: {'maxiter': 1000, 'stepsize': 0.1, 'tol': 1e-06} solver_name: BFGS
>>> # Saving and loading a custom inverse link function >>> model = nmo.glm.GLM( ... observation_model="Poisson", ... inverse_link_function=lambda x: x**2 ... ) >>> model.save_params("model_params.npz") >>> # Provide a mapping for the custom link function when loading. >>> mapping_dict = { ... "inverse_link_function": lambda x: x**2, ... } >>> loaded_model = nmo.load_model("model_params.npz", mapping_dict=mapping_dict) >>> # Now the loaded model will have the updated solver_name and solver_kwargs >>> for key, value in loaded_model.get_params().items(): ... print(f"{key}: {value}") inverse_link_function: <function <lambda> at ...> observation_model: PoissonObservations() regularizer: UnRegularized() regularizer_strength: None solver_kwargs: {} solver_name: LBFGS
- score(X, y, score_type='log-likelihood', aggregate_sample_scores=<function mean>)[source]#
Evaluate the goodness-of-fit of the model to the observed neural data.
This method computes the goodness-of-fit score, which can either be the mean log-likelihood or of two versions of the pseudo-\(R^2\). The scoring process includes validation of input compatibility with the model’s parameters, ensuring that the model has been previously fitted and the input data are appropriate for scoring. A higher score indicates a better fit of the model to the observed data.
- Parameters:
X (Union[DESIGN_INPUT_TYPE, ArrayLike]) – Predictors, array of shape
(n_time_bins, n_features)or a pytree of arrays of the same shape.y (ArrayLike) – Neural activity. Shape
(n_time_bins, ).score_type (Literal[‘log-likelihood’, ‘pseudo-r2-McFadden’, ‘pseudo-r2-Cohen’]) – Type of scoring: either log-likelihood or pseudo-\(R^2\).
aggregate_sample_scores (Callable) – Function that aggregates the score of all samples.
- Returns:
The log-likelihood or the pseudo-\(R^2\) of the current model.
- Return type:
jnp.ndarray
- Raises:
NotFittedError – If
fithas not been called first with this instance.ValueError – If X structure doesn’t match the params, and if X and y have different number of samples.
Examples
>>> # example input >>> import numpy as np >>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10) >>> import nemos as nmo >>> model = nmo.glm.GLM() >>> model = model.fit(X, y) >>> # get model score >>> log_likelihood_score = model.score(X, y) >>> # get a pseudo-R2 score >>> pseudo_r2_score = model.score(X, y, score_type='pseudo-r2-McFadden')
Notes
The log-likelihood is not on a standard scale, its value is influenced by many factors, among which the number of model parameters. The log-likelihood can assume both positive and negative values.
The Pseudo-\(R^2\) is not equivalent to the \(R^2\) value in linear regression. While both provide a measure of model fit, and assume values in the [0,1] range, the methods and interpretations can differ. The Pseudo-\(R^2\) is particularly useful for generalized linear models when the interpretation of the \(R^2\) as explained variance does not apply (i.e., when the observations are not Gaussian distributed).
Why does the traditional \(R^2\) is usually a poor measure of performance in GLMs?
In the context of GLMs the variance and the mean of the observations are related. Ignoring the relation between them can result in underestimating the model performance; for instance, when we model a Poisson variable with large mean we expect an equally large variance. In this scenario, even if our model perfectly captures the mean, the high-variance will result in large residuals and low \(R^2\). Additionally, when the mean of the observations varies, the variance will vary too. This violates the “homoschedasticity” assumption, necessary for interpreting the \(R^2\) as variance explained.
The \(R^2\) capture the variance explained when the relationship between the observations and the predictors is linear. In GLMs, the link function sets a non-linear mapping between the predictors and the mean of the observations, compromising the interpretation of the \(R^2\).
Note that it is possible to re-normalized the residuals by a mean-dependent quantity proportional to the model standard deviation (i.e. Pearson residuals). This “rescaled” residual distribution however deviates substantially from normality for counting data with low mean (common for spike counts). Therefore, even the Pearson residuals performs poorly as a measure of fit quality, especially for GLM modeling counting data.
Refer to the
nmo.observation_models.Observationsconcrete subclasses for the likelihood and pseudo-\(R^2\) equations.
- set_fit_request(*, init_params='$UNCHANGED$')#
Configure whether metadata should be requested to be passed to the
fitmethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- set_params(**params)#
Manage warnings in case of multiple parameter settings.
- Parameters:
params (Any)
- set_score_request(*, aggregate_sample_scores='$UNCHANGED$', score_type='$UNCHANGED$')#
Configure whether metadata should be requested to be passed to the
scoremethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed toscoreif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it toscore.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- Parameters:
aggregate_sample_scores (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
aggregate_sample_scoresparameter inscore.score_type (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
score_typeparameter inscore.self (GLM)
- Returns:
self – The updated object.
- Return type:
- simulate(random_key, feedforward_input)[source]#
Simulate neural activity in response to a feed-forward input.
- Parameters:
random_key (jax.Array) – jax.random.key for seeding the simulation.
feedforward_input (DESIGN_INPUT_TYPE) – External input predictors to the model, representing factors like convolved currents, light intensities, etc. When not provided, the simulation is done with coupling-only. Array of shape (n_time_bins, n_basis_input) or pytree with leaves of the same shape.
- Return type:
Tuple[jnp.ndarray, jnp.ndarray]
- Returns:
simulated_activity – Simulated activity (spike counts for Poisson GLMs) for the neuron over time. Shape:
(n_time_bins, ).firing_rates – Simulated rates for the neuron over time. Shape,
(n_time_bins, ).
- Raises:
NotFittedError –
If the model hasn’t been fitted prior to calling this method.
If the instance has not been previously fitted.
Examples
>>> # example input >>> import numpy as np >>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10) >>> # define and fit model >>> import nemos as nmo >>> model = nmo.glm.GLM() >>> model = model.fit(X, y) >>> # generate spikes and rates >>> random_key = jax.random.key(123) >>> Xnew = np.random.normal(size=(20, X.shape[1])) >>> spikes, rates = model.simulate(random_key, Xnew)
See also
nemos.glm.GLM.predict()Method to predict rates based on the model’s parameters.
- property solver#
Getter for the solver class.
- property solver_kwargs#
Getter for the solver_kwargs attribute.
- property solver_spec: SolverSpec#
Getter for the solver specification.
- update(params, opt_state, X, y, *args, n_samples=None, **kwargs)[source]#
Update the model parameters and solver state.
This method performs a single optimization step using the model’s current solver. It updates the model’s coefficients and intercept based on the provided parameters, predictors (X), responses (y), and the current optimization state. This method is particularly useful for iterative model fitting, especially in scenarios where model parameters need to be updated incrementally, such as online learning or when dealing with very large datasets that do not fit into memory at once.
- Parameters:
params (GLMUserParams) – The current model parameters, typically a tuple of coefficients and intercepts.
opt_state (SolverState) – The current state of the optimizer, encapsulating information necessary for the optimization algorithm to continue from the current state. This includes gradients, step sizes, and other optimizer-specific metrics.
X (DESIGN_INPUT_TYPE) – The predictors used in the model fitting process, which may include feature matrices or a pytree of arrays. Shape
(n_time_bins, n_features).y (jnp.ndarray) – The response variable or output data corresponding to the predictors. Shape
(n_time_bins,).*args – Additional positional arguments to be passed to the solver’s update method.
n_samples (Optional[int]) – The total number of samples. Usually larger than the samples of an individual batch, the
n_samplesare used to estimate the scale parameter of the GLM.**kwargs – Additional keyword arguments to be passed to the solver’s update method.
- Return type:
StepResult
- Returns:
params – Updated model parameters (coefficients, intercepts).
state – Updated optimizer state.
- Raises:
ValueError – If the solver has not been instantiated or if the solver returns NaN values indicating an invalid update step, typically due to numerical instabilities or inappropriate solver configurations.
Examples
>>> import nemos as nmo >>> import numpy as np >>> import jax >>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10) >>> glm_instance = nmo.glm.GLM() >>> params = glm_instance.initialize_params(X, y) >>> opt_state = glm_instance.initialize_optimizer_and_state(params, X, y) >>> new_params, new_opt_state = glm_instance.update(params, opt_state, X, y)