SPIR-V: Add a pre-optimization pass to convert unreachable to return.#709
SPIR-V: Add a pre-optimization pass to convert unreachable to return.#709
Conversation
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/src/spirv.jl b/src/spirv.jl
index 89bb54a..145685c 100644
--- a/src/spirv.jl
+++ b/src/spirv.jl
@@ -370,224 +370,238 @@ end
# results in `OpUnreachable` actually getting executed, which is undefined behavior.
# Instead, we transform unreachable instructions to returns with an error flag that's
# checked by the caller.
-function lower_unreachable_to_return!(@nospecialize(job::CompilerJob),
- mod::LLVM.Module, entry::LLVM.Function)
+function lower_unreachable_to_return!(
+ @nospecialize(job::CompilerJob),
+ mod::LLVM.Module, entry::LLVM.Function
+ )
job = current_job::CompilerJob
changed = false
@tracepoint "lower unreachable to return" begin
- already_transformed_functions = Set{LLVM.Function}()
+ already_transformed_functions = Set{LLVM.Function}()
- # The pass runs until all unreachable instructions are transformed. During each
- # iteration, we transform all unreachable instructions to returns, and transform all
- # callers to handle the flag, generating a new unreachable when it is set.
- while true
- # Find all functions with unreachable instructions
- functions_with_unreachable = Set{LLVM.Function}()
- for f in functions(mod)
- for bb in blocks(f), inst in instructions(bb)
- if inst isa LLVM.UnreachableInst
- push!(functions_with_unreachable, f)
- break
- end
- end
- end
- isempty(functions_with_unreachable) && break
-
- # Transform functions with unreachable to return a flag next to the original value
- transformed_functions = Dict{LLVM.Function, LLVM.Function}()
- for f in functions_with_unreachable
- ft = function_type(f)
- ret_type = return_type(ft)
- fn = LLVM.name(f)
-
- # in the case of the entry-point function, we cannot touch its type or returned
- # value, so simply replace the unreachable with a return.
- if f == entry
- @compiler_assert ret_type == LLVM.VoidType() job
-
- # find un reachables
- unreachables = LLVM.Value[]
+ # The pass runs until all unreachable instructions are transformed. During each
+ # iteration, we transform all unreachable instructions to returns, and transform all
+ # callers to handle the flag, generating a new unreachable when it is set.
+ while true
+ # Find all functions with unreachable instructions
+ functions_with_unreachable = Set{LLVM.Function}()
+ for f in functions(mod)
for bb in blocks(f), inst in instructions(bb)
if inst isa LLVM.UnreachableInst
- push!(unreachables, inst)
+ push!(functions_with_unreachable, f)
+ break
end
end
+ end
+ isempty(functions_with_unreachable) && break
+
+ # Transform functions with unreachable to return a flag next to the original value
+ transformed_functions = Dict{LLVM.Function, LLVM.Function}()
+ for f in functions_with_unreachable
+ ft = function_type(f)
+ ret_type = return_type(ft)
+ fn = LLVM.name(f)
+
+ # in the case of the entry-point function, we cannot touch its type or returned
+ # value, so simply replace the unreachable with a return.
+ if f == entry
+ @compiler_assert ret_type == LLVM.VoidType() job
+
+ # find un reachables
+ unreachables = LLVM.Value[]
+ for bb in blocks(f), inst in instructions(bb)
+ if inst isa LLVM.UnreachableInst
+ push!(unreachables, inst)
+ end
+ end
- # transform unreachable to return
- @dispose builder=IRBuilder() begin
- for inst in unreachables
- position!(builder, inst)
- ret!(builder)
- erase!(inst)
+ # transform unreachable to return
+ @dispose builder = IRBuilder() begin
+ for inst in unreachables
+ position!(builder, inst)
+ ret!(builder)
+ erase!(inst)
+ end
end
+
+ continue
end
- continue
- end
+ # If this is the first time looking at this function, we need to change its type
+ if !in(f, already_transformed_functions)
+ # Create new return type: {i1, original_type}
+ new_ret_type = if ret_type == LLVM.VoidType()
+ LLVM.StructType([LLVM.Int1Type()])
+ else
+ LLVM.StructType([LLVM.Int1Type(), ret_type])
+ end
- # If this is the first time looking at this function, we need to change its type
- if !in(f, already_transformed_functions)
- # Create new return type: {i1, original_type}
- new_ret_type = if ret_type == LLVM.VoidType()
- LLVM.StructType([LLVM.Int1Type()])
- else
- LLVM.StructType([LLVM.Int1Type(), ret_type])
- end
+ LLVM.name!(f, fn * ".old")
+ new_ft = LLVM.FunctionType(new_ret_type, parameters(ft))
+ new_f = LLVM.Function(mod, fn, new_ft)
+ linkage!(new_f, linkage(f))
+ for (i, param) in enumerate(parameters(f))
+ LLVM.name!(parameters(new_f)[i], LLVM.name(param))
+ end
+
+ # clone the IR
+ value_map = Dict{LLVM.Value, LLVM.Value}(
+ param => parameters(new_f)[i] for (i, param) in enumerate(parameters(f))
+ )
+ clone_into!(
+ new_f, f; value_map,
+ changes = LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges
+ )
+
+ # rewrite return instructions
+ returns = LLVM.Value[]
+ for bb in blocks(new_f), inst in instructions(bb)
+ if inst isa LLVM.RetInst
+ push!(returns, inst)
+ end
+ end
+ @dispose builder = IRBuilder() begin
+ for inst in returns
+ position!(builder, inst)
+ if ret_type == LLVM.VoidType()
+ # void function: return {false}
+ flag_and_val =
+ insert_value!(
+ builder, UndefValue(new_ret_type),
+ ConstantInt(LLVM.Int1Type(), false), 0
+ )
+ else
+ # non-void function: return {false, val}
+ val = only(operands(inst))
+ flag_and_val =
+ insert_value!(
+ builder, UndefValue(new_ret_type),
+ ConstantInt(LLVM.Int1Type(), false), 0
+ )
+ flag_and_val = insert_value!(builder, flag_and_val, val, 1)
+ end
+ ret!(builder, flag_and_val)
+ erase!(inst)
+ end
+ end
- LLVM.name!(f, fn * ".old")
- new_ft = LLVM.FunctionType(new_ret_type, parameters(ft))
- new_f = LLVM.Function(mod, fn, new_ft)
- linkage!(new_f, linkage(f))
- for (i, param) in enumerate(parameters(f))
- LLVM.name!(parameters(new_f)[i], LLVM.name(param))
+ transformed_functions[f] = new_f
+ push!(already_transformed_functions, new_f)
+ f = new_f
end
- # clone the IR
- value_map = Dict{LLVM.Value, LLVM.Value}(
- param => parameters(new_f)[i] for (i,param) in enumerate(parameters(f))
- )
- clone_into!(new_f, f; value_map,
- changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)
-
- # rewrite return instructions
- returns = LLVM.Value[]
- for bb in blocks(new_f), inst in instructions(bb)
- if inst isa LLVM.RetInst
- push!(returns, inst)
+ # rewrite unreachable instructions
+ ret_type = return_type(function_type(f))
+ unreachables = LLVM.Value[]
+ for bb in blocks(f), inst in instructions(bb)
+ if inst isa LLVM.UnreachableInst
+ push!(unreachables, inst)
end
end
- @dispose builder=IRBuilder() begin
- for inst in returns
+ @dispose builder = IRBuilder() begin
+ for inst in unreachables
position!(builder, inst)
- if ret_type == LLVM.VoidType()
- # void function: return {false}
- flag_and_val =
- insert_value!(builder, UndefValue(new_ret_type),
- ConstantInt(LLVM.Int1Type(), false), 0)
+ if length(elements(ret_type)) == 1
+ # void function: return {true}
+ flag_and_val = insert_value!(
+ builder, UndefValue(ret_type),
+ ConstantInt(LLVM.Int1Type(), true), 0
+ )
else
- # non-void function: return {false, val}
- val = only(operands(inst))
- flag_and_val =
- insert_value!(builder, UndefValue(new_ret_type),
- ConstantInt(LLVM.Int1Type(), false), 0)
- flag_and_val = insert_value!(builder, flag_and_val, val, 1)
+ # non-void function: return {true, undef}
+ val_type = elements(ret_type)[2]
+ flag_and_val = insert_value!(
+ builder, UndefValue(ret_type),
+ ConstantInt(LLVM.Int1Type(), true), 0
+ )
+ flag_and_val = insert_value!(
+ builder, flag_and_val,
+ UndefValue(val_type), 1
+ )
end
ret!(builder, flag_and_val)
erase!(inst)
end
end
- transformed_functions[f] = new_f
- push!(already_transformed_functions, new_f)
- f = new_f
+ changed = true
end
- # rewrite unreachable instructions
- ret_type = return_type(function_type(f))
- unreachables = LLVM.Value[]
- for bb in blocks(f), inst in instructions(bb)
- if inst isa LLVM.UnreachableInst
- push!(unreachables, inst)
- end
- end
- @dispose builder=IRBuilder() begin
- for inst in unreachables
- position!(builder, inst)
- if length(elements(ret_type)) == 1
- # void function: return {true}
- flag_and_val = insert_value!(builder, UndefValue(ret_type),
- ConstantInt(LLVM.Int1Type(), true), 0)
- else
- # non-void function: return {true, undef}
- val_type = elements(ret_type)[2]
- flag_and_val = insert_value!(builder, UndefValue(ret_type),
- ConstantInt(LLVM.Int1Type(), true), 0)
- flag_and_val = insert_value!(builder, flag_and_val,
- UndefValue(val_type), 1)
+ # Rewrite calls
+ for (old_f, new_f) in transformed_functions
+ calls_to_rewrite = LLVM.CallInst[]
+ for use in uses(old_f)
+ call_inst = user(use)
+ if call_inst isa LLVM.CallInst && called_operand(call_inst) == old_f
+ push!(calls_to_rewrite, call_inst)
end
- ret!(builder, flag_and_val)
- erase!(inst)
end
- end
- changed = true
- end
+ @dispose builder = IRBuilder() begin
+ for call_inst in calls_to_rewrite
+ f = LLVM.parent(LLVM.parent(call_inst))
+ position!(builder, call_inst)
+
+ # Call the new function
+ new_call = call!(builder, function_type(new_f), new_f, arguments(call_inst))
+ callconv!(new_call, callconv(call_inst))
+
+ # Split the block and branch based on the flag
+ flag = extract_value!(builder, new_call, 0)
+ error_block = BasicBlock(f, "error")
+ move_after(error_block, LLVM.parent(call_inst))
+ continue_block = BasicBlock(f, "continue")
+ move_after(continue_block, error_block)
+ br_inst = br!(builder, flag, error_block, continue_block)
+
+ # Extract the returned value in the continue block
+ position!(builder, continue_block)
+ if value_type(call_inst) != LLVM.VoidType()
+ value = extract_value!(builder, new_call, 1)
+ replace_uses!(call_inst, value)
+ end
+ @compiler_assert isempty(uses(call_inst)) job
+ erase!(call_inst)
+
+ # Move the remaining instructions over to the continue block
+ while true
+ inst = LLVM.nextinst(br_inst)
+ inst === nothing && break
+ remove!(inst)
+ insert!(builder, inst)
+ end
- # Rewrite calls
- for (old_f, new_f) in transformed_functions
- calls_to_rewrite = LLVM.CallInst[]
- for use in uses(old_f)
- call_inst = user(use)
- if call_inst isa LLVM.CallInst && called_operand(call_inst) == old_f
- push!(calls_to_rewrite, call_inst)
+ # Generate an unreachable in the error block
+ position!(builder, error_block)
+ unreachable!(builder)
+ end
end
+
+ @compiler_assert isempty(uses(old_f)) job
+ erase!(old_f)
end
+ end
- @dispose builder=IRBuilder() begin
- for call_inst in calls_to_rewrite
- f = LLVM.parent(LLVM.parent(call_inst))
- position!(builder, call_inst)
-
- # Call the new function
- new_call = call!(builder, function_type(new_f), new_f, arguments(call_inst))
- callconv!(new_call, callconv(call_inst))
-
- # Split the block and branch based on the flag
- flag = extract_value!(builder, new_call, 0)
- error_block = BasicBlock(f, "error")
- move_after(error_block, LLVM.parent(call_inst))
- continue_block = BasicBlock(f, "continue")
- move_after(continue_block, error_block)
- br_inst = br!(builder, flag, error_block, continue_block)
-
- # Extract the returned value in the continue block
- position!(builder, continue_block)
- if value_type(call_inst) != LLVM.VoidType()
- value = extract_value!(builder, new_call, 1)
- replace_uses!(call_inst, value)
- end
- @compiler_assert isempty(uses(call_inst)) job
- erase!(call_inst)
-
- # Move the remaining instructions over to the continue block
- while true
- inst = LLVM.nextinst(br_inst)
- inst === nothing && break
- remove!(inst)
- insert!(builder, inst)
- end
+ # Get rid of `llvm.trap` and `noreturn` to prevent reconstructing `unreachable`
+ if haskey(functions(mod), "llvm.trap")
+ trap = functions(mod)["llvm.trap"]
- # Generate an unreachable in the error block
- position!(builder, error_block)
- unreachable!(builder)
+ for use in uses(trap)
+ val = user(use)
+ if isa(val, LLVM.CallInst)
+ erase!(val)
+ changed = true
end
end
- @compiler_assert isempty(uses(old_f)) job
- erase!(old_f)
+ @compiler_assert isempty(uses(trap)) job
+ erase!(trap)
end
- end
-
- # Get rid of `llvm.trap` and `noreturn` to prevent reconstructing `unreachable`
- if haskey(functions(mod), "llvm.trap")
- trap = functions(mod)["llvm.trap"]
-
- for use in uses(trap)
- val = user(use)
- if isa(val, LLVM.CallInst)
- erase!(val)
- changed = true
- end
+ for f in functions(mod)
+ delete!(function_attributes(f), EnumAttribute("noreturn", 0))
end
- @compiler_assert isempty(uses(trap)) job
- erase!(trap)
- end
- for f in functions(mod)
- delete!(function_attributes(f), EnumAttribute("noreturn", 0))
- end
-
end
return changed
end |
|
I'm running into the following issue with this PR: julia> OpenCL.code_llvm((Int,); kernel = true) do x
fldmod1(x, 10)
nothing
end
ERROR: LLVM error: Invalid struct return type!
{ i1 } ([2 x i64]*, i64, i64)* @julia_fldmod1_71344
Stacktrace:
[1] verify(mod::LLVM.Module)
@ LLVM ~/.julia/packages/LLVM/UFrs4/src/analysis.jl:19
[2] finish_module!(job::GPUCompiler.CompilerJob{GPUCompiler.SPIRVCompilerTarget, OpenCL.OpenCLCompilerParams}, mod::LLVM.Module, entry::LLVM.Function)
@ GPUCompiler ~/.julia/dev/GPUCompiler/src/spirv.jl:80
[3] macro expansion
@ ~/.julia/dev/GPUCompiler/src/driver.jl:183 [inlined]
[4] emit_llvm(job::GPUCompiler.CompilerJob; kwargs::@Kwargs{})
@ GPUCompiler ~/.julia/dev/GPUCompiler/src/utils.jl:116
[5] emit_llvm(job::GPUCompiler.CompilerJob)
@ GPUCompiler ~/.julia/dev/GPUCompiler/src/utils.jl:114
[6] compile_unhooked(output::Symbol, job::GPUCompiler.CompilerJob; kwargs::@Kwargs{})
@ GPUCompiler ~/.julia/dev/GPUCompiler/src/driver.jl:95
[7] compile_unhooked
@ ~/.julia/dev/GPUCompiler/src/driver.jl:80 [inlined]
[8] compile(target::Symbol, job::GPUCompiler.CompilerJob; kwargs::@Kwargs{})
@ GPUCompiler ~/.julia/dev/GPUCompiler/src/driver.jl:67
[9] compile
@ ~/.julia/dev/GPUCompiler/src/driver.jl:55 [inlined]
[10] (::GPUCompiler.var"#186#187"{Bool, Symbol, Bool, GPUCompiler.CompilerJob{GPUCompiler.SPIRVCompilerTarget, OpenCL.OpenCLCompilerParams}, GPUCompiler.CompilerConfig{GPUCompiler.SPIRVCompilerTarget, OpenCL.OpenCLCompilerParams}})(ctx::LLVM.Context)
@ GPUCompiler ~/.julia/dev/GPUCompiler/src/reflection.jl:191
[11] JuliaContext(f::GPUCompiler.var"#186#187"{Bool, Symbol, Bool, GPUCompiler.CompilerJob{GPUCompiler.SPIRVCompilerTarget, OpenCL.OpenCLCompilerParams}, GPUCompiler.CompilerConfig{GPUCompiler.SPIRVCompilerTarget, OpenCL.OpenCLCompilerParams}}; kwargs::@Kwargs{})
@ GPUCompiler ~/.julia/dev/GPUCompiler/src/driver.jl:34
[12] JuliaContext(f::Function)
@ GPUCompiler ~/.julia/dev/GPUCompiler/src/driver.jl:25
[13] code_llvm(io::Base.TTY, job::GPUCompiler.CompilerJob; optimize::Bool, raw::Bool, debuginfo::Symbol, dump_module::Bool, kwargs::@Kwargs{})
@ GPUCompiler ~/.julia/dev/GPUCompiler/src/reflection.jl:190
[14] code_llvm
@ ~/.julia/dev/GPUCompiler/src/reflection.jl:186 [inlined]
[15] code_llvm(io::Base.TTY, func::Any, types::Any; kernel::Bool, kwargs::@Kwargs{})
@ OpenCL ~/.julia/dev/OpenCL/src/compiler/reflection.jl:33
[16] code_llvm(func::Any, types::Any; kwargs::@Kwargs{kernel::Bool})
@ OpenCL ~/.julia/dev/OpenCL/src/compiler/reflection.jl:35
[17] top-level scope
@ REPL[8]:1 |
|
It's not unlikely some corner cases are handled incorrectly, as I wasn't able to validate the change due to the LLVM SPIR-V back-end not supporting struct return. I'll try to take a look. |
|
Actually, I just remembered pocl/pocl#1971 (comment) where it was determined that this approach isn't viable. We cannot have one thread One option is to use a global flag that all callers check, but that would probably kill performance, as well as not cover the case where a call to an exception-throwing function isn't done by all threads: if work_item() == 0
fun_that_throws()
end
barrier() # deadlocksSo maybe the only way forwards is to require |
This because SPIR-V doesn't have
trap, meaningunreachableinstructions get executed. That's UB, triggering issues with the PoCL driver (x-ref pocl/pocl#1971).Sadly, this seems to be triggering bugs in the LLVM SPIR-V back-end, so looking into that.
EDIT: filed llvm/llvm-project#151344