From 4b1450e81915c9f6ece534c37f4d5b449914c97d Mon Sep 17 00:00:00 2001 From: muzzlol Date: Tue, 6 Jan 2026 16:11:59 +0530 Subject: [PATCH] docs(examples): add GPU passthrough support to container backend example Add runtime NVIDIA GPU detection to automatically enable GPU passthrough when available. Relates to kubeflow/sdk#219 Signed-off-by: muzzlol --- examples/local/local-container-mnist.ipynb | 125 +++++++++------------ 1 file changed, 56 insertions(+), 69 deletions(-) diff --git a/examples/local/local-container-mnist.ipynb b/examples/local/local-container-mnist.ipynb index a19689fb29..6c4ee3fb35 100644 --- a/examples/local/local-container-mnist.ipynb +++ b/examples/local/local-container-mnist.ipynb @@ -14,6 +14,7 @@ "- **Container Runtime**: Docker or Podman required\n", "- **Use Case**: Testing container workflows, simulating production environments\n", "- **Prerequisites**: Python 3.9+ and Docker Desktop/Engine OR Podman\n", + "- **GPU Support**: Automatic GPU passthrough when available (NVIDIA: Linux/WSL2)\n", "\n", "This example trains a CNN on the classic [MNIST](http://yann.lecun.com/exdb/mnist/) handwritten digit dataset using PyTorch." ] @@ -30,7 +31,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "pip-install", "metadata": {}, "outputs": [], @@ -52,7 +53,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "train-function", "metadata": {}, "outputs": [], @@ -179,7 +180,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "backend-config", "metadata": {}, "outputs": [], @@ -218,7 +219,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 4, "id": "create-client", "metadata": {}, "outputs": [], @@ -238,7 +239,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 5, "id": "get-runtimes", "metadata": {}, "outputs": [ @@ -271,18 +272,41 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 6, "id": "train-job", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NVIDIA GPU detected, enabling GPU passthrough\n" + ] + } + ], "source": [ "from kubeflow.trainer import CustomTrainer\n", + "import shutil\n", + "import platform\n", + "\n", + "# GPU passthrough requires: Linux/WSL2 + NVIDIA drivers + Container Toolkit\n", + "# Check for nvidia-smi as a proxy for NVIDIA driver availability\n", + "gpu_available = (\n", + " platform.system() == \"Linux\" and \n", + " shutil.which(\"nvidia-smi\") is not None\n", + ")\n", + "\n", + "if gpu_available:\n", + " print(\"NVIDIA GPU detected, enabling GPU passthrough\")\n", + "else:\n", + " print(\"No GPU detected, running on CPU\")\n", "\n", "job_name = client.train(\n", " trainer=CustomTrainer(\n", " func=train_mnist,\n", " packages_to_install=[\"torchvision\"],\n", - " num_nodes=2,\n", + " num_nodes=1,\n", + " resources_per_node={\"gpu\": 1} if gpu_available else None,\n", " ),\n", " runtime=torch_runtime,\n", ")" @@ -300,7 +324,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 7, "id": "job-status", "metadata": {}, "outputs": [ @@ -308,7 +332,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Job: p6072ef2ca48, Status: Running\n" + "Job: mb7c8d26fd6d, Status: Running\n" ] } ], @@ -329,7 +353,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 8, "id": "logs", "metadata": {}, "outputs": [ @@ -367,61 +391,24 @@ "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.11/site-packages (from jinja2->torch==2.7.1->torchvision) (3.0.2)\n", "WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\n", "Using Device: cpu, Backend: gloo\n", - "Distributed Training for WORLD_SIZE: 2, RANK: 0, LOCAL_RANK: 0\n", - "100%|██████████| 26.4M/26.4M [00:00<00:00, 50.3MB/s]\n", - "100%|██████████| 29.5k/29.5k [00:00<00:00, 2.97MB/s]\n", - "100%|██████████| 4.42M/4.42M [00:00<00:00, 43.0MB/s]\n", - "100%|██████████| 5.15k/5.15k [00:00<00:00, 47.2MB/s]\n", - "Train Epoch: 1 [0/60000 (0%)]\tLoss: 2.315568\n", - "Train Epoch: 1 [10000/60000 (33%)]\tLoss: 0.734196\n", - "Train Epoch: 1 [20000/60000 (67%)]\tLoss: 0.756162\n", - "Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.722734\n", - "Train Epoch: 2 [10000/60000 (33%)]\tLoss: 0.428438\n", - "Train Epoch: 2 [20000/60000 (67%)]\tLoss: 0.625209\n", - "Training is finished\n", - "Requirement already satisfied: torchvision in /opt/conda/lib/python3.11/site-packages (0.22.1+cu128)\n", - "Requirement already satisfied: numpy in /opt/conda/lib/python3.11/site-packages (from torchvision) (2.2.6)\n", - "Requirement already satisfied: torch==2.7.1 in /opt/conda/lib/python3.11/site-packages (from torchvision) (2.7.1+cu128)\n", - "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/conda/lib/python3.11/site-packages (from torchvision) (11.0.0)\n", - "Requirement already satisfied: filelock in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (3.18.0)\n", - "Requirement already satisfied: typing-extensions>=4.10.0 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (4.14.0)\n", - "Requirement already satisfied: sympy>=1.13.3 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (1.14.0)\n", - "Requirement already satisfied: networkx in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (3.5)\n", - "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (3.1.6)\n", - "Requirement already satisfied: fsspec in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (2025.5.1)\n", - "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.61 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (12.8.61)\n", - "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.57 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (12.8.57)\n", - "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.57 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (12.8.57)\n", - "Requirement already satisfied: nvidia-cudnn-cu12==9.7.1.26 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (9.7.1.26)\n", - "Requirement already satisfied: nvidia-cublas-cu12==12.8.3.14 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (12.8.3.14)\n", - "Requirement already satisfied: nvidia-cufft-cu12==11.3.3.41 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (11.3.3.41)\n", - "Requirement already satisfied: nvidia-curand-cu12==10.3.9.55 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (10.3.9.55)\n", - "Requirement already satisfied: nvidia-cusolver-cu12==11.7.2.55 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (11.7.2.55)\n", - "Requirement already satisfied: nvidia-cusparse-cu12==12.5.7.53 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (12.5.7.53)\n", - "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.3 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (0.6.3)\n", - "Requirement already satisfied: nvidia-nccl-cu12==2.26.2 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (2.26.2)\n", - "Requirement already satisfied: nvidia-nvtx-cu12==12.8.55 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (12.8.55)\n", - "Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.61 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (12.8.61)\n", - "Requirement already satisfied: nvidia-cufile-cu12==1.13.0.11 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (1.13.0.11)\n", - "Requirement already satisfied: triton==3.3.1 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (3.3.1)\n", - "Requirement already satisfied: setuptools>=40.8.0 in /opt/conda/lib/python3.11/site-packages (from triton==3.3.1->torch==2.7.1->torchvision) (75.8.2)\n", - "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/lib/python3.11/site-packages (from sympy>=1.13.3->torch==2.7.1->torchvision) (1.3.0)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.11/site-packages (from jinja2->torch==2.7.1->torchvision) (3.0.2)\n", - "WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\n", - "Waiting for master node p6072ef2ca48-node-0:29500...\n", - "Master node is reachable\n", - "Using Device: cpu, Backend: gloo\n", - "Distributed Training for WORLD_SIZE: 2, RANK: 1, LOCAL_RANK: 0\n", - "100%|██████████| 26.4M/26.4M [00:00<00:00, 51.8MB/s]\n", - "100%|██████████| 29.5k/29.5k [00:00<00:00, 2.81MB/s]\n", - "100%|██████████| 4.42M/4.42M [00:00<00:00, 25.2MB/s]\n", - "100%|██████████| 5.15k/5.15k [00:00<00:00, 13.3MB/s]\n", - "Train Epoch: 1 [0/60000 (0%)]\tLoss: 2.322156\n", - "Train Epoch: 1 [10000/60000 (33%)]\tLoss: 0.986901\n", - "Train Epoch: 1 [20000/60000 (67%)]\tLoss: 0.654047\n", - "Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.686302\n", - "Train Epoch: 2 [10000/60000 (33%)]\tLoss: 0.607211\n", - "Train Epoch: 2 [20000/60000 (67%)]\tLoss: 0.552617\n" + "Distributed Training for WORLD_SIZE: 1, RANK: 0, LOCAL_RANK: 0\n", + "100%|██████████| 26.4M/26.4M [00:14<00:00, 1.79MB/s]\n", + "100%|██████████| 29.5k/29.5k [00:00<00:00, 191kB/s]\n", + "100%|██████████| 4.42M/4.42M [00:02<00:00, 2.17MB/s]\n", + "100%|██████████| 5.15k/5.15k [00:00<00:00, 15.1MB/s]\n", + "Train Epoch: 1 [0/60000 (0%)]\tLoss: 2.303455\n", + "Train Epoch: 1 [10000/60000 (17%)]\tLoss: 2.050205\n", + "Train Epoch: 1 [20000/60000 (33%)]\tLoss: 1.071181\n", + "Train Epoch: 1 [30000/60000 (50%)]\tLoss: 0.900017\n", + "Train Epoch: 1 [40000/60000 (67%)]\tLoss: 0.710202\n", + "Train Epoch: 1 [50000/60000 (83%)]\tLoss: 0.771787\n", + "Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.897072\n", + "Train Epoch: 2 [10000/60000 (17%)]\tLoss: 0.649391\n", + "Train Epoch: 2 [20000/60000 (33%)]\tLoss: 0.592039\n", + "Train Epoch: 2 [30000/60000 (50%)]\tLoss: 0.606340\n", + "Train Epoch: 2 [40000/60000 (67%)]\tLoss: 0.533571\n", + "Train Epoch: 2 [50000/60000 (83%)]\tLoss: 0.617757\n", + "Training is finished\n" ] } ], @@ -442,7 +429,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "id": "delete", "metadata": {}, "outputs": [], @@ -453,7 +440,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -467,7 +454,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.12" + "version": "3.11.13" } }, "nbformat": 4,