Source code for nemos.solvers._solver_adapter

import abc
import inspect
from typing import Any, ClassVar, Type

from ._abstract_solver import AbstractSolver, SolverState


[docs] class SolverAdapter(AbstractSolver[SolverState], abc.ABC): """ Base class for adapters wrapping existing solvers. Needs to define the class attribute `_solver_cls` and set the wrapped solver in the `_solver` attribute. """ _solver_cls: ClassVar[Type] _solver: Any
[docs] def __getattr__(self, name: str): """Try getting undefined attributes from the underlying solver.""" # without this guard deepcopy leads to a RecursionError try: solver = object.__getattribute__(self, "_solver") except AttributeError: raise AttributeError(name) return getattr(solver, name)
[docs] @classmethod def get_accepted_arguments(cls) -> set[str]: """Set of accepted argument names, extended with the wrapped solver's arguments.""" own_arguments = set(inspect.getfullargspec(cls.__init__).args) solver_arguments = set(inspect.getfullargspec(cls._solver_cls).args) solver_init_arguments = set( inspect.getfullargspec(cls._solver_cls.__init__).args ) all_arguments = own_arguments | solver_arguments | solver_init_arguments # discard arguments that are passed by BaseRegressor all_arguments.discard("self") all_arguments.discard("unregularized_loss") all_arguments.discard("regularizer") all_arguments.discard("regularizer_strength") return all_arguments
@classmethod def _note_about_accepted_arguments(cls) -> str: """Add a potential note about the accepted arguments in the docstring.""" return ""
[docs] def __init_subclass__(cls, **kw): """Generate the docstring including accepted arguments and the wrapped solver's documentation.""" super().__init_subclass__(**kw) # can only do anything if there is a _solver_cls class attribute solver_cls = getattr(cls, "_solver_cls", None) if solver_cls is None: return # read the class's docstring or set it to a default adapter_doc = inspect.cleandoc( inspect.getdoc(cls) or f"Adapter for {solver_cls.__name__}" ) # make a list of accepted arguments accepted_doc_header = inspect.cleandoc(""" Accepted arguments: ------------------- """) accepted_doc = "\n".join(f"- {a}" for a in sorted(cls.get_accepted_arguments())) accepted_doc = accepted_doc_header + "\n" + accepted_doc # potentially add a note about the accepted arguments after_accepted = inspect.cleandoc(cls._note_about_accepted_arguments()).strip() if after_accepted: accepted_doc = accepted_doc + "\n\n" + after_accepted # read the underlying solver class's documentation solver_doc_header = inspect.cleandoc(f""" {solver_cls.__name__}'s documentation: """) solver_doc_header += "\n" + "-" * len(solver_doc_header) solver_doc = inspect.cleandoc( inspect.getdoc(solver_cls) or "No class documentation found." ) solver_doc = solver_doc_header + "\n" + solver_doc # read the underlying solver's __init__'s documentation solver_init_doc_header = inspect.cleandoc(f""" More info from {solver_cls.__name__}.__init__ """) solver_init_doc_header += "\n" + "-" * len(solver_init_doc_header) solver_init_doc = inspect.cleandoc( inspect.getdoc(solver_cls.__init__) or "No __init__ documentation found." ) solver_init_doc = solver_init_doc_header + "\n" + solver_init_doc # the whole documentation is the parts after each other separated by blank lines class_doc = "\n\n".join( ( adapter_doc, accepted_doc, solver_doc, solver_init_doc, ) ) cls.__doc__ = inspect.cleandoc(class_doc)