@@ -30,6 +30,14 @@ uint32_t to_uint32_t(sycl::marray<bfloat16, N> x, size_t start) {
3030}
3131} // namespace detail
3232
33+ // According to bfloat16 format, NAN value's exponent field is 0xFF and
34+ // significand has non-zero bits.
35+ template <typename T>
36+ std::enable_if_t <std::is_same<T, bfloat16>::value, bool > isnan (T x) {
37+ oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
38+ return (((XBits & 0x7F80 ) == 0x7F80 ) && (XBits & 0x7F )) ? true : false ;
39+ }
40+
3341template <typename T>
3442std::enable_if_t <std::is_same<T, bfloat16>::value, T> fabs (T x) {
3543#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
@@ -74,20 +82,31 @@ std::enable_if_t<std::is_same<T, bfloat16>::value, T> fmin(T x, T y) {
7482 oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits (y);
7583 return oneapi::detail::bitsToBfloat16 (__clc_fmin (XBits, YBits));
7684#else
77- std::ignore = x;
78- std::ignore = y;
79- throw runtime_error (
80- " bfloat16 math functions are not currently supported on the host device." ,
81- PI_ERROR_INVALID_DEVICE);
85+ static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0 ;
86+ oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
87+ oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits (y);
88+ if (isnan (x) && isnan (y))
89+ return oneapi::detail::bitsToBfloat16 (CanonicalNan);
90+
91+ if (isnan (x))
92+ return y;
93+ if (isnan (y))
94+ return x;
95+ if (((XBits | YBits) ==
96+ static_cast <oneapi::detail::Bfloat16StorageT>(0x8000 )) &&
97+ !(XBits & YBits))
98+ return oneapi::detail::bitsToBfloat16 (
99+ static_cast <oneapi::detail::Bfloat16StorageT>(0x8000 ));
100+
101+ return (x < y) ? x : y;
82102#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
83103}
84104
85105template <size_t N>
86106sycl::marray<bfloat16, N> fmin (sycl::marray<bfloat16, N> x,
87107 sycl::marray<bfloat16, N> y) {
88- #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
89108 sycl::marray<bfloat16, N> res;
90-
109+ # if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
91110 for (size_t i = 0 ; i < N / 2 ; i++) {
92111 auto partial_res = __clc_fmin (detail::to_uint32_t (x, i * 2 ),
93112 detail::to_uint32_t (y, i * 2 ));
@@ -101,15 +120,12 @@ sycl::marray<bfloat16, N> fmin(sycl::marray<bfloat16, N> x,
101120 oneapi::detail::bfloat16ToBits (y[N - 1 ]);
102121 res[N - 1 ] = oneapi::detail::bitsToBfloat16 (__clc_fmin (XBits, YBits));
103122 }
104-
105- return res;
106123#else
107- std::ignore = x;
108- std::ignore = y;
109- throw runtime_error (
110- " bfloat16 math functions are not currently supported on the host device." ,
111- PI_ERROR_INVALID_DEVICE);
124+ for (size_t i = 0 ; i < N; i++) {
125+ res[i] = fmin (x[i], y[i]);
126+ }
112127#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
128+ return res;
113129}
114130
115131template <typename T>
@@ -119,20 +135,30 @@ std::enable_if_t<std::is_same<T, bfloat16>::value, T> fmax(T x, T y) {
119135 oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits (y);
120136 return oneapi::detail::bitsToBfloat16 (__clc_fmax (XBits, YBits));
121137#else
122- std::ignore = x;
123- std::ignore = y;
124- throw runtime_error (
125- " bfloat16 math functions are not currently supported on the host device." ,
126- PI_ERROR_INVALID_DEVICE);
138+ static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0 ;
139+ oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
140+ oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits (y);
141+ if (isnan (x) && isnan (y))
142+ return oneapi::detail::bitsToBfloat16 (CanonicalNan);
143+
144+ if (isnan (x))
145+ return y;
146+ if (isnan (y))
147+ return x;
148+ if (((XBits | YBits) ==
149+ static_cast <oneapi::detail::Bfloat16StorageT>(0x8000 )) &&
150+ !(XBits & YBits))
151+ return oneapi::detail::bitsToBfloat16 (0 );
152+
153+ return (x > y) ? x : y;
127154#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
128155}
129156
130157template <size_t N>
131158sycl::marray<bfloat16, N> fmax (sycl::marray<bfloat16, N> x,
132159 sycl::marray<bfloat16, N> y) {
133- #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
134160 sycl::marray<bfloat16, N> res;
135-
161+ # if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
136162 for (size_t i = 0 ; i < N / 2 ; i++) {
137163 auto partial_res = __clc_fmax (detail::to_uint32_t (x, i * 2 ),
138164 detail::to_uint32_t (y, i * 2 ));
@@ -146,14 +172,12 @@ sycl::marray<bfloat16, N> fmax(sycl::marray<bfloat16, N> x,
146172 oneapi::detail::bfloat16ToBits (y[N - 1 ]);
147173 res[N - 1 ] = oneapi::detail::bitsToBfloat16 (__clc_fmax (XBits, YBits));
148174 }
149- return res;
150175#else
151- std::ignore = x;
152- std::ignore = y;
153- throw runtime_error (
154- " bfloat16 math functions are not currently supported on the host device." ,
155- PI_ERROR_INVALID_DEVICE);
176+ for (size_t i = 0 ; i < N; i++) {
177+ res[i] = fmax (x[i], y[i]);
178+ }
156179#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
180+ return res;
157181}
158182
159183template <typename T>
0 commit comments