diff --git a/packages/prime-mcp-server/src/prime_mcp/client.py b/packages/prime-mcp-server/src/prime_mcp/client.py index e2249693..991921df 100644 --- a/packages/prime-mcp-server/src/prime_mcp/client.py +++ b/packages/prime-mcp-server/src/prime_mcp/client.py @@ -27,6 +27,8 @@ async def make_prime_request( return await _client.get(endpoint, params=params) elif method == "POST": return await _client.post(endpoint, json=json_data) + elif method == "PUT": + return await _client.put(endpoint, json=json_data) elif method == "DELETE": return await _client.delete(endpoint) elif method == "PATCH": diff --git a/packages/prime-mcp-server/src/prime_mcp/core/client.py b/packages/prime-mcp-server/src/prime_mcp/core/client.py index 84e9d90f..062a23eb 100644 --- a/packages/prime-mcp-server/src/prime_mcp/core/client.py +++ b/packages/prime-mcp-server/src/prime_mcp/core/client.py @@ -113,6 +113,9 @@ async def patch(self, endpoint: str, json: Optional[Dict[str, Any]] = None) -> D async def delete(self, endpoint: str) -> Dict[str, Any]: return await self.request("DELETE", endpoint) + async def put(self, endpoint: str, json: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + return await self.request("PUT", endpoint, json=json) + async def aclose(self) -> None: await self.client.aclose() diff --git a/packages/prime-mcp-server/src/prime_mcp/mcp.py b/packages/prime-mcp-server/src/prime_mcp/mcp.py index 54de5ac9..d89d8edd 100644 --- a/packages/prime-mcp-server/src/prime_mcp/mcp.py +++ b/packages/prime-mcp-server/src/prime_mcp/mcp.py @@ -1,6 +1,6 @@ from mcp.server.fastmcp import FastMCP -from prime_mcp.tools import availability, pods, ssh +from prime_mcp.tools import availability, pods, rl, ssh mcp = FastMCP("primeintellect") @@ -253,5 +253,232 @@ async def manage_ssh_keys( return await ssh.manage_ssh_keys(action, key_name, public_key, key_id, offset, limit) +@mcp.tool() +async def list_rl_models() -> dict: + """List all available RL models for training. + + Returns models from healthy RL clusters that are ready to accept training jobs. + Check this before creating a run to see which models are available. + + Returns: + List of available RL models with their names + """ + return await rl.list_rl_models() + + +@mcp.tool() +async def list_rl_runs(team_id: str | None = None) -> dict: + """List RL training runs for the authenticated user. + + If team_id is provided, returns runs for that team only (requires team membership). + If team_id is None, returns user's personal runs AND all runs from teams they're in. + + Args: + team_id: Optional team ID to filter runs by team + + Returns: + List of RL runs with status, configuration, and progress information + """ + return await rl.list_rl_runs(team_id) + + +@mcp.tool() +async def get_rl_run(run_id: str) -> dict: + """Get detailed information about a specific RL training run. + + Args: + run_id: Unique identifier of the RL run + + Returns: + Detailed run information including: + - status: QUEUED, PENDING, RUNNING, COMPLETED, FAILED, STOPPED + - configuration: model, environments, hyperparameters + - progress: current step, started_at, completed_at + - error_message: if run failed + """ + return await rl.get_rl_run(run_id) + + +@mcp.tool() +async def create_rl_run( + model_name: str, + environments: list[dict], + rollouts_per_example: int, + seq_len: int, + max_steps: int, + name: str | None = None, + eval_config: dict | None = None, + wandb_entity: str | None = None, + wandb_project: str | None = None, + wandb_run_name: str | None = None, + wandb_api_key: str | None = None, + secrets: list[dict] | None = None, + team_id: str | None = None, +) -> dict: + """Create a new RL (Reinforcement Learning) training run. + + WORKFLOW: + 1. First check available models with list_rl_models() + 2. Configure your training environments + 3. Optionally set up W&B monitoring with your API key + 4. Create the run - it will be queued and start automatically + + Args: + model_name: Model to fine-tune (e.g., "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"). + Use list_rl_models() to see available models. + environments: Training environments list. Each environment dict should have: + - id (required): Environment ID like "reverse-text" or hub slug "primeintellect/vf-math" + - name (optional): Display name for this environment + - args (optional): Dict of environment-specific arguments + rollouts_per_example: Number of rollouts per training example. + MUST be one of: 1, 2, 4, 8, 16, 32, 64, 128 (must divide batch size 128 evenly) + seq_len: Sequence length for training (context window size) + max_steps: Maximum number of training steps + name: Optional run name (auto-generated if not provided) + eval_config: Optional evaluation configuration dict with: + - environments: List of eval environments (same format as training) + - interval: Evaluate every N steps (default: 100) + - num_examples: Examples per environment (-1 for all) + - rollouts_per_example: Rollouts per eval example (default: 1) + - eval_base_model: Whether to eval base model first (default: True) + wandb_entity: W&B entity (username or team name) for metrics logging + wandb_project: W&B project name - REQUIRED if you want monitoring + wandb_run_name: W&B run name (optional) + wandb_api_key: Your W&B API key - REQUIRED for W&B monitoring + secrets: Additional secrets as list of {"key": "NAME", "value": "secret"} dicts + team_id: Team ID to create run for (requires team membership) + + Returns: + Created RL run details including run ID and initial status (QUEUED) + + Example: + create_rl_run( + model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + environments=[{"id": "reverse-text"}], + rollouts_per_example=8, + seq_len=2048, + max_steps=1000, + wandb_project="my-rl-project", + wandb_api_key="your-wandb-key" + ) + """ + return await rl.create_rl_run( + model_name=model_name, + environments=environments, + rollouts_per_example=rollouts_per_example, + seq_len=seq_len, + max_steps=max_steps, + name=name, + eval_config=eval_config, + wandb_entity=wandb_entity, + wandb_project=wandb_project, + wandb_run_name=wandb_run_name, + wandb_api_key=wandb_api_key, + secrets=secrets, + team_id=team_id, + ) + + +@mcp.tool() +async def stop_rl_run(run_id: str) -> dict: + """Stop/abort a running RL training run. + + Can only stop runs that are in QUEUED, PENDING, or RUNNING status. + The run will save any adapter weights produced so far before stopping. + + Args: + run_id: Unique identifier of the RL run to stop + + Returns: + Updated run details with STOPPED status + """ + return await rl.stop_rl_run(run_id) + + +@mcp.tool() +async def delete_rl_run(run_id: str) -> dict: + """Delete an RL training run. + + This will: + - Cleanup Kubernetes resources if still running + - Delete the run record from the database + - Note: Adapter weights from completed runs are preserved separately + + Args: + run_id: Unique identifier of the RL run to delete + + Returns: + Deletion confirmation with run_id and success status + """ + return await rl.delete_rl_run(run_id) + + +@mcp.tool() +async def get_rl_run_logs(run_id: str, tail_lines: int = 1000) -> dict: + """Get training logs for an RL run. + + Fetches logs from the orchestrator pod running the training job. + Useful for debugging failed runs or monitoring progress. + + Args: + run_id: Unique identifier of the RL run + tail_lines: Number of lines to return from the end of logs (default: 1000) + + Returns: + Log content as a string + """ + return await rl.get_rl_run_logs(run_id, tail_lines) + + +@mcp.tool() +async def list_rl_adapters(team_id: str | None = None) -> dict: + """List trained adapters (LoRA weights) from completed RL runs. + + Adapters are the output of successful RL training - they contain the + fine-tuned LoRA weights that can be used for inference. + + Args: + team_id: Optional team ID to filter adapters by team + + Returns: + List of adapters with: + - id: Adapter ID for use with inference + - display_name: Optional friendly name + - base_model: The base model these weights are for + - rl_run_id: The training run that produced this adapter + - status: PENDING, READY, or FAILED + """ + return await rl.list_rl_adapters(team_id) + + +@mcp.tool() +async def get_rl_adapter(adapter_id: str) -> dict: + """Get detailed information about a specific adapter. + + Args: + adapter_id: Unique identifier of the adapter + + Returns: + Adapter details including base model, status, and source run + """ + return await rl.get_rl_adapter(adapter_id) + + +@mcp.tool() +async def delete_rl_adapter(adapter_id: str) -> dict: + """Delete an adapter. + + Removes the adapter record from the database. + Note: This deletes the database record but storage files may be retained. + + Args: + adapter_id: Unique identifier of the adapter to delete + + Returns: + Deletion confirmation with adapter_id and success status + """ + return await rl.delete_rl_adapter(adapter_id) + + if __name__ == "__main__": mcp.run(transport="stdio") diff --git a/packages/prime-mcp-server/src/prime_mcp/tools/rl.py b/packages/prime-mcp-server/src/prime_mcp/tools/rl.py new file mode 100644 index 00000000..07b200da --- /dev/null +++ b/packages/prime-mcp-server/src/prime_mcp/tools/rl.py @@ -0,0 +1,136 @@ +from typing import Any + +from prime_mcp.client import make_prime_request + + +async def list_rl_models() -> dict[str, Any]: + """List available RL models from healthy clusters.""" + response_data = await make_prime_request("GET", "rft/models") + if not response_data: + return {"error": "Unable to fetch RL models"} + return response_data + + +async def list_rl_runs(team_id: str | None = None) -> dict[str, Any]: + """List RL runs. If team_id is None, returns personal + team runs.""" + params = {"team_id": team_id} if team_id else {} + response_data = await make_prime_request("GET", "rft/runs", params=params if params else None) + if not response_data: + return {"error": "Unable to fetch RL runs"} + return response_data + + +async def get_rl_run(run_id: str) -> dict[str, Any]: + """Get details of a specific RL run.""" + response_data = await make_prime_request("GET", f"rft/runs/{run_id}") + if not response_data: + return {"error": f"Unable to fetch RL run: {run_id}"} + return response_data + + +async def create_rl_run( + model_name: str, + environments: list[dict[str, Any]], + rollouts_per_example: int, + seq_len: int, + max_steps: int, + name: str | None = None, + eval_config: dict[str, Any] | None = None, + wandb_entity: str | None = None, + wandb_project: str | None = None, + wandb_run_name: str | None = None, + wandb_api_key: str | None = None, + secrets: list[dict[str, str]] | None = None, + team_id: str | None = None, +) -> dict[str, Any]: + """Create a new RL training run.""" + valid_rollouts = [1, 2, 4, 8, 16, 32, 64, 128] + if rollouts_per_example not in valid_rollouts: + return {"error": f"rollouts_per_example must be one of {valid_rollouts}"} + + request_body: dict[str, Any] = { + "model": {"name": model_name}, + "environments": environments, + "rollouts_per_example": rollouts_per_example, + "seq_len": seq_len, + "max_steps": max_steps, + } + + if name: + request_body["name"] = name + + all_secrets = list(secrets) if secrets else [] + if wandb_api_key: + all_secrets.append({"key": "WANDB_API_KEY", "value": wandb_api_key}) + if all_secrets: + request_body["secrets"] = all_secrets + + if wandb_project: + request_body["monitoring"] = {"wandb": {"project": wandb_project}} + if wandb_entity: + request_body["monitoring"]["wandb"]["entity"] = wandb_entity + if wandb_run_name: + request_body["monitoring"]["wandb"]["name"] = wandb_run_name + + if eval_config: + request_body["eval"] = eval_config + if team_id: + request_body["team_id"] = team_id + + response_data = await make_prime_request("POST", "rft/runs", json_data=request_body) + if not response_data: + return {"error": "Unable to create RL run"} + return response_data + + +async def stop_rl_run(run_id: str) -> dict[str, Any]: + """Stop an RL training run.""" + response_data = await make_prime_request("PUT", f"rft/runs/{run_id}/stop") + if not response_data: + return {"error": f"Unable to stop RL run: {run_id}"} + return response_data + + +async def delete_rl_run(run_id: str) -> dict[str, Any]: + """Delete an RL training run.""" + response_data = await make_prime_request("DELETE", f"rft/runs/{run_id}") + if response_data is not None and not response_data.get("error"): + return {"success": True, "run_id": run_id} + return response_data or {"error": f"Unable to delete RL run: {run_id}"} + + +async def get_rl_run_logs(run_id: str, tail_lines: int = 1000) -> dict[str, Any]: + """Get logs for an RL run.""" + response_data = await make_prime_request( + "GET", f"rft/runs/{run_id}/logs", params={"tail_lines": tail_lines} + ) + if not response_data: + return {"error": f"Unable to fetch logs for RL run: {run_id}"} + return response_data + + +async def list_rl_adapters(team_id: str | None = None) -> dict[str, Any]: + """List LoRA adapters from completed RL runs.""" + params = {"team_id": team_id} if team_id else {} + response_data = await make_prime_request( + "GET", "rft/adapters", params=params if params else None + ) + if not response_data: + return {"error": "Unable to fetch RL adapters"} + return response_data + + +async def get_rl_adapter(adapter_id: str) -> dict[str, Any]: + """Get a specific adapter by ID.""" + response_data = await make_prime_request("GET", f"rft/adapters/{adapter_id}") + if not response_data: + return {"error": f"Unable to fetch adapter: {adapter_id}"} + return response_data + + +async def delete_rl_adapter(adapter_id: str) -> dict[str, Any]: + """Delete an adapter by ID.""" + response_data = await make_prime_request("DELETE", f"rft/adapters/{adapter_id}") + if response_data is not None and not response_data.get("error"): + return {"success": True, "adapter_id": adapter_id} + return response_data or {"error": f"Unable to delete adapter: {adapter_id}"}