nemos.basis.RaisedCosineLinearEval.split_by_feature#
- RaisedCosineLinearEval.split_by_feature(x, axis=1)[source]#
Decompose an array along a specified axis into sub-arrays based on the number of expected inputs.
This function takes an array (e.g., a design matrix or model coefficients) and splits it along a designated axis.
How it works:
If the basis expects an input shape
(n_samples, n_inputs)
, then the feature axis length will betotal_n_features = n_inputs * n_basis_funcs
. This axis is reshaped into dimensions(n_inputs, n_basis_funcs)
.If the basis expects an input of shape
(n_samples,)
, then the feature axis length will betotal_n_features = n_basis_funcs
. This axis is reshaped into(1, n_basis_funcs)
.
For example, if the input array
x
has shape(1, 2, total_n_features, 4, 5)
, then after applying this method, it will be reshaped into(1, 2, n_inputs, n_basis_funcs, 4, 5)
.The specified axis (
axis
) determines where the split occurs, and all other dimensions remain unchanged. See the example section below for the most common use cases.- Parameters:
x (NDArray) –
The input array to be split, representing concatenated features, coefficients, or other data. The shape of
x
along the specified axis must match the total number of features generated by the basis, i.e.,self.n_output_features
.Examples:
For a design matrix:
(n_samples, total_n_features)
For model coefficients:
(total_n_features,)
or(total_n_features, n_neurons)
.
axis (int, optional) – The axis along which to split the features. Defaults to 1. Use
axis=1
for design matrices (features along columns) andaxis=0
for coefficient arrays (features along rows). All other dimensions are preserved.
- Raises:
ValueError – If the shape of
x
along the specified axis does not matchself.n_output_features
.- Returns:
A dictionary where:
Key: Label of the basis.
Value: the array reshaped to:
(..., n_inputs, n_basis_funcs, ...)
- Return type:
dict
Examples
>>> import numpy as np >>> from nemos.basis import RaisedCosineLinearEval >>> from nemos.glm import GLM >>> basis = RaisedCosineLinearEval(n_basis_funcs=6, label="one_input") >>> X = basis.compute_features(np.random.randn(20,)) >>> split_features_multi = basis.split_by_feature(X, axis=1) >>> for feature, sub_dict in split_features_multi.items(): ... print(f"{feature}, shape {sub_dict.shape}") one_input, shape (20, 6)