nemos.basis.BSplineEval.split_by_feature#
- BSplineEval.split_by_feature(x, axis=1)[source]#
Decompose an array along a specified axis into sub-arrays based on the number of expected inputs.
This function takes an array (e.g., a design matrix or model coefficients) and splits it along a designated axis.
How it works:
If the basis expects an input shape
(n_samples, n_inputs), then the feature axis length will betotal_n_features = n_inputs * n_basis_funcs. This axis is reshaped into dimensions(n_inputs, n_basis_funcs).If the basis expects an input of shape
(n_samples,), then the feature axis length will betotal_n_features = n_basis_funcs. This axis is reshaped into(1, n_basis_funcs).
For example, if the input array
xhas shape(1, 2, total_n_features, 4, 5), then after applying this method, it will be reshaped into(1, 2, n_inputs, n_basis_funcs, 4, 5).The specified axis (
axis) determines where the split occurs, and all other dimensions remain unchanged. See the example section below for the most common use cases.- Parameters:
x (NDArray) –
The input array to be split, representing concatenated features, coefficients, or other data. The shape of
xalong the specified axis must match the total number of features generated by the basis, i.e.,self.n_output_features.Examples:
For a design matrix:
(n_samples, total_n_features)For model coefficients:
(total_n_features,)or(total_n_features, n_neurons).
axis (int, optional) – The axis along which to split the features. Defaults to 1. Use
axis=1for design matrices (features along columns) andaxis=0for coefficient arrays (features along rows). All other dimensions are preserved.
- Raises:
ValueError – If the shape of
xalong the specified axis does not matchself.n_output_features.- Returns:
A dictionary where:
Key: Label of the basis.
Value: the array reshaped to:
(..., n_inputs, n_basis_funcs, ...)
- Return type:
Examples
>>> import numpy as np >>> from nemos.basis import BSplineEval >>> from nemos.glm import GLM >>> basis = BSplineEval(n_basis_funcs=6, label="one_input") >>> X = basis.compute_features(np.random.randn(20,)) >>> split_features_multi = basis.split_by_feature(X, axis=1) >>> for feature, sub_dict in split_features_multi.items(): ... print(f"{feature}, shape {sub_dict.shape}") one_input, shape (20, 6)