Skip to content

Commit 0729a74

Browse files
authored
[ROCDL] Added s_wakeup_barrier (GFX1250) (#172320)
This PR adds `s_wakeup_barrier` op for GFX1250. Additionally, refactoring of the split/named barriers regarding the types in asm was performed.
1 parent 3c97829 commit 0729a74

File tree

4 files changed

+73
-48
lines changed

4 files changed

+73
-48
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -328,28 +328,28 @@ def ROCDL_BarrierOp : ROCDL_Op<"barrier"> {
328328
def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;
329329
def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>;
330330

331-
def ROCDL_BarrierInitOp : ROCDL_IntrOp<"s.barrier.init", [], [], [], 0, 0, 0, 0, [1], ["id"]>,
332-
Arguments<(ins Arg<ROCDLBufferLDS, "", []>:$ptr, I32Attr:$id)> {
331+
def ROCDL_BarrierInitOp : ROCDL_IntrOp<"s.barrier.init", [], [], [], 0, 0, 0, 0, [1], ["memberCnt"]>,
332+
Arguments<(ins Arg<ROCDLBufferLDS, "", []>:$ptr, I32Attr:$memberCnt)> {
333333
let description = [{
334334
Available on gfx1250+.
335335
}];
336336
let results = (outs);
337-
let assemblyFormat = "$ptr `,` $id attr-dict";
337+
let assemblyFormat = "$ptr `member_cnt` `=` $memberCnt attr-dict `:` qualified(type($ptr))";
338338
}
339339

340340
def ROCDL_BarrierSignalOp : ROCDL_ConcreteNonMemIntrOp<"s.barrier.signal", [], 0, [0], ["id"]>,
341341
Arguments<(ins I32Attr:$id)> {
342342
let results = (outs);
343-
let assemblyFormat = "$id attr-dict";
343+
let assemblyFormat = "`id` `=` $id attr-dict";
344344
}
345345

346-
def ROCDL_BarrierSignalVarOp : ROCDL_IntrOp<"s.barrier.signal.var", [], [], [], 0, 0, 0, 0, [1], ["id"]>,
347-
Arguments<(ins Arg<ROCDLBufferLDS, "", []>:$ptr, I32Attr:$id)> {
346+
def ROCDL_BarrierSignalVarOp : ROCDL_IntrOp<"s.barrier.signal.var", [], [], [], 0, 0, 0, 0, [1], ["memberCnt"]>,
347+
Arguments<(ins Arg<ROCDLBufferLDS, "", []>:$ptr, I32Attr:$memberCnt)> {
348348
let description = [{
349349
Available on gfx1250+.
350350
}];
351351
let results = (outs);
352-
let assemblyFormat = "$ptr `,` $id attr-dict";
352+
let assemblyFormat = "$ptr `member_cnt` `=` $memberCnt attr-dict `:` qualified(type($ptr))";
353353
}
354354

355355
def ROCDL_BarrierJoinOp : ROCDL_IntrOp<"s.barrier.join", [], [], [], 0>,
@@ -358,7 +358,7 @@ def ROCDL_BarrierJoinOp : ROCDL_IntrOp<"s.barrier.join", [], [], [], 0>,
358358
Available on gfx1250+.
359359
}];
360360
let results = (outs);
361-
let assemblyFormat = "$ptr attr-dict";
361+
let assemblyFormat = "$ptr attr-dict `:` qualified(type($ptr))";
362362
}
363363

364364
def ROCDL_BarrierLeaveOp : ROCDL_ConcreteNonMemIntrOp<"s.barrier.leave", [], 0, [0], ["id"]>,
@@ -367,13 +367,13 @@ def ROCDL_BarrierLeaveOp : ROCDL_ConcreteNonMemIntrOp<"s.barrier.leave", [], 0,
367367
Available on gfx1250+.
368368
}];
369369
let results = (outs);
370-
let assemblyFormat = "$id attr-dict";
370+
let assemblyFormat = "`id` `=` $id attr-dict";
371371
}
372372

373373
def ROCDL_BarrierWaitOp : ROCDL_ConcreteNonMemIntrOp<"s.barrier.wait", [], 0, [0], ["id"]>,
374374
Arguments<(ins I16Attr:$id)> {
375375
let results = (outs);
376-
let assemblyFormat = "$id attr-dict";
376+
let assemblyFormat = "`id` `=` $id attr-dict";
377377
}
378378

379379
def ROCDL_BarrierSignalIsfirstOp : ROCDL_ConcreteNonMemIntrOp<"s.barrier.signal.isfirst", [], 1, [0], ["id"]>,
@@ -382,7 +382,7 @@ def ROCDL_BarrierSignalIsfirstOp : ROCDL_ConcreteNonMemIntrOp<"s.barrier.signal.
382382
Available on gfx1250+.
383383
}];
384384
let results = (outs I1:$res);
385-
let assemblyFormat = "$id attr-dict `:` type($res)";
385+
let assemblyFormat = "`id` `=` $id attr-dict `->` type($res)";
386386
}
387387

388388
def ROCDL_GetBarrierStateOp : ROCDL_ConcreteNonMemIntrOp<"s.get.barrier.state", [], 1, [0], ["id"]>,
@@ -391,7 +391,7 @@ def ROCDL_GetBarrierStateOp : ROCDL_ConcreteNonMemIntrOp<"s.get.barrier.state",
391391
Available on gfx1250+.
392392
}];
393393
let results = (outs I32:$res);
394-
let assemblyFormat = "$id attr-dict `:` type($res)";
394+
let assemblyFormat = "`id` `=` $id attr-dict `->` type($res)";
395395
}
396396

397397
def ROCDL_GetNamedBarrierStateOp : ROCDL_ConcreteNonMemIntrOp<"s.get.named.barrier.state", [], 1, [], []>,
@@ -400,7 +400,18 @@ def ROCDL_GetNamedBarrierStateOp : ROCDL_ConcreteNonMemIntrOp<"s.get.named.barri
400400
Available on gfx1250+.
401401
}];
402402
let results = (outs I32:$res);
403-
let assemblyFormat = "$ptr attr-dict `:` type($res)";
403+
let assemblyFormat = "$ptr attr-dict `:` qualified(type($ptr)) `->` type($res)";
404+
}
405+
406+
def ROCDL_WakeupBarrierOp : ROCDL_ConcreteNonMemIntrOp<"s.wakeup.barrier", [], 0, [], []>,
407+
Arguments<(ins Arg<ROCDLBufferLDS, "", []>:$ptr)> {
408+
let description = [{
409+
Wakes up waves associated with a given named barrier. Note, This op does not release waves waiting
410+
at the barrier. It just signal other waves in the same work-group waiting on the indicated named barrier
411+
to wake up.
412+
Available on gfx1250+.
413+
}];
414+
let assemblyFormat = "$ptr attr-dict `:` qualified(type($ptr))";
404415
}
405416

406417
def ROCDL_WaitDscntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.dscnt", [], 0, [0], ["count"]>,

mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,8 @@ func.func @lds_barrier() {
419419
// GFX942-NEXT: rocdl.s.barrier
420420
// GFX10-NEXT: rocdl.s.barrier
421421
// GFX11-NEXT: rocdl.s.barrier
422-
// GFX12-NEXT: rocdl.s.barrier.signal -1
423-
// GFX12-NEXT: rocdl.s.barrier.wait -1
422+
// GFX12-NEXT: rocdl.s.barrier.signal id = -1
423+
// GFX12-NEXT: rocdl.s.barrier.wait id = -1
424424
// CHECK-NEXT: llvm.fence syncscope("workgroup") acquire {llvm.mmra = #[[$MMRA_TAG]]}
425425
amdgpu.lds_barrier
426426
func.return

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,64 +1199,71 @@ llvm.func @rocdl.s.barrier() {
11991199

12001200
llvm.func @rocdl.s.barrier.init(%ptr : !llvm.ptr<3>) {
12011201
// CHECK-LABEL: rocdl.s.barrier.init
1202-
// CHECK: rocdl.s.barrier.init %[[PTR:.+]], 1
1203-
rocdl.s.barrier.init %ptr, 1
1202+
// CHECK: rocdl.s.barrier.init %{{.*}} member_cnt = 1 : !llvm.ptr<3>
1203+
rocdl.s.barrier.init %ptr member_cnt = 1 : !llvm.ptr<3>
12041204
llvm.return
12051205
}
12061206

12071207
llvm.func @rocdl.s.barrier.signal() {
12081208
// CHECK-LABEL: rocdl.s.barrier.signal
1209-
// CHECK: rocdl.s.barrier.signal -1
1210-
rocdl.s.barrier.signal -1
1209+
// CHECK: rocdl.s.barrier.signal id = -1
1210+
rocdl.s.barrier.signal id = -1
12111211
llvm.return
12121212
}
12131213

12141214
llvm.func @rocdl.s.barrier.signal.var(%ptr : !llvm.ptr<3>) {
12151215
// CHECK-LABEL: rocdl.s.barrier.signal.var
1216-
// CHECK: rocdl.s.barrier.signal.var %[[PTR:.+]], 1
1217-
rocdl.s.barrier.signal.var %ptr, 1
1216+
// CHECK: rocdl.s.barrier.signal.var %{{.*}} member_cnt = 1 : !llvm.ptr<3>
1217+
rocdl.s.barrier.signal.var %ptr member_cnt = 1 : !llvm.ptr<3>
12181218
llvm.return
12191219
}
12201220

12211221
llvm.func @rocdl.s.barrier.join(%ptr : !llvm.ptr<3>) {
12221222
// CHECK-LABEL: rocdl.s.barrier.join
1223-
// CHECK: rocdl.s.barrier.join %[[PTR:.+]]
1224-
rocdl.s.barrier.join %ptr
1223+
// CHECK: rocdl.s.barrier.join %{{.*}} : !llvm.ptr<3>
1224+
rocdl.s.barrier.join %ptr : !llvm.ptr<3>
12251225
llvm.return
12261226
}
12271227

12281228
llvm.func @rocdl.s.barrier.leave() {
12291229
// CHECK-LABEL: rocdl.s.barrier.leave
1230-
// CHECK: rocdl.s.barrier.leave 1
1231-
rocdl.s.barrier.leave 1
1230+
// CHECK: rocdl.s.barrier.leave id = 1
1231+
rocdl.s.barrier.leave id = 1
12321232
llvm.return
12331233
}
12341234

12351235
llvm.func @rocdl.s.barrier.wait() {
12361236
// CHECK-LABEL: rocdl.s.barrier.wait
1237-
// CHECK: rocdl.s.barrier.wait -1
1238-
rocdl.s.barrier.wait -1
1237+
// CHECK: rocdl.s.barrier.wait id = -1
1238+
rocdl.s.barrier.wait id = -1
12391239
llvm.return
12401240
}
12411241

12421242
llvm.func @rocdl.s.barrier.signal.isfirst() {
12431243
// CHECK-LABEL: rocdl.s.barrier.signal.isfirst
1244-
// CHECK: rocdl.s.barrier.signal.isfirst 1
1245-
%0 = rocdl.s.barrier.signal.isfirst 1 : i1
1244+
// CHECK: rocdl.s.barrier.signal.isfirst id = 1 -> i1
1245+
%0 = rocdl.s.barrier.signal.isfirst id = 1 -> i1
12461246
llvm.return
12471247
}
12481248

12491249
llvm.func @rocdl.s.get.barrier.state() {
12501250
// CHECK-LABEL: rocdl.s.get.barrier.state
1251-
// CHECK: rocdl.s.get.barrier.state 1
1252-
%0 = rocdl.s.get.barrier.state 1 : i32
1251+
// CHECK: rocdl.s.get.barrier.state id = 1 -> i32
1252+
%0 = rocdl.s.get.barrier.state id = 1 -> i32
12531253
llvm.return
12541254
}
12551255

12561256
llvm.func @rocdl.s.get.named.barrier.state(%ptr : !llvm.ptr<3>) {
12571257
// CHECK-LABEL: rocdl.s.get.named.barrier.state
1258-
// CHECK: rocdl.s.get.named.barrier.state %[[PTR:.+]]
1259-
%0 = rocdl.s.get.named.barrier.state %ptr : i32
1258+
// CHECK: rocdl.s.get.named.barrier.state %{{.*}} : !llvm.ptr<3> -> i32
1259+
%0 = rocdl.s.get.named.barrier.state %ptr : !llvm.ptr<3> -> i32
1260+
llvm.return
1261+
}
1262+
1263+
llvm.func @rocdl.s.wakeup.barrier(%ptr : !llvm.ptr<3>) {
1264+
// CHECK-LABEL: rocdl.s.wakeup.barrier
1265+
// CHECK: rocdl.s.wakeup.barrier %{{.*}} : !llvm.ptr<3>
1266+
rocdl.s.wakeup.barrier %ptr : !llvm.ptr<3>
12601267
llvm.return
12611268
}
12621269

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -253,64 +253,71 @@ llvm.func @rocdl.barrier() {
253253

254254
llvm.func @rocdl.s.barrier.init(%ptr : !llvm.ptr<3>) {
255255
// CHECK-LABEL: rocdl.s.barrier.init
256-
// CHECK: call void @llvm.amdgcn.s.barrier.init(ptr addrspace(3) %[[PTR:.+]], i32 1)
257-
rocdl.s.barrier.init %ptr, 1
256+
// CHECK: call void @llvm.amdgcn.s.barrier.init(ptr addrspace(3) %{{.*}}, i32 1)
257+
rocdl.s.barrier.init %ptr member_cnt = 1 : !llvm.ptr<3>
258258
llvm.return
259259
}
260260

261261
llvm.func @rocdl.s.barrier.signal() {
262262
// CHECK-LABEL: rocdl.s.barrier.signal
263263
// CHECK-NEXT: call void @llvm.amdgcn.s.barrier.signal(i32 -1)
264-
rocdl.s.barrier.signal -1
264+
rocdl.s.barrier.signal id = -1
265265
llvm.return
266266
}
267267

268268
llvm.func @rocdl.s.barrier.signal.var(%ptr : !llvm.ptr<3>) {
269269
// CHECK-LABEL: rocdl.s.barrier.signal.var
270-
// CHECK: call void @llvm.amdgcn.s.barrier.signal.var(ptr addrspace(3) %[[PTR:.+]], i32 1)
271-
rocdl.s.barrier.signal.var %ptr, 1
270+
// CHECK: call void @llvm.amdgcn.s.barrier.signal.var(ptr addrspace(3) %{{.*}}, i32 1)
271+
rocdl.s.barrier.signal.var %ptr member_cnt = 1 : !llvm.ptr<3>
272272
llvm.return
273273
}
274274

275275
llvm.func @rocdl.s.barrier.join(%ptr : !llvm.ptr<3>) {
276276
// CHECK-LABEL: rocdl.s.barrier.join
277-
// CHECK: call void @llvm.amdgcn.s.barrier.join(ptr addrspace(3) %[[PTR:.+]])
278-
rocdl.s.barrier.join %ptr
277+
// CHECK: call void @llvm.amdgcn.s.barrier.join(ptr addrspace(3) %{{.*}})
278+
rocdl.s.barrier.join %ptr : !llvm.ptr<3>
279279
llvm.return
280280
}
281281

282282
llvm.func @rocdl.s.barrier.leave() {
283283
// CHECK-LABEL: rocdl.s.barrier.leave
284284
// CHECK: call void @llvm.amdgcn.s.barrier.leave(i16 1)
285-
rocdl.s.barrier.leave 1
285+
rocdl.s.barrier.leave id = 1
286286
llvm.return
287287
}
288288

289289
llvm.func @rocdl.s.barrier.wait() {
290290
// CHECK-LABEL: rocdl.s.barrier.wait
291291
// CHECK-NEXT: call void @llvm.amdgcn.s.barrier.wait(i16 -1)
292-
rocdl.s.barrier.wait -1
292+
rocdl.s.barrier.wait id = -1
293293
llvm.return
294294
}
295295

296296
llvm.func @rocdl.s.barrier.signal.isfirst() {
297297
// CHECK-LABEL: rocdl.s.barrier.signal.isfirst
298-
// CHECK: %[[OUT:.+]] = call i1 @llvm.amdgcn.s.barrier.signal.isfirst(i32 1)
299-
%0 = rocdl.s.barrier.signal.isfirst 1 : i1
298+
// CHECK: %{{.*}} = call i1 @llvm.amdgcn.s.barrier.signal.isfirst(i32 1)
299+
%0 = rocdl.s.barrier.signal.isfirst id = 1 -> i1
300300
llvm.return
301301
}
302302

303303
llvm.func @rocdl.s.get.barrier.state() {
304304
// CHECK-LABEL: rocdl.s.get.barrier.state
305-
// CHECK: %[[STATE:.+]] = call i32 @llvm.amdgcn.s.get.barrier.state(i32 1)
306-
%0 = rocdl.s.get.barrier.state 1 : i32
305+
// CHECK: %{{.*}} = call i32 @llvm.amdgcn.s.get.barrier.state(i32 1)
306+
%0 = rocdl.s.get.barrier.state id = 1 -> i32
307307
llvm.return
308308
}
309309

310310
llvm.func @rocdl.s.get.named.barrier.state(%ptr : !llvm.ptr<3>) {
311311
// CHECK-LABEL: rocdl.s.get.named.barrier.state
312-
// CHECK: %[[STATE:.+]] = call i32 @llvm.amdgcn.s.get.named.barrier.state(ptr addrspace(3) %[[PTR:.+]])
313-
%0 = rocdl.s.get.named.barrier.state %ptr : i32
312+
// CHECK: %{{.*}} = call i32 @llvm.amdgcn.s.get.named.barrier.state(ptr addrspace(3) %{{.*}})
313+
%0 = rocdl.s.get.named.barrier.state %ptr : !llvm.ptr<3> -> i32
314+
llvm.return
315+
}
316+
317+
llvm.func @rocdl.s.wakeup.barrier(%ptr : !llvm.ptr<3>) {
318+
// CHECK-LABEL: rocdl.s.wakeup.barrier
319+
// CHECK: call void @llvm.amdgcn.s.wakeup.barrier(ptr addrspace(3) %{{.*}})
320+
rocdl.s.wakeup.barrier %ptr : !llvm.ptr<3>
314321
llvm.return
315322
}
316323

0 commit comments

Comments
 (0)