nemos.solvers._jaxopt_solvers.JaxoptAdapter#

class nemos.solvers._jaxopt_solvers.JaxoptAdapter(unregularized_loss, regularizer, regularizer_strength, has_aux, init_params=None, **solver_init_kwargs)[source]#

Bases: SolverAdapter

Base class for adapters wrapping JAXopt-style solvers.

Besides _solver_cls, for proximal solvers the _proximal class variable needs to be set to True.

Attributes

Parameters:
  • unregularized_loss (Callable)

  • regularizer (Regularizer)

  • regularizer_strength (float | None)

  • has_aux (bool)

  • init_params (Params | None)

__init__(unregularized_loss, regularizer, regularizer_strength, has_aux, init_params=None, **solver_init_kwargs)[source]#

Create the solver.

Parameters:
  • unregularized_loss (Callable) – Unregularized loss function. Currently BaseRegressor.compute_loss.

  • regularizer (Regularizer) – Regularizer object used to create the penalized loss or get the proximal operator from.

  • regularizer_strength (float | None) – Regularizer strength.

  • has_aux (bool) – Whether unregularized_loss returns auxiliary variables. If False, the loss function is expected to return a single scalar. If True, the loss is expected to return a tuple of length 2 with a scalar and auxiliary variables.

  • init_params (Any | None) – Initial model parameters. Passed to the regularizer’s get_proximal_operator or penalized_loss.

  • **solver_init_kwargs – Keyword arguments modifying the solver’s behavior.

Methods

__init__(unregularized_loss, regularizer, ...)

Create the solver.

get_accepted_arguments()

Set of accepted argument names, extended with the wrapped solver's arguments.

init_state(init_params, *args)

Initialize the solver state.

run(init_params, *args)

Run a full optimization process until a stopping criterion is reached.

update(params, state, *args)

Perform a single step/update of the optimization process.

__getattr__(name)#

Try getting undefined attributes from the underlying solver.

Parameters:

name (str)

classmethod __init_subclass__(**kw)#

Generate the docstring including accepted arguments and the wrapped solver’s documentation.

classmethod get_accepted_arguments()[source]#

Set of accepted argument names, extended with the wrapped solver’s arguments.

Return type:

set[str]

init_state(init_params, *args)[source]#

Initialize the solver state.

Used by BaseRegressor.initialize_state

Return type:

JaxoptAdapterState

Parameters:
  • init_params (Any)

  • args (Any)

property maxiter#
run(init_params, *args)[source]#

Run a full optimization process until a stopping criterion is reached.

Used by BaseRegressor.fit.

Return type:

Tuple[Any, JaxoptAdapterState, TypeVar(Aux)]

Parameters:
  • init_params (Any)

  • args (Any)

update(params, state, *args)[source]#

Perform a single step/update of the optimization process.

Used by BaseRegressor.update.

Return type:

Tuple[Any, JaxoptAdapterState, TypeVar(Aux)]

Parameters:
  • params (Any)

  • state (JaxoptAdapterState)

  • args (Any)