|
|
|
|
|
|
|
|
|
|
|
import datetime |
|
import os, sys |
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
import time |
|
|
|
import torch |
|
|
|
import utils.misc as utils |
|
|
|
|
|
from Generator import build_datasets |
|
|
|
|
|
|
|
|
|
default_gen_cfg_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/default.yaml' |
|
demo_gen_cfg_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/test/demo_synth.yaml' |
|
|
|
|
|
def map_back_orig(img, idx, shp): |
|
if idx is None or shp is None: |
|
return img |
|
if len(img.shape) == 3: |
|
img = img[None, None] |
|
elif len(img.shape) == 4: |
|
img = img[None] |
|
return img[:, :, idx[0]:idx[0] + shp[0], idx[1]:idx[1] + shp[1], idx[2]:idx[2] + shp[2]] |
|
|
|
|
|
def generate(args): |
|
|
|
_, gen_args, _ = args |
|
|
|
if gen_args.device_generator: |
|
device = gen_args.device_generator |
|
elif torch.cuda.is_available(): |
|
device = torch.cuda.current_device() |
|
else: |
|
device = 'cpu' |
|
print('device: %s' % device) |
|
|
|
print('out_dir:', gen_args.out_dir) |
|
|
|
|
|
dataset_dict = build_datasets(gen_args, device = gen_args.device_generator if gen_args.device_generator is not None else device) |
|
dataset = dataset_dict[gen_args.dataset_names[0]] |
|
|
|
tasks = [key for (key, value) in vars(gen_args.task).items() if value] |
|
|
|
print("Start generating") |
|
start_time = time.time() |
|
|
|
|
|
dataset.mild_samples = gen_args.mild_samples |
|
dataset.all_samples = gen_args.all_samples |
|
for itr in range(min(gen_args.test_itr_limit, len(dataset.names))): |
|
|
|
subj_name = os.path.basename(dataset.names[itr]).split('.nii')[0] |
|
|
|
save_dir = utils.make_dir(os.path.join(gen_args.out_dir, subj_name)) |
|
|
|
print('Processing image (%d/%d): %s' % (itr, len(dataset), dataset.names[itr])) |
|
|
|
for i_deform in range(gen_args.num_deformations): |
|
def_save_dir = utils.make_dir(os.path.join(save_dir, 'deform-%s' % i_deform)) |
|
|
|
(_, subjects, samples) = dataset.__getitem__(itr) |
|
|
|
if 'aff' in subjects: |
|
aff = subjects['aff'] |
|
shp = subjects['shp'] |
|
loc_idx = subjects['loc_idx'] |
|
else: |
|
aff = torch.eye((4)) |
|
shp = loc_idx = None |
|
|
|
print('num samples:', len(samples)) |
|
print(' deform:', i_deform) |
|
|
|
|
|
|
|
if 'T1' in subjects: |
|
utils.viewVolume(subjects['T1'], aff, names = ['T1'], save_dir = def_save_dir) |
|
if 'T2' in subjects: |
|
utils.viewVolume(subjects['T2'], aff, names = ['T2'], save_dir = def_save_dir) |
|
if 'FLAIR' in subjects: |
|
utils.viewVolume(subjects['FLAIR'], aff, names = ['FLAIR'], save_dir = def_save_dir) |
|
if 'CT' in subjects: |
|
utils.viewVolume(subjects['CT'], aff, names = ['CT'], save_dir = def_save_dir) |
|
if 'pathology' in tasks: |
|
utils.viewVolume(subjects['pathology'], aff, names = ['pathology'], save_dir = def_save_dir) |
|
if 'segmentation' in tasks: |
|
utils.viewVolume(subjects['segmentation']['label'], aff, names = ['label'], save_dir = def_save_dir) |
|
|
|
for i_sample, sample in enumerate(samples): |
|
print(' sample:', i_sample) |
|
sample_save_dir = utils.make_dir(os.path.join(def_save_dir, 'sample-%s' % i_sample)) |
|
|
|
|
|
|
|
if 'input' in sample: |
|
utils.viewVolume(map_back_orig(sample['input'], loc_idx, shp), aff, names = ['input'], save_dir = sample_save_dir) |
|
if 'super_resolution' in tasks: |
|
utils.viewVolume(map_back_orig(sample['orig'], loc_idx, shp), aff, names = ['high_reso'], save_dir = sample_save_dir) |
|
if 'bias_field' in tasks: |
|
utils.viewVolume(map_back_orig(torch.exp(sample['bias_field_log']), loc_idx, shp), aff, names = ['bias_field'], save_dir = sample_save_dir) |
|
|
|
|
|
total_time = time.time() - start_time |
|
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
print('Generation time {}'.format(total_time_str)) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
gen_args = utils.preprocess_cfg([default_gen_cfg_file, demo_gen_cfg_file]) |
|
utils.launch_job(submit_cfg = None, gen_cfg = gen_args, train_cfg = None, func = generate) |