From c667bb52d0d3b971b756c4bdfe21578bb8436f9a Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 17 Apr 2025 12:23:32 -0300 Subject: [PATCH 1/3] Fix `erf` --- ext/SpecialFunctionsExt.jl | 5 +++++ test/device/intrinsics/math.jl | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/ext/SpecialFunctionsExt.jl b/ext/SpecialFunctionsExt.jl index c1e9d29ba..1ffac7d9f 100644 --- a/ext/SpecialFunctionsExt.jl +++ b/ext/SpecialFunctionsExt.jl @@ -74,6 +74,11 @@ Metal.@device_override function SpecialFunctions.erf(x::Float32) return x + efx*x; end end + z = x * x + r = pp0 + z * (pp1 + z * pp2) + s = 1.0f0 + z * (qq1 + z * (qq2 + z * qq3)) + y = r / s + return x + x*y end if ix < 0x3fa00000 # 0.84375 <= |x| < 1.25 diff --git a/test/device/intrinsics/math.jl b/test/device/intrinsics/math.jl index ac1d7e9e4..a031f217e 100644 --- a/test/device/intrinsics/math.jl +++ b/test/device/intrinsics/math.jl @@ -244,7 +244,7 @@ end end let # erf - arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)) + arr = Float32[-1.0, -0.5, 0.0, 1.0e-3, 1.0, 2.0, 5.5] buffer = MtlArray(arr) vec = Array(SpecialFunctions.erf.(buffer)) @test vec ≈ SpecialFunctions.erf.(arr) From eafbf6681d0a6f7cb0a2db5e9507fb2f5cf5dcae Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 17 Apr 2025 12:24:09 -0300 Subject: [PATCH 2/3] [NFC] Semicolon cleanup --- ext/SpecialFunctionsExt.jl | 6 +++--- test/mps/ndarray.jl | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/SpecialFunctionsExt.jl b/ext/SpecialFunctionsExt.jl index 1ffac7d9f..df00168d7 100644 --- a/ext/SpecialFunctionsExt.jl +++ b/ext/SpecialFunctionsExt.jl @@ -71,7 +71,7 @@ Metal.@device_override function SpecialFunctions.erf(x::Float32) if ix < 0x04000000 # |x|<0x1p-119 return (8 * x + efx8 * x) / 8 # avoid spurious underflow end - return x + efx*x; + return x + efx*x end end z = x * x @@ -157,7 +157,7 @@ Metal.@device_override function SpecialFunctions.erfc(x::Float32) Q = 1.0f0 + s * (qa1 + s * (qa2 + s * (qa3 + s * qa4))) if hx >= 0 z = 1.0f0 - erx - return z - P / Q; + return z - P / Q else z = erx + P / Q return 1.0f0 + z @@ -165,7 +165,7 @@ Metal.@device_override function SpecialFunctions.erfc(x::Float32) end if ix < 0x41300000 # |x|<28 - x = abs(x); + x = abs(x) s = 1.0f0 / (x * x) if ix < 0x4036DB6D # |x| < 1/.35 ~ 2.857143 R = ra0 + s * (ra1 + s * (ra2 + s * ra3)) diff --git a/test/mps/ndarray.jl b/test/mps/ndarray.jl index 3fd691de0..86fef4e61 100644 --- a/test/mps/ndarray.jl +++ b/test/mps/ndarray.jl @@ -1,7 +1,7 @@ # # matrix descriptor # -using Metal,Test; +using Metal using .MPS: MPSNDArrayDescriptor, MPSDataType, lengthOfDimension, descriptor, resourceSize @static if Metal.macos_version() >= v"15" using .MPS: userBuffer From 37c7915c2816e4e0c8ac9aa9ca554ab3da53ceb1 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Thu, 17 Apr 2025 12:25:05 -0300 Subject: [PATCH 3/3] Test proper type --- test/device/intrinsics/math.jl | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/test/device/intrinsics/math.jl b/test/device/intrinsics/math.jl index a031f217e..765d2ea3c 100644 --- a/test/device/intrinsics/math.jl +++ b/test/device/intrinsics/math.jl @@ -237,38 +237,38 @@ end end let # log1p - arr = collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20)) + arr = T.(collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20))) buffer = MtlArray(arr) - vec = Array(log1p.(buffer)) - @test vec ≈ log1p.(arr) + cpures = log1p.(arr) + @test Array(log1p.(buffer)) ≈ log1p.(arr) end let # erf - arr = Float32[-1.0, -0.5, 0.0, 1.0e-3, 1.0, 2.0, 5.5] + arr = T[-1.0, -0.5, 0.0, 1.0e-3, 1.0, 2.0, 5.5] buffer = MtlArray(arr) - vec = Array(SpecialFunctions.erf.(buffer)) - @test vec ≈ SpecialFunctions.erf.(arr) + cpures = SpecialFunctions.erf.(arr) + @test Array(SpecialFunctions.erf.(buffer)) ≈ cpures broken = (T == Float16) end let # erfc - arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)) + arr = T.(collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))) buffer = MtlArray(arr) - vec = Array(SpecialFunctions.erfc.(buffer)) - @test vec ≈ SpecialFunctions.erfc.(arr) + cpures = SpecialFunctions.erfc.(arr) + @test Array(SpecialFunctions.erfc.(buffer)) ≈ cpures broken = (T == Float16) end let # erfinv - arr = collect(LinRange(-1.0f0, 1.0f0, 20)) + arr = T.(collect(LinRange(-1.0f0, 1.0f0, 20))) buffer = MtlArray(arr) - vec = Array(SpecialFunctions.erfinv.(buffer)) - @test vec ≈ SpecialFunctions.erfinv.(arr) + cpures = SpecialFunctions.erfinv.(arr) + @test Array(SpecialFunctions.erfinv.(buffer)) ≈ cpures end let # expm1 - arr = collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100)) + arr = T.(collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100))) buffer = MtlArray(arr) - vec = Array(expm1.(buffer)) - @test vec ≈ expm1.(arr) + cpures = expm1.(arr) + @test Array(expm1.(buffer)) ≈ cpures end