Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion trl/extras/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
if is_vllm_available():
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
from vllm.lora.request import LoRARequest

if is_vllm_ascend_available():
from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator
Expand Down Expand Up @@ -173,6 +174,7 @@ def generate(
top_k: int = -1,
min_p: float = 0.0,
max_tokens: int = 16,
lora_request: Optional[LoRARequest] = None,
guided_decoding_regex: Optional[str] = None,
generation_kwargs: Optional[dict] = None,
) -> list[list[int]]:
Expand All @@ -196,6 +198,8 @@ def generate(
Minimum probability for sampling.
max_tokens (`int`, *optional*, defaults to `16`):
Maximum number of tokens to generate for each prompt.
lora_request (`LoRARequest` or `None`, *optional*, defaults to `None`):
Details of LoRA adapter to apply for the generation.
guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
Regular expression to guide the decoding process.
generation_kwargs (`dict` or `None`, *optional*, defaults to `None`):
Expand All @@ -219,6 +223,7 @@ def generate(
"top_k": top_k,
"min_p": min_p,
"max_tokens": max_tokens,
"lora_request": lora_request,
"guided_decoding_regex": guided_decoding_regex,
"generation_kwargs": generation_kwargs or {},
},
Expand Down Expand Up @@ -293,7 +298,7 @@ def update_model_params(self, model: nn.Module):
for name, param in model.named_parameters():
# Update each parameter individually
self.update_named_param(name, param.data)

def reset_prefix_cache(self):
"""
Resets the prefix cache for the model.
Expand Down
24 changes: 23 additions & 1 deletion trl/scripts/vllm_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Optional

import torch
from vllm.lora.request import LoRARequest

from trl import TrlParser
from trl.import_utils import (
Expand Down Expand Up @@ -158,6 +159,8 @@ class ScriptArguments:
Model name or path to load the model from.
revision (`str` or `None`, *optional*, defaults to `None`):
Revision to use for the model. If not specified, the default branch will be used.
enable_lora (`bool`, *optional*, defaults to `False`):
Whether to enable LoRA.
tensor_parallel_size (`int`, *optional*, defaults to `1`):
Number of tensor parallel workers to use.
data_parallel_size (`int`, *optional*, defaults to `1`):
Expand Down Expand Up @@ -201,6 +204,10 @@ class ScriptArguments:
default=None,
metadata={"help": "Revision to use for the model. If not specified, the default branch will be used."},
)
enable_lora: bool = field(
default=False,
metadata={"help": "Whether to enable LoRA."},
)
tensor_parallel_size: int = field(
default=1,
metadata={"help": "Number of tensor parallel workers to use."},
Expand Down Expand Up @@ -290,6 +297,7 @@ def llm_worker(
llm = LLM(
model=script_args.model,
revision=script_args.revision,
enable_lora=script_args.enable_lora,
tensor_parallel_size=script_args.tensor_parallel_size,
gpu_memory_utilization=script_args.gpu_memory_utilization,
enforce_eager=script_args.enforce_eager,
Expand Down Expand Up @@ -423,6 +431,16 @@ async def get_world_size():
"""
return {"world_size": script_args.tensor_parallel_size * script_args.data_parallel_size}

class PydanticLoRARequest(BaseModel):
"""Pydantic-compatible LoRA request model."""
lora_name: str
lora_int_id: int
lora_path: str

def to_vllm_lora_request(self) -> LoRARequest:
"""Convert to vLLM LoRARequest object."""
return LoRARequest(self.lora_name, self.lora_int_id, self.lora_path)

class GenerateRequest(BaseModel):
prompts: list[str]
n: int = 1
Expand All @@ -432,6 +450,7 @@ class GenerateRequest(BaseModel):
top_k: int = -1
min_p: float = 0.0
max_tokens: int = 16
lora_request: Optional[PydanticLoRARequest] = None
guided_decoding_regex: Optional[str] = None
generation_kwargs: dict = field(default_factory=dict)

Expand All @@ -453,6 +472,7 @@ async def generate(request: GenerateRequest):
- `top_k` (`int`, *optional*, defaults to `-1`): Top-k sampling parameter. If set to `-1`, it disables top-k sampling.
- `min_p` (`float`, *optional*, defaults to `0.0`): Minimum probability threshold for sampling.
- `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each completion.
- `lora_request` (`LoRARequest`, *optional*): A request for LoRA parameters. If provided, the model will use the LoRA parameters to generate completions.
- `guided_decoding_regex` (`str`, *optional*): A regex pattern for guided decoding. If provided, the model will only generate tokens that match this regex pattern.
- `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM `SamplingParams`. This can include parameters like `seed`, `frequency_penalty`, etc. If it contains keys that conflict with the other parameters, they will override them.

Expand Down Expand Up @@ -500,7 +520,9 @@ async def generate(request: GenerateRequest):
# with vLLM's requirement, and we later ignore the result.
if not prompts:
prompts = ["<placeholder>"]
kwargs = {"prompts": prompts, "sampling_params": sampling_params}
# Convert PydanticLoRARequest to vLLM LoRARequest if provided
vllm_lora_request = request.lora_request.to_vllm_lora_request() if request.lora_request else None
kwargs = {"prompts": prompts, "sampling_params": sampling_params, "lora_request": vllm_lora_request}
connection.send({"type": "call", "method": "generate", "kwargs": kwargs})

# Receive results
Expand Down
9 changes: 6 additions & 3 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,15 +873,15 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep,
all_logps.append(logps)
return torch.cat(all_logps, dim=0)

def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None):
def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None, move_static_params: bool = False):
"""Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM."""
if visited is None:
visited = set()

for child_name, child_module in module.named_children():
child_prefix = f"{prefix}.{child_name}" if prefix else child_name
self._sync_fsdp_params_to_vllm(
child_module, prefix=child_prefix, visited=visited
child_module, prefix=child_prefix, visited=visited, move_static_params=move_static_params
) # recurse into the child

if isinstance(module, FSDP):
Expand All @@ -894,6 +894,9 @@ def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited
if full_name in visited:
continue # skip FSDP subtrees already traversed
visited.add(full_name)
if not param.requires_grad and not move_static_params:
# Useful e.g. for PEFT where most parameters are static
continue

if self.vllm_mode == "server" and self.accelerator.is_main_process:
self.vllm_client.update_named_param(full_name, param.data)
Expand Down Expand Up @@ -924,7 +927,7 @@ def _move_model_to_vllm(self):
if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext
# Update vLLM weights while parameters are gathered
# For PEFT with FSDP we need to use the memory efficient post-order traversal
self._sync_fsdp_params_to_vllm(self.model)
self._sync_fsdp_params_to_vllm(self.model, move_static_params=True)
else:
# DeepSpeed ZeRO-3 with PEFT
for name, param in self.model.named_parameters():
Expand Down