File size: 4,709 Bytes
2571f24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
###############################
####  Synthetic Data Demo  ####
###############################


import datetime
import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import time

import torch

import utils.misc as utils 
 

from Generator import build_datasets 



# default & gpu cfg # 
default_gen_cfg_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/default.yaml' 
demo_gen_cfg_file = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/test/demo_synth.yaml'


def map_back_orig(img, idx, shp):
    if idx is None or shp is None:
        return img
    if len(img.shape) == 3:
        img = img[None, None]
    elif len(img.shape) == 4:
        img = img[None]
    return img[:, :, idx[0]:idx[0] + shp[0], idx[1]:idx[1] + shp[1], idx[2]:idx[2] + shp[2]]


def generate(args):

    _, gen_args, _ = args
 
    if gen_args.device_generator:
        device = gen_args.device_generator
    elif torch.cuda.is_available():
        device = torch.cuda.current_device()
    else:
        device = 'cpu'  
    print('device: %s' % device) 

    print('out_dir:', gen_args.out_dir)

    # ============ preparing data ... ============ 
    dataset_dict = build_datasets(gen_args, device = gen_args.device_generator if gen_args.device_generator is not None else device) 
    dataset = dataset_dict[gen_args.dataset_names[0]] 

    tasks = [key for (key, value) in vars(gen_args.task).items() if value]
       
    print("Start generating")
    start_time = time.time()


    dataset.mild_samples = gen_args.mild_samples
    dataset.all_samples = gen_args.all_samples 
    for itr in range(min(gen_args.test_itr_limit, len(dataset.names))):
        
        subj_name = os.path.basename(dataset.names[itr]).split('.nii')[0]

        save_dir = utils.make_dir(os.path.join(gen_args.out_dir, subj_name))

        print('Processing image (%d/%d): %s' % (itr, len(dataset), dataset.names[itr]))

        for i_deform in range(gen_args.num_deformations):
            def_save_dir = utils.make_dir(os.path.join(save_dir, 'deform-%s' % i_deform))

            (_, subjects, samples) = dataset.__getitem__(itr)
                
            if 'aff' in subjects:
                aff = subjects['aff']
                shp = subjects['shp']
                loc_idx = subjects['loc_idx']
            else:
                aff = torch.eye((4))
                shp = loc_idx = None
            
            print('num samples:', len(samples))
            print('     deform:', i_deform)
            
            #print(subjects.keys())

            if 'T1' in subjects:
                utils.viewVolume(subjects['T1'], aff, names = ['T1'], save_dir = def_save_dir)
            if 'T2' in subjects:
                utils.viewVolume(subjects['T2'], aff, names = ['T2'], save_dir = def_save_dir)
            if 'FLAIR' in subjects:
                utils.viewVolume(subjects['FLAIR'], aff, names = ['FLAIR'], save_dir = def_save_dir)
            if 'CT' in subjects:
                utils.viewVolume(subjects['CT'], aff, names = ['CT'], save_dir = def_save_dir) 
            if 'pathology' in tasks:
                utils.viewVolume(subjects['pathology'], aff, names = ['pathology'], save_dir = def_save_dir)
            if 'segmentation' in tasks:
                utils.viewVolume(subjects['segmentation']['label'], aff, names = ['label'], save_dir = def_save_dir)

            for i_sample, sample in enumerate(samples):
                print('         sample:', i_sample)
                sample_save_dir = utils.make_dir(os.path.join(def_save_dir, 'sample-%s' % i_sample))

                #print(sample.keys())
                
                if 'input' in sample:
                    utils.viewVolume(map_back_orig(sample['input'], loc_idx, shp), aff, names = ['input'], save_dir = sample_save_dir)
                if 'super_resolution' in tasks:
                    utils.viewVolume(map_back_orig(sample['orig'], loc_idx, shp), aff, names = ['high_reso'], save_dir = sample_save_dir)
                if 'bias_field' in tasks:
                    utils.viewVolume(map_back_orig(torch.exp(sample['bias_field_log']), loc_idx, shp), aff, names = ['bias_field'], save_dir = sample_save_dir)

 
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Generation time {}'.format(total_time_str))


#####################################################################################


if __name__ == '__main__': 
    gen_args = utils.preprocess_cfg([default_gen_cfg_file, demo_gen_cfg_file])
    utils.launch_job(submit_cfg = None, gen_cfg = gen_args, train_cfg = None, func = generate)