nemos.glm.ClassifierGLM#
- class nemos.glm.ClassifierGLM(n_classes=2, inverse_link_function=None, regularizer=None, regularizer_strength=None, solver_name=None, solver_kwargs=None)[source]#
Bases:
ClassifierMixin,GLMGeneralized Linear Model for multi-class classification.
This model predicts discrete class labels from input features using a softmax (multinomial logistic) model. It uses an over-parameterized representation with one set of coefficients per class, resulting in coefficient shape
(n_features, n_classes)and intercept shape(n_classes,).- Parameters:
n_classes (
Optional[int]) – The number of classes. Must be >= 2.inverse_link_function (
Optional[Callable]) – The inverse link function. Default islog_softmax.regularizer (
Union[str,Regularizer,None]) – The regularization scheme. Default isRidge. Note that the model is over-parameterized: one set of coefficients for each class. Regularization makes the parameters identifiable. SettingUnRegularizedwill result in non-identifiable coefficients, see note below.regularizer_strength (
Any) – The strength of the regularization.solver_name (
str) – The solver to use for optimization.solver_kwargs (
dict) – Additional keyword arguments for the solver.
Notes
Identifiability
This model uses an over-parameterized (symmetric) representation where each class has its own set of coefficients. Since probabilities from softmax are invariant to adding a constant to all linear predictors, the parameters are not uniquely identifiable without regularization. For example, if
(coef, intercept)is a solution, so is(coef + c, intercept + c)for any constantc.Using regularization (default is
Ridge) resolves this ambiguity by penalizing the parameter magnitudes, effectively centering the solution. If you useUnRegularized, the optimization may converge to different equivalent solutions depending on initialization, though predictions will be identical.Class Labels
The target array
ycan contain any hashable class labels that can be stored in a NumPy array, including integers, strings, or other hashable types. The model internally maps these labels to indices[0, n_classes - 1]for computation and maps them back when returning predictions.Performance Considerations
For optimal performance, use integer labels
[0, 1, ..., n_classes - 1]. When labels follow this convention, the model skips the encoding/decoding steps entirely. Using other label formats (e.g.,["cat", "dog"]or[5, 10, 15]) incurs a small overhead for label translation.Setting Class Labels
The
fit()andinitialize_optimizer_and_state()methods automatically infer class labels from the providedy. If you setcoef_andintercept_manually, you must callset_classes()before usingpredict(),predict_proba(),simulate(),score(), orcompute_loss().See also
ClassifierPopulationGLMMulti-class classification for multiple neurons.
GLMGeneralized Linear Model for continuous/count responses.
Examples
Fit a ClassifierGLM
Basic binary classification:
>>> import jax.numpy as jnp >>> import numpy as np >>> import nemos as nmo >>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]) >>> y = jnp.array([0, 0, 1, 1]) >>> model = nmo.glm.ClassifierGLM(n_classes=2).fit(X, y) >>> model.coef_.shape (2, 2)
Predict Class Labels
Get predicted class labels:
>>> predictions = model.predict(X) >>> predictions.shape (4,)
Predict Class Probabilities
Get class probabilities or log-probabilities:
>>> proba = model.predict_proba(X, return_type="proba") >>> proba.shape (4, 2) >>> log_proba = model.predict_proba(X, return_type="log-proba") >>> log_proba.shape (4, 2)
Use String Labels
Class labels can be strings or any hashable type:
>>> y_str = np.array(["cat", "cat", "dog", "dog"]) >>> model = nmo.glm.ClassifierGLM(n_classes=2).fit(X, y_str) >>> model.classes_ array(['cat', 'dog'], dtype='<U3') >>> model.predict(X) array(['cat', 'cat', 'dog', 'dog'], dtype='<U3')
Multi-class Classification
Classify into more than two classes:
>>> X = jnp.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0], [5.0, 6.0], [6.0, 7.0]]) >>> y = jnp.array([0, 0, 1, 1, 2, 2]) >>> model = nmo.glm.ClassifierGLM(n_classes=3).fit(X, y) >>> model.coef_.shape (2, 3)
Use Regularization
Change regularization strength:
>>> model = nmo.glm.ClassifierGLM( ... n_classes=2, ... regularizer="Ridge", ... regularizer_strength=0.5 ... ) >>> model.regularizer Ridge()
Use a Pytree of arrays as Input
Features can be passed as any JAX pytree of 2-D arrays; the fitted
coef_will share the same pytree structure:>>> X_dict = {"feature_1": X[:, :1], "feature_2": X[:, 1:]} >>> model = nmo.glm.ClassifierGLM(n_classes=3).fit(X_dict, y) >>> # The coefficient structure matches the input >>> type(model.coef_) <class 'dict'>
Attributes
Class labels, or None if not set.
Inverse link function mapping the linear predictor to the response space.
Number of classes.
The observation model governing the conditional distribution of
y.Provides the initialization function for the optimizer state.
Provides the function to execute the optimization process.
Provides the function for updating the state during the optimization process.
Getter for the regularizer attribute.
Regularizer strength getter.
Getter for the solver class.
Getter for the solver_kwargs attribute.
Getter for the solver_name attribute.
Getter for the solver specification.
- __init__(n_classes=2, inverse_link_function=None, regularizer=None, regularizer_strength=None, solver_name=None, solver_kwargs=None)[source]#
Methods
__init__([n_classes, inverse_link_function, ...])compute_loss(params, X, y, *args, **kwargs)Compute the loss function for the model.
fit(X, y[, init_params])Fit the model to training data.
get_params([deep])From scikit-learn, get parameters by inspecting init.
initialize_optimizer_and_state(init_params, X, y)Initialize the solver and its state for running fit and update.
initialize_params(X, y)Initialize model parameters for categorical GLM.
predict(X)Predict class labels for samples in X.
predict_proba(X[, return_type])Predict class probabilities for samples in X.
save_params(filename)Save GLM model parameters to a .npz file.
score(X, y[, score_type, ...])Score the model on test data.
set_classes(y)Infer unique class labels and set the
classes_attribute.set_params(**params)Manage warnings in case of multiple parameter settings.
simulate(random_key, feedforward_input)Simulate categorical responses from the model.
update(params, opt_state, X, y, *args[, ...])Update the model parameters and solver state.
- classmethod __init_subclass__(**kwargs)#
Set the
set_{method}_requestmethods.This uses PEP-487 [1] to set the
set_{method}_requestmethods. It looks for the information available in the set default values which are set using__metadata_request__*class attributes, or inferred from method signatures.The
__metadata_request__*class attributes are used when a method does not explicitly accept a metadata through its arguments or if the developer would like to specify a request value for those metadata which are different from the defaultNone.References
- __repr__()#
Representation of the GLM class.
- __sklearn_tags__()#
Return GLM specific estimator tags.
- compute_loss(params, X, y, *args, **kwargs)#
Compute the loss function for the model.
This method validates inputs, encodes class labels to internal indices, and computes the loss (negative log-likelihood).
- Parameters:
params – Parameter tuple of (coefficients, intercept).
X – Input data, array of shape
(n_time_bins, n_features)or pytree of same.y – Target class labels in the same format as
classes_.*args – Additional positional arguments passed to the model-specific loss function.
**kwargs – Additional keyword arguments passed to the model-specific loss function.
- Returns:
The loss value (negative log-likelihood).
- Return type:
loss
- Raises:
RuntimeError – If
classes_has not been set.ValueError – If inputs or parameters have incompatible shapes or invalid values.
- fit(X, y, init_params=None)[source]#
Fit the model to training data.
- Parameters:
X (Union[DESIGN_INPUT_TYPE, ArrayLike]) – Training input samples of shape
(n_samples, n_features)or a pytree of arrays of the same shape.y (ArrayLike) – Target class labels of shape
(n_samples,). Labels can be any hashable type (integers, strings, etc.). Float arrays with integer values are accepted and converted automatically.init_params (Optional[GLMUserParams]) – Initial parameter values as tuple of
(coef, intercept). If None, parameters are initialized automatically.
- Returns:
The fitted model.
Notes
fitcallsset_classes()internally, soclasses_is always consistent with the labels iny.Examples
>>> import jax.numpy as jnp >>> import nemos as nmo >>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]) >>> y = jnp.array([0, 0, 1, 1]) >>> model = nmo.glm.ClassifierGLM(n_classes=2) >>> model = model.fit(X, y) >>> model.coef_.shape (2, 2)
- get_metadata_routing()#
Get metadata routing of this object.
Please check User Guide on how the routing mechanism works.
- Returns:
routing – A
MetadataRequestencapsulating routing information.- Return type:
MetadataRequest
- get_params(deep=True)#
From scikit-learn, get parameters by inspecting init.
- Parameters:
deep – If True, will return the parameters for this estimator and contained subobjects that are estimators.
- Return type:
- Returns:
A dictionary containing the parameters. Key is the parameter name, value is the parameter value.
- initialize_optimizer_and_state(init_params, X, y)#
Initialize the solver and its state for running fit and update.
This method must be called before using
update()for iterative optimization. It sets up the solver with the provided initial parameters and data.- Parameters:
init_params (UserProvidedParamsT) – Initial parameter tuple of (coefficients, intercept).
X (DESIGN_INPUT_TYPE) – Input data, array of shape
(n_time_bins, n_features)or pytree of same.y (jnp.ndarray) – Target labels, array of shape
(n_time_bins,)for single neuron/subject models or(n_time_bins, n_neurons)for population models.
- Returns:
Initial solver state.
- Return type:
SolverState
- Raises:
ValueError – If inputs or parameters have incompatible shapes or invalid values.
- initialize_params(X, y)#
Initialize model parameters for categorical GLM.
Initialize coefficients with zeros and intercept by matching the mean class proportions. Class labels are automatically converted to one-hot encoding.
- Parameters:
X (DESIGN_INPUT_TYPE) – Input data, array of shape
(n_time_bins, n_features)or pytree of same.y (jnp.ndarray) – Class labels, array of shape
(n_time_bins,)for single neuron models or(n_time_bins, n_neurons)for population models. Labels must be a subset ofclasses_.
- Return type:
UserProvidedParamsT
- Returns:
Initial parameter tuple of (coefficients, intercept).
Notes
All labels in
ymust be present inclasses_. Passing labels not inclasses_will raise an error.Examples
>>> import jax.numpy as jnp >>> import nemos as nmo >>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]) >>> y = jnp.array([0, 0, 1, 1]) >>> model = nmo.glm.ClassifierGLM(n_classes=2) >>> model.set_classes(y) ClassifierGLM(...) >>> coef, intercept = model.initialize_params(X, y) >>> coef.shape (2, 2)
- property inverse_link_function#
Inverse link function mapping the linear predictor to the response space.
Always a callable. If
Nonewas passed at construction time, this is resolved to the observation model’s default (e.g.jnp.expfor Poisson,1 / xfor Gamma,jax.nn.sigmoidfor Bernoulli).
- property n_classes#
Number of classes.
- property observation_model: None | Observations#
The observation model governing the conditional distribution of
y.Always an instance of an
Observationssubclass. If a string alias was passed at construction time it is resolved to the corresponding instance here.
- property optimizer_init_state: None | Callable[[Any, Array, Array], SolverState]#
Provides the initialization function for the optimizer state.
This function is responsible for initializing the optimizer state, necessary for the start of the optimizer process. It sets up initial values for parameters like gradients and step sizes based on the model configuration and input data.
- Returns:
The function to initialize the optimizer state, if available; otherwise, None if the optimizer has not yet been instantiated.
- property optimizer_run: None | Callable[[Any, Array, Array], Tuple[Any, SolverState, Aux]]#
Provides the function to execute the optimization process.
This function runs the optimizer using the initialized parameters and state, performing the optimization to fit the model to the data. It iteratively updates the model parameters until a stopping criterion is met, such as convergence or exceeding a maximum number of iterations.
- Returns:
The function to run the optimization process, if available; otherwise, None if the optimizer has not yet been instantiated.
- property optimizer_update: None | Callable[[Any, NamedTuple, Array, Array], Tuple[Any, SolverState, Aux]]#
Provides the function for updating the state during the optimization process.
This function is used to perform a single update step in the optimization process. It updates the model’s parameters based on the current state, data, and gradients. It is typically used in scenarios where fine-grained control over each optimizer step is necessary, such as in online learning or complex optimization scenarios.
- Returns:
The function to perform a single optimization update step, if available; otherwise, None if the optimizer has not yet been instantiated.
- predict(X)#
Predict class labels for samples in X.
- Parameters:
X (DESIGN_INPUT_TYPE) – The input samples. Can be an array of shape
(n_samples, n_features)or a pytree of arrays of the same shape.- Return type:
jnp.ndarray
- Returns:
Predicted class labels for each sample. Returns an integer array of shape
(n_samples, )with values in[0, n_classes - 1].
Examples
>>> import jax.numpy as jnp >>> import nemos as nmo >>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]) >>> y = jnp.array([0, 0, 1, 1]) >>> model = nmo.glm.ClassifierGLM(n_classes=2).fit(X, y) >>> predictions = model.predict(X) >>> predictions.shape (4,)
- predict_proba(X, return_type='log-proba')#
Predict class probabilities for samples in X.
- Parameters:
X (DESIGN_INPUT_TYPE) – The input samples. Can be an array of shape
(n_samples, n_features)or a pytree of arrays of the same shape.return_type (Literal[‘log-proba’, ‘proba’]) – The format of the returned probabilities. If
"log-proba", returns log-probabilities. If"proba", returns probabilities. Defaults to"log-proba".
- Return type:
jnp.ndarray
- Returns:
Predicted class probabilities. Returns an array of shape
(n_samples, n_classes)where each row sums to 1 (for probabilities) or to 0 in log-space (for log-probabilities).
Examples
>>> import jax.numpy as jnp >>> import nemos as nmo >>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]) >>> y = jnp.array([0, 0, 1, 1]) >>> model = nmo.glm.ClassifierGLM(n_classes=2).fit(X, y) >>> proba = model.predict_proba(X, return_type="proba") >>> proba.shape (4, 2)
- property regularizer: None | Regularizer#
Getter for the regularizer attribute.
- save_params(filename)#
Save GLM model parameters to a .npz file.
This method allows to reuse the model parameters. The saved parameters can be loaded back into a GLM instance using the load_params function.
- Parameters:
filename (
Union[str,Path]) – The name of the file where the model parameters will be saved. The file will be saved in .npz format.
Examples
>>> import nemos as nmo >>> # Create a GLM model with specified parameters >>> solver_args = {"stepsize": 0.1, "maxiter": 1000, "tol": 1e-6} >>> model = nmo.glm.GLM( ... regularizer="Ridge", ... regularizer_strength=0.1, ... observation_model="Gamma", ... solver_name="BFGS", ... solver_kwargs=solver_args, ... ) >>> for key, value in model.get_params().items(): ... print(f"{key}: {value}") inverse_link_function: <function one_over_x at ...> observation_model: GammaObservations() regularizer: Ridge() regularizer_strength: 0.1... solver_kwargs: {'stepsize': 0.1, 'maxiter': 1000, 'tol': 1e-06} solver_name: BFGS >>> # Save the model parameters to a file >>> model.save_params("model_params.npz") >>> # Load the model from the saved file >>> model = nmo.load_model("model_params.npz") >>> # Model has the same parameters before and after load >>> for key, value in model.get_params().items(): ... print(f"{key}: {value}") inverse_link_function: <function one_over_x at ...> observation_model: GammaObservations() regularizer: Ridge() regularizer_strength: 0.1 solver_kwargs: {'maxiter': 1000, 'stepsize': 0.1, 'tol': 1e-06} solver_name: BFGS
>>> # Saving and loading a custom inverse link function >>> model = nmo.glm.GLM( ... observation_model="Poisson", ... inverse_link_function=lambda x: x**2 ... ) >>> model.save_params("model_params.npz") >>> # Provide a mapping for the custom link function when loading. >>> mapping_dict = { ... "inverse_link_function": lambda x: x**2, ... } >>> loaded_model = nmo.load_model("model_params.npz", mapping_dict=mapping_dict) >>> # Now the loaded model will have the updated solver_name and solver_kwargs >>> for key, value in loaded_model.get_params().items(): ... print(f"{key}: {value}") inverse_link_function: <function <lambda> at ...> observation_model: PoissonObservations() regularizer: UnRegularized() regularizer_strength: None solver_kwargs: {} solver_name: LBFGS
- score(X, y, score_type='log-likelihood', aggregate_sample_scores=<function mean>)[source]#
Score the model on test data.
- Parameters:
X (Union[DESIGN_INPUT_TYPE, ArrayLike]) – Test input samples of shape
(n_samples, n_features)or a pytree of arrays of the same shape.y (ArrayLike) – True class labels of shape
(n_samples,). Labels must be a subset ofclasses_.score_type (Literal[‘log-likelihood’, ‘pseudo-r2-McFadden’, ‘pseudo-r2-Cohen’]) – The type of score to compute.
aggregate_sample_scores (Optional[Callable]) – Function to aggregate per-sample scores.
- Return type:
jnp.ndarray
- Returns:
The computed score.
Notes
All labels in
ymust be present inclasses_. Passing labels not inclasses_will raise an error.Examples
>>> import jax.numpy as jnp >>> import nemos as nmo >>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]) >>> y = jnp.array([0, 0, 1, 1]) >>> model = nmo.glm.ClassifierGLM(n_classes=2).fit(X, y) >>> score = model.score(X, y)
- set_classes(y)#
Infer unique class labels and set the
classes_attribute.This method infers class labels from
yand sets up the internal encoding/decoding machinery. When labels are the default[0, 1, ..., n_classes-1], encoding is skipped for performance.- Parameters:
y (ArrayLike) – An array that must contain all the class labels, i.e.
len(np.unique(y)) == n_classes.- Raises:
ValueError – If the number of unique class labels in
ydoes not matchn_classes.- Return type:
ClassifierMixin
Notes
fit()andinitialize_optimizer_and_state()callset_classesinternally, making sure that theclasses_attribute matches the provided input. If you are fitting in batches by callingupdate(), make sure that theclasses_are correctly set by callingset_classesbefore starting theupdate()loop.Examples
When fitting in batches with
update(), useset_classesto define all class labels before initialization. This is necessary when individual batches may not contain all classes.>>> import nemos as nmo >>> import numpy as np >>> model = nmo.glm.ClassifierGLM(3)
Generate sample data where the first batch only contains 2 of 3 classes:
>>> X = np.random.randn(100, 5) >>> y_all_classes = np.array([0, 1, 2]) # all possible classes >>> y_batch1 = np.array([0, 1, 0, 1, 0]) # first batch missing class 2 >>> X_batch1 = X[:5]
Without
set_classes, initialization fails if batch lacks all classes:>>> init_params = model.initialize_params(X_batch1, y_batch1) Traceback (most recent call last): RuntimeError: Classes are not set. Must call ``set_classes`` before calling...
Call
set_classesfirst to define all labels, then initialize:>>> model.set_classes(y_all_classes) ClassifierGLM(...) >>> init_params = model.initialize_params(X_batch1, y_batch1) >>> state = model.initialize_optimizer_and_state(init_params, X_batch1, y_batch1)
Now batches with any subset of classes work with
update():>>> result = model.update(init_params, state, X_batch1, y_batch1)
- set_fit_request(*, init_params='$UNCHANGED$')#
Configure whether metadata should be requested to be passed to the
fitmethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- Parameters:
init_params (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
init_paramsparameter infit.self (ClassifierGLM)
- Returns:
self – The updated object.
- Return type:
- set_params(**params)#
Manage warnings in case of multiple parameter settings.
- Parameters:
params (Any)
- set_predict_proba_request(*, return_type='$UNCHANGED$')#
Configure whether metadata should be requested to be passed to the
predict_probamethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed topredict_probaif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it topredict_proba.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- Parameters:
return_type (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
return_typeparameter inpredict_proba.self (ClassifierGLM)
- Returns:
self – The updated object.
- Return type:
- set_score_request(*, aggregate_sample_scores='$UNCHANGED$', score_type='$UNCHANGED$')#
Configure whether metadata should be requested to be passed to the
scoremethod.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True(seesklearn.set_config()). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed toscoreif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it toscore.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- Parameters:
aggregate_sample_scores (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
aggregate_sample_scoresparameter inscore.score_type (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
score_typeparameter inscore.self (ClassifierGLM)
- Returns:
self – The updated object.
- Return type:
- simulate(random_key, feedforward_input)#
Simulate categorical responses from the model.
- Parameters:
random_key (jax.Array) – A JAX random key used to generate the simulated responses.
feedforward_input (DESIGN_INPUT_TYPE) – The input samples used to generate the responses. Can be an array of shape
(n_samples, n_features)or a pytree of arrays of the same shape.
- Return type:
Tuple[jnp.ndarray, jnp.ndarray]
- Returns:
A tuple
(y, log_prob)where: -yis an array of shape(n_samples,)containing thesimulated class labels (in the same format as
classes_).log_probis an array of shape(n_samples,)containing the log-probability of the simulated responses under the model.
- Raises:
RuntimeError – If
classes_has not been set. Callset_classes()orfit()before calling this method.
Examples
>>> import jax >>> import jax.numpy as jnp >>> import nemos as nmo >>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]) >>> y = jnp.array([0, 0, 1, 1]) >>> model = nmo.glm.ClassifierGLM(n_classes=2).fit(X, y) >>> key = jax.random.key(0) >>> simulated_y, log_prob = model.simulate(key, X) >>> simulated_y.shape (4,)
- property solver#
Getter for the solver class.
- property solver_kwargs#
Getter for the solver_kwargs attribute.
- property solver_spec: SolverSpec#
Getter for the solver specification.
- update(params, opt_state, X, y, *args, n_samples=None, **kwargs)#
Update the model parameters and solver state.
Performs a single optimization step using the model’s solver. Class labels are automatically encoded to internal indices and converted to one-hot encoding before the update.
Important: Labels of any dtype (integers, floats, strings, etc.) are supported and will be encoded using the
classes_attribute set viaset_classes(). For best performance, use integer labels[0, n_classes - 1].- Parameters:
params (GLMUserParams) – The current model parameters, typically a tuple of coefficients and intercepts.
opt_state (SolverState) – The current state of the optimizer, encapsulating information necessary for the optimization algorithm to continue from the current state.
X (DESIGN_INPUT_TYPE) – The predictors used in the model fitting process. Shape
(n_time_bins, n_features)or a pytree of arrays of the same shape.y (jnp.ndarray) – Class labels, array of shape
(n_time_bins,)for single neuron models or(n_time_bins, n_neurons)for population models. Labels must match those defined inclasses_.*args – Additional positional arguments to be passed to the solver’s update method.
n_samples (Optional[int]) – The total number of samples. Usually larger than the samples of an individual batch, used to estimate the scale parameter of the GLM.
**kwargs – Additional keyword arguments to be passed to the solver’s update method.
- Return type:
StepResult
- Returns:
params – Updated model parameters (coefficients, intercepts).
state – Updated optimizer state.
Examples
>>> import jax.numpy as jnp >>> import nemos as nmo >>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]) >>> y = jnp.array([0, 0, 1, 1]) >>> model = nmo.glm.ClassifierGLM(n_classes=2) >>> model.set_classes(y) ClassifierGLM(...) >>> params = model.initialize_params(X, y) >>> opt_state = model.initialize_optimizer_and_state(params, X, y) >>> new_params, new_state = model.update(params, opt_state, X, y)