Source code for nemos.solvers._validation

import inspect
import warnings
from typing import Any, Type

import jax.numpy as jnp
import numpy as np

from .._inspect_utils import get_params, implements_methods
from ._abstract_solver import AbstractSolver

# Notes
# We could enforce adherence to the API with type checkers
# https://github.com/agronholm/typeguard
# https://github.com/beartype/beartype

# Validate only public-facing abstract methods (i.e. protocol methods).
# Private abstract methods (e.g. _get_optim_info) are internal implementation
# details and should not be enforced on custom solver classes.
METHOD_NAMES = frozenset(
    m
    for m in AbstractSolver.__abstractmethods__
    if not m.startswith("_") or m == "__init__"
)
AUX_VAL = -1


def _validate_method_signature(
    solver_class: Type, method_name: str
) -> tuple[bool, str]:
    """
    Check that the arguments of the method are the same.

    For __init__ only check the first arguments that are needed,
    the following ones (**solver_kwargs) can be anything.

    Returns (True, None) if there are no problems, and (False, error_message)
    if there are.
    """
    n_params_to_check = None
    if method_name == "__init__":
        n_params_to_check = sum(
            p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
            for p in get_params(getattr(AbstractSolver, method_name), names_only=False)
        )

    reference = get_params(getattr(AbstractSolver, method_name), n_params_to_check)
    got = get_params(getattr(solver_class, method_name), n_params_to_check)

    # enforcing names, not just the number of parameters
    if got != reference:
        problem = (
            f"Incompatible signature for {method_name}. Got {got}. Expected {reference}"
        )
        return False, problem

    return True, None


def _check_all_signatures_match(solver_class: Type) -> None:
    """
    Check that the signature of all required methods matches AbstractSolver's.

    They must have the same argument names.
    In __init__ only the required ones are checked, the rest (**solver_kwargs)
    can be anything.
    """
    # collect mismatches in signatures
    success_dict, problem_dict = {}, {}
    for method_name in METHOD_NAMES:
        success_dict[method_name], problem_dict[method_name] = (
            _validate_method_signature(solver_class, method_name)
        )

    # raise one error with all the problems found
    if not all(success_dict.values()):
        error_msg = "\n".join(
            problem_dict[method] for method in METHOD_NAMES if not success_dict[method]
        )
        raise ValueError(error_msg)


def _check_required_methods_exist(solver_class: Type):
    """Check that all abstractmethods of AbstractSolver are implemented."""
    # a bit more detailed than issubclass(solver_class, SolverProtocol)
    return implements_methods(solver_class, METHOD_NAMES)


def _assert_step_result(step_result: Any, method_name: str) -> tuple[Any, Any, Any]:
    """Make sure step_result is a tuple of length 3."""
    if not isinstance(step_result, tuple):
        raise TypeError(
            f"{method_name} must return a tuple of (params, state, aux), "
            f"got {type(step_result)!r}."
        )
    if len(step_result) != 3:
        raise TypeError(
            f"{method_name} must return a tuple of (params, state, aux), "
            f"got a tuple of length {len(step_result)}."
        )
    return step_result


def _tiny_ridge_regression_problem(
    has_aux: bool,
    seed: int = 123,
    n_samples: int = 100,
    n_features: int = 3,
):
    """Create a tiny ridge regression problem to quickly test solver implementations with."""

    rng = np.random.default_rng(seed)
    X = rng.normal(size=(n_samples, n_features))
    coef = rng.normal(size=(n_features,))
    y = X.dot(coef) + 0.1 * rng.normal(size=(n_samples,))

    def _loss(params, XX, yy):
        return jnp.power(yy - jnp.dot(XX, params), 2).mean()

    if not has_aux:
        loss = _loss
    else:
        # return (loss_val, aux)
        def loss(params, XX, yy):
            return (_loss(params, XX, yy), AUX_VAL)

    init_params = jnp.zeros((n_features,))
    return jnp.asarray(X), jnp.asarray(y), init_params, loss


def _validate_solver_class_on_ridge(
    solver_class: type,
    has_aux: bool,
    solver_kwargs: dict[str, Any] | None = None,
):
    """
    Validate a custom solver by running a tiny ridge regression problem.

    This checks that required methods can be called with the expected inputs,
    and that they return sensible outputs.
    """

    if solver_kwargs is None:
        solver_kwargs = {}
    from nemos.regularizer import Ridge

    regularizer = Ridge()
    regularizer_strength = 1e-2
    X, y, init_params, unregularized_loss = _tiny_ridge_regression_problem(has_aux)

    solver = solver_class(
        unregularized_loss,
        regularizer,
        regularizer_strength,
        has_aux,
        init_params=init_params,
        **solver_kwargs,
    )

    # init_state works
    _ = solver.init_state(init_params, X, y)

    # run can be called as intended
    run_params, run_state, run_aux = _assert_step_result(
        solver.run(init_params, X, y), "run"
    )
    # update can proceed from run's output
    update_params, update_state, update_aux = _assert_step_result(
        solver.update(run_params, run_state, X, y), "update"
    )
    # update can proceed from its own output
    update_params, update_state, update_aux = _assert_step_result(
        solver.update(update_params, update_state, X, y), "update"
    )

    if has_aux:
        assert run_aux == AUX_VAL
        assert update_aux == AUX_VAL
    else:
        assert run_aux is None
        assert update_aux is None

    penalized_loss = regularizer.penalized_loss(
        unregularized_loss, init_params, regularizer_strength
    )
    init_loss = penalized_loss(init_params, X, y)
    run_loss = penalized_loss(run_params, X, y)
    update_loss = penalized_loss(update_params, X, y)

    # only look at the function value
    if has_aux:
        init_loss = init_loss[0]
        run_loss = run_loss[0]
        update_loss = update_loss[0]

    if not jnp.all(jnp.isfinite(jnp.array([init_loss, run_loss, update_loss]))):
        raise ValueError("Loss values must be finite for the validation problem.")

    if run_loss > init_loss:
        warnings.warn(f"{solver_class.__name__} increases loss on ridge problem")


[docs] def validate_solver_class( solver_class: Type, test_ridge: bool, loss_has_aux: bool, ) -> None: """ Validate required methods against AbstractSolver and optionally run a quick ridge regression. 1. Check if all required methods are there 2. Check their signatures and make sure they have the same argument names. In __init__ only the required ones are checked. 3. If `test_ridge` is True, run a ridge regression toy problem to see if the solver actually works. If `loss_has_aux` is True, the ridge loss will carry an aux variable, otherwise it's a scalar loss value. """ _check_required_methods_exist(solver_class) _check_all_signatures_match(solver_class) if test_ridge: _validate_solver_class_on_ridge(solver_class, has_aux=loss_has_aux)