diff --git a/README.md b/README.md index bf9a4eeed..7e94d6535 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,29 @@ TrainerClient().wait_for_job_status(job_id) print("\n".join(TrainerClient().get_job_logs(name=job_id))) ``` +## Local Development + +Kubeflow Trainer client supports local development without needing a Kubernetes cluster. + +### Available Backends + +- **KubernetesBackend** (default) - Production training on Kubernetes +- **ContainerBackend** - Local development with Docker/Podman isolation +- **LocalProcessBackend** - Quick prototyping with Python subprocesses + +**Quick Start:** +Install container support: `pip install kubeflow[docker]` or `pip install kubeflow[podman]` + +```python +from kubeflow.trainer import TrainerClient, ContainerBackendConfig, CustomTrainer + +# Switch to local container execution +client = TrainerClient(backend_config=ContainerBackendConfig()) + +# Your training runs locally in isolated containers +job_id = client.train(trainer=CustomTrainer(func=train_fn)) +``` + ## Supported Kubeflow Projects | Project | Status | Version Support | Description | diff --git a/kubeflow/trainer/__init__.py b/kubeflow/trainer/__init__.py index 1fce8e0f5..a61867b6a 100644 --- a/kubeflow/trainer/__init__.py +++ b/kubeflow/trainer/__init__.py @@ -18,6 +18,10 @@ # Import the Kubeflow Trainer client. from kubeflow.trainer.api.trainer_client import TrainerClient +from kubeflow.trainer.backends.container.types import ( + ContainerBackendConfig, + TrainingRuntimeSource, +) from kubeflow.trainer.backends.localprocess.types import LocalProcessBackendConfig # Import the Kubeflow Trainer constants. @@ -64,5 +68,7 @@ "TrainerClient", "TrainerType", "LocalProcessBackendConfig", + "ContainerBackendConfig", "KubernetesBackendConfig", + "TrainingRuntimeSource", ] diff --git a/kubeflow/trainer/api/trainer_client.py b/kubeflow/trainer/api/trainer_client.py index cb96837dd..79613f541 100644 --- a/kubeflow/trainer/api/trainer_client.py +++ b/kubeflow/trainer/api/trainer_client.py @@ -17,6 +17,8 @@ from typing import Optional, Union from kubeflow.common.types import KubernetesBackendConfig +from kubeflow.trainer.backends.container.backend import ContainerBackend +from kubeflow.trainer.backends.container.types import ContainerBackendConfig from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend from kubeflow.trainer.backends.localprocess.backend import ( LocalProcessBackend, @@ -31,14 +33,19 @@ class TrainerClient: def __init__( self, - backend_config: Optional[Union[KubernetesBackendConfig, LocalProcessBackendConfig]] = None, + backend_config: Union[ + KubernetesBackendConfig, + LocalProcessBackendConfig, + ContainerBackendConfig, + ] = None, ): """Initialize a Kubeflow Trainer client. Args: - backend_config: Backend configuration. Either KubernetesBackendConfig or - LocalProcessBackendConfig, or None to use the backend's - default config class. Defaults to KubernetesBackendConfig. + backend_config: Backend configuration. Either KubernetesBackendConfig, + LocalProcessBackendConfig, ContainerBackendConfig, + or None to use the backend's default config class. + Defaults to KubernetesBackendConfig. Raises: ValueError: Invalid backend configuration. @@ -52,6 +59,8 @@ def __init__( self.backend = KubernetesBackend(backend_config) elif isinstance(backend_config, LocalProcessBackendConfig): self.backend = LocalProcessBackend(backend_config) + elif isinstance(backend_config, ContainerBackendConfig): + self.backend = ContainerBackend(backend_config) else: raise ValueError(f"Invalid backend config '{backend_config}'") diff --git a/kubeflow/trainer/backends/container/adapters/base.py b/kubeflow/trainer/backends/container/adapters/base.py new file mode 100644 index 000000000..3e38a6b89 --- /dev/null +++ b/kubeflow/trainer/backends/container/adapters/base.py @@ -0,0 +1,195 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Container client adapters for Docker and Podman. + +This module implements the adapter pattern to abstract away differences between +Docker and Podman APIs, allowing the backend to work with either runtime through +a common interface. +""" + +from __future__ import annotations + +import abc +from collections.abc import Iterator +from typing import Optional + + +class BaseContainerClientAdapter(abc.ABC): + """ + Abstract adapter interface for container clients. + + This adapter abstracts the container runtime API, allowing the backend + to work with Docker and Podman through a unified interface. + """ + + @abc.abstractmethod + def ping(self): + """Test the connection to the container runtime.""" + raise NotImplementedError() + + @abc.abstractmethod + def create_network( + self, + name: str, + labels: dict[str, str], + ) -> str: + """ + Create a container network. + + Args: + name: Network name + labels: Labels to attach to the network + + Returns: + Network ID or name + """ + raise NotImplementedError() + + @abc.abstractmethod + def delete_network(self, network_id: str): + """Delete a network.""" + raise NotImplementedError() + + @abc.abstractmethod + def create_and_start_container( + self, + image: str, + command: list[str], + name: str, + network_id: str, + environment: dict[str, str], + labels: dict[str, str], + volumes: dict[str, dict[str, str]], + working_dir: str, + ) -> str: + """ + Create and start a container. + + Args: + image: Container image + command: Command to run + name: Container name + network_id: Network to attach to + environment: Environment variables + labels: Container labels + volumes: Volume mounts + working_dir: Working directory + + Returns: + Container ID + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_container(self, container_id: str): + """Get container object by ID.""" + raise NotImplementedError() + + @abc.abstractmethod + def container_logs(self, container_id: str, follow: bool) -> Iterator[str]: + """Stream logs from a container.""" + raise NotImplementedError() + + @abc.abstractmethod + def stop_container(self, container_id: str, timeout: int = 10): + """Stop a container.""" + raise NotImplementedError() + + @abc.abstractmethod + def remove_container(self, container_id: str, force: bool = True): + """Remove a container.""" + raise NotImplementedError() + + @abc.abstractmethod + def pull_image(self, image: str): + """Pull an image.""" + raise NotImplementedError() + + @abc.abstractmethod + def image_exists(self, image: str) -> bool: + """Check if an image exists locally.""" + raise NotImplementedError() + + @abc.abstractmethod + def run_oneoff_container(self, image: str, command: list[str]) -> str: + """ + Run a short-lived container and return its output. + + Args: + image: Container image + command: Command to run + + Returns: + Container output as string + """ + raise NotImplementedError() + + @abc.abstractmethod + def container_status(self, container_id: str) -> tuple[str, Optional[int]]: + """ + Get container status. + + Returns: + Tuple of (status_string, exit_code) + Status strings: "running", "created", "exited", etc. + Exit code is None if container hasn't exited + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_container_ip(self, container_id: str, network_id: str) -> Optional[str]: + """ + Get container's IP address on a specific network. + + Args: + container_id: Container ID + network_id: Network name or ID + + Returns: + IP address string or None if not found + """ + raise NotImplementedError() + + @abc.abstractmethod + def list_containers(self, filters: Optional[dict[str, list[str]]] = None) -> list[dict]: + """ + List containers, optionally filtered by labels. + + Args: + filters: Dictionary of filters (e.g., {"label": ["key=value"]}) + + Returns: + List of container info dictionaries with keys: + - id: Container ID + - name: Container name + - labels: Dictionary of labels + - status: Container status + - created: Creation timestamp + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_network(self, network_id: str) -> Optional[dict]: + """ + Get network information by ID or name. + + Args: + network_id: Network ID or name + + Returns: + Dictionary with network info including labels, or None if not found + """ + raise NotImplementedError() diff --git a/kubeflow/trainer/backends/container/adapters/docker.py b/kubeflow/trainer/backends/container/adapters/docker.py new file mode 100644 index 000000000..c4c0e1205 --- /dev/null +++ b/kubeflow/trainer/backends/container/adapters/docker.py @@ -0,0 +1,229 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Docker client adapter implementation. + +This module provides the DockerClientAdapter class that implements the +BaseContainerClientAdapter interface for Docker runtime. +""" + +from collections.abc import Iterator +from typing import Optional + +from kubeflow.trainer.backends.container.adapters.base import BaseContainerClientAdapter + + +class DockerClientAdapter(BaseContainerClientAdapter): + """Adapter for Docker client.""" + + def __init__(self, host: Optional[str] = None): + """ + Initialize Docker client. + + Args: + host: Docker host URL, or None to use environment defaults + """ + try: + import docker # type: ignore + except ImportError as e: + raise ImportError( + "The 'docker' Python package is not installed. Install with extras: " + "pip install kubeflow[docker]" + ) from e + + if host: + self.client = docker.DockerClient(base_url=host) + else: + self.client = docker.from_env() + + self._runtime_type = "docker" + + def ping(self): + """Test connection to Docker daemon.""" + self.client.ping() + + def create_network(self, name: str, labels: dict[str, str]) -> str: + """Create a Docker network.""" + try: + self.client.networks.get(name) + return name + except Exception: + pass + + self.client.networks.create( + name=name, + check_duplicate=True, + labels=labels, + ) + return name + + def delete_network(self, network_id: str): + """Delete Docker network.""" + try: + net = self.client.networks.get(network_id) + net.remove() + except Exception: + pass + + def create_and_start_container( + self, + image: str, + command: list[str], + name: str, + network_id: str, + environment: dict[str, str], + labels: dict[str, str], + volumes: dict[str, dict[str, str]], + working_dir: str, + ) -> str: + """Create and start a Docker container.""" + container = self.client.containers.run( + image=image, + command=tuple(command), + name=name, + detach=True, + working_dir=working_dir, + network=network_id, + environment=environment, + labels=labels, + volumes=volumes, + auto_remove=False, + ) + return container.id + + def get_container(self, container_id: str): + """Get Docker container by ID.""" + return self.client.containers.get(container_id) + + def container_logs(self, container_id: str, follow: bool) -> Iterator[str]: + """Stream logs from Docker container.""" + container = self.get_container(container_id) + logs = container.logs(stream=bool(follow), follow=bool(follow)) + if follow: + for chunk in logs: + if isinstance(chunk, bytes): + yield chunk.decode("utf-8", errors="ignore") + else: + yield str(chunk) + else: + if isinstance(logs, bytes): + yield logs.decode("utf-8", errors="ignore") + else: + yield str(logs) + + def stop_container(self, container_id: str, timeout: int = 10): + """Stop Docker container.""" + container = self.get_container(container_id) + container.stop(timeout=timeout) + + def remove_container(self, container_id: str, force: bool = True): + """Remove Docker container.""" + container = self.get_container(container_id) + container.remove(force=force) + + def pull_image(self, image: str): + """Pull Docker image.""" + self.client.images.pull(image) + + def image_exists(self, image: str) -> bool: + """Check if Docker image exists locally.""" + try: + self.client.images.get(image) + return True + except Exception: + return False + + def run_oneoff_container(self, image: str, command: list[str]) -> str: + """Run a short-lived Docker container and return output.""" + try: + output = self.client.containers.run( + image=image, + command=tuple(command), + detach=False, + remove=True, + ) + if isinstance(output, (bytes, bytearray)): + return output.decode("utf-8", errors="ignore") + return str(output) + except Exception as e: + raise RuntimeError(f"One-off container failed to run: {e}") from e + + def container_status(self, container_id: str) -> tuple[str, Optional[int]]: + """Get Docker container status.""" + try: + container = self.get_container(container_id) + status = container.status + # Get exit code if container has exited + exit_code = None + if status == "exited": + inspect = container.attrs if hasattr(container, "attrs") else container.inspect() + exit_code = inspect.get("State", {}).get("ExitCode") + return (status, exit_code) + except Exception: + return ("unknown", None) + + def get_container_ip(self, container_id: str, network_id: str) -> Optional[str]: + """Get container's IP address on a specific network.""" + try: + container = self.get_container(container_id) + # Refresh container info + container.reload() + # Get network settings + networks = container.attrs.get("NetworkSettings", {}).get("Networks", {}) + + # Try to find the network by exact name or ID + if network_id in networks: + return networks[network_id].get("IPAddress") + + # Fallback: return first available IP + for _net_name, net_info in networks.items(): + ip = net_info.get("IPAddress") + if ip: + return ip + + return None + except Exception: + return None + + def list_containers(self, filters: Optional[dict[str, str]] = None) -> list[dict]: + """List Docker containers with optional filters.""" + try: + containers = self.client.containers.list(all=True, filters=filters) + result = [] + for c in containers: + result.append( + { + "id": c.id, + "name": c.name, + "labels": c.labels, + "status": c.status, + "created": c.attrs.get("Created", ""), + } + ) + return result + except Exception: + return [] + + def get_network(self, network_id: str) -> Optional[dict]: + """Get Docker network information.""" + try: + network = self.client.networks.get(network_id) + return { + "id": network.id, + "name": network.name, + "labels": network.attrs.get("Labels", {}), + } + except Exception: + return None diff --git a/kubeflow/trainer/backends/container/adapters/podman.py b/kubeflow/trainer/backends/container/adapters/podman.py new file mode 100644 index 000000000..72274e344 --- /dev/null +++ b/kubeflow/trainer/backends/container/adapters/podman.py @@ -0,0 +1,247 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Podman client adapter implementation. + +This module provides the PodmanClientAdapter class that implements the +BaseContainerClientAdapter interface for Podman runtime. + +Key differences from Docker: +- Uses DNS-enabled bridge networks for better container name resolution +- GPU support via CDI (Container Device Interface) instead of NVIDIA Container Toolkit +- Slightly different API for some operations (e.g., container.create + start pattern) +""" + +from collections.abc import Iterator +from typing import Optional + +from kubeflow.trainer.backends.container.adapters.base import BaseContainerClientAdapter + + +class PodmanClientAdapter(BaseContainerClientAdapter): + """Adapter for Podman client.""" + + def __init__(self, host: Optional[str] = None): + """ + Initialize Podman client. + + Args: + host: Podman host URL, or None to use environment defaults + """ + try: + import podman # type: ignore + except ImportError as e: + raise ImportError( + "The 'podman' Python package is not installed. Install with extras: " + "pip install kubeflow[podman]" + ) from e + + if host: + self.client = podman.PodmanClient(base_url=host) + else: + self.client = podman.PodmanClient() + + self._runtime_type = "podman" + + def ping(self): + """Test connection to Podman.""" + self.client.ping() + + def create_network(self, name: str, labels: dict[str, str]) -> str: + """Create a Podman network with DNS enabled.""" + try: + self.client.networks.get(name) + return name + except Exception: + pass + + self.client.networks.create( + name=name, + driver="bridge", + dns_enabled=True, + labels=labels, + ) + return name + + def delete_network(self, network_id: str): + """Delete Podman network.""" + try: + net = self.client.networks.get(network_id) + net.remove() + except Exception: + pass + + def create_and_start_container( + self, + image: str, + command: list[str], + name: str, + network_id: str, + environment: dict[str, str], + labels: dict[str, str], + volumes: dict[str, dict[str, str]], + working_dir: str, + ) -> str: + """Create and start a Podman container.""" + container = self.client.containers.run( + image=image, + command=command, + name=name, + network=network_id, + working_dir=working_dir, + environment=environment, + labels=labels, + volumes=volumes, + detach=True, + remove=False, + ) + return container.id + + def get_container(self, container_id: str): + """Get Podman container by ID.""" + return self.client.containers.get(container_id) + + def container_logs(self, container_id: str, follow: bool) -> Iterator[str]: + """Stream logs from Podman container.""" + container = self.get_container(container_id) + logs = container.logs(stream=bool(follow), follow=bool(follow)) + if follow: + for chunk in logs: + if isinstance(chunk, bytes): + yield chunk.decode("utf-8", errors="ignore") + else: + yield str(chunk) + else: + if isinstance(logs, bytes): + yield logs.decode("utf-8", errors="ignore") + else: + yield str(logs) + + def stop_container(self, container_id: str, timeout: int = 10): + """Stop Podman container.""" + container = self.get_container(container_id) + container.stop(timeout=timeout) + + def remove_container(self, container_id: str, force: bool = True): + """Remove Podman container.""" + container = self.get_container(container_id) + container.remove(force=force) + + def pull_image(self, image: str): + """Pull Podman image.""" + self.client.images.pull(image) + + def image_exists(self, image: str) -> bool: + """Check if Podman image exists locally.""" + try: + self.client.images.get(image) + return True + except Exception: + return False + + def run_oneoff_container(self, image: str, command: list[str]) -> str: + """Run a short-lived Podman container and return output.""" + try: + container = self.client.containers.create( + image=image, + command=command, + detach=False, + remove=True, + ) + container.start() + container.wait() + logs = container.logs() + + if isinstance(logs, (bytes, bytearray)): + return logs.decode("utf-8", errors="ignore") + return str(logs) + except Exception as e: + raise RuntimeError(f"One-off container failed to run: {e}") from e + + def container_status(self, container_id: str) -> tuple[str, Optional[int]]: + """Get Podman container status.""" + try: + container = self.get_container(container_id) + status = container.status + # Get exit code if container has exited + exit_code = None + if status == "exited": + inspect = container.attrs if hasattr(container, "attrs") else container.inspect() + exit_code = inspect.get("State", {}).get("ExitCode") + return (status, exit_code) + except Exception: + return ("unknown", None) + + def get_container_ip(self, container_id: str, network_id: str) -> Optional[str]: + """Get container's IP address on a specific network.""" + try: + container = self.get_container(container_id) + # Get container inspect data + inspect = container.attrs if hasattr(container, "attrs") else container.inspect() + + # Get network settings - Podman structure is similar to Docker + networks = inspect.get("NetworkSettings", {}).get("Networks", {}) + + # Try to find the network by exact name or ID + if network_id in networks: + return networks[network_id].get("IPAddress") + + # Fallback: return first available IP + for _net_name, net_info in networks.items(): + ip = net_info.get("IPAddress") + if ip: + return ip + + return None + except Exception: + return None + + def list_containers(self, filters: Optional[dict[str, str]] = None) -> list[dict]: + """List Podman containers with optional filters.""" + try: + containers = self.client.containers.list(all=True, filters=filters) + result = [] + for c in containers: + inspect = c.attrs if hasattr(c, "attrs") else c.inspect() + labels = ( + c.labels + if hasattr(c, "labels") + else inspect.get("Config", {}).get("Labels", {}) + ) + result.append( + { + "id": c.id, + "name": c.name, + "labels": labels, + "status": c.status, + "created": inspect.get("Created", ""), + } + ) + return result + except Exception: + return [] + + def get_network(self, network_id: str) -> Optional[dict]: + """Get Podman network information.""" + try: + network = self.client.networks.get(network_id) + inspect = network.attrs if hasattr(network, "attrs") else network.inspect() + return { + "id": inspect.get("ID", network_id), + "name": inspect.get("Name", network_id), + "labels": inspect.get("Labels", {}), + } + except Exception: + return None diff --git a/kubeflow/trainer/backends/container/backend.py b/kubeflow/trainer/backends/container/backend.py new file mode 100644 index 000000000..7408d83aa --- /dev/null +++ b/kubeflow/trainer/backends/container/backend.py @@ -0,0 +1,646 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +ContainerBackend +---------------- + +Unified local execution backend for `CustomTrainer` jobs using containers. + +This backend automatically detects and uses either Docker or Podman. +It provides a single interface regardless of the underlying container runtime. + +Key behaviors: +- Auto-detection: Tries Docker first, then Podman. Can be overridden via config. +- Multi-node jobs: one container per node connected via a per-job network. +- Entry script generation: we serialize the user's training function and embed it + inline in the container command using a heredoc (no file I/O on the host). The + script is created inside the container at /tmp/train.py and invoked using + `torchrun` (preferred) or `python` as a fallback. +- Runtimes: we use `config/training_runtimes` to define runtime images and + characteristics (e.g., torch). Defaults to `torch-distributed` if no runtime + is provided. +- Image pulling: controlled via `pull_policy` and performed automatically if + needed. +- Logs and lifecycle: streaming logs and deletion semantics similar to the + Docker/Podman backends, but with automatic runtime detection. +""" + +from __future__ import annotations + +from collections.abc import Iterator +from datetime import datetime +import logging +import os +import random +import shutil +import string +from typing import Optional, Union +import uuid + +from kubeflow.trainer.backends.base import RuntimeBackend +from kubeflow.trainer.backends.container import utils as container_utils +from kubeflow.trainer.backends.container.adapters.base import ( + BaseContainerClientAdapter, +) +from kubeflow.trainer.backends.container.adapters.docker import DockerClientAdapter +from kubeflow.trainer.backends.container.adapters.podman import PodmanClientAdapter +from kubeflow.trainer.backends.container.runtime_loader import ( + get_training_runtime_from_sources, + list_training_runtimes_from_sources, +) +from kubeflow.trainer.backends.container.types import ContainerBackendConfig +from kubeflow.trainer.constants import constants +from kubeflow.trainer.types import types + +logger = logging.getLogger(__name__) + + +class ContainerBackend(RuntimeBackend): + """ + Unified container backend that auto-detects Docker or Podman. + + This backend uses the adapter pattern to abstract away differences between + Docker and Podman, providing a single consistent interface. + """ + + def __init__(self, cfg: ContainerBackendConfig): + self.cfg = cfg + self.label_prefix = "trainer.kubeflow.org" + + # Initialize the container client adapter + self._adapter = self._create_adapter() + + def _get_common_socket_locations(self, runtime_name: str) -> list[Optional[str]]: + """ + Get common socket locations to try for the given runtime. + + Args: + runtime_name: "docker" or "podman" + + Returns: + List of socket URLs to try, including None (for default) + """ + import os + from pathlib import Path + + locations = [self.cfg.container_host] if self.cfg.container_host else [] + + if runtime_name == "docker": + # Common Docker socket locations + colima_sock = Path.home() / ".colima/default/docker.sock" + if colima_sock.exists(): + locations.append(f"unix://{colima_sock}") + # Standard Docker socket + locations.append(None) # Use docker.from_env() default + + elif runtime_name == "podman": + # Common Podman socket locations on macOS + uid = os.getuid() if hasattr(os, "getuid") else None + if uid: + user_sock = f"/run/user/{uid}/podman/podman.sock" + if Path(user_sock).exists(): + locations.append(f"unix://{user_sock}") + # Standard Podman socket + locations.append(None) # Use PodmanClient() default + + # Remove duplicates while preserving order + seen = set() + unique_locations = [] + for loc in locations: + if loc not in seen: + unique_locations.append(loc) + seen.add(loc) + + return unique_locations + + def _create_adapter(self) -> BaseContainerClientAdapter: + """ + Create the appropriate container client adapter. + + Tries Docker first, then Podman if Docker fails, unless a specific + runtime is requested in the config. Automatically tries common socket + locations (e.g., Colima for Docker on macOS, user socket for Podman). + + Raises RuntimeError if neither Docker nor Podman are available. + """ + runtime_map = { + "docker": DockerClientAdapter, + "podman": PodmanClientAdapter, + } + + # Determine which runtimes to try + runtimes_to_try = ( + [self.cfg.container_runtime] if self.cfg.container_runtime else ["docker", "podman"] + ) + + attempted_connections = [] + last_error = None + + for runtime_name in runtimes_to_try: + if runtime_name not in runtime_map: + continue + + # Try common socket locations for this runtime + socket_locations = self._get_common_socket_locations(runtime_name) + + for host in socket_locations: + try: + adapter = runtime_map[runtime_name](host) + adapter.ping() + host_display = host or "default" + logger.debug( + f"Using {runtime_name} as container runtime (host: {host_display})" + ) + return adapter + except Exception as e: + host_str = host or "default" + logger.debug(f"{runtime_name} initialization failed at {host_str}: {e}") + attempted_connections.append(f"{runtime_name} at {host_str}") + last_error = e + + # Build helpful error message + import platform + + system = platform.system() + + attempted = ", ".join(attempted_connections) + error_msg = f"Could not connect to Docker or Podman (tried: {attempted}).\n" + + if system == "Darwin": # macOS + error_msg += ( + "Ensure Docker/Podman is running " + "(e.g., 'colima start' or 'podman machine start').\n" + ) + else: + error_msg += "Ensure Docker/Podman is installed and running.\n" + + error_msg += ( + "To specify a custom socket: ContainerBackendConfig(container_host='unix:///path/to/socket')\n" + "Or use LocalProcessBackendConfig for non-containerized execution." + ) + + raise RuntimeError(error_msg) from last_error + + @property + def _runtime_type(self) -> str: + """Get the runtime type for debugging/logging.""" + return self._adapter._runtime_type + + # ---- Runtime APIs ---- + def list_runtimes(self) -> list[types.Runtime]: + return list_training_runtimes_from_sources(self.cfg.runtime_source.sources) + + def get_runtime(self, name: str) -> types.Runtime: + return get_training_runtime_from_sources(name, self.cfg.runtime_source.sources) + + def get_runtime_packages(self, runtime: types.Runtime): + """ + Spawn a short-lived container to report Python version, pip list, and nvidia-smi. + """ + image = container_utils.resolve_image(runtime) + container_utils.maybe_pull_image(self._adapter, image, self.cfg.pull_policy) + + command = [ + "bash", + "-lc", + "python -c \"import sys; print(f'Python: {sys.version}')\" && " + "(pip list || echo 'pip not found') && " + "(nvidia-smi || echo 'nvidia-smi not found')", + ] + + logs = self._adapter.run_oneoff_container(image=image, command=command) + print(logs) + + def train( + self, + runtime: Optional[types.Runtime] = None, + initializer: Optional[types.Initializer] = None, + trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None, + ) -> str: + if runtime is None: + runtime = self.get_runtime("torch-distributed") + + if not isinstance(trainer, types.CustomTrainer): + raise ValueError(f"{self.__class__.__name__} supports only CustomTrainer in v1") + + # Generate job name + job_name = random.choice(string.ascii_lowercase) + uuid.uuid4().hex[:11] + logger.debug(f"Starting training job: {job_name}") + + try: + # Create per-job working directory on host (for outputs, checkpoints, etc.) + workdir = container_utils.create_workdir(job_name) + logger.debug(f"Created working directory: {workdir}") + + # Generate training script code (inline, not written to disk) + training_script_code = container_utils.get_training_script_code(trainer) + logger.debug("Generated training script code") + + # Resolve image and pull if needed + image = container_utils.resolve_image(runtime) + logger.debug(f"Using image: {image}") + + container_utils.maybe_pull_image(self._adapter, image, self.cfg.pull_policy) + logger.debug(f"Image ready: {image}") + + # Build base environment + env = container_utils.build_environment(trainer) + + # Construct pre-run command to install packages + pre_install_cmd = container_utils.build_pip_install_cmd(trainer) + + # Create network for multi-node communication + num_nodes = trainer.num_nodes or runtime.trainer.num_nodes or 1 + logger.debug(f"Creating network for {num_nodes} nodes") + + # Determine number of processes per node from GPU count + # For GPU training: spawn one process per GPU for optimal utilization + # For CPU training: use single process (PyTorch parallelizes internally via threads) + nproc_per_node = 1 # Default for CPU training + if trainer.resources_per_node and "gpu" in trainer.resources_per_node: + try: + nproc_per_node = int(trainer.resources_per_node["gpu"]) + logger.debug(f"Using {nproc_per_node} processes per node (1 per GPU)") + except (ValueError, TypeError): + logger.warning( + f"Invalid GPU count in resources_per_node: " + f"{trainer.resources_per_node['gpu']}, defaulting to 1 process per node" + ) + else: + logger.debug("No GPU specified, using 1 process per node") + + network_id = self._adapter.create_network( + name=f"{job_name}-net", + labels={ + f"{self.label_prefix}/trainjob-name": job_name, + f"{self.label_prefix}/runtime-name": runtime.name, + f"{self.label_prefix}/workdir": workdir, + }, + ) + logger.debug(f"Created network: {network_id}") + + # Create N containers (one per node) + container_ids: list[str] = [] + master_container_id = None + master_ip = None + + for rank in range(num_nodes): + container_name = f"{job_name}-node-{rank}" + + # Get master address and port for torchrun + master_port = 29500 + + # For Podman: use IP address to avoid DNS timing issues + # For Docker: use hostname (DNS is reliable) + if rank == 0: + # Master node - will be created first + master_addr = f"{job_name}-node-0" + else: + # Worker nodes - determine master address based on runtime + if self._runtime_type == "podman" and master_ip: + master_addr = master_ip + logger.debug(f"Using master IP address for Podman: {master_ip}") + else: + master_addr = f"{job_name}-node-0" + logger.debug(f"Using master hostname: {master_addr}") + + # Prefer torchrun; fall back to python if torchrun is unavailable + # For worker nodes, wait for master to be reachable before starting torchrun + wait_for_master = "" + if rank > 0: + wait_for_master = ( + f"echo 'Waiting for master node {master_addr}:{master_port}...'; " + f"for i in {{1..60}}; do " + f" if timeout 1 bash -c 'cat < /dev/null > " + f"/dev/tcp/{master_addr}/{master_port}' 2>/dev/null; then " + f" echo 'Master node is reachable'; break; " + f" fi; " + f" if [ $i -eq 60 ]; then " + f"echo 'Timeout waiting for master node'; exit 1; fi; " + f" sleep 2; " + f"done; " + ) + + # Embed training script inline using heredoc (no file I/O on host) + entry_cmd = ( + f"{pre_install_cmd}" + f"{wait_for_master}" + f"cat > /tmp/train.py << 'TRAINING_SCRIPT_EOF'\n" + f"{training_script_code}\n" + f"TRAINING_SCRIPT_EOF\n" + "if command -v torchrun >/dev/null 2>&1; then " + f" torchrun --nproc_per_node={nproc_per_node} --nnodes={num_nodes} " + f" --node-rank={rank} --rdzv-backend=c10d " + f" --rdzv-endpoint={master_addr}:{master_port} " + f" /tmp/train.py; " + "else " + f" python /tmp/train.py; " + "fi" + ) + + full_cmd = ["bash", "-lc", entry_cmd] + + labels = { + f"{self.label_prefix}/trainjob-name": job_name, + f"{self.label_prefix}/step": f"node-{rank}", + f"{self.label_prefix}/network-id": network_id, + } + + volumes = { + workdir: { + "bind": constants.WORKSPACE_PATH, + "mode": "rw", + } + } + + logger.debug(f"Creating container {rank}/{num_nodes}: {container_name}") + + container_id = self._adapter.create_and_start_container( + image=image, + command=full_cmd, + name=container_name, + network_id=network_id, + environment=env, + labels=labels, + volumes=volumes, + working_dir=constants.WORKSPACE_PATH, + ) + + logger.debug(f"Started container {container_name} (ID: {container_id[:12]})") + container_ids.append(container_id) + + # If this is the master node and we're using Podman, get its IP address + if rank == 0: + master_container_id = container_id + if self._runtime_type == "podman": + # Get master IP for worker nodes to use + master_ip = self._adapter.get_container_ip(master_container_id, network_id) + if master_ip: + logger.debug(f"Master node IP address: {master_ip}") + else: + logger.warning( + "Could not retrieve master IP address. " + "Worker nodes will fall back to DNS resolution." + ) + + logger.debug( + f"Training job {job_name} created successfully with " + f"{len(container_ids)} container(s)" + ) + return job_name + + except Exception as e: + # Clean up on failure + logger.error(f"Failed to create training job {job_name}: {e}") + logger.exception("Full traceback:") + + # Try to clean up any resources that were created + from contextlib import suppress + + try: + # Stop and remove any containers that were created + if "container_ids" in locals(): + for container_id in container_ids: + with suppress(Exception): + self._adapter.stop_container(container_id, timeout=5) + self._adapter.remove_container(container_id, force=True) + + # Remove network if it was created + if "network_id" in locals(): + with suppress(Exception): + self._adapter.delete_network(network_id) + + # Remove working directory if it was created + if "workdir" in locals() and os.path.isdir(workdir): + shutil.rmtree(workdir, ignore_errors=True) + + except Exception as cleanup_error: + logger.error(f"Error during cleanup: {cleanup_error}") + + # Re-raise the original exception + raise + + def _get_job_containers(self, name: str) -> list[dict]: + """ + Get containers for a specific training job. + + Args: + name: Name of the training job + + Returns: + List of container dictionaries for this job + + Raises: + ValueError: If no containers found for the job + """ + filters = {"label": [f"{self.label_prefix}/trainjob-name={name}"]} + containers = self._adapter.list_containers(filters=filters) + + if not containers: + raise ValueError(f"No TrainJob with name {name}") + + return containers + + def __get_trainjob_from_containers( + self, job_name: str, containers: list[dict] + ) -> types.TrainJob: + """ + Build a TrainJob object from a list of containers. + + Args: + job_name: Name of the training job + containers: List of container dictionaries for this job + + Returns: + TrainJob object + + Raises: + ValueError: If network metadata is missing or runtime not found + """ + if not containers: + raise ValueError(f"No containers found for TrainJob {job_name}") + + # Get metadata from network + network_id = containers[0]["labels"].get(f"{self.label_prefix}/network-id") + if not network_id: + raise ValueError(f"TrainJob {job_name} is missing network metadata") + + network_info = self._adapter.get_network(network_id) + if not network_info: + raise ValueError(f"TrainJob {job_name} network not found") + + network_labels = network_info.get("labels", {}) + runtime_name = network_labels.get(f"{self.label_prefix}/runtime-name") + + # Get runtime object + try: + job_runtime = self.get_runtime(runtime_name) if runtime_name else None + except Exception as e: + raise ValueError(f"Runtime {runtime_name} not found for job {job_name}") from e + + if not job_runtime: + raise ValueError(f"Runtime {runtime_name} not found for job {job_name}") + + # Parse creation timestamp from first container + created_str = containers[0].get("created", "") + try: + from dateutil import parser + + creation_timestamp = parser.isoparse(created_str) + except Exception: + creation_timestamp = datetime.now() + + # Build steps from containers + steps = [] + for container in sorted(containers, key=lambda c: c["name"]): + step_name = container["labels"].get(f"{self.label_prefix}/step", "") + steps.append( + types.Step( + name=step_name, + pod_name=container["name"], + status=container_utils.get_container_status(self._adapter, container["id"]), + ) + ) + + # Get num_nodes from container count + num_nodes = len(containers) + + return types.TrainJob( + name=job_name, + creation_timestamp=creation_timestamp, + runtime=job_runtime, + steps=steps, + num_nodes=num_nodes, + status=container_utils.aggregate_container_statuses(self._adapter, containers), + ) + + def list_jobs(self, runtime: Optional[types.Runtime] = None) -> list[types.TrainJob]: + """List all training jobs by querying container runtime.""" + # Get all containers with our label prefix + filters = {"label": [f"{self.label_prefix}/trainjob-name"]} + containers = self._adapter.list_containers(filters=filters) + + # Group containers by job name + jobs_map: dict[str, list[dict]] = {} + for container in containers: + job_name = container["labels"].get(f"{self.label_prefix}/trainjob-name") + if job_name: + if job_name not in jobs_map: + jobs_map[job_name] = [] + jobs_map[job_name].append(container) + + result: list[types.TrainJob] = [] + for job_name, job_containers in jobs_map.items(): + # Skip jobs with no containers + if not job_containers: + continue + + # Filter by runtime if specified + if runtime: + network_id = job_containers[0]["labels"].get(f"{self.label_prefix}/network-id") + if network_id: + network_info = self._adapter.get_network(network_id) + if network_info: + network_labels = network_info.get("labels", {}) + runtime_name = network_labels.get(f"{self.label_prefix}/runtime-name") + if runtime_name != runtime.name: + continue + + # Build TrainJob from containers + try: + result.append(self.__get_trainjob_from_containers(job_name, job_containers)) + except Exception as e: + logger.warning(f"Failed to get TrainJob {job_name}: {e}") + continue + + return result + + def get_job(self, name: str) -> types.TrainJob: + """Get a specific training job by querying container runtime.""" + containers = self._get_job_containers(name) + return self.__get_trainjob_from_containers(name, containers) + + def get_job_logs( + self, + name: str, + follow: bool = False, + step: str = constants.NODE + "-0", + ) -> Iterator[str]: + """Get logs for a training job by querying container runtime.""" + containers = self._get_job_containers(name) + + want_all = step == constants.NODE + "-0" + for container in sorted(containers, key=lambda c: c["name"]): + container_step = container["labels"].get(f"{self.label_prefix}/step", "") + if not want_all and container_step != step: + continue + try: + yield from self._adapter.container_logs(container["id"], follow) + except Exception as e: + logger.warning(f"Failed to get logs for {container['name']}: {e}") + yield f"Error getting logs: {e}\n" + + def wait_for_job_status( + self, + name: str, + status: set[str] = {constants.TRAINJOB_COMPLETE}, + timeout: int = 600, + polling_interval: int = 2, + ) -> types.TrainJob: + import time + + end = time.time() + timeout + while time.time() < end: + tj = self.get_job(name) + logger.debug(f"TrainJob {name}, status {tj.status}") + if tj.status in status: + return tj + if constants.TRAINJOB_FAILED not in status and tj.status == constants.TRAINJOB_FAILED: + raise RuntimeError(f"TrainJob {name} is Failed") + time.sleep(polling_interval) + raise TimeoutError(f"Timeout waiting for TrainJob {name} to reach status: {status}") + + def delete_job(self, name: str): + """Delete a training job by querying container runtime.""" + containers = self._get_job_containers(name) + + # Get network_id and workdir from labels + network_id = containers[0]["labels"].get(f"{self.label_prefix}/network-id") + + # Get workdir from network labels + workdir_host = None + if network_id: + network_info = self._adapter.get_network(network_id) + if network_info: + network_labels = network_info.get("labels", {}) + workdir_host = network_labels.get(f"{self.label_prefix}/workdir") + + # Stop containers and remove + from contextlib import suppress + + for container in containers: + with suppress(Exception): + self._adapter.stop_container(container["id"], timeout=10) + with suppress(Exception): + self._adapter.remove_container(container["id"], force=True) + + # Remove network (best-effort) + if network_id: + with suppress(Exception): + self._adapter.delete_network(network_id) + + # Remove working directory if configured + if self.cfg.auto_remove and workdir_host and os.path.isdir(workdir_host): + shutil.rmtree(workdir_host, ignore_errors=True) diff --git a/kubeflow/trainer/backends/container/backend_test.py b/kubeflow/trainer/backends/container/backend_test.py new file mode 100644 index 000000000..ac13ca0f5 --- /dev/null +++ b/kubeflow/trainer/backends/container/backend_test.py @@ -0,0 +1,863 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for ContainerBackend. + +Tests the ContainerBackend class with mocked container adapters. +""" + +from collections.abc import Iterator +from contextlib import nullcontext +import os +from pathlib import Path +import shutil +import tempfile +from typing import Optional +from unittest.mock import Mock, patch + +import pytest + +from kubeflow.trainer.backends.container.adapters.base import ( + BaseContainerClientAdapter, +) +from kubeflow.trainer.backends.container.backend import ContainerBackend +from kubeflow.trainer.backends.container.types import ContainerBackendConfig +from kubeflow.trainer.constants import constants +from kubeflow.trainer.test.common import FAILED, SUCCESS, TestCase +from kubeflow.trainer.types import types + + +# Mock Container Adapter +class MockContainerAdapter(BaseContainerClientAdapter): + """Mock adapter for testing ContainerBackend without Docker/Podman.""" + + def __init__(self): + self._runtime_type = "mock" + self.networks_created = [] + self.containers_created = [] + self.containers_stopped = [] + self.containers_removed = [] + self.networks_deleted = [] + self.images_pulled = [] + self.ping_called = False + + def ping(self): + self.ping_called = True + + def create_network(self, name: str, labels: dict[str, str]) -> str: + network_id = f"net-{name}" + self.networks_created.append({"id": network_id, "name": name, "labels": labels}) + return network_id + + def delete_network(self, network_id: str): + self.networks_deleted.append(network_id) + + def create_and_start_container( + self, + image: str, + command: list[str], + name: str, + network_id: str, + environment: dict[str, str], + labels: dict[str, str], + volumes: dict[str, dict[str, str]], + working_dir: str, + ) -> str: + container_id = f"container-{len(self.containers_created)}" + self.containers_created.append( + { + "id": container_id, + "name": name, + "image": image, + "command": command, + "network": network_id, + "environment": environment, + "labels": labels, + "volumes": volumes, + "working_dir": working_dir, + "status": "running", + "exit_code": None, + } + ) + return container_id + + def get_container(self, container_id: str): + for container in self.containers_created: + if container["id"] == container_id: + return Mock(id=container_id, status=container["status"]) + return None + + def container_logs(self, container_id: str, follow: bool) -> Iterator[str]: + if follow: + yield f"Log line 1 from {container_id}\n" + yield f"Log line 2 from {container_id}\n" + else: + yield f"Complete log from {container_id}\n" + + def stop_container(self, container_id: str, timeout: int = 10): + self.containers_stopped.append(container_id) + for container in self.containers_created: + if container["id"] == container_id: + container["status"] = "exited" + container["exit_code"] = 0 + + def remove_container(self, container_id: str, force: bool = True): + self.containers_removed.append(container_id) + + def pull_image(self, image: str): + self.images_pulled.append(image) + + def image_exists(self, image: str) -> bool: + return "local" in image or image in self.images_pulled + + def run_oneoff_container(self, image: str, command: list[str]) -> str: + return "Python 3.9.0\npip 21.0.1\nnvidia-smi not found\n" + + def container_status(self, container_id: str) -> tuple[str, Optional[int]]: + for container in self.containers_created: + if container["id"] == container_id: + return (container["status"], container.get("exit_code")) + return ("unknown", None) + + def set_container_status(self, container_id: str, status: str, exit_code: Optional[int] = None): + """Helper method to set container status for testing.""" + for container in self.containers_created: + if container["id"] == container_id: + container["status"] = status + container["exit_code"] = exit_code + + def get_container_ip(self, container_id: str, network_id: str) -> Optional[str]: + """Get container IP address on a specific network.""" + for container in self.containers_created: + if container["id"] == container_id: + return f"192.168.1.{len(self.containers_created)}" + return None + + def list_containers(self, filters: Optional[dict[str, list[str]]] = None) -> list[dict]: + """List containers with optional filters.""" + if not filters: + return [ + { + "id": c["id"], + "name": c["name"], + "labels": c["labels"], + "status": c["status"], + "created": "2025-01-01T00:00:00Z", + } + for c in self.containers_created + ] + + # Simple label filtering + result = [] + for container in self.containers_created: + if "label" in filters: + match = True + for label_filter in filters["label"]: + if "=" in label_filter: + key, value = label_filter.split("=", 1) + if container["labels"].get(key) != value: + match = False + break + else: + if label_filter not in container["labels"]: + match = False + break + if match: + result.append( + { + "id": container["id"], + "name": container["name"], + "labels": container["labels"], + "status": container["status"], + "created": "2025-01-01T00:00:00Z", + } + ) + return result + + def get_network(self, network_id: str) -> Optional[dict]: + """Get network information.""" + for network in self.networks_created: + if network["id"] == network_id or network["name"] == network_id: + return { + "id": network["id"], + "name": network["name"], + "labels": network["labels"], + } + return None + + +# Fixtures +@pytest.fixture +def container_backend(): + """Provide ContainerBackend with mocked adapter.""" + with patch("kubeflow.trainer.backends.container.backend.DockerClientAdapter") as mock_docker: + mock_docker.return_value = MockContainerAdapter() + backend = ContainerBackend(ContainerBackendConfig()) + return backend + + +@pytest.fixture +def temp_workdir(): + """Provide a temporary working directory.""" + tmpdir = tempfile.mkdtemp() + yield tmpdir + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir) + + +# Helper Function +def simple_train_func(): + """Simple training function for tests.""" + print("Training") + + +# Tests +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="auto-detect docker first", + expected_status=SUCCESS, + ), + TestCase( + name="auto-detect falls back to podman", + expected_status=SUCCESS, + ), + TestCase( + name="both unavailable raises error", + expected_status=FAILED, + expected_error=RuntimeError, + ), + ], +) +def test_backend_initialization(test_case): + """Test ContainerBackend initialization and adapter creation.""" + print("Executing test:", test_case.name) + try: + if test_case.name == "auto-detect docker first": + with ( + patch( + "kubeflow.trainer.backends.container.backend.DockerClientAdapter" + ) as mock_docker, + patch( + "kubeflow.trainer.backends.container.backend.PodmanClientAdapter" + ) as mock_podman, + ): + mock_docker_instance = Mock() + mock_docker.return_value = mock_docker_instance + + _ = ContainerBackend(ContainerBackendConfig()) + + # Docker should be called (could be with Colima socket or None) + assert mock_docker.call_count == 1 + mock_docker_instance.ping.assert_called_once() + mock_podman.assert_not_called() + assert test_case.expected_status == SUCCESS + + elif test_case.name == "auto-detect falls back to podman": + with ( + patch( + "kubeflow.trainer.backends.container.backend.DockerClientAdapter" + ) as mock_docker, + patch( + "kubeflow.trainer.backends.container.backend.PodmanClientAdapter" + ) as mock_podman, + ): + mock_docker_instance = Mock() + mock_docker_instance.ping.side_effect = Exception("Docker not available") + mock_docker.return_value = mock_docker_instance + + mock_podman_instance = Mock() + mock_podman.return_value = mock_podman_instance + + _ = ContainerBackend(ContainerBackendConfig()) + + # Docker may be tried multiple times (different socket locations) + assert mock_docker.call_count >= 1 + mock_podman.assert_called_once_with(None) + mock_podman_instance.ping.assert_called_once() + assert test_case.expected_status == SUCCESS + + elif test_case.name == "both unavailable raises error": + with ( + patch( + "kubeflow.trainer.backends.container.backend.DockerClientAdapter" + ) as mock_docker, + patch( + "kubeflow.trainer.backends.container.backend.PodmanClientAdapter" + ) as mock_podman, + ): + mock_docker_instance = Mock() + mock_docker_instance.ping.side_effect = Exception("Docker not available") + mock_docker.return_value = mock_docker_instance + + mock_podman_instance = Mock() + mock_podman_instance.ping.side_effect = Exception("Podman not available") + mock_podman.return_value = mock_podman_instance + + ContainerBackend(ContainerBackendConfig()) + + except Exception as e: + assert type(e) is test_case.expected_error + print("test execution complete") + + +def test_list_runtimes(container_backend): + """Test listing available local runtimes.""" + print("Executing test: list_runtimes") + runtimes = container_backend.list_runtimes() + + assert isinstance(runtimes, list) + assert len(runtimes) > 0 + runtime_names = [r.name for r in runtimes] + assert "torch-distributed" in runtime_names + print("test execution complete") + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="get valid runtime", + expected_status=SUCCESS, + config={"name": "torch-distributed"}, + ), + TestCase( + name="get invalid runtime", + expected_status=FAILED, + config={"name": "nonexistent-runtime"}, + expected_error=ValueError, + ), + ], +) +def test_get_runtime(container_backend, test_case): + """Test getting a specific runtime.""" + print("Executing test:", test_case.name) + try: + runtime = container_backend.get_runtime(**test_case.config) + + assert test_case.expected_status == SUCCESS + assert isinstance(runtime, types.Runtime) + assert runtime.name == test_case.config["name"] + + except Exception as e: + assert type(e) is test_case.expected_error + print("test execution complete") + + +def test_get_runtime_packages(container_backend): + """Test getting runtime packages.""" + print("Executing test: get_runtime_packages") + runtime = container_backend.get_runtime("torch-distributed") + container_backend.get_runtime_packages(runtime) + + assert len( + container_backend._adapter.images_pulled + ) > 0 or container_backend._adapter.image_exists(runtime.trainer.image) + print("test execution complete") + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="train single node", + expected_status=SUCCESS, + config={"num_nodes": 1, "expected_containers": 1}, + ), + TestCase( + name="train multi-node", + expected_status=SUCCESS, + config={"num_nodes": 3, "expected_containers": 3}, + ), + TestCase( + name="train with custom env", + expected_status=SUCCESS, + config={ + "num_nodes": 1, + "env": {"MY_VAR": "my_value", "DEBUG": "true"}, + "expected_containers": 1, + }, + ), + TestCase( + name="train with packages", + expected_status=SUCCESS, + config={ + "num_nodes": 1, + "packages": ["numpy", "pandas"], + "expected_containers": 1, + }, + ), + TestCase( + name="train with single GPU", + expected_status=SUCCESS, + config={ + "num_nodes": 1, + "resources_per_node": {"gpu": "1"}, + "expected_containers": 1, + "expected_nproc_per_node": 1, + }, + ), + TestCase( + name="train with multiple GPUs", + expected_status=SUCCESS, + config={ + "num_nodes": 1, + "resources_per_node": {"gpu": "4"}, + "expected_containers": 1, + "expected_nproc_per_node": 4, + }, + ), + TestCase( + name="train multi-node with GPUs", + expected_status=SUCCESS, + config={ + "num_nodes": 2, + "resources_per_node": {"gpu": "2"}, + "expected_containers": 2, + "expected_nproc_per_node": 2, + }, + ), + TestCase( + name="train with CPU resources (nproc=1)", + expected_status=SUCCESS, + config={ + "num_nodes": 1, + "resources_per_node": {"cpu": "16"}, + "expected_containers": 1, + "expected_nproc_per_node": 1, + }, + ), + ], +) +def test_train(container_backend, test_case): + """Test training job creation.""" + print("Executing test:", test_case.name) + try: + trainer = types.CustomTrainer( + func=simple_train_func, + num_nodes=test_case.config.get("num_nodes", 1), + env=test_case.config.get("env"), + packages_to_install=test_case.config.get("packages"), + resources_per_node=test_case.config.get("resources_per_node"), + ) + runtime = container_backend.get_runtime("torch-distributed") + + job_name = container_backend.train(runtime=runtime, trainer=trainer) + + assert test_case.expected_status == SUCCESS + assert job_name is not None + assert len(job_name) == 12 + assert ( + len(container_backend._adapter.containers_created) + == test_case.config["expected_containers"] + ) + assert len(container_backend._adapter.networks_created) == 1 + + # Check environment if specified + if "env" in test_case.config: + container = container_backend._adapter.containers_created[0] + for key, value in test_case.config["env"].items(): + assert container["environment"][key] == value + + # Check packages if specified + if "packages" in test_case.config: + container = container_backend._adapter.containers_created[0] + command_str = " ".join(container["command"]) + assert "pip install" in command_str + for package in test_case.config["packages"]: + assert package in command_str + + # Check nproc_per_node if specified + if "expected_nproc_per_node" in test_case.config: + container = container_backend._adapter.containers_created[0] + command_str = container["command"][2] # Get bash script content + expected_nproc = test_case.config["expected_nproc_per_node"] + assert f"--nproc_per_node={expected_nproc}" in command_str, ( + f"Expected --nproc_per_node={expected_nproc} in command, but got: {command_str}" + ) + + except Exception as e: + assert type(e) is test_case.expected_error + print("test execution complete") + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="list all jobs", + expected_status=SUCCESS, + config={"num_jobs": 2}, + ), + TestCase( + name="list empty jobs", + expected_status=SUCCESS, + config={"num_jobs": 0}, + ), + ], +) +def test_list_jobs(container_backend, test_case): + """Test listing training jobs.""" + print("Executing test:", test_case.name) + try: + runtime = container_backend.get_runtime("torch-distributed") + created_jobs = [] + + for _ in range(test_case.config["num_jobs"]): + trainer = types.CustomTrainer(func=simple_train_func, num_nodes=1) + job_name = container_backend.train(runtime=runtime, trainer=trainer) + created_jobs.append(job_name) + + jobs = container_backend.list_jobs() + + assert test_case.expected_status == SUCCESS + assert len(jobs) == test_case.config["num_jobs"] + if test_case.config["num_jobs"] > 0: + job_names = [job.name for job in jobs] + for created_job in created_jobs: + assert created_job in job_names + + except Exception as e: + assert type(e) is test_case.expected_error + print("test execution complete") + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="get existing job", + expected_status=SUCCESS, + config={"num_nodes": 2}, + ), + TestCase( + name="get nonexistent job", + expected_status=FAILED, + config={"job_name": "nonexistent-job"}, + expected_error=ValueError, + ), + ], +) +def test_get_job(container_backend, test_case): + """Test getting a specific job.""" + print("Executing test:", test_case.name) + try: + if test_case.name == "get existing job": + trainer = types.CustomTrainer( + func=simple_train_func, num_nodes=test_case.config["num_nodes"] + ) + runtime = container_backend.get_runtime("torch-distributed") + job_name = container_backend.train(runtime=runtime, trainer=trainer) + + job = container_backend.get_job(job_name) + + assert test_case.expected_status == SUCCESS + assert job.name == job_name + assert job.num_nodes == test_case.config["num_nodes"] + assert len(job.steps) == test_case.config["num_nodes"] + + elif test_case.name == "get nonexistent job": + container_backend.get_job(test_case.config["job_name"]) + + except Exception as e: + assert type(e) is test_case.expected_error + print("test execution complete") + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="get logs no follow", + expected_status=SUCCESS, + config={"follow": False}, + ), + TestCase( + name="get logs with follow", + expected_status=SUCCESS, + config={"follow": True}, + ), + ], +) +def test_get_job_logs(container_backend, test_case): + """Test getting job logs.""" + print("Executing test:", test_case.name) + try: + trainer = types.CustomTrainer(func=simple_train_func, num_nodes=1) + runtime = container_backend.get_runtime("torch-distributed") + job_name = container_backend.train(runtime=runtime, trainer=trainer) + + logs = list(container_backend.get_job_logs(job_name, follow=test_case.config["follow"])) + + assert test_case.expected_status == SUCCESS + assert len(logs) > 0 + if test_case.config["follow"]: + assert any("Log line" in log for log in logs) + else: + assert any("Complete log" in log for log in logs) + + except Exception as e: + assert type(e) is test_case.expected_error + print("test execution complete") + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="wait for complete", + expected_status=SUCCESS, + config={"wait_status": constants.TRAINJOB_COMPLETE, "container_exit_code": 0}, + ), + TestCase( + name="wait timeout", + expected_status=FAILED, + config={"wait_status": constants.TRAINJOB_COMPLETE, "timeout": 2}, + expected_error=TimeoutError, + ), + TestCase( + name="job fails", + expected_status=FAILED, + config={"wait_status": constants.TRAINJOB_COMPLETE, "container_exit_code": 1}, + expected_error=RuntimeError, + ), + ], +) +def test_wait_for_job_status(container_backend, test_case): + """Test waiting for job status.""" + print("Executing test:", test_case.name) + try: + trainer = types.CustomTrainer(func=simple_train_func, num_nodes=1) + runtime = container_backend.get_runtime("torch-distributed") + job_name = container_backend.train(runtime=runtime, trainer=trainer) + + if test_case.name == "wait for complete": + container_id = container_backend._adapter.containers_created[0]["id"] + container_backend._adapter.set_container_status( + container_id, "exited", test_case.config["container_exit_code"] + ) + + completed_job = container_backend.wait_for_job_status( + job_name, status={test_case.config["wait_status"]}, timeout=5, polling_interval=1 + ) + + assert test_case.expected_status == SUCCESS + assert completed_job.status == constants.TRAINJOB_COMPLETE + + elif test_case.name == "wait timeout": + container_backend.wait_for_job_status( + job_name, + status={test_case.config["wait_status"]}, + timeout=test_case.config["timeout"], + polling_interval=1, + ) + + elif test_case.name == "job fails": + container_id = container_backend._adapter.containers_created[0]["id"] + container_backend._adapter.set_container_status( + container_id, "exited", test_case.config["container_exit_code"] + ) + + container_backend.wait_for_job_status( + job_name, status={test_case.config["wait_status"]}, timeout=5, polling_interval=1 + ) + + except Exception as e: + assert type(e) is test_case.expected_error + print("test execution complete") + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="delete with auto_remove true", + expected_status=SUCCESS, + config={"auto_remove": True, "num_nodes": 2}, + ), + TestCase( + name="delete with auto_remove false", + expected_status=SUCCESS, + config={"auto_remove": False, "num_nodes": 2}, + ), + ], +) +def test_delete_job(container_backend, temp_workdir, test_case): + """Test deleting a job.""" + print("Executing test:", test_case.name) + try: + container_backend.cfg.auto_remove = test_case.config["auto_remove"] + + trainer = types.CustomTrainer( + func=simple_train_func, num_nodes=test_case.config["num_nodes"] + ) + runtime = container_backend.get_runtime("torch-distributed") + job_name = container_backend.train(runtime=runtime, trainer=trainer) + + job_workdir = Path.home() / ".kubeflow" / "trainer" / "containers" / job_name + assert job_workdir.exists() + + container_backend.delete_job(job_name) + + assert test_case.expected_status == SUCCESS + assert len(container_backend._adapter.containers_stopped) == test_case.config["num_nodes"] + assert len(container_backend._adapter.containers_removed) == test_case.config["num_nodes"] + assert len(container_backend._adapter.networks_deleted) == 1 + + if test_case.config["auto_remove"]: + assert not job_workdir.exists() + else: + assert job_workdir.exists() + + except Exception as e: + assert type(e) is test_case.expected_error + print("test execution complete") + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="running container", + expected_status=SUCCESS, + config={ + "container_status": "running", + "exit_code": None, + "expected_job_status": constants.TRAINJOB_RUNNING, + }, + ), + TestCase( + name="exited success", + expected_status=SUCCESS, + config={ + "container_status": "exited", + "exit_code": 0, + "expected_job_status": constants.TRAINJOB_COMPLETE, + }, + ), + TestCase( + name="exited failure", + expected_status=SUCCESS, + config={ + "container_status": "exited", + "exit_code": 1, + "expected_job_status": constants.TRAINJOB_FAILED, + }, + ), + ], +) +def test_container_status_mapping(container_backend, test_case): + """Test container status mapping to TrainJob status.""" + print("Executing test:", test_case.name) + try: + trainer = types.CustomTrainer(func=simple_train_func, num_nodes=1) + runtime = container_backend.get_runtime("torch-distributed") + job_name = container_backend.train(runtime=runtime, trainer=trainer) + + container_id = container_backend._adapter.containers_created[0]["id"] + container_backend._adapter.set_container_status( + container_id, test_case.config["container_status"], test_case.config["exit_code"] + ) + + job = container_backend.get_job(job_name) + + assert test_case.expected_status == SUCCESS + assert job.status == test_case.config["expected_job_status"] + + except Exception as e: + assert type(e) is test_case.expected_error + print("test execution complete") + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="docker socket locations with colima", + expected_status=SUCCESS, + config={ + "runtime_name": "docker", + "container_host": None, + "create_colima_socket": True, + "expected_contains_none": True, + "expected_has_colima": True, + }, + ), + TestCase( + name="custom host has priority", + expected_status=SUCCESS, + config={ + "runtime_name": "docker", + "container_host": "unix:///custom/path/docker.sock", + "create_colima_socket": False, + "expected_first": "unix:///custom/path/docker.sock", + }, + ), + ], +) +def test_get_common_socket_locations(test_case, tmp_path): + """Test common socket location detection.""" + print("Executing test:", test_case.name) + + # Setup + if test_case.config.get("create_colima_socket"): + colima_dir = tmp_path / ".colima" / "default" + colima_dir.mkdir(parents=True) + colima_sock = colima_dir / "docker.sock" + colima_sock.touch() + + cfg = ContainerBackendConfig(container_host=test_case.config["container_host"]) + + # Test the method directly without creating the backend + context_manager = ( + patch("pathlib.Path.home", return_value=tmp_path) + if test_case.config.get("create_colima_socket") + else nullcontext() + ) + + with context_manager: + backend = ContainerBackend.__new__(ContainerBackend) + backend.cfg = cfg + locations = backend._get_common_socket_locations(test_case.config["runtime_name"]) + + # Assertions + if "expected_contains_none" in test_case.config: + assert None in locations + + if "expected_has_colima" in test_case.config: + assert f"unix://{colima_sock}" in locations + + if "expected_first" in test_case.config: + assert locations[0] == test_case.config["expected_first"] + + print("test execution complete") + + +def test_create_adapter_error_message_format(): + """Test that error message includes attempted connections.""" + cfg = ContainerBackendConfig(container_runtime="docker") + + docker_adapter = "kubeflow.trainer.backends.container.adapters.docker.DockerClientAdapter" + with patch(docker_adapter) as mock_docker: + mock_docker.side_effect = Exception("Connection failed") + + with pytest.raises(RuntimeError) as exc_info: + ContainerBackend(cfg) + + # Error message should be helpful + error_msg = str(exc_info.value) + assert "Could not connect" in error_msg + assert "tried:" in error_msg diff --git a/kubeflow/trainer/backends/container/runtime_loader.py b/kubeflow/trainer/backends/container/runtime_loader.py new file mode 100644 index 000000000..3e614400e --- /dev/null +++ b/kubeflow/trainer/backends/container/runtime_loader.py @@ -0,0 +1,612 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Runtime loader for container backends (Docker, Podman). + +We support loading training runtime definitions from multiple sources: +1. GitHub: Fetches latest runtimes from kubeflow/trainer repository (with caching) +2. Local bundled: Falls back to `kubeflow/trainer/config/training_runtimes/` YAML files +3. User custom: Additional YAML files in the local directory + +The loader tries GitHub first (with 24-hour cache), then falls back to bundled files +if the network is unavailable or GitHub fetch fails. +""" + +from __future__ import annotations + +from datetime import datetime, timedelta +import json +import logging +from pathlib import Path +from typing import Any, Optional +import urllib.error +import urllib.request + +import yaml + +from kubeflow.trainer.constants import constants +from kubeflow.trainer.types import types as base_types + +logger = logging.getLogger(__name__) + +TRAINING_RUNTIMES_DIR = Path(__file__).parents[2] / "config" / "training_runtimes" +CACHE_DIR = Path.home() / ".kubeflow" / "trainer" / "cache" +CACHE_DURATION = timedelta(hours=24) + +# GitHub runtimes configuration +GITHUB_RUNTIMES_BASE_URL = ( + "https://raw.githubusercontent.com/kubeflow/trainer/master/manifests/base/runtimes" +) +GITHUB_RUNTIMES_TREE_URL = "https://github.com/kubeflow/trainer/tree/master/manifests/base/runtimes" + +__all__ = [ + "TRAINING_RUNTIMES_DIR", + "get_training_runtime_from_sources", + "list_training_runtimes_from_sources", +] + + +def _load_runtime_from_yaml(path: Path) -> dict[str, Any]: + with open(path) as f: + data: dict[str, Any] = yaml.safe_load(f) + return data + + +def _discover_github_runtime_files( + owner: str = "kubeflow", + repo: str = "trainer", + branch: str = "master", + path: str = "manifests/base/runtimes", +) -> list[str]: + """ + Discover available runtime YAML files from GitHub repository. + + Fetches the directory listing from GitHub and extracts .yaml filenames, + excluding kustomization.yaml and other non-runtime files. + + Args: + owner: GitHub repository owner (default: "kubeflow") + repo: GitHub repository name (default: "trainer") + branch: Git branch name (default: "master") + path: Path to runtimes directory (default: "manifests/base/runtimes") + + Returns: + List of YAML filenames (e.g., ['torch_distributed.yaml', ...]) + Returns empty list if discovery fails. + """ + tree_url = f"https://github.com/{owner}/{repo}/tree/{branch}/{path}" + try: + logger.debug(f"Discovering runtimes from GitHub: {tree_url}") + with urllib.request.urlopen(tree_url, timeout=5) as response: + html_content = response.read().decode("utf-8") + + # Parse HTML to find .yaml files + # Look for .yaml filenames in the HTML content + import re + + # Pattern to match .yaml files in the HTML + # Matches word characters, hyphens, underscores followed by .yaml + pattern = r"([\w-]+\.yaml)" + matches = re.findall(pattern, html_content) + + # Filter out kustomization.yaml, config files, and duplicates + # Keep only runtime files (typically named *_distributed.yaml or similar) + runtime_files = [] + seen = set() + exclude_files = {"kustomization.yaml", "golangci.yaml", "pre-commit-config.yaml"} + + for match in matches: + filename = match + if filename not in seen and filename not in exclude_files: + runtime_files.append(filename) + seen.add(filename) + + logger.debug(f"Discovered {len(runtime_files)} runtime files: {runtime_files}") + return runtime_files + + except Exception as e: + logger.debug(f"Failed to discover GitHub runtime files from {tree_url}: {e}") + return [] + + +def _fetch_runtime_from_github( + runtime_file: str, + owner: str = "kubeflow", + repo: str = "trainer", + branch: str = "master", + path: str = "manifests/base/runtimes", +) -> Optional[dict[str, Any]]: + """ + Fetch a runtime YAML from GitHub. + + Args: + runtime_file: YAML filename to fetch + owner: GitHub repository owner (default: "kubeflow") + repo: GitHub repository name (default: "trainer") + branch: Git branch name (default: "master") + path: Path to runtimes directory (default: "manifests/base/runtimes") + + Returns None if fetch fails (network error, timeout, etc.) + """ + url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}/{runtime_file}" + try: + logger.debug(f"Fetching runtime from GitHub: {url}") + with urllib.request.urlopen(url, timeout=5) as response: + content = response.read().decode("utf-8") + data = yaml.safe_load(content) + logger.debug(f"Successfully fetched {runtime_file} from GitHub") + return data + except (urllib.error.URLError, TimeoutError, Exception) as e: + logger.debug(f"Failed to fetch {runtime_file} from GitHub: {e}") + return None + + +def _get_cached_runtime_list() -> Optional[list[str]]: + """ + Get cached runtime file list if it exists and is not expired. + + Returns None if cache doesn't exist or is expired. + """ + if not CACHE_DIR.exists(): + return None + + cache_file = CACHE_DIR / "runtime_list.json" + + if not cache_file.exists(): + return None + + try: + with open(cache_file) as f: + data = json.load(f) + + cached_time = datetime.fromisoformat(data["cached_at"]) + if datetime.now() - cached_time > CACHE_DURATION: + logger.debug("Runtime list cache expired") + return None + + logger.debug(f"Using cached runtime list: {data['files']}") + return data["files"] + except (json.JSONDecodeError, KeyError, ValueError, Exception) as e: + logger.debug(f"Failed to read runtime list cache: {e}") + return None + + +def _cache_runtime_list(files: list[str]) -> None: + """Cache the discovered runtime file list.""" + try: + CACHE_DIR.mkdir(parents=True, exist_ok=True) + cache_file = CACHE_DIR / "runtime_list.json" + + data = { + "cached_at": datetime.now().isoformat(), + "files": files, + } + with open(cache_file, "w") as f: + json.dump(data, f) + + logger.debug(f"Cached runtime list: {files}") + except Exception as e: + logger.debug(f"Failed to cache runtime list: {e}") + + +def _get_github_runtime_files() -> list[str]: + """ + Get list of runtime files from GitHub with caching. + + Priority: + 1. Check cache (if not expired) + 2. Discover from GitHub (and cache if successful) + 3. Return empty list if both fail + """ + # Try cache first + cached = _get_cached_runtime_list() + if cached is not None: + return cached + + # Try GitHub discovery + files = _discover_github_runtime_files() + if files: + _cache_runtime_list(files) + return files + + return [] + + +def _get_cached_runtime(runtime_file: str) -> Optional[dict[str, Any]]: + """ + Get cached runtime if it exists and is not expired. + + Returns None if cache doesn't exist or is expired. + """ + if not CACHE_DIR.exists(): + return None + + cache_file = CACHE_DIR / runtime_file + metadata_file = CACHE_DIR / f"{runtime_file}.metadata" + + if not cache_file.exists() or not metadata_file.exists(): + return None + + try: + # Check if cache is expired + with open(metadata_file) as f: + metadata = json.load(f) + + cached_time = datetime.fromisoformat(metadata["cached_at"]) + if datetime.now() - cached_time > CACHE_DURATION: + logger.debug(f"Cache expired for {runtime_file}") + return None + + # Load cached runtime + with open(cache_file) as f: + data = yaml.safe_load(f) + + logger.debug(f"Using cached runtime: {runtime_file}") + return data + except (json.JSONDecodeError, KeyError, ValueError, Exception) as e: + logger.debug(f"Failed to read cache for {runtime_file}: {e}") + return None + + +def _cache_runtime(runtime_file: str, data: dict[str, Any]) -> None: + """Cache a runtime YAML with metadata.""" + try: + CACHE_DIR.mkdir(parents=True, exist_ok=True) + + cache_file = CACHE_DIR / runtime_file + metadata_file = CACHE_DIR / f"{runtime_file}.metadata" + + # Write runtime data + with open(cache_file, "w") as f: + yaml.safe_dump(data, f) + + # Write metadata + metadata = { + "cached_at": datetime.now().isoformat(), + "source": "github", + } + with open(metadata_file, "w") as f: + json.dump(metadata, f) + + logger.debug(f"Cached runtime: {runtime_file}") + except Exception as e: + logger.debug(f"Failed to cache {runtime_file}: {e}") + + +def _load_runtime_from_github_with_cache(runtime_file: str) -> Optional[dict[str, Any]]: + """ + Load runtime from GitHub with caching. + + Priority: + 1. Check cache (if not expired) + 2. Fetch from GitHub (and cache if successful) + 3. Return None if both fail + + Args: + runtime_file: YAML filename to load + """ + # Try cache first + cached = _get_cached_runtime(runtime_file) + if cached is not None: + return cached + + # Try GitHub + data = _fetch_runtime_from_github(runtime_file) + if data is not None: + _cache_runtime(runtime_file, data) + return data + + return None + + +def _create_default_runtimes() -> list[base_types.Runtime]: + """ + Create default Runtime objects from DEFAULT_FRAMEWORK_IMAGES constant. + + Returns: + List of default Runtime objects for each framework. + """ + default_runtimes = [] + + for framework, image in constants.DEFAULT_FRAMEWORK_IMAGES.items(): + runtime = base_types.Runtime( + name=f"{framework}-distributed", + trainer=base_types.RuntimeTrainer( + trainer_type=base_types.TrainerType.CUSTOM_TRAINER, + framework=framework, + num_nodes=1, + ), + pretrained_model=None, + ) + default_runtimes.append(runtime) + logger.debug(f"Created default runtime: {runtime.name} with image {image}") + + return default_runtimes + + +def _parse_runtime_yaml(data: dict[str, Any], source: str = "unknown") -> base_types.Runtime: + """ + Parse a runtime YAML dict into a Runtime object. + + Args: + data: The YAML data as a dictionary + source: Source of the YAML (for error messages) + + Returns: + Runtime object + + Raises: + ValueError: If the YAML is malformed or missing required fields + """ + # Require CRD-like schema strictly. Accept both ClusterTrainingRuntime + # and TrainingRuntime kinds. + if not ( + data.get("kind") in {"ClusterTrainingRuntime", "TrainingRuntime"} and data.get("metadata") + ): + raise ValueError( + f"Runtime YAML from {source} must be a ClusterTrainingRuntime CRD-shaped document" + ) + + name = data["metadata"].get("name") + if not name: + raise ValueError(f"Runtime YAML from {source} missing metadata.name") + + labels = data["metadata"].get("labels", {}) + framework = labels.get("trainer.kubeflow.org/framework") + if not framework: + raise ValueError( + f"Runtime {name} from {source} must set " + f"metadata.labels['trainer.kubeflow.org/framework']" + ) + + spec = data.get("spec", {}) + ml_policy = spec.get("mlPolicy", {}) + num_nodes = int(ml_policy.get("numNodes", 1)) + + # Validate presence of a 'node' replicated job with a container image + templ = spec.get("template", {}).get("spec", {}) + replicated = templ.get("replicatedJobs", []) + node_jobs = [j for j in replicated if j.get("name") == "node"] + if not node_jobs: + raise ValueError( + f"Runtime {name} from {source} must define replicatedJobs with a 'node' entry" + ) + node_spec = node_jobs[0].get("template", {}).get("spec", {}).get("template", {}).get("spec", {}) + containers = node_spec.get("containers", []) + if not containers or not containers[0].get("image"): + raise ValueError(f"Runtime {name} from {source} 'node' must specify containers[0].image") + + return base_types.Runtime( + name=name, + trainer=base_types.RuntimeTrainer( + trainer_type=base_types.TrainerType.CUSTOM_TRAINER, + framework=framework, + num_nodes=num_nodes, + ), + pretrained_model=None, + ) + + +def _parse_source_url(source: str) -> tuple[str, str]: + """ + Parse a source URL to determine its type and path. + + Args: + source: Source URL with scheme (github://, https://, file://, or absolute path) + + Returns: + Tuple of (source_type, path) where source_type is one of: + 'github', 'http', 'https', 'file' + """ + if source.startswith("github://"): + return ("github", source[9:]) # Remove 'github://' + elif source.startswith("https://"): + return ("https", source) + elif source.startswith("http://"): + return ("http", source) + elif source.startswith("file://"): + return ("file", source[7:]) # Remove 'file://' + elif source.startswith("/"): + # Absolute path without file:// prefix + return ("file", source) + else: + raise ValueError( + f"Unsupported source URL scheme: {source}. " + f"Supported: github://, https://, http://, file://, or absolute paths" + ) + + +def _load_from_github_url(github_path: str) -> list[base_types.Runtime]: + """ + Load runtimes from a GitHub URL (github://owner/repo[/path]). + + Args: + github_path: Path after github:// (e.g., "kubeflow/trainer" or "myorg/myrepo") + + Returns: + List of Runtime objects loaded from GitHub + """ + runtimes = [] + runtime_names_seen = set() + + # Parse the GitHub path + # Format: owner/repo[/path/to/runtimes] + parts = github_path.split("/") + if len(parts) < 2: + logger.warning(f"Invalid GitHub path format: {github_path}. Expected owner/repo[/path]") + return runtimes + + owner = parts[0] + repo = parts[1] + # Custom path if provided (default to manifests/base/runtimes) + custom_path = "/".join(parts[2:]) if len(parts) > 2 else "manifests/base/runtimes" + + # Discover runtime files from the specified GitHub repo + logger.debug(f"Loading runtimes from GitHub: {owner}/{repo}/{custom_path}") + github_runtime_files = _discover_github_runtime_files(owner=owner, repo=repo, path=custom_path) + + for runtime_file in github_runtime_files: + try: + data = _fetch_runtime_from_github( + runtime_file, owner=owner, repo=repo, path=custom_path + ) + if data is not None: + runtime = _parse_runtime_yaml(data, source=f"github://{github_path}/{runtime_file}") + if runtime.name not in runtime_names_seen: + runtimes.append(runtime) + runtime_names_seen.add(runtime.name) + logger.debug(f"Loaded runtime from GitHub: {runtime.name}") + except Exception as e: + logger.debug(f"Failed to parse GitHub runtime {runtime_file}: {e}") + + return runtimes + + +def _load_from_http_url(url: str) -> list[base_types.Runtime]: + """ + Load runtimes from an HTTP(S) URL. + + Args: + url: HTTP(S) URL to a runtime YAML file or directory listing + + Returns: + List of Runtime objects loaded from HTTP(S) + """ + runtimes = [] + + try: + import urllib.request + + logger.debug(f"Fetching runtime from HTTP: {url}") + with urllib.request.urlopen(url, timeout=5) as response: + content = response.read().decode("utf-8") + import yaml + + data = yaml.safe_load(content) + runtime = _parse_runtime_yaml(data, source=url) + runtimes.append(runtime) + logger.debug(f"Loaded runtime from HTTP: {runtime.name}") + except Exception as e: + logger.debug(f"Failed to load runtime from HTTP {url}: {e}") + + return runtimes + + +def _load_from_filesystem(path: str) -> list[base_types.Runtime]: + """ + Load runtimes from local filesystem path. + + Args: + path: Local filesystem path to a directory or YAML file + + Returns: + List of Runtime objects loaded from filesystem + """ + from pathlib import Path + + runtimes = [] + runtime_path = Path(path).expanduser() + + try: + if runtime_path.is_dir(): + # Load all YAML files from directory + for yaml_file in sorted(runtime_path.glob("*.yaml")): + try: + data = _load_runtime_from_yaml(yaml_file) + runtime = _parse_runtime_yaml(data, source=str(yaml_file)) + runtimes.append(runtime) + logger.debug(f"Loaded runtime from file: {runtime.name}") + except Exception as e: + logger.warning(f"Failed to load runtime from {yaml_file}: {e}") + elif runtime_path.is_file(): + # Load single YAML file + data = _load_runtime_from_yaml(runtime_path) + runtime = _parse_runtime_yaml(data, source=str(runtime_path)) + runtimes.append(runtime) + logger.debug(f"Loaded runtime from file: {runtime.name}") + else: + logger.warning(f"Path does not exist: {runtime_path}") + except Exception as e: + logger.warning(f"Failed to load runtimes from {path}: {e}") + + return runtimes + + +def list_training_runtimes_from_sources(sources: list[str]) -> list[base_types.Runtime]: + """ + List all available training runtimes from configured sources. + + Args: + sources: List of source URLs with schemes (github://, https://, http://, file://, or paths) + + Returns: + List of Runtime objects (built-in runtimes used as default if not found in sources) + """ + runtimes: list[base_types.Runtime] = [] + runtime_names_seen = set() + + # Load from each configured source in priority order + for source in sources: + try: + source_type, source_path = _parse_source_url(source) + + if source_type == "github": + source_runtimes = _load_from_github_url(source_path) + elif source_type in ("http", "https"): + source_runtimes = _load_from_http_url(source) + elif source_type == "file": + source_runtimes = _load_from_filesystem(source_path) + else: + logger.warning(f"Unsupported source type: {source_type}") + continue + + # Add runtimes, skipping duplicates + for runtime in source_runtimes: + if runtime.name not in runtime_names_seen: + runtimes.append(runtime) + runtime_names_seen.add(runtime.name) + except Exception as e: + logger.debug(f"Failed to load from source {source}: {e}") + + # Fallback to default runtimes from constants if not found in sources + for default_runtime in _create_default_runtimes(): + if default_runtime.name not in runtime_names_seen: + runtimes.append(default_runtime) + runtime_names_seen.add(default_runtime.name) + + return runtimes + + +def get_training_runtime_from_sources(name: str, sources: list[str]) -> base_types.Runtime: + """ + Get a specific training runtime by name from configured sources. + + Args: + name: The name of the runtime to get + sources: List of source URLs with schemes + + Returns: + Runtime object + + Raises: + ValueError: If the runtime is not found + """ + for rt in list_training_runtimes_from_sources(sources): + if rt.name == name: + return rt + raise ValueError( + f"Runtime '{name}' not found. Available runtimes: " + f"{[rt.name for rt in list_training_runtimes_from_sources(sources)]}" + ) diff --git a/kubeflow/trainer/backends/container/runtime_loader_test.py b/kubeflow/trainer/backends/container/runtime_loader_test.py new file mode 100644 index 000000000..ea0ec2cff --- /dev/null +++ b/kubeflow/trainer/backends/container/runtime_loader_test.py @@ -0,0 +1,526 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for runtime_loader module. + +Tests runtime loading from various sources including GitHub, HTTP, and filesystem. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from kubeflow.trainer.backends.container import runtime_loader +from kubeflow.trainer.constants import constants +from kubeflow.trainer.test.common import FAILED, SUCCESS, TestCase +from kubeflow.trainer.types import types as base_types + +# Sample runtime YAML data for testing +SAMPLE_RUNTIME_YAML = { + "apiVersion": "trainer.kubeflow.org/v1alpha1", + "kind": "ClusterTrainingRuntime", + "metadata": { + "name": "torch-distributed", + "labels": {"trainer.kubeflow.org/framework": "torch"}, + }, + "spec": { + "mlPolicy": {"numNodes": 1}, + "template": { + "spec": { + "replicatedJobs": [ + { + "name": "node", + "template": { + "spec": { + "template": { + "spec": { + "containers": [ + { + "name": "trainer", + "image": "pytorch/pytorch:2.0.0", + } + ] + } + } + } + }, + } + ] + } + }, + }, +} + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="parse github url", + expected_status=SUCCESS, + config={ + "url": "github://kubeflow/trainer", + "expected_type": "github", + "expected_path": "kubeflow/trainer", + }, + ), + TestCase( + name="parse github url with path", + expected_status=SUCCESS, + config={ + "url": "github://myorg/myrepo/custom/path", + "expected_type": "github", + "expected_path": "myorg/myrepo/custom/path", + }, + ), + TestCase( + name="parse https url", + expected_status=SUCCESS, + config={ + "url": "https://example.com/runtime.yaml", + "expected_type": "https", + "expected_path": "https://example.com/runtime.yaml", + }, + ), + TestCase( + name="parse http url", + expected_status=SUCCESS, + config={ + "url": "http://example.com/runtime.yaml", + "expected_type": "http", + "expected_path": "http://example.com/runtime.yaml", + }, + ), + TestCase( + name="parse file url", + expected_status=SUCCESS, + config={ + "url": "file:///path/to/runtime.yaml", + "expected_type": "file", + "expected_path": "/path/to/runtime.yaml", + }, + ), + TestCase( + name="parse absolute path", + expected_status=SUCCESS, + config={ + "url": "/absolute/path/to/runtime.yaml", + "expected_type": "file", + "expected_path": "/absolute/path/to/runtime.yaml", + }, + ), + TestCase( + name="parse unsupported scheme", + expected_status=FAILED, + config={"url": "ftp://example.com/runtime.yaml"}, + expected_error=ValueError, + ), + ], +) +def test_parse_source_url(test_case): + """Test parsing various source URL formats.""" + print("Executing test:", test_case.name) + try: + source_type, path = runtime_loader._parse_source_url(test_case.config["url"]) + + assert test_case.expected_status == SUCCESS + assert source_type == test_case.config["expected_type"] + assert path == test_case.config["expected_path"] + + except Exception as e: + assert type(e) is test_case.expected_error + print("test execution complete") + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="load from default github", + expected_status=SUCCESS, + config={ + "github_path": "kubeflow/trainer", + "discovered_files": ["torch_distributed.yaml"], + "expected_runtime_name": "torch-distributed", + "expected_framework": "torch", + }, + ), + TestCase( + name="load from custom github", + expected_status=SUCCESS, + config={ + "github_path": "myorg/myrepo", + "discovered_files": ["custom_runtime.yaml"], + "expected_runtime_name": "custom-runtime", + "expected_framework": "custom", + }, + ), + TestCase( + name="load from github no files", + expected_status=SUCCESS, + config={ + "github_path": "kubeflow/trainer", + "discovered_files": [], + "expected_count": 0, + }, + ), + TestCase( + name="load from github invalid path", + expected_status=SUCCESS, + config={ + "github_path": "invalid", + "expected_count": 0, + }, + ), + ], +) +def test_load_from_github_url(test_case): + """Test loading runtimes from GitHub URLs.""" + print("Executing test:", test_case.name) + try: + with ( + patch( + "kubeflow.trainer.backends.container.runtime_loader._discover_github_runtime_files" + ) as mock_discover, + patch( + "kubeflow.trainer.backends.container.runtime_loader._fetch_runtime_from_github" + ) as mock_fetch, + ): + if test_case.name == "load from github invalid path": + # Don't set up mocks for invalid path test + runtimes = runtime_loader._load_from_github_url(test_case.config["github_path"]) + assert len(runtimes) == test_case.config["expected_count"] + else: + mock_discover.return_value = test_case.config.get("discovered_files", []) + + # Create runtime YAML with custom name/framework if specified + runtime_yaml = SAMPLE_RUNTIME_YAML.copy() + if "expected_runtime_name" in test_case.config: + runtime_yaml["metadata"]["name"] = test_case.config["expected_runtime_name"] + runtime_yaml["metadata"]["labels"]["trainer.kubeflow.org/framework"] = ( + test_case.config["expected_framework"] + ) + mock_fetch.return_value = runtime_yaml + + runtimes = runtime_loader._load_from_github_url(test_case.config["github_path"]) + + if "expected_count" in test_case.config: + assert len(runtimes) == test_case.config["expected_count"] + else: + assert len(runtimes) == 1 + assert runtimes[0].name == test_case.config["expected_runtime_name"] + assert runtimes[0].trainer.framework == test_case.config["expected_framework"] + + assert test_case.expected_status == SUCCESS + + except Exception as e: + assert type(e) is test_case.expected_error + print("test execution complete") + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="priority order github sources", + expected_status=SUCCESS, + config={ + "sources": ["github://myorg/myrepo", "github://kubeflow/trainer"], + "expected_count": 2, + "expected_names": ["torch-distributed", "deepspeed-distributed"], + }, + ), + TestCase( + name="duplicate runtime names skipped", + expected_status=SUCCESS, + config={ + "sources": ["github://myorg/myrepo", "github://kubeflow/trainer"], + "duplicate_names": True, + "expected_count": 1, + "expected_names": ["torch-distributed"], + }, + ), + TestCase( + name="fallback to defaults", + expected_status=SUCCESS, + config={ + "sources": ["github://myorg/myrepo"], + "no_github_runtimes": True, + "expected_count": 1, + "expected_names": ["torch-distributed"], + }, + ), + ], +) +def test_list_training_runtimes_from_sources(test_case): + """Test listing runtimes from multiple sources.""" + print("Executing test:", test_case.name) + try: + with ( + patch( + "kubeflow.trainer.backends.container.runtime_loader._load_from_github_url" + ) as mock_github, + patch( + "kubeflow.trainer.backends.container.runtime_loader._create_default_runtimes" + ) as mock_defaults, + ): + if test_case.name == "priority order github sources": + torch_runtime = base_types.Runtime( + name="torch-distributed", + trainer=base_types.RuntimeTrainer( + trainer_type=base_types.TrainerType.CUSTOM_TRAINER, + framework="torch", + num_nodes=1, + ), + ) + deepspeed_runtime = base_types.Runtime( + name="deepspeed-distributed", + trainer=base_types.RuntimeTrainer( + trainer_type=base_types.TrainerType.CUSTOM_TRAINER, + framework="deepspeed", + num_nodes=1, + ), + ) + mock_github.side_effect = [[torch_runtime], [deepspeed_runtime]] + mock_defaults.return_value = [] + + elif test_case.name == "duplicate runtime names skipped": + torch_runtime_1 = base_types.Runtime( + name="torch-distributed", + trainer=base_types.RuntimeTrainer( + trainer_type=base_types.TrainerType.CUSTOM_TRAINER, + framework="torch", + num_nodes=1, + ), + ) + torch_runtime_2 = base_types.Runtime( + name="torch-distributed", + trainer=base_types.RuntimeTrainer( + trainer_type=base_types.TrainerType.CUSTOM_TRAINER, + framework="torch", + num_nodes=2, + ), + ) + mock_github.side_effect = [[torch_runtime_1], [torch_runtime_2]] + mock_defaults.return_value = [] + + elif test_case.name == "fallback to defaults": + mock_github.return_value = [] + default_runtime = base_types.Runtime( + name="torch-distributed", + trainer=base_types.RuntimeTrainer( + trainer_type=base_types.TrainerType.CUSTOM_TRAINER, + framework="torch", + num_nodes=1, + ), + ) + mock_defaults.return_value = [default_runtime] + + runtimes = runtime_loader.list_training_runtimes_from_sources( + test_case.config["sources"] + ) + + assert len(runtimes) == test_case.config["expected_count"] + runtime_names = [r.name for r in runtimes] + for expected_name in test_case.config["expected_names"]: + assert expected_name in runtime_names + + assert test_case.expected_status == SUCCESS + + except Exception as e: + assert type(e) is test_case.expected_error + print("test execution complete") + + +def test_create_default_runtimes(): + """Test creating default runtimes from constants.""" + print("Executing test: create default runtimes") + runtimes = runtime_loader._create_default_runtimes() + + assert len(runtimes) == len(constants.DEFAULT_FRAMEWORK_IMAGES) + + # Check torch runtime + torch_runtimes = [r for r in runtimes if r.trainer.framework == "torch"] + assert len(torch_runtimes) == 1 + assert torch_runtimes[0].name == "torch-distributed" + assert torch_runtimes[0].trainer.trainer_type == base_types.TrainerType.CUSTOM_TRAINER + assert torch_runtimes[0].trainer.num_nodes == 1 + print("test execution complete") + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="discover runtime files", + expected_status=SUCCESS, + config={ + "html_content": """ + + torch_distributed.yaml + deepspeed_distributed.yaml + kustomization.yaml + + """, + "expected_files": ["torch_distributed.yaml", "deepspeed_distributed.yaml"], + "excluded_files": ["kustomization.yaml"], + }, + ), + TestCase( + name="discover runtime files custom repo", + expected_status=SUCCESS, + config={ + "html_content": """ + + custom_runtime.yaml + + """, + "expected_files": ["custom_runtime.yaml"], + "owner": "myorg", + "repo": "myrepo", + "path": "custom/path", + }, + ), + TestCase( + name="discover runtime files network error", + expected_status=SUCCESS, + config={ + "network_error": True, + "expected_files": [], + }, + ), + ], +) +def test_discover_github_runtime_files(test_case): + """Test discovering runtime files from GitHub.""" + print("Executing test:", test_case.name) + try: + with patch("urllib.request.urlopen") as mock_urlopen: + if test_case.config.get("network_error"): + mock_urlopen.side_effect = Exception("Network error") + else: + mock_response = MagicMock() + mock_response.read.return_value = test_case.config["html_content"].encode("utf-8") + mock_response.__enter__.return_value = mock_response + mock_urlopen.return_value = mock_response + + kwargs = {} + if "owner" in test_case.config: + kwargs["owner"] = test_case.config["owner"] + kwargs["repo"] = test_case.config["repo"] + kwargs["path"] = test_case.config["path"] + + files = runtime_loader._discover_github_runtime_files(**kwargs) + + for expected_file in test_case.config["expected_files"]: + assert expected_file in files + + for excluded_file in test_case.config.get("excluded_files", []): + assert excluded_file not in files + + if "owner" in test_case.config and not test_case.config.get("network_error"): + called_url = mock_urlopen.call_args[0][0] + assert f"{kwargs['owner']}/{kwargs['repo']}" in called_url + assert kwargs["path"] in called_url + + assert test_case.expected_status == SUCCESS + + except Exception as e: + assert type(e) is test_case.expected_error + print("test execution complete") + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="fetch runtime success", + expected_status=SUCCESS, + config={ + "yaml_content": """ +apiVersion: trainer.kubeflow.org/v1alpha1 +kind: ClusterTrainingRuntime +metadata: + name: torch-distributed +""", + "expected_name": "torch-distributed", + }, + ), + TestCase( + name="fetch runtime custom repo", + expected_status=SUCCESS, + config={ + "yaml_content": """ +apiVersion: trainer.kubeflow.org/v1alpha1 +kind: ClusterTrainingRuntime +metadata: + name: custom-runtime +""", + "expected_name": "custom-runtime", + "runtime_file": "custom.yaml", + "owner": "myorg", + "repo": "myrepo", + "path": "custom/path", + }, + ), + TestCase( + name="fetch runtime network error", + expected_status=SUCCESS, + config={ + "network_error": True, + "expected_none": True, + }, + ), + ], +) +def test_fetch_runtime_from_github(test_case): + """Test fetching runtime YAML from GitHub.""" + print("Executing test:", test_case.name) + try: + with patch("urllib.request.urlopen") as mock_urlopen: + if test_case.config.get("network_error"): + mock_urlopen.side_effect = Exception("Network error") + else: + mock_response = MagicMock() + mock_response.read.return_value = test_case.config["yaml_content"].encode("utf-8") + mock_response.__enter__.return_value = mock_response + mock_urlopen.return_value = mock_response + + default_runtime_file = "torch_distributed.yaml" + kwargs = {"runtime_file": test_case.config.get("runtime_file", default_runtime_file)} + if "owner" in test_case.config: + kwargs["owner"] = test_case.config["owner"] + kwargs["repo"] = test_case.config["repo"] + kwargs["path"] = test_case.config["path"] + + data = runtime_loader._fetch_runtime_from_github(**kwargs) + + if test_case.config.get("expected_none"): + assert data is None + else: + assert data is not None + assert data["metadata"]["name"] == test_case.config["expected_name"] + + if "owner" in test_case.config: + called_url = mock_urlopen.call_args[0][0] + assert "raw.githubusercontent.com" in called_url + assert f"{kwargs['owner']}/{kwargs['repo']}" in called_url + assert f"{kwargs['path']}/{kwargs['runtime_file']}" in called_url + + assert test_case.expected_status == SUCCESS + + except Exception as e: + assert type(e) is test_case.expected_error + print("test execution complete") diff --git a/kubeflow/trainer/backends/container/types.py b/kubeflow/trainer/backends/container/types.py new file mode 100644 index 000000000..f30025cb9 --- /dev/null +++ b/kubeflow/trainer/backends/container/types.py @@ -0,0 +1,67 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Types and configuration for the unified Container backend. + +This backend automatically detects and uses either Docker or Podman. +It provides a single interface for container-based execution regardless +of the underlying runtime. + +Configuration options: + - pull_policy: Controls image pulling. Supported values: "IfNotPresent", + "Always", "Never". The default is "IfNotPresent". + - auto_remove: Whether to remove containers and networks when jobs are deleted. + Defaults to True. + - container_host: Optional override for connecting to a remote/local container + daemon. By default, auto-detects from environment or uses system defaults. + For Docker: uses DOCKER_HOST or default socket. + For Podman: uses CONTAINER_HOST or default socket. + - container_runtime: Force use of a specific container runtime ("docker" or "podman"). + If not set, auto-detects based on availability (tries Docker first, then Podman). + - runtime_source: Configuration for training runtime sources using URL schemes. + Supports github://, https://, http://, file://, and absolute paths. + Built-in runtimes packaged with kubeflow-trainer are used as default fallback. +""" + +from typing import Literal, Optional + +from pydantic import BaseModel, Field + + +class TrainingRuntimeSource(BaseModel): + """Configuration for training runtime sources using URL schemes.""" + + sources: list[str] = Field( + default_factory=lambda: ["github://kubeflow/trainer"], + description=( + "Runtime sources with URL schemes (checked in priority order):\n" + " - github://owner/repo[/path] - GitHub repository\n" + " - https://url or http://url - HTTP(S) endpoint\n" + " - file:///path or /absolute/path - Local filesystem\n" + "If a runtime is not found in configured sources, built-in runtimes " + "packaged with kubeflow-trainer are used as default." + ), + ) + + +class ContainerBackendConfig(BaseModel): + pull_policy: str = Field(default="IfNotPresent") + auto_remove: bool = Field(default=True) + container_host: Optional[str] = Field(default=None) + container_runtime: Optional[Literal["docker", "podman"]] = Field(default=None) + runtime_source: TrainingRuntimeSource = Field( + default_factory=TrainingRuntimeSource, + description="Configuration for training runtime sources", + ) diff --git a/kubeflow/trainer/backends/container/utils.py b/kubeflow/trainer/backends/container/utils.py new file mode 100644 index 000000000..8642f8651 --- /dev/null +++ b/kubeflow/trainer/backends/container/utils.py @@ -0,0 +1,236 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utility functions for the Container backend. +""" + +import logging +import os +from pathlib import Path + +from kubeflow.common.constants import UNKNOWN +from kubeflow.trainer.constants import constants +from kubeflow.trainer.types import types + +logger = logging.getLogger(__name__) + + +def create_workdir(job_name: str) -> str: + """ + Create per-job working directory on host. + + Working directories are created under ~/.kubeflow/trainer/containers/ + + Args: + job_name: Name of the training job. + + Returns: + Absolute path to the working directory. + """ + home_base = Path.home() / ".kubeflow" / "trainer" / "containers" + home_base.mkdir(parents=True, exist_ok=True) + workdir = str((home_base / f"{job_name}").resolve()) + os.makedirs(workdir, exist_ok=True) + return workdir + + +def get_training_script_code(trainer: types.CustomTrainer) -> str: + """ + Generate the training script code from the trainer function. + + This extracts the function source and appends a function call, + similar to how the Kubernetes backend handles training scripts. + + Args: + trainer: CustomTrainer configuration. + + Returns: + Complete Python code as a string to execute. + """ + import inspect + import textwrap + + code = inspect.getsource(trainer.func) + code = textwrap.dedent(code) + if trainer.func_args is None: + code += f"\n{trainer.func.__name__}()\n" + else: + code += f"\n{trainer.func.__name__}(**{trainer.func_args})\n" + return code + + +def build_environment(trainer: types.CustomTrainer) -> dict[str, str]: + """ + Build environment variables for containers. + + Args: + trainer: CustomTrainer configuration. + + Returns: + Dictionary of environment variables. + """ + return dict(trainer.env or {}) + + +def build_pip_install_cmd(trainer: types.CustomTrainer) -> str: + """ + Build pip install command for packages. + + Args: + trainer: CustomTrainer configuration. + + Returns: + Pip install command string (empty if no packages to install). + """ + pkgs = trainer.packages_to_install or [] + if not pkgs: + return "" + + index_urls = trainer.pip_index_urls or list(constants.DEFAULT_PIP_INDEX_URLS) + main_idx = index_urls[0] + extras = " ".join(f"--extra-index-url {u}" for u in index_urls[1:]) + quoted = " ".join(f'"{p}"' for p in pkgs) + return ( + "PIP_DISABLE_PIP_VERSION_CHECK=1 pip install --no-warn-script-location " + f"--index-url {main_idx} {extras} {quoted} && " + ) + + +def container_status_to_trainjob_status(status: str, exit_code: int) -> str: + """ + Convert container status to TrainJob status. + + Args: + status: Container status (e.g., "running", "exited", "created"). + exit_code: Container exit code. + + Returns: + TrainJob status constant. + """ + if status == "running": + return constants.TRAINJOB_RUNNING + if status == "created": + return constants.TRAINJOB_CREATED + if status == "exited": + # Exit code 0 -> complete, else failed + return constants.TRAINJOB_COMPLETE if exit_code == 0 else constants.TRAINJOB_FAILED + return constants.UNKNOWN + + +def aggregate_status_from_containers(container_statuses: list[str]) -> str: + """ + Aggregate status from multiple container statuses. + + Args: + container_statuses: List of container status strings. + + Returns: + Aggregated TrainJob status. + """ + if constants.TRAINJOB_FAILED in container_statuses: + return constants.TRAINJOB_FAILED + if constants.TRAINJOB_RUNNING in container_statuses: + return constants.TRAINJOB_RUNNING + if all(s == constants.TRAINJOB_COMPLETE for s in container_statuses if s != UNKNOWN): + return constants.TRAINJOB_COMPLETE + if any(s == constants.TRAINJOB_CREATED for s in container_statuses): + return constants.TRAINJOB_CREATED + return UNKNOWN + + +def resolve_image(runtime: types.Runtime) -> str: + """ + Resolve the container image for a runtime from DEFAULT_FRAMEWORK_IMAGES. + + Args: + runtime: Runtime object. + + Returns: + Container image name. + + Raises: + ValueError: If no image is found for the runtime's framework. + """ + framework = runtime.trainer.framework + if framework in constants.DEFAULT_FRAMEWORK_IMAGES: + return constants.DEFAULT_FRAMEWORK_IMAGES[framework] + + raise ValueError( + f"No default image found for framework '{framework}'. " + f"Supported frameworks: {list(constants.DEFAULT_FRAMEWORK_IMAGES.keys())}" + ) + + +def maybe_pull_image(adapter, image: str, pull_policy: str): + """ + Pull image based on pull policy. + + Args: + adapter: Container client adapter (DockerClientAdapter or PodmanClientAdapter). + image: Container image name. + pull_policy: Pull policy ("IfNotPresent", "Always", or "Never"). + + Raises: + RuntimeError: If image is not found or pull fails. + """ + policy = pull_policy.lower() + try: + if policy == "never": + if not adapter.image_exists(image): + raise RuntimeError(f"Image '{image}' not found locally and pull policy is Never") + return + if policy == "always": + logger.debug(f"Pulling image (Always): {image}") + adapter.pull_image(image) + return + # IfNotPresent + if not adapter.image_exists(image): + logger.debug(f"Pulling image (IfNotPresent): {image}") + adapter.pull_image(image) + except Exception as e: + raise RuntimeError(f"Failed to ensure image '{image}': {e}") from e + + +def get_container_status(adapter, container_id: str) -> str: + """ + Get the TrainJob status of a container. + + Args: + adapter: Container client adapter (DockerClientAdapter or PodmanClientAdapter). + container_id: Container ID. + + Returns: + TrainJob status constant. + """ + try: + status, exit_code = adapter.container_status(container_id) + return container_status_to_trainjob_status(status, exit_code) + except Exception: + return constants.UNKNOWN + + +def aggregate_container_statuses(adapter, containers: list[dict]) -> str: + """ + Aggregate TrainJob status from container info dicts. + + Args: + adapter: Container client adapter (DockerClientAdapter or PodmanClientAdapter). + containers: List of container info dicts with 'id' key. + + Returns: + Aggregated TrainJob status. + """ + statuses = [get_container_status(adapter, c["id"]) for c in containers] + return aggregate_status_from_containers(statuses) diff --git a/kubeflow/trainer/constants/constants.py b/kubeflow/trainer/constants/constants.py index 5356bc5a2..a15d402ae 100644 --- a/kubeflow/trainer/constants/constants.py +++ b/kubeflow/trainer/constants/constants.py @@ -162,3 +162,8 @@ # The Instruct Datasets class in torchtune TORCH_TUNE_INSTRUCT_DATASET = "torchtune.datasets.instruct_dataset" + +# Default container images for each framework (used as fallback when runtime not provided) +DEFAULT_FRAMEWORK_IMAGES = { + "torch": "pytorch/pytorch:2.7.1-cuda12.8-cudnn9-runtime", +} diff --git a/pyproject.toml b/pyproject.toml index bcf11c1c1..570d16f90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,14 @@ dependencies = [ "kubeflow-katib-api>=0.19.0", ] +[project.optional-dependencies] +docker = [ + "docker>=6.1.3", +] +podman = [ + "podman>=5.6.0" +] + [dependency-groups] dev = [ "pytest>=7.0", @@ -100,7 +108,9 @@ select = [ ] ignore = [ - "B006", # mutable-argument-default + "B006", # mutable-argument-default + "UP007", # Use X | Y instead of Union[X, Y] (requires Python 3.10+) + "UP045", # Use X | None instead of Optional[X] (requires Python 3.10+) ] diff --git a/uv.lock b/uv.lock index 37b6e544b..64adf4957 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.9" [[package]] @@ -341,6 +341,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" }, ] +[[package]] +name = "docker" +version = "7.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "requests" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/91/9b/4a2ea29aeba62471211598dac5d96825bb49348fa07e906ea930394a83ce/docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c", size = 117834, upload-time = "2024-05-23T11:13:57.216Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774, upload-time = "2024-05-23T11:13:55.01Z" }, +] + [[package]] name = "durationpy" version = "0.10" @@ -422,6 +436,14 @@ dependencies = [ { name = "pydantic" }, ] +[package.optional-dependencies] +docker = [ + { name = "docker" }, +] +podman = [ + { name = "podman" }, +] + [package.dev-dependencies] dev = [ { name = "coverage" }, @@ -437,10 +459,13 @@ dev = [ [package.metadata] requires-dist = [ { name = "kubeflow-katib-api", specifier = ">=0.19.0" }, + { name = "docker", marker = "extra == 'docker'", specifier = ">=6.1.3" }, { name = "kubeflow-trainer-api", specifier = ">=2.0.0" }, { name = "kubernetes", specifier = ">=27.2.0" }, + { name = "podman", marker = "extra == 'podman'", specifier = ">=5.6.0" }, { name = "pydantic", specifier = ">=2.10.0" }, ] +provides-extras = ["docker", "podman"] [package.metadata.requires-dev] dev = [ @@ -537,6 +562,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "podman" +version = "5.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "requests" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3b/36/070e7bf682ac0868450584df79198c178323e80f73b8fb9b6fec8bde0a65/podman-5.6.0.tar.gz", hash = "sha256:cc5f7aa9562e30f992fc170a48da970a7132be60d8a2e2941e6c17bd0a0b35c9", size = 72832, upload-time = "2025-09-05T09:42:40.071Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0a/9e/8c62f05b104d9f00edbb4c298b152deceb393ea67f0288d89d1139d7a859/podman-5.6.0-py3-none-any.whl", hash = "sha256:967ff8ad8c6b851bc5da1a9410973882d80e235a9410b7d1e931ce0c3324fbe3", size = 88713, upload-time = "2025-09-05T09:42:38.405Z" }, +] + [[package]] name = "pre-commit" version = "4.3.0" @@ -808,6 +847,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, ] +[[package]] +name = "pywin32" +version = "311" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/40/44efbb0dfbd33aca6a6483191dae0716070ed99e2ecb0c53683f400a0b4f/pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3", size = 8760432, upload-time = "2025-07-14T20:13:05.9Z" }, + { url = "https://files.pythonhosted.org/packages/5e/bf/360243b1e953bd254a82f12653974be395ba880e7ec23e3731d9f73921cc/pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b", size = 9590103, upload-time = "2025-07-14T20:13:07.698Z" }, + { url = "https://files.pythonhosted.org/packages/57/38/d290720e6f138086fb3d5ffe0b6caa019a791dd57866940c82e4eeaf2012/pywin32-311-cp310-cp310-win_arm64.whl", hash = "sha256:0502d1facf1fed4839a9a51ccbcc63d952cf318f78ffc00a7e78528ac27d7a2b", size = 8778557, upload-time = "2025-07-14T20:13:11.11Z" }, + { url = "https://files.pythonhosted.org/packages/7c/af/449a6a91e5d6db51420875c54f6aff7c97a86a3b13a0b4f1a5c13b988de3/pywin32-311-cp311-cp311-win32.whl", hash = "sha256:184eb5e436dea364dcd3d2316d577d625c0351bf237c4e9a5fabbcfa5a58b151", size = 8697031, upload-time = "2025-07-14T20:13:13.266Z" }, + { url = "https://files.pythonhosted.org/packages/51/8f/9bb81dd5bb77d22243d33c8397f09377056d5c687aa6d4042bea7fbf8364/pywin32-311-cp311-cp311-win_amd64.whl", hash = "sha256:3ce80b34b22b17ccbd937a6e78e7225d80c52f5ab9940fe0506a1a16f3dab503", size = 9508308, upload-time = "2025-07-14T20:13:15.147Z" }, + { url = "https://files.pythonhosted.org/packages/44/7b/9c2ab54f74a138c491aba1b1cd0795ba61f144c711daea84a88b63dc0f6c/pywin32-311-cp311-cp311-win_arm64.whl", hash = "sha256:a733f1388e1a842abb67ffa8e7aad0e70ac519e09b0f6a784e65a136ec7cefd2", size = 8703930, upload-time = "2025-07-14T20:13:16.945Z" }, + { url = "https://files.pythonhosted.org/packages/e7/ab/01ea1943d4eba0f850c3c61e78e8dd59757ff815ff3ccd0a84de5f541f42/pywin32-311-cp312-cp312-win32.whl", hash = "sha256:750ec6e621af2b948540032557b10a2d43b0cee2ae9758c54154d711cc852d31", size = 8706543, upload-time = "2025-07-14T20:13:20.765Z" }, + { url = "https://files.pythonhosted.org/packages/d1/a8/a0e8d07d4d051ec7502cd58b291ec98dcc0c3fff027caad0470b72cfcc2f/pywin32-311-cp312-cp312-win_amd64.whl", hash = "sha256:b8c095edad5c211ff31c05223658e71bf7116daa0ecf3ad85f3201ea3190d067", size = 9495040, upload-time = "2025-07-14T20:13:22.543Z" }, + { url = "https://files.pythonhosted.org/packages/ba/3a/2ae996277b4b50f17d61f0603efd8253cb2d79cc7ae159468007b586396d/pywin32-311-cp312-cp312-win_arm64.whl", hash = "sha256:e286f46a9a39c4a18b319c28f59b61de793654af2f395c102b4f819e584b5852", size = 8710102, upload-time = "2025-07-14T20:13:24.682Z" }, + { url = "https://files.pythonhosted.org/packages/a5/be/3fd5de0979fcb3994bfee0d65ed8ca9506a8a1260651b86174f6a86f52b3/pywin32-311-cp313-cp313-win32.whl", hash = "sha256:f95ba5a847cba10dd8c4d8fefa9f2a6cf283b8b88ed6178fa8a6c1ab16054d0d", size = 8705700, upload-time = "2025-07-14T20:13:26.471Z" }, + { url = "https://files.pythonhosted.org/packages/e3/28/e0a1909523c6890208295a29e05c2adb2126364e289826c0a8bc7297bd5c/pywin32-311-cp313-cp313-win_amd64.whl", hash = "sha256:718a38f7e5b058e76aee1c56ddd06908116d35147e133427e59a3983f703a20d", size = 9494700, upload-time = "2025-07-14T20:13:28.243Z" }, + { url = "https://files.pythonhosted.org/packages/04/bf/90339ac0f55726dce7d794e6d79a18a91265bdf3aa70b6b9ca52f35e022a/pywin32-311-cp313-cp313-win_arm64.whl", hash = "sha256:7b4075d959648406202d92a2310cb990fea19b535c7f4a78d3f5e10b926eeb8a", size = 8709318, upload-time = "2025-07-14T20:13:30.348Z" }, + { url = "https://files.pythonhosted.org/packages/c9/31/097f2e132c4f16d99a22bfb777e0fd88bd8e1c634304e102f313af69ace5/pywin32-311-cp314-cp314-win32.whl", hash = "sha256:b7a2c10b93f8986666d0c803ee19b5990885872a7de910fc460f9b0c2fbf92ee", size = 8840714, upload-time = "2025-07-14T20:13:32.449Z" }, + { url = "https://files.pythonhosted.org/packages/90/4b/07c77d8ba0e01349358082713400435347df8426208171ce297da32c313d/pywin32-311-cp314-cp314-win_amd64.whl", hash = "sha256:3aca44c046bd2ed8c90de9cb8427f581c479e594e99b5c0bb19b29c10fd6cb87", size = 9656800, upload-time = "2025-07-14T20:13:34.312Z" }, + { url = "https://files.pythonhosted.org/packages/c0/d2/21af5c535501a7233e734b8af901574572da66fcc254cb35d0609c9080dd/pywin32-311-cp314-cp314-win_arm64.whl", hash = "sha256:a508e2d9025764a8270f93111a970e1d0fbfc33f4153b388bb649b7eec4f9b42", size = 8932540, upload-time = "2025-07-14T20:13:36.379Z" }, + { url = "https://files.pythonhosted.org/packages/59/42/b86689aac0cdaee7ae1c58d464b0ff04ca909c19bb6502d4973cdd9f9544/pywin32-311-cp39-cp39-win32.whl", hash = "sha256:aba8f82d551a942cb20d4a83413ccbac30790b50efb89a75e4f586ac0bb8056b", size = 8760837, upload-time = "2025-07-14T20:12:59.59Z" }, + { url = "https://files.pythonhosted.org/packages/9f/8a/1403d0353f8c5a2f0829d2b1c4becbf9da2f0a4d040886404fc4a5431e4d/pywin32-311-cp39-cp39-win_amd64.whl", hash = "sha256:e0c4cfb0621281fe40387df582097fd796e80430597cb9944f0ae70447bacd91", size = 9590187, upload-time = "2025-07-14T20:13:01.419Z" }, + { url = "https://files.pythonhosted.org/packages/60/22/e0e8d802f124772cec9c75430b01a212f86f9de7546bda715e54140d5aeb/pywin32-311-cp39-cp39-win_arm64.whl", hash = "sha256:62ea666235135fee79bb154e695f3ff67370afefd71bd7fea7512fc70ef31e3d", size = 8778162, upload-time = "2025-07-14T20:13:03.544Z" }, +] + [[package]] name = "pyyaml" version = "6.0.2"