diff --git a/cuda_core/cuda/core/_memory/_legacy.py b/cuda_core/cuda/core/_memory/_legacy.py index 9250819610..62b9072a2d 100644 --- a/cuda_core/cuda/core/_memory/_legacy.py +++ b/cuda_core/cuda/core/_memory/_legacy.py @@ -46,8 +46,11 @@ def allocate(self, size, stream=None) -> Buffer: from cuda.core._stream import default_stream stream = default_stream() - err, ptr = driver.cuMemAllocHost(size) - raise_if_driver_error(err) + if size: + err, ptr = driver.cuMemAllocHost(size) + raise_if_driver_error(err) + else: + ptr = 0 return Buffer._init(ptr, size, self, stream) def deallocate(self, ptr: DevicePointerT, size, stream): @@ -64,8 +67,10 @@ def deallocate(self, ptr: DevicePointerT, size, stream): """ if stream is not None: stream.sync() - (err,) = driver.cuMemFreeHost(ptr) - raise_if_driver_error(err) + + if size: + (err,) = driver.cuMemFreeHost(ptr) + raise_if_driver_error(err) @property def is_device_accessible(self) -> bool: @@ -96,15 +101,19 @@ def allocate(self, size, stream=None) -> Buffer: from cuda.core._stream import default_stream stream = default_stream() - err, ptr = driver.cuMemAlloc(size) - raise_if_driver_error(err) + if size: + err, ptr = driver.cuMemAlloc(size) + raise_if_driver_error(err) + else: + ptr = 0 return Buffer._init(ptr, size, self, stream) def deallocate(self, ptr, size, stream): if stream is not None: stream.sync() - (err,) = driver.cuMemFree(ptr) - raise_if_driver_error(err) + if size: + (err,) = driver.cuMemFree(ptr) + raise_if_driver_error(err) @property def is_device_accessible(self) -> bool: diff --git a/cuda_core/tests/test_device.py b/cuda_core/tests/test_device.py index e4365ac0c9..8726c65a30 100644 --- a/cuda_core/tests/test_device.py +++ b/cuda_core/tests/test_device.py @@ -47,6 +47,16 @@ def test_device_alloc(deinit_cuda): assert buffer.device_id == int(device) +def test_device_alloc_zero_bytes(deinit_cuda): + device = Device() + device.set_current() + buffer = device.allocate(0) + device.sync() + assert buffer.handle >= 0 + assert buffer.size == 0 + assert buffer.device_id == int(device) + + def test_device_id(deinit_cuda): for device in Device.get_all_devices(): device.set_current() diff --git a/cuda_core/tests/test_graph_mem.py b/cuda_core/tests/test_graph_mem.py index 5159fd2b2b..bcb8a800a1 100644 --- a/cuda_core/tests/test_graph_mem.py +++ b/cuda_core/tests/test_graph_mem.py @@ -182,6 +182,23 @@ def test_graph_alloc_with_output(mempool_device, mode): assert compare_buffer_to_constant(out, 6) +@pytest.mark.parametrize("mode", ["global", "thread_local", "relaxed"]) +def test_graph_mem_alloc_zero(mempool_device, mode): + device = mempool_device + gb = device.create_graph_builder().begin_building(mode) + stream = device.create_stream() + gmr = GraphMemoryResource(device) + buffer = gmr.allocate(0, stream=gb) + graph = gb.end_building().complete() + graph.upload(stream) + graph.launch(stream) + stream.sync() + + assert buffer.handle >= 0 + assert buffer.size == 0 + assert buffer.device_id == int(device) + + @pytest.mark.parametrize("mode", ["global", "thread_local", "relaxed"]) def test_graph_mem_set_attributes(mempool_device, mode): device = mempool_device diff --git a/cuda_core/tests/test_memory.py b/cuda_core/tests/test_memory.py index 71adb4ffc7..8851a4600a 100644 --- a/cuda_core/tests/test_memory.py +++ b/cuda_core/tests/test_memory.py @@ -1240,3 +1240,28 @@ def test_graph_memory_resource_object(init_cuda): # These objects are interned. assert gmr1 is gmr2 is gmr3 assert gmr1 == gmr2 == gmr3 + + +def test_memory_resource_alloc_zero_bytes(init_cuda, memory_resource_factory): + MR, MROps = memory_resource_factory + + device = Device() + device.set_current() + + if MR is DeviceMemoryResource and not device.properties.memory_pools_supported: + pytest.skip("Device does not support mempool operations") + elif MR is PinnedMemoryResource: + skip_if_pinned_memory_unsupported(device) + mr = MR() + elif MR is ManagedMemoryResource: + skip_if_managed_memory_unsupported(device) + mr = create_managed_memory_resource_or_skip(MROps(preferred_location=device.device_id)) + else: + assert MR is DeviceMemoryResource + mr = MR(device) + + buffer = mr.allocate(0) + device.sync() + assert buffer.handle >= 0 + assert buffer.size == 0 + assert buffer.device_id == mr.device_id