Skip to content

Conversation

@ClarkChin08
Copy link

@ClarkChin08 ClarkChin08 commented Dec 8, 2025

Description

This PR introduces support for cached KV and paged KV in the New Flash Attention Kernel.

Type

[x] Cached KV on fixed sequence lengths with multi-batch and GQA support (Good accuracy and performance).
[x] Paged KV (non-contiguous) on fixed sequence lengths with multi-batch and GQA support (Good accuracy and performance).
[x] Cached KV/Paged KV on variable sequence lengths (Good accuracy and performance).
[x] Cached KV/Paged KV with Causal Mask enabled, supporting multi-batch and GQA (Good accuracy and performance).

We now support the most complex running combinations with good accuracy and performance. Example command:

./examples/06_bmg_flash_attention/06_xe_fmha_fwd_prefill_bfloat16_t_hdim128 --iterations=1 --batch=2 --num_heads_q=32 --seq_len_kv=1024 --seq_len_qo=1024 --num_heads_kv=8 --varlen --seq_len_kv_cache=512 --is_causal --use_paged_kv

Copilot AI review requested due to automatic review settings December 8, 2025 01:07
@ClarkChin08 ClarkChin08 changed the title Add cached KV and paged KV support to Flash Attention [WIP] Add cached KV and paged KV support to Flash Attention Dec 8, 2025
@ClarkChin08 ClarkChin08 marked this pull request as draft December 8, 2025 01:08
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for cached KV and paged KV cache to the Flash Attention implementation for Intel Xe GPUs. The changes enable more efficient memory management for attention mechanisms by allowing:

  1. Cached KV support: Appending new KV tokens to existing cached tokens
  2. Paged KV support: Non-contiguous memory layout for KV cache using a page table mapping

Key changes include:

  • New command-line options --seq_len_kv_cache, --use_paged_kv, --page_size, and --verify
  • Extension of problem shape to include cached KV sequence length
  • Page table infrastructure for mapping logical to physical pages
  • Mainloop modifications to handle both cached and new KV tensors with optional paging

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.

File Description
examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp Adds command-line options, page table initialization, verification logic for cached/paged KV, and extends problem shape to 8 dimensions
applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp Adds K_cache and V_cache tensors to kernel arguments, extends sequence length shape to 3D tuple, computes offsets for cached tensors
applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp Adds PagedKV template parameter, implements page table lookup for physical tile indexing, handles prefetch/copy for both cached and new KV tensors

Signed-off-by: Chen, Xi2 <xi2.chen@intel.com>
@ClarkChin08 ClarkChin08 changed the title [WIP] Add cached KV and paged KV support to Flash Attention Add cached KV and paged KV support to Flash Attention Dec 11, 2025
@ClarkChin08 ClarkChin08 marked this pull request as ready for review December 11, 2025 06:32
@ClarkChin08 ClarkChin08 changed the title Add cached KV and paged KV support to Flash Attention Support for cached KV and paged KV in the New Flash Attention Kernel Dec 11, 2025
Signed-off-by: Chen, Xi2 <xi2.chen@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants