Source code for nemos.identifiability_constraints

"""Utility functions for applying identifiability constraints to rank deficient feature matrices."""

from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, Callable, Tuple

import jax
import jax.numpy as jnp
import numpy as np
from jax.typing import ArrayLike as JaxArray
from numpy.typing import NDArray

from . import validation
from .tree_utils import get_valid_multitree, tree_slice
from .type_casting import support_pynapple

if TYPE_CHECKING:
    from .basis._basis import Basis

_WARN_FLOAT32_MESSAGE = (
    "The feature matrix is not of dtype `float64`. Consider converting it to `float64` "
    "for increased numerical precision when computing the matrix rank. You can enable "
    "float64 precision globally by adding:\n\n    jax.config.update('jax_enable_x64', True)\n"
)


def add_constant(x):
    """Add intercept term."""
    return jnp.hstack((jnp.ones((x.shape[0], 1), dtype=x.dtype), x))


@partial(jax.jit, static_argnums=(2,))
def _drop_and_compute_rank(feature_matrix, idx, preprocessing_func=add_constant):
    """Drop column idx and compute rank."""
    feature_dropped = feature_matrix.at[:, idx].set(0.0)
    rank_after_drop_column = jnp.linalg.matrix_rank(preprocessing_func(feature_dropped))
    return feature_dropped, rank_after_drop_column


@partial(jax.jit, static_argnums=(1, 2, 3))
def _find_drop_column(
    feature_matrix: JaxArray,
    rank: int,
    max_drop: int,
    preprocessing_func: Callable = add_constant,
) -> JaxArray:
    """
    Find a minimal subset linearly dependent columns that can be dropped.

    This function loops over the columns of a matrix and checks if each column is linearly dependent from the others.
    If the i-th column is linearly dependent, then drop_cols[i] is set to True, and feature_matrix[:, i] is set to 0.
    The loop is stopped when max_drop linearly dependent columns are found.

    Parameters
    ----------
    feature_matrix:
        The rank deficient feature matrix.
    rank:
        The rank of the matrix.
    max_drop:
        Number of columns to be dropped.
    preprocessing_func:
        Additional processing of the feature matrix. By default, add an intercept term. Other processing could
        entail mean-centering the columns or similar.

    Returns
    -------
    drop_cols:
        A boolean vector, True if the column should be dropped, False otherwise.

    """

    def drop_col_and_update(features, dropped_features, drop_cols, iter_num):
        """Drop feature and update drop column boolean."""
        return dropped_features, drop_cols.at[iter_num].set(True)

    def do_not_drop(features, dropped_features, drop_cols, iter_num):
        """Do not drop."""
        return features, drop_cols

    def check_column(iter_num, state):
        """Drop a column if rank is not affected."""
        matrix, _, original_rank, drop_cols, mx_drop = state
        # drop the column (by set to zero) and compute rank
        col_dropped_matrix, mat_rank = _drop_and_compute_rank(
            matrix, iter_num, preprocessing_func
        )
        # apply the change if rank stays constant, do nothing otherwise.
        matrix, drop_cols = jax.lax.cond(
            mat_rank == original_rank,  # condition
            drop_col_and_update,  # true function
            do_not_drop,  # false function
            matrix,  # parameters
            col_dropped_matrix,
            drop_cols,
            iter_num,
        )
        return matrix, mat_rank, original_rank, drop_cols, mx_drop

    def body_func(iter_num, state):
        drop_cols, max_drop = state[-2:]
        return jax.lax.cond(
            drop_cols.sum() < max_drop, check_column, lambda it, x: x, iter_num, state
        )

    init_state = (
        feature_matrix,
        jnp.array(0),
        jnp.array(rank),
        jnp.zeros(feature_matrix.shape[1], dtype=bool),
        max_drop,
    )
    final_state = jax.lax.fori_loop(0, feature_matrix.shape[1], body_func, init_state)

    return final_state[3]


def _add_invalid_entries(feature_matrix, shape_first_axis, is_valid):
    """Add invalid entries to match original shape."""
    feature_matrix = (
        jnp.full(
            (shape_first_axis, *feature_matrix.shape[1:]),
            jnp.nan,
            dtype=feature_matrix.dtype,
        )
        .at[is_valid]
        .set(feature_matrix)
    )
    return feature_matrix


def _apply_identifiability_constraints(
    feature_matrix: JaxArray,
    preprocessing_func: Callable = add_constant,
    warn_if_float32: bool = True,
) -> Tuple[JaxArray, JaxArray]:
    """
    Apply identifiability constraints to a design matrix `feature_matrix`.

    Private function that does the actual computation on a single feature_matrix.
    """
    if warn_if_float32:
        validation._warn_if_not_float64(feature_matrix, _WARN_FLOAT32_MESSAGE)

    shape_sample_axis = feature_matrix.shape[0]
    is_valid = get_valid_multitree(feature_matrix)

    # compute initial rank if needed
    feature_matrix = tree_slice(feature_matrix, is_valid)
    feature_matrix_with_intercept = preprocessing_func(feature_matrix)
    rank = jnp.linalg.matrix_rank(feature_matrix_with_intercept)

    # full rank, no extra computation needed
    if rank == feature_matrix_with_intercept.shape[1]:
        feature_matrix = _add_invalid_entries(
            feature_matrix, shape_sample_axis, is_valid
        )
        return feature_matrix, jnp.zeros((feature_matrix.shape[1]), dtype=bool)

    max_drop = feature_matrix_with_intercept.shape[1] - rank

    # run the search
    drop_cols = _find_drop_column(
        feature_matrix,
        rank=int(rank),
        max_drop=int(max_drop),
        preprocessing_func=preprocessing_func,
    )

    # return the output matrix and the dropped indices
    feature_matrix = _add_invalid_entries(
        feature_matrix[:, ~drop_cols],
        shape_sample_axis,
        is_valid,
    )
    return feature_matrix, drop_cols


[docs] @support_pynapple(conv_type="jax") def apply_identifiability_constraints( feature_matrix: NDArray | JaxArray, add_intercept: bool = True, warn_if_float32: bool = True, ) -> Tuple[NDArray, NDArray[int]]: """ Apply identifiability constraints to a design matrix ``X``. Removes columns from ``X`` until it is full rank to ensure the uniqueness of the GLM (Generalized Linear Model) maximum-likelihood solution. This is particularly crucial for models using bases like BSplines and CyclicBspline, which, due to their construction, sum to 1 and can cause rank deficiency when combined with an intercept. For GLMs, this rank deficiency means that different sets of coefficients might yield identical predicted rates and log-likelihoods, complicating parameter learning, especially in the absence of regularization. For very large feature matrices generated by a sum of low-dimensional basis components, consider ``apply_identifiability_constraints_by_basis_component``. Parameters ---------- feature_matrix: The design matrix before applying the identifiability constraints. add_intercept: Set to True if your model will add an intercept term, False otherwise. warn_if_float32: Raise a warning if feature matrix dtype is float32. Returns ------- constrained_x: The adjusted design matrix with redundant columns dropped and columns mean-centered. kept_columns: The columns that have been kept. Examples -------- >>> import numpy as np >>> from nemos.identifiability_constraints import apply_identifiability_constraints >>> from nemos.basis import BSplineEval >>> from nemos.glm import GLM >>> import jax >>> jax.config.update('jax_enable_x64', True) >>> # define a feature matrix >>> bas = BSplineEval(5) + BSplineEval(6) >>> feature_matrix = bas.compute_features(np.random.randn(100), np.random.randn(100)) >>> # apply constraints >>> constrained_x, kept_columns = apply_identifiability_constraints(feature_matrix) >>> constrained_x.shape (100, 9) >>> kept_columns array([ 1, 2, 3, 4, 6, 7, 8, 9, 10]) Notes ----- Compilation is triggered at every loop. This can be slower than pure python for low number of samples and low dimension for the feature matrix. Usually, the design matrices we work with have a large number of samples. Running the code on GPU will reduce the computation time significantly. """ if add_intercept: preproc_design = add_constant else: def preproc_design(x): return x # return the output matrix and the dropped indices constrained_x, discarded_columns = _apply_identifiability_constraints( jnp.asarray(feature_matrix), preprocessing_func=preproc_design, warn_if_float32=warn_if_float32, ) kept_columns = np.arange(feature_matrix.shape[1])[~discarded_columns] return np.asarray(constrained_x), kept_columns
[docs] @support_pynapple(conv_type="jax") def apply_identifiability_constraints_by_basis_component( basis: Basis, feature_matrix: NDArray, add_intercept: bool = True, ) -> Tuple[NDArray, NDArray]: """Apply identifiability constraint to a design matrix for each component of an additive basis. Parameters ---------- basis: The basis that computed ``feature_matrix``. feature_matrix: The feature matrix before applying the identifiability constraints. add_intercept: Set to True if your model will add an intercept term, False otherwise. Returns ------- constrained_x: The adjusted feature matrix after applying the identifiability constraints as numpy array. kept_columns: Indices of the columns that are kept. This should be used for applying the same transformation to a feature matrix generated from different a set of inputs (as for a test set). Examples -------- >>> import numpy as np >>> import jax >>> from nemos.identifiability_constraints import apply_identifiability_constraints_by_basis_component >>> from nemos.basis import BSplineEval >>> from nemos.glm import GLM >>> jax.config.update('jax_enable_x64', True) >>> # define a feature matrix >>> bas = BSplineEval(5) + BSplineEval(6) >>> feature_matrix = bas.compute_features(np.random.randn(100), np.random.randn(100)) >>> # apply constraints >>> constrained_x, kept_columns = apply_identifiability_constraints_by_basis_component(bas, feature_matrix) >>> constrained_x.shape (100, 9) >>> # generate a test set, shape (20, 11) >>> test_x = bas.compute_features(np.random.randn(20), np.random.randn(20)) >>> test_x.shape (20, 11) >>> # apply constraint to test set >>> test_x = test_x[:, kept_columns] >>> test_x.shape (20, 9) >>> # fit on train and predict on test set >>> rate = GLM().fit(constrained_x, np.random.poisson(size=100)).predict(test_x) """ # gets a dictionary with feature specific feature matrices # stored in tensors of shape (n_samples, n_inputs, n_features) # n_inputs can be larger than one if basis is used to perform # convolutions on multiple signals (as for counts TsdFrames) splits_x = basis.split_by_feature(feature_matrix) # list the arrays split_x = jax.tree_util.tree_leaves(splits_x) # add dim if needed (the dim is at least 2, (n_samples, n_basis) split_x = [x if x.ndim > 2 else x[:, None] for x in split_x] # flatten over inputs split_x = [x.reshape(x.shape[0], -1, x.shape[-1]) for x in split_x] # list leaves and unwrap over input dimension. Additive components have shapes: # (n_samples, n_inputs, n_basis_funcs) split_by_input_x = [x[:, k] for x in split_x for k in range(x.shape[1])] apply_identifiability = partial( apply_identifiability_constraints, add_intercept=add_intercept, warn_if_float32=False, ) validation._warn_if_not_float64(split_by_input_x, _WARN_FLOAT32_MESSAGE) constrained_x_and_columns = jax.tree_util.tree_map( apply_identifiability, split_by_input_x ) # unpack the outputs into array and dropped colum indices def is_leaf(x): return isinstance(x, tuple) constrained_x = tree_slice(constrained_x_and_columns, idx=0, is_leaf=is_leaf) kept_columns = tree_slice(constrained_x_and_columns, idx=1, is_leaf=is_leaf) # stack the arrays back into a feature matrix constrained_x = np.hstack(constrained_x) # indices are referenced to the sub-matrices, get the absolute index in the feature matrix # calculate the shifts for each component shifts = list( np.cumsum( [0] + [ sub_x.shape[1] for sub_x in jax.tree_util.tree_leaves(split_by_input_x)[:-1] ] ) ) kept_columns = jax.tree_util.tree_map(lambda x, y: x + y, kept_columns, shifts) kept_columns = np.hstack(kept_columns) return constrained_x, kept_columns