diff --git a/README.md b/README.md
index bf9a4eeed..7e94d6535 100644
--- a/README.md
+++ b/README.md
@@ -71,6 +71,29 @@ TrainerClient().wait_for_job_status(job_id)
print("\n".join(TrainerClient().get_job_logs(name=job_id)))
```
+## Local Development
+
+Kubeflow Trainer client supports local development without needing a Kubernetes cluster.
+
+### Available Backends
+
+- **KubernetesBackend** (default) - Production training on Kubernetes
+- **ContainerBackend** - Local development with Docker/Podman isolation
+- **LocalProcessBackend** - Quick prototyping with Python subprocesses
+
+**Quick Start:**
+Install container support: `pip install kubeflow[docker]` or `pip install kubeflow[podman]`
+
+```python
+from kubeflow.trainer import TrainerClient, ContainerBackendConfig, CustomTrainer
+
+# Switch to local container execution
+client = TrainerClient(backend_config=ContainerBackendConfig())
+
+# Your training runs locally in isolated containers
+job_id = client.train(trainer=CustomTrainer(func=train_fn))
+```
+
## Supported Kubeflow Projects
| Project | Status | Version Support | Description |
diff --git a/kubeflow/trainer/__init__.py b/kubeflow/trainer/__init__.py
index 1fce8e0f5..a61867b6a 100644
--- a/kubeflow/trainer/__init__.py
+++ b/kubeflow/trainer/__init__.py
@@ -18,6 +18,10 @@
# Import the Kubeflow Trainer client.
from kubeflow.trainer.api.trainer_client import TrainerClient
+from kubeflow.trainer.backends.container.types import (
+ ContainerBackendConfig,
+ TrainingRuntimeSource,
+)
from kubeflow.trainer.backends.localprocess.types import LocalProcessBackendConfig
# Import the Kubeflow Trainer constants.
@@ -64,5 +68,7 @@
"TrainerClient",
"TrainerType",
"LocalProcessBackendConfig",
+ "ContainerBackendConfig",
"KubernetesBackendConfig",
+ "TrainingRuntimeSource",
]
diff --git a/kubeflow/trainer/api/trainer_client.py b/kubeflow/trainer/api/trainer_client.py
index cb96837dd..79613f541 100644
--- a/kubeflow/trainer/api/trainer_client.py
+++ b/kubeflow/trainer/api/trainer_client.py
@@ -17,6 +17,8 @@
from typing import Optional, Union
from kubeflow.common.types import KubernetesBackendConfig
+from kubeflow.trainer.backends.container.backend import ContainerBackend
+from kubeflow.trainer.backends.container.types import ContainerBackendConfig
from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend
from kubeflow.trainer.backends.localprocess.backend import (
LocalProcessBackend,
@@ -31,14 +33,19 @@
class TrainerClient:
def __init__(
self,
- backend_config: Optional[Union[KubernetesBackendConfig, LocalProcessBackendConfig]] = None,
+ backend_config: Union[
+ KubernetesBackendConfig,
+ LocalProcessBackendConfig,
+ ContainerBackendConfig,
+ ] = None,
):
"""Initialize a Kubeflow Trainer client.
Args:
- backend_config: Backend configuration. Either KubernetesBackendConfig or
- LocalProcessBackendConfig, or None to use the backend's
- default config class. Defaults to KubernetesBackendConfig.
+ backend_config: Backend configuration. Either KubernetesBackendConfig,
+ LocalProcessBackendConfig, ContainerBackendConfig,
+ or None to use the backend's default config class.
+ Defaults to KubernetesBackendConfig.
Raises:
ValueError: Invalid backend configuration.
@@ -52,6 +59,8 @@ def __init__(
self.backend = KubernetesBackend(backend_config)
elif isinstance(backend_config, LocalProcessBackendConfig):
self.backend = LocalProcessBackend(backend_config)
+ elif isinstance(backend_config, ContainerBackendConfig):
+ self.backend = ContainerBackend(backend_config)
else:
raise ValueError(f"Invalid backend config '{backend_config}'")
diff --git a/kubeflow/trainer/backends/container/adapters/base.py b/kubeflow/trainer/backends/container/adapters/base.py
new file mode 100644
index 000000000..3e38a6b89
--- /dev/null
+++ b/kubeflow/trainer/backends/container/adapters/base.py
@@ -0,0 +1,195 @@
+# Copyright 2025 The Kubeflow Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Container client adapters for Docker and Podman.
+
+This module implements the adapter pattern to abstract away differences between
+Docker and Podman APIs, allowing the backend to work with either runtime through
+a common interface.
+"""
+
+from __future__ import annotations
+
+import abc
+from collections.abc import Iterator
+from typing import Optional
+
+
+class BaseContainerClientAdapter(abc.ABC):
+ """
+ Abstract adapter interface for container clients.
+
+ This adapter abstracts the container runtime API, allowing the backend
+ to work with Docker and Podman through a unified interface.
+ """
+
+ @abc.abstractmethod
+ def ping(self):
+ """Test the connection to the container runtime."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def create_network(
+ self,
+ name: str,
+ labels: dict[str, str],
+ ) -> str:
+ """
+ Create a container network.
+
+ Args:
+ name: Network name
+ labels: Labels to attach to the network
+
+ Returns:
+ Network ID or name
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def delete_network(self, network_id: str):
+ """Delete a network."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def create_and_start_container(
+ self,
+ image: str,
+ command: list[str],
+ name: str,
+ network_id: str,
+ environment: dict[str, str],
+ labels: dict[str, str],
+ volumes: dict[str, dict[str, str]],
+ working_dir: str,
+ ) -> str:
+ """
+ Create and start a container.
+
+ Args:
+ image: Container image
+ command: Command to run
+ name: Container name
+ network_id: Network to attach to
+ environment: Environment variables
+ labels: Container labels
+ volumes: Volume mounts
+ working_dir: Working directory
+
+ Returns:
+ Container ID
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_container(self, container_id: str):
+ """Get container object by ID."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def container_logs(self, container_id: str, follow: bool) -> Iterator[str]:
+ """Stream logs from a container."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def stop_container(self, container_id: str, timeout: int = 10):
+ """Stop a container."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def remove_container(self, container_id: str, force: bool = True):
+ """Remove a container."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def pull_image(self, image: str):
+ """Pull an image."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def image_exists(self, image: str) -> bool:
+ """Check if an image exists locally."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def run_oneoff_container(self, image: str, command: list[str]) -> str:
+ """
+ Run a short-lived container and return its output.
+
+ Args:
+ image: Container image
+ command: Command to run
+
+ Returns:
+ Container output as string
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def container_status(self, container_id: str) -> tuple[str, Optional[int]]:
+ """
+ Get container status.
+
+ Returns:
+ Tuple of (status_string, exit_code)
+ Status strings: "running", "created", "exited", etc.
+ Exit code is None if container hasn't exited
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_container_ip(self, container_id: str, network_id: str) -> Optional[str]:
+ """
+ Get container's IP address on a specific network.
+
+ Args:
+ container_id: Container ID
+ network_id: Network name or ID
+
+ Returns:
+ IP address string or None if not found
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def list_containers(self, filters: Optional[dict[str, list[str]]] = None) -> list[dict]:
+ """
+ List containers, optionally filtered by labels.
+
+ Args:
+ filters: Dictionary of filters (e.g., {"label": ["key=value"]})
+
+ Returns:
+ List of container info dictionaries with keys:
+ - id: Container ID
+ - name: Container name
+ - labels: Dictionary of labels
+ - status: Container status
+ - created: Creation timestamp
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_network(self, network_id: str) -> Optional[dict]:
+ """
+ Get network information by ID or name.
+
+ Args:
+ network_id: Network ID or name
+
+ Returns:
+ Dictionary with network info including labels, or None if not found
+ """
+ raise NotImplementedError()
diff --git a/kubeflow/trainer/backends/container/adapters/docker.py b/kubeflow/trainer/backends/container/adapters/docker.py
new file mode 100644
index 000000000..c4c0e1205
--- /dev/null
+++ b/kubeflow/trainer/backends/container/adapters/docker.py
@@ -0,0 +1,229 @@
+# Copyright 2025 The Kubeflow Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Docker client adapter implementation.
+
+This module provides the DockerClientAdapter class that implements the
+BaseContainerClientAdapter interface for Docker runtime.
+"""
+
+from collections.abc import Iterator
+from typing import Optional
+
+from kubeflow.trainer.backends.container.adapters.base import BaseContainerClientAdapter
+
+
+class DockerClientAdapter(BaseContainerClientAdapter):
+ """Adapter for Docker client."""
+
+ def __init__(self, host: Optional[str] = None):
+ """
+ Initialize Docker client.
+
+ Args:
+ host: Docker host URL, or None to use environment defaults
+ """
+ try:
+ import docker # type: ignore
+ except ImportError as e:
+ raise ImportError(
+ "The 'docker' Python package is not installed. Install with extras: "
+ "pip install kubeflow[docker]"
+ ) from e
+
+ if host:
+ self.client = docker.DockerClient(base_url=host)
+ else:
+ self.client = docker.from_env()
+
+ self._runtime_type = "docker"
+
+ def ping(self):
+ """Test connection to Docker daemon."""
+ self.client.ping()
+
+ def create_network(self, name: str, labels: dict[str, str]) -> str:
+ """Create a Docker network."""
+ try:
+ self.client.networks.get(name)
+ return name
+ except Exception:
+ pass
+
+ self.client.networks.create(
+ name=name,
+ check_duplicate=True,
+ labels=labels,
+ )
+ return name
+
+ def delete_network(self, network_id: str):
+ """Delete Docker network."""
+ try:
+ net = self.client.networks.get(network_id)
+ net.remove()
+ except Exception:
+ pass
+
+ def create_and_start_container(
+ self,
+ image: str,
+ command: list[str],
+ name: str,
+ network_id: str,
+ environment: dict[str, str],
+ labels: dict[str, str],
+ volumes: dict[str, dict[str, str]],
+ working_dir: str,
+ ) -> str:
+ """Create and start a Docker container."""
+ container = self.client.containers.run(
+ image=image,
+ command=tuple(command),
+ name=name,
+ detach=True,
+ working_dir=working_dir,
+ network=network_id,
+ environment=environment,
+ labels=labels,
+ volumes=volumes,
+ auto_remove=False,
+ )
+ return container.id
+
+ def get_container(self, container_id: str):
+ """Get Docker container by ID."""
+ return self.client.containers.get(container_id)
+
+ def container_logs(self, container_id: str, follow: bool) -> Iterator[str]:
+ """Stream logs from Docker container."""
+ container = self.get_container(container_id)
+ logs = container.logs(stream=bool(follow), follow=bool(follow))
+ if follow:
+ for chunk in logs:
+ if isinstance(chunk, bytes):
+ yield chunk.decode("utf-8", errors="ignore")
+ else:
+ yield str(chunk)
+ else:
+ if isinstance(logs, bytes):
+ yield logs.decode("utf-8", errors="ignore")
+ else:
+ yield str(logs)
+
+ def stop_container(self, container_id: str, timeout: int = 10):
+ """Stop Docker container."""
+ container = self.get_container(container_id)
+ container.stop(timeout=timeout)
+
+ def remove_container(self, container_id: str, force: bool = True):
+ """Remove Docker container."""
+ container = self.get_container(container_id)
+ container.remove(force=force)
+
+ def pull_image(self, image: str):
+ """Pull Docker image."""
+ self.client.images.pull(image)
+
+ def image_exists(self, image: str) -> bool:
+ """Check if Docker image exists locally."""
+ try:
+ self.client.images.get(image)
+ return True
+ except Exception:
+ return False
+
+ def run_oneoff_container(self, image: str, command: list[str]) -> str:
+ """Run a short-lived Docker container and return output."""
+ try:
+ output = self.client.containers.run(
+ image=image,
+ command=tuple(command),
+ detach=False,
+ remove=True,
+ )
+ if isinstance(output, (bytes, bytearray)):
+ return output.decode("utf-8", errors="ignore")
+ return str(output)
+ except Exception as e:
+ raise RuntimeError(f"One-off container failed to run: {e}") from e
+
+ def container_status(self, container_id: str) -> tuple[str, Optional[int]]:
+ """Get Docker container status."""
+ try:
+ container = self.get_container(container_id)
+ status = container.status
+ # Get exit code if container has exited
+ exit_code = None
+ if status == "exited":
+ inspect = container.attrs if hasattr(container, "attrs") else container.inspect()
+ exit_code = inspect.get("State", {}).get("ExitCode")
+ return (status, exit_code)
+ except Exception:
+ return ("unknown", None)
+
+ def get_container_ip(self, container_id: str, network_id: str) -> Optional[str]:
+ """Get container's IP address on a specific network."""
+ try:
+ container = self.get_container(container_id)
+ # Refresh container info
+ container.reload()
+ # Get network settings
+ networks = container.attrs.get("NetworkSettings", {}).get("Networks", {})
+
+ # Try to find the network by exact name or ID
+ if network_id in networks:
+ return networks[network_id].get("IPAddress")
+
+ # Fallback: return first available IP
+ for _net_name, net_info in networks.items():
+ ip = net_info.get("IPAddress")
+ if ip:
+ return ip
+
+ return None
+ except Exception:
+ return None
+
+ def list_containers(self, filters: Optional[dict[str, str]] = None) -> list[dict]:
+ """List Docker containers with optional filters."""
+ try:
+ containers = self.client.containers.list(all=True, filters=filters)
+ result = []
+ for c in containers:
+ result.append(
+ {
+ "id": c.id,
+ "name": c.name,
+ "labels": c.labels,
+ "status": c.status,
+ "created": c.attrs.get("Created", ""),
+ }
+ )
+ return result
+ except Exception:
+ return []
+
+ def get_network(self, network_id: str) -> Optional[dict]:
+ """Get Docker network information."""
+ try:
+ network = self.client.networks.get(network_id)
+ return {
+ "id": network.id,
+ "name": network.name,
+ "labels": network.attrs.get("Labels", {}),
+ }
+ except Exception:
+ return None
diff --git a/kubeflow/trainer/backends/container/adapters/podman.py b/kubeflow/trainer/backends/container/adapters/podman.py
new file mode 100644
index 000000000..72274e344
--- /dev/null
+++ b/kubeflow/trainer/backends/container/adapters/podman.py
@@ -0,0 +1,247 @@
+# Copyright 2025 The Kubeflow Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Podman client adapter implementation.
+
+This module provides the PodmanClientAdapter class that implements the
+BaseContainerClientAdapter interface for Podman runtime.
+
+Key differences from Docker:
+- Uses DNS-enabled bridge networks for better container name resolution
+- GPU support via CDI (Container Device Interface) instead of NVIDIA Container Toolkit
+- Slightly different API for some operations (e.g., container.create + start pattern)
+"""
+
+from collections.abc import Iterator
+from typing import Optional
+
+from kubeflow.trainer.backends.container.adapters.base import BaseContainerClientAdapter
+
+
+class PodmanClientAdapter(BaseContainerClientAdapter):
+ """Adapter for Podman client."""
+
+ def __init__(self, host: Optional[str] = None):
+ """
+ Initialize Podman client.
+
+ Args:
+ host: Podman host URL, or None to use environment defaults
+ """
+ try:
+ import podman # type: ignore
+ except ImportError as e:
+ raise ImportError(
+ "The 'podman' Python package is not installed. Install with extras: "
+ "pip install kubeflow[podman]"
+ ) from e
+
+ if host:
+ self.client = podman.PodmanClient(base_url=host)
+ else:
+ self.client = podman.PodmanClient()
+
+ self._runtime_type = "podman"
+
+ def ping(self):
+ """Test connection to Podman."""
+ self.client.ping()
+
+ def create_network(self, name: str, labels: dict[str, str]) -> str:
+ """Create a Podman network with DNS enabled."""
+ try:
+ self.client.networks.get(name)
+ return name
+ except Exception:
+ pass
+
+ self.client.networks.create(
+ name=name,
+ driver="bridge",
+ dns_enabled=True,
+ labels=labels,
+ )
+ return name
+
+ def delete_network(self, network_id: str):
+ """Delete Podman network."""
+ try:
+ net = self.client.networks.get(network_id)
+ net.remove()
+ except Exception:
+ pass
+
+ def create_and_start_container(
+ self,
+ image: str,
+ command: list[str],
+ name: str,
+ network_id: str,
+ environment: dict[str, str],
+ labels: dict[str, str],
+ volumes: dict[str, dict[str, str]],
+ working_dir: str,
+ ) -> str:
+ """Create and start a Podman container."""
+ container = self.client.containers.run(
+ image=image,
+ command=command,
+ name=name,
+ network=network_id,
+ working_dir=working_dir,
+ environment=environment,
+ labels=labels,
+ volumes=volumes,
+ detach=True,
+ remove=False,
+ )
+ return container.id
+
+ def get_container(self, container_id: str):
+ """Get Podman container by ID."""
+ return self.client.containers.get(container_id)
+
+ def container_logs(self, container_id: str, follow: bool) -> Iterator[str]:
+ """Stream logs from Podman container."""
+ container = self.get_container(container_id)
+ logs = container.logs(stream=bool(follow), follow=bool(follow))
+ if follow:
+ for chunk in logs:
+ if isinstance(chunk, bytes):
+ yield chunk.decode("utf-8", errors="ignore")
+ else:
+ yield str(chunk)
+ else:
+ if isinstance(logs, bytes):
+ yield logs.decode("utf-8", errors="ignore")
+ else:
+ yield str(logs)
+
+ def stop_container(self, container_id: str, timeout: int = 10):
+ """Stop Podman container."""
+ container = self.get_container(container_id)
+ container.stop(timeout=timeout)
+
+ def remove_container(self, container_id: str, force: bool = True):
+ """Remove Podman container."""
+ container = self.get_container(container_id)
+ container.remove(force=force)
+
+ def pull_image(self, image: str):
+ """Pull Podman image."""
+ self.client.images.pull(image)
+
+ def image_exists(self, image: str) -> bool:
+ """Check if Podman image exists locally."""
+ try:
+ self.client.images.get(image)
+ return True
+ except Exception:
+ return False
+
+ def run_oneoff_container(self, image: str, command: list[str]) -> str:
+ """Run a short-lived Podman container and return output."""
+ try:
+ container = self.client.containers.create(
+ image=image,
+ command=command,
+ detach=False,
+ remove=True,
+ )
+ container.start()
+ container.wait()
+ logs = container.logs()
+
+ if isinstance(logs, (bytes, bytearray)):
+ return logs.decode("utf-8", errors="ignore")
+ return str(logs)
+ except Exception as e:
+ raise RuntimeError(f"One-off container failed to run: {e}") from e
+
+ def container_status(self, container_id: str) -> tuple[str, Optional[int]]:
+ """Get Podman container status."""
+ try:
+ container = self.get_container(container_id)
+ status = container.status
+ # Get exit code if container has exited
+ exit_code = None
+ if status == "exited":
+ inspect = container.attrs if hasattr(container, "attrs") else container.inspect()
+ exit_code = inspect.get("State", {}).get("ExitCode")
+ return (status, exit_code)
+ except Exception:
+ return ("unknown", None)
+
+ def get_container_ip(self, container_id: str, network_id: str) -> Optional[str]:
+ """Get container's IP address on a specific network."""
+ try:
+ container = self.get_container(container_id)
+ # Get container inspect data
+ inspect = container.attrs if hasattr(container, "attrs") else container.inspect()
+
+ # Get network settings - Podman structure is similar to Docker
+ networks = inspect.get("NetworkSettings", {}).get("Networks", {})
+
+ # Try to find the network by exact name or ID
+ if network_id in networks:
+ return networks[network_id].get("IPAddress")
+
+ # Fallback: return first available IP
+ for _net_name, net_info in networks.items():
+ ip = net_info.get("IPAddress")
+ if ip:
+ return ip
+
+ return None
+ except Exception:
+ return None
+
+ def list_containers(self, filters: Optional[dict[str, str]] = None) -> list[dict]:
+ """List Podman containers with optional filters."""
+ try:
+ containers = self.client.containers.list(all=True, filters=filters)
+ result = []
+ for c in containers:
+ inspect = c.attrs if hasattr(c, "attrs") else c.inspect()
+ labels = (
+ c.labels
+ if hasattr(c, "labels")
+ else inspect.get("Config", {}).get("Labels", {})
+ )
+ result.append(
+ {
+ "id": c.id,
+ "name": c.name,
+ "labels": labels,
+ "status": c.status,
+ "created": inspect.get("Created", ""),
+ }
+ )
+ return result
+ except Exception:
+ return []
+
+ def get_network(self, network_id: str) -> Optional[dict]:
+ """Get Podman network information."""
+ try:
+ network = self.client.networks.get(network_id)
+ inspect = network.attrs if hasattr(network, "attrs") else network.inspect()
+ return {
+ "id": inspect.get("ID", network_id),
+ "name": inspect.get("Name", network_id),
+ "labels": inspect.get("Labels", {}),
+ }
+ except Exception:
+ return None
diff --git a/kubeflow/trainer/backends/container/backend.py b/kubeflow/trainer/backends/container/backend.py
new file mode 100644
index 000000000..7408d83aa
--- /dev/null
+++ b/kubeflow/trainer/backends/container/backend.py
@@ -0,0 +1,646 @@
+# Copyright 2025 The Kubeflow Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+ContainerBackend
+----------------
+
+Unified local execution backend for `CustomTrainer` jobs using containers.
+
+This backend automatically detects and uses either Docker or Podman.
+It provides a single interface regardless of the underlying container runtime.
+
+Key behaviors:
+- Auto-detection: Tries Docker first, then Podman. Can be overridden via config.
+- Multi-node jobs: one container per node connected via a per-job network.
+- Entry script generation: we serialize the user's training function and embed it
+ inline in the container command using a heredoc (no file I/O on the host). The
+ script is created inside the container at /tmp/train.py and invoked using
+ `torchrun` (preferred) or `python` as a fallback.
+- Runtimes: we use `config/training_runtimes` to define runtime images and
+ characteristics (e.g., torch). Defaults to `torch-distributed` if no runtime
+ is provided.
+- Image pulling: controlled via `pull_policy` and performed automatically if
+ needed.
+- Logs and lifecycle: streaming logs and deletion semantics similar to the
+ Docker/Podman backends, but with automatic runtime detection.
+"""
+
+from __future__ import annotations
+
+from collections.abc import Iterator
+from datetime import datetime
+import logging
+import os
+import random
+import shutil
+import string
+from typing import Optional, Union
+import uuid
+
+from kubeflow.trainer.backends.base import RuntimeBackend
+from kubeflow.trainer.backends.container import utils as container_utils
+from kubeflow.trainer.backends.container.adapters.base import (
+ BaseContainerClientAdapter,
+)
+from kubeflow.trainer.backends.container.adapters.docker import DockerClientAdapter
+from kubeflow.trainer.backends.container.adapters.podman import PodmanClientAdapter
+from kubeflow.trainer.backends.container.runtime_loader import (
+ get_training_runtime_from_sources,
+ list_training_runtimes_from_sources,
+)
+from kubeflow.trainer.backends.container.types import ContainerBackendConfig
+from kubeflow.trainer.constants import constants
+from kubeflow.trainer.types import types
+
+logger = logging.getLogger(__name__)
+
+
+class ContainerBackend(RuntimeBackend):
+ """
+ Unified container backend that auto-detects Docker or Podman.
+
+ This backend uses the adapter pattern to abstract away differences between
+ Docker and Podman, providing a single consistent interface.
+ """
+
+ def __init__(self, cfg: ContainerBackendConfig):
+ self.cfg = cfg
+ self.label_prefix = "trainer.kubeflow.org"
+
+ # Initialize the container client adapter
+ self._adapter = self._create_adapter()
+
+ def _get_common_socket_locations(self, runtime_name: str) -> list[Optional[str]]:
+ """
+ Get common socket locations to try for the given runtime.
+
+ Args:
+ runtime_name: "docker" or "podman"
+
+ Returns:
+ List of socket URLs to try, including None (for default)
+ """
+ import os
+ from pathlib import Path
+
+ locations = [self.cfg.container_host] if self.cfg.container_host else []
+
+ if runtime_name == "docker":
+ # Common Docker socket locations
+ colima_sock = Path.home() / ".colima/default/docker.sock"
+ if colima_sock.exists():
+ locations.append(f"unix://{colima_sock}")
+ # Standard Docker socket
+ locations.append(None) # Use docker.from_env() default
+
+ elif runtime_name == "podman":
+ # Common Podman socket locations on macOS
+ uid = os.getuid() if hasattr(os, "getuid") else None
+ if uid:
+ user_sock = f"/run/user/{uid}/podman/podman.sock"
+ if Path(user_sock).exists():
+ locations.append(f"unix://{user_sock}")
+ # Standard Podman socket
+ locations.append(None) # Use PodmanClient() default
+
+ # Remove duplicates while preserving order
+ seen = set()
+ unique_locations = []
+ for loc in locations:
+ if loc not in seen:
+ unique_locations.append(loc)
+ seen.add(loc)
+
+ return unique_locations
+
+ def _create_adapter(self) -> BaseContainerClientAdapter:
+ """
+ Create the appropriate container client adapter.
+
+ Tries Docker first, then Podman if Docker fails, unless a specific
+ runtime is requested in the config. Automatically tries common socket
+ locations (e.g., Colima for Docker on macOS, user socket for Podman).
+
+ Raises RuntimeError if neither Docker nor Podman are available.
+ """
+ runtime_map = {
+ "docker": DockerClientAdapter,
+ "podman": PodmanClientAdapter,
+ }
+
+ # Determine which runtimes to try
+ runtimes_to_try = (
+ [self.cfg.container_runtime] if self.cfg.container_runtime else ["docker", "podman"]
+ )
+
+ attempted_connections = []
+ last_error = None
+
+ for runtime_name in runtimes_to_try:
+ if runtime_name not in runtime_map:
+ continue
+
+ # Try common socket locations for this runtime
+ socket_locations = self._get_common_socket_locations(runtime_name)
+
+ for host in socket_locations:
+ try:
+ adapter = runtime_map[runtime_name](host)
+ adapter.ping()
+ host_display = host or "default"
+ logger.debug(
+ f"Using {runtime_name} as container runtime (host: {host_display})"
+ )
+ return adapter
+ except Exception as e:
+ host_str = host or "default"
+ logger.debug(f"{runtime_name} initialization failed at {host_str}: {e}")
+ attempted_connections.append(f"{runtime_name} at {host_str}")
+ last_error = e
+
+ # Build helpful error message
+ import platform
+
+ system = platform.system()
+
+ attempted = ", ".join(attempted_connections)
+ error_msg = f"Could not connect to Docker or Podman (tried: {attempted}).\n"
+
+ if system == "Darwin": # macOS
+ error_msg += (
+ "Ensure Docker/Podman is running "
+ "(e.g., 'colima start' or 'podman machine start').\n"
+ )
+ else:
+ error_msg += "Ensure Docker/Podman is installed and running.\n"
+
+ error_msg += (
+ "To specify a custom socket: ContainerBackendConfig(container_host='unix:///path/to/socket')\n"
+ "Or use LocalProcessBackendConfig for non-containerized execution."
+ )
+
+ raise RuntimeError(error_msg) from last_error
+
+ @property
+ def _runtime_type(self) -> str:
+ """Get the runtime type for debugging/logging."""
+ return self._adapter._runtime_type
+
+ # ---- Runtime APIs ----
+ def list_runtimes(self) -> list[types.Runtime]:
+ return list_training_runtimes_from_sources(self.cfg.runtime_source.sources)
+
+ def get_runtime(self, name: str) -> types.Runtime:
+ return get_training_runtime_from_sources(name, self.cfg.runtime_source.sources)
+
+ def get_runtime_packages(self, runtime: types.Runtime):
+ """
+ Spawn a short-lived container to report Python version, pip list, and nvidia-smi.
+ """
+ image = container_utils.resolve_image(runtime)
+ container_utils.maybe_pull_image(self._adapter, image, self.cfg.pull_policy)
+
+ command = [
+ "bash",
+ "-lc",
+ "python -c \"import sys; print(f'Python: {sys.version}')\" && "
+ "(pip list || echo 'pip not found') && "
+ "(nvidia-smi || echo 'nvidia-smi not found')",
+ ]
+
+ logs = self._adapter.run_oneoff_container(image=image, command=command)
+ print(logs)
+
+ def train(
+ self,
+ runtime: Optional[types.Runtime] = None,
+ initializer: Optional[types.Initializer] = None,
+ trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None,
+ ) -> str:
+ if runtime is None:
+ runtime = self.get_runtime("torch-distributed")
+
+ if not isinstance(trainer, types.CustomTrainer):
+ raise ValueError(f"{self.__class__.__name__} supports only CustomTrainer in v1")
+
+ # Generate job name
+ job_name = random.choice(string.ascii_lowercase) + uuid.uuid4().hex[:11]
+ logger.debug(f"Starting training job: {job_name}")
+
+ try:
+ # Create per-job working directory on host (for outputs, checkpoints, etc.)
+ workdir = container_utils.create_workdir(job_name)
+ logger.debug(f"Created working directory: {workdir}")
+
+ # Generate training script code (inline, not written to disk)
+ training_script_code = container_utils.get_training_script_code(trainer)
+ logger.debug("Generated training script code")
+
+ # Resolve image and pull if needed
+ image = container_utils.resolve_image(runtime)
+ logger.debug(f"Using image: {image}")
+
+ container_utils.maybe_pull_image(self._adapter, image, self.cfg.pull_policy)
+ logger.debug(f"Image ready: {image}")
+
+ # Build base environment
+ env = container_utils.build_environment(trainer)
+
+ # Construct pre-run command to install packages
+ pre_install_cmd = container_utils.build_pip_install_cmd(trainer)
+
+ # Create network for multi-node communication
+ num_nodes = trainer.num_nodes or runtime.trainer.num_nodes or 1
+ logger.debug(f"Creating network for {num_nodes} nodes")
+
+ # Determine number of processes per node from GPU count
+ # For GPU training: spawn one process per GPU for optimal utilization
+ # For CPU training: use single process (PyTorch parallelizes internally via threads)
+ nproc_per_node = 1 # Default for CPU training
+ if trainer.resources_per_node and "gpu" in trainer.resources_per_node:
+ try:
+ nproc_per_node = int(trainer.resources_per_node["gpu"])
+ logger.debug(f"Using {nproc_per_node} processes per node (1 per GPU)")
+ except (ValueError, TypeError):
+ logger.warning(
+ f"Invalid GPU count in resources_per_node: "
+ f"{trainer.resources_per_node['gpu']}, defaulting to 1 process per node"
+ )
+ else:
+ logger.debug("No GPU specified, using 1 process per node")
+
+ network_id = self._adapter.create_network(
+ name=f"{job_name}-net",
+ labels={
+ f"{self.label_prefix}/trainjob-name": job_name,
+ f"{self.label_prefix}/runtime-name": runtime.name,
+ f"{self.label_prefix}/workdir": workdir,
+ },
+ )
+ logger.debug(f"Created network: {network_id}")
+
+ # Create N containers (one per node)
+ container_ids: list[str] = []
+ master_container_id = None
+ master_ip = None
+
+ for rank in range(num_nodes):
+ container_name = f"{job_name}-node-{rank}"
+
+ # Get master address and port for torchrun
+ master_port = 29500
+
+ # For Podman: use IP address to avoid DNS timing issues
+ # For Docker: use hostname (DNS is reliable)
+ if rank == 0:
+ # Master node - will be created first
+ master_addr = f"{job_name}-node-0"
+ else:
+ # Worker nodes - determine master address based on runtime
+ if self._runtime_type == "podman" and master_ip:
+ master_addr = master_ip
+ logger.debug(f"Using master IP address for Podman: {master_ip}")
+ else:
+ master_addr = f"{job_name}-node-0"
+ logger.debug(f"Using master hostname: {master_addr}")
+
+ # Prefer torchrun; fall back to python if torchrun is unavailable
+ # For worker nodes, wait for master to be reachable before starting torchrun
+ wait_for_master = ""
+ if rank > 0:
+ wait_for_master = (
+ f"echo 'Waiting for master node {master_addr}:{master_port}...'; "
+ f"for i in {{1..60}}; do "
+ f" if timeout 1 bash -c 'cat < /dev/null > "
+ f"/dev/tcp/{master_addr}/{master_port}' 2>/dev/null; then "
+ f" echo 'Master node is reachable'; break; "
+ f" fi; "
+ f" if [ $i -eq 60 ]; then "
+ f"echo 'Timeout waiting for master node'; exit 1; fi; "
+ f" sleep 2; "
+ f"done; "
+ )
+
+ # Embed training script inline using heredoc (no file I/O on host)
+ entry_cmd = (
+ f"{pre_install_cmd}"
+ f"{wait_for_master}"
+ f"cat > /tmp/train.py << 'TRAINING_SCRIPT_EOF'\n"
+ f"{training_script_code}\n"
+ f"TRAINING_SCRIPT_EOF\n"
+ "if command -v torchrun >/dev/null 2>&1; then "
+ f" torchrun --nproc_per_node={nproc_per_node} --nnodes={num_nodes} "
+ f" --node-rank={rank} --rdzv-backend=c10d "
+ f" --rdzv-endpoint={master_addr}:{master_port} "
+ f" /tmp/train.py; "
+ "else "
+ f" python /tmp/train.py; "
+ "fi"
+ )
+
+ full_cmd = ["bash", "-lc", entry_cmd]
+
+ labels = {
+ f"{self.label_prefix}/trainjob-name": job_name,
+ f"{self.label_prefix}/step": f"node-{rank}",
+ f"{self.label_prefix}/network-id": network_id,
+ }
+
+ volumes = {
+ workdir: {
+ "bind": constants.WORKSPACE_PATH,
+ "mode": "rw",
+ }
+ }
+
+ logger.debug(f"Creating container {rank}/{num_nodes}: {container_name}")
+
+ container_id = self._adapter.create_and_start_container(
+ image=image,
+ command=full_cmd,
+ name=container_name,
+ network_id=network_id,
+ environment=env,
+ labels=labels,
+ volumes=volumes,
+ working_dir=constants.WORKSPACE_PATH,
+ )
+
+ logger.debug(f"Started container {container_name} (ID: {container_id[:12]})")
+ container_ids.append(container_id)
+
+ # If this is the master node and we're using Podman, get its IP address
+ if rank == 0:
+ master_container_id = container_id
+ if self._runtime_type == "podman":
+ # Get master IP for worker nodes to use
+ master_ip = self._adapter.get_container_ip(master_container_id, network_id)
+ if master_ip:
+ logger.debug(f"Master node IP address: {master_ip}")
+ else:
+ logger.warning(
+ "Could not retrieve master IP address. "
+ "Worker nodes will fall back to DNS resolution."
+ )
+
+ logger.debug(
+ f"Training job {job_name} created successfully with "
+ f"{len(container_ids)} container(s)"
+ )
+ return job_name
+
+ except Exception as e:
+ # Clean up on failure
+ logger.error(f"Failed to create training job {job_name}: {e}")
+ logger.exception("Full traceback:")
+
+ # Try to clean up any resources that were created
+ from contextlib import suppress
+
+ try:
+ # Stop and remove any containers that were created
+ if "container_ids" in locals():
+ for container_id in container_ids:
+ with suppress(Exception):
+ self._adapter.stop_container(container_id, timeout=5)
+ self._adapter.remove_container(container_id, force=True)
+
+ # Remove network if it was created
+ if "network_id" in locals():
+ with suppress(Exception):
+ self._adapter.delete_network(network_id)
+
+ # Remove working directory if it was created
+ if "workdir" in locals() and os.path.isdir(workdir):
+ shutil.rmtree(workdir, ignore_errors=True)
+
+ except Exception as cleanup_error:
+ logger.error(f"Error during cleanup: {cleanup_error}")
+
+ # Re-raise the original exception
+ raise
+
+ def _get_job_containers(self, name: str) -> list[dict]:
+ """
+ Get containers for a specific training job.
+
+ Args:
+ name: Name of the training job
+
+ Returns:
+ List of container dictionaries for this job
+
+ Raises:
+ ValueError: If no containers found for the job
+ """
+ filters = {"label": [f"{self.label_prefix}/trainjob-name={name}"]}
+ containers = self._adapter.list_containers(filters=filters)
+
+ if not containers:
+ raise ValueError(f"No TrainJob with name {name}")
+
+ return containers
+
+ def __get_trainjob_from_containers(
+ self, job_name: str, containers: list[dict]
+ ) -> types.TrainJob:
+ """
+ Build a TrainJob object from a list of containers.
+
+ Args:
+ job_name: Name of the training job
+ containers: List of container dictionaries for this job
+
+ Returns:
+ TrainJob object
+
+ Raises:
+ ValueError: If network metadata is missing or runtime not found
+ """
+ if not containers:
+ raise ValueError(f"No containers found for TrainJob {job_name}")
+
+ # Get metadata from network
+ network_id = containers[0]["labels"].get(f"{self.label_prefix}/network-id")
+ if not network_id:
+ raise ValueError(f"TrainJob {job_name} is missing network metadata")
+
+ network_info = self._adapter.get_network(network_id)
+ if not network_info:
+ raise ValueError(f"TrainJob {job_name} network not found")
+
+ network_labels = network_info.get("labels", {})
+ runtime_name = network_labels.get(f"{self.label_prefix}/runtime-name")
+
+ # Get runtime object
+ try:
+ job_runtime = self.get_runtime(runtime_name) if runtime_name else None
+ except Exception as e:
+ raise ValueError(f"Runtime {runtime_name} not found for job {job_name}") from e
+
+ if not job_runtime:
+ raise ValueError(f"Runtime {runtime_name} not found for job {job_name}")
+
+ # Parse creation timestamp from first container
+ created_str = containers[0].get("created", "")
+ try:
+ from dateutil import parser
+
+ creation_timestamp = parser.isoparse(created_str)
+ except Exception:
+ creation_timestamp = datetime.now()
+
+ # Build steps from containers
+ steps = []
+ for container in sorted(containers, key=lambda c: c["name"]):
+ step_name = container["labels"].get(f"{self.label_prefix}/step", "")
+ steps.append(
+ types.Step(
+ name=step_name,
+ pod_name=container["name"],
+ status=container_utils.get_container_status(self._adapter, container["id"]),
+ )
+ )
+
+ # Get num_nodes from container count
+ num_nodes = len(containers)
+
+ return types.TrainJob(
+ name=job_name,
+ creation_timestamp=creation_timestamp,
+ runtime=job_runtime,
+ steps=steps,
+ num_nodes=num_nodes,
+ status=container_utils.aggregate_container_statuses(self._adapter, containers),
+ )
+
+ def list_jobs(self, runtime: Optional[types.Runtime] = None) -> list[types.TrainJob]:
+ """List all training jobs by querying container runtime."""
+ # Get all containers with our label prefix
+ filters = {"label": [f"{self.label_prefix}/trainjob-name"]}
+ containers = self._adapter.list_containers(filters=filters)
+
+ # Group containers by job name
+ jobs_map: dict[str, list[dict]] = {}
+ for container in containers:
+ job_name = container["labels"].get(f"{self.label_prefix}/trainjob-name")
+ if job_name:
+ if job_name not in jobs_map:
+ jobs_map[job_name] = []
+ jobs_map[job_name].append(container)
+
+ result: list[types.TrainJob] = []
+ for job_name, job_containers in jobs_map.items():
+ # Skip jobs with no containers
+ if not job_containers:
+ continue
+
+ # Filter by runtime if specified
+ if runtime:
+ network_id = job_containers[0]["labels"].get(f"{self.label_prefix}/network-id")
+ if network_id:
+ network_info = self._adapter.get_network(network_id)
+ if network_info:
+ network_labels = network_info.get("labels", {})
+ runtime_name = network_labels.get(f"{self.label_prefix}/runtime-name")
+ if runtime_name != runtime.name:
+ continue
+
+ # Build TrainJob from containers
+ try:
+ result.append(self.__get_trainjob_from_containers(job_name, job_containers))
+ except Exception as e:
+ logger.warning(f"Failed to get TrainJob {job_name}: {e}")
+ continue
+
+ return result
+
+ def get_job(self, name: str) -> types.TrainJob:
+ """Get a specific training job by querying container runtime."""
+ containers = self._get_job_containers(name)
+ return self.__get_trainjob_from_containers(name, containers)
+
+ def get_job_logs(
+ self,
+ name: str,
+ follow: bool = False,
+ step: str = constants.NODE + "-0",
+ ) -> Iterator[str]:
+ """Get logs for a training job by querying container runtime."""
+ containers = self._get_job_containers(name)
+
+ want_all = step == constants.NODE + "-0"
+ for container in sorted(containers, key=lambda c: c["name"]):
+ container_step = container["labels"].get(f"{self.label_prefix}/step", "")
+ if not want_all and container_step != step:
+ continue
+ try:
+ yield from self._adapter.container_logs(container["id"], follow)
+ except Exception as e:
+ logger.warning(f"Failed to get logs for {container['name']}: {e}")
+ yield f"Error getting logs: {e}\n"
+
+ def wait_for_job_status(
+ self,
+ name: str,
+ status: set[str] = {constants.TRAINJOB_COMPLETE},
+ timeout: int = 600,
+ polling_interval: int = 2,
+ ) -> types.TrainJob:
+ import time
+
+ end = time.time() + timeout
+ while time.time() < end:
+ tj = self.get_job(name)
+ logger.debug(f"TrainJob {name}, status {tj.status}")
+ if tj.status in status:
+ return tj
+ if constants.TRAINJOB_FAILED not in status and tj.status == constants.TRAINJOB_FAILED:
+ raise RuntimeError(f"TrainJob {name} is Failed")
+ time.sleep(polling_interval)
+ raise TimeoutError(f"Timeout waiting for TrainJob {name} to reach status: {status}")
+
+ def delete_job(self, name: str):
+ """Delete a training job by querying container runtime."""
+ containers = self._get_job_containers(name)
+
+ # Get network_id and workdir from labels
+ network_id = containers[0]["labels"].get(f"{self.label_prefix}/network-id")
+
+ # Get workdir from network labels
+ workdir_host = None
+ if network_id:
+ network_info = self._adapter.get_network(network_id)
+ if network_info:
+ network_labels = network_info.get("labels", {})
+ workdir_host = network_labels.get(f"{self.label_prefix}/workdir")
+
+ # Stop containers and remove
+ from contextlib import suppress
+
+ for container in containers:
+ with suppress(Exception):
+ self._adapter.stop_container(container["id"], timeout=10)
+ with suppress(Exception):
+ self._adapter.remove_container(container["id"], force=True)
+
+ # Remove network (best-effort)
+ if network_id:
+ with suppress(Exception):
+ self._adapter.delete_network(network_id)
+
+ # Remove working directory if configured
+ if self.cfg.auto_remove and workdir_host and os.path.isdir(workdir_host):
+ shutil.rmtree(workdir_host, ignore_errors=True)
diff --git a/kubeflow/trainer/backends/container/backend_test.py b/kubeflow/trainer/backends/container/backend_test.py
new file mode 100644
index 000000000..ac13ca0f5
--- /dev/null
+++ b/kubeflow/trainer/backends/container/backend_test.py
@@ -0,0 +1,863 @@
+# Copyright 2025 The Kubeflow Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Unit tests for ContainerBackend.
+
+Tests the ContainerBackend class with mocked container adapters.
+"""
+
+from collections.abc import Iterator
+from contextlib import nullcontext
+import os
+from pathlib import Path
+import shutil
+import tempfile
+from typing import Optional
+from unittest.mock import Mock, patch
+
+import pytest
+
+from kubeflow.trainer.backends.container.adapters.base import (
+ BaseContainerClientAdapter,
+)
+from kubeflow.trainer.backends.container.backend import ContainerBackend
+from kubeflow.trainer.backends.container.types import ContainerBackendConfig
+from kubeflow.trainer.constants import constants
+from kubeflow.trainer.test.common import FAILED, SUCCESS, TestCase
+from kubeflow.trainer.types import types
+
+
+# Mock Container Adapter
+class MockContainerAdapter(BaseContainerClientAdapter):
+ """Mock adapter for testing ContainerBackend without Docker/Podman."""
+
+ def __init__(self):
+ self._runtime_type = "mock"
+ self.networks_created = []
+ self.containers_created = []
+ self.containers_stopped = []
+ self.containers_removed = []
+ self.networks_deleted = []
+ self.images_pulled = []
+ self.ping_called = False
+
+ def ping(self):
+ self.ping_called = True
+
+ def create_network(self, name: str, labels: dict[str, str]) -> str:
+ network_id = f"net-{name}"
+ self.networks_created.append({"id": network_id, "name": name, "labels": labels})
+ return network_id
+
+ def delete_network(self, network_id: str):
+ self.networks_deleted.append(network_id)
+
+ def create_and_start_container(
+ self,
+ image: str,
+ command: list[str],
+ name: str,
+ network_id: str,
+ environment: dict[str, str],
+ labels: dict[str, str],
+ volumes: dict[str, dict[str, str]],
+ working_dir: str,
+ ) -> str:
+ container_id = f"container-{len(self.containers_created)}"
+ self.containers_created.append(
+ {
+ "id": container_id,
+ "name": name,
+ "image": image,
+ "command": command,
+ "network": network_id,
+ "environment": environment,
+ "labels": labels,
+ "volumes": volumes,
+ "working_dir": working_dir,
+ "status": "running",
+ "exit_code": None,
+ }
+ )
+ return container_id
+
+ def get_container(self, container_id: str):
+ for container in self.containers_created:
+ if container["id"] == container_id:
+ return Mock(id=container_id, status=container["status"])
+ return None
+
+ def container_logs(self, container_id: str, follow: bool) -> Iterator[str]:
+ if follow:
+ yield f"Log line 1 from {container_id}\n"
+ yield f"Log line 2 from {container_id}\n"
+ else:
+ yield f"Complete log from {container_id}\n"
+
+ def stop_container(self, container_id: str, timeout: int = 10):
+ self.containers_stopped.append(container_id)
+ for container in self.containers_created:
+ if container["id"] == container_id:
+ container["status"] = "exited"
+ container["exit_code"] = 0
+
+ def remove_container(self, container_id: str, force: bool = True):
+ self.containers_removed.append(container_id)
+
+ def pull_image(self, image: str):
+ self.images_pulled.append(image)
+
+ def image_exists(self, image: str) -> bool:
+ return "local" in image or image in self.images_pulled
+
+ def run_oneoff_container(self, image: str, command: list[str]) -> str:
+ return "Python 3.9.0\npip 21.0.1\nnvidia-smi not found\n"
+
+ def container_status(self, container_id: str) -> tuple[str, Optional[int]]:
+ for container in self.containers_created:
+ if container["id"] == container_id:
+ return (container["status"], container.get("exit_code"))
+ return ("unknown", None)
+
+ def set_container_status(self, container_id: str, status: str, exit_code: Optional[int] = None):
+ """Helper method to set container status for testing."""
+ for container in self.containers_created:
+ if container["id"] == container_id:
+ container["status"] = status
+ container["exit_code"] = exit_code
+
+ def get_container_ip(self, container_id: str, network_id: str) -> Optional[str]:
+ """Get container IP address on a specific network."""
+ for container in self.containers_created:
+ if container["id"] == container_id:
+ return f"192.168.1.{len(self.containers_created)}"
+ return None
+
+ def list_containers(self, filters: Optional[dict[str, list[str]]] = None) -> list[dict]:
+ """List containers with optional filters."""
+ if not filters:
+ return [
+ {
+ "id": c["id"],
+ "name": c["name"],
+ "labels": c["labels"],
+ "status": c["status"],
+ "created": "2025-01-01T00:00:00Z",
+ }
+ for c in self.containers_created
+ ]
+
+ # Simple label filtering
+ result = []
+ for container in self.containers_created:
+ if "label" in filters:
+ match = True
+ for label_filter in filters["label"]:
+ if "=" in label_filter:
+ key, value = label_filter.split("=", 1)
+ if container["labels"].get(key) != value:
+ match = False
+ break
+ else:
+ if label_filter not in container["labels"]:
+ match = False
+ break
+ if match:
+ result.append(
+ {
+ "id": container["id"],
+ "name": container["name"],
+ "labels": container["labels"],
+ "status": container["status"],
+ "created": "2025-01-01T00:00:00Z",
+ }
+ )
+ return result
+
+ def get_network(self, network_id: str) -> Optional[dict]:
+ """Get network information."""
+ for network in self.networks_created:
+ if network["id"] == network_id or network["name"] == network_id:
+ return {
+ "id": network["id"],
+ "name": network["name"],
+ "labels": network["labels"],
+ }
+ return None
+
+
+# Fixtures
+@pytest.fixture
+def container_backend():
+ """Provide ContainerBackend with mocked adapter."""
+ with patch("kubeflow.trainer.backends.container.backend.DockerClientAdapter") as mock_docker:
+ mock_docker.return_value = MockContainerAdapter()
+ backend = ContainerBackend(ContainerBackendConfig())
+ return backend
+
+
+@pytest.fixture
+def temp_workdir():
+ """Provide a temporary working directory."""
+ tmpdir = tempfile.mkdtemp()
+ yield tmpdir
+ if os.path.exists(tmpdir):
+ shutil.rmtree(tmpdir)
+
+
+# Helper Function
+def simple_train_func():
+ """Simple training function for tests."""
+ print("Training")
+
+
+# Tests
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ TestCase(
+ name="auto-detect docker first",
+ expected_status=SUCCESS,
+ ),
+ TestCase(
+ name="auto-detect falls back to podman",
+ expected_status=SUCCESS,
+ ),
+ TestCase(
+ name="both unavailable raises error",
+ expected_status=FAILED,
+ expected_error=RuntimeError,
+ ),
+ ],
+)
+def test_backend_initialization(test_case):
+ """Test ContainerBackend initialization and adapter creation."""
+ print("Executing test:", test_case.name)
+ try:
+ if test_case.name == "auto-detect docker first":
+ with (
+ patch(
+ "kubeflow.trainer.backends.container.backend.DockerClientAdapter"
+ ) as mock_docker,
+ patch(
+ "kubeflow.trainer.backends.container.backend.PodmanClientAdapter"
+ ) as mock_podman,
+ ):
+ mock_docker_instance = Mock()
+ mock_docker.return_value = mock_docker_instance
+
+ _ = ContainerBackend(ContainerBackendConfig())
+
+ # Docker should be called (could be with Colima socket or None)
+ assert mock_docker.call_count == 1
+ mock_docker_instance.ping.assert_called_once()
+ mock_podman.assert_not_called()
+ assert test_case.expected_status == SUCCESS
+
+ elif test_case.name == "auto-detect falls back to podman":
+ with (
+ patch(
+ "kubeflow.trainer.backends.container.backend.DockerClientAdapter"
+ ) as mock_docker,
+ patch(
+ "kubeflow.trainer.backends.container.backend.PodmanClientAdapter"
+ ) as mock_podman,
+ ):
+ mock_docker_instance = Mock()
+ mock_docker_instance.ping.side_effect = Exception("Docker not available")
+ mock_docker.return_value = mock_docker_instance
+
+ mock_podman_instance = Mock()
+ mock_podman.return_value = mock_podman_instance
+
+ _ = ContainerBackend(ContainerBackendConfig())
+
+ # Docker may be tried multiple times (different socket locations)
+ assert mock_docker.call_count >= 1
+ mock_podman.assert_called_once_with(None)
+ mock_podman_instance.ping.assert_called_once()
+ assert test_case.expected_status == SUCCESS
+
+ elif test_case.name == "both unavailable raises error":
+ with (
+ patch(
+ "kubeflow.trainer.backends.container.backend.DockerClientAdapter"
+ ) as mock_docker,
+ patch(
+ "kubeflow.trainer.backends.container.backend.PodmanClientAdapter"
+ ) as mock_podman,
+ ):
+ mock_docker_instance = Mock()
+ mock_docker_instance.ping.side_effect = Exception("Docker not available")
+ mock_docker.return_value = mock_docker_instance
+
+ mock_podman_instance = Mock()
+ mock_podman_instance.ping.side_effect = Exception("Podman not available")
+ mock_podman.return_value = mock_podman_instance
+
+ ContainerBackend(ContainerBackendConfig())
+
+ except Exception as e:
+ assert type(e) is test_case.expected_error
+ print("test execution complete")
+
+
+def test_list_runtimes(container_backend):
+ """Test listing available local runtimes."""
+ print("Executing test: list_runtimes")
+ runtimes = container_backend.list_runtimes()
+
+ assert isinstance(runtimes, list)
+ assert len(runtimes) > 0
+ runtime_names = [r.name for r in runtimes]
+ assert "torch-distributed" in runtime_names
+ print("test execution complete")
+
+
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ TestCase(
+ name="get valid runtime",
+ expected_status=SUCCESS,
+ config={"name": "torch-distributed"},
+ ),
+ TestCase(
+ name="get invalid runtime",
+ expected_status=FAILED,
+ config={"name": "nonexistent-runtime"},
+ expected_error=ValueError,
+ ),
+ ],
+)
+def test_get_runtime(container_backend, test_case):
+ """Test getting a specific runtime."""
+ print("Executing test:", test_case.name)
+ try:
+ runtime = container_backend.get_runtime(**test_case.config)
+
+ assert test_case.expected_status == SUCCESS
+ assert isinstance(runtime, types.Runtime)
+ assert runtime.name == test_case.config["name"]
+
+ except Exception as e:
+ assert type(e) is test_case.expected_error
+ print("test execution complete")
+
+
+def test_get_runtime_packages(container_backend):
+ """Test getting runtime packages."""
+ print("Executing test: get_runtime_packages")
+ runtime = container_backend.get_runtime("torch-distributed")
+ container_backend.get_runtime_packages(runtime)
+
+ assert len(
+ container_backend._adapter.images_pulled
+ ) > 0 or container_backend._adapter.image_exists(runtime.trainer.image)
+ print("test execution complete")
+
+
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ TestCase(
+ name="train single node",
+ expected_status=SUCCESS,
+ config={"num_nodes": 1, "expected_containers": 1},
+ ),
+ TestCase(
+ name="train multi-node",
+ expected_status=SUCCESS,
+ config={"num_nodes": 3, "expected_containers": 3},
+ ),
+ TestCase(
+ name="train with custom env",
+ expected_status=SUCCESS,
+ config={
+ "num_nodes": 1,
+ "env": {"MY_VAR": "my_value", "DEBUG": "true"},
+ "expected_containers": 1,
+ },
+ ),
+ TestCase(
+ name="train with packages",
+ expected_status=SUCCESS,
+ config={
+ "num_nodes": 1,
+ "packages": ["numpy", "pandas"],
+ "expected_containers": 1,
+ },
+ ),
+ TestCase(
+ name="train with single GPU",
+ expected_status=SUCCESS,
+ config={
+ "num_nodes": 1,
+ "resources_per_node": {"gpu": "1"},
+ "expected_containers": 1,
+ "expected_nproc_per_node": 1,
+ },
+ ),
+ TestCase(
+ name="train with multiple GPUs",
+ expected_status=SUCCESS,
+ config={
+ "num_nodes": 1,
+ "resources_per_node": {"gpu": "4"},
+ "expected_containers": 1,
+ "expected_nproc_per_node": 4,
+ },
+ ),
+ TestCase(
+ name="train multi-node with GPUs",
+ expected_status=SUCCESS,
+ config={
+ "num_nodes": 2,
+ "resources_per_node": {"gpu": "2"},
+ "expected_containers": 2,
+ "expected_nproc_per_node": 2,
+ },
+ ),
+ TestCase(
+ name="train with CPU resources (nproc=1)",
+ expected_status=SUCCESS,
+ config={
+ "num_nodes": 1,
+ "resources_per_node": {"cpu": "16"},
+ "expected_containers": 1,
+ "expected_nproc_per_node": 1,
+ },
+ ),
+ ],
+)
+def test_train(container_backend, test_case):
+ """Test training job creation."""
+ print("Executing test:", test_case.name)
+ try:
+ trainer = types.CustomTrainer(
+ func=simple_train_func,
+ num_nodes=test_case.config.get("num_nodes", 1),
+ env=test_case.config.get("env"),
+ packages_to_install=test_case.config.get("packages"),
+ resources_per_node=test_case.config.get("resources_per_node"),
+ )
+ runtime = container_backend.get_runtime("torch-distributed")
+
+ job_name = container_backend.train(runtime=runtime, trainer=trainer)
+
+ assert test_case.expected_status == SUCCESS
+ assert job_name is not None
+ assert len(job_name) == 12
+ assert (
+ len(container_backend._adapter.containers_created)
+ == test_case.config["expected_containers"]
+ )
+ assert len(container_backend._adapter.networks_created) == 1
+
+ # Check environment if specified
+ if "env" in test_case.config:
+ container = container_backend._adapter.containers_created[0]
+ for key, value in test_case.config["env"].items():
+ assert container["environment"][key] == value
+
+ # Check packages if specified
+ if "packages" in test_case.config:
+ container = container_backend._adapter.containers_created[0]
+ command_str = " ".join(container["command"])
+ assert "pip install" in command_str
+ for package in test_case.config["packages"]:
+ assert package in command_str
+
+ # Check nproc_per_node if specified
+ if "expected_nproc_per_node" in test_case.config:
+ container = container_backend._adapter.containers_created[0]
+ command_str = container["command"][2] # Get bash script content
+ expected_nproc = test_case.config["expected_nproc_per_node"]
+ assert f"--nproc_per_node={expected_nproc}" in command_str, (
+ f"Expected --nproc_per_node={expected_nproc} in command, but got: {command_str}"
+ )
+
+ except Exception as e:
+ assert type(e) is test_case.expected_error
+ print("test execution complete")
+
+
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ TestCase(
+ name="list all jobs",
+ expected_status=SUCCESS,
+ config={"num_jobs": 2},
+ ),
+ TestCase(
+ name="list empty jobs",
+ expected_status=SUCCESS,
+ config={"num_jobs": 0},
+ ),
+ ],
+)
+def test_list_jobs(container_backend, test_case):
+ """Test listing training jobs."""
+ print("Executing test:", test_case.name)
+ try:
+ runtime = container_backend.get_runtime("torch-distributed")
+ created_jobs = []
+
+ for _ in range(test_case.config["num_jobs"]):
+ trainer = types.CustomTrainer(func=simple_train_func, num_nodes=1)
+ job_name = container_backend.train(runtime=runtime, trainer=trainer)
+ created_jobs.append(job_name)
+
+ jobs = container_backend.list_jobs()
+
+ assert test_case.expected_status == SUCCESS
+ assert len(jobs) == test_case.config["num_jobs"]
+ if test_case.config["num_jobs"] > 0:
+ job_names = [job.name for job in jobs]
+ for created_job in created_jobs:
+ assert created_job in job_names
+
+ except Exception as e:
+ assert type(e) is test_case.expected_error
+ print("test execution complete")
+
+
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ TestCase(
+ name="get existing job",
+ expected_status=SUCCESS,
+ config={"num_nodes": 2},
+ ),
+ TestCase(
+ name="get nonexistent job",
+ expected_status=FAILED,
+ config={"job_name": "nonexistent-job"},
+ expected_error=ValueError,
+ ),
+ ],
+)
+def test_get_job(container_backend, test_case):
+ """Test getting a specific job."""
+ print("Executing test:", test_case.name)
+ try:
+ if test_case.name == "get existing job":
+ trainer = types.CustomTrainer(
+ func=simple_train_func, num_nodes=test_case.config["num_nodes"]
+ )
+ runtime = container_backend.get_runtime("torch-distributed")
+ job_name = container_backend.train(runtime=runtime, trainer=trainer)
+
+ job = container_backend.get_job(job_name)
+
+ assert test_case.expected_status == SUCCESS
+ assert job.name == job_name
+ assert job.num_nodes == test_case.config["num_nodes"]
+ assert len(job.steps) == test_case.config["num_nodes"]
+
+ elif test_case.name == "get nonexistent job":
+ container_backend.get_job(test_case.config["job_name"])
+
+ except Exception as e:
+ assert type(e) is test_case.expected_error
+ print("test execution complete")
+
+
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ TestCase(
+ name="get logs no follow",
+ expected_status=SUCCESS,
+ config={"follow": False},
+ ),
+ TestCase(
+ name="get logs with follow",
+ expected_status=SUCCESS,
+ config={"follow": True},
+ ),
+ ],
+)
+def test_get_job_logs(container_backend, test_case):
+ """Test getting job logs."""
+ print("Executing test:", test_case.name)
+ try:
+ trainer = types.CustomTrainer(func=simple_train_func, num_nodes=1)
+ runtime = container_backend.get_runtime("torch-distributed")
+ job_name = container_backend.train(runtime=runtime, trainer=trainer)
+
+ logs = list(container_backend.get_job_logs(job_name, follow=test_case.config["follow"]))
+
+ assert test_case.expected_status == SUCCESS
+ assert len(logs) > 0
+ if test_case.config["follow"]:
+ assert any("Log line" in log for log in logs)
+ else:
+ assert any("Complete log" in log for log in logs)
+
+ except Exception as e:
+ assert type(e) is test_case.expected_error
+ print("test execution complete")
+
+
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ TestCase(
+ name="wait for complete",
+ expected_status=SUCCESS,
+ config={"wait_status": constants.TRAINJOB_COMPLETE, "container_exit_code": 0},
+ ),
+ TestCase(
+ name="wait timeout",
+ expected_status=FAILED,
+ config={"wait_status": constants.TRAINJOB_COMPLETE, "timeout": 2},
+ expected_error=TimeoutError,
+ ),
+ TestCase(
+ name="job fails",
+ expected_status=FAILED,
+ config={"wait_status": constants.TRAINJOB_COMPLETE, "container_exit_code": 1},
+ expected_error=RuntimeError,
+ ),
+ ],
+)
+def test_wait_for_job_status(container_backend, test_case):
+ """Test waiting for job status."""
+ print("Executing test:", test_case.name)
+ try:
+ trainer = types.CustomTrainer(func=simple_train_func, num_nodes=1)
+ runtime = container_backend.get_runtime("torch-distributed")
+ job_name = container_backend.train(runtime=runtime, trainer=trainer)
+
+ if test_case.name == "wait for complete":
+ container_id = container_backend._adapter.containers_created[0]["id"]
+ container_backend._adapter.set_container_status(
+ container_id, "exited", test_case.config["container_exit_code"]
+ )
+
+ completed_job = container_backend.wait_for_job_status(
+ job_name, status={test_case.config["wait_status"]}, timeout=5, polling_interval=1
+ )
+
+ assert test_case.expected_status == SUCCESS
+ assert completed_job.status == constants.TRAINJOB_COMPLETE
+
+ elif test_case.name == "wait timeout":
+ container_backend.wait_for_job_status(
+ job_name,
+ status={test_case.config["wait_status"]},
+ timeout=test_case.config["timeout"],
+ polling_interval=1,
+ )
+
+ elif test_case.name == "job fails":
+ container_id = container_backend._adapter.containers_created[0]["id"]
+ container_backend._adapter.set_container_status(
+ container_id, "exited", test_case.config["container_exit_code"]
+ )
+
+ container_backend.wait_for_job_status(
+ job_name, status={test_case.config["wait_status"]}, timeout=5, polling_interval=1
+ )
+
+ except Exception as e:
+ assert type(e) is test_case.expected_error
+ print("test execution complete")
+
+
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ TestCase(
+ name="delete with auto_remove true",
+ expected_status=SUCCESS,
+ config={"auto_remove": True, "num_nodes": 2},
+ ),
+ TestCase(
+ name="delete with auto_remove false",
+ expected_status=SUCCESS,
+ config={"auto_remove": False, "num_nodes": 2},
+ ),
+ ],
+)
+def test_delete_job(container_backend, temp_workdir, test_case):
+ """Test deleting a job."""
+ print("Executing test:", test_case.name)
+ try:
+ container_backend.cfg.auto_remove = test_case.config["auto_remove"]
+
+ trainer = types.CustomTrainer(
+ func=simple_train_func, num_nodes=test_case.config["num_nodes"]
+ )
+ runtime = container_backend.get_runtime("torch-distributed")
+ job_name = container_backend.train(runtime=runtime, trainer=trainer)
+
+ job_workdir = Path.home() / ".kubeflow" / "trainer" / "containers" / job_name
+ assert job_workdir.exists()
+
+ container_backend.delete_job(job_name)
+
+ assert test_case.expected_status == SUCCESS
+ assert len(container_backend._adapter.containers_stopped) == test_case.config["num_nodes"]
+ assert len(container_backend._adapter.containers_removed) == test_case.config["num_nodes"]
+ assert len(container_backend._adapter.networks_deleted) == 1
+
+ if test_case.config["auto_remove"]:
+ assert not job_workdir.exists()
+ else:
+ assert job_workdir.exists()
+
+ except Exception as e:
+ assert type(e) is test_case.expected_error
+ print("test execution complete")
+
+
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ TestCase(
+ name="running container",
+ expected_status=SUCCESS,
+ config={
+ "container_status": "running",
+ "exit_code": None,
+ "expected_job_status": constants.TRAINJOB_RUNNING,
+ },
+ ),
+ TestCase(
+ name="exited success",
+ expected_status=SUCCESS,
+ config={
+ "container_status": "exited",
+ "exit_code": 0,
+ "expected_job_status": constants.TRAINJOB_COMPLETE,
+ },
+ ),
+ TestCase(
+ name="exited failure",
+ expected_status=SUCCESS,
+ config={
+ "container_status": "exited",
+ "exit_code": 1,
+ "expected_job_status": constants.TRAINJOB_FAILED,
+ },
+ ),
+ ],
+)
+def test_container_status_mapping(container_backend, test_case):
+ """Test container status mapping to TrainJob status."""
+ print("Executing test:", test_case.name)
+ try:
+ trainer = types.CustomTrainer(func=simple_train_func, num_nodes=1)
+ runtime = container_backend.get_runtime("torch-distributed")
+ job_name = container_backend.train(runtime=runtime, trainer=trainer)
+
+ container_id = container_backend._adapter.containers_created[0]["id"]
+ container_backend._adapter.set_container_status(
+ container_id, test_case.config["container_status"], test_case.config["exit_code"]
+ )
+
+ job = container_backend.get_job(job_name)
+
+ assert test_case.expected_status == SUCCESS
+ assert job.status == test_case.config["expected_job_status"]
+
+ except Exception as e:
+ assert type(e) is test_case.expected_error
+ print("test execution complete")
+
+
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ TestCase(
+ name="docker socket locations with colima",
+ expected_status=SUCCESS,
+ config={
+ "runtime_name": "docker",
+ "container_host": None,
+ "create_colima_socket": True,
+ "expected_contains_none": True,
+ "expected_has_colima": True,
+ },
+ ),
+ TestCase(
+ name="custom host has priority",
+ expected_status=SUCCESS,
+ config={
+ "runtime_name": "docker",
+ "container_host": "unix:///custom/path/docker.sock",
+ "create_colima_socket": False,
+ "expected_first": "unix:///custom/path/docker.sock",
+ },
+ ),
+ ],
+)
+def test_get_common_socket_locations(test_case, tmp_path):
+ """Test common socket location detection."""
+ print("Executing test:", test_case.name)
+
+ # Setup
+ if test_case.config.get("create_colima_socket"):
+ colima_dir = tmp_path / ".colima" / "default"
+ colima_dir.mkdir(parents=True)
+ colima_sock = colima_dir / "docker.sock"
+ colima_sock.touch()
+
+ cfg = ContainerBackendConfig(container_host=test_case.config["container_host"])
+
+ # Test the method directly without creating the backend
+ context_manager = (
+ patch("pathlib.Path.home", return_value=tmp_path)
+ if test_case.config.get("create_colima_socket")
+ else nullcontext()
+ )
+
+ with context_manager:
+ backend = ContainerBackend.__new__(ContainerBackend)
+ backend.cfg = cfg
+ locations = backend._get_common_socket_locations(test_case.config["runtime_name"])
+
+ # Assertions
+ if "expected_contains_none" in test_case.config:
+ assert None in locations
+
+ if "expected_has_colima" in test_case.config:
+ assert f"unix://{colima_sock}" in locations
+
+ if "expected_first" in test_case.config:
+ assert locations[0] == test_case.config["expected_first"]
+
+ print("test execution complete")
+
+
+def test_create_adapter_error_message_format():
+ """Test that error message includes attempted connections."""
+ cfg = ContainerBackendConfig(container_runtime="docker")
+
+ docker_adapter = "kubeflow.trainer.backends.container.adapters.docker.DockerClientAdapter"
+ with patch(docker_adapter) as mock_docker:
+ mock_docker.side_effect = Exception("Connection failed")
+
+ with pytest.raises(RuntimeError) as exc_info:
+ ContainerBackend(cfg)
+
+ # Error message should be helpful
+ error_msg = str(exc_info.value)
+ assert "Could not connect" in error_msg
+ assert "tried:" in error_msg
diff --git a/kubeflow/trainer/backends/container/runtime_loader.py b/kubeflow/trainer/backends/container/runtime_loader.py
new file mode 100644
index 000000000..3e614400e
--- /dev/null
+++ b/kubeflow/trainer/backends/container/runtime_loader.py
@@ -0,0 +1,612 @@
+# Copyright 2025 The Kubeflow Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Runtime loader for container backends (Docker, Podman).
+
+We support loading training runtime definitions from multiple sources:
+1. GitHub: Fetches latest runtimes from kubeflow/trainer repository (with caching)
+2. Local bundled: Falls back to `kubeflow/trainer/config/training_runtimes/` YAML files
+3. User custom: Additional YAML files in the local directory
+
+The loader tries GitHub first (with 24-hour cache), then falls back to bundled files
+if the network is unavailable or GitHub fetch fails.
+"""
+
+from __future__ import annotations
+
+from datetime import datetime, timedelta
+import json
+import logging
+from pathlib import Path
+from typing import Any, Optional
+import urllib.error
+import urllib.request
+
+import yaml
+
+from kubeflow.trainer.constants import constants
+from kubeflow.trainer.types import types as base_types
+
+logger = logging.getLogger(__name__)
+
+TRAINING_RUNTIMES_DIR = Path(__file__).parents[2] / "config" / "training_runtimes"
+CACHE_DIR = Path.home() / ".kubeflow" / "trainer" / "cache"
+CACHE_DURATION = timedelta(hours=24)
+
+# GitHub runtimes configuration
+GITHUB_RUNTIMES_BASE_URL = (
+ "https://raw.githubusercontent.com/kubeflow/trainer/master/manifests/base/runtimes"
+)
+GITHUB_RUNTIMES_TREE_URL = "https://github.com/kubeflow/trainer/tree/master/manifests/base/runtimes"
+
+__all__ = [
+ "TRAINING_RUNTIMES_DIR",
+ "get_training_runtime_from_sources",
+ "list_training_runtimes_from_sources",
+]
+
+
+def _load_runtime_from_yaml(path: Path) -> dict[str, Any]:
+ with open(path) as f:
+ data: dict[str, Any] = yaml.safe_load(f)
+ return data
+
+
+def _discover_github_runtime_files(
+ owner: str = "kubeflow",
+ repo: str = "trainer",
+ branch: str = "master",
+ path: str = "manifests/base/runtimes",
+) -> list[str]:
+ """
+ Discover available runtime YAML files from GitHub repository.
+
+ Fetches the directory listing from GitHub and extracts .yaml filenames,
+ excluding kustomization.yaml and other non-runtime files.
+
+ Args:
+ owner: GitHub repository owner (default: "kubeflow")
+ repo: GitHub repository name (default: "trainer")
+ branch: Git branch name (default: "master")
+ path: Path to runtimes directory (default: "manifests/base/runtimes")
+
+ Returns:
+ List of YAML filenames (e.g., ['torch_distributed.yaml', ...])
+ Returns empty list if discovery fails.
+ """
+ tree_url = f"https://github.com/{owner}/{repo}/tree/{branch}/{path}"
+ try:
+ logger.debug(f"Discovering runtimes from GitHub: {tree_url}")
+ with urllib.request.urlopen(tree_url, timeout=5) as response:
+ html_content = response.read().decode("utf-8")
+
+ # Parse HTML to find .yaml files
+ # Look for .yaml filenames in the HTML content
+ import re
+
+ # Pattern to match .yaml files in the HTML
+ # Matches word characters, hyphens, underscores followed by .yaml
+ pattern = r"([\w-]+\.yaml)"
+ matches = re.findall(pattern, html_content)
+
+ # Filter out kustomization.yaml, config files, and duplicates
+ # Keep only runtime files (typically named *_distributed.yaml or similar)
+ runtime_files = []
+ seen = set()
+ exclude_files = {"kustomization.yaml", "golangci.yaml", "pre-commit-config.yaml"}
+
+ for match in matches:
+ filename = match
+ if filename not in seen and filename not in exclude_files:
+ runtime_files.append(filename)
+ seen.add(filename)
+
+ logger.debug(f"Discovered {len(runtime_files)} runtime files: {runtime_files}")
+ return runtime_files
+
+ except Exception as e:
+ logger.debug(f"Failed to discover GitHub runtime files from {tree_url}: {e}")
+ return []
+
+
+def _fetch_runtime_from_github(
+ runtime_file: str,
+ owner: str = "kubeflow",
+ repo: str = "trainer",
+ branch: str = "master",
+ path: str = "manifests/base/runtimes",
+) -> Optional[dict[str, Any]]:
+ """
+ Fetch a runtime YAML from GitHub.
+
+ Args:
+ runtime_file: YAML filename to fetch
+ owner: GitHub repository owner (default: "kubeflow")
+ repo: GitHub repository name (default: "trainer")
+ branch: Git branch name (default: "master")
+ path: Path to runtimes directory (default: "manifests/base/runtimes")
+
+ Returns None if fetch fails (network error, timeout, etc.)
+ """
+ url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}/{runtime_file}"
+ try:
+ logger.debug(f"Fetching runtime from GitHub: {url}")
+ with urllib.request.urlopen(url, timeout=5) as response:
+ content = response.read().decode("utf-8")
+ data = yaml.safe_load(content)
+ logger.debug(f"Successfully fetched {runtime_file} from GitHub")
+ return data
+ except (urllib.error.URLError, TimeoutError, Exception) as e:
+ logger.debug(f"Failed to fetch {runtime_file} from GitHub: {e}")
+ return None
+
+
+def _get_cached_runtime_list() -> Optional[list[str]]:
+ """
+ Get cached runtime file list if it exists and is not expired.
+
+ Returns None if cache doesn't exist or is expired.
+ """
+ if not CACHE_DIR.exists():
+ return None
+
+ cache_file = CACHE_DIR / "runtime_list.json"
+
+ if not cache_file.exists():
+ return None
+
+ try:
+ with open(cache_file) as f:
+ data = json.load(f)
+
+ cached_time = datetime.fromisoformat(data["cached_at"])
+ if datetime.now() - cached_time > CACHE_DURATION:
+ logger.debug("Runtime list cache expired")
+ return None
+
+ logger.debug(f"Using cached runtime list: {data['files']}")
+ return data["files"]
+ except (json.JSONDecodeError, KeyError, ValueError, Exception) as e:
+ logger.debug(f"Failed to read runtime list cache: {e}")
+ return None
+
+
+def _cache_runtime_list(files: list[str]) -> None:
+ """Cache the discovered runtime file list."""
+ try:
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
+ cache_file = CACHE_DIR / "runtime_list.json"
+
+ data = {
+ "cached_at": datetime.now().isoformat(),
+ "files": files,
+ }
+ with open(cache_file, "w") as f:
+ json.dump(data, f)
+
+ logger.debug(f"Cached runtime list: {files}")
+ except Exception as e:
+ logger.debug(f"Failed to cache runtime list: {e}")
+
+
+def _get_github_runtime_files() -> list[str]:
+ """
+ Get list of runtime files from GitHub with caching.
+
+ Priority:
+ 1. Check cache (if not expired)
+ 2. Discover from GitHub (and cache if successful)
+ 3. Return empty list if both fail
+ """
+ # Try cache first
+ cached = _get_cached_runtime_list()
+ if cached is not None:
+ return cached
+
+ # Try GitHub discovery
+ files = _discover_github_runtime_files()
+ if files:
+ _cache_runtime_list(files)
+ return files
+
+ return []
+
+
+def _get_cached_runtime(runtime_file: str) -> Optional[dict[str, Any]]:
+ """
+ Get cached runtime if it exists and is not expired.
+
+ Returns None if cache doesn't exist or is expired.
+ """
+ if not CACHE_DIR.exists():
+ return None
+
+ cache_file = CACHE_DIR / runtime_file
+ metadata_file = CACHE_DIR / f"{runtime_file}.metadata"
+
+ if not cache_file.exists() or not metadata_file.exists():
+ return None
+
+ try:
+ # Check if cache is expired
+ with open(metadata_file) as f:
+ metadata = json.load(f)
+
+ cached_time = datetime.fromisoformat(metadata["cached_at"])
+ if datetime.now() - cached_time > CACHE_DURATION:
+ logger.debug(f"Cache expired for {runtime_file}")
+ return None
+
+ # Load cached runtime
+ with open(cache_file) as f:
+ data = yaml.safe_load(f)
+
+ logger.debug(f"Using cached runtime: {runtime_file}")
+ return data
+ except (json.JSONDecodeError, KeyError, ValueError, Exception) as e:
+ logger.debug(f"Failed to read cache for {runtime_file}: {e}")
+ return None
+
+
+def _cache_runtime(runtime_file: str, data: dict[str, Any]) -> None:
+ """Cache a runtime YAML with metadata."""
+ try:
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
+
+ cache_file = CACHE_DIR / runtime_file
+ metadata_file = CACHE_DIR / f"{runtime_file}.metadata"
+
+ # Write runtime data
+ with open(cache_file, "w") as f:
+ yaml.safe_dump(data, f)
+
+ # Write metadata
+ metadata = {
+ "cached_at": datetime.now().isoformat(),
+ "source": "github",
+ }
+ with open(metadata_file, "w") as f:
+ json.dump(metadata, f)
+
+ logger.debug(f"Cached runtime: {runtime_file}")
+ except Exception as e:
+ logger.debug(f"Failed to cache {runtime_file}: {e}")
+
+
+def _load_runtime_from_github_with_cache(runtime_file: str) -> Optional[dict[str, Any]]:
+ """
+ Load runtime from GitHub with caching.
+
+ Priority:
+ 1. Check cache (if not expired)
+ 2. Fetch from GitHub (and cache if successful)
+ 3. Return None if both fail
+
+ Args:
+ runtime_file: YAML filename to load
+ """
+ # Try cache first
+ cached = _get_cached_runtime(runtime_file)
+ if cached is not None:
+ return cached
+
+ # Try GitHub
+ data = _fetch_runtime_from_github(runtime_file)
+ if data is not None:
+ _cache_runtime(runtime_file, data)
+ return data
+
+ return None
+
+
+def _create_default_runtimes() -> list[base_types.Runtime]:
+ """
+ Create default Runtime objects from DEFAULT_FRAMEWORK_IMAGES constant.
+
+ Returns:
+ List of default Runtime objects for each framework.
+ """
+ default_runtimes = []
+
+ for framework, image in constants.DEFAULT_FRAMEWORK_IMAGES.items():
+ runtime = base_types.Runtime(
+ name=f"{framework}-distributed",
+ trainer=base_types.RuntimeTrainer(
+ trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
+ framework=framework,
+ num_nodes=1,
+ ),
+ pretrained_model=None,
+ )
+ default_runtimes.append(runtime)
+ logger.debug(f"Created default runtime: {runtime.name} with image {image}")
+
+ return default_runtimes
+
+
+def _parse_runtime_yaml(data: dict[str, Any], source: str = "unknown") -> base_types.Runtime:
+ """
+ Parse a runtime YAML dict into a Runtime object.
+
+ Args:
+ data: The YAML data as a dictionary
+ source: Source of the YAML (for error messages)
+
+ Returns:
+ Runtime object
+
+ Raises:
+ ValueError: If the YAML is malformed or missing required fields
+ """
+ # Require CRD-like schema strictly. Accept both ClusterTrainingRuntime
+ # and TrainingRuntime kinds.
+ if not (
+ data.get("kind") in {"ClusterTrainingRuntime", "TrainingRuntime"} and data.get("metadata")
+ ):
+ raise ValueError(
+ f"Runtime YAML from {source} must be a ClusterTrainingRuntime CRD-shaped document"
+ )
+
+ name = data["metadata"].get("name")
+ if not name:
+ raise ValueError(f"Runtime YAML from {source} missing metadata.name")
+
+ labels = data["metadata"].get("labels", {})
+ framework = labels.get("trainer.kubeflow.org/framework")
+ if not framework:
+ raise ValueError(
+ f"Runtime {name} from {source} must set "
+ f"metadata.labels['trainer.kubeflow.org/framework']"
+ )
+
+ spec = data.get("spec", {})
+ ml_policy = spec.get("mlPolicy", {})
+ num_nodes = int(ml_policy.get("numNodes", 1))
+
+ # Validate presence of a 'node' replicated job with a container image
+ templ = spec.get("template", {}).get("spec", {})
+ replicated = templ.get("replicatedJobs", [])
+ node_jobs = [j for j in replicated if j.get("name") == "node"]
+ if not node_jobs:
+ raise ValueError(
+ f"Runtime {name} from {source} must define replicatedJobs with a 'node' entry"
+ )
+ node_spec = node_jobs[0].get("template", {}).get("spec", {}).get("template", {}).get("spec", {})
+ containers = node_spec.get("containers", [])
+ if not containers or not containers[0].get("image"):
+ raise ValueError(f"Runtime {name} from {source} 'node' must specify containers[0].image")
+
+ return base_types.Runtime(
+ name=name,
+ trainer=base_types.RuntimeTrainer(
+ trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
+ framework=framework,
+ num_nodes=num_nodes,
+ ),
+ pretrained_model=None,
+ )
+
+
+def _parse_source_url(source: str) -> tuple[str, str]:
+ """
+ Parse a source URL to determine its type and path.
+
+ Args:
+ source: Source URL with scheme (github://, https://, file://, or absolute path)
+
+ Returns:
+ Tuple of (source_type, path) where source_type is one of:
+ 'github', 'http', 'https', 'file'
+ """
+ if source.startswith("github://"):
+ return ("github", source[9:]) # Remove 'github://'
+ elif source.startswith("https://"):
+ return ("https", source)
+ elif source.startswith("http://"):
+ return ("http", source)
+ elif source.startswith("file://"):
+ return ("file", source[7:]) # Remove 'file://'
+ elif source.startswith("/"):
+ # Absolute path without file:// prefix
+ return ("file", source)
+ else:
+ raise ValueError(
+ f"Unsupported source URL scheme: {source}. "
+ f"Supported: github://, https://, http://, file://, or absolute paths"
+ )
+
+
+def _load_from_github_url(github_path: str) -> list[base_types.Runtime]:
+ """
+ Load runtimes from a GitHub URL (github://owner/repo[/path]).
+
+ Args:
+ github_path: Path after github:// (e.g., "kubeflow/trainer" or "myorg/myrepo")
+
+ Returns:
+ List of Runtime objects loaded from GitHub
+ """
+ runtimes = []
+ runtime_names_seen = set()
+
+ # Parse the GitHub path
+ # Format: owner/repo[/path/to/runtimes]
+ parts = github_path.split("/")
+ if len(parts) < 2:
+ logger.warning(f"Invalid GitHub path format: {github_path}. Expected owner/repo[/path]")
+ return runtimes
+
+ owner = parts[0]
+ repo = parts[1]
+ # Custom path if provided (default to manifests/base/runtimes)
+ custom_path = "/".join(parts[2:]) if len(parts) > 2 else "manifests/base/runtimes"
+
+ # Discover runtime files from the specified GitHub repo
+ logger.debug(f"Loading runtimes from GitHub: {owner}/{repo}/{custom_path}")
+ github_runtime_files = _discover_github_runtime_files(owner=owner, repo=repo, path=custom_path)
+
+ for runtime_file in github_runtime_files:
+ try:
+ data = _fetch_runtime_from_github(
+ runtime_file, owner=owner, repo=repo, path=custom_path
+ )
+ if data is not None:
+ runtime = _parse_runtime_yaml(data, source=f"github://{github_path}/{runtime_file}")
+ if runtime.name not in runtime_names_seen:
+ runtimes.append(runtime)
+ runtime_names_seen.add(runtime.name)
+ logger.debug(f"Loaded runtime from GitHub: {runtime.name}")
+ except Exception as e:
+ logger.debug(f"Failed to parse GitHub runtime {runtime_file}: {e}")
+
+ return runtimes
+
+
+def _load_from_http_url(url: str) -> list[base_types.Runtime]:
+ """
+ Load runtimes from an HTTP(S) URL.
+
+ Args:
+ url: HTTP(S) URL to a runtime YAML file or directory listing
+
+ Returns:
+ List of Runtime objects loaded from HTTP(S)
+ """
+ runtimes = []
+
+ try:
+ import urllib.request
+
+ logger.debug(f"Fetching runtime from HTTP: {url}")
+ with urllib.request.urlopen(url, timeout=5) as response:
+ content = response.read().decode("utf-8")
+ import yaml
+
+ data = yaml.safe_load(content)
+ runtime = _parse_runtime_yaml(data, source=url)
+ runtimes.append(runtime)
+ logger.debug(f"Loaded runtime from HTTP: {runtime.name}")
+ except Exception as e:
+ logger.debug(f"Failed to load runtime from HTTP {url}: {e}")
+
+ return runtimes
+
+
+def _load_from_filesystem(path: str) -> list[base_types.Runtime]:
+ """
+ Load runtimes from local filesystem path.
+
+ Args:
+ path: Local filesystem path to a directory or YAML file
+
+ Returns:
+ List of Runtime objects loaded from filesystem
+ """
+ from pathlib import Path
+
+ runtimes = []
+ runtime_path = Path(path).expanduser()
+
+ try:
+ if runtime_path.is_dir():
+ # Load all YAML files from directory
+ for yaml_file in sorted(runtime_path.glob("*.yaml")):
+ try:
+ data = _load_runtime_from_yaml(yaml_file)
+ runtime = _parse_runtime_yaml(data, source=str(yaml_file))
+ runtimes.append(runtime)
+ logger.debug(f"Loaded runtime from file: {runtime.name}")
+ except Exception as e:
+ logger.warning(f"Failed to load runtime from {yaml_file}: {e}")
+ elif runtime_path.is_file():
+ # Load single YAML file
+ data = _load_runtime_from_yaml(runtime_path)
+ runtime = _parse_runtime_yaml(data, source=str(runtime_path))
+ runtimes.append(runtime)
+ logger.debug(f"Loaded runtime from file: {runtime.name}")
+ else:
+ logger.warning(f"Path does not exist: {runtime_path}")
+ except Exception as e:
+ logger.warning(f"Failed to load runtimes from {path}: {e}")
+
+ return runtimes
+
+
+def list_training_runtimes_from_sources(sources: list[str]) -> list[base_types.Runtime]:
+ """
+ List all available training runtimes from configured sources.
+
+ Args:
+ sources: List of source URLs with schemes (github://, https://, http://, file://, or paths)
+
+ Returns:
+ List of Runtime objects (built-in runtimes used as default if not found in sources)
+ """
+ runtimes: list[base_types.Runtime] = []
+ runtime_names_seen = set()
+
+ # Load from each configured source in priority order
+ for source in sources:
+ try:
+ source_type, source_path = _parse_source_url(source)
+
+ if source_type == "github":
+ source_runtimes = _load_from_github_url(source_path)
+ elif source_type in ("http", "https"):
+ source_runtimes = _load_from_http_url(source)
+ elif source_type == "file":
+ source_runtimes = _load_from_filesystem(source_path)
+ else:
+ logger.warning(f"Unsupported source type: {source_type}")
+ continue
+
+ # Add runtimes, skipping duplicates
+ for runtime in source_runtimes:
+ if runtime.name not in runtime_names_seen:
+ runtimes.append(runtime)
+ runtime_names_seen.add(runtime.name)
+ except Exception as e:
+ logger.debug(f"Failed to load from source {source}: {e}")
+
+ # Fallback to default runtimes from constants if not found in sources
+ for default_runtime in _create_default_runtimes():
+ if default_runtime.name not in runtime_names_seen:
+ runtimes.append(default_runtime)
+ runtime_names_seen.add(default_runtime.name)
+
+ return runtimes
+
+
+def get_training_runtime_from_sources(name: str, sources: list[str]) -> base_types.Runtime:
+ """
+ Get a specific training runtime by name from configured sources.
+
+ Args:
+ name: The name of the runtime to get
+ sources: List of source URLs with schemes
+
+ Returns:
+ Runtime object
+
+ Raises:
+ ValueError: If the runtime is not found
+ """
+ for rt in list_training_runtimes_from_sources(sources):
+ if rt.name == name:
+ return rt
+ raise ValueError(
+ f"Runtime '{name}' not found. Available runtimes: "
+ f"{[rt.name for rt in list_training_runtimes_from_sources(sources)]}"
+ )
diff --git a/kubeflow/trainer/backends/container/runtime_loader_test.py b/kubeflow/trainer/backends/container/runtime_loader_test.py
new file mode 100644
index 000000000..ea0ec2cff
--- /dev/null
+++ b/kubeflow/trainer/backends/container/runtime_loader_test.py
@@ -0,0 +1,526 @@
+# Copyright 2025 The Kubeflow Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Unit tests for runtime_loader module.
+
+Tests runtime loading from various sources including GitHub, HTTP, and filesystem.
+"""
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from kubeflow.trainer.backends.container import runtime_loader
+from kubeflow.trainer.constants import constants
+from kubeflow.trainer.test.common import FAILED, SUCCESS, TestCase
+from kubeflow.trainer.types import types as base_types
+
+# Sample runtime YAML data for testing
+SAMPLE_RUNTIME_YAML = {
+ "apiVersion": "trainer.kubeflow.org/v1alpha1",
+ "kind": "ClusterTrainingRuntime",
+ "metadata": {
+ "name": "torch-distributed",
+ "labels": {"trainer.kubeflow.org/framework": "torch"},
+ },
+ "spec": {
+ "mlPolicy": {"numNodes": 1},
+ "template": {
+ "spec": {
+ "replicatedJobs": [
+ {
+ "name": "node",
+ "template": {
+ "spec": {
+ "template": {
+ "spec": {
+ "containers": [
+ {
+ "name": "trainer",
+ "image": "pytorch/pytorch:2.0.0",
+ }
+ ]
+ }
+ }
+ }
+ },
+ }
+ ]
+ }
+ },
+ },
+}
+
+
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ TestCase(
+ name="parse github url",
+ expected_status=SUCCESS,
+ config={
+ "url": "github://kubeflow/trainer",
+ "expected_type": "github",
+ "expected_path": "kubeflow/trainer",
+ },
+ ),
+ TestCase(
+ name="parse github url with path",
+ expected_status=SUCCESS,
+ config={
+ "url": "github://myorg/myrepo/custom/path",
+ "expected_type": "github",
+ "expected_path": "myorg/myrepo/custom/path",
+ },
+ ),
+ TestCase(
+ name="parse https url",
+ expected_status=SUCCESS,
+ config={
+ "url": "https://example.com/runtime.yaml",
+ "expected_type": "https",
+ "expected_path": "https://example.com/runtime.yaml",
+ },
+ ),
+ TestCase(
+ name="parse http url",
+ expected_status=SUCCESS,
+ config={
+ "url": "http://example.com/runtime.yaml",
+ "expected_type": "http",
+ "expected_path": "http://example.com/runtime.yaml",
+ },
+ ),
+ TestCase(
+ name="parse file url",
+ expected_status=SUCCESS,
+ config={
+ "url": "file:///path/to/runtime.yaml",
+ "expected_type": "file",
+ "expected_path": "/path/to/runtime.yaml",
+ },
+ ),
+ TestCase(
+ name="parse absolute path",
+ expected_status=SUCCESS,
+ config={
+ "url": "/absolute/path/to/runtime.yaml",
+ "expected_type": "file",
+ "expected_path": "/absolute/path/to/runtime.yaml",
+ },
+ ),
+ TestCase(
+ name="parse unsupported scheme",
+ expected_status=FAILED,
+ config={"url": "ftp://example.com/runtime.yaml"},
+ expected_error=ValueError,
+ ),
+ ],
+)
+def test_parse_source_url(test_case):
+ """Test parsing various source URL formats."""
+ print("Executing test:", test_case.name)
+ try:
+ source_type, path = runtime_loader._parse_source_url(test_case.config["url"])
+
+ assert test_case.expected_status == SUCCESS
+ assert source_type == test_case.config["expected_type"]
+ assert path == test_case.config["expected_path"]
+
+ except Exception as e:
+ assert type(e) is test_case.expected_error
+ print("test execution complete")
+
+
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ TestCase(
+ name="load from default github",
+ expected_status=SUCCESS,
+ config={
+ "github_path": "kubeflow/trainer",
+ "discovered_files": ["torch_distributed.yaml"],
+ "expected_runtime_name": "torch-distributed",
+ "expected_framework": "torch",
+ },
+ ),
+ TestCase(
+ name="load from custom github",
+ expected_status=SUCCESS,
+ config={
+ "github_path": "myorg/myrepo",
+ "discovered_files": ["custom_runtime.yaml"],
+ "expected_runtime_name": "custom-runtime",
+ "expected_framework": "custom",
+ },
+ ),
+ TestCase(
+ name="load from github no files",
+ expected_status=SUCCESS,
+ config={
+ "github_path": "kubeflow/trainer",
+ "discovered_files": [],
+ "expected_count": 0,
+ },
+ ),
+ TestCase(
+ name="load from github invalid path",
+ expected_status=SUCCESS,
+ config={
+ "github_path": "invalid",
+ "expected_count": 0,
+ },
+ ),
+ ],
+)
+def test_load_from_github_url(test_case):
+ """Test loading runtimes from GitHub URLs."""
+ print("Executing test:", test_case.name)
+ try:
+ with (
+ patch(
+ "kubeflow.trainer.backends.container.runtime_loader._discover_github_runtime_files"
+ ) as mock_discover,
+ patch(
+ "kubeflow.trainer.backends.container.runtime_loader._fetch_runtime_from_github"
+ ) as mock_fetch,
+ ):
+ if test_case.name == "load from github invalid path":
+ # Don't set up mocks for invalid path test
+ runtimes = runtime_loader._load_from_github_url(test_case.config["github_path"])
+ assert len(runtimes) == test_case.config["expected_count"]
+ else:
+ mock_discover.return_value = test_case.config.get("discovered_files", [])
+
+ # Create runtime YAML with custom name/framework if specified
+ runtime_yaml = SAMPLE_RUNTIME_YAML.copy()
+ if "expected_runtime_name" in test_case.config:
+ runtime_yaml["metadata"]["name"] = test_case.config["expected_runtime_name"]
+ runtime_yaml["metadata"]["labels"]["trainer.kubeflow.org/framework"] = (
+ test_case.config["expected_framework"]
+ )
+ mock_fetch.return_value = runtime_yaml
+
+ runtimes = runtime_loader._load_from_github_url(test_case.config["github_path"])
+
+ if "expected_count" in test_case.config:
+ assert len(runtimes) == test_case.config["expected_count"]
+ else:
+ assert len(runtimes) == 1
+ assert runtimes[0].name == test_case.config["expected_runtime_name"]
+ assert runtimes[0].trainer.framework == test_case.config["expected_framework"]
+
+ assert test_case.expected_status == SUCCESS
+
+ except Exception as e:
+ assert type(e) is test_case.expected_error
+ print("test execution complete")
+
+
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ TestCase(
+ name="priority order github sources",
+ expected_status=SUCCESS,
+ config={
+ "sources": ["github://myorg/myrepo", "github://kubeflow/trainer"],
+ "expected_count": 2,
+ "expected_names": ["torch-distributed", "deepspeed-distributed"],
+ },
+ ),
+ TestCase(
+ name="duplicate runtime names skipped",
+ expected_status=SUCCESS,
+ config={
+ "sources": ["github://myorg/myrepo", "github://kubeflow/trainer"],
+ "duplicate_names": True,
+ "expected_count": 1,
+ "expected_names": ["torch-distributed"],
+ },
+ ),
+ TestCase(
+ name="fallback to defaults",
+ expected_status=SUCCESS,
+ config={
+ "sources": ["github://myorg/myrepo"],
+ "no_github_runtimes": True,
+ "expected_count": 1,
+ "expected_names": ["torch-distributed"],
+ },
+ ),
+ ],
+)
+def test_list_training_runtimes_from_sources(test_case):
+ """Test listing runtimes from multiple sources."""
+ print("Executing test:", test_case.name)
+ try:
+ with (
+ patch(
+ "kubeflow.trainer.backends.container.runtime_loader._load_from_github_url"
+ ) as mock_github,
+ patch(
+ "kubeflow.trainer.backends.container.runtime_loader._create_default_runtimes"
+ ) as mock_defaults,
+ ):
+ if test_case.name == "priority order github sources":
+ torch_runtime = base_types.Runtime(
+ name="torch-distributed",
+ trainer=base_types.RuntimeTrainer(
+ trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
+ framework="torch",
+ num_nodes=1,
+ ),
+ )
+ deepspeed_runtime = base_types.Runtime(
+ name="deepspeed-distributed",
+ trainer=base_types.RuntimeTrainer(
+ trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
+ framework="deepspeed",
+ num_nodes=1,
+ ),
+ )
+ mock_github.side_effect = [[torch_runtime], [deepspeed_runtime]]
+ mock_defaults.return_value = []
+
+ elif test_case.name == "duplicate runtime names skipped":
+ torch_runtime_1 = base_types.Runtime(
+ name="torch-distributed",
+ trainer=base_types.RuntimeTrainer(
+ trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
+ framework="torch",
+ num_nodes=1,
+ ),
+ )
+ torch_runtime_2 = base_types.Runtime(
+ name="torch-distributed",
+ trainer=base_types.RuntimeTrainer(
+ trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
+ framework="torch",
+ num_nodes=2,
+ ),
+ )
+ mock_github.side_effect = [[torch_runtime_1], [torch_runtime_2]]
+ mock_defaults.return_value = []
+
+ elif test_case.name == "fallback to defaults":
+ mock_github.return_value = []
+ default_runtime = base_types.Runtime(
+ name="torch-distributed",
+ trainer=base_types.RuntimeTrainer(
+ trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
+ framework="torch",
+ num_nodes=1,
+ ),
+ )
+ mock_defaults.return_value = [default_runtime]
+
+ runtimes = runtime_loader.list_training_runtimes_from_sources(
+ test_case.config["sources"]
+ )
+
+ assert len(runtimes) == test_case.config["expected_count"]
+ runtime_names = [r.name for r in runtimes]
+ for expected_name in test_case.config["expected_names"]:
+ assert expected_name in runtime_names
+
+ assert test_case.expected_status == SUCCESS
+
+ except Exception as e:
+ assert type(e) is test_case.expected_error
+ print("test execution complete")
+
+
+def test_create_default_runtimes():
+ """Test creating default runtimes from constants."""
+ print("Executing test: create default runtimes")
+ runtimes = runtime_loader._create_default_runtimes()
+
+ assert len(runtimes) == len(constants.DEFAULT_FRAMEWORK_IMAGES)
+
+ # Check torch runtime
+ torch_runtimes = [r for r in runtimes if r.trainer.framework == "torch"]
+ assert len(torch_runtimes) == 1
+ assert torch_runtimes[0].name == "torch-distributed"
+ assert torch_runtimes[0].trainer.trainer_type == base_types.TrainerType.CUSTOM_TRAINER
+ assert torch_runtimes[0].trainer.num_nodes == 1
+ print("test execution complete")
+
+
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ TestCase(
+ name="discover runtime files",
+ expected_status=SUCCESS,
+ config={
+ "html_content": """
+
+ torch_distributed.yaml
+ deepspeed_distributed.yaml
+ kustomization.yaml
+
+ """,
+ "expected_files": ["torch_distributed.yaml", "deepspeed_distributed.yaml"],
+ "excluded_files": ["kustomization.yaml"],
+ },
+ ),
+ TestCase(
+ name="discover runtime files custom repo",
+ expected_status=SUCCESS,
+ config={
+ "html_content": """
+
+ custom_runtime.yaml
+
+ """,
+ "expected_files": ["custom_runtime.yaml"],
+ "owner": "myorg",
+ "repo": "myrepo",
+ "path": "custom/path",
+ },
+ ),
+ TestCase(
+ name="discover runtime files network error",
+ expected_status=SUCCESS,
+ config={
+ "network_error": True,
+ "expected_files": [],
+ },
+ ),
+ ],
+)
+def test_discover_github_runtime_files(test_case):
+ """Test discovering runtime files from GitHub."""
+ print("Executing test:", test_case.name)
+ try:
+ with patch("urllib.request.urlopen") as mock_urlopen:
+ if test_case.config.get("network_error"):
+ mock_urlopen.side_effect = Exception("Network error")
+ else:
+ mock_response = MagicMock()
+ mock_response.read.return_value = test_case.config["html_content"].encode("utf-8")
+ mock_response.__enter__.return_value = mock_response
+ mock_urlopen.return_value = mock_response
+
+ kwargs = {}
+ if "owner" in test_case.config:
+ kwargs["owner"] = test_case.config["owner"]
+ kwargs["repo"] = test_case.config["repo"]
+ kwargs["path"] = test_case.config["path"]
+
+ files = runtime_loader._discover_github_runtime_files(**kwargs)
+
+ for expected_file in test_case.config["expected_files"]:
+ assert expected_file in files
+
+ for excluded_file in test_case.config.get("excluded_files", []):
+ assert excluded_file not in files
+
+ if "owner" in test_case.config and not test_case.config.get("network_error"):
+ called_url = mock_urlopen.call_args[0][0]
+ assert f"{kwargs['owner']}/{kwargs['repo']}" in called_url
+ assert kwargs["path"] in called_url
+
+ assert test_case.expected_status == SUCCESS
+
+ except Exception as e:
+ assert type(e) is test_case.expected_error
+ print("test execution complete")
+
+
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ TestCase(
+ name="fetch runtime success",
+ expected_status=SUCCESS,
+ config={
+ "yaml_content": """
+apiVersion: trainer.kubeflow.org/v1alpha1
+kind: ClusterTrainingRuntime
+metadata:
+ name: torch-distributed
+""",
+ "expected_name": "torch-distributed",
+ },
+ ),
+ TestCase(
+ name="fetch runtime custom repo",
+ expected_status=SUCCESS,
+ config={
+ "yaml_content": """
+apiVersion: trainer.kubeflow.org/v1alpha1
+kind: ClusterTrainingRuntime
+metadata:
+ name: custom-runtime
+""",
+ "expected_name": "custom-runtime",
+ "runtime_file": "custom.yaml",
+ "owner": "myorg",
+ "repo": "myrepo",
+ "path": "custom/path",
+ },
+ ),
+ TestCase(
+ name="fetch runtime network error",
+ expected_status=SUCCESS,
+ config={
+ "network_error": True,
+ "expected_none": True,
+ },
+ ),
+ ],
+)
+def test_fetch_runtime_from_github(test_case):
+ """Test fetching runtime YAML from GitHub."""
+ print("Executing test:", test_case.name)
+ try:
+ with patch("urllib.request.urlopen") as mock_urlopen:
+ if test_case.config.get("network_error"):
+ mock_urlopen.side_effect = Exception("Network error")
+ else:
+ mock_response = MagicMock()
+ mock_response.read.return_value = test_case.config["yaml_content"].encode("utf-8")
+ mock_response.__enter__.return_value = mock_response
+ mock_urlopen.return_value = mock_response
+
+ default_runtime_file = "torch_distributed.yaml"
+ kwargs = {"runtime_file": test_case.config.get("runtime_file", default_runtime_file)}
+ if "owner" in test_case.config:
+ kwargs["owner"] = test_case.config["owner"]
+ kwargs["repo"] = test_case.config["repo"]
+ kwargs["path"] = test_case.config["path"]
+
+ data = runtime_loader._fetch_runtime_from_github(**kwargs)
+
+ if test_case.config.get("expected_none"):
+ assert data is None
+ else:
+ assert data is not None
+ assert data["metadata"]["name"] == test_case.config["expected_name"]
+
+ if "owner" in test_case.config:
+ called_url = mock_urlopen.call_args[0][0]
+ assert "raw.githubusercontent.com" in called_url
+ assert f"{kwargs['owner']}/{kwargs['repo']}" in called_url
+ assert f"{kwargs['path']}/{kwargs['runtime_file']}" in called_url
+
+ assert test_case.expected_status == SUCCESS
+
+ except Exception as e:
+ assert type(e) is test_case.expected_error
+ print("test execution complete")
diff --git a/kubeflow/trainer/backends/container/types.py b/kubeflow/trainer/backends/container/types.py
new file mode 100644
index 000000000..f30025cb9
--- /dev/null
+++ b/kubeflow/trainer/backends/container/types.py
@@ -0,0 +1,67 @@
+# Copyright 2025 The Kubeflow Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Types and configuration for the unified Container backend.
+
+This backend automatically detects and uses either Docker or Podman.
+It provides a single interface for container-based execution regardless
+of the underlying runtime.
+
+Configuration options:
+ - pull_policy: Controls image pulling. Supported values: "IfNotPresent",
+ "Always", "Never". The default is "IfNotPresent".
+ - auto_remove: Whether to remove containers and networks when jobs are deleted.
+ Defaults to True.
+ - container_host: Optional override for connecting to a remote/local container
+ daemon. By default, auto-detects from environment or uses system defaults.
+ For Docker: uses DOCKER_HOST or default socket.
+ For Podman: uses CONTAINER_HOST or default socket.
+ - container_runtime: Force use of a specific container runtime ("docker" or "podman").
+ If not set, auto-detects based on availability (tries Docker first, then Podman).
+ - runtime_source: Configuration for training runtime sources using URL schemes.
+ Supports github://, https://, http://, file://, and absolute paths.
+ Built-in runtimes packaged with kubeflow-trainer are used as default fallback.
+"""
+
+from typing import Literal, Optional
+
+from pydantic import BaseModel, Field
+
+
+class TrainingRuntimeSource(BaseModel):
+ """Configuration for training runtime sources using URL schemes."""
+
+ sources: list[str] = Field(
+ default_factory=lambda: ["github://kubeflow/trainer"],
+ description=(
+ "Runtime sources with URL schemes (checked in priority order):\n"
+ " - github://owner/repo[/path] - GitHub repository\n"
+ " - https://url or http://url - HTTP(S) endpoint\n"
+ " - file:///path or /absolute/path - Local filesystem\n"
+ "If a runtime is not found in configured sources, built-in runtimes "
+ "packaged with kubeflow-trainer are used as default."
+ ),
+ )
+
+
+class ContainerBackendConfig(BaseModel):
+ pull_policy: str = Field(default="IfNotPresent")
+ auto_remove: bool = Field(default=True)
+ container_host: Optional[str] = Field(default=None)
+ container_runtime: Optional[Literal["docker", "podman"]] = Field(default=None)
+ runtime_source: TrainingRuntimeSource = Field(
+ default_factory=TrainingRuntimeSource,
+ description="Configuration for training runtime sources",
+ )
diff --git a/kubeflow/trainer/backends/container/utils.py b/kubeflow/trainer/backends/container/utils.py
new file mode 100644
index 000000000..8642f8651
--- /dev/null
+++ b/kubeflow/trainer/backends/container/utils.py
@@ -0,0 +1,236 @@
+# Copyright 2025 The Kubeflow Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Utility functions for the Container backend.
+"""
+
+import logging
+import os
+from pathlib import Path
+
+from kubeflow.common.constants import UNKNOWN
+from kubeflow.trainer.constants import constants
+from kubeflow.trainer.types import types
+
+logger = logging.getLogger(__name__)
+
+
+def create_workdir(job_name: str) -> str:
+ """
+ Create per-job working directory on host.
+
+ Working directories are created under ~/.kubeflow/trainer/containers/
+
+ Args:
+ job_name: Name of the training job.
+
+ Returns:
+ Absolute path to the working directory.
+ """
+ home_base = Path.home() / ".kubeflow" / "trainer" / "containers"
+ home_base.mkdir(parents=True, exist_ok=True)
+ workdir = str((home_base / f"{job_name}").resolve())
+ os.makedirs(workdir, exist_ok=True)
+ return workdir
+
+
+def get_training_script_code(trainer: types.CustomTrainer) -> str:
+ """
+ Generate the training script code from the trainer function.
+
+ This extracts the function source and appends a function call,
+ similar to how the Kubernetes backend handles training scripts.
+
+ Args:
+ trainer: CustomTrainer configuration.
+
+ Returns:
+ Complete Python code as a string to execute.
+ """
+ import inspect
+ import textwrap
+
+ code = inspect.getsource(trainer.func)
+ code = textwrap.dedent(code)
+ if trainer.func_args is None:
+ code += f"\n{trainer.func.__name__}()\n"
+ else:
+ code += f"\n{trainer.func.__name__}(**{trainer.func_args})\n"
+ return code
+
+
+def build_environment(trainer: types.CustomTrainer) -> dict[str, str]:
+ """
+ Build environment variables for containers.
+
+ Args:
+ trainer: CustomTrainer configuration.
+
+ Returns:
+ Dictionary of environment variables.
+ """
+ return dict(trainer.env or {})
+
+
+def build_pip_install_cmd(trainer: types.CustomTrainer) -> str:
+ """
+ Build pip install command for packages.
+
+ Args:
+ trainer: CustomTrainer configuration.
+
+ Returns:
+ Pip install command string (empty if no packages to install).
+ """
+ pkgs = trainer.packages_to_install or []
+ if not pkgs:
+ return ""
+
+ index_urls = trainer.pip_index_urls or list(constants.DEFAULT_PIP_INDEX_URLS)
+ main_idx = index_urls[0]
+ extras = " ".join(f"--extra-index-url {u}" for u in index_urls[1:])
+ quoted = " ".join(f'"{p}"' for p in pkgs)
+ return (
+ "PIP_DISABLE_PIP_VERSION_CHECK=1 pip install --no-warn-script-location "
+ f"--index-url {main_idx} {extras} {quoted} && "
+ )
+
+
+def container_status_to_trainjob_status(status: str, exit_code: int) -> str:
+ """
+ Convert container status to TrainJob status.
+
+ Args:
+ status: Container status (e.g., "running", "exited", "created").
+ exit_code: Container exit code.
+
+ Returns:
+ TrainJob status constant.
+ """
+ if status == "running":
+ return constants.TRAINJOB_RUNNING
+ if status == "created":
+ return constants.TRAINJOB_CREATED
+ if status == "exited":
+ # Exit code 0 -> complete, else failed
+ return constants.TRAINJOB_COMPLETE if exit_code == 0 else constants.TRAINJOB_FAILED
+ return constants.UNKNOWN
+
+
+def aggregate_status_from_containers(container_statuses: list[str]) -> str:
+ """
+ Aggregate status from multiple container statuses.
+
+ Args:
+ container_statuses: List of container status strings.
+
+ Returns:
+ Aggregated TrainJob status.
+ """
+ if constants.TRAINJOB_FAILED in container_statuses:
+ return constants.TRAINJOB_FAILED
+ if constants.TRAINJOB_RUNNING in container_statuses:
+ return constants.TRAINJOB_RUNNING
+ if all(s == constants.TRAINJOB_COMPLETE for s in container_statuses if s != UNKNOWN):
+ return constants.TRAINJOB_COMPLETE
+ if any(s == constants.TRAINJOB_CREATED for s in container_statuses):
+ return constants.TRAINJOB_CREATED
+ return UNKNOWN
+
+
+def resolve_image(runtime: types.Runtime) -> str:
+ """
+ Resolve the container image for a runtime from DEFAULT_FRAMEWORK_IMAGES.
+
+ Args:
+ runtime: Runtime object.
+
+ Returns:
+ Container image name.
+
+ Raises:
+ ValueError: If no image is found for the runtime's framework.
+ """
+ framework = runtime.trainer.framework
+ if framework in constants.DEFAULT_FRAMEWORK_IMAGES:
+ return constants.DEFAULT_FRAMEWORK_IMAGES[framework]
+
+ raise ValueError(
+ f"No default image found for framework '{framework}'. "
+ f"Supported frameworks: {list(constants.DEFAULT_FRAMEWORK_IMAGES.keys())}"
+ )
+
+
+def maybe_pull_image(adapter, image: str, pull_policy: str):
+ """
+ Pull image based on pull policy.
+
+ Args:
+ adapter: Container client adapter (DockerClientAdapter or PodmanClientAdapter).
+ image: Container image name.
+ pull_policy: Pull policy ("IfNotPresent", "Always", or "Never").
+
+ Raises:
+ RuntimeError: If image is not found or pull fails.
+ """
+ policy = pull_policy.lower()
+ try:
+ if policy == "never":
+ if not adapter.image_exists(image):
+ raise RuntimeError(f"Image '{image}' not found locally and pull policy is Never")
+ return
+ if policy == "always":
+ logger.debug(f"Pulling image (Always): {image}")
+ adapter.pull_image(image)
+ return
+ # IfNotPresent
+ if not adapter.image_exists(image):
+ logger.debug(f"Pulling image (IfNotPresent): {image}")
+ adapter.pull_image(image)
+ except Exception as e:
+ raise RuntimeError(f"Failed to ensure image '{image}': {e}") from e
+
+
+def get_container_status(adapter, container_id: str) -> str:
+ """
+ Get the TrainJob status of a container.
+
+ Args:
+ adapter: Container client adapter (DockerClientAdapter or PodmanClientAdapter).
+ container_id: Container ID.
+
+ Returns:
+ TrainJob status constant.
+ """
+ try:
+ status, exit_code = adapter.container_status(container_id)
+ return container_status_to_trainjob_status(status, exit_code)
+ except Exception:
+ return constants.UNKNOWN
+
+
+def aggregate_container_statuses(adapter, containers: list[dict]) -> str:
+ """
+ Aggregate TrainJob status from container info dicts.
+
+ Args:
+ adapter: Container client adapter (DockerClientAdapter or PodmanClientAdapter).
+ containers: List of container info dicts with 'id' key.
+
+ Returns:
+ Aggregated TrainJob status.
+ """
+ statuses = [get_container_status(adapter, c["id"]) for c in containers]
+ return aggregate_status_from_containers(statuses)
diff --git a/kubeflow/trainer/constants/constants.py b/kubeflow/trainer/constants/constants.py
index 5356bc5a2..a15d402ae 100644
--- a/kubeflow/trainer/constants/constants.py
+++ b/kubeflow/trainer/constants/constants.py
@@ -162,3 +162,8 @@
# The Instruct Datasets class in torchtune
TORCH_TUNE_INSTRUCT_DATASET = "torchtune.datasets.instruct_dataset"
+
+# Default container images for each framework (used as fallback when runtime not provided)
+DEFAULT_FRAMEWORK_IMAGES = {
+ "torch": "pytorch/pytorch:2.7.1-cuda12.8-cudnn9-runtime",
+}
diff --git a/pyproject.toml b/pyproject.toml
index bcf11c1c1..570d16f90 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -32,6 +32,14 @@ dependencies = [
"kubeflow-katib-api>=0.19.0",
]
+[project.optional-dependencies]
+docker = [
+ "docker>=6.1.3",
+]
+podman = [
+ "podman>=5.6.0"
+]
+
[dependency-groups]
dev = [
"pytest>=7.0",
@@ -100,7 +108,9 @@ select = [
]
ignore = [
- "B006", # mutable-argument-default
+ "B006", # mutable-argument-default
+ "UP007", # Use X | Y instead of Union[X, Y] (requires Python 3.10+)
+ "UP045", # Use X | None instead of Optional[X] (requires Python 3.10+)
]
diff --git a/uv.lock b/uv.lock
index 37b6e544b..64adf4957 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1,5 +1,5 @@
version = 1
-revision = 2
+revision = 3
requires-python = ">=3.9"
[[package]]
@@ -341,6 +341,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" },
]
+[[package]]
+name = "docker"
+version = "7.1.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "pywin32", marker = "sys_platform == 'win32'" },
+ { name = "requests" },
+ { name = "urllib3" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/91/9b/4a2ea29aeba62471211598dac5d96825bb49348fa07e906ea930394a83ce/docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c", size = 117834, upload-time = "2024-05-23T11:13:57.216Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774, upload-time = "2024-05-23T11:13:55.01Z" },
+]
+
[[package]]
name = "durationpy"
version = "0.10"
@@ -422,6 +436,14 @@ dependencies = [
{ name = "pydantic" },
]
+[package.optional-dependencies]
+docker = [
+ { name = "docker" },
+]
+podman = [
+ { name = "podman" },
+]
+
[package.dev-dependencies]
dev = [
{ name = "coverage" },
@@ -437,10 +459,13 @@ dev = [
[package.metadata]
requires-dist = [
{ name = "kubeflow-katib-api", specifier = ">=0.19.0" },
+ { name = "docker", marker = "extra == 'docker'", specifier = ">=6.1.3" },
{ name = "kubeflow-trainer-api", specifier = ">=2.0.0" },
{ name = "kubernetes", specifier = ">=27.2.0" },
+ { name = "podman", marker = "extra == 'podman'", specifier = ">=5.6.0" },
{ name = "pydantic", specifier = ">=2.10.0" },
]
+provides-extras = ["docker", "podman"]
[package.metadata.requires-dev]
dev = [
@@ -537,6 +562,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
]
+[[package]]
+name = "podman"
+version = "5.6.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "requests" },
+ { name = "tomli", marker = "python_full_version < '3.11'" },
+ { name = "urllib3" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/3b/36/070e7bf682ac0868450584df79198c178323e80f73b8fb9b6fec8bde0a65/podman-5.6.0.tar.gz", hash = "sha256:cc5f7aa9562e30f992fc170a48da970a7132be60d8a2e2941e6c17bd0a0b35c9", size = 72832, upload-time = "2025-09-05T09:42:40.071Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/0a/9e/8c62f05b104d9f00edbb4c298b152deceb393ea67f0288d89d1139d7a859/podman-5.6.0-py3-none-any.whl", hash = "sha256:967ff8ad8c6b851bc5da1a9410973882d80e235a9410b7d1e931ce0c3324fbe3", size = 88713, upload-time = "2025-09-05T09:42:38.405Z" },
+]
+
[[package]]
name = "pre-commit"
version = "4.3.0"
@@ -808,6 +847,31 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" },
]
+[[package]]
+name = "pywin32"
+version = "311"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/7b/40/44efbb0dfbd33aca6a6483191dae0716070ed99e2ecb0c53683f400a0b4f/pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3", size = 8760432, upload-time = "2025-07-14T20:13:05.9Z" },
+ { url = "https://files.pythonhosted.org/packages/5e/bf/360243b1e953bd254a82f12653974be395ba880e7ec23e3731d9f73921cc/pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b", size = 9590103, upload-time = "2025-07-14T20:13:07.698Z" },
+ { url = "https://files.pythonhosted.org/packages/57/38/d290720e6f138086fb3d5ffe0b6caa019a791dd57866940c82e4eeaf2012/pywin32-311-cp310-cp310-win_arm64.whl", hash = "sha256:0502d1facf1fed4839a9a51ccbcc63d952cf318f78ffc00a7e78528ac27d7a2b", size = 8778557, upload-time = "2025-07-14T20:13:11.11Z" },
+ { url = "https://files.pythonhosted.org/packages/7c/af/449a6a91e5d6db51420875c54f6aff7c97a86a3b13a0b4f1a5c13b988de3/pywin32-311-cp311-cp311-win32.whl", hash = "sha256:184eb5e436dea364dcd3d2316d577d625c0351bf237c4e9a5fabbcfa5a58b151", size = 8697031, upload-time = "2025-07-14T20:13:13.266Z" },
+ { url = "https://files.pythonhosted.org/packages/51/8f/9bb81dd5bb77d22243d33c8397f09377056d5c687aa6d4042bea7fbf8364/pywin32-311-cp311-cp311-win_amd64.whl", hash = "sha256:3ce80b34b22b17ccbd937a6e78e7225d80c52f5ab9940fe0506a1a16f3dab503", size = 9508308, upload-time = "2025-07-14T20:13:15.147Z" },
+ { url = "https://files.pythonhosted.org/packages/44/7b/9c2ab54f74a138c491aba1b1cd0795ba61f144c711daea84a88b63dc0f6c/pywin32-311-cp311-cp311-win_arm64.whl", hash = "sha256:a733f1388e1a842abb67ffa8e7aad0e70ac519e09b0f6a784e65a136ec7cefd2", size = 8703930, upload-time = "2025-07-14T20:13:16.945Z" },
+ { url = "https://files.pythonhosted.org/packages/e7/ab/01ea1943d4eba0f850c3c61e78e8dd59757ff815ff3ccd0a84de5f541f42/pywin32-311-cp312-cp312-win32.whl", hash = "sha256:750ec6e621af2b948540032557b10a2d43b0cee2ae9758c54154d711cc852d31", size = 8706543, upload-time = "2025-07-14T20:13:20.765Z" },
+ { url = "https://files.pythonhosted.org/packages/d1/a8/a0e8d07d4d051ec7502cd58b291ec98dcc0c3fff027caad0470b72cfcc2f/pywin32-311-cp312-cp312-win_amd64.whl", hash = "sha256:b8c095edad5c211ff31c05223658e71bf7116daa0ecf3ad85f3201ea3190d067", size = 9495040, upload-time = "2025-07-14T20:13:22.543Z" },
+ { url = "https://files.pythonhosted.org/packages/ba/3a/2ae996277b4b50f17d61f0603efd8253cb2d79cc7ae159468007b586396d/pywin32-311-cp312-cp312-win_arm64.whl", hash = "sha256:e286f46a9a39c4a18b319c28f59b61de793654af2f395c102b4f819e584b5852", size = 8710102, upload-time = "2025-07-14T20:13:24.682Z" },
+ { url = "https://files.pythonhosted.org/packages/a5/be/3fd5de0979fcb3994bfee0d65ed8ca9506a8a1260651b86174f6a86f52b3/pywin32-311-cp313-cp313-win32.whl", hash = "sha256:f95ba5a847cba10dd8c4d8fefa9f2a6cf283b8b88ed6178fa8a6c1ab16054d0d", size = 8705700, upload-time = "2025-07-14T20:13:26.471Z" },
+ { url = "https://files.pythonhosted.org/packages/e3/28/e0a1909523c6890208295a29e05c2adb2126364e289826c0a8bc7297bd5c/pywin32-311-cp313-cp313-win_amd64.whl", hash = "sha256:718a38f7e5b058e76aee1c56ddd06908116d35147e133427e59a3983f703a20d", size = 9494700, upload-time = "2025-07-14T20:13:28.243Z" },
+ { url = "https://files.pythonhosted.org/packages/04/bf/90339ac0f55726dce7d794e6d79a18a91265bdf3aa70b6b9ca52f35e022a/pywin32-311-cp313-cp313-win_arm64.whl", hash = "sha256:7b4075d959648406202d92a2310cb990fea19b535c7f4a78d3f5e10b926eeb8a", size = 8709318, upload-time = "2025-07-14T20:13:30.348Z" },
+ { url = "https://files.pythonhosted.org/packages/c9/31/097f2e132c4f16d99a22bfb777e0fd88bd8e1c634304e102f313af69ace5/pywin32-311-cp314-cp314-win32.whl", hash = "sha256:b7a2c10b93f8986666d0c803ee19b5990885872a7de910fc460f9b0c2fbf92ee", size = 8840714, upload-time = "2025-07-14T20:13:32.449Z" },
+ { url = "https://files.pythonhosted.org/packages/90/4b/07c77d8ba0e01349358082713400435347df8426208171ce297da32c313d/pywin32-311-cp314-cp314-win_amd64.whl", hash = "sha256:3aca44c046bd2ed8c90de9cb8427f581c479e594e99b5c0bb19b29c10fd6cb87", size = 9656800, upload-time = "2025-07-14T20:13:34.312Z" },
+ { url = "https://files.pythonhosted.org/packages/c0/d2/21af5c535501a7233e734b8af901574572da66fcc254cb35d0609c9080dd/pywin32-311-cp314-cp314-win_arm64.whl", hash = "sha256:a508e2d9025764a8270f93111a970e1d0fbfc33f4153b388bb649b7eec4f9b42", size = 8932540, upload-time = "2025-07-14T20:13:36.379Z" },
+ { url = "https://files.pythonhosted.org/packages/59/42/b86689aac0cdaee7ae1c58d464b0ff04ca909c19bb6502d4973cdd9f9544/pywin32-311-cp39-cp39-win32.whl", hash = "sha256:aba8f82d551a942cb20d4a83413ccbac30790b50efb89a75e4f586ac0bb8056b", size = 8760837, upload-time = "2025-07-14T20:12:59.59Z" },
+ { url = "https://files.pythonhosted.org/packages/9f/8a/1403d0353f8c5a2f0829d2b1c4becbf9da2f0a4d040886404fc4a5431e4d/pywin32-311-cp39-cp39-win_amd64.whl", hash = "sha256:e0c4cfb0621281fe40387df582097fd796e80430597cb9944f0ae70447bacd91", size = 9590187, upload-time = "2025-07-14T20:13:01.419Z" },
+ { url = "https://files.pythonhosted.org/packages/60/22/e0e8d802f124772cec9c75430b01a212f86f9de7546bda715e54140d5aeb/pywin32-311-cp39-cp39-win_arm64.whl", hash = "sha256:62ea666235135fee79bb154e695f3ff67370afefd71bd7fea7512fc70ef31e3d", size = 8778162, upload-time = "2025-07-14T20:13:03.544Z" },
+]
+
[[package]]
name = "pyyaml"
version = "6.0.2"