Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions crossfit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 NVIDIA CORPORATION
# Copyright 2025 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,7 +15,6 @@
# flake8: noqa

from crossfit import backend, metric, op
from crossfit.backend.dask.cluster import Distributed, Serial
from crossfit.calculate.aggregate import Aggregator
from crossfit.calculate.module import CrossModule
from crossfit.data.array.conversion import convert_array
Expand Down Expand Up @@ -67,12 +66,10 @@ def __call__(self, *args, **kwargs):
"CrossFrame",
"crossarray",
"convert_array",
"Distributed",
"FrameBackend",
"op",
"metric",
"setup_dask_cluster",
"Serial",
]

# Using the lazy import function
Expand Down
10 changes: 5 additions & 5 deletions crossfit/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 NVIDIA CORPORATION
# Copyright 2025 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,7 +15,6 @@
# flake8: noqa
import logging

from crossfit.backend.dask.dataframe import *
from crossfit.backend.numpy.sparse import *
from crossfit.backend.pandas.array import *
from crossfit.backend.pandas.dataframe import *
Expand All @@ -24,20 +23,21 @@
from crossfit.backend.cudf.array import *
from crossfit.backend.cudf.dataframe import *
except ImportError:
logging.warning("Import Error for cudf backend in Crossfit. Skipping it.")
logging.warning("Import error for cuDF backend in CrossFit. Skipping it.")
pass

try:
from crossfit.backend.cupy.array import *
from crossfit.backend.cupy.sparse import *
except ImportError:
logging.warning("Import Error for cupy backend in Crossfit. Skipping it.")
logging.warning("Import error for CuPy backend in CrossFit. Skipping it.")
pass

# NOTE: Removing this block is useful for debugging.
try:
from crossfit.backend.torch.array import *
except ImportError:
logging.warning("Import Error for Torch backend in Crossfit. Skipping it.")
logging.warning("Import error for Torch backend in CrossFit. Skipping it.")
pass

# from crossfit.backend.tf.array import *
Expand Down
13 changes: 8 additions & 5 deletions crossfit/backend/dask/cluster.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 NVIDIA CORPORATION
# Copyright 2025 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -18,9 +18,12 @@
from contextvars import ContextVar
from typing import Any, Callable, Optional

import dask
import distributed
from dask.distributed import Client, get_client
import crossfit.config

if not crossfit.config.DISABLE_DASK:
import dask
import distributed
from dask.distributed import Client, get_client

from crossfit.backend.gpu import HAS_GPU

Expand Down Expand Up @@ -403,7 +406,7 @@ def set_dask_client(client="auto", new_cluster=None, force_new=False, **cluster_
return None if active == "auto" else active


def global_dask_client() -> Optional[distributed.Client]:
def global_dask_client() -> Optional["distributed.Client"]:
"""Get Global Dask client if it's been set.

Returns
Expand Down
24 changes: 18 additions & 6 deletions crossfit/backend/gpu.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
#
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright 2025 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# pylint: disable=unused-import
import os

from dask.distributed.diagnostics import nvml
import crossfit.config

if not crossfit.config.DISABLE_DASK:
from dask.distributed.diagnostics.nvml import device_get_count
else:
import pynvml

device_get_count = None


def _get_gpu_count():
Expand All @@ -30,7 +35,14 @@ def _get_gpu_count():
# that are incompatible with Dask-CUDA. If CUDA runtime functions are
# called before Dask-CUDA can spawn worker processes
# then Dask-CUDA it will not work correctly (raises an exception)
nvml_device_count = nvml.device_get_count()
if device_get_count is not None:
nvml_device_count = device_get_count()
else:
try:
nvml_device_count = pynvml.nvmlDeviceGetCount()
except Exception:
nvml_device_count = 0
Comment on lines -33 to +44
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would love an opinion here. This code works as is, but I am not sure if the logic is sound enough to cover all cases here.

Was wondering if I should consider copying over relevant functions from https://github.com/dask/distributed/blob/main/distributed/diagnostics/nvml.py ?


if nvml_device_count == 0:
return 0
try:
Expand Down
20 changes: 20 additions & 0 deletions crossfit/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2025 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

DISABLE_DASK = False


def set_disable_dask(value: bool):
global DISABLE_DASK
DISABLE_DASK = value
4 changes: 2 additions & 2 deletions crossfit/data/array/conversion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 NVIDIA CORPORATION
# Copyright 2025 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,9 +16,9 @@
from typing import Any, Type, TypeVar

import numpy as np
from dask.utils import Dispatch

from crossfit.utils import dispatch_utils
from crossfit.utils.dispatch_utils import Dispatch

InputType = TypeVar("InputType")
IntermediateType = TypeVar("IntermediateType")
Expand Down
4 changes: 2 additions & 2 deletions crossfit/data/array/dispatch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 NVIDIA CORPORATION
# Copyright 2025 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,9 +17,9 @@
from typing import TypeVar

import numpy as np
from dask.utils import Dispatch

from crossfit.utils import dispatch_utils
from crossfit.utils.dispatch_utils import Dispatch


class NPBackendDispatch(Dispatch):
Expand Down
6 changes: 4 additions & 2 deletions crossfit/data/dataframe/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 NVIDIA CORPORATION
# Copyright 2025 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -154,7 +154,9 @@ def cast(self, columns: type | dict | None = None, backend: type | bool = True):
frame = frame.assign(**new_columns)
else:
try:
frame = CrossFrame(self.to_dict()).apply(cf.convert_array, columns)
frame = CrossFrame(self.to_dict()).apply(
cf.data.array.conversion.convert_array, columns
)
except TypeError as err:
raise TypeError(
f"Unable to cast all column types to {columns}.\nOriginal error: {err}"
Expand Down
14 changes: 10 additions & 4 deletions crossfit/data/dataframe/dispatch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 NVIDIA CORPORATION
# Copyright 2025 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dask.utils import Dispatch
import crossfit.config
from crossfit.utils.dispatch_utils import Dispatch


class _CrossFrameDispatch(Dispatch):
Expand All @@ -22,11 +23,16 @@ def __call__(self, data, *args, **kwargs):
if isinstance(data, FrameBackend):
return data

backends = []
# TODO: Fix this
from crossfit.backend.dask.dataframe import DaskDataFrame
if not crossfit.config.DISABLE_DASK:
from crossfit.backend.dask.dataframe import DaskDataFrame

backends.append(DaskDataFrame)

from crossfit.backend.pandas.dataframe import PandasDataFrame

backends = [PandasDataFrame, DaskDataFrame]
backends.append(PandasDataFrame)

try:
from crossfit.backend.cudf.dataframe import CudfDataFrame
Expand Down
5 changes: 2 additions & 3 deletions crossfit/data/sparse/dispatch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 NVIDIA CORPORATION
# Copyright 2025 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dask.utils import Dispatch

from crossfit.data.sparse.core import SparseMatrixProtocol
from crossfit.utils.dispatch_utils import Dispatch


class _CrossSparseDispatch(Dispatch, SparseMatrixProtocol):
Expand Down
16 changes: 13 additions & 3 deletions crossfit/dataset/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 NVIDIA CORPORATION
# Copyright 2025 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,7 +16,17 @@
from dataclasses import dataclass
from typing import Dict, Optional, Union

import dask_cudf
import crossfit.config

if not crossfit.config.DISABLE_DASK:
# Still need a try/except here because crossfit.config needs to import this file
# before we can set DISABLE_DASK.
try:
from dask_cudf import read_parquet
except ImportError:
pass
else:
from cudf import read_parquet

_SPLIT_ALIASES = {
"val": ["validation", "valid", "dev"],
Expand All @@ -36,7 +46,7 @@ class Dataset:
engine: str = "parquet"

def ddf(self):
return dask_cudf.read_parquet(self.path)
return read_parquet(self.path)


class FromDirMixin:
Expand Down
35 changes: 24 additions & 11 deletions crossfit/op/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 NVIDIA CORPORATION
# Copyright 2025 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,11 +14,21 @@

import inspect

import dask.dataframe as dd
from dask.distributed import get_worker, wait
from tqdm.auto import tqdm
import crossfit.config

if not crossfit.config.DISABLE_DASK:
# Still need a try/except here because crossfit.config needs to import this file
# before we can set DISABLE_DASK.
try:
import dask.dataframe as dd
from dask.distributed import get_worker as get_dask_worker
from dask.distributed import wait

from crossfit.backend.dask.cluster import global_dask_client
except ImportError:
pass

from crossfit.backend.dask.cluster import global_dask_client
from tqdm.auto import tqdm


class Op:
Expand All @@ -41,17 +51,20 @@ def meta(self):
return None

def get_worker(self):
try:
worker = get_worker()
except ValueError:
if not crossfit.config.DISABLE_DASK:
try:
worker = get_dask_worker()
except ValueError:
worker = self
else:
worker = self

return worker

def call_dask(self, data: dd.DataFrame):
def call_dask(self, data: "dd.DataFrame"):
output = data.map_partitions(self, meta=self._build_dask_meta(data))

if global_dask_client():
if not crossfit.config.DISABLE_DASK and global_dask_client():
wait(output)

return output
Expand Down Expand Up @@ -79,7 +92,7 @@ def add_keep_cols(self, data, output):
return output

def __call__(self, data, *args, partition_info=None, **kwargs):
if isinstance(data, dd.DataFrame):
if not crossfit.config.DISABLE_DASK and isinstance(data, dd.DataFrame):
output = self.call_dask(data, *args, **kwargs)
self.teardown()
return output
Expand Down
Loading
Loading