Source code for nemos.solvers._optimistix_adapter

from __future__ import annotations

import dataclasses
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Type, TypeAlias

import equinox as eqx
import lazy_loader as lazy
import optimistix as optx
from packaging.version import Version

from ..typing import Aux, Params

if TYPE_CHECKING:
    from ..regularizer import Regularizer

from ._abstract_solver import OptimizationInfo, SolverAdapterState
from ._aux_helpers import (
    convert_fn,
    drop_aux,
    pack_args,
    tree_map_inexact_asarray,
    wrap_aux,
)
from ._solver_adapter import SolverAdapter

jax = lazy.load("jax")

_OPTX_V_010 = Version(optx.__version__) >= Version("0.1.0")

DEFAULT_ATOL = 1e-4
DEFAULT_RTOL = 0.0
DEFAULT_MAX_STEPS = 10_000

OptimistixSolverState: TypeAlias = eqx.Module


class OptimistixAdapterState(SolverAdapterState[OptimistixSolverState]):
    """Solver state for Optimistix-based adapters."""


OptimistixStepResult: TypeAlias = tuple[Params, OptimistixAdapterState, Aux]


@dataclasses.dataclass
class OptimistixConfig:
    """
    Collection of arguments required by and cached for methods of Optimistix solvers.

    They rarely need to be overwritten, and the defaults here should suffice.
    The user has the ability to overwrite them with `solver_kwargs`, and on the solver's construction
    they are saved in `OptimistixAdapter.config` for later use: passing them to `optimistix.optimise`, `init`, `step`.
    """

    # max number of steps
    maxiter: int
    # options dict passed around within optimistix
    options: dict[str, Any] = dataclasses.field(default_factory=dict)
    # "Any Lineax tags describing the structure of the Jacobian matrix d(fn)/dy."
    tags: frozenset = frozenset()
    # sets if the minimisation throws an error if an iterative solver runs out of steps
    throw: bool = False
    # norm used in the Cauchy convergence criterion. Required by all Optimistix solvers.
    norm: Callable = optx.two_norm
    # way of autodifferentiation: https://docs.kidger.site/optimistix/api/adjoints/
    adjoint: optx.AbstractAdjoint = optx.ImplicitAdjoint()


[docs] class OptimistixAdapter(SolverAdapter[OptimistixAdapterState]): """ Base class for adapters wrapping Optimistix minimizers. Subclasses must define the `_solver_cls` class attribute. The `_solver` attribute is assumed to exist after construction, so if a subclass is overwriting `__init__`, these must be created. Note that for backward compatibility the `atol` parameter used in Optimistix is referred to as `tol` in NeMoS. The `maxiter` default is taken from `DEFAULT_MAXITER`, which subclasses may override to set solver-specific defaults. """ _solver_cls: ClassVar[Type] _solver: optx.AbstractMinimiser DEFAULT_MAXITER: ClassVar[int] = DEFAULT_MAX_STEPS _proximal: ClassVar[bool] = False
[docs] def __init__( self, unregularized_loss: Callable, regularizer: Regularizer, regularizer_strength: float | None, has_aux: bool, init_params: Params | None = None, tol: float = DEFAULT_ATOL, rtol: float = DEFAULT_RTOL, maxiter: int | None = None, **solver_init_kwargs, ): if "atol" in solver_init_kwargs: raise TypeError("Please use tol instead of atol.") if "max_steps" in solver_init_kwargs: raise TypeError("Please use maxiter instead of max_steps.") if self._proximal: loss_fn = unregularized_loss solver_init_kwargs["prox"] = regularizer.get_proximal_operator( params=init_params, strength=regularizer_strength ) else: loss_fn = regularizer.penalized_loss( unregularized_loss, params=init_params, strength=regularizer_strength ) # take out the arguments that go into minimise, init, terminate and so on # and only pass the actually needed things to __init__ user_args = {} for f in dataclasses.fields(OptimistixConfig): kw = f.name if kw in solver_init_kwargs: user_args[kw] = solver_init_kwargs.pop(kw) if maxiter is None: maxiter = self.DEFAULT_MAXITER self.config = OptimistixConfig(maxiter=maxiter, **user_args) if has_aux: self.fun_with_aux = pack_args(loss_fn) self.fun = drop_aux(self.fun_with_aux) else: self.fun = pack_args(loss_fn) self.fun_with_aux = wrap_aux(self.fun) # make custom adjustments such as adding a derived "while_loop_kind" parameter for FISTA solver_init_kwargs = self.adjust_solver_init_kwargs(solver_init_kwargs) self._solver = self._solver_cls( atol=tol, rtol=rtol, norm=self.config.norm, **solver_init_kwargs, )
[docs] def init_state(self, init_params: Params, *args: Any) -> OptimistixAdapterState: init_params = tree_map_inexact_asarray(init_params) fn = convert_fn(self.fun_with_aux, True, init_params, args) f_struct, aux_struct = fn.out_struct solver_state = self._solver.init( fn, init_params, args, self.config.options, f_struct, aux_struct, self.config.tags, ) stats = OptimizationInfo( function_val=jax.numpy.nan, # pyright: ignore num_steps=jax.numpy.array(0), converged=jax.numpy.array(False), # pyright: ignore reached_max_steps=jax.numpy.array(False), ) return OptimistixAdapterState(solver_state=solver_state, stats=stats)
[docs] def update( self, params: Params, state: OptimistixAdapterState, *args: Any, ) -> OptimistixStepResult: params = tree_map_inexact_asarray(params) fn = convert_fn(self.fun_with_aux, True, params, args) new_params, solver_state, aux = self._solver.step( fn=fn, y=params, args=args, state=state.solver_state, options=self.config.options, tags=self.config.tags, ) num_steps = state.stats.num_steps + 1 stats = self._get_optim_info(solver_state, num_steps=num_steps) state = OptimistixAdapterState(solver_state=solver_state, stats=stats) return new_params, state, aux
[docs] def run( self, init_params: Params, *args: Any, ) -> OptimistixStepResult: solution = optx.minimise( fn=self.fun_with_aux, solver=self._solver, y0=init_params, args=args, options=self.config.options, has_aux=True, max_steps=self.config.maxiter, adjoint=self.config.adjoint, throw=self.config.throw, tags=self.config.tags, ) stats = self._get_optim_info( solution.state, num_steps=solution.stats["num_steps"] ) state = OptimistixAdapterState(solver_state=solution.state, stats=stats) return solution.value, state, solution.aux
[docs] @classmethod def get_accepted_arguments(cls) -> set[str]: own_and_solver_args = super().get_accepted_arguments() # atol is added from wrapped optimistix solvers # but currently throughout nemos tol is used own_and_solver_args.remove("atol") common_optx_arguments = set( [f.name for f in dataclasses.fields(OptimistixConfig)] ) all_arguments = own_and_solver_args | common_optx_arguments # prox is read from the regularizer, not provided as a solver argument if cls._proximal: all_arguments.remove("prox") return all_arguments
@classmethod def _note_about_accepted_arguments(cls) -> str: return """ Note that for backward compatibility the `atol` parameter used in Optimistix is referred to as `tol` in NeMoS. """ @property def maxiter(self) -> int: return self.config.maxiter def _get_optim_info( self, state: OptimistixSolverState, num_steps: jax.numpy.ndarray = jax.numpy.array(0), ) -> OptimizationInfo: function_val = ( state.f if hasattr(state, "f") else state.f_info.f ) # pyright: ignore return OptimizationInfo( function_val=function_val, num_steps=num_steps, converged=state.terminate, # pyright: ignore reached_max_steps=jax.numpy.asarray(num_steps >= self.maxiter), )
[docs] def adjust_solver_init_kwargs( self, solver_init_kwargs: dict[str, Any] ) -> dict[str, Any]: """ Optionally adjust the parameters (e.g. derive from self.config) for instantiating the wrapped solver. Parameters ---------- solver_init_kwargs: Original keyword arguments that would be passed to _solver_cls.__init__. Returns ------- dict with argument names of _solver_cls.__init__ as keys and their corresponding values as values. Default implementation just returns solver_init_kwargs. """ return solver_init_kwargs