Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions src/compiler/intrinsics/atomics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,160 @@ end
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_add), args)
emit_atomic_rmw!(ctx, args, AtomicADD)
end

# ============================================================================
# Tile-wise atomic operations
# These take pre-computed pointer tiles, value tiles, and masks.
# Used by the public API for tile-indexed atomic operations.
# ============================================================================

# cuda_tile.atomic_cas_tko with tile pointers
@eval Intrinsics begin
"""
atomic_cas_tile(ptr_tile, expected, desired, mask, memory_order, memory_scope)

Tile-wise atomic compare-and-swap.
Operates on a tile of pointers with a tile of expected/desired values.
Mask controls which elements are active (bounds checking).
Returns a tile of original values.
"""
@noinline function atomic_cas_tile(ptr_tile::Tile, expected::Tile{T, S},
desired::Tile{T, S}, mask::Tile,
memory_order::Int, memory_scope::Int) where {T, S}
donotdelete(ptr_tile, expected, desired, mask)
compilerbarrier(:const, expected)::Tile{T, S}
end
end

function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas_tile), args)
cb = ctx.cb
tt = ctx.tt

# args: (ptr_tile, expected, desired, mask, memory_order, memory_scope)
ptr_tv = emit_value!(ctx, args[1])
ptr_tv === nothing && error("atomic_cas_tile requires ptr_tile")
expected_tv = emit_value!(ctx, args[2])
expected_tv === nothing && error("atomic_cas_tile requires expected value")
desired_tv = emit_value!(ctx, args[3])
desired_tv === nothing && error("atomic_cas_tile requires desired value")
mask_tv = emit_value!(ctx, args[4])
mask_tv === nothing && error("atomic_cas_tile requires mask")

# Get memory order and scope
memory_order = @something get_constant(ctx, args[5]) error("atomic_cas_tile requires constant memory_order")
memory_scope = @something get_constant(ctx, args[6]) error("atomic_cas_tile requires constant memory_scope")

# Get shape and element type from expected tile
shape = expected_tv.shape
elem_type = expected_tv.jltype.parameters[1] # T from Tile{T, S}

# Create result type (tile with same shape as inputs)
dtype = julia_to_tile_dtype!(tt, elem_type)
result_tile_type = tile_type!(tt, dtype, collect(shape))
token_type = Token(tt)

# Emit atomic CAS with mask
mem_ordering = memory_order_to_semantics(memory_order)
mem_scope = memory_scope_to_scope(memory_scope)

old_val, new_token = encode_AtomicCASPtrOp!(cb, result_tile_type, token_type,
ptr_tv.v, expected_tv.v, desired_tv.v;
mask=mask_tv.v,
token=ctx.token,
memory_ordering=mem_ordering,
memory_scope=mem_scope)
ctx.token = new_token

# Return Tile type with the same shape
CGVal(old_val, result_tile_type, Tile{elem_type, Tuple(shape)}, collect(shape))
end

# Shared helper for tile-wise atomic RMW operations
function emit_atomic_rmw_tile!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode)
cb = ctx.cb
tt = ctx.tt

# args: (ptr_tile, val, mask, memory_order, memory_scope)
ptr_tv = emit_value!(ctx, args[1])
ptr_tv === nothing && error("atomic RMW tile requires ptr_tile")
val_tv = emit_value!(ctx, args[2])
val_tv === nothing && error("atomic RMW tile requires value")
mask_tv = emit_value!(ctx, args[3])
mask_tv === nothing && error("atomic RMW tile requires mask")

# Get memory order and scope
memory_order = @something get_constant(ctx, args[4]) error("atomic RMW tile requires constant memory_order")
memory_scope = @something get_constant(ctx, args[5]) error("atomic RMW tile requires constant memory_scope")

# Get shape and element type from value tile
shape = val_tv.shape
elem_type = val_tv.jltype.parameters[1] # T from Tile{T, S}

# Create result type (tile with same shape as inputs)
dtype = julia_to_tile_dtype!(tt, elem_type)
result_tile_type = tile_type!(tt, dtype, collect(shape))
token_type = Token(tt)

# Use float add mode for floating point types
actual_mode = mode
if mode == AtomicADD && elem_type <: AbstractFloat
actual_mode = AtomicADDF
end

# Emit atomic RMW with mask
mem_ordering = memory_order_to_semantics(memory_order)
mem_scope = memory_scope_to_scope(memory_scope)

old_val, new_token = encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type,
ptr_tv.v, val_tv.v, actual_mode;
mask=mask_tv.v,
token=ctx.token,
memory_ordering=mem_ordering,
memory_scope=mem_scope)
ctx.token = new_token

# Return Tile type with the same shape
CGVal(old_val, result_tile_type, Tile{elem_type, Tuple(shape)}, collect(shape))
end

# cuda_tile.atomic_rmw_tko with XCHG (tile version)
@eval Intrinsics begin
"""
atomic_xchg_tile(ptr_tile, val, mask, memory_order, memory_scope)

Tile-wise atomic exchange.
Operates on a tile of pointers with a tile of values.
Mask controls which elements are active (bounds checking).
Returns a tile of original values.
"""
@noinline function atomic_xchg_tile(ptr_tile::Tile, val::Tile{T, S}, mask::Tile,
memory_order::Int, memory_scope::Int) where {T, S}
donotdelete(ptr_tile, val, mask)
compilerbarrier(:const, val)::Tile{T, S}
end
end

function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_xchg_tile), args)
emit_atomic_rmw_tile!(ctx, args, AtomicXCHG)
end

# cuda_tile.atomic_rmw_tko with ADD (tile version)
@eval Intrinsics begin
"""
atomic_add_tile(ptr_tile, val, mask, memory_order, memory_scope)

Tile-wise atomic addition.
Operates on a tile of pointers with a tile of values.
Mask controls which elements are active (bounds checking).
Returns a tile of original values.
"""
@noinline function atomic_add_tile(ptr_tile::Tile, val::Tile{T, S}, mask::Tile,
memory_order::Int, memory_scope::Int) where {T, S}
donotdelete(ptr_tile, val, mask)
compilerbarrier(:const, val)::Tile{T, S}
end
end

function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_add_tile), args)
emit_atomic_rmw_tile!(ctx, args, AtomicADD)
end
10 changes: 10 additions & 0 deletions src/language/arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,13 @@ for (op, pred) in ((:<, :CmpLessThan), (:>, :CmpGreaterThan),
_cmp_intrinsic(broadcast_to(Tile(T(a)), S), b, $pred)
end
end

# For index tile arithmetic:
@inline Base.Broadcast.broadcasted(::TileStyle, ::typeof(-), a::Tile{T,S}, ::Base.RefValue{One}) where {T<:Integer,S} =
a .- one(T)
@inline Base.Broadcast.broadcasted(::TileStyle, ::typeof(+), a::Tile{T,S}, ::Base.RefValue{One}) where {T<:Integer,S} =
a .+ one(T)
@inline Base.Broadcast.broadcasted(::TileStyle, ::typeof(-), ::Base.RefValue{One}, a::Tile{T,S}) where {T<:Integer,S} =
one(T) .- a
@inline Base.Broadcast.broadcasted(::TileStyle, ::typeof(+), ::Base.RefValue{One}, a::Tile{T,S}) where {T<:Integer,S} =
one(T) .+ a
Loading