Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/host/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,14 @@ end
function Base._unsafe_setindex!(::IndexStyle, A::WrappedGPUArray, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N
return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...)
end
# And allow one more `ReshapedArray` wrapper to handle the `_maybe_reshape` optimization.
function Base._unsafe_setindex!(::IndexStyle, A::Base.ReshapedArray{<:Any, <:Any, <:WrappedGPUArray}, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N

#Implementation for ReshapedArrays using Cartesian indexing to resolve dispatch ties.
function Base._unsafe_setindex!(::Base.IndexCartesian, A::Base.ReshapedArray{T, N, <:WrappedGPUArray}, x, Is::Vararg{Union{Real, AbstractArray}, N}) where {T, N}
return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...)
end

#Implementation for ReshapedArrays using Linear indexing to resolve dispatch ties.
function Base._unsafe_setindex!(::Base.IndexLinear, A::Base.ReshapedArray{T, N, <:WrappedGPUArray}, x, Is::Vararg{Union{Real, AbstractArray}, N}) where {T, N}
return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...)
end

Expand Down
64 changes: 64 additions & 0 deletions test/testsuite/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,67 @@ end
@test compare(argmin, AT, -rand(Int, 10))
end
end

@testsuite "indexing combinatorial" (AT, eltypes) -> begin
@testset "Reshaped SubArray dispatch" for T in eltypes
@testset "3D slice assignment" begin
A = AT(ones(T, 4, 4, 4))
@views V = A[:, :, 1:2]
@allowscalar begin
@test_nowarn V .= zero(T)
@test all(Array(V) .== zero(T))
end
end

@testset "Logical mask view (dim = 3) — GPU safe" begin
A = AT(ones(T, 4, 4, 4))
idx = findall(Bool[true, false, true, false])
@views V = A[:, :, idx]
@allowscalar begin
@test_nowarn V .+= T(2)
@test all(Array(V) .== T(3))
end
end

@testset "Nested Reshape" begin
A = AT(ones(T, 4, 4, 4))
V = view(A, 1:2, 1:2, 1:2)
R1 = reshape(V, 4, 2)
R2 = reshape(R1, :)
@allowscalar begin
@test_nowarn R2 .+= one(T)
@test all(Array(R2) .== T(2))
end
end
end

@testset "Permuted and Reinterpreted Views" for T in eltypes
@testset "Reshaped PermutedDims" begin
A = AT(ones(T, 4, 4))
P = PermutedDimsArray(A, (2, 1))
R = reshape(P, :)
@allowscalar begin
@test_nowarn R[1:2] .= zero(T)
@test Array(R)[1] == zero(T)
end
end

@testset "Reshaped Reinterpreted" begin
A = AT(ones(T, 4, 4))
IT = T <: Complex ? Complex{Int16} : Int16
R = reshape(reinterpret(IT, A), :)
@allowscalar begin
@test_nowarn R[1] = zero(IT)
@test Array(R)[1] == zero(IT)
end
end
end

@testset "Data parity with compare() — GPU safe" for T in eltypes
@test compare(AT, rand(T, 8, 8, 8)) do A
@views V = A[:, idx, :]
@allowscalar V .+= one(T)
A
end
end
end
Loading