Source code for nemos.solvers._jaxopt_adapter

"""Base class for adapters wrapping JAXopt-style solvers."""

from __future__ import annotations

from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    ClassVar,
    NamedTuple,
    Tuple,
    Type,
    TypeAlias,
)

import lazy_loader as lazy

from ..typing import Aux, Params

if TYPE_CHECKING:
    from ..regularizer import Regularizer

from ._abstract_solver import OptimizationInfo, SolverAdapterState
from ._solver_adapter import SolverAdapter

jax = lazy.load("jax")

JaxoptSolverState: TypeAlias = NamedTuple


class JaxoptAdapterState(SolverAdapterState[JaxoptSolverState]):
    """Solver state for JAXopt-based adapters."""


JaxoptStepResult: TypeAlias = Tuple[Params, JaxoptAdapterState, Aux]


[docs] class JaxoptAdapter(SolverAdapter[JaxoptAdapterState]): """ Base class for adapters wrapping JAXopt-style solvers. Besides `_solver_cls`, for proximal solvers the `_proximal` class variable needs to be set to `True`. """ _solver_cls: ClassVar[Type] _proximal: ClassVar[bool] = False
[docs] def __init__( self, unregularized_loss: Callable, regularizer: Regularizer, regularizer_strength: float | None, has_aux: bool, init_params: Params | None = None, **solver_init_kwargs, ): if self._proximal: self.fun = unregularized_loss solver_init_kwargs["prox"] = regularizer.get_proximal_operator( params=init_params, strength=regularizer_strength ) else: self.fun = regularizer.penalized_loss( unregularized_loss, params=init_params, strength=regularizer_strength ) self.regularizer_strength = regularizer_strength # Prepend the regularizer strength to args for proximal solvers. # Methods of `jaxopt.ProximalGradient` expect `hyperparams_prox` before # the objective function's arguments, while others do not need this. self.hyperparams_prox = (self.regularizer_strength,) if self._proximal else () self._solver = self._solver_cls( fun=self.fun, has_aux=has_aux, **solver_init_kwargs, )
[docs] def init_state(self, init_params: Params, *args: Any) -> JaxoptAdapterState: return JaxoptAdapterState( solver_state=self._solver.init_state( init_params, *self.hyperparams_prox, *args ), stats=OptimizationInfo( function_val=jax.numpy.nan, # pyright: ignore num_steps=jax.numpy.array(0), converged=jax.numpy.array(False), # pyright: ignore reached_max_steps=jax.numpy.array(False), ), )
[docs] def update( self, params: Params, state: JaxoptAdapterState, *args: Any ) -> JaxoptStepResult: params, solver_state = self._solver.update( params, state.solver_state, *self.hyperparams_prox, *args ) aux = self._extract_aux(solver_state, fallback_name="aux_batch") stats = self._get_optim_info(solver_state) state = JaxoptAdapterState(solver_state=solver_state, stats=stats) return (params, state, aux)
[docs] def run(self, init_params: Params, *args: Any) -> JaxoptStepResult: params, solver_state = self._solver.run( init_params, *self.hyperparams_prox, *args ) aux = self._extract_aux(solver_state, fallback_name="aux_full") stats = self._get_optim_info(solver_state) state = JaxoptAdapterState(solver_state=solver_state, stats=stats) return (params, state, aux)
[docs] @classmethod def get_accepted_arguments(cls) -> set[str]: arguments = super().get_accepted_arguments() # prox is read from the regularizer, not provided as a solver argument if cls._proximal: arguments.remove("prox") return arguments
def _get_optim_info(self, state: JaxoptSolverState, **kwargs) -> OptimizationInfo: num_steps = state.iter_num # pyright: ignore function_val = ( state.value if hasattr(state, "value") else None ) # pyright: ignore return OptimizationInfo( function_val=function_val, # pyright: ignore num_steps=num_steps, converged=jax.numpy.array(state.error <= self.tol), # pyright: ignore reached_max_steps=jax.numpy.array(num_steps >= self.maxiter), ) @property def maxiter(self): return self._solver.maxiter def _extract_aux(self, state: JaxoptAdapterState, fallback_name: str): """ Return auxiliary output from a solver state. Prefers `state.aux` when present; otherwise falls back to the provided field name (e.g., `aux_batch` for SVRG updates or `aux_full` for SVRG run). """ # solvers imported from jaxopt have state.aux if hasattr(state, "aux"): return state.aux # for SVRG get state.aux_batch or state.aux_full return getattr(state, fallback_name)