nemos.solvers._optax_optimistix_solvers.AbstractOptimistixOptaxSolver#

class nemos.solvers._optax_optimistix_solvers.AbstractOptimistixOptaxSolver(unregularized_loss, regularizer, regularizer_strength, has_aux, init_params=None, tol=0.0001, rtol=0.0, maxiter=None, **solver_init_kwargs)[source]#

Bases: OptimistixAdapter, ABC

Adapter for optimistix.OptaxMinimiser which is an adapter for Optax solvers.

Accepted arguments:#

  • adjoint

  • has_aux

  • init_params

  • maxiter

  • norm

  • optim

  • options

  • rtol

  • tags

  • throw

  • tol

  • verbose

Note that for backward compatibility the atol parameter used in Optimistix is referred to as tol in NeMoS.

OptaxMinimiser’s documentation:#

A wrapper to use Optax first-order gradient-based optimisers with [optimistix.minimise][].

More info from OptaxMinimiser.__init__#

Arguments:

  • optim: The Optax optimiser to use.

  • rtol: Relative tolerance for terminating the solve. Keyword only argument.

  • atol: Absolute tolerance for terminating the solve. Keyword only argument.

  • norm: The norm used to determine the difference between two iterates in the

    convergence criteria. Should be any function PyTree -> Scalar. Optimistix includes three built-in norms: [optimistix.max_norm][], [optimistix.rms_norm][], and [optimistix.two_norm][]. Keyword only argument.

  • verbose: Whether to print out extra information about how the solve is

    proceeding. Can either be False to print out nothing, or True to print out all information, or (for customisation) a callable **kwargs -> None. If provided as a callable then each value will be a 2-tuple of (str, jax.Array) providing a human-readable name and its corresponding value.

Attributes

__init__(unregularized_loss, regularizer, regularizer_strength, has_aux, init_params=None, tol=0.0001, rtol=0.0, maxiter=None, **solver_init_kwargs)#

Create the solver.

Parameters:
  • unregularized_loss (Callable) – Unregularized loss function. Currently BaseRegressor.compute_loss.

  • regularizer (Regularizer) – Regularizer object used to create the penalized loss or get the proximal operator from.

  • regularizer_strength (float | None) – Regularizer strength.

  • has_aux (bool) – Whether unregularized_loss returns auxiliary variables. If False, the loss function is expected to return a single scalar. If True, the loss is expected to return a tuple of length 2 with a scalar and auxiliary variables.

  • init_params (Any | None) – Initial model parameters. Passed to the regularizer’s get_proximal_operator or penalized_loss.

  • **solver_init_kwargs – Keyword arguments modifying the solver’s behavior.

  • tol (float)

  • rtol (float)

  • maxiter (int | None)

Methods

__init__(unregularized_loss, regularizer, ...)

Create the solver.

adjust_solver_init_kwargs(solver_init_kwargs)

Optionally adjust the parameters (e.g. derive from self.config) for instantiating the wrapped solver.

get_accepted_arguments()

Set of accepted argument names, extended with the wrapped solver's arguments.

init_state(init_params, *args)

Initialize the solver state.

run(init_params, *args)

Run a full optimization process until a stopping criterion is reached.

update(params, state, *args)

Perform a single step/update of the optimization process.

DEFAULT_MAXITER: ClassVar[int] = 10000#
__getattr__(name)#

Try getting undefined attributes from the underlying solver.

Parameters:

name (str)

adjust_solver_init_kwargs(solver_init_kwargs)#

Optionally adjust the parameters (e.g. derive from self.config) for instantiating the wrapped solver.

Parameters:

solver_init_kwargs (dict[str, Any]) – Original keyword arguments that would be passed to _solver_cls.__init__.

Return type:

dict[str, Any]

Returns:

  • dict with argument names of _solver_cls.__init__ as keys and

  • their corresponding values as values.

  • Default implementation just returns solver_init_kwargs.

classmethod get_accepted_arguments()#

Set of accepted argument names, extended with the wrapped solver’s arguments.

Return type:

set[str]

init_state(init_params, *args)#

Initialize the solver state.

Used by BaseRegressor.initialize_state

Return type:

OptimistixAdapterState

Parameters:
  • init_params (Any)

  • args (Any)

property maxiter: int#
run(init_params, *args)#

Run a full optimization process until a stopping criterion is reached.

Used by BaseRegressor.fit.

Return type:

tuple[Any, OptimistixAdapterState, TypeVar(Aux)]

Parameters:
  • init_params (Any)

  • args (Any)

update(params, state, *args)#

Perform a single step/update of the optimization process.

Used by BaseRegressor.update.

Return type:

tuple[Any, OptimistixAdapterState, TypeVar(Aux)]

Parameters:
  • params (Any)

  • state (OptimistixAdapterState)

  • args (Any)

Parameters:
  • unregularized_loss (Callable)

  • regularizer (Regularizer)

  • regularizer_strength (float | None)

  • has_aux (bool)

  • init_params (Params | None)

  • tol (float)

  • rtol (float)

  • maxiter (int | None)