From dafd6379e4f7b4ed2bc25fa784dd17e95a1715be Mon Sep 17 00:00:00 2001 From: Anton Pozharskiy Date: Fri, 19 Dec 2025 18:18:31 +0100 Subject: [PATCH 1/9] A, perhaps hacky, hiding of the runtime in the GLOBAL_METHOD_TABLE overlay --- src/jlgen.jl | 6 ++---- src/runtime.jl | 25 ++++++++++++++++++++++++- src/utils.jl | 3 +++ 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/src/jlgen.jl b/src/jlgen.jl index d6116cb5..6f6b2ccd 100644 --- a/src/jlgen.jl +++ b/src/jlgen.jl @@ -293,10 +293,6 @@ end end # !HAS_INTEGRATED_CACHE -## method overrides - -Base.Experimental.@MethodTable(GLOBAL_METHOD_TABLE) - # Implements a priority lookup for method tables, where the first match in the stack get's returned. # An alternative to this would be to use a "Union" where we would query the parent method table and # do a most-specific match. @@ -314,6 +310,7 @@ CC.isoverlayed(::StackedMethodTable) = true # https://github.com/JuliaLang/julia/pull/51078 # same API as before but without returning isoverlayed flag function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int=-1) + println("findall: sig: $(sig), mt: $(table)") result = CC._findall(sig, table.mt, table.world, limit) result === nothing && return nothing # to many matches nr = CC.length(result) @@ -335,6 +332,7 @@ CC.isoverlayed(::StackedMethodTable) = true end function CC.findsup(@nospecialize(sig::Type), table::StackedMethodTable) + println("findall: sig: $(sig), mt: $(table)") match, valid_worlds = CC._findsup(sig, table.mt, table.world) match !== nothing && return match, valid_worlds parent_match, parent_valid_worlds = CC.findsup(sig, table.parent) diff --git a/src/runtime.jl b/src/runtime.jl index 2b11d915..92349466 100644 --- a/src/runtime.jl +++ b/src/runtime.jl @@ -12,6 +12,26 @@ module Runtime using ..GPUCompiler using LLVM using LLVM.Interop +using ExprTools: splitdef, combinedef + + +macro device_function(ex) + ex = macroexpand(__module__, ex) + def = splitdef(ex) + + # generate a function that errors + def[:body] = quote + error("This function is not intended for use on the CPU") + end + + esc(quote + $(combinedef(def)) + + # NOTE: no use of `@consistent_overlay` here because the regular function errors + Base.Experimental.@overlay($(GPUCompiler).GLOBAL_METHOD_TABLE, $ex) + end) +end + ## representation of a runtime method instance @@ -71,6 +91,8 @@ function compile(def, return_type, types, llvm_return_type=nothing, llvm_types=n meth = RuntimeMethodInstance(def, return_type, types, name, llvm_return_type, llvm_types, llvm_name) + println("Compile called for def $(def)") + if haskey(methods, name) error("Runtime function $name has already been registered!") end @@ -81,8 +103,9 @@ function compile(def, return_type, types, llvm_return_type=nothing, llvm_types=n # work around that by generating an llvmcall stub. can we do better by # using the new nonrecursive codegen to handle function lookup ourselves? if def isa Symbol + println("Symbol passed to compile: $(def)") args = [gensym() for typ in types] - @eval @inline $def($(args...)) = + @eval @device_function @inline $def($(args...)) = ccall($("extern $llvm_name"), llvmcall, $return_type, ($(types...),), $(args...)) end diff --git a/src/utils.jl b/src/utils.jl index 674d8f9b..d00366d4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -238,3 +238,6 @@ end return inits end end +## method overrides + +Base.Experimental.@MethodTable(GLOBAL_METHOD_TABLE) From 5f4ac40a8d5e5914ea1a96ea8209b63403451ef7 Mon Sep 17 00:00:00 2001 From: Anton Pozharskiy Date: Mon, 22 Dec 2025 15:35:56 +0100 Subject: [PATCH 2/9] cleanup some debugging --- src/runtime.jl | 22 ---------------------- src/utils.jl | 19 +++++++++++++++++++ 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/src/runtime.jl b/src/runtime.jl index 92349466..36bf8ef8 100644 --- a/src/runtime.jl +++ b/src/runtime.jl @@ -12,26 +12,6 @@ module Runtime using ..GPUCompiler using LLVM using LLVM.Interop -using ExprTools: splitdef, combinedef - - -macro device_function(ex) - ex = macroexpand(__module__, ex) - def = splitdef(ex) - - # generate a function that errors - def[:body] = quote - error("This function is not intended for use on the CPU") - end - - esc(quote - $(combinedef(def)) - - # NOTE: no use of `@consistent_overlay` here because the regular function errors - Base.Experimental.@overlay($(GPUCompiler).GLOBAL_METHOD_TABLE, $ex) - end) -end - ## representation of a runtime method instance @@ -91,7 +71,6 @@ function compile(def, return_type, types, llvm_return_type=nothing, llvm_types=n meth = RuntimeMethodInstance(def, return_type, types, name, llvm_return_type, llvm_types, llvm_name) - println("Compile called for def $(def)") if haskey(methods, name) error("Runtime function $name has already been registered!") @@ -103,7 +82,6 @@ function compile(def, return_type, types, llvm_return_type=nothing, llvm_types=n # work around that by generating an llvmcall stub. can we do better by # using the new nonrecursive codegen to handle function lookup ourselves? if def isa Symbol - println("Symbol passed to compile: $(def)") args = [gensym() for typ in types] @eval @device_function @inline $def($(args...)) = ccall($("extern $llvm_name"), llvmcall, $return_type, ($(types...),), $(args...)) diff --git a/src/utils.jl b/src/utils.jl index d00366d4..9bf12e2d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -241,3 +241,22 @@ end ## method overrides Base.Experimental.@MethodTable(GLOBAL_METHOD_TABLE) +using ExprTools: splitdef, combinedef +macro device_function(ex) + ex = macroexpand(__module__, ex) + def = splitdef(ex) + + # generate a function that errors + def[:body] = quote + error("This function is not intended for use on the CPU") + end + + esc(quote + $(combinedef(def)) + + # NOTE: no use of `@consistent_overlay` here because the regular function errors + Base.Experimental.@overlay($(GPUCompiler).GLOBAL_METHOD_TABLE, $ex) + end) +end + + From 9412c780e733d49485eb2f4ee84a1a508473e91c Mon Sep 17 00:00:00 2001 From: Anton Pozharskiy Date: Mon, 22 Dec 2025 16:48:42 +0100 Subject: [PATCH 3/9] fix issue caused by moving device_function --- src/runtime.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/runtime.jl b/src/runtime.jl index 36bf8ef8..d2fdf012 100644 --- a/src/runtime.jl +++ b/src/runtime.jl @@ -83,7 +83,7 @@ function compile(def, return_type, types, llvm_return_type=nothing, llvm_types=n # using the new nonrecursive codegen to handle function lookup ourselves? if def isa Symbol args = [gensym() for typ in types] - @eval @device_function @inline $def($(args...)) = + @eval GPUCompiler.@device_function @inline $def($(args...)) = ccall($("extern $llvm_name"), llvmcall, $return_type, ($(types...),), $(args...)) end From 6bd1ece7576a8fd00f2ff7c00952ff39fc225db8 Mon Sep 17 00:00:00 2001 From: Anton Pozharskiy Date: Tue, 23 Dec 2025 17:32:56 +0100 Subject: [PATCH 4/9] dummy CPU functions now seem to get us further but check_ir is kicking us out of some KernelAbstractions compilations in e.g OpenCL.jl --- src/driver.jl | 1 - src/jlgen.jl | 6 +++--- src/rtlib.jl | 6 ++++++ src/runtime.jl | 6 ++++-- src/utils.jl | 11 ++++++++--- src/validation.jl | 3 +++ 6 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/driver.jl b/src/driver.jl index b1d55d33..30c2b1ec 100644 --- a/src/driver.jl +++ b/src/driver.jl @@ -93,7 +93,6 @@ function compile_unhooked(output::Symbol, @nospecialize(job::CompilerJob); kwarg ## LLVM IR ir, ir_meta = emit_llvm(job) - if output == :llvm if job.config.strip @tracepoint "strip debug info" strip_debuginfo!(ir) diff --git a/src/jlgen.jl b/src/jlgen.jl index 6f6b2ccd..15174520 100644 --- a/src/jlgen.jl +++ b/src/jlgen.jl @@ -310,7 +310,7 @@ CC.isoverlayed(::StackedMethodTable) = true # https://github.com/JuliaLang/julia/pull/51078 # same API as before but without returning isoverlayed flag function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int=-1) - println("findall: sig: $(sig), mt: $(table)") + #println("findall: sig: $(sig), mt: $(table)") result = CC._findall(sig, table.mt, table.world, limit) result === nothing && return nothing # to many matches nr = CC.length(result) @@ -332,7 +332,7 @@ CC.isoverlayed(::StackedMethodTable) = true end function CC.findsup(@nospecialize(sig::Type), table::StackedMethodTable) - println("findall: sig: $(sig), mt: $(table)") + #println("findsup: sig: $(sig), mt: $(table)") match, valid_worlds = CC._findsup(sig, table.mt, table.world) match !== nothing && return match, valid_worlds parent_match, parent_valid_worlds = CC.findsup(sig, table.parent) @@ -488,7 +488,7 @@ CC.lock_mi_inference(interp::GPUInterpreter, mi::MethodInstance) = nothing CC.unlock_mi_inference(interp::GPUInterpreter, mi::MethodInstance) = nothing function CC.add_remark!(interp::GPUInterpreter, sv::CC.InferenceState, msg) - @safe_debug "Inference remark during GPU compilation of $(sv.linfo): $msg" + #@safe_debug "Inference remark during GPU compilation of $(sv.linfo): $msg" end CC.may_optimize(interp::GPUInterpreter) = true diff --git a/src/rtlib.jl b/src/rtlib.jl index 91b4c71c..616041ed 100644 --- a/src/rtlib.jl +++ b/src/rtlib.jl @@ -77,6 +77,9 @@ function emit_function!(mod, config::CompilerConfig, f, method) new_mod, meta = compile_unhooked(:llvm, CompilerJob(source, config)) ft = function_type(meta.entry) expected_ft = convert(LLVM.FunctionType, method) + + println("emit_function!: source: $(source)") + #println(code_typed(CompilerJob(source, config))) if return_type(ft) != return_type(expected_ft) error("Invalid return type for runtime function '$(method.name)': expected $(return_type(expected_ft)), got $(return_type(ft))") end @@ -108,12 +111,15 @@ function build_runtime(@nospecialize(job::CompilerJob)) config = CompilerConfig(job.config; kernel=false, toplevel=false, only_entry=false, strip=false) for method in values(Runtime.methods) + #println("build_runtime: method.def: $(method.def)") + #println("build_runtime: method.name: $(method.name)") def = if isa(method.def, Symbol) isdefined(runtime_module(job), method.def) || continue getfield(runtime_module(job), method.def) else method.def end + println("build_runtime: def: $(def)") emit_function!(mod, config, typeof(def), method) end diff --git a/src/runtime.jl b/src/runtime.jl index d2fdf012..2f7312e1 100644 --- a/src/runtime.jl +++ b/src/runtime.jl @@ -83,8 +83,10 @@ function compile(def, return_type, types, llvm_return_type=nothing, llvm_types=n # using the new nonrecursive codegen to handle function lookup ourselves? if def isa Symbol args = [gensym() for typ in types] - @eval GPUCompiler.@device_function @inline $def($(args...)) = - ccall($("extern $llvm_name"), llvmcall, $return_type, ($(types...),), $(args...)) + @eval GPUCompiler.@device_function($return_type, + @inline $def($(args...)) = + ccall($("extern $llvm_name"), llvmcall, $return_type, ($(types...),), $(args...)) + ) end return diff --git a/src/utils.jl b/src/utils.jl index 9bf12e2d..9959938b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -242,13 +242,18 @@ end Base.Experimental.@MethodTable(GLOBAL_METHOD_TABLE) using ExprTools: splitdef, combinedef -macro device_function(ex) +macro device_function(rt, ex) ex = macroexpand(__module__, ex) def = splitdef(ex) - # generate a function that errors + # generate a function that warns and returns the expected type + # FIXME: The type may not have a default constructor, what do we do then? + # Currently we are using the constructor with an Int64(1) as an argument. + # NOTE: using Int64(1) is a bit odd. This is because Ptr(Int64(0)) == C_NULL, and julia code lowering + # seems to get rid of this automatically. def[:body] = quote - error("This function is not intended for use on the CPU") + @warn "This function is not intended for use on the CPU something may have gone wrong" + $rt(1) end esc(quote diff --git a/src/validation.jl b/src/validation.jl index 0190d1c9..7045194c 100644 --- a/src/validation.jl +++ b/src/validation.jl @@ -177,6 +177,9 @@ function check_ir!(job, errors::Vector{IRError}, mod::LLVM.Module) # custom validation append!(errors, validate_ir(job, mod)) + if !isempty(errors) + write("error_ir.ll", string(mod)) + end return errors end From cf1b44266e8f9557c95af948e152627d97486daa Mon Sep 17 00:00:00 2001 From: Anton Pozharskiy Date: Sun, 4 Jan 2026 15:08:45 +0100 Subject: [PATCH 5/9] @warn causing wierd things to be included --- src/utils.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 9959938b..8be6a4b4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -252,7 +252,6 @@ macro device_function(rt, ex) # NOTE: using Int64(1) is a bit odd. This is because Ptr(Int64(0)) == C_NULL, and julia code lowering # seems to get rid of this automatically. def[:body] = quote - @warn "This function is not intended for use on the CPU something may have gone wrong" $rt(1) end From 52518142c2b1c438007b46d5c7e27e14acaf6f09 Mon Sep 17 00:00:00 2001 From: Anton Pozharskiy Date: Sun, 4 Jan 2026 15:23:21 +0100 Subject: [PATCH 6/9] remove debugging --- src/jlgen.jl | 2 -- src/validation.jl | 3 --- 2 files changed, 5 deletions(-) diff --git a/src/jlgen.jl b/src/jlgen.jl index 15174520..dcbe8bd2 100644 --- a/src/jlgen.jl +++ b/src/jlgen.jl @@ -310,7 +310,6 @@ CC.isoverlayed(::StackedMethodTable) = true # https://github.com/JuliaLang/julia/pull/51078 # same API as before but without returning isoverlayed flag function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int=-1) - #println("findall: sig: $(sig), mt: $(table)") result = CC._findall(sig, table.mt, table.world, limit) result === nothing && return nothing # to many matches nr = CC.length(result) @@ -332,7 +331,6 @@ CC.isoverlayed(::StackedMethodTable) = true end function CC.findsup(@nospecialize(sig::Type), table::StackedMethodTable) - #println("findsup: sig: $(sig), mt: $(table)") match, valid_worlds = CC._findsup(sig, table.mt, table.world) match !== nothing && return match, valid_worlds parent_match, parent_valid_worlds = CC.findsup(sig, table.parent) diff --git a/src/validation.jl b/src/validation.jl index 7045194c..0190d1c9 100644 --- a/src/validation.jl +++ b/src/validation.jl @@ -177,9 +177,6 @@ function check_ir!(job, errors::Vector{IRError}, mod::LLVM.Module) # custom validation append!(errors, validate_ir(job, mod)) - if !isempty(errors) - write("error_ir.ll", string(mod)) - end return errors end From 538e94a3bd325062add1278558b1cd41f6e50b18 Mon Sep 17 00:00:00 2001 From: Anton Pozharskiy Date: Sun, 4 Jan 2026 16:04:41 +0100 Subject: [PATCH 7/9] remove more debugging --- src/rtlib.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/rtlib.jl b/src/rtlib.jl index 616041ed..0ec50df1 100644 --- a/src/rtlib.jl +++ b/src/rtlib.jl @@ -78,8 +78,6 @@ function emit_function!(mod, config::CompilerConfig, f, method) ft = function_type(meta.entry) expected_ft = convert(LLVM.FunctionType, method) - println("emit_function!: source: $(source)") - #println(code_typed(CompilerJob(source, config))) if return_type(ft) != return_type(expected_ft) error("Invalid return type for runtime function '$(method.name)': expected $(return_type(expected_ft)), got $(return_type(ft))") end @@ -111,15 +109,12 @@ function build_runtime(@nospecialize(job::CompilerJob)) config = CompilerConfig(job.config; kernel=false, toplevel=false, only_entry=false, strip=false) for method in values(Runtime.methods) - #println("build_runtime: method.def: $(method.def)") - #println("build_runtime: method.name: $(method.name)") def = if isa(method.def, Symbol) isdefined(runtime_module(job), method.def) || continue getfield(runtime_module(job), method.def) else method.def end - println("build_runtime: def: $(def)") emit_function!(mod, config, typeof(def), method) end From 7502a84dd667a84b5b7e5a0e5207550dec6b494b Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Fri, 30 Jan 2026 10:29:15 -0600 Subject: [PATCH 8/9] Add @device_function test --- test/utils.jl | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/test/utils.jl b/test/utils.jl index 3b742795..26c189e2 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -193,3 +193,40 @@ end # Check that we can call this function from the CPU, to support deferred codegen for Enzyme. @test ccall("extern deferred_codegen", llvmcall, UInt, (UInt,), 3) == 3 end + +@testset "@device_function macro" begin + # Test that @device_function creates both CPU stub and overlay + # The macro should: + # 1. Define a CPU-visible function that returns the expected type + # 2. Register an overlay in GLOBAL_METHOD_TABLE for GPU compilation + + # Create a test module to contain the device functions + test_mod = @eval module $(gensym("DeviceFunctionTest")) + using GPUCompiler + + # Test with Ptr return type (common for runtime functions) + GPUCompiler.@device_function(Ptr{Nothing}, + @inline test_device_ptr() = ccall("extern gpu_test", llvmcall, Ptr{Nothing}, ()) + ) + + # Test with primitive return type + GPUCompiler.@device_function(Nothing, + @inline test_device_nothing() = ccall("extern gpu_test2", llvmcall, Nothing, ()) + ) + end + + # Verify the functions are defined in the test module + @test isdefined(test_mod, :test_device_ptr) + @test isdefined(test_mod, :test_device_nothing) + + # Verify the overlay exists in the global method table + mt_view = GPUCompiler.get_method_table_view(Base.get_world_counter(), GPUCompiler.GLOBAL_METHOD_TABLE) + sig_ptr = Tuple{typeof(test_mod.test_device_ptr)} + sig_nothing = Tuple{typeof(test_mod.test_device_nothing)} + + # The overlay should be findable in the method table + result_ptr = findsup(sig_ptr, mt_view) + result_nothing = findsup(sig_nothing, mt_view) + @test result_ptr !== nothing + @test result_nothing !== nothing +end From 0c60f0ca56bff3acb0cd061cc3671e9b2a98dc51 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Fri, 30 Jan 2026 10:29:26 -0600 Subject: [PATCH 9/9] Add comment --- src/jlgen.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/jlgen.jl b/src/jlgen.jl index dcbe8bd2..2812a02b 100644 --- a/src/jlgen.jl +++ b/src/jlgen.jl @@ -486,7 +486,10 @@ CC.lock_mi_inference(interp::GPUInterpreter, mi::MethodInstance) = nothing CC.unlock_mi_inference(interp::GPUInterpreter, mi::MethodInstance) = nothing function CC.add_remark!(interp::GPUInterpreter, sv::CC.InferenceState, msg) - #@safe_debug "Inference remark during GPU compilation of $(sv.linfo): $msg" + # NOTE: @safe_debug is disabled here because including logging/warning code causes + # CPU runtime functions (ccalls to Julia internals) to leak into the GPU IR, + # breaking AOT compilation. See PR #749 for details. + return nothing end CC.may_optimize(interp::GPUInterpreter) = true