nemos.basis._basis.MultiplicativeBasis.set_input_shape#
- MultiplicativeBasis.set_input_shape(*xi)[source]#
Set the expected input shape for the basis object.
This method sets the input shape for each component basis in the basis. One
xi
must be provided for each basis component, specified as an integer, a tuple of integers, or an array. The method calculates and stores the total number of output features based on the number of basis functions in each component and the provided input shapes.- Parameters:
*xi (int | tuple[int, …] | NDArray) –
The input shape specifications. For every k,``xi[k]`` can be: - An integer: Represents the dimensionality of the input. A value of
1
is treated as scalar input. - A tuple: Represents the exact input shape excluding the first axis (sample axis).All elements must be integers.
An array: The shape is extracted, excluding the first axis (assumed to be the sample axis).
- Raises:
ValueError – If a tuple is provided and it contains non-integer elements.
- Returns:
Returns the instance itself to allow method chaining.
- Return type:
self
Examples
>>> # Generate sample data >>> import numpy as np >>> import nemos as nmo
>>> # define an additive basis >>> basis_1 = nmo.basis.BSplineEval(5) >>> basis_2 = nmo.basis.RaisedCosineLinearEval(6) >>> basis_3 = nmo.basis.MSplineEval(7) >>> multiplicative_basis = basis_1 * basis_2 * basis_3
Specify the input shape using all 3 allowed ways: integer, tuple, array >>> _ = multiplicative_basis.set_input_shape(1, (2, 3), np.ones((10, 4, 5)))
Expected output features are: (5 * 6 * 7 bases) * (1 * 6 * 20 inputs) = 25200 >>> multiplicative_basis.n_output_features 25200