Code for auotmated severity prediction and analysis of pulmonary hypertension (PH) in newborns using echocardiography (ECHO).
-
Trained models can be found in the following url
-
Main code and classes are located in the ehco_ph module.
-
Scripts for pre-processing dataset, generating index files (for splitting into train and validation set), training models, and analysing results is found in the scripts directory.
- Create a conda environment with
conda env create -f environment.yml - Activate the conda environment.
- Run
pip install -e .(to activate the echo_ph module)
- Pre-process the dataset:
python scripts/data/preprocess_videos.py - Generate clean labels from excel annotations:
python scripts/data/generate_labels.py - Generate index files with samples acc. to train-val split:
python scripts/data/generate_index_files.py
- Script:
scripts/train.py - Example of training temporal severity PH prediction model on the PSAX view:
-
python scripts/train_simple.py --max_epochs 300 --wd 1e-3 --class_balance_per_epoch --cache_dir ~/.heart_echo --k 10 --fold ${fold} --augment --pretrained --num_rand_frames 10 --model r3d_18 --temporal --label_type 3class --view KAPAP --batch_size 8
-
- Example of training spatial binary PH detection model on the PLAX view:
-
python scripts/train_simple.py --max_epochs 300 --wd 1e-3 --class_balance_per_epoch --cache_dir ~/.heart_echo --k 10 --fold ${fold} --augment --pretrained --num_rand_frames 10 --model resnet --label_type 2class_drop_ambiguous --view LA --batch_size 64
-
- Use same script as for training (
scripts/train.py), with the same arguments as when you trained the model you are now evaluating.- Add the arguments:
--load_modeland--model_path <path_to_trained_model> - This will save the result files, holding the raw output, target and sample names.
- If desired, you can get metric results, by running the
scripts/evaluation/get_metrics.py(see next section)
- If desired, you can get metric results, by running the
- Add the arguments:
python scripts/evaluation/get_metrics.py --res_dir res_dir- Add
--multi_classif any of the models from res_dir are not binary classification. - Note that the res_dir should be the directory storing the directory of other model(s) results dirs.
- Get grad-cam saliency map visualisations for temporal model:
- Save 1 clip per video:
-
python scripts/visualisations/vis_grad_cam_temp.py --model_path <path_to_trained_model.pt> --model <model_type> --num_rand_samples 1 --save_video_clip - Save full video (feed all frames - but model not trained with this long input):
-
python scripts/visualise/vis_grad_cam_temp.py --model_path <path_to_trained_model.pt> --model <model_type> --all_frames --save_video --view <view>
- Get grad-cam saliency map visualisations for spatial model:
- Use
python scripts/visualisations/vis_grad_cam.py
- Use
- Mv of 3 views (similar for 5 views):
-
python scripts/evaluation/multi_view_ensemble.py base_res_dir --res_files file_name_KAPAP file_name_CV file_name_CV --views kapap cv la
-
- Frame-level joining of 3 views:
-
python scripts/evaluation/join_view_models_frame_level.py base_res_dir --res_files file_name_KAPAP file_name_CV file_name_CV --views kapap cv la
-
If this code was helpful to you, please consider citing:
@inproceedings{ragnarsdottir2022interpretable, title={Interpretable Prediction of Pulmonary Hypertension in Newborns Using Echocardiograms}, author={Ragnarsdottir, Hanna and Manduchi, Laura and Michel, Holger and Laumer, Fabian and Wellmann, Sven and Ozkan, Ece and Vogt, Julia E}, booktitle={DAGM German Conference on Pattern Recognition}, pages={529--542}, year={2022}, organization={Springer} }
