"""
A Module for Optimization with Various Regularization Schemes.
This module provides a series of classes that facilitate the optimization of models
with different types of regularizations. Each `Regularizer` class in this module interfaces
with various optimization methods, and they can be applied depending on the model's requirements.
"""
import abc
import math
from typing import Any, Callable, Tuple, Union
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from . import tree_utils
from .base_class import Base
from .proximal_operator import (
compute_normalization,
masked_norm_2,
prox_elastic_net,
prox_group_lasso,
prox_lasso,
prox_none,
prox_ridge,
)
from .tree_utils import pytree_map_and_reduce
from .type_casting import _is_scalar_or_0d
from .typing import ProximalOperator
from .utils import format_repr
from .validation import convert_tree_leaves_to_jax_array
__all__ = ["UnRegularized", "Ridge", "Lasso", "GroupLasso", "ElasticNet"]
def __dir__() -> list[str]:
return __all__
def apply_operator(func, params, *args, filter_kwargs=None, **kwargs):
"""
Apply an operator to all regularizable subtrees of a parameter pytree.
This function iterates over all locations returned by
``params.regularizable_subtrees()`` and applies ``func`` to each selected
subtree. The updated values are written back into ``params`` using
:func:`equinox.tree_at`. Typical use cases include applying proximal
operators or other transformations to parameter tensors while leaving
non-regularized fields (e.g., intercepts or structural metadata) unchanged.
Parameters
----------
func :
A callable with signature ``func(x, *args, **kwargs) -> Any``.
It receives each regularizable subtree ``x`` and must return a value
with the same pytree structure that should replace that subtree.
params :
params any parameter object. If it implements ``regularizable_subtrees()``, the
method is used to return an iterable of selector functions (
suitable for ``eqx.tree_at``) that identify the leaves/subtrees to be transformed.
*args :
Additional positional arguments passed directly to ``func``.
filter_kwargs :
Optional keyword-only dictionary of keyword arguments with PyTree values
that should be filtered per subtree. For each regularizable subtree, the
subtree selector is applied to each value in this dict, extracting only
the portion relevant to that subtree. These extracted kwargs are then
passed to ``func`` along with the subtree. This is useful for operators
that need PyTree-structured metadata (e.g., masks) aligned with the
parameter structure. Must be passed as a keyword argument. Default is
None, which results in no filtering.
**kwargs :
Additional keyword arguments passed directly to ``func``.
Returns
-------
params_new : same type as ``params``
A new pytree/module with ``func`` applied to all regularizable
subtrees. Non-regularized fields are preserved unchanged.
Notes
-----
- ``regularizable_subtrees()`` must return a sequence of callables
compatible with ``eqx.tree_at``. Each callable should extract a subtree
from ``params``.
- ``func`` must be pure and JAX-compatible if this function is used inside
JIT-compiled code.
- When ``filter_kwargs`` is provided, each value in the dict must be a PyTree
with the same structure as ``params`` (or compatible with the subtree selectors).
Examples
--------
A minimal working example with a fake ``Params`` object:
>>> import equinox as eqx
>>> class Params(eqx.Module):
... w: float
... b: float
...
... # Only `w` is regularizable
... def regularizable_subtrees(self):
... return [lambda p: p.w]
>>> p = Params(w=3.0, b=10.0)
Define an operator that halves the value:
>>> def halve(x):
... return x / 2
Apply it only to the regularizable subtree (`w`):
>>> p2 = apply_operator(halve, p)
>>> p2.w
1.5
>>> p2.b
10.0
The bias `b` is unchanged because it is not listed in
`regularizable_subtrees`.
Example with ``filter_kwargs`` for PyTree-structured metadata:
>>> def masked_op(x, mask=None):
... if mask is not None:
... return x * mask
... return x
>>> # Create a mask with same structure as params
>>> mask_tree = Params(w=0.5, b=1.0)
>>> # Apply operator with filtered kwargs - only the relevant mask piece
>>> # is passed to each subtree
>>> p3 = apply_operator(masked_op, p, filter_kwargs={"mask": mask_tree})
>>> p3.w # w was multiplied by mask.w (0.5)
1.5
>>> p3.b # b is not regularizable, so unchanged
10.0
"""
filter_kwargs = filter_kwargs or {}
# if there is a list of regularizable sub-trees use that
if hasattr(params, "regularizable_subtrees"):
regularizable_subtrees = params.regularizable_subtrees()
# otherwise regularize all the tree
else:
regularizable_subtrees = [lambda x: x]
for where in regularizable_subtrees:
# Extract subtree-specific kwargs by applying the selector to each value
subtree_kwargs = {key: where(val) for key, val in filter_kwargs.items()}
params = eqx.tree_at(
where,
params,
func(where(params), *args, **kwargs, **subtree_kwargs),
)
return params
[docs]
class Regularizer(Base, abc.ABC):
"""
Abstract base class for regularized solvers.
This class is designed to provide a consistent interface for optimization solvers,
enabling users to easily switch between different regularizers, ensuring compatibility
with various loss functions and optimization algorithms.
Attributes
----------
allowed_solvers : Tuple[str]
Tuple of solver names that are allowed for use with this regularizer.
default_solver : str
String of the default solver name allowed for use with this regularizer.
"""
_allowed_solvers: Tuple[str]
_default_solver: str
_proximal_operator: Callable
@property
def allowed_solvers(self) -> Tuple[str]:
return self._allowed_solvers
@property
def default_solver(self) -> str:
return self._default_solver
[docs]
def check_solver(self, solver_name: str) -> None:
"""Raise an error if the given solver is not allowed."""
if solver_name not in self._allowed_solvers:
raise ValueError(
f"The solver: {solver_name} is not allowed for "
f"{self.__class__.__name__} regularization. Allowed solvers are "
f"{self.allowed_solvers}."
f"If {solver_name} is your implementation and is designed to be"
f"compatible with {self.__class__.__name__}, register it with"
f"{self.__class__.__name__}.allow_solver({solver_name})"
)
[docs]
@classmethod
def allow_solver(cls, algo_name: str) -> None:
"""
Add an algorithm to the list of compatible solvers.
Parameters
----------
algo_name :
Name of the optimization algorithm to add.
"""
if algo_name in cls._allowed_solvers:
return
cls._allowed_solvers += (algo_name,)
def __repr__(self) -> str:
return format_repr(self)
def __str__(self) -> str:
return format_repr(self)
@staticmethod
def _check_loss_output_tuple(output: tuple):
if len(output) != 2:
n_out = len(output)
word = "value" if n_out == 1 else "values"
raise ValueError(
f"Invalid loss function return. The loss function returns a tuple with {n_out} {word}.\n"
"A valid loss function can return either a single value (float or a 0-dim array), the loss, "
"or a tuple with two values, the loss and an auxiliary variable."
)
[docs]
def get_proximal_operator(self, params: Any, strength: Any) -> ProximalOperator:
"""
Retrieve the proximal operator.
Parameters
----------
params:
The parameters to be regularized.
Returns
-------
:
The proximal operator, applying regularization to the provided parameters.
"""
filter_kwargs = self._get_filter_kwargs(strength=strength, params=params)
# hyperparams is unused: strength is captured in filter_kwargs at construction time.
# The argument is required to match the jaxopt prox interface:
# prox(params, hyperparams_prox, scaling=1.0).
def prox_op(params, hyperparams, scaling=1.0, *args):
return apply_operator(
self._proximal_operator,
params,
filter_kwargs=filter_kwargs,
scaling=scaling,
)
return prox_op
[docs]
def penalized_loss(self, loss: Callable, params: Any, strength: Any) -> Callable:
"""Return a function for calculating the penalized loss."""
filter_kwargs = self._get_filter_kwargs(strength=strength, params=params)
def _penalized_loss(params, *args, **kwargs):
result = loss(params, *args, **kwargs)
penalty = self._penalization(params, filter_kwargs=filter_kwargs)
if isinstance(result, tuple):
self._check_loss_output_tuple(result)
loss_value, aux = result
return loss_value + penalty, aux
return result + penalty
return _penalized_loss
def _penalization(self, params: Any, filter_kwargs: dict) -> jnp.ndarray:
penalty = jnp.array(0.0)
if hasattr(params, "regularizable_subtrees"):
for where in params.regularizable_subtrees():
subtree = where(params)
subtree_kwargs = {key: where(val) for key, val in filter_kwargs.items()}
penalty = penalty + self._penalty_on_subtree(subtree, **subtree_kwargs)
else:
penalty = penalty + self._penalty_on_subtree(params, **filter_kwargs)
return penalty
@abc.abstractmethod
def _penalty_on_subtree(self, subtree, **kwargs) -> jnp.ndarray:
pass
def _validate_strength(self, strength: Any):
"""
Validate regularizer strength type.
Parameters
----------
strength : Any
Regularizer strength specified as one of:
- None
Defaults to a scalar strength of 1.0.
- scalar (Python number or 0-D array)
Preserved as-is.
- array-like or PyTree
Converted leaf-wise to JAX arrays.
Returns
-------
Any
A scalar or PyTree where:
- Python scalar leaves are preserved
- Array-like leaves are converted to `jnp.ndarray`
Raises
------
ValueError
If conversion of array-like leaves to JAX arrays fails.
"""
if strength is None:
return 1.0
def _convert_if_arraylike(x):
if x is None:
return 1.0
elif isinstance(x, (int, float)):
return x
elif isinstance(x, (jnp.ndarray, np.ndarray)) and x.ndim == 0:
return float(x) # use Python floats when possible
elif isinstance(x, (jnp.ndarray, np.ndarray, list, tuple)):
return jnp.asarray(x)
else:
raise TypeError
try:
return jax.tree_util.tree_map(
_convert_if_arraylike,
strength,
is_leaf=lambda x: isinstance(x, (np.ndarray, jnp.ndarray, list, tuple)),
)
except (ValueError, TypeError) as e:
raise TypeError(
f"Could not convert regularizer strength to floats: {strength}"
) from e
def _validate_strength_structure(self, params: Any, strength: Any):
"""
Align and broadcast regularizer strength to match model parameters.
This function takes a validated regularizer strength specification and
aligns it with the structure of `params`, filling only regularizable
subtrees and inserting `None` elsewhere.
Regularizable subtrees are determined via
`params.regularizable_subtrees()` if available; otherwise, the entire
parameter tree is treated as regularizable.
Parameters
----------
params : Any
Model parameters structured as a PyTree.
strength : Any
Regularizer strength specification. Accepted forms:
- None
Uses a scalar strength of 1.0 for all regularizable parameters.
- scalar or 0-D array
Logically broadcast to all regularizable parameter leaves.
- PyTree
Must match the structure of the regularizable subtrees. Each leaf
may be a scalar, 0-D array, or an array matching the corresponding
parameter leaf shape.
Returns
-------
structured_strength : Any
PyTree with the same structure as `params`:
- Regularizable parameter leaves contain strength values
(scalars or arrays)
- Non-regularizable leaves are `None`
Raises
------
ValueError
If:
- The number of provided strength subtrees does not match the number of
regularizable subtrees.
- A strength PyTree does not match the structure of its corresponding
parameter subtree.
- A non-scalar strength leaf does not match the shape of the
corresponding parameter leaf.
"""
wheres = getattr(params, "regularizable_subtrees", lambda: [lambda x: x])()
struct = jax.tree_util.tree_structure(params)
structured_strength = jax.tree_util.tree_unflatten(
struct, [None] * struct.num_leaves
)
if strength is None:
strength = 1.0
substrengths = (
strength if isinstance(strength, list) else [strength] * len(wheres)
)
if len(substrengths) != len(wheres):
raise ValueError(f"Expected {len(wheres)} strength values, got {strength}")
def _structured_strength(strength_leaf, param_leaf):
if _is_scalar_or_0d(strength_leaf):
return strength_leaf
if strength_leaf.shape != param_leaf.shape:
raise ValueError(
f"Strength shape {strength_leaf.shape} does not match "
f"parameter shape {param_leaf.shape}"
)
return strength_leaf
for substrength, where in zip(substrengths, wheres):
subtree = where(params)
validated = (
jax.tree_util.tree_map(lambda p: substrength, subtree)
if _is_scalar_or_0d(substrength)
else jax.tree_util.tree_map(_structured_strength, substrength, subtree)
)
structured_strength = eqx.tree_at(
where, structured_strength, validated, is_leaf=lambda x: x is None
)
return structured_strength
def _get_filter_kwargs(self, params: Any, strength: Any):
strength = self._validate_strength_structure(params, strength)
return {"strength": strength}
[docs]
class UnRegularized(Regularizer):
"""
Regularizer class for unregularized models.
This class equips models with the identity proximal operator (no shrinkage) and the
unpenalized loss function.
"""
_allowed_solvers = (
"GradientDescent",
"BFGS",
"LBFGS",
"NonlinearCG",
"ProximalGradient",
"SVRG",
"ProxSVRG",
)
_default_solver = "LBFGS"
_proximal_operator = staticmethod(prox_none)
def _penalty_on_subtree(self, subtree, **kwargs) -> jnp.ndarray:
return jnp.array(0.0)
def _validate_strength(self, strength: Any):
return None
[docs]
class Ridge(Regularizer):
"""
Regularizer class for Ridge (L2 regularization).
This class equips models with the Ridge proximal operator and the
Ridge penalized loss function.
"""
_allowed_solvers = (
"GradientDescent",
"BFGS",
"LBFGS",
"NonlinearCG",
"ProximalGradient",
"SVRG",
"ProxSVRG",
)
_default_solver = "LBFGS"
_proximal_operator = staticmethod(prox_ridge)
def _penalty_on_subtree(self, subtree, strength: Any, **kwargs) -> jnp.ndarray:
"""
Compute the Ridge penalization for given parameters.
Parameters
----------
subtree :
Model parameter subtree for which to compute the penalization.
strength :
Regularization strength.
Returns
-------
float
The Ridge penalization value.
"""
def l2_penalty(coeff: jnp.ndarray, leaf_strength: jnp.ndarray):
return 0.5 * jnp.sum(leaf_strength * jnp.square(coeff))
return tree_utils.pytree_map_and_reduce(
l2_penalty,
sum,
subtree,
strength,
)
[docs]
class Lasso(Regularizer):
"""
Regularizer class for Lasso (L1 regularization).
This class equips models with the Lasso proximal operator and the
Lasso penalized loss function.
"""
_allowed_solvers = (
"ProximalGradient",
"ProxSVRG",
)
_default_solver = "ProximalGradient"
_proximal_operator = staticmethod(prox_lasso)
def _penalty_on_subtree(self, subtree, strength: Any, **kwargs) -> jnp.ndarray:
"""
Compute the Lasso penalization for given parameters.
Parameters
----------
subtree :
Model parameters for which to compute the penalization.
strength :
Regularization strength.
Returns
-------
float
The Lasso penalization value.
"""
def l1_penalty(coeff: jnp.ndarray, leaf_strength: jnp.ndarray):
return jnp.sum(leaf_strength * jnp.abs(coeff))
return tree_utils.pytree_map_and_reduce(
l1_penalty,
sum,
subtree,
strength,
)
[docs]
class ElasticNet(Regularizer):
r"""
Regularizer class for Elastic Net (L1 + L2 regularization).
The Elastic Net penalty [3]_ [4]_ is defined as:
.. math::
P(\beta) = \alpha \left((1 - \lambda) \frac{1}{2} ||\beta||_{\ell_2}^2 +
\lambda ||\beta||_{\ell_1} \right)
where :math:`\alpha` is the regularizer strength, and :math:`\lambda` is the regularizer ratio.
The regularizer ratio controls the balance between L1 (Lasso) and L2 (Ridge)
regularization, where :math:`\lambda = 0` is equivalent to Ridge regularization and
:math:`\lambda = 1` is equivalent to Lasso regularization.
This class equips models with the Elastic Net proximal operator and the
Elastic Net penalized loss function.
References
----------
.. [3] Zou, H., & Hastie, T. (2005).
Regularization and variable selection via the elastic net.
Journal of the Royal Statistical Society: Series B (Statistical Methodology), 67(2), 301-320.
https://doi.org/10.1111/j.1467-9868.2005.00503.x
.. [4] https://en.wikipedia.org/wiki/Elastic_net_regularization
"""
_allowed_solvers = (
"ProximalGradient",
"ProxSVRG",
)
_default_solver = "ProximalGradient"
_proximal_operator = staticmethod(prox_elastic_net)
def _penalty_on_subtree(self, subtree: Any, strength: Any, **kwargs) -> jnp.ndarray:
r"""
Compute the Elastic Net penalization for given parameters.
The elastic net penalty is defined as:
.. math::
P(\beta) = \alpha ((1 - \lambda) \frac{1}{2} ||\beta||_{\ell_2}^2 +
\lambda ||\beta||_{\ell_1}
where :math:`\alpha` is the regularizer strength, and :math:`\lambda` is the regularizer ratio.
The regularizer ratio controls the balance between L1 (Lasso) and L2 (Ridge)
regularization, where :math:`\lambda = 0` is equivalent to Ridge regularization and
:math:`\lambda = 1` is equivalent to Lasso regularization.
Parameters
----------
subtree :
Model parameters for which to compute the penalization.
strength :
Regularization strength.
Returns
-------
:
The Elastic Net penalization value.
"""
def net_penalty(coeff, leaf_strength):
s, r = leaf_strength
quad = 0.5 * (1.0 - r) * jnp.square(coeff)
l1 = r * jnp.abs(coeff)
return jnp.sum(s * (quad + l1))
return tree_utils.pytree_map_and_reduce(
net_penalty,
sum,
subtree,
strength,
)
def _validate_strength(self, strength: Any):
if strength is None:
strength, ratio = 1.0, 0.5
elif isinstance(strength, tuple):
if len(strength) != 2:
raise TypeError(
"ElasticNet regularizer strength must be a tuple (strength, ratio)"
)
strength, ratio = strength
else:
strength, ratio = strength, 0.5
strength = super()._validate_strength(strength)
ratio = super()._validate_strength(ratio)
def check_ratio(r):
if jnp.any((r <= 0) | (r > 1)):
raise ValueError(
f"ElasticNet regularization ratio must be in (0, 1], got {r}"
)
return r
ratio = jax.tree_util.tree_map(check_ratio, ratio)
return strength, ratio
def _validate_strength_structure(self, params: Any, strength: Any):
_strength = super()._validate_strength_structure(params, strength[0])
ratio = super()._validate_strength_structure(params, strength[1])
def zip_leaves(s, r):
if s is None:
return None
return (s, r)
return jax.tree_util.tree_map(zip_leaves, _strength, ratio)
[docs]
class GroupLasso(Regularizer):
"""
Regularizer class for Group Lasso (group-L1) regularized models.
This class equips models with the group-lasso proximal operator and the
group-lasso penalized loss function.
Attributes
----------
mask :
A mask array (or PyTree of arrays) indicating group membership for regularization.
Each regularizable parameter leaf with shape ``(n_features, ...)`` requires a corresponding
mask leaf with shape ``(n_groups, n_features, ...)``, i.e. ``(n_groups, *params.shape)``.
Row ``i`` of a mask leaf is 1 for features belonging to group ``i`` and 0 elsewhere;
each feature may belong to at most one group.
If ``None`` (default), a mask is auto-initialized so that each trailing dimension of each
parameter leaf forms its own group.
Notes
-----
For GroupLasso, the regularizer strength is defined **per group**, not per parameter.
It must be either a scalar or a 1D array of length ``n_groups``.
Examples
--------
>>> import numpy as np
>>> from nemos.regularizer import GroupLasso # Assuming the module is named group_lasso
>>> from nemos.glm import GLM
>>> # simulate some counts
>>> num_samples, num_features, num_groups = 1000, 5, 3
>>> X = np.random.normal(size=(num_samples, num_features)) # design matrix
>>> w = [0, 0.5, 1, 0, -0.5] # define some weights
>>> y = np.random.poisson(np.exp(X.dot(w))) # observed counts
>>> # Define a mask for 3 groups and 5 features
>>> mask = np.zeros((num_groups, num_features))
>>> mask[0] = [1, 0, 0, 1, 0] # Group 0 includes features 0 and 3
>>> mask[1] = [0, 1, 0, 0, 0] # Group 1 includes features 1
>>> mask[2] = [0, 0, 1, 0, 1] # Group 2 includes features 2 and 4
>>> # Create the GroupLasso regularizer instance
>>> group_lasso = GroupLasso(mask=mask)
>>> # fit a group-lasso glm
>>> model = GLM(regularizer=group_lasso, regularizer_strength=0.1).fit(X, y)
>>> print(f"coeff shape: {model.coef_.shape}")
coeff shape: (5,)
For a :class:`~nemos.glm.PopulationGLM`, where ``coef_`` has shape
``(n_features, n_neurons)``, the mask must have the matching shape
``(n_groups, n_features, n_neurons)``:
>>> import nemos as nmo
>>> num_samples, num_features, num_neurons = 1000, 4, 3
>>> X = np.random.normal(size=(num_samples, num_features))
>>> w = np.random.randn(num_features, num_neurons) * 0.1
>>> y = np.random.poisson(np.exp(X.dot(w)))
>>> # group 0: regularize all features jointly for neurons 0-1
>>> # group 1: regularize all features jointly for neuron 2
>>> mask = np.zeros((2, num_features, num_neurons))
>>> mask[0, :, :2] = 1
>>> mask[1, :, 2:] = 1
>>> model = nmo.glm.PopulationGLM(
... regularizer=nmo.regularizer.GroupLasso(mask=mask),
... regularizer_strength=0.1,
... ).fit(X, y)
>>> print(f"coef shape: {model.coef_.shape}")
coef shape: (4, 3)
"""
_allowed_solvers = (
"ProximalGradient",
"ProxSVRG",
)
_default_solver = "ProximalGradient"
_proximal_operator = staticmethod(prox_group_lasso)
[docs]
def __init__(
self,
mask: Any = None,
):
super().__init__()
self.mask = mask
@property
def mask(self):
"""Getter for the mask attribute."""
return self._mask
@mask.setter
def mask(self, mask: Union[jnp.ndarray, None]):
"""Setter for the mask attribute."""
# check mask if passed by user, else will be initialized later
if mask is not None:
mask = self._cast_and_check_mask(mask)
self._mask = mask
[docs]
def initialize_mask(self, x: Any) -> Any:
"""
Initialize a default group mask for a PyTree of parameters.
Creates a mask where each leaf array and each of its trailing dimensions
(beyond the first) are assigned to separate groups. This default grouping
treats:
- Each leaf in the PyTree as a distinct parameter set
- The first dimension (axis 0) as the feature dimension
- Each trailing dimension as a separate group of features
For a leaf with shape (n_features, d1, d2, ..., dk), this creates
(d1 * d2 * ... * dk) groups, where each group's mask is 1.0 for all
features in that specific trailing dimension combination and 0.0 elsewhere.
Parameters
----------
x : Any
PyTree of parameter arrays. Each leaf should have shape
(n_features, ...) where n_features is the number of features
and trailing dimensions define additional grouping structure.
Returns
-------
mask : Any
PyTree with the same structure as x, where each leaf has shape
(n_groups, n_features, ...) matching the original leaf shape.
The mask contains 1.0 for elements in each group and 0.0 elsewhere.
"""
reg_subtrees = (
x.regularizable_subtrees()
if hasattr(x, "regularizable_subtrees")
else [lambda z: z]
)
struct = jax.tree_util.tree_structure(x)
mask = jax.tree_util.tree_unflatten(struct, [None] * struct.num_leaves)
for where in reg_subtrees:
mask = eqx.tree_at(
where,
mask,
self._initialize_subtree_mask(where(x)),
is_leaf=lambda m: m is None,
)
return mask
@staticmethod
def _initialize_subtree_mask(subtree: Any) -> Any:
"""Initialize individual subtree mask matching structure."""
flat_x, struct = jax.tree_util.tree_flatten(subtree)
# Calculate total number of groups across all leaves
n_groups_per_leaf = [math.prod(leaf.shape[1:]) for leaf in flat_x]
total_groups = sum(n_groups_per_leaf)
# Build mask for each leaf
mask_flat = []
group_offset = 0
for leaf, n_groups in zip(flat_x, n_groups_per_leaf):
# Create mask: (total_groups, n_features, *extra_dims)
mask_shape = (total_groups, *leaf.shape)
mask = jnp.zeros(mask_shape, dtype=float)
# Set 1.0 for this leaf's groups along the flattened extra dimensions
for i in range(n_groups):
# Use reshape to map linear index to multi-dimensional index
extra_shape = leaf.shape[1:]
multi_idx = jnp.unravel_index(i, extra_shape)
# Build index tuple: (group_id, slice(:), *multi_idx)
# When dropping support for Python < 3.11, replace with
# mask = mask.at[group_offset + i, :, *multi_idx].set(1.0)
full_idx = (group_offset + i, slice(None)) + multi_idx
mask = mask.at[full_idx].set(1.0)
mask_flat.append(mask)
group_offset += n_groups
return jax.tree_util.tree_unflatten(struct, mask_flat)
@staticmethod
def _cast_and_check_mask(mask: Any) -> Any:
"""
Cast to jax array of floats and validate the mask.
This method ensures the mask adheres to requirements:
- The mask should be castable to a PyTree of arrays of float type.
- Each element must be either 0 or 1.
- Each feature should belong to only one group.
- The mask should not be empty.
Raises
------
ValueError
If any of the above conditions are not met.
"""
mask = convert_tree_leaves_to_jax_array(
mask,
"Unable to convert mask to a tree with ``jax.ndarray`` leaves.",
)
flat_mask = jax.tree_util.tree_leaves(mask)
n_groups = flat_mask[0].shape[0]
if not all(f.shape[0] == n_groups for f in flat_mask[1:]):
n_groups = {f.shape[0] == n_groups for f in flat_mask[1:]}
raise ValueError(
"The length of the first dimension array leaves in the mask PyTree "
"should be equal to ``n_groups``. "
f"Leaves of the mask tree have inconsistent first dimension lengths: {n_groups}."
)
if any(m.ndim < 2 for m in flat_mask):
raise ValueError(
"Mask arrays should have at least 2 dimensions ``(n_groups, n_features, ...)``."
)
if n_groups == 0:
raise ValueError("Empty mask provided!")
has_invalid_entries = pytree_map_and_reduce(
lambda m: jnp.any((m != 1) & (m != 0)), any, mask
)
if has_invalid_entries:
raise ValueError("Mask elements must be 0s and 1s!")
all_zeros = pytree_map_and_reduce(lambda m: jnp.all(m == 0), all, mask)
if all_zeros:
raise ValueError("Empty mask provided!")
multi_group_assignment = pytree_map_and_reduce(
lambda m: jnp.any(m.sum(axis=0) > 1), any, mask
)
if multi_group_assignment:
raise ValueError(
"Incorrect group assignment. Some of the features are assigned "
"to more than one group."
)
return mask
def _penalty_on_subtree(
self, subtree, strength: Any, mask: Any = None, **kwargs
) -> jnp.ndarray:
r"""
Apply the Group Lasso penalty to a subtree.
Note: the penalty is being calculated according to the following formula:
.. math::
\\text{loss}(\beta_1,...,\beta_g) + \alpha \cdot \sum _{j=1...,g} \sqrt{\dim(\beta_j)} || \beta_j||_2
where :math:`g` is the number of groups, :math:`\dim(\cdot)` is the dimension of the vector,
i.e. the number of coefficient in each :math:`\beta_j`, and :math:`||\cdot||_2` is the euclidean norm.
"""
def penalty_leaf(leaf, leaf_mask, leaf_strength):
leaf_l2_norm = masked_norm_2(leaf, leaf_mask, normalize=False)
leaf_norm = compute_normalization(leaf_mask)
return jnp.sum(leaf_strength * leaf_norm * leaf_l2_norm)
penalties = jax.tree_util.tree_map(
penalty_leaf,
subtree,
mask,
strength,
)
return jnp.sum(jnp.array(jax.tree_util.tree_leaves(penalties)))
def _check_mask_and_params_shape_match(self, mask, params):
reg_subtrees = (
params.regularizable_subtrees()
if hasattr(params, "regularizable_subtrees")
else [lambda z: z]
)
for where in reg_subtrees:
sub_mask = where(mask)
sub_params = where(params)
shape_mismatched = pytree_map_and_reduce(
lambda s, p: s.shape[1:] != p.shape, any, sub_mask, sub_params
)
if shape_mismatched:
flat_mask_leaves = jax.tree_util.tree_leaves(sub_mask)
flat_param_leaves = jax.tree_util.tree_leaves(sub_params)
mismatches = [
f"mask {s.shape} (expected {(s.shape[0], *p.shape)})"
for s, p in zip(flat_mask_leaves, flat_param_leaves)
if s.shape[1:] != p.shape
]
sep = "\n\t- "
raise ValueError(
"GroupLasso mask shape mismatch: the mask must have shape "
"``(n_groups, *params.shape)`` for every regularizable parameter leaf. "
f"Mismatched leaves:\n\t- {sep.join(mismatches)}"
)
def _validate_strength_structure(self, params: Any, strength: Any):
mask = self.mask if self.mask is not None else self.initialize_mask(params)
flat_mask = jax.tree_util.tree_leaves(mask)
n_groups = flat_mask[0].shape[0]
if isinstance(strength, (int, float)) or strength.ndim == 0:
per_group_strength = jnp.full(n_groups, strength, dtype=float)
else:
strength = jnp.asarray(strength, dtype=float)
if strength.ndim != 1 or strength.shape[0] != n_groups:
raise ValueError(
f"GroupLasso strength must be a scalar or shape ({n_groups},), "
f"got shape {strength.shape}"
)
per_group_strength = strength
return jax.tree_util.tree_map(lambda _: per_group_strength, mask)
def _get_filter_kwargs(self, params: Any, strength: Any) -> dict:
if self.mask is not None:
mask = self.mask
self._check_mask_and_params_shape_match(mask, params)
else:
mask = self.initialize_mask(params)
return {"mask": mask, **super()._get_filter_kwargs(params, strength)}