A flexible PyTorch-based framework for training 2D and 3D medical image segmentation models, with support for patch-based training, configurable architectures, and comprehensive metrics tracking.
Brain-Segmentation/
├── base/ # Abstract base classes
│ ├── base_dataset2d_sliced.py
│ ├── base_dataset.py
│ ├── base_model.py
│ └── base_trainer.py
├── config/ # Configuration files for training and transforms
│ ├── config_atlas.json
│ ├── atlas_transforms.json
│ └── ...
├── datasets/ # Dataset loading and preprocessing (inherited from base_datasets)
│ ├── DatasetFactory.py
│ ├── ATLAS.py
│ └── BraTS2D.py
├── losses/ # Loss function implementations
│ └── LossFactory.py
├── metrics/ # Metrics computation and tracking
│ ├── MetricsFactory.py
│ └── MetricsManager.py
├── models/ # Model architectures (inherited from base_model)
│ ├── ModelFactory.py
│ ├── UNet2D.py
│ └── UNet3D.py
├── optimizers/ # Optimizer configurations
│ └── OptimizerFactory.py
├── trainer/ # Training logic (inherited from base_trainer)
│ ├── trainer_2Dsliced.py
│ └── trainer_3D.py
├── transforms/ # Data augmentation and preprocessing
│ └── TransformsFactory.py
├── utils/ # Utility functions
│ ├── util.py
│ └── pad_unpad.py
├── config.py # Config file handler
├── main.py # Training entry point
└── requirements.txt # Python dependencies
- Clone the repository:
git clone https://github.com/kev98/Medical-Image-Segmentation.git
cd Medical-Image-Segmentation- Create and activate a virtual environment:
python -m venv .venv
source .venv/bin/activate- Install dependencies:
pip install -r requirements.txtThe following are some base examples. You can add other CLI parameters useful for your main.py (which must be the entrypoint for training).
Command line arguments implemented in the provided main.py file:
--config: Path to configuration JSON file (required)--epochs: Number of training epochs (required)--save_path: Directory to save model checkpoints (required)--validation: Enable validation during training (flag)--resume: Resume training from last checkpoint (flag)--debug: Enable debug mode with verbose output (flag)
Example of launch of main.py, training a 3D segmentation model, resuming checkpoints,
python main.py \
--config config/config_atlas.json \
--epochs 100 \
--save_path /folder_containing_model_last.pth \
--validation \
--resumeTo set up a complete training pipeline, follow these steps:
-
Create a Dataset Class: Inherit from BaseDataset or BaseDataset2DSliced and implement the required abstract methods.
-
Implement a Model: Create your custom model by inheriting from BaseModel and implementing the
forward()method. -
Implement a Trainer: Create a custom trainer by inheriting from BaseTrainer and implementing
_train_epoch()andeval_epoch()methods. -
Create an Entrypoint: Write a
main.pyfile that loads your configuration and instantiates your trainer. Use the provided main.py as a template or reference.
For detailed documentation on each component, refer to the README files in their respective directories:
- Base Classes - Abstract base classes for datasets, models, and trainers
- Configuration - JSON configuration files for training and transforms
- Datasets - Dataset loading and preprocessing
- Losses - Loss function implementations
- Metrics - Metrics computation and tracking
- Models - Model architectures
- Optimizers - Optimizer configurations
- Trainers - Training logic
- Transforms - Data augmentation and preprocessing
- Utils - Utility functions
- For patch-based training with 3D volumes, the framework uses TorchIO's Queue and GridSampler.
- Metrics are automatically computed per-class and averaged.
- Checkpoints are saved as
model_last.pthandmodel_best.pthin the folder specified by the parameter --save_path. - The framework is compatible with PyTorch 2.3+ and uses TorchIO's SubjectsLoader for proper data handling.