"""Custom Basis Class.
Facilitate the construction of a custom basis class.
"""
from __future__ import annotations
import inspect
import itertools
import re
from copy import deepcopy
from numbers import Number
from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple
import jax.numpy as jnp
import lazy_loader as lazy
import numpy as np
from numpy.typing import ArrayLike, NDArray
from nemos.typing import FeatureMatrix
from ..base_class import Base
from ..type_casting import support_pynapple
from ..utils import format_repr
from . import AdditiveBasis, MultiplicativeBasis
from ._basis_mixin import BasisMixin, BasisTransformerMixin, set_input_shape_state
from ._check_basis import _check_transform_input
from ._composition_utils import (
_check_unique_shapes,
_check_valid_shape_tuple,
add_docstring,
count_positional_and_var_args,
infer_input_dimensionality,
is_basis_like,
multiply_basis_by_integer,
promote_to_transformer,
raise_basis_to_power,
set_input_shape,
)
# Lazy load pynapple to improve import times
nap = lazy.load("pynapple")
if TYPE_CHECKING:
import pynapple as nap # noqa: F811
from pynapple import Tsd, TsdFrame, TsdTensor
from . import TransformerBasis
def simplify_func_repr(string: str):
"""
Simplify function repr.
Simplify function repr by dropping the address and replace "functools.partial" with "partial".
Parameters
----------
string:
The string repr of a parameter.
Returns
-------
A simplified version of the string.
"""
pattern = r"<function (.+) at 0x[0-9a-fA-F]+>"
new_string = deepcopy(string)
for match in re.finditer(pattern, string):
func_name = match.group(1)
orig_repr = match.group(0)
new_string = re.sub(orig_repr, func_name, new_string, count=1)
new_string = re.sub(r"functools\.partial", "partial", new_string)
return new_string
class FunctionList:
def __init__(self, funcs, name="b"):
self.funcs_ = list(funcs) # store as list
self.name = name # optional label prefix
def __getitem__(self, idx):
return self.funcs_[idx]
def __len__(self):
return len(self.funcs_)
def __iter__(self):
return iter(self.funcs_)
@staticmethod
def _unwrap(func):
return func.__wrapped__ if hasattr(func, "__wrapped__") else func
def __repr__(self):
unwrap = [self._unwrap(f) for f in self.funcs_]
if len(unwrap) <= 2:
return simplify_func_repr(repr(unwrap))
return f"[{simplify_func_repr(repr(unwrap[0]))}, ..., {simplify_func_repr(repr(unwrap[-1]))}]"
def apply_f_vectorized(
func: Callable[[NDArray], NDArray], *xi: NDArray, ndim_input: int = 1, **kwargs
):
"""Iterate over the output dim and apply the function to all input combination."""
# check if no dimension needs vectorization
if all(x.ndim == ndim_input for x in xi):
return func(*xi, **kwargs)[..., np.newaxis]
# Get the vectorized shape (should be the same for all inputs)
vec_shape = xi[0].shape[ndim_input:]
# Generate all combinations of vectorized indices in the correct order
vec_indices = itertools.product(*[range(dim) for dim in vec_shape])
# Collect results for each vectorized index combination
results = []
for indices in vec_indices:
# Extract slices for this combination of indices
slices = [x[(slice(None),) * ndim_input + indices] for x in xi]
# Apply function to the slices
result = func(*slices, **kwargs)
results.append(result[..., np.newaxis])
# Concatenate along the last axis
return jnp.concatenate(results, axis=-1)
def check_valid_shape(shape):
is_numeric = not all(isinstance(i, Number) for i in shape)
is_int = all(i == int(i) for i in shape) if is_numeric else False
if not (is_numeric and is_int):
raise ValueError("`output_shape` must be an iterable of integers.")
[docs]
class CustomBasis(BasisMixin, BasisTransformerMixin, Base):
"""
Custom basis class.
Create a custom basis class from a list of callables (the basis functions).
Parameters
----------
funcs:
List of basis functions.
ndim_input:
Dimensionality of the input for each sample, i.e. if your time series is of shape ``(n_samples, n, m)``,
``ndim_input`` is two.
output_shape:
Shape of the output excluding the number of samples. Set automatically when `compute_features` is called.
basis_kwargs:
Additional keyword arguments to pass to the basis function.
pynapple_support:
Enable pynapple support if True.
label:
The label of the basis function.
is_complex : bool, optional
Whether the basis should be treated as complex. This flag ensures that
multiplication with other bases behaves correctly: two real bases, or a real
and a complex basis, can be multiplied, but two complex bases cannot. This
restriction exists because after multiplication, ``basis.compute_features``
does not distinguish between real and imaginary components, which would lead
to incorrect outputs.
Examples
--------
>>> import numpy as np
>>> import nemos as nmo
>>> from functools import partial
>>> # Define a function
>>> def decay_exp(x, rate, shift=0):
... return np.exp(-rate * (x + shift)**2)
>>> # Define a list of basis functions
>>> funcs = [partial(decay_exp, rate=r) for r in np.linspace(0, 1, 10)]
>>> bas = nmo.basis.CustomBasis(funcs=funcs, basis_kwargs=dict(shift=1))
>>> bas
CustomBasis(
funcs=[partial(decay_exp, rate=np.float64(0.0)), ..., partial(decay_exp, rate=np.float64(1.0))],
ndim_input=1,
basis_kwargs={'shift': 1},
pynapple_support=True,
is_complex=False
)
>>> samples = np.linspace(0, 1, 50)
>>> X = bas.compute_features(samples)
>>> X.shape
(50, 10)
>>> # Can be composed with other basis (including other custom basis)
>>> add = bas + bas
>>> X = add.compute_features(samples, samples)
>>> X.shape
(50, 20)
"""
_allow_inputs_of_different_shape = False
[docs]
def __init__(
self,
funcs: List[Callable[[NDArray], NDArray]] | Callable[[NDArray], NDArray],
ndim_input: int = 1,
output_shape: Optional[Tuple[int, ...] | int] = None,
basis_kwargs: Optional[dict] = None,
pynapple_support: bool = True,
label: Optional[str] = None,
is_complex: bool = False,
bounds: Optional[Tuple[float, float]] = None,
fill_value: float = jnp.nan,
):
self._pynapple_support = bool(pynapple_support)
self.funcs = funcs
self.ndim_input = int(ndim_input)
if output_shape is None:
output_shape = ()
self.output_shape = output_shape
self._input_shape_product = None
self._n_inputs = infer_input_dimensionality(self)
self._n_basis_funcs = len(self.funcs)
self.basis_kwargs = basis_kwargs
self._is_complex = bool(is_complex)
self.bounds = bounds
self.fill_value = fill_value
super().__init__(label=label)
@property
def is_complex(self):
# custom classes could be complex or real, so the attribute
# is an instance attribute not a class attribute
return self._is_complex
@property
def pynapple_support(self) -> bool:
"""Support pynapple Tsd/TsdFrame/TsdTensor as inputs."""
return self._pynapple_support
@property
def funcs(self):
"""User defined list of basis functions."""
return self._funcs
@funcs.setter
def funcs(self, val: Iterable[Callable[[NDArray, ...], NDArray]]):
if isinstance(val, Callable):
val = [val]
val = FunctionList(val)
if not all(isinstance(f, Callable) for f in val):
raise ValueError("User must provide an iterable of callable.")
if hasattr(self, "_n_inputs"):
inp_dim = sum(count_positional_and_var_args(f)[0] for f in val)
if inp_dim != self._n_inputs:
raise ValueError(
"The number of input time series required by the CustomBasis must be consistent. "
"Redefine a CustomBasis for a different number of inputs."
)
self._funcs = val
@property
def n_basis_funcs(self) -> int:
"""The number of basis."""
return self._n_basis_funcs
@property
def output_shape(self) -> Tuple[int, ...]:
"""The shape of the output excluding the number of samples and the number of basis functions."""
return self._output_shape
@output_shape.setter
def output_shape(self, output_shape: Tuple[int] | int):
if isinstance(output_shape, int):
if output_shape == 1:
self._output_shape = ()
elif output_shape > 1:
self._output_shape = (output_shape,)
else:
raise ValueError(
f"Output shape must be strictly positive (> 0), {output_shape} provided instead."
)
else:
try:
output_shape = tuple(output_shape)
except TypeError:
raise TypeError(
"`output_shape` must be an iterable of positive integers or a positive integer."
)
_check_valid_shape_tuple(output_shape)
self._output_shape = output_shape
@property
def basis_kwargs(self) -> dict:
"""Additional keyword arguments to pass to the basis functions."""
return self._basis_kwargs
@basis_kwargs.setter
def basis_kwargs(self, basis_kwargs: dict):
# store args and kwargs
basis_kwargs = basis_kwargs if basis_kwargs is not None else {}
if not isinstance(basis_kwargs, dict):
raise ValueError("`basis_kwargs` must be a dictionary.")
sig = inspect.signature(apply_f_vectorized)
invalid_kwargs = {
p.name
for p in sig.parameters.values()
if p.kind is inspect.Parameter.KEYWORD_ONLY
}
if invalid_kwargs.intersection(basis_kwargs.keys()):
raise ValueError(
f"Invalid kwargs name in ``basis_kwargs``: {invalid_kwargs.intersection(basis_kwargs.keys())}. "
"Change parameter name in your function definition."
)
self._basis_kwargs = basis_kwargs
@support_pynapple(conv_type="jax")
def _apply_fill_value(self, *xi: ArrayLike, out: NDArray) -> NDArray:
"""Apply fill value to out-of-bounds samples."""
# Compute mask for out-of-bounds samples
to_fill = jnp.any(
jnp.stack(
[
jnp.any(
(jnp.reshape(x, (x.shape[0], -1)) < self.bounds[0])
| (jnp.reshape(x, (x.shape[0], -1)) > self.bounds[1]),
axis=1,
)
for x in xi
]
),
axis=0,
)
# Reshape to_fill to broadcast correctly: (n_samples,) -> (n_samples, 1, 1, ...)
to_fill_broadcast = to_fill.reshape(to_fill.shape[0], *([1] * (out.ndim - 1)))
return jnp.where(to_fill_broadcast, self.fill_value, out)
[docs]
def compute_features(
self, *xi: ArrayLike | Tsd | TsdFrame | TsdTensor
) -> FeatureMatrix:
"""
Apply the basis transformation to the input data.
This method applies each function in ``self.funcs`` to the input arrays ``*xi``.
These functions are called with the arguments ``(*xi, **self.basis_kwargs)`` and must return
an array of shape ``(n_samples, ...)``, where the first dimension corresponds to the number of samples,
and the output must have at least one dimension (i.e., ``ndim >= 1``).
The outputs of all function calls are reshaped into 2D arrays with shape ``(n_samples, n_output)``, and
then concatenated along the feature axis (second dimension) to form the full design matrix.
If the input arrays have more dimensions than ``self.ndim_input``, the function calls are automatically
vectorized over the additional axes. This is done using Python loops, which may be slow. For better
performance, users are encouraged to provide fully vectorized functions.
Parameters
----------
*xi :
Input arrays. Each must have at least ``self.ndim_input`` dimensions. If extra dimensions are present,
they are interpreted as batch or window dimensions over which the basis functions are applied.
Returns
-------
:
The resulting design matrix, with one row per sample and one column per output feature.
Examples
--------
>>> import nemos as nmo
>>> import numpy as np
>>> from functools import partial
>>> def power_func(n, x):
... return x ** n
>>> bas = nmo.basis.CustomBasis([partial(power_func, 1), partial(power_func, 2)])
>>> bas.compute_features(np.arange(1, 4))
array([[1., 1.],
[2., 4.],
[3., 9.]])
"""
xi = _check_transform_input(self, *xi)
if any(x.ndim < self.ndim_input for x in xi):
invalid_dims = [x.ndim for x in xi if x.ndim < self.ndim_input]
raise ValueError(
f"Each input must have at least {self.ndim_input} dimensions, as required by this basis. "
f"However, some inputs have fewer dimensions: {invalid_dims}."
)
_check_unique_shapes(xi, basis=self)
set_input_shape(self, *xi)
if len(xi[0]) == 0:
# no samples
if self._pynapple_support:
conv_type = "numpy" if nap.nap_config.backend == "numba" else "jax"
apply_func = support_pynapple(conv_type)(
lambda *x: np.zeros((0, self.n_output_features))
)
return apply_func(*xi)
else:
return jnp.zeros((0, self.n_output_features))
design_matrix = self.evaluate(
*xi
) # (n_samples, *n_output_shape, n_vec_dim, n_basis)
# Apply fill_value to out-of-bounds samples
if self.bounds is not None:
design_matrix = self._apply_fill_value(*xi, out=design_matrix)
# return a model design
return design_matrix.reshape((xi[0].shape[0], -1))
[docs]
def evaluate(self, *xi: NDArray):
"""
Evaluate the basis functions in a vectorized form at the given sample points.
Parameters
----------
*xi :
The samples at which the basis functions are evaluated. Each element in `xi` corresponds
to an input dimension, and must be broadcastable to a common shape along the sample axis.
The shape of each input array should be (n_samples, ...) where the first axis indexes samples.
Returns
-------
basis_funcs :
The basis functions evaluated at the given input points, with shape
(n_samples, n_vect_input * n_basis_funcs), n_vect_input is the number of inputs that are
vectorized.
Notes
-----
This method supports both NumPy and pynapple inputs. If pynapple support is enabled,
the inputs and outputs are automatically cast using the configured backend (e.g., JAX or NumPy).
Evaluation is performed by applying a vectorized function over each basis function and
concatenating the results along the last axis.
Examples
--------
>>> import numpy as np
>>> from nemos.basis import CustomBasis
>>> basis = CustomBasis(funcs=[lambda x: x, lambda x: x**2])
>>> x = np.linspace(0, 1, 10)
>>> out = basis.evaluate(x)
>>> out.shape
(10, 2)
>>> # vectorize over 3 inputs
>>> out = basis.evaluate(np.random.randn(10, 3))
>>> out.shape
(10, 3, 2)
"""
if self._pynapple_support:
conv_type = "numpy" if nap.nap_config.backend == "numba" else "jax"
apply_func = support_pynapple(conv_type)(apply_f_vectorized)
stack = np.stack
else:
apply_func = apply_f_vectorized
stack = jnp.stack
# Get individual function results
func_results = [
apply_func(f, *xi, **self.basis_kwargs, ndim_input=self.ndim_input)
for f in self.funcs
]
# Stack functions first, then reorder
stacked = stack(
func_results, axis=-1
) # (n_samples, *out_shape, n_vec_features, n_funcs)
self.output_shape = stacked.shape[1:-2]
# no vectorization
if all(x.ndim == self.ndim_input for x in xi):
stacked = stacked[..., 0, :]
return stacked
[docs]
@set_input_shape_state(states=("_input_shape_product", "_input_shape_", "_label"))
def __sklearn_clone__(self) -> CustomBasis:
"""Clone the basis while preserving attributes related to input shapes.
This method ensures that input shape attributes (e.g., `_input_shape_product`,
`_input_shape_`) are preserved during cloning. Reinitializing the class
as in the regular sklearn clone would drop these attributes, rendering
cross-validation unusable.
"""
klass = self.__class__(**self.get_params())
return klass
@property
def n_output_features(self):
"""The number of output features, i.e. the number of columns of the design matrix."""
if self.input_shape is None:
return None
# Computation for number of output features:
# 1. Compute the number of vectorized dimensions:
# - discard axis corresponding to input dimensionality
# - multiply the shape of the remaining axis
# 2. multiply by the number of basis and the shape of the output
# Note that self._input_shape_ could be:
# 1. A list of tuple, if the funcs require > 1 input.
# 2. A single tuple, if the funcs require 1 input.
ishape = (
[self._input_shape_]
if not isinstance(self._input_shape_, list)
else self._input_shape_
)
vec_inp = np.prod([p for shape in ishape for p in shape[self.ndim_input - 1 :]])
return int(vec_inp * np.prod(self.output_shape) * len(self.funcs))
@staticmethod
def _reshape_concatenated_arrays(
array: NDArray, bas: CustomBasis, axis: int
) -> NDArray:
# reshape the arrays to match input shapes
shape = list(array.shape)
array = array.reshape(
shape[:axis]
+ [
*bas.output_shape,
*(
i
for shape in bas._input_shape_
for i in shape[bas.ndim_input - 1 :]
),
len(bas.funcs),
]
+ shape[axis + 1 :]
)
return array
[docs]
@promote_to_transformer
def __add__(self, other: BasisMixin) -> AdditiveBasis:
return AdditiveBasis(self, other)
@promote_to_transformer
def __rmul__(self, other: BasisMixin | int) -> BasisMixin:
return self.__mul__(other)
[docs]
@promote_to_transformer
def __mul__(self, other: BasisMixin | int) -> BasisMixin:
if isinstance(other, int):
return multiply_basis_by_integer(self, other)
if not is_basis_like(other):
raise TypeError(
"Basis multiplicative factor should be a Basis object or a positive integer!"
)
return MultiplicativeBasis(self, other)
[docs]
@promote_to_transformer
def __pow__(self, exponent) -> BasisMixin:
return raise_basis_to_power(self, exponent)
def __repr__(self, n=0):
if self.bounds is None:
kwargs = dict(exclude_keys=["fill_value"])
else:
kwargs = {}
rep = format_repr(self, **kwargs, multiline=True)
tab = " "
return rep.replace("\n", f"\n{tab * n}")
def __len__(self) -> int:
return 1
[docs]
@add_docstring("split_by_feature", BasisMixin)
def split_by_feature(
self,
x: NDArray,
axis: int = 1,
):
"""
Examples
--------
>>> import nemos as nmo
>>> import numpy as np
>>> from functools import partial
>>> def power_func(n, x):
... return x ** n
>>> bas = nmo.basis.CustomBasis([partial(power_func, 1), partial(power_func, 2)])
>>> # define a 3 x 2 input
>>> inp = np.arange(1, 7).reshape(3, 2)
>>> X = bas.compute_features(inp)
>>> X.shape # (3, 2 * 2)
(3, 4)
>>> bas.split_by_feature(X)["CustomBasis"] # spilt to (3, 2, 2)
array([[[ 1., 1.],
[ 2., 4.]],
...
[[ 3., 9.],
[ 4., 16.]],
...
[[ 5., 25.],
[ 6., 36.]]], dtype=float32)
"""
# ruff: noqa: D205, D400
return super().split_by_feature(x, axis=axis)
@property
def input_shape(
self,
) -> None | List[None] | Tuple[int, ...] | List[Tuple[int, ...]]:
"""Input shape as a tuple or list of tuple.
The property mimics the behavior of atomic bases, and uses the
assumption that _input_shape_ for custom bases is a list of
length one.
"""
input_shape = self._input_shape_
if input_shape is None:
if self._n_inputs == 1:
return None
else:
return [None] * self._n_inputs
if self._n_inputs == 1:
return input_shape[0]
return input_shape * self._n_inputs