Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 56 additions & 69 deletions examples/local/local-container-mnist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"- **Container Runtime**: Docker or Podman required\n",
"- **Use Case**: Testing container workflows, simulating production environments\n",
"- **Prerequisites**: Python 3.9+ and Docker Desktop/Engine OR Podman\n",
"- **GPU Support**: Automatic GPU passthrough when available (NVIDIA: Linux/WSL2)\n",
"\n",
"This example trains a CNN on the classic [MNIST](http://yann.lecun.com/exdb/mnist/) handwritten digit dataset using PyTorch."
]
Expand All @@ -30,7 +31,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "pip-install",
"metadata": {},
"outputs": [],
Expand All @@ -52,7 +53,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"id": "train-function",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -179,7 +180,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"id": "backend-config",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -218,7 +219,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 4,
"id": "create-client",
"metadata": {},
"outputs": [],
Expand All @@ -238,7 +239,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 5,
"id": "get-runtimes",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -271,18 +272,41 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 6,
"id": "train-job",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"NVIDIA GPU detected, enabling GPU passthrough\n"
]
}
],
"source": [
"from kubeflow.trainer import CustomTrainer\n",
"import shutil\n",
"import platform\n",
"\n",
"# GPU passthrough requires: Linux/WSL2 + NVIDIA drivers + Container Toolkit\n",
"# Check for nvidia-smi as a proxy for NVIDIA driver availability\n",
"gpu_available = (\n",
" platform.system() == \"Linux\" and \n",
" shutil.which(\"nvidia-smi\") is not None\n",
")\n",
"\n",
"if gpu_available:\n",
" print(\"NVIDIA GPU detected, enabling GPU passthrough\")\n",
"else:\n",
" print(\"No GPU detected, running on CPU\")\n",
"\n",
"job_name = client.train(\n",
" trainer=CustomTrainer(\n",
" func=train_mnist,\n",
" packages_to_install=[\"torchvision\"],\n",
" num_nodes=2,\n",
" num_nodes=1,\n",
" resources_per_node={\"gpu\": 1} if gpu_available else None,\n",
" ),\n",
" runtime=torch_runtime,\n",
")"
Expand All @@ -300,15 +324,15 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 7,
"id": "job-status",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Job: p6072ef2ca48, Status: Running\n"
"Job: mb7c8d26fd6d, Status: Running\n"
]
}
],
Expand All @@ -329,7 +353,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 8,
"id": "logs",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -367,61 +391,24 @@
"Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.11/site-packages (from jinja2->torch==2.7.1->torchvision) (3.0.2)\n",
"WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\n",
"Using Device: cpu, Backend: gloo\n",
"Distributed Training for WORLD_SIZE: 2, RANK: 0, LOCAL_RANK: 0\n",
"100%|██████████| 26.4M/26.4M [00:00<00:00, 50.3MB/s]\n",
"100%|██████████| 29.5k/29.5k [00:00<00:00, 2.97MB/s]\n",
"100%|██████████| 4.42M/4.42M [00:00<00:00, 43.0MB/s]\n",
"100%|██████████| 5.15k/5.15k [00:00<00:00, 47.2MB/s]\n",
"Train Epoch: 1 [0/60000 (0%)]\tLoss: 2.315568\n",
"Train Epoch: 1 [10000/60000 (33%)]\tLoss: 0.734196\n",
"Train Epoch: 1 [20000/60000 (67%)]\tLoss: 0.756162\n",
"Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.722734\n",
"Train Epoch: 2 [10000/60000 (33%)]\tLoss: 0.428438\n",
"Train Epoch: 2 [20000/60000 (67%)]\tLoss: 0.625209\n",
"Training is finished\n",
"Requirement already satisfied: torchvision in /opt/conda/lib/python3.11/site-packages (0.22.1+cu128)\n",
"Requirement already satisfied: numpy in /opt/conda/lib/python3.11/site-packages (from torchvision) (2.2.6)\n",
"Requirement already satisfied: torch==2.7.1 in /opt/conda/lib/python3.11/site-packages (from torchvision) (2.7.1+cu128)\n",
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/conda/lib/python3.11/site-packages (from torchvision) (11.0.0)\n",
"Requirement already satisfied: filelock in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (3.18.0)\n",
"Requirement already satisfied: typing-extensions>=4.10.0 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (4.14.0)\n",
"Requirement already satisfied: sympy>=1.13.3 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (1.14.0)\n",
"Requirement already satisfied: networkx in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (3.5)\n",
"Requirement already satisfied: jinja2 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (3.1.6)\n",
"Requirement already satisfied: fsspec in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (2025.5.1)\n",
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.61 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (12.8.61)\n",
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.57 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (12.8.57)\n",
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.57 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (12.8.57)\n",
"Requirement already satisfied: nvidia-cudnn-cu12==9.7.1.26 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (9.7.1.26)\n",
"Requirement already satisfied: nvidia-cublas-cu12==12.8.3.14 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (12.8.3.14)\n",
"Requirement already satisfied: nvidia-cufft-cu12==11.3.3.41 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (11.3.3.41)\n",
"Requirement already satisfied: nvidia-curand-cu12==10.3.9.55 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (10.3.9.55)\n",
"Requirement already satisfied: nvidia-cusolver-cu12==11.7.2.55 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (11.7.2.55)\n",
"Requirement already satisfied: nvidia-cusparse-cu12==12.5.7.53 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (12.5.7.53)\n",
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.3 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (0.6.3)\n",
"Requirement already satisfied: nvidia-nccl-cu12==2.26.2 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (2.26.2)\n",
"Requirement already satisfied: nvidia-nvtx-cu12==12.8.55 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (12.8.55)\n",
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.61 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (12.8.61)\n",
"Requirement already satisfied: nvidia-cufile-cu12==1.13.0.11 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (1.13.0.11)\n",
"Requirement already satisfied: triton==3.3.1 in /opt/conda/lib/python3.11/site-packages (from torch==2.7.1->torchvision) (3.3.1)\n",
"Requirement already satisfied: setuptools>=40.8.0 in /opt/conda/lib/python3.11/site-packages (from triton==3.3.1->torch==2.7.1->torchvision) (75.8.2)\n",
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/lib/python3.11/site-packages (from sympy>=1.13.3->torch==2.7.1->torchvision) (1.3.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.11/site-packages (from jinja2->torch==2.7.1->torchvision) (3.0.2)\n",
"WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\n",
"Waiting for master node p6072ef2ca48-node-0:29500...\n",
"Master node is reachable\n",
"Using Device: cpu, Backend: gloo\n",
"Distributed Training for WORLD_SIZE: 2, RANK: 1, LOCAL_RANK: 0\n",
"100%|██████████| 26.4M/26.4M [00:00<00:00, 51.8MB/s]\n",
"100%|██████████| 29.5k/29.5k [00:00<00:00, 2.81MB/s]\n",
"100%|██████████| 4.42M/4.42M [00:00<00:00, 25.2MB/s]\n",
"100%|██████████| 5.15k/5.15k [00:00<00:00, 13.3MB/s]\n",
"Train Epoch: 1 [0/60000 (0%)]\tLoss: 2.322156\n",
"Train Epoch: 1 [10000/60000 (33%)]\tLoss: 0.986901\n",
"Train Epoch: 1 [20000/60000 (67%)]\tLoss: 0.654047\n",
"Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.686302\n",
"Train Epoch: 2 [10000/60000 (33%)]\tLoss: 0.607211\n",
"Train Epoch: 2 [20000/60000 (67%)]\tLoss: 0.552617\n"
"Distributed Training for WORLD_SIZE: 1, RANK: 0, LOCAL_RANK: 0\n",
"100%|██████████| 26.4M/26.4M [00:14<00:00, 1.79MB/s]\n",
"100%|██████████| 29.5k/29.5k [00:00<00:00, 191kB/s]\n",
"100%|██████████| 4.42M/4.42M [00:02<00:00, 2.17MB/s]\n",
"100%|██████████| 5.15k/5.15k [00:00<00:00, 15.1MB/s]\n",
"Train Epoch: 1 [0/60000 (0%)]\tLoss: 2.303455\n",
"Train Epoch: 1 [10000/60000 (17%)]\tLoss: 2.050205\n",
"Train Epoch: 1 [20000/60000 (33%)]\tLoss: 1.071181\n",
"Train Epoch: 1 [30000/60000 (50%)]\tLoss: 0.900017\n",
"Train Epoch: 1 [40000/60000 (67%)]\tLoss: 0.710202\n",
"Train Epoch: 1 [50000/60000 (83%)]\tLoss: 0.771787\n",
"Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.897072\n",
"Train Epoch: 2 [10000/60000 (17%)]\tLoss: 0.649391\n",
"Train Epoch: 2 [20000/60000 (33%)]\tLoss: 0.592039\n",
"Train Epoch: 2 [30000/60000 (50%)]\tLoss: 0.606340\n",
"Train Epoch: 2 [40000/60000 (67%)]\tLoss: 0.533571\n",
"Train Epoch: 2 [50000/60000 (83%)]\tLoss: 0.617757\n",
"Training is finished\n"
]
}
],
Expand All @@ -442,7 +429,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"id": "delete",
"metadata": {},
"outputs": [],
Expand All @@ -453,7 +440,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
Expand All @@ -467,7 +454,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.12"
"version": "3.11.13"
}
},
"nbformat": 4,
Expand Down
Loading