diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index 71341a99542..1a6ea38b32f 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -33,6 +33,7 @@ if is_vllm_available(): from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.utils import StatelessProcessGroup + from vllm.lora.request import LoRARequest if is_vllm_ascend_available(): from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator @@ -173,6 +174,7 @@ def generate( top_k: int = -1, min_p: float = 0.0, max_tokens: int = 16, + lora_request: Optional[LoRARequest] = None, guided_decoding_regex: Optional[str] = None, generation_kwargs: Optional[dict] = None, ) -> list[list[int]]: @@ -196,6 +198,8 @@ def generate( Minimum probability for sampling. max_tokens (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each prompt. + lora_request (`LoRARequest` or `None`, *optional*, defaults to `None`): + Details of LoRA adapter to apply for the generation. guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`): Regular expression to guide the decoding process. generation_kwargs (`dict` or `None`, *optional*, defaults to `None`): @@ -219,6 +223,7 @@ def generate( "top_k": top_k, "min_p": min_p, "max_tokens": max_tokens, + "lora_request": lora_request, "guided_decoding_regex": guided_decoding_regex, "generation_kwargs": generation_kwargs or {}, }, @@ -293,7 +298,7 @@ def update_model_params(self, model: nn.Module): for name, param in model.named_parameters(): # Update each parameter individually self.update_named_param(name, param.data) - + def reset_prefix_cache(self): """ Resets the prefix cache for the model. diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index eb4a6e269ca..d45347db97c 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -24,6 +24,7 @@ from typing import Optional import torch +from vllm.lora.request import LoRARequest from trl import TrlParser from trl.import_utils import ( @@ -158,6 +159,8 @@ class ScriptArguments: Model name or path to load the model from. revision (`str` or `None`, *optional*, defaults to `None`): Revision to use for the model. If not specified, the default branch will be used. + enable_lora (`bool`, *optional*, defaults to `False`): + Whether to enable LoRA. tensor_parallel_size (`int`, *optional*, defaults to `1`): Number of tensor parallel workers to use. data_parallel_size (`int`, *optional*, defaults to `1`): @@ -201,6 +204,10 @@ class ScriptArguments: default=None, metadata={"help": "Revision to use for the model. If not specified, the default branch will be used."}, ) + enable_lora: bool = field( + default=False, + metadata={"help": "Whether to enable LoRA."}, + ) tensor_parallel_size: int = field( default=1, metadata={"help": "Number of tensor parallel workers to use."}, @@ -290,6 +297,7 @@ def llm_worker( llm = LLM( model=script_args.model, revision=script_args.revision, + enable_lora=script_args.enable_lora, tensor_parallel_size=script_args.tensor_parallel_size, gpu_memory_utilization=script_args.gpu_memory_utilization, enforce_eager=script_args.enforce_eager, @@ -423,6 +431,16 @@ async def get_world_size(): """ return {"world_size": script_args.tensor_parallel_size * script_args.data_parallel_size} + class PydanticLoRARequest(BaseModel): + """Pydantic-compatible LoRA request model.""" + lora_name: str + lora_int_id: int + lora_path: str + + def to_vllm_lora_request(self) -> LoRARequest: + """Convert to vLLM LoRARequest object.""" + return LoRARequest(self.lora_name, self.lora_int_id, self.lora_path) + class GenerateRequest(BaseModel): prompts: list[str] n: int = 1 @@ -432,6 +450,7 @@ class GenerateRequest(BaseModel): top_k: int = -1 min_p: float = 0.0 max_tokens: int = 16 + lora_request: Optional[PydanticLoRARequest] = None guided_decoding_regex: Optional[str] = None generation_kwargs: dict = field(default_factory=dict) @@ -453,6 +472,7 @@ async def generate(request: GenerateRequest): - `top_k` (`int`, *optional*, defaults to `-1`): Top-k sampling parameter. If set to `-1`, it disables top-k sampling. - `min_p` (`float`, *optional*, defaults to `0.0`): Minimum probability threshold for sampling. - `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each completion. + - `lora_request` (`LoRARequest`, *optional*): A request for LoRA parameters. If provided, the model will use the LoRA parameters to generate completions. - `guided_decoding_regex` (`str`, *optional*): A regex pattern for guided decoding. If provided, the model will only generate tokens that match this regex pattern. - `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM `SamplingParams`. This can include parameters like `seed`, `frequency_penalty`, etc. If it contains keys that conflict with the other parameters, they will override them. @@ -500,7 +520,9 @@ async def generate(request: GenerateRequest): # with vLLM's requirement, and we later ignore the result. if not prompts: prompts = [""] - kwargs = {"prompts": prompts, "sampling_params": sampling_params} + # Convert PydanticLoRARequest to vLLM LoRARequest if provided + vllm_lora_request = request.lora_request.to_vllm_lora_request() if request.lora_request else None + kwargs = {"prompts": prompts, "sampling_params": sampling_params, "lora_request": vllm_lora_request} connection.send({"type": "call", "method": "generate", "kwargs": kwargs}) # Receive results diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 558b77e9b2e..9598199051a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -873,7 +873,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, all_logps.append(logps) return torch.cat(all_logps, dim=0) - def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None): + def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None, move_static_params: bool = False): """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM.""" if visited is None: visited = set() @@ -881,7 +881,7 @@ def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited for child_name, child_module in module.named_children(): child_prefix = f"{prefix}.{child_name}" if prefix else child_name self._sync_fsdp_params_to_vllm( - child_module, prefix=child_prefix, visited=visited + child_module, prefix=child_prefix, visited=visited, move_static_params=move_static_params ) # recurse into the child if isinstance(module, FSDP): @@ -894,6 +894,9 @@ def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited if full_name in visited: continue # skip FSDP subtrees already traversed visited.add(full_name) + if not param.requires_grad and not move_static_params: + # Useful e.g. for PEFT where most parameters are static + continue if self.vllm_mode == "server" and self.accelerator.is_main_process: self.vllm_client.update_named_param(full_name, param.data) @@ -924,7 +927,7 @@ def _move_model_to_vllm(self): if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext # Update vLLM weights while parameters are gathered # For PEFT with FSDP we need to use the memory efficient post-order traversal - self._sync_fsdp_params_to_vllm(self.model) + self._sync_fsdp_params_to_vllm(self.model, move_static_params=True) else: # DeepSpeed ZeRO-3 with PEFT for name, param in self.model.named_parameters():