From 4385252c3ca1c1a2b9e7172a779f0fbece1b7d33 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 28 Sep 2021 12:35:51 +0500 Subject: [PATCH 1/2] Added file to prepare ACDC dataset for training with TransUNet --- prepare_ACDCdataset.py | 59 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 prepare_ACDCdataset.py diff --git a/prepare_ACDCdataset.py b/prepare_ACDCdataset.py new file mode 100644 index 00000000..0257f9f4 --- /dev/null +++ b/prepare_ACDCdataset.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +""" +Created on Sat Sep 25 20:26:22 2021 + +@author: Yusra Shahid + +This script can be used to prepare ADCDC Dataset for training with TransUNet + +""" +import numpy as np +import nibabel as nb +import glob +import matplotlib.pyplot as plt +import cv2 +import h5py +import os + +## the directory with datatset +root_dir = "../ACDC dataset\\training\*\*" + +files = glob.glob(root_dir) +labels = [] +images = [] + +## this code separates the ground truth files from the images +for each in files: + if "frame" in each and "gt" in each: + labels.append(each) + elif "frame" in each: + images.append(each) + +## read images and labels and save them as npz file +os.mkdir('../ACDC dataset\\train_npz',exist_ok = True) +prev_patient = "patient001" +slice_num = 0 +for i in range(len(images)): + slice_num=0 + patient = images[i].split("\\")[-2] + print(patient) + print(prev_patient) + image = nb.load(images[i]).get_fdata() + label = nb.load(labels[i]).get_fdata() + slices = image.shape[2] + if i!=0 and prev_patient == patient: + slice_num = slice_num +slices + print(slices,slice_num) + for num in range(slices): + # resizing using cv2 so the image isn't changed or tiled as with numpy + case_image = cv2.resize(image[:,:,num],(512,512)) + case_label = cv2.resize(label[:,:,num],(512,512)) + # case['image'] = case_image + # case['label'] = case_label + np.savez("../ACDC dataset\\train_npz\\" + str(patient) + "_slice" + str(slice_num).zfill(3),image = case_image, label=case_label) + slice_num+=1 + prev_patient = patient + + + + From f09ddce7d425ec958fc8e592a5bb9c0af1f92ba4 Mon Sep 17 00:00:00 2001 From: yskix Date: Wed, 29 Sep 2021 11:15:18 +0500 Subject: [PATCH 2/2] added modified scripts for train and test of ACDC --- test_ACDC.py | 143 ++++++++++++++++++++++++++++++++++++++++++++++++++ train_ACDC.py | 95 +++++++++++++++++++++++++++++++++ 2 files changed, 238 insertions(+) create mode 100644 test_ACDC.py create mode 100644 train_ACDC.py diff --git a/test_ACDC.py b/test_ACDC.py new file mode 100644 index 00000000..768fc200 --- /dev/null +++ b/test_ACDC.py @@ -0,0 +1,143 @@ +import argparse +import logging +import os +import random +import sys +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import torch.nn as nn +from torch.utils.data import DataLoader +from tqdm import tqdm +from datasets.dataset_acdc import ACDC_dataset +from utils import test_single_volume +from networks.vit_seg_modeling import VisionTransformer as ViT_seg +from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg + +parser = argparse.ArgumentParser() +parser.add_argument('--volume_path', type=str, + default='../data/ACDC/test_vol_h5', help='root dir for validation volume data') # for acdc volume_path=root_dir +parser.add_argument('--dataset', type=str, + default='ACDC', help='experiment_name') +parser.add_argument('--num_classes', type=int, + default=4, help='output channel of network') +parser.add_argument('--list_dir', type=str, + default='./lists/lists_ACDC', help='list dir') + +parser.add_argument('--max_iterations', type=int,default=20000, help='maximum epoch number to train') +parser.add_argument('--max_epochs', type=int, default=30, help='maximum epoch number to train') +parser.add_argument('--batch_size', type=int, default=12, + help='batch_size per gpu') +parser.add_argument('--img_size', type=int, default=224, help='input patch size of network input') +parser.add_argument('--is_savenii', action="store_true", help='whether to save results during inference') + +parser.add_argument('--n_skip', type=int, default=3, help='using number of skip-connect, default is num') +parser.add_argument('--vit_name', type=str, default='ViT-B_16', help='select one vit model') + +parser.add_argument('--test_save_dir', type=str, default='../predictions', help='saving prediction as nii!') +parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') +parser.add_argument('--base_lr', type=float, default=0.01, help='segmentation network learning rate') +parser.add_argument('--seed', type=int, default=1234, help='random seed') +parser.add_argument('--vit_patches_size', type=int, default=16, help='vit_patches_size, default is 16') +args = parser.parse_args() + + +def inference(args, model, test_save_path=None): + db_test = args.Dataset(base_dir=args.volume_path, split="test_vol", list_dir=args.list_dir) + testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1) + logging.info("{} test iterations per epoch".format(len(testloader))) + model.eval() + metric_list = 0.0 + for i_batch, sampled_batch in tqdm(enumerate(testloader)): + h, w = sampled_batch["image"].size()[2:] + image, label, case_name = sampled_batch["image"], sampled_batch["label"], sampled_batch['case_name'][0] + metric_i = test_single_volume(image, label, model, classes=args.num_classes, patch_size=[args.img_size, args.img_size], + test_save_path=test_save_path, case=case_name, z_spacing=args.z_spacing) + metric_list += np.array(metric_i) + logging.info('idx %d case %s mean_dice %f mean_hd95 %f' % (i_batch, case_name, np.mean(metric_i, axis=0)[0], np.mean(metric_i, axis=0)[1])) + metric_list = metric_list / len(db_test) + for i in range(1, args.num_classes): + logging.info('Mean class %d mean_dice %f mean_hd95 %f' % (i, metric_list[i-1][0], metric_list[i-1][1])) + performance = np.mean(metric_list, axis=0)[0] + mean_hd95 = np.mean(metric_list, axis=0)[1] + logging.info('Testing performance in best val model: mean_dice : %f mean_hd95 : %f' % (performance, mean_hd95)) + return "Testing Finished!" + + +if __name__ == "__main__": + + if not args.deterministic: + cudnn.benchmark = True + cudnn.deterministic = False + else: + cudnn.benchmark = False + cudnn.deterministic = True + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + + dataset_config = { + 'ACDC': { + 'Dataset': ACDC_dataset, + 'volume_path': '../data/ACDC/test_vol_h5', + 'list_dir': './lists/lists_ACDC', + 'num_classes': 4, + 'z_spacing': 1, + }, + } + print(args.dataset) + dataset_name = args.dataset + args.num_classes = dataset_config[dataset_name]['num_classes'] + args.volume_path = dataset_config[dataset_name]['volume_path'] + args.Dataset = dataset_config[dataset_name]['Dataset'] + args.list_dir = dataset_config[dataset_name]['list_dir'] + args.z_spacing = dataset_config[dataset_name]['z_spacing'] + args.is_pretrain = True + + # name the same snapshot defined in train script! + args.exp = 'TU_' + dataset_name + str(args.img_size) + snapshot_path = "../model/{}/{}".format(args.exp, 'TU') + snapshot_path = snapshot_path + '_pretrain' if args.is_pretrain else snapshot_path + snapshot_path += '_' + args.vit_name + snapshot_path = snapshot_path + '_skip' + str(args.n_skip) + snapshot_path = snapshot_path + '_vitpatch' + str(args.vit_patches_size) if args.vit_patches_size!=16 else snapshot_path + snapshot_path = snapshot_path + '_epo' + str(args.max_epochs) if args.max_epochs != 30 else snapshot_path + if dataset_name == 'ACDC': # using max_epoch instead of iteration to control training duration + snapshot_path = snapshot_path + '_' + str(args.max_iterations)[0:2] + 'k' if args.max_iterations != 30000 else snapshot_path + snapshot_path = snapshot_path+'_bs'+str(args.batch_size) + snapshot_path = snapshot_path + '_lr' + str(args.base_lr) if args.base_lr != 0.01 else snapshot_path + snapshot_path = snapshot_path + '_'+str(args.img_size) + snapshot_path = snapshot_path + '_s'+str(args.seed) if args.seed!=1234 else snapshot_path + + config_vit = CONFIGS_ViT_seg[args.vit_name] + config_vit.n_classes = args.num_classes + config_vit.n_skip = args.n_skip + config_vit.patches.size = (args.vit_patches_size, args.vit_patches_size) + if args.vit_name.find('R50') !=-1: + config_vit.patches.grid = (int(args.img_size/args.vit_patches_size), int(args.img_size/args.vit_patches_size)) + net = ViT_seg(config_vit, img_size=args.img_size, num_classes=config_vit.n_classes).cuda() + #print(net) + snapshot = os.path.join(snapshot_path, 'best_model.pth') + if not os.path.exists(snapshot): snapshot = snapshot.replace('best_model', 'epoch_'+str(args.max_epochs-1)) + snapshot = "/home/ix/Documents/thalassemia/segmentation/project_TransUNet/model/TU_ACDC224/TU_pretrain_R50-ViT-B_16_skip3_epo150_bs12_lr0.0025_224/epoch_149.pth" + #torch.load(snapshot) + net.load_state_dict(torch.load(snapshot)) + snapshot_name = snapshot_path.split('/')[-1] + + log_folder = './test_log/test_log_' + args.exp + os.makedirs(log_folder, exist_ok=True) + logging.basicConfig(filename=log_folder + '/'+snapshot_name+".txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging.info(str(args)) + logging.info(snapshot_name) + + if args.is_savenii: + args.test_save_dir = '../predictions' + test_save_path = os.path.join(args.test_save_dir, args.exp, snapshot_name) + os.makedirs(test_save_path, exist_ok=True) + else: + test_save_path = None + inference(args, net, test_save_path) + + diff --git a/train_ACDC.py b/train_ACDC.py new file mode 100644 index 00000000..af141203 --- /dev/null +++ b/train_ACDC.py @@ -0,0 +1,95 @@ +import argparse +import logging +import os +import random +import numpy as np +import torch +import torch.backends.cudnn as cudnn +from networks.vit_seg_modeling import VisionTransformer as ViT_seg +from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg +from trainer import trainer_acdc + +parser = argparse.ArgumentParser() +parser.add_argument('--root_path', type=str, + default='../data/ACDC/train_npz', help='root dir for data') +parser.add_argument('--dataset', type=str, + default='Synapse', help='experiment_name') +parser.add_argument('--list_dir', type=str, + default='./lists/ACDC', help='list dir') +parser.add_argument('--num_classes', type=int, + default=4, help='output channel of network') +parser.add_argument('--max_iterations', type=int, + default=30000, help='maximum epoch number to train') +parser.add_argument('--max_epochs', type=int, + default=150, help='maximum epoch number to train') +parser.add_argument('--batch_size', type=int, + default=12, help='batch_size per gpu') +parser.add_argument('--n_gpu', type=int, default=1, help='total gpu') +parser.add_argument('--deterministic', type=int, default=1, + help='whether use deterministic training') +parser.add_argument('--base_lr', type=float, default=0.005, + help='segmentation network learning rate') +parser.add_argument('--img_size', type=int, + default=224, help='input patch size of network input') +parser.add_argument('--seed', type=int, + default=1234, help='random seed') +parser.add_argument('--n_skip', type=int, + default=3, help='using number of skip-connect, default is num') +parser.add_argument('--vit_name', type=str, + default='R50-ViT-B_16', help='select one vit model') +parser.add_argument('--vit_patches_size', type=int, + default=16, help='vit_patches_size, default is 16') +args = parser.parse_args() + + +if __name__ == "__main__": + if not args.deterministic: + cudnn.benchmark = True + cudnn.deterministic = False + else: + cudnn.benchmark = False + cudnn.deterministic = True + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + dataset_name = args.dataset + dataset_config = { + 'ACDC': { + 'root_path': '../data/ACDC/train_npz', + 'list_dir': './lists/lists_ACDC', + 'num_classes': 4, + }, + } + if args.batch_size != 24 and args.batch_size % 6 == 0: + args.base_lr *= args.batch_size / 24 + args.num_classes = dataset_config[dataset_name]['num_classes'] + args.root_path = dataset_config[dataset_name]['root_path'] + args.list_dir = dataset_config[dataset_name]['list_dir'] + args.is_pretrain = True + args.exp = 'TU_' + dataset_name + str(args.img_size) + snapshot_path = "../model/{}/{}".format(args.exp, 'TU') + snapshot_path = snapshot_path + '_pretrain' if args.is_pretrain else snapshot_path + snapshot_path += '_' + args.vit_name + snapshot_path = snapshot_path + '_skip' + str(args.n_skip) + snapshot_path = snapshot_path + '_vitpatch' + str(args.vit_patches_size) if args.vit_patches_size!=16 else snapshot_path + snapshot_path = snapshot_path+'_'+str(args.max_iterations)[0:2]+'k' if args.max_iterations != 30000 else snapshot_path + snapshot_path = snapshot_path + '_epo' +str(args.max_epochs) if args.max_epochs != 30 else snapshot_path + snapshot_path = snapshot_path+'_bs'+str(args.batch_size) + snapshot_path = snapshot_path + '_lr' + str(args.base_lr) if args.base_lr != 0.01 else snapshot_path + snapshot_path = snapshot_path + '_'+str(args.img_size) + snapshot_path = snapshot_path + '_s'+str(args.seed) if args.seed!=1234 else snapshot_path + + if not os.path.exists(snapshot_path): + os.makedirs(snapshot_path) + config_vit = CONFIGS_ViT_seg[args.vit_name] + config_vit.n_classes = args.num_classes + config_vit.n_skip = args.n_skip + if args.vit_name.find('R50') != -1: + config_vit.patches.grid = (int(args.img_size / args.vit_patches_size), int(args.img_size / args.vit_patches_size)) + net = ViT_seg(config_vit, img_size=args.img_size, num_classes=config_vit.n_classes).cuda() + net.load_from(weights=np.load(config_vit.pretrained_path)) + + trainer = {'ACDC': trainer_acdc,} + trainer[dataset_name](args, net, snapshot_path)