BrainFM / Generator /datasets.py
peirong26's picture
Upload 187 files
2571f24 verified
import os, sys, glob
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from collections import defaultdict
import random
import torch
import numpy as np
import nibabel as nib
from torch.utils.data import Dataset
from .utils import *
from .constants import n_pathology, pathology_paths, pathology_prob_paths, \
n_neutral_labels_brainseg_with_extracerebral, label_list_segmentation_brainseg_with_extracerebral, \
label_list_segmentation_brainseg_left, augmentation_funcs, processing_funcs
import utils.interpol as interpol
from utils.misc import viewVolume
from ShapeID.DiffEqs.pde import AdvDiffPDE
class BaseGen(Dataset):
"""
BaseGen dataset
"""
def __init__(self, gen_args, device='cpu'):
self.gen_args = gen_args
self.split = gen_args.split
self.synth_args = self.gen_args.generator
self.shape_gen_args = gen_args.pathology_shape_generator
self.real_image_args = gen_args.real_image_generator
self.synth_image_args = gen_args.synth_image_generator
self.augmentation_steps = vars(gen_args.augmentation_steps)
self.input_prob = vars(gen_args.modality_probs)
self.device = device
self.prepare_tasks()
self.prepare_paths()
self.prepare_grid()
self.prepare_one_hot()
def __len__(self):
return sum([len(self.names[i]) for i in range(len(self.names))])
def idx_to_path(self, idx):
cnt = 0
for i, l in enumerate(self.datasets_len):
if idx >= cnt and idx < cnt + l:
dataset_name = self.datasets[i]
age = self.ages[i][os.path.basename(self.names[i][idx - cnt]).split('.T1w')[0]] if len(self.ages) > 0 else None
return dataset_name, vars(self.input_prob[dataset_name]), self.names[i][idx - cnt], age
else:
cnt += l
def prepare_paths(self):
# Collect list of available images, per dataset
if len(self.gen_args.dataset_names) < 1:
datasets = []
g = glob.glob(os.path.join(self.gen_args.data_root, '*' + 'T1w.nii'))
for i in range(len(g)):
filename = os.path.basename(g[i])
dataset = filename[:filename.find('.')]
found = False
for d in datasets:
if dataset == d:
found = True
if found is False:
datasets.append(dataset)
print('Found ' + str(len(datasets)) + ' datasets with ' + str(len(g)) + ' scans in total')
else:
datasets = self.gen_args.dataset_names
print('Dataset list', datasets)
names = []
if 'age' in self.tasks:
self.split = self.split + '_age'
if self.gen_args.split_root is not None:
split_file = open(os.path.join(self.gen_args.split_root, self.split + '.txt'), 'r')
split_names = []
for subj in split_file.readlines():
split_names.append(subj.strip())
for i in range(len(datasets)):
names.append([name for name in split_names if os.path.basename(name).startswith(datasets[i])])
#else:
# for i in range(len(datasets)):
# names.append(glob.glob(os.path.join(self.gen_args.data_root, datasets[i] + '.*' + 'T1w.nii')))
# read brain age
ages = []
if 'age' in self.tasks:
age_file = open(os.path.join(self.gen_args.split_root, 'participants_age.txt'), 'r')
subj_name_age = []
for line in age_file.readlines(): # 'subj age\n'
subj_name_age.append(line.strip().split(' '))
for i in range(len(datasets)):
ages.append({})
for [name, age] in subj_name_age:
if name.startswith(datasets[i]):
ages[-1][name] = float(age)
print('Age info', self.split, len(ages[0].items()), min(ages[0].values()), max(ages[0].values()))
self.ages = ages
self.names = names
self.datasets = datasets
self.datasets_num = len(datasets)
self.datasets_len = [len(self.names[i]) for i in range(len(self.names))]
print('Num of data', sum([len(self.names[i]) for i in range(len(self.names))]))
self.pathology_type = None #setup_dict['pathology_type']
def prepare_tasks(self):
self.tasks = [key for (key, value) in vars(self.gen_args.task).items() if value]
if 'bias_field' in self.tasks and 'segmentation' not in self.tasks:
# add segmentation mask for computing bias_field_soft_mask
self.tasks += ['segmentation']
if 'pathology' in self.tasks and self.synth_args.augment_pathology and self.synth_args.random_shape_prob < 1.:
self.t = torch.from_numpy(np.arange(self.shape_gen_args.max_nt) * self.shape_gen_args.dt).to(self.device)
with torch.no_grad():
self.adv_pde = AdvDiffPDE(data_spacing=[1., 1., 1.],
perf_pattern='adv',
V_type='vector_div_free',
V_dict={},
BC=self.shape_gen_args.bc,
dt=self.shape_gen_args.dt,
device=self.device
)
else:
self.t, self.adv_pde = None, None
for task_name in self.tasks:
if task_name not in processing_funcs.keys():
print('Warning: Function for task "%s" not found' % task_name)
def prepare_grid(self):
self.size = self.synth_args.size
# Get resolution of training data
#aff = nib.load(os.path.join(self.modalities['Gen'], self.names[0])).affine
#self.res_training_data = np.sqrt(np.sum(abs(aff[:-1, :-1]), axis=0))
self.res_training_data = np.array([1.0, 1.0, 1.0])
xx, yy, zz = np.meshgrid(range(self.size[0]), range(self.size[1]), range(self.size[2]), sparse=False, indexing='ij')
self.xx = torch.tensor(xx, dtype=torch.float, device=self.device)
self.yy = torch.tensor(yy, dtype=torch.float, device=self.device)
self.zz = torch.tensor(zz, dtype=torch.float, device=self.device)
self.c = torch.tensor((np.array(self.size) - 1) / 2, dtype=torch.float, device=self.device)
self.xc = self.xx - self.c[0]
self.yc = self.yy - self.c[1]
self.zc = self.zz - self.c[2]
return
def prepare_one_hot(self):
if self.synth_args.left_hemis_only:
n_labels = len(label_list_segmentation_brainseg_left)
label_list_segmentation = label_list_segmentation_brainseg_left
else:
# Matrix for one-hot encoding (includes a lookup-table)
n_labels = len(label_list_segmentation_brainseg_with_extracerebral)
label_list_segmentation = label_list_segmentation_brainseg_with_extracerebral
self.lut = torch.zeros(10000, dtype=torch.long, device=self.device)
for l in range(n_labels):
self.lut[label_list_segmentation[l]] = l
self.onehotmatrix = torch.eye(n_labels, dtype=torch.float, device=self.device)
# useless for left_hemis_only
nlat = int((n_labels - n_neutral_labels_brainseg_with_extracerebral) / 2.0)
self.vflip = np.concatenate([np.array(range(n_neutral_labels_brainseg_with_extracerebral)),
np.array(range(n_neutral_labels_brainseg_with_extracerebral + nlat, n_labels)),
np.array(range(n_neutral_labels_brainseg_with_extracerebral, n_neutral_labels_brainseg_with_extracerebral + nlat))])
return
def random_affine_transform(self, shp):
rotations = (2 * self.synth_args.max_rotation * np.random.rand(3) - self.synth_args.max_rotation) / 180.0 * np.pi
shears = (2 * self.synth_args.max_shear * np.random.rand(3) - self.synth_args.max_shear)
scalings = 1 + (2 * self.synth_args.max_scaling * np.random.rand(3) - self.synth_args.max_scaling)
scaling_factor_distances = np.prod(scalings) ** .33333333333
A = torch.tensor(make_affine_matrix(rotations, shears, scalings), dtype=torch.float, device=self.device)
# sample center
if self.synth_args.random_shift:
max_shift = (torch.tensor(np.array(shp[0:3]) - self.size, dtype=torch.float, device=self.device)) / 2
max_shift[max_shift < 0] = 0
c2 = torch.tensor((np.array(shp[0:3]) - 1)/2, dtype=torch.float, device=self.device) + (2 * (max_shift * torch.rand(3, dtype=float, device=self.device)) - max_shift)
else:
c2 = torch.tensor((np.array(shp[0:3]) - 1)/2, dtype=torch.float, device=self.device)
return scaling_factor_distances, A, c2
def random_nonlinear_transform(self, photo_mode, spac):
nonlin_scale = self.synth_args.nonlin_scale_min + np.random.rand(1) * (self.synth_args.nonlin_scale_max - self.synth_args.nonlin_scale_min)
size_F_small = np.round(nonlin_scale * np.array(self.size)).astype(int).tolist()
if photo_mode:
size_F_small[1] = np.round(self.size[1]/spac).astype(int)
nonlin_std = self.synth_args.nonlin_std_max * np.random.rand()
Fsmall = nonlin_std * torch.randn([*size_F_small, 3], dtype=torch.float, device=self.device)
F = myzoom_torch(Fsmall, np.array(self.size) / size_F_small)
if photo_mode:
F[:, :, :, 1] = 0
if 'surface' in self.tasks: # TODO need to integrate the non-linear deformation fields for inverse
steplength = 1.0 / (2.0 ** self.synth_args.n_steps_svf_integration)
Fsvf = F * steplength
for _ in range(self.synth_args.n_steps_svf_integration):
Fsvf += fast_3D_interp_torch(Fsvf, self.xx + Fsvf[:, :, :, 0], self.yy + Fsvf[:, :, :, 1], self.zz + Fsvf[:, :, :, 2], 'linear')
Fsvf_neg = -F * steplength
for _ in range(self.synth_args.n_steps_svf_integration):
Fsvf_neg += fast_3D_interp_torch(Fsvf_neg, self.xx + Fsvf_neg[:, :, :, 0], self.yy + Fsvf_neg[:, :, :, 1], self.zz + Fsvf_neg[:, :, :, 2], 'linear')
F = Fsvf
Fneg = Fsvf_neg
else:
Fneg = None
return F, Fneg
def generate_deformation(self, setups, shp):
# generate affine deformation
scaling_factor_distances, A, c2 = self.random_affine_transform(shp)
# generate nonlinear deformation
if self.synth_args.nonlinear_transform:
F, Fneg = self.random_nonlinear_transform(setups['photo_mode'], setups['spac'])
else:
F, Fneg = None, None
# deform the image grid
xx2, yy2, zz2, x1, y1, z1, x2, y2, z2 = self.deform_grid(shp, A, c2, F)
return {'scaling_factor_distances': scaling_factor_distances,
'A': A,
'c2': c2,
'F': F,
'Fneg': Fneg,
'grid': [xx2, yy2, zz2, x1, y1, z1, x2, y2, z2],
}
def get_left_hemis_mask(self, grid):
[xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = grid
if self.synth_args.left_hemis_only:
S, aff, res = read_image(self.modalities['segmentation']) # read seg map
S = torch.squeeze(torch.from_numpy(S.get_fdata()[x1:x2, y1:y2, z1:z2].astype(int))).to(self.device)
S = self.lut[S.int()] # mask out non-left labels
X, aff, res = read_image(self.modalities['registration'][0]) # read_mni_coord_X
X = torch.squeeze(torch.from_numpy(X.get_fdata()[x1:x2, y1:y2, z1:z2])).to(self.device)
self.hemis_mask = ((S > 0) & (X < 0)).int()
else:
self.hemis_mask = None
def deform_grid(self, shp, A, c2, F):
if F is not None:
# deform the images (we do nonlinear "first" ie after so we can do heavy coronal deformations in photo mode)
xx1 = self.xc + F[:, :, :, 0]
yy1 = self.yc + F[:, :, :, 1]
zz1 = self.zc + F[:, :, :, 2]
else:
xx1 = self.xc
yy1 = self.yc
zz1 = self.zc
xx2 = A[0, 0] * xx1 + A[0, 1] * yy1 + A[0, 2] * zz1 + c2[0]
yy2 = A[1, 0] * xx1 + A[1, 1] * yy1 + A[1, 2] * zz1 + c2[1]
zz2 = A[2, 0] * xx1 + A[2, 1] * yy1 + A[2, 2] * zz1 + c2[2]
xx2[xx2 < 0] = 0
yy2[yy2 < 0] = 0
zz2[zz2 < 0] = 0
xx2[xx2 > (shp[0] - 1)] = shp[0] - 1
yy2[yy2 > (shp[1] - 1)] = shp[1] - 1
zz2[zz2 > (shp[2] - 1)] = shp[2] - 1
# Get the margins for reading images
x1 = torch.floor(torch.min(xx2))
y1 = torch.floor(torch.min(yy2))
z1 = torch.floor(torch.min(zz2))
x2 = 1+torch.ceil(torch.max(xx2))
y2 = 1 + torch.ceil(torch.max(yy2))
z2 = 1 + torch.ceil(torch.max(zz2))
xx2 -= x1
yy2 -= y1
zz2 -= z1
x1 = x1.cpu().numpy().astype(int)
y1 = y1.cpu().numpy().astype(int)
z1 = z1.cpu().numpy().astype(int)
x2 = x2.cpu().numpy().astype(int)
y2 = y2.cpu().numpy().astype(int)
z2 = z2.cpu().numpy().astype(int)
return xx2, yy2, zz2, x1, y1, z1, x2, y2, z2
def augment_sample(self, name, I_def, setups, deform_dict, res, target, pathol_direction = None, input_mode = 'synth'):
sample = {}
[xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid']
if not isinstance(I_def, torch.Tensor):
I_def = torch.squeeze(torch.tensor(I_def.get_fdata()[x1:x2, y1:y2, z1:z2].astype(float), dtype=torch.float, device=self.device))
if self.hemis_mask is not None:
I_def[self.hemis_mask == 0] = 0
# Deform grid
I_def = fast_3D_interp_torch(I_def, xx2, yy2, zz2, 'linear')
if input_mode == 'CT':
I_def = torch.clamp(I_def, min = 0., max = 80.)
if 'pathology' in target and isinstance(target['pathology'], torch.Tensor) and target['pathology'].sum() > 0:
I_def = self.encode_pathology(I_def, target['pathology'], target['pathology_prob'], pathol_direction)
I_def[I_def < 0.] = 0.
else:
target['pathology'] = 0.
target['pathology_prob'] = 0.
# Augment sample
aux_dict = {}
augmentation_steps = self.augmentation_steps['synth'] if input_mode == 'synth' else self.augmentation_steps['real']
for func_name in augmentation_steps:
I_def, aux_dict = augmentation_funcs[func_name](I = I_def, aux_dict = aux_dict, cfg = self.gen_args.generator,
input_mode = input_mode, setups = setups, size = self.size, res = res, device = self.device)
# Back to original resolution
if self.synth_args.bspline_zooming:
I_def = interpol.resize(I_def, shape=self.size, anchor='edge', interpolation=3, bound='dct2', prefilter=True)
else:
I_def = myzoom_torch(I_def, 1 / aux_dict['factors'])
maxi = torch.max(I_def)
I_final = I_def / maxi
if 'super_resolution' in self.tasks:
SRresidual = aux_dict['high_res'] / maxi - I_final
sample.update({'high_res_residual': torch.flip(SRresidual, [0])[None] if setups['flip'] else SRresidual[None]})
sample.update({'input': torch.flip(I_final, [0])[None] if setups['flip'] else I_final[None]})
if 'bias_field' in self.tasks and input_mode != 'CT':
sample.update({'bias_field_log': torch.flip(aux_dict['BFlog'], [0])[None] if setups['flip'] else aux_dict['BFlog'][None]})
return sample
def generate_sample(self, name, G, setups, deform_dict, res, target):
[xx2, yy2, zz2, x1, y1, z1, x2, y2, z2] = deform_dict['grid']
# Generate contrasts
mus, sigmas = self.get_contrast(setups['photo_mode'])
G = torch.squeeze(torch.tensor(G.get_fdata()[x1:x2, y1:y2, z1:z2].astype(float), dtype=torch.float, device=self.device))
#G[G > 255] = 0 # kill extracerebral regions
G[G == 77] = 2 # merge WM lesion to white matter region
if self.hemis_mask is not None:
G[self.hemis_mask == 0] = 0
Gr = torch.round(G).long()
SYN = mus[Gr] + sigmas[Gr] * torch.randn(Gr.shape, dtype=torch.float, device=self.device)
SYN[SYN < 0] = 0
#SYN /= mus[2] # normalize by WM
#SYN = gaussian_blur_3d(SYN, 0.5*np.ones(3), self.device) # cosmetic
SYN = fast_3D_interp_torch(SYN, xx2, yy2, zz2)
# Make random linear combinations
if np.random.rand() < self.gen_args.mix_synth_prob:
v = torch.rand(4)
v[2] = 0 if 'T2' not in self.modalities else v[2]
v[3] = 0 if 'FLAIR' not in self.modalities else v[3]
v /= torch.sum(v)
SYN = v[0] * SYN + v[1] * target['T1'][0]
if 'T2' in self.modalities:
SYN += v[2] * target['T2'][0]
if 'FLAIR' in self.modalities:
SYN += v[3] * target['FLAIR'][0]
if 'pathology' in target and isinstance(target['pathology'], torch.Tensor) and target['pathology'].sum() > 0:
SYN_cerebral = SYN.clone()
SYN_cerebral[Gr == 0] = 0
SYN_cerebral = fast_3D_interp_torch(SYN_cerebral, xx2, yy2, zz2)[None]
wm_mask = (Gr==2) | (Gr==41)
wm_mean = (SYN * wm_mask).sum() / wm_mask.sum()
gm_mask = (Gr!=0) & (Gr!=2) & (Gr!=41)
gm_mean = (SYN * gm_mask).sum() / gm_mask.sum()
target['pathology'][SYN_cerebral == 0] = 0
target['pathology_prob'][SYN_cerebral == 0] = 0
# determine to be T1-resembled or T2-resembled
#if pathol_direction: lesion should be brigher than WM.mean()
# pathol_direction: +1: T2-like; -1: T1-like
pathol_direction = self.get_pathology_direction('synth', gm_mean > wm_mean)
else:
pathol_direction = None
target['pathology'] = 0.
target['pathology_prob'] = 0.
SYN[SYN < 0.] = 0.
return target['pathology'], target['pathology_prob'], self.augment_sample(name, SYN, setups, deform_dict, res, target, pathol_direction = pathol_direction)
def get_pathology_direction(self, input_mode, pathol_direction = None):
#if np.random.rand() < 0.1: # in some (rare) cases, randomly pick the direction
# return random.choice([True, False])
if pathol_direction is not None: # for synth image
return pathol_direction
if input_mode in ['T1', 'CT']:
return False
if input_mode in ['T2', 'FLAIR']:
return True
return random.choice([True, False])
def get_contrast(self, photo_mode):
# Sample Gaussian image
mus = 25 + 200 * torch.rand(256, dtype=torch.float, device=self.device)
sigmas = 5 + 20 * torch.rand(256, dtype=torch.float, device=self.device)
if np.random.rand() < self.synth_args.ct_prob:
darker = 25 + 10 * torch.rand(1, dtype=torch.float, device=self.device)[0]
for l in ct_brightness_group['darker']:
mus[l] = darker
dark = 90 + 20 * torch.rand(1, dtype=torch.float, device=self.device)[0]
for l in ct_brightness_group['dark']:
mus[l] = dark
bright = 110 + 20 * torch.rand(1, dtype=torch.float, device=self.device)[0]
for l in ct_brightness_group['bright']:
mus[l] = bright
brighter = 150 + 50 * torch.rand(1, dtype=torch.float, device=self.device)[0]
for l in ct_brightness_group['brighter']:
mus[l] = brighter
if photo_mode or np.random.rand(1)<0.5: # set the background to zero every once in a while (or always in photo mode)
mus[0] = 0
# partial volume
# 1 = lesion, 2 = WM, 3 = GM, 4 = CSF
v = 0.02 * torch.arange(50).to(self.device)
mus[100:150] = mus[1] * (1 - v) + mus[2] * v
mus[150:200] = mus[2] * (1 - v) + mus[3] * v
mus[200:250] = mus[3] * (1 - v) + mus[4] * v
mus[250] = mus[4]
sigmas[100:150] = torch.sqrt(sigmas[1]**2 * (1 - v) + sigmas[2]**2 * v)
sigmas[150:200] = torch.sqrt(sigmas[2]**2 * (1 - v) + sigmas[3]**2 * v)
sigmas[200:250] = torch.sqrt(sigmas[3]**2 * (1 - v) + sigmas[4]**2 * v)
sigmas[250] = sigmas[4]
return mus, sigmas
def get_setup_params(self):
if self.synth_args.left_hemis_only:
hemis = 'left'
else:
hemis = 'both'
if self.synth_args.low_res_only:
photo_mode = False
elif self.synth_args.left_hemis_only:
photo_mode = True
else:
photo_mode = np.random.rand() < self.synth_args.photo_prob
pathol_mode = np.random.rand() < self.synth_args.pathology_prob
pathol_random_shape = np.random.rand() < self.synth_args.random_shape_prob
spac = 2.5 + 10 * np.random.rand() if photo_mode else None
flip = np.random.randn() < self.synth_args.flip_prob if not self.synth_args.left_hemis_only else False
if photo_mode:
resolution = np.array([self.res_training_data[0], spac, self.res_training_data[2]])
thickness = np.array([self.res_training_data[0], 0.1, self.res_training_data[2]])
else:
resolution, thickness = resolution_sampler(self.synth_args.low_res_only)
return {'resolution': resolution, 'thickness': thickness,
'photo_mode': photo_mode, 'pathol_mode': pathol_mode,
'pathol_random_shape': pathol_random_shape,
'spac': spac, 'flip': flip, 'hemis': hemis}
def encode_pathology(self, I, P, Pprob, pathol_direction = None):
if pathol_direction is None: # True: T2/FLAIR-resembled, False: T1-resembled
pathol_direction = random.choice([True, False])
P, Pprob = torch.squeeze(P), torch.squeeze(Pprob)
I_mu = (I * P).sum() / P.sum()
p_mask = torch.round(P).long()
#pth_mus = I_mu/4 + I_mu/2 * torch.rand(10000, dtype=torch.float, device=self.device)
pth_mus = 3*I_mu/4 + I_mu/4 * torch.rand(10000, dtype=torch.float, device=self.device) # enforce the pathology pattern harder!
pth_mus = pth_mus if pathol_direction else -pth_mus
pth_sigmas = I_mu/4 * torch.rand(10000, dtype=torch.float, device=self.device)
I += Pprob * (pth_mus[p_mask] + pth_sigmas[p_mask] * torch.randn(p_mask.shape, dtype=torch.float, device=self.device))
I[I < 0] = 0
#print('encode', P.shape, P.mean())
#print('pre', I_mu)
#I_mu = (I * P).sum() / P.sum()
#print('post', I_mu)
return I
def get_info(self, t1):
t1dm = t1[:-7] + 'T1w.defacingmask.nii'
t2 = t1[:-7] + 'T2w.nii'
t2dm = t1[:-7] + 'T2w.defacingmask.nii'
flair = t1[:-7] + 'FLAIR.nii'
flairdm = t1[:-7] + 'FLAIR.defacingmask.nii'
ct = t1[:-7] + 'CT.nii'
ctdm = t1[:-7] + 'CT.defacingmask.nii'
generation_labels = t1[:-7] + 'generation_labels.nii'
segmentation_labels = t1[:-7] + self.gen_args.segment_prefix + '.nii'
#brain_dist_map = t1[:-7] + 'brain_dist_map.nii'
lp_dist_map = t1[:-7] + 'lp_dist_map.nii'
rp_dist_map = t1[:-7] + 'rp_dist_map.nii'
lw_dist_map = t1[:-7] + 'lw_dist_map.nii'
rw_dist_map = t1[:-7] + 'rw_dist_map.nii'
mni_reg_x = t1[:-7] + 'mni_reg.x.nii'
mni_reg_y = t1[:-7] + 'mni_reg.y.nii'
mni_reg_z = t1[:-7] + 'mni_reg.z.nii'
self.modalities = {'T1': t1, 'Gen': generation_labels, 'segmentation': segmentation_labels,
'distance': [lp_dist_map, lw_dist_map, rp_dist_map, rw_dist_map],
'registration': [mni_reg_x, mni_reg_y, mni_reg_z]}
if os.path.isfile(t1dm):
self.modalities.update({'T1_DM': t1dm})
if os.path.isfile(t2):
self.modalities.update({'T2': t2})
if os.path.isfile(t2dm):
self.modalities.update({'T2_DM': t2dm})
if os.path.isfile(flair):
self.modalities.update({'FLAIR': flair})
if os.path.isfile(flairdm):
self.modalities.update({'FLAIR_DM': flairdm})
if os.path.isfile(ct):
self.modalities.update({'CT': ct})
if os.path.isfile(ctdm):
self.modalities.update({'CT_DM': ctdm})
return self.modalities
def read_input(self, idx):
"""
determine input type according to prob (in generator/constants.py)
Logic: if np.random.rand() < real_image_prob and is real_image_exist --> input real images; otherwise, synthesize images.
"""
dataset_name, input_prob, t1_path, age = self.idx_to_path(idx)
case_name = os.path.basename(t1_path).split('.T1w.nii')[0]
self.modalities = self.get_info(t1_path)
prob = np.random.rand()
if prob < input_prob['T1'] and 'T1' in self.modalities:
input_mode = 'T1'
img, aff, res = read_image(self.modalities['T1'])
elif prob < input_prob['T2'] and 'T2' in self.modalities:
input_mode = 'T2'
img, aff, res = read_image(self.modalities['T2'])
elif prob < input_prob['FLAIR'] and 'FLAIR' in self.modalities:
input_mode = 'FLAIR'
img, aff, res = read_image(self.modalities['FLAIR'])
elif prob < input_prob['CT'] and 'CT' in self.modalities:
input_mode = 'CT'
img, aff, res = read_image(self.modalities['CT'])
else:
input_mode = 'synth'
img, aff, res = read_image(self.modalities['Gen'])
return dataset_name, case_name, input_mode, img, aff, res, age
def read_and_deform_target(self, idx, exist_keys, task_name, input_mode, setups, deform_dict, linear_weights = None):
current_target = {}
p_prob_path, augment, thres = None, False, 0.1
if task_name == 'pathology':
# NOTE: for now - encode pathology only for healthy cases
# TODO: what to do if the case has pathology itself? -- inconsistency between encoded pathol and the output
if self.pathology_type is None: # healthy
if setups['pathol_mode']: # and input_mode == 'synth':
if setups['pathol_random_shape']:
p_prob_path = 'random_shape'
augment, thres = False, self.shape_gen_args.pathol_thres
else:
p_prob_path = random.choice(pathology_prob_paths)
augment, thres = self.synth_args.augment_pathology, self.shape_gen_args.pathol_thres
else:
pass
#p_prob_path = self.modalities['pathology_prob']
current_target = processing_funcs[task_name](exist_keys, task_name, p_prob_path, setups, deform_dict, self.device,
mask = self.hemis_mask,
augment = augment,
pde_func = self.adv_pde,
t = self.t,
shape_gen_args = self.shape_gen_args,
thres = thres
)
else:
if task_name in self.modalities:
current_target = processing_funcs[task_name](exist_keys, task_name, self.modalities[task_name],
setups, deform_dict, self.device,
mask = self.hemis_mask,
cfg = self.gen_args,
onehotmatrix = self.onehotmatrix,
lut = self.lut, vflip = self.vflip
)
else:
current_target = {task_name: 0.}
return current_target
def update_gen_args(self, new_args):
for key, value in vars(new_args).items():
vars(self.gen_args.generator)[key] = value
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
# read input: real or synthesized image, according to customized prob
dataset_name, case_name, input_mode, img, aff, res, age = self.read_input(idx)
# generate random values
setups = self.get_setup_params()
# sample random deformation
deform_dict = self.generate_deformation(setups, img.shape)
# get left_hemis_mask if needed
self.get_left_hemis_mask(deform_dict['grid'])
# read and deform target according to the assigned tasks
target = defaultdict(lambda: None)
target['name'] = case_name
target.update(self.read_and_deform_target(idx, target.keys(), 'T1', input_mode, setups, deform_dict))
target.update(self.read_and_deform_target(idx, target.keys(), 'T2', input_mode, setups, deform_dict))
target.update(self.read_and_deform_target(idx, target.keys(), 'FLAIR', input_mode, setups, deform_dict))
for task_name in self.tasks:
if task_name in processing_funcs.keys() and task_name not in ['T1', 'T2', 'FLAIR']:
target.update(self.read_and_deform_target(idx, target.keys(), task_name, input_mode, setups, deform_dict))
# process or generate input sample
if input_mode == 'synth':
self.update_gen_args(self.synth_image_args) # severe noise injection for real images
target['pathology'], target['pathology_prob'], sample = \
self.generate_sample(case_name, img, setups, deform_dict, res, target)
else:
self.update_gen_args(self.real_image_args) # milder noise injection for real images
sample = self.augment_sample(case_name, img, setups, deform_dict, res, target,
pathol_direction = self.get_pathology_direction(input_mode),input_mode = input_mode)
if setups['flip'] and isinstance(target['pathology'], torch.Tensor): # flipping should happen after P has been encoded
target['pathology'], target['pathology_prob'] = torch.flip(target['pathology'], [1]), torch.flip(target['pathology_prob'], [1])
if age is not None:
target['age'] = age
return self.datasets_num, dataset_name, input_mode, target, sample
# An example of customized dataset from BaseSynth
class BrainIDGen(BaseGen):
"""
BrainIDGen dataset
BrainIDGen enables intra-subject augmentation, i.e., each subject will have multiple augmentations
"""
def __init__(self, gen_args, device='cpu'):
super(BrainIDGen, self).__init__(gen_args, device)
self.all_samples = gen_args.generator.all_samples
self.mild_samples = gen_args.generator.mild_samples
self.mild_generator_args = gen_args.mild_generator
self.severe_generator_args = gen_args.severe_generator
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
# read input: real or synthesized image, according to customized prob
dataset_name, case_name, input_mode, img, aff, res, age = self.read_input(idx)
# generate random values
setups = self.get_setup_params()
# sample random deformation
deform_dict = self.generate_deformation(setups, img.shape)
# get left_hemis_mask if needed
self.get_left_hemis_mask(deform_dict['grid'])
# read and deform target according to the assigned tasks
target = defaultdict(lambda: 1.)
target['name'] = case_name
target.update(self.read_and_deform_target(idx, target.keys(), 'T1', input_mode, setups, deform_dict))
target.update(self.read_and_deform_target(idx, target.keys(), 'T2', input_mode, setups, deform_dict))
target.update(self.read_and_deform_target(idx, target.keys(), 'FLAIR', input_mode, setups, deform_dict))
for task_name in self.tasks:
if task_name in processing_funcs.keys() and task_name not in ['T1', 'T2', 'FLAIR']:
target.update(self.read_and_deform_target(idx, target.keys(), task_name, input_mode, setups, deform_dict))
# process or generate intra-subject input samples
samples = []
for i_sample in range(self.all_samples):
if i_sample < self.mild_samples:
self.update_gen_args(self.mild_generator_args)
if input_mode == 'synth':
self.update_gen_args(self.synth_image_args)
target['pathology'], target['pathology_prob'], sample = \
self.generate_sample(case_name, img, setups, deform_dict, res, target)
else:
self.update_gen_args(self.real_image_args)
sample = self.augment_sample(case_name, img, setups, deform_dict, res, target,
pathol_direction = self.get_pathology_direction(input_mode),input_mode = input_mode)
else:
self.update_gen_args(self.severe_generator_args)
if input_mode == 'synth':
self.update_gen_args(self.synth_image_args)
target['pathology'], target['pathology_prob'], sample = \
self.generate_sample(case_name, img, setups, deform_dict, res, target)
else:
self.update_gen_args(self.real_image_args)
sample = self.augment_sample(case_name, img, setups, deform_dict, res, target,
pathol_direction = self.get_pathology_direction(input_mode),input_mode = input_mode)
samples.append(sample)
if setups['flip'] and isinstance(target['pathology'], torch.Tensor): # flipping should happen after P has been encoded
target['pathology'], target['pathology_prob'] = torch.flip(target['pathology'], [1]), torch.flip(target['pathology_prob'], [1])
if age is not None:
target['age'] = age
return self.datasets_num, dataset_name, input_mode, target, samples