Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class FMHAFwdEpilogue {
Stride<E<1>, E<0>>{})
));
using ReduceFragARow = decltype(reduce<1>(ReduceFragA{}, sycl::plus<void>{}));
// static_assert(is_same_v<ReduceFragARow, float>, "dtype mismatched");

static auto default_tiled_copy_O_helper() {
if constexpr (ReduceK{} == _1{})
Expand Down Expand Up @@ -149,7 +150,79 @@ class FMHAFwdEpilogue {
using ElementA = typename FragA::element_type;

// Reduce k-blocks of A and A_sum across WG, if needed.
auto [rA, rA_sum, active] = reduce_A(tArA, tA_max, tA_sum, thr_id);
auto [rA, rA_max_unused, rA_sum, active] = reduce_A(tArA, tA_max, tA_sum, thr_id);

/* Some subgroups may not have any work to do; if so, quit early. */
if (!active) return;

/* Complete softmax, dividing out sums. */
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < rA_sum.size(); i++)
rA_sum(i) = ElementA(1) / rA_sum(i);

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < rA.size(); i++)
rA(i) *= broadcast<0>(rA_sum, rA, i);

#if 1
if (ThreadIdxX() == 0) {
// cute::print("wg id: %d, rA(0): %f, rA_sum(0): %f\n", BlockIdxZ(), (float)rA(0),(float)rA_sum(0));
}
#endif

/* Tile output */
Tensor cO = make_identity_tensor(O.shape()); // (q,v)
Tensor gO = local_tile(cO, TileShapeO{}, blk_qv); // (q,v)

/* Prepare slices */
TiledCopyO copy_o{O};
auto thr_copy_o = copy_o.get_slice(thr_id);

auto tOrO = thr_copy_o.partition_sg_fragment_S(gO);
auto tOgO = thr_copy_o.partition_D(gO);

/* Reorder tile and write out */
reorder(rA, tOrO);
copy(copy_o, tOrO, tOgO);
}

// splitK version
template <typename QVCoord>
CUTLASS_DEVICE
void
operator()(TensorO2D const& O, // Global O tensor: (q,v)
FragA & tArA, // O accumulator: (q,v)
FragARow & tA_max, // Softmax row-wise max accumulator
FragARow & tA_sum, // Softmax row-wise sum accumulator
QVCoord blk_qv, // WG tile indices: (q,v)
int thr_id, // Work-item ID
const TensorO2D & exp_sums, // Global exp sum tensor
const TensorO2D & max_logits, // Global max logits tensor
int idx_kv_split,
int head_q
) {

using namespace cute;
using ElementA = typename FragA::element_type;

#if 0
if (ThreadIdxX() == 0 && BlockIdxZ() == 0) {
// cute::print("idx_kv_split: %d, idx_b: %d, head_q: %d, Q(0,0,head_q,l_coord): %f\n", idx_kv_split, idx_b, head_q, float(Q(0,34,head_q,l_coord)));
cute::print(" fwd epilogue tA_max(0): %f\n", float(tA_max(0)));
cute::print(" fwd epilogue tA_sum(0): %f\n", float(tA_sum(0)));
cute::print(" fwd epilogue tArA(0): %f\n", float(tArA(0)));
}
#endif

// Reduce k-blocks of A and A_sum across WG, if needed.
auto [rA, rA_max, rA_sum, active] = reduce_A(tArA, tA_max, tA_sum, thr_id);

// store exp sum and max logits for current KV split
// assume seq_len_qo == 1
if (ThreadIdxX() == 0) {
exp_sums(head_q,idx_kv_split) = rA_sum(0);
max_logits(head_q,idx_kv_split) = rA_max(0);
}

/* Some subgroups may not have any work to do; if so, quit early. */
if (!active) return;
Expand Down Expand Up @@ -193,7 +266,8 @@ class FMHAFwdEpilogue {
using namespace sycl::ext::oneapi::this_work_item;

if constexpr (ReduceK{} == _1{}) {
return std::make_tuple(tArA, tA_sum, true);
ReduceFragARow rA_max;
return std::make_tuple(tArA, rA_max, tA_sum, true);
} else {
/* Identify A tile ID and k block for this subgroup. */
auto thr_vak = group<1,3>(TiledMMAPV{}.get_thr_layout_vmnk()).get_flat_coord(assert_uniform(thr_id));
Expand Down Expand Up @@ -285,7 +359,7 @@ class FMHAFwdEpilogue {
}
}
}
return std::make_tuple(rA, rA_sum, active);
return std::make_tuple(rA, rA_max, rA_sum, active);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
using SingleFragA = FragC<TiledMMAPV>; // (atom val,q',v')
using FragA = expand_sg_fragment_t<SingleFragA, 1, VTiles>; // (atom val,q',v',VV)
using FragARow = decltype(reduce<1>(FragA{}, sycl::plus<void>{}));
// static_assert(is_same_v<decltype(FragSRow{}.shape()), float>, "dtype mismatched");
using ElementA = typename TiledMMAPV::ValTypeD;

static constexpr bool CausalMask = CausalMask_;
Expand Down Expand Up @@ -175,7 +176,8 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
int thr_id,
int seq_len,
int full_tile_offset,
int discard_seq_coord) {
int discard_seq_coord,
bool need_init = false) {
using namespace sycl::ext::oneapi::this_work_item;

// Short dimension names:
Expand All @@ -195,6 +197,12 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
Tensor cV = make_identity_tensor(V_2D.shape()); // (v,k)
Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{})); // (q,k)

#if 0
if (ThreadIdxX() == 0 && BlockIdxZ() == 0) {
print("Q 2D shape: "); print(Q_2D.shape()); print("\n");
}
#endif

/* Partition global tensors into workgroup tiles */
Tensor gQ = local_tile(cQ, TileShapeQK{}, append(blk_qv,_), Step<_1,X,_1>{}); // (q,d,D)
Tensor gK = local_tile(cK, TileShapeQK{}, make_coord(_,_,_), Step<X,_1,_1>{}); // (k,d,K,D)
Expand Down Expand Up @@ -251,7 +259,7 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,

/* Initialization steps for first block: Q/K prefetch, O init */
/* TODO: limit D prefetch for large head size, and reorder K prefetches */
if (blk_k0 == 0) {
if (blk_k0 == 0 || need_init) {
for (int D = 0; D < size<3>(pQgQ); D++) {
prefetch(prefetch_q, pQgQ(_,_,_,D));
}
Expand Down
Loading