From eeb38ccc9a05fde7c6510b9e72b92b778a737a30 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 21 Jan 2026 16:12:05 +0100 Subject: [PATCH 1/5] Switch to CompilerCaching.jl --- .gitignore | 2 +- Project.toml | 5 +- ext/CUDAExt.jl | 108 ++++----- src/compiler/codegen.jl | 2 +- src/compiler/codegen/kernel.jl | 24 +- src/compiler/codegen/utils.jl | 363 ++++++++++++++++++++++++++++++ src/compiler/interface.jl | 299 +++++++++++++++++++++++++ src/compiler/interpreter.jl | 112 ---------- src/compiler/reflection.jl | 95 -------- src/compiler/target.jl | 389 --------------------------------- src/cuTile.jl | 4 +- 11 files changed, 721 insertions(+), 682 deletions(-) create mode 100644 src/compiler/interface.jl delete mode 100644 src/compiler/interpreter.jl delete mode 100644 src/compiler/reflection.jl delete mode 100644 src/compiler/target.jl diff --git a/.gitignore b/.gitignore index daf5565..fd9f176 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ res/ *.cubin -Manifest.toml +Manifest*.toml LocalPreferences.toml CLAUDE.md AGENTS.md diff --git a/Project.toml b/Project.toml index 83cf4b8..90f1f81 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Tim Besard "] version = "0.1.0" [deps] +CompilerCaching = "9db33cc3-5358-4881-8759-fa4194144afd" CUDA_Compiler_jll = "d1e2174e-dfdc-576e-b43e-73b79eb1aca8" CUDA_Tile_jll = "2068806d-a867-5dbd-af0e-42c2eb5d895d" IRStructurizer = "93e32bba-5bb8-402b-805d-ffb066edee93" @@ -12,6 +13,7 @@ IRStructurizer = "93e32bba-5bb8-402b-805d-ffb066edee93" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" [sources] +CompilerCaching = {url = "https://github.com/maleadt/CompilerCaching.jl", rev="main"} IRStructurizer = {path = "IRStructurizer"} [extensions] @@ -21,6 +23,3 @@ CUDAExt = "CUDA" julia = "1.11" CUDA_Compiler_jll = "0.4" CUDA_Tile_jll = "13.1" - -[workspace] -projects = ["test", "IRStructurizer", "FileCheck"] diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index 0012c3d..ea789f8 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -1,15 +1,41 @@ module CUDAExt using cuTile -using cuTile: TileArray, Constant, emit_tileir +using cuTile: TileArray, Constant, CGOpts, emit_ir, emit_code + +using CompilerCaching: CacheHandle, cached_compilation, method_instance using CUDA: CuModule, CuFunction, cudacall, device, capability using CUDA_Compiler_jll public launch -# Compilation cache - stores CuFunction directly to avoid re-loading CuModule -const _compilation_cache = Dict{Any, Any}() # (method, argtypes, sm_arch, opt_level, num_ctas, occupancy) => CuFunction +""" + emit_executable(cache, mi, world, bytecode) -> CuFunction + +Executable phase: run tileiras on bytecode to produce CUBIN, then load into GPU memory. +This is the only session-dependent phase. +""" +function emit_executable(cache::CacheHandle, mi::Core.MethodInstance, world::UInt, bytecode::Vector{UInt8}) + opts = cache.keys + kernel_name = string(mi.def.name) + + # Run tileiras to produce CUBIN + input_path = tempname() * ".tile" + output_path = tempname() * ".cubin" + try + write(input_path, bytecode) + run(`$(CUDA_Compiler_jll.tileiras()) $input_path -o $output_path --gpu-name $(opts.sm_arch) -O$(opts.opt_level)`) + cubin = read(output_path) + + # Load into GPU memory + cumod = CuModule(cubin) + return CuFunction(cumod, kernel_name) + finally + rm(input_path, force=true) + rm(output_path, force=true) + end +end """ launch(f, grid, args...; name=nothing, sm_arch=default_sm_arch(), opt_level=3, num_ctas=nothing, occupancy=nothing) @@ -62,23 +88,19 @@ function cuTile.launch(@nospecialize(f), grid, args...; # Compute argument types from the converted arguments argtypes = Tuple{map(typeof, tile_args)...} - # Determine kernel name - kernel_name = name !== nothing ? name : string(nameof(f)) - - # Use method instance in case of a redefinition - method = which(f, argtypes) - - # Check compilation cache - returns CuFunction directly - cache_key = (method, argtypes, sm_arch, opt_level, num_ctas, occupancy) - cufunc = get(_compilation_cache, cache_key, nothing) - if cufunc === nothing || cuTile.compile_hook[] !== nothing - cubin = compile(f, argtypes; name, sm_arch, opt_level, num_ctas, occupancy) - if cufunc === nothing - cumod = CuModule(cubin) - cufunc = CuFunction(cumod, kernel_name) - _compilation_cache[cache_key] = cufunc - end - end + # Get world age and method instance + # Don't pass method_table - kernel functions are in the global table + # The overlay table is only used by the interpreter during inference + world = Base.get_world_counter() + mi = method_instance(f, argtypes; world) + mi === nothing && throw(MethodError(f, argtypes)) + + # Create cache handle with compilation options as sharding keys + opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy) + cache = CacheHandle{CGOpts}(:cuTile, opts) + + # Run cached three-phase compilation + cufunc = cached_compilation(cache, mi, world; emit_ir, emit_code, emit_executable) # Flatten arguments for cudacall - Constant returns () so ghost types disappear flat_args = Tuple(Iterators.flatten(map(flatten, tile_args))) @@ -104,52 +126,6 @@ function cuTile.launch(@nospecialize(f), grid, args...; return nothing end -""" - compile(f, argtypes; name=nothing, sm_arch=default_sm_arch(), opt_level=3, num_ctas=nothing, occupancy=nothing) -> Vector{UInt8} - -Compile a Julia kernel function to a CUDA binary. -""" -function compile(@nospecialize(f), @nospecialize(argtypes); - name::Union{String, Nothing}=nothing, - sm_arch::String=default_sm_arch(), - opt_level::Int=3, - num_ctas::Union{Int, Nothing}=nothing, - occupancy::Union{Int, Nothing}=nothing) - tile_bytecode = emit_tileir(f, argtypes; name, sm_arch, - num_ctas, occupancy) - - # Dump bytecode if JULIA_CUTILE_DUMP_BYTECODE is set - dump_dir = get(ENV, "JULIA_CUTILE_DUMP_BYTECODE", nothing) - if dump_dir !== nothing - mkpath(dump_dir) - # Get source location from the function's method - mi = first(methods(f, Base.to_tuple_type(argtypes))) - base_filename = basename(string(mi.file)) - base_filename = first(splitext(base_filename)) - # Find unique filename, adding counter if file exists - dump_path = joinpath(dump_dir, "$(base_filename).ln$(mi.line).cutile") - counter = 1 - while isfile(dump_path) - counter += 1 - dump_path = joinpath(dump_dir, "$(base_filename).ln$(mi.line).$(counter).cutile") - end - println(stderr, "Dumping TILEIR bytecode to file: $dump_path") - write(dump_path, tile_bytecode) - end - - input_path = tempname() * ".tile" - output_path = tempname() * ".cubin" - - try - write(input_path, tile_bytecode) - run(`$(CUDA_Compiler_jll.tileiras()) $input_path -o $output_path --gpu-name $sm_arch -O$opt_level`) - return read(output_path) - finally - rm(input_path, force=true) - rm(output_path, force=true) - end -end - """ default_sm_arch() -> String diff --git a/src/compiler/codegen.jl b/src/compiler/codegen.jl index 2c775d4..564aa8e 100644 --- a/src/compiler/codegen.jl +++ b/src/compiler/codegen.jl @@ -1,8 +1,8 @@ # Codegen: Julia IR -> Tile IR bytecode +include("codegen/utils.jl") include("codegen/kernel.jl") include("codegen/control_flow.jl") include("codegen/statements.jl") include("codegen/expressions.jl") include("codegen/values.jl") -include("codegen/utils.jl") diff --git a/src/compiler/codegen/kernel.jl b/src/compiler/codegen/kernel.jl index 81492b1..b50fa6d 100644 --- a/src/compiler/codegen/kernel.jl +++ b/src/compiler/codegen/kernel.jl @@ -1,22 +1,22 @@ # kernel and argument handling """ - emit_kernel!(writer, func_buf, target; name, sm_arch=nothing, is_entry=true, num_ctas=nothing, occupancy=nothing) + emit_kernel!(writer, func_buf, sci, rettype; name, sm_arch=nothing, is_entry=true, num_ctas=nothing, occupancy=nothing) -Compile a TileTarget to Tile IR bytecode. +Compile a StructuredIRCode to Tile IR bytecode. """ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8}, - target::TileTarget; - name::String = string(target.mi.def.name), + sci::StructuredIRCode, rettype::Type; + name::String, sm_arch::Union{String, Nothing} = nothing, is_entry::Bool = true, num_ctas::Union{Int, Nothing} = nothing, occupancy::Union{Int, Nothing} = nothing) - ctx = CGCtx(writer, target, sm_arch) + ctx = CGCtx(writer, sci, sm_arch) tt = ctx.tt # Validate non-ghost argument types are concrete - for (i, argtype) in enumerate(target.sci.argtypes) + for (i, argtype) in enumerate(sci.argtypes) is_ghost_type(unwrap_type(argtype)) && continue require_concrete_type(argtype, "kernel argument $i") end @@ -25,7 +25,7 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8}, param_types = TypeId[] param_mapping = Tuple{Int, Union{Nothing, Symbol}}[] - for (i, argtype) in enumerate(target.sci.argtypes) + for (i, argtype) in enumerate(sci.argtypes) argtype_unwrapped = unwrap_type(argtype) if is_ghost_type(argtype_unwrapped) continue @@ -57,8 +57,8 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8}, # Return types result_types = TypeId[] - if target.rettype !== Nothing && target.rettype !== Union{} - push!(result_types, tile_type_for_julia!(ctx, target.rettype)) + if rettype !== Nothing && rettype !== Union{} + push!(result_types, tile_type_for_julia!(ctx, rettype)) end # Create entry hints if provided @@ -92,8 +92,8 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8}, # Regular argument - create concrete CGVal @assert length(values) == 1 val = values[1] - type_id = tile_type_for_julia!(ctx, target.sci.argtypes[arg_idx]) - tv = CGVal(val, type_id, target.sci.argtypes[arg_idx]) + type_id = tile_type_for_julia!(ctx, sci.argtypes[arg_idx]) + tv = CGVal(val, type_id, sci.argtypes[arg_idx]) ctx[SlotNumber(arg_idx)] = tv ctx[Argument(arg_idx)] = tv end @@ -117,7 +117,7 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8}, ctx.token = encode_MakeTokenOp!(cb, token_type) # Emit the structured IR (uses original Julia SSA indices everywhere) - emit_block!(ctx, ctx.target.sci.entry) + emit_block!(ctx, ctx.sci.entry) finalize_function!(func_buf, cb, writer.debug_info) end diff --git a/src/compiler/codegen/utils.jl b/src/compiler/codegen/utils.jl index abc2c97..4dd9f44 100644 --- a/src/compiler/codegen/utils.jl +++ b/src/compiler/codegen/utils.jl @@ -1,3 +1,366 @@ +# Codegen types and utilities +# +# Core types (CGVal, CGCtx) and helper functions for Tile IR code generation. + +#============================================================================= + CGVal: Unified value representation (analogous to Julia's jl_cgval_t) +=============================================================================# + +""" + CGVal + +Represents a value during Tile IR code generation, bundling the IR value +with its type information and metadata. + +Similar to Julia compiler's `jl_cgval_t`, this provides a unified representation +for all values flowing through codegen. A CGVal can be either: +1. A concrete SSA value (v::Value) +2. A multi-value result from control flow ops (v::Vector{Value}) +3. A lazy argument reference chain (v is nothing, arg_ref tracks the access path) +4. A ghost value (v is nothing, zero-size singleton) +""" +struct CGVal + v::Union{Value, Vector{Value}, Nothing} # Single value, multi-value, or nothing + type_id::Union{TypeId, Nothing} # Tile IR type (nothing for lazy refs or multi-value) + jltype::Any # Original Julia type + shape::Vector{Int} # Tile shape (empty for scalars) + # Lazy argument reference: (arg_idx, [:field, index, ...]) + # e.g., (1, [:sizes, 2]) means "argument 1, field :sizes, index 2" + arg_ref::Union{Tuple{Int, Vector{Union{Symbol, Int}}}, Nothing} + constant::Union{Some, Nothing} # Nothing = no constant, Some(x) = constant value x + tuple::Union{Vector{Any}, Nothing} # For tuples: component refs (SSAValue, etc.) +end + +# Convenience constructors for concrete values +CGVal(v::Value, type_id::TypeId, @nospecialize(jltype)) = + CGVal(v, type_id, jltype, Int[], nothing, nothing, nothing) + +CGVal(v::Value, type_id::TypeId, @nospecialize(jltype), shape::Vector{Int}) = + CGVal(v, type_id, jltype, shape, nothing, nothing, nothing) + +# Constructor for multi-value results (from loops, ifs) +CGVal(v::Vector{Value}, @nospecialize(jltype)) = + CGVal(v, nothing, jltype, Int[], nothing, nothing, nothing) + +# Constructor for lazy argument references +function arg_ref_value(arg_idx::Int, chain::Vector{Union{Symbol, Int}}, @nospecialize(jltype)) + CGVal(nothing, nothing, jltype, Int[], (arg_idx, chain), nothing, nothing) +end + +""" + ghost_value(jltype[, constant]) -> CGVal + +Create a ghost value (zero-size singleton with no runtime representation). +Optionally stores a compile-time constant value. +""" +ghost_value(@nospecialize(jltype)) = CGVal(nothing, TypeId(-1), jltype, Int[], nothing, nothing, nothing) +ghost_value(@nospecialize(jltype), constant) = CGVal(nothing, TypeId(-1), jltype, Int[], nothing, Some(constant), nothing) + +""" + tuple_value(jltype, component_refs, component_constants) -> CGVal + +Create a tuple value with tracked component refs. Derives constant if all components have constants. +Used by intrinsics like cat() that need to access individual tuple elements. +""" +function tuple_value(@nospecialize(jltype), component_refs::Vector{Any}, component_constants::Vector{Any}) + # If all components have constants, derive the tuple constant + constant = if all(!isnothing, component_constants) + Some(Tuple(component_constants)) + else + nothing + end + CGVal(nothing, TypeId(-1), jltype, Int[], nothing, constant, component_refs) +end + +""" + is_ghost(tv::CGVal) -> Bool + +Check if a CGVal is a ghost (no runtime representation). +""" +is_ghost(tv::CGVal) = tv.v === nothing && tv.arg_ref === nothing + +""" + is_arg_ref(tv::CGVal) -> Bool + +Check if a CGVal is a lazy argument reference. +""" +is_arg_ref(tv::CGVal) = tv.arg_ref !== nothing + +#============================================================================= + CGCtx: Compilation context +=============================================================================# + +""" + CGCtx + +Holds all state during Tile IR code generation for a kernel function. +Maps Julia SSA values to CGVals and manages bytecode emission. +""" +mutable struct CGCtx + # SSA value mapping: original Julia SSA index -> CGVal + # Uses global/original indices everywhere (no local renumbering) + # Loop/if ops store a CGVal with tuple_values field (extracted by getfield statements) + values::Dict{Int, CGVal} + args::Dict{Int, CGVal} # Argument index -> CGVal + slots::Dict{Int, CGVal} # Slot number -> CGVal + block_args::Dict{Int, CGVal} # BlockArg id -> CGVal (for control flow) + + # Destructured argument handling (for TileArray fields) + arg_flat_values::Dict{Tuple{Int, Union{Nothing, Symbol}}, Vector{Value}} + arg_types::Dict{Int, Type} + + # Cached TensorViews for TileArray arguments (arg_idx -> (Value, TypeId)) + tensor_views::Dict{Int, Tuple{Value, TypeId}} + + # Bytecode infrastructure + cb::CodeBuilder + tt::TypeTable + sci::StructuredIRCode + + # Memory ordering token + token::Union{Value, Nothing} + token_type::Union{TypeId, Nothing} + + # Type cache: Julia type -> TypeId + type_cache::Dict{Type, TypeId} + + # Target architecture (e.g., :sm_100) + sm_arch::Union{String, Nothing} +end + +function CGCtx(writer::BytecodeWriter, sci::StructuredIRCode, sm_arch::Union{String, Nothing}=nothing) + CGCtx( + Dict{Int, CGVal}(), + Dict{Int, CGVal}(), + Dict{Int, CGVal}(), + Dict{Int, CGVal}(), + Dict{Tuple{Int, Union{Nothing, Symbol}}, Vector{Value}}(), + Dict{Int, Type}(), + Dict{Int, Tuple{Value, TypeId}}(), # tensor_views cache + CodeBuilder(writer.string_table, writer.constant_table, writer.type_table), + writer.type_table, + sci, + nothing, + nothing, + Dict{Type, TypeId}(), + sm_arch, + ) +end + +#============================================================================= + Value lookup via indexing syntax +=============================================================================# + +function Base.getindex(ctx::CGCtx, ssa::SSAValue) + # Simple lookup by original Julia SSA index + get(ctx.values, ssa.id, nothing) +end + +function Base.getindex(ctx::CGCtx, arg::Argument) + get(ctx.args, arg.n, nothing) +end + +function Base.getindex(ctx::CGCtx, slot::SlotNumber) + get(ctx.slots, slot.id, nothing) +end + +function Base.setindex!(ctx::CGCtx, tv::CGVal, ssa::SSAValue) + ctx.values[ssa.id] = tv +end + +function Base.setindex!(ctx::CGCtx, tv::CGVal, arg::Argument) + ctx.args[arg.n] = tv +end + +function Base.setindex!(ctx::CGCtx, tv::CGVal, slot::SlotNumber) + ctx.slots[slot.id] = tv +end + +function Base.getindex(ctx::CGCtx, block_arg::BlockArg) + get(ctx.block_args, block_arg.id, nothing) +end + +function Base.setindex!(ctx::CGCtx, tv::CGVal, block_arg::BlockArg) + ctx.block_args[block_arg.id] = tv +end + +#============================================================================= + Destructured argument helpers +=============================================================================# + +""" + get_arg_flat_values(ctx, arg_idx, field=nothing) -> Union{Vector{Value}, Nothing} + +Get the flat Tile IR values for an argument or its field. +""" +function get_arg_flat_values(ctx::CGCtx, arg_idx::Int, field::Union{Nothing, Symbol}=nothing) + get(ctx.arg_flat_values, (arg_idx, field), nothing) +end + +""" + is_destructured_arg(ctx, arg_idx) -> Bool + +Check if an argument was destructured into multiple flat parameters. +""" +is_destructured_arg(ctx::CGCtx, arg_idx::Int) = haskey(ctx.arg_types, arg_idx) + +""" + get_arg_type(ctx, arg_idx) -> Union{Type, Nothing} + +Get the original Julia type for a destructured argument. +""" +get_arg_type(ctx::CGCtx, arg_idx::Int) = get(ctx.arg_types, arg_idx, nothing) + +#============================================================================= + Type conversion utilities +=============================================================================# + +""" + unwrap_type(T) -> Type + +Unwrap type wrappers like Core.Const to get the actual type. +""" +function unwrap_type(@nospecialize(T)) + if T isa Core.Const + return typeof(T.val) + elseif T isa Core.PartialStruct + return T.typ + elseif T isa Type + return T + else + return T + end +end + +""" + require_concrete_type(T, context::String) + +Ensure a type is fully concrete (not a UnionAll). +""" +function require_concrete_type(@nospecialize(T), context::String) + T_unwrapped = unwrap_type(T) + if T_unwrapped isa UnionAll + error("Type must be fully concrete in $context, got partial type: $T") + end + return T_unwrapped +end + +""" + tile_type_for_julia!(ctx, T) -> TypeId + +Get or create a Tile IR type for a Julia type. +""" +function tile_type_for_julia!(ctx::CGCtx, @nospecialize(T)) + actual_type = unwrap_type(T) + get!(ctx.type_cache, actual_type) do + _tile_type_for_julia!(ctx.tt, actual_type) + end +end + +function _tile_type_for_julia!(tt::TypeTable, @nospecialize(T::Type)) + # Scalar types -> 0-D tile + if T === Bool + return tile_type!(tt, I1(tt), Int[]) + elseif T === Int32 || T === UInt32 + return tile_type!(tt, I32(tt), Int[]) + elseif T === Int64 || T === UInt64 + return tile_type!(tt, I64(tt), Int[]) + elseif T === Float16 + return tile_type!(tt, F16(tt), Int[]) + elseif T === Float32 + return tile_type!(tt, F32(tt), Int[]) + elseif T === Float64 + return tile_type!(tt, F64(tt), Int[]) + elseif T === Nothing + return Token(tt) + end + + # Pointers -> 0-D tile of pointer type + if T <: Ptr + elem_dtype = julia_to_tile_dtype!(tt, eltype(T)) + ptr_type = pointer_type!(tt, elem_dtype) + return tile_type!(tt, ptr_type, Int[]) + end + + # Tile{T, Shape} -> tile type with shape + if T <: Tile + if T isa UnionAll || !isa(T, DataType) || length(T.parameters) < 2 + error("Tile type must be fully specified with element type and shape, got: $T. " * + "This indicates type instability in the kernel - ensure all tile operations have inferrable shapes.") + end + elem_type = T.parameters[1] + shape_param = T.parameters[2] + if !(shape_param isa Tuple) + error("Tile shape must be a tuple, got: $shape_param") + end + elem_dtype = julia_to_tile_dtype!(tt, elem_type) + shape = collect(Int, shape_param) + return tile_type!(tt, elem_dtype, shape) + end + + error("Unsupported Julia type for Tile IR: $T") +end + +""" + tile_type_and_shape_for_julia!(ctx, T) -> (TypeId, Vector{Int}) + +Get the Tile IR type and shape for a Julia type. +""" +function tile_type_and_shape_for_julia!(ctx::CGCtx, @nospecialize(T)) + actual_type = unwrap_type(T) + type_id = tile_type_for_julia!(ctx, actual_type) + + # Extract shape from Tile types + shape = Int[] + if actual_type <: Tile && length(actual_type.parameters) >= 2 + shape_param = actual_type.parameters[2] + if shape_param isa Tuple + shape = collect(Int, shape_param) + end + end + + return (type_id, shape) +end + +#============================================================================= + Struct destructuring helpers +=============================================================================# + +""" + is_ghost_type(T) -> Bool + +Check if a type is a ghost type (zero-size singleton). +""" +function is_ghost_type(@nospecialize(T)) + try + isbitstype(T) && sizeof(T) == 0 + catch + false + end +end + +""" + should_destructure(T) -> Bool + +Check if a type should be destructured into flat parameters. +""" +function should_destructure(@nospecialize(T)) + T = unwrap_type(T) + isstructtype(T) || return false + is_ghost_type(T) && return false + isprimitivetype(T) && return false + T <: TileArray && return true + return false +end + +""" + flat_field_count(T) -> Int + +Count flat parameters a type expands to. +""" +flat_field_count(::Type{<:NTuple{N, T}}) where {N, T} = N +flat_field_count(::Type) = 1 + #----------------------------------------------------------------------------- # Argument helpers #----------------------------------------------------------------------------- diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl new file mode 100644 index 0000000..5ed3c24 --- /dev/null +++ b/src/compiler/interface.jl @@ -0,0 +1,299 @@ +# Compilation interface for cuTile +# +# This file provides the public compilation API: +# - cuTileInterpreter: custom interpreter with overlay method table +# - emit_ir/emit_code: three-phase compilation callbacks (emit_executable in CUDAExt) +# - code_tiled/@code_tiled: reflection utilities + +export code_tiled, @code_tiled + +using CompilerCaching: CacheHandle, @setup_caching, get_ir, get_code, compile_hook, + compile_hook!, method_instance, populate! + +#============================================================================= + Interpreter +=============================================================================# + +Base.Experimental.@MethodTable cuTileMethodTable + +function get_method_table_view(world::UInt) + CC.CachedMethodTable(CC.OverlayMethodTable(world, cuTileMethodTable)) +end + +""" +Custom interpreter that supports overlay method tables for cuTile compilation. +This is necessary because NativeInterpreter has a fixed method_table type parameter. +""" +struct cuTileInterpreter <: CC.AbstractInterpreter + cache::CacheHandle + world::UInt + method_table::CC.CachedMethodTable{CC.OverlayMethodTable} + inf_cache::Vector{CC.InferenceResult} + inf_params::CC.InferenceParams + opt_params::CC.OptimizationParams +end + +function cuTileInterpreter(cache::CacheHandle, world::UInt=Base.get_world_counter(); + always_inline::Bool=true) + method_table = get_method_table_view(world) + inf_cache = Vector{CC.InferenceResult}() + inf_params = CC.InferenceParams() + opt_params = if always_inline + CC.OptimizationParams(; inline_cost_threshold=typemax(Int)) + else + CC.OptimizationParams() + end + return cuTileInterpreter(cache, world, method_table, inf_cache, inf_params, opt_params) +end + +# Required AbstractInterpreter interface methods +CC.InferenceParams(interp::cuTileInterpreter) = interp.inf_params +CC.OptimizationParams(interp::cuTileInterpreter) = interp.opt_params +CC.get_inference_cache(interp::cuTileInterpreter) = interp.inf_cache + +# World age +@static if isdefined(CC, :get_inference_world) + CC.get_inference_world(interp::cuTileInterpreter) = interp.world +else + CC.get_world_counter(interp::cuTileInterpreter) = interp.world +end + +# Method table - this enables the overlays +CC.method_table(interp::cuTileInterpreter) = interp.method_table + +# Locking - not needed for non-cached compilation +CC.lock_mi_inference(::cuTileInterpreter, ::MethodInstance) = nothing +CC.unlock_mi_inference(::cuTileInterpreter, ::MethodInstance) = nothing + +# Setup caching - generates cache_owner and ipo_dataflow_analysis! methods +@setup_caching cuTileInterpreter.cache + +# Optimization flags +CC.may_optimize(::cuTileInterpreter) = true +CC.may_compress(::cuTileInterpreter) = true +CC.may_discard_trees(::cuTileInterpreter) = true + +# Disable semi-concrete interpretation (broken with overlays per JuliaLang/julia#47349) +function CC.concrete_eval_eligible(interp::cuTileInterpreter, + @nospecialize(f), result::CC.MethodCallResult, arginfo::CC.ArgInfo, sv::CC.InferenceState) + ret = @invoke CC.concrete_eval_eligible(interp::CC.AbstractInterpreter, + f::Any, result::CC.MethodCallResult, arginfo::CC.ArgInfo, sv::CC.InferenceState) + if ret === :semi_concrete_eval + return :none + end + return ret +end + +""" + code_ircode(mi::MethodInstance; world, always_inline=true) -> (IRCode, rettype) + +Get optimized IRCode for a MethodInstance using cuTile's overlay method table. +If always_inline=true (default), forces all functions to be inlined. +""" +function code_ircode(mi::MethodInstance; world::UInt=Base.get_world_counter(), + always_inline::Bool=true) + cache = CacheHandle(:cuTile) + interp = cuTileInterpreter(cache, world; always_inline) + result = CC.typeinf_ircode(interp, mi, nothing) + + if result === nothing + error("Type inference failed for $mi") + end + + ir, rettype = result + return ir, rettype +end + +#============================================================================= + Compilation phases +=============================================================================# + +# Compilation options for cache sharding +const CGOpts = @NamedTuple{ + sm_arch::Union{String, Nothing}, + opt_level::Int, + num_ctas::Union{Int, Nothing}, + occupancy::Union{Int, Nothing} +} + +""" + emit_ir(cache, mi, world) -> (StructuredIRCode, rettype) + +IR phase: populate code cache with dependencies and return structured IR. +This phase uses cuTile's overlay method table for intrinsic substitution. +""" +function emit_ir(cache::CacheHandle, mi::Core.MethodInstance, world::UInt) + interp = cuTileInterpreter(cache, world) + populate!(cache, interp, mi) + + # Return StructuredIRCode for emit_code phase + ir, rettype = code_ircode(mi; world) + sci = StructuredIRCode(ir) + return (sci, rettype) +end + +""" + emit_code(cache, mi, world, ir_result) -> Vector{UInt8} + +Code phase: generate Tile IR bytecode from StructuredIRCode. +This phase is deterministic and does not require CUDA. + +Returns bytecode that can be compiled to CUBIN by tileiras in the emit_executable phase. +""" +function emit_code(cache::CacheHandle, mi::Core.MethodInstance, world::UInt, ir_result) + sci, rettype = ir_result + opts = cache.keys + + # Generate Tile IR bytecode + bytecode = write_bytecode!(1) do writer, func_buf + emit_kernel!(writer, func_buf, sci, rettype; + name = string(mi.def.name), + sm_arch = opts.sm_arch, + num_ctas = opts.num_ctas, + occupancy = opts.occupancy + ) + end + + # Dump bytecode if JULIA_CUTILE_DUMP_BYTECODE is set + dump_dir = get(ENV, "JULIA_CUTILE_DUMP_BYTECODE", nothing) + if dump_dir !== nothing + mkpath(dump_dir) + base_filename = basename(string(mi.def.file)) + base_filename = first(splitext(base_filename)) + dump_path = joinpath(dump_dir, "$(base_filename).ln$(mi.def.line).cutile") + counter = 1 + while isfile(dump_path) + counter += 1 + dump_path = joinpath(dump_dir, "$(base_filename).ln$(mi.def.line).$(counter).cutile") + end + println(stderr, "Dumping TILEIR bytecode to file: $dump_path") + write(dump_path, bytecode) + end + + return bytecode +end + +#============================================================================= + Reflection utilities +=============================================================================# + +function disassemble_tileir(bytecode::Vector{UInt8})::String + mktempdir() do dir + input_path = joinpath(dir, "kernel.tile") + output_path = joinpath(dir, "kernel.disasm") + write(input_path, bytecode) + read(`$(cuda_tile_translate()) --cudatilebc-to-mlir $input_path`, String) + end +end + +""" + code_typed(f, argtypes; world, kwargs...) -> Vector{Any} + +Return typed code for a cuTile function.. Analogous to `Base.code_typed`. +""" +function code_typed(@nospecialize(f), @nospecialize(argtypes); + world::UInt=Base.get_world_counter(), kwargs...) + cache = CacheHandle(:cuTile) + interp = cuTileInterpreter(cache, world) + Base.code_typed(f, argtypes; world, interp, kwargs...) +end + +""" + code_structured(f, argtypes; kwargs...) -> StructuredIRCode + +Return the structured IR for a cuTile function. +""" +function code_structured(@nospecialize(f), @nospecialize(argtypes); + sm_arch::Union{String, Nothing}=nothing, + opt_level::Int=3, + num_ctas::Union{Int, Nothing}=nothing, + occupancy::Union{Int, Nothing}=nothing) + world = Base.get_world_counter() + mi = @something(method_instance(f, argtypes; world, method_table=cuTileMethodTable), + method_instance(f, argtypes; world), + throw(MethodError(f, argtypes))) + + opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy) + cache = CacheHandle{CGOpts}(:cuTile, opts) + + _, ir_result = get_ir(cache, mi, world; emit_ir) + sci, rettype = ir_result + return sci +end + +""" + code_tiled(f, argtypes; sm_arch, opt_level, num_ctas, occupancy) -> String + +Return the CUDA Tile IR for a Julia function as a textual MLIR representation. +Analogous to `code_typed` or `code_structured`. + +Uses the same caching infrastructure as `launch`, benefiting from cached IR +and code results. +""" +function code_tiled(@nospecialize(f), @nospecialize(argtypes); + sm_arch::Union{String, Nothing}=nothing, + opt_level::Int=3, + num_ctas::Union{Int, Nothing}=nothing, + occupancy::Union{Int, Nothing}=nothing) + world = Base.get_world_counter() + mi = @something(method_instance(f, argtypes; world, method_table=cuTileMethodTable), + method_instance(f, argtypes; world), + throw(MethodError(f, argtypes))) + + opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy) + cache = CacheHandle{CGOpts}(:cuTile, opts) + + _, bytecode = get_code(cache, mi, world; emit_ir, emit_code) + disassemble_tileir(bytecode) +end + +# compilation hooking: uses CompilerCaching's global hook +function emit_hooked_compilation(inner_hook, ex...) + user_code = ex[end] + user_kwargs = ex[1:end-1] + quote + # we only want to invoke the hook once for every compilation + seen = Set() + function outer_hook(cache, mi, world) + key = (cache, mi) + if !in(key, seen) + # the user hook might invoke the compiler again, so disable the hook + old_hook = $compile_hook() + try + $compile_hook!(nothing) + opts = cache.keys + $inner_hook(cache, mi, world; $(map(esc, user_kwargs)...)) + finally + $compile_hook!(old_hook) + end + push!(seen, key) + end + end + + # now invoke the user code with this hook in place + try + $compile_hook!(outer_hook) + $(esc(user_code)) + finally + $compile_hook!(nothing) + end + + if isempty(seen) + error("no kernels executed while evaluating the given expression") + end + + nothing + end +end + +macro code_tiled(ex...) + function hook(cache, mi, world; io::IO=stdout) + # The hook fires during cached_compilation, so bytecode is being cached + # at this moment - retrieve it via get_code + _, bytecode = get_code(cache, mi, world; emit_ir, emit_code) + println(io, "// $(mi.def.name)($(join(map(string, mi.specTypes.parameters[2:end]), ", ")))") + println(io) + println(io, disassemble_tileir(bytecode)) + end + emit_hooked_compilation(hook, ex...) +end diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl deleted file mode 100644 index 1547375..0000000 --- a/src/compiler/interpreter.jl +++ /dev/null @@ -1,112 +0,0 @@ -using Base.Experimental: @MethodTable - -# Create the cuTile method table -@MethodTable cuTileMethodTable - -# Create cached method table type based on version compatibility -const cuTileMethodTableView = CC.CachedMethodTable{CC.OverlayMethodTable} - -function get_method_table_view(world::UInt) - CC.CachedMethodTable(CC.OverlayMethodTable(world, cuTileMethodTable)) -end - -""" -Custom interpreter that supports overlay method tables for cuTile compilation. -This is necessary because NativeInterpreter has a fixed method_table type parameter. -""" -struct cuTileInterpreter <: CC.AbstractInterpreter - world::UInt - method_table::cuTileMethodTableView - inf_cache::Vector{CC.InferenceResult} - inf_params::CC.InferenceParams - opt_params::CC.OptimizationParams -end - -function cuTileInterpreter(world::UInt=Base.get_world_counter(); - always_inline::Bool=true) - method_table = get_method_table_view(world) - inf_cache = Vector{CC.InferenceResult}() - inf_params = CC.InferenceParams() - opt_params = if always_inline - CC.OptimizationParams(; inline_cost_threshold=typemax(Int)) - else - CC.OptimizationParams() - end - return cuTileInterpreter(world, method_table, inf_cache, inf_params, opt_params) -end - -# Required AbstractInterpreter interface methods -CC.InferenceParams(interp::cuTileInterpreter) = interp.inf_params -CC.OptimizationParams(interp::cuTileInterpreter) = interp.opt_params -CC.get_inference_cache(interp::cuTileInterpreter) = interp.inf_cache - -# World age -@static if isdefined(CC, :get_inference_world) - CC.get_inference_world(interp::cuTileInterpreter) = interp.world -else - CC.get_world_counter(interp::cuTileInterpreter) = interp.world -end - -# Method table - this enables the overlays -CC.method_table(interp::cuTileInterpreter) = interp.method_table - -# Locking - not needed for non-cached compilation -CC.lock_mi_inference(::cuTileInterpreter, ::MethodInstance) = nothing -CC.unlock_mi_inference(::cuTileInterpreter, ::MethodInstance) = nothing - -# Cache owner - use a unique identifier to cache inference results separately -CC.cache_owner(::cuTileInterpreter) = :cutile - -# Optimization flags -CC.may_optimize(::cuTileInterpreter) = true -CC.may_compress(::cuTileInterpreter) = true -CC.may_discard_trees(::cuTileInterpreter) = true - -# Disable semi-concrete interpretation (broken with overlays per JuliaLang/julia#47349) -function CC.concrete_eval_eligible(interp::cuTileInterpreter, - @nospecialize(f), result::CC.MethodCallResult, arginfo::CC.ArgInfo, sv::CC.InferenceState) - ret = @invoke CC.concrete_eval_eligible(interp::CC.AbstractInterpreter, - f::Any, result::CC.MethodCallResult, arginfo::CC.ArgInfo, sv::CC.InferenceState) - if ret === :semi_concrete_eval - return :none - end - return ret -end - - -#============================================================================= - Public API -=============================================================================# - -""" - get_ircode(f, argtypes; world=Base.get_world_counter(), always_inline=true) -> (IRCode, return_type) - -Get the optimized IRCode for a function with the given argument types. -Uses cuTile's method overlay table to redirect Base operations to cuTile intrinsics. -If always_inline=true (default), forces all functions to be inlined. -""" -function get_ircode(@nospecialize(f), @nospecialize(argtypes); - world::UInt = Base.get_world_counter(), - always_inline::Bool = true) - interp = cuTileInterpreter(world; always_inline) - mi = get_method_instance(f, argtypes; world) - result = CC.typeinf_ircode(interp, mi, nothing) - - if result === nothing - error("Type inference failed for $f with argument types $argtypes") - end - - ir, rettype = result - return ir, rettype -end - -""" - get_method_instance(f, argtypes; world=Base.get_world_counter()) -> MethodInstance - -Get the MethodInstance for a function call. -""" -function get_method_instance(@nospecialize(f), @nospecialize(argtypes); world::UInt = Base.get_world_counter()) - tt = Base.signature_type(f, argtypes) - match = Base._which(tt; world) - return CC.specialize_method(match) -end diff --git a/src/compiler/reflection.jl b/src/compiler/reflection.jl deleted file mode 100644 index 12c4cb3..0000000 --- a/src/compiler/reflection.jl +++ /dev/null @@ -1,95 +0,0 @@ -export code_tiled, @code_tiled - -""" - emit_tileir(f, argtypes; name, sm_arch, num_ctas, occupancy) -> Vector{UInt8} - -Compile a Julia function to Tile IR bytecode. -""" -function emit_tileir(@nospecialize(f), @nospecialize(argtypes); - name::Union{String, Nothing} = nothing, - sm_arch::Union{String, Nothing} = nothing, - num_ctas::Union{Int, Nothing} = nothing, - occupancy::Union{Int, Nothing} = nothing) - target = TileTarget(f, argtypes) - kernel_name = name === nothing ? string(target.mi.def.name) : name - - if compile_hook[] !== nothing - compile_hook[](f, argtypes; name=name) - end - - buf = write_bytecode!(1) do writer, func_buf - emit_kernel!(writer, func_buf, target; name=kernel_name, sm_arch, - num_ctas, occupancy) - end - - return buf -end - -function disassemble_tileir(bytecode::Vector{UInt8})::String - mktempdir() do dir - input_path = joinpath(dir, "kernel.tile") - output_path = joinpath(dir, "kernel.disasm") - write(input_path, bytecode) - read(`$(cuda_tile_translate()) --cudatilebc-to-mlir $input_path`, String) - end -end - -""" - code_tiled(f, argtypes; name, sm_arch, num_ctas, occupancy) -> String - -Return the CUDA Tile IR for a Julia function as a textual MLIR representation. -Analogous to `code_typed` or `code_structured`. -""" -function code_tiled(@nospecialize(f), @nospecialize(argtypes); - kwargs...) - bytecode = emit_tileir(f, argtypes; kwargs...) - disassemble_tileir(bytecode) -end - -# compilation hooking: taken from GPUCompiler.jl -const compile_hook = Ref{Union{Nothing,Function}}(nothing) -function emit_hooked_compilation(inner_hook, ex...) - user_code = ex[end] - user_kwargs = ex[1:end-1] - quote - # we only want to invoke the hook once for every compilation - seen = Set() - function outer_hook(f, tt; kwargs...) - key = (f, tt) - if !in(key, seen) - # the user hook might invoke the compiler again, so disable the hook - old_hook = $compile_hook[] - try - $compile_hook[] = nothing - $inner_hook(f, tt; kwargs..., $(map(esc, user_kwargs)...)) - finally - $compile_hook[] = old_hook - end - push!(seen, key) - end - end - - # now invoke the user code with this hook in place - try - $compile_hook[] = outer_hook - $(esc(user_code)) - finally - $compile_hook[] = nothing - end - - if isempty(seen) - error("no kernels executed while evaluating the given expression") - end - - nothing - end -end - -macro code_tiled(ex...) - function hook(f, tt; io::IO=stdout, kwargs...) - println(io, "// $f($(join(map(string, tt.parameters), ", ")))") - println(io) - println(io, code_tiled(f, tt; kwargs...)) - end - emit_hooked_compilation(hook, ex...) -end diff --git a/src/compiler/target.jl b/src/compiler/target.jl deleted file mode 100644 index af26ce7..0000000 --- a/src/compiler/target.jl +++ /dev/null @@ -1,389 +0,0 @@ -# TileTarget and CGCtx for cuTile compilation -# -# Holds compilation target and state for a kernel. - -#============================================================================= - TileTarget: Compilation target -=============================================================================# - -""" - TileTarget - -Holds everything about a function being compiled to Tile IR: -the structured IR and return type. -""" -struct TileTarget - mi::MethodInstance - sci::StructuredIRCode - rettype::Type -end - -function TileTarget(@nospecialize(f), @nospecialize(argtypes::Type{<:Tuple})) - ir, rettype = get_ircode(f, argtypes) - mi = get_method_instance(f, argtypes) - sci = StructuredIRCode(ir) - TileTarget(mi, sci, rettype) -end - -# Accessors -# Count non-ghost arguments (excludes function type for regular functions) -nargs(target::TileTarget) = count(!is_ghost_type ∘ unwrap_type, target.sci.argtypes) - -#============================================================================= - CGVal: Unified value representation (analogous to Julia's jl_cgval_t) -=============================================================================# - -""" - CGVal - -Represents a value during Tile IR code generation, bundling the IR value -with its type information and metadata. - -Similar to Julia compiler's `jl_cgval_t`, this provides a unified representation -for all values flowing through codegen. A CGVal can be either: -1. A concrete SSA value (v::Value) -2. A multi-value result from control flow ops (v::Vector{Value}) -3. A lazy argument reference chain (v is nothing, arg_ref tracks the access path) -4. A ghost value (v is nothing, zero-size singleton) -""" -struct CGVal - v::Union{Value, Vector{Value}, Nothing} # Single value, multi-value, or nothing - type_id::Union{TypeId, Nothing} # Tile IR type (nothing for lazy refs or multi-value) - jltype::Any # Original Julia type - shape::Vector{Int} # Tile shape (empty for scalars) - # Lazy argument reference: (arg_idx, [:field, index, ...]) - # e.g., (1, [:sizes, 2]) means "argument 1, field :sizes, index 2" - arg_ref::Union{Tuple{Int, Vector{Union{Symbol, Int}}}, Nothing} - constant::Union{Some, Nothing} # Nothing = no constant, Some(x) = constant value x - tuple::Union{Vector{Any}, Nothing} # For tuples: component refs (SSAValue, etc.) -end - -# Convenience constructors for concrete values -CGVal(v::Value, type_id::TypeId, @nospecialize(jltype)) = - CGVal(v, type_id, jltype, Int[], nothing, nothing, nothing) - -CGVal(v::Value, type_id::TypeId, @nospecialize(jltype), shape::Vector{Int}) = - CGVal(v, type_id, jltype, shape, nothing, nothing, nothing) - -# Constructor for multi-value results (from loops, ifs) -CGVal(v::Vector{Value}, @nospecialize(jltype)) = - CGVal(v, nothing, jltype, Int[], nothing, nothing, nothing) - -# Constructor for lazy argument references -function arg_ref_value(arg_idx::Int, chain::Vector{Union{Symbol, Int}}, @nospecialize(jltype)) - CGVal(nothing, nothing, jltype, Int[], (arg_idx, chain), nothing, nothing) -end - -""" - ghost_value(jltype[, constant]) -> CGVal - -Create a ghost value (zero-size singleton with no runtime representation). -Optionally stores a compile-time constant value. -""" -ghost_value(@nospecialize(jltype)) = CGVal(nothing, TypeId(-1), jltype, Int[], nothing, nothing, nothing) -ghost_value(@nospecialize(jltype), constant) = CGVal(nothing, TypeId(-1), jltype, Int[], nothing, Some(constant), nothing) - -""" - tuple_value(jltype, component_refs, component_constants) -> CGVal - -Create a tuple value with tracked component refs. Derives constant if all components have constants. -Used by intrinsics like cat() that need to access individual tuple elements. -""" -function tuple_value(@nospecialize(jltype), component_refs::Vector{Any}, component_constants::Vector{Any}) - # If all components have constants, derive the tuple constant - constant = if all(!isnothing, component_constants) - Some(Tuple(component_constants)) - else - nothing - end - CGVal(nothing, TypeId(-1), jltype, Int[], nothing, constant, component_refs) -end - -""" - is_ghost(tv::CGVal) -> Bool - -Check if a CGVal is a ghost (no runtime representation). -""" -is_ghost(tv::CGVal) = tv.v === nothing && tv.arg_ref === nothing - -""" - is_arg_ref(tv::CGVal) -> Bool - -Check if a CGVal is a lazy argument reference. -""" -is_arg_ref(tv::CGVal) = tv.arg_ref !== nothing - -#============================================================================= - CGCtx: Compilation context -=============================================================================# - -""" - CGCtx - -Holds all state during Tile IR code generation for a kernel function. -Maps Julia SSA values to CGVals and manages bytecode emission. -""" -mutable struct CGCtx - # SSA value mapping: original Julia SSA index -> CGVal - # Uses global/original indices everywhere (no local renumbering) - # Loop/if ops store a CGVal with tuple_values field (extracted by getfield statements) - values::Dict{Int, CGVal} - args::Dict{Int, CGVal} # Argument index -> CGVal - slots::Dict{Int, CGVal} # Slot number -> CGVal - block_args::Dict{Int, CGVal} # BlockArg id -> CGVal (for control flow) - - # Destructured argument handling (for TileArray fields) - arg_flat_values::Dict{Tuple{Int, Union{Nothing, Symbol}}, Vector{Value}} - arg_types::Dict{Int, Type} - - # Cached TensorViews for TileArray arguments (arg_idx -> (Value, TypeId)) - tensor_views::Dict{Int, Tuple{Value, TypeId}} - - # Bytecode infrastructure - cb::CodeBuilder - tt::TypeTable - target::TileTarget - - # Memory ordering token - token::Union{Value, Nothing} - token_type::Union{TypeId, Nothing} - - # Type cache: Julia type -> TypeId - type_cache::Dict{Type, TypeId} - - # Target architecture (e.g., :sm_100) - sm_arch::Union{String, Nothing} -end - -function CGCtx(writer::BytecodeWriter, target::TileTarget, sm_arch::Union{String, Nothing}=nothing) - CGCtx( - Dict{Int, CGVal}(), - Dict{Int, CGVal}(), - Dict{Int, CGVal}(), - Dict{Int, CGVal}(), - Dict{Tuple{Int, Union{Nothing, Symbol}}, Vector{Value}}(), - Dict{Int, Type}(), - Dict{Int, Tuple{Value, TypeId}}(), # tensor_views cache - CodeBuilder(writer.string_table, writer.constant_table, writer.type_table), - writer.type_table, - target, - nothing, - nothing, - Dict{Type, TypeId}(), - sm_arch, - ) -end - -#============================================================================= - Value lookup via indexing syntax -=============================================================================# - -function Base.getindex(ctx::CGCtx, ssa::SSAValue) - # Simple lookup by original Julia SSA index - get(ctx.values, ssa.id, nothing) -end - -function Base.getindex(ctx::CGCtx, arg::Argument) - get(ctx.args, arg.n, nothing) -end - -function Base.getindex(ctx::CGCtx, slot::SlotNumber) - get(ctx.slots, slot.id, nothing) -end - -function Base.setindex!(ctx::CGCtx, tv::CGVal, ssa::SSAValue) - ctx.values[ssa.id] = tv -end - -function Base.setindex!(ctx::CGCtx, tv::CGVal, arg::Argument) - ctx.args[arg.n] = tv -end - -function Base.setindex!(ctx::CGCtx, tv::CGVal, slot::SlotNumber) - ctx.slots[slot.id] = tv -end - -function Base.getindex(ctx::CGCtx, block_arg::BlockArg) - get(ctx.block_args, block_arg.id, nothing) -end - -function Base.setindex!(ctx::CGCtx, tv::CGVal, block_arg::BlockArg) - ctx.block_args[block_arg.id] = tv -end - -#============================================================================= - Destructured argument helpers -=============================================================================# - -""" - get_arg_flat_values(ctx, arg_idx, field=nothing) -> Union{Vector{Value}, Nothing} - -Get the flat Tile IR values for an argument or its field. -""" -function get_arg_flat_values(ctx::CGCtx, arg_idx::Int, field::Union{Nothing, Symbol}=nothing) - get(ctx.arg_flat_values, (arg_idx, field), nothing) -end - -""" - is_destructured_arg(ctx, arg_idx) -> Bool - -Check if an argument was destructured into multiple flat parameters. -""" -is_destructured_arg(ctx::CGCtx, arg_idx::Int) = haskey(ctx.arg_types, arg_idx) - -""" - get_arg_type(ctx, arg_idx) -> Union{Type, Nothing} - -Get the original Julia type for a destructured argument. -""" -get_arg_type(ctx::CGCtx, arg_idx::Int) = get(ctx.arg_types, arg_idx, nothing) - -#============================================================================= - Type conversion utilities -=============================================================================# - -""" - unwrap_type(T) -> Type - -Unwrap type wrappers like Core.Const to get the actual type. -""" -function unwrap_type(@nospecialize(T)) - if T isa Core.Const - return typeof(T.val) - elseif T isa Core.PartialStruct - return T.typ - elseif T isa Type - return T - else - return T - end -end - -""" - require_concrete_type(T, context::String) - -Ensure a type is fully concrete (not a UnionAll). -""" -function require_concrete_type(@nospecialize(T), context::String) - T_unwrapped = unwrap_type(T) - if T_unwrapped isa UnionAll - error("Type must be fully concrete in $context, got partial type: $T") - end - return T_unwrapped -end - -""" - tile_type_for_julia!(ctx, T) -> TypeId - -Get or create a Tile IR type for a Julia type. -""" -function tile_type_for_julia!(ctx::CGCtx, @nospecialize(T)) - actual_type = unwrap_type(T) - get!(ctx.type_cache, actual_type) do - _tile_type_for_julia!(ctx.tt, actual_type) - end -end - -function _tile_type_for_julia!(tt::TypeTable, @nospecialize(T::Type)) - # Scalar types -> 0-D tile - if T === Bool - return tile_type!(tt, I1(tt), Int[]) - elseif T === Int32 || T === UInt32 - return tile_type!(tt, I32(tt), Int[]) - elseif T === Int64 || T === UInt64 - return tile_type!(tt, I64(tt), Int[]) - elseif T === Float16 - return tile_type!(tt, F16(tt), Int[]) - elseif T === Float32 - return tile_type!(tt, F32(tt), Int[]) - elseif T === Float64 - return tile_type!(tt, F64(tt), Int[]) - elseif T === Nothing - return Token(tt) - end - - # Pointers -> 0-D tile of pointer type - if T <: Ptr - elem_dtype = julia_to_tile_dtype!(tt, eltype(T)) - ptr_type = pointer_type!(tt, elem_dtype) - return tile_type!(tt, ptr_type, Int[]) - end - - # Tile{T, Shape} -> tile type with shape - if T <: Tile - if T isa UnionAll || !isa(T, DataType) || length(T.parameters) < 2 - error("Tile type must be fully specified with element type and shape, got: $T. " * - "This indicates type instability in the kernel - ensure all tile operations have inferrable shapes.") - end - elem_type = T.parameters[1] - shape_param = T.parameters[2] - if !(shape_param isa Tuple) - error("Tile shape must be a tuple, got: $shape_param") - end - elem_dtype = julia_to_tile_dtype!(tt, elem_type) - shape = collect(Int, shape_param) - return tile_type!(tt, elem_dtype, shape) - end - - error("Unsupported Julia type for Tile IR: $T") -end - -""" - tile_type_and_shape_for_julia!(ctx, T) -> (TypeId, Vector{Int}) - -Get the Tile IR type and shape for a Julia type. -""" -function tile_type_and_shape_for_julia!(ctx::CGCtx, @nospecialize(T)) - actual_type = unwrap_type(T) - type_id = tile_type_for_julia!(ctx, actual_type) - - # Extract shape from Tile types - shape = Int[] - if actual_type <: Tile && length(actual_type.parameters) >= 2 - shape_param = actual_type.parameters[2] - if shape_param isa Tuple - shape = collect(Int, shape_param) - end - end - - return (type_id, shape) -end - -#============================================================================= - Struct destructuring helpers -=============================================================================# - -""" - is_ghost_type(T) -> Bool - -Check if a type is a ghost type (zero-size singleton). -""" -function is_ghost_type(@nospecialize(T)) - try - isbitstype(T) && sizeof(T) == 0 - catch - false - end -end - -""" - should_destructure(T) -> Bool - -Check if a type should be destructured into flat parameters. -""" -function should_destructure(@nospecialize(T)) - T = unwrap_type(T) - isstructtype(T) || return false - is_ghost_type(T) && return false - isprimitivetype(T) && return false - T <: TileArray && return true - return false -end - -""" - flat_field_count(T) -> Int - -Count flat parameters a type expands to. -""" -flat_field_count(::Type{<:NTuple{N, T}}) where {N, T} = N -flat_field_count(::Type) = 1 diff --git a/src/cuTile.jl b/src/cuTile.jl index 375aaa2..6cd5545 100644 --- a/src/cuTile.jl +++ b/src/cuTile.jl @@ -23,11 +23,9 @@ include("bytecode/encodings.jl") include("language/types.jl") # Compiler implementation -include("compiler/interpreter.jl") -include("compiler/target.jl") +include("compiler/interface.jl") include("compiler/codegen.jl") include("compiler/intrinsics.jl") -include("compiler/reflection.jl") # Language implementation include("language/broadcast.jl") From e3c7cf13e411067d2de67e8090cdbb31242917a4 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 22 Jan 2026 11:29:10 +0100 Subject: [PATCH 2/5] Update to new CompilerCaching.jl API. --- ext/CUDAExt.jl | 12 ++++----- src/compiler/interface.jl | 52 +++++++++++++++++++-------------------- 2 files changed, 31 insertions(+), 33 deletions(-) diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index ea789f8..75ac89e 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -3,7 +3,7 @@ module CUDAExt using cuTile using cuTile: TileArray, Constant, CGOpts, emit_ir, emit_code -using CompilerCaching: CacheHandle, cached_compilation, method_instance +using CompilerCaching: CacheView, cached_compilation, method_instance using CUDA: CuModule, CuFunction, cudacall, device, capability using CUDA_Compiler_jll @@ -11,12 +11,12 @@ using CUDA_Compiler_jll public launch """ - emit_executable(cache, mi, world, bytecode) -> CuFunction + emit_executable(cache, mi, bytecode) -> CuFunction Executable phase: run tileiras on bytecode to produce CUBIN, then load into GPU memory. This is the only session-dependent phase. """ -function emit_executable(cache::CacheHandle, mi::Core.MethodInstance, world::UInt, bytecode::Vector{UInt8}) +function emit_executable(cache::CacheView, mi::Core.MethodInstance, bytecode::Vector{UInt8}) opts = cache.keys kernel_name = string(mi.def.name) @@ -95,12 +95,12 @@ function cuTile.launch(@nospecialize(f), grid, args...; mi = method_instance(f, argtypes; world) mi === nothing && throw(MethodError(f, argtypes)) - # Create cache handle with compilation options as sharding keys + # Create cache view with compilation options as sharding keys opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy) - cache = CacheHandle{CGOpts}(:cuTile, opts) + cache = CacheView{CGOpts}(:cuTile, world, opts) # Run cached three-phase compilation - cufunc = cached_compilation(cache, mi, world; emit_ir, emit_code, emit_executable) + cufunc = cached_compilation(cache, mi; emit_ir, emit_code, emit_executable) # Flatten arguments for cudacall - Constant returns () so ghost types disappear flat_args = Tuple(Iterators.flatten(map(flatten, tile_args))) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 5ed3c24..b8715a7 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -7,7 +7,7 @@ export code_tiled, @code_tiled -using CompilerCaching: CacheHandle, @setup_caching, get_ir, get_code, compile_hook, +using CompilerCaching: CacheView, @setup_caching, get_ir, get_code, compile_hook, compile_hook!, method_instance, populate! #============================================================================= @@ -25,17 +25,15 @@ Custom interpreter that supports overlay method tables for cuTile compilation. This is necessary because NativeInterpreter has a fixed method_table type parameter. """ struct cuTileInterpreter <: CC.AbstractInterpreter - cache::CacheHandle - world::UInt + cache::CacheView method_table::CC.CachedMethodTable{CC.OverlayMethodTable} inf_cache::Vector{CC.InferenceResult} inf_params::CC.InferenceParams opt_params::CC.OptimizationParams end -function cuTileInterpreter(cache::CacheHandle, world::UInt=Base.get_world_counter(); - always_inline::Bool=true) - method_table = get_method_table_view(world) +function cuTileInterpreter(cache::CacheView; always_inline::Bool=true) + method_table = get_method_table_view(cache.world) inf_cache = Vector{CC.InferenceResult}() inf_params = CC.InferenceParams() opt_params = if always_inline @@ -43,7 +41,7 @@ function cuTileInterpreter(cache::CacheHandle, world::UInt=Base.get_world_counte else CC.OptimizationParams() end - return cuTileInterpreter(cache, world, method_table, inf_cache, inf_params, opt_params) + return cuTileInterpreter(cache, method_table, inf_cache, inf_params, opt_params) end # Required AbstractInterpreter interface methods @@ -53,9 +51,9 @@ CC.get_inference_cache(interp::cuTileInterpreter) = interp.inf_cache # World age @static if isdefined(CC, :get_inference_world) - CC.get_inference_world(interp::cuTileInterpreter) = interp.world + CC.get_inference_world(interp::cuTileInterpreter) = interp.cache.world else - CC.get_world_counter(interp::cuTileInterpreter) = interp.world + CC.get_world_counter(interp::cuTileInterpreter) = interp.cache.world end # Method table - this enables the overlays @@ -92,8 +90,8 @@ If always_inline=true (default), forces all functions to be inlined. """ function code_ircode(mi::MethodInstance; world::UInt=Base.get_world_counter(), always_inline::Bool=true) - cache = CacheHandle(:cuTile) - interp = cuTileInterpreter(cache, world; always_inline) + cache = CacheView(:cuTile, world) + interp = cuTileInterpreter(cache; always_inline) result = CC.typeinf_ircode(interp, mi, nothing) if result === nothing @@ -117,30 +115,30 @@ const CGOpts = @NamedTuple{ } """ - emit_ir(cache, mi, world) -> (StructuredIRCode, rettype) + emit_ir(cache, mi) -> (StructuredIRCode, rettype) IR phase: populate code cache with dependencies and return structured IR. This phase uses cuTile's overlay method table for intrinsic substitution. """ -function emit_ir(cache::CacheHandle, mi::Core.MethodInstance, world::UInt) - interp = cuTileInterpreter(cache, world) +function emit_ir(cache::CacheView, mi::Core.MethodInstance) + interp = cuTileInterpreter(cache) populate!(cache, interp, mi) # Return StructuredIRCode for emit_code phase - ir, rettype = code_ircode(mi; world) + ir, rettype = code_ircode(mi; world=cache.world) sci = StructuredIRCode(ir) return (sci, rettype) end """ - emit_code(cache, mi, world, ir_result) -> Vector{UInt8} + emit_code(cache, mi, ir_result) -> Vector{UInt8} Code phase: generate Tile IR bytecode from StructuredIRCode. This phase is deterministic and does not require CUDA. Returns bytecode that can be compiled to CUBIN by tileiras in the emit_executable phase. """ -function emit_code(cache::CacheHandle, mi::Core.MethodInstance, world::UInt, ir_result) +function emit_code(cache::CacheView, mi::Core.MethodInstance, ir_result) sci, rettype = ir_result opts = cache.keys @@ -193,8 +191,8 @@ Return typed code for a cuTile function.. Analogous to `Base.code_typed`. """ function code_typed(@nospecialize(f), @nospecialize(argtypes); world::UInt=Base.get_world_counter(), kwargs...) - cache = CacheHandle(:cuTile) - interp = cuTileInterpreter(cache, world) + cache = CacheView(:cuTile, world) + interp = cuTileInterpreter(cache) Base.code_typed(f, argtypes; world, interp, kwargs...) end @@ -214,9 +212,9 @@ function code_structured(@nospecialize(f), @nospecialize(argtypes); throw(MethodError(f, argtypes))) opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy) - cache = CacheHandle{CGOpts}(:cuTile, opts) + cache = CacheView{CGOpts}(:cuTile, world, opts) - _, ir_result = get_ir(cache, mi, world; emit_ir) + ir_result = get_ir(cache, mi; emit_ir) sci, rettype = ir_result return sci end @@ -241,9 +239,9 @@ function code_tiled(@nospecialize(f), @nospecialize(argtypes); throw(MethodError(f, argtypes))) opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy) - cache = CacheHandle{CGOpts}(:cuTile, opts) + cache = CacheView{CGOpts}(:cuTile, world, opts) - _, bytecode = get_code(cache, mi, world; emit_ir, emit_code) + bytecode = get_code(cache, mi; emit_ir, emit_code) disassemble_tileir(bytecode) end @@ -254,7 +252,7 @@ function emit_hooked_compilation(inner_hook, ex...) quote # we only want to invoke the hook once for every compilation seen = Set() - function outer_hook(cache, mi, world) + function outer_hook(cache, mi) key = (cache, mi) if !in(key, seen) # the user hook might invoke the compiler again, so disable the hook @@ -262,7 +260,7 @@ function emit_hooked_compilation(inner_hook, ex...) try $compile_hook!(nothing) opts = cache.keys - $inner_hook(cache, mi, world; $(map(esc, user_kwargs)...)) + $inner_hook(cache, mi; $(map(esc, user_kwargs)...)) finally $compile_hook!(old_hook) end @@ -287,10 +285,10 @@ function emit_hooked_compilation(inner_hook, ex...) end macro code_tiled(ex...) - function hook(cache, mi, world; io::IO=stdout) + function hook(cache, mi; io::IO=stdout) # The hook fires during cached_compilation, so bytecode is being cached # at this moment - retrieve it via get_code - _, bytecode = get_code(cache, mi, world; emit_ir, emit_code) + bytecode = get_code(cache, mi; emit_ir, emit_code) println(io, "// $(mi.def.name)($(join(map(string, mi.specTypes.parameters[2:end]), ", ")))") println(io) println(io, disassemble_tileir(bytecode)) From a6a5038fdde31bab6978d269e4a3edb53371d304 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 22 Jan 2026 17:37:19 +0100 Subject: [PATCH 3/5] Update to new CompilerCaching.jl API. --- Project.toml | 3 ++- ext/CUDAExt.jl | 11 ++++++----- src/compiler/interface.jl | 39 ++++++++++++++++++++++++--------------- 3 files changed, 32 insertions(+), 21 deletions(-) diff --git a/Project.toml b/Project.toml index 90f1f81..129ed4b 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,8 @@ IRStructurizer = "93e32bba-5bb8-402b-805d-ffb066edee93" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" [sources] -CompilerCaching = {url = "https://github.com/maleadt/CompilerCaching.jl", rev="main"} +#CompilerCaching = {url = "https://github.com/maleadt/CompilerCaching.jl", rev="main"} +CompilerCaching = {path = "/Users/tim/Julia/pkg/CompilerCaching"} IRStructurizer = {path = "IRStructurizer"} [extensions] diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index 75ac89e..b4cb379 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -3,7 +3,7 @@ module CUDAExt using cuTile using cuTile: TileArray, Constant, CGOpts, emit_ir, emit_code -using CompilerCaching: CacheView, cached_compilation, method_instance +using CompilerCaching: CacheView, method_instance using CUDA: CuModule, CuFunction, cudacall, device, capability using CUDA_Compiler_jll @@ -11,12 +11,13 @@ using CUDA_Compiler_jll public launch """ - emit_executable(cache, mi, bytecode) -> CuFunction + emit_executable(cache, mi) -> CuFunction Executable phase: run tileiras on bytecode to produce CUBIN, then load into GPU memory. This is the only session-dependent phase. """ -function emit_executable(cache::CacheView, mi::Core.MethodInstance, bytecode::Vector{UInt8}) +function emit_executable(cache::CacheView, mi::Core.MethodInstance) + bytecode = get!(emit_code, cache, mi, :code) opts = cache.keys kernel_name = string(mi.def.name) @@ -99,8 +100,8 @@ function cuTile.launch(@nospecialize(f), grid, args...; opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy) cache = CacheView{CGOpts}(:cuTile, world, opts) - # Run cached three-phase compilation - cufunc = cached_compilation(cache, mi; emit_ir, emit_code, emit_executable) + # Run cached compilation + cufunc = get!(emit_executable, cache, mi, :executable) # Flatten arguments for cudacall - Constant returns () so ghost types disappear flat_args = Tuple(Iterators.flatten(map(flatten, tile_args))) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index b8715a7..76823ab 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -7,8 +7,8 @@ export code_tiled, @code_tiled -using CompilerCaching: CacheView, @setup_caching, get_ir, get_code, compile_hook, - compile_hook!, method_instance, populate! +using CompilerCaching: CacheView, @setup_caching, compile_hook, + compile_hook!, method_instance, typeinf! #============================================================================= Interpreter @@ -122,24 +122,34 @@ This phase uses cuTile's overlay method table for intrinsic substitution. """ function emit_ir(cache::CacheView, mi::Core.MethodInstance) interp = cuTileInterpreter(cache) - populate!(cache, interp, mi) + codeinfos = typeinf!(cache, interp, mi) + + # Get IRCode from the CodeInfo - no second inference needed + ci, codeinfo = first(codeinfos) + + # Get the MethodInstance from the CodeInstance for safety + ci_mi = @static if VERSION >= v"1.12-" + CC.get_ci_mi(ci) + else + ci.def::MethodInstance + end + + ir = CC.inflate_ir(codeinfo, ci_mi) - # Return StructuredIRCode for emit_code phase - ir, rettype = code_ircode(mi; world=cache.world) sci = StructuredIRCode(ir) - return (sci, rettype) + return (sci, ci.rettype) end """ - emit_code(cache, mi, ir_result) -> Vector{UInt8} + emit_code(cache, mi) -> Vector{UInt8} Code phase: generate Tile IR bytecode from StructuredIRCode. This phase is deterministic and does not require CUDA. Returns bytecode that can be compiled to CUBIN by tileiras in the emit_executable phase. """ -function emit_code(cache::CacheView, mi::Core.MethodInstance, ir_result) - sci, rettype = ir_result +function emit_code(cache::CacheView, mi::Core.MethodInstance) + sci, rettype = get!(emit_ir, cache, mi, :ir) opts = cache.keys # Generate Tile IR bytecode @@ -214,8 +224,7 @@ function code_structured(@nospecialize(f), @nospecialize(argtypes); opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy) cache = CacheView{CGOpts}(:cuTile, world, opts) - ir_result = get_ir(cache, mi; emit_ir) - sci, rettype = ir_result + sci, rettype = get!(emit_ir, cache, mi, :ir) return sci end @@ -241,7 +250,7 @@ function code_tiled(@nospecialize(f), @nospecialize(argtypes); opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy) cache = CacheView{CGOpts}(:cuTile, world, opts) - bytecode = get_code(cache, mi; emit_ir, emit_code) + bytecode = get!(emit_code, cache, mi, :code) disassemble_tileir(bytecode) end @@ -286,9 +295,9 @@ end macro code_tiled(ex...) function hook(cache, mi; io::IO=stdout) - # The hook fires during cached_compilation, so bytecode is being cached - # at this moment - retrieve it via get_code - bytecode = get_code(cache, mi; emit_ir, emit_code) + # The hook fires during get!, so bytecode is being cached + # at this moment - retrieve it via get! + bytecode = get!(emit_code, cache, mi, :code) println(io, "// $(mi.def.name)($(join(map(string, mi.specTypes.parameters[2:end]), ", ")))") println(io) println(io, disassemble_tileir(bytecode)) From 684105a987e697731bcb98516897b2b6896538b1 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 23 Jan 2026 11:41:30 +0100 Subject: [PATCH 4/5] Update to new CompilerCaching.jl API. --- ext/CUDAExt.jl | 25 ++++++++++----- src/compiler/interface.jl | 65 +++++++++++++++++++++++++++++---------- 2 files changed, 66 insertions(+), 24 deletions(-) diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index b4cb379..878451a 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -1,9 +1,9 @@ module CUDAExt using cuTile -using cuTile: TileArray, Constant, CGOpts, emit_ir, emit_code +using cuTile: TileArray, Constant, CGOpts, CuTileResults, emit_ir, emit_code -using CompilerCaching: CacheView, method_instance +using CompilerCaching: CacheView, method_instance, results using CUDA: CuModule, CuFunction, cudacall, device, capability using CUDA_Compiler_jll @@ -17,8 +17,17 @@ Executable phase: run tileiras on bytecode to produce CUBIN, then load into GPU This is the only session-dependent phase. """ function emit_executable(cache::CacheView, mi::Core.MethodInstance) - bytecode = get!(emit_code, cache, mi, :code) - opts = cache.keys + # First ensure code is cached + bytecode = emit_code(cache, mi) + + # Check if executable already cached + ci = get(cache, mi, nothing) + res = results(cache, ci) + if res.executable !== nothing + return res.executable + end + + opts = cache.owner[2] kernel_name = string(mi.def.name) # Run tileiras to produce CUBIN @@ -31,7 +40,9 @@ function emit_executable(cache::CacheView, mi::Core.MethodInstance) # Load into GPU memory cumod = CuModule(cubin) - return CuFunction(cumod, kernel_name) + cufunc = CuFunction(cumod, kernel_name) + res.executable = cufunc + return cufunc finally rm(input_path, force=true) rm(output_path, force=true) @@ -98,10 +109,10 @@ function cuTile.launch(@nospecialize(f), grid, args...; # Create cache view with compilation options as sharding keys opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy) - cache = CacheView{CGOpts}(:cuTile, world, opts) + cache = CacheView{CuTileResults}((:cuTile, opts), world) # Run cached compilation - cufunc = get!(emit_executable, cache, mi, :executable) + cufunc = emit_executable(cache, mi) # Flatten arguments for cudacall - Constant returns () so ghost types disappear flat_args = Tuple(Iterators.flatten(map(flatten, tile_args))) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 76823ab..87c5211 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -8,7 +8,7 @@ export code_tiled, @code_tiled using CompilerCaching: CacheView, @setup_caching, compile_hook, - compile_hook!, method_instance, typeinf! + compile_hook!, method_instance, typeinf!, results #============================================================================= Interpreter @@ -90,7 +90,7 @@ If always_inline=true (default), forces all functions to be inlined. """ function code_ircode(mi::MethodInstance; world::UInt=Base.get_world_counter(), always_inline::Bool=true) - cache = CacheView(:cuTile, world) + cache = CacheView{CuTileResults}(:cuTile, world) interp = cuTileInterpreter(cache; always_inline) result = CC.typeinf_ircode(interp, mi, nothing) @@ -114,6 +114,14 @@ const CGOpts = @NamedTuple{ occupancy::Union{Int, Nothing} } +# Results struct for caching compilation phases +mutable struct CuTileResults + ir::Any # (StructuredIRCode, rettype) + code::Any # Vector{UInt8} bytecode + executable::Any # CuFunction (populated by CUDAExt) + CuTileResults() = new(nothing, nothing, nothing) +end + """ emit_ir(cache, mi) -> (StructuredIRCode, rettype) @@ -121,10 +129,16 @@ IR phase: populate code cache with dependencies and return structured IR. This phase uses cuTile's overlay method table for intrinsic substitution. """ function emit_ir(cache::CacheView, mi::Core.MethodInstance) + # Check cache first + ci = get(cache, mi, nothing) + if ci !== nothing + res = results(cache, ci) + res.ir !== nothing && return res.ir + end + + # Cache miss - run inference (creates CI via @setup_caching) interp = cuTileInterpreter(cache) codeinfos = typeinf!(cache, interp, mi) - - # Get IRCode from the CodeInfo - no second inference needed ci, codeinfo = first(codeinfos) # Get the MethodInstance from the CodeInstance for safety @@ -135,9 +149,13 @@ function emit_ir(cache::CacheView, mi::Core.MethodInstance) end ir = CC.inflate_ir(codeinfo, ci_mi) - sci = StructuredIRCode(ir) - return (sci, ci.rettype) + + # Store in results + res = results(cache, ci) + res.ir = (sci, ci.rettype) + + return res.ir end """ @@ -149,8 +167,22 @@ This phase is deterministic and does not require CUDA. Returns bytecode that can be compiled to CUBIN by tileiras in the emit_executable phase. """ function emit_code(cache::CacheView, mi::Core.MethodInstance) - sci, rettype = get!(emit_ir, cache, mi, :ir) - opts = cache.keys + # First ensure IR is cached (this also populates CI) + sci, rettype = emit_ir(cache, mi) + + # Check if code already cached + ci = get(cache, mi, nothing) + res = results(cache, ci) + if res.code !== nothing + return res.code + end + + # Get options from owner (tuple: (:cuTile, opts) or just symbol) + opts = if cache.owner isa Tuple + cache.owner[2] + else + (sm_arch=nothing, opt_level=3, num_ctas=nothing, occupancy=nothing) + end # Generate Tile IR bytecode bytecode = write_bytecode!(1) do writer, func_buf @@ -178,6 +210,7 @@ function emit_code(cache::CacheView, mi::Core.MethodInstance) write(dump_path, bytecode) end + res.code = bytecode return bytecode end @@ -201,7 +234,7 @@ Return typed code for a cuTile function.. Analogous to `Base.code_typed`. """ function code_typed(@nospecialize(f), @nospecialize(argtypes); world::UInt=Base.get_world_counter(), kwargs...) - cache = CacheView(:cuTile, world) + cache = CacheView{CuTileResults}(:cuTile, world) interp = cuTileInterpreter(cache) Base.code_typed(f, argtypes; world, interp, kwargs...) end @@ -222,9 +255,9 @@ function code_structured(@nospecialize(f), @nospecialize(argtypes); throw(MethodError(f, argtypes))) opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy) - cache = CacheView{CGOpts}(:cuTile, world, opts) + cache = CacheView{CuTileResults}((:cuTile, opts), world) - sci, rettype = get!(emit_ir, cache, mi, :ir) + sci, rettype = emit_ir(cache, mi) return sci end @@ -248,9 +281,9 @@ function code_tiled(@nospecialize(f), @nospecialize(argtypes); throw(MethodError(f, argtypes))) opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy) - cache = CacheView{CGOpts}(:cuTile, world, opts) + cache = CacheView{CuTileResults}((:cuTile, opts), world) - bytecode = get!(emit_code, cache, mi, :code) + bytecode = emit_code(cache, mi) disassemble_tileir(bytecode) end @@ -268,7 +301,7 @@ function emit_hooked_compilation(inner_hook, ex...) old_hook = $compile_hook() try $compile_hook!(nothing) - opts = cache.keys + opts = cache.owner[2] $inner_hook(cache, mi; $(map(esc, user_kwargs)...)) finally $compile_hook!(old_hook) @@ -295,9 +328,7 @@ end macro code_tiled(ex...) function hook(cache, mi; io::IO=stdout) - # The hook fires during get!, so bytecode is being cached - # at this moment - retrieve it via get! - bytecode = get!(emit_code, cache, mi, :code) + bytecode = emit_code(cache, mi) println(io, "// $(mi.def.name)($(join(map(string, mi.specTypes.parameters[2:end]), ", ")))") println(io) println(io, disassemble_tileir(bytecode)) From fe1a83a0e073ebf8676e565529a976722384e911 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 23 Jan 2026 14:39:45 +0100 Subject: [PATCH 5/5] Update to new CompilerCaching.jl API. --- ext/CUDAExt.jl | 16 +++-- src/compiler/interface.jl | 119 ++++++++++++-------------------------- 2 files changed, 45 insertions(+), 90 deletions(-) diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index 878451a..cf92437 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -1,7 +1,7 @@ module CUDAExt using cuTile -using cuTile: TileArray, Constant, CGOpts, CuTileResults, emit_ir, emit_code +using cuTile: TileArray, Constant, CGOpts, CuTileResults, emit_code using CompilerCaching: CacheView, method_instance, results @@ -13,20 +13,18 @@ public launch """ emit_executable(cache, mi) -> CuFunction -Executable phase: run tileiras on bytecode to produce CUBIN, then load into GPU memory. -This is the only session-dependent phase. +Executable phase: compile bytecode to CUBIN and load into GPU memory. """ function emit_executable(cache::CacheView, mi::Core.MethodInstance) - # First ensure code is cached + # Delegate to previous phase (handles CI + IR + code) bytecode = emit_code(cache, mi) - # Check if executable already cached - ci = get(cache, mi, nothing) + # Check executable cache + ci = get(cache, mi) res = results(cache, ci) - if res.executable !== nothing - return res.executable - end + res.executable !== nothing && return res.executable + # Compile to CUBIN and load opts = cache.owner[2] kernel_name = string(mi.def.name) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 87c5211..eb6d372 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -7,8 +7,7 @@ export code_tiled, @code_tiled -using CompilerCaching: CacheView, @setup_caching, compile_hook, - compile_hook!, method_instance, typeinf!, results +using CompilerCaching: CacheView, @setup_caching, method_instance, typeinf!, results, get_source #============================================================================= Interpreter @@ -125,36 +124,27 @@ end """ emit_ir(cache, mi) -> (StructuredIRCode, rettype) -IR phase: populate code cache with dependencies and return structured IR. -This phase uses cuTile's overlay method table for intrinsic substitution. +IR phase: run inference if needed and return structured IR. """ function emit_ir(cache::CacheView, mi::Core.MethodInstance) - # Check cache first + # Ensure CI exists ci = get(cache, mi, nothing) - if ci !== nothing - res = results(cache, ci) - res.ir !== nothing && return res.ir + if ci === nothing + interp = cuTileInterpreter(cache) + typeinf!(cache, interp, mi) + ci = get(cache, mi) end - # Cache miss - run inference (creates CI via @setup_caching) - interp = cuTileInterpreter(cache) - codeinfos = typeinf!(cache, interp, mi) - ci, codeinfo = first(codeinfos) - - # Get the MethodInstance from the CodeInstance for safety - ci_mi = @static if VERSION >= v"1.12-" - CC.get_ci_mi(ci) - else - ci.def::MethodInstance - end + # Check IR cache + res = results(cache, ci) + res.ir !== nothing && return res.ir - ir = CC.inflate_ir(codeinfo, ci_mi) + # Compute IR from CodeInfo + src = @something get_source(ci) + ir = CC.inflate_ir(src, mi) sci = StructuredIRCode(ir) - # Store in results - res = results(cache, ci) res.ir = (sci, ci.rettype) - return res.ir end @@ -162,27 +152,18 @@ end emit_code(cache, mi) -> Vector{UInt8} Code phase: generate Tile IR bytecode from StructuredIRCode. -This phase is deterministic and does not require CUDA. - -Returns bytecode that can be compiled to CUBIN by tileiras in the emit_executable phase. """ function emit_code(cache::CacheView, mi::Core.MethodInstance) - # First ensure IR is cached (this also populates CI) + # Delegate to previous phase (handles CI + IR) sci, rettype = emit_ir(cache, mi) - # Check if code already cached - ci = get(cache, mi, nothing) + # Check code cache + ci = get(cache, mi) res = results(cache, ci) - if res.code !== nothing - return res.code - end + res.code !== nothing && return res.code - # Get options from owner (tuple: (:cuTile, opts) or just symbol) - opts = if cache.owner isa Tuple - cache.owner[2] - else - (sm_arch=nothing, opt_level=3, num_ctas=nothing, occupancy=nothing) - end + # Compute bytecode + opts = cache.owner[2] # Generate Tile IR bytecode bytecode = write_bytecode!(1) do writer, func_buf @@ -287,51 +268,27 @@ function code_tiled(@nospecialize(f), @nospecialize(argtypes); disassemble_tileir(bytecode) end -# compilation hooking: uses CompilerCaching's global hook -function emit_hooked_compilation(inner_hook, ex...) - user_code = ex[end] - user_kwargs = ex[1:end-1] - quote - # we only want to invoke the hook once for every compilation - seen = Set() - function outer_hook(cache, mi) - key = (cache, mi) - if !in(key, seen) - # the user hook might invoke the compiler again, so disable the hook - old_hook = $compile_hook() - try - $compile_hook!(nothing) - opts = cache.owner[2] - $inner_hook(cache, mi; $(map(esc, user_kwargs)...)) - finally - $compile_hook!(old_hook) - end - push!(seen, key) - end - end +""" + @code_tiled f(args...) - # now invoke the user code with this hook in place - try - $compile_hook!(outer_hook) - $(esc(user_code)) - finally - $compile_hook!(nothing) - end +Print the Tile IR for the kernel that would be launched by the given call. +This is a convenience macro that extracts the function and argument types. - if isempty(seen) - error("no kernels executed while evaluating the given expression") - end - - nothing +# Example +```julia +@code_tiled vadd_kernel(a, b, c) +``` +""" +macro code_tiled(call) + if !(call isa Expr && call.head === :call) + error("@code_tiled requires a function call expression") end -end - -macro code_tiled(ex...) - function hook(cache, mi; io::IO=stdout) - bytecode = emit_code(cache, mi) - println(io, "// $(mi.def.name)($(join(map(string, mi.specTypes.parameters[2:end]), ", ")))") - println(io) - println(io, disassemble_tileir(bytecode)) + f = call.args[1] + args = call.args[2:end] + quote + local f_val = $(esc(f)) + local args_val = ($(map(esc, args)...),) + local argtypes = Tuple{map(typeof, args_val)...} + code_tiled(f_val, argtypes) end - emit_hooked_compilation(hook, ex...) end