"""Base class defining the interface for solvers that can be used by `BaseRegressor`."""
from __future__ import annotations
import abc
from typing import TYPE_CHECKING, Any, Callable, Generic, Protocol, runtime_checkable
from ..typing import Params, SolverState, StepResult
if TYPE_CHECKING:
from ..regularizer import Regularizer
import equinox as eqx
import jax
class SolverAdapterState(eqx.Module, Generic[SolverState]):
solver_state: SolverState # backend-specific internal state
stats: OptimizationInfo # num_steps, converged, etc. — computed during run, valid JAX pytree
[docs]
class OptimizationInfo(eqx.Module):
"""Basic diagnostic information about finished optimization runs."""
# Not all JAXopt solvers store the function value.
# None means missing value, while NaN usually indicates a diverged optimization
function_val: (
jax.numpy.ndarray | None
) #: Function value. Optional as not all solvers store it.
num_steps: jax.numpy.ndarray #: Number of optimization steps taken, array of int.
converged: jax.numpy.ndarray #: Whether the optimization converged, array of bool.
reached_max_steps: (
jax.numpy.ndarray
) #: Reached the maximum number of allowed steps.
[docs]
class AbstractSolver(abc.ABC, Generic[SolverState]):
"""
Base class defining the interface for solvers that can be used by `BaseRegressor`.
All solver parameters (e.g. tolerance, number of steps) are passed to `__init__`,
the other methods only take parameters, solver state, and the positional arguments of
the objective function.
"""
[docs]
@abc.abstractmethod
def __init__(
self,
unregularized_loss: Callable,
regularizer: Regularizer,
regularizer_strength: float | None,
has_aux: bool,
init_params: Params | None = None,
**solver_init_kwargs,
):
"""
Create the solver.
Arguments
---------
unregularized_loss:
Unregularized loss function.
Currently `BaseRegressor.compute_loss`.
regularizer:
Regularizer object used to create the penalized loss
or get the proximal operator from.
regularizer_strength:
Regularizer strength.
has_aux:
Whether `unregularized_loss` returns auxiliary variables.
If False, the loss function is expected to return a single scalar.
If True, the loss is expected to return a tuple of length 2 with a scalar and auxiliary variables.
init_params:
Initial model parameters. Passed to the regularizer's `get_proximal_operator` or `penalized_loss`.
**solver_init_kwargs:
Keyword arguments modifying the solver's behavior.
"""
pass
[docs]
@abc.abstractmethod
def init_state(self, init_params: Params, *args: Any) -> SolverState:
"""
Initialize the solver state.
Used by `BaseRegressor.initialize_state`
"""
pass
[docs]
@abc.abstractmethod
def update(self, params: Params, state: SolverState, *args: Any) -> StepResult:
"""
Perform a single step/update of the optimization process.
Used by `BaseRegressor.update`.
"""
pass
[docs]
@abc.abstractmethod
def run(self, init_params: Params, *args: Any) -> StepResult:
"""
Run a full optimization process until a stopping criterion is reached.
Used by `BaseRegressor.fit`.
"""
pass
[docs]
@classmethod
@abc.abstractmethod
def get_accepted_arguments(cls) -> set[str]:
"""
Set of argument names accepted by the solver.
Used by `BaseRegressor` to determine what arguments
can be passed to the solver's __init__.
"""
pass
@abc.abstractmethod
def _get_optim_info(self, state: SolverState, **kwargs) -> OptimizationInfo:
"""Extract some commong info about the optimization process.
Currently, the following info is extracted:
- final function value (where available)
- number of steps
- whether the optimization converged
- whether the max number of steps were reached
"""
pass
[docs]
@runtime_checkable
class SolverProtocol(Protocol, Generic[SolverState]):
"""
Protocol mirroring the interface of AbstractSolver.
Implementations can be checked at runtime via isinstance(solver_object, SolverProtocol)
and issubclass(solver_class, SolverProtocol).
"""
self,
unregularized_loss: Callable,
regularizer: Regularizer,
regularizer_strength: float | None,
has_aux: bool,
init_params: Params | None = None,
**solver_init_kwargs: Any,
) -> None: ...
[docs]
def init_state(self, init_params: Params, *args: Any) -> SolverState: ...
[docs]
def update(self, params: Params, state: SolverState, *args: Any) -> StepResult: ...
[docs]
def run(self, init_params: Params, *args: Any) -> StepResult: ...
[docs]
@classmethod
def get_accepted_arguments(cls) -> set[str]: ...