diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a0337e3..f951335 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -29,18 +29,28 @@ jobs: # https://pytorch.org/get-started/previous-versions/ torch-version: ["2.6.0", "2.10.0"] sklearn-version: ["latest"] + numpy-version: ["latest"] + include: # windows test with standard config - os: windows-latest torch-version: 2.6.0 python-version: "3.12" sklearn-version: "latest" + numpy-version: "latest" # legacy sklearn (several API differences) - os: ubuntu-latest torch-version: 2.6.0 python-version: "3.12" sklearn-version: "legacy" + numpy-version: "latest" + + - os: ubuntu-latest + torch-version: 2.6.0 + python-version: "3.12" + sklearn-version: "latest" + numpy-version: "legacy" # TODO(stes): latest torch and python # requires a PyTables release compatible with @@ -55,6 +65,7 @@ jobs: torch-version: 2.4.0 python-version: "3.10" sklearn-version: "legacy" + numpy-version: "latest" runs-on: ${{ matrix.os }} @@ -88,6 +99,11 @@ jobs: run: | pip install scikit-learn==1.4.2 '.[dev,datasets,integrations]' + - name: Check numpy legacy version + if: matrix.numpy-version == 'legacy' + run: | + pip install "numpy<2" '.[dev,datasets,integrations]' + - name: Run the formatter run: | make format diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 474145e..0064552 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -23,6 +23,8 @@ import importlib.metadata import itertools +import pickle +import warnings from typing import (Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union) @@ -50,8 +52,13 @@ # windows (https://github.com/AdaptiveMotorControlLab/CEBRA/pull/281#issuecomment-3764185072) # on build (windows-latest, torch 2.6.0, python 3.12, latest sklearn) CEBRA_LOAD_SAFE_GLOBALS = [ - cebra.data.Offset, torch.torch_version.TorchVersion, np.dtype, - np.dtypes.Int32DType, np.dtypes.Float64DType, np.dtypes.Int64DType + cebra.data.Offset, + torch.torch_version.TorchVersion, + np.dtype, + np.dtypes.Int32DType, + np.dtypes.Int64DType, + np.dtypes.Float32DType, + np.dtypes.Float64DType, ] @@ -62,20 +69,22 @@ def check_version(estimator): sklearn.__version__) < packaging.version.parse("1.6.dev") -def _safe_torch_load(filename, weights_only, **kwargs): - if weights_only is None: - if packaging.version.parse( - torch.__version__) >= packaging.version.parse("2.6.0"): - weights_only = True - else: - weights_only = False +def _safe_torch_load(filename, weights_only=False, **kwargs): + checkpoint = None + legacy_mode = packaging.version.parse( + torch.__version__) < packaging.version.parse("2.6.0") - if not weights_only: + if legacy_mode: checkpoint = torch.load(filename, weights_only=False, **kwargs) else: - # NOTE(stes): This is only supported for torch 2.6+ with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS): - checkpoint = torch.load(filename, weights_only=True, **kwargs) + checkpoint = torch.load(filename, + weights_only=weights_only, + **kwargs) + + if not isinstance(checkpoint, dict): + _check_type_checkpoint(checkpoint) + checkpoint = checkpoint._get_state_dict() return checkpoint @@ -315,8 +324,9 @@ def _require_arg(key): def _check_type_checkpoint(checkpoint): if not isinstance(checkpoint, cebra.CEBRA): - raise RuntimeError("Model loaded from file is not compatible with " - "the current CEBRA version.") + raise RuntimeError( + "Model loaded from file is not compatible with " + f"the current CEBRA version. Got: {type(checkpoint)}") if not sklearn_utils.check_fitted(checkpoint): raise ValueError( "CEBRA model is not fitted. Loading it is not supported.") @@ -1336,6 +1346,26 @@ def _get_state(self): } return state + def _get_state_dict(self): + backend = "sklearn" + return { + 'args': self.get_params(), + 'state': self._get_state(), + 'state_dict': self.solver_.state_dict(), + 'metadata': { + 'backend': + backend, + 'cebra_version': + cebra.__version__, + 'torch_version': + torch.__version__, + 'numpy_version': + np.__version__, + 'sklearn_version': + importlib.metadata.distribution("scikit-learn").version + } + } + def save(self, filename: str, backend: Literal["torch", "sklearn"] = "sklearn"): @@ -1384,28 +1414,16 @@ def save(self, """ if sklearn_utils.check_fitted(self): if backend == "torch": + warnings.warn( + "Saving with backend='torch' is deprecated and will be removed in a future version. " + "Please use backend='sklearn' instead.", + DeprecationWarning, + stacklevel=2, + ) checkpoint = torch.save(self, filename) elif backend == "sklearn": - checkpoint = torch.save( - { - 'args': self.get_params(), - 'state': self._get_state(), - 'state_dict': self.solver_.state_dict(), - 'metadata': { - 'backend': - backend, - 'cebra_version': - cebra.__version__, - 'torch_version': - torch.__version__, - 'numpy_version': - np.__version__, - 'sklearn_version': - importlib.metadata.distribution("scikit-learn" - ).version - } - }, filename) + checkpoint = torch.save(self._get_state_dict(), filename) else: raise NotImplementedError(f"Unsupported backend: {backend}") else: @@ -1457,29 +1475,60 @@ def load(cls, >>> tmp_file.unlink() """ supported_backends = ["auto", "sklearn", "torch"] + if backend not in supported_backends: raise NotImplementedError( f"Unsupported backend: '{backend}'. Supported backends are: {', '.join(supported_backends)}" ) - checkpoint = _safe_torch_load(filename, weights_only, **kwargs) + if backend not in ["auto", "sklearn"]: + warnings.warn( + "From CEBRA version 0.6.1 onwards, the 'backend' parameter in cebra.CEBRA.load is deprecated and will be ignored; " + "the sklearn backend is now always used. Models saved with the torch backend can still be loaded.", + category=DeprecationWarning, + stacklevel=2, + ) - if backend == "auto": - backend = "sklearn" if isinstance(checkpoint, dict) else "torch" + backend = "sklearn" + + # NOTE(stes): For maximum backwards compatibility, we allow to load legacy checkpoints. From 0.7.0 onwards, + # the user will have to explicitly pass weights_only=False to load these checkpoints, following the changes + # introduced in torch 2.6.0. + try: + checkpoint = _safe_torch_load(filename, **kwargs) + except pickle.UnpicklingError as e: + if weights_only is not False: + if packaging.version.parse( + cebra.__version__) < packaging.version.parse("0.7"): + warnings.warn( + "Failed to unpickle checkpoint with weights_only=True. " + "Falling back to loading with weights_only=False. " + "This is unsafe and should only be done if you trust the source of the model file. " + "In the future, loading these checkpoints will only work if weights_only=False is explicitly passed.", + category=UserWarning, + stacklevel=2, + ) + else: + raise ValueError( + "Failed to unpickle checkpoint with weights_only=True. " + "This may be due to an incompatible model file format. " + "To attempt loading this checkpoint, please pass weights_only=False to CEBRA.load. " + "Example: CEBRA.load(filename, weights_only=False)." + ) from e - if isinstance(checkpoint, dict) and backend == "torch": - raise RuntimeError( - "Cannot use 'torch' backend with a dictionary-based checkpoint. " - "Please try a different backend.") - if not isinstance(checkpoint, dict) and backend == "sklearn": + checkpoint = _safe_torch_load(filename, + weights_only=False, + **kwargs) + + if backend != "sklearn": + raise ValueError(f"Unsupported backend: {backend}") + + if not isinstance(checkpoint, dict): raise RuntimeError( "Cannot use 'sklearn' backend a non dictionary-based checkpoint. " - "Please try a different backend.") + f"Please try a different backend. Got: {type(checkpoint)}") - if backend == "sklearn": - cebra_ = _load_cebra_with_sklearn_backend(checkpoint) - else: - cebra_ = _check_type_checkpoint(checkpoint) + cebra_ = _load_cebra_with_sklearn_backend(checkpoint) n_features = cebra_.n_features_ cebra_.solver_.n_features = ([ diff --git a/cebra/registry.py b/cebra/registry.py index 994fbd5..1bbc509 100644 --- a/cebra/registry.py +++ b/cebra/registry.py @@ -46,6 +46,7 @@ from __future__ import annotations import fnmatch +import functools import itertools import sys import textwrap @@ -214,14 +215,29 @@ def _zip_dict(d): yield dict(zip(keys, combination)) def _create_class(cls, **default_kwargs): + class_name = pattern.format(**default_kwargs) - @register(pattern.format(**default_kwargs), base=pattern) + @register(class_name, base=pattern) class _ParametrizedClass(cls): def __init__(self, *args, **kwargs): default_kwargs.update(kwargs) super().__init__(*args, **default_kwargs) + # Make the class pickleable by copying metadata from the base class + # and registering it in the module namespace + functools.update_wrapper(_ParametrizedClass, cls, updated=[]) + + # Set a unique qualname so pickle finds this class, not the base class + unique_name = f"{cls.__qualname__}_{class_name.replace('-', '_')}" + _ParametrizedClass.__qualname__ = unique_name + _ParametrizedClass.__name__ = unique_name + + # Register in module namespace so pickle can find it via getattr + parent_module = sys.modules.get(cls.__module__) + if parent_module is not None: + setattr(parent_module, unique_name, _ParametrizedClass) + def _parametrize(cls): for _default_kwargs in kwargs: _create_class(cls, **_default_kwargs) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 7dfbac0..831ad49 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -19,7 +19,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import contextlib import itertools +import os import tempfile import warnings @@ -47,6 +49,34 @@ _DEVICES = ("cpu",) +@contextlib.contextmanager +def _windows_compatible_tempfile(mode="w+b", delete=True, **kwargs): + """Context manager for creating temporary files compatible with Windows. + + On Windows, files opened with delete=True cannot be accessed by other + processes or reopened. This context manager creates a temporary file + with delete=False, yields its path, and ensures cleanup in a finally block. + + Args: + mode: File mode (default: "w+b") + **kwargs: Additional arguments passed to NamedTemporaryFile + + Yields: + str: Path to the temporary file + """ + if not delete: + raise ValueError("'delete' must be True") + + with tempfile.NamedTemporaryFile(mode=mode, delete=False, **kwargs) as f: + tempname = f.name + + try: + yield tempname + finally: + if os.path.exists(tempname): + os.remove(tempname) + + def test_imports(): import cebra @@ -980,7 +1010,11 @@ def _assert_equal(original_model, loaded_model): if check_if_fit(loaded_model): _assert_same_state_dict(original_model.state_dict_, loaded_model.state_dict_) - X = np.random.normal(0, 1, (100, 1)) + + n_features = loaded_model.n_features_ + if isinstance(n_features, list): + n_features = n_features[0] + X = np.random.normal(0, 1, (100, n_features)) if loaded_model.num_sessions is not None: assert np.allclose(loaded_model.transform(X, session_id=0), @@ -1027,7 +1061,7 @@ def get_offset(self) -> cebra.data.datatypes.Offset: @pytest.mark.parametrize("model_architecture", ["offset1-model", "parametrized-model-5"]) @pytest.mark.parametrize("device", ["cpu"] + - ["cuda"] if torch.cuda.is_available() else []) + (["cuda"] if torch.cuda.is_available() else [])) def test_save_and_load(action, backend_save, backend_load, model_architecture, device): original_model = cebra_sklearn_cebra.CEBRA( @@ -1037,26 +1071,19 @@ def test_save_and_load(action, backend_save, backend_load, model_architecture, device=device) original_model = action(original_model) - with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile: + with _windows_compatible_tempfile(mode="w+b") as tempname: if not check_if_fit(original_model): with pytest.raises(ValueError): - original_model.save(savefile.name, backend=backend_save) + original_model.save(tempname, backend=backend_save) else: - if "parametrized" in original_model.model_architecture and backend_save == "torch": - with pytest.raises(AttributeError): - original_model.save(savefile.name, backend=backend_save) - else: - original_model.save(savefile.name, backend=backend_save) - - if (backend_load != "auto") and (backend_save != backend_load): - with pytest.raises(RuntimeError): - cebra_sklearn_cebra.CEBRA.load(savefile.name, - backend_load) - else: - loaded_model = cebra_sklearn_cebra.CEBRA.load( - savefile.name, backend_load) - _assert_equal(original_model, loaded_model) - action(loaded_model) + original_model.save(tempname, backend=backend_save) + + weights_only = None + + loaded_model = cebra_sklearn_cebra.CEBRA.load( + tempname, backend_load, weights_only=weights_only) + _assert_equal(original_model, loaded_model) + action(loaded_model) def get_ordered_cuda_devices(): @@ -1130,9 +1157,9 @@ def test_move_cpu_to_cuda_device(device): device_str = f'cuda:{device_model.index}' assert device_str == new_device - with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile: - cebra_model.save(savefile.name) - loaded_model = cebra_sklearn_cebra.CEBRA.load(savefile.name) + with _windows_compatible_tempfile(mode="w+b") as tempname: + cebra_model.save(tempname) + loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname) assert cebra_model.device == loaded_model.device assert next(cebra_model.solver_.model.parameters()).device == next( @@ -1159,9 +1186,9 @@ def test_move_cpu_to_mps_device(device): device_model = next(cebra_model.solver_.model.parameters()).device assert device_model.type == new_device - with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile: - cebra_model.save(savefile.name) - loaded_model = cebra_sklearn_cebra.CEBRA.load(savefile.name) + with _windows_compatible_tempfile(mode="w+b") as tempname: + cebra_model.save(tempname) + loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname) assert cebra_model.device == loaded_model.device assert next(cebra_model.solver_.model.parameters()).device == next( @@ -1198,9 +1225,9 @@ def test_move_mps_to_cuda_device(device): device_str = f'cuda:{device_model.index}' assert device_str == new_device - with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile: - cebra_model.save(savefile.name) - loaded_model = cebra_sklearn_cebra.CEBRA.load(savefile.name) + with _windows_compatible_tempfile(mode="w+b") as tempname: + cebra_model.save(tempname) + loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname) assert cebra_model.device == loaded_model.device assert next(cebra_model.solver_.model.parameters()).device == next( @@ -1460,7 +1487,7 @@ def test_new_transform(model_architecture, device): X, session_id=0) assert np.allclose(embedding1, embedding2, rtol=1e-5, - atol=1e-8), "Arrays are not close enough" + atol=1e-8), " are not close enough" embedding1 = cebra_model.transform(torch.Tensor(X), session_id=0) embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, @@ -1561,3 +1588,35 @@ def test_non_writable_array(): embedding = cebra_model.transform(X) assert isinstance(embedding, np.ndarray) assert embedding.shape[0] == X.shape[0] + + +def test_read_write(): + X = np.random.randn(100, 10) + y = np.random.randn(100, 2) + cebra_model = cebra.CEBRA(max_iterations=2, batch_size=32, device="cpu") + cebra_model.fit(X, y) + cebra_model.transform(X) + + with _windows_compatible_tempfile(mode="w+b", delete=True) as tempname: + cebra_model.save(tempname) + loaded_model = cebra.CEBRA.load(tempname) + _assert_equal(cebra_model, loaded_model) + + +def test_repro_pickle_error(): + """The torch backend for save/loading fails with python 3.14. + + See https://github.com/AdaptiveMotorControlLab/CEBRA/pull/292. + + This test is a minimal repro of the error. + """ + + model = cebra_sklearn_cebra.CEBRA(model_architecture='parametrized-model-5', + max_iterations=5, + batch_size=100, + device='cpu') + + model.fit(np.random.randn(1000, 10)) + + with _windows_compatible_tempfile(mode="w+b", delete=True) as tempname: + model.save(tempname, backend="torch")