Skip to content

Fix erf and a few other improvements#582

Merged
christiangnrd merged 3 commits intomainfrom
erf
Apr 18, 2025
Merged

Fix erf and a few other improvements#582
christiangnrd merged 3 commits intomainfrom
erf

Conversation

@christiangnrd
Copy link
Member

Tried to turn of KA SpecialFunctions tests without realizing we don't have gamma implemented, but that uncovered a bug with our openlibm port of erf.

Best looked at each commit individually. I slightly refactored some tests to make it easier to mark broken, and changed the other tests around for consistency. It now also tests Float16 and Float32 instead of testing Float32 twice.

@christiangnrd christiangnrd requested a review from tgymnich April 17, 2025 15:30
@github-actions
Copy link
Contributor

github-actions bot commented Apr 17, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/ext/SpecialFunctionsExt.jl b/ext/SpecialFunctionsExt.jl
index df00168d..8f3dc880 100644
--- a/ext/SpecialFunctionsExt.jl
+++ b/ext/SpecialFunctionsExt.jl
@@ -71,14 +71,14 @@ Metal.@device_override function SpecialFunctions.erf(x::Float32)
                 if ix < 0x04000000 	# |x|<0x1p-119
                     return (8 * x + efx8 * x) / 8	# avoid spurious underflow
                 end
-                return x + efx*x
+                return x + efx * x
             end
         end
         z = x * x
         r = pp0 + z * (pp1 + z * pp2)
         s = 1.0f0 + z * (qq1 + z * (qq2 + z * qq3))
-	    y = r / s
-	    return x + x*y
+        y = r / s
+        return x + x * y
     end
 
     if ix < 0x3fa00000 	# 0.84375 <= |x| < 1.25
diff --git a/test/device/intrinsics/math.jl b/test/device/intrinsics/math.jl
index 765d2ea3..1ea0771d 100644
--- a/test/device/intrinsics/math.jl
+++ b/test/device/intrinsics/math.jl
@@ -237,38 +237,38 @@ end
     end
 
     let # log1p
-        arr = T.(collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20)))
+            arr = T.(collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20)))
         buffer = MtlArray(arr)
-        cpures = log1p.(arr)
-        @test Array(log1p.(buffer)) ≈ log1p.(arr)
+            cpures = log1p.(arr)
+            @test Array(log1p.(buffer)) ≈ log1p.(arr)
     end
 
     let # erf
-        arr = T[-1.0, -0.5, 0.0, 1.0e-3, 1.0, 2.0, 5.5]
+            arr = T[-1.0, -0.5, 0.0, 1.0e-3, 1.0, 2.0, 5.5]
         buffer = MtlArray(arr)
-        cpures = SpecialFunctions.erf.(arr)
-        @test Array(SpecialFunctions.erf.(buffer)) ≈ cpures broken = (T == Float16)
+            cpures = SpecialFunctions.erf.(arr)
+            @test Array(SpecialFunctions.erf.(buffer)) ≈ cpures broken = (T == Float16)
     end
 
     let # erfc
-        arr = T.(collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)))
+            arr = T.(collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)))
         buffer = MtlArray(arr)
-        cpures = SpecialFunctions.erfc.(arr)
-        @test Array(SpecialFunctions.erfc.(buffer)) ≈ cpures broken = (T == Float16)
+            cpures = SpecialFunctions.erfc.(arr)
+            @test Array(SpecialFunctions.erfc.(buffer)) ≈ cpures broken = (T == Float16)
     end
 
     let # erfinv
-        arr = T.(collect(LinRange(-1.0f0, 1.0f0, 20)))
+            arr = T.(collect(LinRange(-1.0f0, 1.0f0, 20)))
         buffer = MtlArray(arr)
-        cpures = SpecialFunctions.erfinv.(arr)
-        @test Array(SpecialFunctions.erfinv.(buffer)) ≈ cpures
+            cpures = SpecialFunctions.erfinv.(arr)
+            @test Array(SpecialFunctions.erfinv.(buffer)) ≈ cpures
     end
 
     let # expm1
-        arr = T.(collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100)))
+            arr = T.(collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100)))
         buffer = MtlArray(arr)
-        cpures = expm1.(arr)
-        @test Array(expm1.(buffer)) ≈ cpures
+            cpures = expm1.(arr)
+            @test Array(expm1.(buffer)) ≈ cpures
     end
 
 

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metal Benchmarks

Details
Benchmark suite Current: 37c7915 Previous: a5b56dc Ratio
private array/construct 26257 ns 24288.25 ns 1.08
private array/broadcast 462750 ns 462666 ns 1.00
private array/random/randn/Float32 923792 ns 809041.5 ns 1.14
private array/random/randn!/Float32 594541 ns 634417 ns 0.94
private array/random/rand!/Int64 555396 ns 566041 ns 0.98
private array/random/rand!/Float32 562291 ns 605750 ns 0.93
private array/random/rand/Int64 875000.5 ns 779834 ns 1.12
private array/random/rand/Float32 832437.5 ns 655667 ns 1.27
private array/copyto!/gpu_to_gpu 547187.5 ns 657625 ns 0.83
private array/copyto!/cpu_to_gpu 630896 ns 637291.5 ns 0.99
private array/copyto!/gpu_to_cpu 632625 ns 834166.5 ns 0.76
private array/accumulate/1d 1426917 ns 1334500 ns 1.07
private array/accumulate/2d 1495562.5 ns 1396958 ns 1.07
private array/iteration/findall/int 2275166.5 ns 2077292 ns 1.10
private array/iteration/findall/bool 2030584 ns 1843270.5 ns 1.10
private array/iteration/findfirst/int 1804812.5 ns 1714708 ns 1.05
private array/iteration/findfirst/bool 1715500 ns 1664854 ns 1.03
private array/iteration/scalar 2566708.5 ns 3557500 ns 0.72
private array/iteration/logical 3554000 ns 3207375 ns 1.11
private array/iteration/findmin/1d 1867792 ns 1757000 ns 1.06
private array/iteration/findmin/2d 1428666.5 ns 1357875 ns 1.05
private array/reductions/reduce/1d 910500 ns 1041270.5 ns 0.87
private array/reductions/reduce/2d 686125 ns 655187.5 ns 1.05
private array/reductions/mapreduce/1d 992478.5 ns 1044250 ns 0.95
private array/reductions/mapreduce/2d 682375 ns 665167 ns 1.03
private array/permutedims/4d 2636208 ns 2525375 ns 1.04
private array/permutedims/2d 1079917 ns 1028666 ns 1.05
private array/permutedims/3d 1798583 ns 1591000 ns 1.13
private array/copy 847083 ns 581875 ns 1.46
latency/precompile 9897694833 ns 9742903667 ns 1.02
latency/ttfp 3855217666.5 ns 3743066375 ns 1.03
latency/import 1284400395.5 ns 1260251146 ns 1.02
integration/metaldevrt 752958 ns 725167 ns 1.04
integration/byval/slices=1 1644562.5 ns 1637687.5 ns 1.00
integration/byval/slices=3 19545937.5 ns 9931917 ns 1.97
integration/byval/reference 1653354 ns 1568625 ns 1.05
integration/byval/slices=2 2801166 ns 2585875 ns 1.08
kernel/indexing 455875 ns 451375 ns 1.01
kernel/indexing_checked 463000 ns 455625 ns 1.02
kernel/launch 9076.5 ns 40828.125 ns 0.22
metal/synchronization/stream 15041 ns 14500 ns 1.04
metal/synchronization/context 15458 ns 14875 ns 1.04
shared array/construct 24275 ns 24486.083333333336 ns 0.99
shared array/broadcast 455333 ns 460791 ns 0.99
shared array/random/randn/Float32 917000 ns 804834 ns 1.14
shared array/random/randn!/Float32 608208 ns 637417 ns 0.95
shared array/random/rand!/Int64 553458 ns 564000 ns 0.98
shared array/random/rand!/Float32 555271 ns 606125 ns 0.92
shared array/random/rand/Int64 873833 ns 750521 ns 1.16
shared array/random/rand/Float32 800813 ns 625979.5 ns 1.28
shared array/copyto!/gpu_to_gpu 80000 ns 78667 ns 1.02
shared array/copyto!/cpu_to_gpu 79959 ns 81458.5 ns 0.98
shared array/copyto!/gpu_to_cpu 79542 ns 83625 ns 0.95
shared array/accumulate/1d 1431417 ns 1356083.5 ns 1.06
shared array/accumulate/2d 1486875 ns 1392416 ns 1.07
shared array/iteration/findall/int 2013083 ns 1859000 ns 1.08
shared array/iteration/findall/bool 1722542 ns 1594250 ns 1.08
shared array/iteration/findfirst/int 1506521 ns 1405917 ns 1.07
shared array/iteration/findfirst/bool 1434583 ns 1375333 ns 1.04
shared array/iteration/scalar 161000 ns 157041 ns 1.03
shared array/iteration/logical 3249333 ns 2995999.5 ns 1.08
shared array/iteration/findmin/1d 1580083 ns 1473166.5 ns 1.07
shared array/iteration/findmin/2d 1437875 ns 1367666 ns 1.05
shared array/reductions/reduce/1d 749708.5 ns 733833 ns 1.02
shared array/reductions/reduce/2d 708333 ns 672875 ns 1.05
shared array/reductions/mapreduce/1d 732333 ns 742208 ns 0.99
shared array/reductions/mapreduce/2d 705000 ns 666458.5 ns 1.06
shared array/permutedims/4d 2662188 ns 2549000 ns 1.04
shared array/permutedims/2d 1094438 ns 1027917 ns 1.06
shared array/permutedims/3d 1785187.5 ns 1583875 ns 1.13
shared array/copy 213292 ns 245500 ns 0.87

This comment was automatically generated by workflow using github-action-benchmark.

@codecov
Copy link

codecov bot commented Apr 17, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 80.47%. Comparing base (a5b56dc) to head (37c7915).
Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #582      +/-   ##
==========================================
- Coverage   80.73%   80.47%   -0.26%     
==========================================
  Files          61       61              
  Lines        2657     2658       +1     
==========================================
- Hits         2145     2139       -6     
- Misses        512      519       +7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@christiangnrd christiangnrd merged commit ab25f7b into main Apr 18, 2025
7 checks passed
@christiangnrd christiangnrd deleted the erf branch April 18, 2025 13:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants