Source code for nemos.solvers._abstract_solver

"""Base class defining the interface for solvers that can be used by `BaseRegressor`."""

from __future__ import annotations

import abc
from typing import TYPE_CHECKING, Any, Callable, Generic, Protocol, runtime_checkable

from ..typing import Params, SolverState, StepResult

if TYPE_CHECKING:
    from ..regularizer import Regularizer
import equinox as eqx
import jax


class SolverAdapterState(eqx.Module, Generic[SolverState]):
    solver_state: SolverState  # backend-specific internal state
    stats: OptimizationInfo  # num_steps, converged, etc. — computed during run, valid JAX pytree


[docs] class OptimizationInfo(eqx.Module): """Basic diagnostic information about finished optimization runs.""" # Not all JAXopt solvers store the function value. # None means missing value, while NaN usually indicates a diverged optimization function_val: ( jax.numpy.ndarray | None ) #: Function value. Optional as not all solvers store it. num_steps: jax.numpy.ndarray #: Number of optimization steps taken, array of int. converged: jax.numpy.ndarray #: Whether the optimization converged, array of bool. reached_max_steps: ( jax.numpy.ndarray ) #: Reached the maximum number of allowed steps.
[docs] class AbstractSolver(abc.ABC, Generic[SolverState]): """ Base class defining the interface for solvers that can be used by `BaseRegressor`. All solver parameters (e.g. tolerance, number of steps) are passed to `__init__`, the other methods only take parameters, solver state, and the positional arguments of the objective function. """
[docs] @abc.abstractmethod def __init__( self, unregularized_loss: Callable, regularizer: Regularizer, regularizer_strength: float | None, has_aux: bool, init_params: Params | None = None, **solver_init_kwargs, ): """ Create the solver. Arguments --------- unregularized_loss: Unregularized loss function. Currently `BaseRegressor.compute_loss`. regularizer: Regularizer object used to create the penalized loss or get the proximal operator from. regularizer_strength: Regularizer strength. has_aux: Whether `unregularized_loss` returns auxiliary variables. If False, the loss function is expected to return a single scalar. If True, the loss is expected to return a tuple of length 2 with a scalar and auxiliary variables. init_params: Initial model parameters. Passed to the regularizer's `get_proximal_operator` or `penalized_loss`. **solver_init_kwargs: Keyword arguments modifying the solver's behavior. """ pass
[docs] @abc.abstractmethod def init_state(self, init_params: Params, *args: Any) -> SolverState: """ Initialize the solver state. Used by `BaseRegressor.initialize_state` """ pass
[docs] @abc.abstractmethod def update(self, params: Params, state: SolverState, *args: Any) -> StepResult: """ Perform a single step/update of the optimization process. Used by `BaseRegressor.update`. """ pass
[docs] @abc.abstractmethod def run(self, init_params: Params, *args: Any) -> StepResult: """ Run a full optimization process until a stopping criterion is reached. Used by `BaseRegressor.fit`. """ pass
[docs] @classmethod @abc.abstractmethod def get_accepted_arguments(cls) -> set[str]: """ Set of argument names accepted by the solver. Used by `BaseRegressor` to determine what arguments can be passed to the solver's __init__. """ pass
@abc.abstractmethod def _get_optim_info(self, state: SolverState, **kwargs) -> OptimizationInfo: """Extract some commong info about the optimization process. Currently, the following info is extracted: - final function value (where available) - number of steps - whether the optimization converged - whether the max number of steps were reached """ pass
[docs] @runtime_checkable class SolverProtocol(Protocol, Generic[SolverState]): """ Protocol mirroring the interface of AbstractSolver. Implementations can be checked at runtime via isinstance(solver_object, SolverProtocol) and issubclass(solver_class, SolverProtocol). """
[docs] def __init__(
self, unregularized_loss: Callable, regularizer: Regularizer, regularizer_strength: float | None, has_aux: bool, init_params: Params | None = None, **solver_init_kwargs: Any, ) -> None: ...
[docs] def init_state(self, init_params: Params, *args: Any) -> SolverState: ...
[docs] def update(self, params: Params, state: SolverState, *args: Any) -> StepResult: ...
[docs] def run(self, init_params: Params, *args: Any) -> StepResult: ...
[docs] @classmethod def get_accepted_arguments(cls) -> set[str]: ...