From 02e1ecdbc07f02f756172e9081c2af0f4c7239b6 Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Wed, 21 Jan 2026 11:53:41 -0800 Subject: [PATCH] Add run_config support to RL command - Add run_config field to RLConfig for TOML parsing - Add run_config parameter to RLClient.create_run() - Pass run_config from CLI to API - Update config template with example --- packages/prime/src/prime_cli/api/rl.py | 8 +++++--- packages/prime/src/prime_cli/commands/rl.py | 12 ++++++++---- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/packages/prime/src/prime_cli/api/rl.py b/packages/prime/src/prime_cli/api/rl.py index a8beba46..c875a4e8 100644 --- a/packages/prime/src/prime_cli/api/rl.py +++ b/packages/prime/src/prime_cli/api/rl.py @@ -113,6 +113,7 @@ def create_run( lora_alpha: Optional[int] = None, oversampling_factor: Optional[float] = None, max_async_level: Optional[int] = None, + run_config: Optional[Dict[str, Any]] = None, ) -> RLRun: """Create a new RL training run.""" try: @@ -177,6 +178,9 @@ def create_run( if max_async_level is not None: payload["max_async_level"] = max_async_level + if run_config: + payload["run_config"] = run_config + response = self.client.post("/rft/runs", json=payload) return RLRun.model_validate(response.get("run")) except ValidationError: @@ -311,9 +315,7 @@ def get_distributions( if step is not None: params["step"] = step - response = self.client.get( - f"/rft/runs/{run_id}/distributions", params=params - ) + response = self.client.get(f"/rft/runs/{run_id}/distributions", params=params) return { "bins": response.get("bins", []), "step": response.get("step"), diff --git a/packages/prime/src/prime_cli/commands/rl.py b/packages/prime/src/prime_cli/commands/rl.py index 4439147e..0df523e8 100644 --- a/packages/prime/src/prime_cli/commands/rl.py +++ b/packages/prime/src/prime_cli/commands/rl.py @@ -129,6 +129,10 @@ def generate_rl_config_template(environment: str | None = None) -> str: # env_ratios = [0.5, 0.5] # skip_verification = false # seed = 42 + +# Optional: advanced run configuration (admin only) +# [run_config] +# custom_key = "custom_value" ''' @@ -281,6 +285,7 @@ class RLConfig(BaseModel): wandb: WandbConfig = Field(default_factory=WandbConfig) env_file: List[str] = Field(default_factory=list) # deprecated, use env_files env_files: List[str] = Field(default_factory=list) + run_config: Dict[str, Any] = Field(default_factory=dict) # advanced config (admin only) def _format_validation_errors(errors: list[dict]) -> list[str]: @@ -440,11 +445,9 @@ def warn(msg: str) -> None: wandb_configured = cfg.wandb.entity or cfg.wandb.project if wandb_configured and (not secrets or "WANDB_API_KEY" not in secrets): console.print("[red]Configuration Error:[/red]") - console.print( - " WANDB_API_KEY is required when W&B monitoring is configured.\n" - ) + console.print(" WANDB_API_KEY is required when W&B monitoring is configured.\n") console.print("Provide it via:") - console.print(" - env_files in your config: env_files = [\"secrets.env\"]") + console.print(' - env_files in your config: env_files = ["secrets.env"]') console.print(" - CLI flag: --env-file secrets.env") console.print(" - CLI flag: -e WANDB_API_KEY=your-key") console.print( @@ -542,6 +545,7 @@ def warn(msg: str) -> None: lora_alpha=cfg.lora_alpha, oversampling_factor=cfg.oversampling_factor, max_async_level=cfg.max_async_level, + run_config=cfg.run_config if cfg.run_config else None, ) if output == "json":