diff --git a/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp index b7c400a63a..522809a819 100644 --- a/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp +++ b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp @@ -54,15 +54,20 @@ using namespace cute; template // Optional TiledCopy for loading V + class TiledCopyV_ = void, // Optional TiledCopy for loading V + class TiledCopyK_cache_ = void, + class TiledCopyV_cache_ = void> // Optional TiledCopy for loading V_cache struct FMHAFwdMainloop { static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); }; @@ -70,14 +75,18 @@ struct FMHAFwdMainloop { ///////////////////////////////////////////////////////////////////////////////////////////////// template -struct FMHAFwdMainloop, CausalMask_, + class TensorK_cache_, class TensorV_cache_, + class TiledCopyQ_, class TiledCopyK_, class TiledCopyV_, + class TiledCopyK_cache_, class TiledCopyV_cache_> +struct FMHAFwdMainloop, CausalMask_, PagedKV_, TiledMMAQK_, TiledMMAPV_, VTiles_, TensorQ_, TensorK_, TensorV_, - TiledCopyQ_, TiledCopyK_, TiledCopyV_> { + TensorK_cache_, TensorV_cache_, + TiledCopyQ_, TiledCopyK_, TiledCopyV_, + TiledCopyK_cache_, TiledCopyV_cache_> { // // Type Aliases // @@ -100,6 +109,12 @@ struct FMHAFwdMainloop, CausalMask_, using TiledCopyQ = conditional_t, decltype(make_block_2d_copy_A(TiledMMAQK{}, TensorQ2D{})), TiledCopyQ_>; using TiledCopyK = conditional_t, decltype(make_block_2d_copy_B(TiledMMAQK{}, TensorK2D{})), TiledCopyK_>; using TiledCopyV = conditional_t, decltype(make_block_2d_copy_B(TiledMMAPV{}, TensorV2D{})), TiledCopyV_>; + using TensorK_cache = TensorK_cache_; + using TensorV_cache = TensorV_cache_; + using TensorK_cache2D = decltype(TensorK_cache_{}(append>(make_coord(_,_),0))); + using TensorV_cache2D = decltype(TensorV_cache_{}(append>(make_coord(_,_),0))); + using TiledCopyK_cache = conditional_t, decltype(make_block_2d_copy_B(TiledMMAQK{}, TensorK_cache2D{})), TiledCopyK_cache_>; + using TiledCopyV_cache = conditional_t, decltype(make_block_2d_copy_B(TiledMMAPV{}, TensorV_cache2D{})), TiledCopyV_cache_>; // TODO: static_asserts on TiledMMAPV here... @@ -127,10 +142,14 @@ struct FMHAFwdMainloop, CausalMask_, using ElementA = typename TiledMMAPV::ValTypeD; static constexpr bool CausalMask = CausalMask_; + static constexpr bool PagedKV = PagedKV_; // User-facing arguments struct Arguments { ElementS const scale; + int const* ptr_page_table = nullptr; + int page_size = 0; + int const* num_pages_per_seq = nullptr; }; // Kernel-facing parameters @@ -151,7 +170,7 @@ struct FMHAFwdMainloop, CausalMask_, Params to_underlying_arguments(Arguments const &args, void * /* workspace */) { constexpr double kLog2e = 1.4426950408889634074; // log_2(e) ElementS val = args.scale * static_cast(kLog2e); - return Params{val}; + return Params{val, args.ptr_page_table, args.page_size, args.num_pages_per_seq}; } CUTLASS_HOST_DEVICE static @@ -159,6 +178,20 @@ struct FMHAFwdMainloop, CausalMask_, return true; } + CUTLASS_DEVICE + int get_physical_k_tile(int K, int l_coord, int seq_len_kv_cache) { + int next_page_logical_idx = K * get<1>(TileShapeQK{}) / params.page_size; + // get<1>(TileShapeQK{}) usually smaller than page_size. + // assuming page_size is multiple of get<1>(TileShapeQK{}) + int tiles_per_page = params.page_size / get<1>(TileShapeQK{}); + int batch_offset = params.num_pages_per_seq ? params.num_pages_per_seq[l_coord] : l_coord * (seq_len_kv_cache / params.page_size); + + return params.ptr_page_table[ + batch_offset + + next_page_logical_idx] * tiles_per_page + + K % tiles_per_page; + } + template CUTLASS_DEVICE void @@ -174,8 +207,12 @@ struct FMHAFwdMainloop, CausalMask_, int total_blk, // Total # of K blocks int thr_id, int seq_len, + int seq_len_kv_cache, + int l_coord, int full_tile_offset, - int discard_seq_coord) { + int discard_seq_coord, + TensorK_cache2D const& K_cache_2D = TensorK_cache2D{}, + TensorV_cache2D const& V_cache_2D = TensorV_cache2D{}) { using namespace sycl::ext::oneapi::this_work_item; // Short dimension names: @@ -193,6 +230,8 @@ struct FMHAFwdMainloop, CausalMask_, Tensor cQ = make_identity_tensor(Q_2D.shape()); // (q,d) Tensor cK = make_identity_tensor(K_2D.shape()); // (k,d) Tensor cV = make_identity_tensor(V_2D.shape()); // (v,k) + Tensor cK_cache = make_identity_tensor(K_cache_2D.shape()); // (k,d) + Tensor cV_cache = make_identity_tensor(V_cache_2D.shape()); // (v,k) Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{})); // (q,k) /* Partition global tensors into workgroup tiles */ @@ -201,10 +240,16 @@ struct FMHAFwdMainloop, CausalMask_, Tensor gV = local_tile(cV, tile_shape_v, make_coord(get<1>(blk_qv),_)); // (v,k,K) Tensor gV_split = local_tile(gV, TileShapePV{}, make_coord(_,_,0), Step{}); // (v,k,VV,K) + Tensor gK_cache = local_tile(cK_cache, TileShapeQK{}, make_coord(_,_,_), Step{}); // (k,d,K,D) + Tensor gV_cache = local_tile(cV_cache, tile_shape_v, make_coord(get<1>(blk_qv),_)); // (v,k,K) + Tensor gV_cache_split = local_tile(gV_cache, TileShapePV{}, make_coord(_,_,0), Step{}); // (v,k,VV,K) + /* Create global -> register copies */ TiledCopyQ copy_q{Q_2D}; TiledCopyK copy_k{K_2D}; TiledCopyV copy_v{V_2D}; + TiledCopyK_cache copy_k_cache{K_cache_2D}; + TiledCopyV_cache copy_v_cache{V_cache_2D}; /* Create MMAs */ TiledMMAQK mma_qk{}; @@ -214,6 +259,8 @@ struct FMHAFwdMainloop, CausalMask_, auto thr_copy_q = copy_q.get_slice(thr_id); auto thr_copy_k = copy_k.get_slice(thr_id); auto thr_copy_v = copy_v.get_slice(thr_id); + auto thr_copy_k_cache = copy_k_cache.get_slice(thr_id); + auto thr_copy_v_cache = copy_v_cache.get_slice(thr_id); auto thr_mma_qk = mma_qk.get_slice(thr_id); auto thr_mma_pv = mma_pv.get_slice(thr_id); @@ -221,6 +268,8 @@ struct FMHAFwdMainloop, CausalMask_, auto tQgQ = thr_copy_q.partition_S(gQ); // (atom_val,q',d',D) auto tKgK = thr_copy_k.partition_S(gK); // (atom_val,k',d',K,D) auto tVgV = thr_copy_v.partition_S(gV_split); // (atom_val,v',k',VV,K) + auto tKgK_cache = thr_copy_k_cache.partition_S(gK_cache); + auto tVgV_cache = thr_copy_v_cache.partition_S(gV_cache_split); /* Create register fragments for MMA and copies */ auto tQrQ = thr_copy_q.partition_sg_fragment_D(gQ(_,_,0)); @@ -239,11 +288,15 @@ struct FMHAFwdMainloop, CausalMask_, auto prefetch_q = make_block_2d_prefetch(copy_q); auto prefetch_k = make_block_2d_prefetch(copy_k); auto prefetch_v = make_block_2d_prefetch(tile_shape_v, V_2D); + auto prefetch_k_cache = make_block_2d_prefetch(copy_k_cache); + auto prefetch_v_cache = make_block_2d_prefetch(tile_shape_v, V_cache_2D); /* Partition global tensors for prefetch */ auto pQgQ = prefetch_q.get_slice(thr_id).partition_S(gQ); auto pKgK = prefetch_k.get_slice(thr_id).partition_S(gK); auto pVgV = prefetch_v.get_slice(thr_id).partition_S(gV); + auto pKgK_cache = prefetch_k_cache.get_slice(thr_id).partition_S(gK_cache); + auto pVgV_cache = prefetch_v_cache.get_slice(thr_id).partition_S(gV_cache); // ------ // Kernel @@ -251,6 +304,7 @@ struct FMHAFwdMainloop, CausalMask_, /* Initialization steps for first block: Q/K prefetch, O init */ /* TODO: limit D prefetch for large head size, and reorder K prefetches */ + int kblocks_cache = ceil_div(seq_len_kv_cache, get<1>(TileShapeQK{})); if (blk_k0 == 0) { for (int D = 0; D < size<3>(pQgQ); D++) { prefetch(prefetch_q, pQgQ(_,_,_,D)); @@ -259,7 +313,16 @@ struct FMHAFwdMainloop, CausalMask_, for (int D = 0; D < size<4>(pKgK); D++) { CUTLASS_PRAGMA_UNROLL for (int K = 0; K < Stages; K++) { - prefetch(prefetch_k, pKgK(_,_,_,K,D)); + if (K < kblocks_cache) { + if constexpr (PagedKV) { + int physical_K_tile = get_physical_k_tile(K, l_coord, seq_len_kv_cache); + prefetch(prefetch_k_cache, pKgK_cache(_,_,_,physical_K_tile,D)); + } else { + prefetch(prefetch_k_cache, pKgK_cache(_,_,_,K,D)); + } + } else { + prefetch(prefetch_k, pKgK(_,_,_,K - kblocks_cache,D)); + } } } @@ -276,11 +339,23 @@ struct FMHAFwdMainloop, CausalMask_, /* Split barrier to keep threads together */ barrier_arrive(ScopeWorkgroup); + bool is_cache = K < kblocks_cache; + int physical_K_tile = K; + if constexpr (PagedKV) { + if (is_cache) { + physical_K_tile = get_physical_k_tile(K, l_coord, seq_len_kv_cache); + } + } + /* GEMM 1: S = K * Q */ clear(tSrS); /* TODO: fuse w/ initial gemm call */ for (int D = 0; D < size<4>(tKgK); D++) { copy(copy_q, tQgQ(_,_,_,D), tQrQ); - copy(copy_k, tKgK(_,_,_,K,D), tKrK); + if (is_cache) { + copy(copy_k_cache, tKgK_cache(_,_,_,physical_K_tile,D), tKrK); + } else { + copy(copy_k, tKgK(_,_,_,K - kblocks_cache,D), tKrK); + } reorder(tQrQ, tSrQ); reorder(tKrK, tSrK); @@ -289,7 +364,11 @@ struct FMHAFwdMainloop, CausalMask_, } /* V prefetch for GEMM 2 */ - prefetch(prefetch_v, pVgV(_,_,_,K)); + if (is_cache) { + prefetch(prefetch_v_cache, pVgV_cache(_,_,_,physical_K_tile)); + } else { + prefetch(prefetch_v, pVgV(_,_,_,K-kblocks_cache)); + } /* Causal masking */ if constexpr (CausalMask) { @@ -302,7 +381,7 @@ struct FMHAFwdMainloop, CausalMask_, for (int i = 0; i < tSrS.size(); ++i) { int row_idx = get<0>(cS_thread(i)); int col_idx = get<1>(cS_thread(i)); - if (col_idx - full_tile_offset > row_idx - discard_seq_coord) { + if (col_idx - seq_len_kv_cache - full_tile_offset > row_idx - discard_seq_coord) { tSrS(i) = ElementS(-INFINITY); } } @@ -311,7 +390,11 @@ struct FMHAFwdMainloop, CausalMask_, /* k masking for remainder tiles */ if (check_remainder_k && K == total_blk - 1) { FragSRow k_rem_mask; - int k = get<0>(tKgK(0,0,0,K,0)) + get_sub_group().get_local_id()[0]; + int k_val; + if (is_cache) k_val = get<0>(tKgK_cache(0,0,0,physical_K_tile,0)); + else k_val = get<0>(tKgK(0,0,0,K-kblocks_cache,0)) + kblocks_cache * get<1>(TileShapeQK{}); // Adjust global index from relative + + int k = k_val + get_sub_group().get_local_id()[0]; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < k_rem_mask.size(); i++, k += intel::sg_size) { k_rem_mask(i) = (k < seq_len) ? ElementS(sycl::nan(0u)) : ElementS(-INFINITY); @@ -329,14 +412,31 @@ struct FMHAFwdMainloop, CausalMask_, /* GEMM 2: A += P * V, split in v dimension */ CUTLASS_PRAGMA_UNROLL for (int VV = 0; VV < VTiles; VV++) { - copy(copy_v, tVgV(_,_,_,VV,K), tVrV); + if (is_cache) { + copy(copy_v_cache, tVgV_cache(_,_,_,VV,physical_K_tile), tVrV); + } else { + copy(copy_v, tVgV(_,_,_,VV,K-kblocks_cache), tVrV); + } reorder(tVrV, tArV); cute::gemm(mma_pv, tArP, tArV, tArA(_,_,_,VV)); } /* K prefetch */ for (int D = 0; D < size<4>(pKgK); D++) { - prefetch(prefetch_k, pKgK(_,_,_,K+Stages,D)); + int K_next = K + Stages; + bool is_cache_next = K_next < kblocks_cache; + int physical_K_next = K_next; + if constexpr (PagedKV) { + if (is_cache_next) { + physical_K_next = get_physical_k_tile(K_next, l_coord, seq_len_kv_cache); + } + } + + if (is_cache_next) { + prefetch(prefetch_k_cache, pKgK_cache(_,_,_,physical_K_next,D)); + } else { + prefetch(prefetch_k, pKgK(_,_,_,K_next-kblocks_cache,D)); + } } barrier_wait(ScopeWorkgroup); diff --git a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp index f5905f746a..dc17f48912 100644 --- a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp +++ b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp @@ -52,7 +52,7 @@ struct FMHAProblemShape { using SeqLenType = cute::conditional_t; int batch; int num_heads_q, num_heads_kv; - SeqLenType seq_len_qo, seq_len_kv; + SeqLenType seq_len_qo, seq_len_kv, seq_len_kv_cache; int head_size_qk, head_size_vo; }; @@ -126,6 +126,10 @@ class XeFMHAFwdKernel { StrideV dV; ElementO *O; StrideO dO; + const ElementK *K_cache; + StrideK dK_cache{}; + const ElementV *V_cache; + StrideV dV_cache{}; }; using KernelParams = KernelArguments; @@ -174,11 +178,11 @@ class XeFMHAFwdKernel { static dim3 get_block_shape() { return dim3(SGPerWG::value * intel::sg_size, 1, 1); } CUTLASS_DEVICE - Shape get_sequence_length_shape(ProblemShape const& problem_shape, int const& batch) { + Shape get_sequence_length_shape(ProblemShape const& problem_shape, int const& batch) { if constexpr (is_var_len) { - return cutlass::fmha::collective::apply_variable_length(Shape{problem_shape.seq_len_qo, problem_shape.seq_len_kv}, batch); + return cutlass::fmha::collective::apply_variable_length(Shape{problem_shape.seq_len_qo, problem_shape.seq_len_kv, problem_shape.seq_len_kv_cache}, batch); } else { - return Shape{problem_shape.seq_len_qo, problem_shape.seq_len_kv}; + return Shape{problem_shape.seq_len_qo, problem_shape.seq_len_kv, problem_shape.seq_len_kv_cache}; } } @@ -211,7 +215,7 @@ class XeFMHAFwdKernel { int head = head_q / head_group_q; auto sequence_length_shape = get_sequence_length_shape(s, idx_b); - auto [seq_len_qo, seq_len_kv] = sequence_length_shape; + auto [seq_len_qo, seq_len_kv, seq_len_kv_cache] = sequence_length_shape; if (blk_q * get<0>(TileShapeQK{}) >= seq_len_qo) continue; auto offset = cute::min(seq_len_qo, seq_len_kv); @@ -220,18 +224,24 @@ class XeFMHAFwdKernel { int seq_coord = cute::min(seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + q_offset_sg)); if (CollectiveMainloop::CausalMask && seq_coord < discard_seq_coord) continue; - const int seq_len = CollectiveMainloop::CausalMask ? full_tile_offset + cute::min(seq_len_kv, seq_coord - discard_seq_coord) + q_sg_tile : seq_len_kv; + const int seq_len_new = CollectiveMainloop::CausalMask ? full_tile_offset + cute::min(seq_len_kv, seq_coord - discard_seq_coord) + q_sg_tile : seq_len_kv; + const int seq_len = seq_len_new + seq_len_kv_cache; const int k_blocks = cute::ceil_div(seq_len, get<1>(TileShapeQK{})); int offset_q = 0, offset_k = 0, offset_v = 0, offset_o = 0; + int offset_k_cache = 0, offset_v_cache = 0; if constexpr (is_var_len) { - int group_heads_q = s.num_heads_q / s.num_heads_kv; auto qo_cumulative = s.seq_len_qo.cumulative_length; auto kv_cumulative = s.seq_len_kv.cumulative_length; offset_q = s.num_heads_q * s.head_size_qk * qo_cumulative[idx_b]; offset_k = s.num_heads_kv * s.head_size_qk * kv_cumulative[idx_b]; offset_v = s.num_heads_kv * s.head_size_vo * kv_cumulative[idx_b]; offset_o = s.num_heads_q * s.head_size_vo * qo_cumulative[idx_b]; + if (s.seq_len_kv_cache.cumulative_length) { + auto kv_cumulative_cache = s.seq_len_kv_cache.cumulative_length; + offset_k_cache = s.num_heads_kv * s.head_size_qk * kv_cumulative_cache[idx_b]; + offset_v_cache = s.num_heads_kv * s.head_size_vo * kv_cumulative_cache[idx_b]; + } } auto batch_dim = is_var_len ? 1 : s.batch; @@ -240,19 +250,28 @@ class XeFMHAFwdKernel { auto shape_V = make_shape(s.head_size_vo, seq_len_kv, s.num_heads_kv, batch_dim); auto shape_O = make_shape(seq_len_qo, s.head_size_vo, s.num_heads_kv, batch_dim); + auto shape_K_cache = make_shape(seq_len_kv_cache, s.head_size_qk, s.num_heads_kv, batch_dim); + auto shape_V_cache = make_shape(s.head_size_vo, seq_len_kv_cache, s.num_heads_kv, batch_dim); + auto dcQ = const_cast(p.Q + offset_q); auto dcK = const_cast(p.K + offset_k); auto dcV = const_cast(p.V + offset_v); + auto dcK_cache = const_cast(p.K_cache + offset_k_cache); + auto dcV_cache = const_cast(p.V_cache + offset_v_cache); auto ptrO = p.O + offset_o; auto stride_q = is_var_len ? cutlass::make_cute_packed_stride(StrideQ{}, shape_Q) : p.dQ; auto stride_k = is_var_len ? cutlass::make_cute_packed_stride(StrideK{}, shape_K) : p.dK; auto stride_v = is_var_len ? cutlass::make_cute_packed_stride(StrideV{}, shape_V) : p.dV; auto stride_o = is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_O) : p.dO; + auto stride_k_cache = is_var_len ? cutlass::make_cute_packed_stride(StrideK{}, shape_K_cache) : p.dK_cache; + auto stride_v_cache = is_var_len ? cutlass::make_cute_packed_stride(StrideV{}, shape_V_cache) : p.dV_cache; Tensor Q = make_tensor(make_gmem_ptr(dcQ), make_layout(shape_Q, stride_q)); Tensor K = make_tensor(make_gmem_ptr(dcK), make_layout(shape_K, stride_k)); Tensor V = make_tensor(make_gmem_ptr(dcV), make_layout(shape_V, stride_v)); + Tensor K_cache = make_tensor(make_gmem_ptr(dcK_cache), make_layout(shape_K_cache, stride_k_cache)); + Tensor V_cache = make_tensor(make_gmem_ptr(dcV_cache), make_layout(shape_V_cache, stride_v_cache)); Tensor O = make_tensor(make_gmem_ptr(ptrO), make_layout(shape_O, stride_o)); @@ -268,8 +287,11 @@ class XeFMHAFwdKernel { V(_,_,head,l_coord), tArA, tA_max, tA_sum, blk_qv, 0, k_blocks, k_blocks, - thr_id, seq_len, - full_tile_offset, discard_seq_coord); + thr_id, seq_len, seq_len_kv_cache, idx_b, + full_tile_offset, discard_seq_coord, + K_cache(_,_,head,l_coord), + V_cache(_,_,head,l_coord)); + if constexpr (!is_empty_v && !is_empty_v) { sycl::group_barrier(get_work_group<3>()); } @@ -359,6 +381,10 @@ class XeFMHAFwdDynamicSplitKernel { StrideV dV; ElementO *O; StrideO dO; + const ElementK *K_cache = nullptr; + StrideK dK_cache{}; + const ElementV *V_cache = nullptr; + StrideV dV_cache{}; }; using KernelParams = KernelArguments; @@ -538,6 +564,13 @@ class XeFMHAFwdDynamicSplitKernel { Tensor V = make_tensor(make_gmem_ptr(dcV), make_layout(shape_V, p.dV)); // (v,k,h,b) Tensor O = make_tensor(make_gmem_ptr(p.O), make_layout(shape_O, p.dO)); // (q,v,h,b) + auto shape_K_cache = make_shape(s.seq_len_kv_cache, s.head_size_qk, s.num_heads_kv, s.batch); + auto shape_V_cache = make_shape(s.head_size_vo, s.seq_len_kv_cache, s.num_heads_kv, s.batch); + auto dcK_cache = const_cast(p.K_cache); + auto dcV_cache = const_cast(p.V_cache); + Tensor K_cache = make_tensor(make_gmem_ptr(dcK_cache), make_layout(shape_K_cache, p.dK_cache)); + Tensor V_cache = make_tensor(make_gmem_ptr(dcV_cache), make_layout(shape_V_cache, p.dV_cache)); + // O accumulator types FragA tArA; FragARow tA_max, tA_sum; @@ -590,7 +623,7 @@ class XeFMHAFwdDynamicSplitKernel { V(_,_,head_kv,idx_b), tArA, tA_max, tA_sum, blk_qv, start_blk, end_blk, local_k_blocks, - thr_id, s.seq_len_kv, /*for causal*/0, 0); + thr_id, s.seq_len_kv, 0, 0, 0, 0); // partition id of start batch head id in current wg int partition_id = get_partition_id(wg_id, batch_head_id, num_blocks_per_wg, local_k_blocks); diff --git a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp index b1e9f0284d..8d95167a62 100644 --- a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp +++ b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp @@ -60,14 +60,15 @@ struct Options { bool error; bool is_causal; bool varlen = false; + bool use_paged_kv = false; std::string scheduler; - int batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo, iterations; + int batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, page_size, head_size_qk, head_size_vo, iterations, verify; float softmax_scale; Options() - : help(false), error(false), is_causal(false), varlen(false), batch(32), num_heads_q(16), num_heads_kv(16), seq_len_qo(512), head_size_qk(128), - seq_len_kv(512), head_size_vo(128), iterations(100), softmax_scale(1.f), scheduler("Individual") {} + : help(false), error(false), is_causal(false), varlen(false), use_paged_kv(false), batch(32), num_heads_q(16), num_heads_kv(16), seq_len_qo(512), head_size_qk(128), + seq_len_kv(512), seq_len_kv_cache(0), page_size(128), head_size_vo(128), iterations(100), softmax_scale(1.f), verify(1), scheduler("Individual") {} // Parses the command line void parse(int argc, char const **args) { @@ -98,6 +99,7 @@ struct Options { cmd.get_cmd_line_argument("num_heads_q", num_heads_q, 16); cmd.get_cmd_line_argument("num_heads_kv", num_heads_kv, num_heads_q); cmd.get_cmd_line_argument("seq_len_kv", seq_len_kv, 512); + cmd.get_cmd_line_argument("seq_len_kv_cache", seq_len_kv_cache, 0); #endif #ifdef DECODE cmd.get_cmd_line_argument("seq_len_qo", seq_len_qo, 1); @@ -107,6 +109,21 @@ struct Options { cmd.get_cmd_line_argument("head_size_vo", head_size_vo, HEAD_DIM); cmd.get_cmd_line_argument("head_size_qk", head_size_qk, head_size_vo); cmd.get_cmd_line_argument("iterations", iterations, 100); + cmd.get_cmd_line_argument("verify", verify, 1); + + if (cmd.check_cmd_line_flag("use_paged_kv")) { + use_paged_kv = true; + cmd.get_cmd_line_argument("page_size", page_size, 128); + + if (page_size % 128 != 0) { + std::cerr << "Invalid: page_size must be a multiple of 128" << std::endl; + return; + } + if (seq_len_kv_cache % page_size != 0) { + std::cerr << "Invalid: seq_len_kv_cache must be divisible by page_size" << std::endl; + return; + } + } softmax_scale = 1 / sqrt(static_cast(head_size_qk)); } @@ -125,10 +142,13 @@ struct Options { << " --num_heads_kv= Sets the Number of Attention Heads for Query input in the Multi-Head Self Attention module\n" << " --seq_len_qo= Sets the Sequence length of the Query input in Multi-Head Self Attention module\n" << " --seq_len_kv= Sets the Sequence length of the Key-Value pair in Multi-Head Self Attention module\n" + << " --seq_len_kv_cache= Sets the Sequence length of the cached Key-Value pair in Multi-Head Self Attention module\n" + << " --use_paged_kv Use paged (non-contiguous) KV cache. Default is contiguous KV Cache\n" + << " --page_size= Block size for paged KV cache. Default is 128\n" << " --head_size_qk= Sets the Attention Head dimension of the 1st Matrix Multiplication in Multi-Head Self Attention module\n" << " --head_size_vo= Sets the Attention Head dimension of the 2nd Matrix Multiplication in Multi-Head Self Attention module\n" - << " --iterations= Iterations\n\n"; - + << " --iterations= Iterations\n" + << " --verify= Specify whether to verify.\n\n"; return out; } }; @@ -191,19 +211,32 @@ template struct ExampleRunner { StrideQ stride_Q; StrideK stride_K; StrideV stride_V; + StrideK stride_K_cache; + StrideV stride_V_cache; StrideO stride_O; uint64_t seed = 0; cutlass::DeviceAllocation block_Q; cutlass::DeviceAllocation block_K; cutlass::DeviceAllocation block_V; + cutlass::DeviceAllocation block_K_cache; + cutlass::DeviceAllocation block_V_cache; cutlass::DeviceAllocation block_O; cutlass::DeviceAllocation block_ref_O; std::vector cumulative_seqlen_q; std::vector cumulative_seqlen_kv; + std::vector cumulative_seqlen_kv_cache; cutlass::DeviceAllocation device_cumulative_seqlen_q; cutlass::DeviceAllocation device_cumulative_seqlen_kv; + cutlass::DeviceAllocation device_cumulative_seqlen_kv_cache; + + struct PagedKVParams { + cutlass::DeviceAllocation page_table; + int page_size = 0; + cutlass::DeviceAllocation num_pages_per_seq; + }; + PagedKVParams paged_kv_cache; // // Methods @@ -219,12 +252,13 @@ template struct ExampleRunner { std::mt19937 rng(0x202305151552ull); std::normal_distribution dist_q(get<3>(problem_size), get<3>(problem_size) / 2); std::normal_distribution dist_kv(get<4>(problem_size), get<4>(problem_size) / 2); + std::normal_distribution dist_kv_cache(get<5>(problem_size), get<5>(problem_size) / 2); // Use Cacheline Size to calculate alignment constexpr int cacheline_bytes = 64; constexpr int AlignmentQ = cacheline_bytes / sizeof(ElementQ); // Alignment of Q matrix in units of elements constexpr int AlignmentKV = cacheline_bytes / sizeof(ElementK); // Alignment of Kand V matrix in units of elements - + constexpr int AlignmentKVCache = 128; //Page size must be a multiple of 128 auto generate_positive_int = [](auto& dist, auto& gen) { int result = 0; do { @@ -235,30 +269,38 @@ template struct ExampleRunner { cumulative_seqlen_q = {0}; cumulative_seqlen_kv = {0}; + cumulative_seqlen_kv_cache = {0}; int total_seqlen_q = 0; int total_seqlen_kv = 0; + int total_seqlen_kv_cache = 0; int max_seqlen_q = 0; int max_seqlen_kv = 0; + int max_seqlen_kv_cache = 0; for (int i = 0; i < num_batches; i++) { int seqlen_q = cutlass::round_up(generate_positive_int(dist_q, rng), AlignmentQ); int seqlen_kv = cutlass::round_up(generate_positive_int(dist_kv, rng), AlignmentKV); + int seqlen_kv_cache = get<5>(problem_size) == 0 ? 0 : cutlass::round_up(generate_positive_int(dist_kv_cache, rng), AlignmentKVCache); total_seqlen_q += seqlen_q; total_seqlen_kv += seqlen_kv; + total_seqlen_kv_cache += seqlen_kv_cache; max_seqlen_q = std::max(max_seqlen_q, seqlen_q); max_seqlen_kv = std::max(max_seqlen_kv, seqlen_kv); + max_seqlen_kv_cache = std::max(max_seqlen_kv_cache, seqlen_kv_cache); cumulative_seqlen_q.push_back(cumulative_seqlen_q.back() + seqlen_q); cumulative_seqlen_kv.push_back(cumulative_seqlen_kv.back() + seqlen_kv); + cumulative_seqlen_kv_cache.push_back(cumulative_seqlen_kv_cache.back() + seqlen_kv_cache); } ProblemShape problem_size_for_init = problem_size; get<0>(problem_size_for_init) = 1; get<3>(problem_size_for_init) = total_seqlen_q; get<4>(problem_size_for_init) = total_seqlen_kv; + get<5>(problem_size_for_init) = total_seqlen_kv_cache; ProblemShapeType problem_size_for_launch; problem_size_for_launch.batch = get<0>(problem_size); @@ -266,21 +308,22 @@ template struct ExampleRunner { problem_size_for_launch.num_heads_kv = get<2>(problem_size); problem_size_for_launch.seq_len_qo = cutlass::fmha::collective::VariableLength{max_seqlen_q}; problem_size_for_launch.seq_len_kv = cutlass::fmha::collective::VariableLength{max_seqlen_kv}; - problem_size_for_launch.head_size_qk = get<5>(problem_size); - problem_size_for_launch.head_size_vo = get<6>(problem_size); + problem_size_for_launch.seq_len_kv_cache = cutlass::fmha::collective::VariableLength{max_seqlen_kv_cache}; + problem_size_for_launch.head_size_qk = get<6>(problem_size); + problem_size_for_launch.head_size_vo = get<7>(problem_size); return cute::make_tuple(problem_size_for_init, problem_size_for_launch); } - - bool verify(ProblemShapeType shape, bool is_causal) { if constexpr (isVarLen) { int max_seq_len_q = shape.seq_len_qo; int max_seq_len_kv = shape.seq_len_kv; + int max_seq_len_kv_cache = shape.seq_len_kv_cache; shape.seq_len_qo = cutlass::fmha::collective::VariableLength{max_seq_len_q, cumulative_seqlen_q.data()}; shape.seq_len_kv = cutlass::fmha::collective::VariableLength{max_seq_len_kv, cumulative_seqlen_kv.data()}; + shape.seq_len_kv_cache = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, cumulative_seqlen_kv_cache.data()}; } auto batch = shape.batch; @@ -288,50 +331,124 @@ template struct ExampleRunner { auto num_heads_kv = shape.num_heads_kv; auto head_size_qk = shape.head_size_qk; auto head_size_vo = shape.head_size_vo; - int seq_len_qo, seq_len_kv; + int seq_len_qo, seq_len_kv, seq_len_kv_cache; auto block_Q_ = in_memory(block_Q); auto block_K_ = in_memory(block_K); auto block_V_ = in_memory(block_V); + auto block_K_cache_ = in_memory(block_K_cache); + auto block_V_cache_ = in_memory(block_V_cache); using ElementV_ = std::remove_pointer_t; + using ElementK_ = std::remove_pointer_t; int offset_q = 0; int offset_k = 0; int offset_v = 0; + int offset_k_cache = 0; + int offset_v_cache = 0; int offset_o = 0; + std::vector page_table_host; + std::vector num_pages_per_seq_host; + if (paged_kv_cache.page_size > 0) { + page_table_host.resize(paged_kv_cache.page_table.size()); + compat::memcpy(page_table_host.data(), paged_kv_cache.page_table.get(), page_table_host.size() * sizeof(int)); + num_pages_per_seq_host.resize(paged_kv_cache.num_pages_per_seq.size()); + compat::memcpy(num_pages_per_seq_host.data(), paged_kv_cache.num_pages_per_seq.get(), num_pages_per_seq_host.size() * sizeof(int)); + compat::wait(); + } + // loop over the batch dimension to compute the output // to avoid the risk of running out of device memory int q_group_size = num_heads_q/num_heads_kv; for (int b = 0; b < batch; b++) { if constexpr (isVarLen) { - auto logical_seq_shape = cutlass::fmha::collective::apply_variable_length(make_shape(shape.seq_len_qo, shape.seq_len_kv), b); + auto logical_seq_shape = cutlass::fmha::collective::apply_variable_length(make_shape(shape.seq_len_qo, shape.seq_len_kv, shape.seq_len_kv_cache), b); seq_len_qo = get<0>(logical_seq_shape); seq_len_kv = get<1>(logical_seq_shape); + seq_len_kv_cache = get<2>(logical_seq_shape); } else { seq_len_qo = shape.seq_len_qo; seq_len_kv = shape.seq_len_kv; + seq_len_kv_cache = shape.seq_len_kv_cache; } + int seq_len_kv_total = seq_len_kv + seq_len_kv_cache; int kv_group_update=1; for (int h = 0; h < num_heads_q; h++) { cutlass::DeviceAllocation block_S; - block_S.reset(seq_len_qo * seq_len_kv); + block_S.reset(seq_len_qo * seq_len_kv_total); + + ElementK_* k_ptr; + ElementV_* v_ptr; + cutlass::DeviceAllocation block_K_concat; + cutlass::DeviceAllocation block_V_concat; + + if (seq_len_kv_cache > 0) { + block_K_concat.reset(head_size_qk * seq_len_kv_total); + block_V_concat.reset(seq_len_kv_total * head_size_vo); + + if (paged_kv_cache.page_size > 0) { + int page_size = paged_kv_cache.page_size; + int start_page_idx = isVarLen ? num_pages_per_seq_host[b] : b * (seq_len_kv_cache / page_size); + int num_pages = ceil_div(seq_len_kv_cache, page_size); + + for (int i = 0; i < num_pages; ++i) { + int physical_page_id = page_table_host[start_page_idx + i]; + int current_copy_len = std::min(page_size, seq_len_kv_cache - i * page_size); + + compat::memcpy( + block_K_concat.get() + head_size_qk * i * page_size, + block_K_cache_.get() + offset_k_cache + head_size_qk * physical_page_id * page_size, + head_size_qk * current_copy_len); + + compat::memcpy( + block_V_concat.get() + i * page_size * head_size_vo, + block_V_cache_.get() + offset_v_cache + physical_page_id * page_size * head_size_vo, + current_copy_len * head_size_vo); + } + } else { + compat::memcpy( + block_K_concat.get(), + block_K_cache_.get() + offset_k_cache, + head_size_qk * seq_len_kv_cache); + compat::memcpy( + block_V_concat.get(), + block_V_cache_.get() + offset_v_cache, + seq_len_kv_cache * head_size_vo); + } + + compat::memcpy( + block_K_concat.get() + head_size_qk * seq_len_kv_cache, + block_K_.get() + offset_k, + head_size_qk * seq_len_kv); + + compat::memcpy( + block_V_concat.get() + seq_len_kv_cache * head_size_vo, + block_V_.get() + offset_v, + seq_len_kv * head_size_vo); + + k_ptr = block_K_concat.get(); + v_ptr = block_V_concat.get(); + } else { + k_ptr = block_K_.get() + offset_k; + v_ptr = block_V_.get() + offset_v; + } cutlass::TensorRef ref_Q(block_Q_.get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk})); - cutlass::TensorRef ref_K(block_K_.get() + offset_k, LayoutK::packed({head_size_qk, seq_len_kv})); - cutlass::TensorRef ref_V(block_V_.get() + offset_v, LayoutV::packed({seq_len_kv, head_size_vo})); - cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv})); + cutlass::TensorRef ref_K(k_ptr, LayoutK::packed({head_size_qk, seq_len_kv_total})); + cutlass::TensorRef ref_V(v_ptr, LayoutV::packed({seq_len_kv_total, head_size_vo})); + cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total})); - cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv, head_size_qk}, 1.f, ref_Q, + cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv_total, head_size_qk}, 1.f, ref_Q, cutlass::ComplexTransform::kNone, ref_K, cutlass::ComplexTransform::kNone, 0.f, ref_S, ref_S, ElementS(0), 1, // batch_count seq_len_qo * head_size_qk, // batch_stride_Q - seq_len_kv * head_size_qk, // batch_stride_K - seq_len_qo * seq_len_kv, // batch_stride_S - seq_len_qo * seq_len_kv // batch_stride_S + seq_len_kv_total * head_size_qk, // batch_stride_K + seq_len_qo * seq_len_kv_total, // batch_stride_S + seq_len_qo * seq_len_kv_total // batch_stride_S ); compat::wait(); @@ -347,9 +464,9 @@ template struct ExampleRunner { if (is_causal) { // apply mask to S for (int row = 0; row < seq_len_qo; row++) { - for (int col = 0; col < seq_len_kv; col++) { - if ((col - full_tile_offset) > (row - discard_seq_coord)) - host_S[col + row * seq_len_kv] = ElementS{-INFINITY}; + for (int col = seq_len_kv_cache; col < seq_len_kv_total; col++) { + if ((col - seq_len_kv_cache - full_tile_offset) > (row - discard_seq_coord)) + host_S[col + row * seq_len_kv_total] = ElementS{-INFINITY}; } } } @@ -357,10 +474,10 @@ template struct ExampleRunner { // compute max element per row of S std::vector max_vec(seq_len_qo, ElementS{-INFINITY}); for (int row = 0; row < seq_len_qo; row++) { - int idx = row * seq_len_kv; + int idx = row * seq_len_kv_total; int max_idx = row; max_vec[max_idx] = host_S[idx++]; - for (int col = 1; col < seq_len_kv; col++, idx++) { + for (int col = 1; col < seq_len_kv_total; col++, idx++) { if (max_vec[max_idx] < host_S[idx]) max_vec[max_idx] = host_S[idx]; } @@ -368,9 +485,9 @@ template struct ExampleRunner { // compute exp of S for (int row = 0; row < seq_len_qo; row++) { - int idx = row * seq_len_kv; + int idx = row * seq_len_kv_total; int max_idx = row; - for (int col = 0; col < seq_len_kv; col++, idx++) { + for (int col = 0; col < seq_len_kv_total; col++, idx++) { /* FIXME: use softmax_scale instead of assuming its value here */ host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / sqrt(static_cast((head_size_qk)))); } @@ -379,16 +496,16 @@ template struct ExampleRunner { // compute sum per row of S std::vector sum_vec(seq_len_qo, ElementS{0}); for (int row = 0; row < seq_len_qo; row++) { - int idx = row * seq_len_kv; + int idx = row * seq_len_kv_total; int sum_idx = row; - for (int col = 0; col < seq_len_kv; col++, idx++) { + for (int col = 0; col < seq_len_kv_total; col++, idx++) { sum_vec[sum_idx] += host_S[idx]; } // scale each row with the sum to compute softmax - idx = row * seq_len_kv; + idx = row * seq_len_kv_total; sum_idx = row; - for (int col = 0; col < seq_len_kv; col++, idx++) { + for (int col = 0; col < seq_len_kv_total; col++, idx++) { if(is_causal && row < discard_seq_coord) { host_S[idx] = 0; } else { @@ -406,18 +523,18 @@ template struct ExampleRunner { compat::memcpy(block_P.get(), host_P.data(), host_P.size()); - cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv})); + cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total})); cutlass::DeviceAllocation block_acc; block_acc.reset(seq_len_qo * head_size_vo); cutlass::TensorRef ref_acc(block_acc.get(), LayoutO::packed({seq_len_qo, head_size_vo})); - cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv}, ElementS{1}, ref_P, + cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv_total}, ElementS{1}, ref_P, cutlass::ComplexTransform::kNone, ref_V, cutlass::ComplexTransform::kNone, ElementS{0}, ref_acc, ref_acc, ElementS{0}, 1, // batch_count - seq_len_qo * seq_len_kv, // batch_stride_P - seq_len_kv * head_size_vo, // batch_stride_V + seq_len_qo * seq_len_kv_total, // batch_stride_P + seq_len_kv_total * head_size_vo, // batch_stride_V seq_len_qo * head_size_vo, // batch_stride_O seq_len_qo * head_size_vo // batch_stride_O ); @@ -441,6 +558,8 @@ template struct ExampleRunner { if(kv_group_update % q_group_size==0) { offset_k += seq_len_kv * head_size_qk; offset_v += seq_len_kv * head_size_vo; + offset_k_cache += seq_len_kv_cache * head_size_qk; + offset_v_cache += seq_len_kv_cache * head_size_vo; } kv_group_update++; offset_o += seq_len_qo * head_size_vo; @@ -458,7 +577,7 @@ template struct ExampleRunner { /// Initialize operands to be used in the GEMM and reference GEMM ProblemShapeType initialize(const Options &options) { - auto problem_shape_in = cute::make_tuple(options.batch, options.num_heads_q, options.num_heads_kv, options.seq_len_qo, options.seq_len_kv, options.head_size_qk, options.head_size_vo); + auto problem_shape_in = cute::make_tuple(options.batch, options.num_heads_q, options.num_heads_kv, options.seq_len_qo, options.seq_len_kv, options.seq_len_kv_cache, options.head_size_qk, options.head_size_vo); ProblemShapeType shape; decltype(problem_shape_in) problem_size; @@ -474,25 +593,71 @@ template struct ExampleRunner { shape.num_heads_kv = options.num_heads_kv; shape.seq_len_qo = options.seq_len_qo; shape.seq_len_kv = options.seq_len_kv; + shape.seq_len_kv_cache = options.seq_len_kv_cache; shape.head_size_qk = options.head_size_qk; shape.head_size_vo = options.head_size_vo; } - auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo] = problem_size; - stride_Q = cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len_qo, head_size_qk, num_heads_q, batch)); - stride_K = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv, head_size_qk, num_heads_kv, batch)); - stride_V = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo, seq_len_kv, num_heads_kv, batch)); - stride_O = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, head_size_vo, num_heads_q, batch)); + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_size; + auto shape_Q = cute::make_shape(seq_len_qo, head_size_qk, num_heads_q, batch); + auto shape_K = cute::make_shape(seq_len_kv, head_size_qk, num_heads_kv, batch); + auto shape_V = cute::make_shape(head_size_vo, seq_len_kv, num_heads_kv, batch); + auto shape_K_cache = cute::make_shape(seq_len_kv_cache, head_size_qk, num_heads_kv, batch); + auto shape_V_cache = cute::make_shape(head_size_vo, seq_len_kv_cache, num_heads_kv, batch); + auto shape_O = cute::make_shape(seq_len_qo, head_size_vo, num_heads_q, batch); + + stride_Q = cutlass::make_cute_packed_stride(StrideQ{}, shape_Q); + stride_K = cutlass::make_cute_packed_stride(StrideK{}, shape_K); + stride_V = cutlass::make_cute_packed_stride(StrideV{}, shape_V); + stride_K_cache = cutlass::make_cute_packed_stride(StrideK{}, shape_K_cache); + stride_V_cache = cutlass::make_cute_packed_stride(StrideV{}, shape_V_cache); + stride_O = cutlass::make_cute_packed_stride(StrideO{}, shape_O); block_Q.reset(static_cast(batch) * num_heads_q * seq_len_qo * head_size_qk); block_K.reset(static_cast(batch) * num_heads_kv * seq_len_kv * head_size_qk); block_V.reset(static_cast(batch) * num_heads_kv * seq_len_kv * head_size_vo); + block_K_cache.reset(static_cast(batch) * num_heads_kv * seq_len_kv_cache * head_size_qk); + block_V_cache.reset(static_cast(batch) * num_heads_kv * seq_len_kv_cache * head_size_vo); block_O.reset(static_cast(batch) * num_heads_q * seq_len_qo * head_size_vo); block_ref_O.reset(static_cast(batch) * num_heads_q * seq_len_qo * head_size_vo); + // Zero-initialize output buffer for the kernel result + // block_ref_O is fully written in verify() before being read, so no initialization needed + compat::memset(block_O.get(), 0, block_O.size() * sizeof(ElementO)); + if (options.use_paged_kv) { + paged_kv_cache.page_size = options.page_size; + std::vector num_pages_per_seq{0}; + int num_pages = 0; + for(int b = 0; b < shape.batch; b++) { + int seq_len_cache = isVarLen ? cumulative_seqlen_kv_cache[b + 1] - cumulative_seqlen_kv_cache[b] : seq_len_kv_cache; + int pages_per_seq = ceil_div(seq_len_cache, paged_kv_cache.page_size); + num_pages_per_seq.push_back(num_pages_per_seq.back() + pages_per_seq); + num_pages += pages_per_seq; + } + paged_kv_cache.page_table.reset(num_pages); + + // initialize block table with random mapping for non-contiguous layout + std::vector page_mapping(num_pages); + for (int b = 0; b < shape.batch; ++b) { + std::vector physical_pages(num_pages_per_seq[b + 1] - num_pages_per_seq[b]); + std::iota(physical_pages.begin(), physical_pages.end(), 0); + // shuffle physical pages + std::shuffle(physical_pages.begin(), physical_pages.end(), std::mt19937{ std::random_device{}() }); + for (int blk = 0; blk < physical_pages.size(); ++blk) { + int logical_idx = num_pages_per_seq[b] + blk; + page_mapping[logical_idx] = physical_pages[blk]; + } + } + compat::memcpy(paged_kv_cache.page_table.get(), page_mapping.data(), page_mapping.size() * sizeof(int)); + + paged_kv_cache.num_pages_per_seq.reset(num_pages_per_seq.size()); + compat::memcpy(paged_kv_cache.num_pages_per_seq.get(), num_pages_per_seq.data(), num_pages_per_seq.size() * sizeof(int)); + } initialize_block(block_Q, seed + 2023); initialize_block(block_K, seed + 2022); initialize_block(block_V, seed + 2021); + initialize_block(block_K_cache, seed + 2024); + initialize_block(block_V_cache, seed + 2025); if (!cumulative_seqlen_q.empty()) { device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size()); @@ -503,9 +668,16 @@ template struct ExampleRunner { device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size()); device_cumulative_seqlen_kv.copy_from_host(cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size()); } + + if (!cumulative_seqlen_kv_cache.empty()) { + device_cumulative_seqlen_kv_cache.reset(cumulative_seqlen_kv_cache.size()); + device_cumulative_seqlen_kv_cache.copy_from_host(cumulative_seqlen_kv_cache.data(), cumulative_seqlen_kv_cache.size()); + } + if constexpr (isVarLen) { shape.seq_len_qo.cumulative_length = device_cumulative_seqlen_q.get(); shape.seq_len_kv.cumulative_length = device_cumulative_seqlen_kv.get(); + shape.seq_len_kv_cache.cumulative_length = device_cumulative_seqlen_kv_cache.get(); } return shape; } @@ -550,9 +722,16 @@ template struct ExampleRunner { block_Q.get(), stride_Q, block_K.get(), stride_K, block_V.get(), stride_V, - block_O.get(), stride_O + block_O.get(), stride_O, + block_K_cache.get(), stride_K_cache, + block_V_cache.get(), stride_V_cache, + }, + { + options.softmax_scale, + options.use_paged_kv ? paged_kv_cache.page_table.get() : nullptr, + options.use_paged_kv ? paged_kv_cache.page_size : 0, + options.use_paged_kv ? paged_kv_cache.num_pages_per_seq.get() : nullptr }, - {options.softmax_scale}, {}, hw_info }; @@ -579,14 +758,17 @@ template struct ExampleRunner { compat::wait(); - // Verify that the result is correct - bool passed = verify(shape, options.is_causal); - std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + if (options.verify != 0) { + // Verify that the result is correct + bool passed = verify(shape, options.is_causal); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; - if (!passed) { - return cutlass::Status::kErrorInternal; + if (!passed) { + return cutlass::Status::kErrorInternal; + } + } else { + std::cout << "Disposition is skipped." << std::endl; } - if (options.iterations > 0) { GPU_Clock timer; timer.start(); @@ -657,7 +839,7 @@ struct FMHAConfig { decltype(cutlass::fmha::collective::get_sg_layout_pv(SubgroupLayoutQK{})), SubgroupLayoutPV_>; - template + template static int run(const Options &options) { // // Run examples @@ -686,14 +868,20 @@ struct FMHAConfig { using TensorK = decltype(make_dummy_tensor(ElementK{}, StrideK{})); using TensorV = decltype(make_dummy_tensor(ElementV{}, StrideV{})); using TensorO = decltype(make_dummy_tensor(ElementO{}, StrideO{})); + using TensorK_cache = TensorK; + using TensorV_cache = TensorV; + using GmemTiledCopyK_cache = GmemTiledCopyK; + using GmemTiledCopyV_cache = GmemTiledCopyV; // Mainloop using MainloopDispatchPolicy = cutlass::fmha::XeDefault; using CollectiveMainloop = cutlass::fmha::collective::FMHAFwdMainloop< - MainloopDispatchPolicy, Causal, + MainloopDispatchPolicy, Causal, PagedKV, TiledMMAQK, TiledMMAPV, VTiles, TensorQ, TensorK, TensorV, - GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV + TensorK_cache, TensorV_cache, + GmemTiledCopyQ, GmemTiledCopyK, GmemTiledCopyV, + GmemTiledCopyK_cache, GmemTiledCopyV_cache >; // Epilogue @@ -719,11 +907,20 @@ struct FMHAConfig { } static int run(const Options &options) { - if (options.varlen) { - return run(options); + if (persistent) { + if (options.use_paged_kv || options.seq_len_kv_cache > 0) { + std::cerr << "Error: Persistent kernel does not support paged/cached KV cache (use_paged_kv or seq_len_kv_cache > 0)." << std::endl; + return -1; + } + return run(options); + } else if (options.use_paged_kv && !options.varlen) { + return run(options); + } else if(!options.use_paged_kv && options.varlen) { + return run(options); + } else if(!options.use_paged_kv && !options.varlen) { + return run(options); } else { - return persistent ? run(options) : - run(options); + return run(options); } } };