AutoSeg4Sinonasal / deepatlas /scripts /deep_atlas_test_customized.py
Chris Xiao
upload files
2ca2f68
from pkg_resources import add_activation_listener
import monai
import torch
import itk
import numpy as np
import os.path
import argparse
import sys
from pathlib import Path
import deep_atlas_train
import glob
ROOT_DIR = str(Path(os.getcwd()).parent.parent.absolute())
sys.path.insert(0, os.path.join(ROOT_DIR, 'deepatlas/test'))
from test import (
seg_training_inference, load_json, reg_training_inference
)
def parse_command_line():
parser = argparse.ArgumentParser(
description='pipeline for deep atlas test')
parser.add_argument('-gpu', metavar='id of gpu', type=str, default='0',
help='id of gpu device to use')
parser.add_argument('-ti', metavar='task id and name', type=str,
help='task name and id')
parser.add_argument('-nf', metavar='number of fold', type=int,
help='number of fold for testing')
parser.add_argument('-op', metavar='output path for prediction step', type=str,
help="relative path of the output directory, should be same name in the registration, crop and final prediction steps")
argv = parser.parse_args()
return argv
def path_to_id(path):
return os.path.basename(path).split('.')[0]
def main():
ROOT_DIR = str(Path(os.getcwd()).parent.parent.absolute())
args = parse_command_line()
gpu = args.gpu
task = args.ti
num_fold = f'fold_{args.nf}'
output_path = os.path.join(ROOT_DIR, 'deepatlas_results', task, 'customize_predicted_results')
fold_path = os.path.join(output_path, num_fold)
out_path = os.path.join(fold_path, args.op)
json_path = os.path.join(
ROOT_DIR, 'deepatlas_results', task, "training_results", num_fold, 'dataset.json')
seg_model_path = os.path.join(
ROOT_DIR, 'deepatlas_results', task, 'training_results', num_fold, 'SegNet', 'seg_net_best.pth')
reg_model_path = os.path.join(
ROOT_DIR, 'deepatlas_results', task, 'training_results', num_fold, 'RegNet', 'reg_net_best.pth')
json_file = load_json(json_path)
labels = json_file['labels']
num_label = len(labels.keys())
network_info = json_file['network']
spatial_dim = network_info['spatial_dim']
dropout = network_info['dropout']
activation_type = network_info['activation_type']
normalization_type = network_info['normalization_type']
num_res = network_info['num_res']
device = torch.device("cuda:" + gpu)
output_seg_path = os.path.join(out_path, 'SegNet')
output_reg_path = os.path.join(out_path, 'RegNet')
try:
os.mkdir(output_path)
except:
print(f'{output_path} is already existed !!!')
try:
os.mkdir(fold_path)
except:
print(f'{fold_path} is already existed !!!')
try:
os.mkdir(out_path)
except:
print(f'{out_path} is already existed !!!')
try:
os.mkdir(output_seg_path)
except:
print(f'{output_seg_path} is already existed !!!')
try:
os.mkdir(output_reg_path)
except:
print(f'{output_reg_path} is already existed !!!')
seg_net = deep_atlas_train.get_seg_net(
spatial_dim, num_label, dropout, activation_type, normalization_type, num_res)
reg_net = deep_atlas_train.get_reg_net(
spatial_dim, spatial_dim, dropout, activation_type, normalization_type, num_res)
img_path = os.path.join(ROOT_DIR, 'deepatlas_preprocessed', task, 'customize_test_data', args.op, 'images')
seg_path = os.path.join(ROOT_DIR, 'deepatlas_preprocessed', task, 'customize_test_data', args.op, 'labels')
total_img_paths = []
total_seg_paths = []
for i in sorted(glob.glob(img_path + '/*.nii.gz')):
total_img_paths.append(i)
for j in sorted(glob.glob(seg_path + '/*.nii.gz')):
total_seg_paths.append(j)
seg_ids = list(map(path_to_id, total_seg_paths))
img_ids = map(path_to_id, total_img_paths)
data = []
for img_index, img_id in enumerate(img_ids):
data_item = {'img': total_img_paths[img_index]}
if img_id in seg_ids:
data_item['seg'] = total_seg_paths[seg_ids.index(img_id)]
data.append(data_item)
seg_training_inference(seg_net, device, seg_model_path, output_seg_path, num_label, json_path=None, data=data)
reg_training_inference(reg_net, device, reg_model_path, output_reg_path, num_label, json_path=None, data=data)
if __name__ == '__main__':
main()