Skip to content
Merged
1 change: 0 additions & 1 deletion src/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ function compile_unhooked(output::Symbol, @nospecialize(job::CompilerJob); kwarg
## LLVM IR

ir, ir_meta = emit_llvm(job)

if output == :llvm
if job.config.strip
@tracepoint "strip debug info" strip_debuginfo!(ir)
Expand Down
9 changes: 4 additions & 5 deletions src/jlgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,6 @@ end
end # !HAS_INTEGRATED_CACHE


## method overrides

Base.Experimental.@MethodTable(GLOBAL_METHOD_TABLE)

# Implements a priority lookup for method tables, where the first match in the stack get's returned.
# An alternative to this would be to use a "Union" where we would query the parent method table and
# do a most-specific match.
Expand Down Expand Up @@ -490,7 +486,10 @@ CC.lock_mi_inference(interp::GPUInterpreter, mi::MethodInstance) = nothing
CC.unlock_mi_inference(interp::GPUInterpreter, mi::MethodInstance) = nothing

function CC.add_remark!(interp::GPUInterpreter, sv::CC.InferenceState, msg)
@safe_debug "Inference remark during GPU compilation of $(sv.linfo): $msg"
# NOTE: @safe_debug is disabled here because including logging/warning code causes
# CPU runtime functions (ccalls to Julia internals) to leak into the GPU IR,
# breaking AOT compilation. See PR #749 for details.
return nothing
end

CC.may_optimize(interp::GPUInterpreter) = true
Expand Down
1 change: 1 addition & 0 deletions src/rtlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ function emit_function!(mod, config::CompilerConfig, f, method)
new_mod, meta = compile_unhooked(:llvm, CompilerJob(source, config))
ft = function_type(meta.entry)
expected_ft = convert(LLVM.FunctionType, method)

if return_type(ft) != return_type(expected_ft)
error("Invalid return type for runtime function '$(method.name)': expected $(return_type(expected_ft)), got $(return_type(ft))")
end
Expand Down
7 changes: 5 additions & 2 deletions src/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ function compile(def, return_type, types, llvm_return_type=nothing, llvm_types=n
meth = RuntimeMethodInstance(def,
return_type, types, name,
llvm_return_type, llvm_types, llvm_name)

if haskey(methods, name)
error("Runtime function $name has already been registered!")
end
Expand All @@ -82,8 +83,10 @@ function compile(def, return_type, types, llvm_return_type=nothing, llvm_types=n
# using the new nonrecursive codegen to handle function lookup ourselves?
if def isa Symbol
args = [gensym() for typ in types]
@eval @inline $def($(args...)) =
ccall($("extern $llvm_name"), llvmcall, $return_type, ($(types...),), $(args...))
@eval GPUCompiler.@device_function($return_type,
@inline $def($(args...)) =
ccall($("extern $llvm_name"), llvmcall, $return_type, ($(types...),), $(args...))
)
end

return
Expand Down
26 changes: 26 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,29 @@ end
return inits
end
end
## method overrides

Base.Experimental.@MethodTable(GLOBAL_METHOD_TABLE)
using ExprTools: splitdef, combinedef
macro device_function(rt, ex)
ex = macroexpand(__module__, ex)
def = splitdef(ex)

# generate a function that warns and returns the expected type
# FIXME: The type may not have a default constructor, what do we do then?
# Currently we are using the constructor with an Int64(1) as an argument.
# NOTE: using Int64(1) is a bit odd. This is because Ptr(Int64(0)) == C_NULL, and julia code lowering
# seems to get rid of this automatically.
def[:body] = quote
$rt(1)
end

esc(quote
$(combinedef(def))

# NOTE: no use of `@consistent_overlay` here because the regular function errors
Base.Experimental.@overlay($(GPUCompiler).GLOBAL_METHOD_TABLE, $ex)
end)
end


37 changes: 37 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,40 @@ end
# Check that we can call this function from the CPU, to support deferred codegen for Enzyme.
@test ccall("extern deferred_codegen", llvmcall, UInt, (UInt,), 3) == 3
end

@testset "@device_function macro" begin
# Test that @device_function creates both CPU stub and overlay
# The macro should:
# 1. Define a CPU-visible function that returns the expected type
# 2. Register an overlay in GLOBAL_METHOD_TABLE for GPU compilation

# Create a test module to contain the device functions
test_mod = @eval module $(gensym("DeviceFunctionTest"))
using GPUCompiler

# Test with Ptr return type (common for runtime functions)
GPUCompiler.@device_function(Ptr{Nothing},
@inline test_device_ptr() = ccall("extern gpu_test", llvmcall, Ptr{Nothing}, ())
)

# Test with primitive return type
GPUCompiler.@device_function(Nothing,
@inline test_device_nothing() = ccall("extern gpu_test2", llvmcall, Nothing, ())
)
end

# Verify the functions are defined in the test module
@test isdefined(test_mod, :test_device_ptr)
@test isdefined(test_mod, :test_device_nothing)

# Verify the overlay exists in the global method table
mt_view = GPUCompiler.get_method_table_view(Base.get_world_counter(), GPUCompiler.GLOBAL_METHOD_TABLE)
sig_ptr = Tuple{typeof(test_mod.test_device_ptr)}
sig_nothing = Tuple{typeof(test_mod.test_device_nothing)}

# The overlay should be findable in the method table
result_ptr = findsup(sig_ptr, mt_view)
result_nothing = findsup(sig_nothing, mt_view)
@test result_ptr !== nothing
@test result_nothing !== nothing
end
Loading