From 9fb073d89f5bc62d4637c9471fe03fc3a5e5f8f7 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Sat, 31 Jan 2026 20:18:24 +0100 Subject: [PATCH] Allow division between two ComplexF32 numbers --- src/device/intrinsics/math.jl | 15 +++++++++++++++ test/device/intrinsics/math.jl | 16 ++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index 4a21ec9c3..1b5265fce 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -48,6 +48,21 @@ end @device_override Base.max(x::Float32, y::Float32, z::Float32) = ccall("extern air.fmax3.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z) @device_override Base.max(x::Float16, y::Float16, z::Float16) = ccall("extern air.fmax3.f16", llvmcall, Float16, (Float16, Float16, Float16), x, y, z) +@device_override function Base.:(/)(z::Complex{Float32}, w::Complex{Float32}) + c, d = reim(w) # Avoid using widen(w) as in Base + a, b = reim(z) # Avoid using widen(z) as in Base + if (isinf(c) | isinf(d)) + if isfinite(z) + return complex(zero(Float32)*sign(real(z))*sign(real(w)), -zero(Float32)*sign(imag(z))*sign(imag(w))) + end + return Float32(NaN)+Float32(NaN)*im + end + mag = inv(muladd(c, c, d^2)) + re_part = muladd(a, c, b*d) + im_part = muladd(b, c, -a*d) + return oftype(z, Complex(re_part*mag, im_part*mag)) +end + @device_override FastMath.acos_fast(x::Float32) = ccall("extern air.fast_acos.f32", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.acos(x::Float32) = ccall("extern air.acos.f32", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.acos(x::Float16) = ccall("extern air.acos.f16", llvmcall, Float16, (Float16,), x) diff --git a/test/device/intrinsics/math.jl b/test/device/intrinsics/math.jl index 8b62548b5..72fd4498c 100644 --- a/test/device/intrinsics/math.jl +++ b/test/device/intrinsics/math.jl @@ -181,6 +181,22 @@ end @test Array(mtlout) == clamp.(in, minval, maxval) end + let + N = 10 + + x = rand(ComplexF32, N) + y = rand(ComplexF32, N) + + dx = MtlArray(x) + dy = MtlArray(y) + + + z = x ./ y + dz = dx ./ dy + + @test Array(dz) ≈ z + end + let #pow N = 4 arr1 = rand(T, N)