From 780d18d4f4282193948cbf849f9dc7c0fed064d2 Mon Sep 17 00:00:00 2001 From: Ian McElhenny Date: Wed, 1 Apr 2020 14:52:27 -0500 Subject: [PATCH] Added AARCH64 support --- AnnService/CMakeLists.txt | 22 +- AnnService/inc/Core/Common/BKTree.h | 8 +- AnnService/inc/Core/Common/Dataset.h | 25 +- AnnService/inc/Core/Common/DistanceUtils.h | 361 +++++++++--------- AnnService/inc/Core/Common/KDTree.h | 5 +- AnnService/inc/Core/Common/malloc_aligned.hpp | 19 + AnnService/src/Core/BKT/BKTIndex.cpp | 10 +- AnnService/src/Core/KDT/KDTIndex.cpp | 5 +- CMakeLists.txt | 47 ++- 9 files changed, 276 insertions(+), 226 deletions(-) create mode 100644 AnnService/inc/Core/Common/malloc_aligned.hpp diff --git a/AnnService/CMakeLists.txt b/AnnService/CMakeLists.txt index fffc5ce42..8ecd4f0b0 100644 --- a/AnnService/CMakeLists.txt +++ b/AnnService/CMakeLists.txt @@ -1,8 +1,26 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -file(GLOB HDR_FILES ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/Common/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/BKT/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/KDT/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Helper/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Helper/VectorSetReaders/*.h) -file(GLOB SRC_FILES ${PROJECT_SOURCE_DIR}/AnnService/src/Core/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Core/Common/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Core/BKT/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Core/KDT/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Helper/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Helper/VectorSetReaders/*.cpp) +file(GLOB HDR_FILES + ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/*.h + ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/Common/*.h + ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/BKT/*.h + ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/KDT/*.h + ${PROJECT_SOURCE_DIR}/AnnService/inc/Helper/*.h + ${PROJECT_SOURCE_DIR}/AnnService/inc/Helper/VectorSetReaders/*.h + ${PROJECT_SOURCE_DIR}/simde/simde/*.h + ${PROJECT_SOURCE_DIR}/simde/simde/x86/*.h + ${PROJECT_SOURCE_DIR}/simde/simde/arm/*.h + ${PROJECT_SOURCE_DIR}/simde/simde/arm/neon/*.h + ) +file(GLOB SRC_FILES + ${PROJECT_SOURCE_DIR}/AnnService/src/Core/*.cpp + ${PROJECT_SOURCE_DIR}/AnnService/src/Core/Common/*.cpp + ${PROJECT_SOURCE_DIR}/AnnService/src/Core/BKT/*.cpp + ${PROJECT_SOURCE_DIR}/AnnService/src/Core/KDT/*.cpp + ${PROJECT_SOURCE_DIR}/AnnService/src/Helper/*.cpp + ${PROJECT_SOURCE_DIR}/AnnService/src/Helper/VectorSetReaders/*.cpp + ) include_directories(${PROJECT_SOURCE_DIR}/AnnService) diff --git a/AnnService/inc/Core/Common/BKTree.h b/AnnService/inc/Core/Common/BKTree.h index e33f7f9b4..69f4413a4 100644 --- a/AnnService/inc/Core/Common/BKTree.h +++ b/AnnService/inc/Core/Common/BKTree.h @@ -47,8 +47,8 @@ namespace SPTAG T* newTCenters; KmeansArgs(int k, DimensionType dim, SizeType datasize, int threadnum) : _K(k), _D(dim), _T(threadnum) { - centers = (T*)aligned_malloc(sizeof(T) * k * dim, ALIGN); - newTCenters = (T*)aligned_malloc(sizeof(T) * k * dim, ALIGN); + centers = (T*)simde_mm_malloc(sizeof(T) * k * dim, ALIGN); + newTCenters = (T*)simde_mm_malloc(sizeof(T) * k * dim, ALIGN); counts = new SizeType[k]; newCenters = new float[threadnum * k * dim]; newCounts = new SizeType[threadnum * k]; @@ -58,8 +58,8 @@ namespace SPTAG } ~KmeansArgs() { - aligned_free(centers); - aligned_free(newTCenters); + simde_mm_free(centers); + simde_mm_free(newTCenters); delete[] counts; delete[] newCenters; delete[] newCounts; diff --git a/AnnService/inc/Core/Common/Dataset.h b/AnnService/inc/Core/Common/Dataset.h index 07eef492c..8cae5208d 100644 --- a/AnnService/inc/Core/Common/Dataset.h +++ b/AnnService/inc/Core/Common/Dataset.h @@ -6,16 +6,17 @@ #include -#if defined(_MSC_VER) || defined(__INTEL_COMPILER) +// #if defined(_MSC_VER) || defined(__INTEL_COMPILER) #include -#else -#include -#endif // defined(__GNUC__) +// #else +// #include +#include "malloc_aligned.hpp" +// #endif // defined(__GNUC__) #define ALIGN 32 - -#define aligned_malloc(a, b) _mm_malloc(a, b) -#define aligned_free(a) _mm_free(a) +// +// #define aligned_malloc(a, b) _mm_malloc(a, b) +// #define aligned_free(a) _mm_free(a) #pragma warning(disable:4996) // 'fopen': This function or variable may be unsafe. Consider using fopen_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details. @@ -48,8 +49,8 @@ namespace SPTAG } ~Dataset() { - if (ownData) aligned_free(data); - for (T* ptr : incBlocks) aligned_free(ptr); + if (ownData) simde_mm_free(data); + for (T* ptr : incBlocks) simde_mm_free(ptr); incBlocks.clear(); } void Initialize(SizeType rows_, DimensionType cols_, T* data_ = nullptr, bool transferOnwership_ = true) @@ -60,7 +61,7 @@ namespace SPTAG if (data_ == nullptr || !transferOnwership_) { ownData = true; - data = (T*)aligned_malloc(((size_t)rows) * cols * sizeof(T), ALIGN); + data = (T*)simde_mm_malloc(((size_t)rows) * cols * sizeof(T), ALIGN); if (data_ != nullptr) memcpy(data, data_, ((size_t)rows) * cols * sizeof(T)); else std::memset(data, -1, ((size_t)rows) * cols * sizeof(T)); } @@ -109,7 +110,7 @@ namespace SPTAG while (written < num) { SizeType curBlockIdx = (incRows + written) / rowsInBlock; if (curBlockIdx >= (SizeType)incBlocks.size()) { - T* newBlock = (T*)aligned_malloc(((size_t)rowsInBlock) * cols * sizeof(T), ALIGN); + T* newBlock = (T*)simde_mm_malloc(((size_t)rowsInBlock) * cols * sizeof(T), ALIGN); if (newBlock == nullptr) return ErrorCode::MemoryOverFlow; incBlocks.push_back(newBlock); } @@ -130,7 +131,7 @@ namespace SPTAG while (written < num) { SizeType curBlockIdx = (incRows + written) / rowsInBlock; if (curBlockIdx >= (SizeType)incBlocks.size()) { - T* newBlock = (T*)aligned_malloc(sizeof(T) * rowsInBlock * cols, ALIGN); + T* newBlock = (T*)simde_mm_malloc(sizeof(T) * rowsInBlock * cols, ALIGN); if (newBlock == nullptr) return ErrorCode::MemoryOverFlow; std::memset(newBlock, -1, sizeof(T) * rowsInBlock * cols); incBlocks.push_back(newBlock); diff --git a/AnnService/inc/Core/Common/DistanceUtils.h b/AnnService/inc/Core/Common/DistanceUtils.h index 4f190543b..cc71b7bea 100644 --- a/AnnService/inc/Core/Common/DistanceUtils.h +++ b/AnnService/inc/Core/Common/DistanceUtils.h @@ -4,10 +4,11 @@ #ifndef _SPTAG_COMMON_DISTANCEUTILS_H_ #define _SPTAG_COMMON_DISTANCEUTILS_H_ -#include #include #include "CommonUtils.h" +#include "../../../../simde/simde/x86/avx.h" +#include "../../../../simde/simde/x86/avx2.h" #if defined(__AVX2__) #define AVX2 @@ -21,6 +22,7 @@ #define SSE #endif + #ifndef _MSC_VER #define DIFF128 diff128 #define DIFF256 diff256 @@ -36,181 +38,182 @@ namespace SPTAG class DistanceUtils { public: -#if defined(SSE2) || defined(AVX2) - static inline __m128 _mm_mul_epi8(__m128i X, __m128i Y) +#if defined(SSE2) || defined(AVX2) || defined(__SIMDE_OPTIMIZATION__) + + static inline simde__m128 simde_mm_mul_epi8(simde__m128i X, simde__m128i Y) { - __m128i zero = _mm_setzero_si128(); + simde__m128i zero = simde_mm_setzero_si128(); - __m128i sign_x = _mm_cmplt_epi8(X, zero); - __m128i sign_y = _mm_cmplt_epi8(Y, zero); + simde__m128i sign_x = simde_mm_cmplt_epi8(X, zero); + simde__m128i sign_y = simde_mm_cmplt_epi8(Y, zero); - __m128i xlo = _mm_unpacklo_epi8(X, sign_x); - __m128i xhi = _mm_unpackhi_epi8(X, sign_x); - __m128i ylo = _mm_unpacklo_epi8(Y, sign_y); - __m128i yhi = _mm_unpackhi_epi8(Y, sign_y); + simde__m128i xlo = simde_mm_unpacklo_epi8(X, sign_x); + simde__m128i xhi = simde_mm_unpackhi_epi8(X, sign_x); + simde__m128i ylo = simde_mm_unpacklo_epi8(Y, sign_y); + simde__m128i yhi = simde_mm_unpackhi_epi8(Y, sign_y); - return _mm_cvtepi32_ps(_mm_add_epi32(_mm_madd_epi16(xlo, ylo), _mm_madd_epi16(xhi, yhi))); + return simde_mm_cvtepi32_ps(simde_mm_add_epi32(simde_mm_madd_epi16(xlo, ylo), simde_mm_madd_epi16(xhi, yhi))); } - static inline __m128 _mm_sqdf_epi8(__m128i X, __m128i Y) + static inline simde__m128 simde_mm_sqdf_epi8(simde__m128i X, simde__m128i Y) { - __m128i zero = _mm_setzero_si128(); + simde__m128i zero = simde_mm_setzero_si128(); - __m128i sign_x = _mm_cmplt_epi8(X, zero); - __m128i sign_y = _mm_cmplt_epi8(Y, zero); + simde__m128i sign_x = simde_mm_cmplt_epi8(X, zero); + simde__m128i sign_y = simde_mm_cmplt_epi8(Y, zero); - __m128i xlo = _mm_unpacklo_epi8(X, sign_x); - __m128i xhi = _mm_unpackhi_epi8(X, sign_x); - __m128i ylo = _mm_unpacklo_epi8(Y, sign_y); - __m128i yhi = _mm_unpackhi_epi8(Y, sign_y); + simde__m128i xlo = simde_mm_unpacklo_epi8(X, sign_x); + simde__m128i xhi = simde_mm_unpackhi_epi8(X, sign_x); + simde__m128i ylo = simde_mm_unpacklo_epi8(Y, sign_y); + simde__m128i yhi = simde_mm_unpackhi_epi8(Y, sign_y); - __m128i dlo = _mm_sub_epi16(xlo, ylo); - __m128i dhi = _mm_sub_epi16(xhi, yhi); + simde__m128i dlo = simde_mm_sub_epi16(xlo, ylo); + simde__m128i dhi = simde_mm_sub_epi16(xhi, yhi); - return _mm_cvtepi32_ps(_mm_add_epi32(_mm_madd_epi16(dlo, dlo), _mm_madd_epi16(dhi, dhi))); + return simde_mm_cvtepi32_ps(simde_mm_add_epi32(simde_mm_madd_epi16(dlo, dlo), simde_mm_madd_epi16(dhi, dhi))); } - static inline __m128 _mm_mul_epu8(__m128i X, __m128i Y) + static inline simde__m128 simde_mm_mul_epu8(simde__m128i X, simde__m128i Y) { - __m128i zero = _mm_setzero_si128(); + simde__m128i zero = simde_mm_setzero_si128(); - __m128i xlo = _mm_unpacklo_epi8(X, zero); - __m128i xhi = _mm_unpackhi_epi8(X, zero); - __m128i ylo = _mm_unpacklo_epi8(Y, zero); - __m128i yhi = _mm_unpackhi_epi8(Y, zero); + simde__m128i xlo = simde_mm_unpacklo_epi8(X, zero); + simde__m128i xhi = simde_mm_unpackhi_epi8(X, zero); + simde__m128i ylo = simde_mm_unpacklo_epi8(Y, zero); + simde__m128i yhi = simde_mm_unpackhi_epi8(Y, zero); - return _mm_cvtepi32_ps(_mm_add_epi32(_mm_madd_epi16(xlo, ylo), _mm_madd_epi16(xhi, yhi))); + return simde_mm_cvtepi32_ps(simde_mm_add_epi32(simde_mm_madd_epi16(xlo, ylo), simde_mm_madd_epi16(xhi, yhi))); } - static inline __m128 _mm_sqdf_epu8(__m128i X, __m128i Y) + static inline simde__m128 simde_mm_sqdf_epu8(simde__m128i X, simde__m128i Y) { - __m128i zero = _mm_setzero_si128(); + simde__m128i zero = simde_mm_setzero_si128(); - __m128i xlo = _mm_unpacklo_epi8(X, zero); - __m128i xhi = _mm_unpackhi_epi8(X, zero); - __m128i ylo = _mm_unpacklo_epi8(Y, zero); - __m128i yhi = _mm_unpackhi_epi8(Y, zero); + simde__m128i xlo = simde_mm_unpacklo_epi8(X, zero); + simde__m128i xhi = simde_mm_unpackhi_epi8(X, zero); + simde__m128i ylo = simde_mm_unpacklo_epi8(Y, zero); + simde__m128i yhi = simde_mm_unpackhi_epi8(Y, zero); - __m128i dlo = _mm_sub_epi16(xlo, ylo); - __m128i dhi = _mm_sub_epi16(xhi, yhi); + simde__m128i dlo = simde_mm_sub_epi16(xlo, ylo); + simde__m128i dhi = simde_mm_sub_epi16(xhi, yhi); - return _mm_cvtepi32_ps(_mm_add_epi32(_mm_madd_epi16(dlo, dlo), _mm_madd_epi16(dhi, dhi))); + return simde_mm_cvtepi32_ps(simde_mm_add_epi32(simde_mm_madd_epi16(dlo, dlo), simde_mm_madd_epi16(dhi, dhi))); } - static inline __m128 _mm_mul_epi16(__m128i X, __m128i Y) + static inline simde__m128 simde_mm_mul_epi16(simde__m128i X, simde__m128i Y) { - return _mm_cvtepi32_ps(_mm_madd_epi16(X, Y)); + return simde_mm_cvtepi32_ps(simde_mm_madd_epi16(X, Y)); } - static inline __m128 _mm_sqdf_epi16(__m128i X, __m128i Y) + static inline simde__m128 simde_mm_sqdf_epi16(simde__m128i X, simde__m128i Y) { - __m128i zero = _mm_setzero_si128(); + simde__m128i zero = simde_mm_setzero_si128(); - __m128i sign_x = _mm_cmplt_epi16(X, zero); - __m128i sign_y = _mm_cmplt_epi16(Y, zero); + simde__m128i sign_x = simde_mm_cmplt_epi16(X, zero); + simde__m128i sign_y = simde_mm_cmplt_epi16(Y, zero); - __m128i xlo = _mm_unpacklo_epi16(X, sign_x); - __m128i xhi = _mm_unpackhi_epi16(X, sign_x); - __m128i ylo = _mm_unpacklo_epi16(Y, sign_y); - __m128i yhi = _mm_unpackhi_epi16(Y, sign_y); + simde__m128i xlo = simde_mm_unpacklo_epi16(X, sign_x); + simde__m128i xhi = simde_mm_unpackhi_epi16(X, sign_x); + simde__m128i ylo = simde_mm_unpacklo_epi16(Y, sign_y); + simde__m128i yhi = simde_mm_unpackhi_epi16(Y, sign_y); - __m128 dlo = _mm_cvtepi32_ps(_mm_sub_epi32(xlo, ylo)); - __m128 dhi = _mm_cvtepi32_ps(_mm_sub_epi32(xhi, yhi)); + simde__m128 dlo = simde_mm_cvtepi32_ps(simde_mm_sub_epi32(xlo, ylo)); + simde__m128 dhi = simde_mm_cvtepi32_ps(simde_mm_sub_epi32(xhi, yhi)); - return _mm_add_ps(_mm_mul_ps(dlo, dlo), _mm_mul_ps(dhi, dhi)); + return simde_mm_add_ps(simde_mm_mul_ps(dlo, dlo), simde_mm_mul_ps(dhi, dhi)); } #endif -#if defined(SSE) || defined(AVX) - static inline __m128 _mm_sqdf_ps(__m128 X, __m128 Y) +#if defined(SSE) || defined(AVX) || defined(__SIMDE_OPTIMIZATION__) + static inline simde__m128 simde_mm_sqdf_ps(simde__m128 X, simde__m128 Y) { - __m128 d = _mm_sub_ps(X, Y); - return _mm_mul_ps(d, d); + simde__m128 d = simde_mm_sub_ps(X, Y); + return simde_mm_mul_ps(d, d); } #endif #if defined(AVX2) - static inline __m256 _mm256_mul_epi8(__m256i X, __m256i Y) + static inline simde__m256 simde_mm256_mul_epi8(simde__m256i X, simde__m256i Y) { - __m256i zero = _mm256_setzero_si256(); + simde__m256i zero = simde_mm256_setzero_si256(); - __m256i sign_x = _mm256_cmpgt_epi8(zero, X); - __m256i sign_y = _mm256_cmpgt_epi8(zero, Y); + simde__m256i sign_x = simde_mm256_cmpgt_epi8(zero, X); + simde__m256i sign_y = simde_mm256_cmpgt_epi8(zero, Y); - __m256i xlo = _mm256_unpacklo_epi8(X, sign_x); - __m256i xhi = _mm256_unpackhi_epi8(X, sign_x); - __m256i ylo = _mm256_unpacklo_epi8(Y, sign_y); - __m256i yhi = _mm256_unpackhi_epi8(Y, sign_y); + simde__m256i xlo = simde_mm256_unpacklo_epi8(X, sign_x); + simde__m256i xhi = simde_mm256_unpackhi_epi8(X, sign_x); + simde__m256i ylo = simde_mm256_unpacklo_epi8(Y, sign_y); + simde__m256i yhi = simde_mm256_unpackhi_epi8(Y, sign_y); - return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(xlo, ylo), _mm256_madd_epi16(xhi, yhi))); + return simde_mm256_cvtepi32_ps(simde_mm256_add_epi32(simde_mm256_madd_epi16(xlo, ylo), simde_mm256_madd_epi16(xhi, yhi))); } - static inline __m256 _mm256_sqdf_epi8(__m256i X, __m256i Y) + static inline simde__m256 simde_mm256_sqdf_epi8(simde__m256i X, simde__m256i Y) { - __m256i zero = _mm256_setzero_si256(); + simde__m256i zero = simde_mm256_setzero_si256(); - __m256i sign_x = _mm256_cmpgt_epi8(zero, X); - __m256i sign_y = _mm256_cmpgt_epi8(zero, Y); + simde__m256i sign_x = simde_mm256_cmpgt_epi8(zero, X); + simde__m256i sign_y = simde_mm256_cmpgt_epi8(zero, Y); - __m256i xlo = _mm256_unpacklo_epi8(X, sign_x); - __m256i xhi = _mm256_unpackhi_epi8(X, sign_x); - __m256i ylo = _mm256_unpacklo_epi8(Y, sign_y); - __m256i yhi = _mm256_unpackhi_epi8(Y, sign_y); + simde__m256i xlo = simde_mm256_unpacklo_epi8(X, sign_x); + simde__m256i xhi = simde_mm256_unpackhi_epi8(X, sign_x); + simde__m256i ylo = simde_mm256_unpacklo_epi8(Y, sign_y); + simde__m256i yhi = simde_mm256_unpackhi_epi8(Y, sign_y); - __m256i dlo = _mm256_sub_epi16(xlo, ylo); - __m256i dhi = _mm256_sub_epi16(xhi, yhi); + simde__m256i dlo = simde_mm256_sub_epi16(xlo, ylo); + simde__m256i dhi = simde_mm256_sub_epi16(xhi, yhi); - return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(dlo, dlo), _mm256_madd_epi16(dhi, dhi))); + return simde_mm256_cvtepi32_ps(simde_mm256_add_epi32(simde_mm256_madd_epi16(dlo, dlo), simde_mm256_madd_epi16(dhi, dhi))); } - static inline __m256 _mm256_mul_epu8(__m256i X, __m256i Y) + static inline simde__m256 simde_mm256_mul_epu8(simde__m256i X, simde__m256i Y) { - __m256i zero = _mm256_setzero_si256(); + simde__m256i zero = simde_mm256_setzero_si256(); - __m256i xlo = _mm256_unpacklo_epi8(X, zero); - __m256i xhi = _mm256_unpackhi_epi8(X, zero); - __m256i ylo = _mm256_unpacklo_epi8(Y, zero); - __m256i yhi = _mm256_unpackhi_epi8(Y, zero); + simde__m256i xlo = simde_mm256_unpacklo_epi8(X, zero); + simde__m256i xhi = simde_mm256_unpackhi_epi8(X, zero); + simde__m256i ylo = simde_mm256_unpacklo_epi8(Y, zero); + simde__m256i yhi = simde_mm256_unpackhi_epi8(Y, zero); - return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(xlo, ylo), _mm256_madd_epi16(xhi, yhi))); + return simde_mm256_cvtepi32_ps(simde_mm256_add_epi32(simde_mm256_madd_epi16(xlo, ylo), simde_mm256_madd_epi16(xhi, yhi))); } - static inline __m256 _mm256_sqdf_epu8(__m256i X, __m256i Y) + static inline simde__m256 simde_mm256_sqdf_epu8(simde__m256i X, simde__m256i Y) { - __m256i zero = _mm256_setzero_si256(); + simde__m256i zero = simde_mm256_setzero_si256(); - __m256i xlo = _mm256_unpacklo_epi8(X, zero); - __m256i xhi = _mm256_unpackhi_epi8(X, zero); - __m256i ylo = _mm256_unpacklo_epi8(Y, zero); - __m256i yhi = _mm256_unpackhi_epi8(Y, zero); + simde__m256i xlo = simde_mm256_unpacklo_epi8(X, zero); + simde__m256i xhi = simde_mm256_unpackhi_epi8(X, zero); + simde__m256i ylo = simde_mm256_unpacklo_epi8(Y, zero); + simde__m256i yhi = simde_mm256_unpackhi_epi8(Y, zero); - __m256i dlo = _mm256_sub_epi16(xlo, ylo); - __m256i dhi = _mm256_sub_epi16(xhi, yhi); + simde__m256i dlo = simde_mm256_sub_epi16(xlo, ylo); + simde__m256i dhi = simde_mm256_sub_epi16(xhi, yhi); - return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(dlo, dlo), _mm256_madd_epi16(dhi, dhi))); + return simde_mm256_cvtepi32_ps(simde_mm256_add_epi32(simde_mm256_madd_epi16(dlo, dlo), simde_mm256_madd_epi16(dhi, dhi))); } - static inline __m256 _mm256_mul_epi16(__m256i X, __m256i Y) + static inline simde__m256 simde_mm256_mul_epi16(simde__m256i X, simde__m256i Y) { - return _mm256_cvtepi32_ps(_mm256_madd_epi16(X, Y)); + return simde_mm256_cvtepi32_ps(simde_mm256_madd_epi16(X, Y)); } - static inline __m256 _mm256_sqdf_epi16(__m256i X, __m256i Y) + static inline simde__m256 simde_mm256_sqdf_epi16(simde__m256i X, simde__m256i Y) { - __m256i zero = _mm256_setzero_si256(); + simde__m256i zero = simde_mm256_setzero_si256(); - __m256i sign_x = _mm256_cmpgt_epi16(zero, X); - __m256i sign_y = _mm256_cmpgt_epi16(zero, Y); + simde__m256i sign_x = simde_mm256_cmpgt_epi16(zero, X); + simde__m256i sign_y = simde_mm256_cmpgt_epi16(zero, Y); - __m256i xlo = _mm256_unpacklo_epi16(X, sign_x); - __m256i xhi = _mm256_unpackhi_epi16(X, sign_x); - __m256i ylo = _mm256_unpacklo_epi16(Y, sign_y); - __m256i yhi = _mm256_unpackhi_epi16(Y, sign_y); + simde__m256i xlo = simde_mm256_unpacklo_epi16(X, sign_x); + simde__m256i xhi = simde_mm256_unpackhi_epi16(X, sign_x); + simde__m256i ylo = simde_mm256_unpacklo_epi16(Y, sign_y); + simde__m256i yhi = simde_mm256_unpackhi_epi16(Y, sign_y); - __m256 dlo = _mm256_cvtepi32_ps(_mm256_sub_epi32(xlo, ylo)); - __m256 dhi = _mm256_cvtepi32_ps(_mm256_sub_epi32(xhi, yhi)); + simde__m256 dlo = simde_mm256_cvtepi32_ps(simde_mm256_sub_epi32(xlo, ylo)); + simde__m256 dhi = simde_mm256_cvtepi32_ps(simde_mm256_sub_epi32(xhi, yhi)); - return _mm256_add_ps(_mm256_mul_ps(dlo, dlo), _mm256_mul_ps(dhi, dhi)); + return simde_mm256_add_ps(simde_mm256_mul_ps(dlo, dlo), simde_mm256_mul_ps(dhi, dhi)); } #endif #if defined(AVX) - static inline __m256 _mm256_sqdf_ps(__m256 X, __m256 Y) + static inline simde__m256 simde_mm256_sqdf_ps(simde__m256 X, simde__m256 Y) { - __m256 d = _mm256_sub_ps(X, Y); - return _mm256_mul_ps(d, d); + simde__m256 d = simde_mm256_sub_ps(X, Y); + return simde_mm256_mul_ps(d, d); } #endif /* @@ -240,23 +243,23 @@ namespace SPTAG const std::int8_t* pEnd4 = pX + ((length >> 2) << 2); const std::int8_t* pEnd1 = pX + length; #if defined(SSE2) - __m128 diff128 = _mm_setzero_ps(); + simde__m128 diff128 = simde_mm_setzero_ps(); while (pX < pEnd32) { - REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epi8, _mm_add_ps, diff128) - REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epi8, _mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 16, simde_mm_loadu_si128, simde_mm_sqdf_epi8, simde_mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 16, simde_mm_loadu_si128, simde_mm_sqdf_epi8, simde_mm_add_ps, diff128) } while (pX < pEnd16) { - REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epi8, _mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 16, simde_mm_loadu_si128, simde_mm_sqdf_epi8, simde_mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; #elif defined(AVX2) - __m256 diff256 = _mm256_setzero_ps(); + simde__m256 diff256 = simde_mm256_setzero_ps(); while (pX < pEnd32) { - REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_sqdf_epi8, _mm256_add_ps, diff256) + REPEAT(simde__m256i, simde__m256i, 32, simde_mm256_loadu_si256, simde_mm256_sqdf_epi8, simde_mm256_add_ps, diff256) } - __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); + simde__m128 diff128 = simde_mm_add_ps(simde_mm256_castps256_ps128(diff256), simde_mm256_extractf128_ps(diff256, 1)); while (pX < pEnd16) { - REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epi8, _mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 16, simde_mm_loadu_si128, simde_mm_sqdf_epi8, simde_mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; #else @@ -281,23 +284,23 @@ namespace SPTAG const std::uint8_t* pEnd4 = pX + ((length >> 2) << 2); const std::uint8_t* pEnd1 = pX + length; #if defined(SSE2) - __m128 diff128 = _mm_setzero_ps(); + simde__m128 diff128 = simde_mm_setzero_ps(); while (pX < pEnd32) { - REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epu8, _mm_add_ps, diff128) - REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epu8, _mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 16, simde_mm_loadu_si128, simde_mm_sqdf_epu8, simde_mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 16, simde_mm_loadu_si128, simde_mm_sqdf_epu8, simde_mm_add_ps, diff128) } while (pX < pEnd16) { - REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epu8, _mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 16, simde_mm_loadu_si128, simde_mm_sqdf_epu8, simde_mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; #elif defined(AVX2) - __m256 diff256 = _mm256_setzero_ps(); + simde__m256 diff256 = simde_mm256_setzero_ps(); while (pX < pEnd32) { - REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_sqdf_epu8, _mm256_add_ps, diff256) + REPEAT(simde__m256i, simde__m256i, 32, simde_mm256_loadu_si256, simde_mm256_sqdf_epu8, simde_mm256_add_ps, diff256) } - __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); + simde__m128 diff128 = simde_mm_add_ps(simde_mm256_castps256_ps128(diff256), simde_mm256_extractf128_ps(diff256, 1)); while (pX < pEnd16) { - REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epu8, _mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 16, simde_mm_loadu_si128, simde_mm_sqdf_epu8, simde_mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; #else @@ -322,23 +325,23 @@ namespace SPTAG const std::int16_t* pEnd4 = pX + ((length >> 2) << 2); const std::int16_t* pEnd1 = pX + length; #if defined(SSE2) - __m128 diff128 = _mm_setzero_ps(); + simde__m128 diff128 = simde_mm_setzero_ps(); while (pX < pEnd16) { - REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_sqdf_epi16, _mm_add_ps, diff128) - REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_sqdf_epi16, _mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 8, simde_mm_loadu_si128, simde_mm_sqdf_epi16, simde_mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 8, simde_mm_loadu_si128, simde_mm_sqdf_epi16, simde_mm_add_ps, diff128) } while (pX < pEnd8) { - REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_sqdf_epi16, _mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 8, simde_mm_loadu_si128, simde_mm_sqdf_epi16, simde_mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; #elif defined(AVX2) - __m256 diff256 = _mm256_setzero_ps(); + simde__m256 diff256 = simde_mm256_setzero_ps(); while (pX < pEnd16) { - REPEAT(__m256i, __m256i, 16, _mm256_loadu_si256, _mm256_sqdf_epi16, _mm256_add_ps, diff256) + REPEAT(simde__m256i, simde__m256i, 16, simde_mm256_loadu_si256, simde_mm256_sqdf_epi16, simde_mm256_add_ps, diff256) } - __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); + simde__m128 diff128 = simde_mm_add_ps(simde_mm256_castps256_ps128(diff256), simde_mm256_extractf128_ps(diff256, 1)); while (pX < pEnd8) { - REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_sqdf_epi16, _mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 8, simde_mm_loadu_si128, simde_mm_sqdf_epi16, simde_mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; #else @@ -362,31 +365,31 @@ namespace SPTAG const float* pEnd16 = pX + ((length >> 4) << 4); const float* pEnd4 = pX + ((length >> 2) << 2); const float* pEnd1 = pX + length; -#if defined(SSE) - __m128 diff128 = _mm_setzero_ps(); +#if defined(SSE2) + simde__m128 diff128 = simde_mm_setzero_ps(); while (pX < pEnd16) { - REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128) - REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128) - REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128) - REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128) + REPEAT(simde__m128, const float, 4, simde_mm_loadu_ps, simde_mm_sqdf_ps, simde_mm_add_ps, diff128) + REPEAT(simde__m128, const float, 4, simde_mm_loadu_ps, simde_mm_sqdf_ps, simde_mm_add_ps, diff128) + REPEAT(simde__m128, const float, 4, simde_mm_loadu_ps, simde_mm_sqdf_ps, simde_mm_add_ps, diff128) + REPEAT(simde__m128, const float, 4, simde_mm_loadu_ps, simde_mm_sqdf_ps, simde_mm_add_ps, diff128) } while (pX < pEnd4) { - REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128) + REPEAT(simde__m128, const float, 4, simde_mm_loadu_ps, simde_mm_sqdf_ps, simde_mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; #elif defined(AVX) - __m256 diff256 = _mm256_setzero_ps(); + simde__m256 diff256 = simde_mm256_setzero_ps(); while (pX < pEnd16) { - REPEAT(__m256, const float, 8, _mm256_loadu_ps, _mm256_sqdf_ps, _mm256_add_ps, diff256) - REPEAT(__m256, const float, 8, _mm256_loadu_ps, _mm256_sqdf_ps, _mm256_add_ps, diff256) + REPEAT(simde__m256, const float, 8, simde_mm256_loadu_ps, simde_mm256_sqdf_ps, simde_mm256_add_ps, diff256) + REPEAT(simde__m256, const float, 8, simde_mm256_loadu_ps, simde_mm256_sqdf_ps, simde_mm256_add_ps, diff256) } - __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); + simde__m128 diff128 = simde_mm_add_ps(simde_mm256_castps256_ps128(diff256), simde_mm256_extractf128_ps(diff256, 1)); while (pX < pEnd4) { - REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128) + REPEAT(simde__m128, const float, 4, simde_mm_loadu_ps, simde_mm_sqdf_ps, simde_mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; #else @@ -419,23 +422,23 @@ namespace SPTAG const std::int8_t* pEnd1 = pX + length; #if defined(SSE2) - __m128 diff128 = _mm_setzero_ps(); + simde__m128 diff128 = simde_mm_setzero_ps(); while (pX < pEnd32) { - REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epi8, _mm_add_ps, diff128) - REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epi8, _mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 16, simde_mm_loadu_si128, simde_mm_mul_epi8, simde_mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 16, simde_mm_loadu_si128, simde_mm_mul_epi8, simde_mm_add_ps, diff128) } while (pX < pEnd16) { - REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epi8, _mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 16, simde_mm_loadu_si128, simde_mm_mul_epi8, simde_mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; #elif defined(AVX2) - __m256 diff256 = _mm256_setzero_ps(); + simde__m256 diff256 = simde_mm256_setzero_ps(); while (pX < pEnd32) { - REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_mul_epi8, _mm256_add_ps, diff256) + REPEAT(simde__m256i, simde__m256i, 32, simde_mm256_loadu_si256, simde_mm256_mul_epi8, simde_mm256_add_ps, diff256) } - __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); + simde__m128 diff128 = simde_mm_add_ps(simde_mm256_castps256_ps128(diff256), simde_mm256_extractf128_ps(diff256, 1)); while (pX < pEnd16) { - REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epi8, _mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 16, simde_mm_loadu_si128, simde_mm_mul_epi8, simde_mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; #else @@ -459,23 +462,23 @@ namespace SPTAG const std::uint8_t* pEnd1 = pX + length; #if defined(SSE2) - __m128 diff128 = _mm_setzero_ps(); + simde__m128 diff128 =simde_mm_setzero_ps(); while (pX < pEnd32) { - REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epu8, _mm_add_ps, diff128) - REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epu8, _mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 16, simde_mm_loadu_si128, simde_mm_mul_epu8, simde_mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 16, simde_mm_loadu_si128, simde_mm_mul_epu8, simde_mm_add_ps, diff128) } while (pX < pEnd16) { - REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epu8, _mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 16, simde_mm_loadu_si128, simde_mm_mul_epu8, simde_mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; #elif defined(AVX2) - __m256 diff256 = _mm256_setzero_ps(); + simde__m256 diff256 = simde_mm256_setzero_ps(); while (pX < pEnd32) { - REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_mul_epu8, _mm256_add_ps, diff256) + REPEAT(simde__m256i, simde__m256i, 32, simde_mm256_loadu_si256, simde_mm256_mul_epu8, simde_mm256_add_ps, diff256) } - __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); + simde__m128 diff128 =simde_mm_add_ps(simde_mm256_castps256_ps128(diff256), simde_mm256_extractf128_ps(diff256, 1)); while (pX < pEnd16) { - REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epu8, _mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 16,simde_mm_loadu_si128, simde_mm_mul_epu8, simde_mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; #else @@ -498,24 +501,24 @@ namespace SPTAG const std::int16_t* pEnd4 = pX + ((length >> 2) << 2); const std::int16_t* pEnd1 = pX + length; #if defined(SSE2) - __m128 diff128 = _mm_setzero_ps(); + simde__m128 diff128 = simde_mm_setzero_ps(); while (pX < pEnd16) { - REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_mul_epi16, _mm_add_ps, diff128) - REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_mul_epi16, _mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 8, simde_mm_loadu_si128, simde_mm_mul_epi16, simde_mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 8, simde_mm_loadu_si128, simde_mm_mul_epi16, simde_mm_add_ps, diff128) } while (pX < pEnd8) { - REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_mul_epi16, _mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 8, simde_mm_loadu_si128, simde_mm_mul_epi16, simde_mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; #elif defined(AVX2) - __m256 diff256 = _mm256_setzero_ps(); + simde__m256 diff256 = simde_mm256_setzero_ps(); while (pX < pEnd16) { - REPEAT(__m256i, __m256i, 16, _mm256_loadu_si256, _mm256_mul_epi16, _mm256_add_ps, diff256) + REPEAT(simde__m256i, simde__m256i, 16, simde_mm256_loadu_si256, simde_mm256_mul_epi16, simde_mm256_add_ps, diff256) } - __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); + simde__m128 diff128 =simde_mm_add_ps(simde_mm256_castps256_ps128(diff256), simde_mm256_extractf128_ps(diff256, 1)); while (pX < pEnd8) { - REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_mul_epi16, _mm_add_ps, diff128) + REPEAT(simde__m128i, simde__m128i, 8,simde_mm_loadu_si128, simde_mm_mul_epi16, simde_mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; #else @@ -537,32 +540,32 @@ namespace SPTAG const float* pEnd16 = pX + ((length >> 4) << 4); const float* pEnd4 = pX + ((length >> 2) << 2); const float* pEnd1 = pX + length; -#if defined(SSE) - __m128 diff128 = _mm_setzero_ps(); +#if defined(SSE2) + simde__m128 diff128 =simde_mm_setzero_ps(); while (pX < pEnd16) { - REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128) - REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128) - REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128) - REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128) + REPEAT(simde__m128, const float, 4, simde_mm_loadu_ps, simde_mm_mul_ps, simde_mm_add_ps, diff128) + REPEAT(simde__m128, const float, 4, simde_mm_loadu_ps, simde_mm_mul_ps, simde_mm_add_ps, diff128) + REPEAT(simde__m128, const float, 4, simde_mm_loadu_ps, simde_mm_mul_ps, simde_mm_add_ps, diff128) + REPEAT(simde__m128, const float, 4, simde_mm_loadu_ps, simde_mm_mul_ps, simde_mm_add_ps, diff128) } while (pX < pEnd4) { - REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128) + REPEAT(simde__m128, const float, 4, simde_mm_loadu_ps, simde_mm_mul_ps, simde_mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; #elif defined(AVX) - __m256 diff256 = _mm256_setzero_ps(); + simde__m256 diff256 = simde_mm256_setzero_ps(); while (pX < pEnd16) { - REPEAT(__m256, const float, 8, _mm256_loadu_ps, _mm256_mul_ps, _mm256_add_ps, diff256) - REPEAT(__m256, const float, 8, _mm256_loadu_ps, _mm256_mul_ps, _mm256_add_ps, diff256) + REPEAT(simde__m256, const float, 8, simde_mm256_loadu_ps, simde_mm256_mul_ps, simde_mm256_add_ps, diff256) + REPEAT(simde__m256, const float, 8, simde_mm256_loadu_ps, simde_mm256_mul_ps, simde_mm256_add_ps, diff256) } - __m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1)); + simde__m128 diff128 =simde_mm_add_ps(simde_mm256_castps256_ps128(diff256), simde_mm256_extractf128_ps(diff256, 1)); while (pX < pEnd4) { - REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128) + REPEAT(simde__m128, const float, 4, simde_mm_loadu_ps, simde_mm_mul_ps, simde_mm_add_ps, diff128) } float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3]; #else diff --git a/AnnService/inc/Core/Common/KDTree.h b/AnnService/inc/Core/Common/KDTree.h index fa50ade06..49270ad46 100644 --- a/AnnService/inc/Core/Common/KDTree.h +++ b/AnnService/inc/Core/Common/KDTree.h @@ -8,6 +8,7 @@ #include #include #include +#include "../../../../simde/simde/x86/sse4.2.h" #include "../VectorIndex.h" @@ -183,8 +184,8 @@ namespace SPTAG if (index >= p_index->GetNumSamples()) return; #ifdef PREFETCH const char* data = (const char *)(p_index->GetSample(index)); - _mm_prefetch(data, _MM_HINT_T0); - _mm_prefetch(data + 64, _MM_HINT_T0); + __builtin_prefetch(data); + __builtin_prefetch(data + 64); #endif if (p_space.CheckAndSet(index)) return; diff --git a/AnnService/inc/Core/Common/malloc_aligned.hpp b/AnnService/inc/Core/Common/malloc_aligned.hpp new file mode 100644 index 000000000..bc4bfb3f2 --- /dev/null +++ b/AnnService/inc/Core/Common/malloc_aligned.hpp @@ -0,0 +1,19 @@ +// +// Created by user on 3/25/20. +// + +#ifndef SPTAGLIB_MALLOC_ALIGNED_HPP +#define SPTAGLIB_MALLOC_ALIGNED_HPP + +static inline void* simde_mm_malloc (size_t size, size_t alignment) { + // This works on posix systems + // For Windows users: C11 should have aligned_alloc(...) that could replace simde_mm_malloc(...), but simde requires C99 + void *ptr; + if (alignment == 1) return malloc (size); + if (alignment == 2 || (sizeof (void *) == 8 && alignment == 4)) alignment = sizeof (void *); + if (posix_memalign (&ptr, alignment, size) == 0) return ptr; + else return NULL; +} +static inline void simde_mm_free (void * ptr) {free (ptr);} + +#endif //SPTAGLIB_MALLOC_ALIGNED_HPP diff --git a/AnnService/src/Core/BKT/BKTIndex.cpp b/AnnService/src/Core/BKT/BKTIndex.cpp index b26fbfadd..d8144a4b1 100644 --- a/AnnService/src/Core/BKT/BKTIndex.cpp +++ b/AnnService/src/Core/BKT/BKTIndex.cpp @@ -2,6 +2,8 @@ // Licensed under the MIT License. #include "inc/Core/BKT/Index.h" +#include "../../../../simde/simde/x86/sse4.2.h" + #pragma warning(disable:4996) // 'fopen': This function or variable may be unsafe. Consider using fopen_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details. #pragma warning(disable:4242) // '=' : conversion from 'int' to 'short', possible loss of data @@ -109,9 +111,9 @@ namespace SPTAG COMMON::HeapCell gnode = p_space.m_NGQueue.pop(); \ SizeType tmpNode = gnode.node; \ const SizeType *node = m_pGraph[tmpNode]; \ - _mm_prefetch((const char *)node, _MM_HINT_T0); \ + __builtin_prefetch((const char *)node); \ for (DimensionType i = 0; i <= checkPos; i++) { \ - _mm_prefetch((const char *)(m_pSamples)[node[i]], _MM_HINT_T0); \ + __builtin_prefetch((const char *)(m_pSamples)[node[i]]); \ } \ if (gnode.distance <= p_query.worstDist()) { \ SizeType checkNode = node[checkPos]; \ @@ -163,9 +165,9 @@ namespace SPTAG COMMON::HeapCell gnode = p_space.m_NGQueue.pop(); \ SizeType tmpNode = gnode.node; \ const SizeType *node = m_pGraph[tmpNode]; \ - _mm_prefetch((const char *)node, _MM_HINT_T0); \ + __builtin_prefetch((const char *)node); \ for (DimensionType i = 0; i <= checkPos; i++) { \ - _mm_prefetch((const char *)(m_pSamples)[node[i]], _MM_HINT_T0); \ + __builtin_prefetch((const char *)(m_pSamples)[node[i]]); \ } \ if (gnode.distance <= p_query.worstDist()) { \ SizeType checkNode = node[checkPos]; \ diff --git a/AnnService/src/Core/KDT/KDTIndex.cpp b/AnnService/src/Core/KDT/KDTIndex.cpp index 06a47916a..52cf99031 100644 --- a/AnnService/src/Core/KDT/KDTIndex.cpp +++ b/AnnService/src/Core/KDT/KDTIndex.cpp @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "inc/Core/KDT/Index.h" +#include "../../../../simde/simde/x86/sse4.2.h" #pragma warning(disable:4996) // 'fopen': This function or variable may be unsafe. Consider using fopen_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details. #pragma warning(disable:4242) // '=' : conversion from 'int' to 'short', possible loss of data @@ -107,9 +108,9 @@ namespace SPTAG while (!p_space.m_NGQueue.empty()) { \ COMMON::HeapCell gnode = p_space.m_NGQueue.pop(); \ const SizeType *node = m_pGraph[gnode.node]; \ - _mm_prefetch((const char *)node, _MM_HINT_T0); \ + __builtin_prefetch((const char *)node); \ for (DimensionType i = 0; i < m_pGraph.m_iNeighborhoodSize; i++) \ - _mm_prefetch((const char *)(m_pSamples)[node[i]], _MM_HINT_T0); \ + __builtin_prefetch((const char *)(m_pSamples)[node[i]]); \ CheckDeleted { \ if (!p_query.AddPoint(gnode.node, gnode.distance) && p_space.m_iNumberOfCheckedLeaves > p_space.m_iMaxCheck) { \ p_query.SortResult(); return; \ diff --git a/CMakeLists.txt b/CMakeLists.txt index fedd6ed73..6efe2becd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,29 +23,34 @@ if(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU") if (CXX_COMPILER_VERSION VERSION_LESS 5.0) message(FATAL_ERROR "GCC version must be at least 5.0!") endif() - set (CMAKE_CXX_FLAGS_RELEASE "-Wall -Wunreachable-code -Wno-reorder -Wno-sign-compare -Wno-unknown-pragmas -Wcast-align -lm -lrt -DNDEBUG -std=c++14 -fopenmp -O3 -march=native") - set (CMAKE_CXX_FLAGS_DEBUG "-Wall -Wunreachable-code -Wno-reorder -Wno-sign-compare -Wno-unknown-pragmas -Wcast-align -ggdb -lm -lrt -DNDEBUG -std=c++14 -fopenmp -O3 -march=native") - EXEC_PROGRAM(cat ARGS "/proc/cpuinfo" OUTPUT_VARIABLE CPUINFO) - STRING(REGEX REPLACE "^.*(avx2).*$" "\\1" THERE ${CPUINFO}) - STRING(COMPARE EQUAL "avx2" "${THERE}" AVX2_SUPPORTED) - if (AVX2_SUPPORTED) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2") - endif() - STRING(REGEX REPLACE "^.*(avx).*$" "\\1" THERE ${CPUINFO}) - STRING(COMPARE EQUAL "avx" "${THERE}" AVX_SUPPORTED) - if (AVX_SUPPORTED) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx") - endif() - STRING(REGEX REPLACE "^.*(sse2).*$" "\\1" THERE ${CPUINFO}) - STRING(COMPARE EQUAL "sse2" "${THERE}" SSE2_SUPPORTED) - if (SSE2_SUPPORTED) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse2") - endif() - set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse") - message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") + set (CMAKE_CXX_FLAGS_RELEASE "-Wall -Wunreachable-code -Wno-reorder -Wno-sign-compare -Wno-unknown-pragmas -Wcast-align -lm -lrt -DNDEBUG -std=c++14 -fopenmp -O3 -march=native -fopenmp-simd -DSIMDE_ENABLE_OPENMP") + set (CMAKE_CXX_FLAGS_DEBUG "-Wall -Wunreachable-code -Wno-reorder -Wno-sign-compare -Wno-unknown-pragmas -Wcast-align -ggdb -lm -lrt -DNDEBUG -std=c++14 -fopenmp -O3 -march=native -fopenmp-simd -DSIMDE_ENABLE_OPENMP") + IF(${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") + message(STATUS "Building for ARM") + ELSE(${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") + message(STATUS "Building for x86") + EXEC_PROGRAM(cat ARGS "/proc/cpuinfo" OUTPUT_VARIABLE CPUINFO) + STRING(REGEX REPLACE "^.*(avx2).*$" "\\1" THERE ${CPUINFO}) + STRING(COMPARE EQUAL "avx2" "${THERE}" AVX2_SUPPORTED) + if (AVX2_SUPPORTED) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2") + endif() + STRING(REGEX REPLACE "^.*(avx).*$" "\\1" THERE ${CPUINFO}) + STRING(COMPARE EQUAL "avx" "${THERE}" AVX_SUPPORTED) + if (AVX_SUPPORTED) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx") + endif() + STRING(REGEX REPLACE "^.*(sse2).*$" "\\1" THERE ${CPUINFO}) + STRING(COMPARE EQUAL "sse2" "${THERE}" SSE2_SUPPORTED) + if (SSE2_SUPPORTED) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse2") + endif() + set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse") + message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") + ENDIF(${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") elseif(WIN32) if(NOT MSVC14) - message(FATAL_ERROR "On Windows, only MSVC version 14 are supported!") + message(FATAL_ERROR "On Windows, only MSVC version 14 are supported!") endif() include(CheckCXXCompilerFlag) CHECK_CXX_COMPILER_FLAG("/arch:AVX2" AVX2_SUPPORTED)