Skip to content
262 changes: 251 additions & 11 deletions packages/prime/src/prime_cli/commands/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,18 +1530,28 @@ def install(
simple_index_url = details.get("simple_index_url")
wheel_url = process_wheel_url(details.get("wheel_url"))

# Check if this is a private environment
# Check if this is a private environment - pull, build, and install from cache
if not simple_index_url and not wheel_url and details.get("visibility") == "PRIVATE":
skipped_envs.append((f"{env_id}@{target_version}", "Private"))
console.print(
f"[yellow]⚠ Skipping {env_id}@{target_version}: Private environment[/yellow]"
)
console.print(
"[dim] Direct installation not available for private environments.[/dim]\n"
"[dim] Please use one of these alternatives:[/dim]\n"
" 1. Use 'prime env pull' to download and install locally\n"
" 2. Make the environment public to enable direct installation"
)
console.print("[dim]Private environment detected, pulling and building...[/dim]")
try:
# Pull, build, and get actual version (resolves "latest" from pyproject.toml)
wheel_path, resolved_version = _pull_and_build_private_env(
client, owner, name, target_version, details
)
if with_tool == "uv":
cmd_parts = ["uv", "pip", "install", str(wheel_path)]
else:
cmd_parts = ["pip", "install", str(wheel_path)]
if not no_upgrade:
cmd_parts.insert(-1, "--upgrade")
installable_envs.append((cmd_parts, env_id, resolved_version, name))
console.print(f"[green]✓ Built {env_id}@{resolved_version}[/green]")
except Exception as e:
failed_envs.append((f"{env_id}@{target_version}", f"Failed to build: {e}"))
console.print(
f"[red]✗ Failed to build private environment {env_id}@{target_version}: "
f"{e}[/red]"
)
continue
elif not simple_index_url and not wheel_url:
skipped_envs.append((f"{env_id}@{target_version}", "No installation method"))
Expand Down Expand Up @@ -1930,6 +1940,236 @@ def delete(
raise typer.Exit(1)


def _safe_tar_extract(tar: tarfile.TarFile, dest_path: Path) -> None:
"""Safely extract tar archive, preventing path traversal and symlink attacks.

Args:
tar: Open tarfile object
dest_path: Destination directory for extraction

Raises:
ValueError: If archive contains unsafe paths, symlinks, or hardlinks
"""
dest_path = dest_path.resolve()

for member in tar.getmembers():
member_path = Path(member.name)

# Block symlinks - they can be used to write outside destination
# (e.g., symlink "evil" -> "/tmp", then file "evil/malicious.txt")
if member.issym():
raise ValueError(f"Refusing to extract symlink: {member.name}")

# Block hardlinks - they can also be used for attacks
if member.islnk():
raise ValueError(f"Refusing to extract hardlink: {member.name}")

# Block absolute paths
if member_path.is_absolute():
raise ValueError(f"Refusing to extract absolute path: {member.name}")

# Block path traversal
if ".." in member_path.parts:
raise ValueError(f"Refusing to extract path with '..': {member.name}")

# Verify resolved path is within destination
target_path = (dest_path / member_path).resolve()
if not target_path.is_relative_to(dest_path):
raise ValueError(f"Path escapes destination directory: {member.name}")

# All members validated, safe to extract
tar.extractall(dest_path)


def _get_env_cache_dir() -> Path:
"""Get the cache directory for private environment wheels."""
cache_dir = Path.home() / ".prime" / "wheel_cache"
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir


def _validate_path_component(component: str, component_name: str) -> None:
"""Validate a path component doesn't contain traversal sequences.

Args:
component: The path component to validate (owner, name, or version)
component_name: Name of the component for error messages

Raises:
ValueError: If component contains unsafe characters
"""
if not component:
raise ValueError(f"{component_name} cannot be empty")

# Block path traversal sequences
if ".." in component:
raise ValueError(f"{component_name} cannot contain '..'")

# Block path separators
if "/" in component or "\\" in component:
raise ValueError(f"{component_name} cannot contain path separators")

# Block null bytes
if "\x00" in component:
raise ValueError(f"{component_name} cannot contain null bytes")


def _get_version_from_pyproject(env_path: Path) -> Optional[str]:
"""Extract version from pyproject.toml in the environment directory."""
pyproject_path = env_path / "pyproject.toml"
if not pyproject_path.exists():
return None
try:
pyproject_data = toml.load(pyproject_path)
return pyproject_data.get("project", {}).get("version")
except Exception:
return None


def _pull_and_build_private_env(
client: APIClient,
owner: str,
name: str,
version: str,
details: Dict[str, Any],
) -> Tuple[Path, str]:
"""Pull a private environment, build it, and return the wheel path and resolved version.

Args:
client: API client with authentication
owner: Environment owner
name: Environment name
version: Environment version (may be "latest")
details: Environment details from API

Returns:
Tuple of (wheel_path, resolved_version)

Raises:
Exception: If download, extraction, or build fails
"""
# Validate path components to prevent directory traversal
_validate_path_component(owner, "owner")
_validate_path_component(name, "name")
_validate_path_component(version, "version")

download_url = details.get("package_url")
if not download_url:
raise ValueError("No downloadable package found for private environment")

cache_dir = _get_env_cache_dir()

# If version is not "latest", check cache directly
if version != "latest":
env_cache_path = cache_dir / owner / name / version
if not env_cache_path.resolve().is_relative_to(cache_dir.resolve()):
raise ValueError("Cache path escapes cache directory")
wheel_cache_path = env_cache_path / "dist"
if wheel_cache_path.exists():
existing_wheels = list(wheel_cache_path.glob("*.whl"))
if existing_wheels:
console.print(f"[dim]Using cached wheel at {existing_wheels[0]}[/dim]")
return existing_wheels[0], version

# Download to temp directory first to determine actual version
temp_extract_dir = None
temp_file_path = None
try:
temp_extract_dir = tempfile.mkdtemp(prefix="prime_env_")
temp_extract_path = Path(temp_extract_dir)

with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as tmp:
temp_file_path = tmp.name
headers = {}
if client.api_key:
headers["Authorization"] = f"Bearer {client.api_key}"

with httpx.stream("GET", download_url, headers=headers, timeout=60.0) as resp:
resp.raise_for_status()
with open(tmp.name, "wb") as f:
for chunk in resp.iter_bytes(chunk_size=8192):
f.write(chunk)

# Extract to temp path (with path traversal protection)
with tarfile.open(tmp.name, "r:gz") as tar:
_safe_tar_extract(tar, temp_extract_path)

# Get actual version from pyproject.toml
actual_version = _get_version_from_pyproject(temp_extract_path) or version
_validate_path_component(actual_version, "version")

# Now we know the real version - check if it's already cached
env_cache_path = cache_dir / owner / name / actual_version
if not env_cache_path.resolve().is_relative_to(cache_dir.resolve()):
raise ValueError("Cache path escapes cache directory")
wheel_cache_path = env_cache_path / "dist"

if wheel_cache_path.exists():
existing_wheels = list(wheel_cache_path.glob("*.whl"))
if existing_wheels:
console.print(f"[dim]Using cached wheel at {existing_wheels[0]}[/dim]")
return existing_wheels[0], actual_version

# Move extracted content to final cache location
env_cache_path.mkdir(parents=True, exist_ok=True)
for item in temp_extract_path.iterdir():
shutil.move(str(item), str(env_cache_path / item.name))

finally:
if temp_file_path and Path(temp_file_path).exists():
Path(temp_file_path).unlink()
if temp_extract_dir and Path(temp_extract_dir).exists():
shutil.rmtree(temp_extract_dir, ignore_errors=True)

# Build the wheel
console.print("[dim]Building wheel...[/dim]")
try:
if shutil.which("uv"):
subprocess.run(
["uv", "build", "--wheel", "--out-dir", "dist"],
cwd=env_cache_path,
capture_output=True,
text=True,
check=True,
)
else:
subprocess.run(
[sys.executable, "-m", "build", "--wheel", str(env_cache_path)],
capture_output=True,
text=True,
check=True,
)
except subprocess.CalledProcessError as e:
raise RuntimeError(f"Failed to build wheel: {e.stderr}") from e

# Find the built wheel
wheels = list(wheel_cache_path.glob("*.whl"))
if not wheels:
raise RuntimeError("No wheel file found after build")

wheel_path = wheels[0]

# Create metadata file for tracking
try:
prime_dir = env_cache_path / ".prime"
prime_dir.mkdir(exist_ok=True)
metadata_path = prime_dir / ".env-metadata.json"
env_metadata = {
"environment_id": details.get("id"),
"owner": owner,
"name": name,
"version": actual_version,
"cached_at": datetime.now().isoformat(),
"wheel_path": str(wheel_path),
}
with open(metadata_path, "w") as f:
json.dump(env_metadata, f, indent=2)
except Exception:
pass # Non-critical if metadata save fails

return wheel_path, actual_version


def _is_environment_installed(env_name: str, required_version: Optional[str] = None) -> bool:
"""Check if an environment package is installed."""
try:
Expand Down
Loading
Loading