Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
321 changes: 321 additions & 0 deletions docs/source/tutorials/pruning/unet_depth_reduction.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,321 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "wx5N8OB-dbDV",
"outputId": "237b8a09-14e3-4d7b-ccfa-3892d17c9f26"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[?25l \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m0.0/2.7 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m\u001b[91m\u2578\u001b[0m \u001b[32m2.7/2.7 MB\u001b[0m \u001b[31m102.2 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m2.7/2.7 MB\u001b[0m \u001b[31m45.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h\u001b[?25l \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m0.0/70.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m70.2/70.2 kB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h"
]
}
],
"source": [
"!pip install -q monai torch-pruning\n"
]
},
{
"cell_type": "markdown",
"source": [
"# Structured Pruning of U-Net for Medical Image Segmentation\n",
"\n",
"This tutorial demonstrates how structured channel pruning can be applied to a MONAI U-Net model to reduce model size and computation, while maintaining segmentation capability.\n"
],
"metadata": {
"id": "iQoMzAlfd4ZH"
}
},
{
"cell_type": "code",
"source": [
"import torch\n",
"import numpy as np\n",
"from monai.networks.nets import UNet\n",
"import torch_pruning as tp\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "-uKJxHXQd5cX",
"outputId": "da495e06-22b6-4309-8b74-69f920c226c3"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"<frozen importlib._bootstrap_external>:1301: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"torch.manual_seed(0)\n",
"np.random.seed(0)\n"
],
"metadata": {
"id": "QPj1Qficd8v4"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"source": [
"images = torch.rand(4, 1, 128, 128)\n",
"labels = (images > 0.5).float()\n"
],
"metadata": {
"id": "APIhPidJd9zv"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def count_params(model):\n",
" return sum(p.numel() for p in model.parameters())\n"
],
"metadata": {
"id": "sjrMXVgxfO8P"
},
"execution_count": 14,
"outputs": []
},
{
"cell_type": "code",
"source": [
"baseline_unet = UNet(\n",
" spatial_dims=2,\n",
" in_channels=1,\n",
" out_channels=1,\n",
" channels=(16, 32, 64, 128, 256), # 5 levels\n",
" strides=(2, 2, 2, 2),\n",
")\n"
],
"metadata": {
"id": "s3MFPtpkfRa_"
},
"execution_count": 15,
"outputs": []
},
{
"cell_type": "code",
"source": [
"print(\"Baseline parameters:\", count_params(baseline_unet))\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9-QPnHhUfTSv",
"outputId": "c91f6ee7-8ad2-4a52-e20a-97c185e28ea0"
},
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Baseline parameters: 659993\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"reduced_unet = UNet(\n",
" spatial_dims=2,\n",
" in_channels=1,\n",
" out_channels=1,\n",
" channels=(16, 32, 64), # only 3 levels\n",
" strides=(2, 2),\n",
")\n"
],
"metadata": {
"id": "tbf-UY9GfVVf"
},
"execution_count": 17,
"outputs": []
},
{
"cell_type": "code",
"source": [
"print(\"Depth-reduced parameters:\", count_params(reduced_unet))\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9Wi3dwetfXRw",
"outputId": "cb8a334b-ece6-45e9-c74b-0a8ec01cac92"
},
"execution_count": 18,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Depth-reduced parameters: 37429\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"baseline_out = baseline_unet(images)\n",
"reduced_out = reduced_unet(images)\n",
"\n",
"print(\"Baseline output shape:\", baseline_out.shape)\n",
"print(\"Reduced output shape:\", reduced_out.shape)\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "iDmYo0HrfZ4A",
"outputId": "6d9eb4e2-c7ab-443c-a2e1-5ec77b1b1cc2"
},
"execution_count": 19,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Baseline output shape: torch.Size([4, 1, 128, 128])\n",
"Reduced output shape: torch.Size([4, 1, 128, 128])\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"baseline_params = count_params(baseline_unet)\n",
"reduced_params = count_params(reduced_unet)\n",
"\n",
"reduction = 100 * (baseline_params - reduced_params) / baseline_params\n",
"print(f\"Parameter reduction: {reduction:.2f}%\")\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "qt4TAf-efc3g",
"outputId": "e5d71e43-14ac-49ba-f1a4-6d63f904275e"
},
"execution_count": 20,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Parameter reduction: 94.33%\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Discussion\n",
"\n",
"Reducing the depth of a U-Net architecture leads to a true reduction in the number of learnable parameters, unlike masking-based pruning approaches that preserve tensor shapes.\n",
"\n",
"Depth reduction decreases representational capacity and receptive field size, which may affect segmentation accuracy. However, for many medical imaging applications\u2014especially those targeting edge devices or real-time inference\u2014this trade-off is acceptable and often desirable.\n",
"\n",
"This approach provides a simple, stable, and reproducible strategy for building lightweight medical imaging models.\n"
],
"metadata": {
"id": "f1Csi3BJfe3w"
}
},
{
"cell_type": "code",
"source": [
"import time\n",
"\n",
"def inference_time(model, x, runs=20):\n",
" model.eval()\n",
" with torch.no_grad():\n",
" start = time.time()\n",
" for _ in range(runs):\n",
" _ = model(x)\n",
" end = time.time()\n",
" return (end - start) / runs\n",
"\n",
"print(\"Baseline avg inference time:\", inference_time(baseline_unet, images))\n",
"print(\"Reduced avg inference time:\", inference_time(reduced_unet, images))\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2cGD6jwsfvli",
"outputId": "b7a66f85-4247-4548-dc2c-782c08425e93"
},
"execution_count": 21,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Baseline avg inference time: 0.01868886947631836\n",
"Reduced avg inference time: 0.013485324382781983\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## When to Use Depth-Reduced Models\n",
"\n",
"Depth-reduced architectures are well suited for:\n",
"- Edge and embedded medical devices\n",
"- Real-time or near\u2013real-time inference\n",
"- Rapid prototyping and experimentation\n",
"- Scenarios with limited memory or compute budgets\n",
"\n",
"For tasks requiring fine-grained segmentation accuracy, deeper architectures may still be preferable.\n"
],
"metadata": {
"id": "eBW8lxpTfyLC"
}
}
]
}
Loading