The solvers Module#
Background#
In the earlier versions, NeMoS relied on JAXopt as its optimization backend. As JAXopt is no longer maintained, we added support for alternative optimization backends. JAXopt remains optionally supported as an extra dependency.
To support flexibility and long-term maintenance, NeMoS now has a backend-agnostic solver interface, allowing the use of solvers from different backend libraries with different interfaces.
In particular, NeMoS’s solvers interface is designed to be compatible with solvers from JAXopt, Google’s Optax, and the community-run Optimistix.
AbstractSolver interface#
This interface is defined by AbstractSolver and mostly follows the JAXopt API.
All solvers implemented in NeMoS are subclasses of AbstractSolver, however subclassing is not strictly required for implementing solvers that can be used with NeMoS. (See custom solvers)
The AbstractSolver interface requires implementing the following methods:
__init__: Construct a solver object. All solver parameters and settings (tolerance, maximum number of steps, etc.) are passed here. The other methods only take the solver state, current or initial solution (model parameters), and the input data for the objective function.init_state: Initialize the solver state.update: Take one step of the optimization algorithm.run: Run a full optimization.get_accepted_arguments: Set of argument names that can be passed to__init__. These will be the parameters users can change by passingsolver_kwargsto NeMoS models (e.g.,GLM)._get_optim_info: Collect diagnostic information about the optimization run into anOptimizationInfonamedtuple, described in the next section.
AbstractSolver is a generic class parametrized by SolverState and StepResult.
SolverState in concrete subclasses should be the type of the solver state.
StepResult is the type of what is returned by each step of the solver. Typically this is a tuple of the parameters, the solver state, and auxiliary variables returned by the objective function.
Optimization info#
Because different libraries store info about the optimization run in different places, we decided to standardize some common diagnostics.
These are accessed using the _get_optim_info method which takes the solver state and returns an OptimizationInfo.
OptimizationInfo holds the following fields:
function_val: The final value of the objective function. As not all solvers store this by default, and as it’s potentially expensive to evaluate, this field is optional.num_steps: The number of steps taken by the solver.converged: Whether the optimization converged according to the solver’s criteria.reached_max_steps: Whether the solver reached the maximum number of steps allowed.
Adapters#
Support for existing solvers from external libraries and the custom implementation of (Prox-)SVRG is done through adapters that “translate” between the interfaces of these external solvers and the AbstractSolver interface.
Creating adapters for existing solvers can be done in multiple ways. In our experience wrapping solver objects through adapters provides a clean way of doing that, and adapters in NeMoS follow this pattern.
Currently there are adapters implemented for two optimization backends:
OptimistixAdapterwraps Optimistix solvers.JaxoptAdapterwraps JAXopt solvers when the optionaljaxoptdependency is installed. AsSVRGandProxSVRGfollow the JAXopt-style interface, these are also wrapped withJaxoptAdaptereven without JAXopt installed.
Both of these are subclasses of SolverAdapter that provides common methods for wrapping existing solvers.
Each subclass of SolverAdapter defines the methods of AbstractInterface, as well as a _solver_cls class variable signaling the type of solver wrapped by it.
During construction they set a _solver attribute that is a concrete instance of _solver_cls.
Default method implementations in SolverAdapter:
get_accepted_argumentsreturns the arguments to__init__,_solver_cls, and_solver_cls.__init__, and discarding the ones required byAbstractSolver.__init__.__getattr__dispatches every attribute call to the wrapped_solver.__init_subclass__generates a docstring for the adapter including accepted arguments and the wrapped solver’s documentation. Extra notes about accepted arguments can be included in docstrings of subclasses using_note_about_accepted_arguments. This is used byOptimistixAdapterto add a note about the different naming of the tolerance parameter.
List of available solvers#
Abstract Class AbstractSolver
│
├─ Abstract Subclass SolverAdapter
│ │
│ ├─ Abstract Subclass OptimistixAdapter
│ │ │
│ │ ├─ Concrete Subclass OptimistixBFGS
│ │ ├─ Concrete Subclass OptimistixFISTA
│ │ ├─ Concrete Subclass OptimistixNAG
│ │ ├─ Concrete Subclass OptimistixNonlinearCG
│ │ └─ Abstract Subclass AbstractOptimistixOptaxSolver
│ │ │
│ │ ├─ Concrete Subclass OptimistixOptaxLBFGS
│ │ └─ Concrete Subclass OptimistixOptaxGradientDescent
│ │
│ └─ Abstract Subclass JaxoptAdapter
│ │
│ ├─ Concrete Subclass JaxoptLBFGS (optional)
│ ├─ Concrete Subclass JaxoptGradientDescent (optional)
│ ├─ Concrete Subclass JaxoptProximalGradient (optional)
│ ├─ Concrete Subclass JaxoptBFGS (optional)
│ ├─ Concrete Subclass JaxoptNonlinearCG (optional)
│ │
│ ├─ Concrete Subclass WrappedSVRG
│ └─ Concrete Subclass WrappedProxSVRG
OptaxOptimistixSolver is an adapter for Optax solvers, relying on optimistix.OptaxMinimiser to run the full optimization loop. If there is a need, this can be used to wrap adaptive solvers (e.g. Adam).
Gradient descent is implemented by two classes:
One is wrapping
optax.sgdwhich supports momentum and acceleration. Note that what Optax calls Nesterov acceleration is not the original method developed for convex optimization but the version adapted for training deep networks with SGD.As JAXopt implemented the original method, a port of JAXopt’s
GradientDescentwas added to NeMoS asOptimistixNAG.
Similarly to NAG, an accelerated proximal gradient algorithm (FISTA) was ported from JAXopt as OptimistixFISTA.
Available solvers and which implementation they dispatch to are defined in the solver registry.
A list of available algorithms is provided by nemos.solvers.list_available_algorithms().
All solvers in the registry can be listed with nemos.solvers.list_available_solvers(), and extended documentation about each solver can be accessed using nemos.solvers.get_solver_documentation().
Custom solvers#
The solver registry – implemented in nemos.solvers._solver_registry – the list of available algorithms and their implementation.
Alternatively, users can use their own solvers to fit NeMoS models, they just have to write a solver that adheres to the AbstractSolver interface, and it should be straightforward to plug in.
Fitting models using this custom solver can be done by:
Registering the class implementing the solver in the solver registry:
nemos.solvers.register("Fancy-Algorithm", MyCustomSolverClass, backend="custom")
Please note that not a solver instance but a class/type has to be passed.Declaring the algorithm’s compatibility with the appropriate regularizers:
nemos.regularizer.UnRegularized.allow_solver("Fancy-Algorithm").Referring to the algorithm by name when creating a
GLM(or anyBaseRegressor):
GLM(solver_name="Fancy-Algorithm[custom]")
When registering a solver, NeMoS does basic checks validating the custom solver’s compatibility by checking if the required methods are implemented, i.e. if the class implements the and that their signatures match SolverProtocol (which needs all AbstractSolver public abstract methods).
There are also options in nemos.solvers.register to run a small ridge regression problem, testing that the solver’s methods can be used as intended.
To validate a solver without registering, the nemos.solvers.validate_solver_class can be used.
While it is not necessary, a way to ensure adherence to the interface is subclassing AbstractSolver.
Stochastic optimization#
To run stochastic (mini-batch) optimization, JAXopt used a run_iterator method.
Instead of the full input data run_iterator accepts a generator / iterator that provides batches of data.
For information on how stochastic optimization is planned to be supported in NeMOS, see the issue tracking the stochastic optimization interface.
Stochastic optimization interface for (Prox-)SVRG
Note that (Prox-)SVRG is especially well-suited for running stochastic optimization, however it currently requires the optimization loop to be implemented separately as it is a bit more involved than what is done by run_iterator.