From 231a464a02fd3e07cda063dfaa1b1b200599e2c9 Mon Sep 17 00:00:00 2001 From: Madan mohan Manokar Date: Tue, 15 Dec 2020 12:55:53 +0530 Subject: [PATCH 01/13] sup zgemm improvement 1. In zgemm, mkernel outperforms nkernel for both m > n, and n > m. 2. Irrespective of mu and nu sizes, mkernel is forced for zgemm based on analysis done. Change-Id: Iafb7ddb2519c17cf2225da84d6cc74ed985cc21e AMD-Internal: [CPUPL-1352] --- frame/3/bli_l3_sup_int.c | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/frame/3/bli_l3_sup_int.c b/frame/3/bli_l3_sup_int.c index e54e01d7c7..e1deb32907 100644 --- a/frame/3/bli_l3_sup_int.c +++ b/frame/3/bli_l3_sup_int.c @@ -181,6 +181,13 @@ err_t bli_gemmsup_int if ( mu >= nu ) use_bp = TRUE; else /* if ( mu < nu ) */ use_bp = FALSE; + // In zgemm, mkernel outperforms nkernel for both m > n and n < m. + // mkernel is forced for zgemm. + if(bli_is_dcomplex(dt)) + { + use_bp = TRUE;//mkernel + } + // If the parallel thread factorization was automatic, we update it // with a new factorization based on the matrix dimensions in units // of micropanels. From 0abc6741d7a0681c5be11a00df952e28c3c7f674 Mon Sep 17 00:00:00 2001 From: Madan mohan Manokar Date: Wed, 3 Feb 2021 15:19:51 +0530 Subject: [PATCH 02/13] gemm_sqp(gemm_squarePacked): 3m_sqp and dgemm_sqp 1. SquarePacked algorithm focuses on efficient zgemm/dgemm implementation for square matrix sizes (m=k=n) 2. Variation of 3m algorithm (3m_sqp) is implemented to allow single load and store of C matrix in kernel. 3. Currently the method supports only m multiple of 8. Residues cases to be implemented later. 4. dgemm Real kernel (dgemm_sqp) implementation without alpha, beta multiple is done, since real alpha and beta scaling are in 3m_sqp framework. 5. gemm_sqp supports dgemm when alpha = +/-1.0 and beta = 1.0. Change-Id: I49becaf6079da4be29be5b06057ff4e50770a7d8 AMD-Internal: [CPUPL-1352] --- frame/compat/bla_gemm.c | 1 - kernels/zen/3/CMakeLists.txt | 11 + kernels/zen/3/bli_gemm_sqp.c | 1025 +++++++++++++++++++++++++++++++++ kernels/zen/bli_kernels_zen.h | 25 +- testsuite/src/test_gemm.c | 28 +- 5 files changed, 1084 insertions(+), 6 deletions(-) create mode 100644 kernels/zen/3/CMakeLists.txt create mode 100644 kernels/zen/3/bli_gemm_sqp.c diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index e04e48cf50..e374fc8d56 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -35,7 +35,6 @@ #include "blis.h" - // // Define BLAS-to-BLIS interfaces. // diff --git a/kernels/zen/3/CMakeLists.txt b/kernels/zen/3/CMakeLists.txt new file mode 100644 index 0000000000..7363f7f173 --- /dev/null +++ b/kernels/zen/3/CMakeLists.txt @@ -0,0 +1,11 @@ +##Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc.## + +target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_small.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_syrk_small.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsm_small.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_sqp.c + ) + +add_subdirectory(sup) diff --git a/kernels/zen/3/bli_gemm_sqp.c b/kernels/zen/3/bli_gemm_sqp.c new file mode 100644 index 0000000000..ca2b339c58 --- /dev/null +++ b/kernels/zen/3/bli_gemm_sqp.c @@ -0,0 +1,1025 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" +#include "immintrin.h" + +#define BLIS_LOADFIRST 0 +#define ENABLE_PREFETCH 1 + +#define MX8 8 +#define DEBUG_3M_SQP 0 + +typedef struct { + siz_t data_size; + siz_t size; + void* alignedBuf; +}mem_block; + +static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, double alpha, double beta, bool isTransA); +static err_t bli_dgemm_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, bool isTransA, double alpha); + +/* +* The bli_gemm_sqp (square packed) function would focus of square matrix sizes, where m=n=k. +* Custom 8mxn block kernels with single load and store of C matrix, to perform gemm computation. +*/ +err_t bli_gemm_sqp + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + // if row major format return. + if ((bli_obj_row_stride( a ) != 1) || + (bli_obj_row_stride( b ) != 1) || + (bli_obj_row_stride( c ) != 1)) + { + return BLIS_INVALID_ROW_STRIDE; + } + + if(bli_obj_has_conj(a) || bli_obj_has_conj(b)) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + + if (bli_obj_has_trans( b )) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + + num_t dt = bli_obj_dt(c); + gint_t m = bli_obj_length( c ); // number of rows of Matrix C + gint_t n = bli_obj_width( c ); // number of columns of Matrix C + gint_t k = bli_obj_length( b ); // number of rows of Matrix B + + guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled. + guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled. + guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C + + if((m==0)||(n==0)||(k==0)) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + + if((dt != BLIS_DCOMPLEX)&&(dt != BLIS_DOUBLE)) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + + bool isTransA = false; + if (bli_obj_has_trans( a )) + { + isTransA = true; + } + + dim_t m8rem = m - ((m>>3)<<3); + if(m8rem!=0) + { + /* Residue kernel m4 and m1 to be implemented */ + return BLIS_NOT_YET_IMPLEMENTED; + } + + double* ap = ( double* )bli_obj_buffer( a ); + double* bp = ( double* )bli_obj_buffer( b ); + double* cp = ( double* )bli_obj_buffer( c ); + if(dt==BLIS_DCOMPLEX) + { + dcomplex* alphap = ( dcomplex* )bli_obj_buffer( alpha ); + dcomplex* betap = ( dcomplex* )bli_obj_buffer( beta ); + + //alpha and beta both real are implemented. alpha and beta with imaginary component to be implemented. + double alpha_real = alphap->real; + double alpha_imag = alphap->imag; + double beta_real = betap->real; + double beta_imag = betap->imag; + if( (alpha_imag!=0)||(beta_imag!=0) ) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + return bli_zgemm_sqp_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, alpha_real, beta_real, isTransA); + } + else if(dt == BLIS_DOUBLE) + { + double *alpha_cast, *beta_cast; + alpha_cast = bli_obj_buffer_for_1x1(BLIS_DOUBLE, alpha); + beta_cast = bli_obj_buffer_for_1x1(BLIS_DOUBLE, beta); + + if((*beta_cast)!=1.0) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + if(((*alpha_cast)!=1.0)&&((*alpha_cast)!=-1.0)) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + return bli_dgemm_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, isTransA, (*alpha_cast)); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); + return BLIS_NOT_YET_IMPLEMENTED; +}; + +/* core dgemm kernel 8mx5n with single load and store of C matrix block + alpha = +/-1 and beta = +/-1,0 handled while packing.*/ +inc_t bli_kernel_8mx5n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) +{ + gint_t p; + + __m256d av0; + __m256d bv0, bv1, bv2, bv3; + __m256d cv0, cv1, cv2, cv3; + __m256d cx0, cx1, cx2, cx3; + __m256d bv4, cv4, cx4; + double* pb, * pc; + + pb = b; + pc = c; + inc_t ldc5 = ldc * 5; inc_t ldb5 = ldb * 5; + + for (j = 0; j <= (n - 5); j += 5) { + + double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; double* pcldc3 = pcldc2 + ldc; double* pcldc4 = pcldc3 + ldc; + double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; double* pbldb3 = pbldb2 + ldb; double* pbldb4 = pbldb3 + ldb; + +#if ENABLE_PREFETCH + _mm_prefetch((char*)(pc), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc2), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc3), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc4), _MM_HINT_T0); + + _mm_prefetch((char*)(aPacked), _MM_HINT_T0); + + _mm_prefetch((char*)(pb), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb2), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb3), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb4), _MM_HINT_T0); +#endif +#if BLIS_LOADFIRST + cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); + cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); + cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); + cv3 = _mm256_loadu_pd(pcldc3); cx3 = _mm256_loadu_pd(pcldc3 + 4); + cv4 = _mm256_loadu_pd(pcldc4); cx4 = _mm256_loadu_pd(pcldc4 + 4); +#else + cv0 = _mm256_setzero_pd(); cx0 = _mm256_setzero_pd(); + cv1 = _mm256_setzero_pd(); cx1 = _mm256_setzero_pd(); + cv2 = _mm256_setzero_pd(); cx2 = _mm256_setzero_pd(); + cv3 = _mm256_setzero_pd(); cx3 = _mm256_setzero_pd(); + cv4 = _mm256_setzero_pd(); cx4 = _mm256_setzero_pd(); +#endif + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + bv0 = _mm256_broadcast_sd(pb0); pb0++; + bv1 = _mm256_broadcast_sd(pbldb); pbldb++; + bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; + bv3 = _mm256_broadcast_sd(pbldb3);pbldb3++; + bv4 = _mm256_broadcast_sd(pbldb4);pbldb4++; + + av0 = _mm256_loadu_pd(x); x += 4; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + cv1 = _mm256_fmadd_pd(av0, bv1, cv1); + cv2 = _mm256_fmadd_pd(av0, bv2, cv2); + cv3 = _mm256_fmadd_pd(av0, bv3, cv3); + cv4 = _mm256_fmadd_pd(av0, bv4, cv4); + + av0 = _mm256_loadu_pd(x); x += 4; + cx0 = _mm256_fmadd_pd(av0, bv0, cx0); + cx1 = _mm256_fmadd_pd(av0, bv1, cx1); + cx2 = _mm256_fmadd_pd(av0, bv2, cx2); + cx3 = _mm256_fmadd_pd(av0, bv3, cx3); + cx4 = _mm256_fmadd_pd(av0, bv4, cx4); + } +#if BLIS_LOADFIRST +#else + bv0 = _mm256_loadu_pd(pc); bv1 = _mm256_loadu_pd(pc + 4); + cv0 = _mm256_add_pd(cv0, bv0); cx0 = _mm256_add_pd(cx0, bv1); + + bv2 = _mm256_loadu_pd(pcldc); bv3 = _mm256_loadu_pd(pcldc + 4); + cv1 = _mm256_add_pd(cv1, bv2); cx1 = _mm256_add_pd(cx1, bv3); + + bv0 = _mm256_loadu_pd(pcldc2); bv1 = _mm256_loadu_pd(pcldc2 + 4); + cv2 = _mm256_add_pd(cv2, bv0); cx2 = _mm256_add_pd(cx2, bv1); + + bv2 = _mm256_loadu_pd(pcldc3); bv3 = _mm256_loadu_pd(pcldc3 + 4); + cv3 = _mm256_add_pd(cv3, bv2); cx3 = _mm256_add_pd(cx3, bv3); + + bv0 = _mm256_loadu_pd(pcldc4); bv1 = _mm256_loadu_pd(pcldc4 + 4); + cv4 = _mm256_add_pd(cv4, bv0); cx4 = _mm256_add_pd(cx4, bv1); +#endif + _mm256_storeu_pd(pc, cv0); + _mm256_storeu_pd(pc + 4, cx0); + + _mm256_storeu_pd(pcldc, cv1); + _mm256_storeu_pd(pcldc + 4, cx1); + + _mm256_storeu_pd(pcldc2, cv2); + _mm256_storeu_pd(pcldc2 + 4, cx2); + + _mm256_storeu_pd(pcldc3, cv3); + _mm256_storeu_pd(pcldc3 + 4, cx3); + + _mm256_storeu_pd(pcldc4, cv4); + _mm256_storeu_pd(pcldc4 + 4, cx4); + + pc += ldc5;pb += ldb5; + } + + return j; +} + +/* residue dgemm kernel 8mx4n with single load and store of C matrix block + Code could be optimized further, complete ymm register set is not used. + Being residue kernel, its of lesser priority. +*/ +inc_t bli_kernel_8mx4n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) +{ + gint_t p; + __m256d av0; + __m256d bv0, bv1, bv2, bv3; + __m256d cv0, cv1, cv2, cv3; + __m256d cx0, cx1, cx2, cx3; + double* pb, * pc; + + pb = b; + pc = c; + inc_t ldc4 = ldc * 4; inc_t ldb4 = ldb * 4; + + for (; j <= (n - 4); j += 4) { + + double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; double* pcldc3 = pcldc2 + ldc; + double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; double* pbldb3 = pbldb2 + ldb; + + cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); + cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); + cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); + cv3 = _mm256_loadu_pd(pcldc3); cx3 = _mm256_loadu_pd(pcldc3 + 4); + { + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + // better kernel to be written since more register are available. + bv0 = _mm256_broadcast_sd(pb0); pb0++; + bv1 = _mm256_broadcast_sd(pbldb); pbldb++; + bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; + bv3 = _mm256_broadcast_sd(pbldb3); pbldb3++; + + av0 = _mm256_loadu_pd(x); x += 4; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + cv1 = _mm256_fmadd_pd(av0, bv1, cv1); + cv2 = _mm256_fmadd_pd(av0, bv2, cv2); + cv3 = _mm256_fmadd_pd(av0, bv3, cv3); + + av0 = _mm256_loadu_pd(x); x += 4; + cx0 = _mm256_fmadd_pd(av0, bv0, cx0); + cx1 = _mm256_fmadd_pd(av0, bv1, cx1); + cx2 = _mm256_fmadd_pd(av0, bv2, cx2); + cx3 = _mm256_fmadd_pd(av0, bv3, cx3); + } + } + _mm256_storeu_pd(pc, cv0); + _mm256_storeu_pd(pc + 4, cx0); + _mm256_storeu_pd(pcldc, cv1); + _mm256_storeu_pd(pcldc + 4, cx1); + _mm256_storeu_pd(pcldc2, cv2); + _mm256_storeu_pd(pcldc2 + 4, cx2); + _mm256_storeu_pd(pcldc3, cv3); + _mm256_storeu_pd(pcldc3 + 4, cx3); + + pc += ldc4;pb += ldb4; + }// j loop 4 multiple + return j; +} + +/* residue dgemm kernel 8mx3n with single load and store of C matrix block + Code could be optimized further, complete ymm register set is not used. + Being residue kernel, its of lesser priority. +*/ +inc_t bli_kernel_8mx3n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) +{ + gint_t p; + __m256d av0; + __m256d bv0, bv1, bv2; + __m256d cv0, cv1, cv2; + __m256d cx0, cx1, cx2; + double* pb, * pc; + + pb = b; + pc = c; + + inc_t ldc3 = ldc * 3; inc_t ldb3 = ldb * 3; + + for (; j <= (n - 3); j += 3) { + + double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; + double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; + + cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); + cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); + cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); + { + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + bv0 = _mm256_broadcast_sd(pb0); pb0++; + bv1 = _mm256_broadcast_sd(pbldb); pbldb++; + bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; + + av0 = _mm256_loadu_pd(x); x += 4; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + cv1 = _mm256_fmadd_pd(av0, bv1, cv1); + cv2 = _mm256_fmadd_pd(av0, bv2, cv2); + + av0 = _mm256_loadu_pd(x); x += 4; + cx0 = _mm256_fmadd_pd(av0, bv0, cx0); + cx1 = _mm256_fmadd_pd(av0, bv1, cx1); + cx2 = _mm256_fmadd_pd(av0, bv2, cx2); + } + } + + _mm256_storeu_pd(pc, cv0); + _mm256_storeu_pd(pc + 4, cx0); + _mm256_storeu_pd(pcldc, cv1); + _mm256_storeu_pd(pcldc + 4, cx1); + _mm256_storeu_pd(pcldc2, cv2); + _mm256_storeu_pd(pcldc2 + 4, cx2); + + pc += ldc3;pb += ldb3; + }// j loop 3 multiple + return j; +} + +/* residue dgemm kernel 8mx2n with single load and store of C matrix block + Code could be optimized further, complete ymm register set is not used. + Being residue kernel, its of lesser priority. +*/ +inc_t bli_kernel_8mx2n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) +{ + gint_t p; + __m256d av0; + __m256d bv0, bv1; + __m256d cv0, cv1; + __m256d cx0, cx1; + double* pb, * pc; + + pb = b; + pc = c; + inc_t ldc2 = ldc * 2; inc_t ldb2 = ldb * 2; + + for (; j <= (n - 2); j += 2) { + double* pcldc = pc + ldc; + double* pbldb = pb + ldb; + + cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); + cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); + { + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + bv0 = _mm256_broadcast_sd(pb0); pb0++; + bv1 = _mm256_broadcast_sd(pbldb); pbldb++; + + av0 = _mm256_loadu_pd(x); x += 4; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + cv1 = _mm256_fmadd_pd(av0, bv1, cv1); + + av0 = _mm256_loadu_pd(x); x += 4; + cx0 = _mm256_fmadd_pd(av0, bv0, cx0); + cx1 = _mm256_fmadd_pd(av0, bv1, cx1); + } + } + _mm256_storeu_pd(pc, cv0); + _mm256_storeu_pd(pc + 4, cx0); + _mm256_storeu_pd(pcldc, cv1); + _mm256_storeu_pd(pcldc + 4, cx1); + + pc += ldc2;pb += ldb2; + }// j loop 2 multiple + return j; +} + +/* residue dgemm kernel 8mx1n with single load and store of C matrix block + Code could be optimized further, complete ymm register set is not used. + Being residue kernel, its of lesser priority. +*/ +inc_t bli_kernel_8mx1n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) +{ + gint_t p; + __m256d av0; + __m256d bv0; + __m256d cv0; + __m256d cx0; + double* pb, * pc; + + pb = b; + pc = c; + + for (; j <= (n - 1); j += 1) { + cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + bv0 = _mm256_broadcast_sd(pb0); pb0++; + + av0 = _mm256_loadu_pd(x); x += 4; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + + av0 = _mm256_loadu_pd(x); x += 4; + cx0 = _mm256_fmadd_pd(av0, bv0, cx0); + } + _mm256_storeu_pd(pc, cv0); + _mm256_storeu_pd(pc + 4, cx0); + pc += ldc;pb += ldb; + }// j loop 1 multiple + return j; +} + +/* Ax8 packing subroutine */ +void bli_prepackA_8(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha) +{ + __m256d av0, av1, ymm0; + if(isTransA==false) + { + if(alpha==1.0) + { + for (gint_t p = 0; p < k; p += 1) { + av0 = _mm256_loadu_pd(pa); av1 = _mm256_loadu_pd(pa + 4); pa += lda; + _mm256_storeu_pd(aPacked, av0); _mm256_storeu_pd(aPacked + 4, av1); + aPacked += MX8; + } + } + else if(alpha==-1.0) + { + ymm0 = _mm256_setzero_pd();//set zero + for (gint_t p = 0; p < k; p += 1) { + av0 = _mm256_loadu_pd(pa); av1 = _mm256_loadu_pd(pa + 4); pa += lda; + av0 = _mm256_sub_pd(ymm0,av0); av1 = _mm256_sub_pd(ymm0,av1); // a = 0 - a; + _mm256_storeu_pd(aPacked, av0); _mm256_storeu_pd(aPacked + 4, av1); + aPacked += MX8; + } + } + } + else + { + if(alpha==1.0) + { + //A Transpose case: + for (gint_t i = 0; i < MX8 ; i++) + { + gint_t idx = i * lda; + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+idx+p); + gint_t sidx = p * MX8; + *(aPacked + sidx + i) = ar_; + } + } + } + else if(alpha==-1.0) + { + //A Transpose case: + for (gint_t i = 0; i < MX8 ; i++) + { + gint_t idx = i * lda; + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+idx+p); + gint_t sidx = p * MX8; + *(aPacked + sidx + i) = -ar_; + } + } + } + } +} + +/* A8x4 packing subroutine */ +void bli_prepackA_8x4(double* pa, double* aPacked, gint_t k, guint_t lda) +{ + __m256d av00, av10; + __m256d av01, av11; + __m256d av02, av12; + __m256d av03, av13; + + for (gint_t p = 0; p < k; p += 4) { + av00 = _mm256_loadu_pd(pa); av10 = _mm256_loadu_pd(pa + 4); pa += lda; + av01 = _mm256_loadu_pd(pa); av11 = _mm256_loadu_pd(pa + 4); pa += lda; + av02 = _mm256_loadu_pd(pa); av12 = _mm256_loadu_pd(pa + 4); pa += lda; + av03 = _mm256_loadu_pd(pa); av13 = _mm256_loadu_pd(pa + 4); pa += lda; + + _mm256_storeu_pd(aPacked, av00); _mm256_storeu_pd(aPacked + 4, av10); + _mm256_storeu_pd(aPacked + 8, av01); _mm256_storeu_pd(aPacked + 12, av11); + _mm256_storeu_pd(aPacked + 16, av02); _mm256_storeu_pd(aPacked + 20, av12); + _mm256_storeu_pd(aPacked + 24, av03); _mm256_storeu_pd(aPacked + 28, av13); + + aPacked += 32; + } +} + +/* dgemm real kernel, which handles m multiple of 8. +m multiple of 4 and 1 to be implemented later */ +static err_t bli_dgemm_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, bool isTransA, double alpha) +{ + double* aPacked; + double* aligned = NULL; + + bool pack_on = false; + if((m!=MX8)||(m!=lda)||isTransA) + { + pack_on = true; + } + + if(pack_on==true) + { + aligned = (double*)bli_malloc_user(sizeof(double) * k * MX8); + } + + for (gint_t i = 0; i < m; i += MX8) //this loop can be threaded. no of workitems = m/8 + { + inc_t j = 0; + double* ci = c + i; + if(pack_on==true) + { + aPacked = aligned; + double *pa = a + i; + if(isTransA==true) + { + pa = a + (i*lda); + } + bli_prepackA_8(pa, aPacked, k, lda, isTransA, alpha); + //bli_prepackA_8x4(a + i, aPacked, k, lda); + } + else + { + aPacked = a+i; + } + + j = bli_kernel_8mx5n(n, k, j, aPacked, lda, b, ldb, ci, ldc); + if (j <= n - 4) + { + j = bli_kernel_8mx4n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); + } + if (j <= n - 3) + { + j = bli_kernel_8mx3n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); + } + if (j <= n - 2) + { + j = bli_kernel_8mx2n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); + } + if (j <= n - 1) + { + j = bli_kernel_8mx1n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); + } + } + + if(pack_on==true) + { + bli_free_user(aligned); + } + + return BLIS_SUCCESS; +} + +gint_t bli_getaligned(mem_block* mem_req) +{ + + guint_t memSize = mem_req->data_size * mem_req->size; + if (memSize == 0) + { + return -1; + } + memSize += 128; + + mem_req->alignedBuf = bli_malloc_user( memSize ); + if (mem_req->alignedBuf == NULL) + { + return -1; + } + return 0; +} + +gint_t bli_allocateWorkspace(gint_t n, gint_t k, mem_block *mxr, mem_block *mxi, mem_block *msx) +{ + //allocate workspace + mxr->data_size = mxi->data_size = msx->data_size = sizeof(double); + mxr->size = mxi->size = n * k; + msx->size = n * k; + mxr->alignedBuf = mxi->alignedBuf = msx->alignedBuf = NULL; + + if (!((bli_getaligned(mxr) == 0) && (bli_getaligned(mxi) == 0) && (bli_getaligned(msx) == 0))) + { + bli_free_user(mxr->alignedBuf); + bli_free_user(mxi->alignedBuf); + bli_free_user(msx->alignedBuf); + return -1; + } + return 0; +} + +void bli_add_m(gint_t m,gint_t n,double* w,double* c) +{ + double* pc = c; + double* pw = w; + for (gint_t i = 0; i < m*n; i++) + { + *pc = *pc + *pw; + pc++; pw++; + } +} + +void bli_sub_m(gint_t m, gint_t n, double* w, double* c) +{ + double* pc = c; + double* pw = w; + for (gint_t i = 0; i < m * n; i++) + { + *pc = *pc - *pw; + pc++; pw++; + } +} + +/****************************************************************/ +/* mmm_sqp implementation, which calls dgemm_sqp as real kernel */ +/****************************************************************/ + + +static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, double alpha, double beta, bool isTransA) +{ + /* B matrix */ + double* br, * bi, * bs; + mem_block mbr, mbi, mbs; + if(bli_allocateWorkspace(n, k, &mbr, &mbi, &mbs)!=0) + { + return BLIS_FAILURE; + } + br = (double*)mbr.alignedBuf; + bi = (double*)mbi.alignedBuf; + bs = (double*)mbs.alignedBuf; + + //multiply lda, ldb and ldc by 2 to account for real and imaginary components per dcomplex. + lda = lda * 2; + ldb = ldb * 2; + ldc = ldc * 2; + +//debug to be removed. +#if DEBUG_3M_SQP + double ax[8][16] = { {10,-10,20,-20,30,-30,40,-40,50,-50,60,-60,70,-70,80,-80}, + {1.1,-1.1,2.1,-2.1,3.1,-3.1,4.1,-4.1,5.1,-5.1,6.1,-6.1,7.1,-7.1,8.1,-8.1}, + {1.2,-1.2,2.2,-2.2,3.2,-3.2,4.2,-4.2,5.2,-5.2,6.2,-6.2,7.2,-7.2,8.2,-8.2}, + {1.3,-1.3,2.3,-2.3,3.3,-3.3,4.3,-4.3,5.3,-5.3,6.3,-6.3,7.3,-7.3,8.3,-8.3}, + + {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6,7,-7,8,-8}, + {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6,7,-7,8,-8}, + {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6,7,-7,8,-8}, + {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6,7,-7,8,-8} }; + + + double bx[6][16] = { {10,-10,20,-20,30,-30,40,-40,50,-50,60,-60,70,-70,80,-80}, + {1.1,-1.1,2.1,-2.1,3.1,-3.1,4.1,-4.1,5.1,-5.1,6.1,-6.1,7.1,-7.1,8.1,-8.1}, + {1.2,-1.2,2.2,-2.2,3.2,-3.2,4.2,-4.2,5.2,-5.2,6.2,-6.2,7.2,-7.2,8.2,-8.2}, + + {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6,7,-7,8,-8}, + {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6,7,-7,8,-8}, + {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6,7,-7,8,-8} }; + + double cx[8][12] = { {10,-10,20,-20,30,-30,40,-40,50,-50,60,-60}, + {1.1,-1.1,2.1,-2.1,3.1,-3.1,4.1,-4.1,5.1,-5.1,6.1,-6.1}, + {1.2,-1.2,2.2,-2.2,3.2,-3.2,4.2,-4.2,5.2,-5.2,6.2,-6.2}, + {1.3,-1.3,2.3,-2.3,3.3,-3.3,4.3,-4.3,5.3,-5.3,6.3,-6.3}, + + {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6}, + {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6}, + {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6}, + {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6} }; + + b = &bx[0][0]; + a = &ax[0][0]; + c = &cx[0][0]; +#endif + + /* Split b (br, bi) and + compute bs = br + bi */ + double* pbr = br; + double* pbi = bi; + double* pbs = bs; + + gint_t j, p; + + /* b matrix real and imag packing and compute to be vectorized. */ + if((alpha ==1.0)||(alpha==-1.0)) + { + if(alpha ==1.0) + { + for (j = 0; j < n; j++) + { + for (p = 0; p < (k*2); p += 2)// (real + imag)*k + { + double br_ = b[(j * ldb) + p]; + double bi_ = b[(j * ldb) + p + 1]; + *pbr = br_; + *pbi = bi_; + *pbs = br_ + bi_; + + pbr++; pbi++; pbs++; + } + } + } + else + { + for (j = 0; j < n; j++) + { + for (p = 0; p < (k*2); p += 2)// (real + imag)*k + { + double br_ = -b[(j * ldb) + p]; + double bi_ = -b[(j * ldb) + p + 1]; + *pbr = br_; + *pbi = bi_; + *pbs = br_ + bi_; + + pbr++; pbi++; pbs++; + } + } + } + } + else + { + for (j = 0; j < n; j++) + { + for (p = 0; p < (k*2); p += 2)// (real + imag)*k + { + double br_ = alpha * b[(j * ldb) + p]; + double bi_ = alpha * b[(j * ldb) + p + 1]; + *pbr = br_; + *pbi = bi_; + *pbs = br_ + bi_; + + pbr++; pbi++; pbs++; + } + } + } + + /* Workspace memory allocation currently done dynamically + This needs to be taken from already allocated memory pool in application for better performance */ + /* A matrix */ + double* ar, * ai, * as; + mem_block mar, mai, mas; + if(bli_allocateWorkspace(8, k, &mar, &mai, &mas) !=0) + { + return BLIS_FAILURE; + } + ar = (double*)mar.alignedBuf; + ai = (double*)mai.alignedBuf; + as = (double*)mas.alignedBuf; + + + /* w matrix */ + double* w; + mem_block mw; + mw.data_size = sizeof(double); + mw.size = 8 * n; + if (bli_getaligned(&mw) != 0) + { + return BLIS_FAILURE; + } + w = (double*)mw.alignedBuf; + + /* cr matrix */ + double* cr; + mem_block mcr; + mcr.data_size = sizeof(double); + mcr.size = 8 * n; + if (bli_getaligned(&mcr) != 0) + { + return BLIS_FAILURE; + } + cr = (double*)mcr.alignedBuf; + + + /* ci matrix */ + double* ci; + mem_block mci; + mci.data_size = sizeof(double); + mci.size = 8 * n; + if (bli_getaligned(&mci) != 0) + { + return BLIS_FAILURE; + } + ci = (double*)mci.alignedBuf; + + for (inc_t i = 0; i < (2*m); i += (2*MX8)) //this loop can be threaded. + { + ////////////// operation 1 ///////////////// + + /* Split a (ar, ai) and + compute as = ar + ai */ + double* par = ar; + double* pai = ai; + double* pas = as; + + /* a matrix real and imag packing and compute to be vectorized. */ + if(isTransA==false) + { + //A No transpose case: + for (gint_t p = 0; p < k; p += 1) { + for (gint_t ii = 0; ii < (2*MX8) ; ii += 2) { //real + imag : Rkernel needs 8 elements each. + double ar_ = a[(p * lda) + i + ii]; + double ai_ = a[(p * lda) + i + ii+1]; + *par = ar_; + *pai = ai_; + *pas = ar_ + ai_; + par++; pai++; pas++; + } + } + } + else + { + //A Transpose case: + for (gint_t ii = 0; ii < MX8 ; ii++) + { + gint_t idx = ((i/2) + ii) * lda; + for (gint_t s = 0; s < (k*2); s += 2) + { + double ar_ = a[ idx + s]; + double ai_ = a[ idx + s + 1]; + gint_t sidx = s * (MX8/2); + *(par + sidx + ii) = ar_; + *(pai + sidx + ii) = ai_; + *(pas + sidx + ii) = ar_ + ai_; + } + } + } + + double* pcr = cr; + double* pci = ci; + + //Split Cr and Ci and beta multiplication done. + if((beta ==1.0)||(beta==-1.0)) + { + if(beta ==1.0) + { + for (j = 0; j < n; j++) + { + for (gint_t ii = 0; ii < (2*MX8); ii += 2) + { + double cr_ = c[(j * ldc) + i + ii]; + double ci_ = c[(j * ldc) + i + ii + 1]; + *pcr = cr_; + *pci = ci_; + pcr++; pci++; + } + } + } + else + { + //beta = -1.0 + for (j = 0; j < n; j++) + { + for (gint_t ii = 0; ii < (2*MX8); ii += 2) + { + double cr_ = -c[(j * ldc) + i + ii]; + double ci_ = -c[(j * ldc) + i + ii + 1]; + *pcr = cr_; + *pci = ci_; + pcr++; pci++; + } + } + } + } + else + { + for (j = 0; j < n; j++) + { + for (gint_t ii = 0; ii < (2*MX8); ii += 2) + { + double cr_ = beta*c[(j * ldc) + i + ii]; + double ci_ = beta*c[(j * ldc) + i + ii + 1]; + *pcr = cr_; + *pci = ci_; + pcr++; pci++; + } + } + } + + //Ci := rgemm( SA, SB, Ci ) + bli_dgemm_m8(MX8, n, k, as, MX8, bs, k, ci, MX8, false, 1.0); + + + + ////////////// operation 2 ///////////////// + //Wr: = dgemm_sqp(Ar, Br, 0) // Wr output 8xn + double* wr = w; + for (j = 0; j < n; j++) { + for (gint_t ii = 0; ii < MX8; ii += 1) { + *wr = 0; + wr++; + } + } + wr = w; + + bli_dgemm_m8(MX8, n, k, ar, MX8, br, k, wr, MX8, false, 1.0); + //Cr : = addm(Wr, Cr) + bli_add_m(MX8, n, wr, cr); + //Ci : = subm(Wr, Ci) + bli_sub_m(MX8, n, wr, ci); + + + + + ////////////// operation 3 ///////////////// + //Wi : = dgemm_sqp(Ai, Bi, 0) // Wi output 8xn + double* wi = w; + for (j = 0; j < n; j++) { + for (gint_t ii = 0; ii < MX8; ii += 1) { + *wi = 0; + wi++; + } + } + wi = w; + + bli_dgemm_m8(MX8, n, k, ai, MX8, bi, k, wi, MX8, false, 1.0); + //Cr : = subm(Wi, Cr) + bli_sub_m(MX8, n, wi, cr); + //Ci : = subm(Wi, Ci) + bli_sub_m(MX8, n, wi, ci); + + + pcr = cr; + pci = ci; + + for (j = 0; j < n; j++) + { + for (gint_t ii = 0; ii < (2*MX8); ii += 2) + { + c[(j * ldc) + i + ii] = *pcr; + c[(j * ldc) + i + ii + 1] = *pci; + pcr++; pci++; + } + } + + } + +//debug to be removed. +#if DEBUG_3M_SQP + for (gint_t jj = 0; jj < n;jj++) + { + for (gint_t ii = 0; ii < m;ii++) + { + printf("( %4.2lf %4.2lf) ", *cr, *ci); + cr++;ci++; + } + printf("\n"); + } +#endif + + /* free workspace buffers */ + bli_free_user(mbr.alignedBuf); + bli_free_user(mbi.alignedBuf); + bli_free_user(mbs.alignedBuf); + bli_free_user(mar.alignedBuf); + bli_free_user(mai.alignedBuf); + bli_free_user(mas.alignedBuf); + bli_free_user(mw.alignedBuf); + bli_free_user(mcr.alignedBuf); + bli_free_user(mci.alignedBuf); + + return BLIS_SUCCESS; +} \ No newline at end of file diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 161bcef1aa..a93f89cfa7 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -115,7 +115,7 @@ GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_2x8 ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_1x8 ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x4 ) -GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_5x4 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_5x4 ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_4x4 ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_3x4 ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_2x4 ) @@ -199,3 +199,26 @@ GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_2x4n ) GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_1x4n ) GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x2 ) GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x1 ) + +err_t bli_dgemm_small + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ); + +// gemm square matrix size friendly implementation +err_t bli_gemm_sqp + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ); diff --git a/testsuite/src/test_gemm.c b/testsuite/src/test_gemm.c index 65f910f9b1..b6d8125054 100644 --- a/testsuite/src/test_gemm.c +++ b/testsuite/src/test_gemm.c @@ -35,7 +35,7 @@ #include "blis.h" #include "test_libblis.h" - +#define TEST_SQP 0// ENABLE to test sqp path. // Static variables. static char* op_str = "gemm"; @@ -243,6 +243,18 @@ void libblis_test_gemm_experiment sc_str[0], m, n, &c_save ); // Set alpha and beta. +#if TEST_SQP + if ( bli_obj_is_real( &c ) ) + { + bli_setsc( 1.0, 0.0, &alpha ); + bli_setsc( 1.0, 0.0, &beta ); + } + else + { + bli_setsc( 1.0, 0.0, &alpha ); + bli_setsc( 1.0, 0.0, &beta ); + } +#else if ( bli_obj_is_real( &c ) ) { bli_setsc( 1.2, 0.0, &alpha ); @@ -253,6 +265,7 @@ void libblis_test_gemm_experiment bli_setsc( 1.2, 0.8, &alpha ); bli_setsc( 0.9, 1.0, &beta ); } +#endif #if 0 //bli_setm( &BLIS_ONE, &a ); @@ -270,7 +283,7 @@ void libblis_test_gemm_experiment bli_obj_set_conjtrans( transa, &a ); bli_obj_set_conjtrans( transb, &b ); - // Repeat the experiment n_repeats times and record results. + // Repeat the experiment n_repeats times and record results. for ( i = 0; i < n_repeats; ++i ) { bli_copym( &c_save, &c ); @@ -399,7 +412,7 @@ void libblis_test_gemm_md bli_obj_set_conjtrans( transa, &a ); bli_obj_set_conjtrans( transb, &b ); - // Repeat the experiment n_repeats times and record results. + // Repeat the experiment n_repeats times and record results. for ( i = 0; i < n_repeats; ++i ) { bli_copym( &c_save, &c ); @@ -457,8 +470,15 @@ bli_printm( "c", c, "%5.2f", "" ); //if ( bli_obj_length( b ) == 16 && // bli_obj_stor3_from_strides( c, a, b ) == BLIS_CRR ) //bli_printm( "c before", c, "%6.3f", "" ); + +#if TEST_SQP + if(bli_gemm_sqp(alpha,a,b,beta,c,NULL,NULL)!=BLIS_SUCCESS) + { bli_gemm( alpha, a, b, beta, c ); - //bls_gemm( alpha, a, b, beta, c ); + } +#else//TEST_SQP + bli_gemm( alpha, a, b, beta, c ); +#endif//TEST_SQP #if 0 if ( bli_obj_dt( c ) == BLIS_DCOMPLEX ) bli_printm( "c after", c, "%6.3f", "" ); From 5dc5ffa06b72e59bbdfa549871509d3846a547d4 Mon Sep 17 00:00:00 2001 From: Madan mohan Manokar Date: Mon, 15 Feb 2021 09:35:22 +0530 Subject: [PATCH 03/13] sqp commenting 1. Added comments. AMD-Internal: [CPUPL-1429] Change-Id: Ie37e24e58cd8bf836038a2258ebd09c3912fab9e --- kernels/zen/3/bli_gemm_sqp.c | 97 +++++++++++++++++++++--------------- 1 file changed, 58 insertions(+), 39 deletions(-) diff --git a/kernels/zen/3/bli_gemm_sqp.c b/kernels/zen/3/bli_gemm_sqp.c index ca2b339c58..e010e0105d 100644 --- a/kernels/zen/3/bli_gemm_sqp.c +++ b/kernels/zen/3/bli_gemm_sqp.c @@ -35,9 +35,9 @@ #include "immintrin.h" #define BLIS_LOADFIRST 0 -#define ENABLE_PREFETCH 1 +#define BLIS_ENABLE_PREFETCH 1 -#define MX8 8 +#define BLIS_MX8 8 #define DEBUG_3M_SQP 0 typedef struct { @@ -50,8 +50,11 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l static err_t bli_dgemm_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, bool isTransA, double alpha); /* -* The bli_gemm_sqp (square packed) function would focus of square matrix sizes, where m=n=k. -* Custom 8mxn block kernels with single load and store of C matrix, to perform gemm computation. +* The bli_gemm_sqp (square packed) function performs dgemm and 3m zgemm. +* It focuses on square matrix sizes, where m=n=k. But supports non-square matrix sizes as well. +* Currently works for m multiple of 8 & column major storage and kernels. It has custom dgemm +* 8mxn block column preferred kernels with single load and store of C matrix to perform dgemm +* , which is also used as real kernel in 3m complex gemm computation. */ err_t bli_gemm_sqp ( @@ -133,6 +136,7 @@ err_t bli_gemm_sqp { return BLIS_NOT_YET_IMPLEMENTED; } + /* 3m zgemm implementation for C = AxB and C = AtxB */ return bli_zgemm_sqp_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, alpha_real, beta_real, isTransA); } else if(dt == BLIS_DOUBLE) @@ -149,6 +153,7 @@ err_t bli_gemm_sqp { return BLIS_NOT_YET_IMPLEMENTED; } + /* dgemm implementation with 8mx5n major kernel and column preferred storage */ return bli_dgemm_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, isTransA, (*alpha_cast)); } @@ -156,7 +161,10 @@ err_t bli_gemm_sqp return BLIS_NOT_YET_IMPLEMENTED; }; -/* core dgemm kernel 8mx5n with single load and store of C matrix block +/************************************************************************************************************/ +/************************** dgemm kernels (8mxn) column preffered ******************************************/ +/************************************************************************************************************/ +/* Main dgemm kernel 8mx5n with single load and store of C matrix block alpha = +/-1 and beta = +/-1,0 handled while packing.*/ inc_t bli_kernel_8mx5n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) { @@ -178,7 +186,7 @@ inc_t bli_kernel_8mx5n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t ld double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; double* pcldc3 = pcldc2 + ldc; double* pcldc4 = pcldc3 + ldc; double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; double* pbldb3 = pbldb2 + ldb; double* pbldb4 = pbldb3 + ldb; -#if ENABLE_PREFETCH +#if BLIS_ENABLE_PREFETCH _mm_prefetch((char*)(pc), _MM_HINT_T0); _mm_prefetch((char*)(pcldc), _MM_HINT_T0); _mm_prefetch((char*)(pcldc2), _MM_HINT_T0); @@ -193,6 +201,7 @@ inc_t bli_kernel_8mx5n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t ld _mm_prefetch((char*)(pbldb3), _MM_HINT_T0); _mm_prefetch((char*)(pbldb4), _MM_HINT_T0); #endif + /* C matrix column major load */ #if BLIS_LOADFIRST cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); @@ -246,6 +255,7 @@ inc_t bli_kernel_8mx5n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t ld bv0 = _mm256_loadu_pd(pcldc4); bv1 = _mm256_loadu_pd(pcldc4 + 4); cv4 = _mm256_add_pd(cv4, bv0); cx4 = _mm256_add_pd(cx4, bv1); #endif + /* C matrix column major store */ _mm256_storeu_pd(pc, cv0); _mm256_storeu_pd(pc + 4, cx0); @@ -484,7 +494,7 @@ void bli_prepackA_8(double* pa, double* aPacked, gint_t k, guint_t lda, bool isT for (gint_t p = 0; p < k; p += 1) { av0 = _mm256_loadu_pd(pa); av1 = _mm256_loadu_pd(pa + 4); pa += lda; _mm256_storeu_pd(aPacked, av0); _mm256_storeu_pd(aPacked + 4, av1); - aPacked += MX8; + aPacked += BLIS_MX8; } } else if(alpha==-1.0) @@ -494,7 +504,7 @@ void bli_prepackA_8(double* pa, double* aPacked, gint_t k, guint_t lda, bool isT av0 = _mm256_loadu_pd(pa); av1 = _mm256_loadu_pd(pa + 4); pa += lda; av0 = _mm256_sub_pd(ymm0,av0); av1 = _mm256_sub_pd(ymm0,av1); // a = 0 - a; _mm256_storeu_pd(aPacked, av0); _mm256_storeu_pd(aPacked + 4, av1); - aPacked += MX8; + aPacked += BLIS_MX8; } } } @@ -503,13 +513,13 @@ void bli_prepackA_8(double* pa, double* aPacked, gint_t k, guint_t lda, bool isT if(alpha==1.0) { //A Transpose case: - for (gint_t i = 0; i < MX8 ; i++) + for (gint_t i = 0; i < BLIS_MX8 ; i++) { gint_t idx = i * lda; for (gint_t p = 0; p < k; p ++) { double ar_ = *(pa+idx+p); - gint_t sidx = p * MX8; + gint_t sidx = p * BLIS_MX8; *(aPacked + sidx + i) = ar_; } } @@ -517,13 +527,13 @@ void bli_prepackA_8(double* pa, double* aPacked, gint_t k, guint_t lda, bool isT else if(alpha==-1.0) { //A Transpose case: - for (gint_t i = 0; i < MX8 ; i++) + for (gint_t i = 0; i < BLIS_MX8 ; i++) { gint_t idx = i * lda; for (gint_t p = 0; p < k; p ++) { double ar_ = *(pa+idx+p); - gint_t sidx = p * MX8; + gint_t sidx = p * BLIS_MX8; *(aPacked + sidx + i) = -ar_; } } @@ -554,25 +564,32 @@ void bli_prepackA_8x4(double* pa, double* aPacked, gint_t k, guint_t lda) } } -/* dgemm real kernel, which handles m multiple of 8. -m multiple of 4 and 1 to be implemented later */ +/************************************************************************************************************/ +/***************************************** dgemm_sqp implementation******************************************/ +/************************************************************************************************************/ +/* dgemm_sqp implementation packs A matrix based on lda and m size. dgemm_sqp focuses mainly on square matrixes + but also supports non-square matrix. Current support is limiteed to m multiple of 8 and column storage. + C = AxB and C = AtxB is handled in the design. AtxB case is done by transposing A matrix while packing A. + In majority of use-case, alpha are +/-1, so instead of explicitly multiplying alpha its done + during packing itself by changing sign. +*/ static err_t bli_dgemm_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, bool isTransA, double alpha) { double* aPacked; double* aligned = NULL; bool pack_on = false; - if((m!=MX8)||(m!=lda)||isTransA) + if((m!=BLIS_MX8)||(m!=lda)||isTransA) { pack_on = true; } if(pack_on==true) { - aligned = (double*)bli_malloc_user(sizeof(double) * k * MX8); + aligned = (double*)bli_malloc_user(sizeof(double) * k * BLIS_MX8); } - for (gint_t i = 0; i < m; i += MX8) //this loop can be threaded. no of workitems = m/8 + for (gint_t i = 0; i < m; i += BLIS_MX8) //this loop can be threaded. no of workitems = m/8 { inc_t j = 0; double* ci = c + i; @@ -677,11 +694,13 @@ void bli_sub_m(gint_t m, gint_t n, double* w, double* c) } } -/****************************************************************/ -/* mmm_sqp implementation, which calls dgemm_sqp as real kernel */ -/****************************************************************/ - - +/************************************************************************************************************/ +/***************************************** 3m_sqp implementation ******************************************/ +/************************************************************************************************************/ +/* 3m_sqp implementation packs A, B and C matrix and uses dgemm_sqp real kernel implementation. + 3m_sqp focuses mainly on square matrixes but also supports non-square matrix. Current support is limiteed to + m multiple of 8 and column storage. +*/ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, double alpha, double beta, bool isTransA) { /* B matrix */ @@ -845,7 +864,7 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l } ci = (double*)mci.alignedBuf; - for (inc_t i = 0; i < (2*m); i += (2*MX8)) //this loop can be threaded. + for (inc_t i = 0; i < (2*m); i += (2*BLIS_MX8)) //this loop can be threaded. { ////////////// operation 1 ///////////////// @@ -860,7 +879,7 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l { //A No transpose case: for (gint_t p = 0; p < k; p += 1) { - for (gint_t ii = 0; ii < (2*MX8) ; ii += 2) { //real + imag : Rkernel needs 8 elements each. + for (gint_t ii = 0; ii < (2*BLIS_MX8) ; ii += 2) { //real + imag : Rkernel needs 8 elements each. double ar_ = a[(p * lda) + i + ii]; double ai_ = a[(p * lda) + i + ii+1]; *par = ar_; @@ -873,14 +892,14 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l else { //A Transpose case: - for (gint_t ii = 0; ii < MX8 ; ii++) + for (gint_t ii = 0; ii < BLIS_MX8 ; ii++) { gint_t idx = ((i/2) + ii) * lda; for (gint_t s = 0; s < (k*2); s += 2) { double ar_ = a[ idx + s]; double ai_ = a[ idx + s + 1]; - gint_t sidx = s * (MX8/2); + gint_t sidx = s * (BLIS_MX8/2); *(par + sidx + ii) = ar_; *(pai + sidx + ii) = ai_; *(pas + sidx + ii) = ar_ + ai_; @@ -898,7 +917,7 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l { for (j = 0; j < n; j++) { - for (gint_t ii = 0; ii < (2*MX8); ii += 2) + for (gint_t ii = 0; ii < (2*BLIS_MX8); ii += 2) { double cr_ = c[(j * ldc) + i + ii]; double ci_ = c[(j * ldc) + i + ii + 1]; @@ -913,7 +932,7 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l //beta = -1.0 for (j = 0; j < n; j++) { - for (gint_t ii = 0; ii < (2*MX8); ii += 2) + for (gint_t ii = 0; ii < (2*BLIS_MX8); ii += 2) { double cr_ = -c[(j * ldc) + i + ii]; double ci_ = -c[(j * ldc) + i + ii + 1]; @@ -928,7 +947,7 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l { for (j = 0; j < n; j++) { - for (gint_t ii = 0; ii < (2*MX8); ii += 2) + for (gint_t ii = 0; ii < (2*BLIS_MX8); ii += 2) { double cr_ = beta*c[(j * ldc) + i + ii]; double ci_ = beta*c[(j * ldc) + i + ii + 1]; @@ -940,7 +959,7 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l } //Ci := rgemm( SA, SB, Ci ) - bli_dgemm_m8(MX8, n, k, as, MX8, bs, k, ci, MX8, false, 1.0); + bli_dgemm_m8(BLIS_MX8, n, k, as, BLIS_MX8, bs, k, ci, BLIS_MX8, false, 1.0); @@ -948,18 +967,18 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l //Wr: = dgemm_sqp(Ar, Br, 0) // Wr output 8xn double* wr = w; for (j = 0; j < n; j++) { - for (gint_t ii = 0; ii < MX8; ii += 1) { + for (gint_t ii = 0; ii < BLIS_MX8; ii += 1) { *wr = 0; wr++; } } wr = w; - bli_dgemm_m8(MX8, n, k, ar, MX8, br, k, wr, MX8, false, 1.0); + bli_dgemm_m8(BLIS_MX8, n, k, ar, BLIS_MX8, br, k, wr, BLIS_MX8, false, 1.0); //Cr : = addm(Wr, Cr) - bli_add_m(MX8, n, wr, cr); + bli_add_m(BLIS_MX8, n, wr, cr); //Ci : = subm(Wr, Ci) - bli_sub_m(MX8, n, wr, ci); + bli_sub_m(BLIS_MX8, n, wr, ci); @@ -968,18 +987,18 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l //Wi : = dgemm_sqp(Ai, Bi, 0) // Wi output 8xn double* wi = w; for (j = 0; j < n; j++) { - for (gint_t ii = 0; ii < MX8; ii += 1) { + for (gint_t ii = 0; ii < BLIS_MX8; ii += 1) { *wi = 0; wi++; } } wi = w; - bli_dgemm_m8(MX8, n, k, ai, MX8, bi, k, wi, MX8, false, 1.0); + bli_dgemm_m8(BLIS_MX8, n, k, ai, BLIS_MX8, bi, k, wi, BLIS_MX8, false, 1.0); //Cr : = subm(Wi, Cr) - bli_sub_m(MX8, n, wi, cr); + bli_sub_m(BLIS_MX8, n, wi, cr); //Ci : = subm(Wi, Ci) - bli_sub_m(MX8, n, wi, ci); + bli_sub_m(BLIS_MX8, n, wi, ci); pcr = cr; @@ -987,7 +1006,7 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l for (j = 0; j < n; j++) { - for (gint_t ii = 0; ii < (2*MX8); ii += 2) + for (gint_t ii = 0; ii < (2*BLIS_MX8); ii += 2) { c[(j * ldc) + i + ii] = *pcr; c[(j * ldc) + i + ii + 1] = *pci; From 87c123f4f3cf2fdbe1f111b381ce9b22f0f06138 Mon Sep 17 00:00:00 2001 From: Madan mohan Manokar Date: Mon, 8 Mar 2021 16:45:33 +0530 Subject: [PATCH 04/13] 3m_sqp vectorization 1. bli_malloc modified to normal malloc and address alignment within 3m_sqp. 2. function added to pack A real,imag and sum. 3. function added to pack B real,imag and sum. 4. function added to pack C real,imag and beta handling. 4. sum and sub vectorized. AMD-Internal: [CPUPL-1352] Change-Id: I514e9efb053d529caef2de413d74d0dac2ceca54 --- kernels/zen/3/bli_gemm_sqp.c | 611 ++++++++++++++++++++++++++--------- 1 file changed, 454 insertions(+), 157 deletions(-) diff --git a/kernels/zen/3/bli_gemm_sqp.c b/kernels/zen/3/bli_gemm_sqp.c index e010e0105d..fab6950d25 100644 --- a/kernels/zen/3/bli_gemm_sqp.c +++ b/kernels/zen/3/bli_gemm_sqp.c @@ -37,6 +37,7 @@ #define BLIS_LOADFIRST 0 #define BLIS_ENABLE_PREFETCH 1 +#define MEM_ALLOC 1//malloc performs better than bli_malloc. #define BLIS_MX8 8 #define DEBUG_3M_SQP 0 @@ -44,6 +45,7 @@ typedef struct { siz_t data_size; siz_t size; void* alignedBuf; + void* unalignedBuf; }mem_block; static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, double alpha, double beta, bool isTransA); @@ -67,7 +69,7 @@ err_t bli_gemm_sqp cntl_t* cntl ) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); // if row major format return. if ((bli_obj_row_stride( a ) != 1) || @@ -88,7 +90,7 @@ err_t bli_gemm_sqp } num_t dt = bli_obj_dt(c); - gint_t m = bli_obj_length( c ); // number of rows of Matrix C + gint_t m = bli_obj_length( c ); // number of rows of Matrix C gint_t n = bli_obj_width( c ); // number of columns of Matrix C gint_t k = bli_obj_length( b ); // number of rows of Matrix B @@ -157,7 +159,7 @@ err_t bli_gemm_sqp return bli_dgemm_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, isTransA, (*alpha_cast)); } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_NOT_YET_IMPLEMENTED; }; @@ -578,36 +580,36 @@ static err_t bli_dgemm_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* aPacked; double* aligned = NULL; - bool pack_on = false; - if((m!=BLIS_MX8)||(m!=lda)||isTransA) - { - pack_on = true; - } + bool pack_on = false; + if((m!=BLIS_MX8)||(m!=lda)||isTransA) + { + pack_on = true; + } - if(pack_on==true) - { - aligned = (double*)bli_malloc_user(sizeof(double) * k * BLIS_MX8); - } + if(pack_on==true) + { + aligned = (double*)bli_malloc_user(sizeof(double) * k * BLIS_MX8); + } for (gint_t i = 0; i < m; i += BLIS_MX8) //this loop can be threaded. no of workitems = m/8 { inc_t j = 0; double* ci = c + i; - if(pack_on==true) - { - aPacked = aligned; + if(pack_on==true) + { + aPacked = aligned; double *pa = a + i; if(isTransA==true) { pa = a + (i*lda); } - bli_prepackA_8(pa, aPacked, k, lda, isTransA, alpha); - //bli_prepackA_8x4(a + i, aPacked, k, lda); - } - else + bli_prepackA_8(pa, aPacked, k, lda, isTransA, alpha); + //bli_prepackA_8x4(a + i, aPacked, k, lda); + } + else { - aPacked = a+i; - } + aPacked = a+i; + } j = bli_kernel_8mx5n(n, k, j, aPacked, lda, b, ldb, ci, ldc); if (j <= n - 4) @@ -644,13 +646,24 @@ gint_t bli_getaligned(mem_block* mem_req) { return -1; } - memSize += 128; + memSize += 128;// extra 128 bytes added for alignment. Could be minimized to 64. +#if MEM_ALLOC + mem_req->unalignedBuf = (double*)malloc(memSize); + if (mem_req->unalignedBuf == NULL) + { + return -1; + } + int64_t address = (int64_t)mem_req->unalignedBuf; + address += (-address) & 63; //64 bytes alignment done. + mem_req->alignedBuf = (double*)address; +#else mem_req->alignedBuf = bli_malloc_user( memSize ); if (mem_req->alignedBuf == NULL) { return -1; } +#endif return 0; } @@ -676,24 +689,419 @@ void bli_add_m(gint_t m,gint_t n,double* w,double* c) { double* pc = c; double* pw = w; - for (gint_t i = 0; i < m*n; i++) + gint_t count = m*n; + gint_t i = 0; + __m256d cv0, wv0; + + for (; i <= (count-4); i+=4) + { + cv0 = _mm256_loadu_pd(pc); + wv0 = _mm256_loadu_pd(pw); pw += 4; + cv0 = _mm256_add_pd(cv0,wv0); + _mm256_storeu_pd(pc, cv0); pc += 4; + } + for (; i < count; i++) { *pc = *pc + *pw; pc++; pw++; } + + } void bli_sub_m(gint_t m, gint_t n, double* w, double* c) { double* pc = c; double* pw = w; - for (gint_t i = 0; i < m * n; i++) + gint_t count = m*n; + gint_t i = 0; + __m256d cv0, wv0; + + for (; i <= (count-4); i+=4) + { + cv0 = _mm256_loadu_pd(pc); + wv0 = _mm256_loadu_pd(pw); pw += 4; + cv0 = _mm256_sub_pd(cv0,wv0); + _mm256_storeu_pd(pc, cv0); pc += 4; + } + for (; i < count; i++) { *pc = *pc - *pw; pc++; pw++; } } +void bli_packX_real_imag(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double mul) +{ + gint_t j, p; + __m256d av0, av1, zerov; + __m256d tv0, tv1; + + if((mul ==1.0)||(mul==-1.0)) + { + if(mul ==1.0) + { + for (j = 0; j < n; j++) + { + for (p = 0; p <= ((k*2)-8); p += 8) + { + double* pbp = pb + p; + av0 = _mm256_loadu_pd(pbp); + av1 = _mm256_loadu_pd(pbp+4); + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); + av0 = _mm256_unpacklo_pd(tv0, tv1); + av1 = _mm256_unpackhi_pd(tv0, tv1); + + _mm256_storeu_pd(pbr, av0); pbr += 4; + _mm256_storeu_pd(pbi, av1); pbi += 4; + } + + for (; p < (k*2); p += 2)// (real + imag)*k + { + double br = *(pb + p) ; + double bi = *(pb + p + 1); + *pbr = br; + *pbi = bi; + pbr++; pbi++; + } + pb = pb + ldb; + } + } + else + { + zerov = _mm256_setzero_pd(); + for (j = 0; j < n; j++) + { + for (p = 0; p <= ((k*2)-8); p += 8) + { + double* pbp = pb + p; + av0 = _mm256_loadu_pd(pbp); + av1 = _mm256_loadu_pd(pbp+4); + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); + av0 = _mm256_unpacklo_pd(tv0, tv1); + av1 = _mm256_unpackhi_pd(tv0, tv1); + + //negate + av0 = _mm256_sub_pd(zerov,av0); + av1 = _mm256_sub_pd(zerov,av1); + + _mm256_storeu_pd(pbr, av0); pbr += 4; + _mm256_storeu_pd(pbi, av1); pbi += 4; + } + + for (; p < (k*2); p += 2)// (real + imag)*k + { + double br = -*(pb + p) ; + double bi = -*(pb + p + 1); + *pbr = br; + *pbi = bi; + pbr++; pbi++; + } + pb = pb + ldb; + } + } + } + else + { + for (j = 0; j < n; j++) + { + for (p = 0; p < (k*2); p += 2)// (real + imag)*k + { + double br_ = mul * (*(pb + p)); + double bi_ = mul * (*(pb + p + 1)); + *pbr = br_; + *pbi = bi_; + pbr++; pbi++; + } + pb = pb + ldb; + } + } +} + +void bli_packX_real_imag_sum(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double* pbs, double mul) +{ + gint_t j, p; + __m256d av0, av1, zerov; + __m256d tv0, tv1, sum; + + if((mul ==1.0)||(mul==-1.0)) + { + if(mul ==1.0) + { + for (j = 0; j < n; j++) + { + for (p = 0; p <= ((k*2)-8); p += 8) + { + double* pbp = pb + p; + av0 = _mm256_loadu_pd(pbp); + av1 = _mm256_loadu_pd(pbp+4); + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); + av0 = _mm256_unpacklo_pd(tv0, tv1); + av1 = _mm256_unpackhi_pd(tv0, tv1); + sum = _mm256_add_pd(av0, av1); + _mm256_storeu_pd(pbr, av0); pbr += 4; + _mm256_storeu_pd(pbi, av1); pbi += 4; + _mm256_storeu_pd(pbs, sum); pbs += 4; + } + + for (; p < (k*2); p += 2)// (real + imag)*k + { + double br = *(pb + p) ; + double bi = *(pb + p + 1); + *pbr = br; + *pbi = bi; + *pbs = br + bi; + + pbr++; pbi++; pbs++; + } + pb = pb + ldb; + } + } + else + { + zerov = _mm256_setzero_pd(); + for (j = 0; j < n; j++) + { + for (p = 0; p <= ((k*2)-8); p += 8) + { + double* pbp = pb + p; + av0 = _mm256_loadu_pd(pbp); + av1 = _mm256_loadu_pd(pbp+4); + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); + av0 = _mm256_unpacklo_pd(tv0, tv1); + av1 = _mm256_unpackhi_pd(tv0, tv1); + + //negate + av0 = _mm256_sub_pd(zerov,av0); + av1 = _mm256_sub_pd(zerov,av1); + + sum = _mm256_add_pd(av0, av1); + _mm256_storeu_pd(pbr, av0); pbr += 4; + _mm256_storeu_pd(pbi, av1); pbi += 4; + _mm256_storeu_pd(pbs, sum); pbs += 4; + } + + for (; p < (k*2); p += 2)// (real + imag)*k + { + double br = -*(pb + p) ; + double bi = -*(pb + p + 1); + *pbr = br; + *pbi = bi; + *pbs = br + bi; + + pbr++; pbi++; pbs++; + } + pb = pb + ldb; + } + } + } + else + { + for (j = 0; j < n; j++) + { + for (p = 0; p < (k*2); p += 2)// (real + imag)*k + { + double br_ = mul * (*(pb + p)); + double bi_ = mul * (*(pb + p + 1)); + *pbr = br_; + *pbi = bi_; + *pbs = br_ + bi_; + + pbr++; pbi++; pbs++; + } + pb = pb + ldb; + } + } +} + +void bli_packA_real_imag_sum(double *pa, gint_t i, guint_t k, guint_t lda, double *par, double *pai, double *pas, bool isTransA) +{ + __m256d av0, av1, av2, av3; + __m256d tv0, tv1, sum; + gint_t p; + if(isTransA==false) + { + pa = pa +i; + for (p = 0; p < k; p += 1) + { + //for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each. + #if 1 + av0 = _mm256_loadu_pd(pa); + av1 = _mm256_loadu_pd(pa+4); + av2 = _mm256_loadu_pd(pa+8); + av3 = _mm256_loadu_pd(pa+12); + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); + av0 = _mm256_unpacklo_pd(tv0, tv1); + av1 = _mm256_unpackhi_pd(tv0, tv1); + sum = _mm256_add_pd(av0, av1); + _mm256_storeu_pd(par, av0); par += 4; + _mm256_storeu_pd(pai, av1); pai += 4; + _mm256_storeu_pd(pas, sum); pas += 4; + + tv0 = _mm256_permute2f128_pd(av2, av3, 0x20); + tv1 = _mm256_permute2f128_pd(av2, av3, 0x31); + av2 = _mm256_unpacklo_pd(tv0, tv1); + av3 = _mm256_unpackhi_pd(tv0, tv1); + sum = _mm256_add_pd(av2, av3); + _mm256_storeu_pd(par, av2); par += 4; + _mm256_storeu_pd(pai, av3); pai += 4; + _mm256_storeu_pd(pas, sum); pas += 4; + #else //method 2 + __m128d high, low, real, img, sum; + av0 = _mm256_loadu_pd(pa); + av1 = _mm256_loadu_pd(pa+4); + av2 = _mm256_loadu_pd(pa+8); + av3 = _mm256_loadu_pd(pa+12); + high = _mm256_extractf128_pd(av0, 1); + low = _mm256_castpd256_pd128(av0); + real = _mm_shuffle_pd(low, high, 0b00); + img = _mm_shuffle_pd(low, high, 0b11); + sum = _mm_add_pd(real, img); + _mm_storeu_pd(par, real); par += 2; + _mm_storeu_pd(pai, img); pai += 2; + _mm_storeu_pd(pas, sum); pas += 2; + + high = _mm256_extractf128_pd(av1, 1); + low = _mm256_castpd256_pd128(av1); + real = _mm_shuffle_pd(low, high, 0b00); + img = _mm_shuffle_pd(low, high, 0b11); + sum = _mm_add_pd(real, img); + _mm_storeu_pd(par, real); par += 2; + _mm_storeu_pd(pai, img); pai += 2; + _mm_storeu_pd(pas, sum); pas += 2; + + high = _mm256_extractf128_pd(av2, 1); + low = _mm256_castpd256_pd128(av2); + real = _mm_shuffle_pd(low, high, 0b00); + img = _mm_shuffle_pd(low, high, 0b11); + sum = _mm_add_pd(real, img); + _mm_storeu_pd(par, real); par += 2; + _mm_storeu_pd(pai, img); pai += 2; + _mm_storeu_pd(pas, sum); pas += 2; + + high = _mm256_extractf128_pd(av3, 1); + low = _mm256_castpd256_pd128(av3); + real = _mm_shuffle_pd(low, high, 0b00); + img = _mm_shuffle_pd(low, high, 0b11); + sum = _mm_add_pd(real, img); + _mm_storeu_pd(par, real); par += 2; + _mm_storeu_pd(pai, img); pai += 2; + _mm_storeu_pd(pas, sum); pas += 2; + #endif + pa = pa + lda; + } + } + else + { + gint_t idx = (i/2) * lda; + pa = pa + idx; + +#if 0 + for (int p = 0; p <= ((2*k)-8); p += 8) + { + //for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each. + av0 = _mm256_loadu_pd(pa); + av1 = _mm256_loadu_pd(pa+4); + av2 = _mm256_loadu_pd(pa+8); + av3 = _mm256_loadu_pd(pa+12); + + //transpose 4x4 + tv0 = _mm256_unpacklo_pd(av0, av1); + tv1 = _mm256_unpackhi_pd(av0, av1); + tv2 = _mm256_unpacklo_pd(av2, av3); + tv3 = _mm256_unpackhi_pd(av2, av3); + + av0 = _mm256_permute2f128_pd(tv0, tv2, 0x20); + av1 = _mm256_permute2f128_pd(tv1, tv3, 0x20); + av2 = _mm256_permute2f128_pd(tv0, tv2, 0x31); + av3 = _mm256_permute2f128_pd(tv1, tv3, 0x31); + + //get real, imag and sum + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); + av0 = _mm256_unpacklo_pd(tv0, tv1); + av1 = _mm256_unpackhi_pd(tv0, tv1); + sum = _mm256_add_pd(av0, av1); + _mm256_storeu_pd(par, av0); par += 4; + _mm256_storeu_pd(pai, av1); pai += 4; + _mm256_storeu_pd(pas, sum); pas += 4; + + tv0 = _mm256_permute2f128_pd(av2, av3, 0x20); + tv1 = _mm256_permute2f128_pd(av2, av3, 0x31); + av2 = _mm256_unpacklo_pd(tv0, tv1); + av3 = _mm256_unpackhi_pd(tv0, tv1); + sum = _mm256_add_pd(av2, av3); + _mm256_storeu_pd(par, av2); par += 4; + _mm256_storeu_pd(pai, av3); pai += 4; + _mm256_storeu_pd(pas, sum); pas += 4; + + pa = pa + lda; + } +#endif + //A Transpose case: + for (gint_t ii = 0; ii < BLIS_MX8 ; ii++) + { + gint_t idx = ii * lda; + gint_t sidx; + for (p = 0; p <= ((k*2)-8); p += 8) + { + double ar0_ = *(pa + idx + p); + double ai0_ = *(pa + idx + p + 1); + + double ar1_ = *(pa + idx + p + 2); + double ai1_ = *(pa + idx + p + 3); + + double ar2_ = *(pa + idx + p + 4); + double ai2_ = *(pa + idx + p + 5); + + double ar3_ = *(pa + idx + p + 6); + double ai3_ = *(pa + idx + p + 7); + + sidx = (p/2) * BLIS_MX8; + *(par + sidx + ii) = ar0_; + *(pai + sidx + ii) = ai0_; + *(pas + sidx + ii) = ar0_ + ai0_; + + sidx = ((p+2)/2) * BLIS_MX8; + *(par + sidx + ii) = ar1_; + *(pai + sidx + ii) = ai1_; + *(pas + sidx + ii) = ar1_ + ai1_; + + sidx = ((p+4)/2) * BLIS_MX8; + *(par + sidx + ii) = ar2_; + *(pai + sidx + ii) = ai2_; + *(pas + sidx + ii) = ar2_ + ai2_; + + sidx = ((p+6)/2) * BLIS_MX8; + *(par + sidx + ii) = ar3_; + *(pai + sidx + ii) = ai3_; + *(pas + sidx + ii) = ar3_ + ai3_; + + } + + for (; p < (k*2); p += 2) + { + double ar_ = *(pa + idx + p); + double ai_ = *(pa + idx + p + 1); + gint_t sidx = (p/2) * BLIS_MX8; + *(par + sidx + ii) = ar_; + *(pai + sidx + ii) = ai_; + *(pas + sidx + ii) = ar_ + ai_; + } + } + } +} + /************************************************************************************************************/ /***************************************** 3m_sqp implementation ******************************************/ /************************************************************************************************************/ @@ -755,66 +1163,16 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l c = &cx[0][0]; #endif - /* Split b (br, bi) and + /* Split b (br, bi) and compute bs = br + bi */ double* pbr = br; double* pbi = bi; double* pbs = bs; - gint_t j, p; + gint_t j; - /* b matrix real and imag packing and compute to be vectorized. */ - if((alpha ==1.0)||(alpha==-1.0)) - { - if(alpha ==1.0) - { - for (j = 0; j < n; j++) - { - for (p = 0; p < (k*2); p += 2)// (real + imag)*k - { - double br_ = b[(j * ldb) + p]; - double bi_ = b[(j * ldb) + p + 1]; - *pbr = br_; - *pbi = bi_; - *pbs = br_ + bi_; - - pbr++; pbi++; pbs++; - } - } - } - else - { - for (j = 0; j < n; j++) - { - for (p = 0; p < (k*2); p += 2)// (real + imag)*k - { - double br_ = -b[(j * ldb) + p]; - double bi_ = -b[(j * ldb) + p + 1]; - *pbr = br_; - *pbi = bi_; - *pbs = br_ + bi_; - - pbr++; pbi++; pbs++; - } - } - } - } - else - { - for (j = 0; j < n; j++) - { - for (p = 0; p < (k*2); p += 2)// (real + imag)*k - { - double br_ = alpha * b[(j * ldb) + p]; - double bi_ = alpha * b[(j * ldb) + p + 1]; - *pbr = br_; - *pbi = bi_; - *pbs = br_ + bi_; - - pbr++; pbi++; pbs++; - } - } - } + /* b matrix real and imag packing and compute. */ + bli_packX_real_imag_sum(b, n, k, ldb, pbr, pbi, pbs, alpha); /* Workspace memory allocation currently done dynamically This needs to be taken from already allocated memory pool in application for better performance */ @@ -874,89 +1232,15 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l double* pai = ai; double* pas = as; - /* a matrix real and imag packing and compute to be vectorized. */ - if(isTransA==false) - { - //A No transpose case: - for (gint_t p = 0; p < k; p += 1) { - for (gint_t ii = 0; ii < (2*BLIS_MX8) ; ii += 2) { //real + imag : Rkernel needs 8 elements each. - double ar_ = a[(p * lda) + i + ii]; - double ai_ = a[(p * lda) + i + ii+1]; - *par = ar_; - *pai = ai_; - *pas = ar_ + ai_; - par++; pai++; pas++; - } - } - } - else - { - //A Transpose case: - for (gint_t ii = 0; ii < BLIS_MX8 ; ii++) - { - gint_t idx = ((i/2) + ii) * lda; - for (gint_t s = 0; s < (k*2); s += 2) - { - double ar_ = a[ idx + s]; - double ai_ = a[ idx + s + 1]; - gint_t sidx = s * (BLIS_MX8/2); - *(par + sidx + ii) = ar_; - *(pai + sidx + ii) = ai_; - *(pas + sidx + ii) = ar_ + ai_; - } - } - } + /* a matrix real and imag packing and compute. */ + bli_packA_real_imag_sum(a, i, k, lda, par, pai, pas, isTransA); double* pcr = cr; double* pci = ci; //Split Cr and Ci and beta multiplication done. - if((beta ==1.0)||(beta==-1.0)) - { - if(beta ==1.0) - { - for (j = 0; j < n; j++) - { - for (gint_t ii = 0; ii < (2*BLIS_MX8); ii += 2) - { - double cr_ = c[(j * ldc) + i + ii]; - double ci_ = c[(j * ldc) + i + ii + 1]; - *pcr = cr_; - *pci = ci_; - pcr++; pci++; - } - } - } - else - { - //beta = -1.0 - for (j = 0; j < n; j++) - { - for (gint_t ii = 0; ii < (2*BLIS_MX8); ii += 2) - { - double cr_ = -c[(j * ldc) + i + ii]; - double ci_ = -c[(j * ldc) + i + ii + 1]; - *pcr = cr_; - *pci = ci_; - pcr++; pci++; - } - } - } - } - else - { - for (j = 0; j < n; j++) - { - for (gint_t ii = 0; ii < (2*BLIS_MX8); ii += 2) - { - double cr_ = beta*c[(j * ldc) + i + ii]; - double ci_ = beta*c[(j * ldc) + i + ii + 1]; - *pcr = cr_; - *pci = ci_; - pcr++; pci++; - } - } - } + double* pc = c + i; + bli_packX_real_imag(pc, n, BLIS_MX8, ldc, pcr, pci, beta); //Ci := rgemm( SA, SB, Ci ) bli_dgemm_m8(BLIS_MX8, n, k, as, BLIS_MX8, bs, k, ci, BLIS_MX8, false, 1.0); @@ -1028,7 +1312,20 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l printf("\n"); } #endif +#if MEM_ALLOC + free(mar.unalignedBuf); + free(mai.unalignedBuf); + free(mas.unalignedBuf); + + free(mw.unalignedBuf); + free(mcr.unalignedBuf); + free(mci.unalignedBuf); + + free(mbr.unalignedBuf); + free(mbi.unalignedBuf); + free(mbs.unalignedBuf); +#else /* free workspace buffers */ bli_free_user(mbr.alignedBuf); bli_free_user(mbi.alignedBuf); @@ -1039,6 +1336,6 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l bli_free_user(mw.alignedBuf); bli_free_user(mcr.alignedBuf); bli_free_user(mci.alignedBuf); - +#endif return BLIS_SUCCESS; } \ No newline at end of file From 2bb4e873f75762dac13ead02cad3d9c7100d8ed6 Mon Sep 17 00:00:00 2001 From: Madan mohan Manokar Date: Fri, 12 Mar 2021 15:15:24 +0530 Subject: [PATCH 05/13] disabled zgemm induced and gemm sqp temporarily. 1. mx1, mx4 kernel addition and framework modification. 2. 8mx6n kernel addition. 3. NULL check added in dgemm_sqp malloc. 4. mem tracing added. 5. Restricted 3m_sqp to limited matrix sizes. 6. Induced methods disabled temporarily for debug. AMD-Internal: [CPUPL-1352] Change-Id: I31671859b32bfbb359687fb7c9056f9eb904c8b2 --- frame/compat/bla_gemm.c | 2 +- kernels/zen/3/bli_gemm_sqp.c | 1525 ++++++++++++++++++++++++---------- 2 files changed, 1093 insertions(+), 434 deletions(-) diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index e374fc8d56..971a203f24 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -38,7 +38,7 @@ // // Define BLAS-to-BLIS interfaces. // - +#define ENABLE_INDUCED_METHOD 0 #ifdef BLIS_BLAS3_CALLS_TAPI #undef GENTFUNC diff --git a/kernels/zen/3/bli_gemm_sqp.c b/kernels/zen/3/bli_gemm_sqp.c index fab6950d25..84924a57f7 100644 --- a/kernels/zen/3/bli_gemm_sqp.c +++ b/kernels/zen/3/bli_gemm_sqp.c @@ -39,6 +39,8 @@ #define MEM_ALLOC 1//malloc performs better than bli_malloc. #define BLIS_MX8 8 +#define BLIS_MX4 4 +#define BLIS_MX1 1 #define DEBUG_3M_SQP 0 typedef struct { @@ -48,8 +50,8 @@ typedef struct { void* unalignedBuf; }mem_block; -static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, double alpha, double beta, bool isTransA); -static err_t bli_dgemm_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, bool isTransA, double alpha); +static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, double alpha, double beta, bool isTransA, gint_t mx, gint_t* p_istart); +static err_t bli_dgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, bool isTransA, double alpha, gint_t mx, gint_t* p_istart); /* * The bli_gemm_sqp (square packed) function performs dgemm and 3m zgemm. @@ -115,15 +117,14 @@ err_t bli_gemm_sqp } dim_t m8rem = m - ((m>>3)<<3); - if(m8rem!=0) - { - /* Residue kernel m4 and m1 to be implemented */ - return BLIS_NOT_YET_IMPLEMENTED; - } double* ap = ( double* )bli_obj_buffer( a ); double* bp = ( double* )bli_obj_buffer( b ); double* cp = ( double* )bli_obj_buffer( c ); + gint_t istart = 0; + gint_t* p_istart = &istart; + *p_istart = 0; + err_t status; if(dt==BLIS_DCOMPLEX) { dcomplex* alphap = ( dcomplex* )bli_obj_buffer( alpha ); @@ -139,7 +140,21 @@ err_t bli_gemm_sqp return BLIS_NOT_YET_IMPLEMENTED; } /* 3m zgemm implementation for C = AxB and C = AtxB */ - return bli_zgemm_sqp_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, alpha_real, beta_real, isTransA); +#if 0 + return bli_zgemm_sqp_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, alpha_real, beta_real, isTransA, 8, p_istart); +#else + status = bli_zgemm_sqp_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, alpha_real, beta_real, isTransA, 8, p_istart); + if(m8rem==0) + { + return status;// No residue: done + } + else + { + //complete residue m blocks + status = bli_zgemm_sqp_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, alpha_real, beta_real, isTransA, 1, p_istart); + return status; + } +#endif } else if(dt == BLIS_DOUBLE) { @@ -156,7 +171,21 @@ err_t bli_gemm_sqp return BLIS_NOT_YET_IMPLEMENTED; } /* dgemm implementation with 8mx5n major kernel and column preferred storage */ - return bli_dgemm_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, isTransA, (*alpha_cast)); + status = bli_dgemm_sqp_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, isTransA, (*alpha_cast), 8, p_istart); + if(status==BLIS_SUCCESS) + { + if(m8rem==0) + { + return status;// No residue: done + } + else + { + //complete residue m blocks + status = bli_dgemm_sqp_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, isTransA, (*alpha_cast), 1, p_istart); + return status; + } + } + } AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); @@ -166,7 +195,133 @@ err_t bli_gemm_sqp /************************************************************************************************************/ /************************** dgemm kernels (8mxn) column preffered ******************************************/ /************************************************************************************************************/ -/* Main dgemm kernel 8mx5n with single load and store of C matrix block + +/* Main dgemm kernel 8mx6n with single load and store of C matrix block + alpha = +/-1 and beta = +/-1,0 handled while packing.*/ +inc_t bli_kernel_8mx6n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) +{ + gint_t p; + + __m256d av0, av1; + __m256d bv0, bv1; + __m256d cv0, cv1, cv2, cv3, cv4, cv5; + __m256d cx0, cx1, cx2, cx3, cx4, cx5; + double* pb, * pc; + + pb = b; + pc = c; + inc_t ldc6 = ldc * 6; inc_t ldb6 = ldb * 6; + + for (j = 0; j <= (n - 6); j += 6) { + + //printf("x"); + double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; double* pcldc3 = pcldc2 + ldc; double* pcldc4 = pcldc3 + ldc; double* pcldc5 = pcldc4 + ldc; + double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; double* pbldb3 = pbldb2 + ldb; double* pbldb4 = pbldb3 + ldb; double* pbldb5 = pbldb4 + ldb; + +#if BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(pc), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc2), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc3), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc4), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc5), _MM_HINT_T0); + + _mm_prefetch((char*)(aPacked), _MM_HINT_T0); + + _mm_prefetch((char*)(pb), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb2), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb3), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb4), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb5), _MM_HINT_T0); +#endif + /* C matrix column major load */ +#if BLIS_LOADFIRST + cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); + cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); + cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); + cv3 = _mm256_loadu_pd(pcldc3); cx3 = _mm256_loadu_pd(pcldc3 + 4); + cv4 = _mm256_loadu_pd(pcldc4); cx4 = _mm256_loadu_pd(pcldc4 + 4); + cv5 = _mm256_loadu_pd(pcldc5); cx5 = _mm256_loadu_pd(pcldc5 + 4); +#else + cv0 = _mm256_setzero_pd(); cx0 = _mm256_setzero_pd(); + cv1 = _mm256_setzero_pd(); cx1 = _mm256_setzero_pd(); + cv2 = _mm256_setzero_pd(); cx2 = _mm256_setzero_pd(); + cv3 = _mm256_setzero_pd(); cx3 = _mm256_setzero_pd(); + cv4 = _mm256_setzero_pd(); cx4 = _mm256_setzero_pd(); + cv5 = _mm256_setzero_pd(); cx5 = _mm256_setzero_pd(); +#endif + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + av0 = _mm256_loadu_pd(x); x += 4; av1 = _mm256_loadu_pd(x); x += 4; + bv0 = _mm256_broadcast_sd (pb0); pb0++; + bv1 = _mm256_broadcast_sd(pbldb); pbldb++; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + cx0 = _mm256_fmadd_pd(av1, bv0, cx0); + cv1 = _mm256_fmadd_pd(av0, bv1, cv1); + cx1 = _mm256_fmadd_pd(av1, bv1, cx1); + + bv0 = _mm256_broadcast_sd(pbldb2);pbldb2++; + bv1 = _mm256_broadcast_sd(pbldb3);pbldb3++; + cv2 = _mm256_fmadd_pd(av0, bv0, cv2); + cx2 = _mm256_fmadd_pd(av1, bv0, cx2); + cv3 = _mm256_fmadd_pd(av0, bv1, cv3); + cx3 = _mm256_fmadd_pd(av1, bv1, cx3); + + bv0 = _mm256_broadcast_sd(pbldb4);pbldb4++; + bv1 = _mm256_broadcast_sd(pbldb5);pbldb5++; + cv4 = _mm256_fmadd_pd(av0, bv0, cv4); + cx4 = _mm256_fmadd_pd(av1, bv0, cx4); + cv5 = _mm256_fmadd_pd(av0, bv1, cv5); + cx5 = _mm256_fmadd_pd(av1, bv1, cx5); + } +#if BLIS_LOADFIRST +#else + bv0 = _mm256_loadu_pd(pc); bv1 = _mm256_loadu_pd(pc + 4); + cv0 = _mm256_add_pd(cv0, bv0); cx0 = _mm256_add_pd(cx0, bv1); + + av0 = _mm256_loadu_pd(pcldc); av1 = _mm256_loadu_pd(pcldc + 4); + cv1 = _mm256_add_pd(cv1, av0); cx1 = _mm256_add_pd(cx1, av1); + + bv0 = _mm256_loadu_pd(pcldc2); bv1 = _mm256_loadu_pd(pcldc2 + 4); + cv2 = _mm256_add_pd(cv2, bv0); cx2 = _mm256_add_pd(cx2, bv1); + + av0 = _mm256_loadu_pd(pcldc3); av1 = _mm256_loadu_pd(pcldc3 + 4); + cv3 = _mm256_add_pd(cv3, av0); cx3 = _mm256_add_pd(cx3, av1); + + bv0 = _mm256_loadu_pd(pcldc4); bv1 = _mm256_loadu_pd(pcldc4 + 4); + cv4 = _mm256_add_pd(cv4, bv0); cx4 = _mm256_add_pd(cx4, bv1); + + av0 = _mm256_loadu_pd(pcldc5); av1 = _mm256_loadu_pd(pcldc5 + 4); + cv5 = _mm256_add_pd(cv5, av0); cx5 = _mm256_add_pd(cx5, av1); +#endif + /* C matrix column major store */ + _mm256_storeu_pd(pc, cv0); + _mm256_storeu_pd(pc + 4, cx0); + + _mm256_storeu_pd(pcldc, cv1); + _mm256_storeu_pd(pcldc + 4, cx1); + + _mm256_storeu_pd(pcldc2, cv2); + _mm256_storeu_pd(pcldc2 + 4, cx2); + + _mm256_storeu_pd(pcldc3, cv3); + _mm256_storeu_pd(pcldc3 + 4, cx3); + + _mm256_storeu_pd(pcldc4, cv4); + _mm256_storeu_pd(pcldc4 + 4, cx4); + + _mm256_storeu_pd(pcldc5, cv5); + _mm256_storeu_pd(pcldc5 + 4, cx5); + + pc += ldc6;pb += ldb6; + } + + return j; +} + +/* alternative Main dgemm kernel 8mx5n with single load and store of C matrix block alpha = +/-1 and beta = +/-1,0 handled while packing.*/ inc_t bli_kernel_8mx5n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) { @@ -485,6 +640,174 @@ inc_t bli_kernel_8mx1n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t ld return j; } +#if 0 +/************************************************************************************************************/ +/************************** dgemm kernels (4mxn) column preffered ******************************************/ +/************************************************************************************************************/ +/* Residue dgemm kernel 4mx10n with single load and store of C matrix block + alpha = +/-1 and beta = +/-1,0 handled while packing.*/ +inc_t bli_kernel_4mx10n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) +{ + gint_t p; + /* incomplete */ + __m256d av0; + __m256d bv0, bv1, bv2, bv3; + __m256d cv0, cv1, cv2, cv3; + __m256d cx0, cx1, cx2, cx3; + __m256d bv4, cv4, cx4; + double* pb, * pc; + + pb = b; + pc = c; + inc_t ldc10 = ldc * 10; inc_t ldb10 = ldb * 10; + + for (j = 0; j <= (n - 10); j += 10) { + + double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; double* pcldc3 = pcldc2 + ldc; double* pcldc4 = pcldc3 + ldc; + double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; double* pbldb3 = pbldb2 + ldb; double* pbldb4 = pbldb3 + ldb; + +#if BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(pc), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc2), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc3), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc4), _MM_HINT_T0); + + _mm_prefetch((char*)(aPacked), _MM_HINT_T0); + + _mm_prefetch((char*)(pb), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb2), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb3), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb4), _MM_HINT_T0); +#endif + /* C matrix column major load */ +#if BLIS_LOADFIRST + cv0 = _mm256_loadu_pd(pc); + cv1 = _mm256_loadu_pd(pcldc); + cv2 = _mm256_loadu_pd(pcldc2); + cv3 = _mm256_loadu_pd(pcldc3); + cv4 = _mm256_loadu_pd(pcldc4); +#else + cv0 = _mm256_setzero_pd(); + cv1 = _mm256_setzero_pd(); + cv2 = _mm256_setzero_pd(); + cv3 = _mm256_setzero_pd(); + cv4 = _mm256_setzero_pd(); +#endif + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + bv0 = _mm256_broadcast_sd(pb0); pb0++; + bv1 = _mm256_broadcast_sd(pbldb); pbldb++; + bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; + bv3 = _mm256_broadcast_sd(pbldb3);pbldb3++; + bv4 = _mm256_broadcast_sd(pbldb4);pbldb4++; + + av0 = _mm256_loadu_pd(x); x += 4; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + cv1 = _mm256_fmadd_pd(av0, bv1, cv1); + cv2 = _mm256_fmadd_pd(av0, bv2, cv2); + cv3 = _mm256_fmadd_pd(av0, bv3, cv3); + cv4 = _mm256_fmadd_pd(av0, bv4, cv4); + + } +#if BLIS_LOADFIRST +#else + bv0 = _mm256_loadu_pd(pc); + cv0 = _mm256_add_pd(cv0, bv0); + + bv2 = _mm256_loadu_pd(pcldc); + cv1 = _mm256_add_pd(cv1, bv2); + + bv0 = _mm256_loadu_pd(pcldc2); + cv2 = _mm256_add_pd(cv2, bv0); + + bv2 = _mm256_loadu_pd(pcldc3); + cv3 = _mm256_add_pd(cv3, bv2); + + bv0 = _mm256_loadu_pd(pcldc4); + cv4 = _mm256_add_pd(cv4, bv0); +#endif + /* C matrix column major store */ + _mm256_storeu_pd(pc, cv0); + _mm256_storeu_pd(pcldc, cv1); + _mm256_storeu_pd(pcldc2, cv2); + _mm256_storeu_pd(pcldc3, cv3); + _mm256_storeu_pd(pcldc4, cv4); + + + pc += ldc10;pb += ldb10; + } + + return j; +} + +/* residue dgemm kernel 4mx1n with single load and store of C matrix block + Code could be optimized further, complete ymm register set is not used. + Being residue kernel, its of lesser priority. +*/ +inc_t bli_kernel_4mx1n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) +{ + gint_t p; + __m256d av0; + __m256d bv0; + __m256d cv0; + double* pb, * pc; + + pb = b; + pc = c; + + for (; j <= (n - 1); j += 1) { + cv0 = _mm256_loadu_pd(pc); + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + bv0 = _mm256_broadcast_sd(pb0); pb0++; + av0 = _mm256_loadu_pd(x); x += 4; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + } + _mm256_storeu_pd(pc, cv0); + pc += ldc;pb += ldb; + }// j loop 1 multiple + return j; +} + +#endif +/************************************************************************************************************/ +/************************** dgemm kernels (1mxn) column preffered ******************************************/ +/************************************************************************************************************/ + +/* residue dgemm kernel 1mx1n with single load and store of C matrix block + Code could be optimized further, complete ymm register set is not used. + Being residue kernel, its of lesser priority. +*/ +inc_t bli_kernel_1mx1n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) +{ + gint_t p; + double a0; + double b0; + double c0; + double* pb, * pc; + + pb = b; + pc = c; + + for (; j <= (n - 1); j += 1) { + c0 = *pc; + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + b0 = *pb0; pb0++; + a0 = *x; x++; + c0 += (a0 * b0); + } + *pc = c0; + pc += ldc;pb += ldb; + }// j loop 1 multiple + return j; +} + /* Ax8 packing subroutine */ void bli_prepackA_8(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha) { @@ -543,55 +866,140 @@ void bli_prepackA_8(double* pa, double* aPacked, gint_t k, guint_t lda, bool isT } } -/* A8x4 packing subroutine */ -void bli_prepackA_8x4(double* pa, double* aPacked, gint_t k, guint_t lda) -{ - __m256d av00, av10; - __m256d av01, av11; - __m256d av02, av12; - __m256d av03, av13; - - for (gint_t p = 0; p < k; p += 4) { - av00 = _mm256_loadu_pd(pa); av10 = _mm256_loadu_pd(pa + 4); pa += lda; - av01 = _mm256_loadu_pd(pa); av11 = _mm256_loadu_pd(pa + 4); pa += lda; - av02 = _mm256_loadu_pd(pa); av12 = _mm256_loadu_pd(pa + 4); pa += lda; - av03 = _mm256_loadu_pd(pa); av13 = _mm256_loadu_pd(pa + 4); pa += lda; - - _mm256_storeu_pd(aPacked, av00); _mm256_storeu_pd(aPacked + 4, av10); - _mm256_storeu_pd(aPacked + 8, av01); _mm256_storeu_pd(aPacked + 12, av11); - _mm256_storeu_pd(aPacked + 16, av02); _mm256_storeu_pd(aPacked + 20, av12); - _mm256_storeu_pd(aPacked + 24, av03); _mm256_storeu_pd(aPacked + 28, av13); - - aPacked += 32; - } -} - -/************************************************************************************************************/ -/***************************************** dgemm_sqp implementation******************************************/ -/************************************************************************************************************/ -/* dgemm_sqp implementation packs A matrix based on lda and m size. dgemm_sqp focuses mainly on square matrixes - but also supports non-square matrix. Current support is limiteed to m multiple of 8 and column storage. - C = AxB and C = AtxB is handled in the design. AtxB case is done by transposing A matrix while packing A. - In majority of use-case, alpha are +/-1, so instead of explicitly multiplying alpha its done - during packing itself by changing sign. -*/ -static err_t bli_dgemm_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, bool isTransA, double alpha) +/* Ax4 packing subroutine */ +void bli_prepackA_4(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha) { - double* aPacked; - double* aligned = NULL; - - bool pack_on = false; - if((m!=BLIS_MX8)||(m!=lda)||isTransA) - { - pack_on = true; - } - - if(pack_on==true) + __m256d av0, ymm0; + if(isTransA==false) { - aligned = (double*)bli_malloc_user(sizeof(double) * k * BLIS_MX8); - } + if(alpha==1.0) + { + for (gint_t p = 0; p < k; p += 1) { + av0 = _mm256_loadu_pd(pa); pa += lda; + _mm256_storeu_pd(aPacked, av0); + aPacked += BLIS_MX4; + } + } + else if(alpha==-1.0) + { + ymm0 = _mm256_setzero_pd();//set zero + for (gint_t p = 0; p < k; p += 1) { + av0 = _mm256_loadu_pd(pa); pa += lda; + av0 = _mm256_sub_pd(ymm0,av0); // a = 0 - a; + _mm256_storeu_pd(aPacked, av0); + aPacked += BLIS_MX4; + } + } + } + else + { + if(alpha==1.0) + { + //A Transpose case: + for (gint_t i = 0; i < BLIS_MX4 ; i++) + { + gint_t idx = i * lda; + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+idx+p); + gint_t sidx = p * BLIS_MX4; + *(aPacked + sidx + i) = ar_; + } + } + } + else if(alpha==-1.0) + { + //A Transpose case: + for (gint_t i = 0; i < BLIS_MX4 ; i++) + { + gint_t idx = i * lda; + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+idx+p); + gint_t sidx = p * BLIS_MX4; + *(aPacked + sidx + i) = -ar_; + } + } + } + } +} - for (gint_t i = 0; i < m; i += BLIS_MX8) //this loop can be threaded. no of workitems = m/8 +/* Ax1 packing subroutine */ +void bli_prepackA_1(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha) +{ + if(isTransA==false) + { + if(alpha==1.0) + { + for (gint_t p = 0; p < k; p += 1) { + *aPacked = *pa; + pa += lda; + aPacked++; + } + } + else if(alpha==-1.0) + { + for (gint_t p = 0; p < k; p += 1) { + *aPacked = -(*pa); + pa += lda; + aPacked++; + } + } + } + else + { + if(alpha==1.0) + { + //A Transpose case: + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+p); + *(aPacked + p) = ar_; + } + } + else if(alpha==-1.0) + { + //A Transpose case: + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+p); + *(aPacked + p) = -ar_; + } + } + } +} + +/************************************************************************************************************/ +/***************************************** dgemm_sqp implementation******************************************/ +/************************************************************************************************************/ +/* dgemm_sqp implementation packs A matrix based on lda and m size. dgemm_sqp focuses mainly on square matrixes + but also supports non-square matrix. Current support is limiteed to m multiple of 8 and column storage. + C = AxB and C = AtxB is handled in the design. AtxB case is done by transposing A matrix while packing A. + In majority of use-case, alpha are +/-1, so instead of explicitly multiplying alpha its done + during packing itself by changing sign. +*/ +static err_t bli_dgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, bool isTransA, double alpha, gint_t mx, gint_t* p_istart) +{ + double* aPacked; + double* aligned = NULL; + gint_t i; + + bool pack_on = false; + if((m!=mx)||(m!=lda)||isTransA) + { + pack_on = true; + } + + if(pack_on==true) + { + aligned = (double*)bli_malloc_user(sizeof(double) * k * mx); + if(aligned==NULL) + { + return BLIS_MALLOC_RETURNED_NULL; + } + } + + for (i = (*p_istart); i <= (m-mx); i += mx) //this loop can be threaded. no of workitems = m/8 { inc_t j = 0; double* ci = c + i; @@ -603,31 +1011,57 @@ static err_t bli_dgemm_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, { pa = a + (i*lda); } - bli_prepackA_8(pa, aPacked, k, lda, isTransA, alpha); - //bli_prepackA_8x4(a + i, aPacked, k, lda); + /* should be changed to func pointer */ + if(mx==8) + { + bli_prepackA_8(pa, aPacked, k, lda, isTransA, alpha); + } + else if(mx==4) + { + bli_prepackA_4(pa, aPacked, k, lda, isTransA, alpha); + } + else if(mx==1) + { + bli_prepackA_1(pa, aPacked, k, lda, isTransA, alpha); + } } else { aPacked = a+i; } - - j = bli_kernel_8mx5n(n, k, j, aPacked, lda, b, ldb, ci, ldc); - if (j <= n - 4) - { - j = bli_kernel_8mx4n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); - } - if (j <= n - 3) + if(mx==8) { - j = bli_kernel_8mx3n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); - } - if (j <= n - 2) - { - j = bli_kernel_8mx2n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); + //printf(" mx8i:%3ld ", i); + //8mx6n currently turned off to isolate a bug. + //j = bli_kernel_8mx6n(n, k, j, aPacked, lda, b, ldb, ci, ldc); + if (j <= n - 5) + { + j = bli_kernel_8mx5n(n, k, j, aPacked, lda, b, ldb, ci, ldc); + } + if (j <= n - 4) + { + j = bli_kernel_8mx4n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); + } + if (j <= n - 3) + { + j = bli_kernel_8mx3n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); + } + if (j <= n - 2) + { + j = bli_kernel_8mx2n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); + } + if (j <= n - 1) + { + j = bli_kernel_8mx1n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); + } } - if (j <= n - 1) + /* mx==4 to be implemented */ + else if(mx==1) { - j = bli_kernel_8mx1n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); + //printf(" mx1i:%3ld ", i); + j = bli_kernel_1mx1n(n, k, j, aPacked, lda, b, ldb, ci, ldc); } + *p_istart = i + mx; } if(pack_on==true) @@ -648,6 +1082,10 @@ gint_t bli_getaligned(mem_block* mem_req) } memSize += 128;// extra 128 bytes added for alignment. Could be minimized to 64. #if MEM_ALLOC +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "malloc(): size %ld\n",( long )memSize; + fflush( stdout ); +#endif mem_req->unalignedBuf = (double*)malloc(memSize); if (mem_req->unalignedBuf == NULL) { @@ -677,9 +1115,24 @@ gint_t bli_allocateWorkspace(gint_t n, gint_t k, mem_block *mxr, mem_block *mxi, if (!((bli_getaligned(mxr) == 0) && (bli_getaligned(mxi) == 0) && (bli_getaligned(msx) == 0))) { +#if MEM_ALLOC + if(mxr->unalignedBuf) + { + free(mxr->unalignedBuf); + } + if(mxi->unalignedBuf) + { + free(mxi->unalignedBuf); + } + if(msx->unalignedBuf) + { + free(msx->unalignedBuf); + } +#else bli_free_user(mxr->alignedBuf); bli_free_user(mxi->alignedBuf); bli_free_user(msx->alignedBuf); +#endif return -1; } return 0; @@ -731,375 +1184,492 @@ void bli_sub_m(gint_t m, gint_t n, double* w, double* c) } } -void bli_packX_real_imag(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double mul) +void bli_packX_real_imag(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double mul, gint_t mx) { gint_t j, p; __m256d av0, av1, zerov; __m256d tv0, tv1; - - if((mul ==1.0)||(mul==-1.0)) + if(mx==8) { - if(mul ==1.0) + if((mul ==1.0)||(mul==-1.0)) { - for (j = 0; j < n; j++) + if(mul ==1.0) { - for (p = 0; p <= ((k*2)-8); p += 8) + for (j = 0; j < n; j++) { - double* pbp = pb + p; - av0 = _mm256_loadu_pd(pbp); - av1 = _mm256_loadu_pd(pbp+4); - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); - av0 = _mm256_unpacklo_pd(tv0, tv1); - av1 = _mm256_unpackhi_pd(tv0, tv1); - - _mm256_storeu_pd(pbr, av0); pbr += 4; - _mm256_storeu_pd(pbi, av1); pbi += 4; + for (p = 0; p <= ((k*2)-8); p += 8) + { + double* pbp = pb + p; + av0 = _mm256_loadu_pd(pbp); + av1 = _mm256_loadu_pd(pbp+4); + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); + av0 = _mm256_unpacklo_pd(tv0, tv1); + av1 = _mm256_unpackhi_pd(tv0, tv1); + + _mm256_storeu_pd(pbr, av0); pbr += 4; + _mm256_storeu_pd(pbi, av1); pbi += 4; + } + + for (; p < (k*2); p += 2)// (real + imag)*k + { + double br = *(pb + p) ; + double bi = *(pb + p + 1); + *pbr = br; + *pbi = bi; + pbr++; pbi++; + } + pb = pb + ldb; } - - for (; p < (k*2); p += 2)// (real + imag)*k + } + else + { + zerov = _mm256_setzero_pd(); + for (j = 0; j < n; j++) { - double br = *(pb + p) ; - double bi = *(pb + p + 1); - *pbr = br; - *pbi = bi; - pbr++; pbi++; + for (p = 0; p <= ((k*2)-8); p += 8) + { + double* pbp = pb + p; + av0 = _mm256_loadu_pd(pbp); + av1 = _mm256_loadu_pd(pbp+4); + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); + av0 = _mm256_unpacklo_pd(tv0, tv1); + av1 = _mm256_unpackhi_pd(tv0, tv1); + + //negate + av0 = _mm256_sub_pd(zerov,av0); + av1 = _mm256_sub_pd(zerov,av1); + + _mm256_storeu_pd(pbr, av0); pbr += 4; + _mm256_storeu_pd(pbi, av1); pbi += 4; + } + + for (; p < (k*2); p += 2)// (real + imag)*k + { + double br = -*(pb + p) ; + double bi = -*(pb + p + 1); + *pbr = br; + *pbi = bi; + pbr++; pbi++; + } + pb = pb + ldb; } - pb = pb + ldb; } } else { - zerov = _mm256_setzero_pd(); for (j = 0; j < n; j++) { - for (p = 0; p <= ((k*2)-8); p += 8) - { - double* pbp = pb + p; - av0 = _mm256_loadu_pd(pbp); - av1 = _mm256_loadu_pd(pbp+4); - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); - av0 = _mm256_unpacklo_pd(tv0, tv1); - av1 = _mm256_unpackhi_pd(tv0, tv1); - - //negate - av0 = _mm256_sub_pd(zerov,av0); - av1 = _mm256_sub_pd(zerov,av1); - - _mm256_storeu_pd(pbr, av0); pbr += 4; - _mm256_storeu_pd(pbi, av1); pbi += 4; - } - - for (; p < (k*2); p += 2)// (real + imag)*k + for (p = 0; p < (k*2); p += 2)// (real + imag)*k { - double br = -*(pb + p) ; - double bi = -*(pb + p + 1); - *pbr = br; - *pbi = bi; + double br_ = mul * (*(pb + p)); + double bi_ = mul * (*(pb + p + 1)); + *pbr = br_; + *pbi = bi_; pbr++; pbi++; } pb = pb + ldb; } } - } - else + }//mx==8 +#if 0//already taken care in previous loop + else//mx==1 { - for (j = 0; j < n; j++) + if((mul ==1.0)||(mul==-1.0)) { - for (p = 0; p < (k*2); p += 2)// (real + imag)*k + if(mul ==1.0) { - double br_ = mul * (*(pb + p)); - double bi_ = mul * (*(pb + p + 1)); - *pbr = br_; - *pbi = bi_; - pbr++; pbi++; + for (j = 0; j < n; j++) + { + for (gint_t ii = 0; ii < (mx*2); ii += 2) + { + double cr_ = c[(j * ldc) + i + ii]; + double ci_ = c[(j * ldc) + i + ii + 1]; + *pcr = cr_; + *pci = ci_; + pcr++; pci++; + } + } + } + else + { + //mul = -1.0 + for (j = 0; j < n; j++) + { + for (gint_t ii = 0; ii < (mx*2); ii += 2) + { + double cr_ = -c[(j * ldc) + i + ii]; + double ci_ = -c[(j * ldc) + i + ii + 1]; + *pcr = cr_; + *pci = ci_; + pcr++; pci++; + } + } } - pb = pb + ldb; } - } + else + { + for (j = 0; j < n; j++) + { + for (gint_t ii = 0; ii < (mx*2); ii += 2) + { + double cr_ = mul*c[(j * ldc) + i + ii]; + double ci_ = mul*c[(j * ldc) + i + ii + 1]; + *pcr = cr_; + *pci = ci_; + pcr++; pci++; + } + } + } + }//mx==1 +#endif } -void bli_packX_real_imag_sum(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double* pbs, double mul) +void bli_packX_real_imag_sum(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double* pbs, double mul, gint_t mx) { gint_t j, p; __m256d av0, av1, zerov; __m256d tv0, tv1, sum; - if((mul ==1.0)||(mul==-1.0)) + if(mx==8) { - if(mul ==1.0) + if((mul ==1.0)||(mul==-1.0)) { - for (j = 0; j < n; j++) + if(mul ==1.0) { - for (p = 0; p <= ((k*2)-8); p += 8) + for (j = 0; j < n; j++) { - double* pbp = pb + p; - av0 = _mm256_loadu_pd(pbp); - av1 = _mm256_loadu_pd(pbp+4); - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); - av0 = _mm256_unpacklo_pd(tv0, tv1); - av1 = _mm256_unpackhi_pd(tv0, tv1); - sum = _mm256_add_pd(av0, av1); - _mm256_storeu_pd(pbr, av0); pbr += 4; - _mm256_storeu_pd(pbi, av1); pbi += 4; - _mm256_storeu_pd(pbs, sum); pbs += 4; + for (p = 0; p <= ((k*2)-8); p += 8) + { + double* pbp = pb + p; + av0 = _mm256_loadu_pd(pbp); + av1 = _mm256_loadu_pd(pbp+4); + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); + av0 = _mm256_unpacklo_pd(tv0, tv1); + av1 = _mm256_unpackhi_pd(tv0, tv1); + sum = _mm256_add_pd(av0, av1); + _mm256_storeu_pd(pbr, av0); pbr += 4; + _mm256_storeu_pd(pbi, av1); pbi += 4; + _mm256_storeu_pd(pbs, sum); pbs += 4; + } + + for (; p < (k*2); p += 2)// (real + imag)*k + { + double br = *(pb + p) ; + double bi = *(pb + p + 1); + *pbr = br; + *pbi = bi; + *pbs = br + bi; + + pbr++; pbi++; pbs++; + } + pb = pb + ldb; } - - for (; p < (k*2); p += 2)// (real + imag)*k + } + else + { + zerov = _mm256_setzero_pd(); + for (j = 0; j < n; j++) { - double br = *(pb + p) ; - double bi = *(pb + p + 1); - *pbr = br; - *pbi = bi; - *pbs = br + bi; - - pbr++; pbi++; pbs++; + for (p = 0; p <= ((k*2)-8); p += 8) + { + double* pbp = pb + p; + av0 = _mm256_loadu_pd(pbp); + av1 = _mm256_loadu_pd(pbp+4); + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); + av0 = _mm256_unpacklo_pd(tv0, tv1); + av1 = _mm256_unpackhi_pd(tv0, tv1); + + //negate + av0 = _mm256_sub_pd(zerov,av0); + av1 = _mm256_sub_pd(zerov,av1); + + sum = _mm256_add_pd(av0, av1); + _mm256_storeu_pd(pbr, av0); pbr += 4; + _mm256_storeu_pd(pbi, av1); pbi += 4; + _mm256_storeu_pd(pbs, sum); pbs += 4; + } + + for (; p < (k*2); p += 2)// (real + imag)*k + { + double br = -*(pb + p) ; + double bi = -*(pb + p + 1); + *pbr = br; + *pbi = bi; + *pbs = br + bi; + + pbr++; pbi++; pbs++; + } + pb = pb + ldb; } - pb = pb + ldb; } } else { - zerov = _mm256_setzero_pd(); for (j = 0; j < n; j++) { - for (p = 0; p <= ((k*2)-8); p += 8) + for (p = 0; p < (k*2); p += 2)// (real + imag)*k { - double* pbp = pb + p; - av0 = _mm256_loadu_pd(pbp); - av1 = _mm256_loadu_pd(pbp+4); - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); - av0 = _mm256_unpacklo_pd(tv0, tv1); - av1 = _mm256_unpackhi_pd(tv0, tv1); - - //negate - av0 = _mm256_sub_pd(zerov,av0); - av1 = _mm256_sub_pd(zerov,av1); - - sum = _mm256_add_pd(av0, av1); - _mm256_storeu_pd(pbr, av0); pbr += 4; - _mm256_storeu_pd(pbi, av1); pbi += 4; - _mm256_storeu_pd(pbs, sum); pbs += 4; - } - - for (; p < (k*2); p += 2)// (real + imag)*k - { - double br = -*(pb + p) ; - double bi = -*(pb + p + 1); - *pbr = br; - *pbi = bi; - *pbs = br + bi; + double br_ = mul * (*(pb + p)); + double bi_ = mul * (*(pb + p + 1)); + *pbr = br_; + *pbi = bi_; + *pbs = br_ + bi_; pbr++; pbi++; pbs++; } pb = pb + ldb; } } - } + }//mx==8 +#if 0 else { - for (j = 0; j < n; j++) + if((alpha ==1.0)||(alpha==-1.0)) { - for (p = 0; p < (k*2); p += 2)// (real + imag)*k + if(alpha ==1.0) { - double br_ = mul * (*(pb + p)); - double bi_ = mul * (*(pb + p + 1)); - *pbr = br_; - *pbi = bi_; - *pbs = br_ + bi_; + for (j = 0; j < n; j++) + { + for (p = 0; p < (k*2); p += 2)// (real + imag)*k + { + double br_ = b[(j * ldb) + p]; + double bi_ = b[(j * ldb) + p + 1]; + *pbr = br_; + *pbi = bi_; + *pbs = br_ + bi_; + + pbr++; pbi++; pbs++; + } + } + } + else + { + for (j = 0; j < n; j++) + { + for (p = 0; p < (k*2); p += 2)// (real + imag)*k + { + double br_ = -b[(j * ldb) + p]; + double bi_ = -b[(j * ldb) + p + 1]; + *pbr = br_; + *pbi = bi_; + *pbs = br_ + bi_; + + pbr++; pbi++; pbs++; + } + } + } + } + else + { + for (j = 0; j < n; j++) + { + for (p = 0; p < (k*2); p += 2)// (real + imag)*k + { + double br_ = alpha * b[(j * ldb) + p]; + double bi_ = alpha * b[(j * ldb) + p + 1]; + *pbr = br_; + *pbi = bi_; + *pbs = br_ + bi_; - pbr++; pbi++; pbs++; + pbr++; pbi++; pbs++; + } } - pb = pb + ldb; } } + #endif } -void bli_packA_real_imag_sum(double *pa, gint_t i, guint_t k, guint_t lda, double *par, double *pai, double *pas, bool isTransA) +void bli_packA_real_imag_sum(double *pa, gint_t i, guint_t k, guint_t lda, double *par, double *pai, double *pas, bool isTransA, gint_t mx) { __m256d av0, av1, av2, av3; __m256d tv0, tv1, sum; gint_t p; - if(isTransA==false) - { - pa = pa +i; - for (p = 0; p < k; p += 1) - { - //for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each. - #if 1 - av0 = _mm256_loadu_pd(pa); - av1 = _mm256_loadu_pd(pa+4); - av2 = _mm256_loadu_pd(pa+8); - av3 = _mm256_loadu_pd(pa+12); - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); - av0 = _mm256_unpacklo_pd(tv0, tv1); - av1 = _mm256_unpackhi_pd(tv0, tv1); - sum = _mm256_add_pd(av0, av1); - _mm256_storeu_pd(par, av0); par += 4; - _mm256_storeu_pd(pai, av1); pai += 4; - _mm256_storeu_pd(pas, sum); pas += 4; - - tv0 = _mm256_permute2f128_pd(av2, av3, 0x20); - tv1 = _mm256_permute2f128_pd(av2, av3, 0x31); - av2 = _mm256_unpacklo_pd(tv0, tv1); - av3 = _mm256_unpackhi_pd(tv0, tv1); - sum = _mm256_add_pd(av2, av3); - _mm256_storeu_pd(par, av2); par += 4; - _mm256_storeu_pd(pai, av3); pai += 4; - _mm256_storeu_pd(pas, sum); pas += 4; - #else //method 2 - __m128d high, low, real, img, sum; - av0 = _mm256_loadu_pd(pa); - av1 = _mm256_loadu_pd(pa+4); - av2 = _mm256_loadu_pd(pa+8); - av3 = _mm256_loadu_pd(pa+12); - high = _mm256_extractf128_pd(av0, 1); - low = _mm256_castpd256_pd128(av0); - real = _mm_shuffle_pd(low, high, 0b00); - img = _mm_shuffle_pd(low, high, 0b11); - sum = _mm_add_pd(real, img); - _mm_storeu_pd(par, real); par += 2; - _mm_storeu_pd(pai, img); pai += 2; - _mm_storeu_pd(pas, sum); pas += 2; - - high = _mm256_extractf128_pd(av1, 1); - low = _mm256_castpd256_pd128(av1); - real = _mm_shuffle_pd(low, high, 0b00); - img = _mm_shuffle_pd(low, high, 0b11); - sum = _mm_add_pd(real, img); - _mm_storeu_pd(par, real); par += 2; - _mm_storeu_pd(pai, img); pai += 2; - _mm_storeu_pd(pas, sum); pas += 2; - - high = _mm256_extractf128_pd(av2, 1); - low = _mm256_castpd256_pd128(av2); - real = _mm_shuffle_pd(low, high, 0b00); - img = _mm_shuffle_pd(low, high, 0b11); - sum = _mm_add_pd(real, img); - _mm_storeu_pd(par, real); par += 2; - _mm_storeu_pd(pai, img); pai += 2; - _mm_storeu_pd(pas, sum); pas += 2; - - high = _mm256_extractf128_pd(av3, 1); - low = _mm256_castpd256_pd128(av3); - real = _mm_shuffle_pd(low, high, 0b00); - img = _mm_shuffle_pd(low, high, 0b11); - sum = _mm_add_pd(real, img); - _mm_storeu_pd(par, real); par += 2; - _mm_storeu_pd(pai, img); pai += 2; - _mm_storeu_pd(pas, sum); pas += 2; - #endif - pa = pa + lda; - } - } - else - { - gint_t idx = (i/2) * lda; - pa = pa + idx; -#if 0 - for (int p = 0; p <= ((2*k)-8); p += 8) + if(mx==8) + { + if(isTransA==false) { - //for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each. - av0 = _mm256_loadu_pd(pa); - av1 = _mm256_loadu_pd(pa+4); - av2 = _mm256_loadu_pd(pa+8); - av3 = _mm256_loadu_pd(pa+12); - - //transpose 4x4 - tv0 = _mm256_unpacklo_pd(av0, av1); - tv1 = _mm256_unpackhi_pd(av0, av1); - tv2 = _mm256_unpacklo_pd(av2, av3); - tv3 = _mm256_unpackhi_pd(av2, av3); - - av0 = _mm256_permute2f128_pd(tv0, tv2, 0x20); - av1 = _mm256_permute2f128_pd(tv1, tv3, 0x20); - av2 = _mm256_permute2f128_pd(tv0, tv2, 0x31); - av3 = _mm256_permute2f128_pd(tv1, tv3, 0x31); - - //get real, imag and sum - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); - av0 = _mm256_unpacklo_pd(tv0, tv1); - av1 = _mm256_unpackhi_pd(tv0, tv1); - sum = _mm256_add_pd(av0, av1); - _mm256_storeu_pd(par, av0); par += 4; - _mm256_storeu_pd(pai, av1); pai += 4; - _mm256_storeu_pd(pas, sum); pas += 4; - - tv0 = _mm256_permute2f128_pd(av2, av3, 0x20); - tv1 = _mm256_permute2f128_pd(av2, av3, 0x31); - av2 = _mm256_unpacklo_pd(tv0, tv1); - av3 = _mm256_unpackhi_pd(tv0, tv1); - sum = _mm256_add_pd(av2, av3); - _mm256_storeu_pd(par, av2); par += 4; - _mm256_storeu_pd(pai, av3); pai += 4; - _mm256_storeu_pd(pas, sum); pas += 4; - - pa = pa + lda; + pa = pa +i; + for (p = 0; p < k; p += 1) + { + //for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each. + av0 = _mm256_loadu_pd(pa); + av1 = _mm256_loadu_pd(pa+4); + av2 = _mm256_loadu_pd(pa+8); + av3 = _mm256_loadu_pd(pa+12); + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); + av0 = _mm256_unpacklo_pd(tv0, tv1); + av1 = _mm256_unpackhi_pd(tv0, tv1); + sum = _mm256_add_pd(av0, av1); + _mm256_storeu_pd(par, av0); par += 4; + _mm256_storeu_pd(pai, av1); pai += 4; + _mm256_storeu_pd(pas, sum); pas += 4; + + tv0 = _mm256_permute2f128_pd(av2, av3, 0x20); + tv1 = _mm256_permute2f128_pd(av2, av3, 0x31); + av2 = _mm256_unpacklo_pd(tv0, tv1); + av3 = _mm256_unpackhi_pd(tv0, tv1); + sum = _mm256_add_pd(av2, av3); + _mm256_storeu_pd(par, av2); par += 4; + _mm256_storeu_pd(pai, av3); pai += 4; + _mm256_storeu_pd(pas, sum); pas += 4; + + pa = pa + lda; + } } -#endif - //A Transpose case: - for (gint_t ii = 0; ii < BLIS_MX8 ; ii++) + else { - gint_t idx = ii * lda; - gint_t sidx; - for (p = 0; p <= ((k*2)-8); p += 8) + gint_t idx = (i/2) * lda; + pa = pa + idx; + #if 0 + for (int p = 0; p <= ((2*k)-8); p += 8) { - double ar0_ = *(pa + idx + p); - double ai0_ = *(pa + idx + p + 1); + //for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each. + av0 = _mm256_loadu_pd(pa); + av1 = _mm256_loadu_pd(pa+4); + av2 = _mm256_loadu_pd(pa+8); + av3 = _mm256_loadu_pd(pa+12); + + //transpose 4x4 + tv0 = _mm256_unpacklo_pd(av0, av1); + tv1 = _mm256_unpackhi_pd(av0, av1); + tv2 = _mm256_unpacklo_pd(av2, av3); + tv3 = _mm256_unpackhi_pd(av2, av3); + + av0 = _mm256_permute2f128_pd(tv0, tv2, 0x20); + av1 = _mm256_permute2f128_pd(tv1, tv3, 0x20); + av2 = _mm256_permute2f128_pd(tv0, tv2, 0x31); + av3 = _mm256_permute2f128_pd(tv1, tv3, 0x31); + + //get real, imag and sum + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); + av0 = _mm256_unpacklo_pd(tv0, tv1); + av1 = _mm256_unpackhi_pd(tv0, tv1); + sum = _mm256_add_pd(av0, av1); + _mm256_storeu_pd(par, av0); par += 4; + _mm256_storeu_pd(pai, av1); pai += 4; + _mm256_storeu_pd(pas, sum); pas += 4; + + tv0 = _mm256_permute2f128_pd(av2, av3, 0x20); + tv1 = _mm256_permute2f128_pd(av2, av3, 0x31); + av2 = _mm256_unpacklo_pd(tv0, tv1); + av3 = _mm256_unpackhi_pd(tv0, tv1); + sum = _mm256_add_pd(av2, av3); + _mm256_storeu_pd(par, av2); par += 4; + _mm256_storeu_pd(pai, av3); pai += 4; + _mm256_storeu_pd(pas, sum); pas += 4; + + pa = pa + lda; + } + #endif + //A Transpose case: + for (gint_t ii = 0; ii < BLIS_MX8 ; ii++) + { + gint_t idx = ii * lda; + gint_t sidx; + for (p = 0; p <= ((k*2)-8); p += 8) + { + double ar0_ = *(pa + idx + p); + double ai0_ = *(pa + idx + p + 1); - double ar1_ = *(pa + idx + p + 2); - double ai1_ = *(pa + idx + p + 3); + double ar1_ = *(pa + idx + p + 2); + double ai1_ = *(pa + idx + p + 3); - double ar2_ = *(pa + idx + p + 4); - double ai2_ = *(pa + idx + p + 5); + double ar2_ = *(pa + idx + p + 4); + double ai2_ = *(pa + idx + p + 5); - double ar3_ = *(pa + idx + p + 6); - double ai3_ = *(pa + idx + p + 7); + double ar3_ = *(pa + idx + p + 6); + double ai3_ = *(pa + idx + p + 7); - sidx = (p/2) * BLIS_MX8; - *(par + sidx + ii) = ar0_; - *(pai + sidx + ii) = ai0_; - *(pas + sidx + ii) = ar0_ + ai0_; + sidx = (p/2) * BLIS_MX8; + *(par + sidx + ii) = ar0_; + *(pai + sidx + ii) = ai0_; + *(pas + sidx + ii) = ar0_ + ai0_; - sidx = ((p+2)/2) * BLIS_MX8; - *(par + sidx + ii) = ar1_; - *(pai + sidx + ii) = ai1_; - *(pas + sidx + ii) = ar1_ + ai1_; + sidx = ((p+2)/2) * BLIS_MX8; + *(par + sidx + ii) = ar1_; + *(pai + sidx + ii) = ai1_; + *(pas + sidx + ii) = ar1_ + ai1_; - sidx = ((p+4)/2) * BLIS_MX8; - *(par + sidx + ii) = ar2_; - *(pai + sidx + ii) = ai2_; - *(pas + sidx + ii) = ar2_ + ai2_; + sidx = ((p+4)/2) * BLIS_MX8; + *(par + sidx + ii) = ar2_; + *(pai + sidx + ii) = ai2_; + *(pas + sidx + ii) = ar2_ + ai2_; - sidx = ((p+6)/2) * BLIS_MX8; - *(par + sidx + ii) = ar3_; - *(pai + sidx + ii) = ai3_; - *(pas + sidx + ii) = ar3_ + ai3_; + sidx = ((p+6)/2) * BLIS_MX8; + *(par + sidx + ii) = ar3_; + *(pai + sidx + ii) = ai3_; + *(pas + sidx + ii) = ar3_ + ai3_; + + } + for (; p < (k*2); p += 2) + { + double ar_ = *(pa + idx + p); + double ai_ = *(pa + idx + p + 1); + gint_t sidx = (p/2) * BLIS_MX8; + *(par + sidx + ii) = ar_; + *(pai + sidx + ii) = ai_; + *(pas + sidx + ii) = ar_ + ai_; + } } + } + } //mx==8 + else//mx==1 + { + if(isTransA==false) + { + pa = pa +i; + //A No transpose case: + for (gint_t p = 0; p < k; p += 1) + { + gint_t idx = p * lda; + for (gint_t ii = 0; ii < (mx*2) ; ii += 2) + { //real + imag : Rkernel needs 8 elements each. + double ar_ = *(pa + idx + ii); + double ai_ = *(pa + idx + ii + 1); + *par = ar_; + *pai = ai_; + *pas = ar_ + ai_; + par++; pai++; pas++; + } + } + } + else + { + gint_t idx = (i/2) * lda; + pa = pa + idx; - for (; p < (k*2); p += 2) + //A Transpose case: + for (gint_t ii = 0; ii < mx ; ii++) { - double ar_ = *(pa + idx + p); - double ai_ = *(pa + idx + p + 1); - gint_t sidx = (p/2) * BLIS_MX8; - *(par + sidx + ii) = ar_; - *(pai + sidx + ii) = ai_; - *(pas + sidx + ii) = ar_ + ai_; + gint_t idx = ii * lda; + gint_t sidx; + for (p = 0; p < (k*2); p += 2) + { + double ar0_ = *(pa + idx + p); + double ai0_ = *(pa + idx + p + 1); + + sidx = (p/2) * mx; + *(par + sidx + ii) = ar0_; + *(pai + sidx + ii) = ai0_; + *(pas + sidx + ii) = ar0_ + ai0_; + + } } } - } + }//mx==1 } /************************************************************************************************************/ @@ -1109,8 +1679,14 @@ void bli_packA_real_imag_sum(double *pa, gint_t i, guint_t k, guint_t lda, doubl 3m_sqp focuses mainly on square matrixes but also supports non-square matrix. Current support is limiteed to m multiple of 8 and column storage. */ -static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, double alpha, double beta, bool isTransA) +static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, double alpha, double beta, bool isTransA, gint_t mx, gint_t* p_istart) { + inc_t m2 = m<<1; + inc_t mxmul2 = mx<<1; + if((*p_istart) > (m2-mxmul2)) + { + return BLIS_SUCCESS; + } /* B matrix */ double* br, * bi, * bs; mem_block mbr, mbi, mbs; @@ -1127,59 +1703,75 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l ldb = ldb * 2; ldc = ldc * 2; -//debug to be removed. -#if DEBUG_3M_SQP - double ax[8][16] = { {10,-10,20,-20,30,-30,40,-40,50,-50,60,-60,70,-70,80,-80}, - {1.1,-1.1,2.1,-2.1,3.1,-3.1,4.1,-4.1,5.1,-5.1,6.1,-6.1,7.1,-7.1,8.1,-8.1}, - {1.2,-1.2,2.2,-2.2,3.2,-3.2,4.2,-4.2,5.2,-5.2,6.2,-6.2,7.2,-7.2,8.2,-8.2}, - {1.3,-1.3,2.3,-2.3,3.3,-3.3,4.3,-4.3,5.3,-5.3,6.3,-6.3,7.3,-7.3,8.3,-8.3}, - - {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6,7,-7,8,-8}, - {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6,7,-7,8,-8}, - {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6,7,-7,8,-8}, - {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6,7,-7,8,-8} }; - - - double bx[6][16] = { {10,-10,20,-20,30,-30,40,-40,50,-50,60,-60,70,-70,80,-80}, - {1.1,-1.1,2.1,-2.1,3.1,-3.1,4.1,-4.1,5.1,-5.1,6.1,-6.1,7.1,-7.1,8.1,-8.1}, - {1.2,-1.2,2.2,-2.2,3.2,-3.2,4.2,-4.2,5.2,-5.2,6.2,-6.2,7.2,-7.2,8.2,-8.2}, - - {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6,7,-7,8,-8}, - {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6,7,-7,8,-8}, - {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6,7,-7,8,-8} }; - - double cx[8][12] = { {10,-10,20,-20,30,-30,40,-40,50,-50,60,-60}, - {1.1,-1.1,2.1,-2.1,3.1,-3.1,4.1,-4.1,5.1,-5.1,6.1,-6.1}, - {1.2,-1.2,2.2,-2.2,3.2,-3.2,4.2,-4.2,5.2,-5.2,6.2,-6.2}, - {1.3,-1.3,2.3,-2.3,3.3,-3.3,4.3,-4.3,5.3,-5.3,6.3,-6.3}, - - {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6}, - {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6}, - {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6}, - {1,-1,2,-2,3,-3,4,-4,5,-5,6,-6} }; - - b = &bx[0][0]; - a = &ax[0][0]; - c = &cx[0][0]; -#endif - /* Split b (br, bi) and compute bs = br + bi */ double* pbr = br; double* pbi = bi; double* pbs = bs; - gint_t j; + gint_t j, p; /* b matrix real and imag packing and compute. */ - bli_packX_real_imag_sum(b, n, k, ldb, pbr, pbi, pbs, alpha); + //bli_packX_real_imag_sum(b, n, k, ldb, pbr, pbi, pbs, alpha, mx); +#if 1//bug in above api to be fixed for mx = 1 + if((alpha ==1.0)||(alpha==-1.0)) + { + if(alpha ==1.0) + { + for (j = 0; j < n; j++) + { + for (p = 0; p < (k*2); p += 2)// (real + imag)*k + { + double br_ = b[(j * ldb) + p]; + double bi_ = b[(j * ldb) + p + 1]; + *pbr = br_; + *pbi = bi_; + *pbs = br_ + bi_; + pbr++; pbi++; pbs++; + } + } + } + else + { + for (j = 0; j < n; j++) + { + for (p = 0; p < (k*2); p += 2)// (real + imag)*k + { + double br_ = -b[(j * ldb) + p]; + double bi_ = -b[(j * ldb) + p + 1]; + *pbr = br_; + *pbi = bi_; + *pbs = br_ + bi_; + + pbr++; pbi++; pbs++; + } + } + } + } + else + { + for (j = 0; j < n; j++) + { + for (p = 0; p < (k*2); p += 2)// (real + imag)*k + { + double br_ = alpha * b[(j * ldb) + p]; + double bi_ = alpha * b[(j * ldb) + p + 1]; + *pbr = br_; + *pbi = bi_; + *pbs = br_ + bi_; + + pbr++; pbi++; pbs++; + } + } + } +#endif /* Workspace memory allocation currently done dynamically This needs to be taken from already allocated memory pool in application for better performance */ /* A matrix */ double* ar, * ai, * as; mem_block mar, mai, mas; - if(bli_allocateWorkspace(8, k, &mar, &mai, &mas) !=0) + if(bli_allocateWorkspace(mx, k, &mar, &mai, &mas) !=0) { return BLIS_FAILURE; } @@ -1192,7 +1784,7 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l double* w; mem_block mw; mw.data_size = sizeof(double); - mw.size = 8 * n; + mw.size = mx * n; if (bli_getaligned(&mw) != 0) { return BLIS_FAILURE; @@ -1203,7 +1795,7 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l double* cr; mem_block mcr; mcr.data_size = sizeof(double); - mcr.size = 8 * n; + mcr.size = mx * n; if (bli_getaligned(&mcr) != 0) { return BLIS_FAILURE; @@ -1215,14 +1807,14 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l double* ci; mem_block mci; mci.data_size = sizeof(double); - mci.size = 8 * n; + mci.size = mx * n; if (bli_getaligned(&mci) != 0) { return BLIS_FAILURE; } ci = (double*)mci.alignedBuf; - - for (inc_t i = 0; i < (2*m); i += (2*BLIS_MX8)) //this loop can be threaded. + inc_t i; + for (i = (*p_istart); i <= (m2-mxmul2); i += mxmul2) //this loop can be threaded. { ////////////// operation 1 ///////////////// @@ -1233,17 +1825,67 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l double* pas = as; /* a matrix real and imag packing and compute. */ - bli_packA_real_imag_sum(a, i, k, lda, par, pai, pas, isTransA); + bli_packA_real_imag_sum(a, i, k, lda, par, pai, pas, isTransA, mx); double* pcr = cr; double* pci = ci; //Split Cr and Ci and beta multiplication done. double* pc = c + i; - bli_packX_real_imag(pc, n, BLIS_MX8, ldc, pcr, pci, beta); - + //bli_packX_real_imag(pc, n, mx, ldc, pcr, pci, beta, mx); +#if 1 //bug in above api to be fixed for mx = 1 + if((beta ==1.0)||(beta==-1.0)) + { + if(beta ==1.0) + { + for (j = 0; j < n; j++) + { + for (gint_t ii = 0; ii < mxmul2; ii += 2) + { + double cr_ = c[(j * ldc) + i + ii]; + double ci_ = c[(j * ldc) + i + ii + 1]; + *pcr = cr_; + *pci = ci_; + pcr++; pci++; + } + } + } + else + { + //beta = -1.0 + for (j = 0; j < n; j++) + { + for (gint_t ii = 0; ii < mxmul2; ii += 2) + { + double cr_ = -c[(j * ldc) + i + ii]; + double ci_ = -c[(j * ldc) + i + ii + 1]; + *pcr = cr_; + *pci = ci_; + pcr++; pci++; + } + } + } + } + else + { + for (j = 0; j < n; j++) + { + for (gint_t ii = 0; ii < mxmul2; ii += 2) + { + double cr_ = beta*c[(j * ldc) + i + ii]; + double ci_ = beta*c[(j * ldc) + i + ii + 1]; + *pcr = cr_; + *pci = ci_; + pcr++; pci++; + } + } + } +#endif //Ci := rgemm( SA, SB, Ci ) - bli_dgemm_m8(BLIS_MX8, n, k, as, BLIS_MX8, bs, k, ci, BLIS_MX8, false, 1.0); + gint_t istart = 0; + gint_t* p_is = &istart; + *p_is = 0; + bli_dgemm_sqp_m8(mx, n, k, as, mx, bs, k, ci, mx, false, 1.0, mx, p_is); @@ -1251,18 +1893,19 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l //Wr: = dgemm_sqp(Ar, Br, 0) // Wr output 8xn double* wr = w; for (j = 0; j < n; j++) { - for (gint_t ii = 0; ii < BLIS_MX8; ii += 1) { + for (gint_t ii = 0; ii < mx; ii += 1) { *wr = 0; wr++; } } wr = w; - bli_dgemm_m8(BLIS_MX8, n, k, ar, BLIS_MX8, br, k, wr, BLIS_MX8, false, 1.0); + *p_is = 0; + bli_dgemm_sqp_m8(mx, n, k, ar, mx, br, k, wr, mx, false, 1.0, mx, p_is); //Cr : = addm(Wr, Cr) - bli_add_m(BLIS_MX8, n, wr, cr); + bli_add_m(mx, n, wr, cr); //Ci : = subm(Wr, Ci) - bli_sub_m(BLIS_MX8, n, wr, ci); + bli_sub_m(mx, n, wr, ci); @@ -1271,18 +1914,19 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l //Wi : = dgemm_sqp(Ai, Bi, 0) // Wi output 8xn double* wi = w; for (j = 0; j < n; j++) { - for (gint_t ii = 0; ii < BLIS_MX8; ii += 1) { + for (gint_t ii = 0; ii < mx; ii += 1) { *wi = 0; wi++; } } wi = w; - bli_dgemm_m8(BLIS_MX8, n, k, ai, BLIS_MX8, bi, k, wi, BLIS_MX8, false, 1.0); + *p_is = 0; + bli_dgemm_sqp_m8(mx, n, k, ai, mx, bi, k, wi, mx, false, 1.0, mx, p_is); //Cr : = subm(Wi, Cr) - bli_sub_m(BLIS_MX8, n, wi, cr); + bli_sub_m(mx, n, wi, cr); //Ci : = subm(Wi, Ci) - bli_sub_m(BLIS_MX8, n, wi, ci); + bli_sub_m(mx, n, wi, ci); pcr = cr; @@ -1290,41 +1934,56 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l for (j = 0; j < n; j++) { - for (gint_t ii = 0; ii < (2*BLIS_MX8); ii += 2) + for (gint_t ii = 0; ii < mxmul2; ii += 2) { c[(j * ldc) + i + ii] = *pcr; c[(j * ldc) + i + ii + 1] = *pci; pcr++; pci++; } } - + *p_istart = i + mxmul2; } -//debug to be removed. -#if DEBUG_3M_SQP - for (gint_t jj = 0; jj < n;jj++) +#if MEM_ALLOC + if(mar.unalignedBuf) { - for (gint_t ii = 0; ii < m;ii++) - { - printf("( %4.2lf %4.2lf) ", *cr, *ci); - cr++;ci++; - } - printf("\n"); + free(mar.unalignedBuf); + } + if(mai.unalignedBuf) + { + free(mai.unalignedBuf); + } + if(mas.unalignedBuf) + { + free(mas.unalignedBuf); + } + if(mw.unalignedBuf) + { + free(mw.unalignedBuf); + } + if(mcr.unalignedBuf) + { + free(mcr.unalignedBuf); } -#endif -#if MEM_ALLOC - free(mar.unalignedBuf); - free(mai.unalignedBuf); - free(mas.unalignedBuf); - free(mw.unalignedBuf); + if(mci.unalignedBuf) + { + free(mci.unalignedBuf); + } + if(mbr.unalignedBuf) + { + free(mbr.unalignedBuf); + } - free(mcr.unalignedBuf); - free(mci.unalignedBuf); + if(mbi.unalignedBuf) + { + free(mbi.unalignedBuf); + } - free(mbr.unalignedBuf); - free(mbi.unalignedBuf); - free(mbs.unalignedBuf); + if(mbs.unalignedBuf) + { + free(mbs.unalignedBuf); + } #else /* free workspace buffers */ bli_free_user(mbr.alignedBuf); From acfec6a44417c49e5d051e274cb86ffd8a333f00 Mon Sep 17 00:00:00 2001 From: Madan mohan Manokar Date: Mon, 5 Apr 2021 21:08:15 +0530 Subject: [PATCH 06/13] Enabling 3m_sqp and 3m1 methods 1. Re-enabling 3m methods for zgemm. 2. Vectorization of pack_sum routines re-enabled with bug fix. 3. 8mx6n kernel added. AMD-Internal: [CPUPL-1352] Change-Id: Id9f010ba763afc52d268c2e68805f069919b8810 --- frame/compat/bla_gemm.c | 2 +- kernels/zen/3/bli_gemm_sqp.c | 438 ++++++++++++++--------------------- 2 files changed, 176 insertions(+), 264 deletions(-) diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index 971a203f24..0364da46f8 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -38,7 +38,7 @@ // // Define BLAS-to-BLIS interfaces. // -#define ENABLE_INDUCED_METHOD 0 +#define ENABLE_INDUCED_METHOD 1 #ifdef BLIS_BLAS3_CALLS_TAPI #undef GENTFUNC diff --git a/kernels/zen/3/bli_gemm_sqp.c b/kernels/zen/3/bli_gemm_sqp.c index 84924a57f7..ccc9c4aedf 100644 --- a/kernels/zen/3/bli_gemm_sqp.c +++ b/kernels/zen/3/bli_gemm_sqp.c @@ -213,10 +213,17 @@ inc_t bli_kernel_8mx6n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t ld inc_t ldc6 = ldc * 6; inc_t ldb6 = ldb * 6; for (j = 0; j <= (n - 6); j += 6) { + double* pcldc = pc + ldc; + double* pcldc2 = pcldc + ldc; + double* pcldc3 = pcldc2 + ldc; + double* pcldc4 = pcldc3 + ldc; + double* pcldc5 = pcldc4 + ldc; - //printf("x"); - double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; double* pcldc3 = pcldc2 + ldc; double* pcldc4 = pcldc3 + ldc; double* pcldc5 = pcldc4 + ldc; - double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; double* pbldb3 = pbldb2 + ldb; double* pbldb4 = pbldb3 + ldb; double* pbldb5 = pbldb4 + ldb; + double* pbldb = pb + ldb; + double* pbldb2 = pbldb + ldb; + double* pbldb3 = pbldb2 + ldb; + double* pbldb4 = pbldb3 + ldb; + double* pbldb5 = pbldb4 + ldb; #if BLIS_ENABLE_PREFETCH _mm_prefetch((char*)(pc), _MM_HINT_T0); @@ -317,7 +324,7 @@ inc_t bli_kernel_8mx6n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t ld pc += ldc6;pb += ldb6; } - + //printf(" 8x6:j:%d ", j); return j; } @@ -326,7 +333,6 @@ inc_t bli_kernel_8mx6n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t ld inc_t bli_kernel_8mx5n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) { gint_t p; - __m256d av0; __m256d bv0, bv1, bv2, bv3; __m256d cv0, cv1, cv2, cv3; @@ -338,10 +344,17 @@ inc_t bli_kernel_8mx5n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t ld pc = c; inc_t ldc5 = ldc * 5; inc_t ldb5 = ldb * 5; - for (j = 0; j <= (n - 5); j += 5) { + for (; j <= (n - 5); j += 5) { - double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; double* pcldc3 = pcldc2 + ldc; double* pcldc4 = pcldc3 + ldc; - double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; double* pbldb3 = pbldb2 + ldb; double* pbldb4 = pbldb3 + ldb; + double* pcldc = pc + ldc; + double* pcldc2 = pcldc + ldc; + double* pcldc3 = pcldc2 + ldc; + double* pcldc4 = pcldc3 + ldc; + + double* pbldb = pb + ldb; + double* pbldb2 = pbldb + ldb; + double* pbldb3 = pbldb2 + ldb; + double* pbldb4 = pbldb3 + ldb; #if BLIS_ENABLE_PREFETCH _mm_prefetch((char*)(pc), _MM_HINT_T0); @@ -430,7 +443,7 @@ inc_t bli_kernel_8mx5n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t ld pc += ldc5;pb += ldb5; } - + //printf(" 8x5:j:%d ", j); return j; } @@ -494,6 +507,7 @@ inc_t bli_kernel_8mx4n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t ld pc += ldc4;pb += ldb4; }// j loop 4 multiple + //printf(" 8x4:j:%d ", j); return j; } @@ -552,6 +566,7 @@ inc_t bli_kernel_8mx3n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t ld pc += ldc3;pb += ldb3; }// j loop 3 multiple + //printf(" 8x3:j:%d ", j); return j; } @@ -601,6 +616,7 @@ inc_t bli_kernel_8mx2n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t ld pc += ldc2;pb += ldb2; }// j loop 2 multiple + //printf(" 8x2:j:%d ", j); return j; } @@ -637,6 +653,7 @@ inc_t bli_kernel_8mx1n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t ld _mm256_storeu_pd(pc + 4, cx0); pc += ldc;pb += ldb; }// j loop 1 multiple + //printf(" 8x1:j:%d ", j); return j; } @@ -805,6 +822,7 @@ inc_t bli_kernel_1mx1n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t ld *pc = c0; pc += ldc;pb += ldb; }// j loop 1 multiple + //printf(" 1x1:j:%d ", j); return j; } @@ -998,7 +1016,6 @@ static err_t bli_dgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l return BLIS_MALLOC_RETURNED_NULL; } } - for (i = (*p_istart); i <= (m-mx); i += mx) //this loop can be threaded. no of workitems = m/8 { inc_t j = 0; @@ -1016,10 +1033,12 @@ static err_t bli_dgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l { bli_prepackA_8(pa, aPacked, k, lda, isTransA, alpha); } +#if 0//mx=4, kernels not yet implemented. else if(mx==4) { bli_prepackA_4(pa, aPacked, k, lda, isTransA, alpha); } +#endif//0 else if(mx==1) { bli_prepackA_1(pa, aPacked, k, lda, isTransA, alpha); @@ -1031,26 +1050,25 @@ static err_t bli_dgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l } if(mx==8) { - //printf(" mx8i:%3ld ", i); - //8mx6n currently turned off to isolate a bug. - //j = bli_kernel_8mx6n(n, k, j, aPacked, lda, b, ldb, ci, ldc); - if (j <= n - 5) + //printf("\n mx8i:%3ld ", i); + j = bli_kernel_8mx6n(n, k, j, aPacked, lda, b, ldb, ci, ldc); + if (j <= (n - 5)) { - j = bli_kernel_8mx5n(n, k, j, aPacked, lda, b, ldb, ci, ldc); + j = bli_kernel_8mx5n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); } - if (j <= n - 4) + if (j <= (n - 4)) { j = bli_kernel_8mx4n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); } - if (j <= n - 3) + if (j <= (n - 3)) { j = bli_kernel_8mx3n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); } - if (j <= n - 2) + if (j <= (n - 2)) { j = bli_kernel_8mx2n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); } - if (j <= n - 1) + if (j <= (n - 1)) { j = bli_kernel_8mx1n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); } @@ -1058,7 +1076,7 @@ static err_t bli_dgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l /* mx==4 to be implemented */ else if(mx==1) { - //printf(" mx1i:%3ld ", i); + //printf("\n mx1i:%3ld ", i); j = bli_kernel_1mx1n(n, k, j, aPacked, lda, b, ldb, ci, ldc); } *p_istart = i + mx; @@ -1184,309 +1202,201 @@ void bli_sub_m(gint_t m, gint_t n, double* w, double* c) } } +/* Pack real and imaginary parts in separate buffers and also multipy with multiplication factor */ void bli_packX_real_imag(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double mul, gint_t mx) { gint_t j, p; __m256d av0, av1, zerov; __m256d tv0, tv1; - if(mx==8) + gint_t max_k = (k*2)-8; + + if((mul ==1.0)||(mul==-1.0)) { - if((mul ==1.0)||(mul==-1.0)) + if(mul ==1.0) /* handles alpha or beta = 1.0 */ { - if(mul ==1.0) + for (j = 0; j < n; j++) { - for (j = 0; j < n; j++) + for (p = 0; p <= max_k; p += 8) { - for (p = 0; p <= ((k*2)-8); p += 8) - { - double* pbp = pb + p; - av0 = _mm256_loadu_pd(pbp); - av1 = _mm256_loadu_pd(pbp+4); - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); - av0 = _mm256_unpacklo_pd(tv0, tv1); - av1 = _mm256_unpackhi_pd(tv0, tv1); - - _mm256_storeu_pd(pbr, av0); pbr += 4; - _mm256_storeu_pd(pbi, av1); pbi += 4; - } - - for (; p < (k*2); p += 2)// (real + imag)*k - { - double br = *(pb + p) ; - double bi = *(pb + p + 1); - *pbr = br; - *pbi = bi; - pbr++; pbi++; - } - pb = pb + ldb; + double* pbp = pb + p; + av0 = _mm256_loadu_pd(pbp); //ai1, ar1, ai0, ar0 + av1 = _mm256_loadu_pd(pbp+4); //ai3, ar3, ai2, ar2 + // + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 + av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 + av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 + + _mm256_storeu_pd(pbr, av0); pbr += 4; + _mm256_storeu_pd(pbi, av1); pbi += 4; } - } - else - { - zerov = _mm256_setzero_pd(); - for (j = 0; j < n; j++) - { - for (p = 0; p <= ((k*2)-8); p += 8) - { - double* pbp = pb + p; - av0 = _mm256_loadu_pd(pbp); - av1 = _mm256_loadu_pd(pbp+4); - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); - av0 = _mm256_unpacklo_pd(tv0, tv1); - av1 = _mm256_unpackhi_pd(tv0, tv1); - - //negate - av0 = _mm256_sub_pd(zerov,av0); - av1 = _mm256_sub_pd(zerov,av1); - _mm256_storeu_pd(pbr, av0); pbr += 4; - _mm256_storeu_pd(pbi, av1); pbi += 4; - } - - for (; p < (k*2); p += 2)// (real + imag)*k - { - double br = -*(pb + p) ; - double bi = -*(pb + p + 1); - *pbr = br; - *pbi = bi; - pbr++; pbi++; - } - pb = pb + ldb; - } - } - } - else - { - for (j = 0; j < n; j++) - { - for (p = 0; p < (k*2); p += 2)// (real + imag)*k + for (; p < (k*2); p += 2)// (real + imag)*k { - double br_ = mul * (*(pb + p)); - double bi_ = mul * (*(pb + p + 1)); - *pbr = br_; - *pbi = bi_; + double br = *(pb + p) ; + double bi = *(pb + p + 1); + *pbr = br; + *pbi = bi; pbr++; pbi++; } pb = pb + ldb; } } - }//mx==8 -#if 0//already taken care in previous loop - else//mx==1 - { - if((mul ==1.0)||(mul==-1.0)) + else /* handles alpha or beta = - 1.0 */ { - if(mul ==1.0) + zerov = _mm256_setzero_pd(); + for (j = 0; j < n; j++) { - for (j = 0; j < n; j++) + for (p = 0; p <= max_k; p += 8) { - for (gint_t ii = 0; ii < (mx*2); ii += 2) - { - double cr_ = c[(j * ldc) + i + ii]; - double ci_ = c[(j * ldc) + i + ii + 1]; - *pcr = cr_; - *pci = ci_; - pcr++; pci++; - } + double* pbp = pb + p; + av0 = _mm256_loadu_pd(pbp); //ai1, ar1, ai0, ar0 + av1 = _mm256_loadu_pd(pbp+4);//ai3, ar3, ai2, ar2 + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 + av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 + av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 + + //negate + av0 = _mm256_sub_pd(zerov,av0); + av1 = _mm256_sub_pd(zerov,av1); + + _mm256_storeu_pd(pbr, av0); pbr += 4; + _mm256_storeu_pd(pbi, av1); pbi += 4; } - } - else - { - //mul = -1.0 - for (j = 0; j < n; j++) + + for (; p < (k*2); p += 2)// (real + imag)*k { - for (gint_t ii = 0; ii < (mx*2); ii += 2) - { - double cr_ = -c[(j * ldc) + i + ii]; - double ci_ = -c[(j * ldc) + i + ii + 1]; - *pcr = cr_; - *pci = ci_; - pcr++; pci++; - } + double br = -*(pb + p) ; + double bi = -*(pb + p + 1); + *pbr = br; + *pbi = bi; + pbr++; pbi++; } + pb = pb + ldb; } } - else + } + else /* handles alpha or beta is not equal +/- 1.0 */ + { + for (j = 0; j < n; j++) { - for (j = 0; j < n; j++) + for (p = 0; p < (k*2); p += 2)// (real + imag)*k { - for (gint_t ii = 0; ii < (mx*2); ii += 2) - { - double cr_ = mul*c[(j * ldc) + i + ii]; - double ci_ = mul*c[(j * ldc) + i + ii + 1]; - *pcr = cr_; - *pci = ci_; - pcr++; pci++; - } + double br_ = mul * (*(pb + p)); + double bi_ = mul * (*(pb + p + 1)); + *pbr = br_; + *pbi = bi_; + pbr++; pbi++; } + pb = pb + ldb; } - }//mx==1 -#endif + } } +/* Pack real and imaginary parts in separate buffers and compute sum of real and imaginary part */ void bli_packX_real_imag_sum(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double* pbs, double mul, gint_t mx) { gint_t j, p; __m256d av0, av1, zerov; __m256d tv0, tv1, sum; - - if(mx==8) + gint_t max_k = (k*2) - 8; + if((mul ==1.0)||(mul==-1.0)) { - if((mul ==1.0)||(mul==-1.0)) + if(mul ==1.0) { - if(mul ==1.0) + for (j = 0; j < n; j++) { - for (j = 0; j < n; j++) + for (p=0; p <= max_k; p += 8) { - for (p = 0; p <= ((k*2)-8); p += 8) - { - double* pbp = pb + p; - av0 = _mm256_loadu_pd(pbp); - av1 = _mm256_loadu_pd(pbp+4); - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); - av0 = _mm256_unpacklo_pd(tv0, tv1); - av1 = _mm256_unpackhi_pd(tv0, tv1); - sum = _mm256_add_pd(av0, av1); - _mm256_storeu_pd(pbr, av0); pbr += 4; - _mm256_storeu_pd(pbi, av1); pbi += 4; - _mm256_storeu_pd(pbs, sum); pbs += 4; - } - - for (; p < (k*2); p += 2)// (real + imag)*k - { - double br = *(pb + p) ; - double bi = *(pb + p + 1); - *pbr = br; - *pbi = bi; - *pbs = br + bi; - - pbr++; pbi++; pbs++; - } - pb = pb + ldb; + double* pbp = pb + p; + av0 = _mm256_loadu_pd(pbp);//ai1, ar1, ai0, ar0 + av1 = _mm256_loadu_pd(pbp+4);//ai3, ar3, ai2, ar2 + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 + av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 + av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 + sum = _mm256_add_pd(av0, av1); + _mm256_storeu_pd(pbr, av0); pbr += 4; + _mm256_storeu_pd(pbi, av1); pbi += 4; + _mm256_storeu_pd(pbs, sum); pbs += 4; } - } - else - { - zerov = _mm256_setzero_pd(); - for (j = 0; j < n; j++) - { - for (p = 0; p <= ((k*2)-8); p += 8) - { - double* pbp = pb + p; - av0 = _mm256_loadu_pd(pbp); - av1 = _mm256_loadu_pd(pbp+4); - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); - av0 = _mm256_unpacklo_pd(tv0, tv1); - av1 = _mm256_unpackhi_pd(tv0, tv1); - - //negate - av0 = _mm256_sub_pd(zerov,av0); - av1 = _mm256_sub_pd(zerov,av1); - - sum = _mm256_add_pd(av0, av1); - _mm256_storeu_pd(pbr, av0); pbr += 4; - _mm256_storeu_pd(pbi, av1); pbi += 4; - _mm256_storeu_pd(pbs, sum); pbs += 4; - } - for (; p < (k*2); p += 2)// (real + imag)*k - { - double br = -*(pb + p) ; - double bi = -*(pb + p + 1); - *pbr = br; - *pbi = bi; - *pbs = br + bi; + for (; p < (k*2); p += 2)// (real + imag)*k + { + double br = *(pb + p) ; + double bi = *(pb + p + 1); + *pbr = br; + *pbi = bi; + *pbs = br + bi; - pbr++; pbi++; pbs++; - } - pb = pb + ldb; + pbr++; pbi++; pbs++; } + pb = pb + ldb; } } else { + zerov = _mm256_setzero_pd(); for (j = 0; j < n; j++) { - for (p = 0; p < (k*2); p += 2)// (real + imag)*k + for (p = 0; p <= max_k; p += 8) { - double br_ = mul * (*(pb + p)); - double bi_ = mul * (*(pb + p + 1)); - *pbr = br_; - *pbi = bi_; - *pbs = br_ + bi_; + double* pbp = pb + p; + av0 = _mm256_loadu_pd(pbp);//ai1, ar1, ai0, ar0 + av1 = _mm256_loadu_pd(pbp+4);//ai3, ar3, ai2, ar2 + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 + av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 + av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 + + //negate + av0 = _mm256_sub_pd(zerov,av0); + av1 = _mm256_sub_pd(zerov,av1); + + sum = _mm256_add_pd(av0, av1); + _mm256_storeu_pd(pbr, av0); pbr += 4; + _mm256_storeu_pd(pbi, av1); pbi += 4; + _mm256_storeu_pd(pbs, sum); pbs += 4; + } + + for (; p < (k*2); p += 2)// (real + imag)*k + { + double br = -*(pb + p) ; + double bi = -*(pb + p + 1); + *pbr = br; + *pbi = bi; + *pbs = br + bi; pbr++; pbi++; pbs++; } pb = pb + ldb; } } - }//mx==8 -#if 0 + } else { - if((alpha ==1.0)||(alpha==-1.0)) - { - if(alpha ==1.0) - { - for (j = 0; j < n; j++) - { - for (p = 0; p < (k*2); p += 2)// (real + imag)*k - { - double br_ = b[(j * ldb) + p]; - double bi_ = b[(j * ldb) + p + 1]; - *pbr = br_; - *pbi = bi_; - *pbs = br_ + bi_; - - pbr++; pbi++; pbs++; - } - } - } - else - { - for (j = 0; j < n; j++) - { - for (p = 0; p < (k*2); p += 2)// (real + imag)*k - { - double br_ = -b[(j * ldb) + p]; - double bi_ = -b[(j * ldb) + p + 1]; - *pbr = br_; - *pbi = bi_; - *pbs = br_ + bi_; - - pbr++; pbi++; pbs++; - } - } - } - } - else + for (j = 0; j < n; j++) { - for (j = 0; j < n; j++) + for (p = 0; p < (k*2); p += 2)// (real + imag)*k { - for (p = 0; p < (k*2); p += 2)// (real + imag)*k - { - double br_ = alpha * b[(j * ldb) + p]; - double bi_ = alpha * b[(j * ldb) + p + 1]; - *pbr = br_; - *pbi = bi_; - *pbs = br_ + bi_; + double br_ = mul * (*(pb + p)); + double bi_ = mul * (*(pb + p + 1)); + *pbr = br_; + *pbi = bi_; + *pbs = br_ + bi_; - pbr++; pbi++; pbs++; - } + pbr++; pbi++; pbs++; } + pb = pb + ldb; } } - #endif } +/* Pack real and imaginary parts of A matrix in separate buffers and compute sum of real and imaginary part */ void bli_packA_real_imag_sum(double *pa, gint_t i, guint_t k, guint_t lda, double *par, double *pai, double *pas, bool isTransA, gint_t mx) { __m256d av0, av1, av2, av3; @@ -1578,7 +1488,8 @@ void bli_packA_real_imag_sum(double *pa, gint_t i, guint_t k, guint_t lda, doubl { gint_t idx = ii * lda; gint_t sidx; - for (p = 0; p <= ((k*2)-8); p += 8) + gint_t max_k = (k*2) - 8; + for (p = 0; p <= max_k; p += 8) { double ar0_ = *(pa + idx + p); double ai0_ = *(pa + idx + p + 1); @@ -1709,11 +1620,11 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l double* pbi = bi; double* pbs = bs; - gint_t j, p; + gint_t j; /* b matrix real and imag packing and compute. */ - //bli_packX_real_imag_sum(b, n, k, ldb, pbr, pbi, pbs, alpha, mx); -#if 1//bug in above api to be fixed for mx = 1 + bli_packX_real_imag_sum(b, n, k, ldb, pbr, pbi, pbs, alpha, mx); +#if 0//bug in above api to be fixed for mx = 1 if((alpha ==1.0)||(alpha==-1.0)) { if(alpha ==1.0) @@ -1814,7 +1725,8 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l } ci = (double*)mci.alignedBuf; inc_t i; - for (i = (*p_istart); i <= (m2-mxmul2); i += mxmul2) //this loop can be threaded. + gint_t max_m = (m2-mxmul2); + for (i = (*p_istart); i <= max_m; i += mxmul2) //this loop can be threaded. { ////////////// operation 1 ///////////////// @@ -1832,8 +1744,8 @@ static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t l //Split Cr and Ci and beta multiplication done. double* pc = c + i; - //bli_packX_real_imag(pc, n, mx, ldc, pcr, pci, beta, mx); -#if 1 //bug in above api to be fixed for mx = 1 + bli_packX_real_imag(pc, n, mx, ldc, pcr, pci, beta, mx); +#if 0 //bug in above api to be fixed for mx = 1 if((beta ==1.0)||(beta==-1.0)) { if(beta ==1.0) From 74800cfbc6f07c14d0f8746768b804568cfdc61d Mon Sep 17 00:00:00 2001 From: Madan mohan Manokar Date: Wed, 5 May 2021 15:42:39 +0530 Subject: [PATCH 07/13] squarePacked(sqp) framework and multi-instance handling 1. kx partitions added to k loop for dgemm and zgemm. 2. mx loop based threading model added for dgemm as prototype of zgemm. 3. nx loop added for 3m_sqp and dgemm_sqp. 4. single 3m_sqp workspace allocation with smaller memory footprint. 5. sqp framework done from dgemm and zgemm. 6. sqp kernels moved to seperate kernel file. 7. residue kernel core added to handle mx<8. 8. multi-instance tuning for 3m_sqp done. 9. user can set env "BLIS_MULTI_INSTANCE" to 1 for better multi-instance behavior of 3m_sqp. AMD-Internal: [CPUPL-1521] Change-Id: Ibef50a8a37fe99f164edb4621acb44fc0c86514c --- frame/compat/bla_gemm.c | 1 - kernels/zen/3/CMakeLists.txt | 2 + kernels/zen/3/bli_gemm_sqp.c | 2235 +++++++++----------------- kernels/zen/3/bli_gemm_sqp_kernels.c | 1588 ++++++++++++++++++ kernels/zen/3/bli_gemm_sqp_kernels.h | 65 + 5 files changed, 2408 insertions(+), 1483 deletions(-) create mode 100644 kernels/zen/3/bli_gemm_sqp_kernels.c create mode 100644 kernels/zen/3/bli_gemm_sqp_kernels.h diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index 0364da46f8..3eafad1903 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -217,7 +217,6 @@ void PASTEF77(ch,blasname) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ } - #endif #ifdef BLIS_ENABLE_BLAS diff --git a/kernels/zen/3/CMakeLists.txt b/kernels/zen/3/CMakeLists.txt index 7363f7f173..fd3aabe70b 100644 --- a/kernels/zen/3/CMakeLists.txt +++ b/kernels/zen/3/CMakeLists.txt @@ -5,6 +5,8 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_small.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_syrk_small.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsm_small.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_ref_k1.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_sqp_kernels.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_sqp.c ) diff --git a/kernels/zen/3/bli_gemm_sqp.c b/kernels/zen/3/bli_gemm_sqp.c index ccc9c4aedf..1622c551c4 100644 --- a/kernels/zen/3/bli_gemm_sqp.c +++ b/kernels/zen/3/bli_gemm_sqp.c @@ -33,15 +33,36 @@ */ #include "blis.h" #include "immintrin.h" +#include "bli_gemm_sqp_kernels.h" +#define SQP_THREAD_ENABLE 0//currently disabled +#define BLI_SQP_MAX_THREADS 128 #define BLIS_LOADFIRST 0 -#define BLIS_ENABLE_PREFETCH 1 - #define MEM_ALLOC 1//malloc performs better than bli_malloc. -#define BLIS_MX8 8 -#define BLIS_MX4 4 -#define BLIS_MX1 1 -#define DEBUG_3M_SQP 0 + +//Macro for 3m_sqp n loop +#define BLI_SQP_ZGEMM_N(MX)\ + int j=0;\ + for(; j<=(n-nx); j+= nx)\ + {\ + status = bli_sqp_zgemm_m8( m, nx, k, a, lda, b+(j*ldb), ldb, c+(j*ldc), ldc, alpha_real, beta_real, isTransA, MX, p_istart, kx, &mem_3m_sqp);\ + }\ + if(j>3)<<3); + dim_t nt = bli_thread_get_num_threads(); // get number of threads double* ap = ( double* )bli_obj_buffer( a ); double* bp = ( double* )bli_obj_buffer( b ); double* cp = ( double* )bli_obj_buffer( c ); - gint_t istart = 0; - gint_t* p_istart = &istart; - *p_istart = 0; - err_t status; + if(dt==BLIS_DCOMPLEX) { dcomplex* alphap = ( dcomplex* )bli_obj_buffer( alpha ); @@ -139,22 +236,8 @@ err_t bli_gemm_sqp { return BLIS_NOT_YET_IMPLEMENTED; } - /* 3m zgemm implementation for C = AxB and C = AtxB */ -#if 0 - return bli_zgemm_sqp_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, alpha_real, beta_real, isTransA, 8, p_istart); -#else - status = bli_zgemm_sqp_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, alpha_real, beta_real, isTransA, 8, p_istart); - if(m8rem==0) - { - return status;// No residue: done - } - else - { - //complete residue m blocks - status = bli_zgemm_sqp_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, alpha_real, beta_real, isTransA, 1, p_istart); - return status; - } -#endif + //printf("zsqp "); + return bli_sqp_zgemm( m, n, k, ap, lda, bp, ldb, cp, ldc, alpha_real, beta_real, isTransA, nt); } else if(dt == BLIS_DOUBLE) { @@ -170,922 +253,299 @@ err_t bli_gemm_sqp { return BLIS_NOT_YET_IMPLEMENTED; } - /* dgemm implementation with 8mx5n major kernel and column preferred storage */ - status = bli_dgemm_sqp_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, isTransA, (*alpha_cast), 8, p_istart); - if(status==BLIS_SUCCESS) - { - if(m8rem==0) - { - return status;// No residue: done - } - else - { - //complete residue m blocks - status = bli_dgemm_sqp_m8( m, n, k, ap, lda, bp, ldb, cp, ldc, isTransA, (*alpha_cast), 1, p_istart); - return status; - } - } - + //printf("dsqp "); + return bli_sqp_dgemm( m, n, k, ap, lda, bp, ldb, cp, ldc, *alpha_cast, *beta_cast, isTransA, nt); } AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_NOT_YET_IMPLEMENTED; }; -/************************************************************************************************************/ -/************************** dgemm kernels (8mxn) column preffered ******************************************/ -/************************************************************************************************************/ - -/* Main dgemm kernel 8mx6n with single load and store of C matrix block - alpha = +/-1 and beta = +/-1,0 handled while packing.*/ -inc_t bli_kernel_8mx6n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) +//sqp_dgemm k partition +BLIS_INLINE void bli_sqp_dgemm_kx( gint_t m, + gint_t n, + gint_t kx, + gint_t p, + double* a, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc, + bool isTransA, + double alpha, + gint_t mx, + gint_t i, + bool pack_on, + double *aligned) { - gint_t p; - - __m256d av0, av1; - __m256d bv0, bv1; - __m256d cv0, cv1, cv2, cv3, cv4, cv5; - __m256d cx0, cx1, cx2, cx3, cx4, cx5; - double* pb, * pc; - - pb = b; - pc = c; - inc_t ldc6 = ldc * 6; inc_t ldb6 = ldb * 6; - - for (j = 0; j <= (n - 6); j += 6) { - double* pcldc = pc + ldc; - double* pcldc2 = pcldc + ldc; - double* pcldc3 = pcldc2 + ldc; - double* pcldc4 = pcldc3 + ldc; - double* pcldc5 = pcldc4 + ldc; - - double* pbldb = pb + ldb; - double* pbldb2 = pbldb + ldb; - double* pbldb3 = pbldb2 + ldb; - double* pbldb4 = pbldb3 + ldb; - double* pbldb5 = pbldb4 + ldb; - -#if BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(pc), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc2), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc3), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc4), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc5), _MM_HINT_T0); - - _mm_prefetch((char*)(aPacked), _MM_HINT_T0); - - _mm_prefetch((char*)(pb), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb2), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb3), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb4), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb5), _MM_HINT_T0); -#endif - /* C matrix column major load */ -#if BLIS_LOADFIRST - cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); - cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); - cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); - cv3 = _mm256_loadu_pd(pcldc3); cx3 = _mm256_loadu_pd(pcldc3 + 4); - cv4 = _mm256_loadu_pd(pcldc4); cx4 = _mm256_loadu_pd(pcldc4 + 4); - cv5 = _mm256_loadu_pd(pcldc5); cx5 = _mm256_loadu_pd(pcldc5 + 4); -#else - cv0 = _mm256_setzero_pd(); cx0 = _mm256_setzero_pd(); - cv1 = _mm256_setzero_pd(); cx1 = _mm256_setzero_pd(); - cv2 = _mm256_setzero_pd(); cx2 = _mm256_setzero_pd(); - cv3 = _mm256_setzero_pd(); cx3 = _mm256_setzero_pd(); - cv4 = _mm256_setzero_pd(); cx4 = _mm256_setzero_pd(); - cv5 = _mm256_setzero_pd(); cx5 = _mm256_setzero_pd(); -#endif - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - av0 = _mm256_loadu_pd(x); x += 4; av1 = _mm256_loadu_pd(x); x += 4; - bv0 = _mm256_broadcast_sd (pb0); pb0++; - bv1 = _mm256_broadcast_sd(pbldb); pbldb++; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - cx0 = _mm256_fmadd_pd(av1, bv0, cx0); - cv1 = _mm256_fmadd_pd(av0, bv1, cv1); - cx1 = _mm256_fmadd_pd(av1, bv1, cx1); - - bv0 = _mm256_broadcast_sd(pbldb2);pbldb2++; - bv1 = _mm256_broadcast_sd(pbldb3);pbldb3++; - cv2 = _mm256_fmadd_pd(av0, bv0, cv2); - cx2 = _mm256_fmadd_pd(av1, bv0, cx2); - cv3 = _mm256_fmadd_pd(av0, bv1, cv3); - cx3 = _mm256_fmadd_pd(av1, bv1, cx3); - - bv0 = _mm256_broadcast_sd(pbldb4);pbldb4++; - bv1 = _mm256_broadcast_sd(pbldb5);pbldb5++; - cv4 = _mm256_fmadd_pd(av0, bv0, cv4); - cx4 = _mm256_fmadd_pd(av1, bv0, cx4); - cv5 = _mm256_fmadd_pd(av0, bv1, cv5); - cx5 = _mm256_fmadd_pd(av1, bv1, cx5); + inc_t j = 0; + double* ci = c + i; + double* aPacked; + //packing + if(pack_on==true) + { + aPacked = aligned; + double *pa = a + i + (p*lda); + if(isTransA==true) + { + pa = a + (i*lda) + p; } -#if BLIS_LOADFIRST -#else - bv0 = _mm256_loadu_pd(pc); bv1 = _mm256_loadu_pd(pc + 4); - cv0 = _mm256_add_pd(cv0, bv0); cx0 = _mm256_add_pd(cx0, bv1); - - av0 = _mm256_loadu_pd(pcldc); av1 = _mm256_loadu_pd(pcldc + 4); - cv1 = _mm256_add_pd(cv1, av0); cx1 = _mm256_add_pd(cx1, av1); - - bv0 = _mm256_loadu_pd(pcldc2); bv1 = _mm256_loadu_pd(pcldc2 + 4); - cv2 = _mm256_add_pd(cv2, bv0); cx2 = _mm256_add_pd(cx2, bv1); - - av0 = _mm256_loadu_pd(pcldc3); av1 = _mm256_loadu_pd(pcldc3 + 4); - cv3 = _mm256_add_pd(cv3, av0); cx3 = _mm256_add_pd(cx3, av1); - - bv0 = _mm256_loadu_pd(pcldc4); bv1 = _mm256_loadu_pd(pcldc4 + 4); - cv4 = _mm256_add_pd(cv4, bv0); cx4 = _mm256_add_pd(cx4, bv1); - - av0 = _mm256_loadu_pd(pcldc5); av1 = _mm256_loadu_pd(pcldc5 + 4); - cv5 = _mm256_add_pd(cv5, av0); cx5 = _mm256_add_pd(cx5, av1); -#endif - /* C matrix column major store */ - _mm256_storeu_pd(pc, cv0); - _mm256_storeu_pd(pc + 4, cx0); - - _mm256_storeu_pd(pcldc, cv1); - _mm256_storeu_pd(pcldc + 4, cx1); - - _mm256_storeu_pd(pcldc2, cv2); - _mm256_storeu_pd(pcldc2 + 4, cx2); - - _mm256_storeu_pd(pcldc3, cv3); - _mm256_storeu_pd(pcldc3 + 4, cx3); - - _mm256_storeu_pd(pcldc4, cv4); - _mm256_storeu_pd(pcldc4 + 4, cx4); - - _mm256_storeu_pd(pcldc5, cv5); - _mm256_storeu_pd(pcldc5 + 4, cx5); - - pc += ldc6;pb += ldb6; + bli_sqp_prepackA(pa, aPacked, kx, lda, isTransA, alpha, mx); } - //printf(" 8x6:j:%d ", j); - return j; -} - -/* alternative Main dgemm kernel 8mx5n with single load and store of C matrix block - alpha = +/-1 and beta = +/-1,0 handled while packing.*/ -inc_t bli_kernel_8mx5n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) -{ - gint_t p; - __m256d av0; - __m256d bv0, bv1, bv2, bv3; - __m256d cv0, cv1, cv2, cv3; - __m256d cx0, cx1, cx2, cx3; - __m256d bv4, cv4, cx4; - double* pb, * pc; - - pb = b; - pc = c; - inc_t ldc5 = ldc * 5; inc_t ldb5 = ldb * 5; - - for (; j <= (n - 5); j += 5) { - - double* pcldc = pc + ldc; - double* pcldc2 = pcldc + ldc; - double* pcldc3 = pcldc2 + ldc; - double* pcldc4 = pcldc3 + ldc; - - double* pbldb = pb + ldb; - double* pbldb2 = pbldb + ldb; - double* pbldb3 = pbldb2 + ldb; - double* pbldb4 = pbldb3 + ldb; - -#if BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(pc), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc2), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc3), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc4), _MM_HINT_T0); - - _mm_prefetch((char*)(aPacked), _MM_HINT_T0); - - _mm_prefetch((char*)(pb), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb2), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb3), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb4), _MM_HINT_T0); -#endif - /* C matrix column major load */ -#if BLIS_LOADFIRST - cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); - cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); - cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); - cv3 = _mm256_loadu_pd(pcldc3); cx3 = _mm256_loadu_pd(pcldc3 + 4); - cv4 = _mm256_loadu_pd(pcldc4); cx4 = _mm256_loadu_pd(pcldc4 + 4); -#else - cv0 = _mm256_setzero_pd(); cx0 = _mm256_setzero_pd(); - cv1 = _mm256_setzero_pd(); cx1 = _mm256_setzero_pd(); - cv2 = _mm256_setzero_pd(); cx2 = _mm256_setzero_pd(); - cv3 = _mm256_setzero_pd(); cx3 = _mm256_setzero_pd(); - cv4 = _mm256_setzero_pd(); cx4 = _mm256_setzero_pd(); -#endif - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - bv0 = _mm256_broadcast_sd(pb0); pb0++; - bv1 = _mm256_broadcast_sd(pbldb); pbldb++; - bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; - bv3 = _mm256_broadcast_sd(pbldb3);pbldb3++; - bv4 = _mm256_broadcast_sd(pbldb4);pbldb4++; - - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - cv1 = _mm256_fmadd_pd(av0, bv1, cv1); - cv2 = _mm256_fmadd_pd(av0, bv2, cv2); - cv3 = _mm256_fmadd_pd(av0, bv3, cv3); - cv4 = _mm256_fmadd_pd(av0, bv4, cv4); - - av0 = _mm256_loadu_pd(x); x += 4; - cx0 = _mm256_fmadd_pd(av0, bv0, cx0); - cx1 = _mm256_fmadd_pd(av0, bv1, cx1); - cx2 = _mm256_fmadd_pd(av0, bv2, cx2); - cx3 = _mm256_fmadd_pd(av0, bv3, cx3); - cx4 = _mm256_fmadd_pd(av0, bv4, cx4); - } -#if BLIS_LOADFIRST -#else - bv0 = _mm256_loadu_pd(pc); bv1 = _mm256_loadu_pd(pc + 4); - cv0 = _mm256_add_pd(cv0, bv0); cx0 = _mm256_add_pd(cx0, bv1); - - bv2 = _mm256_loadu_pd(pcldc); bv3 = _mm256_loadu_pd(pcldc + 4); - cv1 = _mm256_add_pd(cv1, bv2); cx1 = _mm256_add_pd(cx1, bv3); - - bv0 = _mm256_loadu_pd(pcldc2); bv1 = _mm256_loadu_pd(pcldc2 + 4); - cv2 = _mm256_add_pd(cv2, bv0); cx2 = _mm256_add_pd(cx2, bv1); - - bv2 = _mm256_loadu_pd(pcldc3); bv3 = _mm256_loadu_pd(pcldc3 + 4); - cv3 = _mm256_add_pd(cv3, bv2); cx3 = _mm256_add_pd(cx3, bv3); - - bv0 = _mm256_loadu_pd(pcldc4); bv1 = _mm256_loadu_pd(pcldc4 + 4); - cv4 = _mm256_add_pd(cv4, bv0); cx4 = _mm256_add_pd(cx4, bv1); -#endif - /* C matrix column major store */ - _mm256_storeu_pd(pc, cv0); - _mm256_storeu_pd(pc + 4, cx0); - - _mm256_storeu_pd(pcldc, cv1); - _mm256_storeu_pd(pcldc + 4, cx1); - - _mm256_storeu_pd(pcldc2, cv2); - _mm256_storeu_pd(pcldc2 + 4, cx2); - - _mm256_storeu_pd(pcldc3, cv3); - _mm256_storeu_pd(pcldc3 + 4, cx3); - - _mm256_storeu_pd(pcldc4, cv4); - _mm256_storeu_pd(pcldc4 + 4, cx4); - - pc += ldc5;pb += ldb5; + else + { + aPacked = a+i + (p*lda); } - //printf(" 8x5:j:%d ", j); - return j; -} -/* residue dgemm kernel 8mx4n with single load and store of C matrix block - Code could be optimized further, complete ymm register set is not used. - Being residue kernel, its of lesser priority. -*/ -inc_t bli_kernel_8mx4n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) -{ - gint_t p; - __m256d av0; - __m256d bv0, bv1, bv2, bv3; - __m256d cv0, cv1, cv2, cv3; - __m256d cx0, cx1, cx2, cx3; - double* pb, * pc; - - pb = b; - pc = c; - inc_t ldc4 = ldc * 4; inc_t ldb4 = ldb * 4; - - for (; j <= (n - 4); j += 4) { - - double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; double* pcldc3 = pcldc2 + ldc; - double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; double* pbldb3 = pbldb2 + ldb; - - cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); - cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); - cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); - cv3 = _mm256_loadu_pd(pcldc3); cx3 = _mm256_loadu_pd(pcldc3 + 4); + //compute + if(mx==8) + { + //printf("\n mx8i:%3ld ", i); + if (j <= (n - 6)) { - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - // better kernel to be written since more register are available. - bv0 = _mm256_broadcast_sd(pb0); pb0++; - bv1 = _mm256_broadcast_sd(pbldb); pbldb++; - bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; - bv3 = _mm256_broadcast_sd(pbldb3); pbldb3++; - - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - cv1 = _mm256_fmadd_pd(av0, bv1, cv1); - cv2 = _mm256_fmadd_pd(av0, bv2, cv2); - cv3 = _mm256_fmadd_pd(av0, bv3, cv3); - - av0 = _mm256_loadu_pd(x); x += 4; - cx0 = _mm256_fmadd_pd(av0, bv0, cx0); - cx1 = _mm256_fmadd_pd(av0, bv1, cx1); - cx2 = _mm256_fmadd_pd(av0, bv2, cx2); - cx3 = _mm256_fmadd_pd(av0, bv3, cx3); - } + j = bli_sqp_dgemm_kernel_8mx6n(n, kx, j, aPacked, lda, b + p, ldb, ci, ldc); } - _mm256_storeu_pd(pc, cv0); - _mm256_storeu_pd(pc + 4, cx0); - _mm256_storeu_pd(pcldc, cv1); - _mm256_storeu_pd(pcldc + 4, cx1); - _mm256_storeu_pd(pcldc2, cv2); - _mm256_storeu_pd(pcldc2 + 4, cx2); - _mm256_storeu_pd(pcldc3, cv3); - _mm256_storeu_pd(pcldc3 + 4, cx3); - - pc += ldc4;pb += ldb4; - }// j loop 4 multiple - //printf(" 8x4:j:%d ", j); - return j; -} - -/* residue dgemm kernel 8mx3n with single load and store of C matrix block - Code could be optimized further, complete ymm register set is not used. - Being residue kernel, its of lesser priority. -*/ -inc_t bli_kernel_8mx3n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) -{ - gint_t p; - __m256d av0; - __m256d bv0, bv1, bv2; - __m256d cv0, cv1, cv2; - __m256d cx0, cx1, cx2; - double* pb, * pc; - - pb = b; - pc = c; - - inc_t ldc3 = ldc * 3; inc_t ldb3 = ldb * 3; - - for (; j <= (n - 3); j += 3) { - - double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; - double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; - - cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); - cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); - cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); + if (j <= (n - 5)) { - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - bv0 = _mm256_broadcast_sd(pb0); pb0++; - bv1 = _mm256_broadcast_sd(pbldb); pbldb++; - bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; - - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - cv1 = _mm256_fmadd_pd(av0, bv1, cv1); - cv2 = _mm256_fmadd_pd(av0, bv2, cv2); - - av0 = _mm256_loadu_pd(x); x += 4; - cx0 = _mm256_fmadd_pd(av0, bv0, cx0); - cx1 = _mm256_fmadd_pd(av0, bv1, cx1); - cx2 = _mm256_fmadd_pd(av0, bv2, cx2); - } + j = bli_sqp_dgemm_kernel_8mx5n(n, kx, j, aPacked, lda, b + (j * ldb) + p, ldb, ci + (j * ldc), ldc); } - - _mm256_storeu_pd(pc, cv0); - _mm256_storeu_pd(pc + 4, cx0); - _mm256_storeu_pd(pcldc, cv1); - _mm256_storeu_pd(pcldc + 4, cx1); - _mm256_storeu_pd(pcldc2, cv2); - _mm256_storeu_pd(pcldc2 + 4, cx2); - - pc += ldc3;pb += ldb3; - }// j loop 3 multiple - //printf(" 8x3:j:%d ", j); - return j; -} - -/* residue dgemm kernel 8mx2n with single load and store of C matrix block - Code could be optimized further, complete ymm register set is not used. - Being residue kernel, its of lesser priority. -*/ -inc_t bli_kernel_8mx2n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) -{ - gint_t p; - __m256d av0; - __m256d bv0, bv1; - __m256d cv0, cv1; - __m256d cx0, cx1; - double* pb, * pc; - - pb = b; - pc = c; - inc_t ldc2 = ldc * 2; inc_t ldb2 = ldb * 2; - - for (; j <= (n - 2); j += 2) { - double* pcldc = pc + ldc; - double* pbldb = pb + ldb; - - cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); - cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); + if (j <= (n - 4)) { - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - bv0 = _mm256_broadcast_sd(pb0); pb0++; - bv1 = _mm256_broadcast_sd(pbldb); pbldb++; - - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - cv1 = _mm256_fmadd_pd(av0, bv1, cv1); - - av0 = _mm256_loadu_pd(x); x += 4; - cx0 = _mm256_fmadd_pd(av0, bv0, cx0); - cx1 = _mm256_fmadd_pd(av0, bv1, cx1); - } - } - _mm256_storeu_pd(pc, cv0); - _mm256_storeu_pd(pc + 4, cx0); - _mm256_storeu_pd(pcldc, cv1); - _mm256_storeu_pd(pcldc + 4, cx1); - - pc += ldc2;pb += ldb2; - }// j loop 2 multiple - //printf(" 8x2:j:%d ", j); - return j; -} - -/* residue dgemm kernel 8mx1n with single load and store of C matrix block - Code could be optimized further, complete ymm register set is not used. - Being residue kernel, its of lesser priority. -*/ -inc_t bli_kernel_8mx1n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) -{ - gint_t p; - __m256d av0; - __m256d bv0; - __m256d cv0; - __m256d cx0; - double* pb, * pc; - - pb = b; - pc = c; - - for (; j <= (n - 1); j += 1) { - cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - bv0 = _mm256_broadcast_sd(pb0); pb0++; - - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - - av0 = _mm256_loadu_pd(x); x += 4; - cx0 = _mm256_fmadd_pd(av0, bv0, cx0); - } - _mm256_storeu_pd(pc, cv0); - _mm256_storeu_pd(pc + 4, cx0); - pc += ldc;pb += ldb; - }// j loop 1 multiple - //printf(" 8x1:j:%d ", j); - return j; -} - -#if 0 -/************************************************************************************************************/ -/************************** dgemm kernels (4mxn) column preffered ******************************************/ -/************************************************************************************************************/ -/* Residue dgemm kernel 4mx10n with single load and store of C matrix block - alpha = +/-1 and beta = +/-1,0 handled while packing.*/ -inc_t bli_kernel_4mx10n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) -{ - gint_t p; - /* incomplete */ - __m256d av0; - __m256d bv0, bv1, bv2, bv3; - __m256d cv0, cv1, cv2, cv3; - __m256d cx0, cx1, cx2, cx3; - __m256d bv4, cv4, cx4; - double* pb, * pc; - - pb = b; - pc = c; - inc_t ldc10 = ldc * 10; inc_t ldb10 = ldb * 10; - - for (j = 0; j <= (n - 10); j += 10) { - - double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; double* pcldc3 = pcldc2 + ldc; double* pcldc4 = pcldc3 + ldc; - double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; double* pbldb3 = pbldb2 + ldb; double* pbldb4 = pbldb3 + ldb; - -#if BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(pc), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc2), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc3), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc4), _MM_HINT_T0); - - _mm_prefetch((char*)(aPacked), _MM_HINT_T0); - - _mm_prefetch((char*)(pb), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb2), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb3), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb4), _MM_HINT_T0); -#endif - /* C matrix column major load */ -#if BLIS_LOADFIRST - cv0 = _mm256_loadu_pd(pc); - cv1 = _mm256_loadu_pd(pcldc); - cv2 = _mm256_loadu_pd(pcldc2); - cv3 = _mm256_loadu_pd(pcldc3); - cv4 = _mm256_loadu_pd(pcldc4); -#else - cv0 = _mm256_setzero_pd(); - cv1 = _mm256_setzero_pd(); - cv2 = _mm256_setzero_pd(); - cv3 = _mm256_setzero_pd(); - cv4 = _mm256_setzero_pd(); -#endif - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - bv0 = _mm256_broadcast_sd(pb0); pb0++; - bv1 = _mm256_broadcast_sd(pbldb); pbldb++; - bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; - bv3 = _mm256_broadcast_sd(pbldb3);pbldb3++; - bv4 = _mm256_broadcast_sd(pbldb4);pbldb4++; - - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - cv1 = _mm256_fmadd_pd(av0, bv1, cv1); - cv2 = _mm256_fmadd_pd(av0, bv2, cv2); - cv3 = _mm256_fmadd_pd(av0, bv3, cv3); - cv4 = _mm256_fmadd_pd(av0, bv4, cv4); - - } -#if BLIS_LOADFIRST -#else - bv0 = _mm256_loadu_pd(pc); - cv0 = _mm256_add_pd(cv0, bv0); - - bv2 = _mm256_loadu_pd(pcldc); - cv1 = _mm256_add_pd(cv1, bv2); - - bv0 = _mm256_loadu_pd(pcldc2); - cv2 = _mm256_add_pd(cv2, bv0); - - bv2 = _mm256_loadu_pd(pcldc3); - cv3 = _mm256_add_pd(cv3, bv2); - - bv0 = _mm256_loadu_pd(pcldc4); - cv4 = _mm256_add_pd(cv4, bv0); -#endif - /* C matrix column major store */ - _mm256_storeu_pd(pc, cv0); - _mm256_storeu_pd(pcldc, cv1); - _mm256_storeu_pd(pcldc2, cv2); - _mm256_storeu_pd(pcldc3, cv3); - _mm256_storeu_pd(pcldc4, cv4); - - - pc += ldc10;pb += ldb10; - } - - return j; -} - -/* residue dgemm kernel 4mx1n with single load and store of C matrix block - Code could be optimized further, complete ymm register set is not used. - Being residue kernel, its of lesser priority. -*/ -inc_t bli_kernel_4mx1n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) -{ - gint_t p; - __m256d av0; - __m256d bv0; - __m256d cv0; - double* pb, * pc; - - pb = b; - pc = c; - - for (; j <= (n - 1); j += 1) { - cv0 = _mm256_loadu_pd(pc); - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - bv0 = _mm256_broadcast_sd(pb0); pb0++; - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + j = bli_sqp_dgemm_kernel_8mx4n(n, kx, j, aPacked, lda, b + (j * ldb) + p, ldb, ci + (j * ldc), ldc); } - _mm256_storeu_pd(pc, cv0); - pc += ldc;pb += ldb; - }// j loop 1 multiple - return j; -} - -#endif -/************************************************************************************************************/ -/************************** dgemm kernels (1mxn) column preffered ******************************************/ -/************************************************************************************************************/ - -/* residue dgemm kernel 1mx1n with single load and store of C matrix block - Code could be optimized further, complete ymm register set is not used. - Being residue kernel, its of lesser priority. -*/ -inc_t bli_kernel_1mx1n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc) -{ - gint_t p; - double a0; - double b0; - double c0; - double* pb, * pc; - - pb = b; - pc = c; - - for (; j <= (n - 1); j += 1) { - c0 = *pc; - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - b0 = *pb0; pb0++; - a0 = *x; x++; - c0 += (a0 * b0); + if (j <= (n - 3)) + { + j = bli_sqp_dgemm_kernel_8mx3n(n, kx, j, aPacked, lda, b + (j * ldb) + p, ldb, ci + (j * ldc), ldc); } - *pc = c0; - pc += ldc;pb += ldb; - }// j loop 1 multiple - //printf(" 1x1:j:%d ", j); - return j; -} - -/* Ax8 packing subroutine */ -void bli_prepackA_8(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha) -{ - __m256d av0, av1, ymm0; - if(isTransA==false) - { - if(alpha==1.0) + if (j <= (n - 2)) { - for (gint_t p = 0; p < k; p += 1) { - av0 = _mm256_loadu_pd(pa); av1 = _mm256_loadu_pd(pa + 4); pa += lda; - _mm256_storeu_pd(aPacked, av0); _mm256_storeu_pd(aPacked + 4, av1); - aPacked += BLIS_MX8; - } + j = bli_sqp_dgemm_kernel_8mx2n(n, kx, j, aPacked, lda, b + (j * ldb) + p, ldb, ci + (j * ldc), ldc); } - else if(alpha==-1.0) + if (j <= (n - 1)) { - ymm0 = _mm256_setzero_pd();//set zero - for (gint_t p = 0; p < k; p += 1) { - av0 = _mm256_loadu_pd(pa); av1 = _mm256_loadu_pd(pa + 4); pa += lda; - av0 = _mm256_sub_pd(ymm0,av0); av1 = _mm256_sub_pd(ymm0,av1); // a = 0 - a; - _mm256_storeu_pd(aPacked, av0); _mm256_storeu_pd(aPacked + 4, av1); - aPacked += BLIS_MX8; - } + j = bli_sqp_dgemm_kernel_8mx1n(n, kx, j, aPacked, lda, b + (j * ldb) + p, ldb, ci + (j * ldc), ldc); } } + /* mx==4 to be implemented */ else { - if(alpha==1.0) - { - //A Transpose case: - for (gint_t i = 0; i < BLIS_MX8 ; i++) - { - gint_t idx = i * lda; - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+idx+p); - gint_t sidx = p * BLIS_MX8; - *(aPacked + sidx + i) = ar_; - } - } - } - else if(alpha==-1.0) - { - //A Transpose case: - for (gint_t i = 0; i < BLIS_MX8 ; i++) - { - gint_t idx = i * lda; - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+idx+p); - gint_t sidx = p * BLIS_MX8; - *(aPacked + sidx + i) = -ar_; - } - } - } + // this residue kernel needs to be improved. + j = bli_sqp_dgemm_kernel_mxn(n, kx, j, aPacked, lda, b + p, ldb, ci, ldc, mx); } } -/* Ax4 packing subroutine */ -void bli_prepackA_4(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha) +//sqp dgemm m loop +void bli_sqp_dgemm_m( gint_t i_start, + gint_t i_end, + gint_t m, + gint_t n, + gint_t k, + gint_t kx, + double* a, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc, + bool isTransA, + double alpha, + gint_t mx, + bool pack_on, + double *aligned ) { - __m256d av0, ymm0; - if(isTransA==false) +#if SQP_THREAD_ENABLE + if(pack_on==true) { - if(alpha==1.0) - { - for (gint_t p = 0; p < k; p += 1) { - av0 = _mm256_loadu_pd(pa); pa += lda; - _mm256_storeu_pd(aPacked, av0); - aPacked += BLIS_MX4; - } - } - else if(alpha==-1.0) + //NEEDED IN THREADING CASE: + aligned = (double*)bli_malloc_user(sizeof(double) * kx * mx); + if(aligned==NULL) { - ymm0 = _mm256_setzero_pd();//set zero - for (gint_t p = 0; p < k; p += 1) { - av0 = _mm256_loadu_pd(pa); pa += lda; - av0 = _mm256_sub_pd(ymm0,av0); // a = 0 - a; - _mm256_storeu_pd(aPacked, av0); - aPacked += BLIS_MX4; - } + return BLIS_MALLOC_RETURNED_NULL;// return to be removed } } - else +#endif//SQP_THREAD_ENABLE + + for (gint_t i = i_start; i <= (i_end-mx); i += mx) //this loop can be threaded. no of workitems = m/8 { - if(alpha==1.0) + int p = 0; + for(; p <= (k-kx); p += kx) { - //A Transpose case: - for (gint_t i = 0; i < BLIS_MX4 ; i++) - { - gint_t idx = i * lda; - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+idx+p); - gint_t sidx = p * BLIS_MX4; - *(aPacked + sidx + i) = ar_; - } - } - } - else if(alpha==-1.0) + bli_sqp_dgemm_kx(m, n, kx, p, a, lda, b, ldb, c, ldc, isTransA, alpha, mx, i, pack_on, aligned); + }// k loop end + + if(pi_start, + arg->i_end, + arg->m, + arg->n, + arg->k, + arg->kx, + arg->a, + arg->lda, + arg->b, + arg->ldb, + arg->c, + arg->ldc, + arg->isTransA, + arg->alpha, + arg->mx, + arg->pack_on, + arg->aligned); +} + +// sqp_dgemm m loop +BLIS_INLINE err_t bli_sqp_dgemm_m8( gint_t m, + gint_t n, + gint_t k, + double* a, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc, + bool isTransA, + double alpha, + gint_t mx, + gint_t* p_istart, + gint_t kx, + double *aligned) +{ + gint_t i; + if(kx > k) { - if(alpha==1.0) + kx = k; + } + + bool pack_on = false; + if((m!=mx)||(m!=lda)||isTransA) + { + pack_on = true; + } + +#if 0//SQP_THREAD_ENABLE//ENABLE Threading + gint_t status = 0; + gint_t workitems = (m-(*p_istart))/mx; + gint_t inputThreadCount = bli_thread_get_num_threads(); + inputThreadCount = bli_min(inputThreadCount, BLI_SQP_MAX_THREADS); + inputThreadCount = bli_min(inputThreadCount,workitems);// limit input thread count when workitems are lesser. + inputThreadCount = bli_max(inputThreadCount,1); + gint_t num_threads; + num_threads = bli_max(inputThreadCount,1); + gint_t mx_per_thread = workitems/num_threads;//no of workitems per thread + //printf("\nistart %d workitems %d inputThreadCount %d num_threads %d mx_per_thread %d mx %d " , + *p_istart, workitems,inputThreadCount,num_threads,mx_per_thread, mx); + + pthread_t ptid[BLI_SQP_MAX_THREADS]; + bli_sqp_thread_info thread_info[BLI_SQP_MAX_THREADS]; + + //create threads + for (gint_t t = 0; t < num_threads; t++) + { + //ptid[t].tid = t; + gint_t i_end = ((mx_per_thread*(t+1))*mx)+(*p_istart); + if(i_end>m) { - for (gint_t p = 0; p < k; p += 1) { - *aPacked = *pa; - pa += lda; - aPacked++; - } + i_end = m; } - else if(alpha==-1.0) + + if(t==(num_threads-1)) { - for (gint_t p = 0; p < k; p += 1) { - *aPacked = -(*pa); - pa += lda; - aPacked++; + if((i_end+mx)==m) + { + i_end = m; } - } - } - else - { - if(alpha==1.0) - { - //A Transpose case: - for (gint_t p = 0; p < k; p ++) + + if(mx==1) { - double ar_ = *(pa+p); - *(aPacked + p) = ar_; + i_end = m; } } - else if(alpha==-1.0) + + thread_info[t].i_start = ((mx_per_thread*t)*mx)+(*p_istart); + thread_info[t].i_end = i_end; + //printf("\n threadid %d istart %d iend %d m %d mx %d", t, thread_info[t].i_start, i_end, m, mx); + thread_info[t].m = m; + thread_info[t].n = n; + thread_info[t].k = k; + thread_info[t].kx = kx; + thread_info[t].a = a; + thread_info[t].lda = lda; + thread_info[t].b = b; + thread_info[t].ldb = ldb; + thread_info[t].c = c; + thread_info[t].ldc = ldc; + thread_info[t].isTransA = isTransA; + thread_info[t].alpha = alpha; + thread_info[t].mx = mx; + thread_info[t].pack_on = pack_on; + thread_info[t].aligned = aligned; +#if 1 + if ((status = pthread_create(&ptid[t], NULL, bli_sqp_thread, (void*)&thread_info[t]))) { - //A Transpose case: - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+p); - *(aPacked + p) = -ar_; - } + printf("error sqp pthread_create\n"); + return BLIS_FAILURE; } +#else + //simulate thread for debugging.. + bli_sqp_thread((void*)&thread_info[t]); +#endif } -} -/************************************************************************************************************/ -/***************************************** dgemm_sqp implementation******************************************/ -/************************************************************************************************************/ -/* dgemm_sqp implementation packs A matrix based on lda and m size. dgemm_sqp focuses mainly on square matrixes - but also supports non-square matrix. Current support is limiteed to m multiple of 8 and column storage. - C = AxB and C = AtxB is handled in the design. AtxB case is done by transposing A matrix while packing A. - In majority of use-case, alpha are +/-1, so instead of explicitly multiplying alpha its done - during packing itself by changing sign. -*/ -static err_t bli_dgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, bool isTransA, double alpha, gint_t mx, gint_t* p_istart) -{ - double* aPacked; - double* aligned = NULL; - gint_t i; + //wait for completion + for (gint_t t = 0; t < num_threads; t++) + { + pthread_join(ptid[t], NULL); + } - bool pack_on = false; - if((m!=mx)||(m!=lda)||isTransA) + if(num_threads>0) { - pack_on = true; + *p_istart = thread_info[(num_threads-1)].i_end; } +#else//SQP_THREAD_ENABLE if(pack_on==true) { - aligned = (double*)bli_malloc_user(sizeof(double) * k * mx); + //aligned = (double*)bli_malloc_user(sizeof(double) * kx * mx); // allocation moved to top. if(aligned==NULL) { return BLIS_MALLOC_RETURNED_NULL; } } + for (i = (*p_istart); i <= (m-mx); i += mx) //this loop can be threaded. no of workitems = m/8 { - inc_t j = 0; - double* ci = c + i; - if(pack_on==true) + int p = 0; + for(; p <= (k-kx); p += kx) { - aPacked = aligned; - double *pa = a + i; - if(isTransA==true) - { - pa = a + (i*lda); - } - /* should be changed to func pointer */ - if(mx==8) - { - bli_prepackA_8(pa, aPacked, k, lda, isTransA, alpha); - } -#if 0//mx=4, kernels not yet implemented. - else if(mx==4) - { - bli_prepackA_4(pa, aPacked, k, lda, isTransA, alpha); - } -#endif//0 - else if(mx==1) - { - bli_prepackA_1(pa, aPacked, k, lda, isTransA, alpha); - } - } - else - { - aPacked = a+i; - } - if(mx==8) - { - //printf("\n mx8i:%3ld ", i); - j = bli_kernel_8mx6n(n, k, j, aPacked, lda, b, ldb, ci, ldc); - if (j <= (n - 5)) - { - j = bli_kernel_8mx5n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); - } - if (j <= (n - 4)) - { - j = bli_kernel_8mx4n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); - } - if (j <= (n - 3)) - { - j = bli_kernel_8mx3n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); - } - if (j <= (n - 2)) - { - j = bli_kernel_8mx2n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); - } - if (j <= (n - 1)) - { - j = bli_kernel_8mx1n(n, k, j, aPacked, lda, b + (j * ldb), ldb, ci + (j * ldc), ldc); - } - } - /* mx==4 to be implemented */ - else if(mx==1) + bli_sqp_dgemm_kx(m, n, kx, p, a, lda, b, ldb, c, ldc, + isTransA, alpha, mx, i, pack_on, aligned); + }// k loop end + + if(punalignedBuf = (double*)malloc(memSize); @@ -1130,6 +590,7 @@ gint_t bli_allocateWorkspace(gint_t n, gint_t k, mem_block *mxr, mem_block *mxi, mxr->size = mxi->size = n * k; msx->size = n * k; mxr->alignedBuf = mxi->alignedBuf = msx->alignedBuf = NULL; + mxr->unalignedBuf = mxi->unalignedBuf = msx->unalignedBuf = NULL; if (!((bli_getaligned(mxr) == 0) && (bli_getaligned(mxi) == 0) && (bli_getaligned(msx) == 0))) { @@ -1156,578 +617,326 @@ gint_t bli_allocateWorkspace(gint_t n, gint_t k, mem_block *mxr, mem_block *mxi, return 0; } -void bli_add_m(gint_t m,gint_t n,double* w,double* c) +//3m_sqp k loop +BLIS_INLINE void bli_sqp_zgemm_kx( gint_t m, + gint_t n, + gint_t kx, + gint_t p, + double* a, + guint_t lda, + guint_t ldb, + double* c, + guint_t ldc, + bool isTransA, + double alpha, + double beta, + gint_t mx, + gint_t i, + double* ar, + double* ai, + double* as, + double* br, + double* bi, + double* bs, + double* cr, + double* ci, + double* w, + double *a_aligned) { - double* pc = c; - double* pw = w; - gint_t count = m*n; - gint_t i = 0; - __m256d cv0, wv0; - - for (; i <= (count-4); i+=4) - { - cv0 = _mm256_loadu_pd(pc); - wv0 = _mm256_loadu_pd(pw); pw += 4; - cv0 = _mm256_add_pd(cv0,wv0); - _mm256_storeu_pd(pc, cv0); pc += 4; - } - for (; i < count; i++) - { - *pc = *pc + *pw; - pc++; pw++; - } - + gint_t j; -} + ////////////// operation 1 ///////////////// + /* Split a (ar, ai) and + compute as = ar + ai */ + double* par = ar; + double* pai = ai; + double* pas = as; -void bli_sub_m(gint_t m, gint_t n, double* w, double* c) -{ - double* pc = c; - double* pw = w; - gint_t count = m*n; - gint_t i = 0; - __m256d cv0, wv0; + /* a matrix real and imag packing and compute. */ + bli_3m_sqp_packA_real_imag_sum(a, i, kx+p, lda, par, pai, pas, isTransA, mx, p); - for (; i <= (count-4); i+=4) - { - cv0 = _mm256_loadu_pd(pc); - wv0 = _mm256_loadu_pd(pw); pw += 4; - cv0 = _mm256_sub_pd(cv0,wv0); - _mm256_storeu_pd(pc, cv0); pc += 4; - } - for (; i < count; i++) - { - *pc = *pc - *pw; - pc++; pw++; - } -} + double* pcr = cr; + double* pci = ci; -/* Pack real and imaginary parts in separate buffers and also multipy with multiplication factor */ -void bli_packX_real_imag(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double mul, gint_t mx) -{ - gint_t j, p; - __m256d av0, av1, zerov; - __m256d tv0, tv1; - gint_t max_k = (k*2)-8; - - if((mul ==1.0)||(mul==-1.0)) + //Split Cr and Ci and beta multiplication done. + double* pc = c + i; + if(p==0) { - if(mul ==1.0) /* handles alpha or beta = 1.0 */ - { - for (j = 0; j < n; j++) - { - for (p = 0; p <= max_k; p += 8) - { - double* pbp = pb + p; - av0 = _mm256_loadu_pd(pbp); //ai1, ar1, ai0, ar0 - av1 = _mm256_loadu_pd(pbp+4); //ai3, ar3, ai2, ar2 - // - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 - av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 - av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 - - _mm256_storeu_pd(pbr, av0); pbr += 4; - _mm256_storeu_pd(pbi, av1); pbi += 4; - } - - for (; p < (k*2); p += 2)// (real + imag)*k - { - double br = *(pb + p) ; - double bi = *(pb + p + 1); - *pbr = br; - *pbi = bi; - pbr++; pbi++; - } - pb = pb + ldb; - } - } - else /* handles alpha or beta = - 1.0 */ - { - zerov = _mm256_setzero_pd(); - for (j = 0; j < n; j++) - { - for (p = 0; p <= max_k; p += 8) - { - double* pbp = pb + p; - av0 = _mm256_loadu_pd(pbp); //ai1, ar1, ai0, ar0 - av1 = _mm256_loadu_pd(pbp+4);//ai3, ar3, ai2, ar2 - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 - av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 - av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 - - //negate - av0 = _mm256_sub_pd(zerov,av0); - av1 = _mm256_sub_pd(zerov,av1); - - _mm256_storeu_pd(pbr, av0); pbr += 4; - _mm256_storeu_pd(pbi, av1); pbi += 4; - } - - for (; p < (k*2); p += 2)// (real + imag)*k - { - double br = -*(pb + p) ; - double bi = -*(pb + p + 1); - *pbr = br; - *pbi = bi; - pbr++; pbi++; - } - pb = pb + ldb; - } - } + bli_3m_sqp_packC_real_imag(pc, n, mx, ldc, pcr, pci, beta, mx); } - else /* handles alpha or beta is not equal +/- 1.0 */ - { - for (j = 0; j < n; j++) - { - for (p = 0; p < (k*2); p += 2)// (real + imag)*k - { - double br_ = mul * (*(pb + p)); - double bi_ = mul * (*(pb + p + 1)); - *pbr = br_; - *pbi = bi_; - pbr++; pbi++; - } - pb = pb + ldb; + //Ci := rgemm( SA, SB, Ci ) + gint_t istart = 0; + gint_t* p_is = &istart; + *p_is = 0; + bli_sqp_dgemm_m8(mx, n, kx, as, mx, bs, ldb, ci, mx, false, 1.0, mx, p_is, kx, a_aligned); + + ////////////// operation 2 ///////////////// + //Wr: = dgemm_sqp(Ar, Br, 0) // Wr output 8xn + double* wr = w; + for (j = 0; j < n; j++) { + for (gint_t ii = 0; ii < mx; ii += 1) { + *wr = 0; + wr++; } } -} - -/* Pack real and imaginary parts in separate buffers and compute sum of real and imaginary part */ -void bli_packX_real_imag_sum(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double* pbs, double mul, gint_t mx) -{ - gint_t j, p; - __m256d av0, av1, zerov; - __m256d tv0, tv1, sum; - gint_t max_k = (k*2) - 8; - if((mul ==1.0)||(mul==-1.0)) - { - if(mul ==1.0) - { - for (j = 0; j < n; j++) - { - for (p=0; p <= max_k; p += 8) - { - double* pbp = pb + p; - av0 = _mm256_loadu_pd(pbp);//ai1, ar1, ai0, ar0 - av1 = _mm256_loadu_pd(pbp+4);//ai3, ar3, ai2, ar2 - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 - av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 - av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 - sum = _mm256_add_pd(av0, av1); - _mm256_storeu_pd(pbr, av0); pbr += 4; - _mm256_storeu_pd(pbi, av1); pbi += 4; - _mm256_storeu_pd(pbs, sum); pbs += 4; - } - - for (; p < (k*2); p += 2)// (real + imag)*k - { - double br = *(pb + p) ; - double bi = *(pb + p + 1); - *pbr = br; - *pbi = bi; - *pbs = br + bi; - - pbr++; pbi++; pbs++; - } - pb = pb + ldb; - } - } - else - { - zerov = _mm256_setzero_pd(); - for (j = 0; j < n; j++) - { - for (p = 0; p <= max_k; p += 8) - { - double* pbp = pb + p; - av0 = _mm256_loadu_pd(pbp);//ai1, ar1, ai0, ar0 - av1 = _mm256_loadu_pd(pbp+4);//ai3, ar3, ai2, ar2 - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 - av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 - av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 - - //negate - av0 = _mm256_sub_pd(zerov,av0); - av1 = _mm256_sub_pd(zerov,av1); - - sum = _mm256_add_pd(av0, av1); - _mm256_storeu_pd(pbr, av0); pbr += 4; - _mm256_storeu_pd(pbi, av1); pbi += 4; - _mm256_storeu_pd(pbs, sum); pbs += 4; - } - - for (; p < (k*2); p += 2)// (real + imag)*k - { - double br = -*(pb + p) ; - double bi = -*(pb + p + 1); - *pbr = br; - *pbi = bi; - *pbs = br + bi; - - pbr++; pbi++; pbs++; - } - pb = pb + ldb; - } + wr = w; + + *p_is = 0; + bli_sqp_dgemm_m8(mx, n, kx, ar, mx, br, ldb, wr, mx, false, 1.0, mx, p_is, kx, a_aligned); + //Cr : = addm(Wr, Cr) + bli_add_m(mx, n, wr, cr); + //Ci : = subm(Wr, Ci) + bli_sub_m(mx, n, wr, ci); + + + ////////////// operation 3 ///////////////// + //Wi : = dgemm_sqp(Ai, Bi, 0) // Wi output 8xn + double* wi = w; + for (j = 0; j < n; j++) { + for (gint_t ii = 0; ii < mx; ii += 1) { + *wi = 0; + wi++; } } - else - { - for (j = 0; j < n; j++) - { - for (p = 0; p < (k*2); p += 2)// (real + imag)*k - { - double br_ = mul * (*(pb + p)); - double bi_ = mul * (*(pb + p + 1)); - *pbr = br_; - *pbi = bi_; - *pbs = br_ + bi_; + wi = w; - pbr++; pbi++; pbs++; - } - pb = pb + ldb; - } - } -} + *p_is = 0; + bli_sqp_dgemm_m8(mx, n, kx, ai, mx, bi, ldb, wi, mx, false, 1.0, mx, p_is, kx, a_aligned); + //Cr : = subm(Wi, Cr) + bli_sub_m(mx, n, wi, cr); + //Ci : = subm(Wi, Ci) + bli_sub_m(mx, n, wi, ci); -/* Pack real and imaginary parts of A matrix in separate buffers and compute sum of real and imaginary part */ -void bli_packA_real_imag_sum(double *pa, gint_t i, guint_t k, guint_t lda, double *par, double *pai, double *pas, bool isTransA, gint_t mx) -{ - __m256d av0, av1, av2, av3; - __m256d tv0, tv1, sum; - gint_t p; + pcr = cr; + pci = ci; - if(mx==8) + for (j = 0; j < n; j++) { - if(isTransA==false) + for (gint_t ii = 0; ii < (mx*2); ii += 2) { - pa = pa +i; - for (p = 0; p < k; p += 1) - { - //for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each. - av0 = _mm256_loadu_pd(pa); - av1 = _mm256_loadu_pd(pa+4); - av2 = _mm256_loadu_pd(pa+8); - av3 = _mm256_loadu_pd(pa+12); - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); - av0 = _mm256_unpacklo_pd(tv0, tv1); - av1 = _mm256_unpackhi_pd(tv0, tv1); - sum = _mm256_add_pd(av0, av1); - _mm256_storeu_pd(par, av0); par += 4; - _mm256_storeu_pd(pai, av1); pai += 4; - _mm256_storeu_pd(pas, sum); pas += 4; - - tv0 = _mm256_permute2f128_pd(av2, av3, 0x20); - tv1 = _mm256_permute2f128_pd(av2, av3, 0x31); - av2 = _mm256_unpacklo_pd(tv0, tv1); - av3 = _mm256_unpackhi_pd(tv0, tv1); - sum = _mm256_add_pd(av2, av3); - _mm256_storeu_pd(par, av2); par += 4; - _mm256_storeu_pd(pai, av3); pai += 4; - _mm256_storeu_pd(pas, sum); pas += 4; - - pa = pa + lda; - } + c[(j * ldc) + i + ii] = *pcr; + c[(j * ldc) + i + ii + 1] = *pci; + pcr++; pci++; } - else - { - gint_t idx = (i/2) * lda; - pa = pa + idx; - #if 0 - for (int p = 0; p <= ((2*k)-8); p += 8) - { - //for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each. - av0 = _mm256_loadu_pd(pa); - av1 = _mm256_loadu_pd(pa+4); - av2 = _mm256_loadu_pd(pa+8); - av3 = _mm256_loadu_pd(pa+12); - - //transpose 4x4 - tv0 = _mm256_unpacklo_pd(av0, av1); - tv1 = _mm256_unpackhi_pd(av0, av1); - tv2 = _mm256_unpacklo_pd(av2, av3); - tv3 = _mm256_unpackhi_pd(av2, av3); - - av0 = _mm256_permute2f128_pd(tv0, tv2, 0x20); - av1 = _mm256_permute2f128_pd(tv1, tv3, 0x20); - av2 = _mm256_permute2f128_pd(tv0, tv2, 0x31); - av3 = _mm256_permute2f128_pd(tv1, tv3, 0x31); - - //get real, imag and sum - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); - av0 = _mm256_unpacklo_pd(tv0, tv1); - av1 = _mm256_unpackhi_pd(tv0, tv1); - sum = _mm256_add_pd(av0, av1); - _mm256_storeu_pd(par, av0); par += 4; - _mm256_storeu_pd(pai, av1); pai += 4; - _mm256_storeu_pd(pas, sum); pas += 4; - - tv0 = _mm256_permute2f128_pd(av2, av3, 0x20); - tv1 = _mm256_permute2f128_pd(av2, av3, 0x31); - av2 = _mm256_unpacklo_pd(tv0, tv1); - av3 = _mm256_unpackhi_pd(tv0, tv1); - sum = _mm256_add_pd(av2, av3); - _mm256_storeu_pd(par, av2); par += 4; - _mm256_storeu_pd(pai, av3); pai += 4; - _mm256_storeu_pd(pas, sum); pas += 4; - - pa = pa + lda; - } - #endif - //A Transpose case: - for (gint_t ii = 0; ii < BLIS_MX8 ; ii++) - { - gint_t idx = ii * lda; - gint_t sidx; - gint_t max_k = (k*2) - 8; - for (p = 0; p <= max_k; p += 8) - { - double ar0_ = *(pa + idx + p); - double ai0_ = *(pa + idx + p + 1); - - double ar1_ = *(pa + idx + p + 2); - double ai1_ = *(pa + idx + p + 3); - - double ar2_ = *(pa + idx + p + 4); - double ai2_ = *(pa + idx + p + 5); - - double ar3_ = *(pa + idx + p + 6); - double ai3_ = *(pa + idx + p + 7); - - sidx = (p/2) * BLIS_MX8; - *(par + sidx + ii) = ar0_; - *(pai + sidx + ii) = ai0_; - *(pas + sidx + ii) = ar0_ + ai0_; - - sidx = ((p+2)/2) * BLIS_MX8; - *(par + sidx + ii) = ar1_; - *(pai + sidx + ii) = ai1_; - *(pas + sidx + ii) = ar1_ + ai1_; - - sidx = ((p+4)/2) * BLIS_MX8; - *(par + sidx + ii) = ar2_; - *(pai + sidx + ii) = ai2_; - *(pas + sidx + ii) = ar2_ + ai2_; - - sidx = ((p+6)/2) * BLIS_MX8; - *(par + sidx + ii) = ar3_; - *(pai + sidx + ii) = ai3_; - *(pas + sidx + ii) = ar3_ + ai3_; - - } - - for (; p < (k*2); p += 2) - { - double ar_ = *(pa + idx + p); - double ai_ = *(pa + idx + p + 1); - gint_t sidx = (p/2) * BLIS_MX8; - *(par + sidx + ii) = ar_; - *(pai + sidx + ii) = ai_; - *(pas + sidx + ii) = ar_ + ai_; - } - } - } - } //mx==8 - else//mx==1 - { - if(isTransA==false) - { - pa = pa +i; - //A No transpose case: - for (gint_t p = 0; p < k; p += 1) - { - gint_t idx = p * lda; - for (gint_t ii = 0; ii < (mx*2) ; ii += 2) - { //real + imag : Rkernel needs 8 elements each. - double ar_ = *(pa + idx + ii); - double ai_ = *(pa + idx + ii + 1); - *par = ar_; - *pai = ai_; - *pas = ar_ + ai_; - par++; pai++; pas++; - } - } - } - else - { - gint_t idx = (i/2) * lda; - pa = pa + idx; - - //A Transpose case: - for (gint_t ii = 0; ii < mx ; ii++) - { - gint_t idx = ii * lda; - gint_t sidx; - for (p = 0; p < (k*2); p += 2) - { - double ar0_ = *(pa + idx + p); - double ai0_ = *(pa + idx + p + 1); - - sidx = (p/2) * mx; - *(par + sidx + ii) = ar0_; - *(pai + sidx + ii) = ai0_; - *(pas + sidx + ii) = ar0_ + ai0_; - - } - } - } - }//mx==1 + } } -/************************************************************************************************************/ -/***************************************** 3m_sqp implementation ******************************************/ -/************************************************************************************************************/ -/* 3m_sqp implementation packs A, B and C matrix and uses dgemm_sqp real kernel implementation. - 3m_sqp focuses mainly on square matrixes but also supports non-square matrix. Current support is limiteed to - m multiple of 8 and column storage. -*/ -static err_t bli_zgemm_sqp_m8(gint_t m, gint_t n, gint_t k, double* a, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, double alpha, double beta, bool isTransA, gint_t mx, gint_t* p_istart) +/**************************************************************/ +/* workspace memory allocation for 3m_sqp algorithm for zgemm */ +/**************************************************************/ +err_t allocate_3m_Sqp_workspace(workspace_3m_sqp *mem_3m_sqp, + gint_t mx, + gint_t nx, + gint_t k, + gint_t kx ) { - inc_t m2 = m<<1; - inc_t mxmul2 = mx<<1; - if((*p_istart) > (m2-mxmul2)) - { - return BLIS_SUCCESS; - } + //3m_sqp workspace Memory allocation /* B matrix */ - double* br, * bi, * bs; + // B matrix packed with n x k size. without kx smaller sizes for now. mem_block mbr, mbi, mbs; - if(bli_allocateWorkspace(n, k, &mbr, &mbi, &mbs)!=0) + if(bli_allocateWorkspace(nx, k, &mbr, &mbi, &mbs)!=0) { return BLIS_FAILURE; } - br = (double*)mbr.alignedBuf; - bi = (double*)mbi.alignedBuf; - bs = (double*)mbs.alignedBuf; - - //multiply lda, ldb and ldc by 2 to account for real and imaginary components per dcomplex. - lda = lda * 2; - ldb = ldb * 2; - ldc = ldc * 2; - - /* Split b (br, bi) and - compute bs = br + bi */ - double* pbr = br; - double* pbi = bi; - double* pbs = bs; + mem_3m_sqp->br = (double*)mbr.alignedBuf; + mem_3m_sqp->bi = (double*)mbi.alignedBuf; + mem_3m_sqp->bs = (double*)mbs.alignedBuf; + mem_3m_sqp->br_unaligned = (double*)mbr.unalignedBuf; + mem_3m_sqp->bi_unaligned = (double*)mbi.unalignedBuf; + mem_3m_sqp->bs_unaligned = (double*)mbs.unalignedBuf; - gint_t j; - - /* b matrix real and imag packing and compute. */ - bli_packX_real_imag_sum(b, n, k, ldb, pbr, pbi, pbs, alpha, mx); -#if 0//bug in above api to be fixed for mx = 1 - if((alpha ==1.0)||(alpha==-1.0)) - { - if(alpha ==1.0) - { - for (j = 0; j < n; j++) - { - for (p = 0; p < (k*2); p += 2)// (real + imag)*k - { - double br_ = b[(j * ldb) + p]; - double bi_ = b[(j * ldb) + p + 1]; - *pbr = br_; - *pbi = bi_; - *pbs = br_ + bi_; - - pbr++; pbi++; pbs++; - } - } - } - else - { - for (j = 0; j < n; j++) - { - for (p = 0; p < (k*2); p += 2)// (real + imag)*k - { - double br_ = -b[(j * ldb) + p]; - double bi_ = -b[(j * ldb) + p + 1]; - *pbr = br_; - *pbi = bi_; - *pbs = br_ + bi_; - - pbr++; pbi++; pbs++; - } - } - } - } - else - { - for (j = 0; j < n; j++) - { - for (p = 0; p < (k*2); p += 2)// (real + imag)*k - { - double br_ = alpha * b[(j * ldb) + p]; - double bi_ = alpha * b[(j * ldb) + p + 1]; - *pbr = br_; - *pbi = bi_; - *pbs = br_ + bi_; - - pbr++; pbi++; pbs++; - } - } - } -#endif /* Workspace memory allocation currently done dynamically This needs to be taken from already allocated memory pool in application for better performance */ /* A matrix */ - double* ar, * ai, * as; mem_block mar, mai, mas; - if(bli_allocateWorkspace(mx, k, &mar, &mai, &mas) !=0) + if(bli_allocateWorkspace(mx, kx, &mar, &mai, &mas) !=0) { return BLIS_FAILURE; } - ar = (double*)mar.alignedBuf; - ai = (double*)mai.alignedBuf; - as = (double*)mas.alignedBuf; - + mem_3m_sqp->ar = (double*)mar.alignedBuf; + mem_3m_sqp->ai = (double*)mai.alignedBuf; + mem_3m_sqp->as = (double*)mas.alignedBuf; + mem_3m_sqp->ar_unaligned = (double*)mar.unalignedBuf; + mem_3m_sqp->ai_unaligned = (double*)mai.unalignedBuf; + mem_3m_sqp->as_unaligned = (double*)mas.unalignedBuf; /* w matrix */ - double* w; mem_block mw; mw.data_size = sizeof(double); - mw.size = mx * n; + mw.size = mx * nx; if (bli_getaligned(&mw) != 0) { return BLIS_FAILURE; } - w = (double*)mw.alignedBuf; - + mem_3m_sqp->w = (double*)mw.alignedBuf; + mem_3m_sqp->w_unaligned = (double*)mw.unalignedBuf; /* cr matrix */ - double* cr; mem_block mcr; mcr.data_size = sizeof(double); - mcr.size = mx * n; + mcr.size = mx * nx; if (bli_getaligned(&mcr) != 0) { return BLIS_FAILURE; } - cr = (double*)mcr.alignedBuf; + mem_3m_sqp->cr = (double*)mcr.alignedBuf; + mem_3m_sqp->cr_unaligned = (double*)mcr.unalignedBuf; /* ci matrix */ - double* ci; mem_block mci; mci.data_size = sizeof(double); - mci.size = mx * n; + mci.size = mx * nx; if (bli_getaligned(&mci) != 0) { return BLIS_FAILURE; } - ci = (double*)mci.alignedBuf; + mem_3m_sqp->ci = (double*)mci.alignedBuf; + mem_3m_sqp->ci_unaligned = (double*)mci.unalignedBuf; + + // A packing buffer + mem_3m_sqp->aPacked = (double*)bli_malloc_user(sizeof(double) * kx * mx); + if (mem_3m_sqp->aPacked == NULL) + { + return BLIS_FAILURE; + } + + return BLIS_SUCCESS; +} + +void free_3m_Sqp_workspace(workspace_3m_sqp *mem_3m_sqp) +{ + // A packing buffer free + bli_free_user(mem_3m_sqp->aPacked); + +#if MEM_ALLOC + if(mem_3m_sqp->ar_unaligned) + { + free(mem_3m_sqp->ar_unaligned); + } + if(mem_3m_sqp->ai_unaligned) + { + free(mem_3m_sqp->ai_unaligned); + } + if(mem_3m_sqp->as_unaligned) + { + free(mem_3m_sqp->as_unaligned); + } + + if(mem_3m_sqp->br_unaligned) + { + free(mem_3m_sqp->br_unaligned); + } + if(mem_3m_sqp->bi_unaligned) + { + free(mem_3m_sqp->bi_unaligned); + } + if(mem_3m_sqp->bs_unaligned) + { + free(mem_3m_sqp->bs_unaligned); + } + + if(mem_3m_sqp->w_unaligned) + { + free(mem_3m_sqp->w_unaligned); + } + if(mem_3m_sqp->cr_unaligned) + { + free(mem_3m_sqp->cr_unaligned); + } + if(mem_3m_sqp->ci_unaligned) + { + free(mem_3m_sqp->ci_unaligned); + } + +#else//MEM_ALLOC + /* free workspace buffers */ + bli_free_user(mem_3m_sqp->br); + bli_free_user(mem_3m_sqp->bi); + bli_free_user(mem_3m_sqp->bs); + bli_free_user(mem_3m_sqp->ar); + bli_free_user(mem_3m_sqp->ai); + bli_free_user(mem_3m_sqp->as); + bli_free_user(mem_3m_sqp->w); + bli_free_user(mem_3m_sqp->cr); + bli_free_user(mem_3m_sqp->ci); +#endif//MEM_ALLOC +} + +//3m_sqp m loop +BLIS_INLINE err_t bli_sqp_zgemm_m8( gint_t m, + gint_t n, + gint_t k, + double* a, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc, + double alpha, + double beta, + bool isTransA, + gint_t mx, + gint_t* p_istart, + gint_t kx, + workspace_3m_sqp *mem_3m_sqp) +{ + inc_t m2 = m<<1; + inc_t mxmul2 = mx<<1; + + if((*p_istart) > (m2-mxmul2)) + { + return BLIS_SUCCESS; + } inc_t i; gint_t max_m = (m2-mxmul2); + + //get workspace + double* ar, * ai, * as; + ar = mem_3m_sqp->ar; + ai = mem_3m_sqp->ai; + as = mem_3m_sqp->as; + + double* br, * bi, * bs; + br = mem_3m_sqp->br; + bi = mem_3m_sqp->bi; + bs = mem_3m_sqp->bs; + + double* cr, * ci; + cr = mem_3m_sqp->cr; + ci = mem_3m_sqp->ci; + + double *w; + w = mem_3m_sqp->w; + + double* a_aligned; + a_aligned = mem_3m_sqp->aPacked; + + /* Split b (br, bi) and + compute bs = br + bi */ + double* pbr = br; + double* pbi = bi; + double* pbs = bs; + /* b matrix real and imag packing and compute. */ + bli_3m_sqp_packB_real_imag_sum(b, n, k, ldb, pbr, pbi, pbs, alpha, mx); + for (i = (*p_istart); i <= max_m; i += mxmul2) //this loop can be threaded. { +#if KLP//kloop + int p = 0; + for(; p <= (k-kx); p += kx) + { + bli_sqp_zgemm_kx(m, n, kx, p, a, lda, k, c, ldc, + isTransA, alpha, beta, mx, i, ar, ai, as, + br + p, bi + p, bs + p, cr, ci, w, a_aligned); + }// k loop end + + if(p>3)<<3); + + workspace_3m_sqp mem_3m_sqp; + + /* multiply lda, ldb and ldc by 2 to account for + real & imaginary components per dcomplex. */ + lda = lda * 2; + ldb = ldb * 2; + ldc = ldc * 2; + + /* user can set BLIS_MULTI_INSTANCE macro for + better performance while runing multi-instance use-case. + */ + dim_t multi_instance = bli_env_get_var( "BLIS_MULTI_INSTANCE", -1 ); + gint_t nx = n; + if(multi_instance>0) { - free(mar.unalignedBuf); + //limited nx size helps in reducing memory footprint in multi-instance case. + nx = 84; + // 84 is derived based on tuning results } - if(mai.unalignedBuf) + + if(nx>n) { - free(mai.unalignedBuf); + nx = n; } - if(mas.unalignedBuf) + + gint_t kx = k;// kx is configurable at run-time. +#if KLP + if (kx > k) { - free(mas.unalignedBuf); + kx = k; } - if(mw.unalignedBuf) + // for tn case there is a bug in handling k parts. To be fixed. + if(isTransA==true) { - free(mw.unalignedBuf); + kx = k; } - if(mcr.unalignedBuf) +#else + kx = k; +#endif + //3m_sqp workspace Memory allocation + if(allocate_3m_Sqp_workspace(&mem_3m_sqp, mx, nx, k, kx)!=BLIS_SUCCESS) { - free(mcr.unalignedBuf); + return BLIS_FAILURE; } - if(mci.unalignedBuf) + BLI_SQP_ZGEMM_N(mx) + *p_istart = (m-m8rem)*2; + + if(m8rem!=0) { - free(mci.unalignedBuf); + //complete residue m blocks + BLI_SQP_ZGEMM_N(m8rem) } - if(mbr.unalignedBuf) + + free_3m_Sqp_workspace(&mem_3m_sqp); + return status; +} + +/****************************************************************************/ +/*********************** dgemm_sqp implementation****************************/ +/****************************************************************************/ +/* dgemm_sqp implementation packs A matrix based on lda and m size. + dgemm_sqp focuses mainly on square matrixes but also supports non-square matrix. + Current support is limiteed to m multiple of 8 and column storage. + C = AxB and C = AtxB is handled in the design. + AtxB case is done by transposing A matrix while packing A. + In majority of use-case, alpha are +/-1, so instead of explicitly multiplying + alpha its done during packing itself by changing sign. +*/ +BLIS_INLINE err_t bli_sqp_dgemm(gint_t m, + gint_t n, + gint_t k, + double* a, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc, + double alpha, + double beta, + bool isTransA, + dim_t nt) +{ + gint_t istart = 0; + gint_t* p_istart = &istart; + *p_istart = 0; + err_t status = BLIS_SUCCESS; + dim_t m8rem = m - ((m>>3)<<3); + + /* dgemm implementation with 8mx5n major kernel and column preferred storage */ + gint_t mx = 8; + gint_t kx = k; + double* a_aligned = NULL; + + if(nt<=1)//single pack buffer allocated for single thread case { - free(mbr.unalignedBuf); + a_aligned = (double*)bli_malloc_user(sizeof(double) * kx * mx); } - if(mbi.unalignedBuf) + gint_t nx = n;//MAX; + if(nx>n) { - free(mbi.unalignedBuf); + nx = n; } - if(mbs.unalignedBuf) + //mx==8 case for dgemm. + BLI_SQP_DGEMM_N(mx) + *p_istart = (m-m8rem); + + if(nt>1) { - free(mbs.unalignedBuf); + //2nd level thread for mx=8 + gint_t rem_m = m - (*p_istart); + if((rem_m>=mx)&&(status==BLIS_SUCCESS)) + { + status = bli_sqp_dgemm_m8( m, n, k, a, lda, b, ldb, c, ldc, + isTransA, alpha, mx, p_istart, kx, a_aligned); + } } -#else - /* free workspace buffers */ - bli_free_user(mbr.alignedBuf); - bli_free_user(mbi.alignedBuf); - bli_free_user(mbs.alignedBuf); - bli_free_user(mar.alignedBuf); - bli_free_user(mai.alignedBuf); - bli_free_user(mas.alignedBuf); - bli_free_user(mw.alignedBuf); - bli_free_user(mcr.alignedBuf); - bli_free_user(mci.alignedBuf); -#endif - return BLIS_SUCCESS; + + if(status==BLIS_SUCCESS) + { + if(m8rem!=0) + { + //complete residue m blocks + BLI_SQP_DGEMM_N(m8rem) + } + } + + if(nt<=1)//single pack buffer allocated for single thread case + { + bli_free_user(a_aligned); + } + return status; } \ No newline at end of file diff --git a/kernels/zen/3/bli_gemm_sqp_kernels.c b/kernels/zen/3/bli_gemm_sqp_kernels.c new file mode 100644 index 0000000000..4762fb8314 --- /dev/null +++ b/kernels/zen/3/bli_gemm_sqp_kernels.c @@ -0,0 +1,1588 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" +#include "immintrin.h" +#include "bli_gemm_sqp_kernels.h" + +#define BLIS_LOADFIRST 0 +#define BLIS_ENABLE_PREFETCH 1 + +#define BLIS_MX8 8 +#define BLIS_MX4 4 +#define BLIS_MX1 1 + +/****************************************************************************/ +/*************** dgemm kernels (8mxn) column preffered *********************/ +/****************************************************************************/ + +/* Main dgemm kernel 8mx6n with single load and store of C matrix block + alpha = +/-1 and beta = +/-1,0 handled while packing.*/ +inc_t bli_sqp_dgemm_kernel_8mx6n(gint_t n, + gint_t k, + gint_t j, + double* aPacked, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc) +{ + gint_t p; + + __m256d av0, av1; + __m256d bv0, bv1; + __m256d cv0, cv1, cv2, cv3, cv4, cv5; + __m256d cx0, cx1, cx2, cx3, cx4, cx5; + double* pb, * pc; + + pb = b; + pc = c; + inc_t ldc6 = ldc * 6; inc_t ldb6 = ldb * 6; + + for (j = 0; j <= (n - 6); j += 6) { + double* pcldc = pc + ldc; + double* pcldc2 = pcldc + ldc; + double* pcldc3 = pcldc2 + ldc; + double* pcldc4 = pcldc3 + ldc; + double* pcldc5 = pcldc4 + ldc; + + double* pbldb = pb + ldb; + double* pbldb2 = pbldb + ldb; + double* pbldb3 = pbldb2 + ldb; + double* pbldb4 = pbldb3 + ldb; + double* pbldb5 = pbldb4 + ldb; + +#if BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(pc), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc2), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc3), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc4), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc5), _MM_HINT_T0); + + _mm_prefetch((char*)(aPacked), _MM_HINT_T0); + + _mm_prefetch((char*)(pb), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb2), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb3), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb4), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb5), _MM_HINT_T0); +#endif + /* C matrix column major load */ +#if BLIS_LOADFIRST + cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); + cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); + cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); + cv3 = _mm256_loadu_pd(pcldc3); cx3 = _mm256_loadu_pd(pcldc3 + 4); + cv4 = _mm256_loadu_pd(pcldc4); cx4 = _mm256_loadu_pd(pcldc4 + 4); + cv5 = _mm256_loadu_pd(pcldc5); cx5 = _mm256_loadu_pd(pcldc5 + 4); +#else + cv0 = _mm256_setzero_pd(); cx0 = _mm256_setzero_pd(); + cv1 = _mm256_setzero_pd(); cx1 = _mm256_setzero_pd(); + cv2 = _mm256_setzero_pd(); cx2 = _mm256_setzero_pd(); + cv3 = _mm256_setzero_pd(); cx3 = _mm256_setzero_pd(); + cv4 = _mm256_setzero_pd(); cx4 = _mm256_setzero_pd(); + cv5 = _mm256_setzero_pd(); cx5 = _mm256_setzero_pd(); +#endif + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + av0 = _mm256_loadu_pd(x); x += 4; av1 = _mm256_loadu_pd(x); x += 4; + bv0 = _mm256_broadcast_sd (pb0); pb0++; + bv1 = _mm256_broadcast_sd(pbldb); pbldb++; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + cx0 = _mm256_fmadd_pd(av1, bv0, cx0); + cv1 = _mm256_fmadd_pd(av0, bv1, cv1); + cx1 = _mm256_fmadd_pd(av1, bv1, cx1); + + bv0 = _mm256_broadcast_sd(pbldb2);pbldb2++; + bv1 = _mm256_broadcast_sd(pbldb3);pbldb3++; + cv2 = _mm256_fmadd_pd(av0, bv0, cv2); + cx2 = _mm256_fmadd_pd(av1, bv0, cx2); + cv3 = _mm256_fmadd_pd(av0, bv1, cv3); + cx3 = _mm256_fmadd_pd(av1, bv1, cx3); + + bv0 = _mm256_broadcast_sd(pbldb4);pbldb4++; + bv1 = _mm256_broadcast_sd(pbldb5);pbldb5++; + cv4 = _mm256_fmadd_pd(av0, bv0, cv4); + cx4 = _mm256_fmadd_pd(av1, bv0, cx4); + cv5 = _mm256_fmadd_pd(av0, bv1, cv5); + cx5 = _mm256_fmadd_pd(av1, bv1, cx5); + } +#if BLIS_LOADFIRST +#else + bv0 = _mm256_loadu_pd(pc); bv1 = _mm256_loadu_pd(pc + 4); + cv0 = _mm256_add_pd(cv0, bv0); cx0 = _mm256_add_pd(cx0, bv1); + + av0 = _mm256_loadu_pd(pcldc); av1 = _mm256_loadu_pd(pcldc + 4); + cv1 = _mm256_add_pd(cv1, av0); cx1 = _mm256_add_pd(cx1, av1); + + bv0 = _mm256_loadu_pd(pcldc2); bv1 = _mm256_loadu_pd(pcldc2 + 4); + cv2 = _mm256_add_pd(cv2, bv0); cx2 = _mm256_add_pd(cx2, bv1); + + av0 = _mm256_loadu_pd(pcldc3); av1 = _mm256_loadu_pd(pcldc3 + 4); + cv3 = _mm256_add_pd(cv3, av0); cx3 = _mm256_add_pd(cx3, av1); + + bv0 = _mm256_loadu_pd(pcldc4); bv1 = _mm256_loadu_pd(pcldc4 + 4); + cv4 = _mm256_add_pd(cv4, bv0); cx4 = _mm256_add_pd(cx4, bv1); + + av0 = _mm256_loadu_pd(pcldc5); av1 = _mm256_loadu_pd(pcldc5 + 4); + cv5 = _mm256_add_pd(cv5, av0); cx5 = _mm256_add_pd(cx5, av1); +#endif + /* C matrix column major store */ + _mm256_storeu_pd(pc, cv0); + _mm256_storeu_pd(pc + 4, cx0); + + _mm256_storeu_pd(pcldc, cv1); + _mm256_storeu_pd(pcldc + 4, cx1); + + _mm256_storeu_pd(pcldc2, cv2); + _mm256_storeu_pd(pcldc2 + 4, cx2); + + _mm256_storeu_pd(pcldc3, cv3); + _mm256_storeu_pd(pcldc3 + 4, cx3); + + _mm256_storeu_pd(pcldc4, cv4); + _mm256_storeu_pd(pcldc4 + 4, cx4); + + _mm256_storeu_pd(pcldc5, cv5); + _mm256_storeu_pd(pcldc5 + 4, cx5); + + pc += ldc6;pb += ldb6; + } + //printf(" 8x6:j:%d ", j); + return j; +} + +/* alternative Main dgemm kernel 8mx5n with single load and store of C matrix block + alpha = +/-1 and beta = +/-1,0 handled while packing.*/ +inc_t bli_sqp_dgemm_kernel_8mx5n( gint_t n, + gint_t k, + gint_t j, + double* aPacked, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc) +{ + gint_t p; + __m256d av0; + __m256d bv0, bv1, bv2, bv3; + __m256d cv0, cv1, cv2, cv3; + __m256d cx0, cx1, cx2, cx3; + __m256d bv4, cv4, cx4; + double* pb, * pc; + + pb = b; + pc = c; + inc_t ldc5 = ldc * 5; inc_t ldb5 = ldb * 5; + + for (; j <= (n - 5); j += 5) { + + double* pcldc = pc + ldc; + double* pcldc2 = pcldc + ldc; + double* pcldc3 = pcldc2 + ldc; + double* pcldc4 = pcldc3 + ldc; + + double* pbldb = pb + ldb; + double* pbldb2 = pbldb + ldb; + double* pbldb3 = pbldb2 + ldb; + double* pbldb4 = pbldb3 + ldb; + +#if BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(pc), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc2), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc3), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc4), _MM_HINT_T0); + + _mm_prefetch((char*)(aPacked), _MM_HINT_T0); + + _mm_prefetch((char*)(pb), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb2), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb3), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb4), _MM_HINT_T0); +#endif + /* C matrix column major load */ +#if BLIS_LOADFIRST + cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); + cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); + cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); + cv3 = _mm256_loadu_pd(pcldc3); cx3 = _mm256_loadu_pd(pcldc3 + 4); + cv4 = _mm256_loadu_pd(pcldc4); cx4 = _mm256_loadu_pd(pcldc4 + 4); +#else + cv0 = _mm256_setzero_pd(); cx0 = _mm256_setzero_pd(); + cv1 = _mm256_setzero_pd(); cx1 = _mm256_setzero_pd(); + cv2 = _mm256_setzero_pd(); cx2 = _mm256_setzero_pd(); + cv3 = _mm256_setzero_pd(); cx3 = _mm256_setzero_pd(); + cv4 = _mm256_setzero_pd(); cx4 = _mm256_setzero_pd(); +#endif + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + bv0 = _mm256_broadcast_sd(pb0); pb0++; + bv1 = _mm256_broadcast_sd(pbldb); pbldb++; + bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; + bv3 = _mm256_broadcast_sd(pbldb3);pbldb3++; + bv4 = _mm256_broadcast_sd(pbldb4);pbldb4++; + + av0 = _mm256_loadu_pd(x); x += 4; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + cv1 = _mm256_fmadd_pd(av0, bv1, cv1); + cv2 = _mm256_fmadd_pd(av0, bv2, cv2); + cv3 = _mm256_fmadd_pd(av0, bv3, cv3); + cv4 = _mm256_fmadd_pd(av0, bv4, cv4); + + av0 = _mm256_loadu_pd(x); x += 4; + cx0 = _mm256_fmadd_pd(av0, bv0, cx0); + cx1 = _mm256_fmadd_pd(av0, bv1, cx1); + cx2 = _mm256_fmadd_pd(av0, bv2, cx2); + cx3 = _mm256_fmadd_pd(av0, bv3, cx3); + cx4 = _mm256_fmadd_pd(av0, bv4, cx4); + } +#if BLIS_LOADFIRST +#else + bv0 = _mm256_loadu_pd(pc); bv1 = _mm256_loadu_pd(pc + 4); + cv0 = _mm256_add_pd(cv0, bv0); cx0 = _mm256_add_pd(cx0, bv1); + + bv2 = _mm256_loadu_pd(pcldc); bv3 = _mm256_loadu_pd(pcldc + 4); + cv1 = _mm256_add_pd(cv1, bv2); cx1 = _mm256_add_pd(cx1, bv3); + + bv0 = _mm256_loadu_pd(pcldc2); bv1 = _mm256_loadu_pd(pcldc2 + 4); + cv2 = _mm256_add_pd(cv2, bv0); cx2 = _mm256_add_pd(cx2, bv1); + + bv2 = _mm256_loadu_pd(pcldc3); bv3 = _mm256_loadu_pd(pcldc3 + 4); + cv3 = _mm256_add_pd(cv3, bv2); cx3 = _mm256_add_pd(cx3, bv3); + + bv0 = _mm256_loadu_pd(pcldc4); bv1 = _mm256_loadu_pd(pcldc4 + 4); + cv4 = _mm256_add_pd(cv4, bv0); cx4 = _mm256_add_pd(cx4, bv1); +#endif + /* C matrix column major store */ + _mm256_storeu_pd(pc, cv0); + _mm256_storeu_pd(pc + 4, cx0); + + _mm256_storeu_pd(pcldc, cv1); + _mm256_storeu_pd(pcldc + 4, cx1); + + _mm256_storeu_pd(pcldc2, cv2); + _mm256_storeu_pd(pcldc2 + 4, cx2); + + _mm256_storeu_pd(pcldc3, cv3); + _mm256_storeu_pd(pcldc3 + 4, cx3); + + _mm256_storeu_pd(pcldc4, cv4); + _mm256_storeu_pd(pcldc4 + 4, cx4); + + pc += ldc5;pb += ldb5; + } + //printf(" 8x5:j:%d ", j); + return j; +} + +/* residue dgemm kernel 8mx4n with single load and store of C matrix block + Code could be optimized further, complete ymm register set is not used. + Being residue kernel, its of lesser priority. +*/ +inc_t bli_sqp_dgemm_kernel_8mx4n( gint_t n, + gint_t k, + gint_t j, + double* aPacked, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc) +{ + gint_t p; + __m256d av0; + __m256d bv0, bv1, bv2, bv3; + __m256d cv0, cv1, cv2, cv3; + __m256d cx0, cx1, cx2, cx3; + double* pb, * pc; + + pb = b; + pc = c; + inc_t ldc4 = ldc * 4; inc_t ldb4 = ldb * 4; + + for (; j <= (n - 4); j += 4) { + + double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; double* pcldc3 = pcldc2 + ldc; + double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; double* pbldb3 = pbldb2 + ldb; + + cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); + cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); + cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); + cv3 = _mm256_loadu_pd(pcldc3); cx3 = _mm256_loadu_pd(pcldc3 + 4); + { + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + // better kernel to be written since more register are available. + bv0 = _mm256_broadcast_sd(pb0); pb0++; + bv1 = _mm256_broadcast_sd(pbldb); pbldb++; + bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; + bv3 = _mm256_broadcast_sd(pbldb3); pbldb3++; + + av0 = _mm256_loadu_pd(x); x += 4; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + cv1 = _mm256_fmadd_pd(av0, bv1, cv1); + cv2 = _mm256_fmadd_pd(av0, bv2, cv2); + cv3 = _mm256_fmadd_pd(av0, bv3, cv3); + + av0 = _mm256_loadu_pd(x); x += 4; + cx0 = _mm256_fmadd_pd(av0, bv0, cx0); + cx1 = _mm256_fmadd_pd(av0, bv1, cx1); + cx2 = _mm256_fmadd_pd(av0, bv2, cx2); + cx3 = _mm256_fmadd_pd(av0, bv3, cx3); + } + } + _mm256_storeu_pd(pc, cv0); + _mm256_storeu_pd(pc + 4, cx0); + _mm256_storeu_pd(pcldc, cv1); + _mm256_storeu_pd(pcldc + 4, cx1); + _mm256_storeu_pd(pcldc2, cv2); + _mm256_storeu_pd(pcldc2 + 4, cx2); + _mm256_storeu_pd(pcldc3, cv3); + _mm256_storeu_pd(pcldc3 + 4, cx3); + + pc += ldc4;pb += ldb4; + }// j loop 4 multiple + //printf(" 8x4:j:%d ", j); + return j; +} + +/* residue dgemm kernel 8mx3n with single load and store of C matrix block + Code could be optimized further, complete ymm register set is not used. + Being residue kernel, its of lesser priority. +*/ +inc_t bli_sqp_dgemm_kernel_8mx3n( gint_t n, + gint_t k, + gint_t j, + double* aPacked, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc) +{ + gint_t p; + __m256d av0; + __m256d bv0, bv1, bv2; + __m256d cv0, cv1, cv2; + __m256d cx0, cx1, cx2; + double* pb, * pc; + + pb = b; + pc = c; + + inc_t ldc3 = ldc * 3; inc_t ldb3 = ldb * 3; + + for (; j <= (n - 3); j += 3) { + + double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; + double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; + + cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); + cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); + cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); + { + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + bv0 = _mm256_broadcast_sd(pb0); pb0++; + bv1 = _mm256_broadcast_sd(pbldb); pbldb++; + bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; + + av0 = _mm256_loadu_pd(x); x += 4; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + cv1 = _mm256_fmadd_pd(av0, bv1, cv1); + cv2 = _mm256_fmadd_pd(av0, bv2, cv2); + + av0 = _mm256_loadu_pd(x); x += 4; + cx0 = _mm256_fmadd_pd(av0, bv0, cx0); + cx1 = _mm256_fmadd_pd(av0, bv1, cx1); + cx2 = _mm256_fmadd_pd(av0, bv2, cx2); + } + } + + _mm256_storeu_pd(pc, cv0); + _mm256_storeu_pd(pc + 4, cx0); + _mm256_storeu_pd(pcldc, cv1); + _mm256_storeu_pd(pcldc + 4, cx1); + _mm256_storeu_pd(pcldc2, cv2); + _mm256_storeu_pd(pcldc2 + 4, cx2); + + pc += ldc3;pb += ldb3; + }// j loop 3 multiple + //printf(" 8x3:j:%d ", j); + return j; +} + +/* residue dgemm kernel 8mx2n with single load and store of C matrix block + Code could be optimized further, complete ymm register set is not used. + Being residue kernel, its of lesser priority. +*/ +inc_t bli_sqp_dgemm_kernel_8mx2n( gint_t n, + gint_t k, + gint_t j, + double* aPacked, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc) +{ + gint_t p; + __m256d av0; + __m256d bv0, bv1; + __m256d cv0, cv1; + __m256d cx0, cx1; + double* pb, * pc; + + pb = b; + pc = c; + inc_t ldc2 = ldc * 2; inc_t ldb2 = ldb * 2; + + for (; j <= (n - 2); j += 2) { + double* pcldc = pc + ldc; + double* pbldb = pb + ldb; + + cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); + cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); + { + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + bv0 = _mm256_broadcast_sd(pb0); pb0++; + bv1 = _mm256_broadcast_sd(pbldb); pbldb++; + + av0 = _mm256_loadu_pd(x); x += 4; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + cv1 = _mm256_fmadd_pd(av0, bv1, cv1); + + av0 = _mm256_loadu_pd(x); x += 4; + cx0 = _mm256_fmadd_pd(av0, bv0, cx0); + cx1 = _mm256_fmadd_pd(av0, bv1, cx1); + } + } + _mm256_storeu_pd(pc, cv0); + _mm256_storeu_pd(pc + 4, cx0); + _mm256_storeu_pd(pcldc, cv1); + _mm256_storeu_pd(pcldc + 4, cx1); + + pc += ldc2;pb += ldb2; + }// j loop 2 multiple + //printf(" 8x2:j:%d ", j); + return j; +} + +/* residue dgemm kernel 8mx1n with single load and store of C matrix block + Code could be optimized further, complete ymm register set is not used. + Being residue kernel, its of lesser priority. +*/ +inc_t bli_sqp_dgemm_kernel_8mx1n( gint_t n, + gint_t k, + gint_t j, + double* aPacked, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc) +{ + gint_t p; + __m256d av0; + __m256d bv0; + __m256d cv0; + __m256d cx0; + double* pb, * pc; + + pb = b; + pc = c; + + for (; j <= (n - 1); j += 1) { + cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + bv0 = _mm256_broadcast_sd(pb0); pb0++; + + av0 = _mm256_loadu_pd(x); x += 4; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + + av0 = _mm256_loadu_pd(x); x += 4; + cx0 = _mm256_fmadd_pd(av0, bv0, cx0); + } + _mm256_storeu_pd(pc, cv0); + _mm256_storeu_pd(pc + 4, cx0); + pc += ldc;pb += ldb; + }// j loop 1 multiple + //printf(" 8x1:j:%d ", j); + return j; +} + +#if 0 +/************************************************************************************************************/ +/************************** dgemm kernels (4mxn) column preffered ******************************************/ +/************************************************************************************************************/ +/* Residue dgemm kernel 4mx10n with single load and store of C matrix block + alpha = +/-1 and beta = +/-1,0 handled while packing.*/ +inc_t bli_sqp_dgemm_kernel_4mx10n( gint_t n, + gint_t k, + gint_t j, + double* aPacked, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc) +{ + gint_t p; + /* incomplete */ + __m256d av0; + __m256d bv0, bv1, bv2, bv3; + __m256d cv0, cv1, cv2, cv3; + __m256d cx0, cx1, cx2, cx3; + __m256d bv4, cv4, cx4; + double* pb, * pc; + + pb = b; + pc = c; + inc_t ldc10 = ldc * 10; inc_t ldb10 = ldb * 10; + + for (j = 0; j <= (n - 10); j += 10) { + + double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; double* pcldc3 = pcldc2 + ldc; double* pcldc4 = pcldc3 + ldc; + double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; double* pbldb3 = pbldb2 + ldb; double* pbldb4 = pbldb3 + ldb; + +#if BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(pc), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc2), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc3), _MM_HINT_T0); + _mm_prefetch((char*)(pcldc4), _MM_HINT_T0); + + _mm_prefetch((char*)(aPacked), _MM_HINT_T0); + + _mm_prefetch((char*)(pb), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb2), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb3), _MM_HINT_T0); + _mm_prefetch((char*)(pbldb4), _MM_HINT_T0); +#endif + /* C matrix column major load */ +#if BLIS_LOADFIRST + cv0 = _mm256_loadu_pd(pc); + cv1 = _mm256_loadu_pd(pcldc); + cv2 = _mm256_loadu_pd(pcldc2); + cv3 = _mm256_loadu_pd(pcldc3); + cv4 = _mm256_loadu_pd(pcldc4); +#else + cv0 = _mm256_setzero_pd(); + cv1 = _mm256_setzero_pd(); + cv2 = _mm256_setzero_pd(); + cv3 = _mm256_setzero_pd(); + cv4 = _mm256_setzero_pd(); +#endif + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + bv0 = _mm256_broadcast_sd(pb0); pb0++; + bv1 = _mm256_broadcast_sd(pbldb); pbldb++; + bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; + bv3 = _mm256_broadcast_sd(pbldb3);pbldb3++; + bv4 = _mm256_broadcast_sd(pbldb4);pbldb4++; + + av0 = _mm256_loadu_pd(x); x += 4; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + cv1 = _mm256_fmadd_pd(av0, bv1, cv1); + cv2 = _mm256_fmadd_pd(av0, bv2, cv2); + cv3 = _mm256_fmadd_pd(av0, bv3, cv3); + cv4 = _mm256_fmadd_pd(av0, bv4, cv4); + + } +#if BLIS_LOADFIRST +#else + bv0 = _mm256_loadu_pd(pc); + cv0 = _mm256_add_pd(cv0, bv0); + + bv2 = _mm256_loadu_pd(pcldc); + cv1 = _mm256_add_pd(cv1, bv2); + + bv0 = _mm256_loadu_pd(pcldc2); + cv2 = _mm256_add_pd(cv2, bv0); + + bv2 = _mm256_loadu_pd(pcldc3); + cv3 = _mm256_add_pd(cv3, bv2); + + bv0 = _mm256_loadu_pd(pcldc4); + cv4 = _mm256_add_pd(cv4, bv0); +#endif + /* C matrix column major store */ + _mm256_storeu_pd(pc, cv0); + _mm256_storeu_pd(pcldc, cv1); + _mm256_storeu_pd(pcldc2, cv2); + _mm256_storeu_pd(pcldc3, cv3); + _mm256_storeu_pd(pcldc4, cv4); + + + pc += ldc10;pb += ldb10; + } + + return j; +} + +/* residue dgemm kernel 4mx1n with single load and store of C matrix block + Code could be optimized further, complete ymm register set is not used. + Being residue kernel, its of lesser priority. +*/ +inc_t bli_sqp_dgemm_kernel_4mx1n( gint_t n, + gint_t k, + gint_t j, + double* aPacked, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc) +{ + gint_t p; + __m256d av0; + __m256d bv0; + __m256d cv0; + double* pb, * pc; + + pb = b; + pc = c; + + for (; j <= (n - 1); j += 1) { + cv0 = _mm256_loadu_pd(pc); + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + bv0 = _mm256_broadcast_sd(pb0); pb0++; + av0 = _mm256_loadu_pd(x); x += 4; + cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + } + _mm256_storeu_pd(pc, cv0); + pc += ldc;pb += ldb; + }// j loop 1 multiple + return j; +} + +#endif +/************************************************************************************************************/ +/************************** dgemm kernels (1mxn) column preffered ******************************************/ +/************************************************************************************************************/ + +/* residue dgemm kernel 1mx1n with single load and store of C matrix block + Code could be optimized further, complete ymm register set is not used. + Being residue kernel, its of lesser priority. +*/ +inc_t bli_sqp_dgemm_kernel_1mx1n( gint_t n, + gint_t k, + gint_t j, + double* aPacked, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc) +{ + gint_t p; + double a0; + double b0; + double c0; + double* pb, * pc; + + pb = b; + pc = c; + + for (; j <= (n - 1); j += 1) { + c0 = *pc; + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + b0 = *pb0; pb0++; + a0 = *x; x++; + c0 += (a0 * b0); + } + *pc = c0; + pc += ldc;pb += ldb; + }// j loop 1 multiple + //printf(" 1x1:j:%d ", j); + return j; +} + +inc_t bli_sqp_dgemm_kernel_mxn( gint_t n, + gint_t k, + gint_t j, + double* aPacked, + guint_t lda, + double* b, + guint_t ldb, + double* c, + guint_t ldc, + gint_t mx) +{ + gint_t p; + double cx[7]; + + double* pb, * pc; + + pb = b; + pc = c; + + for (; j <= (n - 1); j += 1) { + //cv0 = _mm256_loadu_pd(pc); + for (int i = 0; i < mx; i++) + { + cx[i] = *(pc + i); + } + + double* x = aPacked; + double* pb0 = pb; + for (p = 0; p < k; p += 1) { + //bv0 = _mm256_broadcast_sd(pb0); + double b0 = *pb0; + pb0++; + for (int i = 0; i < mx; i++) + { + cx[i] += (*(x + i)) * b0;//cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + } + //av0 = _mm256_loadu_pd(x); + x += mx; + } + //_mm256_storeu_pd(pc, cv0); + for (int i = 0; i < mx; i++) + { + *(pc + i) = cx[i]; + } + pc += ldc;pb += ldb; + }// j loop 1 multiple + //printf(" mx1:j:%d ", j); + return j; +} + +void bli_sqp_prepackA( double* pa, + double* aPacked, + gint_t k, + guint_t lda, + bool isTransA, + double alpha, + gint_t mx) +{ + //printf(" pmx:%d ",mx); + if(mx==8) + { + bli_prepackA_8(pa,aPacked,k, lda,isTransA, alpha); + } + else if(mx==4) + { + bli_prepackA_4(pa,aPacked,k, lda,isTransA, alpha); + } + else if(mx>4) + { + bli_prepackA_G4(pa,aPacked,k, lda,isTransA, alpha, mx); + } + else + { + bli_prepackA_L4(pa,aPacked,k, lda,isTransA, alpha, mx); + } +} + +/* Ax8 packing subroutine */ +void bli_prepackA_8(double* pa, + double* aPacked, + gint_t k, + guint_t lda, + bool isTransA, + double alpha) +{ + __m256d av0, av1, ymm0; + if(isTransA==false) + { + if(alpha==1.0) + { + for (gint_t p = 0; p < k; p += 1) { + av0 = _mm256_loadu_pd(pa); av1 = _mm256_loadu_pd(pa + 4); pa += lda; + _mm256_storeu_pd(aPacked, av0); _mm256_storeu_pd(aPacked + 4, av1); + aPacked += BLIS_MX8; + } + } + else if(alpha==-1.0) + { + ymm0 = _mm256_setzero_pd();//set zero + for (gint_t p = 0; p < k; p += 1) { + av0 = _mm256_loadu_pd(pa); av1 = _mm256_loadu_pd(pa + 4); pa += lda; + av0 = _mm256_sub_pd(ymm0,av0); av1 = _mm256_sub_pd(ymm0,av1); // a = 0 - a; + _mm256_storeu_pd(aPacked, av0); _mm256_storeu_pd(aPacked + 4, av1); + aPacked += BLIS_MX8; + } + } + } + else //subroutine below to be optimized + { + if(alpha==1.0) + { + //A Transpose case: + for (gint_t i = 0; i < BLIS_MX8 ; i++) + { + gint_t idx = i * lda; + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+idx+p); + gint_t sidx = p * BLIS_MX8; + *(aPacked + sidx + i) = ar_; + } + } + } + else if(alpha==-1.0) + { + //A Transpose case: + for (gint_t i = 0; i < BLIS_MX8 ; i++) + { + gint_t idx = i * lda; + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+idx+p); + gint_t sidx = p * BLIS_MX8; + *(aPacked + sidx + i) = -ar_; + } + } + } + } +} + +/* Ax4 packing subroutine */ +void bli_prepackA_4(double* pa, + double* aPacked, + gint_t k, + guint_t lda, + bool isTransA, + double alpha) +{ + __m256d av0, ymm0; + if(isTransA==false) + { + if(alpha==1.0) + { + for (gint_t p = 0; p < k; p += 1) { + av0 = _mm256_loadu_pd(pa); pa += lda; + _mm256_storeu_pd(aPacked, av0); + aPacked += BLIS_MX4; + } + } + else if(alpha==-1.0) + { + ymm0 = _mm256_setzero_pd();//set zero + for (gint_t p = 0; p < k; p += 1) { + av0 = _mm256_loadu_pd(pa); pa += lda; + av0 = _mm256_sub_pd(ymm0,av0); // a = 0 - a; + _mm256_storeu_pd(aPacked, av0); + aPacked += BLIS_MX4; + } + } + } + else //subroutine below to be optimized + { + if(alpha==1.0) + { + //A Transpose case: + for (gint_t i = 0; i < BLIS_MX4 ; i++) + { + gint_t idx = i * lda; + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+idx+p); + gint_t sidx = p * BLIS_MX4; + *(aPacked + sidx + i) = ar_; + } + } + } + else if(alpha==-1.0) + { + //A Transpose case: + for (gint_t i = 0; i < BLIS_MX4 ; i++) + { + gint_t idx = i * lda; + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+idx+p); + gint_t sidx = p * BLIS_MX4; + *(aPacked + sidx + i) = -ar_; + } + } + } + } + +} + +/* A packing m>4 subroutine */ +void bli_prepackA_G4( double* pa, + double* aPacked, + gint_t k, + guint_t lda, + bool isTransA, + double alpha, + gint_t mx) +{ + __m256d av0, ymm0; + gint_t mrem = mx - 4; + + if(isTransA==false) + { + if(alpha==1.0) + { + for (gint_t p = 0; p < k; p += 1) { + av0 = _mm256_loadu_pd(pa); + _mm256_storeu_pd(aPacked, av0); + for (gint_t i = 0; i < mrem; i += 1) { + *(aPacked+4+i) = *(pa+4+i); + } + aPacked += mx;pa += lda; + } + } + else if(alpha==-1.0) + { + ymm0 = _mm256_setzero_pd();//set zero + for (gint_t p = 0; p < k; p += 1) { + av0 = _mm256_loadu_pd(pa); + av0 = _mm256_sub_pd(ymm0,av0); // a = 0 - a; + _mm256_storeu_pd(aPacked, av0); + for (gint_t i = 0; i < mrem; i += 1) { + *(aPacked+4+i) = -*(pa+4+i); + } + aPacked += mx;pa += lda; + } + } + } + else //subroutine below to be optimized + { + if(alpha==1.0) + { + //A Transpose case: + for (gint_t i = 0; i < mx ; i++) + { + gint_t idx = i * lda; + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+idx+p); + gint_t sidx = p * mx; + *(aPacked + sidx + i) = ar_; + } + } + } + else if(alpha==-1.0) + { + //A Transpose case: + for (gint_t i = 0; i < mx ; i++) + { + gint_t idx = i * lda; + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+idx+p); + gint_t sidx = p * mx; + *(aPacked + sidx + i) = -ar_; + } + } + } + } + +} + +/* A packing m<4 subroutine */ +void bli_prepackA_L4( double* pa, + double* aPacked, + gint_t k, + guint_t lda, + bool isTransA, + double alpha, + gint_t mx) +{ + if(isTransA==false) + { + if(alpha==1.0) + { + for (gint_t p = 0; p < k; p += 1) + { + for (gint_t i = 0; i < mx; i += 1) + { + *(aPacked+i) = *(pa+i); + } + aPacked += mx;pa += lda; + } + } + else if(alpha==-1.0) + { + for (gint_t p = 0; p < k; p += 1) + { + for (gint_t i = 0; i < mx; i += 1) + { + *(aPacked+i) = -*(pa+i); + } + aPacked += mx;pa += lda; + } + } + } + else + { + if(alpha==1.0) + { + //A Transpose case: + for (gint_t i = 0; i < mx ; i++) + { + gint_t idx = i * lda; + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+idx+p); + gint_t sidx = p * mx; + *(aPacked + sidx + i) = ar_; + } + } + } + else if(alpha==-1.0) + { + //A Transpose case: + for (gint_t i = 0; i < mx ; i++) + { + gint_t idx = i * lda; + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+idx+p); + gint_t sidx = p * mx; + *(aPacked + sidx + i) = -ar_; + } + } + } + } + + +} + +/* Ax1 packing subroutine */ +void bli_prepackA_1(double* pa, + double* aPacked, + gint_t k, + guint_t lda, + bool isTransA, + double alpha) +{ + if(isTransA==false) + { + if(alpha==1.0) + { + for (gint_t p = 0; p < k; p += 1) { + *aPacked = *pa; + pa += lda; + aPacked++; + } + } + else if(alpha==-1.0) + { + for (gint_t p = 0; p < k; p += 1) { + *aPacked = -(*pa); + pa += lda; + aPacked++; + } + } + } + else + { + if(alpha==1.0) + { + //A Transpose case: + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+p); + *(aPacked + p) = ar_; + } + } + else if(alpha==-1.0) + { + //A Transpose case: + for (gint_t p = 0; p < k; p ++) + { + double ar_ = *(pa+p); + *(aPacked + p) = -ar_; + } + } + } +} + + +void bli_add_m( gint_t m, + gint_t n, + double* w, + double* c) +{ + double* pc = c; + double* pw = w; + gint_t count = m*n; + gint_t i = 0; + __m256d cv0, wv0; + + for (; i <= (count-4); i+=4) + { + cv0 = _mm256_loadu_pd(pc); + wv0 = _mm256_loadu_pd(pw); pw += 4; + cv0 = _mm256_add_pd(cv0,wv0); + _mm256_storeu_pd(pc, cv0); pc += 4; + } + for (; i < count; i++) + { + *pc = *pc + *pw; + pc++; pw++; + } +} + +void bli_sub_m( gint_t m, + gint_t n, + double* w, + double* c) +{ + double* pc = c; + double* pw = w; + gint_t count = m*n; + gint_t i = 0; + __m256d cv0, wv0; + + for (; i <= (count-4); i+=4) + { + cv0 = _mm256_loadu_pd(pc); + wv0 = _mm256_loadu_pd(pw); pw += 4; + cv0 = _mm256_sub_pd(cv0,wv0); + _mm256_storeu_pd(pc, cv0); pc += 4; + } + for (; i < count; i++) + { + *pc = *pc - *pw; + pc++; pw++; + } +} + +/* Pack real and imaginary parts in separate buffers and also multipy with multiplication factor */ +void bli_3m_sqp_packC_real_imag(double* pc, + guint_t n, + guint_t m, + guint_t ldc, + double* pcr, + double* pci, + double mul, + gint_t mx) +{ + gint_t j, p; + __m256d av0, av1, zerov; + __m256d tv0, tv1; + gint_t max_m = (m*2)-8; + + if((mul ==1.0)||(mul==-1.0)) + { + if(mul ==1.0) /* handles alpha or beta = 1.0 */ + { + for (j = 0; j < n; j++) + { + for (p = 0; p <= max_m; p += 8) + { + double* pbp = pc + p; + av0 = _mm256_loadu_pd(pbp); //ai1, ar1, ai0, ar0 + av1 = _mm256_loadu_pd(pbp+4); //ai3, ar3, ai2, ar2 + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 + av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 + av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 + + _mm256_storeu_pd(pcr, av0); pcr += 4; + _mm256_storeu_pd(pci, av1); pci += 4; + } + + for (; p < (m*2); p += 2)// (real + imag)*m + { + double br = *(pc + p) ; + double bi = *(pc + p + 1); + *pcr = br; + *pci = bi; + pcr++; pci++; + } + pc = pc + ldc; + } + } + else /* handles alpha or beta = - 1.0 */ + { + zerov = _mm256_setzero_pd(); + for (j = 0; j < n; j++) + { + for (p = 0; p <= max_m; p += 8) + { + double* pbp = pc + p; + av0 = _mm256_loadu_pd(pbp); //ai1, ar1, ai0, ar0 + av1 = _mm256_loadu_pd(pbp+4);//ai3, ar3, ai2, ar2 + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 + av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 + av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 + + //negate + av0 = _mm256_sub_pd(zerov,av0); + av1 = _mm256_sub_pd(zerov,av1); + + _mm256_storeu_pd(pcr, av0); pcr += 4; + _mm256_storeu_pd(pci, av1); pci += 4; + } + + for (; p < (m*2); p += 2)// (real + imag)*m + { + double br = -*(pc + p) ; + double bi = -*(pc + p + 1); + *pcr = br; + *pci = bi; + pcr++; pci++; + } + pc = pc + ldc; + } + } + } + else /* handles alpha or beta is not equal +/- 1.0 */ + { + for (j = 0; j < n; j++) + { + for (p = 0; p < (m*2); p += 2)// (real + imag)*m + { + double br_ = mul * (*(pc + p)); + double bi_ = mul * (*(pc + p + 1)); + *pcr = br_; + *pci = bi_; + pcr++; pci++; + } + pc = pc + ldc; + } + } +} + +/* Pack real and imaginary parts in separate buffers and compute sum of real and imaginary part */ +void bli_3m_sqp_packB_real_imag_sum(double* pb, + guint_t n, + guint_t k, + guint_t ldb, + double* pbr, + double* pbi, + double* pbs, + double mul, + gint_t mx) +{ + gint_t j, p; + __m256d av0, av1, zerov; + __m256d tv0, tv1, sum; + gint_t max_k = (k*2) - 8; + if((mul ==1.0)||(mul==-1.0)) + { + if(mul ==1.0) + { + for (j = 0; j < n; j++) + { + for (p=0; p <= max_k; p += 8) + { + double* pbp = pb + p; + av0 = _mm256_loadu_pd(pbp);//ai1, ar1, ai0, ar0 + av1 = _mm256_loadu_pd(pbp+4);//ai3, ar3, ai2, ar2 + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 + av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 + av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 + sum = _mm256_add_pd(av0, av1); + _mm256_storeu_pd(pbr, av0); pbr += 4; + _mm256_storeu_pd(pbi, av1); pbi += 4; + _mm256_storeu_pd(pbs, sum); pbs += 4; + } + + for (; p < (k*2); p += 2)// (real + imag)*k + { + double br = *(pb + p) ; + double bi = *(pb + p + 1); + *pbr = br; + *pbi = bi; + *pbs = br + bi; + + pbr++; pbi++; pbs++; + } + pb = pb + ldb; + } + } + else + { + zerov = _mm256_setzero_pd(); + for (j = 0; j < n; j++) + { + for (p = 0; p <= max_k; p += 8) + { + double* pbp = pb + p; + av0 = _mm256_loadu_pd(pbp);//ai1, ar1, ai0, ar0 + av1 = _mm256_loadu_pd(pbp+4);//ai3, ar3, ai2, ar2 + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 + av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 + av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 + + //negate + av0 = _mm256_sub_pd(zerov,av0); + av1 = _mm256_sub_pd(zerov,av1); + + sum = _mm256_add_pd(av0, av1); + _mm256_storeu_pd(pbr, av0); pbr += 4; + _mm256_storeu_pd(pbi, av1); pbi += 4; + _mm256_storeu_pd(pbs, sum); pbs += 4; + } + + for (; p < (k*2); p += 2)// (real + imag)*k + { + double br = -*(pb + p) ; + double bi = -*(pb + p + 1); + *pbr = br; + *pbi = bi; + *pbs = br + bi; + + pbr++; pbi++; pbs++; + } + pb = pb + ldb; + } + } + } + else + { + for (j = 0; j < n; j++) + { + for (p = 0; p < (k*2); p += 2)// (real + imag)*k + { + double br_ = mul * (*(pb + p)); + double bi_ = mul * (*(pb + p + 1)); + *pbr = br_; + *pbi = bi_; + *pbs = br_ + bi_; + + pbr++; pbi++; pbs++; + } + pb = pb + ldb; + } + } +} + +/* Pack real and imaginary parts of A matrix in separate buffers and compute sum of real and imaginary part */ +void bli_3m_sqp_packA_real_imag_sum(double *pa, + gint_t i, + guint_t k, + guint_t lda, + double *par, + double *pai, + double *pas, + bool isTransA, + gint_t mx, + gint_t p) +{ + __m256d av0, av1, av2, av3; + __m256d tv0, tv1, sum; + gint_t poffset = p; +#if KLP + //k = p + k; +#endif + if(mx==8) + { + if(isTransA==false) + { + pa = pa +i; +#if KLP + pa = pa + (p*lda); +#else + p = 0; +#endif + //printf("packA from p_%d to p_%d \n", p, k); + for (; p < k; p += 1) + { + //for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each. + av0 = _mm256_loadu_pd(pa); + av1 = _mm256_loadu_pd(pa+4); + av2 = _mm256_loadu_pd(pa+8); + av3 = _mm256_loadu_pd(pa+12); + + tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); + tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); + av0 = _mm256_unpacklo_pd(tv0, tv1); + av1 = _mm256_unpackhi_pd(tv0, tv1); + sum = _mm256_add_pd(av0, av1); + _mm256_storeu_pd(par, av0); par += 4; + _mm256_storeu_pd(pai, av1); pai += 4; + _mm256_storeu_pd(pas, sum); pas += 4; + + tv0 = _mm256_permute2f128_pd(av2, av3, 0x20); + tv1 = _mm256_permute2f128_pd(av2, av3, 0x31); + av2 = _mm256_unpacklo_pd(tv0, tv1); + av3 = _mm256_unpackhi_pd(tv0, tv1); + sum = _mm256_add_pd(av2, av3); + _mm256_storeu_pd(par, av2); par += 4; + _mm256_storeu_pd(pai, av3); pai += 4; + _mm256_storeu_pd(pas, sum); pas += 4; + + pa = pa + lda; + } + } + else + { + gint_t idx = (i/2) * lda; + pa = pa + idx; +#if KLP + //pa = pa + p; +#else + p = 0; +#endif + //A Transpose case: + for (gint_t ii = 0; ii < BLIS_MX8 ; ii++) + { + gint_t idx = ii * lda; + gint_t sidx; + gint_t pidx = 0; + gint_t max_k = (k*2) - 8; + for (p = poffset; p <= max_k; p += 8) + { + double ar0_ = *(pa + idx + p); + double ai0_ = *(pa + idx + p + 1); + + double ar1_ = *(pa + idx + p + 2); + double ai1_ = *(pa + idx + p + 3); + + double ar2_ = *(pa + idx + p + 4); + double ai2_ = *(pa + idx + p + 5); + + double ar3_ = *(pa + idx + p + 6); + double ai3_ = *(pa + idx + p + 7); + + sidx = (pidx/2) * BLIS_MX8; + *(par + sidx + ii) = ar0_; + *(pai + sidx + ii) = ai0_; + *(pas + sidx + ii) = ar0_ + ai0_; + + sidx = ((pidx+2)/2) * BLIS_MX8; + *(par + sidx + ii) = ar1_; + *(pai + sidx + ii) = ai1_; + *(pas + sidx + ii) = ar1_ + ai1_; + + sidx = ((pidx+4)/2) * BLIS_MX8; + *(par + sidx + ii) = ar2_; + *(pai + sidx + ii) = ai2_; + *(pas + sidx + ii) = ar2_ + ai2_; + + sidx = ((pidx+6)/2) * BLIS_MX8; + *(par + sidx + ii) = ar3_; + *(pai + sidx + ii) = ai3_; + *(pas + sidx + ii) = ar3_ + ai3_; + pidx += 8; + + } + + for (; p < (k*2); p += 2) + { + double ar_ = *(pa + idx + p); + double ai_ = *(pa + idx + p + 1); + gint_t sidx = (pidx/2) * BLIS_MX8; + *(par + sidx + ii) = ar_; + *(pai + sidx + ii) = ai_; + *(pas + sidx + ii) = ar_ + ai_; + pidx += 2; + } + } + } + } //mx==8 + else//mx==1 + { + if(isTransA==false) + { + pa = pa + i; +#if KLP + //pa = pa + (p*lda); done below.. not needed +#else + p = 0; +#endif + //printf(" packAx1 from p_%d to p_%d ",p,k-1); + //A No transpose case: + for (; p < k; p += 1) + { + gint_t idx = p * lda; + for (gint_t ii = 0; ii < (mx*2) ; ii += 2) + { //real + imag : Rkernel needs 8 elements each. + double ar_ = *(pa + idx + ii); + double ai_ = *(pa + idx + ii + 1); + *par = ar_; + *pai = ai_; + *pas = ar_ + ai_; + par++; pai++; pas++; + } + } + } + else + { + gint_t idx = (i/2) * lda; + pa = pa + idx; +#if KLP + //pa = pa + p; done below.. not needed +#else + p = 0; +#endif + //A Transpose case: + for (gint_t ii = 0; ii < mx ; ii++) + { + gint_t idx = ii * lda; + gint_t sidx; + gint_t pidx = 0; + for (p = poffset;p < (k*2); p += 2) + { + double ar0_ = *(pa + idx + p); + double ai0_ = *(pa + idx + p + 1); + + sidx = (pidx/2) * mx; + *(par + sidx + ii) = ar0_; + *(pai + sidx + ii) = ai0_; + *(pas + sidx + ii) = ar0_ + ai0_; + pidx += 2; + + } + } + } + }//mx==1 +} + diff --git a/kernels/zen/3/bli_gemm_sqp_kernels.h b/kernels/zen/3/bli_gemm_sqp_kernels.h new file mode 100644 index 0000000000..5d5405e0bb --- /dev/null +++ b/kernels/zen/3/bli_gemm_sqp_kernels.h @@ -0,0 +1,65 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +/* square packed (sqp) kernels */ +#define KLP 1// k loop partition. + +/* sqp dgemm core kernels, targetted mainly for square sizes by default. + sqp framework allows tunning for other shapes.*/ +inc_t bli_sqp_dgemm_kernel_8mx6n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); +inc_t bli_sqp_dgemm_kernel_8mx5n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); +inc_t bli_sqp_dgemm_kernel_8mx4n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); +inc_t bli_sqp_dgemm_kernel_8mx3n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); +inc_t bli_sqp_dgemm_kernel_8mx2n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); +inc_t bli_sqp_dgemm_kernel_8mx1n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); +inc_t bli_sqp_dgemm_kernel_1mx1n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); +inc_t bli_sqp_dgemm_kernel_mxn(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, gint_t mx); + +//add and sub kernels +void bli_add_m(gint_t m,gint_t n,double* w,double* c); +void bli_sub_m(gint_t m, gint_t n, double* w, double* c); + +//packing kernels +//Pack A with alpha multiplication +void bli_sqp_prepackA(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha, gint_t mx); + +void bli_prepackA_8(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha); +void bli_prepackA_4(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha); +void bli_prepackA_G4(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha, gint_t mx); +void bli_prepackA_L4(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha, gint_t mx); +void bli_prepackA_1(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha); + +/* Pack real and imaginary parts in separate buffers and also multipy with multiplication factor */ +void bli_3m_sqp_packC_real_imag(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double mul, gint_t mx); +void bli_3m_sqp_packB_real_imag_sum(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double* pbs, double mul, gint_t mx); +void bli_3m_sqp_packA_real_imag_sum(double *pa, gint_t i, guint_t k, guint_t lda, double *par, double *pai, double *pas, bool isTransA, gint_t mx, gint_t p); \ No newline at end of file From 35ad5d8dc5b68adb41246f2bbd6783b96a7e74c9 Mon Sep 17 00:00:00 2001 From: Madan mohan Manokar Date: Mon, 5 Jul 2021 18:40:34 +0530 Subject: [PATCH 08/13] 3m_sqp conjugate support added 1. 3m_sqp support for A matrix with conjugate_no_transpose and conjugate_transpose added. AMD-Internal: [CPUPL-1521] Change-Id: Ie6e5c49cf86f7d3b95d78705cf445e57f20b3d1f --- kernels/zen/3/bli_gemm_sqp.c | 46 +++++-- kernels/zen/3/bli_gemm_sqp_kernels.c | 171 +++++++++++++++++++++++++-- kernels/zen/3/bli_gemm_sqp_kernels.h | 2 +- 3 files changed, 193 insertions(+), 26 deletions(-) diff --git a/kernels/zen/3/bli_gemm_sqp.c b/kernels/zen/3/bli_gemm_sqp.c index 1622c551c4..ceab622bf3 100644 --- a/kernels/zen/3/bli_gemm_sqp.c +++ b/kernels/zen/3/bli_gemm_sqp.c @@ -40,16 +40,31 @@ #define BLIS_LOADFIRST 0 #define MEM_ALLOC 1//malloc performs better than bli_malloc. +#define SET_TRANS(X,Y)\ + Y = BLIS_NO_TRANSPOSE;\ + if(bli_obj_has_trans( a ))\ + {\ + Y = BLIS_TRANSPOSE;\ + if(bli_obj_has_conj(a))\ + {\ + Y = BLIS_CONJ_TRANSPOSE;\ + }\ + }\ + else if(bli_obj_has_conj(a))\ + {\ + Y = BLIS_CONJ_NO_TRANSPOSE;\ + } + //Macro for 3m_sqp n loop #define BLI_SQP_ZGEMM_N(MX)\ int j=0;\ for(; j<=(n-nx); j+= nx)\ {\ - status = bli_sqp_zgemm_m8( m, nx, k, a, lda, b+(j*ldb), ldb, c+(j*ldc), ldc, alpha_real, beta_real, isTransA, MX, p_istart, kx, &mem_3m_sqp);\ + status = bli_sqp_zgemm_m8( m, nx, k, a, lda, b+(j*ldb), ldb, c+(j*ldc), ldc, alpha_real, beta_real, transa, MX, p_istart, kx, &mem_3m_sqp);\ }\ if(j Date: Fri, 10 Sep 2021 09:12:59 +0530 Subject: [PATCH 09/13] Induced method turned off, fix for beta=0 & C = NAN 1. Induced Method turned off, till the path fully tested for different alpha,beta conditions. 2. Fix for Beta =0, and C = NAN done. Change-Id: I5a7bd1393ac245c2ebb72f9a634728af4c0d4000 --- frame/compat/bla_gemm.c | 2 +- kernels/zen/3/bli_gemm_sqp_kernels.c | 17 ++++++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index 3eafad1903..1bf0a81cb8 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -38,7 +38,7 @@ // // Define BLAS-to-BLIS interfaces. // -#define ENABLE_INDUCED_METHOD 1 +#define ENABLE_INDUCED_METHOD 0 #ifdef BLIS_BLAS3_CALLS_TAPI #undef GENTFUNC diff --git a/kernels/zen/3/bli_gemm_sqp_kernels.c b/kernels/zen/3/bli_gemm_sqp_kernels.c index 9cac5e83eb..0f20c0a956 100644 --- a/kernels/zen/3/bli_gemm_sqp_kernels.c +++ b/kernels/zen/3/bli_gemm_sqp_kernels.c @@ -1278,7 +1278,22 @@ void bli_3m_sqp_packC_real_imag(double* pc, } } } - else /* handles alpha or beta is not equal +/- 1.0 */ + else if(mul==0) /* handles alpha or beta is equal to zero */ + { + double br_ = 0; + double bi_ = 0; + for (j = 0; j < n; j++) + { + for (p = 0; p < (m*2); p += 2)// (real + imag)*m + { + *pcr = br_; + *pci = bi_; + pcr++; pci++; + } + pc = pc + ldc; + } + } + else /* handles alpha or beta is not equal +/- 1.0 and zero */ { for (j = 0; j < n; j++) { From 7cd79683d4764d908bda20011386fcd20ef639cb Mon Sep 17 00:00:00 2001 From: Madan mohan Manokar Date: Wed, 15 Dec 2021 14:29:57 +0530 Subject: [PATCH 10/13] compile error fixes 1. New err_t param in bli_malloc_user added. 2. AOCL_DTL log removed. --- kernels/zen/3/bli_gemm_sqp.c | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/kernels/zen/3/bli_gemm_sqp.c b/kernels/zen/3/bli_gemm_sqp.c index ceab622bf3..34636ccaea 100644 --- a/kernels/zen/3/bli_gemm_sqp.c +++ b/kernels/zen/3/bli_gemm_sqp.c @@ -186,7 +186,7 @@ err_t bli_gemm_sqp cntl_t* cntl ) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + //AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); // if row major format return. if ((bli_obj_row_stride( a ) != 1) || @@ -277,7 +277,7 @@ err_t bli_gemm_sqp return bli_sqp_dgemm( m, n, k, ap, lda, bp, ldb, cp, ldc, *alpha_cast, *beta_cast, isTransA, nt); } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); + //AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_NOT_YET_IMPLEMENTED; }; @@ -377,8 +377,9 @@ void bli_sqp_dgemm_m( gint_t i_start, #if SQP_THREAD_ENABLE if(pack_on==true) { + err_t r_val; //NEEDED IN THREADING CASE: - aligned = (double*)bli_malloc_user(sizeof(double) * kx * mx); + aligned = (double*)bli_malloc_user(sizeof(double) * kx * mx, &r_val); if(aligned==NULL) { return BLIS_MALLOC_RETURNED_NULL;// return to be removed @@ -594,7 +595,8 @@ gint_t bli_getaligned(mem_block* mem_req) address += (-address) & 63; //64 bytes alignment done. mem_req->alignedBuf = (double*)address; #else - mem_req->alignedBuf = bli_malloc_user( memSize ); + err_t r_val; + mem_req->alignedBuf = bli_malloc_user( memSize, &r_val); if (mem_req->alignedBuf == NULL) { return -1; @@ -814,7 +816,8 @@ err_t allocate_3m_Sqp_workspace(workspace_3m_sqp *mem_3m_sqp, mem_3m_sqp->ci_unaligned = (double*)mci.unalignedBuf; // A packing buffer - mem_3m_sqp->aPacked = (double*)bli_malloc_user(sizeof(double) * kx * mx); + err_t r_val; + mem_3m_sqp->aPacked = (double*)bli_malloc_user(sizeof(double) * kx * mx, &r_val); if (mem_3m_sqp->aPacked == NULL) { return BLIS_FAILURE; @@ -1162,7 +1165,8 @@ BLIS_INLINE err_t bli_sqp_dgemm(gint_t m, if(nt<=1)//single pack buffer allocated for single thread case { - a_aligned = (double*)bli_malloc_user(sizeof(double) * kx * mx); + err_t r_val; + a_aligned = (double*)bli_malloc_user(sizeof(double) * kx * mx, &r_val); } gint_t nx = n;//MAX; From 59029ee1739baf05eecebc492814adc944c3d35c Mon Sep 17 00:00:00 2001 From: Madan mohan Manokar Date: Wed, 15 Dec 2021 15:06:43 +0530 Subject: [PATCH 11/13] Revert "sup zgemm improvement" This reverts commit 231a464a02fd3e07cda063dfaa1b1b200599e2c9. --- frame/3/bli_l3_sup_int.c | 7 ------- 1 file changed, 7 deletions(-) diff --git a/frame/3/bli_l3_sup_int.c b/frame/3/bli_l3_sup_int.c index e1deb32907..e54e01d7c7 100644 --- a/frame/3/bli_l3_sup_int.c +++ b/frame/3/bli_l3_sup_int.c @@ -181,13 +181,6 @@ err_t bli_gemmsup_int if ( mu >= nu ) use_bp = TRUE; else /* if ( mu < nu ) */ use_bp = FALSE; - // In zgemm, mkernel outperforms nkernel for both m > n and n < m. - // mkernel is forced for zgemm. - if(bli_is_dcomplex(dt)) - { - use_bp = TRUE;//mkernel - } - // If the parallel thread factorization was automatic, we update it // with a new factorization based on the matrix dimensions in units // of micropanels. From b3e82baf62f9d2492286b341c95a71b1a63c2c71 Mon Sep 17 00:00:00 2001 From: Madan mohan Manokar Date: Wed, 15 Dec 2021 16:24:59 +0530 Subject: [PATCH 12/13] code clean and comments added --- kernels/zen/3/bli_gemm_sqp.c | 77 ++++++++------- kernels/zen/3/bli_gemm_sqp_kernels.c | 134 ++------------------------- 2 files changed, 52 insertions(+), 159 deletions(-) diff --git a/kernels/zen/3/bli_gemm_sqp.c b/kernels/zen/3/bli_gemm_sqp.c index 34636ccaea..60a7aa9436 100644 --- a/kernels/zen/3/bli_gemm_sqp.c +++ b/kernels/zen/3/bli_gemm_sqp.c @@ -35,10 +35,11 @@ #include "immintrin.h" #include "bli_gemm_sqp_kernels.h" -#define SQP_THREAD_ENABLE 0//currently disabled +#define SQP_THREAD_ENABLE 0// Currently disabled: simple threading model along m dimension. + // Works well when m is large and other dimensions are relatively smaller. #define BLI_SQP_MAX_THREADS 128 #define BLIS_LOADFIRST 0 -#define MEM_ALLOC 1//malloc performs better than bli_malloc. +#define MEM_ALLOC 1 // malloc performs better than bli_malloc. Reason to be understood. #define SET_TRANS(X,Y)\ Y = BLIS_NO_TRANSPOSE;\ @@ -55,7 +56,7 @@ Y = BLIS_CONJ_NO_TRANSPOSE;\ } -//Macro for 3m_sqp n loop +//Macro for 3m_sqp nx loop #define BLI_SQP_ZGEMM_N(MX)\ int j=0;\ for(; j<=(n-nx); j+= nx)\ @@ -67,7 +68,7 @@ status = bli_sqp_zgemm_m8( m, n-j, k, a, lda, b+(j*ldb), ldb, c+(j*ldc), ldc, alpha_real, beta_real, transa, MX, p_istart, kx, &mem_3m_sqp);\ } -//Macro for sqp_dgemm n loop +//Macro for sqp_dgemm nx loop #define BLI_SQP_DGEMM_N(MX)\ int j=0;\ for(; j<=(n-nx); j+= nx)\ @@ -88,20 +89,25 @@ typedef struct { // 3m_sqp workspace data-structure typedef struct { - double *ar; - double *ai; - double *as; - double *br; - double *bi; - double *bs; + //list of aligned buffers used in 3m_sqp algorithm + double *ar; //A matrix real component buffer pointer + double *ai; //A matrix imaginary component buffer pointer + double *as; //as = ar + ai - double *cr; - double *ci; + double *br; //B matrix real component buffer pointer + double *bi; //B matrix complex component buffer pointer + double *bs; //bs = br + bi - double *w; - double *aPacked; + double *cr; //C matrix real component buffer pointer + double *ci; //C matrix imaginary component buffer pointer + + double *w; //workspace buffer + double *aPacked; //A packed buffer + + /* Other unaligned buffer, which are allocated and + buffer pointers are stored to free at the end task completion. */ double *ar_unaligned; double *ai_unaligned; double *as_unaligned; @@ -117,7 +123,7 @@ typedef struct { }workspace_3m_sqp; -//sqp threading datastructure +//sqp threading datastructure: typedef struct bli_sqp_thread_info { gint_t i_start; @@ -186,7 +192,6 @@ err_t bli_gemm_sqp cntl_t* cntl ) { - //AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); // if row major format return. if ((bli_obj_row_stride( a ) != 1) || @@ -277,7 +282,6 @@ err_t bli_gemm_sqp return bli_sqp_dgemm( m, n, k, ap, lda, bp, ldb, cp, ldc, *alpha_cast, *beta_cast, isTransA, nt); } - //AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_NOT_YET_IMPLEMENTED; }; @@ -461,17 +465,19 @@ BLIS_INLINE err_t bli_sqp_dgemm_m8( gint_t m, pack_on = true; } -#if 0//SQP_THREAD_ENABLE//ENABLE Threading +#if SQP_THREAD_ENABLE//ENABLE Threading gint_t status = 0; gint_t workitems = (m-(*p_istart))/mx; gint_t inputThreadCount = bli_thread_get_num_threads(); inputThreadCount = bli_min(inputThreadCount, BLI_SQP_MAX_THREADS); - inputThreadCount = bli_min(inputThreadCount,workitems);// limit input thread count when workitems are lesser. + // limit input thread count when workitems are lesser. + inputThreadCount = bli_min(inputThreadCount,workitems); inputThreadCount = bli_max(inputThreadCount,1); gint_t num_threads; num_threads = bli_max(inputThreadCount,1); - gint_t mx_per_thread = workitems/num_threads;//no of workitems per thread - //printf("\nistart %d workitems %d inputThreadCount %d num_threads %d mx_per_thread %d mx %d " , + //no of workitems per thread + gint_t mx_per_thread = workitems/num_threads; + *p_istart, workitems,inputThreadCount,num_threads,mx_per_thread, mx); pthread_t ptid[BLI_SQP_MAX_THREADS]; @@ -480,14 +486,13 @@ BLIS_INLINE err_t bli_sqp_dgemm_m8( gint_t m, //create threads for (gint_t t = 0; t < num_threads; t++) { - //ptid[t].tid = t; gint_t i_end = ((mx_per_thread*(t+1))*mx)+(*p_istart); if(i_end>m) { i_end = m; } - if(t==(num_threads-1)) + if(t==(num_threads-1))//for last thread allocate remaining work. { if((i_end+mx)==m) { @@ -502,7 +507,7 @@ BLIS_INLINE err_t bli_sqp_dgemm_m8( gint_t m, thread_info[t].i_start = ((mx_per_thread*t)*mx)+(*p_istart); thread_info[t].i_end = i_end; - //printf("\n threadid %d istart %d iend %d m %d mx %d", t, thread_info[t].i_start, i_end, m, mx); + thread_info[t].m = m; thread_info[t].n = n; thread_info[t].k = k; @@ -518,16 +523,12 @@ BLIS_INLINE err_t bli_sqp_dgemm_m8( gint_t m, thread_info[t].mx = mx; thread_info[t].pack_on = pack_on; thread_info[t].aligned = aligned; -#if 1 + if ((status = pthread_create(&ptid[t], NULL, bli_sqp_thread, (void*)&thread_info[t]))) { printf("error sqp pthread_create\n"); return BLIS_FAILURE; } -#else - //simulate thread for debugging.. - bli_sqp_thread((void*)&thread_info[t]); -#endif } //wait for completion @@ -542,16 +543,16 @@ BLIS_INLINE err_t bli_sqp_dgemm_m8( gint_t m, } #else//SQP_THREAD_ENABLE + // single thread code segement. if(pack_on==true) { - //aligned = (double*)bli_malloc_user(sizeof(double) * kx * mx); // allocation moved to top. if(aligned==NULL) { return BLIS_MALLOC_RETURNED_NULL; } } - for (i = (*p_istart); i <= (m-mx); i += mx) //this loop can be threaded. no of workitems = m/8 + for (i = (*p_istart); i <= (m-mx); i += mx) { int p = 0; for(; p <= (k-kx); p += kx) @@ -913,7 +914,9 @@ BLIS_INLINE err_t bli_sqp_zgemm_m8( gint_t m, inc_t i; gint_t max_m = (m2-mxmul2); - //get workspace + /* + get workspace for packing and + for storing sum of real and imag component */ double* ar, * ai, * as; ar = mem_3m_sqp->ar; ai = mem_3m_sqp->ai; @@ -1163,13 +1166,16 @@ BLIS_INLINE err_t bli_sqp_dgemm(gint_t m, gint_t kx = k; double* a_aligned = NULL; - if(nt<=1)//single pack buffer allocated for single thread case + //single pack buffer allocated for single thread case + if(nt<=1) { err_t r_val; a_aligned = (double*)bli_malloc_user(sizeof(double) * kx * mx, &r_val); } - gint_t nx = n;//MAX; + /* nx assigned to n, but can be assigned with lower value + to add partition along n dimension */ + gint_t nx = n;. if(nx>n) { nx = n; @@ -1199,7 +1205,8 @@ BLIS_INLINE err_t bli_sqp_dgemm(gint_t m, } } - if(nt<=1)//single pack buffer allocated for single thread case + //single pack buffer allocated for single thread case + if(nt<=1) { bli_free_user(a_aligned); } diff --git a/kernels/zen/3/bli_gemm_sqp_kernels.c b/kernels/zen/3/bli_gemm_sqp_kernels.c index 0f20c0a956..2f553708cd 100644 --- a/kernels/zen/3/bli_gemm_sqp_kernels.c +++ b/kernels/zen/3/bli_gemm_sqp_kernels.c @@ -182,7 +182,7 @@ inc_t bli_sqp_dgemm_kernel_8mx6n(gint_t n, pc += ldc6;pb += ldb6; } - //printf(" 8x6:j:%d ", j); + return j; } @@ -309,7 +309,7 @@ inc_t bli_sqp_dgemm_kernel_8mx5n( gint_t n, pc += ldc5;pb += ldb5; } - //printf(" 8x5:j:%d ", j); + return j; } @@ -381,7 +381,7 @@ inc_t bli_sqp_dgemm_kernel_8mx4n( gint_t n, pc += ldc4;pb += ldb4; }// j loop 4 multiple - //printf(" 8x4:j:%d ", j); + return j; } @@ -448,7 +448,7 @@ inc_t bli_sqp_dgemm_kernel_8mx3n( gint_t n, pc += ldc3;pb += ldb3; }// j loop 3 multiple - //printf(" 8x3:j:%d ", j); + return j; } @@ -506,7 +506,7 @@ inc_t bli_sqp_dgemm_kernel_8mx2n( gint_t n, pc += ldc2;pb += ldb2; }// j loop 2 multiple - //printf(" 8x2:j:%d ", j); + return j; } @@ -571,98 +571,7 @@ inc_t bli_sqp_dgemm_kernel_4mx10n( gint_t n, double* c, guint_t ldc) { - gint_t p; - /* incomplete */ - __m256d av0; - __m256d bv0, bv1, bv2, bv3; - __m256d cv0, cv1, cv2, cv3; - __m256d cx0, cx1, cx2, cx3; - __m256d bv4, cv4, cx4; - double* pb, * pc; - - pb = b; - pc = c; - inc_t ldc10 = ldc * 10; inc_t ldb10 = ldb * 10; - - for (j = 0; j <= (n - 10); j += 10) { - - double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; double* pcldc3 = pcldc2 + ldc; double* pcldc4 = pcldc3 + ldc; - double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; double* pbldb3 = pbldb2 + ldb; double* pbldb4 = pbldb3 + ldb; - -#if BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(pc), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc2), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc3), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc4), _MM_HINT_T0); - - _mm_prefetch((char*)(aPacked), _MM_HINT_T0); - - _mm_prefetch((char*)(pb), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb2), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb3), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb4), _MM_HINT_T0); -#endif - /* C matrix column major load */ -#if BLIS_LOADFIRST - cv0 = _mm256_loadu_pd(pc); - cv1 = _mm256_loadu_pd(pcldc); - cv2 = _mm256_loadu_pd(pcldc2); - cv3 = _mm256_loadu_pd(pcldc3); - cv4 = _mm256_loadu_pd(pcldc4); -#else - cv0 = _mm256_setzero_pd(); - cv1 = _mm256_setzero_pd(); - cv2 = _mm256_setzero_pd(); - cv3 = _mm256_setzero_pd(); - cv4 = _mm256_setzero_pd(); -#endif - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - bv0 = _mm256_broadcast_sd(pb0); pb0++; - bv1 = _mm256_broadcast_sd(pbldb); pbldb++; - bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; - bv3 = _mm256_broadcast_sd(pbldb3);pbldb3++; - bv4 = _mm256_broadcast_sd(pbldb4);pbldb4++; - - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - cv1 = _mm256_fmadd_pd(av0, bv1, cv1); - cv2 = _mm256_fmadd_pd(av0, bv2, cv2); - cv3 = _mm256_fmadd_pd(av0, bv3, cv3); - cv4 = _mm256_fmadd_pd(av0, bv4, cv4); - - } -#if BLIS_LOADFIRST -#else - bv0 = _mm256_loadu_pd(pc); - cv0 = _mm256_add_pd(cv0, bv0); - - bv2 = _mm256_loadu_pd(pcldc); - cv1 = _mm256_add_pd(cv1, bv2); - - bv0 = _mm256_loadu_pd(pcldc2); - cv2 = _mm256_add_pd(cv2, bv0); - - bv2 = _mm256_loadu_pd(pcldc3); - cv3 = _mm256_add_pd(cv3, bv2); - - bv0 = _mm256_loadu_pd(pcldc4); - cv4 = _mm256_add_pd(cv4, bv0); -#endif - /* C matrix column major store */ - _mm256_storeu_pd(pc, cv0); - _mm256_storeu_pd(pcldc, cv1); - _mm256_storeu_pd(pcldc2, cv2); - _mm256_storeu_pd(pcldc3, cv3); - _mm256_storeu_pd(pcldc4, cv4); - - - pc += ldc10;pb += ldb10; - } - + //tbd return j; } @@ -680,30 +589,9 @@ inc_t bli_sqp_dgemm_kernel_4mx1n( gint_t n, double* c, guint_t ldc) { - gint_t p; - __m256d av0; - __m256d bv0; - __m256d cv0; - double* pb, * pc; - - pb = b; - pc = c; - - for (; j <= (n - 1); j += 1) { - cv0 = _mm256_loadu_pd(pc); - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - bv0 = _mm256_broadcast_sd(pb0); pb0++; - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - } - _mm256_storeu_pd(pc, cv0); - pc += ldc;pb += ldb; - }// j loop 1 multiple + //tbd return j; } - #endif /************************************************************************************************************/ /************************** dgemm kernels (1mxn) column preffered ******************************************/ @@ -768,7 +656,7 @@ inc_t bli_sqp_dgemm_kernel_mxn( gint_t n, pc = c; for (; j <= (n - 1); j += 1) { - //cv0 = _mm256_loadu_pd(pc); + for (int i = 0; i < mx; i++) { cx[i] = *(pc + i); @@ -777,17 +665,15 @@ inc_t bli_sqp_dgemm_kernel_mxn( gint_t n, double* x = aPacked; double* pb0 = pb; for (p = 0; p < k; p += 1) { - //bv0 = _mm256_broadcast_sd(pb0); + double b0 = *pb0; pb0++; for (int i = 0; i < mx; i++) { - cx[i] += (*(x + i)) * b0;//cv0 = _mm256_fmadd_pd(av0, bv0, cv0); + cx[i] += (*(x + i)) * b0; } - //av0 = _mm256_loadu_pd(x); x += mx; } - //_mm256_storeu_pd(pc, cv0); for (int i = 0; i < mx; i++) { *(pc + i) = cx[i]; From 0f984c5818edd32165c5cd823f50f0da12367207 Mon Sep 17 00:00:00 2001 From: Madan mohan Manokar Date: Wed, 15 Dec 2021 16:31:20 +0530 Subject: [PATCH 13/13] bug fix and print added when bli_gemm_sqp fails --- kernels/zen/3/bli_gemm_sqp.c | 2 +- testsuite/src/test_gemm.c | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/kernels/zen/3/bli_gemm_sqp.c b/kernels/zen/3/bli_gemm_sqp.c index 60a7aa9436..ca33972a29 100644 --- a/kernels/zen/3/bli_gemm_sqp.c +++ b/kernels/zen/3/bli_gemm_sqp.c @@ -1175,7 +1175,7 @@ BLIS_INLINE err_t bli_sqp_dgemm(gint_t m, /* nx assigned to n, but can be assigned with lower value to add partition along n dimension */ - gint_t nx = n;. + gint_t nx = n; if(nx>n) { nx = n; diff --git a/testsuite/src/test_gemm.c b/testsuite/src/test_gemm.c index b6d8125054..729b7aee5f 100644 --- a/testsuite/src/test_gemm.c +++ b/testsuite/src/test_gemm.c @@ -474,6 +474,8 @@ bli_printm( "c", c, "%5.2f", "" ); #if TEST_SQP if(bli_gemm_sqp(alpha,a,b,beta,c,NULL,NULL)!=BLIS_SUCCESS) { + printf("\n configuration not supported or failed while calling bli_gemm_sqp. + \n falling back to bli_gemm API"); bli_gemm( alpha, a, b, beta, c ); } #else//TEST_SQP