diff --git a/packages/prime/src/prime_cli/commands/env.py b/packages/prime/src/prime_cli/commands/env.py index 9fa68314..48ac63be 100644 --- a/packages/prime/src/prime_cli/commands/env.py +++ b/packages/prime/src/prime_cli/commands/env.py @@ -1,3 +1,4 @@ +import asyncio import hashlib import json import os @@ -16,6 +17,7 @@ import httpx import toml import typer +from prime_evals import EvalsClient from rich.console import Console from rich.table import Table from rich.text import Text @@ -2001,10 +2003,135 @@ def _install_single_environment(env_slug: str, tool: str = "uv") -> bool: return False +async def _run_eval_streaming( + env_id: str, + model: str, + api_key: str, + inference_url: str, + num_examples: int, + rollouts_per_example: int, + max_concurrent: int, + sampling_args_dict: Dict[str, Any], + env_args_dict: Dict[str, Any], + save_results: bool, + save_every: int, + job_id: str, + environments: List[Dict[str, str]], + team_id: Optional[str], + verbose: bool, +) -> None: + from verifiers.types import ClientConfig, EvalConfig, State + from verifiers.utils.eval_utils import run_evaluation, save_rollout_results + + api_client = APIClient() + evals_client = EvalsClient(api_client) + + eval_name = f"{env_id}--{model}--{datetime.now().strftime('%Y%m%d_%H%M%S')}" + eval_metadata = { + "framework": "verifiers", + "job_id": job_id, + "num_examples": num_examples, + "rollouts_per_example": rollouts_per_example, + } + + console.print("[blue]Creating evaluation on server...[/blue]") + create_response = evals_client.create_evaluation( + name=eval_name, + environments=environments, + model_name=model, + framework="verifiers", + metadata=eval_metadata, + is_public=False, + ) + eval_id = create_response.get("evaluation_id") + if not eval_id: + raise ValueError("Failed to get evaluation ID from server") + console.print(f"[green]✓ Created evaluation:[/green] {eval_id}") + + samples_pushed = 0 + + async def on_group_complete(states: list[State], completed: int, total: int) -> None: + nonlocal samples_pushed + samples = [] + for s in states: + prompt = s.get("prompt", []) + completion = s.get("completion", []) + if isinstance(prompt, list): + prompt = [dict(m) if hasattr(m, "items") else m for m in prompt] + if isinstance(completion, list): + completion = [dict(m) if hasattr(m, "items") else m for m in completion] + + sample = { + "example_id": s.get("example_id", 0), + "reward": s.get("reward", 0.0), + "prompt": prompt, + "completion": completion, + "task": s.get("task", ""), + "generation_ms": s["timing"]["generation_ms"] if s.get("timing") else 0, + "scoring_ms": s["timing"]["scoring_ms"] if s.get("timing") else 0, + } + if s.get("metrics"): + for k, v in s["metrics"].items(): + sample[k] = v + samples.append(sample) + + evals_client.push_samples(eval_id, samples) + samples_pushed += len(samples) + console.print(f"[dim]Streamed {completed}/{total} groups ({samples_pushed} samples)[/dim]") + + extra_headers = {"X-PI-Job-Id": job_id} + if team_id: + extra_headers["X-Prime-Team-ID"] = team_id + + client_config = ClientConfig( + api_key_var="PRIME_API_KEY", + api_base_url=inference_url, + extra_headers=extra_headers, + ) + + os.environ["PRIME_API_KEY"] = api_key + + eval_config = EvalConfig( + env_id=env_id, + env_args=env_args_dict, + env_dir_path="./environments", + model=model, + client_config=client_config, + sampling_args=sampling_args_dict, + num_examples=num_examples, + rollouts_per_example=rollouts_per_example, + max_concurrent=max_concurrent, + print_results=True, + verbose=verbose, + save_results=save_results, + save_every=save_every, + ) + + console.print(f"[blue]Starting evaluation with model: {model}[/blue]") + results = await run_evaluation(eval_config, on_group_complete=on_group_complete) + + if save_results: + save_rollout_results(results) + + metrics = {f"avg_{k}": sum(v) / len(v) for k, v in results["metrics"].items() if v} + metrics["avg_reward"] = ( + sum(results["reward"]) / len(results["reward"]) if results["reward"] else 0.0 + ) + + console.print("[blue]Finalizing evaluation...[/blue]") + evals_client.finalize_evaluation(eval_id, metrics=metrics) + console.print("[green]✓ Evaluation finalized[/green]") + + frontend_url = api_client.config.frontend_url + eval_url = f"{frontend_url}/dashboard/evaluations/{eval_id}" + console.print(f"\n[green]View results at:[/green] {eval_url}") + + @app.command( "eval", no_args_is_help=True, context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, + deprecated=True, ) def eval_env( ctx: typer.Context, @@ -2155,12 +2282,10 @@ def eval_env( inference_url = chosen_base - # Fast fail if the model doesn't exist (only for Prime Inference, not custom URLs) - # Check if using Prime Inference URL (either from config or explicitly provided) if chosen_base == inference_base_url: - client = InferenceClient() + inf_client = InferenceClient() try: - client.retrieve_model(model) + inf_client.retrieve_model(model) except InferenceAPIError as e: console.print( f"[red]Invalid model:[/red] {e} \n\n" @@ -2168,107 +2293,151 @@ def eval_env( ) raise typer.Exit(1) - cmd = ["uv", "run", "vf-eval", env_name_for_vf_eval] - - # Add chosen inference url - cmd += ["-b", inference_url] - - # Always pass the selected model (required option) - cmd += ["-m", model] - - # Environment modification may be necessary for passing in API key - env = os.environ.copy() - - # API key var: respect --api-key-var if provided to this command, else inject PRIME_API_KEY - if api_key_var: - cmd += ["-k", api_key_var] - else: - env["PRIME_API_KEY"] = api_key - cmd += ["-k", "PRIME_API_KEY"] - - # Forward vf-eval options if provided here - if env_args: - cmd += ["-a", env_args] - if num_examples is not None: - cmd += ["-n", str(num_examples)] - if rollouts_per_example is not None: - cmd += ["-r", str(rollouts_per_example)] - if max_concurrent is not None: - cmd += ["-c", str(max_concurrent)] - if max_tokens is not None: - cmd += ["-t", str(max_tokens)] - if temperature is not None: - cmd += ["-T", str(temperature)] - if sampling_args: - cmd += ["-S", sampling_args] - if verbose: - cmd += ["-v"] - if save_results: - cmd += ["-s"] - if save_every is not None: - cmd += ["-f", str(save_every)] - if save_to_hf_hub: - cmd += ["-H"] - if hf_hub_dataset_name: - cmd += ["-D", hf_hub_dataset_name] - - # Generate job_id for end-to-end tracing of eval runs eval_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") job_uuid = str(uuid.uuid4())[:8] sanitized_env = env_name_for_vf_eval.replace("-", "_").replace("/", "_") sanitized_model = model.replace("/", "_").replace("-", "_") job_id = f"{sanitized_env}_{sanitized_model}_{eval_timestamp}_{job_uuid}" - # Pass tracking header to vf-eval - cmd += ["--header", f"X-PI-Job-Id: {job_id}"] - - # If a team is configured, pass it to vf-eval via header - if config.team_id: - cmd += ["--header", f"X-Prime-Team-ID: {config.team_id}"] - console.print(f"[dim]Eval job_id: {job_id}[/dim]") - # Execute; stream output directly - try: - result = subprocess.run(cmd, env=env) - if result.returncode != 0: - raise typer.Exit(result.returncode) - except KeyboardInterrupt: - raise typer.Exit(130) - except FileNotFoundError: - console.print("[red]Failed to start vf-eval process.[/red]") - raise typer.Exit(1) + use_streaming = is_resolved and not skip_upload and not api_key_var - # Automatically push to hub after successful eval (unless --skip-upload is used) - if not skip_upload: - if is_resolved: - try: - if is_slug and upstream_owner and upstream_name: - push_eval_results_to_hub( - env_name=env_name_for_vf_eval, - model=model, - job_id=job_id, - env_path=Path(env_path) if env_path else None, - upstream_slug=f"{upstream_owner}/{upstream_name}", - ) - else: - check_path = Path(env_path) if env_path else Path.cwd() - push_eval_results_to_hub( - env_name=env_name_for_vf_eval, - model=model, - job_id=job_id, - env_path=check_path, - ) - except Exception as e: - console.print(f"[red]Failed to push results to hub:[/red] {e}") - console.print("[yellow]Evaluation completed but results were not pushed.[/yellow]") - raise typer.Exit(1) + if use_streaming: + if is_slug and upstream_owner and upstream_name: + environments: List[Dict[str, str]] = [{"slug": f"{upstream_owner}/{upstream_name}"}] else: - console.print( - "[dim]No upstream environment found. Skipped uploading evaluation " - "results to platform.\nUse `prime env push` to set an " - "upstream, or use `--env-path` to specify the correct path to the " - "environment if it's not the current directory.[/dim]" + hub_metadata = find_environment_metadata( + env_name=env_name_for_vf_eval, + env_path=Path(env_path) if env_path else Path.cwd(), + module_name=env_name_for_vf_eval.replace("-", "_"), + ) + if hub_metadata and hub_metadata.get("owner") and hub_metadata.get("name"): + environments = [{"slug": f"{hub_metadata['owner']}/{hub_metadata['name']}"}] + else: + environments = [{"name": env_name_for_vf_eval}] + + sampling_args_dict: Dict[str, Any] = {} + if sampling_args: + sampling_args_dict = json.loads(sampling_args) + if max_tokens is not None: + sampling_args_dict["max_tokens"] = max_tokens + if temperature is not None: + sampling_args_dict["temperature"] = temperature + + env_args_dict: Dict[str, Any] = {} + if env_args: + env_args_dict = json.loads(env_args) + + try: + asyncio.run( + _run_eval_streaming( + env_id=env_name_for_vf_eval, + model=model, + api_key=api_key, + inference_url=inference_url, + num_examples=num_examples or 5, + rollouts_per_example=rollouts_per_example or 3, + max_concurrent=max_concurrent or 32, + sampling_args_dict=sampling_args_dict, + env_args_dict=env_args_dict, + save_results=save_results, + save_every=save_every, + job_id=job_id, + environments=environments, + team_id=config.team_id, + verbose=verbose, + ) ) + except KeyboardInterrupt: + raise typer.Exit(130) + except Exception as e: + console.print(f"[red]Evaluation failed:[/red] {e}") + raise typer.Exit(1) else: - console.print("[dim]Skipped uploading evaluation results[/dim]") + cmd = ["uv", "run", "vf-eval", env_name_for_vf_eval] + cmd += ["-b", inference_url] + cmd += ["-m", model] + + env = os.environ.copy() + + if api_key_var: + cmd += ["-k", api_key_var] + else: + env["PRIME_API_KEY"] = api_key + cmd += ["-k", "PRIME_API_KEY"] + + if env_args: + cmd += ["-a", env_args] + if num_examples is not None: + cmd += ["-n", str(num_examples)] + if rollouts_per_example is not None: + cmd += ["-r", str(rollouts_per_example)] + if max_concurrent is not None: + cmd += ["-c", str(max_concurrent)] + if max_tokens is not None: + cmd += ["-t", str(max_tokens)] + if temperature is not None: + cmd += ["-T", str(temperature)] + if sampling_args: + cmd += ["-S", sampling_args] + if verbose: + cmd += ["-v"] + if save_results: + cmd += ["-s"] + if save_every is not None: + cmd += ["-f", str(save_every)] + if save_to_hf_hub: + cmd += ["-H"] + if hf_hub_dataset_name: + cmd += ["-D", hf_hub_dataset_name] + + cmd += ["--header", f"X-PI-Job-Id: {job_id}"] + + if config.team_id: + cmd += ["--header", f"X-Prime-Team-ID: {config.team_id}"] + + try: + result = subprocess.run(cmd, env=env) + if result.returncode != 0: + raise typer.Exit(result.returncode) + except KeyboardInterrupt: + raise typer.Exit(130) + except FileNotFoundError: + console.print("[red]Failed to start vf-eval process.[/red]") + raise typer.Exit(1) + + if not skip_upload: + if is_resolved: + try: + if is_slug and upstream_owner and upstream_name: + push_eval_results_to_hub( + env_name=env_name_for_vf_eval, + model=model, + job_id=job_id, + env_path=Path(env_path) if env_path else None, + upstream_slug=f"{upstream_owner}/{upstream_name}", + ) + else: + check_path = Path(env_path) if env_path else Path.cwd() + push_eval_results_to_hub( + env_name=env_name_for_vf_eval, + model=model, + job_id=job_id, + env_path=check_path, + ) + except Exception as e: + console.print(f"[red]Failed to push results to hub:[/red] {e}") + console.print( + "[yellow]Evaluation completed but results were not pushed.[/yellow]" + ) + raise typer.Exit(1) + else: + console.print( + "[dim]No upstream environment found. Skipped uploading evaluation " + "results to platform.\nUse `prime env push` to set an " + "upstream, or use `--env-path` to specify the correct path to the " + "environment if it's not the current directory.[/dim]" + ) + else: + console.print("[dim]Skipped uploading evaluation results[/dim]")