Skip to content

Commit df14096

Browse files
authored
[NFC][AMDGPU] Refactor the multiclass for WMMA_F8F6F4 instructions (#172245)
1 parent 2e2e48f commit df14096

File tree

1 file changed

+36
-15
lines changed

1 file changed

+36
-15
lines changed

llvm/lib/Target/AMDGPU/VOP3PInstructions.td

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1814,21 +1814,42 @@ def F32_FP8BF8X128_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v16i32, v8f
18141814
def F16_FP8BF8X128_SWMMAC_w32 : VOP3PWMMA_Profile<[v8f16, v8i32, v16i32, v8f16], 1, 32, 0, 1, 1, 0, 0, 0, 1>;
18151815
def I32_IU8X128_SWMMAC_w32 : VOP3PWMMA_Profile<[v8i32, v8i32, v16i32, v8i32], 1, 32, 1, 0, 1, 0, 0, 0, 1>;
18161816

1817-
multiclass WMMA_F8F6F4_Profiles<bit HasMatrixScale, bit Scale16, bit HasMatrixReuse> {
1818-
def _f8_f8_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
1819-
def _f8_f6_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
1820-
def _f8_f4_w32 : VOP3PWMMA_Profile<[v8f32, v16i32, v8i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
1821-
def _f6_f8_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
1822-
def _f6_f6_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
1823-
def _f6_f4_w32 : VOP3PWMMA_Profile<[v8f32, v12i32, v8i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
1824-
def _f4_f8_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v16i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
1825-
def _f4_f6_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v12i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
1826-
def _f4_f4_w32 : VOP3PWMMA_Profile<[v8f32, v8i32, v8i32, v8f32], 0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
1827-
}
1828-
1829-
defm F32_16X16X128_F8F6F4 : WMMA_F8F6F4_Profiles<0, 0, 0>;
1830-
defm F32_16X16X128_F8F6F4_SCALE : WMMA_F8F6F4_Profiles<1, 0, 1>;
1831-
defm F32_16X16X128_F8F6F4_SCALE16 : WMMA_F8F6F4_Profiles<1, 1, 1>;
1817+
// Helper class to compute the destination vector type of WMMA_F8F6F4 instructions based on element type and dimensions.
1818+
class getWMMAF8F6F4DstVTy<ValueType DstEltTy, int M, int N> {
1819+
// Size in bits = (M * N / 32) * element_size_in_bits
1820+
defvar Size = !mul(!div(!mul(M, N), 32), DstEltTy.Size);
1821+
ValueType ret = !cond(!eq(Size, 256) : v8f32,
1822+
!eq(Size, 1024) : v64f16);
1823+
}
1824+
1825+
// Helper class to compute the type of matrix A and B of WMMA_F8F6F4 instructions based on format and dimensions.
1826+
class getWMMAF8F6F4ABVTy<string Fmt, int D1, int D2> {
1827+
defvar FmtBits = !cond(!eq(Fmt, "f8") : 8,
1828+
!eq(Fmt, "f6") : 6,
1829+
!eq(Fmt, "f4") : 4);
1830+
// TypeSize in bits = (D1 * D2 / 32) * format_bits
1831+
defvar TypeSize = !mul(!div(!mul(D1, D2), 32), FmtBits);
1832+
ValueType ret = !cond(!eq(TypeSize, 256) : v8i32,
1833+
!eq(TypeSize, 384) : v12i32,
1834+
!eq(TypeSize, 512) : v16i32,
1835+
!eq(TypeSize, 1024) : v32i32);
1836+
}
1837+
1838+
multiclass WMMA_F8F6F4_Profiles<ValueType DstEltTy, int M, int N, int K,
1839+
bit HasMatrixScale, bit Scale16, bit HasMatrixReuse> {
1840+
defvar DstTy = getWMMAF8F6F4DstVTy<DstEltTy, M, N>.ret;
1841+
foreach ATy = ["f8", "f6", "f4"] in {
1842+
foreach BTy = ["f8", "f6", "f4"] in {
1843+
def _#ATy#_#BTy#_w32 : VOP3PWMMA_Profile<
1844+
[DstTy, getWMMAF8F6F4ABVTy<ATy, M, K>.ret, getWMMAF8F6F4ABVTy<BTy, K, N>.ret, DstTy],
1845+
0, 0, 0, 1, 1, 1, HasMatrixScale, Scale16, HasMatrixReuse>;
1846+
}
1847+
}
1848+
}
1849+
1850+
defm F32_16X16X128_F8F6F4 : WMMA_F8F6F4_Profiles<f32, /*M=*/16, /*N=*/16, /*K=*/128, /*HasMatrixScale=*/0, /*Scale16=*/0, /*HasMatrixReuse=*/0>;
1851+
defm F32_16X16X128_F8F6F4_SCALE : WMMA_F8F6F4_Profiles<f32, /*M=*/16, /*N=*/16, /*K=*/128, /*HasMatrixScale=*/1, /*Scale16=*/0, /*HasMatrixReuse=*/1>;
1852+
defm F32_16X16X128_F8F6F4_SCALE16 : WMMA_F8F6F4_Profiles<f32, /*M=*/16, /*N=*/16, /*K=*/128, /*HasMatrixScale=*/1, /*Scale16=*/1, /*HasMatrixReuse=*/1>;
18321853

18331854
class VOP_WMMA_LD_SCALE<ValueType vt, RegisterOperand RC> : VOP3P_Profile<VOPProfile<[untyped, vt, vt, untyped]>> {
18341855
let HasMatrixScale = 1;

0 commit comments

Comments
 (0)