diff --git a/README.md b/README.md
index a52c8a9..e6c5b89 100644
--- a/README.md
+++ b/README.md
@@ -1,203 +1,96 @@
-# slime
+# APRIL: Active Partial Rollouts in Reinforcement Learning to Tame Long-tail Generation
+## About
+### Background: Why the sampling-training loop of synchronous RL is dragged down by the "long tail"
-[中文版](./README_zh.md)
+In on-policy RLHF/GR?O training, the system enters an update phase only after collecting **N** rollout samples in a "round." Due to the inconsistent lengths of generated samples, the system has to wait for a few **long-tail samples** to complete before starting the training phase. This leads to decreased GPU utilization and lower throughput in the later stages of the rollout phase.
-**slime** is an LLM post-training framework for RL scaling, providing two core capabilities:
+### What We Did: Active Partial Rollout (APRIL)
-1. **High-Performance Training**: Supports efficient training in various modes by connecting Megatron with SGLang;
-2. **Flexible Data Generation**: Enables arbitrary training data generation workflows through custom data generation interfaces and server-based engines.
+**Core Idea**: In each round, we **over-sample** (N' > N) and **actively interrupt** the remaining in-progress requests once the target of **N** completed samples is reached. The **unfinished responses** are stored in a **buffer** and are **prioritized for continued rollout** in the next round, thereby mitigating the efficiency degradation caused by long-tail requests.
-## Table of Contents
+
+### Highlights
- - [Architecture Overview](#architecture-overview)
- - [Quick Start](#quick-start)
- - [Environment Setup](#environment-setup)
- - [Examples](#examples)
- - [Dense Model Examples: GLM-4-9B and Qwen3-4B](#Dense-Model-Examples-GLM-4-9B-and-Qwen3-4B)
- - [MoE Model Example: Qwen3-30B-A3B and DeepSeek-R1](#MoE-Model-Example-Qwen3-30B-A3B-and-DeepSeek-R1)
- - [Multi-Turn + Tool Calling Example: Search-R1 lite](#Multi-Turn--Tool-Calling-Example-Search-R1-lite)
- - [SFT Example: Qwen3-4B-Base with OpenHermes-2.5](#SFT-Example-Qwen3-4B-Base-with-OpenHermes-25)
- - [Checkpoint Format Conversion](#checkpoint-format-conversion)
- - [Starting the Training Process](#starting-the-training-process)
- - [Argument Descriptions](#argument-descriptions)
- - [Developer Guide](#developer-guide)
- - [FAQ & Acknowledgements](#faq--acknowledgements)
+- **Over-sampling**: Assuming the training phase requires `rollout_batch_size=32` complete samples per round, we actually initiate a larger sampling request, i.e., `over_sampling_batch_size=64`.
+- **Stop upon collection**: As soon as the number of collected complete sample groups reaches `rollout_batch_size`, an `abort` signal is immediately sent to the sglang router.
+- **Collect and reuse**: Upon receiving the `abort` signal, sglang stops the ongoing generation tasks and returns their partially generated portions (half-completed trajectories). This partial data is not discarded but is stored in a buffer. When the next rollout round begins, they continue generating from where they left off, along with new prompts, thus achieving seamless reuse across iteration steps.
+- **Elegant implementation**: Slime's partial rollout provides a more native and lightweight optimization solution that is less intrusive to the original pipeline. You can enable it out-of-the-box simply by setting the `--partial-rollout` flag and specifying `--over-sampling-batch-size`.
-## Architecture Overview
+## Three Steps to Get Started
-
-
-**Module Descriptions**:
-
- - **training (Megatron)**: Responsible for the main training process, reads data from the Data Buffer, and synchronizes parameters to the rollout module after training.
- - **rollout (SGLang + router)**: Generates new data (including rewards/verifier outputs) and stores it in the Data Buffer.
- - **data buffer**: A bridge module that manages prompt initialization, custom data, and rollout generation methods.
-
-## Quick Start
-
-### Environment Setup
-
-Based on the `zhuzilin/slime:latest` image (pre-installed with SGLang 0.4.7 and Megatron):
+### 1) Environment Setup (Requires an AMD GPU)
+**Start docker**
```bash
docker run --rm --gpus all --ipc=host --shm-size=16g \
--ulimit memlock=-1 --ulimit stack=67108864 \
- -it zhuzilin/slime:latest /bin/bash
-
-git clone https://github.com/THUDM/slime.git
-cd slime
-pip install -e .
+ -it rlsys/slime:slime_ubuntu22.04_rocm6.3.4-patch-numa-patch_sglang0.4.9_megatron-patch_ray2.47.1_apex_torch-memory-saver0.0.8-patch-vim /bin/bash
```
-
-- If you prefer not to use Docker, or if it's inconvenient, please refer to [Setting up the Environment from Scratch](./docs/en/build.md).
-- For AMD support, please refer to [AMD Tutorial](./docs/en/amd_tutorial.md).
-
-### Examples
-
-#### Dense Model Examples: GLM-4-9B and Qwen3-4B
-
-We provide examples to use [GLM-4-9B](https://huggingface.co/THUDM/GLM-Z1-9B-0414) and [Qwen3-4B](https://huggingface.co/Qwen/Qwen3-4B), please refer to:
-
-- [Example: GLM-4-9B](docs/en/models/glm4-9B.md).
-- [Example: Qwen3-4B](docs/en/models/qwen3-4B.md).
-
-#### MoE Model Example: Qwen3-30B-A3B and DeepSeek-R1
-
-For MoE example, please refer to:
-
-- [Example: Qwen3-30B-A3B](docs/en/models/qwen3-30B-A3B.md).
-- [Example: Training DeepSeek R1 with 128xH100](docs/en/models/deepseek-r1.md)
-
-#### Multi-Turn + Tool Calling Example: Search-R1 lite
-
-For multi-turn and tool calling, we also provides an minimal reimplenmentation of Search-R1, please refer to:
-
-- [Example: Search-R1 lite](examples/search-r1/README.md).
-
-#### SFT Example: Qwen3-4B-Base with OpenHermes-2.5
-
-slime is not just a RL framework, we support a diverse set of post-training setups. For an SFT example, please refer to:
-
-- [Example: Qwen3-4B-Base with OpenHermes-2.5](docs/en/sft.md).
-
-### Checkpoint Format Conversion
-
-Since slime uses Megatron, and Megatron does not support loading Hugging Face checkpoints directly, we need to convert the model to the `torch_dist` format that Megatron supports.
-
-#### HF → Megatron torch\_dist ckpt
-
-We recommend using [Pai-Megatron-Patch](https://github.com/alibaba/Pai-Megatron-Patch) for mcore checkpoint conversion.
-
-If the mode you are using are not supported by Pai-Megatron-Patch, you could use [mbridge](https://github.com/ISEEKYAN/mbridge.git) for conversion:
+### 2) Install APRIL
```bash
-cd slime/
-PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \
- --hf-checkpoint /root/GLM-Z1-9B-0414 \
- --save /root/GLM-Z1-9B-0414_torch_dist
+git clone [https://github.com/RLsys-Foundation/APRIL.git](https://github.com/RLsys-Foundation/APRIL.git)
+cd APRIL
+pip install -e .
```
-⚠️ If you encounter an issue where slime cannot be found, please run `pip install -e .` in the slime directory.
-
-#### Megatron torch\_dist → HF ckpt
+### 3) Run an Example
-To convert a `torch_dist` checkpoint saved during training back to a Hugging Face checkpoint:
+All scripts are in the `scripts/partial_rollout/` directory.
```bash
-cd slime/
-PYTHONPATH=/root/Megatron-LM python tools/convert_torch_dist_to_hf.py \
- --input-dir /path/to/torch_dist_ckpt/iter_xxx/ \
- --output-dir /root/GLM-Z1-9B-0414-iter_xxx \
- --origin-hf-dir /root/GLM-Z1-9B-0414
+bash scripts/partial_rollout/qwen/grpo/run-qwen3-4B-dapo-partial.sh
```
+### 4) Parameter Details
-⚠️ Since the `torch_dist` checkpoint converted by mbridge does not currently save args, you cannot convert the checkpoint from the previous step back to HF format.
-
-#### Any Megatron ckpt → HF
-
-Applicable for custom save formats (e.g., `--ckpt-format torch`).
-
-The principle behind this conversion method is to reuse the function that updates parameters from Megatron to SGLang during training. This means reusing the training script and changing the original command from:
-
+The core functionality of partial rollout is controlled by the following parameters:
```bash
-ray job submit --address="http://127.0.0.1:8265" \
- --runtime-env-json='{
- "env_vars": { ...}
- }' \
- -- python3 train.py \
- ... # Other training args
+# Enable the partial rollout feature
+# Set this parameter to enable the mechanism of stopping generation upon reaching the target count + recycling unfinished samples
+--partial-rollout
+
+# The batch size for sampling. This parameter controls the sampling granularity per round.
+# If this parameter > rollout_batch_size, over-sampling is performed.
+# If this parameter < rollout_batch_size, sampling will continue at this granularity until rollout_batch_size samples are collected.
+--over-sampling-batch-size 16
```
+For other parameters, please refer to the arguments in [arguments.py](./slime/utils/arguments.py). For more details, you can consult the original [slime](https://github.com/THUDM/slime) repository.
+## Results and Comparison (Abridged)
-To:
-
-```bash
-torchrun --nproc_per_node ${NUM_GPU} tools/convert_to_hf.py \
- --load /your/saved/megatron_ckpt \
- --output-dir /your/converted/hf_ckpt \
- ... # Other training args
-```
+| Dataset | Model | Metric | APRIL vs. Baseline |
+|---------------|----------|------------------|-----------------------|
+| DAPO‑Math‑17k | Qwen3‑4B | Rollout Throughput | **+17%** |
+| DeepScaleR | Qwen3‑4B | Rollout Throughput | **+21%** |
+| DeepMath‑103K | Qwen3‑4B | Rollout Throughput | **+35%** |
-That is, keep all other arguments the same, and:
+
-1. Change the task launcher from `ray` to `torchrun`. Set the number of GPUs to the minimum required for Megatron's parallelism without data parallelism (DP). For example, if you are using `tp4`, set it to 4.
-2. Make sure to change `--load` to the path of the checkpoint you want to load.
-3. Add the `--output-dir` argument to specify where the converted Hugging Face checkpoint should be saved.
+## Frequently Asked Questions (FAQ)
-## Starting the Training Process
+- **Q: Will APRIL affect policy purity and convergence?**
+ - A: It will definitely have an impact on policy purity; the proportion of off-policy tokens in one round is about 40%. However, from both an engineering and experimental perspective, partial rollout has not introduced significant instability under the current settings. Further verification is needed for tasks with a much larger `max_response_length` (e.g., agent tasks, multi-turn tasks).
-The entire program needs to be launched using Ray. First, you need to start a Ray cluster. On node 0, run:
+- **Q: Are changes to the decoding kernel required?**
+ - A: No. APRIL operates at the **system scheduling layer** and does not conflict with inference acceleration techniques like speculative decoding or continuous batching. Instead, they are complementary and can be stacked.
-```bash
-# Node0 (HEAD)
-ray start --head --node-ip-address ${MASTER_ADDR} \
- --num-gpus 8 --disable-usage-stats
+## Directory Structure
-# Other Nodes
-ray start --address=${MASTER_ADDR}:6379 --num-gpus 8
```
+APRIL/
+├── scripts/
+│ └── partial_rollout/
+│ ├── deepseek/ # Experiment code for deepseek-r1-distill-1.5B
+│ └── qwen/ # Experiment code for qwen3-4B
+├── slime/
+│ ├── backends/
+│ ├── rollout/
+│ │ └── sglang_example.py # Core sampling code
+│ ├── ray/ # Core scheduling logic
+│ │ └── buffer.py # Buffer implementation code
+│ └── utils/
+└── tools/ # Megatron format conversion tools
-After the Ray cluster has started, you can submit a job from node 0, for example:
-
-```bash
-ray job submit --address="http://127.0.0.1:8265" \
- --runtime-env-json='{
- "env_vars": {
- "PYTHONPATH": "/root/Megatron-LM/",
- ... # e.g., no_proxy, API variables, etc.
- }
- }' \
- -- python3 train.py \
- --... # Other Megatron/SGLang/slime arguments
```
+## Paper
-### Argument Descriptions
-
-Arguments are divided into three categories:
-
-1. **Megatron arguments**: slime reads all arguments set in Megatron via `PYTHONPATH`. You can configure Megatron by passing arguments like `--tensor-model-parallel-size 2`.
-2. **SGLang arguments**: All arguments for the installed SGLang are supported. These arguments must be prefixed with `--sglang-`. For example, `--mem-fraction-static` should be passed as `--sglang-mem-fraction-static`.
-3. **slime-specific arguments**: Please refer to: [slime/utils/arguments.py](slime/utils/arguments.py)
-
-For complete usage instructions, please refer to the [Usage Documentation](docs/en/usage.md).
-
-## Developer Guide
-
- - **Contributions are welcome\!** If you have suggestions for new features, performance tuning, or feedback on user experience, feel free to submit an Issue or PR 😊
-
- - Use [pre-commit](https://pre-commit.com/) to ensure code style consistency for your commits:
-
- ```bash
- apt install pre-commit -y
- pre-commit install
- ```
-
- - For debugging tips, please refer to the [Debugging Guide](docs/en/debug.md)
-
-## Hardware Support
-- Nvidia: refer to this repo README
-- AMD: refer to the [tutorial](docs/en/amd_tutorial.md)
-
-## FAQ & Acknowledgements
-
- - For frequently asked questions, please see the [Q\&A](docs/en/qa.md)
- - Special thanks to the following projects & communities: SGLang, Megatron‑LM, mbridge, OpenRLHF, veRL, and others.
+(TODO: arXiv link for the paper)
\ No newline at end of file
diff --git a/README_zh.md b/README_zh.md
index 111ffa0..9bbb80f 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -1,199 +1,86 @@
-# slime
+APRIL: Active Partial Rollouts in Reinforcement Learning to Tame Long-tail Generation
+About
+Background: Why the sampling-training loop of synchronous RL is dragged down by the "long tail"
+In on-policy RLHF/GR?O training, the system enters an update phase only after collecting N rollout samples in a "round." Due to the inconsistent lengths of generated samples, the system has to wait for a few long-tail samples to complete before starting the training phase. This leads to decreased GPU utilization and lower throughput in the later stages of the rollout phase.
-[English](./README.md)
+What We Did: Active Partial Rollout (APRIL)
+Core Idea: In each round, we over-sample (N' > N) and actively interrupt the remaining in-progress requests once the target of N completed samples is reached. The unfinished responses are stored in a buffer and are prioritized for continued rollout in the next round, thereby mitigating the efficiency degradation caused by long-tail requests.
-**slime** 是为 RL scaling 设计的 LLM post‑training 框架,提供两大核心能力:
+Highlights
+Over-sampling: Assuming the training phase requires rollout_batch_size=32 complete samples per round, we actually initiate a larger sampling request, i.e., over_sampling_batch_size=64.
-1. **高性能训练**:通过连接 Megatron 与 SGLang,支持各种模式的高效训练;
-2. **灵活的数据生成**:通过自定义数据生成接口以及 server based engine,实现任意的数据训练数据生成流程。
+Stop upon collection: As soon as the number of collected complete sample groups reaches rollout_batch_size, an abort signal is immediately sent to the sglang router.
-## 目录
+Collect and reuse: Upon receiving the abort signal, sglang stops the ongoing generation tasks and returns their partially generated portions (half-completed trajectories). This partial data is not discarded but is stored in a buffer. When the next rollout round begins, they continue generating from where they left off, along with new prompts, thus achieving seamless reuse across iteration steps.
-- [架构总览](#架构总览)
-- [快速开始](#快速开始)
- - [环境准备](#环境准备)
- - [示例](#示例)
- - [Dense 模型示例:GLM-4-9B 与 Qwen3-4B](#Dense-模型示例GLM-4-9B-与-Qwen3-4B)
- - [MoE 模型示例:Qwen3-30B-A3B 与 DeepSeek-R1](#MoE-模型示例Qwen3-30B-A3B-与-DeepSeek-R1)
- - [多轮对话 + 工具调用示例:Search-R1 lite](#多轮对话--工具调用示例Search-R1-lite)
- - [SFT 示例:Qwen3-4B-Base + OpenHermes-2.5](#SFT-示例Qwen3-4B-Base--OpenHermes-25)
-- [Checkpoint 格式转换](#checkpoint-格式转换)
-- [启动训练流程](#启动训练流程)
-- [参数说明](#参数说明)
-- [开发指南](#开发指南)
-- [常见 Q&A 与致谢](#常见-qa-与致谢)
+Elegant implementation: Slime's partial rollout provides a more native and lightweight optimization solution that is less intrusive to the original pipeline. You can enable it out-of-the-box simply by setting the --partial-rollout flag and specifying --over-sampling-batch-size.
-## 架构总览
+Three Steps to Get Started
+1) Environment Setup (Requires an AMD GPU)
+Start docker
-
+Bash
-**模块说明**:
-
-- **training (Megatron)**:负责主训练流程,从 Data Buffer 读取数据,训练完后将参数同步至 rollout 模块;
-- **rollout (SGLang + router)**:生成新数据(含 reward/verifier),存储至 Data Buffer;
-- **data buffer**:桥梁模块,管理 prompt 初始化、自定义数据与 rollout 生成方法。
-
-## 快速开始
-
-### 环境准备
-
-基于镜像 zhuzilin/slime:latest(已预装 SGLang 0.4.7 和 Megatron):
-
-```bash
docker run --rm --gpus all --ipc=host --shm-size=16g \
--ulimit memlock=-1 --ulimit stack=67108864 \
- -it zhuzilin/slime:latest /bin/bash
+ -it rlsys/slime:slime_ubuntu22.04_rocm6.3.4-patch-numa-patch_sglang0.4.9_megatron-patch_ray2.47.1_apex_torch-memory-saver0.0.8-patch-vim /bin/bash
+2) Install APRIL
+Bash
-git clone https://github.com/THUDM/slime.git
-cd slime
+git clone https://github.com/RLsys-Foundation/APRIL.git
+cd APRIL
pip install -e .
-```
-
-- 对于不方便使用 docker 的场景,请参考 [从零搭建环境](./docs/zh/build.md);
-- 对于 AMD 支持,请参考 [AMD 使用教程](./docs/en/amd_tutorial.md)。
-
-### 示例
-
-#### Dense 模型示例:GLM-4-9B 与 Qwen3-4B
-
-我们提供了 [GLM-4-9B](https://huggingface.co/THUDM/GLM-Z1-9B-0414) 和 [Qwen3-4B](https://huggingface.co/Qwen/Qwen3-4B) 的使用示例,可以通过他们对 slime 的使用方法有个基本的了解:
-
-- [示例:GLM-4-9B](docs/zh/models/glm4-9B.md)
-- [示例:Qwen3-4B](docs/zh/models/qwen3-4B.md)
-
-#### MoE 模型示例:Qwen3-30B-A3B 与 DeepSeek-R1
-
-我们也提供了 MoE 模型的示例,请查看:
-
-- [示例:Qwen3-30B-A3B](docs/zh/models/qwen3-30B-A3B.md)
-- [示例:128xH100 训练 DeepSeek-R1](docs/zh/models/deepseek-r1.md)
-
-#### 多轮对话 + 工具调用示例:Search-R1 lite
-
-针对多轮对话和工具调用场景,我们提供了一个简化版的 Search-R1 复现,请查看:
-
-- [示例:Search-R1 lite](examples/search-r1/README_zh.md)
-
-#### SFT 示例:Qwen3-4B-Base + OpenHermes-2.5
-
-slime is not just a RL framework, we support a diverse set of post-training setups. For an SFT example, please refer to:
-
-slime 不仅仅是一个 RL 框架,我们还支持了各种后训练流程。如果想使用 SFT,请参看:
-
-- [示例: Qwen3-4B-Base + OpenHermes-2.5](docs/zh/sft.md).
-
-### Checkpoint 格式转换
-
-由于 slime 使用 megatron,而 megatron 不支持加载 huggingface checkpoint,我们需要将模型转换至 megatron 可以支持的 torch_dist 格式。
-
-#### HF → Megatron torch_dist ckpt
-
-我们推荐使用 [Pai-Megatron-Patch](https://github.com/alibaba/Pai-Megatron-Patch) 进行转换。如果你目前在使用的模型不被 Pai-Megatron-Patch 支持,可以使用 [mbridge](https://github.com/ISEEKYAN/mbridge.git) 转换:
-
-```bash
-cd slime/
-PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \
- --hf-checkpoint /root/GLM-Z1-9B-0414 \
- --save /root/GLM-Z1-9B-0414_torch_dist
-```
-
-⚠️ 如果出现找不到 slime 的问题,请在 slime 目录下 `pip install -e .`。
-
-#### Megatron torch_dist → HF ckpt
-
-将训练过程中的存储的 torch_dist ckpt 转为 hf ckpt:
-
-```bash
-cd slime/
-PYTHONPATH=/root/Megatron-LM python tools/convert_torch_dist_to_hf.py \
- --input-dir /path/to/torch_dist_ckpt/iter_xxx/ \
- --output-dir /root/GLM-Z1-9B-0414-iter_xxx \
- --origin-hf-dir /root/GLM-Z1-9B-0414
-```
-
-⚠️ 由于 mbridge 转换的 torch_dist ckpt 目前不保存 args,不能基于上一步的 torch_dist ckpt 反转回 HF。
-
-#### 任意 Megatron ckpt → HF
-
-适用于自定义保存格式(如 `--ckpt-format torch`)。
-
-转化方式的原理是直接复用训练中,从 megatron 向 sglang 更新参数的函数,也就是直接复用一下训练脚本,将原先的:
-
-```bash
-ray job submit --address="http://127.0.0.1:8265" \
- --runtime-env-json='{
- "env_vars": { ...}
- }' \
- -- python3 train.py \
- ... # 其他训练 args
-```
-
-改成:
-
-```bash
-torchrun --nproc_per_node ${NUM_GPU} tools/convert_to_hf.py \
- --load /your/saved/megatron_ckpt \
- --output-dir /your/converted/hf_ckpt \
- ... # 其他训练 args
-```
-
-即,保持所有的参数不变,将:
-
-1. 任务启动从 ray 变成 torchrun,把 gpu 数量保存为 megatron 并行的不带 dp 的最小 gpu 数,例如如果是 tp4,就设成 4;
-2. 确认把 `--load` 改成了需要 load 的路径;
-3. 增加 `--output-dir` 对应要保存的 hf_ckpt。
-
-## 启动训练流程
-
-整个程序需要使用 ray 进行启动,首先需要启动一个 ray 集群,即在 node 0 运行:
-
-```bash
-# Node0(HEAD)
-ray start --head --node-ip-address ${MASTER_ADDR} \
- --num-gpus 8 --disable-usage-stats
-
-# 其他 Node
-ray start --address=${MASTER_ADDR}:6379 --num-gpus 8
-```
-
-在 ray 集群启动后,可以在 node 0 提交任务,例如:
-
-```bash
-ray job submit --address="http://127.0.0.1:8265" \
- --runtime-env-json='{
- "env_vars": {
- "PYTHONPATH": "/root/Megatron-LM/",
- ... # e.g. no_proxy、接口变量等
- }
- }' \
- -- python3 train.py \
- --...(其他 Megatron/SGLang/slime 参数)
-```
-
-#### 参数说明
-
-参数分为三类:
-
-1. **megatron 参数**:slime 会读取 `PYTHONPATH` 中的 megatron 里设置的所有参数,可以通过传入如 `--tensor-model-parallel-size 2` 的方式配置 megatron;
-2. **sglang 参数**:支持环境中安装的 sglang 的所有参数,这些参数需要以 `--sglang` 起始,例如 `--mem-fraction-static` 需要通过 `--sglang-mem-fraction-static` 传入。
-3. **slime 自身的参数**:请见:[slime/utils/arguments.py](slime/utils/arguments.py)
-
-完整使用说明请查阅 [使用文档](docs/zh/usage.md)。
-
-## 开发指南
-
-- **欢迎贡献!** 若有功能建议、性能调优或使用体验反馈,欢迎提交 Issue / PR 😊
-
-- 使用 [pre-commit](https://pre-commit.com/) 保证提交代码风格:
-
- ```bash
- apt install pre-commit -y
- pre-commit install
- ```
-
-- 调试技巧请参考 [debug 指南](docs/zh/debug.md)
-
-## 常见 Q&A 与致谢
-
-- 常见问题请见 [Q&A](docs/zh/qa.md)
-- 特别感谢以下项目 & 社区:SGLang、Megatron‑LM、mbridge、OpenRLHF、veRL 等。
+3) Run an Example
+All scripts are in the scripts/partial_rollout/ directory.
+
+Bash
+
+bash scripts/partial_rollout/qwen/grpo/run-qwen3-4B-dapo-partial.sh
+4) Parameter Details
+The core functionality of partial rollout is controlled by the following parameters:
+
+Bash
+
+# Enable the partial rollout feature
+# Set this parameter to enable the mechanism of stopping generation upon reaching the target count + recycling unfinished samples
+--partial-rollout
+
+# The batch size for sampling. This parameter controls the sampling granularity per round.
+# If this parameter > rollout_batch_size, over-sampling is performed.
+# If this parameter < rollout_batch_size, sampling will continue at this granularity until rollout_batch_size samples are collected.
+--over-sampling-batch-size 16
+For other parameters, please refer to the arguments in arguments.py. For more details, you can consult the original slime repository.
+
+Results and Comparison (Abridged)
+Dataset Model Metric APRIL vs. Baseline
+DAPO‑Math‑17k Qwen3‑4B Rollout Throughput +17%
+DeepScaleR Qwen3‑4B Rollout Throughput +21%
+DeepMath‑103K Qwen3‑4B Rollout Throughput +35%
+
+导出到 Google 表格
+Frequently Asked Questions (FAQ)
+Q: Will APRIL affect policy purity and convergence?
+
+A: It will definitely have an impact on policy purity; the proportion of off-policy tokens in one round is about 40%. However, from both an engineering and experimental perspective, partial rollout has not introduced significant instability under the current settings. Further verification is needed for tasks with a much larger max_response_length (e.g., agent tasks, multi-turn tasks).
+
+Q: Are changes to the decoding kernel required?
+
+A: No. APRIL operates at the system scheduling layer and does not conflict with inference acceleration techniques like speculative decoding or continuous batching. Instead, they are complementary and can be stacked.
+
+Directory Structure
+APRIL/
+├── scripts/
+│ └── partial_rollout/
+│ ├── deepseek/ # Experiment code for deepseek-r1-distill-1.5B
+│ └── qwen/ # Experiment code for qwen3-4B
+├── slime/
+│ ├── backends/
+│ ├── rollout/
+│ │ └── sglang_example.py # Core sampling code
+│ ├── ray/ # Core scheduling logic
+│ │ └── buffer.py # Buffer implementation code
+│ └── utils/
+└── tools/ # Megatron format conversion tools
+
+Paper
+(TODO: arXiv link for the paper)
\ No newline at end of file
diff --git a/docs/en/models/qwen3-4B.md b/docs/en/models/qwen3-4B.md
index 00b4dbd..d1489f0 100644
--- a/docs/en/models/qwen3-4B.md
+++ b/docs/en/models/qwen3-4B.md
@@ -309,9 +309,3 @@ In this case, 2 GPUs will be allocated for training, and 6 GPUs will be allocate
```bash
--sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256)
```
-
-### Asynchronous Training
-
-When you separate training and inference, you may notice that the training and inference GPUs are always waiting for each other. To prevent these resources from being idle, we can enable asynchronous training. This can be done by changing `train.py` to `train_async.py` in the startup script. By doing this, slime will generate data for the next rollout while training on the current one.
-
-The only difference between `train.py` and `train_async.py` lies in the synchronization logic of the training loop. We achieve this by using Ray's asynchronous features (`.remote`, `ray.get`).
diff --git a/docs/en/sft.md b/docs/en/sft.md
deleted file mode 100644
index bead865..0000000
--- a/docs/en/sft.md
+++ /dev/null
@@ -1,87 +0,0 @@
-# Example: Qwen3-4B-Base with OpenHermes-2.5
-
-[中文版](../zh/sft.md)
-
-## Environment Preparation
-
-First, we need to create a mirror environment and convert the `Qwen3-4B-Base` model by following the [Example: Qwen3-4B Model](./models/qwen3-4B.md).
-
-After that, we will process the SFT data. Here, we use the classic [OpenHermes-2.5](https://huggingface.co/datasets/teknium/OpenHermes-2.5) as an example. First, we process the data into a format suitable for `slime` to load. You can use the following script to add a column that conforms to the OpenAI message format and save it to `/root/openhermes2_5.parquet`.
-
-```python
-from datasets import load_dataset
-
-ds = load_dataset("teknium/OpenHermes-2.5")["train"]
-
-def convert(sample):
- conversations = sample["conversations"]
-
- def convert_role(role):
- if role == "human":
- return "user"
- elif role == "gpt":
- return "assistant"
- elif role == "system":
- return "system"
- else:
- raise ValueError(f"Unknown role: {role}")
-
- messages = [
- {
- "role": convert_role(turn["from"]),
- "content": turn["value"],
- }
- for turn in conversations
- ]
-
- return {"messages": messages}
-
-ds = ds.map(convert)
-ds.to_parquet("/root/openhermes2_5.parquet")
-```
-
-## Execute Training
-
-Execute the training:
-
-```bash
-cd /root/slime
-bash script/run-qwen3-4B-base-sft.sh
-```
-
-### Parameter Introduction
-
-You can compare [run-qwen3-4B-base-sft.sh](../../scripts/run-qwen3-4B.sh) with [run-qwen3-4B.sh](../../scripts/run-qwen3-4B.sh). You will find that besides changing the model from the instruct version to the base model, the main adjustments are as follows:
-
-1. Removed `SGLANG_ARGS` and `GRPO_ARGS`. This is because it is not necessary to start SGLang or configure GRPO-related settings during the SFT process.
-
-2. Renamed `ROLLOUT_ARGS` to `SFT_ARGS` and configured it as follows:
-
- ```bash
- SFT_ARGS=(
- --rollout-function-path slime.rollout.sft_example.generate_rollout
- --prompt-data /root/openhermes2_5.parquet
- --input-key messages
- --rollout-shuffle
- --num-epoch 3
- --rollout-batch-size 128
- --global-batch-size 128
-
- --loss-type sft_loss
- --calculate-per-token-loss
- --disable-compute-advantages-and-returns
- --debug-train-only
- )
- ```
-
- SFT actually reuses the custom rollout functionality of slime. By using `--rollout-function-path`, the data generation part is switched from the RL rollout that uses `sglang` to the SFT version that reads data from a file, which is `slime.rollout.sft_example.generate_rollout`.
-
- For SFT, it is recommended to set `rollout_batch_size` and `global_batch_size` to the same value and not to configure `n_samples_per_prompt`. This is equivalent to training one batch right after reading one batch.
-
- `slime` also supports different loss types, and we configure the SFT loss using `--loss-type sft_loss`.
-
- As for `--calculate-per-token-loss`, this is because `slime` defaults to calculating the per-sample mean for GRPO. In general SFT training, the average is taken over all unmasked tokens in a batch, so it is recommended to configure this.
-
- Finally, `--disable-compute-advantages-and-returns` indicates that there is no need to pre-calculate log probabilities during the SFT process, and `--debug-train-only` means that `sglang` does not need to be initialized.
-
-3. Used `train_async.py` instead of `train.py`. This is to leverage the asynchronous training process to implement data prefetching.
diff --git a/docs/zh/models/qwen3-4B.md b/docs/zh/models/qwen3-4B.md
index f2802a7..e99348f 100644
--- a/docs/zh/models/qwen3-4B.md
+++ b/docs/zh/models/qwen3-4B.md
@@ -309,11 +309,3 @@ ray job submit ... \
```bash
--sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256)
```
-
-### 异步训练
-
-当进行训推分离时,你会发现训练和推理的 GPU 总是相互等待着,为了避免这种资源空闲,我们可以开启异步训练。开启的方式即为将启动脚本中的 `train.py` 改变为 `train_async.py`。这样 slime 就会在进行当前 rollout 的训练时进行下一个 rollout 的数据生成了。
-
-`train.py` 和 `train_async.py` 的差别只在于 train loop 的同步逻辑,我们通过 ray 的异步(`.remote`, `ray.get`)实现了这点。
-
-⚠️ 在异步训练时,sglang 的性能检测日志与训练日志可能会混到一起,不易区分,可以通过 `--sglang-log-level` 来减少 sglang 的日志。
\ No newline at end of file
diff --git a/docs/zh/sft.md b/docs/zh/sft.md
deleted file mode 100644
index ff58882..0000000
--- a/docs/zh/sft.md
+++ /dev/null
@@ -1,87 +0,0 @@
-# 示例:Qwen3-4B-Base + OpenHermes-2.5
-
-[English](../en/sft.md)
-
-## 环境准备
-
-首先需要我们仿照 [示例:Qwen3-4B 模型](./models/qwen3-4B.md) 创建镜像环境与转换 `Qwen3-4B-Base` 模型。
-
-之后,我们处理 sft 数据。这里我们以经典的 [OpenHermes-2.5](https://huggingface.co/datasets/teknium/OpenHermes-2.5) 为例,首先把数据处理成适合 slime 加载的格式,可以用如下的脚本进行处理,增加一个符合 openai message 格式的列,并保存在 `/root/openhermes2_5.parquet`。
-
-```python
-from datasets import load_dataset
-
-ds = load_dataset("teknium/OpenHermes-2.5")["train"]
-
-def convert(sample):
- conversations = sample["conversations"]
-
- def convert_role(role):
- if role == "human":
- return "user"
- elif role == "gpt":
- return "assistant"
- elif role == "system":
- return "system"
- else:
- raise ValueError(f"Unknown role: {role}")
-
- messages = [
- {
- "role": convert_role(turn["from"]),
- "content": turn["value"],
- }
- for turn in conversations
- ]
-
- return {"messages": messages}
-
-ds = ds.map(convert)
-ds.to_parquet("/root/openhermes2_5.parquet")
-```
-
-## 执行训练
-
-执行训练:
-
-```bash
-cd /root/slime
-bash script/run-qwen3-4B-base-sft.sh
-```
-
-### 参数简介
-
-可以将 [run-qwen3-4B-base-sft.sh](../../scripts/run-qwen3-4B-base-sft.sh) 与 [run-qwen3-4B.sh](../../scripts/run-qwen3-4B.sh) 进行对比。会发现除了我们将模型由 instruct 模型换为了 base 模型之外,主要进行了如下的几个调整:
-
-1. 移除了 `SGLANG_ARGS` 和 `GRPO_ARGS`。这是因为 sft 的过程中不需要启动 sglang 或者做 grpo 相关的配置;
-
-2. 将 `ROLLOUT_ARGS` 改名为了 `SFT_ARGS`,并配置为:
-
- ```bash
- SFT_ARGS=(
- --rollout-function-path slime.rollout.sft_example.generate_rollout
- --prompt-data /root/openhermes2_5.parquet
- --input-key messages
- --rollout-shuffle
- --num-epoch 3
- --rollout-batch-size 128
- --global-batch-size 128
-
- --loss-type sft_loss
- --calculate-per-token-loss
- --disable-compute-advantages-and-returns
- --debug-train-only
- )
- ```
-
- slime 中的 sft 实际上是复用了 slime 的 custom rollout 功能,通过 `--rollout-function-path` 将数据生成部分从使用 sglang 的 RL rollout,切换成了从文件中读取数据的 sft 版本,即 `slime.rollout.sft_example.generate_rollout`。
-
- 对于 sft 来说,建议将 `rollout_batch_size` 与 `global_batch_size` 设置成相同的,并不要配置 `n_samples_per_prompt`,这样相当于是读一个 batch 就训一个 batch。
-
- slime 还支持不同的 loss 类型,我们就是通过 `--loss-type sft_loss` 配置上 sft loss 的。
-
- 至于 `--calculate-per-token-loss`,这是因为 slime 默认是以 GRPO 的 per sample mean 进行计算的,而一般 sft 训练都是按一个 batch 的所有不被 mask 的 token 取平均,所以建议配置上。
-
- 最后 `--disable-compute-advantages-and-returns` 表示 sft 的过程中不需要预先计算 log prob,`--debug-train-only` 表示不需要初始化 sglang。
-
-3. 使用了 `train_async.py` 而不是 `train.py`。这是为了利用异步训练的流程,来实现数据 prefetch。
diff --git a/examples/search-r1/README.md b/examples/search-r1/README.md
deleted file mode 100644
index 53f21a7..0000000
--- a/examples/search-r1/README.md
+++ /dev/null
@@ -1,75 +0,0 @@
-# Example: Search-R1 lite
-
-[中文版](./README_zh.md)
-
-This is a minimal reproduction of [Search-R1](https://github.com/PeterGriffinJin/Search-R1) and an example of using multi-turn conversation and tool-calling in slime.
-
-## Environment Setup
-
-Use the `zhuzilin/slime:latest` image and initialize the environment required for Search-R1:
-
-```bash
-cd /root/
-git clone https://github.com/THUDM/slime.git
-pip install -e .
-# for Search R1
-pip install chardet
-```
-
-Please refer to the script provided in Search-R1 to download the data:
-
-```bash
-git clone https://github.com/PeterGriffinJin/Search-R1.git
-cd Search-R1/
-python scripts/data_process/nq_search.py --local_dir /root/nq_search/
-```
-
-Initialize the Qwen2.5-3B model:
-
-```bash
-# hf checkpoint
-huggingface-cli download Qwen/Qwen2.5-3B --local-dir /root/Qwen2.5-3B
-
-# mcore checkpoint
-cd /root/slime
-PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \
- --hf-checkpoint /root/Qwen2.5-3B \
- --save /root/Qwen2.5-3B_torch_dist
-```
-
-## Running the Script
-
-You need to configure your serper.dev API in `generate_with_search.py`:
-
-```python
-SEARCH_R1_CONFIGS = {
- "max_turns": 3,
- "topk": 3,
- "google_api_key": "YOUR_API_KEY", # Replace with your actual API key
- "snippet_only": True, # Set to True to only return snippets
- "proxy": None, # Set to your proxy if needed
- "search_concurrency": 256,
- # rm
- "format_score": 0.2,
-}
-```
-
-And run:
-
-```bash
-cd slime/
-bash examples/search-r1/run_qwen2.5_3B.sh
-```
-
-## Code Structure
-
-To implement multi-turn conversation + tool-calling in slime, you only need to implement a custom data generation function and a reward model for the task. These correspond to the following 2 configuration items in the startup script:
-
-```bash
-CUSTOM_ARGS=(
- --custom-generate-function-path generate_with_search.generate
- --custom-rm-path generate_with_search.reward_func
-)
-```
-
-These are the `generate` and `reward_func` functions in `generate_with_search.py`.
diff --git a/examples/search-r1/README_zh.md b/examples/search-r1/README_zh.md
deleted file mode 100644
index cbf13b5..0000000
--- a/examples/search-r1/README_zh.md
+++ /dev/null
@@ -1,75 +0,0 @@
-# 示例:Search-R1 lite
-
-[English](./README.md)
-
-这里是一个对 [Search-R1](https://github.com/PeterGriffinJin/Search-R1) 的简单复现,以及是一个在 slime 中使用多轮对话和工具调用的样例。
-
-## 配置环境
-
-使用 `zhuzilin/slime:latest` 镜像,并初始化 Search-R1 需要的环境:
-
-```bash
-cd /root/
-git clone https://github.com/THUDM/slime.git
-pip install -e .
-# for Search R1
-pip install chardet
-```
-
-请参照 Search-R1 中提供的脚本下载数据:
-
-```bash
-git clone https://github.com/PeterGriffinJin/Search-R1.git
-cd Search-R1/
-python scripts/data_process/nq_search.py --local_dir /root/nq_search/
-```
-
-初始化 Qwen2.5-3B 模型:
-
-```bash
-# hf checkpoint
-huggingface-cli download Qwen/Qwen2.5-3B --local-dir /root/Qwen2.5-3B
-
-# mcore checkpoint
-cd /root/slime
-PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \
- --hf-checkpoint /root/Qwen2.5-3B \
- --save /root/Qwen2.5-3B_torch_dist
-```
-
-## 运行脚本
-
-需要将你的 serper.dev API 配置在 `generate_with_search.py` 中:
-
-```python
-SEARCH_R1_CONFIGS = {
- "max_turns": 3,
- "topk": 3,
- "google_api_key": "YOUR_API_KEY", # Replace with your actual API key
- "snippet_only": True, # Set to True to only return snippets
- "proxy": None, # Set to your proxy if needed
- "search_concurrency": 256,
- # rm
- "format_score": 0.2,
-}
-```
-
-并运行:
-
-```bash
-cd slime/
-bash examples/search-r1/run_qwen2.5_3B.sh
-```
-
-## 代码结构
-
-为了实现多轮 + 工具调用,在 slime 中只需要实现一个自定义的数据生成函数,以及一个任务所需的 reward model,对应启动脚本中的这 2 个配置项:
-
-```bash
-CUSTOM_ARGS=(
- --custom-generate-function-path generate_with_search.generate
- --custom-rm-path generate_with_search.reward_func
-)
-```
-
-也就是 `generate_with_search.py` 中的 `generate` 和 `reward_func` 两个函数。
diff --git a/examples/search-r1/generate_with_search.py b/examples/search-r1/generate_with_search.py
deleted file mode 100644
index bbe9760..0000000
--- a/examples/search-r1/generate_with_search.py
+++ /dev/null
@@ -1,169 +0,0 @@
-# Adapted form https://github.com/PeterGriffinJin/Search-R1/blob/ceee7b89655ed52f205b9beb98e1190c3eedcfb0/search_r1/llm_agent/generation.py
-import asyncio
-import re
-
-from google_search_server import google_search
-from qa_em_format import compute_score_em
-
-from slime.rollout.sglang_example import GenerateState
-from slime.utils.http_utils import post
-from slime.utils.types import Sample
-
-SEARCH_R1_CONFIGS = {
- "max_turns": 3,
- "topk": 3,
- "google_api_key": "YOUR_API_KEY", # Replace with your actual API key
- "snippet_only": True, # Set to True to only return snippets
- "proxy": None, # Set to your proxy if needed
- "search_concurrency": 256,
- # rm
- "format_score": 0.2,
-}
-
-
-SEMAPHORE = asyncio.Semaphore(SEARCH_R1_CONFIGS["search_concurrency"])
-
-
-def _passages2string(retrieval_result):
- format_reference = ""
- for idx, doc_item in enumerate(retrieval_result):
-
- content = doc_item["document"]["contents"]
- title = content.split("\n")[0]
- text = "\n".join(content.split("\n")[1:])
- format_reference += f"Doc {idx+1}(Title: {title}) {text}\n"
-
- return format_reference
-
-
-async def search(query: str) -> str:
- result = await google_search(
- SEARCH_R1_CONFIGS["google_api_key"],
- query,
- SEARCH_R1_CONFIGS["topk"],
- snippet_only=SEARCH_R1_CONFIGS["snippet_only"],
- proxy=SEARCH_R1_CONFIGS["proxy"],
- )
- return _passages2string(result)
-
-
-def postprocess_responses(resp: str) -> str:
- return (
- resp.split("")[0] + ""
- if "" in resp
- else resp.split("")[0] + "" if "" in resp else resp
- )
-
-
-def postprocess_predictions(prediction: str):
- pattern = r"<(search|answer)>(.*?)\1>"
- match = re.search(pattern, prediction, re.DOTALL)
- if match:
- content = match.group(2).strip() # Return only the content inside the tags
- action = match.group(1)
- else:
- content = ""
- action = None
-
- return action, content
-
-
-async def execute_predictions(prediction: str) -> str:
- action, content = postprocess_predictions(prediction)
-
- if action == "search":
- search_query = content
- async with SEMAPHORE:
- search_results = await search(search_query)
- next_obs = f"\n\n{search_results.strip()}\n\n"
- done = False
- elif action == "answer":
- next_obs = ""
- done = True
- else:
- next_obs = f"\nMy previous action is invalid. \
-If I want to search, I should put the query between and . \
-If I want to give the final answer, I should put the answer between and . Let me try again.\n"
- done = False
-
- return next_obs, done
-
-
-async def generate(args, sample: Sample, sampling_params) -> Sample:
- assert not args.partial_rollout, f"Partial rollout is not supported for this function at the moment."
-
- state = GenerateState(args)
-
- url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"
-
- # Handle partial rollout samples: continue generation from existing response
- prompt = sample.prompt
- prompt_tokens_ids = state.tokenizer(sample.prompt, add_special_tokens=False)["input_ids"]
- response = ""
- response_token_ids = []
- loss_masks = []
- for _ in range(SEARCH_R1_CONFIGS["max_turns"]):
- payload = {
- "text": prompt + response,
- "sampling_params": sampling_params,
- }
- output = await post(url, payload, use_http2=args.use_http2)
-
- # abort
- if output["meta_info"]["finish_reason"]["type"] == "abort":
- sample.status = Sample.Status.ABORTED
- return sample
-
- cur_response = output["text"]
- cur_response = postprocess_responses(cur_response)
-
- cur_response_token_ids = state.tokenizer(cur_response, add_special_tokens=False)["input_ids"]
- response += cur_response
- response_token_ids += cur_response_token_ids
- loss_masks += [1] * len(cur_response_token_ids)
-
- if output["meta_info"]["finish_reason"]["type"] == "length":
- break
-
- next_obs, done = await execute_predictions(cur_response)
- if done:
- break
-
- assert next_obs != "", "Next observation should not be empty."
- obs_tokens_ids = state.tokenizer(next_obs, add_special_tokens=False)["input_ids"]
- response += next_obs
- response_token_ids += obs_tokens_ids
- loss_masks += [0] * len(obs_tokens_ids)
-
- sample.tokens = prompt_tokens_ids + response_token_ids
- sample.response_length = len(response_token_ids)
- sample.response = response
- sample.loss_masks = loss_masks
- match output["meta_info"]["finish_reason"]["type"]:
- case "length":
- sample.status = Sample.Status.TRUNCATED
- case "abort":
- sample.status = Sample.Status.ABORTED
- case "stop":
- sample.status = Sample.Status.COMPLETED
-
- return sample
-
-
-async def reward_func(args, sample, **kwargs):
- """The reward function for retrieval-based question answering.
-
- Args:
- args: the arguments
- sample: the sample to evaluate
- """
- if not isinstance(sample, Sample):
- raise TypeError("Sample must be an instance of Sample class.")
-
- score = compute_score_em(
- solution_str=sample.prompt + sample.response,
- ground_truth=sample.label["ground_truth"],
- format_score=SEARCH_R1_CONFIGS["format_score"],
- )
-
- return score
diff --git a/examples/search-r1/google_search_server.py b/examples/search-r1/google_search_server.py
deleted file mode 100644
index 6f19356..0000000
--- a/examples/search-r1/google_search_server.py
+++ /dev/null
@@ -1,150 +0,0 @@
-import asyncio
-import os
-import random
-import re
-from typing import Dict, List
-
-import aiohttp
-import chardet
-
-
-# --- Utilities ---
-def parse_snippet(snippet: str) -> List[str]:
- segments = snippet.split("...")
- return [s.strip() for s in segments if len(s.strip().split()) > 5]
-
-
-def sanitize_search_query(query: str) -> str:
- # Remove or replace special characters that might cause issues.
- # This is a basic example; you might need to add more characters or patterns.
- sanitized_query = re.sub(r"[^\w\s]", " ", query) # Replace non-alphanumeric and non-whitespace with spaces.
- sanitized_query = re.sub(
- r"[\t\r\f\v\n]", " ", sanitized_query
- ) # replace tab, return, formfeed, vertical tab with spaces.
- sanitized_query = re.sub(
- r"\s+", " ", sanitized_query
- ).strip() # remove duplicate spaces, and trailing/leading spaces.
-
- return sanitized_query
-
-
-def filter_links(search_results: List[Dict]) -> List[str]:
- links = []
- for result in search_results:
- for item in result.get("items", []):
- if "mime" in item:
- continue
- ext = os.path.splitext(item["link"])[1]
- if ext in ["", ".html", ".htm", ".shtml"]:
- links.append(item["link"])
- return links
-
-
-async def fetch(session: aiohttp.ClientSession, url: str, semaphore: asyncio.Semaphore) -> str:
- if url == "":
- return ""
- user_agents = [
- "Mozilla/5.0 (Linux; Android 6.0.1; Nexus 5X Build/MMB29P)...",
- "Mozilla/5.0 AppleWebKit/537.36...",
- "Mozilla/5.0 (compatible; Googlebot/2.1; +https://www.google.com/bot.html)",
- ]
- headers = {"User-Agent": random.choice(user_agents)}
-
- async with semaphore:
- try:
- async with session.get(url, headers=headers) as response:
- raw = await response.read()
- detected = chardet.detect(raw)
- encoding = detected["encoding"] or "utf-8"
- return raw.decode(encoding, errors="ignore")
- except (aiohttp.ClientError, asyncio.TimeoutError):
- return ""
-
-
-async def fetch_all(urls: List[str], limit: int = 8) -> List[str]:
- semaphore = asyncio.Semaphore(limit)
- timeout = aiohttp.ClientTimeout(total=5)
- connector = aiohttp.TCPConnector(limit_per_host=limit, force_close=True)
-
- async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
- tasks = [fetch(session, url, semaphore) for url in urls]
- return await asyncio.gather(*tasks)
-
-
-def collect_context(snippet: str, doc: str) -> str:
- snippets = parse_snippet(snippet)
- ctx_paras = []
-
- for s in snippets:
- pos = doc.replace("\n", " ").find(s)
- if pos == -1:
- continue
- sta = pos
- while sta > 0 and doc[sta] != "\n":
- sta -= 1
- end = pos + len(s)
- while end < len(doc) and doc[end] != "\n":
- end += 1
- para = doc[sta:end].strip()
- if para not in ctx_paras:
- ctx_paras.append(para)
-
- return "\n".join(ctx_paras)
-
-
-async def google_search(api_key, query, top_k=5, timeout: int = 60, proxy=None, snippet_only=False) -> List[Dict]:
- timeout_obj = aiohttp.ClientTimeout(total=timeout)
- session_kwargs = {}
- if proxy:
- session_kwargs["proxy"] = proxy
- async with aiohttp.ClientSession(**session_kwargs) as session:
- async with session.post(
- "https://google.serper.dev/search",
- json={
- "q": query,
- "num": top_k,
- "gl": "us",
- "hl": "en",
- },
- headers={
- "Content-Type": "application/json",
- "X-API-KEY": api_key,
- },
- timeout=timeout_obj,
- ) as resp:
- resp.raise_for_status()
- response = await resp.json()
- items = response.get("organic", [])
-
- contexts = []
- if snippet_only:
- for item in items:
- title = item.get("title", "")
- context = " ".join(parse_snippet(item.get("snippet", "")))
- if title != "" or context != "":
- title = "No title." if not title else title
- context = "No snippet available." if not context else context
- contexts.append(
- {
- "document": {"contents": f'"{title}"\n{context}'},
- }
- )
- else:
- links = [item.get("link", "") for item in items if "link" in item]
- web_contents = await fetch_all(links)
- contexts = []
- for i, item in enumerate(items):
- title = item.get("title", "")
- snippet = item.get("snippet", "")
-
- context = collect_context(snippet, web_contents[i])
- if title != "" or context != "":
- title = "No title." if not title else title
- context = "No snippet available." if not context else context
- contexts.append(
- {
- "document": {"contents": f'"{title}"\n{context}'},
- }
- )
-
- return contexts
diff --git a/examples/search-r1/qa_em_format.py b/examples/search-r1/qa_em_format.py
deleted file mode 100644
index 7820168..0000000
--- a/examples/search-r1/qa_em_format.py
+++ /dev/null
@@ -1,208 +0,0 @@
-# Adapt from https://github.com/PeterGriffinJin/Search-R1/blob/ceee7b89655ed52f205b9beb98e1190c3eedcfb0/verl/utils/reward_score/qa_em_format.py
-# Copyright 2024 Bytedance Ltd. and/or its affiliates
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import random
-import re
-import string
-
-
-def normalize_answer(s):
- def remove_articles(text):
- return re.sub(r"\b(a|an|the)\b", " ", text)
-
- def white_space_fix(text):
- return " ".join(text.split())
-
- def remove_punc(text):
- exclude = set(string.punctuation)
- return "".join(ch for ch in text if ch not in exclude)
-
- def lower(text):
- return text.lower()
-
- return white_space_fix(remove_articles(remove_punc(lower(s))))
-
-
-def em_check(prediction, golden_answers):
- if isinstance(golden_answers, str):
- golden_answers = [golden_answers]
- normalized_prediction = normalize_answer(prediction)
- score = 0
- for golden_answer in golden_answers:
- golden_answer = normalize_answer(golden_answer)
- if golden_answer == normalized_prediction:
- score = 1
- break
- return score
-
-
-def is_valid_sequence(text):
- # Find the position of "<|im_start|>assistant" with potential whitespace
- assistant_pattern = r"<\|im_start\|>assistant\s*"
- assistant_match = re.search(assistant_pattern, text)
-
- if not assistant_match:
- return False, "Missing assistant marker"
-
- # Extract the content after the assistant marker
- start_pos = assistant_match.end()
- content = text[start_pos:]
-
- # Check for balanced tags
- tags_to_check = ["think", "search", "information", "answer"]
- for tag in tags_to_check:
- opening_count = len(re.findall(f"<{tag}>", content))
- closing_count = len(re.findall(f"{tag}>", content))
- if opening_count != closing_count:
- return False, f"Mismatch in {tag} tags: {opening_count} opening vs {closing_count} closing tags"
-
- # Now check for proper sequence pattern and no extraneous content
-
- # 1. First split the content by any tags we recognize
- split_pattern = r"(?(?:think|search|information|answer)>)"
- parts = re.split(split_pattern, content)
-
- # 2. Keep track of the current position in the expected sequence
- state = "start" # start -> think -> search -> information -> think -> ... -> answer -> end
-
- # 3. Check each part
- for i, part in enumerate(parts):
- # Skip empty parts
- if not part.strip():
- continue
-
- # Check if this is a tag
- if re.match(r"?(?:think|search|information|answer)>", part):
- # This is a tag, check if it's valid in the current state
- if part == "" and state in ["start", "information"]:
- state = "in_think"
- elif part == "" and state == "in_think":
- state = "after_think"
- elif part == "" and state == "after_think":
- state = "in_search"
- elif part == "" and state == "in_search":
- state = "after_search"
- elif part == "" and state == "after_search":
- state = "in_information"
- elif part == "" and state == "in_information":
- state = "information"
- elif part == "" and state == "after_think":
- state = "in_answer"
- elif part == "" and state == "in_answer":
- state = "end"
- else:
- return False, f"Unexpected tag {part} in state {state}"
- else:
- # This is content, check if it's valid in the current state
- if state in ["in_think", "in_search", "in_information", "in_answer"]:
- # Content is allowed inside tags
- pass
- elif state in ["start", "after_think", "after_search", "information"]:
- # Only whitespace is allowed between tags
- if part.strip():
- return False, f"Unexpected content '{part.strip()}' between tags (state: {state})"
- else:
- return False, f"Unexpected content in state {state}"
-
- # Check final state
- if state != "end":
- return False, f"Incomplete sequence, ended in state {state}"
-
- return True, "Valid sequence format"
-
-
-def extract_solution(solution_str):
- """Extract the equation from the solution string."""
-
- answer_pattern = r"(.*?)"
- match = re.finditer(answer_pattern, solution_str, re.DOTALL)
- matches = list(match)
-
- # If there are 0 or exactly 1 matches, return None
- if len(matches) <= 1:
- return None
-
- # If there are 2 or more matches, return the last one
- return matches[-1].group(1).strip()
-
-
-def extract_information_blocks(text: str) -> list[str]:
- pattern = r"(.*?)"
- matches = re.findall(pattern, text, re.DOTALL)
- return [match.strip() for match in matches]
-
-
-def is_retrieval_correct(text: str, golden_answers: list[str]) -> list[str]:
- seqs = extract_information_blocks(text)
- for seq in seqs:
- for golden_answer in golden_answers:
- if normalize_answer(golden_answer) in normalize_answer(seq):
- return True
- return False
-
-
-def compute_score_em(
- solution_str,
- ground_truth,
- method="strict",
- structure_format_score=0,
- final_format_score=0,
- retrieval_score=0,
- format_score=0,
- score=1.0,
-):
- """The scoring function for exact match (EM).
-
- Args:
- solution_str: the solution text
- ground_truth: the ground truth
- method: the method to extract the solution, choices are 'strict' and 'flexible'
- format_score: the score for the format
- score: the score for the correct answer
- """
- is_valid_format, _ = is_valid_sequence(solution_str)
- retrieval_correct = False
- if is_valid_format:
- retrieval_correct = is_retrieval_correct(solution_str, ground_truth["target"])
- answer = extract_solution(solution_str=solution_str)
- do_print = random.randint(1, 64) == 1
-
- if do_print:
- print(f"--------------------------------")
- print(f"Golden answers: {ground_truth['target']}")
- print(f"Extracted answer: {answer}")
- print(f"Solution string: {solution_str}")
-
- if answer is None:
- if is_valid_format:
- if retrieval_correct:
- return structure_format_score + retrieval_score # 0.3
- else:
- return structure_format_score # 0.2
- else:
- return 0
- else:
- if em_check(answer, ground_truth["target"]):
- if is_valid_format:
- return score # 1
- else:
- return score - structure_format_score # 0.8
- elif is_valid_format:
- if retrieval_correct:
- return structure_format_score + retrieval_score # 0.3
- else:
- return structure_format_score # 0.2
- else:
- return final_format_score # 0.1
diff --git a/examples/search-r1/run_qwen2.5_3B.sh b/examples/search-r1/run_qwen2.5_3B.sh
deleted file mode 100644
index 6cfd820..0000000
--- a/examples/search-r1/run_qwen2.5_3B.sh
+++ /dev/null
@@ -1,137 +0,0 @@
-#!/bin/bash
-
-# for rerun the task
-pkill -9 sglang
-sleep 3
-ray stop --force
-pkill -9 ray
-pkill -9 python
-sleep 3
-pkill -9 ray
-pkill -9 python
-
-set -ex
-
-# will prevent ray from buffering stdout/stderr
-export PYTHONBUFFERED=16
-
-SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
-source "${SCRIPT_DIR}/../../scripts/models/qwen2.5-3B.sh"
-
-CKPT_ARGS=(
- --hf-checkpoint /root/Qwen2.5-3B/
- --ref-load /root/Qwen2.5-3B_torch_dist/
- --load /root/Qwen2.5-3B_slime/
- --save /root/Qwen2.5-3B_slime/
- --save-interval 20
-)
-
-ROLLOUT_ARGS=(
- --prompt-data /root/nq_search/train.parquet
- --input-key prompt
- --label-key reward_model
- --apply-chat-template
- --rollout-shuffle
- --num-rollout 3000
- --rollout-batch-size 32
- --n-samples-per-prompt 8
- --rollout-max-response-len 512
- --rollout-temperature 0.8
-
- --global-batch-size 256
- --balance-data
-)
-
-PERF_ARGS=(
- --tensor-model-parallel-size 2
- --sequence-parallel
- --pipeline-model-parallel-size 1
- --context-parallel-size 1
- --expert-model-parallel-size 1
- --expert-tensor-parallel-size 1
-
- --recompute-granularity full
- --recompute-method uniform
- --recompute-num-layers 1
-
- # --micro-batch-size 1
- --use-dynamic-batch-size
- --max-tokens-per-gpu 9216
-)
-
-GRPO_ARGS=(
- --advantage-estimator grpo
- --use-kl-loss
- --kl-loss-coef 0.00
- --kl-loss-type low_var_kl
- --entropy-coef 0.00
- --eps-clip 0.2
- --eps-clip-high 0.28
-)
-
-OPTIMIZER_ARGS=(
- --optimizer adam
- --lr 1e-6
- --lr-decay-style constant
- --weight-decay 0.1
- --adam-beta1 0.9
- --adam-beta2 0.98
-)
-
-WANDB_ARGS=(
- # --use-wandb
- # --wandb-project slime-dev
- # --wandb-group search-r1_qwen2.5-3B-test
- # --wandb-key ${WANDB_KEY}
-)
-
-SGLANG_ARGS=(
- --rollout-num-gpus-per-engine 2
- --sglang-mem-fraction-static 0.7
-)
-
-MISC_ARGS=(
- # default dropout in megatron is 0.1
- --attention-dropout 0.0
- --hidden-dropout 0.0
- # should be good for model performance
- --accumulate-allreduce-grads-in-fp32
- --attention-softmax-in-fp32
- # need to comment this when using model with MLA
- --attention-backend flash
-)
-
-CUSTOM_ARGS=(
- --custom-generate-function-path generate_with_search.generate
- --custom-rm-path generate_with_search.reward_func
-)
-
-# launch the master node of ray in container
-export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
-ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats
-
-RUNTIME_ENV_JSON="{
- \"env_vars\": {
- \"PYTHONPATH\": \"/root/Megatron-LM/:${SCRIPT_DIR}\",
- \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\"
- }
-}"
-
-ray job submit --address="http://127.0.0.1:8265" \
- --runtime-env-json="${RUNTIME_ENV_JSON}" \
- -- python3 train.py \
- --actor-num-nodes 1 \
- --actor-num-gpus-per-node 4 \
- --rollout-num-gpus 4 \
- --colocate \
- ${MODEL_ARGS[@]} \
- ${CKPT_ARGS[@]} \
- ${ROLLOUT_ARGS[@]} \
- ${OPTIMIZER_ARGS[@]} \
- ${GRPO_ARGS[@]} \
- ${DISTRIBUTED_ARGS[@]} \
- ${WANDB_ARGS[@]} \
- ${PERF_ARGS[@]} \
- ${SGLANG_ARGS[@]} \
- ${MISC_ARGS[@]} \
- ${CUSTOM_ARGS[@]}
diff --git a/imgs/eval_dapo_qwen.png b/imgs/eval_dapo_qwen.png
new file mode 100644
index 0000000..7caf4fd
Binary files /dev/null and b/imgs/eval_dapo_qwen.png differ
diff --git a/imgs/partial_scheduling.png b/imgs/partial_scheduling.png
new file mode 100644
index 0000000..f295c84
Binary files /dev/null and b/imgs/partial_scheduling.png differ
diff --git a/slime/backends/fsdp_utils/__init__.py b/slime/backends/fsdp_utils/__init__.py
deleted file mode 100644
index 2ab3cdc..0000000
--- a/slime/backends/fsdp_utils/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .actor import FSDPTrainRayActor
-
-__all__ = ["FSDPTrainRayActor"]
diff --git a/slime/backends/fsdp_utils/actor.py b/slime/backends/fsdp_utils/actor.py
deleted file mode 100644
index 5015b53..0000000
--- a/slime/backends/fsdp_utils/actor.py
+++ /dev/null
@@ -1,32 +0,0 @@
-from slime.ray.ppo_actor import TrainRayActor
-
-
-class FSDPTrainRayActor(TrainRayActor):
- def init(self, args, role, with_ref=False):
- super().init(args, role, with_ref)
-
- raise NotImplementedError
-
- def sleep(self, tags):
- raise NotImplementedError
-
- def wake_up(self, tags):
- raise NotImplementedError
-
- def connect_rollout_engines(self, rollout_engines, rollout_engine_lock):
- raise NotImplementedError
-
- def set_data_buffer(self, data_buffer):
- raise NotImplementedError
-
- def train(self, rollout_id, with_data_fetching=True):
- raise NotImplementedError
-
- def eval(self, rollout_id):
- raise NotImplementedError
-
- def save_model(self, iteration, with_optimizer=True):
- raise NotImplementedError
-
- def update_weights(self):
- raise NotImplementedError
diff --git a/slime/rollout/sft_example.py b/slime/rollout/sft_example.py
deleted file mode 100644
index 2956ca9..0000000
--- a/slime/rollout/sft_example.py
+++ /dev/null
@@ -1,47 +0,0 @@
-from transformers import AutoTokenizer
-
-from slime.utils.mask_utils import MultiTurnLossMaskGenerator
-
-__all__ = ["generate_rollout"]
-
-
-TOKENIZER = None
-MASK_GENERATOR = None
-
-
-def generate_rollout(args, rollout_id, data_buffer, evaluation=False):
- """An example to implement the generate_rollout function for an rule based rm rollout generation.
-
- Args:
- args: the whole args
- rollout_id: int, the id of the rollout, used for deterministic data generation
- data_buffer: the data buffer to store the generated samples
- evaluation: bool, whether the rollout is for evaluation or not
-
- Returns:
- list[Sample]: a list of samples generated by the rollout
- """
- assert not evaluation
- assert args.rollout_global_dataset
-
- global TOKENIZER, MASK_GENERATOR
- if TOKENIZER is None:
- TOKENIZER = AutoTokenizer.from_pretrained(args.hf_checkpoint, trust_remote_code=True)
-
- if MASK_GENERATOR is None:
- MASK_GENERATOR = MultiTurnLossMaskGenerator(TOKENIZER, tokenizer_type=args.loss_mask_type)
-
- samples = data_buffer.get_samples(args.rollout_batch_size)
-
- for sample in samples:
- (sample,) = sample
- messages = sample.prompt
- token_ids, loss_mask = MASK_GENERATOR.get_loss_mask(messages)
- response_length = MASK_GENERATOR.get_response_lengths([loss_mask])[0]
-
- sample.tokens = token_ids
- sample.response_length = response_length
- sample.reward = 0
- sample.loss_mask = loss_mask[-response_length:]
-
- return samples
diff --git a/slime_plugins/rollout_buffer/README.md b/slime_plugins/rollout_buffer/README.md
deleted file mode 100644
index e85e68d..0000000
--- a/slime_plugins/rollout_buffer/README.md
+++ /dev/null
@@ -1,50 +0,0 @@
-# Rollout Buffer
-
-## Overview
-
-Rollout Buffer is an independent component for asynchronous agent trajectory generation, with the main function of using the LLM OpenAI Server launched by slime training to generate agent trajectories.
-
-### Workflow
-
-```
-slime Training Process ←─── HTTP API ───→ Rollout Buffer
- ↓ ↓
- LLM Server ←─────── HTTP Requests ─────── Agent Framework
- ↓ ↓
- Model Response ──────────────────────→ Trajectory Generation
-```
-
-For each different Agent task, there should be a corresponding independent Generator class, responsible for generating trajectories for that type of task. Rollout Buffer automatically reads and loads different types of Generators.
-
-## Quick Start
-
-### Basic Usage Process
-
-1. **Copy Template**: Copy `base_generator.py` as a template
-2. **Modify Task Type**: Change `TASK_TYPE` to your task name (cannot duplicate with other Generators)
-3. **Implement Core Function**: Implement the `run_rollout()` function
-4. **Optional Customization**: Rewrite five optional functions as needed
-
-
-Generator files must end with `_generator.py` and be placed in the `generator/` directory:
-
-```
-generator/
-├── base_generator.py # Math task implementation (default template)
-└── your_task_generator.py # Your custom task
-```
-
-Each Generator file must define `TASK_TYPE` and `run_rollout()`.
-
-In addition, Rollout Buffer also provides some customizable functions to meet special needs of different tasks. If no custom implementation is provided, the system will use default implementations (located in `slime_plugins/rollout_buffer/default_func.py`).
-
-### Example Script
-
-First, you need to follow [Example: Qwen3-4B Model](../../docs/en/models/qwen3-4B.md) to configure the environment, download data and convert model checkpoints. And then run the following scripts:
-```bash
-cd slime_plugins/rollout_buffer
-bash rollout_buffer_example.sh
-
-# In a different terminal
-python buffer.py
-```
diff --git a/slime_plugins/rollout_buffer/README_zh.md b/slime_plugins/rollout_buffer/README_zh.md
deleted file mode 100644
index cfa689f..0000000
--- a/slime_plugins/rollout_buffer/README_zh.md
+++ /dev/null
@@ -1,51 +0,0 @@
-# Rollout Buffer
-
-## 概述
-
-Rollout Buffer 是用于辅助纯异步 agent 训练的独立组件,其主要功能是使用 slime 训练启动的 LLM OpenAI Server 进行智能体轨迹的生成。
-
-### 工作流程
-
-```
-slime Training Process ←─── HTTP API ───→ Rollout Buffer
- ↓ ↓
- LLM Server ←─────── HTTP Requests ─────── Agent Framework
- ↓ ↓
- Model Response ──────────────────────→ Trajectory Generation
-```
-
-对于每一个不同的 Agent 任务,都应该对应一个独立的 Generator 类,负责生成该类任务的轨迹。Rollout Buffer 会自动读取并加载不同类型的 Generator。
-
-## 快速开始
-
-### 基本使用流程
-
-1. **复制模板**:将 `base_generator.py` 作为模板进行复制
-2. **修改任务类型**:将 `TASK_TYPE` 修改为您的任务名称(不能与其他 Generator 重复)
-3. **实现核心函数**:实现 `run_rollout()` 函数
-4. **可选定制**:根据需要重写五个可选函数
-
-
-Generator 文件必须以 `_generator.py` 结尾,并放置在 `generator/` 目录下:
-
-```
-generator/
-├── base_generator.py # Math 任务实现(默认模板)
-└── your_task_generator.py # 您的自定义任务
-```
-
-每个 Generator 文件必须定义 `TASK_TYPE` 与 `run_rollout()`。
-
-此外,Rollout Buffer 还提供了一些可自定义的函数来满足不同任务的特殊需求。如果不提供自定义实现,系统将使用默认实现(位于 `slime_plugins/rollout_buffer/default_func.py`)。
-
-### 示例脚本
-
-请仿照 [示例:Qwen3-4B 模型](../../docs/zh/models/qwen3-4B.md) 文档中配置好 slime 的运行环境,下载数据,并转换模型 ckpt。之后分别运行
-
-```bash
-cd slime_plugins/rollout_buffer
-bash rollout_buffer_example.sh
-
-# In a different terminal
-python buffer.py
-```
diff --git a/slime_plugins/rollout_buffer/buffer.py b/slime_plugins/rollout_buffer/buffer.py
deleted file mode 100644
index 2f2dfc7..0000000
--- a/slime_plugins/rollout_buffer/buffer.py
+++ /dev/null
@@ -1,340 +0,0 @@
-import copy
-import glob
-import importlib.util
-import json
-import pathlib
-import threading
-import time
-from typing import Any, Dict, Optional, List
-
-import uvicorn
-from fastapi import BackgroundTasks, FastAPI, HTTPException, Request
-from pydantic import BaseModel
-
-app = FastAPI(title="Rollout Buffer Server", debug=True)
-
-
-def default_is_valid_group(group_data, min_valid_group_size, task_type):
- instance_id, samples = group_data
- return len(samples) >= min_valid_group_size
-
-
-def default_get_group_data_meta_info(temp_data: Dict[str, List[Dict[str, Any]]]) -> Dict[str, Any]:
- """
- Default implementation for getting meta information about the temporary data
- collected between get_batch calls.
- """
- if not temp_data:
- return {
- "total_samples": 0,
- "num_groups": 0,
- "avg_group_size": 0,
- "avg_reward": 0,
- }
-
- meta_info = {"total_samples": 0, "num_groups": len(temp_data)}
-
- all_rewards = []
- # Calculate per-group statistics
- for instance_id, samples in temp_data.items():
- group_size = len(samples)
- group_rewards = [s["reward"] for s in samples] # Calculate group reward standard deviation
- meta_info["total_samples"] += group_size
- all_rewards.extend(group_rewards)
- # Calculate global statistics
- meta_info["avg_group_size"] = meta_info["total_samples"] / meta_info["num_groups"]
-
- if all_rewards:
- meta_info["avg_reward"] = sum(all_rewards) / len(all_rewards)
- else:
- meta_info["avg_reward"] = 0
- return meta_info
-
-
-def discover_generators():
- """
- Automatically discover generator modules in the generator directory.
- Returns a dictionary mapping task_type to module with run_rollout function.
- """
- generator_map = {}
- generator_dir = pathlib.Path(__file__).parent / "generator"
-
- # Find all files within generator_dir
- for file_path in glob.glob(str(generator_dir / "*.py")):
- if file_path.endswith("__init__.py"):
- continue
-
- try:
- # Load the module
- spec = importlib.util.spec_from_file_location("generator_module", file_path)
- if spec is None or spec.loader is None:
- print(f"Warning: Could not load spec for {file_path}")
- continue
-
- module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(module)
-
- # Check if module has TASK_TYPE constant
- if not hasattr(module, "TASK_TYPE"):
- print(f"Warning: {file_path} does not define TASK_TYPE constant")
- continue
-
- # Check if module has run_rollout function
- if not hasattr(module, "run_rollout"):
- print(f"Warning: {file_path} does not define run_rollout function")
- continue
-
- task_type = getattr(module, "TASK_TYPE")
- generator_info = {
- "module": module,
- "file_path": file_path,
- "run_rollout": getattr(module, "run_rollout"),
- }
-
- # Check for optional functions and use defaults if not present
- for func_name in [
- "transform_group",
- "is_valid_group",
- "get_group_data_meta_info",
- ]:
- generator_info[func_name] = getattr(module, func_name, None)
-
- generator_map[task_type] = generator_info
- print(f"Discovered generator: {task_type} -> {file_path}")
-
- except Exception as e:
- print(f"Error loading generator from {file_path}: {str(e)}")
- continue
-
- return generator_map
-
-
-@app.middleware("http")
-async def set_body_size(request: Request, call_next):
- request._body_size_limit = 1_073_741_824 # 1GB
- response = await call_next(request)
- return response
-
-
-class BufferResponse(BaseModel):
- success: bool
- message: str = ""
- data: Optional[Dict[str, Any]] = None
-
-
-class BufferQueue:
- def __init__(
- self,
- group_size,
- task_type="math",
- transform_group_func=None,
- is_valid_group_func=None,
- get_group_data_meta_info_func=None,
- ):
- self.data = {}
- self.temp_data = {}
- self.group_timestamps = {}
- self.group_size = group_size
- self.task_type = task_type
-
- # Set up function handlers with defaults
- self.is_valid_group_func = is_valid_group_func or default_is_valid_group
- self.get_group_data_meta_info_func = get_group_data_meta_info_func or default_get_group_data_meta_info
- self.transform_group_func = transform_group_func or (lambda group, task_type: group)
-
- def append(self, item):
- instance_id = item["instance_id"]
- current_time = time.time()
-
- # Update timestamp for this group
- self.group_timestamps[instance_id] = current_time
-
- if instance_id not in self.temp_data:
- self.temp_data[instance_id] = [copy.deepcopy(item)]
- else:
- self.temp_data[instance_id].append(copy.deepcopy(item))
-
- if instance_id not in self.data:
- self.data[instance_id] = [item]
- else:
- self.data[instance_id].append(item)
-
- def _get_valid_groups_with_timeout(self, del_data=False):
- """Get valid groups including timeout-based groups"""
- valid_groups = {}
- timed_out_groups = {}
- finished_groups = []
-
- for instance_id, group_data in self.data.items():
- if self.is_valid_group_func((instance_id, group_data), self.group_size, self.task_type):
- valid_groups[instance_id] = group_data
-
- # Remove finished groups and timed out groups with insufficient data
- if del_data:
- for instance_id in finished_groups:
- self.data.pop(instance_id, None)
- self.group_timestamps.pop(instance_id, None)
- print(f"Removed finished group {instance_id}")
-
- # Combine normal valid groups and timeout groups
- all_valid_groups = {**valid_groups, **timed_out_groups}
-
- return all_valid_groups, finished_groups
-
- def get(self):
- output = {"data": [], "meta_info": {}}
-
- # Get meta information about temp data before processing
- meta_info = self.get_group_data_meta_info_func(self.temp_data)
- output["meta_info"] = meta_info
-
- valid_groups, finished_groups = self._get_valid_groups_with_timeout(del_data=True)
- output["meta_info"]["finished_groups"] = finished_groups
-
- print(f"meta info: {json.dumps(meta_info, indent=2)}")
-
- valid_groups = list(valid_groups.items())
-
- for instance_id, group in valid_groups:
- # First filter individual items
- transformed_group = self.transform_group_func((instance_id, group), self.task_type)
- output["data"].extend(transformed_group[1])
-
- if instance_id in self.data:
- self.data.pop(instance_id)
-
- return output
-
- def __len__(self):
- valid_groups, _ = self._get_valid_groups_with_timeout()
- num = sum([len(v) for v in valid_groups.values()])
- num_of_all_groups = sum([len(v) for v in self.data.values()])
- print(f"valid_groups: {len(valid_groups)}, num: {num}, num_of_all_groups: {num_of_all_groups}")
- return num
-
-
-class RolloutBuffer:
- def __init__(
- self,
- group_size=16,
- task_type="math",
- transform_group_func=None,
- is_valid_group_func=None,
- get_group_data_meta_info_func=None,
- ):
- self.buffer = BufferQueue(
- group_size=group_size,
- task_type=task_type,
- transform_group_func=transform_group_func,
- is_valid_group_func=is_valid_group_func,
- get_group_data_meta_info_func=get_group_data_meta_info_func,
- )
- self.lock = threading.RLock()
- self.not_empty = threading.Condition(self.lock)
- self.total_written = 0
- self.total_read = 0
- self.task_type = task_type
-
- def write(self, data):
- with self.lock:
- self.buffer.append(data)
- self.total_written += 1
- self.not_empty.notify_all()
- return data
-
- def read(self):
- with self.not_empty:
- if len(self.buffer) == 0:
- return {"data": [], "meta_info": {}}
-
- # Don't clear temp_data for regular read operations
- result = self.buffer.get()
- self.total_read += len(result["data"])
- return result
-
-
-buffer = RolloutBuffer()
-
-
-@app.post("/buffer/write", response_model=BufferResponse)
-async def write_to_buffer(request: Request):
- try:
- data = await request.json()
- item = buffer.write(data)
- return BufferResponse(
- success=True,
- message="Data has been successfully written to buffer",
- data={"data": [item], "meta_info": "write to buffer"},
- )
- except Exception as e:
- print(f"Write failed: {str(e)}")
- import traceback
-
- traceback.print_exc()
- raise HTTPException(status_code=500, detail=f"Write failed: {str(e)}")
-
-
-@app.post("/get_rollout_data", response_model=BufferResponse)
-async def get_rollout_data(request: Request):
- items = buffer.read()
-
- if not items["data"]:
- return BufferResponse(
- success=False,
- message="No data available to read",
- data={"data": [], "meta_info": items["meta_info"]},
- )
-
- print(f"return {len(items['data'])} items and save them to local")
- buffer.buffer.temp_data = {}
-
- return BufferResponse(
- success=True,
- message=f"Successfully read {len(items['data'])} items",
- data=items,
- )
-
-
-def run_rollout(data: dict):
- global buffer
- # Auto-discover generators
- generator_map = discover_generators()
-
- task_type = data["task_type"]
- if task_type not in generator_map:
- print(f"Error: No generator found for task_type '{task_type}'")
- print(f"Available generators: {list(generator_map.keys())}")
- return
-
- generator_info = generator_map[task_type]
- print(f"Using generator: {generator_info['file_path']} for task_type: {task_type}")
-
- buffer = RolloutBuffer(
- group_size=int(data["num_repeat_per_sample"]),
- task_type=task_type,
- transform_group_func=generator_info.get("transform_group", None),
- is_valid_group_func=generator_info.get("is_valid_group"),
- get_group_data_meta_info_func=generator_info.get("get_group_data_meta_info"),
- )
-
- # Call the run_rollout function from the appropriate generator module
- generator_info["run_rollout"](data)
- print(f"Rollout completed successfully for task_type: {task_type}")
-
-
-@app.post("/start_rollout")
-async def start_rollout(request: Request, background: BackgroundTasks):
- payload = await request.json()
- background.add_task(run_rollout, payload)
- return {"message": "Rollout started"}
-
-
-if __name__ == "__main__":
- uvicorn.run(
- app,
- host="0.0.0.0",
- port=8889,
- limit_concurrency=1000, # Connection concurrency limit
- # limit_max_requests=1000000, # Maximum request limit
- timeout_keep_alive=5, # Keep-alive timeout,
- )
diff --git a/slime_plugins/rollout_buffer/generator/__init__.py b/slime_plugins/rollout_buffer/generator/__init__.py
deleted file mode 100644
index 87090b1..0000000
--- a/slime_plugins/rollout_buffer/generator/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from .base_generator import BaseGenerator, query_single_turn
-
-__all__ = [
- "BaseGenerator",
- "query_single_turn",
-]
diff --git a/slime_plugins/rollout_buffer/generator/base_generator.py b/slime_plugins/rollout_buffer/generator/base_generator.py
deleted file mode 100644
index 0b4c5bc..0000000
--- a/slime_plugins/rollout_buffer/generator/base_generator.py
+++ /dev/null
@@ -1,351 +0,0 @@
-import copy
-import json
-import random
-import time
-import uuid
-from functools import partial
-from multiprocessing import Process, Queue
-from time import sleep
-from typing import List, Optional
-
-import requests
-from openai import OpenAI
-from tqdm import tqdm
-from slime.rollout.rm_hub import get_deepscaler_rule_based_reward
-
-TASK_TYPE = "math"
-
-SAMPLING_PARAMS = {
- "top_p": 1,
-}
-
-
-def get_rule_based_math_reward(item):
- messages = item["messages"]
- label = item["label"]
- assert messages[-1]["role"] == "assistant", "last message must be assistant, but got {}".format(
- messages[-1]["role"]
- )
-
- response = messages[-1]["content"]
- if response is None or len(response) == 0:
- return 0
-
- reward = get_deepscaler_rule_based_reward(response, label)
- return reward
-
-
-def query_single_turn(client, messages, sampling_params, tools=None):
- base_payload = {
- "messages": messages,
- **sampling_params,
- "model": "custom",
- "stream": False,
- "seed": random.randint(1, 10000000),
- "tools": tools,
- }
-
- text = None
- accumulated_tokens = 0
-
- for attempt in range(6):
- try:
- # Create a fresh payload for each attempt
- current_payload = copy.deepcopy(base_payload)
-
- if text is not None:
- # Update messages with current progress
- current_messages = copy.deepcopy(messages)
- current_messages.append({"role": "assistant", "content": text})
- current_payload["messages"] = current_messages
-
- # Adjust max_tokens based on accumulated tokens
- if "max_tokens" in sampling_params:
- current_payload["max_tokens"] = max(0, sampling_params["max_tokens"] - accumulated_tokens)
-
- # Add continue flag for partial rollouts
- current_payload["extra_body"] = {"continue_final_message": True}
- if current_payload["max_tokens"] == 0:
- break
- response = client.chat.completions.create(**current_payload)
-
- if len(response.choices) > 0:
- finish_reason = response.choices[0].finish_reason
- if finish_reason == "abort":
- print(
- f"query failed, reason: {response.choices[0].finish_reason}, currently generated: {response.usage.completion_tokens}"
- )
-
- accumulated_tokens += response.usage.completion_tokens
-
- if text is None:
- text = response.choices[0].message.content
- else:
- text += response.choices[0].message.content
-
- sleep(10)
- continue
- if text is None:
- text = response.choices[0].message.content
- elif response.choices[0].message.content is not None:
- text += response.choices[0].message.content
- break
- else:
- print(f"Error in query, status code: {response.status_code}")
- continue
- except Exception as e:
- print(f"query failed in single turn, error: {e}")
- continue
-
- # Update final messages
- if len(messages) > 0 and messages[-1]["role"] == "assistant":
- messages = messages[:-1]
- messages.append({"role": "assistant", "content": text})
-
- return messages, finish_reason
-
-
-def worker_process(task_queue, done_queue, rollout_func, reward_func, client, sampling_params):
-
- for line in iter(task_queue.get, "STOP"):
- if isinstance(line, str):
- item = json.loads(line)
- else:
- item = line
-
- # try:
- messages, finish_reason = rollout_func(client, item["prompt"], sampling_params)
-
- item["uid"] = str(uuid.uuid4())
- item["messages"] = messages
- reward = reward_func(item)
- item["rollout_index"] = 1
- item["reward"] = reward
- item["extra_info"] = {}
- item.update(sampling_params)
- item["timestamp"] = str(time.time())
- item["round_number"] = len([_ for _ in item["messages"] if _["role"] == "assistant"])
- item["finish_reason"] = finish_reason
-
- output_item = {
- "uid": item.pop("uid"),
- "messages": messages,
- "reward": reward,
- "instance_id": item.pop("instance_id"),
- "extra_info": item,
- }
-
- done_queue.put(output_item)
-
- done_queue.put("COMPLETE")
-
-
-class BaseGenerator:
- def __init__(
- self,
- remote_engine_url,
- remote_buffer_url,
- num_repeat_per_sample=1,
- queue_size=1000000,
- num_process=10,
- task_type="math",
- max_tokens=4096,
- num_repeats=10,
- skip_instance_ids: Optional[List[str]] = None,
- ):
- self.queue_size = queue_size
- self.num_process = num_process
- self.remote_engine_url = remote_engine_url
- self.remote_buffer_url = remote_buffer_url
- self.num_repeat_per_sample = num_repeat_per_sample
- self.task_type = task_type
- self.max_tokens = max_tokens
- self.num_repeats = num_repeats
- # Ensure skip_instance_ids is a mutable list (copy to avoid modifying original)
- self.skip_instance_ids = list(skip_instance_ids) if skip_instance_ids is not None else None
-
- if self.skip_instance_ids is not None:
- print(f"BaseGenerator initialized with {len(self.skip_instance_ids)} instance_ids to skip")
- self.skip_instance_ids = self.skip_instance_ids * self.num_repeat_per_sample
-
- if "/v1" in remote_engine_url:
- self.client = OpenAI(api_key="test", base_url=remote_engine_url)
- else:
- remote_engine_url = remote_engine_url.strip("/") + "/v1"
- self.client = OpenAI(api_key="test", base_url=remote_engine_url)
-
- def send_data_to_buffer(self, data):
- remote_buffer_url = self.remote_buffer_url.rstrip("/") + "/buffer/write"
-
- for _ in range(2):
- try:
- response = requests.post(remote_buffer_url, json=data)
- if response.status_code == 200:
- break
- else:
- print(f"send data to buffer failed, status code: {response.status_code}")
- continue
- except Exception as e:
- print(f"send data to buffer failed, error: {e}")
- continue
-
- def run(self, input_file, rollout_func, reward_func):
- task_queue, done_queue = Queue(maxsize=self.queue_size), Queue(maxsize=self.queue_size)
-
- def read_data_into_queue():
- cnt = 0
- items = []
- skipped_count = 0
- with open(input_file, "r") as f:
- for i, line in enumerate(f):
- item = json.loads(line)
- if "instance_id" not in item:
- item["instance_id"] = i
- items.append(item)
- random.shuffle(items)
-
- for _ in range(self.num_repeats):
-
- for item in items:
- for _ in range(self.num_repeat_per_sample):
- item_repeat = copy.deepcopy(item)
-
- if "uid" not in item_repeat:
- item_repeat["uid"] = str(uuid.uuid4())
-
- # Check if instance_id should be skipped
- if self.skip_instance_ids is not None and item_repeat["instance_id"] in self.skip_instance_ids:
- print(f"Skipping instance_id: {item_repeat['instance_id']}")
- # Remove from skip list to handle potential duplicates in multiple epochs
- self.skip_instance_ids.remove(item_repeat["instance_id"])
- skipped_count += 1
- continue
-
- task_queue.put(item_repeat)
- cnt += 1
- time.sleep(300)
-
- if skipped_count > 0:
- remaining_skip_count = len(self.skip_instance_ids) if self.skip_instance_ids is not None else 0
- print(
- f"Rollout summary: skipped {skipped_count} instance_ids, {remaining_skip_count} still in skip list"
- )
-
- for _ in range(self.num_process):
- task_queue.put("STOP")
-
- processes = []
- SAMPLING_PARAMS["max_tokens"] = self.max_tokens
-
- for _ in range(self.num_process):
- process = Process(
- target=partial(worker_process, client=self.client, sampling_params=SAMPLING_PARAMS),
- args=(task_queue, done_queue, rollout_func, reward_func),
- )
- process.start()
- processes.append(process)
-
- process = Process(target=read_data_into_queue)
- process.start()
-
- progress_bar = tqdm()
- num_finished = 0
- while num_finished < self.num_process:
- item = done_queue.get()
- if item == "COMPLETE":
- num_finished += 1
- else:
- assert "reward" in item, f"reward not in item: {item}"
- assert "instance_id" in item, f"instance_id not in item: {item}"
- self.send_data_to_buffer(item)
- progress_bar.update(1)
-
- progress_bar.close()
-
- return "finished"
-
- def entry(self, input_file, rollout_func, reward_func, num_epoch=1):
- for _ in range(num_epoch):
- status = self.run(input_file, rollout_func, reward_func)
-
-
-def run_rollout(data: dict):
-
- print(f"Starting math rollout with data: {data}")
-
- rollout_func = query_single_turn
- reward_func = get_rule_based_math_reward
-
- print(f"Waiting for 10 seconds for buffer server to start")
- time.sleep(10)
- global SAMPLING_PARAMS
- for k, v in data["sampling_params"].items():
- SAMPLING_PARAMS[k] = v
- print(f"Set {k} to {v}", type(v))
-
- generator = BaseGenerator(
- data["remote_engine_url"],
- data["remote_buffer_url"],
- num_repeat_per_sample=int(data["num_repeat_per_sample"]),
- queue_size=1000000,
- max_tokens=int(data["sampling_params"]["max_tokens"]),
- num_process=int(data.get("num_process", 100)),
- task_type=data["task_type"],
- skip_instance_ids=data.get("skip_instance_ids", None),
- )
-
- generator.entry(data["input_file"], rollout_func, reward_func, int(data.get("num_epoch", 1)))
-
-
-def normalize_group_data(group, epsilon=1e-8, algo="grpo"):
- print(f"Using math-specific normalization for group {group[0]}")
-
- assert algo == "grpo", "Only 'grpo' is supported for now."
-
- instance_id = group[0]
- data = group[1]
- rewards = [item["reward"] for item in data]
-
- valid_rewards = [r for r in rewards if 1 >= r >= 0]
-
- if set(valid_rewards) == {0}:
- normalized_rewards = rewards
- else:
- mean_reward = sum(valid_rewards) / len(valid_rewards)
- std_reward = (sum((r - mean_reward) ** 2 for r in valid_rewards) / len(valid_rewards)) ** 0.5
-
- if std_reward < epsilon:
- print(f"[Math Info] Zero variance in group {instance_id}, setting all to 0.")
- normalized_rewards = [0.0 if 1 >= r >= 0 else r for r in rewards]
- else:
- normalized_rewards = [(r - mean_reward) / (std_reward + epsilon) if 1 >= r >= 0 else r for r in rewards]
-
- for i, item in enumerate(data):
- item["reward"] = normalized_rewards[i]
- item["raw_reward"] = rewards[i]
-
- return (instance_id, data)
-
-
-def is_valid_group(group, min_valid_group_size, task_type="math"):
- # Handle both tuple and list inputs
- if isinstance(group, tuple):
- instance_id, items = group
- else:
- items = group
-
- # Count valid items (non-empty responses)
- valid_indices = []
- for i, item in enumerate(items):
- if item["messages"][-1]["content"].strip():
- valid_indices.append(i)
-
- group_size = len(items)
- valid_count = len(valid_indices)
-
- # A group is finished if it has reached the target size
- is_finished = group_size >= min_valid_group_size
-
- is_valid = is_finished and valid_count >= min_valid_group_size
-
- return is_valid
diff --git a/slime_plugins/rollout_buffer/rollout_buffer_example.py b/slime_plugins/rollout_buffer/rollout_buffer_example.py
deleted file mode 100644
index a25e6ac..0000000
--- a/slime_plugins/rollout_buffer/rollout_buffer_example.py
+++ /dev/null
@@ -1,301 +0,0 @@
-import asyncio
-import time
-from typing import Any, Dict, List
-
-import aiohttp
-import requests
-from transformers import AutoTokenizer
-
-import wandb
-from slime.ray.buffer import Buffer
-from slime.utils.async_utils import run
-from slime.utils.mask_utils import MultiTurnLossMaskGenerator
-from slime.utils.types import Sample
-
-__all__ = ["generate_rollout"]
-
-
-# Global variables for evaluation
-TOKENIZER = None
-START_ROLLOUT = True
-
-
-def select_rollout_data(args, results, need_length):
- """
- Select the most recent groups when there are too many samples.
- Groups all samples by instance_id, sorts groups by timestamp.
-
- Args:
- args: Arguments containing configuration
- results: List of rollout data items with timestamps
-
- Returns:
- Selected samples from the newest groups based on timestamp cutoff
- """
- if not results:
- return results
-
- # Group samples by instance_id
- groups = {}
- for item in results:
- assert "instance_id" in item, "instance_id must be in item"
- instance_id = item["instance_id"]
- if instance_id not in groups:
- groups[instance_id] = []
- groups[instance_id].append(item)
-
- print(f"📊 Total groups: {len(groups)}, total samples: {len(results)}")
-
- # If we don't have too many samples, return all
- assert need_length < len(results), "need_length must be smaller than results length"
-
- # Get timestamp for each group (use the latest timestamp in the group)
- def get_group_timestamp(group_items):
- timestamps = []
- for item in group_items:
- if "timestamp" in item:
- timestamps.append(float(item["timestamp"]))
- elif "extra_info" in item and "timestamp" in item["extra_info"]:
- timestamps.append(float(item["extra_info"]["timestamp"]))
- return max(timestamps) if timestamps else 0
-
- # Create list of (group_id, timestamp, samples) and sort by timestamp
- group_data = []
- for group_id, group_items in groups.items():
- group_timestamp = get_group_timestamp(group_items)
- group_data.append((group_id, group_timestamp, group_items))
-
- # Sort groups by timestamp (newest first)
- group_data.sort(key=lambda x: x[1], reverse=True)
-
- selected_groups = group_data[:need_length]
-
- # Flatten selected groups back to sample list
- selected_results = []
- for group_id, timestamp, group_items in selected_groups:
- selected_results.extend(group_items)
-
- # Statistics for monitoring
- if selected_groups:
- newest_ts = selected_groups[0][1]
- oldest_ts = selected_groups[-1][1]
- print(f"📈 Selected {len(selected_groups)} groups with {len(selected_results)} samples")
- print(f"📈 Group timestamp range: {oldest_ts:.2f} to {newest_ts:.2f}")
- print(f"📈 Time span: {newest_ts - oldest_ts:.2f} seconds")
-
- return selected_results
-
-
-def log_raw_info(args, all_meta_info, rollout_id):
- final_meta_info = {}
- if all_meta_info:
- final_meta_info = {
- "total_samples": sum(meta["total_samples"] for meta in all_meta_info if "total_samples" in meta)
- }
-
- total_samples = final_meta_info["total_samples"]
- if total_samples > 0:
- weighted_reward_sum = sum(
- meta["avg_reward"] * meta["total_samples"]
- for meta in all_meta_info
- if "avg_reward" in meta and "total_samples" in meta
- )
-
- final_meta_info.update(
- {
- "avg_reward": weighted_reward_sum / total_samples,
- }
- )
- if hasattr(args, "use_wandb") and args.use_wandb:
- log_dict = {
- f"rollout/no_filter/total_samples": final_meta_info["total_samples"],
- f"rollout/no_filter/avg_reward": final_meta_info["avg_reward"],
- }
- try:
- if args.use_wandb:
- log_dict["rollout/step"] = (
- rollout_id
- if not args.wandb_always_use_train_step
- else rollout_id
- * args.rollout_batch_size
- * args.n_samples_per_prompt
- // args.global_batch_size
- )
- wandb.log(log_dict)
- print(f"no filter rollout log {rollout_id}: {log_dict}")
- except Exception as e:
- print(f"Failed to log to wandb: {e}")
- print(f"no filter rollout log {rollout_id}: {final_meta_info}")
- else:
- print(f"no filter rollout log {rollout_id}: {final_meta_info}")
-
-
-async def get_rollout_data(api_base_url: str) -> tuple[List[Dict[str, Any]], Dict[str, Any]]:
- start_time = time.time()
- async with aiohttp.ClientSession() as session:
- while True:
- async with session.post(
- f"{api_base_url}/get_rollout_data", json={}, timeout=aiohttp.ClientTimeout(total=120)
- ) as response:
- response.raise_for_status()
- resp_json = await response.json()
- if resp_json["success"]:
- break
- await asyncio.sleep(3)
- if time.time() - start_time > 30:
- print("rollout data is not ready, have been waiting for 30 seconds")
- # Reset start_time to continue waiting or handle timeout differently
- start_time = time.time() # Or raise an exception, or return empty list
-
- data = resp_json["data"]
- meta_info = {}
- if isinstance(data, list):
- if "data" in data[0]:
- data = [item["data"] for item in data]
- elif isinstance(data, dict):
- if "data" in data:
- meta_info = data["meta_info"]
- data = data["data"]
- print(f"Meta info: {meta_info}")
- required_keys = {"uid", "instance_id", "messages", "reward", "extra_info"}
- for item in data:
- if not required_keys.issubset(item.keys()):
- raise ValueError(f"Missing required keys in response item: {item}")
-
- return data, meta_info
-
-
-def start_rollout(api_base_url: str, args, metadata):
- url = f"{api_base_url}/start_rollout"
- print(f"metadata: {metadata}")
- finished_groups_instance_id_list = [item for sublist in metadata.values() for item in sublist]
- payload = {
- "num_process": str(getattr(args, "rollout_num_process", 100)),
- "num_epoch": str(args.num_epoch or 3),
- "remote_engine_url": f"http://{args.sglang_router_ip}:{args.sglang_router_port}",
- "remote_buffer_url": args.rollout_buffer_url,
- "task_type": args.rollout_task_type,
- "input_file": args.prompt_data,
- "num_repeat_per_sample": str(args.n_samples_per_prompt),
- "max_tokens": str(args.rollout_max_response_len),
- "sampling_params": {
- "max_tokens": args.rollout_max_response_len,
- "temperature": args.rollout_temperature,
- "top_p": args.rollout_top_p,
- },
- "tokenizer_path": args.hf_checkpoint,
- "skip_instance_ids": finished_groups_instance_id_list,
- }
- print("start rollout with payload: ", payload)
-
- while True:
- try:
- resp = requests.post(url, json=payload, timeout=10)
- resp.raise_for_status()
- data = resp.json()
- print(f"[start_rollout] Success: {data}")
- return data
- except Exception as e:
- print(f"[start_rollout] Failed to send rollout config: {e}")
-
-
-async def generate_rollout_async(
- args, rollout_id: int, data_buffer: Buffer, evaluation: bool = False
-) -> Dict[str, Any]:
-
- global START_ROLLOUT
- if evaluation:
- raise NotImplementedError("Evaluation rollout is not implemented")
-
- if START_ROLLOUT:
- metadata = data_buffer.get_metadata()
- start_inform = start_rollout(args.rollout_buffer_url, args, metadata)
- print(f"start rollout with payload: {start_inform}")
- print(f"start rollout id: {rollout_id}")
- START_ROLLOUT = False
-
- data_number_to_fetch = args.rollout_batch_size * args.n_samples_per_prompt - data_buffer.get_buffer_length()
- if data_number_to_fetch <= 0:
- print(
- f"❕buffer length: {data_buffer.get_buffer_length()}, buffer has enough data, return {args.rollout_batch_size} prompts"
- )
- return data_buffer.get_samples(args.rollout_batch_size)
- assert (
- data_number_to_fetch % args.n_samples_per_prompt == 0
- ), "data_number_to_fetch must be a multiple of n_samples_per_prompt"
- print(f"INFO: buffer length: {data_buffer.get_buffer_length()}, data_number_to_fetch: {data_number_to_fetch}")
- base_url = args.rollout_buffer_url
- tokenizer = AutoTokenizer.from_pretrained(args.hf_checkpoint, trust_remote_code=True)
- retry_times = 0
- results = []
- all_meta_info = []
-
- if args.fetch_trajectory_retry_times == -1:
- print(
- f"⚠️ [get_rollout_data] Fetch trajectory retry times set to -1, will retry indefinitely until sufficient data is collected"
- )
- while args.fetch_trajectory_retry_times == -1 or retry_times < args.fetch_trajectory_retry_times:
- try:
- while len(results) < data_number_to_fetch:
- time.sleep(5)
- data, meta_info = await get_rollout_data(api_base_url=base_url)
- results.extend(data)
- if meta_info:
- all_meta_info.append(meta_info)
- print(f"get rollout data with length: {len(results)}")
- break
- except Exception as err:
- print(f"[get_rollout_data] Failed to get rollout data: {err}, retry times: {retry_times}")
- retry_times += 1
-
- log_raw_info(args, all_meta_info, rollout_id)
-
- # Apply group-based data selection if there are too many samples
- results = select_rollout_data(args, results, data_number_to_fetch // args.n_samples_per_prompt)
-
- if len(all_meta_info) > 0 and "finished_groups" in all_meta_info[0]:
- finished_groups_instance_id_list = []
- for item in all_meta_info:
- finished_groups_instance_id_list.extend(item["finished_groups"])
-
- data_buffer.update_metadata({str(rollout_id): finished_groups_instance_id_list})
-
- print("finally get rollout data with length: ", len(results))
- sample_results = []
-
- for i, record in enumerate(results):
- oai_messages = record["messages"]
-
- mask_generator = MultiTurnLossMaskGenerator(tokenizer, tokenizer_type=args.loss_mask_type)
- token_ids, loss_mask = mask_generator.get_loss_mask(oai_messages)
- response_length = mask_generator.get_response_lengths([loss_mask])[0]
-
- loss_mask = loss_mask[-response_length:]
-
- sample_results.append(
- Sample(
- index=record["instance_id"],
- prompt=record["uid"],
- tokens=token_ids,
- response_length=response_length,
- reward=record["reward"],
- status=(
- Sample.Status.COMPLETED
- if "finish_reason" not in record["extra_info"] or record["extra_info"]["finish_reason"] != "length"
- else Sample.Status.TRUNCATED
- ),
- loss_mask=loss_mask,
- metadata={**record["extra_info"]},
- )
- )
- final_return_results = []
-
- data_buffer.add_samples(sample_results)
- final_return_results = data_buffer.get_samples(args.rollout_batch_size)
-
- return final_return_results
-
-
-def generate_rollout(args, rollout_id, data_buffer, evaluation=False):
- """Generate rollout for both training and evaluation."""
- return run(generate_rollout_async(args, rollout_id, data_buffer, evaluation))
diff --git a/slime_plugins/rollout_buffer/rollout_buffer_example.sh b/slime_plugins/rollout_buffer/rollout_buffer_example.sh
deleted file mode 100644
index 1c369bd..0000000
--- a/slime_plugins/rollout_buffer/rollout_buffer_example.sh
+++ /dev/null
@@ -1,133 +0,0 @@
-#!/bin/bash
-
-# for rerun the task
-pkill -9 sglang
-sleep 3
-ray stop --force
-pkill -9 ray
-pkill -9 python
-sleep 3
-pkill -9 ray
-pkill -9 python
-
-set -ex
-
-export PYTHONBUFFERED=16
-
-# DeepSeek-R1-Distill-Qwen-7B
-MODEL_ARGS=(
- --swiglu
- --num-layers 28
- --hidden-size 3584
- --ffn-hidden-size 18944
- --num-attention-heads 28
- --group-query-attention
- --num-query-groups 4
- --max-position-embeddings 131072
- --seq-length 4096
- --use-rotary-position-embeddings
- --disable-bias-linear
- --add-qkv-bias
- --normalization "RMSNorm"
- --norm-epsilon 1e-06
- --rotary-base 10000
- --vocab-size 152064
- --accumulate-allreduce-grads-in-fp32
- --attention-softmax-in-fp32
- --attention-backend flash
- --moe-token-dispatcher-type alltoall
- --untie-embeddings-and-output-weights
- --attention-dropout 0.0
- --hidden-dropout 0.0
-)
-
-CKPT_ARGS=(
- --hf-checkpoint /root/DeepSeek-R1-Distill-Qwen-7B
- --ref-load /root/DeepSeek-R1-Distill-Qwen-7B_torch_dist
- --save-interval 100
- --save /root/DeepSeek-R1-Distill-Qwen-7B_slime
-)
-
-ROLLOUT_ARGS=(
- --rollout-function-path slime_plugin.rollout_buffer.rollout_buffer_example.generate_rollout
- --rm-type deepscaler
- --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl
- --input-key prompt
- --label-key label
- --num-rollout 3000
- --rollout-batch-size 128
- --rollout-max-response-len 8192
- --rollout-temperature 0.8
- --rollout-shuffle
- --n-samples-per-prompt 8
- --global-batch-size 1024
- --micro-batch-size 8
- --ref-micro-batch-size 8
- --use-dynamic-batch-size
- --max-tokens-per-gpu 9216
- --balance-data
-)
-
-DISTRIBUTED_ARGS=(
- --tensor-model-parallel-size 2
- --pipeline-model-parallel-size 1
- --context-parallel-size 1
- --sequence-parallel
-)
-
-PERF_ARGS=(
- --recompute-granularity full
- --recompute-method uniform
- --recompute-num-layers 1
-)
-
-GRPO_ARGS=(
- --advantage-estimator grpo
- --use-kl-loss
- --kl-loss-coef 0.001
- --kl-loss-type low_var_kl
- --entropy-coef 0.00
-)
-
-OPTIMIZER_ARGS=(
- --lr 1e-6
- --lr-decay-style constant
- --weight-decay 0.1
- --adam-beta1 0.9
- --adam-beta2 0.98
-)
-
-WANDB_ARGS=(
- # --use-wandb
-)
-
-# launch the master node of ray in container
-export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
-ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats
-
-ray job submit --address="http://127.0.0.1:8265" \
- --runtime-env-json='{
- "env_vars": {
- "PYTHONPATH": "/root/Megatron-LM/",
- "CUDA_DEVICE_MAX_CONNECTIONS": "1",
- "NCCL_CUMEM_ENABLE": "0"
- }
- }' \
- -- python3 train_async.py \
- --actor-num-nodes 1 \
- --actor-num-gpus-per-node 4 \
- --rollout-num-gpus 4 \
- --rollout-num-gpus-per-engine 1 \
- ${MODEL_ARGS[@]} \
- ${CKPT_ARGS[@]} \
- ${ROLLOUT_ARGS[@]} \
- ${OPTIMIZER_ARGS[@]} \
- ${GRPO_ARGS[@]} \
- ${DISTRIBUTED_ARGS[@]} \
- ${WANDB_ARGS[@]} \
- ${PERF_ARGS[@]} \
- --rollout-buffer-url http://${MASTER_ADDR}:8889 \
- --keep-old-actor \
- --disable-rewards-normalization \
- --loss-mask-type distill_qwen \
- --log-passrate
diff --git a/tests/test-qwen2.5-0.5B-async.sh b/tests/test-qwen2.5-0.5B-async.sh
deleted file mode 100644
index 4f7fb4f..0000000
--- a/tests/test-qwen2.5-0.5B-async.sh
+++ /dev/null
@@ -1,135 +0,0 @@
-#!/bin/bash
-
-# for rerun the task
-pkill -9 sglang
-sleep 3
-ray stop --force
-pkill -9 ray
-pkill -9 python
-sleep 3
-pkill -9 ray
-pkill -9 python
-
-set -ex
-
-
-huggingface-cli download --repo-type dataset zhuzilin/gsm8k --local-dir gsm8k
-
-
-# will prevent ray from buffering stdout/stderr
-export PYTHONBUFFERED=16
-
-SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
-source "${SCRIPT_DIR}/../scripts/models/qwen2.5-0.5B.sh"
-
-CKPT_ARGS=(
- --hf-checkpoint /root/Qwen2.5-0.5B-Instruct/
- --ref-load /root/Qwen2.5-0.5B-Instruct_torch_dist/
-)
-
-ROLLOUT_ARGS=(
- --prompt-data gsm8k/train.parquet
- --input-key messages
- --label-key label
- --apply-chat-template
- --rollout-shuffle
- --rm-type math
- --num-rollout 3000
- --rollout-batch-size 32
- --n-samples-per-prompt 8
- --rollout-max-response-len 1024
- --rollout-temperature 0.8
- --rollout-num-gpus 2
-
- --over-sampling-batch-size 64
- --dynamic-sampling-filter-path slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std
-
- --global-batch-size 256
-)
-
-EVAL_ARGS=(
- --eval-interval 20
- --eval-prompt-data gsm8k gsm8k/test.parquet
- --n-samples-per-eval-prompt 1
- --eval-max-response-len 1024
- --eval-top-k 1
-)
-
-PERF_ARGS=(
- --tensor-model-parallel-size 1
- --sequence-parallel
- --pipeline-model-parallel-size 1
- --context-parallel-size 1
- --expert-model-parallel-size 1
- --expert-tensor-parallel-size 1
-
- # --micro-batch-size 1
- --use-dynamic-batch-size
- --max-tokens-per-gpu 9216
-)
-
-GRPO_ARGS=(
- --advantage-estimator grpo
- --use-kl-loss
- --kl-loss-coef 0.00
- --kl-loss-type low_var_kl
- --entropy-coef 0.00
- --eps-clip 0.2
- --eps-clip-high 0.28
-)
-
-OPTIMIZER_ARGS=(
- --optimizer adam
- --lr 1e-6
- --lr-decay-style constant
- --weight-decay 0.1
- --adam-beta1 0.9
- --adam-beta2 0.98
-)
-
-WANDB_ARGS=(
- --use-wandb
- --wandb-project slime-test
- --wandb-group test-qwen2.5-0.5B-gsm8k
-)
-
-SGLANG_ARGS=(
- --rollout-num-gpus-per-engine 1
- --sglang-mem-fraction-static 0.7
-)
-
-MISC_ARGS=(
- # default dropout in megatron is 0.1
- --attention-dropout 0.0
- --hidden-dropout 0.0
- # should be good for model performance
- --accumulate-allreduce-grads-in-fp32
- --attention-softmax-in-fp32
- # need to comment this when using model with MLA
- --attention-backend flash
-)
-
-# launch the master node of ray in container
-ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 4 --disable-usage-stats
-
-ray job submit --address="http://127.0.0.1:8265" \
- --runtime-env-json='{
- "env_vars": {
- "PYTHONPATH": "/root/Megatron-LM",
- "CUDA_DEVICE_MAX_CONNECTIONS": "1"
- }
- }' \
- -- python3 train_async.py \
- --actor-num-nodes 1 \
- --actor-num-gpus-per-node 2 \
- ${MODEL_ARGS[@]} \
- ${CKPT_ARGS[@]} \
- ${ROLLOUT_ARGS[@]} \
- ${OPTIMIZER_ARGS[@]} \
- ${GRPO_ARGS[@]} \
- ${DISTRIBUTED_ARGS[@]} \
- ${WANDB_ARGS[@]} \
- ${PERF_ARGS[@]} \
- ${EVAL_ARGS[@]} \
- ${SGLANG_ARGS[@]} \
- ${MISC_ARGS[@]}
diff --git a/tests/test-qwen2.5-0.5B.sh b/tests/test-qwen2.5-0.5B.sh
deleted file mode 100644
index 954ce96..0000000
--- a/tests/test-qwen2.5-0.5B.sh
+++ /dev/null
@@ -1,135 +0,0 @@
-#!/bin/bash
-
-# for rerun the task
-pkill -9 sglang
-sleep 3
-ray stop --force
-pkill -9 ray
-pkill -9 python
-sleep 3
-pkill -9 ray
-pkill -9 python
-
-set -ex
-
-
-huggingface-cli download --repo-type dataset zhuzilin/gsm8k --local-dir gsm8k
-
-
-# will prevent ray from buffering stdout/stderr
-export PYTHONBUFFERED=16
-
-SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
-source "${SCRIPT_DIR}/../scripts/models/qwen2.5-0.5B.sh"
-
-CKPT_ARGS=(
- --hf-checkpoint /root/Qwen2.5-0.5B-Instruct/
- --ref-load /root/Qwen2.5-0.5B-Instruct_torch_dist/
-)
-
-ROLLOUT_ARGS=(
- --prompt-data gsm8k/train.parquet
- --input-key messages
- --label-key label
- --apply-chat-template
- --rollout-shuffle
- --rm-type math
- --num-rollout 3000
- --rollout-batch-size 32
- --n-samples-per-prompt 8
- --rollout-max-response-len 1024
- --rollout-temperature 0.8
-
- --over-sampling-batch-size 64
- --dynamic-sampling-filter-path slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std
-
- --global-batch-size 256
-)
-
-EVAL_ARGS=(
- --eval-interval 20
- --eval-prompt-data gsm8k gsm8k/test.parquet
- --n-samples-per-eval-prompt 1
- --eval-max-response-len 1024
- --eval-top-k 1
-)
-
-PERF_ARGS=(
- --tensor-model-parallel-size 1
- --sequence-parallel
- --pipeline-model-parallel-size 1
- --context-parallel-size 1
- --expert-model-parallel-size 1
- --expert-tensor-parallel-size 1
-
- # --micro-batch-size 1
- --use-dynamic-batch-size
- --max-tokens-per-gpu 9216
-)
-
-GRPO_ARGS=(
- --advantage-estimator grpo
- --use-kl-loss
- --kl-loss-coef 0.00
- --kl-loss-type low_var_kl
- --entropy-coef 0.00
- --eps-clip 0.2
- --eps-clip-high 0.28
-)
-
-OPTIMIZER_ARGS=(
- --optimizer adam
- --lr 1e-6
- --lr-decay-style constant
- --weight-decay 0.1
- --adam-beta1 0.9
- --adam-beta2 0.98
-)
-
-WANDB_ARGS=(
- --use-wandb
- --wandb-project slime-test
- --wandb-group test-qwen2.5-0.5B-gsm8k
-)
-
-SGLANG_ARGS=(
- --rollout-num-gpus-per-engine 1
- --sglang-mem-fraction-static 0.7
-)
-
-MISC_ARGS=(
- # default dropout in megatron is 0.1
- --attention-dropout 0.0
- --hidden-dropout 0.0
- # should be good for model performance
- --accumulate-allreduce-grads-in-fp32
- --attention-softmax-in-fp32
- # need to comment this when using model with MLA
- --attention-backend flash
-)
-
-# launch the master node of ray in container
-ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 4 --disable-usage-stats
-
-ray job submit --address="http://127.0.0.1:8265" \
- --runtime-env-json='{
- "env_vars": {
- "PYTHONPATH": "/root/Megatron-LM",
- "CUDA_DEVICE_MAX_CONNECTIONS": "1"
- }
- }' \
- -- python3 train.py \
- --actor-num-nodes 1 \
- --actor-num-gpus-per-node 4 \
- --colocate \
- ${MODEL_ARGS[@]} \
- ${CKPT_ARGS[@]} \
- ${ROLLOUT_ARGS[@]} \
- ${OPTIMIZER_ARGS[@]} \
- ${GRPO_ARGS[@]} \
- ${DISTRIBUTED_ARGS[@]} \
- ${WANDB_ARGS[@]} \
- ${PERF_ARGS[@]} \
- ${EVAL_ARGS[@]} \
- ${SGLANG_ARGS[@]} \
- ${MISC_ARGS[@]}
diff --git a/tests/test_qwen3_0.6B.sh b/tests/test_qwen3_0.6B.sh
deleted file mode 100644
index a283ab3..0000000
--- a/tests/test_qwen3_0.6B.sh
+++ /dev/null
@@ -1,133 +0,0 @@
-#!/bin/bash
-
-# for rerun the task
-pkill -9 sglang
-sleep 3
-ray stop --force
-pkill -9 ray
-pkill -9 python
-sleep 3
-pkill -9 ray
-pkill -9 python
-
-set -ex
-
-# will prevent ray from buffering stdout/stderr
-export PYTHONBUFFERED=16
-
-SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
-source "${SCRIPT_DIR}/../scripts/models/qwen3-0.6B.sh"
-
-CKPT_ARGS=(
- --hf-checkpoint /root/Qwen3-0.6B
- --ref-load /root/Qwen3-0.6B_torch_dist
-)
-
-ROLLOUT_ARGS=(
- --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl
- --input-key prompt
- --label-key label
- --apply-chat-template
- --rollout-shuffle
- --rm-type deepscaler
- --num-rollout 3000
- --rollout-batch-size 32
- --n-samples-per-prompt 8
- --rollout-max-response-len 8192
- --rollout-temperature 0.8
-
- --over-sampling-batch-size 64
- --dynamic-sampling-filter-path slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std
- #--partial-rollout
-
- --global-batch-size 256
- #--balance-data
-)
-
-EVAL_ARGS=(
- --eval-interval 20
- --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl
- --n-samples-per-eval-prompt 1
- --eval-max-response-len 16384
- --eval-temperature 0
-)
-
-PERF_ARGS=(
- --tensor-model-parallel-size 1
- --sequence-parallel
- --pipeline-model-parallel-size 1
- --context-parallel-size 1
- --expert-model-parallel-size 1
- --expert-tensor-parallel-size 1
-
- # --micro-batch-size 1
- --use-dynamic-batch-size
- --max-tokens-per-gpu 9216
-)
-
-GRPO_ARGS=(
- --advantage-estimator grpo
- --use-kl-loss
- --kl-loss-coef 0.00
- --kl-loss-type low_var_kl
- --entropy-coef 0.00
- --eps-clip 0.2
- --eps-clip-high 0.28
-)
-
-OPTIMIZER_ARGS=(
- --optimizer adam
- --lr 1e-6
- --lr-decay-style constant
- --weight-decay 0.1
- --adam-beta1 0.9
- --adam-beta2 0.98
-)
-
-WANDB_ARGS=(
- #--use-wandb
- --wandb-project slime-test
- --wandb-group test-qwen-3-0.6B
-)
-
-SGLANG_ARGS=(
- --rollout-num-gpus-per-engine 1
- --sglang-mem-fraction-static 0.7
-)
-
-MISC_ARGS=(
- # default dropout in megatron is 0.1
- --attention-dropout 0.0
- --hidden-dropout 0.0
- # should be good for model performance
- --accumulate-allreduce-grads-in-fp32
- --attention-softmax-in-fp32
- # need to comment this when using model with MLA
- --attention-backend flash
-)
-
-# launch the master node of ray in container
-ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats
-
-ray job submit --address="http://127.0.0.1:8265" \
- --runtime-env-json='{
- "env_vars": {
- "PYTHONPATH": "/root/Megatron-LM",
- "CUDA_DEVICE_MAX_CONNECTIONS": "1"
- }
- }' \
- -- python3 train.py \
- --actor-num-nodes 1 \
- --actor-num-gpus-per-node 1 \
- --colocate \
- ${MODEL_ARGS[@]} \
- ${CKPT_ARGS[@]} \
- ${ROLLOUT_ARGS[@]} \
- ${OPTIMIZER_ARGS[@]} \
- ${GRPO_ARGS[@]} \
- ${DISTRIBUTED_ARGS[@]} \
- ${WANDB_ARGS[@]} \
- ${PERF_ARGS[@]} \
- ${EVAL_ARGS[@]} \
- ${SGLANG_ARGS[@]} \
- ${MISC_ARGS[@]}
\ No newline at end of file
diff --git a/train_async.py b/train_async.py
deleted file mode 100644
index f0d96bd..0000000
--- a/train_async.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import ray
-
-from slime.ray.placement_group import create_actor_group, create_placement_groups, create_rollout_manager
-from slime.utils.arguments import parse_args
-from slime.utils.wandb_utils import init_wandb_primary
-
-
-def train(args):
- assert not args.colocate, "Colocation is not supported for async training."
- # allocate the GPUs
- pgs = create_placement_groups(args)
- wandb_run_id = init_wandb_primary(args)
-
- actor_model = create_actor_group(args, pgs["actor"], wandb_run_id=wandb_run_id)
-
- # create the rollout manager, with sglang engines inside.
- rollout_manager = create_rollout_manager(args, pgs["rollout"], wandb_run_id=wandb_run_id)
-
- # calculate num_rollout from num_epoch
- num_rollout_per_epoch = None
- if args.num_rollout is None:
- num_rollout_per_epoch = ray.get(rollout_manager.data_buffer.get_num_rollout_per_epoch.remote())
- args.num_rollout = num_rollout_per_epoch * args.num_epoch
- assert args.num_rollout > 0
-
- # sync the initialization (model initalization, load checkpoint, etc.)
- # Note that we initialize it earlier as megatron ckpt loading may have really large peak memory usage.
- start_rollout_ids = ray.get(
- actor_model.async_init(args, role="actor", with_ref=args.kl_coef != 0 or args.use_kl_loss)
- )
- assert len(set(start_rollout_ids)) == 1
- if args.start_rollout_id is None:
- args.start_rollout_id = start_rollout_ids[0]
-
- if args.rollout_global_dataset:
- ray.get(rollout_manager.data_buffer.load.remote(args.start_rollout_id - 1))
-
- # initialize the connection for weight update during training
- ray.get(actor_model.async_init_weight_update_connections(rollout_manager))
-
- # always update weight first so that sglang has the loaded weights from training.
- ray.get(actor_model.async_update_weights())
-
- # async train loop.
- rollout_data_next_future = rollout_manager.async_generate(args.start_rollout_id)
- for rollout_id in range(args.start_rollout_id, args.num_rollout):
- # Sync the last generation
- if rollout_data_next_future is not None:
- rollout_data_curr_ref = ray.get(rollout_data_next_future)
-
- # Start the next rollout early.
- if rollout_id + 1 < args.num_rollout:
- rollout_data_next_future = rollout_manager.async_generate(rollout_id + 1)
-
- ray.get(actor_model.async_train(rollout_id, rollout_data_curr_ref))
-
- if args.save_interval is not None and (
- (rollout_id + 1) % args.save_interval == 0
- or (num_rollout_per_epoch is not None and (rollout_id + 1) % num_rollout_per_epoch == 0)
- ):
- ray.get(actor_model.async_save_model(rollout_id))
- if args.rollout_global_dataset:
- ray.get(rollout_manager.data_buffer.save.remote(rollout_id))
-
- if (rollout_id + 1) % args.update_weights_interval == 0:
- # sync generate before update weights to prevent update weight in the middle of generation
- rollout_data_curr_ref = ray.get(rollout_data_next_future)
- rollout_data_next_future = None
- ray.get(actor_model.async_update_weights())
-
- if args.eval_interval is not None and (
- (rollout_id + 1) % args.eval_interval == 0
- or (num_rollout_per_epoch is not None and (rollout_id + 1) % num_rollout_per_epoch == 0)
- ):
- eval_rollout_data_ref = ray.get(rollout_manager.async_generate(rollout_id, evaluation=True))
- ray.get(actor_model.async_eval(rollout_id, eval_rollout_data_ref))
-
-
-if __name__ == "__main__":
- args = parse_args()
- train(args)