Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
res/
*.cubin
Manifest.toml
Manifest*.toml
LocalPreferences.toml
CLAUDE.md
AGENTS.md
Expand Down
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Tim Besard <tim.besard@gmail.com>"]
version = "0.1.0"

[deps]
CompilerCaching = "9db33cc3-5358-4881-8759-fa4194144afd"
CUDA_Compiler_jll = "d1e2174e-dfdc-576e-b43e-73b79eb1aca8"
CUDA_Tile_jll = "2068806d-a867-5dbd-af0e-42c2eb5d895d"
IRStructurizer = "93e32bba-5bb8-402b-805d-ffb066edee93"
Expand All @@ -12,6 +13,8 @@ IRStructurizer = "93e32bba-5bb8-402b-805d-ffb066edee93"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[sources]
#CompilerCaching = {url = "https://github.com/maleadt/CompilerCaching.jl", rev="main"}
CompilerCaching = {path = "/Users/tim/Julia/pkg/CompilerCaching"}
IRStructurizer = {path = "IRStructurizer"}

[extensions]
Expand All @@ -21,6 +24,3 @@ CUDAExt = "CUDA"
julia = "1.11"
CUDA_Compiler_jll = "0.4"
CUDA_Tile_jll = "13.1"

[workspace]
projects = ["test", "IRStructurizer", "FileCheck"]
118 changes: 52 additions & 66 deletions ext/CUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,51 @@
module CUDAExt

using cuTile
using cuTile: TileArray, Constant, emit_tileir
using cuTile: TileArray, Constant, CGOpts, CuTileResults, emit_code

using CompilerCaching: CacheView, method_instance, results

using CUDA: CuModule, CuFunction, cudacall, device, capability
using CUDA_Compiler_jll

public launch

# Compilation cache - stores CuFunction directly to avoid re-loading CuModule
const _compilation_cache = Dict{Any, Any}() # (method, argtypes, sm_arch, opt_level, num_ctas, occupancy) => CuFunction
"""
emit_executable(cache, mi) -> CuFunction

Executable phase: compile bytecode to CUBIN and load into GPU memory.
"""
function emit_executable(cache::CacheView, mi::Core.MethodInstance)
# Delegate to previous phase (handles CI + IR + code)
bytecode = emit_code(cache, mi)

# Check executable cache
ci = get(cache, mi)
res = results(cache, ci)
res.executable !== nothing && return res.executable

# Compile to CUBIN and load
opts = cache.owner[2]
kernel_name = string(mi.def.name)

# Run tileiras to produce CUBIN
input_path = tempname() * ".tile"
output_path = tempname() * ".cubin"
try
write(input_path, bytecode)
run(`$(CUDA_Compiler_jll.tileiras()) $input_path -o $output_path --gpu-name $(opts.sm_arch) -O$(opts.opt_level)`)
cubin = read(output_path)

# Load into GPU memory
cumod = CuModule(cubin)
cufunc = CuFunction(cumod, kernel_name)
res.executable = cufunc
return cufunc
finally
rm(input_path, force=true)
rm(output_path, force=true)
end
end

"""
launch(f, grid, args...; name=nothing, sm_arch=default_sm_arch(), opt_level=3, num_ctas=nothing, occupancy=nothing)
Expand Down Expand Up @@ -62,23 +98,19 @@ function cuTile.launch(@nospecialize(f), grid, args...;
# Compute argument types from the converted arguments
argtypes = Tuple{map(typeof, tile_args)...}

# Determine kernel name
kernel_name = name !== nothing ? name : string(nameof(f))

# Use method instance in case of a redefinition
method = which(f, argtypes)

# Check compilation cache - returns CuFunction directly
cache_key = (method, argtypes, sm_arch, opt_level, num_ctas, occupancy)
cufunc = get(_compilation_cache, cache_key, nothing)
if cufunc === nothing || cuTile.compile_hook[] !== nothing
cubin = compile(f, argtypes; name, sm_arch, opt_level, num_ctas, occupancy)
if cufunc === nothing
cumod = CuModule(cubin)
cufunc = CuFunction(cumod, kernel_name)
_compilation_cache[cache_key] = cufunc
end
end
# Get world age and method instance
# Don't pass method_table - kernel functions are in the global table
# The overlay table is only used by the interpreter during inference
world = Base.get_world_counter()
mi = method_instance(f, argtypes; world)
mi === nothing && throw(MethodError(f, argtypes))

# Create cache view with compilation options as sharding keys
opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy)
cache = CacheView{CuTileResults}((:cuTile, opts), world)

# Run cached compilation
cufunc = emit_executable(cache, mi)

# Flatten arguments for cudacall - Constant returns () so ghost types disappear
flat_args = Tuple(Iterators.flatten(map(flatten, tile_args)))
Expand All @@ -104,52 +136,6 @@ function cuTile.launch(@nospecialize(f), grid, args...;
return nothing
end

"""
compile(f, argtypes; name=nothing, sm_arch=default_sm_arch(), opt_level=3, num_ctas=nothing, occupancy=nothing) -> Vector{UInt8}

Compile a Julia kernel function to a CUDA binary.
"""
function compile(@nospecialize(f), @nospecialize(argtypes);
name::Union{String, Nothing}=nothing,
sm_arch::String=default_sm_arch(),
opt_level::Int=3,
num_ctas::Union{Int, Nothing}=nothing,
occupancy::Union{Int, Nothing}=nothing)
tile_bytecode = emit_tileir(f, argtypes; name, sm_arch,
num_ctas, occupancy)

# Dump bytecode if JULIA_CUTILE_DUMP_BYTECODE is set
dump_dir = get(ENV, "JULIA_CUTILE_DUMP_BYTECODE", nothing)
if dump_dir !== nothing
mkpath(dump_dir)
# Get source location from the function's method
mi = first(methods(f, Base.to_tuple_type(argtypes)))
base_filename = basename(string(mi.file))
base_filename = first(splitext(base_filename))
# Find unique filename, adding counter if file exists
dump_path = joinpath(dump_dir, "$(base_filename).ln$(mi.line).cutile")
counter = 1
while isfile(dump_path)
counter += 1
dump_path = joinpath(dump_dir, "$(base_filename).ln$(mi.line).$(counter).cutile")
end
println(stderr, "Dumping TILEIR bytecode to file: $dump_path")
write(dump_path, tile_bytecode)
end

input_path = tempname() * ".tile"
output_path = tempname() * ".cubin"

try
write(input_path, tile_bytecode)
run(`$(CUDA_Compiler_jll.tileiras()) $input_path -o $output_path --gpu-name $sm_arch -O$opt_level`)
return read(output_path)
finally
rm(input_path, force=true)
rm(output_path, force=true)
end
end

"""
default_sm_arch() -> String

Expand Down
2 changes: 1 addition & 1 deletion src/compiler/codegen.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Codegen: Julia IR -> Tile IR bytecode

include("codegen/utils.jl")
include("codegen/kernel.jl")
include("codegen/control_flow.jl")
include("codegen/statements.jl")
include("codegen/expressions.jl")
include("codegen/values.jl")
include("codegen/utils.jl")
24 changes: 12 additions & 12 deletions src/compiler/codegen/kernel.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
# kernel and argument handling

"""
emit_kernel!(writer, func_buf, target; name, sm_arch=nothing, is_entry=true, num_ctas=nothing, occupancy=nothing)
emit_kernel!(writer, func_buf, sci, rettype; name, sm_arch=nothing, is_entry=true, num_ctas=nothing, occupancy=nothing)

Compile a TileTarget to Tile IR bytecode.
Compile a StructuredIRCode to Tile IR bytecode.
"""
function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
target::TileTarget;
name::String = string(target.mi.def.name),
sci::StructuredIRCode, rettype::Type;
name::String,
sm_arch::Union{String, Nothing} = nothing,
is_entry::Bool = true,
num_ctas::Union{Int, Nothing} = nothing,
occupancy::Union{Int, Nothing} = nothing)
ctx = CGCtx(writer, target, sm_arch)
ctx = CGCtx(writer, sci, sm_arch)
tt = ctx.tt

# Validate non-ghost argument types are concrete
for (i, argtype) in enumerate(target.sci.argtypes)
for (i, argtype) in enumerate(sci.argtypes)
is_ghost_type(unwrap_type(argtype)) && continue
require_concrete_type(argtype, "kernel argument $i")
end
Expand All @@ -25,7 +25,7 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
param_types = TypeId[]
param_mapping = Tuple{Int, Union{Nothing, Symbol}}[]

for (i, argtype) in enumerate(target.sci.argtypes)
for (i, argtype) in enumerate(sci.argtypes)
argtype_unwrapped = unwrap_type(argtype)
if is_ghost_type(argtype_unwrapped)
continue
Expand Down Expand Up @@ -57,8 +57,8 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},

# Return types
result_types = TypeId[]
if target.rettype !== Nothing && target.rettype !== Union{}
push!(result_types, tile_type_for_julia!(ctx, target.rettype))
if rettype !== Nothing && rettype !== Union{}
push!(result_types, tile_type_for_julia!(ctx, rettype))
end

# Create entry hints if provided
Expand Down Expand Up @@ -92,8 +92,8 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
# Regular argument - create concrete CGVal
@assert length(values) == 1
val = values[1]
type_id = tile_type_for_julia!(ctx, target.sci.argtypes[arg_idx])
tv = CGVal(val, type_id, target.sci.argtypes[arg_idx])
type_id = tile_type_for_julia!(ctx, sci.argtypes[arg_idx])
tv = CGVal(val, type_id, sci.argtypes[arg_idx])
ctx[SlotNumber(arg_idx)] = tv
ctx[Argument(arg_idx)] = tv
end
Expand All @@ -117,7 +117,7 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
ctx.token = encode_MakeTokenOp!(cb, token_type)

# Emit the structured IR (uses original Julia SSA indices everywhere)
emit_block!(ctx, ctx.target.sci.entry)
emit_block!(ctx, ctx.sci.entry)

finalize_function!(func_buf, cb, writer.debug_info)
end
Expand Down
Loading
Loading