nemos.solvers.AbstractSolver#
- class nemos.solvers.AbstractSolver(unregularized_loss, regularizer, regularizer_strength, has_aux, init_params=None, **solver_init_kwargs)[source]#
-
Base class defining the interface for solvers that can be used by BaseRegressor.
All solver parameters (e.g. tolerance, number of steps) are passed to __init__, the other methods only take parameters, solver state, and the positional arguments of the objective function.
- Parameters:
unregularized_loss (Callable)
regularizer (Regularizer)
regularizer_strength (float | None)
has_aux (bool)
init_params (Params | None)
- abstractmethod __init__(unregularized_loss, regularizer, regularizer_strength, has_aux, init_params=None, **solver_init_kwargs)[source]#
Create the solver.
- Parameters:
unregularized_loss (
Callable) – Unregularized loss function. Currently BaseRegressor.compute_loss.regularizer (
Regularizer) – Regularizer object used to create the penalized loss or get the proximal operator from.has_aux (
bool) – Whether unregularized_loss returns auxiliary variables. If False, the loss function is expected to return a single scalar. If True, the loss is expected to return a tuple of length 2 with a scalar and auxiliary variables.init_params (
Any|None) – Initial model parameters. Passed to the regularizer’s get_proximal_operator or penalized_loss.**solver_init_kwargs – Keyword arguments modifying the solver’s behavior.
Methods
__init__(unregularized_loss, regularizer, ...)Create the solver.
Set of argument names accepted by the solver.
init_state(init_params, *args)Initialize the solver state.
run(init_params, *args)Run a full optimization process until a stopping criterion is reached.
update(params, state, *args)Perform a single step/update of the optimization process.
- abstractmethod classmethod get_accepted_arguments()[source]#
Set of argument names accepted by the solver.
Used by BaseRegressor to determine what arguments can be passed to the solver’s __init__.
- abstractmethod init_state(init_params, *args)[source]#
Initialize the solver state.
Used by BaseRegressor.initialize_state