BrainFM / scripts /demo_generator.py
peirong26's picture
Upload 187 files
2571f24 verified
###############################
#### Synthetic Data Demo ####
###############################
import datetime
import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import time
import torch
import utils.misc as utils
from Generator import build_datasets
# default & gpu cfg #
default_gen_cfg_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/default.yaml'
demo_gen_cfg_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/test/demo_synth.yaml'
def map_back_orig(img, idx, shp):
if idx is None or shp is None:
return img
if len(img.shape) == 3:
img = img[None, None]
elif len(img.shape) == 4:
img = img[None]
return img[:, :, idx[0]:idx[0] + shp[0], idx[1]:idx[1] + shp[1], idx[2]:idx[2] + shp[2]]
def generate(args):
_, gen_args, _ = args
if gen_args.device_generator:
device = gen_args.device_generator
elif torch.cuda.is_available():
device = torch.cuda.current_device()
else:
device = 'cpu'
print('device: %s' % device)
print('out_dir:', gen_args.out_dir)
# ============ preparing data ... ============
dataset_dict = build_datasets(gen_args, device = gen_args.device_generator if gen_args.device_generator is not None else device)
dataset = dataset_dict[gen_args.dataset_names[0]]
tasks = [key for (key, value) in vars(gen_args.task).items() if value]
print("Start generating")
start_time = time.time()
dataset.mild_samples = gen_args.mild_samples
dataset.all_samples = gen_args.all_samples
for itr in range(min(gen_args.test_itr_limit, len(dataset.names))):
subj_name = os.path.basename(dataset.names[itr]).split('.nii')[0]
save_dir = utils.make_dir(os.path.join(gen_args.out_dir, subj_name))
print('Processing image (%d/%d): %s' % (itr, len(dataset), dataset.names[itr]))
for i_deform in range(gen_args.num_deformations):
def_save_dir = utils.make_dir(os.path.join(save_dir, 'deform-%s' % i_deform))
(_, subjects, samples) = dataset.__getitem__(itr)
if 'aff' in subjects:
aff = subjects['aff']
shp = subjects['shp']
loc_idx = subjects['loc_idx']
else:
aff = torch.eye((4))
shp = loc_idx = None
print('num samples:', len(samples))
print(' deform:', i_deform)
#print(subjects.keys())
if 'T1' in subjects:
utils.viewVolume(subjects['T1'], aff, names = ['T1'], save_dir = def_save_dir)
if 'T2' in subjects:
utils.viewVolume(subjects['T2'], aff, names = ['T2'], save_dir = def_save_dir)
if 'FLAIR' in subjects:
utils.viewVolume(subjects['FLAIR'], aff, names = ['FLAIR'], save_dir = def_save_dir)
if 'CT' in subjects:
utils.viewVolume(subjects['CT'], aff, names = ['CT'], save_dir = def_save_dir)
if 'pathology' in tasks:
utils.viewVolume(subjects['pathology'], aff, names = ['pathology'], save_dir = def_save_dir)
if 'segmentation' in tasks:
utils.viewVolume(subjects['segmentation']['label'], aff, names = ['label'], save_dir = def_save_dir)
for i_sample, sample in enumerate(samples):
print(' sample:', i_sample)
sample_save_dir = utils.make_dir(os.path.join(def_save_dir, 'sample-%s' % i_sample))
#print(sample.keys())
if 'input' in sample:
utils.viewVolume(map_back_orig(sample['input'], loc_idx, shp), aff, names = ['input'], save_dir = sample_save_dir)
if 'super_resolution' in tasks:
utils.viewVolume(map_back_orig(sample['orig'], loc_idx, shp), aff, names = ['high_reso'], save_dir = sample_save_dir)
if 'bias_field' in tasks:
utils.viewVolume(map_back_orig(torch.exp(sample['bias_field_log']), loc_idx, shp), aff, names = ['bias_field'], save_dir = sample_save_dir)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Generation time {}'.format(total_time_str))
#####################################################################################
if __name__ == '__main__':
gen_args = utils.preprocess_cfg([default_gen_cfg_file, demo_gen_cfg_file])
utils.launch_job(submit_cfg = None, gen_cfg = gen_args, train_cfg = None, func = generate)