diff --git a/.gitignore b/.gitignore index daf5565..fd9f176 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ res/ *.cubin -Manifest.toml +Manifest*.toml LocalPreferences.toml CLAUDE.md AGENTS.md diff --git a/Project.toml b/Project.toml index 83cf4b8..129ed4b 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Tim Besard "] version = "0.1.0" [deps] +CompilerCaching = "9db33cc3-5358-4881-8759-fa4194144afd" CUDA_Compiler_jll = "d1e2174e-dfdc-576e-b43e-73b79eb1aca8" CUDA_Tile_jll = "2068806d-a867-5dbd-af0e-42c2eb5d895d" IRStructurizer = "93e32bba-5bb8-402b-805d-ffb066edee93" @@ -12,6 +13,8 @@ IRStructurizer = "93e32bba-5bb8-402b-805d-ffb066edee93" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" [sources] +#CompilerCaching = {url = "https://github.com/maleadt/CompilerCaching.jl", rev="main"} +CompilerCaching = {path = "/Users/tim/Julia/pkg/CompilerCaching"} IRStructurizer = {path = "IRStructurizer"} [extensions] @@ -21,6 +24,3 @@ CUDAExt = "CUDA" julia = "1.11" CUDA_Compiler_jll = "0.4" CUDA_Tile_jll = "13.1" - -[workspace] -projects = ["test", "IRStructurizer", "FileCheck"] diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index 0012c3d..cf92437 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -1,15 +1,51 @@ module CUDAExt using cuTile -using cuTile: TileArray, Constant, emit_tileir +using cuTile: TileArray, Constant, CGOpts, CuTileResults, emit_code + +using CompilerCaching: CacheView, method_instance, results using CUDA: CuModule, CuFunction, cudacall, device, capability using CUDA_Compiler_jll public launch -# Compilation cache - stores CuFunction directly to avoid re-loading CuModule -const _compilation_cache = Dict{Any, Any}() # (method, argtypes, sm_arch, opt_level, num_ctas, occupancy) => CuFunction +""" + emit_executable(cache, mi) -> CuFunction + +Executable phase: compile bytecode to CUBIN and load into GPU memory. +""" +function emit_executable(cache::CacheView, mi::Core.MethodInstance) + # Delegate to previous phase (handles CI + IR + code) + bytecode = emit_code(cache, mi) + + # Check executable cache + ci = get(cache, mi) + res = results(cache, ci) + res.executable !== nothing && return res.executable + + # Compile to CUBIN and load + opts = cache.owner[2] + kernel_name = string(mi.def.name) + + # Run tileiras to produce CUBIN + input_path = tempname() * ".tile" + output_path = tempname() * ".cubin" + try + write(input_path, bytecode) + run(`$(CUDA_Compiler_jll.tileiras()) $input_path -o $output_path --gpu-name $(opts.sm_arch) -O$(opts.opt_level)`) + cubin = read(output_path) + + # Load into GPU memory + cumod = CuModule(cubin) + cufunc = CuFunction(cumod, kernel_name) + res.executable = cufunc + return cufunc + finally + rm(input_path, force=true) + rm(output_path, force=true) + end +end """ launch(f, grid, args...; name=nothing, sm_arch=default_sm_arch(), opt_level=3, num_ctas=nothing, occupancy=nothing) @@ -62,23 +98,19 @@ function cuTile.launch(@nospecialize(f), grid, args...; # Compute argument types from the converted arguments argtypes = Tuple{map(typeof, tile_args)...} - # Determine kernel name - kernel_name = name !== nothing ? name : string(nameof(f)) - - # Use method instance in case of a redefinition - method = which(f, argtypes) - - # Check compilation cache - returns CuFunction directly - cache_key = (method, argtypes, sm_arch, opt_level, num_ctas, occupancy) - cufunc = get(_compilation_cache, cache_key, nothing) - if cufunc === nothing || cuTile.compile_hook[] !== nothing - cubin = compile(f, argtypes; name, sm_arch, opt_level, num_ctas, occupancy) - if cufunc === nothing - cumod = CuModule(cubin) - cufunc = CuFunction(cumod, kernel_name) - _compilation_cache[cache_key] = cufunc - end - end + # Get world age and method instance + # Don't pass method_table - kernel functions are in the global table + # The overlay table is only used by the interpreter during inference + world = Base.get_world_counter() + mi = method_instance(f, argtypes; world) + mi === nothing && throw(MethodError(f, argtypes)) + + # Create cache view with compilation options as sharding keys + opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy) + cache = CacheView{CuTileResults}((:cuTile, opts), world) + + # Run cached compilation + cufunc = emit_executable(cache, mi) # Flatten arguments for cudacall - Constant returns () so ghost types disappear flat_args = Tuple(Iterators.flatten(map(flatten, tile_args))) @@ -104,52 +136,6 @@ function cuTile.launch(@nospecialize(f), grid, args...; return nothing end -""" - compile(f, argtypes; name=nothing, sm_arch=default_sm_arch(), opt_level=3, num_ctas=nothing, occupancy=nothing) -> Vector{UInt8} - -Compile a Julia kernel function to a CUDA binary. -""" -function compile(@nospecialize(f), @nospecialize(argtypes); - name::Union{String, Nothing}=nothing, - sm_arch::String=default_sm_arch(), - opt_level::Int=3, - num_ctas::Union{Int, Nothing}=nothing, - occupancy::Union{Int, Nothing}=nothing) - tile_bytecode = emit_tileir(f, argtypes; name, sm_arch, - num_ctas, occupancy) - - # Dump bytecode if JULIA_CUTILE_DUMP_BYTECODE is set - dump_dir = get(ENV, "JULIA_CUTILE_DUMP_BYTECODE", nothing) - if dump_dir !== nothing - mkpath(dump_dir) - # Get source location from the function's method - mi = first(methods(f, Base.to_tuple_type(argtypes))) - base_filename = basename(string(mi.file)) - base_filename = first(splitext(base_filename)) - # Find unique filename, adding counter if file exists - dump_path = joinpath(dump_dir, "$(base_filename).ln$(mi.line).cutile") - counter = 1 - while isfile(dump_path) - counter += 1 - dump_path = joinpath(dump_dir, "$(base_filename).ln$(mi.line).$(counter).cutile") - end - println(stderr, "Dumping TILEIR bytecode to file: $dump_path") - write(dump_path, tile_bytecode) - end - - input_path = tempname() * ".tile" - output_path = tempname() * ".cubin" - - try - write(input_path, tile_bytecode) - run(`$(CUDA_Compiler_jll.tileiras()) $input_path -o $output_path --gpu-name $sm_arch -O$opt_level`) - return read(output_path) - finally - rm(input_path, force=true) - rm(output_path, force=true) - end -end - """ default_sm_arch() -> String diff --git a/src/compiler/codegen.jl b/src/compiler/codegen.jl index 2c775d4..564aa8e 100644 --- a/src/compiler/codegen.jl +++ b/src/compiler/codegen.jl @@ -1,8 +1,8 @@ # Codegen: Julia IR -> Tile IR bytecode +include("codegen/utils.jl") include("codegen/kernel.jl") include("codegen/control_flow.jl") include("codegen/statements.jl") include("codegen/expressions.jl") include("codegen/values.jl") -include("codegen/utils.jl") diff --git a/src/compiler/codegen/kernel.jl b/src/compiler/codegen/kernel.jl index 81492b1..b50fa6d 100644 --- a/src/compiler/codegen/kernel.jl +++ b/src/compiler/codegen/kernel.jl @@ -1,22 +1,22 @@ # kernel and argument handling """ - emit_kernel!(writer, func_buf, target; name, sm_arch=nothing, is_entry=true, num_ctas=nothing, occupancy=nothing) + emit_kernel!(writer, func_buf, sci, rettype; name, sm_arch=nothing, is_entry=true, num_ctas=nothing, occupancy=nothing) -Compile a TileTarget to Tile IR bytecode. +Compile a StructuredIRCode to Tile IR bytecode. """ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8}, - target::TileTarget; - name::String = string(target.mi.def.name), + sci::StructuredIRCode, rettype::Type; + name::String, sm_arch::Union{String, Nothing} = nothing, is_entry::Bool = true, num_ctas::Union{Int, Nothing} = nothing, occupancy::Union{Int, Nothing} = nothing) - ctx = CGCtx(writer, target, sm_arch) + ctx = CGCtx(writer, sci, sm_arch) tt = ctx.tt # Validate non-ghost argument types are concrete - for (i, argtype) in enumerate(target.sci.argtypes) + for (i, argtype) in enumerate(sci.argtypes) is_ghost_type(unwrap_type(argtype)) && continue require_concrete_type(argtype, "kernel argument $i") end @@ -25,7 +25,7 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8}, param_types = TypeId[] param_mapping = Tuple{Int, Union{Nothing, Symbol}}[] - for (i, argtype) in enumerate(target.sci.argtypes) + for (i, argtype) in enumerate(sci.argtypes) argtype_unwrapped = unwrap_type(argtype) if is_ghost_type(argtype_unwrapped) continue @@ -57,8 +57,8 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8}, # Return types result_types = TypeId[] - if target.rettype !== Nothing && target.rettype !== Union{} - push!(result_types, tile_type_for_julia!(ctx, target.rettype)) + if rettype !== Nothing && rettype !== Union{} + push!(result_types, tile_type_for_julia!(ctx, rettype)) end # Create entry hints if provided @@ -92,8 +92,8 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8}, # Regular argument - create concrete CGVal @assert length(values) == 1 val = values[1] - type_id = tile_type_for_julia!(ctx, target.sci.argtypes[arg_idx]) - tv = CGVal(val, type_id, target.sci.argtypes[arg_idx]) + type_id = tile_type_for_julia!(ctx, sci.argtypes[arg_idx]) + tv = CGVal(val, type_id, sci.argtypes[arg_idx]) ctx[SlotNumber(arg_idx)] = tv ctx[Argument(arg_idx)] = tv end @@ -117,7 +117,7 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8}, ctx.token = encode_MakeTokenOp!(cb, token_type) # Emit the structured IR (uses original Julia SSA indices everywhere) - emit_block!(ctx, ctx.target.sci.entry) + emit_block!(ctx, ctx.sci.entry) finalize_function!(func_buf, cb, writer.debug_info) end diff --git a/src/compiler/codegen/utils.jl b/src/compiler/codegen/utils.jl index abc2c97..4dd9f44 100644 --- a/src/compiler/codegen/utils.jl +++ b/src/compiler/codegen/utils.jl @@ -1,3 +1,366 @@ +# Codegen types and utilities +# +# Core types (CGVal, CGCtx) and helper functions for Tile IR code generation. + +#============================================================================= + CGVal: Unified value representation (analogous to Julia's jl_cgval_t) +=============================================================================# + +""" + CGVal + +Represents a value during Tile IR code generation, bundling the IR value +with its type information and metadata. + +Similar to Julia compiler's `jl_cgval_t`, this provides a unified representation +for all values flowing through codegen. A CGVal can be either: +1. A concrete SSA value (v::Value) +2. A multi-value result from control flow ops (v::Vector{Value}) +3. A lazy argument reference chain (v is nothing, arg_ref tracks the access path) +4. A ghost value (v is nothing, zero-size singleton) +""" +struct CGVal + v::Union{Value, Vector{Value}, Nothing} # Single value, multi-value, or nothing + type_id::Union{TypeId, Nothing} # Tile IR type (nothing for lazy refs or multi-value) + jltype::Any # Original Julia type + shape::Vector{Int} # Tile shape (empty for scalars) + # Lazy argument reference: (arg_idx, [:field, index, ...]) + # e.g., (1, [:sizes, 2]) means "argument 1, field :sizes, index 2" + arg_ref::Union{Tuple{Int, Vector{Union{Symbol, Int}}}, Nothing} + constant::Union{Some, Nothing} # Nothing = no constant, Some(x) = constant value x + tuple::Union{Vector{Any}, Nothing} # For tuples: component refs (SSAValue, etc.) +end + +# Convenience constructors for concrete values +CGVal(v::Value, type_id::TypeId, @nospecialize(jltype)) = + CGVal(v, type_id, jltype, Int[], nothing, nothing, nothing) + +CGVal(v::Value, type_id::TypeId, @nospecialize(jltype), shape::Vector{Int}) = + CGVal(v, type_id, jltype, shape, nothing, nothing, nothing) + +# Constructor for multi-value results (from loops, ifs) +CGVal(v::Vector{Value}, @nospecialize(jltype)) = + CGVal(v, nothing, jltype, Int[], nothing, nothing, nothing) + +# Constructor for lazy argument references +function arg_ref_value(arg_idx::Int, chain::Vector{Union{Symbol, Int}}, @nospecialize(jltype)) + CGVal(nothing, nothing, jltype, Int[], (arg_idx, chain), nothing, nothing) +end + +""" + ghost_value(jltype[, constant]) -> CGVal + +Create a ghost value (zero-size singleton with no runtime representation). +Optionally stores a compile-time constant value. +""" +ghost_value(@nospecialize(jltype)) = CGVal(nothing, TypeId(-1), jltype, Int[], nothing, nothing, nothing) +ghost_value(@nospecialize(jltype), constant) = CGVal(nothing, TypeId(-1), jltype, Int[], nothing, Some(constant), nothing) + +""" + tuple_value(jltype, component_refs, component_constants) -> CGVal + +Create a tuple value with tracked component refs. Derives constant if all components have constants. +Used by intrinsics like cat() that need to access individual tuple elements. +""" +function tuple_value(@nospecialize(jltype), component_refs::Vector{Any}, component_constants::Vector{Any}) + # If all components have constants, derive the tuple constant + constant = if all(!isnothing, component_constants) + Some(Tuple(component_constants)) + else + nothing + end + CGVal(nothing, TypeId(-1), jltype, Int[], nothing, constant, component_refs) +end + +""" + is_ghost(tv::CGVal) -> Bool + +Check if a CGVal is a ghost (no runtime representation). +""" +is_ghost(tv::CGVal) = tv.v === nothing && tv.arg_ref === nothing + +""" + is_arg_ref(tv::CGVal) -> Bool + +Check if a CGVal is a lazy argument reference. +""" +is_arg_ref(tv::CGVal) = tv.arg_ref !== nothing + +#============================================================================= + CGCtx: Compilation context +=============================================================================# + +""" + CGCtx + +Holds all state during Tile IR code generation for a kernel function. +Maps Julia SSA values to CGVals and manages bytecode emission. +""" +mutable struct CGCtx + # SSA value mapping: original Julia SSA index -> CGVal + # Uses global/original indices everywhere (no local renumbering) + # Loop/if ops store a CGVal with tuple_values field (extracted by getfield statements) + values::Dict{Int, CGVal} + args::Dict{Int, CGVal} # Argument index -> CGVal + slots::Dict{Int, CGVal} # Slot number -> CGVal + block_args::Dict{Int, CGVal} # BlockArg id -> CGVal (for control flow) + + # Destructured argument handling (for TileArray fields) + arg_flat_values::Dict{Tuple{Int, Union{Nothing, Symbol}}, Vector{Value}} + arg_types::Dict{Int, Type} + + # Cached TensorViews for TileArray arguments (arg_idx -> (Value, TypeId)) + tensor_views::Dict{Int, Tuple{Value, TypeId}} + + # Bytecode infrastructure + cb::CodeBuilder + tt::TypeTable + sci::StructuredIRCode + + # Memory ordering token + token::Union{Value, Nothing} + token_type::Union{TypeId, Nothing} + + # Type cache: Julia type -> TypeId + type_cache::Dict{Type, TypeId} + + # Target architecture (e.g., :sm_100) + sm_arch::Union{String, Nothing} +end + +function CGCtx(writer::BytecodeWriter, sci::StructuredIRCode, sm_arch::Union{String, Nothing}=nothing) + CGCtx( + Dict{Int, CGVal}(), + Dict{Int, CGVal}(), + Dict{Int, CGVal}(), + Dict{Int, CGVal}(), + Dict{Tuple{Int, Union{Nothing, Symbol}}, Vector{Value}}(), + Dict{Int, Type}(), + Dict{Int, Tuple{Value, TypeId}}(), # tensor_views cache + CodeBuilder(writer.string_table, writer.constant_table, writer.type_table), + writer.type_table, + sci, + nothing, + nothing, + Dict{Type, TypeId}(), + sm_arch, + ) +end + +#============================================================================= + Value lookup via indexing syntax +=============================================================================# + +function Base.getindex(ctx::CGCtx, ssa::SSAValue) + # Simple lookup by original Julia SSA index + get(ctx.values, ssa.id, nothing) +end + +function Base.getindex(ctx::CGCtx, arg::Argument) + get(ctx.args, arg.n, nothing) +end + +function Base.getindex(ctx::CGCtx, slot::SlotNumber) + get(ctx.slots, slot.id, nothing) +end + +function Base.setindex!(ctx::CGCtx, tv::CGVal, ssa::SSAValue) + ctx.values[ssa.id] = tv +end + +function Base.setindex!(ctx::CGCtx, tv::CGVal, arg::Argument) + ctx.args[arg.n] = tv +end + +function Base.setindex!(ctx::CGCtx, tv::CGVal, slot::SlotNumber) + ctx.slots[slot.id] = tv +end + +function Base.getindex(ctx::CGCtx, block_arg::BlockArg) + get(ctx.block_args, block_arg.id, nothing) +end + +function Base.setindex!(ctx::CGCtx, tv::CGVal, block_arg::BlockArg) + ctx.block_args[block_arg.id] = tv +end + +#============================================================================= + Destructured argument helpers +=============================================================================# + +""" + get_arg_flat_values(ctx, arg_idx, field=nothing) -> Union{Vector{Value}, Nothing} + +Get the flat Tile IR values for an argument or its field. +""" +function get_arg_flat_values(ctx::CGCtx, arg_idx::Int, field::Union{Nothing, Symbol}=nothing) + get(ctx.arg_flat_values, (arg_idx, field), nothing) +end + +""" + is_destructured_arg(ctx, arg_idx) -> Bool + +Check if an argument was destructured into multiple flat parameters. +""" +is_destructured_arg(ctx::CGCtx, arg_idx::Int) = haskey(ctx.arg_types, arg_idx) + +""" + get_arg_type(ctx, arg_idx) -> Union{Type, Nothing} + +Get the original Julia type for a destructured argument. +""" +get_arg_type(ctx::CGCtx, arg_idx::Int) = get(ctx.arg_types, arg_idx, nothing) + +#============================================================================= + Type conversion utilities +=============================================================================# + +""" + unwrap_type(T) -> Type + +Unwrap type wrappers like Core.Const to get the actual type. +""" +function unwrap_type(@nospecialize(T)) + if T isa Core.Const + return typeof(T.val) + elseif T isa Core.PartialStruct + return T.typ + elseif T isa Type + return T + else + return T + end +end + +""" + require_concrete_type(T, context::String) + +Ensure a type is fully concrete (not a UnionAll). +""" +function require_concrete_type(@nospecialize(T), context::String) + T_unwrapped = unwrap_type(T) + if T_unwrapped isa UnionAll + error("Type must be fully concrete in $context, got partial type: $T") + end + return T_unwrapped +end + +""" + tile_type_for_julia!(ctx, T) -> TypeId + +Get or create a Tile IR type for a Julia type. +""" +function tile_type_for_julia!(ctx::CGCtx, @nospecialize(T)) + actual_type = unwrap_type(T) + get!(ctx.type_cache, actual_type) do + _tile_type_for_julia!(ctx.tt, actual_type) + end +end + +function _tile_type_for_julia!(tt::TypeTable, @nospecialize(T::Type)) + # Scalar types -> 0-D tile + if T === Bool + return tile_type!(tt, I1(tt), Int[]) + elseif T === Int32 || T === UInt32 + return tile_type!(tt, I32(tt), Int[]) + elseif T === Int64 || T === UInt64 + return tile_type!(tt, I64(tt), Int[]) + elseif T === Float16 + return tile_type!(tt, F16(tt), Int[]) + elseif T === Float32 + return tile_type!(tt, F32(tt), Int[]) + elseif T === Float64 + return tile_type!(tt, F64(tt), Int[]) + elseif T === Nothing + return Token(tt) + end + + # Pointers -> 0-D tile of pointer type + if T <: Ptr + elem_dtype = julia_to_tile_dtype!(tt, eltype(T)) + ptr_type = pointer_type!(tt, elem_dtype) + return tile_type!(tt, ptr_type, Int[]) + end + + # Tile{T, Shape} -> tile type with shape + if T <: Tile + if T isa UnionAll || !isa(T, DataType) || length(T.parameters) < 2 + error("Tile type must be fully specified with element type and shape, got: $T. " * + "This indicates type instability in the kernel - ensure all tile operations have inferrable shapes.") + end + elem_type = T.parameters[1] + shape_param = T.parameters[2] + if !(shape_param isa Tuple) + error("Tile shape must be a tuple, got: $shape_param") + end + elem_dtype = julia_to_tile_dtype!(tt, elem_type) + shape = collect(Int, shape_param) + return tile_type!(tt, elem_dtype, shape) + end + + error("Unsupported Julia type for Tile IR: $T") +end + +""" + tile_type_and_shape_for_julia!(ctx, T) -> (TypeId, Vector{Int}) + +Get the Tile IR type and shape for a Julia type. +""" +function tile_type_and_shape_for_julia!(ctx::CGCtx, @nospecialize(T)) + actual_type = unwrap_type(T) + type_id = tile_type_for_julia!(ctx, actual_type) + + # Extract shape from Tile types + shape = Int[] + if actual_type <: Tile && length(actual_type.parameters) >= 2 + shape_param = actual_type.parameters[2] + if shape_param isa Tuple + shape = collect(Int, shape_param) + end + end + + return (type_id, shape) +end + +#============================================================================= + Struct destructuring helpers +=============================================================================# + +""" + is_ghost_type(T) -> Bool + +Check if a type is a ghost type (zero-size singleton). +""" +function is_ghost_type(@nospecialize(T)) + try + isbitstype(T) && sizeof(T) == 0 + catch + false + end +end + +""" + should_destructure(T) -> Bool + +Check if a type should be destructured into flat parameters. +""" +function should_destructure(@nospecialize(T)) + T = unwrap_type(T) + isstructtype(T) || return false + is_ghost_type(T) && return false + isprimitivetype(T) && return false + T <: TileArray && return true + return false +end + +""" + flat_field_count(T) -> Int + +Count flat parameters a type expands to. +""" +flat_field_count(::Type{<:NTuple{N, T}}) where {N, T} = N +flat_field_count(::Type) = 1 + #----------------------------------------------------------------------------- # Argument helpers #----------------------------------------------------------------------------- diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl new file mode 100644 index 0000000..eb6d372 --- /dev/null +++ b/src/compiler/interface.jl @@ -0,0 +1,294 @@ +# Compilation interface for cuTile +# +# This file provides the public compilation API: +# - cuTileInterpreter: custom interpreter with overlay method table +# - emit_ir/emit_code: three-phase compilation callbacks (emit_executable in CUDAExt) +# - code_tiled/@code_tiled: reflection utilities + +export code_tiled, @code_tiled + +using CompilerCaching: CacheView, @setup_caching, method_instance, typeinf!, results, get_source + +#============================================================================= + Interpreter +=============================================================================# + +Base.Experimental.@MethodTable cuTileMethodTable + +function get_method_table_view(world::UInt) + CC.CachedMethodTable(CC.OverlayMethodTable(world, cuTileMethodTable)) +end + +""" +Custom interpreter that supports overlay method tables for cuTile compilation. +This is necessary because NativeInterpreter has a fixed method_table type parameter. +""" +struct cuTileInterpreter <: CC.AbstractInterpreter + cache::CacheView + method_table::CC.CachedMethodTable{CC.OverlayMethodTable} + inf_cache::Vector{CC.InferenceResult} + inf_params::CC.InferenceParams + opt_params::CC.OptimizationParams +end + +function cuTileInterpreter(cache::CacheView; always_inline::Bool=true) + method_table = get_method_table_view(cache.world) + inf_cache = Vector{CC.InferenceResult}() + inf_params = CC.InferenceParams() + opt_params = if always_inline + CC.OptimizationParams(; inline_cost_threshold=typemax(Int)) + else + CC.OptimizationParams() + end + return cuTileInterpreter(cache, method_table, inf_cache, inf_params, opt_params) +end + +# Required AbstractInterpreter interface methods +CC.InferenceParams(interp::cuTileInterpreter) = interp.inf_params +CC.OptimizationParams(interp::cuTileInterpreter) = interp.opt_params +CC.get_inference_cache(interp::cuTileInterpreter) = interp.inf_cache + +# World age +@static if isdefined(CC, :get_inference_world) + CC.get_inference_world(interp::cuTileInterpreter) = interp.cache.world +else + CC.get_world_counter(interp::cuTileInterpreter) = interp.cache.world +end + +# Method table - this enables the overlays +CC.method_table(interp::cuTileInterpreter) = interp.method_table + +# Locking - not needed for non-cached compilation +CC.lock_mi_inference(::cuTileInterpreter, ::MethodInstance) = nothing +CC.unlock_mi_inference(::cuTileInterpreter, ::MethodInstance) = nothing + +# Setup caching - generates cache_owner and ipo_dataflow_analysis! methods +@setup_caching cuTileInterpreter.cache + +# Optimization flags +CC.may_optimize(::cuTileInterpreter) = true +CC.may_compress(::cuTileInterpreter) = true +CC.may_discard_trees(::cuTileInterpreter) = true + +# Disable semi-concrete interpretation (broken with overlays per JuliaLang/julia#47349) +function CC.concrete_eval_eligible(interp::cuTileInterpreter, + @nospecialize(f), result::CC.MethodCallResult, arginfo::CC.ArgInfo, sv::CC.InferenceState) + ret = @invoke CC.concrete_eval_eligible(interp::CC.AbstractInterpreter, + f::Any, result::CC.MethodCallResult, arginfo::CC.ArgInfo, sv::CC.InferenceState) + if ret === :semi_concrete_eval + return :none + end + return ret +end + +""" + code_ircode(mi::MethodInstance; world, always_inline=true) -> (IRCode, rettype) + +Get optimized IRCode for a MethodInstance using cuTile's overlay method table. +If always_inline=true (default), forces all functions to be inlined. +""" +function code_ircode(mi::MethodInstance; world::UInt=Base.get_world_counter(), + always_inline::Bool=true) + cache = CacheView{CuTileResults}(:cuTile, world) + interp = cuTileInterpreter(cache; always_inline) + result = CC.typeinf_ircode(interp, mi, nothing) + + if result === nothing + error("Type inference failed for $mi") + end + + ir, rettype = result + return ir, rettype +end + +#============================================================================= + Compilation phases +=============================================================================# + +# Compilation options for cache sharding +const CGOpts = @NamedTuple{ + sm_arch::Union{String, Nothing}, + opt_level::Int, + num_ctas::Union{Int, Nothing}, + occupancy::Union{Int, Nothing} +} + +# Results struct for caching compilation phases +mutable struct CuTileResults + ir::Any # (StructuredIRCode, rettype) + code::Any # Vector{UInt8} bytecode + executable::Any # CuFunction (populated by CUDAExt) + CuTileResults() = new(nothing, nothing, nothing) +end + +""" + emit_ir(cache, mi) -> (StructuredIRCode, rettype) + +IR phase: run inference if needed and return structured IR. +""" +function emit_ir(cache::CacheView, mi::Core.MethodInstance) + # Ensure CI exists + ci = get(cache, mi, nothing) + if ci === nothing + interp = cuTileInterpreter(cache) + typeinf!(cache, interp, mi) + ci = get(cache, mi) + end + + # Check IR cache + res = results(cache, ci) + res.ir !== nothing && return res.ir + + # Compute IR from CodeInfo + src = @something get_source(ci) + ir = CC.inflate_ir(src, mi) + sci = StructuredIRCode(ir) + + res.ir = (sci, ci.rettype) + return res.ir +end + +""" + emit_code(cache, mi) -> Vector{UInt8} + +Code phase: generate Tile IR bytecode from StructuredIRCode. +""" +function emit_code(cache::CacheView, mi::Core.MethodInstance) + # Delegate to previous phase (handles CI + IR) + sci, rettype = emit_ir(cache, mi) + + # Check code cache + ci = get(cache, mi) + res = results(cache, ci) + res.code !== nothing && return res.code + + # Compute bytecode + opts = cache.owner[2] + + # Generate Tile IR bytecode + bytecode = write_bytecode!(1) do writer, func_buf + emit_kernel!(writer, func_buf, sci, rettype; + name = string(mi.def.name), + sm_arch = opts.sm_arch, + num_ctas = opts.num_ctas, + occupancy = opts.occupancy + ) + end + + # Dump bytecode if JULIA_CUTILE_DUMP_BYTECODE is set + dump_dir = get(ENV, "JULIA_CUTILE_DUMP_BYTECODE", nothing) + if dump_dir !== nothing + mkpath(dump_dir) + base_filename = basename(string(mi.def.file)) + base_filename = first(splitext(base_filename)) + dump_path = joinpath(dump_dir, "$(base_filename).ln$(mi.def.line).cutile") + counter = 1 + while isfile(dump_path) + counter += 1 + dump_path = joinpath(dump_dir, "$(base_filename).ln$(mi.def.line).$(counter).cutile") + end + println(stderr, "Dumping TILEIR bytecode to file: $dump_path") + write(dump_path, bytecode) + end + + res.code = bytecode + return bytecode +end + +#============================================================================= + Reflection utilities +=============================================================================# + +function disassemble_tileir(bytecode::Vector{UInt8})::String + mktempdir() do dir + input_path = joinpath(dir, "kernel.tile") + output_path = joinpath(dir, "kernel.disasm") + write(input_path, bytecode) + read(`$(cuda_tile_translate()) --cudatilebc-to-mlir $input_path`, String) + end +end + +""" + code_typed(f, argtypes; world, kwargs...) -> Vector{Any} + +Return typed code for a cuTile function.. Analogous to `Base.code_typed`. +""" +function code_typed(@nospecialize(f), @nospecialize(argtypes); + world::UInt=Base.get_world_counter(), kwargs...) + cache = CacheView{CuTileResults}(:cuTile, world) + interp = cuTileInterpreter(cache) + Base.code_typed(f, argtypes; world, interp, kwargs...) +end + +""" + code_structured(f, argtypes; kwargs...) -> StructuredIRCode + +Return the structured IR for a cuTile function. +""" +function code_structured(@nospecialize(f), @nospecialize(argtypes); + sm_arch::Union{String, Nothing}=nothing, + opt_level::Int=3, + num_ctas::Union{Int, Nothing}=nothing, + occupancy::Union{Int, Nothing}=nothing) + world = Base.get_world_counter() + mi = @something(method_instance(f, argtypes; world, method_table=cuTileMethodTable), + method_instance(f, argtypes; world), + throw(MethodError(f, argtypes))) + + opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy) + cache = CacheView{CuTileResults}((:cuTile, opts), world) + + sci, rettype = emit_ir(cache, mi) + return sci +end + +""" + code_tiled(f, argtypes; sm_arch, opt_level, num_ctas, occupancy) -> String + +Return the CUDA Tile IR for a Julia function as a textual MLIR representation. +Analogous to `code_typed` or `code_structured`. + +Uses the same caching infrastructure as `launch`, benefiting from cached IR +and code results. +""" +function code_tiled(@nospecialize(f), @nospecialize(argtypes); + sm_arch::Union{String, Nothing}=nothing, + opt_level::Int=3, + num_ctas::Union{Int, Nothing}=nothing, + occupancy::Union{Int, Nothing}=nothing) + world = Base.get_world_counter() + mi = @something(method_instance(f, argtypes; world, method_table=cuTileMethodTable), + method_instance(f, argtypes; world), + throw(MethodError(f, argtypes))) + + opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=num_ctas, occupancy=occupancy) + cache = CacheView{CuTileResults}((:cuTile, opts), world) + + bytecode = emit_code(cache, mi) + disassemble_tileir(bytecode) +end + +""" + @code_tiled f(args...) + +Print the Tile IR for the kernel that would be launched by the given call. +This is a convenience macro that extracts the function and argument types. + +# Example +```julia +@code_tiled vadd_kernel(a, b, c) +``` +""" +macro code_tiled(call) + if !(call isa Expr && call.head === :call) + error("@code_tiled requires a function call expression") + end + f = call.args[1] + args = call.args[2:end] + quote + local f_val = $(esc(f)) + local args_val = ($(map(esc, args)...),) + local argtypes = Tuple{map(typeof, args_val)...} + code_tiled(f_val, argtypes) + end +end diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl deleted file mode 100644 index 1547375..0000000 --- a/src/compiler/interpreter.jl +++ /dev/null @@ -1,112 +0,0 @@ -using Base.Experimental: @MethodTable - -# Create the cuTile method table -@MethodTable cuTileMethodTable - -# Create cached method table type based on version compatibility -const cuTileMethodTableView = CC.CachedMethodTable{CC.OverlayMethodTable} - -function get_method_table_view(world::UInt) - CC.CachedMethodTable(CC.OverlayMethodTable(world, cuTileMethodTable)) -end - -""" -Custom interpreter that supports overlay method tables for cuTile compilation. -This is necessary because NativeInterpreter has a fixed method_table type parameter. -""" -struct cuTileInterpreter <: CC.AbstractInterpreter - world::UInt - method_table::cuTileMethodTableView - inf_cache::Vector{CC.InferenceResult} - inf_params::CC.InferenceParams - opt_params::CC.OptimizationParams -end - -function cuTileInterpreter(world::UInt=Base.get_world_counter(); - always_inline::Bool=true) - method_table = get_method_table_view(world) - inf_cache = Vector{CC.InferenceResult}() - inf_params = CC.InferenceParams() - opt_params = if always_inline - CC.OptimizationParams(; inline_cost_threshold=typemax(Int)) - else - CC.OptimizationParams() - end - return cuTileInterpreter(world, method_table, inf_cache, inf_params, opt_params) -end - -# Required AbstractInterpreter interface methods -CC.InferenceParams(interp::cuTileInterpreter) = interp.inf_params -CC.OptimizationParams(interp::cuTileInterpreter) = interp.opt_params -CC.get_inference_cache(interp::cuTileInterpreter) = interp.inf_cache - -# World age -@static if isdefined(CC, :get_inference_world) - CC.get_inference_world(interp::cuTileInterpreter) = interp.world -else - CC.get_world_counter(interp::cuTileInterpreter) = interp.world -end - -# Method table - this enables the overlays -CC.method_table(interp::cuTileInterpreter) = interp.method_table - -# Locking - not needed for non-cached compilation -CC.lock_mi_inference(::cuTileInterpreter, ::MethodInstance) = nothing -CC.unlock_mi_inference(::cuTileInterpreter, ::MethodInstance) = nothing - -# Cache owner - use a unique identifier to cache inference results separately -CC.cache_owner(::cuTileInterpreter) = :cutile - -# Optimization flags -CC.may_optimize(::cuTileInterpreter) = true -CC.may_compress(::cuTileInterpreter) = true -CC.may_discard_trees(::cuTileInterpreter) = true - -# Disable semi-concrete interpretation (broken with overlays per JuliaLang/julia#47349) -function CC.concrete_eval_eligible(interp::cuTileInterpreter, - @nospecialize(f), result::CC.MethodCallResult, arginfo::CC.ArgInfo, sv::CC.InferenceState) - ret = @invoke CC.concrete_eval_eligible(interp::CC.AbstractInterpreter, - f::Any, result::CC.MethodCallResult, arginfo::CC.ArgInfo, sv::CC.InferenceState) - if ret === :semi_concrete_eval - return :none - end - return ret -end - - -#============================================================================= - Public API -=============================================================================# - -""" - get_ircode(f, argtypes; world=Base.get_world_counter(), always_inline=true) -> (IRCode, return_type) - -Get the optimized IRCode for a function with the given argument types. -Uses cuTile's method overlay table to redirect Base operations to cuTile intrinsics. -If always_inline=true (default), forces all functions to be inlined. -""" -function get_ircode(@nospecialize(f), @nospecialize(argtypes); - world::UInt = Base.get_world_counter(), - always_inline::Bool = true) - interp = cuTileInterpreter(world; always_inline) - mi = get_method_instance(f, argtypes; world) - result = CC.typeinf_ircode(interp, mi, nothing) - - if result === nothing - error("Type inference failed for $f with argument types $argtypes") - end - - ir, rettype = result - return ir, rettype -end - -""" - get_method_instance(f, argtypes; world=Base.get_world_counter()) -> MethodInstance - -Get the MethodInstance for a function call. -""" -function get_method_instance(@nospecialize(f), @nospecialize(argtypes); world::UInt = Base.get_world_counter()) - tt = Base.signature_type(f, argtypes) - match = Base._which(tt; world) - return CC.specialize_method(match) -end diff --git a/src/compiler/reflection.jl b/src/compiler/reflection.jl deleted file mode 100644 index 12c4cb3..0000000 --- a/src/compiler/reflection.jl +++ /dev/null @@ -1,95 +0,0 @@ -export code_tiled, @code_tiled - -""" - emit_tileir(f, argtypes; name, sm_arch, num_ctas, occupancy) -> Vector{UInt8} - -Compile a Julia function to Tile IR bytecode. -""" -function emit_tileir(@nospecialize(f), @nospecialize(argtypes); - name::Union{String, Nothing} = nothing, - sm_arch::Union{String, Nothing} = nothing, - num_ctas::Union{Int, Nothing} = nothing, - occupancy::Union{Int, Nothing} = nothing) - target = TileTarget(f, argtypes) - kernel_name = name === nothing ? string(target.mi.def.name) : name - - if compile_hook[] !== nothing - compile_hook[](f, argtypes; name=name) - end - - buf = write_bytecode!(1) do writer, func_buf - emit_kernel!(writer, func_buf, target; name=kernel_name, sm_arch, - num_ctas, occupancy) - end - - return buf -end - -function disassemble_tileir(bytecode::Vector{UInt8})::String - mktempdir() do dir - input_path = joinpath(dir, "kernel.tile") - output_path = joinpath(dir, "kernel.disasm") - write(input_path, bytecode) - read(`$(cuda_tile_translate()) --cudatilebc-to-mlir $input_path`, String) - end -end - -""" - code_tiled(f, argtypes; name, sm_arch, num_ctas, occupancy) -> String - -Return the CUDA Tile IR for a Julia function as a textual MLIR representation. -Analogous to `code_typed` or `code_structured`. -""" -function code_tiled(@nospecialize(f), @nospecialize(argtypes); - kwargs...) - bytecode = emit_tileir(f, argtypes; kwargs...) - disassemble_tileir(bytecode) -end - -# compilation hooking: taken from GPUCompiler.jl -const compile_hook = Ref{Union{Nothing,Function}}(nothing) -function emit_hooked_compilation(inner_hook, ex...) - user_code = ex[end] - user_kwargs = ex[1:end-1] - quote - # we only want to invoke the hook once for every compilation - seen = Set() - function outer_hook(f, tt; kwargs...) - key = (f, tt) - if !in(key, seen) - # the user hook might invoke the compiler again, so disable the hook - old_hook = $compile_hook[] - try - $compile_hook[] = nothing - $inner_hook(f, tt; kwargs..., $(map(esc, user_kwargs)...)) - finally - $compile_hook[] = old_hook - end - push!(seen, key) - end - end - - # now invoke the user code with this hook in place - try - $compile_hook[] = outer_hook - $(esc(user_code)) - finally - $compile_hook[] = nothing - end - - if isempty(seen) - error("no kernels executed while evaluating the given expression") - end - - nothing - end -end - -macro code_tiled(ex...) - function hook(f, tt; io::IO=stdout, kwargs...) - println(io, "// $f($(join(map(string, tt.parameters), ", ")))") - println(io) - println(io, code_tiled(f, tt; kwargs...)) - end - emit_hooked_compilation(hook, ex...) -end diff --git a/src/compiler/target.jl b/src/compiler/target.jl deleted file mode 100644 index af26ce7..0000000 --- a/src/compiler/target.jl +++ /dev/null @@ -1,389 +0,0 @@ -# TileTarget and CGCtx for cuTile compilation -# -# Holds compilation target and state for a kernel. - -#============================================================================= - TileTarget: Compilation target -=============================================================================# - -""" - TileTarget - -Holds everything about a function being compiled to Tile IR: -the structured IR and return type. -""" -struct TileTarget - mi::MethodInstance - sci::StructuredIRCode - rettype::Type -end - -function TileTarget(@nospecialize(f), @nospecialize(argtypes::Type{<:Tuple})) - ir, rettype = get_ircode(f, argtypes) - mi = get_method_instance(f, argtypes) - sci = StructuredIRCode(ir) - TileTarget(mi, sci, rettype) -end - -# Accessors -# Count non-ghost arguments (excludes function type for regular functions) -nargs(target::TileTarget) = count(!is_ghost_type ∘ unwrap_type, target.sci.argtypes) - -#============================================================================= - CGVal: Unified value representation (analogous to Julia's jl_cgval_t) -=============================================================================# - -""" - CGVal - -Represents a value during Tile IR code generation, bundling the IR value -with its type information and metadata. - -Similar to Julia compiler's `jl_cgval_t`, this provides a unified representation -for all values flowing through codegen. A CGVal can be either: -1. A concrete SSA value (v::Value) -2. A multi-value result from control flow ops (v::Vector{Value}) -3. A lazy argument reference chain (v is nothing, arg_ref tracks the access path) -4. A ghost value (v is nothing, zero-size singleton) -""" -struct CGVal - v::Union{Value, Vector{Value}, Nothing} # Single value, multi-value, or nothing - type_id::Union{TypeId, Nothing} # Tile IR type (nothing for lazy refs or multi-value) - jltype::Any # Original Julia type - shape::Vector{Int} # Tile shape (empty for scalars) - # Lazy argument reference: (arg_idx, [:field, index, ...]) - # e.g., (1, [:sizes, 2]) means "argument 1, field :sizes, index 2" - arg_ref::Union{Tuple{Int, Vector{Union{Symbol, Int}}}, Nothing} - constant::Union{Some, Nothing} # Nothing = no constant, Some(x) = constant value x - tuple::Union{Vector{Any}, Nothing} # For tuples: component refs (SSAValue, etc.) -end - -# Convenience constructors for concrete values -CGVal(v::Value, type_id::TypeId, @nospecialize(jltype)) = - CGVal(v, type_id, jltype, Int[], nothing, nothing, nothing) - -CGVal(v::Value, type_id::TypeId, @nospecialize(jltype), shape::Vector{Int}) = - CGVal(v, type_id, jltype, shape, nothing, nothing, nothing) - -# Constructor for multi-value results (from loops, ifs) -CGVal(v::Vector{Value}, @nospecialize(jltype)) = - CGVal(v, nothing, jltype, Int[], nothing, nothing, nothing) - -# Constructor for lazy argument references -function arg_ref_value(arg_idx::Int, chain::Vector{Union{Symbol, Int}}, @nospecialize(jltype)) - CGVal(nothing, nothing, jltype, Int[], (arg_idx, chain), nothing, nothing) -end - -""" - ghost_value(jltype[, constant]) -> CGVal - -Create a ghost value (zero-size singleton with no runtime representation). -Optionally stores a compile-time constant value. -""" -ghost_value(@nospecialize(jltype)) = CGVal(nothing, TypeId(-1), jltype, Int[], nothing, nothing, nothing) -ghost_value(@nospecialize(jltype), constant) = CGVal(nothing, TypeId(-1), jltype, Int[], nothing, Some(constant), nothing) - -""" - tuple_value(jltype, component_refs, component_constants) -> CGVal - -Create a tuple value with tracked component refs. Derives constant if all components have constants. -Used by intrinsics like cat() that need to access individual tuple elements. -""" -function tuple_value(@nospecialize(jltype), component_refs::Vector{Any}, component_constants::Vector{Any}) - # If all components have constants, derive the tuple constant - constant = if all(!isnothing, component_constants) - Some(Tuple(component_constants)) - else - nothing - end - CGVal(nothing, TypeId(-1), jltype, Int[], nothing, constant, component_refs) -end - -""" - is_ghost(tv::CGVal) -> Bool - -Check if a CGVal is a ghost (no runtime representation). -""" -is_ghost(tv::CGVal) = tv.v === nothing && tv.arg_ref === nothing - -""" - is_arg_ref(tv::CGVal) -> Bool - -Check if a CGVal is a lazy argument reference. -""" -is_arg_ref(tv::CGVal) = tv.arg_ref !== nothing - -#============================================================================= - CGCtx: Compilation context -=============================================================================# - -""" - CGCtx - -Holds all state during Tile IR code generation for a kernel function. -Maps Julia SSA values to CGVals and manages bytecode emission. -""" -mutable struct CGCtx - # SSA value mapping: original Julia SSA index -> CGVal - # Uses global/original indices everywhere (no local renumbering) - # Loop/if ops store a CGVal with tuple_values field (extracted by getfield statements) - values::Dict{Int, CGVal} - args::Dict{Int, CGVal} # Argument index -> CGVal - slots::Dict{Int, CGVal} # Slot number -> CGVal - block_args::Dict{Int, CGVal} # BlockArg id -> CGVal (for control flow) - - # Destructured argument handling (for TileArray fields) - arg_flat_values::Dict{Tuple{Int, Union{Nothing, Symbol}}, Vector{Value}} - arg_types::Dict{Int, Type} - - # Cached TensorViews for TileArray arguments (arg_idx -> (Value, TypeId)) - tensor_views::Dict{Int, Tuple{Value, TypeId}} - - # Bytecode infrastructure - cb::CodeBuilder - tt::TypeTable - target::TileTarget - - # Memory ordering token - token::Union{Value, Nothing} - token_type::Union{TypeId, Nothing} - - # Type cache: Julia type -> TypeId - type_cache::Dict{Type, TypeId} - - # Target architecture (e.g., :sm_100) - sm_arch::Union{String, Nothing} -end - -function CGCtx(writer::BytecodeWriter, target::TileTarget, sm_arch::Union{String, Nothing}=nothing) - CGCtx( - Dict{Int, CGVal}(), - Dict{Int, CGVal}(), - Dict{Int, CGVal}(), - Dict{Int, CGVal}(), - Dict{Tuple{Int, Union{Nothing, Symbol}}, Vector{Value}}(), - Dict{Int, Type}(), - Dict{Int, Tuple{Value, TypeId}}(), # tensor_views cache - CodeBuilder(writer.string_table, writer.constant_table, writer.type_table), - writer.type_table, - target, - nothing, - nothing, - Dict{Type, TypeId}(), - sm_arch, - ) -end - -#============================================================================= - Value lookup via indexing syntax -=============================================================================# - -function Base.getindex(ctx::CGCtx, ssa::SSAValue) - # Simple lookup by original Julia SSA index - get(ctx.values, ssa.id, nothing) -end - -function Base.getindex(ctx::CGCtx, arg::Argument) - get(ctx.args, arg.n, nothing) -end - -function Base.getindex(ctx::CGCtx, slot::SlotNumber) - get(ctx.slots, slot.id, nothing) -end - -function Base.setindex!(ctx::CGCtx, tv::CGVal, ssa::SSAValue) - ctx.values[ssa.id] = tv -end - -function Base.setindex!(ctx::CGCtx, tv::CGVal, arg::Argument) - ctx.args[arg.n] = tv -end - -function Base.setindex!(ctx::CGCtx, tv::CGVal, slot::SlotNumber) - ctx.slots[slot.id] = tv -end - -function Base.getindex(ctx::CGCtx, block_arg::BlockArg) - get(ctx.block_args, block_arg.id, nothing) -end - -function Base.setindex!(ctx::CGCtx, tv::CGVal, block_arg::BlockArg) - ctx.block_args[block_arg.id] = tv -end - -#============================================================================= - Destructured argument helpers -=============================================================================# - -""" - get_arg_flat_values(ctx, arg_idx, field=nothing) -> Union{Vector{Value}, Nothing} - -Get the flat Tile IR values for an argument or its field. -""" -function get_arg_flat_values(ctx::CGCtx, arg_idx::Int, field::Union{Nothing, Symbol}=nothing) - get(ctx.arg_flat_values, (arg_idx, field), nothing) -end - -""" - is_destructured_arg(ctx, arg_idx) -> Bool - -Check if an argument was destructured into multiple flat parameters. -""" -is_destructured_arg(ctx::CGCtx, arg_idx::Int) = haskey(ctx.arg_types, arg_idx) - -""" - get_arg_type(ctx, arg_idx) -> Union{Type, Nothing} - -Get the original Julia type for a destructured argument. -""" -get_arg_type(ctx::CGCtx, arg_idx::Int) = get(ctx.arg_types, arg_idx, nothing) - -#============================================================================= - Type conversion utilities -=============================================================================# - -""" - unwrap_type(T) -> Type - -Unwrap type wrappers like Core.Const to get the actual type. -""" -function unwrap_type(@nospecialize(T)) - if T isa Core.Const - return typeof(T.val) - elseif T isa Core.PartialStruct - return T.typ - elseif T isa Type - return T - else - return T - end -end - -""" - require_concrete_type(T, context::String) - -Ensure a type is fully concrete (not a UnionAll). -""" -function require_concrete_type(@nospecialize(T), context::String) - T_unwrapped = unwrap_type(T) - if T_unwrapped isa UnionAll - error("Type must be fully concrete in $context, got partial type: $T") - end - return T_unwrapped -end - -""" - tile_type_for_julia!(ctx, T) -> TypeId - -Get or create a Tile IR type for a Julia type. -""" -function tile_type_for_julia!(ctx::CGCtx, @nospecialize(T)) - actual_type = unwrap_type(T) - get!(ctx.type_cache, actual_type) do - _tile_type_for_julia!(ctx.tt, actual_type) - end -end - -function _tile_type_for_julia!(tt::TypeTable, @nospecialize(T::Type)) - # Scalar types -> 0-D tile - if T === Bool - return tile_type!(tt, I1(tt), Int[]) - elseif T === Int32 || T === UInt32 - return tile_type!(tt, I32(tt), Int[]) - elseif T === Int64 || T === UInt64 - return tile_type!(tt, I64(tt), Int[]) - elseif T === Float16 - return tile_type!(tt, F16(tt), Int[]) - elseif T === Float32 - return tile_type!(tt, F32(tt), Int[]) - elseif T === Float64 - return tile_type!(tt, F64(tt), Int[]) - elseif T === Nothing - return Token(tt) - end - - # Pointers -> 0-D tile of pointer type - if T <: Ptr - elem_dtype = julia_to_tile_dtype!(tt, eltype(T)) - ptr_type = pointer_type!(tt, elem_dtype) - return tile_type!(tt, ptr_type, Int[]) - end - - # Tile{T, Shape} -> tile type with shape - if T <: Tile - if T isa UnionAll || !isa(T, DataType) || length(T.parameters) < 2 - error("Tile type must be fully specified with element type and shape, got: $T. " * - "This indicates type instability in the kernel - ensure all tile operations have inferrable shapes.") - end - elem_type = T.parameters[1] - shape_param = T.parameters[2] - if !(shape_param isa Tuple) - error("Tile shape must be a tuple, got: $shape_param") - end - elem_dtype = julia_to_tile_dtype!(tt, elem_type) - shape = collect(Int, shape_param) - return tile_type!(tt, elem_dtype, shape) - end - - error("Unsupported Julia type for Tile IR: $T") -end - -""" - tile_type_and_shape_for_julia!(ctx, T) -> (TypeId, Vector{Int}) - -Get the Tile IR type and shape for a Julia type. -""" -function tile_type_and_shape_for_julia!(ctx::CGCtx, @nospecialize(T)) - actual_type = unwrap_type(T) - type_id = tile_type_for_julia!(ctx, actual_type) - - # Extract shape from Tile types - shape = Int[] - if actual_type <: Tile && length(actual_type.parameters) >= 2 - shape_param = actual_type.parameters[2] - if shape_param isa Tuple - shape = collect(Int, shape_param) - end - end - - return (type_id, shape) -end - -#============================================================================= - Struct destructuring helpers -=============================================================================# - -""" - is_ghost_type(T) -> Bool - -Check if a type is a ghost type (zero-size singleton). -""" -function is_ghost_type(@nospecialize(T)) - try - isbitstype(T) && sizeof(T) == 0 - catch - false - end -end - -""" - should_destructure(T) -> Bool - -Check if a type should be destructured into flat parameters. -""" -function should_destructure(@nospecialize(T)) - T = unwrap_type(T) - isstructtype(T) || return false - is_ghost_type(T) && return false - isprimitivetype(T) && return false - T <: TileArray && return true - return false -end - -""" - flat_field_count(T) -> Int - -Count flat parameters a type expands to. -""" -flat_field_count(::Type{<:NTuple{N, T}}) where {N, T} = N -flat_field_count(::Type) = 1 diff --git a/src/cuTile.jl b/src/cuTile.jl index 375aaa2..6cd5545 100644 --- a/src/cuTile.jl +++ b/src/cuTile.jl @@ -23,11 +23,9 @@ include("bytecode/encodings.jl") include("language/types.jl") # Compiler implementation -include("compiler/interpreter.jl") -include("compiler/target.jl") +include("compiler/interface.jl") include("compiler/codegen.jl") include("compiler/intrinsics.jl") -include("compiler/reflection.jl") # Language implementation include("language/broadcast.jl")