diff --git a/ext/SpecialFunctionsExt.jl b/ext/SpecialFunctionsExt.jl index c1e9d29ba..df00168d7 100644 --- a/ext/SpecialFunctionsExt.jl +++ b/ext/SpecialFunctionsExt.jl @@ -71,9 +71,14 @@ Metal.@device_override function SpecialFunctions.erf(x::Float32) if ix < 0x04000000 # |x|<0x1p-119 return (8 * x + efx8 * x) / 8 # avoid spurious underflow end - return x + efx*x; + return x + efx*x end end + z = x * x + r = pp0 + z * (pp1 + z * pp2) + s = 1.0f0 + z * (qq1 + z * (qq2 + z * qq3)) + y = r / s + return x + x*y end if ix < 0x3fa00000 # 0.84375 <= |x| < 1.25 @@ -152,7 +157,7 @@ Metal.@device_override function SpecialFunctions.erfc(x::Float32) Q = 1.0f0 + s * (qa1 + s * (qa2 + s * (qa3 + s * qa4))) if hx >= 0 z = 1.0f0 - erx - return z - P / Q; + return z - P / Q else z = erx + P / Q return 1.0f0 + z @@ -160,7 +165,7 @@ Metal.@device_override function SpecialFunctions.erfc(x::Float32) end if ix < 0x41300000 # |x|<28 - x = abs(x); + x = abs(x) s = 1.0f0 / (x * x) if ix < 0x4036DB6D # |x| < 1/.35 ~ 2.857143 R = ra0 + s * (ra1 + s * (ra2 + s * ra3)) diff --git a/test/device/intrinsics/math.jl b/test/device/intrinsics/math.jl index ac1d7e9e4..765d2ea3c 100644 --- a/test/device/intrinsics/math.jl +++ b/test/device/intrinsics/math.jl @@ -237,38 +237,38 @@ end end let # log1p - arr = collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20)) + arr = T.(collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20))) buffer = MtlArray(arr) - vec = Array(log1p.(buffer)) - @test vec ≈ log1p.(arr) + cpures = log1p.(arr) + @test Array(log1p.(buffer)) ≈ log1p.(arr) end let # erf - arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)) + arr = T[-1.0, -0.5, 0.0, 1.0e-3, 1.0, 2.0, 5.5] buffer = MtlArray(arr) - vec = Array(SpecialFunctions.erf.(buffer)) - @test vec ≈ SpecialFunctions.erf.(arr) + cpures = SpecialFunctions.erf.(arr) + @test Array(SpecialFunctions.erf.(buffer)) ≈ cpures broken = (T == Float16) end let # erfc - arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)) + arr = T.(collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))) buffer = MtlArray(arr) - vec = Array(SpecialFunctions.erfc.(buffer)) - @test vec ≈ SpecialFunctions.erfc.(arr) + cpures = SpecialFunctions.erfc.(arr) + @test Array(SpecialFunctions.erfc.(buffer)) ≈ cpures broken = (T == Float16) end let # erfinv - arr = collect(LinRange(-1.0f0, 1.0f0, 20)) + arr = T.(collect(LinRange(-1.0f0, 1.0f0, 20))) buffer = MtlArray(arr) - vec = Array(SpecialFunctions.erfinv.(buffer)) - @test vec ≈ SpecialFunctions.erfinv.(arr) + cpures = SpecialFunctions.erfinv.(arr) + @test Array(SpecialFunctions.erfinv.(buffer)) ≈ cpures end let # expm1 - arr = collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100)) + arr = T.(collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100))) buffer = MtlArray(arr) - vec = Array(expm1.(buffer)) - @test vec ≈ expm1.(arr) + cpures = expm1.(arr) + @test Array(expm1.(buffer)) ≈ cpures end diff --git a/test/mps/ndarray.jl b/test/mps/ndarray.jl index 3fd691de0..86fef4e61 100644 --- a/test/mps/ndarray.jl +++ b/test/mps/ndarray.jl @@ -1,7 +1,7 @@ # # matrix descriptor # -using Metal,Test; +using Metal using .MPS: MPSNDArrayDescriptor, MPSDataType, lengthOfDimension, descriptor, resourceSize @static if Metal.macos_version() >= v"15" using .MPS: userBuffer