nemos.regularizer.GroupLasso#
- class nemos.regularizer.GroupLasso(mask=None)[source]#
Bases:
RegularizerRegularizer class for Group Lasso (group-L1) regularized models.
This class equips models with the group-lasso proximal operator and the group-lasso penalized loss function.
- Parameters:
mask (Any)
- mask#
A mask array (or PyTree of arrays) indicating group membership for regularization. Each regularizable parameter leaf with shape
(n_features, ...)requires a corresponding mask leaf with shape(n_groups, n_features, ...), i.e.(n_groups, *params.shape). Rowiof a mask leaf is 1 for features belonging to groupiand 0 elsewhere; each feature may belong to at most one group. IfNone(default), a mask is auto-initialized so that each trailing dimension of each parameter leaf forms its own group.
Notes
For GroupLasso, the regularizer strength is defined per group, not per parameter. It must be either a scalar or a 1D array of length
n_groups.Examples
>>> import numpy as np >>> from nemos.regularizer import GroupLasso # Assuming the module is named group_lasso >>> from nemos.glm import GLM >>> # simulate some counts >>> num_samples, num_features, num_groups = 1000, 5, 3 >>> X = np.random.normal(size=(num_samples, num_features)) # design matrix >>> w = [0, 0.5, 1, 0, -0.5] # define some weights >>> y = np.random.poisson(np.exp(X.dot(w))) # observed counts >>> # Define a mask for 3 groups and 5 features >>> mask = np.zeros((num_groups, num_features)) >>> mask[0] = [1, 0, 0, 1, 0] # Group 0 includes features 0 and 3 >>> mask[1] = [0, 1, 0, 0, 0] # Group 1 includes features 1 >>> mask[2] = [0, 0, 1, 0, 1] # Group 2 includes features 2 and 4
>>> # Create the GroupLasso regularizer instance >>> group_lasso = GroupLasso(mask=mask) >>> # fit a group-lasso glm >>> model = GLM(regularizer=group_lasso, regularizer_strength=0.1).fit(X, y) >>> print(f"coeff shape: {model.coef_.shape}") coeff shape: (5,)
For a
PopulationGLM, wherecoef_has shape(n_features, n_neurons), the mask must have the matching shape(n_groups, n_features, n_neurons):>>> import nemos as nmo >>> num_samples, num_features, num_neurons = 1000, 4, 3 >>> X = np.random.normal(size=(num_samples, num_features)) >>> w = np.random.randn(num_features, num_neurons) * 0.1 >>> y = np.random.poisson(np.exp(X.dot(w))) >>> # group 0: regularize all features jointly for neurons 0-1 >>> # group 1: regularize all features jointly for neuron 2 >>> mask = np.zeros((2, num_features, num_neurons)) >>> mask[0, :, :2] = 1 >>> mask[1, :, 2:] = 1 >>> model = nmo.glm.PopulationGLM( ... regularizer=nmo.regularizer.GroupLasso(mask=mask), ... regularizer_strength=0.1, ... ).fit(X, y) >>> print(f"coef shape: {model.coef_.shape}") coef shape: (4, 3)
Attributes
Getter for the mask attribute.
Methods
__init__([mask])allow_solver(algo_name)Add an algorithm to the list of compatible solvers.
check_solver(solver_name)Raise an error if the given solver is not allowed.
get_params([deep])From scikit-learn, get parameters by inspecting init.
get_proximal_operator(params, strength)Retrieve the proximal operator.
Initialize a default group mask for a PyTree of parameters.
penalized_loss(loss, params, strength)Return a function for calculating the penalized loss.
set_params(**params)Set the parameters of this estimator.
- classmethod __init_subclass__(**kwargs)#
Set the
set_{method}_requestmethods.This uses PEP-487 [1] to set the
set_{method}_requestmethods. It looks for the information available in the set default values which are set using__metadata_request__*class attributes, or inferred from method signatures.The
__metadata_request__*class attributes are used when a method does not explicitly accept a metadata through its arguments or if the developer would like to specify a request value for those metadata which are different from the defaultNone.References
- classmethod allow_solver(algo_name)#
Add an algorithm to the list of compatible solvers.
- check_solver(solver_name)#
Raise an error if the given solver is not allowed.
- get_metadata_routing()#
Get metadata routing of this object.
Please check User Guide on how the routing mechanism works.
- Returns:
routing – A
MetadataRequestencapsulating routing information.- Return type:
MetadataRequest
- get_params(deep=True)#
From scikit-learn, get parameters by inspecting init.
- Parameters:
deep – If True, will return the parameters for this estimator and contained subobjects that are estimators.
- Return type:
- Returns:
A dictionary containing the parameters. Key is the parameter name, value is the parameter value.
- get_proximal_operator(params, strength)#
Retrieve the proximal operator.
- initialize_mask(x)[source]#
Initialize a default group mask for a PyTree of parameters.
Creates a mask where each leaf array and each of its trailing dimensions (beyond the first) are assigned to separate groups. This default grouping treats: - Each leaf in the PyTree as a distinct parameter set - The first dimension (axis 0) as the feature dimension - Each trailing dimension as a separate group of features
For a leaf with shape (n_features, d1, d2, …, dk), this creates (d1 * d2 * … * dk) groups, where each group’s mask is 1.0 for all features in that specific trailing dimension combination and 0.0 elsewhere.
- Parameters:
x (
Any) – PyTree of parameter arrays. Each leaf should have shape (n_features, …) where n_features is the number of features and trailing dimensions define additional grouping structure.- Returns:
mask – PyTree with the same structure as x, where each leaf has shape (n_groups, n_features, …) matching the original leaf shape. The mask contains 1.0 for elements in each group and 0.0 elsewhere.
- Return type:
- property mask#
Getter for the mask attribute.
- penalized_loss(loss, params, strength)#
Return a function for calculating the penalized loss.
- set_params(**params)#
Set the parameters of this estimator.
The method works on simple estimators as well as on nested objects (such as
Pipeline). The latter have parameters of the form<component>__<parameter>so that it’s possible to update each component of a nested object.- Parameters:
**params (
Any) – Estimator parameters.- Returns:
self – Estimator instance.
- Return type:
estimator instance