Download
Download this notebook: finegrained_regularization.ipynb!
Regularizing parameters with different strengths#
NeMoS allows for regularizing individual parameters with different regulariation strengths. By passing structures of regularization strengths that match the parameter structure, you can get fine control over how parameters are regularized.
Traditional regularization: all parameters are regularized equally#
We will first generate some synthetic data with two feature groups:
Note
We will store the features in a dict, but one can use any JAX pytree, take a look at the background note on pytrees to find out what they are.
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import nemos as nmo
np.random.seed(123)
n_samples = 500
n_features = 7
# random design matrix, containing two feature groups: f1 and f2
X = dict(
f1=0.5*np.random.normal(size=(n_samples, 5)),
f2=0.5*np.random.normal(size=(n_samples, 2)),
)
# log-rates & weights
b_true = 1.0
w1_true = np.random.uniform(size=(5,))
w2_true = np.random.uniform(size=(2,))
# generate counts (spikes will be (n_samples, )
rate = jnp.exp(jnp.dot(X['f1'], w1_true) + jnp.dot(X['f2'], w2_true)+ b_true)
spikes = np.random.poisson(rate)
print(spikes.shape)
(500,)
Let’s start with the traditional case where every parameter is regularized with the same strength.
We’ll fit a GLM with Ridge regression (this all works identically for Lasso regression):
glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=0.1)
glm.fit(X, spikes)
GLM(
observation_model=PoissonObservations(),
inverse_link_function=exp,
regularizer=Ridge(),
regularizer_strength=0.1,
solver_name='LBFGS'
)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| observation_model | PoissonObservations() | |
| inverse_link_function | <function exp...x709bda1a3f60> | |
| regularizer | Ridge() | |
| regularizer_strength | 0.1 | |
| solver_name | 'LBFGS' | |
| solver_kwargs | {} |
Fitted attributes
| Name | Type | Value |
|---|---|---|
| aux_ | NoneType | None |
| coef_ | dict | {'f1': Array([0.6940...dtype=float32), 'f2': Array([0.1343...dtype=float32)} |
| dof_resid_ | ArrayImpl[float32](1,) | Array([492.], dtype=float32) |
| intercept_ | ArrayImpl[float32](1,) | Array([1.0842...dtype=float32) |
| scale_ | ArrayImpl[float32](1,) | Array([1.], dtype=float32) |
| solver_state_ | OptimistixAdapterState | OptimistixAda...k_bool[] ) ) |
In this case, the regularization strength of 0.1 is used for all 7 parameters (the intercept is not regularized).
Group-wise regularization: all parameters within a group are regularized equally#
If we want different regularization strengths for the two parameter groups, we can pass a dictionary matching the design matrix:
glm = nmo.glm.GLM(
regularizer="Ridge",
regularizer_strength=dict(f1=0.1, f2=0.2)
)
glm.fit(X, spikes)
GLM(
observation_model=PoissonObservations(),
inverse_link_function=exp,
regularizer=Ridge(),
regularizer_strength={'f1': 0.1, 'f2': 0.2},
solver_name='LBFGS'
)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| observation_model | PoissonObservations() | |
| inverse_link_function | <function exp...x709bda1a3f60> | |
| regularizer | Ridge() | |
| regularizer_strength | {'f1': 0.1, 'f2': 0.2} | |
| solver_name | 'LBFGS' | |
| solver_kwargs | {} |
Fitted attributes
| Name | Type | Value |
|---|---|---|
| aux_ | NoneType | None |
| coef_ | dict | {'f1': Array([0.6984...dtype=float32), 'f2': Array([0.1297...dtype=float32)} |
| dof_resid_ | ArrayImpl[float32](1,) | Array([492.], dtype=float32) |
| intercept_ | ArrayImpl[float32](1,) | Array([1.0994...dtype=float32) |
| scale_ | ArrayImpl[float32](1,) | Array([1.], dtype=float32) |
| solver_state_ | OptimistixAdapterState | OptimistixAda...k_bool[] ) ) |
Parameter-wise regularization: every parameter has their own regularization strength#
If we want even finer control over regularization, we can pass arrays within the dictionary that match the design matrix:
glm = nmo.glm.GLM(
regularizer="Ridge",
regularizer_strength=dict(
f1=[0.1, 0.3, 0.3, 0.1, 1.0],
f2=[0.2, 0.1]
)
)
glm.fit(X, spikes)
GLM(
observation_model=PoissonObservations(),
inverse_link_function=exp,
regularizer=Ridge(),
regularizer_strength={'f1': Array([0.1, 0.3, 0.3, 0.1, 1. ], dtype=float32), 'f2': Array([0.2, 0.1], dtype=float32)},
solver_name='LBFGS'
)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| observation_model | PoissonObservations() | |
| inverse_link_function | <function exp...x709bda1a3f60> | |
| regularizer | Ridge() | |
| regularizer_strength | {'f1': Array([0.1, 0...dtype=float32), 'f2': Array([0.2, 0...dtype=float32)} | |
| solver_name | 'LBFGS' | |
| solver_kwargs | {} |
Fitted attributes
| Name | Type | Value |
|---|---|---|
| aux_ | NoneType | None |
| coef_ | dict | {'f1': Array([0.6907...dtype=float32), 'f2': Array([0.1308...dtype=float32)} |
| dof_resid_ | ArrayImpl[float32](1,) | Array([492.], dtype=float32) |
| intercept_ | ArrayImpl[float32](1,) | Array([1.1214...dtype=float32) |
| scale_ | ArrayImpl[float32](1,) | Array([1.], dtype=float32) |
| solver_state_ | OptimistixAdapterState | OptimistixAda...k_bool[] ) ) |
You can also mix different approaches, such as passing a single value for one group, and a list for the other:
glm = nmo.glm.GLM(
regularizer="Ridge",
regularizer_strength=dict(
f1=0.1,
f2=[0.2, 0.1]
)
)
glm.fit(X, spikes)
GLM(
observation_model=PoissonObservations(),
inverse_link_function=exp,
regularizer=Ridge(),
regularizer_strength={'f1': 0.1, 'f2': Array([0.2, 0.1], dtype=float32)},
solver_name='LBFGS'
)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| observation_model | PoissonObservations() | |
| inverse_link_function | <function exp...x709bda1a3f60> | |
| regularizer | Ridge() | |
| regularizer_strength | {'f1': 0.1, 'f2': Array([0.2, 0...dtype=float32)} | |
| solver_name | 'LBFGS' | |
| solver_kwargs | {} |
Fitted attributes
| Name | Type | Value |
|---|---|---|
| aux_ | NoneType | None |
| coef_ | dict | {'f1': Array([0.6942...dtype=float32), 'f2': Array([0.1223...dtype=float32)} |
| dof_resid_ | ArrayImpl[float32](1,) | Array([492.], dtype=float32) |
| intercept_ | ArrayImpl[float32](1,) | Array([1.0841...dtype=float32) |
| scale_ | ArrayImpl[float32](1,) | Array([1.], dtype=float32) |
| solver_state_ | OptimistixAdapterState | OptimistixAda...k_bool[] ) ) |
Special cases#
There are a couple special cases to keep in mind!
ElasticNet regularization#
ElasticNet regularization combines L1 and L2 regularization, introducing a ratio parameter that determines the relative contribution of either.
In the traditional case, you can pass the strength and ratio as a tuple:
glm = nmo.glm.GLM(regularizer="ElasticNet", regularizer_strength=(1.0, 0.5))
glm.fit(X, spikes)
GLM(
observation_model=PoissonObservations(),
inverse_link_function=exp,
regularizer=ElasticNet(),
regularizer_strength=(1.0, 0.5),
solver_name='ProximalGradient'
)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| observation_model | PoissonObservations() | |
| inverse_link_function | <function exp...x709bda1a3f60> | |
| regularizer | ElasticNet() | |
| regularizer_strength | (1.0, ...) | |
| solver_name | 'ProximalGradient' | |
| solver_kwargs | {} |
Fitted attributes
| Name | Type | Value |
|---|---|---|
| aux_ | NoneType | None |
| coef_ | dict | {'f1': Array([0.2151...dtype=float32), 'f2': Array([0. ...dtype=float32)} |
| dof_resid_ | ArrayImpl[int32]() | Array(495, dtype=int32) |
| intercept_ | ArrayImpl[float32](1,) | Array([1.3568...dtype=float32) |
| scale_ | ArrayImpl[float32](1,) | Array([1.], dtype=float32) |
| solver_state_ | OptimistixAdapterState | OptimistixAda...k_bool[] ) ) |
However, if you want finer control, you can again pass a dictionary matching the parameter structure: this time one for the strenghts, and one for the ratios:
glm = nmo.glm.GLM(
regularizer="ElasticNet",
regularizer_strength=(
dict( # strength
f1=[0.1, 0.3, 0.3, 0.1, 1.0],
f2=[0.2, 0.1]
),
dict( # ratio
f1=[0.5, 0.3, 0.5, 0.5, 0.5],
f2=[0.5, 0.4]
),
)
)
glm.fit(X, spikes)
GLM(
observation_model=PoissonObservations(),
inverse_link_function=exp,
regularizer=ElasticNet(),
regularizer_strength=({'f1': Array([0.1, 0.3, 0.3, 0.1, 1. ], dtype=float32), 'f2': Array([0.2, 0.1], dtype=float32)}, {'f1': Array([0.5, 0.3, 0.5, 0.5, 0.5], dtype=float32), 'f2': Array([0.5, 0.4], dtype=float32)}),
solver_name='ProximalGradient'
)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| observation_model | PoissonObservations() | |
| inverse_link_function | <function exp...x709bda1a3f60> | |
| regularizer | ElasticNet() | |
| regularizer_strength | ({'f1': Array([0.1, 0...dtype=float32), 'f2': Array([0.2, 0...dtype=float32)}, ...) | |
| solver_name | 'ProximalGradient' | |
| solver_kwargs | {} |
Fitted attributes
| Name | Type | Value |
|---|---|---|
| aux_ | NoneType | None |
| coef_ | dict | {'f1': Array([0.6824...dtype=float32), 'f2': Array([0.0619...dtype=float32)} |
| dof_resid_ | ArrayImpl[int32]() | Array(492, dtype=int32) |
| intercept_ | ArrayImpl[float32](1,) | Array([1.1417...dtype=float32) |
| scale_ | ArrayImpl[float32](1,) | Array([1.], dtype=float32) |
| solver_state_ | OptimistixAdapterState | OptimistixAda...k_bool[] ) ) |
GroupLasso regularization#
GroupLasso works like Lasso, but it works on groups of features instead of individual features.
It either keeps all features in a group, or shrinks the whole group to zero.
Regularizing individual parameters differently in GroupLasso does not make sense.
Instead, NeMoS allows for regularizing each group differently.
Again, you pass a dictionary, but now matching the groups, instead of the parameters explicitly:
glm = nmo.glm.GLM(
regularizer="GroupLasso",
regularizer_strength=[0.1, 0.4],
)
glm.fit(X, spikes)
GLM(
observation_model=PoissonObservations(),
inverse_link_function=exp,
regularizer=GroupLasso(),
solver_name='ProximalGradient'
)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| observation_model | PoissonObservations() | |
| inverse_link_function | <function exp...x709bda1a3f60> | |
| regularizer | GroupLasso() | |
| regularizer_strength | Array([0.1, 0...dtype=float32) | |
| solver_name | 'ProximalGradient' | |
| solver_kwargs | {} | |
| regularizer__mask | None |
Fitted attributes
| Name | Type | Value |
|---|---|---|
| aux_ | NoneType | None |
| coef_ | dict | {'f1': Array([0.6702...dtype=float32), 'f2': Array([0.0689...dtype=float32)} |
| dof_resid_ | ArrayImpl[int32]() | Array(492, dtype=int32) |
| intercept_ | ArrayImpl[float32](1,) | Array([1.2005...dtype=float32) |
| scale_ | ArrayImpl[float32](1,) | Array([1.], dtype=float32) |
| solver_state_ | OptimistixAdapterState | OptimistixAda...k_bool[] ) ) |
By default, GroupLasso generates masks that match the group stucture in the design matrix X.
If you pass your own mask, you need to make sure the regularizer strength matches its structure.
PopulationGLM#
A PopulationGLM models many neurons simultaneously.
Internally, this means it will have a set of parameters per neuron.
All regularization strategies above work for a PopulationGLM as well:
# we'll create a second neuron with twice the amount of spikes
spikes = jnp.stack([spikes, spikes*2], axis=1)
glm = nmo.glm.PopulationGLM(
regularizer="Ridge",
regularizer_strength=dict(
f1=[[0.1, 0.2], [0.1, 0.2], [0.1, 0.2], [0.1, 0.2], [0.1, 0.2]],
f2=[[0.1, 0.2], [0.1, 0.2]]
)
)
glm.fit(X, spikes)
PopulationGLM(
observation_model=PoissonObservations(),
inverse_link_function=exp,
regularizer=Ridge(),
regularizer_strength={'f1': Array([[0.1, 0.2],
[0.1, 0.2],
[0.1, 0.2],
[0.1, 0.2],
[0.1, 0.2]], dtype=float32), 'f2': Array([[0.1, 0.2],
[0.1, 0.2]], dtype=float32)},
solver_name='LBFGS'
)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| observation_model | PoissonObservations() | |
| inverse_link_function | <function exp...x709bda1a3f60> | |
| regularizer | Ridge() | |
| regularizer_strength | {'f1': Array([[0.1, ...dtype=float32), 'f2': Array([[0.1, ...dtype=float32)} | |
| solver_name | 'LBFGS' | |
| solver_kwargs | {} | |
| feature_mask | None |
Fitted attributes
| Name | Type | Value |
|---|---|---|
| aux_ | NoneType | None |
| coef_ | dict | {'f1': Array([[0.694...dtype=float32), 'f2': Array([[0.134...dtype=float32)} |
| dof_resid_ | ArrayImpl[float32](2,) | Array([492., ...dtype=float32) |
| intercept_ | ArrayImpl[float32](2,) | Array([1.0842...dtype=float32) |
| scale_ | ArrayImpl[float32](2,) | Array([1., 1.], dtype=float32) |
| solver_state_ | OptimistixAdapterState | OptimistixAda...k_bool[] ) ) |
For every parameter of every feature group (5 for f1 and 2 for f1) we are now passing two regularization strengths, one for each neuron. It is a bit tedious, but: this model is fitting two neurons at the same time and regularizing the parameters for each neuron differently!