BrainFM / scripts /test.py
peirong26's picture
Upload 187 files
2571f24 verified
###############################
#### Brani-ID Inference #####
###############################
import os, sys, warnings, shutil, glob, time, datetime
warnings.filterwarnings("ignore")
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from collections import defaultdict
import torch
import numpy as np
from utils.misc import make_dir, viewVolume, MRIread
import utils.test_utils as utils
from Generator.utils import fast_3D_interp_torch
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
#################################
### For hemisphere prediction ###
#################################
label_list_left_segmentation = [0, 1, 2, 3, 4, 7, 8, 9, 10, 14, 15, 17, 31, 34, 36, 38, 40, 42]
lut = torch.zeros(10000, dtype=torch.long, device=device)
for l in range(len(label_list_left_segmentation)):
lut[label_list_left_segmentation[l]] = l
# get left hemis mask
#S = read_brainseg.nii
#S = lut[X.astype(np.int)]
#X = read_mni_coord_X
#M = (S > 0) & (X < 0)
# apply hemis mask for all image I
#I[M==0] = 0
#################################
#################################
def prepare_paths(data_root, split_txt):
# Collect list of available images, per dataset
datasets = []
g = glob.glob(os.path.join(data_root, '*' + 'T1w.nii'))
for i in range(len(g)):
filename = os.path.basename(g[i])
dataset = filename[:filename.find('.')]
found = False
for d in datasets:
if dataset == d:
found = True
if found is False:
datasets.append(dataset)
print('Found ' + str(len(datasets)) + ' datasets with ' + str(len(g)) + ' scans in total')
print('Dataset list', datasets)
names = []
split_file = open(split_txt, 'r')
split_names = []
for subj in split_file.readlines():
split_names.append(subj.strip())
for i in range(len(datasets)):
names.append([name for name in split_names if os.path.basename(name).startswith(datasets[i])])
datasets_num = len(datasets)
datasets_len = [len(names[i]) for i in range(len(names))]
print('Num of testing data', sum([len(names[i]) for i in range(len(names))]))
return names, datasets
def get_info(t1):
t2 = t1[:-7] + 'T2w.nii'
flair = t1[:-7] + 'FLAIR.nii'
ct = t1[:-7] + 'CT.nii'
cerebral_labels = t1[:-7] + 'brainseg.nii'
segmentation_labels = t1[:-7] + 'brainseg_with_extracerebral.nii'
brain_dist_map = t1[:-7] + 'brain_dist_map.nii'
lp_dist_map = t1[:-7] + 'lp_dist_map.nii'
rp_dist_map = t1[:-7] + 'rp_dist_map.nii'
lw_dist_map = t1[:-7] + 'lw_dist_map.nii'
rw_dist_map = t1[:-7] + 'rw_dist_map.nii'
mni_reg_x = t1[:-7] + 'mni_reg.x.nii'
mni_reg_y = t1[:-7] + 'mni_reg.y.nii'
mni_reg_z = t1[:-7] + 'mni_reg.z.nii'
modalities = {'T1': t1}
if os.path.isfile(t2):
modalities.update({'T2': t2})
if os.path.isfile(flair):
modalities.update({'FLAIR': flair})
if os.path.isfile(ct):
modalities.update({'CT': ct})
aux = {'label': segmentation_labels, 'cerebral_label': cerebral_labels, 'distance': brain_dist_map,
'regx': mni_reg_x, 'regy': mni_reg_y, 'regz': mni_reg_z,
'lp': lp_dist_map, 'lw': lw_dist_map, 'rp': rp_dist_map, 'rw': rw_dist_map}
return modalities, aux
#################################
gen_cfg = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/test/demo_test.yaml'
gen_hemis_cfg = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/generator/test/demo_test_hemis.yaml'
model_cfg = '/autofs/space/yogurt_003/users/pl629/code/MTBrainID/cfgs/trainer/test/demo_test.yaml'
#win_size = [192, 192, 192]
win_size = [160, 160, 160]
mask_output = False
exclude_keys = ['segmentation']
data_root = '/autofs/vast/lemon/data_curated/brain_mris_QCed'
split_txt = '/autofs/vast/lemon/temp_stuff/peirong/train_test_split/test.txt'
names, datasets = prepare_paths(data_root, split_txt)
max_num_test_dataset = None #1
max_num_per_dataset = None #5
zero_crop = False
main_save_dir = make_dir('/autofs/space/yogurt_002/users/pl629/results/MTBrainID/test/', reset = False)
models = [
#('test_reggr', '/autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/wosr_reggrad/l6_16/1025-1744/ckp/checkpoint_latest.pth'),
#('test_lowres', '/autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/wosr_reggrad_lowres/l6_16/1025-1746/ckp/checkpoint_latest.pth'),
('test_sr', '/autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/sr/l6_16/0926-2035/ckp/checkpoint_latest.pth'),
#('test_sr_lowres', '/autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/sr_lowres/l6_16/0926-2025/ckp/checkpoint_latest.pth'),
#('test_hemis', '/autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/hemis/l6_16/0806-1008/ckp/checkpoint_latest.pth'),
#('comp_synth', '/autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/synth/l6_16/0924-0929/ckp/checkpoint_latest.pth'),
#('comp_dist', '/autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/dist/l6_16/0806-1024/ckp/checkpoint_latest.pth'),
#('comp_reg', '/autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/reg/l6_16/0806-1029/ckp/checkpoint_latest.pth'),
#('comp_bf', '/autofs/vast/lemon/temp_stuff/peirong/results/MTBrainID/bf/l6_16/0806-1026/ckp/checkpoint_latest.pth'),
]
#spacing = [1.5, 1.5, 5] # [1, 1, 1], [1.5, 1.5, 5], [3, 3, 3], None
#add_bf = False
setups = [
#([1, 1, 1], False),
#([1, 1, 1], True),
([1.5, 1.5, 5], False),
#([1.5, 1.5, 5], True),
]
all_start_time = time.time()
for postfix, ckp_path in models:
for spacing, add_bf in setups:
curr_postfix = postfix + '_BF' if add_bf else postfix
curr_postfix += '_%s-%s-%s' % (str(spacing[0]), str(spacing[1]), str(spacing[2])) if spacing is not None else '_1-1-1'
save_dir = make_dir(os.path.join(main_save_dir, curr_postfix), reset = True)
print('\nSave at: %s\n' % save_dir)
curr_gen_cfg = gen_hemis_cfg if 'hemis' in postfix else gen_cfg
for i, curr_dataset in enumerate(names):
curr_dataset.sort()
print('Dataset: %s (%d/%d) -- %d total cases' % (datasets[i], i+1, len(datasets), len(curr_dataset)))
#'''
if max_num_test_dataset is not None and i >= max_num_test_dataset:
break
start_time = time.time()
for j, t1_name in enumerate(curr_dataset):
if max_num_per_dataset is not None and j >= max_num_per_dataset:
break
subj_name = os.path.basename(t1_name).split('.T1w')[0]
subj_dir = make_dir(os.path.join(save_dir, subj_name))
print('Now testing: %s (%d/%d)' % (t1_name, j+1, len(curr_dataset)))
modalities, aux = get_info(t1_name)
S_cerebral = torch.squeeze(utils.prepare_image(aux['cerebral_label'], win_size = win_size, zero_crop = zero_crop, spacing = spacing, rescale = False, im_only = True, device = device)) # read seg map
if 'hemis' in postfix:
S = utils.prepare_image(aux['cerebral_label'], win_size = win_size, zero_crop = zero_crop, spacing = spacing, rescale = False, im_only = True, device = device) # read seg map
S = lut[S.int()] # mask out non-left labels
X = utils.prepare_image(aux['regx'], win_size = win_size, zero_crop = zero_crop, spacing = spacing, rescale = False, im_only = True, device = device) # read_mni_coord_X
hemis_mask = (S > 0) & (X < 0).int() # (1, 1, s, r, c)
viewVolume(hemis_mask, names = ['.'.join(os.path.basename(aux['label']).split('.')[:2]) + '.hemis_mask'], save_dir = subj_dir)
else:
hemis_mask = None
# save all GT
for mod in modalities.keys():
final, orig, high_res, bf, _, _, _ = utils.prepare_image(modalities[mod], win_size = win_size, zero_crop = zero_crop, spacing = spacing, add_bf = add_bf, is_CT = 'CT' in mod, rescale = False, hemis_mask = hemis_mask, im_only = False, device = device)
viewVolume(orig, names = [os.path.basename(modalities[mod])[:-4]], save_dir = subj_dir)
viewVolume(final, names = [os.path.basename(modalities[mod])[:-4] + '.input'], save_dir = subj_dir)
viewVolume(high_res, names = [os.path.basename(modalities[mod])[:-4] + '.high_res'], save_dir = subj_dir)
if bf is not None:
viewVolume(bf, names = [os.path.basename(modalities[mod])[:-4] + '.bias_field'], save_dir = subj_dir)
for mod in aux.keys():
im = utils.prepare_image(aux[mod], win_size = win_size, zero_crop = zero_crop, is_label = 'label' in mod, rescale = False, hemis_mask = hemis_mask, im_only = True, device = device)
viewVolume(im, names = [os.path.basename(aux[mod])[:-4]], save_dir = subj_dir)
# testing
for mod in modalities.keys():
test_dir = make_dir(os.path.join(subj_dir, 'input_' + mod))
im = utils.prepare_image(os.path.join(subj_dir, os.path.basename(modalities[mod])[:-4] + '.input.nii.gz'), win_size = win_size, zero_crop = zero_crop, is_CT = 'CT' in mod, hemis_mask = hemis_mask, im_only = True, device = device)
outs = utils.evaluate_image(im, ckp_path = ckp_path, feature_only = False, device = device, gen_cfg = curr_gen_cfg, model_cfg = model_cfg)
if mask_output:
mask = im.clone()
mask[im != 0.] = 1.
for k, v in outs.items():
if 'feat' not in k and k not in exclude_keys:
viewVolume(v * mask if mask_output else v, names = [ 'out_' + k], save_dir = test_dir)
print(S_cerebral.shape, outs['regx'].shape)
deformed_atlas = utils.get_deformed_atlas(S_cerebral, torch.squeeze(outs['regx']), torch.squeeze(outs['regy']), torch.squeeze(outs['regz']))
viewVolume(deformed_atlas * mask if mask_output else deformed_atlas, names = [ 'out_deformed_atlas'], save_dir = test_dir)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Testing time for {}: {}'.format(total_time_str, datasets[i]))
all_total_time = time.time() - all_start_time
all_total_time_str = str(datetime.timedelta(seconds=int(all_total_time)))
print('Total testing time: {}'.format(total_time_str))
#'''