Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 81 additions & 4 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,46 @@ end

buffer_for(::Function, args::Vararg{Type,N}) where {N} = nothing

function mutable_buffered_operate_to_fallback(::NotMutable, buffer, output, op::Function, args...)
throw(
ArgumentError(
"Cannot call `mutable_buffered_operate_to!(::$(typeof(buffer)), ::$(typeof(output)), $op, ::$(join(typeof.(args), ", ::")))` as objects of type `$(typeof(output))` cannot be modifed to equal the result of the operation. Use `buffered_operate_to!` instead which returns the value of the result (possibly modifying the first argument) to write generic code that also works when the type cannot be modified.",
),
)
end

function mutable_buffered_operate_to_fallback(::IsMutable, buffer, output, op::Function, args...)
error(
"`mutable_buffered_operate_to!(::$(typeof(buffer)), ::$(typeof(output)), $op, ::",
join(typeof.(args), ", ::"),
")` is not implemented.",
)
end

function mutable_buffered_operate_to_fallback(
buffer,
output,
op::Function,
args::Vararg{Any,N},
) where {N}
return mutable_buffered_operate_to_fallback(
mutability(output, op, args...),
buffer,
output,
op,
args...
)
end

function mutable_buffered_operate_to_fallback(
::Nothing,
output,
op::Function,
args::Vararg{Any,N},
) where {N}
return mutable_operate_to!(output, op, args...)
end

"""
mutable_buffered_operate_to!(buffer, output, op::Function, args...)

Expand All @@ -268,12 +308,49 @@ possibly modifying `buffer`. Can only be called if
`mutability(output, op, args...)` returns `true`.
"""
function mutable_buffered_operate_to!(
::Nothing,
buffer,
output,
op::Function,
args::Vararg{Any,N},
) where {N}
return mutable_operate_to!(output, op, args...)
return mutable_buffered_operate_to_fallback(buffer, output, op, args...)
end

function mutable_buffered_operate_fallback(::NotMutable, buffer, op::Function, args...)
throw(
ArgumentError(
"Cannot call `mutable_buffered_operate!(::$(typeof(buffer)), $op, ::$(join(typeof.(args), ", ::")))` as objects of type `$(typeof(args[1]))` cannot be modifed to equal the result of the operation. Use `buffered_operate!` instead which returns the value of the result (possibly modifying the first argument) to write generic code that also works when the type cannot be modified.",
),
)
end

function mutable_buffered_operate_fallback(::IsMutable, buffer, op::Function, args...)
error(
"`mutable_buffered_operate!(::$(typeof(buffer)), $op, ::",
join(typeof.(args), ", ::"),
")` is not implemented.",
)
end

function mutable_buffered_operate_fallback(
buffer,
op::Function,
args::Vararg{Any,N},
) where {N}
return mutable_buffered_operate_fallback(
mutability(args[1], op, args...),
buffer,
op,
args...
)
end

function mutable_buffered_operate_fallback(
::Nothing,
op::Function,
args::Vararg{Any,N},
) where {N}
return mutable_operate!(op, args...)
end

"""
Expand All @@ -284,8 +361,8 @@ possibly modifying `buffer`. Can only be called if
`mutability(args[1], op, args...)` returns `true`.
"""
function mutable_buffered_operate! end
function mutable_buffered_operate!(::Nothing, op::Function, args::Vararg{Any,N}) where {N}
return mutable_operate!(op, args...)
function mutable_buffered_operate!(buffer, op::Function, args::Vararg{Any,N}) where {N}
return mutable_buffered_operate_fallback(buffer, op, args...)
end

"""
Expand Down
33 changes: 25 additions & 8 deletions src/linear_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,16 +257,15 @@ function _dim_check(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix)
end
end

function _add_mul_array(C::Vector, A::AbstractMatrix, B::AbstractVector)
function _add_mul_array(buffer, C::Vector, A::AbstractMatrix, B::AbstractVector)
Astride = size(A, 1)
# We need a buffer to hold the intermediate multiplication.
mul_buffer = buffer_for(add_mul, eltype(C), eltype(A), eltype(B))
@inbounds begin
for k in eachindex(B)
aoffs = (k - 1) * Astride
b = B[k]
for i in Base.OneTo(size(A, 1))
C[i] = buffered_operate!(mul_buffer, add_mul, C[i], A[aoffs+i], b)
C[i] = buffered_operate!(buffer, add_mul, C[i], A[aoffs+i], b)
end
end
end # @inbounds
Expand All @@ -275,28 +274,46 @@ end

# This is incorrect if `C` is `LinearAlgebra.Symmetric` as we modify twice the
# same diagonal element.
function _add_mul_array(C::Matrix, A::AbstractMatrix, B::AbstractMatrix)
mul_buffer = buffer_for(add_mul, eltype(C), eltype(A), eltype(B))
function _add_mul_array(buffer, C::Matrix, A::AbstractMatrix, B::AbstractMatrix)
@inbounds begin
for i = 1:size(A, 1), j = 1:size(B, 2)
Ctmp = C[i, j]
for k = 1:size(A, 2)
Ctmp = buffered_operate!(mul_buffer, add_mul, Ctmp, A[i, k], B[k, j])
Ctmp = buffered_operate!(buffer, add_mul, Ctmp, A[i, k], B[k, j])
end
C[i, j] = Ctmp
end
end # @inbounds
return C
end

function mutable_operate!(
function mutable_buffered_operate!(
buffer,
::typeof(add_mul),
C::VecOrMat,
A::AbstractMatrix,
B::AbstractVecOrMat,
)
_dim_check(C, A, B)
_add_mul_array(C, A, B)
_add_mul_array(buffer, C, A, B)
end

function buffer_for(
::typeof(add_mul),
::Type{<:VecOrMat{S}},
::Type{<:AbstractMatrix{T}},
::Type{<:AbstractVecOrMat{U}},
) where {S,T,U}
return buffer_for(add_mul, S, T, U)
end
function mutable_operate!(
::typeof(add_mul),
C::VecOrMat,
A::AbstractMatrix,
B::AbstractVecOrMat,
)
buffer = buffer_for(add_mul, typeof(C), typeof(A), typeof(B))
return mutable_buffered_operate!(buffer, add_mul, C, A, B)
end

function mutable_operate!(::typeof(zero), C::Union{Vector,Matrix})
Expand Down
16 changes: 16 additions & 0 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ end
"Cannot call `mutable_operate!(+, ::$Int, ::$Int)` as objects of type `$Int` cannot be modifed to equal the result of the operation. Use `operate!` instead which returns the value of the result (possibly modifying the first argument) to write generic code that also works when the type cannot be modified.",
)
@test_throws err MA.mutable_operate!(+, 0, 0)
err = ArgumentError(
"Cannot call `mutable_buffered_operate_to!(::$Int, ::$Int, +, ::$Int, ::$Int)` as objects of type `$Int` cannot be modifed to equal the result of the operation. Use `buffered_operate_to!` instead which returns the value of the result (possibly modifying the first argument) to write generic code that also works when the type cannot be modified.",
)
@test_throws err MA.mutable_buffered_operate_to!(0, 0, +, 0, 0)
err = ArgumentError(
"Cannot call `mutable_buffered_operate!(::$Int, +, ::$Int, ::$Int)` as objects of type `$Int` cannot be modifed to equal the result of the operation. Use `buffered_operate!` instead which returns the value of the result (possibly modifying the first argument) to write generic code that also works when the type cannot be modified.",
)
@test_throws err MA.mutable_buffered_operate!(0, +, 0, 0)
x = DummyMutable()
err = ErrorException(
"`mutable_operate_to!(::DummyMutable, +, ::DummyMutable, ::DummyMutable)` is not implemented yet.",
Expand All @@ -29,4 +37,12 @@ end
"`mutable_operate!(+, ::DummyMutable, ::DummyMutable)` is not implemented yet.",
)
@test_throws err MA.mutable_operate!(+, x, x)
err = ErrorException(
"`mutable_buffered_operate_to!(::DummyMutable, ::DummyMutable, +, ::DummyMutable, ::DummyMutable)` is not implemented.",
)
@test_throws err MA.mutable_buffered_operate_to!(x, x, +, x, x)
err = ErrorException(
"`mutable_buffered_operate!(::DummyMutable, +, ::DummyMutable, ::DummyMutable)` is not implemented.",
)
@test_throws err MA.mutable_buffered_operate!(x, +, x, x)
end
3 changes: 3 additions & 0 deletions test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ end
alloc_test(() -> MA.operate_fallback!(MA.IsMutable(), MA.add_mul, y, A, x), n)
alloc_test(() -> MA.operate!(MA.add_mul, y, A, x), n)
alloc_test(() -> MA.mutable_operate!(MA.add_mul, y, A, x), n)
# Apparently, all allocations were on creating the buffer since this is allocation free:
buffer = MA.buffer_for(MA.add_mul, typeof(y), typeof(A), typeof(x))
alloc_test(() -> MA.mutable_buffered_operate!(buffer, MA.add_mul, y, A, x), 0)
end
@testset "matrix-matrix product" begin
A = [1 2 3; 4 5 6; 6 8 9]
Expand Down