diff --git a/cuda_core/cuda/core/_cpp/resource_handles.cpp b/cuda_core/cuda/core/_cpp/resource_handles.cpp index 724ea97169..c4bb47261a 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.cpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.cpp @@ -51,6 +51,11 @@ decltype(&cuMemFreeHost) p_cuMemFreeHost = nullptr; decltype(&cuMemPoolImportPointer) p_cuMemPoolImportPointer = nullptr; +decltype(&cuLibraryLoadFromFile) p_cuLibraryLoadFromFile = nullptr; +decltype(&cuLibraryLoadData) p_cuLibraryLoadData = nullptr; +decltype(&cuLibraryUnload) p_cuLibraryUnload = nullptr; +decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel = nullptr; + // ============================================================================ // GIL management helpers // ============================================================================ @@ -682,4 +687,81 @@ DevicePtrHandle deviceptr_import_ipc(MemoryPoolHandle h_pool, const void* export } } +// ============================================================================ +// Library Handles +// ============================================================================ + +namespace { +struct LibraryBox { + CUlibrary resource; +}; +} // namespace + +LibraryHandle create_library_handle_from_file(const char* path) { + GILReleaseGuard gil; + CUlibrary library; + if (CUDA_SUCCESS != (err = p_cuLibraryLoadFromFile(&library, path, nullptr, nullptr, 0, nullptr, nullptr, 0))) { + return {}; + } + + auto box = std::shared_ptr( + new LibraryBox{library}, + [](const LibraryBox* b) { + GILReleaseGuard gil; + p_cuLibraryUnload(b->resource); + delete b; + } + ); + return LibraryHandle(box, &box->resource); +} + +LibraryHandle create_library_handle_from_data(const void* data) { + GILReleaseGuard gil; + CUlibrary library; + if (CUDA_SUCCESS != (err = p_cuLibraryLoadData(&library, data, nullptr, nullptr, 0, nullptr, nullptr, 0))) { + return {}; + } + + auto box = std::shared_ptr( + new LibraryBox{library}, + [](const LibraryBox* b) { + GILReleaseGuard gil; + p_cuLibraryUnload(b->resource); + delete b; + } + ); + return LibraryHandle(box, &box->resource); +} + +LibraryHandle create_library_handle_ref(CUlibrary library) { + auto box = std::make_shared(LibraryBox{library}); + return LibraryHandle(box, &box->resource); +} + +// ============================================================================ +// Kernel Handles +// ============================================================================ + +namespace { +struct KernelBox { + CUkernel resource; + LibraryHandle h_library; // Keeps library alive +}; +} // namespace + +KernelHandle create_kernel_handle(LibraryHandle h_library, const char* name) { + GILReleaseGuard gil; + CUkernel kernel; + if (CUDA_SUCCESS != (err = p_cuLibraryGetKernel(&kernel, *h_library, name))) { + return {}; + } + + return create_kernel_handle_ref(kernel, h_library); +} + +KernelHandle create_kernel_handle_ref(CUkernel kernel, LibraryHandle h_library) { + auto box = std::make_shared(KernelBox{kernel, h_library}); + return KernelHandle(box, &box->resource); +} + } // namespace cuda_core diff --git a/cuda_core/cuda/core/_cpp/resource_handles.hpp b/cuda_core/cuda/core/_cpp/resource_handles.hpp index 4a6d9bb241..1df181ee56 100644 --- a/cuda_core/cuda/core/_cpp/resource_handles.hpp +++ b/cuda_core/cuda/core/_cpp/resource_handles.hpp @@ -61,6 +61,12 @@ extern decltype(&cuMemFreeHost) p_cuMemFreeHost; extern decltype(&cuMemPoolImportPointer) p_cuMemPoolImportPointer; +// Library +extern decltype(&cuLibraryLoadFromFile) p_cuLibraryLoadFromFile; +extern decltype(&cuLibraryLoadData) p_cuLibraryLoadData; +extern decltype(&cuLibraryUnload) p_cuLibraryUnload; +extern decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel; + // ============================================================================ // Handle type aliases - expose only the raw CUDA resource // ============================================================================ @@ -69,6 +75,8 @@ using ContextHandle = std::shared_ptr; using StreamHandle = std::shared_ptr; using EventHandle = std::shared_ptr; using MemoryPoolHandle = std::shared_ptr; +using LibraryHandle = std::shared_ptr; +using KernelHandle = std::shared_ptr; // ============================================================================ // Context handle functions @@ -218,6 +226,40 @@ StreamHandle deallocation_stream(const DevicePtrHandle& h) noexcept; // Set the deallocation stream for a device pointer handle. void set_deallocation_stream(const DevicePtrHandle& h, StreamHandle h_stream) noexcept; +// ============================================================================ +// Library handle functions +// ============================================================================ + +// Create an owning library handle by loading from a file path. +// When the last reference is released, cuLibraryUnload is called automatically. +// Returns empty handle on error (caller must check). +LibraryHandle create_library_handle_from_file(const char* path); + +// Create an owning library handle by loading from memory data. +// The driver makes an internal copy of the data; caller can free it after return. +// When the last reference is released, cuLibraryUnload is called automatically. +// Returns empty handle on error (caller must check). +LibraryHandle create_library_handle_from_data(const void* data); + +// Create a non-owning library handle (references existing library). +// Use for borrowed libraries (e.g., from foreign code). +// The library will NOT be unloaded when the handle is released. +LibraryHandle create_library_handle_ref(CUlibrary library); + +// ============================================================================ +// Kernel handle functions +// ============================================================================ + +// Get a kernel from a library by name. +// The kernel structurally depends on the provided library handle. +// Kernels have no explicit destroy - their lifetime is tied to the library. +// Returns empty handle on error (caller must check). +KernelHandle create_kernel_handle(LibraryHandle h_library, const char* name); + +// Create a non-owning kernel handle with library dependency. +// Use for borrowed kernels. The library handle keeps the library alive. +KernelHandle create_kernel_handle_ref(CUkernel kernel, LibraryHandle h_library); + // ============================================================================ // Overloaded helper functions to extract raw resources from handles // ============================================================================ @@ -243,6 +285,14 @@ inline CUdeviceptr as_cu(const DevicePtrHandle& h) noexcept { return h ? *h : 0; } +inline CUlibrary as_cu(const LibraryHandle& h) noexcept { + return h ? *h : nullptr; +} + +inline CUkernel as_cu(const KernelHandle& h) noexcept { + return h ? *h : nullptr; +} + // as_intptr() - extract handle as intptr_t for Python interop // Using signed intptr_t per C standard convention and issue #1342 inline std::intptr_t as_intptr(const ContextHandle& h) noexcept { @@ -265,6 +315,14 @@ inline std::intptr_t as_intptr(const DevicePtrHandle& h) noexcept { return static_cast(as_cu(h)); } +inline std::intptr_t as_intptr(const LibraryHandle& h) noexcept { + return reinterpret_cast(as_cu(h)); +} + +inline std::intptr_t as_intptr(const KernelHandle& h) noexcept { + return reinterpret_cast(as_cu(h)); +} + // as_py() - convert handle to Python driver wrapper object (returns new reference) namespace detail { // n.b. class lookup is not cached to avoid deadlock hazard, see DESIGN.md @@ -300,4 +358,12 @@ inline PyObject* as_py(const DevicePtrHandle& h) noexcept { return detail::make_py("CUdeviceptr", as_intptr(h)); } +inline PyObject* as_py(const LibraryHandle& h) noexcept { + return detail::make_py("CUlibrary", as_intptr(h)); +} + +inline PyObject* as_py(const KernelHandle& h) noexcept { + return detail::make_py("CUkernel", as_intptr(h)); +} + } // namespace cuda_core diff --git a/cuda_core/cuda/core/_launcher.pyx b/cuda_core/cuda/core/_launcher.pyx index 9559f7697a..48eb2038b2 100644 --- a/cuda_core/cuda/core/_launcher.pyx +++ b/cuda_core/cuda/core/_launcher.pyx @@ -8,6 +8,7 @@ from cuda.bindings cimport cydriver from cuda.core._launch_config cimport LaunchConfig from cuda.core._kernel_arg_handler cimport ParamHolder +from cuda.core._module cimport Kernel from cuda.core._resource_handles cimport as_cu from cuda.core._stream cimport Stream_accept, Stream from cuda.core._utils.cuda_utils cimport ( @@ -77,11 +78,11 @@ def launch(stream: Stream | GraphBuilder | IsStreamT, config: LaunchConfig, kern cdef ParamHolder ker_args = ParamHolder(kernel_args) cdef void** args_ptr = (ker_args.ptr) - # TODO: cythonize Module/Kernel/... # Note: We now use CUkernel handles exclusively (CUDA 12+), but they can be cast to # CUfunction for use with cuLaunchKernel, as both handle types are interchangeable # for kernel launch purposes. - cdef cydriver.CUfunction func_handle = ((kernel._handle)) + cdef Kernel ker = kernel + cdef cydriver.CUfunction func_handle = as_cu(ker._h_kernel) # Note: CUkernel can still be launched via cuLaunchKernel (not just cuLaunchKernelEx). # We check both binding & driver versions here mainly to see if the "Ex" API is diff --git a/cuda_core/cuda/core/_linker.py b/cuda_core/cuda/core/_linker.py index 1f6f221a39..0257655164 100644 --- a/cuda_core/cuda/core/_linker.py +++ b/cuda_core/cuda/core/_linker.py @@ -444,13 +444,13 @@ def __init__(self, *object_codes: ObjectCode, options: LinkerOptions = None): self._add_code_object(code) def _add_code_object(self, object_code: ObjectCode): - data = object_code._module + data = object_code.code with _exception_manager(self): name_str = f"{object_code.name}" if _nvjitlink and isinstance(data, bytes): _nvjitlink.add_data( self._mnff.handle, - self._input_type_from_code_type(object_code._code_type), + self._input_type_from_code_type(object_code.code_type), data, len(data), name_str, @@ -458,7 +458,7 @@ def _add_code_object(self, object_code: ObjectCode): elif _nvjitlink and isinstance(data, str): _nvjitlink.add_file( self._mnff.handle, - self._input_type_from_code_type(object_code._code_type), + self._input_type_from_code_type(object_code.code_type), data, ) elif (not _nvjitlink) and isinstance(data, bytes): @@ -466,7 +466,7 @@ def _add_code_object(self, object_code: ObjectCode): handle_return( _driver.cuLinkAddData( self._mnff.handle, - self._input_type_from_code_type(object_code._code_type), + self._input_type_from_code_type(object_code.code_type), data, len(data), name_bytes, @@ -481,7 +481,7 @@ def _add_code_object(self, object_code: ObjectCode): handle_return( _driver.cuLinkAddFile( self._mnff.handle, - self._input_type_from_code_type(object_code._code_type), + self._input_type_from_code_type(object_code.code_type), data.encode(), 0, None, diff --git a/cuda_core/cuda/core/_memory/_memory_pool.pxd b/cuda_core/cuda/core/_memory/_memory_pool.pxd index eaff8e4bab..434e6b07c1 100644 --- a/cuda_core/cuda/core/_memory/_memory_pool.pxd +++ b/cuda_core/cuda/core/_memory/_memory_pool.pxd @@ -16,7 +16,16 @@ cdef class _MemPool(MemoryResource): IPCDataForMR _ipc_data object _attributes object _peer_accessible_by - object __weakref__ + + +cdef class _MemPoolAttributes: + cdef: + MemoryPoolHandle _h_pool + + @staticmethod + cdef _MemPoolAttributes _init(MemoryPoolHandle h_pool) + + cdef int _getattribute(self, cydriver.CUmemPool_attribute attr_enum, void* value) except? -1 cdef class _MemPoolOptions: diff --git a/cuda_core/cuda/core/_memory/_memory_pool.pyx b/cuda_core/cuda/core/_memory/_memory_pool.pyx index 563f556015..b5823048e1 100644 --- a/cuda_core/cuda/core/_memory/_memory_pool.pyx +++ b/cuda_core/cuda/core/_memory/_memory_pool.pyx @@ -30,7 +30,6 @@ from cuda.core._utils.cuda_utils cimport ( from typing import TYPE_CHECKING import platform # no-cython-lint -import weakref from cuda.core._utils.cuda_utils import driver @@ -50,16 +49,15 @@ cdef class _MemPoolOptions: cdef class _MemPoolAttributes: - cdef: - object _mr_weakref + """Provides access to memory pool attributes.""" def __init__(self, *args, **kwargs): raise RuntimeError("_MemPoolAttributes cannot be instantiated directly. Please use MemoryResource APIs.") - @classmethod - def _init(cls, mr): - cdef _MemPoolAttributes self = _MemPoolAttributes.__new__(cls) - self._mr_weakref = mr + @staticmethod + cdef _MemPoolAttributes _init(MemoryPoolHandle h_pool): + cdef _MemPoolAttributes self = _MemPoolAttributes.__new__(_MemPoolAttributes) + self._h_pool = h_pool return self def __repr__(self): @@ -69,12 +67,8 @@ cdef class _MemPoolAttributes: ) cdef int _getattribute(self, cydriver.CUmemPool_attribute attr_enum, void* value) except?-1: - cdef _MemPool mr = <_MemPool>(self._mr_weakref()) - if mr is None: - raise RuntimeError("_MemPool is expired") - cdef cydriver.CUmemoryPool pool_handle = as_cu(mr._h_pool) with nogil: - HANDLE_RETURN(cydriver.cuMemPoolGetAttribute(pool_handle, attr_enum, value)) + HANDLE_RETURN(cydriver.cuMemPoolGetAttribute(as_cu(self._h_pool), attr_enum, value)) return 0 @property @@ -202,8 +196,7 @@ cdef class _MemPool(MemoryResource): def attributes(self) -> _MemPoolAttributes: """Memory pool attributes.""" if self._attributes is None: - ref = weakref.ref(self) - self._attributes = _MemPoolAttributes._init(ref) + self._attributes = _MemPoolAttributes._init(self._h_pool) return self._attributes @property diff --git a/cuda_core/cuda/core/_module.pxd b/cuda_core/cuda/core/_module.pxd new file mode 100644 index 0000000000..9333703175 --- /dev/null +++ b/cuda_core/cuda/core/_module.pxd @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from cuda.bindings cimport cydriver +from cuda.core._resource_handles cimport LibraryHandle, KernelHandle + +cdef class ObjectCode +cdef class Kernel +cdef class KernelOccupancy +cdef class KernelAttributes + + +cdef class Kernel: + cdef: + KernelHandle _h_kernel + KernelAttributes _attributes # lazy + KernelOccupancy _occupancy # lazy + + @staticmethod + cdef Kernel _from_obj(KernelHandle h_kernel) + + cdef tuple _get_arguments_info(self, bint param_info=*) + + +cdef class ObjectCode: + cdef: + LibraryHandle _h_library + str _code_type + object _module # bytes/str source + dict _sym_map + str _name + + cdef int _lazy_load_module(self) except -1 + + +cdef class KernelOccupancy: + cdef: + KernelHandle _h_kernel + + @staticmethod + cdef KernelOccupancy _init(KernelHandle h_kernel) + + +cdef class KernelAttributes: + cdef: + KernelHandle _h_kernel + dict _cache + + @staticmethod + cdef KernelAttributes _init(KernelHandle h_kernel) + + cdef int _get_cached_attribute(self, int device_id, cydriver.CUfunction_attribute attribute) except? -1 + cdef int _resolve_device_id(self, device_id) except? -1 diff --git a/cuda_core/cuda/core/_module.py b/cuda_core/cuda/core/_module.pyx similarity index 63% rename from cuda_core/cuda/core/_module.py rename to cuda_core/cuda/core/_module.pyx index dd3f4494d5..49a564b6b9 100644 --- a/cuda_core/cuda/core/_module.py +++ b/cuda_core/cuda/core/_module.pyx @@ -2,35 +2,54 @@ # # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from libc.stddef cimport size_t + import functools import threading -import weakref from collections import namedtuple -from typing import Union from cuda.core._device import Device -from cuda.core._launch_config import LaunchConfig, _to_native_launch_config +from cuda.core._launch_config cimport LaunchConfig +from cuda.core._launch_config import LaunchConfig +from cuda.core._stream cimport Stream +from cuda.core._resource_handles cimport ( + LibraryHandle, + KernelHandle, + create_library_handle_from_file, + create_library_handle_from_data, + create_library_handle_ref, + create_kernel_handle, + create_kernel_handle_ref, + get_last_error, + as_cu, + as_py, +) from cuda.core._stream import Stream from cuda.core._utils.clear_error_support import ( - assert_type, assert_type_str_or_bytes_like, raise_code_path_meant_to_be_unreachable, ) -from cuda.core._utils.cuda_utils import CUDAError, driver, get_binding_version, handle_return, precondition +from cuda.core._utils.cuda_utils cimport HANDLE_RETURN +from cuda.core._utils.cuda_utils import driver, get_binding_version +from cuda.bindings cimport cydriver + +__all__ = ["Kernel", "ObjectCode"] # Lazy initialization state and synchronization # For Python 3.13t (free-threaded builds), we use a lock to ensure thread-safe initialization. # For regular Python builds with GIL, the lock overhead is minimal and the code remains safe. -_init_lock = threading.Lock() -_inited = False -_py_major_ver = None -_py_minor_ver = None -_driver_ver = None -_kernel_ctypes = None -_backend = {} +cdef object _init_lock = threading.Lock() +cdef bint _inited = False +cdef int _py_major_ver = 0 +cdef int _py_minor_ver = 0 +cdef int _driver_ver = 0 +cdef tuple _kernel_ctypes = None +cdef bint _paraminfo_supported = False -def _lazy_init(): +cdef int _lazy_init() except -1: """ Initialize module-level state in a thread-safe manner. @@ -45,57 +64,61 @@ def _lazy_init(): global _inited # Fast path: already initialized (no lock needed for read) if _inited: - return + return 0 + cdef int drv_ver # Slow path: acquire lock and initialize with _init_lock: # Double-check: another thread might have initialized while we waited if _inited: - return + return 0 - global _py_major_ver, _py_minor_ver, _driver_ver, _kernel_ctypes, _backend + global _py_major_ver, _py_minor_ver, _driver_ver, _kernel_ctypes, _paraminfo_supported # binding availability depends on cuda-python version _py_major_ver, _py_minor_ver = get_binding_version() - _backend = { - "file": driver.cuLibraryLoadFromFile, - "data": driver.cuLibraryLoadData, - "kernel": driver.cuLibraryGetKernel, - "attribute": driver.cuKernelGetAttribute, - } _kernel_ctypes = (driver.CUkernel,) - _driver_ver = handle_return(driver.cuDriverGetVersion()) - if _driver_ver >= 12040: - _backend["paraminfo"] = driver.cuKernelGetParamInfo + with nogil: + HANDLE_RETURN(cydriver.cuDriverGetVersion(&drv_ver)) + _driver_ver = drv_ver + _paraminfo_supported = _driver_ver >= 12040 # Mark as initialized (must be last to ensure all state is set) _inited = True + return 0 -# Auto-initializing property accessors -def _get_py_major_ver(): + +# Auto-initializing accessors (cdef for internal use) +cdef inline int _get_py_major_ver() except -1: """Get the Python binding major version, initializing if needed.""" _lazy_init() return _py_major_ver -def _get_py_minor_ver(): +cdef inline int _get_py_minor_ver() except -1: """Get the Python binding minor version, initializing if needed.""" _lazy_init() return _py_minor_ver -def _get_driver_ver(): +cdef inline int _get_driver_ver() except -1: """Get the CUDA driver version, initializing if needed.""" _lazy_init() return _driver_ver -def _get_kernel_ctypes(): +cdef inline tuple _get_kernel_ctypes(): """Get the kernel ctypes tuple, initializing if needed.""" _lazy_init() return _kernel_ctypes +cdef inline bint _is_paraminfo_supported() except -1: + """Return True if cuKernelGetParamInfo is available (driver >= 12.4).""" + _lazy_init() + return _paraminfo_supported + + @functools.cache def _is_cukernel_get_library_supported() -> bool: """Return True when cuKernelGetLibrary is available for inverse kernel-to-library lookup. @@ -109,95 +132,110 @@ def _is_cukernel_get_library_supported() -> bool: ) -def _make_dummy_library_handle(): - """Create a non-null placeholder CUlibrary handle to disable lazy loading.""" - return driver.CUlibrary(1) if hasattr(driver, "CUlibrary") else 1 +cdef inline LibraryHandle _make_empty_library_handle(): + """Create an empty LibraryHandle to indicate no library loaded.""" + return LibraryHandle() # Empty shared_ptr -class KernelAttributes: - def __new__(self, *args, **kwargs): - raise RuntimeError("KernelAttributes cannot be instantiated directly. Please use Kernel APIs.") +cdef class KernelAttributes: + """Provides access to kernel attributes.""" - slots = ("_kernel", "_cache", "_loader") + def __init__(self, *args, **kwargs): + raise RuntimeError("KernelAttributes cannot be instantiated directly. Please use Kernel APIs.") - @classmethod - def _init(cls, kernel): - self = super().__new__(cls) - self._kernel = weakref.ref(kernel) + @staticmethod + cdef KernelAttributes _init(KernelHandle h_kernel): + cdef KernelAttributes self = KernelAttributes.__new__(KernelAttributes) + self._h_kernel = h_kernel self._cache = {} - - # Ensure backend is initialized before setting loader _lazy_init() - self._loader = _backend return self - def _get_cached_attribute(self, device_id: Device | int, attribute: driver.CUfunction_attribute) -> int: + cdef int _get_cached_attribute(self, int device_id, cydriver.CUfunction_attribute attribute) except? -1: """Helper function to get a cached attribute or fetch and cache it if not present.""" - device_id = Device(device_id).device_id - cache_key = device_id, attribute - result = self._cache.get(cache_key, cache_key) - if result is not cache_key: - return result - kernel = self._kernel() - if kernel is None: - raise RuntimeError("Cannot access kernel attributes for expired Kernel object") - result = handle_return(self._loader["attribute"](attribute, kernel._handle, device_id)) + cdef tuple cache_key = (device_id, attribute) + cached = self._cache.get(cache_key, cache_key) + if cached is not cache_key: + return cached + cdef int result + with nogil: + HANDLE_RETURN(cydriver.cuKernelGetAttribute(&result, attribute, as_cu(self._h_kernel), device_id)) self._cache[cache_key] = result return result + cdef inline int _resolve_device_id(self, device_id) except? -1: + """Convert Device or int to device_id int.""" + return Device(device_id).device_id + def max_threads_per_block(self, device_id: Device | int = None) -> int: """int : The maximum number of threads per block. This attribute is read-only.""" return self._get_cached_attribute( - device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK + self._resolve_device_id(device_id), cydriver.CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK ) def shared_size_bytes(self, device_id: Device | int = None) -> int: """int : The size in bytes of statically-allocated shared memory required by this function. This attribute is read-only.""" - return self._get_cached_attribute(device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES) + return self._get_cached_attribute( + self._resolve_device_id(device_id), cydriver.CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES + ) def const_size_bytes(self, device_id: Device | int = None) -> int: """int : The size in bytes of user-allocated constant memory required by this function. This attribute is read-only.""" - return self._get_cached_attribute(device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES) + return self._get_cached_attribute( + self._resolve_device_id(device_id), cydriver.CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES + ) def local_size_bytes(self, device_id: Device | int = None) -> int: """int : The size in bytes of local memory used by each thread of this function. This attribute is read-only.""" - return self._get_cached_attribute(device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES) + return self._get_cached_attribute( + self._resolve_device_id(device_id), cydriver.CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES + ) def num_regs(self, device_id: Device | int = None) -> int: """int : The number of registers used by each thread of this function. This attribute is read-only.""" - return self._get_cached_attribute(device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NUM_REGS) + return self._get_cached_attribute( + self._resolve_device_id(device_id), cydriver.CU_FUNC_ATTRIBUTE_NUM_REGS + ) def ptx_version(self, device_id: Device | int = None) -> int: """int : The PTX virtual architecture version for which the function was compiled. This attribute is read-only.""" - return self._get_cached_attribute(device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_PTX_VERSION) + return self._get_cached_attribute( + self._resolve_device_id(device_id), cydriver.CU_FUNC_ATTRIBUTE_PTX_VERSION + ) def binary_version(self, device_id: Device | int = None) -> int: """int : The binary architecture version for which the function was compiled. This attribute is read-only.""" - return self._get_cached_attribute(device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_BINARY_VERSION) + return self._get_cached_attribute( + self._resolve_device_id(device_id), cydriver.CU_FUNC_ATTRIBUTE_BINARY_VERSION + ) def cache_mode_ca(self, device_id: Device | int = None) -> bool: """bool : Whether the function has been compiled with user specified option "-Xptxas --dlcm=ca" set. This attribute is read-only.""" - return bool(self._get_cached_attribute(device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_CACHE_MODE_CA)) + return bool( + self._get_cached_attribute( + self._resolve_device_id(device_id), cydriver.CU_FUNC_ATTRIBUTE_CACHE_MODE_CA + ) + ) def max_dynamic_shared_size_bytes(self, device_id: Device | int = None) -> int: """int : The maximum size in bytes of dynamically-allocated shared memory that can be used by this function.""" return self._get_cached_attribute( - device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES + self._resolve_device_id(device_id), cydriver.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES ) def preferred_shared_memory_carveout(self, device_id: Device | int = None) -> int: """int : The shared memory carveout preference, in percent of the total shared memory.""" return self._get_cached_attribute( - device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT + self._resolve_device_id(device_id), cydriver.CU_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT ) def cluster_size_must_be_set(self, device_id: Device | int = None) -> bool: @@ -205,61 +243,60 @@ def cluster_size_must_be_set(self, device_id: Device | int = None) -> bool: This attribute is read-only.""" return bool( self._get_cached_attribute( - device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_CLUSTER_SIZE_MUST_BE_SET + self._resolve_device_id(device_id), cydriver.CU_FUNC_ATTRIBUTE_CLUSTER_SIZE_MUST_BE_SET ) ) def required_cluster_width(self, device_id: Device | int = None) -> int: """int : The required cluster width in blocks.""" return self._get_cached_attribute( - device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_WIDTH + self._resolve_device_id(device_id), cydriver.CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_WIDTH ) def required_cluster_height(self, device_id: Device | int = None) -> int: """int : The required cluster height in blocks.""" return self._get_cached_attribute( - device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_HEIGHT + self._resolve_device_id(device_id), cydriver.CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_HEIGHT ) def required_cluster_depth(self, device_id: Device | int = None) -> int: """int : The required cluster depth in blocks.""" return self._get_cached_attribute( - device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_DEPTH + self._resolve_device_id(device_id), cydriver.CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_DEPTH ) def non_portable_cluster_size_allowed(self, device_id: Device | int = None) -> bool: """bool : Whether the function can be launched with non-portable cluster size.""" return bool( self._get_cached_attribute( - device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED + self._resolve_device_id(device_id), + cydriver.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, ) ) def cluster_scheduling_policy_preference(self, device_id: Device | int = None) -> int: """int : The block scheduling policy of a function.""" return self._get_cached_attribute( - device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE + self._resolve_device_id(device_id), + cydriver.CU_FUNC_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE, ) MaxPotentialBlockSizeOccupancyResult = namedtuple("MaxPotential", ("min_grid_size", "max_block_size")) -class KernelOccupancy: +cdef class KernelOccupancy: """This class offers methods to query occupancy metrics that help determine optimal launch parameters such as block size, grid size, and shared memory usage. """ - def __new__(self, *args, **kwargs): + def __init__(self, *args, **kwargs): raise RuntimeError("KernelOccupancy cannot be instantiated directly. Please use Kernel APIs.") - slots = ("_handle",) - - @classmethod - def _init(cls, handle): - self = super().__new__(cls) - self._handle = handle - + @staticmethod + cdef KernelOccupancy _init(KernelHandle h_kernel): + cdef KernelOccupancy self = KernelOccupancy.__new__(KernelOccupancy) + self._h_kernel = h_kernel return self def max_active_blocks_per_multiprocessor(self, block_size: int, dynamic_shared_memory_size: int) -> int: @@ -287,12 +324,18 @@ def max_active_blocks_per_multiprocessor(self, block_size: int, dynamic_shared_m theoretical multiprocessor utilization (occupancy). """ - return handle_return( - driver.cuOccupancyMaxActiveBlocksPerMultiprocessor(self._handle, block_size, dynamic_shared_memory_size) - ) + cdef int num_blocks + cdef int c_block_size = block_size + cdef size_t c_shmem_size = dynamic_shared_memory_size + cdef cydriver.CUfunction func = as_cu(self._h_kernel) + with nogil: + HANDLE_RETURN(cydriver.cuOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks, func, c_block_size, c_shmem_size + )) + return num_blocks def max_potential_block_size( - self, dynamic_shared_memory_needed: Union[int, driver.CUoccupancyB2DSize], block_size_limit: int + self, dynamic_shared_memory_needed: int | driver.CUoccupancyB2DSize, block_size_limit: int ) -> MaxPotentialBlockSizeOccupancyResult: """MaxPotentialBlockSizeOccupancyResult: Suggested launch configuration for reasonable occupancy. @@ -323,18 +366,23 @@ def max_potential_block_size( Interpreter Lock may lead to deadlocks. """ + cdef int min_grid_size, max_block_size + cdef cydriver.CUfunction func = as_cu(self._h_kernel) + cdef cydriver.CUoccupancyB2DSize callback + cdef size_t c_shmem_size + cdef int c_block_size_limit = block_size_limit if isinstance(dynamic_shared_memory_needed, int): - min_grid_size, max_block_size = handle_return( - driver.cuOccupancyMaxPotentialBlockSize( - self._handle, None, dynamic_shared_memory_needed, block_size_limit - ) - ) + c_shmem_size = dynamic_shared_memory_needed + with nogil: + HANDLE_RETURN(cydriver.cuOccupancyMaxPotentialBlockSize( + &min_grid_size, &max_block_size, func, NULL, c_shmem_size, c_block_size_limit + )) elif isinstance(dynamic_shared_memory_needed, driver.CUoccupancyB2DSize): - min_grid_size, max_block_size = handle_return( - driver.cuOccupancyMaxPotentialBlockSize( - self._handle, dynamic_shared_memory_needed.getPtr(), 0, block_size_limit - ) - ) + # Callback may require GIL, so don't use nogil here + callback = dynamic_shared_memory_needed.getPtr() + HANDLE_RETURN(cydriver.cuOccupancyMaxPotentialBlockSize( + &min_grid_size, &max_block_size, func, callback, 0, c_block_size_limit + )) else: raise TypeError( "dynamic_shared_memory_needed expected to have type int, or CUoccupancyB2DSize, " @@ -359,9 +407,15 @@ def available_dynamic_shared_memory_per_block(self, num_blocks_per_multiprocesso int Dynamic shared memory available per block for given launch configuration. """ - return handle_return( - driver.cuOccupancyAvailableDynamicSMemPerBlock(self._handle, num_blocks_per_multiprocessor, block_size) - ) + cdef size_t dynamic_smem_size + cdef int c_num_blocks = num_blocks_per_multiprocessor + cdef int c_block_size = block_size + cdef cydriver.CUfunction func = as_cu(self._h_kernel) + with nogil: + HANDLE_RETURN(cydriver.cuOccupancyAvailableDynamicSMemPerBlock( + &dynamic_smem_size, func, c_num_blocks, c_block_size + )) + return dynamic_smem_size def max_potential_cluster_size(self, config: LaunchConfig, stream: Stream | None = None) -> int: """Maximum potential cluster size. @@ -380,10 +434,16 @@ def max_potential_cluster_size(self, config: LaunchConfig, stream: Stream | None int The maximum cluster size that can be launched for this kernel and launch configuration. """ - drv_cfg = _to_native_launch_config(config) + cdef cydriver.CUlaunchConfig drv_cfg = (config)._to_native_launch_config() + cdef Stream s if stream is not None: - drv_cfg.hStream = stream.handle - return handle_return(driver.cuOccupancyMaxPotentialClusterSize(self._handle, drv_cfg)) + s = stream + drv_cfg.hStream = as_cu(s._h_stream) + cdef int cluster_size + cdef cydriver.CUfunction func = as_cu(self._h_kernel) + with nogil: + HANDLE_RETURN(cydriver.cuOccupancyMaxPotentialClusterSize(&cluster_size, func, &drv_cfg)) + return cluster_size def max_active_clusters(self, config: LaunchConfig, stream: Stream | None = None) -> int: """Maximum number of active clusters on the target device. @@ -402,16 +462,22 @@ def max_active_clusters(self, config: LaunchConfig, stream: Stream | None = None int The maximum number of clusters that could co-exist on the target device. """ - drv_cfg = _to_native_launch_config(config) + cdef cydriver.CUlaunchConfig drv_cfg = (config)._to_native_launch_config() + cdef Stream s if stream is not None: - drv_cfg.hStream = stream.handle - return handle_return(driver.cuOccupancyMaxActiveClusters(self._handle, drv_cfg)) + s = stream + drv_cfg.hStream = as_cu(s._h_stream) + cdef int num_clusters + cdef cydriver.CUfunction func = as_cu(self._h_kernel) + with nogil: + HANDLE_RETURN(cydriver.cuOccupancyMaxActiveClusters(&num_clusters, func, &drv_cfg)) + return num_clusters ParamInfo = namedtuple("ParamInfo", ["offset", "size"]) -class Kernel: +cdef class Kernel: """Represent a compiled kernel that had been loaded onto the device. Kernel instances can execution when passed directly into the @@ -422,18 +488,13 @@ class Kernel: """ - __slots__ = ("_handle", "_module", "_attributes", "_occupancy", "__weakref__") - - def __new__(self, *args, **kwargs): + def __init__(self, *args, **kwargs): raise RuntimeError("Kernel objects cannot be instantiated directly. Please use ObjectCode APIs.") - @classmethod - def _from_obj(cls, obj, mod): - assert_type(obj, _get_kernel_ctypes()) - assert_type(mod, ObjectCode) - ker = super().__new__(cls) - ker._handle = obj - ker._module = mod + @staticmethod + cdef Kernel _from_obj(KernelHandle h_kernel): + cdef Kernel ker = Kernel.__new__(Kernel) + ker._h_kernel = h_kernel ker._attributes = None ker._occupancy = None return ker @@ -442,29 +503,31 @@ def _from_obj(cls, obj, mod): def attributes(self) -> KernelAttributes: """Get the read-only attributes of this kernel.""" if self._attributes is None: - self._attributes = KernelAttributes._init(self) + self._attributes = KernelAttributes._init(self._h_kernel) return self._attributes - def _get_arguments_info(self, param_info=False) -> tuple[int, list[ParamInfo]]: - attr_impl = self.attributes - if "paraminfo" not in attr_impl._loader: + cdef tuple _get_arguments_info(self, bint param_info=False): + if not _is_paraminfo_supported(): driver_ver = _get_driver_ver() raise NotImplementedError( "Driver version 12.4 or newer is required for this function. " f"Using driver version {driver_ver // 1000}.{(driver_ver % 1000) // 10}" ) - arg_pos = 0 - param_info_data = [] + cdef size_t arg_pos = 0 + cdef list param_info_data = [] + cdef cydriver.CUkernel cu_kernel = as_cu(self._h_kernel) + cdef size_t param_offset, param_size + cdef cydriver.CUresult err while True: - result = attr_impl._loader["paraminfo"](self._handle, arg_pos) - if result[0] != driver.CUresult.CUDA_SUCCESS: + with nogil: + err = cydriver.cuKernelGetParamInfo(cu_kernel, arg_pos, ¶m_offset, ¶m_size) + if err != cydriver.CUDA_SUCCESS: break if param_info: - p_info = ParamInfo(offset=result[1], size=result[2]) - param_info_data.append(p_info) + param_info_data.append(ParamInfo(offset=param_offset, size=param_size)) arg_pos = arg_pos + 1 - if result[0] != driver.CUresult.CUDA_ERROR_INVALID_VALUE: - handle_return(result) + if err != cydriver.CUDA_ERROR_INVALID_VALUE: + HANDLE_RETURN(err) return arg_pos, param_info_data @property @@ -483,11 +546,22 @@ def arguments_info(self) -> list[ParamInfo]: def occupancy(self) -> KernelOccupancy: """Get the occupancy information for launching this kernel.""" if self._occupancy is None: - self._occupancy = KernelOccupancy._init(self._handle) + self._occupancy = KernelOccupancy._init(self._h_kernel) return self._occupancy + @property + def handle(self): + """Return the underlying kernel handle object. + + .. caution:: + + This handle is a Python object. To get the memory address of the underlying C + handle, call ``int(Kernel.handle)``. + """ + return as_py(self._h_kernel) + @staticmethod - def from_handle(handle: int, mod: "ObjectCode" = None) -> "Kernel": + def from_handle(handle, mod: ObjectCode = None) -> Kernel: """Creates a new :obj:`Kernel` object from a foreign kernel handle. Uses a CUkernel pointer address to create a new :obj:`Kernel` object. @@ -507,31 +581,36 @@ def from_handle(handle: int, mod: "ObjectCode" = None) -> "Kernel": if not isinstance(handle, int): raise TypeError(f"handle must be an integer, got {type(handle).__name__}") - # Convert the integer handle to CUkernel driver type - kernel_obj = driver.CUkernel(handle) + # Convert the integer handle to CUkernel + cdef cydriver.CUkernel cu_kernel = handle + cdef KernelHandle h_kernel + cdef cydriver.CUlibrary cu_library + cdef cydriver.CUresult err - # If no module provided, create a placeholder + # If no module provided, create a placeholder and try to get the library if mod is None: - # For CUkernel, we can (optionally) inverse-lookup the owning CUlibrary via - # cuKernelGetLibrary (added in CUDA 12.5). If the API is not available, we fall - # back to a non-null dummy handle purely to disable lazy loading. mod = ObjectCode._init(b"", "cubin") if _is_cukernel_get_library_supported(): - try: - mod._handle = handle_return(driver.cuKernelGetLibrary(kernel_obj)) - except (CUDAError, RuntimeError): - # Best-effort: don't fail construction if inverse lookup fails. - mod._handle = _make_dummy_library_handle() - else: - mod._handle = _make_dummy_library_handle() + # Try to get the owning library via cuKernelGetLibrary + with nogil: + err = cydriver.cuKernelGetLibrary(&cu_library, cu_kernel) + if err == cydriver.CUDA_SUCCESS: + mod._h_library = create_library_handle_ref(cu_library) - return Kernel._from_obj(kernel_obj, mod) + # Create kernel handle with library dependency + h_kernel = create_kernel_handle_ref(cu_kernel, mod._h_library) + if not h_kernel: + HANDLE_RETURN(get_last_error()) + return Kernel._from_obj(h_kernel) -CodeTypeT = Union[bytes, bytearray, str] +CodeTypeT = bytes | bytearray | str -class ObjectCode: +cdef tuple _supported_code_type = ("cubin", "ptx", "ltoir", "fatbin", "object", "library") + + +cdef class ObjectCode: """Represent a compiled program to be loaded onto the device. This object provides a unified interface for different types of @@ -545,10 +624,7 @@ class ObjectCode: :class:`~cuda.core.Program` """ - __slots__ = ("_handle", "_code_type", "_module", "_loader", "_sym_map", "_name") - _supported_code_type = ("cubin", "ptx", "ltoir", "fatbin", "object", "library") - - def __new__(self, *args, **kwargs): + def __init__(self, *args, **kwargs): raise RuntimeError( "ObjectCode objects cannot be instantiated directly. " "Please use ObjectCode APIs (from_cubin, from_ptx) or Program APIs (compile)." @@ -556,33 +632,30 @@ def __new__(self, *args, **kwargs): @classmethod def _init(cls, module, code_type, *, name: str = "", symbol_mapping: dict | None = None): - self = super().__new__(cls) - assert code_type in self._supported_code_type, f"{code_type=} is not supported" + assert code_type in _supported_code_type, f"{code_type=} is not supported" + cdef ObjectCode self = ObjectCode.__new__(ObjectCode) - # handle is assigned during _lazy_load - self._handle = None - - # Ensure backend is initialized before setting loader + # _h_library is assigned during _lazy_load_module + self._h_library = LibraryHandle() # Empty handle _lazy_init() - self._loader = _backend self._code_type = code_type self._module = module self._sym_map = {} if symbol_mapping is None else symbol_mapping - self._name = name + self._name = name if name else "" return self @classmethod - def _reduce_helper(self, module, code_type, name, symbol_mapping): + def _reduce_helper(cls, module, code_type, name, symbol_mapping): # just for forwarding kwargs - return ObjectCode._init(module, code_type, name=name, symbol_mapping=symbol_mapping) + return cls._init(module, code_type, name=name if name else "", symbol_mapping=symbol_mapping) def __reduce__(self): return ObjectCode._reduce_helper, (self._module, self._code_type, self._name, self._sym_map) @staticmethod - def from_cubin(module: Union[bytes, str], *, name: str = "", symbol_mapping: dict | None = None) -> "ObjectCode": + def from_cubin(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode: """Create an :class:`ObjectCode` instance from an existing cubin. Parameters @@ -600,7 +673,7 @@ def from_cubin(module: Union[bytes, str], *, name: str = "", symbol_mapping: dic return ObjectCode._init(module, "cubin", name=name, symbol_mapping=symbol_mapping) @staticmethod - def from_ptx(module: Union[bytes, str], *, name: str = "", symbol_mapping: dict | None = None) -> "ObjectCode": + def from_ptx(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode: """Create an :class:`ObjectCode` instance from an existing PTX. Parameters @@ -618,7 +691,7 @@ def from_ptx(module: Union[bytes, str], *, name: str = "", symbol_mapping: dict return ObjectCode._init(module, "ptx", name=name, symbol_mapping=symbol_mapping) @staticmethod - def from_ltoir(module: Union[bytes, str], *, name: str = "", symbol_mapping: dict | None = None) -> "ObjectCode": + def from_ltoir(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode: """Create an :class:`ObjectCode` instance from an existing LTOIR. Parameters @@ -636,7 +709,7 @@ def from_ltoir(module: Union[bytes, str], *, name: str = "", symbol_mapping: dic return ObjectCode._init(module, "ltoir", name=name, symbol_mapping=symbol_mapping) @staticmethod - def from_fatbin(module: Union[bytes, str], *, name: str = "", symbol_mapping: dict | None = None) -> "ObjectCode": + def from_fatbin(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode: """Create an :class:`ObjectCode` instance from an existing fatbin. Parameters @@ -654,7 +727,7 @@ def from_fatbin(module: Union[bytes, str], *, name: str = "", symbol_mapping: di return ObjectCode._init(module, "fatbin", name=name, symbol_mapping=symbol_mapping) @staticmethod - def from_object(module: Union[bytes, str], *, name: str = "", symbol_mapping: dict | None = None) -> "ObjectCode": + def from_object(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode: """Create an :class:`ObjectCode` instance from an existing object code. Parameters @@ -672,7 +745,7 @@ def from_object(module: Union[bytes, str], *, name: str = "", symbol_mapping: di return ObjectCode._init(module, "object", name=name, symbol_mapping=symbol_mapping) @staticmethod - def from_library(module: Union[bytes, str], *, name: str = "", symbol_mapping: dict | None = None) -> "ObjectCode": + def from_library(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode: """Create an :class:`ObjectCode` instance from an existing library. Parameters @@ -691,20 +764,26 @@ def from_library(module: Union[bytes, str], *, name: str = "", symbol_mapping: d # TODO: do we want to unload in a finalizer? Probably not.. - def _lazy_load_module(self, *args, **kwargs): - if self._handle is not None: - return + cdef int _lazy_load_module(self) except -1: + if self._h_library: + return 0 module = self._module assert_type_str_or_bytes_like(module) + cdef bytes path_bytes if isinstance(module, str): - self._handle = handle_return(self._loader["file"](module.encode(), [], [], 0, [], [], 0)) - return + path_bytes = module.encode() + self._h_library = create_library_handle_from_file(path_bytes) + if not self._h_library: + HANDLE_RETURN(get_last_error()) + return 0 if isinstance(module, (bytes, bytearray)): - self._handle = handle_return(self._loader["data"](module, [], [], 0, [], [], 0)) - return + self._h_library = create_library_handle_from_data(module) + if not self._h_library: + HANDLE_RETURN(get_last_error()) + return 0 raise_code_path_meant_to_be_unreachable() + return -1 - @precondition(_lazy_load_module) def get_kernel(self, name) -> Kernel: """Return the :obj:`~_module.Kernel` of a specified name from this object code. @@ -719,6 +798,7 @@ def get_kernel(self, name) -> Kernel: Newly created kernel object. """ + self._lazy_load_module() supported_code_types = ("cubin", "ptx", "fatbin") if self._code_type not in supported_code_types: raise RuntimeError(f'Unsupported code type "{self._code_type}" ({supported_code_types=})') @@ -727,8 +807,10 @@ def get_kernel(self, name) -> Kernel: except KeyError: name = name.encode() - data = handle_return(self._loader["kernel"](self._handle, name)) - return Kernel._from_obj(data, self) + cdef KernelHandle h_kernel = create_kernel_handle(self._h_library, name) + if not h_kernel: + HANDLE_RETURN(get_last_error()) + return Kernel._from_obj(h_kernel) @property def code(self) -> CodeTypeT: @@ -746,7 +828,11 @@ def code_type(self) -> str: return self._code_type @property - @precondition(_lazy_load_module) + def symbol_mapping(self) -> dict: + """Return a copy of the symbol mapping dictionary.""" + return dict(self._sym_map) + + @property def handle(self): """Return the underlying handle object. @@ -755,4 +841,5 @@ def handle(self): This handle is a Python object. To get the memory address of the underlying C handle, call ``int(ObjectCode.handle)``. """ - return self._handle + self._lazy_load_module() + return as_py(self._h_library) diff --git a/cuda_core/cuda/core/_resource_handles.pxd b/cuda_core/cuda/core/_resource_handles.pxd index 7a634f3a82..b146d93aa7 100644 --- a/cuda_core/cuda/core/_resource_handles.pxd +++ b/cuda_core/cuda/core/_resource_handles.pxd @@ -21,6 +21,8 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": ctypedef shared_ptr[const cydriver.CUevent] EventHandle ctypedef shared_ptr[const cydriver.CUmemoryPool] MemoryPoolHandle ctypedef shared_ptr[const cydriver.CUdeviceptr] DevicePtrHandle + ctypedef shared_ptr[const cydriver.CUlibrary] LibraryHandle + ctypedef shared_ptr[const cydriver.CUkernel] KernelHandle # as_cu() - extract the raw CUDA handle (inline C++) cydriver.CUcontext as_cu(ContextHandle h) noexcept nogil @@ -28,6 +30,8 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": cydriver.CUevent as_cu(EventHandle h) noexcept nogil cydriver.CUmemoryPool as_cu(MemoryPoolHandle h) noexcept nogil cydriver.CUdeviceptr as_cu(DevicePtrHandle h) noexcept nogil + cydriver.CUlibrary as_cu(LibraryHandle h) noexcept nogil + cydriver.CUkernel as_cu(KernelHandle h) noexcept nogil # as_intptr() - extract handle as intptr_t for Python interop (inline C++) intptr_t as_intptr(ContextHandle h) noexcept nogil @@ -35,6 +39,8 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": intptr_t as_intptr(EventHandle h) noexcept nogil intptr_t as_intptr(MemoryPoolHandle h) noexcept nogil intptr_t as_intptr(DevicePtrHandle h) noexcept nogil + intptr_t as_intptr(LibraryHandle h) noexcept nogil + intptr_t as_intptr(KernelHandle h) noexcept nogil # as_py() - convert handle to Python driver wrapper object (inline C++; requires GIL) object as_py(ContextHandle h) @@ -42,6 +48,8 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": object as_py(EventHandle h) object as_py(MemoryPoolHandle h) object as_py(DevicePtrHandle h) + object as_py(LibraryHandle h) + object as_py(KernelHandle h) # ============================================================================= @@ -94,3 +102,13 @@ cdef DevicePtrHandle deviceptr_import_ipc( MemoryPoolHandle h_pool, const void* export_data, StreamHandle h_stream) nogil except+ cdef StreamHandle deallocation_stream(const DevicePtrHandle& h) noexcept nogil cdef void set_deallocation_stream(const DevicePtrHandle& h, StreamHandle h_stream) noexcept nogil + +# Library handles +cdef LibraryHandle create_library_handle_from_file(const char* path) nogil except+ +cdef LibraryHandle create_library_handle_from_data(const void* data) nogil except+ +cdef LibraryHandle create_library_handle_ref(cydriver.CUlibrary library) nogil except+ + +# Kernel handles +cdef KernelHandle create_kernel_handle(LibraryHandle h_library, const char* name) nogil except+ +cdef KernelHandle create_kernel_handle_ref( + cydriver.CUkernel kernel, LibraryHandle h_library) nogil except+ diff --git a/cuda_core/cuda/core/_resource_handles.pyx b/cuda_core/cuda/core/_resource_handles.pyx index 7989cd1bb0..022929f7e3 100644 --- a/cuda_core/cuda/core/_resource_handles.pyx +++ b/cuda_core/cuda/core/_resource_handles.pyx @@ -21,6 +21,8 @@ from ._resource_handles cimport ( EventHandle, MemoryPoolHandle, DevicePtrHandle, + LibraryHandle, + KernelHandle, ) import cuda.bindings.cydriver as cydriver @@ -91,6 +93,20 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": void set_deallocation_stream "cuda_core::set_deallocation_stream" ( const DevicePtrHandle& h, StreamHandle h_stream) noexcept nogil + # Library handles + LibraryHandle create_library_handle_from_file "cuda_core::create_library_handle_from_file" ( + const char* path) nogil except+ + LibraryHandle create_library_handle_from_data "cuda_core::create_library_handle_from_data" ( + const void* data) nogil except+ + LibraryHandle create_library_handle_ref "cuda_core::create_library_handle_ref" ( + cydriver.CUlibrary library) nogil except+ + + # Kernel handles + KernelHandle create_kernel_handle "cuda_core::create_kernel_handle" ( + LibraryHandle h_library, const char* name) nogil except+ + KernelHandle create_kernel_handle_ref "cuda_core::create_kernel_handle_ref" ( + cydriver.CUkernel kernel, LibraryHandle h_library) nogil except+ + # ============================================================================= # CUDA Driver API capsule @@ -152,6 +168,12 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core": # IPC void* p_cuMemPoolImportPointer "reinterpret_cast(cuda_core::p_cuMemPoolImportPointer)" + # Library + void* p_cuLibraryLoadFromFile "reinterpret_cast(cuda_core::p_cuLibraryLoadFromFile)" + void* p_cuLibraryLoadData "reinterpret_cast(cuda_core::p_cuLibraryLoadData)" + void* p_cuLibraryUnload "reinterpret_cast(cuda_core::p_cuLibraryUnload)" + void* p_cuLibraryGetKernel "reinterpret_cast(cuda_core::p_cuLibraryGetKernel)" + # Initialize driver function pointers from cydriver.__pyx_capi__ at module load cdef void* _get_driver_fn(str name): @@ -195,3 +217,9 @@ p_cuMemFreeHost = _get_driver_fn("cuMemFreeHost") # IPC p_cuMemPoolImportPointer = _get_driver_fn("cuMemPoolImportPointer") + +# Library +p_cuLibraryLoadFromFile = _get_driver_fn("cuLibraryLoadFromFile") +p_cuLibraryLoadData = _get_driver_fn("cuLibraryLoadData") +p_cuLibraryUnload = _get_driver_fn("cuLibraryUnload") +p_cuLibraryGetKernel = _get_driver_fn("cuLibraryGetKernel") diff --git a/cuda_core/tests/test_memory.py b/cuda_core/tests/test_memory.py index 8851a4600a..47091995e7 100644 --- a/cuda_core/tests/test_memory.py +++ b/cuda_core/tests/test_memory.py @@ -1162,7 +1162,7 @@ def test_mempool_attributes_repr(memory_resource_factory): def test_mempool_attributes_ownership(memory_resource_factory): - """Ensure the attributes bundle handles references correctly for all memory resource types.""" + """Ensure the attributes bundle keeps the pool alive via the handle.""" MR, MRops = memory_resource_factory device = Device() @@ -1190,21 +1190,9 @@ def test_mempool_attributes_ownership(memory_resource_factory): mr.close() del mr - # After deleting the memory resource, the attributes suite is disconnected. - with pytest.raises(RuntimeError, match="is expired"): - _ = attributes.used_mem_high - - # Even when a new object is created (we found a case where the same - # mempool handle was really reused). - if MR is DeviceMemoryResource: - mr = MR(device, dict(max_size=POOL_SIZE)) # noqa: F841 - elif MR is PinnedMemoryResource: - mr = MR(dict(max_size=POOL_SIZE)) # noqa: F841 - elif MR is ManagedMemoryResource: - mr = create_managed_memory_resource_or_skip(dict()) # noqa: F841 - - with pytest.raises(RuntimeError, match="is expired"): - _ = attributes.used_mem_high + # The attributes bundle keeps the pool alive via MemoryPoolHandle, + # so accessing attributes still works even after the MR is deleted. + _ = attributes.used_mem_high # Should not raise # Ensure that memory views dellocate their reference to dlpack tensors diff --git a/cuda_core/tests/test_module.py b/cuda_core/tests/test_module.py index f9bbcd3e4c..72591b54d5 100644 --- a/cuda_core/tests/test_module.py +++ b/cuda_core/tests/test_module.py @@ -79,7 +79,7 @@ def get_saxpy_kernel_ptx(init_cuda): "ptx", name_expressions=("saxpy", "saxpy"), ) - ptx = mod._module + ptx = mod.code return ptx, mod @@ -100,10 +100,10 @@ def test_get_kernel(init_cuda): if any("The CUDA driver version is older than the backend version" in str(warning.message) for warning in w): pytest.skip("PTX version too new for current driver") - assert object_code._handle is None + # Verify lazy loading: get_kernel triggers module loading and returns a valid kernel kernel = object_code.get_kernel("ABC") - assert object_code._handle is not None - assert kernel._handle is not None + assert object_code.handle is not None + assert kernel.handle is not None @pytest.mark.parametrize( @@ -143,7 +143,7 @@ def test_read_only_kernel_attributes(get_saxpy_kernel_cubin, attr, expected_type def test_object_code_load_ptx(get_saxpy_kernel_ptx): ptx, mod = get_saxpy_kernel_ptx - sym_map = mod._sym_map + sym_map = mod.symbol_mapping mod_obj = ObjectCode.from_ptx(ptx, symbol_mapping=sym_map) assert mod.code == ptx if not Program._can_load_generated_ptx(): @@ -153,7 +153,7 @@ def test_object_code_load_ptx(get_saxpy_kernel_ptx): def test_object_code_load_ptx_from_file(get_saxpy_kernel_ptx, tmp_path): ptx, mod = get_saxpy_kernel_ptx - sym_map = mod._sym_map + sym_map = mod.symbol_mapping assert isinstance(ptx, bytes) ptx_file = tmp_path / "test.ptx" ptx_file.write_bytes(ptx) @@ -167,8 +167,8 @@ def test_object_code_load_ptx_from_file(get_saxpy_kernel_ptx, tmp_path): def test_object_code_load_cubin(get_saxpy_kernel_cubin): _, mod = get_saxpy_kernel_cubin - cubin = mod._module - sym_map = mod._sym_map + cubin = mod.code + sym_map = mod.symbol_mapping assert isinstance(cubin, bytes) mod = ObjectCode.from_cubin(cubin, symbol_mapping=sym_map) assert mod.code == cubin @@ -177,8 +177,8 @@ def test_object_code_load_cubin(get_saxpy_kernel_cubin): def test_object_code_load_cubin_from_file(get_saxpy_kernel_cubin, tmp_path): _, mod = get_saxpy_kernel_cubin - cubin = mod._module - sym_map = mod._sym_map + cubin = mod.code + sym_map = mod.symbol_mapping assert isinstance(cubin, bytes) cubin_file = tmp_path / "test.cubin" cubin_file.write_bytes(cubin) @@ -194,14 +194,13 @@ def test_object_code_handle(get_saxpy_kernel_cubin): def test_object_code_load_ltoir(get_saxpy_kernel_ltoir): mod = get_saxpy_kernel_ltoir - ltoir = mod._module - sym_map = mod._sym_map + ltoir = mod.code + sym_map = mod.symbol_mapping assert isinstance(ltoir, bytes) mod_obj = ObjectCode.from_ltoir(ltoir, symbol_mapping=sym_map) assert mod_obj.code == ltoir assert mod_obj.code_type == "ltoir" # ltoir doesn't support kernel retrieval directly as it's used for linking - assert mod_obj._handle is None # Test that get_kernel fails for unsupported code type with pytest.raises(RuntimeError, match=r'Unsupported code type "ltoir"'): mod_obj.get_kernel("saxpy") @@ -209,8 +208,8 @@ def test_object_code_load_ltoir(get_saxpy_kernel_ltoir): def test_object_code_load_ltoir_from_file(get_saxpy_kernel_ltoir, tmp_path): mod = get_saxpy_kernel_ltoir - ltoir = mod._module - sym_map = mod._sym_map + ltoir = mod.code + sym_map = mod.symbol_mapping assert isinstance(ltoir, bytes) ltoir_file = tmp_path / "test.ltoir" ltoir_file.write_bytes(ltoir) @@ -218,7 +217,6 @@ def test_object_code_load_ltoir_from_file(get_saxpy_kernel_ltoir, tmp_path): assert mod_obj.code == str(ltoir_file) assert mod_obj.code_type == "ltoir" # ltoir doesn't support kernel retrieval directly as it's used for linking - assert mod_obj._handle is None def test_saxpy_arguments(get_saxpy_kernel_cubin, cuda12_4_prerequisite_check): @@ -231,7 +229,6 @@ def test_saxpy_arguments(get_saxpy_kernel_cubin, cuda12_4_prerequisite_check): _ = krn.num_arguments return - assert "ParamInfo" in str(type(krn).arguments_info.fget.__annotations__) arg_info = krn.arguments_info n_args = len(arg_info) assert n_args == krn.num_arguments @@ -418,7 +415,7 @@ def test_module_serialization_roundtrip(get_saxpy_kernel_cubin): assert isinstance(result, ObjectCode) assert objcode.code == result.code - assert objcode._sym_map == result._sym_map + assert objcode.symbol_mapping == result.symbol_mapping assert objcode.code_type == result.code_type @@ -427,7 +424,7 @@ def test_kernel_from_handle(get_saxpy_kernel_cubin): original_kernel, objcode = get_saxpy_kernel_cubin # Get the handle from the original kernel - handle = int(original_kernel._handle) + handle = int(original_kernel.handle) # Create a new Kernel from the handle kernel_from_handle = Kernel.from_handle(handle, objcode) @@ -444,7 +441,7 @@ def test_kernel_from_handle_no_module(get_saxpy_kernel_cubin): original_kernel, _ = get_saxpy_kernel_cubin # Get the handle from the original kernel - handle = int(original_kernel._handle) + handle = int(original_kernel.handle) # Create a new Kernel from the handle without a module # This is supported on CUDA 12+ backend (CUkernel) @@ -481,7 +478,7 @@ def test_kernel_from_handle_type_validation(invalid_value): def test_kernel_from_handle_invalid_module_type(get_saxpy_kernel_cubin): """Test Kernel.from_handle() with invalid module parameter""" original_kernel, _ = get_saxpy_kernel_cubin - handle = int(original_kernel._handle) + handle = int(original_kernel.handle) # Invalid module type (should fail type assertion in _from_obj) with pytest.raises((TypeError, AssertionError)): @@ -494,7 +491,7 @@ def test_kernel_from_handle_invalid_module_type(get_saxpy_kernel_cubin): def test_kernel_from_handle_multiple_instances(get_saxpy_kernel_cubin): """Test creating multiple Kernel instances from the same handle""" original_kernel, objcode = get_saxpy_kernel_cubin - handle = int(original_kernel._handle) + handle = int(original_kernel.handle) # Create multiple Kernel instances from the same handle kernel1 = Kernel.from_handle(handle, objcode) @@ -507,4 +504,57 @@ def test_kernel_from_handle_multiple_instances(get_saxpy_kernel_cubin): assert isinstance(kernel3, Kernel) # All should reference the same underlying CUDA kernel handle - assert int(kernel1._handle) == int(kernel2._handle) == int(kernel3._handle) == handle + assert int(kernel1.handle) == int(kernel2.handle) == int(kernel3.handle) == handle + + +def test_kernel_keeps_library_alive(init_cuda): + """Test that a Kernel keeps its underlying library alive after ObjectCode goes out of scope.""" + import gc + + import numpy as np + + def get_kernel_only(): + """Return a kernel, letting ObjectCode go out of scope.""" + code = """ + extern "C" __global__ void write_value(int* out) { + if (threadIdx.x == 0 && blockIdx.x == 0) { + *out = 42; + } + } + """ + program = Program(code, "c++") + object_code = program.compile("cubin") + kernel = object_code.get_kernel("write_value") + # ObjectCode goes out of scope here + return kernel + + kernel = get_kernel_only() + + # Force GC to ensure ObjectCode destructor runs + gc.collect() + + # Kernel should still be valid + assert kernel.handle is not None + assert kernel.num_arguments == 1 + + # Actually launch the kernel to prove library is still loaded + device = Device() + stream = device.create_stream() + + # Allocate pinned host buffer and device buffer + pinned_mr = cuda.core.LegacyPinnedMemoryResource() + host_buf = pinned_mr.allocate(4) # sizeof(int) + result = np.from_dlpack(host_buf).view(np.int32) + result[:] = 0 + + dev_buf = device.memory_resource.allocate(4) + + # Launch kernel + config = cuda.core.LaunchConfig(grid=1, block=1) + cuda.core.launch(stream, config, kernel, dev_buf) + + # Copy result back to host + dev_buf.copy_to(host_buf, stream=stream) + stream.sync() + + assert result[0] == 42, f"Expected 42, got {result[0]}" diff --git a/cuda_core/tests/test_program.py b/cuda_core/tests/test_program.py index 9a9e4926ae..e2b3783dd7 100644 --- a/cuda_core/tests/test_program.py +++ b/cuda_core/tests/test_program.py @@ -335,7 +335,7 @@ def test_cpp_program_with_pch_options(init_cuda, tmp_path): @pytest.mark.parametrize("options", options) def test_ptx_program_with_various_options(init_cuda, ptx_code_object, options): - program = Program(ptx_code_object._module.decode(), "ptx", options=options) + program = Program(ptx_code_object.code.decode(), "ptx", options=options) assert program.backend == ("driver" if is_culink_backend else "nvJitLink") program.compile("cubin") program.close() @@ -378,7 +378,7 @@ def test_program_compile_valid_target_type(init_cuda): ptx_kernel = ptx_object_code.get_kernel("my_kernel") assert isinstance(ptx_kernel, Kernel) - program = Program(ptx_object_code._module.decode(), "ptx", options={"name": "24"}) + program = Program(ptx_object_code.code.decode(), "ptx", options={"name": "24"}) cubin_object_code = program.compile("cubin") assert isinstance(cubin_object_code, ObjectCode) assert cubin_object_code.name == "24" diff --git a/cuda_core/tests/test_utils.py b/cuda_core/tests/test_utils.py index a3c62a6aee..dd9c52e817 100644 --- a/cuda_core/tests/test_utils.py +++ b/cuda_core/tests/test_utils.py @@ -345,7 +345,7 @@ def _get_ptr(array): for view_as in ["dlpack", "cai"] ], ) -def test_view_sliced_external(shape, slices, stride_order, view_as): +def test_view_sliced_external(init_cuda, shape, slices, stride_order, view_as): if view_as == "dlpack": if np is None: pytest.skip("NumPy is not installed") @@ -380,7 +380,7 @@ def test_view_sliced_external(shape, slices, stride_order, view_as): ("stride_order", "view_as"), [(stride_order, view_as) for stride_order in ["C", "F"] for view_as in ["dlpack", "cai"]], ) -def test_view_sliced_external_negative_offset(stride_order, view_as): +def test_view_sliced_external_negative_offset(init_cuda, stride_order, view_as): shape = (5,) if view_as == "dlpack": if np is None: @@ -422,7 +422,7 @@ def test_view_sliced_external_negative_offset(stride_order, view_as): ) @pytest.mark.parametrize("shape", [(0,), (0, 0), (0, 0, 0)]) @pytest.mark.parametrize("dtype", [np.int64, np.uint8, np.float64]) -def test_view_zero_size_array(api, shape, dtype): +def test_view_zero_size_array(init_cuda, api, shape, dtype): cp = pytest.importorskip("cupy") x = cp.empty(shape, dtype=dtype) @@ -446,7 +446,7 @@ def test_from_buffer_with_non_power_of_two_itemsize(): assert view.dtype == dtype -def test_struct_array(): +def test_struct_array(init_cuda): cp = pytest.importorskip("cupy") x = np.array([(1.0, 2), (2.0, 3)], dtype=[("array1", np.float64), ("array2", np.int64)])