From 02f200b1fbe8dacb3d24d6ed6b6515bc4ec9da7e Mon Sep 17 00:00:00 2001 From: ojh31 Date: Fri, 20 Jun 2025 20:15:23 -0700 Subject: [PATCH 1/6] Initial commit --- trl/extras/vllm_client.py | 4 ++++ trl/scripts/vllm_serve.py | 11 ++++++++++- trl/trainer/grpo_trainer.py | 10 +++++++++- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index 71341a99542..5b12a6b6280 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -173,6 +173,7 @@ def generate( top_k: int = -1, min_p: float = 0.0, max_tokens: int = 16, + lora_request: Optional[LoRARequest] = None, guided_decoding_regex: Optional[str] = None, generation_kwargs: Optional[dict] = None, ) -> list[list[int]]: @@ -196,6 +197,8 @@ def generate( Minimum probability for sampling. max_tokens (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each prompt. + lora_request (`LoRARequest` or `None`, *optional*, defaults to `None`): + Details of LoRA adapter to apply for the generation. guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`): Regular expression to guide the decoding process. generation_kwargs (`dict` or `None`, *optional*, defaults to `None`): @@ -219,6 +222,7 @@ def generate( "top_k": top_k, "min_p": min_p, "max_tokens": max_tokens, + "lora_request": lora_request, "guided_decoding_regex": guided_decoding_regex, "generation_kwargs": generation_kwargs or {}, }, diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index eb4a6e269ca..47d18288327 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -158,6 +158,8 @@ class ScriptArguments: Model name or path to load the model from. revision (`str` or `None`, *optional*, defaults to `None`): Revision to use for the model. If not specified, the default branch will be used. + enable_lora (`bool`, *optional*, defaults to `False`): + Whether to enable LoRA. tensor_parallel_size (`int`, *optional*, defaults to `1`): Number of tensor parallel workers to use. data_parallel_size (`int`, *optional*, defaults to `1`): @@ -199,7 +201,13 @@ class ScriptArguments: ) revision: Optional[str] = field( default=None, - metadata={"help": "Revision to use for the model. If not specified, the default branch will be used."}, + metadata={ + "help": "Revision to use for the model. If not specified, the default branch will be used." + }, + ) + enable_lora: bool = field( + default=False, + metadata={"help": "Whether to enable LoRA."}, ) tensor_parallel_size: int = field( default=1, @@ -290,6 +298,7 @@ def llm_worker( llm = LLM( model=script_args.model, revision=script_args.revision, + enable_lora=script_args.enable_lora, tensor_parallel_size=script_args.tensor_parallel_size, gpu_memory_utilization=script_args.gpu_memory_utilization, enforce_eager=script_args.enforce_eager, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 558b77e9b2e..654b7b9a182 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -913,7 +913,15 @@ def _move_model_to_vllm(self): else: gather_if_zero3 = nullcontext - if is_peft_model(self.model): + if ( + is_peft_model(self.model) + and self.is_fsdp_enabled + and self.vllm_mode == "server" + ): + # TODO: special handling for PEFT with FSDP and vLLM server + # Need to avoid summoning full params somehow + pass + elif is_peft_model(self.model): # With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as # merging adapters in a sharded manner is not supported. # TODO: does this work with FSDP? From 1b5f73162d87f301452096883e72d4018f9fa1e7 Mon Sep 17 00:00:00 2001 From: ojh31 Date: Fri, 20 Jun 2025 20:25:01 -0700 Subject: [PATCH 2/6] Link to reference implementation --- trl/extras/vllm_client.py | 3 +++ trl/trainer/grpo_trainer.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index 5b12a6b6280..89895a94b8f 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -297,6 +297,9 @@ def update_model_params(self, model: nn.Module): for name, param in model.named_parameters(): # Update each parameter individually self.update_named_param(name, param.data) + + # TODO(oskar): new interface for updating lora params + # https://github.com/AlignmentResearch/llm/blob/2aef8be95ed7e182c096a8e8135381bc4a58ee43/split/generation/vllm_backed.py#L111 def reset_prefix_cache(self): """ diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 654b7b9a182..d69c1c4e204 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -918,8 +918,9 @@ def _move_model_to_vllm(self): and self.is_fsdp_enabled and self.vllm_mode == "server" ): - # TODO: special handling for PEFT with FSDP and vLLM server + # TODO(oskar): special handling for PEFT with FSDP and vLLM server # Need to avoid summoning full params somehow + # https://github.com/AlignmentResearch/llm/blob/2aef8be95ed7e182c096a8e8135381bc4a58ee43/split/training/huggingface.py#L105 pass elif is_peft_model(self.model): # With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as From ed9ec85c072f2c11f426a192a85caa03c5759905 Mon Sep 17 00:00:00 2001 From: ojh31 Date: Fri, 20 Jun 2025 20:26:18 -0700 Subject: [PATCH 3/6] Revert whitespace --- trl/scripts/vllm_serve.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 47d18288327..15502b3cd3b 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -201,9 +201,7 @@ class ScriptArguments: ) revision: Optional[str] = field( default=None, - metadata={ - "help": "Revision to use for the model. If not specified, the default branch will be used." - }, + metadata={"help": "Revision to use for the model. If not specified, the default branch will be used."}, ) enable_lora: bool = field( default=False, From 97e29cbe85592b284d74e9e22627b5befc843fe1 Mon Sep 17 00:00:00 2001 From: ojh31 Date: Thu, 26 Jun 2025 16:12:42 -0700 Subject: [PATCH 4/6] Pass lora modules through generate and only sync trainable params --- trl/extras/vllm_client.py | 4 +--- trl/scripts/vllm_serve.py | 5 ++++- trl/trainer/grpo_trainer.py | 20 +++++++------------- 3 files changed, 12 insertions(+), 17 deletions(-) diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index 89895a94b8f..1a6ea38b32f 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -33,6 +33,7 @@ if is_vllm_available(): from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.utils import StatelessProcessGroup + from vllm.lora.request import LoRARequest if is_vllm_ascend_available(): from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator @@ -298,9 +299,6 @@ def update_model_params(self, model: nn.Module): # Update each parameter individually self.update_named_param(name, param.data) - # TODO(oskar): new interface for updating lora params - # https://github.com/AlignmentResearch/llm/blob/2aef8be95ed7e182c096a8e8135381bc4a58ee43/split/generation/vllm_backed.py#L111 - def reset_prefix_cache(self): """ Resets the prefix cache for the model. diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 15502b3cd3b..e22b2b51703 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -24,6 +24,7 @@ from typing import Optional import torch +from vllm.lora.request import LoRARequest from trl import TrlParser from trl.import_utils import ( @@ -439,6 +440,7 @@ class GenerateRequest(BaseModel): top_k: int = -1 min_p: float = 0.0 max_tokens: int = 16 + lora_request: Optional[LoRARequest] = None, guided_decoding_regex: Optional[str] = None generation_kwargs: dict = field(default_factory=dict) @@ -460,6 +462,7 @@ async def generate(request: GenerateRequest): - `top_k` (`int`, *optional*, defaults to `-1`): Top-k sampling parameter. If set to `-1`, it disables top-k sampling. - `min_p` (`float`, *optional*, defaults to `0.0`): Minimum probability threshold for sampling. - `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each completion. + - `lora_request` (`LoRARequest`, *optional*): A request for LoRA parameters. If provided, the model will use the LoRA parameters to generate completions. - `guided_decoding_regex` (`str`, *optional*): A regex pattern for guided decoding. If provided, the model will only generate tokens that match this regex pattern. - `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM `SamplingParams`. This can include parameters like `seed`, `frequency_penalty`, etc. If it contains keys that conflict with the other parameters, they will override them. @@ -507,7 +510,7 @@ async def generate(request: GenerateRequest): # with vLLM's requirement, and we later ignore the result. if not prompts: prompts = [""] - kwargs = {"prompts": prompts, "sampling_params": sampling_params} + kwargs = {"prompts": prompts, "sampling_params": sampling_params, "lora_request": request.lora_request} connection.send({"type": "call", "method": "generate", "kwargs": kwargs}) # Receive results diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index d69c1c4e204..9598199051a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -873,7 +873,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, all_logps.append(logps) return torch.cat(all_logps, dim=0) - def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None): + def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None, move_static_params: bool = False): """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM.""" if visited is None: visited = set() @@ -881,7 +881,7 @@ def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited for child_name, child_module in module.named_children(): child_prefix = f"{prefix}.{child_name}" if prefix else child_name self._sync_fsdp_params_to_vllm( - child_module, prefix=child_prefix, visited=visited + child_module, prefix=child_prefix, visited=visited, move_static_params=move_static_params ) # recurse into the child if isinstance(module, FSDP): @@ -894,6 +894,9 @@ def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited if full_name in visited: continue # skip FSDP subtrees already traversed visited.add(full_name) + if not param.requires_grad and not move_static_params: + # Useful e.g. for PEFT where most parameters are static + continue if self.vllm_mode == "server" and self.accelerator.is_main_process: self.vllm_client.update_named_param(full_name, param.data) @@ -913,16 +916,7 @@ def _move_model_to_vllm(self): else: gather_if_zero3 = nullcontext - if ( - is_peft_model(self.model) - and self.is_fsdp_enabled - and self.vllm_mode == "server" - ): - # TODO(oskar): special handling for PEFT with FSDP and vLLM server - # Need to avoid summoning full params somehow - # https://github.com/AlignmentResearch/llm/blob/2aef8be95ed7e182c096a8e8135381bc4a58ee43/split/training/huggingface.py#L105 - pass - elif is_peft_model(self.model): + if is_peft_model(self.model): # With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as # merging adapters in a sharded manner is not supported. # TODO: does this work with FSDP? @@ -933,7 +927,7 @@ def _move_model_to_vllm(self): if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext # Update vLLM weights while parameters are gathered # For PEFT with FSDP we need to use the memory efficient post-order traversal - self._sync_fsdp_params_to_vllm(self.model) + self._sync_fsdp_params_to_vllm(self.model, move_static_params=True) else: # DeepSpeed ZeRO-3 with PEFT for name, param in self.model.named_parameters(): From a125d72a603304e04095bb1c99990521c14fd838 Mon Sep 17 00:00:00 2001 From: ojh31 Date: Thu, 26 Jun 2025 16:20:51 -0700 Subject: [PATCH 5/6] Remove bad comma --- trl/scripts/vllm_serve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index e22b2b51703..20134564714 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -440,7 +440,7 @@ class GenerateRequest(BaseModel): top_k: int = -1 min_p: float = 0.0 max_tokens: int = 16 - lora_request: Optional[LoRARequest] = None, + lora_request: Optional[LoRARequest] = None guided_decoding_regex: Optional[str] = None generation_kwargs: dict = field(default_factory=dict) From 514f4ec95a85886a99571d295174cdbf591625cf Mon Sep 17 00:00:00 2001 From: ojh31 Date: Thu, 26 Jun 2025 16:27:58 -0700 Subject: [PATCH 6/6] PydanticLoraRequest --- trl/scripts/vllm_serve.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 20134564714..d45347db97c 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -431,6 +431,16 @@ async def get_world_size(): """ return {"world_size": script_args.tensor_parallel_size * script_args.data_parallel_size} + class PydanticLoRARequest(BaseModel): + """Pydantic-compatible LoRA request model.""" + lora_name: str + lora_int_id: int + lora_path: str + + def to_vllm_lora_request(self) -> LoRARequest: + """Convert to vLLM LoRARequest object.""" + return LoRARequest(self.lora_name, self.lora_int_id, self.lora_path) + class GenerateRequest(BaseModel): prompts: list[str] n: int = 1 @@ -440,7 +450,7 @@ class GenerateRequest(BaseModel): top_k: int = -1 min_p: float = 0.0 max_tokens: int = 16 - lora_request: Optional[LoRARequest] = None + lora_request: Optional[PydanticLoRARequest] = None guided_decoding_regex: Optional[str] = None generation_kwargs: dict = field(default_factory=dict) @@ -510,7 +520,9 @@ async def generate(request: GenerateRequest): # with vLLM's requirement, and we later ignore the result. if not prompts: prompts = [""] - kwargs = {"prompts": prompts, "sampling_params": sampling_params, "lora_request": request.lora_request} + # Convert PydanticLoRARequest to vLLM LoRARequest if provided + vllm_lora_request = request.lora_request.to_vllm_lora_request() if request.lora_request else None + kwargs = {"prompts": prompts, "sampling_params": sampling_params, "lora_request": vllm_lora_request} connection.send({"type": "call", "method": "generate", "kwargs": kwargs}) # Receive results