From 373b27fdc1cf2697cca83aa0c74c97d9e6fb6485 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Thu, 3 Mar 2022 13:00:17 +0000 Subject: [PATCH 1/4] [SYCL][CUDA] add bf16 builtins --- sycl/include/CL/__spirv/spirv_ops.hpp | 22 ++++++ sycl/include/CL/sycl.hpp | 1 + .../sycl/ext/oneapi/bf16_storage_builtins.hpp | 79 +++++++++++++++++++ 3 files changed, 102 insertions(+) create mode 100644 sycl/include/sycl/ext/oneapi/bf16_storage_builtins.hpp diff --git a/sycl/include/CL/__spirv/spirv_ops.hpp b/sycl/include/CL/__spirv/spirv_ops.hpp index 4878cc4dd5db8..782d33bc7ca4f 100644 --- a/sycl/include/CL/__spirv/spirv_ops.hpp +++ b/sycl/include/CL/__spirv/spirv_ops.hpp @@ -755,6 +755,28 @@ __spirv_ocl_printf(const __attribute__((opencl_constant)) char *Format, ...); extern SYCL_EXTERNAL int __spirv_ocl_printf(const char *Format, ...); #endif +extern SYCL_EXTERNAL __SYCL_EXPORT uint16_t __clc_fabs(uint16_t) noexcept; + +#define __CLC_BF16(...) \ +extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fabs(__VA_ARGS__) noexcept; \ +extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fmin(__VA_ARGS__, __VA_ARGS__) noexcept; \ +extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fmax(__VA_ARGS__, __VA_ARGS__) noexcept; \ +extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fma(__VA_ARGS__, __VA_ARGS__, __VA_ARGS__) noexcept; + +#define __CLC_BF16_SCAL_VEC(TYPE) \ +__CLC_BF16(TYPE) \ +__CLC_BF16(__ocl_vec_t) \ +__CLC_BF16(__ocl_vec_t) \ +__CLC_BF16(__ocl_vec_t) \ +__CLC_BF16(__ocl_vec_t) \ +__CLC_BF16(__ocl_vec_t) + +__CLC_BF16_SCAL_VEC(uint16_t) +__CLC_BF16_SCAL_VEC(uint32_t) + +#undef __CLC_BF16_SCAL_VEC +#undef __CLC_BF16 + #else // if !__SYCL_DEVICE_ONLY__ template diff --git a/sycl/include/CL/sycl.hpp b/sycl/include/CL/sycl.hpp index 96b9d2e7fd9f9..91cb8c18e1865 100644 --- a/sycl/include/CL/sycl.hpp +++ b/sycl/include/CL/sycl.hpp @@ -60,6 +60,7 @@ #if SYCL_EXT_ONEAPI_BACKEND_LEVEL_ZERO #include #endif +#include #include #include #include diff --git a/sycl/include/sycl/ext/oneapi/bf16_storage_builtins.hpp b/sycl/include/sycl/ext/oneapi/bf16_storage_builtins.hpp new file mode 100644 index 0000000000000..88737c6c668d5 --- /dev/null +++ b/sycl/include/sycl/ext/oneapi/bf16_storage_builtins.hpp @@ -0,0 +1,79 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +__SYCL_INLINE_NAMESPACE(cl) { +namespace sycl { +namespace ext { +namespace oneapi { + +namespace detail { + +template struct is_bf16_storage_type { + static constexpr int value = false; +}; + +template <> struct is_bf16_storage_type { + static constexpr int value = true; +}; + +template <> struct is_bf16_storage_type { + static constexpr int value = true; +}; + +template struct is_bf16_storage_type> { + static constexpr int value = true; +}; + +template struct is_bf16_storage_type> { + static constexpr int value = true; +}; + +} // namespace detail + +template +std::enable_if_t::value, T> fabs(T x) { +#ifdef __SYCL_DEVICE_ONLY__ + return __clc_fabs(x); +#else + throw runtime_error("bf16 is not supported on host device.", + PI_INVALID_DEVICE); +#endif +} +template +std::enable_if_t::value, T> fmin(T x, T y) { +#ifdef __SYCL_DEVICE_ONLY__ + return __clc_fmin(x, y); +#else + throw runtime_error("bf16 is not supported on host device.", + PI_INVALID_DEVICE); +#endif +} +template +std::enable_if_t::value, T> fmax(T x, T y) { +#ifdef __SYCL_DEVICE_ONLY__ + return __clc_fmax(x, y); +#else + throw runtime_error("bf16 is not supported on host device.", + PI_INVALID_DEVICE); +#endif +} +template +std::enable_if_t::value, T> fma(T x, T y, T z) { +#ifdef __SYCL_DEVICE_ONLY__ + return __clc_fma(x, y, z); +#else + throw runtime_error("bf16 is not supported on host device.", + PI_INVALID_DEVICE); +#endif +} + +} // namespace oneapi +} // namespace ext +} // namespace sycl +} // __SYCL_INLINE_NAMESPACE(cl) From e684326f2c75df0318913c8e2ee8fa2d51a972ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Mon, 7 Mar 2022 03:02:32 -0800 Subject: [PATCH 2/4] fix a bug in intrinsics --- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index c980f6ed4bdc2..32ad5a3ede0a0 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -855,13 +855,13 @@ def INT_NVVM_FABS_D : F_MATH_1<"abs.f64 \t$dst, $src0;", Float64Regs, // Abs, Neg bf16, bf16x2 // -def INT_NVVM_ABS_BF16 : F_MATH_1<"abs.bf16 \t$dst, $dst;", Int16Regs, +def INT_NVVM_ABS_BF16 : F_MATH_1<"abs.bf16 \t$dst, $src0;", Int16Regs, Int16Regs, int_nvvm_abs_bf16, [hasPTX70, hasSM80]>; -def INT_NVVM_ABS_BF16X2 : F_MATH_1<"abs.bf16x2 \t$dst, $dst;", Int32Regs, +def INT_NVVM_ABS_BF16X2 : F_MATH_1<"abs.bf16x2 \t$dst, $src0;", Int32Regs, Int32Regs, int_nvvm_abs_bf16x2, [hasPTX70, hasSM80]>; -def INT_NVVM_NEG_BF16 : F_MATH_1<"neg.bf16 \t$dst, $dst;", Int16Regs, +def INT_NVVM_NEG_BF16 : F_MATH_1<"neg.bf16 \t$dst, $src0;", Int16Regs, Int16Regs, int_nvvm_neg_bf16, [hasPTX70, hasSM80]>; -def INT_NVVM_NEG_BF16X2 : F_MATH_1<"neg.bf16x2 \t$dst, $dst;", Int32Regs, +def INT_NVVM_NEG_BF16X2 : F_MATH_1<"neg.bf16x2 \t$dst, $src0;", Int32Regs, Int32Regs, int_nvvm_neg_bf16x2, [hasPTX70, hasSM80]>; // From 0449fc9e003f26442ff0bdd22c04a81a980fbd9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Mon, 7 Mar 2022 03:06:00 -0800 Subject: [PATCH 3/4] remove redundant declaration --- sycl/include/CL/__spirv/spirv_ops.hpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/sycl/include/CL/__spirv/spirv_ops.hpp b/sycl/include/CL/__spirv/spirv_ops.hpp index 782d33bc7ca4f..5ff56a539fa9d 100644 --- a/sycl/include/CL/__spirv/spirv_ops.hpp +++ b/sycl/include/CL/__spirv/spirv_ops.hpp @@ -755,8 +755,6 @@ __spirv_ocl_printf(const __attribute__((opencl_constant)) char *Format, ...); extern SYCL_EXTERNAL int __spirv_ocl_printf(const char *Format, ...); #endif -extern SYCL_EXTERNAL __SYCL_EXPORT uint16_t __clc_fabs(uint16_t) noexcept; - #define __CLC_BF16(...) \ extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fabs(__VA_ARGS__) noexcept; \ extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fmin(__VA_ARGS__, __VA_ARGS__) noexcept; \ From 2f3afe4cde7980e949df3ef917d0aed8af268468 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Mon, 7 Mar 2022 11:38:35 +0000 Subject: [PATCH 4/4] format --- sycl/include/CL/__spirv/spirv_ops.hpp | 30 +++++++++++++++------------ 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/sycl/include/CL/__spirv/spirv_ops.hpp b/sycl/include/CL/__spirv/spirv_ops.hpp index 5ff56a539fa9d..d0d778c998959 100644 --- a/sycl/include/CL/__spirv/spirv_ops.hpp +++ b/sycl/include/CL/__spirv/spirv_ops.hpp @@ -755,19 +755,23 @@ __spirv_ocl_printf(const __attribute__((opencl_constant)) char *Format, ...); extern SYCL_EXTERNAL int __spirv_ocl_printf(const char *Format, ...); #endif -#define __CLC_BF16(...) \ -extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fabs(__VA_ARGS__) noexcept; \ -extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fmin(__VA_ARGS__, __VA_ARGS__) noexcept; \ -extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fmax(__VA_ARGS__, __VA_ARGS__) noexcept; \ -extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fma(__VA_ARGS__, __VA_ARGS__, __VA_ARGS__) noexcept; - -#define __CLC_BF16_SCAL_VEC(TYPE) \ -__CLC_BF16(TYPE) \ -__CLC_BF16(__ocl_vec_t) \ -__CLC_BF16(__ocl_vec_t) \ -__CLC_BF16(__ocl_vec_t) \ -__CLC_BF16(__ocl_vec_t) \ -__CLC_BF16(__ocl_vec_t) +#define __CLC_BF16(...) \ + extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fabs( \ + __VA_ARGS__) noexcept; \ + extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fmin( \ + __VA_ARGS__, __VA_ARGS__) noexcept; \ + extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fmax( \ + __VA_ARGS__, __VA_ARGS__) noexcept; \ + extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fma( \ + __VA_ARGS__, __VA_ARGS__, __VA_ARGS__) noexcept; + +#define __CLC_BF16_SCAL_VEC(TYPE) \ + __CLC_BF16(TYPE) \ + __CLC_BF16(__ocl_vec_t) \ + __CLC_BF16(__ocl_vec_t) \ + __CLC_BF16(__ocl_vec_t) \ + __CLC_BF16(__ocl_vec_t) \ + __CLC_BF16(__ocl_vec_t) __CLC_BF16_SCAL_VEC(uint16_t) __CLC_BF16_SCAL_VEC(uint32_t)