From ceb4dd3f3c7ec3a1081db84e1ba27c55fcccb1a7 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 26 Dec 2025 20:27:38 -0500 Subject: [PATCH 1/5] Initial Implementation and Style --- examples/README.jl | 2 +- src/FunctionImplementations.jl | 3 +- src/implementation.jl | 4 + src/style.jl | 185 +++++++++++++++++++++++++++++++++ test/test_basics.jl | 15 ++- 5 files changed, 205 insertions(+), 4 deletions(-) create mode 100644 src/implementation.jl create mode 100644 src/style.jl diff --git a/examples/README.jl b/examples/README.jl index 2dbfd29..b7fa27f 100644 --- a/examples/README.jl +++ b/examples/README.jl @@ -1,5 +1,5 @@ # # FunctionImplementations.jl -# +# # [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://itensor.github.io/FunctionImplementations.jl/stable/) # [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://itensor.github.io/FunctionImplementations.jl/dev/) # [![Build Status](https://github.com/ITensor/FunctionImplementations.jl/actions/workflows/Tests.yml/badge.svg?branch=main)](https://github.com/ITensor/FunctionImplementations.jl/actions/workflows/Tests.yml?query=branch%3Amain) diff --git a/src/FunctionImplementations.jl b/src/FunctionImplementations.jl index d34cdb5..186c693 100644 --- a/src/FunctionImplementations.jl +++ b/src/FunctionImplementations.jl @@ -1,5 +1,6 @@ module FunctionImplementations -# Write your package code here. +include("implementation.jl") +include("style.jl") end diff --git a/src/implementation.jl b/src/implementation.jl new file mode 100644 index 0000000..9c5314a --- /dev/null +++ b/src/implementation.jl @@ -0,0 +1,4 @@ +struct Implementation{F, Style} <: Function + f::F + style::Style +end diff --git a/src/style.jl b/src/style.jl new file mode 100644 index 0000000..215ee82 --- /dev/null +++ b/src/style.jl @@ -0,0 +1,185 @@ +### This is based on the BroadcastStyle code in +### https://github.com/JuliaLang/julia/blob/master/base/broadcast.jl +### Objects with customized behavior for a certain function should declare a Style + +""" +`Style` is an abstract type and trait-function used to determine behavior of +objects. `Style(typeof(x))` returns the style associated +with `x`. To customize the behavior of a type, one can declare a style +by defining a type/method pair + + struct MyContainerStyle <: Style end + FunctionImplementations.Style(::Type{<:MyContainer}) = MyContainerStyle() + +""" +abstract type Style end + +struct UnknownStyle <: Style end +Style(::Type{Union{}}, slurp...) = UnknownStyle() # ambiguity resolution + +""" +`FunctionImplementations.AbstractArrayStyle{N} <: Style` is the abstract supertype for any style +associated with an `AbstractArray` type. +The `N` parameter is the dimensionality, which can be handy for AbstractArray types +that only support specific dimensionalities: + + struct SparseMatrixStyle <: FunctionImplementations.AbstractArrayStyle{2} end + FunctionImplementations.Style(::Type{<:SparseMatrixCSC}) = SparseMatrixStyle() + +For `AbstractArray` types that support arbitrary dimensionality, `N` can be set to `Any`: + + struct MyArrayStyle <: FunctionImplementations.AbstractArrayStyle{Any} end + FunctionImplementations.Style(::Type{<:MyArray}) = MyArrayStyle() + +In cases where you want to be able to mix multiple `AbstractArrayStyle`s and keep track +of dimensionality, your style needs to support a [`Val`](@ref) constructor: + + struct MyArrayStyleDim{N} <: FunctionImplementations.AbstractArrayStyle{N} end + (::Type{<:MyArrayStyleDim})(::Val{N}) where N = MyArrayStyleDim{N}() + +Note that if two or more `AbstractArrayStyle` subtypes conflict, the resulting +style will fall back to that of `Array`s. If this is undesirable, you may need to +define binary [`Style`](@ref) rules to control the output type. + +See also [`FunctionImplementations.DefaultArrayStyle`](@ref). +""" +abstract type AbstractArrayStyle{N} <: Style end + +""" +`FunctionImplementations.ArrayStyle{MyArrayType}()` is a [`FunctionImplementations.Style`](@ref) indicating that an object +behaves as an array. It presents a simple way to construct +[`FunctionImplementations.AbstractArrayStyle`](@ref)s for specific `AbstractArray` container types. +Styles created this way lose track of dimensionality; if keeping track is important +for your type, you should create your own custom [`FunctionImplementations.AbstractArrayStyle`](@ref). +""" +struct ArrayStyle{A <: AbstractArray} <: AbstractArrayStyle{Any} end +ArrayStyle{A}(::Val) where {A} = ArrayStyle{A}() + +""" +`FunctionImplementations.DefaultArrayStyle{N}()` is a [`FunctionImplementations.Style`](@ref) indicating that an object +behaves as an `N`-dimensional array. Specifically, `DefaultArrayStyle` is +used for any +`AbstractArray` type that hasn't defined a specialized style, and in the absence of +overrides from other arguments the resulting output type is `Array`. +When there are multiple inputs, `DefaultArrayStyle` "loses" to any other [`FunctionImplementations.ArrayStyle`](@ref). +""" +struct DefaultArrayStyle{N} <: AbstractArrayStyle{N} end +DefaultArrayStyle(::Val{N}) where {N} = DefaultArrayStyle{N}() +DefaultArrayStyle{M}(::Val{N}) where {N, M} = DefaultArrayStyle{N}() +const DefaultVectorStyle = DefaultArrayStyle{1} +const DefaultMatrixStyle = DefaultArrayStyle{2} +Style(::Type{<:AbstractArray{T, N}}) where {T, N} = DefaultArrayStyle{N}() +Style(::Type{T}) where {T} = DefaultArrayStyle{ndims(T)}() + +# `ArrayConflict` is an internal type signaling that two or more different `AbstractArrayStyle` +# objects were supplied as arguments, and that no rule was defined for resolving the +# conflict. The resulting output is `Array`. While this is the same output type +# produced by `DefaultArrayStyle`, `ArrayConflict` "poisons" the Style so that +# 3 or more arguments still return an `ArrayConflict`. +struct ArrayConflict <: AbstractArrayStyle{Any} end +ArrayConflict(::Val) = ArrayConflict() + +### Binary Style rules +""" + Style(::Style1, ::Style2) = Style3() + +Indicate how to resolve different `Style`s. For example, + + Style(::Primary, ::Secondary) = Primary() + +would indicate that style `Primary` has precedence over `Secondary`. +You do not have to (and generally should not) define both argument orders. +The result does not have to be one of the input arguments, it could be a third type. +""" +Style(::S, ::S) where {S <: Style} = S() # homogeneous types preserved +# Fall back to UnknownStyle. This is necessary to implement argument-swapping +Style(::Style, ::Style) = UnknownStyle() +# UnknownStyle loses to everything +Style(::UnknownStyle, ::UnknownStyle) = UnknownStyle() +Style(::S, ::UnknownStyle) where {S <: Style} = S() +# Precedence rules +Style(::A, ::A) where {A <: ArrayStyle} = A() +Style(::ArrayStyle, ::ArrayStyle) = UnknownStyle() +Style(::A, ::A) where {A <: AbstractArrayStyle} = A() +function Style(a::A, b::B) where {A <: AbstractArrayStyle{M}, B <: AbstractArrayStyle{N}} where {M, N} + if Base.typename(A) === Base.typename(B) + return A(Val(max(M, N))) + end + return UnknownStyle() +end +# Any specific array type beats DefaultArrayStyle +Style(a::AbstractArrayStyle{Any}, ::DefaultArrayStyle) = a +Style(a::AbstractArrayStyle{N}, ::DefaultArrayStyle{N}) where {N} = a +Style(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M, N} = + typeof(a)(Val(max(M, N))) + +## logic for deciding the Style + +""" + combine_styles(cs...)::Style + +Decides which `Style` to use for any number of value arguments. +Uses [`Style`](@ref) to get the style for each argument, and uses +[`result_style`](@ref) to combine styles. + +# Examples +```jldoctest +julia> FunctionImplementations.combine_styles([1], [1 2; 3 4]) +FunctionImplementations.DefaultArrayStyle{2}() +``` +""" +function combine_styles end + +combine_styles() = DefaultArrayStyle{0}() +combine_styles(c) = result_style(Style(typeof(c))) +combine_styles(c1, c2) = result_style(combine_styles(c1), combine_styles(c2)) +@inline combine_styles(c1, c2, cs...) = result_style(combine_styles(c1), combine_styles(c2, cs...)) + +""" + result_style(s1::Style[, s2::Style])::Style + +Takes one or two `Style`s and combines them using [`Style`](@ref) to +determine a common `Style`. + +# Examples + +```jldoctest +julia> FunctionImplementations.result_style(FunctionImplementations.DefaultArrayStyle{0}(), FunctionImplementations.DefaultArrayStyle{3}()) +FunctionImplementations.DefaultArrayStyle{3}() + +julia> FunctionImplementations.result_style(FunctionImplementations.UnknownStyle(), FunctionImplementations.DefaultArrayStyle{1}()) +FunctionImplementations.DefaultArrayStyle{1}() +``` +""" +function result_style end + +result_style(s::Style) = s +function result_style(s1::S, s2::S) where {S <: Style} + return s1 ≡ s2 ? s1 : error("inconsistent styles, custom rule needed") +end +# Test both orders so users typically only have to declare one order +result_style(s1, s2) = result_join(s1, s2, Style(s1, s2), Style(s2, s1)) + +# result_join is the final arbiter. Because `Style` for undeclared pairs results in UnknownStyle, +# we defer to any case where the result of `Style` is known. +result_join(::Any, ::Any, ::UnknownStyle, ::UnknownStyle) = UnknownStyle() +result_join(::Any, ::Any, ::UnknownStyle, s::Style) = s +result_join(::Any, ::Any, s::Style, ::UnknownStyle) = s +# For AbstractArray types with undefined precedence rules, +# we have to signal conflict. Because ArrayConflict is a subtype of AbstractArray, +# this will "poison" any future operations (if we instead returned `DefaultArrayStyle`, then for +# 3-array functions returned type would depend on argument order). +result_join(::AbstractArrayStyle, ::AbstractArrayStyle, ::UnknownStyle, ::UnknownStyle) = + ArrayConflict() +# Fallbacks in case users define `rule` for both argument-orders (not recommended) +result_join(::Any, ::Any, s1::S, s2::S) where {S <: Style} = result_style(s1, s2) + +@noinline function result_join(::S, ::T, ::U, ::V) where {S, T, U, V} + error( + """ + conflicting rules defined + FunctionImplementations.Style(::$S, ::$T) = $U() + FunctionImplementations.Style(::$T, ::$S) = $V() + One of these should be undefined (and thus return FunctionImplementations.UnknownStyle).""" + ) +end diff --git a/test/test_basics.jl b/test/test_basics.jl index 0dd6556..f22a371 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,6 +1,17 @@ -using FunctionImplementations: FunctionImplementations +import FunctionImplementations as FI using Test: @test, @testset @testset "FunctionImplementations" begin - # Tests go here. + @testset "Implementation" begin + struct MyAddAlgorithm end + f = FI.Implementation(+, MyAddAlgorithm()) + @test f.f ≡ + + @test f.style ≡ MyAddAlgorithm() + (::typeof(f))(x, y) = "My add" + @test f(2, 3) == "My add" + @test f.f ≡ + + @test f.style ≡ MyAddAlgorithm() + end + @testset "Style" begin + end end From fcc7ec09142400a046c93ae67b5295d1fcc520a5 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 26 Dec 2025 20:34:54 -0500 Subject: [PATCH 2/5] Add tests --- test/test_basics.jl | 118 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) diff --git a/test/test_basics.jl b/test/test_basics.jl index f22a371..7b953b1 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -13,5 +13,123 @@ using Test: @test, @testset @test f.style ≡ MyAddAlgorithm() end @testset "Style" begin + # Test basic Style trait for different array types + @test FI.Style(typeof([1, 2, 3])) isa FI.DefaultArrayStyle{1} + @test FI.Style(typeof([1 2; 3 4])) isa FI.DefaultArrayStyle{2} + @test FI.Style(typeof(rand(2, 3, 4))) isa FI.DefaultArrayStyle{3} + + # Test custom Style definition + struct CustomStyle <: FI.Style end + struct CustomArray end + FI.Style(::Type{CustomArray}) = CustomStyle() + @test FI.Style(CustomArray) isa CustomStyle + + # Test custom AbstractArrayStyle definition + struct MyArray{T, N} <: AbstractArray{T, N} + data::Array{T, N} + end + struct MyArrayStyle <: FI.AbstractArrayStyle{Any} end + FI.Style(::Type{<:MyArray}) = MyArrayStyle() + @test FI.Style(MyArray) isa MyArrayStyle + + # Test style homogeneity rule (same type returns preserved) + s1 = FI.DefaultArrayStyle{1}() + s2 = FI.DefaultArrayStyle{1}() + @test FI.Style(s1, s2) ≡ s1 + + # Test UnknownStyle precedence + unknown = FI.UnknownStyle() + known = FI.DefaultArrayStyle{1}() + @test FI.Style(known, unknown) ≡ known + @test FI.Style(unknown, unknown) ≡ unknown + + # Test AbstractArrayStyle with different dimensions uses max + @test FI.Style( + FI.DefaultArrayStyle{1}(), + FI.DefaultArrayStyle{2}() + ) isa FI.DefaultArrayStyle{2} + + # Test ArrayStyle + arr_style = FI.ArrayStyle{Vector{Int}}() + @test arr_style isa FI.ArrayStyle{Vector{Int}} + @test arr_style isa FI.AbstractArrayStyle{Any} + + # Test that same ArrayStyle returns preserved + arr_style1 = FI.ArrayStyle{Vector{Int}}() + arr_style2 = FI.ArrayStyle{Vector{Int}}() + @test FI.Style(arr_style1, arr_style2) ≡ arr_style1 + + # Test different ArrayStyles result in UnknownStyle + arr_style_vec = FI.ArrayStyle{Vector{Int}}() + arr_style_mat = FI.ArrayStyle{Matrix{Int}}() + @test FI.Style(arr_style_vec, arr_style_mat) isa FI.UnknownStyle + + # Test ArrayStyle Val constructor + arr_style_val = FI.ArrayStyle{Vector{Int}}(Val(2)) + @test arr_style_val isa FI.ArrayStyle{Vector{Int}} + + # Test DefaultArrayStyle Val constructor preserves type when dimension matches + default_style = FI.DefaultArrayStyle{1}(Val(1)) + @test default_style isa FI.DefaultArrayStyle{1} + + # Test DefaultArrayStyle Val constructor changes dimension + default_style_change = FI.DefaultArrayStyle{1}(Val(2)) + @test default_style_change isa FI.DefaultArrayStyle{2} + + # Test const aliases + @test FI.DefaultVectorStyle ≡ FI.DefaultArrayStyle{1} + @test FI.DefaultMatrixStyle ≡ FI.DefaultArrayStyle{2} + + # Test ArrayConflict + conflict = FI.ArrayConflict() + @test conflict isa FI.ArrayConflict + @test conflict isa FI.AbstractArrayStyle{Any} + + # Test ArrayConflict Val constructor + conflict_val = FI.ArrayConflict(Val(3)) + @test conflict_val isa FI.ArrayConflict + + # Test combine_styles with no arguments + @test FI.combine_styles() isa FI.DefaultArrayStyle{0} + + # Test combine_styles with single argument + @test FI.combine_styles([1, 2]) isa FI.DefaultArrayStyle{1} + @test FI.combine_styles([1 2; 3 4]) isa FI.DefaultArrayStyle{2} + + # Test combine_styles with two arguments + result = FI.combine_styles([1, 2], [1 2; 3 4]) + @test result isa FI.DefaultArrayStyle{2} + + # Test combine_styles with same dimensions + result = FI.combine_styles([1], [2]) + @test result isa FI.DefaultArrayStyle{1} + + # Test combine_styles with multiple arguments + result = FI.combine_styles([1], [1 2], rand(2, 3, 4)) + @test result isa FI.DefaultArrayStyle{3} + + # Test result_style with single argument + @test FI.result_style(FI.DefaultArrayStyle{1}()) isa FI.DefaultArrayStyle{1} + + # Test result_style with two identical styles + s = FI.DefaultArrayStyle{2}() + @test FI.result_style(s, s) ≡ s + + # Test result_style with UnknownStyle + known = FI.DefaultArrayStyle{1}() + unknown = FI.UnknownStyle() + @test FI.result_style(known, unknown) ≡ known + @test FI.result_style(unknown, known) ≡ known + + # Test result_style with different dimension DefaultArrayStyle uses max + result = FI.result_style( + FI.DefaultArrayStyle{1}(), + FI.DefaultArrayStyle{2}() + ) + @test result isa FI.DefaultArrayStyle{2} + + # Test result_style with same shape behaves consistently + same_style = FI.DefaultArrayStyle{2}() + @test FI.result_style(same_style, same_style) ≡ same_style end end From c8a0c9f30796ab6d7d52cd569107c58bba3844a1 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 26 Dec 2025 20:41:04 -0500 Subject: [PATCH 3/5] Delete ArrayStyle --- src/style.jl | 13 +------------ test/test_basics.jl | 19 ------------------- 2 files changed, 1 insertion(+), 31 deletions(-) diff --git a/src/style.jl b/src/style.jl index 215ee82..1d83864 100644 --- a/src/style.jl +++ b/src/style.jl @@ -45,16 +45,6 @@ See also [`FunctionImplementations.DefaultArrayStyle`](@ref). """ abstract type AbstractArrayStyle{N} <: Style end -""" -`FunctionImplementations.ArrayStyle{MyArrayType}()` is a [`FunctionImplementations.Style`](@ref) indicating that an object -behaves as an array. It presents a simple way to construct -[`FunctionImplementations.AbstractArrayStyle`](@ref)s for specific `AbstractArray` container types. -Styles created this way lose track of dimensionality; if keeping track is important -for your type, you should create your own custom [`FunctionImplementations.AbstractArrayStyle`](@ref). -""" -struct ArrayStyle{A <: AbstractArray} <: AbstractArrayStyle{Any} end -ArrayStyle{A}(::Val) where {A} = ArrayStyle{A}() - """ `FunctionImplementations.DefaultArrayStyle{N}()` is a [`FunctionImplementations.Style`](@ref) indicating that an object behaves as an `N`-dimensional array. Specifically, `DefaultArrayStyle` is @@ -64,6 +54,7 @@ overrides from other arguments the resulting output type is `Array`. When there are multiple inputs, `DefaultArrayStyle` "loses" to any other [`FunctionImplementations.ArrayStyle`](@ref). """ struct DefaultArrayStyle{N} <: AbstractArrayStyle{N} end +DefaultArrayStyle() = DefaultArrayStyle{Any}() DefaultArrayStyle(::Val{N}) where {N} = DefaultArrayStyle{N}() DefaultArrayStyle{M}(::Val{N}) where {N, M} = DefaultArrayStyle{N}() const DefaultVectorStyle = DefaultArrayStyle{1} @@ -98,8 +89,6 @@ Style(::Style, ::Style) = UnknownStyle() Style(::UnknownStyle, ::UnknownStyle) = UnknownStyle() Style(::S, ::UnknownStyle) where {S <: Style} = S() # Precedence rules -Style(::A, ::A) where {A <: ArrayStyle} = A() -Style(::ArrayStyle, ::ArrayStyle) = UnknownStyle() Style(::A, ::A) where {A <: AbstractArrayStyle} = A() function Style(a::A, b::B) where {A <: AbstractArrayStyle{M}, B <: AbstractArrayStyle{N}} where {M, N} if Base.typename(A) === Base.typename(B) diff --git a/test/test_basics.jl b/test/test_basics.jl index 7b953b1..63b40b8 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -49,25 +49,6 @@ using Test: @test, @testset FI.DefaultArrayStyle{2}() ) isa FI.DefaultArrayStyle{2} - # Test ArrayStyle - arr_style = FI.ArrayStyle{Vector{Int}}() - @test arr_style isa FI.ArrayStyle{Vector{Int}} - @test arr_style isa FI.AbstractArrayStyle{Any} - - # Test that same ArrayStyle returns preserved - arr_style1 = FI.ArrayStyle{Vector{Int}}() - arr_style2 = FI.ArrayStyle{Vector{Int}}() - @test FI.Style(arr_style1, arr_style2) ≡ arr_style1 - - # Test different ArrayStyles result in UnknownStyle - arr_style_vec = FI.ArrayStyle{Vector{Int}}() - arr_style_mat = FI.ArrayStyle{Matrix{Int}}() - @test FI.Style(arr_style_vec, arr_style_mat) isa FI.UnknownStyle - - # Test ArrayStyle Val constructor - arr_style_val = FI.ArrayStyle{Vector{Int}}(Val(2)) - @test arr_style_val isa FI.ArrayStyle{Vector{Int}} - # Test DefaultArrayStyle Val constructor preserves type when dimension matches default_style = FI.DefaultArrayStyle{1}(Val(1)) @test default_style isa FI.DefaultArrayStyle{1} From deb4f04021363d26e391ec22c7645b0257031dbb Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 26 Dec 2025 20:47:24 -0500 Subject: [PATCH 4/5] Combining styles with mixed dimensions results in Any dimension --- src/style.jl | 4 ++-- test/test_basics.jl | 16 +++++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/style.jl b/src/style.jl index 1d83864..073d943 100644 --- a/src/style.jl +++ b/src/style.jl @@ -92,7 +92,7 @@ Style(::S, ::UnknownStyle) where {S <: Style} = S() Style(::A, ::A) where {A <: AbstractArrayStyle} = A() function Style(a::A, b::B) where {A <: AbstractArrayStyle{M}, B <: AbstractArrayStyle{N}} where {M, N} if Base.typename(A) === Base.typename(B) - return A(Val(max(M, N))) + return A(Val(Any)) end return UnknownStyle() end @@ -100,7 +100,7 @@ end Style(a::AbstractArrayStyle{Any}, ::DefaultArrayStyle) = a Style(a::AbstractArrayStyle{N}, ::DefaultArrayStyle{N}) where {N} = a Style(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M, N} = - typeof(a)(Val(max(M, N))) + typeof(a)(Val(Any)) ## logic for deciding the Style diff --git a/test/test_basics.jl b/test/test_basics.jl index 63b40b8..43f7d9f 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -47,15 +47,17 @@ using Test: @test, @testset @test FI.Style( FI.DefaultArrayStyle{1}(), FI.DefaultArrayStyle{2}() - ) isa FI.DefaultArrayStyle{2} + ) isa FI.DefaultArrayStyle{Any} # Test DefaultArrayStyle Val constructor preserves type when dimension matches default_style = FI.DefaultArrayStyle{1}(Val(1)) - @test default_style isa FI.DefaultArrayStyle{1} + @test FI.DefaultArrayStyle{1}(Val(1)) isa FI.DefaultArrayStyle{1} # Test DefaultArrayStyle Val constructor changes dimension - default_style_change = FI.DefaultArrayStyle{1}(Val(2)) - @test default_style_change isa FI.DefaultArrayStyle{2} + @test FI.DefaultArrayStyle{1}(Val(2)) isa FI.DefaultArrayStyle{2} + + # Test DefaultArrayStyle constructor defaults to Any dimension + @test FI.DefaultArrayStyle() isa FI.DefaultArrayStyle{Any} # Test const aliases @test FI.DefaultVectorStyle ≡ FI.DefaultArrayStyle{1} @@ -79,7 +81,7 @@ using Test: @test, @testset # Test combine_styles with two arguments result = FI.combine_styles([1, 2], [1 2; 3 4]) - @test result isa FI.DefaultArrayStyle{2} + @test result isa FI.DefaultArrayStyle{Any} # Test combine_styles with same dimensions result = FI.combine_styles([1], [2]) @@ -87,7 +89,7 @@ using Test: @test, @testset # Test combine_styles with multiple arguments result = FI.combine_styles([1], [1 2], rand(2, 3, 4)) - @test result isa FI.DefaultArrayStyle{3} + @test result isa FI.DefaultArrayStyle{Any} # Test result_style with single argument @test FI.result_style(FI.DefaultArrayStyle{1}()) isa FI.DefaultArrayStyle{1} @@ -107,7 +109,7 @@ using Test: @test, @testset FI.DefaultArrayStyle{1}(), FI.DefaultArrayStyle{2}() ) - @test result isa FI.DefaultArrayStyle{2} + @test result isa FI.DefaultArrayStyle{Any} # Test result_style with same shape behaves consistently same_style = FI.DefaultArrayStyle{2}() From 8f64aa6cbd55faed7acca981ecb5d8a6c85d1dba Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 26 Dec 2025 21:05:11 -0500 Subject: [PATCH 5/5] Make Style callable and constructable from instances --- src/implementation.jl | 5 +++++ src/style.jl | 19 ++++++++++++++----- test/test_basics.jl | 7 +++++++ 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/implementation.jl b/src/implementation.jl index 9c5314a..df44cf9 100644 --- a/src/implementation.jl +++ b/src/implementation.jl @@ -1,3 +1,8 @@ +""" +`FunctionImplementations.Implementation(f, s)` wraps a function `f` with a style `s`. +This can be used to create function implementations that behave differently +based on the style of their arguments. +""" struct Implementation{F, Style} <: Function f::F style::Style diff --git a/src/style.jl b/src/style.jl index 073d943..2430a5e 100644 --- a/src/style.jl +++ b/src/style.jl @@ -13,10 +13,21 @@ by defining a type/method pair """ abstract type Style end +Style(x) = Style(typeof(x)) +Style(::Type{T}) where {T} = throw(MethodError(Style, (T,))) struct UnknownStyle <: Style end Style(::Type{Union{}}, slurp...) = UnknownStyle() # ambiguity resolution +""" + (s::Style)(f) + +Calling a Style `s` with a function `f` as `s(f)` is a shorthand for creating a +[`FunctionImplementations.Implementation`](@ref) object wrapping the function `f` with +Style `s`. +""" +(s::Style)(f) = Implementation(f, s) + """ `FunctionImplementations.AbstractArrayStyle{N} <: Style` is the abstract supertype for any style associated with an `AbstractArray` type. @@ -32,7 +43,7 @@ For `AbstractArray` types that support arbitrary dimensionality, `N` can be set FunctionImplementations.Style(::Type{<:MyArray}) = MyArrayStyle() In cases where you want to be able to mix multiple `AbstractArrayStyle`s and keep track -of dimensionality, your style needs to support a [`Val`](@ref) constructor: +of dimensionality, your style needs to support a `Val` constructor: struct MyArrayStyleDim{N} <: FunctionImplementations.AbstractArrayStyle{N} end (::Type{<:MyArrayStyleDim})(::Val{N}) where N = MyArrayStyleDim{N}() @@ -51,7 +62,6 @@ behaves as an `N`-dimensional array. Specifically, `DefaultArrayStyle` is used for any `AbstractArray` type that hasn't defined a specialized style, and in the absence of overrides from other arguments the resulting output type is `Array`. -When there are multiple inputs, `DefaultArrayStyle` "loses" to any other [`FunctionImplementations.ArrayStyle`](@ref). """ struct DefaultArrayStyle{N} <: AbstractArrayStyle{N} end DefaultArrayStyle() = DefaultArrayStyle{Any}() @@ -60,7 +70,6 @@ DefaultArrayStyle{M}(::Val{N}) where {N, M} = DefaultArrayStyle{N}() const DefaultVectorStyle = DefaultArrayStyle{1} const DefaultMatrixStyle = DefaultArrayStyle{2} Style(::Type{<:AbstractArray{T, N}}) where {T, N} = DefaultArrayStyle{N}() -Style(::Type{T}) where {T} = DefaultArrayStyle{ndims(T)}() # `ArrayConflict` is an internal type signaling that two or more different `AbstractArrayStyle` # objects were supplied as arguments, and that no rule was defined for resolving the @@ -114,7 +123,7 @@ Uses [`Style`](@ref) to get the style for each argument, and uses # Examples ```jldoctest julia> FunctionImplementations.combine_styles([1], [1 2; 3 4]) -FunctionImplementations.DefaultArrayStyle{2}() +FunctionImplementations.DefaultArrayStyle{Any}() ``` """ function combine_styles end @@ -134,7 +143,7 @@ determine a common `Style`. ```jldoctest julia> FunctionImplementations.result_style(FunctionImplementations.DefaultArrayStyle{0}(), FunctionImplementations.DefaultArrayStyle{3}()) -FunctionImplementations.DefaultArrayStyle{3}() +FunctionImplementations.DefaultArrayStyle{Any}() julia> FunctionImplementations.result_style(FunctionImplementations.UnknownStyle(), FunctionImplementations.DefaultArrayStyle{1}()) FunctionImplementations.DefaultArrayStyle{1}() diff --git a/test/test_basics.jl b/test/test_basics.jl index 43f7d9f..4bb3488 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -12,9 +12,16 @@ using Test: @test, @testset @test f.f ≡ + @test f.style ≡ MyAddAlgorithm() end + @testset "(s::Style)(f)" begin + # Test the shorthand for creating an Implementation by calling a Style with a + # function. + @test FI.Style([1, 2, 3])(getindex) ≡ + FI.Implementation(getindex, FI.DefaultArrayStyle{1}()) + end @testset "Style" begin # Test basic Style trait for different array types @test FI.Style(typeof([1, 2, 3])) isa FI.DefaultArrayStyle{1} + @test FI.Style([1, 2, 3]) isa FI.DefaultArrayStyle{1} @test FI.Style(typeof([1 2; 3 4])) isa FI.DefaultArrayStyle{2} @test FI.Style(typeof(rand(2, 3, 4))) isa FI.DefaultArrayStyle{3}