|
import seg_train |
|
from pathlib import Path |
|
from collections import OrderedDict |
|
import json |
|
import sys |
|
import argparse |
|
import os.path |
|
import glob |
|
import random |
|
import matplotlib.pyplot as plt |
|
import torch |
|
import monai |
|
import logging |
|
import shutil |
|
from collections import namedtuple |
|
import numpy as np |
|
import datetime |
|
|
|
ROOT_DIR = str(Path(os.getcwd()).parent.parent.absolute()) |
|
sys.path.insert(0, os.path.join(ROOT_DIR, 'deepatlas/preprocess')) |
|
sys.path.insert(0, os.path.join(ROOT_DIR, 'deepatlas/network')) |
|
sys.path.insert(0, os.path.join(ROOT_DIR, 'deepatlas/train')) |
|
sys.path.insert(0, os.path.join(ROOT_DIR, 'deepatlas/utils')) |
|
from train import ( |
|
train_network |
|
) |
|
from network import ( |
|
regNet, segNet |
|
) |
|
from process_data import ( |
|
split_data, load_seg_dataset, load_reg_dataset, take_data_pairs, subdivide_list_of_data_pairs |
|
) |
|
from utils import ( |
|
load_json, make_if_dont_exist |
|
) |
|
|
|
def parse_command_line(): |
|
parser = argparse.ArgumentParser( |
|
description='pipeline for deep atlas train') |
|
parser.add_argument('--config', metavar='path to the configuration file', type=str, |
|
help='absolute path to the configuration file') |
|
parser.add_argument('--continue_training', action='store_true', |
|
help='use this if you want to continue a training') |
|
parser.add_argument('--train_only', action='store_true', |
|
help='only training or training plus test') |
|
parser.add_argument('--plot_network', action='store_true', |
|
help='whether to plot the network') |
|
argv = parser.parse_args() |
|
return argv |
|
|
|
|
|
def get_seg_net(spatial_dims, num_label, dropout, activation_type, normalization_type, num_res): |
|
seg_net = segNet( |
|
spatial_dim=spatial_dims, |
|
in_channel=1, |
|
out_channel=num_label, |
|
channel=(8, 16, 16, 32, 32, 64, 64), |
|
stride=(1, 2, 1, 2, 1, 2), |
|
dropouts=dropout, |
|
acts=activation_type, |
|
norms=normalization_type, |
|
num_res_unit=num_res |
|
) |
|
return seg_net |
|
|
|
|
|
def get_reg_net(spatial_dims, num_label, dropout, activation_type, normalization_type, num_res): |
|
reg_net = regNet( |
|
spatial_dim=spatial_dims, |
|
in_channel=2, |
|
out_channel=num_label, |
|
channel=(16, 32, 32, 32, 32), |
|
stride=(1, 2, 2, 2), |
|
dropouts=dropout, |
|
acts=activation_type, |
|
norms=normalization_type, |
|
num_res_unit=num_res |
|
) |
|
return reg_net |
|
|
|
|
|
def setup_logger(logger_name, log_file, level=logging.INFO): |
|
log_setup = logging.getLogger(logger_name) |
|
formatter = logging.Formatter('%(asctime)s %(message)s', datefmt="%Y-%m-%d %H:%M:%S") |
|
fileHandler = logging.FileHandler(log_file, mode='w') |
|
fileHandler.setFormatter(formatter) |
|
streamHandler = logging.StreamHandler() |
|
streamHandler.setFormatter(formatter) |
|
log_setup.setLevel(level) |
|
log_setup.addHandler(fileHandler) |
|
log_setup.addHandler(streamHandler) |
|
|
|
def classify_data(data_info, fold): |
|
lab_each_fold = {} |
|
lab = [] |
|
unlab = [] |
|
total_seg = 0 |
|
total_seg_each_fold = {} |
|
for key, value in data_info.items(): |
|
if key != f'fold_{fold}': |
|
lab_each_fold[key] = [] |
|
total_seg_each_fold[key] = 0 |
|
for val in value: |
|
if 'seg' not in val.keys(): |
|
unlab.append(val) |
|
else: |
|
lab_each_fold[key].append(val) |
|
lab.append(val) |
|
total_seg += 1 |
|
total_seg_each_fold[key] += 1 |
|
|
|
return lab_each_fold, lab, unlab, total_seg, total_seg_each_fold |
|
|
|
def select_n_seg(lab, fold, num, total_seg_each_fold): |
|
seg_items = lab[f'fold_{fold}'] |
|
num_seg = len(seg_items) |
|
rand_num = random.sample(range(num_seg), num) |
|
seg_item = np.array(seg_items)[np.array(rand_num)] |
|
seg_items.pop(rand_num[0]) |
|
total_seg_each_fold[f'fold_{fold}'] -= 1 |
|
lab[f'fold_{fold}'] = seg_items |
|
return list(seg_item), lab, total_seg_each_fold |
|
|
|
def combine_data(data_info, fold, exp, num_seg): |
|
all_fold = np.arange(len(data_info.keys())) + 1 |
|
num_train_fold = len(data_info.keys()) - 1 |
|
fake_train_fold = np.delete(all_fold, fold-1) |
|
fake_train_fold = np.tile(fake_train_fold, 2) |
|
real_train_fold = fake_train_fold[fold-1:fold+num_train_fold-1] |
|
train = [] |
|
test = [] |
|
for j in data_info[f'fold_{fold}']: |
|
if 'seg' in j.keys(): |
|
test.append(j) |
|
|
|
lab_each_fold, lab, unlab, total_seg, total_seg_each_fold = classify_data(data_info, fold) |
|
if total_seg < num_seg: |
|
num_seg = total_seg |
|
|
|
num_each_fold_seg = divmod(num_seg, num_train_fold)[0] |
|
fold_num_seg = np.repeat(num_each_fold_seg, num_train_fold) |
|
num_remain_seg = divmod(num_seg, num_train_fold)[1] |
|
count = 0 |
|
while num_remain_seg > 0: |
|
fold_num_seg[count] += 1 |
|
count = (count+1) % num_train_fold |
|
num_remain_seg -= 1 |
|
|
|
train = unlab |
|
k = 0 |
|
while num_seg > 0: |
|
next_fold = real_train_fold[k] |
|
if total_seg_each_fold[f'fold_{next_fold}'] > 0: |
|
seg_items, lab_each_fold, total_seg_each_fold = select_n_seg(lab_each_fold, next_fold, 1, total_seg_each_fold) |
|
train.extend(seg_items) |
|
num_seg -= 1 |
|
k = (k+1) % 4 |
|
|
|
num_segs = 0 |
|
if exp != 1: |
|
for key, value in total_seg_each_fold.items(): |
|
if value != 0: |
|
for j in lab_each_fold[key]: |
|
item = {'img': j['img']} |
|
train.append(item) |
|
total_seg_each_fold[key] -= 1 |
|
for key, value in total_seg_each_fold.items(): |
|
num_segs += value |
|
|
|
assert num_segs == 0 |
|
|
|
return train, test |
|
|
|
|
|
def main(): |
|
args = parse_command_line() |
|
config = args.config |
|
continue_training = args.continue_training |
|
train_only = args.train_only |
|
config = load_json(config) |
|
config = namedtuple("config", config.keys())(*config.values()) |
|
folder_name = config.folder_name |
|
num_seg_used = config.num_seg_used |
|
experiment_set = config.exp_set |
|
monai.utils.set_determinism(seed=2938649572) |
|
data_path = os.path.join(ROOT_DIR, 'deepatlas_results') |
|
base_path = os.path.join(ROOT_DIR, 'deepatlas_preprocessed') |
|
task = os.path.join(data_path, config.task_name) |
|
exp_path = os.path.join(task, f'set_{experiment_set}') |
|
gt_path = os.path.join(exp_path, f'{num_seg_used}gt') |
|
folder_path = os.path.join(gt_path, folder_name) |
|
result_path = os.path.join(folder_path, 'training_results') |
|
if train_only: |
|
info_name = 'info_train_only' |
|
else: |
|
info_name = 'info' |
|
info_path = os.path.join(base_path, config.task_name, 'Training_dataset', 'data_info', folder_name, info_name+'.json') |
|
info = load_json(info_path) |
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda:" + str(torch.cuda.current_device())) |
|
|
|
spatial_dim = config.network['spatial_dim'] |
|
dropout = config.network['dropout'] |
|
activation_type = config.network['activation_type'] |
|
normalization_type = config.network['normalization_type'] |
|
num_res = config.network['num_res'] |
|
lr_reg = config.network["registration_network_learning_rate"] |
|
lr_seg = config.network["segmentation_network_learning_rate"] |
|
lam_a = config.network["anatomy_loss_weight"] |
|
lam_sp = config.network["supervised_segmentation_loss_weight"] |
|
lam_re = config.network["regularization_loss_weight"] |
|
max_epoch = config.network["number_epoch"] |
|
val_step = config.network["validation_step"] |
|
make_if_dont_exist(data_path) |
|
make_if_dont_exist(task) |
|
make_if_dont_exist(exp_path) |
|
make_if_dont_exist(gt_path) |
|
make_if_dont_exist(folder_path) |
|
make_if_dont_exist(result_path) |
|
|
|
if not continue_training: |
|
start_fold = 1 |
|
else: |
|
folds = sorted(os.listdir(result_path)) |
|
if len(folds) == 0: |
|
continue_training = False |
|
start_fold = 1 |
|
else: |
|
last_fold_num = folds[-1].split('_')[-1] |
|
start_fold = int(last_fold_num) |
|
|
|
if train_only: |
|
num_fold = 1 |
|
else: |
|
num_fold = config.num_fold |
|
|
|
for i in range (start_fold, num_fold+1): |
|
if not train_only: |
|
fold_path = os.path.join(result_path, f'fold_{i}') |
|
result_seg_path = os.path.join(fold_path, 'SegNet') |
|
result_reg_path = os.path.join(fold_path, 'RegNet') |
|
else: |
|
fold_path = os.path.join(result_path, f'all') |
|
result_seg_path = os.path.join(fold_path, 'SegNet') |
|
result_reg_path = os.path.join(fold_path, 'RegNet') |
|
|
|
make_if_dont_exist(fold_path) |
|
make_if_dont_exist(result_reg_path) |
|
make_if_dont_exist(result_seg_path) |
|
datetime_object = 'training_log_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") + '.log' |
|
log_path = os.path.join(fold_path, datetime_object) |
|
|
|
if not train_only: |
|
if not continue_training: |
|
setup_logger(f'log_{i}', log_path) |
|
logger = logging.getLogger(f'log_{i}') |
|
logger.info(f"Start Pipeline with fold_{i}") |
|
else: |
|
setup_logger(f'log_{i+1}', log_path) |
|
logger = logging.getLogger(f'log_{i+1}') |
|
logger.info(f"Resume Pipeline with fold_{i}") |
|
else: |
|
setup_logger(f'all', log_path) |
|
logger = logging.getLogger(f'all') |
|
logger.info(f"Start Pipeline with all data") |
|
|
|
if not os.path.exists(os.path.join(fold_path, 'dataset.json')): |
|
logger.info('prepare dataset into train and test') |
|
json_dict = OrderedDict() |
|
json_dict['name'] = os.path.basename(task).split('_')[0] |
|
json_dict['description'] = '_'.join(os.path.basename(task).split('_')[1:]) |
|
json_dict['tensorImageSize'] = "4D" |
|
json_dict['reference'] = "MODIFY" |
|
json_dict['licence'] = "MODIFY" |
|
json_dict['release'] = "0.0" |
|
json_dict['modality'] = { |
|
"0": "CT" |
|
} |
|
json_dict['labels'] = config.labels |
|
json_dict['network'] = config.network |
|
json_dict['experiment_set'] = experiment_set |
|
if not train_only: |
|
json_dict['num_fold'] = f'fold_{i}' |
|
train, test = combine_data(info, i, experiment_set, num_seg_used) |
|
else: |
|
json_dict['num_fold'] = 'all' |
|
train = info |
|
test = [] |
|
num_seg_used = len(list(filter(lambda d: 'seg' in d.keys(), train))) |
|
|
|
|
|
|
|
|
|
num_seg = num_seg_used |
|
num_train = len(train) |
|
num_test = len(test) |
|
|
|
json_dict['total_numScanTraining'] = num_train |
|
json_dict['total_numLabelTraining'] = num_seg |
|
json_dict['total_numTest'] = num_test |
|
json_dict['total_train'] = train |
|
json_dict['total_test'] = test |
|
|
|
logger.info('prepare segmentation dataset') |
|
data_seg_available = list(filter(lambda d: 'seg' in d.keys(), train)) |
|
data_seg_unavailable = list(filter(lambda d: 'seg' not in d.keys(), train)) |
|
data_seg_available_train, data_seg_available_valid = \ |
|
monai.data.utils.partition_dataset(data_seg_available, ratios=(8, 2)) |
|
json_dict['seg_numTrain'] = len(data_seg_available_train) |
|
json_dict['seg_train'] = data_seg_available_train |
|
json_dict['seg_numValid'] = len(data_seg_available_valid) |
|
json_dict['seg_valid'] = data_seg_available_valid |
|
dataset_seg_available_train, dataset_seg_available_valid = load_seg_dataset( |
|
data_seg_available_train, data_seg_available_valid) |
|
data_item = random.choice(dataset_seg_available_train) |
|
img_shape = data_item['seg'].unsqueeze(0).shape[2:] |
|
num_label = len(torch.unique(data_item['seg'])) |
|
logger.info('prepare segmentation network') |
|
seg_net = get_seg_net(spatial_dim, num_label, dropout, |
|
activation_type, normalization_type, num_res) |
|
|
|
logger.info('prepare registration dataset') |
|
data_without_seg_valid = data_seg_unavailable + data_seg_available_train |
|
data_valid, data_train = monai.data.utils.partition_dataset( |
|
data_without_seg_valid, |
|
ratios=(2, 8), |
|
shuffle=False |
|
) |
|
data_paires_without_seg_valid = take_data_pairs(data_without_seg_valid) |
|
data_pairs_valid = take_data_pairs(data_valid) |
|
data_pairs_train = take_data_pairs(data_train) |
|
data_pairs_valid_subdivided = subdivide_list_of_data_pairs( |
|
data_pairs_valid) |
|
data_pairs_train_subdivided = subdivide_list_of_data_pairs( |
|
data_pairs_train) |
|
num_train_reg_net = len(data_pairs_train) |
|
num_valid_reg_net = len(data_pairs_valid) |
|
num_train_both = len(data_pairs_train_subdivided['01']) +\ |
|
len(data_pairs_train_subdivided['10']) +\ |
|
len(data_pairs_train_subdivided['11']) |
|
json_dict['reg_seg_numTrain'] = num_train_reg_net |
|
json_dict['reg_seg_numTrain_00'] = len(data_pairs_train_subdivided['00']) |
|
json_dict['reg_seg_train_00'] = data_pairs_train_subdivided['00'] |
|
json_dict['reg_seg_numTrain_01'] = len(data_pairs_train_subdivided['01']) |
|
json_dict['reg_seg_train_01'] = data_pairs_train_subdivided['01'] |
|
json_dict['reg_seg_numTrain_10'] = len(data_pairs_train_subdivided['10']) |
|
json_dict['reg_seg_train_10'] = data_pairs_train_subdivided['10'] |
|
json_dict['reg_seg_numTrain_11'] = len(data_pairs_train_subdivided['11']) |
|
json_dict['reg_seg_train_11'] = data_pairs_train_subdivided['11'] |
|
json_dict['reg_numValid'] = num_valid_reg_net |
|
json_dict['reg_numValid_00'] = len(data_pairs_valid_subdivided['00']) |
|
json_dict['reg_valid_00'] = data_pairs_valid_subdivided['00'] |
|
json_dict['reg_numValid_01'] = len(data_pairs_valid_subdivided['01']) |
|
json_dict['reg_valid_01'] = data_pairs_valid_subdivided['01'] |
|
json_dict['reg_numValid_10'] = len(data_pairs_valid_subdivided['10']) |
|
json_dict['reg_valid_10'] = data_pairs_valid_subdivided['10'] |
|
json_dict['reg_numValid_11'] = len(data_pairs_valid_subdivided['11']) |
|
json_dict['reg_valid_11'] = data_pairs_valid_subdivided['11'] |
|
print(f"""We have {num_train_both} pairs to train reg_net and seg_net together, and an additional {num_train_reg_net - num_train_both} to train reg_net alone.""") |
|
print(f"We have {num_valid_reg_net} pairs for reg_net validation.") |
|
|
|
dataset_pairs_train_subdivided, dataset_pairs_valid_subdivided = load_reg_dataset( |
|
data_pairs_train_subdivided, data_pairs_valid_subdivided) |
|
logger.info('prepare registration network') |
|
reg_net = get_reg_net(spatial_dim, spatial_dim, dropout, |
|
activation_type, normalization_type, num_res) |
|
logger.info('generate dataset json file') |
|
with open(os.path.join(fold_path, 'dataset.json'), 'w') as f: |
|
json.dump(json_dict, f, indent=4, sort_keys=False) |
|
|
|
else: |
|
dataset_json = load_json(os.path.join(fold_path, 'dataset.json')) |
|
|
|
data_seg_available_train = dataset_json['seg_train'] |
|
data_seg_available_valid = dataset_json['seg_valid'] |
|
dataset_seg_available_train, dataset_seg_available_valid = load_seg_dataset(data_seg_available_train, data_seg_available_valid) |
|
data_item = random.choice(dataset_seg_available_train) |
|
img_shape = data_item['seg'].unsqueeze(0).shape[2:] |
|
num_label = len(torch.unique(data_item['seg'])) |
|
logger.info('prepare segmentation network') |
|
seg_net = get_seg_net(spatial_dim, num_label, dropout, activation_type, normalization_type, num_res) |
|
|
|
data_pairs_train_subdivided = { |
|
'00': dataset_json['reg_seg_train_00'], |
|
'01': dataset_json['reg_seg_train_01'], |
|
'10': dataset_json['reg_seg_train_10'], |
|
'11': dataset_json['reg_seg_train_11'] |
|
} |
|
data_pairs_valid_subdivided = { |
|
'00': dataset_json['reg_valid_00'], |
|
'01': dataset_json['reg_valid_01'], |
|
'10': dataset_json['reg_valid_10'], |
|
'11': dataset_json['reg_valid_11'] |
|
} |
|
num_train_reg_net = dataset_json['reg_seg_numTrain'] |
|
num_valid_reg_net = dataset_json['reg_numValid'] |
|
num_train_both = len(data_pairs_train_subdivided['01']) +\ |
|
len(data_pairs_train_subdivided['10']) +\ |
|
len(data_pairs_train_subdivided['11']) |
|
print(f"""We have {num_train_both} pairs to train reg_net and seg_net together, |
|
and an additional {num_train_reg_net - num_train_both} to train reg_net alone.""") |
|
print(f"We have {num_valid_reg_net} pairs for reg_net validation.") |
|
|
|
dataset_pairs_train_subdivided, dataset_pairs_valid_subdivided = load_reg_dataset( |
|
data_pairs_train_subdivided, data_pairs_valid_subdivided) |
|
logger.info('prepare registration network') |
|
reg_net = get_reg_net(spatial_dim, spatial_dim, dropout, |
|
activation_type, normalization_type, num_res) |
|
|
|
|
|
dataloader_train_seg = monai.data.DataLoader( |
|
dataset_seg_available_train, |
|
batch_size=2, |
|
num_workers=4, |
|
shuffle=True |
|
) |
|
dataloader_valid_seg = monai.data.DataLoader( |
|
dataset_seg_available_valid, |
|
batch_size=4, |
|
num_workers=4, |
|
shuffle=False |
|
) |
|
dataloader_train_reg = { |
|
seg_availability: monai.data.DataLoader( |
|
dataset, |
|
batch_size=1, |
|
num_workers=4, |
|
shuffle=True |
|
) |
|
|
|
if len(dataset) > 0 else [] |
|
for seg_availability, dataset in dataset_pairs_train_subdivided.items() |
|
} |
|
dataloader_valid_reg = { |
|
seg_availability: monai.data.DataLoader( |
|
dataset, |
|
batch_size=2, |
|
num_workers=4, |
|
shuffle=True |
|
) |
|
|
|
if len(dataset) > 0 else [] |
|
for seg_availability, dataset in dataset_pairs_valid_subdivided.items() |
|
} |
|
train_network(dataloader_train_reg, |
|
dataloader_valid_reg, |
|
dataloader_train_seg, |
|
dataloader_valid_seg, |
|
device, |
|
seg_net, |
|
reg_net, |
|
num_label, |
|
lr_reg, |
|
lr_seg, |
|
lam_a, |
|
lam_sp, |
|
lam_re, |
|
max_epoch, |
|
val_step, |
|
result_seg_path, |
|
result_reg_path, |
|
logger, |
|
img_shape, |
|
plot_network=args.plot_network, |
|
continue_training=continue_training |
|
) |
|
''' |
|
seg_train.train_seg( |
|
dataloader_train_seg, |
|
dataloader_valid_seg, |
|
device, |
|
seg_net, |
|
lr_seg, |
|
max_epoch, |
|
val_step, |
|
result_seg_path |
|
) |
|
''' |
|
|
|
if __name__ == '__main__': |
|
torch.cuda.empty_cache() |
|
main() |
|
|