|
import os, sys, glob |
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
from collections import defaultdict |
|
import random |
|
|
|
import torch |
|
import numpy as np |
|
import nibabel as nib |
|
from torch.utils.data import Dataset |
|
|
|
|
|
from .utils import * |
|
from .constants import n_pathology, pathology_paths, pathology_prob_paths, \ |
|
n_neutral_labels_brainseg_with_extracerebral, label_list_segmentation_brainseg_with_extracerebral, \ |
|
label_list_segmentation_brainseg_left, augmentation_funcs, processing_funcs |
|
import utils.interpol as interpol |
|
|
|
from utils.misc import viewVolume |
|
|
|
|
|
from ShapeID.DiffEqs.pde import AdvDiffPDE |
|
|
|
|
|
|
|
class BaseGen(Dataset): |
|
""" |
|
BaseGen dataset |
|
""" |
|
def __init__(self, gen_args, device='cpu'): |
|
|
|
self.gen_args = gen_args |
|
self.split = gen_args.split |
|
|
|
self.synth_args = self.gen_args.generator |
|
self.shape_gen_args = gen_args.pathology_shape_generator |
|
self.real_image_args = gen_args.real_image_generator |
|
self.synth_image_args = gen_args.synth_image_generator |
|
self.augmentation_steps = vars(gen_args.augmentation_steps) |
|
self.input_prob = vars(gen_args.modality_probs) |
|
self.device = device |
|
|
|
self.prepare_tasks() |
|
self.prepare_paths() |
|
self.prepare_grid() |
|
self.prepare_one_hot() |
|
|
|
|
|
def __len__(self): |
|
return sum([len(self.names[i]) for i in range(len(self.names))]) |
|
|
|
|
|
def idx_to_path(self, idx): |
|
cnt = 0 |
|
for i, l in enumerate(self.datasets_len): |
|
if idx >= cnt and idx < cnt + l: |
|
dataset_name = self.datasets[i] |
|
age = self.ages[i][os.path.basename(self.names[i][idx - cnt]).split('.T1w')[0]] if len(self.ages) > 0 else None |
|
return dataset_name, vars(self.input_prob[dataset_name]), self.names[i][idx - cnt], age |
|
else: |
|
cnt += l |
|
|
|
|
|
def prepare_paths(self): |
|
|
|
|
|
if len(self.gen_args.dataset_names) < 1: |
|
datasets = [] |
|
g = glob.glob(os.path.join(self.gen_args.data_root, '*' + 'T1w.nii')) |
|
for i in range(len(g)): |
|
filename = os.path.basename(g[i]) |
|
dataset = filename[:filename.find('.')] |
|
found = False |
|
for d in datasets: |
|
if dataset == d: |
|
found = True |
|
if found is False: |
|
datasets.append(dataset) |
|
print('Found ' + str(len(datasets)) + ' datasets with ' + str(len(g)) + ' scans in total') |
|
else: |
|
datasets = self.gen_args.dataset_names |
|
print('Dataset list', datasets) |
|
|
|
|
|
names = [] |
|
if 'age' in self.tasks: |
|
self.split = self.split + '_age' |
|
if self.gen_args.split_root is not None: |
|
split_file = open(os.path.join(self.gen_args.split_root, self.split + '.txt'), 'r') |
|
split_names = [] |
|
for subj in split_file.readlines(): |
|
split_names.append(subj.strip()) |
|
|
|
for i in range(len(datasets)): |
|
names.append([name for name in split_names if os.path.basename(name).startswith(datasets[i])]) |
|
|
|
|
|
|
|
|
|
|
|
ages = [] |
|
if 'age' in self.tasks: |
|
age_file = open(os.path.join(self.gen_args.split_root, 'participants_age.txt'), 'r') |
|
subj_name_age = [] |
|
for line in age_file.readlines(): |
|
subj_name_age.append(line.strip().split(' ')) |
|
for i in range(len(datasets)): |
|
ages.append({}) |
|
for [name, age] in subj_name_age: |
|
if name.startswith(datasets[i]): |
|
ages[-1][name] = float(age) |
|
print('Age info', self.split, len(ages[0].items()), min(ages[0].values()), max(ages[0].values())) |
|
|
|
self.ages = ages |
|
self.names = names |
|
self.datasets = datasets |
|
self.datasets_num = len(datasets) |
|
self.datasets_len = [len(self.names[i]) for i in range(len(self.names))] |
|
print('Num of data', sum([len(self.names[i]) for i in range(len(self.names))])) |
|
|
|
self.pathology_type = None |
|
|
|
|
|
def prepare_tasks(self): |
|
self.tasks = [key for (key, value) in vars(self.gen_args.task).items() if value] |
|
if 'bias_field' in self.tasks and 'segmentation' not in self.tasks: |
|
|
|
self.tasks += ['segmentation'] |
|
if 'pathology' in self.tasks and self.synth_args.augment_pathology and self.synth_args.random_shape_prob < 1.: |
|
self.t = torch.from_numpy(np.arange(self.shape_gen_args.max_nt) * self.shape_gen_args.dt).to(self.device) |
|
with torch.no_grad(): |
|
self.adv_pde = AdvDiffPDE(data_spacing=[1., 1., 1.], |
|
perf_pattern='adv', |
|
V_type='vector_div_free', |
|
V_dict={}, |
|
BC=self.shape_gen_args.bc, |
|
dt=self.shape_gen_args.dt, |
|
device=self.device |
|
) |
|
else: |
|
self.t, self.adv_pde = None, None |
|
for task_name in self.tasks: |
|
if task_name not in processing_funcs.keys(): |
|
print('Warning: Function for task "%s" not found' % task_name) |
|
|
|
|
|
def prepare_grid(self): |
|
self.size = self.synth_args.size |
|
|
|
|
|
|
|
|
|
|
|
self.res_training_data = np.array([1.0, 1.0, 1.0]) |
|
|
|
xx, yy, zz = np.meshgrid(range(self.size[0]), range(self.size[1]), range(self.size[2]), sparse=False, indexing='ij') |
|
self.xx = torch.tensor(xx, dtype=torch.float, device=self.device) |
|
self.yy = torch.tensor(yy, dtype=torch.float, device=self.device) |
|
self.zz = torch.tensor(zz, dtype=torch.float, device=self.device) |
|
self.c = torch.tensor((np.array(self.size) - 1) / 2, dtype=torch.float, device=self.device) |
|
self.xc = self.xx - self.c[0] |
|
self.yc = self.yy - self.c[1] |
|
self.zc = self.zz - self.c[2] |
|
return |
|
|
|
def prepare_one_hot(self): |
|
if self.synth_args.left_hemis_only: |
|
n_labels = len(label_list_segmentation_brainseg_left) |
|
label_list_segmentation = label_list_segmentation_brainseg_left |
|
else: |
|
|
|
n_labels = len(label_list_segmentation_brainseg_with_extracerebral) |
|
label_list_segmentation = label_list_segmentation_brainseg_with_extracerebral |
|
|
|
self.lut = torch.zeros(10000, dtype=torch.long, device=self.device) |
|
for l in range(n_labels): |
|
self.lut[label_list_segmentation[l]] = l |
|
self.onehotmatrix = torch.eye(n_labels, dtype=torch.float, device=self.device) |
|
|
|
|
|
nlat = int((n_labels - n_neutral_labels_brainseg_with_extracerebral) / 2.0) |
|
self.vflip = np.concatenate([np.array(range(n_neutral_labels_brainseg_with_extracerebral)), |
|
np.array(range(n_neutral_labels_brainseg_with_extracerebral + nlat, n_labels)), |
|
np.array(range(n_neutral_labels_brainseg_with_extracerebral, n_neutral_labels_brainseg_with_extracerebral + nlat))]) |
|
return |
|
|
|
|
|
def random_affine_transform(self, shp): |
|
rotations = (2 * self.synth_args.max_rotation * np.random.rand(3) - self.synth_args.max_rotation) / 180.0 * np.pi |
|
shears = (2 * self.synth_args.max_shear * np.random.rand(3) - self.synth_args.max_shear) |
|
scalings = 1 + (2 * self.synth_args.max_scaling * np.random.rand(3) - self.synth_args.max_scaling) |
|
scaling_factor_distances = np.prod(scalings) ** .33333333333 |
|
A = torch.tensor(make_affine_matrix(rotations, shears, scalings), dtype=torch.float, device=self.device) |
|
|
|
|
|
if self.synth_args.random_shift: |
|
max_shift = (torch.tensor(np.array(shp[0:3]) - self.size, dtype=torch.float, device=self.device)) / 2 |
|
max_shift[max_shift < 0] = 0 |
|
c2 = torch.tensor((np.array(shp[0:3]) - 1)/2, dtype=torch.float, device=self.device) + (2 * (max_shift * torch.rand(3, dtype=float, device=self.device)) - max_shift) |
|
else: |
|
c2 = torch.tensor((np.array(shp[0:3]) - 1)/2, dtype=torch.float, device=self.device) |
|
return scaling_factor_distances, A, c2 |
|
|
|
def random_nonlinear_transform(self, photo_mode, spac): |
|
nonlin_scale = self.synth_args.nonlin_scale_min + np.random.rand(1) * (self.synth_args.nonlin_scale_max - self.synth_args.nonlin_scale_min) |
|
size_F_small = np.round(nonlin_scale * np.array(self.size)).astype(int).tolist() |
|
if photo_mode: |
|
size_F_small[1] = np.round(self.size[1]/spac).astype(int) |
|
nonlin_std = self.synth_args.nonlin_std_max * np.random.rand() |
|
Fsmall = nonlin_std * torch.randn([*size_F_small, 3], dtype=torch.float, device=self.device) |
|
F = myzoom_torch(Fsmall, np.array(self.size) / size_F_small) |
|
if photo_mode: |
|
F[:, :, :, 1] = 0 |
|
|
|
if 'surface' in self.tasks: |
|
steplength = 1.0 / (2.0 ** self.synth_args.n_steps_svf_integration) |
|
Fsvf = F * steplength |
|
for _ in range(self.synth_args.n_steps_svf_integration): |
|
Fsvf += fast_3D_interp_torch(Fsvf, self.xx + Fsvf[:, :, :, 0], self.yy + Fsvf[:, :, :, 1], self.zz + Fsvf[:, :, :, 2], 'linear') |
|
Fsvf_neg = -F * steplength |
|
for _ in range(self.synth_args.n_steps_svf_integration): |
|
Fsvf_neg += fast_3D_interp_torch(Fsvf_neg, self.xx + Fsvf_neg[:, :, :, 0], self.yy + Fsvf_neg[:, :, :, 1], self.zz + Fsvf_neg[:, :, :, 2], 'linear') |
|
F = Fsvf |
|
Fneg = Fsvf_neg |
|
else: |
|
Fneg = None |
|
return F, Fneg |
|
|
|
def generate_deformation(self, setups, shp): |
|
|
|
|
|
scaling_factor_distances, A, c2 = self.random_affine_transform(shp) |
|
|
|
|
|
if self.synth_args.nonlinear_transform: |
|
F, Fneg = self.random_nonlinear_transform(setups['photo_mode'], setups['spac']) |
|
else: |
|
F, Fneg = None, None |
|
|
|
|
|
xx2, yy2, zz2, x1, y1, z1, x2, y2, z2 = self.deform_grid(shp, A, c2, F) |
|
|
|
return {'scaling_factor_distances': scaling_factor_distances, |
|
'A': A, |
|
'c2': c2, |
|
'F': F, |
|
'Fneg': Fneg, |
|
'grid': [xx2, yy2, zz2, x1, y1, z1, x2, y2, z2], |
|
} |
|
|
|
|
|
def get_left_hemis_mask(self, grid): |
|
[xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = grid |
|
|
|
if self.synth_args.left_hemis_only: |
|
S, aff, res = read_image(self.modalities['segmentation']) |
|
S = torch.squeeze(torch.from_numpy(S.get_fdata()[x1:x2, y1:y2, z1:z2].astype(int))).to(self.device) |
|
S = self.lut[S.int()] |
|
X, aff, res = read_image(self.modalities['registration'][0]) |
|
X = torch.squeeze(torch.from_numpy(X.get_fdata()[x1:x2, y1:y2, z1:z2])).to(self.device) |
|
self.hemis_mask = ((S > 0) & (X < 0)).int() |
|
else: |
|
self.hemis_mask = None |
|
|
|
def deform_grid(self, shp, A, c2, F): |
|
if F is not None: |
|
|
|
xx1 = self.xc + F[:, :, :, 0] |
|
yy1 = self.yc + F[:, :, :, 1] |
|
zz1 = self.zc + F[:, :, :, 2] |
|
else: |
|
xx1 = self.xc |
|
yy1 = self.yc |
|
zz1 = self.zc |
|
|
|
xx2 = A[0, 0] * xx1 + A[0, 1] * yy1 + A[0, 2] * zz1 + c2[0] |
|
yy2 = A[1, 0] * xx1 + A[1, 1] * yy1 + A[1, 2] * zz1 + c2[1] |
|
zz2 = A[2, 0] * xx1 + A[2, 1] * yy1 + A[2, 2] * zz1 + c2[2] |
|
xx2[xx2 < 0] = 0 |
|
yy2[yy2 < 0] = 0 |
|
zz2[zz2 < 0] = 0 |
|
xx2[xx2 > (shp[0] - 1)] = shp[0] - 1 |
|
yy2[yy2 > (shp[1] - 1)] = shp[1] - 1 |
|
zz2[zz2 > (shp[2] - 1)] = shp[2] - 1 |
|
|
|
|
|
x1 = torch.floor(torch.min(xx2)) |
|
y1 = torch.floor(torch.min(yy2)) |
|
z1 = torch.floor(torch.min(zz2)) |
|
x2 = 1+torch.ceil(torch.max(xx2)) |
|
y2 = 1 + torch.ceil(torch.max(yy2)) |
|
z2 = 1 + torch.ceil(torch.max(zz2)) |
|
xx2 -= x1 |
|
yy2 -= y1 |
|
zz2 -= z1 |
|
|
|
x1 = x1.cpu().numpy().astype(int) |
|
y1 = y1.cpu().numpy().astype(int) |
|
z1 = z1.cpu().numpy().astype(int) |
|
x2 = x2.cpu().numpy().astype(int) |
|
y2 = y2.cpu().numpy().astype(int) |
|
z2 = z2.cpu().numpy().astype(int) |
|
|
|
return xx2, yy2, zz2, x1, y1, z1, x2, y2, z2 |
|
|
|
|
|
def augment_sample(self, name, I_def, setups, deform_dict, res, target, pathol_direction = None, input_mode = 'synth'): |
|
|
|
sample = {} |
|
[xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid'] |
|
|
|
if not isinstance(I_def, torch.Tensor): |
|
I_def = torch.squeeze(torch.tensor(I_def.get_fdata()[x1:x2, y1:y2, z1:z2].astype(float), dtype=torch.float, device=self.device)) |
|
if self.hemis_mask is not None: |
|
I_def[self.hemis_mask == 0] = 0 |
|
|
|
I_def = fast_3D_interp_torch(I_def, xx2, yy2, zz2, 'linear') |
|
|
|
if input_mode == 'CT': |
|
I_def = torch.clamp(I_def, min = 0., max = 80.) |
|
|
|
if 'pathology' in target and isinstance(target['pathology'], torch.Tensor) and target['pathology'].sum() > 0: |
|
I_def = self.encode_pathology(I_def, target['pathology'], target['pathology_prob'], pathol_direction) |
|
I_def[I_def < 0.] = 0. |
|
else: |
|
target['pathology'] = 0. |
|
target['pathology_prob'] = 0. |
|
|
|
|
|
aux_dict = {} |
|
augmentation_steps = self.augmentation_steps['synth'] if input_mode == 'synth' else self.augmentation_steps['real'] |
|
for func_name in augmentation_steps: |
|
I_def, aux_dict = augmentation_funcs[func_name](I = I_def, aux_dict = aux_dict, cfg = self.gen_args.generator, |
|
input_mode = input_mode, setups = setups, size = self.size, res = res, device = self.device) |
|
|
|
|
|
|
|
if self.synth_args.bspline_zooming: |
|
I_def = interpol.resize(I_def, shape=self.size, anchor='edge', interpolation=3, bound='dct2', prefilter=True) |
|
else: |
|
I_def = myzoom_torch(I_def, 1 / aux_dict['factors']) |
|
|
|
maxi = torch.max(I_def) |
|
I_final = I_def / maxi |
|
|
|
if 'super_resolution' in self.tasks: |
|
SRresidual = aux_dict['high_res'] / maxi - I_final |
|
sample.update({'high_res_residual': torch.flip(SRresidual, [0])[None] if setups['flip'] else SRresidual[None]}) |
|
|
|
|
|
sample.update({'input': torch.flip(I_final, [0])[None] if setups['flip'] else I_final[None]}) |
|
if 'bias_field' in self.tasks and input_mode != 'CT': |
|
sample.update({'bias_field_log': torch.flip(aux_dict['BFlog'], [0])[None] if setups['flip'] else aux_dict['BFlog'][None]}) |
|
|
|
return sample |
|
|
|
|
|
def generate_sample(self, name, G, setups, deform_dict, res, target): |
|
|
|
[xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid'] |
|
|
|
|
|
mus, sigmas = self.get_contrast(setups['photo_mode']) |
|
|
|
G = torch.squeeze(torch.tensor(G.get_fdata()[x1:x2, y1:y2, z1:z2].astype(float), dtype=torch.float, device=self.device)) |
|
|
|
G[G == 77] = 2 |
|
if self.hemis_mask is not None: |
|
G[self.hemis_mask == 0] = 0 |
|
Gr = torch.round(G).long() |
|
|
|
SYN = mus[Gr] + sigmas[Gr] * torch.randn(Gr.shape, dtype=torch.float, device=self.device) |
|
SYN[SYN < 0] = 0 |
|
|
|
|
|
|
|
SYN = fast_3D_interp_torch(SYN, xx2, yy2, zz2) |
|
|
|
|
|
if np.random.rand() < self.gen_args.mix_synth_prob: |
|
v = torch.rand(4) |
|
v[2] = 0 if 'T2' not in self.modalities else v[2] |
|
v[3] = 0 if 'FLAIR' not in self.modalities else v[3] |
|
v /= torch.sum(v) |
|
SYN = v[0] * SYN + v[1] * target['T1'][0] |
|
if 'T2' in self.modalities: |
|
SYN += v[2] * target['T2'][0] |
|
if 'FLAIR' in self.modalities: |
|
SYN += v[3] * target['FLAIR'][0] |
|
|
|
if 'pathology' in target and isinstance(target['pathology'], torch.Tensor) and target['pathology'].sum() > 0: |
|
SYN_cerebral = SYN.clone() |
|
SYN_cerebral[Gr == 0] = 0 |
|
SYN_cerebral = fast_3D_interp_torch(SYN_cerebral, xx2, yy2, zz2)[None] |
|
|
|
wm_mask = (Gr==2) | (Gr==41) |
|
wm_mean = (SYN * wm_mask).sum() / wm_mask.sum() |
|
gm_mask = (Gr!=0) & (Gr!=2) & (Gr!=41) |
|
gm_mean = (SYN * gm_mask).sum() / gm_mask.sum() |
|
|
|
target['pathology'][SYN_cerebral == 0] = 0 |
|
target['pathology_prob'][SYN_cerebral == 0] = 0 |
|
|
|
|
|
|
|
pathol_direction = self.get_pathology_direction('synth', gm_mean > wm_mean) |
|
else: |
|
pathol_direction = None |
|
target['pathology'] = 0. |
|
target['pathology_prob'] = 0. |
|
|
|
SYN[SYN < 0.] = 0. |
|
return target['pathology'], target['pathology_prob'], self.augment_sample(name, SYN, setups, deform_dict, res, target, pathol_direction = pathol_direction) |
|
|
|
def get_pathology_direction(self, input_mode, pathol_direction = None): |
|
|
|
|
|
|
|
if pathol_direction is not None: |
|
return pathol_direction |
|
|
|
if input_mode in ['T1', 'CT']: |
|
return False |
|
|
|
if input_mode in ['T2', 'FLAIR']: |
|
return True |
|
|
|
return random.choice([True, False]) |
|
|
|
|
|
def get_contrast(self, photo_mode): |
|
|
|
mus = 25 + 200 * torch.rand(256, dtype=torch.float, device=self.device) |
|
sigmas = 5 + 20 * torch.rand(256, dtype=torch.float, device=self.device) |
|
|
|
if np.random.rand() < self.synth_args.ct_prob: |
|
darker = 25 + 10 * torch.rand(1, dtype=torch.float, device=self.device)[0] |
|
for l in ct_brightness_group['darker']: |
|
mus[l] = darker |
|
dark = 90 + 20 * torch.rand(1, dtype=torch.float, device=self.device)[0] |
|
for l in ct_brightness_group['dark']: |
|
mus[l] = dark |
|
bright = 110 + 20 * torch.rand(1, dtype=torch.float, device=self.device)[0] |
|
for l in ct_brightness_group['bright']: |
|
mus[l] = bright |
|
brighter = 150 + 50 * torch.rand(1, dtype=torch.float, device=self.device)[0] |
|
for l in ct_brightness_group['brighter']: |
|
mus[l] = brighter |
|
|
|
if photo_mode or np.random.rand(1)<0.5: |
|
mus[0] = 0 |
|
|
|
|
|
|
|
v = 0.02 * torch.arange(50).to(self.device) |
|
mus[100:150] = mus[1] * (1 - v) + mus[2] * v |
|
mus[150:200] = mus[2] * (1 - v) + mus[3] * v |
|
mus[200:250] = mus[3] * (1 - v) + mus[4] * v |
|
mus[250] = mus[4] |
|
sigmas[100:150] = torch.sqrt(sigmas[1]**2 * (1 - v) + sigmas[2]**2 * v) |
|
sigmas[150:200] = torch.sqrt(sigmas[2]**2 * (1 - v) + sigmas[3]**2 * v) |
|
sigmas[200:250] = torch.sqrt(sigmas[3]**2 * (1 - v) + sigmas[4]**2 * v) |
|
sigmas[250] = sigmas[4] |
|
|
|
return mus, sigmas |
|
|
|
def get_setup_params(self): |
|
|
|
if self.synth_args.left_hemis_only: |
|
hemis = 'left' |
|
else: |
|
hemis = 'both' |
|
|
|
if self.synth_args.low_res_only: |
|
photo_mode = False |
|
elif self.synth_args.left_hemis_only: |
|
photo_mode = True |
|
else: |
|
photo_mode = np.random.rand() < self.synth_args.photo_prob |
|
|
|
pathol_mode = np.random.rand() < self.synth_args.pathology_prob |
|
pathol_random_shape = np.random.rand() < self.synth_args.random_shape_prob |
|
spac = 2.5 + 10 * np.random.rand() if photo_mode else None |
|
flip = np.random.randn() < self.synth_args.flip_prob if not self.synth_args.left_hemis_only else False |
|
|
|
if photo_mode: |
|
resolution = np.array([self.res_training_data[0], spac, self.res_training_data[2]]) |
|
thickness = np.array([self.res_training_data[0], 0.1, self.res_training_data[2]]) |
|
else: |
|
resolution, thickness = resolution_sampler(self.synth_args.low_res_only) |
|
return {'resolution': resolution, 'thickness': thickness, |
|
'photo_mode': photo_mode, 'pathol_mode': pathol_mode, |
|
'pathol_random_shape': pathol_random_shape, |
|
'spac': spac, 'flip': flip, 'hemis': hemis} |
|
|
|
|
|
def encode_pathology(self, I, P, Pprob, pathol_direction = None): |
|
|
|
|
|
if pathol_direction is None: |
|
pathol_direction = random.choice([True, False]) |
|
|
|
P, Pprob = torch.squeeze(P), torch.squeeze(Pprob) |
|
I_mu = (I * P).sum() / P.sum() |
|
|
|
p_mask = torch.round(P).long() |
|
|
|
pth_mus = 3*I_mu/4 + I_mu/4 * torch.rand(10000, dtype=torch.float, device=self.device) |
|
pth_mus = pth_mus if pathol_direction else -pth_mus |
|
pth_sigmas = I_mu/4 * torch.rand(10000, dtype=torch.float, device=self.device) |
|
I += Pprob * (pth_mus[p_mask] + pth_sigmas[p_mask] * torch.randn(p_mask.shape, dtype=torch.float, device=self.device)) |
|
I[I < 0] = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
return I |
|
|
|
def get_info(self, t1): |
|
|
|
t1dm = t1[:-7] + 'T1w.defacingmask.nii' |
|
t2 = t1[:-7] + 'T2w.nii' |
|
t2dm = t1[:-7] + 'T2w.defacingmask.nii' |
|
flair = t1[:-7] + 'FLAIR.nii' |
|
flairdm = t1[:-7] + 'FLAIR.defacingmask.nii' |
|
ct = t1[:-7] + 'CT.nii' |
|
ctdm = t1[:-7] + 'CT.defacingmask.nii' |
|
generation_labels = t1[:-7] + 'generation_labels.nii' |
|
segmentation_labels = t1[:-7] + self.gen_args.segment_prefix + '.nii' |
|
|
|
lp_dist_map = t1[:-7] + 'lp_dist_map.nii' |
|
rp_dist_map = t1[:-7] + 'rp_dist_map.nii' |
|
lw_dist_map = t1[:-7] + 'lw_dist_map.nii' |
|
rw_dist_map = t1[:-7] + 'rw_dist_map.nii' |
|
mni_reg_x = t1[:-7] + 'mni_reg.x.nii' |
|
mni_reg_y = t1[:-7] + 'mni_reg.y.nii' |
|
mni_reg_z = t1[:-7] + 'mni_reg.z.nii' |
|
|
|
|
|
self.modalities = {'T1': t1, 'Gen': generation_labels, 'segmentation': segmentation_labels, |
|
'distance': [lp_dist_map, lw_dist_map, rp_dist_map, rw_dist_map], |
|
'registration': [mni_reg_x, mni_reg_y, mni_reg_z]} |
|
|
|
if os.path.isfile(t1dm): |
|
self.modalities.update({'T1_DM': t1dm}) |
|
if os.path.isfile(t2): |
|
self.modalities.update({'T2': t2}) |
|
if os.path.isfile(t2dm): |
|
self.modalities.update({'T2_DM': t2dm}) |
|
if os.path.isfile(flair): |
|
self.modalities.update({'FLAIR': flair}) |
|
if os.path.isfile(flairdm): |
|
self.modalities.update({'FLAIR_DM': flairdm}) |
|
if os.path.isfile(ct): |
|
self.modalities.update({'CT': ct}) |
|
if os.path.isfile(ctdm): |
|
self.modalities.update({'CT_DM': ctdm}) |
|
|
|
return self.modalities |
|
|
|
|
|
def read_input(self, idx): |
|
""" |
|
determine input type according to prob (in generator/constants.py) |
|
Logic: if np.random.rand() < real_image_prob and is real_image_exist --> input real images; otherwise, synthesize images. |
|
""" |
|
dataset_name, input_prob, t1_path, age = self.idx_to_path(idx) |
|
case_name = os.path.basename(t1_path).split('.T1w.nii')[0] |
|
self.modalities = self.get_info(t1_path) |
|
|
|
prob = np.random.rand() |
|
if prob < input_prob['T1'] and 'T1' in self.modalities: |
|
input_mode = 'T1' |
|
img, aff, res = read_image(self.modalities['T1']) |
|
elif prob < input_prob['T2'] and 'T2' in self.modalities: |
|
input_mode = 'T2' |
|
img, aff, res = read_image(self.modalities['T2']) |
|
elif prob < input_prob['FLAIR'] and 'FLAIR' in self.modalities: |
|
input_mode = 'FLAIR' |
|
img, aff, res = read_image(self.modalities['FLAIR']) |
|
elif prob < input_prob['CT'] and 'CT' in self.modalities: |
|
input_mode = 'CT' |
|
img, aff, res = read_image(self.modalities['CT']) |
|
else: |
|
input_mode = 'synth' |
|
img, aff, res = read_image(self.modalities['Gen']) |
|
|
|
return dataset_name, case_name, input_mode, img, aff, res, age |
|
|
|
|
|
def read_and_deform_target(self, idx, exist_keys, task_name, input_mode, setups, deform_dict, linear_weights = None): |
|
current_target = {} |
|
p_prob_path, augment, thres = None, False, 0.1 |
|
|
|
if task_name == 'pathology': |
|
|
|
|
|
if self.pathology_type is None: |
|
if setups['pathol_mode']: |
|
if setups['pathol_random_shape']: |
|
p_prob_path = 'random_shape' |
|
augment, thres = False, self.shape_gen_args.pathol_thres |
|
else: |
|
p_prob_path = random.choice(pathology_prob_paths) |
|
augment, thres = self.synth_args.augment_pathology, self.shape_gen_args.pathol_thres |
|
else: |
|
pass |
|
|
|
|
|
current_target = processing_funcs[task_name](exist_keys, task_name, p_prob_path, setups, deform_dict, self.device, |
|
mask = self.hemis_mask, |
|
augment = augment, |
|
pde_func = self.adv_pde, |
|
t = self.t, |
|
shape_gen_args = self.shape_gen_args, |
|
thres = thres |
|
) |
|
|
|
else: |
|
if task_name in self.modalities: |
|
current_target = processing_funcs[task_name](exist_keys, task_name, self.modalities[task_name], |
|
setups, deform_dict, self.device, |
|
mask = self.hemis_mask, |
|
cfg = self.gen_args, |
|
onehotmatrix = self.onehotmatrix, |
|
lut = self.lut, vflip = self.vflip |
|
) |
|
else: |
|
current_target = {task_name: 0.} |
|
return current_target |
|
|
|
|
|
def update_gen_args(self, new_args): |
|
for key, value in vars(new_args).items(): |
|
vars(self.gen_args.generator)[key] = value |
|
|
|
def __getitem__(self, idx): |
|
if torch.is_tensor(idx): |
|
idx = idx.tolist() |
|
|
|
|
|
dataset_name, case_name, input_mode, img, aff, res, age = self.read_input(idx) |
|
|
|
|
|
setups = self.get_setup_params() |
|
|
|
|
|
deform_dict = self.generate_deformation(setups, img.shape) |
|
|
|
|
|
self.get_left_hemis_mask(deform_dict['grid']) |
|
|
|
|
|
target = defaultdict(lambda: None) |
|
target['name'] = case_name |
|
target.update(self.read_and_deform_target(idx, target.keys(), 'T1', input_mode, setups, deform_dict)) |
|
target.update(self.read_and_deform_target(idx, target.keys(), 'T2', input_mode, setups, deform_dict)) |
|
target.update(self.read_and_deform_target(idx, target.keys(), 'FLAIR', input_mode, setups, deform_dict)) |
|
for task_name in self.tasks: |
|
if task_name in processing_funcs.keys() and task_name not in ['T1', 'T2', 'FLAIR']: |
|
target.update(self.read_and_deform_target(idx, target.keys(), task_name, input_mode, setups, deform_dict)) |
|
|
|
|
|
|
|
if input_mode == 'synth': |
|
self.update_gen_args(self.synth_image_args) |
|
target['pathology'], target['pathology_prob'], sample = \ |
|
self.generate_sample(case_name, img, setups, deform_dict, res, target) |
|
else: |
|
self.update_gen_args(self.real_image_args) |
|
sample = self.augment_sample(case_name, img, setups, deform_dict, res, target, |
|
pathol_direction = self.get_pathology_direction(input_mode),input_mode = input_mode) |
|
|
|
if setups['flip'] and isinstance(target['pathology'], torch.Tensor): |
|
target['pathology'], target['pathology_prob'] = torch.flip(target['pathology'], [1]), torch.flip(target['pathology_prob'], [1]) |
|
|
|
if age is not None: |
|
target['age'] = age |
|
|
|
return self.datasets_num, dataset_name, input_mode, target, sample |
|
|
|
|
|
|
|
|
|
|
|
class BrainIDGen(BaseGen): |
|
""" |
|
BrainIDGen dataset |
|
BrainIDGen enables intra-subject augmentation, i.e., each subject will have multiple augmentations |
|
""" |
|
def __init__(self, gen_args, device='cpu'): |
|
super(BrainIDGen, self).__init__(gen_args, device) |
|
|
|
self.all_samples = gen_args.generator.all_samples |
|
self.mild_samples = gen_args.generator.mild_samples |
|
self.mild_generator_args = gen_args.mild_generator |
|
self.severe_generator_args = gen_args.severe_generator |
|
|
|
def __getitem__(self, idx): |
|
if torch.is_tensor(idx): |
|
idx = idx.tolist() |
|
|
|
|
|
dataset_name, case_name, input_mode, img, aff, res, age = self.read_input(idx) |
|
|
|
|
|
setups = self.get_setup_params() |
|
|
|
|
|
deform_dict = self.generate_deformation(setups, img.shape) |
|
|
|
|
|
self.get_left_hemis_mask(deform_dict['grid']) |
|
|
|
|
|
target = defaultdict(lambda: 1.) |
|
target['name'] = case_name |
|
target.update(self.read_and_deform_target(idx, target.keys(), 'T1', input_mode, setups, deform_dict)) |
|
target.update(self.read_and_deform_target(idx, target.keys(), 'T2', input_mode, setups, deform_dict)) |
|
target.update(self.read_and_deform_target(idx, target.keys(), 'FLAIR', input_mode, setups, deform_dict)) |
|
for task_name in self.tasks: |
|
if task_name in processing_funcs.keys() and task_name not in ['T1', 'T2', 'FLAIR']: |
|
target.update(self.read_and_deform_target(idx, target.keys(), task_name, input_mode, setups, deform_dict)) |
|
|
|
|
|
samples = [] |
|
for i_sample in range(self.all_samples): |
|
if i_sample < self.mild_samples: |
|
self.update_gen_args(self.mild_generator_args) |
|
if input_mode == 'synth': |
|
self.update_gen_args(self.synth_image_args) |
|
target['pathology'], target['pathology_prob'], sample = \ |
|
self.generate_sample(case_name, img, setups, deform_dict, res, target) |
|
else: |
|
self.update_gen_args(self.real_image_args) |
|
sample = self.augment_sample(case_name, img, setups, deform_dict, res, target, |
|
pathol_direction = self.get_pathology_direction(input_mode),input_mode = input_mode) |
|
else: |
|
self.update_gen_args(self.severe_generator_args) |
|
if input_mode == 'synth': |
|
self.update_gen_args(self.synth_image_args) |
|
target['pathology'], target['pathology_prob'], sample = \ |
|
self.generate_sample(case_name, img, setups, deform_dict, res, target) |
|
else: |
|
self.update_gen_args(self.real_image_args) |
|
sample = self.augment_sample(case_name, img, setups, deform_dict, res, target, |
|
pathol_direction = self.get_pathology_direction(input_mode),input_mode = input_mode) |
|
|
|
samples.append(sample) |
|
|
|
if setups['flip'] and isinstance(target['pathology'], torch.Tensor): |
|
target['pathology'], target['pathology_prob'] = torch.flip(target['pathology'], [1]), torch.flip(target['pathology_prob'], [1]) |
|
|
|
if age is not None: |
|
target['age'] = age |
|
return self.datasets_num, dataset_name, input_mode, target, samples |