diff --git a/Project.toml b/Project.toml index 2c6d02d..90689e7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SparseArraysBase" uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208" authors = ["ITensor developers and contributors"] -version = "0.7.7" +version = "0.7.8" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" @@ -14,8 +14,15 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" +[weakdeps] +TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" + +[extensions] +SparseArraysBaseTensorAlgebraExt = ["TensorAlgebra", "SparseArrays"] + [compat] Accessors = "0.1.41" Adapt = "4.3.0" @@ -29,11 +36,16 @@ LinearAlgebra = "1.10" MapBroadcast = "0.1.5" Random = "1.10.0" SafeTestsets = "0.1" +SparseArrays = "1.10" Suppressor = "0.2" +TensorAlgebra = "0.6.2" Test = "1.10" TypeParameterAccessors = "0.4.3" julia = "1.10" +[workspace] +projects = ["benchmark", "dev", "docs", "examples", "test"] + [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" diff --git a/docs/Project.toml b/docs/Project.toml index 700efd6..786b536 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,6 +4,9 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" +[sources] +SparseArraysBase = {path = ".."} + [compat] Dictionaries = "0.4.4" Documenter = "1.8.1" diff --git a/examples/Project.toml b/examples/Project.toml index 76d70b2..9dbdee0 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -3,6 +3,9 @@ Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[sources] +SparseArraysBase = {path = ".."} + [compat] Dictionaries = "0.4.4" SparseArraysBase = "0.7.0" diff --git a/ext/SparseArraysBaseTensorAlgebraExt/SparseArraysBaseTensorAlgebraExt.jl b/ext/SparseArraysBaseTensorAlgebraExt/SparseArraysBaseTensorAlgebraExt.jl new file mode 100644 index 0000000..192fa85 --- /dev/null +++ b/ext/SparseArraysBaseTensorAlgebraExt/SparseArraysBaseTensorAlgebraExt.jl @@ -0,0 +1,28 @@ +module SparseArraysBaseTensorAlgebraExt + +using SparseArrays: SparseMatrixCSC +using SparseArraysBase: AnyAbstractSparseArray, AnyAbstractSparseMatrix, SparseArrayDOK +using TensorAlgebra: TensorAlgebra, BlockedTrivialPermutation, BlockedTuple, FusionStyle, + ReshapeFusion, matricize, unmatricize + +struct SparseArrayFusion <: FusionStyle end +TensorAlgebra.FusionStyle(::Type{<:AnyAbstractSparseArray}) = SparseArrayFusion() + +function TensorAlgebra.matricize( + style::SparseArrayFusion, a::AbstractArray, length_codomain::Val + ) + m = matricize(ReshapeFusion(), a, length_codomain) + return convert(SparseMatrixCSC, m) +end +function TensorAlgebra.unmatricize( + style::SparseArrayFusion, + m::AbstractMatrix, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, + ) + a = unmatricize(ReshapeFusion(), m, axes_codomain, axes_domain) + # TODO: Use `similar_type(m)` instead of hardcoding to `SparseArrayDOK`? + return convert(SparseArrayDOK, a) +end + +end diff --git a/src/SparseArraysBase.jl b/src/SparseArraysBase.jl index 5a09d24..9f75a13 100644 --- a/src/SparseArraysBase.jl +++ b/src/SparseArraysBase.jl @@ -25,5 +25,6 @@ include("wrappers.jl") include("abstractsparsearray.jl") include("sparsearraydok.jl") include("oneelementarray.jl") +include("sparsearrays.jl") end diff --git a/src/abstractsparsearray.jl b/src/abstractsparsearray.jl index b6d3ae5..3f13ccf 100644 --- a/src/abstractsparsearray.jl +++ b/src/abstractsparsearray.jl @@ -2,6 +2,8 @@ using Dictionaries: AbstractDictionary abstract type AbstractSparseArray{T, N} <: AbstractArray{T, N} end +Base.convert(T::Type{<:AbstractSparseArray}, a::AbstractArray) = a isa T ? a : T(a) + using DerivableInterfaces: @array_aliases # Define AbstractSparseVector, AnyAbstractSparseArray, etc. @array_aliases AbstractSparseArray diff --git a/src/sparsearrays.jl b/src/sparsearrays.jl new file mode 100644 index 0000000..ec03616 --- /dev/null +++ b/src/sparsearrays.jl @@ -0,0 +1,36 @@ +using SparseArrays: SparseArrays, AbstractSparseMatrixCSC, SparseMatrixCSC, findnz + +function eachstoredindex(m::AbstractSparseMatrixCSC) + I, J, V = findnz(m) + # TODO: This loses the compile time element type, is there a better lazy way? + return Iterators.map(CartesianIndex, zip(I, J)) +end +function eachstoredindex(a::Base.ReshapedArray{<:Any, <:Any, <:AbstractSparseMatrixCSC}) + return @interface SparseArrayInterface() eachstoredindex(a) +end + +function SparseArrays.SparseMatrixCSC{Tv, Ti}(m::AnyAbstractSparseMatrix) where {Tv, Ti} + m′ = SparseMatrixCSC{Tv, Ti}(undef, size(m)) + for I in eachstoredindex(m) + m′[I] = m[I] + end + return m′ +end + +function SparseArrayDOK(a::Base.ReshapedArray{<:Any, <:Any, <:AbstractSparseMatrixCSC}) + return SparseArrayDOK{eltype(a), ndims(a)}(a) +end +function SparseArrayDOK{T}( + a::Base.ReshapedArray{<:Any, <:Any, <:AbstractSparseMatrixCSC} + ) where {T} + return SparseArrayDOK{T, ndims(a)}(a) +end +function SparseArrayDOK{T, N}( + a::Base.ReshapedArray{<:Any, N, <:AbstractSparseMatrixCSC} + ) where {T, N} + a′ = SparseArrayDOK{T, N}(undef, size(a)) + for I in eachstoredindex(a) + a′[I] = a[I] + end + return a′ +end diff --git a/test/Project.toml b/test/Project.toml index 8b99b36..4ac9c6c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,11 +8,16 @@ JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" +TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[sources] +SparseArraysBase = {path = ".."} + [compat] Adapt = "4.2.0" Aqua = "0.8.11" @@ -23,7 +28,9 @@ JLArrays = "0.2.0, 0.3" LinearAlgebra = "<0.0.1, 1" Random = "<0.0.1, 1" SafeTestsets = "0.1.0" +SparseArrays = "1.10" SparseArraysBase = "0.7.0" StableRNGs = "1.0.2" Suppressor = "0.2.8" +TensorAlgebra = "0.6" Test = "<0.0.1, 1" diff --git a/test/test_dense.jl b/test/test_dense.jl index 3acc147..3542e2d 100644 --- a/test/test_dense.jl +++ b/test/test_dense.jl @@ -34,7 +34,6 @@ arrayts = (Array, JLArray) dev(elt[2, 4]), Dict([CartesianIndex(1, 2) => 1, CartesianIndex(3, 4) => 2]), (3, 4) ) d = dense(s) - @show typeof(d) @test d isa arrayt{elt, 2} @test d == dev(elt[0 2 0 0; 0 0 0 0; 0 0 0 4]) end diff --git a/test/test_tensoralgebraext.jl b/test/test_tensoralgebraext.jl new file mode 100644 index 0000000..8921e31 --- /dev/null +++ b/test/test_tensoralgebraext.jl @@ -0,0 +1,31 @@ +using SparseArrays: SparseMatrixCSC, findnz, nnz +using SparseArraysBase: SparseMatrixDOK, eachstoredindex, isstored, sparsezeros, + storedlength +using TensorAlgebra: contract, matricize +using Test: @testset, @test + +@testset "TensorAlgebraExt (eltype = $elt)" for elt in (Float32, ComplexF64) + a = sparsezeros(elt, (2, 2, 2)) + a[1, 1, 1] = 1 + a[2, 1, 2] = 2 + + # matricize + m = matricize(a, (1, 3), (2,)) + @test m isa SparseMatrixCSC{elt} + @test nnz(m) == 2 + @test isstored(m, 1, 1) + @test m[1, 1] ≡ elt(1) + @test isstored(m, 4, 1) + @test m[4, 1] ≡ elt(2) + @test issetequal(eachstoredindex(m), [CartesianIndex(1, 1), CartesianIndex(4, 1)]) + for I in setdiff(CartesianIndices(m), [CartesianIndex(1, 1), CartesianIndex(4, 1)]) + @test m[I] ≡ zero(elt) + end + + # contract + b, l = contract(a, ("i", "j", "k"), a, ("j", "k", "l")) + @test b isa SparseMatrixDOK{elt} + @test storedlength(b) == 1 + @test only(eachstoredindex(b)) == CartesianIndex(1, 1) + @test b[1, 1] ≡ elt(1) +end