Fix erf and a few other improvements#582
Merged
christiangnrd merged 3 commits intomainfrom Apr 18, 2025
Merged
Conversation
Contributor
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/ext/SpecialFunctionsExt.jl b/ext/SpecialFunctionsExt.jl
index df00168d..8f3dc880 100644
--- a/ext/SpecialFunctionsExt.jl
+++ b/ext/SpecialFunctionsExt.jl
@@ -71,14 +71,14 @@ Metal.@device_override function SpecialFunctions.erf(x::Float32)
if ix < 0x04000000 # |x|<0x1p-119
return (8 * x + efx8 * x) / 8 # avoid spurious underflow
end
- return x + efx*x
+ return x + efx * x
end
end
z = x * x
r = pp0 + z * (pp1 + z * pp2)
s = 1.0f0 + z * (qq1 + z * (qq2 + z * qq3))
- y = r / s
- return x + x*y
+ y = r / s
+ return x + x * y
end
if ix < 0x3fa00000 # 0.84375 <= |x| < 1.25
diff --git a/test/device/intrinsics/math.jl b/test/device/intrinsics/math.jl
index 765d2ea3..1ea0771d 100644
--- a/test/device/intrinsics/math.jl
+++ b/test/device/intrinsics/math.jl
@@ -237,38 +237,38 @@ end
end
let # log1p
- arr = T.(collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20)))
+ arr = T.(collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20)))
buffer = MtlArray(arr)
- cpures = log1p.(arr)
- @test Array(log1p.(buffer)) ≈ log1p.(arr)
+ cpures = log1p.(arr)
+ @test Array(log1p.(buffer)) ≈ log1p.(arr)
end
let # erf
- arr = T[-1.0, -0.5, 0.0, 1.0e-3, 1.0, 2.0, 5.5]
+ arr = T[-1.0, -0.5, 0.0, 1.0e-3, 1.0, 2.0, 5.5]
buffer = MtlArray(arr)
- cpures = SpecialFunctions.erf.(arr)
- @test Array(SpecialFunctions.erf.(buffer)) ≈ cpures broken = (T == Float16)
+ cpures = SpecialFunctions.erf.(arr)
+ @test Array(SpecialFunctions.erf.(buffer)) ≈ cpures broken = (T == Float16)
end
let # erfc
- arr = T.(collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)))
+ arr = T.(collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)))
buffer = MtlArray(arr)
- cpures = SpecialFunctions.erfc.(arr)
- @test Array(SpecialFunctions.erfc.(buffer)) ≈ cpures broken = (T == Float16)
+ cpures = SpecialFunctions.erfc.(arr)
+ @test Array(SpecialFunctions.erfc.(buffer)) ≈ cpures broken = (T == Float16)
end
let # erfinv
- arr = T.(collect(LinRange(-1.0f0, 1.0f0, 20)))
+ arr = T.(collect(LinRange(-1.0f0, 1.0f0, 20)))
buffer = MtlArray(arr)
- cpures = SpecialFunctions.erfinv.(arr)
- @test Array(SpecialFunctions.erfinv.(buffer)) ≈ cpures
+ cpures = SpecialFunctions.erfinv.(arr)
+ @test Array(SpecialFunctions.erfinv.(buffer)) ≈ cpures
end
let # expm1
- arr = T.(collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100)))
+ arr = T.(collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100)))
buffer = MtlArray(arr)
- cpures = expm1.(arr)
- @test Array(expm1.(buffer)) ≈ cpures
+ cpures = expm1.(arr)
+ @test Array(expm1.(buffer)) ≈ cpures
end
|
christiangnrd
commented
Apr 17, 2025
Contributor
There was a problem hiding this comment.
Metal Benchmarks
Details
| Benchmark suite | Current: 37c7915 | Previous: a5b56dc | Ratio |
|---|---|---|---|
private array/construct |
26257 ns |
24288.25 ns |
1.08 |
private array/broadcast |
462750 ns |
462666 ns |
1.00 |
private array/random/randn/Float32 |
923792 ns |
809041.5 ns |
1.14 |
private array/random/randn!/Float32 |
594541 ns |
634417 ns |
0.94 |
private array/random/rand!/Int64 |
555396 ns |
566041 ns |
0.98 |
private array/random/rand!/Float32 |
562291 ns |
605750 ns |
0.93 |
private array/random/rand/Int64 |
875000.5 ns |
779834 ns |
1.12 |
private array/random/rand/Float32 |
832437.5 ns |
655667 ns |
1.27 |
private array/copyto!/gpu_to_gpu |
547187.5 ns |
657625 ns |
0.83 |
private array/copyto!/cpu_to_gpu |
630896 ns |
637291.5 ns |
0.99 |
private array/copyto!/gpu_to_cpu |
632625 ns |
834166.5 ns |
0.76 |
private array/accumulate/1d |
1426917 ns |
1334500 ns |
1.07 |
private array/accumulate/2d |
1495562.5 ns |
1396958 ns |
1.07 |
private array/iteration/findall/int |
2275166.5 ns |
2077292 ns |
1.10 |
private array/iteration/findall/bool |
2030584 ns |
1843270.5 ns |
1.10 |
private array/iteration/findfirst/int |
1804812.5 ns |
1714708 ns |
1.05 |
private array/iteration/findfirst/bool |
1715500 ns |
1664854 ns |
1.03 |
private array/iteration/scalar |
2566708.5 ns |
3557500 ns |
0.72 |
private array/iteration/logical |
3554000 ns |
3207375 ns |
1.11 |
private array/iteration/findmin/1d |
1867792 ns |
1757000 ns |
1.06 |
private array/iteration/findmin/2d |
1428666.5 ns |
1357875 ns |
1.05 |
private array/reductions/reduce/1d |
910500 ns |
1041270.5 ns |
0.87 |
private array/reductions/reduce/2d |
686125 ns |
655187.5 ns |
1.05 |
private array/reductions/mapreduce/1d |
992478.5 ns |
1044250 ns |
0.95 |
private array/reductions/mapreduce/2d |
682375 ns |
665167 ns |
1.03 |
private array/permutedims/4d |
2636208 ns |
2525375 ns |
1.04 |
private array/permutedims/2d |
1079917 ns |
1028666 ns |
1.05 |
private array/permutedims/3d |
1798583 ns |
1591000 ns |
1.13 |
private array/copy |
847083 ns |
581875 ns |
1.46 |
latency/precompile |
9897694833 ns |
9742903667 ns |
1.02 |
latency/ttfp |
3855217666.5 ns |
3743066375 ns |
1.03 |
latency/import |
1284400395.5 ns |
1260251146 ns |
1.02 |
integration/metaldevrt |
752958 ns |
725167 ns |
1.04 |
integration/byval/slices=1 |
1644562.5 ns |
1637687.5 ns |
1.00 |
integration/byval/slices=3 |
19545937.5 ns |
9931917 ns |
1.97 |
integration/byval/reference |
1653354 ns |
1568625 ns |
1.05 |
integration/byval/slices=2 |
2801166 ns |
2585875 ns |
1.08 |
kernel/indexing |
455875 ns |
451375 ns |
1.01 |
kernel/indexing_checked |
463000 ns |
455625 ns |
1.02 |
kernel/launch |
9076.5 ns |
40828.125 ns |
0.22 |
metal/synchronization/stream |
15041 ns |
14500 ns |
1.04 |
metal/synchronization/context |
15458 ns |
14875 ns |
1.04 |
shared array/construct |
24275 ns |
24486.083333333336 ns |
0.99 |
shared array/broadcast |
455333 ns |
460791 ns |
0.99 |
shared array/random/randn/Float32 |
917000 ns |
804834 ns |
1.14 |
shared array/random/randn!/Float32 |
608208 ns |
637417 ns |
0.95 |
shared array/random/rand!/Int64 |
553458 ns |
564000 ns |
0.98 |
shared array/random/rand!/Float32 |
555271 ns |
606125 ns |
0.92 |
shared array/random/rand/Int64 |
873833 ns |
750521 ns |
1.16 |
shared array/random/rand/Float32 |
800813 ns |
625979.5 ns |
1.28 |
shared array/copyto!/gpu_to_gpu |
80000 ns |
78667 ns |
1.02 |
shared array/copyto!/cpu_to_gpu |
79959 ns |
81458.5 ns |
0.98 |
shared array/copyto!/gpu_to_cpu |
79542 ns |
83625 ns |
0.95 |
shared array/accumulate/1d |
1431417 ns |
1356083.5 ns |
1.06 |
shared array/accumulate/2d |
1486875 ns |
1392416 ns |
1.07 |
shared array/iteration/findall/int |
2013083 ns |
1859000 ns |
1.08 |
shared array/iteration/findall/bool |
1722542 ns |
1594250 ns |
1.08 |
shared array/iteration/findfirst/int |
1506521 ns |
1405917 ns |
1.07 |
shared array/iteration/findfirst/bool |
1434583 ns |
1375333 ns |
1.04 |
shared array/iteration/scalar |
161000 ns |
157041 ns |
1.03 |
shared array/iteration/logical |
3249333 ns |
2995999.5 ns |
1.08 |
shared array/iteration/findmin/1d |
1580083 ns |
1473166.5 ns |
1.07 |
shared array/iteration/findmin/2d |
1437875 ns |
1367666 ns |
1.05 |
shared array/reductions/reduce/1d |
749708.5 ns |
733833 ns |
1.02 |
shared array/reductions/reduce/2d |
708333 ns |
672875 ns |
1.05 |
shared array/reductions/mapreduce/1d |
732333 ns |
742208 ns |
0.99 |
shared array/reductions/mapreduce/2d |
705000 ns |
666458.5 ns |
1.06 |
shared array/permutedims/4d |
2662188 ns |
2549000 ns |
1.04 |
shared array/permutedims/2d |
1094438 ns |
1027917 ns |
1.06 |
shared array/permutedims/3d |
1785187.5 ns |
1583875 ns |
1.13 |
shared array/copy |
213292 ns |
245500 ns |
0.87 |
This comment was automatically generated by workflow using github-action-benchmark.
tgymnich
approved these changes
Apr 17, 2025
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #582 +/- ##
==========================================
- Coverage 80.73% 80.47% -0.26%
==========================================
Files 61 61
Lines 2657 2658 +1
==========================================
- Hits 2145 2139 -6
- Misses 512 519 +7 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
maleadt
reviewed
Apr 18, 2025
maleadt
approved these changes
Apr 18, 2025
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Tried to turn of KA SpecialFunctions tests without realizing we don't have
gammaimplemented, but that uncovered a bug with our openlibm port oferf.Best looked at each commit individually. I slightly refactored some tests to make it easier to mark broken, and changed the other tests around for consistency. It now also tests Float16 and Float32 instead of testing Float32 twice.