diff --git a/packages/prime/src/prime_cli/commands/env.py b/packages/prime/src/prime_cli/commands/env.py index 9ed18e84..ec41bc3c 100644 --- a/packages/prime/src/prime_cli/commands/env.py +++ b/packages/prime/src/prime_cli/commands/env.py @@ -1530,18 +1530,28 @@ def install( simple_index_url = details.get("simple_index_url") wheel_url = process_wheel_url(details.get("wheel_url")) - # Check if this is a private environment + # Check if this is a private environment - pull, build, and install from cache if not simple_index_url and not wheel_url and details.get("visibility") == "PRIVATE": - skipped_envs.append((f"{env_id}@{target_version}", "Private")) - console.print( - f"[yellow]⚠ Skipping {env_id}@{target_version}: Private environment[/yellow]" - ) - console.print( - "[dim] Direct installation not available for private environments.[/dim]\n" - "[dim] Please use one of these alternatives:[/dim]\n" - " 1. Use 'prime env pull' to download and install locally\n" - " 2. Make the environment public to enable direct installation" - ) + console.print("[dim]Private environment detected, pulling and building...[/dim]") + try: + # Pull, build, and get actual version (resolves "latest" from pyproject.toml) + wheel_path, resolved_version = _pull_and_build_private_env( + client, owner, name, target_version, details + ) + if with_tool == "uv": + cmd_parts = ["uv", "pip", "install", str(wheel_path)] + else: + cmd_parts = ["pip", "install", str(wheel_path)] + if not no_upgrade: + cmd_parts.insert(-1, "--upgrade") + installable_envs.append((cmd_parts, env_id, resolved_version, name)) + console.print(f"[green]✓ Built {env_id}@{resolved_version}[/green]") + except Exception as e: + failed_envs.append((f"{env_id}@{target_version}", f"Failed to build: {e}")) + console.print( + f"[red]✗ Failed to build private environment {env_id}@{target_version}: " + f"{e}[/red]" + ) continue elif not simple_index_url and not wheel_url: skipped_envs.append((f"{env_id}@{target_version}", "No installation method")) @@ -1930,6 +1940,236 @@ def delete( raise typer.Exit(1) +def _safe_tar_extract(tar: tarfile.TarFile, dest_path: Path) -> None: + """Safely extract tar archive, preventing path traversal and symlink attacks. + + Args: + tar: Open tarfile object + dest_path: Destination directory for extraction + + Raises: + ValueError: If archive contains unsafe paths, symlinks, or hardlinks + """ + dest_path = dest_path.resolve() + + for member in tar.getmembers(): + member_path = Path(member.name) + + # Block symlinks - they can be used to write outside destination + # (e.g., symlink "evil" -> "/tmp", then file "evil/malicious.txt") + if member.issym(): + raise ValueError(f"Refusing to extract symlink: {member.name}") + + # Block hardlinks - they can also be used for attacks + if member.islnk(): + raise ValueError(f"Refusing to extract hardlink: {member.name}") + + # Block absolute paths + if member_path.is_absolute(): + raise ValueError(f"Refusing to extract absolute path: {member.name}") + + # Block path traversal + if ".." in member_path.parts: + raise ValueError(f"Refusing to extract path with '..': {member.name}") + + # Verify resolved path is within destination + target_path = (dest_path / member_path).resolve() + if not target_path.is_relative_to(dest_path): + raise ValueError(f"Path escapes destination directory: {member.name}") + + # All members validated, safe to extract + tar.extractall(dest_path) + + +def _get_env_cache_dir() -> Path: + """Get the cache directory for private environment wheels.""" + cache_dir = Path.home() / ".prime" / "wheel_cache" + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir + + +def _validate_path_component(component: str, component_name: str) -> None: + """Validate a path component doesn't contain traversal sequences. + + Args: + component: The path component to validate (owner, name, or version) + component_name: Name of the component for error messages + + Raises: + ValueError: If component contains unsafe characters + """ + if not component: + raise ValueError(f"{component_name} cannot be empty") + + # Block path traversal sequences + if ".." in component: + raise ValueError(f"{component_name} cannot contain '..'") + + # Block path separators + if "/" in component or "\\" in component: + raise ValueError(f"{component_name} cannot contain path separators") + + # Block null bytes + if "\x00" in component: + raise ValueError(f"{component_name} cannot contain null bytes") + + +def _get_version_from_pyproject(env_path: Path) -> Optional[str]: + """Extract version from pyproject.toml in the environment directory.""" + pyproject_path = env_path / "pyproject.toml" + if not pyproject_path.exists(): + return None + try: + pyproject_data = toml.load(pyproject_path) + return pyproject_data.get("project", {}).get("version") + except Exception: + return None + + +def _pull_and_build_private_env( + client: APIClient, + owner: str, + name: str, + version: str, + details: Dict[str, Any], +) -> Tuple[Path, str]: + """Pull a private environment, build it, and return the wheel path and resolved version. + + Args: + client: API client with authentication + owner: Environment owner + name: Environment name + version: Environment version (may be "latest") + details: Environment details from API + + Returns: + Tuple of (wheel_path, resolved_version) + + Raises: + Exception: If download, extraction, or build fails + """ + # Validate path components to prevent directory traversal + _validate_path_component(owner, "owner") + _validate_path_component(name, "name") + _validate_path_component(version, "version") + + download_url = details.get("package_url") + if not download_url: + raise ValueError("No downloadable package found for private environment") + + cache_dir = _get_env_cache_dir() + + # If version is not "latest", check cache directly + if version != "latest": + env_cache_path = cache_dir / owner / name / version + if not env_cache_path.resolve().is_relative_to(cache_dir.resolve()): + raise ValueError("Cache path escapes cache directory") + wheel_cache_path = env_cache_path / "dist" + if wheel_cache_path.exists(): + existing_wheels = list(wheel_cache_path.glob("*.whl")) + if existing_wheels: + console.print(f"[dim]Using cached wheel at {existing_wheels[0]}[/dim]") + return existing_wheels[0], version + + # Download to temp directory first to determine actual version + temp_extract_dir = None + temp_file_path = None + try: + temp_extract_dir = tempfile.mkdtemp(prefix="prime_env_") + temp_extract_path = Path(temp_extract_dir) + + with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as tmp: + temp_file_path = tmp.name + headers = {} + if client.api_key: + headers["Authorization"] = f"Bearer {client.api_key}" + + with httpx.stream("GET", download_url, headers=headers, timeout=60.0) as resp: + resp.raise_for_status() + with open(tmp.name, "wb") as f: + for chunk in resp.iter_bytes(chunk_size=8192): + f.write(chunk) + + # Extract to temp path (with path traversal protection) + with tarfile.open(tmp.name, "r:gz") as tar: + _safe_tar_extract(tar, temp_extract_path) + + # Get actual version from pyproject.toml + actual_version = _get_version_from_pyproject(temp_extract_path) or version + _validate_path_component(actual_version, "version") + + # Now we know the real version - check if it's already cached + env_cache_path = cache_dir / owner / name / actual_version + if not env_cache_path.resolve().is_relative_to(cache_dir.resolve()): + raise ValueError("Cache path escapes cache directory") + wheel_cache_path = env_cache_path / "dist" + + if wheel_cache_path.exists(): + existing_wheels = list(wheel_cache_path.glob("*.whl")) + if existing_wheels: + console.print(f"[dim]Using cached wheel at {existing_wheels[0]}[/dim]") + return existing_wheels[0], actual_version + + # Move extracted content to final cache location + env_cache_path.mkdir(parents=True, exist_ok=True) + for item in temp_extract_path.iterdir(): + shutil.move(str(item), str(env_cache_path / item.name)) + + finally: + if temp_file_path and Path(temp_file_path).exists(): + Path(temp_file_path).unlink() + if temp_extract_dir and Path(temp_extract_dir).exists(): + shutil.rmtree(temp_extract_dir, ignore_errors=True) + + # Build the wheel + console.print("[dim]Building wheel...[/dim]") + try: + if shutil.which("uv"): + subprocess.run( + ["uv", "build", "--wheel", "--out-dir", "dist"], + cwd=env_cache_path, + capture_output=True, + text=True, + check=True, + ) + else: + subprocess.run( + [sys.executable, "-m", "build", "--wheel", str(env_cache_path)], + capture_output=True, + text=True, + check=True, + ) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Failed to build wheel: {e.stderr}") from e + + # Find the built wheel + wheels = list(wheel_cache_path.glob("*.whl")) + if not wheels: + raise RuntimeError("No wheel file found after build") + + wheel_path = wheels[0] + + # Create metadata file for tracking + try: + prime_dir = env_cache_path / ".prime" + prime_dir.mkdir(exist_ok=True) + metadata_path = prime_dir / ".env-metadata.json" + env_metadata = { + "environment_id": details.get("id"), + "owner": owner, + "name": name, + "version": actual_version, + "cached_at": datetime.now().isoformat(), + "wheel_path": str(wheel_path), + } + with open(metadata_path, "w") as f: + json.dump(env_metadata, f, indent=2) + except Exception: + pass # Non-critical if metadata save fails + + return wheel_path, actual_version + + def _is_environment_installed(env_name: str, required_version: Optional[str] = None) -> bool: """Check if an environment package is installed.""" try: diff --git a/packages/prime/tests/test_private_env_install.py b/packages/prime/tests/test_private_env_install.py new file mode 100644 index 00000000..b8da0927 --- /dev/null +++ b/packages/prime/tests/test_private_env_install.py @@ -0,0 +1,365 @@ +"""Tests for private environment installation and caching.""" + +import os +import subprocess +from pathlib import Path + +import pytest + +# Test with a known private environment +ENV_OWNER = "prime-cli-test" +ENV_NAME = "private-reverse-text" + + +@pytest.fixture +def temp_home(tmp_path: Path): + """Temporarily set HOME to a temp directory for cache isolation.""" + original_home = os.environ.get("HOME") + os.environ["HOME"] = str(tmp_path) + + yield tmp_path + + # Cleanup: uninstall after tests + subprocess.run( + ["uv", "pip", "uninstall", ENV_NAME.replace("-", "_"), "-y"], + capture_output=True, + ) + + # Restore HOME to original state + if original_home is None: + del os.environ["HOME"] + else: + os.environ["HOME"] = original_home + + +class TestPrivateEnvInstall: + """Tests for private environment installation.""" + + @pytest.mark.skipif( + not os.environ.get("PRIME_API_KEY"), + reason="PRIME_API_KEY not set - required for private env access", + ) + def test_install_private_env_creates_cache(self, temp_home: Path): + """Test that installing a private env creates the correct cache structure.""" + # Install the private environment + result = subprocess.run( + ["uv", "run", "prime", "env", "install", f"{ENV_OWNER}/{ENV_NAME}"], + capture_output=True, + text=True, + timeout=300, + env={ + **os.environ, + "HOME": str(temp_home), + "PRIME_API_KEY": os.environ.get("PRIME_API_KEY", ""), + }, + ) + + print(f"stdout: {result.stdout}") + print(f"stderr: {result.stderr}") + + assert result.returncode == 0, f"Install failed: {result.stderr}\n{result.stdout}" + + # Verify cache structure: ~/.prime/wheel_cache/{owner}/{name}/{version}/ + envs_cache = temp_home / ".prime" / "wheel_cache" + assert envs_cache.exists(), "Cache directory ~/.prime/wheel_cache/ not created" + + owner_dir = envs_cache / ENV_OWNER + assert owner_dir.exists(), f"Owner directory not created: {owner_dir}" + + name_dir = owner_dir / ENV_NAME + assert name_dir.exists(), f"Environment directory not created: {name_dir}" + + # Should have at least one version directory + version_dirs = [d for d in name_dir.iterdir() if d.is_dir()] + assert len(version_dirs) > 0, "No version directories found" + + version_dir = version_dirs[0] + # Note: version may be "latest" if API doesn't return semantic_version + # The important thing is the cache exists and wheel was built + + # Verify wheel was built + dist_dir = version_dir / "dist" + assert dist_dir.exists(), f"dist/ directory not created: {dist_dir}" + + wheels = list(dist_dir.glob("*.whl")) + assert len(wheels) > 0, "No wheel file found in dist/" + + # Verify metadata was saved + metadata_path = version_dir / ".prime" / ".env-metadata.json" + assert metadata_path.exists(), f"Metadata file not created: {metadata_path}" + + @pytest.mark.skipif( + not os.environ.get("PRIME_API_KEY"), + reason="PRIME_API_KEY not set - required for private env access", + ) + def test_installed_private_env_can_be_loaded(self, temp_home: Path): + """Test that an installed private env can be loaded by verifiers.""" + # First install the environment + install_result = subprocess.run( + ["uv", "run", "prime", "env", "install", f"{ENV_OWNER}/{ENV_NAME}"], + capture_output=True, + text=True, + timeout=300, + env={ + **os.environ, + "HOME": str(temp_home), + "PRIME_API_KEY": os.environ.get("PRIME_API_KEY", ""), + }, + ) + + assert install_result.returncode == 0, ( + f"Install failed: {install_result.stderr}\n{install_result.stdout}" + ) + + # Try to load the environment using verifiers, both with and without the owner/name + load_script = f""" +import sys +try: + from verifiers import load_environment + env = load_environment('{ENV_NAME.replace("-", "_")}') + env = load_environment('{ENV_OWNER}/{ENV_NAME}') + print(f"Successfully loaded: {{type(env).__name__}}") + sys.exit(0) +except ImportError as e: + print(f"Import error: {{e}}") + sys.exit(1) +except Exception as e: + print(f"Load error: {{e}}") + sys.exit(1) +""" + + load_result = subprocess.run( + ["uv", "run", "python", "-c", load_script], + capture_output=True, + text=True, + timeout=60, + env={ + **os.environ, + "HOME": str(temp_home), + "PRIME_API_KEY": os.environ.get("PRIME_API_KEY", ""), + }, + ) + + print(f"Load stdout: {load_result.stdout}") + print(f"Load stderr: {load_result.stderr}") + + assert load_result.returncode == 0, ( + f"Failed to load environment: {load_result.stderr}\n{load_result.stdout}" + ) + assert "Successfully loaded" in load_result.stdout + + @pytest.mark.skipif( + not os.environ.get("PRIME_API_KEY"), + reason="PRIME_API_KEY not set - required for private env access", + ) + def test_cached_wheel_is_reused(self, temp_home: Path): + """Test that subsequent installs reuse the cached wheel.""" + # First install + result1 = subprocess.run( + ["uv", "run", "prime", "env", "install", f"{ENV_OWNER}/{ENV_NAME}"], + capture_output=True, + text=True, + timeout=300, + env={ + **os.environ, + "HOME": str(temp_home), + "PRIME_API_KEY": os.environ.get("PRIME_API_KEY", ""), + }, + ) + assert result1.returncode == 0, f"First install failed: {result1.stderr}" + + # Get the wheel modification time + envs_cache = temp_home / ".prime" / "wheel_cache" / ENV_OWNER / ENV_NAME + version_dirs = list(envs_cache.iterdir()) + assert len(version_dirs) > 0, f"No version dirs in {envs_cache}" + wheel_files = list((version_dirs[0] / "dist").glob("*.whl")) + assert len(wheel_files) > 0, f"No wheels in {version_dirs[0] / 'dist'}" + wheel_mtime_1 = wheel_files[0].stat().st_mtime + + # Second install (should reuse cache) + result2 = subprocess.run( + ["uv", "run", "prime", "env", "install", f"{ENV_OWNER}/{ENV_NAME}"], + capture_output=True, + text=True, + timeout=300, + env={ + **os.environ, + "HOME": str(temp_home), + "PRIME_API_KEY": os.environ.get("PRIME_API_KEY", ""), + }, + ) + assert result2.returncode == 0, f"Second install failed: {result2.stderr}" + + # Verify cache message appears + assert "Using cached wheel" in result2.stdout or "Using cached" in result2.stdout, ( + f"Expected cache reuse message in output: {result2.stdout}" + ) + + # Verify wheel wasn't rebuilt (same mtime) + wheel_mtime_2 = wheel_files[0].stat().st_mtime + assert wheel_mtime_1 == wheel_mtime_2, "Wheel was rebuilt instead of reusing cache" + + +class TestPathComponentValidation: + """Tests for path component validation (prevents cache directory escape).""" + + def test_validate_blocks_double_dot(self): + """Test that '..' in path components is rejected.""" + from prime_cli.commands.env import _validate_path_component + + with pytest.raises(ValueError, match="cannot contain '\\.\\.'"): + _validate_path_component("..", "owner") + + with pytest.raises(ValueError, match="cannot contain '\\.\\.'"): + _validate_path_component("foo/../bar", "name") + + def test_validate_blocks_path_separators(self): + """Test that path separators in components are rejected.""" + from prime_cli.commands.env import _validate_path_component + + with pytest.raises(ValueError, match="path separators"): + _validate_path_component("foo/bar", "version") + + with pytest.raises(ValueError, match="path separators"): + _validate_path_component("foo\\bar", "version") + + def test_validate_blocks_empty(self): + """Test that empty components are rejected.""" + from prime_cli.commands.env import _validate_path_component + + with pytest.raises(ValueError, match="cannot be empty"): + _validate_path_component("", "owner") + + def test_validate_blocks_null_bytes(self): + """Test that null bytes in components are rejected.""" + from prime_cli.commands.env import _validate_path_component + + with pytest.raises(ValueError, match="null bytes"): + _validate_path_component("foo\x00bar", "name") + + def test_validate_allows_normal_components(self): + """Test that normal path components are allowed.""" + from prime_cli.commands.env import _validate_path_component + + # These should not raise + _validate_path_component("primeintellect", "owner") + _validate_path_component("my-environment", "name") + _validate_path_component("1.0.0", "version") + _validate_path_component("latest", "version") + _validate_path_component("v2.3.4-beta.1", "version") + + +class TestSafeTarExtract: + """Tests for tar extraction security (tar-slip prevention).""" + + def test_safe_extract_blocks_absolute_path(self, tmp_path: Path): + """Test that absolute paths in tarball are rejected.""" + import tarfile + + from prime_cli.commands.env import _safe_tar_extract + + # Create a tarball with absolute path + tar_path = tmp_path / "malicious.tar.gz" + with tarfile.open(tar_path, "w:gz") as tar: + # Add a file with absolute path + info = tarfile.TarInfo(name="/etc/malicious") + info.size = 0 + tar.addfile(info) + + dest = tmp_path / "dest" + dest.mkdir() + + with tarfile.open(tar_path, "r:gz") as tar: + with pytest.raises(ValueError, match="absolute path"): + _safe_tar_extract(tar, dest) + + def test_safe_extract_blocks_path_traversal(self, tmp_path: Path): + """Test that path traversal (..) in tarball is rejected.""" + import tarfile + + from prime_cli.commands.env import _safe_tar_extract + + # Create a tarball with path traversal + tar_path = tmp_path / "malicious.tar.gz" + with tarfile.open(tar_path, "w:gz") as tar: + info = tarfile.TarInfo(name="../../../etc/malicious") + info.size = 0 + tar.addfile(info) + + dest = tmp_path / "dest" + dest.mkdir() + + with tarfile.open(tar_path, "r:gz") as tar: + with pytest.raises(ValueError, match="'\\.\\.'"): + _safe_tar_extract(tar, dest) + + def test_safe_extract_allows_normal_paths(self, tmp_path: Path): + """Test that normal paths in tarball are allowed.""" + import tarfile + + from prime_cli.commands.env import _safe_tar_extract + + # Create a tarball with normal paths + tar_path = tmp_path / "normal.tar.gz" + with tarfile.open(tar_path, "w:gz") as tar: + # Add normal files + info1 = tarfile.TarInfo(name="file.txt") + info1.size = 0 + tar.addfile(info1) + + info2 = tarfile.TarInfo(name="subdir/nested.txt") + info2.size = 0 + tar.addfile(info2) + + dest = tmp_path / "dest" + dest.mkdir() + + with tarfile.open(tar_path, "r:gz") as tar: + _safe_tar_extract(tar, dest) # Should not raise + + assert (dest / "file.txt").exists() + assert (dest / "subdir" / "nested.txt").exists() + + def test_safe_extract_blocks_symlinks(self, tmp_path: Path): + """Test that symlinks in tarball are rejected (prevents symlink attacks).""" + import tarfile + + from prime_cli.commands.env import _safe_tar_extract + + # Create a tarball with a symlink + tar_path = tmp_path / "malicious.tar.gz" + with tarfile.open(tar_path, "w:gz") as tar: + # Add a symlink pointing outside + info = tarfile.TarInfo(name="evil_link") + info.type = tarfile.SYMTYPE + info.linkname = "/tmp" + tar.addfile(info) + + dest = tmp_path / "dest" + dest.mkdir() + + with tarfile.open(tar_path, "r:gz") as tar: + with pytest.raises(ValueError, match="symlink"): + _safe_tar_extract(tar, dest) + + def test_safe_extract_blocks_hardlinks(self, tmp_path: Path): + """Test that hardlinks in tarball are rejected.""" + import tarfile + + from prime_cli.commands.env import _safe_tar_extract + + # Create a tarball with a hardlink + tar_path = tmp_path / "malicious.tar.gz" + with tarfile.open(tar_path, "w:gz") as tar: + # Add a hardlink + info = tarfile.TarInfo(name="evil_hardlink") + info.type = tarfile.LNKTYPE + info.linkname = "/etc/passwd" + tar.addfile(info) + + dest = tmp_path / "dest" + dest.mkdir() + + with tarfile.open(tar_path, "r:gz") as tar: + with pytest.raises(ValueError, match="hardlink"): + _safe_tar_extract(tar, dest)