nemos.glm.PopulationGLM#

class nemos.glm.PopulationGLM(observation_model='Poisson', inverse_link_function=None, regularizer='UnRegularized', regularizer_strength=None, solver_name=None, solver_kwargs=None, feature_mask=None, **kwargs)[source]#

Bases: GLM

Population Generalized Linear Model.

This class implements a Generalized Linear Model for a neural population. This GLM implementation allows users to model the activity of a population of neurons based on a combination of exogenous inputs (like convolved currents or light intensities) and a choice of observation model. It is suitable for scenarios where the relationship between predictors and the response variable might be non-linear, and the residuals don’t follow a normal distribution. The predictors must be stored in tabular format, shape (n_timebins, num_features) or as a pytree of arrays of the same shape. Below is a table listing the default and available solvers for each regularizer.

Regularizer

Default Solver

Available Solvers

UnRegularized

LBFGS

GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient

Ridge

LBFGS

GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient

Lasso

ProximalGradient

ProximalGradient

GroupLasso

ProximalGradient

ProximalGradient

Fitting Large Models

For very large models, you may consider using the Stochastic Variance Reduced Gradient nemos.solvers._svrg.SVRG or its proximal variant (nemos.solvers._svrg.ProxSVRG) solver, which take advantage of batched computation. You can change the solver by passing "SVRG" or "ProxSVRG" as solver_name at model initialization.

The performance of the SVRG solver depends critically on the choice of batch_size and stepsize hyperparameters. These parameters control the size of the mini-batches used for gradient computations and the step size for each iteration, respectively. Improper selection of these parameters can lead to slow convergence or even divergence of the optimization process.

To assist with this, for certain GLM configurations, we provide batch_size and stepsize default values that are theoretically guaranteed to ensure fast convergence.

Below is a list of the configurations for which we can provide guaranteed hyperparameters:

GLM / PopulationGLM Configuration

Stepsize

Batch Size

Poisson + soft-plus + UnRegularized

✅ | ❌

Poisson + soft-plus + Ridge

✅ | ✅

Poisson + soft-plus + Lasso

✅ | ❌

Poisson + soft-plus + GroupLasso

✅ | ❌

Parameters:
  • observation_model (Union[BernoulliObservations, GammaObservations, GaussianObservations, NegativeBinomialObservations, PoissonObservations, Literal['Poisson', 'Gamma', 'Bernoulli', 'NegativeBinomial', 'Gaussian']]) – Observation model to use. The model describes the distribution of the neural activity. Default is the Poisson model.

  • inverse_link_function (Optional[Callable]) – A function that maps the linear combination of predictors into a firing rate. The default depends on the observation model, see the table above.

  • regularizer (Union[str, Regularizer]) – Regularization to use for model optimization. Defines the regularization scheme and related parameters. Default is UnRegularized regression.

  • regularizer_strength (Any) – Typically a float. Default is None. Sets the regularizer strength. If a user does not pass a value, and it is needed for regularization, a warning will be raised and the strength will default to 1.0. For finer control, the user can pass a pytree that matches the parameter structure to regularize parameters differentially.

  • solver_name (str) – Solver to use for model optimization. Defines the optimization scheme and related parameters. The solver must be an appropriate match for the chosen regularizer. Default is None. If no solver specified, one will be chosen based on the regularizer. Please see table above for regularizer/optimizer pairings.

  • solver_kwargs (dict) –

    Optional dictionary for keyword arguments that are passed to the solver when instantiated. E.g. stepsize, tol, acceleration, etc.

    For details on each solver’s kwargs, see get_accepted_arguments and get_solver_documentation.

  • feature_mask (Optional[Array]) – Either a matrix of shape (num_features, num_neurons) or a PyTree of 0s and 1s, with leaves of shape (num_neurons, ). The mask will be used to select which features are used as predictors for which neuron.

intercept_#

Model baseline linked firing rate parameters, e.g. if the link is the logarithm, the baseline firing rate will be jnp.exp(model.intercept_).

coef_#

Basis coefficients for the model.

solver_state_#

State of the solver after fitting. May include details like optimization error.

Raises:
  • TypeError – If provided regularizer or observation_model are not valid.

  • TypeError – If provided feature_mask is not an array-like of dimension two.

Parameters:
  • observation_model (REGRESSION_GLM_TYPES | Literal['Poisson', 'Gamma', 'Gaussian', 'Bernoulli', 'NegativeBinomial'])

  • inverse_link_function (Optional[Callable])

  • regularizer (Union[str, Regularizer])

  • regularizer_strength (Any)

  • solver_name (str)

  • solver_kwargs (dict)

  • feature_mask (Optional[jnp.ndarray])

Examples

Fit a PopulationGLM

Basic model fitting for a population of neurons:

>>> import jax.numpy as jnp
>>> import numpy as np
>>> import nemos as nmo
>>> np.random.seed(123)
>>> num_samples, num_features, num_neurons = 100, 3, 2
>>> X = np.random.normal(size=(num_samples, num_features))
>>> weights = np.array([[0.5, 0.0], [-0.5, -0.5], [0.0, 1.0]])
>>> y = np.random.poisson(np.exp(X.dot(weights)))
>>> model = nmo.glm.PopulationGLM().fit(X, y)
>>> model.coef_.shape
(3, 2)

Mask Coefficients with an Array

Use a feature mask to specify which features predict each neuron. The mask has shape (num_features, num_neurons):

>>> feature_mask = np.array([[1, 0], [1, 1], [0, 1]])
>>> model = nmo.glm.PopulationGLM(feature_mask=feature_mask).fit(X, y)
>>> model.coef_
Array(...)

Use a Dict of Arrays as Input

Features can be passed as a dict (or any JAX pytree). The feature mask should mirror the same structure, with one 1-D entry per leaf:

>>> feature_1 = np.random.normal(size=(num_samples, 2))
>>> feature_2 = np.random.normal(size=(num_samples, 1))
>>> X_dict = {"feature_1": feature_1, "feature_2": feature_2}
>>> weights = dict(
...     feature_1=jnp.array([[0.0, 0.5], [0.0, -0.5]]),
...     feature_2=jnp.array([[1.0, 0.0]])
... )
>>> rate = np.exp(
...     X_dict["feature_1"].dot(weights["feature_1"]) +
...     X_dict["feature_2"].dot(weights["feature_2"])
... )
>>> y = np.random.poisson(rate)
>>> feature_mask = {
...     "feature_1": jnp.array([0, 1], dtype=jnp.int32),
...     "feature_2": jnp.array([1, 0], dtype=jnp.int32)
... }
>>> model = nmo.glm.PopulationGLM(feature_mask=feature_mask).fit(X_dict, y)
>>> model.coef_
{...}

Customize the Observation Model

Use a Gamma observation model for continuous positive data:

>>> model = nmo.glm.PopulationGLM(observation_model="Gamma")
>>> model.observation_model
GammaObservations()

Use Regularization

Fit with Ridge regularization:

>>> X = np.random.normal(size=(num_samples, num_features))
>>> weights = np.array([[0.5, 0.0], [-0.5, -0.5], [0.0, 1.0]])
>>> y = np.random.poisson(np.exp(X.dot(weights)))
>>> model = nmo.glm.PopulationGLM(
...     regularizer="Ridge",
...     regularizer_strength=0.1
... ).fit(X, y)
>>> model.regularizer
Ridge()

Attributes

feature_mask

Mask indicating which features are used for each neuron.

inverse_link_function

Inverse link function mapping the linear predictor to the response space.

observation_model

The observation model governing the conditional distribution of y.

optimizer_init_state

Provides the initialization function for the optimizer state.

optimizer_run

Provides the function to execute the optimization process.

optimizer_update

Provides the function for updating the state during the optimization process.

regularizer

Getter for the regularizer attribute.

regularizer_strength

Regularizer strength getter.

solver

Getter for the solver class.

solver_kwargs

Getter for the solver_kwargs attribute.

solver_name

Getter for the solver_name attribute.

solver_spec

Getter for the solver specification.

__init__(observation_model='Poisson', inverse_link_function=None, regularizer='UnRegularized', regularizer_strength=None, solver_name=None, solver_kwargs=None, feature_mask=None, **kwargs)[source]#
Parameters:

Methods

__init__([observation_model, ...])

compute_loss(params, X, y, *args, **kwargs)

Compute the loss function for the model.

fit(X, y[, init_params])

Fit GLM to the activity of a population of neurons.

get_params([deep])

From scikit-learn, get parameters by inspecting init.

initialize_optimizer_and_state(init_params, X, y)

Initialize the optimization routine and its state for running fit and update.

initialize_params(X, y)

Initialize model parameters.

predict(X)

Predict rates based on fit parameters.

save_params(filename)

Save GLM model parameters to a .npz file.

score(X, y[, score_type, ...])

Evaluate the goodness-of-fit of the model to the observed neural data.

set_params(**params)

Manage warnings in case of multiple parameter settings.

simulate(random_key, feedforward_input)

Simulate neural activity in response to a feed-forward input.

update(params, opt_state, X, y, *args[, ...])

Update the model parameters and solver state.

classmethod __init_subclass__(**kwargs)#

Set the set_{method}_request methods.

This uses PEP-487 [1] to set the set_{method}_request methods. It looks for the information available in the set default values which are set using __metadata_request__* class attributes, or inferred from method signatures.

The __metadata_request__* class attributes are used when a method does not explicitly accept a metadata through its arguments or if the developer would like to specify a request value for those metadata which are different from the default None.

References

__repr__()#

Representation of the GLM class.

__sklearn_clone__()[source]#

Clone the PopulationGLM, dropping feature_mask.

Return type:

PopulationGLM

__sklearn_tags__()[source]#

Return Population GLM specific estimator tags.

compute_loss(params, X, y, *args, **kwargs)#

Compute the loss function for the model.

This method validates inputs and converts user-provided parameters to the internal representation before computing the loss.

Parameters:
  • params (UserProvidedParamsT) – Parameter tuple of (coefficients, intercept).

  • X (DESIGN_INPUT_TYPE) – Input data, array of shape (n_time_bins, n_features) or pytree of same.

  • y (jnp.ndarray) – Target data, array of shape (n_time_bins,) for single neuron models or (n_time_bins, n_neurons) for population models.

  • *args – Additional positional arguments passed to the model-specific loss function.

  • **kwargs – Additional keyword arguments passed to the model-specific loss function.

Returns:

The loss value (negative log-likelihood).

Return type:

jnp.ndarray

Raises:

ValueError – If inputs or parameters have incompatible shapes or invalid values.

property feature_mask: Array | dict[str, Array]#

Mask indicating which features are used for each neuron.

The feature mask has a tree structure matching the coefficients (coef_):

  • Array input: Shape (n_features, n_neurons). Each entry [i, j] indicates whether feature i is used for neuron j (1 = used, 0 = masked).

  • Pytree: A pytree with structure matching that of coef_. Each leaf array has shape (n_neurons,), indicating whether that feature group is used for each neuron.

Returns:

The feature mask, or None if not set.

Return type:

jnp.ndarray or dict[str, jnp.ndarray]

fit(X, y, init_params=None)[source]#

Fit GLM to the activity of a population of neurons.

Fit and store the model parameters as attributes coef_ and intercept_. Each neuron can have different predictors. The feature_mask will determine which feature will be used for which neurons. See the note below for more information on the feature_mask.

Parameters:
  • X (Union[DESIGN_INPUT_TYPE, ArrayLike]) – Predictors, array of shape (n_timebins, n_features) or pytree of the same shape.

  • y (ArrayLike) – Target neural activity arranged in a matrix, shape (n_timebins, n_neurons).

  • init_params (Optional[GLMUserParams]) – 2-tuple of initial parameter values: (coefficients, intercepts). If None, we initialize coefficients with zeros, intercepts with the log of the mean neural activity. coefficients is an array of shape (n_features, n_neurons) or pytree of the same shape, intercepts is an array of shape (n_neurons, )

Raises:
  • ValueError – If init_params is not of length two.

  • ValueError – If dimensionality of init_params are not correct.

  • ValueError – If X is not two-dimensional.

  • ValueError – If y is not two-dimensional.

  • ValueError – If the feature_mask is not of the right shape.

  • ValueError – If solver returns at least one NaN parameter, which means it found an invalid solution. Try tuning optimization hyperparameters.

  • TypeError – If init_params are not array-like

  • TypeError – If init_params[i] cannot be converted to jnp.ndarray for all i

Notes

The feature_mask is used to select features for each neuron, and it is an NDArray or a PyTree of 0s and 1s. In particular,

  • If the mask is in array format, feature i is a predictor for neuron j if feature_mask[i, j] == 1.

  • If the mask is a PyTree, then a leaf is a predictor of neuron j if the matching leaf in feature_mask is equal to 1.

Examples

>>> # Generate sample data
>>> import jax.numpy as jnp
>>> import numpy as np
>>> from nemos.glm import PopulationGLM
>>> np.random.seed(0)
>>> # Define predictors (X), weights, and neural activity (y)
>>> num_samples, num_features, num_neurons = 100, 3, 2
>>> X = np.random.normal(size=(num_samples, num_features))
>>> # Weights is defined by how each feature influences the output, shape (num_features, num_neurons)
>>> weights = np.array([[ 0.5,  0. ], [-0.5, -0.5], [ 0. ,  1. ]])
>>> # Output y simulates a Poisson distribution based on a linear model between features X and wegihts
>>> y = np.random.poisson(np.exp(X.dot(weights)))
>>> # Define a feature mask, shape (num_features, num_neurons)
>>> feature_mask = jnp.array([[1, 0], [1, 1], [0, 1]])
>>> # Create and fit the model
>>> model = PopulationGLM(feature_mask=feature_mask).fit(X, y)
>>> print(model.coef_.shape)
(3, 2)
get_metadata_routing()#

Get metadata routing of this object.

Please check User Guide on how the routing mechanism works.

Returns:

routing – A MetadataRequest encapsulating routing information.

Return type:

MetadataRequest

get_params(deep=True)#

From scikit-learn, get parameters by inspecting init.

Parameters:

deep – If True, will return the parameters for this estimator and contained subobjects that are estimators.

Return type:

dict

Returns:

A dictionary containing the parameters. Key is the parameter name, value is the parameter value.

initialize_optimizer_and_state(init_params, X, y)#

Initialize the optimization routine and its state for running fit and update.

This method must be called before using update() for iterative optimization. It sets up the solver with the provided initial parameters and data.

Parameters:
  • X (DESIGN_INPUT_TYPE) – Input data, array of shape (n_time_bins, n_features) or pytree of same.

  • y (jnp.ndarray) – Target data, array of shape (n_time_bins,) for single neuron models or (n_time_bins, n_neurons) for population models.

  • init_params (UserProvidedParamsT) – Initial parameter tuple of (coefficients, intercept).

Returns:

Initial solver state.

Return type:

SolverState

Raises:

ValueError – If inputs or parameters have incompatible shapes or invalid values.

initialize_params(X, y)#

Initialize model parameters.

Initialize coefficients with zeros and intercept by matching the mean firing rate.

Parameters:
  • X (DESIGN_INPUT_TYPE) – Input data, array of shape (n_time_bins, n_features) or pytree of same.

  • y (jnp.ndarray) – Target data, array of shape (n_time_bins,) for single neuron models or (n_time_bins, n_neurons) for population models.

Returns:

Initial parameter tuple of (coefficients, intercept).

Return type:

UserProvidedParamsT

Inverse link function mapping the linear predictor to the response space.

Always a callable. If None was passed at construction time, this is resolved to the observation model’s default (e.g. jnp.exp for Poisson, 1 / x for Gamma, jax.nn.sigmoid for Bernoulli).

property observation_model: None | Observations#

The observation model governing the conditional distribution of y.

Always an instance of an Observations subclass. If a string alias was passed at construction time it is resolved to the corresponding instance here.

property optimizer_init_state: None | Callable[[Any, Array, Array], SolverState]#

Provides the initialization function for the optimizer state.

This function is responsible for initializing the optimizer state, necessary for the start of the optimizer process. It sets up initial values for parameters like gradients and step sizes based on the model configuration and input data.

Returns:

The function to initialize the optimizer state, if available; otherwise, None if the optimizer has not yet been instantiated.

property optimizer_run: None | Callable[[Any, Array, Array], Tuple[Any, SolverState, Aux]]#

Provides the function to execute the optimization process.

This function runs the optimizer using the initialized parameters and state, performing the optimization to fit the model to the data. It iteratively updates the model parameters until a stopping criterion is met, such as convergence or exceeding a maximum number of iterations.

Returns:

The function to run the optimization process, if available; otherwise, None if the optimizer has not yet been instantiated.

property optimizer_update: None | Callable[[Any, NamedTuple, Array, Array], Tuple[Any, SolverState, Aux]]#

Provides the function for updating the state during the optimization process.

This function is used to perform a single update step in the optimization process. It updates the model’s parameters based on the current state, data, and gradients. It is typically used in scenarios where fine-grained control over each optimizer step is necessary, such as in online learning or complex optimization scenarios.

Returns:

The function to perform a single optimization update step, if available; otherwise, None if the optimizer has not yet been instantiated.

predict(X)#

Predict rates based on fit parameters.

Parameters:

X (DESIGN_INPUT_TYPE) – Predictors, array of shape (n_time_bins, n_features) or a pytree of arrays of the same shape.

Return type:

jnp.ndarray

Returns:

The predicted rates with shape (n_time_bins, ).

Raises:
  • NotFittedError – If fit has not been called first with this instance.

  • ValueError – If params is not a JAX pytree of size two.

  • ValueError – If weights and bias terms in params don’t have the expected dimensions.

  • ValueError – If X is not three-dimensional.

  • ValueError – If there’s an inconsistent number of features between spike basis coefficients and X.

Examples

>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
>>> # define and fit a GLM
>>> import nemos as nmo
>>> model = nmo.glm.GLM()
>>> model = model.fit(X, y)
>>> # predict new spike data
>>> Xnew = np.random.normal(size=(20, X.shape[1]))
>>> predicted_spikes = model.predict(Xnew)

See also

nemos.glm.GLM.score()

Score predicted rates against target spike counts.

nemos.glm.GLM.simulate()

Simulate neural activity in response to a feed-forward input (feed-forward only).

nemos.simulation.simulate_recurrent()

Simulate neural activity in response to a feed-forward input using the GLM as a recurrent network (feed-forward + coupling).

property regularizer: None | Regularizer#

Getter for the regularizer attribute.

property regularizer_strength: Any#

Regularizer strength getter.

save_params(filename)#

Save GLM model parameters to a .npz file.

This method allows to reuse the model parameters. The saved parameters can be loaded back into a GLM instance using the load_params function.

Parameters:

filename (Union[str, Path]) – The name of the file where the model parameters will be saved. The file will be saved in .npz format.

Examples

>>> import nemos as nmo
>>> # Create a GLM model with specified parameters
>>> solver_args = {"stepsize": 0.1, "maxiter": 1000, "tol": 1e-6}
>>> model = nmo.glm.GLM(
...     regularizer="Ridge",
...     regularizer_strength=0.1,
...     observation_model="Gamma",
...     solver_name="BFGS",
...     solver_kwargs=solver_args,
... )
>>> for key, value in model.get_params().items():
...     print(f"{key}: {value}")
inverse_link_function: <function one_over_x at ...>
observation_model: GammaObservations()
regularizer: Ridge()
regularizer_strength: 0.1...
solver_kwargs: {'stepsize': 0.1, 'maxiter': 1000, 'tol': 1e-06}
solver_name: BFGS
>>> # Save the model parameters to a file
>>> model.save_params("model_params.npz")
>>> # Load the model from the saved file
>>> model = nmo.load_model("model_params.npz")
>>> # Model has the same parameters before and after load
>>> for key, value in model.get_params().items():
...     print(f"{key}: {value}")
inverse_link_function: <function one_over_x at ...>
observation_model: GammaObservations()
regularizer: Ridge()
regularizer_strength: 0.1
solver_kwargs: {'maxiter': 1000, 'stepsize': 0.1, 'tol': 1e-06}
solver_name: BFGS
>>> # Saving and loading a custom inverse link function
>>> model = nmo.glm.GLM(
...     observation_model="Poisson",
...     inverse_link_function=lambda x: x**2
... )
>>> model.save_params("model_params.npz")
>>> # Provide a mapping for the custom link function when loading.
>>> mapping_dict = {
...     "inverse_link_function": lambda x: x**2,
... }
>>> loaded_model = nmo.load_model("model_params.npz", mapping_dict=mapping_dict)
>>> # Now the loaded model will have the updated solver_name and solver_kwargs
>>> for key, value in loaded_model.get_params().items():
...     print(f"{key}: {value}")
inverse_link_function: <function <lambda> at ...>
observation_model: PoissonObservations()
regularizer: UnRegularized()
regularizer_strength: None
solver_kwargs: {}
solver_name: LBFGS
score(X, y, score_type='log-likelihood', aggregate_sample_scores=<function mean>)#

Evaluate the goodness-of-fit of the model to the observed neural data.

This method computes the goodness-of-fit score, which can either be the mean log-likelihood or of two versions of the pseudo-\(R^2\). The scoring process includes validation of input compatibility with the model’s parameters, ensuring that the model has been previously fitted and the input data are appropriate for scoring. A higher score indicates a better fit of the model to the observed data.

Parameters:
  • X (Union[DESIGN_INPUT_TYPE, ArrayLike]) – Predictors, array of shape (n_time_bins, n_features) or a pytree of arrays of the same shape.

  • y (ArrayLike) – Neural activity. Shape (n_time_bins, ).

  • score_type (Literal[‘log-likelihood’, ‘pseudo-r2-McFadden’, ‘pseudo-r2-Cohen’]) – Type of scoring: either log-likelihood or pseudo-\(R^2\).

  • aggregate_sample_scores (Callable) – Function that aggregates the score of all samples.

Returns:

The log-likelihood or the pseudo-\(R^2\) of the current model.

Return type:

jnp.ndarray

Raises:
  • NotFittedError – If fit has not been called first with this instance.

  • ValueError – If X structure doesn’t match the params, and if X and y have different number of samples.

Examples

>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
>>> import nemos as nmo
>>> model = nmo.glm.GLM()
>>> model = model.fit(X, y)
>>> # get model score
>>> log_likelihood_score = model.score(X, y)
>>> # get a pseudo-R2 score
>>> pseudo_r2_score = model.score(X, y, score_type='pseudo-r2-McFadden')

Notes

The log-likelihood is not on a standard scale, its value is influenced by many factors, among which the number of model parameters. The log-likelihood can assume both positive and negative values.

The Pseudo-\(R^2\) is not equivalent to the \(R^2\) value in linear regression. While both provide a measure of model fit, and assume values in the [0,1] range, the methods and interpretations can differ. The Pseudo-\(R^2\) is particularly useful for generalized linear models when the interpretation of the \(R^2\) as explained variance does not apply (i.e., when the observations are not Gaussian distributed).

Why does the traditional \(R^2\) is usually a poor measure of performance in GLMs?

  1. In the context of GLMs the variance and the mean of the observations are related. Ignoring the relation between them can result in underestimating the model performance; for instance, when we model a Poisson variable with large mean we expect an equally large variance. In this scenario, even if our model perfectly captures the mean, the high-variance will result in large residuals and low \(R^2\). Additionally, when the mean of the observations varies, the variance will vary too. This violates the “homoschedasticity” assumption, necessary for interpreting the \(R^2\) as variance explained.

  2. The \(R^2\) capture the variance explained when the relationship between the observations and the predictors is linear. In GLMs, the link function sets a non-linear mapping between the predictors and the mean of the observations, compromising the interpretation of the \(R^2\).

Note that it is possible to re-normalized the residuals by a mean-dependent quantity proportional to the model standard deviation (i.e. Pearson residuals). This “rescaled” residual distribution however deviates substantially from normality for counting data with low mean (common for spike counts). Therefore, even the Pearson residuals performs poorly as a measure of fit quality, especially for GLM modeling counting data.

Refer to the nmo.observation_models.Observations concrete subclasses for the likelihood and pseudo-\(R^2\) equations.

set_fit_request(*, init_params='$UNCHANGED$')#

Configure whether metadata should be requested to be passed to the fit method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Parameters:
  • init_params (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for init_params parameter in fit.

  • self (PopulationGLM)

Returns:

self – The updated object.

Return type:

object

set_params(**params)#

Manage warnings in case of multiple parameter settings.

Parameters:

params (Any)

set_score_request(*, aggregate_sample_scores='$UNCHANGED$', score_type='$UNCHANGED$')#

Configure whether metadata should be requested to be passed to the score method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to score if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to score.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Parameters:
  • aggregate_sample_scores (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for aggregate_sample_scores parameter in score.

  • score_type (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for score_type parameter in score.

  • self (PopulationGLM)

Returns:

self – The updated object.

Return type:

object

simulate(random_key, feedforward_input)#

Simulate neural activity in response to a feed-forward input.

Parameters:
  • random_key (jax.Array) – jax.random.key for seeding the simulation.

  • feedforward_input (DESIGN_INPUT_TYPE) – External input predictors to the model, representing factors like convolved currents, light intensities, etc. When not provided, the simulation is done with coupling-only. Array of shape (n_time_bins, n_basis_input) or pytree with leaves of the same shape.

Return type:

Tuple[jnp.ndarray, jnp.ndarray]

Returns:

  • simulated_activity – Simulated activity (spike counts for Poisson GLMs) for the neuron over time. Shape: (n_time_bins, ).

  • firing_rates – Simulated rates for the neuron over time. Shape, (n_time_bins, ).

Raises:
  • NotFittedError

    • If the model hasn’t been fitted prior to calling this method.

  • ValueError

    • If the instance has not been previously fitted.

Examples

>>> # example input
>>> import numpy as np
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
>>> # define and fit model
>>> import nemos as nmo
>>> model = nmo.glm.GLM()
>>> model = model.fit(X, y)
>>> # generate spikes and rates
>>> random_key = jax.random.key(123)
>>> Xnew = np.random.normal(size=(20, X.shape[1]))
>>> spikes, rates = model.simulate(random_key, Xnew)

See also

nemos.glm.GLM.predict()

Method to predict rates based on the model’s parameters.

property solver#

Getter for the solver class.

property solver_kwargs#

Getter for the solver_kwargs attribute.

property solver_name: str#

Getter for the solver_name attribute.

property solver_spec: SolverSpec#

Getter for the solver specification.

update(params, opt_state, X, y, *args, n_samples=None, **kwargs)#

Update the model parameters and solver state.

This method performs a single optimization step using the model’s current solver. It updates the model’s coefficients and intercept based on the provided parameters, predictors (X), responses (y), and the current optimization state. This method is particularly useful for iterative model fitting, especially in scenarios where model parameters need to be updated incrementally, such as online learning or when dealing with very large datasets that do not fit into memory at once.

Parameters:
  • params (GLMUserParams) – The current model parameters, typically a tuple of coefficients and intercepts.

  • opt_state (SolverState) – The current state of the optimizer, encapsulating information necessary for the optimization algorithm to continue from the current state. This includes gradients, step sizes, and other optimizer-specific metrics.

  • X (DESIGN_INPUT_TYPE) – The predictors used in the model fitting process, which may include feature matrices or a pytree of arrays. Shape (n_time_bins, n_features).

  • y (jnp.ndarray) – The response variable or output data corresponding to the predictors. Shape (n_time_bins,).

  • *args – Additional positional arguments to be passed to the solver’s update method.

  • n_samples (Optional[int]) – The total number of samples. Usually larger than the samples of an individual batch, the n_samples are used to estimate the scale parameter of the GLM.

  • **kwargs – Additional keyword arguments to be passed to the solver’s update method.

Return type:

StepResult

Returns:

  • params – Updated model parameters (coefficients, intercepts).

  • state – Updated optimizer state.

Raises:

ValueError – If the solver has not been instantiated or if the solver returns NaN values indicating an invalid update step, typically due to numerical instabilities or inappropriate solver configurations.

Examples

>>> import nemos as nmo
>>> import numpy as np
>>> import jax
>>> X, y = np.random.normal(size=(10, 2)), np.random.poisson(size=10)
>>> glm_instance = nmo.glm.GLM()
>>> params = glm_instance.initialize_params(X, y)
>>> opt_state = glm_instance.initialize_optimizer_and_state(params, X, y)
>>> new_params, new_opt_state = glm_instance.update(params, opt_state, X, y)