From ff051b3c190235fa3636a635c395e0eb449bba09 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 13 Oct 2020 10:11:59 +0800 Subject: [PATCH 01/11] support pad to modulo in fwd --- driver/conv_driver.cpp | 2 +- driver/igemm_fwd_gtc_driver.h | 18 ++++++++--- igemm/algo/igemm_base.py | 2 +- igemm/algo/igemm_fwd_gtc.py | 60 ++++++++++++++++++++++------------- igemm/codegen/compile.py | 2 +- 5 files changed, 55 insertions(+), 29 deletions(-) diff --git a/driver/conv_driver.cpp b/driver/conv_driver.cpp index 9114b27a..6475d5c7 100755 --- a/driver/conv_driver.cpp +++ b/driver/conv_driver.cpp @@ -43,7 +43,7 @@ #endif #ifndef USE_MAGIC_DIV -#define USE_MAGIC_DIV 1 +#define USE_MAGIC_DIV 0 #endif #ifndef USE_SOURCE_ACCESS_ENCODING_KERNEL_NAME diff --git a/driver/igemm_fwd_gtc_driver.h b/driver/igemm_fwd_gtc_driver.h index 901fcba7..e66fde9e 100644 --- a/driver/igemm_fwd_gtc_driver.h +++ b/driver/igemm_fwd_gtc_driver.h @@ -143,6 +143,7 @@ class igemm_fwd_gtc_t { int gemm_m = k; int gemm_n = n * ho * wo; + // this also valid considering pad to modulo int grid_size = utility_integer_divide_ceil(gemm_m, gemm_m_per_block) * utility_integer_divide_ceil(gemm_n, gemm_n_per_block); return grid_size; @@ -176,10 +177,17 @@ class igemm_fwd_gtc_t { int gemm_n = n * ho * wo; int gemm_k = c * y * x; - if((gemm_n % gemm_n_per_block != 0) || (gemm_m % gemm_m_per_block != 0) || (gemm_k % gemm_k_per_block != 0)){ - // printf("tunable_is_valid false:: gemm_n is %d, gemm_n_per_block is %d, gemm_m is %d, gemm_m_per_block is %d\n", gemm_n,gemm_n_per_block,gemm_m,gemm_m_per_block); +#if 0 + // support pad to modulo, no need to check valid + if(gemm_n % gemm_n_per_block != 0) + return false; +#endif + + if(gemm_m % gemm_m_per_block != 0) + return false; + + if(gemm_k % gemm_k_per_block != 0) return false; - } if(gemm_n_per_block % tunable->nxb != 0){ // printf("tunable_is_valid false: gemm_n_per_block%tunable->nxb!=0, gemm_n_per_block is %d, tunable->nxb is %d\n", gemm_n_per_block, tunable->nxb); @@ -191,9 +199,11 @@ class igemm_fwd_gtc_t { return false; } - if( (ho * wo) % tunable->nxb != 0){ +#if 0 + if((ho * wo) % tunable->nxb != 0){ return false; } +#endif if(tunable->nxe == 0){ if((x!=1)||(y!=1)||(stride_h!=1)||(stride_w!=1)||(dilation_h!=1)||(dilation_w!=1)||(pad_h!=0)||(pad_w!=0)){ diff --git a/igemm/algo/igemm_base.py b/igemm/algo/igemm_base.py index 2061acca..11c65385 100755 --- a/igemm/algo/igemm_base.py +++ b/igemm/algo/igemm_base.py @@ -200,7 +200,7 @@ def __init__(self, tunable_dict): # assert type(self.opt_1x1) is bool assert self.direction in ('fwd', 'bwd', 'wrw') assert self.precision in ('fp32', 'fp16', 'bf16') - assert self.nxb in (1,4,8,16,32,64,256) + assert self.nxb in (1,4,8,16,32,64,128,256) assert self.nxe in (0,1) # TODO: better specify diff --git a/igemm/algo/igemm_fwd_gtc.py b/igemm/algo/igemm_fwd_gtc.py index a09268d1..9486c916 100644 --- a/igemm/algo/igemm_fwd_gtc.py +++ b/igemm/algo/igemm_fwd_gtc.py @@ -656,6 +656,9 @@ def __init__(self, mc, outer): self.s_gemm_k_num_y = sym_t("s_gemm_k_num_y" , self.s_y.value) self.s_gemm_k_num_x = sym_t("s_gemm_k_num_x" , self.s_x.value) + if outer.tunable.nxe != 0: + self.s_dim_b = sym_t("s_dim_b" , self.s_move_slice_k_y.value) + self.s_kitr = sym_t("s_kitr" ,1) if outer.tunable.precache_soffset: m_wei_2d_global_load, m_in_2d_global_load = outer.get_macro_global_load() @@ -720,7 +723,8 @@ def __init__(self, mc, outer): self.v_sld_b_os = sym_t("v_sld_b_os" ,vseq(1)) self.v_in_os = sym_t("v_in_os" ,vseq(1)) self.v_in_os_base = sym_t("v_in_os_base" ,vseq(1)) - self.v_in_flag = sym_t("v_in_flag" ,vseq(1)) + if outer.tunable.nxe != 0: + self.v_in_flag = sym_t("v_in_flag" ,vseq(1)) self.v_wei_os = sym_t("v_wei_os" ,vseq(1)) self.v_co_sst = sym_t("v_co_sst" ,vseq(1)) @@ -738,20 +742,22 @@ def __init__(self, mc, outer): self.v_gtc_tb_in1 = sym_t("v_gtc_tb_in1" ,vseq(1)) self.v_gtc_tb_ib = sym_t("v_gtc_tb_ib" ,vseq(1)) if outer.tunable.nxe != 0: - self.v_gtc_tb_ic1 = sym_t("v_gtc_tb_ic1" ,vseq(1)) + self.v_gtc_tb_ic1 = sym_t("v_gtc_tb_ic1" ,vseq(1)) self.v_out_os = sym_t("v_out_os" ,vseq(1)) - self.v_out_in0 = sym_t("v_out_in0" ,vseq(1)) - self.v_out_in1b = sym_t("v_out_in1b" ,vseq(1)) - self.v_out_in1 = sym_t("v_out_in1" ,vseq(1)) - - self.v_in_iho = sym_t("v_in_iho" ,vseq(1)) - self.v_in_iwo = sym_t("v_in_iwo" ,vseq(1)) - self.v_in_ihi = sym_t("v_in_ihi" ,vseq(1)) - self.v_in_iwi = sym_t("v_in_iwi" ,vseq(1)) if outer.tunable.nxe != 0: - self.v_in_iy = sym_t("v_in_iy" ,vseq(1)) - self.v_in_ix = sym_t("v_in_ix" ,vseq(1)) + self.v_out_flag = sym_t("v_out_flag" ,vseq(1)) + self.v_out_in0 = sym_t("v_out_in0" ,vseq(1)) + self.v_out_in1b = sym_t("v_out_in1b" ,vseq(1)) + self.v_out_in1 = sym_t("v_out_in1" ,vseq(1)) + + self.v_in_iho = sym_t("v_in_iho" ,vseq(1)) + self.v_in_iwo = sym_t("v_in_iwo" ,vseq(1)) + self.v_in_ihi = sym_t("v_in_ihi" ,vseq(1)) + self.v_in_iwi = sym_t("v_in_iwi" ,vseq(1)) + if outer.tunable.nxe != 0: + self.v_in_iy = sym_t("v_in_iy" ,vseq(1)) + self.v_in_ix = sym_t("v_in_ix" ,vseq(1)) self.v_move_slice_k_ic1 = sym_t("v_move_slice_k_ic1" , self.v_gtc_tb_ic1.value if outer.tunable.nxe != 0 else self.v_gtc_tb_ic1e.value) if outer.tunable.nxe != 0: @@ -1423,11 +1429,18 @@ def emit_kernel_prologue(self): else: self._emit(f"s_mov_b32 s[{s.s_knum()}], s[{s.s_c()}]") + # warp around the really dim_b length, in case pad + if self.tunable.nxe != 0: + self._emit(f"s_add_u32 s[{s.s_tmp()}], {self.tunable.nxb - 1}, s[{s.s_out_stride_k()}]") + self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_tmp()}], {igemm_log2(self.tunable.nxb)}") + self._emit(f"s_lshl_b32 s[{s.s_dim_b()}], s[{s.s_tmp(1)}], {igemm_log2(self.tunable.nxb)}") + + self._emit_empty_line() self._emit(f"; gemm_m_per_block:{self.tunable.gemm_m_per_block}, gemm_n_per_block:{self.tunable.gemm_n_per_block}, source_access_order:{self.tunable.source_access_order}") if self.tunable.source_access_order == IGEMM_GTC_TUNABLE_SOURCE_ACCESS_ORDER_GEMM_M_GEMM_N: if self.tunable.nxe != 0: - self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_out_stride_k()}], s[{s.s_n()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_dim_b()}], s[{s.s_n()}]") else: self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_stride_hw()}], s[{s.s_n()}]") @@ -1455,12 +1468,12 @@ def emit_kernel_prologue(self): if gemm_n_unmerge_cluster == 0: if self.tunable.nxe != 0: if unmerge_sub_n1 == 1: - self._emit(f"s_lshr_b32 s[0], s[{s.s_out_stride_k()}], {igemm_log2(nb_n1b)} ; total number of n1b") + self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_b()}], {igemm_log2(nb_n1b)} ; total number of n1b") else: if unmerge_sub_n1 == nb_n1b: - self._emit(f"s_mov_b32 s[0], s[{s.s_out_stride_k()}] ; total number of n1b") + self._emit(f"s_mov_b32 s[0], s[{s.s_dim_b()}] ; total number of n1b") else: - self._emit(f"s_lshr_b32 s[0], s[{s.s_out_stride_k()}], {igemm_log2(nb_n1b // unmerge_sub_n1)} ; total number of n1b") + self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_b()}], {igemm_log2(nb_n1b // unmerge_sub_n1)} ; total number of n1b") else: if unmerge_sub_n1 == 1: self._emit(f"s_lshr_b32 s[0], s[{s.s_stride_hw()}], {igemm_log2(nb_n1b)} ; total number of n1b") @@ -1472,7 +1485,7 @@ def emit_kernel_prologue(self): else: if self.tunable.nxe != 0: self._emit(f"s_lshr_b32 s[{s.s_tmp()}], s[{s.s_n()}], {igemm_log2(nb_n0)}") - self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_out_stride_k()}], s[{s.s_tmp()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_dim_b()}], s[{s.s_tmp()}]") self._emit(f"s_lshr_b32 s[0], s[{s.s_tmp(1)}], {igemm_log2(nb_n1b)}") else: self._emit(f"s_lshr_b32 s[{s.s_tmp()}], s[{s.s_n()}], {igemm_log2(nb_n0)}") @@ -1526,11 +1539,11 @@ def emit_kernel_prologue(self): if self.tunable.nxe != 0: if IGEMM_GTC_FEAT_MAGIC_DIVISION: self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") - self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_gtc_tb_in1(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_out_stride_k(), v.v_tmp())) + self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_gtc_tb_in1(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_b(), v.v_tmp())) self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") self._emit(m_mdiv_u32_vs(v.v_in_iwo(), v.v_in_iho(), v.v_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wo(), v.v_tmp())) else: - self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_gtc_tb_in1(), v.v_tmp(5), s.s_out_stride_k(), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_gtc_tb_in1(), v.v_tmp(5), s.s_dim_b(), v.v_tmp(), s.s_tmp())) self._emit(m_int_div_rem_vs(v.v_in_iwo(), v.v_in_iho(), v.v_tmp(4), s.s_wo(), v.v_tmp(), s.s_tmp())) self._emit(f"v_mul_lo_u32 v[{v.v_in_iho()}], s[{s.s_stride_h()}], v[{v.v_in_iho()}]") self._emit(f"v_sub_i32 v[{v.v_in_iho()}], v[{v.v_in_iho()}], s[{s.s_pad_h()}]") @@ -1752,11 +1765,11 @@ def emit_kernel_prologue(self): if self.tunable.nxe != 0: if IGEMM_GTC_FEAT_MAGIC_DIVISION: self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") - self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_out_in1(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_out_stride_k(), v.v_tmp())) + self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_out_in1(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_b(), v.v_tmp())) self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") self._emit(m_mdiv_u32_vs(v.v_out_iwo(), v.v_out_iho(), v.v_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wo(), v.v_tmp())) else: - self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_out_in1(), v.v_tmp(5), s.s_out_stride_k(), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_out_in1(), v.v_tmp(5), s.s_dim_b(), v.v_tmp(), s.s_tmp())) self._emit(m_int_div_rem_vs(v.v_out_iwo(), v.v_out_iho(), v.v_tmp(4), s.s_wo(), v.v_tmp(), s.s_tmp())) self._emit_empty_line() else: @@ -1813,6 +1826,8 @@ def emit_kernel_prologue(self): self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_wo() if self.tunable.nxe != 0 else s.s_wi()}], v[{v.v_out_iho()}]") self._emit(f"v_add3_u32 v[{v.v_out_os()}], v[{v.v_out_os()}], v[{v.v_tmp(1)}], v[{v.v_out_iwo()}]") self._emit(f"v_lshlrev_b32 v[{v.v_out_os()}], {igemm_log2(data_byte)}, v[{v.v_out_os()}]") + if self.tunable.nxe != 0: + self._emit(m_set_flag_hw(v.v_out_flag(), v.v_out_iho(), v.v_out_iwo(), s.s_ho(), s.s_wo())) self._emit(f"; move slice stride") assert na_c0 * na_c1e == self.tunable.gemm_k_per_block and nb_c0 * nb_c1e == self.tunable.gemm_k_per_block @@ -2037,7 +2052,8 @@ def emit_kernel_epilogue(self): a = self.agpr self._emit(self.coalescing_store(a.a_c(), v.v_c(), v.v_co_sst(), v.v_co_sld(), s.s_p_out(), v.v_out_os(), None, - s.s_out_stride_k0() if self.tunable.gemm_m_unmerge_cluster == 1 else None, s.s_out_stride_k(), s.s_tmp())) + s.s_out_stride_k0() if self.tunable.gemm_m_unmerge_cluster == 1 else None, s.s_out_stride_k(), s.s_tmp(), + v.v_out_flag() if self.tunable.nxe != 0 else None)) self._emit_front(f"{self.label_out}:") diff --git a/igemm/codegen/compile.py b/igemm/codegen/compile.py index 15561c29..df712ac6 100644 --- a/igemm/codegen/compile.py +++ b/igemm/codegen/compile.py @@ -141,7 +141,7 @@ def compile(self, **kwargs): if IGEMM_HOST_USE_XDNN: cmd += [f'-I{bytes.fromhex(xdnnroot).decode()}/include', '-DUSE_XDNN'] if IGEMM_HOST_USE_MAGIC_DIV: - cmd += ['-DUSE_MAGIC_DIV'] + cmd += ['-DUSE_MAGIC_DIV=1'] if 'cflags' in kwargs: cmd += kwargs['cflags'] if 'cxxflags' in kwargs: From 7bedfb54c049fd1b2e65fca672ba585b5f27fb17 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 13 Oct 2020 15:36:44 +0800 Subject: [PATCH 02/11] fix a bug in vgpr alloc, and check dim_b when nxe==0 --- driver/igemm_fwd_gtc_driver.h | 28 +++++++++++++++------------- igemm/algo/igemm_fwd_gtc.py | 8 ++++---- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/driver/igemm_fwd_gtc_driver.h b/driver/igemm_fwd_gtc_driver.h index e66fde9e..1db050ca 100644 --- a/driver/igemm_fwd_gtc_driver.h +++ b/driver/igemm_fwd_gtc_driver.h @@ -139,11 +139,13 @@ class igemm_fwd_gtc_t { int gemm_m_per_block = tunable->gemm_m_per_block; int gemm_n_per_block = tunable->gemm_n_per_block; + int nxe = tunable->nxe; + int nxb = tunable->nxb; + int b = nxe == 0 ? (ho * wo) : ((ho * wo + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0 int gemm_m = k; - int gemm_n = n * ho * wo; + int gemm_n = n * b; - // this also valid considering pad to modulo int grid_size = utility_integer_divide_ceil(gemm_m, gemm_m_per_block) * utility_integer_divide_ceil(gemm_n, gemm_n_per_block); return grid_size; @@ -173,15 +175,17 @@ class igemm_fwd_gtc_t { int gemm_n_per_block = tunable->gemm_n_per_block; int gemm_k_per_block = tunable->gemm_k_per_block; - int gemm_m = k; - int gemm_n = n * ho * wo; - int gemm_k = c * y * x; + int nxe = tunable->nxe; + int nxb = tunable->nxb; + int b = nxe == 0 ? (ho * wo) : ((ho * wo + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0 + + int gemm_m = k; + int gemm_n = n * b; + int gemm_k = c * y * x; -#if 0 - // support pad to modulo, no need to check valid - if(gemm_n % gemm_n_per_block != 0) + // support pad to modulo, hence only check when nxe is 0 + if((nxe == 0) && (gemm_n % gemm_n_per_block != 0)) return false; -#endif if(gemm_m % gemm_m_per_block != 0) return false; @@ -199,13 +203,11 @@ class igemm_fwd_gtc_t { return false; } -#if 0 - if((ho * wo) % tunable->nxb != 0){ + if((nxe == 0) && ((ho * wo) % tunable->nxb != 0)){ return false; } -#endif - if(tunable->nxe == 0){ + if(nxe == 0){ if((x!=1)||(y!=1)||(stride_h!=1)||(stride_w!=1)||(dilation_h!=1)||(dilation_w!=1)||(pad_h!=0)||(pad_w!=0)){ return false; } diff --git a/igemm/algo/igemm_fwd_gtc.py b/igemm/algo/igemm_fwd_gtc.py index 17d6ffcb..5e5b9690 100644 --- a/igemm/algo/igemm_fwd_gtc.py +++ b/igemm/algo/igemm_fwd_gtc.py @@ -706,7 +706,7 @@ def __init__(self, mc, outer): else: v_c_resuable_num = outer.tunable.num_vgpr_accumulate_a + outer.tunable.num_vgpr_accumulate_b + \ outer.tunable.num_vgpr_global_load_a + outer.tunable.num_vgpr_global_load_b + \ - 8 # from v_sst_a_os to v_wei_os + 16 # from v_sst_a_os to v_co_sst v_c_coalescing_num = outer.tunable.num_agpr_accumulate_c // outer.coalescing_store_groups v_c_needed = (v_c_coalescing_num - v_c_resuable_num) if (v_c_coalescing_num - v_c_resuable_num) > 0 else 0 @@ -727,9 +727,6 @@ def __init__(self, mc, outer): self.v_in_flag = sym_t("v_in_flag" ,vseq(1)) self.v_wei_os = sym_t("v_wei_os" ,vseq(1)) - self.v_co_sst = sym_t("v_co_sst" ,vseq(1)) - self.v_co_sld = sym_t("v_co_sld" ,vseq(1)) - self.v_gtc_ta_ik1 = sym_t("v_gtc_ta_ik1" ,vseq(1)) self.v_gtc_ta_ik0 = sym_t("v_gtc_ta_ik0" ,vseq(1)) self.v_gtc_ta_ic1e = sym_t("v_gtc_ta_ic1e" ,vseq(1)) @@ -744,6 +741,9 @@ def __init__(self, mc, outer): if outer.tunable.nxe != 0: self.v_gtc_tb_ic1 = sym_t("v_gtc_tb_ic1" ,vseq(1)) + self.v_co_sst = sym_t("v_co_sst" ,vseq(1)) + self.v_co_sld = sym_t("v_co_sld" ,vseq(1)) + self.v_out_os = sym_t("v_out_os" ,vseq(1)) if outer.tunable.nxe != 0: self.v_out_flag = sym_t("v_out_flag" ,vseq(1)) From 5f3045294b7d70ddc9a63c08e5f2402103d2b4c0 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 13 Oct 2020 16:19:23 +0800 Subject: [PATCH 03/11] pretty fastest_id --- driver/conv_driver.cpp | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/driver/conv_driver.cpp b/driver/conv_driver.cpp index 6475d5c7..e1a918ef 100755 --- a/driver/conv_driver.cpp +++ b/driver/conv_driver.cpp @@ -428,7 +428,7 @@ int main(int argc, char **argv) { if (need_fwd){ result_t fastest_result_fwd; fastest_result_fwd.duration_ms = FLT_MAX; - int fastest_id = 0; + int fastest_id = -1; float *device_output_to_host = NULL; if (need_verify) { // gen rand @@ -492,7 +492,10 @@ int main(int argc, char **argv) { } if(log_fastest_config){ dump_arg(&conv_args); - printf(" fastest: [%d]%s, cost:%.3fms, tflops:%.3f(%.2f%%)\n", + if(fastest_id == -1) + printf(" fastest: no suitable kernel\n"); + else + printf(" fastest: [%d]%s, cost:%.3fms, tflops:%.3f(%.2f%%)\n", fastest_id, fastest_result_fwd.kernel_name.c_str(), fastest_result_fwd.duration_ms, @@ -507,7 +510,7 @@ int main(int argc, char **argv) { float *device_input_to_host = NULL; result_t fastest_result_bwd; fastest_result_bwd.duration_ms = FLT_MAX; - int fastest_id; + int fastest_id = -1; if (need_verify) { // gen rand gen_rand_vector(host_output, n * k * ho * wo, 0.0, 1.0); @@ -590,7 +593,10 @@ int main(int argc, char **argv) { } if(log_fastest_config){ dump_arg(&conv_args); - printf(" fastest: [%d]%s, cost:%.3fms, tflops:%.3f(%.2f%%)\n", + if(fastest_id == -1) + printf(" fastest: no suitable kernel\n"); + else + printf(" fastest: [%d]%s, cost:%.3fms, tflops:%.3f(%.2f%%)\n", fastest_id, fastest_result_bwd.kernel_name.c_str(), fastest_result_bwd.duration_ms, From 22be6f10ceafee476ce4da3c5ea013d285ae0b53 Mon Sep 17 00:00:00 2001 From: Jane-zxy Date: Wed, 14 Oct 2020 10:46:16 +0000 Subject: [PATCH 04/11] support padding for odd-size image for bwd --- driver/igemm_bwd_gtc_driver.h | 13 ++++++++++--- igemm/algo/igemm_bwd_gtc.py | 23 ++++++++++++++++------- 2 files changed, 26 insertions(+), 10 deletions(-) mode change 100644 => 100755 igemm/algo/igemm_bwd_gtc.py diff --git a/driver/igemm_bwd_gtc_driver.h b/driver/igemm_bwd_gtc_driver.h index b26e0d9e..c438a2f3 100755 --- a/driver/igemm_bwd_gtc_driver.h +++ b/driver/igemm_bwd_gtc_driver.h @@ -247,8 +247,15 @@ class igemm_bwd_gtc_t { int w_tilda_slice = w_tilda_right - w_tilda_left; int gemm_m = c; + int gemm_n = n * h_tilda_slice * w_tilda_slice; + int nxe = tunable->nxe; + int nxb = tunable->nxb; + int b = h_tilda_slice * w_tilda_slice; + b = (nxe == 0) ? (b) : ((b + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0 + gemm_n = n * b; + int grid_size = utility_integer_divide_ceil(gemm_m, gemm_m_per_block) * utility_integer_divide_ceil(gemm_n, gemm_n_per_block); int num_of_gemm = y_tilda * x_tilda; @@ -316,7 +323,7 @@ class igemm_bwd_gtc_t { int gemm_m = c; int gemm_n = n * h_tilda_slice * w_tilda_slice; - +/* if((gemm_n%gemm_n_per_block!=0)||(gemm_m%gemm_m_per_block!=0)){ // printf("tunable_is_valid false:: gemm_n is %d, gemm_n_per_block is %d, gemm_m is %d, gemm_m_per_block is %d\n", gemm_n,gemm_n_per_block,gemm_m,gemm_m_per_block); return false; @@ -334,7 +341,7 @@ class igemm_bwd_gtc_t { if( (h_tilda_slice * w_tilda_slice) % tunable->nxb != 0){ return false; } - +*/ bool gemm_k_valid = true; for(int gemm_id = 0; gemm_id < num_of_gemm; gemm_id++){ int i_y_tilda = gemm_id / x_tilda; @@ -460,7 +467,7 @@ class igemm_bwd_gtc_t { hipFunction_t kernel_func; std::string kernel_name = get_kernel_name(tunable); - // printf("kernel:%s\n, block:%d, grid:%d\n", kernel_name.c_str(), block_size, grid_size); + printf("kernel:%s\n, block:%d, grid:%d\n", kernel_name.c_str(), block_size, grid_size); HIP_CALL( hipModuleGetFunction(&kernel_func, module, kernel_name.c_str())); diff --git a/igemm/algo/igemm_bwd_gtc.py b/igemm/algo/igemm_bwd_gtc.py old mode 100644 new mode 100755 index 64aa45d4..643f971c --- a/igemm/algo/igemm_bwd_gtc.py +++ b/igemm/algo/igemm_bwd_gtc.py @@ -742,6 +742,9 @@ def __init__(self, mc, outer): self.s_stride_dslice_hw = sym_t("s_stride_dslice_hw" ,sseq(1)) self.s_stride_dslice_yx = sym_t("s_stride_dslice_yx" ,sseq(1)) + if outer.tunable.nxe != 0: + self.s_dslice_dim_b = sym_t("s_dslice_dim_b", self.s_stride_dslice_hw.value) + if outer.tunable.nxe != 0: self.s_out_stride_k_k1 = sym_t("s_out_stride_k_k1" ,self.s_stride_h.value) self.s_out_stride_k_k0_k1_diff = sym_t("s_out_stride_k_k0_k1_diff",self.s_stride_w.value) @@ -1489,6 +1492,12 @@ def emit_kernel_prologue(self): self._emit(f"s_mul_i32 s[{s.s_wei_stride_k()}], s[{s.s_c()}], s[{s.s_wei_stride_c()}]") self._emit(f"s_mul_i32 s[{s.s_stride_dslice_hw()}], s[{s.s_dslice_h()}], s[{s.s_dslice_w()}]") self._emit(f"s_mul_i32 s[{s.s_stride_dslice_yx()}], s[{s.s_dslice_y()}], s[{s.s_dslice_x()}]") + + self._emit(f"; change for padding") + self._emit(f"s_add_u32 s[{s.s_tmp()}], {self.tunable.nxb - 1}, s[{s.s_stride_dslice_hw()}]") + self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_tmp()}], {igemm_log2(self.tunable.nxb)}") + self._emit(f"s_lshl_b32 s[{s.s_dslice_dim_b()}], s[{s.s_tmp(1)}], {igemm_log2(self.tunable.nxb)}") + if t_k0 != 1: self._emit(f"s_lshl_b32 s[{s.s_out_stride_k0()}], s[{s.s_out_stride_k()}], {igemm_log2(unmerge_sub_k1)}") self._emit(f"s_lshl_b32 s[{s.s_wei_stride_k0()}], s[{s.s_wei_stride_k()}], {igemm_log2(unmerge_sub_k1)}") @@ -1551,7 +1560,7 @@ def emit_kernel_prologue(self): self._emit_empty_line() self._emit(f"; gemm_m_per_block:{self.tunable.gemm_m_per_block}, gemm_n_per_block:{self.tunable.gemm_n_per_block}") if self.tunable.nxe != 0: - self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_stride_dslice_hw()}], s[{s.s_n()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_dslice_dim_b()}], s[{s.s_n()}]") else: self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_stride_hw()}], s[{s.s_n()}]") self._emit(f"s_lshr_b32 s[0], s[{s.s_tmp()}], {igemm_log2(self.tunable.gemm_n_per_block)}") @@ -1566,14 +1575,14 @@ def emit_kernel_prologue(self): if gemm_n_unmerge_cluster == 0: if self.tunable.nxe != 0: if unmerge_sub_n1 == 1: - self._emit(f"s_lshr_b32 s[0], s[{s.s_stride_dslice_hw()}], {igemm_log2(n_n1b)} ; total number of n1b") + self._emit(f"s_lshr_b32 s[0], s[{s.s_dslice_dim_b()}], {igemm_log2(n_n1b)} ; total number of n1b") else: if unmerge_sub_n1 == n_n1b: - self._emit(f"s_mov_b32 s[0], s[{s.s_stride_dslice_hw()}] ; total number of n1b") + self._emit(f"s_mov_b32 s[0], s[{s.s_dslice_dim_b()}] ; total number of n1b") else: # self._emit(f"s_lshl_b32 s[{s.s_tmp()}], s[{s.s_stride_dslice_hw()}], {igemm_log2(unmerge_sub_n1)} ; total number of n1b") # self._emit(f"s_lshr_b32 s[0], s[{s.s_tmp()}], {igemm_log2(n_n1b)}") - self._emit(f"s_lshr_b32 s[0], s[{s.s_stride_dslice_hw()}], {igemm_log2(n_n1b // unmerge_sub_n1)} ; total number of n1b") + self._emit(f"s_lshr_b32 s[0], s[{s.s_dslice_dim_b()}], {igemm_log2(n_n1b // unmerge_sub_n1)} ; total number of n1b") else: if unmerge_sub_n1 == 1: self._emit(f"s_lshr_b32 s[0], s[{s.s_stride_hw()}], {igemm_log2(n_n1b)} ; total number of n1b") @@ -1587,7 +1596,7 @@ def emit_kernel_prologue(self): else: if self.tunable.nxe != 0: self._emit(f"s_lshr_b32 s[{s.s_tmp()}], s[{s.s_n()}], {igemm_log2(n_n0)}") - self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_stride_dslice_hw()}], s[{s.s_tmp()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_dslice_dim_b()}], s[{s.s_tmp()}]") self._emit(f"s_lshr_b32 s[0], s[{s.s_tmp(1)}], {igemm_log2(n_n1b)}") else: self._emit(f"s_lshr_b32 s[{s.s_tmp()}], s[{s.s_n()}], {igemm_log2(n_n0)}") @@ -1608,7 +1617,7 @@ def emit_kernel_prologue(self): else: self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_in1b()}], v[{v.v_gtc_in1b()}]") if self.tunable.nxe != 0: - self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_gtc_in1(), v.v_tmp(5), s.s_stride_dslice_hw(), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_gtc_in1(), v.v_tmp(5), s.s_dslice_dim_b(), v.v_tmp(), s.s_tmp())) self._emit(m_int_div_rem_vs(v.v_out_dslice_iw(), v.v_out_dslice_ih(), v.v_tmp(4), s.s_dslice_w(), v.v_tmp(), s.s_tmp())) self._emit_empty_line() self._emit(f"; iHTildaLeft, iWTildaLeft") @@ -1827,7 +1836,7 @@ def emit_kernel_prologue(self): self._emit(f"; compute from n1b") if self.tunable.nxe != 0: self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_in1b()}], v[{v.v_in_in1b()}]") - self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_in_in1(), v.v_tmp(5), s.s_stride_dslice_hw(), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_in_in1(), v.v_tmp(5), s.s_dslice_dim_b(), v.v_tmp(), s.s_tmp())) self._emit(m_int_div_rem_vs(v.v_in_dslice_iw(), v.v_in_dslice_ih(), v.v_tmp(4), s.s_dslice_w(), v.v_tmp(), s.s_tmp())) self._emit_empty_line() self._emit(f"v_add_u32 v[{v.v_in_dslice_ih()}], s[{s.s_dslice_h_left()}], v[{v.v_in_dslice_ih()}]") From f189e9572535155f7e1d389d3475a0d914991ad6 Mon Sep 17 00:00:00 2001 From: jane-zxy Date: Fri, 16 Oct 2020 00:52:42 +0800 Subject: [PATCH 05/11] modify for isValid function --- driver/igemm_bwd_gtc_driver.h | 18 +++++++++--------- igemm/algo/igemm_bwd_gtc.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/driver/igemm_bwd_gtc_driver.h b/driver/igemm_bwd_gtc_driver.h index c438a2f3..c8212080 100755 --- a/driver/igemm_bwd_gtc_driver.h +++ b/driver/igemm_bwd_gtc_driver.h @@ -247,14 +247,11 @@ class igemm_bwd_gtc_t { int w_tilda_slice = w_tilda_right - w_tilda_left; int gemm_m = c; - - int gemm_n = n * h_tilda_slice * w_tilda_slice; - int nxe = tunable->nxe; int nxb = tunable->nxb; int b = h_tilda_slice * w_tilda_slice; b = (nxe == 0) ? (b) : ((b + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0 - gemm_n = n * b; + int gemm_n = n * b; int grid_size = utility_integer_divide_ceil(gemm_m, gemm_m_per_block) * utility_integer_divide_ceil(gemm_n, gemm_n_per_block); @@ -322,8 +319,12 @@ class igemm_bwd_gtc_t { int num_of_gemm = y_tilda * x_tilda; int gemm_m = c; - int gemm_n = n * h_tilda_slice * w_tilda_slice; -/* + int nxe = tunable->nxe; + int nxb = tunable->nxb; + int b = h_tilda_slice * w_tilda_slice; + b = (nxe == 0) ? (b) : ((b + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0 + int gemm_n = n * b; + if((gemm_n%gemm_n_per_block!=0)||(gemm_m%gemm_m_per_block!=0)){ // printf("tunable_is_valid false:: gemm_n is %d, gemm_n_per_block is %d, gemm_m is %d, gemm_m_per_block is %d\n", gemm_n,gemm_n_per_block,gemm_m,gemm_m_per_block); return false; @@ -338,10 +339,9 @@ class igemm_bwd_gtc_t { // printf("tunable_is_valid false: n%(gemm_n_per_block/tunable->nxb)!=0, gemm_n_per_block is %d, tunable->nxb is %d\n", gemm_n_per_block, tunable->nxb); return false; } - if( (h_tilda_slice * w_tilda_slice) % tunable->nxb != 0){ + if( (tunable->nxe == 0)&& ((h_tilda_slice * w_tilda_slice) % tunable->nxb != 0) ){ return false; } -*/ bool gemm_k_valid = true; for(int gemm_id = 0; gemm_id < num_of_gemm; gemm_id++){ int i_y_tilda = gemm_id / x_tilda; @@ -467,7 +467,7 @@ class igemm_bwd_gtc_t { hipFunction_t kernel_func; std::string kernel_name = get_kernel_name(tunable); - printf("kernel:%s\n, block:%d, grid:%d\n", kernel_name.c_str(), block_size, grid_size); + //printf("kernel:%s\n, block:%d, grid:%d\n", kernel_name.c_str(), block_size, grid_size); HIP_CALL( hipModuleGetFunction(&kernel_func, module, kernel_name.c_str())); diff --git a/igemm/algo/igemm_bwd_gtc.py b/igemm/algo/igemm_bwd_gtc.py index 643f971c..2797cc5a 100755 --- a/igemm/algo/igemm_bwd_gtc.py +++ b/igemm/algo/igemm_bwd_gtc.py @@ -1493,7 +1493,7 @@ def emit_kernel_prologue(self): self._emit(f"s_mul_i32 s[{s.s_stride_dslice_hw()}], s[{s.s_dslice_h()}], s[{s.s_dslice_w()}]") self._emit(f"s_mul_i32 s[{s.s_stride_dslice_yx()}], s[{s.s_dslice_y()}], s[{s.s_dslice_x()}]") - self._emit(f"; change for padding") + self._emit(f"; pad b into multiplier of nxb") self._emit(f"s_add_u32 s[{s.s_tmp()}], {self.tunable.nxb - 1}, s[{s.s_stride_dslice_hw()}]") self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_tmp()}], {igemm_log2(self.tunable.nxb)}") self._emit(f"s_lshl_b32 s[{s.s_dslice_dim_b()}], s[{s.s_tmp(1)}], {igemm_log2(self.tunable.nxb)}") From b421f5184b3ae3402fef6c06322ee69bdcb66915 Mon Sep 17 00:00:00 2001 From: Jane-zxy Date: Fri, 30 Oct 2020 16:30:03 +0000 Subject: [PATCH 06/11] bwd-functional-ok --- config/igemm_bwd_generate.config | 4 + igemm/__init__.py | 1 + igemm/igemm_config_gen_driver.py | 206 +++++++++++++++++++++++++++++++ igemm_codegen.py | 11 +- 4 files changed, 216 insertions(+), 6 deletions(-) create mode 100644 config/igemm_bwd_generate.config create mode 100644 igemm/igemm_config_gen_driver.py diff --git a/config/igemm_bwd_generate.config b/config/igemm_bwd_generate.config new file mode 100644 index 00000000..f1f5d94a --- /dev/null +++ b/config/igemm_bwd_generate.config @@ -0,0 +1,4 @@ +[codegen] +arch = 'gfx908' +code_object = 'cov3' +mode = 'sequencer' diff --git a/igemm/__init__.py b/igemm/__init__.py index edc0344f..8db99acb 100644 --- a/igemm/__init__.py +++ b/igemm/__init__.py @@ -29,6 +29,7 @@ from .codegen import * from .algo import * from .igemm_codegen_driver import * +from .igemm_config_gen_driver import * if sys.hexversion < 0x30600f0: print("must use python 3.6+. current is {}".format(sys.version)) diff --git a/igemm/igemm_config_gen_driver.py b/igemm/igemm_config_gen_driver.py new file mode 100644 index 00000000..ed43a8f0 --- /dev/null +++ b/igemm/igemm_config_gen_driver.py @@ -0,0 +1,206 @@ +################################################################################ +# +# MIT License +# +# Copyright (c) 2020 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################ +# pylint: disable=maybe-no-member + +from .algo import * +from .codegen import * + +import os +import copy + +ctrl_xdlops_mapping_fp32_debug = [ + ctrl_xdlops_mapping_t( 256, 128, 64, 32, 4, 2, 2, 1, 1, v_mfma_f32_32x32x1f32), + ctrl_xdlops_mapping_t( 128, 256, 32, 64, 4, 2, 2, 1, 1, v_mfma_f32_32x32x1f32), + ctrl_xdlops_mapping_t( 256, 64 , 64, 16, 4, 2, 2, 1, 1, v_mfma_f32_16x16x1f32), + ctrl_xdlops_mapping_t( 64 , 256, 16, 64, 4, 2, 2, 1, 1, v_mfma_f32_16x16x1f32), + ctrl_xdlops_mapping_t( 256, 32 , 64, 4 , 4, 2, 2, 1, 2, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 32 , 256, 4 , 64, 4, 2, 2, 2, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 256, 16 , 64, 4 , 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 16 , 256, 4 , 64, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), + + ctrl_xdlops_mapping_t( 128, 128, 32, 32, 4, 2, 2, 1, 1, v_mfma_f32_16x16x1f32), + ctrl_xdlops_mapping_t( 128, 128, 32, 64, 4, 1, 1, 2, 1, v_mfma_f32_32x32x1f32), + ctrl_xdlops_mapping_t( 128, 64 , 32, 8 , 4, 2, 2, 1, 2, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 64 , 128, 8 , 32, 4, 2, 2, 2, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 64 , 128, 32, 64, 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32), + ctrl_xdlops_mapping_t( 64 , 128, 64, 32, 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32), + ctrl_xdlops_mapping_t( 128, 32 , 32, 8 , 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 32 , 128, 8 , 32, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), + #ctrl_xdlops_mapping_t( 32 , 128, 16, 64, 4, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), + ctrl_xdlops_mapping_t( 64 , 64 , 16, 16, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 128, 16 , 64, 16, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), + ctrl_xdlops_mapping_t( 16 , 128, 16, 64, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), + ctrl_xdlops_mapping_t( 64 , 32 , 32, 8 , 4, 1, 1, 1, 2, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 32 , 64 , 8 , 32, 4, 1, 1, 2, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 32 , 32 , 16, 16, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), + + ctrl_xdlops_mapping_t( 64 , 16 , 64, 4 , 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 16 , 64 , 4 , 64, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 64 , 16 , 64, 4 , 2, 1, 1, 1, 2, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 16 , 64 , 4 , 64, 2, 1, 1, 2, 1, v_mfma_f32_4x4x1f32), + + # 2waves, block_size=128 + ctrl_xdlops_mapping_t( 64 , 8 , 64, 4 , 2, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 8 , 64 , 4 , 64, 2, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 32 , 16 , 32, 8 , 2, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 16 , 32 , 8 , 32, 2, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), + # 1 wave + ctrl_xdlops_mapping_t( 32 , 16 , 32, 8 , 1, 1, 1, 1, 2, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 16 , 32 , 8 , 32, 1, 1, 1, 2, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 64 , 4 , 64, 4 , 1, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 4 , 64, 4 , 64, 1, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 16 , 16, 16, 16, 1, 1, 1, 1, 1, v_mfma_f32_4x4x1f32) + ] + +class igemm_config_gen_driver_t(): + def __init__(self, emitter, config_content): + self.emitter = emitter + self.emitter.open() + self.config_content = config_content + + def non1numbers(self, length_list): + n = 0 + for i in range(len(length_list)): + if length_list[i] > 1: + n += 1 + return n + def __del__(self): + self.emitter.close() + + def emit(self, s): + self.emitter.emit(s) + + def emit_one_valid_config(self): + self.emit(f"[igemm_bwd_gtc]") + self.emit(f"{'gemm_m_per_block':25}= {self.gemm_m_per_block}") + self.emit(f"{'gemm_n_per_block':25}= {self.gemm_n_per_block}") + self.emit(f"{'gemm_k_per_block':25}= {self.gemm_k_per_block}") + self.emit(f"{'wave_tile_m':25}= {self.wave_tile_m}") + self.emit(f"{'wave_step_m':25}= {self.wave_step_m}") + self.emit(f"{'wave_repeat_m':25}= {self.wave_repeat_m}") + self.emit(f"{'wave_tile_n':25}= {self.wave_tile_n}") + self.emit(f"{'wave_step_n':25}= {self.wave_step_n}") + self.emit(f"{'wave_repeat_n':25}= {self.wave_repeat_n}") + self.emit(f"{'tensor_a_thread_lengths':25}= [{self.t_k0}, {self.t_k1e}, {self.t_c0},{self.t_c1}]") + self.emit(f"{'tensor_a_cluster_lengths':25}= [{self.c_k0}, {self.c_k1e}, {self.c_c0},{self.c_c1}]") + self.emit(f"{'tensor_b_thread_lengths':25}= [{self.t_k0}, {self.t_k1e}, {self.t_n0},{self.t_n1b}]") + self.emit(f"{'tensor_b_cluster_lengths':25}= [{self.c_k0}, {self.c_k1e}, {self.c_n0},{self.c_n1b}]") + self.emit(f"{'direction':25}= '{self.direction}'") + self.emit(f"{'precision':25}= '{self.precision}'") + self.emit(f"{'nxb':25}= {self.nxb}") + self.emit(f"{'nxe':25}= {self.nxe}") + self.emit('') + + def __call__(self): + sec_root = self.config_content.get_section('codegen')[0] + + self.emit("[codegen]") + self.emit(f"arch = '{sec_root['arch']}'") + self.emit(f"code_object = '{sec_root['code_object']}'") + self.emit("mode = 'flat'") + self.emit('') + self.direction = 'bwd' + self.precision = 'fp32' + + for t in ctrl_xdlops_mapping_fp32_debug: + self.emit(f"### {t.macro_tile_m}x{t.macro_tile_n}") + self.gemm_m_per_block = t.macro_tile_m + self.gemm_n_per_block = t.macro_tile_n + self.wave_tile_m = t.wave_tile_m + self.wave_step_m = t.wave_step_m + self.wave_repeat_m = t.wave_repeat_m + self.wave_tile_n = t.wave_tile_n + self.wave_step_n = t.wave_step_n + self.wave_repeat_n = t.wave_repeat_n + waves_per_m = self.gemm_m_per_block // (self.wave_tile_m * self.wave_step_m * self.wave_repeat_m) + waves_per_n = self.gemm_n_per_block // (self.wave_tile_n * self.wave_step_n * self.wave_repeat_n) + self.block_size = waves_per_m * waves_per_n * AMDGPU_WAVE_SIZE + + potential_nxb_list = [128,64,32,16,8,4,1] + potential_nxe_list = [0,1] + potential_k_list = [16,8,4] + self.gemm_k_per_block = 16 + for i_k in potential_k_list: + self.gemm_k_per_block = i_k + b_data_per_thread = (self.gemm_n_per_block*self.gemm_k_per_block)//self.block_size + a_data_per_thread = (self.gemm_m_per_block*self.gemm_k_per_block)//self.block_size + + # when nxe=0, nxb=[1,4,8,16,32,64,128], t_n1b self.gemm_n_per_block: + continue + for i_nxe in potential_nxe_list: + if i_nxe == 0: + potential_t_n1b_list = [4,2,1] + elif i_nxe == 1: + potential_t_n1b_list = [1] + + for i_t_n1b in potential_t_n1b_list: + self.t_n1b = i_t_n1b + self.nxb = i_nxb + self.nxe = i_nxe + self.t_k1e = 1 + self.t_c1 = 1 + self.c_k0 = 1 + self.c_n0 = 1 + self.c_c0 = 1 + self.c_n1b = self.gemm_n_per_block*2 + while self.c_n1b>1: + self.c_n1b = self.c_n1b//2 + if self.c_n1b*self.t_n1b > self.gemm_n_per_block: + continue + self.c_k1e = self.block_size//self.c_n1b #a? + if self.c_k1e > self.gemm_k_per_block: + continue + self.t_n0 = self.gemm_n_per_block//(self.c_n1b*self.t_n1b) #b? + self.t_k0 = self.gemm_k_per_block//self.c_k1e #c? + if self.t_k0 != (b_data_per_thread//(self.t_n0*self.t_n1b)): + continue + self.c_c1 = self.c_n1b #d? + if i_nxe == 0: + potential_t_c1_list = [4,2,1] + elif i_nxe == 1: + potential_t_c1_list = [1] + #assert unmerge_sub_n % n_n0 == 0, f"unmerge_sub_n:{unmerge_sub_n}, n_n0:{n_n0}" + if (self.gemm_n_per_block//self.nxb % (self.t_n0*self.c_n0) != 0): + continue + if self.non1numbers([self.t_k0, self.t_n0, self.t_n1b]) > 2: #check [t_k0, t_k1e, t_c0] + continue + for i_t_c1 in potential_t_c1_list: #e? + self.t_c1 = i_t_c1 + self.t_c0 = self.gemm_m_per_block//(self.c_c1*self.t_c1) #f? + if self.t_c0 == 0: + continue + if self.t_k0*self.t_c0*self.t_c1 !=a_data_per_thread: + continue + if self.non1numbers([self.t_k0, self.t_c0, self.t_c1]) > 2: #check [t_k0, t_k1e, t_c0] + continue + self.emit_one_valid_config() + diff --git a/igemm_codegen.py b/igemm_codegen.py index e72db444..8b7a8c21 100755 --- a/igemm_codegen.py +++ b/igemm_codegen.py @@ -78,10 +78,9 @@ def igemm_out_tunable_param(output_file, config_content): list_emitter.emit(td_item.output()) list_emitter.close() -#def igemm_sequence(args, config_content): -# kseq = v4r1_dynamic_kernel_sequencer_t(amdgpu_get_gfx906_60cu(), -# config_content.get_section('v4r1_dynamic_kernel')[0].to_dict()) -# kseq() +def igemm_sequence(args, config_content): + emitter = mc_emit_to_file_t('new.config') + igemm_config_gen_driver_t(emitter, config_content)() if __name__ == '__main__': parser = argparse.ArgumentParser() @@ -106,6 +105,6 @@ def igemm_out_tunable_param(output_file, config_content): igemm_flatten(args, config_content) if config_content.get_section('codegen')[0]['mode'] in ('seq', 'sequencer'): - # config_content.dump() - # igemm_sequence(args, config_content) + config_content.dump() + igemm_sequence(args, config_content) pass From e9120ba43c048acf35145d4c23bfb1cf80f2097a Mon Sep 17 00:00:00 2001 From: Jane-zxy Date: Thu, 5 Nov 2020 11:35:34 +0000 Subject: [PATCH 07/11] add several codegen_config restriction support: nxb,nxe,gemm_k,micro-tile,micro-tile_with_gemm_k_4 --- config/igemm_bwd_generate.config | 10 +- igemm/igemm_codegen_driver.py | 3 +- igemm/igemm_config_gen_driver.py | 216 ++++++++++++++++++++----------- igemm_codegen.py | 11 +- 4 files changed, 159 insertions(+), 81 deletions(-) diff --git a/config/igemm_bwd_generate.config b/config/igemm_bwd_generate.config index f1f5d94a..a23769ea 100644 --- a/config/igemm_bwd_generate.config +++ b/config/igemm_bwd_generate.config @@ -1,4 +1,12 @@ [codegen] arch = 'gfx908' code_object = 'cov3' -mode = 'sequencer' +mode = 'generate' +direction = 'bwd' + +[codegen_config] +nxb = '4,1' +nxe = '0,1' +gemm_k = '16,8,4' +micro_tile_with_gemm_k_4 = '16x32,32x16' +precision = 'fp32' diff --git a/igemm/igemm_codegen_driver.py b/igemm/igemm_codegen_driver.py index 49aefd11..fb880114 100755 --- a/igemm/igemm_codegen_driver.py +++ b/igemm/igemm_codegen_driver.py @@ -142,7 +142,7 @@ def get_kernel_per_inc_file_name(ker, origin_file_name): self.mc.emitter = emitter_per_inc_dict[kpi_file_name] if IGEMM_EMIT_KERNEL_METADATA_PER_INC_FILE: kinfo_per_inc_dict[kpi_file_name].append(kernel.get_kernel_info()) - + print('Jane debug, kernel name is'+kernel.name()) self._emit(';----------------------------------------------------------') self._emit('; starting of kernel {}'.format(kernel.name())) self._emit(kernel.tunable.serialize()) @@ -196,3 +196,4 @@ def do_compile(self): def __call__(self): self.do_emit() self.do_compile() + diff --git a/igemm/igemm_config_gen_driver.py b/igemm/igemm_config_gen_driver.py index ed43a8f0..508bc69d 100644 --- a/igemm/igemm_config_gen_driver.py +++ b/igemm/igemm_config_gen_driver.py @@ -31,30 +31,46 @@ import os import copy -ctrl_xdlops_mapping_fp32_debug = [ +# micro-tile: +# 128x128,256x128,256x64,128x256,64x256,128x64,64x128,256x32,32x256,64x64,128x32,32x128,256x16,16x256,128x16,16x128,64x32,32x64,32x32,64x16,16x64,32x16,16x32,64x8,8x64,16x16,64x4,4x64 +ctrl_xdlops_mapping_fp32_config = [ ctrl_xdlops_mapping_t( 256, 128, 64, 32, 4, 2, 2, 1, 1, v_mfma_f32_32x32x1f32), ctrl_xdlops_mapping_t( 128, 256, 32, 64, 4, 2, 2, 1, 1, v_mfma_f32_32x32x1f32), ctrl_xdlops_mapping_t( 256, 64 , 64, 16, 4, 2, 2, 1, 1, v_mfma_f32_16x16x1f32), + ctrl_xdlops_mapping_t( 256, 64 , 64, 32, 4, 1, 1, 2, 1, v_mfma_f32_32x32x1f32), #add by jane ctrl_xdlops_mapping_t( 64 , 256, 16, 64, 4, 2, 2, 1, 1, v_mfma_f32_16x16x1f32), + ctrl_xdlops_mapping_t( 64 , 256, 32, 64, 4, 1, 1, 1, 2, v_mfma_f32_32x32x1f32), #add by jane + ctrl_xdlops_mapping_t( 256, 32 , 64, 4 , 4, 2, 2, 1, 2, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 32 , 256, 4 , 64, 4, 2, 2, 2, 1, v_mfma_f32_4x4x1f32), + #ctrl_xdlops_mapping_t( 32 , 256, 16 , 64, 4, 1, 1, 1, 2, v_mfma_f32_16x16x1f32), #add by jane can not because coleasing group assert + ctrl_xdlops_mapping_t( 256, 16 , 64, 4 , 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 16 , 256, 4 , 64, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 128, 128, 32, 32, 4, 2, 2, 1, 1, v_mfma_f32_16x16x1f32), ctrl_xdlops_mapping_t( 128, 128, 32, 64, 4, 1, 1, 2, 1, v_mfma_f32_32x32x1f32), ctrl_xdlops_mapping_t( 128, 64 , 32, 8 , 4, 2, 2, 1, 2, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 128, 64 , 64, 32 , 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32), #add by jane + ctrl_xdlops_mapping_t( 64 , 128, 8 , 32, 4, 2, 2, 2, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 64 , 128, 32, 64, 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32), ctrl_xdlops_mapping_t( 64 , 128, 64, 32, 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32), ctrl_xdlops_mapping_t( 128, 32 , 32, 8 , 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 128, 32 , 64, 16 , 4, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), #add by jane + ctrl_xdlops_mapping_t( 32 , 128, 8 , 32, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), #ctrl_xdlops_mapping_t( 32 , 128, 16, 64, 4, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), ctrl_xdlops_mapping_t( 64 , 64 , 16, 16, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 64 , 64 , 32, 32, 4, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), #add by jane + ctrl_xdlops_mapping_t( 128, 16 , 64, 16, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), ctrl_xdlops_mapping_t( 16 , 128, 16, 64, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), ctrl_xdlops_mapping_t( 64 , 32 , 32, 8 , 4, 1, 1, 1, 2, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 64 , 32 , 32, 32 , 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), #add by jane; ctrl_xdlops_mapping_t( 32 , 64 , 8 , 32, 4, 1, 1, 2, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 32 , 64 , 32 , 32, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), #add by jane; + ctrl_xdlops_mapping_t( 32 , 32 , 16, 16, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 64 , 16 , 64, 4 , 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), @@ -114,6 +130,13 @@ def emit_one_valid_config(self): self.emit(f"{'nxe':25}= {self.nxe}") self.emit('') + def get_specific_ctrl_xdlops_mapping_t(self, macro_tile_m, macro_tile_n): + target_mfma_tiling = list() + for t in ctrl_xdlops_mapping_fp32_config: + if t.macro_tile_m == macro_tile_m and t.macro_tile_n == macro_tile_n: + target_mfma_tiling.append(t) + return target_mfma_tiling + def __call__(self): sec_root = self.config_content.get_section('codegen')[0] @@ -122,85 +145,130 @@ def __call__(self): self.emit(f"code_object = '{sec_root['code_object']}'") self.emit("mode = 'flat'") self.emit('') + + gen_config_lit = [sec.to_dict() for sec in self.config_content if sec.get_name().startswith('codegen_config')] + for gen_conf in gen_config_lit: + if sec_root['direction'] == 'bwd': + self.emit_bwd_configs(gen_conf) + + + + def emit_bwd_configs(self, gen_conf): self.direction = 'bwd' - self.precision = 'fp32' - - for t in ctrl_xdlops_mapping_fp32_debug: - self.emit(f"### {t.macro_tile_m}x{t.macro_tile_n}") - self.gemm_m_per_block = t.macro_tile_m - self.gemm_n_per_block = t.macro_tile_n - self.wave_tile_m = t.wave_tile_m - self.wave_step_m = t.wave_step_m - self.wave_repeat_m = t.wave_repeat_m - self.wave_tile_n = t.wave_tile_n - self.wave_step_n = t.wave_step_n - self.wave_repeat_n = t.wave_repeat_n - waves_per_m = self.gemm_m_per_block // (self.wave_tile_m * self.wave_step_m * self.wave_repeat_m) - waves_per_n = self.gemm_n_per_block // (self.wave_tile_n * self.wave_step_n * self.wave_repeat_n) - self.block_size = waves_per_m * waves_per_n * AMDGPU_WAVE_SIZE + self.precision = gen_conf['precision'] + if 'micro_tile' in gen_conf: + micro_tile_array = gen_conf['micro_tile'].split(',') + potential_micro_tile_list = [i for i in micro_tile_array] + else: + potential_micro_tile_list = ['128x128','256x128','256x64','128x256','64x256','128x64','64x128','256x32','32x256','64x64','128x32','32x128','256x16','16x256','128x16','16x128','64x32','32x64','32x32','64x16','16x64','32x16','16x32','64x8','8x64','16x16','64x4','4x64'] + if 'nxb' in gen_conf: + nxb_array = gen_conf['nxb'].split(',') + potential_nxb_list = [int(i) for i in nxb_array] + else: potential_nxb_list = [128,64,32,16,8,4,1] + + if 'nxe' in gen_conf: + nxe_array = gen_conf['nxe'].split(',') + potential_nxe_list = [int(i) for i in nxe_array] + else: potential_nxe_list = [0,1] + + if 'gemm_k' in gen_conf: + gemm_k_array = gen_conf['gemm_k'].split(',') + potential_k_list = [int(i) for i in gemm_k_array] + else: potential_k_list = [16,8,4] - self.gemm_k_per_block = 16 - for i_k in potential_k_list: - self.gemm_k_per_block = i_k - b_data_per_thread = (self.gemm_n_per_block*self.gemm_k_per_block)//self.block_size - a_data_per_thread = (self.gemm_m_per_block*self.gemm_k_per_block)//self.block_size - - # when nxe=0, nxb=[1,4,8,16,32,64,128], t_n1b self.gemm_n_per_block: - continue - for i_nxe in potential_nxe_list: - if i_nxe == 0: - potential_t_n1b_list = [4,2,1] - elif i_nxe == 1: - potential_t_n1b_list = [1] - - for i_t_n1b in potential_t_n1b_list: - self.t_n1b = i_t_n1b - self.nxb = i_nxb - self.nxe = i_nxe - self.t_k1e = 1 - self.t_c1 = 1 - self.c_k0 = 1 - self.c_n0 = 1 - self.c_c0 = 1 - self.c_n1b = self.gemm_n_per_block*2 - while self.c_n1b>1: - self.c_n1b = self.c_n1b//2 - if self.c_n1b*self.t_n1b > self.gemm_n_per_block: - continue - self.c_k1e = self.block_size//self.c_n1b #a? - if self.c_k1e > self.gemm_k_per_block: - continue - self.t_n0 = self.gemm_n_per_block//(self.c_n1b*self.t_n1b) #b? - self.t_k0 = self.gemm_k_per_block//self.c_k1e #c? - if self.t_k0 != (b_data_per_thread//(self.t_n0*self.t_n1b)): - continue - self.c_c1 = self.c_n1b #d? - if i_nxe == 0: - potential_t_c1_list = [4,2,1] - elif i_nxe == 1: - potential_t_c1_list = [1] - #assert unmerge_sub_n % n_n0 == 0, f"unmerge_sub_n:{unmerge_sub_n}, n_n0:{n_n0}" - if (self.gemm_n_per_block//self.nxb % (self.t_n0*self.c_n0) != 0): - continue - if self.non1numbers([self.t_k0, self.t_n0, self.t_n1b]) > 2: #check [t_k0, t_k1e, t_c0] - continue - for i_t_c1 in potential_t_c1_list: #e? - self.t_c1 = i_t_c1 - self.t_c0 = self.gemm_m_per_block//(self.c_c1*self.t_c1) #f? - if self.t_c0 == 0: + + if 'micro_tile_with_gemm_k_4' in gen_conf: + micro_tile_with_gemm_k_4_array = gen_conf['micro_tile_with_gemm_k_4'].split(',') + micro_tile_with_gemm_k_4_list = [i for i in micro_tile_with_gemm_k_4_array] + else: + micro_tile_with_gemm_k_4_list = ['32x16','16x32'] + + for item in potential_micro_tile_list: + tile = item.split('x') + self.gemm_m_per_block = int(tile[0]) + self.gemm_n_per_block = int(tile[1]) + target_xdlops_t_list = self.get_specific_ctrl_xdlops_mapping_t(self.gemm_m_per_block, self.gemm_n_per_block) + for t in target_xdlops_t_list: + self.emit(f"### {t.macro_tile_m}x{t.macro_tile_n}") + self.gemm_m_per_block = t.macro_tile_m + self.gemm_n_per_block = t.macro_tile_n + self.wave_tile_m = t.wave_tile_m + self.wave_step_m = t.wave_step_m + self.wave_repeat_m = t.wave_repeat_m + self.wave_tile_n = t.wave_tile_n + self.wave_step_n = t.wave_step_n + self.wave_repeat_n = t.wave_repeat_n + waves_per_m = self.gemm_m_per_block // (self.wave_tile_m * self.wave_step_m * self.wave_repeat_m) + waves_per_n = self.gemm_n_per_block // (self.wave_tile_n * self.wave_step_n * self.wave_repeat_n) + self.block_size = waves_per_m * waves_per_n * AMDGPU_WAVE_SIZE + + for i_k in potential_k_list: + self.gemm_k_per_block = i_k + if i_k == 4: + cur_tile_str = f"{self.gemm_m_per_block}x{self.gemm_n_per_block}" + if cur_tile_str not in micro_tile_with_gemm_k_4_list: + break + + + b_data_per_thread = (self.gemm_n_per_block*self.gemm_k_per_block)//self.block_size + a_data_per_thread = (self.gemm_m_per_block*self.gemm_k_per_block)//self.block_size + + # when nxe=0, nxb=[1,4,8,16,32,64,128], t_n1b self.gemm_n_per_block: + continue + for i_nxe in potential_nxe_list: + if i_nxe == 0: + potential_t_n1b_list = [4,2,1] + elif i_nxe == 1: + potential_t_n1b_list = [1] + + for i_t_n1b in potential_t_n1b_list: + self.t_n1b = i_t_n1b + self.nxb = i_nxb + self.nxe = i_nxe + self.t_k1e = 1 + self.t_c1 = 1 + self.c_k0 = 1 + self.c_n0 = 1 + self.c_c0 = 1 + self.c_n1b = self.gemm_n_per_block*2 + while self.c_n1b>1: + self.c_n1b = self.c_n1b//2 + if self.c_n1b*self.t_n1b > self.gemm_n_per_block: + continue + self.c_k1e = self.block_size//self.c_n1b #a? + if self.c_k1e > self.gemm_k_per_block: + continue + self.t_n0 = self.gemm_n_per_block//(self.c_n1b*self.t_n1b) #b? + self.t_k0 = self.gemm_k_per_block//self.c_k1e #c? + if self.t_k0 != (b_data_per_thread//(self.t_n0*self.t_n1b)): continue - if self.t_k0*self.t_c0*self.t_c1 !=a_data_per_thread: + self.c_c1 = self.c_n1b #d? + if i_nxe == 0: + potential_t_c1_list = [4,2,1] + elif i_nxe == 1: + potential_t_c1_list = [1] + #assert unmerge_sub_n % n_n0 == 0, f"unmerge_sub_n:{unmerge_sub_n}, n_n0:{n_n0}" + if (self.gemm_n_per_block//self.nxb % (self.t_n0*self.c_n0) != 0): continue - if self.non1numbers([self.t_k0, self.t_c0, self.t_c1]) > 2: #check [t_k0, t_k1e, t_c0] + if self.non1numbers([self.t_k0, self.t_n0, self.t_n1b]) > 2: #check [t_k0, t_k1e, t_c0] continue - self.emit_one_valid_config() + for i_t_c1 in potential_t_c1_list: #e? + self.t_c1 = i_t_c1 + self.t_c0 = self.gemm_m_per_block//(self.c_c1*self.t_c1) #f? + if self.t_c0 == 0: + continue + if self.t_k0*self.t_c0*self.t_c1 !=a_data_per_thread: + continue + if self.non1numbers([self.t_k0, self.t_c0, self.t_c1]) > 2: #check [t_k0, t_k1e, t_c0] + continue + self.emit_one_valid_config() diff --git a/igemm_codegen.py b/igemm_codegen.py index 8b7a8c21..4a4b5f72 100755 --- a/igemm_codegen.py +++ b/igemm_codegen.py @@ -78,8 +78,10 @@ def igemm_out_tunable_param(output_file, config_content): list_emitter.emit(td_item.output()) list_emitter.close() -def igemm_sequence(args, config_content): - emitter = mc_emit_to_file_t('new.config') +def igemm_generate(args, config_content): + sec_root = config_content.get_section('codegen')[0] + config_file_name = f"igemm_{sec_root['direction']}_gtc_{sec_root['arch']}.config" + emitter = mc_emit_to_file_t(config_file_name) igemm_config_gen_driver_t(emitter, config_content)() if __name__ == '__main__': @@ -104,7 +106,6 @@ def igemm_sequence(args, config_content): igemm_host_driver(args, config_content) igemm_flatten(args, config_content) - if config_content.get_section('codegen')[0]['mode'] in ('seq', 'sequencer'): - config_content.dump() - igemm_sequence(args, config_content) + if config_content.get_section('codegen')[0]['mode'] in ('gen', 'generate'): + igemm_generate(args, config_content) pass From 6a32f0362ff3345f93ecd74952251f153463a610 Mon Sep 17 00:00:00 2001 From: Jane-zxy Date: Thu, 5 Nov 2020 17:48:35 +0000 Subject: [PATCH 08/11] add more xdlops_mapping_t --- igemm/algo/xdlops_mapping.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/igemm/algo/xdlops_mapping.py b/igemm/algo/xdlops_mapping.py index fcaaddc4..bbd430b4 100755 --- a/igemm/algo/xdlops_mapping.py +++ b/igemm/algo/xdlops_mapping.py @@ -253,9 +253,13 @@ def serialize(self): ctrl_xdlops_mapping_t( 256, 128, 64, 32, 4, 2, 2, 1, 1, v_mfma_f32_32x32x1f32), ctrl_xdlops_mapping_t( 128, 256, 32, 64, 4, 2, 2, 1, 1, v_mfma_f32_32x32x1f32), ctrl_xdlops_mapping_t( 256, 64 , 64, 16, 4, 2, 2, 1, 1, v_mfma_f32_16x16x1f32), + ctrl_xdlops_mapping_t( 256, 64 , 64, 32, 4, 1, 1, 2, 1, v_mfma_f32_32x32x1f32), #add by jane + ctrl_xdlops_mapping_t( 64 , 256, 16, 64, 4, 2, 2, 1, 1, v_mfma_f32_16x16x1f32), + ctrl_xdlops_mapping_t( 64 , 256, 32, 64, 4, 1, 1, 1, 2, v_mfma_f32_32x32x1f32), #add by jane ctrl_xdlops_mapping_t( 256, 32 , 64, 4 , 4, 2, 2, 1, 2, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 32 , 256, 4 , 64, 4, 2, 2, 2, 1, v_mfma_f32_4x4x1f32), + #ctrl_xdlops_mapping_t( 32 , 256, 16 , 64, 4, 1, 1, 1, 2, v_mfma_f32_16x16x1f32), #add by jane, can not because coleasing group assert ctrl_xdlops_mapping_t( 256, 16 , 64, 4 , 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 16 , 256, 4 , 64, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), @@ -266,18 +270,24 @@ def serialize(self): ctrl_xdlops_mapping_t( 128, 128, 32, 64, 4, 1, 1, 2, 1, v_mfma_f32_32x32x1f32), ctrl_xdlops_mapping_t( 128, 64 , 32, 8 , 4, 2, 2, 1, 2, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 64 , 128, 8 , 32, 4, 2, 2, 2, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 128, 64 , 64, 32 , 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32), #add by jane + ctrl_xdlops_mapping_t( 64 , 128, 32, 64, 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32), ctrl_xdlops_mapping_t( 64 , 128, 64, 32, 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32), ctrl_xdlops_mapping_t( 128, 32 , 32, 8 , 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 128, 32 , 64, 16 , 4, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), #add by jane, it's better ctrl_xdlops_mapping_t( 32 , 128, 8 , 32, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 32 , 128, 16, 64, 4, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), ctrl_xdlops_mapping_t( 64 , 64 , 16, 16, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 64 , 64 , 32, 32, 4, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), #add by jane; it's better #ctrl_xdlops_mapping_t( 128, 16 , 64, 4 , 4, 1, 1, 2, 1, v_mfma_f32_4x4x1f32), #ctrl_xdlops_mapping_t( 16 , 128, 4 , 64, 4, 1, 1, 1, 2, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 128, 16 , 64, 16, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), ctrl_xdlops_mapping_t( 16 , 128, 16, 64, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), ctrl_xdlops_mapping_t( 64 , 32 , 32, 8 , 4, 1, 1, 1, 2, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 64 , 32 , 32, 32 , 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), #add by jane; it's better ctrl_xdlops_mapping_t( 32 , 64 , 8 , 32, 4, 1, 1, 2, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 32 , 64 , 32 , 32, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), #add by jane; ctrl_xdlops_mapping_t( 32 , 32 , 16, 16, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), #ctrl_xdlops_mapping_t( 256, 4 , 64, 4 , 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), # TODO: small/skinny gemm #ctrl_xdlops_mapping_t( 4 , 256, 4 , 64, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), # TODO: small/skinny gemm From ddc75a685e60064d1e9c851dc7bf7106fb57799f Mon Sep 17 00:00:00 2001 From: Jane-zxy Date: Fri, 6 Nov 2020 14:26:02 +0000 Subject: [PATCH 09/11] remove some useless comments --- driver/igemm_fwd_gtc_driver.h | 15 --------------- igemm/igemm_codegen_driver.py | 1 - 2 files changed, 16 deletions(-) diff --git a/driver/igemm_fwd_gtc_driver.h b/driver/igemm_fwd_gtc_driver.h index c2b19924..8b007297 100755 --- a/driver/igemm_fwd_gtc_driver.h +++ b/driver/igemm_fwd_gtc_driver.h @@ -187,20 +187,6 @@ class igemm_fwd_gtc_t { int nxb = tunable->nxb; int b = nxe == 0 ? (ho * wo) : ((ho * wo + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0 -<<<<<<< HEAD - int gemm_m = k; - int gemm_n = n * b; - int gemm_k = c * y * x; - - // support pad to modulo, hence only check when nxe is 0 - if((nxe == 0) && (gemm_n % gemm_n_per_block != 0)) - return false; - - if(gemm_m % gemm_m_per_block != 0) - return false; - - if(gemm_k % gemm_k_per_block != 0) -======= int gemm_m = k / group; int gemm_n = n * b; int gemm_k = (c / group) * y * x; @@ -209,7 +195,6 @@ class igemm_fwd_gtc_t { if((gemm_n % gemm_n_per_block != 0) || (gemm_m % gemm_m_per_block != 0) || (gemm_k % gemm_k_per_block != 0)) { ->>>>>>> master return false; if(gemm_n_per_block % tunable->nxb != 0){ diff --git a/igemm/igemm_codegen_driver.py b/igemm/igemm_codegen_driver.py index fb880114..3a294db9 100755 --- a/igemm/igemm_codegen_driver.py +++ b/igemm/igemm_codegen_driver.py @@ -142,7 +142,6 @@ def get_kernel_per_inc_file_name(ker, origin_file_name): self.mc.emitter = emitter_per_inc_dict[kpi_file_name] if IGEMM_EMIT_KERNEL_METADATA_PER_INC_FILE: kinfo_per_inc_dict[kpi_file_name].append(kernel.get_kernel_info()) - print('Jane debug, kernel name is'+kernel.name()) self._emit(';----------------------------------------------------------') self._emit('; starting of kernel {}'.format(kernel.name())) self._emit(kernel.tunable.serialize()) From 8f1771722b948974422d37f2ccb99343913bac8e Mon Sep 17 00:00:00 2001 From: Jane-zxy Date: Fri, 6 Nov 2020 14:52:26 +0000 Subject: [PATCH 10/11] change xdlops structure --- igemm/algo/xdlops_mapping.py | 66 ++++++++++++++++---------------- igemm/igemm_config_gen_driver.py | 5 ++- 2 files changed, 37 insertions(+), 34 deletions(-) mode change 100644 => 100755 igemm/igemm_config_gen_driver.py diff --git a/igemm/algo/xdlops_mapping.py b/igemm/algo/xdlops_mapping.py index 9309f5b7..90d5b151 100755 --- a/igemm/algo/xdlops_mapping.py +++ b/igemm/algo/xdlops_mapping.py @@ -259,45 +259,45 @@ def serialize(self): # mt_m,mt_n,wt_m,wt_n,wt_k,ws,r_m,r_n,s_m,s_n, inst_mfma ctrl_xdlops_mapping_fp32 = [ # ctrl_xdlops_mapping_t( 256, 256, 32, 64, 4, 2, 2, 2, 1, v_mfma_f32_32x32x1f32), - ctrl_xdlops_mapping_t( 256, 128, 64, 32, 4, 2, 2, 1, 1, v_mfma_f32_32x32x1f32), - ctrl_xdlops_mapping_t( 128, 256, 32, 64, 4, 2, 2, 1, 1, v_mfma_f32_32x32x1f32), - ctrl_xdlops_mapping_t( 256, 64 , 64, 16, 4, 2, 2, 1, 1, v_mfma_f32_16x16x1f32), - ctrl_xdlops_mapping_t( 256, 64 , 64, 32, 4, 1, 1, 2, 1, v_mfma_f32_32x32x1f32), #add by jane - - ctrl_xdlops_mapping_t( 64 , 256, 16, 64, 4, 2, 2, 1, 1, v_mfma_f32_16x16x1f32), - ctrl_xdlops_mapping_t( 64 , 256, 32, 64, 4, 1, 1, 1, 2, v_mfma_f32_32x32x1f32), #add by jane - ctrl_xdlops_mapping_t( 256, 32 , 64, 4 , 4, 2, 2, 1, 2, v_mfma_f32_4x4x1f32), - ctrl_xdlops_mapping_t( 32 , 256, 4 , 64, 4, 2, 2, 2, 1, v_mfma_f32_4x4x1f32), - #ctrl_xdlops_mapping_t( 32 , 256, 16 , 64, 4, 1, 1, 1, 2, v_mfma_f32_16x16x1f32), #add by jane, can not because coleasing group assert - ctrl_xdlops_mapping_t( 256, 16 , 64, 4 , 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), - ctrl_xdlops_mapping_t( 16 , 256, 4 , 64, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 256, 128, 64, 32, 1, 4, 2, 2, 1, 1, v_mfma_f32_32x32x1f32), + ctrl_xdlops_mapping_t( 128, 256, 32, 64, 1, 4, 2, 2, 1, 1, v_mfma_f32_32x32x1f32), + ctrl_xdlops_mapping_t( 256, 64 , 64, 16, 1, 4, 2, 2, 1, 1, v_mfma_f32_16x16x1f32), + ctrl_xdlops_mapping_t( 256, 64 , 64, 32, 1, 4, 1, 1, 2, 1, v_mfma_f32_32x32x1f32), #add by jane + + ctrl_xdlops_mapping_t( 64 , 256, 16, 64, 1, 4, 2, 2, 1, 1, v_mfma_f32_16x16x1f32), + ctrl_xdlops_mapping_t( 64 , 256, 32, 64, 1, 4, 1, 1, 1, 2, v_mfma_f32_32x32x1f32), #add by jane + ctrl_xdlops_mapping_t( 256, 32 , 64, 4 , 1, 4, 2, 2, 1, 2, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 32 , 256, 4 , 64, 1, 4, 2, 2, 2, 1, v_mfma_f32_4x4x1f32), + #ctrl_xdlops_mapping_t( 32 , 256, 16 , 64, 1, 4, 1, 1, 1, 2, v_mfma_f32_16x16x1f32), #add by jane, can not because coleasing group assert + ctrl_xdlops_mapping_t( 256, 16 , 64, 4 , 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 16 , 256, 4 , 64, 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), #ctrl_xdlops_mapping_t( 256, 16 , 64, 16, 2, 1, 1, 2, 1, v_mfma_f32_16x16x1f32), # TODO: this will fail in coalescing #ctrl_xdlops_mapping_t( 16 , 256, 16, 64, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), # TODO: this will fail in coalescing - ctrl_xdlops_mapping_t( 128, 128, 32, 32, 4, 2, 2, 1, 1, v_mfma_f32_16x16x1f32), - ctrl_xdlops_mapping_t( 128, 128, 32, 64, 4, 1, 1, 2, 1, v_mfma_f32_32x32x1f32), - ctrl_xdlops_mapping_t( 128, 64 , 32, 8 , 4, 2, 2, 1, 2, v_mfma_f32_4x4x1f32), - ctrl_xdlops_mapping_t( 64 , 128, 8 , 32, 4, 2, 2, 2, 1, v_mfma_f32_4x4x1f32), - ctrl_xdlops_mapping_t( 128, 64 , 64, 32 , 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32), #add by jane - - ctrl_xdlops_mapping_t( 64 , 128, 32, 64, 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32), - ctrl_xdlops_mapping_t( 64 , 128, 64, 32, 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32), - ctrl_xdlops_mapping_t( 128, 32 , 32, 8 , 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), - ctrl_xdlops_mapping_t( 128, 32 , 64, 16 , 4, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), #add by jane, it's better - ctrl_xdlops_mapping_t( 32 , 128, 8 , 32, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), - ctrl_xdlops_mapping_t( 32 , 128, 16, 64, 4, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), - ctrl_xdlops_mapping_t( 64 , 64 , 16, 16, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), - ctrl_xdlops_mapping_t( 64 , 64 , 32, 32, 4, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), #add by jane; it's better + ctrl_xdlops_mapping_t( 128, 128, 32, 32, 1, 4, 2, 2, 1, 1, v_mfma_f32_16x16x1f32), + ctrl_xdlops_mapping_t( 128, 128, 32, 32, 2, 4, 2, 2, 1, 1, v_mfma_f32_32x32x2f32), + ctrl_xdlops_mapping_t( 128, 128, 32, 64, 1, 4, 1, 1, 2, 1, v_mfma_f32_32x32x1f32), + ctrl_xdlops_mapping_t( 128, 64 , 32, 8 , 1, 4, 2, 2, 1, 2, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 128, 64 , 64, 32 , 1, 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32), #add by jane + ctrl_xdlops_mapping_t( 64 , 128, 8 , 32, 1, 4, 2, 2, 2, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 64 , 128, 32, 64, 1, 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32), + ctrl_xdlops_mapping_t( 64 , 128, 64, 32, 1, 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32), + ctrl_xdlops_mapping_t( 128, 32 , 32, 8, 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 128, 32 , 64, 16, 1, 4, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), #add by jane, it's better + ctrl_xdlops_mapping_t( 32 , 128, 8 , 32, 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 32 , 128, 16, 64, 1, 4, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), + ctrl_xdlops_mapping_t( 64 , 64 , 16, 16, 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 64 , 64 , 32, 32, 1, 4, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), #add by jane; it's better #ctrl_xdlops_mapping_t( 128, 16 , 64, 4 , 4, 1, 1, 2, 1, v_mfma_f32_4x4x1f32), #ctrl_xdlops_mapping_t( 16 , 128, 4 , 64, 4, 1, 1, 1, 2, v_mfma_f32_4x4x1f32), - ctrl_xdlops_mapping_t( 128, 16 , 64, 16, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), - ctrl_xdlops_mapping_t( 16 , 128, 16, 64, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), - ctrl_xdlops_mapping_t( 64 , 32 , 32, 8 , 4, 1, 1, 1, 2, v_mfma_f32_4x4x1f32), - ctrl_xdlops_mapping_t( 64 , 32 , 32, 32 , 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), #add by jane; it's better - ctrl_xdlops_mapping_t( 32 , 64 , 8 , 32, 4, 1, 1, 2, 1, v_mfma_f32_4x4x1f32), - ctrl_xdlops_mapping_t( 32 , 64 , 32 , 32, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), #add by jane; - ctrl_xdlops_mapping_t( 32 , 32 , 16, 16, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 128, 16 , 64, 16, 1, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), + ctrl_xdlops_mapping_t( 16 , 128, 16, 64, 1, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), + ctrl_xdlops_mapping_t( 64 , 32 , 32, 8 , 1, 4, 1, 1, 1, 2, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 64 , 32 , 32, 32, 1, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), #add by jane; it's better + ctrl_xdlops_mapping_t( 32 , 64 , 8 , 32, 1, 4, 1, 1, 2, 1, v_mfma_f32_4x4x1f32), + ctrl_xdlops_mapping_t( 32 , 64 , 32 , 32, 1, 2, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), #add by jane; + ctrl_xdlops_mapping_t( 32 , 32 , 16, 16, 1, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), #ctrl_xdlops_mapping_t( 256, 4 , 64, 4 , 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), # TODO: small/skinny gemm #ctrl_xdlops_mapping_t( 4 , 256, 4 , 64, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), # TODO: small/skinny gemm ctrl_xdlops_mapping_t( 64 , 16 , 64, 4 , 1, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), diff --git a/igemm/igemm_config_gen_driver.py b/igemm/igemm_config_gen_driver.py old mode 100644 new mode 100755 index 508bc69d..88614289 --- a/igemm/igemm_config_gen_driver.py +++ b/igemm/igemm_config_gen_driver.py @@ -33,6 +33,7 @@ # micro-tile: # 128x128,256x128,256x64,128x256,64x256,128x64,64x128,256x32,32x256,64x64,128x32,32x128,256x16,16x256,128x16,16x128,64x32,32x64,32x32,64x16,16x64,32x16,16x32,64x8,8x64,16x16,64x4,4x64 +''' ctrl_xdlops_mapping_fp32_config = [ ctrl_xdlops_mapping_t( 256, 128, 64, 32, 4, 2, 2, 1, 1, v_mfma_f32_32x32x1f32), ctrl_xdlops_mapping_t( 128, 256, 32, 64, 4, 2, 2, 1, 1, v_mfma_f32_32x32x1f32), @@ -91,6 +92,8 @@ ctrl_xdlops_mapping_t( 16 , 16, 16, 16, 1, 1, 1, 1, 1, v_mfma_f32_4x4x1f32) ] +''' + class igemm_config_gen_driver_t(): def __init__(self, emitter, config_content): self.emitter = emitter @@ -132,7 +135,7 @@ def emit_one_valid_config(self): def get_specific_ctrl_xdlops_mapping_t(self, macro_tile_m, macro_tile_n): target_mfma_tiling = list() - for t in ctrl_xdlops_mapping_fp32_config: + for t in ctrl_xdlops_mapping_fp32: if t.macro_tile_m == macro_tile_m and t.macro_tile_n == macro_tile_n: target_mfma_tiling.append(t) return target_mfma_tiling From b0442012286db453704862263b3fa2bcffe592f6 Mon Sep 17 00:00:00 2001 From: Jane-zxy Date: Fri, 6 Nov 2020 15:01:44 +0000 Subject: [PATCH 11/11] fix some format issue --- driver/igemm_fwd_gtc_driver.h | 1 + igemm/algo/igemm_fwd_gtc.py | 0 igemm/igemm_codegen_driver.py | 2 +- 3 files changed, 2 insertions(+), 1 deletion(-) mode change 100755 => 100644 driver/igemm_fwd_gtc_driver.h mode change 100755 => 100644 igemm/algo/igemm_fwd_gtc.py diff --git a/driver/igemm_fwd_gtc_driver.h b/driver/igemm_fwd_gtc_driver.h old mode 100755 new mode 100644 index 8b007297..372d3648 --- a/driver/igemm_fwd_gtc_driver.h +++ b/driver/igemm_fwd_gtc_driver.h @@ -196,6 +196,7 @@ class igemm_fwd_gtc_t { (gemm_k % gemm_k_per_block != 0)) { return false; + } if(gemm_n_per_block % tunable->nxb != 0){ //printf("tunable_is_valid false: gemm_n_per_block%tunable->nxb!=0, gemm_n_per_block is %d, tunable->nxb is %d\n", gemm_n_per_block, tunable->nxb); diff --git a/igemm/algo/igemm_fwd_gtc.py b/igemm/algo/igemm_fwd_gtc.py old mode 100755 new mode 100644 diff --git a/igemm/igemm_codegen_driver.py b/igemm/igemm_codegen_driver.py index 3a294db9..d9e2d671 100755 --- a/igemm/igemm_codegen_driver.py +++ b/igemm/igemm_codegen_driver.py @@ -142,6 +142,7 @@ def get_kernel_per_inc_file_name(ker, origin_file_name): self.mc.emitter = emitter_per_inc_dict[kpi_file_name] if IGEMM_EMIT_KERNEL_METADATA_PER_INC_FILE: kinfo_per_inc_dict[kpi_file_name].append(kernel.get_kernel_info()) + self._emit(';----------------------------------------------------------') self._emit('; starting of kernel {}'.format(kernel.name())) self._emit(kernel.tunable.serialize()) @@ -195,4 +196,3 @@ def do_compile(self): def __call__(self): self.do_emit() self.do_compile() -