Source code for nemos.solvers._solver_registry

"""Registry of optimization algorithms and their implementations."""

from dataclasses import dataclass
from importlib.util import find_spec as _find_spec
from typing import Type

from ._abstract_solver import SolverProtocol
from ._fista import OptimistixFISTA, OptimistixNAG
from ._optax_optimistix_solvers import (
    OptimistixOptaxGradientDescent,
    OptimistixOptaxLBFGS,
)
from ._optimistix_solvers import OptimistixBFGS, OptimistixNonlinearCG
from ._svrg import WrappedProxSVRG, WrappedSVRG
from ._validation import validate_solver_class

JAXOPT_AVAILABLE = _find_spec("jaxopt") is not None


[docs] @dataclass class SolverSpec: """ Solver specification representing an entry in the solver registry. A solver is specified by: - the name of the algorithm it implements - its backend (optimization library or custom) - the class implementing the optimization method (ideally compatible with the AbstractSolver and SolverProtocol interface) Examples -------- >>> import nemos as nmo >>> spec = nmo.solvers.SolverSpec("BFGS", "optimistix", nmo.solvers._optimistix_solvers.OptimistixBFGS) >>> spec.algo_name 'BFGS' >>> spec.backend 'optimistix' >>> spec.implementation <class 'nemos.solvers._optimistix_solvers.OptimistixBFGS'> """ algo_name: str backend: str implementation: Type[SolverProtocol] @property def full_name(self) -> str: return f"{self.algo_name}[{self.backend}]" def __repr__(self) -> str: return ( f'"{self.full_name}" - ' f"{self.__class__.__name__}(\n" f' algo_name="{self.algo_name}",\n' f' backend="{self.backend}",\n' f" implementation={self.implementation.__module__}.{self.implementation.__qualname__}\n)" )
# mapping is {algo_name : {backend : implementation}} _registry: dict[str, dict[str, SolverSpec]] = {} # mapping is {algo_name : backend} _defaults: dict[str, str] = {} def _parse_name(name: str) -> tuple[str, str | None]: """Parse an algo_name[backend] string.""" algo_name = name backend = None if "[" in name: if name.count("[") > 1: raise ValueError( f"Found multiple opening '[' in solver name of {name}. " "Only use '[' for specifying the backend " "using the algo_name[backend_name] syntax. " ) if name.count("]") > 1: raise ValueError( f"Found multiple closing ']' in solver name of {name}. " "Only use ']' for specifying the backend " "using the algo_name[backend_name] syntax. " ) if "]" not in name: raise ValueError( "Found opening '[' in solver name but it does not end with closing ']'. " "Solver name can only use '[' for specifying the backend " "using the algo_name[backend_name] syntax. " f"Got {name}" ) if name.index("]") != len(name) - 1: raise ValueError( "Found closing ']' in the middle of the solver name. " "Brackets are reserved for specifying the backend " "using the algo_name[backend_name] syntax. " f"Got {name}" ) algo_name = name[: name.index("[")] backend = name[name.index("[") + 1 : -1] elif "]" in name: raise ValueError( "Found closing ']' in solver name without opening '['. " "Solver name can only use '[' for specifying the backend " "using the algo_name[backend_name] syntax. " f"Got {name}" ) if algo_name == "": raise ValueError("Algorithm name cannot be an empty string.") if backend == "": raise ValueError("Backend name cannot be an empty string.") return algo_name, backend def _raise_if_not_in_registry(algo_name: str): """Raise an error if an algorithm is not in the registry.""" if algo_name not in _registry: raise ValueError(f"No solver registered for algorithm {algo_name}.") def _resolve_backend(name: str, raise_if_given: bool) -> str: """ Return the backend that will be used for the algorithm if not specified. Parameters ---------- name: Name of the algorithm. raise_if_given: Raise an error if a backend is given, i.e. algo_name[backend_name] format is used. Returns ------- Backend name extracted from the registry. """ algo_name, backend = _parse_name(name) if backend is not None: if not raise_if_given: return backend raise ValueError( f"Provide an algorithm name only. Got {algo_name} with backend {backend}." ) _raise_if_not_in_registry(algo_name) algo_versions = _registry[algo_name] backend = _defaults.get(algo_name, None) if backend is None: if len(algo_versions) == 1: backend = next(iter(algo_versions.keys())) else: _spec = " " if raise_if_given else " specify or " raise ValueError( f"Multiple backends and no default found for {algo_name}. " f"Please{_spec}set a default backend." ) return backend
[docs] def get_solver(name: str) -> SolverSpec: """ Fetch the solver spec. from the registry for a given solver. Parameters ---------- name : Name of the solver with or without backend specified. Returns ------- spec : Specification for the solver, listing algorithm name, backend, implementation class. """ algo_name, _ = _parse_name(name) backend = _resolve_backend(name, False) # make sure we have the algorithm _raise_if_not_in_registry(algo_name) algo_versions = _registry[algo_name] if backend not in algo_versions: raise ValueError( f"{backend} backend not available for {algo_name}. " f"Available backends: {list_algo_backends(algo_name)}" ) return algo_versions[backend]
[docs] def register( algo_name: str, implementation: Type[SolverProtocol], backend: str = "custom", replace: bool = False, default: bool = False, validate: bool = True, test_ridge_without_aux: bool = False, test_ridge_with_aux: bool = False, ) -> None: """ Register a solver implementation in the registry. Parameters ---------- algo_name : Name of the optimization algorithm. implementation : Class implementing the solver. Has to adhere to the AbstractSolver interface. backend : Backend name. Defaults to "custom". When wrapping and registering an existing solver from an external package, this would be the package name. replace : If an implementation for the given algorithm and backend names is already present in the registry, overwrite it. default : Set this implementation as the default for the algorithm. Can also be done with `set_default_backend`. validate : Validate all required methods exist and have correct signatures. test_ridge_without_aux : Validate solver signatures and functionality by running a small ridge regression, objective function without aux. test_ridge_with_aux : Validate solver signatures and functionality by running a small ridgeregression, testing that objective functions with auxiliary variables are handled. Examples -------- >>> import nemos as nmo >>> nmo.solvers.register("FISTA", nmo.solvers._fista.OptimistixFISTA, backend="optimistix") """ if not replace and backend in _registry.get(algo_name, {}): raise ValueError( f"{algo_name}[{backend}] already registered. Use replace=True to overwrite." ) if not issubclass(implementation, SolverProtocol): raise TypeError(f"{implementation.__name__} doesn't implement SolverProtocol.") if validate: validate_solver_class(implementation, False, False) if test_ridge_without_aux: validate_solver_class(implementation, True, False) if test_ridge_with_aux: validate_solver_class(implementation, True, True) if algo_name not in _registry: _registry[algo_name] = {} _registry[algo_name][backend] = SolverSpec(algo_name, backend, implementation) if default: set_default_backend(algo_name, backend)
[docs] def set_default_backend(algo_name: str, backend: str) -> None: """ Set the default backend for a given algorithm. Parameters ---------- algo_name : Name of the optimization algorithm whose default backend to set. backend : Name of the backend to set as default. Examples -------- >>> import nemos as nmo >>> nmo.solvers.set_default_backend("LBFGS", "optax+optimistix") >>> nmo.solvers.get_solver("LBFGS").backend 'optax+optimistix' """ _raise_if_not_in_registry(algo_name) if backend not in _registry[algo_name]: raise ValueError( f"{backend} backend not available for {algo_name}." f"Available backends: {list_algo_backends(algo_name)}" ) _defaults[algo_name] = backend
[docs] def list_algo_backends(algo_name: str) -> list[str]: """ List the available backends for an algorithm. Parameters ---------- algo_name : Name of the optimization algorithm. """ _raise_if_not_in_registry(algo_name) return list(_registry[algo_name].keys())
[docs] def list_available_solvers() -> list[SolverSpec]: """List all available solvers.""" return [ spec for algo_versions in _registry.values() for spec in algo_versions.values() ]
[docs] def list_available_algorithms() -> list[str]: """ List the available algorithms that can be used for fitting models. To list the available backends for a given algorithm, see `list_algo_backends`. To access an extended documentation about a specific solver, see `nemos.solvers.get_solver_documentation`. Example ------- >>> import nemos as nmo >>> nmo.solvers.list_available_algorithms() ['GradientDescent', 'ProximalGradient', 'LBFGS', 'BFGS', 'NonlinearCG', 'SVRG', 'ProxSVRG'] """ return list(_registry.keys())
register("GradientDescent", OptimistixNAG, "optimistix", default=True) register("ProximalGradient", OptimistixFISTA, "optimistix", default=True) register("LBFGS", OptimistixOptaxLBFGS, "optax+optimistix", default=True) register("BFGS", OptimistixBFGS, "optimistix", default=True) register("NonlinearCG", OptimistixNonlinearCG, "optimistix", default=True) register("SVRG", WrappedSVRG, "nemos", default=True) register("ProxSVRG", WrappedProxSVRG, "nemos", default=True) register( "GradientDescent", OptimistixOptaxGradientDescent, "optax+optimistix", default=False ) if JAXOPT_AVAILABLE: from ._jaxopt_solvers import ( JaxoptBFGS, JaxoptGradientDescent, JaxoptLBFGS, JaxoptNonlinearCG, JaxoptProximalGradient, ) register("GradientDescent", JaxoptGradientDescent, "jaxopt", default=False) register("ProximalGradient", JaxoptProximalGradient, "jaxopt", default=False) register("LBFGS", JaxoptLBFGS, "jaxopt", default=False) register("BFGS", JaxoptBFGS, "jaxopt", default=False) register("NonlinearCG", JaxoptNonlinearCG, "jaxopt", default=False)