From 239008f9ac743ca3d01562f9787fbbcf848a1726 Mon Sep 17 00:00:00 2001 From: d42me Date: Sun, 4 Jan 2026 22:45:36 -0600 Subject: [PATCH 1/4] Add rft MCP. --- .../prime-mcp-server/src/prime_mcp/client.py | 2 + .../src/prime_mcp/core/client.py | 3 + .../prime-mcp-server/src/prime_mcp/mcp.py | 229 +++++++++++++- .../src/prime_mcp/tools/rft.py | 279 ++++++++++++++++++ 4 files changed, 512 insertions(+), 1 deletion(-) create mode 100644 packages/prime-mcp-server/src/prime_mcp/tools/rft.py diff --git a/packages/prime-mcp-server/src/prime_mcp/client.py b/packages/prime-mcp-server/src/prime_mcp/client.py index e2249693..991921df 100644 --- a/packages/prime-mcp-server/src/prime_mcp/client.py +++ b/packages/prime-mcp-server/src/prime_mcp/client.py @@ -27,6 +27,8 @@ async def make_prime_request( return await _client.get(endpoint, params=params) elif method == "POST": return await _client.post(endpoint, json=json_data) + elif method == "PUT": + return await _client.put(endpoint, json=json_data) elif method == "DELETE": return await _client.delete(endpoint) elif method == "PATCH": diff --git a/packages/prime-mcp-server/src/prime_mcp/core/client.py b/packages/prime-mcp-server/src/prime_mcp/core/client.py index 84e9d90f..062a23eb 100644 --- a/packages/prime-mcp-server/src/prime_mcp/core/client.py +++ b/packages/prime-mcp-server/src/prime_mcp/core/client.py @@ -113,6 +113,9 @@ async def patch(self, endpoint: str, json: Optional[Dict[str, Any]] = None) -> D async def delete(self, endpoint: str) -> Dict[str, Any]: return await self.request("DELETE", endpoint) + async def put(self, endpoint: str, json: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + return await self.request("PUT", endpoint, json=json) + async def aclose(self) -> None: await self.client.aclose() diff --git a/packages/prime-mcp-server/src/prime_mcp/mcp.py b/packages/prime-mcp-server/src/prime_mcp/mcp.py index 54de5ac9..0b068634 100644 --- a/packages/prime-mcp-server/src/prime_mcp/mcp.py +++ b/packages/prime-mcp-server/src/prime_mcp/mcp.py @@ -1,6 +1,6 @@ from mcp.server.fastmcp import FastMCP -from prime_mcp.tools import availability, pods, ssh +from prime_mcp.tools import availability, pods, rft, ssh mcp = FastMCP("primeintellect") @@ -253,5 +253,232 @@ async def manage_ssh_keys( return await ssh.manage_ssh_keys(action, key_name, public_key, key_id, offset, limit) +@mcp.tool() +async def list_rft_models() -> dict: + """List all available RFT models for training. + + Returns models from healthy RFT clusters that are ready to accept training jobs. + Check this before creating a run to see which models are available. + + Returns: + List of available RFT models with their names + """ + return await rft.list_rft_models() + + +@mcp.tool() +async def list_rft_runs(team_id: str | None = None) -> dict: + """List RFT training runs for the authenticated user. + + If team_id is provided, returns runs for that team only (requires team membership). + If team_id is None, returns user's personal runs AND all runs from teams they're in. + + Args: + team_id: Optional team ID to filter runs by team + + Returns: + List of RFT runs with status, configuration, and progress information + """ + return await rft.list_rft_runs(team_id) + + +@mcp.tool() +async def get_rft_run(run_id: str) -> dict: + """Get detailed information about a specific RFT training run. + + Args: + run_id: Unique identifier of the RFT run + + Returns: + Detailed run information including: + - status: QUEUED, PENDING, RUNNING, COMPLETED, FAILED, STOPPED + - configuration: model, environments, hyperparameters + - progress: current step, started_at, completed_at + - error_message: if run failed + """ + return await rft.get_rft_run(run_id) + + +@mcp.tool() +async def create_rft_run( + model_name: str, + environments: list[dict], + rollouts_per_example: int, + seq_len: int, + max_steps: int, + name: str | None = None, + eval_config: dict | None = None, + wandb_entity: str | None = None, + wandb_project: str | None = None, + wandb_run_name: str | None = None, + wandb_api_key: str | None = None, + secrets: list[dict] | None = None, + team_id: str | None = None, +) -> dict: + """Create a new RFT (Reinforcement Fine-Tuning) training run. + + WORKFLOW: + 1. First check available models with list_rft_models() + 2. Configure your training environments + 3. Optionally set up W&B monitoring with your API key + 4. Create the run - it will be queued and start automatically + + Args: + model_name: Model to fine-tune (e.g., "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"). + Use list_rft_models() to see available models. + environments: Training environments list. Each environment dict should have: + - id (required): Environment ID like "reverse-text" or hub slug "primeintellect/vf-math" + - name (optional): Display name for this environment + - args (optional): Dict of environment-specific arguments + rollouts_per_example: Number of rollouts per training example. + MUST be one of: 1, 2, 4, 8, 16, 32, 64, 128 (must divide batch size 128 evenly) + seq_len: Sequence length for training (context window size) + max_steps: Maximum number of training steps + name: Optional run name (auto-generated if not provided) + eval_config: Optional evaluation configuration dict with: + - environments: List of eval environments (same format as training) + - interval: Evaluate every N steps (default: 100) + - num_examples: Examples per environment (-1 for all) + - rollouts_per_example: Rollouts per eval example (default: 1) + - eval_base_model: Whether to eval base model first (default: True) + wandb_entity: W&B entity (username or team name) for metrics logging + wandb_project: W&B project name - REQUIRED if you want monitoring + wandb_run_name: W&B run name (optional) + wandb_api_key: Your W&B API key - REQUIRED for W&B monitoring + secrets: Additional secrets as list of {"key": "NAME", "value": "secret"} dicts + team_id: Team ID to create run for (requires team membership) + + Returns: + Created RFT run details including run ID and initial status (QUEUED) + + Example: + create_rft_run( + model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + environments=[{"id": "reverse-text"}], + rollouts_per_example=8, + seq_len=2048, + max_steps=1000, + wandb_project="my-rft-project", + wandb_api_key="your-wandb-key" + ) + """ + return await rft.create_rft_run( + model_name=model_name, + environments=environments, + rollouts_per_example=rollouts_per_example, + seq_len=seq_len, + max_steps=max_steps, + name=name, + eval_config=eval_config, + wandb_entity=wandb_entity, + wandb_project=wandb_project, + wandb_run_name=wandb_run_name, + wandb_api_key=wandb_api_key, + secrets=secrets, + team_id=team_id, + ) + + +@mcp.tool() +async def stop_rft_run(run_id: str) -> dict: + """Stop/abort a running RFT training run. + + Can only stop runs that are in QUEUED, PENDING, or RUNNING status. + The run will save any adapter weights produced so far before stopping. + + Args: + run_id: Unique identifier of the RFT run to stop + + Returns: + Updated run details with STOPPED status + """ + return await rft.stop_rft_run(run_id) + + +@mcp.tool() +async def delete_rft_run(run_id: str) -> dict: + """Delete an RFT training run. + + This will: + - Cleanup Kubernetes resources if still running + - Delete the run record from the database + - Note: Adapter weights from completed runs are preserved separately + + Args: + run_id: Unique identifier of the RFT run to delete + + Returns: + Deletion confirmation with run_id and success status + """ + return await rft.delete_rft_run(run_id) + + +@mcp.tool() +async def get_rft_run_logs(run_id: str, tail_lines: int = 1000) -> dict: + """Get training logs for an RFT run. + + Fetches logs from the orchestrator pod running the training job. + Useful for debugging failed runs or monitoring progress. + + Args: + run_id: Unique identifier of the RFT run + tail_lines: Number of lines to return from the end of logs (default: 1000) + + Returns: + Log content as a string + """ + return await rft.get_rft_run_logs(run_id, tail_lines) + + +@mcp.tool() +async def list_rft_adapters(team_id: str | None = None) -> dict: + """List trained adapters (LoRA weights) from completed RFT runs. + + Adapters are the output of successful RFT training - they contain the + fine-tuned LoRA weights that can be used for inference. + + Args: + team_id: Optional team ID to filter adapters by team + + Returns: + List of adapters with: + - id: Adapter ID for use with inference + - display_name: Optional friendly name + - base_model: The base model these weights are for + - rft_run_id: The training run that produced this adapter + - status: PENDING, READY, or FAILED + """ + return await rft.list_rft_adapters(team_id) + + +@mcp.tool() +async def get_rft_adapter(adapter_id: str) -> dict: + """Get detailed information about a specific adapter. + + Args: + adapter_id: Unique identifier of the adapter + + Returns: + Adapter details including base model, status, and source run + """ + return await rft.get_rft_adapter(adapter_id) + + +@mcp.tool() +async def delete_rft_adapter(adapter_id: str) -> dict: + """Delete an adapter. + + Removes the adapter record from the database. + Note: This deletes the database record but storage files may be retained. + + Args: + adapter_id: Unique identifier of the adapter to delete + + Returns: + Deletion confirmation with adapter_id and success status + """ + return await rft.delete_rft_adapter(adapter_id) + + if __name__ == "__main__": mcp.run(transport="stdio") diff --git a/packages/prime-mcp-server/src/prime_mcp/tools/rft.py b/packages/prime-mcp-server/src/prime_mcp/tools/rft.py new file mode 100644 index 00000000..43d80838 --- /dev/null +++ b/packages/prime-mcp-server/src/prime_mcp/tools/rft.py @@ -0,0 +1,279 @@ +from typing import Any + +from prime_mcp.client import make_prime_request + + +async def list_rft_models() -> dict[str, Any]: + """List all available RFT models. + + Returns models from healthy RFT clusters (heartbeat within last 1 minute). + + Returns: + List of available RFT models with their names + """ + response_data = await make_prime_request("GET", "rft/models") + + if not response_data: + return {"error": "Unable to fetch RFT models"} + + return response_data + + +async def list_rft_runs(team_id: str | None = None) -> dict[str, Any]: + """List RFT runs for the authenticated user. + + If team_id is provided, returns runs for that team only (requires team membership). + If team_id is None, returns user's personal runs AND all runs from teams they're in. + + Args: + team_id: Optional team ID to filter runs + + Returns: + List of RFT runs with their details + """ + params = {} + if team_id: + params["team_id"] = team_id + + response_data = await make_prime_request("GET", "rft/runs", params=params if params else None) + + if not response_data: + return {"error": "Unable to fetch RFT runs"} + + return response_data + + +async def get_rft_run(run_id: str) -> dict[str, Any]: + """Get details of a specific RFT run. + + Args: + run_id: Unique identifier of the RFT run + + Returns: + Detailed RFT run information including status, configuration, and progress + """ + response_data = await make_prime_request("GET", f"rft/runs/{run_id}") + + if not response_data: + return {"error": f"Unable to fetch RFT run: {run_id}"} + + return response_data + + +async def create_rft_run( + model_name: str, + environments: list[dict[str, Any]], + rollouts_per_example: int, + seq_len: int, + max_steps: int, + name: str | None = None, + eval_config: dict[str, Any] | None = None, + wandb_entity: str | None = None, + wandb_project: str | None = None, + wandb_run_name: str | None = None, + wandb_api_key: str | None = None, + secrets: list[dict[str, str]] | None = None, + team_id: str | None = None, +) -> dict[str, Any]: + """Create a new RFT training run. + + IMPORTANT PREREQUISITES: + 1. Check available models with list_rft_models() first + 2. Ensure you have a W&B API key if you want monitoring + + Args: + model_name: Model name/path (e.g., "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B") + environments: List of training environments. Each environment should have: + - id: Environment ID (e.g., "reverse-text" or "primeintellect/vf-math") + - name: Optional display name + - args: Optional environment-specific arguments dict + rollouts_per_example: Number of rollouts per example. + Must divide 128: valid values 1,2,4,8,16,32,64,128 + seq_len: Sequence length for training + max_steps: Maximum training steps + name: Optional run name (auto-generated if not provided) + eval_config: Optional evaluation configuration with: + - environments: List of eval environments (same format as training) + - interval: Evaluate every N steps (default: 100) + - num_examples: Number of examples per environment (-1 for all) + - rollouts_per_example: Rollouts per example (default: 1) + - eval_base_model: Whether to eval base model before training (default: True) + wandb_entity: W&B entity (username or team) for monitoring + wandb_project: W&B project name for monitoring + wandb_run_name: W&B run name + wandb_api_key: W&B API key for authentication (passed as secret) + secrets: Additional secrets as list of {key, value} dicts + team_id: Optional team ID to create run for team + + Returns: + Created RFT run details including run ID and status + """ + # Validate rollouts_per_example + valid_rollouts = [1, 2, 4, 8, 16, 32, 64, 128] + if rollouts_per_example not in valid_rollouts: + return {"error": f"rollouts_per_example must be one of {valid_rollouts}"} + + # Build request body + request_body: dict[str, Any] = { + "model": {"name": model_name}, + "environments": environments, + "rollouts_per_example": rollouts_per_example, + "seq_len": seq_len, + "max_steps": max_steps, + } + + if name: + request_body["name"] = name + + # Build secrets list + all_secrets = secrets or [] + if wandb_api_key: + all_secrets.append({"key": "WANDB_API_KEY", "value": wandb_api_key}) + + if all_secrets: + request_body["secrets"] = all_secrets + + # Add monitoring config if W&B settings provided + if wandb_project: + request_body["monitoring"] = { + "wandb": { + "project": wandb_project, + } + } + if wandb_entity: + request_body["monitoring"]["wandb"]["entity"] = wandb_entity + if wandb_run_name: + request_body["monitoring"]["wandb"]["name"] = wandb_run_name + + # Add eval config if provided + if eval_config: + request_body["eval"] = eval_config + + if team_id: + request_body["team_id"] = team_id + + response_data = await make_prime_request("POST", "rft/runs", json_data=request_body) + + if not response_data: + return {"error": "Unable to create RFT run"} + + return response_data + + +async def stop_rft_run(run_id: str) -> dict[str, Any]: + """Stop/abort a running RFT training run. + + Can only stop runs in QUEUED, PENDING, or RUNNING status. + + Args: + run_id: Unique identifier of the RFT run to stop + + Returns: + Updated RFT run details with STOPPED status + """ + response_data = await make_prime_request("PUT", f"rft/runs/{run_id}/stop") + + if not response_data: + return {"error": f"Unable to stop RFT run: {run_id}"} + + return response_data + + +async def delete_rft_run(run_id: str) -> dict[str, Any]: + """Delete an RFT training run. + + This will cleanup Kubernetes resources and delete the run from the database. + + Args: + run_id: Unique identifier of the RFT run to delete + + Returns: + Deletion confirmation with run_id and success status + """ + response_data = await make_prime_request("DELETE", f"rft/runs/{run_id}") + + if not response_data: + return {"error": f"Unable to delete RFT run: {run_id}"} + + return response_data + + +async def get_rft_run_logs(run_id: str, tail_lines: int = 1000) -> dict[str, Any]: + """Get orchestrator logs for an RFT run. + + Args: + run_id: Unique identifier of the RFT run + tail_lines: Number of lines to tail from the end of logs (default: 1000) + + Returns: + Pod logs as a string + """ + params = {"tail_lines": tail_lines} + + response_data = await make_prime_request("GET", f"rft/runs/{run_id}/logs", params=params) + + if not response_data: + return {"error": f"Unable to fetch logs for RFT run: {run_id}"} + + return response_data + + +async def list_rft_adapters(team_id: str | None = None) -> dict[str, Any]: + """List adapters for the authenticated user. + + Adapters are LoRA weights produced by completed RFT training runs. + + Args: + team_id: Optional team ID to filter adapters + + Returns: + List of adapters with their details (ID, base model, status, etc.) + """ + params = {} + if team_id: + params["team_id"] = team_id + + response_data = await make_prime_request( + "GET", "rft/adapters", params=params if params else None + ) + + if not response_data: + return {"error": "Unable to fetch RFT adapters"} + + return response_data + + +async def get_rft_adapter(adapter_id: str) -> dict[str, Any]: + """Get a specific adapter by ID. + + Args: + adapter_id: Unique identifier of the adapter + + Returns: + Adapter details including ID, base model, status, and associated run + """ + response_data = await make_prime_request("GET", f"rft/adapters/{adapter_id}") + + if not response_data: + return {"error": f"Unable to fetch adapter: {adapter_id}"} + + return response_data + + +async def delete_rft_adapter(adapter_id: str) -> dict[str, Any]: + """Delete an adapter by ID. + + Note: This only deletes the database record. Storage files are not automatically cleaned up. + + Args: + adapter_id: Unique identifier of the adapter to delete + + Returns: + Deletion confirmation with adapter_id and success status + """ + response_data = await make_prime_request("DELETE", f"rft/adapters/{adapter_id}") + + if not response_data: + return {"error": f"Unable to delete adapter: {adapter_id}"} + + return response_data From 2cc125c18bf5c4ca4147eebb3cc1f87c0052afd8 Mon Sep 17 00:00:00 2001 From: d42me Date: Mon, 5 Jan 2026 09:07:41 -0600 Subject: [PATCH 2/4] Fix list secrets. --- packages/prime-mcp-server/src/prime_mcp/tools/rft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/prime-mcp-server/src/prime_mcp/tools/rft.py b/packages/prime-mcp-server/src/prime_mcp/tools/rft.py index 43d80838..266f6f2f 100644 --- a/packages/prime-mcp-server/src/prime_mcp/tools/rft.py +++ b/packages/prime-mcp-server/src/prime_mcp/tools/rft.py @@ -125,8 +125,8 @@ async def create_rft_run( if name: request_body["name"] = name - # Build secrets list - all_secrets = secrets or [] + # Build secrets list (copy to avoid mutating caller's list) + all_secrets = list(secrets) if secrets else [] if wandb_api_key: all_secrets.append({"key": "WANDB_API_KEY", "value": wandb_api_key}) From 2e4a0472334d415500dcf8bcf5b0d95dc2e05c17 Mon Sep 17 00:00:00 2001 From: d42me Date: Mon, 5 Jan 2026 09:47:19 -0600 Subject: [PATCH 3/4] Update naming. --- .../prime-mcp-server/src/prime_mcp/mcp.py | 86 +++--- .../src/prime_mcp/tools/rft.py | 279 ------------------ .../src/prime_mcp/tools/rl.py | 136 +++++++++ 3 files changed, 179 insertions(+), 322 deletions(-) delete mode 100644 packages/prime-mcp-server/src/prime_mcp/tools/rft.py create mode 100644 packages/prime-mcp-server/src/prime_mcp/tools/rl.py diff --git a/packages/prime-mcp-server/src/prime_mcp/mcp.py b/packages/prime-mcp-server/src/prime_mcp/mcp.py index 0b068634..d89d8edd 100644 --- a/packages/prime-mcp-server/src/prime_mcp/mcp.py +++ b/packages/prime-mcp-server/src/prime_mcp/mcp.py @@ -1,6 +1,6 @@ from mcp.server.fastmcp import FastMCP -from prime_mcp.tools import availability, pods, rft, ssh +from prime_mcp.tools import availability, pods, rl, ssh mcp = FastMCP("primeintellect") @@ -254,21 +254,21 @@ async def manage_ssh_keys( @mcp.tool() -async def list_rft_models() -> dict: - """List all available RFT models for training. +async def list_rl_models() -> dict: + """List all available RL models for training. - Returns models from healthy RFT clusters that are ready to accept training jobs. + Returns models from healthy RL clusters that are ready to accept training jobs. Check this before creating a run to see which models are available. Returns: - List of available RFT models with their names + List of available RL models with their names """ - return await rft.list_rft_models() + return await rl.list_rl_models() @mcp.tool() -async def list_rft_runs(team_id: str | None = None) -> dict: - """List RFT training runs for the authenticated user. +async def list_rl_runs(team_id: str | None = None) -> dict: + """List RL training runs for the authenticated user. If team_id is provided, returns runs for that team only (requires team membership). If team_id is None, returns user's personal runs AND all runs from teams they're in. @@ -277,17 +277,17 @@ async def list_rft_runs(team_id: str | None = None) -> dict: team_id: Optional team ID to filter runs by team Returns: - List of RFT runs with status, configuration, and progress information + List of RL runs with status, configuration, and progress information """ - return await rft.list_rft_runs(team_id) + return await rl.list_rl_runs(team_id) @mcp.tool() -async def get_rft_run(run_id: str) -> dict: - """Get detailed information about a specific RFT training run. +async def get_rl_run(run_id: str) -> dict: + """Get detailed information about a specific RL training run. Args: - run_id: Unique identifier of the RFT run + run_id: Unique identifier of the RL run Returns: Detailed run information including: @@ -296,11 +296,11 @@ async def get_rft_run(run_id: str) -> dict: - progress: current step, started_at, completed_at - error_message: if run failed """ - return await rft.get_rft_run(run_id) + return await rl.get_rl_run(run_id) @mcp.tool() -async def create_rft_run( +async def create_rl_run( model_name: str, environments: list[dict], rollouts_per_example: int, @@ -315,17 +315,17 @@ async def create_rft_run( secrets: list[dict] | None = None, team_id: str | None = None, ) -> dict: - """Create a new RFT (Reinforcement Fine-Tuning) training run. + """Create a new RL (Reinforcement Learning) training run. WORKFLOW: - 1. First check available models with list_rft_models() + 1. First check available models with list_rl_models() 2. Configure your training environments 3. Optionally set up W&B monitoring with your API key 4. Create the run - it will be queued and start automatically Args: model_name: Model to fine-tune (e.g., "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"). - Use list_rft_models() to see available models. + Use list_rl_models() to see available models. environments: Training environments list. Each environment dict should have: - id (required): Environment ID like "reverse-text" or hub slug "primeintellect/vf-math" - name (optional): Display name for this environment @@ -349,20 +349,20 @@ async def create_rft_run( team_id: Team ID to create run for (requires team membership) Returns: - Created RFT run details including run ID and initial status (QUEUED) + Created RL run details including run ID and initial status (QUEUED) Example: - create_rft_run( + create_rl_run( model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", environments=[{"id": "reverse-text"}], rollouts_per_example=8, seq_len=2048, max_steps=1000, - wandb_project="my-rft-project", + wandb_project="my-rl-project", wandb_api_key="your-wandb-key" ) """ - return await rft.create_rft_run( + return await rl.create_rl_run( model_name=model_name, environments=environments, rollouts_per_example=rollouts_per_example, @@ -380,24 +380,24 @@ async def create_rft_run( @mcp.tool() -async def stop_rft_run(run_id: str) -> dict: - """Stop/abort a running RFT training run. +async def stop_rl_run(run_id: str) -> dict: + """Stop/abort a running RL training run. Can only stop runs that are in QUEUED, PENDING, or RUNNING status. The run will save any adapter weights produced so far before stopping. Args: - run_id: Unique identifier of the RFT run to stop + run_id: Unique identifier of the RL run to stop Returns: Updated run details with STOPPED status """ - return await rft.stop_rft_run(run_id) + return await rl.stop_rl_run(run_id) @mcp.tool() -async def delete_rft_run(run_id: str) -> dict: - """Delete an RFT training run. +async def delete_rl_run(run_id: str) -> dict: + """Delete an RL training run. This will: - Cleanup Kubernetes resources if still running @@ -405,36 +405,36 @@ async def delete_rft_run(run_id: str) -> dict: - Note: Adapter weights from completed runs are preserved separately Args: - run_id: Unique identifier of the RFT run to delete + run_id: Unique identifier of the RL run to delete Returns: Deletion confirmation with run_id and success status """ - return await rft.delete_rft_run(run_id) + return await rl.delete_rl_run(run_id) @mcp.tool() -async def get_rft_run_logs(run_id: str, tail_lines: int = 1000) -> dict: - """Get training logs for an RFT run. +async def get_rl_run_logs(run_id: str, tail_lines: int = 1000) -> dict: + """Get training logs for an RL run. Fetches logs from the orchestrator pod running the training job. Useful for debugging failed runs or monitoring progress. Args: - run_id: Unique identifier of the RFT run + run_id: Unique identifier of the RL run tail_lines: Number of lines to return from the end of logs (default: 1000) Returns: Log content as a string """ - return await rft.get_rft_run_logs(run_id, tail_lines) + return await rl.get_rl_run_logs(run_id, tail_lines) @mcp.tool() -async def list_rft_adapters(team_id: str | None = None) -> dict: - """List trained adapters (LoRA weights) from completed RFT runs. +async def list_rl_adapters(team_id: str | None = None) -> dict: + """List trained adapters (LoRA weights) from completed RL runs. - Adapters are the output of successful RFT training - they contain the + Adapters are the output of successful RL training - they contain the fine-tuned LoRA weights that can be used for inference. Args: @@ -445,14 +445,14 @@ async def list_rft_adapters(team_id: str | None = None) -> dict: - id: Adapter ID for use with inference - display_name: Optional friendly name - base_model: The base model these weights are for - - rft_run_id: The training run that produced this adapter + - rl_run_id: The training run that produced this adapter - status: PENDING, READY, or FAILED """ - return await rft.list_rft_adapters(team_id) + return await rl.list_rl_adapters(team_id) @mcp.tool() -async def get_rft_adapter(adapter_id: str) -> dict: +async def get_rl_adapter(adapter_id: str) -> dict: """Get detailed information about a specific adapter. Args: @@ -461,11 +461,11 @@ async def get_rft_adapter(adapter_id: str) -> dict: Returns: Adapter details including base model, status, and source run """ - return await rft.get_rft_adapter(adapter_id) + return await rl.get_rl_adapter(adapter_id) @mcp.tool() -async def delete_rft_adapter(adapter_id: str) -> dict: +async def delete_rl_adapter(adapter_id: str) -> dict: """Delete an adapter. Removes the adapter record from the database. @@ -477,7 +477,7 @@ async def delete_rft_adapter(adapter_id: str) -> dict: Returns: Deletion confirmation with adapter_id and success status """ - return await rft.delete_rft_adapter(adapter_id) + return await rl.delete_rl_adapter(adapter_id) if __name__ == "__main__": diff --git a/packages/prime-mcp-server/src/prime_mcp/tools/rft.py b/packages/prime-mcp-server/src/prime_mcp/tools/rft.py deleted file mode 100644 index 266f6f2f..00000000 --- a/packages/prime-mcp-server/src/prime_mcp/tools/rft.py +++ /dev/null @@ -1,279 +0,0 @@ -from typing import Any - -from prime_mcp.client import make_prime_request - - -async def list_rft_models() -> dict[str, Any]: - """List all available RFT models. - - Returns models from healthy RFT clusters (heartbeat within last 1 minute). - - Returns: - List of available RFT models with their names - """ - response_data = await make_prime_request("GET", "rft/models") - - if not response_data: - return {"error": "Unable to fetch RFT models"} - - return response_data - - -async def list_rft_runs(team_id: str | None = None) -> dict[str, Any]: - """List RFT runs for the authenticated user. - - If team_id is provided, returns runs for that team only (requires team membership). - If team_id is None, returns user's personal runs AND all runs from teams they're in. - - Args: - team_id: Optional team ID to filter runs - - Returns: - List of RFT runs with their details - """ - params = {} - if team_id: - params["team_id"] = team_id - - response_data = await make_prime_request("GET", "rft/runs", params=params if params else None) - - if not response_data: - return {"error": "Unable to fetch RFT runs"} - - return response_data - - -async def get_rft_run(run_id: str) -> dict[str, Any]: - """Get details of a specific RFT run. - - Args: - run_id: Unique identifier of the RFT run - - Returns: - Detailed RFT run information including status, configuration, and progress - """ - response_data = await make_prime_request("GET", f"rft/runs/{run_id}") - - if not response_data: - return {"error": f"Unable to fetch RFT run: {run_id}"} - - return response_data - - -async def create_rft_run( - model_name: str, - environments: list[dict[str, Any]], - rollouts_per_example: int, - seq_len: int, - max_steps: int, - name: str | None = None, - eval_config: dict[str, Any] | None = None, - wandb_entity: str | None = None, - wandb_project: str | None = None, - wandb_run_name: str | None = None, - wandb_api_key: str | None = None, - secrets: list[dict[str, str]] | None = None, - team_id: str | None = None, -) -> dict[str, Any]: - """Create a new RFT training run. - - IMPORTANT PREREQUISITES: - 1. Check available models with list_rft_models() first - 2. Ensure you have a W&B API key if you want monitoring - - Args: - model_name: Model name/path (e.g., "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B") - environments: List of training environments. Each environment should have: - - id: Environment ID (e.g., "reverse-text" or "primeintellect/vf-math") - - name: Optional display name - - args: Optional environment-specific arguments dict - rollouts_per_example: Number of rollouts per example. - Must divide 128: valid values 1,2,4,8,16,32,64,128 - seq_len: Sequence length for training - max_steps: Maximum training steps - name: Optional run name (auto-generated if not provided) - eval_config: Optional evaluation configuration with: - - environments: List of eval environments (same format as training) - - interval: Evaluate every N steps (default: 100) - - num_examples: Number of examples per environment (-1 for all) - - rollouts_per_example: Rollouts per example (default: 1) - - eval_base_model: Whether to eval base model before training (default: True) - wandb_entity: W&B entity (username or team) for monitoring - wandb_project: W&B project name for monitoring - wandb_run_name: W&B run name - wandb_api_key: W&B API key for authentication (passed as secret) - secrets: Additional secrets as list of {key, value} dicts - team_id: Optional team ID to create run for team - - Returns: - Created RFT run details including run ID and status - """ - # Validate rollouts_per_example - valid_rollouts = [1, 2, 4, 8, 16, 32, 64, 128] - if rollouts_per_example not in valid_rollouts: - return {"error": f"rollouts_per_example must be one of {valid_rollouts}"} - - # Build request body - request_body: dict[str, Any] = { - "model": {"name": model_name}, - "environments": environments, - "rollouts_per_example": rollouts_per_example, - "seq_len": seq_len, - "max_steps": max_steps, - } - - if name: - request_body["name"] = name - - # Build secrets list (copy to avoid mutating caller's list) - all_secrets = list(secrets) if secrets else [] - if wandb_api_key: - all_secrets.append({"key": "WANDB_API_KEY", "value": wandb_api_key}) - - if all_secrets: - request_body["secrets"] = all_secrets - - # Add monitoring config if W&B settings provided - if wandb_project: - request_body["monitoring"] = { - "wandb": { - "project": wandb_project, - } - } - if wandb_entity: - request_body["monitoring"]["wandb"]["entity"] = wandb_entity - if wandb_run_name: - request_body["monitoring"]["wandb"]["name"] = wandb_run_name - - # Add eval config if provided - if eval_config: - request_body["eval"] = eval_config - - if team_id: - request_body["team_id"] = team_id - - response_data = await make_prime_request("POST", "rft/runs", json_data=request_body) - - if not response_data: - return {"error": "Unable to create RFT run"} - - return response_data - - -async def stop_rft_run(run_id: str) -> dict[str, Any]: - """Stop/abort a running RFT training run. - - Can only stop runs in QUEUED, PENDING, or RUNNING status. - - Args: - run_id: Unique identifier of the RFT run to stop - - Returns: - Updated RFT run details with STOPPED status - """ - response_data = await make_prime_request("PUT", f"rft/runs/{run_id}/stop") - - if not response_data: - return {"error": f"Unable to stop RFT run: {run_id}"} - - return response_data - - -async def delete_rft_run(run_id: str) -> dict[str, Any]: - """Delete an RFT training run. - - This will cleanup Kubernetes resources and delete the run from the database. - - Args: - run_id: Unique identifier of the RFT run to delete - - Returns: - Deletion confirmation with run_id and success status - """ - response_data = await make_prime_request("DELETE", f"rft/runs/{run_id}") - - if not response_data: - return {"error": f"Unable to delete RFT run: {run_id}"} - - return response_data - - -async def get_rft_run_logs(run_id: str, tail_lines: int = 1000) -> dict[str, Any]: - """Get orchestrator logs for an RFT run. - - Args: - run_id: Unique identifier of the RFT run - tail_lines: Number of lines to tail from the end of logs (default: 1000) - - Returns: - Pod logs as a string - """ - params = {"tail_lines": tail_lines} - - response_data = await make_prime_request("GET", f"rft/runs/{run_id}/logs", params=params) - - if not response_data: - return {"error": f"Unable to fetch logs for RFT run: {run_id}"} - - return response_data - - -async def list_rft_adapters(team_id: str | None = None) -> dict[str, Any]: - """List adapters for the authenticated user. - - Adapters are LoRA weights produced by completed RFT training runs. - - Args: - team_id: Optional team ID to filter adapters - - Returns: - List of adapters with their details (ID, base model, status, etc.) - """ - params = {} - if team_id: - params["team_id"] = team_id - - response_data = await make_prime_request( - "GET", "rft/adapters", params=params if params else None - ) - - if not response_data: - return {"error": "Unable to fetch RFT adapters"} - - return response_data - - -async def get_rft_adapter(adapter_id: str) -> dict[str, Any]: - """Get a specific adapter by ID. - - Args: - adapter_id: Unique identifier of the adapter - - Returns: - Adapter details including ID, base model, status, and associated run - """ - response_data = await make_prime_request("GET", f"rft/adapters/{adapter_id}") - - if not response_data: - return {"error": f"Unable to fetch adapter: {adapter_id}"} - - return response_data - - -async def delete_rft_adapter(adapter_id: str) -> dict[str, Any]: - """Delete an adapter by ID. - - Note: This only deletes the database record. Storage files are not automatically cleaned up. - - Args: - adapter_id: Unique identifier of the adapter to delete - - Returns: - Deletion confirmation with adapter_id and success status - """ - response_data = await make_prime_request("DELETE", f"rft/adapters/{adapter_id}") - - if not response_data: - return {"error": f"Unable to delete adapter: {adapter_id}"} - - return response_data diff --git a/packages/prime-mcp-server/src/prime_mcp/tools/rl.py b/packages/prime-mcp-server/src/prime_mcp/tools/rl.py new file mode 100644 index 00000000..303f1485 --- /dev/null +++ b/packages/prime-mcp-server/src/prime_mcp/tools/rl.py @@ -0,0 +1,136 @@ +from typing import Any + +from prime_mcp.client import make_prime_request + + +async def list_rl_models() -> dict[str, Any]: + """List available RL models from healthy clusters.""" + response_data = await make_prime_request("GET", "rft/models") + if not response_data: + return {"error": "Unable to fetch RL models"} + return response_data + + +async def list_rl_runs(team_id: str | None = None) -> dict[str, Any]: + """List RL runs. If team_id is None, returns personal + team runs.""" + params = {"team_id": team_id} if team_id else {} + response_data = await make_prime_request("GET", "rft/runs", params=params if params else None) + if not response_data: + return {"error": "Unable to fetch RL runs"} + return response_data + + +async def get_rl_run(run_id: str) -> dict[str, Any]: + """Get details of a specific RL run.""" + response_data = await make_prime_request("GET", f"rft/runs/{run_id}") + if not response_data: + return {"error": f"Unable to fetch RL run: {run_id}"} + return response_data + + +async def create_rl_run( + model_name: str, + environments: list[dict[str, Any]], + rollouts_per_example: int, + seq_len: int, + max_steps: int, + name: str | None = None, + eval_config: dict[str, Any] | None = None, + wandb_entity: str | None = None, + wandb_project: str | None = None, + wandb_run_name: str | None = None, + wandb_api_key: str | None = None, + secrets: list[dict[str, str]] | None = None, + team_id: str | None = None, +) -> dict[str, Any]: + """Create a new RL training run.""" + valid_rollouts = [1, 2, 4, 8, 16, 32, 64, 128] + if rollouts_per_example not in valid_rollouts: + return {"error": f"rollouts_per_example must be one of {valid_rollouts}"} + + request_body: dict[str, Any] = { + "model": {"name": model_name}, + "environments": environments, + "rollouts_per_example": rollouts_per_example, + "seq_len": seq_len, + "max_steps": max_steps, + } + + if name: + request_body["name"] = name + + all_secrets = list(secrets) if secrets else [] + if wandb_api_key: + all_secrets.append({"key": "WANDB_API_KEY", "value": wandb_api_key}) + if all_secrets: + request_body["secrets"] = all_secrets + + if wandb_project: + request_body["monitoring"] = {"wandb": {"project": wandb_project}} + if wandb_entity: + request_body["monitoring"]["wandb"]["entity"] = wandb_entity + if wandb_run_name: + request_body["monitoring"]["wandb"]["name"] = wandb_run_name + + if eval_config: + request_body["eval"] = eval_config + if team_id: + request_body["team_id"] = team_id + + response_data = await make_prime_request("POST", "rft/runs", json_data=request_body) + if not response_data: + return {"error": "Unable to create RL run"} + return response_data + + +async def stop_rl_run(run_id: str) -> dict[str, Any]: + """Stop an RL training run.""" + response_data = await make_prime_request("PUT", f"rft/runs/{run_id}/stop") + if not response_data: + return {"error": f"Unable to stop RL run: {run_id}"} + return response_data + + +async def delete_rl_run(run_id: str) -> dict[str, Any]: + """Delete an RL training run.""" + response_data = await make_prime_request("DELETE", f"rft/runs/{run_id}") + if not response_data: + return {"error": f"Unable to delete RL run: {run_id}"} + return response_data + + +async def get_rl_run_logs(run_id: str, tail_lines: int = 1000) -> dict[str, Any]: + """Get logs for an RL run.""" + response_data = await make_prime_request( + "GET", f"rft/runs/{run_id}/logs", params={"tail_lines": tail_lines} + ) + if not response_data: + return {"error": f"Unable to fetch logs for RL run: {run_id}"} + return response_data + + +async def list_rl_adapters(team_id: str | None = None) -> dict[str, Any]: + """List LoRA adapters from completed RL runs.""" + params = {"team_id": team_id} if team_id else {} + response_data = await make_prime_request( + "GET", "rft/adapters", params=params if params else None + ) + if not response_data: + return {"error": "Unable to fetch RL adapters"} + return response_data + + +async def get_rl_adapter(adapter_id: str) -> dict[str, Any]: + """Get a specific adapter by ID.""" + response_data = await make_prime_request("GET", f"rft/adapters/{adapter_id}") + if not response_data: + return {"error": f"Unable to fetch adapter: {adapter_id}"} + return response_data + + +async def delete_rl_adapter(adapter_id: str) -> dict[str, Any]: + """Delete an adapter by ID.""" + response_data = await make_prime_request("DELETE", f"rft/adapters/{adapter_id}") + if not response_data: + return {"error": f"Unable to delete adapter: {adapter_id}"} + return response_data From d46e4f57aebe891454a0d6d864d03359c5a32bca Mon Sep 17 00:00:00 2001 From: d42me Date: Mon, 5 Jan 2026 10:15:40 -0600 Subject: [PATCH 4/4] Fix bugbot comment. --- packages/prime-mcp-server/src/prime_mcp/tools/rl.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/packages/prime-mcp-server/src/prime_mcp/tools/rl.py b/packages/prime-mcp-server/src/prime_mcp/tools/rl.py index 303f1485..07b200da 100644 --- a/packages/prime-mcp-server/src/prime_mcp/tools/rl.py +++ b/packages/prime-mcp-server/src/prime_mcp/tools/rl.py @@ -94,9 +94,9 @@ async def stop_rl_run(run_id: str) -> dict[str, Any]: async def delete_rl_run(run_id: str) -> dict[str, Any]: """Delete an RL training run.""" response_data = await make_prime_request("DELETE", f"rft/runs/{run_id}") - if not response_data: - return {"error": f"Unable to delete RL run: {run_id}"} - return response_data + if response_data is not None and not response_data.get("error"): + return {"success": True, "run_id": run_id} + return response_data or {"error": f"Unable to delete RL run: {run_id}"} async def get_rl_run_logs(run_id: str, tail_lines: int = 1000) -> dict[str, Any]: @@ -131,6 +131,6 @@ async def get_rl_adapter(adapter_id: str) -> dict[str, Any]: async def delete_rl_adapter(adapter_id: str) -> dict[str, Any]: """Delete an adapter by ID.""" response_data = await make_prime_request("DELETE", f"rft/adapters/{adapter_id}") - if not response_data: - return {"error": f"Unable to delete adapter: {adapter_id}"} - return response_data + if response_data is not None and not response_data.get("error"): + return {"success": True, "adapter_id": adapter_id} + return response_data or {"error": f"Unable to delete adapter: {adapter_id}"}