diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index e04e48cf50..1bf0a81cb8 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -35,11 +35,10 @@ #include "blis.h" - // // Define BLAS-to-BLIS interfaces. // - +#define ENABLE_INDUCED_METHOD 0 #ifdef BLIS_BLAS3_CALLS_TAPI #undef GENTFUNC @@ -218,7 +217,6 @@ void PASTEF77(ch,blasname) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ } - #endif #ifdef BLIS_ENABLE_BLAS diff --git a/kernels/zen/3/CMakeLists.txt b/kernels/zen/3/CMakeLists.txt new file mode 100644 index 0000000000..fd3aabe70b --- /dev/null +++ b/kernels/zen/3/CMakeLists.txt @@ -0,0 +1,13 @@ +##Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc.## + +target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_small.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_syrk_small.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsm_small.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_ref_k1.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_sqp_kernels.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_sqp.c + ) + +add_subdirectory(sup) diff --git a/kernels/zen/3/bli_gemm_sqp.c b/kernels/zen/3/bli_gemm_sqp.c new file mode 100644 index 0000000000..ca33972a29 --- /dev/null +++ b/kernels/zen/3/bli_gemm_sqp.c @@ -0,0 +1,1214 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" +#include "immintrin.h" +#include "bli_gemm_sqp_kernels.h" + +#define SQP_THREAD_ENABLE 0// Currently disabled: simple threading model along m dimension. + // Works well when m is large and other dimensions are relatively smaller. +#define BLI_SQP_MAX_THREADS 128 +#define BLIS_LOADFIRST 0 +#define MEM_ALLOC 1 // malloc performs better than bli_malloc. Reason to be understood. + +#define SET_TRANS(X,Y)\ + Y = BLIS_NO_TRANSPOSE;\ + if(bli_obj_has_trans( a ))\ + {\ + Y = BLIS_TRANSPOSE;\ + if(bli_obj_has_conj(a))\ + {\ + Y = BLIS_CONJ_TRANSPOSE;\ + }\ + }\ + else if(bli_obj_has_conj(a))\ + {\ + Y = BLIS_CONJ_NO_TRANSPOSE;\ + } + +//Macro for 3m_sqp nx loop +#define BLI_SQP_ZGEMM_N(MX)\ + int j=0;\ + for(; j<=(n-nx); j+= nx)\ + {\ + status = bli_sqp_zgemm_m8( m, nx, k, a, lda, b+(j*ldb), ldb, c+(j*ldc), ldc, alpha_real, beta_real, transa, MX, p_istart, kx, &mem_3m_sqp);\ + }\ + if(jreal; + double alpha_imag = alphap->imag; + double beta_real = betap->real; + double beta_imag = betap->imag; + if( (alpha_imag!=0)||(beta_imag!=0) ) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + //printf("zsqp "); + return bli_sqp_zgemm( m, n, k, ap, lda, bp, ldb, cp, ldc, alpha_real, beta_real, transa, nt); + } + else if(dt == BLIS_DOUBLE) + { + double *alpha_cast, *beta_cast; + alpha_cast = bli_obj_buffer_for_1x1(BLIS_DOUBLE, alpha); + beta_cast = bli_obj_buffer_for_1x1(BLIS_DOUBLE, beta); + + if((*beta_cast)!=1.0) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + if(((*alpha_cast)!=1.0)&&((*alpha_cast)!=-1.0)) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + //printf("dsqp "); + // dgemm case only transpose or no-transpose is handled. + // conjugate_transpose and conjugate no transpose are not applicable. + return bli_sqp_dgemm( m, n, k, ap, lda, bp, ldb, cp, ldc, *alpha_cast, *beta_cast, isTransA, nt); + } + + return BLIS_NOT_YET_IMPLEMENTED; +}; + +//sqp_dgemm k partition +BLIS_INLINE void bli_sqp_dgemm_kx( gint_t m, + gint_t n, + gint_t kx, + gint_t p, + double* a, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc, + bool isTransA, + double alpha, + gint_t mx, + gint_t i, + bool pack_on, + double *aligned) +{ + inc_t j = 0; + double* ci = c + i; + double* aPacked; + //packing + if(pack_on==true) + { + aPacked = aligned; + double *pa = a + i + (p*lda); + if(isTransA==true) + { + pa = a + (i*lda) + p; + } + bli_sqp_prepackA(pa, aPacked, kx, lda, isTransA, alpha, mx); + } + else + { + aPacked = a+i + (p*lda); + } + + //compute + if(mx==8) + { + //printf("\n mx8i:%3ld ", i); + if (j <= (n - 6)) + { + j = bli_sqp_dgemm_kernel_8mx6n(n, kx, j, aPacked, lda, b + p, ldb, ci, ldc); + } + if (j <= (n - 5)) + { + j = bli_sqp_dgemm_kernel_8mx5n(n, kx, j, aPacked, lda, b + (j * ldb) + p, ldb, ci + (j * ldc), ldc); + } + if (j <= (n - 4)) + { + j = bli_sqp_dgemm_kernel_8mx4n(n, kx, j, aPacked, lda, b + (j * ldb) + p, ldb, ci + (j * ldc), ldc); + } + if (j <= (n - 3)) + { + j = bli_sqp_dgemm_kernel_8mx3n(n, kx, j, aPacked, lda, b + (j * ldb) + p, ldb, ci + (j * ldc), ldc); + } + if (j <= (n - 2)) + { + j = bli_sqp_dgemm_kernel_8mx2n(n, kx, j, aPacked, lda, b + (j * ldb) + p, ldb, ci + (j * ldc), ldc); + } + if (j <= (n - 1)) + { + j = bli_sqp_dgemm_kernel_8mx1n(n, kx, j, aPacked, lda, b + (j * ldb) + p, ldb, ci + (j * ldc), ldc); + } + } + /* mx==4 to be implemented */ + else + { + // this residue kernel needs to be improved. + j = bli_sqp_dgemm_kernel_mxn(n, kx, j, aPacked, lda, b + p, ldb, ci, ldc, mx); + } +} + +//sqp dgemm m loop +void bli_sqp_dgemm_m( gint_t i_start, + gint_t i_end, + gint_t m, + gint_t n, + gint_t k, + gint_t kx, + double* a, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc, + bool isTransA, + double alpha, + gint_t mx, + bool pack_on, + double *aligned ) +{ +#if SQP_THREAD_ENABLE + if(pack_on==true) + { + err_t r_val; + //NEEDED IN THREADING CASE: + aligned = (double*)bli_malloc_user(sizeof(double) * kx * mx, &r_val); + if(aligned==NULL) + { + return BLIS_MALLOC_RETURNED_NULL;// return to be removed + } + } +#endif//SQP_THREAD_ENABLE + + for (gint_t i = i_start; i <= (i_end-mx); i += mx) //this loop can be threaded. no of workitems = m/8 + { + int p = 0; + for(; p <= (k-kx); p += kx) + { + bli_sqp_dgemm_kx(m, n, kx, p, a, lda, b, ldb, c, ldc, isTransA, alpha, mx, i, pack_on, aligned); + }// k loop end + + if(pi_start, + arg->i_end, + arg->m, + arg->n, + arg->k, + arg->kx, + arg->a, + arg->lda, + arg->b, + arg->ldb, + arg->c, + arg->ldc, + arg->isTransA, + arg->alpha, + arg->mx, + arg->pack_on, + arg->aligned); +} + +// sqp_dgemm m loop +BLIS_INLINE err_t bli_sqp_dgemm_m8( gint_t m, + gint_t n, + gint_t k, + double* a, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc, + bool isTransA, + double alpha, + gint_t mx, + gint_t* p_istart, + gint_t kx, + double *aligned) +{ + gint_t i; + if(kx > k) + { + kx = k; + } + + bool pack_on = false; + if((m!=mx)||(m!=lda)||isTransA) + { + pack_on = true; + } + +#if SQP_THREAD_ENABLE//ENABLE Threading + gint_t status = 0; + gint_t workitems = (m-(*p_istart))/mx; + gint_t inputThreadCount = bli_thread_get_num_threads(); + inputThreadCount = bli_min(inputThreadCount, BLI_SQP_MAX_THREADS); + // limit input thread count when workitems are lesser. + inputThreadCount = bli_min(inputThreadCount,workitems); + inputThreadCount = bli_max(inputThreadCount,1); + gint_t num_threads; + num_threads = bli_max(inputThreadCount,1); + //no of workitems per thread + gint_t mx_per_thread = workitems/num_threads; + + *p_istart, workitems,inputThreadCount,num_threads,mx_per_thread, mx); + + pthread_t ptid[BLI_SQP_MAX_THREADS]; + bli_sqp_thread_info thread_info[BLI_SQP_MAX_THREADS]; + + //create threads + for (gint_t t = 0; t < num_threads; t++) + { + gint_t i_end = ((mx_per_thread*(t+1))*mx)+(*p_istart); + if(i_end>m) + { + i_end = m; + } + + if(t==(num_threads-1))//for last thread allocate remaining work. + { + if((i_end+mx)==m) + { + i_end = m; + } + + if(mx==1) + { + i_end = m; + } + } + + thread_info[t].i_start = ((mx_per_thread*t)*mx)+(*p_istart); + thread_info[t].i_end = i_end; + + thread_info[t].m = m; + thread_info[t].n = n; + thread_info[t].k = k; + thread_info[t].kx = kx; + thread_info[t].a = a; + thread_info[t].lda = lda; + thread_info[t].b = b; + thread_info[t].ldb = ldb; + thread_info[t].c = c; + thread_info[t].ldc = ldc; + thread_info[t].isTransA = isTransA; + thread_info[t].alpha = alpha; + thread_info[t].mx = mx; + thread_info[t].pack_on = pack_on; + thread_info[t].aligned = aligned; + + if ((status = pthread_create(&ptid[t], NULL, bli_sqp_thread, (void*)&thread_info[t]))) + { + printf("error sqp pthread_create\n"); + return BLIS_FAILURE; + } + } + + //wait for completion + for (gint_t t = 0; t < num_threads; t++) + { + pthread_join(ptid[t], NULL); + } + + if(num_threads>0) + { + *p_istart = thread_info[(num_threads-1)].i_end; + } +#else//SQP_THREAD_ENABLE + + // single thread code segement. + if(pack_on==true) + { + if(aligned==NULL) + { + return BLIS_MALLOC_RETURNED_NULL; + } + } + + for (i = (*p_istart); i <= (m-mx); i += mx) + { + int p = 0; + for(; p <= (k-kx); p += kx) + { + bli_sqp_dgemm_kx(m, n, kx, p, a, lda, b, ldb, c, ldc, + isTransA, alpha, mx, i, pack_on, aligned); + }// k loop end + + if(pdata_size * mem_req->size; + if (memSize == 0) + { + return -1; + } + memSize += 128;// extra 128 bytes added for alignment. Could be minimized to 64. +#if MEM_ALLOC +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "malloc(): size %ld\n",( long )memSize); + fflush( stdout ); +#endif + mem_req->unalignedBuf = (double*)malloc(memSize); + if (mem_req->unalignedBuf == NULL) + { + return -1; + } + + int64_t address = (int64_t)mem_req->unalignedBuf; + address += (-address) & 63; //64 bytes alignment done. + mem_req->alignedBuf = (double*)address; +#else + err_t r_val; + mem_req->alignedBuf = bli_malloc_user( memSize, &r_val); + if (mem_req->alignedBuf == NULL) + { + return -1; + } +#endif + return 0; +} + +gint_t bli_allocateWorkspace(gint_t n, gint_t k, mem_block *mxr, mem_block *mxi, mem_block *msx) +{ + //allocate workspace + mxr->data_size = mxi->data_size = msx->data_size = sizeof(double); + mxr->size = mxi->size = n * k; + msx->size = n * k; + mxr->alignedBuf = mxi->alignedBuf = msx->alignedBuf = NULL; + mxr->unalignedBuf = mxi->unalignedBuf = msx->unalignedBuf = NULL; + + if (!((bli_getaligned(mxr) == 0) && (bli_getaligned(mxi) == 0) && (bli_getaligned(msx) == 0))) + { +#if MEM_ALLOC + if(mxr->unalignedBuf) + { + free(mxr->unalignedBuf); + } + if(mxi->unalignedBuf) + { + free(mxi->unalignedBuf); + } + if(msx->unalignedBuf) + { + free(msx->unalignedBuf); + } +#else + bli_free_user(mxr->alignedBuf); + bli_free_user(mxi->alignedBuf); + bli_free_user(msx->alignedBuf); +#endif + return -1; + } + return 0; +} + +//3m_sqp k loop +BLIS_INLINE void bli_sqp_zgemm_kx( gint_t m, + gint_t n, + gint_t kx, + gint_t p, + double* a, + guint_t lda, + guint_t ldb, + double* c, + guint_t ldc, + trans_t transa, + double alpha, + double beta, + gint_t mx, + gint_t i, + double* ar, + double* ai, + double* as, + double* br, + double* bi, + double* bs, + double* cr, + double* ci, + double* w, + double *a_aligned) +{ + gint_t j; + + ////////////// operation 1 ///////////////// + /* Split a (ar, ai) and + compute as = ar + ai */ + double* par = ar; + double* pai = ai; + double* pas = as; + + /* a matrix real and imag packing and compute. */ + bli_3m_sqp_packA_real_imag_sum(a, i, kx+p, lda, par, pai, pas, transa, mx, p); + + double* pcr = cr; + double* pci = ci; + + //Split Cr and Ci and beta multiplication done. + double* pc = c + i; + if(p==0) + { + bli_3m_sqp_packC_real_imag(pc, n, mx, ldc, pcr, pci, beta, mx); + } + //Ci := rgemm( SA, SB, Ci ) + gint_t istart = 0; + gint_t* p_is = &istart; + *p_is = 0; + bli_sqp_dgemm_m8(mx, n, kx, as, mx, bs, ldb, ci, mx, false, 1.0, mx, p_is, kx, a_aligned); + + ////////////// operation 2 ///////////////// + //Wr: = dgemm_sqp(Ar, Br, 0) // Wr output 8xn + double* wr = w; + for (j = 0; j < n; j++) { + for (gint_t ii = 0; ii < mx; ii += 1) { + *wr = 0; + wr++; + } + } + wr = w; + + *p_is = 0; + bli_sqp_dgemm_m8(mx, n, kx, ar, mx, br, ldb, wr, mx, false, 1.0, mx, p_is, kx, a_aligned); + //Cr : = addm(Wr, Cr) + bli_add_m(mx, n, wr, cr); + //Ci : = subm(Wr, Ci) + bli_sub_m(mx, n, wr, ci); + + + ////////////// operation 3 ///////////////// + //Wi : = dgemm_sqp(Ai, Bi, 0) // Wi output 8xn + double* wi = w; + for (j = 0; j < n; j++) { + for (gint_t ii = 0; ii < mx; ii += 1) { + *wi = 0; + wi++; + } + } + wi = w; + + *p_is = 0; + bli_sqp_dgemm_m8(mx, n, kx, ai, mx, bi, ldb, wi, mx, false, 1.0, mx, p_is, kx, a_aligned); + //Cr : = subm(Wi, Cr) + bli_sub_m(mx, n, wi, cr); + //Ci : = subm(Wi, Ci) + bli_sub_m(mx, n, wi, ci); + + pcr = cr; + pci = ci; + + for (j = 0; j < n; j++) + { + for (gint_t ii = 0; ii < (mx*2); ii += 2) + { + c[(j * ldc) + i + ii] = *pcr; + c[(j * ldc) + i + ii + 1] = *pci; + pcr++; pci++; + } + } +} + +/**************************************************************/ +/* workspace memory allocation for 3m_sqp algorithm for zgemm */ +/**************************************************************/ +err_t allocate_3m_Sqp_workspace(workspace_3m_sqp *mem_3m_sqp, + gint_t mx, + gint_t nx, + gint_t k, + gint_t kx ) +{ + //3m_sqp workspace Memory allocation + /* B matrix */ + // B matrix packed with n x k size. without kx smaller sizes for now. + mem_block mbr, mbi, mbs; + if(bli_allocateWorkspace(nx, k, &mbr, &mbi, &mbs)!=0) + { + return BLIS_FAILURE; + } + mem_3m_sqp->br = (double*)mbr.alignedBuf; + mem_3m_sqp->bi = (double*)mbi.alignedBuf; + mem_3m_sqp->bs = (double*)mbs.alignedBuf; + mem_3m_sqp->br_unaligned = (double*)mbr.unalignedBuf; + mem_3m_sqp->bi_unaligned = (double*)mbi.unalignedBuf; + mem_3m_sqp->bs_unaligned = (double*)mbs.unalignedBuf; + + /* Workspace memory allocation currently done dynamically + This needs to be taken from already allocated memory pool in application for better performance */ + /* A matrix */ + mem_block mar, mai, mas; + if(bli_allocateWorkspace(mx, kx, &mar, &mai, &mas) !=0) + { + return BLIS_FAILURE; + } + mem_3m_sqp->ar = (double*)mar.alignedBuf; + mem_3m_sqp->ai = (double*)mai.alignedBuf; + mem_3m_sqp->as = (double*)mas.alignedBuf; + mem_3m_sqp->ar_unaligned = (double*)mar.unalignedBuf; + mem_3m_sqp->ai_unaligned = (double*)mai.unalignedBuf; + mem_3m_sqp->as_unaligned = (double*)mas.unalignedBuf; + + /* w matrix */ + mem_block mw; + mw.data_size = sizeof(double); + mw.size = mx * nx; + if (bli_getaligned(&mw) != 0) + { + return BLIS_FAILURE; + } + mem_3m_sqp->w = (double*)mw.alignedBuf; + mem_3m_sqp->w_unaligned = (double*)mw.unalignedBuf; + /* cr matrix */ + mem_block mcr; + mcr.data_size = sizeof(double); + mcr.size = mx * nx; + if (bli_getaligned(&mcr) != 0) + { + return BLIS_FAILURE; + } + mem_3m_sqp->cr = (double*)mcr.alignedBuf; + mem_3m_sqp->cr_unaligned = (double*)mcr.unalignedBuf; + + + /* ci matrix */ + mem_block mci; + mci.data_size = sizeof(double); + mci.size = mx * nx; + if (bli_getaligned(&mci) != 0) + { + return BLIS_FAILURE; + } + mem_3m_sqp->ci = (double*)mci.alignedBuf; + mem_3m_sqp->ci_unaligned = (double*)mci.unalignedBuf; + + // A packing buffer + err_t r_val; + mem_3m_sqp->aPacked = (double*)bli_malloc_user(sizeof(double) * kx * mx, &r_val); + if (mem_3m_sqp->aPacked == NULL) + { + return BLIS_FAILURE; + } + + return BLIS_SUCCESS; +} + +void free_3m_Sqp_workspace(workspace_3m_sqp *mem_3m_sqp) +{ + // A packing buffer free + bli_free_user(mem_3m_sqp->aPacked); + +#if MEM_ALLOC + if(mem_3m_sqp->ar_unaligned) + { + free(mem_3m_sqp->ar_unaligned); + } + if(mem_3m_sqp->ai_unaligned) + { + free(mem_3m_sqp->ai_unaligned); + } + if(mem_3m_sqp->as_unaligned) + { + free(mem_3m_sqp->as_unaligned); + } + + if(mem_3m_sqp->br_unaligned) + { + free(mem_3m_sqp->br_unaligned); + } + if(mem_3m_sqp->bi_unaligned) + { + free(mem_3m_sqp->bi_unaligned); + } + if(mem_3m_sqp->bs_unaligned) + { + free(mem_3m_sqp->bs_unaligned); + } + + if(mem_3m_sqp->w_unaligned) + { + free(mem_3m_sqp->w_unaligned); + } + if(mem_3m_sqp->cr_unaligned) + { + free(mem_3m_sqp->cr_unaligned); + } + if(mem_3m_sqp->ci_unaligned) + { + free(mem_3m_sqp->ci_unaligned); + } + +#else//MEM_ALLOC + /* free workspace buffers */ + bli_free_user(mem_3m_sqp->br); + bli_free_user(mem_3m_sqp->bi); + bli_free_user(mem_3m_sqp->bs); + bli_free_user(mem_3m_sqp->ar); + bli_free_user(mem_3m_sqp->ai); + bli_free_user(mem_3m_sqp->as); + bli_free_user(mem_3m_sqp->w); + bli_free_user(mem_3m_sqp->cr); + bli_free_user(mem_3m_sqp->ci); +#endif//MEM_ALLOC +} + +//3m_sqp m loop +BLIS_INLINE err_t bli_sqp_zgemm_m8( gint_t m, + gint_t n, + gint_t k, + double* a, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc, + double alpha, + double beta, + trans_t transa, + gint_t mx, + gint_t* p_istart, + gint_t kx, + workspace_3m_sqp *mem_3m_sqp) +{ + inc_t m2 = m<<1; + inc_t mxmul2 = mx<<1; + + if((*p_istart) > (m2-mxmul2)) + { + return BLIS_SUCCESS; + } + inc_t i; + gint_t max_m = (m2-mxmul2); + + /* + get workspace for packing and + for storing sum of real and imag component */ + double* ar, * ai, * as; + ar = mem_3m_sqp->ar; + ai = mem_3m_sqp->ai; + as = mem_3m_sqp->as; + + double* br, * bi, * bs; + br = mem_3m_sqp->br; + bi = mem_3m_sqp->bi; + bs = mem_3m_sqp->bs; + + double* cr, * ci; + cr = mem_3m_sqp->cr; + ci = mem_3m_sqp->ci; + + double *w; + w = mem_3m_sqp->w; + + double* a_aligned; + a_aligned = mem_3m_sqp->aPacked; + + /* Split b (br, bi) and + compute bs = br + bi */ + double* pbr = br; + double* pbi = bi; + double* pbs = bs; + /* b matrix real and imag packing and compute. */ + bli_3m_sqp_packB_real_imag_sum(b, n, k, ldb, pbr, pbi, pbs, alpha, mx); + + for (i = (*p_istart); i <= max_m; i += mxmul2) //this loop can be threaded. + { +#if KLP//kloop + int p = 0; + for(; p <= (k-kx); p += kx) + { + bli_sqp_zgemm_kx(m, n, kx, p, a, lda, k, c, ldc, + transa, alpha, beta, mx, i, ar, ai, as, + br + p, bi + p, bs + p, cr, ci, w, a_aligned); + }// k loop end + + if(p>3)<<3); + + workspace_3m_sqp mem_3m_sqp; + + /* multiply lda, ldb and ldc by 2 to account for + real & imaginary components per dcomplex. */ + lda = lda * 2; + ldb = ldb * 2; + ldc = ldc * 2; + + /* user can set BLIS_MULTI_INSTANCE macro for + better performance while runing multi-instance use-case. + */ + dim_t multi_instance = bli_env_get_var( "BLIS_MULTI_INSTANCE", -1 ); + gint_t nx = n; + if(multi_instance>0) + { + //limited nx size helps in reducing memory footprint in multi-instance case. + nx = 84; + // 84 is derived based on tuning results + } + + if(nx>n) + { + nx = n; + } + + gint_t kx = k;// kx is configurable at run-time. +#if KLP + if (kx > k) + { + kx = k; + } + // for tn case there is a bug in handling k parts. To be fixed. + if(transa!=BLIS_NO_TRANSPOSE) + { + kx = k; + } +#else + kx = k; +#endif + //3m_sqp workspace Memory allocation + if(allocate_3m_Sqp_workspace(&mem_3m_sqp, mx, nx, k, kx)!=BLIS_SUCCESS) + { + return BLIS_FAILURE; + } + + BLI_SQP_ZGEMM_N(mx) + *p_istart = (m-m8rem)*2; + + if(m8rem!=0) + { + //complete residue m blocks + BLI_SQP_ZGEMM_N(m8rem) + } + + free_3m_Sqp_workspace(&mem_3m_sqp); + return status; +} + +/****************************************************************************/ +/*********************** dgemm_sqp implementation****************************/ +/****************************************************************************/ +/* dgemm_sqp implementation packs A matrix based on lda and m size. + dgemm_sqp focuses mainly on square matrixes but also supports non-square matrix. + Current support is limiteed to m multiple of 8 and column storage. + C = AxB and C = AtxB is handled in the design. + AtxB case is done by transposing A matrix while packing A. + In majority of use-case, alpha are +/-1, so instead of explicitly multiplying + alpha its done during packing itself by changing sign. +*/ +BLIS_INLINE err_t bli_sqp_dgemm(gint_t m, + gint_t n, + gint_t k, + double* a, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc, + double alpha, + double beta, + bool isTransA, + dim_t nt) +{ + gint_t istart = 0; + gint_t* p_istart = &istart; + *p_istart = 0; + err_t status = BLIS_SUCCESS; + dim_t m8rem = m - ((m>>3)<<3); + + /* dgemm implementation with 8mx5n major kernel and column preferred storage */ + gint_t mx = 8; + gint_t kx = k; + double* a_aligned = NULL; + + //single pack buffer allocated for single thread case + if(nt<=1) + { + err_t r_val; + a_aligned = (double*)bli_malloc_user(sizeof(double) * kx * mx, &r_val); + } + + /* nx assigned to n, but can be assigned with lower value + to add partition along n dimension */ + gint_t nx = n; + if(nx>n) + { + nx = n; + } + + //mx==8 case for dgemm. + BLI_SQP_DGEMM_N(mx) + *p_istart = (m-m8rem); + + if(nt>1) + { + //2nd level thread for mx=8 + gint_t rem_m = m - (*p_istart); + if((rem_m>=mx)&&(status==BLIS_SUCCESS)) + { + status = bli_sqp_dgemm_m8( m, n, k, a, lda, b, ldb, c, ldc, + isTransA, alpha, mx, p_istart, kx, a_aligned); + } + } + + if(status==BLIS_SUCCESS) + { + if(m8rem!=0) + { + //complete residue m blocks + BLI_SQP_DGEMM_N(m8rem) + } + } + + //single pack buffer allocated for single thread case + if(nt<=1) + { + bli_free_user(a_aligned); + } + return status; +} \ No newline at end of file diff --git a/kernels/zen/3/bli_gemm_sqp_kernels.c b/kernels/zen/3/bli_gemm_sqp_kernels.c new file mode 100644 index 0000000000..2f553708cd --- /dev/null +++ b/kernels/zen/3/bli_gemm_sqp_kernels.c @@ -0,0 +1,1636 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" +#include "immintrin.h" +#include "bli_gemm_sqp_kernels.h" + +#define BLIS_LOADFIRST 0 +#define BLIS_ENABLE_PREFETCH 1 + +#define BLIS_MX8 8 +#define BLIS_MX4 4 +#define BLIS_MX1 1 + +/****************************************************************************/ +/*************** dgemm kernels (8mxn) column preffered *********************/ +/****************************************************************************/ + +/* Main dgemm kernel 8mx6n with single load and store of C matrix block + alpha = +/-1 and beta = +/-1,0 handled while packing.*/ +inc_t bli_sqp_dgemm_kernel_8mx6n(gint_t n, + gint_t k, + gint_t j, + double* aPacked, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc) +{ + gint_t p; + + __m256d av0, av1; + __m256d bv0, bv1; + __m256d cv0, cv1, cv2, cv3, cv4, cv5; + __m256d cx0, cx1, cx2, cx3, cx4, cx5; + double* pb, * pc; + + pb = b; + pc = c; + inc_t ldc6 = ldc * 6; inc_t ldb6 = ldb * 6; + + for (j = 0; j <= (n - 6); j += 6) { + double* pcldc = pc + ldc; + double* pcldc2 = pcldc + ldc; + double* pcldc3 = pcldc2 + ldc; + double* pcldc4 = pcldc3 + ldc; + double* pcldc5 = pcldc4 + ldc; + + double* pbldb = pb + ldb; + double* pbldb2 = pbldb + ldb; + double* pbldb3 = pbldb2 + ldb; + double* pbldb4 = pbldb3 + ldb; + double* pbldb5 = pbldb4 + ldb; + +#if BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(pc), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc2), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc3), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc4), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc5), _MM_HINT_T0); + + _mm_prefetch((char*)(aPacked), _MM_HINT_T0); + + _mm_prefetch((char*)(pb), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb2), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb3), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb4), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb5), _MM_HINT_T0); +#endif + /* C matrix column major load */ +#if BLIS_LOADFIRST + cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); + cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); + cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); + cv3 = _mm256_loadu_pd(pcldc3); cx3 = _mm256_loadu_pd(pcldc3 + 4); + cv4 = _mm256_loadu_pd(pcldc4); cx4 = _mm256_loadu_pd(pcldc4 + 4); + cv5 = _mm256_loadu_pd(pcldc5); cx5 = _mm256_loadu_pd(pcldc5 + 4); +#else + cv0 = _mm256_setzero_pd(); cx0 = _mm256_setzero_pd(); + cv1 = _mm256_setzero_pd(); cx1 = _mm256_setzero_pd(); + cv2 = _mm256_setzero_pd(); cx2 = _mm256_setzero_pd(); + cv3 = _mm256_setzero_pd(); cx3 = _mm256_setzero_pd(); + cv4 = _mm256_setzero_pd(); cx4 = _mm256_setzero_pd(); + cv5 = _mm256_setzero_pd(); cx5 = _mm256_setzero_pd(); +#endif + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + av0 = _mm256_loadu_pd(x); x += 4; av1 = _mm256_loadu_pd(x); x += 4; + bv0 = _mm256_broadcast_sd (pb0); pb0++; + bv1 = _mm256_broadcast_sd(pbldb); pbldb++; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + cx0 = _mm256_fmadd_pd(av1, bv0, cx0); + cv1 = _mm256_fmadd_pd(av0, bv1, cv1); + cx1 = _mm256_fmadd_pd(av1, bv1, cx1); + + bv0 = _mm256_broadcast_sd(pbldb2);pbldb2++; + bv1 = _mm256_broadcast_sd(pbldb3);pbldb3++; + cv2 = _mm256_fmadd_pd(av0, bv0, cv2); + cx2 = _mm256_fmadd_pd(av1, bv0, cx2); + cv3 = _mm256_fmadd_pd(av0, bv1, cv3); + cx3 = _mm256_fmadd_pd(av1, bv1, cx3); + + bv0 = _mm256_broadcast_sd(pbldb4);pbldb4++; + bv1 = _mm256_broadcast_sd(pbldb5);pbldb5++; + cv4 = _mm256_fmadd_pd(av0, bv0, cv4); + cx4 = _mm256_fmadd_pd(av1, bv0, cx4); + cv5 = _mm256_fmadd_pd(av0, bv1, cv5); + cx5 = _mm256_fmadd_pd(av1, bv1, cx5); + } +#if BLIS_LOADFIRST +#else + bv0 = _mm256_loadu_pd(pc); bv1 = _mm256_loadu_pd(pc + 4); + cv0 = _mm256_add_pd(cv0, bv0); cx0 = _mm256_add_pd(cx0, bv1); + + av0 = _mm256_loadu_pd(pcldc); av1 = _mm256_loadu_pd(pcldc + 4); + cv1 = _mm256_add_pd(cv1, av0); cx1 = _mm256_add_pd(cx1, av1); + + bv0 = _mm256_loadu_pd(pcldc2); bv1 = _mm256_loadu_pd(pcldc2 + 4); + cv2 = _mm256_add_pd(cv2, bv0); cx2 = _mm256_add_pd(cx2, bv1); + + av0 = _mm256_loadu_pd(pcldc3); av1 = _mm256_loadu_pd(pcldc3 + 4); + cv3 = _mm256_add_pd(cv3, av0); cx3 = _mm256_add_pd(cx3, av1); + + bv0 = _mm256_loadu_pd(pcldc4); bv1 = _mm256_loadu_pd(pcldc4 + 4); + cv4 = _mm256_add_pd(cv4, bv0); cx4 = _mm256_add_pd(cx4, bv1); + + av0 = _mm256_loadu_pd(pcldc5); av1 = _mm256_loadu_pd(pcldc5 + 4); + cv5 = _mm256_add_pd(cv5, av0); cx5 = _mm256_add_pd(cx5, av1); +#endif + /* C matrix column major store */ + _mm256_storeu_pd(pc, cv0); + _mm256_storeu_pd(pc + 4, cx0); + + _mm256_storeu_pd(pcldc, cv1); + _mm256_storeu_pd(pcldc + 4, cx1); + + _mm256_storeu_pd(pcldc2, cv2); + _mm256_storeu_pd(pcldc2 + 4, cx2); + + _mm256_storeu_pd(pcldc3, cv3); + _mm256_storeu_pd(pcldc3 + 4, cx3); + + _mm256_storeu_pd(pcldc4, cv4); + _mm256_storeu_pd(pcldc4 + 4, cx4); + + _mm256_storeu_pd(pcldc5, cv5); + _mm256_storeu_pd(pcldc5 + 4, cx5); + + pc += ldc6;pb += ldb6; + } + + return j; +} + +/* alternative Main dgemm kernel 8mx5n with single load and store of C matrix block + alpha = +/-1 and beta = +/-1,0 handled while packing.*/ +inc_t bli_sqp_dgemm_kernel_8mx5n( gint_t n, + gint_t k, + gint_t j, + double* aPacked, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc) +{ + gint_t p; + __m256d av0; + __m256d bv0, bv1, bv2, bv3; + __m256d cv0, cv1, cv2, cv3; + __m256d cx0, cx1, cx2, cx3; + __m256d bv4, cv4, cx4; + double* pb, * pc; + + pb = b; + pc = c; + inc_t ldc5 = ldc * 5; inc_t ldb5 = ldb * 5; + + for (; j <= (n - 5); j += 5) { + + double* pcldc = pc + ldc; + double* pcldc2 = pcldc + ldc; + double* pcldc3 = pcldc2 + ldc; + double* pcldc4 = pcldc3 + ldc; + + double* pbldb = pb + ldb; + double* pbldb2 = pbldb + ldb; + double* pbldb3 = pbldb2 + ldb; + double* pbldb4 = pbldb3 + ldb; + +#if BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(pc), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc2), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc3), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc4), _MM_HINT_T0); + + _mm_prefetch((char*)(aPacked), _MM_HINT_T0); + + _mm_prefetch((char*)(pb), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb2), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb3), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb4), _MM_HINT_T0); +#endif + /* C matrix column major load */ +#if BLIS_LOADFIRST + cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); + cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); + cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); + cv3 = _mm256_loadu_pd(pcldc3); cx3 = _mm256_loadu_pd(pcldc3 + 4); + cv4 = _mm256_loadu_pd(pcldc4); cx4 = _mm256_loadu_pd(pcldc4 + 4); +#else + cv0 = _mm256_setzero_pd(); cx0 = _mm256_setzero_pd(); + cv1 = _mm256_setzero_pd(); cx1 = _mm256_setzero_pd(); + cv2 = _mm256_setzero_pd(); cx2 = _mm256_setzero_pd(); + cv3 = _mm256_setzero_pd(); cx3 = _mm256_setzero_pd(); + cv4 = _mm256_setzero_pd(); cx4 = _mm256_setzero_pd(); +#endif + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + bv0 = _mm256_broadcast_sd(pb0); pb0++; + bv1 = _mm256_broadcast_sd(pbldb); pbldb++; + bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; + bv3 = _mm256_broadcast_sd(pbldb3);pbldb3++; + bv4 = _mm256_broadcast_sd(pbldb4);pbldb4++; + + av0 = _mm256_loadu_pd(x); x += 4; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + cv1 = _mm256_fmadd_pd(av0, bv1, cv1); + cv2 = _mm256_fmadd_pd(av0, bv2, cv2); + cv3 = _mm256_fmadd_pd(av0, bv3, cv3); + cv4 = _mm256_fmadd_pd(av0, bv4, cv4); + + av0 = _mm256_loadu_pd(x); x += 4; + cx0 = _mm256_fmadd_pd(av0, bv0, cx0); + cx1 = _mm256_fmadd_pd(av0, bv1, cx1); + cx2 = _mm256_fmadd_pd(av0, bv2, cx2); + cx3 = _mm256_fmadd_pd(av0, bv3, cx3); + cx4 = _mm256_fmadd_pd(av0, bv4, cx4); + } +#if BLIS_LOADFIRST +#else + bv0 = _mm256_loadu_pd(pc); bv1 = _mm256_loadu_pd(pc + 4); + cv0 = _mm256_add_pd(cv0, bv0); cx0 = _mm256_add_pd(cx0, bv1); + + bv2 = _mm256_loadu_pd(pcldc); bv3 = _mm256_loadu_pd(pcldc + 4); + cv1 = _mm256_add_pd(cv1, bv2); cx1 = _mm256_add_pd(cx1, bv3); + + bv0 = _mm256_loadu_pd(pcldc2); bv1 = _mm256_loadu_pd(pcldc2 + 4); + cv2 = _mm256_add_pd(cv2, bv0); cx2 = _mm256_add_pd(cx2, bv1); + + bv2 = _mm256_loadu_pd(pcldc3); bv3 = _mm256_loadu_pd(pcldc3 + 4); + cv3 = _mm256_add_pd(cv3, bv2); cx3 = _mm256_add_pd(cx3, bv3); + + bv0 = _mm256_loadu_pd(pcldc4); bv1 = _mm256_loadu_pd(pcldc4 + 4); + cv4 = _mm256_add_pd(cv4, bv0); cx4 = _mm256_add_pd(cx4, bv1); +#endif + /* C matrix column major store */ + _mm256_storeu_pd(pc, cv0); + _mm256_storeu_pd(pc + 4, cx0); + + _mm256_storeu_pd(pcldc, cv1); + _mm256_storeu_pd(pcldc + 4, cx1); + + _mm256_storeu_pd(pcldc2, cv2); + _mm256_storeu_pd(pcldc2 + 4, cx2); + + _mm256_storeu_pd(pcldc3, cv3); + _mm256_storeu_pd(pcldc3 + 4, cx3); + + _mm256_storeu_pd(pcldc4, cv4); + _mm256_storeu_pd(pcldc4 + 4, cx4); + + pc += ldc5;pb += ldb5; + } + + return j; +} + +/* residue dgemm kernel 8mx4n with single load and store of C matrix block + Code could be optimized further, complete ymm register set is not used. + Being residue kernel, its of lesser priority. +*/ +inc_t bli_sqp_dgemm_kernel_8mx4n( gint_t n, + gint_t k, + gint_t j, + double* aPacked, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc) +{ + gint_t p; + __m256d av0; + __m256d bv0, bv1, bv2, bv3; + __m256d cv0, cv1, cv2, cv3; + __m256d cx0, cx1, cx2, cx3; + double* pb, * pc; + + pb = b; + pc = c; + inc_t ldc4 = ldc * 4; inc_t ldb4 = ldb * 4; + + for (; j <= (n - 4); j += 4) { + + double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; double* pcldc3 = pcldc2 + ldc; + double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; double* pbldb3 = pbldb2 + ldb; + + cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); + cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); + cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); + cv3 = _mm256_loadu_pd(pcldc3); cx3 = _mm256_loadu_pd(pcldc3 + 4); + { + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + // better kernel to be written since more register are available. + bv0 = _mm256_broadcast_sd(pb0); pb0++; + bv1 = _mm256_broadcast_sd(pbldb); pbldb++; + bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; + bv3 = _mm256_broadcast_sd(pbldb3); pbldb3++; + + av0 = _mm256_loadu_pd(x); x += 4; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + cv1 = _mm256_fmadd_pd(av0, bv1, cv1); + cv2 = _mm256_fmadd_pd(av0, bv2, cv2); + cv3 = _mm256_fmadd_pd(av0, bv3, cv3); + + av0 = _mm256_loadu_pd(x); x += 4; + cx0 = _mm256_fmadd_pd(av0, bv0, cx0); + cx1 = _mm256_fmadd_pd(av0, bv1, cx1); + cx2 = _mm256_fmadd_pd(av0, bv2, cx2); + cx3 = _mm256_fmadd_pd(av0, bv3, cx3); + } + } + _mm256_storeu_pd(pc, cv0); + _mm256_storeu_pd(pc + 4, cx0); + _mm256_storeu_pd(pcldc, cv1); + _mm256_storeu_pd(pcldc + 4, cx1); + _mm256_storeu_pd(pcldc2, cv2); + _mm256_storeu_pd(pcldc2 + 4, cx2); + _mm256_storeu_pd(pcldc3, cv3); + _mm256_storeu_pd(pcldc3 + 4, cx3); + + pc += ldc4;pb += ldb4; + }// j loop 4 multiple + + return j; +} + +/* residue dgemm kernel 8mx3n with single load and store of C matrix block + Code could be optimized further, complete ymm register set is not used. + Being residue kernel, its of lesser priority. +*/ +inc_t bli_sqp_dgemm_kernel_8mx3n( gint_t n, + gint_t k, + gint_t j, + double* aPacked, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc) +{ + gint_t p; + __m256d av0; + __m256d bv0, bv1, bv2; + __m256d cv0, cv1, cv2; + __m256d cx0, cx1, cx2; + double* pb, * pc; + + pb = b; + pc = c; + + inc_t ldc3 = ldc * 3; inc_t ldb3 = ldb * 3; + + for (; j <= (n - 3); j += 3) { + + double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; + double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; + + cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); + cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); + cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); + { + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + bv0 = _mm256_broadcast_sd(pb0); pb0++; + bv1 = _mm256_broadcast_sd(pbldb); pbldb++; + bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; + + av0 = _mm256_loadu_pd(x); x += 4; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + cv1 = _mm256_fmadd_pd(av0, bv1, cv1); + cv2 = _mm256_fmadd_pd(av0, bv2, cv2); + + av0 = _mm256_loadu_pd(x); x += 4; + cx0 = _mm256_fmadd_pd(av0, bv0, cx0); + cx1 = _mm256_fmadd_pd(av0, bv1, cx1); + cx2 = _mm256_fmadd_pd(av0, bv2, cx2); + } + } + + _mm256_storeu_pd(pc, cv0); + _mm256_storeu_pd(pc + 4, cx0); + _mm256_storeu_pd(pcldc, cv1); + _mm256_storeu_pd(pcldc + 4, cx1); + _mm256_storeu_pd(pcldc2, cv2); + _mm256_storeu_pd(pcldc2 + 4, cx2); + + pc += ldc3;pb += ldb3; + }// j loop 3 multiple + + return j; +} + +/* residue dgemm kernel 8mx2n with single load and store of C matrix block + Code could be optimized further, complete ymm register set is not used. + Being residue kernel, its of lesser priority. +*/ +inc_t bli_sqp_dgemm_kernel_8mx2n( gint_t n, + gint_t k, + gint_t j, + double* aPacked, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc) +{ + gint_t p; + __m256d av0; + __m256d bv0, bv1; + __m256d cv0, cv1; + __m256d cx0, cx1; + double* pb, * pc; + + pb = b; + pc = c; + inc_t ldc2 = ldc * 2; inc_t ldb2 = ldb * 2; + + for (; j <= (n - 2); j += 2) { + double* pcldc = pc + ldc; + double* pbldb = pb + ldb; + + cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); + cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); + { + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + bv0 = _mm256_broadcast_sd(pb0); pb0++; + bv1 = _mm256_broadcast_sd(pbldb); pbldb++; + + av0 = _mm256_loadu_pd(x); x += 4; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + cv1 = _mm256_fmadd_pd(av0, bv1, cv1); + + av0 = _mm256_loadu_pd(x); x += 4; + cx0 = _mm256_fmadd_pd(av0, bv0, cx0); + cx1 = _mm256_fmadd_pd(av0, bv1, cx1); + } + } + _mm256_storeu_pd(pc, cv0); + _mm256_storeu_pd(pc + 4, cx0); + _mm256_storeu_pd(pcldc, cv1); + _mm256_storeu_pd(pcldc + 4, cx1); + + pc += ldc2;pb += ldb2; + }// j loop 2 multiple + + return j; +} + +/* residue dgemm kernel 8mx1n with single load and store of C matrix block + Code could be optimized further, complete ymm register set is not used. + Being residue kernel, its of lesser priority. +*/ +inc_t bli_sqp_dgemm_kernel_8mx1n( gint_t n, + gint_t k, + gint_t j, + double* aPacked, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc) +{ + gint_t p; + __m256d av0; + __m256d bv0; + __m256d cv0; + __m256d cx0; + double* pb, * pc; + + pb = b; + pc = c; + + for (; j <= (n - 1); j += 1) { + cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + bv0 = _mm256_broadcast_sd(pb0); pb0++; + + av0 = _mm256_loadu_pd(x); x += 4; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + + av0 = _mm256_loadu_pd(x); x += 4; + cx0 = _mm256_fmadd_pd(av0, bv0, cx0); + } + _mm256_storeu_pd(pc, cv0); + _mm256_storeu_pd(pc + 4, cx0); + pc += ldc;pb += ldb; + }// j loop 1 multiple + //printf(" 8x1:j:%d ", j); + return j; +} + +#if 0 +/************************************************************************************************************/ +/************************** dgemm kernels (4mxn) column preffered ******************************************/ +/************************************************************************************************************/ +/* Residue dgemm kernel 4mx10n with single load and store of C matrix block + alpha = +/-1 and beta = +/-1,0 handled while packing.*/ +inc_t bli_sqp_dgemm_kernel_4mx10n( gint_t n, + gint_t k, + gint_t j, + double* aPacked, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc) +{ + //tbd + return j; +} + +/* residue dgemm kernel 4mx1n with single load and store of C matrix block + Code could be optimized further, complete ymm register set is not used. + Being residue kernel, its of lesser priority. +*/ +inc_t bli_sqp_dgemm_kernel_4mx1n( gint_t n, + gint_t k, + gint_t j, + double* aPacked, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc) +{ + //tbd + return j; +} +#endif +/************************************************************************************************************/ +/************************** dgemm kernels (1mxn) column preffered ******************************************/ +/************************************************************************************************************/ + +/* residue dgemm kernel 1mx1n with single load and store of C matrix block + Code could be optimized further, complete ymm register set is not used. + Being residue kernel, its of lesser priority. +*/ +inc_t bli_sqp_dgemm_kernel_1mx1n( gint_t n, + gint_t k, + gint_t j, + double* aPacked, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc) +{ + gint_t p; + double a0; + double b0; + double c0; + double* pb, * pc; + + pb = b; + pc = c; + + for (; j <= (n - 1); j += 1) { + c0 = *pc; + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + b0 = *pb0; pb0++; + a0 = *x; x++; + c0 += (a0 * b0); + } + *pc = c0; + pc += ldc;pb += ldb; + }// j loop 1 multiple + //printf(" 1x1:j:%d ", j); + return j; +} + +inc_t bli_sqp_dgemm_kernel_mxn( gint_t n, + gint_t k, + gint_t j, + double* aPacked, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc, + gint_t mx) +{ + gint_t p; + double cx[7]; + + double* pb, * pc; + + pb = b; + pc = c; + + for (; j <= (n - 1); j += 1) { + + for (int i = 0; i < mx; i++) + { + cx[i] = *(pc + i); + } + + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + + double b0 = *pb0; + pb0++; + for (int i = 0; i < mx; i++) + { + cx[i] += (*(x + i)) * b0; + } + x += mx; + } + for (int i = 0; i < mx; i++) + { + *(pc + i) = cx[i]; + } + pc += ldc;pb += ldb; + }// j loop 1 multiple + //printf(" mx1:j:%d ", j); + return j; +} + +void bli_sqp_prepackA( double* pa, + double* aPacked, + gint_t k, + guint_t lda, + bool isTransA, + double alpha, + gint_t mx) +{ + //printf(" pmx:%d ",mx); + if(mx==8) + { + bli_prepackA_8(pa,aPacked,k, lda,isTransA, alpha); + } + else if(mx==4) + { + bli_prepackA_4(pa,aPacked,k, lda,isTransA, alpha); + } + else if(mx>4) + { + bli_prepackA_G4(pa,aPacked,k, lda,isTransA, alpha, mx); + } + else + { + bli_prepackA_L4(pa,aPacked,k, lda,isTransA, alpha, mx); + } +} + +/* Ax8 packing subroutine */ +void bli_prepackA_8(double* pa, + double* aPacked, + gint_t k, + guint_t lda, + bool isTransA, + double alpha) +{ + __m256d av0, av1, ymm0; + if(isTransA==false) + { + if(alpha==1.0) + { + for (gint_t p = 0; p < k; p += 1) { + av0 = _mm256_loadu_pd(pa); av1 = _mm256_loadu_pd(pa + 4); pa += lda; + _mm256_storeu_pd(aPacked, av0); _mm256_storeu_pd(aPacked + 4, av1); + aPacked += BLIS_MX8; + } + } + else if(alpha==-1.0) + { + ymm0 = _mm256_setzero_pd();//set zero + for (gint_t p = 0; p < k; p += 1) { + av0 = _mm256_loadu_pd(pa); av1 = _mm256_loadu_pd(pa + 4); pa += lda; + av0 = _mm256_sub_pd(ymm0,av0); av1 = _mm256_sub_pd(ymm0,av1); // a = 0 - a; + _mm256_storeu_pd(aPacked, av0); _mm256_storeu_pd(aPacked + 4, av1); + aPacked += BLIS_MX8; + } + } + } + else //subroutine below to be optimized + { + if(alpha==1.0) + { + //A Transpose case: + for (gint_t i = 0; i < BLIS_MX8 ; i++) + { + gint_t idx = i * lda; + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+idx+p); + gint_t sidx = p * BLIS_MX8; + *(aPacked + sidx + i) = ar_; + } + } + } + else if(alpha==-1.0) + { + //A Transpose case: + for (gint_t i = 0; i < BLIS_MX8 ; i++) + { + gint_t idx = i * lda; + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+idx+p); + gint_t sidx = p * BLIS_MX8; + *(aPacked + sidx + i) = -ar_; + } + } + } + } +} + +/* Ax4 packing subroutine */ +void bli_prepackA_4(double* pa, + double* aPacked, + gint_t k, + guint_t lda, + bool isTransA, + double alpha) +{ + __m256d av0, ymm0; + if(isTransA==false) + { + if(alpha==1.0) + { + for (gint_t p = 0; p < k; p += 1) { + av0 = _mm256_loadu_pd(pa); pa += lda; + _mm256_storeu_pd(aPacked, av0); + aPacked += BLIS_MX4; + } + } + else if(alpha==-1.0) + { + ymm0 = _mm256_setzero_pd();//set zero + for (gint_t p = 0; p < k; p += 1) { + av0 = _mm256_loadu_pd(pa); pa += lda; + av0 = _mm256_sub_pd(ymm0,av0); // a = 0 - a; + _mm256_storeu_pd(aPacked, av0); + aPacked += BLIS_MX4; + } + } + } + else //subroutine below to be optimized + { + if(alpha==1.0) + { + //A Transpose case: + for (gint_t i = 0; i < BLIS_MX4 ; i++) + { + gint_t idx = i * lda; + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+idx+p); + gint_t sidx = p * BLIS_MX4; + *(aPacked + sidx + i) = ar_; + } + } + } + else if(alpha==-1.0) + { + //A Transpose case: + for (gint_t i = 0; i < BLIS_MX4 ; i++) + { + gint_t idx = i * lda; + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+idx+p); + gint_t sidx = p * BLIS_MX4; + *(aPacked + sidx + i) = -ar_; + } + } + } + } + +} + +/* A packing m>4 subroutine */ +void bli_prepackA_G4( double* pa, + double* aPacked, + gint_t k, + guint_t lda, + bool isTransA, + double alpha, + gint_t mx) +{ + __m256d av0, ymm0; + gint_t mrem = mx - 4; + + if(isTransA==false) + { + if(alpha==1.0) + { + for (gint_t p = 0; p < k; p += 1) { + av0 = _mm256_loadu_pd(pa); + _mm256_storeu_pd(aPacked, av0); + for (gint_t i = 0; i < mrem; i += 1) { + *(aPacked+4+i) = *(pa+4+i); + } + aPacked += mx;pa += lda; + } + } + else if(alpha==-1.0) + { + ymm0 = _mm256_setzero_pd();//set zero + for (gint_t p = 0; p < k; p += 1) { + av0 = _mm256_loadu_pd(pa); + av0 = _mm256_sub_pd(ymm0,av0); // a = 0 - a; + _mm256_storeu_pd(aPacked, av0); + for (gint_t i = 0; i < mrem; i += 1) { + *(aPacked+4+i) = -*(pa+4+i); + } + aPacked += mx;pa += lda; + } + } + } + else //subroutine below to be optimized + { + if(alpha==1.0) + { + //A Transpose case: + for (gint_t i = 0; i < mx ; i++) + { + gint_t idx = i * lda; + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+idx+p); + gint_t sidx = p * mx; + *(aPacked + sidx + i) = ar_; + } + } + } + else if(alpha==-1.0) + { + //A Transpose case: + for (gint_t i = 0; i < mx ; i++) + { + gint_t idx = i * lda; + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+idx+p); + gint_t sidx = p * mx; + *(aPacked + sidx + i) = -ar_; + } + } + } + } + +} + +/* A packing m<4 subroutine */ +void bli_prepackA_L4( double* pa, + double* aPacked, + gint_t k, + guint_t lda, + bool isTransA, + double alpha, + gint_t mx) +{ + if(isTransA==false) + { + if(alpha==1.0) + { + for (gint_t p = 0; p < k; p += 1) + { + for (gint_t i = 0; i < mx; i += 1) + { + *(aPacked+i) = *(pa+i); + } + aPacked += mx;pa += lda; + } + } + else if(alpha==-1.0) + { + for (gint_t p = 0; p < k; p += 1) + { + for (gint_t i = 0; i < mx; i += 1) + { + *(aPacked+i) = -*(pa+i); + } + aPacked += mx;pa += lda; + } + } + } + else + { + if(alpha==1.0) + { + //A Transpose case: + for (gint_t i = 0; i < mx ; i++) + { + gint_t idx = i * lda; + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+idx+p); + gint_t sidx = p * mx; + *(aPacked + sidx + i) = ar_; + } + } + } + else if(alpha==-1.0) + { + //A Transpose case: + for (gint_t i = 0; i < mx ; i++) + { + gint_t idx = i * lda; + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+idx+p); + gint_t sidx = p * mx; + *(aPacked + sidx + i) = -ar_; + } + } + } + } + + +} + +/* Ax1 packing subroutine */ +void bli_prepackA_1(double* pa, + double* aPacked, + gint_t k, + guint_t lda, + bool isTransA, + double alpha) +{ + if(isTransA==false) + { + if(alpha==1.0) + { + for (gint_t p = 0; p < k; p += 1) { + *aPacked = *pa; + pa += lda; + aPacked++; + } + } + else if(alpha==-1.0) + { + for (gint_t p = 0; p < k; p += 1) { + *aPacked = -(*pa); + pa += lda; + aPacked++; + } + } + } + else + { + if(alpha==1.0) + { + //A Transpose case: + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+p); + *(aPacked + p) = ar_; + } + } + else if(alpha==-1.0) + { + //A Transpose case: + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+p); + *(aPacked + p) = -ar_; + } + } + } +} + + +void bli_add_m( gint_t m, + gint_t n, + double* w, + double* c) +{ + double* pc = c; + double* pw = w; + gint_t count = m*n; + gint_t i = 0; + __m256d cv0, wv0; + + for (; i <= (count-4); i+=4) + { + cv0 = _mm256_loadu_pd(pc); + wv0 = _mm256_loadu_pd(pw); pw += 4; + cv0 = _mm256_add_pd(cv0,wv0); + _mm256_storeu_pd(pc, cv0); pc += 4; + } + for (; i < count; i++) + { + *pc = *pc + *pw; + pc++; pw++; + } +} + +void bli_sub_m( gint_t m, + gint_t n, + double* w, + double* c) +{ + double* pc = c; + double* pw = w; + gint_t count = m*n; + gint_t i = 0; + __m256d cv0, wv0; + + for (; i <= (count-4); i+=4) + { + cv0 = _mm256_loadu_pd(pc); + wv0 = _mm256_loadu_pd(pw); pw += 4; + cv0 = _mm256_sub_pd(cv0,wv0); + _mm256_storeu_pd(pc, cv0); pc += 4; + } + for (; i < count; i++) + { + *pc = *pc - *pw; + pc++; pw++; + } +} + +/* Pack real and imaginary parts in separate buffers and also multipy with multiplication factor */ +void bli_3m_sqp_packC_real_imag(double* pc, + guint_t n, + guint_t m, + guint_t ldc, + double* pcr, + double* pci, + double mul, + gint_t mx) +{ + gint_t j, p; + __m256d av0, av1, zerov; + __m256d tv0, tv1; + gint_t max_m = (m*2)-8; + + if((mul ==1.0)||(mul==-1.0)) + { + if(mul ==1.0) /* handles alpha or beta = 1.0 */ + { + for (j = 0; j < n; j++) + { + for (p = 0; p <= max_m; p += 8) + { + double* pbp = pc + p; + av0 = _mm256_loadu_pd(pbp); //ai1, ar1, ai0, ar0 + av1 = _mm256_loadu_pd(pbp+4); //ai3, ar3, ai2, ar2 + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 + av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 + av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 + + _mm256_storeu_pd(pcr, av0); pcr += 4; + _mm256_storeu_pd(pci, av1); pci += 4; + } + + for (; p < (m*2); p += 2)// (real + imag)*m + { + double br = *(pc + p) ; + double bi = *(pc + p + 1); + *pcr = br; + *pci = bi; + pcr++; pci++; + } + pc = pc + ldc; + } + } + else /* handles alpha or beta = - 1.0 */ + { + zerov = _mm256_setzero_pd(); + for (j = 0; j < n; j++) + { + for (p = 0; p <= max_m; p += 8) + { + double* pbp = pc + p; + av0 = _mm256_loadu_pd(pbp); //ai1, ar1, ai0, ar0 + av1 = _mm256_loadu_pd(pbp+4);//ai3, ar3, ai2, ar2 + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 + av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 + av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 + + //negate + av0 = _mm256_sub_pd(zerov,av0); + av1 = _mm256_sub_pd(zerov,av1); + + _mm256_storeu_pd(pcr, av0); pcr += 4; + _mm256_storeu_pd(pci, av1); pci += 4; + } + + for (; p < (m*2); p += 2)// (real + imag)*m + { + double br = -*(pc + p) ; + double bi = -*(pc + p + 1); + *pcr = br; + *pci = bi; + pcr++; pci++; + } + pc = pc + ldc; + } + } + } + else if(mul==0) /* handles alpha or beta is equal to zero */ + { + double br_ = 0; + double bi_ = 0; + for (j = 0; j < n; j++) + { + for (p = 0; p < (m*2); p += 2)// (real + imag)*m + { + *pcr = br_; + *pci = bi_; + pcr++; pci++; + } + pc = pc + ldc; + } + } + else /* handles alpha or beta is not equal +/- 1.0 and zero */ + { + for (j = 0; j < n; j++) + { + for (p = 0; p < (m*2); p += 2)// (real + imag)*m + { + double br_ = mul * (*(pc + p)); + double bi_ = mul * (*(pc + p + 1)); + *pcr = br_; + *pci = bi_; + pcr++; pci++; + } + pc = pc + ldc; + } + } +} + +/* Pack real and imaginary parts in separate buffers and compute sum of real and imaginary part */ +void bli_3m_sqp_packB_real_imag_sum(double* pb, + guint_t n, + guint_t k, + guint_t ldb, + double* pbr, + double* pbi, + double* pbs, + double mul, + gint_t mx) +{ + gint_t j, p; + __m256d av0, av1, zerov; + __m256d tv0, tv1, sum; + gint_t max_k = (k*2) - 8; + if((mul ==1.0)||(mul==-1.0)) + { + if(mul ==1.0) + { + for (j = 0; j < n; j++) + { + for (p=0; p <= max_k; p += 8) + { + double* pbp = pb + p; + av0 = _mm256_loadu_pd(pbp);//ai1, ar1, ai0, ar0 + av1 = _mm256_loadu_pd(pbp+4);//ai3, ar3, ai2, ar2 + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 + av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 + av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 + sum = _mm256_add_pd(av0, av1); + _mm256_storeu_pd(pbr, av0); pbr += 4; + _mm256_storeu_pd(pbi, av1); pbi += 4; + _mm256_storeu_pd(pbs, sum); pbs += 4; + } + + for (; p < (k*2); p += 2)// (real + imag)*k + { + double br = *(pb + p) ; + double bi = *(pb + p + 1); + *pbr = br; + *pbi = bi; + *pbs = br + bi; + + pbr++; pbi++; pbs++; + } + pb = pb + ldb; + } + } + else + { + zerov = _mm256_setzero_pd(); + for (j = 0; j < n; j++) + { + for (p = 0; p <= max_k; p += 8) + { + double* pbp = pb + p; + av0 = _mm256_loadu_pd(pbp);//ai1, ar1, ai0, ar0 + av1 = _mm256_loadu_pd(pbp+4);//ai3, ar3, ai2, ar2 + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 + av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 + av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 + + //negate + av0 = _mm256_sub_pd(zerov,av0); + av1 = _mm256_sub_pd(zerov,av1); + + sum = _mm256_add_pd(av0, av1); + _mm256_storeu_pd(pbr, av0); pbr += 4; + _mm256_storeu_pd(pbi, av1); pbi += 4; + _mm256_storeu_pd(pbs, sum); pbs += 4; + } + + for (; p < (k*2); p += 2)// (real + imag)*k + { + double br = -*(pb + p) ; + double bi = -*(pb + p + 1); + *pbr = br; + *pbi = bi; + *pbs = br + bi; + + pbr++; pbi++; pbs++; + } + pb = pb + ldb; + } + } + } + else + { + for (j = 0; j < n; j++) + { + for (p = 0; p < (k*2); p += 2)// (real + imag)*k + { + double br_ = mul * (*(pb + p)); + double bi_ = mul * (*(pb + p + 1)); + *pbr = br_; + *pbi = bi_; + *pbs = br_ + bi_; + + pbr++; pbi++; pbs++; + } + pb = pb + ldb; + } + } +} + +/* Pack real and imaginary parts of A matrix in separate buffers and compute sum of real and imaginary part */ +void bli_3m_sqp_packA_real_imag_sum(double *pa, + gint_t i, + guint_t k, + guint_t lda, + double *par, + double *pai, + double *pas, + trans_t transa, + gint_t mx, + gint_t p) +{ + __m256d av0, av1, av2, av3; + __m256d tv0, tv1, sum, zerov; + gint_t poffset = p; +#if KLP +#endif + if(mx==8) + { + if(transa == BLIS_NO_TRANSPOSE) + { + pa = pa +i; +#if KLP + pa = pa + (p*lda); +#else + p = 0; +#endif + for (; p < k; p += 1) + { + //for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each. + av0 = _mm256_loadu_pd(pa); + av1 = _mm256_loadu_pd(pa+4); + av2 = _mm256_loadu_pd(pa+8); + av3 = _mm256_loadu_pd(pa+12); + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); + av0 = _mm256_unpacklo_pd(tv0, tv1); + av1 = _mm256_unpackhi_pd(tv0, tv1); + sum = _mm256_add_pd(av0, av1); + _mm256_storeu_pd(par, av0); par += 4; + _mm256_storeu_pd(pai, av1); pai += 4; + _mm256_storeu_pd(pas, sum); pas += 4; + + tv0 = _mm256_permute2f128_pd(av2, av3, 0x20); + tv1 = _mm256_permute2f128_pd(av2, av3, 0x31); + av2 = _mm256_unpacklo_pd(tv0, tv1); + av3 = _mm256_unpackhi_pd(tv0, tv1); + sum = _mm256_add_pd(av2, av3); + _mm256_storeu_pd(par, av2); par += 4; + _mm256_storeu_pd(pai, av3); pai += 4; + _mm256_storeu_pd(pas, sum); pas += 4; + + pa = pa + lda; + } + } + else if(transa == BLIS_CONJ_NO_TRANSPOSE) + { + zerov = _mm256_setzero_pd(); + pa = pa +i; +#if KLP + pa = pa + (p*lda); +#else + p = 0; +#endif + for (; p < k; p += 1) + { + //for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each. + av0 = _mm256_loadu_pd(pa); + av1 = _mm256_loadu_pd(pa+4); + av2 = _mm256_loadu_pd(pa+8); + av3 = _mm256_loadu_pd(pa+12); + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); + av0 = _mm256_unpacklo_pd(tv0, tv1); + av1 = _mm256_unpackhi_pd(tv0, tv1); + av1 = _mm256_sub_pd(zerov,av1);//negate imaginary component + sum = _mm256_add_pd(av0, av1); + _mm256_storeu_pd(par, av0); par += 4; + _mm256_storeu_pd(pai, av1); pai += 4; + _mm256_storeu_pd(pas, sum); pas += 4; + + tv0 = _mm256_permute2f128_pd(av2, av3, 0x20); + tv1 = _mm256_permute2f128_pd(av2, av3, 0x31); + av2 = _mm256_unpacklo_pd(tv0, tv1); + av3 = _mm256_unpackhi_pd(tv0, tv1); + av3 = _mm256_sub_pd(zerov,av3);//negate imaginary component + sum = _mm256_add_pd(av2, av3); + _mm256_storeu_pd(par, av2); par += 4; + _mm256_storeu_pd(pai, av3); pai += 4; + _mm256_storeu_pd(pas, sum); pas += 4; + + pa = pa + lda; + } + } + else if(transa == BLIS_TRANSPOSE) + { + gint_t idx = (i/2) * lda; + pa = pa + idx; +#if KLP +#else + p = 0; +#endif + //A Transpose case: + for (gint_t ii = 0; ii < BLIS_MX8 ; ii++) + { + gint_t idx = ii * lda; + gint_t sidx; + gint_t pidx = 0; + gint_t max_k = (k*2) - 8; + for (p = poffset; p <= max_k; p += 8) + { + double ar0_ = *(pa + idx + p); + double ai0_ = *(pa + idx + p + 1); + + double ar1_ = *(pa + idx + p + 2); + double ai1_ = *(pa + idx + p + 3); + + double ar2_ = *(pa + idx + p + 4); + double ai2_ = *(pa + idx + p + 5); + + double ar3_ = *(pa + idx + p + 6); + double ai3_ = *(pa + idx + p + 7); + + sidx = (pidx/2) * BLIS_MX8; + *(par + sidx + ii) = ar0_; + *(pai + sidx + ii) = ai0_; + *(pas + sidx + ii) = ar0_ + ai0_; + + sidx = ((pidx+2)/2) * BLIS_MX8; + *(par + sidx + ii) = ar1_; + *(pai + sidx + ii) = ai1_; + *(pas + sidx + ii) = ar1_ + ai1_; + + sidx = ((pidx+4)/2) * BLIS_MX8; + *(par + sidx + ii) = ar2_; + *(pai + sidx + ii) = ai2_; + *(pas + sidx + ii) = ar2_ + ai2_; + + sidx = ((pidx+6)/2) * BLIS_MX8; + *(par + sidx + ii) = ar3_; + *(pai + sidx + ii) = ai3_; + *(pas + sidx + ii) = ar3_ + ai3_; + pidx += 8; + + } + + for (; p < (k*2); p += 2) + { + double ar_ = *(pa + idx + p); + double ai_ = *(pa + idx + p + 1); + gint_t sidx = (pidx/2) * BLIS_MX8; + *(par + sidx + ii) = ar_; + *(pai + sidx + ii) = ai_; + *(pas + sidx + ii) = ar_ + ai_; + pidx += 2; + } + } + } + else if(transa == BLIS_CONJ_TRANSPOSE) + { + gint_t idx = (i/2) * lda; + pa = pa + idx; +#if KLP +#else + p = 0; +#endif + //A conjugate Transpose case: + for (gint_t ii = 0; ii < BLIS_MX8 ; ii++) + { + gint_t idx = ii * lda; + gint_t sidx; + gint_t pidx = 0; + gint_t max_k = (k*2) - 8; + for (p = poffset; p <= max_k; p += 8) + { + double ar0_ = *(pa + idx + p); + double ai0_ = -(*(pa + idx + p + 1)); + + double ar1_ = *(pa + idx + p + 2); + double ai1_ = -(*(pa + idx + p + 3)); + + double ar2_ = *(pa + idx + p + 4); + double ai2_ = -(*(pa + idx + p + 5)); + + double ar3_ = *(pa + idx + p + 6); + double ai3_ = -(*(pa + idx + p + 7)); + + sidx = (pidx/2) * BLIS_MX8; + *(par + sidx + ii) = ar0_; + *(pai + sidx + ii) = ai0_; + *(pas + sidx + ii) = ar0_ + ai0_; + + sidx = ((pidx+2)/2) * BLIS_MX8; + *(par + sidx + ii) = ar1_; + *(pai + sidx + ii) = ai1_; + *(pas + sidx + ii) = ar1_ + ai1_; + + sidx = ((pidx+4)/2) * BLIS_MX8; + *(par + sidx + ii) = ar2_; + *(pai + sidx + ii) = ai2_; + *(pas + sidx + ii) = ar2_ + ai2_; + + sidx = ((pidx+6)/2) * BLIS_MX8; + *(par + sidx + ii) = ar3_; + *(pai + sidx + ii) = ai3_; + *(pas + sidx + ii) = ar3_ + ai3_; + pidx += 8; + } + + for (; p < (k*2); p += 2) + { + double ar_ = *(pa + idx + p); + double ai_ = -(*(pa + idx + p + 1)); + gint_t sidx = (pidx/2) * BLIS_MX8; + *(par + sidx + ii) = ar_; + *(pai + sidx + ii) = ai_; + *(pas + sidx + ii) = ar_ + ai_; + pidx += 2; + } + } + } + } //mx==8 + else//mx==1 + { + if(transa == BLIS_NO_TRANSPOSE) + { + pa = pa + i; +#if KLP +#else + p = 0; +#endif + //A No transpose case: + for (; p < k; p += 1) + { + gint_t idx = p * lda; + for (gint_t ii = 0; ii < (mx*2) ; ii += 2) + { //real + imag : Rkernel needs 8 elements each. + double ar_ = *(pa + idx + ii); + double ai_ = *(pa + idx + ii + 1); + *par = ar_; + *pai = ai_; + *pas = ar_ + ai_; + par++; pai++; pas++; + } + } + } + else if(transa == BLIS_CONJ_NO_TRANSPOSE) + { + pa = pa + i; +#if KLP +#else + p = 0; +#endif + //A conjuate No transpose case: + for (; p < k; p += 1) + { + gint_t idx = p * lda; + for (gint_t ii = 0; ii < (mx*2) ; ii += 2) + { //real + imag : Rkernel needs 8 elements each. + double ar_ = *(pa + idx + ii); + double ai_ = -(*(pa + idx + ii + 1));// conjugate: negate imaginary component + *par = ar_; + *pai = ai_; + *pas = ar_ + ai_; + par++; pai++; pas++; + } + } + } + else if(transa == BLIS_TRANSPOSE) + { + gint_t idx = (i/2) * lda; + pa = pa + idx; +#if KLP +#else + p = 0; +#endif + //A Transpose case: + for (gint_t ii = 0; ii < mx ; ii++) + { + gint_t idx = ii * lda; + gint_t sidx; + gint_t pidx = 0; + for (p = poffset;p < (k*2); p += 2) + { + double ar0_ = *(pa + idx + p); + double ai0_ = *(pa + idx + p + 1); + + sidx = (pidx/2) * mx; + *(par + sidx + ii) = ar0_; + *(pai + sidx + ii) = ai0_; + *(pas + sidx + ii) = ar0_ + ai0_; + pidx += 2; + + } + } + } + else if(transa == BLIS_CONJ_TRANSPOSE) + { + gint_t idx = (i/2) * lda; + pa = pa + idx; +#if KLP +#else + p = 0; +#endif + //A Transpose case: + for (gint_t ii = 0; ii < mx ; ii++) + { + gint_t idx = ii * lda; + gint_t sidx; + gint_t pidx = 0; + for (p = poffset;p < (k*2); p += 2) + { + double ar0_ = *(pa + idx + p); + double ai0_ = -(*(pa + idx + p + 1)); + + sidx = (pidx/2) * mx; + *(par + sidx + ii) = ar0_; + *(pai + sidx + ii) = ai0_; + *(pas + sidx + ii) = ar0_ + ai0_; + pidx += 2; + + } + } + } + }//mx==1 +} + diff --git a/kernels/zen/3/bli_gemm_sqp_kernels.h b/kernels/zen/3/bli_gemm_sqp_kernels.h new file mode 100644 index 0000000000..588981fad0 --- /dev/null +++ b/kernels/zen/3/bli_gemm_sqp_kernels.h @@ -0,0 +1,65 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +/* square packed (sqp) kernels */ +#define KLP 1// k loop partition. + +/* sqp dgemm core kernels, targetted mainly for square sizes by default. + sqp framework allows tunning for other shapes.*/ +inc_t bli_sqp_dgemm_kernel_8mx6n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); +inc_t bli_sqp_dgemm_kernel_8mx5n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); +inc_t bli_sqp_dgemm_kernel_8mx4n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); +inc_t bli_sqp_dgemm_kernel_8mx3n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); +inc_t bli_sqp_dgemm_kernel_8mx2n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); +inc_t bli_sqp_dgemm_kernel_8mx1n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); +inc_t bli_sqp_dgemm_kernel_1mx1n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); +inc_t bli_sqp_dgemm_kernel_mxn(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, gint_t mx); + +//add and sub kernels +void bli_add_m(gint_t m,gint_t n,double* w,double* c); +void bli_sub_m(gint_t m, gint_t n, double* w, double* c); + +//packing kernels +//Pack A with alpha multiplication +void bli_sqp_prepackA(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha, gint_t mx); + +void bli_prepackA_8(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha); +void bli_prepackA_4(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha); +void bli_prepackA_G4(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha, gint_t mx); +void bli_prepackA_L4(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha, gint_t mx); +void bli_prepackA_1(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha); + +/* Pack real and imaginary parts in separate buffers and also multipy with multiplication factor */ +void bli_3m_sqp_packC_real_imag(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double mul, gint_t mx); +void bli_3m_sqp_packB_real_imag_sum(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double* pbs, double mul, gint_t mx); +void bli_3m_sqp_packA_real_imag_sum(double *pa, gint_t i, guint_t k, guint_t lda, double *par, double *pai, double *pas, trans_t transa, gint_t mx, gint_t p); \ No newline at end of file diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 161bcef1aa..a93f89cfa7 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -115,7 +115,7 @@ GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_2x8 ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_1x8 ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x4 ) -GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_5x4 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_5x4 ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_4x4 ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_3x4 ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_2x4 ) @@ -199,3 +199,26 @@ GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_2x4n ) GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_1x4n ) GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x2 ) GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x1 ) + +err_t bli_dgemm_small + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ); + +// gemm square matrix size friendly implementation +err_t bli_gemm_sqp + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ); diff --git a/testsuite/src/test_gemm.c b/testsuite/src/test_gemm.c index 65f910f9b1..729b7aee5f 100644 --- a/testsuite/src/test_gemm.c +++ b/testsuite/src/test_gemm.c @@ -35,7 +35,7 @@ #include "blis.h" #include "test_libblis.h" - +#define TEST_SQP 0// ENABLE to test sqp path. // Static variables. static char* op_str = "gemm"; @@ -243,6 +243,18 @@ void libblis_test_gemm_experiment sc_str[0], m, n, &c_save ); // Set alpha and beta. +#if TEST_SQP + if ( bli_obj_is_real( &c ) ) + { + bli_setsc( 1.0, 0.0, &alpha ); + bli_setsc( 1.0, 0.0, &beta ); + } + else + { + bli_setsc( 1.0, 0.0, &alpha ); + bli_setsc( 1.0, 0.0, &beta ); + } +#else if ( bli_obj_is_real( &c ) ) { bli_setsc( 1.2, 0.0, &alpha ); @@ -253,6 +265,7 @@ void libblis_test_gemm_experiment bli_setsc( 1.2, 0.8, &alpha ); bli_setsc( 0.9, 1.0, &beta ); } +#endif #if 0 //bli_setm( &BLIS_ONE, &a ); @@ -270,7 +283,7 @@ void libblis_test_gemm_experiment bli_obj_set_conjtrans( transa, &a ); bli_obj_set_conjtrans( transb, &b ); - // Repeat the experiment n_repeats times and record results. + // Repeat the experiment n_repeats times and record results. for ( i = 0; i < n_repeats; ++i ) { bli_copym( &c_save, &c ); @@ -399,7 +412,7 @@ void libblis_test_gemm_md bli_obj_set_conjtrans( transa, &a ); bli_obj_set_conjtrans( transb, &b ); - // Repeat the experiment n_repeats times and record results. + // Repeat the experiment n_repeats times and record results. for ( i = 0; i < n_repeats; ++i ) { bli_copym( &c_save, &c ); @@ -457,8 +470,17 @@ bli_printm( "c", c, "%5.2f", "" ); //if ( bli_obj_length( b ) == 16 && // bli_obj_stor3_from_strides( c, a, b ) == BLIS_CRR ) //bli_printm( "c before", c, "%6.3f", "" ); + +#if TEST_SQP + if(bli_gemm_sqp(alpha,a,b,beta,c,NULL,NULL)!=BLIS_SUCCESS) + { + printf("\n configuration not supported or failed while calling bli_gemm_sqp. + \n falling back to bli_gemm API"); bli_gemm( alpha, a, b, beta, c ); - //bls_gemm( alpha, a, b, beta, c ); + } +#else//TEST_SQP + bli_gemm( alpha, a, b, beta, c ); +#endif//TEST_SQP #if 0 if ( bli_obj_dt( c ) == BLIS_DCOMPLEX ) bli_printm( "c after", c, "%6.3f", "" );