nemos.solvers._optax_optimistix_solvers.AbstractOptimistixOptaxSolver#
- class nemos.solvers._optax_optimistix_solvers.AbstractOptimistixOptaxSolver(unregularized_loss, regularizer, regularizer_strength, has_aux, init_params=None, tol=0.0001, rtol=0.0, maxiter=None, **solver_init_kwargs)[source]#
Bases:
OptimistixAdapter,ABCAdapter for optimistix.OptaxMinimiser which is an adapter for Optax solvers.
Accepted arguments:#
adjoint
has_aux
init_params
maxiter
norm
optim
options
rtol
tags
throw
tol
verbose
Note that for backward compatibility the atol parameter used in Optimistix is referred to as tol in NeMoS.
OptaxMinimiser’s documentation:#
A wrapper to use Optax first-order gradient-based optimisers with [optimistix.minimise][].
More info from OptaxMinimiser.__init__#
Arguments:
optim: The Optax optimiser to use.
rtol: Relative tolerance for terminating the solve. Keyword only argument.
atol: Absolute tolerance for terminating the solve. Keyword only argument.
- norm: The norm used to determine the difference between two iterates in the
convergence criteria. Should be any function PyTree -> Scalar. Optimistix includes three built-in norms: [optimistix.max_norm][], [optimistix.rms_norm][], and [optimistix.two_norm][]. Keyword only argument.
- verbose: Whether to print out extra information about how the solve is
proceeding. Can either be False to print out nothing, or True to print out all information, or (for customisation) a callable **kwargs -> None. If provided as a callable then each value will be a 2-tuple of (str, jax.Array) providing a human-readable name and its corresponding value.
Attributes
- __init__(unregularized_loss, regularizer, regularizer_strength, has_aux, init_params=None, tol=0.0001, rtol=0.0, maxiter=None, **solver_init_kwargs)#
Create the solver.
- Parameters:
unregularized_loss (
Callable) – Unregularized loss function. Currently BaseRegressor.compute_loss.regularizer (
Regularizer) – Regularizer object used to create the penalized loss or get the proximal operator from.has_aux (
bool) – Whether unregularized_loss returns auxiliary variables. If False, the loss function is expected to return a single scalar. If True, the loss is expected to return a tuple of length 2 with a scalar and auxiliary variables.init_params (
Any|None) – Initial model parameters. Passed to the regularizer’s get_proximal_operator or penalized_loss.**solver_init_kwargs – Keyword arguments modifying the solver’s behavior.
tol (float)
rtol (float)
maxiter (int | None)
Methods
__init__(unregularized_loss, regularizer, ...)Create the solver.
adjust_solver_init_kwargs(solver_init_kwargs)Optionally adjust the parameters (e.g. derive from self.config) for instantiating the wrapped solver.
Set of accepted argument names, extended with the wrapped solver's arguments.
init_state(init_params, *args)Initialize the solver state.
run(init_params, *args)Run a full optimization process until a stopping criterion is reached.
update(params, state, *args)Perform a single step/update of the optimization process.
- __getattr__(name)#
Try getting undefined attributes from the underlying solver.
- Parameters:
name (str)
- adjust_solver_init_kwargs(solver_init_kwargs)#
Optionally adjust the parameters (e.g. derive from self.config) for instantiating the wrapped solver.
- classmethod get_accepted_arguments()#
Set of accepted argument names, extended with the wrapped solver’s arguments.
- init_state(init_params, *args)#
Initialize the solver state.
Used by BaseRegressor.initialize_state
- run(init_params, *args)#
Run a full optimization process until a stopping criterion is reached.
Used by BaseRegressor.fit.