From 214bf6a6d2f7c08f5228a74878d31c06b59b5de0 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 1 Feb 2026 11:24:00 +0000 Subject: [PATCH 1/9] fix accedentally skipped save/load test --- tests/test_sklearn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 7dfbac0..145c466 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -1027,7 +1027,7 @@ def get_offset(self) -> cebra.data.datatypes.Offset: @pytest.mark.parametrize("model_architecture", ["offset1-model", "parametrized-model-5"]) @pytest.mark.parametrize("device", ["cpu"] + - ["cuda"] if torch.cuda.is_available() else []) + (["cuda"] if torch.cuda.is_available() else [])) def test_save_and_load(action, backend_save, backend_load, model_architecture, device): original_model = cebra_sklearn_cebra.CEBRA( From a893974feb0c3306350030000229f967dec4ca22 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 1 Feb 2026 11:33:52 +0000 Subject: [PATCH 2/9] include a numpy legacy test --- .github/workflows/build.yml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a0337e3..f951335 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -29,18 +29,28 @@ jobs: # https://pytorch.org/get-started/previous-versions/ torch-version: ["2.6.0", "2.10.0"] sklearn-version: ["latest"] + numpy-version: ["latest"] + include: # windows test with standard config - os: windows-latest torch-version: 2.6.0 python-version: "3.12" sklearn-version: "latest" + numpy-version: "latest" # legacy sklearn (several API differences) - os: ubuntu-latest torch-version: 2.6.0 python-version: "3.12" sklearn-version: "legacy" + numpy-version: "latest" + + - os: ubuntu-latest + torch-version: 2.6.0 + python-version: "3.12" + sklearn-version: "latest" + numpy-version: "legacy" # TODO(stes): latest torch and python # requires a PyTables release compatible with @@ -55,6 +65,7 @@ jobs: torch-version: 2.4.0 python-version: "3.10" sklearn-version: "legacy" + numpy-version: "latest" runs-on: ${{ matrix.os }} @@ -88,6 +99,11 @@ jobs: run: | pip install scikit-learn==1.4.2 '.[dev,datasets,integrations]' + - name: Check numpy legacy version + if: matrix.numpy-version == 'legacy' + run: | + pip install "numpy<2" '.[dev,datasets,integrations]' + - name: Run the formatter run: | make format From 3e98794f90c608188a571c1bf9cd9228d25147a2 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 1 Feb 2026 12:45:12 +0000 Subject: [PATCH 3/9] fix windows compatibility for tempfile --- tests/test_sklearn.py | 74 +++++++++++++++++++++++++++++++++---------- 1 file changed, 58 insertions(+), 16 deletions(-) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 145c466..7bee251 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -19,7 +19,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import contextlib import itertools +import os import tempfile import warnings @@ -47,6 +49,34 @@ _DEVICES = ("cpu",) +@contextlib.contextmanager +def _windows_compatible_tempfile(mode="w+b", delete=True, **kwargs): + """Context manager for creating temporary files compatible with Windows. + + On Windows, files opened with delete=True cannot be accessed by other + processes or reopened. This context manager creates a temporary file + with delete=False, yields its path, and ensures cleanup in a finally block. + + Args: + mode: File mode (default: "w+b") + **kwargs: Additional arguments passed to NamedTemporaryFile + + Yields: + str: Path to the temporary file + """ + if not delete: + raise ValueError("'delete' must be True") + + with tempfile.NamedTemporaryFile(mode=mode, delete=False, **kwargs) as f: + tempname = f.name + + try: + yield tempname + finally: + if os.path.exists(tempname): + os.remove(tempname) + + def test_imports(): import cebra @@ -1037,24 +1067,23 @@ def test_save_and_load(action, backend_save, backend_load, model_architecture, device=device) original_model = action(original_model) - with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile: + with _windows_compatible_tempfile(mode="w+b") as tempname: if not check_if_fit(original_model): with pytest.raises(ValueError): - original_model.save(savefile.name, backend=backend_save) + original_model.save(tempname, backend=backend_save) else: if "parametrized" in original_model.model_architecture and backend_save == "torch": with pytest.raises(AttributeError): - original_model.save(savefile.name, backend=backend_save) + original_model.save(tempname, backend=backend_save) else: - original_model.save(savefile.name, backend=backend_save) + original_model.save(tempname, backend=backend_save) if (backend_load != "auto") and (backend_save != backend_load): with pytest.raises(RuntimeError): - cebra_sklearn_cebra.CEBRA.load(savefile.name, - backend_load) + cebra_sklearn_cebra.CEBRA.load(tempname, backend_load) else: loaded_model = cebra_sklearn_cebra.CEBRA.load( - savefile.name, backend_load) + tempname, backend_load) _assert_equal(original_model, loaded_model) action(loaded_model) @@ -1130,9 +1159,9 @@ def test_move_cpu_to_cuda_device(device): device_str = f'cuda:{device_model.index}' assert device_str == new_device - with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile: - cebra_model.save(savefile.name) - loaded_model = cebra_sklearn_cebra.CEBRA.load(savefile.name) + with _windows_compatible_tempfile(mode="w+b") as tempname: + cebra_model.save(tempname) + loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname) assert cebra_model.device == loaded_model.device assert next(cebra_model.solver_.model.parameters()).device == next( @@ -1159,9 +1188,9 @@ def test_move_cpu_to_mps_device(device): device_model = next(cebra_model.solver_.model.parameters()).device assert device_model.type == new_device - with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile: - cebra_model.save(savefile.name) - loaded_model = cebra_sklearn_cebra.CEBRA.load(savefile.name) + with _windows_compatible_tempfile(mode="w+b") as tempname: + cebra_model.save(tempname) + loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname) assert cebra_model.device == loaded_model.device assert next(cebra_model.solver_.model.parameters()).device == next( @@ -1198,9 +1227,9 @@ def test_move_mps_to_cuda_device(device): device_str = f'cuda:{device_model.index}' assert device_str == new_device - with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile: - cebra_model.save(savefile.name) - loaded_model = cebra_sklearn_cebra.CEBRA.load(savefile.name) + with _windows_compatible_tempfile(mode="w+b") as tempname: + cebra_model.save(tempname) + loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname) assert cebra_model.device == loaded_model.device assert next(cebra_model.solver_.model.parameters()).device == next( @@ -1561,3 +1590,16 @@ def test_non_writable_array(): embedding = cebra_model.transform(X) assert isinstance(embedding, np.ndarray) assert embedding.shape[0] == X.shape[0] + + +def test_read_write(): + X = np.random.randn(100, 10) + y = np.random.randn(100, 2) + cebra_model = cebra.CEBRA(max_iterations=2, batch_size=32, device="cpu") + cebra_model.fit(X, y) + cebra_model.transform(X) + + with _windows_compatible_tempfile(mode="w+b", delete=False) as tempname: + cebra_model.save(tempname) + loaded_model = cebra.CEBRA.load(tempname) + _assert_equal(cebra_model, loaded_model) From d71effae6ddbcef044799589236358026f34e0e8 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 1 Feb 2026 12:59:30 +0000 Subject: [PATCH 4/9] fix added test --- tests/test_sklearn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 7bee251..7d13383 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -1599,7 +1599,7 @@ def test_read_write(): cebra_model.fit(X, y) cebra_model.transform(X) - with _windows_compatible_tempfile(mode="w+b", delete=False) as tempname: + with _windows_compatible_tempfile(mode="w+b", delete=True) as tempname: cebra_model.save(tempname) loaded_model = cebra.CEBRA.load(tempname) _assert_equal(cebra_model, loaded_model) From 7bc5ca353faba8d08a081f27f9004a61d132faf3 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 1 Feb 2026 14:08:40 +0000 Subject: [PATCH 5/9] Fix legacy loading logic --- cebra/integrations/sklearn/cebra.py | 99 +++++++++++++++++++++-------- cebra/registry.py | 18 +++++- tests/test_sklearn.py | 43 ++++++++----- 3 files changed, 118 insertions(+), 42 deletions(-) diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 474145e..b9f9c79 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -23,6 +23,8 @@ import importlib.metadata import itertools +import pickle +import warnings from typing import (Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union) @@ -1336,6 +1338,26 @@ def _get_state(self): } return state + def _get_state_dict(self): + backend = "sklearn" + return { + 'args': self.get_params(), + 'state': self._get_state(), + 'state_dict': self.solver_.state_dict(), + 'metadata': { + 'backend': + backend, + 'cebra_version': + cebra.__version__, + 'torch_version': + torch.__version__, + 'numpy_version': + np.__version__, + 'sklearn_version': + importlib.metadata.distribution("scikit-learn").version + } + } + def save(self, filename: str, backend: Literal["torch", "sklearn"] = "sklearn"): @@ -1384,28 +1406,16 @@ def save(self, """ if sklearn_utils.check_fitted(self): if backend == "torch": + warnings.warn( + "Saving with backend='torch' is deprecated and will be removed in a future version. " + "Please use backend='sklearn' instead.", + DeprecationWarning, + stacklevel=2, + ) checkpoint = torch.save(self, filename) elif backend == "sklearn": - checkpoint = torch.save( - { - 'args': self.get_params(), - 'state': self._get_state(), - 'state_dict': self.solver_.state_dict(), - 'metadata': { - 'backend': - backend, - 'cebra_version': - cebra.__version__, - 'torch_version': - torch.__version__, - 'numpy_version': - np.__version__, - 'sklearn_version': - importlib.metadata.distribution("scikit-learn" - ).version - } - }, filename) + checkpoint = torch.save(self._get_state_dict(), filename) else: raise NotImplementedError(f"Unsupported backend: {backend}") else: @@ -1457,15 +1467,52 @@ def load(cls, >>> tmp_file.unlink() """ supported_backends = ["auto", "sklearn", "torch"] + if backend not in supported_backends: raise NotImplementedError( f"Unsupported backend: '{backend}'. Supported backends are: {', '.join(supported_backends)}" ) - checkpoint = _safe_torch_load(filename, weights_only, **kwargs) + if backend not in ["auto", "sklearn"]: + warnings.warn( + "From CEBRA version 0.6.1 onwards, the 'backend' parameter in cebra.CEBRA.load is deprecated and will be ignored; " + "the sklearn backend is now always used. Models saved with the torch backend can still be loaded.", + category=DeprecationWarning, + stacklevel=2, + ) - if backend == "auto": - backend = "sklearn" if isinstance(checkpoint, dict) else "torch" + backend = "sklearn" + + # NOTE(stes): For maximum backwards compatibility, we allow to load legacy checkpoints. From 0.7.0 onwards, + # the user will have to explicitly pass weights_only=False to load these checkpoints, following the changes + # introduced in torch 2.6.0. + try: + checkpoint = _safe_torch_load(filename, weights_only=True, **kwargs) + except pickle.UnpicklingError as e: + if weights_only is not False: + if packaging.version.parse( + cebra.__version__) < packaging.version.parse("0.7"): + warnings.warn( + "Failed to unpickle checkpoint with weights_only=True. " + "Falling back to loading with weights_only=False. " + "This is unsafe and should only be done if you trust the source of the model file. " + "In the future, loading these checkpoints will only work if weights_only=False is explicitly passed.", + category=UserWarning, + stacklevel=2, + ) + else: + raise ValueError( + "Failed to unpickle checkpoint with weights_only=True. " + "This may be due to an incompatible model file format. " + "To attempt loading this checkpoint, please pass weights_only=False to CEBRA.load. " + "Example: CEBRA.load(filename, weights_only=False)." + ) from e + + checkpoint = _safe_torch_load(filename, + weights_only=False, + **kwargs) + checkpoint = _check_type_checkpoint(checkpoint) + checkpoint = checkpoint._get_state_dict() if isinstance(checkpoint, dict) and backend == "torch": raise RuntimeError( @@ -1476,10 +1523,10 @@ def load(cls, "Cannot use 'sklearn' backend a non dictionary-based checkpoint. " "Please try a different backend.") - if backend == "sklearn": - cebra_ = _load_cebra_with_sklearn_backend(checkpoint) - else: - cebra_ = _check_type_checkpoint(checkpoint) + if backend != "sklearn": + raise ValueError(f"Unsupported backend: {backend}") + + cebra_ = _load_cebra_with_sklearn_backend(checkpoint) n_features = cebra_.n_features_ cebra_.solver_.n_features = ([ diff --git a/cebra/registry.py b/cebra/registry.py index 994fbd5..1bbc509 100644 --- a/cebra/registry.py +++ b/cebra/registry.py @@ -46,6 +46,7 @@ from __future__ import annotations import fnmatch +import functools import itertools import sys import textwrap @@ -214,14 +215,29 @@ def _zip_dict(d): yield dict(zip(keys, combination)) def _create_class(cls, **default_kwargs): + class_name = pattern.format(**default_kwargs) - @register(pattern.format(**default_kwargs), base=pattern) + @register(class_name, base=pattern) class _ParametrizedClass(cls): def __init__(self, *args, **kwargs): default_kwargs.update(kwargs) super().__init__(*args, **default_kwargs) + # Make the class pickleable by copying metadata from the base class + # and registering it in the module namespace + functools.update_wrapper(_ParametrizedClass, cls, updated=[]) + + # Set a unique qualname so pickle finds this class, not the base class + unique_name = f"{cls.__qualname__}_{class_name.replace('-', '_')}" + _ParametrizedClass.__qualname__ = unique_name + _ParametrizedClass.__name__ = unique_name + + # Register in module namespace so pickle can find it via getattr + parent_module = sys.modules.get(cls.__module__) + if parent_module is not None: + setattr(parent_module, unique_name, _ParametrizedClass) + def _parametrize(cls): for _default_kwargs in kwargs: _create_class(cls, **_default_kwargs) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 7d13383..999bc7f 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -1053,7 +1053,7 @@ def get_offset(self) -> cebra.data.datatypes.Offset: @pytest.mark.parametrize("action", _iterate_actions()) @pytest.mark.parametrize("backend_save", ["torch", "sklearn"]) -@pytest.mark.parametrize("backend_load", ["auto", "torch", "sklearn"]) +@pytest.mark.parametrize("backend_load", ["sklearn", "auto", "torch"]) @pytest.mark.parametrize("model_architecture", ["offset1-model", "parametrized-model-5"]) @pytest.mark.parametrize("device", ["cpu"] + @@ -1072,20 +1072,14 @@ def test_save_and_load(action, backend_save, backend_load, model_architecture, with pytest.raises(ValueError): original_model.save(tempname, backend=backend_save) else: - if "parametrized" in original_model.model_architecture and backend_save == "torch": - with pytest.raises(AttributeError): - original_model.save(tempname, backend=backend_save) - else: - original_model.save(tempname, backend=backend_save) + original_model.save(tempname, backend=backend_save) + + weights_only = None - if (backend_load != "auto") and (backend_save != backend_load): - with pytest.raises(RuntimeError): - cebra_sklearn_cebra.CEBRA.load(tempname, backend_load) - else: - loaded_model = cebra_sklearn_cebra.CEBRA.load( - tempname, backend_load) - _assert_equal(original_model, loaded_model) - action(loaded_model) + loaded_model = cebra_sklearn_cebra.CEBRA.load( + tempname, backend_load, weights_only=weights_only) + _assert_equal(original_model, loaded_model) + action(loaded_model) def get_ordered_cuda_devices(): @@ -1489,7 +1483,7 @@ def test_new_transform(model_architecture, device): X, session_id=0) assert np.allclose(embedding1, embedding2, rtol=1e-5, - atol=1e-8), "Arrays are not close enough" + atol=1e-8), " are not close enough" embedding1 = cebra_model.transform(torch.Tensor(X), session_id=0) embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, @@ -1603,3 +1597,22 @@ def test_read_write(): cebra_model.save(tempname) loaded_model = cebra.CEBRA.load(tempname) _assert_equal(cebra_model, loaded_model) + + +def test_repro_pickle_error(): + """The torch backend for save/loading fails with python 3.14. + + See https://github.com/AdaptiveMotorControlLab/CEBRA/pull/292. + + This test is a minimal repro of the error. + """ + + model = cebra_sklearn_cebra.CEBRA(model_architecture='parametrized-model-5', + max_iterations=5, + batch_size=100, + device='cpu') + + model.fit(np.random.randn(1000, 10)) + + with _windows_compatible_tempfile(mode="w+b", delete=True) as tempname: + model.save(tempname, backend="torch") From 52023c9965952c9a4a8ca79f35573c68cec266db Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 1 Feb 2026 14:20:46 +0000 Subject: [PATCH 6/9] minimize diff in tests --- tests/test_sklearn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 999bc7f..de3cec4 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -1053,7 +1053,7 @@ def get_offset(self) -> cebra.data.datatypes.Offset: @pytest.mark.parametrize("action", _iterate_actions()) @pytest.mark.parametrize("backend_save", ["torch", "sklearn"]) -@pytest.mark.parametrize("backend_load", ["sklearn", "auto", "torch"]) +@pytest.mark.parametrize("backend_load", ["auto", "torch", "sklearn"]) @pytest.mark.parametrize("model_architecture", ["offset1-model", "parametrized-model-5"]) @pytest.mark.parametrize("device", ["cpu"] + From abdba0047219cf20cc1fc989f79b03f475d4b6df Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 1 Feb 2026 14:29:07 +0000 Subject: [PATCH 7/9] Fix _assert_equal check --- tests/test_sklearn.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index de3cec4..831ad49 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -1010,7 +1010,11 @@ def _assert_equal(original_model, loaded_model): if check_if_fit(loaded_model): _assert_same_state_dict(original_model.state_dict_, loaded_model.state_dict_) - X = np.random.normal(0, 1, (100, 1)) + + n_features = loaded_model.n_features_ + if isinstance(n_features, list): + n_features = n_features[0] + X = np.random.normal(0, 1, (100, n_features)) if loaded_model.num_sessions is not None: assert np.allclose(loaded_model.transform(X, session_id=0), From 2aad99a9593c31c7c7a905cf00f8e8ac6f02f5d8 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 1 Feb 2026 14:50:33 +0000 Subject: [PATCH 8/9] fix loading logic for legacy torch --- cebra/integrations/sklearn/cebra.py | 45 ++++++++++++++--------------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index b9f9c79..3de6d4f 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -64,20 +64,22 @@ def check_version(estimator): sklearn.__version__) < packaging.version.parse("1.6.dev") -def _safe_torch_load(filename, weights_only, **kwargs): - if weights_only is None: - if packaging.version.parse( - torch.__version__) >= packaging.version.parse("2.6.0"): - weights_only = True - else: - weights_only = False +def _safe_torch_load(filename, weights_only=False, **kwargs): + checkpoint = None + legacy_mode = packaging.version.parse( + torch.__version__) < packaging.version.parse("2.6.0") - if not weights_only: + if legacy_mode: checkpoint = torch.load(filename, weights_only=False, **kwargs) else: - # NOTE(stes): This is only supported for torch 2.6+ with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS): - checkpoint = torch.load(filename, weights_only=True, **kwargs) + checkpoint = torch.load(filename, + weights_only=weights_only, + **kwargs) + + if not isinstance(checkpoint, dict): + _check_type_checkpoint(checkpoint) + checkpoint = checkpoint._get_state_dict() return checkpoint @@ -317,8 +319,9 @@ def _require_arg(key): def _check_type_checkpoint(checkpoint): if not isinstance(checkpoint, cebra.CEBRA): - raise RuntimeError("Model loaded from file is not compatible with " - "the current CEBRA version.") + raise RuntimeError( + "Model loaded from file is not compatible with " + f"the current CEBRA version. Got: {type(checkpoint)}") if not sklearn_utils.check_fitted(checkpoint): raise ValueError( "CEBRA model is not fitted. Loading it is not supported.") @@ -1487,7 +1490,7 @@ def load(cls, # the user will have to explicitly pass weights_only=False to load these checkpoints, following the changes # introduced in torch 2.6.0. try: - checkpoint = _safe_torch_load(filename, weights_only=True, **kwargs) + checkpoint = _safe_torch_load(filename, **kwargs) except pickle.UnpicklingError as e: if weights_only is not False: if packaging.version.parse( @@ -1511,21 +1514,15 @@ def load(cls, checkpoint = _safe_torch_load(filename, weights_only=False, **kwargs) - checkpoint = _check_type_checkpoint(checkpoint) - checkpoint = checkpoint._get_state_dict() - - if isinstance(checkpoint, dict) and backend == "torch": - raise RuntimeError( - "Cannot use 'torch' backend with a dictionary-based checkpoint. " - "Please try a different backend.") - if not isinstance(checkpoint, dict) and backend == "sklearn": - raise RuntimeError( - "Cannot use 'sklearn' backend a non dictionary-based checkpoint. " - "Please try a different backend.") if backend != "sklearn": raise ValueError(f"Unsupported backend: {backend}") + if not isinstance(checkpoint, dict): + raise RuntimeError( + "Cannot use 'sklearn' backend a non dictionary-based checkpoint. " + f"Please try a different backend. Got: {type(checkpoint)}") + cebra_ = _load_cebra_with_sklearn_backend(checkpoint) n_features = cebra_.n_features_ From bd27653db9885f6dd8103937720dea037fd687d3 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 1 Feb 2026 15:06:45 +0000 Subject: [PATCH 9/9] allowlist float32d --- cebra/integrations/sklearn/cebra.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 3de6d4f..0064552 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -52,8 +52,13 @@ # windows (https://github.com/AdaptiveMotorControlLab/CEBRA/pull/281#issuecomment-3764185072) # on build (windows-latest, torch 2.6.0, python 3.12, latest sklearn) CEBRA_LOAD_SAFE_GLOBALS = [ - cebra.data.Offset, torch.torch_version.TorchVersion, np.dtype, - np.dtypes.Int32DType, np.dtypes.Float64DType, np.dtypes.Int64DType + cebra.data.Offset, + torch.torch_version.TorchVersion, + np.dtype, + np.dtypes.Int32DType, + np.dtypes.Int64DType, + np.dtypes.Float32DType, + np.dtypes.Float64DType, ]