diff --git a/src/metal.jl b/src/metal.jl index 6f1749eb..fd571f10 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -1002,8 +1002,8 @@ function annotate_air_intrinsics!(@nospecialize(job::CompilerJob), mod::LLVM.Mod isdeclaration(f) || continue fn = LLVM.name(f) - attrs = function_attributes(f) - function add_attributes(names...) + fn_attrs = function_attributes(f) + function add_fn_attributes(names...) for name in names if LLVM.version() >= v"16" && name in ["argmemonly", "inaccessiblememonly", "inaccessiblemem_or_argmemonly", @@ -1011,36 +1011,48 @@ function annotate_air_intrinsics!(@nospecialize(job::CompilerJob), mod::LLVM.Mod # XXX: workaround for changes from https://reviews.llvm.org/D135780 continue end - push!(attrs, EnumAttribute(name, 0)) + push!(fn_attrs, EnumAttribute(name, 0)) + end + changed = true + end + + function add_param_attributes(idx, names...) + param_attrs = parameter_attributes(f, idx) + for name in names + push!(param_attrs, EnumAttribute(name, 0)) end changed = true end # synchronization if fn == "air.wg.barrier" || fn == "air.simdgroup.barrier" - add_attributes("nounwind", "mustprogress", "convergent", "willreturn") + add_fn_attributes("nounwind", "mustprogress", "convergent", "willreturn") + + # sincos + elseif match(r"^air.sincos", fn) !== nothing + add_param_attributes(2, "nocapture", "writeonly") # atomics elseif match(r"air.atomic.(local|global).load", fn) !== nothing # TODO: "memory(argmem: read)" on LLVM 16+ - add_attributes("argmemonly", "readonly", "nounwind") + add_fn_attributes("argmemonly", "readonly", "nounwind") elseif match(r"air.atomic.(local|global).store", fn) !== nothing # TODO: "memory(argmem: write)" on LLVM 16+ - add_attributes("argmemonly", "writeonly", "nounwind") + add_fn_attributes("argmemonly", "writeonly", "nounwind") elseif match(r"air.atomic.(local|global).(xchg|cmpxchg)", fn) !== nothing # TODO: "memory(argmem: readwrite)" on LLVM 16+ - add_attributes("argmemonly", "nounwind") + add_fn_attributes("argmemonly", "nounwind") elseif match(r"^air.atomic.(local|global).(add|sub|min|max|and|or|xor)", fn) !== nothing # TODO: "memory(argmem: readwrite)" on LLVM 16+ - add_attributes("argmemonly", "nounwind") + add_fn_attributes("argmemonly", "nounwind") # simdgroup elseif match(r"air.simdgroup_matrix_8x8_multiply_accumulate", fn) !== nothing - add_attributes("convergent", "mustprogress", "nounwind", "willreturn") + add_fn_attributes("convergent", "mustprogress", "nounwind", "willreturn") elseif match(r"air.simdgroup_matrix_8x8_load", fn) !== nothing - add_attributes("convergent", "mustprogress", "nofree", "nounwind", "readonly", "willreturn") + add_fn_attributes("convergent", "mustprogress", "nofree", "nounwind", "readonly", "willreturn") elseif match(r"air.simdgroup_matrix_8x8_store", fn) !== nothing - add_attributes("convergent", "mustprogress", "nounwind", "willreturn", "writeonly") + add_fn_attributes("convergent", "mustprogress", "nounwind", "willreturn", "writeonly") end end