nemos.basis._spline_basis.SplineBasis#

class nemos.basis._spline_basis.SplineBasis(n_basis_funcs, order=2, label=None)[source]#

Bases: AtomicBasisMixin, Basis, ABC

SplineBasis class inherits from the Basis class and represents spline basis functions.

Parameters:
  • n_basis_funcs (int) – Number of basis functions.

  • order (optional) – Spline order.

  • label (Optional[str]) – The label of the basis, intended to be descriptive of the task variable being processed. For example: velocity, position, spike_counts.

order#

Spline order.

Type:

int

Attributes

input_shape

Expected per-sample input shape.

is_complex

Whether the basis is intrinsically complex.

label

Label for the basis.

n_basis_funcs

Number of basis functions.

n_output_features

Number of features returned by the basis.

order

Spline order.

__init__(n_basis_funcs, order=2, label=None)[source]#
Parameters:
  • n_basis_funcs (int)

  • order (int)

  • label (str | None)

Return type:

None

Methods

__init__(n_basis_funcs[, order, label])

compute_features(*xi)

Apply the basis transformation to the input data.

evaluate(*xi)

Abstract method to evaluate the basis functions at given points.

evaluate_on_grid(*n_samples)

Evaluate the basis set on a grid of equi-spaced sample points.

get_metadata_routing()

Get metadata routing of this object.

get_params([deep])

From scikit-learn, get parameters by inspecting init.

set_input_shape(xi)

Set the expected input shape for the basis object.

set_params(**params)

Set the parameters of this estimator.

setup_basis(*xi)

Pre-compute all basis state variables.

split_by_feature(x[, axis])

Decompose an array along a specified axis into sub-arrays based on the number of expected inputs.

to_transformer()

Turn the Basis into a TransformerBasis for use with scikit-learn.

__add__(other)#

Add two Basis objects together.

Parameters:

other (BasisMixin) – The other Basis object to add.

Returns:

The resulting Basis object.

Return type:

AdditiveBasis

classmethod __init_subclass__(**kwargs)#

Set the set_{method}_request methods.

This uses PEP-487 [1] to set the set_{method}_request methods. It looks for the information available in the set default values which are set using __metadata_request__* class attributes, or inferred from method signatures.

The __metadata_request__* class attributes are used when a method does not explicitly accept a metadata through its arguments or if the developer would like to specify a request value for those metadata which are different from the default None.

References

__iter__()#

Make basis iterable. Re-implemented for additive.

__len__()#

Return the number of additive basis.

__mul__(other)#

Multiply two Basis objects together.

Parameters:

other (BasisMixin | int) – The other Basis object to multiply.

Return type:

Basis

Returns:

The resulting Basis object.

__pow__(exponent)#

Exponentiation of a Basis object.

Define the power of a basis by repeatedly applying the method __multiply__. The exponent must be a positive integer.

Parameters:

exponent (int) – Positive integer exponent

Return type:

BasisMixin

Returns:

The product of the basis with itself “exponent” times. Equivalent to self * self * ... * self.

Raises:
  • TypeError – If the provided exponent is not an integer.

  • ValueError – If the integer is zero or negative.

__rmul__(other)#

Right multiplication operator for basis.

Parameters:

other (BasisMixin | int)

__sklearn_clone__()#

Clone the basis while preserving attributes related to input shapes.

This method ensures that input shape attributes (e.g., _input_shape_product, _input_shape_) are preserved during cloning. Reinitializing the class as in the regular sklearn clone would drop these attributes, rendering cross-validation unusable.

Return type:

Basis

compute_features(*xi)#

Apply the basis transformation to the input data.

This method is designed to be a high-level interface for transforming input data using the basis functions defined by the subclass. Depending on the basis’ mode (‘Eval’ or ‘Conv’), it either evaluates the basis functions at the sample points or performs a convolution operation between the input data and the basis functions.

Parameters:

*xi (Union[ArrayLike, Tsd, pynapple.TsdFrame, TsdTensor]) – Input data arrays to be transformed. The shape and content requirements depend on the subclass and mode of operation (‘Eval’ or ‘Conv’).

Return type:

TsdFrame | ndarray[tuple[Any, ...], dtype[TypeVar(_ScalarT, bound= generic)]]

Returns:

Transformed features. In ‘Eval’ mode, it corresponds to the basis functions evaluated at the input samples. In ‘Conv’ mode, it consists of convolved input samples with the basis functions. The output shape varies based on the subclass and mode.

Notes

Subclasses should implement how to handle the transformation specific to their basis function types and operation modes.

abstractmethod evaluate(*xi)#

Abstract method to evaluate the basis functions at given points.

This method must be implemented by subclasses to define the specific behavior of the basis transformation. The implementation depends on the type of basis (e.g., spline, raised cosine), and it should evaluate the basis functions at the specified points in the domain.

Parameters:

*xi (Union[ArrayLike, Tsd, pynapple.TsdFrame, TsdTensor]) – Variable number of arguments, each representing an array of points at which to evaluate the basis functions. The dimensions and requirements of these inputs vary depending on the specific basis implementation.

Return type:

TsdFrame | ndarray[tuple[Any, ...], dtype[TypeVar(_ScalarT, bound= generic)]]

Returns:

An array containing the evaluated values of the basis functions at the input points. The shape and structure of this array are specific to the subclass implementation.

evaluate_on_grid(*n_samples)#

Evaluate the basis set on a grid of equi-spaced sample points.

Parameters:

n_samples (int) – The number of samples.

Return type:

Tuple[Tuple[NDArray], NDArray]

Returns:

  • X – Array of shape (n_samples,) containing the equi-spaced sample points where we’ve evaluated the basis.

  • basis_funcs – Evaluated exponentially decaying basis functions, numerically orthogonalized, shape (n_samples, n_basis_funcs)

get_metadata_routing()#

Get metadata routing of this object.

Please check User Guide on how the routing mechanism works.

Returns:

routing – A MetadataRequest encapsulating routing information.

Return type:

MetadataRequest

get_params(deep=True)#

From scikit-learn, get parameters by inspecting init.

Parameters:

deep – If True, will return the parameters for this estimator and contained subobjects that are estimators.

Return type:

dict

Returns:

A dictionary containing the parameters. Key is the parameter name, value is the parameter value.

property input_shape#

Expected per-sample input shape.

Returns:

If inputs are shaped (n_samples, *shape), returns shape.

property is_complex#

Whether the basis is intrinsically complex.

Returns:

True if the basis is complex; False otherwise.

Notes

compute_features() always returns a real-valued design matrix. For complex bases (e.g., FourierEval), the real and imaginary parts are returned as separate columns.

property label: str#

Label for the basis.

property n_basis_funcs#

Number of basis functions.

property n_output_features: int | None#

Number of features returned by the basis.

Notes

The number of output features can be determined only when the number of inputs provided to the basis is known. Therefore, before the first call to compute_features, this property will return None. After that call, or after setting the input shape with set_input_shape, n_output_features will be available.

property order#

Spline order.

Spline order, i.e. the polynomial degree of the spline plus one.

set_input_shape(xi)#

Set the expected input shape for the basis object.

This method configures the shape of the input data that the basis object expects. xi can be specified as an integer, a tuple of integers, or derived from an array. The method also calculates the total number of input features and output features based on the number of basis functions.

Parameters:

xi (Union[int, tuple[int, ...], NDArray]) –

The input shape specification. - An integer: Represents the dimensionality of the input. A value of 1 is treated as scalar input. - A tuple: Represents the exact input shape excluding the first axis (sample axis).

All elements must be integers.

  • An array: The shape is extracted, excluding the first axis (assumed to be the sample axis).

Raises:

ValueError – If a tuple is provided and it contains non-integer elements.

Returns:

Returns the instance itself to allow method chaining.

Return type:

self

Notes

All state attributes that depends on the input must be set in this method in order for the API of basis to work correctly. In particular, this method is called by setup_basis, which is equivalent to fit for a transformer. If any input dependent state is not set in this method, then compute_features (equivalent to fit_transform) will break.

set_params(**params)#

Set the parameters of this estimator.

The method works on simple estimators as well as on nested objects (such as Pipeline). The latter have parameters of the form <component>__<parameter> so that it’s possible to update each component of a nested object.

Parameters:

**params (dict) – Estimator parameters.

Returns:

self – Estimator instance.

Return type:

estimator instance

abstractmethod setup_basis(*xi)#

Pre-compute all basis state variables.

This method is intended to be equivalent to the sklearn transformer fit method. As the latter, it computes all the state attributes, and store it with the convention that the attribute name must end with “_”, for example self.kernel_, self._input_shape_.

The method differs from transformer’s fit for the structure of the input that it accepts. In particular, _fit_basis accepts a number of different time series, one per 1D basis component, while fit requires all inputs to be concatenated in a single array.

Return type:

TsdFrame | ndarray[tuple[Any, ...], dtype[TypeVar(_ScalarT, bound= generic)]]

Parameters:

xi (ArrayLike)

split_by_feature(x, axis=1)#

Decompose an array along a specified axis into sub-arrays based on the number of expected inputs.

This function takes an array (e.g., a design matrix or model coefficients) and splits it along a designated axis.

How it works:

  • If the basis expects an input shape (n_samples, n_inputs), then the feature axis length will be total_n_features = n_inputs * n_basis_funcs. This axis is reshaped into dimensions (n_inputs, n_basis_funcs).

  • If the basis expects an input of shape (n_samples,), then the feature axis length will be total_n_features = n_basis_funcs. This axis is reshaped into (1, n_basis_funcs).

For example, if the input array x has shape (1, 2, total_n_features, 4, 5), then after applying this method, it will be reshaped into (1, 2, n_inputs, n_basis_funcs, 4, 5).

The specified axis (axis) determines where the split occurs, and all other dimensions remain unchanged. See the example section below for the most common use cases.

Parameters:
  • x (NDArray) –

    The input array to be split, representing concatenated features, coefficients, or other data. The shape of x along the specified axis must match the total number of features generated by the basis, i.e., self.n_output_features.

    Examples:

    • For a design matrix: (n_samples, total_n_features)

    • For model coefficients: (total_n_features,) or (total_n_features, n_neurons).

  • axis (int, optional) – The axis along which to split the features. Defaults to 1. Use axis=1 for design matrices (features along columns) and axis=0 for coefficient arrays (features along rows). All other dimensions are preserved.

Raises:

ValueError – If the shape of x along the specified axis does not match self.n_output_features.

Returns:

A dictionary where:

  • Key: Label of the basis.

  • Value: the array reshaped to: (..., n_inputs, n_basis_funcs, ...)

Return type:

dict

to_transformer()#

Turn the Basis into a TransformerBasis for use with scikit-learn.

Return type:

TransformerBasis

Examples

Jointly cross-validating basis and GLM parameters with scikit-learn.

>>> import nemos as nmo
>>> from sklearn.pipeline import Pipeline
>>> from sklearn.model_selection import GridSearchCV
>>> # load some data
>>> X, y = np.random.normal(size=(30, 1)), np.random.poisson(size=30)
>>> basis = nmo.basis.RaisedCosineLinearEval(10).set_input_shape(1).to_transformer()
>>> glm = nmo.glm.GLM(regularizer="Ridge", regularizer_strength=1.)
>>> pipeline = Pipeline([("basis", basis), ("glm", glm)])
>>> param_grid = dict(
...     glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6),
...     basis__n_basis_funcs=(3, 5, 10, 20, 100),
... )
>>> gridsearch = GridSearchCV(
...     pipeline,
...     param_grid=param_grid,
...     cv=5,
... )
>>> gridsearch = gridsearch.fit(X, y)