Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/prime-mcp-server/src/prime_mcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ async def make_prime_request(
return await _client.get(endpoint, params=params)
elif method == "POST":
return await _client.post(endpoint, json=json_data)
elif method == "PUT":
return await _client.put(endpoint, json=json_data)
elif method == "DELETE":
return await _client.delete(endpoint)
elif method == "PATCH":
Expand Down
3 changes: 3 additions & 0 deletions packages/prime-mcp-server/src/prime_mcp/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ async def patch(self, endpoint: str, json: Optional[Dict[str, Any]] = None) -> D
async def delete(self, endpoint: str) -> Dict[str, Any]:
return await self.request("DELETE", endpoint)

async def put(self, endpoint: str, json: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
return await self.request("PUT", endpoint, json=json)

async def aclose(self) -> None:
await self.client.aclose()

Expand Down
229 changes: 228 additions & 1 deletion packages/prime-mcp-server/src/prime_mcp/mcp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from mcp.server.fastmcp import FastMCP

from prime_mcp.tools import availability, pods, ssh
from prime_mcp.tools import availability, pods, rl, ssh

mcp = FastMCP("primeintellect")

Expand Down Expand Up @@ -253,5 +253,232 @@ async def manage_ssh_keys(
return await ssh.manage_ssh_keys(action, key_name, public_key, key_id, offset, limit)


@mcp.tool()
async def list_rl_models() -> dict:
"""List all available RL models for training.

Returns models from healthy RL clusters that are ready to accept training jobs.
Check this before creating a run to see which models are available.

Returns:
List of available RL models with their names
"""
return await rl.list_rl_models()


@mcp.tool()
async def list_rl_runs(team_id: str | None = None) -> dict:
"""List RL training runs for the authenticated user.

If team_id is provided, returns runs for that team only (requires team membership).
If team_id is None, returns user's personal runs AND all runs from teams they're in.

Args:
team_id: Optional team ID to filter runs by team

Returns:
List of RL runs with status, configuration, and progress information
"""
return await rl.list_rl_runs(team_id)


@mcp.tool()
async def get_rl_run(run_id: str) -> dict:
"""Get detailed information about a specific RL training run.

Args:
run_id: Unique identifier of the RL run

Returns:
Detailed run information including:
- status: QUEUED, PENDING, RUNNING, COMPLETED, FAILED, STOPPED
- configuration: model, environments, hyperparameters
- progress: current step, started_at, completed_at
- error_message: if run failed
"""
return await rl.get_rl_run(run_id)


@mcp.tool()
async def create_rl_run(
model_name: str,
environments: list[dict],
rollouts_per_example: int,
seq_len: int,
max_steps: int,
name: str | None = None,
eval_config: dict | None = None,
wandb_entity: str | None = None,
wandb_project: str | None = None,
wandb_run_name: str | None = None,
wandb_api_key: str | None = None,
secrets: list[dict] | None = None,
team_id: str | None = None,
) -> dict:
"""Create a new RL (Reinforcement Learning) training run.

WORKFLOW:
1. First check available models with list_rl_models()
2. Configure your training environments
3. Optionally set up W&B monitoring with your API key
4. Create the run - it will be queued and start automatically

Args:
model_name: Model to fine-tune (e.g., "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B").
Use list_rl_models() to see available models.
environments: Training environments list. Each environment dict should have:
- id (required): Environment ID like "reverse-text" or hub slug "primeintellect/vf-math"
- name (optional): Display name for this environment
- args (optional): Dict of environment-specific arguments
rollouts_per_example: Number of rollouts per training example.
MUST be one of: 1, 2, 4, 8, 16, 32, 64, 128 (must divide batch size 128 evenly)
seq_len: Sequence length for training (context window size)
max_steps: Maximum number of training steps
name: Optional run name (auto-generated if not provided)
eval_config: Optional evaluation configuration dict with:
- environments: List of eval environments (same format as training)
- interval: Evaluate every N steps (default: 100)
- num_examples: Examples per environment (-1 for all)
- rollouts_per_example: Rollouts per eval example (default: 1)
- eval_base_model: Whether to eval base model first (default: True)
wandb_entity: W&B entity (username or team name) for metrics logging
wandb_project: W&B project name - REQUIRED if you want monitoring
wandb_run_name: W&B run name (optional)
wandb_api_key: Your W&B API key - REQUIRED for W&B monitoring
secrets: Additional secrets as list of {"key": "NAME", "value": "secret"} dicts
team_id: Team ID to create run for (requires team membership)

Returns:
Created RL run details including run ID and initial status (QUEUED)

Example:
create_rl_run(
model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
environments=[{"id": "reverse-text"}],
rollouts_per_example=8,
seq_len=2048,
max_steps=1000,
wandb_project="my-rl-project",
wandb_api_key="your-wandb-key"
)
"""
return await rl.create_rl_run(
model_name=model_name,
environments=environments,
rollouts_per_example=rollouts_per_example,
seq_len=seq_len,
max_steps=max_steps,
name=name,
eval_config=eval_config,
wandb_entity=wandb_entity,
wandb_project=wandb_project,
wandb_run_name=wandb_run_name,
wandb_api_key=wandb_api_key,
secrets=secrets,
team_id=team_id,
)


@mcp.tool()
async def stop_rl_run(run_id: str) -> dict:
"""Stop/abort a running RL training run.

Can only stop runs that are in QUEUED, PENDING, or RUNNING status.
The run will save any adapter weights produced so far before stopping.

Args:
run_id: Unique identifier of the RL run to stop

Returns:
Updated run details with STOPPED status
"""
return await rl.stop_rl_run(run_id)


@mcp.tool()
async def delete_rl_run(run_id: str) -> dict:
"""Delete an RL training run.

This will:
- Cleanup Kubernetes resources if still running
- Delete the run record from the database
- Note: Adapter weights from completed runs are preserved separately

Args:
run_id: Unique identifier of the RL run to delete

Returns:
Deletion confirmation with run_id and success status
"""
return await rl.delete_rl_run(run_id)


@mcp.tool()
async def get_rl_run_logs(run_id: str, tail_lines: int = 1000) -> dict:
"""Get training logs for an RL run.

Fetches logs from the orchestrator pod running the training job.
Useful for debugging failed runs or monitoring progress.

Args:
run_id: Unique identifier of the RL run
tail_lines: Number of lines to return from the end of logs (default: 1000)

Returns:
Log content as a string
"""
return await rl.get_rl_run_logs(run_id, tail_lines)


@mcp.tool()
async def list_rl_adapters(team_id: str | None = None) -> dict:
"""List trained adapters (LoRA weights) from completed RL runs.

Adapters are the output of successful RL training - they contain the
fine-tuned LoRA weights that can be used for inference.

Args:
team_id: Optional team ID to filter adapters by team

Returns:
List of adapters with:
- id: Adapter ID for use with inference
- display_name: Optional friendly name
- base_model: The base model these weights are for
- rl_run_id: The training run that produced this adapter
- status: PENDING, READY, or FAILED
"""
return await rl.list_rl_adapters(team_id)


@mcp.tool()
async def get_rl_adapter(adapter_id: str) -> dict:
"""Get detailed information about a specific adapter.

Args:
adapter_id: Unique identifier of the adapter

Returns:
Adapter details including base model, status, and source run
"""
return await rl.get_rl_adapter(adapter_id)


@mcp.tool()
async def delete_rl_adapter(adapter_id: str) -> dict:
"""Delete an adapter.

Removes the adapter record from the database.
Note: This deletes the database record but storage files may be retained.

Args:
adapter_id: Unique identifier of the adapter to delete

Returns:
Deletion confirmation with adapter_id and success status
"""
return await rl.delete_rl_adapter(adapter_id)


if __name__ == "__main__":
mcp.run(transport="stdio")
136 changes: 136 additions & 0 deletions packages/prime-mcp-server/src/prime_mcp/tools/rl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import Any

from prime_mcp.client import make_prime_request


async def list_rl_models() -> dict[str, Any]:
"""List available RL models from healthy clusters."""
response_data = await make_prime_request("GET", "rft/models")
if not response_data:
return {"error": "Unable to fetch RL models"}
return response_data


async def list_rl_runs(team_id: str | None = None) -> dict[str, Any]:
"""List RL runs. If team_id is None, returns personal + team runs."""
params = {"team_id": team_id} if team_id else {}
response_data = await make_prime_request("GET", "rft/runs", params=params if params else None)
if not response_data:
return {"error": "Unable to fetch RL runs"}
return response_data


async def get_rl_run(run_id: str) -> dict[str, Any]:
"""Get details of a specific RL run."""
response_data = await make_prime_request("GET", f"rft/runs/{run_id}")
if not response_data:
return {"error": f"Unable to fetch RL run: {run_id}"}
return response_data


async def create_rl_run(
model_name: str,
environments: list[dict[str, Any]],
rollouts_per_example: int,
seq_len: int,
max_steps: int,
name: str | None = None,
eval_config: dict[str, Any] | None = None,
wandb_entity: str | None = None,
wandb_project: str | None = None,
wandb_run_name: str | None = None,
wandb_api_key: str | None = None,
secrets: list[dict[str, str]] | None = None,
team_id: str | None = None,
) -> dict[str, Any]:
"""Create a new RL training run."""
valid_rollouts = [1, 2, 4, 8, 16, 32, 64, 128]
if rollouts_per_example not in valid_rollouts:
return {"error": f"rollouts_per_example must be one of {valid_rollouts}"}

request_body: dict[str, Any] = {
"model": {"name": model_name},
"environments": environments,
"rollouts_per_example": rollouts_per_example,
"seq_len": seq_len,
"max_steps": max_steps,
}

if name:
request_body["name"] = name

all_secrets = list(secrets) if secrets else []
if wandb_api_key:
all_secrets.append({"key": "WANDB_API_KEY", "value": wandb_api_key})
if all_secrets:
request_body["secrets"] = all_secrets

if wandb_project:
request_body["monitoring"] = {"wandb": {"project": wandb_project}}
if wandb_entity:
request_body["monitoring"]["wandb"]["entity"] = wandb_entity
if wandb_run_name:
request_body["monitoring"]["wandb"]["name"] = wandb_run_name

if eval_config:
request_body["eval"] = eval_config
if team_id:
request_body["team_id"] = team_id

response_data = await make_prime_request("POST", "rft/runs", json_data=request_body)
if not response_data:
return {"error": "Unable to create RL run"}
return response_data


async def stop_rl_run(run_id: str) -> dict[str, Any]:
"""Stop an RL training run."""
response_data = await make_prime_request("PUT", f"rft/runs/{run_id}/stop")
if not response_data:
return {"error": f"Unable to stop RL run: {run_id}"}
return response_data


async def delete_rl_run(run_id: str) -> dict[str, Any]:
"""Delete an RL training run."""
response_data = await make_prime_request("DELETE", f"rft/runs/{run_id}")
if response_data is not None and not response_data.get("error"):
return {"success": True, "run_id": run_id}
return response_data or {"error": f"Unable to delete RL run: {run_id}"}


async def get_rl_run_logs(run_id: str, tail_lines: int = 1000) -> dict[str, Any]:
"""Get logs for an RL run."""
response_data = await make_prime_request(
"GET", f"rft/runs/{run_id}/logs", params={"tail_lines": tail_lines}
)
if not response_data:
return {"error": f"Unable to fetch logs for RL run: {run_id}"}
return response_data


async def list_rl_adapters(team_id: str | None = None) -> dict[str, Any]:
"""List LoRA adapters from completed RL runs."""
params = {"team_id": team_id} if team_id else {}
response_data = await make_prime_request(
"GET", "rft/adapters", params=params if params else None
)
if not response_data:
return {"error": "Unable to fetch RL adapters"}
return response_data


async def get_rl_adapter(adapter_id: str) -> dict[str, Any]:
"""Get a specific adapter by ID."""
response_data = await make_prime_request("GET", f"rft/adapters/{adapter_id}")
if not response_data:
return {"error": f"Unable to fetch adapter: {adapter_id}"}
return response_data


async def delete_rl_adapter(adapter_id: str) -> dict[str, Any]:
"""Delete an adapter by ID."""
response_data = await make_prime_request("DELETE", f"rft/adapters/{adapter_id}")
if response_data is not None and not response_data.get("error"):
return {"success": True, "adapter_id": adapter_id}
return response_data or {"error": f"Unable to delete adapter: {adapter_id}"}