Source code for nemos.glm_hmm.glm_hmm

"""API for the GLM-HMM model."""

from __future__ import annotations

import warnings
from pathlib import Path
from typing import Any, Callable, Literal, NamedTuple, Optional, Tuple, Union

import equinox as eqx
import jax
import jax.numpy as jnp
import pynapple as nap
from numpy.typing import ArrayLike, NDArray

from .. import observation_models as obs
from .._observation_model_builder import instantiate_observation_model
from ..hmm.expectation_maximization import EMState, em_hmm, em_step
from ..hmm.hmm import BaseHMM
from ..hmm.initialize_parameters import HMM_INITIALIZATION_FN_DICT, InitFunctionHMM
from ..hmm.utils import _check_state_format
from ..inverse_link_function_utils import resolve_inverse_link_function
from ..observation_models import Observations
from ..regularizer import GroupLasso, Lasso, Regularizer, Ridge
from ..tree_utils import pytree_map_and_reduce
from ..typing import (
    DESIGN_INPUT_TYPE,
    ModelParamsT,
    SolverState,
    StepResult,
)
from ..utils import format_repr
from .algorithm_configs import prepare_estep_log_likelihood, prepare_mstep_update_fn
from .initialize_parameters import (
    DEFAULT_INIT_FUNCTIONS_GLMHMM,
    GLMHMM_INITIALIZATION_FN_DICT,
    InitFunctionGLM,
    KMeansInitializerGLM,
    generate_glm_hmm_initial_model_params,
    kmeans_glm_params_init,
    kmeans_scale_init,
    setup_glm_hmm_initialization,
)
from .params import GLMHMMParams, GLMHMMUserParams
from .utils import compute_rate_per_state
from .validation import GLMHMMValidator


[docs] class GLMHMM( BaseHMM[ GLMHMMUserParams, GLMHMMParams, GLMHMM_INITIALIZATION_FN_DICT, GLMHMMValidator ] ): r"""Generalized Linear Model with Hidden Markov Model (GLM-HMM). This model combines a Generalized Linear Model (GLM) with a Hidden Markov Model (HMM) to capture state-dependent relationships between predictors and neural or behavioral responses. The GLM-HMM is suitable for modeling time series data where the relationship between inputs and outputs varies according to an underlying latent state that evolves over time following Markovian dynamics. The model assumes that at each time step, the system is in one of ``n_states`` discrete hidden states. Each state has its own GLM parameters (coefficients and intercept), and transitions between states are governed by a transition probability matrix. The model is fitted using the Expectation-Maximization (EM) algorithm. Below is a table of the default inverse link function for the available observation models. +---------------------+---------------------------------+ | Observation Model | Default Inverse Link Function | +=====================+=================================+ | Poisson | :math:`e^x` | +---------------------+---------------------------------+ | Gamma | :math:`1/x` | +---------------------+---------------------------------+ | Bernoulli | :math:`1 / (1 + e^{-x})` | +---------------------+---------------------------------+ | NegativeBinomial | :math:`e^x` | +---------------------+---------------------------------+ | Gaussian | :math:`x` | +---------------------+---------------------------------+ Below is a table listing the default and available solvers for each regularizer. +---------------+------------------+-------------------------------------------------------------+ | Regularizer | Default Solver | Available Solvers | +===============+==================+=============================================================+ | UnRegularized | LBFGS | GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient | +---------------+------------------+-------------------------------------------------------------+ | Ridge | LBFGS | GradientDescent, BFGS, LBFGS, NonlinearCG, ProximalGradient | +---------------+------------------+-------------------------------------------------------------+ | Lasso | ProximalGradient | ProximalGradient | +---------------+------------------+-------------------------------------------------------------+ | ElasticNet | ProximalGradient | ProximalGradient | +---------------+------------------+-------------------------------------------------------------+ | GroupLasso | ProximalGradient | ProximalGradient | +---------------+------------------+-------------------------------------------------------------+ Parameters ---------- n_states : The number of hidden states in the HMM. Must be a positive integer. observation_model : Observation model to use. The model describes the distribution of the response variable. Default is the Bernoulli model. Alternatives are "Poisson", "Gamma", "NegativeBinomial", and "Gaussian". inverse_link_function : A function that maps the linear combination of predictors into a rate or probability. The default depends on the observation model, see the table above. regularizer : Regularization scheme used in the M-step for the per-state GLM coefficients. Default is ``"Ridge"``. Pass ``"UnRegularized"`` to disable regularization. regularizer_strength : Strength of the regularization applied to the GLM coefficients. Default is ``1.0``. Ignored when ``regularizer="UnRegularized"``. dirichlet_initial_proba : Alpha parameters for the Dirichlet prior over the initial state probabilities. Shape ``(n_states,)``. If None, a flat (uninformative) prior is assumed. dirichlet_transition_proba : Alpha parameters for the Dirichlet prior over the transition probabilities. Shape ``(n_states, n_states)``. If None, a flat (uninformative) prior is assumed. solver_name : Solver used for the GLM M-step. The solver must be valid for the chosen regularizer (see table above). Default is ``None``, in which case the regularizer's default solver is selected (``"LBFGS"`` for Ridge / UnRegularized, ``"ProximalGradient"`` for Lasso / ElasticNet / GroupLasso). solver_kwargs : Optional dictionary for keyword arguments that are passed to the solver when instantiated. E.g., stepsize, tol, acceleration, etc. maxiter : Maximum number of EM iterations. Default is 1000. tol : Convergence tolerance for the EM algorithm. The algorithm stops when the absolute change in log-likelihood between consecutive iterations falls below this threshold. Default is 1e-8. seed : JAX PRNG key for random number generation during initialization. Default is ``jax.random.PRNGKey(123)``. hmm_initialization_funcs : dict, optional Dictionary of initialization functions for HMM probabilities (initial and transition). Included for scikit-learn compatibility; prefer configuring via the :meth:`setup` method after construction. If ``None``, defaults from ``DEFAULT_INIT_FUNCTIONS`` are used. model_initialization_funcs : dict, optional Dictionary of initialization functions for the GLM-specific parameters (coefficients, intercept, and scale). Included for scikit-learn compatibility; prefer configuring via the :meth:`setup` method after construction. If ``None``, defaults from ``DEFAULT_INIT_FUNCTIONS_GLMHMM`` are used. Attributes ---------- transition_prob_ : Transition probability matrix of shape ``(n_states, n_states)``. Entry ``[i, j]`` represents the probability of transitioning from state ``i`` to state ``j``. initial_prob_ : Initial state probability vector of shape ``(n_states,)``. Entry ``[i]`` represents the probability of starting in state ``i``. coef_ : GLM coefficients for each state, shape ``(n_features, n_states)``. intercept_ : GLM intercepts (bias terms) for each state, shape ``(n_states,)``. solver_state_ : State of the solver after fitting. May include details like optimization error. scale_ : Scale parameter for the observation model, shape ``(n_states,)``. dof_resid_ : Degrees of freedom for the residuals. Notes ----- To bypass the initialization functions entirely and provide parameter arrays directly, pass them to the ``fit()`` method:: model.fit(X, y, init_params=my_params) Raises ------ TypeError If ``n_states`` is not a positive integer. TypeError If provided ``regularizer`` or ``observation_model`` are not valid. TypeError If ``seed`` is not a valid JAX PRNG key. KeyError If ``hmm_initialization_funcs`` or ``model_initialization_funcs`` contains keys that are not valid for their respective default dictionary. ValueError If any ``*_kwargs`` entry in either initialization-funcs dictionary contains keyword arguments that don't match the signature of the corresponding initialization function. ValueError If ``maxiter`` is not a positive integer. ValueError If ``tol`` is not a positive float. Examples -------- **Fit a GLM-HMM** Basic model fitting with the default Bernoulli observation model. The number of hidden states is the only required argument; ``coef_`` carries one column per state, and the HMM transition matrix and initial distribution are exposed as fitted attributes. >>> import jax >>> import numpy as np >>> import nemos as nmo >>> np.random.seed(123) >>> X = np.random.normal(size=(200, 4)) >>> y = np.random.binomial(n=1, p=0.5, size=200) >>> model = nmo.glm_hmm.GLMHMM(n_states=2).fit(X, y) >>> model.coef_.shape (4, 2) >>> model.transition_prob_.shape (2, 2) >>> model.initial_prob_.shape (2,) **Customize the Observation Model** Specify the observation model as a string: >>> model = nmo.glm_hmm.GLMHMM(n_states=2, observation_model="Poisson") >>> model.observation_model PoissonObservations() Or pass the observation model object directly: >>> model = nmo.glm_hmm.GLMHMM( ... n_states=2, observation_model=nmo.observation_models.PoissonObservations() ... ) >>> model.observation_model PoissonObservations() **Customize the Inverse Link Function** Use a soft-plus inverse link function instead of the observation-model default: >>> model = nmo.glm_hmm.GLMHMM(n_states=2, inverse_link_function=jax.nn.softplus) >>> model.inverse_link_function.__name__ 'softplus' **Change the Regularization** Regularization applies to the per-state GLM coefficients. The default is Ridge with strength ``1.0``. Tune the strength: >>> model = nmo.glm_hmm.GLMHMM(n_states=2, regularizer_strength=0.1).fit(X, y) >>> model.regularizer, float(model.regularizer_strength) (Ridge(), 0.1) Or switch to Lasso for sparse per-state coefficients (Lasso requires a proximal solver): >>> model = nmo.glm_hmm.GLMHMM( ... n_states=2, ... regularizer="Lasso", ... regularizer_strength=0.01, ... solver_name="ProximalGradient", ... ).fit(X, y) >>> model.regularizer Lasso() **Select a Solver** The solver is used for the M-step inside EM. Pick LBFGS for potentially faster convergence on smooth losses: >>> model = nmo.glm_hmm.GLMHMM(n_states=2, solver_name="LBFGS").fit(X, y) >>> model.solver_name 'LBFGS' **Fit Across Multiple Sessions** Mark session boundaries with ``session_starts`` so the HMM resets at each new session start instead of treating the data as a single chain. Pass either a boolean mask of shape ``(n_time_bins,)`` with ``True`` at each session start, or an integer array of session-start indices — the two are equivalent: >>> is_new_mask = np.zeros(200, dtype=bool) >>> is_new_mask[0] = True >>> is_new_mask[100] = True >>> model = nmo.glm_hmm.GLMHMM(n_states=2).fit(X, y, session_starts=is_new_mask) >>> # Equivalent: pass the starts as integer indices. >>> model = nmo.glm_hmm.GLMHMM(n_states=2).fit(X, y, session_starts=np.array([0, 100])) **Decode Hidden States** Recover the most-likely state sequence (Viterbi-style) or the smoothed posterior probabilities from the forward-backward pass: >>> states = model.decode_state(X, y, session_starts=is_new_mask) >>> states.shape (200, 2) >>> post = model.smooth_proba(X, y, session_starts=is_new_mask) >>> post.shape (200, 2) **Simulate from the Fitted Model** Sample a hidden-state trajectory and observations conditioned on inputs: >>> activity, rates, sim_states = model.simulate( ... jax.random.key(0), X, state_format="index" ... ) >>> activity.shape, sim_states.shape ((200,), (200,)) **Use a Dict of Arrays as Input** Features can be passed as any JAX pytree of 2-D arrays; the fitted ``coef_`` will share the same pytree structure, with the trailing axis indexing states: >>> X_dict = {"input_1": X[:, :2], "input_2": X[:, 2:]} >>> model = nmo.glm_hmm.GLMHMM(n_states=2).fit(X_dict, y) >>> type(model.coef_) <class 'dict'> """ _validator_class = GLMHMMValidator _model_default_init_dict = DEFAULT_INIT_FUNCTIONS_GLMHMM _kmeans_init_class = KMeansInitializerGLM
[docs] def __init__( self, n_states: int, observation_model: ( Observations | Literal["Poisson", "Gamma", "Bernoulli", "NegativeBinomial", "Gaussian"] ) = "Bernoulli", inverse_link_function: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, regularizer: Union[str, Regularizer] = "Ridge", regularizer_strength: Any = 1.0, # this is used to regularize GLM coef. # prior to regularize init prob and transition dirichlet_initial_proba: Union[jnp.ndarray, None] = None, # (n_state, ) dirichlet_transition_proba: Union[ jnp.ndarray | None ] = None, # (n_state, n_state) solver_name: str = None, solver_kwargs: Optional[dict] = None, maxiter: int = 1000, tol: float = 1e-8, seed=jax.random.PRNGKey(123), hmm_initialization_funcs: Optional[HMM_INITIALIZATION_FN_DICT] = None, model_initialization_funcs: Optional[GLMHMM_INITIALIZATION_FN_DICT] = None, ): super().__init__( n_states=n_states, dirichlet_initial_proba=dirichlet_initial_proba, dirichlet_transition_proba=dirichlet_transition_proba, regularizer=regularizer, regularizer_strength=regularizer_strength, solver_name=solver_name, solver_kwargs=solver_kwargs, maxiter=maxiter, tol=tol, seed=seed, hmm_initialization_funcs=hmm_initialization_funcs, ) self.observation_model = observation_model self.inverse_link_function = inverse_link_function self.model_initialization_funcs = model_initialization_funcs # fit attributes self.coef_: jnp.ndarray | None = None self.intercept_: jnp.ndarray | None = None self.solver_state_: NamedTuple | None = None self.scale_: jnp.ndarray | None = None self.dof_resid_: int | None = None # cache the log-like self._log_like_cache = {}
def _log_likelihood( self, params: GLMHMMParams, X: DESIGN_INPUT_TYPE, y: ArrayLike ) -> jnp.ndarray: """Compute the log-likelihood of the data given the model parameters. Use cached values to avoid unnecessary computations. """ cache_key = ( y.ndim > 1, self._observation_model, self._inverse_link_function, ) ll_func = self._log_like_cache.get(cache_key) if ll_func is None: ll_func = prepare_estep_log_likelihood( y.ndim > 1, self._observation_model, self._inverse_link_function ) self._log_like_cache[cache_key] = ll_func return ll_func(params, X, y)
[docs] def setup( self, initial_proba_init: Optional[ Literal["uniform", "random", "dirichlet", "kmeans"] | InitFunctionHMM ] = None, initial_proba_init_kwargs: Optional[dict] = None, transition_proba_init: Optional[ Literal["sticky", "uniform", "random", "dirichlet", "kmeans"] | InitFunctionHMM ] = None, transition_proba_init_kwargs: Optional[dict] = None, glm_params_init: Optional[Literal["random", "kmeans"] | InitFunctionGLM] = None, glm_params_init_kwargs: Optional[dict] = None, scale_init: Optional[Literal["constant", "kmeans"] | InitFunctionGLM] = None, scale_init_kwargs: Optional[dict] = None, ): """Configure how :meth:`fit` initializes each model parameter. Calling :meth:`setup` is optional: if it is never called, fitting starts from the default initializers listed below. Use it to change the initialization strategy by providing either the name of a built-in initialization function or a custom callable. Each argument left as ``None`` keeps the previously configured value; only the parameters you supply are updated. Available built-in initialization functions: - ``initial_proba_init``: ``"uniform"`` (default), ``"random"``, ``"dirichlet"``, ``"kmeans"``. - ``transition_proba_init``: ``"sticky"`` (default), ``"uniform"``, ``"random"``, ``"dirichlet"``, ``"kmeans"``. - ``glm_params_init``: ``"random"`` (default), ``"kmeans"``. - ``scale_init``: ``"constant"`` (default), ``"kmeans"``. Parameters ---------- initial_proba_init : Built-in name or custom callable used to initialize the initial-state probabilities (shape ``(n_states,)``). initial_proba_init_kwargs : Extra keyword arguments forwarded to ``initial_proba_init``. transition_proba_init : Built-in name or custom callable used to initialize the transition matrix (shape ``(n_states, n_states)``). transition_proba_init_kwargs : Extra keyword arguments forwarded to ``transition_proba_init``. glm_params_init : Built-in name or custom callable used to initialize the per-state GLM coefficients and intercepts. glm_params_init_kwargs : Extra keyword arguments forwarded to ``glm_params_init``. scale_init : Built-in name or custom callable used to initialize the scale parameter of the observation model (e.g. variance for Gaussian, dispersion for NegativeBinomial). Ignored by observation models without a scale. scale_init_kwargs : Extra keyword arguments forwarded to ``scale_init``. Raises ------ ValueError If a custom callable's signature is incompatible with the protocol above, or if a ``*_kwargs`` entry contains keys that don't match the corresponding initializer's signature. Notes ----- Custom callables must satisfy one of two ``typing.Protocol`` classes: - ``initial_proba_init`` and ``transition_proba_init`` must satisfy :class:`~nemos.hmm.initialize_parameters.InitFunctionHMM` and return a ``jnp.ndarray`` of shape ``(n_states,)`` or ``(n_states, n_states)`` respectively. - ``glm_params_init`` and ``scale_init`` must satisfy :class:`~nemos.glm_hmm.initialize_parameters.InitFunctionGLM`. ``glm_params_init`` returns ``(coef, intercept)`` matched to the design and ``n_states``; ``scale_init`` returns the scale array for the observation model. To inspect a protocol's signature, import and ``help()`` it:: from nemos.hmm.initialize_parameters import InitFunctionHMM from nemos.glm_hmm.initialize_parameters import InitFunctionGLM help(InitFunctionHMM) # or help(InitFunctionGLM) All arguments must appear in the function signature even when unused, so the framework can supply them uniformly. Examples -------- Switch a parameter to a different built-in scheme by passing its label: >>> from nemos.glm_hmm import GLMHMM >>> model = GLMHMM(n_states=3) >>> model.setup(initial_proba_init="random", glm_params_init="kmeans") Plug in a custom callable matching the GLM-side protocol: >>> import jax.numpy as jnp >>> def my_glm_init( ... n_states, X, y, inverse_link_function, observation_model, ... session_starts, random_key, ... ): ... coef = jnp.zeros((X.shape[1], n_states)) ... intercept = jnp.zeros((n_states,)) ... return coef, intercept >>> model.setup(glm_params_init=my_glm_init) """ super().setup( initial_proba_init=initial_proba_init, initial_proba_init_kwargs=initial_proba_init_kwargs, transition_proba_init=transition_proba_init, transition_proba_init_kwargs=transition_proba_init_kwargs, glm_params_init=glm_params_init, glm_params_init_kwargs=glm_params_init_kwargs, scale_init=scale_init, scale_init_kwargs=scale_init_kwargs, )
def _model_setup( self, glm_params_init: Optional[str | Callable] = None, glm_params_init_kwargs=None, scale_init: Optional[str | Callable] = None, scale_init_kwargs=None, ): """Validate and set GLM-side initialization functions. Derives ``_model_use_kmeans`` from the identity of the stored callables so the flag stays accurate regardless of whether the user passed the string ``"kmeans"`` or the kmeans callable directly (e.g. when set via the ``model_initialization_funcs`` property). """ self._model_initialization_funcs = setup_glm_hmm_initialization( glm_params_init=glm_params_init, glm_params_init_kwargs=glm_params_init_kwargs, scale_init=scale_init, scale_init_kwargs=scale_init_kwargs, init_funcs=self._model_initialization_funcs, ) self._model_use_kmeans = { "glm_params_init": ( self._model_initialization_funcs["glm_params_init"] is kmeans_glm_params_init ), "scale_init": ( self._model_initialization_funcs["scale_init"] is kmeans_scale_init ), } @property def observation_model(self) -> obs.Observations: """The observation model governing the emission distribution at each state. Always an instance of an :class:`~nemos.observation_models.Observations` subclass. The same distribution is used across all hidden states (per-state differences come from the state-specific coefficients/intercept/scale, not from the family). If a string alias was passed at construction time it is resolved to the corresponding instance here. """ return self._observation_model @observation_model.setter def observation_model(self, observation: obs.Observations): """Validate and set the observation model. Parameters ---------- observation : Either an :class:`~nemos.observation_models.Observations` instance, or a string alias from ``{"Poisson", "Gamma", "Bernoulli", "NegativeBinomial", "Gaussian"}``. String aliases are instantiated via :func:`nemos.observation_models.instantiate_observation_model`. Raises ------ AttributeError, TypeError If the instance does not implement the :class:`~nemos.observation_models.Observations` interface (checked via :func:`nemos.observation_models.check_observation_model`). """ if isinstance(observation, str): self._observation_model = instantiate_observation_model(observation) return # check that the model has the required attributes # and that the attribute can be called obs.check_observation_model(observation) self._observation_model = observation @property def inverse_link_function(self): """Inverse link function mapping the linear predictor to the emission space. Always a callable. If ``None`` was passed at construction time, this is resolved to the observation model's default (e.g. ``jnp.exp`` for Poisson, ``1 / x`` for Gamma, ``jax.nn.sigmoid`` for Bernoulli). Shared across all hidden states. """ return self._inverse_link_function @inverse_link_function.setter def inverse_link_function(self, inverse_link_function: Callable): """Validate and set the inverse link function. Parameters ---------- inverse_link_function : One of: - ``None`` — use the observation model's default inverse link. - ``str`` — name of a built-in (e.g. ``"identity"``, ``"log"``, ``"logit"``); resolved by :func:`nemos.inverse_link_function_utils.resolve_inverse_link_function`. - ``Callable`` — a custom function. Must be JAX-traceable (differentiable) and return a ``jax.numpy.ndarray`` or scalar when called on a JAX array. Raises ------ TypeError If the value is neither callable nor a string. ValueError If a callable is non-differentiable or returns an unsupported type. """ self._inverse_link_function = resolve_inverse_link_function( inverse_link_function, self._observation_model ) def _check_model_is_fit(self): """Ensure the instance has been fitted.""" flat_params = [ self.coef_, self.intercept_, self.scale_, ] is_missing = [x is None for x in flat_params] if any(is_missing): param_labels = [ "coef_", "intercept_", "scale_", ] missing_params = [ p for p, missing in zip(param_labels, is_missing) if missing ] raise ValueError( "This GLMHMM instance is not fitted yet. The following attributes are not set:" f" {missing_params}.\nPlease fit the GLM-HMM model first or " "set the missing attributes." ) def _kmeans_extra_kwargs(self) -> dict: return { "inverse_link_function": self.inverse_link_function, "observation_model": self.observation_model, } def _model_params_initialization( self, X: DESIGN_INPUT_TYPE, y: jnp.ndarray, session_starts: jnp.ndarray, random_key: jax.Array, ) -> Tuple[GLMHMMUserParams, bool]: """GLM-HMM initialization.""" user_params = generate_glm_hmm_initial_model_params( self._n_states, X, y, inverse_link_function=self._inverse_link_function, observation_model=self._observation_model, session_starts=session_starts, random_key=random_key, init_funcs=self._model_initialization_funcs, ) validate_params = any( self._model_initialization_funcs[s] for s in ("glm_params_init_custom", "scale_init_custom") ) return user_params, validate_params
[docs] def fit( self, X: DESIGN_INPUT_TYPE, y: Union[NDArray, jnp.ndarray, nap.Tsd], init_params: Optional[GLMHMMUserParams] = None, session_starts: Optional[jnp.ndarray] = None, ) -> "GLMHMM": """Fit the GLM-HMM via Expectation-Maximization. Runs the EM algorithm until the absolute change in log-likelihood between consecutive iterations falls below ``tol`` or ``maxiter`` is reached. Fitted parameters are exposed on the instance as ``coef_``, ``intercept_``, ``scale_``, ``initial_prob_``, ``transition_prob_``, plus ``solver_state_`` (EM trace) and ``dof_resid_``. How parameters are initialized: - If ``init_params`` is ``None`` (typical), the per-state GLM parameters and HMM probabilities are produced by the initializers configured via :meth:`setup` (or the package defaults when :meth:`setup` was never called). - If ``init_params`` is provided, it bypasses the initializers entirely. It must be a 5-tuple ``(coef, intercept, scale, initial_prob, transition_prob)`` whose shapes are consistent with ``X``, ``y``, and ``n_states``. Parameters ---------- X : Predictors, shape ``(n_time_bins, n_features)``. A pytree of arrays sharing leading dimension is also accepted; the fitted ``coef_`` mirrors the pytree structure (with a trailing state axis). A pynapple ``TsdFrame`` is accepted. y : Observations, shape ``(n_time_bins,)`` for single neuron or ``(n_time_bins, n_neurons)`` for population models. A pynapple ``Tsd``/``TsdFrame`` is accepted. init_params : Optional explicit initial parameters as a 5-tuple ``(coef, intercept, scale, initial_prob, transition_prob)``. When ``None`` (default), the initializers configured by :meth:`setup` (or the defaults) are used. session_starts : Optional session boundaries for the HMM. Accepts: - a boolean array of shape ``(n_time_bins,)`` with ``True`` at each session start, - an integer array of session-start indices, - a pynapple ``IntervalSet`` (requires ``X`` or ``y`` to be a pynapple object to supply timestamps). If ``X`` or ``y`` is a pynapple object and ``session_starts`` is ``None``, the (unique, enforced) ``time_support`` of the pynapple input determines the session starts. With no pynapple input and ``session_starts=None``, the whole input is treated as a single session. Returns ------- self : The fitted estimator. Raises ------ ValueError If inputs fail dimensionality, shape, or consistency checks (e.g. ``coef`` features do not match ``X.shape[1]``, or NaNs appear mid-epoch). TypeError If ``init_params`` is not a 5-tuple or has incompatible leaf types. Warns ----- RuntimeWarning Emitted when EM runs out of iterations without satisfying the ``tol`` criterion (``solver_state_.iterations == maxiter``). Consider enabling float64, raising ``maxiter``, or loosening ``tol``. Examples -------- Basic fit with default Bernoulli observations: >>> import numpy as np >>> import nemos as nmo >>> np.random.seed(0) >>> X = np.random.normal(size=(200, 4)) >>> y = np.random.binomial(n=1, p=0.5, size=200) >>> model = nmo.glm_hmm.GLMHMM(n_states=2).fit(X, y) >>> model.coef_.shape, model.transition_prob_.shape ((4, 2), (2, 2)) Multiple sessions via explicit ``session_starts``: >>> session_starts = np.array([0, 100]) >>> model = nmo.glm_hmm.GLMHMM(n_states=2).fit(X, y, session_starts=session_starts) See Also -------- setup : Configure the initializers used when ``init_params is None``. update : Run a single EM iteration (advanced, manual loop). """ self._validator.validate_inputs(X=X, y=y) # validate and cast session boundaries, shifting markers off NaN samples session_starts = self._validator.validate_and_cast_session_starts( X, y, session_starts=session_starts ) # validate the inputs & initialize solver # initialize params if no params are provided if init_params is None: init_params = self._model_specific_initialization(X, y, session_starts) else: init_params = self._validator.validate_and_cast_params(init_params) self._validator.validate_consistency(init_params, X=X, y=y) self._validator.feature_mask_consistency( getattr(self, "_feature_mask", None), init_params ) # filter for non-nans, grab data if needed data, y, session_starts = self._preprocess_inputs(X, y, session_starts) # set up optimization self._initialize_optimizer_and_state(init_params, data, y) # run EM ( fit_params, self.solver_state_, ) = self._optimizer_run(init_params, X=data, y=y, session_starts=session_starts) if self.solver_state_.iterations == self.maxiter: warnings.warn( "The fit did not converge. " "Consider the following:" "\n1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``" "\n2) Increase the ``maxiter`` parameter (max number of iterations of the EM) " "or increase the ``tol`` parameter (tolerance).", RuntimeWarning, ) # assign fit attributes self._set_model_params(fit_params) self.dof_resid_ = self._estimate_resid_degrees_of_freedom(data) return self
def _estimate_resid_degrees_of_freedom( self, X: DESIGN_INPUT_TYPE, n_samples: Optional[int] = None ): """ Estimate the degrees of freedom of the residuals. Parameters ---------- self : A fitted GLM model. X : The design matrix. n_samples : The number of samples observed. If not provided, n_samples is set to ``X.shape[0]``. If the fit is batched, the n_samples could be larger than ``X.shape[0]``. Returns ------- : An estimate of the degrees of freedom of the residuals. """ # Convert a pytree to a design-matrix with pytrees X = jnp.hstack(jax.tree_util.tree_leaves(X)) if n_samples is None: n_samples = X.shape[0] else: if not isinstance(n_samples, int): raise TypeError( f"`n_samples` must either `None` or of type `int`. Type {type(n_samples)} provided " "instead!" ) params = self._get_model_params() coef = params.model_params.coef coef_leaf = jax.tree_util.tree_leaves(coef)[0] if coef_leaf.ndim == 3: n_neurons = coef_leaf.shape[1] else: n_neurons = 1 dof_intercept_and_hmm = ( self._n_states * n_neurons # intercept + ( self._n_states - 1 ) # init prob (n values but sum to 1, so n-1 free values) + (self._n_states - 1) * self._n_states ) # transition n n-dim vectors that sum to 1 # if the regularizer is lasso use the non-zero # coef as an estimate of the dof # see https://arxiv.org/abs/0712.0881 if isinstance(self.regularizer, (GroupLasso, Lasso)): resid_dof = sum( pytree_map_and_reduce( lambda x: ~jnp.isclose(x, jnp.zeros_like(x)), lambda x: sum([jnp.sum(i, axis=0) for i in x]), coef, ) ) return n_samples - resid_dof - dof_intercept_and_hmm elif isinstance(self.regularizer, Ridge): # for Ridge, use the tot parameters (X.shape[1] + intercept) return ( n_samples - (X.shape[1] * self.n_states) - dof_intercept_and_hmm ) * jnp.ones(n_neurons) else: # for UnRegularized, use the rank rank = jnp.linalg.matrix_rank(X) return (n_samples - rank - dof_intercept_and_hmm) * jnp.ones(n_neurons) def _simulate( self, random_key: jax.Array, params: GLMHMMParams, X: jnp.ndarray, session_starts: jnp.ndarray, ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Simulate activity vis jax.lax.scan. Parameters ---------- random_key : JAX random key. params : Model parameters. X : Design matrix of shape ``(n_time_bins, n_features)``. session_starts : Boolean array marking session starts. Returns ------- simulated_activity : Simulated observations. firing_rates : Predicted rates conditioned on simulated states. simulated_states : State indices at each time point. """ # unpack log probabilities directly (avoid exp then log in categorical) log_initial_prob = params.hmm_params.log_initial_prob log_transition_prob = params.hmm_params.log_transition_prob scale = jnp.exp(params.model_params.log_scale) # pre-compute rates for all states: (n_time_bins, n_states) or (n_time_bins, n_neurons, n_states) all_rates = compute_rate_per_state( X, params.model_params, self._inverse_link_function ) # pre-generate random keys for all time steps n_time_bins = jax.tree_util.tree_leaves(X)[0].shape[0] all_keys = jax.random.split(random_key, n_time_bins * 2) state_keys = all_keys[:n_time_bins] obs_keys = all_keys[n_time_bins:] def simulate_step(carry, inputs): """Single simulation step.""" prev_state_idx = carry rates_t, is_new_sess, state_key, obs_key = inputs # sample state: log_initial_prob if new session, else log transition from prev log_state_probs = jax.lax.cond( is_new_sess, lambda: log_initial_prob, lambda: log_transition_prob[prev_state_idx], ) state_idx = jax.random.categorical(state_key, log_state_probs) # get rate and scale for sampled state # handles both (n_states,) and (n_neurons, n_states) rate = rates_t[..., state_idx] state_scale = scale[..., state_idx] # sample observation y_t = self._observation_model.sample_generator( key=obs_key, predicted_rate=rate, scale=state_scale ) return state_idx, (y_t, rate, state_idx) # initialize carry (state will be overwritten at first step since session_starts[0]=True) init_carry = jnp.array(0) # run scan _, (simulated_activity, firing_rates, simulated_states) = jax.lax.scan( simulate_step, init_carry, (all_rates, session_starts, state_keys, obs_keys) ) return simulated_activity, firing_rates, simulated_states
[docs] def simulate( self, random_key: jax.Array, feedforward_input: DESIGN_INPUT_TYPE, state_format: Literal["one-hot", "index"] = "index", session_starts: Optional[jax.Array] = None, ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Simulate neural activity and hidden states from the model. Simulates a trajectory through the hidden state space according to the HMM dynamics, then generates observations from the GLM emission model conditioned on each state. Parameters ---------- random_key : JAX random key for reproducible simulation. feedforward_input : Design matrix of shape ``(n_time_bins, n_features)``. If a pynapple Tsd/TsdFrame is provided, session boundaries are detected from ``time_support`` and the hidden state chain is reset at each session start. state_format : Format for the returned states: - ``"index"``: Integer array of shape ``(n_time_bins,)`` with state indices. - ``"one-hot"``: Binary array of shape ``(n_time_bins, n_states)``. session_starts : Optional session boundaries. Accepts: - a boolean array of shape ``(n_time_bins,)`` with ``True`` at each session start, - an integer array of session-start indices, - a pynapple ``IntervalSet`` (requires ``feedforward_input`` to be a pynapple object to supply timestamps). If ``feedforward_input`` is a pynapple object and ``session_starts`` is ``None``, the ``time_support`` determines the session starts. With no pynapple input and ``session_starts=None``, the whole input is treated as a single session. Returns ------- simulated_activity : Simulated observations from the emission model. Shape ``(n_time_bins,)`` for single neuron or ``(n_time_bins, n_neurons)`` for population models. firing_rates : Predicted firing rates conditioned on the simulated states. Shape ``(n_time_bins,)`` or ``(n_time_bins, n_neurons)``. simulated_states : Simulated hidden state trajectory. Shape depends on ``state_format``. Raises ------ ValueError If the model has not been fit. Examples -------- >>> import jax >>> import numpy as np >>> import nemos as nmo >>> np.random.seed(123) >>> X = np.random.randn(100, 3) >>> y = np.random.binomial(1, 0.5, 100) >>> model = nmo.glm_hmm.GLMHMM(n_states=2, observation_model="Bernoulli") >>> model = model.fit(X, y) >>> key = jax.random.key(0) >>> X_new = np.random.randn(50, 3) >>> activity, rates, states = model.simulate(key, X_new) >>> activity.shape (50,) >>> states.shape (50,) See Also -------- decode_state : Infer most likely state sequence from observations. smooth_proba : Compute posterior state probabilities. """ _check_state_format(state_format) params, feedforward_input, _, session_starts = ( self._validate_and_prepare_inputs(feedforward_input, None, session_starts) ) # preprocess inputs (drop nans, extract data) data, _, session_starts = self._preprocess_inputs( feedforward_input, None, session_starts ) # run simulation simulated_activity, firing_rates, simulated_states = self._simulate( random_key, params, data, session_starts ) # format state output if state_format == "one-hot": simulated_states = jax.nn.one_hot( simulated_states, self._n_states, dtype=jnp.int32 ) return simulated_activity, firing_rates, simulated_states
[docs] def smooth_proba( self, X: Union[DESIGN_INPUT_TYPE, ArrayLike], y: Union[NDArray, jnp.ndarray, nap.Tsd], session_starts: Optional[ArrayLike] = None, ) -> jnp.ndarray | nap.TsdFrame: """Compute smoothing posterior probabilities over hidden states. Computes the probability of being in each hidden state at each time bin, conditioned on the entire observed sequence. Uses the forward-backward algorithm to incorporate information from both past and future observations, providing optimal state estimates given all available data. The smoothing posteriors answer: "Given all observations, what is the probability that the system was in state ``k`` at time ``t``?" Parameters ---------- X : Predictors, shape ``(n_time_bins, n_features)``. A pytree of 2-D arrays sharing the leading time axis is also accepted. y : Observations, shape ``(n_time_bins,)`` for a single neuron or ``(n_time_bins, n_neurons)`` for a population model. A pynapple ``Tsd``/``TsdFrame`` is accepted; session boundaries are then inferred from ``time_support``. session_starts : Optional session boundaries. Accepts: - a boolean array of shape ``(n_time_bins,)`` with ``True`` at each session start, - an integer array of session-start indices, - a pynapple ``IntervalSet`` (requires ``X`` or ``y`` to be a pynapple object to supply timestamps). If ``None``, the entire input is treated as a single session. Returns ------- posteriors : Smoothing posterior probabilities, shape ``(n_time_bins, n_states)``. Each row sums to 1. Returns a pynapple ``TsdFrame`` (with columns named ``"state_0"``, ``"state_1"``, …) when the inputs are pynapple objects; otherwise returns a JAX array. Raises ------ ValueError If the model has not been fitted (call :meth:`fit` first). ValueError If ``X`` or ``y`` contain NaN values in the interior of an epoch (boundary NaNs are allowed and removed before inference). ValueError If ``X`` and ``y`` have inconsistent shapes or feature counts. See Also -------- filter_proba : Compute filtering posteriors (conditioned on past observations only). decode_state : Compute the most likely state sequence via Viterbi decoding. Notes ----- Smoothing uses all data (non-causal) and gives better state estimates than filtering. For online or real-time applications use :meth:`filter_proba` instead. Session boundaries reset the HMM chain so that no information crosses session borders. Examples -------- Fit a GLM-HMM and compute smoothing posteriors: >>> import numpy as np >>> import nemos as nmo >>> np.random.seed(123) >>> X = np.random.randn(100, 5) >>> y = np.random.poisson(2, size=100) >>> model = nmo.glm_hmm.GLMHMM(n_states=3, observation_model="Poisson").fit(X, y) >>> posteriors = model.smooth_proba(X, y) >>> posteriors.shape (100, 3) >>> bool(np.allclose(posteriors.sum(axis=1), 1.0)) True With pynapple inputs the result is returned as a ``TsdFrame``: >>> import pynapple as nap >>> t = np.arange(100) * 0.01 >>> X_tsd = nap.TsdFrame(t=t, d=X) >>> y_tsd = nap.Tsd(t=t, d=y.astype(float)) >>> type(model.smooth_proba(X_tsd, y_tsd)).__name__ 'TsdFrame' """ return super().smooth_proba(X, y, session_starts=session_starts)
[docs] def filter_proba( self, X: Union[DESIGN_INPUT_TYPE, ArrayLike], y: Union[NDArray, jnp.ndarray, nap.Tsd], session_starts: Optional[ArrayLike] = None, ) -> jnp.ndarray | nap.TsdFrame: """Compute filtering posterior probabilities over hidden states. Computes the probability of being in each hidden state at each time bin, conditioned only on observations up to that time bin. Uses the forward pass of the forward-backward algorithm, providing causal (online) state estimates that rely solely on past and current observations. The filtering posteriors answer: "Given observations up to time ``t``, what is the probability that the system is in state ``k`` at time ``t``?" Parameters ---------- X : Predictors, shape ``(n_time_bins, n_features)``. A pytree of 2-D arrays sharing the leading time axis is also accepted. y : Observations, shape ``(n_time_bins,)`` for a single neuron or ``(n_time_bins, n_neurons)`` for a population model. A pynapple ``Tsd``/``TsdFrame`` is accepted; session boundaries are then inferred from ``time_support``. session_starts : Optional session boundaries. Accepts: - a boolean array of shape ``(n_time_bins,)`` with ``True`` at each session start, - an integer array of session-start indices, - a pynapple ``IntervalSet`` (requires ``X`` or ``y`` to be a pynapple object to supply timestamps). If ``None``, the entire input is treated as a single session. Returns ------- posteriors : Filtering posterior probabilities, shape ``(n_time_bins, n_states)``. Each row sums to 1. Returns a pynapple ``TsdFrame`` (with columns named ``"state_0"``, ``"state_1"``, …) when the inputs are pynapple objects; otherwise returns a JAX array. Raises ------ ValueError If the model has not been fitted (call :meth:`fit` first). ValueError If ``X`` or ``y`` contain NaN values in the interior of an epoch (boundary NaNs are allowed and removed before inference). ValueError If ``X`` and ``y`` have inconsistent shapes or feature counts. See Also -------- smooth_proba : Compute smoothing posteriors (conditioned on all observations). decode_state : Compute the most likely state sequence via Viterbi decoding. Notes ----- Filtering is causal: each posterior at time ``t`` uses only observations up to ``t``, making it suitable for online or real-time applications. For retrospective analysis where all data are available, :meth:`smooth_proba` gives better state estimates. Session boundaries reset the HMM chain so that no information crosses session borders. Examples -------- Fit a GLM-HMM and compute filtering posteriors (causal/online): >>> import numpy as np >>> import nemos as nmo >>> np.random.seed(123) >>> X = np.random.randn(100, 5) >>> y = np.random.poisson(2, size=100) >>> model = nmo.glm_hmm.GLMHMM(n_states=3, observation_model="Poisson").fit(X, y) >>> filt = model.filter_proba(X, y) >>> filt.shape (100, 3) >>> bool(np.allclose(filt.sum(axis=1), 1.0)) True With pynapple inputs the result is returned as a ``TsdFrame``: >>> import pynapple as nap >>> t = np.arange(100) * 0.01 >>> X_tsd = nap.TsdFrame(t=t, d=X) >>> y_tsd = nap.Tsd(t=t, d=y.astype(float)) >>> type(model.filter_proba(X_tsd, y_tsd)).__name__ 'TsdFrame' """ return super().filter_proba(X, y, session_starts=session_starts)
[docs] def decode_state( self, X: Union[DESIGN_INPUT_TYPE, ArrayLike], y: ArrayLike, session_starts: Optional[ArrayLike] = None, state_format: Literal["one-hot", "index"] = "one-hot", ) -> jnp.ndarray | nap.TsdFrame: """Compute the most likely hidden state sequence (Viterbi decoding). Finds the single most likely sequence of hidden states that best explains the observed data. Uses the Viterbi (max-sum) algorithm to compute the state sequence that maximizes the joint probability of states and observations. Unlike :meth:`smooth_proba` and :meth:`filter_proba`, which return a probability distribution over states at each time bin, this method makes a hard assignment to the single globally optimal state path. The decoded states answer: "What is the most likely sequence of states that generated the observed data?" Parameters ---------- X : Predictors, shape ``(n_time_bins, n_features)``. A pytree of 2-D arrays sharing the leading time axis is also accepted. y : Observations, shape ``(n_time_bins,)`` for a single neuron or ``(n_time_bins, n_neurons)`` for a population model. A pynapple ``Tsd``/``TsdFrame`` is accepted; session boundaries are then inferred from ``time_support``. session_starts : Optional session boundaries. Accepts: - a boolean array of shape ``(n_time_bins,)`` with ``True`` at each session start, - an integer array of session-start indices, - a pynapple ``IntervalSet`` (requires ``X`` or ``y`` to be a pynapple object to supply timestamps). If ``None``, the entire input is treated as a single session. state_format : Format of the returned state sequence: - ``"one-hot"`` (default): binary array of shape ``(n_time_bins, n_states)`` with a single 1 per row. - ``"index"``: integer array of shape ``(n_time_bins,)`` with values in ``[0, n_states - 1]``. Returns ------- decoded_states : Most likely state sequence. Shape and dtype depend on ``state_format`` (see above). Returns a pynapple ``TsdFrame`` (columns ``"state_0"``, ``"state_1"``, …) for ``"one-hot"`` format or a pynapple ``Tsd`` for ``"index"`` format when the inputs are pynapple objects; otherwise returns a JAX array. Raises ------ ValueError If the model has not been fitted (call :meth:`fit` first). ValueError If ``state_format`` is not ``"one-hot"`` or ``"index"``. ValueError If ``X`` or ``y`` contain NaN values in the interior of an epoch (boundary NaNs are allowed and removed before inference). ValueError If ``X`` and ``y`` have inconsistent shapes or feature counts. See Also -------- smooth_proba : Compute smoothing posteriors (soft, probabilistic state assignments). filter_proba : Compute filtering posteriors (causal, conditioned on past observations). Notes ----- Viterbi decoding finds the globally optimal state *sequence*, which can differ from the sequence of states that are individually most probable at each time bin (as returned by :meth:`smooth_proba`). For uncertainty estimates use :meth:`smooth_proba` instead. Session boundaries reset the Viterbi recursion so that no path crosses session borders. Examples -------- Decode the most likely state sequence as integer indices: >>> import numpy as np >>> import nemos as nmo >>> np.random.seed(123) >>> X = np.random.randn(100, 5) >>> y = np.random.poisson(2, size=100) >>> model = nmo.glm_hmm.GLMHMM(n_states=3, observation_model="Poisson").fit(X, y) >>> states = model.decode_state(X, y, state_format="index") >>> states.shape (100,) One-hot output (default): >>> states_onehot = model.decode_state(X, y) >>> states_onehot.shape (100, 3) >>> bool(np.all(states_onehot.sum(axis=1) == 1)) True """ return super().decode_state( X, y, session_starts=session_starts, state_format=state_format )
[docs] def save_params( self, filename: Union[str, Path], ): """Save GLM-HMM model parameters and fit state to a .npz file. Persists hyperparameters returned by :meth:`get_params` together with the fitted attributes (``coef_``, ``intercept_``, ``scale_``, ``initial_prob_``, ``transition_prob_``, ``dof_resid_``). The ``solver_state_`` is intentionally excluded as it is solver-specific and not needed to reuse the fitted model. The file can be reloaded with :func:`nemos.load_model`. If the model was configured with custom initialization functions, pass them back to :func:`nemos.load_model` via ``mapping_dict`` to restore them (see example below). Built-in initializers are resolved automatically. Parameters ---------- filename : Path of the output file (``.npz`` format). Examples -------- Default round-trip — built-in initializers are resolved automatically on load: >>> import os, tempfile >>> import numpy as np >>> import nemos as nmo >>> np.random.seed(0) >>> X = np.random.normal(size=(80, 3)) >>> y = np.random.binomial(n=1, p=0.5, size=80) >>> model = nmo.glm_hmm.GLMHMM(n_states=2).fit(X, y) >>> with tempfile.TemporaryDirectory() as d: ... path = os.path.join(d, "glmhmm.npz") ... model.save_params(path) ... loaded = nmo.load_model(path) >>> bool(np.allclose(model.coef_, loaded.coef_)) True Round-trip with a custom GLM-params initializer. Pass it back as a partial dict under ``model_initialization_funcs``; remaining slots fall back to the saved (built-in) names: >>> import jax.numpy as jnp >>> def my_glm_init( ... n_states, X, y, inverse_link_function, observation_model, ... session_starts, random_key, ... ): ... return jnp.zeros((X.shape[1], n_states)), jnp.zeros((n_states,)) >>> model = nmo.glm_hmm.GLMHMM(n_states=2) >>> model.setup(glm_params_init=my_glm_init) >>> _ = model.fit(X, y) >>> with tempfile.TemporaryDirectory() as d: ... path = os.path.join(d, "glmhmm.npz") ... model.save_params(path) ... loaded = nmo.load_model( ... path, ... mapping_dict={ ... "model_initialization_funcs": {"glm_params_init": my_glm_init}, ... }, ... ) >>> loaded.model_initialization_funcs["glm_params_init"] is my_glm_init True """ # initialize saving dictionary fit_attrs = self._get_fit_state() fit_attrs.pop("solver_state_", None) string_attrs = ["inverse_link_function"] self._save_params(filename, fit_attrs, string_attrs)
# SVRG specific optimization not available. def _get_optimal_solver_params_config(self): """No optimal parameters known for SVRG in GLMHMM.""" return None, None, None def _get_model_params(self) -> GLMHMMParams: return self._validator.to_model_params( ( self.coef_, self.intercept_, self.scale_, self.initial_prob_, self.transition_prob_, ) ) def _set_model_params(self, params: GLMHMMParams): coef, intercept, scale, initial_prob, transition_prob = ( self._validator.from_model_params(params) ) self.coef_ = coef self.intercept_ = intercept self.scale_ = scale self.initial_prob_ = initial_prob self.transition_prob_ = transition_prob
[docs] def update( self, params: GLMHMMUserParams, opt_state: NamedTuple, X: DESIGN_INPUT_TYPE, y: jnp.ndarray, *args, session_starts: Optional[jnp.ndarray] = None, n_samples: Optional[int] = None, **kwargs, ) -> StepResult: """Run a single EM iteration on the GLM-HMM. Performs one E-step / M-step pair starting from the supplied parameters and EM state, updates the model's fitted attributes (``coef_``, ``intercept_``, ``scale_``, ``initial_prob_``, ``transition_prob_``, ``solver_state_``, ``dof_resid_``) in place, and returns the updated parameter tuple and EM state. Intended for callers that need fine-grained control over EM iteration (e.g. checkpointing, custom convergence criteria) instead of the bundled :meth:`fit` loop. :meth:`initialize_optimizer_and_state` must be called first so that the EM step function and initial ``opt_state`` are available. Parameters ---------- params : Current model parameters as a 5-tuple ``(coef, intercept, scale, initial_prob, transition_prob)`` matching the structure produced by :meth:`initialize_params`. opt_state : EM state returned by :meth:`initialize_optimizer_and_state` or by the previous call to :meth:`update`. X : Predictors, shape ``(n_time_bins, n_features)`` (or a pytree of arrays of the same shape). y : Observations, shape ``(n_time_bins,)`` or ``(n_time_bins, n_neurons)``. session_starts : Optional session boundaries. Accepts: - a boolean array of shape ``(n_time_bins,)`` with ``True`` at each session start, - an integer array of session-start indices, - a pynapple ``IntervalSet`` (requires ``X`` or ``y`` to be a pynapple object to supply timestamps). If ``None``, the entire input is treated as a single session. n_samples : Total sample count to use when estimating the residual degrees of freedom. Defaults to ``X.shape[0]``. Returns ------- params : Updated user-facing parameter tuple. state : Updated EM state. Raises ------ ValueError If inputs fail shape/consistency validation. Examples -------- >>> import numpy as np >>> import nemos as nmo >>> np.random.seed(0) >>> X = np.random.normal(size=(80, 3)) >>> y = np.random.binomial(n=1, p=0.5, size=80) >>> model = nmo.glm_hmm.GLMHMM(n_states=2) >>> init_params = model.initialize_params(X, y) >>> opt_state = model.initialize_optimizer_and_state(init_params, X, y) >>> new_params, new_state = model.update(init_params, opt_state, X, y) """ # validate inputs and session boundaries self._validator.validate_inputs(X=X, y=y) session_starts = self._validator.validate_and_cast_session_starts( X, y, session_starts=session_starts ) # drop nans and pull pytree data data, y, session_starts = self._preprocess_inputs(X, y, session_starts) # wrap into model params (assumes init was done via # `initialize_optimizer_and_state` so the EM step function is in place) params = self._validator.to_model_params(params) # one EM step updated_params, updated_state = self._optimizer_update( params, opt_state, data, y, session_starts=session_starts ) # persist self._set_model_params(updated_params) self.solver_state_ = updated_state self.dof_resid_ = self._estimate_resid_degrees_of_freedom( data, n_samples=n_samples ) return self._validator.from_model_params(updated_params), updated_state
[docs] def __repr__(self) -> str: """Hierarchical repr for the GLMHMM class.""" return format_repr( self, multiline=True, use_name_keys=["inverse_link_function"] )
def _initialize_optimizer_and_state( self, init_params: ModelParamsT, X: DESIGN_INPUT_TYPE, y: jnp.ndarray, ) -> SolverState: """Initialize the optimizer and state of the model.""" # glm params m-step setup is_population = y.ndim > 1 m_step_update = prepare_mstep_update_fn( is_population_glm=is_population, observation_model=self._observation_model, inverse_link_function=self._inverse_link_function, setup_solver=self._instantiate_solver, init_params=init_params.model_params, ) # cannot wrap session_starts, that's to be calculated at each update form the provided X and y. # for consistency, do not make a partial of that argument in run as well. self._optimizer_run = eqx.Partial( em_hmm, log_likelihood_func=self._log_likelihood, m_step_fn_model_params=m_step_update, maxiter=self.maxiter, tol=self.tol, ) self._optimizer_update = eqx.Partial( em_step, log_likelihood_func=self._log_likelihood, m_step_fn_model_params=m_step_update, ) def init_state_fn(*args, **kwargs) -> SolverState: state = EMState( data_log_likelihood=-jnp.array(jnp.inf), previous_data_log_likelihood=-jnp.array(jnp.inf), log_likelihood_history=jnp.full(self.maxiter, jnp.nan), iterations=0, converged=False, ) return state self._optimizer_init_state = init_state_fn return init_state_fn()