diff --git a/src/interface.jl b/src/interface.jl index fa73e97d..3c91e876 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -260,6 +260,46 @@ end buffer_for(::Function, args::Vararg{Type,N}) where {N} = nothing +function mutable_buffered_operate_to_fallback(::NotMutable, buffer, output, op::Function, args...) + throw( + ArgumentError( + "Cannot call `mutable_buffered_operate_to!(::$(typeof(buffer)), ::$(typeof(output)), $op, ::$(join(typeof.(args), ", ::")))` as objects of type `$(typeof(output))` cannot be modifed to equal the result of the operation. Use `buffered_operate_to!` instead which returns the value of the result (possibly modifying the first argument) to write generic code that also works when the type cannot be modified.", + ), + ) +end + +function mutable_buffered_operate_to_fallback(::IsMutable, buffer, output, op::Function, args...) + error( + "`mutable_buffered_operate_to!(::$(typeof(buffer)), ::$(typeof(output)), $op, ::", + join(typeof.(args), ", ::"), + ")` is not implemented.", + ) +end + +function mutable_buffered_operate_to_fallback( + buffer, + output, + op::Function, + args::Vararg{Any,N}, +) where {N} + return mutable_buffered_operate_to_fallback( + mutability(output, op, args...), + buffer, + output, + op, + args... + ) +end + +function mutable_buffered_operate_to_fallback( + ::Nothing, + output, + op::Function, + args::Vararg{Any,N}, +) where {N} + return mutable_operate_to!(output, op, args...) +end + """ mutable_buffered_operate_to!(buffer, output, op::Function, args...) @@ -268,12 +308,49 @@ possibly modifying `buffer`. Can only be called if `mutability(output, op, args...)` returns `true`. """ function mutable_buffered_operate_to!( - ::Nothing, + buffer, output, op::Function, args::Vararg{Any,N}, ) where {N} - return mutable_operate_to!(output, op, args...) + return mutable_buffered_operate_to_fallback(buffer, output, op, args...) +end + +function mutable_buffered_operate_fallback(::NotMutable, buffer, op::Function, args...) + throw( + ArgumentError( + "Cannot call `mutable_buffered_operate!(::$(typeof(buffer)), $op, ::$(join(typeof.(args), ", ::")))` as objects of type `$(typeof(args[1]))` cannot be modifed to equal the result of the operation. Use `buffered_operate!` instead which returns the value of the result (possibly modifying the first argument) to write generic code that also works when the type cannot be modified.", + ), + ) +end + +function mutable_buffered_operate_fallback(::IsMutable, buffer, op::Function, args...) + error( + "`mutable_buffered_operate!(::$(typeof(buffer)), $op, ::", + join(typeof.(args), ", ::"), + ")` is not implemented.", + ) +end + +function mutable_buffered_operate_fallback( + buffer, + op::Function, + args::Vararg{Any,N}, +) where {N} + return mutable_buffered_operate_fallback( + mutability(args[1], op, args...), + buffer, + op, + args... + ) +end + +function mutable_buffered_operate_fallback( + ::Nothing, + op::Function, + args::Vararg{Any,N}, +) where {N} + return mutable_operate!(op, args...) end """ @@ -284,8 +361,8 @@ possibly modifying `buffer`. Can only be called if `mutability(args[1], op, args...)` returns `true`. """ function mutable_buffered_operate! end -function mutable_buffered_operate!(::Nothing, op::Function, args::Vararg{Any,N}) where {N} - return mutable_operate!(op, args...) +function mutable_buffered_operate!(buffer, op::Function, args::Vararg{Any,N}) where {N} + return mutable_buffered_operate_fallback(buffer, op, args...) end """ diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index dc2db067..4e8652e5 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -257,16 +257,15 @@ function _dim_check(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix) end end -function _add_mul_array(C::Vector, A::AbstractMatrix, B::AbstractVector) +function _add_mul_array(buffer, C::Vector, A::AbstractMatrix, B::AbstractVector) Astride = size(A, 1) # We need a buffer to hold the intermediate multiplication. - mul_buffer = buffer_for(add_mul, eltype(C), eltype(A), eltype(B)) @inbounds begin for k in eachindex(B) aoffs = (k - 1) * Astride b = B[k] for i in Base.OneTo(size(A, 1)) - C[i] = buffered_operate!(mul_buffer, add_mul, C[i], A[aoffs+i], b) + C[i] = buffered_operate!(buffer, add_mul, C[i], A[aoffs+i], b) end end end # @inbounds @@ -275,13 +274,12 @@ end # This is incorrect if `C` is `LinearAlgebra.Symmetric` as we modify twice the # same diagonal element. -function _add_mul_array(C::Matrix, A::AbstractMatrix, B::AbstractMatrix) - mul_buffer = buffer_for(add_mul, eltype(C), eltype(A), eltype(B)) +function _add_mul_array(buffer, C::Matrix, A::AbstractMatrix, B::AbstractMatrix) @inbounds begin for i = 1:size(A, 1), j = 1:size(B, 2) Ctmp = C[i, j] for k = 1:size(A, 2) - Ctmp = buffered_operate!(mul_buffer, add_mul, Ctmp, A[i, k], B[k, j]) + Ctmp = buffered_operate!(buffer, add_mul, Ctmp, A[i, k], B[k, j]) end C[i, j] = Ctmp end @@ -289,14 +287,33 @@ function _add_mul_array(C::Matrix, A::AbstractMatrix, B::AbstractMatrix) return C end -function mutable_operate!( +function mutable_buffered_operate!( + buffer, ::typeof(add_mul), C::VecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, ) _dim_check(C, A, B) - _add_mul_array(C, A, B) + _add_mul_array(buffer, C, A, B) +end + +function buffer_for( + ::typeof(add_mul), + ::Type{<:VecOrMat{S}}, + ::Type{<:AbstractMatrix{T}}, + ::Type{<:AbstractVecOrMat{U}}, +) where {S,T,U} + return buffer_for(add_mul, S, T, U) +end +function mutable_operate!( + ::typeof(add_mul), + C::VecOrMat, + A::AbstractMatrix, + B::AbstractVecOrMat, +) + buffer = buffer_for(add_mul, typeof(C), typeof(A), typeof(B)) + return mutable_buffered_operate!(buffer, add_mul, C, A, B) end function mutable_operate!(::typeof(zero), C::Union{Vector,Matrix}) diff --git a/test/interface.jl b/test/interface.jl index 35309965..055bcebf 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -20,6 +20,14 @@ end "Cannot call `mutable_operate!(+, ::$Int, ::$Int)` as objects of type `$Int` cannot be modifed to equal the result of the operation. Use `operate!` instead which returns the value of the result (possibly modifying the first argument) to write generic code that also works when the type cannot be modified.", ) @test_throws err MA.mutable_operate!(+, 0, 0) + err = ArgumentError( + "Cannot call `mutable_buffered_operate_to!(::$Int, ::$Int, +, ::$Int, ::$Int)` as objects of type `$Int` cannot be modifed to equal the result of the operation. Use `buffered_operate_to!` instead which returns the value of the result (possibly modifying the first argument) to write generic code that also works when the type cannot be modified.", + ) + @test_throws err MA.mutable_buffered_operate_to!(0, 0, +, 0, 0) + err = ArgumentError( + "Cannot call `mutable_buffered_operate!(::$Int, +, ::$Int, ::$Int)` as objects of type `$Int` cannot be modifed to equal the result of the operation. Use `buffered_operate!` instead which returns the value of the result (possibly modifying the first argument) to write generic code that also works when the type cannot be modified.", + ) + @test_throws err MA.mutable_buffered_operate!(0, +, 0, 0) x = DummyMutable() err = ErrorException( "`mutable_operate_to!(::DummyMutable, +, ::DummyMutable, ::DummyMutable)` is not implemented yet.", @@ -29,4 +37,12 @@ end "`mutable_operate!(+, ::DummyMutable, ::DummyMutable)` is not implemented yet.", ) @test_throws err MA.mutable_operate!(+, x, x) + err = ErrorException( + "`mutable_buffered_operate_to!(::DummyMutable, ::DummyMutable, +, ::DummyMutable, ::DummyMutable)` is not implemented.", + ) + @test_throws err MA.mutable_buffered_operate_to!(x, x, +, x, x) + err = ErrorException( + "`mutable_buffered_operate!(::DummyMutable, +, ::DummyMutable, ::DummyMutable)` is not implemented.", + ) + @test_throws err MA.mutable_buffered_operate!(x, +, x, x) end diff --git a/test/matmul.jl b/test/matmul.jl index 57f5fedd..ddd22f82 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -131,6 +131,9 @@ end alloc_test(() -> MA.operate_fallback!(MA.IsMutable(), MA.add_mul, y, A, x), n) alloc_test(() -> MA.operate!(MA.add_mul, y, A, x), n) alloc_test(() -> MA.mutable_operate!(MA.add_mul, y, A, x), n) + # Apparently, all allocations were on creating the buffer since this is allocation free: + buffer = MA.buffer_for(MA.add_mul, typeof(y), typeof(A), typeof(x)) + alloc_test(() -> MA.mutable_buffered_operate!(buffer, MA.add_mul, y, A, x), 0) end @testset "matrix-matrix product" begin A = [1 2 3; 4 5 6; 6 8 9]