nemos.glm.ClassifierGLM#

class nemos.glm.ClassifierGLM(n_classes=2, inverse_link_function=None, regularizer=None, regularizer_strength=None, solver_name=None, solver_kwargs=None)[source]#

Bases: ClassifierMixin, GLM

Generalized Linear Model for multi-class classification.

This model predicts discrete class labels from input features using a softmax (multinomial logistic) model. It uses an over-parameterized representation with one set of coefficients per class, resulting in coefficient shape (n_features, n_classes) and intercept shape (n_classes,).

Parameters:
  • n_classes (Optional[int]) – The number of classes. Must be >= 2.

  • inverse_link_function (Optional[Callable]) – The inverse link function. Default is log_softmax.

  • regularizer (Union[str, Regularizer, None]) – The regularization scheme. Default is Ridge. Note that the model is over-parameterized: one set of coefficients for each class. Regularization makes the parameters identifiable. Setting UnRegularized will result in non-identifiable coefficients, see note below.

  • regularizer_strength (Any) – The strength of the regularization.

  • solver_name (str) – The solver to use for optimization.

  • solver_kwargs (dict) – Additional keyword arguments for the solver.

coef_#

Fitted coefficients of shape (n_features, n_classes) after calling fit().

intercept_#

Fitted intercepts of shape (n_classes,) after calling fit().

Notes

Identifiability

This model uses an over-parameterized (symmetric) representation where each class has its own set of coefficients. Since probabilities from softmax are invariant to adding a constant to all linear predictors, the parameters are not uniquely identifiable without regularization. For example, if (coef, intercept) is a solution, so is (coef + c, intercept + c) for any constant c.

Using regularization (default is Ridge) resolves this ambiguity by penalizing the parameter magnitudes, effectively centering the solution. If you use UnRegularized, the optimization may converge to different equivalent solutions depending on initialization, though predictions will be identical.

Class Labels

The target array y can contain any hashable class labels that can be stored in a NumPy array, including integers, strings, or other hashable types. The model internally maps these labels to indices [0, n_classes - 1] for computation and maps them back when returning predictions.

Performance Considerations

For optimal performance, use integer labels [0, 1, ..., n_classes - 1]. When labels follow this convention, the model skips the encoding/decoding steps entirely. Using other label formats (e.g., ["cat", "dog"] or [5, 10, 15]) incurs a small overhead for label translation.

Setting Class Labels

The fit() and initialize_optimizer_and_state() methods automatically infer class labels from the provided y. If you set coef_ and intercept_ manually, you must call set_classes() before using predict(), predict_proba(), simulate(), score(), or compute_loss().

See also

ClassifierPopulationGLM

Multi-class classification for multiple neurons.

GLM

Generalized Linear Model for continuous/count responses.

Examples

Fit a ClassifierGLM

Basic binary classification:

>>> import jax.numpy as jnp
>>> import numpy as np
>>> import nemos as nmo
>>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
>>> y = jnp.array([0, 0, 1, 1])
>>> model = nmo.glm.ClassifierGLM(n_classes=2).fit(X, y)
>>> model.coef_.shape
(2, 2)

Predict Class Labels

Get predicted class labels:

>>> predictions = model.predict(X)
>>> predictions.shape
(4,)

Predict Class Probabilities

Get class probabilities or log-probabilities:

>>> proba = model.predict_proba(X, return_type="proba")
>>> proba.shape
(4, 2)
>>> log_proba = model.predict_proba(X, return_type="log-proba")
>>> log_proba.shape
(4, 2)

Use String Labels

Class labels can be strings or any hashable type:

>>> y_str = np.array(["cat", "cat", "dog", "dog"])
>>> model = nmo.glm.ClassifierGLM(n_classes=2).fit(X, y_str)
>>> model.classes_
array(['cat', 'dog'], dtype='<U3')
>>> model.predict(X)
array(['cat', 'cat', 'dog', 'dog'], dtype='<U3')

Multi-class Classification

Classify into more than two classes:

>>> X = jnp.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0], [5.0, 6.0], [6.0, 7.0]])
>>> y = jnp.array([0, 0, 1, 1, 2, 2])
>>> model = nmo.glm.ClassifierGLM(n_classes=3).fit(X, y)
>>> model.coef_.shape
(2, 3)

Use Regularization

Change regularization strength:

>>> model = nmo.glm.ClassifierGLM(
...     n_classes=2,
...     regularizer="Ridge",
...     regularizer_strength=0.5
... )
>>> model.regularizer
Ridge()

Use a Pytree of arrays as Input

Features can be passed as any JAX pytree of 2-D arrays; the fitted coef_ will share the same pytree structure:

>>> X_dict = {"feature_1": X[:, :1], "feature_2": X[:, 1:]}
>>> model = nmo.glm.ClassifierGLM(n_classes=3).fit(X_dict, y)
>>> # The coefficient structure matches the input
>>> type(model.coef_)
<class 'dict'>

Attributes

classes_

Class labels, or None if not set.

inverse_link_function

Inverse link function mapping the linear predictor to the response space.

n_classes

Number of classes.

observation_model

The observation model governing the conditional distribution of y.

optimizer_init_state

Provides the initialization function for the optimizer state.

optimizer_run

Provides the function to execute the optimization process.

optimizer_update

Provides the function for updating the state during the optimization process.

regularizer

Getter for the regularizer attribute.

regularizer_strength

Regularizer strength getter.

solver

Getter for the solver class.

solver_kwargs

Getter for the solver_kwargs attribute.

solver_name

Getter for the solver_name attribute.

solver_spec

Getter for the solver specification.

__init__(n_classes=2, inverse_link_function=None, regularizer=None, regularizer_strength=None, solver_name=None, solver_kwargs=None)[source]#
Parameters:

Methods

__init__([n_classes, inverse_link_function, ...])

compute_loss(params, X, y, *args, **kwargs)

Compute the loss function for the model.

fit(X, y[, init_params])

Fit the model to training data.

get_params([deep])

From scikit-learn, get parameters by inspecting init.

initialize_optimizer_and_state(init_params, X, y)

Initialize the solver and its state for running fit and update.

initialize_params(X, y)

Initialize model parameters for categorical GLM.

predict(X)

Predict class labels for samples in X.

predict_proba(X[, return_type])

Predict class probabilities for samples in X.

save_params(filename)

Save GLM model parameters to a .npz file.

score(X, y[, score_type, ...])

Score the model on test data.

set_classes(y)

Infer unique class labels and set the classes_ attribute.

set_params(**params)

Manage warnings in case of multiple parameter settings.

simulate(random_key, feedforward_input)

Simulate categorical responses from the model.

update(params, opt_state, X, y, *args[, ...])

Update the model parameters and solver state.

classmethod __init_subclass__(**kwargs)#

Set the set_{method}_request methods.

This uses PEP-487 [1] to set the set_{method}_request methods. It looks for the information available in the set default values which are set using __metadata_request__* class attributes, or inferred from method signatures.

The __metadata_request__* class attributes are used when a method does not explicitly accept a metadata through its arguments or if the developer would like to specify a request value for those metadata which are different from the default None.

References

__repr__()#

Representation of the GLM class.

__sklearn_clone__()#

Clone the GLM.

Return type:

GLM

__sklearn_tags__()#

Return GLM specific estimator tags.

property classes_: TypeAliasForwardRef('NDArray') | None#

Class labels, or None if not set.

compute_loss(params, X, y, *args, **kwargs)#

Compute the loss function for the model.

This method validates inputs, encodes class labels to internal indices, and computes the loss (negative log-likelihood).

Parameters:
  • params – Parameter tuple of (coefficients, intercept).

  • X – Input data, array of shape (n_time_bins, n_features) or pytree of same.

  • y – Target class labels in the same format as classes_.

  • *args – Additional positional arguments passed to the model-specific loss function.

  • **kwargs – Additional keyword arguments passed to the model-specific loss function.

Returns:

The loss value (negative log-likelihood).

Return type:

loss

Raises:
  • RuntimeError – If classes_ has not been set.

  • ValueError – If inputs or parameters have incompatible shapes or invalid values.

fit(X, y, init_params=None)[source]#

Fit the model to training data.

Parameters:
  • X (Union[DESIGN_INPUT_TYPE, ArrayLike]) – Training input samples of shape (n_samples, n_features) or a pytree of arrays of the same shape.

  • y (ArrayLike) – Target class labels of shape (n_samples,). Labels can be any hashable type (integers, strings, etc.). Float arrays with integer values are accepted and converted automatically.

  • init_params (Optional[GLMUserParams]) – Initial parameter values as tuple of (coef, intercept). If None, parameters are initialized automatically.

Returns:

The fitted model.

Notes

fit calls set_classes() internally, so classes_ is always consistent with the labels in y.

Examples

>>> import jax.numpy as jnp
>>> import nemos as nmo
>>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
>>> y = jnp.array([0, 0, 1, 1])
>>> model = nmo.glm.ClassifierGLM(n_classes=2)
>>> model = model.fit(X, y)
>>> model.coef_.shape
(2, 2)
get_metadata_routing()#

Get metadata routing of this object.

Please check User Guide on how the routing mechanism works.

Returns:

routing – A MetadataRequest encapsulating routing information.

Return type:

MetadataRequest

get_params(deep=True)#

From scikit-learn, get parameters by inspecting init.

Parameters:

deep – If True, will return the parameters for this estimator and contained subobjects that are estimators.

Return type:

dict

Returns:

A dictionary containing the parameters. Key is the parameter name, value is the parameter value.

initialize_optimizer_and_state(init_params, X, y)#

Initialize the solver and its state for running fit and update.

This method must be called before using update() for iterative optimization. It sets up the solver with the provided initial parameters and data.

Parameters:
  • init_params (UserProvidedParamsT) – Initial parameter tuple of (coefficients, intercept).

  • X (DESIGN_INPUT_TYPE) – Input data, array of shape (n_time_bins, n_features) or pytree of same.

  • y (jnp.ndarray) – Target labels, array of shape (n_time_bins,) for single neuron/subject models or (n_time_bins, n_neurons) for population models.

Returns:

Initial solver state.

Return type:

SolverState

Raises:

ValueError – If inputs or parameters have incompatible shapes or invalid values.

initialize_params(X, y)#

Initialize model parameters for categorical GLM.

Initialize coefficients with zeros and intercept by matching the mean class proportions. Class labels are automatically converted to one-hot encoding.

Parameters:
  • X (DESIGN_INPUT_TYPE) – Input data, array of shape (n_time_bins, n_features) or pytree of same.

  • y (jnp.ndarray) – Class labels, array of shape (n_time_bins,) for single neuron models or (n_time_bins, n_neurons) for population models. Labels must be a subset of classes_.

Return type:

UserProvidedParamsT

Returns:

Initial parameter tuple of (coefficients, intercept).

Notes

All labels in y must be present in classes_. Passing labels not in classes_ will raise an error.

Examples

>>> import jax.numpy as jnp
>>> import nemos as nmo
>>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
>>> y = jnp.array([0, 0, 1, 1])
>>> model = nmo.glm.ClassifierGLM(n_classes=2)
>>> model.set_classes(y)
ClassifierGLM(...)
>>> coef, intercept = model.initialize_params(X, y)
>>> coef.shape
(2, 2)

Inverse link function mapping the linear predictor to the response space.

Always a callable. If None was passed at construction time, this is resolved to the observation model’s default (e.g. jnp.exp for Poisson, 1 / x for Gamma, jax.nn.sigmoid for Bernoulli).

property n_classes#

Number of classes.

property observation_model: None | Observations#

The observation model governing the conditional distribution of y.

Always an instance of an Observations subclass. If a string alias was passed at construction time it is resolved to the corresponding instance here.

property optimizer_init_state: None | Callable[[Any, Array, Array], SolverState]#

Provides the initialization function for the optimizer state.

This function is responsible for initializing the optimizer state, necessary for the start of the optimizer process. It sets up initial values for parameters like gradients and step sizes based on the model configuration and input data.

Returns:

The function to initialize the optimizer state, if available; otherwise, None if the optimizer has not yet been instantiated.

property optimizer_run: None | Callable[[Any, Array, Array], Tuple[Any, SolverState, Aux]]#

Provides the function to execute the optimization process.

This function runs the optimizer using the initialized parameters and state, performing the optimization to fit the model to the data. It iteratively updates the model parameters until a stopping criterion is met, such as convergence or exceeding a maximum number of iterations.

Returns:

The function to run the optimization process, if available; otherwise, None if the optimizer has not yet been instantiated.

property optimizer_update: None | Callable[[Any, NamedTuple, Array, Array], Tuple[Any, SolverState, Aux]]#

Provides the function for updating the state during the optimization process.

This function is used to perform a single update step in the optimization process. It updates the model’s parameters based on the current state, data, and gradients. It is typically used in scenarios where fine-grained control over each optimizer step is necessary, such as in online learning or complex optimization scenarios.

Returns:

The function to perform a single optimization update step, if available; otherwise, None if the optimizer has not yet been instantiated.

predict(X)#

Predict class labels for samples in X.

Parameters:

X (DESIGN_INPUT_TYPE) – The input samples. Can be an array of shape (n_samples, n_features) or a pytree of arrays of the same shape.

Return type:

jnp.ndarray

Returns:

Predicted class labels for each sample. Returns an integer array of shape (n_samples, ) with values in [0, n_classes - 1].

Examples

>>> import jax.numpy as jnp
>>> import nemos as nmo
>>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
>>> y = jnp.array([0, 0, 1, 1])
>>> model = nmo.glm.ClassifierGLM(n_classes=2).fit(X, y)
>>> predictions = model.predict(X)
>>> predictions.shape
(4,)
predict_proba(X, return_type='log-proba')#

Predict class probabilities for samples in X.

Parameters:
  • X (DESIGN_INPUT_TYPE) – The input samples. Can be an array of shape (n_samples, n_features) or a pytree of arrays of the same shape.

  • return_type (Literal[‘log-proba’, ‘proba’]) – The format of the returned probabilities. If "log-proba", returns log-probabilities. If "proba", returns probabilities. Defaults to "log-proba".

Return type:

jnp.ndarray

Returns:

Predicted class probabilities. Returns an array of shape (n_samples, n_classes) where each row sums to 1 (for probabilities) or to 0 in log-space (for log-probabilities).

Examples

>>> import jax.numpy as jnp
>>> import nemos as nmo
>>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
>>> y = jnp.array([0, 0, 1, 1])
>>> model = nmo.glm.ClassifierGLM(n_classes=2).fit(X, y)
>>> proba = model.predict_proba(X, return_type="proba")
>>> proba.shape
(4, 2)
property regularizer: None | Regularizer#

Getter for the regularizer attribute.

property regularizer_strength: Any#

Regularizer strength getter.

save_params(filename)#

Save GLM model parameters to a .npz file.

This method allows to reuse the model parameters. The saved parameters can be loaded back into a GLM instance using the load_params function.

Parameters:

filename (Union[str, Path]) – The name of the file where the model parameters will be saved. The file will be saved in .npz format.

Examples

>>> import nemos as nmo
>>> # Create a GLM model with specified parameters
>>> solver_args = {"stepsize": 0.1, "maxiter": 1000, "tol": 1e-6}
>>> model = nmo.glm.GLM(
...     regularizer="Ridge",
...     regularizer_strength=0.1,
...     observation_model="Gamma",
...     solver_name="BFGS",
...     solver_kwargs=solver_args,
... )
>>> for key, value in model.get_params().items():
...     print(f"{key}: {value}")
inverse_link_function: <function one_over_x at ...>
observation_model: GammaObservations()
regularizer: Ridge()
regularizer_strength: 0.1...
solver_kwargs: {'stepsize': 0.1, 'maxiter': 1000, 'tol': 1e-06}
solver_name: BFGS
>>> # Save the model parameters to a file
>>> model.save_params("model_params.npz")
>>> # Load the model from the saved file
>>> model = nmo.load_model("model_params.npz")
>>> # Model has the same parameters before and after load
>>> for key, value in model.get_params().items():
...     print(f"{key}: {value}")
inverse_link_function: <function one_over_x at ...>
observation_model: GammaObservations()
regularizer: Ridge()
regularizer_strength: 0.1
solver_kwargs: {'maxiter': 1000, 'stepsize': 0.1, 'tol': 1e-06}
solver_name: BFGS
>>> # Saving and loading a custom inverse link function
>>> model = nmo.glm.GLM(
...     observation_model="Poisson",
...     inverse_link_function=lambda x: x**2
... )
>>> model.save_params("model_params.npz")
>>> # Provide a mapping for the custom link function when loading.
>>> mapping_dict = {
...     "inverse_link_function": lambda x: x**2,
... }
>>> loaded_model = nmo.load_model("model_params.npz", mapping_dict=mapping_dict)
>>> # Now the loaded model will have the updated solver_name and solver_kwargs
>>> for key, value in loaded_model.get_params().items():
...     print(f"{key}: {value}")
inverse_link_function: <function <lambda> at ...>
observation_model: PoissonObservations()
regularizer: UnRegularized()
regularizer_strength: None
solver_kwargs: {}
solver_name: LBFGS
score(X, y, score_type='log-likelihood', aggregate_sample_scores=<function mean>)[source]#

Score the model on test data.

Parameters:
  • X (Union[DESIGN_INPUT_TYPE, ArrayLike]) – Test input samples of shape (n_samples, n_features) or a pytree of arrays of the same shape.

  • y (ArrayLike) – True class labels of shape (n_samples,). Labels must be a subset of classes_.

  • score_type (Literal[‘log-likelihood’, ‘pseudo-r2-McFadden’, ‘pseudo-r2-Cohen’]) – The type of score to compute.

  • aggregate_sample_scores (Optional[Callable]) – Function to aggregate per-sample scores.

Return type:

jnp.ndarray

Returns:

The computed score.

Notes

All labels in y must be present in classes_. Passing labels not in classes_ will raise an error.

Examples

>>> import jax.numpy as jnp
>>> import nemos as nmo
>>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
>>> y = jnp.array([0, 0, 1, 1])
>>> model = nmo.glm.ClassifierGLM(n_classes=2).fit(X, y)
>>> score = model.score(X, y)
set_classes(y)#

Infer unique class labels and set the classes_ attribute.

This method infers class labels from y and sets up the internal encoding/decoding machinery. When labels are the default [0, 1, ..., n_classes-1], encoding is skipped for performance.

Parameters:

y (ArrayLike) – An array that must contain all the class labels, i.e. len(np.unique(y)) == n_classes.

Raises:

ValueError – If the number of unique class labels in y does not match n_classes.

Return type:

ClassifierMixin

Notes

fit() and initialize_optimizer_and_state() call set_classes internally, making sure that the classes_ attribute matches the provided input. If you are fitting in batches by calling update(), make sure that the classes_ are correctly set by calling set_classes before starting the update() loop.

Examples

When fitting in batches with update(), use set_classes to define all class labels before initialization. This is necessary when individual batches may not contain all classes.

>>> import nemos as nmo
>>> import numpy as np
>>> model = nmo.glm.ClassifierGLM(3)

Generate sample data where the first batch only contains 2 of 3 classes:

>>> X = np.random.randn(100, 5)
>>> y_all_classes = np.array([0, 1, 2])  # all possible classes
>>> y_batch1 = np.array([0, 1, 0, 1, 0])  # first batch missing class 2
>>> X_batch1 = X[:5]

Without set_classes, initialization fails if batch lacks all classes:

>>> init_params = model.initialize_params(X_batch1, y_batch1)
Traceback (most recent call last):
RuntimeError: Classes are not set. Must call ``set_classes`` before calling...

Call set_classes first to define all labels, then initialize:

>>> model.set_classes(y_all_classes)
ClassifierGLM(...)
>>> init_params = model.initialize_params(X_batch1, y_batch1)
>>> state = model.initialize_optimizer_and_state(init_params, X_batch1, y_batch1)

Now batches with any subset of classes work with update():

>>> result = model.update(init_params, state, X_batch1, y_batch1)
set_fit_request(*, init_params='$UNCHANGED$')#

Configure whether metadata should be requested to be passed to the fit method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Parameters:
  • init_params (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for init_params parameter in fit.

  • self (ClassifierGLM)

Returns:

self – The updated object.

Return type:

object

set_params(**params)#

Manage warnings in case of multiple parameter settings.

Parameters:

params (Any)

set_predict_proba_request(*, return_type='$UNCHANGED$')#

Configure whether metadata should be requested to be passed to the predict_proba method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to predict_proba if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to predict_proba.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Parameters:
  • return_type (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for return_type parameter in predict_proba.

  • self (ClassifierGLM)

Returns:

self – The updated object.

Return type:

object

set_score_request(*, aggregate_sample_scores='$UNCHANGED$', score_type='$UNCHANGED$')#

Configure whether metadata should be requested to be passed to the score method.

Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with enable_metadata_routing=True (see sklearn.set_config()). Please check the User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to score if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to score.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Parameters:
  • aggregate_sample_scores (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for aggregate_sample_scores parameter in score.

  • score_type (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for score_type parameter in score.

  • self (ClassifierGLM)

Returns:

self – The updated object.

Return type:

object

simulate(random_key, feedforward_input)#

Simulate categorical responses from the model.

Parameters:
  • random_key (jax.Array) – A JAX random key used to generate the simulated responses.

  • feedforward_input (DESIGN_INPUT_TYPE) – The input samples used to generate the responses. Can be an array of shape (n_samples, n_features) or a pytree of arrays of the same shape.

Return type:

Tuple[jnp.ndarray, jnp.ndarray]

Returns:

A tuple (y, log_prob) where: - y is an array of shape (n_samples,) containing the

simulated class labels (in the same format as classes_).

  • log_prob is an array of shape (n_samples,) containing the log-probability of the simulated responses under the model.

Raises:

RuntimeError – If classes_ has not been set. Call set_classes() or fit() before calling this method.

Examples

>>> import jax
>>> import jax.numpy as jnp
>>> import nemos as nmo
>>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
>>> y = jnp.array([0, 0, 1, 1])
>>> model = nmo.glm.ClassifierGLM(n_classes=2).fit(X, y)
>>> key = jax.random.key(0)
>>> simulated_y, log_prob = model.simulate(key, X)
>>> simulated_y.shape
(4,)
property solver#

Getter for the solver class.

property solver_kwargs#

Getter for the solver_kwargs attribute.

property solver_name: str#

Getter for the solver_name attribute.

property solver_spec: SolverSpec#

Getter for the solver specification.

update(params, opt_state, X, y, *args, n_samples=None, **kwargs)#

Update the model parameters and solver state.

Performs a single optimization step using the model’s solver. Class labels are automatically encoded to internal indices and converted to one-hot encoding before the update.

Important: Labels of any dtype (integers, floats, strings, etc.) are supported and will be encoded using the classes_ attribute set via set_classes(). For best performance, use integer labels [0, n_classes - 1].

Parameters:
  • params (GLMUserParams) – The current model parameters, typically a tuple of coefficients and intercepts.

  • opt_state (SolverState) – The current state of the optimizer, encapsulating information necessary for the optimization algorithm to continue from the current state.

  • X (DESIGN_INPUT_TYPE) – The predictors used in the model fitting process. Shape (n_time_bins, n_features) or a pytree of arrays of the same shape.

  • y (jnp.ndarray) – Class labels, array of shape (n_time_bins,) for single neuron models or (n_time_bins, n_neurons) for population models. Labels must match those defined in classes_.

  • *args – Additional positional arguments to be passed to the solver’s update method.

  • n_samples (Optional[int]) – The total number of samples. Usually larger than the samples of an individual batch, used to estimate the scale parameter of the GLM.

  • **kwargs – Additional keyword arguments to be passed to the solver’s update method.

Return type:

StepResult

Returns:

  • params – Updated model parameters (coefficients, intercepts).

  • state – Updated optimizer state.

Examples

>>> import jax.numpy as jnp
>>> import nemos as nmo
>>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
>>> y = jnp.array([0, 0, 1, 1])
>>> model = nmo.glm.ClassifierGLM(n_classes=2)
>>> model.set_classes(y)
ClassifierGLM(...)
>>> params = model.initialize_params(X, y)
>>> opt_state = model.initialize_optimizer_and_state(params, X, y)
>>> new_params, new_state = model.update(params, opt_state, X, y)