Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 114 additions & 14 deletions applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,30 +54,39 @@ using namespace cute;

template <class DispatchPolicy_,
bool CausalMask_,
bool PagedKV_,
class TiledMMAQK_, // Tiling for Q*K GEMM
class TiledMMAPV_, // Tiling for P*V GEMM
int VTiles_, // # of tiles in V dimension
class TensorQ_, // Global Q/K/V tensors
class TensorK_,
class TensorV_,
class TensorK_cache_,
class TensorV_cache_,
class TiledCopyQ_ = void, // Optional TiledCopy for loading Q
class TiledCopyK_ = void, // Optional TiledCopy for loading K
class TiledCopyV_ = void> // Optional TiledCopy for loading V
class TiledCopyV_ = void, // Optional TiledCopy for loading V
class TiledCopyK_cache_ = void,
class TiledCopyV_cache_ = void> // Optional TiledCopy for loading V_cache
struct FMHAFwdMainloop {
static_assert(cutlass::detail::dependent_false<DispatchPolicy_>, "Could not find a mainloop specialization.");
};

/////////////////////////////////////////////////////////////////////////////////////////////////

template <int Stages,
bool CausalMask_,
bool CausalMask_, bool PagedKV_,
class TiledMMAQK_, class TiledMMAPV_, int VTiles_,
class TensorQ_, class TensorK_, class TensorV_,
class TiledCopyQ_, class TiledCopyK_, class TiledCopyV_>
struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
class TensorK_cache_, class TensorV_cache_,
class TiledCopyQ_, class TiledCopyK_, class TiledCopyV_,
class TiledCopyK_cache_, class TiledCopyV_cache_>
struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_, PagedKV_,
TiledMMAQK_, TiledMMAPV_, VTiles_,
TensorQ_, TensorK_, TensorV_,
TiledCopyQ_, TiledCopyK_, TiledCopyV_> {
TensorK_cache_, TensorV_cache_,
TiledCopyQ_, TiledCopyK_, TiledCopyV_,
TiledCopyK_cache_, TiledCopyV_cache_> {
//
// Type Aliases
//
Expand All @@ -100,6 +109,12 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
using TiledCopyQ = conditional_t<is_void_v<TiledCopyQ_>, decltype(make_block_2d_copy_A(TiledMMAQK{}, TensorQ2D{})), TiledCopyQ_>;
using TiledCopyK = conditional_t<is_void_v<TiledCopyK_>, decltype(make_block_2d_copy_B(TiledMMAQK{}, TensorK2D{})), TiledCopyK_>;
using TiledCopyV = conditional_t<is_void_v<TiledCopyV_>, decltype(make_block_2d_copy_B(TiledMMAPV{}, TensorV2D{})), TiledCopyV_>;
using TensorK_cache = TensorK_cache_;
using TensorV_cache = TensorV_cache_;
using TensorK_cache2D = decltype(TensorK_cache_{}(append<rank_v<TensorK_cache_>>(make_coord(_,_),0)));
using TensorV_cache2D = decltype(TensorV_cache_{}(append<rank_v<TensorV_cache_>>(make_coord(_,_),0)));
using TiledCopyK_cache = conditional_t<is_void_v<TiledCopyK_cache_>, decltype(make_block_2d_copy_B(TiledMMAQK{}, TensorK_cache2D{})), TiledCopyK_cache_>;
using TiledCopyV_cache = conditional_t<is_void_v<TiledCopyV_cache_>, decltype(make_block_2d_copy_B(TiledMMAPV{}, TensorV_cache2D{})), TiledCopyV_cache_>;

// TODO: static_asserts on TiledMMAPV here...

Expand Down Expand Up @@ -127,10 +142,14 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
using ElementA = typename TiledMMAPV::ValTypeD;

static constexpr bool CausalMask = CausalMask_;
static constexpr bool PagedKV = PagedKV_;

// User-facing arguments
struct Arguments {
ElementS const scale;
int const* ptr_page_table = nullptr;
int page_size = 0;
int const* num_pages_per_seq = nullptr;
};

// Kernel-facing parameters
Expand All @@ -151,14 +170,28 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
Params to_underlying_arguments(Arguments const &args, void * /* workspace */) {
constexpr double kLog2e = 1.4426950408889634074; // log_2(e)
ElementS val = args.scale * static_cast<ElementS>(kLog2e);
return Params{val};
return Params{val, args.ptr_page_table, args.page_size, args.num_pages_per_seq};
}

CUTLASS_HOST_DEVICE static
bool can_implement(Arguments const&) {
return true;
}

CUTLASS_DEVICE
int get_physical_k_tile(int K, int l_coord, int seq_len_kv_cache) {
int next_page_logical_idx = K * get<1>(TileShapeQK{}) / params.page_size;
// get<1>(TileShapeQK{}) usually smaller than page_size.
// assuming page_size is multiple of get<1>(TileShapeQK{})
int tiles_per_page = params.page_size / get<1>(TileShapeQK{});
int batch_offset = params.num_pages_per_seq ? params.num_pages_per_seq[l_coord] : l_coord * (seq_len_kv_cache / params.page_size);

return params.ptr_page_table[
batch_offset +
next_page_logical_idx] * tiles_per_page +
K % tiles_per_page;
}

template <typename QVCoord>
CUTLASS_DEVICE
void
Expand All @@ -174,8 +207,12 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
int total_blk, // Total # of K blocks
int thr_id,
int seq_len,
int seq_len_kv_cache,
int l_coord,
int full_tile_offset,
int discard_seq_coord) {
int discard_seq_coord,
TensorK_cache2D const& K_cache_2D = TensorK_cache2D{},
TensorV_cache2D const& V_cache_2D = TensorV_cache2D{}) {
using namespace sycl::ext::oneapi::this_work_item;

// Short dimension names:
Expand All @@ -193,6 +230,8 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
Tensor cQ = make_identity_tensor(Q_2D.shape()); // (q,d)
Tensor cK = make_identity_tensor(K_2D.shape()); // (k,d)
Tensor cV = make_identity_tensor(V_2D.shape()); // (v,k)
Tensor cK_cache = make_identity_tensor(K_cache_2D.shape()); // (k,d)
Tensor cV_cache = make_identity_tensor(V_cache_2D.shape()); // (v,k)
Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{})); // (q,k)

/* Partition global tensors into workgroup tiles */
Expand All @@ -201,10 +240,16 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
Tensor gV = local_tile(cV, tile_shape_v, make_coord(get<1>(blk_qv),_)); // (v,k,K)
Tensor gV_split = local_tile(gV, TileShapePV{}, make_coord(_,_,0), Step<X,_1,_1>{}); // (v,k,VV,K)

Tensor gK_cache = local_tile(cK_cache, TileShapeQK{}, make_coord(_,_,_), Step<X,_1,_1>{}); // (k,d,K,D)
Tensor gV_cache = local_tile(cV_cache, tile_shape_v, make_coord(get<1>(blk_qv),_)); // (v,k,K)
Tensor gV_cache_split = local_tile(gV_cache, TileShapePV{}, make_coord(_,_,0), Step<X,_1,_1>{}); // (v,k,VV,K)

/* Create global -> register copies */
TiledCopyQ copy_q{Q_2D};
TiledCopyK copy_k{K_2D};
TiledCopyV copy_v{V_2D};
TiledCopyK_cache copy_k_cache{K_cache_2D};
TiledCopyV_cache copy_v_cache{V_cache_2D};

/* Create MMAs */
TiledMMAQK mma_qk{};
Expand All @@ -214,13 +259,17 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
auto thr_copy_q = copy_q.get_slice(thr_id);
auto thr_copy_k = copy_k.get_slice(thr_id);
auto thr_copy_v = copy_v.get_slice(thr_id);
auto thr_copy_k_cache = copy_k_cache.get_slice(thr_id);
auto thr_copy_v_cache = copy_v_cache.get_slice(thr_id);
auto thr_mma_qk = mma_qk.get_slice(thr_id);
auto thr_mma_pv = mma_pv.get_slice(thr_id);

/* Partition coordinate tensors for copy */
auto tQgQ = thr_copy_q.partition_S(gQ); // (atom_val,q',d',D)
auto tKgK = thr_copy_k.partition_S(gK); // (atom_val,k',d',K,D)
auto tVgV = thr_copy_v.partition_S(gV_split); // (atom_val,v',k',VV,K)
auto tKgK_cache = thr_copy_k_cache.partition_S(gK_cache);
auto tVgV_cache = thr_copy_v_cache.partition_S(gV_cache_split);

/* Create register fragments for MMA and copies */
auto tQrQ = thr_copy_q.partition_sg_fragment_D(gQ(_,_,0));
Expand All @@ -239,18 +288,23 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
auto prefetch_q = make_block_2d_prefetch(copy_q);
auto prefetch_k = make_block_2d_prefetch(copy_k);
auto prefetch_v = make_block_2d_prefetch<SGPerWG::value>(tile_shape_v, V_2D);
auto prefetch_k_cache = make_block_2d_prefetch(copy_k_cache);
auto prefetch_v_cache = make_block_2d_prefetch<SGPerWG::value>(tile_shape_v, V_cache_2D);

/* Partition global tensors for prefetch */
auto pQgQ = prefetch_q.get_slice(thr_id).partition_S(gQ);
auto pKgK = prefetch_k.get_slice(thr_id).partition_S(gK);
auto pVgV = prefetch_v.get_slice(thr_id).partition_S(gV);
auto pKgK_cache = prefetch_k_cache.get_slice(thr_id).partition_S(gK_cache);
auto pVgV_cache = prefetch_v_cache.get_slice(thr_id).partition_S(gV_cache);

// ------
// Kernel
// ------

/* Initialization steps for first block: Q/K prefetch, O init */
/* TODO: limit D prefetch for large head size, and reorder K prefetches */
int kblocks_cache = ceil_div(seq_len_kv_cache, get<1>(TileShapeQK{}));
if (blk_k0 == 0) {
for (int D = 0; D < size<3>(pQgQ); D++) {
prefetch(prefetch_q, pQgQ(_,_,_,D));
Expand All @@ -259,7 +313,16 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
for (int D = 0; D < size<4>(pKgK); D++) {
CUTLASS_PRAGMA_UNROLL
for (int K = 0; K < Stages; K++) {
prefetch(prefetch_k, pKgK(_,_,_,K,D));
if (K < kblocks_cache) {
if constexpr (PagedKV) {
int physical_K_tile = get_physical_k_tile(K, l_coord, seq_len_kv_cache);
prefetch(prefetch_k_cache, pKgK_cache(_,_,_,physical_K_tile,D));
} else {
prefetch(prefetch_k_cache, pKgK_cache(_,_,_,K,D));
}
} else {
prefetch(prefetch_k, pKgK(_,_,_,K - kblocks_cache,D));
}
}
}

Expand All @@ -276,11 +339,23 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
/* Split barrier to keep threads together */
barrier_arrive(ScopeWorkgroup);

bool is_cache = K < kblocks_cache;
int physical_K_tile = K;
if constexpr (PagedKV) {
if (is_cache) {
physical_K_tile = get_physical_k_tile(K, l_coord, seq_len_kv_cache);
}
}

/* GEMM 1: S = K * Q */
clear(tSrS); /* TODO: fuse w/ initial gemm call */
for (int D = 0; D < size<4>(tKgK); D++) {
copy(copy_q, tQgQ(_,_,_,D), tQrQ);
copy(copy_k, tKgK(_,_,_,K,D), tKrK);
if (is_cache) {
copy(copy_k_cache, tKgK_cache(_,_,_,physical_K_tile,D), tKrK);
} else {
copy(copy_k, tKgK(_,_,_,K - kblocks_cache,D), tKrK);
}

reorder(tQrQ, tSrQ);
reorder(tKrK, tSrK);
Expand All @@ -289,7 +364,11 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
}

/* V prefetch for GEMM 2 */
prefetch(prefetch_v, pVgV(_,_,_,K));
if (is_cache) {
prefetch(prefetch_v_cache, pVgV_cache(_,_,_,physical_K_tile));
} else {
prefetch(prefetch_v, pVgV(_,_,_,K-kblocks_cache));
}

/* Causal masking */
if constexpr (CausalMask) {
Expand All @@ -302,7 +381,7 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
for (int i = 0; i < tSrS.size(); ++i) {
int row_idx = get<0>(cS_thread(i));
int col_idx = get<1>(cS_thread(i));
if (col_idx - full_tile_offset > row_idx - discard_seq_coord) {
if (col_idx - seq_len_kv_cache - full_tile_offset > row_idx - discard_seq_coord) {
tSrS(i) = ElementS(-INFINITY);
}
}
Expand All @@ -311,7 +390,11 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
/* k masking for remainder tiles */
if (check_remainder_k && K == total_blk - 1) {
FragSRow k_rem_mask;
int k = get<0>(tKgK(0,0,0,K,0)) + get_sub_group().get_local_id()[0];
int k_val;
if (is_cache) k_val = get<0>(tKgK_cache(0,0,0,physical_K_tile,0));
else k_val = get<0>(tKgK(0,0,0,K-kblocks_cache,0)) + kblocks_cache * get<1>(TileShapeQK{}); // Adjust global index from relative

int k = k_val + get_sub_group().get_local_id()[0];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < k_rem_mask.size(); i++, k += intel::sg_size) {
k_rem_mask(i) = (k < seq_len) ? ElementS(sycl::nan(0u)) : ElementS(-INFINITY);
Expand All @@ -329,14 +412,31 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
/* GEMM 2: A += P * V, split in v dimension */
CUTLASS_PRAGMA_UNROLL
for (int VV = 0; VV < VTiles; VV++) {
copy(copy_v, tVgV(_,_,_,VV,K), tVrV);
if (is_cache) {
copy(copy_v_cache, tVgV_cache(_,_,_,VV,physical_K_tile), tVrV);
} else {
copy(copy_v, tVgV(_,_,_,VV,K-kblocks_cache), tVrV);
}
reorder(tVrV, tArV);
cute::gemm(mma_pv, tArP, tArV, tArA(_,_,_,VV));
}

/* K prefetch */
for (int D = 0; D < size<4>(pKgK); D++) {
prefetch(prefetch_k, pKgK(_,_,_,K+Stages,D));
int K_next = K + Stages;
bool is_cache_next = K_next < kblocks_cache;
int physical_K_next = K_next;
if constexpr (PagedKV) {
if (is_cache_next) {
physical_K_next = get_physical_k_tile(K_next, l_coord, seq_len_kv_cache);
}
}

if (is_cache_next) {
prefetch(prefetch_k_cache, pKgK_cache(_,_,_,physical_K_next,D));
} else {
prefetch(prefetch_k, pKgK(_,_,_,K_next-kblocks_cache,D));
}
}

barrier_wait(ScopeWorkgroup);
Expand Down
Loading
Loading