diff --git a/test/unit/cute/intel_xe/xe_copy_prefetch_2d.cpp b/test/unit/cute/intel_xe/xe_copy_prefetch_2d.cpp index 675b8d1c43..3659af0060 100644 --- a/test/unit/cute/intel_xe/xe_copy_prefetch_2d.cpp +++ b/test/unit/cute/intel_xe/xe_copy_prefetch_2d.cpp @@ -47,8 +47,8 @@ using namespace compat::experimental; #if (IGC_VERSION_MAJOR > 2) || (IGC_VERSION_MAJOR == 2 && IGC_VERSION_MINOR >= 18) -// Kernel name for unique identification -template +// Kernel name for unique identification - includes Bits to ensure uniqueness +template class XEPrefetch2DKernelName; // Device kernel for XE_PREFETCH_2D testing @@ -106,7 +106,7 @@ void test_xe_prefetch_2d() { // Initialize source with test pattern for (size_t i = 0; i < host_src.size(); ++i) { - host_src[i] = static_cast(i % 256); + host_src[i] = static_cast(static_cast(i % 256)); } // Copy to device @@ -122,7 +122,7 @@ void test_xe_prefetch_2d() { auto gridDim = compat::dim3(1); launch, - XEPrefetch2DKernelName>( + XEPrefetch2DKernelName>( launch_policy{ gridDim, blockDim, kernel_properties{sycl_exp::sub_group_size} @@ -150,6 +150,153 @@ TEST(CuTe_Xe, XE_PREFETCH_2D_float) { test_xe_prefetch_2d(); } + +// Test 4: 8-bit Minimal Configuration +TEST(CuTe_Xe, XE_PREFETCH_2D_8bit_Minimal) { + test_xe_prefetch_2d(); +} + +// Test 5: 8-bit Small Height +TEST(CuTe_Xe, XE_PREFETCH_2D_8bit_SmallHeight) { + test_xe_prefetch_2d(); +} + +// Test 6: 8-bit Medium Configuration +TEST(CuTe_Xe, XE_PREFETCH_2D_8bit_Medium) { + test_xe_prefetch_2d(); +} + +// Test 7: 8-bit Large Height +TEST(CuTe_Xe, XE_PREFETCH_2D_8bit_LargeHeight) { + test_xe_prefetch_2d(); +} + +// Test 8: 8-bit Wide Configuration (respecting 512-bit width limit) +TEST(CuTe_Xe, XE_PREFETCH_2D_8bit_Wide) { + test_xe_prefetch_2d(); // 8*64=512 bits (max) +} + +// Test 9: 16-bit Minimal Configuration +TEST(CuTe_Xe, XE_PREFETCH_2D_16bit_Minimal) { + test_xe_prefetch_2d(); +} + +// Test 10: 16-bit Small Configuration +TEST(CuTe_Xe, XE_PREFETCH_2D_16bit_Small) { + test_xe_prefetch_2d(); +} + +// Test 11: 16-bit Medium Configuration +TEST(CuTe_Xe, XE_PREFETCH_2D_16bit_Medium) { + test_xe_prefetch_2d(); +} + +// Test 12: 16-bit Large Height +TEST(CuTe_Xe, XE_PREFETCH_2D_16bit_LargeHeight) { + test_xe_prefetch_2d(); +} + +// Test 13: 16-bit Wide Configuration (respecting 512-bit width limit) +TEST(CuTe_Xe, XE_PREFETCH_2D_16bit_Wide) { + test_xe_prefetch_2d(); // 16*32=512 bits (max) +} + +// Test 14: 32-bit Minimal Configuration +TEST(CuTe_Xe, XE_PREFETCH_2D_32bit_Minimal) { + test_xe_prefetch_2d(); // 32*16=512 bits (max) +} + +// Test 15: 32-bit Small Configuration +TEST(CuTe_Xe, XE_PREFETCH_2D_32bit_Small) { + test_xe_prefetch_2d(); +} + +// Test 16: 32-bit Medium Configuration +TEST(CuTe_Xe, XE_PREFETCH_2D_32bit_Medium) { + test_xe_prefetch_2d(); +} + +// Test 17: 32-bit Large Height +TEST(CuTe_Xe, XE_PREFETCH_2D_32bit_LargeHeight) { + test_xe_prefetch_2d(); +} + +// Test 18: 32-bit Wide Configuration (respecting 512-bit width limit) +TEST(CuTe_Xe, XE_PREFETCH_2D_32bit_Wide) { + test_xe_prefetch_2d(); // 32*16=512 bits (max) +} + +// Test 19: 64-bit Small Configuration +TEST(CuTe_Xe, XE_PREFETCH_2D_64bit_Small) { + test_xe_prefetch_2d(); // 64*8=512 bits (max) +} + +// Test 20: 64-bit Medium Configuration +TEST(CuTe_Xe, XE_PREFETCH_2D_64bit_Medium) { + test_xe_prefetch_2d(); // 64*8=512 bits (max) +} + +// Test 21: 64-bit Large Height +TEST(CuTe_Xe, XE_PREFETCH_2D_64bit_LargeHeight) { + test_xe_prefetch_2d(); // 64*8=512 bits (max) +} + +// Test 22: Mixed Data Types - Power of Two Heights +TEST(CuTe_Xe, XE_PREFETCH_2D_PowerOfTwo_Heights) { + // 8-bit with power-of-two heights + test_xe_prefetch_2d(); + test_xe_prefetch_2d(); + + // 16-bit with power-of-two heights + test_xe_prefetch_2d(); + + // 32-bit with power-of-two heights + test_xe_prefetch_2d(); +} + +// Test 23: Various Width Configurations +TEST(CuTe_Xe, XE_PREFETCH_2D_VariousWidths) { + // 8-bit with various widths + test_xe_prefetch_2d(); + test_xe_prefetch_2d(); + + // 16-bit with various widths + test_xe_prefetch_2d(); + test_xe_prefetch_2d(); + + // 32-bit with various widths + test_xe_prefetch_2d(); + test_xe_prefetch_2d(); +} + +// Test 24: Square Tiles +TEST(CuTe_Xe, XE_PREFETCH_2D_SquareTiles) { + // 8-bit square (in memory view) + test_xe_prefetch_2d(); + + // 16-bit square + test_xe_prefetch_2d(); + + // 32-bit square + test_xe_prefetch_2d(); +} + +// Test 25: Tall Tiles (Height > Width) +TEST(CuTe_Xe, XE_PREFETCH_2D_TallTiles) { + test_xe_prefetch_2d(); + test_xe_prefetch_2d(); + test_xe_prefetch_2d(); +} + +// Test 26: Cache Line Optimization +TEST(CuTe_Xe, XE_PREFETCH_2D_CacheOptimized) { + // Configurations aligned to cache lines (64 bytes) + test_xe_prefetch_2d(); // 64 bytes per row + test_xe_prefetch_2d(); // 64 bytes per row + test_xe_prefetch_2d(); // 64 bytes per row + test_xe_prefetch_2d(); // 64 bytes per row +} + #else // For the fallback case diff --git a/test/unit/cute/intel_xe/xe_transpose_2d.cpp b/test/unit/cute/intel_xe/xe_transpose_2d.cpp index d2375d2fc8..f7a9752f24 100644 --- a/test/unit/cute/intel_xe/xe_transpose_2d.cpp +++ b/test/unit/cute/intel_xe/xe_transpose_2d.cpp @@ -25,75 +25,196 @@ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * OF THIS SOFTWARE, EVEN IF ADVISED OF POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ +#include "cutlass/detail/layout.hpp" + #include #include #include #include #include +#include + #include "cutlass_unit_test.h" +#include "utils.hpp" using namespace cute; +using namespace cutlass; +using namespace compat::experimental; + +#define SUBGROUP_SIZE (16) + +#if (IGC_VERSION_MAJOR > 2) || (IGC_VERSION_MAJOR == 2 && IGC_VERSION_MINOR >= 18) + +// Kernel name for unique identification +template class XETranspose2DKernelName; + +// Device kernel for XE_LOAD_2D_TRANSPOSE testing +// Note: Transpose load performs HW-level transpose during load operation +// Memory layout (Height×Width) is transposed to register layout (Width×Height) +template +void xe_transpose_2d_kernel(SrcTensor src, DstTensor dst) { + using namespace cute; + using Element = typename SrcTensor::value_type; + + // Only execute with the first subgroup to avoid race conditions + if (sycl::ext::oneapi::this_work_item::get_nd_item<1>().get_group(0) == 0) { + // Get thread/subgroup information + auto local_id = int(sycl::ext::oneapi::this_work_item::get_nd_item<1>().get_local_id(0)); + + // Create block 2D transpose load inside kernel (device-only operation) + using TransposeOp = XE_LOAD_2D_TRANSPOSE; + auto tiled_transpose = make_block_2d_copy(TransposeOp{}, src); + + // Get thread slice of the tiled transpose + auto thr_transpose = tiled_transpose.get_slice(local_id); + + // Create coordinate tensor for a single tile + // Note: coordinates are in memory space (Height×Width) + auto coord_shape = make_shape(Int{}, Int>{}); + Tensor coord_tile = make_identity_tensor(coord_shape); + + // Partition source coordinates for transpose load + auto thr_src_coord = thr_transpose.partition_S(coord_tile); + + // Create destination fragment - transpose changes the layout in registers + auto thr_dst_frag = thr_transpose.partition_fragment_D(coord_tile); + + // Perform the transpose load operation from global memory to registers + // Data is transposed during this operation by hardware + copy(tiled_transpose, thr_src_coord, thr_dst_frag); + + // For verification, we need to store the transposed data back + // Note: Output will be in transposed layout (Width×Height in memory) + // We store to the transposed destination shape + auto dst_coord_shape = make_shape(Int>{}, Int{}); + Tensor dst_coord_tile = make_identity_tensor(dst_coord_shape); + + using StoreOp = XE_STORE_2D; // Swapped dimensions + auto tiled_store = make_block_2d_copy(StoreOp{}, dst); + auto thr_store = tiled_store.get_slice(local_id); + + // Create destination coordinates for the store operation + auto thr_dst_coord = thr_store.partition_D(dst_coord_tile); + auto thr_src_frag = thr_store.partition_fragment_S(dst_coord_tile); + + // Copy from transpose fragment to store fragment + copy(thr_dst_frag, thr_src_frag); + + // Perform the store operation from registers to global memory + copy(tiled_store, thr_src_frag, thr_dst_coord); + + // Synchronize to ensure all threads complete their operations + sycl::group_barrier(sycl::ext::oneapi::this_work_item::get_nd_item<1>().get_group()); + } +} + +// Host test function template for transpose operations +template +void test_xe_transpose_2d() { + using namespace cute; + + // Source matrix dimensions (Height×Width in memory) + constexpr int M = Height; + constexpr int N = Width * sizeof_bits_v / Bits; + + // Destination will be transposed (Width×Height in memory) + constexpr int M_dst = N; + constexpr int N_dst = M; + + // Ensure proper alignment + constexpr int elem_alignment = 16 / sizeof(Element); + constexpr int aligned_N = ((N + elem_alignment - 1) / elem_alignment) * elem_alignment; + constexpr int aligned_M_dst = ((M_dst + elem_alignment - 1) / elem_alignment) * elem_alignment; + + // Allocate host memory + cutlass::host_vector host_src(M * aligned_N); + cutlass::host_vector host_dst(M_dst * aligned_M_dst); + + // Initialize source with test pattern + for (int i = 0; i < M; ++i) { + for (int j = 0; j < N; ++j) { + Element val; + if constexpr (std::is_floating_point_v || + std::is_same_v || + std::is_same_v) { + val = Element(static_cast(i * N + j) / 100.0f); + } else { + val = static_cast((i * N + j) % 256); + } + host_src[i * aligned_N + j] = val; + } + } + + // Copy to device + cutlass::device_vector device_src = host_src; + cutlass::device_vector device_dst(M_dst * aligned_M_dst); + + // Create source tensor (Height×Width) + Tensor tensor_src = + make_tensor(make_gmem_ptr(device_src.data()), + make_layout(Shape, Int>{}, + Stride, _1>{})); + + // Create destination tensor (Width×Height) - transposed shape + Tensor tensor_dst = + make_tensor(make_gmem_ptr(device_dst.data()), + make_layout(Shape, Int>{}, + Stride, _1>{})); + + // Launch kernel + auto blockDim = compat::dim3(SUBGROUP_SIZE); + auto gridDim = compat::dim3(1); + + launch, + XETranspose2DKernelName>( + launch_policy{ + gridDim, blockDim, + kernel_properties{sycl_exp::sub_group_size} + }, + tensor_src, tensor_dst); + + compat::wait_and_throw(); + host_dst = device_dst; + + // Verify transpose: dst[j][i] should equal src[i][j] + for (int i = 0; i < M; ++i) { + for (int j = 0; j < N; ++j) { + Element src_val = host_src[i * aligned_N + j]; + Element dst_val = host_dst[j * aligned_M_dst + i]; + EXPECT_EQ(dst_val, src_val) + << "Mismatch at src[" << i << "][" << j << "] vs dst[" << j << "][" << i << "]"; + } + } +} + +// Test 32-bit transpose operations (Width ≤ 8 constraint) +TEST(CuTe_Xe, XE_TRANSPOSE_2D_float_4x8) { + test_xe_transpose_2d(); +} + +TEST(CuTe_Xe, XE_TRANSPOSE_2D_float_8x8) { + test_xe_transpose_2d(); +} + +TEST(CuTe_Xe, XE_TRANSPOSE_2D_float_4x4) { + test_xe_transpose_2d(); +} -#if (IGC_VERSION_MAJOR > 2) || (IGC_VERSION_MAJOR == 2 && IGC_VERSION_MINOR >= 18) - -TEST(CuTe_Xe, XE_LOAD_2D_TRANSPOSE_API_Declaration) { - // Template: XE_LOAD_2D_TRANSPOSE - // Constraints: Bits == 32 || Bits == 64, Width <= 8 - // For 64-bit: Height == 8 && Width < 4 - - // Test 32-bit transpose operations - using TransposeOp_32bit_2x4 = XE_LOAD_2D_TRANSPOSE<32, 2, 4>; - using TransposeOp_32bit_4x8 = XE_LOAD_2D_TRANSPOSE<32, 4, 8>; - using TransposeOp_32bit_8x2 = XE_LOAD_2D_TRANSPOSE<32, 8, 2>; - - // Test 64-bit transpose operations (limited constraints) - using TransposeOp_64bit_8x2 = XE_LOAD_2D_TRANSPOSE<64, 8, 2>; - using TransposeOp_64bit_8x3 = XE_LOAD_2D_TRANSPOSE<64, 8, 3>; - - // Test that the operations have the required static members from XE_Copy_Op_2D_Base - static_assert(TransposeOp_32bit_2x4::AtomHeight == 2); - static_assert(TransposeOp_32bit_2x4::AtomWidth == 4); - static_assert(TransposeOp_32bit_2x4::CopyBits == 32); - - static_assert(TransposeOp_32bit_4x8::AtomHeight == 4); - static_assert(TransposeOp_32bit_4x8::AtomWidth == 8); - static_assert(TransposeOp_32bit_4x8::CopyBits == 32); - - static_assert(TransposeOp_64bit_8x2::AtomHeight == 8); - static_assert(TransposeOp_64bit_8x2::AtomWidth == 2); - static_assert(TransposeOp_64bit_8x2::CopyBits == 64); - - EXPECT_TRUE(true) << "XE_LOAD_2D_TRANSPOSE API types declared successfully"; +TEST(CuTe_Xe, XE_TRANSPOSE_2D_int32_4x8) { + test_xe_transpose_2d(); } -TEST(CuTe_Xe, XE_LOAD_2D_TRANSPOSE_Constraints) { - // Test that the compile-time constraints are enforced - - // Valid 32-bit operations - using Valid32_1 = XE_LOAD_2D_TRANSPOSE<32, 1, 1>; - using Valid32_2 = XE_LOAD_2D_TRANSPOSE<32, 16, 8>; // Width <= 8 - - // Valid 64-bit operations (Height == 8 && Width < 4) - using Valid64_1 = XE_LOAD_2D_TRANSPOSE<64, 8, 1>; - using Valid64_2 = XE_LOAD_2D_TRANSPOSE<64, 8, 2>; - using Valid64_3 = XE_LOAD_2D_TRANSPOSE<64, 8, 3>; - - static_assert(Valid32_1::CopyBits == 32); - static_assert(Valid32_2::CopyBits == 32); - static_assert(Valid64_1::CopyBits == 64); - static_assert(Valid64_2::CopyBits == 64); - static_assert(Valid64_3::CopyBits == 64); - - EXPECT_TRUE(true) << "XE_LOAD_2D_TRANSPOSE constraint validation successful"; +TEST(CuTe_Xe, XE_TRANSPOSE_2D_uint32_4x8) { + test_xe_transpose_2d(); } #else -TEST(CuTe_Xe, XE_LOAD_2D_TRANSPOSE_SKIPPED) { +TEST(CuTe_Xe, XE_TRANSPOSE_2D_SKIPPED) { GTEST_SKIP() << "XE_LOAD_2D_TRANSPOSE tests require IGC version 2.18 or higher. skipped"; } diff --git a/test/unit/cute/intel_xe/xe_vnni_2d.cpp b/test/unit/cute/intel_xe/xe_vnni_2d.cpp index 2112e474b0..45100aa9ff 100644 --- a/test/unit/cute/intel_xe/xe_vnni_2d.cpp +++ b/test/unit/cute/intel_xe/xe_vnni_2d.cpp @@ -26,44 +26,235 @@ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF POSSIBILITY OF SUCH DAMAGE. + * **************************************************************************************************/ + /* + * VNNI Usage Summary: + * + * This file demonstrates XE_LOAD_2D_VNNI usage in kernel context. + * + * Key points: + * 1. VNNI is used to load B matrix in GEMM operations + * 2. Hardware performs interleaving during load (free transformation) + * 3. VNNI data flows directly to DPAS operations + * 4. Only 8-bit and 16-bit data types supported + * 5. BlockWidth parameter creates multiple blocks (block_count = Width/BlockWidth) + * + * Real-world usage pattern: + * auto copy_b = make_block_2d_copy_B(XE_LOAD_2D_VNNI<16, 32, 16, 16>{}, mma, gB); + * copy(copy_b, tBgB, tBrB); // Load in VNNI format + * gemm(mma, tCrA, tBrB, tCrC); // DPAS consumes VNNI data + * + * See examples/12_bmg_moe_gemm_cute_interface/ for full GEMM implementation. + + */ + +#include "cutlass/detail/layout.hpp" + #include #include #include #include +#include +#include #include +#include + #include "cutlass_unit_test.h" +#include "utils.hpp" using namespace cute; +using namespace cutlass; +using namespace compat::experimental; + +#define SUBGROUP_SIZE (16) #if (IGC_VERSION_MAJOR > 2) || (IGC_VERSION_MAJOR == 2 && IGC_VERSION_MINOR >= 18) -TEST(CuTe_Xe, XE_LOAD_2D_VNNI_API_Declaration) { - // Template: XE_LOAD_2D_VNNI +// Kernel name for unique identification +template class XEVnniLoadKernelName; + +// VNNI load demonstration kernel +// Note: VNNI is designed for B matrix in GEMM context with DPAS consumption +// This simplified test only verifies the load operation executes without errors +template +void xe_vnni_load_kernel(SrcTensor src, DstTensor dst) { + using namespace cute; + using Element = typename SrcTensor::value_type; + + // Only execute with the first subgroup to avoid race conditions + if (sycl::ext::oneapi::this_work_item::get_nd_item<1>().get_group(0) == 0) { + // Get thread/subgroup information + auto local_id = int(sycl::ext::oneapi::this_work_item::get_nd_item<1>().get_local_id(0)); + + // ============================================ + // Use VNNI load instead of regular XE_LOAD_2D + // ============================================ + // Note: VNNI is typically used with make_block_2d_copy_B in GEMM context + // But for demonstration, we show the raw VNNI operation + using VnniOp = XE_LOAD_2D_VNNI; // BlockWidth = Width for single block + auto tiled_copy = make_block_2d_copy(VnniOp{}, src); + + // Get thread slice of the tiled copy + auto thr_copy = tiled_copy.get_slice(local_id); + + // Create coordinate tensor for a single tile + auto coord_shape = make_shape(Int{}, Int>{}); + Tensor coord_tile = make_identity_tensor(coord_shape); + + // Partition source coordinates and create destination fragment + auto thr_src_coord = thr_copy.partition_S(coord_tile); + auto thr_dst_frag = thr_copy.partition_fragment_D(coord_tile); + + // ============================================ + // THIS IS THE VNNI LOAD + // Hardware performs interleaving during this load + // Data in thr_dst_frag is now in VNNI interleaved format + // ============================================ + copy(tiled_copy, thr_src_coord, thr_dst_frag); + + // For verification, store back to destination + // Note: In real usage, thr_dst_frag would go directly to gemm(mma, tCrA, thr_dst_frag, tCrC) + using StoreOp = XE_STORE_2D; + auto tiled_store = make_block_2d_copy(StoreOp{}, dst); + auto thr_store = tiled_store.get_slice(local_id); + + // Create destination coordinates for the store operation + auto thr_dst_coord = thr_store.partition_D(coord_tile); + auto thr_src_frag = thr_store.partition_fragment_S(coord_tile); + + // Copy the loaded data from registers to the fragment for storing + copy(thr_dst_frag, thr_src_frag); + + // Perform the store operation from registers to global memory + copy(tiled_store, thr_src_frag, thr_dst_coord); + + // Synchronize to ensure all threads complete their operations + sycl::group_barrier(sycl::ext::oneapi::this_work_item::get_nd_item<1>().get_group()); + } +} + +// Host test function for VNNI load operation +template +void test_xe_vnni_load() { + using namespace cute; + + // Matrix dimensions - must be compatible with block 2D constraints + constexpr int M = Height; + constexpr int N = Width * sizeof_bits_v / Bits; + + // Ensure proper alignment (required for block 2D operations) + constexpr int elem_alignment = 16 / sizeof(Element); + constexpr int aligned_N = ((N + elem_alignment - 1) / elem_alignment) * elem_alignment; + + // Allocate and initialize host data + cutlass::host_vector host_src(M * aligned_N); + cutlass::host_vector host_dst(M * aligned_N); + + + // Initialize source with test pattern + for (size_t i = 0; i < host_src.size(); ++i) { + // Use a safe conversion that works for all numeric types + if constexpr (std::is_floating_point_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + + // For floating-point types, convert through float + float val = static_cast(i % 256) / 255.0f; // Normalize to [0,1] + host_src[i] = Element(val); + } else { + // For integer types (including uint64_t) and char, direct conversion is safe + host_src[i] = static_cast(i % 256); + } + } + + // Copy to device + cutlass::device_vector device_src = host_src; + cutlass::device_vector device_dst(M * aligned_N); + + // Create tensors with proper layout + Tensor tensor_src = + make_tensor(make_gmem_ptr(device_src.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); - // Test that the VNNI operation types can be declared - using VNNIOp_8bit_2x32 = XE_LOAD_2D_VNNI<8, 2, 32>; - using VNNIOp_8bit_4x32 = XE_LOAD_2D_VNNI<8, 4, 32>; - using VNNIOp_16bit_2x16 = XE_LOAD_2D_VNNI<16, 2, 16>; - using VNNIOp_16bit_4x16 = XE_LOAD_2D_VNNI<16, 4, 16>; + Tensor tensor_dst = + make_tensor(make_gmem_ptr(device_dst.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); - // Test that the operations have the required static members from XE_Copy_Op_2D_Base - static_assert(VNNIOp_8bit_2x32::AtomHeight == 2); - static_assert(VNNIOp_8bit_2x32::AtomWidth == 32); - static_assert(VNNIOp_8bit_2x32::CopyBits == 8); + // Launch kernel - VNNI load demonstration + auto blockDim = compat::dim3(SUBGROUP_SIZE); + auto gridDim = compat::dim3(1); - static_assert(VNNIOp_16bit_2x16::AtomHeight == 2); - static_assert(VNNIOp_16bit_2x16::AtomWidth == 16); - static_assert(VNNIOp_16bit_2x16::CopyBits == 16); + launch, + XEVnniLoadKernelName>( + launch_policy{ + gridDim, blockDim, + kernel_properties{sycl_exp::sub_group_size} + }, + tensor_src, tensor_dst); - EXPECT_TRUE(true) << "XE_LOAD_2D_VNNI API types declared successfully"; + compat::wait_and_throw(); + + // Note: We do NOT verify data matches because VNNI performs interleaving transformation + // The loaded data is in VNNI format (hardware-interleaved for DPAS consumption) + // When stored back to memory, the interleaved pattern is visible + // In real usage, VNNI data goes directly to gemm()/DPAS, never stored back + // This test verifies that VNNI load operation executes without errors } +// ============================================ +// VNNI Tests - Only 8-bit and 16-bit supported +// ============================================ + +TEST(PVC_CuTe_Xe, XE_VNNI_2D_uint8) { + // VNNI is used for B matrix in GEMM - typically with BlockWidth creating multiple blocks + test_xe_vnni_load(); // 4 blocks of 16 + test_xe_vnni_load(); // 2 blocks of 32 + test_xe_vnni_load(); // 1 block of 64 +} + +TEST(PVC_CuTe_Xe, XE_VNNI_2D_int8) { + test_xe_vnni_load(); + test_xe_vnni_load(); + test_xe_vnni_load(); +} + +TEST(PVC_CuTe_Xe, XE_VNNI_2D_uint16) { + test_xe_vnni_load(); // 2 blocks of 16 + test_xe_vnni_load(); // 2 blocks of 16 + test_xe_vnni_load(); // 1 block of 32 +} + +TEST(PVC_CuTe_Xe, XE_VNNI_2D_int16) { + test_xe_vnni_load(); + test_xe_vnni_load(); + test_xe_vnni_load(); +} + +TEST(PVC_CuTe_Xe, XE_VNNI_2D_half) { + test_xe_vnni_load(); + test_xe_vnni_load(); + test_xe_vnni_load(); +} + +TEST(PVC_CuTe_Xe, XE_VNNI_2D_bfloat16) { + test_xe_vnni_load(); + test_xe_vnni_load(); + test_xe_vnni_load(); +} + +// Note: 32-bit and 64-bit types are NOT supported by VNNI +// VNNI only works with 8-bit and 16-bit data types + #else -TEST(CuTe_Xe, XE_LOAD_2D_VNNI_SKIPPED) { - GTEST_SKIP() << "XE_LOAD_2D_VNNI tests require IGC version 2.18 or higher. skipped"; +// For the fallback case +#include "cutlass_unit_test.h" + +TEST(PVC_CuTe_Xe, XE_VNNI_2D_SKIPPED) { + GTEST_SKIP() << "XE_VNNI_2D tests require IGC version 2.18 or higher. skipped"; } #endif