nemos.basis._custom_basis.CustomBasis#
- class nemos.basis._custom_basis.CustomBasis(funcs, ndim_input=1, output_shape=None, basis_kwargs=None, pynapple_support=True, label=None, is_complex=False, bounds=None, fill_value=nan)[source]#
Bases:
BasisMixin,BasisTransformerMixin,BaseCustom basis class.
Create a custom basis class from a list of callables (the basis functions).
- Parameters:
funcs (
Union[List[Callable[[NDArray], NDArray]],Callable[[NDArray], NDArray]]) – List of basis functions.ndim_input (
int) – Dimensionality of the input for each sample, i.e. if your time series is of shape(n_samples, n, m),ndim_inputis two.output_shape (
Union[Tuple[int,...],int,None]) – Shape of the output excluding the number of samples. Set automatically when compute_features is called.basis_kwargs (
Optional[dict]) – Additional keyword arguments to pass to the basis function.pynapple_support (
bool) – Enable pynapple support if True.is_complex (
bool) – Whether the basis should be treated as complex. This flag ensures that multiplication with other bases behaves correctly: two real bases, or a real and a complex basis, can be multiplied, but two complex bases cannot. This restriction exists because after multiplication,basis.compute_featuresdoes not distinguish between real and imaginary components, which would lead to incorrect outputs.fill_value (float)
Examples
>>> import numpy as np >>> import nemos as nmo >>> from functools import partial >>> # Define a function >>> def decay_exp(x, rate, shift=0): ... return np.exp(-rate * (x + shift)**2) >>> # Define a list of basis functions >>> funcs = [partial(decay_exp, rate=r) for r in np.linspace(0, 1, 10)] >>> bas = nmo.basis.CustomBasis(funcs=funcs, basis_kwargs=dict(shift=1)) >>> bas CustomBasis( funcs=[partial(decay_exp, rate=np.float64(0.0)), ..., partial(decay_exp, rate=np.float64(1.0))], ndim_input=1, basis_kwargs={'shift': 1}, pynapple_support=True, is_complex=False ) >>> samples = np.linspace(0, 1, 50) >>> X = bas.compute_features(samples) >>> X.shape (50, 10) >>> # Can be composed with other basis (including other custom basis) >>> add = bas + bas >>> X = add.compute_features(samples, samples) >>> X.shape (50, 20)
Attributes
Additional keyword arguments to pass to the basis functions.
User defined list of basis functions.
Input shape as a tuple or list of tuple.
Label for the basis.
The number of basis.
The number of output features, i.e. the number of columns of the design matrix.
The shape of the output excluding the number of samples and the number of basis functions.
Support pynapple Tsd/TsdFrame/TsdTensor as inputs.
- __init__(funcs, ndim_input=1, output_shape=None, basis_kwargs=None, pynapple_support=True, label=None, is_complex=False, bounds=None, fill_value=nan)[source]#
- Parameters:
Methods
__init__(funcs[, ndim_input, output_shape, ...])compute_features(*xi)Apply the basis transformation to the input data.
evaluate(*xi)Evaluate the basis functions in a vectorized form at the given sample points.
get_params([deep])From scikit-learn, get parameters by inspecting init.
set_input_shape(*xi)Set the expected input shape for the basis object.
set_params(**params)Set the parameters of this estimator.
split_by_feature(x[, axis])Decompose an array along a specified axis into sub-arrays based on the number of expected inputs.
Turn the Basis into a TransformerBasis for use with scikit-learn.
- classmethod __init_subclass__(**kwargs)#
Set the
set_{method}_requestmethods.This uses PEP-487 [1] to set the
set_{method}_requestmethods. It looks for the information available in the set default values which are set using__metadata_request__*class attributes, or inferred from method signatures.The
__metadata_request__*class attributes are used when a method does not explicitly accept a metadata through its arguments or if the developer would like to specify a request value for those metadata which are different from the defaultNone.References
- __iter__()#
Make basis iterable. Re-implemented for additive.
- __sklearn_clone__()[source]#
Clone the basis while preserving attributes related to input shapes.
This method ensures that input shape attributes (e.g., _input_shape_product, _input_shape_) are preserved during cloning. Reinitializing the class as in the regular sklearn clone would drop these attributes, rendering cross-validation unusable.
- Return type:
- compute_features(*xi)[source]#
Apply the basis transformation to the input data.
This method applies each function in
self.funcsto the input arrays*xi. These functions are called with the arguments(*xi, **self.basis_kwargs)and must return an array of shape(n_samples, ...), where the first dimension corresponds to the number of samples, and the output must have at least one dimension (i.e.,ndim >= 1).The outputs of all function calls are reshaped into 2D arrays with shape
(n_samples, n_output), and then concatenated along the feature axis (second dimension) to form the full design matrix.If the input arrays have more dimensions than
self.ndim_input, the function calls are automatically vectorized over the additional axes. This is done using Python loops, which may be slow. For better performance, users are encouraged to provide fully vectorized functions.- Parameters:
*xi (
Union[ArrayLike,Tsd, pynapple.TsdFrame,TsdTensor]) – Input arrays. Each must have at leastself.ndim_inputdimensions. If extra dimensions are present, they are interpreted as batch or window dimensions over which the basis functions are applied.- Return type:
Union[TsdFrame, NDArray,Array]- Returns:
The resulting design matrix, with one row per sample and one column per output feature.
Examples
>>> import nemos as nmo >>> import numpy as np >>> from functools import partial >>> def power_func(n, x): ... return x ** n >>> bas = nmo.basis.CustomBasis([partial(power_func, 1), partial(power_func, 2)]) >>> bas.compute_features(np.arange(1, 4)) array([[1., 1.], [2., 4.], [3., 9.]])
- evaluate(*xi)[source]#
Evaluate the basis functions in a vectorized form at the given sample points.
- Parameters:
*xi (NDArray) – The samples at which the basis functions are evaluated. Each element in xi corresponds to an input dimension, and must be broadcastable to a common shape along the sample axis. The shape of each input array should be (n_samples, …) where the first axis indexes samples.
- Returns:
The basis functions evaluated at the given input points, with shape (n_samples, n_vect_input * n_basis_funcs), n_vect_input is the number of inputs that are vectorized.
- Return type:
basis_funcs
Notes
This method supports both NumPy and pynapple inputs. If pynapple support is enabled, the inputs and outputs are automatically cast using the configured backend (e.g., JAX or NumPy). Evaluation is performed by applying a vectorized function over each basis function and concatenating the results along the last axis.
Examples
>>> import numpy as np >>> from nemos.basis import CustomBasis >>> basis = CustomBasis(funcs=[lambda x: x, lambda x: x**2]) >>> x = np.linspace(0, 1, 10) >>> out = basis.evaluate(x) >>> out.shape (10, 2) >>> # vectorize over 3 inputs >>> out = basis.evaluate(np.random.randn(10, 3)) >>> out.shape (10, 3, 2)
- property funcs#
User defined list of basis functions.
- get_metadata_routing()#
Get metadata routing of this object.
Please check User Guide on how the routing mechanism works.
- Returns:
routing – A
MetadataRequestencapsulating routing information.- Return type:
MetadataRequest
- get_params(deep=True)#
From scikit-learn, get parameters by inspecting init.
- Parameters:
deep – If True, will return the parameters for this estimator and contained subobjects that are estimators.
- Return type:
- Returns:
A dictionary containing the parameters. Key is the parameter name, value is the parameter value.
- property input_shape: None | List[None] | Tuple[int, ...] | List[Tuple[int, ...]]#
Input shape as a tuple or list of tuple.
The property mimics the behavior of atomic bases, and uses the assumption that _input_shape_ for custom bases is a list of length one.
- property is_complex#
- property n_output_features#
The number of output features, i.e. the number of columns of the design matrix.
- property output_shape: Tuple[int, ...]#
The shape of the output excluding the number of samples and the number of basis functions.
- set_input_shape(*xi)[source]#
Set the expected input shape for the basis object.
This method sets the input shape for each input required by the funcs in the CustomBasis. One
ximust be provided for each input, specified as an integer, a tuple of integers, or an array. The method calculates and stores the total number of output features based on the number of basis functions, the number of input per function, and the provided input shapes.- Parameters:
*xi (
Union[int,tuple[int,...], NDArray]) –The input shape specifications. For every k,``xi[k]`` can be: - An integer: Represents the dimensionality of the input. A value of
1is treated as scalar input. - A tuple: Represents the exact input shape excluding the first axis (sample axis).All elements must be integers.
An array: The shape is extracted, excluding the first axis (assumed to be the sample axis).
- Raises:
ValueError – If a tuple is provided, and it contains non-integer elements. If not enough inputs are provided.
- Returns:
Returns the instance itself to allow method chaining.
- Return type:
self
Examples
>>> import nemos as nmo >>> import numpy as np >>> from functools import partial >>> # Basis with one input only >>> def power_func(n, x): ... return x ** n >>> basis = nmo.basis.CustomBasis([partial(power_func, n) for n in range(1, 6)]) >>> # Configure with an integer input: >>> _ = basis.set_input_shape(3) >>> basis.n_output_features 15 >>> # Configure with a tuple: >>> _ = basis.set_input_shape((4, 5)) >>> basis.n_output_features 100 >>> # Configure with an array: >>> x = np.ones((10, 4, 5)) >>> _ = basis.set_input_shape(x) >>> basis.n_output_features 100 >>> # basis with 2 inputs >>> def power_add_func(n, x, y): ... return x ** n + y ** n >>> basis = nmo.basis.CustomBasis([partial(power_add_func, n) for n in range(1, 6)]) >>> _ = basis.set_input_shape(3, 3) >>> basis.n_output_features 15 >>> _ = basis.set_input_shape((3, 2), (3, 2)) >>> basis.n_output_features 30 >>> _ = basis.set_input_shape(np.ones((10, 3, 2)), (3, 2)) >>> basis.n_output_features 30
- set_params(**params)#
Set the parameters of this estimator.
The method works on simple estimators as well as on nested objects (such as
Pipeline). The latter have parameters of the form<component>__<parameter>so that it’s possible to update each component of a nested object.- Parameters:
**params (
Any) – Estimator parameters.- Returns:
self – Estimator instance.
- Return type:
estimator instance
- split_by_feature(x, axis=1)[source]#
Decompose an array along a specified axis into sub-arrays based on the number of expected inputs.
This function takes an array (e.g., a design matrix or model coefficients) and splits it along a designated axis.
How it works:
If the basis expects an input shape
(n_samples, n_inputs), then the feature axis length will betotal_n_features = n_inputs * n_basis_funcs. This axis is reshaped into dimensions(n_inputs, n_basis_funcs).If the basis expects an input of shape
(n_samples,), then the feature axis length will betotal_n_features = n_basis_funcs. This axis is reshaped into(1, n_basis_funcs).
For example, if the input array
xhas shape(1, 2, total_n_features, 4, 5), then after applying this method, it will be reshaped into(1, 2, n_inputs, n_basis_funcs, 4, 5).The specified axis (
axis) determines where the split occurs, and all other dimensions remain unchanged. See the example section below for the most common use cases.- Parameters:
x (NDArray) –
The input array to be split, representing concatenated features, coefficients, or other data. The shape of
xalong the specified axis must match the total number of features generated by the basis, i.e.,self.n_output_features.Examples:
For a design matrix:
(n_samples, total_n_features)For model coefficients:
(total_n_features,)or(total_n_features, n_neurons).
axis (
int) – The axis along which to split the features. Defaults to 1. Useaxis=1for design matrices (features along columns) andaxis=0for coefficient arrays (features along rows). All other dimensions are preserved.
- Raises:
ValueError – If the shape of
xalong the specified axis does not matchself.n_output_features.- Returns:
A dictionary where:
Key: Label of the basis.
Value: the array reshaped to:
(..., n_inputs, n_basis_funcs, ...)
- Return type:
Examples
>>> import nemos as nmo >>> import numpy as np >>> from functools import partial >>> def power_func(n, x): ... return x ** n >>> bas = nmo.basis.CustomBasis([partial(power_func, 1), partial(power_func, 2)]) >>> # define a 3 x 2 input >>> inp = np.arange(1, 7).reshape(3, 2) >>> X = bas.compute_features(inp) >>> X.shape # (3, 2 * 2) (3, 4) >>> bas.split_by_feature(X)["CustomBasis"] # spilt to (3, 2, 2) array([[[ 1., 1.], [ 2., 4.]], ... [[ 3., 9.], [ 4., 16.]], ... [[ 5., 25.], [ 6., 36.]]], dtype=float32)
- to_transformer()[source]#
Turn the Basis into a TransformerBasis for use with scikit-learn.
- Return type:
- Returns:
A transformer basis.
Examples
>>> from functools import partial >>> >>> import numpy as np >>> from sklearn.model_selection import GridSearchCV >>> from sklearn.pipeline import Pipeline >>> >>> import nemos as nmo >>> >>> # load some data >>> x = 0.1 * np.random.normal(size=(100, 1)) >>> y = np.random.poisson(np.exp(x[:, 0]), size=100) >>> >>> >>> def power_func(n, x, bias=0): ... return (x + bias) ** n >>> >>> >>> basis = nmo.basis.CustomBasis([partial(power_func, n) for n in range(1, 6)]) >>> basis = basis.to_transformer() >>> glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=1.0) >>> pipeline = Pipeline([("basis", basis), ("glm", glm)]) >>> param_grid = dict( ... glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6), ... basis__basis_kwargs=(dict(bias=0), dict(bias=1)), ... ) >>> gridsearch = GridSearchCV( ... pipeline, ... param_grid=param_grid, ... cv=2, ... ) >>> gridsearch = gridsearch.fit(x, y)