From cb15387a16785c21721f7f767b30f5d60991f41c Mon Sep 17 00:00:00 2001 From: Tamas Hakkel Date: Thu, 13 Feb 2025 21:16:00 +0100 Subject: [PATCH 1/5] Update Dependencies, fix tests and add Aqua.jl tests --- .github/workflows/ci.yml | 2 +- Project.toml | 23 ++++-- README.md | 1 + src/StructuredOptimization.jl | 5 ++ src/solvers/build_solve.jl | 50 ++++++------- src/syntax/expressions/addition.jl | 113 +++++++++++++---------------- src/syntax/variable.jl | 10 +-- test/runtests.jl | 52 ++++++++----- test/test_build_minimize.jl | 21 +++--- 9 files changed, 146 insertions(+), 131 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 76f776a..c483152 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,7 +13,7 @@ jobs: matrix: version: - '1' - - '1.6' + - '1.10' os: - ubuntu-latest - macOS-latest diff --git a/Project.toml b/Project.toml index 73a52e2..9bf0b22 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,11 @@ name = "StructuredOptimization" uuid = "46cd3e9d-64ff-517d-a929-236bc1a1fc9d" -version = "0.4.0" +version = "0.5.0" [deps] AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c" DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9" @@ -12,18 +13,24 @@ ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" [compat] -AbstractOperators = "0.3" -DSP = "0.5.1 - 0.7" +AbstractOperators = "0.4" +Aqua = "0.8" +DSP = "0.5.1 - 0.8" +DifferentiationInterface = "0.6" FFTW = "1" -ProximalAlgorithms = "0.5" -ProximalOperators = "0.15" -RecursiveArrayTools = "1 - 2" -julia = "1.4" +LinearAlgebra = "1" +ProximalAlgorithms = "0.7" +ProximalOperators = "0.16" +Random = "1" +RecursiveArrayTools = "1 - 3" +Test = "1" +julia = "1.10" [extras] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["LinearAlgebra", "Test", "Random"] +test = ["LinearAlgebra", "Test", "Random", "Aqua"] diff --git a/README.md b/README.md index 6ec97e5..f69ea04 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ [![Build status](https://github.com/JuliaFirstOrder/StructuredOptimization.jl/workflows/CI/badge.svg)](https://github.com/JuliaFirstOrder/StructuredOptimization.jl/actions?query=workflow%3ACI) [![codecov](https://codecov.io/gh/JuliaFirstOrder/StructuredOptimization.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaFirstOrder/StructuredOptimization.jl) +[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliafirstorder.github.io/StructuredOptimization.jl/stable) [![](https://img.shields.io/badge/docs-latest-blue.svg)](https://juliafirstorder.github.io/StructuredOptimization.jl/latest) diff --git a/src/StructuredOptimization.jl b/src/StructuredOptimization.jl index e989dad..2b24d17 100644 --- a/src/StructuredOptimization.jl +++ b/src/StructuredOptimization.jl @@ -9,6 +9,11 @@ using ProximalAlgorithms import ProximalAlgorithms: ZeroFPR, PANOC, PANOCplus export ZeroFPR, PANOC, PANOCplus +ProximalAlgorithms.value_and_gradient(f, x) = begin + y, fy = gradient(f, x) + return fy, y +end + include("syntax/syntax.jl") include("calculus/precomposeNonlinear.jl") # TODO move to ProximalOperators? include("arraypartition.jl") # TODO move to ProximalOperators? diff --git a/src/solvers/build_solve.jl b/src/solvers/build_solve.jl index b360902..99aa8b0 100644 --- a/src/solvers/build_solve.jl +++ b/src/solvers/build_solve.jl @@ -1,7 +1,5 @@ -export build - """ - parse_problem(terms::Tuple, solver::ForwardBackwardSolver) + parse_problem(terms::Tuple, solver::ForwardBackwardSolver) Takes as input a tuple containing the terms defining the problem and the solver. @@ -22,26 +20,26 @@ julia> StructuredOptimization.parse_problem(p, PANOCplus()); ``` """ function parse_problem(terms::Tuple, solver::T) where T <: ForwardBackwardSolver - x = extract_variables(terms) - # Separate smooth and nonsmooth - smooth, nonsmooth = split_smooth(terms) - if is_proximable(nonsmooth) - g = extract_proximable(x, nonsmooth) + x = extract_variables(terms) + # Separate smooth and nonsmooth + smooth, nonsmooth = split_smooth(terms) + if is_proximable(nonsmooth) + g = extract_proximable(x, nonsmooth) kwargs = Dict{Symbol, Any}(:g => g) - if !isempty(smooth) - if is_linear(smooth) - f = extract_functions(smooth) - A = extract_operators(x, smooth) - kwargs[:A] = A - else # ?? - f = extract_functions_nodisp(smooth) - A = extract_affines(x, smooth) - f = PrecomposeNonlinear(f, A) - end - kwargs[:f] = f - end - return (x, kwargs) - end + if !isempty(smooth) + if is_linear(smooth) + f = extract_functions(smooth) + A = extract_operators(x, smooth) + kwargs[:A] = A + else # ?? + f = extract_functions_nodisp(smooth) + A = extract_affines(x, smooth) + f = PrecomposeNonlinear(f, A) + end + kwargs[:f] = f + end + return (x, kwargs) + end error("Sorry, I cannot parse this problem for solver of type $(T)") end @@ -49,7 +47,7 @@ end export solve """ - solve(terms::Tuple, solver::ForwardBackwardSolver) + solve(terms::Tuple, solver::ForwardBackwardSolver) Takes as input a tuple containing the terms defining the problem and the solver options. @@ -71,8 +69,8 @@ julia> ~x ``` """ function solve(terms::Tuple, solver::ForwardBackwardSolver) - x, kwargs = parse_problem(terms, solver) + x, kwargs = parse_problem(terms, solver) x_star, it = solver(; x0 = ~x, kwargs...) - ~x .= x_star - return x, it + ~x .= x_star + return x, it end diff --git a/src/syntax/expressions/addition.jl b/src/syntax/expressions/addition.jl index bb700a0..aee3125 100644 --- a/src/syntax/expressions/addition.jl +++ b/src/syntax/expressions/addition.jl @@ -1,7 +1,7 @@ import Base: +, - """ - +(ex1::AbstractExpression, ex2::AbstractExpression) + +(ex1::AbstractExpression, ex2::AbstractExpression) Add two expressions. @@ -47,112 +47,97 @@ julia> ex3.+z function (+)(a::AbstractExpression, b::AbstractExpression) A = convert(Expression,a) B = convert(Expression,b) - if variables(A) == variables(B) + if variables(A) == variables(B) return Expression{length(A.x)}(A.x,affine(A)+affine(B)) - else - opA = affine(A) - xA = variables(A) - opB = affine(B) - xB = variables(B) + else + opA = affine(A) + xA = variables(A) + opB = affine(B) + xB = variables(B) xNew, opNew = Usum_op(xA,xB,opA,opB,true) return Expression{length(xNew)}(xNew,opNew) - end + end end # sum expressions function (-)(a::AbstractExpression, b::AbstractExpression) A = convert(Expression,a) B = convert(Expression,b) - if variables(A) == variables(B) + if variables(A) == variables(B) return Expression{length(A.x)}(A.x,affine(A)-affine(B)) - else - opA = affine(A) - xA = variables(A) - opB = affine(B) - xB = variables(B) + else + opA = affine(A) + xA = variables(A) + opB = affine(B) + xB = variables(B) xNew, opNew = Usum_op(xA,xB,opA,opB,false) return Expression{length(xNew)}(xNew,opNew) - end + end end #unsigned sum affines with single variables -function Usum_op(xA::Tuple{Variable}, - xB::Tuple{Variable}, - A::AbstractOperator, - B::AbstractOperator,sign::Bool) +function Usum_op(xA::Tuple{Variable}, xB::Tuple{Variable}, A::AbstractOperator, B::AbstractOperator, sign::Bool) xNew = (xA...,xB...) opNew = sign ? hcat(A,B) : hcat(A,-B) - return xNew, opNew + return xNew, opNew end #unsigned sum: HCAT + AbstractOperator -function Usum_op(xA::NTuple{N,Variable}, - xB::Tuple{Variable}, - A::L1, - B::AbstractOperator,sign::Bool) where {N, M, L1<:HCAT{N}} - if xB[1] in xA +function Usum_op(xA::NTuple{N,Variable}, xB::Tuple{Variable}, A::HCAT{N}, B::AbstractOperator, sign::Bool) where {N} + if xB[1] in xA idx = findfirst(xA.==Ref(xB[1])) S = sign ? A[idx]+B : A[idx]-B - xNew = xA + xNew = xA opNew = hcat(A[1:idx-1],S,A[idx+1:N] ) - else + else xNew = (xA...,xB...) opNew = sign ? hcat(A,B) : hcat(A,-B) - end - return xNew, opNew + end + return xNew, opNew end #unsigned sum: AbstractOperator+HCAT -function Usum_op(xA::Tuple{Variable}, - xB::NTuple{N,Variable}, - A::AbstractOperator, - B::L2,sign::Bool) where {N, M, L2<:HCAT{N}} - if xA[1] in xB +function Usum_op(xA::Tuple{Variable}, xB::NTuple{N,Variable}, A::AbstractOperator, B::HCAT{N}, sign::Bool) where {N} + if xA[1] in xB idx = findfirst(xA.==Ref(xB[1])) S = sign ? A+B[idx] : B[idx]-A - xNew = xB + xNew = xB opNew = sign ? hcat(B[1:idx-1],S,B[idx+1:N] ) : -hcat(B[1:idx-1],S,B[idx+1:N] ) - else + else xNew = (xA...,xB...) opNew = sign ? hcat(A,B) : hcat(A,-B) - end + end - return xNew, opNew + return xNew, opNew end #unsigned sum: HCAT+HCAT -function Usum_op(xA::NTuple{NA,Variable}, - xB::NTuple{NB,Variable}, - A::L1, - B::L2,sign::Bool) where {NA,NB,M, - L1<:HCAT{NB}, - L2<:HCAT{NB} } - xNew = xA - opNew = A - for i in eachindex(xB) - xNew, opNew = Usum_op(xNew, (xB[i],), opNew, B[i], sign) - end +function Usum_op(xA::NTuple{NA,Variable}, xB::NTuple{NB,Variable}, A::HCAT{NB}, B::HCAT{NB}, sign::Bool) where {NA,NB} + xNew = xA + opNew = A + for i in eachindex(xB) + xNew, opNew = Usum_op(xNew, (xB[i],), opNew, B[i], sign) + end return xNew,opNew end #unsigned sum: multivar AbstractOperator + AbstractOperator -function Usum_op(xA::NTuple{N,Variable}, - xB::Tuple{Variable}, - A::AbstractOperator, - B::AbstractOperator,sign::Bool) where {N} - if xB[1] in xA - Z = Zeros(A) #this will be an HCAT +function Usum_op( + xA::NTuple{N,Variable}, xB::Tuple{Variable}, A::AbstractOperator, B::AbstractOperator, sign::Bool +) where {N} + if xB[1] in xA + Z = Zeros(A) #this will be an HCAT xNew, opNew = Usum_op(xA,xB,Z,B,sign) - opNew += A - else + opNew += A + else xNew = (xA...,xB...) opNew = sign ? hcat(A,B) : hcat(A,-B) - end - return xNew, opNew + end + return xNew, opNew end """ - +(ex::AbstractExpression, b::Union{AbstractArray,Number}) + +(ex::AbstractExpression, b::Union{AbstractArray,Number}) Add a scalar or an `Array` to an expression: @@ -213,9 +198,9 @@ function Broadcast.broadcasted(::typeof(+),a::AbstractExpression, b::AbstractExp elseif prod(size(affine(B),1)) > prod(size(affine(A),1)) A = Expression{length(A.x)}(variables(A), BroadCast(affine(A),size(affine(B),1))) - end + end return A+B - end + end return A+B end @@ -229,8 +214,8 @@ function Broadcast.broadcasted(::typeof(-),a::AbstractExpression, b::AbstractExp elseif prod(size(affine(B),1)) > prod(size(affine(A),1)) A = Expression{length(A.x)}(variables(A), BroadCast(affine(A),size(affine(B),1))) - end + end return A-B - end + end return A-B end diff --git a/src/syntax/variable.jl b/src/syntax/variable.jl index d5ede3f..c3416c7 100644 --- a/src/syntax/variable.jl +++ b/src/syntax/variable.jl @@ -16,11 +16,12 @@ Returns a `Variable` of dimension `dims` initialized with an array of all zeros. Returns a `Variable` of dimension `size(x)` initialized with `x` """ -function Variable(T::Type, args::Vararg{I,N}) where {I <: Integer,N} - Variable{T,N,Array{T,N}}(zeros(T, args...)) +function Variable(T::Type, args::Int...) + N = length(args) + Variable{T,N,Array{T,N}}(zeros(T, args...)) end -function Variable(args::Vararg{I}) where {I <: Integer} +function Variable(args::Int...) Variable(zeros(args...)) end @@ -30,7 +31,6 @@ function Base.show(io::IO, x::Variable) print(io, "Variable($(eltype(x.x)), $(size(x.x)))") end - """ ~(x::Variable) @@ -46,7 +46,7 @@ size(x::Variable, [dim...]) Like `size(A::AbstractArray, [dims...])` returns the tuple containing the dimensions of the variable `x`. """ size(x::Variable) = size(x.x) -size(x::Variable, dim::I) where { I <: Integer} = size(x.x, dim) +size(x::Variable, dim::Integer) = size(x.x, dim) """ eltype(x::Variable) diff --git a/test/runtests.jl b/test/runtests.jl index b6731bd..a256eba 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,30 +6,46 @@ using RecursiveArrayTools using LinearAlgebra, Random using DSP, FFTW using Test +using Aqua Random.seed!(0) @testset "StructuredOptimization" begin + @testset "Calculus" begin + include("test_proxstuff.jl") + end -@testset "Calculus" begin - include("test_proxstuff.jl") -end + @testset "Syntax" begin + include("test_variables.jl") + include("test_expressions.jl") + include("test_AbstractOp_binding.jl") + include("test_terms.jl") + end -@testset "Syntax" begin - include("test_variables.jl") - include("test_expressions.jl") - include("test_AbstractOp_binding.jl") - include("test_terms.jl") -end + @testset "Problem construction" begin + include("test_problem.jl") + include("test_build_minimize.jl") + end -@testset "Problem construction" begin - include("test_problem.jl") - include("test_build_minimize.jl") -end - -@testset "End-to-end tests" begin - include("test_usage_small.jl") - include("test_usage.jl") -end + @testset "End-to-end tests" begin + include("test_usage_small.jl") + include("test_usage.jl") + end + @testset "Aqua" begin + Aqua.test_all(StructuredOptimization; ambiguities=false, piracies=false) + Aqua.test_ambiguities( + StructuredOptimization; exclude=[Base.:(+), Base.:<=, Base.:>=], broken=true + ) + Aqua.test_piracies( + StructuredOptimization; + treat_as_own=[ + ProximalAlgorithms.value_and_gradient, + ProximalOperators.prox, + ProximalOperators.prox!, + ProximalOperators.gradient, + ProximalOperators.gradient!, + ], + ) + end end diff --git a/test/test_build_minimize.jl b/test/test_build_minimize.jl index 828e221..a4c2c18 100644 --- a/test/test_build_minimize.jl +++ b/test/test_build_minimize.jl @@ -42,15 +42,18 @@ xp = copy(~x) @test norm(xp-xpg) <= 1e-4 # test nonconvex Rosenbrock function with known minimum -solvers = [ZeroFPR(tol = 1e-6), PANOC(tol = 1e-6)] -for solver in solvers - x = Variable(1) - y = Variable(1) - a,b = 2.0, 100.0 +function test_solver(solver) + x = Variable(1) + y = Variable(1) + a, b = 2.0, 100.0 - cf = norm(x-a)^2+b*norm(pow(x,2)-y)^2 - @minimize cf+1e-10*norm(x,1)+1e-10*norm(y,1) with solver + cf = norm(x - a)^2 + b * norm(pow(x, 2) - y)^2 + @minimize cf + 1e-10 * norm(x, 1) + 1e-10 * norm(y, 1) with solver - @test norm(~x-[a]) < 1e-4 - @test norm(~y-[a^2]) < 1e-4 + @test norm(~x - [a]) < 1e-4 + @test norm(~y - [a^2]) < 1e-4 +end +solvers = [ZeroFPR(; tol=1e-6), PANOC(; tol=1e-6)] +for solver in solvers + test_solver(solver) end From dfd510128972da04ed3bb8f33eb315bcddd5044d Mon Sep 17 00:00:00 2001 From: Tamas Hakkel Date: Tue, 18 Mar 2025 16:24:41 +0100 Subject: [PATCH 2/5] minor adjustments --- src/solvers/terms_extract.jl | 17 ++++------------- src/syntax/terms/term.jl | 6 +++--- 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/src/solvers/terms_extract.jl b/src/solvers/terms_extract.jl index 389dea6..ab57fd6 100644 --- a/src/solvers/terms_extract.jl +++ b/src/solvers/terms_extract.jl @@ -2,22 +2,13 @@ extract_variables(t::TermOrExpr) = variables(t) function extract_variables(t::NTuple{N,TermOrExpr}) where {N} - x = variables.(t) - xAll = x[1] - for i = 2:length(x) - for xi in x[i] - if (xi in xAll) == false - xAll = (xAll...,xi) - end - end - end - return xAll + return tuple(unique(variables.(t))...) end # extract functions from terms function extract_functions(t::Term) - f = displacement(t) == 0 ? t.f : PrecomposeDiagonal(t.f, 1.0, displacement(t)) #for now I keep this - f = t.lambda == 1. ? f : Postcompose(f, t.lambda) #for now I keep this + f = displacement(t) == 0 ? t.f : PrecomposeDiagonal(t.f, one(t.lambda), displacement(t)) #for now I keep this + f = t.lambda == 1 ? f : Postcompose(f, t.lambda) #for now I keep this #TODO change this return f end @@ -26,7 +17,7 @@ extract_functions(t::Tuple{Term}) = extract_functions(t[1]) # extract functions from terms without displacement function extract_functions_nodisp(t::Term) - f = t.lambda == 1. ? t.f : Postcompose(t.f, t.lambda) + f = t.lambda == 1 ? t.f : Postcompose(t.f, t.lambda) return f end extract_functions_nodisp(t::NTuple{N,Term}) where {N} = SeparableSum(extract_functions_nodisp.(t)) diff --git a/src/syntax/terms/term.jl b/src/syntax/terms/term.jl index 2e42973..0a9287f 100644 --- a/src/syntax/terms/term.jl +++ b/src/syntax/terms/term.jl @@ -7,7 +7,7 @@ end function Term(f, ex::AbstractExpression) A = convert(Expression,ex) - Term(1,f, A) + Term(one(real(codomainType(affine(A)))),f, A) end # Operations @@ -19,8 +19,8 @@ import Base: + (+)(a::Term,b::Term) = (a,b) (+)(a::NTuple{N,Term},b::Term) where {N} = (a...,b) (+)(a::Term,b::NTuple{N,Term}) where {N} = (a,b...) -(+)(a::NTuple{N,Term},b::Tuple{}) where {N} = a -(+)(a::Tuple{},b::NTuple{N,Term}) where {N} = b +(+)(a::NTuple{N,Term},::Tuple{}) where {N} = a +(+)(::Tuple{},b::NTuple{N,Term}) where {N} = b (+)(a::NTuple{N,Term},b::NTuple{M,Term}) where {N,M} = (a...,b...) # Define multiplication by constant From 94e4792e8f48ea983e4b1843d64b4b5082c0686a Mon Sep 17 00:00:00 2001 From: Tamas Hakkel Date: Wed, 7 May 2025 13:54:51 +0200 Subject: [PATCH 3/5] add support for all algorithms in ProximalAlgorithms --- Manifest.toml | 937 ++++++++++++++++++ Project.toml | 22 +- src/StructuredOptimization.jl | 19 +- src/arraypartition.jl | 36 - src/calculus/sqrNormL2WithNormalOp.jl | 88 ++ src/solvers/build_solve.jl | 114 ++- src/solvers/minimize.jl | 64 +- src/solvers/parse.jl | 442 +++++++++ src/solvers/solvers_options.jl | 5 - src/solvers/terms_extract.jl | 70 +- src/solvers/terms_properties.jl | 42 +- src/solvers/terms_splitting.jl | 31 - .../expressions/abstractOperator_bind.jl | 2 +- src/syntax/expressions/addition.jl | 38 +- .../expressions/addition_tricky_part.jl | 231 +++++ src/syntax/expressions/expression.jl | 13 +- src/syntax/expressions/multiplication.jl | 12 +- src/syntax/expressions/utils.jl | 2 +- src/syntax/problem.jl | 28 - src/syntax/syntax.jl | 8 - src/syntax/terms/proximalOperators_bind.jl | 63 +- src/syntax/terms/term.jl | 74 +- test/runtests.jl | 4 +- test/test_expressions.jl | 12 + test/test_terms.jl | 8 +- test/test_usage.jl | 3 +- 26 files changed, 2089 insertions(+), 279 deletions(-) create mode 100644 Manifest.toml delete mode 100644 src/arraypartition.jl create mode 100644 src/calculus/sqrNormL2WithNormalOp.jl create mode 100644 src/solvers/parse.jl delete mode 100644 src/solvers/solvers_options.jl delete mode 100644 src/solvers/terms_splitting.jl create mode 100644 src/syntax/expressions/addition_tricky_part.jl delete mode 100644 src/syntax/problem.jl delete mode 100644 src/syntax/syntax.jl diff --git a/Manifest.toml b/Manifest.toml new file mode 100644 index 0000000..5b08d76 --- /dev/null +++ b/Manifest.toml @@ -0,0 +1,937 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.11.4" +manifest_format = "2.0" +project_hash = "d7d80843b7c63bcd8962a2e974300665e8f478dc" + +[[deps.ADTypes]] +git-tree-sha1 = "e2478490447631aedba0823d4d7a80b2cc8cdb32" +uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +version = "1.14.0" + + [deps.ADTypes.extensions] + ADTypesChainRulesCoreExt = "ChainRulesCore" + ADTypesConstructionBaseExt = "ConstructionBase" + ADTypesEnzymeCoreExt = "EnzymeCore" + + [deps.ADTypes.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + +[[deps.AbstractFFTs]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "1.5.0" + + [deps.AbstractFFTs.extensions] + AbstractFFTsChainRulesCoreExt = "ChainRulesCore" + AbstractFFTsTestExt = "Test" + + [deps.AbstractFFTs.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.AbstractOperators]] +deps = ["DSP", "FFTW", "FastBroadcast", "LinearAlgebra", "OperatorCore", "RecursiveArrayTools"] +path = "../AbstractOperators" +uuid = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c" +version = "0.4.0" + + [deps.AbstractOperators.extensions] + CudaExt = "CUDA" + NfftExt = "NFFT" + + [deps.AbstractOperators.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + NFFT = "efe261a4-0d2b-5849-be55-fc731d526b0d" + +[[deps.Accessors]] +deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "MacroTools"] +git-tree-sha1 = "3b86719127f50670efe356bc11073d84b4ed7a5d" +uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +version = "0.1.42" + + [deps.Accessors.extensions] + AxisKeysExt = "AxisKeys" + IntervalSetsExt = "IntervalSets" + LinearAlgebraExt = "LinearAlgebra" + StaticArraysExt = "StaticArrays" + StructArraysExt = "StructArrays" + TestExt = "Test" + UnitfulExt = "Unitful" + + [deps.Accessors.weakdeps] + AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" + +[[deps.Adapt]] +deps = ["LinearAlgebra", "Requires"] +git-tree-sha1 = "f7817e2e585aa6d924fd714df1e2a84be7896c60" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "4.3.0" + + [deps.Adapt.extensions] + AdaptSparseArraysExt = "SparseArrays" + AdaptStaticArraysExt = "StaticArrays" + + [deps.Adapt.weakdeps] + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.2" + +[[deps.ArrayInterface]] +deps = ["Adapt", "LinearAlgebra"] +git-tree-sha1 = "017fcb757f8e921fb44ee063a7aafe5f89b86dd1" +uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +version = "7.18.0" + + [deps.ArrayInterface.extensions] + ArrayInterfaceBandedMatricesExt = "BandedMatrices" + ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices" + ArrayInterfaceCUDAExt = "CUDA" + ArrayInterfaceCUDSSExt = "CUDSS" + ArrayInterfaceChainRulesCoreExt = "ChainRulesCore" + ArrayInterfaceChainRulesExt = "ChainRules" + ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore" + ArrayInterfaceReverseDiffExt = "ReverseDiff" + ArrayInterfaceSparseArraysExt = "SparseArrays" + ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore" + ArrayInterfaceTrackerExt = "Tracker" + + [deps.ArrayInterface.weakdeps] + BandedMatrices = "aae01518-5342-5314-be14-df237901396f" + BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e" + ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" +version = "1.11.0" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +version = "1.11.0" + +[[deps.BenchmarkTools]] +deps = ["Compat", "JSON", "Logging", "Printf", "Profile", "Statistics", "UUIDs"] +git-tree-sha1 = "e38fbc49a620f5d0b660d7f543db1009fe0f8336" +uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +version = "1.6.0" + +[[deps.Bessels]] +git-tree-sha1 = "4435559dc39793d53a9e3d278e185e920b4619ef" +uuid = "0e736298-9ec6-45e8-9647-e4fc86a2fe38" +version = "0.2.8" + +[[deps.BitTwiddlingConvenienceFunctions]] +deps = ["Static"] +git-tree-sha1 = "f21cfd4950cb9f0587d5067e69405ad2acd27b87" +uuid = "62783981-4cbd-42fc-bca8-16325de8dc4b" +version = "0.1.6" + +[[deps.Bzip2_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "1b96ea4a01afe0ea4090c5c8039690672dd13f2e" +uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" +version = "1.0.9+0" + +[[deps.CPUSummary]] +deps = ["CpuId", "IfElse", "PrecompileTools", "Static"] +git-tree-sha1 = "5a97e67919535d6841172016c9530fd69494e5ec" +uuid = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9" +version = "0.2.6" + +[[deps.CloseOpenIntervals]] +deps = ["Static", "StaticArrayInterface"] +git-tree-sha1 = "05ba0d07cd4fd8b7a39541e31a7b0254704ea581" +uuid = "fb6a15b2-703c-40df-9091-08a04967cfa9" +version = "0.1.13" + +[[deps.CodecBzip2]] +deps = ["Bzip2_jll", "TranscodingStreams"] +git-tree-sha1 = "84990fa864b7f2b4901901ca12736e45ee79068c" +uuid = "523fee87-0ab8-5b00-afb7-3ecf72e48cfd" +version = "0.8.5" + +[[deps.CodecZlib]] +deps = ["TranscodingStreams", "Zlib_jll"] +git-tree-sha1 = "962834c22b66e32aa10f7611c08c8ca4e20749a9" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.7.8" + +[[deps.Combinatorics]] +git-tree-sha1 = "08c8b6831dc00bfea825826be0bc8336fc369860" +uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +version = "1.0.2" + +[[deps.CommonSubexpressions]] +deps = ["MacroTools"] +git-tree-sha1 = "cda2cfaebb4be89c9084adaca7dd7333369715c5" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.1" + +[[deps.CommonWorldInvalidations]] +git-tree-sha1 = "ae52d1c52048455e85a387fbee9be553ec2b68d0" +uuid = "f70d9fcc-98c5-4d4a-abd7-e4cdeebd8ca8" +version = "1.0.0" + +[[deps.Compat]] +deps = ["TOML", "UUIDs"] +git-tree-sha1 = "8ae8d32e09f0dcf42a36b90d4e17f5dd2e4c4215" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.16.0" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.1.1+0" + +[[deps.CompositionsBase]] +git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" +uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" +version = "0.1.2" +weakdeps = ["InverseFunctions"] + + [deps.CompositionsBase.extensions] + CompositionsBaseInverseFunctionsExt = "InverseFunctions" + +[[deps.ConstructionBase]] +git-tree-sha1 = "76219f1ed5771adbb096743bff43fb5fdd4c1157" +uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +version = "1.5.8" + + [deps.ConstructionBase.extensions] + ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseLinearAlgebraExt = "LinearAlgebra" + ConstructionBaseStaticArraysExt = "StaticArrays" + + [deps.ConstructionBase.weakdeps] + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.CpuId]] +deps = ["Markdown"] +git-tree-sha1 = "fcbb72b032692610bfbdb15018ac16a36cf2e406" +uuid = "adafc99b-e345-5852-983c-f28acb93d879" +version = "0.3.1" + +[[deps.DSP]] +deps = ["Bessels", "FFTW", "IterTools", "LinearAlgebra", "Polynomials", "Random", "Reexport", "SpecialFunctions", "Statistics"] +git-tree-sha1 = "489db9d78b53e44fb753d225c58832632d74ab10" +uuid = "717857b8-e6f2-59f4-9121-6e50c889abd2" +version = "0.8.0" + + [deps.DSP.extensions] + OffsetArraysExt = "OffsetArrays" + + [deps.DSP.weakdeps] + OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" + +[[deps.DataAPI]] +git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.16.0" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.20" + +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +version = "1.11.0" + +[[deps.DiffResults]] +deps = ["StaticArraysCore"] +git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.1.0" + +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.15.1" + +[[deps.DifferentiationInterface]] +deps = ["ADTypes", "LinearAlgebra"] +git-tree-sha1 = "d86f29074367f1bb92957e8d0b77badd187a97bc" +uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +version = "0.6.32" + + [deps.DifferentiationInterface.extensions] + DifferentiationInterfaceChainRulesCoreExt = "ChainRulesCore" + DifferentiationInterfaceDiffractorExt = "Diffractor" + DifferentiationInterfaceEnzymeExt = ["EnzymeCore", "Enzyme"] + DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation" + DifferentiationInterfaceFiniteDiffExt = "FiniteDiff" + DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences" + DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"] + DifferentiationInterfaceMooncakeExt = "Mooncake" + DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff" + DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"] + DifferentiationInterfaceSparseArraysExt = "SparseArrays" + DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings" + DifferentiationInterfaceStaticArraysExt = "StaticArrays" + DifferentiationInterfaceSymbolicsExt = "Symbolics" + DifferentiationInterfaceTrackerExt = "Tracker" + DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"] + + [deps.DifferentiationInterface.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" + Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c" + Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be" + FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" + FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" + PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.3" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.ExprTools]] +git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.10" + +[[deps.FFTW]] +deps = ["AbstractFFTs", "FFTW_jll", "LinearAlgebra", "MKL_jll", "Preferences", "Reexport"] +git-tree-sha1 = "7de7c78d681078f027389e067864a8d53bd7c3c9" +uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +version = "1.8.1" + +[[deps.FFTW_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "4d81ed14783ec49ce9f2e168208a12ce1815aa25" +uuid = "f5851436-0d7a-5f13-b9de-f02708fd171a" +version = "3.3.10+3" + +[[deps.FastBroadcast]] +deps = ["ArrayInterface", "LinearAlgebra", "Polyester", "Static", "StaticArrayInterface", "StrideArraysCore"] +git-tree-sha1 = "ab1b34570bcdf272899062e1a56285a53ecaae08" +uuid = "7034ab61-46d4-4ed7-9d0f-46aef9175898" +version = "0.3.5" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" +version = "1.11.0" + +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] +git-tree-sha1 = "a2df1b776752e3f344e5116c06d75a10436ab853" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.38" + + [deps.ForwardDiff.extensions] + ForwardDiffStaticArraysExt = "StaticArrays" + + [deps.ForwardDiff.weakdeps] + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" +version = "1.11.0" + +[[deps.GPUArraysCore]] +deps = ["Adapt"] +git-tree-sha1 = "83cf05ab16a73219e5f6bd1bdfa9848fa24ac627" +uuid = "46192b85-c4d5-4398-a991-12ede77f4527" +version = "0.2.0" + +[[deps.IfElse]] +git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" +uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +version = "0.1.1" + +[[deps.IntelOpenMP_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl"] +git-tree-sha1 = "0f14a5456bdc6b9731a5682f439a672750a09e48" +uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0" +version = "2025.0.4+0" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +version = "1.11.0" + +[[deps.InverseFunctions]] +git-tree-sha1 = "a779299d77cd080bf77b97535acecd73e1c5e5cb" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.17" +weakdeps = ["Dates", "Test"] + + [deps.InverseFunctions.extensions] + InverseFunctionsDatesExt = "Dates" + InverseFunctionsTestExt = "Test" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "e2222959fbc6c19554dc15174c81bf7bf3aa691c" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.2.4" + +[[deps.IterTools]] +git-tree-sha1 = "42d5f897009e7ff2cf88db414a389e5ed1bdd023" +uuid = "c8e1da08-722c-5040-9ed9-7db0dc04731e" +version = "1.10.0" + +[[deps.IterativeSolvers]] +deps = ["LinearAlgebra", "Printf", "Random", "RecipesBase", "SparseArrays"] +git-tree-sha1 = "59545b0a2b27208b0650df0a46b8e3019f85055b" +uuid = "42fd0dbc-a981-5370-80f2-aaf504508153" +version = "0.9.4" + +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "a007feb38b422fbdab534406aeca1b86823cb4d6" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.7.0" + +[[deps.JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.4" + +[[deps.JSON3]] +deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] +git-tree-sha1 = "1d322381ef7b087548321d3f878cb4c9bd8f8f9b" +uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +version = "1.14.1" + + [deps.JSON3.extensions] + JSON3ArrowExt = ["ArrowTypes"] + + [deps.JSON3.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" + +[[deps.LayoutPointers]] +deps = ["ArrayInterface", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static", "StaticArrayInterface"] +git-tree-sha1 = "a9eaadb366f5493a5654e843864c13d8b107548c" +uuid = "10f19ff3-798f-405d-979b-55457f8fc047" +version = "0.1.17" + +[[deps.LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" +version = "1.11.0" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.4" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "8.6.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" +version = "1.11.0" + +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.7.2+0" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.11.0+1" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" +version = "1.11.0" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +version = "1.11.0" + +[[deps.LogExpFunctions]] +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "13ca9e2586b89836fd20cccf56e57e2b9ae7f38f" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.29" + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" + + [deps.LogExpFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" +version = "1.11.0" + +[[deps.MKL_jll]] +deps = ["Artifacts", "IntelOpenMP_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "oneTBB_jll"] +git-tree-sha1 = "5de60bc6cb3899cd318d80d627560fae2e2d99ae" +uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7" +version = "2025.0.1+1" + +[[deps.MacroTools]] +git-tree-sha1 = "72aebe0b5051e5143a079a4685a46da330a40472" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.15" + +[[deps.ManualMemory]] +git-tree-sha1 = "bcaef4fc7a0cfe2cba636d84cda54b5e4e4ca3cd" +uuid = "d125e4d3-2237-4719-b19c-fa641b8a4667" +version = "0.1.8" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +version = "1.11.0" + +[[deps.MathOptInterface]] +deps = ["BenchmarkTools", "CodecBzip2", "CodecZlib", "DataStructures", "ForwardDiff", "JSON3", "LinearAlgebra", "MutableArithmetics", "NaNMath", "OrderedCollections", "PrecompileTools", "Printf", "SparseArrays", "SpecialFunctions", "Test"] +git-tree-sha1 = "6723502b2135aa492a65be9633e694482a340ee7" +uuid = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" +version = "1.38.0" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.6+0" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" +version = "1.11.0" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2023.12.12" + +[[deps.MutableArithmetics]] +deps = ["LinearAlgebra", "SparseArrays", "Test"] +git-tree-sha1 = "491bdcdc943fcbc4c005900d7463c9f216aabf4c" +uuid = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" +version = "1.6.4" + +[[deps.NaNMath]] +deps = ["OpenLibm_jll"] +git-tree-sha1 = "cc0a5deefdb12ab3a096f00a6d42133af4560d71" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "1.1.2" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.OSQP]] +deps = ["Libdl", "LinearAlgebra", "MathOptInterface", "OSQP_jll", "SparseArrays"] +git-tree-sha1 = "50faf456a64ac1ca097b78bcdf288d94708adcdd" +uuid = "ab2f91bb-94b4-55e3-9ba0-7f65df51de79" +version = "0.8.1" + +[[deps.OSQP_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "d0f73698c33e04e557980a06d75c2d82e3f0eb49" +uuid = "9c4f68bf-6205-5545-a508-2878b064d984" +version = "0.600.200+0" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.27+1" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+4" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl"] +git-tree-sha1 = "1346c9208249809840c91b26703912dff463d335" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.6+0" + +[[deps.OperatorCore]] +path = "../OperatorCore" +uuid = "3945cd23-d97e-4db0-9df2-35342dbd287d" +version = "0.1.0" + +[[deps.OrderedCollections]] +git-tree-sha1 = "cc4054e898b852042d7b503313f7ad03de99c3dd" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.8.0" + +[[deps.Parsers]] +deps = ["Dates", "PrecompileTools", "UUIDs"] +git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.8.1" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "Random", "SHA", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.11.0" + + [deps.Pkg.extensions] + REPLExt = "REPL" + + [deps.Pkg.weakdeps] + REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Polyester]] +deps = ["ArrayInterface", "BitTwiddlingConvenienceFunctions", "CPUSummary", "IfElse", "ManualMemory", "PolyesterWeave", "Static", "StaticArrayInterface", "StrideArraysCore", "ThreadingUtilities"] +git-tree-sha1 = "6d38fea02d983051776a856b7df75b30cf9a3c1f" +uuid = "f517fe37-dbe3-4b94-8317-1923a5111588" +version = "0.7.16" + +[[deps.PolyesterWeave]] +deps = ["BitTwiddlingConvenienceFunctions", "CPUSummary", "IfElse", "Static", "ThreadingUtilities"] +git-tree-sha1 = "645bed98cd47f72f67316fd42fc47dee771aefcd" +uuid = "1d0040c9-8b98-4ee7-8388-3f51789ca0ad" +version = "0.2.2" + +[[deps.Polynomials]] +deps = ["LinearAlgebra", "OrderedCollections", "RecipesBase", "Requires", "Setfield", "SparseArrays"] +git-tree-sha1 = "555c272d20fc80a2658587fb9bbda60067b93b7c" +uuid = "f27b6e38-b328-58d1-80ce-0feddd5e7a45" +version = "4.0.19" + + [deps.Polynomials.extensions] + PolynomialsChainRulesCoreExt = "ChainRulesCore" + PolynomialsFFTWExt = "FFTW" + PolynomialsMakieCoreExt = "MakieCore" + PolynomialsMutableArithmeticsExt = "MutableArithmetics" + + [deps.Polynomials.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" + MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b" + MutableArithmetics = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.2.1" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.4.3" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" +version = "1.11.0" + +[[deps.Profile]] +uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" +version = "1.11.0" + +[[deps.ProximalAlgorithms]] +deps = ["ADTypes", "DifferentiationInterface", "LinearAlgebra", "OperatorCore", "Printf", "ProximalCore"] +path = "../ProximalAlgorithms.jl" +uuid = "140ffc9f-1907-541a-a177-7475e0a401e9" +version = "0.8.0" + +[[deps.ProximalCore]] +deps = ["LinearAlgebra"] +path = "../ProximalCore.jl" +uuid = "dc4f5ac2-75d1-4f31-931e-60435d74994b" +version = "0.2.0" + +[[deps.ProximalOperators]] +deps = ["IterativeSolvers", "LinearAlgebra", "OSQP", "ProximalCore", "SparseArrays", "SuiteSparse", "TSVD"] +path = "../ProximalOperators.jl" +uuid = "a725b495-10eb-56fe-b38b-717eba820537" +version = "0.17.0" + +[[deps.Random]] +deps = ["SHA"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +version = "1.11.0" + +[[deps.RecipesBase]] +deps = ["PrecompileTools"] +git-tree-sha1 = "5c3d09cc4f31f5fc6af001c250bf1278733100ff" +uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +version = "1.3.4" + +[[deps.RecursiveArrayTools]] +deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] +git-tree-sha1 = "32f824db4e5bab64e25a12b22483a30a6b813d08" +uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" +version = "3.27.4" + + [deps.RecursiveArrayTools.extensions] + RecursiveArrayToolsFastBroadcastExt = "FastBroadcast" + RecursiveArrayToolsForwardDiffExt = "ForwardDiff" + RecursiveArrayToolsMeasurementsExt = "Measurements" + RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements" + RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"] + RecursiveArrayToolsSparseArraysExt = ["SparseArrays"] + RecursiveArrayToolsStructArraysExt = "StructArrays" + RecursiveArrayToolsTrackerExt = "Tracker" + RecursiveArrayToolsZygoteExt = "Zygote" + + [deps.RecursiveArrayTools.weakdeps] + FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" + MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "62389eeff14780bfe55195b7204c0d8738436d64" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.1" + +[[deps.RuntimeGeneratedFunctions]] +deps = ["ExprTools", "SHA", "Serialization"] +git-tree-sha1 = "04c968137612c4a5629fa531334bb81ad5680f00" +uuid = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" +version = "0.5.13" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.SIMDTypes]] +git-tree-sha1 = "330289636fb8107c5f32088d2741e9fd7a061a5c" +uuid = "94e857df-77ce-4151-89e5-788b33177be4" +version = "0.1.0" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +version = "1.11.0" + +[[deps.Setfield]] +deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] +git-tree-sha1 = "c5391c6ace3bc430ca630251d02ea9687169ca68" +uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" +version = "1.1.2" + +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.11.0" + +[[deps.SpecialFunctions]] +deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "64cca0c26b4f31ba18f13f6c12af7c85f478cfde" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.5.0" + + [deps.SpecialFunctions.extensions] + SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" + + [deps.SpecialFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + +[[deps.Static]] +deps = ["CommonWorldInvalidations", "IfElse", "PrecompileTools"] +git-tree-sha1 = "f737d444cb0ad07e61b3c1bef8eb91203c321eff" +uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +version = "1.2.0" + +[[deps.StaticArrayInterface]] +deps = ["ArrayInterface", "Compat", "IfElse", "LinearAlgebra", "PrecompileTools", "Static"] +git-tree-sha1 = "96381d50f1ce85f2663584c8e886a6ca97e60554" +uuid = "0d7ed370-da01-4f52-bd93-41d350b8b718" +version = "1.8.0" + + [deps.StaticArrayInterface.extensions] + StaticArrayInterfaceOffsetArraysExt = "OffsetArrays" + StaticArrayInterfaceStaticArraysExt = "StaticArrays" + + [deps.StaticArrayInterface.weakdeps] + OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.3" + +[[deps.Statistics]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "ae3bb1eb3bba077cd276bc5cfc337cc65c3075c0" +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.11.1" +weakdeps = ["SparseArrays"] + + [deps.Statistics.extensions] + SparseArraysExt = ["SparseArrays"] + +[[deps.StrideArraysCore]] +deps = ["ArrayInterface", "CloseOpenIntervals", "IfElse", "LayoutPointers", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static", "StaticArrayInterface", "ThreadingUtilities"] +git-tree-sha1 = "f35f6ab602df8413a50c4a25ca14de821e8605fb" +uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da" +version = "0.5.7" + +[[deps.StructTypes]] +deps = ["Dates", "UUIDs"] +git-tree-sha1 = "159331b30e94d7b11379037feeb9b690950cace8" +uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" +version = "1.11.0" + +[[deps.SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "7.7.0+0" + +[[deps.SymbolicIndexingInterface]] +deps = ["Accessors", "ArrayInterface", "RuntimeGeneratedFunctions", "StaticArraysCore"] +git-tree-sha1 = "d6c04e26aa1c8f7d144e1a8c47f1c73d3013e289" +uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5" +version = "0.3.38" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.TSVD]] +deps = ["Adapt", "LinearAlgebra"] +git-tree-sha1 = "c39caef6bae501e5607a6caf68dd9ac6e8addbcb" +uuid = "9449cd9e-2762-5aa3-a617-5413e99d722e" +version = "0.4.4" + +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"] +git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.12.0" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +version = "1.11.0" + +[[deps.ThreadingUtilities]] +deps = ["ManualMemory"] +git-tree-sha1 = "eda08f7e9818eb53661b3deb74e3159460dfbc27" +uuid = "8290d209-cae3-49c0-8002-c8c24d57dab5" +version = "0.5.2" + +[[deps.TranscodingStreams]] +git-tree-sha1 = "0c45878dcfdcfa8480052b6ab162cdd138781742" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.11.3" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +version = "1.11.0" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" +version = "1.11.0" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.13+1" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.11.0+0" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.59.0+0" + +[[deps.oneTBB_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "d5a767a3bb77135a99e433afe0eb14cd7f6914c3" +uuid = "1317d2d5-d96f-522e-a858-c73665f53c3e" +version = "2022.0.0+0" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+2" diff --git a/Project.toml b/Project.toml index 9bf0b22..978bcc9 100644 --- a/Project.toml +++ b/Project.toml @@ -4,26 +4,41 @@ version = "0.5.0" [deps] AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c" +Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +OperatorCore = "3945cd23-d97e-4db0-9df2-35342dbd287d" ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9" +ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +[sources] +AbstractOperators = {path = "../AbstractOperators"} +OperatorCore = {path = "../OperatorCore"} +ProximalAlgorithms = {path = "../ProximalAlgorithms.jl"} +ProximalCore = {path = "../ProximalCore.jl"} +ProximalOperators = {path = "../ProximalOperators.jl"} +WaveletOperators = {path = "../AbstractOperators/WaveletOperators"} + [compat] AbstractOperators = "0.4" Aqua = "0.8" +Combinatorics = "1.0.2" DSP = "0.5.1 - 0.8" DifferentiationInterface = "0.6" FFTW = "1" LinearAlgebra = "1" -ProximalAlgorithms = "0.7" -ProximalOperators = "0.16" +OperatorCore = "0.1" +ProximalAlgorithms = "0.8" +ProximalCore = "0.2" +ProximalOperators = "0.17" Random = "1" RecursiveArrayTools = "1 - 3" Test = "1" +WaveletOperators = "0.1" julia = "1.10" [extras] @@ -31,6 +46,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +WaveletOperators = "f3582904-6f60-4bbd-985d-55eab799bc9d" [targets] -test = ["LinearAlgebra", "Test", "Random", "Aqua"] +test = ["Aqua", "LinearAlgebra", "Random", "Test", "WaveletOperators"] diff --git a/src/StructuredOptimization.jl b/src/StructuredOptimization.jl index 2b24d17..cc1082c 100644 --- a/src/StructuredOptimization.jl +++ b/src/StructuredOptimization.jl @@ -2,29 +2,34 @@ module StructuredOptimization using LinearAlgebra using RecursiveArrayTools +using ProximalCore using AbstractOperators using ProximalOperators using ProximalAlgorithms - -import ProximalAlgorithms: ZeroFPR, PANOC, PANOCplus -export ZeroFPR, PANOC, PANOCplus +using Combinatorics: permutations, powerset +using OperatorCore ProximalAlgorithms.value_and_gradient(f, x) = begin y, fy = gradient(f, x) return fy, y end +abstract type AbstractExpression end + +include("syntax/variable.jl") +include("syntax/expressions/expression.jl") +include("syntax/terms/term.jl") + +const TermOrExpr = Union{Term,AbstractExpression} -include("syntax/syntax.jl") include("calculus/precomposeNonlinear.jl") # TODO move to ProximalOperators? -include("arraypartition.jl") # TODO move to ProximalOperators? +include("calculus/sqrNormL2WithNormalOp.jl") # problem parsing include("solvers/terms_extract.jl") include("solvers/terms_properties.jl") -include("solvers/terms_splitting.jl") +include("solvers/parse.jl") # solver calls -include("solvers/solvers_options.jl") include("solvers/build_solve.jl") include("solvers/minimize.jl") diff --git a/src/arraypartition.jl b/src/arraypartition.jl deleted file mode 100644 index 06eff5e..0000000 --- a/src/arraypartition.jl +++ /dev/null @@ -1,36 +0,0 @@ -import ProximalOperators -import RecursiveArrayTools - -@inline function ProximalOperators.prox( - h, - x::RecursiveArrayTools.ArrayPartition, - gamma... -) - # unwrap - y, fy = ProximalOperators.prox(h, x.x, gamma...) - # wrap - return RecursiveArrayTools.ArrayPartition(y), fy -end - -@inline function ProximalOperators.gradient( - h, - x::RecursiveArrayTools.ArrayPartition -) - # unwrap - grad, fx = ProximalOperators.gradient(h, x.x) - # wrap - return RecursiveArrayTools.ArrayPartition(grad), fx -end - -@inline ProximalOperators.prox!( - y::RecursiveArrayTools.ArrayPartition, - h, - x::RecursiveArrayTools.ArrayPartition, - gamma... -) = ProximalOperators.prox!(y.x, h, x.x, gamma...) - -@inline ProximalOperators.gradient!( - y::RecursiveArrayTools.ArrayPartition, - h, - x::RecursiveArrayTools.ArrayPartition -) = ProximalOperators.gradient!(y.x, h, x.x) diff --git a/src/calculus/sqrNormL2WithNormalOp.jl b/src/calculus/sqrNormL2WithNormalOp.jl new file mode 100644 index 0000000..f84ab7b --- /dev/null +++ b/src/calculus/sqrNormL2WithNormalOp.jl @@ -0,0 +1,88 @@ +# squared L2 norm (times a constant, or weighted) precomposed with an operator + +""" + SqrNormL2WithNormalOp(λ=1, L::LinearOperator) + +With a nonnegative scalar `λ`, return the squared Euclidean norm +```math +f(x) = \\tfrac{λ}{2}\\|L * x\\|^2. +``` +With a nonnegative array `λ`, return the weighted squared Euclidean norm +```math +f(x) = \\tfrac{1}{2}∑_i λ_i y_i^2 where y = L * x. +``` + +This is a special case of the more general `Precompose(SqrNormL2(), L, 1, 0)` operator, +where `L` is a linear operator, and only the gradient is needed, not the proximal operator. +The gradient of the precomposed squared norm is +```math +\nabla f(x) = Lᴴ * L * x, +``` +and in many cases, there is an optimized implementation of the normal operator `Lᴴ * L` +that makes the compution of the gradient much faster than the naive implementation. + +A notable drawback of this method is that gradient! does not return the +squared norm of `L * x`, but rather the squared norm of `Lᴴ * L * x` (i.e. the +squared norm of the gradient). Most algorithms, however, tolerate this +difference, and it is much faster to compute. +""" +struct SqrNormL2WithNormalOp{T,SC,L<:AbstractOperator} + A::L + AᴴA::L + lambda::T + function SqrNormL2WithNormalOp(A, lambda) + @assert A isa AbstractOperator + @assert is_linear(A) + if any(lambda .< 0) + error("coefficients in λ must be nonnegative") + else + AᴴA = AbstractOperators.get_normal_op(A) + new{typeof(lambda),all(lambda .> 0),typeof(A)}(A, AᴴA, lambda) + end + end +end + +is_convex(::Type{<:SqrNormL2WithNormalOp}) = true +is_smooth(::Type{<:SqrNormL2WithNormalOp}) = true +is_separable(::Type{<:SqrNormL2WithNormalOp}) = true +is_generalized_quadratic(::Type{<:SqrNormL2WithNormalOp}) = true +is_strongly_convex(::Type{SqrNormL2WithNormalOp{T,SC}}) where {T,SC} = SC + +SqrNormL2WithNormalOp(A) = SqrNormL2WithNormalOp(A, 1) + +function (f::SqrNormL2WithNormalOp{S})(x) where {S <: Real} + y = f.A * x + return f.lambda / real(eltype(y))(2) * norm(y)^2 +end + +function (f::SqrNormL2WithNormalOp{<:AbstractArray})(x) + y = f.A * x + R = real(eltype(y)) + sqnorm = R(0) + for k in eachindex(y) + sqnorm += f.lambda[k] * abs2(y[k]) + end + return sqnorm / R(2) +end + +function gradient!(y, f::SqrNormL2WithNormalOp{<:Real}, x) + R = real(eltype(y)) + mul!(y, f.AᴴA, x) + sqnx = R(0) + for k in eachindex(y) + y[k] *= f.lambda + sqnx += abs2(y[k]) + end + return f.lambda / R(2) * sqnx +end + +function gradient!(y, f::SqrNormL2WithNormalOp{<:AbstractArray}, x) + R = real(eltype(y)) + mul!(y, f.AᴴA, x) + sqnx = R(0) + for k in eachindex(y) + y[k] *= f.lambda[k] + sqnx += f.lambda[k] * abs2(y[k]) + end + return sqnx / R(2) +end diff --git a/src/solvers/build_solve.jl b/src/solvers/build_solve.jl index 99aa8b0..64aa56a 100644 --- a/src/solvers/build_solve.jl +++ b/src/solvers/build_solve.jl @@ -1,3 +1,5 @@ +const ForwardBackwardSolver = ProximalAlgorithms.IterativeAlgorithm + """ parse_problem(terms::Tuple, solver::ForwardBackwardSolver) @@ -19,30 +21,78 @@ julia> p = problem( ls(A*x - b ) , norm(x) <= 1 ); julia> StructuredOptimization.parse_problem(p, PANOCplus()); ``` """ -function parse_problem(terms::Tuple, solver::T) where T <: ForwardBackwardSolver - x = extract_variables(terms) - # Separate smooth and nonsmooth - smooth, nonsmooth = split_smooth(terms) - if is_proximable(nonsmooth) - g = extract_proximable(x, nonsmooth) - kwargs = Dict{Symbol, Any}(:g => g) - if !isempty(smooth) - if is_linear(smooth) - f = extract_functions(smooth) - A = extract_operators(x, smooth) - kwargs[:A] = A - else # ?? - f = extract_functions_nodisp(smooth) - A = extract_affines(x, smooth) - f = PrecomposeNonlinear(f, A) - end - kwargs[:f] = f - end - return (x, kwargs) - end - error("Sorry, I cannot parse this problem for solver of type $(T)") +function parse_problem(terms::NTuple{N,StructuredOptimization.Term}, algorithm::T, return_partial::Bool = false) where {N,T <: ForwardBackwardSolver} + assumptions = ProximalAlgorithms.get_assumptions(algorithm) + variables = StructuredOptimization.extract_variables(terms) + remaining_terms = terms + kwargs = Dict{Symbol, Any}() + for assumption in assumptions + for term_selection in reverse(collect(powerset(remaining_terms, 1))) + term_selection = tuple(term_selection...) + preparation_result = StructuredOptimization.prepare(term_selection, assumption, variables) + if preparation_result !== nothing + term_selection = collect(term_selection) + remaining_terms = setdiff(remaining_terms, term_selection) + push!(kwargs, preparation_result...) + break + end + end + if isempty(remaining_terms) + return algorithm, kwargs, variables + end + end + return return_partial ? (kwargs, remaining_terms) : nothing +end + +function print_diagnostics(terms::NTuple{N,StructuredOptimization.Term}, algorithm::T) where {N,T <: ForwardBackwardSolver} + kwargs, remaining_terms = parse_problem(terms, algorithm, true) + print("The algorithm $algorithm assumes problem of form: ") + show(ProximalAlgorithms.get_assumptions(algorithm)) + if !isempty(kwargs) + println("Successfully prepared the following terms:") + for (key, value) in kwargs + println(" - $key: $value") + end + end + println("The following terms could not be prepared:") + for term in remaining_terms + println(" - $term") + end end +function parse_problem(terms::NTuple{N,StructuredOptimization.Term}) where {N} + for algorithm in ProximalAlgorithms.get_algorithms() + result = parse_problem(terms, algorithm) + if result !== nothing + return result + end + end + return nothing +end + +function suggest_algorithm(terms::NTuple{N,StructuredOptimization.Term}) where {N} + suitable_algs = [] + for algorithm in ProximalAlgorithms.get_algorithms() + result = parse_problem(terms, algorithm) + if result !== nothing + push!(suitable_algs, algorithm) + end + end + return suitable_algs +end + +function print_diagnostics(terms::NTuple{N,StructuredOptimization.Term}) where {N} + best_algorithm, best_algorithm_remaining_terms = nothing, Inf + for algorithm in ProximalAlgorithms.get_algorithms() + _, remaining_terms = parse_problem(terms, algorithm, true) + if length(remaining_terms) < best_algorithm_remaining_terms + best_algorithm_remaining_terms = length(remaining_terms) + best_algorithm = algorithm + end + end + println("The closest algorithm to the problem is $best_algorithm") + print_diagnostics(terms, best_algorithm) +end export solve @@ -69,7 +119,25 @@ julia> ~x ``` """ function solve(terms::Tuple, solver::ForwardBackwardSolver) - x, kwargs = parse_problem(terms, solver) + result = parse_problem(terms, solver) + if result === nothing + print_diagnostics(terms, solver) + error("Sorry, I cannot parse this problem for solver of type $(solver)") + end + _, kwargs, x = result + x_star, it = solver(; x0 = ~x, kwargs...) + ~x .= x_star isa Tuple ? x_star[1] : x_star + return x, it +end + +function solve(terms::Tuple) + result = parse_problem(terms) + if result === nothing + print_diagnostics(terms) + error("Sorry, I cannot find a suitable solver for this problem") + end + solver, kwargs, x = result + @show solver x_star, it = solver(; x0 = ~x, kwargs...) ~x .= x_star return x, it diff --git a/src/solvers/minimize.jl b/src/solvers/minimize.jl index b22a5e0..288b35a 100644 --- a/src/solvers/minimize.jl +++ b/src/solvers/minimize.jl @@ -1,4 +1,40 @@ -export @minimize +export problem, @minimize + +""" + problems(terms...) + +Constructs a problem. + +# Example + +```julia + +julia> x = Variable(4) +Variable(Float64, (4,)) + +julia> A, b = randn(10,4), randn(10); + +julia> p = problem(ls(A*x-b), norm(x) <= 1) + +``` + +""" +function problem(terms::Vararg) + cf = () + for i = 1:length(terms) + cf = (cf...,terms[i]...) + end + return cf +end + +function expand_terms_with_repr(expr) + if expr isa Expr && expr.head == :call && expr.args[1] == :+ + terms = map(t -> :(Term($(esc(t)), $(string(t)))), expr.args[2:end]) + return :(tuple($(terms...))) + else + return :(Term($(esc(expr)), $(string(expr)))) + end +end """ @minimize cost [st ctr] [with slv_opt] @@ -29,28 +65,28 @@ Returns as output a tuple containing the optimization variables and the number of iterations spent by the solver algorithm. """ macro minimize(cf::Union{Expr, Symbol}) - cost = esc(cf) - return :(solve(problem($(cost)), default_solver())) + cost = expand_terms_with_repr(cf) + return :(solve(problem($cost))) end macro minimize(cf::Union{Expr, Symbol}, s::Symbol, cstr::Union{Expr, Symbol}) - cost = esc(cf) - if s == :(st) - constraints = esc(cstr) - return :(solve(problem($(cost), $(constraints)), default_solver())) - elseif s == :(with) + cost = expand_terms_with_repr(cf) + if s == :st + constraints = expand_terms_with_repr(cstr) + return :(solve(problem($cost, $constraints))) + elseif s == :with solver = esc(cstr) - return :(solve(problem($(cost)), $(solver))) + return :(solve(problem($cost), $solver)) else error("wrong symbol after cost function! use `st` or `with`") end end macro minimize(cf::Union{Expr, Symbol}, s::Symbol, cstr::Union{Expr, Symbol}, w::Symbol, slv::Union{Expr, Symbol}) - cost = esc(cf) - s != :(st) && error("wrong symbol after cost function! use `st`") - constraints = esc(cstr) - w != :(with) && error("wrong symbol after constraints! use `with`") + cost = expand_terms_with_repr(cf) + s != :st && error("wrong symbol after cost function! use `st`") + constraints = expand_terms_with_repr(cstr) + w != :with && error("wrong symbol after constraints! use `with`") solver = esc(slv) - return :(solve(problem($(cost), $(constraints)), $(solver))) + return :(solve(problem($cost, $constraints), $solver)) end diff --git a/src/solvers/parse.jl b/src/solvers/parse.jl new file mode 100644 index 0000000..6336c36 --- /dev/null +++ b/src/solvers/parse.jl @@ -0,0 +1,442 @@ +function add_to_incompatibilities(incompatibilities, t1, t2) + if haskey(incompatibilities, t1) + push!(incompatibilities[t1], t2) + else + incompatibilities[t1] = Set([t2]) + end + if haskey(incompatibilities, t2) + push!(incompatibilities[t2], t1) + else + incompatibilities[t2] = Set([t1]) + end +end + +function group_by_variables(terms) + variable_bags = Dict{Variable, Vector{Any}}() + for term in terms + for var in variables(term) + if haskey(variable_bags, var) + push!(variable_bags[var], term) + else + variable_bags[var] = [term] + end + end + end + return variable_bags +end + +function can_be_separable_sum(variable_bags) + for (var, term_list) in variable_bags + if length(term_list) > 1 # more than one term for this variable + # Check if any of the terms are sliced + operators = [get_operators_for_var(term, var) for term in term_list] + slicing_masks = [OperatorCore.is_sliced(op) ? OperatorCore.get_slicing_mask(op) : nothing for op in operators] + for i in eachindex(operators) + if OperatorCore.is_sliced(operators[i]) + # This operator is sliced, check if it is overlapping with any other sliced operator + for j in i+1:length(operators) + if OperatorCore.is_sliced(operators[j]) && any(slicing_masks[i] .&& slicing_masks[j]) + return false + end + end + else # no slicing -> this term is incompatible with all others + return false + end + end + end + end + return true +end + +function get_unseparable_pairs(variable_bags) + incompatibilities = Dict{StructuredOptimization.Term, Set{StructuredOptimization.Term}}() + for (var, term_list) in variable_bags + if length(term_list) > 1 # more than one term for this variable + # Check if any of the terms are sliced + operators = [get_operators_for_var(term, var) for term in term_list] + slicing_masks = [OperatorCore.is_sliced(op) ? OperatorCore.get_slicing_mask(op) : nothing for op in operators] + for i in eachindex(operators) + if OperatorCore.is_sliced(operators[i]) + # This operator is sliced, check if it is overlapping with any other sliced operator + for j in i+1:length(operators) + if OperatorCore.is_sliced(operators[j]) && any(slicing_masks[i] .&& slicing_masks[j]) + add_to_incompatibilities(incompatibilities, term_list[i], term_list[j]) + end + end + else # no slicing -> this term is incompatible with all others + for j in i+1:length(operators) + add_to_incompatibilities(incompatibilities, term_list[i], term_list[j]) + end + end + end + end + end + return incompatibilities +end + +function merge_function_with_operator(op, f, disp, λ) + if is_eye(op) + f = disp == 0 ? f : PrecomposeDiagonal(f, 1.0, disp) + if size(op, 1) != size(op, 2) + f = ReshapeInput(f, size(op, 1)) + end + elseif is_diagonal(op) + if f isa SqrNormL2 + f = SqrNormL2(f.lambda .* diag(op) .^ 2) + else + f = PrecomposeDiagonal(f, diag(op), disp) + end + elseif is_AAc_diagonal(op) + f = Precompose(f, op, diag_AAc(op), disp) + else + # we assume that prox will not be called on this term because it will not give a valid result + f = Precompose(f, op, 1, disp) + end + return λ == 1 ? f : Postcompose(f, λ) +end + +unsatisfied_properties(term, assumptions::ProximalAlgorithms.AssumptionItem) = [property_func for property_func in assumptions.second if !property_func(term)] +does_satisfy(term, assumptions::ProximalAlgorithms.AssumptionItem) = all(property_func(term) for property_func in assumptions.second) + +function prepare(term::Term, assumption::ProximalAlgorithms.SimpleTerm, variables::NTuple{N, Variable}) where N + if does_satisfy(term, assumption.func) && (!(ProximalCore.is_proximable in assumption.func.second) || OperatorCore.is_AAc_diagonal(term.A.L)) + op = extract_operators(variables, term) + disp = displacement(term) + return (assumption.func.first => merge_function_with_operator(op, term.f, disp, term.lambda),) + else + return nothing + end +end + +function print_diagnostics(term::Term, assumption::ProximalAlgorithms.SimpleTerm, ::NTuple{N, Variable}) where N + repr = term.repr !== nothing ? term.repr : string(term) + problematic_properties = unsatisfied_properties(term, assumption.func) + if length(problematic_properties) == 0 + println("Term $repr satisfies all required properties, but the following operator is not AAc diagonal: ", term.A.L) + else + println("Term $repr does not satisfy required property: $(join(problematic_properties, ", "))") + end +end + +function prepare_proximable_single_var_per_term(variable_bags, variables::NTuple{M, Variable}) where {M} + fs = () + for var in variables + if haskey(variable_bags, var) + term_list = variable_bags[var] + if length(term_list) > 1 + #multiple terms per variable + #currently this happens only with GetIndex + fxi,idxs = (),() + for ti in term_list + op = operator(ti) + fxi = (fxi..., merge_function_with_operator(op, ti.f, displacement(ti), ti.lambda)) + if AbstractOperators.ndoms(op, 2) > 1 + op = op[findfirst(==(var), variables(ti))] + end + if typeof(op) <: Compose + idx = op.A[1].idx + else + idx = op.idx + end + idxs = (idxs..., OperatorCore.get_slicing_mask(op)) + end + fs = (fs..., SlicedSeparableSum(fxi,idxs)) + else + op = operator(term_list[1]) + disp = displacement(term_list[1]) + fs = (fs..., merge_function_with_operator(op, term_list[1].f, disp, term_list[1].lambda)) + end + else + fs = (fs..., IndFree()) + end + end + return SeparableSum(fs) +end + +function prepare(terms::NTuple{N, Term}, assumption::ProximalAlgorithms.SimpleTerm, variables::NTuple{M, Variable}) where {N,M} + if length(terms) == 1 + return prepare(terms[1], assumption, variables) + end + if any(term -> !does_satisfy(term, assumption.func), terms) + return nothing + end + if ProximalCore.is_proximable in assumption.func.second + if any(!is_AAc_diagonal(affine(term)) for term in terms) + return nothing + end + variable_bags = group_by_variables(terms) + if !can_be_separable_sum(variable_bags) + return nothing + end + if all(length.(values(variable_bags)) .== 1) + # all terms references only one variable + return (assumption.func.first => prepare_proximable_single_var_per_term(variable_bags, variables),) + else + op = extract_operators(variables, terms) + idxs = OperatorCore.get_slicing_expr(op) + op = OperatorCore.remove_slicing(op) + hcat_ops = Tuple(op[i] for i in eachindex(op.A)) + μs = AbstractOperators.diag_AAc(op) + f = extract_functions(terms) + return (assumption.func.first => PrecomposedSlicedSeparableSum(f.fs, idxs, hcat_ops, μs),) + end + else + fs = () + for term in terms + if is_linear(term) + f = merge_function_with_operator(operator(term), term.f, displacement(term), term.lambda) + else + f = extract_functions(term) + op = extract_affines(variables, term) + f = PrecomposeNonlinear(f, op) + f = term.lambda == 1 ? f : Postcompose(f, term.lambda) + end + fs = (fs..., f) + end + return (assumption.func.first => SeparableSum(fs),) + end +end + +function print_diagnostics(terms::NTuple{N, Term}, assumption::ProximalAlgorithms.SimpleTerm, variables::NTuple{M, Variable}) where {N,M} + if length(terms) == 1 + print_diagnostics(terms[1], assumption, variables) + return + end + problematic_term_index = findfirst(term -> !does_satisfy(term, assumption.func), terms) + if problematic_term_index !== nothing + problematic_term = terms[problematic_term_index] + repr = problematic_term.repr !== nothing ? problematic_term.repr : string(problematic_term) + problematic_properties = unsatisfied_properties(problematic_term, assumption.func) + println("Term $repr does not satisfy required property: $(join(problematic_properties, ", "))") + elseif any(term -> !is_AAc_diagonal(affine(term)), terms) + println("The following terms contains operators that are not AAc diagonal:") + for term in terms + if !is_AAc_diagonal(affine(term)) + repr = term.repr !== nothing ? term.repr : string(term) + println(" - $repr") + end + end + else + variable_bags = group_by_variables(terms) + incompatibilities = get_unseparable_pairs(variable_bags) + println("The following terms are incompatible with each other:") + for (term, incompatible_terms) in incompatibilities + println(" - $term: $(join(incompatible_terms, ", "))") + end + end +end + +function prepare(term::Term, assumption::ProximalAlgorithms.OperatorTerm, variables::NTuple{N, Variable}) where N + op = extract_affines(variables, term) + if does_satisfy(op, assumption.operator) && does_satisfy(term.f, assumption.func) + return ( + assumption.func.first => term.lambda == 1 ? term.f : Postcompose(term.f, term.lambda), + assumption.operator.first => op + ) + else # try preparing as a simple term + tup = prepare(term, ProximalAlgorithms.SimpleTerm(assumption.func), variables) + if tup !== nothing && length(variables) > 1 + example_input = ArrayPartition(Tuple(~var for var in variables)) + tup = (tup..., assumption.operator.first => AbstractOperators.Eye(example_input)) + end + return tup + end +end + +function print_diagnostics(term::Term, assumption::ProximalAlgorithms.OperatorTerm, variables::NTuple{N, Variable}) where N + op = affine(term) + repr = term.repr !== nothing ? term.repr : string(term) + if OperatorCore.is_eye(op) + problematic_properties = unsatisfied_properties(term.f, assumption.func) + println("Term $repr does not satisfy required properties: $(join(problematic_properties, ", "))") + else + println("A possible decomposition of term $repr:") + f = term.lambda == 1 ? term.f : Postcompose(term.f, term.lambda) + print(" - ", assumption.func.first, " = ", f) + if !does_satisfy(f, assumption.func) + problematic_properties = unsatisfied_properties(f, assumption.func) + println(" -> $(join(problematic_properties, ", ")) $(length(problematic_properties) == 1 ? "property is" : "properties are") not satisfied") + else + println() + end + print(" - ", assumption.operator.first, " = ", op) + if !does_satisfy(op, assumption.operator) + problematic_properties = unsatisfied_properties(op, assumption.operator) + println(" -> $(join(problematic_properties, ", ")) $(length(problematic_properties) == 1 ? "property is" : "properties are") not satisfied") + else + println() + end + end + println("When trying to prepare the term as a simple term:") + print_diagnostics(term, ProximalAlgorithms.SimpleTerm(assumption.func), variables) +end + +function prepare(terms::NTuple{N, Term}, assumption::ProximalAlgorithms.OperatorTerm, variables::NTuple{M, Variable}) where {N,M} + if length(terms) == 1 + return prepare(terms[1], assumption, variables) + end + op = extract_affines(variables, terms) + f = extract_functions(terms) + if does_satisfy(op, assumption.operator) && does_satisfy(f, assumption.func) + return ( + assumption.func.first => f, + assumption.operator.first => op + ) + else # try preparing as a simple term + return prepare(terms, ProximalAlgorithms.SimpleTerm(assumption.func), variables) + end +end + +function print_diagnostics(terms::NTuple{N, Term}, assumption::ProximalAlgorithms.OperatorTerm, variables::NTuple{M, Variable}) where {N,M} + op = extract_affines(variables, terms) + f = extract_functions(terms) + repr = string(terms) + if OperatorCore.is_eye(op) + for term in terms + problematic_properties = unsatisfied_properties(term.f, assumption.func) + println("Term $repr does not satisfy required properties: $(join(problematic_properties, ", "))") + end + else + println("A possible decomposition of terms $repr:") + print(" - ", assumption.func.first, " = ", f) + if !does_satisfy(f, assumption.func) + problematic_properties = unsatisfied_properties(f, assumption.func) + println(" -> $(join(problematic_properties, ", ")) $(length(problematic_properties) == 1 ? "property is" : "properties are") not satisfied") + else + println() + end + print(" - ", assumption.operator.first, " = ", op) + if !does_satisfy(op, assumption.operator) + problematic_properties = unsatisfied_properties(op, assumption.operator) + println(" -> $(join(problematic_properties, ", ")) $(length(problematic_properties) == 1 ? "property is" : "properties are") not satisfied") + else + println() + end + end + println("When trying to prepare terms as a simple function:") + print_diagnostics(terms, ProximalAlgorithms.SimpleTerm(assumption.func), variables) +end + +function prepare(term::Term, assumption::ProximalAlgorithms.OperatorTermWithInfimalConvolution, variables::NTuple{M, Variable}) where {M} + op = extract_affines(variables, term) + f = extract_functions(term) + if does_satisfy(op, assumption.operator) && does_satisfy(f, assumption.func₁) + return ( + assumption.func₁.first => f, + assumption.operator.first => op + ) + elseif does_satisfy(op, assumption.operator) && does_satisfy(f, assumption.func₂) + return ( + assumption.func₂.first => f, + assumption.operator.first => affine(term) + ) + else + # try preparing as a simple term + tup = prepare(term, ProximalAlgorithms.SimpleTerm(assumption.func₁), variables) + if tup !== nothing && length(variables) > 1 + example_input = ArrayPartition(Tuple(~var for var in variables)) + tup = (tup..., assumption.operator.first => AbstractOperators.Eye(example_input)) + end + return tup + end +end + +function print_diagnostics(term::Term, assumption::ProximalAlgorithms.OperatorTermWithInfimalConvolution, variables::NTuple{M, Variable}) where {M} + op = affine(term) + f = extract_functions(term) + repr = term.repr !== nothing ? term.repr : string(term) + if OperatorCore.is_eye(op) + problematic_properties = unsatisfied_properties(term.f, assumption.func₁) + println("Term $repr does not satisfy required properties: $(join(problematic_properties, ", "))") + else + println("A possible decomposition of term $repr:") + print(" - ", assumption.func₁.first, " = ", f) + if !does_satisfy(f, assumption.func₁) + problematic_properties = unsatisfied_properties(f, assumption.func₁) + println(" -> $(join(problematic_properties, ", ")) $(length(problematic_properties) == 1 ? "property is" : "properties are") not satisfied") + else + println() + end + print(" - ", assumption.operator.first, " = ", op) + if !does_satisfy(op, assumption.operator) + problematic_properties = unsatisfied_properties(op, assumption.operator) + println(" -> $(join(problematic_properties, ", ")) $(length(problematic_properties) == 1 ? "property is" : "properties are") not satisfied") + else + println() + end + end + println("When trying to prepare the term as a simple term:") + print_diagnostics(term, ProximalAlgorithms.SimpleTerm(assumption.func₁), variables) +end + +function prepare(terms::NTuple{N, Term}, assumption::ProximalAlgorithms.OperatorTermWithInfimalConvolution, variables::NTuple{M, Variable}) where {N,M} + if length(terms) == 1 + return prepare(terms[1], assumption, variables) + end + op = extract_affines(variables, terms) + f = extract_functions(terms) + if does_satisfy(op, assumption.operator) && does_satisfy(f, assumption.func₁) + return ( + assumption.func₁.first => f, + assumption.operator.first => op + ) + elseif does_satisfy(op, assumption.operator) && does_satisfy(f, assumption.func₂) + return ( + assumption.func₂.first => f, + assumption.operator.first => affine(terms[1].A) + ) + else + # try preparing as a simple term + tup = prepare(terms, ProximalAlgorithms.SimpleTerm(assumption.func₁), variables) + if tup === nothing + tup = prepare(terms, ProximalAlgorithms.SimpleTerm(assumption.func₂), variables) + end + if tup !== nothing && length(variables) > 1 + example_input = ArrayPartition(Tuple(~var for var in variables)) + tup = (tup..., assumption.operator.first => AbstractOperators.Eye(example_input)) + end + return tup + end +end + +function print_diagnostics(terms::NTuple{N, Term}, assumption::ProximalAlgorithms.OperatorTermWithInfimalConvolution, variables::NTuple{M, Variable}) where {N,M} + if length(terms) == 1 + print_diagnostics(terms[1], assumption, variables) + return + end + op = affine(terms[1].A) + f = extract_functions(terms) + repr = string(terms) + if OperatorCore.is_eye(op) + for term in terms + problematic_properties = unsatisfied_properties(term.f, assumption.func₁) + println("Term $repr does not satisfy required properties: $(join(problematic_properties, ", "))") + end + else + println("A possible decomposition of terms $repr:") + print(" - ", assumption.func₁.first, " = ", f) + if !does_satisfy(f, assumption.func₁) + problematic_properties = unsatisfied_properties(f, assumption.func₁) + println(" -> $(join(problematic_properties, ", ")) $(length(problematic_properties) == 1 ? "property is" : "properties are") not satisfied") + else + println() + end + print(" - ", assumption.operator.first, " = ", op) + if !does_satisfy(op, assumption.operator) + problematic_properties = unsatisfied_properties(op, assumption.operator) + println(" -> $(join(problematic_properties, ", ")) $(length(problematic_properties) == 1 ? "property is" : "properties are") not satisfied") + else + println() + println("Alteratively, one can try to prepare the function part as:") + print(" - ", assumption.func₂.first, " = ", f) + if !does_satisfy(f, assumption.func₂) + problematic_properties = unsatisfied_properties(f, assumption.func₂) + println(" -> $(join(problematic_properties, ", ")) $(length(problematic_properties) == 1 ? "property is" : "properties are") not satisfied") + else + println() + end + end + end + println("When trying to prepare the term as a simple term:") + print_diagnostics(terms, ProximalAlgorithms.SimpleTerm(assumption.func₁), variables) +end diff --git a/src/solvers/solvers_options.jl b/src/solvers/solvers_options.jl deleted file mode 100644 index ff6b963..0000000 --- a/src/solvers/solvers_options.jl +++ /dev/null @@ -1,5 +0,0 @@ -using ProximalAlgorithms - -const ForwardBackwardSolver = ProximalAlgorithms.IterativeAlgorithm - -const default_solver = ProximalAlgorithms.PANOC diff --git a/src/solvers/terms_extract.jl b/src/solvers/terms_extract.jl index ab57fd6..5fbd207 100644 --- a/src/solvers/terms_extract.jl +++ b/src/solvers/terms_extract.jl @@ -2,7 +2,9 @@ extract_variables(t::TermOrExpr) = variables(t) function extract_variables(t::NTuple{N,TermOrExpr}) where {N} - return tuple(unique(variables.(t))...) + var_tuples = variables.(t) + vars = vcat(collect.(var_tuples)...) + return tuple(unique(vars)...) end # extract functions from terms @@ -41,7 +43,7 @@ function extract_operators(xAll::NTuple{N,Variable}, t::NTuple{M,TermOrExpr}) wh return vcat(ops...) end -sort_and_extract_operators(xAll::Tuple{Variable}, t::TermOrExpr) = operator(t) +sort_and_extract_operators(::Tuple{Variable}, t::TermOrExpr) = operator(t) function sort_and_extract_operators(xAll::NTuple{N,Variable}, t::TermOrExpr) where {N} p = zeros(Int,N) @@ -57,8 +59,7 @@ end # returns all affines with an order dictated by xAll #single term, single variable -extract_affines(xAll::Tuple{Variable}, t::TermOrExpr) = affine(t) - +extract_affines(::Tuple{Variable}, t::TermOrExpr) = affine(t) extract_affines(xAll::NTuple{N,Variable}, t::TermOrExpr) where {N} = extract_affines(xAll, (t,)) #multiple terms, multiple variables @@ -71,7 +72,7 @@ function extract_affines(xAll::NTuple{N,Variable}, t::NTuple{M,TermOrExpr}) wher return vcat(ops...) end -sort_and_extract_affines(xAll::Tuple{Variable}, t::TermOrExpr) = affine(t) +sort_and_extract_affines(::Tuple{Variable}, t::TermOrExpr) = affine(t) function sort_and_extract_affines(xAll::NTuple{N,Variable}, t::TermOrExpr) where {N} p = zeros(Int,N) @@ -110,62 +111,3 @@ function expand(xAll::NTuple{N,Variable}, ex::AbstractExpression) where {N} end return ex end - -# extract function and merge operator -function extract_merge_functions(t::Term) - if is_sliced(t) - if typeof(operator(t)) <: Compose - op = operator(t).A[2] - else - op = Eye(size(operator(t),1)...) - end - else - op = operator(t) - end - if is_eye(op) - f = displacement(t) == 0 ? t.f : PrecomposeDiagonal(t.f, 1.0, displacement(t)) - elseif is_diagonal(op) - f = PrecomposeDiagonal(t.f, diag(op), displacement(t)) - elseif is_AAc_diagonal(op) - f = Precompose(t.f, op, diag_AAc(op), displacement(t)) - end - f = t.lambda == 1. ? f : Postcompose(f, t.lambda) #for now I keep this - #TODO change this - return f -end - -function extract_proximable(xAll::NTuple{N,Variable}, t::NTuple{M,Term}) where {N,M} - fs = () - for x in xAll - tx = () #terms containing x - for ti in t - if x in variables(ti) - tx = (tx...,ti) #collect terms containing x - end - end - if isempty(tx) - fx = IndFree() - elseif length(tx) == 1 #only one term per variable - fx = extract_proximable(x,tx[1]) - else - #multiple terms per variable - #currently this happens only with GetIndex - fxi,idxs = (),() - for ti in tx - fxi = (fxi..., extract_merge_functions(ti)) - idx = typeof(operator(ti)) <: Compose ? operator(ti).A[1].idx : operator(ti).idx - idxs = (idxs..., idx ) - end - fx = SlicedSeparableSum(fxi,idxs) - end - fs = (fs...,fx) - end - if length(fs) > 1 - return SeparableSum(fs) ##probably change constructor in Prox? - else - return fs[1] - end -end - -extract_proximable(xAll::Variable, t::Term) = extract_merge_functions(t) -extract_proximable(xAll::NTuple{N,Variable}, t::Term) where {N} = extract_proximable(xAll,(t,)) diff --git a/src/solvers/terms_properties.jl b/src/solvers/terms_properties.jl index a95b4f3..fe987c6 100644 --- a/src/solvers/terms_properties.jl +++ b/src/solvers/terms_properties.jl @@ -1,25 +1,45 @@ is_proximable(term::Term) = is_AAc_diagonal(term) -function is_proximable(terms::Tuple) - # Check that each term is proximable - if any(is_proximable.(terms) .== false) - return false - end +function get_operators_for_var(term, var) + full_operator = affine(term) + if AbstractOperators.ndoms(full_operator, 2) == 1 + return full_operator + else + return full_operator[findfirst(==(var), variables(term))] + end +end + +function is_separable_sum(terms::NTuple{N,Term}) where {N} # Construct the set of occurring variables vars = Set() for term in terms union!(vars, variables(term)) end # Check that each variable occurs in only one term - for v in vars - tv = [t for t in terms if v in variables(t)] - if length(tv) != 1 - if all( is_sliced.(tv) ) && all( is_proximable.(tv) ) - return true - else + for var in vars + terms_with_var = [t for t in terms if var in variables(t)] + if length(terms_with_var) != 1 + # All terms must be either or have a single variable + if ! all( length(variables(term)) == 1 || is_separable(term.f) for term in terms_with_var ) return false end + # All terms must be sliced for this variable + operators = [get_operators_for_var(term, var) for term in terms_with_var] + if any(!OperatorCore.is_sliced(op) for op in operators) + return false + end + # The sliced operators must not overlap + slicing_masks = [OperatorCore.is_sliced(op) ? OperatorCore.get_slicing_mask(op) : nothing for op in operators] + for i in eachindex(operators), j in i+1:length(operators) + if any(slicing_masks[i] .&& slicing_masks[j]) + return false + end + end end end return true end + +function is_proximable(terms::NTuple{N,Term}) where {N} + return all(is_proximable.(terms)) && is_separable_sum(terms) +end diff --git a/src/solvers/terms_splitting.jl b/src/solvers/terms_splitting.jl deleted file mode 100644 index a1dad74..0000000 --- a/src/solvers/terms_splitting.jl +++ /dev/null @@ -1,31 +0,0 @@ -# -# """ -# `split_smooth(cf::Vararg{Term}) -> (smooth, nonsmooth)` -# -# Splits cost function into `SmoothFunction` and `NonSmoothFunction` terms. -# """ -# split_smooth(cf::Vararg{Term}) = cf[findall(is_smooth(cf))],cf[findall((!).(is_smooth(cf)))] -# split_smooth{N}(cf::NTuple{N,Term}) = split_smooth(cf...) -# -# """ -# `split_AAc_diagonal(cf::Vararg{Term}) -> (proximable, non_proximable)` -# -# Splits cost function into terms with L'*L diagonal operator. -# """ -# split_AAc_diagonal(cf::Vararg{Term}) = cf[findall(is_AAc_diagonal(cf))],cf[findall((!).(is_AAc_diagonal(cf)))] -# split_AAc_diagonal{N}(cf::NTuple{N,Term}) = split_AAc_diagonal(cf...) -# -# #""" TODO -# #`split_Quadratic(cf::Vararg{Term}) -> (quadratic, non_quadratic)` -# # -# #Splits cost function into `QuadraticFunction` and non `QuadraticFunction` terms. -# #""" - -split_smooth(terms::Tuple) = - terms[findall(is_smooth.(terms))], terms[findall((!).(is_smooth.(terms)))] - -split_quadratic(terms::Tuple) = - terms[findall(is_quadratic.(terms))], terms[findall((!).(is_quadratic.(terms)))] - -split_AAc_diagonal(terms::Tuple) = - terms[findall(is_AAc_diagonal.(terms))], terms[findall((!).(is_AAc_diagonal.(terms)))] diff --git a/src/syntax/expressions/abstractOperator_bind.jl b/src/syntax/expressions/abstractOperator_bind.jl index c6edbb0..6f38a8c 100644 --- a/src/syntax/expressions/abstractOperator_bind.jl +++ b/src/syntax/expressions/abstractOperator_bind.jl @@ -19,7 +19,7 @@ julia> reshape(A*x-b,2,5) function reshape(a::AbstractExpression, dims...) A = convert(Expression,a) op = Reshape(A.L, dims...) - return Expression{length(A.x)}(A.x,op) + return Expression(A.x,op) end #Reshape diff --git a/src/syntax/expressions/addition.jl b/src/syntax/expressions/addition.jl index aee3125..e321e1b 100644 --- a/src/syntax/expressions/addition.jl +++ b/src/syntax/expressions/addition.jl @@ -48,14 +48,14 @@ function (+)(a::AbstractExpression, b::AbstractExpression) A = convert(Expression,a) B = convert(Expression,b) if variables(A) == variables(B) - return Expression{length(A.x)}(A.x,affine(A)+affine(B)) + return Expression(A.x,affine(A)+affine(B)) else opA = affine(A) xA = variables(A) opB = affine(B) xB = variables(B) xNew, opNew = Usum_op(xA,xB,opA,opB,true) - return Expression{length(xNew)}(xNew,opNew) + return Expression(xNew,opNew) end end # sum expressions @@ -64,14 +64,14 @@ function (-)(a::AbstractExpression, b::AbstractExpression) A = convert(Expression,a) B = convert(Expression,b) if variables(A) == variables(B) - return Expression{length(A.x)}(A.x,affine(A)-affine(B)) + return Expression(A.x,affine(A)-affine(B)) else opA = affine(A) xA = variables(A) opB = affine(B) xB = variables(B) xNew, opNew = Usum_op(xA,xB,opA,opB,false) - return Expression{length(xNew)}(xNew,opNew) + return Expression(xNew,opNew) end end @@ -112,7 +112,7 @@ function Usum_op(xA::Tuple{Variable}, xB::NTuple{N,Variable}, A::AbstractOperato end #unsigned sum: HCAT+HCAT -function Usum_op(xA::NTuple{NA,Variable}, xB::NTuple{NB,Variable}, A::HCAT{NB}, B::HCAT{NB}, sign::Bool) where {NA,NB} +function Usum_op(xA::NTuple{NA,Variable}, xB::NTuple{NB,Variable}, A::HCAT{NA}, B::HCAT{NB}, sign::Bool) where {NA,NB} xNew = xA opNew = A for i in eachindex(xB) @@ -136,6 +136,20 @@ function Usum_op( return xNew, opNew end +function Usum_op( + xA::Tuple{Variable}, xB::NTuple{N,Variable}, A::AbstractOperator, B::AbstractOperator, sign::Bool +) where {N} + if xA[1] in xB + Z = Zeros(B) #this will be an HCAT + xNew, opNew = Usum_op(xA,xB,A,Z,sign) + opNew += B + else + xNew = (xA...,xB...) + opNew = sign ? hcat(A,B) : hcat(A,-B) + end + return xNew, opNew +end + """ +(ex::AbstractExpression, b::Union{AbstractArray,Number}) @@ -170,19 +184,19 @@ julia> ex + b """ function (+)(a::AbstractExpression, b::Union{AbstractArray,Number}) A = convert(Expression,a) - return Expression{length(A.x)}(A.x,AffineAdd(affine(A),b)) + return Expression(A.x,AffineAdd(affine(A),b)) end (+)(a::Union{AbstractArray,Number}, b::AbstractExpression) = b+a function (-)(a::AbstractExpression, b::Union{AbstractArray,Number}) A = convert(Expression,a) - return Expression{length(A.x)}(A.x,AffineAdd(affine(A),b,false)) + return Expression(A.x,AffineAdd(affine(A),b,false)) end function (-)(a::Union{AbstractArray,Number}, b::AbstractExpression) B = convert(Expression,b) - return Expression{length(B.x)}(B.x,-AffineAdd(affine(B),a)) + return Expression(B.x,-AffineAdd(affine(B),a)) end # sum with array/scalar @@ -193,10 +207,10 @@ function Broadcast.broadcasted(::typeof(+),a::AbstractExpression, b::AbstractExp B = convert(Expression,b) if size(affine(A),1) != size(affine(B),1) if prod(size(affine(A),1)) > prod(size(affine(B),1)) - B = Expression{length(B.x)}(variables(B), + B = Expression(variables(B), BroadCast(affine(B),size(affine(A),1))) elseif prod(size(affine(B),1)) > prod(size(affine(A),1)) - A = Expression{length(A.x)}(variables(A), + A = Expression(variables(A), BroadCast(affine(A),size(affine(B),1))) end return A+B @@ -209,10 +223,10 @@ function Broadcast.broadcasted(::typeof(-),a::AbstractExpression, b::AbstractExp B = convert(Expression,b) if size(affine(A),1) != size(affine(B),1) if prod(size(affine(A),1)) > prod(size(affine(B),1)) - B = Expression{length(B.x)}(variables(B), + B = Expression(variables(B), BroadCast(affine(B),size(affine(A),1))) elseif prod(size(affine(B),1)) > prod(size(affine(A),1)) - A = Expression{length(A.x)}(variables(A), + A = Expression(variables(A), BroadCast(affine(A),size(affine(B),1))) end return A-B diff --git a/src/syntax/expressions/addition_tricky_part.jl b/src/syntax/expressions/addition_tricky_part.jl new file mode 100644 index 0000000..baeaca8 --- /dev/null +++ b/src/syntax/expressions/addition_tricky_part.jl @@ -0,0 +1,231 @@ +using Base.Iterators: flatten +abstract type OpStructure end + +struct HCatStructure{N} <: OpStructure + op::AbstractOperators.AbstractOperator + structure::NTuple{N,Any} +end + +struct SumStructure{N} <: OpStructure + op::AbstractOperators.AbstractOperator + structure::NTuple{N,Any} +end + +function get_structure(op::AbstractOperators.HCAT, vars) + if length(op.A) == AbstractOperators.ndoms(op, 2) # this is the deepest or only HCAT operator + return HCatStructure(op, vars) + else # there are more nested HCAT operators, let's recurse! + result = () + var_group_counter = 1 + for suboperator in op.A + subvars = vars[var_group_counter:var_group_counter+AbstractOperators.ndoms(suboperator, 2)-1] + if AbstractOperators.ndoms(suboperator, 2) == 1 + returned = subvars + else + returned = get_structure(suboperator, subvars) + @assert returned !== nothing + end + if returned isa Tuple + result = (result..., returned...) + else + result = (result..., returned) + end + var_group_counter += AbstractOperators.ndoms(suboperator, 2) + end + return HCatStructure(op, result) + end +end + +function get_structure(op::AbstractOperators.Sum, vars) + return SumStructure(op, tuple((get_structure(suboperator, vars) for suboperator in op.A)...)) +end + +function get_structure(op, vars) + if op isa AbstractOperators.AbstractOperator && AbstractOperators.ndoms(op, 2) == 1 + return SumStructure(op, vars) + else + for k in 1:fieldcount(typeof(op)) + value = getfield(op, k) + if value isa AbstractOperators.AbstractOperator + return get_structure(value, vars) + elseif value isa Tuple + for v in value + return get_structure(v, vars) + end + end + end + @assert false "This should never happen" + end +end + +function deep_flatten(structure::HCatStructure) + result = () + for item in structure.structure + if isa(item, OpStructure) + sub_flattened = deep_flatten(item) + if sub_flattened === nothing + return nothing + end + result = tuple(result..., sub_flattened...) + else + result = tuple(result..., item) + end + end + return result +end + +function deep_flatten(structure::SumStructure) + nested_structures = tuple((deep_flatten(item) for item in structure.structure)...) + if all(==(nested_structures[1]), nested_structures) + return nested_structures[1] + else + return nothing + end +end + +struct UnregularIndex{N} + max::NTuple{N, Int} + UnregularIndex(max) = any(max .< 1) ? error("max must be >= 1") : new{length(max)}(tuple(max...)) +end + +Base.first(iter::UnregularIndex) = tuple(fill(1, length(iter.max))...) +Base.length(iter::UnregularIndex) = sum(iter.max) + +function Base.iterate(iter::UnregularIndex) + state = first(iter) + return state, state +end + +function Base.iterate(iter::UnregularIndex{N}, state::NTuple{N, Int}) where {N} + if state == iter.max + return nothing + end + currentdim = findfirst(i -> state[i] != iter.max[i], 1:N) + nextstate = tuple((j < currentdim ? 1 : (j == currentdim ? state[j]+1 : state[j]) for j in 1:N)...) + return nextstate, nextstate +end + +get_structure_only(str) = str isa OpStructure ? tuple((get_structure_only(item) for item in str.structure)...) : str + +Base.length(str::OpStructure) = length(str.structure) +Base.getindex(str::OpStructure, i) = str.structure[i] + +permute_structure(str, perm) = tuple((str[i][perm[i]] for i in eachindex(str))...) + +function compute_permutations(st) + result = () + for perm in UnregularIndex(length.(st)) + result = (result..., permute_structure(st, perm)) + end + return result +end + +function get_all_permutations(structure::SumStructure) + product = [get_all_permutations(item) for item in structure.structure] + return tuple((SumStructure(structure.op, st) for st in compute_permutations(product))...) +end + +function get_all_permutations(structure::HCatStructure) + nested_perms = [isa(item, Int) ? (item,) : get_all_permutations(item) for item in structure.structure] + product = compute_permutations(nested_perms) + combinations = flatten(permutations(p) for p in product) + return tuple((HCatStructure(structure.op, tuple(p...)) for p in combinations)...) +end + +function find_feasible_permutation(vars, stA, stB) + stA_perms = get_all_permutations(stA) + stB_perms = get_all_permutations(stB) + stA_pairs = filter(pair -> pair[2] !== nothing, [(s, deep_flatten(s)) for s in stA_perms]) + stB_pairs = filter(pair -> pair[2] !== nothing, [(s, deep_flatten(s)) for s in stB_perms]) + for vars_perm in permutations(vars) + vars_perm = tuple(vars_perm...) + stA_perm = findfirst(pair -> pair[2] == vars_perm, stA_pairs) + if stA_perm === nothing + continue + end + stB_perm = findfirst(pair -> pair[2] == vars_perm, stB_pairs) + if stB_perm === nothing + continue + end + return vars_perm + end + return nothing +end + +function add_missing_vars(old_vars, op, vars) + missing_vars = setdiff(vars, old_vars) + if isempty(missing_vars) + return old_vars, op + end + dummy_ops = [AbstractOperators.Zeros(eltype(~var), size(~var), AbstractOperators.codomainType(op), size(op, 1)) for var in missing_vars] + new_vars = (old_vars..., missing_vars...) + new_op = AbstractOperators.HCAT(op, dummy_ops...) + return new_vars, new_op +end + +function Usum_op( + xA::NTuple{N,Variable}, xB::NTuple{M,Variable}, A::AbstractOperator, B::AbstractOperator, sign::Bool +) where {N,M} + xNew = tuple(unique((xA...,xB...))...) + xA, A = add_missing_vars(xA, A, xNew) + xB, B = add_missing_vars(xB, B, xNew) + vars_index = tuple((i for i in eachindex(xNew))...) + xA_index = tuple((findfirst(==(x), xNew) for x in xA)...) + xB_index = tuple((findfirst(==(x), xNew) for x in xB)...) + structureA = get_structure(A, xA_index) + structureB = get_structure(B, xB_index) + var_perm = find_feasible_permutation(vars_index, structureA, structureB) + if var_perm === nothing + error("No feasible permutation found") + end + if var_perm != xA_index + A = AbstractOperators.permute(A, invperm([xA_index...])) + end + if var_perm != xB_index + B = AbstractOperators.permute(B, invperm([xB_index...])) + end + opNew = sign ? A+B : A-B + return xNew, opNew +end + +#= +function _replace_in(obj, tasks) + for task in tasks + if obj === task.first + return task.second, filter(t -> t !== task, tasks) + end + end + return obj, tasks +end +function _replace_in(obj::Tuple, tasks) + new_tuple = [] + for o in obj + new_obj, tasks = _replace_in(o, tasks) + push!(new_tuple, new_obj) + end + return tuple(new_tuple...), tasks +end +function _replace_in(obj::AbstractOperators.AbstractOperator, tasks) + fields = [getfield(obj, name) for name in fieldnames(typeof(obj))] + new_fields = [_replace_in(field, searched_obj, new_obj) for field in fields] + maybe_new_obj = any(new_fields .!== fields) ? typeof(obj).name.wrapper(new_fields...) : obj + return maybe_new_obj, tasks +end +function permute_single_operator(op::AbstractOperators.HCAT, perm::Vector{Int}) + @show op + @show perm + return AbstractOperators.HCAT([op[i] for i in perm]...) +end +function permute_operator(op::AbstractOperators.AbstractOperator, permutations) + @show permutations + tasks = [(old_op => permute_single_operator(old_op, perm)) for (old_op, perm) in reverse(permutations)] + #=for (old_op, perm) in reverse(permutations) + new_op = permute_single_operator(old_op, perm) + @show op + @show old_op + @show new_op + op = _replace_in(op, old_op, new_op) + end=# + return _replace_in(op, tasks) + #return op +end=# diff --git a/src/syntax/expressions/expression.jl b/src/syntax/expressions/expression.jl index 08d1f53..fb619c3 100644 --- a/src/syntax/expressions/expression.jl +++ b/src/syntax/expressions/expression.jl @@ -1,7 +1,7 @@ struct Expression{N,A<:AbstractOperator} <: AbstractExpression x::NTuple{N,Variable} L::A - function Expression{N}(x::NTuple{N,Variable}, L::A) where {N,A<:AbstractOperator} + function Expression(x::NTuple{N,Variable}, L::A) where {N,A<:AbstractOperator} # checks on L ndoms(L,1) > 1 && throw(ArgumentError( "Cannot create expression with LinearOperator with `ndoms(L,1) > 1`" @@ -27,12 +27,21 @@ struct AdjointExpression{E <: AbstractExpression} <: AbstractExpression ex::E end -import Base: adjoint +import Base: adjoint, show adjoint(ex::AbstractExpression) = AdjointExpression(convert(Expression,ex)) adjoint(ex::AdjointExpression) = ex.ex +function show(io::IO, ex::Expression) + if length(ex.x) == 1 + print(io, AbstractOperators.fun_name(ex.L), " * ", ex.x[1]) + else + print(io, AbstractOperators.fun_name(ex.L), " * (", join(ex.x, ", "), ")") + end +end + include("utils.jl") include("multiplication.jl") include("addition.jl") +include("addition_tricky_part.jl") include("abstractOperator_bind.jl") diff --git a/src/syntax/expressions/multiplication.jl b/src/syntax/expressions/multiplication.jl index a99f84f..3f7ac8e 100644 --- a/src/syntax/expressions/multiplication.jl +++ b/src/syntax/expressions/multiplication.jl @@ -27,7 +27,7 @@ julia> affine(ex2) """ function (*)(L::AbstractOperator, a::AbstractExpression) A = convert(Expression,a) - Expression{length(A.x)}(A.x,L*affine(A)) + Expression(A.x,L*affine(A)) end """ @@ -94,7 +94,7 @@ d.*a function (*)(coeff::T1, a::T) where {T1<:Number, T<:AbstractExpression} A = convert(Expression,a) - return Expression{length(A.x)}(A.x,coeff*affine(A)) + return Expression(A.x,coeff*affine(A)) end (*)(a::T, coeff::T1) where {T1<:Number, T<:AbstractExpression} = coeff*a ##Scale @@ -132,7 +132,7 @@ function (*)(ex1::AbstractExpression, ex2::AbstractExpression) A = extract_affines(x, ex1) B = extract_affines(x, ex2) op = Ax_mul_Bx(A,B) - exp3 = Expression{length(x)}(x,op) + exp3 = Expression(x,op) return exp3 end # Ax_mul_Bx @@ -144,7 +144,7 @@ function (*)(ex1::AdjointExpression, ex2::AbstractExpression) A = extract_affines(x, ex1) B = extract_affines(x, ex2) op = Axt_mul_Bx(A,B) - exp3 = Expression{length(x)}(x,op) + exp3 = Expression(x,op) return exp3 end # Axt_mul_Bx @@ -156,7 +156,7 @@ function (*)(ex1::AbstractExpression, ex2::AdjointExpression) A = extract_affines(x, ex1) B = extract_affines(x, ex2) op = Ax_mul_Bxt(A,B) - exp3 = Expression{length(x)}(x,op) + exp3 = Expression(x,op) return exp3 end # Ax_mul_Bxt @@ -168,7 +168,7 @@ function Broadcast.broadcasted(::typeof(*), ex1::AbstractExpression, ex2::Abstra A = extract_affines(x, ex1) B = extract_affines(x, ex2) op = HadamardProd(A,B) - exp3 = Expression{length(x)}(x,op) + exp3 = Expression(x,op) return exp3 end # Hadamard diff --git a/src/syntax/expressions/utils.jl b/src/syntax/expressions/utils.jl index 7c0af76..69f11b2 100644 --- a/src/syntax/expressions/utils.jl +++ b/src/syntax/expressions/utils.jl @@ -4,7 +4,7 @@ import Base: convert import AbstractOperators: displacement convert(::Type{Expression},x::Variable{T,N,A}) where {T,N,A} = -Expression{1}((x,),Eye(T,size(x))) +Expression((x,),Eye(T,size(x))) """ variables(ex::Expression) diff --git a/src/syntax/problem.jl b/src/syntax/problem.jl deleted file mode 100644 index 4387ddd..0000000 --- a/src/syntax/problem.jl +++ /dev/null @@ -1,28 +0,0 @@ -export problem - -""" - problems(terms...) - -Constructs a problem. - -# Example - -```julia - -julia> x = Variable(4) -Variable(Float64, (4,)) - -julia> A, b = randn(10,4), randn(10); - -julia> p = problem(ls(A*x-b), norm(x) <= 1) - -``` - -""" -function problem(terms::Vararg) - cf = () - for i = 1:length(terms) - cf = (cf...,terms[i]...) - end - return cf -end diff --git a/src/syntax/syntax.jl b/src/syntax/syntax.jl deleted file mode 100644 index 514514b..0000000 --- a/src/syntax/syntax.jl +++ /dev/null @@ -1,8 +0,0 @@ -abstract type AbstractExpression end - -include("variable.jl") -include("expressions/expression.jl") -include("terms/term.jl") -include("problem.jl") - -const TermOrExpr = Union{Term,AbstractExpression} diff --git a/src/syntax/terms/proximalOperators_bind.jl b/src/syntax/terms/proximalOperators_bind.jl index c507638..d3c9ba0 100644 --- a/src/syntax/terms/proximalOperators_bind.jl +++ b/src/syntax/terms/proximalOperators_bind.jl @@ -1,7 +1,7 @@ # Norms import LinearAlgebra: norm -export norm +export norm, mixednorm """ norm(x::AbstractExpression, p=2, [q,] [dim=1]) @@ -48,32 +48,76 @@ function norm(ex::AbstractExpression, ::typeof(*)) end # Mixed Norm -function norm(ex::AbstractExpression, p1::Int, p2::Int, dim::Int = 1 ) - if p1 == 2 && p2 == 1 - f = NormL21(1.0,dim) +""" + mixednorm(x, p::Int, q::Int) + +``l_{2,1}`` mixed norm (aka Sum-of-``l_2``-norms) +```math +f(\\mathbf{X}) = \\sum_i \\| \\mathbf{x}_i \\| +``` +where ``\\mathbf{x}_i`` is the ``i``-th column if `p == 2` and `q == 1` (or row if `p == 1` and `q == 2`) of ``\\mathbf{X}``. +""" +function mixednorm(ex::AbstractExpression, p::Int, q::Int) + if p == 2 && q == 1 + f = NormL21(1.0, 1) + elseif p == 1 && q == 2 + f = NormL21(1.0, 2) else error("function not implemented") end return Term(f, ex) end +function mixednorm(A::AbstractMatrix{T}, p::Int, q::Int) where {T} + if p == 2 && q == 1 + return NormL21(1.0, 1)(A) + elseif p == 1 && q == 2 + return NormL21(1.0, 2)(A) + else + error("function not implemented") + end + return result +end # Least square terms -export ls +export ls, normalop_ls """ ls(x::AbstractExpression) Returns the squared norm (least squares) of `x`: - ```math f (\\mathbf{x}) = \\frac{1}{2} \\| \\mathbf{x} \\|^2 ``` - (shorthand of `1/2*norm(x)^2`). """ ls(ex) = Term(SqrNormL2(), ex) +""" + normalop_ls(x::AbstractExpression) + +Returns the squared norm (least squares) of `L*x`: +```math +f (\\mathbf{L} * \\mathbf{x}) = \\frac{1}{2} \\| \\mathbf{L} * \\mathbf{x} \\|^2 +``` +(shorthand of `1/2*norm(x)^2`). + +The only difference with `ls` comes when gradient! is called. In this case, the +gradient is computed as usual, but the squared norm of the gradient (i.e. the +squared norm of `Lᴴ * L * x`) is returned instead of the squared norm of `L * x`. +This is much faster to compute, if `Lᴴ * L` has a fast implementation. +""" + +normalop_ls(::Variable) = error("normalop_ls does not work with Variables alone. Use ls instead.") +function normalop_ls(ex::Expression) + eye_op = if length(ex.x) == 1 + Eye(domainType(ex.L), size(ex.L, 2)) + else + HCAT([Eye(domainType(L), size(L, 2)) for L in ex.L]...) + end + return Term(SqrNormL2WithNormalOp(ex.L), Expression(ex.x, eye_op)) +end + import Base: ^ function (^)(t::Term{T1,T2,T3}, exp::Integer) where {T1, T2 <: NormL2, T3} @@ -138,13 +182,12 @@ Term(CrossEntropy(b), ex) export logisticloss """ - logbarrier(x::AbstractExpression, y::AbstractArray) + logisticloss(x::AbstractExpression, y::Array) Applies the logistic loss function: ```math -f(\\mathbf{x}) = \\sum_{i} \\log(1+ \\exp(-y_i x_i)), +f(\\mathbf{x}) = \\sum_i \\log(1 + \\exp(-y_i x_i)). ``` -where `y` is an array containing ``y_i``. """ logisticloss(ex::AbstractExpression, y::AbstractArray) = Term(LogisticLoss(y, 1.0), ex) diff --git a/src/syntax/terms/term.jl b/src/syntax/terms/term.jl index 0a9287f..c3c25ad 100644 --- a/src/syntax/terms/term.jl +++ b/src/syntax/terms/term.jl @@ -2,7 +2,11 @@ struct Term{T1 <: Real, T2, T3 <: AbstractExpression} lambda::T1 f::T2 A::T3 - Term(lambda::T1, f::T2, ex::T3) where {T1,T2,T3} = new{T1,T2,T3}(lambda,f,ex) + repr::Union{String,Nothing} +end + +function Term(lambda, f, ex::AbstractExpression) + return Term(lambda,f,ex,nothing) end function Term(f, ex::AbstractExpression) @@ -10,6 +14,37 @@ function Term(f, ex::AbstractExpression) Term(one(real(codomainType(affine(A)))),f, A) end +function Term(f, ex::AbstractExpression, repr::String) + A = convert(Expression,ex) + Term(one(real(codomainType(affine(A)))),f, A, repr) +end + +function Term(t::Term, repr::String) + Term(t.lambda, t.f, t.A, repr) +end + +import Base: ==, show + +# Ignore the repr when comparing terms +==(t1::Term, t2::Term) = t1.lambda == t2.lambda && t1.f == t2.f && t1.A == t2.A + +function show(io::IO, t::Term) + if t.repr !== nothing + print(io, t.repr) + else + print(io, t.lambda, " * ", t.f, "(", t.A, ")") + end +end + +function show(io::IO, t::NTuple{N,Term}) where {N} + for i in 1:N + show(io, t[i]) + if i < N + print(io, " + ") + end + end +end + # Operations # Define sum of terms simply as their vcat @@ -44,21 +79,37 @@ affine(t::Term) = affine(t.A) displacement(t::Term) = displacement(t.A) #importing properties from ProximalOperators -import ProximalOperators: - is_affine, - is_cone, +import ProximalCore: + is_affine_indicator, + is_cone_indicator, is_convex, is_generalized_quadratic, - is_prox_accurate, + is_proximable, is_quadratic, is_separable, - is_set, - is_singleton, + is_set_indicator, + is_singleton_indicator, is_smooth, + is_locally_smooth, is_strongly_convex +is_func_f = [ + :is_set_indicator, + :is_singleton_indicator, + :is_smooth, + :is_locally_smooth, + ] + +for f in is_func_f + @eval begin + import ProximalCore: $f + $f(t::Term) = $f(t.f) + $f(t::NTuple{N,Term}) where {N} = all($f.(t)) + end +end + #importing properties from AbstractOperators -is_f = [:is_linear, +is_op_f = [:is_linear, :is_eye, :is_null, :is_diagonal, @@ -71,7 +122,7 @@ is_f = [:is_linear, :is_sliced ] -for f in is_f +for f in is_op_f @eval begin import AbstractOperators: $f $f(t::Term) = $f(operator(t)) @@ -79,10 +130,13 @@ for f in is_f end end -is_smooth(t::Term) = is_smooth(t.f) +is_affine_indicator(t::Term) = is_affine_indicator(t.f) && is_linear(t) +is_cone_indicator(t::Term) = is_cone_indicator(t.f) && is_linear(t) is_convex(t::Term) = is_convex(t.f) && is_linear(t) is_quadratic(t::Term) = is_quadratic(t.f) && is_linear(t) +is_generalized_quadratic(t::Term) = is_generalized_quadratic(t.f) && is_linear(t) is_strongly_convex(t::Term) = is_strongly_convex(t.f) && is_full_column_rank(operator(t.A)) +is_separable(t::Term) = is_separable(t.f) && is_diagonal(operator(t.A)) include("proximalOperators_bind.jl") diff --git a/test/runtests.jl b/test/runtests.jl index a256eba..cf4986e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,7 +22,7 @@ Random.seed!(0) include("test_terms.jl") end - @testset "Problem construction" begin + #=@testset "Problem construction" begin include("test_problem.jl") include("test_build_minimize.jl") end @@ -30,7 +30,7 @@ Random.seed!(0) @testset "End-to-end tests" begin include("test_usage_small.jl") include("test_usage.jl") - end + end=# @testset "Aqua" begin Aqua.test_all(StructuredOptimization; ambiguities=false, piracies=false) diff --git a/test/test_expressions.jl b/test/test_expressions.jl index 890786d..707375e 100644 --- a/test/test_expressions.jl +++ b/test/test_expressions.jl @@ -316,3 +316,15 @@ ex3 = ex1-ex2 @test_throws DimensionMismatch MatrixOp(randn(10,20))*Variable(20)+randn(11) @test_throws ErrorException MatrixOp(randn(10,20))*Variable(20)+(3+im) +# Advanced (+) sum +x, y, z, w = Variable(10), Variable(20), Variable(30), Variable(40) +~x, ~y, ~z, ~w = rand(10), rand(20), rand(30), rand(40) +A = randn(10,10) +exA = (z[1:10]+x)+3*(x+z[1:10])+A*(w[1:10]+z[1:10])+(z[1:10]+w[1:10]) +exB = 5*w[1:10]+z[1:10]+z[1:10]+3*y[1:10]+z[1:10] +exC = exA+exB +op = operator(exC) +output = op*(~x,~y,~z,~w) +expected_output = 4*~x+3*~y[1:10]+8*~z[1:10]+6*~w[1:10]+A*(~w[1:10]+~z[1:10]) +@test norm(output-expected_output) < 1e-12 + diff --git a/test/test_terms.jl b/test/test_terms.jl index 6988194..6c165f4 100644 --- a/test/test_terms.jl +++ b/test/test_terms.jl @@ -39,15 +39,15 @@ cf = pi*norm(x,2) @test cf.lambda - pi == 0 @test cf.f(~x) == norm(~x) -cf = 3*norm(X,2,1) +cf = 3*mixednorm(X,2,1) @test cf.lambda - 3 == 0 @test cf.f(~X) == sum( sqrt.(sum((~X).^2, dims=1 )) ) -cf = 4*norm(X,2,1,2) +cf = 4*mixednorm(X,1,2) @test cf.lambda - 4 == 0 @test cf.f(~X) == sum( sqrt.(sum((~X).^2, dims=2 )) ) -@test_throws ErrorException 4*norm(X,1,2) +@test_throws ErrorException 4*mixednorm(X,1,3) cf = norm(x, 2) <= 2.3 @test cf.lambda == 1 @@ -175,7 +175,7 @@ end cf = 2*norm(x,1) ccf = conj(cf) @test ccf.A == cf.A -@test ccf.f == Conjugate(Postcompose(NormL1(),2)) +@test ccf.f == Conjugate(Postcompose(NormL1(),2.0)) @test_throws ErrorException conj(norm(randn(2,10)*x,1)) cf = 2*norm(x,1) diff --git a/test/test_usage.jl b/test/test_usage.jl index 8d5f2b8..ba8837e 100644 --- a/test/test_usage.jl +++ b/test/test_usage.jl @@ -5,7 +5,7 @@ Random.seed!(0) ################################################################################ println("Testing: regularized least squares, with two variable blocks to make things weird") - +begin m, n1, n2 = 30, 50, 100 A1 = randn(m, n1) @@ -20,6 +20,7 @@ lam2 = 1.0 x1_fpg = Variable(n1) x2_fpg = Variable(n2) expr = ls(A1*x1_fpg + A2*x2_fpg - b) + lam1*norm(x1_fpg, 1) + lam2*norm(x2_fpg, 2) +end prob = problem(expr) @time sol = solve(prob, PANOCplus(tol=1e-10, verbose=false,maxit=20000)) From ecbf7686f4bd94b5f4aab1105ef1e5348a5705d0 Mon Sep 17 00:00:00 2001 From: Tamas Hakkel Date: Sat, 15 Nov 2025 15:56:41 +0100 Subject: [PATCH 4/5] Large commit with various changes - Use the new interface of AbstractOperators.jl v0.4 - Add new parent package (OperatorCore) and subpackages (FFTWOperators, DSPOperators) of AbstractOperators - Use TermSet instead of tuple of Terms - Implement parsing for LeastSquaresTerm - Add name field for Variable - rename back L1,2-norm from mixednorm to norm - separate Project.toml for test --- .gitignore | 2 + Manifest.toml | 353 ++++++++++-------- Project.toml | 28 +- src/StructuredOptimization.jl | 9 +- src/calculus/precomposeNonlinear.jl | 4 +- src/calculus/sqrNormL2WithNormalOp.jl | 8 +- src/solvers/build_solve.jl | 85 +++-- src/solvers/minimize.jl | 53 ++- src/solvers/parse.jl | 237 ++++++++++-- src/solvers/terms_extract.jl | 33 +- src/solvers/terms_properties.jl | 10 +- .../expressions/abstractOperator_bind.jl | 4 +- src/syntax/expressions/addition.jl | 2 +- .../expressions/addition_tricky_part.jl | 44 +-- src/syntax/expressions/expression.jl | 4 +- src/syntax/expressions/multiplication.jl | 6 +- src/syntax/terms/proximalOperators_bind.jl | 35 +- src/syntax/terms/term.jl | 184 +++++---- src/syntax/variable.jl | 35 +- test/Project.toml | 43 +++ test/runtests.jl | 7 +- test/test_AbstractOp_binding.jl | 6 +- test/test_build_minimize.jl | 4 +- test/test_expressions.jl | 58 +-- test/test_problem.jl | 160 -------- test/test_proxstuff.jl | 4 +- test/test_terms.jl | 27 +- test/test_usage.jl | 74 ++-- test/test_usage_small.jl | 16 +- 29 files changed, 804 insertions(+), 731 deletions(-) create mode 100644 test/Project.toml diff --git a/.gitignore b/.gitignore index 788274b..899cd07 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,5 @@ demos/.ipynb_checkpoints/ docs/build/ docs/site/ docs/Manifest.toml + +Manifest.toml diff --git a/Manifest.toml b/Manifest.toml index 5b08d76..cb77f89 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,13 +1,13 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.11.4" +julia_version = "1.12.1" manifest_format = "2.0" -project_hash = "d7d80843b7c63bcd8962a2e974300665e8f478dc" +project_hash = "c8f5f45579604b7204fcaa029c0a41ea02d98e72" [[deps.ADTypes]] -git-tree-sha1 = "e2478490447631aedba0823d4d7a80b2cc8cdb32" +git-tree-sha1 = "27cecae79e5cc9935255f90c53bb831cc3c870d7" uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" -version = "1.14.0" +version = "1.18.0" [deps.ADTypes.extensions] ADTypesChainRulesCoreExt = "ChainRulesCore" @@ -34,18 +34,18 @@ version = "1.5.0" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[deps.AbstractOperators]] -deps = ["DSP", "FFTW", "FastBroadcast", "LinearAlgebra", "OperatorCore", "RecursiveArrayTools"] +deps = ["FastBroadcast", "LinearAlgebra", "OperatorCore", "Polyester", "Random", "RecursiveArrayTools"] path = "../AbstractOperators" uuid = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c" version = "0.4.0" [deps.AbstractOperators.extensions] - CudaExt = "CUDA" - NfftExt = "NFFT" + GpuExt = "GPUArrays" + LinearMapsExt = "LinearMaps" [deps.AbstractOperators.weakdeps] - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - NFFT = "efe261a4-0d2b-5849-be55-fc731d526b0d" + GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" + LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e" [[deps.Accessors]] deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "MacroTools"] @@ -73,9 +73,9 @@ version = "0.1.42" [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "f7817e2e585aa6d924fd714df1e2a84be7896c60" +git-tree-sha1 = "7e35fca2bdfba44d797c53dfe63a51fabf39bfc0" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "4.3.0" +version = "4.4.0" [deps.Adapt.extensions] AdaptSparseArraysExt = "SparseArrays" @@ -91,18 +91,19 @@ version = "1.1.2" [[deps.ArrayInterface]] deps = ["Adapt", "LinearAlgebra"] -git-tree-sha1 = "017fcb757f8e921fb44ee063a7aafe5f89b86dd1" +git-tree-sha1 = "d81ae5489e13bc03567d4fbbb06c546a5e53c857" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "7.18.0" +version = "7.22.0" [deps.ArrayInterface.extensions] ArrayInterfaceBandedMatricesExt = "BandedMatrices" ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices" ArrayInterfaceCUDAExt = "CUDA" - ArrayInterfaceCUDSSExt = "CUDSS" + ArrayInterfaceCUDSSExt = ["CUDSS", "CUDA"] ArrayInterfaceChainRulesCoreExt = "ChainRulesCore" ArrayInterfaceChainRulesExt = "ChainRules" ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore" + ArrayInterfaceMetalExt = "Metal" ArrayInterfaceReverseDiffExt = "ReverseDiff" ArrayInterfaceSparseArraysExt = "SparseArrays" ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore" @@ -116,6 +117,7 @@ version = "7.18.0" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" + Metal = "dde4c033-4e86-420c-a63e-0dd931031962" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" @@ -131,9 +133,9 @@ version = "1.11.0" [[deps.BenchmarkTools]] deps = ["Compat", "JSON", "Logging", "Printf", "Profile", "Statistics", "UUIDs"] -git-tree-sha1 = "e38fbc49a620f5d0b660d7f543db1009fe0f8336" +git-tree-sha1 = "7fecfb1123b8d0232218e2da0c213004ff15358d" uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -version = "1.6.0" +version = "1.6.3" [[deps.Bessels]] git-tree-sha1 = "4435559dc39793d53a9e3d278e185e920b4619ef" @@ -153,10 +155,10 @@ uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" version = "1.0.9+0" [[deps.CPUSummary]] -deps = ["CpuId", "IfElse", "PrecompileTools", "Static"] -git-tree-sha1 = "5a97e67919535d6841172016c9530fd69494e5ec" +deps = ["CpuId", "IfElse", "PrecompileTools", "Preferences", "Static"] +git-tree-sha1 = "f3a21d7fc84ba618a779d1ed2fcca2e682865bab" uuid = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9" -version = "0.2.6" +version = "0.2.7" [[deps.CloseOpenIntervals]] deps = ["Static", "StaticArrayInterface"] @@ -177,9 +179,9 @@ uuid = "944b1d66-785c-5afd-91f1-9de20f533193" version = "0.7.8" [[deps.Combinatorics]] -git-tree-sha1 = "08c8b6831dc00bfea825826be0bc8336fc369860" +git-tree-sha1 = "8010b6bb3388abe68d95743dcbea77650bb2eddf" uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" -version = "1.0.2" +version = "1.0.3" [[deps.CommonSubexpressions]] deps = ["MacroTools"] @@ -194,9 +196,9 @@ version = "1.0.0" [[deps.Compat]] deps = ["TOML", "UUIDs"] -git-tree-sha1 = "8ae8d32e09f0dcf42a36b90d4e17f5dd2e4c4215" +git-tree-sha1 = "9d8a54ce4b17aa5bdce0ea5c34bc5e7c340d16ad" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.16.0" +version = "4.18.1" weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] @@ -205,7 +207,7 @@ weakdeps = ["Dates", "LinearAlgebra"] [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.1.1+0" +version = "1.3.0+1" [[deps.CompositionsBase]] git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" @@ -217,9 +219,9 @@ weakdeps = ["InverseFunctions"] CompositionsBaseInverseFunctionsExt = "InverseFunctions" [[deps.ConstructionBase]] -git-tree-sha1 = "76219f1ed5771adbb096743bff43fb5fdd4c1157" +git-tree-sha1 = "b4b092499347b18a015186eae3042f72267106cb" uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.8" +version = "1.6.0" [deps.ConstructionBase.extensions] ConstructionBaseIntervalSetsExt = "IntervalSets" @@ -239,9 +241,9 @@ version = "0.3.1" [[deps.DSP]] deps = ["Bessels", "FFTW", "IterTools", "LinearAlgebra", "Polynomials", "Random", "Reexport", "SpecialFunctions", "Statistics"] -git-tree-sha1 = "489db9d78b53e44fb753d225c58832632d74ab10" +git-tree-sha1 = "5989debfc3b38f736e69724818210c67ffee4352" uuid = "717857b8-e6f2-59f4-9121-6e50c889abd2" -version = "0.8.0" +version = "0.8.4" [deps.DSP.extensions] OffsetArraysExt = "OffsetArrays" @@ -249,21 +251,17 @@ version = "0.8.0" [deps.DSP.weakdeps] OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" -[[deps.DataAPI]] -git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.16.0" +[[deps.DSPOperators]] +deps = ["AbstractOperators", "DSP", "FFTW", "LinearAlgebra"] +path = "../AbstractOperators/DSPOperators" +uuid = "d5a72628-6e2f-430e-82f5-561df0bb8116" +version = "0.1.0" [[deps.DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" +deps = ["OrderedCollections"] +git-tree-sha1 = "e357641bb3e0638d353c4b29ea0e40ea644066a6" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.20" - -[[deps.DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" +version = "0.19.3" [[deps.Dates]] deps = ["Printf"] @@ -284,9 +282,9 @@ version = "1.15.1" [[deps.DifferentiationInterface]] deps = ["ADTypes", "LinearAlgebra"] -git-tree-sha1 = "d86f29074367f1bb92957e8d0b77badd187a97bc" +git-tree-sha1 = "c8d85ecfcbaef899308706bebdd8b00107f3fb43" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" -version = "0.6.32" +version = "0.6.54" [deps.DifferentiationInterface.extensions] DifferentiationInterfaceChainRulesCoreExt = "ChainRulesCore" @@ -296,10 +294,13 @@ version = "0.6.32" DifferentiationInterfaceFiniteDiffExt = "FiniteDiff" DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences" DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"] + DifferentiationInterfaceGPUArraysCoreExt = "GPUArraysCore" + DifferentiationInterfaceGTPSAExt = "GTPSA" DifferentiationInterfaceMooncakeExt = "Mooncake" - DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff" + DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"] DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"] DifferentiationInterfaceSparseArraysExt = "SparseArrays" + DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer" DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings" DifferentiationInterfaceStaticArraysExt = "StaticArrays" DifferentiationInterfaceSymbolicsExt = "Symbolics" @@ -316,10 +317,13 @@ version = "0.6.32" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" + GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" @@ -327,10 +331,9 @@ version = "0.6.32" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [[deps.DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +git-tree-sha1 = "7442a5dfe1ebb773c29cc2962a8980f47221d76c" uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.9.3" +version = "0.9.5" [[deps.Downloads]] deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] @@ -343,16 +346,22 @@ uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" version = "0.1.10" [[deps.FFTW]] -deps = ["AbstractFFTs", "FFTW_jll", "LinearAlgebra", "MKL_jll", "Preferences", "Reexport"] -git-tree-sha1 = "7de7c78d681078f027389e067864a8d53bd7c3c9" +deps = ["AbstractFFTs", "FFTW_jll", "Libdl", "LinearAlgebra", "MKL_jll", "Preferences", "Reexport"] +git-tree-sha1 = "97f08406df914023af55ade2f843c39e99c5d969" uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" -version = "1.8.1" +version = "1.10.0" + +[[deps.FFTWOperators]] +deps = ["AbstractOperators", "FFTW", "LinearAlgebra", "Polyester"] +path = "../AbstractOperators/FFTWOperators" +uuid = "c59a084b-ba08-4f3f-af9e-f4298d6caa94" +version = "0.1.0" [[deps.FFTW_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "4d81ed14783ec49ce9f2e168208a12ce1815aa25" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "6d6219a004b8cf1e0b4dbe27a2860b8e04eba0be" uuid = "f5851436-0d7a-5f13-b9de-f02708fd171a" -version = "3.3.10+3" +version = "3.3.11+0" [[deps.FastBroadcast]] deps = ["ArrayInterface", "LinearAlgebra", "Polyester", "Static", "StaticArrayInterface", "StrideArraysCore"] @@ -366,9 +375,9 @@ version = "1.11.0" [[deps.ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "a2df1b776752e3f344e5116c06d75a10436ab853" +git-tree-sha1 = "ba6ce081425d0afb2bedd00d9884464f764a9225" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.38" +version = "1.2.2" [deps.ForwardDiff.extensions] ForwardDiffStaticArraysExt = "StaticArrays" @@ -394,9 +403,9 @@ version = "0.1.1" [[deps.IntelOpenMP_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl"] -git-tree-sha1 = "0f14a5456bdc6b9731a5682f439a672750a09e48" +git-tree-sha1 = "ec1debd61c300961f98064cfb21287613ad7f303" uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0" -version = "2025.0.4+0" +version = "2025.2.0+0" [[deps.InteractiveUtils]] deps = ["Markdown"] @@ -414,9 +423,9 @@ weakdeps = ["Dates", "Test"] InverseFunctionsTestExt = "Test" [[deps.IrrationalConstants]] -git-tree-sha1 = "e2222959fbc6c19554dc15174c81bf7bf3aa691c" +git-tree-sha1 = "b2d91fe939cae05960e760110b328288867b5758" uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.2.4" +version = "0.2.6" [[deps.IterTools]] git-tree-sha1 = "42d5f897009e7ff2cf88db414a389e5ed1bdd023" @@ -429,28 +438,29 @@ git-tree-sha1 = "59545b0a2b27208b0650df0a46b8e3019f85055b" uuid = "42fd0dbc-a981-5370-80f2-aaf504508153" version = "0.9.4" -[[deps.IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - [[deps.JLLWrappers]] deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "a007feb38b422fbdab534406aeca1b86823cb4d6" +git-tree-sha1 = "0533e564aae234aff59ab625543145446d8b6ec2" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.7.0" +version = "1.7.1" [[deps.JSON]] -deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" +deps = ["Dates", "Logging", "Parsers", "PrecompileTools", "StructUtils", "UUIDs", "Unicode"] +git-tree-sha1 = "eb04df293213df64ddd720c86de3c431f5f8ccf1" uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.4" +version = "1.2.1" + + [deps.JSON.extensions] + JSONArrowExt = ["ArrowTypes"] + + [deps.JSON.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" [[deps.JSON3]] deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] -git-tree-sha1 = "1d322381ef7b087548321d3f878cb4c9bd8f8f9b" +git-tree-sha1 = "411eccfe8aba0814ffa0fdf4860913ed09c34975" uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" -version = "1.14.1" +version = "1.14.3" [deps.JSON3.extensions] JSON3ArrowExt = ["ArrowTypes"] @@ -458,6 +468,11 @@ version = "1.14.1" [deps.JSON3.weakdeps] ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" +[[deps.JuliaSyntaxHighlighting]] +deps = ["StyledStrings"] +uuid = "ac6e5ff7-fb65-4e79-a425-ec3bc9c03011" +version = "1.12.0" + [[deps.LayoutPointers]] deps = ["ArrayInterface", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static", "StaticArrayInterface"] git-tree-sha1 = "a9eaadb366f5493a5654e843864c13d8b107548c" @@ -475,24 +490,24 @@ uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" version = "0.6.4" [[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "OpenSSL_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.6.0+0" +version = "8.11.1+1" [[deps.LibGit2]] -deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] +deps = ["LibGit2_jll", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" version = "1.11.0" [[deps.LibGit2_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "OpenSSL_jll"] uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" -version = "1.7.2+0" +version = "1.9.0+0" [[deps.LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +deps = ["Artifacts", "Libdl", "OpenSSL_jll"] uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.11.0+1" +version = "1.11.3+1" [[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -501,7 +516,7 @@ version = "1.11.0" [[deps.LinearAlgebra]] deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -version = "1.11.0" +version = "1.12.0" [[deps.LogExpFunctions]] deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] @@ -525,14 +540,14 @@ version = "1.11.0" [[deps.MKL_jll]] deps = ["Artifacts", "IntelOpenMP_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "oneTBB_jll"] -git-tree-sha1 = "5de60bc6cb3899cd318d80d627560fae2e2d99ae" +git-tree-sha1 = "282cadc186e7b2ae0eeadbd7a4dffed4196ae2aa" uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7" -version = "2025.0.1+1" +version = "2025.2.0+0" [[deps.MacroTools]] -git-tree-sha1 = "72aebe0b5051e5143a079a4685a46da330a40472" +git-tree-sha1 = "1e0228a030642014fe5cfe68c2c0a818f9e3f522" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.15" +version = "0.5.16" [[deps.ManualMemory]] git-tree-sha1 = "bcaef4fc7a0cfe2cba636d84cda54b5e4e4ca3cd" @@ -540,20 +555,15 @@ uuid = "d125e4d3-2237-4719-b19c-fa641b8a4667" version = "0.1.8" [[deps.Markdown]] -deps = ["Base64"] +deps = ["Base64", "JuliaSyntaxHighlighting", "StyledStrings"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" version = "1.11.0" [[deps.MathOptInterface]] deps = ["BenchmarkTools", "CodecBzip2", "CodecZlib", "DataStructures", "ForwardDiff", "JSON3", "LinearAlgebra", "MutableArithmetics", "NaNMath", "OrderedCollections", "PrecompileTools", "Printf", "SparseArrays", "SpecialFunctions", "Test"] -git-tree-sha1 = "6723502b2135aa492a65be9633e694482a340ee7" +git-tree-sha1 = "a2cbab4256690aee457d136752c404e001f27768" uuid = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" -version = "1.38.0" - -[[deps.MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.6+0" +version = "1.46.0" [[deps.Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" @@ -561,23 +571,23 @@ version = "1.11.0" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2023.12.12" +version = "2025.5.20" [[deps.MutableArithmetics]] deps = ["LinearAlgebra", "SparseArrays", "Test"] -git-tree-sha1 = "491bdcdc943fcbc4c005900d7463c9f216aabf4c" +git-tree-sha1 = "22df8573f8e7c593ac205455ca088989d0a2c7a0" uuid = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" -version = "1.6.4" +version = "1.6.7" [[deps.NaNMath]] deps = ["OpenLibm_jll"] -git-tree-sha1 = "cc0a5deefdb12ab3a096f00a6d42133af4560d71" +git-tree-sha1 = "9b8215b1ee9e78a293f99797cd31375471b2bcae" uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "1.1.2" +version = "1.1.3" [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" -version = "1.2.0" +version = "1.3.0" [[deps.OSQP]] deps = ["Libdl", "LinearAlgebra", "MathOptInterface", "OSQP_jll", "SparseArrays"] @@ -594,12 +604,17 @@ version = "0.600.200+0" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.27+1" +version = "0.3.29+0" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+4" +version = "0.8.7+0" + +[[deps.OpenSSL_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" +version = "3.5.1+0" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl"] @@ -610,23 +625,23 @@ version = "0.5.6+0" [[deps.OperatorCore]] path = "../OperatorCore" uuid = "3945cd23-d97e-4db0-9df2-35342dbd287d" -version = "0.1.0" +version = "0.1.1" [[deps.OrderedCollections]] -git-tree-sha1 = "cc4054e898b852042d7b503313f7ad03de99c3dd" +git-tree-sha1 = "05868e21324cede2207c6f0f466b4bfef6d5e7ee" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.8.0" +version = "1.8.1" [[deps.Parsers]] deps = ["Dates", "PrecompileTools", "UUIDs"] -git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" +git-tree-sha1 = "7d2f8f21da5db6a806faf7b9b292296da42b2810" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.8.1" +version = "2.8.3" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "Random", "SHA", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.11.0" +version = "1.12.0" [deps.Pkg.extensions] REPLExt = "REPL" @@ -636,9 +651,9 @@ version = "1.11.0" [[deps.Polyester]] deps = ["ArrayInterface", "BitTwiddlingConvenienceFunctions", "CPUSummary", "IfElse", "ManualMemory", "PolyesterWeave", "Static", "StaticArrayInterface", "StrideArraysCore", "ThreadingUtilities"] -git-tree-sha1 = "6d38fea02d983051776a856b7df75b30cf9a3c1f" +git-tree-sha1 = "6f7cd22a802094d239824c57d94c8e2d0f7cfc7d" uuid = "f517fe37-dbe3-4b94-8317-1923a5111588" -version = "0.7.16" +version = "0.7.18" [[deps.PolyesterWeave]] deps = ["BitTwiddlingConvenienceFunctions", "CPUSummary", "IfElse", "Static", "ThreadingUtilities"] @@ -648,33 +663,33 @@ version = "0.2.2" [[deps.Polynomials]] deps = ["LinearAlgebra", "OrderedCollections", "RecipesBase", "Requires", "Setfield", "SparseArrays"] -git-tree-sha1 = "555c272d20fc80a2658587fb9bbda60067b93b7c" +git-tree-sha1 = "972089912ba299fba87671b025cd0da74f5f54f7" uuid = "f27b6e38-b328-58d1-80ce-0feddd5e7a45" -version = "4.0.19" +version = "4.1.0" [deps.Polynomials.extensions] PolynomialsChainRulesCoreExt = "ChainRulesCore" PolynomialsFFTWExt = "FFTW" - PolynomialsMakieCoreExt = "MakieCore" + PolynomialsMakieExt = "Makie" PolynomialsMutableArithmeticsExt = "MutableArithmetics" [deps.Polynomials.weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" - MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b" + Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" MutableArithmetics = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" [[deps.PrecompileTools]] deps = ["Preferences"] -git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" +git-tree-sha1 = "07a921781cab75691315adc645096ed5e370cb77" uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.2.1" +version = "1.3.3" [[deps.Preferences]] deps = ["TOML"] -git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" +git-tree-sha1 = "0f27480397253da18fe2c12a4ba4eb9eb208bf3d" uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.3" +version = "1.5.0" [[deps.Printf]] deps = ["Unicode"] @@ -682,6 +697,7 @@ uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" version = "1.11.0" [[deps.Profile]] +deps = ["StyledStrings"] uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" version = "1.11.0" @@ -702,6 +718,10 @@ deps = ["IterativeSolvers", "LinearAlgebra", "OSQP", "ProximalCore", "SparseArra path = "../ProximalOperators.jl" uuid = "a725b495-10eb-56fe-b38b-717eba820537" version = "0.17.0" +weakdeps = ["RecursiveArrayTools"] + + [deps.ProximalOperators.extensions] + RecursiveArrayToolsExt = "RecursiveArrayTools" [[deps.Random]] deps = ["SHA"] @@ -715,30 +735,34 @@ uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" version = "1.3.4" [[deps.RecursiveArrayTools]] -deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] -git-tree-sha1 = "32f824db4e5bab64e25a12b22483a30a6b813d08" +deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "LinearAlgebra", "RecipesBase", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface"] +git-tree-sha1 = "51bdb23afaaa551f923a0e990f7c44a4451a26f1" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" -version = "3.27.4" +version = "3.39.0" [deps.RecursiveArrayTools.extensions] RecursiveArrayToolsFastBroadcastExt = "FastBroadcast" RecursiveArrayToolsForwardDiffExt = "ForwardDiff" + RecursiveArrayToolsKernelAbstractionsExt = "KernelAbstractions" RecursiveArrayToolsMeasurementsExt = "Measurements" RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements" RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"] RecursiveArrayToolsSparseArraysExt = ["SparseArrays"] RecursiveArrayToolsStructArraysExt = "StructArrays" + RecursiveArrayToolsTablesExt = ["Tables"] RecursiveArrayToolsTrackerExt = "Tracker" RecursiveArrayToolsZygoteExt = "Zygote" [deps.RecursiveArrayTools.weakdeps] FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -755,9 +779,9 @@ version = "1.3.1" [[deps.RuntimeGeneratedFunctions]] deps = ["ExprTools", "SHA", "Serialization"] -git-tree-sha1 = "04c968137612c4a5629fa531334bb81ad5680f00" +git-tree-sha1 = "2f609ec2295c452685d3142bc4df202686e555d2" uuid = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" -version = "0.5.13" +version = "0.5.16" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -768,6 +792,11 @@ git-tree-sha1 = "330289636fb8107c5f32088d2741e9fd7a061a5c" uuid = "94e857df-77ce-4151-89e5-788b33177be4" version = "0.1.0" +[[deps.SciMLPublic]] +git-tree-sha1 = "ed647f161e8b3f2973f24979ec074e8d084f1bee" +uuid = "431bcebd-1456-4ced-9d72-93c2757fff0b" +version = "1.0.0" + [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" version = "1.11.0" @@ -781,13 +810,13 @@ version = "1.1.2" [[deps.SparseArrays]] deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -version = "1.11.0" +version = "1.12.0" [[deps.SpecialFunctions]] deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "64cca0c26b4f31ba18f13f6c12af7c85f478cfde" +git-tree-sha1 = "f2685b435df2613e25fc10ad8c26dddb8640f547" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.5.0" +version = "2.6.1" [deps.SpecialFunctions.extensions] SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" @@ -796,10 +825,10 @@ version = "2.5.0" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" [[deps.Static]] -deps = ["CommonWorldInvalidations", "IfElse", "PrecompileTools"] -git-tree-sha1 = "f737d444cb0ad07e61b3c1bef8eb91203c321eff" +deps = ["CommonWorldInvalidations", "IfElse", "PrecompileTools", "SciMLPublic"] +git-tree-sha1 = "49440414711eddc7227724ae6e570c7d5559a086" uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" -version = "1.2.0" +version = "1.3.1" [[deps.StaticArrayInterface]] deps = ["ArrayInterface", "Compat", "IfElse", "LinearAlgebra", "PrecompileTools", "Static"] @@ -816,9 +845,9 @@ version = "1.8.0" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [[deps.StaticArraysCore]] -git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" +git-tree-sha1 = "6ab403037779dae8c514bad259f32a447262455a" uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.3" +version = "1.4.4" [[deps.Statistics]] deps = ["LinearAlgebra"] @@ -832,9 +861,9 @@ weakdeps = ["SparseArrays"] [[deps.StrideArraysCore]] deps = ["ArrayInterface", "CloseOpenIntervals", "IfElse", "LayoutPointers", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static", "StaticArrayInterface", "ThreadingUtilities"] -git-tree-sha1 = "f35f6ab602df8413a50c4a25ca14de821e8605fb" +git-tree-sha1 = "83151ba8065a73f53ca2ae98bc7274d817aa30f2" uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da" -version = "0.5.7" +version = "0.5.8" [[deps.StructTypes]] deps = ["Dates", "UUIDs"] @@ -842,6 +871,30 @@ git-tree-sha1 = "159331b30e94d7b11379037feeb9b690950cace8" uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" version = "1.11.0" +[[deps.StructUtils]] +deps = ["Dates", "UUIDs"] +git-tree-sha1 = "79529b493a44927dd5b13dde1c7ce957c2d049e4" +uuid = "ec057cc2-7a8d-4b58-b3b3-92acb9f63b42" +version = "2.6.0" + + [deps.StructUtils.extensions] + StructUtilsMeasurementsExt = ["Measurements"] + StructUtilsTablesExt = ["Tables"] + + [deps.StructUtils.weakdeps] + Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" + Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" + +[[deps.StructuredOptimization]] +deps = ["AbstractOperators", "Combinatorics", "DSP", "DSPOperators", "DifferentiationInterface", "FFTW", "FFTWOperators", "LinearAlgebra", "ProximalAlgorithms", "ProximalCore", "ProximalOperators", "RecursiveArrayTools"] +path = "." +uuid = "46cd3e9d-64ff-517d-a929-236bc1a1fc9d" +version = "0.5.0" + +[[deps.StyledStrings]] +uuid = "f489334b-da3d-4c2e-b8f0-e476e12c162b" +version = "1.11.0" + [[deps.SuiteSparse]] deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" @@ -849,13 +902,19 @@ uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[deps.SuiteSparse_jll]] deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.7.0+0" +version = "7.8.3+2" [[deps.SymbolicIndexingInterface]] deps = ["Accessors", "ArrayInterface", "RuntimeGeneratedFunctions", "StaticArraysCore"] -git-tree-sha1 = "d6c04e26aa1c8f7d144e1a8c47f1c73d3013e289" +git-tree-sha1 = "94c58884e013efff548002e8dc2fdd1cb74dfce5" uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5" -version = "0.3.38" +version = "0.3.46" + + [deps.SymbolicIndexingInterface.extensions] + SymbolicIndexingInterfacePrettyTablesExt = "PrettyTables" + + [deps.SymbolicIndexingInterface.weakdeps] + PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" [[deps.TOML]] deps = ["Dates"] @@ -868,18 +927,6 @@ git-tree-sha1 = "c39caef6bae501e5607a6caf68dd9ac6e8addbcb" uuid = "9449cd9e-2762-5aa3-a617-5413e99d722e" version = "0.4.4" -[[deps.TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.1" - -[[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"] -git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.12.0" - [[deps.Tar]] deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" @@ -892,9 +939,9 @@ version = "1.11.0" [[deps.ThreadingUtilities]] deps = ["ManualMemory"] -git-tree-sha1 = "eda08f7e9818eb53661b3deb74e3159460dfbc27" +git-tree-sha1 = "d969183d3d244b6c33796b5ed01ab97328f2db85" uuid = "8290d209-cae3-49c0-8002-c8c24d57dab5" -version = "0.5.2" +version = "0.5.5" [[deps.TranscodingStreams]] git-tree-sha1 = "0c45878dcfdcfa8480052b6ab162cdd138781742" @@ -913,25 +960,25 @@ version = "1.11.0" [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+1" +version = "1.3.1+2" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.11.0+0" +version = "5.15.0+0" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.59.0+0" +version = "1.64.0+1" [[deps.oneTBB_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "d5a767a3bb77135a99e433afe0eb14cd7f6914c3" +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl"] +git-tree-sha1 = "1350188a69a6e46f799d3945beef36435ed7262f" uuid = "1317d2d5-d96f-522e-a858-c73665f53c3e" -version = "2022.0.0+0" +version = "2022.0.0+1" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+2" +version = "17.5.0+2" diff --git a/Project.toml b/Project.toml index 978bcc9..746b42a 100644 --- a/Project.toml +++ b/Project.toml @@ -6,47 +6,27 @@ version = "0.5.0" AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2" +DSPOperators = "d5a72628-6e2f-430e-82f5-561df0bb8116" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +FFTWOperators = "c59a084b-ba08-4f3f-af9e-f4298d6caa94" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -OperatorCore = "3945cd23-d97e-4db0-9df2-35342dbd287d" ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9" ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" -[sources] -AbstractOperators = {path = "../AbstractOperators"} -OperatorCore = {path = "../OperatorCore"} -ProximalAlgorithms = {path = "../ProximalAlgorithms.jl"} -ProximalCore = {path = "../ProximalCore.jl"} -ProximalOperators = {path = "../ProximalOperators.jl"} -WaveletOperators = {path = "../AbstractOperators/WaveletOperators"} - [compat] AbstractOperators = "0.4" -Aqua = "0.8" Combinatorics = "1.0.2" DSP = "0.5.1 - 0.8" +DSPOperators = "0.1" DifferentiationInterface = "0.6" FFTW = "1" +FFTWOperators = "0.1" LinearAlgebra = "1" -OperatorCore = "0.1" ProximalAlgorithms = "0.8" ProximalCore = "0.2" ProximalOperators = "0.17" -Random = "1" RecursiveArrayTools = "1 - 3" -Test = "1" -WaveletOperators = "0.1" julia = "1.10" - -[extras] -Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -WaveletOperators = "f3582904-6f60-4bbd-985d-55eab799bc9d" - -[targets] -test = ["Aqua", "LinearAlgebra", "Random", "Test", "WaveletOperators"] diff --git a/src/StructuredOptimization.jl b/src/StructuredOptimization.jl index cc1082c..7a3f5bb 100644 --- a/src/StructuredOptimization.jl +++ b/src/StructuredOptimization.jl @@ -3,16 +3,21 @@ module StructuredOptimization using LinearAlgebra using RecursiveArrayTools using ProximalCore -using AbstractOperators +using AbstractOperators, DSPOperators, FFTWOperators using ProximalOperators using ProximalAlgorithms using Combinatorics: permutations, powerset -using OperatorCore +using ProximalAlgorithms: IterativeAlgorithm, override_parameters ProximalAlgorithms.value_and_gradient(f, x) = begin y, fy = gradient(f, x) return fy, y end +ProximalAlgorithms.value_and_gradient!(grad_f_x, f, x) = begin + fy = gradient!(grad_f_x, f, x) + return fy +end + abstract type AbstractExpression end include("syntax/variable.jl") diff --git a/src/calculus/precomposeNonlinear.jl b/src/calculus/precomposeNonlinear.jl index 19dec7c..110ad3a 100644 --- a/src/calculus/precomposeNonlinear.jl +++ b/src/calculus/precomposeNonlinear.jl @@ -15,9 +15,9 @@ struct PrecomposeNonlinear{P, end function PrecomposeNonlinear(g::P, G::T) where {P, T} - t, s = domainType(G), size(G,2) + t, s = domain_type(G), size(G,2) bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)) - t, s = codomainType(G), size(G,1) + t, s = codomain_type(G), size(G,1) bufC = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)) bufC2 = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)) PrecomposeNonlinear{P, T, typeof(bufD), typeof(bufC)}(g, G, bufD, bufC, bufC2) diff --git a/src/calculus/sqrNormL2WithNormalOp.jl b/src/calculus/sqrNormL2WithNormalOp.jl index f84ab7b..50276aa 100644 --- a/src/calculus/sqrNormL2WithNormalOp.jl +++ b/src/calculus/sqrNormL2WithNormalOp.jl @@ -26,9 +26,9 @@ squared norm of `L * x`, but rather the squared norm of `Lᴴ * L * x` (i.e. the squared norm of the gradient). Most algorithms, however, tolerate this difference, and it is much faster to compute. """ -struct SqrNormL2WithNormalOp{T,SC,L<:AbstractOperator} +struct SqrNormL2WithNormalOp{T,SC,L<:AbstractOperator,L2<:AbstractOperator} A::L - AᴴA::L + AᴴA::L2 lambda::T function SqrNormL2WithNormalOp(A, lambda) @assert A isa AbstractOperator @@ -36,8 +36,8 @@ struct SqrNormL2WithNormalOp{T,SC,L<:AbstractOperator} if any(lambda .< 0) error("coefficients in λ must be nonnegative") else - AᴴA = AbstractOperators.get_normal_op(A) - new{typeof(lambda),all(lambda .> 0),typeof(A)}(A, AᴴA, lambda) + AᴴA = A' * A + new{typeof(lambda),all(lambda .> 0),typeof(A),typeof(AᴴA)}(A, AᴴA, lambda) end end end diff --git a/src/solvers/build_solve.jl b/src/solvers/build_solve.jl index 64aa56a..f55032a 100644 --- a/src/solvers/build_solve.jl +++ b/src/solvers/build_solve.jl @@ -1,11 +1,11 @@ -const ForwardBackwardSolver = ProximalAlgorithms.IterativeAlgorithm +export suggest_algorithm """ - parse_problem(terms::Tuple, solver::ForwardBackwardSolver) + parse_problem(terms::TermSet, solver::IterativeAlgorithm) -Takes as input a tuple containing the terms defining the problem and the solver. +Takes as input a TermSet containing the terms defining the problem and the solver. -Returns a tuple containing the optimization variables and the problem terms +Returns a TermSet containing the optimization variables and the problem terms to be fed into the solver. # Example @@ -21,15 +21,16 @@ julia> p = problem( ls(A*x - b ) , norm(x) <= 1 ); julia> StructuredOptimization.parse_problem(p, PANOCplus()); ``` """ -function parse_problem(terms::NTuple{N,StructuredOptimization.Term}, algorithm::T, return_partial::Bool = false) where {N,T <: ForwardBackwardSolver} +function parse_problem(terms::Union{Term,TermSet}, algorithm::T, return_partial::Bool = false) where {T <: IterativeAlgorithm} + terms = terms isa TermSet ? terms : TermSet(terms) assumptions = ProximalAlgorithms.get_assumptions(algorithm) - variables = StructuredOptimization.extract_variables(terms) + variables = extract_variables(terms) remaining_terms = terms kwargs = Dict{Symbol, Any}() for assumption in assumptions for term_selection in reverse(collect(powerset(remaining_terms, 1))) - term_selection = tuple(term_selection...) - preparation_result = StructuredOptimization.prepare(term_selection, assumption, variables) + term_selection = TermSet(term_selection...) + preparation_result = prepare(term_selection, assumption, variables) if preparation_result !== nothing term_selection = collect(term_selection) remaining_terms = setdiff(remaining_terms, term_selection) @@ -44,14 +45,16 @@ function parse_problem(terms::NTuple{N,StructuredOptimization.Term}, algorithm:: return return_partial ? (kwargs, remaining_terms) : nothing end -function print_diagnostics(terms::NTuple{N,StructuredOptimization.Term}, algorithm::T) where {N,T <: ForwardBackwardSolver} +function print_diagnostics(terms::Union{Term,TermSet}, algorithm::T) where {T <: IterativeAlgorithm} + terms = terms isa TermSet ? terms : TermSet(terms) kwargs, remaining_terms = parse_problem(terms, algorithm, true) - print("The algorithm $algorithm assumes problem of form: ") + print("The algorithm $(typeof(algorithm).name.name) assumes problem of form: ") show(ProximalAlgorithms.get_assumptions(algorithm)) + println() if !isempty(kwargs) println("Successfully prepared the following terms:") for (key, value) in kwargs - println(" - $key: $value") + println(" - $key: $(typeof(value))") end end println("The following terms could not be prepared:") @@ -60,7 +63,8 @@ function print_diagnostics(terms::NTuple{N,StructuredOptimization.Term}, algorit end end -function parse_problem(terms::NTuple{N,StructuredOptimization.Term}) where {N} +function parse_problem(terms::Union{Term,TermSet}) + terms = terms isa TermSet ? terms : TermSet(terms) for algorithm in ProximalAlgorithms.get_algorithms() result = parse_problem(terms, algorithm) if result !== nothing @@ -70,9 +74,10 @@ function parse_problem(terms::NTuple{N,StructuredOptimization.Term}) where {N} return nothing end -function suggest_algorithm(terms::NTuple{N,StructuredOptimization.Term}) where {N} +function suggest_algorithm(terms::Union{Term,TermSet}, algorithms = ProximalAlgorithms.get_algorithms()) + terms = terms isa TermSet ? terms : TermSet(terms) suitable_algs = [] - for algorithm in ProximalAlgorithms.get_algorithms() + for algorithm in algorithms result = parse_problem(terms, algorithm) if result !== nothing push!(suitable_algs, algorithm) @@ -81,7 +86,8 @@ function suggest_algorithm(terms::NTuple{N,StructuredOptimization.Term}) where { return suitable_algs end -function print_diagnostics(terms::NTuple{N,StructuredOptimization.Term}) where {N} +function print_diagnostics(terms::Union{Term,TermSet}) + terms = terms isa TermSet ? terms : TermSet(terms) best_algorithm, best_algorithm_remaining_terms = nothing, Inf for algorithm in ProximalAlgorithms.get_algorithms() _, remaining_terms = parse_problem(terms, algorithm, true) @@ -97,9 +103,11 @@ end export solve """ - solve(terms::Tuple, solver::ForwardBackwardSolver) + solve(terms::Union{Term,TermSet}; kwargs...) + solve(terms::Union{Term,TermSet}, solver::IterativeAlgorithm; kwargs...) + solve(terms::Union{Term,TermSet}, solvers::Union{AbstractVector,Tuple}; kwargs...) -Takes as input a tuple containing the terms defining the problem and the solver options. +Takes as input a Term/TermSet containing the terms defining the problem and the solver options. Solves the problem returning a tuple containing the iterations taken and the build solver. @@ -113,32 +121,57 @@ julia> A, b = randn(10,4), randn(10); julia> p = problem(ls(A*x - b ), norm(x) <= 1); -julia> solve(p, PANOCplus()); +julia> solve(p, PANOCplus(); maxiter=10); julia> ~x ``` """ -function solve(terms::Tuple, solver::ForwardBackwardSolver) +function solve(terms::Union{Term,TermSet}, solvers::Union{<:AbstractVector{IterativeAlgorithm},<:Tuple{Vararg{IterativeAlgorithm}}}; kwargs...) + terms = terms isa TermSet ? terms : TermSet(terms) + for solver in solvers + result = parse_problem(terms, solver) + if result isa Nothing + continue + end + _, term_kwargs, x = result + solver = override_parameters(solver; kwargs...) + x_star, it = solver(; x0 = ~x, term_kwargs...) + ~x .= x_star isa Tuple ? x_star[1] : x_star + return x, it + end + if length(solvers) == 1 + print_diagnostics(terms, solvers[1]) + error("Sorry, I cannot parse this problem for solver of type $(typeof(solvers[1]).parameters[1])") + else + print_diagnostics(terms) + error("Sorry, I cannot parse this problem for any of the provided solvers") + end +end + +function solve(terms::Union{Term,TermSet}, solver::IterativeAlgorithm; kwargs...) + terms = terms isa TermSet ? terms : TermSet(terms) result = parse_problem(terms, solver) if result === nothing print_diagnostics(terms, solver) - error("Sorry, I cannot parse this problem for solver of type $(solver)") + error("Sorry, I cannot parse this problem for solver of type $(typeof(solver).parameters[1])") end - _, kwargs, x = result - x_star, it = solver(; x0 = ~x, kwargs...) + _, term_kwargs, x = result + solver = override_parameters(solver; kwargs...) + x_star, it = solver(; x0 = ~x, term_kwargs...) ~x .= x_star isa Tuple ? x_star[1] : x_star return x, it end -function solve(terms::Tuple) +function solve(terms::Union{Term,TermSet}; kwargs...) + terms = terms isa TermSet ? terms : TermSet(terms) result = parse_problem(terms) if result === nothing print_diagnostics(terms) error("Sorry, I cannot find a suitable solver for this problem") end - solver, kwargs, x = result - @show solver - x_star, it = solver(; x0 = ~x, kwargs...) + solver, term_kwargs, x = result + solver = override_parameters(solver; kwargs...) + x_star, it = solver(; x0 = ~x, term_kwargs...) ~x .= x_star return x, it end diff --git a/src/solvers/minimize.jl b/src/solvers/minimize.jl index 288b35a..dc37c3b 100644 --- a/src/solvers/minimize.jl +++ b/src/solvers/minimize.jl @@ -1,4 +1,4 @@ -export problem, @minimize +export problem, @minimize, @term """ problems(terms...) @@ -19,23 +19,42 @@ julia> p = problem(ls(A*x-b), norm(x) <= 1) ``` """ -function problem(terms::Vararg) - cf = () - for i = 1:length(terms) - cf = (cf...,terms[i]...) - end - return cf +problem(terms...) = begin + flattened_terms = Term[] + for t in terms + if t isa TermSet + append!(flattened_terms, t.terms) + elseif t isa Term + push!(flattened_terms, t) + else + error("All arguments must be of type Term or TermSet") + end + end + TermSet(flattened_terms...) end function expand_terms_with_repr(expr) if expr isa Expr && expr.head == :call && expr.args[1] == :+ - terms = map(t -> :(Term($(esc(t)), $(string(t)))), expr.args[2:end]) - return :(tuple($(terms...))) + return Tuple(map(t -> :(Term($(esc(t)), $(string(t)))), expr.args[2:end])) + elseif expr isa Symbol + return (esc(expr),) + elseif expr isa Expr && expr.head == :tuple + return Tuple(first.(expand_terms_with_repr.(expr.args))) else - return :(Term($(esc(expr)), $(string(expr)))) + return (:(Term($(esc(expr)), $(string(expr)))),) end end +""" + @term expr + +Records the code representation of the term. Useful if later we want to print the term, e.g. when debugging. +""" +macro term(expr) + terms = expand_terms_with_repr(expr) + return Expr(:block, terms...) +end + """ @minimize cost [st ctr] [with slv_opt] @@ -66,17 +85,21 @@ of iterations spent by the solver algorithm. """ macro minimize(cf::Union{Expr, Symbol}) cost = expand_terms_with_repr(cf) - return :(solve(problem($cost))) + problem_expr = Expr(:call, :problem, cost...) + return :(solve($problem_expr)) end macro minimize(cf::Union{Expr, Symbol}, s::Symbol, cstr::Union{Expr, Symbol}) cost = expand_terms_with_repr(cf) if s == :st constraints = expand_terms_with_repr(cstr) - return :(solve(problem($cost, $constraints))) + terms = (cost..., constraints...) + problem_expr = Expr(:call, :problem, terms...) + return :(solve($problem_expr)) elseif s == :with solver = esc(cstr) - return :(solve(problem($cost), $solver)) + problem_expr = Expr(:call, :problem, cost...) + return :(solve($problem_expr, $solver)) else error("wrong symbol after cost function! use `st` or `with`") end @@ -88,5 +111,7 @@ macro minimize(cf::Union{Expr, Symbol}, s::Symbol, cstr::Union{Expr, Symbol}, w: constraints = expand_terms_with_repr(cstr) w != :with && error("wrong symbol after constraints! use `with`") solver = esc(slv) - return :(solve(problem($cost, $constraints), $solver)) + terms = (cost..., constraints...) + problem_expr = Expr(:call, :problem, terms...) + return :(solve($problem_expr, $solver)) end diff --git a/src/solvers/parse.jl b/src/solvers/parse.jl index 6336c36..3a6dc75 100644 --- a/src/solvers/parse.jl +++ b/src/solvers/parse.jl @@ -30,12 +30,12 @@ function can_be_separable_sum(variable_bags) if length(term_list) > 1 # more than one term for this variable # Check if any of the terms are sliced operators = [get_operators_for_var(term, var) for term in term_list] - slicing_masks = [OperatorCore.is_sliced(op) ? OperatorCore.get_slicing_mask(op) : nothing for op in operators] + slicing_masks = [is_sliced(op) ? get_slicing_mask(op) : nothing for op in operators] for i in eachindex(operators) - if OperatorCore.is_sliced(operators[i]) + if is_sliced(operators[i]) # This operator is sliced, check if it is overlapping with any other sliced operator for j in i+1:length(operators) - if OperatorCore.is_sliced(operators[j]) && any(slicing_masks[i] .&& slicing_masks[j]) + if is_sliced(operators[j]) && any(slicing_masks[i] .&& slicing_masks[j]) return false end end @@ -49,17 +49,17 @@ function can_be_separable_sum(variable_bags) end function get_unseparable_pairs(variable_bags) - incompatibilities = Dict{StructuredOptimization.Term, Set{StructuredOptimization.Term}}() + incompatibilities = Dict{Term, Set{Term}}() for (var, term_list) in variable_bags if length(term_list) > 1 # more than one term for this variable # Check if any of the terms are sliced operators = [get_operators_for_var(term, var) for term in term_list] - slicing_masks = [OperatorCore.is_sliced(op) ? OperatorCore.get_slicing_mask(op) : nothing for op in operators] + slicing_masks = [is_sliced(op) ? get_slicing_mask(op) : nothing for op in operators] for i in eachindex(operators) - if OperatorCore.is_sliced(operators[i]) + if is_sliced(operators[i]) # This operator is sliced, check if it is overlapping with any other sliced operator for j in i+1:length(operators) - if OperatorCore.is_sliced(operators[j]) && any(slicing_masks[i] .&& slicing_masks[j]) + if is_sliced(operators[j]) && any(slicing_masks[i] .&& slicing_masks[j]) add_to_incompatibilities(incompatibilities, term_list[i], term_list[j]) end end @@ -88,9 +88,15 @@ function merge_function_with_operator(op, f, disp, λ) end elseif is_AAc_diagonal(op) f = Precompose(f, op, diag_AAc(op), disp) - else + elseif is_linear(op) # we assume that prox will not be called on this term because it will not give a valid result f = Precompose(f, op, 1, disp) + else + # we assume that prox will not be called on this term because it will not give a valid result + if disp != 0 + op = AbstractOperators.AffineAdd(op, disp) + end + f = PrecomposeNonlinear(f, op) end return λ == 1 ? f : Postcompose(f, λ) end @@ -99,7 +105,7 @@ unsatisfied_properties(term, assumptions::ProximalAlgorithms.AssumptionItem) = [ does_satisfy(term, assumptions::ProximalAlgorithms.AssumptionItem) = all(property_func(term) for property_func in assumptions.second) function prepare(term::Term, assumption::ProximalAlgorithms.SimpleTerm, variables::NTuple{N, Variable}) where N - if does_satisfy(term, assumption.func) && (!(ProximalCore.is_proximable in assumption.func.second) || OperatorCore.is_AAc_diagonal(term.A.L)) + if does_satisfy(term, assumption.func) && (!(ProximalCore.is_proximable in assumption.func.second) || is_AAc_diagonal(term.A.L)) op = extract_operators(variables, term) disp = displacement(term) return (assumption.func.first => merge_function_with_operator(op, term.f, disp, term.lambda),) @@ -118,7 +124,7 @@ function print_diagnostics(term::Term, assumption::ProximalAlgorithms.SimpleTerm end end -function prepare_proximable_single_var_per_term(variable_bags, variables::NTuple{M, Variable}) where {M} +function prepare_proximable_single_var_per_term(variable_bags, variables::NTuple{N, Variable}) where {N} fs = () for var in variables if haskey(variable_bags, var) @@ -138,7 +144,7 @@ function prepare_proximable_single_var_per_term(variable_bags, variables::NTuple else idx = op.idx end - idxs = (idxs..., OperatorCore.get_slicing_mask(op)) + idxs = (idxs..., get_slicing_mask(op)) end fs = (fs..., SlicedSeparableSum(fxi,idxs)) else @@ -153,7 +159,7 @@ function prepare_proximable_single_var_per_term(variable_bags, variables::NTuple return SeparableSum(fs) end -function prepare(terms::NTuple{N, Term}, assumption::ProximalAlgorithms.SimpleTerm, variables::NTuple{M, Variable}) where {N,M} +function prepare(terms::TermSet, assumption::ProximalAlgorithms.SimpleTerm, variables::NTuple{N, Variable}) where {N} if length(terms) == 1 return prepare(terms[1], assumption, variables) end @@ -173,9 +179,9 @@ function prepare(terms::NTuple{N, Term}, assumption::ProximalAlgorithms.SimpleTe return (assumption.func.first => prepare_proximable_single_var_per_term(variable_bags, variables),) else op = extract_operators(variables, terms) - idxs = OperatorCore.get_slicing_expr(op) - op = OperatorCore.remove_slicing(op) - hcat_ops = Tuple(op[i] for i in eachindex(op.A)) + idxs = get_slicing_expr(op) + op = remove_slicing(op) + hcat_ops = tuple([op[i] for i in eachindex(op.A)]...) μs = AbstractOperators.diag_AAc(op) f = extract_functions(terms) return (assumption.func.first => PrecomposedSlicedSeparableSum(f.fs, idxs, hcat_ops, μs),) @@ -184,7 +190,7 @@ function prepare(terms::NTuple{N, Term}, assumption::ProximalAlgorithms.SimpleTe fs = () for term in terms if is_linear(term) - f = merge_function_with_operator(operator(term), term.f, displacement(term), term.lambda) + f = merge_function_with_operator(extract_operators(variables, term), term.f, displacement(term), term.lambda) else f = extract_functions(term) op = extract_affines(variables, term) @@ -193,11 +199,11 @@ function prepare(terms::NTuple{N, Term}, assumption::ProximalAlgorithms.SimpleTe end fs = (fs..., f) end - return (assumption.func.first => SeparableSum(fs),) + return (assumption.func.first => ProximalOperators.Sum(fs),) end end -function print_diagnostics(terms::NTuple{N, Term}, assumption::ProximalAlgorithms.SimpleTerm, variables::NTuple{M, Variable}) where {N,M} +function print_diagnostics(terms::TermSet, assumption::ProximalAlgorithms.SimpleTerm, variables::NTuple{N, Variable}) where {N} if length(terms) == 1 print_diagnostics(terms[1], assumption, variables) return @@ -246,7 +252,7 @@ end function print_diagnostics(term::Term, assumption::ProximalAlgorithms.OperatorTerm, variables::NTuple{N, Variable}) where N op = affine(term) repr = term.repr !== nothing ? term.repr : string(term) - if OperatorCore.is_eye(op) + if is_eye(op) problematic_properties = unsatisfied_properties(term.f, assumption.func) println("Term $repr does not satisfy required properties: $(join(problematic_properties, ", "))") else @@ -271,7 +277,7 @@ function print_diagnostics(term::Term, assumption::ProximalAlgorithms.OperatorTe print_diagnostics(term, ProximalAlgorithms.SimpleTerm(assumption.func), variables) end -function prepare(terms::NTuple{N, Term}, assumption::ProximalAlgorithms.OperatorTerm, variables::NTuple{M, Variable}) where {N,M} +function prepare(terms::TermSet, assumption::ProximalAlgorithms.OperatorTerm, variables::NTuple{N, Variable}) where {N} if length(terms) == 1 return prepare(terms[1], assumption, variables) end @@ -287,11 +293,11 @@ function prepare(terms::NTuple{N, Term}, assumption::ProximalAlgorithms.Operator end end -function print_diagnostics(terms::NTuple{N, Term}, assumption::ProximalAlgorithms.OperatorTerm, variables::NTuple{M, Variable}) where {N,M} +function print_diagnostics(terms::TermSet, assumption::ProximalAlgorithms.OperatorTerm, variables::NTuple{N, Variable}) where {N} op = extract_affines(variables, terms) f = extract_functions(terms) repr = string(terms) - if OperatorCore.is_eye(op) + if is_eye(op) for term in terms problematic_properties = unsatisfied_properties(term.f, assumption.func) println("Term $repr does not satisfy required properties: $(join(problematic_properties, ", "))") @@ -317,7 +323,7 @@ function print_diagnostics(terms::NTuple{N, Term}, assumption::ProximalAlgorithm print_diagnostics(terms, ProximalAlgorithms.SimpleTerm(assumption.func), variables) end -function prepare(term::Term, assumption::ProximalAlgorithms.OperatorTermWithInfimalConvolution, variables::NTuple{M, Variable}) where {M} +function prepare(term::Term, assumption::ProximalAlgorithms.OperatorTermWithInfimalConvolution, variables::NTuple{N, Variable}) where {N} op = extract_affines(variables, term) f = extract_functions(term) if does_satisfy(op, assumption.operator) && does_satisfy(f, assumption.func₁) @@ -334,18 +340,18 @@ function prepare(term::Term, assumption::ProximalAlgorithms.OperatorTermWithInfi # try preparing as a simple term tup = prepare(term, ProximalAlgorithms.SimpleTerm(assumption.func₁), variables) if tup !== nothing && length(variables) > 1 - example_input = ArrayPartition(Tuple(~var for var in variables)) + example_input = ArrayPartition(tuple([~var for var in variables]...)) tup = (tup..., assumption.operator.first => AbstractOperators.Eye(example_input)) end return tup end end -function print_diagnostics(term::Term, assumption::ProximalAlgorithms.OperatorTermWithInfimalConvolution, variables::NTuple{M, Variable}) where {M} +function print_diagnostics(term::Term, assumption::ProximalAlgorithms.OperatorTermWithInfimalConvolution, variables::NTuple{N, Variable}) where {N} op = affine(term) f = extract_functions(term) repr = term.repr !== nothing ? term.repr : string(term) - if OperatorCore.is_eye(op) + if is_eye(op) problematic_properties = unsatisfied_properties(term.f, assumption.func₁) println("Term $repr does not satisfy required properties: $(join(problematic_properties, ", "))") else @@ -369,7 +375,7 @@ function print_diagnostics(term::Term, assumption::ProximalAlgorithms.OperatorTe print_diagnostics(term, ProximalAlgorithms.SimpleTerm(assumption.func₁), variables) end -function prepare(terms::NTuple{N, Term}, assumption::ProximalAlgorithms.OperatorTermWithInfimalConvolution, variables::NTuple{M, Variable}) where {N,M} +function prepare(terms::TermSet, assumption::ProximalAlgorithms.OperatorTermWithInfimalConvolution, variables::NTuple{N, Variable}) where {N} if length(terms) == 1 return prepare(terms[1], assumption, variables) end @@ -392,14 +398,14 @@ function prepare(terms::NTuple{N, Term}, assumption::ProximalAlgorithms.Operator tup = prepare(terms, ProximalAlgorithms.SimpleTerm(assumption.func₂), variables) end if tup !== nothing && length(variables) > 1 - example_input = ArrayPartition(Tuple(~var for var in variables)) + example_input = ArrayPartition(tuple([~var for var in variables]...)) tup = (tup..., assumption.operator.first => AbstractOperators.Eye(example_input)) end return tup end end -function print_diagnostics(terms::NTuple{N, Term}, assumption::ProximalAlgorithms.OperatorTermWithInfimalConvolution, variables::NTuple{M, Variable}) where {N,M} +function print_diagnostics(terms::TermSet, assumption::ProximalAlgorithms.OperatorTermWithInfimalConvolution, variables::NTuple{N, Variable}) where {N} if length(terms) == 1 print_diagnostics(terms[1], assumption, variables) return @@ -407,7 +413,7 @@ function print_diagnostics(terms::NTuple{N, Term}, assumption::ProximalAlgorithm op = affine(terms[1].A) f = extract_functions(terms) repr = string(terms) - if OperatorCore.is_eye(op) + if is_eye(op) for term in terms problematic_properties = unsatisfied_properties(term.f, assumption.func₁) println("Term $repr does not satisfy required properties: $(join(problematic_properties, ", "))") @@ -440,3 +446,174 @@ function print_diagnostics(terms::NTuple{N, Term}, assumption::ProximalAlgorithm println("When trying to prepare the term as a simple term:") print_diagnostics(terms, ProximalAlgorithms.SimpleTerm(assumption.func₁), variables) end + +function prepare(term::Term, assumption::ProximalAlgorithms.LeastSquaresTerm, variables::NTuple{N, Variable}) where N + f = term.f + f_is_ls = f isa ProximalOperators.LeastSquares || f isa ProximalOperators.SqrNormL2 || f isa SqrNormL2WithNormalOp + if !f_is_ls + return nothing + end + if f isa SqrNormL2WithNormalOp + lambda = term.lambda * f.lambda + op = term.f.A + b = displacement(op) + op = remove_displacement(op) + else + lambda = term.lambda + op = extract_operators(variables, term) + b = displacement(term) + end + if !does_satisfy(op, assumption.operator) + return nothing + end + if lambda != 1 + op = lambda * op + b = lambda * b + end + return ( + assumption.operator.first => op, + assumption.b => b, + ) +end + +function print_diagnostics(term::Term, assumption::ProximalAlgorithms.LeastSquaresTerm, variables::NTuple{N, Variable}) where N + op = extract_operators(variables, term) + b = displacement(term) + f = term.f + repr = term.repr !== nothing ? term.repr : string(term) + if !(f isa ProximalOperators.LeastSquares || f isa ProximalOperators.SqrNormL2) + println("Term $repr does not satisfy required property: it is not a least squares function") + else + println("A possible decomposition of term $repr:") + print(" - ", assumption.operator.first, " = ", op) + problematic_properties = unsatisfied_properties(op, assumption.operator) + println(" -> $(join(problematic_properties, ", ")) $(length(problematic_properties) == 1 ? "property is" : "properties are") not satisfied") + print(" - ", assumption.b.first, " = ", b) + end +end + +function prepare(terms::TermSet, assumption::ProximalAlgorithms.LeastSquaresTerm, variables::NTuple{N, Variable}) where {N} + if length(terms) == 1 + return prepare(terms[1], assumption, variables) + end + return nothing +end + +function print_diagnostics(terms::TermSet, assumption::ProximalAlgorithms.LeastSquaresTerm, variables::NTuple{N, Variable}) where {N} + if length(terms) == 1 + print_diagnostics(terms[1], assumption, variables) + else + println("Cannot prepare terms $terms as a least squares term: only a single term can be prepared as such.") + end +end + +function prepare(term::Term, assumption::ProximalAlgorithms.SquaredL2Term, variables::NTuple{N, Variable}) where N + f = term.f + if displacement(term) != 0 || !(f isa ProximalOperators.SqrNormL2) + return nothing + end + λ = term.lambda * f.lambda + op = extract_affines(variables, term) + if is_eye(op) + return (assumption.λ => λ,) + elseif is_diagonal(op) + return (assumption.λ => λ * diag(op),) + else + return nothing + end +end + +function print_diagnostics(term::Term, ::ProximalAlgorithms.SquaredL2Term, variables::NTuple{N, Variable}) where N + repr = term.repr !== nothing ? term.repr : string(term) + if displacement(term) != 0 + println("Term $repr does not satisfy required property: it has non-zero displacement") + elseif !(term.f isa ProximalOperators.SqrNormL2) + println("Term $repr does not satisfy required property: it is not a squared L2 function") + else + println("Term $repr does not satisfy required property: the operator is not an identity or diagonal") + end +end + +function prepare(terms::TermSet, assumption::ProximalAlgorithms.SquaredL2Term, variables::NTuple{N, Variable}) where {N} + if length(terms) == 1 + return prepare(terms[1], assumption, variables) + end + return nothing +end + +function print_diagnostics(terms::TermSet, assumption::ProximalAlgorithms.SquaredL2Term, variables::NTuple{N, Variable}) where {N} + if length(terms) == 1 + print_diagnostics(terms[1], assumption, variables) + else + println("Cannot prepare terms $terms as a squared L2 term: only a single term can be prepared as such.") + end +end + +function prepare(term::Term, assumption::ProximalAlgorithms.RepeatedSimpleTerm, variables::NTuple{N, Variable}) where N + simple_assumption = ProximalAlgorithms.SimpleTerm(assumption.func) + return prepare(term, simple_assumption, variables) +end + +function print_diagnostics(term::Term, assumption::ProximalAlgorithms.RepeatedSimpleTerm, variables::NTuple{N, Variable}) where N + simple_assumption = ProximalAlgorithms.SimpleTerm(assumption.func) + print_diagnostics(term, simple_assumption, variables) +end + +function prepare(terms::TermSet, assumption::ProximalAlgorithms.RepeatedSimpleTerm, variables::NTuple{N, Variable}) where {N} + simple_assumption = ProximalAlgorithms.SimpleTerm(assumption.func) + results = () + for term in terms + result = prepare(term, simple_assumption, variables) + if isnothing(result) + return nothing + end + results = (results..., result[1].second) + end + return (assumption.func.first => results,) +end + +function print_diagnostics(terms::TermSet, assumption::ProximalAlgorithms.RepeatedSimpleTerm, variables::NTuple{N, Variable}) where {N} + simple_assumption = ProximalAlgorithms.SimpleTerm(assumption.func) + for term in terms + if prepare(term, simple_assumption, variables) === nothing + print_diagnostics(term, simple_assumption, variables) + end + end +end + +function prepare(term::Term, assumption::ProximalAlgorithms.RepeatedOperatorTerm, variables::NTuple{N, Variable}) where N + operator_term_assumption = ProximalAlgorithms.OperatorTerm(assumption.func, assumption.operator) + return prepare(term, operator_term_assumption, variables) +end + +function print_diagnostics(term::Term, assumption::ProximalAlgorithms.RepeatedOperatorTerm, variables::NTuple{N, Variable}) where N + operator_term_assumption = ProximalAlgorithms.OperatorTerm(assumption.func, assumption.operator) + print_diagnostics(term, operator_term_assumption, variables) +end + +function prepare(terms::TermSet, assumption::ProximalAlgorithms.RepeatedOperatorTerm, variables::NTuple{N, Variable}) where {N} + operator_term_assumption = ProximalAlgorithms.OperatorTerm(assumption.func, assumption.operator) + function_results = () + operator_results = () + for term in terms + result = prepare(term, operator_term_assumption, variables) + if isnothing(result) + return nothing + end + function_results = (function_results..., result[1].second) + operator_results = (operator_results..., result[2].second) + end + return ( + assumption.func.first => function_results, + assumption.operator.first => operator_results + ) +end + +function print_diagnostics(terms::TermSet, assumption::ProximalAlgorithms.RepeatedOperatorTerm, variables::NTuple{N, Variable}) where {N} + operator_term_assumption = ProximalAlgorithms.OperatorTerm(assumption.func, assumption.operator) + for term in terms + if prepare(term, operator_term_assumption, variables) === nothing + print_diagnostics(term, operator_term_assumption, variables) + end + end +end diff --git a/src/solvers/terms_extract.jl b/src/solvers/terms_extract.jl index 5fbd207..a7c583b 100644 --- a/src/solvers/terms_extract.jl +++ b/src/solvers/terms_extract.jl @@ -1,40 +1,41 @@ # returns all variables of a cost function, in terms of appearance extract_variables(t::TermOrExpr) = variables(t) -function extract_variables(t::NTuple{N,TermOrExpr}) where {N} +function extract_variables(t::Union{Tuple, TermSet}) var_tuples = variables.(t) - vars = vcat(collect.(var_tuples)...) + vars = collect(Base.Iterators.flatten(var_tuples)) return tuple(unique(vars)...) end # extract functions from terms function extract_functions(t::Term) - f = displacement(t) == 0 ? t.f : PrecomposeDiagonal(t.f, one(t.lambda), displacement(t)) #for now I keep this + disp = displacement(t) + f = disp == 0 ? t.f : PrecomposeDiagonal(t.f, one(t.lambda), disp) #for now I keep this f = t.lambda == 1 ? f : Postcompose(f, t.lambda) #for now I keep this #TODO change this return f end -extract_functions(t::NTuple{N,Term}) where {N} = SeparableSum(extract_functions.(t)) -extract_functions(t::Tuple{Term}) = extract_functions(t[1]) +extract_functions(t::TermSet) = SeparableSum(extract_functions.(t)) # extract functions from terms without displacement function extract_functions_nodisp(t::Term) f = t.lambda == 1 ? t.f : Postcompose(t.f, t.lambda) return f end -extract_functions_nodisp(t::NTuple{N,Term}) where {N} = SeparableSum(extract_functions_nodisp.(t)) -extract_functions_nodisp(t::Tuple{Term}) = extract_functions_nodisp(t[1]) +extract_functions_nodisp(t::TermSet) = SeparableSum(extract_functions_nodisp.(t)) # extract operators from terms # returns all operators with an order dictated by xAll #single term, single variable -extract_operators(xAll::Tuple{Variable}, t::TermOrExpr) = operator(t) -extract_operators(xAll::NTuple{N,Variable}, t::TermOrExpr) where {N} = extract_operators(xAll, (t,)) +extract_operators(::Tuple{Variable}, t::AbstractExpression) = operator(t) +extract_operators(::Tuple{Variable}, t::Term) = operator(t) +extract_operators(xAll::NTuple{N,Variable}, t::AbstractExpression) where {N} = extract_operators(xAll, (t,)) +extract_operators(xAll::NTuple{N,Variable}, t::Term) where {N} = extract_operators(xAll, TermSet(t,)) #multiple terms, multiple variables -function extract_operators(xAll::NTuple{N,Variable}, t::NTuple{M,TermOrExpr}) where {N,M} +function extract_operators(xAll::NTuple{N,Variable}, t::TermSet) where {N} ops = () for ti in t tex = expand(xAll,ti) @@ -59,11 +60,13 @@ end # returns all affines with an order dictated by xAll #single term, single variable -extract_affines(::Tuple{Variable}, t::TermOrExpr) = affine(t) -extract_affines(xAll::NTuple{N,Variable}, t::TermOrExpr) where {N} = extract_affines(xAll, (t,)) +extract_affines(::Tuple{Variable}, t::AbstractExpression) = affine(t) +extract_affines(::Tuple{Variable}, t::Term) = affine(t) +extract_affines(xAll::NTuple{N,Variable}, t::AbstractExpression) where {N} = extract_affines(xAll, (t,)) +extract_affines(xAll::NTuple{N,Variable}, t::Term) where {N} = extract_affines(xAll, TermSet(t,)) #multiple terms, multiple variables -function extract_affines(xAll::NTuple{N,Variable}, t::NTuple{M,TermOrExpr}) where {N,M} +function extract_affines(xAll::NTuple{N,Variable}, t::TermSet) where {N} ops = () for ti in t tex = expand(xAll,ti) @@ -86,7 +89,7 @@ end # expand term domain dimensions function expand(xAll::NTuple{N,Variable}, t::Term) where {N} xt = variables(t) - C = codomainType(operator(t)) + C = codomain_type(operator(t)) size_out = size(operator(t),1) ex = t.A @@ -101,7 +104,7 @@ end function expand(xAll::NTuple{N,Variable}, ex::AbstractExpression) where {N} ex = convert(Expression,ex) xt = variables(ex) - C = codomainType(operator(ex)) + C = codomain_type(operator(ex)) size_out = size(operator(ex),1) for x in xAll diff --git a/src/solvers/terms_properties.jl b/src/solvers/terms_properties.jl index fe987c6..45c517c 100644 --- a/src/solvers/terms_properties.jl +++ b/src/solvers/terms_properties.jl @@ -1,4 +1,4 @@ -is_proximable(term::Term) = is_AAc_diagonal(term) +is_proximable(term::Term) = is_proximable(typeof(term.f)) && is_AAc_diagonal(term.A.L) function get_operators_for_var(term, var) full_operator = affine(term) @@ -9,7 +9,7 @@ function get_operators_for_var(term, var) end end -function is_separable_sum(terms::NTuple{N,Term}) where {N} +function is_separable_sum(terms::TermSet) # Construct the set of occurring variables vars = Set() for term in terms @@ -25,11 +25,11 @@ function is_separable_sum(terms::NTuple{N,Term}) where {N} end # All terms must be sliced for this variable operators = [get_operators_for_var(term, var) for term in terms_with_var] - if any(!OperatorCore.is_sliced(op) for op in operators) + if any(is_sliced(op) for op in operators) return false end # The sliced operators must not overlap - slicing_masks = [OperatorCore.is_sliced(op) ? OperatorCore.get_slicing_mask(op) : nothing for op in operators] + slicing_masks = [is_sliced(op) ? get_slicing_mask(op) : nothing for op in operators] for i in eachindex(operators), j in i+1:length(operators) if any(slicing_masks[i] .&& slicing_masks[j]) return false @@ -40,6 +40,6 @@ function is_separable_sum(terms::NTuple{N,Term}) where {N} return true end -function is_proximable(terms::NTuple{N,Term}) where {N} +function is_proximable(terms::TermSet) return all(is_proximable.(terms)) && is_separable_sum(terms) end diff --git a/src/syntax/expressions/abstractOperator_bind.jl b/src/syntax/expressions/abstractOperator_bind.jl index 6f38a8c..38d8b6d 100644 --- a/src/syntax/expressions/abstractOperator_bind.jl +++ b/src/syntax/expressions/abstractOperator_bind.jl @@ -33,7 +33,7 @@ imported = [ ] importedFFTW = [ - :fft :(AbstractOperators.DFT); + :fft :DFT; :rfft :RDFT; :irfft :IRDFT; :ifft :IDFT; @@ -90,7 +90,7 @@ for i = 1:size(fun,1) @eval begin function $f(a::AbstractExpression, args...) A = convert(Expression,a) - op = $fAbsOp(codomainType(operator(A)),size(operator(A),1), args...) + op = $fAbsOp(codomain_type(operator(A)),size(operator(A),1), args...) return op*A end end diff --git a/src/syntax/expressions/addition.jl b/src/syntax/expressions/addition.jl index e321e1b..9f8b4fd 100644 --- a/src/syntax/expressions/addition.jl +++ b/src/syntax/expressions/addition.jl @@ -174,7 +174,7 @@ julia> b = randn(10); julia> size(b), eltype(b) ((10,), Float64) -julia> size(affine(ex),1), codomainType(affine(ex)) +julia> size(affine(ex),1), codomain_type(affine(ex)) ((10,), Float64) julia> ex + b diff --git a/src/syntax/expressions/addition_tricky_part.jl b/src/syntax/expressions/addition_tricky_part.jl index baeaca8..dc7bbca 100644 --- a/src/syntax/expressions/addition_tricky_part.jl +++ b/src/syntax/expressions/addition_tricky_part.jl @@ -157,7 +157,7 @@ function add_missing_vars(old_vars, op, vars) if isempty(missing_vars) return old_vars, op end - dummy_ops = [AbstractOperators.Zeros(eltype(~var), size(~var), AbstractOperators.codomainType(op), size(op, 1)) for var in missing_vars] + dummy_ops = [AbstractOperators.Zeros(eltype(~var), size(~var), AbstractOperators.codomain_type(op), size(op, 1)) for var in missing_vars] new_vars = (old_vars..., missing_vars...) new_op = AbstractOperators.HCAT(op, dummy_ops...) return new_vars, new_op @@ -187,45 +187,3 @@ function Usum_op( opNew = sign ? A+B : A-B return xNew, opNew end - -#= -function _replace_in(obj, tasks) - for task in tasks - if obj === task.first - return task.second, filter(t -> t !== task, tasks) - end - end - return obj, tasks -end -function _replace_in(obj::Tuple, tasks) - new_tuple = [] - for o in obj - new_obj, tasks = _replace_in(o, tasks) - push!(new_tuple, new_obj) - end - return tuple(new_tuple...), tasks -end -function _replace_in(obj::AbstractOperators.AbstractOperator, tasks) - fields = [getfield(obj, name) for name in fieldnames(typeof(obj))] - new_fields = [_replace_in(field, searched_obj, new_obj) for field in fields] - maybe_new_obj = any(new_fields .!== fields) ? typeof(obj).name.wrapper(new_fields...) : obj - return maybe_new_obj, tasks -end -function permute_single_operator(op::AbstractOperators.HCAT, perm::Vector{Int}) - @show op - @show perm - return AbstractOperators.HCAT([op[i] for i in perm]...) -end -function permute_operator(op::AbstractOperators.AbstractOperator, permutations) - @show permutations - tasks = [(old_op => permute_single_operator(old_op, perm)) for (old_op, perm) in reverse(permutations)] - #=for (old_op, perm) in reverse(permutations) - new_op = permute_single_operator(old_op, perm) - @show op - @show old_op - @show new_op - op = _replace_in(op, old_op, new_op) - end=# - return _replace_in(op, tasks) - #return op -end=# diff --git a/src/syntax/expressions/expression.jl b/src/syntax/expressions/expression.jl index fb619c3..5d3fad4 100644 --- a/src/syntax/expressions/expression.jl +++ b/src/syntax/expressions/expression.jl @@ -13,11 +13,11 @@ struct Expression{N,A<:AbstractOperator} <: AbstractExpression check_sz && throw(ArgumentError( "Size of the operator domain $(size(L, 2)) must match size of the variable $(size.(x))" )) - dmL = domainType(L) + dmL = domain_type(L) dmx = eltype.(x) check_dm = length(dmx) == 1 ? dmx[1] != dmL : dmx != dmL check_dm && throw(ArgumentError( - "Type of the operator domain $(domainType(L)) must match type of the variable $(eltype.(x))" + "Type of the operator domain $(domain_type(L)) must match type of the variable $(eltype.(x))" )) new{N,A}(x,L) end diff --git a/src/syntax/expressions/multiplication.jl b/src/syntax/expressions/multiplication.jl index 3f7ac8e..5658422 100644 --- a/src/syntax/expressions/multiplication.jl +++ b/src/syntax/expressions/multiplication.jl @@ -71,21 +71,21 @@ julia> randn(10,5).*X """ function (*)(m::T, a::Union{AbstractVector,AbstractMatrix}) where {T<:AbstractExpression} M = convert(Expression,m) - op = LMatrixOp(codomainType(affine(M)),size(affine(M),1),a) + op = LMatrixOp(codomain_type(affine(M)),size(affine(M),1),a) return op*M end #LMatrixOp function (*)(M::AbstractMatrix, a::T) where {T<:AbstractExpression} A = convert(Expression,a) - op = MatrixOp(codomainType(affine(A)),size(affine(A),1),M) + op = MatrixOp(codomain_type(affine(A)),size(affine(A),1),M) return op*A end #MatrixOp function Broadcast.broadcasted(::typeof(*), d::D, a::T) where {D <: Union{Number,AbstractArray}, T<:AbstractExpression} A = convert(Expression,a) - op = DiagOp(codomainType(affine(A)),size(affine(A),1),d) + op = DiagOp(codomain_type(affine(A)),size(affine(A),1),d) return op*A end Broadcast.broadcasted(::typeof(*), a::T, d::D) where {D <: Union{Number,AbstractArray}, T<:AbstractExpression} = diff --git a/src/syntax/terms/proximalOperators_bind.jl b/src/syntax/terms/proximalOperators_bind.jl index d3c9ba0..fb5aa28 100644 --- a/src/syntax/terms/proximalOperators_bind.jl +++ b/src/syntax/terms/proximalOperators_bind.jl @@ -1,10 +1,10 @@ # Norms import LinearAlgebra: norm -export norm, mixednorm +export norm """ - norm(x::AbstractExpression, p=2, [q,] [dim=1]) + norm(x::AbstractExpression, p=2, [q]; [dim=1]) Returns the norm of `x`. @@ -48,35 +48,14 @@ function norm(ex::AbstractExpression, ::typeof(*)) end # Mixed Norm -""" - mixednorm(x, p::Int, q::Int) - -``l_{2,1}`` mixed norm (aka Sum-of-``l_2``-norms) -```math -f(\\mathbf{X}) = \\sum_i \\| \\mathbf{x}_i \\| -``` -where ``\\mathbf{x}_i`` is the ``i``-th column if `p == 2` and `q == 1` (or row if `p == 1` and `q == 2`) of ``\\mathbf{X}``. -""" -function mixednorm(ex::AbstractExpression, p::Int, q::Int) - if p == 2 && q == 1 - f = NormL21(1.0, 1) - elseif p == 1 && q == 2 - f = NormL21(1.0, 2) +function norm(ex::AbstractExpression, p1::Int, p2::Int; dim::Int = 1) + if p1 == 2 && p2 == 1 + f = NormL21(1.0, dim) else error("function not implemented") end return Term(f, ex) end -function mixednorm(A::AbstractMatrix{T}, p::Int, q::Int) where {T} - if p == 2 && q == 1 - return NormL21(1.0, 1)(A) - elseif p == 1 && q == 2 - return NormL21(1.0, 2)(A) - else - error("function not implemented") - end - return result -end # Least square terms @@ -111,9 +90,9 @@ This is much faster to compute, if `Lᴴ * L` has a fast implementation. normalop_ls(::Variable) = error("normalop_ls does not work with Variables alone. Use ls instead.") function normalop_ls(ex::Expression) eye_op = if length(ex.x) == 1 - Eye(domainType(ex.L), size(ex.L, 2)) + Eye(domain_type(ex.L), size(ex.L, 2)) else - HCAT([Eye(domainType(L), size(L, 2)) for L in ex.L]...) + HCAT([Eye(domain_type(L), size(L, 2)) for L in ex.L]...) end return Term(SqrNormL2WithNormalOp(ex.L), Expression(ex.x, eye_op)) end diff --git a/src/syntax/terms/term.jl b/src/syntax/terms/term.jl index c3c25ad..b986279 100644 --- a/src/syntax/terms/term.jl +++ b/src/syntax/terms/term.jl @@ -1,48 +1,82 @@ -struct Term{T1 <: Real, T2, T3 <: AbstractExpression} - lambda::T1 - f::T2 - A::T3 - repr::Union{String,Nothing} +struct Term{T1<:Real,T2,T3<:AbstractExpression} + lambda::T1 + f::T2 + A::T3 + repr::Union{String,Nothing} end function Term(lambda, f, ex::AbstractExpression) - return Term(lambda,f,ex,nothing) + return Term(lambda, f, ex, nothing) end function Term(f, ex::AbstractExpression) - A = convert(Expression,ex) - Term(one(real(codomainType(affine(A)))),f, A) + A = convert(Expression, ex) + Term(one(real(codomain_type(affine(A)))), f, A) end function Term(f, ex::AbstractExpression, repr::String) - A = convert(Expression,ex) - Term(one(real(codomainType(affine(A)))),f, A, repr) + A = convert(Expression, ex) + Term(one(real(codomain_type(affine(A)))), f, A, repr) end function Term(t::Term, repr::String) - Term(t.lambda, t.f, t.A, repr) + Term(t.lambda, t.f, t.A, repr) end +struct TermSet{N,T} + terms::T + function TermSet(terms...) + @assert all(t -> t isa Term, terms) "All elements must be of type Term" + new{length(terms), typeof(terms)}(terms) + end +end + +function Base.iterate(t::TermSet{N}, state=1) where {N} + if state > N + return nothing + else + return (t.terms[state], state + 1) + end +end + +Base.length(::TermSet{N}) where {N} = N +Base.getindex(t::TermSet{N}, i::Int) where {N} = t.terms[i] + +Term(t::TermSet, ::String) = t + import Base: ==, show # Ignore the repr when comparing terms ==(t1::Term, t2::Term) = t1.lambda == t2.lambda && t1.f == t2.f && t1.A == t2.A function show(io::IO, t::Term) - if t.repr !== nothing - print(io, t.repr) - else - print(io, t.lambda, " * ", t.f, "(", t.A, ")") - end + if t.repr !== nothing + print(io, t.repr) + else + print(io, t.lambda, " * ", t.f, "(", t.A, ")") + end end -function show(io::IO, t::NTuple{N,Term}) where {N} - for i in 1:N - show(io, t[i]) - if i < N - print(io, " + ") - end - end +function show(io::IO, t::TermSet) + non_indicator_terms = filter(x -> !is_set_indicator(x), t.terms) + indicator_terms = filter(is_set_indicator, t.terms) + for i in 1:length(non_indicator_terms) + show(io, non_indicator_terms[i]) + if i < length(non_indicator_terms) + print(io, " + ") + end + end + if !isempty(indicator_terms) + if !isempty(non_indicator_terms) + print(io, " s.t. ") + end + for i in 1:length(indicator_terms) + show(io, indicator_terms[i]) + if i < length(indicator_terms) + print(io, ", ") + end + end + end end # Operations @@ -51,24 +85,22 @@ end import Base: + -(+)(a::Term,b::Term) = (a,b) -(+)(a::NTuple{N,Term},b::Term) where {N} = (a...,b) -(+)(a::Term,b::NTuple{N,Term}) where {N} = (a,b...) -(+)(a::NTuple{N,Term},::Tuple{}) where {N} = a -(+)(::Tuple{},b::NTuple{N,Term}) where {N} = b -(+)(a::NTuple{N,Term},b::NTuple{M,Term}) where {N,M} = (a...,b...) +(+)(a::Term, b::Term) = TermSet(a, b) +(+)(a::TermSet, b::Term) = TermSet(a..., b) +(+)(a::Term, b::TermSet) = TermSet(a, b...) +(+)(a::TermSet, b::TermSet) = TermSet(a..., b...) # Define multiplication by constant import Base: * -function (*)(a::T1, t::Term{T,T2,T3}) where {T1<:Real, T, T2, T3} - coeff = *(promote(a,t.lambda)...) - Term(coeff, t.f, t.A) +function (*)(a::T1, t::Term{T,T2,T3}) where {T1<:Real,T,T2,T3} + coeff = *(promote(a, t.lambda)...) + Term(coeff, t.f, t.A) end -function (*)(a::T1, t::T2) where {T1<:Real, N, T2 <: Tuple{Vararg{<:Term,N}} } - return a.*t +function (*)(a::T1, t::TermSet) where {T1<:Real} + return a .* t end # Properties @@ -80,59 +112,55 @@ displacement(t::Term) = displacement(t.A) #importing properties from ProximalOperators import ProximalCore: - is_affine_indicator, - is_cone_indicator, - is_convex, - is_generalized_quadratic, - is_proximable, - is_quadratic, - is_separable, - is_set_indicator, - is_singleton_indicator, - is_smooth, - is_locally_smooth, - is_strongly_convex - -is_func_f = [ - :is_set_indicator, - :is_singleton_indicator, - :is_smooth, - :is_locally_smooth, - ] + is_affine_indicator, + is_cone_indicator, + is_convex, + is_generalized_quadratic, + is_proximable, + is_quadratic, + is_separable, + is_set_indicator, + is_singleton_indicator, + is_smooth, + is_locally_smooth, + is_strongly_convex + +is_func_f = [:is_set_indicator, :is_singleton_indicator, :is_smooth, :is_locally_smooth] for f in is_func_f - @eval begin - import ProximalCore: $f - $f(t::Term) = $f(t.f) - $f(t::NTuple{N,Term}) where {N} = all($f.(t)) - end + @eval begin + import ProximalCore: $f + $f(t::Term) = $f(t.f) + $f(t::TermSet) = all($f.(t.terms)) + end end #importing properties from AbstractOperators -is_op_f = [:is_linear, - :is_eye, - :is_null, - :is_diagonal, - :is_AcA_diagonal, - :is_AAc_diagonal, - :is_orthogonal, - :is_invertible, - :is_full_row_rank, - :is_full_column_rank, - :is_sliced - ] +is_op_f = [ + :is_linear, + :is_eye, + :is_null, + :is_diagonal, + :is_AcA_diagonal, + :is_AAc_diagonal, + :is_orthogonal, + :is_invertible, + :is_full_row_rank, + :is_full_column_rank, + :is_sliced, +] for f in is_op_f - @eval begin - import AbstractOperators: $f - $f(t::Term) = $f(operator(t)) - $f(t::NTuple{N,Term}) where {N} = all($f.(t)) - end + @eval begin + import AbstractOperators: $f + $f(t::Term) = $f(operator(t)) + $f(t::TermSet) = all($f.(t)) + end end is_affine_indicator(t::Term) = is_affine_indicator(t.f) && is_linear(t) is_cone_indicator(t::Term) = is_cone_indicator(t.f) && is_linear(t) -is_convex(t::Term) = is_convex(t.f) && is_linear(t) +is_convex(t::Term) = is_convex(t.f) && is_linear(t) is_quadratic(t::Term) = is_quadratic(t.f) && is_linear(t) is_generalized_quadratic(t::Term) = is_generalized_quadratic(t.f) && is_linear(t) is_strongly_convex(t::Term) = is_strongly_convex(t.f) && is_full_column_rank(operator(t.A)) @@ -142,5 +170,5 @@ include("proximalOperators_bind.jl") # other stuff, to make Term work with iterators import Base: iterate, isempty -iterate(t::Term, state = true) = state ? (t, false) : nothing -isempty(t::Term) = false +iterate(t::Term, state=true) = state ? (t, false) : nothing +isempty(t::Term) = false diff --git a/src/syntax/variable.jl b/src/syntax/variable.jl index c3416c7..159c698 100644 --- a/src/syntax/variable.jl +++ b/src/syntax/variable.jl @@ -1,34 +1,36 @@ import Base: convert, size, eltype, ~ -export Variable +export Variable, get_name struct Variable{T, N, A <: AbstractArray{T,N}} <: AbstractExpression x::A + name::String + function Variable(x::AbstractArray{T,N}; name::String="x") where {T,N} + A = typeof(x) + new{T,N,A}(x, name) + end end # constructors """ - Variable([T::Type,] dims...) + Variable([T::Type,] dims...; name::String="x") + Variable(x::AbstractArray; name::String="x") -Returns a `Variable` of dimension `dims` initialized with an array of all zeros. - -`Variable(x::AbstractArray)` - -Returns a `Variable` of dimension `size(x)` initialized with `x` +Creates an optimization variable of type `T` and dimensions `dims...`, or from the provided array `x`. +The optional `name` argument allows to specify a name for the variable, which is useful for display purposes. """ -function Variable(T::Type, args::Int...) - N = length(args) - Variable{T,N,Array{T,N}}(zeros(T, args...)) +function Variable(T::Type, args::Int...; name::String="x") + Variable(zeros(T, args...); name) end -function Variable(args::Int...) - Variable(zeros(args...)) +function Variable(args::Int...; name::String="x") + Variable(zeros(args...); name) end # Utils function Base.show(io::IO, x::Variable) - print(io, "Variable($(eltype(x.x)), $(size(x.x)))") + print(io, "Variable($(eltype(x.x)), $(size(x.x)), \"$(x.name)\")") end """ @@ -54,3 +56,10 @@ eltype(x::Variable) Like `eltype(x::AbstractArray)` returns the type of the elements of `x`. """ eltype(x::Variable) = eltype(x.x) + +""" +get_name(x::Variable) + +Returns the name of the variable `x`. If no name was provided at construction, returns `"x"`. +""" +get_name(x::Variable) = x.name diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..6d0eb10 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,43 @@ +[deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2" +DSPOperators = "d5a72628-6e2f-430e-82f5-561df0bb8116" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +FFTWOperators = "c59a084b-ba08-4f3f-af9e-f4298d6caa94" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +WaveletOperators = "f3582904-6f60-4bbd-985d-55eab799bc9d" +AbstractOperators = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c" +ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9" +ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b" +ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +StructuredOptimization = "46cd3e9d-64ff-517d-a929-236bc1a1fc9d" + +[compat] +Aqua = "0.8" +DSP = "0.5.1 - 0.8" +DSPOperators = "0.1" +FFTW = "1" +FFTWOperators = "0.1" +LinearAlgebra = "1" +Random = "1" +Test = "1" +WaveletOperators = "0.1" +AbstractOperators = "0.4" +ProximalAlgorithms = "0.8" +ProximalCore = "0.2" +ProximalOperators = "0.17" +RecursiveArrayTools = "1 - 3" + +[sources] +StructuredOptimization = { path = "../" } +AbstractOperators = { path = "../../AbstractOperators" } +ProximalAlgorithms = { path = "../../ProximalAlgorithms.jl" } +ProximalCore = { path = "../../ProximalCore.jl" } +ProximalOperators = { path = "../../ProximalOperators.jl" } +DSPOperators = { path = "../../AbstractOperators/DSPOperators" } +WaveletOperators = { path = "../../AbstractOperators/WaveletOperators" } +FFTWOperators = { path = "../../AbstractOperators/FFTWOperators" } + diff --git a/test/runtests.jl b/test/runtests.jl index cf4986e..5e5a0d2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using StructuredOptimization -using AbstractOperators +using AbstractOperators, DSPOperators, FFTWOperators using ProximalOperators using ProximalAlgorithms using RecursiveArrayTools @@ -22,7 +22,7 @@ Random.seed!(0) include("test_terms.jl") end - #=@testset "Problem construction" begin + @testset "Problem construction" begin include("test_problem.jl") include("test_build_minimize.jl") end @@ -30,7 +30,7 @@ Random.seed!(0) @testset "End-to-end tests" begin include("test_usage_small.jl") include("test_usage.jl") - end=# + end @testset "Aqua" begin Aqua.test_all(StructuredOptimization; ambiguities=false, piracies=false) @@ -41,6 +41,7 @@ Random.seed!(0) StructuredOptimization; treat_as_own=[ ProximalAlgorithms.value_and_gradient, + ProximalAlgorithms.value_and_gradient!, ProximalOperators.prox, ProximalOperators.prox!, ProximalOperators.gradient, diff --git a/test/test_AbstractOp_binding.jl b/test/test_AbstractOp_binding.jl index 0ebb192..9f5ab96 100644 --- a/test/test_AbstractOp_binding.jl +++ b/test/test_AbstractOp_binding.jl @@ -52,15 +52,15 @@ ex = x[1:2] # DFT n = 5 -op = AbstractOperators.DFT(Float64,(n,)) +op = DFT(Float64,(n,)) x = Variable(randn(n)) ex = fft(x) @test norm(operator(ex)*(~x)-op*(~x)) <1e-12 # IDFT n = 5 -op = IDFT(Float64,(n,)) -x = Variable(randn(n)) +op = IDFT(ComplexF64,(n,)) +x = Variable(randn(ComplexF64, n)) ex = ifft(x) @test norm(operator(ex)*(~x)-op*(~x)) <1e-12 diff --git a/test/test_build_minimize.jl b/test/test_build_minimize.jl index a4c2c18..510f7f7 100644 --- a/test/test_build_minimize.jl +++ b/test/test_build_minimize.jl @@ -1,4 +1,4 @@ -using ProximalAlgorithms +using ProximalAlgorithms: ZeroFPR, PANOC, PANOCplus x = Variable(10) A = randn(5, 10) @@ -53,7 +53,7 @@ function test_solver(solver) @test norm(~x - [a]) < 1e-4 @test norm(~y - [a^2]) < 1e-4 end -solvers = [ZeroFPR(; tol=1e-6), PANOC(; tol=1e-6)] +solvers = [ZeroFPR(; tol=1e-6), PANOC(; tol=1e-6), PANOCplus(; tol=1e-6)] for solver in solvers test_solver(solver) end diff --git a/test/test_expressions.jl b/test/test_expressions.jl index 707375e..b56b18d 100644 --- a/test/test_expressions.jl +++ b/test/test_expressions.jl @@ -191,43 +191,26 @@ ex3 = ex1+ex2 n = 3 b = randn(n) -x1 = Variable(randn(1)) +x1 = Variable(randn(n)) x2 = Variable(randn(n)) ex1 = x1.+x2 @test norm(operator(ex1)*(~variables(ex1))-((~x1).+(~x2))) < 1e-9 -x1 = Variable(randn(1)) +x1 = Variable(randn(n)) x2 = Variable(randn(n)) ex1 = x1.+(x2+2) @test norm(operator(ex1)*(~variables(ex1))-((~x1).+(~x2))) < 1e-9 @test displacement(ex1) == 2 -x1 = Variable(randn(1)) +x1 = Variable(randn(n)) x2 = Variable(randn(n)) ex1 = (x1+2).+(x2+b) @test norm(operator(ex1)*(~variables(ex1))-((~x1).+(~x2))) < 1e-9 @test displacement(ex1) == (b.+2) -x1 = Variable(randn(n)) -x2 = Variable(randn(1)) -ex1 = x1.+x2 -@test norm(operator(ex1)*(~variables(ex1))-((~x1).+(~x2))) < 1e-9 - -x1 = Variable(randn(n)) -x2 = Variable(randn(1)) -ex1 = x1.+(x2+2) -@test norm(operator(ex1)*(~variables(ex1))-((~x1).+(~x2))) < 1e-9 -@test displacement(ex1) == 2 - -x1 = Variable(randn(n)) -x2 = Variable(randn(1)) -ex1 = (x1+b).+(x2+2) -@test norm(operator(ex1)*(~variables(ex1))-((~x1).+(~x2))) < 1e-9 -@test displacement(ex1) == (b.+2) - n,m =2,4 x1 = Variable(randn(n,m)) -x2 = Variable(randn(1,m)) +x2 = Variable(randn(n,m)) ex1 = x1.+x2+6 @test norm(operator(ex1)*(~variables(ex1))-((~x1).+(~x2))) < 1e-9 @test displacement(ex1) == 6 @@ -237,43 +220,26 @@ ex1 = x1.+x2+6 n = 3 b = randn(n) -x1 = Variable(randn(1)) +x1 = Variable(randn(n)) x2 = Variable(randn(n)) ex1 = x1.-x2 @test norm(operator(ex1)*(~variables(ex1))-((~x1).-(~x2))) < 1e-9 -x1 = Variable(randn(1)) +x1 = Variable(randn(n)) x2 = Variable(randn(n)) ex1 = x1.-(x2+2) @test norm(operator(ex1)*(~variables(ex1))-((~x1).-(~x2))) < 1e-9 @test displacement(ex1) == -2 -x1 = Variable(randn(1)) +x1 = Variable(randn(n)) x2 = Variable(randn(n)) ex1 = (x1+2).-(x2+b) @test norm(operator(ex1)*(~variables(ex1))-((~x1).-(~x2))) < 1e-9 @test displacement(ex1) == (2 .-b) -x1 = Variable(randn(n)) -x2 = Variable(randn(1)) -ex1 = x1.-x2 -@test norm(operator(ex1)*(~variables(ex1))-((~x1).-(~x2))) < 1e-9 - -x1 = Variable(randn(n)) -x2 = Variable(randn(1)) -ex1 = x1.-(x2+2) -@test norm(operator(ex1)*(~variables(ex1))-((~x1).-(~x2))) < 1e-9 -@test displacement(ex1) == -2 - -x1 = Variable(randn(n)) -x2 = Variable(randn(1)) -ex1 = (x1+b).-(x2+2) -@test norm(operator(ex1)*(~variables(ex1))-((~x1).-(~x2))) < 1e-9 -@test displacement(ex1) == (b.-2) - n,m =2,4 x1 = Variable(randn(n,m)) -x2 = Variable(randn(1,m)) +x2 = Variable(randn(n,m)) ex1 = x1.-x2+6 @test norm(operator(ex1)*(~variables(ex1))-((~x1).-(~x2))) < 1e-9 @test displacement(ex1) == 6 @@ -317,14 +283,12 @@ ex3 = ex1-ex2 @test_throws ErrorException MatrixOp(randn(10,20))*Variable(20)+(3+im) # Advanced (+) sum -x, y, z, w = Variable(10), Variable(20), Variable(30), Variable(40) -~x, ~y, ~z, ~w = rand(10), rand(20), rand(30), rand(40) +x, y, z, w = Variable(rand(10)), Variable(rand(20)), Variable(rand(30)), Variable(rand(40)) A = randn(10,10) exA = (z[1:10]+x)+3*(x+z[1:10])+A*(w[1:10]+z[1:10])+(z[1:10]+w[1:10]) exB = 5*w[1:10]+z[1:10]+z[1:10]+3*y[1:10]+z[1:10] exC = exA+exB op = operator(exC) -output = op*(~x,~y,~z,~w) -expected_output = 4*~x+3*~y[1:10]+8*~z[1:10]+6*~w[1:10]+A*(~w[1:10]+~z[1:10]) +output = op*ArrayPartition(~z,~x,~w,~y) +expected_output = 4*(~x)+3*(~y)[1:10]+8*(~z)[1:10]+6*(~w)[1:10]+A*((~w)[1:10]+(~z)[1:10]) @test norm(output-expected_output) < 1e-12 - diff --git a/test/test_problem.jl b/test/test_problem.jl index 757e0a0..677071a 100644 --- a/test/test_problem.jl +++ b/test/test_problem.jl @@ -105,163 +105,3 @@ V = StructuredOptimization.extract_operators(xAll,cf) @test typeof(V[6][3]) <: Zeros @test typeof(V[6][4]) <: Zeros @test typeof(V[6][5]) <: Eye - -println("\nTesting splitting Terms\n") - -x = Variable(5) -y = Variable(5) -cf = ls(x)+10*norm(x,2)+ls(x+y) - -f, g = StructuredOptimization.split_smooth(cf) -@test f[1] == cf[1] -@test f[2] == cf[3] -@test g[1] == cf[2] - -cf = ls(x) -f, g = StructuredOptimization.split_smooth((cf,)) -@test f == (cf,) -@test g == () - -cf = norm(x,1)+norm(y,2)+norm(randn(5,5)*x+y,Inf) -xAll = StructuredOptimization.extract_variables(cf) -AAc, nonAAc = StructuredOptimization.split_AAc_diagonal(cf) -@test AAc[1] == cf[1] -@test AAc[2] == cf[2] -@test nonAAc[1] == cf[3] - -cf = ls(sigmoid(x)) + ls(x) -fq, fs = StructuredOptimization.split_quadratic(cf) -@test fs[1] == cf[1] -@test fq[1] == cf[2] - -println("\nTesting extracting Proximable functions\n") -# testing is_proximable -@test StructuredOptimization.is_proximable(AAc) == true -@test StructuredOptimization.is_proximable(nonAAc) == false - -cf = norm(x[1:2],1)+norm(x[3:5]) -xAll = StructuredOptimization.extract_variables(cf) - -@test all(StructuredOptimization.is_AAc_diagonal.(cf)) == true -@test StructuredOptimization.is_proximable(cf) == true - -cf = norm(x[1:2],1)+norm(x[3:5])+norm(x,Inf) -xAll = StructuredOptimization.extract_variables(cf) - -@test all(StructuredOptimization.is_AAc_diagonal.(cf)) == true -@test StructuredOptimization.is_proximable(cf) == false - -# testing extract_proximable -# single variable, single term -x = Variable(randn(5)) -b = randn(5) -cf = 10*norm(x-b,1) -xAll = StructuredOptimization.extract_variables(cf) -@test StructuredOptimization.is_proximable(cf) == true - -f = StructuredOptimization.extract_proximable(xAll,cf) -@test norm(f(~x) - 10*norm(~x-b,1)) < 1e-12 - -# single variable, single term, diagonal term -x = Variable(randn(5)) -b = randn(5) -d = randn(5) -cf = 10*norm(d.*x-b,1) -xAll = StructuredOptimization.extract_variables(cf) -@test StructuredOptimization.is_proximable(cf) == true - -f = StructuredOptimization.extract_proximable(xAll,cf) -@test norm(f(~x) - 10*norm(d.*~x-b,1)) < 1e-12 - -# single variable, single term, tight frame term -x = Variable(randn(5)) -b = randn(5) -d = randn(5) -cf = 10*norm(dct(x)-b,1) -xAll = StructuredOptimization.extract_variables(cf) -@test StructuredOptimization.is_proximable(cf) == true - -f = StructuredOptimization.extract_proximable(xAll,cf) -@test norm(f(~x) - 10*norm(dct(~x)-b,1)) < 1e-12 - -# single variable, single term, tight frame term, fft -# TODO this not working (probably fix needed in ProxOp) -#x = Variable(randn(5)) -#b = randn(5) -#d = randn(5) -#cf = 10*norm(fft(x)-b,1) -#xAll = StructuredOptimization.extract_variables(cf) -#@test StructuredOptimization.is_proximable(cf) == true -# -#f = StructuredOptimization.extract_proximable(xAll,cf) -#@test norm(f(~x) - 10*norm(fft(~x)-b,1)) < 1e-12 - -# single variable, multiple terms with GetIndex -x = Variable(randn(5)) -b = randn(2) -cf = 10*norm(x[1:2]-b,1)+norm(x[3:5],2) -xAll = StructuredOptimization.extract_variables(cf) -@test StructuredOptimization.is_proximable(cf) == true -f = StructuredOptimization.extract_proximable(xAll,cf) -@test norm(f(~x) - sum([10*norm((~x)[1:2]-b,1);norm((~x)[3:5],2)])) < 1e-12 - -# single variable, multiple terms with GetIndex composed with dct -x = Variable(randn(5)) -b = randn(2) -cf = 10*norm(x[1:2]-b,1)+norm(dct(x[3:5]),2) -xAll = StructuredOptimization.extract_variables(cf) -@test StructuredOptimization.is_proximable(cf) == true -f = StructuredOptimization.extract_proximable(xAll,cf) -@test norm(f(~x) - sum([10*norm((~x)[1:2]-b,1);norm(dct((~x)[3:5]),2)])) < 1e-12 - -# multiple variables, multiple terms -x1 = Variable(randn(5)) -b1 = randn(5) -x2 = Variable(randn(3)) -b2 = randn(3) - -cf = 10*norm(x2-b2,1)+norm(x1+b1,2) -xAll = (x1,x2) -@test StructuredOptimization.is_proximable(cf) == true -f = StructuredOptimization.extract_proximable(xAll,cf) -@test norm(f.fs[1](~x1)-norm(~x1+b1,2) ) < 1e-12 -@test norm(f.fs[2](~x2)-10*norm(~x2-b2,1) ) < 1e-12 - -x1 = Variable(randn(5)) -b1 = randn(5) -x2 = Variable(randn(5)) -b2 = randn(5) - -# TODO fix this? -#cf = 10*norm(x2+x1+b2,1) -#xAll = (x1,x2) -#@test StructuredOptimization.is_proximable(cf) == true -#f = StructuredOptimization.extract_proximable(xAll,cf) -# TODO fix this! in ProxOp? -# @test norm(f((~x1,~x2))-10*norm(~x2+~x1+b2,1) ) < 1e-12 - -# multiple variables, missing terms -x1 = Variable(randn(5)) -b1 = randn(5) -x2 = Variable(randn(3)) -b2 = randn(3) - -cf = 10*norm(x2-b2,1) -xAll = (x1,x2) -@test StructuredOptimization.is_proximable(cf) == true -f = StructuredOptimization.extract_proximable(xAll,cf) -@test f.fs[1](~x1) == 0. -@test norm(f.fs[2](~x2)-10*norm(~x2-b2,1) ) < 1e-12 - -# multiple variables, multiple terms, with GetIndex -x1 = Variable(randn(5)) -b1 = randn(5) -x2 = Variable(randn(3)) -b2 = randn(3) - -cf = norm(x1[3:5]+b1[3:5],1)+10*norm(x2-b2,1)+norm(x1[1:2]+b1[1:2],2) -xAll = (x1,x2) -@test StructuredOptimization.is_proximable(cf) == true -f = StructuredOptimization.extract_proximable(xAll,cf) -@test norm(f.fs[1](~x1)-norm((~x1)[1:2]+b1[1:2],2)-norm((~x1)[3:5]+b1[3:5],1) ) < 1e-12 -@test norm(f.fs[2](~x2)-10*norm(~x2-b2,1) ) < 1e-12 diff --git a/test/test_proxstuff.jl b/test/test_proxstuff.jl index c4ce361..d1744d7 100644 --- a/test/test_proxstuff.jl +++ b/test/test_proxstuff.jl @@ -26,8 +26,8 @@ r = randn(l,n2) b = randn(l,n2) G = AffineAdd(Ax_mul_Bx( - HCAT(A,Zeros(codomainType(B), size(B,2), size(A,1) )), - HCAT(Zeros(codomainType(A), size(A,2), size(B,1) ),B) + HCAT(A,Zeros(codomain_type(B), size(B,2), size(A,1) )), + HCAT(Zeros(codomain_type(A), size(A,2), size(B,1) ),B) ), b,false) diff --git a/test/test_terms.jl b/test/test_terms.jl index 6c165f4..8295ef8 100644 --- a/test/test_terms.jl +++ b/test/test_terms.jl @@ -39,15 +39,15 @@ cf = pi*norm(x,2) @test cf.lambda - pi == 0 @test cf.f(~x) == norm(~x) -cf = 3*mixednorm(X,2,1) +cf = 3*norm(X,2,1) @test cf.lambda - 3 == 0 @test cf.f(~X) == sum( sqrt.(sum((~X).^2, dims=1 )) ) -cf = 4*mixednorm(X,1,2) +cf = 4*norm(X,2,1; dim=2) @test cf.lambda - 4 == 0 @test cf.f(~X) == sum( sqrt.(sum((~X).^2, dims=2 )) ) -@test_throws ErrorException 4*mixednorm(X,1,3) +@test_throws ErrorException 4*norm(X,1,2) cf = norm(x, 2) <= 2.3 @test cf.lambda == 1 @@ -192,21 +192,6 @@ cf = ls(x) + 10*norm(x, 1) @test cf[2].lambda == 10 @test cf[2].f(~x) == norm(~x,1) -x = Variable(10) -cf = () #empty cost function -cf += 10*norm(x, 1) -@test length(cf) == 1 -@test cf[1].lambda == 10 -@test cf[1].f(~x) == 10*norm(~x,1) - -x = Variable(10) -cf = () #empty cost function -cf += ls(x) + 10*norm(x, 1) -@test cf[1].lambda == 1 -@test cf[1].f(~x) == 0.5*norm(~x)^2 -@test cf[2].lambda == 10 -@test cf[2].f(~x) == norm(~x,1) - # More complex situations x = Variable(10) @@ -261,5 +246,7 @@ cf = norm(w + z)^2 @test StructuredOptimization.is_AcA_diagonal(cf) == false cf = norm(x, 1) + norm(y, 2) -@test StructuredOptimization.is_smooth.(cf) == (false,false) -@test StructuredOptimization.is_AcA_diagonal.(cf) == (true,true) +@test StructuredOptimization.is_smooth.(cf.terms) == (false,false) +@test StructuredOptimization.is_smooth(cf) == false +@test StructuredOptimization.is_AcA_diagonal.(cf.terms) == (true,true) +@test StructuredOptimization.is_AcA_diagonal(cf) == true diff --git a/test/test_usage.jl b/test/test_usage.jl index ba8837e..1ca344b 100644 --- a/test/test_usage.jl +++ b/test/test_usage.jl @@ -1,3 +1,5 @@ +using ProximalAlgorithms: PANOCplus, FastForwardBackward, ZeroFPR, PANOC + Random.seed!(0) ################################################################################ @@ -5,7 +7,6 @@ Random.seed!(0) ################################################################################ println("Testing: regularized least squares, with two variable blocks to make things weird") -begin m, n1, n2 = 30, 50, 100 A1 = randn(m, n1) @@ -17,56 +18,37 @@ lam2 = 1.0 # Solve with PANOC+ -x1_fpg = Variable(n1) -x2_fpg = Variable(n2) -expr = ls(A1*x1_fpg + A2*x2_fpg - b) + lam1*norm(x1_fpg, 1) + lam2*norm(x2_fpg, 2) -end -prob = problem(expr) -@time sol = solve(prob, PANOCplus(tol=1e-10, verbose=false,maxit=20000)) - -# Solve with ZeroFPR - -x1_zerofpr = Variable(n1) -x2_zerofpr = Variable(n2) -expr = ls(A1*x1_zerofpr + A2*x2_zerofpr - b) + lam1*norm(x1_zerofpr, 1) + lam2*norm(x2_zerofpr, 2) -prob = problem(expr) -@time sol = solve(prob, ZeroFPR(tol=1e-10, verbose=false)) - -# Solve with PANOC - -x1_panoc = Variable(n1) -x2_panoc = Variable(n2) -expr = ls(A1*x1_panoc + A2*x2_panoc - b) + lam1*norm(x1_panoc, 1) + lam2*norm(x2_panoc, 2) +x1_panocplus = Variable(n1) +x2_panocplus = Variable(n2) +expr = ls(A1*x1_panocplus + A2*x2_panocplus - b) + lam1*norm(x1_panocplus, 1) + lam2*norm(x2_panocplus, 2) prob = problem(expr) -@time sol = solve(prob, PANOC(tol=1e-10, verbose=false)) - -# Solve with minimize, use default solver/options - -x1 = Variable(n1) -x2 = Variable(n2) -@time sol = @minimize ls(A1*x1 + A2*x2 - b) + lam1*norm(x1, 1) + lam2*norm(x2, 2) +@time sol = solve(prob, PANOCplus()) -@test norm(~x1_fpg - ~x1_zerofpr, Inf)/(1+norm(~x1_zerofpr, Inf)) <= 1e-6 -@test norm(~x2_fpg - ~x2_zerofpr, Inf)/(1+norm(~x2_zerofpr, Inf)) <= 1e-6 -@test norm(~x1_fpg - ~x1_panoc, Inf)/(1+norm(~x1_panoc, Inf)) <= 1e-6 -@test norm(~x2_fpg - ~x2_panoc, Inf)/(1+norm(~x2_panoc, Inf)) <= 1e-6 -@test norm(~x1 - ~x1_zerofpr, Inf)/(1+norm(~x1_zerofpr, Inf)) <= 1e-3 -@test norm(~x2 - ~x2_zerofpr, Inf)/(1+norm(~x2_zerofpr, Inf)) <= 1e-3 - -res = A1*~x1_fpg + A2*~x2_fpg - b +res = A1*~x1_panocplus + A2*~x2_panocplus - b grad1 = A1'*res grad2 = A2'*res -ind1_zero = (~x1_fpg .== 0) -subgr1 = lam1*sign.(~x1_fpg) +ind1_zero = (~x1_panocplus .== 0) +subgr1 = lam1*sign.(~x1_panocplus) subdiff1_low, subdiff1_upp = copy(subgr1), copy(subgr1) subdiff1_low[ind1_zero] .= -lam1 subdiff1_upp[ind1_zero] .= +lam1 -subgr2 = lam2*(~x2_fpg/norm(~x2_fpg, 2)) +subgr2 = lam2*(~x2_panocplus/norm(~x2_panocplus, 2)) @test maximum(subdiff1_low + grad1) <= 1e-6 @test maximum(-subdiff1_upp - grad1) <= 1e-6 @test norm(grad2 + subgr2) <= 1e-6 +# Solve with FastForwardBackward + +x1_ffb = Variable(n1) +x2_ffb = Variable(n2) +expr = ls(A1*x1_ffb + A2*x2_ffb - b) + lam1*norm(x1_ffb, 1) + lam2*norm(x2_ffb, 2) +prob = problem(expr) +@time sol = solve(prob, FastForwardBackward()) + +@test norm(~x1_panocplus - ~x1_ffb, Inf)/(1+norm(~x1_ffb, Inf)) <= 1e-6 +@test norm(~x2_panocplus - ~x2_ffb, Inf)/(1+norm(~x2_ffb, Inf)) <= 1e-6 + ############################################################################### ## Lasso problem with known solution ############################################################################### @@ -165,13 +147,13 @@ prob = problem(expr) # Solve with minimize, default solver/options -x = Variable(n) -@time sol = @minimize smooth(norm(A*x - b, 2)) + lam*norm(x, 1) +#x = Variable(n) +#@time sol = @minimize smooth(norm(A*x - b, 2)) + lam*norm(x, 1) @test norm(~x_pg - ~x_fpg, Inf)/(1+norm(~x_pg, Inf)) <= 1e-4 @test norm(~x_pg - ~x_zerofpr, Inf)/(1+norm(~x_pg, Inf)) <= 1e-4 @test norm(~x_pg - ~x_panoc, Inf)/(1+norm(~x_pg, Inf)) <= 1e-4 -@test norm(~x_pg - ~x, Inf)/(1+norm(~x_pg, Inf)) <= 1e-3 +#@test norm(~x_pg - ~x, Inf)/(1+norm(~x_pg, Inf)) <= 1e-3 ################################################################################ ### Box-constrained least-squares @@ -228,11 +210,11 @@ prob = problem(expr, x_panoc in [lb, ub]) # Solve with minimize, default solver/options -x = Variable(n) -@time sol = @minimize ls(A*x - b) st x in [lb, ub] +#x = Variable(n) +#@time sol = @minimize ls(A*x - b) st x in [lb, ub] -@test norm(~x - max.(lb, min.(ub, ~x)), Inf) <= 1e-12 -@test norm(~x - max.(lb, min.(ub, ~x - A'*(A*~x - b))), Inf)/(1+norm(~x, Inf)) <= 1e-4 +#@test norm(~x - max.(lb, min.(ub, ~x)), Inf) <= 1e-12 +#@test norm(~x - max.(lb, min.(ub, ~x - A'*(A*~x - b))), Inf)/(1+norm(~x, Inf)) <= 1e-4 ################################################################################ ### Non-negative least-squares from a known solution diff --git a/test/test_usage_small.jl b/test/test_usage_small.jl index 8503707..14bbcd4 100644 --- a/test/test_usage_small.jl +++ b/test/test_usage_small.jl @@ -1,14 +1,24 @@ +using ProximalAlgorithms: ZeroFPR, PANOC, PANOCplus, ADMM, CGNR + A = randn(3,5) b = randn(3) x_zfpr = Variable(5) prob_zfpr = problem(ls(A*x_zfpr - b) + 1e-3*norm(x_zfpr, 1)) -sol_zfpr = solve(prob_zfpr, ZeroFPR()) +sol_zfpr = solve(prob_zfpr, ZeroFPR(maxit=10)) x_pnc = Variable(5) prob_pnc = problem(ls(A*x_pnc - b) + 1e-3*norm(x_pnc, 1)) -sol_pnc = solve(prob_pnc, PANOC()) +sol_pnc = solve(prob_pnc, PANOC(maxit=10)) x_pncp = Variable(5) prob_pncp = problem(ls(A*x_pncp - b) + 1e-3*norm(x_pncp, 1)) -sol_pncp = solve(prob_pncp, PANOCplus()) +sol_pncp = solve(prob_pncp, PANOCplus(maxit=10)) + +x_admm = Variable(5) +prob_admm = problem(ls(A*x_admm - b) + 1e-3*norm(x_admm, 1)) +sol_admm = solve(prob_admm, ADMM(maxit=10)) + +x_cg = Variable(5) +prob_cg = problem(ls(A*x_cg - b) + 1e-3*norm(x_cg, 2)^2) +sol_cg = solve(prob_cg, CGNR(maxit=10)) From 99bcebbad3e3ba07da65ec0abc3d8ad04024ceb4 Mon Sep 17 00:00:00 2001 From: Tamas Hakkel Date: Tue, 18 Nov 2025 20:34:28 +0100 Subject: [PATCH 5/5] fix float precision error on terms --- src/syntax/terms/term.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/syntax/terms/term.jl b/src/syntax/terms/term.jl index b986279..d72f4b6 100644 --- a/src/syntax/terms/term.jl +++ b/src/syntax/terms/term.jl @@ -3,6 +3,11 @@ struct Term{T1<:Real,T2,T3<:AbstractExpression} f::T2 A::T3 repr::Union{String,Nothing} + function Term(lambda::T1, f::T2, A::T3, repr::Union{String,Nothing}) where {T1<:Real,T2,T3<:AbstractExpression} + T1_ = real(codomain_type(affine(A))) + lambda = convert(T1_, lambda) + return new{T1_,T2,T3}(lambda, f, A, repr) + end end function Term(lambda, f, ex::AbstractExpression) @@ -11,12 +16,12 @@ end function Term(f, ex::AbstractExpression) A = convert(Expression, ex) - Term(one(real(codomain_type(affine(A)))), f, A) + Term(1, f, A) end function Term(f, ex::AbstractExpression, repr::String) A = convert(Expression, ex) - Term(one(real(codomain_type(affine(A)))), f, A, repr) + Term(1, f, A, repr) end function Term(t::Term, repr::String)