-
Notifications
You must be signed in to change notification settings - Fork 483
New Remote Env / TS Env #703
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| from .remote_env import RemoteEnv | ||
| from .ts_env import TypeScriptEnv | ||
|
|
||
| __all__ = ["RemoteEnv", "TypeScriptEnv"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| import shlex | ||
|
|
||
| import httpx | ||
|
|
||
| import verifiers as vf | ||
| from verifiers.envs.sandbox_env import SandboxEnv | ||
|
|
||
| DEFAULT_API_URL = "https://api.primeintellect.ai" | ||
|
|
||
|
|
||
| class RemoteEnv(SandboxEnv): | ||
| def __init__( | ||
| self, | ||
| environment: str, | ||
| upload_path: str = "/app", | ||
| docker_image: str = "python:3.11-slim", | ||
| api_base_url: str | None = None, | ||
| api_key: str | None = None, | ||
| **kwargs, | ||
| ): | ||
| """ | ||
| Remote environment that downloads files from the Prime Environments Hub. | ||
|
|
||
| Args: | ||
| environment: Environment identifier in format "owner/name" or "owner/name@version" | ||
| upload_path: Path inside sandbox where files are extracted (default: /app) | ||
| docker_image: Docker image for sandbox (default: python:3.11-slim) | ||
| api_base_url: Base URL for Prime API (default: https://api.primeintellect.ai) | ||
| api_key: API key for authentication (optional, needed for private environments) | ||
| **kwargs: Additional arguments passed to SandboxEnv | ||
| """ | ||
| self.environment = environment | ||
| self.upload_path = upload_path | ||
| self.api_base_url = (api_base_url or DEFAULT_API_URL).rstrip("/") | ||
| self.api_key = api_key | ||
| self._package_url: str | None = None | ||
|
|
||
| if "@" in environment: | ||
| env_id, self.version = environment.rsplit("@", 1) | ||
| else: | ||
| env_id = environment | ||
| self.version = "latest" | ||
|
|
||
| parts = env_id.split("/") | ||
| if len(parts) != 2: | ||
| raise ValueError( | ||
| f"Invalid environment format: {environment}. Expected: owner/name or owner/name@version" | ||
| ) | ||
| self.owner, self.name = parts | ||
|
|
||
| super().__init__( | ||
| docker_image=docker_image, | ||
| start_command="tail -f /dev/null", | ||
| **kwargs, | ||
| ) | ||
|
|
||
| async def _fetch_package_url(self) -> str: | ||
| """Fetch the package URL from the environments hub.""" | ||
| if self._package_url: | ||
| return self._package_url | ||
|
|
||
| headers = {} | ||
| if self.api_key: | ||
| headers["Authorization"] = f"Bearer {self.api_key}" | ||
|
|
||
| async with httpx.AsyncClient() as client: | ||
| response = await client.get( | ||
| f"{self.api_base_url}/environmentshub/{self.owner}/{self.name}/@{self.version}", | ||
| headers=headers, | ||
| timeout=30.0, | ||
| ) | ||
| response.raise_for_status() | ||
| data = response.json() | ||
|
|
||
| details = data.get("data", data) | ||
| package_url = details.get("package_url") | ||
|
|
||
| if not package_url: | ||
| raise ValueError(f"No package URL found for environment {self.environment}") | ||
|
|
||
| self._package_url = package_url | ||
| return package_url | ||
|
|
||
| async def _download_and_extract(self, sandbox_id: str) -> None: | ||
| """Download tarball from hub and extract to sandbox.""" | ||
| package_url = await self._fetch_package_url() | ||
|
|
||
| download_script = f""" | ||
| import urllib.request | ||
| import tarfile | ||
| import os | ||
|
|
||
| os.makedirs("{self.upload_path}", exist_ok=True) | ||
| urllib.request.urlretrieve("{package_url}", "/tmp/env.tar.gz") | ||
| with tarfile.open("/tmp/env.tar.gz", "r:gz") as tar: | ||
| tar.extractall("{self.upload_path}") | ||
| os.remove("/tmp/env.tar.gz") | ||
| print("Download and extraction complete") | ||
| """ | ||
|
|
||
| result = await self.sandbox_client.execute_command( | ||
| sandbox_id, | ||
| f"python3 -c {shlex.quote(download_script)}", | ||
| timeout=120, | ||
| ) | ||
|
|
||
| if result.exit_code != 0: | ||
| raise RuntimeError(f"Failed to download environment: {result.stderr}") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Non-vf.Error exceptions cause sandbox resource leaksHigh Severity The new code raises Additional Locations (2) |
||
|
|
||
| async def _run_setup(self, sandbox_id: str) -> None: | ||
| """Run setup.sh from the sandbox directory.""" | ||
| sandbox_dir = f"{self.upload_path}/sandbox" | ||
|
|
||
| await self.sandbox_client.execute_command( | ||
| sandbox_id, | ||
| f"chmod +x {sandbox_dir}/setup.sh", | ||
| timeout=10, | ||
| ) | ||
|
|
||
| await self.sandbox_client.start_background_job( | ||
| sandbox_id, | ||
| f"{sandbox_dir}/setup.sh", | ||
| working_dir=sandbox_dir, | ||
| ) | ||
|
|
||
| async def setup_state(self, state: vf.State, **kwargs) -> vf.State: | ||
| state = await super().setup_state(state, **kwargs) | ||
| sandbox_id = state["sandbox_id"] | ||
|
|
||
| await self._wait_for_sandbox_ready(state["sandbox_state"], sandbox_id) | ||
| await self._download_and_extract(sandbox_id) | ||
| await self._run_setup(sandbox_id) | ||
|
|
||
| return state | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,252 @@ | ||
| import asyncio | ||
| import json | ||
| from typing import Any, Callable, cast | ||
|
|
||
| import verifiers as vf | ||
| from openai.types.chat import ChatCompletionFunctionToolParam | ||
| from .remote_env import RemoteEnv | ||
|
|
||
|
|
||
| class RemoteToolWrapper: | ||
| def __init__( | ||
| self, | ||
| name: str, | ||
| description: str, | ||
| parameters: dict, | ||
| env: "TypeScriptEnv", | ||
| ): | ||
| self.name = name | ||
| self.__name__ = name | ||
| self.__doc__ = description | ||
| self.parameters = parameters | ||
| self.env = env | ||
|
|
||
| async def __call__(self, **kwargs) -> str: | ||
| return await self.env._call_remote_tool(self.name, kwargs) | ||
|
|
||
| def to_oai_tool(self) -> ChatCompletionFunctionToolParam: | ||
| tool: ChatCompletionFunctionToolParam = { | ||
| "type": "function", | ||
| "function": { | ||
| "name": self.__name__, | ||
| "description": self.__doc__ or "", | ||
| "parameters": self.parameters or {"type": "object", "properties": {}}, | ||
| }, | ||
| } | ||
| return tool | ||
|
|
||
|
|
||
| class RemoteRewardRubric(vf.Rubric): | ||
| def __init__(self, reward_specs: list[dict], env: "TypeScriptEnv", **kwargs): | ||
| super().__init__(**kwargs) | ||
| self.env = env | ||
| self.reward_specs = reward_specs | ||
|
|
||
| for spec in reward_specs: | ||
| name = spec["name"] | ||
| weight = spec.get("weight", 1.0) | ||
| reward_func = self._create_reward_func(name) | ||
| self.add_reward_func(reward_func, weight=weight) | ||
|
|
||
| def _create_reward_func(self, name: str) -> Callable: | ||
| async def reward_func( | ||
| prompt: vf.Messages, | ||
| completion: vf.Messages, | ||
| answer: Any, | ||
| state: vf.State, | ||
| **kwargs, | ||
| ) -> float: | ||
| return await self.env._call_remote_reward(name, prompt, completion, answer, state) | ||
|
|
||
| reward_func.__name__ = name | ||
| return reward_func | ||
|
|
||
|
|
||
| class TypeScriptEnv(RemoteEnv): | ||
| def __init__( | ||
| self, | ||
| environment: str, | ||
| server_port: int = 3000, | ||
| server_ready_timeout: int = 120, | ||
| **kwargs, | ||
| ): | ||
| """ | ||
| TypeScript environment that runs a Bun server with tools and rewards. | ||
|
|
||
| Args: | ||
| environment: Environment identifier in format "owner/name" or "owner/name@version" | ||
| server_port: Port the TypeScript server listens on (default: 3000) | ||
| server_ready_timeout: Seconds to wait for server to be ready (default: 120) | ||
| **kwargs: Additional arguments passed to RemoteEnv | ||
| """ | ||
| super().__init__(environment=environment, **kwargs) | ||
|
|
||
| self.server_port = server_port | ||
| self.server_ready_timeout = server_ready_timeout | ||
| self.remote_tools: dict[str, RemoteToolWrapper] = {} | ||
| self._remote_rubric: RemoteRewardRubric | None = None | ||
| self._tools_discovered = False | ||
|
|
||
| async def _wait_for_server(self, sandbox_id: str) -> None: | ||
| for _ in range(self.server_ready_timeout): | ||
| result = await self.sandbox_client.execute_command( | ||
| sandbox_id, | ||
| f"curl -sf http://localhost:{self.server_port}/tools > /dev/null", | ||
| timeout=5, | ||
| ) | ||
| if result.exit_code == 0: | ||
| return | ||
|
|
||
| await asyncio.sleep(1) | ||
|
|
||
| raise TimeoutError(f"Server not ready after {self.server_ready_timeout} seconds") | ||
|
|
||
| async def _discover_tools(self, sandbox_id: str) -> list[dict]: | ||
| result = await self.sandbox_client.execute_command( | ||
| sandbox_id, | ||
| f"curl -sf http://localhost:{self.server_port}/tools", | ||
| timeout=10, | ||
| ) | ||
|
|
||
| if result.exit_code != 0: | ||
| raise RuntimeError(f"Failed to fetch tools: {result.stderr}") | ||
|
|
||
| data = json.loads(result.stdout) | ||
| return data["tools"] | ||
|
|
||
| async def _discover_rewards(self, sandbox_id: str) -> list[dict]: | ||
| result = await self.sandbox_client.execute_command( | ||
| sandbox_id, | ||
| f"curl -sf http://localhost:{self.server_port}/rewards", | ||
| timeout=10, | ||
| ) | ||
|
|
||
| if result.exit_code != 0: | ||
| raise RuntimeError(f"Failed to fetch rewards: {result.stderr}") | ||
|
|
||
| data = json.loads(result.stdout) | ||
| return data["rewards"] | ||
|
|
||
| def _register_tools(self, tool_specs: list[dict]) -> None: | ||
| for spec in tool_specs: | ||
| func_spec = spec.get("function", spec) if spec.get("type") == "function" else spec | ||
| name = func_spec["name"] | ||
| description = func_spec.get("description", "") | ||
| parameters = func_spec.get("parameters", {"type": "object", "properties": {}}) | ||
|
|
||
| wrapper = RemoteToolWrapper(name, description, parameters, self) | ||
| self.remote_tools[name] = wrapper | ||
| self.tools.append(wrapper) | ||
| oai_tool = wrapper.to_oai_tool() | ||
| self.oai_tools.append(oai_tool) | ||
| self.tool_map[name] = wrapper | ||
|
|
||
| def _register_rewards(self, reward_specs: list[dict]) -> None: | ||
| self._remote_rubric = RemoteRewardRubric(reward_specs, self) | ||
| self.add_rubric(self._remote_rubric) | ||
|
|
||
| async def _call_remote_tool(self, tool_name: str, args: dict) -> str: | ||
| sandbox_id = args.pop("_sandbox_id") | ||
| state = args.pop("_state", None) | ||
|
|
||
| payload = json.dumps({"args": args, "state": state or {}}) | ||
| payload_escaped = payload.replace("'", "'\"'\"'") | ||
|
|
||
| result = await self.sandbox_client.execute_command( | ||
| sandbox_id, | ||
| f"curl -sf -X POST http://localhost:{self.server_port}/tools/{tool_name} " | ||
| f"-H 'Content-Type: application/json' -d '{payload_escaped}'", | ||
| timeout=self.timeout_per_command_seconds, | ||
| ) | ||
|
|
||
| if result.exit_code != 0: | ||
| return f"Error calling tool {tool_name}: {result.stderr or 'Unknown error'}" | ||
|
|
||
| data = json.loads(result.stdout) | ||
| return data.get("result", str(data)) | ||
|
|
||
| async def _call_remote_reward( | ||
| self, | ||
| reward_name: str, | ||
| prompt: vf.Messages, | ||
| completion: vf.Messages, | ||
| answer: Any, | ||
| state: vf.State, | ||
| ) -> float: | ||
| sandbox_id = state["sandbox_id"] | ||
|
|
||
| payload = json.dumps({ | ||
| "prompt": prompt, | ||
| "completion": completion, | ||
| "answer": answer, | ||
| "state": {k: v for k, v in state.items() if k not in ["sandbox_state"]}, | ||
| }) | ||
| payload_escaped = payload.replace("'", "'\"'\"'") | ||
|
|
||
| result = await self.sandbox_client.execute_command( | ||
| sandbox_id, | ||
| f"curl -sf -X POST http://localhost:{self.server_port}/rewards/{reward_name} " | ||
| f"-H 'Content-Type: application/json' -d '{payload_escaped}'", | ||
| timeout=30, | ||
| ) | ||
|
|
||
| if result.exit_code != 0: | ||
| raise RuntimeError(f"Reward {reward_name} failed: {result.stderr}") | ||
|
|
||
| data = json.loads(result.stdout) | ||
| return float(data["score"]) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remote rewards fail because sandbox destroyed before scoringHigh Severity The Additional Locations (1) |
||
|
|
||
| def update_tool_args( | ||
| self, | ||
| tool_name: str, | ||
| tool_args: dict[str, Any], | ||
| messages: vf.Messages, | ||
| state: vf.State, | ||
| **kwargs, | ||
| ) -> dict[str, Any]: | ||
| updated_args = super().update_tool_args(tool_name, tool_args, messages, state, **kwargs) | ||
|
|
||
| if tool_name in self.remote_tools: | ||
| updated_args["_sandbox_id"] = state["sandbox_id"] | ||
| updated_args["_state"] = {k: v for k, v in state.items() if k not in ["sandbox_state"]} | ||
|
|
||
| return updated_args | ||
|
|
||
| async def setup_state(self, state: vf.State, **kwargs) -> vf.State: | ||
| state = await super().setup_state(state, **kwargs) | ||
| sandbox_id = state["sandbox_id"] | ||
|
|
||
| await self._wait_for_server(sandbox_id) | ||
|
|
||
| if not self._tools_discovered: | ||
| tool_specs = await self._discover_tools(sandbox_id) | ||
| self._register_tools(tool_specs) | ||
|
|
||
| reward_specs = await self._discover_rewards(sandbox_id) | ||
| self._register_rewards(reward_specs) | ||
|
|
||
| self._tools_discovered = True | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Race condition causes duplicate tool registrationHigh Severity The Additional Locations (1) |
||
|
|
||
| return state | ||
|
|
||
| async def call_tool( | ||
| self, | ||
| tool_name: str, | ||
| tool_args: dict, | ||
| tool_call_id: str, | ||
| **kwargs, | ||
| ) -> vf.Message: | ||
| if tool_name in self.remote_tools: | ||
| try: | ||
| result = await self.remote_tools[tool_name](**tool_args) | ||
| return cast( | ||
| vf.Message, | ||
| {"role": "tool", "content": str(result), "tool_call_id": tool_call_id}, | ||
| ) | ||
| except Exception as e: | ||
| return cast( | ||
| vf.Message, | ||
| {"role": "tool", "content": self.error_formatter(e), "tool_call_id": tool_call_id}, | ||
| ) | ||
|
|
||
| return await super().call_tool(tool_name, tool_args, tool_call_id, **kwargs) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unescaped string interpolation can break download script
Low Severity
The
package_urlandupload_pathvalues are directly interpolated into Python code using f-strings without escaping. If either value contains quote characters (particularly double quotes), the generated Python script will have invalid syntax and fail to execute. While the defaultupload_pathis safe and API-provided URLs typically don't contain quotes, this could cause hard-to-debug failures in edge cases.