The regularizer Module#

Introduction#

The regularizer module introduces an archetype class Regularizer which provides the structural components for each concrete sub-class.

Objects of type Regularizer provide methods to define a regularized optimization objective. These objects serve as attribute of the nemos.glm.GLM, equipping the glm with an appropriate regularization scheme.

Each Regularizer object defines a default solver, and a set of allowed solvers, which depends on the loss function characteristics (smooth vs non-smooth).

Abstract Class Regularizer
|
├─ Concrete Class UnRegularized
|
├─ Concrete Class Ridge
|
├─ Concrete Class Lasso
|
└─ Concrete Class GroupLasso

The Abstract Class Regularizer#

The abstract class Regularizer enforces the implementation of the penalized_loss and get_proximal_operator methods.

Attributes#

The attributes of Regularizer consist of the default_solver and allowed_solvers, which are stored as read-only properties of type string and tuple of strings respectively.

Abstract Methods#

  • penalized_loss: Returns a penalized version of the input loss function which is uniquely defined by the regularization scheme and the regularizer strength parameter.

  • get_proximal_operator: Returns the proximal projection operator which is uniquely defined by the regularization scheme.

Core Functions#

apply_operator#

The apply_operator function applies a transformation to all regularizable components of a parameter pytree:

def apply_operator(func, params, *args, **kwargs):
    """
    Apply an operator to all regularizable subtrees of a parameter pytree.

    Uses params.regularizable_subtrees() to identify which parameters
    should be transformed, applies func to each, and returns updated params.
    """

This function enables selective regularization: models can specify which parameter components should be regularized via the regularizable_subtrees() method on their parameter containers. For example, GLMs regularize coefficients but not intercepts.

Benefits:

  • No hardcoded assumptions about parameter structure

  • Model-specific control over what gets regularized

  • Works with any pytree structure

_penalize#

Base method that computes regularization penalties using the regularizable_subtrees() hook. The current implementation assumes penalties are additive across parameter groups (e.g., separate penalty for each neuron’s coefficients), which covers most use cases but can be extended if needed.

Proximal Operators#

Proximal operators have been updated to work with arbitrary pytree structures rather than assuming specific parameter layouts. Each regularizer’s get_proximal_operator() method returns a function that:

  1. Accepts any pytree of parameters

  2. Applies the proximal operation element-wise

  3. Returns a pytree with the same structure

The apply_operator function then uses the model’s regularizable_subtrees() specification to apply the proximal operator only to the appropriate parameter components.

The UnRegularized Class#

The UnRegularized class extends the base Regularizer class and is designed specifically for optimizing unregularized models. This means that the solver instantiated by this class does not add any regularization penalty to the loss function during the optimization process.

Concrete Methods Specifics#

Contributor Guidelines#

Implementing Regularizer Subclasses#

When developing a functional (i.e., concrete) Regularizer class:

  • Must inherit from Regularizer or one of its derivatives.

  • Must implement the penalized_loss and get_proximal_operator methods.

  • Must define a default solver and a tuple of allowed solvers.

  • Should implement proximal operators to work on arbitrary pytrees (element-wise operations).

  • Should use the regularizable_subtrees() hook on parameter containers to determine which components to regularize.

  • May require extra initialization parameters, like the mask argument of GroupLasso.

  • May override _penalize if penalty computation requires non-additive aggregation across parameter groups.

Convergence Test

When adding a new regularizer, you must include a convergence test, which verifies that the model parameters the regularizer finds for a convex problem such as the GLM are identical whether one minimizes the penalized loss directly and uses the proximal operator (i.e., when using ProximalGradient). In practice, this means you should test the result of the ProximalGradient optimization against that of either GradientDescent (if your regularization is differentiable) or Nelder-Mead from scipy.optimize.minimize (or another non-gradient based method, if your regularization is non-differentiable). You can refer to NeMoS test_lasso_convergence from tests/test_convergence.py for a concrete example.

Interaction with Parameter Containers#

Regularizers interact with model parameters through the regularizable_subtrees() hook defined on parameter containers (e.g., GLMParams). This method returns a list of selector functions that identify which parameter components should be regularized.

Example workflow:

  1. Model defines parameter container with regularization hook:

    class GLMParams(eqx.Module):
        coef: jnp.ndarray
        intercept: jnp.ndarray
    
        @staticmethod
        def regularizable_subtrees():
            return [lambda p: p.coef]  # Only regularize coefficients
    
  2. Regularizer applies operations using apply_operator:

    # Apply proximal operator only to coefficients, leave intercept unchanged
    updated_params = apply_operator(proximal_op, params, strength=0.1)
    
  3. Penalty computation respects the same hook:

    # Compute penalty only on coefficients
    penalty = regularizer._penalize(params, strength)
    

This design allows:

  • Model flexibility: Each model controls what gets regularized

  • Code reuse: Same regularizer works with different model types

  • Extensibility: Easy to add new models with custom regularization needs