Source code for nemos.regularizer

"""
A Module for Optimization with Various Regularization Schemes.

This module provides a series of classes that facilitate the optimization of models
with different types of regularizations. Each `Regularizer` class in this module interfaces
with various optimization methods, and they can be applied depending on the model's requirements.
"""

import abc
import math
from typing import Any, Callable, Tuple, Union

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np

from . import tree_utils
from .base_class import Base
from .proximal_operator import (
    compute_normalization,
    masked_norm_2,
    prox_elastic_net,
    prox_group_lasso,
    prox_lasso,
    prox_none,
    prox_ridge,
)
from .tree_utils import pytree_map_and_reduce
from .type_casting import _is_scalar_or_0d
from .typing import ProximalOperator
from .utils import format_repr
from .validation import convert_tree_leaves_to_jax_array

__all__ = ["UnRegularized", "Ridge", "Lasso", "GroupLasso", "ElasticNet"]


def __dir__() -> list[str]:
    return __all__


def apply_operator(func, params, *args, filter_kwargs=None, **kwargs):
    """
    Apply an operator to all regularizable subtrees of a parameter pytree.

    This function iterates over all locations returned by
    ``params.regularizable_subtrees()`` and applies ``func`` to each selected
    subtree. The updated values are written back into ``params`` using
    :func:`equinox.tree_at`. Typical use cases include applying proximal
    operators or other transformations to parameter tensors while leaving
    non-regularized fields (e.g., intercepts or structural metadata) unchanged.

    Parameters
    ----------
    func :
        A callable with signature ``func(x, *args, **kwargs) -> Any``.
        It receives each regularizable subtree ``x`` and must return a value
        with the same pytree structure that should replace that subtree.
    params :
       params any parameter object. If it implements ``regularizable_subtrees()``, the
       method is used to return an iterable of selector functions (
       suitable for ``eqx.tree_at``) that identify the leaves/subtrees to be transformed.
    *args :
        Additional positional arguments passed directly to ``func``.
    filter_kwargs :
        Optional keyword-only dictionary of keyword arguments with PyTree values
        that should be filtered per subtree. For each regularizable subtree, the
        subtree selector is applied to each value in this dict, extracting only
        the portion relevant to that subtree. These extracted kwargs are then
        passed to ``func`` along with the subtree. This is useful for operators
        that need PyTree-structured metadata (e.g., masks) aligned with the
        parameter structure. Must be passed as a keyword argument. Default is
        None, which results in no filtering.
    **kwargs :
        Additional keyword arguments passed directly to ``func``.

    Returns
    -------
    params_new : same type as ``params``
        A new pytree/module with ``func`` applied to all regularizable
        subtrees. Non-regularized fields are preserved unchanged.

    Notes
    -----
    - ``regularizable_subtrees()`` must return a sequence of callables
      compatible with ``eqx.tree_at``. Each callable should extract a subtree
      from ``params``.
    - ``func`` must be pure and JAX-compatible if this function is used inside
      JIT-compiled code.
    - When ``filter_kwargs`` is provided, each value in the dict must be a PyTree
      with the same structure as ``params`` (or compatible with the subtree selectors).

    Examples
    --------
    A minimal working example with a fake ``Params`` object:

    >>> import equinox as eqx
    >>> class Params(eqx.Module):
    ...     w: float
    ...     b: float
    ...
    ...     # Only `w` is regularizable
    ...     def regularizable_subtrees(self):
    ...         return [lambda p: p.w]

    >>> p = Params(w=3.0, b=10.0)

    Define an operator that halves the value:

    >>> def halve(x):
    ...     return x / 2

    Apply it only to the regularizable subtree (`w`):

    >>> p2 = apply_operator(halve, p)
    >>> p2.w
    1.5
    >>> p2.b
    10.0

    The bias `b` is unchanged because it is not listed in
    `regularizable_subtrees`.

    Example with ``filter_kwargs`` for PyTree-structured metadata:

    >>> def masked_op(x, mask=None):
    ...     if mask is not None:
    ...         return x * mask
    ...     return x

    >>> # Create a mask with same structure as params
    >>> mask_tree = Params(w=0.5, b=1.0)

    >>> # Apply operator with filtered kwargs - only the relevant mask piece
    >>> # is passed to each subtree
    >>> p3 = apply_operator(masked_op, p, filter_kwargs={"mask": mask_tree})
    >>> p3.w  # w was multiplied by mask.w (0.5)
    1.5
    >>> p3.b  # b is not regularizable, so unchanged
    10.0
    """
    filter_kwargs = filter_kwargs or {}
    # if there is a list of regularizable sub-trees use that
    if hasattr(params, "regularizable_subtrees"):
        regularizable_subtrees = params.regularizable_subtrees()
    # otherwise regularize all the tree
    else:
        regularizable_subtrees = [lambda x: x]

    for where in regularizable_subtrees:
        # Extract subtree-specific kwargs by applying the selector to each value
        subtree_kwargs = {key: where(val) for key, val in filter_kwargs.items()}
        params = eqx.tree_at(
            where,
            params,
            func(where(params), *args, **kwargs, **subtree_kwargs),
        )

    return params


[docs] class Regularizer(Base, abc.ABC): """ Abstract base class for regularized solvers. This class is designed to provide a consistent interface for optimization solvers, enabling users to easily switch between different regularizers, ensuring compatibility with various loss functions and optimization algorithms. Attributes ---------- allowed_solvers : Tuple[str] Tuple of solver names that are allowed for use with this regularizer. default_solver : str String of the default solver name allowed for use with this regularizer. """ _allowed_solvers: Tuple[str] _default_solver: str _proximal_operator: Callable @property def allowed_solvers(self) -> Tuple[str]: return self._allowed_solvers @property def default_solver(self) -> str: return self._default_solver
[docs] def check_solver(self, solver_name: str) -> None: """Raise an error if the given solver is not allowed.""" if solver_name not in self._allowed_solvers: raise ValueError( f"The solver: {solver_name} is not allowed for " f"{self.__class__.__name__} regularization. Allowed solvers are " f"{self.allowed_solvers}." f"If {solver_name} is your implementation and is designed to be" f"compatible with {self.__class__.__name__}, register it with" f"{self.__class__.__name__}.allow_solver({solver_name})" )
[docs] @classmethod def allow_solver(cls, algo_name: str) -> None: """ Add an algorithm to the list of compatible solvers. Parameters ---------- algo_name : Name of the optimization algorithm to add. """ if algo_name in cls._allowed_solvers: return cls._allowed_solvers += (algo_name,)
def __repr__(self) -> str: return format_repr(self) def __str__(self) -> str: return format_repr(self) @staticmethod def _check_loss_output_tuple(output: tuple): if len(output) != 2: n_out = len(output) word = "value" if n_out == 1 else "values" raise ValueError( f"Invalid loss function return. The loss function returns a tuple with {n_out} {word}.\n" "A valid loss function can return either a single value (float or a 0-dim array), the loss, " "or a tuple with two values, the loss and an auxiliary variable." )
[docs] def get_proximal_operator(self, params: Any, strength: Any) -> ProximalOperator: """ Retrieve the proximal operator. Parameters ---------- params: The parameters to be regularized. Returns ------- : The proximal operator, applying regularization to the provided parameters. """ filter_kwargs = self._get_filter_kwargs(strength=strength, params=params) # hyperparams is unused: strength is captured in filter_kwargs at construction time. # The argument is required to match the jaxopt prox interface: # prox(params, hyperparams_prox, scaling=1.0). def prox_op(params, hyperparams, scaling=1.0, *args): return apply_operator( self._proximal_operator, params, filter_kwargs=filter_kwargs, scaling=scaling, ) return prox_op
[docs] def penalized_loss(self, loss: Callable, params: Any, strength: Any) -> Callable: """Return a function for calculating the penalized loss.""" filter_kwargs = self._get_filter_kwargs(strength=strength, params=params) def _penalized_loss(params, *args, **kwargs): result = loss(params, *args, **kwargs) penalty = self._penalization(params, filter_kwargs=filter_kwargs) if isinstance(result, tuple): self._check_loss_output_tuple(result) loss_value, aux = result return loss_value + penalty, aux return result + penalty return _penalized_loss
def _penalization(self, params: Any, filter_kwargs: dict) -> jnp.ndarray: penalty = jnp.array(0.0) if hasattr(params, "regularizable_subtrees"): for where in params.regularizable_subtrees(): subtree = where(params) subtree_kwargs = {key: where(val) for key, val in filter_kwargs.items()} penalty = penalty + self._penalty_on_subtree(subtree, **subtree_kwargs) else: penalty = penalty + self._penalty_on_subtree(params, **filter_kwargs) return penalty @abc.abstractmethod def _penalty_on_subtree(self, subtree, **kwargs) -> jnp.ndarray: pass def _validate_strength(self, strength: Any): """ Validate regularizer strength type. Parameters ---------- strength : Any Regularizer strength specified as one of: - None Defaults to a scalar strength of 1.0. - scalar (Python number or 0-D array) Preserved as-is. - array-like or PyTree Converted leaf-wise to JAX arrays. Returns ------- Any A scalar or PyTree where: - Python scalar leaves are preserved - Array-like leaves are converted to `jnp.ndarray` Raises ------ ValueError If conversion of array-like leaves to JAX arrays fails. """ if strength is None: return 1.0 def _convert_if_arraylike(x): if x is None: return 1.0 elif isinstance(x, (int, float)): return x elif isinstance(x, (jnp.ndarray, np.ndarray)) and x.ndim == 0: return float(x) # use Python floats when possible elif isinstance(x, (jnp.ndarray, np.ndarray, list, tuple)): return jnp.asarray(x) else: raise TypeError try: return jax.tree_util.tree_map( _convert_if_arraylike, strength, is_leaf=lambda x: isinstance(x, (np.ndarray, jnp.ndarray, list, tuple)), ) except (ValueError, TypeError) as e: raise TypeError( f"Could not convert regularizer strength to floats: {strength}" ) from e def _validate_strength_structure(self, params: Any, strength: Any): """ Align and broadcast regularizer strength to match model parameters. This function takes a validated regularizer strength specification and aligns it with the structure of `params`, filling only regularizable subtrees and inserting `None` elsewhere. Regularizable subtrees are determined via `params.regularizable_subtrees()` if available; otherwise, the entire parameter tree is treated as regularizable. Parameters ---------- params : Any Model parameters structured as a PyTree. strength : Any Regularizer strength specification. Accepted forms: - None Uses a scalar strength of 1.0 for all regularizable parameters. - scalar or 0-D array Logically broadcast to all regularizable parameter leaves. - PyTree Must match the structure of the regularizable subtrees. Each leaf may be a scalar, 0-D array, or an array matching the corresponding parameter leaf shape. Returns ------- structured_strength : Any PyTree with the same structure as `params`: - Regularizable parameter leaves contain strength values (scalars or arrays) - Non-regularizable leaves are `None` Raises ------ ValueError If: - The number of provided strength subtrees does not match the number of regularizable subtrees. - A strength PyTree does not match the structure of its corresponding parameter subtree. - A non-scalar strength leaf does not match the shape of the corresponding parameter leaf. """ wheres = getattr(params, "regularizable_subtrees", lambda: [lambda x: x])() struct = jax.tree_util.tree_structure(params) structured_strength = jax.tree_util.tree_unflatten( struct, [None] * struct.num_leaves ) if strength is None: strength = 1.0 substrengths = ( strength if isinstance(strength, list) else [strength] * len(wheres) ) if len(substrengths) != len(wheres): raise ValueError(f"Expected {len(wheres)} strength values, got {strength}") def _structured_strength(strength_leaf, param_leaf): if _is_scalar_or_0d(strength_leaf): return strength_leaf if strength_leaf.shape != param_leaf.shape: raise ValueError( f"Strength shape {strength_leaf.shape} does not match " f"parameter shape {param_leaf.shape}" ) return strength_leaf for substrength, where in zip(substrengths, wheres): subtree = where(params) validated = ( jax.tree_util.tree_map(lambda p: substrength, subtree) if _is_scalar_or_0d(substrength) else jax.tree_util.tree_map(_structured_strength, substrength, subtree) ) structured_strength = eqx.tree_at( where, structured_strength, validated, is_leaf=lambda x: x is None ) return structured_strength def _get_filter_kwargs(self, params: Any, strength: Any): strength = self._validate_strength_structure(params, strength) return {"strength": strength}
[docs] class UnRegularized(Regularizer): """ Regularizer class for unregularized models. This class equips models with the identity proximal operator (no shrinkage) and the unpenalized loss function. """ _allowed_solvers = ( "GradientDescent", "BFGS", "LBFGS", "NonlinearCG", "ProximalGradient", "SVRG", "ProxSVRG", ) _default_solver = "LBFGS" _proximal_operator = staticmethod(prox_none) def _penalty_on_subtree(self, subtree, **kwargs) -> jnp.ndarray: return jnp.array(0.0) def _validate_strength(self, strength: Any): return None
[docs] class Ridge(Regularizer): """ Regularizer class for Ridge (L2 regularization). This class equips models with the Ridge proximal operator and the Ridge penalized loss function. """ _allowed_solvers = ( "GradientDescent", "BFGS", "LBFGS", "NonlinearCG", "ProximalGradient", "SVRG", "ProxSVRG", ) _default_solver = "LBFGS" _proximal_operator = staticmethod(prox_ridge) def _penalty_on_subtree(self, subtree, strength: Any, **kwargs) -> jnp.ndarray: """ Compute the Ridge penalization for given parameters. Parameters ---------- subtree : Model parameter subtree for which to compute the penalization. strength : Regularization strength. Returns ------- float The Ridge penalization value. """ def l2_penalty(coeff: jnp.ndarray, leaf_strength: jnp.ndarray): return 0.5 * jnp.sum(leaf_strength * jnp.square(coeff)) return tree_utils.pytree_map_and_reduce( l2_penalty, sum, subtree, strength, )
[docs] class Lasso(Regularizer): """ Regularizer class for Lasso (L1 regularization). This class equips models with the Lasso proximal operator and the Lasso penalized loss function. """ _allowed_solvers = ( "ProximalGradient", "ProxSVRG", ) _default_solver = "ProximalGradient" _proximal_operator = staticmethod(prox_lasso) def _penalty_on_subtree(self, subtree, strength: Any, **kwargs) -> jnp.ndarray: """ Compute the Lasso penalization for given parameters. Parameters ---------- subtree : Model parameters for which to compute the penalization. strength : Regularization strength. Returns ------- float The Lasso penalization value. """ def l1_penalty(coeff: jnp.ndarray, leaf_strength: jnp.ndarray): return jnp.sum(leaf_strength * jnp.abs(coeff)) return tree_utils.pytree_map_and_reduce( l1_penalty, sum, subtree, strength, )
[docs] class ElasticNet(Regularizer): r""" Regularizer class for Elastic Net (L1 + L2 regularization). The Elastic Net penalty [3]_ [4]_ is defined as: .. math:: P(\beta) = \alpha \left((1 - \lambda) \frac{1}{2} ||\beta||_{\ell_2}^2 + \lambda ||\beta||_{\ell_1} \right) where :math:`\alpha` is the regularizer strength, and :math:`\lambda` is the regularizer ratio. The regularizer ratio controls the balance between L1 (Lasso) and L2 (Ridge) regularization, where :math:`\lambda = 0` is equivalent to Ridge regularization and :math:`\lambda = 1` is equivalent to Lasso regularization. This class equips models with the Elastic Net proximal operator and the Elastic Net penalized loss function. References ---------- .. [3] Zou, H., & Hastie, T. (2005). Regularization and variable selection via the elastic net. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 67(2), 301-320. https://doi.org/10.1111/j.1467-9868.2005.00503.x .. [4] https://en.wikipedia.org/wiki/Elastic_net_regularization """ _allowed_solvers = ( "ProximalGradient", "ProxSVRG", ) _default_solver = "ProximalGradient" _proximal_operator = staticmethod(prox_elastic_net) def _penalty_on_subtree(self, subtree: Any, strength: Any, **kwargs) -> jnp.ndarray: r""" Compute the Elastic Net penalization for given parameters. The elastic net penalty is defined as: .. math:: P(\beta) = \alpha ((1 - \lambda) \frac{1}{2} ||\beta||_{\ell_2}^2 + \lambda ||\beta||_{\ell_1} where :math:`\alpha` is the regularizer strength, and :math:`\lambda` is the regularizer ratio. The regularizer ratio controls the balance between L1 (Lasso) and L2 (Ridge) regularization, where :math:`\lambda = 0` is equivalent to Ridge regularization and :math:`\lambda = 1` is equivalent to Lasso regularization. Parameters ---------- subtree : Model parameters for which to compute the penalization. strength : Regularization strength. Returns ------- : The Elastic Net penalization value. """ def net_penalty(coeff, leaf_strength): s, r = leaf_strength quad = 0.5 * (1.0 - r) * jnp.square(coeff) l1 = r * jnp.abs(coeff) return jnp.sum(s * (quad + l1)) return tree_utils.pytree_map_and_reduce( net_penalty, sum, subtree, strength, ) def _validate_strength(self, strength: Any): if strength is None: strength, ratio = 1.0, 0.5 elif isinstance(strength, tuple): if len(strength) != 2: raise TypeError( "ElasticNet regularizer strength must be a tuple (strength, ratio)" ) strength, ratio = strength else: strength, ratio = strength, 0.5 strength = super()._validate_strength(strength) ratio = super()._validate_strength(ratio) def check_ratio(r): if jnp.any((r <= 0) | (r > 1)): raise ValueError( f"ElasticNet regularization ratio must be in (0, 1], got {r}" ) return r ratio = jax.tree_util.tree_map(check_ratio, ratio) return strength, ratio def _validate_strength_structure(self, params: Any, strength: Any): _strength = super()._validate_strength_structure(params, strength[0]) ratio = super()._validate_strength_structure(params, strength[1]) def zip_leaves(s, r): if s is None: return None return (s, r) return jax.tree_util.tree_map(zip_leaves, _strength, ratio)
[docs] class GroupLasso(Regularizer): """ Regularizer class for Group Lasso (group-L1) regularized models. This class equips models with the group-lasso proximal operator and the group-lasso penalized loss function. Attributes ---------- mask : A mask array (or PyTree of arrays) indicating group membership for regularization. Each regularizable parameter leaf with shape ``(n_features, ...)`` requires a corresponding mask leaf with shape ``(n_groups, n_features, ...)``, i.e. ``(n_groups, *params.shape)``. Row ``i`` of a mask leaf is 1 for features belonging to group ``i`` and 0 elsewhere; each feature may belong to at most one group. If ``None`` (default), a mask is auto-initialized so that each trailing dimension of each parameter leaf forms its own group. Notes ----- For GroupLasso, the regularizer strength is defined **per group**, not per parameter. It must be either a scalar or a 1D array of length ``n_groups``. Examples -------- >>> import numpy as np >>> from nemos.regularizer import GroupLasso # Assuming the module is named group_lasso >>> from nemos.glm import GLM >>> # simulate some counts >>> num_samples, num_features, num_groups = 1000, 5, 3 >>> X = np.random.normal(size=(num_samples, num_features)) # design matrix >>> w = [0, 0.5, 1, 0, -0.5] # define some weights >>> y = np.random.poisson(np.exp(X.dot(w))) # observed counts >>> # Define a mask for 3 groups and 5 features >>> mask = np.zeros((num_groups, num_features)) >>> mask[0] = [1, 0, 0, 1, 0] # Group 0 includes features 0 and 3 >>> mask[1] = [0, 1, 0, 0, 0] # Group 1 includes features 1 >>> mask[2] = [0, 0, 1, 0, 1] # Group 2 includes features 2 and 4 >>> # Create the GroupLasso regularizer instance >>> group_lasso = GroupLasso(mask=mask) >>> # fit a group-lasso glm >>> model = GLM(regularizer=group_lasso, regularizer_strength=0.1).fit(X, y) >>> print(f"coeff shape: {model.coef_.shape}") coeff shape: (5,) For a :class:`~nemos.glm.PopulationGLM`, where ``coef_`` has shape ``(n_features, n_neurons)``, the mask must have the matching shape ``(n_groups, n_features, n_neurons)``: >>> import nemos as nmo >>> num_samples, num_features, num_neurons = 1000, 4, 3 >>> X = np.random.normal(size=(num_samples, num_features)) >>> w = np.random.randn(num_features, num_neurons) * 0.1 >>> y = np.random.poisson(np.exp(X.dot(w))) >>> # group 0: regularize all features jointly for neurons 0-1 >>> # group 1: regularize all features jointly for neuron 2 >>> mask = np.zeros((2, num_features, num_neurons)) >>> mask[0, :, :2] = 1 >>> mask[1, :, 2:] = 1 >>> model = nmo.glm.PopulationGLM( ... regularizer=nmo.regularizer.GroupLasso(mask=mask), ... regularizer_strength=0.1, ... ).fit(X, y) >>> print(f"coef shape: {model.coef_.shape}") coef shape: (4, 3) """ _allowed_solvers = ( "ProximalGradient", "ProxSVRG", ) _default_solver = "ProximalGradient" _proximal_operator = staticmethod(prox_group_lasso)
[docs] def __init__( self, mask: Any = None, ): super().__init__() self.mask = mask
@property def mask(self): """Getter for the mask attribute.""" return self._mask @mask.setter def mask(self, mask: Union[jnp.ndarray, None]): """Setter for the mask attribute.""" # check mask if passed by user, else will be initialized later if mask is not None: mask = self._cast_and_check_mask(mask) self._mask = mask
[docs] def initialize_mask(self, x: Any) -> Any: """ Initialize a default group mask for a PyTree of parameters. Creates a mask where each leaf array and each of its trailing dimensions (beyond the first) are assigned to separate groups. This default grouping treats: - Each leaf in the PyTree as a distinct parameter set - The first dimension (axis 0) as the feature dimension - Each trailing dimension as a separate group of features For a leaf with shape (n_features, d1, d2, ..., dk), this creates (d1 * d2 * ... * dk) groups, where each group's mask is 1.0 for all features in that specific trailing dimension combination and 0.0 elsewhere. Parameters ---------- x : Any PyTree of parameter arrays. Each leaf should have shape (n_features, ...) where n_features is the number of features and trailing dimensions define additional grouping structure. Returns ------- mask : Any PyTree with the same structure as x, where each leaf has shape (n_groups, n_features, ...) matching the original leaf shape. The mask contains 1.0 for elements in each group and 0.0 elsewhere. """ reg_subtrees = ( x.regularizable_subtrees() if hasattr(x, "regularizable_subtrees") else [lambda z: z] ) struct = jax.tree_util.tree_structure(x) mask = jax.tree_util.tree_unflatten(struct, [None] * struct.num_leaves) for where in reg_subtrees: mask = eqx.tree_at( where, mask, self._initialize_subtree_mask(where(x)), is_leaf=lambda m: m is None, ) return mask
@staticmethod def _initialize_subtree_mask(subtree: Any) -> Any: """Initialize individual subtree mask matching structure.""" flat_x, struct = jax.tree_util.tree_flatten(subtree) # Calculate total number of groups across all leaves n_groups_per_leaf = [math.prod(leaf.shape[1:]) for leaf in flat_x] total_groups = sum(n_groups_per_leaf) # Build mask for each leaf mask_flat = [] group_offset = 0 for leaf, n_groups in zip(flat_x, n_groups_per_leaf): # Create mask: (total_groups, n_features, *extra_dims) mask_shape = (total_groups, *leaf.shape) mask = jnp.zeros(mask_shape, dtype=float) # Set 1.0 for this leaf's groups along the flattened extra dimensions for i in range(n_groups): # Use reshape to map linear index to multi-dimensional index extra_shape = leaf.shape[1:] multi_idx = jnp.unravel_index(i, extra_shape) # Build index tuple: (group_id, slice(:), *multi_idx) # When dropping support for Python < 3.11, replace with # mask = mask.at[group_offset + i, :, *multi_idx].set(1.0) full_idx = (group_offset + i, slice(None)) + multi_idx mask = mask.at[full_idx].set(1.0) mask_flat.append(mask) group_offset += n_groups return jax.tree_util.tree_unflatten(struct, mask_flat) @staticmethod def _cast_and_check_mask(mask: Any) -> Any: """ Cast to jax array of floats and validate the mask. This method ensures the mask adheres to requirements: - The mask should be castable to a PyTree of arrays of float type. - Each element must be either 0 or 1. - Each feature should belong to only one group. - The mask should not be empty. Raises ------ ValueError If any of the above conditions are not met. """ mask = convert_tree_leaves_to_jax_array( mask, "Unable to convert mask to a tree with ``jax.ndarray`` leaves.", ) flat_mask = jax.tree_util.tree_leaves(mask) n_groups = flat_mask[0].shape[0] if not all(f.shape[0] == n_groups for f in flat_mask[1:]): n_groups = {f.shape[0] == n_groups for f in flat_mask[1:]} raise ValueError( "The length of the first dimension array leaves in the mask PyTree " "should be equal to ``n_groups``. " f"Leaves of the mask tree have inconsistent first dimension lengths: {n_groups}." ) if any(m.ndim < 2 for m in flat_mask): raise ValueError( "Mask arrays should have at least 2 dimensions ``(n_groups, n_features, ...)``." ) if n_groups == 0: raise ValueError("Empty mask provided!") has_invalid_entries = pytree_map_and_reduce( lambda m: jnp.any((m != 1) & (m != 0)), any, mask ) if has_invalid_entries: raise ValueError("Mask elements must be 0s and 1s!") all_zeros = pytree_map_and_reduce(lambda m: jnp.all(m == 0), all, mask) if all_zeros: raise ValueError("Empty mask provided!") multi_group_assignment = pytree_map_and_reduce( lambda m: jnp.any(m.sum(axis=0) > 1), any, mask ) if multi_group_assignment: raise ValueError( "Incorrect group assignment. Some of the features are assigned " "to more than one group." ) return mask def _penalty_on_subtree( self, subtree, strength: Any, mask: Any = None, **kwargs ) -> jnp.ndarray: r""" Apply the Group Lasso penalty to a subtree. Note: the penalty is being calculated according to the following formula: .. math:: \\text{loss}(\beta_1,...,\beta_g) + \alpha \cdot \sum _{j=1...,g} \sqrt{\dim(\beta_j)} || \beta_j||_2 where :math:`g` is the number of groups, :math:`\dim(\cdot)` is the dimension of the vector, i.e. the number of coefficient in each :math:`\beta_j`, and :math:`||\cdot||_2` is the euclidean norm. """ def penalty_leaf(leaf, leaf_mask, leaf_strength): leaf_l2_norm = masked_norm_2(leaf, leaf_mask, normalize=False) leaf_norm = compute_normalization(leaf_mask) return jnp.sum(leaf_strength * leaf_norm * leaf_l2_norm) penalties = jax.tree_util.tree_map( penalty_leaf, subtree, mask, strength, ) return jnp.sum(jnp.array(jax.tree_util.tree_leaves(penalties))) def _check_mask_and_params_shape_match(self, mask, params): reg_subtrees = ( params.regularizable_subtrees() if hasattr(params, "regularizable_subtrees") else [lambda z: z] ) for where in reg_subtrees: sub_mask = where(mask) sub_params = where(params) shape_mismatched = pytree_map_and_reduce( lambda s, p: s.shape[1:] != p.shape, any, sub_mask, sub_params ) if shape_mismatched: flat_mask_leaves = jax.tree_util.tree_leaves(sub_mask) flat_param_leaves = jax.tree_util.tree_leaves(sub_params) mismatches = [ f"mask {s.shape} (expected {(s.shape[0], *p.shape)})" for s, p in zip(flat_mask_leaves, flat_param_leaves) if s.shape[1:] != p.shape ] sep = "\n\t- " raise ValueError( "GroupLasso mask shape mismatch: the mask must have shape " "``(n_groups, *params.shape)`` for every regularizable parameter leaf. " f"Mismatched leaves:\n\t- {sep.join(mismatches)}" ) def _validate_strength_structure(self, params: Any, strength: Any): mask = self.mask if self.mask is not None else self.initialize_mask(params) flat_mask = jax.tree_util.tree_leaves(mask) n_groups = flat_mask[0].shape[0] if isinstance(strength, (int, float)) or strength.ndim == 0: per_group_strength = jnp.full(n_groups, strength, dtype=float) else: strength = jnp.asarray(strength, dtype=float) if strength.ndim != 1 or strength.shape[0] != n_groups: raise ValueError( f"GroupLasso strength must be a scalar or shape ({n_groups},), " f"got shape {strength.shape}" ) per_group_strength = strength return jax.tree_util.tree_map(lambda _: per_group_strength, mask) def _get_filter_kwargs(self, params: Any, strength: Any) -> dict: if self.mask is not None: mask = self.mask self._check_mask_and_params_shape_match(mask, params) else: mask = self.initialize_mask(params) return {"mask": mask, **super()._get_filter_kwargs(params, strength)}