"""GLM for Classification."""
# required to get ArrayLike to render correctly
from __future__ import annotations
from numbers import Number
from typing import Any, Callable, Literal, Optional, Tuple, Union
import jax
import jax.numpy as jnp
from numpy.typing import ArrayLike, NDArray
from .. import observation_models as obs
from .. import tree_utils
from ..label_encoder import LabelEncoder
from ..regularizer import ElasticNet, GroupLasso, Lasso, Regularizer, Ridge
from ..type_casting import is_numpy_array_like, support_pynapple
from ..typing import (
DESIGN_INPUT_TYPE,
SolverState,
StepResult,
UserProvidedParamsT,
)
from .glm import GLM, PopulationGLM
from .params import GLMUserParams
from .validation import (
ClassifierGLMValidator,
PopulationClassifierGLMValidator,
)
__all__ = ["ClassifierGLM", "ClassifierPopulationGLM"]
class ClassifierMixin:
"""GLM for classification."""
# observation model inferred
_invalid_observation_types = ()
def set_classes(self, y: ArrayLike) -> ClassifierMixin:
"""
Infer unique class labels and set the ``classes_`` attribute.
This method infers class labels from ``y`` and sets up the internal
encoding/decoding machinery. When labels are the default ``[0, 1, ..., n_classes-1]``,
encoding is skipped for performance.
Parameters
----------
y
An array that must contain all the class labels,
i.e. ``len(np.unique(y)) == n_classes``.
Raises
------
ValueError
If the number of unique class labels in ``y`` does not match ``n_classes``.
Notes
-----
:meth:`fit` and :meth:`initialize_optimizer_and_state` call ``set_classes`` internally,
making sure that the ``classes_`` attribute matches the provided input.
If you are fitting in batches by calling :meth:`update`, make sure that the ``classes_``
are correctly set by calling ``set_classes`` before starting the :meth:`update` loop.
Examples
--------
When fitting in batches with :meth:`update`, use ``set_classes`` to define
all class labels before initialization. This is necessary when individual
batches may not contain all classes.
>>> import nemos as nmo
>>> import numpy as np
>>> model = nmo.glm.ClassifierGLM(3)
Generate sample data where the first batch only contains 2 of 3 classes:
>>> X = np.random.randn(100, 5)
>>> y_all_classes = np.array([0, 1, 2]) # all possible classes
>>> y_batch1 = np.array([0, 1, 0, 1, 0]) # first batch missing class 2
>>> X_batch1 = X[:5]
Without ``set_classes``, initialization fails if batch lacks all classes:
>>> init_params = model.initialize_params(X_batch1, y_batch1)
Traceback (most recent call last):
RuntimeError: Classes are not set. Must call ``set_classes`` before calling...
Call ``set_classes`` first to define all labels, then initialize:
>>> model.set_classes(y_all_classes)
ClassifierGLM(...)
>>> init_params = model.initialize_params(X_batch1, y_batch1)
>>> state = model.initialize_optimizer_and_state(init_params, X_batch1, y_batch1)
Now batches with any subset of classes work with :meth:`update`:
>>> result = model.update(init_params, state, X_batch1, y_batch1)
"""
self._label_encoder.set_classes(y)
return self
@property
def classes_(self) -> NDArray | None:
"""Class labels, or None if not set."""
return self._label_encoder.classes_
@classes_.setter
def classes_(self, value: NDArray | None) -> None:
if value is not None:
self._label_encoder.set_classes(value)
else:
self._label_encoder.reset()
def compute_loss(
self,
params,
X,
y,
*args,
**kwargs,
):
"""
Compute the loss function for the model.
This method validates inputs, encodes class labels to internal indices,
and computes the loss (negative log-likelihood).
Parameters
----------
params
Parameter tuple of (coefficients, intercept).
X
Input data, array of shape ``(n_time_bins, n_features)`` or pytree of same.
y
Target class labels in the same format as ``classes_``.
*args
Additional positional arguments passed to the model-specific loss function.
**kwargs
Additional keyword arguments passed to the model-specific loss function.
Returns
-------
loss
The loss value (negative log-likelihood).
Raises
------
RuntimeError
If ``classes_`` has not been set.
ValueError
If inputs or parameters have incompatible shapes or invalid values.
"""
self._label_encoder.check_classes_is_set("compute_loss")
y = self._label_encoder.encode(y)
return super().compute_loss(params, X, y, *args, **kwargs)
@property
def n_classes(self):
"""Number of classes."""
return self._label_encoder.n_classes
@n_classes.setter
def n_classes(self, value: int):
# extract item from scalar arrays
if is_numpy_array_like(value)[1] and value.size == 1:
value = value.item()
if not isinstance(value, Number) or value < 2 or not int(value) == value:
raise ValueError(
"The number of classes must be an integer greater than or equal to 2."
)
self._label_encoder = LabelEncoder(int(value))
# reset validator.
self._validator = self._validator_class(
extra_params=self._get_validator_extra_params()
)
def _get_validator_extra_params(self) -> dict:
"""Get validator extra parameters."""
return {"n_classes": self._label_encoder.n_classes}
def _preprocess_inputs(
self,
X: DESIGN_INPUT_TYPE,
y: Optional[jnp.ndarray] = None,
drop_nans: bool = True,
) -> Tuple[dict[str, jnp.ndarray] | jnp.ndarray, jnp.ndarray | None]:
"""Preprocess inputs before initializing state."""
X, y = super()._preprocess_inputs(X, y=y, drop_nans=drop_nans)
if y is not None:
y = self._validator.check_and_cast_y_to_integer(y)
y = jax.nn.one_hot(y, self._label_encoder.n_classes)
return X, y
# Note: necessary double decorator. The super().predict is decorated as well,
# but the pynapple metadata would be dropped if we do not decorate here.
# This happens because super().predict returns the log-proba which have the same
# shape of one_hot(y), not matching the original y.shape.
@support_pynapple(conv_type="jax")
def predict(self, X: DESIGN_INPUT_TYPE) -> jnp.ndarray:
"""
Predict class labels for samples in X.
Parameters
----------
X :
The input samples. Can be an array of shape ``(n_samples, n_features)``
or a pytree of arrays of the same shape.
Returns
-------
:
Predicted class labels for each sample.
Returns an integer array of shape ``(n_samples, )`` with values in
``[0, n_classes - 1]``.
Examples
--------
>>> import jax.numpy as jnp
>>> import nemos as nmo
>>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
>>> y = jnp.array([0, 0, 1, 1])
>>> model = nmo.glm.ClassifierGLM(n_classes=2).fit(X, y)
>>> predictions = model.predict(X)
>>> predictions.shape
(4,)
"""
# Below will raise if user set manually coef and intercept
# and calls predict.
# One could assume default labels 0,...,n-1
# but requiring to be explicit is safer
self._label_encoder.check_classes_is_set("predict")
log_proba = super().predict(X)
return self._label_encoder.decode(jnp.argmax(log_proba, axis=-1))
def predict_proba(
self,
X: DESIGN_INPUT_TYPE,
return_type: Literal["log-proba", "proba"] = "log-proba",
) -> jnp.ndarray:
"""
Predict class probabilities for samples in X.
Parameters
----------
X :
The input samples. Can be an array of shape ``(n_samples, n_features)``
or a pytree of arrays of the same shape.
return_type :
The format of the returned probabilities. If ``"log-proba"``, returns
log-probabilities. If ``"proba"``, returns probabilities. Defaults to
``"log-proba"``.
Returns
-------
:
Predicted class probabilities. Returns an array of shape ``(n_samples, n_classes)``
where each row sums to 1 (for probabilities) or to 0 in log-space (for log-probabilities).
Examples
--------
>>> import jax.numpy as jnp
>>> import nemos as nmo
>>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
>>> y = jnp.array([0, 0, 1, 1])
>>> model = nmo.glm.ClassifierGLM(n_classes=2).fit(X, y)
>>> proba = model.predict_proba(X, return_type="proba")
>>> proba.shape
(4, 2)
"""
# Below will raise if user set manually coef and intercept
# and calls predict without setting the class label mapping.
# One could assume default labels 0,...,n-1
# but requiring to be explicit makes the mapping between
# the class labels and the probability index less ambiguous:
# `log_proba[:, i]` is the log-proba of class `self.classes_[i]`
self._label_encoder.check_classes_is_set("predict_proba")
# log-proba for categorical, proba for Bernoulli
log_proba = super().predict(X)
if return_type == "log-proba":
return log_proba
elif return_type == "proba":
exp = support_pynapple(conv_type="jax")(jnp.exp)
proba = exp(log_proba)
# renormalize (sum to 1 constraint)
proba /= proba.sum(axis=-1, keepdims=True)
return proba
else:
raise ValueError(f"Unrecognized return type ``'{return_type}'``")
def _estimate_resid_degrees_of_freedom(
self, X: DESIGN_INPUT_TYPE, n_samples: Optional[int] = None
) -> jnp.ndarray:
"""
Estimate the degrees of freedom of the residuals for classifier GLM.
Parameters
----------
X :
The design matrix.
n_samples :
The number of samples observed. If not provided, n_samples is set to
``X.shape[0]``. If the fit is batched, n_samples could be larger than
``X.shape[0]``.
Returns
-------
:
An estimate of the degrees of freedom of the residuals.
"""
# Convert a pytree to a design-matrix
x_leaf = jax.tree_util.tree_leaves(X)
if n_samples is None:
n_samples = x_leaf[0].shape[0]
else:
if not isinstance(n_samples, int):
raise TypeError(
f"`n_samples` must be `None` or of type `int`. "
f"Type {type(n_samples)} provided instead!"
)
n_features = sum(x.shape[1] for x in x_leaf)
# Effective degrees of freedom is n_classes - 1 due to probability simplex constraint
n_m1_classes = self._label_encoder.n_classes - 1
params = self._get_model_params()
# Infer n_neurons from coef shape:
# ClassifierGLM: coef is (n_features, n_classes) -> n_neurons = 1
# ClassifierPopulationGLM: coef is (n_features, n_neurons, n_classes) -> n_neurons = shape[1]
coef_leaf = jax.tree_util.tree_leaves(params.coef)[0]
n_neurons = 1 if coef_leaf.ndim == 2 else coef_leaf.shape[1]
# For Lasso-type regularizers, use the non-zero coefficients as DOF estimate
# see https://arxiv.org/abs/0712.0881
if isinstance(self.regularizer, (GroupLasso, Lasso, ElasticNet)):
# Sum over features (axis 0) and classes (axis -1)
# This leaves shape (n_neurons,) for ClassifierPopulationGLM
# or scalar for ClassifierGLM
resid_dof = tree_utils.pytree_map_and_reduce(
lambda x: ~jnp.isclose(x, jnp.zeros_like(x)),
lambda x: sum([jnp.sum(i, axis=(0, -1)) for i in x]),
params.coef,
)
return jnp.atleast_1d(n_samples - resid_dof - n_m1_classes)
elif isinstance(self.regularizer, Ridge):
# For Ridge, use total parameters
return (n_samples - n_m1_classes * n_features - n_m1_classes) * jnp.ones(
n_neurons
)
else:
# For UnRegularized, use the rank
rank = jnp.linalg.matrix_rank(jnp.concatenate(x_leaf, axis=1))
return (n_samples - rank * n_m1_classes - n_m1_classes) * jnp.ones(
n_neurons
)
def simulate(
self,
random_key: jax.Array,
feedforward_input: DESIGN_INPUT_TYPE,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Simulate categorical responses from the model.
Parameters
----------
random_key :
A JAX random key used to generate the simulated responses.
feedforward_input :
The input samples used to generate the responses. Can be an array of
shape ``(n_samples, n_features)`` or a pytree of arrays of the same
shape.
Returns
-------
:
A tuple ``(y, log_prob)`` where:
- ``y`` is an array of shape ``(n_samples,)`` containing the
simulated class labels (in the same format as ``classes_``).
- ``log_prob`` is an array of shape ``(n_samples,)`` containing the
log-probability of the simulated responses under the model.
Raises
------
RuntimeError
If ``classes_`` has not been set. Call :meth:`set_classes` or :meth:`fit`
before calling this method.
Examples
--------
>>> import jax
>>> import jax.numpy as jnp
>>> import nemos as nmo
>>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
>>> y = jnp.array([0, 0, 1, 1])
>>> model = nmo.glm.ClassifierGLM(n_classes=2).fit(X, y)
>>> key = jax.random.key(0)
>>> simulated_y, log_prob = model.simulate(key, X)
>>> simulated_y.shape
(4,)
"""
self._label_encoder.check_classes_is_set("simulate")
y, log_prob = super().simulate(random_key, feedforward_input)
argmax = support_pynapple(conv_type="jax")(lambda x: jnp.argmax(x, axis=-1))
y = self._label_encoder.decode(argmax(y))
return y, log_prob
def initialize_optimizer_and_state(
self,
init_params: UserProvidedParamsT,
X: DESIGN_INPUT_TYPE,
y: jnp.ndarray,
) -> SolverState:
"""Initialize the solver and its state for running fit and update.
This method must be called before using :meth:`update` for iterative optimization.
It sets up the solver with the provided initial parameters and data.
Parameters
----------
init_params
Initial parameter tuple of (coefficients, intercept).
X
Input data, array of shape ``(n_time_bins, n_features)`` or pytree of same.
y
Target labels, array of shape ``(n_time_bins,)`` for single neuron/subject models or
``(n_time_bins, n_neurons)`` for population models.
Returns
-------
state
Initial solver state.
Raises
------
ValueError
If inputs or parameters have incompatible shapes or invalid values.
"""
self._label_encoder.check_classes_is_set("initialize_optimizer_and_state")
y = self._label_encoder.encode(y)
return super().initialize_optimizer_and_state(init_params, X, y)
def initialize_params(
self,
X: DESIGN_INPUT_TYPE,
y: jnp.ndarray,
) -> UserProvidedParamsT:
"""
Initialize model parameters for categorical GLM.
Initialize coefficients with zeros and intercept by matching the mean class
proportions. Class labels are automatically converted to one-hot encoding.
Parameters
----------
X :
Input data, array of shape ``(n_time_bins, n_features)`` or pytree of same.
y :
Class labels, array of shape ``(n_time_bins,)`` for single neuron
models or ``(n_time_bins, n_neurons)`` for population models. Labels
must be a subset of ``classes_``.
Returns
-------
:
Initial parameter tuple of (coefficients, intercept).
Notes
-----
All labels in ``y`` must be present in ``classes_``. Passing labels not
in ``classes_`` will raise an error.
Examples
--------
>>> import jax.numpy as jnp
>>> import nemos as nmo
>>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
>>> y = jnp.array([0, 0, 1, 1])
>>> model = nmo.glm.ClassifierGLM(n_classes=2)
>>> model.set_classes(y)
ClassifierGLM(...)
>>> coef, intercept = model.initialize_params(X, y)
>>> coef.shape
(2, 2)
"""
self._label_encoder.check_classes_is_set("initialize_params")
y = self._label_encoder.encode(y)
y = self._validator.check_and_cast_y_to_integer(y)
y = jax.nn.one_hot(y, self.n_classes)
return super().initialize_params(X, y)
def update(
self,
params: GLMUserParams,
opt_state: SolverState,
X: DESIGN_INPUT_TYPE,
y: jnp.ndarray,
*args,
n_samples: Optional[int] = None,
**kwargs,
) -> StepResult:
"""
Update the model parameters and solver state.
Performs a single optimization step using the model's solver. Class labels
are automatically encoded to internal indices and converted to one-hot
encoding before the update.
**Important**: Labels of any dtype (integers, floats, strings, etc.) are
supported and will be encoded using the ``classes_`` attribute set via
:meth:`set_classes`. For best performance, use integer labels ``[0, n_classes - 1]``.
Parameters
----------
params :
The current model parameters, typically a tuple of coefficients and intercepts.
opt_state :
The current state of the optimizer, encapsulating information necessary for the
optimization algorithm to continue from the current state.
X :
The predictors used in the model fitting process. Shape ``(n_time_bins, n_features)``
or a pytree of arrays of the same shape.
y :
Class labels, array of shape ``(n_time_bins,)`` for single neuron
models or ``(n_time_bins, n_neurons)`` for population models. Labels must
match those defined in ``classes_``.
*args :
Additional positional arguments to be passed to the solver's update method.
n_samples :
The total number of samples. Usually larger than the samples of an individual batch,
used to estimate the scale parameter of the GLM.
**kwargs :
Additional keyword arguments to be passed to the solver's update method.
Returns
-------
params :
Updated model parameters (coefficients, intercepts).
state :
Updated optimizer state.
Examples
--------
>>> import jax.numpy as jnp
>>> import nemos as nmo
>>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
>>> y = jnp.array([0, 0, 1, 1])
>>> model = nmo.glm.ClassifierGLM(n_classes=2)
>>> model.set_classes(y)
ClassifierGLM(...)
>>> params = model.initialize_params(X, y)
>>> opt_state = model.initialize_optimizer_and_state(params, X, y)
>>> new_params, new_state = model.update(params, opt_state, X, y)
"""
self._label_encoder.check_classes_is_set("update")
# note: do not check and cast here. Risky but the performance of
# the update has priority.
y = self._label_encoder.encode(y, safe=False)
y = jax.nn.one_hot(y, self.n_classes)
return super().update(
params, opt_state, X, y, *args, n_samples=n_samples, **kwargs
)
[docs]
class ClassifierGLM(ClassifierMixin, GLM):
"""
Generalized Linear Model for multi-class classification.
This model predicts discrete class labels from input features using a
softmax (multinomial logistic) model. It uses an over-parameterized
representation with one set of coefficients per class, resulting in
coefficient shape ``(n_features, n_classes)`` and intercept shape ``(n_classes,)``.
Parameters
----------
n_classes
The number of classes. Must be >= 2.
inverse_link_function
The inverse link function. Default is ``log_softmax``.
regularizer
The regularization scheme. Default is ``Ridge``. Note that the
model is over-parameterized: one set of coefficients for each class.
Regularization makes the parameters identifiable. Setting ``UnRegularized``
will result in non-identifiable coefficients, see note below.
regularizer_strength
The strength of the regularization.
solver_name
The solver to use for optimization.
solver_kwargs
Additional keyword arguments for the solver.
Attributes
----------
coef_
Fitted coefficients of shape ``(n_features, n_classes)`` after calling :meth:`fit`.
intercept_
Fitted intercepts of shape ``(n_classes,)`` after calling :meth:`fit`.
Notes
-----
**Identifiability**
This model uses an over-parameterized (symmetric) representation where each class
has its own set of coefficients. Since probabilities from softmax are invariant to
adding a constant to all linear predictors, the parameters are not uniquely
identifiable without regularization. For example, if ``(coef, intercept)`` is a
solution, so is ``(coef + c, intercept + c)`` for any constant ``c``.
Using regularization (default is ``Ridge``) resolves this ambiguity by penalizing
the parameter magnitudes, effectively centering the solution. If you use
``UnRegularized``, the optimization may converge to different equivalent solutions
depending on initialization, though predictions will be identical.
**Class Labels**
The target array ``y`` can contain any hashable class labels that can be stored
in a NumPy array, including integers, strings, or other hashable types. The model
internally maps these labels to indices ``[0, n_classes - 1]`` for computation
and maps them back when returning predictions.
**Performance Considerations**
For optimal performance, use integer labels ``[0, 1, ..., n_classes - 1]``. When
labels follow this convention, the model skips the encoding/decoding steps entirely.
Using other label formats (e.g., ``["cat", "dog"]`` or ``[5, 10, 15]``) incurs a
small overhead for label translation.
**Setting Class Labels**
The :meth:`fit` and :meth:`initialize_optimizer_and_state` methods automatically infer
class labels from the provided ``y``. If you set ``coef_`` and ``intercept_`` manually,
you must call :meth:`set_classes` before using :meth:`predict`, :meth:`predict_proba`,
:meth:`simulate`, :meth:`score`, or :meth:`compute_loss`.
See Also
--------
ClassifierPopulationGLM : Multi-class classification for multiple neurons.
GLM : Generalized Linear Model for continuous/count responses.
Examples
--------
**Fit a ClassifierGLM**
Basic binary classification:
>>> import jax.numpy as jnp
>>> import numpy as np
>>> import nemos as nmo
>>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
>>> y = jnp.array([0, 0, 1, 1])
>>> model = nmo.glm.ClassifierGLM(n_classes=2).fit(X, y)
>>> model.coef_.shape
(2, 2)
**Predict Class Labels**
Get predicted class labels:
>>> predictions = model.predict(X)
>>> predictions.shape
(4,)
**Predict Class Probabilities**
Get class probabilities or log-probabilities:
>>> proba = model.predict_proba(X, return_type="proba")
>>> proba.shape
(4, 2)
>>> log_proba = model.predict_proba(X, return_type="log-proba")
>>> log_proba.shape
(4, 2)
**Use String Labels**
Class labels can be strings or any hashable type:
>>> y_str = np.array(["cat", "cat", "dog", "dog"])
>>> model = nmo.glm.ClassifierGLM(n_classes=2).fit(X, y_str)
>>> model.classes_
array(['cat', 'dog'], dtype='<U3')
>>> model.predict(X) # doctest: +NORMALIZE_WHITESPACE
array(['cat', 'cat', 'dog', 'dog'], dtype='<U3')
**Multi-class Classification**
Classify into more than two classes:
>>> X = jnp.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0], [5.0, 6.0], [6.0, 7.0]])
>>> y = jnp.array([0, 0, 1, 1, 2, 2])
>>> model = nmo.glm.ClassifierGLM(n_classes=3).fit(X, y)
>>> model.coef_.shape
(2, 3)
**Use Regularization**
Change regularization strength:
>>> model = nmo.glm.ClassifierGLM(
... n_classes=2,
... regularizer="Ridge",
... regularizer_strength=0.5
... )
>>> model.regularizer
Ridge()
**Use a Pytree of arrays as Input**
Features can be passed as any JAX pytree of 2-D arrays; the fitted
``coef_`` will share the same pytree structure:
>>> X_dict = {"feature_1": X[:, :1], "feature_2": X[:, 1:]}
>>> model = nmo.glm.ClassifierGLM(n_classes=3).fit(X_dict, y)
>>> # The coefficient structure matches the input
>>> type(model.coef_)
<class 'dict'>
"""
_validator_class = ClassifierGLMValidator
[docs]
def __init__(
self,
n_classes: Optional[int] = 2,
inverse_link_function: Optional[Callable] = None,
regularizer: Optional[Union[str, Regularizer]] = None,
regularizer_strength: Any = None,
solver_name: str = None,
solver_kwargs: dict = None,
):
self.n_classes = n_classes
observation_model = obs.CategoricalObservations()
if regularizer is None:
regularizer = "Ridge"
super().__init__(
observation_model=observation_model,
inverse_link_function=inverse_link_function,
regularizer=regularizer,
regularizer_strength=regularizer_strength,
solver_name=solver_name,
solver_kwargs=solver_kwargs,
)
[docs]
def fit(
self,
X: Union[DESIGN_INPUT_TYPE, ArrayLike],
y: ArrayLike,
init_params: Optional[GLMUserParams] = None,
):
"""
Fit the model to training data.
Parameters
----------
X
Training input samples of shape ``(n_samples, n_features)`` or a pytree of arrays of the same shape.
y
Target class labels of shape ``(n_samples,)``. Labels can be any hashable
type (integers, strings, etc.). Float arrays with integer values are
accepted and converted automatically.
init_params
Initial parameter values as tuple of ``(coef, intercept)``. If None,
parameters are initialized automatically.
Returns
-------
:
The fitted model.
Notes
-----
``fit`` calls :meth:`set_classes` internally, so ``classes_`` is always
consistent with the labels in ``y``.
Examples
--------
>>> import jax.numpy as jnp
>>> import nemos as nmo
>>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
>>> y = jnp.array([0, 0, 1, 1])
>>> model = nmo.glm.ClassifierGLM(n_classes=2)
>>> model = model.fit(X, y)
>>> model.coef_.shape
(2, 2)
"""
self.set_classes(y)
y = self._label_encoder.encode(y)
return super().fit(X, y, init_params)
[docs]
def score(
self,
X: Union[DESIGN_INPUT_TYPE, ArrayLike],
y: ArrayLike,
score_type: Literal[
"log-likelihood", "pseudo-r2-McFadden", "pseudo-r2-Cohen"
] = "log-likelihood",
aggregate_sample_scores: Optional[Callable] = jnp.mean,
) -> jnp.ndarray:
"""
Score the model on test data.
Parameters
----------
X
Test input samples of shape ``(n_samples, n_features)`` or a pytree of arrays of the same shape.
y
True class labels of shape ``(n_samples,)``. Labels must be a subset
of ``classes_``.
score_type
The type of score to compute.
aggregate_sample_scores
Function to aggregate per-sample scores.
Returns
-------
:
The computed score.
Notes
-----
All labels in ``y`` must be present in ``classes_``. Passing labels not
in ``classes_`` will raise an error.
Examples
--------
>>> import jax.numpy as jnp
>>> import nemos as nmo
>>> X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
>>> y = jnp.array([0, 0, 1, 1])
>>> model = nmo.glm.ClassifierGLM(n_classes=2).fit(X, y)
>>> score = model.score(X, y)
"""
# check if classes are not set, aka user set the coef and intercept
# manually, raise otherwise there may be ambiguity in interpreting
# the labels.
self._label_encoder.check_classes_is_set("score")
y = self._label_encoder.encode(y)
return super().score(X, y, score_type, aggregate_sample_scores)
[docs]
class ClassifierPopulationGLM(ClassifierMixin, PopulationGLM):
"""
Population Generalized Linear Model for multi-class classification.
This model predicts discrete class labels from input features using a
softmax (multinomial logistic) model for multiple neurons simultaneously.
It uses an over-parameterized representation with one set of coefficients
per class, resulting in coefficient shape ``(n_features, n_neurons, n_classes)``
and intercept shape ``(n_neurons, n_classes)``.
Parameters
----------
n_classes
The number of classes. Must be >= 2.
inverse_link_function
The inverse link function. Default is ``log_softmax``.
regularizer
The regularization scheme. Default is ``Ridge``. Note that the
model is over-parameterized: one set of coefficients for each class.
Regularization makes the parameters identifiable. Setting ``UnRegularized``
will result in non-identifiable coefficients, see note below.
regularizer_strength
The strength of the regularization.
solver_name
The solver to use for optimization.
solver_kwargs
Additional keyword arguments for the solver.
feature_mask
Mask indicating which features are used for each neuron.
Attributes
----------
coef_
Fitted coefficients of shape ``(n_features, n_neurons, n_classes)``
after calling :meth:`fit`.
intercept_
Fitted intercepts of shape ``(n_neurons, n_classes)`` after calling :meth:`fit`.
Notes
-----
**Identifiability**
This model uses an over-parameterized (symmetric) representation where each class
has its own set of coefficients. Since probabilities from softmax are invariant to
adding a constant to all linear predictors, the parameters are not uniquely
identifiable without regularization. For example, if ``(coef, intercept)`` is a
solution, so is ``(coef + c, intercept + c)`` for any constant ``c``.
Using regularization (default is ``Ridge``) resolves this ambiguity by penalizing
the parameter magnitudes, effectively centering the solution. If you use
``UnRegularized``, the optimization may converge to different equivalent solutions
depending on initialization, though predictions will be identical.
**Class Labels**
The target array ``y`` can contain any hashable class labels that can be stored
in a NumPy array, including integers, strings, or other hashable types. The model
internally maps these labels to indices ``[0, n_classes - 1]`` for computation
and maps them back when returning predictions.
**Performance Considerations**
For optimal performance, use integer labels ``[0, 1, ..., n_classes - 1]``. When
labels follow this convention, the model skips the encoding/decoding steps entirely.
Using other label formats (e.g., ``["cat", "dog"]`` or ``[5, 10, 15]``) incurs a
small overhead for label translation.
**Setting Class Labels**
The :meth:`fit` and :meth:`initialize_optimizer_and_state` methods automatically infer
class labels from the provided ``y``. If you set ``coef_`` and ``intercept_`` manually,
you must call :meth:`set_classes` before using :meth:`predict`, :meth:`predict_proba`,
:meth:`simulate`, :meth:`score`, or :meth:`compute_loss`.
See Also
--------
ClassifierGLM : Multi-class classification for a single neuron.
PopulationGLM : Population GLM for continuous/count responses.
Examples
--------
**Fit a ClassifierPopulationGLM**
Basic multi-class classification for multi-subjects
>>> import jax.numpy as jnp
>>> import numpy as np
>>> import nemos as nmo
>>> X = jnp.array([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.], [6., 7.]])
>>> y = jnp.array([[0, 0], [0, 1], [1, 0], [1, 2], [2, 1], [2, 2]])
>>> model = nmo.glm.ClassifierPopulationGLM(n_classes=3).fit(X, y)
>>> model.coef_.shape
(2, 2, 3)
**Predict Class Labels**
Get predicted class labels for each subject:
>>> predictions = model.predict(X)
>>> predictions.shape
(6, 2)
**Predict Class Probabilities**
Get class probabilities for each subject:
>>> proba = model.predict_proba(X, return_type="proba")
>>> proba.shape
(6, 2, 3)
**Use String Labels**
Class labels can be strings or any hashable type:
>>> y_str = np.array([["a", "a"], ["a", "b"], ["b", "a"], ["b", "c"], ["c", "b"], ["c", "c"]])
>>> model = nmo.glm.ClassifierPopulationGLM(n_classes=3).fit(X, y_str)
>>> model.classes_
array(['a', 'b', 'c'], dtype='<U1')
>>> model.predict(X).shape
(6, 2)
**Use a Feature Mask**
Specify which features predict each neuron:
>>> feature_mask = jnp.array([[[1, 1, 1], [0, 0, 0]], [[1, 1, 1], [1, 1, 1]]])
>>> y = jnp.array([[0, 0], [0, 1], [1, 0], [1, 2], [2, 1], [2, 2]])
>>> model = nmo.glm.ClassifierPopulationGLM(
... n_classes=3,
... feature_mask=feature_mask
... ).fit(X, y)
>>> model.coef_
Array(...)
**Use Regularization**
Change regularization strength:
>>> model = nmo.glm.ClassifierPopulationGLM(
... n_classes=3,
... regularizer="Ridge",
... regularizer_strength=0.5
... )
>>> model.regularizer
Ridge()
**Use a Pytree of arrays as Input**
Features can be passed as any JAX pytree of 2-D arrays; the fitted
``coef_`` will share the same pytree structure:
>>> X_dict = {"feature_1": X[:, :1], "feature_2": X[:, 1:]}
>>> model = nmo.glm.ClassifierPopulationGLM(n_classes=3).fit(X_dict, y)
>>> type(model.coef_)
<class 'dict'>
"""
_validator_class = PopulationClassifierGLMValidator
[docs]
def __init__(
self,
n_classes: Optional[int] = 2,
inverse_link_function: Optional[Callable] = None,
regularizer: Optional[Union[str, Regularizer]] = None,
regularizer_strength: Any = None,
solver_name: str = None,
solver_kwargs: dict = None,
feature_mask: Optional[jnp.ndarray] = None,
):
self.n_classes = n_classes
observation_model = obs.CategoricalObservations()
if regularizer is None:
regularizer = "Ridge"
super().__init__(
observation_model=observation_model,
inverse_link_function=inverse_link_function,
regularizer=regularizer,
regularizer_strength=regularizer_strength,
solver_name=solver_name,
solver_kwargs=solver_kwargs,
feature_mask=feature_mask,
)
@property
def feature_mask(self) -> Union[jnp.ndarray, dict[str, jnp.ndarray]]:
"""
Mask indicating which weights are used, matching the coefficients shape.
The feature mask has the same structure and shape as the coefficients (``coef_``):
- **Array input**: Shape ``(n_features, n_neurons, n_classes)``.
Each entry ``[i, j, k]`` indicates whether the weight for feature ``i``,
neuron ``j``, and category ``k`` is used (1 = used, 0 = masked).
- **Pytree input**: A pytree matching the ``coef_`` structure.
Each leaf array has the same shape as the corresponding coefficient leaf
``(n_features_per_key, n_neurons, n_classes)``.
Returns
-------
:
The feature mask, or None if not set.
"""
return self._feature_mask
@feature_mask.setter
def feature_mask(self, feature_mask: Union[DESIGN_INPUT_TYPE, dict]):
# do not allow reassignment after fit
if (self.coef_ is not None) and (self.intercept_ is not None):
raise AttributeError(
"property 'feature_mask' of 'populationGLM' cannot be set after fitting."
)
self._feature_mask = self._validator.validate_and_cast_feature_mask(
feature_mask
)
[docs]
def fit(
self,
X: Union[DESIGN_INPUT_TYPE, ArrayLike],
y: ArrayLike,
init_params: Optional[GLMUserParams] = None,
):
"""
Fit the model to training data.
Parameters
----------
X
Training input samples of shape ``(n_samples, n_features)`` or a pytree of arrays of the same shape.
y
Target class labels of shape ``(n_samples, n_neurons)``. Labels can be
any hashable type (integers, strings, etc.). Float arrays with integer
values are accepted and converted automatically.
init_params
Initial parameter values as tuple of ``(coef, intercept)``. If None,
parameters are initialized automatically.
Returns
-------
:
The fitted model.
Notes
-----
``fit`` calls :meth:`set_classes` internally, so ``classes_`` is always
consistent with the labels in ``y``.
Examples
--------
>>> import jax.numpy as jnp
>>> import nemos as nmo
>>> X = jnp.array([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.], [6., 7.]])
>>> y = jnp.array([[0, 0], [0, 1], [1, 0], [1, 2], [2, 1], [2, 2]])
>>> model = nmo.glm.ClassifierPopulationGLM(n_classes=3)
>>> model = model.fit(X, y)
>>> model.coef_.shape
(2, 2, 3)
"""
self.set_classes(y)
y = self._label_encoder.encode(y)
return super().fit(X, y, init_params)
[docs]
def score(
self,
X: Union[DESIGN_INPUT_TYPE, ArrayLike],
y: ArrayLike,
score_type: Literal[
"log-likelihood", "pseudo-r2-McFadden", "pseudo-r2-Cohen"
] = "log-likelihood",
aggregate_sample_scores: Optional[Callable] = jnp.mean,
) -> jnp.ndarray:
"""
Score the model on test data.
Parameters
----------
X
Test input samples of shape ``(n_samples, n_features)`` or a pytree of arrays of the same shape.
y
True class labels of shape ``(n_samples, n_neurons)``. Labels must
be a subset of ``classes_``.
score_type
The type of score to compute.
aggregate_sample_scores
Function to aggregate per-sample scores.
Returns
-------
:
The computed score.
Notes
-----
All labels in ``y`` must be present in ``classes_``. Passing labels not
in ``classes_`` will raise an error.
Examples
--------
>>> import jax.numpy as jnp
>>> import nemos as nmo
>>> X = jnp.array([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.], [6., 7.]])
>>> y = jnp.array([[0, 0], [0, 1], [1, 0], [1, 2], [2, 1], [2, 2]])
>>> model = nmo.glm.ClassifierPopulationGLM(n_classes=3).fit(X, y)
>>> score = model.score(X, y)
"""
self._label_encoder.check_classes_is_set("score")
y = self._label_encoder.encode(y)
return super().score(X, y, score_type, aggregate_sample_scores)