diff --git a/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp b/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp index efa54931d3..455cc38a7c 100644 --- a/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp +++ b/applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp @@ -92,6 +92,7 @@ class FMHAFwdEpilogue { Stride, E<0>>{}) )); using ReduceFragARow = decltype(reduce<1>(ReduceFragA{}, sycl::plus{})); + // static_assert(is_same_v, "dtype mismatched"); static auto default_tiled_copy_O_helper() { if constexpr (ReduceK{} == _1{}) @@ -149,7 +150,79 @@ class FMHAFwdEpilogue { using ElementA = typename FragA::element_type; // Reduce k-blocks of A and A_sum across WG, if needed. - auto [rA, rA_sum, active] = reduce_A(tArA, tA_max, tA_sum, thr_id); + auto [rA, rA_max_unused, rA_sum, active] = reduce_A(tArA, tA_max, tA_sum, thr_id); + + /* Some subgroups may not have any work to do; if so, quit early. */ + if (!active) return; + + /* Complete softmax, dividing out sums. */ + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA_sum.size(); i++) + rA_sum(i) = ElementA(1) / rA_sum(i); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA.size(); i++) + rA(i) *= broadcast<0>(rA_sum, rA, i); + +#if 1 + if (ThreadIdxX() == 0) { + // cute::print("wg id: %d, rA(0): %f, rA_sum(0): %f\n", BlockIdxZ(), (float)rA(0),(float)rA_sum(0)); + } +#endif + + /* Tile output */ + Tensor cO = make_identity_tensor(O.shape()); // (q,v) + Tensor gO = local_tile(cO, TileShapeO{}, blk_qv); // (q,v) + + /* Prepare slices */ + TiledCopyO copy_o{O}; + auto thr_copy_o = copy_o.get_slice(thr_id); + + auto tOrO = thr_copy_o.partition_sg_fragment_S(gO); + auto tOgO = thr_copy_o.partition_D(gO); + + /* Reorder tile and write out */ + reorder(rA, tOrO); + copy(copy_o, tOrO, tOgO); + } + + // splitK version + template + CUTLASS_DEVICE + void + operator()(TensorO2D const& O, // Global O tensor: (q,v) + FragA & tArA, // O accumulator: (q,v) + FragARow & tA_max, // Softmax row-wise max accumulator + FragARow & tA_sum, // Softmax row-wise sum accumulator + QVCoord blk_qv, // WG tile indices: (q,v) + int thr_id, // Work-item ID + const TensorO2D & exp_sums, // Global exp sum tensor + const TensorO2D & max_logits, // Global max logits tensor + int idx_kv_split, + int head_q + ) { + + using namespace cute; + using ElementA = typename FragA::element_type; + +#if 0 + if (ThreadIdxX() == 0 && BlockIdxZ() == 0) { + // cute::print("idx_kv_split: %d, idx_b: %d, head_q: %d, Q(0,0,head_q,l_coord): %f\n", idx_kv_split, idx_b, head_q, float(Q(0,34,head_q,l_coord))); + cute::print(" fwd epilogue tA_max(0): %f\n", float(tA_max(0))); + cute::print(" fwd epilogue tA_sum(0): %f\n", float(tA_sum(0))); + cute::print(" fwd epilogue tArA(0): %f\n", float(tArA(0))); + } +#endif + + // Reduce k-blocks of A and A_sum across WG, if needed. + auto [rA, rA_max, rA_sum, active] = reduce_A(tArA, tA_max, tA_sum, thr_id); + + // store exp sum and max logits for current KV split + // assume seq_len_qo == 1 + if (ThreadIdxX() == 0) { + exp_sums(head_q,idx_kv_split) = rA_sum(0); + max_logits(head_q,idx_kv_split) = rA_max(0); + } /* Some subgroups may not have any work to do; if so, quit early. */ if (!active) return; @@ -193,7 +266,8 @@ class FMHAFwdEpilogue { using namespace sycl::ext::oneapi::this_work_item; if constexpr (ReduceK{} == _1{}) { - return std::make_tuple(tArA, tA_sum, true); + ReduceFragARow rA_max; + return std::make_tuple(tArA, rA_max, tA_sum, true); } else { /* Identify A tile ID and k block for this subgroup. */ auto thr_vak = group<1,3>(TiledMMAPV{}.get_thr_layout_vmnk()).get_flat_coord(assert_uniform(thr_id)); @@ -285,7 +359,7 @@ class FMHAFwdEpilogue { } } } - return std::make_tuple(rA, rA_sum, active); + return std::make_tuple(rA, rA_max, rA_sum, active); } } }; diff --git a/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp index b7c400a63a..ed776d166d 100644 --- a/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp +++ b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp @@ -124,6 +124,7 @@ struct FMHAFwdMainloop, CausalMask_, using SingleFragA = FragC; // (atom val,q',v') using FragA = expand_sg_fragment_t; // (atom val,q',v',VV) using FragARow = decltype(reduce<1>(FragA{}, sycl::plus{})); + // static_assert(is_same_v, "dtype mismatched"); using ElementA = typename TiledMMAPV::ValTypeD; static constexpr bool CausalMask = CausalMask_; @@ -175,7 +176,8 @@ struct FMHAFwdMainloop, CausalMask_, int thr_id, int seq_len, int full_tile_offset, - int discard_seq_coord) { + int discard_seq_coord, + bool need_init = false) { using namespace sycl::ext::oneapi::this_work_item; // Short dimension names: @@ -195,6 +197,12 @@ struct FMHAFwdMainloop, CausalMask_, Tensor cV = make_identity_tensor(V_2D.shape()); // (v,k) Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{})); // (q,k) +#if 0 + if (ThreadIdxX() == 0 && BlockIdxZ() == 0) { + print("Q 2D shape: "); print(Q_2D.shape()); print("\n"); + } +#endif + /* Partition global tensors into workgroup tiles */ Tensor gQ = local_tile(cQ, TileShapeQK{}, append(blk_qv,_), Step<_1,X,_1>{}); // (q,d,D) Tensor gK = local_tile(cK, TileShapeQK{}, make_coord(_,_,_), Step{}); // (k,d,K,D) @@ -251,7 +259,7 @@ struct FMHAFwdMainloop, CausalMask_, /* Initialization steps for first block: Q/K prefetch, O init */ /* TODO: limit D prefetch for large head size, and reorder K prefetches */ - if (blk_k0 == 0) { + if (blk_k0 == 0 || need_init) { for (int D = 0; D < size<3>(pQgQ); D++) { prefetch(prefetch_q, pQgQ(_,_,_,D)); } diff --git a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp index f5905f746a..7f49f730f4 100644 --- a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp +++ b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp @@ -134,6 +134,7 @@ class XeFMHAFwdKernel { MainloopArguments mainloop{}; EpilogueArguments epilogue{}; KernelHardwareInfo hw_info{}; + int num_kv_splits = -1; }; // Kernel entry point API @@ -206,7 +207,7 @@ class XeFMHAFwdKernel { CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { - auto [blk_q, blk_v, head_q, idx_b] = tile_scheduler.get_block_coord(); // (Q,V,h,b) + auto [blk_q, blk_v, head_q, idx_b, unused] = tile_scheduler.get_block_coord(); // (Q,V,h,b) auto blk_qv = make_coord(blk_q, blk_v); int head = head_q / head_group_q; @@ -284,7 +285,7 @@ class XeFMHAFwdKernel { }; template -class XeFMHAFwdDynamicSplitKernel { +class XeFMHAFwdPersistentKernel { public: // @@ -319,7 +320,7 @@ class XeFMHAFwdDynamicSplitKernel { using ElementA = typename CollectiveMainloop::ElementA; // Tile scheduler derived types - static_assert(is_same_v); + static_assert(is_same_v); using TileScheduler = TileScheduler_; using TileSchedulerParams = typename TileScheduler::Params; @@ -367,6 +368,7 @@ class XeFMHAFwdDynamicSplitKernel { MainloopArguments mainloop{}; EpilogueArguments epilogue{}; KernelHardwareInfo hw_info{}; + int num_kv_splits = -1; }; // Kernel entry point API @@ -590,7 +592,7 @@ class XeFMHAFwdDynamicSplitKernel { V(_,_,head_kv,idx_b), tArA, tA_max, tA_sum, blk_qv, start_blk, end_blk, local_k_blocks, - thr_id, s.seq_len_kv, /*for causal*/0, 0); + thr_id, s.seq_len_kv, /*for causal*/0, 0, /*need_init*/true); // partition id of start batch head id in current wg int partition_id = get_partition_id(wg_id, batch_head_id, num_blocks_per_wg, local_k_blocks); @@ -697,4 +699,331 @@ class XeFMHAFwdDynamicSplitKernel { } }; +template +class XeFMHAFwdSplitKVKernel { + +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + using VariableLength = cutlass::fmha::collective::VariableLength; + static constexpr bool is_var_len = cutlass::fmha::collective::is_variable_length_v; + using CollectiveMainloop = CollectiveMainloop_; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + using TiledMMAQK = typename CollectiveMainloop::TiledMMAQK; + using TiledMMAPV = typename CollectiveMainloop::TiledMMAPV; + using TileShapeQK = typename CollectiveMainloop::TileShapeQK; + using TileShapePV = typename CollectiveMainloop::TileShapePV; + using SubgroupLayoutQK = typename CollectiveMainloop::SubgroupLayoutQK; + using ElementQ = typename CollectiveMainloop::TensorQ::element_type; + using ElementK = typename CollectiveMainloop::TensorK::element_type; + using ElementV = typename CollectiveMainloop::TensorV::element_type; + + using StrideQ = decltype(stride(typename CollectiveMainloop::TensorQ{})); + using StrideK = decltype(stride(typename CollectiveMainloop::TensorK{})); + using StrideV = decltype(stride(typename CollectiveMainloop::TensorV{})); + + using SGPerWG = typename CollectiveMainloop::SGPerWG; + + using FragA = typename CollectiveMainloop::FragA; + using FragARow = typename CollectiveMainloop::FragARow; + + // Tile scheduler derived types + using TileScheduler = TileScheduler_; + using TileSchedulerParams = typename TileScheduler::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + using TileShapeO = typename CollectiveEpilogue::TileShapeO; + using ElementO = typename CollectiveEpilogue::TensorO::element_type; + using StrideO = decltype(stride(typename CollectiveEpilogue::TensorO{})); + + // Kernel level shared memory storage + using MainloopSharedStorage = typename CollectiveMainloop::SharedStorage; + using EpilogueSharedStorage = typename CollectiveEpilogue::SharedStorage; + union SharedStorage { + MainloopSharedStorage mainloop; + EpilogueSharedStorage epilogue; + }; + + static constexpr int SharedStorageSize = is_empty_v ? size_t(0) + : sizeof(SharedStorage); + + static constexpr int max_num_kv_splits = SGPerWG::value * intel::sg_size; + static constexpr int dpas_max_repeat_count = 8; + + // Device side arguments + struct KernelArguments { + ProblemShape shape; + const ElementQ *Q; + StrideQ dQ; + const ElementK *K; + StrideK dK; + const ElementV *V; + StrideV dV; + ElementO *O; + StrideO dO; + // TODO: whether same dtype as output or accum? + ElementO *Oaccum; + StrideO dOaccum; + ElementO *exp_sums; + StrideO dExp_sums; + ElementO *max_logits; + StrideO dMax_logits; + }; + using KernelParams = KernelArguments; + + struct Arguments { + KernelArguments kernel{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + int num_kv_splits = -1; // no split by default + }; + + // Kernel entry point API + struct Params { + KernelParams kernel; + MainloopParams mainloop; + EpilogueParams epilogue; + TileSchedulerParams scheduler; + }; + + // + // Methods + // + + static Params to_underlying_arguments(Arguments const &args, void *workspace) { + return {args.kernel, + CollectiveMainloop::to_underlying_arguments(args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.kernel.shape, args.hw_info, TileShapeO{}, args.num_kv_splits)}; + } + + static bool can_implement(Arguments const &args) { + if (!is_var_len && args.kernel.shape.seq_len_qo != 1) { + // decode only + return false; + } + + if (args.num_kv_splits > max_num_kv_splits) { + return false; + } + + // when GQA packing enabled, limit head group size to 8 + if (args.kernel.shape.num_heads_q / args.kernel.shape.num_heads_kv > dpas_max_repeat_count) { + return false; + } + + return CollectiveMainloop::can_implement(args.mainloop) + && CollectiveEpilogue::can_implement(args.epilogue); + } + + static int get_workspace_size(Arguments const &args) { return 0; } + + static cutlass::Status initialize_workspace(Arguments const &args, void *workspace = nullptr, + cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const ¶ms) { + return TileScheduler::template get_grid_shape(params.scheduler); + } + + static dim3 get_block_shape() { return dim3(SGPerWG::value * intel::sg_size, 1, 1); } + + CUTLASS_DEVICE + Shape get_sequence_length_shape(ProblemShape const& problem_shape, int const& batch) { + if constexpr (is_var_len) { + return cutlass::fmha::collective::apply_variable_length(Shape{problem_shape.seq_len_qo, problem_shape.seq_len_kv}, batch); + } else { + return Shape{problem_shape.seq_len_qo, problem_shape.seq_len_kv}; + } + } + + CUTLASS_DEVICE + void operator()(Params const ¶ms, char *smem_buf) + { + using namespace sycl::ext::oneapi::this_work_item; + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + auto &p = params.kernel; + ProblemShape const& s = p.shape; + int head_group_q = s.num_heads_q / s.num_heads_kv; + + int thr_id = int(ThreadIdxX()); + int sub_group_id = thr_id / intel::sg_size; + int q_sg_tile = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{}))); + + auto cS = make_identity_tensor(take<0,2>(TiledMMAQK{}.tile_mnk())); + auto tScS = TiledMMAQK{}.get_slice(thr_id).partition_C(cS); + auto q_offset_wi = get<0>(tScS(0)); + auto q_offset_sg = group_broadcast(sycl::ext::oneapi::this_work_item::get_sub_group(), q_offset_wi, 0); + + TileScheduler tile_scheduler{params.scheduler}; + auto num_kv_splits = params.scheduler.num_kv_splits_; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + // auto [blk_q, blk_v, head_q, idx_b] = tile_scheduler.get_block_coord(); // (Q,V,h,b) + auto [blk_q, blk_v, head, idx_b, idx_kv_split] = tile_scheduler.get_block_coord(); // (Q,V,h,b) + auto blk_qv = make_coord(blk_q, blk_v); + int head_q_start = head * head_group_q; + + auto sequence_length_shape = get_sequence_length_shape(s, idx_b); + auto [seq_len_qo, seq_len_kv] = sequence_length_shape; + if (blk_q * get<0>(TileShapeQK{}) >= seq_len_qo) continue; + + auto offset = cute::min(seq_len_qo, seq_len_kv); + auto discard_seq_coord = seq_len_qo - offset; + auto full_tile_offset = seq_len_kv - offset; + int seq_coord = cute::min(seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + q_offset_sg)); + + if (CollectiveMainloop::CausalMask && seq_coord < discard_seq_coord) continue; + const int seq_len = CollectiveMainloop::CausalMask ? full_tile_offset + cute::min(seq_len_kv, seq_coord - discard_seq_coord) + q_sg_tile : seq_len_kv; + const int k_blocks = cute::ceil_div(seq_len, get<1>(TileShapeQK{})); + + int offset_q = 0, offset_k = 0, offset_v = 0, offset_o = 0; + int offset_exp_sums = 0, offset_max_logits = 0; + if constexpr (is_var_len) { + auto qo_cumulative = s.seq_len_qo.cumulative_length; + auto kv_cumulative = s.seq_len_kv.cumulative_length; + offset_q = s.num_heads_q * s.head_size_qk * qo_cumulative[idx_b]; + offset_k = s.num_heads_kv * s.head_size_qk * kv_cumulative[idx_b]; + offset_v = s.num_heads_kv * s.head_size_vo * kv_cumulative[idx_b]; + + offset_o = s.num_heads_q * s.head_size_vo * qo_cumulative[idx_b] * num_kv_splits; + offset_exp_sums = s.num_heads_q * num_kv_splits * qo_cumulative[idx_b]; + offset_max_logits = s.num_heads_q * num_kv_splits * qo_cumulative[idx_b]; + + // for gqa packing, seq_len_qo must be 1 + seq_len_qo = 1; + } + + auto batch_dim = is_var_len ? 1 : s.batch; + auto shape_Q = make_shape(seq_len_qo * head_group_q, s.head_size_qk, s.num_heads_kv, batch_dim); + // 4D shape + auto shape_K = make_shape(seq_len_kv, s.head_size_qk, s.num_heads_kv, batch_dim); + auto shape_V = make_shape(s.head_size_vo, seq_len_kv, s.num_heads_kv, batch_dim); + auto shape_O = make_shape(seq_len_qo * head_group_q, s.head_size_vo, s.num_heads_kv * num_kv_splits, batch_dim); + + auto shape_exp_sums = make_shape(seq_len_qo * head_group_q, num_kv_splits, s.num_heads_kv, batch_dim); + auto shape_max_logits = make_shape(seq_len_qo * head_group_q, num_kv_splits, s.num_heads_kv, batch_dim); + + int num_blocks_per_split = cute::ceil_div(k_blocks, num_kv_splits); + int kv_split_offset = idx_kv_split * num_blocks_per_split; + int num_effective_kv_blocks = cute::min(k_blocks - kv_split_offset, num_blocks_per_split); + + if (num_effective_kv_blocks <= 0) { + // no need computation + continue; + } + +#if 0 + if (thr_id == 0) { + cute::print("\nidx_kv_split: %d, kv_split_offset: %d, num_effective_kv_blocks: %d, k_blocks: %d, num_blocks_per_split: %d\n", + idx_kv_split, kv_split_offset, num_effective_kv_blocks, k_blocks, num_blocks_per_split); + } +#endif + + auto dcQ = const_cast(p.Q + offset_q); + auto dcK = const_cast(p.K + offset_k); + auto dcV = const_cast(p.V + offset_v); + auto ptrO = p.Oaccum + offset_o; + auto ptrExp_sums = p.exp_sums + offset_exp_sums; + auto ptrMax_logits = p.max_logits + offset_max_logits; + + auto stride_q = cutlass::make_cute_packed_stride(StrideQ{}, shape_Q); + auto stride_k = cutlass::make_cute_packed_stride(StrideK{}, shape_K); + auto stride_v = cutlass::make_cute_packed_stride(StrideV{}, shape_V); + auto stride_o = cutlass::make_cute_packed_stride(StrideO{}, shape_O); + auto stride_exp_sums = cutlass::make_cute_packed_stride(StrideO{}, shape_exp_sums); + auto stride_max_logits = cutlass::make_cute_packed_stride(StrideO{}, shape_max_logits); + + Tensor Q = make_tensor(make_gmem_ptr(dcQ), make_layout(shape_Q, stride_q)); + Tensor K = make_tensor(make_gmem_ptr(dcK), make_layout(shape_K, stride_k)); + Tensor V = make_tensor(make_gmem_ptr(dcV), make_layout(shape_V, stride_v)); + Tensor O = make_tensor(make_gmem_ptr(ptrO), make_layout(shape_O, stride_o)); + Tensor exp_sums = make_tensor(make_gmem_ptr(ptrExp_sums), make_layout(shape_exp_sums, stride_exp_sums)); + Tensor max_logits = make_tensor(make_gmem_ptr(ptrMax_logits), make_layout(shape_max_logits, stride_max_logits)); + +#if 0 + if (thr_id == 0 && BlockIdxZ() == 0 && idx_kv_split == 0 && head_q_start == 0) { + cute::print("\nidx_kv_split: %d, idx_b: %d, head_q_start: %d, O shape: ", idx_kv_split, idx_b, head_q_start);cute::print(O.shape());print("\n"); + cute::print("K shape: ");cute::print(K.shape());cute::print(K.stride());cute::print("\n"); + cute::print("V shape: ");cute::print(V.shape());cute::print(V.stride());cute::print("\n"); + cute::print("O stride: ");cute::print(O.stride());cute::print("\n"); + cute::print("stride_o: ");cute::print(stride_o);cute::print("\n"); + } +#endif + + // O accumulator types + FragA tArA; + FragARow tA_max, tA_sum; + + // Main loop + int l_coord = is_var_len ? 0 : idx_b; + + int start_blk = kv_split_offset; + int end_blk = kv_split_offset + num_effective_kv_blocks; + +#if 0 + if (thr_id == 0) { + cute::print("\nidx_kv_split: %d, idx_b: %d, head_q_start: %d, start_blk: %d, end_blk: %d\n", idx_kv_split, idx_b, head_q_start, start_blk, end_blk); + } +#endif + + CollectiveMainloop mainloop(params.mainloop, shared_storage.mainloop); + + mainloop(Q(_,_,head,l_coord), + K(_,_,head,l_coord), + V(_,_,head,l_coord), + tArA, tA_max, tA_sum, + blk_qv, start_blk, end_blk, k_blocks, + thr_id, seq_len, + full_tile_offset, discard_seq_coord, + /*need_init*/true); + +#if 0 + // static_assert(is_same_v, "dtype mismatch"); + if (idx_kv_split == 0 && head == 0 && thr_id == 0) { + // cute::print("idx_kv_split: %d, idx_b: %d, head_q: %d, Q(0,0,head_q,l_coord): %f\n", idx_kv_split, idx_b, head_q, float(Q(0,34,head_q,l_coord))); + cute::print("idx_kv_split: 0, head_q: 0, tid: 0, tA_max(0): %f\n", float(tA_max(0))); + cute::print("idx_kv_split: 0, head_q: 0, tid: 0, tA_sum(0): %f\n", float(tA_sum(0))); + cute::print("idx_kv_split: 0, head_q: 0, tid: 0, tArA(0): %f\n", float(tArA(0))); + } + + if (idx_kv_split == 1 && head == 0 && thr_id == 0) { + // cute::print("idx_kv_split: %d, idx_b: %d, head_q: %d, Q(0,0,head_q,l_coord): %f\n", idx_kv_split, idx_b, head_q, float(Q(0,34,head_q,l_coord))); + cute::print("idx_kv_split: 1, head_q: 0, tid: 0, tA_max(0): %f\n", float(tA_max(0))); + cute::print("idx_kv_split: 1, head_q: 0, tid: 0, tArA(0): %f\n", float(tArA(0))); + } +#endif + + if constexpr (!is_empty_v && !is_empty_v) { + sycl::group_barrier(get_work_group<3>()); + } + + // Epilogue + CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; + + for (int q_head_cnt = 0; q_head_cnt < head_group_q; ++q_head_cnt) { + epilogue(O(_,_,idx_kv_split * s.num_heads_kv + head,l_coord), + tArA, tA_max, tA_sum, + blk_qv, thr_id, + exp_sums(_,_,head,l_coord), + max_logits(_,_,head,l_coord), + idx_kv_split, q_head_cnt); + } + } + } +}; + } // namespace cutlass::fmha::kernel diff --git a/applications/flash_attention_v2/kernel/xe_reduce_split_k.h b/applications/flash_attention_v2/kernel/xe_reduce_split_k.h new file mode 100644 index 0000000000..e121d528b6 --- /dev/null +++ b/applications/flash_attention_v2/kernel/xe_reduce_split_k.h @@ -0,0 +1,290 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a reduction over densely packed tensors in global memory +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/kernel_hardware_info.hpp" + +#include "flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp" +#include "flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp" +#include "cute/util/type_traits.hpp" +#include "flash_attention_v2/collective/fmha_fusion.hpp" +#include "flash_attention_v2/kernel/xe_tile_scheduler.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class ReduceSplitK { +public: + + using ProblemShape = ProblemShape_; + using VariableLength = cutlass::fmha::collective::VariableLength; + static constexpr bool is_var_len = cutlass::fmha::collective::is_variable_length_v; + using TileScheduler = TileScheduler_; + static_assert(is_same_v, + "ReduceSplitK kernel requires XeReduceSplitKTileScheduler"); + using TileSchedulerParams = typename TileScheduler::Params; + + using ElementO = typename FMHAKernel_::ElementO; + using StrideO = typename FMHAKernel_::StrideO; + using TileShapeO = typename FMHAKernel_::TileShapeO; + using TileShapeQK = typename FMHAKernel_::TileShapeQK; + + using SGPerWG = typename FMHAKernel_::SGPerWG; + + // num values (head_dim) processed by each thread + constexpr static int num_vals_per_thread = int(get<1>(TileShapeO{}) / (SGPerWG::value * intel::sg_size)); + + // + // Types + // + + struct KernelArguments { + ProblemShape shape; + // outputs: + ElementO *O; + StrideO dO; + // below are inputs + // TODO: whether same dtype as output or accum? + const ElementO *Oaccum; + StrideO dOaccum; + const ElementO *exp_sums; + StrideO dExp_sums; + const ElementO *max_logits; + StrideO dMax_logits; + }; + using KernelParams = KernelArguments; + + struct Arguments { + KernelArguments kernel{}; + KernelHardwareInfo hw_info{}; + int num_kv_splits = -1; // no split by default + }; + + /// Params structure + struct Params { + KernelParams kernel; + TileSchedulerParams scheduler; + }; + + struct SharedStorage { + cutlass::Array max_logits_slm_array; + cutlass::Array exp_sums_slm_array; + }; + + static constexpr int SharedStorageSize = is_empty_v ? size_t(0) + : sizeof(SharedStorage); + +public: + + static Params to_underlying_arguments(Arguments const &args, void *workspace) { + return {args.kernel, + TileScheduler::to_underlying_arguments(args.kernel.shape, args.hw_info, TileShapeO{}, args.num_kv_splits)}; + } + + static bool can_implement(Arguments const &args) { + // only support decode + if (!is_var_len && args.kernel.shape.seq_len_qo > 1) { + return false; + } + + if (args.num_kv_splits > FMHAKernel_::max_num_kv_splits) { + return false; + } + return true; + } + + static int get_workspace_size(Arguments const &args) { return 0; } + + static cutlass::Status initialize_workspace(Arguments const &args, void *workspace = nullptr, + cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const ¶ms) { + return TileScheduler::template get_grid_shape(params.scheduler); + } + + static dim3 get_block_shape() { return dim3(SGPerWG::value * intel::sg_size, 1, 1); } + + CUTLASS_DEVICE + Shape get_sequence_length_shape(ProblemShape const& problem_shape, int const& batch) { + if constexpr (is_var_len) { + return cutlass::fmha::collective::apply_variable_length(Shape{problem_shape.seq_len_qo, problem_shape.seq_len_kv}, batch); + } else { + return Shape{problem_shape.seq_len_qo, problem_shape.seq_len_kv}; + } + } + + /// Perform a reduction + CUTLASS_DEVICE + void operator()(Params const ¶ms, char *smem_buf) { + using namespace sycl::ext::oneapi::this_work_item; + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + auto &p = params.kernel; + ProblemShape const& s = p.shape; + + int thr_id = int(ThreadIdxX()); + int sub_group_id = thr_id / intel::sg_size; + int tid_in_sg = thr_id % intel::sg_size; + + TileScheduler tile_scheduler{params.scheduler}; + auto num_kv_splits = params.scheduler.num_kv_splits; + + auto batch_dim = is_var_len ? 1 : s.batch; + auto num_heads_q = s.num_heads_q; + auto head_size_vo = s.head_size_vo; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto [seq_idx, head_q, idx_b] = tile_scheduler.get_block_coord(); + + auto sequence_length_shape = get_sequence_length_shape(s, idx_b); + auto [seq_len_qo, seq_len_kv] = sequence_length_shape; + + // when varlen enabled, use largest seq_len_qo to decide work group num + if (seq_idx >= seq_len_qo) continue; + + const int k_blocks = cute::ceil_div(seq_len_kv, get<1>(TileShapeQK{})); + int num_blocks_per_split = cute::ceil_div(k_blocks, num_kv_splits); + + int offset_o = 0, offset_o_accum = 0; + int offset_exp_sums = 0, offset_max_logits = 0; + + if constexpr (is_var_len) { + auto qo_cumulative = s.seq_len_qo.cumulative_length; + + offset_o_accum = s.num_heads_q * s.head_size_vo * num_kv_splits * qo_cumulative[idx_b]; + offset_exp_sums = s.num_heads_q * num_kv_splits * qo_cumulative[idx_b]; + offset_max_logits = s.num_heads_q * num_kv_splits * qo_cumulative[idx_b]; + + offset_o = s.num_heads_q * s.head_size_vo * qo_cumulative[idx_b]; + } + + auto shape_O = make_shape(seq_len_qo, head_size_vo, num_heads_q, batch_dim); + auto shape_Oaccum = is_var_len ? make_shape(seq_len_qo, head_size_vo, num_heads_q * num_kv_splits, batch_dim) : make_shape(seq_len_qo, head_size_vo, num_heads_q * num_kv_splits, batch_dim); + + auto shape_exp_sums = make_shape(seq_len_qo, num_kv_splits, num_heads_q, batch_dim); + auto shape_max_logits = make_shape(seq_len_qo, num_kv_splits, num_heads_q, batch_dim); + + auto dcOaccum = const_cast(p.Oaccum + offset_o_accum); + auto ptrExp_sums = const_cast(p.exp_sums + offset_exp_sums); + auto ptrMax_logits = const_cast(p.max_logits + offset_max_logits); + auto ptrO = p.O + offset_o; + + auto stride_o = is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_O) : p.dO; + auto stride_o_accum = is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_Oaccum) : p.dOaccum; + auto stride_exp_sums = is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_exp_sums) : p.dExp_sums; + auto stride_max_logits = is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_max_logits) : p.dMax_logits; + + Tensor Oaccum = make_tensor(make_gmem_ptr(dcOaccum), make_layout(shape_Oaccum, stride_o_accum)); + Tensor O = make_tensor(make_gmem_ptr(ptrO), make_layout(shape_O, stride_o)); + + Tensor exp_sums = make_tensor(make_gmem_ptr(ptrExp_sums), make_layout(shape_exp_sums, stride_exp_sums)); + Tensor max_logits = make_tensor(make_gmem_ptr(ptrMax_logits), make_layout(shape_max_logits, stride_max_logits)); + + int l_coord = is_var_len ? 0 : idx_b; + + // Step 1: reduce max logits across different partitions + // store into SLM for later use + + ElementO global_max_logits = cutlass::platform::numeric_limits::lowest(); + ElementO global_exp_sums = 0; + // only first subgroup participates + if (thr_id < num_kv_splits) { + ElementO cur_max_logit = max_logits(seq_idx, thr_id, head_q, l_coord); + global_max_logits = sycl::max(global_max_logits, cur_max_logit); + shared_storage.max_logits_slm_array[thr_id] = cur_max_logit; + + ElementO cur_exp_sum = exp_sums(seq_idx, thr_id, head_q, l_coord); + shared_storage.exp_sums_slm_array[thr_id] = cur_exp_sum; + } + + // barrier for SLM writes finished + sycl::group_barrier(get_work_group<3>()); + + // reduce across wg + global_max_logits = reduce_over_group(get_work_group<1>(), global_max_logits, sycl::maximum<>()); + + // broadcast to all other threads + global_max_logits = sycl::group_broadcast(get_work_group<1>(), global_max_logits, 0); + + // step 2: rescale Oaccum and write back to O + for (int idx = thr_id; idx < s.head_size_vo; idx += SGPerWG::value * intel::sg_size) { + ElementO acc = 0; + for (int i = 0; i < num_kv_splits; ++i) { + if (i * num_blocks_per_split > k_blocks) { + break; + } + ElementO local_max_logit = shared_storage.max_logits_slm_array[i]; + ElementO local_exp_sum = shared_storage.exp_sums_slm_array[i]; + + ElementO rescale = sycl::native::exp2(local_max_logit - global_max_logits); + + // in FMHA epilogue, it's divided by local_exp_sum + // assume seq_len_q == 1 + ElementO adjusted_o_accum = Oaccum(seq_idx, idx, i * num_heads_q + head_q, l_coord) * local_exp_sum; + acc += adjusted_o_accum * rescale; + + // update global exp sum + global_exp_sums += local_exp_sum * rescale; + } + + ElementO inv_global_exp_sums = 1. / global_exp_sums; + acc *= inv_global_exp_sums; + O(seq_idx, idx, head_q, l_coord) = acc; + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace reduction +} // namespace cutlass diff --git a/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp b/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp index 24a686993c..753b431fe3 100644 --- a/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp +++ b/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp @@ -45,6 +45,8 @@ struct XeFHMAIndividualTileScheduler { struct Params { dim3 grid; FastDivmod divmod_num_heads; + FastDivmod divmod_batch; + int num_kv_splits_ = -1; }; bool valid_ = true; @@ -56,14 +58,23 @@ struct XeFHMAIndividualTileScheduler { template static Params to_underlying_arguments( ProblemShape const& shape, KernelHardwareInfo hw_info, - TileShape const& tile_shape) + TileShape const& tile_shape, const int &num_kv_splits = -1) { using namespace cute; dim3 grid(size(ceil_div(shape.head_size_vo, get<1>(tile_shape))), // V size(ceil_div(shape.seq_len_qo, get<0>(tile_shape))), // Q - size(shape.batch * shape.num_heads_q)); // (h,b) -- split later - return Params{grid, {shape.num_heads_q}}; + size(shape.batch * shape.num_heads_q)); // (h,b) -- split later + std::cout << "seq len qo: " << shape.seq_len_qo << ", seq_len_kv: " << shape.seq_len_kv << "\n"; + int num_head = shape.num_heads_q; + if (num_kv_splits > 1) { + // for splitKV, each wg handles group query heads + grid.z = size(shape.batch * shape.num_heads_kv); + grid.z *= num_kv_splits; + num_head = shape.num_heads_kv; + } + std::cout << "XeFHMAIndividualTileScheduler Grid: (" << grid.x << ", " << grid.y << ", " << grid.z << ")\n"; + return Params{grid, {num_head}, {shape.batch * num_head}, num_kv_splits}; } template @@ -79,10 +90,18 @@ struct XeFHMAIndividualTileScheduler { CUTLASS_DEVICE auto get_block_coord() { using namespace cute; - int idx_b = BlockIdxZ(); - int head; + int idx_kv_split = BlockIdxZ(); + int head, idx_b; + + if (params.num_kv_splits_ > 1) { + params.divmod_batch(idx_kv_split, idx_b, idx_kv_split); + params.divmod_num_heads(idx_b, head, idx_b); + return make_coord(BlockIdxY(), BlockIdxX(), head, idx_b, idx_kv_split); + } + + idx_b = idx_kv_split; params.divmod_num_heads(idx_b, head, idx_b); - return make_coord(BlockIdxY(), BlockIdxX(), head, idx_b); + return make_coord(BlockIdxY(), BlockIdxX(), head, idx_b, (int)-1); } CUTLASS_DEVICE @@ -92,7 +111,7 @@ struct XeFHMAIndividualTileScheduler { } }; -struct XeFHMAIndividualPersistentTileScheduler { +struct XeFHMAPersistentTileScheduler { struct Params { dim3 grid; @@ -107,7 +126,7 @@ struct XeFHMAIndividualPersistentTileScheduler { int num_batch_heads_; CUTLASS_DEVICE - XeFHMAIndividualPersistentTileScheduler(Params const& params, int kv_tile_size, + XeFHMAPersistentTileScheduler(Params const& params, int kv_tile_size, int local_num_kv_blocks, int num_batch_heads) : params(params), kv_tile_size_(kv_tile_size), local_num_kv_blocks_(local_num_kv_blocks), num_batch_heads_(num_batch_heads) {} @@ -154,7 +173,57 @@ struct XeFHMAIndividualPersistentTileScheduler { } CUTLASS_DEVICE - XeFHMAIndividualPersistentTileScheduler& operator++() { + XeFHMAPersistentTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +struct XeReduceSplitKTileScheduler { + + struct Params { + dim3 grid; + FastDivmod divmod_num_heads; + int num_kv_splits = -1; + }; + + bool valid_ = true; + Params params; + + CUTLASS_DEVICE + XeReduceSplitKTileScheduler(Params const& params) : params(params) {} + + template + static Params to_underlying_arguments( + ProblemShape const& shape, KernelHardwareInfo hw_info, + TileShape const& tile_shape, const int &num_kv_splits = -1) + { + using namespace cute; + + dim3 grid(shape.seq_len_qo, shape.num_heads_q, shape.batch); + std::cout << "Reduce Split K Grid: (" << grid.x << ", " << grid.y << ", " << grid.z << ")\n"; + return Params{grid, {shape.num_heads_q}, num_kv_splits}; + } + + template + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + + return make_coord(BlockIdxX(), BlockIdxY(), BlockIdxZ()); + } + + CUTLASS_DEVICE + XeReduceSplitKTileScheduler& operator++() { valid_ = false; return *this; } diff --git a/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp b/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp index 9de908336f..6d4fdc6168 100644 --- a/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp +++ b/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp @@ -108,12 +108,22 @@ int main(int argc, const char **argv) { #endif #elif defined(DECODE) -#if PERSISTENT +#if defined(PERSISTENT) || defined(SPLITKV) #define NUM_SG _16 #define KV_TILE_SIZE _256 + +#if defined(SPLITKV) +// turn on gqa packing optimizations +// dpas maximum repeat count is 8 +#define Q_FUSED_TILE_SIZE _8 +#else +#define Q_FUSED_TILE_SIZE _1 +#endif + #else #define NUM_SG _8 #define KV_TILE_SIZE _512 +#define Q_FUSED_TILE_SIZE _1 #endif #if HEAD_DIM == 16 @@ -124,27 +134,27 @@ int main(int argc, const char **argv) { using SubgroupLayoutQK = Layout>; #elif HEAD_DIM == 64 - using ShapeQK = Shape<_1, KV_TILE_SIZE, _64>; - using ShapePV = Shape<_1, _32, KV_TILE_SIZE>; - using ShapeOut = Shape<_1, _64>; + using ShapeQK = Shape; + using ShapePV = Shape; + using ShapeOut = Shape; using SubgroupLayoutQK = Layout>; #elif HEAD_DIM == 96 - using ShapeQK = Shape<_1, KV_TILE_SIZE, _64>; - using ShapePV = Shape<_1, _32, KV_TILE_SIZE>; - using ShapeOut = Shape<_1, _96>; + using ShapeQK = Shape; + using ShapePV = Shape; + using ShapeOut = Shape; using SubgroupLayoutQK = Layout>; #elif HEAD_DIM == 128 - using ShapeQK = Shape<_1, KV_TILE_SIZE, _64>; - using ShapePV = Shape<_1, _32, KV_TILE_SIZE>; - using ShapeOut = Shape<_1, _128>; + using ShapeQK = Shape; + using ShapePV = Shape; + using ShapeOut = Shape; using SubgroupLayoutQK = Layout>; #elif HEAD_DIM == 192 - using ShapeQK = Shape<_1, KV_TILE_SIZE, _64>; - using ShapePV = Shape<_1, _32, KV_TILE_SIZE>; - using ShapeOut = Shape<_1, _192>; + using ShapeQK = Shape; + using ShapePV = Shape; + using ShapeOut = Shape; using SubgroupLayoutQK = Layout>; #endif #else @@ -171,9 +181,11 @@ int main(int argc, const char **argv) { #endif #if PERSISTENT - return FMHAConfig::run(options); + return FMHAConfig::run(options); +#elif SPLITKV + return FMHAConfig::run(options); #else - return options.is_causal ? FMHAConfig::run(options) - : FMHAConfig::run(options); + return options.is_causal ? FMHAConfig::run(options) + : FMHAConfig::run(options); #endif } diff --git a/examples/06_bmg_flash_attention/CMakeLists.txt b/examples/06_bmg_flash_attention/CMakeLists.txt index 435a65e6ea..af451a2a56 100644 --- a/examples/06_bmg_flash_attention/CMakeLists.txt +++ b/examples/06_bmg_flash_attention/CMakeLists.txt @@ -52,6 +52,12 @@ foreach(HEAD_DIM 64 96 128 192) ) endif() + # specific test for split reduction kernel + cutlass_example_add_executable( + 06_xe_fmha_fwd_decode_splitkv_${INPUT_TYPE}_hdim${HEAD_DIM} + 06_xe_fmha_fwd.cpp + ) + if(INPUT_TYPE STREQUAL "bfloat16_t") set(INPUT_MACRO "IS_BFLOAT16") elseif(INPUT_TYPE STREQUAL "float_e5m2_t") @@ -65,6 +71,7 @@ foreach(HEAD_DIM 64 96 128 192) if (NOT HEAD_DIM STREQUAL 192) target_compile_definitions(06_xe_fmha_fwd_decode_persistent_${INPUT_TYPE}_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} DECODE PERSISTENT SHOW_DIFF=1 INPUT_TYPE=${INPUT_TYPE} ${INPUT_MACRO}) endif() + target_compile_definitions(06_xe_fmha_fwd_decode_splitkv_${INPUT_TYPE}_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} DECODE SPLITKV SHOW_DIFF=1 INPUT_TYPE=${INPUT_TYPE} ${INPUT_MACRO}) endforeach() cutlass_example_add_executable( diff --git a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp index b1e9f0284d..3c211d9f1e 100644 --- a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp +++ b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp @@ -37,6 +37,7 @@ #include "flash_attention_v2/collective/fmha_fusion.hpp" #include "flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp" #include "flash_attention_v2/kernel/xe_tile_scheduler.hpp" +#include "flash_attention_v2/kernel/xe_reduce_split_k.h" #include "cutlass/util/GPU_Clock.hpp" #include "cutlass/util/sycl_event_manager.hpp" #include @@ -64,10 +65,11 @@ struct Options { int batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo, iterations; float softmax_scale; + int num_kv_splits; Options() : help(false), error(false), is_causal(false), varlen(false), batch(32), num_heads_q(16), num_heads_kv(16), seq_len_qo(512), head_size_qk(128), - seq_len_kv(512), head_size_vo(128), iterations(100), softmax_scale(1.f), scheduler("Individual") {} + seq_len_kv(512), head_size_vo(128), iterations(100), softmax_scale(1.f), num_kv_splits(1), scheduler("Individual") {} // Parses the command line void parse(int argc, char const **args) { @@ -88,7 +90,7 @@ struct Options { cmd.get_cmd_line_argument("scheduler", scheduler, std::string("Individual")); -#ifdef PERSISTENT +#if defined(PERSISTENT) || defined(SPLITKV) cmd.get_cmd_line_argument("batch", batch, 1); cmd.get_cmd_line_argument("num_heads_q", num_heads_q, 8); cmd.get_cmd_line_argument("num_heads_kv", num_heads_kv, 1); @@ -101,6 +103,11 @@ struct Options { #endif #ifdef DECODE cmd.get_cmd_line_argument("seq_len_qo", seq_len_qo, 1); +#ifdef SPLITKV + cmd.get_cmd_line_argument("num_kv_splits", num_kv_splits, 2); +#else + cmd.get_cmd_line_argument("num_kv_splits", num_kv_splits, 1); +#endif #else cmd.get_cmd_line_argument("seq_len_qo", seq_len_qo, seq_len_kv); #endif @@ -127,6 +134,7 @@ struct Options { << " --seq_len_kv= Sets the Sequence length of the Key-Value pair in Multi-Head Self Attention module\n" << " --head_size_qk= Sets the Attention Head dimension of the 1st Matrix Multiplication in Multi-Head Self Attention module\n" << " --head_size_vo= Sets the Attention Head dimension of the 2nd Matrix Multiplication in Multi-Head Self Attention module\n" + << " --num_kv_splits= Sets the Number of Key-Value splits in Multi-Head Self Attention module\n" << " --iterations= Iterations\n\n"; return out; @@ -166,7 +174,8 @@ using LayoutK = cutlass::layout::ColumnMajor; using LayoutV = cutlass::layout::RowMajor; using LayoutO = cutlass::layout::RowMajor; -template struct ExampleRunner { +template +struct ExampleRunner { using StrideQ = typename FMHAKernel::StrideQ; using StrideK = typename FMHAKernel::StrideK; @@ -183,6 +192,8 @@ template struct ExampleRunner { using ProblemShapeType = cutlass::fmha::kernel::FMHAProblemShape; + static_assert(!isSplitKV || !is_void_v, "Standalone reduction kernel is required if splitKV enabled"); + // // Data members // @@ -192,12 +203,22 @@ template struct ExampleRunner { StrideK stride_K; StrideV stride_V; StrideO stride_O; + StrideO stride_Oaccum; + StrideO stride_exp_sums; + StrideO stride_max_logits; uint64_t seed = 0; + int num_kv_splits; + cutlass::DeviceAllocation block_Q; cutlass::DeviceAllocation block_K; cutlass::DeviceAllocation block_V; cutlass::DeviceAllocation block_O; + // TODO: assume same dtype as outputs + cutlass::DeviceAllocation block_Oaccum; + cutlass::DeviceAllocation block_exp_sums; + cutlass::DeviceAllocation block_max_logits; + cutlass::DeviceAllocation block_ref_O; std::vector cumulative_seqlen_q; @@ -242,8 +263,17 @@ template struct ExampleRunner { int max_seqlen_kv = 0; for (int i = 0; i < num_batches; i++) { +#if defined(DECODE) + int seqlen_q = 1; + int seqlen_kv = cutlass::round_up(generate_positive_int(dist_kv, rng), AlignmentKV); + // for test purpose + if (num_batches == 1) { + seqlen_kv = get<4>(problem_size); + } +#else int seqlen_q = cutlass::round_up(generate_positive_int(dist_q, rng), AlignmentQ); int seqlen_kv = cutlass::round_up(generate_positive_int(dist_kv, rng), AlignmentKV); +#endif total_seqlen_q += seqlen_q; total_seqlen_kv += seqlen_kv; @@ -449,6 +479,33 @@ template struct ExampleRunner { compat::wait(); +#if 0 + std::vector vec_Oaccum(block_Oaccum.size()); + block_Oaccum.copy_to_host(vec_Oaccum.data()); + for (size_t i = 0; i < vec_Oaccum.size(); i++) { + std::cout << "Oaccum[" << i << "] = " << vec_Oaccum[i] << std::endl; + if (i > 20) break; + } + + std::vector vec_exp_sums(block_exp_sums.size()); + std::vector vec_max_logits(block_max_logits.size()); + block_exp_sums.copy_to_host(vec_exp_sums.data()); + block_max_logits.copy_to_host(vec_max_logits.data()); + for (size_t i = 0; i < vec_exp_sums.size(); i++) { + // if (i < 8 * 10) continue; + std::cout << "exp_sums[" << i << "]: " << vec_exp_sums[i] << ", max_logits[" << i << "]: " << vec_max_logits[i] << std::endl; + } + + std::vector vec_O(block_O.size()); + block_O.copy_to_host(vec_O.data()); + std::vector vec_ref_O(block_ref_O.size()); + block_ref_O.copy_to_host(vec_ref_O.data()); + for (size_t i = 0; i < vec_ref_O.size(); i++) { + if (i < 5 * 128) continue; + std::cout << "ref_O[" << i << "] = " << vec_ref_O[i] << " vs. O[" << i << "] = " << vec_O[i] << std::endl; + } +#endif + // Check if output from CUTLASS kernel and reference kernel are equal or not bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(), block_O.size(), ElementO{0.05}, ElementO{0.05}); @@ -460,6 +517,7 @@ template struct ExampleRunner { ProblemShapeType initialize(const Options &options) { auto problem_shape_in = cute::make_tuple(options.batch, options.num_heads_q, options.num_heads_kv, options.seq_len_qo, options.seq_len_kv, options.head_size_qk, options.head_size_vo); ProblemShapeType shape; + num_kv_splits = options.num_kv_splits; decltype(problem_shape_in) problem_size; @@ -490,6 +548,18 @@ template struct ExampleRunner { block_O.reset(static_cast(batch) * num_heads_q * seq_len_qo * head_size_vo); block_ref_O.reset(static_cast(batch) * num_heads_q * seq_len_qo * head_size_vo); + if constexpr (isSplitKV) { + stride_Oaccum = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, head_size_vo, num_heads_q * num_kv_splits, batch)); + block_Oaccum.reset(static_cast(batch) * num_heads_q * seq_len_qo * head_size_vo * num_kv_splits); + + // assume seq_len_qo==1 + stride_exp_sums = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, num_kv_splits, num_heads_q, batch)); + block_exp_sums.reset(static_cast(batch) * num_heads_q * seq_len_qo * num_kv_splits); + + stride_max_logits = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, num_kv_splits, num_heads_q, batch)); + block_max_logits.reset(static_cast(batch) * num_heads_q * seq_len_qo * num_kv_splits); + } + initialize_block(block_Q, seed + 2023); initialize_block(block_K, seed + 2022); initialize_block(block_V, seed + 2021); @@ -540,26 +610,101 @@ template struct ExampleRunner { EventManager::getInstance().addEvent(event); } + static void run(typename FMHAKernel::Params params, typename ReductionSplitKernel::Params reduce_params) + { + namespace syclex = sycl::ext::oneapi::experimental; + namespace intelex = sycl::ext::intel::experimental; + + dim3 const block = FMHAKernel::get_block_shape(); + dim3 const grid = FMHAKernel::get_grid_shape(params); + + // configure smem size and carveout + int smem_size = FMHAKernel::SharedStorageSize; + + const auto sycl_block = compat::dim3(block.x, block.y, block.z); + const auto sycl_grid = compat::dim3(grid.x, grid.y, grid.z); + + // Launch parameters depend on whether SYCL compiler supports work-group scratch memory extension + compat::experimental::launch_properties launch_props { + syclex::work_group_scratch_size(smem_size), + }; + compat::experimental::kernel_properties kernel_props{ + syclex::sub_group_size, + intelex::grf_size<256> + }; + compat::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props}; + auto event = compat::experimental::launch, FMHAKernel>(policy, params); + + dim3 const reduce_grid = ReductionSplitKernel::get_grid_shape(reduce_params); + int reduce_smem_size = ReductionSplitKernel::SharedStorageSize; + const auto reduce_sycl_block = compat::dim3(block.x, block.y, block.z); + const auto reduce_sycl_grid = compat::dim3(reduce_grid.x, reduce_grid.y, reduce_grid.z); + compat::experimental::launch_properties launch_props_reduce { + syclex::work_group_scratch_size(reduce_smem_size), + }; + compat::experimental::launch_policy reduce_policy{reduce_sycl_grid, reduce_sycl_block, launch_props_reduce, kernel_props}; + + // wait for FA kernel finished + // maybe no need wait here if launched with in-order queue + event.wait(); + + auto reduce_event = compat::experimental::launch, ReductionSplitKernel>(reduce_policy, reduce_params); + + EventManager::getInstance().addEvent(event); + EventManager::getInstance().addEvent(reduce_event); + } + cutlass::Status run(const Options &options, const cutlass::KernelHardwareInfo &hw_info) { ProblemShapeType shape = initialize(options); - typename FMHAKernel::Arguments arguments{ - { + typename FMHAKernel::KernelArguments kernel_args; + if constexpr (isSplitKV) { + kernel_args = typename FMHAKernel::KernelArguments { + shape, + block_Q.get(), stride_Q, + block_K.get(), stride_K, + block_V.get(), stride_V, + block_O.get(), stride_O, + block_Oaccum.get(), stride_Oaccum, + block_exp_sums.get(), stride_exp_sums, + block_max_logits.get(), stride_max_logits, + }; + } else { + kernel_args = typename FMHAKernel::KernelArguments { shape, block_Q.get(), stride_Q, block_K.get(), stride_K, block_V.get(), stride_V, block_O.get(), stride_O - }, + }; + } + + typename FMHAKernel::Arguments arguments { + kernel_args, {options.softmax_scale}, {}, - hw_info + hw_info, + options.num_kv_splits, + }; + + typename ReductionSplitKernel::Arguments reduce_arg { + {shape, + block_O.get(), stride_O, + block_Oaccum.get(), stride_Oaccum, + block_exp_sums.get(), stride_exp_sums, + block_max_logits.get(), stride_max_logits}, + hw_info, + options.num_kv_splits }; // Define device-global scratch memory size_t workspace_size = FMHAKernel::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); + size_t reduce_workspace_size = 0; + if constexpr (isSplitKV) { + reduce_workspace_size = ReductionSplitKernel::get_workspace_size(reduce_arg); + } + cutlass::device_memory::allocation workspace(workspace_size + reduce_workspace_size); if (!FMHAKernel::can_implement(arguments)) { std::cout << "Invalid Problem Size: " << options.batch << 'x' << options.num_heads_q << 'x' << @@ -573,9 +718,20 @@ template struct ExampleRunner { // Convert host-side arguments to device-side arguments to be passed to the kernel auto params = FMHAKernel::to_underlying_arguments(arguments, workspace.get()); + auto reduce_params = ReductionSplitKernel::to_underlying_arguments(reduce_arg, workspace.get() + workspace_size); // Run the GEMM - run(params); + if constexpr (isSplitKV) { + if (!ReductionSplitKernel::can_implement(reduce_arg)) { + std::cout << "Invalid Problem Size for ReductionSplitKernel" << std::endl; + return cutlass::Status::kErrorInvalidProblem; + } + + CUTLASS_CHECK(ReductionSplitKernel::initialize_workspace(reduce_arg, workspace.get() + workspace_size)); + run(params, reduce_params); + } else { + run(params); + } compat::wait(); @@ -584,17 +740,21 @@ template struct ExampleRunner { std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; if (!passed) { - return cutlass::Status::kErrorInternal; + // return cutlass::Status::kErrorInternal; } if (options.iterations > 0) { GPU_Clock timer; timer.start(); for (int i = 0; i < options.iterations; ++i) { - run(params); + if constexpr (isSplitKV) { + run(params, reduce_params); + } else { + run(params); + } } compat::wait(); - // when seq_len_qo is not equal to seq_len_kv we use bottom up approach for the masking. + // when seq_len_qo is not equal to seq_len_kv we use bottom up approach for the masking. // Following changes will adjust the effective_seq_len_kv when masking applied for such cases auto offset = cute::min(options.seq_len_qo, options.seq_len_kv); auto discard_seq_coord = options.seq_len_qo - offset; @@ -614,6 +774,7 @@ template struct ExampleRunner { std::cout << "Batch: " << options.batch << "\tNumHeads_q: " << options.num_heads_q << "\tNumHeads_kv: " << options.num_heads_kv << "\tSeq Length QO: " << options.seq_len_qo << "\tSeq Length KV: " << options.seq_len_kv << "\tHead Size QK: " << options.head_size_qk << "\tHead Size VO: " << options.head_size_vo << "\tCausal Mask: " << (options.is_causal ? "true" : "false") << "\tVariable Sequence Length: " << (options.varlen ? "true" : "false") + << "\tKV Splits: " << options.num_kv_splits << "\t Scheduler: " << options.scheduler; printf("\nPerformance: %4.3f GB/s, %4.3f TFlop/s, %6.4f ms\n\n", gbps, tflops, cute_time * 1000); } @@ -630,21 +791,25 @@ template default */ int PipelineStages, bool persistent, + bool splitkv, typename ElementQ = bfloat16_t, typename ElementK = bfloat16_t, typename ElementV = bfloat16_t, typename ElementO = float, + typename ElementOaccum = float, typename MMAOperation_ = void, /* void -> default */ typename StrideQ = Stride, typename StrideK = Stride, typename StrideV = Stride<_1, int, int, int>, typename StrideO = Stride, + typename StrideOaccum = Stride, typename GmemTiledCopyQ = void, /* void -> default block 2D */ typename GmemTiledCopyK = void, typename GmemTiledCopyV = void, typename GmemTiledCopyO = void> struct FMHAConfig { + // when GQA used, fused group query heads into one single MMA call static constexpr int SGTileQ = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{})))(); using MMAOperation = cute::conditional_t, typename cute::conditional_t< @@ -657,7 +822,7 @@ struct FMHAConfig { decltype(cutlass::fmha::collective::get_sg_layout_pv(SubgroupLayoutQK{})), SubgroupLayoutPV_>; - template + template static int run(const Options &options) { // // Run examples @@ -685,7 +850,10 @@ struct FMHAConfig { using TensorQ = decltype(make_dummy_tensor(ElementQ{}, StrideQ{})); using TensorK = decltype(make_dummy_tensor(ElementK{}, StrideK{})); using TensorV = decltype(make_dummy_tensor(ElementV{}, StrideV{})); - using TensorO = decltype(make_dummy_tensor(ElementO{}, StrideO{})); + using TensorO = conditional_t; // Mainloop using MainloopDispatchPolicy = cutlass::fmha::XeDefault; @@ -701,29 +869,39 @@ struct FMHAConfig { CollectiveMainloop, TileShapeOutput, TensorO, - GmemTiledCopyO + conditional_t >; static_assert(!(persistent & Causal), "persistent SDPA kernel not support Causal yet"); - using FMHAKernel = conditional_t, - cutlass::fmha::kernel::XeFMHAFwdDynamicSplitKernel< + using FMHAKernel = conditional_t, + cutlass::fmha::kernel::XeFMHAFwdPersistentKernel< ProblemShapeType, CollectiveMainloop, CollectiveEpilogue, Scheduler>, - cutlass::fmha::kernel::XeFMHAFwdKernel< - ProblemShapeType, CollectiveMainloop, CollectiveEpilogue, Scheduler> + conditional_t, + cutlass::fmha::kernel::XeFMHAFwdKernel< + ProblemShapeType, CollectiveMainloop, CollectiveEpilogue, Scheduler> + > >; - ExampleRunner runner; + using ReduceSplitKernel = cutlass::reduction::kernel::ReduceSplitK< + ProblemShapeType, cutlass::fmha::kernel::XeReduceSplitKTileScheduler, FMHAKernel>; + + ExampleRunner runner; CUTLASS_CHECK(runner.run(options, hw_info)); + return 0; } static int run(const Options &options) { if (options.varlen) { - return run(options); + return splitkv ? run(options) : + run(options); } else { - return persistent ? run(options) : - run(options); + return persistent ? run(options) : + (splitkv ? run(options) : + run(options)); } } };