From 712808de42ee266493f0f26cb04ef82c65091a83 Mon Sep 17 00:00:00 2001 From: Sarah Yurick Date: Wed, 28 May 2025 12:06:05 -0700 Subject: [PATCH 1/8] Enable CrossFit to work without Dask Signed-off-by: Sarah Yurick --- crossfit/backend/__init__.py | 15 ++++-- crossfit/backend/dask/cluster.py | 13 +++-- crossfit/backend/gpu.py | 22 ++++++--- crossfit/data/array/conversion.py | 4 +- crossfit/data/array/dispatch.py | 4 +- crossfit/data/dataframe/dispatch.py | 15 ++++-- crossfit/data/sparse/dispatch.py | 5 +- crossfit/dataset/base.py | 9 ++-- crossfit/op/base.py | 18 ++++--- crossfit/op/vector_search.py | 18 ++++--- crossfit/utils/dispatch_utils.py | 77 ++++++++++++++++++++++++++++- requirements/base.txt | 2 - requirements/cuda12x.txt | 3 -- requirements/dask.txt | 2 + requirements/dask_cuda12x.txt | 3 ++ setup.py | 10 ++-- 16 files changed, 167 insertions(+), 53 deletions(-) create mode 100644 requirements/dask.txt create mode 100644 requirements/dask_cuda12x.txt diff --git a/crossfit/backend/__init__.py b/crossfit/backend/__init__.py index 32b40602..9e6328b4 100644 --- a/crossfit/backend/__init__.py +++ b/crossfit/backend/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 NVIDIA CORPORATION +# Copyright 2025 NVIDIA CORPORATION # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,7 +15,12 @@ # flake8: noqa import logging -from crossfit.backend.dask.dataframe import * +try: + from crossfit.backend.dask.dataframe import * +except ImportError: + logging.warning("Import error for Dask backend in CrossFit. Skipping it.") + pass + from crossfit.backend.numpy.sparse import * from crossfit.backend.pandas.array import * from crossfit.backend.pandas.dataframe import * @@ -24,20 +29,20 @@ from crossfit.backend.cudf.array import * from crossfit.backend.cudf.dataframe import * except ImportError: - logging.warning("Import Error for cudf backend in Crossfit. Skipping it.") + logging.warning("Import error for cuDF backend in Crossfit. Skipping it.") pass try: from crossfit.backend.cupy.array import * from crossfit.backend.cupy.sparse import * except ImportError: - logging.warning("Import Error for cupy backend in Crossfit. Skipping it.") + logging.warning("Import error for CuPy backend in Crossfit. Skipping it.") pass try: from crossfit.backend.torch.array import * except ImportError: - logging.warning("Import Error for Torch backend in Crossfit. Skipping it.") + logging.warning("Import error for Torch backend in Crossfit. Skipping it.") pass # from crossfit.backend.tf.array import * diff --git a/crossfit/backend/dask/cluster.py b/crossfit/backend/dask/cluster.py index 04b54a8c..2acef612 100644 --- a/crossfit/backend/dask/cluster.py +++ b/crossfit/backend/dask/cluster.py @@ -1,4 +1,4 @@ -# Copyright 2023 NVIDIA CORPORATION +# Copyright 2025 NVIDIA CORPORATION # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,9 +18,12 @@ from contextvars import ContextVar from typing import Any, Callable, Optional -import dask -import distributed -from dask.distributed import Client, get_client +try: + import dask + import distributed + from dask.distributed import Client, get_client +except ImportError: + pass from crossfit.backend.gpu import HAS_GPU @@ -403,7 +406,7 @@ def set_dask_client(client="auto", new_cluster=None, force_new=False, **cluster_ return None if active == "auto" else active -def global_dask_client() -> Optional[distributed.Client]: +def global_dask_client() -> Optional["distributed.Client"]: """Get Global Dask client if it's been set. Returns diff --git a/crossfit/backend/gpu.py b/crossfit/backend/gpu.py index 3d582e99..92de1c98 100644 --- a/crossfit/backend/gpu.py +++ b/crossfit/backend/gpu.py @@ -1,23 +1,26 @@ -# -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright 2025 NVIDIA CORPORATION # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# # pylint: disable=unused-import import os -from dask.distributed.diagnostics import nvml +try: + from dask.distributed.diagnostics.nvml import device_get_count +except ImportError: + import pynvml + + device_get_count = None def _get_gpu_count(): @@ -30,7 +33,14 @@ def _get_gpu_count(): # that are incompatible with Dask-CUDA. If CUDA runtime functions are # called before Dask-CUDA can spawn worker processes # then Dask-CUDA it will not work correctly (raises an exception) - nvml_device_count = nvml.device_get_count() + if device_get_count is not None: + nvml_device_count = device_get_count() + else: + try: + nvml_device_count = pynvml.nvmlDeviceGetCount() + except Exception: + nvml_device_count = 0 + if nvml_device_count == 0: return 0 try: diff --git a/crossfit/data/array/conversion.py b/crossfit/data/array/conversion.py index 77b8163e..0f3c3e6c 100644 --- a/crossfit/data/array/conversion.py +++ b/crossfit/data/array/conversion.py @@ -1,4 +1,4 @@ -# Copyright 2023 NVIDIA CORPORATION +# Copyright 2025 NVIDIA CORPORATION # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,9 +16,9 @@ from typing import Any, Type, TypeVar import numpy as np -from dask.utils import Dispatch from crossfit.utils import dispatch_utils +from crossfit.utils.dispatch_utils import Dispatch InputType = TypeVar("InputType") IntermediateType = TypeVar("IntermediateType") diff --git a/crossfit/data/array/dispatch.py b/crossfit/data/array/dispatch.py index f8154f92..c454a482 100644 --- a/crossfit/data/array/dispatch.py +++ b/crossfit/data/array/dispatch.py @@ -1,4 +1,4 @@ -# Copyright 2023 NVIDIA CORPORATION +# Copyright 2025 NVIDIA CORPORATION # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,9 +17,9 @@ from typing import TypeVar import numpy as np -from dask.utils import Dispatch from crossfit.utils import dispatch_utils +from crossfit.utils.dispatch_utils import Dispatch class NPBackendDispatch(Dispatch): diff --git a/crossfit/data/dataframe/dispatch.py b/crossfit/data/dataframe/dispatch.py index 011a077b..7f41a25c 100644 --- a/crossfit/data/dataframe/dispatch.py +++ b/crossfit/data/dataframe/dispatch.py @@ -1,4 +1,4 @@ -# Copyright 2023 NVIDIA CORPORATION +# Copyright 2025 NVIDIA CORPORATION # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dask.utils import Dispatch +from crossfit.utils.dispatch_utils import Dispatch class _CrossFrameDispatch(Dispatch): @@ -22,11 +22,18 @@ def __call__(self, data, *args, **kwargs): if isinstance(data, FrameBackend): return data + backends = [] # TODO: Fix this - from crossfit.backend.dask.dataframe import DaskDataFrame + try: + from crossfit.backend.dask.dataframe import DaskDataFrame + + backends.append(DaskDataFrame) + except ImportError: + pass + from crossfit.backend.pandas.dataframe import PandasDataFrame - backends = [PandasDataFrame, DaskDataFrame] + backends.append(PandasDataFrame) try: from crossfit.backend.cudf.dataframe import CudfDataFrame diff --git a/crossfit/data/sparse/dispatch.py b/crossfit/data/sparse/dispatch.py index 45c0d6e8..b2293010 100644 --- a/crossfit/data/sparse/dispatch.py +++ b/crossfit/data/sparse/dispatch.py @@ -1,4 +1,4 @@ -# Copyright 2023 NVIDIA CORPORATION +# Copyright 2025 NVIDIA CORPORATION # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dask.utils import Dispatch - from crossfit.data.sparse.core import SparseMatrixProtocol +from crossfit.utils.dispatch_utils import Dispatch class _CrossSparseDispatch(Dispatch, SparseMatrixProtocol): diff --git a/crossfit/dataset/base.py b/crossfit/dataset/base.py index 0a9094ab..e9743f5f 100644 --- a/crossfit/dataset/base.py +++ b/crossfit/dataset/base.py @@ -1,4 +1,4 @@ -# Copyright 2023 NVIDIA CORPORATION +# Copyright 2025 NVIDIA CORPORATION # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +16,10 @@ from dataclasses import dataclass from typing import Dict, Optional, Union -import dask_cudf +try: + from dask_cudf import read_parquet +except ImportError: + from cudf import read_parquet _SPLIT_ALIASES = { "val": ["validation", "valid", "dev"], @@ -36,7 +39,7 @@ class Dataset: engine: str = "parquet" def ddf(self): - return dask_cudf.read_parquet(self.path) + return read_parquet(self.path) class FromDirMixin: diff --git a/crossfit/op/base.py b/crossfit/op/base.py index 38f5df36..7a16a898 100644 --- a/crossfit/op/base.py +++ b/crossfit/op/base.py @@ -1,4 +1,4 @@ -# Copyright 2023 NVIDIA CORPORATION +# Copyright 2025 NVIDIA CORPORATION # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,8 +14,14 @@ import inspect -import dask.dataframe as dd -from dask.distributed import get_worker, wait +try: + import dask.dataframe as dd + from dask.distributed import get_worker as get_dask_worker + from dask.distributed import wait +except ImportError: + dd = None + get_dask_worker = None + from tqdm.auto import tqdm from crossfit.backend.dask.cluster import global_dask_client @@ -42,13 +48,13 @@ def meta(self): def get_worker(self): try: - worker = get_worker() + worker = get_dask_worker() if get_dask_worker is not None else self except ValueError: worker = self return worker - def call_dask(self, data: dd.DataFrame): + def call_dask(self, data: "dd.DataFrame"): output = data.map_partitions(self, meta=self._build_dask_meta(data)) if global_dask_client(): @@ -79,7 +85,7 @@ def add_keep_cols(self, data, output): return output def __call__(self, data, *args, partition_info=None, **kwargs): - if isinstance(data, dd.DataFrame): + if dd is not None and isinstance(data, dd.DataFrame): output = self.call_dask(data, *args, **kwargs) self.teardown() return output diff --git a/crossfit/op/vector_search.py b/crossfit/op/vector_search.py index 1491c77c..a9ba7fda 100644 --- a/crossfit/op/vector_search.py +++ b/crossfit/op/vector_search.py @@ -1,4 +1,4 @@ -# Copyright 2023 NVIDIA CORPORATION +# Copyright 2025 NVIDIA CORPORATION # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,13 +17,17 @@ import cudf import cupy as cp import cuvs -import dask.dataframe as dd import pylibraft -from cuml.dask.neighbors import NearestNeighbors -from dask import delayed -from dask_cudf import from_delayed from packaging.version import parse as parse_version +try: + import dask.dataframe as dd + from cuml.dask.neighbors import NearestNeighbors + from dask import delayed + from dask_cudf import from_delayed +except ImportError: + pass + from crossfit.backend.cudf.series import create_list_series_from_1d_or_2d_ar from crossfit.backend.dask.cluster import global_dask_client from crossfit.dataset.base import EmbeddingDatataset @@ -295,8 +299,8 @@ def _get_embedding_cupy(data, embedding_col, normalize=True): def _per_dim_ddf( - data: dd.DataFrame, embedding_col: str, index_col: str = "index", normalize: bool = True -) -> dd.DataFrame: + data: "dd.DataFrame", embedding_col: str, index_col: str = "index", normalize: bool = True +) -> "dd.DataFrame": dim = len(data.head()[embedding_col].iloc[0]) def to_map(part, dim): diff --git a/crossfit/utils/dispatch_utils.py b/crossfit/utils/dispatch_utils.py index 1bff95ec..af0b80b8 100644 --- a/crossfit/utils/dispatch_utils.py +++ b/crossfit/utils/dispatch_utils.py @@ -1,4 +1,4 @@ -# Copyright 2023 NVIDIA CORPORATION +# Copyright 2025 NVIDIA CORPORATION # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,80 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dask.utils import Dispatch +# Code recycled from Dask: https://github.com/dask/dask/blob/main/dask/utils.py +class Dispatch: + """Simple single dispatch.""" + + def __init__(self, name=None): + self._lookup = {} + self._lazy = {} + if name: + self.__name__ = name + + def register(self, type, func=None): + """Register dispatch of `func` on arguments of type `type`""" + + def wrapper(func): + if isinstance(type, tuple): + for t in type: + self.register(t, func) + else: + self._lookup[type] = func + return func + + return wrapper(func) if func is not None else wrapper + + def register_lazy(self, toplevel, func=None): + """ + Register a registration function which will be called if the + *toplevel* module (e.g. 'pandas') is ever loaded. + """ + + def wrapper(func): + self._lazy[toplevel] = func + return func + + return wrapper(func) if func is not None else wrapper + + def dispatch(self, cls): + """Return the function implementation for the given ``cls``""" + lk = self._lookup + for cls2 in cls.__mro__: + # Is a lazy registration function present? + toplevel, _, _ = cls2.__module__.partition(".") + try: + register = self._lazy[toplevel] + except KeyError: + pass + else: + register() + self._lazy.pop(toplevel, None) + return self.dispatch(cls) # recurse + try: + impl = lk[cls2] + except KeyError: + pass + else: + if cls is not cls2: + # Cache lookup + lk[cls] = impl + return impl + raise TypeError(f"No dispatch for {cls}") + + def __call__(self, arg, *args, **kwargs): + """ + Call the corresponding method based on type of argument. + """ + meth = self.dispatch(type(arg)) + return meth(arg, *args, **kwargs) + + @property + def __doc__(self): + try: + func = self.dispatch(object) + return func.__doc__ + except TypeError: + return "Single Dispatch for %s" % self.__name__ def supports(dispatch: Dispatch) -> set: diff --git a/requirements/base.txt b/requirements/base.txt index 1ba7d78a..93c96e87 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -3,8 +3,6 @@ pandas pyarrow numpy numba -dask[dataframe]>=2024.12.1 -distributed>=2024.12.1 scikit-learn>=1.2.0 fsspec>=2022.7.1 tensorflow_metadata diff --git a/requirements/cuda12x.txt b/requirements/cuda12x.txt index 24632b9d..7b2b2ad9 100644 --- a/requirements/cuda12x.txt +++ b/requirements/cuda12x.txt @@ -1,10 +1,7 @@ cudf-cu12>=24.4 -dask-cudf-cu12>=24.4 cuml-cu12>=24.4 pylibraft-cu12>=24.4 -raft-dask-cu12>=24.4 cuvs-cu12>=24.4 -dask-cuda>=24.6 torch>=2.0 transformers>=4.0 curated-transformers>=1.0 diff --git a/requirements/dask.txt b/requirements/dask.txt new file mode 100644 index 00000000..1bef038d --- /dev/null +++ b/requirements/dask.txt @@ -0,0 +1,2 @@ +dask[dataframe]>=2024.12.1 +distributed>=2024.12.1 diff --git a/requirements/dask_cuda12x.txt b/requirements/dask_cuda12x.txt new file mode 100644 index 00000000..318fa9df --- /dev/null +++ b/requirements/dask_cuda12x.txt @@ -0,0 +1,3 @@ +dask-cudf-cu12>=24.4 +raft-dask-cu12>=24.4 +dask-cuda>=24.6 diff --git a/setup.py b/setup.py index e9ab05e3..2b2d6bc4 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Copyright 2024 NVIDIA CORPORATION +# Copyright 2025 NVIDIA CORPORATION # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,11 +36,15 @@ def read_requirements(filename): return [line for line in lineiter if line and not line.startswith("#")] +_dask = read_requirements("requirements/dask.txt") +_dask_cuda12x = read_requirements("requirements/dask_cuda12x.txt") _dev = read_requirements("requirements/dev.txt") requirements = { - "base": read_requirements("requirements/base.txt"), - "cuda12x": read_requirements("requirements/cuda12x.txt"), + "base": read_requirements("requirements/base.txt") + _dask, + "base_no_dask": read_requirements("requirements/base.txt"), + "cuda12x": read_requirements("requirements/cuda12x.txt") + _dask_cuda12x, + "cuda12x_no_dask": read_requirements("requirements/cuda12x.txt"), "dev": _dev, "tensorflow": read_requirements("requirements/tensorflow.txt"), "pytorch": read_requirements("requirements/pytorch.txt"), From 21649e45e68d2ef5a37028f5fa3b3e1d070f9578 Mon Sep 17 00:00:00 2001 From: Sarah Yurick Date: Wed, 28 May 2025 12:15:31 -0700 Subject: [PATCH 2/8] add whitespace Signed-off-by: Sarah Yurick --- crossfit/utils/dispatch_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/crossfit/utils/dispatch_utils.py b/crossfit/utils/dispatch_utils.py index af0b80b8..06ac1211 100644 --- a/crossfit/utils/dispatch_utils.py +++ b/crossfit/utils/dispatch_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + # Code recycled from Dask: https://github.com/dask/dask/blob/main/dask/utils.py class Dispatch: """Simple single dispatch.""" From f9224a7dc1aa8db784ba311a57e7ff02bba85ff7 Mon Sep 17 00:00:00 2001 From: Sarah Yurick <53962159+sarahyurick@users.noreply.github.com> Date: Thu, 29 May 2025 13:46:36 -0700 Subject: [PATCH 3/8] Fix `convert_array` bug --- crossfit/data/dataframe/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crossfit/data/dataframe/core.py b/crossfit/data/dataframe/core.py index 341bffa0..407d7f72 100644 --- a/crossfit/data/dataframe/core.py +++ b/crossfit/data/dataframe/core.py @@ -1,4 +1,4 @@ -# Copyright 2023 NVIDIA CORPORATION +# Copyright 2025 NVIDIA CORPORATION # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -154,7 +154,7 @@ def cast(self, columns: type | dict | None = None, backend: type | bool = True): frame = frame.assign(**new_columns) else: try: - frame = CrossFrame(self.to_dict()).apply(cf.convert_array, columns) + frame = CrossFrame(self.to_dict()).apply(cf.data.array.conversion.convert_array, columns) except TypeError as err: raise TypeError( f"Unable to cast all column types to {columns}.\nOriginal error: {err}" From fa32f363c76e0c777c1034130dcff479aa5b29ea Mon Sep 17 00:00:00 2001 From: Sarah Yurick <53962159+sarahyurick@users.noreply.github.com> Date: Thu, 29 May 2025 13:49:51 -0700 Subject: [PATCH 4/8] format --- crossfit/data/dataframe/core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/crossfit/data/dataframe/core.py b/crossfit/data/dataframe/core.py index 407d7f72..9e403bff 100644 --- a/crossfit/data/dataframe/core.py +++ b/crossfit/data/dataframe/core.py @@ -154,7 +154,9 @@ def cast(self, columns: type | dict | None = None, backend: type | bool = True): frame = frame.assign(**new_columns) else: try: - frame = CrossFrame(self.to_dict()).apply(cf.data.array.conversion.convert_array, columns) + frame = CrossFrame(self.to_dict()).apply( + cf.data.array.conversion.convert_array, columns + ) except TypeError as err: raise TypeError( f"Unable to cast all column types to {columns}.\nOriginal error: {err}" From 5234237786e46ffdacfdd8ed270727bf78b3ded0 Mon Sep 17 00:00:00 2001 From: Sarah Yurick Date: Mon, 30 Jun 2025 11:42:39 -0700 Subject: [PATCH 5/8] fix req files Signed-off-by: Sarah Yurick --- requirements/cuda12x.txt | 4 ---- requirements/dask_cuda12x.txt | 6 +++--- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/requirements/cuda12x.txt b/requirements/cuda12x.txt index 774fd90c..2d4aeca3 100644 --- a/requirements/cuda12x.txt +++ b/requirements/cuda12x.txt @@ -1,7 +1,3 @@ -cudf-cu12>=24.4 -cuml-cu12>=24.4 -pylibraft-cu12>=24.4 -cuvs-cu12>=24.4 cudf-cu12>=25.6 cuml-cu12>=25.6 pylibraft-cu12>=25.6 diff --git a/requirements/dask_cuda12x.txt b/requirements/dask_cuda12x.txt index 318fa9df..c894b080 100644 --- a/requirements/dask_cuda12x.txt +++ b/requirements/dask_cuda12x.txt @@ -1,3 +1,3 @@ -dask-cudf-cu12>=24.4 -raft-dask-cu12>=24.4 -dask-cuda>=24.6 +dask-cudf-cu12>=25.6 +raft-dask-cu12>=25.6 +dask-cuda>=25.6 From 9c1ed15b3b90a0288c33ad1590f48dc604e4d4dd Mon Sep 17 00:00:00 2001 From: Sarah Yurick Date: Thu, 3 Jul 2025 12:34:39 -0700 Subject: [PATCH 6/8] add DISABLE_DASK param, update setup.py Signed-off-by: Sarah Yurick --- crossfit/__init__.py | 5 +---- crossfit/backend/__init__.py | 13 ++++--------- crossfit/backend/dask/cluster.py | 6 +++--- crossfit/backend/gpu.py | 6 ++++-- crossfit/config.py | 20 ++++++++++++++++++++ crossfit/data/dataframe/dispatch.py | 5 ++--- crossfit/dataset/base.py | 13 ++++++++++--- crossfit/op/base.py | 29 ++++++++++++++++------------- crossfit/op/vector_search.py | 28 +++++++++++++++++++--------- setup.py | 7 +++---- 10 files changed, 82 insertions(+), 50 deletions(-) create mode 100644 crossfit/config.py diff --git a/crossfit/__init__.py b/crossfit/__init__.py index 5e563f08..b7de710d 100644 --- a/crossfit/__init__.py +++ b/crossfit/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 NVIDIA CORPORATION +# Copyright 2025 NVIDIA CORPORATION # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,7 +15,6 @@ # flake8: noqa from crossfit import backend, metric, op -from crossfit.backend.dask.cluster import Distributed, Serial from crossfit.calculate.aggregate import Aggregator from crossfit.calculate.module import CrossModule from crossfit.data.array.conversion import convert_array @@ -67,12 +66,10 @@ def __call__(self, *args, **kwargs): "CrossFrame", "crossarray", "convert_array", - "Distributed", "FrameBackend", "op", "metric", "setup_dask_cluster", - "Serial", ] # Using the lazy import function diff --git a/crossfit/backend/__init__.py b/crossfit/backend/__init__.py index 9e6328b4..2be1d7a9 100644 --- a/crossfit/backend/__init__.py +++ b/crossfit/backend/__init__.py @@ -15,12 +15,6 @@ # flake8: noqa import logging -try: - from crossfit.backend.dask.dataframe import * -except ImportError: - logging.warning("Import error for Dask backend in CrossFit. Skipping it.") - pass - from crossfit.backend.numpy.sparse import * from crossfit.backend.pandas.array import * from crossfit.backend.pandas.dataframe import * @@ -29,20 +23,21 @@ from crossfit.backend.cudf.array import * from crossfit.backend.cudf.dataframe import * except ImportError: - logging.warning("Import error for cuDF backend in Crossfit. Skipping it.") + logging.warning("Import error for cuDF backend in CrossFit. Skipping it.") pass try: from crossfit.backend.cupy.array import * from crossfit.backend.cupy.sparse import * except ImportError: - logging.warning("Import error for CuPy backend in Crossfit. Skipping it.") + logging.warning("Import error for CuPy backend in CrossFit. Skipping it.") pass +# NOTE: Removing this block is useful for debugging. try: from crossfit.backend.torch.array import * except ImportError: - logging.warning("Import error for Torch backend in Crossfit. Skipping it.") + logging.warning("Import error for Torch backend in CrossFit. Skipping it.") pass # from crossfit.backend.tf.array import * diff --git a/crossfit/backend/dask/cluster.py b/crossfit/backend/dask/cluster.py index 2acef612..ebeaa32b 100644 --- a/crossfit/backend/dask/cluster.py +++ b/crossfit/backend/dask/cluster.py @@ -18,12 +18,12 @@ from contextvars import ContextVar from typing import Any, Callable, Optional -try: +import crossfit.config + +if not crossfit.config.DISABLE_DASK: import dask import distributed from dask.distributed import Client, get_client -except ImportError: - pass from crossfit.backend.gpu import HAS_GPU diff --git a/crossfit/backend/gpu.py b/crossfit/backend/gpu.py index 92de1c98..1f65512f 100644 --- a/crossfit/backend/gpu.py +++ b/crossfit/backend/gpu.py @@ -15,9 +15,11 @@ # pylint: disable=unused-import import os -try: +import crossfit.config + +if not crossfit.config.DISABLE_DASK: from dask.distributed.diagnostics.nvml import device_get_count -except ImportError: +else: import pynvml device_get_count = None diff --git a/crossfit/config.py b/crossfit/config.py new file mode 100644 index 00000000..764fb790 --- /dev/null +++ b/crossfit/config.py @@ -0,0 +1,20 @@ +# Copyright 2025 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +DISABLE_DASK = False + + +def set_disable_dask(value: bool): + global DISABLE_DASK + DISABLE_DASK = value diff --git a/crossfit/data/dataframe/dispatch.py b/crossfit/data/dataframe/dispatch.py index 7f41a25c..d7abd849 100644 --- a/crossfit/data/dataframe/dispatch.py +++ b/crossfit/data/dataframe/dispatch.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import crossfit.config from crossfit.utils.dispatch_utils import Dispatch @@ -24,12 +25,10 @@ def __call__(self, data, *args, **kwargs): backends = [] # TODO: Fix this - try: + if not crossfit.config.DISABLE_DASK: from crossfit.backend.dask.dataframe import DaskDataFrame backends.append(DaskDataFrame) - except ImportError: - pass from crossfit.backend.pandas.dataframe import PandasDataFrame diff --git a/crossfit/dataset/base.py b/crossfit/dataset/base.py index e9743f5f..5aa293fe 100644 --- a/crossfit/dataset/base.py +++ b/crossfit/dataset/base.py @@ -16,9 +16,16 @@ from dataclasses import dataclass from typing import Dict, Optional, Union -try: - from dask_cudf import read_parquet -except ImportError: +import crossfit.config + +if not crossfit.config.DISABLE_DASK: + # Still need a try/except here because crossfit.config needs to import this file + # before we can set DISABLE_DASK. + try: + from dask_cudf import read_parquet + except ImportError: + pass +else: from cudf import read_parquet _SPLIT_ALIASES = { diff --git a/crossfit/op/base.py b/crossfit/op/base.py index 7a16a898..98514561 100644 --- a/crossfit/op/base.py +++ b/crossfit/op/base.py @@ -14,18 +14,21 @@ import inspect -try: - import dask.dataframe as dd - from dask.distributed import get_worker as get_dask_worker - from dask.distributed import wait -except ImportError: - dd = None - get_dask_worker = None +import crossfit.config + +if not crossfit.config.DISABLE_DASK: + # Still need a try/except here because crossfit.config needs to import this file + # before we can set DISABLE_DASK. + try: + import dask.dataframe as dd + from dask.distributed import get_worker as get_dask_worker + from dask.distributed import wait + from crossfit.backend.dask.cluster import global_dask_client + except ImportError: + pass from tqdm.auto import tqdm -from crossfit.backend.dask.cluster import global_dask_client - class Op: def __init__(self, pre=None, cols=False, keep_cols=None): @@ -47,9 +50,9 @@ def meta(self): return None def get_worker(self): - try: + if not crossfit.config.DISABLE_DASK: worker = get_dask_worker() if get_dask_worker is not None else self - except ValueError: + else: worker = self return worker @@ -57,7 +60,7 @@ def get_worker(self): def call_dask(self, data: "dd.DataFrame"): output = data.map_partitions(self, meta=self._build_dask_meta(data)) - if global_dask_client(): + if not crossfit.config.DISABLE_DASK and global_dask_client(): wait(output) return output @@ -85,7 +88,7 @@ def add_keep_cols(self, data, output): return output def __call__(self, data, *args, partition_info=None, **kwargs): - if dd is not None and isinstance(data, dd.DataFrame): + if not crossfit.config.DISABLE_DASK and isinstance(data, dd.DataFrame): output = self.call_dask(data, *args, **kwargs) self.teardown() return output diff --git a/crossfit/op/vector_search.py b/crossfit/op/vector_search.py index a9ba7fda..3f0e7787 100644 --- a/crossfit/op/vector_search.py +++ b/crossfit/op/vector_search.py @@ -20,16 +20,21 @@ import pylibraft from packaging.version import parse as parse_version -try: - import dask.dataframe as dd - from cuml.dask.neighbors import NearestNeighbors - from dask import delayed - from dask_cudf import from_delayed -except ImportError: - pass +import crossfit.config + +if not crossfit.config.DISABLE_DASK: + # Still need a try/except here because crossfit.config needs to import this file + # before we can set DISABLE_DASK. + try: + import dask.dataframe as dd + from cuml.dask.neighbors import NearestNeighbors + from dask import delayed + from dask_cudf import from_delayed + from crossfit.backend.dask.cluster import global_dask_client + except ImportError: + pass from crossfit.backend.cudf.series import create_list_series_from_1d_or_2d_ar -from crossfit.backend.dask.cluster import global_dask_client from crossfit.dataset.base import EmbeddingDatataset from crossfit.op.base import Op @@ -209,10 +214,15 @@ def __init__( self.normalize = normalize def fit(self, items, **kwargs): + if not crossfit.config.DISABLE_DASK: + client = global_dask_client() + else: + client = None + knn = NearestNeighbors( n_neighbors=self.k, algorithm=self.algorithm, - client=global_dask_client(), + client=client, metric=self.metric, **kwargs, ) diff --git a/setup.py b/setup.py index 001bdb86..b0652b4a 100644 --- a/setup.py +++ b/setup.py @@ -41,10 +41,9 @@ def read_requirements(filename): _dev = read_requirements("requirements/dev.txt") requirements = { - "base": read_requirements("requirements/base.txt") + _dask, - "base_no_dask": read_requirements("requirements/base.txt"), + "base": read_requirements("requirements/base.txt"), + "base_with_dask": read_requirements("requirements/base.txt") + _dask, "cuda12x": read_requirements("requirements/cuda12x.txt") + _dask_cuda12x, - "cuda12x_no_dask": read_requirements("requirements/cuda12x.txt"), "dev": _dev, "tensorflow": read_requirements("requirements/tensorflow.txt"), "pytorch": read_requirements("requirements/pytorch.txt"), @@ -75,7 +74,7 @@ def read_requirements(filename): version=VERSION, packages=find_packages(), package_dir={"crossfit": "crossfit"}, - install_requires=requirements["base"], + install_requires=requirements["base_with_dask"], include_package_data=True, extras_require={ **requirements, From c496c7cf1e5ccb7413e83e22cab3cd030d82c781 Mon Sep 17 00:00:00 2001 From: Sarah Yurick Date: Thu, 3 Jul 2025 12:38:33 -0700 Subject: [PATCH 7/8] run isort Signed-off-by: Sarah Yurick --- crossfit/op/base.py | 1 + crossfit/op/vector_search.py | 1 + 2 files changed, 2 insertions(+) diff --git a/crossfit/op/base.py b/crossfit/op/base.py index 98514561..e257d76b 100644 --- a/crossfit/op/base.py +++ b/crossfit/op/base.py @@ -23,6 +23,7 @@ import dask.dataframe as dd from dask.distributed import get_worker as get_dask_worker from dask.distributed import wait + from crossfit.backend.dask.cluster import global_dask_client except ImportError: pass diff --git a/crossfit/op/vector_search.py b/crossfit/op/vector_search.py index 3f0e7787..69ac3b59 100644 --- a/crossfit/op/vector_search.py +++ b/crossfit/op/vector_search.py @@ -30,6 +30,7 @@ from cuml.dask.neighbors import NearestNeighbors from dask import delayed from dask_cudf import from_delayed + from crossfit.backend.dask.cluster import global_dask_client except ImportError: pass From b9cc359810a5a1bf1ab0456ab543b65aefd36c3a Mon Sep 17 00:00:00 2001 From: Sarah Yurick Date: Thu, 3 Jul 2025 13:03:05 -0700 Subject: [PATCH 8/8] fix no worker error Signed-off-by: Sarah Yurick --- crossfit/op/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/crossfit/op/base.py b/crossfit/op/base.py index e257d76b..0dc3afad 100644 --- a/crossfit/op/base.py +++ b/crossfit/op/base.py @@ -52,7 +52,10 @@ def meta(self): def get_worker(self): if not crossfit.config.DISABLE_DASK: - worker = get_dask_worker() if get_dask_worker is not None else self + try: + worker = get_dask_worker() + except ValueError: + worker = self else: worker = self