Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ export MetalBackend

include("deprecated.jl")

include("warmup.jl")

include("precompile.jl")

end # module
19 changes: 15 additions & 4 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
try
dev = device()
return supports_family(dev, MTL.MTLGPUFamilyApple7) &&
supports_family(dev, MTL.MTLGPUFamilyMetal3)
supports_family(dev, MTL.MTLGPUFamilyMetal3)
catch
return false
end
end
else
# Becomes `nothing` once it has been determined that the device is on macOS
const _functional = Ref{Union{Nothing,Bool}}(false)
const _functional = Ref{Union{Nothing, Bool}}(false)

function functional()
if isnothing(_functional[])
Expand All @@ -24,6 +24,10 @@ else
end
end

# Async warmup system to reduce first-kernel JIT compilation latency
const _warmup_task = Ref{Union{Nothing, Task}}(nothing)
const _warmup_enabled = @load_preference("warmup", true)

function __init__()
precompiling = ccall(:jl_generating_output, Cint, ()) != 0
precompiling && return
Expand Down Expand Up @@ -63,7 +67,7 @@ function __init__()
_functional[] = nothing # VERSION <= v"1.12.0-DEV.1421"
end
catch err
@error "Failed to load Metal" exception=(err,catch_backtrace())
@error "Failed to load Metal" exception = (err, catch_backtrace())
return
end

Expand All @@ -72,10 +76,17 @@ function __init__()
if isdefined(Base, :active_repl_backend) && !isnothing(Base.active_repl_backend)
push!(Base.active_repl_backend.ast_transforms, synchronize_metal_tasks)
end

# Start async warmup to reduce first-kernel JIT compilation latency.
# Only run with multiple threads - with a single thread, the async task would
# block the main thread due to Julia's cooperative task runtime.
return if functional() && _warmup_enabled && Threads.nthreads() > 1
_warmup_task[] = errormonitor(Threads.@spawn _warmup_compilation())
end
end

function synchronize_metal_tasks(ex)
quote
return quote
try
$(ex)
finally
Expand Down
76 changes: 76 additions & 0 deletions src/warmup.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Async warmup to reduce first-kernel JIT compilation latency
#
# The first GPU kernel in a Metal.jl session takes ~1.75s due to one-time JIT
# compilation of GPUCompiler internals. By starting a minimal kernel compilation
# in the background during __init__(), we can reduce this to 0.035-0.20s for the
# user's first actual kernel—a 9-50x improvement.
#
# NOTE: Warmup only runs when multiple threads are available (Threads.nthreads() > 1).
# With a single thread, async warmup would block the main thread due to Julia's
# cooperative task runtime, potentially hurting perceived latency.

# Minimal kernel that triggers the full compilation pipeline
function _warmup_kernel!(a)
i = thread_position_in_grid().x
if i <= length(a)
a[i] = 0.0f0
end
return nothing
end

# Called from __init__() via @async
function _warmup_compilation()
try
# Minimal allocation - just need to trigger compilation
arr = MtlArray{Float32}(undef, 1)
# launch=false compiles but doesn't execute - fastest warmup path
@metal launch = false _warmup_kernel!(arr)
unsafe_free!(arr)
catch
# Silently ignore warmup failures - this is a non-critical optimization
end
return nothing
end

"""
Metal.warmup(; blocking::Bool=true)

Ensure the GPU compilation pipeline is warmed up.

The first GPU kernel in a Metal.jl session incurs a one-time JIT compilation overhead
of ~1.7 seconds. When running with multiple threads (`julia -t auto`), Metal.jl
automatically starts warming up in the background when the package is loaded.
This function allows you to explicitly wait for warmup to complete.

If `blocking=true` (default), waits for warmup to complete before returning.
If `blocking=false`, returns immediately while warmup continues in background.

# When to use

Call `Metal.warmup()` before timing-sensitive code to ensure consistent benchmark results:

```julia
using Metal
Metal.warmup() # wait for warmup to complete
@time @metal kernel!(a) # consistently fast (~0.035s, not ~1.7s)
```

# Note

- Background warmup only runs with multiple threads. With a single thread, async
warmup would block the main thread due to Julia's cooperative task runtime.
- You never need to call this function for correctness—only for consistent timing.
- Most users will never need to call this explicitly, as the background warmup will
complete during normal program setup (loading data, preprocessing, etc.).
"""
function warmup(; blocking::Bool = true)
task = _warmup_task[]
if task === nothing
# Warmup wasn't started (non-functional GPU or disabled)
return nothing
end
if blocking
wait(task)
end
return nothing
end
68 changes: 68 additions & 0 deletions test/warmup.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
@testset "warmup" begin
@testset "warmup API" begin
# warmup() should always return nothing, regardless of thread configuration
@test Metal.warmup() === nothing
@test Metal.warmup(blocking = false) === nothing
@test Metal.warmup(blocking = true) === nothing

# Multiple calls should be safe
@test Metal.warmup() === nothing
@test Metal.warmup() === nothing
end

@testset "kernel compilation after warmup" begin
Metal.warmup()

# Define and compile a test kernel
function test_kernel!(a)
i = thread_position_in_grid().x
if i <= length(a)
a[i] = Float32(i)
end
return nothing
end

a = MtlArray{Float32}(undef, 256)
@metal threads = 256 test_kernel!(a)
synchronize()

# Verify the kernel executed correctly
result = Array(a)
@test result[1] == 1.0f0
@test result[128] == 128.0f0
@test result[256] == 256.0f0
end

@testset "concurrent kernel compilation" begin
Metal.warmup()

# Define two distinct kernels
function kernel_add!(a)
i = thread_position_in_grid().x
if i <= length(a)
a[i] += 1.0f0
end
return nothing
end

function kernel_mul!(a)
i = thread_position_in_grid().x
if i <= length(a)
a[i] *= 2.0f0
end
return nothing
end

a = MtlArray(ones(Float32, 64))
b = MtlArray(ones(Float32, 64))

# Compile and run both kernels
@metal threads = 64 kernel_add!(a)
@metal threads = 64 kernel_mul!(b)
synchronize()

# Verify both executed correctly
@test Array(a)[1] == 2.0f0
@test Array(b)[1] == 2.0f0
end
end