Skip to content
25 changes: 17 additions & 8 deletions cuda_core/cuda/core/_memory/_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@ def allocate(self, size, stream=None) -> Buffer:
from cuda.core._stream import default_stream

stream = default_stream()
err, ptr = driver.cuMemAllocHost(size)
raise_if_driver_error(err)
if size:
err, ptr = driver.cuMemAllocHost(size)
raise_if_driver_error(err)
else:
ptr = 0
return Buffer._init(ptr, size, self, stream)

def deallocate(self, ptr: DevicePointerT, size, stream):
Expand All @@ -64,8 +67,10 @@ def deallocate(self, ptr: DevicePointerT, size, stream):
"""
if stream is not None:
stream.sync()
(err,) = driver.cuMemFreeHost(ptr)
raise_if_driver_error(err)

if size:
(err,) = driver.cuMemFreeHost(ptr)
raise_if_driver_error(err)

@property
def is_device_accessible(self) -> bool:
Expand Down Expand Up @@ -96,15 +101,19 @@ def allocate(self, size, stream=None) -> Buffer:
from cuda.core._stream import default_stream

stream = default_stream()
err, ptr = driver.cuMemAlloc(size)
raise_if_driver_error(err)
if size:
err, ptr = driver.cuMemAlloc(size)
raise_if_driver_error(err)
else:
ptr = 0
return Buffer._init(ptr, size, self, stream)

def deallocate(self, ptr, size, stream):
if stream is not None:
stream.sync()
(err,) = driver.cuMemFree(ptr)
raise_if_driver_error(err)
if size:
(err,) = driver.cuMemFree(ptr)
raise_if_driver_error(err)

@property
def is_device_accessible(self) -> bool:
Expand Down
10 changes: 10 additions & 0 deletions cuda_core/tests/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ def test_device_alloc(deinit_cuda):
assert buffer.device_id == int(device)


def test_device_alloc_zero_bytes(deinit_cuda):
device = Device()
device.set_current()
buffer = device.allocate(0)
device.sync()
assert buffer.handle >= 0
assert buffer.size == 0
assert buffer.device_id == int(device)


def test_device_id(deinit_cuda):
for device in Device.get_all_devices():
device.set_current()
Expand Down
17 changes: 17 additions & 0 deletions cuda_core/tests/test_graph_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,23 @@ def test_graph_alloc_with_output(mempool_device, mode):
assert compare_buffer_to_constant(out, 6)


@pytest.mark.parametrize("mode", ["global", "thread_local", "relaxed"])
def test_graph_mem_alloc_zero(mempool_device, mode):
device = mempool_device
gb = device.create_graph_builder().begin_building(mode)
stream = device.create_stream()
gmr = GraphMemoryResource(device)
buffer = gmr.allocate(0, stream=gb)
graph = gb.end_building().complete()
graph.upload(stream)
graph.launch(stream)
stream.sync()

assert buffer.handle >= 0
assert buffer.size == 0
assert buffer.device_id == int(device)


@pytest.mark.parametrize("mode", ["global", "thread_local", "relaxed"])
def test_graph_mem_set_attributes(mempool_device, mode):
device = mempool_device
Expand Down
25 changes: 25 additions & 0 deletions cuda_core/tests/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,3 +1240,28 @@ def test_graph_memory_resource_object(init_cuda):
# These objects are interned.
assert gmr1 is gmr2 is gmr3
assert gmr1 == gmr2 == gmr3


def test_memory_resource_alloc_zero_bytes(init_cuda, memory_resource_factory):
MR, MROps = memory_resource_factory

device = Device()
device.set_current()

if MR is DeviceMemoryResource and not device.properties.memory_pools_supported:
pytest.skip("Device does not support mempool operations")
elif MR is PinnedMemoryResource:
skip_if_pinned_memory_unsupported(device)
mr = MR()
elif MR is ManagedMemoryResource:
skip_if_managed_memory_unsupported(device)
mr = create_managed_memory_resource_or_skip(MROps(preferred_location=device.device_id))
else:
assert MR is DeviceMemoryResource
mr = MR(device)

buffer = mr.allocate(0)
device.sync()
assert buffer.handle >= 0
assert buffer.size == 0
assert buffer.device_id == mr.device_id
Loading