Use MultiplicativeInverse to speedup Linear to Cartesian indexing operations#539
Use MultiplicativeInverse to speedup Linear to Cartesian indexing operations#539
Conversation
Benchmark Results
Benchmark PlotsA plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR. |
This stack of pull requests is managed by Graphite. Learn more about stacking. |
102fbbd to
73dc429
Compare
|
Well the only issue is that my benchmarks are angry at me... |
|
From a quick look, this doesn't look to be a silver bullet. Some initial performance measurements in JuliaGPU/GPUArrays.jl#565 (comment). Also, there's now some ; ││││┌ @ /home/tim/Julia/pkg/KernelAbstractions/src/nditeration.jl:31 within `getindex`
; │││││┌ @ tuple.jl:383 within `map`
; ││││││┌ @ /home/tim/Julia/pkg/KernelAbstractions/src/nditeration.jl:32 within `#5`
; │││││││┌ @ array.jl:3065 within `getindex`
; ││││││││┌ @ range.jl:923 within `_getindex`
; │││││││││┌ @ range.jl:953 within `unsafe_getindex`
; ││││││││││┌ @ number.jl:7 within `convert`
; │││││││││││┌ @ boot.jl:891 within `Int32`
; ││││││││││││┌ @ boot.jl:801 within `toInt32`
; │││││││││││││┌ @ boot.jl:764 within `checked_trunc_sint`
%47 = add nsw i64 %43, -2147483647
%48 = icmp ult i64 %47, -4294967296
br i1 %48, label %L304, label %L313
L304: ; preds = %L219
call fastcc void @julia__throw_inexacterror_25251({ i64, i32 } %state)
call void @llvm.trap()
call void asm sideeffect "exit;", ""()
unreachable
L313: ; preds = %L219
%49 = add nsw i64 %45, -2147483647
%50 = icmp ult i64 %49, -4294967296
br i1 %50, label %L335, label %L345
L335: ; preds = %L313
call fastcc void @julia__throw_inexacterror_25251({ i64, i32 } %state)
call void @llvm.trap()
call void asm sideeffect "exit;", ""()
unreachableLooks like this generates a significant amount of code. With the scalar broadcast from ttps://github.com/JuliaGPU/GPUArrays.jl/issues/565, the CUDA.jl version that simply uses hardware indices vs. the KA.jl version: define ptx_kernel void @old({ i64, i32 } %state, { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, { [1 x float], [2 x [1 x i64]] } %1) local_unnamed_addr {
conversion:
%.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 2, 0
%.fca.2.1.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 2, 1
%.fca.3.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 3
%2 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%3 = zext i32 %2 to i64
%4 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
%5 = zext i32 %4 to i64
%6 = mul nuw nsw i64 %3, %5
%7 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%8 = add nuw nsw i32 %7, 1
%9 = zext i32 %8 to i64
%10 = add nuw nsw i64 %6, %9
%11 = icmp sgt i64 %.fca.2.0.extract, 0
call void @llvm.assume(i1 %11)
%12 = icmp sgt i64 %.fca.2.1.extract, 0
call void @llvm.assume(i1 %12)
%.not = icmp sgt i64 %10, %.fca.3.extract
br i1 %.not, label %L176, label %pass
L176: ; preds = %pass, %conversion
ret void
pass: ; preds = %conversion
%.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 0
%13 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
%14 = add nsw i64 %10, -1
%15 = getelementptr inbounds float, float addrspace(1)* %13, i64 %14
%.fca.0.0.extract = extractvalue { [1 x float], [2 x [1 x i64]] } %1, 0, 0
store float %.fca.0.0.extract, float addrspace(1)* %15, align 4
br label %L176
}define ptx_kernel void @new({ i64, i32 } %state, { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, { [1 x float], [2 x [1 x i64]] } %2) local_unnamed_addr {
conversion:
%.fca.0.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 0, 0, 0, 0
%.fca.0.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 0, 0, 1, 0
%.fca.1.0.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 0, 0, 0, 0
%.fca.1.0.0.0.1.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 0, 0, 0, 1
%.fca.1.0.0.0.2.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 0, 0, 0, 2
%.fca.1.0.0.0.3.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 0, 0, 0, 3
%.fca.1.1.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 1, 0, 0, 0
%.fca.1.1.0.0.1.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 1, 0, 0, 1
%.fca.1.1.0.0.2.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 1, 0, 0, 2
%.fca.1.1.0.0.3.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 1, 0, 0, 3
%.fca.1.1.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 1, 0, 1, 0
%3 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%4 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%5 = icmp ne i32 %3, 0
call void @llvm.assume(i1 %5)
%6 = zext i32 %3 to i64
%7 = sext i32 %.fca.1.0.0.0.1.extract to i64
%8 = mul nsw i64 %7, %6
%9 = lshr i64 %8, 32
%10 = trunc i64 %9 to i32
%11 = sext i8 %.fca.1.0.0.0.2.extract to i32
%12 = mul i32 %3, %11
%13 = add i32 %12, %10
%abs.i = call i32 @llvm.abs.i32(i32 %.fca.1.0.0.0.0.extract, i1 false)
%.not = icmp eq i32 %abs.i, 1
%14 = mul i32 %.fca.1.0.0.0.0.extract, %3
%narrow = call i8 @llvm.umin.i8(i8 %.fca.1.0.0.0.3.extract, i8 31)
%.v = zext i8 %narrow to i32
%15 = ashr i32 %13, %.v
%.lobit = lshr i32 %13, 31
%16 = add i32 %.lobit, %15
%17 = select i1 %.not, i32 %14, i32 %16
%18 = mul i32 %17, %.fca.1.0.0.0.0.extract
%19 = add nuw nsw i32 %3, 1
%20 = sub i32 %19, %18
%21 = add i32 %17, 1
%22 = sext i32 %20 to i64
%23 = sext i32 %21 to i64
%24 = icmp ne i32 %4, 0
call void @llvm.assume(i1 %24)
%25 = zext i32 %4 to i64
%26 = sext i32 %.fca.1.1.0.0.1.extract to i64
%27 = mul nsw i64 %26, %25
%28 = lshr i64 %27, 32
%29 = trunc i64 %28 to i32
%30 = sext i8 %.fca.1.1.0.0.2.extract to i32
%31 = mul nsw i32 %4, %30
%32 = add i32 %31, %29
%abs.i29 = call i32 @llvm.abs.i32(i32 %.fca.1.1.0.0.0.extract, i1 false)
%.not39 = icmp eq i32 %abs.i29, 1
%33 = mul i32 %.fca.1.1.0.0.0.extract, %4
%narrow37 = call i8 @llvm.umin.i8(i8 %.fca.1.1.0.0.3.extract, i8 31)
%.v36 = zext i8 %narrow37 to i32
%34 = ashr i32 %32, %.v36
%.lobit38 = lshr i32 %32, 31
%35 = add i32 %.lobit38, %34
%36 = select i1 %.not39, i32 %33, i32 %35
%37 = mul i32 %36, %.fca.1.1.0.0.0.extract
%38 = add nuw nsw i32 %4, 1
%39 = sub i32 %38, %37
%40 = add i32 %36, 1
%41 = sext i32 %39 to i64
%42 = sext i32 %40 to i64
%43 = add nsw i64 %22, -1
%44 = sext i32 %.fca.1.1.0.0.0.extract to i64
%45 = mul nsw i64 %43, %44
%46 = add nsw i64 %45, %41
%47 = add nsw i64 %23, -1
%48 = sext i32 %.fca.1.1.0.1.0.extract to i64
%49 = mul nsw i64 %47, %48
%50 = add nsw i64 %49, %42
%51 = icmp sgt i64 %46, 0
%52 = icmp sle i64 %46, %.fca.0.0.0.0.extract
%53 = and i1 %51, %52
%54 = icmp sgt i64 %50, 0
%55 = icmp sle i64 %50, %.fca.0.0.1.0.extract
%56 = and i1 %54, %55
%57 = and i1 %56, %53
br i1 %57, label %L340, label %L723
L340: ; preds = %conversion
%.fca.0.0.extract = extractvalue { [1 x float], [2 x [1 x i64]] } %2, 0, 0
%.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 2, 0
%.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 0
%58 = add nsw i64 %42, -1
%59 = add nsw i64 %58, %49
%60 = mul i64 %59, %.fca.2.0.extract
%61 = add nsw i64 %41, -1
%62 = add nsw i64 %61, %45
%63 = add i64 %62, %60
%64 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
%65 = getelementptr inbounds float, float addrspace(1)* %64, i64 %63
store float %.fca.0.0.extract, float addrspace(1)* %65, align 4
br label %L723
L723: ; preds = %L340, %conversion
ret void
}Of course, this looks extra bad because the base kernel is so simple. The changes to get rid of all exceptions: diff --git a/src/nditeration.jl b/src/nditeration.jl
index b933c2b..e7f087c 100644
--- a/src/nditeration.jl
+++ b/src/nditeration.jl
@@ -29,7 +29,7 @@ end
@inline function Base.getindex(iter::FastCartesianIndices{N}, I::Vararg{Int, N}) where N
@boundscheck checkbounds(iter, I...)
index = map(iter.inverses, I) do inv, i
- @inbounds getindex(Base.OneTo(inv.divisor), i)
+ @inbounds getindex(Base.OneTo(inv.divisor), i%Int32)
end
CartesianIndex(index)
end
@@ -43,13 +43,15 @@ end
function _ind2sub_recurse(inds, ind)
Base.@_inline_meta
inv = inds[1]
+ Main.LLVM.Interop.assume(ind > 0)
indnext, f, l = _div(ind, inv)
(ind-l*indnext+f, _ind2sub_recurse(Base.tail(inds), indnext)...)
end
_lookup(ind, inv::SignedMultiplicativeInverse) = ind+1
function _div(ind, inv::SignedMultiplicativeInverse)
- inv.divisor == 0 && throw(DivideError())
+ #inv.divisor == 0 && throw(DivideError())
+ Main.LLVM.Interop.assume(ind >= 0)
div(ind%Int32, inv), 1, inv.divisor
end |
73dc429 to
00bcec3
Compare
|
I'm very interested to know if a silver bullet solution emerges. We basically had to redefine some of broadcast and force linear indexing for certain clima cuda kernels. |

Related to JuliaGPU/Metal.jl#101
Using the idea from @N5N3 JuliaGPU/Metal.jl#101 (comment)