Download

Download this notebook: save_and_load.ipynb!

Saving and Loading#

Saving and Loading a Model#

In nemos, you can save a model by calling the save_params() method, which writes a npz file, a NumPy-specific binary format.

import nemos as nmo

# define a ridge regularized glm, with LBFGS solver
model = nmo.glm.GLM(
    regularizer="Ridge",
    solver_name="LBFGS"
)

# save
model.save_params("ridge_glm_params.npz")


# load
loaded_model = nmo.load_model("ridge_glm_params.npz")

print("Original Model: \n", model)
print("\nLoaded Model: \n", loaded_model)
Original Model: 
 GLM(
    observation_model=PoissonObservations(),
    inverse_link_function=exp,
    regularizer=Ridge(),
    regularizer_strength=1.0,
    solver_name='LBFGS'
)

Loaded Model: 
 GLM(
    observation_model=PoissonObservations(),
    inverse_link_function=exp,
    regularizer=Ridge(),
    regularizer_strength=1.0,
    solver_name='LBFGS'
)

Saving and Loading a Fitted Model#

The same workflow works for fitted models, meaning the learned coefficients and intercepts are also saved and restored:

import numpy as np

# generate some data
np.random.seed(123)
X, weights = np.random.randn(50, 1), 0.1 * np.random.randn(1)
counts = np.random.poisson(np.exp(X @ weights))

# fit and save
model.fit(X, counts)
model.save_params("ridge_glm_params_fitted.npz")

# load
loaded_model = nmo.load_model("ridge_glm_params_fitted.npz")

print("Original coefficient and intercept:", model.coef_, model.intercept_)
print("Loaded coefficient and intercept:", loaded_model.coef_, loaded_model.intercept_)
Original coefficient and intercept: [-0.00710396] [0.05829372]
Loaded coefficient and intercept: [-0.00710396] [0.05829372]

Inspecting the npz#

You can inspect the contents of a saved .npz file with inspect_npz(), which displays the stored metadata and parameter keys—useful for debugging (e.g., when loading fails) or verifying saved models.

nmo.inspect_npz("ridge_glm_params.npz")
Metadata
--------
jax version            : 0.10.1 (installed: 0.10.1)
jaxlib version         : 0.10.1 (installed: 0.10.1)
scipy version          : 1.17.1 (installed: 1.17.1)
scikit-learn version   : 1.9.0 (installed: 1.9.0)
nemos version          : 0.2.9.dev414 (installed: 0.2.9.dev414)

Model class
-----------
Saved model class      : nemos.glm.glm.GLM

Model parameters
----------------
inverse_link_function  : jax.numpy.exp
observation_model      : {'class': 'nemos.observation_models.PoissonObservations'}
regularizer            : {'class': 'nemos.regularizer.Ridge'}
regularizer_strength   : 1
solver_kwargs          : None
solver_name            : LBFGS

Model fit parameters
--------------------
aux_: None
coef_: None
dof_resid_: None
intercept_: None
scale_: None

Save and Load Custom Objects#

Advanced users may want to specify custom models and still be able to save and load. For example, one could try a different inverse link function (non-linearity) or a custom Regularizer.

def custom_link(x):
    return x**2

class CustomRegularizer(nmo.regularizer.Ridge):
    def __init__(self, new_param):
        self.new_param = new_param

model = nmo.glm.GLM(inverse_link_function=custom_link, regularizer=CustomRegularizer(10))
model.save_params("custom_regularizer_params.npz")

nmo.inspect_npz("custom_regularizer_params.npz")
Metadata
--------
jax version            : 0.10.1 (installed: 0.10.1)
jaxlib version         : 0.10.1 (installed: 0.10.1)
scipy version          : 1.17.1 (installed: 1.17.1)
scikit-learn version   : 1.9.0 (installed: 1.9.0)
nemos version          : 0.2.9.dev414 (installed: 0.2.9.dev414)

Model class
-----------
Saved model class      : nemos.glm.glm.GLM

Model parameters
----------------
inverse_link_function  : __main__.custom_link
observation_model      : {'class': 'nemos.observation_models.PoissonObservations'}
regularizer            : {'class': '__main__.CustomRegularizer', 'params': {'new_param': 10}}
regularizer_strength   : 1
solver_kwargs          : None
solver_name            : LBFGS

Model fit parameters
--------------------
aux_: None
coef_: None
dof_resid_: None
intercept_: None
scale_: None

As you can see, the regularizer class is stored as a string, "{object_class.__module__}.{object_class.__name__}". This means that trying to load this model directly will result in an error, because NeMoS doesn’t pickle objects and therefore doesn’t know how to recreate the CustomRegularizer automatically.

Why prevent pickling?

Unpickling typically involves executing code, which can pose a security risk. A third party could tamper with a pickled file to insert malicious code that runs whenever the object is unpickled.

For a real-world example, see this discussion.

loaded_model = nmo.load_model("custom_regularizer_params.npz")
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[5], line 1
----> 1 loaded_model = nmo.load_model("custom_regularizer_params.npz")

File ~/checkouts/readthedocs.org/user_builds/nemos/envs/latest/lib/python3.12/site-packages/nemos/io/io.py:157, in load_model(filename, mapping_dict)
    150     raise ValueError(
    151         "The following keys in your mapping do not match any parameters in the loaded model:\n\n"
    152         f"{suggestions}\n"
    153         "Please double-check your mapping dictionary."
    154     )
    155 # if any value from saved_params is a key in mapping_dict,
    156 # replace it with the corresponding value from mapping_dict
--> 157 saved_params, updated_keys = _apply_custom_map(saved_params, nested_map_dict)
    159 if len(updated_keys) > 0:
    160     warnings.warn(
    161         f"The following keys have been replaced in the model parameters: {updated_keys}.",
    162         UserWarning,
    163     )

File ~/checkouts/readthedocs.org/user_builds/nemos/envs/latest/lib/python3.12/site-packages/nemos/io/io.py:323, in _apply_custom_map(params, mapping_dict, updated_keys)
    318 if not is_mapped:
    319     # check for nested callable/classes save instantiate based on the string
    320     new_params, updated_keys = _apply_custom_map(
    321         val.pop("params", {}), mapped_params, updated_keys=updated_keys
    322     )
--> 323     updated_params[key] = _safe_instantiate(key, class_name, **new_params)
    324 else:
    325     mapped_class = mapped_val["class"]

File ~/checkouts/readthedocs.org/user_builds/nemos/envs/latest/lib/python3.12/site-packages/nemos/io/io.py:213, in _safe_instantiate(param_name, class_name, **kwargs)
    211 else:
    212     class_type = "regularization"
--> 213 raise ValueError(
    214     f"The class '{class_basename}' is not a native NeMoS class.\n"
    215     f"To load a custom {class_type} class, please provide the following mapping:\n\n"
    216     f" - nemos.load_model(save_path, mapping_dict={{'{param_name}': {class_basename}}})"
    217 )

ValueError: The class 'CustomRegularizer' is not a native NeMoS class.
To load a custom regularization class, please provide the following mapping:

 - nemos.load_model(save_path, mapping_dict={'regularizer': CustomRegularizer})

As the error explains, you can tell nemos how to load the custom objects by providing a mapping between the saved string and to the callable.

mapping = {
    "regularizer": CustomRegularizer,
    "inverse_link_function": custom_link
}
loaded_model = nmo.load_model("custom_regularizer_params.npz", mapping_dict=mapping)
loaded_model
/home/docs/checkouts/readthedocs.org/user_builds/nemos/envs/latest/lib/python3.12/site-packages/nemos/io/io.py:160: UserWarning: The following keys have been replaced in the model parameters: ['inverse_link_function', 'regularizer'].
  warnings.warn(
GLM(
    observation_model=PoissonObservations(),
    inverse_link_function=custom_link,
    regularizer=CustomRegularizer(new_param=10),
    regularizer_strength=1.0,
    solver_name='LBFGS'
)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Allowed Mappings

Mapping is allowed only for callables (functions) and classes, because these cannot be stored directly without pickling. Other values (like numbers, strings, or arrays) are always stored directly in the .npz and cannot be remapped.

When mapping a custom class, you must pass the class itself (e.g., mapping = {"regularizer": CustomRegularizer}), not an instance (CustomRegularizer()). Passing an instance would overwrite the saved parameters and could lead to inconsistencies.