diff --git a/src/host/indexing.jl b/src/host/indexing.jl index 401780c6..77ad5279 100644 --- a/src/host/indexing.jl +++ b/src/host/indexing.jl @@ -167,8 +167,14 @@ end function Base._unsafe_setindex!(::IndexStyle, A::WrappedGPUArray, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...) end -# And allow one more `ReshapedArray` wrapper to handle the `_maybe_reshape` optimization. -function Base._unsafe_setindex!(::IndexStyle, A::Base.ReshapedArray{<:Any, <:Any, <:WrappedGPUArray}, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N + +#Implementation for ReshapedArrays using Cartesian indexing to resolve dispatch ties. +function Base._unsafe_setindex!(::Base.IndexCartesian, A::Base.ReshapedArray{T, N, <:WrappedGPUArray}, x, Is::Vararg{Union{Real, AbstractArray}, N}) where {T, N} + return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...) +end + +#Implementation for ReshapedArrays using Linear indexing to resolve dispatch ties. +function Base._unsafe_setindex!(::Base.IndexLinear, A::Base.ReshapedArray{T, N, <:WrappedGPUArray}, x, Is::Vararg{Union{Real, AbstractArray}, N}) where {T, N} return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...) end diff --git a/test/testsuite/indexing.jl b/test/testsuite/indexing.jl index 2c44d21a..e04192cf 100644 --- a/test/testsuite/indexing.jl +++ b/test/testsuite/indexing.jl @@ -284,3 +284,67 @@ end @test compare(argmin, AT, -rand(Int, 10)) end end + +@testsuite "indexing combinatorial" (AT, eltypes) -> begin + @testset "Reshaped SubArray dispatch" for T in eltypes + @testset "3D slice assignment" begin + A = AT(ones(T, 4, 4, 4)) + @views V = A[:, :, 1:2] + @allowscalar begin + @test_nowarn V .= zero(T) + @test all(Array(V) .== zero(T)) + end + end + + @testset "Logical mask view (dim = 3) — GPU safe" begin + A = AT(ones(T, 4, 4, 4)) + idx = findall(Bool[true, false, true, false]) + @views V = A[:, :, idx] + @allowscalar begin + @test_nowarn V .+= T(2) + @test all(Array(V) .== T(3)) + end + end + + @testset "Nested Reshape" begin + A = AT(ones(T, 4, 4, 4)) + V = view(A, 1:2, 1:2, 1:2) + R1 = reshape(V, 4, 2) + R2 = reshape(R1, :) + @allowscalar begin + @test_nowarn R2 .+= one(T) + @test all(Array(R2) .== T(2)) + end + end + end + + @testset "Permuted and Reinterpreted Views" for T in eltypes + @testset "Reshaped PermutedDims" begin + A = AT(ones(T, 4, 4)) + P = PermutedDimsArray(A, (2, 1)) + R = reshape(P, :) + @allowscalar begin + @test_nowarn R[1:2] .= zero(T) + @test Array(R)[1] == zero(T) + end + end + + @testset "Reshaped Reinterpreted" begin + A = AT(ones(T, 4, 4)) + IT = T <: Complex ? Complex{Int16} : Int16 + R = reshape(reinterpret(IT, A), :) + @allowscalar begin + @test_nowarn R[1] = zero(IT) + @test Array(R)[1] == zero(IT) + end + end + end + + @testset "Data parity with compare() — GPU safe" for T in eltypes + @test compare(AT, rand(T, 8, 8, 8)) do A + @views V = A[:, idx, :] + @allowscalar V .+= one(T) + A + end + end +end