Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,28 @@ jobs:
# https://pytorch.org/get-started/previous-versions/
torch-version: ["2.6.0", "2.10.0"]
sklearn-version: ["latest"]
numpy-version: ["latest"]

include:
# windows test with standard config
- os: windows-latest
torch-version: 2.6.0
python-version: "3.12"
sklearn-version: "latest"
numpy-version: "latest"

# legacy sklearn (several API differences)
- os: ubuntu-latest
torch-version: 2.6.0
python-version: "3.12"
sklearn-version: "legacy"
numpy-version: "latest"

- os: ubuntu-latest
torch-version: 2.6.0
python-version: "3.12"
sklearn-version: "latest"
numpy-version: "legacy"

# TODO(stes): latest torch and python
# requires a PyTables release compatible with
Expand All @@ -55,6 +65,7 @@ jobs:
torch-version: 2.4.0
python-version: "3.10"
sklearn-version: "legacy"
numpy-version: "latest"

runs-on: ${{ matrix.os }}

Expand Down Expand Up @@ -88,6 +99,11 @@ jobs:
run: |
pip install scikit-learn==1.4.2 '.[dev,datasets,integrations]'

- name: Check numpy legacy version
if: matrix.numpy-version == 'legacy'
run: |
pip install "numpy<2" '.[dev,datasets,integrations]'

- name: Run the formatter
run: |
make format
Expand Down
141 changes: 95 additions & 46 deletions cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import importlib.metadata
import itertools
import pickle
import warnings
from typing import (Callable, Dict, Iterable, List, Literal, Optional, Tuple,
Union)

Expand Down Expand Up @@ -50,8 +52,13 @@
# windows (https://github.com/AdaptiveMotorControlLab/CEBRA/pull/281#issuecomment-3764185072)
# on build (windows-latest, torch 2.6.0, python 3.12, latest sklearn)
CEBRA_LOAD_SAFE_GLOBALS = [
cebra.data.Offset, torch.torch_version.TorchVersion, np.dtype,
np.dtypes.Int32DType, np.dtypes.Float64DType, np.dtypes.Int64DType
cebra.data.Offset,
torch.torch_version.TorchVersion,
np.dtype,
np.dtypes.Int32DType,
np.dtypes.Int64DType,
np.dtypes.Float32DType,
np.dtypes.Float64DType,
]


Expand All @@ -62,20 +69,22 @@ def check_version(estimator):
sklearn.__version__) < packaging.version.parse("1.6.dev")


def _safe_torch_load(filename, weights_only, **kwargs):
if weights_only is None:
if packaging.version.parse(
torch.__version__) >= packaging.version.parse("2.6.0"):
weights_only = True
else:
weights_only = False
def _safe_torch_load(filename, weights_only=False, **kwargs):
checkpoint = None
legacy_mode = packaging.version.parse(
torch.__version__) < packaging.version.parse("2.6.0")

if not weights_only:
if legacy_mode:
checkpoint = torch.load(filename, weights_only=False, **kwargs)
else:
# NOTE(stes): This is only supported for torch 2.6+
with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS):
checkpoint = torch.load(filename, weights_only=True, **kwargs)
checkpoint = torch.load(filename,
weights_only=weights_only,
**kwargs)

if not isinstance(checkpoint, dict):
_check_type_checkpoint(checkpoint)
checkpoint = checkpoint._get_state_dict()

return checkpoint

Expand Down Expand Up @@ -315,8 +324,9 @@ def _require_arg(key):

def _check_type_checkpoint(checkpoint):
if not isinstance(checkpoint, cebra.CEBRA):
raise RuntimeError("Model loaded from file is not compatible with "
"the current CEBRA version.")
raise RuntimeError(
"Model loaded from file is not compatible with "
f"the current CEBRA version. Got: {type(checkpoint)}")
if not sklearn_utils.check_fitted(checkpoint):
raise ValueError(
"CEBRA model is not fitted. Loading it is not supported.")
Expand Down Expand Up @@ -1336,6 +1346,26 @@ def _get_state(self):
}
return state

def _get_state_dict(self):
backend = "sklearn"
return {
'args': self.get_params(),
'state': self._get_state(),
'state_dict': self.solver_.state_dict(),
'metadata': {
'backend':
backend,
'cebra_version':
cebra.__version__,
'torch_version':
torch.__version__,
'numpy_version':
np.__version__,
'sklearn_version':
importlib.metadata.distribution("scikit-learn").version
}
}

def save(self,
filename: str,
backend: Literal["torch", "sklearn"] = "sklearn"):
Expand Down Expand Up @@ -1384,28 +1414,16 @@ def save(self,
"""
if sklearn_utils.check_fitted(self):
if backend == "torch":
warnings.warn(
"Saving with backend='torch' is deprecated and will be removed in a future version. "
"Please use backend='sklearn' instead.",
DeprecationWarning,
stacklevel=2,
)
checkpoint = torch.save(self, filename)

elif backend == "sklearn":
checkpoint = torch.save(
{
'args': self.get_params(),
'state': self._get_state(),
'state_dict': self.solver_.state_dict(),
'metadata': {
'backend':
backend,
'cebra_version':
cebra.__version__,
'torch_version':
torch.__version__,
'numpy_version':
np.__version__,
'sklearn_version':
importlib.metadata.distribution("scikit-learn"
).version
}
}, filename)
checkpoint = torch.save(self._get_state_dict(), filename)
else:
raise NotImplementedError(f"Unsupported backend: {backend}")
else:
Expand Down Expand Up @@ -1457,29 +1475,60 @@ def load(cls,
>>> tmp_file.unlink()
"""
supported_backends = ["auto", "sklearn", "torch"]

if backend not in supported_backends:
raise NotImplementedError(
f"Unsupported backend: '{backend}'. Supported backends are: {', '.join(supported_backends)}"
)

checkpoint = _safe_torch_load(filename, weights_only, **kwargs)
if backend not in ["auto", "sklearn"]:
warnings.warn(
"From CEBRA version 0.6.1 onwards, the 'backend' parameter in cebra.CEBRA.load is deprecated and will be ignored; "
"the sklearn backend is now always used. Models saved with the torch backend can still be loaded.",
category=DeprecationWarning,
stacklevel=2,
)

if backend == "auto":
backend = "sklearn" if isinstance(checkpoint, dict) else "torch"
backend = "sklearn"

# NOTE(stes): For maximum backwards compatibility, we allow to load legacy checkpoints. From 0.7.0 onwards,
# the user will have to explicitly pass weights_only=False to load these checkpoints, following the changes
# introduced in torch 2.6.0.
try:
checkpoint = _safe_torch_load(filename, **kwargs)
except pickle.UnpicklingError as e:
if weights_only is not False:
if packaging.version.parse(
cebra.__version__) < packaging.version.parse("0.7"):
warnings.warn(
"Failed to unpickle checkpoint with weights_only=True. "
"Falling back to loading with weights_only=False. "
"This is unsafe and should only be done if you trust the source of the model file. "
"In the future, loading these checkpoints will only work if weights_only=False is explicitly passed.",
category=UserWarning,
stacklevel=2,
)
else:
raise ValueError(
"Failed to unpickle checkpoint with weights_only=True. "
"This may be due to an incompatible model file format. "
"To attempt loading this checkpoint, please pass weights_only=False to CEBRA.load. "
"Example: CEBRA.load(filename, weights_only=False)."
) from e

if isinstance(checkpoint, dict) and backend == "torch":
raise RuntimeError(
"Cannot use 'torch' backend with a dictionary-based checkpoint. "
"Please try a different backend.")
if not isinstance(checkpoint, dict) and backend == "sklearn":
checkpoint = _safe_torch_load(filename,
weights_only=False,
**kwargs)

if backend != "sklearn":
raise ValueError(f"Unsupported backend: {backend}")

if not isinstance(checkpoint, dict):
raise RuntimeError(
"Cannot use 'sklearn' backend a non dictionary-based checkpoint. "
"Please try a different backend.")
f"Please try a different backend. Got: {type(checkpoint)}")

if backend == "sklearn":
cebra_ = _load_cebra_with_sklearn_backend(checkpoint)
else:
cebra_ = _check_type_checkpoint(checkpoint)
cebra_ = _load_cebra_with_sklearn_backend(checkpoint)

n_features = cebra_.n_features_
cebra_.solver_.n_features = ([
Expand Down
18 changes: 17 additions & 1 deletion cebra/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from __future__ import annotations

import fnmatch
import functools
import itertools
import sys
import textwrap
Expand Down Expand Up @@ -214,14 +215,29 @@ def _zip_dict(d):
yield dict(zip(keys, combination))

def _create_class(cls, **default_kwargs):
class_name = pattern.format(**default_kwargs)

@register(pattern.format(**default_kwargs), base=pattern)
@register(class_name, base=pattern)
class _ParametrizedClass(cls):

def __init__(self, *args, **kwargs):
default_kwargs.update(kwargs)
super().__init__(*args, **default_kwargs)

# Make the class pickleable by copying metadata from the base class
# and registering it in the module namespace
functools.update_wrapper(_ParametrizedClass, cls, updated=[])

# Set a unique qualname so pickle finds this class, not the base class
unique_name = f"{cls.__qualname__}_{class_name.replace('-', '_')}"
_ParametrizedClass.__qualname__ = unique_name
_ParametrizedClass.__name__ = unique_name

# Register in module namespace so pickle can find it via getattr
parent_module = sys.modules.get(cls.__module__)
if parent_module is not None:
setattr(parent_module, unique_name, _ParametrizedClass)

def _parametrize(cls):
for _default_kwargs in kwargs:
_create_class(cls, **_default_kwargs)
Expand Down
Loading