Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions packages/prime/src/prime_cli/api/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def create_run(
lora_alpha: Optional[int] = None,
oversampling_factor: Optional[float] = None,
max_async_level: Optional[int] = None,
run_config: Optional[Dict[str, Any]] = None,
) -> RLRun:
"""Create a new RL training run."""
try:
Expand Down Expand Up @@ -177,6 +178,9 @@ def create_run(
if max_async_level is not None:
payload["max_async_level"] = max_async_level

if run_config:
payload["run_config"] = run_config
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent payload key breaks config naming pattern

Medium Severity

The run_config parameter uses payload["run_config"] as its key, but the established pattern for similar configs uses shortened names without the _config suffix: eval_config"eval", val_config"val", buffer_config"buffer". If the backend API follows this same naming convention, it would expect "run" as the key rather than "run_config", causing the configuration to be silently ignored.

Fix in Cursor Fix in Web


response = self.client.post("/rft/runs", json=payload)
return RLRun.model_validate(response.get("run"))
except ValidationError:
Expand Down Expand Up @@ -311,9 +315,7 @@ def get_distributions(
if step is not None:
params["step"] = step

response = self.client.get(
f"/rft/runs/{run_id}/distributions", params=params
)
response = self.client.get(f"/rft/runs/{run_id}/distributions", params=params)
return {
"bins": response.get("bins", []),
"step": response.get("step"),
Expand Down
12 changes: 8 additions & 4 deletions packages/prime/src/prime_cli/commands/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ def generate_rl_config_template(environment: str | None = None) -> str:
# env_ratios = [0.5, 0.5]
# skip_verification = false
# seed = 42

# Optional: advanced run configuration (admin only)
# [run_config]
# custom_key = "custom_value"
'''


Expand Down Expand Up @@ -281,6 +285,7 @@ class RLConfig(BaseModel):
wandb: WandbConfig = Field(default_factory=WandbConfig)
env_file: List[str] = Field(default_factory=list) # deprecated, use env_files
env_files: List[str] = Field(default_factory=list)
run_config: Dict[str, Any] = Field(default_factory=dict) # advanced config (admin only)


def _format_validation_errors(errors: list[dict]) -> list[str]:
Expand Down Expand Up @@ -440,11 +445,9 @@ def warn(msg: str) -> None:
wandb_configured = cfg.wandb.entity or cfg.wandb.project
if wandb_configured and (not secrets or "WANDB_API_KEY" not in secrets):
console.print("[red]Configuration Error:[/red]")
console.print(
" WANDB_API_KEY is required when W&B monitoring is configured.\n"
)
console.print(" WANDB_API_KEY is required when W&B monitoring is configured.\n")
console.print("Provide it via:")
console.print(" - env_files in your config: env_files = [\"secrets.env\"]")
console.print(' - env_files in your config: env_files = ["secrets.env"]')
console.print(" - CLI flag: --env-file secrets.env")
console.print(" - CLI flag: -e WANDB_API_KEY=your-key")
console.print(
Expand Down Expand Up @@ -542,6 +545,7 @@ def warn(msg: str) -> None:
lora_alpha=cfg.lora_alpha,
oversampling_factor=cfg.oversampling_factor,
max_async_level=cfg.max_async_level,
run_config=cfg.run_config if cfg.run_config else None,
)

if output == "json":
Expand Down