nemos.glm.PopulationGLM.update#
- PopulationGLM.update(params, opt_state, X, y, *args, n_samples=None, **kwargs)#
Update the model parameters and solver state.
This method performs a single optimization step using the model’s current solver. It updates the model’s coefficients and intercept based on the provided parameters, predictors (X), responses (y), and the current optimization state. This method is particularly useful for iterative model fitting, especially in scenarios where model parameters need to be updated incrementally, such as online learning or when dealing with very large datasets that do not fit into memory at once.
- Parameters:
params (
Tuple
[Array
,Array
]) – The current model parameters, typically a tuple of coefficients and intercepts.opt_state (
NamedTuple
) – The current state of the optimizer, encapsulating information necessary for the optimization algorithm to continue from the current state. This includes gradients, step sizes, and other optimizer-specific metrics.X (
Union
[Array
,FeaturePytree
]) – The predictors used in the model fitting process, which may include feature matrices ornemos.pytrees.FeaturePytree
objects.y (
Array
) – The response variable or output data corresponding to the predictors, used in the model fitting process.*args – Additional positional arguments to be passed to the solver’s update method.
n_samples (
Optional
[int
]) – The tot number of samples. Usually larger than the samples of an indivisual batch, then_samples
are used to estimate the scale parameter of the GLM.**kwargs – Additional keyword arguments to be passed to the solver’s update method.
- Returns:
A tuple containing the updated parameters and optimization state. This tuple is typically used to continue the optimization process in subsequent steps.
- Return type:
jaxopt.OptStep
- Raises:
ValueError – If the solver has not been instantiated or if the solver returns NaN values indicating an invalid update step, typically due to numerical instabilities or inappropriate solver configurations.
Examples
>>> import nemos as nmo >>> import numpy as np >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10) >>> glm_instance = nmo.glm.GLM().fit(X, y) >>> params = glm_instance.coef_, glm_instance.intercept_ >>> opt_state = glm_instance.solver_state_ >>> new_params, new_opt_state = glm_instance.update(params, opt_state, X, y)