Download
Download this notebook: handling_composite_bases.ipynb!
Handling Composite Bases#
Structure of Composite Basis#
Composite basis, aka objects of type AdditiveBasis or MultiplicativeBasis, are containers of multiple “atomic” one-dimensional basis, organized in a tree structure. Every time we add or multiplied two bases, they will be stored as attributes of the AdditiveBasis or MultiplicativeBasis respectively.
import nemos as nmo
# define a composite basis
add = nmo.basis.RaisedCosineLinearEval(5, label="input1") + nmo.basis.BSplineEval(6, label="input2")
# `add` stores the two 1dimensional bases as attributes
print(add)
print(add.basis1)
print(add.basis2)
'(input1 + input2)': AdditiveBasis(
basis1='input1': RaisedCosineLinearEval(n_basis_funcs=5, width=2.0),
basis2='input2': BSplineEval(n_basis_funcs=6, order=4),
)
'input1': RaisedCosineLinearEval(n_basis_funcs=5, width=2.0)
'input2': BSplineEval(n_basis_funcs=6, order=4)
Composing even more, will result in more nesting of attributes.
add = add + nmo.basis.MSplineEval(4, label="input3")
print(add)
print(add.basis1.basis1)
print(add.basis1.basis2)
print(add.basis2)
'((input1 + input2) + input3)': AdditiveBasis(
basis1='(input1 + input2)': AdditiveBasis(
basis1='input1': RaisedCosineLinearEval(n_basis_funcs=5, width=2.0),
basis2='input2': BSplineEval(n_basis_funcs=6, order=4),
),
basis2='input3': MSplineEval(n_basis_funcs=4, order=4),
)
'input1': RaisedCosineLinearEval(n_basis_funcs=5, width=2.0)
'input2': BSplineEval(n_basis_funcs=6, order=4)
'input3': MSplineEval(n_basis_funcs=4, order=4)
Retrieving Basis Components and Their Parameters#
In principle, nesting makes the process of retrieving or setting the parameters of individual components quite cumbersome.
# retreive the number of basis funciton for input2 basis
add.basis1.basis2.n_basis_funcs
6
However, if you associated a label to the basis, you can use it to get the corresponding basis element.
add["input2"]
'input2': BSplineEval(n_basis_funcs=6, order=4)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| n_basis_funcs | 6 | |
| label | 'input2' | |
| order | 4 | |
| bounds | None | |
| fill_value | nan |
And its parameters can be easily accessed.
add["input2"].n_basis_funcs
6
This works for any sub-element, including the one that are composite.
# get input1 + input2
add["(input1 + input2)"]
'(input1 + input2)': AdditiveBasis(
basis1='input1': RaisedCosineLinearEval(n_basis_funcs=5, width=2.0),
basis2='input2': BSplineEval(n_basis_funcs=6, order=4),
)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| label | '(input1 + input2)' | |
| input1__bounds | None | |
| input1__fill_value | nan | |
| input1__label | 'input1' | |
| input1__n_basis_funcs | 5 | |
| input1__width | 2.0 | |
| input1 | 'input1': Rai...=5, width=2.0) | |
| input2__bounds | None | |
| input2__fill_value | nan | |
| input2__label | 'input2' | |
| input2__n_basis_funcs | 6 | |
| input2__order | 4 | |
| input2 | 'input2': BSp...cs=6, order=4) |
Note that the label of this composite basis is assigned automatically. You can overwrite that with a custom label.
add["(input1 + input2)"].label = "my_custom_label"
add
'(my_custom_label + input3)': AdditiveBasis(
basis1='my_custom_label': AdditiveBasis(
basis1='input1': RaisedCosineLinearEval(n_basis_funcs=5, width=2.0),
basis2='input2': BSplineEval(n_basis_funcs=6, order=4),
),
basis2='input3': MSplineEval(n_basis_funcs=4, order=4),
)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| label | '(my_custom_label + input3)' | |
| input1__bounds | None | |
| input1__fill_value | nan | |
| input1__label | 'input1' | |
| input1__n_basis_funcs | 5 | |
| input1__width | 2.0 | |
| input1 | 'input1': Rai...=5, width=2.0) | |
| input2__bounds | None | |
| input2__fill_value | nan | |
| input2__label | 'input2' | |
| input2__n_basis_funcs | 6 | |
| input2__order | 4 | |
| input2 | 'input2': BSp...cs=6, order=4) | |
| my_custom_label__label | 'my_custom_label' | |
| my_custom_label | 'my_custom_la...6, order=4), ) | |
| input3__bounds | None | |
| input3__fill_value | nan | |
| input3__label | 'input3' | |
| input3__n_basis_funcs | 4 | |
| input3__order | 4 | |
| input3 | 'input3': MSp...cs=4, order=4) |
A label can be specified at initialization if the composite basis is defined directly.
nmo.basis.AdditiveBasis(
nmo.basis.BSplineEval(5),
nmo.basis.MSplineEval(5),
label="my_custom_label"
)
'my_custom_label': AdditiveBasis(
basis1=BSplineEval(n_basis_funcs=5, order=4),
basis2=MSplineEval(n_basis_funcs=5, order=4),
)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| label | 'my_custom_label' | |
| BSplineEval__bounds | None | |
| BSplineEval__fill_value | nan | |
| BSplineEval__label | 'BSplineEval' | |
| BSplineEval__n_basis_funcs | 5 | |
| BSplineEval__order | 4 | |
| BSplineEval | BSplineEval(n...cs=5, order=4) | |
| MSplineEval__bounds | None | |
| MSplineEval__fill_value | nan | |
| MSplineEval__label | 'MSplineEval' | |
| MSplineEval__n_basis_funcs | 5 | |
| MSplineEval__order | 4 | |
| MSplineEval | MSplineEval(n...cs=5, order=4) |
And if you are asking yourself what happens when two bases with the same label are composed, well, this results in an error. This guarantees that the labels are always unique and you can always retrieve a basis using its label.
nmo.basis.BSplineEval(5, label="x") + nmo.basis.MSplineEval(5, label="x")
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[9], line 1
----> 1 nmo.basis.BSplineEval(5, label="x") + nmo.basis.MSplineEval(5, label="x")
File ~/checkouts/readthedocs.org/user_builds/nemos/envs/latest/lib/python3.12/site-packages/nemos/basis/_composition_utils.py:609, in promote_to_transformer.<locals>.wrapper(*args, **kwargs)
607 args_bas = (b.basis if is_transformer(b) else b for b in args)
608 kwargs_bas = {k: b.basis if is_transformer(b) else b for k, b in kwargs.items()}
--> 609 out = method(*args_bas, **kwargs_bas)
610 any_transformer = any(
611 (
612 *(is_transformer(a) for a in args),
613 *(is_transformer(b) for b in kwargs.values()),
614 )
615 )
616 if any_transformer and hasattr(out, "to_transformer"):
File ~/checkouts/readthedocs.org/user_builds/nemos/envs/latest/lib/python3.12/site-packages/nemos/basis/_basis.py:433, in Basis.__add__(self, other)
418 @promote_to_transformer
419 def __add__(self, other: BasisMixin) -> AdditiveBasis:
420 """
421 Add two Basis objects together.
422
(...) 431 The resulting Basis object.
432 """
--> 433 return AdditiveBasis(self, other)
File ~/checkouts/readthedocs.org/user_builds/nemos/envs/latest/lib/python3.12/site-packages/nemos/basis/_basis.py:534, in AdditiveBasis.__init__(self, basis1, basis2, label)
531 def __init__(
532 self, basis1: BasisMixin, basis2: BasisMixin, label: Optional[str] = None
533 ) -> None:
--> 534 CompositeBasisMixin.__init__(self, basis1, basis2, label=label)
535 Basis.__init__(self)
File ~/checkouts/readthedocs.org/user_builds/nemos/envs/latest/lib/python3.12/site-packages/nemos/basis/_basis_mixin.py:935, in CompositeBasisMixin.__init__(self, basis1, basis2, label)
932 basis2 = copy.deepcopy(basis2)
934 self.basis1 = basis1
--> 935 self.basis2 = basis2
937 # set parents
938 self.basis1._parent = self
File ~/checkouts/readthedocs.org/user_builds/nemos/envs/latest/lib/python3.12/site-packages/nemos/basis/_basis_mixin.py:989, in CompositeBasisMixin.basis2(self, basis)
987 self._is_basis_like(basis2=basis)
988 if self._basis1:
--> 989 self._set_labels(self._basis1, basis)
990 if self._basis2:
991 basis = _composite_basis_setter_logic(basis, self._basis2)
File ~/checkouts/readthedocs.org/user_builds/nemos/envs/latest/lib/python3.12/site-packages/nemos/basis/_basis_mixin.py:1041, in CompositeBasisMixin._set_labels(self, basis1, basis2)
1039 non_unique, err_msg = self._check_unique_labels(basis1, basis2)
1040 if non_unique:
-> 1041 raise ValueError(err_msg)
1043 self.update_default_label_id(basis1, basis2)
ValueError: All user-provided labels of basis elements must be distinct.
The basis you are composing share the following labels: 'x'.
Please change the labels for one of the elements before composition.
Because we ensure that all basis labels are unique, you can always retrieve a specific basis using its label, even when the composite basis is made up of many individual basis objects.
# add 10 basis
composite_bas = nmo.basis.MSplineEval(4, label="label_0")
for k in range(1, 10):
composite_bas = composite_bas + nmo.basis.MSplineEval(4, label=f"label_{k}")
# retreive one of them using the label
composite_bas["label_5"]
'label_5': MSplineEval(n_basis_funcs=4, order=4)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| n_basis_funcs | 4 | |
| label | 'label_5' | |
| order | 4 | |
| bounds | None | |
| fill_value | nan |
Get and Set Composite Basis Parameters#
When working with composite bases, often times one wants to re-configurate specific components. Again, the easiest way to achieve this is labeling each element and using the label to retrieve the basis.
# get the basis function parameter
print(add["input2"].n_basis_funcs)
# set a new value for the parameter
add["input2"].n_basis_funcs = 8
print(add["input2"].n_basis_funcs)
6
8
This change is reflected on the composite basis.
# check that the input2 basis has now 8 basis funcs
add
'(my_custom_label + input3)': AdditiveBasis(
basis1='my_custom_label': AdditiveBasis(
basis1='input1': RaisedCosineLinearEval(n_basis_funcs=5, width=2.0),
basis2='input2': BSplineEval(n_basis_funcs=8, order=4),
),
basis2='input3': MSplineEval(n_basis_funcs=4, order=4),
)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| label | '(my_custom_label + input3)' | |
| input1__bounds | None | |
| input1__fill_value | nan | |
| input1__label | 'input1' | |
| input1__n_basis_funcs | 5 | |
| input1__width | 2.0 | |
| input1 | 'input1': Rai...=5, width=2.0) | |
| input2__bounds | None | |
| input2__fill_value | nan | |
| input2__label | 'input2' | |
| input2__n_basis_funcs | 8 | |
| input2__order | 4 | |
| input2 | 'input2': BSp...cs=8, order=4) | |
| my_custom_label__label | 'my_custom_label' | |
| my_custom_label | 'my_custom_la...8, order=4), ) | |
| input3__bounds | None | |
| input3__fill_value | nan | |
| input3__label | 'input3' | |
| input3__n_basis_funcs | 4 | |
| input3__order | 4 | |
| input3 | 'input3': MSp...cs=4, order=4) |
Note that if you don’t provide a label, basis class name is used to construct the keys. If the same basis is repeated, the key is disambiguated by appending an extra numerical identifier.
nmo.basis.BSplineEval(10) + nmo.basis.BSplineEval(5)
'(BSplineEval + BSplineEval_1)': AdditiveBasis(
basis1=BSplineEval(n_basis_funcs=10, order=4),
basis2='BSplineEval_1': BSplineEval(n_basis_funcs=5, order=4),
)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| label | '(BSplineEval + BSplineEval_1)' | |
| BSplineEval__bounds | None | |
| BSplineEval__fill_value | nan | |
| BSplineEval__label | 'BSplineEval' | |
| BSplineEval__n_basis_funcs | 10 | |
| BSplineEval__order | 4 | |
| BSplineEval | BSplineEval(n...s=10, order=4) | |
| BSplineEval_1__bounds | None | |
| BSplineEval_1__fill_value | nan | |
| BSplineEval_1__label | 'BSplineEval_1' | |
| BSplineEval_1__n_basis_funcs | 5 | |
| BSplineEval_1__order | 4 | |
| BSplineEval_1 | 'BSplineEval_...cs=5, order=4) |
Modifying Basis Parameters with get_params and set_params#
Another way to get and set the basis parameter is via the get_params and set_params methods. This is how scikit-learn interacts with basis objects, and so enables cross-validation.
The get_params method returns a dictionary, containing all the parameters. The dictionary keys start with the basis label, followed by a double underscore and the name of the parameter.
add.get_params()
{'input1__bounds': None,
'input1__fill_value': nan,
'input1__label': 'input1',
'input1__n_basis_funcs': 5,
'input1__width': 2.0,
'input1': 'input1': RaisedCosineLinearEval(n_basis_funcs=5, width=2.0),
'input2__bounds': None,
'input2__fill_value': nan,
'input2__label': 'input2',
'input2__n_basis_funcs': 8,
'input2__order': 4,
'input2': 'input2': BSplineEval(n_basis_funcs=8, order=4),
'my_custom_label__label': 'my_custom_label',
'my_custom_label': 'my_custom_label': AdditiveBasis(
basis1='input1': RaisedCosineLinearEval(n_basis_funcs=5, width=2.0),
basis2='input2': BSplineEval(n_basis_funcs=8, order=4),
),
'input3__bounds': None,
'input3__fill_value': nan,
'input3__label': 'input3',
'input3__n_basis_funcs': 4,
'input3__order': 4,
'input3': 'input3': MSplineEval(n_basis_funcs=4, order=4),
'label': '(my_custom_label + input3)'}
Each of the key can be used as keyword argument to the set_method which in turns sets one or more of the parameter values.
add.set_params(input3__order=3, input1__bounds=(-1,1))
'(my_custom_label + input3)': AdditiveBasis(
basis1='my_custom_label': AdditiveBasis(
basis1='input1': RaisedCosineLinearEval(n_basis_funcs=5, width=2.0, bounds=(-1.0, 1.0), fill_value=nan),
basis2='input2': BSplineEval(n_basis_funcs=8, order=4),
),
basis2='input3': MSplineEval(n_basis_funcs=4, order=3),
)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| label | '(my_custom_label + input3)' | |
| input1__bounds | (-1.0, ...) | |
| input1__fill_value | nan | |
| input1__label | 'input1' | |
| input1__n_basis_funcs | 5 | |
| input1__width | 2.0 | |
| input1 | 'input1': Rai...ill_value=nan) | |
| input2__bounds | None | |
| input2__fill_value | nan | |
| input2__label | 'input2' | |
| input2__n_basis_funcs | 8 | |
| input2__order | 4 | |
| input2 | 'input2': BSp...cs=8, order=4) | |
| my_custom_label__label | 'my_custom_label' | |
| my_custom_label | 'my_custom_la...8, order=4), ) | |
| input3__bounds | None | |
| input3__fill_value | nan | |
| input3__label | 'input3' | |
| input3__n_basis_funcs | 4 | |
| input3__order | 3 | |
| input3 | 'input3': MSp...cs=4, order=3) |
Grid definition
The parameter keys retrieved by get_params are the one needed to define a parameter grid when cross-validating your hyper-parameters with scikit-learn. Learn how to cross-validate basis parameters using pipelines with this notebook.
As noted above, when labels are not provided, get_params retrieves the auto-generated ones.
basis = nmo.basis.BSplineEval(10) + nmo.basis.BSplineEval(5)
basis.get_params()
{'BSplineEval__bounds': None,
'BSplineEval__fill_value': nan,
'BSplineEval__label': 'BSplineEval',
'BSplineEval__n_basis_funcs': 10,
'BSplineEval__order': 4,
'BSplineEval': BSplineEval(n_basis_funcs=10, order=4),
'BSplineEval_1__bounds': None,
'BSplineEval_1__fill_value': nan,
'BSplineEval_1__label': 'BSplineEval_1',
'BSplineEval_1__n_basis_funcs': 5,
'BSplineEval_1__order': 4,
'BSplineEval_1': 'BSplineEval_1': BSplineEval(n_basis_funcs=5, order=4),
'label': '(BSplineEval + BSplineEval_1)'}
Setting the parameters is still possible, but we recommend to always provide informative labels in order to improve code readability.
basis.set_params(BSplineEval_1__n_basis_funcs=12)
'(BSplineEval + BSplineEval_1)': AdditiveBasis(
basis1=BSplineEval(n_basis_funcs=10, order=4),
basis2='BSplineEval_1': BSplineEval(n_basis_funcs=12, order=4),
)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| label | '(BSplineEval + BSplineEval_1)' | |
| BSplineEval__bounds | None | |
| BSplineEval__fill_value | nan | |
| BSplineEval__label | 'BSplineEval' | |
| BSplineEval__n_basis_funcs | 10 | |
| BSplineEval__order | 4 | |
| BSplineEval | BSplineEval(n...s=10, order=4) | |
| BSplineEval_1__bounds | None | |
| BSplineEval_1__fill_value | nan | |
| BSplineEval_1__label | 'BSplineEval_1' | |
| BSplineEval_1__n_basis_funcs | 12 | |
| BSplineEval_1__order | 4 | |
| BSplineEval_1 | 'BSplineEval_...s=12, order=4) |