qicq1c commited on
Commit
afe7f68
1 Parent(s): 948abd5

a3661a547c9a42b65413352701dc44d90f93ae03fa3e31e8d61086010509c3f3

Browse files
Files changed (48) hide show
  1. Generation_Pipeline_filter/syn_liver/TumorGeneration/diffusion_config/vq_gan_3d.yaml +37 -0
  2. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__init__.py +1 -0
  3. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc +0 -0
  4. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc +0 -0
  5. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc +0 -0
  6. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc +0 -0
  7. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc +0 -0
  8. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc +0 -0
  9. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc +0 -0
  10. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/ddim.py +206 -0
  11. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/diffusion.py +1016 -0
  12. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/text.py +94 -0
  13. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/time_embedding.py +75 -0
  14. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/unet.py +226 -0
  15. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/util.py +271 -0
  16. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc +0 -0
  17. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__init__.py +3 -0
  18. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc +0 -0
  19. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc +0 -0
  20. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc +0 -0
  21. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc +0 -0
  22. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth +3 -0
  23. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/codebook.py +109 -0
  24. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/lpips.py +181 -0
  25. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py +561 -0
  26. Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/utils.py +177 -0
  27. Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_kidney_fold0_early.pt +3 -0
  28. Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_kidney_fold0_noearly_t200.pt +3 -0
  29. Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_liver_fold0_early.pt +3 -0
  30. Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_liver_fold0_noearly_t200.pt +3 -0
  31. Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_pancreas_fold0_early.pt +3 -0
  32. Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_pancreas_fold0_noearly_t200.pt +3 -0
  33. Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/recon_96d4_all.ckpt +3 -0
  34. Generation_Pipeline_filter/syn_liver/TumorGeneration/utils.py +471 -0
  35. Generation_Pipeline_filter/syn_liver/TumorGeneration/utils_.py +298 -0
  36. Generation_Pipeline_filter/syn_liver/healthy_liver_1k.txt +895 -0
  37. Generation_Pipeline_filter/syn_liver/out/BDMAP_00000411/ct.nii.gz +3 -0
  38. Generation_Pipeline_filter/syn_liver/out/BDMAP_00000411/segmentations/liver_tumor.nii.gz +3 -0
  39. Generation_Pipeline_filter/syn_liver/out/BDMAP_00000696/ct.nii.gz +3 -0
  40. Generation_Pipeline_filter/syn_liver/out/BDMAP_00000696/segmentations/liver_tumor.nii.gz +3 -0
  41. Generation_Pipeline_filter/syn_liver/out/BDMAP_00001736/ct.nii.gz +3 -0
  42. Generation_Pipeline_filter/syn_liver/out/BDMAP_00001736/segmentations/liver_tumor.nii.gz +3 -0
  43. Generation_Pipeline_filter/syn_liver/out/BDMAP_00002690/ct.nii.gz +3 -0
  44. Generation_Pipeline_filter/syn_liver/out/BDMAP_00002690/segmentations/liver_tumor.nii.gz +3 -0
  45. Generation_Pipeline_filter/syn_liver/out/BDMAP_00003277/ct.nii.gz +3 -0
  46. Generation_Pipeline_filter/syn_liver/out/BDMAP_00003277/segmentations/liver_tumor.nii.gz +3 -0
  47. Generation_Pipeline_filter/syn_liver/out/BDMAP_00004183/ct.nii.gz +3 -0
  48. Generation_Pipeline_filter/syn_liver/out/BDMAP_00004183/segmentations/liver_tumor.nii.gz +3 -0
Generation_Pipeline_filter/syn_liver/TumorGeneration/diffusion_config/vq_gan_3d.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 1234
2
+ batch_size: 2 # 30
3
+ num_workers: 32 # 30
4
+
5
+ gpus: 1
6
+ accumulate_grad_batches: 1
7
+ default_root_dir: checkpoints/vq_gan/
8
+ default_root_dir_postfix: 'flair'
9
+ resume_from_checkpoint:
10
+ max_steps: -1
11
+ max_epochs: -1
12
+ precision: 16
13
+ gradient_clip_val: 1.0
14
+
15
+
16
+ embedding_dim: 8 # 256
17
+ n_codes: 16384 # 2048
18
+ n_hiddens: 16
19
+ lr: 3e-4
20
+ downsample: [2, 2, 2] # [4, 4, 4]
21
+ disc_channels: 64
22
+ disc_layers: 3
23
+ discriminator_iter_start: 10000 # 50000
24
+ disc_loss_type: hinge
25
+ image_gan_weight: 1.0
26
+ video_gan_weight: 1.0
27
+ l1_weight: 4.0
28
+ gan_feat_weight: 4.0 # 0.0
29
+ perceptual_weight: 4.0 # 0.0
30
+ i3d_feat: False
31
+ restart_thres: 1.0
32
+ no_random_restart: False
33
+ norm_type: group
34
+ padding_type: replicate
35
+ num_groups: 32
36
+
37
+
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .diffusion import Unet3D, GaussianDiffusion, Tester
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (285 Bytes). View file
 
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc ADDED
Binary file (6.05 kB). View file
 
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc ADDED
Binary file (28.5 kB). View file
 
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc ADDED
Binary file (1.9 kB). View file
 
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc ADDED
Binary file (2.85 kB). View file
 
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc ADDED
Binary file (5.77 kB). View file
 
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc ADDED
Binary file (9.47 kB). View file
 
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/ddim.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from .util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
+
10
+
11
+ class DDIMSampler(object):
12
+ def __init__(self, model, schedule="linear", **kwargs):
13
+ super().__init__()
14
+ self.model = model
15
+ self.ddpm_num_timesteps = model.num_timesteps
16
+ self.schedule = schedule
17
+
18
+ def register_buffer(self, name, attr):
19
+ if type(attr) == torch.Tensor:
20
+ if attr.device != torch.device("cuda"):
21
+ attr = attr.to(torch.device("cuda"))
22
+ setattr(self, name, attr)
23
+
24
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): # "uniform" 'quad'
25
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
26
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
27
+
28
+ alphas_cumprod = self.model.alphas_cumprod
29
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
30
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
31
+
32
+ self.register_buffer('betas', to_torch(self.model.betas))
33
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
34
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
35
+
36
+ # calculations for diffusion q(x_t | x_{t-1}) and others
37
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
38
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
39
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
40
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
41
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
42
+ # breakpoint()
43
+ # ddim sampling parameters
44
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
45
+ ddim_timesteps=self.ddim_timesteps,
46
+ eta=ddim_eta,verbose=verbose)
47
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
48
+ self.register_buffer('ddim_alphas', ddim_alphas)
49
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
50
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
51
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
52
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
53
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
54
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
55
+
56
+ @torch.no_grad()
57
+ def sample(self,
58
+ S,
59
+ batch_size,
60
+ shape,
61
+ conditioning=None,
62
+ callback=None,
63
+ normals_sequence=None,
64
+ img_callback=None,
65
+ quantize_x0=False,
66
+ eta=0.,
67
+ mask=None,
68
+ x0=None,
69
+ temperature=1.,
70
+ noise_dropout=0.,
71
+ score_corrector=None,
72
+ corrector_kwargs=None,
73
+ verbose=True,
74
+ x_T=None,
75
+ log_every_t=100,
76
+ unconditional_guidance_scale=1.,
77
+ unconditional_conditioning=None,
78
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
79
+ **kwargs
80
+ ):
81
+ if conditioning is not None:
82
+ if isinstance(conditioning, dict):
83
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
84
+ if cbs != batch_size:
85
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
86
+ else:
87
+ if conditioning.shape[0] != batch_size:
88
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
89
+
90
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
91
+ # sampling
92
+ C, T, H, W = shape
93
+ # breakpoint()
94
+ size = (batch_size, C, T, H, W)
95
+ # print(f'Data shape for DDIM sampling is {size}, eta {eta}')
96
+
97
+ samples, intermediates = self.ddim_sampling(conditioning, size,
98
+ callback=callback,
99
+ img_callback=img_callback,
100
+ quantize_denoised=quantize_x0,
101
+ mask=mask, x0=x0,
102
+ ddim_use_original_steps=False,
103
+ noise_dropout=noise_dropout,
104
+ temperature=temperature,
105
+ score_corrector=score_corrector,
106
+ corrector_kwargs=corrector_kwargs,
107
+ x_T=x_T,
108
+ log_every_t=log_every_t,
109
+ unconditional_guidance_scale=unconditional_guidance_scale,
110
+ unconditional_conditioning=unconditional_conditioning,
111
+ )
112
+ return samples, intermediates
113
+
114
+ @torch.no_grad()
115
+ def ddim_sampling(self, cond, shape,
116
+ x_T=None, ddim_use_original_steps=False,
117
+ callback=None, timesteps=None, quantize_denoised=False,
118
+ mask=None, x0=None, img_callback=None, log_every_t=100,
119
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
120
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
121
+ device = self.model.betas.device
122
+ b = shape[0]
123
+ if x_T is None:
124
+ img = torch.randn(shape, device=device)
125
+ else:
126
+ img = x_T
127
+
128
+ if timesteps is None:
129
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
130
+ elif timesteps is not None and not ddim_use_original_steps:
131
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
132
+ timesteps = self.ddim_timesteps[:subset_end]
133
+
134
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
135
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
136
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
137
+ # print(f"Running DDIM Sampling with {total_steps} timesteps")
138
+
139
+ # iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
140
+
141
+ for i, step in enumerate(time_range):
142
+ index = total_steps - i - 1
143
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
144
+
145
+ if mask is not None:
146
+ assert x0 is not None
147
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
148
+ img = img_orig * mask + (1. - mask) * img
149
+
150
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
151
+ quantize_denoised=quantize_denoised, temperature=temperature,
152
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
153
+ corrector_kwargs=corrector_kwargs,
154
+ unconditional_guidance_scale=unconditional_guidance_scale,
155
+ unconditional_conditioning=unconditional_conditioning)
156
+ img, pred_x0 = outs
157
+ if callback: callback(i)
158
+ if img_callback: img_callback(pred_x0, i)
159
+
160
+ if index % log_every_t == 0 or index == total_steps - 1:
161
+ intermediates['x_inter'].append(img)
162
+ intermediates['pred_x0'].append(pred_x0)
163
+
164
+ return img, intermediates
165
+
166
+ @torch.no_grad()
167
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
168
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
169
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
170
+ b, *_, device = *x.shape, x.device
171
+
172
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
173
+ # breakpoint()
174
+ e_t = self.model.denoise_fn(x, t, c)
175
+ else:
176
+ x_in = torch.cat([x] * 2)
177
+ t_in = torch.cat([t] * 2)
178
+ c_in = torch.cat([unconditional_conditioning, c])
179
+ e_t_uncond, e_t = self.model.denoise_fn(x_in, t_in, c_in).chunk(2)
180
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
181
+
182
+ if score_corrector is not None:
183
+ assert self.model.parameterization == "eps"
184
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
185
+
186
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
187
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
188
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
189
+ sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
190
+ # select parameters corresponding to the currently considered timestep
191
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
192
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
193
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
194
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
195
+
196
+ # current prediction for x_0
197
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
198
+ if quantize_denoised:
199
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
200
+ # direction pointing to x_t
201
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
202
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
203
+ if noise_dropout > 0.:
204
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
205
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
206
+ return x_prev, pred_x0
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/diffusion.py ADDED
@@ -0,0 +1,1016 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Largely taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch"
2
+
3
+ import math
4
+ import copy
5
+ import torch
6
+ from torch import nn, einsum
7
+ import torch.nn.functional as F
8
+ from functools import partial
9
+
10
+ from torch.utils import data
11
+ from pathlib import Path
12
+ from torch.optim import Adam
13
+ from torchvision import transforms as T, utils
14
+ from torch.cuda.amp import autocast, GradScaler
15
+ from PIL import Image
16
+
17
+ from tqdm import tqdm
18
+ from einops import rearrange
19
+ from einops_exts import check_shape, rearrange_many
20
+
21
+ from rotary_embedding_torch import RotaryEmbedding
22
+
23
+ from .text import tokenize, bert_embed, BERT_MODEL_DIM
24
+ from torch.utils.data import Dataset, DataLoader
25
+ from ..vq_gan_3d.model.vqgan import VQGAN
26
+
27
+ import matplotlib.pyplot as plt
28
+
29
+ # helpers functions
30
+
31
+
32
+ def exists(x):
33
+ return x is not None
34
+
35
+
36
+ def noop(*args, **kwargs):
37
+ pass
38
+
39
+
40
+ def is_odd(n):
41
+ return (n % 2) == 1
42
+
43
+
44
+ def default(val, d):
45
+ if exists(val):
46
+ return val
47
+ return d() if callable(d) else d
48
+
49
+
50
+ def cycle(dl):
51
+ while True:
52
+ for data in dl:
53
+ yield data
54
+
55
+
56
+ def num_to_groups(num, divisor):
57
+ groups = num // divisor
58
+ remainder = num % divisor
59
+ arr = [divisor] * groups
60
+ if remainder > 0:
61
+ arr.append(remainder)
62
+ return arr
63
+
64
+
65
+ def prob_mask_like(shape, prob, device):
66
+ if prob == 1:
67
+ return torch.ones(shape, device=device, dtype=torch.bool)
68
+ elif prob == 0:
69
+ return torch.zeros(shape, device=device, dtype=torch.bool)
70
+ else:
71
+ return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
72
+
73
+
74
+ def is_list_str(x):
75
+ if not isinstance(x, (list, tuple)):
76
+ return False
77
+ return all([type(el) == str for el in x])
78
+
79
+ # relative positional bias
80
+
81
+
82
+ class RelativePositionBias(nn.Module):
83
+ def __init__(
84
+ self,
85
+ heads=8,
86
+ num_buckets=32,
87
+ max_distance=128
88
+ ):
89
+ super().__init__()
90
+ self.num_buckets = num_buckets
91
+ self.max_distance = max_distance
92
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
93
+
94
+ @staticmethod
95
+ def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
96
+ ret = 0
97
+ n = -relative_position
98
+
99
+ num_buckets //= 2
100
+ ret += (n < 0).long() * num_buckets
101
+ n = torch.abs(n)
102
+
103
+ max_exact = num_buckets // 2
104
+ is_small = n < max_exact
105
+
106
+ val_if_large = max_exact + (
107
+ torch.log(n.float() / max_exact) / math.log(max_distance /
108
+ max_exact) * (num_buckets - max_exact)
109
+ ).long()
110
+ val_if_large = torch.min(
111
+ val_if_large, torch.full_like(val_if_large, num_buckets - 1))
112
+
113
+ ret += torch.where(is_small, n, val_if_large)
114
+ return ret
115
+
116
+ def forward(self, n, device):
117
+ q_pos = torch.arange(n, dtype=torch.long, device=device)
118
+ k_pos = torch.arange(n, dtype=torch.long, device=device)
119
+ rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
120
+ rp_bucket = self._relative_position_bucket(
121
+ rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance)
122
+ values = self.relative_attention_bias(rp_bucket)
123
+ return rearrange(values, 'i j h -> h i j')
124
+
125
+ # small helper modules
126
+
127
+
128
+ class EMA():
129
+ def __init__(self, beta):
130
+ super().__init__()
131
+ self.beta = beta
132
+
133
+ def update_model_average(self, ma_model, current_model):
134
+ for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
135
+ old_weight, up_weight = ma_params.data, current_params.data
136
+ ma_params.data = self.update_average(old_weight, up_weight)
137
+
138
+ def update_average(self, old, new):
139
+ if old is None:
140
+ return new
141
+ return old * self.beta + (1 - self.beta) * new
142
+
143
+
144
+ class Residual(nn.Module):
145
+ def __init__(self, fn):
146
+ super().__init__()
147
+ self.fn = fn
148
+
149
+ def forward(self, x, *args, **kwargs):
150
+ return self.fn(x, *args, **kwargs) + x
151
+
152
+
153
+ class SinusoidalPosEmb(nn.Module):
154
+ def __init__(self, dim):
155
+ super().__init__()
156
+ self.dim = dim
157
+
158
+ def forward(self, x):
159
+ device = x.device
160
+ half_dim = self.dim // 2
161
+ emb = math.log(10000) / (half_dim - 1)
162
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
163
+ emb = x[:, None] * emb[None, :]
164
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
165
+ return emb
166
+
167
+
168
+ def Upsample(dim):
169
+ return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))
170
+
171
+
172
+ def Downsample(dim):
173
+ return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))
174
+
175
+
176
+ class LayerNorm(nn.Module):
177
+ def __init__(self, dim, eps=1e-5):
178
+ super().__init__()
179
+ self.eps = eps
180
+ self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1))
181
+
182
+ def forward(self, x):
183
+ var = torch.var(x, dim=1, unbiased=False, keepdim=True)
184
+ mean = torch.mean(x, dim=1, keepdim=True)
185
+ return (x - mean) / (var + self.eps).sqrt() * self.gamma
186
+
187
+
188
+ class PreNorm(nn.Module):
189
+ def __init__(self, dim, fn):
190
+ super().__init__()
191
+ self.fn = fn
192
+ self.norm = LayerNorm(dim)
193
+
194
+ def forward(self, x, **kwargs):
195
+ x = self.norm(x)
196
+ return self.fn(x, **kwargs)
197
+
198
+ # building block modules
199
+
200
+
201
+ class Block(nn.Module):
202
+ def __init__(self, dim, dim_out, groups=8):
203
+ super().__init__()
204
+ self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1))
205
+ self.norm = nn.GroupNorm(groups, dim_out)
206
+ self.act = nn.SiLU()
207
+
208
+ def forward(self, x, scale_shift=None):
209
+ x = self.proj(x)
210
+ x = self.norm(x)
211
+
212
+ if exists(scale_shift):
213
+ scale, shift = scale_shift
214
+ x = x * (scale + 1) + shift
215
+
216
+ return self.act(x)
217
+
218
+
219
+ class ResnetBlock(nn.Module):
220
+ def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
221
+ super().__init__()
222
+ self.mlp = nn.Sequential(
223
+ nn.SiLU(),
224
+ nn.Linear(time_emb_dim, dim_out * 2)
225
+ ) if exists(time_emb_dim) else None
226
+
227
+ self.block1 = Block(dim, dim_out, groups=groups)
228
+ self.block2 = Block(dim_out, dim_out, groups=groups)
229
+ self.res_conv = nn.Conv3d(
230
+ dim, dim_out, 1) if dim != dim_out else nn.Identity()
231
+
232
+ def forward(self, x, time_emb=None):
233
+
234
+ scale_shift = None
235
+ if exists(self.mlp):
236
+ assert exists(time_emb), 'time emb must be passed in'
237
+ time_emb = self.mlp(time_emb)
238
+ time_emb = rearrange(time_emb, 'b c -> b c 1 1 1')
239
+ scale_shift = time_emb.chunk(2, dim=1)
240
+
241
+ h = self.block1(x, scale_shift=scale_shift)
242
+
243
+ h = self.block2(h)
244
+ return h + self.res_conv(x)
245
+
246
+
247
+ class SpatialLinearAttention(nn.Module):
248
+ def __init__(self, dim, heads=4, dim_head=32):
249
+ super().__init__()
250
+ self.scale = dim_head ** -0.5
251
+ self.heads = heads
252
+ hidden_dim = dim_head * heads
253
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
254
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
255
+
256
+ def forward(self, x):
257
+ b, c, f, h, w = x.shape
258
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
259
+
260
+ qkv = self.to_qkv(x).chunk(3, dim=1)
261
+ q, k, v = rearrange_many(
262
+ qkv, 'b (h c) x y -> b h c (x y)', h=self.heads)
263
+
264
+ q = q.softmax(dim=-2)
265
+ k = k.softmax(dim=-1)
266
+
267
+ q = q * self.scale
268
+ context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
269
+
270
+ out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
271
+ out = rearrange(out, 'b h c (x y) -> b (h c) x y',
272
+ h=self.heads, x=h, y=w)
273
+ out = self.to_out(out)
274
+ return rearrange(out, '(b f) c h w -> b c f h w', b=b)
275
+
276
+ # attention along space and time
277
+
278
+
279
+ class EinopsToAndFrom(nn.Module):
280
+ def __init__(self, from_einops, to_einops, fn):
281
+ super().__init__()
282
+ self.from_einops = from_einops
283
+ self.to_einops = to_einops
284
+ self.fn = fn
285
+
286
+ def forward(self, x, **kwargs):
287
+ shape = x.shape
288
+ reconstitute_kwargs = dict(
289
+ tuple(zip(self.from_einops.split(' '), shape)))
290
+ x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')
291
+ x = self.fn(x, **kwargs)
292
+ x = rearrange(
293
+ x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)
294
+ return x
295
+
296
+
297
+ class Attention(nn.Module):
298
+ def __init__(
299
+ self,
300
+ dim,
301
+ heads=4,
302
+ dim_head=32,
303
+ rotary_emb=None
304
+ ):
305
+ super().__init__()
306
+ self.scale = dim_head ** -0.5
307
+ self.heads = heads
308
+ hidden_dim = dim_head * heads
309
+
310
+ self.rotary_emb = rotary_emb
311
+ self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False)
312
+ self.to_out = nn.Linear(hidden_dim, dim, bias=False)
313
+
314
+ def forward(
315
+ self,
316
+ x,
317
+ pos_bias=None,
318
+ focus_present_mask=None
319
+ ):
320
+ n, device = x.shape[-2], x.device
321
+
322
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
323
+
324
+ if exists(focus_present_mask) and focus_present_mask.all():
325
+ # if all batch samples are focusing on present
326
+ # it would be equivalent to passing that token's values through to the output
327
+ values = qkv[-1]
328
+ return self.to_out(values)
329
+
330
+ # split out heads
331
+
332
+ q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads)
333
+
334
+ # scale
335
+
336
+ q = q * self.scale
337
+
338
+ # rotate positions into queries and keys for time attention
339
+
340
+ if exists(self.rotary_emb):
341
+ q = self.rotary_emb.rotate_queries_or_keys(q)
342
+ k = self.rotary_emb.rotate_queries_or_keys(k)
343
+
344
+ # similarity
345
+
346
+ sim = einsum('... h i d, ... h j d -> ... h i j', q, k)
347
+
348
+ # relative positional bias
349
+
350
+ if exists(pos_bias):
351
+ sim = sim + pos_bias
352
+
353
+ if exists(focus_present_mask) and not (~focus_present_mask).all():
354
+ attend_all_mask = torch.ones(
355
+ (n, n), device=device, dtype=torch.bool)
356
+ attend_self_mask = torch.eye(n, device=device, dtype=torch.bool)
357
+
358
+ mask = torch.where(
359
+ rearrange(focus_present_mask, 'b -> b 1 1 1 1'),
360
+ rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),
361
+ rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),
362
+ )
363
+
364
+ sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
365
+
366
+ # numerical stability
367
+
368
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
369
+ attn = sim.softmax(dim=-1)
370
+
371
+ # aggregate values
372
+
373
+ out = einsum('... h i j, ... h j d -> ... h i d', attn, v)
374
+ out = rearrange(out, '... h n d -> ... n (h d)')
375
+ return self.to_out(out)
376
+
377
+ # model
378
+
379
+
380
+ class Unet3D(nn.Module):
381
+ def __init__(
382
+ self,
383
+ dim,
384
+ cond_dim=None,
385
+ out_dim=None,
386
+ dim_mults=(1, 2, 4, 8),
387
+ channels=3,
388
+ attn_heads=8,
389
+ attn_dim_head=32,
390
+ use_bert_text_cond=False,
391
+ init_dim=None,
392
+ init_kernel_size=7,
393
+ use_sparse_linear_attn=True,
394
+ block_type='resnet',
395
+ resnet_groups=8
396
+ ):
397
+ super().__init__()
398
+ self.channels = channels
399
+
400
+ # temporal attention and its relative positional encoding
401
+
402
+ rotary_emb = RotaryEmbedding(min(32, attn_dim_head))
403
+
404
+ def temporal_attn(dim): return EinopsToAndFrom('b c f h w', 'b (h w) f c', Attention(
405
+ dim, heads=attn_heads, dim_head=attn_dim_head, rotary_emb=rotary_emb))
406
+
407
+ # realistically will not be able to generate that many frames of video... yet
408
+ self.time_rel_pos_bias = RelativePositionBias(
409
+ heads=attn_heads, max_distance=32)
410
+
411
+ # initial conv
412
+
413
+ init_dim = default(init_dim, dim)
414
+ assert is_odd(init_kernel_size)
415
+
416
+ init_padding = init_kernel_size // 2
417
+ self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size,
418
+ init_kernel_size), padding=(0, init_padding, init_padding))
419
+
420
+ self.init_temporal_attn = Residual(
421
+ PreNorm(init_dim, temporal_attn(init_dim)))
422
+
423
+ # dimensions
424
+
425
+ dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
426
+ in_out = list(zip(dims[:-1], dims[1:]))
427
+
428
+ # time conditioning
429
+
430
+ time_dim = dim * 4
431
+ self.time_mlp = nn.Sequential(
432
+ SinusoidalPosEmb(dim),
433
+ nn.Linear(dim, time_dim),
434
+ nn.GELU(),
435
+ nn.Linear(time_dim, time_dim)
436
+ )
437
+
438
+ # text conditioning
439
+
440
+ self.has_cond = exists(cond_dim) or use_bert_text_cond
441
+ cond_dim = BERT_MODEL_DIM if use_bert_text_cond else cond_dim
442
+
443
+ self.null_cond_emb = nn.Parameter(
444
+ torch.randn(1, cond_dim)) if self.has_cond else None
445
+
446
+ cond_dim = time_dim + int(cond_dim or 0)
447
+
448
+ # layers
449
+
450
+ self.downs = nn.ModuleList([])
451
+ self.ups = nn.ModuleList([])
452
+
453
+ num_resolutions = len(in_out)
454
+ # block type
455
+
456
+ block_klass = partial(ResnetBlock, groups=resnet_groups)
457
+ block_klass_cond = partial(block_klass, time_emb_dim=cond_dim)
458
+
459
+ # modules for all layers
460
+ for ind, (dim_in, dim_out) in enumerate(in_out):
461
+ is_last = ind >= (num_resolutions - 1)
462
+
463
+ self.downs.append(nn.ModuleList([
464
+ block_klass_cond(dim_in, dim_out),
465
+ block_klass_cond(dim_out, dim_out),
466
+ Residual(PreNorm(dim_out, SpatialLinearAttention(
467
+ dim_out, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(),
468
+ Residual(PreNorm(dim_out, temporal_attn(dim_out))),
469
+ Downsample(dim_out) if not is_last else nn.Identity()
470
+ ]))
471
+
472
+ mid_dim = dims[-1]
473
+ self.mid_block1 = block_klass_cond(mid_dim, mid_dim)
474
+
475
+ spatial_attn = EinopsToAndFrom(
476
+ 'b c f h w', 'b f (h w) c', Attention(mid_dim, heads=attn_heads))
477
+
478
+ self.mid_spatial_attn = Residual(PreNorm(mid_dim, spatial_attn))
479
+ self.mid_temporal_attn = Residual(
480
+ PreNorm(mid_dim, temporal_attn(mid_dim)))
481
+
482
+ self.mid_block2 = block_klass_cond(mid_dim, mid_dim)
483
+
484
+ for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
485
+ is_last = ind >= (num_resolutions - 1)
486
+
487
+ self.ups.append(nn.ModuleList([
488
+ block_klass_cond(dim_out * 2, dim_in),
489
+ block_klass_cond(dim_in, dim_in),
490
+ Residual(PreNorm(dim_in, SpatialLinearAttention(
491
+ dim_in, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(),
492
+ Residual(PreNorm(dim_in, temporal_attn(dim_in))),
493
+ Upsample(dim_in) if not is_last else nn.Identity()
494
+ ]))
495
+
496
+ out_dim = default(out_dim, channels)
497
+ self.final_conv = nn.Sequential(
498
+ block_klass(dim * 2, dim),
499
+ nn.Conv3d(dim, out_dim, 1)
500
+ )
501
+
502
+ def forward_with_cond_scale(
503
+ self,
504
+ *args,
505
+ cond_scale=2.,
506
+ **kwargs
507
+ ):
508
+ logits = self.forward(*args, null_cond_prob=0., **kwargs)
509
+ if cond_scale == 1 or not self.has_cond:
510
+ return logits
511
+
512
+ null_logits = self.forward(*args, null_cond_prob=1., **kwargs)
513
+ return null_logits + (logits - null_logits) * cond_scale
514
+
515
+ def forward(
516
+ self,
517
+ x,
518
+ time,
519
+ cond=None,
520
+ null_cond_prob=0.,
521
+ focus_present_mask=None,
522
+ # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time)
523
+ prob_focus_present=0.
524
+ ):
525
+ assert not (self.has_cond and not exists(cond)
526
+ ), 'cond must be passed in if cond_dim specified'
527
+ x = torch.cat([x, cond], dim=1)
528
+
529
+ batch, device = x.shape[0], x.device
530
+
531
+ focus_present_mask = default(focus_present_mask, lambda: prob_mask_like(
532
+ (batch,), prob_focus_present, device=device))
533
+
534
+ time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device=x.device)
535
+
536
+ x = self.init_conv(x)
537
+ r = x.clone()
538
+
539
+ x = self.init_temporal_attn(x, pos_bias=time_rel_pos_bias)
540
+
541
+ t = self.time_mlp(time) if exists(self.time_mlp) else None # [2, 128]
542
+
543
+ # classifier free guidance
544
+
545
+ if self.has_cond:
546
+ batch, device = x.shape[0], x.device
547
+ mask = prob_mask_like((batch,), null_cond_prob, device=device)
548
+ cond = torch.where(rearrange(mask, 'b -> b 1'),
549
+ self.null_cond_emb, cond)
550
+ t = torch.cat((t, cond), dim=-1)
551
+
552
+ h = []
553
+
554
+ for block1, block2, spatial_attn, temporal_attn, downsample in self.downs:
555
+ x = block1(x, t)
556
+ x = block2(x, t)
557
+ x = spatial_attn(x)
558
+ x = temporal_attn(x, pos_bias=time_rel_pos_bias,
559
+ focus_present_mask=focus_present_mask)
560
+ h.append(x)
561
+ x = downsample(x)
562
+
563
+ # [2, 256, 32, 4, 4]
564
+ x = self.mid_block1(x, t)
565
+ x = self.mid_spatial_attn(x)
566
+ x = self.mid_temporal_attn(
567
+ x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask)
568
+ x = self.mid_block2(x, t)
569
+
570
+ for block1, block2, spatial_attn, temporal_attn, upsample in self.ups:
571
+ x = torch.cat((x, h.pop()), dim=1)
572
+ x = block1(x, t)
573
+ x = block2(x, t)
574
+ x = spatial_attn(x)
575
+ x = temporal_attn(x, pos_bias=time_rel_pos_bias,
576
+ focus_present_mask=focus_present_mask)
577
+ x = upsample(x)
578
+
579
+ x = torch.cat((x, r), dim=1)
580
+ return self.final_conv(x)
581
+
582
+ # gaussian diffusion trainer class
583
+
584
+
585
+ def extract(a, t, x_shape):
586
+ b, *_ = t.shape
587
+ out = a.gather(-1, t)
588
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
589
+
590
+
591
+ def cosine_beta_schedule(timesteps, s=0.008):
592
+ """
593
+ cosine schedule
594
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
595
+ """
596
+ steps = timesteps + 1
597
+ x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
598
+ alphas_cumprod = torch.cos(
599
+ ((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
600
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
601
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
602
+ return torch.clip(betas, 0, 0.9999)
603
+
604
+
605
+ class GaussianDiffusion(nn.Module):
606
+ def __init__(
607
+ self,
608
+ denoise_fn,
609
+ *,
610
+ image_size,
611
+ num_frames,
612
+ text_use_bert_cls=False,
613
+ channels=3,
614
+ timesteps=1000,
615
+ loss_type='l1',
616
+ use_dynamic_thres=False, # from the Imagen paper
617
+ dynamic_thres_percentile=0.9,
618
+ vqgan_ckpt=None,
619
+ device=None
620
+ ):
621
+ super().__init__()
622
+ self.channels = channels
623
+ self.image_size = image_size
624
+ self.num_frames = num_frames
625
+ self.denoise_fn = denoise_fn
626
+ self.device = device
627
+
628
+ if vqgan_ckpt:
629
+ self.vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt).cuda()
630
+ self.vqgan.eval()
631
+ else:
632
+ self.vqgan = None
633
+
634
+ betas = cosine_beta_schedule(timesteps)
635
+
636
+ alphas = 1. - betas
637
+ alphas_cumprod = torch.cumprod(alphas, axis=0)
638
+ alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)
639
+
640
+ timesteps, = betas.shape
641
+ self.num_timesteps = int(timesteps)
642
+ self.loss_type = loss_type
643
+
644
+ # register buffer helper function that casts float64 to float32
645
+
646
+ def register_buffer(name, val): return self.register_buffer(
647
+ name, val.to(torch.float32))
648
+
649
+ register_buffer('betas', betas)
650
+ register_buffer('alphas_cumprod', alphas_cumprod)
651
+ register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
652
+
653
+ # calculations for diffusion q(x_t | x_{t-1}) and others
654
+
655
+ register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
656
+ register_buffer('sqrt_one_minus_alphas_cumprod',
657
+ torch.sqrt(1. - alphas_cumprod))
658
+ register_buffer('log_one_minus_alphas_cumprod',
659
+ torch.log(1. - alphas_cumprod))
660
+ register_buffer('sqrt_recip_alphas_cumprod',
661
+ torch.sqrt(1. / alphas_cumprod))
662
+ register_buffer('sqrt_recipm1_alphas_cumprod',
663
+ torch.sqrt(1. / alphas_cumprod - 1))
664
+
665
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
666
+
667
+ posterior_variance = betas * \
668
+ (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
669
+
670
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
671
+
672
+ register_buffer('posterior_variance', posterior_variance)
673
+
674
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
675
+
676
+ register_buffer('posterior_log_variance_clipped',
677
+ torch.log(posterior_variance.clamp(min=1e-20)))
678
+ register_buffer('posterior_mean_coef1', betas *
679
+ torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
680
+ register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev)
681
+ * torch.sqrt(alphas) / (1. - alphas_cumprod))
682
+
683
+ # text conditioning parameters
684
+
685
+ self.text_use_bert_cls = text_use_bert_cls
686
+
687
+ # dynamic thresholding when sampling
688
+
689
+ self.use_dynamic_thres = use_dynamic_thres
690
+ self.dynamic_thres_percentile = dynamic_thres_percentile
691
+
692
+ def q_mean_variance(self, x_start, t):
693
+ mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
694
+ variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
695
+ log_variance = extract(
696
+ self.log_one_minus_alphas_cumprod, t, x_start.shape)
697
+ return mean, variance, log_variance
698
+
699
+ def predict_start_from_noise(self, x_t, t, noise):
700
+ return (
701
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
702
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
703
+ )
704
+
705
+ def q_posterior(self, x_start, x_t, t):
706
+ posterior_mean = (
707
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
708
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
709
+ )
710
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
711
+ posterior_log_variance_clipped = extract(
712
+ self.posterior_log_variance_clipped, t, x_t.shape)
713
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
714
+
715
+ def p_mean_variance(self, x, t, clip_denoised: bool, cond=None, cond_scale=1.):
716
+ x_recon = self.predict_start_from_noise(
717
+ x, t=t, noise=self.denoise_fn.forward_with_cond_scale(x, t, cond=cond, cond_scale=cond_scale))
718
+
719
+ if clip_denoised:
720
+ s = 1.
721
+ if self.use_dynamic_thres:
722
+ s = torch.quantile(
723
+ rearrange(x_recon, 'b ... -> b (...)').abs(),
724
+ self.dynamic_thres_percentile,
725
+ dim=-1
726
+ )
727
+
728
+ s.clamp_(min=1.)
729
+ s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
730
+
731
+ # clip by threshold, depending on whether static or dynamic
732
+ x_recon = x_recon.clamp(-s, s) / s
733
+
734
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
735
+ x_start=x_recon, x_t=x, t=t)
736
+ return model_mean, posterior_variance, posterior_log_variance
737
+
738
+ @torch.inference_mode()
739
+ def p_sample(self, x, t, cond=None, cond_scale=1., clip_denoised=True):
740
+ b, *_, device = *x.shape, x.device
741
+ model_mean, _, model_log_variance = self.p_mean_variance(
742
+ x=x, t=t, clip_denoised=clip_denoised, cond=cond, cond_scale=cond_scale)
743
+ noise = torch.randn_like(x)
744
+ # no noise when t == 0
745
+ nonzero_mask = (1 - (t == 0).float()).reshape(b,
746
+ *((1,) * (len(x.shape) - 1)))
747
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
748
+
749
+ @torch.inference_mode()
750
+ def p_sample_loop(self, shape, cond=None, cond_scale=1.):
751
+ device = self.betas.device
752
+
753
+ b = shape[0]
754
+ img = torch.randn(shape, device=device)
755
+ # print('cond', cond.shape)
756
+ for i in reversed(range(0, self.num_timesteps)):
757
+ img = self.p_sample(img, torch.full(
758
+ (b,), i, device=device, dtype=torch.long), cond=cond, cond_scale=cond_scale)
759
+
760
+ return img
761
+
762
+ @torch.inference_mode()
763
+ def sample(self, cond=None, cond_scale=1., batch_size=16):
764
+ device = next(self.denoise_fn.parameters()).device
765
+
766
+ if is_list_str(cond):
767
+ cond = bert_embed(tokenize(cond)).to(device)
768
+
769
+ # batch_size = cond.shape[0] if exists(cond) else batch_size
770
+ batch_size = batch_size
771
+ image_size = self.image_size
772
+ channels = 8 # self.channels
773
+ num_frames = self.num_frames
774
+ # print((batch_size, channels, num_frames, image_size, image_size))
775
+ # print('cond_',cond.shape)
776
+ _sample = self.p_sample_loop(
777
+ (batch_size, channels, num_frames, image_size, image_size), cond=cond, cond_scale=cond_scale)
778
+
779
+ if isinstance(self.vqgan, VQGAN):
780
+ # denormalize TODO: Remove eventually
781
+ _sample = (((_sample + 1.0) / 2.0) * (self.vqgan.codebook.embeddings.max() -
782
+ self.vqgan.codebook.embeddings.min())) + self.vqgan.codebook.embeddings.min()
783
+
784
+ _sample = self.vqgan.decode(_sample, quantize=True)
785
+ else:
786
+ unnormalize_img(_sample)
787
+
788
+ return _sample
789
+
790
+ @torch.inference_mode()
791
+ def interpolate(self, x1, x2, t=None, lam=0.5):
792
+ b, *_, device = *x1.shape, x1.device
793
+ t = default(t, self.num_timesteps - 1)
794
+
795
+ assert x1.shape == x2.shape
796
+
797
+ t_batched = torch.stack([torch.tensor(t, device=device)] * b)
798
+ xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
799
+
800
+ img = (1 - lam) * xt1 + lam * xt2
801
+ for i in reversed(range(0, t)):
802
+ img = self.p_sample(img, torch.full(
803
+ (b,), i, device=device, dtype=torch.long))
804
+
805
+ return img
806
+
807
+ def q_sample(self, x_start, t, noise=None):
808
+ noise = default(noise, lambda: torch.randn_like(x_start))
809
+
810
+ return (
811
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
812
+ extract(self.sqrt_one_minus_alphas_cumprod,
813
+ t, x_start.shape) * noise
814
+ )
815
+
816
+ def p_losses(self, x_start, t, cond=None, noise=None, **kwargs):
817
+ b, c, f, h, w, device = *x_start.shape, x_start.device
818
+ noise = default(noise, lambda: torch.randn_like(x_start))
819
+ # breakpoint()
820
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # [2, 8, 32, 32, 32]
821
+
822
+ if is_list_str(cond):
823
+ cond = bert_embed(
824
+ tokenize(cond), return_cls_repr=self.text_use_bert_cls)
825
+ cond = cond.to(device)
826
+
827
+ x_recon = self.denoise_fn(x_noisy, t, cond=cond, **kwargs)
828
+
829
+ if self.loss_type == 'l1':
830
+ loss = F.l1_loss(noise, x_recon)
831
+ elif self.loss_type == 'l2':
832
+ loss = F.mse_loss(noise, x_recon)
833
+ else:
834
+ raise NotImplementedError()
835
+
836
+ return loss
837
+
838
+ def forward(self, x, *args, **kwargs):
839
+ bs = int(x.shape[0]/2)
840
+ img=x[:bs,...]
841
+ mask=x[bs:,...]
842
+ mask_=(1-mask).detach()
843
+ masked_img = (img*mask_).detach()
844
+ masked_img=masked_img.permute(0,1,-1,-3,-2)
845
+ img=img.permute(0,1,-1,-3,-2)
846
+ mask=mask.permute(0,1,-1,-3,-2)
847
+ # breakpoint()
848
+ if isinstance(self.vqgan, VQGAN):
849
+ with torch.no_grad():
850
+ img = self.vqgan.encode(
851
+ img, quantize=False, include_embeddings=True)
852
+ # normalize to -1 and 1
853
+ img = ((img - self.vqgan.codebook.embeddings.min()) /
854
+ (self.vqgan.codebook.embeddings.max() -
855
+ self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0
856
+
857
+ masked_img = self.vqgan.encode(
858
+ masked_img, quantize=False, include_embeddings=True)
859
+ # normalize to -1 and 1
860
+ masked_img = ((masked_img - self.vqgan.codebook.embeddings.min()) /
861
+ (self.vqgan.codebook.embeddings.max() -
862
+ self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0
863
+ else:
864
+ print("Hi")
865
+ img = normalize_img(img)
866
+ masked_img = normalize_img(masked_img)
867
+ mask = mask*2.0 - 1.0
868
+ cc = torch.nn.functional.interpolate(mask, size=masked_img.shape[-3:])
869
+ cond = torch.cat((masked_img, cc), dim=1)
870
+
871
+ b, device, img_size, = img.shape[0], img.device, self.image_size
872
+ t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
873
+ # breakpoint()
874
+ return self.p_losses(img, t, cond=cond, *args, **kwargs)
875
+
876
+ # trainer class
877
+
878
+
879
+ CHANNELS_TO_MODE = {
880
+ 1: 'L',
881
+ 3: 'RGB',
882
+ 4: 'RGBA'
883
+ }
884
+
885
+
886
+ def seek_all_images(img, channels=3):
887
+ assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid'
888
+ mode = CHANNELS_TO_MODE[channels]
889
+
890
+ i = 0
891
+ while True:
892
+ try:
893
+ img.seek(i)
894
+ yield img.convert(mode)
895
+ except EOFError:
896
+ break
897
+ i += 1
898
+
899
+ # tensor of shape (channels, frames, height, width) -> gif
900
+
901
+
902
+ def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True):
903
+ tensor = ((tensor - tensor.min()) / (tensor.max() - tensor.min())) * 1.0
904
+ images = map(T.ToPILImage(), tensor.unbind(dim=1))
905
+ first_img, *rest_imgs = images
906
+ first_img.save(path, save_all=True, append_images=rest_imgs,
907
+ duration=duration, loop=loop, optimize=optimize)
908
+ return images
909
+
910
+ # gif -> (channels, frame, height, width) tensor
911
+
912
+
913
+ def gif_to_tensor(path, channels=3, transform=T.ToTensor()):
914
+ img = Image.open(path)
915
+ tensors = tuple(map(transform, seek_all_images(img, channels=channels)))
916
+ return torch.stack(tensors, dim=1)
917
+
918
+
919
+ def identity(t, *args, **kwargs):
920
+ return t
921
+
922
+
923
+ def normalize_img(t):
924
+ return t * 2 - 1
925
+
926
+
927
+ def unnormalize_img(t):
928
+ return (t + 1) * 0.5
929
+
930
+
931
+ def cast_num_frames(t, *, frames):
932
+ f = t.shape[1]
933
+
934
+ if f == frames:
935
+ return t
936
+
937
+ if f > frames:
938
+ return t[:, :frames]
939
+
940
+ return F.pad(t, (0, 0, 0, 0, 0, frames - f))
941
+
942
+
943
+ class Dataset(data.Dataset):
944
+ def __init__(
945
+ self,
946
+ folder,
947
+ image_size,
948
+ channels=3,
949
+ num_frames=16,
950
+ horizontal_flip=False,
951
+ force_num_frames=True,
952
+ exts=['gif']
953
+ ):
954
+ super().__init__()
955
+ self.folder = folder
956
+ self.image_size = image_size
957
+ self.channels = channels
958
+ self.paths = [p for ext in exts for p in Path(
959
+ f'{folder}').glob(f'**/*.{ext}')]
960
+
961
+ self.cast_num_frames_fn = partial(
962
+ cast_num_frames, frames=num_frames) if force_num_frames else identity
963
+
964
+ self.transform = T.Compose([
965
+ T.Resize(image_size),
966
+ T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity),
967
+ T.CenterCrop(image_size),
968
+ T.ToTensor()
969
+ ])
970
+
971
+ def __len__(self):
972
+ return len(self.paths)
973
+
974
+ def __getitem__(self, index):
975
+ path = self.paths[index]
976
+ tensor = gif_to_tensor(path, self.channels, transform=self.transform)
977
+ return self.cast_num_frames_fn(tensor)
978
+
979
+ # trainer class
980
+
981
+
982
+ class Tester(object):
983
+ def __init__(
984
+ self,
985
+ diffusion_model,
986
+ ):
987
+ super().__init__()
988
+ self.model = diffusion_model
989
+ self.ema_model = copy.deepcopy(self.model)
990
+ self.step=0
991
+ self.image_size = diffusion_model.image_size
992
+
993
+ self.reset_parameters()
994
+
995
+ def reset_parameters(self):
996
+ self.ema_model.load_state_dict(self.model.state_dict())
997
+
998
+
999
+ def load(self, milestone, map_location=None, **kwargs):
1000
+ if milestone == -1:
1001
+ all_milestones = [int(p.stem.split('-')[-1])
1002
+ for p in Path(self.results_folder).glob('**/*.pt')]
1003
+ assert len(
1004
+ all_milestones) > 0, 'need to have at least one milestone to load from latest checkpoint (milestone == -1)'
1005
+ milestone = max(all_milestones)
1006
+
1007
+ if map_location:
1008
+ data = torch.load(milestone, map_location=map_location)
1009
+ else:
1010
+ data = torch.load(milestone)
1011
+
1012
+ self.step = data['step']
1013
+ self.model.load_state_dict(data['model'], **kwargs)
1014
+ self.ema_model.load_state_dict(data['ema'], **kwargs)
1015
+
1016
+
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/text.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch"
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+
7
+ def exists(val):
8
+ return val is not None
9
+
10
+ # singleton globals
11
+
12
+
13
+ MODEL = None
14
+ TOKENIZER = None
15
+ BERT_MODEL_DIM = 768
16
+
17
+
18
+ def get_tokenizer():
19
+ global TOKENIZER
20
+ if not exists(TOKENIZER):
21
+ TOKENIZER = torch.hub.load(
22
+ 'huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased')
23
+ return TOKENIZER
24
+
25
+
26
+ def get_bert():
27
+ global MODEL
28
+ if not exists(MODEL):
29
+ MODEL = torch.hub.load(
30
+ 'huggingface/pytorch-transformers', 'model', 'bert-base-cased')
31
+ if torch.cuda.is_available():
32
+ MODEL = MODEL.cuda()
33
+
34
+ return MODEL
35
+
36
+ # tokenize
37
+
38
+
39
+ def tokenize(texts, add_special_tokens=True):
40
+ if not isinstance(texts, (list, tuple)):
41
+ texts = [texts]
42
+
43
+ tokenizer = get_tokenizer()
44
+
45
+ encoding = tokenizer.batch_encode_plus(
46
+ texts,
47
+ add_special_tokens=add_special_tokens,
48
+ padding=True,
49
+ return_tensors='pt'
50
+ )
51
+
52
+ token_ids = encoding.input_ids
53
+ return token_ids
54
+
55
+ # embedding function
56
+
57
+
58
+ @torch.no_grad()
59
+ def bert_embed(
60
+ token_ids,
61
+ return_cls_repr=False,
62
+ eps=1e-8,
63
+ pad_id=0.
64
+ ):
65
+ model = get_bert()
66
+ mask = token_ids != pad_id
67
+
68
+ if torch.cuda.is_available():
69
+ token_ids = token_ids.cuda()
70
+ mask = mask.cuda()
71
+
72
+ outputs = model(
73
+ input_ids=token_ids,
74
+ attention_mask=mask,
75
+ output_hidden_states=True
76
+ )
77
+
78
+ hidden_state = outputs.hidden_states[-1]
79
+
80
+ if return_cls_repr:
81
+ # return [cls] as representation
82
+ return hidden_state[:, 0]
83
+
84
+ if not exists(mask):
85
+ return hidden_state.mean(dim=1)
86
+
87
+ # mean all tokens excluding [cls], accounting for length
88
+ mask = mask[:, 1:]
89
+ mask = rearrange(mask, 'b n -> b n 1')
90
+
91
+ numer = (hidden_state[:, 1:] * mask).sum(dim=1)
92
+ denom = mask.sum(dim=1)
93
+ masked_mean = numer / (denom + eps)
94
+ return
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/time_embedding.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from monai.networks.layers.utils import get_act_layer
6
+
7
+
8
+ class SinusoidalPosEmb(nn.Module):
9
+ def __init__(self, emb_dim=16, downscale_freq_shift=1, max_period=10000, flip_sin_to_cos=False):
10
+ super().__init__()
11
+ self.emb_dim = emb_dim
12
+ self.downscale_freq_shift = downscale_freq_shift
13
+ self.max_period = max_period
14
+ self.flip_sin_to_cos = flip_sin_to_cos
15
+
16
+ def forward(self, x):
17
+ device = x.device
18
+ half_dim = self.emb_dim // 2
19
+ emb = math.log(self.max_period) / \
20
+ (half_dim - self.downscale_freq_shift)
21
+ emb = torch.exp(-emb*torch.arange(half_dim, device=device))
22
+ emb = x[:, None] * emb[None, :]
23
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
24
+
25
+ if self.flip_sin_to_cos:
26
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
27
+
28
+ if self.emb_dim % 2 == 1:
29
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
30
+ return emb
31
+
32
+
33
+ class LearnedSinusoidalPosEmb(nn.Module):
34
+ """ following @crowsonkb 's lead with learned sinusoidal pos emb """
35
+ """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
36
+
37
+ def __init__(self, emb_dim):
38
+ super().__init__()
39
+ self.emb_dim = emb_dim
40
+ half_dim = emb_dim // 2
41
+ self.weights = nn.Parameter(torch.randn(half_dim))
42
+
43
+ def forward(self, x):
44
+ x = x[:, None]
45
+ freqs = x * self.weights[None, :] * 2 * math.pi
46
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
47
+ fouriered = torch.cat((x, fouriered), dim=-1)
48
+ if self.emb_dim % 2 == 1:
49
+ fouriered = torch.nn.functional.pad(fouriered, (0, 1, 0, 0))
50
+ return fouriered
51
+
52
+
53
+ class TimeEmbbeding(nn.Module):
54
+ def __init__(
55
+ self,
56
+ emb_dim=64,
57
+ pos_embedder=SinusoidalPosEmb,
58
+ pos_embedder_kwargs={},
59
+ act_name=("SWISH", {}) # Swish = SiLU
60
+ ):
61
+ super().__init__()
62
+ self.emb_dim = emb_dim
63
+ self.pos_emb_dim = pos_embedder_kwargs.get('emb_dim', emb_dim//4)
64
+ pos_embedder_kwargs['emb_dim'] = self.pos_emb_dim
65
+ self.pos_embedder = pos_embedder(**pos_embedder_kwargs)
66
+
67
+ self.time_emb = nn.Sequential(
68
+ self.pos_embedder,
69
+ nn.Linear(self.pos_emb_dim, self.emb_dim),
70
+ get_act_layer(act_name),
71
+ nn.Linear(self.emb_dim, self.emb_dim)
72
+ )
73
+
74
+ def forward(self, time):
75
+ return self.time_emb(time)
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/unet.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ddpm.time_embedding import TimeEmbbeding
2
+
3
+ import monai.networks.nets as nets
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+
8
+ from monai.networks.blocks import UnetBasicBlock, UnetResBlock, UnetUpBlock, Convolution, UnetOutBlock
9
+ from monai.networks.layers.utils import get_act_layer
10
+
11
+
12
+ class DownBlock(nn.Module):
13
+ def __init__(
14
+ self,
15
+ spatial_dims,
16
+ in_ch,
17
+ out_ch,
18
+ time_emb_dim,
19
+ cond_emb_dim,
20
+ act_name=("swish", {}),
21
+ **kwargs):
22
+ super(DownBlock, self).__init__()
23
+ self.loca_time_embedder = nn.Sequential(
24
+ get_act_layer(name=act_name),
25
+ nn.Linear(time_emb_dim, in_ch) # in_ch * 2
26
+ )
27
+ if cond_emb_dim is not None:
28
+ self.loca_cond_embedder = nn.Sequential(
29
+ get_act_layer(name=act_name),
30
+ nn.Linear(cond_emb_dim, in_ch),
31
+ )
32
+ self.down_op = UnetBasicBlock(
33
+ spatial_dims, in_ch, out_ch, act_name=act_name, **kwargs)
34
+
35
+ def forward(self, x, time_emb, cond_emb):
36
+ b, c, *_ = x.shape
37
+ sp_dim = x.ndim-2
38
+
39
+ # ------------ Time ----------
40
+ time_emb = self.loca_time_embedder(time_emb)
41
+ time_emb = time_emb.reshape(b, c, *((1,)*sp_dim))
42
+ # scale, shift = time_emb.chunk(2, dim = 1)
43
+
44
+ # ------------ Combine ------------
45
+ # x = x * (scale + 1) + shift
46
+ x = x + time_emb
47
+
48
+ # ----------- Condition ------------
49
+ if cond_emb is not None:
50
+ cond_emb = self.loca_cond_embedder(cond_emb)
51
+ cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim))
52
+ x = x + cond_emb
53
+
54
+ # ----------- Image ---------
55
+ y = self.down_op(x)
56
+ return y
57
+
58
+
59
+ class UpBlock(nn.Module):
60
+ def __init__(
61
+ self,
62
+ spatial_dims,
63
+ skip_ch,
64
+ enc_ch,
65
+ time_emb_dim,
66
+ cond_emb_dim,
67
+ act_name=("swish", {}),
68
+ **kwargs):
69
+ super(UpBlock, self).__init__()
70
+ self.up_op = UnetUpBlock(spatial_dims, enc_ch,
71
+ skip_ch, act_name=act_name, **kwargs)
72
+ self.loca_time_embedder = nn.Sequential(
73
+ get_act_layer(name=act_name),
74
+ nn.Linear(time_emb_dim, skip_ch * 2),
75
+ )
76
+ if cond_emb_dim is not None:
77
+ self.loca_cond_embedder = nn.Sequential(
78
+ get_act_layer(name=act_name),
79
+ nn.Linear(cond_emb_dim, skip_ch * 2),
80
+ )
81
+
82
+ def forward(self, x_skip, x_enc, time_emb, cond_emb):
83
+ b, c, *_ = x_enc.shape
84
+ sp_dim = x_enc.ndim-2
85
+
86
+ # ----------- Time --------------
87
+ time_emb = self.loca_time_embedder(time_emb)
88
+ time_emb = time_emb.reshape(b, c, *((1,)*sp_dim))
89
+ # scale, shift = time_emb.chunk(2, dim = 1)
90
+
91
+ # -------- Combine -------------
92
+ # y = x * (scale + 1) + shift
93
+ x_enc = x_enc + time_emb
94
+
95
+ # ----------- Condition ------------
96
+ if cond_emb is not None:
97
+ cond_emb = self.loca_cond_embedder(cond_emb)
98
+ cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim))
99
+ x_enc = x_enc + cond_emb
100
+
101
+ # ----------- Image -------------
102
+ y = self.up_op(x_enc, x_skip)
103
+
104
+ # -------- Combine -------------
105
+ # y = y * (scale + 1) + shift
106
+
107
+ return y
108
+
109
+
110
+ class UNet(nn.Module):
111
+
112
+ def __init__(self,
113
+ in_ch=1,
114
+ out_ch=1,
115
+ spatial_dims=3,
116
+ hid_chs=[32, 64, 128, 256, 512],
117
+ kernel_sizes=[(1, 3, 3), (1, 3, 3), (1, 3, 3), 3, 3],
118
+ strides=[1, (1, 2, 2), (1, 2, 2), 2, 2],
119
+ upsample_kernel_sizes=None,
120
+ act_name=("SWISH", {}),
121
+ norm_name=("INSTANCE", {"affine": True}),
122
+ time_embedder=TimeEmbbeding,
123
+ time_embedder_kwargs={},
124
+ cond_embedder=None,
125
+ cond_embedder_kwargs={},
126
+ # True = all but last layer, 0/False=disable, 1=only first layer, ...
127
+ deep_ver_supervision=True,
128
+ estimate_variance=False,
129
+ use_self_conditioning=False,
130
+ **kwargs
131
+ ):
132
+ super().__init__()
133
+ if upsample_kernel_sizes is None:
134
+ upsample_kernel_sizes = strides[1:]
135
+
136
+ # ------------- Time-Embedder-----------
137
+ self.time_embedder = time_embedder(**time_embedder_kwargs)
138
+
139
+ # ------------- Condition-Embedder-----------
140
+ if cond_embedder is not None:
141
+ self.cond_embedder = cond_embedder(**cond_embedder_kwargs)
142
+ cond_emb_dim = self.cond_embedder.emb_dim
143
+ else:
144
+ self.cond_embedder = None
145
+ cond_emb_dim = None
146
+
147
+ # ----------- In-Convolution ------------
148
+ in_ch = in_ch*2 if use_self_conditioning else in_ch
149
+ self.inc = UnetBasicBlock(spatial_dims, in_ch, hid_chs[0], kernel_size=kernel_sizes[0], stride=strides[0],
150
+ act_name=act_name, norm_name=norm_name, **kwargs)
151
+
152
+ # ----------- Encoder ----------------
153
+ self.encoders = nn.ModuleList([
154
+ DownBlock(spatial_dims, hid_chs[i-1], hid_chs[i], time_emb_dim=self.time_embedder.emb_dim,
155
+ cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[
156
+ i], stride=strides[i], act_name=act_name,
157
+ norm_name=norm_name, **kwargs)
158
+ for i in range(1, len(strides))
159
+ ])
160
+
161
+ # ------------ Decoder ----------
162
+ self.decoders = nn.ModuleList([
163
+ UpBlock(spatial_dims, hid_chs[i], hid_chs[i+1], time_emb_dim=self.time_embedder.emb_dim,
164
+ cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[i +
165
+ 1], stride=strides[i+1], act_name=act_name,
166
+ norm_name=norm_name, upsample_kernel_size=upsample_kernel_sizes[i], **kwargs)
167
+ for i in range(len(strides)-1)
168
+ ])
169
+
170
+ # --------------- Out-Convolution ----------------
171
+ out_ch_hor = out_ch*2 if estimate_variance else out_ch
172
+ self.outc = UnetOutBlock(
173
+ spatial_dims, hid_chs[0], out_ch_hor, dropout=None)
174
+ if isinstance(deep_ver_supervision, bool):
175
+ deep_ver_supervision = len(
176
+ strides)-2 if deep_ver_supervision else 0
177
+ self.outc_ver = nn.ModuleList([
178
+ UnetOutBlock(spatial_dims, hid_chs[i], out_ch, dropout=None)
179
+ for i in range(1, deep_ver_supervision+1)
180
+ ])
181
+
182
+ def forward(self, x_t, t, cond=None, self_cond=None, **kwargs):
183
+ condition = cond
184
+ # x_t [B, C, (D), H, W]
185
+ # t [B,]
186
+
187
+ # -------- In-Convolution --------------
188
+ x = [None for _ in range(len(self.encoders)+1)]
189
+ x_t = torch.cat([x_t, self_cond],
190
+ dim=1) if self_cond is not None else x_t
191
+ x[0] = self.inc(x_t)
192
+
193
+ # -------- Time Embedding (Gloabl) -----------
194
+ time_emb = self.time_embedder(t) # [B, C]
195
+
196
+ # -------- Condition Embedding (Gloabl) -----------
197
+ if (condition is None) or (self.cond_embedder is None):
198
+ cond_emb = None
199
+ else:
200
+ cond_emb = self.cond_embedder(condition) # [B, C]
201
+
202
+ # --------- Encoder --------------
203
+ for i in range(len(self.encoders)):
204
+ x[i+1] = self.encoders[i](x[i], time_emb, cond_emb)
205
+
206
+ # -------- Decoder -----------
207
+ for i in range(len(self.decoders), 0, -1):
208
+ x[i-1] = self.decoders[i-1](x[i-1], x[i], time_emb, cond_emb)
209
+
210
+ # ---------Out-Convolution ------------
211
+ y_hor = self.outc(x[0])
212
+ y_ver = [outc_ver_i(x[i+1])
213
+ for i, outc_ver_i in enumerate(self.outc_ver)]
214
+
215
+ return y_hor # , y_ver
216
+
217
+ def forward_with_cond_scale(self, *args, cond_scale=0., **kwargs):
218
+ return self.forward(*args, **kwargs)
219
+
220
+
221
+ if __name__ == '__main__':
222
+ model = UNet(in_ch=3)
223
+ input = torch.randn((1, 3, 16, 128, 128))
224
+ time = torch.randn((1,))
225
+ out_hor, out_ver = model(input, time)
226
+ print(out_hor[0].shape)
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/util.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from einops import repeat
17
+
18
+ # from ldm.util import instantiate_from_config
19
+
20
+
21
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22
+ if schedule == "linear":
23
+ betas = (
24
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25
+ )
26
+
27
+ elif schedule == "cosine":
28
+ timesteps = (
29
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30
+ )
31
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
32
+ alphas = torch.cos(alphas).pow(2)
33
+ alphas = alphas / alphas[0]
34
+ betas = 1 - alphas[1:] / alphas[:-1]
35
+ betas = np.clip(betas, a_min=0, a_max=0.999)
36
+
37
+ elif schedule == "sqrt_linear":
38
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39
+ elif schedule == "sqrt":
40
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41
+ else:
42
+ raise ValueError(f"schedule '{schedule}' unknown.")
43
+ return betas.numpy()
44
+
45
+
46
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47
+ if ddim_discr_method == 'uniform':
48
+ c = num_ddpm_timesteps // num_ddim_timesteps
49
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50
+ elif ddim_discr_method == 'quad':
51
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52
+ else:
53
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54
+
55
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
57
+ if c != 1:
58
+ steps_out = ddim_timesteps + 1
59
+ else:
60
+ steps_out = ddim_timesteps
61
+ if verbose:
62
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
63
+ return steps_out
64
+
65
+
66
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
67
+ # select alphas for computing the variance schedule
68
+
69
+ alphas = alphacums[ddim_timesteps]
70
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
71
+
72
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
73
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
74
+ if verbose:
75
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
76
+ print(f'For the chosen value of eta, which is {eta}, '
77
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
78
+ return sigmas, alphas, alphas_prev
79
+
80
+
81
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
82
+ """
83
+ Create a beta schedule that discretizes the given alpha_t_bar function,
84
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
85
+ :param num_diffusion_timesteps: the number of betas to produce.
86
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
87
+ produces the cumulative product of (1-beta) up to that
88
+ part of the diffusion process.
89
+ :param max_beta: the maximum beta to use; use values lower than 1 to
90
+ prevent singularities.
91
+ """
92
+ betas = []
93
+ for i in range(num_diffusion_timesteps):
94
+ t1 = i / num_diffusion_timesteps
95
+ t2 = (i + 1) / num_diffusion_timesteps
96
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
97
+ return np.array(betas)
98
+
99
+
100
+ def extract_into_tensor(a, t, x_shape):
101
+ b, *_ = t.shape
102
+ out = a.gather(-1, t)
103
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
104
+
105
+
106
+ def checkpoint(func, inputs, params, flag):
107
+ """
108
+ Evaluate a function without caching intermediate activations, allowing for
109
+ reduced memory at the expense of extra compute in the backward pass.
110
+ :param func: the function to evaluate.
111
+ :param inputs: the argument sequence to pass to `func`.
112
+ :param params: a sequence of parameters `func` depends on but does not
113
+ explicitly take as arguments.
114
+ :param flag: if False, disable gradient checkpointing.
115
+ """
116
+ if flag:
117
+ args = tuple(inputs) + tuple(params)
118
+ return CheckpointFunction.apply(func, len(inputs), *args)
119
+ else:
120
+ return func(*inputs)
121
+
122
+
123
+ class CheckpointFunction(torch.autograd.Function):
124
+ @staticmethod
125
+ def forward(ctx, run_function, length, *args):
126
+ ctx.run_function = run_function
127
+ ctx.input_tensors = list(args[:length])
128
+ ctx.input_params = list(args[length:])
129
+
130
+ with torch.no_grad():
131
+ output_tensors = ctx.run_function(*ctx.input_tensors)
132
+ return output_tensors
133
+
134
+ @staticmethod
135
+ def backward(ctx, *output_grads):
136
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
137
+ with torch.enable_grad():
138
+ # Fixes a bug where the first op in run_function modifies the
139
+ # Tensor storage in place, which is not allowed for detach()'d
140
+ # Tensors.
141
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
142
+ output_tensors = ctx.run_function(*shallow_copies)
143
+ input_grads = torch.autograd.grad(
144
+ output_tensors,
145
+ ctx.input_tensors + ctx.input_params,
146
+ output_grads,
147
+ allow_unused=True,
148
+ )
149
+ del ctx.input_tensors
150
+ del ctx.input_params
151
+ del output_tensors
152
+ return (None, None) + input_grads
153
+
154
+
155
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
156
+ """
157
+ Create sinusoidal timestep embeddings.
158
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
159
+ These may be fractional.
160
+ :param dim: the dimension of the output.
161
+ :param max_period: controls the minimum frequency of the embeddings.
162
+ :return: an [N x dim] Tensor of positional embeddings.
163
+ """
164
+ if not repeat_only:
165
+ half = dim // 2
166
+ freqs = torch.exp(
167
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
168
+ ).to(device=timesteps.device)
169
+ args = timesteps[:, None].float() * freqs[None]
170
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
171
+ if dim % 2:
172
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
173
+ else:
174
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
175
+ return embedding
176
+
177
+
178
+ def zero_module(module):
179
+ """
180
+ Zero out the parameters of a module and return it.
181
+ """
182
+ for p in module.parameters():
183
+ p.detach().zero_()
184
+ return module
185
+
186
+
187
+ def scale_module(module, scale):
188
+ """
189
+ Scale the parameters of a module and return it.
190
+ """
191
+ for p in module.parameters():
192
+ p.detach().mul_(scale)
193
+ return module
194
+
195
+
196
+ def mean_flat(tensor):
197
+ """
198
+ Take the mean over all non-batch dimensions.
199
+ """
200
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
201
+
202
+
203
+ def normalization(channels):
204
+ """
205
+ Make a standard normalization layer.
206
+ :param channels: number of input channels.
207
+ :return: an nn.Module for normalization.
208
+ """
209
+ return GroupNorm32(32, channels)
210
+
211
+
212
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
213
+ class SiLU(nn.Module):
214
+ def forward(self, x):
215
+ return x * torch.sigmoid(x)
216
+
217
+
218
+ class GroupNorm32(nn.GroupNorm):
219
+ def forward(self, x):
220
+ return super().forward(x.float()).type(x.dtype)
221
+
222
+ def conv_nd(dims, *args, **kwargs):
223
+ """
224
+ Create a 1D, 2D, or 3D convolution module.
225
+ """
226
+ if dims == 1:
227
+ return nn.Conv1d(*args, **kwargs)
228
+ elif dims == 2:
229
+ return nn.Conv2d(*args, **kwargs)
230
+ elif dims == 3:
231
+ return nn.Conv3d(*args, **kwargs)
232
+ raise ValueError(f"unsupported dimensions: {dims}")
233
+
234
+
235
+ def linear(*args, **kwargs):
236
+ """
237
+ Create a linear module.
238
+ """
239
+ return nn.Linear(*args, **kwargs)
240
+
241
+
242
+ def avg_pool_nd(dims, *args, **kwargs):
243
+ """
244
+ Create a 1D, 2D, or 3D average pooling module.
245
+ """
246
+ if dims == 1:
247
+ return nn.AvgPool1d(*args, **kwargs)
248
+ elif dims == 2:
249
+ return nn.AvgPool2d(*args, **kwargs)
250
+ elif dims == 3:
251
+ return nn.AvgPool3d(*args, **kwargs)
252
+ raise ValueError(f"unsupported dimensions: {dims}")
253
+
254
+
255
+ class HybridConditioner(nn.Module):
256
+
257
+ def __init__(self, c_concat_config, c_crossattn_config):
258
+ super().__init__()
259
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
260
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
261
+
262
+ def forward(self, c_concat, c_crossattn):
263
+ c_concat = self.concat_conditioner(c_concat)
264
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
265
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
266
+
267
+
268
+ def noise_like(shape, device, repeat=False):
269
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
270
+ noise = lambda: torch.randn(shape, device=device)
271
+ return repeat_noise() if repeat else noise()
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc ADDED
Binary file (4.89 kB). View file
 
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .vqgan import VQGAN
2
+ from .codebook import Codebook
3
+ from .lpips import LPIPS
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (322 Bytes). View file
 
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc ADDED
Binary file (3.45 kB). View file
 
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc ADDED
Binary file (6.82 kB). View file
 
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc ADDED
Binary file (16.6 kB). View file
 
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
3
+ size 7289
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/codebook.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Adapted from https://github.com/SongweiGe/TATS"""
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ import numpy as np
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.distributed as dist
10
+
11
+ from ..utils import shift_dim
12
+
13
+
14
+ class Codebook(nn.Module):
15
+ def __init__(self, n_codes, embedding_dim, no_random_restart=False, restart_thres=1.0):
16
+ super().__init__()
17
+ self.register_buffer('embeddings', torch.randn(n_codes, embedding_dim))
18
+ self.register_buffer('N', torch.zeros(n_codes))
19
+ self.register_buffer('z_avg', self.embeddings.data.clone())
20
+
21
+ self.n_codes = n_codes
22
+ self.embedding_dim = embedding_dim
23
+ self._need_init = True
24
+ self.no_random_restart = no_random_restart
25
+ self.restart_thres = restart_thres
26
+
27
+ def _tile(self, x):
28
+ d, ew = x.shape
29
+ if d < self.n_codes:
30
+ n_repeats = (self.n_codes + d - 1) // d
31
+ std = 0.01 / np.sqrt(ew)
32
+ x = x.repeat(n_repeats, 1)
33
+ x = x + torch.randn_like(x) * std
34
+ return x
35
+
36
+ def _init_embeddings(self, z):
37
+ # z: [b, c, t, h, w]
38
+ self._need_init = False
39
+ breakpoint()
40
+ flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [65536, 8] [2, 8, 32, 32, 32]
41
+ y = self._tile(flat_inputs) # [65536, 8]
42
+
43
+ d = y.shape[0]
44
+ _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes]
45
+ if dist.is_initialized():
46
+ dist.broadcast(_k_rand, 0)
47
+ self.embeddings.data.copy_(_k_rand)
48
+ self.z_avg.data.copy_(_k_rand)
49
+ self.N.data.copy_(torch.ones(self.n_codes))
50
+
51
+ def forward(self, z):
52
+ # z: [b, c, t, h, w]
53
+ if self._need_init and self.training:
54
+ self._init_embeddings(z)
55
+ flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [bthw, c] [65536, 8]
56
+ distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \
57
+ - 2 * flat_inputs @ self.embeddings.t() \
58
+ + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) # [bthw, c] [65536, 8]
59
+
60
+ encoding_indices = torch.argmin(distances, dim=1) # [65536]
61
+ encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(
62
+ flat_inputs) # [bthw, ncode] [65536, 16384]
63
+ encoding_indices = encoding_indices.view(
64
+ z.shape[0], *z.shape[2:]) # [b, t, h, w, ncode] [2, 32, 32, 32]
65
+
66
+ embeddings = F.embedding(
67
+ encoding_indices, self.embeddings) # [b, t, h, w, c] self.embeddings [16384, 8]
68
+ embeddings = shift_dim(embeddings, -1, 1) # [b, c, t, h, w] [2, 8, 32, 32, 32]
69
+
70
+ commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
71
+
72
+ # EMA codebook update
73
+ if self.training:
74
+ n_total = encode_onehot.sum(dim=0) # [16384]
75
+ encode_sum = flat_inputs.t() @ encode_onehot # [8, 16384]
76
+ if dist.is_initialized():
77
+ dist.all_reduce(n_total)
78
+ dist.all_reduce(encode_sum)
79
+
80
+ self.N.data.mul_(0.99).add_(n_total, alpha=0.01)
81
+ self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01)
82
+
83
+ n = self.N.sum()
84
+ weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n
85
+ encode_normalized = self.z_avg / weights.unsqueeze(1)
86
+ self.embeddings.data.copy_(encode_normalized)
87
+
88
+ y = self._tile(flat_inputs)
89
+ _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes]
90
+ if dist.is_initialized():
91
+ dist.broadcast(_k_rand, 0)
92
+
93
+ if not self.no_random_restart:
94
+ usage = (self.N.view(self.n_codes, 1)
95
+ >= self.restart_thres).float()
96
+ self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage))
97
+
98
+ embeddings_st = (embeddings - z).detach() + z
99
+
100
+ avg_probs = torch.mean(encode_onehot, dim=0)
101
+ perplexity = torch.exp(-torch.sum(avg_probs *
102
+ torch.log(avg_probs + 1e-10)))
103
+
104
+ return dict(embeddings=embeddings_st, encodings=encoding_indices,
105
+ commitment_loss=commitment_loss, perplexity=perplexity)
106
+
107
+ def dictionary_lookup(self, encodings):
108
+ embeddings = F.embedding(encodings, self.embeddings)
109
+ return embeddings
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/lpips.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adapted from https://github.com/SongweiGe/TATS"""
2
+
3
+ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
4
+
5
+
6
+ from collections import namedtuple
7
+ from torchvision import models
8
+ import torch.nn as nn
9
+ import torch
10
+ from tqdm import tqdm
11
+ import requests
12
+ import os
13
+ import hashlib
14
+ URL_MAP = {
15
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
16
+ }
17
+
18
+ CKPT_MAP = {
19
+ "vgg_lpips": "vgg.pth"
20
+ }
21
+
22
+ MD5_MAP = {
23
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
24
+ }
25
+
26
+
27
+ def download(url, local_path, chunk_size=1024):
28
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
29
+ with requests.get(url, stream=True) as r:
30
+ total_size = int(r.headers.get("content-length", 0))
31
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
32
+ with open(local_path, "wb") as f:
33
+ for data in r.iter_content(chunk_size=chunk_size):
34
+ if data:
35
+ f.write(data)
36
+ pbar.update(chunk_size)
37
+
38
+
39
+ def md5_hash(path):
40
+ with open(path, "rb") as f:
41
+ content = f.read()
42
+ return hashlib.md5(content).hexdigest()
43
+
44
+
45
+ def get_ckpt_path(name, root, check=False):
46
+ assert name in URL_MAP
47
+ path = os.path.join(root, CKPT_MAP[name])
48
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
49
+ print("Downloading {} model from {} to {}".format(
50
+ name, URL_MAP[name], path))
51
+ download(URL_MAP[name], path)
52
+ md5 = md5_hash(path)
53
+ assert md5 == MD5_MAP[name], md5
54
+ return path
55
+
56
+
57
+ class LPIPS(nn.Module):
58
+ # Learned perceptual metric
59
+ def __init__(self, use_dropout=True):
60
+ super().__init__()
61
+ self.scaling_layer = ScalingLayer()
62
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
63
+ self.net = vgg16(pretrained=True, requires_grad=False)
64
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
65
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
66
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
67
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
68
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
69
+ # self.load_from_pretrained()
70
+ for param in self.parameters():
71
+ param.requires_grad = False
72
+
73
+ def load_from_pretrained(self, name="vgg_lpips"):
74
+ ckpt = get_ckpt_path(name, os.path.join(
75
+ os.path.dirname(os.path.abspath(__file__)), "cache"))
76
+ self.load_state_dict(torch.load(
77
+ ckpt, map_location=torch.device("cpu")), strict=False)
78
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
79
+
80
+ @classmethod
81
+ def from_pretrained(cls, name="vgg_lpips"):
82
+ if name is not "vgg_lpips":
83
+ raise NotImplementedError
84
+ model = cls()
85
+ ckpt = get_ckpt_path(name, os.path.join(
86
+ os.path.dirname(os.path.abspath(__file__)), "cache"))
87
+ model.load_state_dict(torch.load(
88
+ ckpt, map_location=torch.device("cpu")), strict=False)
89
+ return model
90
+
91
+ def forward(self, input, target):
92
+ in0_input, in1_input = (self.scaling_layer(
93
+ input), self.scaling_layer(target))
94
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
95
+ feats0, feats1, diffs = {}, {}, {}
96
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
97
+ for kk in range(len(self.chns)):
98
+ feats0[kk], feats1[kk] = normalize_tensor(
99
+ outs0[kk]), normalize_tensor(outs1[kk])
100
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
101
+
102
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
103
+ for kk in range(len(self.chns))]
104
+ val = res[0]
105
+ for l in range(1, len(self.chns)):
106
+ val += res[l]
107
+ return val
108
+
109
+
110
+ class ScalingLayer(nn.Module):
111
+ def __init__(self):
112
+ super(ScalingLayer, self).__init__()
113
+ self.register_buffer('shift', torch.Tensor(
114
+ [-.030, -.088, -.188])[None, :, None, None])
115
+ self.register_buffer('scale', torch.Tensor(
116
+ [.458, .448, .450])[None, :, None, None])
117
+
118
+ def forward(self, inp):
119
+ return (inp - self.shift) / self.scale
120
+
121
+
122
+ class NetLinLayer(nn.Module):
123
+ """ A single linear layer which does a 1x1 conv """
124
+
125
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
126
+ super(NetLinLayer, self).__init__()
127
+ layers = [nn.Dropout(), ] if (use_dropout) else []
128
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1,
129
+ padding=0, bias=False), ]
130
+ self.model = nn.Sequential(*layers)
131
+
132
+
133
+ class vgg16(torch.nn.Module):
134
+ def __init__(self, requires_grad=False, pretrained=True):
135
+ super(vgg16, self).__init__()
136
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
137
+ self.slice1 = torch.nn.Sequential()
138
+ self.slice2 = torch.nn.Sequential()
139
+ self.slice3 = torch.nn.Sequential()
140
+ self.slice4 = torch.nn.Sequential()
141
+ self.slice5 = torch.nn.Sequential()
142
+ self.N_slices = 5
143
+ for x in range(4):
144
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
145
+ for x in range(4, 9):
146
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
147
+ for x in range(9, 16):
148
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
149
+ for x in range(16, 23):
150
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
151
+ for x in range(23, 30):
152
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
153
+ if not requires_grad:
154
+ for param in self.parameters():
155
+ param.requires_grad = False
156
+
157
+ def forward(self, X):
158
+ h = self.slice1(X)
159
+ h_relu1_2 = h
160
+ h = self.slice2(h)
161
+ h_relu2_2 = h
162
+ h = self.slice3(h)
163
+ h_relu3_3 = h
164
+ h = self.slice4(h)
165
+ h_relu4_3 = h
166
+ h = self.slice5(h)
167
+ h_relu5_3 = h
168
+ vgg_outputs = namedtuple(
169
+ "VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
170
+ out = vgg_outputs(h_relu1_2, h_relu2_2,
171
+ h_relu3_3, h_relu4_3, h_relu5_3)
172
+ return out
173
+
174
+
175
+ def normalize_tensor(x, eps=1e-10):
176
+ norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
177
+ return x/(norm_factor+eps)
178
+
179
+
180
+ def spatial_average(x, keepdim=True):
181
+ return x.mean([2, 3], keepdim=keepdim)
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adapted from https://github.com/SongweiGe/TATS"""
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ import math
5
+ import argparse
6
+ import numpy as np
7
+ import pickle as pkl
8
+
9
+ import pytorch_lightning as pl
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.distributed as dist
14
+
15
+ from ..utils import shift_dim, adopt_weight, comp_getattr
16
+ from .lpips import LPIPS
17
+ from .codebook import Codebook
18
+
19
+
20
+ def silu(x):
21
+ return x*torch.sigmoid(x)
22
+
23
+
24
+ class SiLU(nn.Module):
25
+ def __init__(self):
26
+ super(SiLU, self).__init__()
27
+
28
+ def forward(self, x):
29
+ return silu(x)
30
+
31
+
32
+ def hinge_d_loss(logits_real, logits_fake):
33
+ loss_real = torch.mean(F.relu(1. - logits_real))
34
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
35
+ d_loss = 0.5 * (loss_real + loss_fake)
36
+ return d_loss
37
+
38
+
39
+ def vanilla_d_loss(logits_real, logits_fake):
40
+ d_loss = 0.5 * (
41
+ torch.mean(torch.nn.functional.softplus(-logits_real)) +
42
+ torch.mean(torch.nn.functional.softplus(logits_fake)))
43
+ return d_loss
44
+
45
+
46
+ class VQGAN(pl.LightningModule):
47
+ def __init__(self, cfg):
48
+ super().__init__()
49
+ self.cfg = cfg
50
+ self.embedding_dim = cfg.model.embedding_dim # 8
51
+ self.n_codes = cfg.model.n_codes # 16384
52
+
53
+ self.encoder = Encoder(cfg.model.n_hiddens, # 16
54
+ cfg.model.downsample, # [2, 2, 2]
55
+ cfg.dataset.image_channels, # 1
56
+ cfg.model.norm_type, # group
57
+ cfg.model.padding_type, # replicate
58
+ cfg.model.num_groups, # 32
59
+ )
60
+ self.decoder = Decoder(
61
+ cfg.model.n_hiddens, cfg.model.downsample, cfg.dataset.image_channels, cfg.model.norm_type, cfg.model.num_groups)
62
+ self.enc_out_ch = self.encoder.out_channels
63
+ self.pre_vq_conv = SamePadConv3d(
64
+ self.enc_out_ch, cfg.model.embedding_dim, 1, padding_type=cfg.model.padding_type)
65
+ self.post_vq_conv = SamePadConv3d(
66
+ cfg.model.embedding_dim, self.enc_out_ch, 1)
67
+
68
+ self.codebook = Codebook(cfg.model.n_codes, cfg.model.embedding_dim,
69
+ no_random_restart=cfg.model.no_random_restart, restart_thres=cfg.model.restart_thres)
70
+
71
+ self.gan_feat_weight = cfg.model.gan_feat_weight
72
+ # TODO: Changed batchnorm from sync to normal
73
+ self.image_discriminator = NLayerDiscriminator(
74
+ cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm2d)
75
+ self.video_discriminator = NLayerDiscriminator3D(
76
+ cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm3d)
77
+
78
+ if cfg.model.disc_loss_type == 'vanilla':
79
+ self.disc_loss = vanilla_d_loss
80
+ elif cfg.model.disc_loss_type == 'hinge':
81
+ self.disc_loss = hinge_d_loss
82
+
83
+ self.perceptual_model = LPIPS().eval()
84
+
85
+ self.image_gan_weight = cfg.model.image_gan_weight
86
+ self.video_gan_weight = cfg.model.video_gan_weight
87
+
88
+ self.perceptual_weight = cfg.model.perceptual_weight
89
+
90
+ self.l1_weight = cfg.model.l1_weight
91
+ self.save_hyperparameters()
92
+
93
+ def encode(self, x, include_embeddings=False, quantize=True):
94
+ h = self.pre_vq_conv(self.encoder(x))
95
+ if quantize:
96
+ vq_output = self.codebook(h)
97
+ if include_embeddings:
98
+ return vq_output['embeddings'], vq_output['encodings']
99
+ else:
100
+ return vq_output['encodings']
101
+ return h
102
+
103
+ def decode(self, latent, quantize=False):
104
+ if quantize:
105
+ vq_output = self.codebook(latent)
106
+ latent = vq_output['encodings']
107
+ h = F.embedding(latent, self.codebook.embeddings)
108
+ h = self.post_vq_conv(shift_dim(h, -1, 1))
109
+ return self.decoder(h)
110
+
111
+ def forward(self, x, optimizer_idx=None, log_image=False):
112
+ B, C, T, H, W = x.shape
113
+ z = self.pre_vq_conv(self.encoder(x)) # [2, 32, 32, 32, 32] [2, 8, 32, 32, 32]
114
+ vq_output = self.codebook(z) # ['embeddings', 'encodings', 'commitment_loss', 'perplexity']
115
+ x_recon = self.decoder(self.post_vq_conv(vq_output['embeddings'])) # [2, 8, 32, 32, 32] [2, 32, 32, 32, 32]
116
+
117
+ recon_loss = F.l1_loss(x_recon, x) * self.l1_weight
118
+
119
+ # Selects one random 2D image from each 3D Image
120
+ frame_idx = torch.randint(0, T, [B]).cuda()
121
+ frame_idx_selected = frame_idx.reshape(-1,
122
+ 1, 1, 1, 1).repeat(1, C, 1, H, W) # [2, 1, 1, 64, 64]
123
+ frames = torch.gather(x, 2, frame_idx_selected).squeeze(2) # [2, 1, 64, 64]
124
+ frames_recon = torch.gather(x_recon, 2, frame_idx_selected).squeeze(2) # [2, 1, 64, 64]
125
+
126
+ if log_image:
127
+ return frames, frames_recon, x, x_recon
128
+
129
+ if optimizer_idx == 0:
130
+ # Autoencoder - train the "generator"
131
+
132
+ # Perceptual loss
133
+ perceptual_loss = 0
134
+ if self.perceptual_weight > 0:
135
+ perceptual_loss = self.perceptual_model(
136
+ frames, frames_recon).mean() * self.perceptual_weight
137
+
138
+ # Discriminator loss (turned on after a certain epoch)
139
+ logits_image_fake, pred_image_fake = self.image_discriminator(
140
+ frames_recon)
141
+ logits_video_fake, pred_video_fake = self.video_discriminator(
142
+ x_recon)
143
+ g_image_loss = -torch.mean(logits_image_fake)
144
+ g_video_loss = -torch.mean(logits_video_fake)
145
+ g_loss = self.image_gan_weight*g_image_loss + self.video_gan_weight*g_video_loss
146
+ disc_factor = adopt_weight(
147
+ self.global_step, threshold=self.cfg.model.discriminator_iter_start)
148
+ aeloss = disc_factor * g_loss
149
+
150
+ # GAN feature matching loss - tune features such that we get the same prediction result on the discriminator
151
+ image_gan_feat_loss = 0
152
+ video_gan_feat_loss = 0
153
+ feat_weights = 4.0 / (3 + 1)
154
+ if self.image_gan_weight > 0:
155
+ logits_image_real, pred_image_real = self.image_discriminator(
156
+ frames)
157
+ for i in range(len(pred_image_fake)-1):
158
+ image_gan_feat_loss += feat_weights * \
159
+ F.l1_loss(pred_image_fake[i], pred_image_real[i].detach(
160
+ )) * (self.image_gan_weight > 0)
161
+ if self.video_gan_weight > 0:
162
+ logits_video_real, pred_video_real = self.video_discriminator(
163
+ x)
164
+ for i in range(len(pred_video_fake)-1):
165
+ video_gan_feat_loss += feat_weights * \
166
+ F.l1_loss(pred_video_fake[i], pred_video_real[i].detach(
167
+ )) * (self.video_gan_weight > 0)
168
+ gan_feat_loss = disc_factor * self.gan_feat_weight * \
169
+ (image_gan_feat_loss + video_gan_feat_loss)
170
+
171
+ self.log("train/g_image_loss", g_image_loss,
172
+ logger=True, on_step=True, on_epoch=True)
173
+ self.log("train/g_video_loss", g_video_loss,
174
+ logger=True, on_step=True, on_epoch=True)
175
+ self.log("train/image_gan_feat_loss", image_gan_feat_loss,
176
+ logger=True, on_step=True, on_epoch=True)
177
+ self.log("train/video_gan_feat_loss", video_gan_feat_loss,
178
+ logger=True, on_step=True, on_epoch=True)
179
+ self.log("train/perceptual_loss", perceptual_loss,
180
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
181
+ self.log("train/recon_loss", recon_loss, prog_bar=True,
182
+ logger=True, on_step=True, on_epoch=True)
183
+ self.log("train/aeloss", aeloss, prog_bar=True,
184
+ logger=True, on_step=True, on_epoch=True)
185
+ self.log("train/commitment_loss", vq_output['commitment_loss'],
186
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
187
+ self.log('train/perplexity', vq_output['perplexity'],
188
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
189
+ return recon_loss, x_recon, vq_output, aeloss, perceptual_loss, gan_feat_loss
190
+
191
+ if optimizer_idx == 1:
192
+ # Train discriminator
193
+ logits_image_real, _ = self.image_discriminator(frames.detach())
194
+ logits_video_real, _ = self.video_discriminator(x.detach())
195
+
196
+ logits_image_fake, _ = self.image_discriminator(
197
+ frames_recon.detach())
198
+ logits_video_fake, _ = self.video_discriminator(x_recon.detach())
199
+
200
+ d_image_loss = self.disc_loss(logits_image_real, logits_image_fake)
201
+ d_video_loss = self.disc_loss(logits_video_real, logits_video_fake)
202
+ disc_factor = adopt_weight(
203
+ self.global_step, threshold=self.cfg.model.discriminator_iter_start)
204
+ discloss = disc_factor * \
205
+ (self.image_gan_weight*d_image_loss +
206
+ self.video_gan_weight*d_video_loss)
207
+
208
+ self.log("train/logits_image_real", logits_image_real.mean().detach(),
209
+ logger=True, on_step=True, on_epoch=True)
210
+ self.log("train/logits_image_fake", logits_image_fake.mean().detach(),
211
+ logger=True, on_step=True, on_epoch=True)
212
+ self.log("train/logits_video_real", logits_video_real.mean().detach(),
213
+ logger=True, on_step=True, on_epoch=True)
214
+ self.log("train/logits_video_fake", logits_video_fake.mean().detach(),
215
+ logger=True, on_step=True, on_epoch=True)
216
+ self.log("train/d_image_loss", d_image_loss,
217
+ logger=True, on_step=True, on_epoch=True)
218
+ self.log("train/d_video_loss", d_video_loss,
219
+ logger=True, on_step=True, on_epoch=True)
220
+ self.log("train/discloss", discloss, prog_bar=True,
221
+ logger=True, on_step=True, on_epoch=True)
222
+ return discloss
223
+
224
+ perceptual_loss = self.perceptual_model(
225
+ frames, frames_recon) * self.perceptual_weight
226
+ return recon_loss, x_recon, vq_output, perceptual_loss
227
+
228
+ def training_step(self, batch, batch_idx, optimizer_idx):
229
+ x = batch['image']
230
+ if optimizer_idx == 0:
231
+ recon_loss, _, vq_output, aeloss, perceptual_loss, gan_feat_loss = self.forward(
232
+ x, optimizer_idx)
233
+ commitment_loss = vq_output['commitment_loss']
234
+ loss = recon_loss + commitment_loss + aeloss + perceptual_loss + gan_feat_loss
235
+ if optimizer_idx == 1:
236
+ discloss = self.forward(x, optimizer_idx)
237
+ loss = discloss
238
+ return loss
239
+
240
+ def validation_step(self, batch, batch_idx):
241
+ x = batch['image'] # TODO: batch['stft']
242
+ recon_loss, _, vq_output, perceptual_loss = self.forward(x)
243
+ self.log('val/recon_loss', recon_loss, prog_bar=True)
244
+ self.log('val/perceptual_loss', perceptual_loss, prog_bar=True)
245
+ self.log('val/perplexity', vq_output['perplexity'], prog_bar=True)
246
+ self.log('val/commitment_loss',
247
+ vq_output['commitment_loss'], prog_bar=True)
248
+
249
+ def configure_optimizers(self):
250
+ lr = self.cfg.model.lr
251
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters()) +
252
+ list(self.decoder.parameters()) +
253
+ list(self.pre_vq_conv.parameters()) +
254
+ list(self.post_vq_conv.parameters()) +
255
+ list(self.codebook.parameters()),
256
+ lr=lr, betas=(0.5, 0.9))
257
+ opt_disc = torch.optim.Adam(list(self.image_discriminator.parameters()) +
258
+ list(self.video_discriminator.parameters()),
259
+ lr=lr, betas=(0.5, 0.9))
260
+ return [opt_ae, opt_disc], []
261
+
262
+ def log_images(self, batch, **kwargs):
263
+ log = dict()
264
+ x = batch['image']
265
+ x = x.to(self.device)
266
+ frames, frames_rec, _, _ = self(x, log_image=True)
267
+ log["inputs"] = frames
268
+ log["reconstructions"] = frames_rec
269
+ #log['mean_org'] = batch['mean_org']
270
+ #log['std_org'] = batch['std_org']
271
+ return log
272
+
273
+ def log_videos(self, batch, **kwargs):
274
+ log = dict()
275
+ x = batch['image']
276
+ _, _, x, x_rec = self(x, log_image=True)
277
+ log["inputs"] = x
278
+ log["reconstructions"] = x_rec
279
+ #log['mean_org'] = batch['mean_org']
280
+ #log['std_org'] = batch['std_org']
281
+ return log
282
+
283
+
284
+ def Normalize(in_channels, norm_type='group', num_groups=32):
285
+ assert norm_type in ['group', 'batch']
286
+ if norm_type == 'group':
287
+ # TODO Changed num_groups from 32 to 8
288
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
289
+ elif norm_type == 'batch':
290
+ return torch.nn.SyncBatchNorm(in_channels)
291
+
292
+
293
+ class Encoder(nn.Module):
294
+ def __init__(self, n_hiddens, downsample, image_channel=3, norm_type='group', padding_type='replicate', num_groups=32):
295
+ super().__init__()
296
+ n_times_downsample = np.array([int(math.log2(d)) for d in downsample])
297
+ self.conv_blocks = nn.ModuleList()
298
+ max_ds = n_times_downsample.max()
299
+
300
+ self.conv_first = SamePadConv3d(
301
+ image_channel, n_hiddens, kernel_size=3, padding_type=padding_type)
302
+
303
+ for i in range(max_ds):
304
+ block = nn.Module()
305
+ in_channels = n_hiddens * 2**i
306
+ out_channels = n_hiddens * 2**(i+1)
307
+ stride = tuple([2 if d > 0 else 1 for d in n_times_downsample])
308
+ block.down = SamePadConv3d(
309
+ in_channels, out_channels, 4, stride=stride, padding_type=padding_type)
310
+ block.res = ResBlock(
311
+ out_channels, out_channels, norm_type=norm_type, num_groups=num_groups)
312
+ self.conv_blocks.append(block)
313
+ n_times_downsample -= 1
314
+
315
+ self.final_block = nn.Sequential(
316
+ Normalize(out_channels, norm_type, num_groups=num_groups),
317
+ SiLU()
318
+ )
319
+
320
+ self.out_channels = out_channels
321
+
322
+ def forward(self, x):
323
+ h = self.conv_first(x)
324
+ for block in self.conv_blocks:
325
+ h = block.down(h)
326
+ h = block.res(h)
327
+ h = self.final_block(h)
328
+ return h
329
+
330
+
331
+ class Decoder(nn.Module):
332
+ def __init__(self, n_hiddens, upsample, image_channel, norm_type='group', num_groups=32):
333
+ super().__init__()
334
+
335
+ n_times_upsample = np.array([int(math.log2(d)) for d in upsample])
336
+ max_us = n_times_upsample.max()
337
+
338
+ in_channels = n_hiddens*2**max_us
339
+ self.final_block = nn.Sequential(
340
+ Normalize(in_channels, norm_type, num_groups=num_groups),
341
+ SiLU()
342
+ )
343
+
344
+ self.conv_blocks = nn.ModuleList()
345
+ for i in range(max_us):
346
+ block = nn.Module()
347
+ in_channels = in_channels if i == 0 else n_hiddens*2**(max_us-i+1)
348
+ out_channels = n_hiddens*2**(max_us-i)
349
+ us = tuple([2 if d > 0 else 1 for d in n_times_upsample])
350
+ block.up = SamePadConvTranspose3d(
351
+ in_channels, out_channels, 4, stride=us)
352
+ block.res1 = ResBlock(
353
+ out_channels, out_channels, norm_type=norm_type, num_groups=num_groups)
354
+ block.res2 = ResBlock(
355
+ out_channels, out_channels, norm_type=norm_type, num_groups=num_groups)
356
+ self.conv_blocks.append(block)
357
+ n_times_upsample -= 1
358
+
359
+ self.conv_last = SamePadConv3d(
360
+ out_channels, image_channel, kernel_size=3)
361
+
362
+ def forward(self, x):
363
+ h = self.final_block(x)
364
+ for i, block in enumerate(self.conv_blocks):
365
+ h = block.up(h)
366
+ h = block.res1(h)
367
+ h = block.res2(h)
368
+ h = self.conv_last(h)
369
+ return h
370
+
371
+
372
+ class ResBlock(nn.Module):
373
+ def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group', padding_type='replicate', num_groups=32):
374
+ super().__init__()
375
+ self.in_channels = in_channels
376
+ out_channels = in_channels if out_channels is None else out_channels
377
+ self.out_channels = out_channels
378
+ self.use_conv_shortcut = conv_shortcut
379
+
380
+ self.norm1 = Normalize(in_channels, norm_type, num_groups=num_groups)
381
+ self.conv1 = SamePadConv3d(
382
+ in_channels, out_channels, kernel_size=3, padding_type=padding_type)
383
+ self.dropout = torch.nn.Dropout(dropout)
384
+ self.norm2 = Normalize(in_channels, norm_type, num_groups=num_groups)
385
+ self.conv2 = SamePadConv3d(
386
+ out_channels, out_channels, kernel_size=3, padding_type=padding_type)
387
+ if self.in_channels != self.out_channels:
388
+ self.conv_shortcut = SamePadConv3d(
389
+ in_channels, out_channels, kernel_size=3, padding_type=padding_type)
390
+
391
+ def forward(self, x):
392
+ h = x
393
+ h = self.norm1(h)
394
+ h = silu(h)
395
+ h = self.conv1(h)
396
+ h = self.norm2(h)
397
+ h = silu(h)
398
+ h = self.conv2(h)
399
+
400
+ if self.in_channels != self.out_channels:
401
+ x = self.conv_shortcut(x)
402
+
403
+ return x+h
404
+
405
+
406
+ # Does not support dilation
407
+ class SamePadConv3d(nn.Module):
408
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'):
409
+ super().__init__()
410
+ if isinstance(kernel_size, int):
411
+ kernel_size = (kernel_size,) * 3
412
+ if isinstance(stride, int):
413
+ stride = (stride,) * 3
414
+
415
+ # assumes that the input shape is divisible by stride
416
+ total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
417
+ pad_input = []
418
+ for p in total_pad[::-1]: # reverse since F.pad starts from last dim
419
+ pad_input.append((p // 2 + p % 2, p // 2))
420
+ pad_input = sum(pad_input, tuple())
421
+ self.pad_input = pad_input
422
+ self.padding_type = padding_type
423
+
424
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size,
425
+ stride=stride, padding=0, bias=bias)
426
+
427
+ def forward(self, x):
428
+ return self.conv(F.pad(x, self.pad_input, mode=self.padding_type))
429
+
430
+
431
+ class SamePadConvTranspose3d(nn.Module):
432
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'):
433
+ super().__init__()
434
+ if isinstance(kernel_size, int):
435
+ kernel_size = (kernel_size,) * 3
436
+ if isinstance(stride, int):
437
+ stride = (stride,) * 3
438
+
439
+ total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
440
+ pad_input = []
441
+ for p in total_pad[::-1]: # reverse since F.pad starts from last dim
442
+ pad_input.append((p // 2 + p % 2, p // 2))
443
+ pad_input = sum(pad_input, tuple())
444
+ self.pad_input = pad_input
445
+ self.padding_type = padding_type
446
+
447
+ self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size,
448
+ stride=stride, bias=bias,
449
+ padding=tuple([k - 1 for k in kernel_size]))
450
+
451
+ def forward(self, x):
452
+ return self.convt(F.pad(x, self.pad_input, mode=self.padding_type))
453
+
454
+
455
+ class NLayerDiscriminator(nn.Module):
456
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True):
457
+ # def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=True):
458
+ super(NLayerDiscriminator, self).__init__()
459
+ self.getIntermFeat = getIntermFeat
460
+ self.n_layers = n_layers
461
+
462
+ kw = 4
463
+ padw = int(np.ceil((kw-1.0)/2))
464
+ sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw,
465
+ stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
466
+
467
+ nf = ndf
468
+ for n in range(1, n_layers):
469
+ nf_prev = nf
470
+ nf = min(nf * 2, 512)
471
+ sequence += [[
472
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
473
+ norm_layer(nf), nn.LeakyReLU(0.2, True)
474
+ ]]
475
+
476
+ nf_prev = nf
477
+ nf = min(nf * 2, 512)
478
+ sequence += [[
479
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
480
+ norm_layer(nf),
481
+ nn.LeakyReLU(0.2, True)
482
+ ]]
483
+
484
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw,
485
+ stride=1, padding=padw)]]
486
+
487
+ if use_sigmoid:
488
+ sequence += [[nn.Sigmoid()]]
489
+
490
+ if getIntermFeat:
491
+ for n in range(len(sequence)):
492
+ setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
493
+ else:
494
+ sequence_stream = []
495
+ for n in range(len(sequence)):
496
+ sequence_stream += sequence[n]
497
+ self.model = nn.Sequential(*sequence_stream)
498
+
499
+ def forward(self, input):
500
+ if self.getIntermFeat:
501
+ res = [input]
502
+ for n in range(self.n_layers+2):
503
+ model = getattr(self, 'model'+str(n))
504
+ res.append(model(res[-1]))
505
+ return res[-1], res[1:]
506
+ else:
507
+ return self.model(input), _
508
+
509
+
510
+ class NLayerDiscriminator3D(nn.Module):
511
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True):
512
+ super(NLayerDiscriminator3D, self).__init__()
513
+ self.getIntermFeat = getIntermFeat
514
+ self.n_layers = n_layers
515
+
516
+ kw = 4
517
+ padw = int(np.ceil((kw-1.0)/2))
518
+ sequence = [[nn.Conv3d(input_nc, ndf, kernel_size=kw,
519
+ stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
520
+
521
+ nf = ndf
522
+ for n in range(1, n_layers):
523
+ nf_prev = nf
524
+ nf = min(nf * 2, 512)
525
+ sequence += [[
526
+ nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
527
+ norm_layer(nf), nn.LeakyReLU(0.2, True)
528
+ ]]
529
+
530
+ nf_prev = nf
531
+ nf = min(nf * 2, 512)
532
+ sequence += [[
533
+ nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
534
+ norm_layer(nf),
535
+ nn.LeakyReLU(0.2, True)
536
+ ]]
537
+
538
+ sequence += [[nn.Conv3d(nf, 1, kernel_size=kw,
539
+ stride=1, padding=padw)]]
540
+
541
+ if use_sigmoid:
542
+ sequence += [[nn.Sigmoid()]]
543
+
544
+ if getIntermFeat:
545
+ for n in range(len(sequence)):
546
+ setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
547
+ else:
548
+ sequence_stream = []
549
+ for n in range(len(sequence)):
550
+ sequence_stream += sequence[n]
551
+ self.model = nn.Sequential(*sequence_stream)
552
+
553
+ def forward(self, input):
554
+ if self.getIntermFeat:
555
+ res = [input]
556
+ for n in range(self.n_layers+2):
557
+ model = getattr(self, 'model'+str(n))
558
+ res.append(model(res[-1]))
559
+ return res[-1], res[1:]
560
+ else:
561
+ return self.model(input), _
Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/utils.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Adapted from https://github.com/SongweiGe/TATS"""
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ import warnings
5
+ import torch
6
+ import imageio
7
+
8
+ import math
9
+ import numpy as np
10
+
11
+ import sys
12
+ import pdb as pdb_original
13
+ import logging
14
+
15
+ import imageio.core.util
16
+ logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR)
17
+
18
+
19
+ class ForkedPdb(pdb_original.Pdb):
20
+ """A Pdb subclass that may be used
21
+ from a forked multiprocessing child
22
+
23
+ """
24
+
25
+ def interaction(self, *args, **kwargs):
26
+ _stdin = sys.stdin
27
+ try:
28
+ sys.stdin = open('/dev/stdin')
29
+ pdb_original.Pdb.interaction(self, *args, **kwargs)
30
+ finally:
31
+ sys.stdin = _stdin
32
+
33
+
34
+ # Shifts src_tf dim to dest dim
35
+ # i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c)
36
+ def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
37
+ n_dims = len(x.shape)
38
+ if src_dim < 0:
39
+ src_dim = n_dims + src_dim
40
+ if dest_dim < 0:
41
+ dest_dim = n_dims + dest_dim
42
+
43
+ assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
44
+
45
+ dims = list(range(n_dims))
46
+ del dims[src_dim]
47
+
48
+ permutation = []
49
+ ctr = 0
50
+ for i in range(n_dims):
51
+ if i == dest_dim:
52
+ permutation.append(src_dim)
53
+ else:
54
+ permutation.append(dims[ctr])
55
+ ctr += 1
56
+ x = x.permute(permutation)
57
+ if make_contiguous:
58
+ x = x.contiguous()
59
+ return x
60
+
61
+
62
+ # reshapes tensor start from dim i (inclusive)
63
+ # to dim j (exclusive) to the desired shape
64
+ # e.g. if x.shape = (b, thw, c) then
65
+ # view_range(x, 1, 2, (t, h, w)) returns
66
+ # x of shape (b, t, h, w, c)
67
+ def view_range(x, i, j, shape):
68
+ shape = tuple(shape)
69
+
70
+ n_dims = len(x.shape)
71
+ if i < 0:
72
+ i = n_dims + i
73
+
74
+ if j is None:
75
+ j = n_dims
76
+ elif j < 0:
77
+ j = n_dims + j
78
+
79
+ assert 0 <= i < j <= n_dims
80
+
81
+ x_shape = x.shape
82
+ target_shape = x_shape[:i] + shape + x_shape[j:]
83
+ return x.view(target_shape)
84
+
85
+
86
+ def accuracy(output, target, topk=(1,)):
87
+ """Computes the accuracy over the k top predictions for the specified values of k"""
88
+ with torch.no_grad():
89
+ maxk = max(topk)
90
+ batch_size = target.size(0)
91
+
92
+ _, pred = output.topk(maxk, 1, True, True)
93
+ pred = pred.t()
94
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
95
+
96
+ res = []
97
+ for k in topk:
98
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
99
+ res.append(correct_k.mul_(100.0 / batch_size))
100
+ return res
101
+
102
+
103
+ def tensor_slice(x, begin, size):
104
+ assert all([b >= 0 for b in begin])
105
+ size = [l - b if s == -1 else s
106
+ for s, b, l in zip(size, begin, x.shape)]
107
+ assert all([s >= 0 for s in size])
108
+
109
+ slices = [slice(b, b + s) for b, s in zip(begin, size)]
110
+ return x[slices]
111
+
112
+
113
+ def adopt_weight(global_step, threshold=0, value=0.):
114
+ weight = 1
115
+ if global_step < threshold:
116
+ weight = value
117
+ return weight
118
+
119
+
120
+ def save_video_grid(video, fname, nrow=None, fps=6):
121
+ b, c, t, h, w = video.shape
122
+ video = video.permute(0, 2, 3, 4, 1)
123
+ video = (video.cpu().numpy() * 255).astype('uint8')
124
+ if nrow is None:
125
+ nrow = math.ceil(math.sqrt(b))
126
+ ncol = math.ceil(b / nrow)
127
+ padding = 1
128
+ video_grid = np.zeros((t, (padding + h) * nrow + padding,
129
+ (padding + w) * ncol + padding, c), dtype='uint8')
130
+ for i in range(b):
131
+ r = i // ncol
132
+ c = i % ncol
133
+ start_r = (padding + h) * r
134
+ start_c = (padding + w) * c
135
+ video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i]
136
+ video = []
137
+ for i in range(t):
138
+ video.append(video_grid[i])
139
+ imageio.mimsave(fname, video, fps=fps)
140
+ ## skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '5'})
141
+ #print('saved videos to', fname)
142
+
143
+
144
+ def comp_getattr(args, attr_name, default=None):
145
+ if hasattr(args, attr_name):
146
+ return getattr(args, attr_name)
147
+ else:
148
+ return default
149
+
150
+
151
+ def visualize_tensors(t, name=None, nest=0):
152
+ if name is not None:
153
+ print(name, "current nest: ", nest)
154
+ print("type: ", type(t))
155
+ if 'dict' in str(type(t)):
156
+ print(t.keys())
157
+ for k in t.keys():
158
+ if t[k] is None:
159
+ print(k, "None")
160
+ else:
161
+ if 'Tensor' in str(type(t[k])):
162
+ print(k, t[k].shape)
163
+ elif 'dict' in str(type(t[k])):
164
+ print(k, 'dict')
165
+ visualize_tensors(t[k], name, nest + 1)
166
+ elif 'list' in str(type(t[k])):
167
+ print(k, len(t[k]))
168
+ visualize_tensors(t[k], name, nest + 1)
169
+ elif 'list' in str(type(t)):
170
+ print("list length: ", len(t))
171
+ for t2 in t:
172
+ visualize_tensors(t2, name, nest + 1)
173
+ elif 'Tensor' in str(type(t)):
174
+ print(t.shape)
175
+ else:
176
+ print(t)
177
+ return ""
Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_kidney_fold0_early.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d889b3561803f7490f4050c03a02163f099633e4f00fea4cb10b5b993685e5cc
3
+ size 290138333
Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_kidney_fold0_noearly_t200.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26bc7847ae15377a5586535cbb2e6a1ec5b6a98732f7f795c284d7dcda208c97
3
+ size 290156765
Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_liver_fold0_early.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0135000f031f741252b3e706748b674d33e7278402a7cb2500fec5f4966847bd
3
+ size 290138333
Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_liver_fold0_noearly_t200.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2a21980e53efd6758ae92e79a82668f0e1e6d9b52fdf6b2a709cb929ebedb3b
3
+ size 290156765
Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_pancreas_fold0_early.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9438e39a44af92bb0fbaf5cc50a3ac3aaa260978a69ac341ed7ec23512c080a5
3
+ size 290138333
Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_pancreas_fold0_noearly_t200.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb37011840156f548fd2348dbb5578f9bc81de16719ec226fbef2de6f0244f9d
3
+ size 290156765
Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/recon_96d4_all.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef88523af9590a7325bc9ca41999de191c3fbc41afc6186a8c4db5528446bb1f
3
+ size 242615727
Generation_Pipeline_filter/syn_liver/TumorGeneration/utils.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Tumor Generateion
2
+ import random
3
+ import cv2
4
+ import elasticdeform
5
+ import numpy as np
6
+ from scipy.ndimage import gaussian_filter
7
+ from TumorGeneration.ldm.ddpm.ddim import DDIMSampler
8
+
9
+ # Step 1: Random select (numbers) location for tumor.
10
+ def random_select(mask_scan, organ_type):
11
+ # we first find z index and then sample point with z slice
12
+ # print('mask_scan',np.unique(mask_scan))
13
+ # print('pixel num', (mask_scan == 1).sum())
14
+ z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0].min(), np.where(np.any(mask_scan, axis=(0, 1)))[0].max()
15
+ # print('z_start, z_end',z_start, z_end)
16
+ # we need to strict number z's position (0.3 - 0.7 in the middle of liver)
17
+ flag=0
18
+ while 1:
19
+ z = round(random.uniform(0.3, 0.7) * (z_end - z_start)) + z_start
20
+ liver_mask = mask_scan[..., z]
21
+ # erode the mask (we don't want the edge points)
22
+ if organ_type == 'liver':
23
+ flag+=1
24
+ if flag <= 10:
25
+ kernel = np.ones((5,5), dtype=np.uint8)
26
+ liver_mask = cv2.erode(liver_mask, kernel, iterations=1)
27
+ elif flag >10 and flag <= 20:
28
+ kernel = np.ones((3,3), dtype=np.uint8)
29
+ liver_mask = cv2.erode(liver_mask, kernel, iterations=1)
30
+ else:
31
+ pass
32
+ print(flag)
33
+ if (liver_mask == 1).sum() > 0:
34
+ break
35
+
36
+ # print('liver_mask', (liver_mask == 1).sum())
37
+ coordinates = np.argwhere(liver_mask == 1)
38
+ random_index = np.random.randint(0, len(coordinates))
39
+ xyz = coordinates[random_index].tolist() # get x,y
40
+ xyz.append(z)
41
+ potential_points = xyz
42
+
43
+ return potential_points
44
+
45
+ def center_select(mask_scan):
46
+ z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0].min(), np.where(np.any(mask_scan, axis=(0, 1)))[0].max()
47
+ x_start, x_end = np.where(np.any(mask_scan, axis=(1, 2)))[0].min(), np.where(np.any(mask_scan, axis=(1, 2)))[0].max()
48
+ y_start, y_end = np.where(np.any(mask_scan, axis=(0, 2)))[0].min(), np.where(np.any(mask_scan, axis=(0, 2)))[0].max()
49
+
50
+ z = round(0.5 * (z_end - z_start)) + z_start
51
+ x = round(0.5 * (x_end - x_start)) + x_start
52
+ y = round(0.5 * (y_end - y_start)) + y_start
53
+
54
+ xyz = [x, y, z]
55
+ potential_points = xyz
56
+
57
+ return potential_points
58
+
59
+ # Step 2 : generate the ellipsoid
60
+ def get_ellipsoid(x, y, z):
61
+ """"
62
+ x, y, z is the radius of this ellipsoid in x, y, z direction respectly.
63
+ """
64
+ sh = (4*x, 4*y, 4*z)
65
+ out = np.zeros(sh, int)
66
+ aux = np.zeros(sh)
67
+ radii = np.array([x, y, z])
68
+ com = np.array([2*x, 2*y, 2*z]) # center point
69
+
70
+ # calculate the ellipsoid
71
+ bboxl = np.floor(com-radii).clip(0,None).astype(int)
72
+ bboxh = (np.ceil(com+radii)+1).clip(None, sh).astype(int)
73
+ roi = out[tuple(map(slice,bboxl,bboxh))]
74
+ roiaux = aux[tuple(map(slice,bboxl,bboxh))]
75
+ logrid = *map(np.square,np.ogrid[tuple(
76
+ map(slice,(bboxl-com)/radii,(bboxh-com-1)/radii,1j*(bboxh-bboxl)))]),
77
+ dst = (1-sum(logrid)).clip(0,None)
78
+ mask = dst>roiaux
79
+ roi[mask] = 1
80
+ np.copyto(roiaux,dst,where=mask)
81
+
82
+ return out
83
+
84
+ def get_fixed_geo(mask_scan, tumor_type, organ_type):
85
+ if tumor_type == 'large':
86
+ enlarge_x, enlarge_y, enlarge_z = 280, 280, 280
87
+ else:
88
+ enlarge_x, enlarge_y, enlarge_z = 160, 160, 160
89
+ geo_mask = np.zeros((mask_scan.shape[0] + enlarge_x, mask_scan.shape[1] + enlarge_y, mask_scan.shape[2] + enlarge_z), dtype=np.int8)
90
+ tiny_radius, small_radius, medium_radius, large_radius = 4, 8, 16, 32
91
+
92
+ if tumor_type == 'tiny':
93
+ num_tumor = random.randint(1,3)
94
+ # num_tumor = 1
95
+ for _ in range(num_tumor):
96
+ # Tiny tumor
97
+ x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
98
+ y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
99
+ z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
100
+ sigma = random.uniform(0.5,1)
101
+
102
+ geo = get_ellipsoid(x, y, z)
103
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
104
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
105
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
106
+ point = random_select(mask_scan, organ_type)
107
+ new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
108
+ x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
109
+ y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
110
+ z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
111
+
112
+ # paste small tumor geo into test sample
113
+ geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
114
+
115
+ if tumor_type == 'small':
116
+ num_tumor = random.randint(1,3)
117
+ # num_tumor = 1
118
+ for _ in range(num_tumor):
119
+ # Small tumor
120
+ x = random.randint(int(0.75*small_radius), int(1.25*small_radius))
121
+ y = random.randint(int(0.75*small_radius), int(1.25*small_radius))
122
+ z = random.randint(int(0.75*small_radius), int(1.25*small_radius))
123
+ sigma = random.randint(1, 2)
124
+
125
+ geo = get_ellipsoid(x, y, z)
126
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
127
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
128
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
129
+ # texture = get_texture((4*x, 4*y, 4*z))
130
+ point = random_select(mask_scan, organ_type)
131
+ new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
132
+ x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
133
+ y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
134
+ z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
135
+
136
+ # paste small tumor geo into test sample
137
+ geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
138
+ # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture
139
+
140
+ if tumor_type == 'medium':
141
+ # num_tumor = random.randint(1, 3)
142
+ num_tumor = 1
143
+ for _ in range(num_tumor):
144
+ # medium tumor
145
+ x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
146
+ y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
147
+ z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
148
+ sigma = random.randint(3, 6)
149
+
150
+ geo = get_ellipsoid(x, y, z)
151
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
152
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
153
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
154
+ # texture = get_texture((4*x, 4*y, 4*z))
155
+ point = random_select(mask_scan, organ_type)
156
+ new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
157
+ x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
158
+ y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
159
+ z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
160
+
161
+ # paste medium tumor geo into test sample
162
+ geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
163
+ # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture
164
+
165
+ if tumor_type == 'large':
166
+ num_tumor = 1 # random.randint(1,3)
167
+ for _ in range(num_tumor):
168
+ # Large tumor
169
+
170
+ x = random.randint(int(0.75*large_radius), int(1.25*large_radius))
171
+ y = random.randint(int(0.75*large_radius), int(1.25*large_radius))
172
+ z = random.randint(int(0.75*large_radius), int(1.25*large_radius))
173
+ sigma = random.randint(5, 10)
174
+
175
+ geo = get_ellipsoid(x, y, z)
176
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
177
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
178
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
179
+ if organ_type == 'liver' or organ_type == 'kidney' :
180
+ point = random_select(mask_scan, organ_type)
181
+ new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
182
+ x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
183
+ y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
184
+ z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
185
+ else:
186
+ x_start, x_end = np.where(np.any(geo, axis=(1, 2)))[0].min(), np.where(np.any(geo, axis=(1, 2)))[0].max()
187
+ y_start, y_end = np.where(np.any(geo, axis=(0, 2)))[0].min(), np.where(np.any(geo, axis=(0, 2)))[0].max()
188
+ z_start, z_end = np.where(np.any(geo, axis=(0, 1)))[0].min(), np.where(np.any(geo, axis=(0, 1)))[0].max()
189
+ geo = geo[x_start:x_end, y_start:y_end, z_start:z_end]
190
+
191
+ point = center_select(mask_scan)
192
+
193
+ new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
194
+ x_low = new_point[0] - geo.shape[0]//2
195
+ y_low = new_point[1] - geo.shape[1]//2
196
+ z_low = new_point[2] - geo.shape[2]//2
197
+
198
+ # paste small tumor geo into test sample
199
+ geo_mask[x_low:x_low+geo.shape[0], y_low:y_low+geo.shape[1], z_low:z_low+geo.shape[2]] += geo
200
+
201
+ geo_mask = geo_mask[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2]
202
+ # texture_map = texture_map[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2]
203
+ if ((tumor_type == 'medium') or (tumor_type == 'large')) and (organ_type == 'kidney'):
204
+ if random.random() > 0.5:
205
+ geo_mask = (geo_mask>=1)
206
+ else:
207
+ geo_mask = (geo_mask * mask_scan) >=1
208
+ else:
209
+ geo_mask = (geo_mask * mask_scan) >=1
210
+
211
+ return geo_mask
212
+
213
+
214
+ from .ldm.vq_gan_3d.model.vqgan import VQGAN
215
+ import matplotlib.pyplot as plt
216
+ import SimpleITK as sitk
217
+ from .ldm.ddpm import Unet3D, GaussianDiffusion, Tester
218
+ from hydra import initialize, compose
219
+ import torch
220
+ import yaml
221
+ def synt_model_prepare(device, vqgan_ckpt='TumorGeneration/model_weight/recon_96d4_all.ckpt', diffusion_ckpt='TumorGeneration/model_weight/', fold=0, organ='liver'):
222
+ with initialize(config_path="diffusion_config/"):
223
+ cfg=compose(config_name="ddpm.yaml")
224
+ print('diffusion_ckpt',diffusion_ckpt)
225
+ vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt)
226
+ vqgan = vqgan.to(device)
227
+ vqgan.eval()
228
+
229
+ early_Unet3D = Unet3D(
230
+ dim=cfg.diffusion_img_size,
231
+ dim_mults=cfg.dim_mults,
232
+ channels=cfg.diffusion_num_channels,
233
+ out_dim=cfg.out_dim
234
+ ).to(device)
235
+
236
+ early_diffusion = GaussianDiffusion(
237
+ early_Unet3D,
238
+ vqgan_ckpt= vqgan_ckpt, # cfg.vqgan_ckpt,
239
+ image_size=cfg.diffusion_img_size,
240
+ num_frames=cfg.diffusion_depth_size,
241
+ channels=cfg.diffusion_num_channels,
242
+ timesteps=4, # cfg.timesteps,
243
+ loss_type=cfg.loss_type,
244
+ device=device
245
+ ).to(device)
246
+
247
+ noearly_Unet3D = Unet3D(
248
+ dim=cfg.diffusion_img_size,
249
+ dim_mults=cfg.dim_mults,
250
+ channels=cfg.diffusion_num_channels,
251
+ out_dim=cfg.out_dim
252
+ ).to(device)
253
+
254
+ noearly_diffusion = GaussianDiffusion(
255
+ noearly_Unet3D,
256
+ vqgan_ckpt= vqgan_ckpt, # cfg.vqgan_ckpt,
257
+ image_size=cfg.diffusion_img_size,
258
+ num_frames=cfg.diffusion_depth_size,
259
+ channels=cfg.diffusion_num_channels,
260
+ timesteps=200, # cfg.timesteps,
261
+ loss_type=cfg.loss_type,
262
+ device=device
263
+ ).to(device)
264
+
265
+ early_tester = Tester(early_diffusion)
266
+ # noearly_tester = Tester(noearly_diffusion)
267
+ early_tester.load(diffusion_ckpt+'diff_{}_fold{}_early.pt'.format(organ, fold), map_location=device)
268
+ # noearly_tester.load(diffusion_ckpt+'diff_liver_fold{}_noearly_t200.pt'.format(fold), map_location=device)
269
+
270
+ # early_checkpoint = torch.load(diffusion_ckpt+'diff_liver_fold{}_early.pt'.format(fold), map_location=device)
271
+ noearly_checkpoint = torch.load(diffusion_ckpt+'diff_{}_fold{}_noearly_t200.pt'.format(organ, fold), map_location=device)
272
+ # early_diffusion.load_state_dict(early_checkpoint['ema'])
273
+ noearly_diffusion.load_state_dict(noearly_checkpoint['ema'])
274
+ # early_sampler = DDIMSampler(early_diffusion, schedule="cosine")
275
+ noearly_sampler = DDIMSampler(noearly_diffusion, schedule="cosine")
276
+ # breakpoint()
277
+ return vqgan, early_tester, noearly_sampler
278
+
279
+ def synthesize_early_tumor(ct_volume, organ_mask, organ_type, vqgan, tester):
280
+ device=ct_volume.device
281
+
282
+ # generate tumor mask
283
+ tumor_types = ['tiny', 'small']
284
+ tumor_probs = np.array([0.5, 0.5])
285
+ total_tumor_mask = []
286
+ organ_mask_np = organ_mask.cpu().numpy()
287
+ with torch.no_grad():
288
+ # get model input
289
+ for bs in range(organ_mask_np.shape[0]):
290
+ synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel())
291
+ tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type)
292
+ total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:])
293
+ total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device)
294
+
295
+ volume = ct_volume*2.0 - 1.0
296
+ mask = total_tumor_mask*2.0 - 1.0
297
+ mask_ = 1-total_tumor_mask
298
+ masked_volume = (volume*mask_).detach()
299
+
300
+ volume = volume.permute(0,1,-1,-3,-2)
301
+ masked_volume = masked_volume.permute(0,1,-1,-3,-2)
302
+ mask = mask.permute(0,1,-1,-3,-2)
303
+
304
+ # vqgan encoder inference
305
+ masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True)
306
+ masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) /
307
+ (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0
308
+
309
+ cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:])
310
+ cond = torch.cat((masked_volume_feat, cc), dim=1)
311
+
312
+ # diffusion inference and decoder
313
+ tester.ema_model.eval()
314
+ sample = tester.ema_model.sample(batch_size=volume.shape[0], cond=cond)
315
+
316
+ # if organ_type == 'liver' or organ_type == 'kidney' :
317
+
318
+ mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0)
319
+ sigma = np.random.uniform(0, 4) # (1, 2)
320
+ mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma])
321
+ # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy()
322
+
323
+ volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0)
324
+ sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0)
325
+
326
+ mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device)
327
+ final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_
328
+ final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0)
329
+ # elif organ_type == 'pancreas':
330
+ # final_volume_ = (sample+1.0)/2.0
331
+ final_volume_ = final_volume_.permute(0,1,-2,-1,-3)
332
+ organ_tumor_mask = torch.ones_like(organ_mask)
333
+ organ_tumor_mask[total_tumor_mask==1] = 2
334
+
335
+ return final_volume_, organ_tumor_mask
336
+
337
+ def synthesize_medium_tumor(ct_volume, organ_mask, organ_type, vqgan, sampler, ddim_ts=50):
338
+ device=ct_volume.device
339
+
340
+ # generate tumor mask
341
+ # tumor_types = ['large']
342
+ # tumor_probs = np.array([1.0])
343
+ total_tumor_mask = []
344
+ organ_mask_np = organ_mask.cpu().numpy()
345
+ with torch.no_grad():
346
+ # get model input
347
+ for bs in range(organ_mask_np.shape[0]):
348
+ # synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel())
349
+ synthetic_tumor_type = 'medium'
350
+ tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type)
351
+ total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:])
352
+ total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device)
353
+
354
+ volume = ct_volume*2.0 - 1.0
355
+ mask = total_tumor_mask*2.0 - 1.0
356
+ mask_ = 1-total_tumor_mask
357
+ masked_volume = (volume*mask_).detach()
358
+
359
+ volume = volume.permute(0,1,-1,-3,-2)
360
+ masked_volume = masked_volume.permute(0,1,-1,-3,-2)
361
+ mask = mask.permute(0,1,-1,-3,-2)
362
+
363
+ # vqgan encoder inference
364
+ masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True)
365
+ masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) /
366
+ (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0
367
+
368
+ cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:])
369
+ cond = torch.cat((masked_volume_feat, cc), dim=1)
370
+
371
+ # diffusion inference and decoder
372
+ shape = masked_volume_feat.shape[-4:]
373
+ samples_ddim, _ = sampler.sample(S=ddim_ts,
374
+ conditioning=cond,
375
+ batch_size=1,
376
+ shape=shape,
377
+ verbose=False)
378
+ samples_ddim = (((samples_ddim + 1.0) / 2.0) * (vqgan.codebook.embeddings.max() -
379
+ vqgan.codebook.embeddings.min())) + vqgan.codebook.embeddings.min()
380
+
381
+ sample = vqgan.decode(samples_ddim, quantize=True)
382
+
383
+ # if organ_type == 'liver' or organ_type == 'kidney':
384
+ # post-process
385
+ mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0)
386
+ sigma = np.random.uniform(0, 4) # (1, 2)
387
+ mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma])
388
+ # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy()
389
+
390
+ volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0)
391
+ sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0)
392
+
393
+ mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device)
394
+ final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_
395
+ final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0)
396
+ # elif organ_type == 'pancreas':
397
+ # final_volume_ = (sample+1.0)/2.0
398
+
399
+ final_volume_ = final_volume_.permute(0,1,-2,-1,-3)
400
+ organ_tumor_mask = torch.ones_like(organ_mask)
401
+ organ_tumor_mask[total_tumor_mask==1] = 2
402
+
403
+ return final_volume_, organ_tumor_mask
404
+
405
+ def synthesize_large_tumor(ct_volume, organ_mask, organ_type, vqgan, sampler, ddim_ts=50):
406
+ device=ct_volume.device
407
+
408
+ # generate tumor mask
409
+ # tumor_types = ['large']
410
+ # tumor_probs = np.array([1.0])
411
+ total_tumor_mask = []
412
+ organ_mask_np = organ_mask.cpu().numpy()
413
+ with torch.no_grad():
414
+ # get model input
415
+ for bs in range(organ_mask_np.shape[0]):
416
+ # synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel())
417
+ synthetic_tumor_type = 'large'
418
+ tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type)
419
+ total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:])
420
+ total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device)
421
+
422
+ volume = ct_volume*2.0 - 1.0
423
+ mask = total_tumor_mask*2.0 - 1.0
424
+ mask_ = 1-total_tumor_mask
425
+ masked_volume = (volume*mask_).detach()
426
+
427
+ volume = volume.permute(0,1,-1,-3,-2)
428
+ masked_volume = masked_volume.permute(0,1,-1,-3,-2)
429
+ mask = mask.permute(0,1,-1,-3,-2)
430
+
431
+ # vqgan encoder inference
432
+ masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True)
433
+ masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) /
434
+ (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0
435
+
436
+ cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:])
437
+ cond = torch.cat((masked_volume_feat, cc), dim=1)
438
+
439
+ # diffusion inference and decoder
440
+ shape = masked_volume_feat.shape[-4:]
441
+ samples_ddim, _ = sampler.sample(S=ddim_ts,
442
+ conditioning=cond,
443
+ batch_size=1,
444
+ shape=shape,
445
+ verbose=False)
446
+ samples_ddim = (((samples_ddim + 1.0) / 2.0) * (vqgan.codebook.embeddings.max() -
447
+ vqgan.codebook.embeddings.min())) + vqgan.codebook.embeddings.min()
448
+
449
+ sample = vqgan.decode(samples_ddim, quantize=True)
450
+
451
+ # if organ_type == 'liver' or organ_type == 'kidney':
452
+ # post-process
453
+ mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0)
454
+ sigma = np.random.uniform(0, 4) # (1, 2)
455
+ mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma])
456
+ # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy()
457
+
458
+ volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0)
459
+ sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0)
460
+
461
+ mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device)
462
+ final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_
463
+ final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0)
464
+ # elif organ_type == 'pancreas':
465
+ # final_volume_ = (sample+1.0)/2.0
466
+
467
+ final_volume_ = final_volume_.permute(0,1,-2,-1,-3)
468
+ organ_tumor_mask = torch.ones_like(organ_mask)
469
+ organ_tumor_mask[total_tumor_mask==1] = 2
470
+
471
+ return final_volume_, organ_tumor_mask
Generation_Pipeline_filter/syn_liver/TumorGeneration/utils_.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Tumor Generateion
2
+ import random
3
+ import cv2
4
+ import elasticdeform
5
+ import numpy as np
6
+ from scipy.ndimage import gaussian_filter
7
+
8
+ # Step 1: Random select (numbers) location for tumor.
9
+ def random_select(mask_scan):
10
+ # we first find z index and then sample point with z slice
11
+ z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0][[0, -1]]
12
+
13
+ # we need to strict number z's position (0.3 - 0.7 in the middle of liver)
14
+ z = round(random.uniform(0.3, 0.7) * (z_end - z_start)) + z_start
15
+
16
+ liver_mask = mask_scan[..., z]
17
+
18
+ # erode the mask (we don't want the edge points)
19
+ kernel = np.ones((5,5), dtype=np.uint8)
20
+ liver_mask = cv2.erode(liver_mask, kernel, iterations=1)
21
+
22
+ coordinates = np.argwhere(liver_mask == 1)
23
+ random_index = np.random.randint(0, len(coordinates))
24
+ xyz = coordinates[random_index].tolist() # get x,y
25
+ xyz.append(z)
26
+ potential_points = xyz
27
+
28
+ return potential_points
29
+
30
+ # Step 2 : generate the ellipsoid
31
+ def get_ellipsoid(x, y, z):
32
+ """"
33
+ x, y, z is the radius of this ellipsoid in x, y, z direction respectly.
34
+ """
35
+ sh = (4*x, 4*y, 4*z)
36
+ out = np.zeros(sh, int)
37
+ aux = np.zeros(sh)
38
+ radii = np.array([x, y, z])
39
+ com = np.array([2*x, 2*y, 2*z]) # center point
40
+
41
+ # calculate the ellipsoid
42
+ bboxl = np.floor(com-radii).clip(0,None).astype(int)
43
+ bboxh = (np.ceil(com+radii)+1).clip(None, sh).astype(int)
44
+ roi = out[tuple(map(slice,bboxl,bboxh))]
45
+ roiaux = aux[tuple(map(slice,bboxl,bboxh))]
46
+ logrid = *map(np.square,np.ogrid[tuple(
47
+ map(slice,(bboxl-com)/radii,(bboxh-com-1)/radii,1j*(bboxh-bboxl)))]),
48
+ dst = (1-sum(logrid)).clip(0,None)
49
+ mask = dst>roiaux
50
+ roi[mask] = 1
51
+ np.copyto(roiaux,dst,where=mask)
52
+
53
+ return out
54
+
55
+ def get_fixed_geo(mask_scan, tumor_type):
56
+
57
+ enlarge_x, enlarge_y, enlarge_z = 160, 160, 160
58
+ geo_mask = np.zeros((mask_scan.shape[0] + enlarge_x, mask_scan.shape[1] + enlarge_y, mask_scan.shape[2] + enlarge_z), dtype=np.int8)
59
+ tiny_radius, small_radius, medium_radius, large_radius = 4, 8, 16, 32
60
+
61
+ if tumor_type == 'tiny':
62
+ num_tumor = random.randint(3,10)
63
+ for _ in range(num_tumor):
64
+ # Tiny tumor
65
+ x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
66
+ y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
67
+ z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
68
+ sigma = random.uniform(0.5,1)
69
+
70
+ geo = get_ellipsoid(x, y, z)
71
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
72
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
73
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
74
+ point = random_select(mask_scan)
75
+ new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
76
+ x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
77
+ y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
78
+ z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
79
+
80
+ # paste small tumor geo into test sample
81
+ geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
82
+
83
+ if tumor_type == 'small':
84
+ num_tumor = random.randint(3,10)
85
+ for _ in range(num_tumor):
86
+ # Small tumor
87
+ x = random.randint(int(0.75*small_radius), int(1.25*small_radius))
88
+ y = random.randint(int(0.75*small_radius), int(1.25*small_radius))
89
+ z = random.randint(int(0.75*small_radius), int(1.25*small_radius))
90
+ sigma = random.randint(1, 2)
91
+
92
+ geo = get_ellipsoid(x, y, z)
93
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
94
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
95
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
96
+ # texture = get_texture((4*x, 4*y, 4*z))
97
+ point = random_select(mask_scan)
98
+ new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
99
+ x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
100
+ y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
101
+ z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
102
+
103
+ # paste small tumor geo into test sample
104
+ geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
105
+ # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture
106
+
107
+ if tumor_type == 'medium':
108
+ num_tumor = random.randint(2, 5)
109
+ for _ in range(num_tumor):
110
+ # medium tumor
111
+ x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
112
+ y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
113
+ z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
114
+ sigma = random.randint(3, 6)
115
+
116
+ geo = get_ellipsoid(x, y, z)
117
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
118
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
119
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
120
+ # texture = get_texture((4*x, 4*y, 4*z))
121
+ point = random_select(mask_scan)
122
+ new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
123
+ x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
124
+ y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
125
+ z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
126
+
127
+ # paste medium tumor geo into test sample
128
+ geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
129
+ # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture
130
+
131
+ if tumor_type == 'large':
132
+ num_tumor = random.randint(1,3)
133
+ for _ in range(num_tumor):
134
+ # Large tumor
135
+ x = random.randint(int(0.75*large_radius), int(1.25*large_radius))
136
+ y = random.randint(int(0.75*large_radius), int(1.25*large_radius))
137
+ z = random.randint(int(0.75*large_radius), int(1.25*large_radius))
138
+ sigma = random.randint(5, 10)
139
+
140
+ geo = get_ellipsoid(x, y, z)
141
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
142
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
143
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
144
+ # texture = get_texture((4*x, 4*y, 4*z))
145
+ point = random_select(mask_scan)
146
+ new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
147
+ x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
148
+ y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
149
+ z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
150
+
151
+ # paste small tumor geo into test sample
152
+ geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
153
+ # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture
154
+
155
+ if tumor_type == "mix":
156
+ # tiny
157
+ num_tumor = random.randint(3,10)
158
+ for _ in range(num_tumor):
159
+ # Tiny tumor
160
+ x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
161
+ y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
162
+ z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
163
+ sigma = random.uniform(0.5,1)
164
+
165
+ geo = get_ellipsoid(x, y, z)
166
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
167
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
168
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
169
+ point = random_select(mask_scan)
170
+ new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
171
+ x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
172
+ y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
173
+ z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
174
+
175
+ # paste small tumor geo into test sample
176
+ geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
177
+
178
+ # small
179
+ num_tumor = random.randint(5,10)
180
+ for _ in range(num_tumor):
181
+ # Small tumor
182
+ x = random.randint(int(0.75*small_radius), int(1.25*small_radius))
183
+ y = random.randint(int(0.75*small_radius), int(1.25*small_radius))
184
+ z = random.randint(int(0.75*small_radius), int(1.25*small_radius))
185
+ sigma = random.randint(1, 2)
186
+
187
+ geo = get_ellipsoid(x, y, z)
188
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
189
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
190
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
191
+ # texture = get_texture((4*x, 4*y, 4*z))
192
+ point = random_select(mask_scan)
193
+ new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
194
+ x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
195
+ y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
196
+ z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
197
+
198
+ # paste small tumor geo into test sample
199
+ geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
200
+ # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture
201
+
202
+ # medium
203
+ num_tumor = random.randint(2, 5)
204
+ for _ in range(num_tumor):
205
+ # medium tumor
206
+ x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
207
+ y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
208
+ z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
209
+ sigma = random.randint(3, 6)
210
+
211
+ geo = get_ellipsoid(x, y, z)
212
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
213
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
214
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
215
+ # texture = get_texture((4*x, 4*y, 4*z))
216
+ point = random_select(mask_scan)
217
+ new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
218
+ x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
219
+ y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
220
+ z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
221
+
222
+ # paste medium tumor geo into test sample
223
+ geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
224
+ # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture
225
+
226
+ # large
227
+ num_tumor = random.randint(1,3)
228
+ for _ in range(num_tumor):
229
+ # Large tumor
230
+ x = random.randint(int(0.75*large_radius), int(1.25*large_radius))
231
+ y = random.randint(int(0.75*large_radius), int(1.25*large_radius))
232
+ z = random.randint(int(0.75*large_radius), int(1.25*large_radius))
233
+ sigma = random.randint(5, 10)
234
+ geo = get_ellipsoid(x, y, z)
235
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
236
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
237
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
238
+ # texture = get_texture((4*x, 4*y, 4*z))
239
+ point = random_select(mask_scan)
240
+ new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
241
+ x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
242
+ y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
243
+ z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
244
+
245
+ # paste small tumor geo into test sample
246
+ geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
247
+ # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture
248
+
249
+ geo_mask = geo_mask[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2]
250
+ # texture_map = texture_map[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2]
251
+ geo_mask = (geo_mask * mask_scan) >=1
252
+
253
+ return geo_mask
254
+
255
+
256
+ def get_tumor(volume_scan, mask_scan, tumor_type):
257
+ tumor_mask = get_fixed_geo(mask_scan, tumor_type)
258
+
259
+ sigma = np.random.uniform(1, 2)
260
+ # difference = np.random.uniform(65, 145)
261
+ difference = 1
262
+
263
+ # blur the boundary
264
+ tumor_mask_blur = gaussian_filter(tumor_mask*1.0, sigma)
265
+
266
+
267
+ abnormally_full = volume_scan * (1 - mask_scan) + abnormally
268
+ abnormally_mask = mask_scan + geo_mask
269
+
270
+ return abnormally_full, abnormally_mask
271
+
272
+ def SynthesisTumor(volume_scan, mask_scan, tumor_type):
273
+ # for speed_generate_tumor, we only send the liver part into the generate program
274
+ x_start, x_end = np.where(np.any(mask_scan, axis=(1, 2)))[0][[0, -1]]
275
+ y_start, y_end = np.where(np.any(mask_scan, axis=(0, 2)))[0][[0, -1]]
276
+ z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0][[0, -1]]
277
+
278
+ # shrink the boundary
279
+ x_start, x_end = max(0, x_start+1), min(mask_scan.shape[0], x_end-1)
280
+ y_start, y_end = max(0, y_start+1), min(mask_scan.shape[1], y_end-1)
281
+ z_start, z_end = max(0, z_start+1), min(mask_scan.shape[2], z_end-1)
282
+
283
+ ct_volume = volume_scan[x_start:x_end, y_start:y_end, z_start:z_end]
284
+ organ_mask = mask_scan[x_start:x_end, y_start:y_end, z_start:z_end]
285
+
286
+ # input texture shape: 420, 300, 320
287
+ # we need to cut it into the shape of liver_mask
288
+ # for examples, the liver_mask.shape = 286, 173, 46; we should change the texture shape
289
+ x_length, y_length, z_length = 64, 64, 64
290
+ crop_x = random.randint(x_start, x_end - x_length - 1) # random select the start point, -1 is to avoid boundary check
291
+ crop_y = random.randint(y_start, y_end - y_length - 1)
292
+ crop_z = random.randint(z_start, z_end - z_length - 1)
293
+
294
+ ct_volume, organ_tumor_mask = get_tumor(ct_volume, organ_mask, tumor_type)
295
+ volume_scan[x_start:x_end, y_start:y_end, z_start:z_end] = ct_volume
296
+ mask_scan[x_start:x_end, y_start:y_end, z_start:z_end] = organ_tumor_mask
297
+
298
+ return volume_scan, mask_scan
Generation_Pipeline_filter/syn_liver/healthy_liver_1k.txt ADDED
@@ -0,0 +1,895 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BDMAP_00004578
2
+ BDMAP_00004183
3
+ BDMAP_00002690
4
+ BDMAP_00004295
5
+ BDMAP_00001736
6
+ BDMAP_00000411
7
+ BDMAP_00003277
8
+ BDMAP_00000696
9
+ BDMAP_00004196
10
+ BDMAP_00001598
11
+ BDMAP_00001183
12
+ BDMAP_00002626
13
+ BDMAP_00004793
14
+ BDMAP_00003385
15
+ BDMAP_00005037
16
+ BDMAP_00004652
17
+ BDMAP_00001383
18
+ BDMAP_00001092
19
+ BDMAP_00004927
20
+ BDMAP_00001618
21
+ BDMAP_00004087
22
+ BDMAP_00002273
23
+ BDMAP_00001288
24
+ BDMAP_00000043
25
+ BDMAP_00003356
26
+ BDMAP_00002776
27
+ BDMAP_00003961
28
+ BDMAP_00002422
29
+ BDMAP_00000345
30
+ BDMAP_00000438
31
+ BDMAP_00001517
32
+ BDMAP_00003564
33
+ BDMAP_00001275
34
+ BDMAP_00003315
35
+ BDMAP_00002986
36
+ BDMAP_00003514
37
+ BDMAP_00000190
38
+ BDMAP_00001434
39
+ BDMAP_00003608
40
+ BDMAP_00001995
41
+ BDMAP_00000414
42
+ BDMAP_00003451
43
+ BDMAP_00002612
44
+ BDMAP_00003744
45
+ BDMAP_00005170
46
+ BDMAP_00002328
47
+ BDMAP_00002940
48
+ BDMAP_00005020
49
+ BDMAP_00000562
50
+ BDMAP_00000810
51
+ BDMAP_00003833
52
+ BDMAP_00000320
53
+ BDMAP_00001791
54
+ BDMAP_00004895
55
+ BDMAP_00003576
56
+ BDMAP_00001924
57
+ BDMAP_00005140
58
+ BDMAP_00003946
59
+ BDMAP_00005067
60
+ BDMAP_00001102
61
+ BDMAP_00001826
62
+ BDMAP_00004131
63
+ BDMAP_00003141
64
+ BDMAP_00002758
65
+ BDMAP_00004969
66
+ BDMAP_00003633
67
+ BDMAP_00004195
68
+ BDMAP_00000030
69
+ BDMAP_00000939
70
+ BDMAP_00001835
71
+ BDMAP_00003762
72
+ BDMAP_00003215
73
+ BDMAP_00003396
74
+ BDMAP_00001078
75
+ BDMAP_00003484
76
+ BDMAP_00001096
77
+ BDMAP_00001688
78
+ BDMAP_00005155
79
+ BDMAP_00005064
80
+ BDMAP_00001862
81
+ BDMAP_00004867
82
+ BDMAP_00001982
83
+ BDMAP_00002295
84
+ BDMAP_00000062
85
+ BDMAP_00000715
86
+ BDMAP_00004608
87
+ BDMAP_00000162
88
+ BDMAP_00003558
89
+ BDMAP_00005070
90
+ BDMAP_00003812
91
+ BDMAP_00000725
92
+ BDMAP_00004624
93
+ BDMAP_00003752
94
+ BDMAP_00001557
95
+ BDMAP_00002185
96
+ BDMAP_00000093
97
+ BDMAP_00003774
98
+ BDMAP_00001701
99
+ BDMAP_00004184
100
+ BDMAP_00000873
101
+ BDMAP_00000236
102
+ BDMAP_00001676
103
+ BDMAP_00001635
104
+ BDMAP_00002475
105
+ BDMAP_00002653
106
+ BDMAP_00003400
107
+ BDMAP_00001863
108
+ BDMAP_00003017
109
+ BDMAP_00001283
110
+ BDMAP_00001359
111
+ BDMAP_00001281
112
+ BDMAP_00004293
113
+ BDMAP_00000582
114
+ BDMAP_00001752
115
+ BDMAP_00004910
116
+ BDMAP_00003373
117
+ BDMAP_00004297
118
+ BDMAP_00003947
119
+ BDMAP_00003612
120
+ BDMAP_00003598
121
+ BDMAP_00002746
122
+ BDMAP_00004552
123
+ BDMAP_00002333
124
+ BDMAP_00002580
125
+ BDMAP_00002871
126
+ BDMAP_00001565
127
+ BDMAP_00003549
128
+ BDMAP_00003976
129
+ BDMAP_00001712
130
+ BDMAP_00001602
131
+ BDMAP_00000812
132
+ BDMAP_00000353
133
+ BDMAP_00001251
134
+ BDMAP_00004841
135
+ BDMAP_00000429
136
+ BDMAP_00000432
137
+ BDMAP_00000159
138
+ BDMAP_00002347
139
+ BDMAP_00002496
140
+ BDMAP_00004735
141
+ BDMAP_00001514
142
+ BDMAP_00003560
143
+ BDMAP_00001209
144
+ BDMAP_00002313
145
+ BDMAP_00005092
146
+ BDMAP_00005009
147
+ BDMAP_00004673
148
+ BDMAP_00000547
149
+ BDMAP_00003255
150
+ BDMAP_00000229
151
+ BDMAP_00001522
152
+ BDMAP_00002426
153
+ BDMAP_00004015
154
+ BDMAP_00004541
155
+ BDMAP_00003952
156
+ BDMAP_00003853
157
+ BDMAP_00001119
158
+ BDMAP_00004198
159
+ BDMAP_00004427
160
+ BDMAP_00004417
161
+ BDMAP_00000833
162
+ BDMAP_00002487
163
+ BDMAP_00002981
164
+ BDMAP_00000653
165
+ BDMAP_00003815
166
+ BDMAP_00003972
167
+ BDMAP_00000373
168
+ BDMAP_00002864
169
+ BDMAP_00002902
170
+ BDMAP_00001836
171
+ BDMAP_00004897
172
+ BDMAP_00002889
173
+ BDMAP_00003493
174
+ BDMAP_00000667
175
+ BDMAP_00004163
176
+ BDMAP_00004586
177
+ BDMAP_00001704
178
+ BDMAP_00002152
179
+ BDMAP_00001258
180
+ BDMAP_00003827
181
+ BDMAP_00001265
182
+ BDMAP_00001040
183
+ BDMAP_00004106
184
+ BDMAP_00000059
185
+ BDMAP_00002363
186
+ BDMAP_00000161
187
+ BDMAP_00001475
188
+ BDMAP_00001747
189
+ BDMAP_00001027
190
+ BDMAP_00000279
191
+ BDMAP_00002242
192
+ BDMAP_00004175
193
+ BDMAP_00003358
194
+ BDMAP_00004815
195
+ BDMAP_00003580
196
+ BDMAP_00001068
197
+ BDMAP_00003327
198
+ BDMAP_00004616
199
+ BDMAP_00000197
200
+ BDMAP_00003740
201
+ BDMAP_00005074
202
+ BDMAP_00001261
203
+ BDMAP_00002775
204
+ BDMAP_00002545
205
+ BDMAP_00000104
206
+ BDMAP_00004738
207
+ BDMAP_00005099
208
+ BDMAP_00004672
209
+ BDMAP_00004074
210
+ BDMAP_00004288
211
+ BDMAP_00003590
212
+ BDMAP_00001545
213
+ BDMAP_00004922
214
+ BDMAP_00002619
215
+ BDMAP_00000874
216
+ BDMAP_00001438
217
+ BDMAP_00003138
218
+ BDMAP_00002251
219
+ BDMAP_00003769
220
+ BDMAP_00003267
221
+ BDMAP_00002216
222
+ BDMAP_00003994
223
+ BDMAP_00002742
224
+ BDMAP_00001089
225
+ BDMAP_00003957
226
+ BDMAP_00001533
227
+ BDMAP_00004636
228
+ BDMAP_00004499
229
+ BDMAP_00000698
230
+ BDMAP_00002232
231
+ BDMAP_00004250
232
+ BDMAP_00004491
233
+ BDMAP_00001636
234
+ BDMAP_00005078
235
+ BDMAP_00004121
236
+ BDMAP_00001845
237
+ BDMAP_00004264
238
+ BDMAP_00000137
239
+ BDMAP_00003516
240
+ BDMAP_00005017
241
+ BDMAP_00000087
242
+ BDMAP_00000319
243
+ BDMAP_00001828
244
+ BDMAP_00000948
245
+ BDMAP_00001977
246
+ BDMAP_00003457
247
+ BDMAP_00005157
248
+ BDMAP_00003150
249
+ BDMAP_00002166
250
+ BDMAP_00003301
251
+ BDMAP_00003680
252
+ BDMAP_00003133
253
+ BDMAP_00000574
254
+ BDMAP_00002305
255
+ BDMAP_00004843
256
+ BDMAP_00002230
257
+ BDMAP_00000332
258
+ BDMAP_00003063
259
+ BDMAP_00002076
260
+ BDMAP_00003319
261
+ BDMAP_00004373
262
+ BDMAP_00004880
263
+ BDMAP_00000623
264
+ BDMAP_00003631
265
+ BDMAP_00001737
266
+ BDMAP_00001057
267
+ BDMAP_00002173
268
+ BDMAP_00000139
269
+ BDMAP_00001891
270
+ BDMAP_00000552
271
+ BDMAP_00004717
272
+ BDMAP_00003172
273
+ BDMAP_00003955
274
+ BDMAP_00001664
275
+ BDMAP_00003070
276
+ BDMAP_00004550
277
+ BDMAP_00002057
278
+ BDMAP_00000616
279
+ BDMAP_00000913
280
+ BDMAP_00000388
281
+ BDMAP_00000355
282
+ BDMAP_00003333
283
+ BDMAP_00004148
284
+ BDMAP_00001985
285
+ BDMAP_00001921
286
+ BDMAP_00001624
287
+ BDMAP_00004129
288
+ BDMAP_00002598
289
+ BDMAP_00000859
290
+ BDMAP_00000558
291
+ BDMAP_00002226
292
+ BDMAP_00000452
293
+ BDMAP_00004829
294
+ BDMAP_00003455
295
+ BDMAP_00002402
296
+ BDMAP_00000117
297
+ BDMAP_00000826
298
+ BDMAP_00000243
299
+ BDMAP_00002319
300
+ BDMAP_00002737
301
+ BDMAP_00002318
302
+ BDMAP_00003357
303
+ BDMAP_00000692
304
+ BDMAP_00003427
305
+ BDMAP_00001441
306
+ BDMAP_00004796
307
+ BDMAP_00002171
308
+ BDMAP_00001296
309
+ BDMAP_00004296
310
+ BDMAP_00003808
311
+ BDMAP_00003058
312
+ BDMAP_00003502
313
+ BDMAP_00001045
314
+ BDMAP_00003438
315
+ BDMAP_00002884
316
+ BDMAP_00004561
317
+ BDMAP_00000462
318
+ BDMAP_00001785
319
+ BDMAP_00000794
320
+ BDMAP_00000942
321
+ BDMAP_00002947
322
+ BDMAP_00004744
323
+ BDMAP_00004328
324
+ BDMAP_00004671
325
+ BDMAP_00005108
326
+ BDMAP_00002278
327
+ BDMAP_00000679
328
+ BDMAP_00004903
329
+ BDMAP_00001732
330
+ BDMAP_00001095
331
+ BDMAP_00003343
332
+ BDMAP_00001289
333
+ BDMAP_00001109
334
+ BDMAP_00003650
335
+ BDMAP_00001710
336
+ BDMAP_00003031
337
+ BDMAP_00001617
338
+ BDMAP_00001246
339
+ BDMAP_00004894
340
+ BDMAP_00003520
341
+ BDMAP_00004097
342
+ BDMAP_00001020
343
+ BDMAP_00003600
344
+ BDMAP_00001518
345
+ BDMAP_00000416
346
+ BDMAP_00004990
347
+ BDMAP_00005151
348
+ BDMAP_00000132
349
+ BDMAP_00000138
350
+ BDMAP_00004885
351
+ BDMAP_00000771
352
+ BDMAP_00003928
353
+ BDMAP_00001419
354
+ BDMAP_00003130
355
+ BDMAP_00001892
356
+ BDMAP_00003886
357
+ BDMAP_00004479
358
+ BDMAP_00003918
359
+ BDMAP_00003324
360
+ BDMAP_00002410
361
+ BDMAP_00002509
362
+ BDMAP_00000701
363
+ BDMAP_00003847
364
+ BDMAP_00004450
365
+ BDMAP_00003363
366
+ BDMAP_00002875
367
+ BDMAP_00002793
368
+ BDMAP_00005113
369
+ BDMAP_00000465
370
+ BDMAP_00004847
371
+ BDMAP_00004294
372
+ BDMAP_00000936
373
+ BDMAP_00002476
374
+ BDMAP_00003840
375
+ BDMAP_00004130
376
+ BDMAP_00003614
377
+ BDMAP_00000883
378
+ BDMAP_00000542
379
+ BDMAP_00002562
380
+ BDMAP_00000285
381
+ BDMAP_00001256
382
+ BDMAP_00004597
383
+ BDMAP_00002260
384
+ BDMAP_00001067
385
+ BDMAP_00000968
386
+ BDMAP_00005085
387
+ BDMAP_00003412
388
+ BDMAP_00003884
389
+ BDMAP_00001420
390
+ BDMAP_00003268
391
+ BDMAP_00001735
392
+ BDMAP_00003392
393
+ BDMAP_00000241
394
+ BDMAP_00003326
395
+ BDMAP_00001853
396
+ BDMAP_00001126
397
+ BDMAP_00002237
398
+ BDMAP_00003809
399
+ BDMAP_00001584
400
+ BDMAP_00003359
401
+ BDMAP_00002730
402
+ BDMAP_00000923
403
+ BDMAP_00000687
404
+ BDMAP_00003281
405
+ BDMAP_00004431
406
+ BDMAP_00001440
407
+ BDMAP_00001410
408
+ BDMAP_00004650
409
+ BDMAP_00004065
410
+ BDMAP_00001806
411
+ BDMAP_00002227
412
+ BDMAP_00001906
413
+ BDMAP_00000331
414
+ BDMAP_00001130
415
+ BDMAP_00003178
416
+ BDMAP_00002707
417
+ BDMAP_00001646
418
+ BDMAP_00001707
419
+ BDMAP_00003592
420
+ BDMAP_00003943
421
+ BDMAP_00002361
422
+ BDMAP_00004901
423
+ BDMAP_00003329
424
+ BDMAP_00005075
425
+ BDMAP_00002326
426
+ BDMAP_00003713
427
+ BDMAP_00003832
428
+ BDMAP_00004165
429
+ BDMAP_00004415
430
+ BDMAP_00004331
431
+ BDMAP_00001035
432
+ BDMAP_00004457
433
+ BDMAP_00003347
434
+ BDMAP_00001422
435
+ BDMAP_00002437
436
+ BDMAP_00003996
437
+ BDMAP_00003461
438
+ BDMAP_00002751
439
+ BDMAP_00002523
440
+ BDMAP_00000439
441
+ BDMAP_00004746
442
+ BDMAP_00002188
443
+ BDMAP_00004253
444
+ BDMAP_00000935
445
+ BDMAP_00002451
446
+ BDMAP_00003971
447
+ BDMAP_00000926
448
+ BDMAP_00003109
449
+ BDMAP_00000660
450
+ BDMAP_00001169
451
+ BDMAP_00001331
452
+ BDMAP_00001175
453
+ BDMAP_00000881
454
+ BDMAP_00000263
455
+ BDMAP_00002401
456
+ BDMAP_00005167
457
+ BDMAP_00002041
458
+ BDMAP_00000656
459
+ BDMAP_00000366
460
+ BDMAP_00002582
461
+ BDMAP_00001238
462
+ BDMAP_00001590
463
+ BDMAP_00001784
464
+ BDMAP_00001564
465
+ BDMAP_00004719
466
+ BDMAP_00001917
467
+ BDMAP_00003956
468
+ BDMAP_00003225
469
+ BDMAP_00000982
470
+ BDMAP_00004992
471
+ BDMAP_00003479
472
+ BDMAP_00001215
473
+ BDMAP_00004147
474
+ BDMAP_00001711
475
+ BDMAP_00000626
476
+ BDMAP_00000516
477
+ BDMAP_00004876
478
+ BDMAP_00003376
479
+ BDMAP_00001628
480
+ BDMAP_00001148
481
+ BDMAP_00003672
482
+ BDMAP_00001205
483
+ BDMAP_00004651
484
+ BDMAP_00000987
485
+ BDMAP_00004104
486
+ BDMAP_00001647
487
+ BDMAP_00000998
488
+ BDMAP_00002244
489
+ BDMAP_00004676
490
+ BDMAP_00001908
491
+ BDMAP_00000714
492
+ BDMAP_00001104
493
+ BDMAP_00001911
494
+ BDMAP_00000882
495
+ BDMAP_00003930
496
+ BDMAP_00000368
497
+ BDMAP_00003923
498
+ BDMAP_00002099
499
+ BDMAP_00000240
500
+ BDMAP_00003658
501
+ BDMAP_00005077
502
+ BDMAP_00002696
503
+ BDMAP_00002184
504
+ BDMAP_00003890
505
+ BDMAP_00002704
506
+ BDMAP_00000066
507
+ BDMAP_00005006
508
+ BDMAP_00001242
509
+ BDMAP_00002396
510
+ BDMAP_00004389
511
+ BDMAP_00002656
512
+ BDMAP_00000469
513
+ BDMAP_00001138
514
+ BDMAP_00004773
515
+ BDMAP_00004033
516
+ BDMAP_00004128
517
+ BDMAP_00002631
518
+ BDMAP_00004925
519
+ BDMAP_00004475
520
+ BDMAP_00001521
521
+ BDMAP_00000364
522
+ BDMAP_00002953
523
+ BDMAP_00003776
524
+ BDMAP_00004154
525
+ BDMAP_00002654
526
+ BDMAP_00002959
527
+ BDMAP_00002199
528
+ BDMAP_00003551
529
+ BDMAP_00002465
530
+ BDMAP_00005154
531
+ BDMAP_00002648
532
+ BDMAP_00000128
533
+ BDMAP_00001001
534
+ BDMAP_00002017
535
+ BDMAP_00004712
536
+ BDMAP_00004286
537
+ BDMAP_00000568
538
+ BDMAP_00004858
539
+ BDMAP_00001782
540
+ BDMAP_00001496
541
+ BDMAP_00004407
542
+ BDMAP_00002250
543
+ BDMAP_00001212
544
+ BDMAP_00000972
545
+ BDMAP_00004374
546
+ BDMAP_00002846
547
+ BDMAP_00002472
548
+ BDMAP_00000569
549
+ BDMAP_00004981
550
+ BDMAP_00000176
551
+ BDMAP_00003510
552
+ BDMAP_00003771
553
+ BDMAP_00002804
554
+ BDMAP_00004558
555
+ BDMAP_00003411
556
+ BDMAP_00001563
557
+ BDMAP_00000604
558
+ BDMAP_00002075
559
+ BDMAP_00005160
560
+ BDMAP_00001511
561
+ BDMAP_00001273
562
+ BDMAP_00002603
563
+ BDMAP_00001656
564
+ BDMAP_00003822
565
+ BDMAP_00004510
566
+ BDMAP_00001809
567
+ BDMAP_00002944
568
+ BDMAP_00002739
569
+ BDMAP_00002609
570
+ BDMAP_00003849
571
+ BDMAP_00001128
572
+ BDMAP_00003717
573
+ BDMAP_00000036
574
+ BDMAP_00002863
575
+ BDMAP_00004956
576
+ BDMAP_00004229
577
+ BDMAP_00003425
578
+ BDMAP_00001865
579
+ BDMAP_00000608
580
+ BDMAP_00004620
581
+ BDMAP_00000589
582
+ BDMAP_00001597
583
+ BDMAP_00003543
584
+ BDMAP_00004645
585
+ BDMAP_00004395
586
+ BDMAP_00005105
587
+ BDMAP_00001426
588
+ BDMAP_00000264
589
+ BDMAP_00001504
590
+ BDMAP_00001649
591
+ BDMAP_00000662
592
+ BDMAP_00002854
593
+ BDMAP_00004060
594
+ BDMAP_00003440
595
+ BDMAP_00003367
596
+ BDMAP_00004011
597
+ BDMAP_00003634
598
+ BDMAP_00003443
599
+ BDMAP_00000828
600
+ BDMAP_00000889
601
+ BDMAP_00000321
602
+ BDMAP_00004615
603
+ BDMAP_00000244
604
+ BDMAP_00003685
605
+ BDMAP_00001461
606
+ BDMAP_00001396
607
+ BDMAP_00004262
608
+ BDMAP_00004579
609
+ BDMAP_00005022
610
+ BDMAP_00004804
611
+ BDMAP_00001632
612
+ BDMAP_00002661
613
+ BDMAP_00000980
614
+ BDMAP_00001445
615
+ BDMAP_00000809
616
+ BDMAP_00004384
617
+ BDMAP_00003114
618
+ BDMAP_00000435
619
+ BDMAP_00003406
620
+ BDMAP_00002899
621
+ BDMAP_00002164
622
+ BDMAP_00002498
623
+ BDMAP_00000039
624
+ BDMAP_00002524
625
+ BDMAP_00000805
626
+ BDMAP_00004604
627
+ BDMAP_00000338
628
+ BDMAP_00002990
629
+ BDMAP_00001516
630
+ BDMAP_00002896
631
+ BDMAP_00004549
632
+ BDMAP_00000259
633
+ BDMAP_00001945
634
+ BDMAP_00002695
635
+ BDMAP_00005141
636
+ BDMAP_00002828
637
+ BDMAP_00003781
638
+ BDMAP_00003900
639
+ BDMAP_00004278
640
+ BDMAP_00004551
641
+ BDMAP_00000532
642
+ BDMAP_00002844
643
+ BDMAP_00001476
644
+ BDMAP_00004887
645
+ BDMAP_00005174
646
+ BDMAP_00000836
647
+ BDMAP_00001456
648
+ BDMAP_00001607
649
+ BDMAP_00003164
650
+ BDMAP_00002404
651
+ BDMAP_00003036
652
+ BDMAP_00001225
653
+ BDMAP_00002022
654
+ BDMAP_00004030
655
+ BDMAP_00000329
656
+ BDMAP_00002253
657
+ BDMAP_00000154
658
+ BDMAP_00003111
659
+ BDMAP_00003384
660
+ BDMAP_00000023
661
+ BDMAP_00001125
662
+ BDMAP_00001414
663
+ BDMAP_00002383
664
+ BDMAP_00003483
665
+ BDMAP_00000034
666
+ BDMAP_00001413
667
+ BDMAP_00003767
668
+ BDMAP_00001368
669
+ BDMAP_00003448
670
+ BDMAP_00000940
671
+ BDMAP_00000430
672
+ BDMAP_00003153
673
+ BDMAP_00003603
674
+ BDMAP_00003202
675
+ BDMAP_00002421
676
+ BDMAP_00005001
677
+ BDMAP_00004447
678
+ BDMAP_00001325
679
+ BDMAP_00003168
680
+ BDMAP_00000887
681
+ BDMAP_00004481
682
+ BDMAP_00001324
683
+ BDMAP_00004066
684
+ BDMAP_00001474
685
+ BDMAP_00004850
686
+ BDMAP_00002233
687
+ BDMAP_00000511
688
+ BDMAP_00001223
689
+ BDMAP_00003581
690
+ BDMAP_00002930
691
+ BDMAP_00001305
692
+ BDMAP_00002689
693
+ BDMAP_00002332
694
+ BDMAP_00000683
695
+ BDMAP_00003300
696
+ BDMAP_00003701
697
+ BDMAP_00001015
698
+ BDMAP_00001562
699
+ BDMAP_00001898
700
+ BDMAP_00001247
701
+ BDMAP_00001941
702
+ BDMAP_00002840
703
+ BDMAP_00002440
704
+ BDMAP_00000245
705
+ BDMAP_00002855
706
+ BDMAP_00004493
707
+ BDMAP_00000989
708
+ BDMAP_00003736
709
+ BDMAP_00002265
710
+ BDMAP_00004039
711
+ BDMAP_00002826
712
+ BDMAP_00002924
713
+ BDMAP_00003299
714
+ BDMAP_00001361
715
+ BDMAP_00004014
716
+ BDMAP_00001444
717
+ BDMAP_00001370
718
+ BDMAP_00002304
719
+ BDMAP_00000774
720
+ BDMAP_00000614
721
+ BDMAP_00000434
722
+ BDMAP_00001230
723
+ BDMAP_00000044
724
+ BDMAP_00001768
725
+ BDMAP_00004783
726
+ BDMAP_00004494
727
+ BDMAP_00001905
728
+ BDMAP_00003824
729
+ BDMAP_00002309
730
+ BDMAP_00004511
731
+ BDMAP_00000233
732
+ BDMAP_00002845
733
+ BDMAP_00005016
734
+ BDMAP_00002829
735
+ BDMAP_00001059
736
+ BDMAP_00001549
737
+ BDMAP_00002403
738
+ BDMAP_00001794
739
+ BDMAP_00001286
740
+ BDMAP_00003294
741
+ BDMAP_00003722
742
+ BDMAP_00000902
743
+ BDMAP_00002298
744
+ BDMAP_00005191
745
+ BDMAP_00001487
746
+ BDMAP_00003364
747
+ BDMAP_00001605
748
+ BDMAP_00001483
749
+ BDMAP_00000676
750
+ BDMAP_00002945
751
+ BDMAP_00005073
752
+ BDMAP_00002085
753
+ BDMAP_00000716
754
+ BDMAP_00003435
755
+ BDMAP_00002803
756
+ BDMAP_00002663
757
+ BDMAP_00003727
758
+ BDMAP_00000839
759
+ BDMAP_00002068
760
+ BDMAP_00004764
761
+ BDMAP_00002114
762
+ BDMAP_00004741
763
+ BDMAP_00004077
764
+ BDMAP_00004870
765
+ BDMAP_00000571
766
+ BDMAP_00004115
767
+ BDMAP_00001868
768
+ BDMAP_00004113
769
+ BDMAP_00002039
770
+ BDMAP_00004257
771
+ BDMAP_00001620
772
+ BDMAP_00000470
773
+ BDMAP_00000149
774
+ BDMAP_00002815
775
+ BDMAP_00000304
776
+ BDMAP_00005185
777
+ BDMAP_00003113
778
+ BDMAP_00005063
779
+ BDMAP_00000122
780
+ BDMAP_00004482
781
+ BDMAP_00002471
782
+ BDMAP_00004023
783
+ BDMAP_00000225
784
+ BDMAP_00003657
785
+ BDMAP_00001255
786
+ BDMAP_00002616
787
+ BDMAP_00002407
788
+ BDMAP_00002060
789
+ BDMAP_00004546
790
+ BDMAP_00004917
791
+ BDMAP_00003615
792
+ BDMAP_00003525
793
+ BDMAP_00002120
794
+ BDMAP_00000481
795
+ BDMAP_00004770
796
+ BDMAP_00003683
797
+ BDMAP_00000618
798
+ BDMAP_00001875
799
+ BDMAP_00003409
800
+ BDMAP_00003381
801
+ BDMAP_00004398
802
+ BDMAP_00000867
803
+ BDMAP_00000487
804
+ BDMAP_00003073
805
+ BDMAP_00002592
806
+ BDMAP_00005120
807
+ BDMAP_00003128
808
+ BDMAP_00001754
809
+ BDMAP_00004232
810
+ BDMAP_00000855
811
+ BDMAP_00000069
812
+ BDMAP_00002744
813
+ BDMAP_00004808
814
+ BDMAP_00004031
815
+ BDMAP_00001842
816
+ BDMAP_00000324
817
+ BDMAP_00002933
818
+ BDMAP_00004954
819
+ BDMAP_00000541
820
+ BDMAP_00002458
821
+ BDMAP_00002288
822
+ BDMAP_00002807
823
+ BDMAP_00000837
824
+ BDMAP_00002065
825
+ BDMAP_00000152
826
+ BDMAP_00003491
827
+ BDMAP_00001464
828
+ BDMAP_00003486
829
+ BDMAP_00003244
830
+ BDMAP_00000871
831
+ BDMAP_00002362
832
+ BDMAP_00000993
833
+ BDMAP_00000219
834
+ BDMAP_00000192
835
+ BDMAP_00001218
836
+ BDMAP_00001024
837
+ BDMAP_00004980
838
+ BDMAP_00000713
839
+ BDMAP_00001523
840
+ BDMAP_00002688
841
+ BDMAP_00003143
842
+ BDMAP_00005114
843
+ BDMAP_00003749
844
+ BDMAP_00002354
845
+ BDMAP_00000052
846
+ BDMAP_00002710
847
+ BDMAP_00004817
848
+ BDMAP_00004964
849
+ BDMAP_00004775
850
+ BDMAP_00005005
851
+ BDMAP_00004216
852
+ BDMAP_00002936
853
+ BDMAP_00000956
854
+ BDMAP_00002942
855
+ BDMAP_00001705
856
+ BDMAP_00001823
857
+ BDMAP_00002387
858
+ BDMAP_00000690
859
+ BDMAP_00002021
860
+ BDMAP_00000851
861
+ BDMAP_00000427
862
+ BDMAP_00002133
863
+ BDMAP_00004231
864
+ BDMAP_00005169
865
+ BDMAP_00003640
866
+ BDMAP_00000977
867
+ BDMAP_00002103
868
+ BDMAP_00000449
869
+ BDMAP_00001214
870
+ BDMAP_00003506
871
+ BDMAP_00002411
872
+ BDMAP_00003973
873
+ BDMAP_00001912
874
+ BDMAP_00000710
875
+ BDMAP_00004514
876
+ BDMAP_00001807
877
+ BDMAP_00001769
878
+ BDMAP_00001746
879
+ BDMAP_00001804
880
+ BDMAP_00002484
881
+ BDMAP_00003444
882
+ BDMAP_00002029
883
+ BDMAP_00001237
884
+ BDMAP_00004420
885
+ BDMAP_00000431
886
+ BDMAP_00003252
887
+ BDMAP_00005081
888
+ BDMAP_00003694
889
+ BDMAP_00002655
890
+ BDMAP_00004641
891
+ BDMAP_00000297
892
+ BDMAP_00001077
893
+ BDMAP_00003254
894
+ BDMAP_00000447
895
+ BDMAP_00004834
Generation_Pipeline_filter/syn_liver/out/BDMAP_00000411/ct.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccce1a9d60005f0ebedea7a9bfc7f4ca0228d2ba6285d52d6ef40eda6714f6f3
3
+ size 20965188
Generation_Pipeline_filter/syn_liver/out/BDMAP_00000411/segmentations/liver_tumor.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ddbe505108fb18bca8b6edf1150249647080ad0ef05bfa750409590ef2880f9
3
+ size 64437
Generation_Pipeline_filter/syn_liver/out/BDMAP_00000696/ct.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6e057c93fe99ecf8b9d2c81fc2f4535df63b19fefd623b8f7ad96b8d1463e1a
3
+ size 27445316
Generation_Pipeline_filter/syn_liver/out/BDMAP_00000696/segmentations/liver_tumor.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12e2b1662f93f991b04ac06b4c50573b8321a49671a65366d1de28769ea37805
3
+ size 76815
Generation_Pipeline_filter/syn_liver/out/BDMAP_00001736/ct.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2eb22ba25ca011daff851e2a3c2d96e0378df78586caadde56217a049fc37f5
3
+ size 16868219
Generation_Pipeline_filter/syn_liver/out/BDMAP_00001736/segmentations/liver_tumor.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca7474ac04c3fb9ff09bf080f4b0c3c9d054d9f6b3f56a8b9a09c4283fa2ac31
3
+ size 46592
Generation_Pipeline_filter/syn_liver/out/BDMAP_00002690/ct.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fad3428ec7670ad5d3d277275732c148cf9b7510a261e47dd3f967f17cae3511
3
+ size 20584373
Generation_Pipeline_filter/syn_liver/out/BDMAP_00002690/segmentations/liver_tumor.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00e316a875b7b1ed74b42023fbf58dab903030e65be627f0857038598a902d83
3
+ size 53855
Generation_Pipeline_filter/syn_liver/out/BDMAP_00003277/ct.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30e2630e9acb2526793d24b67ee57a80beb430d09980444cc3ca4cd61504214e
3
+ size 26242723
Generation_Pipeline_filter/syn_liver/out/BDMAP_00003277/segmentations/liver_tumor.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b993b9c74d95feebd2a4c7a1617cd2c4d21ca926096e5d1df39e760a802aca70
3
+ size 76864
Generation_Pipeline_filter/syn_liver/out/BDMAP_00004183/ct.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13a7ee052fc2fbaf0139f3a2fcd280723330e986407038348e23658097d0cc44
3
+ size 25002393
Generation_Pipeline_filter/syn_liver/out/BDMAP_00004183/segmentations/liver_tumor.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d88c86229effcde00b19dc23ae357edb5656f0a702cabca6613450201652702
3
+ size 66834