-
Notifications
You must be signed in to change notification settings - Fork 70
Support for cached KV and paged KV in the New Flash Attention Kernel #661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds support for cached KV and paged KV cache to the Flash Attention implementation for Intel Xe GPUs. The changes enable more efficient memory management for attention mechanisms by allowing:
- Cached KV support: Appending new KV tokens to existing cached tokens
- Paged KV support: Non-contiguous memory layout for KV cache using a page table mapping
Key changes include:
- New command-line options
--seq_len_kv_cache,--use_paged_kv,--page_size, and--verify - Extension of problem shape to include cached KV sequence length
- Page table infrastructure for mapping logical to physical pages
- Mainloop modifications to handle both cached and new KV tensors with optional paging
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
| examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp | Adds command-line options, page table initialization, verification logic for cached/paged KV, and extends problem shape to 8 dimensions |
| applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp | Adds K_cache and V_cache tensors to kernel arguments, extends sequence length shape to 3D tuple, computes offsets for cached tensors |
| applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp | Adds PagedKV template parameter, implements page table lookup for physical tile indexing, handles prefetch/copy for both cached and new KV tensors |
applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp
Outdated
Show resolved
Hide resolved
Signed-off-by: Chen, Xi2 <xi2.chen@intel.com>
cf63631 to
6fbed33
Compare
Signed-off-by: Chen, Xi2 <xi2.chen@intel.com>
Signed-off-by: Chen, Xi2 <xi2.chen@intel.com>
Signed-off-by: Chen, Xi2 <xi2.chen@intel.com>
Signed-off-by: Chen, Xi2 <xi2.chen@intel.com>
Description
This PR introduces support for cached KV and paged KV in the New Flash Attention Kernel.
Type
[x] Cached KV on fixed sequence lengths with multi-batch and GQA support (Good accuracy and performance).
[x] Paged KV (non-contiguous) on fixed sequence lengths with multi-batch and GQA support (Good accuracy and performance).
[x] Cached KV/Paged KV on variable sequence lengths (Good accuracy and performance).
[x] Cached KV/Paged KV with Causal Mask enabled, supporting multi-batch and GQA (Good accuracy and performance).
We now support the most complex running combinations with good accuracy and performance. Example command:
./examples/06_bmg_flash_attention/06_xe_fmha_fwd_prefill_bfloat16_t_hdim128 --iterations=1 --batch=2 --num_heads_q=32 --seq_len_kv=1024 --seq_len_qo=1024 --num_heads_kv=8 --varlen --seq_len_kv_cache=512 --is_causal --use_paged_kv