diff --git a/ref_kernels/3/bli_gemm_ref.c b/ref_kernels/3/bli_gemm_ref.c index 119aa7b590..8d563199f0 100644 --- a/ref_kernels/3/bli_gemm_ref.c +++ b/ref_kernels/3/bli_gemm_ref.c @@ -157,6 +157,89 @@ INSERT_GENTFUNCR_BASIC( gemm_gen, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) // instructions via constant loop bounds + #pragma omp simd directives. // If compile-time MR/NR are not available (indicated by BLIS_[MN]R_x = -1), // then the non-unrolled version (above) is used. +// first the fastest case, 4 macros for m==mr, n==nr, k>0 +// cs_c = 1, beta != 0 (row major) +// cs_c = 1, beta == 0 +// rs_c = 1, beta != 0 (column major) +// rs_c = 1, beta == 0 + +#define TAIL_NITER 5 // in units of 4x k iterations +#define CACHELINE_SIZE 64 +#define TAXPBYS_BETA0(ch1,ch2,ch3,ch4,ch5,alpha,ab,beta,c) bli_tscal2s(ch1,ch2,ch3,ch4,alpha,ab,c) +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, arch, suf, taxpbys, i_or_j, j_or_i, mr_or_nr, nr_or_mr ) \ +\ +static void PASTEMAC(ch,ch,opname,arch,suf) \ + ( \ + dim_t k, \ + const ctype* alpha, \ + const ctype* a, \ + const ctype* b, \ + const ctype* beta, \ + ctype* c, inc_t s_c \ + ) \ +{ \ + const dim_t mr = PASTECH(BLIS_,mr_or_nr,_,ch); \ + const dim_t nr = PASTECH(BLIS_,nr_or_mr,_,ch); \ +\ + const inc_t cs_a = PASTECH(BLIS_PACKMR_,ch); \ + const inc_t rs_b = PASTECH(BLIS_PACKNR_,ch); \ +\ + char ab_[ BLIS_STACK_BUF_MAX_SIZE ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))) = { 0 }; \ + ctype* ab = (ctype*)ab_; \ + const inc_t s_ab = nr; \ +\ +\ + /* Initialize the accumulator elements in ab to zero. */ \ + PRAGMA_SIMD \ + for ( dim_t i = 0; i < mr * nr; ++i ) \ + { \ + bli_tset0s( ch, ab[ i ] ); \ + } \ +\ + /* Perform a series of k rank-1 updates into ab. */ \ + dim_t l = 0; do \ + { \ + dim_t i = l + TAIL_NITER*4 + mr - k; \ + if ( i >= 0 && i < mr ) \ + for ( dim_t j = 0; j < nr; j += CACHELINE_SIZE/sizeof(double) ) \ + bli_prefetch( &c[ i*s_c + j ], 0, 3 ); \ + for ( dim_t i = 0; i < mr; ++i ) \ + { \ + PRAGMA_SIMD \ + for ( dim_t j = 0; j < nr; ++j ) \ + { \ + bli_tdots \ + ( \ + ch,ch,ch,ch, \ + a[ i_or_j ], \ + b[ j_or_i ], \ + ab[ i*s_ab + j ] \ + ); \ + } \ + } \ +\ + a += cs_a; \ + b += rs_b; \ + } while ( ++l < k ); \ +\ + for ( dim_t i = 0; i < mr; ++i ) \ + PRAGMA_SIMD \ + for ( dim_t j = 0; j < nr; ++j ) \ + taxpbys \ + ( \ + ch,ch,ch,ch,ch, \ + *alpha, \ + ab[ i*s_ab + j ], \ + *beta, \ + c [ i*s_c + j ] \ + ); \ +} + +INSERT_GENTFUNC_BASIC( gemm_vect_r_beta0, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, TAXPBYS_BETA0, i, j, MR, NR ) +INSERT_GENTFUNC_BASIC( gemm_vect_r, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, bli_taxpbys, i, j, MR, NR ) +INSERT_GENTFUNC_BASIC( gemm_vect_c_beta0, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, TAXPBYS_BETA0, j, i, NR, MR ) +INSERT_GENTFUNC_BASIC( gemm_vect_c, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, bli_taxpbys, j, i, NR, MR ) #undef GENTFUNC #define GENTFUNC( ctype, ch, opname, arch, suf ) \ @@ -210,6 +293,36 @@ void PASTEMAC(ch,ch,opname,arch,suf) \ ); \ return; \ } \ +\ + if ( m == mr && n == nr && k > 0 ) \ + { \ + if ( cs_c == 1 ) \ + { \ + (bli_teq0s( ch, *beta ) ? PASTEMAC(ch,ch,gemm_vect_r_beta0,arch,suf) : PASTEMAC(ch,ch,gemm_vect_r,arch,suf)) \ + ( \ + k, \ + alpha, \ + a, \ + b, \ + beta, \ + c, rs_c \ + ); \ + return; \ + } \ + if ( rs_c == 1 ) \ + { \ + (bli_teq0s( ch, *beta ) ? PASTEMAC(ch,ch,gemm_vect_c_beta0,arch,suf) : PASTEMAC(ch,ch,gemm_vect_c,arch,suf)) \ + ( \ + k, \ + alpha, \ + a, \ + b, \ + beta, \ + c, cs_c \ + ); \ + return; \ + } \ + } \ \ char ab_[ BLIS_STACK_BUF_MAX_SIZE ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))) = { 0 }; \ ctype* ab = (ctype*)ab_; \ @@ -382,5 +495,3 @@ void PASTEMAC(chab,chc,opname,arch,suf) \ } INSERT_GENTFUNC2_MIX_P( gemm, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX ) - -