From abda01fa69530374c05886736afca5564968b624 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Sun, 20 Jul 2025 14:43:43 -0300 Subject: [PATCH 1/3] Use GPUArrays accumulation implementation --- src/Metal.jl | 1 - src/accumulate.jl | 202 ---------------------------------------------- test/array.jl | 31 ------- 3 files changed, 234 deletions(-) delete mode 100644 src/accumulate.jl diff --git a/src/Metal.jl b/src/Metal.jl index 90d859d92..c58e731e0 100644 --- a/src/Metal.jl +++ b/src/Metal.jl @@ -64,7 +64,6 @@ include("linalg.jl") include("utilities.jl") include("broadcast.jl") include("mapreduce.jl") -include("accumulate.jl") include("indexing.jl") include("random.jl") include("gpuarrays.jl") diff --git a/src/accumulate.jl b/src/accumulate.jl deleted file mode 100644 index 31e2dc4fe..000000000 --- a/src/accumulate.jl +++ /dev/null @@ -1,202 +0,0 @@ -## COV_EXCL_START -function partial_scan(op::Function, output::AbstractArray{T}, input::AbstractArray, - Rdim, Rpre, Rpost, Rother, neutral, init, - ::Val{maxthreads}, ::Val{inclusive}=Val(true)) where {T, maxthreads, inclusive} - threads = threads_per_threadgroup_3d().x - thread = thread_position_in_threadgroup_3d().x - - temp = MtlThreadGroupArray(T, (Int32(2) * maxthreads,)) - - i = (threadgroup_position_in_grid_3d().x - Int32(1)) * threads_per_threadgroup_3d().x + thread_position_in_threadgroup_3d().x - j = (threadgroup_position_in_grid_3d().z - Int32(1)) * threadgroups_per_grid_3d().y + threadgroup_position_in_grid_3d().y - - if j > length(Rother) - return - end - - @inbounds begin - I = Rother[j] - Ipre = Rpre[I[1]] - Ipost = Rpost[I[2]] - end - - @inbounds temp[thread] = if i <= length(Rdim) - op(neutral, input[Ipre, i, Ipost]) - else - op(neutral, neutral) - end - - offset = one(thread) - d = threads >> 0x1 - while d > zero(d) - threadgroup_barrier(MemoryFlagThreadGroup) - @inbounds if thread <= d - ai = offset * (thread << 0x1 - 0x1) - bi = offset * (thread << 0x1) - temp[bi] = op(temp[ai], temp[bi]) - end - offset <<= 0x1 - d >>= 0x1 - end - - @inbounds if isone(thread) - temp[threads] = neutral - end - - d = one(thread) - while d < threads - offset >>= 0x1 - threadgroup_barrier(MemoryFlagThreadGroup) - @inbounds if thread <= d - ai = offset * (thread << 0x1 - 0x1) - bi = offset * (thread << 0x1) - - t = temp[ai] - temp[ai] = temp[bi] - temp[bi] = op(t, temp[bi]) - end - d <<= 0x1 - end - - threadgroup_barrier(MemoryFlagThreadGroup) - - @inbounds if i <= length(Rdim) - val = if inclusive - op(temp[thread], input[Ipre, i, Ipost]) - else - temp[thread] - end - if init !== nothing - val = op(something(init), val) - end - output[Ipre, i, Ipost] = val - end - - return -end - -function aggregate_partial_scan(op::Function, output::AbstractArray, aggregates::AbstractArray, Rdim, Rpre, Rpost, Rother, init) - block = threadgroup_position_in_grid_3d().x - - i = (threadgroup_position_in_grid_3d().x - Int32(1)) * threads_per_threadgroup_3d().x + thread_position_in_threadgroup_3d().x - j = (threadgroup_position_in_grid_3d().z - Int32(1)) * threadgroups_per_grid_3d().y + threadgroup_position_in_grid_3d().y - - @inbounds if i <= length(Rdim) && j <= length(Rother) - I = Rother[j] - Ipre = Rpre[I[1]] - Ipost = Rpost[I[2]] - - val = if block > 1 - op(aggregates[Ipre, block - Int32(1), Ipost], output[Ipre, i, Ipost]) - else - output[Ipre, i, Ipost] - end - - if init !== nothing - val = op(something(init), val) - end - - output[Ipre, i, Ipost] = val - end - - return -end -## COV_EXCL_STOP - -function scan!(f::Function, output::WrappedMtlArray{T}, input::WrappedMtlArray; - dims::Integer, init=nothing, neutral=GPUArrays.neutral_element(f, T)) where {T} - dims > 0 || throw(ArgumentError("dims must be a positive integer")) - inds_t = axes(input) - axes(output) == inds_t || throw(DimensionMismatch("shape of B must match A")) - dims > ndims(input) && return copyto!(output, input) - isempty(inds_t[dims]) && return output - - # iteration domain across the main dimension - Rdim = CartesianIndices((size(input, dims),)) - - # iteration domain for the other dimensions - Rpre = CartesianIndices(size(input)[1:dims-1]) - Rpost = CartesianIndices(size(input)[dims+1:end]) - Rother = CartesianIndices((length(Rpre), length(Rpost))) - - # the maximum number of threads is limited by the hardware - dev = device() - maxthreads = min(Int(dev.maxThreadsPerThreadgroup.width), - Int(dev.maxThreadgroupMemoryLength) ÷ sizeof(T) ÷ 2) - - # determine how many threads we can launch for the scan kernel - kernel = @metal launch=false partial_scan(f, output, input, Rdim, Rpre, Rpost, Rother, neutral, init, Val(maxthreads), Val(true)) - threads = Int(kernel.pipeline.maxTotalThreadsPerThreadgroup) - - # determine the grid layout to cover the other dimensions - blocks_other = (length(Rother), 1) - - # does that suffice to scan the array in one go? - full = nextpow(2, length(Rdim)) - if full <= threads - @metal(threads=full, groups=(1, blocks_other...), - partial_scan(f, output, input, Rdim, Rpre, Rpost, Rother, neutral, init, Val(full), Val(true))) - else - # perform partial scans across the scanning dimension - # NOTE: don't set init here to avoid applying the value multiple times - partial = prevpow(2, threads) - blocks_dim = cld(length(Rdim), partial) - @metal(threads=partial, groups=(blocks_dim, blocks_other...), - partial_scan(f, output, input, Rdim, Rpre, Rpost, Rother, neutral, nothing, Val(partial), Val(true))) - - # get the total of each thread block (except the first) of the partial scans - aggregates = fill(neutral, Base.setindex(size(input), blocks_dim, dims)) - partials = selectdim(output, dims, partial:partial:length(Rdim)) - indices = CartesianIndices(partials) - copyto!(aggregates, indices, partials, indices) - - # scan these totals to get totals for the entire partial scan - accumulate!(f, aggregates, aggregates; dims=dims) - - # add those totals to the partial scan result - # NOTE: we assume that this kernel requires fewer resources than the scan kernel. - # if that does not hold, launch with fewer threads and calculate - # the aggregate block index within the kernel itself. - @metal(threads=partial, groups=(blocks_dim, blocks_other...), - aggregate_partial_scan(f, output, aggregates, Rdim, Rpre, Rpost, Rother, init)) - - unsafe_free!(aggregates) - end - - return output -end - - -## Base interface - -Base._accumulate!(op, output::WrappedMtlArray, input::WrappedMtlVector, dims::Nothing, init::Nothing) = - scan!(op, output, input; dims=1) - -Base._accumulate!(op, output::WrappedMtlArray, input::WrappedMtlArray, dims::Integer, init::Nothing) = - scan!(op, output, input; dims=dims) - -Base._accumulate!(op, output::WrappedMtlArray, input::MtlVector, dims::Nothing, init::Some) = - scan!(op, output, input; dims=1, init=init) - -Base._accumulate!(op, output::WrappedMtlArray, input::WrappedMtlArray, dims::Integer, init::Some) = - scan!(op, output, input; dims=dims, init=init) - -Base.accumulate_pairwise!(op, result::WrappedMtlVector, v::WrappedMtlVector) = accumulate!(op, result, v) - -# default behavior unless dims are specified by the user -function Base.accumulate(op, A::WrappedMtlArray; - dims::Union{Nothing,Integer}=nothing, kw...) - if dims === nothing && !(A isa AbstractVector) - # This branch takes care of the cases not handled by `_accumulate!`. - return reshape(accumulate(op, A[:]; kw...), size(A)) - end - nt = values(kw) - if isempty(kw) - out = similar(A, Base.promote_op(op, eltype(A), eltype(A))) - elseif keys(nt) === (:init,) - out = similar(A, Base.promote_op(op, typeof(nt.init), eltype(A))) - else - throw(ArgumentError("accumulate does not support the keyword arguments $(setdiff(keys(nt), (:init,)))")) - end - accumulate!(op, out, A; dims=dims, kw...) -end diff --git a/test/array.jl b/test/array.jl index e64dbf9bb..bbe56fb1c 100644 --- a/test/array.jl +++ b/test/array.jl @@ -478,37 +478,6 @@ end sum(reshape(PermutedDimsArray(reshape(Float32.(1:30), 5, 3, 2), (3, 1, 2)), (10, 3)); dims=1) end -@testset "accumulate" begin - for n in (0, 1, 2, 3, 10, 10_000, 16384, 16384+1) # small, large, odd & even, pow2 and not - @test testf(x->accumulate(+, x), rand(Float32, n)) - @test testf(x->accumulate(+, x), rand(Float32, n, 2)) - @test testf(Base.Fix2((x,y)->accumulate(+, x; init=y), rand(Float32)), rand(Float32, n)) - end - - # multidimensional - for (sizes, dims) in ((2,) => 2, - (3,4,5) => 2, - (1, 70, 50, 20) => 3,) - @test testf(x->accumulate(+, x; dims=dims), rand(-10:10, sizes)) - @test testf(x->accumulate(+, x), rand(-10:10, sizes)) - end - - # using initializer - for (sizes, dims) in ((2,) => 2, - (3,4,5) => 2, - (1, 70, 50, 20) => 3) - @test testf(Base.Fix2((x,y)->accumulate(+, x; dims=dims, init=y), rand(-10:10)), rand(-10:10, sizes)) - @test testf(Base.Fix2((x,y)->accumulate(+, x; init=y), rand(-10:10)), rand(-10:10, sizes)) - end - - # in place - @test testf(x->(accumulate!(+, x, copy(x)); x), rand(Float32, 2)) - - # specialized - @test testf(cumsum, rand(Float32, 2)) - @test testf(cumprod, rand(Float32, 2)) -end - @testset "findall" begin # 1D @test testf(x->findall(x), rand(Bool, 1000)) From 12deb2f45f676e263cb231384b665184dc1bfd15 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Sun, 20 Jul 2025 15:13:35 -0300 Subject: [PATCH 2/3] [REMOVE BEFORE MERGE] --- .buildkite/pipeline.yml | 1 + test/runtests.jl | 3 +++ 2 files changed, 4 insertions(+) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 1d44ef482..68f03f977 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -104,6 +104,7 @@ steps: println("--- :julia: Instantiating project") Pkg.develop([PackageSpec(path=pwd())]) + Pkg.add(url="https://github.com/christiangnrd/GPUArrays.jl", rev="accumulatetests") println("+++ :julia: Benchmarking") include("perf/runbenchmarks.jl")' diff --git a/test/runtests.jl b/test/runtests.jl index 49389e328..9b6b0c3d0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,9 @@ if parse(Bool, get(ENV, "BUILDKITE", "false")) end end +using Pkg +Pkg.add(url="https://github.com/christiangnrd/GPUArrays.jl", rev="accumulatetests") + # Quit without erroring if Metal loaded without issues on unsupported platforms if !Sys.isapple() @warn """Metal.jl succesfully loaded on non-macOS system. From 84f519aceed2e8741b3aed3ecac836adc19cd3f2 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Sun, 20 Jul 2025 16:20:08 -0300 Subject: [PATCH 3/3] fgsnb --- test/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/test/Project.toml b/test/Project.toml index 894b74671..5d013042e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -11,6 +11,7 @@ LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" ObjectiveC = "e86c9b32-1129-44ac-8ea0-90d5bb39ded9" ParallelTestRunner = "d3525ed8-44d0-4b2c-a655-542cee43accc" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"