diff --git a/packages/prime-mcp-server/pyproject.toml b/packages/prime-mcp-server/pyproject.toml index 054aa240..d700d1e3 100644 --- a/packages/prime-mcp-server/pyproject.toml +++ b/packages/prime-mcp-server/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "httpx>=0.25.0", "mcp>=1.0.0", "fastmcp>=0.2.0", + "prime-sandboxes>=0.2.8", ] keywords = ["mcp", "model-context-protocol"] classifiers = [ @@ -49,6 +50,9 @@ packages = ["src/prime_mcp"] [tool.hatch.version] path = "src/prime_mcp/__init__.py" +[tool.uv.sources] +prime-sandboxes = { workspace = true } + [tool.ruff] line-length = 100 target-version = "py310" diff --git a/packages/prime-mcp-server/src/prime_mcp/__init__.py b/packages/prime-mcp-server/src/prime_mcp/__init__.py index d3e162ff..f88f001b 100644 --- a/packages/prime-mcp-server/src/prime_mcp/__init__.py +++ b/packages/prime-mcp-server/src/prime_mcp/__init__.py @@ -1,13 +1,14 @@ from prime_mcp.client import make_prime_request from prime_mcp.mcp import mcp -from prime_mcp.tools import availability, pods, ssh +from prime_mcp.tools import availability, pods, sandboxes, ssh -__version__ = "0.1.2" +__version__ = "0.1.3" __all__ = [ "mcp", "make_prime_request", "availability", "pods", + "sandboxes", "ssh", ] diff --git a/packages/prime-mcp-server/src/prime_mcp/client.py b/packages/prime-mcp-server/src/prime_mcp/client.py index e2249693..ebd0cb41 100644 --- a/packages/prime-mcp-server/src/prime_mcp/client.py +++ b/packages/prime-mcp-server/src/prime_mcp/client.py @@ -11,24 +11,14 @@ async def make_prime_request( params: dict[str, Any] | None = None, json_data: dict[str, Any] | None = None, ) -> dict[str, Any]: - """Make a request to the PrimeIntellect API with proper error handling. - - Args: - method: HTTP method (GET, POST, DELETE, PATCH) - endpoint: API endpoint (e.g., "/pods", "availability/") - params: Query parameters for GET requests - json_data: JSON body for POST/PATCH requests - - Returns: - API response as dictionary, or dict with "error" key on failure - """ + """Make a request to the PrimeIntellect API with proper error handling.""" try: if method == "GET": return await _client.get(endpoint, params=params) elif method == "POST": return await _client.post(endpoint, json=json_data) elif method == "DELETE": - return await _client.delete(endpoint) + return await _client.delete(endpoint, json=json_data) elif method == "PATCH": return await _client.patch(endpoint, json=json_data) else: diff --git a/packages/prime-mcp-server/src/prime_mcp/core/client.py b/packages/prime-mcp-server/src/prime_mcp/core/client.py index 84e9d90f..d0d59301 100644 --- a/packages/prime-mcp-server/src/prime_mcp/core/client.py +++ b/packages/prime-mcp-server/src/prime_mcp/core/client.py @@ -110,8 +110,8 @@ async def post(self, endpoint: str, json: Optional[Dict[str, Any]] = None) -> Di async def patch(self, endpoint: str, json: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: return await self.request("PATCH", endpoint, json=json) - async def delete(self, endpoint: str) -> Dict[str, Any]: - return await self.request("DELETE", endpoint) + async def delete(self, endpoint: str, json: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + return await self.request("DELETE", endpoint, json=json) async def aclose(self) -> None: await self.client.aclose() diff --git a/packages/prime-mcp-server/src/prime_mcp/mcp.py b/packages/prime-mcp-server/src/prime_mcp/mcp.py index 54de5ac9..c7ef82c4 100644 --- a/packages/prime-mcp-server/src/prime_mcp/mcp.py +++ b/packages/prime-mcp-server/src/prime_mcp/mcp.py @@ -1,6 +1,6 @@ from mcp.server.fastmcp import FastMCP -from prime_mcp.tools import availability, pods, ssh +from prime_mcp.tools import availability, pods, sandboxes, ssh mcp = FastMCP("primeintellect") @@ -253,5 +253,300 @@ async def manage_ssh_keys( return await ssh.manage_ssh_keys(action, key_name, public_key, key_id, offset, limit) +@mcp.tool() +async def create_sandbox( + name: str, + docker_image: str = "python:3.11-slim", + start_command: str | None = "tail -f /dev/null", + cpu_cores: int = 1, + memory_gb: int = 2, + disk_size_gb: int = 5, + network_access: bool = True, + timeout_minutes: int = 60, + environment_vars: dict[str, str] | None = None, + secrets: dict[str, str] | None = None, + labels: list[str] | None = None, + team_id: str | None = None, + registry_credentials_id: str | None = None, +) -> dict: + """Create a new sandbox for isolated code execution. + + A sandbox is a containerized environment where you can safely execute code, + run commands, and manage files in isolation. Perfect for: + - Running untrusted code safely + - Testing and development + - Data processing pipelines + - CI/CD tasks + + WORKFLOW: + 1. Create sandbox with create_sandbox() + 2. Wait for status to become RUNNING (check with get_sandbox()) + 3. Execute commands with execute_sandbox_command() + 4. Clean up with delete_sandbox() + + Args: + name: Name for the sandbox (required) + docker_image: Docker image to use (default: "python:3.11-slim") + Popular options: python:3.11-slim, ubuntu:22.04, node:20-slim + start_command: Command to run on startup (default: "tail -f /dev/null") + cpu_cores: Number of CPU cores (1-16, default: 1) + memory_gb: Memory in GB (1-64, default: 2) + disk_size_gb: Disk size in GB (1-1000, default: 5) + network_access: Enable network access (default: True) + timeout_minutes: Auto-termination timeout (1-1440 minutes, default: 60) + environment_vars: Environment variables as key-value pairs + secrets: Sensitive environment variables (e.g., API keys) - stored securely + labels: Labels for organizing and filtering sandboxes + team_id: Team ID for organization accounts + registry_credentials_id: ID for private Docker registry credentials + + Returns: + Created sandbox details including ID, status, and configuration + """ + return await sandboxes.create_sandbox( + name=name, + docker_image=docker_image, + start_command=start_command, + cpu_cores=cpu_cores, + memory_gb=memory_gb, + disk_size_gb=disk_size_gb, + network_access=network_access, + timeout_minutes=timeout_minutes, + environment_vars=environment_vars, + secrets=secrets, + labels=labels, + team_id=team_id, + registry_credentials_id=registry_credentials_id, + ) + + +@mcp.tool() +async def list_sandboxes( + team_id: str | None = None, + status: str | None = None, + labels: list[str] | None = None, + page: int = 1, + per_page: int = 50, + exclude_terminated: bool = False, +) -> dict: + """List all sandboxes in your account. + + Args: + team_id: Filter by team ID + status: Filter by status (PENDING, PROVISIONING, RUNNING, STOPPED, ERROR, TERMINATED) + labels: Filter by labels (sandboxes must have ALL specified labels) + page: Page number for pagination (default: 1) + per_page: Results per page (default: 50, max: 100) + exclude_terminated: Exclude terminated sandboxes (default: False) + + Returns: + List of sandboxes with pagination info (sandboxes, total, page, per_page, has_next) + """ + return await sandboxes.list_sandboxes( + team_id=team_id, + status=status, + labels=labels, + page=page, + per_page=per_page, + exclude_terminated=exclude_terminated, + ) + + +@mcp.tool() +async def get_sandbox(sandbox_id: str) -> dict: + """Get detailed information about a specific sandbox. + + Use this to check sandbox status before executing commands. + Sandbox must be in RUNNING status for command execution. + + Args: + sandbox_id: Unique identifier of the sandbox + + Returns: + Detailed sandbox information including: + - id, name, status + - docker_image, cpu_cores, memory_gb, disk_size_gb + - created_at, started_at, terminated_at + - labels, environment_vars + """ + return await sandboxes.get_sandbox(sandbox_id) + + +@mcp.tool() +async def delete_sandbox(sandbox_id: str) -> dict: + """Delete/terminate a sandbox. + + This will immediately terminate the sandbox and release resources. + Any unsaved data will be lost. + + Args: + sandbox_id: Unique identifier of the sandbox to delete + + Returns: + Deletion confirmation + """ + return await sandboxes.delete_sandbox(sandbox_id) + + +@mcp.tool() +async def bulk_delete_sandboxes( + sandbox_ids: list[str] | None = None, + labels: list[str] | None = None, +) -> dict: + """Bulk delete multiple sandboxes by IDs or labels. + + Useful for cleanup operations. You must specify either sandbox_ids OR labels, + but not both. + + Args: + sandbox_ids: List of sandbox IDs to delete + labels: Delete all sandboxes with ALL of these labels + + Returns: + Results showing succeeded and failed deletions + """ + return await sandboxes.bulk_delete_sandboxes(sandbox_ids=sandbox_ids, labels=labels) + + +@mcp.tool() +async def get_sandbox_logs(sandbox_id: str) -> dict: + """Get logs from a sandbox. + + Returns container logs including stdout/stderr from the start command + and any executed commands. + + Args: + sandbox_id: Unique identifier of the sandbox + + Returns: + Sandbox logs as text + """ + return await sandboxes.get_sandbox_logs(sandbox_id) + + +@mcp.tool() +async def execute_sandbox_command( + sandbox_id: str, + command: str, + working_dir: str | None = None, + env: dict[str, str] | None = None, + timeout: int = 300, +) -> dict: + """Execute a command in a sandbox. + + IMPORTANT: The sandbox must be in RUNNING status before executing commands. + Use get_sandbox() to check status first. + + Args: + sandbox_id: Unique identifier of the sandbox + command: Shell command to execute (e.g., "python script.py", "ls -la") + working_dir: Working directory for the command (optional) + env: Additional environment variables (optional) + timeout: Command timeout in seconds (default: 300, max: 3600) + + Returns: + Command result with: + - stdout: Standard output + - stderr: Standard error + - exit_code: Exit code (0 = success) + """ + return await sandboxes.execute_command( + sandbox_id=sandbox_id, + command=command, + working_dir=working_dir, + env=env, + timeout=timeout, + ) + + +@mcp.tool() +async def expose_sandbox_port( + sandbox_id: str, + port: int, + name: str | None = None, +) -> dict: + """Expose an HTTP port from a sandbox to the internet. + + Creates a public URL that routes traffic to the specified port. + Useful for web servers, APIs, Jupyter notebooks, Streamlit apps, etc. + + Args: + sandbox_id: Unique identifier of the sandbox + port: Port number to expose (22-9000, excluding 8080 which is reserved) + name: Optional friendly name for the exposure + + Returns: + Exposure details including: + - exposure_id: ID to use for unexpose_sandbox_port() + - url: Public URL to access the service + - tls_socket: TLS socket address + """ + return await sandboxes.expose_port(sandbox_id=sandbox_id, port=port, name=name) + + +@mcp.tool() +async def unexpose_sandbox_port(sandbox_id: str, exposure_id: str) -> dict: + """Remove a port exposure from a sandbox. + + Args: + sandbox_id: Unique identifier of the sandbox + exposure_id: ID of the exposure to remove (from expose_sandbox_port result) + + Returns: + Confirmation of removal + """ + return await sandboxes.unexpose_port(sandbox_id=sandbox_id, exposure_id=exposure_id) + + +@mcp.tool() +async def list_sandbox_exposed_ports(sandbox_id: str) -> dict: + """List all exposed ports for a sandbox. + + Args: + sandbox_id: Unique identifier of the sandbox + + Returns: + List of exposed ports with their URLs and details + """ + return await sandboxes.list_exposed_ports(sandbox_id) + + +@mcp.tool() +async def list_registry_credentials() -> dict: + """List available registry credentials for private Docker images. + + Registry credentials allow you to pull images from private Docker registries + like GitHub Container Registry, AWS ECR, Google Container Registry, etc. + + Returns: + List of registry credentials (id, name, server - no secrets) + """ + return await sandboxes.list_registry_credentials() + + +@mcp.tool() +async def check_docker_image( + image: str, + registry_credentials_id: str | None = None, +) -> dict: + """Check if a Docker image is accessible before creating a sandbox. + + Validates that the image exists and can be pulled. Useful for: + - Verifying public images exist + - Testing private registry credentials + + Args: + image: Docker image name (e.g., "python:3.11-slim", "ghcr.io/org/image:tag") + registry_credentials_id: Optional credentials ID for private registries + + Returns: + - accessible: Whether the image can be pulled + - details: Additional information or error message + """ + return await sandboxes.check_docker_image( + image=image, registry_credentials_id=registry_credentials_id + ) + + if __name__ == "__main__": mcp.run(transport="stdio") diff --git a/packages/prime-mcp-server/src/prime_mcp/tools/sandboxes.py b/packages/prime-mcp-server/src/prime_mcp/tools/sandboxes.py new file mode 100644 index 00000000..fde3a70a --- /dev/null +++ b/packages/prime-mcp-server/src/prime_mcp/tools/sandboxes.py @@ -0,0 +1,279 @@ +from typing import Any, Optional + +from prime_sandboxes import ( + APIError, + AsyncSandboxClient, + AsyncTemplateClient, + CommandTimeoutError, + CreateSandboxRequest, +) + +_sandbox_client: Optional[AsyncSandboxClient] = None +_template_client: Optional[AsyncTemplateClient] = None + + +def _get_sandbox_client() -> AsyncSandboxClient: + """Get or create the sandbox client singleton.""" + global _sandbox_client + if _sandbox_client is None: + _sandbox_client = AsyncSandboxClient() + return _sandbox_client + + +def _get_template_client() -> AsyncTemplateClient: + """Get or create the template client singleton.""" + global _template_client + if _template_client is None: + _template_client = AsyncTemplateClient() + return _template_client + + +async def create_sandbox( + name: str, + docker_image: str = "python:3.11-slim", + start_command: Optional[str] = "tail -f /dev/null", + cpu_cores: int = 1, + memory_gb: int = 2, + disk_size_gb: int = 5, + network_access: bool = True, + timeout_minutes: int = 60, + environment_vars: Optional[dict[str, str]] = None, + secrets: Optional[dict[str, str]] = None, + labels: Optional[list[str]] = None, + team_id: Optional[str] = None, + registry_credentials_id: Optional[str] = None, +) -> dict[str, Any]: + """Create a new sandbox for isolated code execution.""" + try: + client = _get_sandbox_client() + request = CreateSandboxRequest( + name=name, + docker_image=docker_image, + start_command=start_command, + cpu_cores=cpu_cores, + memory_gb=memory_gb, + disk_size_gb=disk_size_gb, + gpu_count=0, # GPU support not yet available + network_access=network_access, + timeout_minutes=timeout_minutes, + environment_vars=environment_vars, + secrets=secrets, + labels=labels or [], + team_id=team_id, + registry_credentials_id=registry_credentials_id, + ) + sandbox = await client.create(request) + return sandbox.model_dump(by_alias=True) + except APIError as e: + return {"error": str(e)} + except Exception as e: + return {"error": f"Failed to create sandbox: {e}"} + + +async def list_sandboxes( + team_id: Optional[str] = None, + status: Optional[str] = None, + labels: Optional[list[str]] = None, + page: int = 1, + per_page: int = 50, + exclude_terminated: bool = False, +) -> dict[str, Any]: + """List all sandboxes in your account.""" + try: + client = _get_sandbox_client() + response = await client.list( + team_id=team_id, + status=status, + labels=labels, + page=page, + per_page=per_page, + exclude_terminated=exclude_terminated, + ) + return { + "sandboxes": [s.model_dump(by_alias=True) for s in response.sandboxes], + "total": response.total, + "page": response.page, + "per_page": response.per_page, + "has_next": response.has_next, + } + except APIError as e: + return {"error": str(e)} + except Exception as e: + return {"error": f"Failed to list sandboxes: {e}"} + + +async def get_sandbox(sandbox_id: str) -> dict[str, Any]: + """Get detailed information about a specific sandbox.""" + if not sandbox_id: + return {"error": "sandbox_id is required"} + try: + client = _get_sandbox_client() + sandbox = await client.get(sandbox_id) + return sandbox.model_dump(by_alias=True) + except APIError as e: + return {"error": str(e)} + except Exception as e: + return {"error": f"Failed to get sandbox: {e}"} + + +async def delete_sandbox(sandbox_id: str) -> dict[str, Any]: + """Delete/terminate a sandbox.""" + if not sandbox_id: + return {"error": "sandbox_id is required"} + try: + client = _get_sandbox_client() + result = await client.delete(sandbox_id) + return result + except APIError as e: + return {"error": str(e)} + except Exception as e: + return {"error": f"Failed to delete sandbox: {e}"} + + +async def bulk_delete_sandboxes( + sandbox_ids: Optional[list[str]] = None, + labels: Optional[list[str]] = None, +) -> dict[str, Any]: + """Bulk delete multiple sandboxes by IDs or labels.""" + if not sandbox_ids and not labels: + return {"error": "Must specify either sandbox_ids or labels"} + if sandbox_ids and labels: + return {"error": "Cannot specify both sandbox_ids and labels"} + try: + client = _get_sandbox_client() + response = await client.bulk_delete(sandbox_ids=sandbox_ids, labels=labels) + return response.model_dump() + except APIError as e: + return {"error": str(e)} + except Exception as e: + return {"error": f"Failed to bulk delete sandboxes: {e}"} + + +async def get_sandbox_logs(sandbox_id: str) -> dict[str, Any]: + """Get logs from a sandbox.""" + if not sandbox_id: + return {"error": "sandbox_id is required"} + try: + client = _get_sandbox_client() + logs = await client.get_logs(sandbox_id) + return {"logs": logs} + except APIError as e: + return {"error": str(e)} + except Exception as e: + return {"error": f"Failed to get sandbox logs: {e}"} + + +async def execute_command( + sandbox_id: str, + command: str, + working_dir: Optional[str] = None, + env: Optional[dict[str, str]] = None, + timeout: int = 300, +) -> dict[str, Any]: + """Execute a command in a sandbox via the gateway.""" + if not sandbox_id: + return {"error": "sandbox_id is required"} + if not command: + return {"error": "command is required"} + if timeout < 1: + return {"error": "timeout must be at least 1 second"} + try: + client = _get_sandbox_client() + result = await client.execute_command( + sandbox_id=sandbox_id, + command=command, + working_dir=working_dir, + env=env, + timeout=timeout, + ) + return result.model_dump() + except CommandTimeoutError: + return {"error": f"Command timed out after {timeout} seconds"} + except APIError as e: + return {"error": str(e)} + except Exception as e: + return {"error": f"Failed to execute command: {e}"} + + +async def expose_port( + sandbox_id: str, + port: int, + name: Optional[str] = None, +) -> dict[str, Any]: + """Expose an HTTP port from a sandbox to the internet.""" + if not sandbox_id: + return {"error": "sandbox_id is required"} + if not port or port < 22 or port > 9000: + return {"error": "port must be between 22 and 9000"} + if port == 8080: + return {"error": "port 8080 is reserved and cannot be exposed"} + try: + client = _get_sandbox_client() + result = await client.expose(sandbox_id=sandbox_id, port=port, name=name) + return result.model_dump() + except APIError as e: + return {"error": str(e)} + except Exception as e: + return {"error": f"Failed to expose port: {e}"} + + +async def unexpose_port(sandbox_id: str, exposure_id: str) -> dict[str, Any]: + """Remove a port exposure from a sandbox.""" + if not sandbox_id: + return {"error": "sandbox_id is required"} + if not exposure_id: + return {"error": "exposure_id is required"} + try: + client = _get_sandbox_client() + await client.unexpose(sandbox_id=sandbox_id, exposure_id=exposure_id) + return {"success": True} + except APIError as e: + return {"error": str(e)} + except Exception as e: + return {"error": f"Failed to unexpose port: {e}"} + + +async def list_exposed_ports(sandbox_id: str) -> dict[str, Any]: + """List all exposed ports for a sandbox.""" + if not sandbox_id: + return {"error": "sandbox_id is required"} + try: + client = _get_sandbox_client() + response = await client.list_exposed_ports(sandbox_id) + return {"exposures": [e.model_dump() for e in response.exposures]} + except APIError as e: + return {"error": str(e)} + except Exception as e: + return {"error": f"Failed to list exposed ports: {e}"} + + +async def list_registry_credentials() -> dict[str, Any]: + """List available registry credentials for private Docker images.""" + try: + client = _get_template_client() + credentials = await client.list_registry_credentials() + return {"credentials": [c.model_dump(by_alias=True) for c in credentials]} + except APIError as e: + return {"error": str(e)} + except Exception as e: + return {"error": f"Failed to list registry credentials: {e}"} + + +async def check_docker_image( + image: str, + registry_credentials_id: Optional[str] = None, +) -> dict[str, Any]: + """Check if a Docker image is accessible.""" + if not image: + return {"error": "image is required"} + try: + client = _get_template_client() + result = await client.check_docker_image( + image=image, + registry_credentials_id=registry_credentials_id, + ) + return result.model_dump() + except APIError as e: + return {"error": str(e)} + except Exception as e: + return {"error": f"Failed to check docker image: {e}"} diff --git a/packages/prime-mcp-server/tests/test_mcp.py b/packages/prime-mcp-server/tests/test_mcp.py index 395f1a48..d320a517 100644 --- a/packages/prime-mcp-server/tests/test_mcp.py +++ b/packages/prime-mcp-server/tests/test_mcp.py @@ -82,6 +82,7 @@ async def test_create_pod_validation(): cloud_id="test-cloud-id", gpu_type="A100_80GB", provider_type="runpod", + data_center_id="US-CA-1", gpu_count=0, # Invalid ) @@ -96,6 +97,7 @@ async def test_create_pod_disk_size_validation(): cloud_id="test-cloud-id", gpu_type="A100_80GB", provider_type="runpod", + data_center_id="US-CA-1", disk_size=0, # Invalid ) @@ -110,6 +112,7 @@ async def test_create_pod_vcpus_validation(): cloud_id="test-cloud-id", gpu_type="A100_80GB", provider_type="runpod", + data_center_id="US-CA-1", vcpus=0, # Invalid ) @@ -124,6 +127,7 @@ async def test_create_pod_memory_validation(): cloud_id="test-cloud-id", gpu_type="A100_80GB", provider_type="runpod", + data_center_id="US-CA-1", memory=0, # Invalid ) diff --git a/packages/prime-mcp-server/tests/test_sandbox_tools.py b/packages/prime-mcp-server/tests/test_sandbox_tools.py new file mode 100644 index 00000000..ff25684f --- /dev/null +++ b/packages/prime-mcp-server/tests/test_sandbox_tools.py @@ -0,0 +1,265 @@ +import pytest + +from prime_mcp.tools import sandboxes + + +class TestCreateSandbox: + """Tests for create_sandbox function.""" + + @pytest.mark.asyncio + async def test_create_sandbox_validation_cpu_cores(self): + """Test that cpu_cores must be at least 1.""" + result = await sandboxes.create_sandbox( + name="test-sandbox", + cpu_cores=0, + ) + assert "error" in result + assert "cpu_cores" in result["error"].lower() or "greater than" in result["error"].lower() + + @pytest.mark.asyncio + async def test_create_sandbox_validation_memory_gb(self): + """Test that memory_gb must be at least 1.""" + result = await sandboxes.create_sandbox( + name="test-sandbox", + memory_gb=0, + ) + assert "error" in result + error_msg = result["error"].lower() + assert any(x in error_msg for x in ["memory", "greater than", "event loop"]) + + @pytest.mark.asyncio + async def test_create_sandbox_validation_disk_size_gb(self): + """Test that disk_size_gb must be at least 1.""" + result = await sandboxes.create_sandbox( + name="test-sandbox", + disk_size_gb=0, + ) + assert "error" in result + assert "disk" in result["error"].lower() or "greater than" in result["error"].lower() + + @pytest.mark.asyncio + async def test_create_sandbox_validation_timeout_minutes(self): + """Test that timeout_minutes must be at least 1.""" + result = await sandboxes.create_sandbox( + name="test-sandbox", + timeout_minutes=0, + ) + assert "error" in result + error_msg = result["error"].lower() + assert any(x in error_msg for x in ["timeout", "greater than", "event loop"]) + + +class TestListSandboxes: + """Tests for list_sandboxes function.""" + + @pytest.mark.asyncio + async def test_list_sandboxes_default_params(self): + """Test list_sandboxes with default parameters.""" + result = await sandboxes.list_sandboxes() + # Should return a dict (either with sandboxes or error) + assert isinstance(result, dict) + + @pytest.mark.asyncio + async def test_list_sandboxes_with_filters(self): + """Test list_sandboxes with status filter.""" + result = await sandboxes.list_sandboxes( + status="RUNNING", + page=1, + per_page=10, + ) + assert isinstance(result, dict) + + +class TestGetSandbox: + """Tests for get_sandbox function.""" + + @pytest.mark.asyncio + async def test_get_sandbox_empty_id(self): + """Test that sandbox_id is required.""" + result = await sandboxes.get_sandbox("") + assert "error" in result + assert "sandbox_id is required" in result["error"] + + @pytest.mark.asyncio + async def test_get_sandbox_valid_id(self): + """Test get_sandbox with valid ID format.""" + result = await sandboxes.get_sandbox("test-sandbox-id") + assert isinstance(result, dict) + + +class TestDeleteSandbox: + """Tests for delete_sandbox function.""" + + @pytest.mark.asyncio + async def test_delete_sandbox_empty_id(self): + """Test that sandbox_id is required.""" + result = await sandboxes.delete_sandbox("") + assert "error" in result + assert "sandbox_id is required" in result["error"] + + +class TestBulkDeleteSandboxes: + """Tests for bulk_delete_sandboxes function.""" + + @pytest.mark.asyncio + async def test_bulk_delete_no_params(self): + """Test that either sandbox_ids or labels is required.""" + result = await sandboxes.bulk_delete_sandboxes() + assert "error" in result + assert "Must specify either sandbox_ids or labels" in result["error"] + + @pytest.mark.asyncio + async def test_bulk_delete_both_params(self): + """Test that both sandbox_ids and labels cannot be specified.""" + result = await sandboxes.bulk_delete_sandboxes( + sandbox_ids=["id1", "id2"], + labels=["label1"], + ) + assert "error" in result + assert "Cannot specify both sandbox_ids and labels" in result["error"] + + +class TestGetSandboxLogs: + """Tests for get_sandbox_logs function.""" + + @pytest.mark.asyncio + async def test_get_logs_empty_id(self): + """Test that sandbox_id is required.""" + result = await sandboxes.get_sandbox_logs("") + assert "error" in result + assert "sandbox_id is required" in result["error"] + + +class TestExecuteCommand: + """Tests for execute_command function.""" + + @pytest.mark.asyncio + async def test_execute_command_empty_sandbox_id(self): + """Test that sandbox_id is required.""" + result = await sandboxes.execute_command( + sandbox_id="", + command="echo hello", + ) + assert "error" in result + assert "sandbox_id is required" in result["error"] + + @pytest.mark.asyncio + async def test_execute_command_empty_command(self): + """Test that command is required.""" + result = await sandboxes.execute_command( + sandbox_id="test-id", + command="", + ) + assert "error" in result + assert "command is required" in result["error"] + + @pytest.mark.asyncio + async def test_execute_command_invalid_timeout(self): + """Test that timeout must be at least 1 second.""" + result = await sandboxes.execute_command( + sandbox_id="test-id", + command="echo hello", + timeout=0, + ) + assert "error" in result + assert "timeout must be at least 1" in result["error"] + + +class TestExposePort: + """Tests for expose_port function.""" + + @pytest.mark.asyncio + async def test_expose_port_empty_sandbox_id(self): + """Test that sandbox_id is required.""" + result = await sandboxes.expose_port( + sandbox_id="", + port=8080, + ) + assert "error" in result + assert "sandbox_id is required" in result["error"] + + @pytest.mark.asyncio + async def test_expose_port_invalid_port_zero(self): + """Test that port must be valid (not 0).""" + result = await sandboxes.expose_port( + sandbox_id="test-id", + port=0, + ) + assert "error" in result + assert "port must be between 22 and 9000" in result["error"] + + @pytest.mark.asyncio + async def test_expose_port_invalid_port_high(self): + """Test that port must be valid (not > 9000).""" + result = await sandboxes.expose_port( + sandbox_id="test-id", + port=10000, + ) + assert "error" in result + assert "port must be between 22 and 9000" in result["error"] + + +class TestUnexposePort: + """Tests for unexpose_port function.""" + + @pytest.mark.asyncio + async def test_unexpose_port_empty_sandbox_id(self): + """Test that sandbox_id is required.""" + result = await sandboxes.unexpose_port( + sandbox_id="", + exposure_id="exp-123", + ) + assert "error" in result + assert "sandbox_id is required" in result["error"] + + @pytest.mark.asyncio + async def test_unexpose_port_empty_exposure_id(self): + """Test that exposure_id is required.""" + result = await sandboxes.unexpose_port( + sandbox_id="test-id", + exposure_id="", + ) + assert "error" in result + assert "exposure_id is required" in result["error"] + + +class TestListExposedPorts: + """Tests for list_exposed_ports function.""" + + @pytest.mark.asyncio + async def test_list_exposed_ports_empty_id(self): + """Test that sandbox_id is required.""" + result = await sandboxes.list_exposed_ports("") + assert "error" in result + assert "sandbox_id is required" in result["error"] + + +class TestCheckDockerImage: + """Tests for check_docker_image function.""" + + @pytest.mark.asyncio + async def test_check_docker_image_empty(self): + """Test that image is required.""" + result = await sandboxes.check_docker_image("") + assert "error" in result + assert "image is required" in result["error"] + + +class TestModuleImports: + """Test that all modules import correctly.""" + + def test_import_sandboxes(self): + """Test that sandboxes module can be imported.""" + from prime_mcp.tools import sandboxes as sb + + assert sb is not None + + def test_import_mcp_tools(self): + """Test that all tools can be imported from main module.""" + from prime_mcp import sandboxes as sb + + assert sb is not None + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/uv.lock b/uv.lock index b025b5a7..42ed2e9f 100644 --- a/uv.lock +++ b/uv.lock @@ -1682,6 +1682,7 @@ dependencies = [ { name = "fastmcp" }, { name = "httpx" }, { name = "mcp" }, + { name = "prime-sandboxes" }, ] [package.optional-dependencies] @@ -1696,6 +1697,7 @@ requires-dist = [ { name = "fastmcp", specifier = ">=0.2.0" }, { name = "httpx", specifier = ">=0.25.0" }, { name = "mcp", specifier = ">=1.0.0" }, + { name = "prime-sandboxes", editable = "packages/prime-sandboxes" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.21.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.13.1" },