From 3c898898e656516d70104cb34d03bc579dbccfbc Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Wed, 26 Nov 2025 19:40:33 -0800 Subject: [PATCH 01/12] Add split reduction kernel --- .../collective/xe_fmha_fwd_epilogue.hpp | 57 +++- .../kernel/xe_fhma_fwd_kernel.hpp | 313 +++++++++++++++++- .../kernel/xe_reduce_split_k.h | 267 +++++++++++++++ .../kernel/xe_tile_scheduler.hpp | 83 ++++- .../06_bmg_flash_attention/06_xe_fmha_fwd.cpp | 8 +- .../06_bmg_flash_attention/CMakeLists.txt | 7 + .../xe_fmha_fwd_runner.hpp | 182 ++++++++-- 7 files changed, 878 insertions(+), 39 deletions(-) create mode 100644 applications/flash_attention_v2/kernel/xe_reduce_split_k.h diff --git a/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp b/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp index efa54931d3..38733d0cb2 100644 --- a/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp +++ b/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp @@ -92,6 +92,7 @@ class FMHAFwdEpilogue { Stride, E<0>>{}) )); using ReduceFragARow = decltype(reduce<1>(ReduceFragA{}, sycl::plus{})); + // static_assert(is_same_v, "dtype mismatched"); static auto default_tiled_copy_O_helper() { if constexpr (ReduceK{} == _1{}) @@ -149,7 +150,59 @@ class FMHAFwdEpilogue { using ElementA = typename FragA::element_type; // Reduce k-blocks of A and A_sum across WG, if needed. - auto [rA, rA_sum, active] = reduce_A(tArA, tA_max, tA_sum, thr_id); + auto [rA, rA_max_unused, rA_sum, active] = reduce_A(tArA, tA_max, tA_sum, thr_id); + + /* Some subgroups may not have any work to do; if so, quit early. */ + if (!active) return; + + /* Complete softmax, dividing out sums. */ + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA_sum.size(); i++) + rA_sum(i) = ElementA(1) / rA_sum(i); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA.size(); i++) + rA(i) *= broadcast<0>(rA_sum, rA, i); + + /* Tile output */ + Tensor cO = make_identity_tensor(O.shape()); // (q,v) + Tensor gO = local_tile(cO, TileShapeO{}, blk_qv); // (q,v) + + /* Prepare slices */ + TiledCopyO copy_o{O}; + auto thr_copy_o = copy_o.get_slice(thr_id); + + auto tOrO = thr_copy_o.partition_sg_fragment_S(gO); + auto tOgO = thr_copy_o.partition_D(gO); + + /* Reorder tile and write out */ + reorder(rA, tOrO); + copy(copy_o, tOrO, tOgO); + } + + // splitK version + template + CUTLASS_DEVICE + void + operator()(TensorO2D const& O, // Global O tensor: (q,v) + FragA & tArA, // O accumulator: (q,v) + FragARow & tA_max, // Softmax row-wise max accumulator + FragARow & tA_sum, // Softmax row-wise sum accumulator + QVCoord blk_qv, // WG tile indices: (q,v) + int thr_id, // Work-item ID + const TensorO2D & exp_sums, // Global exp sum tensor + const TensorO2D & max_logits // Global max logits tensor + ) { + + using namespace cute; + using ElementA = typename FragA::element_type; + + // Reduce k-blocks of A and A_sum across WG, if needed. + auto [rA, rA_max, rA_sum, active] = reduce_A(tArA, tA_max, tA_sum, thr_id); + + // store exp sum and max logits for current KV split + exp_sums(0) = rA_sum(0); + max_logits(0) = rA_max(0); /* Some subgroups may not have any work to do; if so, quit early. */ if (!active) return; @@ -285,7 +338,7 @@ class FMHAFwdEpilogue { } } } - return std::make_tuple(rA, rA_sum, active); + return std::make_tuple(rA, rA_max, rA_sum, active); } } }; diff --git a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp index f5905f746a..2e3a822d57 100644 --- a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp +++ b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp @@ -58,7 +58,7 @@ struct FMHAProblemShape { /////////////////////////////////////////////////////////////////////////////// -template +template class XeFMHAFwdKernel { public: @@ -134,6 +134,7 @@ class XeFMHAFwdKernel { MainloopArguments mainloop{}; EpilogueArguments epilogue{}; KernelHardwareInfo hw_info{}; + int num_kv_splits = -1; }; // Kernel entry point API @@ -206,7 +207,7 @@ class XeFMHAFwdKernel { CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { - auto [blk_q, blk_v, head_q, idx_b] = tile_scheduler.get_block_coord(); // (Q,V,h,b) + auto [blk_q, blk_v, head_q, idx_b, unused] = tile_scheduler.get_block_coord(); // (Q,V,h,b) auto blk_qv = make_coord(blk_q, blk_v); int head = head_q / head_group_q; @@ -284,7 +285,7 @@ class XeFMHAFwdKernel { }; template -class XeFMHAFwdDynamicSplitKernel { +class XeFMHAFwdPersistentKernel { public: // @@ -319,7 +320,7 @@ class XeFMHAFwdDynamicSplitKernel { using ElementA = typename CollectiveMainloop::ElementA; // Tile scheduler derived types - static_assert(is_same_v); + static_assert(is_same_v); using TileScheduler = TileScheduler_; using TileSchedulerParams = typename TileScheduler::Params; @@ -367,6 +368,7 @@ class XeFMHAFwdDynamicSplitKernel { MainloopArguments mainloop{}; EpilogueArguments epilogue{}; KernelHardwareInfo hw_info{}; + int num_kv_splits = -1; }; // Kernel entry point API @@ -697,4 +699,307 @@ class XeFMHAFwdDynamicSplitKernel { } }; +template +class XeFMHAFwdSplitKVKernel { + +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + using VariableLength = cutlass::fmha::collective::VariableLength; + static constexpr bool is_var_len = cutlass::fmha::collective::is_variable_length_v; + // TODO: support later + static_assert(SplitKV && !is_var_len, "XeFMHAFwdSplitKVKernel only supports variable length without KV split"); + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + using TiledMMAQK = typename CollectiveMainloop::TiledMMAQK; + using TiledMMAPV = typename CollectiveMainloop::TiledMMAPV; + using TileShapeQK = typename CollectiveMainloop::TileShapeQK; + using TileShapePV = typename CollectiveMainloop::TileShapePV; + using SubgroupLayoutQK = typename CollectiveMainloop::SubgroupLayoutQK; + using ElementQ = typename CollectiveMainloop::TensorQ::element_type; + using ElementK = typename CollectiveMainloop::TensorK::element_type; + using ElementV = typename CollectiveMainloop::TensorV::element_type; + + using StrideQ = decltype(stride(typename CollectiveMainloop::TensorQ{})); + using StrideK = decltype(stride(typename CollectiveMainloop::TensorK{})); + using StrideV = decltype(stride(typename CollectiveMainloop::TensorV{})); + + using SGPerWG = typename CollectiveMainloop::SGPerWG; + + using FragA = typename CollectiveMainloop::FragA; + using FragARow = typename CollectiveMainloop::FragARow; + + // Tile scheduler derived types + using TileScheduler = TileScheduler_; + using TileSchedulerParams = typename TileScheduler::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + using TileShapeO = typename CollectiveEpilogue::TileShapeO; + using ElementO = typename CollectiveEpilogue::TensorO::element_type; + using StrideO = decltype(stride(typename CollectiveEpilogue::TensorO{})); + + // Kernel level shared memory storage + using MainloopSharedStorage = typename CollectiveMainloop::SharedStorage; + using EpilogueSharedStorage = typename CollectiveEpilogue::SharedStorage; + union SharedStorage { + MainloopSharedStorage mainloop; + EpilogueSharedStorage epilogue; + }; + + static constexpr int SharedStorageSize = is_empty_v ? size_t(0) + : sizeof(SharedStorage); + + static constexpr int max_num_kv_splits = 8; + + // Device side arguments + struct KernelArguments { + ProblemShape shape; + const ElementQ *Q; + StrideQ dQ; + const ElementK *K; + StrideK dK; + const ElementV *V; + StrideV dV; + ElementO *O; + StrideO dO; + // TODO: whether same dtype as output or accum? + ElementO *Oaccum; + StrideO dOaccum; + ElementO *exp_sums; + StrideO dExp_sums; + ElementO *max_logits; + StrideO dMax_logits; + }; + using KernelParams = KernelArguments; + + struct Arguments { + KernelArguments kernel{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + int num_kv_splits = -1; // no split by default + }; + + // Kernel entry point API + struct Params { + KernelParams kernel; + MainloopParams mainloop; + EpilogueParams epilogue; + TileSchedulerParams scheduler; + }; + + // + // Methods + // + + static Params to_underlying_arguments(Arguments const &args, void *workspace) { + return {args.kernel, + CollectiveMainloop::to_underlying_arguments(args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.kernel.shape, args.hw_info, TileShapeO{}, args.num_kv_splits)}; + } + + static bool can_implement(Arguments const &args) { + if (args.kernel.shape.seq_len_qo != 1) { + // decode only + return false; + } + + if (args.num_kv_splits > max_num_kv_splits) { + return false; + } + + return CollectiveMainloop::can_implement(args.mainloop) + && CollectiveEpilogue::can_implement(args.epilogue); + } + + static int get_workspace_size(Arguments const &args) { return 0; } + + static cutlass::Status initialize_workspace(Arguments const &args, void *workspace = nullptr, + cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const ¶ms) { + return TileScheduler::template get_grid_shape(params.scheduler); + } + + static dim3 get_block_shape() { return dim3(SGPerWG::value * intel::sg_size, 1, 1); } + + CUTLASS_DEVICE + Shape get_sequence_length_shape(ProblemShape const& problem_shape, int const& batch) { + if constexpr (is_var_len) { + return cutlass::fmha::collective::apply_variable_length(Shape{problem_shape.seq_len_qo, problem_shape.seq_len_kv}, batch); + } else { + return Shape{problem_shape.seq_len_qo, problem_shape.seq_len_kv}; + } + } + + CUTLASS_DEVICE + void operator()(Params const ¶ms, char *smem_buf) + { + using namespace sycl::ext::oneapi::this_work_item; + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + auto &p = params.kernel; + ProblemShape const& s = p.shape; + int head_group_q = s.num_heads_q / s.num_heads_kv; + + int thr_id = int(ThreadIdxX()); + int sub_group_id = thr_id / intel::sg_size; + int q_sg_tile = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{}))); + + auto cS = make_identity_tensor(take<0,2>(TiledMMAQK{}.tile_mnk())); + auto tScS = TiledMMAQK{}.get_slice(thr_id).partition_C(cS); + auto q_offset_wi = get<0>(tScS(0)); + auto q_offset_sg = group_broadcast(sycl::ext::oneapi::this_work_item::get_sub_group(), q_offset_wi, 0); + + TileScheduler tile_scheduler{params.scheduler}; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + // auto [blk_q, blk_v, head_q, idx_b] = tile_scheduler.get_block_coord(); // (Q,V,h,b) + auto [blk_q, blk_v, head_q, idx_b, idx_kv_split] = tile_scheduler.get_block_coord(); // (Q,V,h,b) + auto blk_qv = make_coord(blk_q, blk_v); + int head = head_q / head_group_q; + + auto sequence_length_shape = get_sequence_length_shape(s, idx_b); + auto [seq_len_qo, seq_len_kv] = sequence_length_shape; + if (blk_q * get<0>(TileShapeQK{}) >= seq_len_qo) continue; + + auto offset = cute::min(seq_len_qo, seq_len_kv); + auto discard_seq_coord = seq_len_qo - offset; + auto full_tile_offset = seq_len_kv - offset; + int seq_coord = cute::min(seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + q_offset_sg)); + + if (CollectiveMainloop::CausalMask && seq_coord < discard_seq_coord) continue; + const int seq_len = CollectiveMainloop::CausalMask ? full_tile_offset + cute::min(seq_len_kv, seq_coord - discard_seq_coord) + q_sg_tile : seq_len_kv; + const int k_blocks = cute::ceil_div(seq_len, get<1>(TileShapeQK{})); + + int offset_q = 0, offset_k = 0, offset_v = 0, offset_o = 0; + int offset_exp_sums = 0, offset_max_logits = 0; + if constexpr (is_var_len) { + auto qo_cumulative = s.seq_len_qo.cumulative_length; + auto kv_cumulative = s.seq_len_kv.cumulative_length; + offset_q = s.num_heads_q * s.head_size_qk * qo_cumulative[idx_b]; + offset_k = s.num_heads_kv * s.head_size_qk * kv_cumulative[idx_b]; + offset_v = s.num_heads_kv * s.head_size_vo * kv_cumulative[idx_b]; + offset_o = s.num_heads_q * s.head_size_vo * qo_cumulative[idx_b]; + } + + auto batch_dim = is_var_len ? 1 : s.batch; + auto shape_Q = make_shape(seq_len_qo, s.head_size_qk, s.num_heads_q, batch_dim); + // 4D shape + decltype(shape_Q) shape_K, shape_V, shape_O, shape_exp_sums, shape_max_logits; + // auto shape_K = make_shape(seq_len_kv, s.head_size_qk, s.num_heads_kv, batch_dim); + // auto shape_V = make_shape(s.head_size_vo, seq_len_kv, s.num_heads_kv, batch_dim); + // auto shape_O = make_shape(seq_len_qo, s.head_size_vo, s.num_heads_kv, batch_dim); + + int num_blocks_per_split, kv_split_offset, num_effective_kv_blocks; + + if constexpr (SplitKV) { + num_blocks_per_split = cute::ceil_div(k_blocks, params.scheduler.num_kv_splits_); + kv_split_offset = idx_kv_split * num_blocks_per_split; + num_effective_kv_blocks = cute::min(k_blocks - kv_split_offset, num_blocks_per_split); + + shape_K = make_shape(num_effective_kv_blocks * get<1>(TileShapeQK{}), s.head_size_qk, s.num_heads_kv, batch_dim); + shape_V = make_shape(s.head_size_vo, num_effective_kv_blocks * get<1>(TileShapeQK{}), s.num_heads_kv, batch_dim); + shape_O = make_shape(seq_len_qo, s.head_size_vo, 1, s.num_heads_q * batch_dim); + + shape_exp_sums = make_shape(s.seq_len_qo, 1, s.num_heads_q, batch_dim); + shape_max_logits = make_shape(s.seq_len_qo, 1, s.num_heads_q, batch_dim); + + // TODO: adapt for var length + // offset_k = ((kv_split_offset * get<1>(TileShapeQK{})) * s.head_size_qk) * s.num_heads_kv * batch_dim; + offset_k = kv_split_offset * get<1>(TileShapeQK{}); + // offset_v = s.num_heads_kv * s.head_size_vo * (idx_b * seq_len + kv_split_offset * get<1>(TileShapeQK{})); + // offset_v = ((kv_split_offset * get<1>(TileShapeQK{})) * s.head_size_vo) * s.num_heads_kv * batch_dim; + offset_v = kv_split_offset * get<1>(TileShapeQK{}); + + // assume: Oaccum is allocated with shape (batch * num_heads_q, num_kv_splits, seq_len_qo, head_size_vo) + offset_o = s.head_size_vo * seq_len_qo * idx_kv_split; + + offset_exp_sums = idx_kv_split; + offset_max_logits = idx_kv_split; + } else { + shape_K = make_shape(seq_len_kv, s.head_size_qk, s.num_heads_kv, batch_dim); + shape_V = make_shape(s.head_size_vo, seq_len_kv, s.num_heads_kv, batch_dim); + shape_O = make_shape(seq_len_qo, s.head_size_vo, s.num_heads_kv, batch_dim); + } + + auto dcQ = const_cast(p.Q + offset_q); + auto dcK = const_cast(p.K + offset_k); + auto dcV = const_cast(p.V + offset_v); + auto ptrO = (SplitKV ? p.Oaccum : p.O) + offset_o; + auto ptrExp_sums = p.exp_sums + offset_exp_sums; + auto ptrMax_logits = p.max_logits + offset_max_logits; + + auto stride_q = is_var_len ? cutlass::make_cute_packed_stride(StrideQ{}, shape_Q) : p.dQ; + auto stride_k = is_var_len ? cutlass::make_cute_packed_stride(StrideK{}, shape_K) : p.dK; + auto stride_v = is_var_len ? cutlass::make_cute_packed_stride(StrideV{}, shape_V) : p.dV; + auto stride_o = is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_O) : p.dO; + auto stride_exp_sums = p.dExp_sums; + auto stride_max_logits = p.dMax_logits; + + Tensor Q = make_tensor(make_gmem_ptr(dcQ), make_layout(shape_Q, stride_q)); + Tensor K = make_tensor(make_gmem_ptr(dcK), make_layout(shape_K, stride_k)); + Tensor V = make_tensor(make_gmem_ptr(dcV), make_layout(shape_V, stride_v)); + Tensor O = make_tensor(make_gmem_ptr(ptrO), make_layout(shape_O, stride_o)); + + Tensor exp_sums = make_tensor(make_gmem_ptr(ptrExp_sums), make_layout(shape_exp_sums, stride_exp_sums)); + Tensor max_logits = make_tensor(make_gmem_ptr(ptrMax_logits), make_layout(shape_max_logits, stride_max_logits)); + + // O accumulator types + FragA tArA; + FragARow tA_max, tA_sum; + + // Main loop + int l_coord = is_var_len ? 0 : idx_b; + + int start_blk = SplitKV ? kv_split_offset : 0; + int end_blk = SplitKV ? (kv_split_offset + num_effective_kv_blocks) : k_blocks; + + CollectiveMainloop mainloop(params.mainloop, shared_storage.mainloop); + + mainloop(Q(_,_,head_q,l_coord), + K(_,_,head,l_coord), + V(_,_,head,l_coord), + tArA, tA_max, tA_sum, + blk_qv, start_blk, end_blk, k_blocks, + thr_id, seq_len, + full_tile_offset, discard_seq_coord); + + if constexpr (!is_empty_v && !is_empty_v) { + sycl::group_barrier(get_work_group<3>()); + } + + // Epilogue + CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; + + if constexpr (SplitKV) { + epilogue(O(_,_,0,head_q + idx_b * s.num_heads_q), + tArA, tA_max, tA_sum, + blk_qv, thr_id, + exp_sums(_,_,head_q,l_coord), + max_logits(_,_,head_q,l_coord)); + } else { + epilogue(O(_,_,head_q,l_coord), + tArA, tA_max, tA_sum, + blk_qv, thr_id); + } + } +}; + +}; + } // namespace cutlass::fmha::kernel diff --git a/applications/flash_attention_v2/kernel/xe_reduce_split_k.h b/applications/flash_attention_v2/kernel/xe_reduce_split_k.h new file mode 100644 index 0000000000..8802ba3f35 --- /dev/null +++ b/applications/flash_attention_v2/kernel/xe_reduce_split_k.h @@ -0,0 +1,267 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a reduction over densely packed tensors in global memory +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/kernel_hardware_info.hpp" + +#include "flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp" +#include "flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp" +#include "cute/util/type_traits.hpp" +#include "flash_attention_v2/collective/fmha_fusion.hpp" +#include "flash_attention_v2/kernel/xe_tile_scheduler.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class ReduceSplitK { +public: + + using ProblemShape = ProblemShape_; + using TileScheduler = TileScheduler_; + static_assert(is_same_v, + "ReduceSplitK kernel requires XeReduceSplitKTileScheduler"); + using TileSchedulerParams = typename TileScheduler::Params; + + using ElementO = typename FMHAKernel_::ElementO; + using StrideO = typename FMHAKernel_::StrideO; + using TileShapeO = typename FMHAKernel_::TileShapeO; + + using SGPerWG = typename FMHAKernel_::SGPerWG; + + // num values (head_dim) processed by each thread + constexpr static int num_vals_per_thread = int(get<1>(TileShapeO{}) / (SGPerWG::value * intel::sg_size)); + + // + // Types + // + + struct KernelArguments { + ProblemShape shape; + // outputs: + ElementO *O; + StrideO dO; + // below are inputs + // TODO: whether same dtype as output or accum? + ElementO *Oaccum; + StrideO dOaccum; + ElementO *exp_sums; + StrideO dExp_sums; + ElementO *max_logits; + StrideO dMax_logits; + }; + using KernelParams = KernelArguments; + + struct Arguments { + KernelArguments kernel{}; + KernelHardwareInfo hw_info{}; + int num_kv_splits = -1; // no split by default + }; + + /// Params structure + struct Params { + KernelParams kernel; + TileSchedulerParams scheduler; + }; + + struct SharedStorage { + cutlass::Array max_logits_slm_array; + cutlass::Array exp_sums_slm_array; + }; + + static constexpr int SharedStorageSize = is_empty_v ? size_t(0) + : sizeof(SharedStorage); + +public: + + static Params to_underlying_arguments(Arguments const &args, void *workspace) { + return {args.kernel, + TileScheduler::to_underlying_arguments(args.kernel.shape, args.hw_info, TileShapeO{}, args.num_kv_splits)}; + } + + static bool can_implement(Arguments const &args) { + // only support decode + if (args.kernel.shape.seq_len_qo > 1) { + return false; + } + + if (args.num_kv_splits > FMHAKernel_::max_num_kv_splits) { + return false; + } + return true; + } + + static int get_workspace_size(Arguments const &args) { return 0; } + + static cutlass::Status initialize_workspace(Arguments const &args, void *workspace = nullptr, + cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const ¶ms) { + return TileScheduler::template get_grid_shape(params.scheduler); + } + + static dim3 get_block_shape() { return dim3(SGPerWG::value * intel::sg_size, 1, 1); } + + /// Perform a reduction + CUTLASS_DEVICE + void operator()(Params const ¶ms, char *smem_buf) { + using namespace sycl::ext::oneapi::this_work_item; + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + auto &p = params.kernel; + ProblemShape const& s = p.shape; + + int thr_id = int(ThreadIdxX()); + int sub_group_id = thr_id / intel::sg_size; + int tid_in_sg = thr_id % intel::sg_size; + + TileScheduler tile_scheduler{params.scheduler}; + auto num_kv_splits = params.scheduler.num_kv_splits; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto [seq_idx, head_q, idx_b] = tile_scheduler.get_block_coord(); + + int offset_o = 0, offset_o_accum = 0; + int offset_exp_sums = 0, offset_max_logits = 0; + + auto shape_O = make_shape(s.seq_len_qo, s.head_size_vo, 1); + auto shape_Oaccum = make_shape(s.seq_len_qo, s.head_size_vo, num_kv_splits); + + auto shape_exp_sums = make_shape(s.seq_len_qo, num_kv_splits, 1); + auto shape_max_logits = make_shape(s.seq_len_qo, num_kv_splits, 1); + + // assume: Oaccum is allocated with shape (batch * num_heads_q, num_kv_splits, seq_len_qo, head_size_vo) + offset_o_accum = (idx_b * s.num_heads_q + head_q) * num_kv_splits * s.seq_len_qo * s.head_size_vo; + offset_o = (idx_b * s.num_heads_q + head_q) * s.seq_len_qo * s.head_size_vo; + + offset_exp_sums = (idx_b * s.num_heads_q + head_q) * s.seq_len_qo; + offset_max_logits = (idx_b * s.num_heads_q + head_q) * s.seq_len_qo; + auto dcOaccum = const_cast(p.Oaccum + offset_o_accum); + auto ptrO = p.O + offset_o; + auto ptrExp_sums = p.exp_sums + offset_exp_sums; + auto ptrMax_logits = p.max_logits + offset_max_logits; + + using Stride_O = cute::Stride, int64_t>; + using Stride_Oaccum = Stride_O; + using Stride_Exp_sums = Stride_O; + + // 3D + // static_assert(is_same_v, "dtype mismatched"); + // static_assert(is_same_v(StrideO{})), float>, "dtype mismatched"); + auto stride_o_accum = cutlass::make_cute_packed_stride(Stride_Oaccum{}, shape_Oaccum); + // 2D + auto stride_o = cutlass::make_cute_packed_stride(Stride_O{}, shape_O); + auto stride_exp_sums = cutlass::make_cute_packed_stride(Stride_Exp_sums{}, shape_exp_sums); + auto stride_max_logits = cutlass::make_cute_packed_stride(Stride_Exp_sums{}, shape_max_logits); + + Tensor Oaccum = make_tensor(make_gmem_ptr(dcOaccum), make_layout(shape_Oaccum, stride_o_accum)); + Tensor O = make_tensor(make_gmem_ptr(ptrO), make_layout(shape_O, stride_o)); + + Tensor exp_sums = make_tensor(make_gmem_ptr(ptrExp_sums), make_layout(shape_exp_sums, stride_exp_sums)); + Tensor max_logits = make_tensor(make_gmem_ptr(ptrMax_logits), make_layout(shape_max_logits, stride_max_logits)); + + // static_assert(is_same_v, "dtype mismatched"); + + // Step 1: reduce max logits across different partitions + // store into SLM for later use + + ElementO global_max_logits = cutlass::platform::numeric_limits::lowest(); + ElementO global_exp_sums = 0; + // only first subgroup participates + if (thr_id < num_kv_splits) { + ElementO cur_max_logit = max_logits(0, thr_id, 0); + global_max_logits = sycl::max(global_max_logits, cur_max_logit); + shared_storage.max_logits_slm_array[thr_id] = cur_max_logit; + + ElementO cur_exp_sum = exp_sums(0, thr_id, 0); + shared_storage.exp_sums_slm_array[thr_id] = cur_exp_sum; + global_exp_sums = cur_exp_sum; + } + + if (sub_group_id == 0) { + // reduce within subgroup + global_max_logits = reduce_over_group(get_sub_group(), global_max_logits, sycl::maximum<>()); + global_exp_sums = reduce_over_group(get_sub_group(), global_exp_sums, sycl::plus<>()); + + // broadcast to other threads + sycl::group_broadcast(get_work_group<3>(), global_max_logits, 0); + sycl::group_broadcast(get_work_group<3>(), global_exp_sums, 0); + } + + // barrier for SLM writes finished + sycl::group_barrier(get_work_group<3>()); + + ElementO inv_global_exp_sums = 1. / global_exp_sums; + + // step 2: rescale Oaccum and write back to O + for (int idx = thr_id; idx < s.head_size_vo; idx += SGPerWG::value * intel::sg_size) { + ElementO acc = 0; + for (int i = 0; i < num_kv_splits; ++i) { + ElementO local_max_logit = shared_storage.max_logits_slm_array[i]; + ElementO local_exp_sum = shared_storage.exp_sums_slm_array[i]; + + ElementO rescale = sycl::native::exp2(local_max_logit - global_max_logits); + + // in FMHA epilogue, it's divided by local_exp_sum + ElementO adjusted_o_accum = Oaccum(0, idx, i) * local_exp_sum; + acc += adjusted_o_accum * rescale; + } + acc *= inv_global_exp_sums; + O(0, idx, 0) = acc; + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace reduction +} // namespace cutlass diff --git a/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp b/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp index 24a686993c..2197c26b0c 100644 --- a/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp +++ b/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp @@ -45,6 +45,8 @@ struct XeFHMAIndividualTileScheduler { struct Params { dim3 grid; FastDivmod divmod_num_heads; + FastDivmod divmod_batch; + int num_kv_splits_ = -1; }; bool valid_ = true; @@ -56,14 +58,18 @@ struct XeFHMAIndividualTileScheduler { template static Params to_underlying_arguments( ProblemShape const& shape, KernelHardwareInfo hw_info, - TileShape const& tile_shape) + TileShape const& tile_shape, const int &num_kv_splits = -1) { using namespace cute; dim3 grid(size(ceil_div(shape.head_size_vo, get<1>(tile_shape))), // V size(ceil_div(shape.seq_len_qo, get<0>(tile_shape))), // Q size(shape.batch * shape.num_heads_q)); // (h,b) -- split later - return Params{grid, {shape.num_heads_q}}; + if (num_kv_splits > 0) { + grid.z *= num_kv_splits; + } + std::cout << "XeFHMAIndividualTileScheduler Grid: (" << grid.x << ", " << grid.y << ", " << grid.z << ")\n"; + return Params{grid, {shape.num_heads_q}, {shape.batch * shape.num_heads_q}, num_kv_splits}; } template @@ -79,10 +85,18 @@ struct XeFHMAIndividualTileScheduler { CUTLASS_DEVICE auto get_block_coord() { using namespace cute; - int idx_b = BlockIdxZ(); - int head; + int idx_kv_split = BlockIdxZ(); + int head, idx_b; + + if (params.num_kv_splits_ > 1) { + params.divmod_batch(idx_kv_split, idx_b, idx_kv_split); + params.divmod_num_heads(idx_b, head, idx_b); + return make_coord(BlockIdxY(), BlockIdxX(), head, idx_b, idx_kv_split); + } + + idx_b = idx_kv_split; params.divmod_num_heads(idx_b, head, idx_b); - return make_coord(BlockIdxY(), BlockIdxX(), head, idx_b); + return make_coord(BlockIdxY(), BlockIdxX(), head, idx_b, (int)-1); } CUTLASS_DEVICE @@ -92,7 +106,7 @@ struct XeFHMAIndividualTileScheduler { } }; -struct XeFHMAIndividualPersistentTileScheduler { +struct XeFHMAPersistentTileScheduler { struct Params { dim3 grid; @@ -107,7 +121,7 @@ struct XeFHMAIndividualPersistentTileScheduler { int num_batch_heads_; CUTLASS_DEVICE - XeFHMAIndividualPersistentTileScheduler(Params const& params, int kv_tile_size, + XeFHMAPersistentTileScheduler(Params const& params, int kv_tile_size, int local_num_kv_blocks, int num_batch_heads) : params(params), kv_tile_size_(kv_tile_size), local_num_kv_blocks_(local_num_kv_blocks), num_batch_heads_(num_batch_heads) {} @@ -154,7 +168,60 @@ struct XeFHMAIndividualPersistentTileScheduler { } CUTLASS_DEVICE - XeFHMAIndividualPersistentTileScheduler& operator++() { + XeFHMAPersistentTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +struct XeReduceSplitKTileScheduler { + + struct Params { + dim3 grid; + FastDivmod divmod_num_heads; + int num_kv_splits = -1; + }; + + bool valid_ = true; + Params params; + + CUTLASS_DEVICE + XeReduceSplitKTileScheduler(Params const& params) : params(params) {} + + template + static Params to_underlying_arguments( + ProblemShape const& shape, KernelHardwareInfo hw_info, + TileShape const& tile_shape, const int &num_kv_splits = -1) + { + using namespace cute; + + // dim3 grid(size(ceil_div(shape.head_size_vo, get<1>(tile_shape))), // V + // size(ceil_div(shape.seq_len_qo, get<0>(tile_shape))), // Q + // size(shape.batch * shape.num_heads_q)); // (h,b) -- split later + dim3 grid(shape.seq_len_qo, shape.num_heads_q, shape.batch); + std::cout << "Reduce Split K Grid: (" << grid.x << ", " << grid.y << ", " << grid.z << ")\n"; + return Params{grid, {shape.num_heads_q}, num_kv_splits}; + } + + template + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + + return make_coord(BlockIdxX(), BlockIdxY(), BlockIdxZ()); + } + + CUTLASS_DEVICE + XeReduceSplitKTileScheduler& operator++() { valid_ = false; return *this; } diff --git a/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp b/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp index 9de908336f..4fdfa6c0ae 100644 --- a/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp +++ b/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp @@ -171,9 +171,11 @@ int main(int argc, const char **argv) { #endif #if PERSISTENT - return FMHAConfig::run(options); + return FMHAConfig::run(options); +#elif SPLITKV + return FMHAConfig::run(options); #else - return options.is_causal ? FMHAConfig::run(options) - : FMHAConfig::run(options); + return options.is_causal ? FMHAConfig::run(options) + : FMHAConfig::run(options); #endif } diff --git a/examples/06_bmg_flash_attention/CMakeLists.txt b/examples/06_bmg_flash_attention/CMakeLists.txt index 435a65e6ea..89dfe27be7 100644 --- a/examples/06_bmg_flash_attention/CMakeLists.txt +++ b/examples/06_bmg_flash_attention/CMakeLists.txt @@ -52,6 +52,12 @@ foreach(HEAD_DIM 64 96 128 192) ) endif() + # specific test for split kernel + cutlass_example_add_executable( + 06_xe_fmha_fwd_decode_splitkv_${INPUT_TYPE}_hdim${HEAD_DIM} + 06_xe_fmha_fwd.cpp + ) + if(INPUT_TYPE STREQUAL "bfloat16_t") set(INPUT_MACRO "IS_BFLOAT16") elseif(INPUT_TYPE STREQUAL "float_e5m2_t") @@ -65,6 +71,7 @@ foreach(HEAD_DIM 64 96 128 192) if (NOT HEAD_DIM STREQUAL 192) target_compile_definitions(06_xe_fmha_fwd_decode_persistent_${INPUT_TYPE}_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} DECODE PERSISTENT SHOW_DIFF=1 INPUT_TYPE=${INPUT_TYPE} ${INPUT_MACRO}) endif() + target_compile_definitions(06_xe_fmha_fwd_decode_splitkv_${INPUT_TYPE}_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} DECODE SPLITKV SHOW_DIFF=1 INPUT_TYPE=${INPUT_TYPE} ${INPUT_MACRO}) endforeach() cutlass_example_add_executable( diff --git a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp index b1e9f0284d..c6b7c11d56 100644 --- a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp +++ b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp @@ -37,6 +37,7 @@ #include "flash_attention_v2/collective/fmha_fusion.hpp" #include "flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp" #include "flash_attention_v2/kernel/xe_tile_scheduler.hpp" +#include "flash_attention_v2/kernel/xe_reduce_split_k.h" #include "cutlass/util/GPU_Clock.hpp" #include "cutlass/util/sycl_event_manager.hpp" #include @@ -64,10 +65,11 @@ struct Options { int batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo, iterations; float softmax_scale; + int num_kv_splits; Options() : help(false), error(false), is_causal(false), varlen(false), batch(32), num_heads_q(16), num_heads_kv(16), seq_len_qo(512), head_size_qk(128), - seq_len_kv(512), head_size_vo(128), iterations(100), softmax_scale(1.f), scheduler("Individual") {} + seq_len_kv(512), head_size_vo(128), iterations(100), softmax_scale(1.f), num_kv_splits(1), scheduler("Individual") {} // Parses the command line void parse(int argc, char const **args) { @@ -101,6 +103,11 @@ struct Options { #endif #ifdef DECODE cmd.get_cmd_line_argument("seq_len_qo", seq_len_qo, 1); +#ifdef SPLITKV + cmd.get_cmd_line_argument("num_kv_splits", num_kv_splits, 2); +#else + cmd.get_cmd_line_argument("num_kv_splits", num_kv_splits, 1); +#endif #else cmd.get_cmd_line_argument("seq_len_qo", seq_len_qo, seq_len_kv); #endif @@ -127,6 +134,7 @@ struct Options { << " --seq_len_kv= Sets the Sequence length of the Key-Value pair in Multi-Head Self Attention module\n" << " --head_size_qk= Sets the Attention Head dimension of the 1st Matrix Multiplication in Multi-Head Self Attention module\n" << " --head_size_vo= Sets the Attention Head dimension of the 2nd Matrix Multiplication in Multi-Head Self Attention module\n" + << " --num_kv_splits= Sets the Number of Key-Value splits in Multi-Head Self Attention module\n" << " --iterations= Iterations\n\n"; return out; @@ -166,7 +174,8 @@ using LayoutK = cutlass::layout::ColumnMajor; using LayoutV = cutlass::layout::RowMajor; using LayoutO = cutlass::layout::RowMajor; -template struct ExampleRunner { +template +struct ExampleRunner { using StrideQ = typename FMHAKernel::StrideQ; using StrideK = typename FMHAKernel::StrideK; @@ -183,6 +192,8 @@ template struct ExampleRunner { using ProblemShapeType = cutlass::fmha::kernel::FMHAProblemShape; + static_assert(!isSplitKV || !is_void_v, "Standalone reduction kernel is required if splitKV enabled"); + // // Data members // @@ -192,12 +203,22 @@ template struct ExampleRunner { StrideK stride_K; StrideV stride_V; StrideO stride_O; + StrideO stride_Oaccum; + StrideO stride_exp_sums; + StrideO stride_max_logits; uint64_t seed = 0; + int num_kv_splits; + cutlass::DeviceAllocation block_Q; cutlass::DeviceAllocation block_K; cutlass::DeviceAllocation block_V; cutlass::DeviceAllocation block_O; + // TODO: assume same dtype as outputs + cutlass::DeviceAllocation block_Oaccum; + cutlass::DeviceAllocation block_exp_sums; + cutlass::DeviceAllocation block_max_logits; + cutlass::DeviceAllocation block_ref_O; std::vector cumulative_seqlen_q; @@ -460,6 +481,7 @@ template struct ExampleRunner { ProblemShapeType initialize(const Options &options) { auto problem_shape_in = cute::make_tuple(options.batch, options.num_heads_q, options.num_heads_kv, options.seq_len_qo, options.seq_len_kv, options.head_size_qk, options.head_size_vo); ProblemShapeType shape; + num_kv_splits = options.num_kv_splits; decltype(problem_shape_in) problem_size; @@ -490,6 +512,18 @@ template struct ExampleRunner { block_O.reset(static_cast(batch) * num_heads_q * seq_len_qo * head_size_vo); block_ref_O.reset(static_cast(batch) * num_heads_q * seq_len_qo * head_size_vo); + if constexpr (isSplitKV) { + stride_Oaccum = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, head_size_vo, num_kv_splits, num_heads_q * batch)); + block_Oaccum.reset(static_cast(batch) * num_heads_q * seq_len_qo * head_size_vo * num_kv_splits); + + // assume seq_len_qo==1 + stride_exp_sums = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, num_kv_splits, num_heads_q, batch)); + block_exp_sums.reset(static_cast(batch) * num_heads_q * seq_len_qo * num_kv_splits); + + stride_max_logits = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, num_kv_splits, num_heads_q, batch)); + block_max_logits.reset(static_cast(batch) * num_heads_q * seq_len_qo * num_kv_splits); + } + initialize_block(block_Q, seed + 2023); initialize_block(block_K, seed + 2022); initialize_block(block_V, seed + 2021); @@ -540,26 +574,100 @@ template struct ExampleRunner { EventManager::getInstance().addEvent(event); } + static void run(typename FMHAKernel::Params params, typename ReductionSplitKernel::Params reduce_params) + { + namespace syclex = sycl::ext::oneapi::experimental; + namespace intelex = sycl::ext::intel::experimental; + + dim3 const block = FMHAKernel::get_block_shape(); + dim3 const grid = FMHAKernel::get_grid_shape(params); + + // configure smem size and carveout + int smem_size = FMHAKernel::SharedStorageSize; + + const auto sycl_block = compat::dim3(block.x, block.y, block.z); + const auto sycl_grid = compat::dim3(grid.x, grid.y, grid.z); + + // Launch parameters depend on whether SYCL compiler supports work-group scratch memory extension + compat::experimental::launch_properties launch_props { + syclex::work_group_scratch_size(smem_size), + }; + compat::experimental::kernel_properties kernel_props{ + syclex::sub_group_size, + intelex::grf_size<256> + }; + compat::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props}; + auto event = compat::experimental::launch, FMHAKernel>(policy, params); + + dim3 const reduce_grid = ReductionSplitKernel::get_grid_shape(reduce_params); + int reduce_smem_size = ReductionSplitKernel::SharedStorageSize; + const auto reduce_sycl_block = compat::dim3(block.x, block.y, block.z); + const auto reduce_sycl_grid = compat::dim3(reduce_grid.x, reduce_grid.y, reduce_grid.z); + compat::experimental::launch_properties launch_props_reduce { + syclex::work_group_scratch_size(reduce_smem_size), + }; + compat::experimental::launch_policy reduce_policy{reduce_sycl_grid, reduce_sycl_block, launch_props_reduce, kernel_props}; + + // wait for FA kernel finished + event.wait(); + + auto reduce_event = compat::experimental::launch, ReductionSplitKernel>(reduce_policy, reduce_params); + + EventManager::getInstance().addEvent(event); + EventManager::getInstance().addEvent(reduce_event); + } + cutlass::Status run(const Options &options, const cutlass::KernelHardwareInfo &hw_info) { ProblemShapeType shape = initialize(options); - typename FMHAKernel::Arguments arguments{ - { + typename FMHAKernel::KernelArguments kernel_args; + if constexpr (isSplitKV) { + kernel_args = typename FMHAKernel::KernelArguments { + shape, + block_Q.get(), stride_Q, + block_K.get(), stride_K, + block_V.get(), stride_V, + block_O.get(), stride_O, + block_Oaccum.get(), stride_Oaccum, + block_exp_sums.get(), stride_exp_sums, + block_max_logits.get(), stride_max_logits, + }; + } else { + kernel_args = typename FMHAKernel::KernelArguments { shape, block_Q.get(), stride_Q, block_K.get(), stride_K, block_V.get(), stride_V, block_O.get(), stride_O - }, + }; + } + + typename FMHAKernel::Arguments arguments { + kernel_args, {options.softmax_scale}, {}, - hw_info + hw_info, + options.num_kv_splits, + }; + + typename ReductionSplitKernel::Arguments reduce_arg { + {shape, + block_O.get(), stride_O, + block_Oaccum.get(), stride_Oaccum, + block_exp_sums.get(), stride_exp_sums, + block_max_logits.get(), stride_max_logits}, + hw_info, + options.num_kv_splits }; // Define device-global scratch memory size_t workspace_size = FMHAKernel::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); + size_t reduce_workspace_size = 0; + if constexpr (isSplitKV) { + reduce_workspace_size = ReductionSplitKernel::get_workspace_size(reduce_arg); + } + cutlass::device_memory::allocation workspace(workspace_size + reduce_workspace_size); if (!FMHAKernel::can_implement(arguments)) { std::cout << "Invalid Problem Size: " << options.batch << 'x' << options.num_heads_q << 'x' << @@ -573,9 +681,20 @@ template struct ExampleRunner { // Convert host-side arguments to device-side arguments to be passed to the kernel auto params = FMHAKernel::to_underlying_arguments(arguments, workspace.get()); + auto reduce_params = ReductionSplitKernel::to_underlying_arguments(reduce_arg, workspace.get() + workspace_size); // Run the GEMM - run(params); + if constexpr (isSplitKV) { + if (!ReductionSplitKernel::can_implement(reduce_arg)) { + std::cout << "Invalid Problem Size for ReductionSplitKernel" << std::endl; + return cutlass::Status::kErrorInvalidProblem; + } + + CUTLASS_CHECK(ReductionSplitKernel::initialize_workspace(reduce_arg, workspace.get() + workspace_size)); + run(params, reduce_params); + } else { + run(params); + } compat::wait(); @@ -584,17 +703,21 @@ template struct ExampleRunner { std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; if (!passed) { - return cutlass::Status::kErrorInternal; + // return cutlass::Status::kErrorInternal; } if (options.iterations > 0) { GPU_Clock timer; timer.start(); for (int i = 0; i < options.iterations; ++i) { - run(params); + if constexpr (isSplitKV) { + run(params, reduce_params); + } else { + run(params); + } } compat::wait(); - // when seq_len_qo is not equal to seq_len_kv we use bottom up approach for the masking. + // when seq_len_qo is not equal to seq_len_kv we use bottom up approach for the masking. // Following changes will adjust the effective_seq_len_kv when masking applied for such cases auto offset = cute::min(options.seq_len_qo, options.seq_len_kv); auto discard_seq_coord = options.seq_len_qo - offset; @@ -630,15 +753,18 @@ template default */ int PipelineStages, bool persistent, + bool splitkv, typename ElementQ = bfloat16_t, typename ElementK = bfloat16_t, typename ElementV = bfloat16_t, typename ElementO = float, + typename ElementOaccum = float, typename MMAOperation_ = void, /* void -> default */ typename StrideQ = Stride, typename StrideK = Stride, typename StrideV = Stride<_1, int, int, int>, typename StrideO = Stride, + typename StrideOaccum = Stride, typename GmemTiledCopyQ = void, /* void -> default block 2D */ typename GmemTiledCopyK = void, typename GmemTiledCopyV = void, @@ -657,7 +783,7 @@ struct FMHAConfig { decltype(cutlass::fmha::collective::get_sg_layout_pv(SubgroupLayoutQK{})), SubgroupLayoutPV_>; - template + template static int run(const Options &options) { // // Run examples @@ -685,7 +811,10 @@ struct FMHAConfig { using TensorQ = decltype(make_dummy_tensor(ElementQ{}, StrideQ{})); using TensorK = decltype(make_dummy_tensor(ElementK{}, StrideK{})); using TensorV = decltype(make_dummy_tensor(ElementV{}, StrideV{})); - using TensorO = decltype(make_dummy_tensor(ElementO{}, StrideO{})); + using TensorO = conditional_t; // Mainloop using MainloopDispatchPolicy = cutlass::fmha::XeDefault; @@ -701,29 +830,38 @@ struct FMHAConfig { CollectiveMainloop, TileShapeOutput, TensorO, - GmemTiledCopyO + conditional_t >; static_assert(!(persistent & Causal), "persistent SDPA kernel not support Causal yet"); - using FMHAKernel = conditional_t, - cutlass::fmha::kernel::XeFMHAFwdDynamicSplitKernel< + using FMHAKernel = conditional_t, + cutlass::fmha::kernel::XeFMHAFwdPersistentKernel< ProblemShapeType, CollectiveMainloop, CollectiveEpilogue, Scheduler>, - cutlass::fmha::kernel::XeFMHAFwdKernel< - ProblemShapeType, CollectiveMainloop, CollectiveEpilogue, Scheduler> + conditional_t, + cutlass::fmha::kernel::XeFMHAFwdKernel< + ProblemShapeType, CollectiveMainloop, CollectiveEpilogue, Scheduler> + > >; - ExampleRunner runner; + using ReduceSplitKernel = cutlass::reduction::kernel::ReduceSplitK< + ProblemShapeType, cutlass::fmha::kernel::XeReduceSplitKTileScheduler, FMHAKernel>; + + ExampleRunner runner; CUTLASS_CHECK(runner.run(options, hw_info)); + return 0; } static int run(const Options &options) { if (options.varlen) { - return run(options); + return run(options); } else { - return persistent ? run(options) : - run(options); + return persistent ? run(options) : + (splitkv ? run(options) : + run(options)); } } }; From 1f069312f6fd9ac0e6226e4490e209aa76b40911 Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Sun, 7 Dec 2025 21:44:59 -0800 Subject: [PATCH 02/12] fix return type --- .../flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp b/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp index 38733d0cb2..6efceb674d 100644 --- a/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp +++ b/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp @@ -246,7 +246,8 @@ class FMHAFwdEpilogue { using namespace sycl::ext::oneapi::this_work_item; if constexpr (ReduceK{} == _1{}) { - return std::make_tuple(tArA, tA_sum, true); + ReduceFragARow rA_sum; + return std::make_tuple(tArA, tA_sum, rA_sum, true); } else { /* Identify A tile ID and k block for this subgroup. */ auto thr_vak = group<1,3>(TiledMMAPV{}.get_thr_layout_vmnk()).get_flat_coord(assert_uniform(thr_id)); From 7e784befdf487a9f86e4ad060d764b3f04492dfc Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Fri, 12 Dec 2025 05:14:59 -0800 Subject: [PATCH 03/12] debugging accuracy passed performance WIP --- .../collective/xe_fmha_fwd_epilogue.hpp | 23 +++++- .../collective/xe_fmha_fwd_mainloop.hpp | 5 +- .../kernel/xe_fhma_fwd_kernel.hpp | 77 ++++++++++++++++--- .../kernel/xe_reduce_split_k.h | 33 ++++---- .../kernel/xe_tile_scheduler.hpp | 2 +- .../xe_fmha_fwd_runner.hpp | 25 ++++++ 6 files changed, 138 insertions(+), 27 deletions(-) diff --git a/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp b/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp index 6efceb674d..511570af61 100644 --- a/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp +++ b/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp @@ -164,6 +164,12 @@ class FMHAFwdEpilogue { for (int i = 0; i < rA.size(); i++) rA(i) *= broadcast<0>(rA_sum, rA, i); +#if 1 + if (ThreadIdxX() == 0) { + // cute::print("wg id: %d, rA(0): %f, rA_sum(0): %f\n", BlockIdxZ(), (float)rA(0),(float)rA_sum(0)); + } +#endif + /* Tile output */ Tensor cO = make_identity_tensor(O.shape()); // (q,v) Tensor gO = local_tile(cO, TileShapeO{}, blk_qv); // (q,v) @@ -197,12 +203,25 @@ class FMHAFwdEpilogue { using namespace cute; using ElementA = typename FragA::element_type; +#if 0 + if (ThreadIdxX() == 0 && BlockIdxZ() == 0) { + // cute::print("idx_kv_split: %d, idx_b: %d, head_q: %d, Q(0,0,head_q,l_coord): %f\n", idx_kv_split, idx_b, head_q, float(Q(0,34,head_q,l_coord))); + cute::print(" fwd epilogue tA_max(0): %f\n", float(tA_max(0))); + cute::print(" fwd epilogue tA_sum(0): %f\n", float(tA_sum(0))); + cute::print(" fwd epilogue tArA(0): %f\n", float(tArA(0))); + } +#endif + // Reduce k-blocks of A and A_sum across WG, if needed. auto [rA, rA_max, rA_sum, active] = reduce_A(tArA, tA_max, tA_sum, thr_id); // store exp sum and max logits for current KV split - exp_sums(0) = rA_sum(0); - max_logits(0) = rA_max(0); + // assume seq_len_qo == 1 + if (ThreadIdxX() == 0) { + static_assert(size(FragARow{}) == 1, "only size 1 of FragARow is now supported"); + exp_sums(0,0) = rA_sum(0); + max_logits(0,0) = rA_max(0); + } /* Some subgroups may not have any work to do; if so, quit early. */ if (!active) return; diff --git a/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp index b7c400a63a..290b8a33c7 100644 --- a/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp +++ b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp @@ -175,7 +175,8 @@ struct FMHAFwdMainloop, CausalMask_, int thr_id, int seq_len, int full_tile_offset, - int discard_seq_coord) { + int discard_seq_coord, + bool need_init = false) { using namespace sycl::ext::oneapi::this_work_item; // Short dimension names: @@ -251,7 +252,7 @@ struct FMHAFwdMainloop, CausalMask_, /* Initialization steps for first block: Q/K prefetch, O init */ /* TODO: limit D prefetch for large head size, and reorder K prefetches */ - if (blk_k0 == 0) { + if (blk_k0 == 0 || need_init) { for (int D = 0; D < size<3>(pQgQ); D++) { prefetch(prefetch_q, pQgQ(_,_,_,D)); } diff --git a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp index 2e3a822d57..523ab7c285 100644 --- a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp +++ b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp @@ -592,7 +592,7 @@ class XeFMHAFwdPersistentKernel { V(_,_,head_kv,idx_b), tArA, tA_max, tA_sum, blk_qv, start_blk, end_blk, local_k_blocks, - thr_id, s.seq_len_kv, /*for causal*/0, 0); + thr_id, s.seq_len_kv, /*for causal*/0, 0, /*need_init*/true); // partition id of start batch head id in current wg int partition_id = get_partition_id(wg_id, batch_head_id, num_blocks_per_wg, local_k_blocks); @@ -905,49 +905,64 @@ class XeFMHAFwdSplitKVKernel { // auto shape_V = make_shape(s.head_size_vo, seq_len_kv, s.num_heads_kv, batch_dim); // auto shape_O = make_shape(seq_len_qo, s.head_size_vo, s.num_heads_kv, batch_dim); - int num_blocks_per_split, kv_split_offset, num_effective_kv_blocks; + int num_blocks_per_split, kv_split_offset, num_effective_kv_blocks, effective_kv_seq_length; if constexpr (SplitKV) { num_blocks_per_split = cute::ceil_div(k_blocks, params.scheduler.num_kv_splits_); kv_split_offset = idx_kv_split * num_blocks_per_split; num_effective_kv_blocks = cute::min(k_blocks - kv_split_offset, num_blocks_per_split); + effective_kv_seq_length = num_effective_kv_blocks * get<1>(TileShapeQK{}); + effective_kv_seq_length = seq_len; - shape_K = make_shape(num_effective_kv_blocks * get<1>(TileShapeQK{}), s.head_size_qk, s.num_heads_kv, batch_dim); - shape_V = make_shape(s.head_size_vo, num_effective_kv_blocks * get<1>(TileShapeQK{}), s.num_heads_kv, batch_dim); +#if 0 + if (thr_id == 0) { + cute::print("\nidx_kv_split: %d, kv_split_offset: %d, num_effective_kv_blocks: %d, k_blocks: %d, num_blocks_per_split: %d\n", + idx_kv_split, kv_split_offset, num_effective_kv_blocks, k_blocks, num_blocks_per_split); + } +#endif + + shape_K = make_shape(effective_kv_seq_length, s.head_size_qk, s.num_heads_kv, batch_dim); + shape_V = make_shape(s.head_size_vo, effective_kv_seq_length, s.num_heads_kv, batch_dim); shape_O = make_shape(seq_len_qo, s.head_size_vo, 1, s.num_heads_q * batch_dim); + // shape_O = make_shape(seq_len_qo, s.head_size_vo, s.num_heads_q, batch_dim); shape_exp_sums = make_shape(s.seq_len_qo, 1, s.num_heads_q, batch_dim); shape_max_logits = make_shape(s.seq_len_qo, 1, s.num_heads_q, batch_dim); // TODO: adapt for var length // offset_k = ((kv_split_offset * get<1>(TileShapeQK{})) * s.head_size_qk) * s.num_heads_kv * batch_dim; - offset_k = kv_split_offset * get<1>(TileShapeQK{}); + offset_k = kv_split_offset * get<1>(TileShapeQK{}) * s.head_size_qk; // offset_v = s.num_heads_kv * s.head_size_vo * (idx_b * seq_len + kv_split_offset * get<1>(TileShapeQK{})); // offset_v = ((kv_split_offset * get<1>(TileShapeQK{})) * s.head_size_vo) * s.num_heads_kv * batch_dim; - offset_v = kv_split_offset * get<1>(TileShapeQK{}); + offset_v = kv_split_offset * get<1>(TileShapeQK{}) * s.head_size_qk; // assume: Oaccum is allocated with shape (batch * num_heads_q, num_kv_splits, seq_len_qo, head_size_vo) offset_o = s.head_size_vo * seq_len_qo * idx_kv_split; offset_exp_sums = idx_kv_split; offset_max_logits = idx_kv_split; + + // offset_o = offset_k = offset_v = offset_exp_sums = offset_max_logits = 0; + offset_k = offset_v = 0; } else { shape_K = make_shape(seq_len_kv, s.head_size_qk, s.num_heads_kv, batch_dim); shape_V = make_shape(s.head_size_vo, seq_len_kv, s.num_heads_kv, batch_dim); - shape_O = make_shape(seq_len_qo, s.head_size_vo, s.num_heads_kv, batch_dim); + shape_O = make_shape(seq_len_qo, s.head_size_vo, s.num_heads_q, batch_dim); } auto dcQ = const_cast(p.Q + offset_q); auto dcK = const_cast(p.K + offset_k); auto dcV = const_cast(p.V + offset_v); auto ptrO = (SplitKV ? p.Oaccum : p.O) + offset_o; + // auto ptrO = p.O + offset_o; auto ptrExp_sums = p.exp_sums + offset_exp_sums; auto ptrMax_logits = p.max_logits + offset_max_logits; auto stride_q = is_var_len ? cutlass::make_cute_packed_stride(StrideQ{}, shape_Q) : p.dQ; auto stride_k = is_var_len ? cutlass::make_cute_packed_stride(StrideK{}, shape_K) : p.dK; auto stride_v = is_var_len ? cutlass::make_cute_packed_stride(StrideV{}, shape_V) : p.dV; - auto stride_o = is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_O) : p.dO; + auto stride_o = is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_O) : (SplitKV ? p.dOaccum : p.dO); + // auto stride_o = is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_O) : p.dO; auto stride_exp_sums = p.dExp_sums; auto stride_max_logits = p.dMax_logits; @@ -956,6 +971,16 @@ class XeFMHAFwdSplitKVKernel { Tensor V = make_tensor(make_gmem_ptr(dcV), make_layout(shape_V, stride_v)); Tensor O = make_tensor(make_gmem_ptr(ptrO), make_layout(shape_O, stride_o)); +#if 0 + if (thr_id == 0 && BlockIdxZ() == 0 && idx_kv_split == 0 && head_q == 0) { + cute::print("\nidx_kv_split: %d, idx_b: %d, head_q: %d, O shape: ", idx_kv_split, idx_b, head_q);cute::print(O.shape());print("\n"); + cute::print("K shape: ");cute::print(K.shape());cute::print(K.stride());cute::print("\n"); + cute::print("V shape: ");cute::print(V.shape());cute::print(V.stride());cute::print("\n"); + cute::print("O stride: ");cute::print(O.stride());cute::print("\n"); + cute::print("stride_o: ");cute::print(stride_o);cute::print("\n"); + } +#endif + Tensor exp_sums = make_tensor(make_gmem_ptr(ptrExp_sums), make_layout(shape_exp_sums, stride_exp_sums)); Tensor max_logits = make_tensor(make_gmem_ptr(ptrMax_logits), make_layout(shape_max_logits, stride_max_logits)); @@ -969,6 +994,12 @@ class XeFMHAFwdSplitKVKernel { int start_blk = SplitKV ? kv_split_offset : 0; int end_blk = SplitKV ? (kv_split_offset + num_effective_kv_blocks) : k_blocks; +#if 0 + if (thr_id == 0) { + cute::print("\nidx_kv_split: %d, idx_b: %d, head_q: %d, start_blk: %d, end_blk: %d\n", idx_kv_split, idx_b, head_q, start_blk, end_blk); + } +#endif + CollectiveMainloop mainloop(params.mainloop, shared_storage.mainloop); mainloop(Q(_,_,head_q,l_coord), @@ -977,7 +1008,24 @@ class XeFMHAFwdSplitKVKernel { tArA, tA_max, tA_sum, blk_qv, start_blk, end_blk, k_blocks, thr_id, seq_len, - full_tile_offset, discard_seq_coord); + full_tile_offset, discard_seq_coord, + /*need_init*/true); + +#if 0 + // static_assert(is_same_v, "dtype mismatch"); + if (idx_kv_split == 0 && head_q == 0 && thr_id == 0) { + // cute::print("idx_kv_split: %d, idx_b: %d, head_q: %d, Q(0,0,head_q,l_coord): %f\n", idx_kv_split, idx_b, head_q, float(Q(0,34,head_q,l_coord))); + cute::print("idx_kv_split: 0, head_q: 0, tid: 0, tA_max(0): %f\n", float(tA_max(0))); + cute::print("idx_kv_split: 0, head_q: 0, tid: 0, tA_sum(0): %f\n", float(tA_sum(0))); + cute::print("idx_kv_split: 0, head_q: 0, tid: 0, tArA(0): %f\n", float(tArA(0))); + } + + if (idx_kv_split == 1 && head_q == 0 && thr_id == 0) { + // cute::print("idx_kv_split: %d, idx_b: %d, head_q: %d, Q(0,0,head_q,l_coord): %f\n", idx_kv_split, idx_b, head_q, float(Q(0,34,head_q,l_coord))); + cute::print("idx_kv_split: 1, head_q: 0, tid: 0, tA_max(0): %f\n", float(tA_max(0))); + cute::print("idx_kv_split: 1, head_q: 0, tid: 0, tArA(0): %f\n", float(tArA(0))); + } +#endif if constexpr (!is_empty_v && !is_empty_v) { sycl::group_barrier(get_work_group<3>()); @@ -992,11 +1040,22 @@ class XeFMHAFwdSplitKVKernel { blk_qv, thr_id, exp_sums(_,_,head_q,l_coord), max_logits(_,_,head_q,l_coord)); + // epilogue(O(_,_,0,head_q + idx_b * s.num_heads_q), + // tArA, tA_max, tA_sum, + // blk_qv, thr_id); } else { epilogue(O(_,_,head_q,l_coord), tArA, tA_max, tA_sum, blk_qv, thr_id); } + +#if 1 + if (thr_id == 0) { + // cute::print("idx_kv_split: %d, idx_b: %d, head_q: %d, O(0,0): %f, exp_sums(0,0): \n", idx_kv_split, idx_b, head_q, float(O(0,0,0,1))); + // cute::print("idx_kv_split: %d, idx_b: %d, head_q: %d, max_logits(0,0): %f\n", idx_kv_split, idx_b, head_q, max_logits(0,0,head_q,l_coord)); + } +#endif + } }; diff --git a/applications/flash_attention_v2/kernel/xe_reduce_split_k.h b/applications/flash_attention_v2/kernel/xe_reduce_split_k.h index 8802ba3f35..326e87eef5 100644 --- a/applications/flash_attention_v2/kernel/xe_reduce_split_k.h +++ b/applications/flash_attention_v2/kernel/xe_reduce_split_k.h @@ -86,11 +86,11 @@ class ReduceSplitK { StrideO dO; // below are inputs // TODO: whether same dtype as output or accum? - ElementO *Oaccum; + const ElementO *Oaccum; StrideO dOaccum; - ElementO *exp_sums; + const ElementO *exp_sums; StrideO dExp_sums; - ElementO *max_logits; + const ElementO *max_logits; StrideO dMax_logits; }; using KernelParams = KernelArguments; @@ -185,8 +185,8 @@ class ReduceSplitK { offset_max_logits = (idx_b * s.num_heads_q + head_q) * s.seq_len_qo; auto dcOaccum = const_cast(p.Oaccum + offset_o_accum); auto ptrO = p.O + offset_o; - auto ptrExp_sums = p.exp_sums + offset_exp_sums; - auto ptrMax_logits = p.max_logits + offset_max_logits; + auto ptrExp_sums = const_cast(p.exp_sums + offset_exp_sums); + auto ptrMax_logits = const_cast(p.max_logits + offset_max_logits); using Stride_O = cute::Stride, int64_t>; using Stride_Oaccum = Stride_O; @@ -222,24 +222,26 @@ class ReduceSplitK { ElementO cur_exp_sum = exp_sums(0, thr_id, 0); shared_storage.exp_sums_slm_array[thr_id] = cur_exp_sum; - global_exp_sums = cur_exp_sum; } + // barrier for SLM writes finished + sycl::group_barrier(get_work_group<3>()); + if (sub_group_id == 0) { // reduce within subgroup + // here assume num_kv_splits not exceed subgroup size 16 global_max_logits = reduce_over_group(get_sub_group(), global_max_logits, sycl::maximum<>()); - global_exp_sums = reduce_over_group(get_sub_group(), global_exp_sums, sycl::plus<>()); - - // broadcast to other threads - sycl::group_broadcast(get_work_group<3>(), global_max_logits, 0); - sycl::group_broadcast(get_work_group<3>(), global_exp_sums, 0); + // global_exp_sums = reduce_over_group(get_sub_group(), global_exp_sums, sycl::plus<>()); } + // broadcast to other threads + global_max_logits = sycl::group_broadcast(get_work_group<1>(), global_max_logits, 0); + + // global_exp_sums = sycl::group_broadcast(get_work_group<1>(), global_exp_sums, 0); + // barrier for SLM writes finished sycl::group_barrier(get_work_group<3>()); - ElementO inv_global_exp_sums = 1. / global_exp_sums; - // step 2: rescale Oaccum and write back to O for (int idx = thr_id; idx < s.head_size_vo; idx += SGPerWG::value * intel::sg_size) { ElementO acc = 0; @@ -252,7 +254,12 @@ class ReduceSplitK { // in FMHA epilogue, it's divided by local_exp_sum ElementO adjusted_o_accum = Oaccum(0, idx, i) * local_exp_sum; acc += adjusted_o_accum * rescale; + + // update global exp sum + global_exp_sums += local_exp_sum * rescale; } + + ElementO inv_global_exp_sums = 1. / global_exp_sums; acc *= inv_global_exp_sums; O(0, idx, 0) = acc; } diff --git a/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp b/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp index 2197c26b0c..5642552e4a 100644 --- a/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp +++ b/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp @@ -65,7 +65,7 @@ struct XeFHMAIndividualTileScheduler { dim3 grid(size(ceil_div(shape.head_size_vo, get<1>(tile_shape))), // V size(ceil_div(shape.seq_len_qo, get<0>(tile_shape))), // Q size(shape.batch * shape.num_heads_q)); // (h,b) -- split later - if (num_kv_splits > 0) { + if (num_kv_splits > 1) { grid.z *= num_kv_splits; } std::cout << "XeFHMAIndividualTileScheduler Grid: (" << grid.x << ", " << grid.y << ", " << grid.z << ")\n"; diff --git a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp index c6b7c11d56..8341eaa97d 100644 --- a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp +++ b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp @@ -698,6 +698,30 @@ struct ExampleRunner { compat::wait(); +#if 0 + std::vector vec_Oaccum(block_Oaccum.size()); + block_Oaccum.copy_to_host(vec_Oaccum.data()); + for (size_t i = 0; i < vec_Oaccum.size(); i++) { + std::cout << "Oaccum[" << i << "] = " << vec_Oaccum[i] << std::endl; + if (i > 20) break; + } + + std::vector vec_exp_sums(block_exp_sums.size()); + std::vector vec_max_logits(block_max_logits.size()); + block_exp_sums.copy_to_host(vec_exp_sums.data()); + block_max_logits.copy_to_host(vec_max_logits.data()); + for (size_t i = 0; i < vec_exp_sums.size(); i++) { + std::cout << "exp_sums[" << i << "]: " << vec_exp_sums[i] << ", max_logits[" << i << "]: " << vec_max_logits[i] << std::endl; + if (i > 20) break; + } + + std::vector vec_O(block_O.size()); + block_O.copy_to_host(vec_O.data()); + for (size_t i = 0; i < vec_O.size(); i++) { + std::cout << "O[" << i << "] = " << vec_O[i] << std::endl; + if (i > 20) break; + } +#endif // Verify that the result is correct bool passed = verify(shape, options.is_causal); std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; @@ -737,6 +761,7 @@ struct ExampleRunner { std::cout << "Batch: " << options.batch << "\tNumHeads_q: " << options.num_heads_q << "\tNumHeads_kv: " << options.num_heads_kv << "\tSeq Length QO: " << options.seq_len_qo << "\tSeq Length KV: " << options.seq_len_kv << "\tHead Size QK: " << options.head_size_qk << "\tHead Size VO: " << options.head_size_vo << "\tCausal Mask: " << (options.is_causal ? "true" : "false") << "\tVariable Sequence Length: " << (options.varlen ? "true" : "false") + << "\tKV Splits: " << options.num_kv_splits << "\t Scheduler: " << options.scheduler; printf("\nPerformance: %4.3f GB/s, %4.3f TFlop/s, %6.4f ms\n\n", gbps, tflops, cute_time * 1000); } From 015bba4828f3ed5a370a6a8553660223d410f4a0 Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Mon, 15 Dec 2025 01:08:56 -0800 Subject: [PATCH 04/12] GQA packing each work group handles whole group query heads and packing group query heads into single MMA call --- .../collective/xe_fmha_fwd_epilogue.hpp | 9 +- .../collective/xe_fmha_fwd_mainloop.hpp | 7 ++ .../kernel/xe_fhma_fwd_kernel.hpp | 106 +++++++++--------- .../kernel/xe_reduce_split_k.h | 53 ++++----- .../kernel/xe_tile_scheduler.hpp | 4 +- .../06_bmg_flash_attention/06_xe_fmha_fwd.cpp | 19 +++- .../xe_fmha_fwd_runner.hpp | 57 +++++----- 7 files changed, 143 insertions(+), 112 deletions(-) diff --git a/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp b/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp index 511570af61..6399c4fce8 100644 --- a/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp +++ b/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp @@ -197,7 +197,9 @@ class FMHAFwdEpilogue { QVCoord blk_qv, // WG tile indices: (q,v) int thr_id, // Work-item ID const TensorO2D & exp_sums, // Global exp sum tensor - const TensorO2D & max_logits // Global max logits tensor + const TensorO2D & max_logits, // Global max logits tensor + int idx_kv_split, + int head_q ) { using namespace cute; @@ -218,9 +220,8 @@ class FMHAFwdEpilogue { // store exp sum and max logits for current KV split // assume seq_len_qo == 1 if (ThreadIdxX() == 0) { - static_assert(size(FragARow{}) == 1, "only size 1 of FragARow is now supported"); - exp_sums(0,0) = rA_sum(0); - max_logits(0,0) = rA_max(0); + exp_sums(head_q,idx_kv_split) = rA_sum(0); + max_logits(head_q,idx_kv_split) = rA_max(0); } /* Some subgroups may not have any work to do; if so, quit early. */ diff --git a/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp index 290b8a33c7..ed776d166d 100644 --- a/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp +++ b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp @@ -124,6 +124,7 @@ struct FMHAFwdMainloop, CausalMask_, using SingleFragA = FragC; // (atom val,q',v') using FragA = expand_sg_fragment_t; // (atom val,q',v',VV) using FragARow = decltype(reduce<1>(FragA{}, sycl::plus{})); + // static_assert(is_same_v, "dtype mismatched"); using ElementA = typename TiledMMAPV::ValTypeD; static constexpr bool CausalMask = CausalMask_; @@ -196,6 +197,12 @@ struct FMHAFwdMainloop, CausalMask_, Tensor cV = make_identity_tensor(V_2D.shape()); // (v,k) Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{})); // (q,k) +#if 0 + if (ThreadIdxX() == 0 && BlockIdxZ() == 0) { + print("Q 2D shape: "); print(Q_2D.shape()); print("\n"); + } +#endif + /* Partition global tensors into workgroup tiles */ Tensor gQ = local_tile(cQ, TileShapeQK{}, append(blk_qv,_), Step<_1,X,_1>{}); // (q,d,D) Tensor gK = local_tile(cK, TileShapeQK{}, make_coord(_,_,_), Step{}); // (k,d,K,D) diff --git a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp index 523ab7c285..1bfb8fc711 100644 --- a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp +++ b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp @@ -699,7 +699,7 @@ class XeFMHAFwdPersistentKernel { } }; -template +template class XeFMHAFwdSplitKVKernel { public: @@ -758,7 +758,8 @@ class XeFMHAFwdSplitKVKernel { static constexpr int SharedStorageSize = is_empty_v ? size_t(0) : sizeof(SharedStorage); - static constexpr int max_num_kv_splits = 8; + static constexpr int max_num_kv_splits = intel::sg_size; + static constexpr int dpas_max_repeat_count = 8; // Device side arguments struct KernelArguments { @@ -818,6 +819,11 @@ class XeFMHAFwdSplitKVKernel { return false; } + // when GQA packing enabled, limit head group size to 8 + if (GqaPacking && (args.kernel.shape.num_heads_q / args.kernel.shape.num_heads_kv > dpas_max_repeat_count)) { + return false; + } + return CollectiveMainloop::can_implement(args.mainloop) && CollectiveEpilogue::can_implement(args.epilogue); } @@ -865,13 +871,14 @@ class XeFMHAFwdSplitKVKernel { auto q_offset_sg = group_broadcast(sycl::ext::oneapi::this_work_item::get_sub_group(), q_offset_wi, 0); TileScheduler tile_scheduler{params.scheduler}; + auto num_kv_splits = params.scheduler.num_kv_splits_; CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { // auto [blk_q, blk_v, head_q, idx_b] = tile_scheduler.get_block_coord(); // (Q,V,h,b) - auto [blk_q, blk_v, head_q, idx_b, idx_kv_split] = tile_scheduler.get_block_coord(); // (Q,V,h,b) + auto [blk_q, blk_v, head, idx_b, idx_kv_split] = tile_scheduler.get_block_coord(); // (Q,V,h,b) auto blk_qv = make_coord(blk_q, blk_v); - int head = head_q / head_group_q; + int head_q_start = head * head_group_q; auto sequence_length_shape = get_sequence_length_shape(s, idx_b); auto [seq_len_qo, seq_len_kv] = sequence_length_shape; @@ -898,7 +905,7 @@ class XeFMHAFwdSplitKVKernel { } auto batch_dim = is_var_len ? 1 : s.batch; - auto shape_Q = make_shape(seq_len_qo, s.head_size_qk, s.num_heads_q, batch_dim); + auto shape_Q = GqaPacking ? make_shape(seq_len_qo * head_group_q, s.head_size_qk, s.num_heads_kv, batch_dim) : make_shape(seq_len_qo, s.head_size_qk, s.num_heads_q, batch_dim); // 4D shape decltype(shape_Q) shape_K, shape_V, shape_O, shape_exp_sums, shape_max_logits; // auto shape_K = make_shape(seq_len_kv, s.head_size_qk, s.num_heads_kv, batch_dim); @@ -908,7 +915,7 @@ class XeFMHAFwdSplitKVKernel { int num_blocks_per_split, kv_split_offset, num_effective_kv_blocks, effective_kv_seq_length; if constexpr (SplitKV) { - num_blocks_per_split = cute::ceil_div(k_blocks, params.scheduler.num_kv_splits_); + num_blocks_per_split = cute::ceil_div(k_blocks, num_kv_splits); kv_split_offset = idx_kv_split * num_blocks_per_split; num_effective_kv_blocks = cute::min(k_blocks - kv_split_offset, num_blocks_per_split); effective_kv_seq_length = num_effective_kv_blocks * get<1>(TileShapeQK{}); @@ -923,27 +930,28 @@ class XeFMHAFwdSplitKVKernel { shape_K = make_shape(effective_kv_seq_length, s.head_size_qk, s.num_heads_kv, batch_dim); shape_V = make_shape(s.head_size_vo, effective_kv_seq_length, s.num_heads_kv, batch_dim); - shape_O = make_shape(seq_len_qo, s.head_size_vo, 1, s.num_heads_q * batch_dim); - // shape_O = make_shape(seq_len_qo, s.head_size_vo, s.num_heads_q, batch_dim); + // shape_O = make_shape(seq_len_qo, s.head_size_vo, 1, s.num_heads_q * batch_dim); + shape_O = GqaPacking ? make_shape(seq_len_qo * head_group_q, s.head_size_vo, s.num_heads_kv * batch_dim, num_kv_splits) : make_shape(seq_len_qo, s.head_size_vo, s.num_heads_q * batch_dim, num_kv_splits); - shape_exp_sums = make_shape(s.seq_len_qo, 1, s.num_heads_q, batch_dim); - shape_max_logits = make_shape(s.seq_len_qo, 1, s.num_heads_q, batch_dim); + // shape_exp_sums = make_shape(s.seq_len_qo, 1, s.num_heads_q, batch_dim); + // shape_max_logits = make_shape(s.seq_len_qo, 1, s.num_heads_q, batch_dim); + shape_exp_sums = GqaPacking ? make_shape(s.seq_len_qo * head_group_q, num_kv_splits, s.num_heads_kv, batch_dim) : make_shape(s.seq_len_qo, num_kv_splits, s.num_heads_q, batch_dim); + shape_max_logits = GqaPacking ? make_shape(s.seq_len_qo * head_group_q, num_kv_splits, s.num_heads_kv, batch_dim) : make_shape(s.seq_len_qo, num_kv_splits, s.num_heads_q, batch_dim); // TODO: adapt for var length // offset_k = ((kv_split_offset * get<1>(TileShapeQK{})) * s.head_size_qk) * s.num_heads_kv * batch_dim; - offset_k = kv_split_offset * get<1>(TileShapeQK{}) * s.head_size_qk; - // offset_v = s.num_heads_kv * s.head_size_vo * (idx_b * seq_len + kv_split_offset * get<1>(TileShapeQK{})); - // offset_v = ((kv_split_offset * get<1>(TileShapeQK{})) * s.head_size_vo) * s.num_heads_kv * batch_dim; - offset_v = kv_split_offset * get<1>(TileShapeQK{}) * s.head_size_qk; + // offset_k = kv_split_offset * get<1>(TileShapeQK{}) * s.head_size_qk; + // // offset_v = s.num_heads_kv * s.head_size_vo * (idx_b * seq_len + kv_split_offset * get<1>(TileShapeQK{})); + // // offset_v = ((kv_split_offset * get<1>(TileShapeQK{})) * s.head_size_vo) * s.num_heads_kv * batch_dim; + // offset_v = kv_split_offset * get<1>(TileShapeQK{}) * s.head_size_qk; - // assume: Oaccum is allocated with shape (batch * num_heads_q, num_kv_splits, seq_len_qo, head_size_vo) - offset_o = s.head_size_vo * seq_len_qo * idx_kv_split; + // // assume: Oaccum is allocated with shape (num_kv_splits, batch * num_heads_q, seq_len_qo, head_size_vo) + // offset_o = s.head_size_vo * seq_len_qo * idx_kv_split * s.num_heads_q * batch_dim; - offset_exp_sums = idx_kv_split; - offset_max_logits = idx_kv_split; + // offset_exp_sums = idx_kv_split; + // offset_max_logits = idx_kv_split; - // offset_o = offset_k = offset_v = offset_exp_sums = offset_max_logits = 0; - offset_k = offset_v = 0; + offset_o = offset_k = offset_v = offset_exp_sums = offset_max_logits = 0; } else { shape_K = make_shape(seq_len_kv, s.head_size_qk, s.num_heads_kv, batch_dim); shape_V = make_shape(s.head_size_vo, seq_len_kv, s.num_heads_kv, batch_dim); @@ -954,17 +962,17 @@ class XeFMHAFwdSplitKVKernel { auto dcK = const_cast(p.K + offset_k); auto dcV = const_cast(p.V + offset_v); auto ptrO = (SplitKV ? p.Oaccum : p.O) + offset_o; - // auto ptrO = p.O + offset_o; auto ptrExp_sums = p.exp_sums + offset_exp_sums; auto ptrMax_logits = p.max_logits + offset_max_logits; - auto stride_q = is_var_len ? cutlass::make_cute_packed_stride(StrideQ{}, shape_Q) : p.dQ; + auto stride_q = (is_var_len || GqaPacking) ? cutlass::make_cute_packed_stride(StrideQ{}, shape_Q) : p.dQ; auto stride_k = is_var_len ? cutlass::make_cute_packed_stride(StrideK{}, shape_K) : p.dK; auto stride_v = is_var_len ? cutlass::make_cute_packed_stride(StrideV{}, shape_V) : p.dV; - auto stride_o = is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_O) : (SplitKV ? p.dOaccum : p.dO); - // auto stride_o = is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_O) : p.dO; - auto stride_exp_sums = p.dExp_sums; - auto stride_max_logits = p.dMax_logits; + auto stride_o = (is_var_len || GqaPacking) ? cutlass::make_cute_packed_stride(StrideO{}, shape_O) : (SplitKV ? p.dOaccum : p.dO); + // auto stride_exp_sums = p.dExp_sums; + // auto stride_max_logits = p.dMax_logits; + auto stride_exp_sums = GqaPacking ? cutlass::make_cute_packed_stride(StrideQ{}, shape_exp_sums) : p.dExp_sums; + auto stride_max_logits = GqaPacking ? cutlass::make_cute_packed_stride(StrideQ{}, shape_max_logits) : p.dMax_logits; Tensor Q = make_tensor(make_gmem_ptr(dcQ), make_layout(shape_Q, stride_q)); Tensor K = make_tensor(make_gmem_ptr(dcK), make_layout(shape_K, stride_k)); @@ -972,8 +980,8 @@ class XeFMHAFwdSplitKVKernel { Tensor O = make_tensor(make_gmem_ptr(ptrO), make_layout(shape_O, stride_o)); #if 0 - if (thr_id == 0 && BlockIdxZ() == 0 && idx_kv_split == 0 && head_q == 0) { - cute::print("\nidx_kv_split: %d, idx_b: %d, head_q: %d, O shape: ", idx_kv_split, idx_b, head_q);cute::print(O.shape());print("\n"); + if (thr_id == 0 && BlockIdxZ() == 0 && idx_kv_split == 0 && head_q_start == 0) { + cute::print("\nidx_kv_split: %d, idx_b: %d, head_q_start: %d, O shape: ", idx_kv_split, idx_b, head_q_start);cute::print(O.shape());print("\n"); cute::print("K shape: ");cute::print(K.shape());cute::print(K.stride());cute::print("\n"); cute::print("V shape: ");cute::print(V.shape());cute::print(V.stride());cute::print("\n"); cute::print("O stride: ");cute::print(O.stride());cute::print("\n"); @@ -996,13 +1004,14 @@ class XeFMHAFwdSplitKVKernel { #if 0 if (thr_id == 0) { - cute::print("\nidx_kv_split: %d, idx_b: %d, head_q: %d, start_blk: %d, end_blk: %d\n", idx_kv_split, idx_b, head_q, start_blk, end_blk); + cute::print("\nidx_kv_split: %d, idx_b: %d, head_q_start: %d, start_blk: %d, end_blk: %d\n", idx_kv_split, idx_b, head_q_start, start_blk, end_blk); } #endif CollectiveMainloop mainloop(params.mainloop, shared_storage.mainloop); - mainloop(Q(_,_,head_q,l_coord), + // for GQA packing + mainloop(Q(_,_,head,l_coord), K(_,_,head,l_coord), V(_,_,head,l_coord), tArA, tA_max, tA_sum, @@ -1034,28 +1043,23 @@ class XeFMHAFwdSplitKVKernel { // Epilogue CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; - if constexpr (SplitKV) { - epilogue(O(_,_,0,head_q + idx_b * s.num_heads_q), - tArA, tA_max, tA_sum, - blk_qv, thr_id, - exp_sums(_,_,head_q,l_coord), - max_logits(_,_,head_q,l_coord)); - // epilogue(O(_,_,0,head_q + idx_b * s.num_heads_q), - // tArA, tA_max, tA_sum, - // blk_qv, thr_id); - } else { - epilogue(O(_,_,head_q,l_coord), - tArA, tA_max, tA_sum, - blk_qv, thr_id); - } - -#if 1 - if (thr_id == 0) { - // cute::print("idx_kv_split: %d, idx_b: %d, head_q: %d, O(0,0): %f, exp_sums(0,0): \n", idx_kv_split, idx_b, head_q, float(O(0,0,0,1))); - // cute::print("idx_kv_split: %d, idx_b: %d, head_q: %d, max_logits(0,0): %f\n", idx_kv_split, idx_b, head_q, max_logits(0,0,head_q,l_coord)); + // for GQA packing + for (int q_head_cnt = 0; q_head_cnt < head_group_q; ++q_head_cnt) { + int head_q_curr = head_q_start + q_head_cnt; + if constexpr (SplitKV) { + epilogue(O(_,_,idx_b * s.num_heads_kv + head, idx_kv_split), + tArA, tA_max, tA_sum, + blk_qv, thr_id, + exp_sums(_,_,head,l_coord), + max_logits(_,_,head,l_coord), + idx_kv_split, q_head_cnt); + } else { + // FIXME: use correct head q indx + epilogue(O(_,_,head_q_curr,l_coord), + tArA, tA_max, tA_sum, + blk_qv, thr_id); } -#endif - + } } }; diff --git a/applications/flash_attention_v2/kernel/xe_reduce_split_k.h b/applications/flash_attention_v2/kernel/xe_reduce_split_k.h index 326e87eef5..ab0e3fb788 100644 --- a/applications/flash_attention_v2/kernel/xe_reduce_split_k.h +++ b/applications/flash_attention_v2/kernel/xe_reduce_split_k.h @@ -164,6 +164,11 @@ class ReduceSplitK { TileScheduler tile_scheduler{params.scheduler}; auto num_kv_splits = params.scheduler.num_kv_splits; + auto seq_len_qo = s.seq_len_qo; + auto batch_dim = s.batch; + auto num_heads_q = s.num_heads_q; + auto head_size_vo = s.head_size_vo; + CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto [seq_idx, head_q, idx_b] = tile_scheduler.get_block_coord(); @@ -171,18 +176,18 @@ class ReduceSplitK { int offset_o = 0, offset_o_accum = 0; int offset_exp_sums = 0, offset_max_logits = 0; - auto shape_O = make_shape(s.seq_len_qo, s.head_size_vo, 1); - auto shape_Oaccum = make_shape(s.seq_len_qo, s.head_size_vo, num_kv_splits); + auto shape_O = make_shape(seq_len_qo, head_size_vo, num_heads_q, batch_dim); + auto shape_Oaccum = make_shape(seq_len_qo, head_size_vo, num_heads_q * batch_dim, num_kv_splits); - auto shape_exp_sums = make_shape(s.seq_len_qo, num_kv_splits, 1); - auto shape_max_logits = make_shape(s.seq_len_qo, num_kv_splits, 1); + auto shape_exp_sums = make_shape(seq_len_qo, num_kv_splits, num_heads_q, batch_dim); + auto shape_max_logits = make_shape(seq_len_qo, num_kv_splits, num_heads_q, batch_dim); - // assume: Oaccum is allocated with shape (batch * num_heads_q, num_kv_splits, seq_len_qo, head_size_vo) - offset_o_accum = (idx_b * s.num_heads_q + head_q) * num_kv_splits * s.seq_len_qo * s.head_size_vo; - offset_o = (idx_b * s.num_heads_q + head_q) * s.seq_len_qo * s.head_size_vo; + // assume: Oaccum is allocated with shape (num_kv_splits, batch * num_heads_q, seq_len_qo, head_size_vo) + // offset_o_accum = (idx_b * s.num_heads_q + head_q) * num_kv_splits * s.seq_len_qo * s.head_size_vo; + // offset_o = (idx_b * s.num_heads_q + head_q) * s.seq_len_qo * s.head_size_vo; - offset_exp_sums = (idx_b * s.num_heads_q + head_q) * s.seq_len_qo; - offset_max_logits = (idx_b * s.num_heads_q + head_q) * s.seq_len_qo; + // offset_exp_sums = (idx_b * s.num_heads_q + head_q) * s.seq_len_qo; + // offset_max_logits = (idx_b * s.num_heads_q + head_q) * s.seq_len_qo; auto dcOaccum = const_cast(p.Oaccum + offset_o_accum); auto ptrO = p.O + offset_o; auto ptrExp_sums = const_cast(p.exp_sums + offset_exp_sums); @@ -195,17 +200,17 @@ class ReduceSplitK { // 3D // static_assert(is_same_v, "dtype mismatched"); // static_assert(is_same_v(StrideO{})), float>, "dtype mismatched"); - auto stride_o_accum = cutlass::make_cute_packed_stride(Stride_Oaccum{}, shape_Oaccum); - // 2D - auto stride_o = cutlass::make_cute_packed_stride(Stride_O{}, shape_O); - auto stride_exp_sums = cutlass::make_cute_packed_stride(Stride_Exp_sums{}, shape_exp_sums); - auto stride_max_logits = cutlass::make_cute_packed_stride(Stride_Exp_sums{}, shape_max_logits); + // auto stride_o_accum = cutlass::make_cute_packed_stride(Stride_Oaccum{}, shape_Oaccum); + // // 2D + // auto stride_o = cutlass::make_cute_packed_stride(Stride_O{}, shape_O); + // auto stride_exp_sums = cutlass::make_cute_packed_stride(Stride_Exp_sums{}, shape_exp_sums); + // auto stride_max_logits = cutlass::make_cute_packed_stride(Stride_Exp_sums{}, shape_max_logits); - Tensor Oaccum = make_tensor(make_gmem_ptr(dcOaccum), make_layout(shape_Oaccum, stride_o_accum)); - Tensor O = make_tensor(make_gmem_ptr(ptrO), make_layout(shape_O, stride_o)); + Tensor Oaccum = make_tensor(make_gmem_ptr(dcOaccum), make_layout(shape_Oaccum, p.dOaccum)); + Tensor O = make_tensor(make_gmem_ptr(ptrO), make_layout(shape_O, p.dO)); - Tensor exp_sums = make_tensor(make_gmem_ptr(ptrExp_sums), make_layout(shape_exp_sums, stride_exp_sums)); - Tensor max_logits = make_tensor(make_gmem_ptr(ptrMax_logits), make_layout(shape_max_logits, stride_max_logits)); + Tensor exp_sums = make_tensor(make_gmem_ptr(ptrExp_sums), make_layout(shape_exp_sums, p.dExp_sums)); + Tensor max_logits = make_tensor(make_gmem_ptr(ptrMax_logits), make_layout(shape_max_logits, p.dMax_logits)); // static_assert(is_same_v, "dtype mismatched"); @@ -216,11 +221,11 @@ class ReduceSplitK { ElementO global_exp_sums = 0; // only first subgroup participates if (thr_id < num_kv_splits) { - ElementO cur_max_logit = max_logits(0, thr_id, 0); + ElementO cur_max_logit = max_logits(seq_idx, thr_id, head_q, idx_b); global_max_logits = sycl::max(global_max_logits, cur_max_logit); shared_storage.max_logits_slm_array[thr_id] = cur_max_logit; - ElementO cur_exp_sum = exp_sums(0, thr_id, 0); + ElementO cur_exp_sum = exp_sums(seq_idx, thr_id, head_q, idx_b); shared_storage.exp_sums_slm_array[thr_id] = cur_exp_sum; } @@ -231,14 +236,11 @@ class ReduceSplitK { // reduce within subgroup // here assume num_kv_splits not exceed subgroup size 16 global_max_logits = reduce_over_group(get_sub_group(), global_max_logits, sycl::maximum<>()); - // global_exp_sums = reduce_over_group(get_sub_group(), global_exp_sums, sycl::plus<>()); } // broadcast to other threads global_max_logits = sycl::group_broadcast(get_work_group<1>(), global_max_logits, 0); - // global_exp_sums = sycl::group_broadcast(get_work_group<1>(), global_exp_sums, 0); - // barrier for SLM writes finished sycl::group_barrier(get_work_group<3>()); @@ -252,7 +254,8 @@ class ReduceSplitK { ElementO rescale = sycl::native::exp2(local_max_logit - global_max_logits); // in FMHA epilogue, it's divided by local_exp_sum - ElementO adjusted_o_accum = Oaccum(0, idx, i) * local_exp_sum; + // assume seq_len_q == 1 + ElementO adjusted_o_accum = Oaccum(seq_idx, idx, idx_b * num_heads_q + head_q, i) * local_exp_sum; acc += adjusted_o_accum * rescale; // update global exp sum @@ -261,7 +264,7 @@ class ReduceSplitK { ElementO inv_global_exp_sums = 1. / global_exp_sums; acc *= inv_global_exp_sums; - O(0, idx, 0) = acc; + O(seq_idx, idx, head_q, idx_b) = acc; } } } diff --git a/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp b/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp index 5642552e4a..5b43474014 100644 --- a/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp +++ b/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp @@ -64,12 +64,12 @@ struct XeFHMAIndividualTileScheduler { dim3 grid(size(ceil_div(shape.head_size_vo, get<1>(tile_shape))), // V size(ceil_div(shape.seq_len_qo, get<0>(tile_shape))), // Q - size(shape.batch * shape.num_heads_q)); // (h,b) -- split later + size(shape.batch * shape.num_heads_kv)); // (h,b) -- split later if (num_kv_splits > 1) { grid.z *= num_kv_splits; } std::cout << "XeFHMAIndividualTileScheduler Grid: (" << grid.x << ", " << grid.y << ", " << grid.z << ")\n"; - return Params{grid, {shape.num_heads_q}, {shape.batch * shape.num_heads_q}, num_kv_splits}; + return Params{grid, {shape.num_heads_kv}, {shape.batch * shape.num_heads_kv}, num_kv_splits}; } template diff --git a/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp b/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp index 4fdfa6c0ae..19e481a49f 100644 --- a/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp +++ b/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp @@ -108,12 +108,23 @@ int main(int argc, const char **argv) { #endif #elif defined(DECODE) -#if PERSISTENT +#if defined(PERSISTENT) || defined(SPLITKV) #define NUM_SG _16 #define KV_TILE_SIZE _256 + +// turn on gqa packing optimizations +#define GQA_PACKING +#if defined(GQA_PACKING) +// dpas maximum repeat count is 8 +#define Q_FUSED_TILE_SIZE _8 +#else +#define Q_FUSED_TILE_SIZE _1 +#endif + #else #define NUM_SG _8 #define KV_TILE_SIZE _512 +#define Q_FUSED_TILE_SIZE _1 #endif #if HEAD_DIM == 16 @@ -136,9 +147,9 @@ int main(int argc, const char **argv) { using SubgroupLayoutQK = Layout>; #elif HEAD_DIM == 128 - using ShapeQK = Shape<_1, KV_TILE_SIZE, _64>; - using ShapePV = Shape<_1, _32, KV_TILE_SIZE>; - using ShapeOut = Shape<_1, _128>; + using ShapeQK = Shape; + using ShapePV = Shape; + using ShapeOut = Shape; using SubgroupLayoutQK = Layout>; #elif HEAD_DIM == 192 diff --git a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp index 8341eaa97d..c39f6bf5ff 100644 --- a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp +++ b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp @@ -470,6 +470,33 @@ struct ExampleRunner { compat::wait(); +#if 0 + std::vector vec_Oaccum(block_Oaccum.size()); + block_Oaccum.copy_to_host(vec_Oaccum.data()); + for (size_t i = 0; i < vec_Oaccum.size(); i++) { + std::cout << "Oaccum[" << i << "] = " << vec_Oaccum[i] << std::endl; + if (i > 20) break; + } + + std::vector vec_exp_sums(block_exp_sums.size()); + std::vector vec_max_logits(block_max_logits.size()); + block_exp_sums.copy_to_host(vec_exp_sums.data()); + block_max_logits.copy_to_host(vec_max_logits.data()); + for (size_t i = 0; i < vec_exp_sums.size(); i++) { + // if (i < 8 * 10) continue; + std::cout << "exp_sums[" << i << "]: " << vec_exp_sums[i] << ", max_logits[" << i << "]: " << vec_max_logits[i] << std::endl; + } + + std::vector vec_O(block_O.size()); + block_O.copy_to_host(vec_O.data()); + std::vector vec_ref_O(block_ref_O.size()); + block_ref_O.copy_to_host(vec_ref_O.data()); + for (size_t i = 0; i < vec_ref_O.size(); i++) { + if (i < 5 * 128) continue; + std::cout << "ref_O[" << i << "] = " << vec_ref_O[i] << " vs. O[" << i << "] = " << vec_O[i] << std::endl; + } +#endif + // Check if output from CUTLASS kernel and reference kernel are equal or not bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(), block_O.size(), ElementO{0.05}, ElementO{0.05}); @@ -513,7 +540,7 @@ struct ExampleRunner { block_ref_O.reset(static_cast(batch) * num_heads_q * seq_len_qo * head_size_vo); if constexpr (isSplitKV) { - stride_Oaccum = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, head_size_vo, num_kv_splits, num_heads_q * batch)); + stride_Oaccum = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, head_size_vo, num_heads_q * batch, num_kv_splits)); block_Oaccum.reset(static_cast(batch) * num_heads_q * seq_len_qo * head_size_vo * num_kv_splits); // assume seq_len_qo==1 @@ -609,7 +636,8 @@ struct ExampleRunner { compat::experimental::launch_policy reduce_policy{reduce_sycl_grid, reduce_sycl_block, launch_props_reduce, kernel_props}; // wait for FA kernel finished - event.wait(); + // no need wait here if launched with same queue??? + // event.wait(); auto reduce_event = compat::experimental::launch, ReductionSplitKernel>(reduce_policy, reduce_params); @@ -698,30 +726,6 @@ struct ExampleRunner { compat::wait(); -#if 0 - std::vector vec_Oaccum(block_Oaccum.size()); - block_Oaccum.copy_to_host(vec_Oaccum.data()); - for (size_t i = 0; i < vec_Oaccum.size(); i++) { - std::cout << "Oaccum[" << i << "] = " << vec_Oaccum[i] << std::endl; - if (i > 20) break; - } - - std::vector vec_exp_sums(block_exp_sums.size()); - std::vector vec_max_logits(block_max_logits.size()); - block_exp_sums.copy_to_host(vec_exp_sums.data()); - block_max_logits.copy_to_host(vec_max_logits.data()); - for (size_t i = 0; i < vec_exp_sums.size(); i++) { - std::cout << "exp_sums[" << i << "]: " << vec_exp_sums[i] << ", max_logits[" << i << "]: " << vec_max_logits[i] << std::endl; - if (i > 20) break; - } - - std::vector vec_O(block_O.size()); - block_O.copy_to_host(vec_O.data()); - for (size_t i = 0; i < vec_O.size(); i++) { - std::cout << "O[" << i << "] = " << vec_O[i] << std::endl; - if (i > 20) break; - } -#endif // Verify that the result is correct bool passed = verify(shape, options.is_causal); std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; @@ -796,6 +800,7 @@ template struct FMHAConfig { + // when GQA used, fused group query heads into one single MMA call static constexpr int SGTileQ = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{})))(); using MMAOperation = cute::conditional_t, typename cute::conditional_t< From 4d6ed2f245e6a1283da0a15a96e16e8024b7f6ea Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Mon, 15 Dec 2025 18:07:34 -0800 Subject: [PATCH 05/12] fix NaN issue --- .../flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp | 5 +++++ .../flash_attention_v2/kernel/xe_reduce_split_k.h | 8 ++++++++ examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp | 4 ++-- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp index 1bfb8fc711..28481b86c5 100644 --- a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp +++ b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp @@ -1002,6 +1002,11 @@ class XeFMHAFwdSplitKVKernel { int start_blk = SplitKV ? kv_split_offset : 0; int end_blk = SplitKV ? (kv_split_offset + num_effective_kv_blocks) : k_blocks; + if (end_blk <= start_blk) { + // early exit + return; + } + #if 0 if (thr_id == 0) { cute::print("\nidx_kv_split: %d, idx_b: %d, head_q_start: %d, start_blk: %d, end_blk: %d\n", idx_kv_split, idx_b, head_q_start, start_blk, end_blk); diff --git a/applications/flash_attention_v2/kernel/xe_reduce_split_k.h b/applications/flash_attention_v2/kernel/xe_reduce_split_k.h index ab0e3fb788..f3ee5bffba 100644 --- a/applications/flash_attention_v2/kernel/xe_reduce_split_k.h +++ b/applications/flash_attention_v2/kernel/xe_reduce_split_k.h @@ -69,6 +69,7 @@ class ReduceSplitK { using ElementO = typename FMHAKernel_::ElementO; using StrideO = typename FMHAKernel_::StrideO; using TileShapeO = typename FMHAKernel_::TileShapeO; + using TileShapeQK = typename FMHAKernel_::TileShapeQK; using SGPerWG = typename FMHAKernel_::SGPerWG; @@ -165,10 +166,14 @@ class ReduceSplitK { auto num_kv_splits = params.scheduler.num_kv_splits; auto seq_len_qo = s.seq_len_qo; + auto seq_len_kv = s.seq_len_kv; auto batch_dim = s.batch; auto num_heads_q = s.num_heads_q; auto head_size_vo = s.head_size_vo; + const int k_blocks = cute::ceil_div(seq_len_kv, get<1>(TileShapeQK{})); + int num_blocks_per_split = cute::ceil_div(k_blocks, num_kv_splits); + CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto [seq_idx, head_q, idx_b] = tile_scheduler.get_block_coord(); @@ -248,6 +253,9 @@ class ReduceSplitK { for (int idx = thr_id; idx < s.head_size_vo; idx += SGPerWG::value * intel::sg_size) { ElementO acc = 0; for (int i = 0; i < num_kv_splits; ++i) { + if (i * num_blocks_per_split > k_blocks) { + break; + } ElementO local_max_logit = shared_storage.max_logits_slm_array[i]; ElementO local_exp_sum = shared_storage.exp_sums_slm_array[i]; diff --git a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp index c39f6bf5ff..391d9eef30 100644 --- a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp +++ b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp @@ -636,8 +636,8 @@ struct ExampleRunner { compat::experimental::launch_policy reduce_policy{reduce_sycl_grid, reduce_sycl_block, launch_props_reduce, kernel_props}; // wait for FA kernel finished - // no need wait here if launched with same queue??? - // event.wait(); + // maybe no need wait here if launched with in-order queue + event.wait(); auto reduce_event = compat::experimental::launch, ReductionSplitKernel>(reduce_policy, reduce_params); From e16bc349599991bc560c595e91770dcd25e7cca8 Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Mon, 15 Dec 2025 21:22:29 -0800 Subject: [PATCH 06/12] remove redundant barrier --- applications/flash_attention_v2/kernel/xe_reduce_split_k.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/applications/flash_attention_v2/kernel/xe_reduce_split_k.h b/applications/flash_attention_v2/kernel/xe_reduce_split_k.h index f3ee5bffba..8ae765638d 100644 --- a/applications/flash_attention_v2/kernel/xe_reduce_split_k.h +++ b/applications/flash_attention_v2/kernel/xe_reduce_split_k.h @@ -246,9 +246,6 @@ class ReduceSplitK { // broadcast to other threads global_max_logits = sycl::group_broadcast(get_work_group<1>(), global_max_logits, 0); - // barrier for SLM writes finished - sycl::group_barrier(get_work_group<3>()); - // step 2: rescale Oaccum and write back to O for (int idx = thr_id; idx < s.head_size_vo; idx += SGPerWG::value * intel::sg_size) { ElementO acc = 0; From f516fef3fe089b3e5809fc7d1f7e86faa42ff9eb Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Wed, 17 Dec 2025 00:10:05 -0800 Subject: [PATCH 07/12] Add variable length support --- .../kernel/xe_fhma_fwd_kernel.hpp | 125 ++++++------------ .../kernel/xe_reduce_split_k.h | 89 +++++++------ .../kernel/xe_tile_scheduler.hpp | 15 ++- .../xe_fmha_fwd_runner.hpp | 13 +- 4 files changed, 115 insertions(+), 127 deletions(-) diff --git a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp index 28481b86c5..7b21a1807f 100644 --- a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp +++ b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp @@ -709,9 +709,6 @@ class XeFMHAFwdSplitKVKernel { using ProblemShape = ProblemShape_; using VariableLength = cutlass::fmha::collective::VariableLength; static constexpr bool is_var_len = cutlass::fmha::collective::is_variable_length_v; - // TODO: support later - static_assert(SplitKV && !is_var_len, "XeFMHAFwdSplitKVKernel only supports variable length without KV split"); - // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; using MainloopArguments = typename CollectiveMainloop::Arguments; using MainloopParams = typename CollectiveMainloop::Params; @@ -810,7 +807,7 @@ class XeFMHAFwdSplitKVKernel { } static bool can_implement(Arguments const &args) { - if (args.kernel.shape.seq_len_qo != 1) { + if (!is_var_len && args.kernel.shape.seq_len_qo != 1) { // decode only return false; } @@ -901,83 +898,63 @@ class XeFMHAFwdSplitKVKernel { offset_q = s.num_heads_q * s.head_size_qk * qo_cumulative[idx_b]; offset_k = s.num_heads_kv * s.head_size_qk * kv_cumulative[idx_b]; offset_v = s.num_heads_kv * s.head_size_vo * kv_cumulative[idx_b]; - offset_o = s.num_heads_q * s.head_size_vo * qo_cumulative[idx_b]; + + offset_o = s.num_heads_q * s.head_size_vo * qo_cumulative[idx_b] * num_kv_splits; + offset_exp_sums = s.num_heads_q * num_kv_splits * qo_cumulative[idx_b]; + offset_max_logits = s.num_heads_q * num_kv_splits * qo_cumulative[idx_b]; + + if constexpr (GqaPacking) { + // seq_len_qo must be 1 + seq_len_qo = 1; + } } auto batch_dim = is_var_len ? 1 : s.batch; auto shape_Q = GqaPacking ? make_shape(seq_len_qo * head_group_q, s.head_size_qk, s.num_heads_kv, batch_dim) : make_shape(seq_len_qo, s.head_size_qk, s.num_heads_q, batch_dim); // 4D shape - decltype(shape_Q) shape_K, shape_V, shape_O, shape_exp_sums, shape_max_logits; - // auto shape_K = make_shape(seq_len_kv, s.head_size_qk, s.num_heads_kv, batch_dim); - // auto shape_V = make_shape(s.head_size_vo, seq_len_kv, s.num_heads_kv, batch_dim); - // auto shape_O = make_shape(seq_len_qo, s.head_size_vo, s.num_heads_kv, batch_dim); - - int num_blocks_per_split, kv_split_offset, num_effective_kv_blocks, effective_kv_seq_length; + auto shape_K = make_shape(seq_len_kv, s.head_size_qk, s.num_heads_kv, batch_dim); + auto shape_V = make_shape(s.head_size_vo, seq_len_kv, s.num_heads_kv, batch_dim); + decltype(shape_Q) shape_O; + if constexpr (is_var_len) { + shape_O = GqaPacking ? make_shape(seq_len_qo * head_group_q, s.head_size_vo, s.num_heads_kv * num_kv_splits, batch_dim) : make_shape(seq_len_qo, s.head_size_vo, s.num_heads_q * num_kv_splits, batch_dim); + } else { + shape_O = GqaPacking ? make_shape(seq_len_qo * head_group_q, s.head_size_vo, s.num_heads_kv * num_kv_splits, batch_dim) : make_shape(seq_len_qo, s.head_size_vo, s.num_heads_q * num_kv_splits, batch_dim); + } + auto shape_exp_sums = GqaPacking ? make_shape(seq_len_qo * head_group_q, num_kv_splits, s.num_heads_kv, batch_dim) : make_shape(seq_len_qo, num_kv_splits, s.num_heads_q, batch_dim); + auto shape_max_logits = GqaPacking ? make_shape(seq_len_qo * head_group_q, num_kv_splits, s.num_heads_kv, batch_dim) : make_shape(seq_len_qo, num_kv_splits, s.num_heads_q, batch_dim); - if constexpr (SplitKV) { - num_blocks_per_split = cute::ceil_div(k_blocks, num_kv_splits); - kv_split_offset = idx_kv_split * num_blocks_per_split; - num_effective_kv_blocks = cute::min(k_blocks - kv_split_offset, num_blocks_per_split); - effective_kv_seq_length = num_effective_kv_blocks * get<1>(TileShapeQK{}); - effective_kv_seq_length = seq_len; + int num_blocks_per_split = cute::ceil_div(k_blocks, num_kv_splits); + int kv_split_offset = idx_kv_split * num_blocks_per_split; + int num_effective_kv_blocks = cute::min(k_blocks - kv_split_offset, num_blocks_per_split); #if 0 - if (thr_id == 0) { - cute::print("\nidx_kv_split: %d, kv_split_offset: %d, num_effective_kv_blocks: %d, k_blocks: %d, num_blocks_per_split: %d\n", - idx_kv_split, kv_split_offset, num_effective_kv_blocks, k_blocks, num_blocks_per_split); - } -#endif - - shape_K = make_shape(effective_kv_seq_length, s.head_size_qk, s.num_heads_kv, batch_dim); - shape_V = make_shape(s.head_size_vo, effective_kv_seq_length, s.num_heads_kv, batch_dim); - // shape_O = make_shape(seq_len_qo, s.head_size_vo, 1, s.num_heads_q * batch_dim); - shape_O = GqaPacking ? make_shape(seq_len_qo * head_group_q, s.head_size_vo, s.num_heads_kv * batch_dim, num_kv_splits) : make_shape(seq_len_qo, s.head_size_vo, s.num_heads_q * batch_dim, num_kv_splits); - - // shape_exp_sums = make_shape(s.seq_len_qo, 1, s.num_heads_q, batch_dim); - // shape_max_logits = make_shape(s.seq_len_qo, 1, s.num_heads_q, batch_dim); - shape_exp_sums = GqaPacking ? make_shape(s.seq_len_qo * head_group_q, num_kv_splits, s.num_heads_kv, batch_dim) : make_shape(s.seq_len_qo, num_kv_splits, s.num_heads_q, batch_dim); - shape_max_logits = GqaPacking ? make_shape(s.seq_len_qo * head_group_q, num_kv_splits, s.num_heads_kv, batch_dim) : make_shape(s.seq_len_qo, num_kv_splits, s.num_heads_q, batch_dim); - - // TODO: adapt for var length - // offset_k = ((kv_split_offset * get<1>(TileShapeQK{})) * s.head_size_qk) * s.num_heads_kv * batch_dim; - // offset_k = kv_split_offset * get<1>(TileShapeQK{}) * s.head_size_qk; - // // offset_v = s.num_heads_kv * s.head_size_vo * (idx_b * seq_len + kv_split_offset * get<1>(TileShapeQK{})); - // // offset_v = ((kv_split_offset * get<1>(TileShapeQK{})) * s.head_size_vo) * s.num_heads_kv * batch_dim; - // offset_v = kv_split_offset * get<1>(TileShapeQK{}) * s.head_size_qk; - - // // assume: Oaccum is allocated with shape (num_kv_splits, batch * num_heads_q, seq_len_qo, head_size_vo) - // offset_o = s.head_size_vo * seq_len_qo * idx_kv_split * s.num_heads_q * batch_dim; - - // offset_exp_sums = idx_kv_split; - // offset_max_logits = idx_kv_split; - - offset_o = offset_k = offset_v = offset_exp_sums = offset_max_logits = 0; - } else { - shape_K = make_shape(seq_len_kv, s.head_size_qk, s.num_heads_kv, batch_dim); - shape_V = make_shape(s.head_size_vo, seq_len_kv, s.num_heads_kv, batch_dim); - shape_O = make_shape(seq_len_qo, s.head_size_vo, s.num_heads_q, batch_dim); + if (thr_id == 0) { + cute::print("\nidx_kv_split: %d, kv_split_offset: %d, num_effective_kv_blocks: %d, k_blocks: %d, num_blocks_per_split: %d\n", + idx_kv_split, kv_split_offset, num_effective_kv_blocks, k_blocks, num_blocks_per_split); } +#endif auto dcQ = const_cast(p.Q + offset_q); auto dcK = const_cast(p.K + offset_k); auto dcV = const_cast(p.V + offset_v); - auto ptrO = (SplitKV ? p.Oaccum : p.O) + offset_o; + auto ptrO = p.Oaccum + offset_o; auto ptrExp_sums = p.exp_sums + offset_exp_sums; auto ptrMax_logits = p.max_logits + offset_max_logits; auto stride_q = (is_var_len || GqaPacking) ? cutlass::make_cute_packed_stride(StrideQ{}, shape_Q) : p.dQ; auto stride_k = is_var_len ? cutlass::make_cute_packed_stride(StrideK{}, shape_K) : p.dK; auto stride_v = is_var_len ? cutlass::make_cute_packed_stride(StrideV{}, shape_V) : p.dV; - auto stride_o = (is_var_len || GqaPacking) ? cutlass::make_cute_packed_stride(StrideO{}, shape_O) : (SplitKV ? p.dOaccum : p.dO); - // auto stride_exp_sums = p.dExp_sums; - // auto stride_max_logits = p.dMax_logits; - auto stride_exp_sums = GqaPacking ? cutlass::make_cute_packed_stride(StrideQ{}, shape_exp_sums) : p.dExp_sums; - auto stride_max_logits = GqaPacking ? cutlass::make_cute_packed_stride(StrideQ{}, shape_max_logits) : p.dMax_logits; + auto stride_o = (is_var_len || GqaPacking) ? cutlass::make_cute_packed_stride(StrideO{}, shape_O) : p.dOaccum; + auto stride_exp_sums = GqaPacking ? cutlass::make_cute_packed_stride(StrideO{}, shape_exp_sums) : p.dExp_sums; + auto stride_max_logits = GqaPacking ? cutlass::make_cute_packed_stride(StrideO{}, shape_max_logits) : p.dMax_logits; Tensor Q = make_tensor(make_gmem_ptr(dcQ), make_layout(shape_Q, stride_q)); Tensor K = make_tensor(make_gmem_ptr(dcK), make_layout(shape_K, stride_k)); Tensor V = make_tensor(make_gmem_ptr(dcV), make_layout(shape_V, stride_v)); Tensor O = make_tensor(make_gmem_ptr(ptrO), make_layout(shape_O, stride_o)); + Tensor exp_sums = make_tensor(make_gmem_ptr(ptrExp_sums), make_layout(shape_exp_sums, stride_exp_sums)); + Tensor max_logits = make_tensor(make_gmem_ptr(ptrMax_logits), make_layout(shape_max_logits, stride_max_logits)); + #if 0 if (thr_id == 0 && BlockIdxZ() == 0 && idx_kv_split == 0 && head_q_start == 0) { @@ -989,9 +966,6 @@ class XeFMHAFwdSplitKVKernel { } #endif - Tensor exp_sums = make_tensor(make_gmem_ptr(ptrExp_sums), make_layout(shape_exp_sums, stride_exp_sums)); - Tensor max_logits = make_tensor(make_gmem_ptr(ptrMax_logits), make_layout(shape_max_logits, stride_max_logits)); - // O accumulator types FragA tArA; FragARow tA_max, tA_sum; @@ -999,8 +973,8 @@ class XeFMHAFwdSplitKVKernel { // Main loop int l_coord = is_var_len ? 0 : idx_b; - int start_blk = SplitKV ? kv_split_offset : 0; - int end_blk = SplitKV ? (kv_split_offset + num_effective_kv_blocks) : k_blocks; + int start_blk = kv_split_offset; + int end_blk = kv_split_offset + num_effective_kv_blocks; if (end_blk <= start_blk) { // early exit @@ -1015,7 +989,6 @@ class XeFMHAFwdSplitKVKernel { CollectiveMainloop mainloop(params.mainloop, shared_storage.mainloop); - // for GQA packing mainloop(Q(_,_,head,l_coord), K(_,_,head,l_coord), V(_,_,head,l_coord), @@ -1027,14 +1000,14 @@ class XeFMHAFwdSplitKVKernel { #if 0 // static_assert(is_same_v, "dtype mismatch"); - if (idx_kv_split == 0 && head_q == 0 && thr_id == 0) { + if (idx_kv_split == 0 && head == 0 && thr_id == 0) { // cute::print("idx_kv_split: %d, idx_b: %d, head_q: %d, Q(0,0,head_q,l_coord): %f\n", idx_kv_split, idx_b, head_q, float(Q(0,34,head_q,l_coord))); cute::print("idx_kv_split: 0, head_q: 0, tid: 0, tA_max(0): %f\n", float(tA_max(0))); cute::print("idx_kv_split: 0, head_q: 0, tid: 0, tA_sum(0): %f\n", float(tA_sum(0))); cute::print("idx_kv_split: 0, head_q: 0, tid: 0, tArA(0): %f\n", float(tArA(0))); } - if (idx_kv_split == 1 && head_q == 0 && thr_id == 0) { + if (idx_kv_split == 1 && head == 0 && thr_id == 0) { // cute::print("idx_kv_split: %d, idx_b: %d, head_q: %d, Q(0,0,head_q,l_coord): %f\n", idx_kv_split, idx_b, head_q, float(Q(0,34,head_q,l_coord))); cute::print("idx_kv_split: 1, head_q: 0, tid: 0, tA_max(0): %f\n", float(tA_max(0))); cute::print("idx_kv_split: 1, head_q: 0, tid: 0, tArA(0): %f\n", float(tArA(0))); @@ -1048,26 +1021,16 @@ class XeFMHAFwdSplitKVKernel { // Epilogue CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; - // for GQA packing for (int q_head_cnt = 0; q_head_cnt < head_group_q; ++q_head_cnt) { - int head_q_curr = head_q_start + q_head_cnt; - if constexpr (SplitKV) { - epilogue(O(_,_,idx_b * s.num_heads_kv + head, idx_kv_split), - tArA, tA_max, tA_sum, - blk_qv, thr_id, - exp_sums(_,_,head,l_coord), - max_logits(_,_,head,l_coord), - idx_kv_split, q_head_cnt); - } else { - // FIXME: use correct head q indx - epilogue(O(_,_,head_q_curr,l_coord), - tArA, tA_max, tA_sum, - blk_qv, thr_id); + epilogue(O(_,_,idx_kv_split * s.num_heads_kv + head,l_coord), + tArA, tA_max, tA_sum, + blk_qv, thr_id, + exp_sums(_,_,head,l_coord), + max_logits(_,_,head,l_coord), + idx_kv_split, q_head_cnt); } } } }; -}; - } // namespace cutlass::fmha::kernel diff --git a/applications/flash_attention_v2/kernel/xe_reduce_split_k.h b/applications/flash_attention_v2/kernel/xe_reduce_split_k.h index 8ae765638d..e13f9e5d43 100644 --- a/applications/flash_attention_v2/kernel/xe_reduce_split_k.h +++ b/applications/flash_attention_v2/kernel/xe_reduce_split_k.h @@ -61,6 +61,8 @@ class ReduceSplitK { public: using ProblemShape = ProblemShape_; + using VariableLength = cutlass::fmha::collective::VariableLength; + static constexpr bool is_var_len = cutlass::fmha::collective::is_variable_length_v; using TileScheduler = TileScheduler_; static_assert(is_same_v, "ReduceSplitK kernel requires XeReduceSplitKTileScheduler"); @@ -120,14 +122,14 @@ class ReduceSplitK { static Params to_underlying_arguments(Arguments const &args, void *workspace) { return {args.kernel, - TileScheduler::to_underlying_arguments(args.kernel.shape, args.hw_info, TileShapeO{}, args.num_kv_splits)}; + TileScheduler::template to_underlying_arguments(args.kernel.shape, args.hw_info, TileShapeO{}, args.num_kv_splits)}; } static bool can_implement(Arguments const &args) { // only support decode - if (args.kernel.shape.seq_len_qo > 1) { - return false; - } + // if (args.kernel.shape.seq_len_qo > 1) { + // return false; + // } if (args.num_kv_splits > FMHAKernel_::max_num_kv_splits) { return false; @@ -148,6 +150,15 @@ class ReduceSplitK { static dim3 get_block_shape() { return dim3(SGPerWG::value * intel::sg_size, 1, 1); } + CUTLASS_DEVICE + Shape get_sequence_length_shape(ProblemShape const& problem_shape, int const& batch) { + if constexpr (is_var_len) { + return cutlass::fmha::collective::apply_variable_length(Shape{problem_shape.seq_len_qo, problem_shape.seq_len_kv}, batch); + } else { + return Shape{problem_shape.seq_len_qo, problem_shape.seq_len_kv}; + } + } + /// Perform a reduction CUTLASS_DEVICE void operator()(Params const ¶ms, char *smem_buf) { @@ -165,59 +176,59 @@ class ReduceSplitK { TileScheduler tile_scheduler{params.scheduler}; auto num_kv_splits = params.scheduler.num_kv_splits; - auto seq_len_qo = s.seq_len_qo; - auto seq_len_kv = s.seq_len_kv; - auto batch_dim = s.batch; + auto batch_dim = is_var_len ? 1 : s.batch; auto num_heads_q = s.num_heads_q; auto head_size_vo = s.head_size_vo; - const int k_blocks = cute::ceil_div(seq_len_kv, get<1>(TileShapeQK{})); - int num_blocks_per_split = cute::ceil_div(k_blocks, num_kv_splits); - CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto [seq_idx, head_q, idx_b] = tile_scheduler.get_block_coord(); + auto sequence_length_shape = get_sequence_length_shape(s, idx_b); + auto [seq_len_qo, seq_len_kv] = sequence_length_shape; + + // when varlen enabled, use largest seq_len_qo to decide work group num + if (seq_idx >= seq_len_qo) continue; + + const int k_blocks = cute::ceil_div(seq_len_kv, get<1>(TileShapeQK{})); + int num_blocks_per_split = cute::ceil_div(k_blocks, num_kv_splits); + int offset_o = 0, offset_o_accum = 0; int offset_exp_sums = 0, offset_max_logits = 0; + if constexpr (is_var_len) { + auto qo_cumulative = s.seq_len_qo.cumulative_length; + + offset_o_accum = s.num_heads_q * s.head_size_vo * num_kv_splits * qo_cumulative[idx_b]; + offset_exp_sums = s.num_heads_q * num_kv_splits * qo_cumulative[idx_b]; + offset_max_logits = s.num_heads_q * num_kv_splits * qo_cumulative[idx_b]; + + offset_o = s.num_heads_q * s.head_size_vo * qo_cumulative[idx_b]; + } + auto shape_O = make_shape(seq_len_qo, head_size_vo, num_heads_q, batch_dim); - auto shape_Oaccum = make_shape(seq_len_qo, head_size_vo, num_heads_q * batch_dim, num_kv_splits); + auto shape_Oaccum = is_var_len ? make_shape(seq_len_qo, head_size_vo, num_heads_q * num_kv_splits, batch_dim) : make_shape(seq_len_qo, head_size_vo, num_heads_q * num_kv_splits, batch_dim); auto shape_exp_sums = make_shape(seq_len_qo, num_kv_splits, num_heads_q, batch_dim); auto shape_max_logits = make_shape(seq_len_qo, num_kv_splits, num_heads_q, batch_dim); - // assume: Oaccum is allocated with shape (num_kv_splits, batch * num_heads_q, seq_len_qo, head_size_vo) - // offset_o_accum = (idx_b * s.num_heads_q + head_q) * num_kv_splits * s.seq_len_qo * s.head_size_vo; - // offset_o = (idx_b * s.num_heads_q + head_q) * s.seq_len_qo * s.head_size_vo; - - // offset_exp_sums = (idx_b * s.num_heads_q + head_q) * s.seq_len_qo; - // offset_max_logits = (idx_b * s.num_heads_q + head_q) * s.seq_len_qo; auto dcOaccum = const_cast(p.Oaccum + offset_o_accum); - auto ptrO = p.O + offset_o; auto ptrExp_sums = const_cast(p.exp_sums + offset_exp_sums); auto ptrMax_logits = const_cast(p.max_logits + offset_max_logits); + auto ptrO = p.O + offset_o; - using Stride_O = cute::Stride, int64_t>; - using Stride_Oaccum = Stride_O; - using Stride_Exp_sums = Stride_O; - - // 3D - // static_assert(is_same_v, "dtype mismatched"); - // static_assert(is_same_v(StrideO{})), float>, "dtype mismatched"); - // auto stride_o_accum = cutlass::make_cute_packed_stride(Stride_Oaccum{}, shape_Oaccum); - // // 2D - // auto stride_o = cutlass::make_cute_packed_stride(Stride_O{}, shape_O); - // auto stride_exp_sums = cutlass::make_cute_packed_stride(Stride_Exp_sums{}, shape_exp_sums); - // auto stride_max_logits = cutlass::make_cute_packed_stride(Stride_Exp_sums{}, shape_max_logits); + auto stride_o = is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_O) : p.dO; + auto stride_o_accum = is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_Oaccum) : p.dOaccum; + auto stride_exp_sums = is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_exp_sums) : p.dExp_sums; + auto stride_max_logits = is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_max_logits) : p.dMax_logits; - Tensor Oaccum = make_tensor(make_gmem_ptr(dcOaccum), make_layout(shape_Oaccum, p.dOaccum)); - Tensor O = make_tensor(make_gmem_ptr(ptrO), make_layout(shape_O, p.dO)); + Tensor Oaccum = make_tensor(make_gmem_ptr(dcOaccum), make_layout(shape_Oaccum, stride_o_accum)); + Tensor O = make_tensor(make_gmem_ptr(ptrO), make_layout(shape_O, stride_o)); - Tensor exp_sums = make_tensor(make_gmem_ptr(ptrExp_sums), make_layout(shape_exp_sums, p.dExp_sums)); - Tensor max_logits = make_tensor(make_gmem_ptr(ptrMax_logits), make_layout(shape_max_logits, p.dMax_logits)); + Tensor exp_sums = make_tensor(make_gmem_ptr(ptrExp_sums), make_layout(shape_exp_sums, stride_exp_sums)); + Tensor max_logits = make_tensor(make_gmem_ptr(ptrMax_logits), make_layout(shape_max_logits, stride_max_logits)); - // static_assert(is_same_v, "dtype mismatched"); + int l_coord = is_var_len ? 0 : idx_b; // Step 1: reduce max logits across different partitions // store into SLM for later use @@ -226,11 +237,11 @@ class ReduceSplitK { ElementO global_exp_sums = 0; // only first subgroup participates if (thr_id < num_kv_splits) { - ElementO cur_max_logit = max_logits(seq_idx, thr_id, head_q, idx_b); + ElementO cur_max_logit = max_logits(seq_idx, thr_id, head_q, l_coord); global_max_logits = sycl::max(global_max_logits, cur_max_logit); shared_storage.max_logits_slm_array[thr_id] = cur_max_logit; - ElementO cur_exp_sum = exp_sums(seq_idx, thr_id, head_q, idx_b); + ElementO cur_exp_sum = exp_sums(seq_idx, thr_id, head_q, l_coord); shared_storage.exp_sums_slm_array[thr_id] = cur_exp_sum; } @@ -260,7 +271,7 @@ class ReduceSplitK { // in FMHA epilogue, it's divided by local_exp_sum // assume seq_len_q == 1 - ElementO adjusted_o_accum = Oaccum(seq_idx, idx, idx_b * num_heads_q + head_q, i) * local_exp_sum; + ElementO adjusted_o_accum = Oaccum(seq_idx, idx, i * num_heads_q + head_q, l_coord) * local_exp_sum; acc += adjusted_o_accum * rescale; // update global exp sum @@ -269,7 +280,7 @@ class ReduceSplitK { ElementO inv_global_exp_sums = 1. / global_exp_sums; acc *= inv_global_exp_sums; - O(seq_idx, idx, head_q, idx_b) = acc; + O(seq_idx, idx, head_q, l_coord) = acc; } } } diff --git a/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp b/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp index 5b43474014..e077b250c9 100644 --- a/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp +++ b/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp @@ -65,6 +65,7 @@ struct XeFHMAIndividualTileScheduler { dim3 grid(size(ceil_div(shape.head_size_vo, get<1>(tile_shape))), // V size(ceil_div(shape.seq_len_qo, get<0>(tile_shape))), // Q size(shape.batch * shape.num_heads_kv)); // (h,b) -- split later + std::cout << "seq len qo: " << shape.seq_len_qo << ", seq_len_kv: " << shape.seq_len_kv << "\n"; if (num_kv_splits > 1) { grid.z *= num_kv_splits; } @@ -188,17 +189,21 @@ struct XeReduceSplitKTileScheduler { CUTLASS_DEVICE XeReduceSplitKTileScheduler(Params const& params) : params(params) {} - template + template static Params to_underlying_arguments( ProblemShape const& shape, KernelHardwareInfo hw_info, TileShape const& tile_shape, const int &num_kv_splits = -1) { using namespace cute; - // dim3 grid(size(ceil_div(shape.head_size_vo, get<1>(tile_shape))), // V - // size(ceil_div(shape.seq_len_qo, get<0>(tile_shape))), // Q - // size(shape.batch * shape.num_heads_q)); // (h,b) -- split later - dim3 grid(shape.seq_len_qo, shape.num_heads_q, shape.batch); + int seq_len_qo; + if constexpr (is_var_len) { + seq_len_qo = shape.seq_len_qo; + } else { + seq_len_qo = shape.seq_len_qo; + } + + dim3 grid(seq_len_qo, shape.num_heads_q, shape.batch); std::cout << "Reduce Split K Grid: (" << grid.x << ", " << grid.y << ", " << grid.z << ")\n"; return Params{grid, {shape.num_heads_q}, num_kv_splits}; } diff --git a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp index 391d9eef30..11fba5634a 100644 --- a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp +++ b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp @@ -263,8 +263,16 @@ struct ExampleRunner { int max_seqlen_kv = 0; for (int i = 0; i < num_batches; i++) { +#if defined(DECODE) + int seqlen_q = 1; + int seqlen_kv = cutlass::round_up(generate_positive_int(dist_kv, rng), AlignmentKV); + if (num_batches == 1) { + seqlen_kv = get<4>(problem_size); + } +#else int seqlen_q = cutlass::round_up(generate_positive_int(dist_q, rng), AlignmentQ); int seqlen_kv = cutlass::round_up(generate_positive_int(dist_kv, rng), AlignmentKV); +#endif total_seqlen_q += seqlen_q; total_seqlen_kv += seqlen_kv; @@ -540,7 +548,7 @@ struct ExampleRunner { block_ref_O.reset(static_cast(batch) * num_heads_q * seq_len_qo * head_size_vo); if constexpr (isSplitKV) { - stride_Oaccum = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, head_size_vo, num_heads_q * batch, num_kv_splits)); + stride_Oaccum = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, head_size_vo, num_heads_q * num_kv_splits, batch)); block_Oaccum.reset(static_cast(batch) * num_heads_q * seq_len_qo * head_size_vo * num_kv_splits); // assume seq_len_qo==1 @@ -887,7 +895,8 @@ struct FMHAConfig { static int run(const Options &options) { if (options.varlen) { - return run(options); + return splitkv ? run(options) : + run(options); } else { return persistent ? run(options) : (splitkv ? run(options) : From dbab7fb7c987c8a77e9f0e03ef1d231e9a4d44ae Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Wed, 17 Dec 2025 01:13:05 -0800 Subject: [PATCH 08/12] GQA packing as default --- .../kernel/xe_fhma_fwd_kernel.hpp | 39 ++++++++----------- .../kernel/xe_reduce_split_k.h | 2 +- .../kernel/xe_tile_scheduler.hpp | 11 +----- .../xe_fmha_fwd_runner.hpp | 1 + 4 files changed, 20 insertions(+), 33 deletions(-) diff --git a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp index 7b21a1807f..da3151fdac 100644 --- a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp +++ b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp @@ -58,7 +58,7 @@ struct FMHAProblemShape { /////////////////////////////////////////////////////////////////////////////// -template +template class XeFMHAFwdKernel { public: @@ -699,7 +699,7 @@ class XeFMHAFwdPersistentKernel { } }; -template +template class XeFMHAFwdSplitKVKernel { public: @@ -817,7 +817,7 @@ class XeFMHAFwdSplitKVKernel { } // when GQA packing enabled, limit head group size to 8 - if (GqaPacking && (args.kernel.shape.num_heads_q / args.kernel.shape.num_heads_kv > dpas_max_repeat_count)) { + if (args.kernel.shape.num_heads_q / args.kernel.shape.num_heads_kv > dpas_max_repeat_count) { return false; } @@ -903,25 +903,19 @@ class XeFMHAFwdSplitKVKernel { offset_exp_sums = s.num_heads_q * num_kv_splits * qo_cumulative[idx_b]; offset_max_logits = s.num_heads_q * num_kv_splits * qo_cumulative[idx_b]; - if constexpr (GqaPacking) { - // seq_len_qo must be 1 - seq_len_qo = 1; - } + // for gqa packing, seq_len_qo must be 1 + seq_len_qo = 1; } auto batch_dim = is_var_len ? 1 : s.batch; - auto shape_Q = GqaPacking ? make_shape(seq_len_qo * head_group_q, s.head_size_qk, s.num_heads_kv, batch_dim) : make_shape(seq_len_qo, s.head_size_qk, s.num_heads_q, batch_dim); + auto shape_Q = make_shape(seq_len_qo * head_group_q, s.head_size_qk, s.num_heads_kv, batch_dim); // 4D shape auto shape_K = make_shape(seq_len_kv, s.head_size_qk, s.num_heads_kv, batch_dim); auto shape_V = make_shape(s.head_size_vo, seq_len_kv, s.num_heads_kv, batch_dim); - decltype(shape_Q) shape_O; - if constexpr (is_var_len) { - shape_O = GqaPacking ? make_shape(seq_len_qo * head_group_q, s.head_size_vo, s.num_heads_kv * num_kv_splits, batch_dim) : make_shape(seq_len_qo, s.head_size_vo, s.num_heads_q * num_kv_splits, batch_dim); - } else { - shape_O = GqaPacking ? make_shape(seq_len_qo * head_group_q, s.head_size_vo, s.num_heads_kv * num_kv_splits, batch_dim) : make_shape(seq_len_qo, s.head_size_vo, s.num_heads_q * num_kv_splits, batch_dim); - } - auto shape_exp_sums = GqaPacking ? make_shape(seq_len_qo * head_group_q, num_kv_splits, s.num_heads_kv, batch_dim) : make_shape(seq_len_qo, num_kv_splits, s.num_heads_q, batch_dim); - auto shape_max_logits = GqaPacking ? make_shape(seq_len_qo * head_group_q, num_kv_splits, s.num_heads_kv, batch_dim) : make_shape(seq_len_qo, num_kv_splits, s.num_heads_q, batch_dim); + auto shape_O = make_shape(seq_len_qo * head_group_q, s.head_size_vo, s.num_heads_kv * num_kv_splits, batch_dim); + + auto shape_exp_sums = make_shape(seq_len_qo * head_group_q, num_kv_splits, s.num_heads_kv, batch_dim); + auto shape_max_logits = make_shape(seq_len_qo * head_group_q, num_kv_splits, s.num_heads_kv, batch_dim); int num_blocks_per_split = cute::ceil_div(k_blocks, num_kv_splits); int kv_split_offset = idx_kv_split * num_blocks_per_split; @@ -941,12 +935,12 @@ class XeFMHAFwdSplitKVKernel { auto ptrExp_sums = p.exp_sums + offset_exp_sums; auto ptrMax_logits = p.max_logits + offset_max_logits; - auto stride_q = (is_var_len || GqaPacking) ? cutlass::make_cute_packed_stride(StrideQ{}, shape_Q) : p.dQ; - auto stride_k = is_var_len ? cutlass::make_cute_packed_stride(StrideK{}, shape_K) : p.dK; - auto stride_v = is_var_len ? cutlass::make_cute_packed_stride(StrideV{}, shape_V) : p.dV; - auto stride_o = (is_var_len || GqaPacking) ? cutlass::make_cute_packed_stride(StrideO{}, shape_O) : p.dOaccum; - auto stride_exp_sums = GqaPacking ? cutlass::make_cute_packed_stride(StrideO{}, shape_exp_sums) : p.dExp_sums; - auto stride_max_logits = GqaPacking ? cutlass::make_cute_packed_stride(StrideO{}, shape_max_logits) : p.dMax_logits; + auto stride_q = cutlass::make_cute_packed_stride(StrideQ{}, shape_Q); + auto stride_k = cutlass::make_cute_packed_stride(StrideK{}, shape_K); + auto stride_v = cutlass::make_cute_packed_stride(StrideV{}, shape_V); + auto stride_o = cutlass::make_cute_packed_stride(StrideO{}, shape_O); + auto stride_exp_sums = cutlass::make_cute_packed_stride(StrideO{}, shape_exp_sums); + auto stride_max_logits = cutlass::make_cute_packed_stride(StrideO{}, shape_max_logits); Tensor Q = make_tensor(make_gmem_ptr(dcQ), make_layout(shape_Q, stride_q)); Tensor K = make_tensor(make_gmem_ptr(dcK), make_layout(shape_K, stride_k)); @@ -955,7 +949,6 @@ class XeFMHAFwdSplitKVKernel { Tensor exp_sums = make_tensor(make_gmem_ptr(ptrExp_sums), make_layout(shape_exp_sums, stride_exp_sums)); Tensor max_logits = make_tensor(make_gmem_ptr(ptrMax_logits), make_layout(shape_max_logits, stride_max_logits)); - #if 0 if (thr_id == 0 && BlockIdxZ() == 0 && idx_kv_split == 0 && head_q_start == 0) { cute::print("\nidx_kv_split: %d, idx_b: %d, head_q_start: %d, O shape: ", idx_kv_split, idx_b, head_q_start);cute::print(O.shape());print("\n"); diff --git a/applications/flash_attention_v2/kernel/xe_reduce_split_k.h b/applications/flash_attention_v2/kernel/xe_reduce_split_k.h index e13f9e5d43..2663003312 100644 --- a/applications/flash_attention_v2/kernel/xe_reduce_split_k.h +++ b/applications/flash_attention_v2/kernel/xe_reduce_split_k.h @@ -122,7 +122,7 @@ class ReduceSplitK { static Params to_underlying_arguments(Arguments const &args, void *workspace) { return {args.kernel, - TileScheduler::template to_underlying_arguments(args.kernel.shape, args.hw_info, TileShapeO{}, args.num_kv_splits)}; + TileScheduler::to_underlying_arguments(args.kernel.shape, args.hw_info, TileShapeO{}, args.num_kv_splits)}; } static bool can_implement(Arguments const &args) { diff --git a/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp b/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp index e077b250c9..9258f9bfea 100644 --- a/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp +++ b/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp @@ -189,21 +189,14 @@ struct XeReduceSplitKTileScheduler { CUTLASS_DEVICE XeReduceSplitKTileScheduler(Params const& params) : params(params) {} - template + template static Params to_underlying_arguments( ProblemShape const& shape, KernelHardwareInfo hw_info, TileShape const& tile_shape, const int &num_kv_splits = -1) { using namespace cute; - int seq_len_qo; - if constexpr (is_var_len) { - seq_len_qo = shape.seq_len_qo; - } else { - seq_len_qo = shape.seq_len_qo; - } - - dim3 grid(seq_len_qo, shape.num_heads_q, shape.batch); + dim3 grid(shape.seq_len_qo, shape.num_heads_q, shape.batch); std::cout << "Reduce Split K Grid: (" << grid.x << ", " << grid.y << ", " << grid.z << ")\n"; return Params{grid, {shape.num_heads_q}, num_kv_splits}; } diff --git a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp index 11fba5634a..7ae50df9b6 100644 --- a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp +++ b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp @@ -266,6 +266,7 @@ struct ExampleRunner { #if defined(DECODE) int seqlen_q = 1; int seqlen_kv = cutlass::round_up(generate_positive_int(dist_kv, rng), AlignmentKV); + // for test purpose if (num_batches == 1) { seqlen_kv = get<4>(problem_size); } From f3371559ffe147eca743233910edf7733160832d Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Wed, 17 Dec 2025 17:27:06 -0800 Subject: [PATCH 09/12] limit num kv splits to wg size --- .../kernel/xe_fhma_fwd_kernel.hpp | 12 ++++++------ .../flash_attention_v2/kernel/xe_reduce_split_k.h | 15 ++++++--------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp index da3151fdac..7f49f730f4 100644 --- a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp +++ b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp @@ -755,7 +755,7 @@ class XeFMHAFwdSplitKVKernel { static constexpr int SharedStorageSize = is_empty_v ? size_t(0) : sizeof(SharedStorage); - static constexpr int max_num_kv_splits = intel::sg_size; + static constexpr int max_num_kv_splits = SGPerWG::value * intel::sg_size; static constexpr int dpas_max_repeat_count = 8; // Device side arguments @@ -921,6 +921,11 @@ class XeFMHAFwdSplitKVKernel { int kv_split_offset = idx_kv_split * num_blocks_per_split; int num_effective_kv_blocks = cute::min(k_blocks - kv_split_offset, num_blocks_per_split); + if (num_effective_kv_blocks <= 0) { + // no need computation + continue; + } + #if 0 if (thr_id == 0) { cute::print("\nidx_kv_split: %d, kv_split_offset: %d, num_effective_kv_blocks: %d, k_blocks: %d, num_blocks_per_split: %d\n", @@ -969,11 +974,6 @@ class XeFMHAFwdSplitKVKernel { int start_blk = kv_split_offset; int end_blk = kv_split_offset + num_effective_kv_blocks; - if (end_blk <= start_blk) { - // early exit - return; - } - #if 0 if (thr_id == 0) { cute::print("\nidx_kv_split: %d, idx_b: %d, head_q_start: %d, start_blk: %d, end_blk: %d\n", idx_kv_split, idx_b, head_q_start, start_blk, end_blk); diff --git a/applications/flash_attention_v2/kernel/xe_reduce_split_k.h b/applications/flash_attention_v2/kernel/xe_reduce_split_k.h index 2663003312..e121d528b6 100644 --- a/applications/flash_attention_v2/kernel/xe_reduce_split_k.h +++ b/applications/flash_attention_v2/kernel/xe_reduce_split_k.h @@ -127,9 +127,9 @@ class ReduceSplitK { static bool can_implement(Arguments const &args) { // only support decode - // if (args.kernel.shape.seq_len_qo > 1) { - // return false; - // } + if (!is_var_len && args.kernel.shape.seq_len_qo > 1) { + return false; + } if (args.num_kv_splits > FMHAKernel_::max_num_kv_splits) { return false; @@ -248,13 +248,10 @@ class ReduceSplitK { // barrier for SLM writes finished sycl::group_barrier(get_work_group<3>()); - if (sub_group_id == 0) { - // reduce within subgroup - // here assume num_kv_splits not exceed subgroup size 16 - global_max_logits = reduce_over_group(get_sub_group(), global_max_logits, sycl::maximum<>()); - } + // reduce across wg + global_max_logits = reduce_over_group(get_work_group<1>(), global_max_logits, sycl::maximum<>()); - // broadcast to other threads + // broadcast to all other threads global_max_logits = sycl::group_broadcast(get_work_group<1>(), global_max_logits, 0); // step 2: rescale Oaccum and write back to O From 61649ceadd191aeae0b1b55cc84cb67ac3ef1b91 Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Wed, 17 Dec 2025 17:36:59 -0800 Subject: [PATCH 10/12] fix tile shceduler --- .../flash_attention_v2/kernel/xe_tile_scheduler.hpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp b/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp index 9258f9bfea..753b431fe3 100644 --- a/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp +++ b/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp @@ -64,13 +64,17 @@ struct XeFHMAIndividualTileScheduler { dim3 grid(size(ceil_div(shape.head_size_vo, get<1>(tile_shape))), // V size(ceil_div(shape.seq_len_qo, get<0>(tile_shape))), // Q - size(shape.batch * shape.num_heads_kv)); // (h,b) -- split later + size(shape.batch * shape.num_heads_q)); // (h,b) -- split later std::cout << "seq len qo: " << shape.seq_len_qo << ", seq_len_kv: " << shape.seq_len_kv << "\n"; + int num_head = shape.num_heads_q; if (num_kv_splits > 1) { + // for splitKV, each wg handles group query heads + grid.z = size(shape.batch * shape.num_heads_kv); grid.z *= num_kv_splits; + num_head = shape.num_heads_kv; } std::cout << "XeFHMAIndividualTileScheduler Grid: (" << grid.x << ", " << grid.y << ", " << grid.z << ")\n"; - return Params{grid, {shape.num_heads_kv}, {shape.batch * shape.num_heads_kv}, num_kv_splits}; + return Params{grid, {num_head}, {shape.batch * num_head}, num_kv_splits}; } template From 386b28803e7407833388515891d10b9a91e97cef Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Sun, 21 Dec 2025 18:39:53 -0800 Subject: [PATCH 11/12] fix return order --- .../flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp b/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp index 6399c4fce8..455cc38a7c 100644 --- a/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp +++ b/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp @@ -266,8 +266,8 @@ class FMHAFwdEpilogue { using namespace sycl::ext::oneapi::this_work_item; if constexpr (ReduceK{} == _1{}) { - ReduceFragARow rA_sum; - return std::make_tuple(tArA, tA_sum, rA_sum, true); + ReduceFragARow rA_max; + return std::make_tuple(tArA, rA_max, tA_sum, true); } else { /* Identify A tile ID and k block for this subgroup. */ auto thr_vak = group<1,3>(TiledMMAPV{}.get_thr_layout_vmnk()).get_flat_coord(assert_uniform(thr_id)); From c126c71e865930281a97c2b117ac8e0029505674 Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Sun, 21 Dec 2025 20:07:33 -0800 Subject: [PATCH 12/12] extend to all head_dim --- .../06_bmg_flash_attention/06_xe_fmha_fwd.cpp | 21 +++++++++---------- .../06_bmg_flash_attention/CMakeLists.txt | 2 +- .../xe_fmha_fwd_runner.hpp | 2 +- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp b/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp index 19e481a49f..6d4fdc6168 100644 --- a/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp +++ b/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp @@ -112,9 +112,8 @@ int main(int argc, const char **argv) { #define NUM_SG _16 #define KV_TILE_SIZE _256 +#if defined(SPLITKV) // turn on gqa packing optimizations -#define GQA_PACKING -#if defined(GQA_PACKING) // dpas maximum repeat count is 8 #define Q_FUSED_TILE_SIZE _8 #else @@ -135,15 +134,15 @@ int main(int argc, const char **argv) { using SubgroupLayoutQK = Layout>; #elif HEAD_DIM == 64 - using ShapeQK = Shape<_1, KV_TILE_SIZE, _64>; - using ShapePV = Shape<_1, _32, KV_TILE_SIZE>; - using ShapeOut = Shape<_1, _64>; + using ShapeQK = Shape; + using ShapePV = Shape; + using ShapeOut = Shape; using SubgroupLayoutQK = Layout>; #elif HEAD_DIM == 96 - using ShapeQK = Shape<_1, KV_TILE_SIZE, _64>; - using ShapePV = Shape<_1, _32, KV_TILE_SIZE>; - using ShapeOut = Shape<_1, _96>; + using ShapeQK = Shape; + using ShapePV = Shape; + using ShapeOut = Shape; using SubgroupLayoutQK = Layout>; #elif HEAD_DIM == 128 @@ -153,9 +152,9 @@ int main(int argc, const char **argv) { using SubgroupLayoutQK = Layout>; #elif HEAD_DIM == 192 - using ShapeQK = Shape<_1, KV_TILE_SIZE, _64>; - using ShapePV = Shape<_1, _32, KV_TILE_SIZE>; - using ShapeOut = Shape<_1, _192>; + using ShapeQK = Shape; + using ShapePV = Shape; + using ShapeOut = Shape; using SubgroupLayoutQK = Layout>; #endif #else diff --git a/examples/06_bmg_flash_attention/CMakeLists.txt b/examples/06_bmg_flash_attention/CMakeLists.txt index 89dfe27be7..af451a2a56 100644 --- a/examples/06_bmg_flash_attention/CMakeLists.txt +++ b/examples/06_bmg_flash_attention/CMakeLists.txt @@ -52,7 +52,7 @@ foreach(HEAD_DIM 64 96 128 192) ) endif() - # specific test for split kernel + # specific test for split reduction kernel cutlass_example_add_executable( 06_xe_fmha_fwd_decode_splitkv_${INPUT_TYPE}_hdim${HEAD_DIM} 06_xe_fmha_fwd.cpp diff --git a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp index 7ae50df9b6..3c211d9f1e 100644 --- a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp +++ b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp @@ -90,7 +90,7 @@ struct Options { cmd.get_cmd_line_argument("scheduler", scheduler, std::string("Individual")); -#ifdef PERSISTENT +#if defined(PERSISTENT) || defined(SPLITKV) cmd.get_cmd_line_argument("batch", batch, 1); cmd.get_cmd_line_argument("num_heads_q", num_heads_q, 8); cmd.get_cmd_line_argument("num_heads_kv", num_heads_kv, 1);