AutoSeg4Sinonasal / deepatlas /scripts /deep_atlas_train.py
Chris Xiao
upload files
2ca2f68
import seg_train
from pathlib import Path
from collections import OrderedDict
import json
import sys
import argparse
import os.path
import glob
import random
import matplotlib.pyplot as plt
import torch
import monai
import logging
import shutil
from collections import namedtuple
import numpy as np
import datetime
ROOT_DIR = str(Path(os.getcwd()).parent.parent.absolute())
sys.path.insert(0, os.path.join(ROOT_DIR, 'deepatlas/preprocess'))
sys.path.insert(0, os.path.join(ROOT_DIR, 'deepatlas/network'))
sys.path.insert(0, os.path.join(ROOT_DIR, 'deepatlas/train'))
sys.path.insert(0, os.path.join(ROOT_DIR, 'deepatlas/utils'))
from train import (
train_network
)
from network import (
regNet, segNet
)
from process_data import (
split_data, load_seg_dataset, load_reg_dataset, take_data_pairs, subdivide_list_of_data_pairs
)
from utils import (
load_json, make_if_dont_exist
)
def parse_command_line():
parser = argparse.ArgumentParser(
description='pipeline for deep atlas train')
parser.add_argument('--config', metavar='path to the configuration file', type=str,
help='absolute path to the configuration file')
parser.add_argument('--continue_training', action='store_true',
help='use this if you want to continue a training')
parser.add_argument('--train_only', action='store_true',
help='only training or training plus test')
parser.add_argument('--plot_network', action='store_true',
help='whether to plot the network')
argv = parser.parse_args()
return argv
def get_seg_net(spatial_dims, num_label, dropout, activation_type, normalization_type, num_res):
seg_net = segNet(
spatial_dim=spatial_dims, # spatial dims
in_channel=1, # input channels
out_channel=num_label, # output channels
channel=(8, 16, 16, 32, 32, 64, 64), # channel sequence
stride=(1, 2, 1, 2, 1, 2), # convolutional strides
dropouts=dropout,
acts=activation_type,
norms=normalization_type,
num_res_unit=num_res
)
return seg_net
def get_reg_net(spatial_dims, num_label, dropout, activation_type, normalization_type, num_res):
reg_net = regNet(
spatial_dim=spatial_dims, # spatial dims
in_channel=2, # input channels
out_channel=num_label, # output channels
channel=(16, 32, 32, 32, 32), # channel sequence
stride=(1, 2, 2, 2), # convolutional strides
dropouts=dropout,
acts=activation_type,
norms=normalization_type,
num_res_unit=num_res
)
return reg_net
def setup_logger(logger_name, log_file, level=logging.INFO):
log_setup = logging.getLogger(logger_name)
formatter = logging.Formatter('%(asctime)s %(message)s', datefmt="%Y-%m-%d %H:%M:%S")
fileHandler = logging.FileHandler(log_file, mode='w')
fileHandler.setFormatter(formatter)
streamHandler = logging.StreamHandler()
streamHandler.setFormatter(formatter)
log_setup.setLevel(level)
log_setup.addHandler(fileHandler)
log_setup.addHandler(streamHandler)
def classify_data(data_info, fold):
lab_each_fold = {}
lab = []
unlab = []
total_seg = 0
total_seg_each_fold = {}
for key, value in data_info.items():
if key != f'fold_{fold}':
lab_each_fold[key] = []
total_seg_each_fold[key] = 0
for val in value:
if 'seg' not in val.keys():
unlab.append(val)
else:
lab_each_fold[key].append(val)
lab.append(val)
total_seg += 1
total_seg_each_fold[key] += 1
return lab_each_fold, lab, unlab, total_seg, total_seg_each_fold
def select_n_seg(lab, fold, num, total_seg_each_fold):
seg_items = lab[f'fold_{fold}']
num_seg = len(seg_items)
rand_num = random.sample(range(num_seg), num)
seg_item = np.array(seg_items)[np.array(rand_num)]
seg_items.pop(rand_num[0])
total_seg_each_fold[f'fold_{fold}'] -= 1
lab[f'fold_{fold}'] = seg_items
return list(seg_item), lab, total_seg_each_fold
def combine_data(data_info, fold, exp, num_seg):
all_fold = np.arange(len(data_info.keys())) + 1
num_train_fold = len(data_info.keys()) - 1
fake_train_fold = np.delete(all_fold, fold-1)
fake_train_fold = np.tile(fake_train_fold, 2)
real_train_fold = fake_train_fold[fold-1:fold+num_train_fold-1]
train = []
test = []
for j in data_info[f'fold_{fold}']:
if 'seg' in j.keys():
test.append(j)
lab_each_fold, lab, unlab, total_seg, total_seg_each_fold = classify_data(data_info, fold)
if total_seg < num_seg:
num_seg = total_seg
num_each_fold_seg = divmod(num_seg, num_train_fold)[0]
fold_num_seg = np.repeat(num_each_fold_seg, num_train_fold)
num_remain_seg = divmod(num_seg, num_train_fold)[1]
count = 0
while num_remain_seg > 0:
fold_num_seg[count] += 1
count = (count+1) % num_train_fold
num_remain_seg -= 1
train = unlab
k = 0
while num_seg > 0:
next_fold = real_train_fold[k]
if total_seg_each_fold[f'fold_{next_fold}'] > 0:
seg_items, lab_each_fold, total_seg_each_fold = select_n_seg(lab_each_fold, next_fold, 1, total_seg_each_fold)
train.extend(seg_items)
num_seg -= 1
k = (k+1) % 4
num_segs = 0
if exp != 1:
for key, value in total_seg_each_fold.items():
if value != 0:
for j in lab_each_fold[key]:
item = {'img': j['img']}
train.append(item)
total_seg_each_fold[key] -= 1
for key, value in total_seg_each_fold.items():
num_segs += value
assert num_segs == 0
return train, test
def main():
args = parse_command_line()
config = args.config
continue_training = args.continue_training
train_only = args.train_only
config = load_json(config)
config = namedtuple("config", config.keys())(*config.values())
folder_name = config.folder_name
num_seg_used = config.num_seg_used
experiment_set = config.exp_set
monai.utils.set_determinism(seed=2938649572)
data_path = os.path.join(ROOT_DIR, 'deepatlas_results')
base_path = os.path.join(ROOT_DIR, 'deepatlas_preprocessed')
task = os.path.join(data_path, config.task_name)
exp_path = os.path.join(task, f'set_{experiment_set}')
gt_path = os.path.join(exp_path, f'{num_seg_used}gt')
folder_path = os.path.join(gt_path, folder_name)
result_path = os.path.join(folder_path, 'training_results')
if train_only:
info_name = 'info_train_only'
else:
info_name = 'info'
info_path = os.path.join(base_path, config.task_name, 'Training_dataset', 'data_info', folder_name, info_name+'.json')
info = load_json(info_path)
if torch.cuda.is_available():
device = torch.device("cuda:" + str(torch.cuda.current_device()))
spatial_dim = config.network['spatial_dim']
dropout = config.network['dropout']
activation_type = config.network['activation_type']
normalization_type = config.network['normalization_type']
num_res = config.network['num_res']
lr_reg = config.network["registration_network_learning_rate"]
lr_seg = config.network["segmentation_network_learning_rate"]
lam_a = config.network["anatomy_loss_weight"]
lam_sp = config.network["supervised_segmentation_loss_weight"]
lam_re = config.network["regularization_loss_weight"]
max_epoch = config.network["number_epoch"]
val_step = config.network["validation_step"]
make_if_dont_exist(data_path)
make_if_dont_exist(task)
make_if_dont_exist(exp_path)
make_if_dont_exist(gt_path)
make_if_dont_exist(folder_path)
make_if_dont_exist(result_path)
if not continue_training:
start_fold = 1
else:
folds = sorted(os.listdir(result_path))
if len(folds) == 0:
continue_training = False
start_fold = 1
else:
last_fold_num = folds[-1].split('_')[-1]
start_fold = int(last_fold_num)
if train_only:
num_fold = 1
else:
num_fold = config.num_fold
for i in range (start_fold, num_fold+1):
if not train_only:
fold_path = os.path.join(result_path, f'fold_{i}')
result_seg_path = os.path.join(fold_path, 'SegNet')
result_reg_path = os.path.join(fold_path, 'RegNet')
else:
fold_path = os.path.join(result_path, f'all')
result_seg_path = os.path.join(fold_path, 'SegNet')
result_reg_path = os.path.join(fold_path, 'RegNet')
make_if_dont_exist(fold_path)
make_if_dont_exist(result_reg_path)
make_if_dont_exist(result_seg_path)
datetime_object = 'training_log_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") + '.log'
log_path = os.path.join(fold_path, datetime_object)
if not train_only:
if not continue_training:
setup_logger(f'log_{i}', log_path)
logger = logging.getLogger(f'log_{i}')
logger.info(f"Start Pipeline with fold_{i}")
else:
setup_logger(f'log_{i+1}', log_path)
logger = logging.getLogger(f'log_{i+1}')
logger.info(f"Resume Pipeline with fold_{i}")
else:
setup_logger(f'all', log_path)
logger = logging.getLogger(f'all')
logger.info(f"Start Pipeline with all data")
if not os.path.exists(os.path.join(fold_path, 'dataset.json')):
logger.info('prepare dataset into train and test')
json_dict = OrderedDict()
json_dict['name'] = os.path.basename(task).split('_')[0]
json_dict['description'] = '_'.join(os.path.basename(task).split('_')[1:])
json_dict['tensorImageSize'] = "4D"
json_dict['reference'] = "MODIFY"
json_dict['licence'] = "MODIFY"
json_dict['release'] = "0.0"
json_dict['modality'] = {
"0": "CT"
}
json_dict['labels'] = config.labels
json_dict['network'] = config.network
json_dict['experiment_set'] = experiment_set
if not train_only:
json_dict['num_fold'] = f'fold_{i}'
train, test = combine_data(info, i, experiment_set, num_seg_used)
else:
json_dict['num_fold'] = 'all'
train = info
test = []
num_seg_used = len(list(filter(lambda d: 'seg' in d.keys(), train)))
#num_seg = 15
#train, test, num_train, num_test = split_data(img_path, seg_path, num_seg)
#print(type(train))
num_seg = num_seg_used
num_train = len(train)
num_test = len(test)
#print(train.keys())
json_dict['total_numScanTraining'] = num_train
json_dict['total_numLabelTraining'] = num_seg
json_dict['total_numTest'] = num_test
json_dict['total_train'] = train
json_dict['total_test'] = test
# prepare segmentation dataset
logger.info('prepare segmentation dataset')
data_seg_available = list(filter(lambda d: 'seg' in d.keys(), train))
data_seg_unavailable = list(filter(lambda d: 'seg' not in d.keys(), train))
data_seg_available_train, data_seg_available_valid = \
monai.data.utils.partition_dataset(data_seg_available, ratios=(8, 2))
json_dict['seg_numTrain'] = len(data_seg_available_train)
json_dict['seg_train'] = data_seg_available_train
json_dict['seg_numValid'] = len(data_seg_available_valid)
json_dict['seg_valid'] = data_seg_available_valid
dataset_seg_available_train, dataset_seg_available_valid = load_seg_dataset(
data_seg_available_train, data_seg_available_valid)
data_item = random.choice(dataset_seg_available_train)
img_shape = data_item['seg'].unsqueeze(0).shape[2:]
num_label = len(torch.unique(data_item['seg']))
logger.info('prepare segmentation network')
seg_net = get_seg_net(spatial_dim, num_label, dropout,
activation_type, normalization_type, num_res)
# prepare registration dataset
logger.info('prepare registration dataset')
data_without_seg_valid = data_seg_unavailable + data_seg_available_train
data_valid, data_train = monai.data.utils.partition_dataset(
data_without_seg_valid, # Note the order
ratios=(2, 8), # Note the order
shuffle=False
)
data_paires_without_seg_valid = take_data_pairs(data_without_seg_valid)
data_pairs_valid = take_data_pairs(data_valid)
data_pairs_train = take_data_pairs(data_train)
data_pairs_valid_subdivided = subdivide_list_of_data_pairs(
data_pairs_valid)
data_pairs_train_subdivided = subdivide_list_of_data_pairs(
data_pairs_train)
num_train_reg_net = len(data_pairs_train)
num_valid_reg_net = len(data_pairs_valid)
num_train_both = len(data_pairs_train_subdivided['01']) +\
len(data_pairs_train_subdivided['10']) +\
len(data_pairs_train_subdivided['11'])
json_dict['reg_seg_numTrain'] = num_train_reg_net
json_dict['reg_seg_numTrain_00'] = len(data_pairs_train_subdivided['00'])
json_dict['reg_seg_train_00'] = data_pairs_train_subdivided['00']
json_dict['reg_seg_numTrain_01'] = len(data_pairs_train_subdivided['01'])
json_dict['reg_seg_train_01'] = data_pairs_train_subdivided['01']
json_dict['reg_seg_numTrain_10'] = len(data_pairs_train_subdivided['10'])
json_dict['reg_seg_train_10'] = data_pairs_train_subdivided['10']
json_dict['reg_seg_numTrain_11'] = len(data_pairs_train_subdivided['11'])
json_dict['reg_seg_train_11'] = data_pairs_train_subdivided['11']
json_dict['reg_numValid'] = num_valid_reg_net
json_dict['reg_numValid_00'] = len(data_pairs_valid_subdivided['00'])
json_dict['reg_valid_00'] = data_pairs_valid_subdivided['00']
json_dict['reg_numValid_01'] = len(data_pairs_valid_subdivided['01'])
json_dict['reg_valid_01'] = data_pairs_valid_subdivided['01']
json_dict['reg_numValid_10'] = len(data_pairs_valid_subdivided['10'])
json_dict['reg_valid_10'] = data_pairs_valid_subdivided['10']
json_dict['reg_numValid_11'] = len(data_pairs_valid_subdivided['11'])
json_dict['reg_valid_11'] = data_pairs_valid_subdivided['11']
print(f"""We have {num_train_both} pairs to train reg_net and seg_net together, and an additional {num_train_reg_net - num_train_both} to train reg_net alone.""")
print(f"We have {num_valid_reg_net} pairs for reg_net validation.")
dataset_pairs_train_subdivided, dataset_pairs_valid_subdivided = load_reg_dataset(
data_pairs_train_subdivided, data_pairs_valid_subdivided)
logger.info('prepare registration network')
reg_net = get_reg_net(spatial_dim, spatial_dim, dropout,
activation_type, normalization_type, num_res)
logger.info('generate dataset json file')
with open(os.path.join(fold_path, 'dataset.json'), 'w') as f:
json.dump(json_dict, f, indent=4, sort_keys=False)
else:
dataset_json = load_json(os.path.join(fold_path, 'dataset.json'))
data_seg_available_train = dataset_json['seg_train']
data_seg_available_valid = dataset_json['seg_valid']
dataset_seg_available_train, dataset_seg_available_valid = load_seg_dataset(data_seg_available_train, data_seg_available_valid)
data_item = random.choice(dataset_seg_available_train)
img_shape = data_item['seg'].unsqueeze(0).shape[2:]
num_label = len(torch.unique(data_item['seg']))
logger.info('prepare segmentation network')
seg_net = get_seg_net(spatial_dim, num_label, dropout, activation_type, normalization_type, num_res)
data_pairs_train_subdivided = {
'00': dataset_json['reg_seg_train_00'],
'01': dataset_json['reg_seg_train_01'],
'10': dataset_json['reg_seg_train_10'],
'11': dataset_json['reg_seg_train_11']
}
data_pairs_valid_subdivided = {
'00': dataset_json['reg_valid_00'],
'01': dataset_json['reg_valid_01'],
'10': dataset_json['reg_valid_10'],
'11': dataset_json['reg_valid_11']
}
num_train_reg_net = dataset_json['reg_seg_numTrain']
num_valid_reg_net = dataset_json['reg_numValid']
num_train_both = len(data_pairs_train_subdivided['01']) +\
len(data_pairs_train_subdivided['10']) +\
len(data_pairs_train_subdivided['11'])
print(f"""We have {num_train_both} pairs to train reg_net and seg_net together,
and an additional {num_train_reg_net - num_train_both} to train reg_net alone.""")
print(f"We have {num_valid_reg_net} pairs for reg_net validation.")
dataset_pairs_train_subdivided, dataset_pairs_valid_subdivided = load_reg_dataset(
data_pairs_train_subdivided, data_pairs_valid_subdivided)
logger.info('prepare registration network')
reg_net = get_reg_net(spatial_dim, spatial_dim, dropout,
activation_type, normalization_type, num_res)
dataloader_train_seg = monai.data.DataLoader(
dataset_seg_available_train,
batch_size=2,
num_workers=4,
shuffle=True
)
dataloader_valid_seg = monai.data.DataLoader(
dataset_seg_available_valid,
batch_size=4,
num_workers=4,
shuffle=False
)
dataloader_train_reg = {
seg_availability: monai.data.DataLoader(
dataset,
batch_size=1,
num_workers=4,
shuffle=True
)
# empty dataloaders are not a thing-- put an empty list if needed
if len(dataset) > 0 else []
for seg_availability, dataset in dataset_pairs_train_subdivided.items()
}
dataloader_valid_reg = {
seg_availability: monai.data.DataLoader(
dataset,
batch_size=2,
num_workers=4,
shuffle=True # Shuffle validation data because we will only take a sample for validation each time
)
# empty dataloaders are not a thing-- put an empty list if needed
if len(dataset) > 0 else []
for seg_availability, dataset in dataset_pairs_valid_subdivided.items()
}
train_network(dataloader_train_reg,
dataloader_valid_reg,
dataloader_train_seg,
dataloader_valid_seg,
device,
seg_net,
reg_net,
num_label,
lr_reg,
lr_seg,
lam_a,
lam_sp,
lam_re,
max_epoch,
val_step,
result_seg_path,
result_reg_path,
logger,
img_shape,
plot_network=args.plot_network,
continue_training=continue_training
)
'''
seg_train.train_seg(
dataloader_train_seg,
dataloader_valid_seg,
device,
seg_net,
lr_seg,
max_epoch,
val_step,
result_seg_path
)
'''
if __name__ == '__main__':
torch.cuda.empty_cache()
main()