Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 85 additions & 9 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2619,7 +2619,32 @@ def test_kernel(self, kernel, make_input, input_dtype, output_dtype, device, sca
scale=scale,
)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
@pytest.mark.parametrize(
("kernel", "input_type"),
[
(F.to_dtype_image, torch.Tensor),
(F.to_dtype_video, tv_tensors.Video),
pytest.param(
F._misc._to_dtype_image_cvcuda,
None,
marks=pytest.mark.needs_cvcuda,
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._misc._to_dtype_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.to_dtype, kernel=kernel, input_type=input_type)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this test!


@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image,
make_video,
pytest.param(make_image_cvcuda, marks=pytest.mark.needs_cvcuda),
],
)
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
Expand All @@ -2634,7 +2659,14 @@ def test_functional(self, make_input, input_dtype, output_dtype, device, scale):

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
[
make_image_tensor,
make_image,
make_bounding_boxes,
make_segmentation_mask,
make_video,
pytest.param(make_image_cvcuda, marks=pytest.mark.needs_cvcuda),
],
)
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
Expand Down Expand Up @@ -2680,25 +2712,69 @@ def fn(value):

return torch.tensor(tree_map(fn, image.tolist())).to(dtype=output_dtype, device=image.device)

def _get_dtype_conversion_atol_cvcuda(self, input_dtype, output_dtype):
in_bits = torch.iinfo(input_dtype).bits if not input_dtype.is_floating_point else None
out_bits = torch.iinfo(output_dtype).bits if not output_dtype.is_floating_point else None
narrows_bits = in_bits is not None and out_bits is not None and out_bits < in_bits

# int->int with narrowing bits, allow atol=1 for rounding diffs
if narrows_bits:
atol = 1
# float->int check for same diff, rounding error on float
elif input_dtype.is_floating_point and not output_dtype.is_floating_point:
atol = 1
# if generating a float value from an int, allow small rounding error
elif not input_dtype.is_floating_point and output_dtype.is_floating_point:
atol = 1e-7
# all other cases, should be exact
# uint8 -> uint16 promotion would be here
else:
atol = 0

return atol

@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16])
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("scale", (True, False))
def test_image_correctness(self, input_dtype, output_dtype, device, scale):
@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(make_image_cvcuda, marks=pytest.mark.needs_cvcuda),
],
)
@pytest.mark.parametrize("fn", [F.to_dtype, transform_cls_to_functional(transforms.ToDtype)])
def test_image_correctness(self, input_dtype, output_dtype, device, scale, make_input, fn):
if input_dtype.is_floating_point and output_dtype == torch.int64:
pytest.xfail("float to int64 conversion is not supported")
if input_dtype == torch.uint8 and output_dtype == torch.uint16 and device == "cuda":
pytest.xfail("uint8 to uint16 conversion is not supported on cuda")
if (
input_dtype == torch.uint16
and output_dtype == torch.uint8
and not scale
and make_input is make_image_cvcuda
):
pytest.xfail("uint16 to uint8 conversion without scale is not supported for CV-CUDA.")

input = make_image(dtype=input_dtype, device=device)
input = make_input(dtype=input_dtype, device=device)
out = fn(input, dtype=output_dtype, scale=scale)

if make_input is make_image_cvcuda:
input = F.cvcuda_to_tensor(input)
out = F.cvcuda_to_tensor(out)

out = F.to_dtype(input, dtype=output_dtype, scale=scale)
expected = self.reference_convert_dtype_image_tensor(input, dtype=output_dtype, scale=scale)

if input_dtype.is_floating_point and not output_dtype.is_floating_point and scale:
torch.testing.assert_close(out, expected, atol=1, rtol=0)
else:
torch.testing.assert_close(out, expected)
atol, rtol = None, None
if make_input is make_image_cvcuda:
atol = self._get_dtype_conversion_atol_cvcuda(input_dtype, output_dtype)
rtol = 0
elif input_dtype.is_floating_point and not output_dtype.is_floating_point and scale:
atol, rtol = 1, 0

torch.testing.assert_close(out, expected, atol=atol, rtol=rtol)

def was_scaled(self, inpt):
# this assumes the target dtype is float
Expand Down
13 changes: 10 additions & 3 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2.functional._utils import _is_cvcuda_tensor

from ._utils import (
_parse_labels_getter,
Expand Down Expand Up @@ -267,7 +268,7 @@ class ToDtype(Transform):
Default: ``False``.
"""

_transformed_types = (torch.Tensor,)
_transformed_types = (torch.Tensor, _is_cvcuda_tensor)

def __init__(
self, dtype: Union[torch.dtype, dict[Union[type, str], Optional[torch.dtype]]], scale: bool = False
Expand All @@ -294,7 +295,11 @@ def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
if isinstance(self.dtype, torch.dtype):
# For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype
# is a simple torch.dtype
if not is_pure_tensor(inpt) and not isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
if (
not is_pure_tensor(inpt)
and not isinstance(inpt, (tv_tensors.Image, tv_tensors.Video))
and not _is_cvcuda_tensor(inpt)
):
return inpt

dtype: Optional[torch.dtype] = self.dtype
Expand All @@ -311,7 +316,9 @@ def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
'e.g. dtype={tv_tensors.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.'
)

supports_scaling = is_pure_tensor(inpt) or isinstance(inpt, (tv_tensors.Image, tv_tensors.Video))
supports_scaling = (
is_pure_tensor(inpt) or isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or _is_cvcuda_tensor(inpt)
)
if dtype is None:
if self.scale and supports_scaling:
warnings.warn(
Expand Down
104 changes: 102 additions & 2 deletions torchvision/transforms/v2/functional/_misc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Optional
from typing import Optional, TYPE_CHECKING

import PIL.Image
import torch
Expand All @@ -13,7 +13,12 @@

from ._meta import _convert_bounding_box_format

from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor

CVCUDA_AVAILABLE = _is_cvcuda_available()

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]


def normalize(
Expand Down Expand Up @@ -340,6 +345,101 @@ def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: boo
return inpt.to(dtype)


# cvcuda is only used if it is installed, so we can simply define empty mappings
_torch_to_cvcuda_dtypes: dict[torch.dtype, "cvcuda.Type"] = {}
_cvcuda_to_torch_dtypes: dict["cvcuda.Type", torch.dtype] = {}


def _to_dtype_image_cvcuda(
inpt: "cvcuda.Tensor",
dtype: torch.dtype = torch.float,
scale: bool = False,
) -> "cvcuda.Tensor":
"""
Convert the dtype of a CV-CUDA tensor, based on a torch.dtype.

Args:
inpt: The CV-CUDA tensor to convert the dtype of.
dtype: The torch.dtype to convert the dtype to.
scale: Whether to scale the values to the new dtype.
There are four cases for the scaling setup:
1. float -> float
2. int -> int
3. float -> int
4. int -> float
If scale is True, the values will be scaled to the new dtype.
If scale is False, the values will not be scaled.
The scale values for float -> float are 1.0 and 0.0 respectively.
The scale values for int -> int are 2^(bit_diff) of the new dtype.
Where bit_diff is the difference in the number of bits of the new dtype and the input dtype.
The scale values for float -> int and int -> float are the maximum value of the new dtype.

Returns:
out (cvcuda.Tensor): The CV-CUDA tensor with the converted dtype.

"""
cvcuda = _import_cvcuda()

if not _torch_to_cvcuda_dtypes:
_torch_to_cvcuda_dtypes[torch.uint8] = cvcuda.Type.U8
_torch_to_cvcuda_dtypes[torch.uint16] = cvcuda.Type.U16
_torch_to_cvcuda_dtypes[torch.uint32] = cvcuda.Type.U32
_torch_to_cvcuda_dtypes[torch.uint64] = cvcuda.Type.U64
_torch_to_cvcuda_dtypes[torch.int8] = cvcuda.Type.S8
_torch_to_cvcuda_dtypes[torch.int16] = cvcuda.Type.S16
_torch_to_cvcuda_dtypes[torch.int32] = cvcuda.Type.S32
_torch_to_cvcuda_dtypes[torch.int64] = cvcuda.Type.S64
_torch_to_cvcuda_dtypes[torch.float32] = cvcuda.Type.F32
_torch_to_cvcuda_dtypes[torch.float64] = cvcuda.Type.F64

if not _cvcuda_to_torch_dtypes:
for k, v in _torch_to_cvcuda_dtypes.items():
_cvcuda_to_torch_dtypes[v] = k

dtype_in = _cvcuda_to_torch_dtypes.get(inpt.dtype)
cvc_dtype = _torch_to_cvcuda_dtypes.get(dtype)
if dtype_in is None or cvc_dtype is None:
raise ValueError(f"No torch or cvcuda dtype found for dtype {dtype} or {inpt.dtype}")

# torchvision will overflow the values of uint16 when converting down to uint8 without scale
# example: 300 -> 255 (cvcuda) vs 300 mod 256 = 44 (torchvision)
# since it is not equivalent, raise an error for unsupported behavior
# the workaround could be using torch for dtype conversion directly via zero-copy
if dtype_in == torch.uint16 and dtype == torch.uint8 and not scale:
raise ValueError("uint16 to uint8 conversion without scale is not supported for CV-CUDA.")

scale_val, offset = 1.0, 0.0
if scale:
in_dtype_float = dtype_in.is_floating_point
out_dtype_float = dtype.is_floating_point

if in_dtype_float and out_dtype_float:
scale_val, offset = 1.0, 0.0
elif not in_dtype_float and not out_dtype_float:
in_bits = torch.iinfo(dtype_in).bits
out_bits = torch.iinfo(dtype).bits
scale_val = float(2 ** (out_bits - in_bits))
offset = 0.0
elif in_dtype_float and not out_dtype_float:
# Mirror the scaling factor which torchvision uses
eps = 1e-3
max_val = float(_max_value(dtype))
scale_val, offset = max_val + 1.0 - eps, 0.0
else:
scale_val, offset = 1.0 / float(_max_value(dtype_in)), 0.0

return cvcuda.convertto(
inpt,
dtype=cvc_dtype,
scale=scale_val,
offset=offset,
)


if CVCUDA_AVAILABLE:
_register_kernel_internal(to_dtype, _import_cvcuda().Tensor)(_to_dtype_image_cvcuda)


def sanitize_bounding_boxes(
bounding_boxes: torch.Tensor,
format: Optional[tv_tensors.BoundingBoxFormat] = None,
Expand Down