From 30eeaba1e60bf005a94152b67c15cf1260ed3831 Mon Sep 17 00:00:00 2001 From: Dounia Khaldi Date: Tue, 29 Aug 2023 09:00:33 -0700 Subject: [PATCH 1/2] [SYCL][Matrix spec] Do not delete assign op and copy ctor as they are needed for joint_matrix_mad --- .../sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc b/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc index 94c2bebe04906..0778191d8ab16 100644 --- a/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc +++ b/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc @@ -109,16 +109,17 @@ template struct joint_matrix { joint_matrix(); - joint_matrix(const joint_matrix &) = delete; - joint_matrix &operator=(const joint_matrix &) = delete; + joint_matrix(const joint_matrix &); + joint_matrix &operator=(const joint_matrix &); }; } // namespace sycl::ext::oneapi::experimental::matrix ``` -The constructor for the `joint_matrix` type is a group function as -defined in section 4.17.3 of the core SYCL specification. It must be -encountered in converged control flow by all work-items in the -`Group`. +The constructors for the `joint_matrix` type and the assignment +operator are group functions as defined in section 4.17.3 of the core +SYCL specification. They must be encountered in converged control flow +by all work-items in the `Group`. Note that the assignment operator +and the copy constructor do not copy the entire matrix content. ==== Group Memory Scope Most operations on the joint_matrix are group functions, meaning that From 8e139edf2930745aa75865e8e0f14da6a1e8c9c5 Mon Sep 17 00:00:00 2001 From: Dounia Khaldi Date: Fri, 15 Sep 2023 09:54:08 -0700 Subject: [PATCH 2/2] Keep ctors deleted, make D input in mad, remove deprecated API like multi_ptr --- .../sycl_ext_intel_matrix.asciidoc | 19 +++++--- .../sycl_ext_oneapi_matrix.asciidoc | 46 +++++++++---------- 2 files changed, 35 insertions(+), 30 deletions(-) diff --git a/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_intel_matrix.asciidoc b/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_intel_matrix.asciidoc index d6de22bdae391..d861db1e141fe 100644 --- a/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_intel_matrix.asciidoc +++ b/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_intel_matrix.asciidoc @@ -304,15 +304,20 @@ q.submit([&](sycl::handler& cgh) { joint_matrix tC; joint_matrix_fill(sg, tC, 0); for (int k = 0; k < K; k += tK) { - joint_matrix_load(sg, tA, accA + sg_startx * tM * K + k, K); - joint_matrix_load(sg, tB, accB + k * N*4 + sg_starty/SG_SIZE*tN*4, N*4); - tC = joint_matrix_mad(sg, tA, tB, tC); + joint_matrix_load(sg, tA, + accA.template get_multi_ptr() + + sg_startx * tM * K + k, K); + joint_matrix_load(sg, tB, + accB.template get_multi_ptr() + + k * N*4 + sg_starty/SG_SIZE*tN*4, N*4); + joint_matrix_mad(sg, tC, tA, tB, tC); } - auto wi_data_c = ext::intel::experimental::matrix::get_wi_data(sg, tC); - for (int i = 0; i < wi_data_c.length(); i++) - wi_data_c[i] *= alpha; + joint_matrix_apply(sg, tC, [=](int8_t x) { + x *= alpha; + }); joint_matrix_store(sg, tC, - accC + sg_startx * tM * N + sg_starty/SG_SIZE*tN, N, layout::row_major); + accC.template get_multi_ptr() + + sg_startx * tM * N + sg_starty/SG_SIZE*tN, N, layout::row_major); }); }); q.wait(); diff --git a/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc b/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc index 0778191d8ab16..65c8508eab668 100644 --- a/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc +++ b/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc @@ -109,17 +109,16 @@ template struct joint_matrix { joint_matrix(); - joint_matrix(const joint_matrix &); - joint_matrix &operator=(const joint_matrix &); + joint_matrix(const joint_matrix &) = delete; + joint_matrix &operator=(const joint_matrix &) = delete; }; } // namespace sycl::ext::oneapi::experimental::matrix ``` -The constructors for the `joint_matrix` type and the assignment -operator are group functions as defined in section 4.17.3 of the core -SYCL specification. They must be encountered in converged control flow -by all work-items in the `Group`. Note that the assignment operator -and the copy constructor do not copy the entire matrix content. +The constructor for the `joint_matrix` type is a group function as +defined in section 4.17.3 of the core SYCL specification. It must be +encountered in converged control flow by all work-items in the +`Group`. ==== Group Memory Scope Most operations on the joint_matrix are group functions, meaning that @@ -275,11 +274,11 @@ rows for the row major layout, or between columns for the column major layout. ```c++ namespace sycl::ext::oneapi::experimental::matrix { -template -joint_matrix -joint_matrix_mad(Group g, +template +void joint_matrix_mad(Group g, + joint_matrix &D, const joint_matrix &A, const joint_matrix &B, const joint_matrix &C); @@ -288,7 +287,7 @@ joint_matrix_mad(Group g, ``` The matrix multiply and add function performs the multiply operation on the matrices `A` and `B`, accumulates the result with `C` and returns -the result. +the result into the matrix `D`. Each device supports only certain combinations of types for the `A`, `B`, and `C` matrices. The application must use the query operations @@ -506,6 +505,12 @@ range<2> L = {1, SG_SIZE}; int8_t *memA = malloc_shared(M*K, q); int8_t *memB = malloc_shared(K*N, q); int32_t *memC = malloc_shared(M*N, q); +auto pA = address_space_cast(memA); +auto pB = address_space_cast(memB); +auto pC = address_space_cast(memC); q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> item) [[sycl::reqd_sub_group_size(SG_SIZE)]] { const auto global_idx = item.get_global_id(0); @@ -518,20 +523,15 @@ q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> item) joint_matrix tC; joint_matrix_fill(sg, tC, 0); for (int k = 0; k < K; k += tK) { - joint_matrix_load(sg, tA, - multi_ptr(memA) + - sg_startx * tM * K + k, K); - joint_matrix_load(sg, tB, - multi_ptr(memB) + - k * N + sg_starty/SG_SIZE*tN, N); - tC = joint_matrix_mad(sg, tA, tB, tC); + joint_matrix_load(sg, tA, pA + sg_startx * tM * K + k, K); + joint_matrix_load(sg, tB, pB + k * N + sg_starty/SG_SIZE*tN, N); + joint_matrix_mad(sg, tC, tA, tB, tC); } joint_matrix_apply(sg, tC, [=](int8_t x) { x *= alpha; }); - joint_matrix_store(sg, tC, - multi_ptr(memC) + - sg_startx * tM * N + sg_starty/SG_SIZE*tN, N, layout::row_major); + joint_matrix_store(sg, tC, pC + sg_startx * tM * N + sg_starty/SG_SIZE*tN, + N, layout::row_major); }).wait(); ```