The regularizer Module#
Introduction#
The regularizer module introduces an archetype class Regularizer which provides the structural components for each concrete sub-class.
Objects of type Regularizer provide methods to define a regularized optimization objective. These objects serve as attribute of the nemos.glm.GLM, equipping the glm with an appropriate regularization scheme.
Each Regularizer object defines a default solver, and a set of allowed solvers, which depends on the loss function characteristics (smooth vs non-smooth).
Abstract Class Regularizer
|
├─ Concrete Class UnRegularized
|
├─ Concrete Class Ridge
|
├─ Concrete Class Lasso
|
└─ Concrete Class GroupLasso
The Abstract Class Regularizer#
The abstract class Regularizer enforces the implementation of the penalized_loss and get_proximal_operator methods.
Attributes#
The attributes of Regularizer consist of the default_solver and allowed_solvers, which are stored as read-only properties of type string and tuple of strings respectively.
Abstract Methods#
penalized_loss: Returns a penalized version of the input loss function which is uniquely defined by the regularization scheme and the regularizer strength parameter.get_proximal_operator: Returns the proximal projection operator which is uniquely defined by the regularization scheme.
Core Functions#
apply_operator#
The apply_operator function applies a transformation to all regularizable components of a parameter pytree:
def apply_operator(func, params, *args, **kwargs):
"""
Apply an operator to all regularizable subtrees of a parameter pytree.
Uses params.regularizable_subtrees() to identify which parameters
should be transformed, applies func to each, and returns updated params.
"""
This function enables selective regularization: models can specify which parameter components should be regularized via the regularizable_subtrees() method on their parameter containers. For example, GLMs regularize coefficients but not intercepts.
Benefits:
No hardcoded assumptions about parameter structure
Model-specific control over what gets regularized
Works with any pytree structure
_penalize#
Base method that computes regularization penalties using the regularizable_subtrees() hook. The current implementation assumes penalties are additive across parameter groups (e.g., separate penalty for each neuron’s coefficients), which covers most use cases but can be extended if needed.
Proximal Operators#
Proximal operators have been updated to work with arbitrary pytree structures rather than assuming specific parameter layouts. Each regularizer’s get_proximal_operator() method returns a function that:
Accepts any pytree of parameters
Applies the proximal operation element-wise
Returns a pytree with the same structure
The apply_operator function then uses the model’s regularizable_subtrees() specification to apply the proximal operator only to the appropriate parameter components.
The UnRegularized Class#
The UnRegularized class extends the base Regularizer class and is designed specifically for optimizing unregularized models. This means that the solver instantiated by this class does not add any regularization penalty to the loss function during the optimization process.
Concrete Methods Specifics#
penalized_loss: Returns the original loss without any changes.get_proximal_operator: Returns the identity operator.
Contributor Guidelines#
Implementing Regularizer Subclasses#
When developing a functional (i.e., concrete) Regularizer class:
Must inherit from
Regularizeror one of its derivatives.Must implement the
penalized_lossandget_proximal_operatormethods.Must define a default solver and a tuple of allowed solvers.
Should implement proximal operators to work on arbitrary pytrees (element-wise operations).
Should use the
regularizable_subtrees()hook on parameter containers to determine which components to regularize.May require extra initialization parameters, like the
maskargument ofGroupLasso.May override
_penalizeif penalty computation requires non-additive aggregation across parameter groups.
Convergence Test
When adding a new regularizer, you must include a convergence test, which verifies that
the model parameters the regularizer finds for a convex problem such as the GLM are identical
whether one minimizes the penalized loss directly and uses the proximal operator (i.e., when
using ProximalGradient). In practice, this means you should test the result of the ProximalGradient
optimization against that of either GradientDescent (if your regularization is differentiable) or
Nelder-Mead from scipy.optimize.minimize
(or another non-gradient based method, if your regularization is non-differentiable). You can refer to NeMoS test_lasso_convergence
from tests/test_convergence.py for a concrete example.
Interaction with Parameter Containers#
Regularizers interact with model parameters through the regularizable_subtrees() hook defined on parameter containers (e.g., GLMParams). This method returns a list of selector functions that identify which parameter components should be regularized.
Example workflow:
Model defines parameter container with regularization hook:
class GLMParams(eqx.Module): coef: jnp.ndarray intercept: jnp.ndarray @staticmethod def regularizable_subtrees(): return [lambda p: p.coef] # Only regularize coefficients
Regularizer applies operations using
apply_operator:# Apply proximal operator only to coefficients, leave intercept unchanged updated_params = apply_operator(proximal_op, params, strength=0.1)
Penalty computation respects the same hook:
# Compute penalty only on coefficients penalty = regularizer._penalize(params, strength)
This design allows:
Model flexibility: Each model controls what gets regularized
Code reuse: Same regularizer works with different model types
Extensibility: Easy to add new models with custom regularization needs