Skip to content

CASIA-IVA-Lab/PrefixGrouper

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

50 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GitHub License GitHub Release PyPI GitHub Stars

Efficient GRPO training through shared-prefix forward

PrefixGrouper is a plug-and-play efficient GRPO training tool that requires minimal modifications to existing codebases to achieve reduced computation, lower device memory consumption, and accelerated training. Additionally, this tool can be applied to other scenarios requiring shared-prefix training/inference beyond GRPO.

In current mainstream GRPO training pipelines, policy model training primarily involves copying prefixes (typically questions, multimodal inputs, etc.) G times. Consequently, when training data prefixes are sufficiently long (e.g., long-context reasoning, image/long-video inference), redundant computation during training becomes non-negligible, leading to increased device memory usage, higher computation costs, and slower training speeds. To address this, we propose PrefixGrouper, a plug-and-play GRPO training tool that enables efficient training through shared-prefix forward passes. Reduced device memory consumption conversely allows more GPUs to support larger group sizes—critical for GRPO algorithms.

News

[2025/6/9] Our technical report is available here!

[2025/6/7] We've updated PrefixGrouper to version 0.0.1rc2 with better encapsulation and fewer code changes required. Welcome to use it!

[2025/6/3] We release PrefixGrouper. Tech report is coming, please stay tuned.

Method Overview

The core of PrefixGrouper lies in its attention operation design:

By decomposing the original redundant self-attention operation into prefix self-attention + suffix concat-attention, PrefixGrouper enables efficient GRPO training and is theoretically compatible with various attention implementations (EagerAttention, FlashAttention, SDPA, etc.) as well as hardware devices (GPU, NPU, etc.).

Comparison of FLOPs and memory usage between PrefixGrouper and baseline is as follows, which display results at fixed prefix lengths (4096, 8192, and 16384) across different ratios (prefix length / suffix length):

PrefixGrouper demonstrates significant advantages in long-context scenarios, further highlighting its efficiency.

Installation

pip install prefix_grouper

Quick Start

To make PrefixGrouper simpler and easier to use, we provide modification examples for some models.

  • Model file modification examples can be found in examples. For clarity, we wrap key modifications with "PrefixGrouper Start" and "PrefixGrouper End" comments.
  • For examples simulating the full training workflow, see tests/equivalence. We provide an almost complete flow for one training step.

If you happen to use one of these models, you can directly integrate the example code into your codebase. However, we recommend briefly reviewing the tutorial below to better understand the tool's workflow.

Running examples:

cd PrefixGrouper
python src/tests/equivalence/test_xxx.py --model_path /path/to/your/model

Tutorial

Tip

For better understanding, it's recommended to read alongside the code in both examples and tests/equivalence.

Briefly, PrefixGrouper requires modifications in three areas: data input/output, attention mechanisms, and position encoding. Throughout this document, we refer to data corresponding to a query (prefix) as a sample, and each model-generated output based on the prefix as a response.

Data Input/Output

To minimize redundant prefix forward passes and maximize parallel acceleration, PrefixGrouper first concatenates each sample in the batch with its corresponding responses (pseudocode example):

  • Best Practices (requires version 0.0.1rc2 or above)
# Prefix: [b1, seq_len1], where b1 should be the number of samples
prompt_ids = ...
# Prefix mask: [b1, seq_len1]
prompt_mask = ...
# Suffix: [b2, seq_len2], where b2 should be the total number of responses across all samples
completion_ids = ...
# Suffix mask: [b2, seq_len2]
completion_mask = ...
# int or List[int]. int indicates each sample has the same number of responses, List[int] specifies different response counts per sample.
group_sizes = ...

# Initialize a PrefixGrouper instance.
prefix_grouper = PrefixGrouper.from_ungrouped_masks(
    prefix_mask=prompt_mask,
    suffix_mask=completion_mask,
    group_sizes=group_sizes,
    padding_mode="right",
    device=device,
)
# Here we use PrefixGrouper to concatenate inputs into final input_ids with shape [b1, seq_len].
# NOTE: Can also input features, i.e. prompt_embeds ([b1, seq_len1, dim]), suffix_embeds ([b2, seq_len2, dim])
input_ids = prefix_grouper.concat_input(prompt_ids, prompt_mask, completion_ids, completion_mask)
attention_mask = prefix_grouper.padding_mask
# Perform model forward - just add one extra argument
res = model(*args, **kwargs, prefix_grouper=prefix_grouper)
# ====== Forward process complete ======
# Explanation of ``include_prefix_last`` parameter: Note that the first token output of the response is actually generated by the last token input of the prefix. Thus the output of the prefix's last token requires loss calculation. Passing ``include_prefix_last=1`` to ``split_output`` means ``PrefixGrouper`` will repeat and concatenate the prefix's last token to the beginning of the suffix. The mask undergoes identical processing.
prefix_output, prefix_mask, suffix_output, suffix_mask = (
    prefix_grouper.split_output(res.logits, include_prefix_last=1)
)
# Must convert completion_ids to right padding to align with suffix_output positions
completion_ids = prefix_grouper.convert_padding(completion_ids, completion_mask, padding_mode="right")
# ====== Entire input/output process complete ======

# After obtaining normal outputs, proceed to calculate loss and backpropagate per standard GRPO procedure - fully identical.
# NOTE: Some parts are omitted here, such as advantage, KL loss, importance sampling, etc. Please write your own GRPO loss according to your actual needs.
suffix_output = suffix_output[:, :-1]
suffix_mask = suffix_mask[:, 1:]
# NOTE: Since suffix_output uses ``include_prefix_last=1``, ``completion_ids`` is actually 1 token shorter than ``suffix_output``
# Thus it doesn't require [:, 1:] slicing because the first token is already a valid target.
loss = (suffix_output.gather(-1, completion_ids.unsqueeze(-1)).squeeze(-1) - suffix_output.logsumexp(-1)).exp()
loss = loss * suffix_mask
loss = (loss.sum(-1) / suffix_mask.sum(-1)).mean()
(-loss).backward()

Explanation of concat_input and split_output:

  1. concat_input concatenates prompts and completions based on prompt_mask and completion_mask. The resulting input_ids are organized according to the padding_mode parameter passed to PrefixGrouper. For example, PrefixGrouper.from_ungrouped_masks(..., padding_mode="right") means the concatenated input_ids will use compact right padding: prompts and completions form continuous sequences without intermediate padding, left-aligned with padding added on the right.
  2. split_output splits the output logits into prefix and suffix portions, returning corresponding masks. Note that the include_prefix_last=1 parameter means the last token of the original prefix will be assigned to the beginning of the suffix. Specifically: input prompt_ids ([b1, seq_len1]) and completion_ids ([b2, seq_len2]) produce output logits of size [b1, seq_len, dim]. After split_output(..., include_prefix_last=1), prefix_output and suffix_output become sizes [b1, seq_len1 - 1, dim] (missing last token) and [b2, seq_len2 + 1, dim] (with extra first token). To better implement include_prefix_last=1, PrefixGrouper uses left padding for prefixes and right padding for suffixes during splitting, ensuring continuous boundaries. This requires converting completion_ids to the same padding pattern via prefix_grouper.convert_padding. Finally, note that after conversion, completion_ids remains shape [b2, seq_len2], while suffix_output and suffix_mask have sequence length seq_len2 + 1 (due to the extra token). Thus for alignment, use suffix_output = suffix_output[:, :-1] and suffix_mask = suffix_mask[:, 1:], while completion_ids requires no modification.

Note

NOTE that the output should use suffix_mask returned by split_output to calculate loss rather than completion_mask, because both the suffix_output and completion_ids will be converted to right-padding, in which case the completion_mask may not apply.

  • Older Version Examples: Please see tests/test_equivalence.

Key points for data processing: input concatenation, group_info statistics, and output splitting. Customize based on your project needs while maintaining interface consistency (see docs).

Attention Mechanism

Minor model modifications suffice for attention adaptation. For transformers supporting AttentionInterface, simple registration is possible (experimental). Below describes the generic approach:

if prefix_grouper is None:
    # Original attention (baseline)
    attn_output = _flash_attention_forward(...)
else:
    # ===== PrefixGrouper Start =====
    def attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor, *args, **kwargs):
        # Adapter function for dimension/parameter alignment
        return _flash_attention_forward(...)
    
    attn_output = prefix_grouper.forward(...)
    # ====== PrefixGrouper End ======

Propagate prefix_grouper parameter through model forward passes.

Position Encoding

Position IDs for concatenated responses should share the same starting ID pattern: [0, 1, ..., prefix_len, prefix_len+1, ..., suffix1_end, prefix_len+1, ..., suffix2_end, ...]. Some useful information: prefix_grouper.group_info[i].prefix_len and prefix_grouper.group_info[i].suffix_lens can obtain the prefix/suffix length information (number of valid tokens excluding padding) for the i-th sample; prefix_grouper.padding_mask can retrieve the attention mask for the input tensor after concatenating the prefix and suffix. The above information can be used to assist in position ids calculation.

Position encoding is pre-adapted for models in the Quick Start section (see examples).

Start Training

Complete GRPO training simulations are provided in tests for reference.

Documentation

Core API documentation:

PrefixGrouper

PrefixGrouper(Optional[List[List[int]]] = None, device=None, padding_mode: Union[str, torch.Tensor] = "right")

group_info: Outer list: sample count (b). Inner lists: [prefix_len, suffix1_len, suffix2_len,...]. This parameter can be None, in which case you need to manually call init (same signature as PrefixGrouper.__init__) to implement delay initialization.

device: Device for initializing PrefixGrouper (actual ops use input tensor's device).

padding_mode: "left"/"right" (dense padding) or torch.Tensor (custom padding mask, shape [b, seq_len]).

Usage examples:

  • With concat_input (recommended):
prefix_grouper = PrefixGrouper(group_info, padding_mode="right")
  • Custom input handling:
prefix_grouper = PrefixGrouper(group_info, padding_mode=custom_padding_mask)

PrefixGrouper.concat_input(self, prefix: torch.Tensor, prefix_mask: torch.Tensor, suffix: torch.Tensor, suffix_mask: torch.Tensor)

Concatenates prefix ([b, seq_len] or [b, seq_len, dim]) and suffix ([b * group_size, seq_len] or [b * group_size, seq_len, dim]) using group_info. Requires prefix_mask/suffix_mask (shape [b, seq_len]).

PrefixGrouper.forward(self, __attn_func: AttnFuncType, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs)

Performs attention using __attn_func. Function signature: attn_func(q, k, v, attn_mask, *args, **kwargs). Input q/k/v shape: [b, num_heads, seq_len, head_dim]. Output shape: [b, seq_len, num_heads, head_dim]. Do not manually pass attention masks.

PrefixGrouper.split_output(self, output: torch.Tensor, include_prefix_last: int = 0)

output: Shape [b, seq_len, dim]

include_prefix_last: Controls prefix boundary handling (0: no conversion; 1: attach last prefix token to suffixes).

Future Plans

  • Hugging Face Transformers AttentionInterface Integration (This feature is currently in testing)
  • Additional Training Device Support (NPU under testing - no compatibility issues found so far)
  • Test Cases for More Models (We plan to release plain-text test cases for Qwen2.5 and Qwen3 models)
  • Support for other attention implementations (EagerAttention, SDPA)

Data Usage Statement

Test data in this project is strictly for academic research purposes with the following limitations:

  1. Commercial use is prohibited
  2. Data redistribution is prohibited
  3. De-anonymization attempts are prohibited

Citation

If you find this work helpful, you can cite the following papers:

@misc{liu2025prefixgrouperefficientgrpo,
      title={Prefix Grouper: Efficient GRPO Training through Shared-Prefix Forward}, 
      author={Zikang Liu and Tongtian Yue and Yepeng Tang and Longteng Guo and Junxian Cai and Qingbin Liu and Xi Chen and Jing Liu},
      year={2025},
      eprint={2506.05433},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2506.05433}, 
}

About

An efficient GRPO training util.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages