qicq1c commited on
Commit
245fd45
·
verified ·
1 Parent(s): 70083c8

4f53598323d7a6dc77e99aa21ab33e62828bdc1dca94a3f3ab173d6e8d4ecf73

Browse files
Files changed (50) hide show
  1. Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/ddpm/diffusion.py +1016 -0
  2. Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/ddpm/text.py +94 -0
  3. Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/ddpm/time_embedding.py +75 -0
  4. Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/ddpm/unet.py +226 -0
  5. Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/ddpm/util.py +271 -0
  6. Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc +0 -0
  7. Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__init__.py +3 -0
  8. Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc +0 -0
  9. Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc +0 -0
  10. Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc +0 -0
  11. Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc +0 -0
  12. Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth +3 -0
  13. Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/codebook.py +109 -0
  14. Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/lpips.py +181 -0
  15. Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py +561 -0
  16. Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/vq_gan_3d/utils.py +177 -0
  17. Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/model_weight/diffusion_colon.pt +3 -0
  18. Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/model_weight/recon_colon.ckpt +3 -0
  19. Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/utils.py +233 -0
  20. Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/utils_.py +298 -0
  21. Generation_Pipeline_filter_all2/syn_colon/healthy_colon_1k.txt +928 -0
  22. Generation_Pipeline_filter_all2/syn_colon/requirements.txt +94 -0
  23. Generation_Pipeline_filter_all2/syn_kidney/CT_syn_kidney_data_new.py +238 -0
  24. Generation_Pipeline_filter_all2/syn_kidney/CT_syn_kidney_data_new2.py +247 -0
  25. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/.DS_Store +0 -0
  26. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/README.md +5 -0
  27. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/TumorGenerated.py +39 -0
  28. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/__init__.py +5 -0
  29. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc +0 -0
  30. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/__pycache__/__init__.cpython-38.pyc +0 -0
  31. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/__pycache__/utils.cpython-38.pyc +0 -0
  32. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/__pycache__/utils_.cpython-38.pyc +0 -0
  33. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/diffusion_config/ddpm.yaml +29 -0
  34. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/diffusion_config/vq_gan_3d.yaml +37 -0
  35. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/__init__.py +1 -0
  36. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc +0 -0
  37. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc +0 -0
  38. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc +0 -0
  39. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc +0 -0
  40. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc +0 -0
  41. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc +0 -0
  42. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc +0 -0
  43. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/ddim.py +206 -0
  44. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/diffusion.py +1016 -0
  45. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/text.py +94 -0
  46. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/time_embedding.py +75 -0
  47. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/unet.py +226 -0
  48. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/util.py +271 -0
  49. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc +0 -0
  50. Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__init__.py +3 -0
Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/ddpm/diffusion.py ADDED
@@ -0,0 +1,1016 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Largely taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch"
2
+
3
+ import math
4
+ import copy
5
+ import torch
6
+ from torch import nn, einsum
7
+ import torch.nn.functional as F
8
+ from functools import partial
9
+
10
+ from torch.utils import data
11
+ from pathlib import Path
12
+ from torch.optim import Adam
13
+ from torchvision import transforms as T, utils
14
+ from torch.cuda.amp import autocast, GradScaler
15
+ from PIL import Image
16
+
17
+ from tqdm import tqdm
18
+ from einops import rearrange
19
+ from einops_exts import check_shape, rearrange_many
20
+
21
+ from rotary_embedding_torch import RotaryEmbedding
22
+
23
+ from .text import tokenize, bert_embed, BERT_MODEL_DIM
24
+ from torch.utils.data import Dataset, DataLoader
25
+ from ..vq_gan_3d.model.vqgan import VQGAN
26
+
27
+ import matplotlib.pyplot as plt
28
+
29
+ # helpers functions
30
+
31
+
32
+ def exists(x):
33
+ return x is not None
34
+
35
+
36
+ def noop(*args, **kwargs):
37
+ pass
38
+
39
+
40
+ def is_odd(n):
41
+ return (n % 2) == 1
42
+
43
+
44
+ def default(val, d):
45
+ if exists(val):
46
+ return val
47
+ return d() if callable(d) else d
48
+
49
+
50
+ def cycle(dl):
51
+ while True:
52
+ for data in dl:
53
+ yield data
54
+
55
+
56
+ def num_to_groups(num, divisor):
57
+ groups = num // divisor
58
+ remainder = num % divisor
59
+ arr = [divisor] * groups
60
+ if remainder > 0:
61
+ arr.append(remainder)
62
+ return arr
63
+
64
+
65
+ def prob_mask_like(shape, prob, device):
66
+ if prob == 1:
67
+ return torch.ones(shape, device=device, dtype=torch.bool)
68
+ elif prob == 0:
69
+ return torch.zeros(shape, device=device, dtype=torch.bool)
70
+ else:
71
+ return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
72
+
73
+
74
+ def is_list_str(x):
75
+ if not isinstance(x, (list, tuple)):
76
+ return False
77
+ return all([type(el) == str for el in x])
78
+
79
+ # relative positional bias
80
+
81
+
82
+ class RelativePositionBias(nn.Module):
83
+ def __init__(
84
+ self,
85
+ heads=8,
86
+ num_buckets=32,
87
+ max_distance=128
88
+ ):
89
+ super().__init__()
90
+ self.num_buckets = num_buckets
91
+ self.max_distance = max_distance
92
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
93
+
94
+ @staticmethod
95
+ def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
96
+ ret = 0
97
+ n = -relative_position
98
+
99
+ num_buckets //= 2
100
+ ret += (n < 0).long() * num_buckets
101
+ n = torch.abs(n)
102
+
103
+ max_exact = num_buckets // 2
104
+ is_small = n < max_exact
105
+
106
+ val_if_large = max_exact + (
107
+ torch.log(n.float() / max_exact) / math.log(max_distance /
108
+ max_exact) * (num_buckets - max_exact)
109
+ ).long()
110
+ val_if_large = torch.min(
111
+ val_if_large, torch.full_like(val_if_large, num_buckets - 1))
112
+
113
+ ret += torch.where(is_small, n, val_if_large)
114
+ return ret
115
+
116
+ def forward(self, n, device):
117
+ q_pos = torch.arange(n, dtype=torch.long, device=device)
118
+ k_pos = torch.arange(n, dtype=torch.long, device=device)
119
+ rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
120
+ rp_bucket = self._relative_position_bucket(
121
+ rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance)
122
+ values = self.relative_attention_bias(rp_bucket)
123
+ return rearrange(values, 'i j h -> h i j')
124
+
125
+ # small helper modules
126
+
127
+
128
+ class EMA():
129
+ def __init__(self, beta):
130
+ super().__init__()
131
+ self.beta = beta
132
+
133
+ def update_model_average(self, ma_model, current_model):
134
+ for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
135
+ old_weight, up_weight = ma_params.data, current_params.data
136
+ ma_params.data = self.update_average(old_weight, up_weight)
137
+
138
+ def update_average(self, old, new):
139
+ if old is None:
140
+ return new
141
+ return old * self.beta + (1 - self.beta) * new
142
+
143
+
144
+ class Residual(nn.Module):
145
+ def __init__(self, fn):
146
+ super().__init__()
147
+ self.fn = fn
148
+
149
+ def forward(self, x, *args, **kwargs):
150
+ return self.fn(x, *args, **kwargs) + x
151
+
152
+
153
+ class SinusoidalPosEmb(nn.Module):
154
+ def __init__(self, dim):
155
+ super().__init__()
156
+ self.dim = dim
157
+
158
+ def forward(self, x):
159
+ device = x.device
160
+ half_dim = self.dim // 2
161
+ emb = math.log(10000) / (half_dim - 1)
162
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
163
+ emb = x[:, None] * emb[None, :]
164
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
165
+ return emb
166
+
167
+
168
+ def Upsample(dim):
169
+ return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))
170
+
171
+
172
+ def Downsample(dim):
173
+ return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))
174
+
175
+
176
+ class LayerNorm(nn.Module):
177
+ def __init__(self, dim, eps=1e-5):
178
+ super().__init__()
179
+ self.eps = eps
180
+ self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1))
181
+
182
+ def forward(self, x):
183
+ var = torch.var(x, dim=1, unbiased=False, keepdim=True)
184
+ mean = torch.mean(x, dim=1, keepdim=True)
185
+ return (x - mean) / (var + self.eps).sqrt() * self.gamma
186
+
187
+
188
+ class PreNorm(nn.Module):
189
+ def __init__(self, dim, fn):
190
+ super().__init__()
191
+ self.fn = fn
192
+ self.norm = LayerNorm(dim)
193
+
194
+ def forward(self, x, **kwargs):
195
+ x = self.norm(x)
196
+ return self.fn(x, **kwargs)
197
+
198
+ # building block modules
199
+
200
+
201
+ class Block(nn.Module):
202
+ def __init__(self, dim, dim_out, groups=8):
203
+ super().__init__()
204
+ self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1))
205
+ self.norm = nn.GroupNorm(groups, dim_out)
206
+ self.act = nn.SiLU()
207
+
208
+ def forward(self, x, scale_shift=None):
209
+ x = self.proj(x)
210
+ x = self.norm(x)
211
+
212
+ if exists(scale_shift):
213
+ scale, shift = scale_shift
214
+ x = x * (scale + 1) + shift
215
+
216
+ return self.act(x)
217
+
218
+
219
+ class ResnetBlock(nn.Module):
220
+ def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
221
+ super().__init__()
222
+ self.mlp = nn.Sequential(
223
+ nn.SiLU(),
224
+ nn.Linear(time_emb_dim, dim_out * 2)
225
+ ) if exists(time_emb_dim) else None
226
+
227
+ self.block1 = Block(dim, dim_out, groups=groups)
228
+ self.block2 = Block(dim_out, dim_out, groups=groups)
229
+ self.res_conv = nn.Conv3d(
230
+ dim, dim_out, 1) if dim != dim_out else nn.Identity()
231
+
232
+ def forward(self, x, time_emb=None):
233
+
234
+ scale_shift = None
235
+ if exists(self.mlp):
236
+ assert exists(time_emb), 'time emb must be passed in'
237
+ time_emb = self.mlp(time_emb)
238
+ time_emb = rearrange(time_emb, 'b c -> b c 1 1 1')
239
+ scale_shift = time_emb.chunk(2, dim=1)
240
+
241
+ h = self.block1(x, scale_shift=scale_shift)
242
+
243
+ h = self.block2(h)
244
+ return h + self.res_conv(x)
245
+
246
+
247
+ class SpatialLinearAttention(nn.Module):
248
+ def __init__(self, dim, heads=4, dim_head=32):
249
+ super().__init__()
250
+ self.scale = dim_head ** -0.5
251
+ self.heads = heads
252
+ hidden_dim = dim_head * heads
253
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
254
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
255
+
256
+ def forward(self, x):
257
+ b, c, f, h, w = x.shape
258
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
259
+
260
+ qkv = self.to_qkv(x).chunk(3, dim=1)
261
+ q, k, v = rearrange_many(
262
+ qkv, 'b (h c) x y -> b h c (x y)', h=self.heads)
263
+
264
+ q = q.softmax(dim=-2)
265
+ k = k.softmax(dim=-1)
266
+
267
+ q = q * self.scale
268
+ context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
269
+
270
+ out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
271
+ out = rearrange(out, 'b h c (x y) -> b (h c) x y',
272
+ h=self.heads, x=h, y=w)
273
+ out = self.to_out(out)
274
+ return rearrange(out, '(b f) c h w -> b c f h w', b=b)
275
+
276
+ # attention along space and time
277
+
278
+
279
+ class EinopsToAndFrom(nn.Module):
280
+ def __init__(self, from_einops, to_einops, fn):
281
+ super().__init__()
282
+ self.from_einops = from_einops
283
+ self.to_einops = to_einops
284
+ self.fn = fn
285
+
286
+ def forward(self, x, **kwargs):
287
+ shape = x.shape
288
+ reconstitute_kwargs = dict(
289
+ tuple(zip(self.from_einops.split(' '), shape)))
290
+ x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')
291
+ x = self.fn(x, **kwargs)
292
+ x = rearrange(
293
+ x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)
294
+ return x
295
+
296
+
297
+ class Attention(nn.Module):
298
+ def __init__(
299
+ self,
300
+ dim,
301
+ heads=4,
302
+ dim_head=32,
303
+ rotary_emb=None
304
+ ):
305
+ super().__init__()
306
+ self.scale = dim_head ** -0.5
307
+ self.heads = heads
308
+ hidden_dim = dim_head * heads
309
+
310
+ self.rotary_emb = rotary_emb
311
+ self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False)
312
+ self.to_out = nn.Linear(hidden_dim, dim, bias=False)
313
+
314
+ def forward(
315
+ self,
316
+ x,
317
+ pos_bias=None,
318
+ focus_present_mask=None
319
+ ):
320
+ n, device = x.shape[-2], x.device
321
+
322
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
323
+
324
+ if exists(focus_present_mask) and focus_present_mask.all():
325
+ # if all batch samples are focusing on present
326
+ # it would be equivalent to passing that token's values through to the output
327
+ values = qkv[-1]
328
+ return self.to_out(values)
329
+
330
+ # split out heads
331
+
332
+ q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads)
333
+
334
+ # scale
335
+
336
+ q = q * self.scale
337
+
338
+ # rotate positions into queries and keys for time attention
339
+
340
+ if exists(self.rotary_emb):
341
+ q = self.rotary_emb.rotate_queries_or_keys(q)
342
+ k = self.rotary_emb.rotate_queries_or_keys(k)
343
+
344
+ # similarity
345
+
346
+ sim = einsum('... h i d, ... h j d -> ... h i j', q, k)
347
+
348
+ # relative positional bias
349
+
350
+ if exists(pos_bias):
351
+ sim = sim + pos_bias
352
+
353
+ if exists(focus_present_mask) and not (~focus_present_mask).all():
354
+ attend_all_mask = torch.ones(
355
+ (n, n), device=device, dtype=torch.bool)
356
+ attend_self_mask = torch.eye(n, device=device, dtype=torch.bool)
357
+
358
+ mask = torch.where(
359
+ rearrange(focus_present_mask, 'b -> b 1 1 1 1'),
360
+ rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),
361
+ rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),
362
+ )
363
+
364
+ sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
365
+
366
+ # numerical stability
367
+
368
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
369
+ attn = sim.softmax(dim=-1)
370
+
371
+ # aggregate values
372
+
373
+ out = einsum('... h i j, ... h j d -> ... h i d', attn, v)
374
+ out = rearrange(out, '... h n d -> ... n (h d)')
375
+ return self.to_out(out)
376
+
377
+ # model
378
+
379
+
380
+ class Unet3D(nn.Module):
381
+ def __init__(
382
+ self,
383
+ dim,
384
+ cond_dim=None,
385
+ out_dim=None,
386
+ dim_mults=(1, 2, 4, 8),
387
+ channels=3,
388
+ attn_heads=8,
389
+ attn_dim_head=32,
390
+ use_bert_text_cond=False,
391
+ init_dim=None,
392
+ init_kernel_size=7,
393
+ use_sparse_linear_attn=True,
394
+ block_type='resnet',
395
+ resnet_groups=8
396
+ ):
397
+ super().__init__()
398
+ self.channels = channels
399
+
400
+ # temporal attention and its relative positional encoding
401
+
402
+ rotary_emb = RotaryEmbedding(min(32, attn_dim_head))
403
+
404
+ def temporal_attn(dim): return EinopsToAndFrom('b c f h w', 'b (h w) f c', Attention(
405
+ dim, heads=attn_heads, dim_head=attn_dim_head, rotary_emb=rotary_emb))
406
+
407
+ # realistically will not be able to generate that many frames of video... yet
408
+ self.time_rel_pos_bias = RelativePositionBias(
409
+ heads=attn_heads, max_distance=32)
410
+
411
+ # initial conv
412
+
413
+ init_dim = default(init_dim, dim)
414
+ assert is_odd(init_kernel_size)
415
+
416
+ init_padding = init_kernel_size // 2
417
+ self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size,
418
+ init_kernel_size), padding=(0, init_padding, init_padding))
419
+
420
+ self.init_temporal_attn = Residual(
421
+ PreNorm(init_dim, temporal_attn(init_dim)))
422
+
423
+ # dimensions
424
+
425
+ dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
426
+ in_out = list(zip(dims[:-1], dims[1:]))
427
+
428
+ # time conditioning
429
+
430
+ time_dim = dim * 4
431
+ self.time_mlp = nn.Sequential(
432
+ SinusoidalPosEmb(dim),
433
+ nn.Linear(dim, time_dim),
434
+ nn.GELU(),
435
+ nn.Linear(time_dim, time_dim)
436
+ )
437
+
438
+ # text conditioning
439
+
440
+ self.has_cond = exists(cond_dim) or use_bert_text_cond
441
+ cond_dim = BERT_MODEL_DIM if use_bert_text_cond else cond_dim
442
+
443
+ self.null_cond_emb = nn.Parameter(
444
+ torch.randn(1, cond_dim)) if self.has_cond else None
445
+
446
+ cond_dim = time_dim + int(cond_dim or 0)
447
+
448
+ # layers
449
+
450
+ self.downs = nn.ModuleList([])
451
+ self.ups = nn.ModuleList([])
452
+
453
+ num_resolutions = len(in_out)
454
+ # block type
455
+
456
+ block_klass = partial(ResnetBlock, groups=resnet_groups)
457
+ block_klass_cond = partial(block_klass, time_emb_dim=cond_dim)
458
+
459
+ # modules for all layers
460
+ for ind, (dim_in, dim_out) in enumerate(in_out):
461
+ is_last = ind >= (num_resolutions - 1)
462
+
463
+ self.downs.append(nn.ModuleList([
464
+ block_klass_cond(dim_in, dim_out),
465
+ block_klass_cond(dim_out, dim_out),
466
+ Residual(PreNorm(dim_out, SpatialLinearAttention(
467
+ dim_out, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(),
468
+ Residual(PreNorm(dim_out, temporal_attn(dim_out))),
469
+ Downsample(dim_out) if not is_last else nn.Identity()
470
+ ]))
471
+
472
+ mid_dim = dims[-1]
473
+ self.mid_block1 = block_klass_cond(mid_dim, mid_dim)
474
+
475
+ spatial_attn = EinopsToAndFrom(
476
+ 'b c f h w', 'b f (h w) c', Attention(mid_dim, heads=attn_heads))
477
+
478
+ self.mid_spatial_attn = Residual(PreNorm(mid_dim, spatial_attn))
479
+ self.mid_temporal_attn = Residual(
480
+ PreNorm(mid_dim, temporal_attn(mid_dim)))
481
+
482
+ self.mid_block2 = block_klass_cond(mid_dim, mid_dim)
483
+
484
+ for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
485
+ is_last = ind >= (num_resolutions - 1)
486
+
487
+ self.ups.append(nn.ModuleList([
488
+ block_klass_cond(dim_out * 2, dim_in),
489
+ block_klass_cond(dim_in, dim_in),
490
+ Residual(PreNorm(dim_in, SpatialLinearAttention(
491
+ dim_in, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(),
492
+ Residual(PreNorm(dim_in, temporal_attn(dim_in))),
493
+ Upsample(dim_in) if not is_last else nn.Identity()
494
+ ]))
495
+
496
+ out_dim = default(out_dim, channels)
497
+ self.final_conv = nn.Sequential(
498
+ block_klass(dim * 2, dim),
499
+ nn.Conv3d(dim, out_dim, 1)
500
+ )
501
+
502
+ def forward_with_cond_scale(
503
+ self,
504
+ *args,
505
+ cond_scale=2.,
506
+ **kwargs
507
+ ):
508
+ logits = self.forward(*args, null_cond_prob=0., **kwargs)
509
+ if cond_scale == 1 or not self.has_cond:
510
+ return logits
511
+
512
+ null_logits = self.forward(*args, null_cond_prob=1., **kwargs)
513
+ return null_logits + (logits - null_logits) * cond_scale
514
+
515
+ def forward(
516
+ self,
517
+ x,
518
+ time,
519
+ cond=None,
520
+ null_cond_prob=0.,
521
+ focus_present_mask=None,
522
+ # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time)
523
+ prob_focus_present=0.
524
+ ):
525
+ assert not (self.has_cond and not exists(cond)
526
+ ), 'cond must be passed in if cond_dim specified'
527
+ x = torch.cat([x, cond], dim=1)
528
+
529
+ batch, device = x.shape[0], x.device
530
+
531
+ focus_present_mask = default(focus_present_mask, lambda: prob_mask_like(
532
+ (batch,), prob_focus_present, device=device))
533
+
534
+ time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device=x.device)
535
+
536
+ x = self.init_conv(x)
537
+ r = x.clone()
538
+
539
+ x = self.init_temporal_attn(x, pos_bias=time_rel_pos_bias)
540
+
541
+ t = self.time_mlp(time) if exists(self.time_mlp) else None # [2, 128]
542
+
543
+ # classifier free guidance
544
+
545
+ if self.has_cond:
546
+ batch, device = x.shape[0], x.device
547
+ mask = prob_mask_like((batch,), null_cond_prob, device=device)
548
+ cond = torch.where(rearrange(mask, 'b -> b 1'),
549
+ self.null_cond_emb, cond)
550
+ t = torch.cat((t, cond), dim=-1)
551
+
552
+ h = []
553
+
554
+ for block1, block2, spatial_attn, temporal_attn, downsample in self.downs:
555
+ x = block1(x, t)
556
+ x = block2(x, t)
557
+ x = spatial_attn(x)
558
+ x = temporal_attn(x, pos_bias=time_rel_pos_bias,
559
+ focus_present_mask=focus_present_mask)
560
+ h.append(x)
561
+ x = downsample(x)
562
+
563
+ # [2, 256, 32, 4, 4]
564
+ x = self.mid_block1(x, t)
565
+ x = self.mid_spatial_attn(x)
566
+ x = self.mid_temporal_attn(
567
+ x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask)
568
+ x = self.mid_block2(x, t)
569
+
570
+ for block1, block2, spatial_attn, temporal_attn, upsample in self.ups:
571
+ x = torch.cat((x, h.pop()), dim=1)
572
+ x = block1(x, t)
573
+ x = block2(x, t)
574
+ x = spatial_attn(x)
575
+ x = temporal_attn(x, pos_bias=time_rel_pos_bias,
576
+ focus_present_mask=focus_present_mask)
577
+ x = upsample(x)
578
+
579
+ x = torch.cat((x, r), dim=1)
580
+ return self.final_conv(x)
581
+
582
+ # gaussian diffusion trainer class
583
+
584
+
585
+ def extract(a, t, x_shape):
586
+ b, *_ = t.shape
587
+ out = a.gather(-1, t)
588
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
589
+
590
+
591
+ def cosine_beta_schedule(timesteps, s=0.008):
592
+ """
593
+ cosine schedule
594
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
595
+ """
596
+ steps = timesteps + 1
597
+ x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
598
+ alphas_cumprod = torch.cos(
599
+ ((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
600
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
601
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
602
+ return torch.clip(betas, 0, 0.9999)
603
+
604
+
605
+ class GaussianDiffusion(nn.Module):
606
+ def __init__(
607
+ self,
608
+ denoise_fn,
609
+ *,
610
+ image_size,
611
+ num_frames,
612
+ text_use_bert_cls=False,
613
+ channels=3,
614
+ timesteps=1000,
615
+ loss_type='l1',
616
+ use_dynamic_thres=False, # from the Imagen paper
617
+ dynamic_thres_percentile=0.9,
618
+ vqgan_ckpt=None,
619
+ device=None
620
+ ):
621
+ super().__init__()
622
+ self.channels = channels
623
+ self.image_size = image_size
624
+ self.num_frames = num_frames
625
+ self.denoise_fn = denoise_fn
626
+ self.device = device
627
+
628
+ if vqgan_ckpt:
629
+ self.vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt).cuda()
630
+ self.vqgan.eval()
631
+ else:
632
+ self.vqgan = None
633
+
634
+ betas = cosine_beta_schedule(timesteps)
635
+
636
+ alphas = 1. - betas
637
+ alphas_cumprod = torch.cumprod(alphas, axis=0)
638
+ alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)
639
+
640
+ timesteps, = betas.shape
641
+ self.num_timesteps = int(timesteps)
642
+ self.loss_type = loss_type
643
+
644
+ # register buffer helper function that casts float64 to float32
645
+
646
+ def register_buffer(name, val): return self.register_buffer(
647
+ name, val.to(torch.float32))
648
+
649
+ register_buffer('betas', betas)
650
+ register_buffer('alphas_cumprod', alphas_cumprod)
651
+ register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
652
+
653
+ # calculations for diffusion q(x_t | x_{t-1}) and others
654
+
655
+ register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
656
+ register_buffer('sqrt_one_minus_alphas_cumprod',
657
+ torch.sqrt(1. - alphas_cumprod))
658
+ register_buffer('log_one_minus_alphas_cumprod',
659
+ torch.log(1. - alphas_cumprod))
660
+ register_buffer('sqrt_recip_alphas_cumprod',
661
+ torch.sqrt(1. / alphas_cumprod))
662
+ register_buffer('sqrt_recipm1_alphas_cumprod',
663
+ torch.sqrt(1. / alphas_cumprod - 1))
664
+
665
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
666
+
667
+ posterior_variance = betas * \
668
+ (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
669
+
670
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
671
+
672
+ register_buffer('posterior_variance', posterior_variance)
673
+
674
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
675
+
676
+ register_buffer('posterior_log_variance_clipped',
677
+ torch.log(posterior_variance.clamp(min=1e-20)))
678
+ register_buffer('posterior_mean_coef1', betas *
679
+ torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
680
+ register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev)
681
+ * torch.sqrt(alphas) / (1. - alphas_cumprod))
682
+
683
+ # text conditioning parameters
684
+
685
+ self.text_use_bert_cls = text_use_bert_cls
686
+
687
+ # dynamic thresholding when sampling
688
+
689
+ self.use_dynamic_thres = use_dynamic_thres
690
+ self.dynamic_thres_percentile = dynamic_thres_percentile
691
+
692
+ def q_mean_variance(self, x_start, t):
693
+ mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
694
+ variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
695
+ log_variance = extract(
696
+ self.log_one_minus_alphas_cumprod, t, x_start.shape)
697
+ return mean, variance, log_variance
698
+
699
+ def predict_start_from_noise(self, x_t, t, noise):
700
+ return (
701
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
702
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
703
+ )
704
+
705
+ def q_posterior(self, x_start, x_t, t):
706
+ posterior_mean = (
707
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
708
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
709
+ )
710
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
711
+ posterior_log_variance_clipped = extract(
712
+ self.posterior_log_variance_clipped, t, x_t.shape)
713
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
714
+
715
+ def p_mean_variance(self, x, t, clip_denoised: bool, cond=None, cond_scale=1.):
716
+ x_recon = self.predict_start_from_noise(
717
+ x, t=t, noise=self.denoise_fn.forward_with_cond_scale(x, t, cond=cond, cond_scale=cond_scale))
718
+
719
+ if clip_denoised:
720
+ s = 1.
721
+ if self.use_dynamic_thres:
722
+ s = torch.quantile(
723
+ rearrange(x_recon, 'b ... -> b (...)').abs(),
724
+ self.dynamic_thres_percentile,
725
+ dim=-1
726
+ )
727
+
728
+ s.clamp_(min=1.)
729
+ s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
730
+
731
+ # clip by threshold, depending on whether static or dynamic
732
+ x_recon = x_recon.clamp(-s, s) / s
733
+
734
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
735
+ x_start=x_recon, x_t=x, t=t)
736
+ return model_mean, posterior_variance, posterior_log_variance
737
+
738
+ @torch.inference_mode()
739
+ def p_sample(self, x, t, cond=None, cond_scale=1., clip_denoised=True):
740
+ b, *_, device = *x.shape, x.device
741
+ model_mean, _, model_log_variance = self.p_mean_variance(
742
+ x=x, t=t, clip_denoised=clip_denoised, cond=cond, cond_scale=cond_scale)
743
+ noise = torch.randn_like(x)
744
+ # no noise when t == 0
745
+ nonzero_mask = (1 - (t == 0).float()).reshape(b,
746
+ *((1,) * (len(x.shape) - 1)))
747
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
748
+
749
+ @torch.inference_mode()
750
+ def p_sample_loop(self, shape, cond=None, cond_scale=1.):
751
+ device = self.betas.device
752
+
753
+ b = shape[0]
754
+ img = torch.randn(shape, device=device)
755
+ # print('cond', cond.shape)
756
+ for i in reversed(range(0, self.num_timesteps)):
757
+ img = self.p_sample(img, torch.full(
758
+ (b,), i, device=device, dtype=torch.long), cond=cond, cond_scale=cond_scale)
759
+
760
+ return img
761
+
762
+ @torch.inference_mode()
763
+ def sample(self, cond=None, cond_scale=1., batch_size=16):
764
+ device = next(self.denoise_fn.parameters()).device
765
+
766
+ if is_list_str(cond):
767
+ cond = bert_embed(tokenize(cond)).to(device)
768
+
769
+ # batch_size = cond.shape[0] if exists(cond) else batch_size
770
+ batch_size = batch_size
771
+ image_size = self.image_size
772
+ channels = 8 # self.channels
773
+ num_frames = self.num_frames
774
+ # print((batch_size, channels, num_frames, image_size, image_size))
775
+ # print('cond_',cond.shape)
776
+ _sample = self.p_sample_loop(
777
+ (batch_size, channels, num_frames, image_size, image_size), cond=cond, cond_scale=cond_scale)
778
+
779
+ if isinstance(self.vqgan, VQGAN):
780
+ # denormalize TODO: Remove eventually
781
+ _sample = (((_sample + 1.0) / 2.0) * (self.vqgan.codebook.embeddings.max() -
782
+ self.vqgan.codebook.embeddings.min())) + self.vqgan.codebook.embeddings.min()
783
+
784
+ _sample = self.vqgan.decode(_sample, quantize=True)
785
+ else:
786
+ unnormalize_img(_sample)
787
+
788
+ return _sample
789
+
790
+ @torch.inference_mode()
791
+ def interpolate(self, x1, x2, t=None, lam=0.5):
792
+ b, *_, device = *x1.shape, x1.device
793
+ t = default(t, self.num_timesteps - 1)
794
+
795
+ assert x1.shape == x2.shape
796
+
797
+ t_batched = torch.stack([torch.tensor(t, device=device)] * b)
798
+ xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
799
+
800
+ img = (1 - lam) * xt1 + lam * xt2
801
+ for i in reversed(range(0, t)):
802
+ img = self.p_sample(img, torch.full(
803
+ (b,), i, device=device, dtype=torch.long))
804
+
805
+ return img
806
+
807
+ def q_sample(self, x_start, t, noise=None):
808
+ noise = default(noise, lambda: torch.randn_like(x_start))
809
+
810
+ return (
811
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
812
+ extract(self.sqrt_one_minus_alphas_cumprod,
813
+ t, x_start.shape) * noise
814
+ )
815
+
816
+ def p_losses(self, x_start, t, cond=None, noise=None, **kwargs):
817
+ b, c, f, h, w, device = *x_start.shape, x_start.device
818
+ noise = default(noise, lambda: torch.randn_like(x_start))
819
+ # breakpoint()
820
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # [2, 8, 32, 32, 32]
821
+
822
+ if is_list_str(cond):
823
+ cond = bert_embed(
824
+ tokenize(cond), return_cls_repr=self.text_use_bert_cls)
825
+ cond = cond.to(device)
826
+
827
+ x_recon = self.denoise_fn(x_noisy, t, cond=cond, **kwargs)
828
+
829
+ if self.loss_type == 'l1':
830
+ loss = F.l1_loss(noise, x_recon)
831
+ elif self.loss_type == 'l2':
832
+ loss = F.mse_loss(noise, x_recon)
833
+ else:
834
+ raise NotImplementedError()
835
+
836
+ return loss
837
+
838
+ def forward(self, x, *args, **kwargs):
839
+ bs = int(x.shape[0]/2)
840
+ img=x[:bs,...]
841
+ mask=x[bs:,...]
842
+ mask_=(1-mask).detach()
843
+ masked_img = (img*mask_).detach()
844
+ masked_img=masked_img.permute(0,1,-1,-3,-2)
845
+ img=img.permute(0,1,-1,-3,-2)
846
+ mask=mask.permute(0,1,-1,-3,-2)
847
+ # breakpoint()
848
+ if isinstance(self.vqgan, VQGAN):
849
+ with torch.no_grad():
850
+ img = self.vqgan.encode(
851
+ img, quantize=False, include_embeddings=True)
852
+ # normalize to -1 and 1
853
+ img = ((img - self.vqgan.codebook.embeddings.min()) /
854
+ (self.vqgan.codebook.embeddings.max() -
855
+ self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0
856
+
857
+ masked_img = self.vqgan.encode(
858
+ masked_img, quantize=False, include_embeddings=True)
859
+ # normalize to -1 and 1
860
+ masked_img = ((masked_img - self.vqgan.codebook.embeddings.min()) /
861
+ (self.vqgan.codebook.embeddings.max() -
862
+ self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0
863
+ else:
864
+ print("Hi")
865
+ img = normalize_img(img)
866
+ masked_img = normalize_img(masked_img)
867
+ mask = mask*2.0 - 1.0
868
+ cc = torch.nn.functional.interpolate(mask, size=masked_img.shape[-3:])
869
+ cond = torch.cat((masked_img, cc), dim=1)
870
+
871
+ b, device, img_size, = img.shape[0], img.device, self.image_size
872
+ t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
873
+ # breakpoint()
874
+ return self.p_losses(img, t, cond=cond, *args, **kwargs)
875
+
876
+ # trainer class
877
+
878
+
879
+ CHANNELS_TO_MODE = {
880
+ 1: 'L',
881
+ 3: 'RGB',
882
+ 4: 'RGBA'
883
+ }
884
+
885
+
886
+ def seek_all_images(img, channels=3):
887
+ assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid'
888
+ mode = CHANNELS_TO_MODE[channels]
889
+
890
+ i = 0
891
+ while True:
892
+ try:
893
+ img.seek(i)
894
+ yield img.convert(mode)
895
+ except EOFError:
896
+ break
897
+ i += 1
898
+
899
+ # tensor of shape (channels, frames, height, width) -> gif
900
+
901
+
902
+ def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True):
903
+ tensor = ((tensor - tensor.min()) / (tensor.max() - tensor.min())) * 1.0
904
+ images = map(T.ToPILImage(), tensor.unbind(dim=1))
905
+ first_img, *rest_imgs = images
906
+ first_img.save(path, save_all=True, append_images=rest_imgs,
907
+ duration=duration, loop=loop, optimize=optimize)
908
+ return images
909
+
910
+ # gif -> (channels, frame, height, width) tensor
911
+
912
+
913
+ def gif_to_tensor(path, channels=3, transform=T.ToTensor()):
914
+ img = Image.open(path)
915
+ tensors = tuple(map(transform, seek_all_images(img, channels=channels)))
916
+ return torch.stack(tensors, dim=1)
917
+
918
+
919
+ def identity(t, *args, **kwargs):
920
+ return t
921
+
922
+
923
+ def normalize_img(t):
924
+ return t * 2 - 1
925
+
926
+
927
+ def unnormalize_img(t):
928
+ return (t + 1) * 0.5
929
+
930
+
931
+ def cast_num_frames(t, *, frames):
932
+ f = t.shape[1]
933
+
934
+ if f == frames:
935
+ return t
936
+
937
+ if f > frames:
938
+ return t[:, :frames]
939
+
940
+ return F.pad(t, (0, 0, 0, 0, 0, frames - f))
941
+
942
+
943
+ class Dataset(data.Dataset):
944
+ def __init__(
945
+ self,
946
+ folder,
947
+ image_size,
948
+ channels=3,
949
+ num_frames=16,
950
+ horizontal_flip=False,
951
+ force_num_frames=True,
952
+ exts=['gif']
953
+ ):
954
+ super().__init__()
955
+ self.folder = folder
956
+ self.image_size = image_size
957
+ self.channels = channels
958
+ self.paths = [p for ext in exts for p in Path(
959
+ f'{folder}').glob(f'**/*.{ext}')]
960
+
961
+ self.cast_num_frames_fn = partial(
962
+ cast_num_frames, frames=num_frames) if force_num_frames else identity
963
+
964
+ self.transform = T.Compose([
965
+ T.Resize(image_size),
966
+ T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity),
967
+ T.CenterCrop(image_size),
968
+ T.ToTensor()
969
+ ])
970
+
971
+ def __len__(self):
972
+ return len(self.paths)
973
+
974
+ def __getitem__(self, index):
975
+ path = self.paths[index]
976
+ tensor = gif_to_tensor(path, self.channels, transform=self.transform)
977
+ return self.cast_num_frames_fn(tensor)
978
+
979
+ # trainer class
980
+
981
+
982
+ class Tester(object):
983
+ def __init__(
984
+ self,
985
+ diffusion_model,
986
+ ):
987
+ super().__init__()
988
+ self.model = diffusion_model
989
+ self.ema_model = copy.deepcopy(self.model)
990
+ self.step=0
991
+ self.image_size = diffusion_model.image_size
992
+
993
+ self.reset_parameters()
994
+
995
+ def reset_parameters(self):
996
+ self.ema_model.load_state_dict(self.model.state_dict())
997
+
998
+
999
+ def load(self, milestone, map_location=None, **kwargs):
1000
+ if milestone == -1:
1001
+ all_milestones = [int(p.stem.split('-')[-1])
1002
+ for p in Path(self.results_folder).glob('**/*.pt')]
1003
+ assert len(
1004
+ all_milestones) > 0, 'need to have at least one milestone to load from latest checkpoint (milestone == -1)'
1005
+ milestone = max(all_milestones)
1006
+
1007
+ if map_location:
1008
+ data = torch.load(milestone, map_location=map_location)
1009
+ else:
1010
+ data = torch.load(milestone)
1011
+
1012
+ self.step = data['step']
1013
+ self.model.load_state_dict(data['model'], **kwargs)
1014
+ self.ema_model.load_state_dict(data['ema'], **kwargs)
1015
+
1016
+
Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/ddpm/text.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch"
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+
7
+ def exists(val):
8
+ return val is not None
9
+
10
+ # singleton globals
11
+
12
+
13
+ MODEL = None
14
+ TOKENIZER = None
15
+ BERT_MODEL_DIM = 768
16
+
17
+
18
+ def get_tokenizer():
19
+ global TOKENIZER
20
+ if not exists(TOKENIZER):
21
+ TOKENIZER = torch.hub.load(
22
+ 'huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased')
23
+ return TOKENIZER
24
+
25
+
26
+ def get_bert():
27
+ global MODEL
28
+ if not exists(MODEL):
29
+ MODEL = torch.hub.load(
30
+ 'huggingface/pytorch-transformers', 'model', 'bert-base-cased')
31
+ if torch.cuda.is_available():
32
+ MODEL = MODEL.cuda()
33
+
34
+ return MODEL
35
+
36
+ # tokenize
37
+
38
+
39
+ def tokenize(texts, add_special_tokens=True):
40
+ if not isinstance(texts, (list, tuple)):
41
+ texts = [texts]
42
+
43
+ tokenizer = get_tokenizer()
44
+
45
+ encoding = tokenizer.batch_encode_plus(
46
+ texts,
47
+ add_special_tokens=add_special_tokens,
48
+ padding=True,
49
+ return_tensors='pt'
50
+ )
51
+
52
+ token_ids = encoding.input_ids
53
+ return token_ids
54
+
55
+ # embedding function
56
+
57
+
58
+ @torch.no_grad()
59
+ def bert_embed(
60
+ token_ids,
61
+ return_cls_repr=False,
62
+ eps=1e-8,
63
+ pad_id=0.
64
+ ):
65
+ model = get_bert()
66
+ mask = token_ids != pad_id
67
+
68
+ if torch.cuda.is_available():
69
+ token_ids = token_ids.cuda()
70
+ mask = mask.cuda()
71
+
72
+ outputs = model(
73
+ input_ids=token_ids,
74
+ attention_mask=mask,
75
+ output_hidden_states=True
76
+ )
77
+
78
+ hidden_state = outputs.hidden_states[-1]
79
+
80
+ if return_cls_repr:
81
+ # return [cls] as representation
82
+ return hidden_state[:, 0]
83
+
84
+ if not exists(mask):
85
+ return hidden_state.mean(dim=1)
86
+
87
+ # mean all tokens excluding [cls], accounting for length
88
+ mask = mask[:, 1:]
89
+ mask = rearrange(mask, 'b n -> b n 1')
90
+
91
+ numer = (hidden_state[:, 1:] * mask).sum(dim=1)
92
+ denom = mask.sum(dim=1)
93
+ masked_mean = numer / (denom + eps)
94
+ return
Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/ddpm/time_embedding.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from monai.networks.layers.utils import get_act_layer
6
+
7
+
8
+ class SinusoidalPosEmb(nn.Module):
9
+ def __init__(self, emb_dim=16, downscale_freq_shift=1, max_period=10000, flip_sin_to_cos=False):
10
+ super().__init__()
11
+ self.emb_dim = emb_dim
12
+ self.downscale_freq_shift = downscale_freq_shift
13
+ self.max_period = max_period
14
+ self.flip_sin_to_cos = flip_sin_to_cos
15
+
16
+ def forward(self, x):
17
+ device = x.device
18
+ half_dim = self.emb_dim // 2
19
+ emb = math.log(self.max_period) / \
20
+ (half_dim - self.downscale_freq_shift)
21
+ emb = torch.exp(-emb*torch.arange(half_dim, device=device))
22
+ emb = x[:, None] * emb[None, :]
23
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
24
+
25
+ if self.flip_sin_to_cos:
26
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
27
+
28
+ if self.emb_dim % 2 == 1:
29
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
30
+ return emb
31
+
32
+
33
+ class LearnedSinusoidalPosEmb(nn.Module):
34
+ """ following @crowsonkb 's lead with learned sinusoidal pos emb """
35
+ """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
36
+
37
+ def __init__(self, emb_dim):
38
+ super().__init__()
39
+ self.emb_dim = emb_dim
40
+ half_dim = emb_dim // 2
41
+ self.weights = nn.Parameter(torch.randn(half_dim))
42
+
43
+ def forward(self, x):
44
+ x = x[:, None]
45
+ freqs = x * self.weights[None, :] * 2 * math.pi
46
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
47
+ fouriered = torch.cat((x, fouriered), dim=-1)
48
+ if self.emb_dim % 2 == 1:
49
+ fouriered = torch.nn.functional.pad(fouriered, (0, 1, 0, 0))
50
+ return fouriered
51
+
52
+
53
+ class TimeEmbbeding(nn.Module):
54
+ def __init__(
55
+ self,
56
+ emb_dim=64,
57
+ pos_embedder=SinusoidalPosEmb,
58
+ pos_embedder_kwargs={},
59
+ act_name=("SWISH", {}) # Swish = SiLU
60
+ ):
61
+ super().__init__()
62
+ self.emb_dim = emb_dim
63
+ self.pos_emb_dim = pos_embedder_kwargs.get('emb_dim', emb_dim//4)
64
+ pos_embedder_kwargs['emb_dim'] = self.pos_emb_dim
65
+ self.pos_embedder = pos_embedder(**pos_embedder_kwargs)
66
+
67
+ self.time_emb = nn.Sequential(
68
+ self.pos_embedder,
69
+ nn.Linear(self.pos_emb_dim, self.emb_dim),
70
+ get_act_layer(act_name),
71
+ nn.Linear(self.emb_dim, self.emb_dim)
72
+ )
73
+
74
+ def forward(self, time):
75
+ return self.time_emb(time)
Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/ddpm/unet.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ddpm.time_embedding import TimeEmbbeding
2
+
3
+ import monai.networks.nets as nets
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+
8
+ from monai.networks.blocks import UnetBasicBlock, UnetResBlock, UnetUpBlock, Convolution, UnetOutBlock
9
+ from monai.networks.layers.utils import get_act_layer
10
+
11
+
12
+ class DownBlock(nn.Module):
13
+ def __init__(
14
+ self,
15
+ spatial_dims,
16
+ in_ch,
17
+ out_ch,
18
+ time_emb_dim,
19
+ cond_emb_dim,
20
+ act_name=("swish", {}),
21
+ **kwargs):
22
+ super(DownBlock, self).__init__()
23
+ self.loca_time_embedder = nn.Sequential(
24
+ get_act_layer(name=act_name),
25
+ nn.Linear(time_emb_dim, in_ch) # in_ch * 2
26
+ )
27
+ if cond_emb_dim is not None:
28
+ self.loca_cond_embedder = nn.Sequential(
29
+ get_act_layer(name=act_name),
30
+ nn.Linear(cond_emb_dim, in_ch),
31
+ )
32
+ self.down_op = UnetBasicBlock(
33
+ spatial_dims, in_ch, out_ch, act_name=act_name, **kwargs)
34
+
35
+ def forward(self, x, time_emb, cond_emb):
36
+ b, c, *_ = x.shape
37
+ sp_dim = x.ndim-2
38
+
39
+ # ------------ Time ----------
40
+ time_emb = self.loca_time_embedder(time_emb)
41
+ time_emb = time_emb.reshape(b, c, *((1,)*sp_dim))
42
+ # scale, shift = time_emb.chunk(2, dim = 1)
43
+
44
+ # ------------ Combine ------------
45
+ # x = x * (scale + 1) + shift
46
+ x = x + time_emb
47
+
48
+ # ----------- Condition ------------
49
+ if cond_emb is not None:
50
+ cond_emb = self.loca_cond_embedder(cond_emb)
51
+ cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim))
52
+ x = x + cond_emb
53
+
54
+ # ----------- Image ---------
55
+ y = self.down_op(x)
56
+ return y
57
+
58
+
59
+ class UpBlock(nn.Module):
60
+ def __init__(
61
+ self,
62
+ spatial_dims,
63
+ skip_ch,
64
+ enc_ch,
65
+ time_emb_dim,
66
+ cond_emb_dim,
67
+ act_name=("swish", {}),
68
+ **kwargs):
69
+ super(UpBlock, self).__init__()
70
+ self.up_op = UnetUpBlock(spatial_dims, enc_ch,
71
+ skip_ch, act_name=act_name, **kwargs)
72
+ self.loca_time_embedder = nn.Sequential(
73
+ get_act_layer(name=act_name),
74
+ nn.Linear(time_emb_dim, skip_ch * 2),
75
+ )
76
+ if cond_emb_dim is not None:
77
+ self.loca_cond_embedder = nn.Sequential(
78
+ get_act_layer(name=act_name),
79
+ nn.Linear(cond_emb_dim, skip_ch * 2),
80
+ )
81
+
82
+ def forward(self, x_skip, x_enc, time_emb, cond_emb):
83
+ b, c, *_ = x_enc.shape
84
+ sp_dim = x_enc.ndim-2
85
+
86
+ # ----------- Time --------------
87
+ time_emb = self.loca_time_embedder(time_emb)
88
+ time_emb = time_emb.reshape(b, c, *((1,)*sp_dim))
89
+ # scale, shift = time_emb.chunk(2, dim = 1)
90
+
91
+ # -------- Combine -------------
92
+ # y = x * (scale + 1) + shift
93
+ x_enc = x_enc + time_emb
94
+
95
+ # ----------- Condition ------------
96
+ if cond_emb is not None:
97
+ cond_emb = self.loca_cond_embedder(cond_emb)
98
+ cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim))
99
+ x_enc = x_enc + cond_emb
100
+
101
+ # ----------- Image -------------
102
+ y = self.up_op(x_enc, x_skip)
103
+
104
+ # -------- Combine -------------
105
+ # y = y * (scale + 1) + shift
106
+
107
+ return y
108
+
109
+
110
+ class UNet(nn.Module):
111
+
112
+ def __init__(self,
113
+ in_ch=1,
114
+ out_ch=1,
115
+ spatial_dims=3,
116
+ hid_chs=[32, 64, 128, 256, 512],
117
+ kernel_sizes=[(1, 3, 3), (1, 3, 3), (1, 3, 3), 3, 3],
118
+ strides=[1, (1, 2, 2), (1, 2, 2), 2, 2],
119
+ upsample_kernel_sizes=None,
120
+ act_name=("SWISH", {}),
121
+ norm_name=("INSTANCE", {"affine": True}),
122
+ time_embedder=TimeEmbbeding,
123
+ time_embedder_kwargs={},
124
+ cond_embedder=None,
125
+ cond_embedder_kwargs={},
126
+ # True = all but last layer, 0/False=disable, 1=only first layer, ...
127
+ deep_ver_supervision=True,
128
+ estimate_variance=False,
129
+ use_self_conditioning=False,
130
+ **kwargs
131
+ ):
132
+ super().__init__()
133
+ if upsample_kernel_sizes is None:
134
+ upsample_kernel_sizes = strides[1:]
135
+
136
+ # ------------- Time-Embedder-----------
137
+ self.time_embedder = time_embedder(**time_embedder_kwargs)
138
+
139
+ # ------------- Condition-Embedder-----------
140
+ if cond_embedder is not None:
141
+ self.cond_embedder = cond_embedder(**cond_embedder_kwargs)
142
+ cond_emb_dim = self.cond_embedder.emb_dim
143
+ else:
144
+ self.cond_embedder = None
145
+ cond_emb_dim = None
146
+
147
+ # ----------- In-Convolution ------------
148
+ in_ch = in_ch*2 if use_self_conditioning else in_ch
149
+ self.inc = UnetBasicBlock(spatial_dims, in_ch, hid_chs[0], kernel_size=kernel_sizes[0], stride=strides[0],
150
+ act_name=act_name, norm_name=norm_name, **kwargs)
151
+
152
+ # ----------- Encoder ----------------
153
+ self.encoders = nn.ModuleList([
154
+ DownBlock(spatial_dims, hid_chs[i-1], hid_chs[i], time_emb_dim=self.time_embedder.emb_dim,
155
+ cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[
156
+ i], stride=strides[i], act_name=act_name,
157
+ norm_name=norm_name, **kwargs)
158
+ for i in range(1, len(strides))
159
+ ])
160
+
161
+ # ------------ Decoder ----------
162
+ self.decoders = nn.ModuleList([
163
+ UpBlock(spatial_dims, hid_chs[i], hid_chs[i+1], time_emb_dim=self.time_embedder.emb_dim,
164
+ cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[i +
165
+ 1], stride=strides[i+1], act_name=act_name,
166
+ norm_name=norm_name, upsample_kernel_size=upsample_kernel_sizes[i], **kwargs)
167
+ for i in range(len(strides)-1)
168
+ ])
169
+
170
+ # --------------- Out-Convolution ----------------
171
+ out_ch_hor = out_ch*2 if estimate_variance else out_ch
172
+ self.outc = UnetOutBlock(
173
+ spatial_dims, hid_chs[0], out_ch_hor, dropout=None)
174
+ if isinstance(deep_ver_supervision, bool):
175
+ deep_ver_supervision = len(
176
+ strides)-2 if deep_ver_supervision else 0
177
+ self.outc_ver = nn.ModuleList([
178
+ UnetOutBlock(spatial_dims, hid_chs[i], out_ch, dropout=None)
179
+ for i in range(1, deep_ver_supervision+1)
180
+ ])
181
+
182
+ def forward(self, x_t, t, cond=None, self_cond=None, **kwargs):
183
+ condition = cond
184
+ # x_t [B, C, (D), H, W]
185
+ # t [B,]
186
+
187
+ # -------- In-Convolution --------------
188
+ x = [None for _ in range(len(self.encoders)+1)]
189
+ x_t = torch.cat([x_t, self_cond],
190
+ dim=1) if self_cond is not None else x_t
191
+ x[0] = self.inc(x_t)
192
+
193
+ # -------- Time Embedding (Gloabl) -----------
194
+ time_emb = self.time_embedder(t) # [B, C]
195
+
196
+ # -------- Condition Embedding (Gloabl) -----------
197
+ if (condition is None) or (self.cond_embedder is None):
198
+ cond_emb = None
199
+ else:
200
+ cond_emb = self.cond_embedder(condition) # [B, C]
201
+
202
+ # --------- Encoder --------------
203
+ for i in range(len(self.encoders)):
204
+ x[i+1] = self.encoders[i](x[i], time_emb, cond_emb)
205
+
206
+ # -------- Decoder -----------
207
+ for i in range(len(self.decoders), 0, -1):
208
+ x[i-1] = self.decoders[i-1](x[i-1], x[i], time_emb, cond_emb)
209
+
210
+ # ---------Out-Convolution ------------
211
+ y_hor = self.outc(x[0])
212
+ y_ver = [outc_ver_i(x[i+1])
213
+ for i, outc_ver_i in enumerate(self.outc_ver)]
214
+
215
+ return y_hor # , y_ver
216
+
217
+ def forward_with_cond_scale(self, *args, cond_scale=0., **kwargs):
218
+ return self.forward(*args, **kwargs)
219
+
220
+
221
+ if __name__ == '__main__':
222
+ model = UNet(in_ch=3)
223
+ input = torch.randn((1, 3, 16, 128, 128))
224
+ time = torch.randn((1,))
225
+ out_hor, out_ver = model(input, time)
226
+ print(out_hor[0].shape)
Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/ddpm/util.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from einops import repeat
17
+
18
+ # from ldm.util import instantiate_from_config
19
+
20
+
21
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22
+ if schedule == "linear":
23
+ betas = (
24
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25
+ )
26
+
27
+ elif schedule == "cosine":
28
+ timesteps = (
29
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30
+ )
31
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
32
+ alphas = torch.cos(alphas).pow(2)
33
+ alphas = alphas / alphas[0]
34
+ betas = 1 - alphas[1:] / alphas[:-1]
35
+ betas = np.clip(betas, a_min=0, a_max=0.999)
36
+
37
+ elif schedule == "sqrt_linear":
38
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39
+ elif schedule == "sqrt":
40
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41
+ else:
42
+ raise ValueError(f"schedule '{schedule}' unknown.")
43
+ return betas.numpy()
44
+
45
+
46
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47
+ if ddim_discr_method == 'uniform':
48
+ c = num_ddpm_timesteps // num_ddim_timesteps
49
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50
+ elif ddim_discr_method == 'quad':
51
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52
+ else:
53
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54
+
55
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
57
+ if c != 1:
58
+ steps_out = ddim_timesteps + 1
59
+ else:
60
+ steps_out = ddim_timesteps
61
+ if verbose:
62
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
63
+ return steps_out
64
+
65
+
66
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
67
+ # select alphas for computing the variance schedule
68
+
69
+ alphas = alphacums[ddim_timesteps]
70
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
71
+
72
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
73
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
74
+ if verbose:
75
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
76
+ print(f'For the chosen value of eta, which is {eta}, '
77
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
78
+ return sigmas, alphas, alphas_prev
79
+
80
+
81
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
82
+ """
83
+ Create a beta schedule that discretizes the given alpha_t_bar function,
84
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
85
+ :param num_diffusion_timesteps: the number of betas to produce.
86
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
87
+ produces the cumulative product of (1-beta) up to that
88
+ part of the diffusion process.
89
+ :param max_beta: the maximum beta to use; use values lower than 1 to
90
+ prevent singularities.
91
+ """
92
+ betas = []
93
+ for i in range(num_diffusion_timesteps):
94
+ t1 = i / num_diffusion_timesteps
95
+ t2 = (i + 1) / num_diffusion_timesteps
96
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
97
+ return np.array(betas)
98
+
99
+
100
+ def extract_into_tensor(a, t, x_shape):
101
+ b, *_ = t.shape
102
+ out = a.gather(-1, t)
103
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
104
+
105
+
106
+ def checkpoint(func, inputs, params, flag):
107
+ """
108
+ Evaluate a function without caching intermediate activations, allowing for
109
+ reduced memory at the expense of extra compute in the backward pass.
110
+ :param func: the function to evaluate.
111
+ :param inputs: the argument sequence to pass to `func`.
112
+ :param params: a sequence of parameters `func` depends on but does not
113
+ explicitly take as arguments.
114
+ :param flag: if False, disable gradient checkpointing.
115
+ """
116
+ if flag:
117
+ args = tuple(inputs) + tuple(params)
118
+ return CheckpointFunction.apply(func, len(inputs), *args)
119
+ else:
120
+ return func(*inputs)
121
+
122
+
123
+ class CheckpointFunction(torch.autograd.Function):
124
+ @staticmethod
125
+ def forward(ctx, run_function, length, *args):
126
+ ctx.run_function = run_function
127
+ ctx.input_tensors = list(args[:length])
128
+ ctx.input_params = list(args[length:])
129
+
130
+ with torch.no_grad():
131
+ output_tensors = ctx.run_function(*ctx.input_tensors)
132
+ return output_tensors
133
+
134
+ @staticmethod
135
+ def backward(ctx, *output_grads):
136
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
137
+ with torch.enable_grad():
138
+ # Fixes a bug where the first op in run_function modifies the
139
+ # Tensor storage in place, which is not allowed for detach()'d
140
+ # Tensors.
141
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
142
+ output_tensors = ctx.run_function(*shallow_copies)
143
+ input_grads = torch.autograd.grad(
144
+ output_tensors,
145
+ ctx.input_tensors + ctx.input_params,
146
+ output_grads,
147
+ allow_unused=True,
148
+ )
149
+ del ctx.input_tensors
150
+ del ctx.input_params
151
+ del output_tensors
152
+ return (None, None) + input_grads
153
+
154
+
155
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
156
+ """
157
+ Create sinusoidal timestep embeddings.
158
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
159
+ These may be fractional.
160
+ :param dim: the dimension of the output.
161
+ :param max_period: controls the minimum frequency of the embeddings.
162
+ :return: an [N x dim] Tensor of positional embeddings.
163
+ """
164
+ if not repeat_only:
165
+ half = dim // 2
166
+ freqs = torch.exp(
167
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
168
+ ).to(device=timesteps.device)
169
+ args = timesteps[:, None].float() * freqs[None]
170
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
171
+ if dim % 2:
172
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
173
+ else:
174
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
175
+ return embedding
176
+
177
+
178
+ def zero_module(module):
179
+ """
180
+ Zero out the parameters of a module and return it.
181
+ """
182
+ for p in module.parameters():
183
+ p.detach().zero_()
184
+ return module
185
+
186
+
187
+ def scale_module(module, scale):
188
+ """
189
+ Scale the parameters of a module and return it.
190
+ """
191
+ for p in module.parameters():
192
+ p.detach().mul_(scale)
193
+ return module
194
+
195
+
196
+ def mean_flat(tensor):
197
+ """
198
+ Take the mean over all non-batch dimensions.
199
+ """
200
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
201
+
202
+
203
+ def normalization(channels):
204
+ """
205
+ Make a standard normalization layer.
206
+ :param channels: number of input channels.
207
+ :return: an nn.Module for normalization.
208
+ """
209
+ return GroupNorm32(32, channels)
210
+
211
+
212
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
213
+ class SiLU(nn.Module):
214
+ def forward(self, x):
215
+ return x * torch.sigmoid(x)
216
+
217
+
218
+ class GroupNorm32(nn.GroupNorm):
219
+ def forward(self, x):
220
+ return super().forward(x.float()).type(x.dtype)
221
+
222
+ def conv_nd(dims, *args, **kwargs):
223
+ """
224
+ Create a 1D, 2D, or 3D convolution module.
225
+ """
226
+ if dims == 1:
227
+ return nn.Conv1d(*args, **kwargs)
228
+ elif dims == 2:
229
+ return nn.Conv2d(*args, **kwargs)
230
+ elif dims == 3:
231
+ return nn.Conv3d(*args, **kwargs)
232
+ raise ValueError(f"unsupported dimensions: {dims}")
233
+
234
+
235
+ def linear(*args, **kwargs):
236
+ """
237
+ Create a linear module.
238
+ """
239
+ return nn.Linear(*args, **kwargs)
240
+
241
+
242
+ def avg_pool_nd(dims, *args, **kwargs):
243
+ """
244
+ Create a 1D, 2D, or 3D average pooling module.
245
+ """
246
+ if dims == 1:
247
+ return nn.AvgPool1d(*args, **kwargs)
248
+ elif dims == 2:
249
+ return nn.AvgPool2d(*args, **kwargs)
250
+ elif dims == 3:
251
+ return nn.AvgPool3d(*args, **kwargs)
252
+ raise ValueError(f"unsupported dimensions: {dims}")
253
+
254
+
255
+ class HybridConditioner(nn.Module):
256
+
257
+ def __init__(self, c_concat_config, c_crossattn_config):
258
+ super().__init__()
259
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
260
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
261
+
262
+ def forward(self, c_concat, c_crossattn):
263
+ c_concat = self.concat_conditioner(c_concat)
264
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
265
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
266
+
267
+
268
+ def noise_like(shape, device, repeat=False):
269
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
270
+ noise = lambda: torch.randn(shape, device=device)
271
+ return repeat_noise() if repeat else noise()
Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc ADDED
Binary file (4.85 kB). View file
 
Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .vqgan import VQGAN
2
+ from .codebook import Codebook
3
+ from .lpips import LPIPS
Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (282 Bytes). View file
 
Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc ADDED
Binary file (3.41 kB). View file
 
Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc ADDED
Binary file (6.78 kB). View file
 
Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc ADDED
Binary file (16.6 kB). View file
 
Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
3
+ size 7289
Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/codebook.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Adapted from https://github.com/SongweiGe/TATS"""
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ import numpy as np
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.distributed as dist
10
+
11
+ from ..utils import shift_dim
12
+
13
+
14
+ class Codebook(nn.Module):
15
+ def __init__(self, n_codes, embedding_dim, no_random_restart=False, restart_thres=1.0):
16
+ super().__init__()
17
+ self.register_buffer('embeddings', torch.randn(n_codes, embedding_dim))
18
+ self.register_buffer('N', torch.zeros(n_codes))
19
+ self.register_buffer('z_avg', self.embeddings.data.clone())
20
+
21
+ self.n_codes = n_codes
22
+ self.embedding_dim = embedding_dim
23
+ self._need_init = True
24
+ self.no_random_restart = no_random_restart
25
+ self.restart_thres = restart_thres
26
+
27
+ def _tile(self, x):
28
+ d, ew = x.shape
29
+ if d < self.n_codes:
30
+ n_repeats = (self.n_codes + d - 1) // d
31
+ std = 0.01 / np.sqrt(ew)
32
+ x = x.repeat(n_repeats, 1)
33
+ x = x + torch.randn_like(x) * std
34
+ return x
35
+
36
+ def _init_embeddings(self, z):
37
+ # z: [b, c, t, h, w]
38
+ self._need_init = False
39
+ breakpoint()
40
+ flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [65536, 8] [2, 8, 32, 32, 32]
41
+ y = self._tile(flat_inputs) # [65536, 8]
42
+
43
+ d = y.shape[0]
44
+ _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes]
45
+ if dist.is_initialized():
46
+ dist.broadcast(_k_rand, 0)
47
+ self.embeddings.data.copy_(_k_rand)
48
+ self.z_avg.data.copy_(_k_rand)
49
+ self.N.data.copy_(torch.ones(self.n_codes))
50
+
51
+ def forward(self, z):
52
+ # z: [b, c, t, h, w]
53
+ if self._need_init and self.training:
54
+ self._init_embeddings(z)
55
+ flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [bthw, c] [65536, 8]
56
+ distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \
57
+ - 2 * flat_inputs @ self.embeddings.t() \
58
+ + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) # [bthw, c] [65536, 8]
59
+
60
+ encoding_indices = torch.argmin(distances, dim=1) # [65536]
61
+ encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(
62
+ flat_inputs) # [bthw, ncode] [65536, 16384]
63
+ encoding_indices = encoding_indices.view(
64
+ z.shape[0], *z.shape[2:]) # [b, t, h, w, ncode] [2, 32, 32, 32]
65
+
66
+ embeddings = F.embedding(
67
+ encoding_indices, self.embeddings) # [b, t, h, w, c] self.embeddings [16384, 8]
68
+ embeddings = shift_dim(embeddings, -1, 1) # [b, c, t, h, w] [2, 8, 32, 32, 32]
69
+
70
+ commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
71
+
72
+ # EMA codebook update
73
+ if self.training:
74
+ n_total = encode_onehot.sum(dim=0) # [16384]
75
+ encode_sum = flat_inputs.t() @ encode_onehot # [8, 16384]
76
+ if dist.is_initialized():
77
+ dist.all_reduce(n_total)
78
+ dist.all_reduce(encode_sum)
79
+
80
+ self.N.data.mul_(0.99).add_(n_total, alpha=0.01)
81
+ self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01)
82
+
83
+ n = self.N.sum()
84
+ weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n
85
+ encode_normalized = self.z_avg / weights.unsqueeze(1)
86
+ self.embeddings.data.copy_(encode_normalized)
87
+
88
+ y = self._tile(flat_inputs)
89
+ _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes]
90
+ if dist.is_initialized():
91
+ dist.broadcast(_k_rand, 0)
92
+
93
+ if not self.no_random_restart:
94
+ usage = (self.N.view(self.n_codes, 1)
95
+ >= self.restart_thres).float()
96
+ self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage))
97
+
98
+ embeddings_st = (embeddings - z).detach() + z
99
+
100
+ avg_probs = torch.mean(encode_onehot, dim=0)
101
+ perplexity = torch.exp(-torch.sum(avg_probs *
102
+ torch.log(avg_probs + 1e-10)))
103
+
104
+ return dict(embeddings=embeddings_st, encodings=encoding_indices,
105
+ commitment_loss=commitment_loss, perplexity=perplexity)
106
+
107
+ def dictionary_lookup(self, encodings):
108
+ embeddings = F.embedding(encodings, self.embeddings)
109
+ return embeddings
Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/lpips.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adapted from https://github.com/SongweiGe/TATS"""
2
+
3
+ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
4
+
5
+
6
+ from collections import namedtuple
7
+ from torchvision import models
8
+ import torch.nn as nn
9
+ import torch
10
+ from tqdm import tqdm
11
+ import requests
12
+ import os
13
+ import hashlib
14
+ URL_MAP = {
15
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
16
+ }
17
+
18
+ CKPT_MAP = {
19
+ "vgg_lpips": "vgg.pth"
20
+ }
21
+
22
+ MD5_MAP = {
23
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
24
+ }
25
+
26
+
27
+ def download(url, local_path, chunk_size=1024):
28
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
29
+ with requests.get(url, stream=True) as r:
30
+ total_size = int(r.headers.get("content-length", 0))
31
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
32
+ with open(local_path, "wb") as f:
33
+ for data in r.iter_content(chunk_size=chunk_size):
34
+ if data:
35
+ f.write(data)
36
+ pbar.update(chunk_size)
37
+
38
+
39
+ def md5_hash(path):
40
+ with open(path, "rb") as f:
41
+ content = f.read()
42
+ return hashlib.md5(content).hexdigest()
43
+
44
+
45
+ def get_ckpt_path(name, root, check=False):
46
+ assert name in URL_MAP
47
+ path = os.path.join(root, CKPT_MAP[name])
48
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
49
+ print("Downloading {} model from {} to {}".format(
50
+ name, URL_MAP[name], path))
51
+ download(URL_MAP[name], path)
52
+ md5 = md5_hash(path)
53
+ assert md5 == MD5_MAP[name], md5
54
+ return path
55
+
56
+
57
+ class LPIPS(nn.Module):
58
+ # Learned perceptual metric
59
+ def __init__(self, use_dropout=True):
60
+ super().__init__()
61
+ self.scaling_layer = ScalingLayer()
62
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
63
+ self.net = vgg16(pretrained=True, requires_grad=False)
64
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
65
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
66
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
67
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
68
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
69
+ # self.load_from_pretrained()
70
+ for param in self.parameters():
71
+ param.requires_grad = False
72
+
73
+ def load_from_pretrained(self, name="vgg_lpips"):
74
+ ckpt = get_ckpt_path(name, os.path.join(
75
+ os.path.dirname(os.path.abspath(__file__)), "cache"))
76
+ self.load_state_dict(torch.load(
77
+ ckpt, map_location=torch.device("cpu")), strict=False)
78
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
79
+
80
+ @classmethod
81
+ def from_pretrained(cls, name="vgg_lpips"):
82
+ if name is not "vgg_lpips":
83
+ raise NotImplementedError
84
+ model = cls()
85
+ ckpt = get_ckpt_path(name, os.path.join(
86
+ os.path.dirname(os.path.abspath(__file__)), "cache"))
87
+ model.load_state_dict(torch.load(
88
+ ckpt, map_location=torch.device("cpu")), strict=False)
89
+ return model
90
+
91
+ def forward(self, input, target):
92
+ in0_input, in1_input = (self.scaling_layer(
93
+ input), self.scaling_layer(target))
94
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
95
+ feats0, feats1, diffs = {}, {}, {}
96
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
97
+ for kk in range(len(self.chns)):
98
+ feats0[kk], feats1[kk] = normalize_tensor(
99
+ outs0[kk]), normalize_tensor(outs1[kk])
100
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
101
+
102
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
103
+ for kk in range(len(self.chns))]
104
+ val = res[0]
105
+ for l in range(1, len(self.chns)):
106
+ val += res[l]
107
+ return val
108
+
109
+
110
+ class ScalingLayer(nn.Module):
111
+ def __init__(self):
112
+ super(ScalingLayer, self).__init__()
113
+ self.register_buffer('shift', torch.Tensor(
114
+ [-.030, -.088, -.188])[None, :, None, None])
115
+ self.register_buffer('scale', torch.Tensor(
116
+ [.458, .448, .450])[None, :, None, None])
117
+
118
+ def forward(self, inp):
119
+ return (inp - self.shift) / self.scale
120
+
121
+
122
+ class NetLinLayer(nn.Module):
123
+ """ A single linear layer which does a 1x1 conv """
124
+
125
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
126
+ super(NetLinLayer, self).__init__()
127
+ layers = [nn.Dropout(), ] if (use_dropout) else []
128
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1,
129
+ padding=0, bias=False), ]
130
+ self.model = nn.Sequential(*layers)
131
+
132
+
133
+ class vgg16(torch.nn.Module):
134
+ def __init__(self, requires_grad=False, pretrained=True):
135
+ super(vgg16, self).__init__()
136
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
137
+ self.slice1 = torch.nn.Sequential()
138
+ self.slice2 = torch.nn.Sequential()
139
+ self.slice3 = torch.nn.Sequential()
140
+ self.slice4 = torch.nn.Sequential()
141
+ self.slice5 = torch.nn.Sequential()
142
+ self.N_slices = 5
143
+ for x in range(4):
144
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
145
+ for x in range(4, 9):
146
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
147
+ for x in range(9, 16):
148
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
149
+ for x in range(16, 23):
150
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
151
+ for x in range(23, 30):
152
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
153
+ if not requires_grad:
154
+ for param in self.parameters():
155
+ param.requires_grad = False
156
+
157
+ def forward(self, X):
158
+ h = self.slice1(X)
159
+ h_relu1_2 = h
160
+ h = self.slice2(h)
161
+ h_relu2_2 = h
162
+ h = self.slice3(h)
163
+ h_relu3_3 = h
164
+ h = self.slice4(h)
165
+ h_relu4_3 = h
166
+ h = self.slice5(h)
167
+ h_relu5_3 = h
168
+ vgg_outputs = namedtuple(
169
+ "VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
170
+ out = vgg_outputs(h_relu1_2, h_relu2_2,
171
+ h_relu3_3, h_relu4_3, h_relu5_3)
172
+ return out
173
+
174
+
175
+ def normalize_tensor(x, eps=1e-10):
176
+ norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
177
+ return x/(norm_factor+eps)
178
+
179
+
180
+ def spatial_average(x, keepdim=True):
181
+ return x.mean([2, 3], keepdim=keepdim)
Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adapted from https://github.com/SongweiGe/TATS"""
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ import math
5
+ import argparse
6
+ import numpy as np
7
+ import pickle as pkl
8
+
9
+ import pytorch_lightning as pl
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.distributed as dist
14
+
15
+ from ..utils import shift_dim, adopt_weight, comp_getattr
16
+ from .lpips import LPIPS
17
+ from .codebook import Codebook
18
+
19
+
20
+ def silu(x):
21
+ return x*torch.sigmoid(x)
22
+
23
+
24
+ class SiLU(nn.Module):
25
+ def __init__(self):
26
+ super(SiLU, self).__init__()
27
+
28
+ def forward(self, x):
29
+ return silu(x)
30
+
31
+
32
+ def hinge_d_loss(logits_real, logits_fake):
33
+ loss_real = torch.mean(F.relu(1. - logits_real))
34
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
35
+ d_loss = 0.5 * (loss_real + loss_fake)
36
+ return d_loss
37
+
38
+
39
+ def vanilla_d_loss(logits_real, logits_fake):
40
+ d_loss = 0.5 * (
41
+ torch.mean(torch.nn.functional.softplus(-logits_real)) +
42
+ torch.mean(torch.nn.functional.softplus(logits_fake)))
43
+ return d_loss
44
+
45
+
46
+ class VQGAN(pl.LightningModule):
47
+ def __init__(self, cfg):
48
+ super().__init__()
49
+ self.cfg = cfg
50
+ self.embedding_dim = cfg.model.embedding_dim # 8
51
+ self.n_codes = cfg.model.n_codes # 16384
52
+
53
+ self.encoder = Encoder(cfg.model.n_hiddens, # 16
54
+ cfg.model.downsample, # [2, 2, 2]
55
+ cfg.dataset.image_channels, # 1
56
+ cfg.model.norm_type, # group
57
+ cfg.model.padding_type, # replicate
58
+ cfg.model.num_groups, # 32
59
+ )
60
+ self.decoder = Decoder(
61
+ cfg.model.n_hiddens, cfg.model.downsample, cfg.dataset.image_channels, cfg.model.norm_type, cfg.model.num_groups)
62
+ self.enc_out_ch = self.encoder.out_channels
63
+ self.pre_vq_conv = SamePadConv3d(
64
+ self.enc_out_ch, cfg.model.embedding_dim, 1, padding_type=cfg.model.padding_type)
65
+ self.post_vq_conv = SamePadConv3d(
66
+ cfg.model.embedding_dim, self.enc_out_ch, 1)
67
+
68
+ self.codebook = Codebook(cfg.model.n_codes, cfg.model.embedding_dim,
69
+ no_random_restart=cfg.model.no_random_restart, restart_thres=cfg.model.restart_thres)
70
+
71
+ self.gan_feat_weight = cfg.model.gan_feat_weight
72
+ # TODO: Changed batchnorm from sync to normal
73
+ self.image_discriminator = NLayerDiscriminator(
74
+ cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm2d)
75
+ self.video_discriminator = NLayerDiscriminator3D(
76
+ cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm3d)
77
+
78
+ if cfg.model.disc_loss_type == 'vanilla':
79
+ self.disc_loss = vanilla_d_loss
80
+ elif cfg.model.disc_loss_type == 'hinge':
81
+ self.disc_loss = hinge_d_loss
82
+
83
+ self.perceptual_model = LPIPS().eval()
84
+
85
+ self.image_gan_weight = cfg.model.image_gan_weight
86
+ self.video_gan_weight = cfg.model.video_gan_weight
87
+
88
+ self.perceptual_weight = cfg.model.perceptual_weight
89
+
90
+ self.l1_weight = cfg.model.l1_weight
91
+ self.save_hyperparameters()
92
+
93
+ def encode(self, x, include_embeddings=False, quantize=True):
94
+ h = self.pre_vq_conv(self.encoder(x))
95
+ if quantize:
96
+ vq_output = self.codebook(h)
97
+ if include_embeddings:
98
+ return vq_output['embeddings'], vq_output['encodings']
99
+ else:
100
+ return vq_output['encodings']
101
+ return h
102
+
103
+ def decode(self, latent, quantize=False):
104
+ if quantize:
105
+ vq_output = self.codebook(latent)
106
+ latent = vq_output['encodings']
107
+ h = F.embedding(latent, self.codebook.embeddings)
108
+ h = self.post_vq_conv(shift_dim(h, -1, 1))
109
+ return self.decoder(h)
110
+
111
+ def forward(self, x, optimizer_idx=None, log_image=False):
112
+ B, C, T, H, W = x.shape
113
+ z = self.pre_vq_conv(self.encoder(x)) # [2, 32, 32, 32, 32] [2, 8, 32, 32, 32]
114
+ vq_output = self.codebook(z) # ['embeddings', 'encodings', 'commitment_loss', 'perplexity']
115
+ x_recon = self.decoder(self.post_vq_conv(vq_output['embeddings'])) # [2, 8, 32, 32, 32] [2, 32, 32, 32, 32]
116
+
117
+ recon_loss = F.l1_loss(x_recon, x) * self.l1_weight
118
+
119
+ # Selects one random 2D image from each 3D Image
120
+ frame_idx = torch.randint(0, T, [B]).cuda()
121
+ frame_idx_selected = frame_idx.reshape(-1,
122
+ 1, 1, 1, 1).repeat(1, C, 1, H, W) # [2, 1, 1, 64, 64]
123
+ frames = torch.gather(x, 2, frame_idx_selected).squeeze(2) # [2, 1, 64, 64]
124
+ frames_recon = torch.gather(x_recon, 2, frame_idx_selected).squeeze(2) # [2, 1, 64, 64]
125
+
126
+ if log_image:
127
+ return frames, frames_recon, x, x_recon
128
+
129
+ if optimizer_idx == 0:
130
+ # Autoencoder - train the "generator"
131
+
132
+ # Perceptual loss
133
+ perceptual_loss = 0
134
+ if self.perceptual_weight > 0:
135
+ perceptual_loss = self.perceptual_model(
136
+ frames, frames_recon).mean() * self.perceptual_weight
137
+
138
+ # Discriminator loss (turned on after a certain epoch)
139
+ logits_image_fake, pred_image_fake = self.image_discriminator(
140
+ frames_recon)
141
+ logits_video_fake, pred_video_fake = self.video_discriminator(
142
+ x_recon)
143
+ g_image_loss = -torch.mean(logits_image_fake)
144
+ g_video_loss = -torch.mean(logits_video_fake)
145
+ g_loss = self.image_gan_weight*g_image_loss + self.video_gan_weight*g_video_loss
146
+ disc_factor = adopt_weight(
147
+ self.global_step, threshold=self.cfg.model.discriminator_iter_start)
148
+ aeloss = disc_factor * g_loss
149
+
150
+ # GAN feature matching loss - tune features such that we get the same prediction result on the discriminator
151
+ image_gan_feat_loss = 0
152
+ video_gan_feat_loss = 0
153
+ feat_weights = 4.0 / (3 + 1)
154
+ if self.image_gan_weight > 0:
155
+ logits_image_real, pred_image_real = self.image_discriminator(
156
+ frames)
157
+ for i in range(len(pred_image_fake)-1):
158
+ image_gan_feat_loss += feat_weights * \
159
+ F.l1_loss(pred_image_fake[i], pred_image_real[i].detach(
160
+ )) * (self.image_gan_weight > 0)
161
+ if self.video_gan_weight > 0:
162
+ logits_video_real, pred_video_real = self.video_discriminator(
163
+ x)
164
+ for i in range(len(pred_video_fake)-1):
165
+ video_gan_feat_loss += feat_weights * \
166
+ F.l1_loss(pred_video_fake[i], pred_video_real[i].detach(
167
+ )) * (self.video_gan_weight > 0)
168
+ gan_feat_loss = disc_factor * self.gan_feat_weight * \
169
+ (image_gan_feat_loss + video_gan_feat_loss)
170
+
171
+ self.log("train/g_image_loss", g_image_loss,
172
+ logger=True, on_step=True, on_epoch=True)
173
+ self.log("train/g_video_loss", g_video_loss,
174
+ logger=True, on_step=True, on_epoch=True)
175
+ self.log("train/image_gan_feat_loss", image_gan_feat_loss,
176
+ logger=True, on_step=True, on_epoch=True)
177
+ self.log("train/video_gan_feat_loss", video_gan_feat_loss,
178
+ logger=True, on_step=True, on_epoch=True)
179
+ self.log("train/perceptual_loss", perceptual_loss,
180
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
181
+ self.log("train/recon_loss", recon_loss, prog_bar=True,
182
+ logger=True, on_step=True, on_epoch=True)
183
+ self.log("train/aeloss", aeloss, prog_bar=True,
184
+ logger=True, on_step=True, on_epoch=True)
185
+ self.log("train/commitment_loss", vq_output['commitment_loss'],
186
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
187
+ self.log('train/perplexity', vq_output['perplexity'],
188
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
189
+ return recon_loss, x_recon, vq_output, aeloss, perceptual_loss, gan_feat_loss
190
+
191
+ if optimizer_idx == 1:
192
+ # Train discriminator
193
+ logits_image_real, _ = self.image_discriminator(frames.detach())
194
+ logits_video_real, _ = self.video_discriminator(x.detach())
195
+
196
+ logits_image_fake, _ = self.image_discriminator(
197
+ frames_recon.detach())
198
+ logits_video_fake, _ = self.video_discriminator(x_recon.detach())
199
+
200
+ d_image_loss = self.disc_loss(logits_image_real, logits_image_fake)
201
+ d_video_loss = self.disc_loss(logits_video_real, logits_video_fake)
202
+ disc_factor = adopt_weight(
203
+ self.global_step, threshold=self.cfg.model.discriminator_iter_start)
204
+ discloss = disc_factor * \
205
+ (self.image_gan_weight*d_image_loss +
206
+ self.video_gan_weight*d_video_loss)
207
+
208
+ self.log("train/logits_image_real", logits_image_real.mean().detach(),
209
+ logger=True, on_step=True, on_epoch=True)
210
+ self.log("train/logits_image_fake", logits_image_fake.mean().detach(),
211
+ logger=True, on_step=True, on_epoch=True)
212
+ self.log("train/logits_video_real", logits_video_real.mean().detach(),
213
+ logger=True, on_step=True, on_epoch=True)
214
+ self.log("train/logits_video_fake", logits_video_fake.mean().detach(),
215
+ logger=True, on_step=True, on_epoch=True)
216
+ self.log("train/d_image_loss", d_image_loss,
217
+ logger=True, on_step=True, on_epoch=True)
218
+ self.log("train/d_video_loss", d_video_loss,
219
+ logger=True, on_step=True, on_epoch=True)
220
+ self.log("train/discloss", discloss, prog_bar=True,
221
+ logger=True, on_step=True, on_epoch=True)
222
+ return discloss
223
+
224
+ perceptual_loss = self.perceptual_model(
225
+ frames, frames_recon) * self.perceptual_weight
226
+ return recon_loss, x_recon, vq_output, perceptual_loss
227
+
228
+ def training_step(self, batch, batch_idx, optimizer_idx):
229
+ x = batch['image']
230
+ if optimizer_idx == 0:
231
+ recon_loss, _, vq_output, aeloss, perceptual_loss, gan_feat_loss = self.forward(
232
+ x, optimizer_idx)
233
+ commitment_loss = vq_output['commitment_loss']
234
+ loss = recon_loss + commitment_loss + aeloss + perceptual_loss + gan_feat_loss
235
+ if optimizer_idx == 1:
236
+ discloss = self.forward(x, optimizer_idx)
237
+ loss = discloss
238
+ return loss
239
+
240
+ def validation_step(self, batch, batch_idx):
241
+ x = batch['image'] # TODO: batch['stft']
242
+ recon_loss, _, vq_output, perceptual_loss = self.forward(x)
243
+ self.log('val/recon_loss', recon_loss, prog_bar=True)
244
+ self.log('val/perceptual_loss', perceptual_loss, prog_bar=True)
245
+ self.log('val/perplexity', vq_output['perplexity'], prog_bar=True)
246
+ self.log('val/commitment_loss',
247
+ vq_output['commitment_loss'], prog_bar=True)
248
+
249
+ def configure_optimizers(self):
250
+ lr = self.cfg.model.lr
251
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters()) +
252
+ list(self.decoder.parameters()) +
253
+ list(self.pre_vq_conv.parameters()) +
254
+ list(self.post_vq_conv.parameters()) +
255
+ list(self.codebook.parameters()),
256
+ lr=lr, betas=(0.5, 0.9))
257
+ opt_disc = torch.optim.Adam(list(self.image_discriminator.parameters()) +
258
+ list(self.video_discriminator.parameters()),
259
+ lr=lr, betas=(0.5, 0.9))
260
+ return [opt_ae, opt_disc], []
261
+
262
+ def log_images(self, batch, **kwargs):
263
+ log = dict()
264
+ x = batch['image']
265
+ x = x.to(self.device)
266
+ frames, frames_rec, _, _ = self(x, log_image=True)
267
+ log["inputs"] = frames
268
+ log["reconstructions"] = frames_rec
269
+ #log['mean_org'] = batch['mean_org']
270
+ #log['std_org'] = batch['std_org']
271
+ return log
272
+
273
+ def log_videos(self, batch, **kwargs):
274
+ log = dict()
275
+ x = batch['image']
276
+ _, _, x, x_rec = self(x, log_image=True)
277
+ log["inputs"] = x
278
+ log["reconstructions"] = x_rec
279
+ #log['mean_org'] = batch['mean_org']
280
+ #log['std_org'] = batch['std_org']
281
+ return log
282
+
283
+
284
+ def Normalize(in_channels, norm_type='group', num_groups=32):
285
+ assert norm_type in ['group', 'batch']
286
+ if norm_type == 'group':
287
+ # TODO Changed num_groups from 32 to 8
288
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
289
+ elif norm_type == 'batch':
290
+ return torch.nn.SyncBatchNorm(in_channels)
291
+
292
+
293
+ class Encoder(nn.Module):
294
+ def __init__(self, n_hiddens, downsample, image_channel=3, norm_type='group', padding_type='replicate', num_groups=32):
295
+ super().__init__()
296
+ n_times_downsample = np.array([int(math.log2(d)) for d in downsample])
297
+ self.conv_blocks = nn.ModuleList()
298
+ max_ds = n_times_downsample.max()
299
+
300
+ self.conv_first = SamePadConv3d(
301
+ image_channel, n_hiddens, kernel_size=3, padding_type=padding_type)
302
+
303
+ for i in range(max_ds):
304
+ block = nn.Module()
305
+ in_channels = n_hiddens * 2**i
306
+ out_channels = n_hiddens * 2**(i+1)
307
+ stride = tuple([2 if d > 0 else 1 for d in n_times_downsample])
308
+ block.down = SamePadConv3d(
309
+ in_channels, out_channels, 4, stride=stride, padding_type=padding_type)
310
+ block.res = ResBlock(
311
+ out_channels, out_channels, norm_type=norm_type, num_groups=num_groups)
312
+ self.conv_blocks.append(block)
313
+ n_times_downsample -= 1
314
+
315
+ self.final_block = nn.Sequential(
316
+ Normalize(out_channels, norm_type, num_groups=num_groups),
317
+ SiLU()
318
+ )
319
+
320
+ self.out_channels = out_channels
321
+
322
+ def forward(self, x):
323
+ h = self.conv_first(x)
324
+ for block in self.conv_blocks:
325
+ h = block.down(h)
326
+ h = block.res(h)
327
+ h = self.final_block(h)
328
+ return h
329
+
330
+
331
+ class Decoder(nn.Module):
332
+ def __init__(self, n_hiddens, upsample, image_channel, norm_type='group', num_groups=32):
333
+ super().__init__()
334
+
335
+ n_times_upsample = np.array([int(math.log2(d)) for d in upsample])
336
+ max_us = n_times_upsample.max()
337
+
338
+ in_channels = n_hiddens*2**max_us
339
+ self.final_block = nn.Sequential(
340
+ Normalize(in_channels, norm_type, num_groups=num_groups),
341
+ SiLU()
342
+ )
343
+
344
+ self.conv_blocks = nn.ModuleList()
345
+ for i in range(max_us):
346
+ block = nn.Module()
347
+ in_channels = in_channels if i == 0 else n_hiddens*2**(max_us-i+1)
348
+ out_channels = n_hiddens*2**(max_us-i)
349
+ us = tuple([2 if d > 0 else 1 for d in n_times_upsample])
350
+ block.up = SamePadConvTranspose3d(
351
+ in_channels, out_channels, 4, stride=us)
352
+ block.res1 = ResBlock(
353
+ out_channels, out_channels, norm_type=norm_type, num_groups=num_groups)
354
+ block.res2 = ResBlock(
355
+ out_channels, out_channels, norm_type=norm_type, num_groups=num_groups)
356
+ self.conv_blocks.append(block)
357
+ n_times_upsample -= 1
358
+
359
+ self.conv_last = SamePadConv3d(
360
+ out_channels, image_channel, kernel_size=3)
361
+
362
+ def forward(self, x):
363
+ h = self.final_block(x)
364
+ for i, block in enumerate(self.conv_blocks):
365
+ h = block.up(h)
366
+ h = block.res1(h)
367
+ h = block.res2(h)
368
+ h = self.conv_last(h)
369
+ return h
370
+
371
+
372
+ class ResBlock(nn.Module):
373
+ def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group', padding_type='replicate', num_groups=32):
374
+ super().__init__()
375
+ self.in_channels = in_channels
376
+ out_channels = in_channels if out_channels is None else out_channels
377
+ self.out_channels = out_channels
378
+ self.use_conv_shortcut = conv_shortcut
379
+
380
+ self.norm1 = Normalize(in_channels, norm_type, num_groups=num_groups)
381
+ self.conv1 = SamePadConv3d(
382
+ in_channels, out_channels, kernel_size=3, padding_type=padding_type)
383
+ self.dropout = torch.nn.Dropout(dropout)
384
+ self.norm2 = Normalize(in_channels, norm_type, num_groups=num_groups)
385
+ self.conv2 = SamePadConv3d(
386
+ out_channels, out_channels, kernel_size=3, padding_type=padding_type)
387
+ if self.in_channels != self.out_channels:
388
+ self.conv_shortcut = SamePadConv3d(
389
+ in_channels, out_channels, kernel_size=3, padding_type=padding_type)
390
+
391
+ def forward(self, x):
392
+ h = x
393
+ h = self.norm1(h)
394
+ h = silu(h)
395
+ h = self.conv1(h)
396
+ h = self.norm2(h)
397
+ h = silu(h)
398
+ h = self.conv2(h)
399
+
400
+ if self.in_channels != self.out_channels:
401
+ x = self.conv_shortcut(x)
402
+
403
+ return x+h
404
+
405
+
406
+ # Does not support dilation
407
+ class SamePadConv3d(nn.Module):
408
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'):
409
+ super().__init__()
410
+ if isinstance(kernel_size, int):
411
+ kernel_size = (kernel_size,) * 3
412
+ if isinstance(stride, int):
413
+ stride = (stride,) * 3
414
+
415
+ # assumes that the input shape is divisible by stride
416
+ total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
417
+ pad_input = []
418
+ for p in total_pad[::-1]: # reverse since F.pad starts from last dim
419
+ pad_input.append((p // 2 + p % 2, p // 2))
420
+ pad_input = sum(pad_input, tuple())
421
+ self.pad_input = pad_input
422
+ self.padding_type = padding_type
423
+
424
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size,
425
+ stride=stride, padding=0, bias=bias)
426
+
427
+ def forward(self, x):
428
+ return self.conv(F.pad(x, self.pad_input, mode=self.padding_type))
429
+
430
+
431
+ class SamePadConvTranspose3d(nn.Module):
432
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'):
433
+ super().__init__()
434
+ if isinstance(kernel_size, int):
435
+ kernel_size = (kernel_size,) * 3
436
+ if isinstance(stride, int):
437
+ stride = (stride,) * 3
438
+
439
+ total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
440
+ pad_input = []
441
+ for p in total_pad[::-1]: # reverse since F.pad starts from last dim
442
+ pad_input.append((p // 2 + p % 2, p // 2))
443
+ pad_input = sum(pad_input, tuple())
444
+ self.pad_input = pad_input
445
+ self.padding_type = padding_type
446
+
447
+ self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size,
448
+ stride=stride, bias=bias,
449
+ padding=tuple([k - 1 for k in kernel_size]))
450
+
451
+ def forward(self, x):
452
+ return self.convt(F.pad(x, self.pad_input, mode=self.padding_type))
453
+
454
+
455
+ class NLayerDiscriminator(nn.Module):
456
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True):
457
+ # def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=True):
458
+ super(NLayerDiscriminator, self).__init__()
459
+ self.getIntermFeat = getIntermFeat
460
+ self.n_layers = n_layers
461
+
462
+ kw = 4
463
+ padw = int(np.ceil((kw-1.0)/2))
464
+ sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw,
465
+ stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
466
+
467
+ nf = ndf
468
+ for n in range(1, n_layers):
469
+ nf_prev = nf
470
+ nf = min(nf * 2, 512)
471
+ sequence += [[
472
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
473
+ norm_layer(nf), nn.LeakyReLU(0.2, True)
474
+ ]]
475
+
476
+ nf_prev = nf
477
+ nf = min(nf * 2, 512)
478
+ sequence += [[
479
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
480
+ norm_layer(nf),
481
+ nn.LeakyReLU(0.2, True)
482
+ ]]
483
+
484
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw,
485
+ stride=1, padding=padw)]]
486
+
487
+ if use_sigmoid:
488
+ sequence += [[nn.Sigmoid()]]
489
+
490
+ if getIntermFeat:
491
+ for n in range(len(sequence)):
492
+ setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
493
+ else:
494
+ sequence_stream = []
495
+ for n in range(len(sequence)):
496
+ sequence_stream += sequence[n]
497
+ self.model = nn.Sequential(*sequence_stream)
498
+
499
+ def forward(self, input):
500
+ if self.getIntermFeat:
501
+ res = [input]
502
+ for n in range(self.n_layers+2):
503
+ model = getattr(self, 'model'+str(n))
504
+ res.append(model(res[-1]))
505
+ return res[-1], res[1:]
506
+ else:
507
+ return self.model(input), _
508
+
509
+
510
+ class NLayerDiscriminator3D(nn.Module):
511
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True):
512
+ super(NLayerDiscriminator3D, self).__init__()
513
+ self.getIntermFeat = getIntermFeat
514
+ self.n_layers = n_layers
515
+
516
+ kw = 4
517
+ padw = int(np.ceil((kw-1.0)/2))
518
+ sequence = [[nn.Conv3d(input_nc, ndf, kernel_size=kw,
519
+ stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
520
+
521
+ nf = ndf
522
+ for n in range(1, n_layers):
523
+ nf_prev = nf
524
+ nf = min(nf * 2, 512)
525
+ sequence += [[
526
+ nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
527
+ norm_layer(nf), nn.LeakyReLU(0.2, True)
528
+ ]]
529
+
530
+ nf_prev = nf
531
+ nf = min(nf * 2, 512)
532
+ sequence += [[
533
+ nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
534
+ norm_layer(nf),
535
+ nn.LeakyReLU(0.2, True)
536
+ ]]
537
+
538
+ sequence += [[nn.Conv3d(nf, 1, kernel_size=kw,
539
+ stride=1, padding=padw)]]
540
+
541
+ if use_sigmoid:
542
+ sequence += [[nn.Sigmoid()]]
543
+
544
+ if getIntermFeat:
545
+ for n in range(len(sequence)):
546
+ setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
547
+ else:
548
+ sequence_stream = []
549
+ for n in range(len(sequence)):
550
+ sequence_stream += sequence[n]
551
+ self.model = nn.Sequential(*sequence_stream)
552
+
553
+ def forward(self, input):
554
+ if self.getIntermFeat:
555
+ res = [input]
556
+ for n in range(self.n_layers+2):
557
+ model = getattr(self, 'model'+str(n))
558
+ res.append(model(res[-1]))
559
+ return res[-1], res[1:]
560
+ else:
561
+ return self.model(input), _
Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/ldm/vq_gan_3d/utils.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Adapted from https://github.com/SongweiGe/TATS"""
2
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
+
4
+ import warnings
5
+ import torch
6
+ import imageio
7
+
8
+ import math
9
+ import numpy as np
10
+
11
+ import sys
12
+ import pdb as pdb_original
13
+ import logging
14
+
15
+ import imageio.core.util
16
+ logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR)
17
+
18
+
19
+ class ForkedPdb(pdb_original.Pdb):
20
+ """A Pdb subclass that may be used
21
+ from a forked multiprocessing child
22
+
23
+ """
24
+
25
+ def interaction(self, *args, **kwargs):
26
+ _stdin = sys.stdin
27
+ try:
28
+ sys.stdin = open('/dev/stdin')
29
+ pdb_original.Pdb.interaction(self, *args, **kwargs)
30
+ finally:
31
+ sys.stdin = _stdin
32
+
33
+
34
+ # Shifts src_tf dim to dest dim
35
+ # i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c)
36
+ def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
37
+ n_dims = len(x.shape)
38
+ if src_dim < 0:
39
+ src_dim = n_dims + src_dim
40
+ if dest_dim < 0:
41
+ dest_dim = n_dims + dest_dim
42
+
43
+ assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
44
+
45
+ dims = list(range(n_dims))
46
+ del dims[src_dim]
47
+
48
+ permutation = []
49
+ ctr = 0
50
+ for i in range(n_dims):
51
+ if i == dest_dim:
52
+ permutation.append(src_dim)
53
+ else:
54
+ permutation.append(dims[ctr])
55
+ ctr += 1
56
+ x = x.permute(permutation)
57
+ if make_contiguous:
58
+ x = x.contiguous()
59
+ return x
60
+
61
+
62
+ # reshapes tensor start from dim i (inclusive)
63
+ # to dim j (exclusive) to the desired shape
64
+ # e.g. if x.shape = (b, thw, c) then
65
+ # view_range(x, 1, 2, (t, h, w)) returns
66
+ # x of shape (b, t, h, w, c)
67
+ def view_range(x, i, j, shape):
68
+ shape = tuple(shape)
69
+
70
+ n_dims = len(x.shape)
71
+ if i < 0:
72
+ i = n_dims + i
73
+
74
+ if j is None:
75
+ j = n_dims
76
+ elif j < 0:
77
+ j = n_dims + j
78
+
79
+ assert 0 <= i < j <= n_dims
80
+
81
+ x_shape = x.shape
82
+ target_shape = x_shape[:i] + shape + x_shape[j:]
83
+ return x.view(target_shape)
84
+
85
+
86
+ def accuracy(output, target, topk=(1,)):
87
+ """Computes the accuracy over the k top predictions for the specified values of k"""
88
+ with torch.no_grad():
89
+ maxk = max(topk)
90
+ batch_size = target.size(0)
91
+
92
+ _, pred = output.topk(maxk, 1, True, True)
93
+ pred = pred.t()
94
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
95
+
96
+ res = []
97
+ for k in topk:
98
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
99
+ res.append(correct_k.mul_(100.0 / batch_size))
100
+ return res
101
+
102
+
103
+ def tensor_slice(x, begin, size):
104
+ assert all([b >= 0 for b in begin])
105
+ size = [l - b if s == -1 else s
106
+ for s, b, l in zip(size, begin, x.shape)]
107
+ assert all([s >= 0 for s in size])
108
+
109
+ slices = [slice(b, b + s) for b, s in zip(begin, size)]
110
+ return x[slices]
111
+
112
+
113
+ def adopt_weight(global_step, threshold=0, value=0.):
114
+ weight = 1
115
+ if global_step < threshold:
116
+ weight = value
117
+ return weight
118
+
119
+
120
+ def save_video_grid(video, fname, nrow=None, fps=6):
121
+ b, c, t, h, w = video.shape
122
+ video = video.permute(0, 2, 3, 4, 1)
123
+ video = (video.cpu().numpy() * 255).astype('uint8')
124
+ if nrow is None:
125
+ nrow = math.ceil(math.sqrt(b))
126
+ ncol = math.ceil(b / nrow)
127
+ padding = 1
128
+ video_grid = np.zeros((t, (padding + h) * nrow + padding,
129
+ (padding + w) * ncol + padding, c), dtype='uint8')
130
+ for i in range(b):
131
+ r = i // ncol
132
+ c = i % ncol
133
+ start_r = (padding + h) * r
134
+ start_c = (padding + w) * c
135
+ video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i]
136
+ video = []
137
+ for i in range(t):
138
+ video.append(video_grid[i])
139
+ imageio.mimsave(fname, video, fps=fps)
140
+ ## skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '5'})
141
+ #print('saved videos to', fname)
142
+
143
+
144
+ def comp_getattr(args, attr_name, default=None):
145
+ if hasattr(args, attr_name):
146
+ return getattr(args, attr_name)
147
+ else:
148
+ return default
149
+
150
+
151
+ def visualize_tensors(t, name=None, nest=0):
152
+ if name is not None:
153
+ print(name, "current nest: ", nest)
154
+ print("type: ", type(t))
155
+ if 'dict' in str(type(t)):
156
+ print(t.keys())
157
+ for k in t.keys():
158
+ if t[k] is None:
159
+ print(k, "None")
160
+ else:
161
+ if 'Tensor' in str(type(t[k])):
162
+ print(k, t[k].shape)
163
+ elif 'dict' in str(type(t[k])):
164
+ print(k, 'dict')
165
+ visualize_tensors(t[k], name, nest + 1)
166
+ elif 'list' in str(type(t[k])):
167
+ print(k, len(t[k]))
168
+ visualize_tensors(t[k], name, nest + 1)
169
+ elif 'list' in str(type(t)):
170
+ print("list length: ", len(t))
171
+ for t2 in t:
172
+ visualize_tensors(t2, name, nest + 1)
173
+ elif 'Tensor' in str(type(t)):
174
+ print(t.shape)
175
+ else:
176
+ print(t)
177
+ return ""
Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/model_weight/diffusion_colon.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbf32978597f4f8cb480c83117f4e041cfd94074c933223b3d60d6e65e5cce60
3
+ size 289042653
Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/model_weight/recon_colon.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a12f8e09916f2c6c35666738566c3e1612953e0ecf6ba5bfa9f209682e6bba62
3
+ size 242058287
Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/utils.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Tumor Generateion
2
+ import random
3
+ import cv2
4
+ import elasticdeform
5
+ import numpy as np
6
+ from scipy.ndimage import gaussian_filter
7
+ from TumorGeneration.ldm.ddpm.ddim import DDIMSampler
8
+ import skimage
9
+
10
+ def im2col(A, BSZ, stepsize=1):
11
+ # Parameters
12
+ M, N = A.shape
13
+ # Get Starting block indices
14
+ start_idx = np.arange(
15
+ 0, M-BSZ[0]+1, stepsize)[:, None]*N + np.arange(0, N-BSZ[1]+1, stepsize)
16
+ # Get offsetted indices across the height and width of input array
17
+ offset_idx = np.arange(BSZ[0])[:, None]*N + np.arange(BSZ[1])
18
+ # Get all actual indices & index into input array for final output
19
+ return np.take(A, start_idx.ravel()[:, None] + offset_idx.ravel())
20
+
21
+ def seg_to_instance_bd(seg: np.ndarray,
22
+ tsz_h: int = 1) -> np.ndarray:
23
+ """Generate instance contour map from segmentation masks.
24
+ """
25
+
26
+ tsz = tsz_h*2+1
27
+ tsz=int(tsz)
28
+ kernel = np.ones((tsz, tsz, tsz), np.uint8)
29
+ dilated_seg_mask = skimage.morphology.binary_erosion(seg.astype('uint8'), kernel)
30
+
31
+ dilated_seg_mask = dilated_seg_mask.astype(np.uint8)
32
+ bd = seg-dilated_seg_mask
33
+ bd = (bd>0).astype('uint8')
34
+
35
+ return bd
36
+
37
+ def sector_mask(shape,centre,radius,angle_range):
38
+ """
39
+ Return a boolean mask for a circular sector. The start/stop angles in
40
+ `angle_range` should be given in clockwise order.
41
+ """
42
+
43
+ x,y = np.ogrid[:shape[0],:shape[1]]
44
+ cx,cy = centre
45
+ tmin,tmax = np.deg2rad(angle_range)
46
+
47
+ # ensure stop angle > start angle
48
+ if tmax < tmin:
49
+ tmax += 2*np.pi
50
+
51
+ # convert cartesian --> polar coordinates
52
+ r2 = (x-cx)*(x-cx) + (y-cy)*(y-cy)
53
+ theta = np.arctan2(x-cx,y-cy) - tmin
54
+
55
+ # wrap angles between 0 and 2*pi
56
+ theta %= (2*np.pi)
57
+
58
+ # circular mask
59
+ circmask = r2 <= radius*radius
60
+
61
+ # angular mask
62
+ anglemask = theta <= (tmax-tmin)
63
+
64
+ return circmask*anglemask
65
+
66
+ from scipy.ndimage import label
67
+ import elasticdeform
68
+ def generate_random_mask(organ_mask):
69
+ # initialize tumor mask
70
+ tumor_mask = np.zeros_like(organ_mask)
71
+
72
+ # randowm mask angle
73
+ start_angle = random.randint(0, 360)
74
+ angle_range = random.randint(90, 360)
75
+
76
+ # generate organ boundary
77
+ erode_sz = angle_range//45 * 1 + 3
78
+ # select_size = [3.5, 4, 4.5, 5.0, 5.5, 6.0]
79
+ # erode_sz = np.random.choice(select_size)
80
+ # print('erode_sz', erode_sz)
81
+ organ_bd = seg_to_instance_bd(organ_mask, tsz_h=erode_sz)
82
+
83
+ # organ mask range
84
+ z_valid_list = np.where(np.any(organ_bd, axis=(0, 1)))[0]
85
+ valid_num = len(z_valid_list)
86
+ z_valid_list = z_valid_list[round(valid_num*0.25):round(valid_num*0.75)]
87
+ # print(z_valid_list)
88
+ z = random.choice(z_valid_list)
89
+
90
+ # sample thickness
91
+ z_thickness = random.randint(10, 20) # 10-20
92
+ # print('z, z_thickness', z, z_thickness)
93
+ # crop
94
+ tumor_mask[:,:,max(0,z-z_thickness):min(95,z+z_thickness)] = organ_bd[:,:,max(0,z-z_thickness):min(95,z+z_thickness)]
95
+
96
+ # random select one
97
+ tumor_mask, nb = label(tumor_mask)
98
+ sample_id = random.randint(1, nb)
99
+ sample_tumor_mask = (tumor_mask==sample_id).astype(np.uint8)
100
+
101
+ z_valid = np.where(np.any(sample_tumor_mask, axis=(0, 1)))[0]
102
+ z = z_valid[round(0.5 * len(z_valid))]
103
+
104
+ # randowm mask region
105
+ selected_slice = sample_tumor_mask[..., z]
106
+ coordinates = np.argwhere(selected_slice == 1)
107
+ center_x, center_y = int(coordinates[:,0].mean()), int(coordinates[:,1].mean())
108
+ # start_angle = random.randint(0, 360)
109
+ # angle_range = random.randint(90, 360)
110
+ mask_region = sector_mask(selected_slice.shape,(center_x,center_y), 48, (start_angle,start_angle+angle_range))
111
+ mask_region = np.repeat(mask_region[:,:,np.newaxis], axis=-1, repeats=96)
112
+
113
+ # elasticdeform
114
+ # sigma = random.uniform(1,2)
115
+ sigma = random.uniform(2,5)
116
+ # sigma = random.uniform(5,10)
117
+ deform_tumor_mask = elasticdeform.deform_random_grid(sample_tumor_mask, sigma=sigma, points=3, order=0, axis=(0,1))
118
+ # deform_tumor_mask = elasticdeform.deform_random_grid(deform_tumor_mask, sigma=sigma, points=3, order=0, axis=(1,2))
119
+ # deform_tumor_mask = elasticdeform.deform_random_grid(deform_tumor_mask, sigma=sigma, points=3, order=0, axis=(0,2))
120
+
121
+ # final_tumor_mask = deform_tumor_mask*mask_region*organ_mask
122
+ final_tumor_mask = deform_tumor_mask*mask_region
123
+
124
+ return final_tumor_mask
125
+
126
+
127
+ from .ldm.vq_gan_3d.model.vqgan import VQGAN
128
+ import matplotlib.pyplot as plt
129
+ import SimpleITK as sitk
130
+ from .ldm.ddpm import Unet3D, GaussianDiffusion, Tester
131
+ from hydra import initialize, compose
132
+ import torch
133
+ import yaml
134
+ def synt_model_prepare(device, vqgan_ckpt='TumorGeneration/model_weight/recon_colon.ckpt', diffusion_ckpt='TumorGeneration/model_weight/', fold=0, organ='colon'):
135
+ with initialize(config_path="diffusion_config/"):
136
+ cfg=compose(config_name="ddpm.yaml")
137
+ vqgan_ckpt = 'TumorGeneration/model_weight/recon_colon.ckpt'
138
+ diffusion_ckpt = 'TumorGeneration/model_weight/diffusion_colon.pt'
139
+ print('diffusion_ckpt',diffusion_ckpt)
140
+ vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt)
141
+ vqgan = vqgan.to(device)
142
+ vqgan.eval()
143
+
144
+
145
+ noearly_Unet3D = Unet3D(
146
+ dim=cfg.diffusion_img_size,
147
+ dim_mults=cfg.dim_mults,
148
+ channels=cfg.diffusion_num_channels,
149
+ out_dim=cfg.out_dim
150
+ ).to(device)
151
+
152
+ noearly_diffusion = GaussianDiffusion(
153
+ noearly_Unet3D,
154
+ vqgan_ckpt= vqgan_ckpt, # cfg.vqgan_ckpt,
155
+ image_size=cfg.diffusion_img_size,
156
+ num_frames=cfg.diffusion_depth_size,
157
+ channels=cfg.diffusion_num_channels,
158
+ timesteps=200, # cfg.timesteps,
159
+ loss_type=cfg.loss_type,
160
+ device=device
161
+ ).to(device)
162
+
163
+ noearly_checkpoint = torch.load(diffusion_ckpt, map_location=device)
164
+
165
+ noearly_diffusion.load_state_dict(noearly_checkpoint['ema'])
166
+
167
+ noearly_sampler = DDIMSampler(noearly_diffusion, schedule="cosine")
168
+ # breakpoint()
169
+ return vqgan, noearly_sampler
170
+
171
+ def synthesize_colon_tumor(ct_volume, organ_mask, vqgan, sampler, ddim_ts=50):
172
+ device=ct_volume.device
173
+
174
+ # generate tumor mask
175
+ total_tumor_mask = []
176
+ organ_mask_np = organ_mask.cpu().numpy()
177
+ with torch.no_grad():
178
+ # get model input
179
+ for bs in range(organ_mask_np.shape[0]):
180
+ tumor_mask = generate_random_mask(organ_mask_np[bs,0])
181
+ # tumor_mask = organ_mask_np[bs,0]
182
+ total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:])
183
+ total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device)
184
+ # breakpoint()
185
+ volume = ct_volume*2.0 - 1.0
186
+ mask = total_tumor_mask*2.0 - 1.0
187
+ mask_ = 1-total_tumor_mask
188
+ masked_volume = (volume*mask_).detach()
189
+
190
+ volume = volume.permute(0,1,-1,-3,-2)
191
+ masked_volume = masked_volume.permute(0,1,-1,-3,-2)
192
+ mask = mask.permute(0,1,-1,-3,-2)
193
+
194
+ # vqgan encoder inference
195
+ masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True)
196
+ masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) /
197
+ (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0
198
+
199
+ cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:])
200
+ cond = torch.cat((masked_volume_feat, cc), dim=1)
201
+
202
+ # diffusion inference and decoder
203
+ shape = masked_volume_feat.shape[-4:]
204
+ samples_ddim, _ = sampler.sample(S=ddim_ts,
205
+ conditioning=cond,
206
+ batch_size=1,
207
+ shape=shape,
208
+ verbose=False)
209
+ samples_ddim = (((samples_ddim + 1.0) / 2.0) * (vqgan.codebook.embeddings.max() -
210
+ vqgan.codebook.embeddings.min())) + vqgan.codebook.embeddings.min()
211
+
212
+ sample = vqgan.decode(samples_ddim, quantize=True)
213
+
214
+ # post-process
215
+ mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0)
216
+ sigma = np.random.uniform(1, 2) # (1, 2)
217
+ mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma])
218
+ # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy()
219
+
220
+ volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0)
221
+ sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0)
222
+
223
+ mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device)
224
+ final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_
225
+ final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0)
226
+
227
+ # final_volume_ = (sample+1.0)/2.0
228
+ final_volume_ = final_volume_.permute(0,1,-2,-1,-3)
229
+ organ_tumor_mask = torch.zeros_like(organ_mask)
230
+ organ_tumor_mask[organ_mask==1] = 1
231
+ organ_tumor_mask[total_tumor_mask==1] = 2
232
+
233
+ return final_volume_, organ_tumor_mask
Generation_Pipeline_filter_all2/syn_colon/TumorGeneration/utils_.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Tumor Generateion
2
+ import random
3
+ import cv2
4
+ import elasticdeform
5
+ import numpy as np
6
+ from scipy.ndimage import gaussian_filter
7
+
8
+ # Step 1: Random select (numbers) location for tumor.
9
+ def random_select(mask_scan):
10
+ # we first find z index and then sample point with z slice
11
+ z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0][[0, -1]]
12
+
13
+ # we need to strict number z's position (0.3 - 0.7 in the middle of liver)
14
+ z = round(random.uniform(0.3, 0.7) * (z_end - z_start)) + z_start
15
+
16
+ liver_mask = mask_scan[..., z]
17
+
18
+ # erode the mask (we don't want the edge points)
19
+ kernel = np.ones((5,5), dtype=np.uint8)
20
+ liver_mask = cv2.erode(liver_mask, kernel, iterations=1)
21
+
22
+ coordinates = np.argwhere(liver_mask == 1)
23
+ random_index = np.random.randint(0, len(coordinates))
24
+ xyz = coordinates[random_index].tolist() # get x,y
25
+ xyz.append(z)
26
+ potential_points = xyz
27
+
28
+ return potential_points
29
+
30
+ # Step 2 : generate the ellipsoid
31
+ def get_ellipsoid(x, y, z):
32
+ """"
33
+ x, y, z is the radius of this ellipsoid in x, y, z direction respectly.
34
+ """
35
+ sh = (4*x, 4*y, 4*z)
36
+ out = np.zeros(sh, int)
37
+ aux = np.zeros(sh)
38
+ radii = np.array([x, y, z])
39
+ com = np.array([2*x, 2*y, 2*z]) # center point
40
+
41
+ # calculate the ellipsoid
42
+ bboxl = np.floor(com-radii).clip(0,None).astype(int)
43
+ bboxh = (np.ceil(com+radii)+1).clip(None, sh).astype(int)
44
+ roi = out[tuple(map(slice,bboxl,bboxh))]
45
+ roiaux = aux[tuple(map(slice,bboxl,bboxh))]
46
+ logrid = *map(np.square,np.ogrid[tuple(
47
+ map(slice,(bboxl-com)/radii,(bboxh-com-1)/radii,1j*(bboxh-bboxl)))]),
48
+ dst = (1-sum(logrid)).clip(0,None)
49
+ mask = dst>roiaux
50
+ roi[mask] = 1
51
+ np.copyto(roiaux,dst,where=mask)
52
+
53
+ return out
54
+
55
+ def get_fixed_geo(mask_scan, tumor_type):
56
+
57
+ enlarge_x, enlarge_y, enlarge_z = 160, 160, 160
58
+ geo_mask = np.zeros((mask_scan.shape[0] + enlarge_x, mask_scan.shape[1] + enlarge_y, mask_scan.shape[2] + enlarge_z), dtype=np.int8)
59
+ tiny_radius, small_radius, medium_radius, large_radius = 4, 8, 16, 32
60
+
61
+ if tumor_type == 'tiny':
62
+ num_tumor = random.randint(3,10)
63
+ for _ in range(num_tumor):
64
+ # Tiny tumor
65
+ x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
66
+ y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
67
+ z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
68
+ sigma = random.uniform(0.5,1)
69
+
70
+ geo = get_ellipsoid(x, y, z)
71
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
72
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
73
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
74
+ point = random_select(mask_scan)
75
+ new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
76
+ x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
77
+ y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
78
+ z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
79
+
80
+ # paste small tumor geo into test sample
81
+ geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
82
+
83
+ if tumor_type == 'small':
84
+ num_tumor = random.randint(3,10)
85
+ for _ in range(num_tumor):
86
+ # Small tumor
87
+ x = random.randint(int(0.75*small_radius), int(1.25*small_radius))
88
+ y = random.randint(int(0.75*small_radius), int(1.25*small_radius))
89
+ z = random.randint(int(0.75*small_radius), int(1.25*small_radius))
90
+ sigma = random.randint(1, 2)
91
+
92
+ geo = get_ellipsoid(x, y, z)
93
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
94
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
95
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
96
+ # texture = get_texture((4*x, 4*y, 4*z))
97
+ point = random_select(mask_scan)
98
+ new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
99
+ x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
100
+ y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
101
+ z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
102
+
103
+ # paste small tumor geo into test sample
104
+ geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
105
+ # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture
106
+
107
+ if tumor_type == 'medium':
108
+ num_tumor = random.randint(2, 5)
109
+ for _ in range(num_tumor):
110
+ # medium tumor
111
+ x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
112
+ y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
113
+ z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
114
+ sigma = random.randint(3, 6)
115
+
116
+ geo = get_ellipsoid(x, y, z)
117
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
118
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
119
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
120
+ # texture = get_texture((4*x, 4*y, 4*z))
121
+ point = random_select(mask_scan)
122
+ new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
123
+ x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
124
+ y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
125
+ z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
126
+
127
+ # paste medium tumor geo into test sample
128
+ geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
129
+ # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture
130
+
131
+ if tumor_type == 'large':
132
+ num_tumor = random.randint(1,3)
133
+ for _ in range(num_tumor):
134
+ # Large tumor
135
+ x = random.randint(int(0.75*large_radius), int(1.25*large_radius))
136
+ y = random.randint(int(0.75*large_radius), int(1.25*large_radius))
137
+ z = random.randint(int(0.75*large_radius), int(1.25*large_radius))
138
+ sigma = random.randint(5, 10)
139
+
140
+ geo = get_ellipsoid(x, y, z)
141
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
142
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
143
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
144
+ # texture = get_texture((4*x, 4*y, 4*z))
145
+ point = random_select(mask_scan)
146
+ new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
147
+ x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
148
+ y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
149
+ z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
150
+
151
+ # paste small tumor geo into test sample
152
+ geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
153
+ # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture
154
+
155
+ if tumor_type == "mix":
156
+ # tiny
157
+ num_tumor = random.randint(3,10)
158
+ for _ in range(num_tumor):
159
+ # Tiny tumor
160
+ x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
161
+ y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
162
+ z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius))
163
+ sigma = random.uniform(0.5,1)
164
+
165
+ geo = get_ellipsoid(x, y, z)
166
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
167
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
168
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
169
+ point = random_select(mask_scan)
170
+ new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
171
+ x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
172
+ y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
173
+ z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
174
+
175
+ # paste small tumor geo into test sample
176
+ geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
177
+
178
+ # small
179
+ num_tumor = random.randint(5,10)
180
+ for _ in range(num_tumor):
181
+ # Small tumor
182
+ x = random.randint(int(0.75*small_radius), int(1.25*small_radius))
183
+ y = random.randint(int(0.75*small_radius), int(1.25*small_radius))
184
+ z = random.randint(int(0.75*small_radius), int(1.25*small_radius))
185
+ sigma = random.randint(1, 2)
186
+
187
+ geo = get_ellipsoid(x, y, z)
188
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
189
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
190
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
191
+ # texture = get_texture((4*x, 4*y, 4*z))
192
+ point = random_select(mask_scan)
193
+ new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
194
+ x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
195
+ y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
196
+ z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
197
+
198
+ # paste small tumor geo into test sample
199
+ geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
200
+ # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture
201
+
202
+ # medium
203
+ num_tumor = random.randint(2, 5)
204
+ for _ in range(num_tumor):
205
+ # medium tumor
206
+ x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
207
+ y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
208
+ z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius))
209
+ sigma = random.randint(3, 6)
210
+
211
+ geo = get_ellipsoid(x, y, z)
212
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
213
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
214
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
215
+ # texture = get_texture((4*x, 4*y, 4*z))
216
+ point = random_select(mask_scan)
217
+ new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
218
+ x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
219
+ y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
220
+ z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
221
+
222
+ # paste medium tumor geo into test sample
223
+ geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
224
+ # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture
225
+
226
+ # large
227
+ num_tumor = random.randint(1,3)
228
+ for _ in range(num_tumor):
229
+ # Large tumor
230
+ x = random.randint(int(0.75*large_radius), int(1.25*large_radius))
231
+ y = random.randint(int(0.75*large_radius), int(1.25*large_radius))
232
+ z = random.randint(int(0.75*large_radius), int(1.25*large_radius))
233
+ sigma = random.randint(5, 10)
234
+ geo = get_ellipsoid(x, y, z)
235
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1))
236
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2))
237
+ geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2))
238
+ # texture = get_texture((4*x, 4*y, 4*z))
239
+ point = random_select(mask_scan)
240
+ new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2]
241
+ x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2
242
+ y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2
243
+ z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2
244
+
245
+ # paste small tumor geo into test sample
246
+ geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo
247
+ # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture
248
+
249
+ geo_mask = geo_mask[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2]
250
+ # texture_map = texture_map[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2]
251
+ geo_mask = (geo_mask * mask_scan) >=1
252
+
253
+ return geo_mask
254
+
255
+
256
+ def get_tumor(volume_scan, mask_scan, tumor_type):
257
+ tumor_mask = get_fixed_geo(mask_scan, tumor_type)
258
+
259
+ sigma = np.random.uniform(1, 2)
260
+ # difference = np.random.uniform(65, 145)
261
+ difference = 1
262
+
263
+ # blur the boundary
264
+ tumor_mask_blur = gaussian_filter(tumor_mask*1.0, sigma)
265
+
266
+
267
+ abnormally_full = volume_scan * (1 - mask_scan) + abnormally
268
+ abnormally_mask = mask_scan + geo_mask
269
+
270
+ return abnormally_full, abnormally_mask
271
+
272
+ def SynthesisTumor(volume_scan, mask_scan, tumor_type):
273
+ # for speed_generate_tumor, we only send the liver part into the generate program
274
+ x_start, x_end = np.where(np.any(mask_scan, axis=(1, 2)))[0][[0, -1]]
275
+ y_start, y_end = np.where(np.any(mask_scan, axis=(0, 2)))[0][[0, -1]]
276
+ z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0][[0, -1]]
277
+
278
+ # shrink the boundary
279
+ x_start, x_end = max(0, x_start+1), min(mask_scan.shape[0], x_end-1)
280
+ y_start, y_end = max(0, y_start+1), min(mask_scan.shape[1], y_end-1)
281
+ z_start, z_end = max(0, z_start+1), min(mask_scan.shape[2], z_end-1)
282
+
283
+ ct_volume = volume_scan[x_start:x_end, y_start:y_end, z_start:z_end]
284
+ organ_mask = mask_scan[x_start:x_end, y_start:y_end, z_start:z_end]
285
+
286
+ # input texture shape: 420, 300, 320
287
+ # we need to cut it into the shape of liver_mask
288
+ # for examples, the liver_mask.shape = 286, 173, 46; we should change the texture shape
289
+ x_length, y_length, z_length = 64, 64, 64
290
+ crop_x = random.randint(x_start, x_end - x_length - 1) # random select the start point, -1 is to avoid boundary check
291
+ crop_y = random.randint(y_start, y_end - y_length - 1)
292
+ crop_z = random.randint(z_start, z_end - z_length - 1)
293
+
294
+ ct_volume, organ_tumor_mask = get_tumor(ct_volume, organ_mask, tumor_type)
295
+ volume_scan[x_start:x_end, y_start:y_end, z_start:z_end] = ct_volume
296
+ mask_scan[x_start:x_end, y_start:y_end, z_start:z_end] = organ_tumor_mask
297
+
298
+ return volume_scan, mask_scan
Generation_Pipeline_filter_all2/syn_colon/healthy_colon_1k.txt ADDED
@@ -0,0 +1,928 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BDMAP_00001823
2
+ BDMAP_00003074
3
+ BDMAP_00001305
4
+ BDMAP_00001635
5
+ BDMAP_00002359
6
+ BDMAP_00001265
7
+ BDMAP_00000701
8
+ BDMAP_00000771
9
+ BDMAP_00003581
10
+ BDMAP_00002523
11
+ BDMAP_00004028
12
+ BDMAP_00005151
13
+ BDMAP_00001183
14
+ BDMAP_00001656
15
+ BDMAP_00003898
16
+ BDMAP_00001845
17
+ BDMAP_00000481
18
+ BDMAP_00003324
19
+ BDMAP_00002688
20
+ BDMAP_00000948
21
+ BDMAP_00004796
22
+ BDMAP_00004198
23
+ BDMAP_00003514
24
+ BDMAP_00000432
25
+ BDMAP_00003832
26
+ BDMAP_00001296
27
+ BDMAP_00003683
28
+ BDMAP_00001607
29
+ BDMAP_00004745
30
+ BDMAP_00005167
31
+ BDMAP_00005154
32
+ BDMAP_00003598
33
+ BDMAP_00003551
34
+ BDMAP_00000176
35
+ BDMAP_00004719
36
+ BDMAP_00003722
37
+ BDMAP_00002690
38
+ BDMAP_00002244
39
+ BDMAP_00000883
40
+ BDMAP_00000926
41
+ BDMAP_00002849
42
+ BDMAP_00004549
43
+ BDMAP_00004017
44
+ BDMAP_00003482
45
+ BDMAP_00003225
46
+ BDMAP_00000416
47
+ BDMAP_00002387
48
+ BDMAP_00002022
49
+ BDMAP_00002909
50
+ BDMAP_00003236
51
+ BDMAP_00000465
52
+ BDMAP_00001784
53
+ BDMAP_00004103
54
+ BDMAP_00000656
55
+ BDMAP_00004850
56
+ BDMAP_00002955
57
+ BDMAP_00003633
58
+ BDMAP_00000137
59
+ BDMAP_00004529
60
+ BDMAP_00004903
61
+ BDMAP_00001309
62
+ BDMAP_00002216
63
+ BDMAP_00001444
64
+ BDMAP_00000263
65
+ BDMAP_00004066
66
+ BDMAP_00003920
67
+ BDMAP_00001434
68
+ BDMAP_00004890
69
+ BDMAP_00000400
70
+ BDMAP_00001238
71
+ BDMAP_00003592
72
+ BDMAP_00000431
73
+ BDMAP_00002304
74
+ BDMAP_00000285
75
+ BDMAP_00004995
76
+ BDMAP_00004264
77
+ BDMAP_00001440
78
+ BDMAP_00001383
79
+ BDMAP_00003614
80
+ BDMAP_00005157
81
+ BDMAP_00003608
82
+ BDMAP_00002619
83
+ BDMAP_00000615
84
+ BDMAP_00000084
85
+ BDMAP_00002804
86
+ BDMAP_00002592
87
+ BDMAP_00001868
88
+ BDMAP_00002021
89
+ BDMAP_00000297
90
+ BDMAP_00003202
91
+ BDMAP_00000411
92
+ BDMAP_00005070
93
+ BDMAP_00003364
94
+ BDMAP_00004395
95
+ BDMAP_00002075
96
+ BDMAP_00002844
97
+ BDMAP_00002712
98
+ BDMAP_00000714
99
+ BDMAP_00002717
100
+ BDMAP_00004895
101
+ BDMAP_00000698
102
+ BDMAP_00003384
103
+ BDMAP_00001286
104
+ BDMAP_00001562
105
+ BDMAP_00004228
106
+ BDMAP_00000831
107
+ BDMAP_00000855
108
+ BDMAP_00004672
109
+ BDMAP_00000882
110
+ BDMAP_00004992
111
+ BDMAP_00002232
112
+ BDMAP_00003849
113
+ BDMAP_00004880
114
+ BDMAP_00004074
115
+ BDMAP_00002626
116
+ BDMAP_00004262
117
+ BDMAP_00000368
118
+ BDMAP_00002826
119
+ BDMAP_00000837
120
+ BDMAP_00001911
121
+ BDMAP_00001557
122
+ BDMAP_00001126
123
+ BDMAP_00002328
124
+ BDMAP_00002959
125
+ BDMAP_00002562
126
+ BDMAP_00003600
127
+ BDMAP_00001057
128
+ BDMAP_00000940
129
+ BDMAP_00002120
130
+ BDMAP_00002227
131
+ BDMAP_00000122
132
+ BDMAP_00002479
133
+ BDMAP_00002805
134
+ BDMAP_00004980
135
+ BDMAP_00001862
136
+ BDMAP_00000778
137
+ BDMAP_00003749
138
+ BDMAP_00000245
139
+ BDMAP_00000989
140
+ BDMAP_00001247
141
+ BDMAP_00000623
142
+ BDMAP_00004113
143
+ BDMAP_00002278
144
+ BDMAP_00004841
145
+ BDMAP_00001602
146
+ BDMAP_00001464
147
+ BDMAP_00001712
148
+ BDMAP_00003815
149
+ BDMAP_00002407
150
+ BDMAP_00003150
151
+ BDMAP_00001711
152
+ BDMAP_00002273
153
+ BDMAP_00002751
154
+ BDMAP_00005074
155
+ BDMAP_00001068
156
+ BDMAP_00004447
157
+ BDMAP_00000977
158
+ BDMAP_00004297
159
+ BDMAP_00000812
160
+ BDMAP_00004641
161
+ BDMAP_00001422
162
+ BDMAP_00003385
163
+ BDMAP_00003164
164
+ BDMAP_00002475
165
+ BDMAP_00002166
166
+ BDMAP_00004232
167
+ BDMAP_00000826
168
+ BDMAP_00003769
169
+ BDMAP_00003569
170
+ BDMAP_00003853
171
+ BDMAP_00004494
172
+ BDMAP_00004011
173
+ BDMAP_00002776
174
+ BDMAP_00001517
175
+ BDMAP_00004304
176
+ BDMAP_00004645
177
+ BDMAP_00000091
178
+ BDMAP_00004738
179
+ BDMAP_00000725
180
+ BDMAP_00003771
181
+ BDMAP_00002524
182
+ BDMAP_00000161
183
+ BDMAP_00000902
184
+ BDMAP_00001786
185
+ BDMAP_00002332
186
+ BDMAP_00004175
187
+ BDMAP_00002419
188
+ BDMAP_00004077
189
+ BDMAP_00004295
190
+ BDMAP_00002871
191
+ BDMAP_00004148
192
+ BDMAP_00000676
193
+ BDMAP_00001782
194
+ BDMAP_00003947
195
+ BDMAP_00003513
196
+ BDMAP_00003130
197
+ BDMAP_00001545
198
+ BDMAP_00000667
199
+ BDMAP_00005078
200
+ BDMAP_00003435
201
+ BDMAP_00002545
202
+ BDMAP_00002498
203
+ BDMAP_00001255
204
+ BDMAP_00004065
205
+ BDMAP_00002099
206
+ BDMAP_00001504
207
+ BDMAP_00001863
208
+ BDMAP_00000542
209
+ BDMAP_00002326
210
+ BDMAP_00005155
211
+ BDMAP_00001476
212
+ BDMAP_00000388
213
+ BDMAP_00000159
214
+ BDMAP_00004060
215
+ BDMAP_00000332
216
+ BDMAP_00004087
217
+ BDMAP_00000516
218
+ BDMAP_00000574
219
+ BDMAP_00004943
220
+ BDMAP_00004514
221
+ BDMAP_00003329
222
+ BDMAP_00001597
223
+ BDMAP_00002172
224
+ BDMAP_00000833
225
+ BDMAP_00004187
226
+ BDMAP_00004744
227
+ BDMAP_00001676
228
+ BDMAP_00003558
229
+ BDMAP_00003438
230
+ BDMAP_00001957
231
+ BDMAP_00004128
232
+ BDMAP_00005140
233
+ BDMAP_00002656
234
+ BDMAP_00004817
235
+ BDMAP_00000745
236
+ BDMAP_00000205
237
+ BDMAP_00000671
238
+ BDMAP_00001962
239
+ BDMAP_00003543
240
+ BDMAP_00001620
241
+ BDMAP_00003128
242
+ BDMAP_00003409
243
+ BDMAP_00000982
244
+ BDMAP_00004015
245
+ BDMAP_00001707
246
+ BDMAP_00002068
247
+ BDMAP_00001236
248
+ BDMAP_00003973
249
+ BDMAP_00004870
250
+ BDMAP_00000366
251
+ BDMAP_00003685
252
+ BDMAP_00001096
253
+ BDMAP_00003347
254
+ BDMAP_00001892
255
+ BDMAP_00003740
256
+ BDMAP_00004773
257
+ BDMAP_00002260
258
+ BDMAP_00002815
259
+ BDMAP_00000972
260
+ BDMAP_00000998
261
+ BDMAP_00003063
262
+ BDMAP_00001791
263
+ BDMAP_00002085
264
+ BDMAP_00002275
265
+ BDMAP_00004016
266
+ BDMAP_00000438
267
+ BDMAP_00000709
268
+ BDMAP_00004416
269
+ BDMAP_00003884
270
+ BDMAP_00002237
271
+ BDMAP_00001794
272
+ BDMAP_00004378
273
+ BDMAP_00000713
274
+ BDMAP_00004286
275
+ BDMAP_00001109
276
+ BDMAP_00001223
277
+ BDMAP_00001027
278
+ BDMAP_00001001
279
+ BDMAP_00005097
280
+ BDMAP_00002942
281
+ BDMAP_00000607
282
+ BDMAP_00002940
283
+ BDMAP_00002930
284
+ BDMAP_00003377
285
+ BDMAP_00004509
286
+ BDMAP_00000923
287
+ BDMAP_00001413
288
+ BDMAP_00001636
289
+ BDMAP_00001705
290
+ BDMAP_00000273
291
+ BDMAP_00003840
292
+ BDMAP_00001333
293
+ BDMAP_00005092
294
+ BDMAP_00001368
295
+ BDMAP_00003994
296
+ BDMAP_00004925
297
+ BDMAP_00001370
298
+ BDMAP_00003455
299
+ BDMAP_00002631
300
+ BDMAP_00005174
301
+ BDMAP_00005009
302
+ BDMAP_00001549
303
+ BDMAP_00001941
304
+ BDMAP_00000154
305
+ BDMAP_00001521
306
+ BDMAP_00002653
307
+ BDMAP_00001148
308
+ BDMAP_00000774
309
+ BDMAP_00005105
310
+ BDMAP_00002421
311
+ BDMAP_00000139
312
+ BDMAP_00003867
313
+ BDMAP_00003479
314
+ BDMAP_00004741
315
+ BDMAP_00001516
316
+ BDMAP_00002396
317
+ BDMAP_00003481
318
+ BDMAP_00000324
319
+ BDMAP_00002841
320
+ BDMAP_00003326
321
+ BDMAP_00002437
322
+ BDMAP_00000100
323
+ BDMAP_00004586
324
+ BDMAP_00004867
325
+ BDMAP_00001040
326
+ BDMAP_00001185
327
+ BDMAP_00001461
328
+ BDMAP_00000692
329
+ BDMAP_00001563
330
+ BDMAP_00002289
331
+ BDMAP_00004901
332
+ BDMAP_00001632
333
+ BDMAP_00000558
334
+ BDMAP_00000469
335
+ BDMAP_00001966
336
+ BDMAP_00003315
337
+ BDMAP_00002313
338
+ BDMAP_00005006
339
+ BDMAP_00000439
340
+ BDMAP_00004551
341
+ BDMAP_00003294
342
+ BDMAP_00001807
343
+ BDMAP_00004579
344
+ BDMAP_00002057
345
+ BDMAP_00002060
346
+ BDMAP_00004508
347
+ BDMAP_00004104
348
+ BDMAP_00000052
349
+ BDMAP_00003439
350
+ BDMAP_00001502
351
+ BDMAP_00005186
352
+ BDMAP_00002529
353
+ BDMAP_00002775
354
+ BDMAP_00004834
355
+ BDMAP_00001496
356
+ BDMAP_00002319
357
+ BDMAP_00002856
358
+ BDMAP_00004552
359
+ BDMAP_00004878
360
+ BDMAP_00001331
361
+ BDMAP_00001912
362
+ BDMAP_00002758
363
+ BDMAP_00000414
364
+ BDMAP_00004288
365
+ BDMAP_00000805
366
+ BDMAP_00004597
367
+ BDMAP_00003178
368
+ BDMAP_00001752
369
+ BDMAP_00003943
370
+ BDMAP_00004652
371
+ BDMAP_00004541
372
+ BDMAP_00000614
373
+ BDMAP_00004639
374
+ BDMAP_00001804
375
+ BDMAP_00005063
376
+ BDMAP_00002807
377
+ BDMAP_00000062
378
+ BDMAP_00005119
379
+ BDMAP_00004417
380
+ BDMAP_00005075
381
+ BDMAP_00001441
382
+ BDMAP_00002373
383
+ BDMAP_00002041
384
+ BDMAP_00003727
385
+ BDMAP_00001483
386
+ BDMAP_00001128
387
+ BDMAP_00004927
388
+ BDMAP_00001119
389
+ BDMAP_00004106
390
+ BDMAP_00000355
391
+ BDMAP_00002354
392
+ BDMAP_00004030
393
+ BDMAP_00004847
394
+ BDMAP_00000618
395
+ BDMAP_00003736
396
+ BDMAP_00002803
397
+ BDMAP_00005099
398
+ BDMAP_00003168
399
+ BDMAP_00000941
400
+ BDMAP_00000243
401
+ BDMAP_00001664
402
+ BDMAP_00001747
403
+ BDMAP_00003774
404
+ BDMAP_00004917
405
+ BDMAP_00000867
406
+ BDMAP_00000435
407
+ BDMAP_00003822
408
+ BDMAP_00003411
409
+ BDMAP_00000965
410
+ BDMAP_00003612
411
+ BDMAP_00004023
412
+ BDMAP_00002333
413
+ BDMAP_00001270
414
+ BDMAP_00002616
415
+ BDMAP_00004511
416
+ BDMAP_00005130
417
+ BDMAP_00000642
418
+ BDMAP_00002471
419
+ BDMAP_00000589
420
+ BDMAP_00002509
421
+ BDMAP_00004561
422
+ BDMAP_00001275
423
+ BDMAP_00003133
424
+ BDMAP_00000626
425
+ BDMAP_00003491
426
+ BDMAP_00000993
427
+ BDMAP_00003493
428
+ BDMAP_00004499
429
+ BDMAP_00002065
430
+ BDMAP_00001175
431
+ BDMAP_00002696
432
+ BDMAP_00000319
433
+ BDMAP_00002410
434
+ BDMAP_00002485
435
+ BDMAP_00001258
436
+ BDMAP_00000660
437
+ BDMAP_00003272
438
+ BDMAP_00004183
439
+ BDMAP_00003359
440
+ BDMAP_00000956
441
+ BDMAP_00004462
442
+ BDMAP_00001704
443
+ BDMAP_00000039
444
+ BDMAP_00001853
445
+ BDMAP_00003857
446
+ BDMAP_00000572
447
+ BDMAP_00005168
448
+ BDMAP_00000304
449
+ BDMAP_00002426
450
+ BDMAP_00000244
451
+ BDMAP_00001646
452
+ BDMAP_00000413
453
+ BDMAP_00004735
454
+ BDMAP_00002476
455
+ BDMAP_00004039
456
+ BDMAP_00000219
457
+ BDMAP_00004651
458
+ BDMAP_00005065
459
+ BDMAP_00004281
460
+ BDMAP_00000113
461
+ BDMAP_00003956
462
+ BDMAP_00002226
463
+ BDMAP_00004130
464
+ BDMAP_00002707
465
+ BDMAP_00000430
466
+ BDMAP_00002661
467
+ BDMAP_00001617
468
+ BDMAP_00002298
469
+ BDMAP_00003930
470
+ BDMAP_00000687
471
+ BDMAP_00004195
472
+ BDMAP_00001647
473
+ BDMAP_00000487
474
+ BDMAP_00003367
475
+ BDMAP_00003277
476
+ BDMAP_00004600
477
+ BDMAP_00003497
478
+ BDMAP_00004546
479
+ BDMAP_00004808
480
+ BDMAP_00002981
481
+ BDMAP_00000229
482
+ BDMAP_00004185
483
+ BDMAP_00003406
484
+ BDMAP_00002422
485
+ BDMAP_00002947
486
+ BDMAP_00001261
487
+ BDMAP_00005037
488
+ BDMAP_00003590
489
+ BDMAP_00003058
490
+ BDMAP_00003461
491
+ BDMAP_00003151
492
+ BDMAP_00001035
493
+ BDMAP_00001289
494
+ BDMAP_00000087
495
+ BDMAP_00004981
496
+ BDMAP_00001836
497
+ BDMAP_00004712
498
+ BDMAP_00002363
499
+ BDMAP_00002495
500
+ BDMAP_00004398
501
+ BDMAP_00003457
502
+ BDMAP_00003752
503
+ BDMAP_00001891
504
+ BDMAP_00004373
505
+ BDMAP_00001590
506
+ BDMAP_00003506
507
+ BDMAP_00001921
508
+ BDMAP_00004229
509
+ BDMAP_00001898
510
+ BDMAP_00003483
511
+ BDMAP_00004616
512
+ BDMAP_00002648
513
+ BDMAP_00000562
514
+ BDMAP_00002403
515
+ BDMAP_00003361
516
+ BDMAP_00000887
517
+ BDMAP_00001283
518
+ BDMAP_00002719
519
+ BDMAP_00005064
520
+ BDMAP_00002793
521
+ BDMAP_00002242
522
+ BDMAP_00004278
523
+ BDMAP_00002117
524
+ BDMAP_00000320
525
+ BDMAP_00005191
526
+ BDMAP_00000809
527
+ BDMAP_00000859
528
+ BDMAP_00003955
529
+ BDMAP_00004253
530
+ BDMAP_00004031
531
+ BDMAP_00005139
532
+ BDMAP_00003244
533
+ BDMAP_00000149
534
+ BDMAP_00001414
535
+ BDMAP_00001945
536
+ BDMAP_00004510
537
+ BDMAP_00003824
538
+ BDMAP_00001361
539
+ BDMAP_00000662
540
+ BDMAP_00005022
541
+ BDMAP_00000434
542
+ BDMAP_00000241
543
+ BDMAP_00000710
544
+ BDMAP_00005120
545
+ BDMAP_00002383
546
+ BDMAP_00003036
547
+ BDMAP_00002609
548
+ BDMAP_00004922
549
+ BDMAP_00004407
550
+ BDMAP_00004481
551
+ BDMAP_00001225
552
+ BDMAP_00003556
553
+ BDMAP_00000329
554
+ BDMAP_00003052
555
+ BDMAP_00003396
556
+ BDMAP_00002164
557
+ BDMAP_00001077
558
+ BDMAP_00003153
559
+ BDMAP_00003776
560
+ BDMAP_00002710
561
+ BDMAP_00004746
562
+ BDMAP_00000066
563
+ BDMAP_00005085
564
+ BDMAP_00004435
565
+ BDMAP_00002695
566
+ BDMAP_00001828
567
+ BDMAP_00003392
568
+ BDMAP_00003976
569
+ BDMAP_00002744
570
+ BDMAP_00002214
571
+ BDMAP_00000569
572
+ BDMAP_00000571
573
+ BDMAP_00004888
574
+ BDMAP_00003301
575
+ BDMAP_00004956
576
+ BDMAP_00003809
577
+ BDMAP_00002265
578
+ BDMAP_00002944
579
+ BDMAP_00004457
580
+ BDMAP_00001768
581
+ BDMAP_00001020
582
+ BDMAP_00000541
583
+ BDMAP_00000101
584
+ BDMAP_00003664
585
+ BDMAP_00003255
586
+ BDMAP_00001379
587
+ BDMAP_00002347
588
+ BDMAP_00000128
589
+ BDMAP_00002252
590
+ BDMAP_00001697
591
+ BDMAP_00002953
592
+ BDMAP_00001122
593
+ BDMAP_00003525
594
+ BDMAP_00003070
595
+ BDMAP_00004829
596
+ BDMAP_00002233
597
+ BDMAP_00001288
598
+ BDMAP_00002791
599
+ BDMAP_00004199
600
+ BDMAP_00004184
601
+ BDMAP_00003381
602
+ BDMAP_00001766
603
+ BDMAP_00003114
604
+ BDMAP_00004804
605
+ BDMAP_00002184
606
+ BDMAP_00001138
607
+ BDMAP_00000044
608
+ BDMAP_00002271
609
+ BDMAP_00003603
610
+ BDMAP_00001523
611
+ BDMAP_00004097
612
+ BDMAP_00002440
613
+ BDMAP_00004664
614
+ BDMAP_00003808
615
+ BDMAP_00000427
616
+ BDMAP_00002362
617
+ BDMAP_00005169
618
+ BDMAP_00000023
619
+ BDMAP_00003833
620
+ BDMAP_00001710
621
+ BDMAP_00001518
622
+ BDMAP_00004482
623
+ BDMAP_00003549
624
+ BDMAP_00002171
625
+ BDMAP_00002309
626
+ BDMAP_00000338
627
+ BDMAP_00000715
628
+ BDMAP_00003897
629
+ BDMAP_00003812
630
+ BDMAP_00004257
631
+ BDMAP_00001753
632
+ BDMAP_00000117
633
+ BDMAP_00001456
634
+ BDMAP_00004115
635
+ BDMAP_00003319
636
+ BDMAP_00003744
637
+ BDMAP_00004154
638
+ BDMAP_00003658
639
+ BDMAP_00001214
640
+ BDMAP_00004293
641
+ BDMAP_00001842
642
+ BDMAP_00001420
643
+ BDMAP_00003343
644
+ BDMAP_00001325
645
+ BDMAP_00000921
646
+ BDMAP_00002582
647
+ BDMAP_00002864
648
+ BDMAP_00000889
649
+ BDMAP_00001092
650
+ BDMAP_00000968
651
+ BDMAP_00002402
652
+ BDMAP_00004427
653
+ BDMAP_00001605
654
+ BDMAP_00000462
655
+ BDMAP_00005081
656
+ BDMAP_00002463
657
+ BDMAP_00000839
658
+ BDMAP_00000437
659
+ BDMAP_00000604
660
+ BDMAP_00001104
661
+ BDMAP_00001281
662
+ BDMAP_00000679
663
+ BDMAP_00004717
664
+ BDMAP_00001511
665
+ BDMAP_00003281
666
+ BDMAP_00001977
667
+ BDMAP_00000653
668
+ BDMAP_00000232
669
+ BDMAP_00004328
670
+ BDMAP_00002496
671
+ BDMAP_00000987
672
+ BDMAP_00003717
673
+ BDMAP_00004897
674
+ BDMAP_00003713
675
+ BDMAP_00002889
676
+ BDMAP_00003657
677
+ BDMAP_00002829
678
+ BDMAP_00004839
679
+ BDMAP_00001397
680
+ BDMAP_00001908
681
+ BDMAP_00003911
682
+ BDMAP_00004843
683
+ BDMAP_00004969
684
+ BDMAP_00003918
685
+ BDMAP_00004216
686
+ BDMAP_00000034
687
+ BDMAP_00003923
688
+ BDMAP_00000225
689
+ BDMAP_00003576
690
+ BDMAP_00002884
691
+ BDMAP_00002472
692
+ BDMAP_00001688
693
+ BDMAP_00001246
694
+ BDMAP_00004620
695
+ BDMAP_00005017
696
+ BDMAP_00002990
697
+ BDMAP_00000971
698
+ BDMAP_00004578
699
+ BDMAP_00001735
700
+ BDMAP_00002655
701
+ BDMAP_00000233
702
+ BDMAP_00001205
703
+ BDMAP_00003073
704
+ BDMAP_00003957
705
+ BDMAP_00001093
706
+ BDMAP_00003440
707
+ BDMAP_00001251
708
+ BDMAP_00004793
709
+ BDMAP_00000162
710
+ BDMAP_00003444
711
+ BDMAP_00001533
712
+ BDMAP_00003971
713
+ BDMAP_00001584
714
+ BDMAP_00000036
715
+ BDMAP_00002251
716
+ BDMAP_00003141
717
+ BDMAP_00002484
718
+ BDMAP_00004770
719
+ BDMAP_00001487
720
+ BDMAP_00001754
721
+ BDMAP_00003356
722
+ BDMAP_00000353
723
+ BDMAP_00001419
724
+ BDMAP_00001802
725
+ BDMAP_00003701
726
+ BDMAP_00005141
727
+ BDMAP_00000321
728
+ BDMAP_00001746
729
+ BDMAP_00000364
730
+ BDMAP_00003900
731
+ BDMAP_00001995
732
+ BDMAP_00001025
733
+ BDMAP_00004231
734
+ BDMAP_00000918
735
+ BDMAP_00001130
736
+ BDMAP_00003443
737
+ BDMAP_00003215
738
+ BDMAP_00004815
739
+ BDMAP_00002933
740
+ BDMAP_00000192
741
+ BDMAP_00003615
742
+ BDMAP_00004704
743
+ BDMAP_00001218
744
+ BDMAP_00002295
745
+ BDMAP_00000429
746
+ BDMAP_00000532
747
+ BDMAP_00001474
748
+ BDMAP_00003961
749
+ BDMAP_00004129
750
+ BDMAP_00000362
751
+ BDMAP_00002863
752
+ BDMAP_00003267
753
+ BDMAP_00001198
754
+ BDMAP_00000259
755
+ BDMAP_00000683
756
+ BDMAP_00001256
757
+ BDMAP_00003252
758
+ BDMAP_00004475
759
+ BDMAP_00004250
760
+ BDMAP_00004887
761
+ BDMAP_00000240
762
+ BDMAP_00003767
763
+ BDMAP_00003427
764
+ BDMAP_00000043
765
+ BDMAP_00003448
766
+ BDMAP_00001114
767
+ BDMAP_00001067
768
+ BDMAP_00001089
769
+ BDMAP_00002133
770
+ BDMAP_00004033
771
+ BDMAP_00002896
772
+ BDMAP_00003138
773
+ BDMAP_00001010
774
+ BDMAP_00001059
775
+ BDMAP_00004990
776
+ BDMAP_00000936
777
+ BDMAP_00001359
778
+ BDMAP_00005077
779
+ BDMAP_00000582
780
+ BDMAP_00004296
781
+ BDMAP_00005114
782
+ BDMAP_00004389
783
+ BDMAP_00004673
784
+ BDMAP_00003254
785
+ BDMAP_00003516
786
+ BDMAP_00001475
787
+ BDMAP_00002580
788
+ BDMAP_00002689
789
+ BDMAP_00004671
790
+ BDMAP_00003762
791
+ BDMAP_00003330
792
+ BDMAP_00002188
793
+ BDMAP_00001736
794
+ BDMAP_00002404
795
+ BDMAP_00003502
796
+ BDMAP_00004117
797
+ BDMAP_00004964
798
+ BDMAP_00002742
799
+ BDMAP_00000093
800
+ BDMAP_00002361
801
+ BDMAP_00000794
802
+ BDMAP_00002349
803
+ BDMAP_00001273
804
+ BDMAP_00000449
805
+ BDMAP_00001628
806
+ BDMAP_00002250
807
+ BDMAP_00004479
808
+ BDMAP_00000608
809
+ BDMAP_00001834
810
+ BDMAP_00002267
811
+ BDMAP_00001125
812
+ BDMAP_00000447
813
+ BDMAP_00005113
814
+ BDMAP_00004014
815
+ BDMAP_00001701
816
+ BDMAP_00003952
817
+ BDMAP_00003520
818
+ BDMAP_00000347
819
+ BDMAP_00002936
820
+ BDMAP_00001624
821
+ BDMAP_00000104
822
+ BDMAP_00002487
823
+ BDMAP_00005020
824
+ BDMAP_00000511
825
+ BDMAP_00001564
826
+ BDMAP_00004294
827
+ BDMAP_00004858
828
+ BDMAP_00004608
829
+ BDMAP_00003650
830
+ BDMAP_00001171
831
+ BDMAP_00000059
832
+ BDMAP_00000871
833
+ BDMAP_00003996
834
+ BDMAP_00001169
835
+ BDMAP_00003363
836
+ BDMAP_00003376
837
+ BDMAP_00002167
838
+ BDMAP_00002737
839
+ BDMAP_00003694
840
+ BDMAP_00001396
841
+ BDMAP_00005083
842
+ BDMAP_00002918
843
+ BDMAP_00003580
844
+ BDMAP_00001324
845
+ BDMAP_00002855
846
+ BDMAP_00001649
847
+ BDMAP_00004459
848
+ BDMAP_00002288
849
+ BDMAP_00004830
850
+ BDMAP_00004775
851
+ BDMAP_00000279
852
+ BDMAP_00002114
853
+ BDMAP_00005185
854
+ BDMAP_00004885
855
+ BDMAP_00000236
856
+ BDMAP_00003928
857
+ BDMAP_00002663
858
+ BDMAP_00002282
859
+ BDMAP_00003798
860
+ BDMAP_00001055
861
+ BDMAP_00002945
862
+ BDMAP_00001316
863
+ BDMAP_00003451
864
+ BDMAP_00000696
865
+ BDMAP_00003143
866
+ BDMAP_00001522
867
+ BDMAP_00000452
868
+ BDMAP_00002603
869
+ BDMAP_00004131
870
+ BDMAP_00001045
871
+ BDMAP_00004954
872
+ BDMAP_00003358
873
+ BDMAP_00000980
874
+ BDMAP_00001343
875
+ BDMAP_00001410
876
+ BDMAP_00002173
877
+ BDMAP_00002840
878
+ BDMAP_00001200
879
+ BDMAP_00001905
880
+ BDMAP_00003425
881
+ BDMAP_00003672
882
+ BDMAP_00003781
883
+ BDMAP_00001906
884
+ BDMAP_00004636
885
+ BDMAP_00000836
886
+ BDMAP_00002076
887
+ BDMAP_00001230
888
+ BDMAP_00003932
889
+ BDMAP_00002029
890
+ BDMAP_00000331
891
+ BDMAP_00000197
892
+ BDMAP_00001539
893
+ BDMAP_00003524
894
+ BDMAP_00001692
895
+ BDMAP_00004550
896
+ BDMAP_00004331
897
+ BDMAP_00004825
898
+ BDMAP_00002411
899
+ BDMAP_00003484
900
+ BDMAP_00000480
901
+ BDMAP_00003886
902
+ BDMAP_00001907
903
+ BDMAP_00002746
904
+ BDMAP_00002899
905
+ BDMAP_00004420
906
+ BDMAP_00002748
907
+ BDMAP_00003002
908
+ BDMAP_00003300
909
+ BDMAP_00005108
910
+ BDMAP_00003400
911
+ BDMAP_00002283
912
+ BDMAP_00003486
913
+ BDMAP_00000935
914
+ BDMAP_00001618
915
+ BDMAP_00001565
916
+ BDMAP_00000616
917
+ BDMAP_00000810
918
+ BDMAP_00001826
919
+ BDMAP_00000470
920
+ BDMAP_00002017
921
+ BDMAP_00003631
922
+ BDMAP_00001242
923
+ BDMAP_00000907
924
+ BDMAP_00001806
925
+ BDMAP_00002854
926
+ BDMAP_00002902
927
+ BDMAP_00003017
928
+ BDMAP_00001471
Generation_Pipeline_filter_all2/syn_colon/requirements.txt ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.1.0
2
+ accelerate==0.11.0
3
+ aiohttp==3.8.1
4
+ aiosignal==1.2.0
5
+ antlr4-python3-runtime==4.9.3
6
+ async-timeout==4.0.2
7
+ attrs==21.4.0
8
+ autopep8==1.6.0
9
+ cachetools==5.2.0
10
+ certifi==2022.6.15
11
+ charset-normalizer==2.0.12
12
+ click==8.1.3
13
+ cycler==0.11.0
14
+ Deprecated==1.2.13
15
+ docker-pycreds==0.4.0
16
+ einops==0.4.1
17
+ einops-exts==0.0.3
18
+ ema-pytorch==0.0.8
19
+ fonttools==4.34.4
20
+ frozenlist==1.3.0
21
+ fsspec==2022.5.0
22
+ ftfy==6.1.1
23
+ future==0.18.2
24
+ gitdb==4.0.9
25
+ GitPython==3.1.27
26
+ google-auth==2.9.0
27
+ google-auth-oauthlib==0.4.6
28
+ grpcio==1.47.0
29
+ h5py==3.7.0
30
+ humanize==4.2.2
31
+ hydra-core==1.2.0
32
+ idna==3.3
33
+ imageio==2.19.3
34
+ imageio-ffmpeg==0.4.7
35
+ importlib-metadata==4.12.0
36
+ importlib-resources==5.9.0
37
+ joblib==1.1.0
38
+ kiwisolver==1.4.3
39
+ lxml==4.9.1
40
+ Markdown==3.3.7
41
+ matplotlib==3.5.2
42
+ multidict==6.0.2
43
+ networkx==2.8.5
44
+ nibabel==4.0.1
45
+ nilearn==0.9.1
46
+ numpy==1.23.0
47
+ oauthlib==3.2.0
48
+ omegaconf==2.2.3
49
+ pandas==1.4.3
50
+ Pillow==9.1.1
51
+ pyasn1==0.4.8
52
+ pyasn1-modules==0.2.8
53
+ pycodestyle==2.8.0
54
+ pyDeprecate==0.3.1
55
+ pydicom==2.3.0
56
+ pytorch-lightning==1.6.4
57
+ pytz==2022.1
58
+ PyWavelets==1.3.0
59
+ PyYAML==6.0
60
+ pyzmq==19.0.2
61
+ regex==2022.6.2
62
+ requests==2.28.0
63
+ requests-oauthlib==1.3.1
64
+ rotary-embedding-torch==0.1.5
65
+ rsa==4.8
66
+ scikit-image==0.19.3
67
+ scikit-learn==1.1.2
68
+ scikit-video==1.1.11
69
+ scipy==1.8.1
70
+ seaborn==0.11.2
71
+ sentry-sdk==1.7.2
72
+ setproctitle==1.2.3
73
+ shortuuid==1.0.9
74
+ SimpleITK==2.1.1.2
75
+ smmap==5.0.0
76
+ tensorboard==2.9.1
77
+ tensorboard-data-server==0.6.1
78
+ tensorboard-plugin-wit==1.8.1
79
+ threadpoolctl==3.1.0
80
+ tifffile==2022.8.3
81
+ toml==0.10.2
82
+ torch-tb-profiler==0.4.0
83
+ torchio==0.18.80
84
+ torchmetrics==0.9.1
85
+ tqdm==4.64.0
86
+ typing_extensions==4.2.0
87
+ urllib3==1.26.9
88
+ wandb==0.12.21
89
+ Werkzeug==2.1.2
90
+ wrapt==1.14.1
91
+ yarl==1.7.2
92
+ zipp==3.8.0
93
+ wandb
94
+ tensorboardX==2.4.1
Generation_Pipeline_filter_all2/syn_kidney/CT_syn_kidney_data_new.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, time, csv
2
+ import numpy as np
3
+ import torch
4
+ from sklearn.metrics import confusion_matrix
5
+ from scipy import ndimage
6
+ from scipy.ndimage import label
7
+ from functools import partial
8
+ import monai
9
+ from monai.transforms import AsDiscrete,AsDiscreted,Compose,Invertd,SaveImaged
10
+ from monai import transforms, data
11
+ from TumorGeneration.utils import synthesize_early_tumor, synthesize_medium_tumor, synthesize_large_tumor, synt_model_prepare
12
+ import nibabel as nib
13
+
14
+ import warnings
15
+ warnings.filterwarnings("ignore")
16
+
17
+ import argparse
18
+ parser = argparse.ArgumentParser(description='kidney tumor validation')
19
+
20
+ # file dir
21
+ parser.add_argument('--data_root', default=None, type=str)
22
+ parser.add_argument('--organ_type', default='kidney', type=str)
23
+ parser.add_argument('--save_dir', default='out', type=str)
24
+ parser.add_argument('--data_file', default='out', type=str)
25
+ parser.add_argument('--ddim_ts', default=50, type=int)
26
+ parser.add_argument('--fg_thresh', default=30, type=int)
27
+ parser.add_argument('--start', default=0, type=int)
28
+ parser.add_argument('--end', default=1000, type=int)
29
+
30
+ def voxel2R(A):
31
+ return (np.array(A)/4*3/np.pi)**(1/3)
32
+
33
+ class RandCropByPosNegLabeld_select(transforms.RandCropByPosNegLabeld):
34
+ def __init__(self, keys, label_key, spatial_size,
35
+ pos=1.0, neg=1.0, num_samples=1,
36
+ image_key=None, image_threshold=0.0, allow_missing_keys=True,
37
+ fg_thresh=0):
38
+ super().__init__(keys=keys, label_key=label_key, spatial_size=spatial_size,
39
+ pos=pos, neg=neg, num_samples=num_samples,
40
+ image_key=image_key, image_threshold=image_threshold, allow_missing_keys=allow_missing_keys)
41
+ self.fg_thresh = fg_thresh
42
+
43
+ def R2voxel(self,R):
44
+ return (4/3*np.pi)*(R)**(3)
45
+
46
+ def __call__(self, data):
47
+ d = dict(data)
48
+ data_name = d['name']
49
+ d.pop('name')
50
+
51
+ if '10_Decathlon' in data_name or '05_KiTS' in data_name:
52
+ d_crop = super().__call__(d)
53
+
54
+ else:
55
+ flag=0
56
+ while 1:
57
+ flag+=1
58
+
59
+ d_crop = super().__call__(d)
60
+ pixel_num = (d_crop[0]['label']>0).sum()
61
+
62
+ if pixel_num > self.R2voxel(self.fg_thresh):
63
+ break
64
+ if flag>5 and pixel_num > self.R2voxel(max(self.fg_thresh-5, 5)):
65
+ break
66
+ if flag>10 and pixel_num > self.R2voxel(max(self.fg_thresh-10, 5)):
67
+ break
68
+ if flag>15 and pixel_num > self.R2voxel(max(self.fg_thresh-15, 5)):
69
+ break
70
+ if flag>20 and pixel_num > self.R2voxel(max(self.fg_thresh-20, 5)):
71
+ break
72
+ if flag>25 and pixel_num > self.R2voxel(max(self.fg_thresh-25, 5)):
73
+ break
74
+ if flag>50:
75
+ break
76
+
77
+ d_crop[0]['name'] = data_name
78
+
79
+ return d_crop
80
+
81
+ def _get_loader(args):
82
+ # val_data_dir = args.val_dir
83
+ # datalist_json = args.json_dir
84
+ val_org_transform = transforms.Compose(
85
+ [
86
+ transforms.LoadImaged(keys=["image", "label", "raw_image"]),
87
+ transforms.AddChanneld(keys=["image", "label", "raw_image"]),
88
+ transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
89
+ transforms.Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear")),
90
+ transforms.ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
91
+ transforms.SpatialPadd(keys=["image", "label"], mode=["minimum", "constant"], spatial_size=[96, 96, 96]),
92
+ RandCropByPosNegLabeld_select(
93
+ keys=["image", "label", "name"],
94
+ label_key="label",
95
+ spatial_size=(96, 96, 96),
96
+ pos=1,
97
+ neg=0,
98
+ num_samples=1,
99
+ image_key="image",
100
+ image_threshold=0,
101
+ fg_thresh = args.fg_thresh,
102
+ ),
103
+ transforms.ToTensord(keys=["image", "label", "raw_image"]),
104
+ ]
105
+ )
106
+
107
+ val_img=[]
108
+ val_lbl=[]
109
+ val_name=[]
110
+
111
+ for line in open(args.data_file):
112
+ # name = line.strip().split()[1].split('.')[0]
113
+ # val_img.append(args.data_root + line.strip().split()[0])
114
+ # val_lbl.append(args.data_root + line.strip().split()[1])
115
+ # breakpoint()
116
+ name = line.strip()
117
+ val_img.append(os.path.join(args.data_root, name, 'ct.nii.gz'))
118
+ val_lbl.append(os.path.join(args.data_root, name, 'segmentations/kidney_left.nii.gz'))
119
+ val_name.append(name)
120
+ data_dicts_val = [{'image': image, 'raw_image':image, 'label': label, 'name': name}
121
+ for image, label, name in zip(val_img, val_lbl, val_name)]
122
+
123
+ if args.end < len(data_dicts_val):
124
+ data_dicts_val = data_dicts_val[args.start:args.end]
125
+ else:
126
+ data_dicts_val = data_dicts_val[args.start:]
127
+ print('val len {}'.format(len(data_dicts_val)))
128
+ val_org_ds = data.Dataset(data_dicts_val, transform=val_org_transform)
129
+ val_org_loader = data.DataLoader(val_org_ds, batch_size=1, shuffle=False, num_workers=4, sampler=None, pin_memory=True)
130
+
131
+ post_transforms = Compose([
132
+ Invertd(
133
+ keys=['image'],
134
+ transform=val_org_transform,
135
+ orig_keys="image",
136
+ nearest_interp=False,
137
+ # nearest_interp=True,
138
+ to_tensor=True,
139
+ ),
140
+ Invertd(
141
+ keys=['label'],
142
+ transform=val_org_transform,
143
+ orig_keys="label",
144
+ nearest_interp=False,
145
+ # nearest_interp=True,
146
+ to_tensor=True,
147
+ )
148
+ ])
149
+ return val_org_loader, post_transforms
150
+
151
+ def main():
152
+ args = parser.parse_args()
153
+ output_dir = args.save_dir
154
+ if not os.path.exists(output_dir):
155
+ os.makedirs(output_dir)
156
+ print("MAIN Argument values:")
157
+ for k, v in vars(args).items():
158
+ print(k, '=>', v)
159
+ print('-----------------')
160
+
161
+ ## loader and post_transform
162
+ val_loader, post_transforms = _get_loader(args)
163
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
164
+ model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)
165
+ model.load_state_dict(torch.load("../best_metric_model_classification3d_dict.pth"))
166
+ model.eval()
167
+
168
+ start_time=0
169
+ with torch.no_grad():
170
+ for idx, val_data in enumerate(val_loader):
171
+ print('idx',idx)
172
+ if idx == 0:
173
+ start_time = time.time()
174
+ # val_inputs = val_data["image"]
175
+ # name = val_data['label_meta_dict']['filename_or_obj'][0].split('/')[-1].split('.')[0]
176
+
177
+ vqgan, early_sampler, noearly_sampler= synt_model_prepare(device = torch.device("cuda"), fold=0, organ=args.organ_type)
178
+
179
+ healthy_data, healthy_target, data_names, raw_data = val_data['image'], val_data['label'], val_data['name'], val_data['raw_image']
180
+ case_name = data_names[0].split('/')[-1]
181
+ print('case_name', case_name)
182
+ original_affine = val_data["label_meta_dict"]["original_affine"][0].numpy()
183
+
184
+ if healthy_target.sum() == 0:
185
+ val_data = [post_transforms(i) for i in data.decollate_batch(val_data)]
186
+ tumor_mask = val_data[0]['label'][0].cpu().numpy().astype(np.uint8)
187
+ tumor_mask_ = np.zeros_like(tumor_mask)
188
+ nib.save(nib.Nifti1Image(tumor_mask_, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/kidney_tumor.nii.gz'))
189
+ continue
190
+
191
+ healthy_data, healthy_target = healthy_data.cuda(), healthy_target.cuda()
192
+ healthy_target = (healthy_target>0).to(healthy_target)
193
+
194
+ tumor_types = ['early', 'medium', 'large']
195
+ # tumor_probs = np.array([0.45, 0.45, 0.1])
196
+ # tumor_probs = np.array([1.0, 0.0, 0.0])
197
+ tumor_probs = np.array([0.5, 0.4, 0.1])
198
+ synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel())
199
+ print('synthetic_tumor_type',synthetic_tumor_type)
200
+ flag=0
201
+ while 1:
202
+ if synthetic_tumor_type == 'early':
203
+ synt_data, synt_target = synthesize_early_tumor(healthy_data, healthy_target, args.organ_type, vqgan, early_sampler)
204
+ elif synthetic_tumor_type == 'medium':
205
+ synt_data, synt_target = synthesize_medium_tumor(healthy_data, healthy_target, args.organ_type, vqgan, noearly_sampler, ddim_ts=args.ddim_ts)
206
+ elif synthetic_tumor_type == 'large':
207
+ synt_data, synt_target = synthesize_large_tumor(healthy_data, healthy_target, args.organ_type, vqgan, noearly_sampler, ddim_ts=args.ddim_ts)
208
+
209
+ syn_confidence = model(synt_data).sigmoid()[:,1]
210
+ flag+=1
211
+ if syn_confidence>0.005:
212
+ break
213
+ elif flag > 20 and syn_confidence>0.001:
214
+ break
215
+ val_data['image'] = synt_data.detach()
216
+ val_data['label'] = synt_target.detach()
217
+
218
+ val_data = [post_transforms(i) for i in data.decollate_batch(val_data)]
219
+ synt_data = val_data[0]['image'][0]
220
+ synt_target = val_data[0]['label'][0]
221
+ final_data = raw_data[0,0]
222
+
223
+ synt_data = (synt_data*(250+175)-175)
224
+ final_data[synt_target>1] = synt_data[synt_target>1]
225
+ final_data = final_data.cpu().numpy()
226
+ final_label = (synt_target>=1.5).cpu().numpy().astype(np.uint8)
227
+
228
+ os.makedirs(os.path.join(output_dir, f'{case_name}'), exist_ok=True)
229
+ os.makedirs(os.path.join(output_dir, f'{case_name}/segmentations'), exist_ok=True)
230
+ nib.save(nib.Nifti1Image(final_data, original_affine), os.path.join(output_dir, f'{case_name}/ct.nii.gz'))
231
+ nib.save(nib.Nifti1Image(final_label, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/kidney_tumor.nii.gz'))
232
+
233
+ print('time = ', time.time()-start_time)
234
+ start_time = time.time()
235
+
236
+
237
+ if __name__ == "__main__":
238
+ main()
Generation_Pipeline_filter_all2/syn_kidney/CT_syn_kidney_data_new2.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, time, csv
2
+ import numpy as np
3
+ import torch
4
+ from sklearn.metrics import confusion_matrix
5
+ from scipy import ndimage
6
+ from scipy.ndimage import label
7
+ from functools import partial
8
+ import monai
9
+ from monai.transforms import AsDiscrete,AsDiscreted,Compose,Invertd,SaveImaged
10
+ from monai import transforms, data
11
+ from TumorGeneration.utils import synthesize_early_tumor, synthesize_medium_tumor, synthesize_large_tumor, synt_model_prepare
12
+ import nibabel as nib
13
+
14
+ import warnings
15
+ warnings.filterwarnings("ignore")
16
+
17
+ import argparse
18
+ parser = argparse.ArgumentParser(description='kidney tumor validation')
19
+
20
+ # file dir
21
+ parser.add_argument('--data_root', default=None, type=str)
22
+ parser.add_argument('--organ_type', default='kidney', type=str)
23
+ parser.add_argument('--save_dir', default='out', type=str)
24
+ parser.add_argument('--data_file', default='out', type=str)
25
+ parser.add_argument('--ddim_ts', default=50, type=int)
26
+ parser.add_argument('--fg_thresh', default=30, type=int)
27
+ parser.add_argument('--start', default=0, type=int)
28
+ parser.add_argument('--end', default=1000, type=int)
29
+
30
+ def voxel2R(A):
31
+ return (np.array(A)/4*3/np.pi)**(1/3)
32
+
33
+ class RandCropByPosNegLabeld_select(transforms.RandCropByPosNegLabeld):
34
+ def __init__(self, keys, label_key, spatial_size,
35
+ pos=1.0, neg=1.0, num_samples=1,
36
+ image_key=None, image_threshold=0.0, allow_missing_keys=True,
37
+ fg_thresh=0):
38
+ super().__init__(keys=keys, label_key=label_key, spatial_size=spatial_size,
39
+ pos=pos, neg=neg, num_samples=num_samples,
40
+ image_key=image_key, image_threshold=image_threshold, allow_missing_keys=allow_missing_keys)
41
+ self.fg_thresh = fg_thresh
42
+
43
+ def R2voxel(self,R):
44
+ return (4/3*np.pi)*(R)**(3)
45
+
46
+ def __call__(self, data):
47
+ d = dict(data)
48
+ data_name = d['name']
49
+ d.pop('name')
50
+
51
+ if '10_Decathlon' in data_name or '05_KiTS' in data_name:
52
+ d_crop = super().__call__(d)
53
+
54
+ else:
55
+ flag=0
56
+ while 1:
57
+ flag+=1
58
+
59
+ d_crop = super().__call__(d)
60
+ pixel_num = (d_crop[0]['label']>0).sum()
61
+
62
+ if pixel_num > self.R2voxel(self.fg_thresh):
63
+ break
64
+ if flag>5 and pixel_num > self.R2voxel(max(self.fg_thresh-5, 5)):
65
+ break
66
+ if flag>10 and pixel_num > self.R2voxel(max(self.fg_thresh-10, 5)):
67
+ break
68
+ if flag>15 and pixel_num > self.R2voxel(max(self.fg_thresh-15, 5)):
69
+ break
70
+ if flag>20 and pixel_num > self.R2voxel(max(self.fg_thresh-20, 5)):
71
+ break
72
+ if flag>25 and pixel_num > self.R2voxel(max(self.fg_thresh-25, 5)):
73
+ break
74
+ if flag>50:
75
+ break
76
+
77
+ d_crop[0]['name'] = data_name
78
+
79
+ return d_crop
80
+
81
+ def _get_loader(args):
82
+ # val_data_dir = args.val_dir
83
+ # datalist_json = args.json_dir
84
+ val_org_transform = transforms.Compose(
85
+ [
86
+ transforms.LoadImaged(keys=["image", "label", 'tumor_label', "raw_image"]),
87
+ transforms.AddChanneld(keys=["image", "label", 'tumor_label', "raw_image"]),
88
+ transforms.Orientationd(keys=["image", "label", 'tumor_label'], axcodes="RAS"),
89
+ transforms.Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear")),
90
+ transforms.ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
91
+ transforms.SpatialPadd(keys=["image", "label"], mode=["minimum", "constant"], spatial_size=[96, 96, 96]),
92
+ RandCropByPosNegLabeld_select(
93
+ keys=["image", "label", "name"],
94
+ label_key="label",
95
+ spatial_size=(96, 96, 96),
96
+ pos=1,
97
+ neg=0,
98
+ num_samples=1,
99
+ image_key="image",
100
+ image_threshold=0,
101
+ fg_thresh = args.fg_thresh,
102
+ ),
103
+ transforms.ToTensord(keys=["image", "label", 'tumor_label', "raw_image"]),
104
+ ]
105
+ )
106
+
107
+ val_img=[]
108
+ val_lbl=[]
109
+ val_name=[]
110
+ tumor_lbl=[]
111
+ for line in open(args.data_file):
112
+ # name = line.strip().split()[1].split('.')[0]
113
+ # val_img.append(args.data_root + line.strip().split()[0])
114
+ # val_lbl.append(args.data_root + line.strip().split()[1])
115
+ # breakpoint()
116
+ name = line.strip()
117
+ val_img.append(os.path.join(args.data_root, name, 'ct.nii.gz'))
118
+ val_lbl.append(os.path.join(args.data_root, name, 'segmentations/kidney_right.nii.gz'))
119
+ tumor_lbl.append(os.path.join(args.data_root, name, 'segmentations/kidney_tumor.nii.gz'))
120
+ val_name.append(name)
121
+ data_dicts_val = [{'image': image, 'raw_image':image, 'label': label, 'tumor_label':tumor_label,'name': name}
122
+ for image, label, tumor_label, name in zip(val_img, val_lbl, tumor_lbl, val_name)]
123
+
124
+ if args.end < len(data_dicts_val):
125
+ data_dicts_val = data_dicts_val[args.start:args.end]
126
+ else:
127
+ data_dicts_val = data_dicts_val[args.start:]
128
+ print('val len {}'.format(len(data_dicts_val)))
129
+ val_org_ds = data.Dataset(data_dicts_val, transform=val_org_transform)
130
+ val_org_loader = data.DataLoader(val_org_ds, batch_size=1, shuffle=False, num_workers=4, sampler=None, pin_memory=True)
131
+
132
+ post_transforms = Compose([
133
+ Invertd(
134
+ keys=['image'],
135
+ transform=val_org_transform,
136
+ orig_keys="image",
137
+ nearest_interp=False,
138
+ # nearest_interp=True,
139
+ to_tensor=True,
140
+ ),
141
+ Invertd(
142
+ keys=['label'],
143
+ transform=val_org_transform,
144
+ orig_keys="label",
145
+ nearest_interp=False,
146
+ # nearest_interp=True,
147
+ to_tensor=True,
148
+ ),
149
+ Invertd(
150
+ keys=['tumor_label'],
151
+ transform=val_org_transform,
152
+ orig_keys="tumor_label",
153
+ nearest_interp=False,
154
+ to_tensor=True,
155
+ )
156
+ ])
157
+ return val_org_loader, post_transforms
158
+
159
+ def main():
160
+ args = parser.parse_args()
161
+ output_dir = args.save_dir
162
+ if not os.path.exists(output_dir):
163
+ os.makedirs(output_dir)
164
+ print("MAIN Argument values:")
165
+ for k, v in vars(args).items():
166
+ print(k, '=>', v)
167
+ print('-----------------')
168
+
169
+ ## loader and post_transform
170
+ val_loader, post_transforms = _get_loader(args)
171
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
172
+ model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)
173
+ model.load_state_dict(torch.load("../best_metric_model_classification3d_dict.pth"))
174
+ model.eval()
175
+ start_time=0
176
+ with torch.no_grad():
177
+ for idx, val_data in enumerate(val_loader):
178
+ print('idx',idx)
179
+ if idx == 0:
180
+ start_time = time.time()
181
+ # val_inputs = val_data["image"]
182
+ # name = val_data['label_meta_dict']['filename_or_obj'][0].split('/')[-1].split('.')[0]
183
+
184
+ vqgan, early_sampler, noearly_sampler= synt_model_prepare(device = torch.device("cuda"), fold=0, organ=args.organ_type)
185
+
186
+ healthy_data, healthy_target, data_names, raw_data = val_data['image'], val_data['label'], val_data['name'], val_data['raw_image']
187
+ case_name = data_names[0].split('/')[-1]
188
+ print('case_name', case_name)
189
+ original_affine = val_data["label_meta_dict"]["original_affine"][0].numpy()
190
+
191
+ if healthy_target.sum() == 0:
192
+ # val_data = [post_transforms(i) for i in data.decollate_batch(val_data)]
193
+ # tumor_mask = val_data[0]['label'][0].cpu().numpy().astype(np.uint8)
194
+ # tumor_mask_ = np.zeros_like(tumor_mask)
195
+ # nib.save(nib.Nifti1Image(tumor_mask_, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/kidney_tumor.nii.gz'))
196
+ continue
197
+
198
+ healthy_data, healthy_target = healthy_data.cuda(), healthy_target.cuda()
199
+ healthy_target = (healthy_target>0).to(healthy_target)
200
+
201
+ tumor_types = ['early', 'medium', 'large']
202
+ # tumor_probs = np.array([0.45, 0.45, 0.1])
203
+ # tumor_probs = np.array([1.0, 0.0, 0.0])
204
+ tumor_probs = np.array([0.5, 0.4, 0.1])
205
+ synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel())
206
+ print('synthetic_tumor_type',synthetic_tumor_type)
207
+ flag=0
208
+ while 1:
209
+ if synthetic_tumor_type == 'early':
210
+ synt_data, synt_target = synthesize_early_tumor(healthy_data, healthy_target, args.organ_type, vqgan, early_sampler)
211
+ elif synthetic_tumor_type == 'medium':
212
+ synt_data, synt_target = synthesize_medium_tumor(healthy_data, healthy_target, args.organ_type, vqgan, noearly_sampler, ddim_ts=args.ddim_ts)
213
+ elif synthetic_tumor_type == 'large':
214
+ synt_data, synt_target = synthesize_large_tumor(healthy_data, healthy_target, args.organ_type, vqgan, noearly_sampler, ddim_ts=args.ddim_ts)
215
+
216
+ syn_confidence = model(synt_data).sigmoid()[:,1]
217
+ flag+=1
218
+ if syn_confidence>0.005:
219
+ break
220
+ elif flag > 20 and syn_confidence>0.001:
221
+ break
222
+ val_data['image'] = synt_data.detach()
223
+ val_data['label'] = synt_target.detach()
224
+
225
+ val_data = [post_transforms(i) for i in data.decollate_batch(val_data)]
226
+ synt_data = val_data[0]['image'][0]
227
+ synt_target = val_data[0]['label'][0]
228
+ tumor_mask = val_data[0]['tumor_label'][0]
229
+ final_data = raw_data[0,0]
230
+
231
+ synt_data = (synt_data*(250+175)-175)
232
+ final_data[synt_target>1] = synt_data[synt_target>1]
233
+ final_data = final_data.cpu().numpy()
234
+ final_label = (synt_target>=1.5).cpu().numpy().astype(np.uint8)
235
+ final_label[tumor_mask==1] = 1
236
+
237
+ os.makedirs(os.path.join(output_dir, f'{case_name}'), exist_ok=True)
238
+ os.makedirs(os.path.join(output_dir, f'{case_name}/segmentations'), exist_ok=True)
239
+ nib.save(nib.Nifti1Image(final_data, original_affine), os.path.join(output_dir, f'{case_name}/ct.nii.gz'))
240
+ nib.save(nib.Nifti1Image(final_label, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/kidney_tumor.nii.gz'))
241
+
242
+ print('time = ', time.time()-start_time)
243
+ start_time = time.time()
244
+
245
+
246
+ if __name__ == "__main__":
247
+ main()
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/.DS_Store ADDED
Binary file (6.15 kB). View file
 
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ```bash
2
+ wget https://www.dropbox.com/scl/fi/k856fhk60kck8uqxtxazw/model_weight.tar.gz?rlkey=hrcn4cbt690dzern1bkbfejll
3
+ mv model_weight.tar.gz?rlkey=hrcn4cbt690dzern1bkbfejll model_weight.tar.gz
4
+ tar -xzvf model_weight.tar.gz
5
+ ```
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/TumorGenerated.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Hashable, Mapping, Dict
3
+
4
+ from monai.config import KeysCollection
5
+ from monai.config.type_definitions import NdarrayOrTensor
6
+ from monai.transforms.transform import MapTransform, RandomizableTransform
7
+
8
+ from .utils_ import SynthesisTumor
9
+ import numpy as np
10
+
11
+ class TumorGenerated(RandomizableTransform, MapTransform):
12
+ def __init__(self,
13
+ keys: KeysCollection,
14
+ prob: float = 0.1,
15
+ tumor_prob = [0.2, 0.2, 0.2, 0.2, 0.2],
16
+ allow_missing_keys: bool = False
17
+ ) -> None:
18
+ MapTransform.__init__(self, keys, allow_missing_keys)
19
+ RandomizableTransform.__init__(self, prob)
20
+ random.seed(0)
21
+ np.random.seed(0)
22
+
23
+ self.tumor_types = ['tiny', 'small', 'medium', 'large', 'mix']
24
+
25
+ assert len(tumor_prob) == 5
26
+ self.tumor_prob = np.array(tumor_prob)
27
+
28
+
29
+
30
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
31
+ d = dict(data)
32
+ self.randomize(None)
33
+
34
+ if self._do_transform and (np.max(d['label']) <= 1):
35
+ tumor_type = np.random.choice(self.tumor_types, p=self.tumor_prob.ravel())
36
+
37
+ d['image'][0], d['label'][0] = SynthesisTumor(d['image'][0], d['label'][0], tumor_type)
38
+ # print(tumor_type, d['image'].shape, np.max(d['label']))
39
+ return d
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ### Online Version TumorGeneration ###
2
+
3
+ from .TumorGenerated import TumorGenerated
4
+
5
+ from .utils import synthesize_early_tumor, synthesize_medium_tumor, synthesize_large_tumor
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc ADDED
Binary file (1.64 kB). View file
 
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (320 Bytes). View file
 
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/__pycache__/utils.cpython-38.pyc ADDED
Binary file (11 kB). View file
 
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/__pycache__/utils_.cpython-38.pyc ADDED
Binary file (7.22 kB). View file
 
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/diffusion_config/ddpm.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ vqgan_ckpt: None
2
+
3
+ # Have to be derived from VQ-GAN Latent space dimensions
4
+ diffusion_img_size: 24
5
+ diffusion_depth_size: 24
6
+ diffusion_num_channels: 17 # 17
7
+ out_dim: 8
8
+ dim_mults: [1,2,4,8]
9
+ results_folder: checkpoints/ddpm/
10
+ results_folder_postfix: 'own_dataset_t2'
11
+ load_milestone: False # False
12
+
13
+ batch_size: 2 # 40
14
+ num_workers: 20
15
+ logger: wandb
16
+ objective: pred_x0
17
+ save_and_sample_every: 1000
18
+ denoising_fn: Unet3D
19
+ train_lr: 1e-4
20
+ timesteps: 2 # number of steps
21
+ sampling_timesteps: 250 # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
22
+ loss_type: l1 # L1 or L2
23
+ train_num_steps: 700000 # total training steps
24
+ gradient_accumulate_every: 2 # gradient accumulation steps
25
+ ema_decay: 0.995 # exponential moving average decay
26
+ amp: False # turn on mixed precision
27
+ num_sample_rows: 1
28
+ gpus: 0
29
+
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/diffusion_config/vq_gan_3d.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 1234
2
+ batch_size: 2 # 30
3
+ num_workers: 32 # 30
4
+
5
+ gpus: 1
6
+ accumulate_grad_batches: 1
7
+ default_root_dir: checkpoints/vq_gan/
8
+ default_root_dir_postfix: 'flair'
9
+ resume_from_checkpoint:
10
+ max_steps: -1
11
+ max_epochs: -1
12
+ precision: 16
13
+ gradient_clip_val: 1.0
14
+
15
+
16
+ embedding_dim: 8 # 256
17
+ n_codes: 16384 # 2048
18
+ n_hiddens: 16
19
+ lr: 3e-4
20
+ downsample: [2, 2, 2] # [4, 4, 4]
21
+ disc_channels: 64
22
+ disc_layers: 3
23
+ discriminator_iter_start: 10000 # 50000
24
+ disc_loss_type: hinge
25
+ image_gan_weight: 1.0
26
+ video_gan_weight: 1.0
27
+ l1_weight: 4.0
28
+ gan_feat_weight: 4.0 # 0.0
29
+ perceptual_weight: 4.0 # 0.0
30
+ i3d_feat: False
31
+ restart_thres: 1.0
32
+ no_random_restart: False
33
+ norm_type: group
34
+ padding_type: replicate
35
+ num_groups: 32
36
+
37
+
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .diffusion import Unet3D, GaussianDiffusion, Tester
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (258 Bytes). View file
 
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc ADDED
Binary file (6.03 kB). View file
 
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc ADDED
Binary file (28.5 kB). View file
 
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc ADDED
Binary file (1.88 kB). View file
 
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc ADDED
Binary file (2.85 kB). View file
 
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc ADDED
Binary file (5.77 kB). View file
 
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc ADDED
Binary file (9.44 kB). View file
 
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/ddim.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from .util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
+
10
+
11
+ class DDIMSampler(object):
12
+ def __init__(self, model, schedule="linear", **kwargs):
13
+ super().__init__()
14
+ self.model = model
15
+ self.ddpm_num_timesteps = model.num_timesteps
16
+ self.schedule = schedule
17
+
18
+ def register_buffer(self, name, attr):
19
+ if type(attr) == torch.Tensor:
20
+ if attr.device != torch.device("cuda"):
21
+ attr = attr.to(torch.device("cuda"))
22
+ setattr(self, name, attr)
23
+
24
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): # "uniform" 'quad'
25
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
26
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
27
+
28
+ alphas_cumprod = self.model.alphas_cumprod
29
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
30
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
31
+
32
+ self.register_buffer('betas', to_torch(self.model.betas))
33
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
34
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
35
+
36
+ # calculations for diffusion q(x_t | x_{t-1}) and others
37
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
38
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
39
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
40
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
41
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
42
+ # breakpoint()
43
+ # ddim sampling parameters
44
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
45
+ ddim_timesteps=self.ddim_timesteps,
46
+ eta=ddim_eta,verbose=verbose)
47
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
48
+ self.register_buffer('ddim_alphas', ddim_alphas)
49
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
50
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
51
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
52
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
53
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
54
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
55
+
56
+ @torch.no_grad()
57
+ def sample(self,
58
+ S,
59
+ batch_size,
60
+ shape,
61
+ conditioning=None,
62
+ callback=None,
63
+ normals_sequence=None,
64
+ img_callback=None,
65
+ quantize_x0=False,
66
+ eta=0.,
67
+ mask=None,
68
+ x0=None,
69
+ temperature=1.,
70
+ noise_dropout=0.,
71
+ score_corrector=None,
72
+ corrector_kwargs=None,
73
+ verbose=True,
74
+ x_T=None,
75
+ log_every_t=100,
76
+ unconditional_guidance_scale=1.,
77
+ unconditional_conditioning=None,
78
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
79
+ **kwargs
80
+ ):
81
+ if conditioning is not None:
82
+ if isinstance(conditioning, dict):
83
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
84
+ if cbs != batch_size:
85
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
86
+ else:
87
+ if conditioning.shape[0] != batch_size:
88
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
89
+
90
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
91
+ # sampling
92
+ C, T, H, W = shape
93
+ # breakpoint()
94
+ size = (batch_size, C, T, H, W)
95
+ # print(f'Data shape for DDIM sampling is {size}, eta {eta}')
96
+
97
+ samples, intermediates = self.ddim_sampling(conditioning, size,
98
+ callback=callback,
99
+ img_callback=img_callback,
100
+ quantize_denoised=quantize_x0,
101
+ mask=mask, x0=x0,
102
+ ddim_use_original_steps=False,
103
+ noise_dropout=noise_dropout,
104
+ temperature=temperature,
105
+ score_corrector=score_corrector,
106
+ corrector_kwargs=corrector_kwargs,
107
+ x_T=x_T,
108
+ log_every_t=log_every_t,
109
+ unconditional_guidance_scale=unconditional_guidance_scale,
110
+ unconditional_conditioning=unconditional_conditioning,
111
+ )
112
+ return samples, intermediates
113
+
114
+ @torch.no_grad()
115
+ def ddim_sampling(self, cond, shape,
116
+ x_T=None, ddim_use_original_steps=False,
117
+ callback=None, timesteps=None, quantize_denoised=False,
118
+ mask=None, x0=None, img_callback=None, log_every_t=100,
119
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
120
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
121
+ device = self.model.betas.device
122
+ b = shape[0]
123
+ if x_T is None:
124
+ img = torch.randn(shape, device=device)
125
+ else:
126
+ img = x_T
127
+
128
+ if timesteps is None:
129
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
130
+ elif timesteps is not None and not ddim_use_original_steps:
131
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
132
+ timesteps = self.ddim_timesteps[:subset_end]
133
+
134
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
135
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
136
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
137
+ # print(f"Running DDIM Sampling with {total_steps} timesteps")
138
+
139
+ # iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
140
+
141
+ for i, step in enumerate(time_range):
142
+ index = total_steps - i - 1
143
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
144
+
145
+ if mask is not None:
146
+ assert x0 is not None
147
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
148
+ img = img_orig * mask + (1. - mask) * img
149
+
150
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
151
+ quantize_denoised=quantize_denoised, temperature=temperature,
152
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
153
+ corrector_kwargs=corrector_kwargs,
154
+ unconditional_guidance_scale=unconditional_guidance_scale,
155
+ unconditional_conditioning=unconditional_conditioning)
156
+ img, pred_x0 = outs
157
+ if callback: callback(i)
158
+ if img_callback: img_callback(pred_x0, i)
159
+
160
+ if index % log_every_t == 0 or index == total_steps - 1:
161
+ intermediates['x_inter'].append(img)
162
+ intermediates['pred_x0'].append(pred_x0)
163
+
164
+ return img, intermediates
165
+
166
+ @torch.no_grad()
167
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
168
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
169
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
170
+ b, *_, device = *x.shape, x.device
171
+
172
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
173
+ # breakpoint()
174
+ e_t = self.model.denoise_fn(x, t, c)
175
+ else:
176
+ x_in = torch.cat([x] * 2)
177
+ t_in = torch.cat([t] * 2)
178
+ c_in = torch.cat([unconditional_conditioning, c])
179
+ e_t_uncond, e_t = self.model.denoise_fn(x_in, t_in, c_in).chunk(2)
180
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
181
+
182
+ if score_corrector is not None:
183
+ assert self.model.parameterization == "eps"
184
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
185
+
186
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
187
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
188
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
189
+ sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
190
+ # select parameters corresponding to the currently considered timestep
191
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
192
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
193
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
194
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
195
+
196
+ # current prediction for x_0
197
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
198
+ if quantize_denoised:
199
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
200
+ # direction pointing to x_t
201
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
202
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
203
+ if noise_dropout > 0.:
204
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
205
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
206
+ return x_prev, pred_x0
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/diffusion.py ADDED
@@ -0,0 +1,1016 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Largely taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch"
2
+
3
+ import math
4
+ import copy
5
+ import torch
6
+ from torch import nn, einsum
7
+ import torch.nn.functional as F
8
+ from functools import partial
9
+
10
+ from torch.utils import data
11
+ from pathlib import Path
12
+ from torch.optim import Adam
13
+ from torchvision import transforms as T, utils
14
+ from torch.cuda.amp import autocast, GradScaler
15
+ from PIL import Image
16
+
17
+ from tqdm import tqdm
18
+ from einops import rearrange
19
+ from einops_exts import check_shape, rearrange_many
20
+
21
+ from rotary_embedding_torch import RotaryEmbedding
22
+
23
+ from .text import tokenize, bert_embed, BERT_MODEL_DIM
24
+ from torch.utils.data import Dataset, DataLoader
25
+ from ..vq_gan_3d.model.vqgan import VQGAN
26
+
27
+ import matplotlib.pyplot as plt
28
+
29
+ # helpers functions
30
+
31
+
32
+ def exists(x):
33
+ return x is not None
34
+
35
+
36
+ def noop(*args, **kwargs):
37
+ pass
38
+
39
+
40
+ def is_odd(n):
41
+ return (n % 2) == 1
42
+
43
+
44
+ def default(val, d):
45
+ if exists(val):
46
+ return val
47
+ return d() if callable(d) else d
48
+
49
+
50
+ def cycle(dl):
51
+ while True:
52
+ for data in dl:
53
+ yield data
54
+
55
+
56
+ def num_to_groups(num, divisor):
57
+ groups = num // divisor
58
+ remainder = num % divisor
59
+ arr = [divisor] * groups
60
+ if remainder > 0:
61
+ arr.append(remainder)
62
+ return arr
63
+
64
+
65
+ def prob_mask_like(shape, prob, device):
66
+ if prob == 1:
67
+ return torch.ones(shape, device=device, dtype=torch.bool)
68
+ elif prob == 0:
69
+ return torch.zeros(shape, device=device, dtype=torch.bool)
70
+ else:
71
+ return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
72
+
73
+
74
+ def is_list_str(x):
75
+ if not isinstance(x, (list, tuple)):
76
+ return False
77
+ return all([type(el) == str for el in x])
78
+
79
+ # relative positional bias
80
+
81
+
82
+ class RelativePositionBias(nn.Module):
83
+ def __init__(
84
+ self,
85
+ heads=8,
86
+ num_buckets=32,
87
+ max_distance=128
88
+ ):
89
+ super().__init__()
90
+ self.num_buckets = num_buckets
91
+ self.max_distance = max_distance
92
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
93
+
94
+ @staticmethod
95
+ def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
96
+ ret = 0
97
+ n = -relative_position
98
+
99
+ num_buckets //= 2
100
+ ret += (n < 0).long() * num_buckets
101
+ n = torch.abs(n)
102
+
103
+ max_exact = num_buckets // 2
104
+ is_small = n < max_exact
105
+
106
+ val_if_large = max_exact + (
107
+ torch.log(n.float() / max_exact) / math.log(max_distance /
108
+ max_exact) * (num_buckets - max_exact)
109
+ ).long()
110
+ val_if_large = torch.min(
111
+ val_if_large, torch.full_like(val_if_large, num_buckets - 1))
112
+
113
+ ret += torch.where(is_small, n, val_if_large)
114
+ return ret
115
+
116
+ def forward(self, n, device):
117
+ q_pos = torch.arange(n, dtype=torch.long, device=device)
118
+ k_pos = torch.arange(n, dtype=torch.long, device=device)
119
+ rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
120
+ rp_bucket = self._relative_position_bucket(
121
+ rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance)
122
+ values = self.relative_attention_bias(rp_bucket)
123
+ return rearrange(values, 'i j h -> h i j')
124
+
125
+ # small helper modules
126
+
127
+
128
+ class EMA():
129
+ def __init__(self, beta):
130
+ super().__init__()
131
+ self.beta = beta
132
+
133
+ def update_model_average(self, ma_model, current_model):
134
+ for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
135
+ old_weight, up_weight = ma_params.data, current_params.data
136
+ ma_params.data = self.update_average(old_weight, up_weight)
137
+
138
+ def update_average(self, old, new):
139
+ if old is None:
140
+ return new
141
+ return old * self.beta + (1 - self.beta) * new
142
+
143
+
144
+ class Residual(nn.Module):
145
+ def __init__(self, fn):
146
+ super().__init__()
147
+ self.fn = fn
148
+
149
+ def forward(self, x, *args, **kwargs):
150
+ return self.fn(x, *args, **kwargs) + x
151
+
152
+
153
+ class SinusoidalPosEmb(nn.Module):
154
+ def __init__(self, dim):
155
+ super().__init__()
156
+ self.dim = dim
157
+
158
+ def forward(self, x):
159
+ device = x.device
160
+ half_dim = self.dim // 2
161
+ emb = math.log(10000) / (half_dim - 1)
162
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
163
+ emb = x[:, None] * emb[None, :]
164
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
165
+ return emb
166
+
167
+
168
+ def Upsample(dim):
169
+ return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))
170
+
171
+
172
+ def Downsample(dim):
173
+ return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))
174
+
175
+
176
+ class LayerNorm(nn.Module):
177
+ def __init__(self, dim, eps=1e-5):
178
+ super().__init__()
179
+ self.eps = eps
180
+ self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1))
181
+
182
+ def forward(self, x):
183
+ var = torch.var(x, dim=1, unbiased=False, keepdim=True)
184
+ mean = torch.mean(x, dim=1, keepdim=True)
185
+ return (x - mean) / (var + self.eps).sqrt() * self.gamma
186
+
187
+
188
+ class PreNorm(nn.Module):
189
+ def __init__(self, dim, fn):
190
+ super().__init__()
191
+ self.fn = fn
192
+ self.norm = LayerNorm(dim)
193
+
194
+ def forward(self, x, **kwargs):
195
+ x = self.norm(x)
196
+ return self.fn(x, **kwargs)
197
+
198
+ # building block modules
199
+
200
+
201
+ class Block(nn.Module):
202
+ def __init__(self, dim, dim_out, groups=8):
203
+ super().__init__()
204
+ self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1))
205
+ self.norm = nn.GroupNorm(groups, dim_out)
206
+ self.act = nn.SiLU()
207
+
208
+ def forward(self, x, scale_shift=None):
209
+ x = self.proj(x)
210
+ x = self.norm(x)
211
+
212
+ if exists(scale_shift):
213
+ scale, shift = scale_shift
214
+ x = x * (scale + 1) + shift
215
+
216
+ return self.act(x)
217
+
218
+
219
+ class ResnetBlock(nn.Module):
220
+ def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
221
+ super().__init__()
222
+ self.mlp = nn.Sequential(
223
+ nn.SiLU(),
224
+ nn.Linear(time_emb_dim, dim_out * 2)
225
+ ) if exists(time_emb_dim) else None
226
+
227
+ self.block1 = Block(dim, dim_out, groups=groups)
228
+ self.block2 = Block(dim_out, dim_out, groups=groups)
229
+ self.res_conv = nn.Conv3d(
230
+ dim, dim_out, 1) if dim != dim_out else nn.Identity()
231
+
232
+ def forward(self, x, time_emb=None):
233
+
234
+ scale_shift = None
235
+ if exists(self.mlp):
236
+ assert exists(time_emb), 'time emb must be passed in'
237
+ time_emb = self.mlp(time_emb)
238
+ time_emb = rearrange(time_emb, 'b c -> b c 1 1 1')
239
+ scale_shift = time_emb.chunk(2, dim=1)
240
+
241
+ h = self.block1(x, scale_shift=scale_shift)
242
+
243
+ h = self.block2(h)
244
+ return h + self.res_conv(x)
245
+
246
+
247
+ class SpatialLinearAttention(nn.Module):
248
+ def __init__(self, dim, heads=4, dim_head=32):
249
+ super().__init__()
250
+ self.scale = dim_head ** -0.5
251
+ self.heads = heads
252
+ hidden_dim = dim_head * heads
253
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
254
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
255
+
256
+ def forward(self, x):
257
+ b, c, f, h, w = x.shape
258
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
259
+
260
+ qkv = self.to_qkv(x).chunk(3, dim=1)
261
+ q, k, v = rearrange_many(
262
+ qkv, 'b (h c) x y -> b h c (x y)', h=self.heads)
263
+
264
+ q = q.softmax(dim=-2)
265
+ k = k.softmax(dim=-1)
266
+
267
+ q = q * self.scale
268
+ context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
269
+
270
+ out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
271
+ out = rearrange(out, 'b h c (x y) -> b (h c) x y',
272
+ h=self.heads, x=h, y=w)
273
+ out = self.to_out(out)
274
+ return rearrange(out, '(b f) c h w -> b c f h w', b=b)
275
+
276
+ # attention along space and time
277
+
278
+
279
+ class EinopsToAndFrom(nn.Module):
280
+ def __init__(self, from_einops, to_einops, fn):
281
+ super().__init__()
282
+ self.from_einops = from_einops
283
+ self.to_einops = to_einops
284
+ self.fn = fn
285
+
286
+ def forward(self, x, **kwargs):
287
+ shape = x.shape
288
+ reconstitute_kwargs = dict(
289
+ tuple(zip(self.from_einops.split(' '), shape)))
290
+ x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')
291
+ x = self.fn(x, **kwargs)
292
+ x = rearrange(
293
+ x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)
294
+ return x
295
+
296
+
297
+ class Attention(nn.Module):
298
+ def __init__(
299
+ self,
300
+ dim,
301
+ heads=4,
302
+ dim_head=32,
303
+ rotary_emb=None
304
+ ):
305
+ super().__init__()
306
+ self.scale = dim_head ** -0.5
307
+ self.heads = heads
308
+ hidden_dim = dim_head * heads
309
+
310
+ self.rotary_emb = rotary_emb
311
+ self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False)
312
+ self.to_out = nn.Linear(hidden_dim, dim, bias=False)
313
+
314
+ def forward(
315
+ self,
316
+ x,
317
+ pos_bias=None,
318
+ focus_present_mask=None
319
+ ):
320
+ n, device = x.shape[-2], x.device
321
+
322
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
323
+
324
+ if exists(focus_present_mask) and focus_present_mask.all():
325
+ # if all batch samples are focusing on present
326
+ # it would be equivalent to passing that token's values through to the output
327
+ values = qkv[-1]
328
+ return self.to_out(values)
329
+
330
+ # split out heads
331
+
332
+ q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads)
333
+
334
+ # scale
335
+
336
+ q = q * self.scale
337
+
338
+ # rotate positions into queries and keys for time attention
339
+
340
+ if exists(self.rotary_emb):
341
+ q = self.rotary_emb.rotate_queries_or_keys(q)
342
+ k = self.rotary_emb.rotate_queries_or_keys(k)
343
+
344
+ # similarity
345
+
346
+ sim = einsum('... h i d, ... h j d -> ... h i j', q, k)
347
+
348
+ # relative positional bias
349
+
350
+ if exists(pos_bias):
351
+ sim = sim + pos_bias
352
+
353
+ if exists(focus_present_mask) and not (~focus_present_mask).all():
354
+ attend_all_mask = torch.ones(
355
+ (n, n), device=device, dtype=torch.bool)
356
+ attend_self_mask = torch.eye(n, device=device, dtype=torch.bool)
357
+
358
+ mask = torch.where(
359
+ rearrange(focus_present_mask, 'b -> b 1 1 1 1'),
360
+ rearrange(attend_self_mask, 'i j -> 1 1 1 i j'),
361
+ rearrange(attend_all_mask, 'i j -> 1 1 1 i j'),
362
+ )
363
+
364
+ sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
365
+
366
+ # numerical stability
367
+
368
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
369
+ attn = sim.softmax(dim=-1)
370
+
371
+ # aggregate values
372
+
373
+ out = einsum('... h i j, ... h j d -> ... h i d', attn, v)
374
+ out = rearrange(out, '... h n d -> ... n (h d)')
375
+ return self.to_out(out)
376
+
377
+ # model
378
+
379
+
380
+ class Unet3D(nn.Module):
381
+ def __init__(
382
+ self,
383
+ dim,
384
+ cond_dim=None,
385
+ out_dim=None,
386
+ dim_mults=(1, 2, 4, 8),
387
+ channels=3,
388
+ attn_heads=8,
389
+ attn_dim_head=32,
390
+ use_bert_text_cond=False,
391
+ init_dim=None,
392
+ init_kernel_size=7,
393
+ use_sparse_linear_attn=True,
394
+ block_type='resnet',
395
+ resnet_groups=8
396
+ ):
397
+ super().__init__()
398
+ self.channels = channels
399
+
400
+ # temporal attention and its relative positional encoding
401
+
402
+ rotary_emb = RotaryEmbedding(min(32, attn_dim_head))
403
+
404
+ def temporal_attn(dim): return EinopsToAndFrom('b c f h w', 'b (h w) f c', Attention(
405
+ dim, heads=attn_heads, dim_head=attn_dim_head, rotary_emb=rotary_emb))
406
+
407
+ # realistically will not be able to generate that many frames of video... yet
408
+ self.time_rel_pos_bias = RelativePositionBias(
409
+ heads=attn_heads, max_distance=32)
410
+
411
+ # initial conv
412
+
413
+ init_dim = default(init_dim, dim)
414
+ assert is_odd(init_kernel_size)
415
+
416
+ init_padding = init_kernel_size // 2
417
+ self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size,
418
+ init_kernel_size), padding=(0, init_padding, init_padding))
419
+
420
+ self.init_temporal_attn = Residual(
421
+ PreNorm(init_dim, temporal_attn(init_dim)))
422
+
423
+ # dimensions
424
+
425
+ dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
426
+ in_out = list(zip(dims[:-1], dims[1:]))
427
+
428
+ # time conditioning
429
+
430
+ time_dim = dim * 4
431
+ self.time_mlp = nn.Sequential(
432
+ SinusoidalPosEmb(dim),
433
+ nn.Linear(dim, time_dim),
434
+ nn.GELU(),
435
+ nn.Linear(time_dim, time_dim)
436
+ )
437
+
438
+ # text conditioning
439
+
440
+ self.has_cond = exists(cond_dim) or use_bert_text_cond
441
+ cond_dim = BERT_MODEL_DIM if use_bert_text_cond else cond_dim
442
+
443
+ self.null_cond_emb = nn.Parameter(
444
+ torch.randn(1, cond_dim)) if self.has_cond else None
445
+
446
+ cond_dim = time_dim + int(cond_dim or 0)
447
+
448
+ # layers
449
+
450
+ self.downs = nn.ModuleList([])
451
+ self.ups = nn.ModuleList([])
452
+
453
+ num_resolutions = len(in_out)
454
+ # block type
455
+
456
+ block_klass = partial(ResnetBlock, groups=resnet_groups)
457
+ block_klass_cond = partial(block_klass, time_emb_dim=cond_dim)
458
+
459
+ # modules for all layers
460
+ for ind, (dim_in, dim_out) in enumerate(in_out):
461
+ is_last = ind >= (num_resolutions - 1)
462
+
463
+ self.downs.append(nn.ModuleList([
464
+ block_klass_cond(dim_in, dim_out),
465
+ block_klass_cond(dim_out, dim_out),
466
+ Residual(PreNorm(dim_out, SpatialLinearAttention(
467
+ dim_out, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(),
468
+ Residual(PreNorm(dim_out, temporal_attn(dim_out))),
469
+ Downsample(dim_out) if not is_last else nn.Identity()
470
+ ]))
471
+
472
+ mid_dim = dims[-1]
473
+ self.mid_block1 = block_klass_cond(mid_dim, mid_dim)
474
+
475
+ spatial_attn = EinopsToAndFrom(
476
+ 'b c f h w', 'b f (h w) c', Attention(mid_dim, heads=attn_heads))
477
+
478
+ self.mid_spatial_attn = Residual(PreNorm(mid_dim, spatial_attn))
479
+ self.mid_temporal_attn = Residual(
480
+ PreNorm(mid_dim, temporal_attn(mid_dim)))
481
+
482
+ self.mid_block2 = block_klass_cond(mid_dim, mid_dim)
483
+
484
+ for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
485
+ is_last = ind >= (num_resolutions - 1)
486
+
487
+ self.ups.append(nn.ModuleList([
488
+ block_klass_cond(dim_out * 2, dim_in),
489
+ block_klass_cond(dim_in, dim_in),
490
+ Residual(PreNorm(dim_in, SpatialLinearAttention(
491
+ dim_in, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(),
492
+ Residual(PreNorm(dim_in, temporal_attn(dim_in))),
493
+ Upsample(dim_in) if not is_last else nn.Identity()
494
+ ]))
495
+
496
+ out_dim = default(out_dim, channels)
497
+ self.final_conv = nn.Sequential(
498
+ block_klass(dim * 2, dim),
499
+ nn.Conv3d(dim, out_dim, 1)
500
+ )
501
+
502
+ def forward_with_cond_scale(
503
+ self,
504
+ *args,
505
+ cond_scale=2.,
506
+ **kwargs
507
+ ):
508
+ logits = self.forward(*args, null_cond_prob=0., **kwargs)
509
+ if cond_scale == 1 or not self.has_cond:
510
+ return logits
511
+
512
+ null_logits = self.forward(*args, null_cond_prob=1., **kwargs)
513
+ return null_logits + (logits - null_logits) * cond_scale
514
+
515
+ def forward(
516
+ self,
517
+ x,
518
+ time,
519
+ cond=None,
520
+ null_cond_prob=0.,
521
+ focus_present_mask=None,
522
+ # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time)
523
+ prob_focus_present=0.
524
+ ):
525
+ assert not (self.has_cond and not exists(cond)
526
+ ), 'cond must be passed in if cond_dim specified'
527
+ x = torch.cat([x, cond], dim=1)
528
+
529
+ batch, device = x.shape[0], x.device
530
+
531
+ focus_present_mask = default(focus_present_mask, lambda: prob_mask_like(
532
+ (batch,), prob_focus_present, device=device))
533
+
534
+ time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device=x.device)
535
+
536
+ x = self.init_conv(x)
537
+ r = x.clone()
538
+
539
+ x = self.init_temporal_attn(x, pos_bias=time_rel_pos_bias)
540
+
541
+ t = self.time_mlp(time) if exists(self.time_mlp) else None # [2, 128]
542
+
543
+ # classifier free guidance
544
+
545
+ if self.has_cond:
546
+ batch, device = x.shape[0], x.device
547
+ mask = prob_mask_like((batch,), null_cond_prob, device=device)
548
+ cond = torch.where(rearrange(mask, 'b -> b 1'),
549
+ self.null_cond_emb, cond)
550
+ t = torch.cat((t, cond), dim=-1)
551
+
552
+ h = []
553
+
554
+ for block1, block2, spatial_attn, temporal_attn, downsample in self.downs:
555
+ x = block1(x, t)
556
+ x = block2(x, t)
557
+ x = spatial_attn(x)
558
+ x = temporal_attn(x, pos_bias=time_rel_pos_bias,
559
+ focus_present_mask=focus_present_mask)
560
+ h.append(x)
561
+ x = downsample(x)
562
+
563
+ # [2, 256, 32, 4, 4]
564
+ x = self.mid_block1(x, t)
565
+ x = self.mid_spatial_attn(x)
566
+ x = self.mid_temporal_attn(
567
+ x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask)
568
+ x = self.mid_block2(x, t)
569
+
570
+ for block1, block2, spatial_attn, temporal_attn, upsample in self.ups:
571
+ x = torch.cat((x, h.pop()), dim=1)
572
+ x = block1(x, t)
573
+ x = block2(x, t)
574
+ x = spatial_attn(x)
575
+ x = temporal_attn(x, pos_bias=time_rel_pos_bias,
576
+ focus_present_mask=focus_present_mask)
577
+ x = upsample(x)
578
+
579
+ x = torch.cat((x, r), dim=1)
580
+ return self.final_conv(x)
581
+
582
+ # gaussian diffusion trainer class
583
+
584
+
585
+ def extract(a, t, x_shape):
586
+ b, *_ = t.shape
587
+ out = a.gather(-1, t)
588
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
589
+
590
+
591
+ def cosine_beta_schedule(timesteps, s=0.008):
592
+ """
593
+ cosine schedule
594
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
595
+ """
596
+ steps = timesteps + 1
597
+ x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
598
+ alphas_cumprod = torch.cos(
599
+ ((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
600
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
601
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
602
+ return torch.clip(betas, 0, 0.9999)
603
+
604
+
605
+ class GaussianDiffusion(nn.Module):
606
+ def __init__(
607
+ self,
608
+ denoise_fn,
609
+ *,
610
+ image_size,
611
+ num_frames,
612
+ text_use_bert_cls=False,
613
+ channels=3,
614
+ timesteps=1000,
615
+ loss_type='l1',
616
+ use_dynamic_thres=False, # from the Imagen paper
617
+ dynamic_thres_percentile=0.9,
618
+ vqgan_ckpt=None,
619
+ device=None
620
+ ):
621
+ super().__init__()
622
+ self.channels = channels
623
+ self.image_size = image_size
624
+ self.num_frames = num_frames
625
+ self.denoise_fn = denoise_fn
626
+ self.device = device
627
+
628
+ if vqgan_ckpt:
629
+ self.vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt).cuda()
630
+ self.vqgan.eval()
631
+ else:
632
+ self.vqgan = None
633
+
634
+ betas = cosine_beta_schedule(timesteps)
635
+
636
+ alphas = 1. - betas
637
+ alphas_cumprod = torch.cumprod(alphas, axis=0)
638
+ alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)
639
+
640
+ timesteps, = betas.shape
641
+ self.num_timesteps = int(timesteps)
642
+ self.loss_type = loss_type
643
+
644
+ # register buffer helper function that casts float64 to float32
645
+
646
+ def register_buffer(name, val): return self.register_buffer(
647
+ name, val.to(torch.float32))
648
+
649
+ register_buffer('betas', betas)
650
+ register_buffer('alphas_cumprod', alphas_cumprod)
651
+ register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
652
+
653
+ # calculations for diffusion q(x_t | x_{t-1}) and others
654
+
655
+ register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
656
+ register_buffer('sqrt_one_minus_alphas_cumprod',
657
+ torch.sqrt(1. - alphas_cumprod))
658
+ register_buffer('log_one_minus_alphas_cumprod',
659
+ torch.log(1. - alphas_cumprod))
660
+ register_buffer('sqrt_recip_alphas_cumprod',
661
+ torch.sqrt(1. / alphas_cumprod))
662
+ register_buffer('sqrt_recipm1_alphas_cumprod',
663
+ torch.sqrt(1. / alphas_cumprod - 1))
664
+
665
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
666
+
667
+ posterior_variance = betas * \
668
+ (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
669
+
670
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
671
+
672
+ register_buffer('posterior_variance', posterior_variance)
673
+
674
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
675
+
676
+ register_buffer('posterior_log_variance_clipped',
677
+ torch.log(posterior_variance.clamp(min=1e-20)))
678
+ register_buffer('posterior_mean_coef1', betas *
679
+ torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
680
+ register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev)
681
+ * torch.sqrt(alphas) / (1. - alphas_cumprod))
682
+
683
+ # text conditioning parameters
684
+
685
+ self.text_use_bert_cls = text_use_bert_cls
686
+
687
+ # dynamic thresholding when sampling
688
+
689
+ self.use_dynamic_thres = use_dynamic_thres
690
+ self.dynamic_thres_percentile = dynamic_thres_percentile
691
+
692
+ def q_mean_variance(self, x_start, t):
693
+ mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
694
+ variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
695
+ log_variance = extract(
696
+ self.log_one_minus_alphas_cumprod, t, x_start.shape)
697
+ return mean, variance, log_variance
698
+
699
+ def predict_start_from_noise(self, x_t, t, noise):
700
+ return (
701
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
702
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
703
+ )
704
+
705
+ def q_posterior(self, x_start, x_t, t):
706
+ posterior_mean = (
707
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
708
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
709
+ )
710
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
711
+ posterior_log_variance_clipped = extract(
712
+ self.posterior_log_variance_clipped, t, x_t.shape)
713
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
714
+
715
+ def p_mean_variance(self, x, t, clip_denoised: bool, cond=None, cond_scale=1.):
716
+ x_recon = self.predict_start_from_noise(
717
+ x, t=t, noise=self.denoise_fn.forward_with_cond_scale(x, t, cond=cond, cond_scale=cond_scale))
718
+
719
+ if clip_denoised:
720
+ s = 1.
721
+ if self.use_dynamic_thres:
722
+ s = torch.quantile(
723
+ rearrange(x_recon, 'b ... -> b (...)').abs(),
724
+ self.dynamic_thres_percentile,
725
+ dim=-1
726
+ )
727
+
728
+ s.clamp_(min=1.)
729
+ s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
730
+
731
+ # clip by threshold, depending on whether static or dynamic
732
+ x_recon = x_recon.clamp(-s, s) / s
733
+
734
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
735
+ x_start=x_recon, x_t=x, t=t)
736
+ return model_mean, posterior_variance, posterior_log_variance
737
+
738
+ @torch.inference_mode()
739
+ def p_sample(self, x, t, cond=None, cond_scale=1., clip_denoised=True):
740
+ b, *_, device = *x.shape, x.device
741
+ model_mean, _, model_log_variance = self.p_mean_variance(
742
+ x=x, t=t, clip_denoised=clip_denoised, cond=cond, cond_scale=cond_scale)
743
+ noise = torch.randn_like(x)
744
+ # no noise when t == 0
745
+ nonzero_mask = (1 - (t == 0).float()).reshape(b,
746
+ *((1,) * (len(x.shape) - 1)))
747
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
748
+
749
+ @torch.inference_mode()
750
+ def p_sample_loop(self, shape, cond=None, cond_scale=1.):
751
+ device = self.betas.device
752
+
753
+ b = shape[0]
754
+ img = torch.randn(shape, device=device)
755
+ # print('cond', cond.shape)
756
+ for i in reversed(range(0, self.num_timesteps)):
757
+ img = self.p_sample(img, torch.full(
758
+ (b,), i, device=device, dtype=torch.long), cond=cond, cond_scale=cond_scale)
759
+
760
+ return img
761
+
762
+ @torch.inference_mode()
763
+ def sample(self, cond=None, cond_scale=1., batch_size=16):
764
+ device = next(self.denoise_fn.parameters()).device
765
+
766
+ if is_list_str(cond):
767
+ cond = bert_embed(tokenize(cond)).to(device)
768
+
769
+ # batch_size = cond.shape[0] if exists(cond) else batch_size
770
+ batch_size = batch_size
771
+ image_size = self.image_size
772
+ channels = 8 # self.channels
773
+ num_frames = self.num_frames
774
+ # print((batch_size, channels, num_frames, image_size, image_size))
775
+ # print('cond_',cond.shape)
776
+ _sample = self.p_sample_loop(
777
+ (batch_size, channels, num_frames, image_size, image_size), cond=cond, cond_scale=cond_scale)
778
+
779
+ if isinstance(self.vqgan, VQGAN):
780
+ # denormalize TODO: Remove eventually
781
+ _sample = (((_sample + 1.0) / 2.0) * (self.vqgan.codebook.embeddings.max() -
782
+ self.vqgan.codebook.embeddings.min())) + self.vqgan.codebook.embeddings.min()
783
+
784
+ _sample = self.vqgan.decode(_sample, quantize=True)
785
+ else:
786
+ unnormalize_img(_sample)
787
+
788
+ return _sample
789
+
790
+ @torch.inference_mode()
791
+ def interpolate(self, x1, x2, t=None, lam=0.5):
792
+ b, *_, device = *x1.shape, x1.device
793
+ t = default(t, self.num_timesteps - 1)
794
+
795
+ assert x1.shape == x2.shape
796
+
797
+ t_batched = torch.stack([torch.tensor(t, device=device)] * b)
798
+ xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
799
+
800
+ img = (1 - lam) * xt1 + lam * xt2
801
+ for i in reversed(range(0, t)):
802
+ img = self.p_sample(img, torch.full(
803
+ (b,), i, device=device, dtype=torch.long))
804
+
805
+ return img
806
+
807
+ def q_sample(self, x_start, t, noise=None):
808
+ noise = default(noise, lambda: torch.randn_like(x_start))
809
+
810
+ return (
811
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
812
+ extract(self.sqrt_one_minus_alphas_cumprod,
813
+ t, x_start.shape) * noise
814
+ )
815
+
816
+ def p_losses(self, x_start, t, cond=None, noise=None, **kwargs):
817
+ b, c, f, h, w, device = *x_start.shape, x_start.device
818
+ noise = default(noise, lambda: torch.randn_like(x_start))
819
+ # breakpoint()
820
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # [2, 8, 32, 32, 32]
821
+
822
+ if is_list_str(cond):
823
+ cond = bert_embed(
824
+ tokenize(cond), return_cls_repr=self.text_use_bert_cls)
825
+ cond = cond.to(device)
826
+
827
+ x_recon = self.denoise_fn(x_noisy, t, cond=cond, **kwargs)
828
+
829
+ if self.loss_type == 'l1':
830
+ loss = F.l1_loss(noise, x_recon)
831
+ elif self.loss_type == 'l2':
832
+ loss = F.mse_loss(noise, x_recon)
833
+ else:
834
+ raise NotImplementedError()
835
+
836
+ return loss
837
+
838
+ def forward(self, x, *args, **kwargs):
839
+ bs = int(x.shape[0]/2)
840
+ img=x[:bs,...]
841
+ mask=x[bs:,...]
842
+ mask_=(1-mask).detach()
843
+ masked_img = (img*mask_).detach()
844
+ masked_img=masked_img.permute(0,1,-1,-3,-2)
845
+ img=img.permute(0,1,-1,-3,-2)
846
+ mask=mask.permute(0,1,-1,-3,-2)
847
+ # breakpoint()
848
+ if isinstance(self.vqgan, VQGAN):
849
+ with torch.no_grad():
850
+ img = self.vqgan.encode(
851
+ img, quantize=False, include_embeddings=True)
852
+ # normalize to -1 and 1
853
+ img = ((img - self.vqgan.codebook.embeddings.min()) /
854
+ (self.vqgan.codebook.embeddings.max() -
855
+ self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0
856
+
857
+ masked_img = self.vqgan.encode(
858
+ masked_img, quantize=False, include_embeddings=True)
859
+ # normalize to -1 and 1
860
+ masked_img = ((masked_img - self.vqgan.codebook.embeddings.min()) /
861
+ (self.vqgan.codebook.embeddings.max() -
862
+ self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0
863
+ else:
864
+ print("Hi")
865
+ img = normalize_img(img)
866
+ masked_img = normalize_img(masked_img)
867
+ mask = mask*2.0 - 1.0
868
+ cc = torch.nn.functional.interpolate(mask, size=masked_img.shape[-3:])
869
+ cond = torch.cat((masked_img, cc), dim=1)
870
+
871
+ b, device, img_size, = img.shape[0], img.device, self.image_size
872
+ t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
873
+ # breakpoint()
874
+ return self.p_losses(img, t, cond=cond, *args, **kwargs)
875
+
876
+ # trainer class
877
+
878
+
879
+ CHANNELS_TO_MODE = {
880
+ 1: 'L',
881
+ 3: 'RGB',
882
+ 4: 'RGBA'
883
+ }
884
+
885
+
886
+ def seek_all_images(img, channels=3):
887
+ assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid'
888
+ mode = CHANNELS_TO_MODE[channels]
889
+
890
+ i = 0
891
+ while True:
892
+ try:
893
+ img.seek(i)
894
+ yield img.convert(mode)
895
+ except EOFError:
896
+ break
897
+ i += 1
898
+
899
+ # tensor of shape (channels, frames, height, width) -> gif
900
+
901
+
902
+ def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True):
903
+ tensor = ((tensor - tensor.min()) / (tensor.max() - tensor.min())) * 1.0
904
+ images = map(T.ToPILImage(), tensor.unbind(dim=1))
905
+ first_img, *rest_imgs = images
906
+ first_img.save(path, save_all=True, append_images=rest_imgs,
907
+ duration=duration, loop=loop, optimize=optimize)
908
+ return images
909
+
910
+ # gif -> (channels, frame, height, width) tensor
911
+
912
+
913
+ def gif_to_tensor(path, channels=3, transform=T.ToTensor()):
914
+ img = Image.open(path)
915
+ tensors = tuple(map(transform, seek_all_images(img, channels=channels)))
916
+ return torch.stack(tensors, dim=1)
917
+
918
+
919
+ def identity(t, *args, **kwargs):
920
+ return t
921
+
922
+
923
+ def normalize_img(t):
924
+ return t * 2 - 1
925
+
926
+
927
+ def unnormalize_img(t):
928
+ return (t + 1) * 0.5
929
+
930
+
931
+ def cast_num_frames(t, *, frames):
932
+ f = t.shape[1]
933
+
934
+ if f == frames:
935
+ return t
936
+
937
+ if f > frames:
938
+ return t[:, :frames]
939
+
940
+ return F.pad(t, (0, 0, 0, 0, 0, frames - f))
941
+
942
+
943
+ class Dataset(data.Dataset):
944
+ def __init__(
945
+ self,
946
+ folder,
947
+ image_size,
948
+ channels=3,
949
+ num_frames=16,
950
+ horizontal_flip=False,
951
+ force_num_frames=True,
952
+ exts=['gif']
953
+ ):
954
+ super().__init__()
955
+ self.folder = folder
956
+ self.image_size = image_size
957
+ self.channels = channels
958
+ self.paths = [p for ext in exts for p in Path(
959
+ f'{folder}').glob(f'**/*.{ext}')]
960
+
961
+ self.cast_num_frames_fn = partial(
962
+ cast_num_frames, frames=num_frames) if force_num_frames else identity
963
+
964
+ self.transform = T.Compose([
965
+ T.Resize(image_size),
966
+ T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity),
967
+ T.CenterCrop(image_size),
968
+ T.ToTensor()
969
+ ])
970
+
971
+ def __len__(self):
972
+ return len(self.paths)
973
+
974
+ def __getitem__(self, index):
975
+ path = self.paths[index]
976
+ tensor = gif_to_tensor(path, self.channels, transform=self.transform)
977
+ return self.cast_num_frames_fn(tensor)
978
+
979
+ # trainer class
980
+
981
+
982
+ class Tester(object):
983
+ def __init__(
984
+ self,
985
+ diffusion_model,
986
+ ):
987
+ super().__init__()
988
+ self.model = diffusion_model
989
+ self.ema_model = copy.deepcopy(self.model)
990
+ self.step=0
991
+ self.image_size = diffusion_model.image_size
992
+
993
+ self.reset_parameters()
994
+
995
+ def reset_parameters(self):
996
+ self.ema_model.load_state_dict(self.model.state_dict())
997
+
998
+
999
+ def load(self, milestone, map_location=None, **kwargs):
1000
+ if milestone == -1:
1001
+ all_milestones = [int(p.stem.split('-')[-1])
1002
+ for p in Path(self.results_folder).glob('**/*.pt')]
1003
+ assert len(
1004
+ all_milestones) > 0, 'need to have at least one milestone to load from latest checkpoint (milestone == -1)'
1005
+ milestone = max(all_milestones)
1006
+
1007
+ if map_location:
1008
+ data = torch.load(milestone, map_location=map_location)
1009
+ else:
1010
+ data = torch.load(milestone)
1011
+
1012
+ self.step = data['step']
1013
+ self.model.load_state_dict(data['model'], **kwargs)
1014
+ self.ema_model.load_state_dict(data['ema'], **kwargs)
1015
+
1016
+
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/text.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch"
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+
7
+ def exists(val):
8
+ return val is not None
9
+
10
+ # singleton globals
11
+
12
+
13
+ MODEL = None
14
+ TOKENIZER = None
15
+ BERT_MODEL_DIM = 768
16
+
17
+
18
+ def get_tokenizer():
19
+ global TOKENIZER
20
+ if not exists(TOKENIZER):
21
+ TOKENIZER = torch.hub.load(
22
+ 'huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased')
23
+ return TOKENIZER
24
+
25
+
26
+ def get_bert():
27
+ global MODEL
28
+ if not exists(MODEL):
29
+ MODEL = torch.hub.load(
30
+ 'huggingface/pytorch-transformers', 'model', 'bert-base-cased')
31
+ if torch.cuda.is_available():
32
+ MODEL = MODEL.cuda()
33
+
34
+ return MODEL
35
+
36
+ # tokenize
37
+
38
+
39
+ def tokenize(texts, add_special_tokens=True):
40
+ if not isinstance(texts, (list, tuple)):
41
+ texts = [texts]
42
+
43
+ tokenizer = get_tokenizer()
44
+
45
+ encoding = tokenizer.batch_encode_plus(
46
+ texts,
47
+ add_special_tokens=add_special_tokens,
48
+ padding=True,
49
+ return_tensors='pt'
50
+ )
51
+
52
+ token_ids = encoding.input_ids
53
+ return token_ids
54
+
55
+ # embedding function
56
+
57
+
58
+ @torch.no_grad()
59
+ def bert_embed(
60
+ token_ids,
61
+ return_cls_repr=False,
62
+ eps=1e-8,
63
+ pad_id=0.
64
+ ):
65
+ model = get_bert()
66
+ mask = token_ids != pad_id
67
+
68
+ if torch.cuda.is_available():
69
+ token_ids = token_ids.cuda()
70
+ mask = mask.cuda()
71
+
72
+ outputs = model(
73
+ input_ids=token_ids,
74
+ attention_mask=mask,
75
+ output_hidden_states=True
76
+ )
77
+
78
+ hidden_state = outputs.hidden_states[-1]
79
+
80
+ if return_cls_repr:
81
+ # return [cls] as representation
82
+ return hidden_state[:, 0]
83
+
84
+ if not exists(mask):
85
+ return hidden_state.mean(dim=1)
86
+
87
+ # mean all tokens excluding [cls], accounting for length
88
+ mask = mask[:, 1:]
89
+ mask = rearrange(mask, 'b n -> b n 1')
90
+
91
+ numer = (hidden_state[:, 1:] * mask).sum(dim=1)
92
+ denom = mask.sum(dim=1)
93
+ masked_mean = numer / (denom + eps)
94
+ return
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/time_embedding.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from monai.networks.layers.utils import get_act_layer
6
+
7
+
8
+ class SinusoidalPosEmb(nn.Module):
9
+ def __init__(self, emb_dim=16, downscale_freq_shift=1, max_period=10000, flip_sin_to_cos=False):
10
+ super().__init__()
11
+ self.emb_dim = emb_dim
12
+ self.downscale_freq_shift = downscale_freq_shift
13
+ self.max_period = max_period
14
+ self.flip_sin_to_cos = flip_sin_to_cos
15
+
16
+ def forward(self, x):
17
+ device = x.device
18
+ half_dim = self.emb_dim // 2
19
+ emb = math.log(self.max_period) / \
20
+ (half_dim - self.downscale_freq_shift)
21
+ emb = torch.exp(-emb*torch.arange(half_dim, device=device))
22
+ emb = x[:, None] * emb[None, :]
23
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
24
+
25
+ if self.flip_sin_to_cos:
26
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
27
+
28
+ if self.emb_dim % 2 == 1:
29
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
30
+ return emb
31
+
32
+
33
+ class LearnedSinusoidalPosEmb(nn.Module):
34
+ """ following @crowsonkb 's lead with learned sinusoidal pos emb """
35
+ """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
36
+
37
+ def __init__(self, emb_dim):
38
+ super().__init__()
39
+ self.emb_dim = emb_dim
40
+ half_dim = emb_dim // 2
41
+ self.weights = nn.Parameter(torch.randn(half_dim))
42
+
43
+ def forward(self, x):
44
+ x = x[:, None]
45
+ freqs = x * self.weights[None, :] * 2 * math.pi
46
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
47
+ fouriered = torch.cat((x, fouriered), dim=-1)
48
+ if self.emb_dim % 2 == 1:
49
+ fouriered = torch.nn.functional.pad(fouriered, (0, 1, 0, 0))
50
+ return fouriered
51
+
52
+
53
+ class TimeEmbbeding(nn.Module):
54
+ def __init__(
55
+ self,
56
+ emb_dim=64,
57
+ pos_embedder=SinusoidalPosEmb,
58
+ pos_embedder_kwargs={},
59
+ act_name=("SWISH", {}) # Swish = SiLU
60
+ ):
61
+ super().__init__()
62
+ self.emb_dim = emb_dim
63
+ self.pos_emb_dim = pos_embedder_kwargs.get('emb_dim', emb_dim//4)
64
+ pos_embedder_kwargs['emb_dim'] = self.pos_emb_dim
65
+ self.pos_embedder = pos_embedder(**pos_embedder_kwargs)
66
+
67
+ self.time_emb = nn.Sequential(
68
+ self.pos_embedder,
69
+ nn.Linear(self.pos_emb_dim, self.emb_dim),
70
+ get_act_layer(act_name),
71
+ nn.Linear(self.emb_dim, self.emb_dim)
72
+ )
73
+
74
+ def forward(self, time):
75
+ return self.time_emb(time)
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/unet.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ddpm.time_embedding import TimeEmbbeding
2
+
3
+ import monai.networks.nets as nets
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+
8
+ from monai.networks.blocks import UnetBasicBlock, UnetResBlock, UnetUpBlock, Convolution, UnetOutBlock
9
+ from monai.networks.layers.utils import get_act_layer
10
+
11
+
12
+ class DownBlock(nn.Module):
13
+ def __init__(
14
+ self,
15
+ spatial_dims,
16
+ in_ch,
17
+ out_ch,
18
+ time_emb_dim,
19
+ cond_emb_dim,
20
+ act_name=("swish", {}),
21
+ **kwargs):
22
+ super(DownBlock, self).__init__()
23
+ self.loca_time_embedder = nn.Sequential(
24
+ get_act_layer(name=act_name),
25
+ nn.Linear(time_emb_dim, in_ch) # in_ch * 2
26
+ )
27
+ if cond_emb_dim is not None:
28
+ self.loca_cond_embedder = nn.Sequential(
29
+ get_act_layer(name=act_name),
30
+ nn.Linear(cond_emb_dim, in_ch),
31
+ )
32
+ self.down_op = UnetBasicBlock(
33
+ spatial_dims, in_ch, out_ch, act_name=act_name, **kwargs)
34
+
35
+ def forward(self, x, time_emb, cond_emb):
36
+ b, c, *_ = x.shape
37
+ sp_dim = x.ndim-2
38
+
39
+ # ------------ Time ----------
40
+ time_emb = self.loca_time_embedder(time_emb)
41
+ time_emb = time_emb.reshape(b, c, *((1,)*sp_dim))
42
+ # scale, shift = time_emb.chunk(2, dim = 1)
43
+
44
+ # ------------ Combine ------------
45
+ # x = x * (scale + 1) + shift
46
+ x = x + time_emb
47
+
48
+ # ----------- Condition ------------
49
+ if cond_emb is not None:
50
+ cond_emb = self.loca_cond_embedder(cond_emb)
51
+ cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim))
52
+ x = x + cond_emb
53
+
54
+ # ----------- Image ---------
55
+ y = self.down_op(x)
56
+ return y
57
+
58
+
59
+ class UpBlock(nn.Module):
60
+ def __init__(
61
+ self,
62
+ spatial_dims,
63
+ skip_ch,
64
+ enc_ch,
65
+ time_emb_dim,
66
+ cond_emb_dim,
67
+ act_name=("swish", {}),
68
+ **kwargs):
69
+ super(UpBlock, self).__init__()
70
+ self.up_op = UnetUpBlock(spatial_dims, enc_ch,
71
+ skip_ch, act_name=act_name, **kwargs)
72
+ self.loca_time_embedder = nn.Sequential(
73
+ get_act_layer(name=act_name),
74
+ nn.Linear(time_emb_dim, skip_ch * 2),
75
+ )
76
+ if cond_emb_dim is not None:
77
+ self.loca_cond_embedder = nn.Sequential(
78
+ get_act_layer(name=act_name),
79
+ nn.Linear(cond_emb_dim, skip_ch * 2),
80
+ )
81
+
82
+ def forward(self, x_skip, x_enc, time_emb, cond_emb):
83
+ b, c, *_ = x_enc.shape
84
+ sp_dim = x_enc.ndim-2
85
+
86
+ # ----------- Time --------------
87
+ time_emb = self.loca_time_embedder(time_emb)
88
+ time_emb = time_emb.reshape(b, c, *((1,)*sp_dim))
89
+ # scale, shift = time_emb.chunk(2, dim = 1)
90
+
91
+ # -------- Combine -------------
92
+ # y = x * (scale + 1) + shift
93
+ x_enc = x_enc + time_emb
94
+
95
+ # ----------- Condition ------------
96
+ if cond_emb is not None:
97
+ cond_emb = self.loca_cond_embedder(cond_emb)
98
+ cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim))
99
+ x_enc = x_enc + cond_emb
100
+
101
+ # ----------- Image -------------
102
+ y = self.up_op(x_enc, x_skip)
103
+
104
+ # -------- Combine -------------
105
+ # y = y * (scale + 1) + shift
106
+
107
+ return y
108
+
109
+
110
+ class UNet(nn.Module):
111
+
112
+ def __init__(self,
113
+ in_ch=1,
114
+ out_ch=1,
115
+ spatial_dims=3,
116
+ hid_chs=[32, 64, 128, 256, 512],
117
+ kernel_sizes=[(1, 3, 3), (1, 3, 3), (1, 3, 3), 3, 3],
118
+ strides=[1, (1, 2, 2), (1, 2, 2), 2, 2],
119
+ upsample_kernel_sizes=None,
120
+ act_name=("SWISH", {}),
121
+ norm_name=("INSTANCE", {"affine": True}),
122
+ time_embedder=TimeEmbbeding,
123
+ time_embedder_kwargs={},
124
+ cond_embedder=None,
125
+ cond_embedder_kwargs={},
126
+ # True = all but last layer, 0/False=disable, 1=only first layer, ...
127
+ deep_ver_supervision=True,
128
+ estimate_variance=False,
129
+ use_self_conditioning=False,
130
+ **kwargs
131
+ ):
132
+ super().__init__()
133
+ if upsample_kernel_sizes is None:
134
+ upsample_kernel_sizes = strides[1:]
135
+
136
+ # ------------- Time-Embedder-----------
137
+ self.time_embedder = time_embedder(**time_embedder_kwargs)
138
+
139
+ # ------------- Condition-Embedder-----------
140
+ if cond_embedder is not None:
141
+ self.cond_embedder = cond_embedder(**cond_embedder_kwargs)
142
+ cond_emb_dim = self.cond_embedder.emb_dim
143
+ else:
144
+ self.cond_embedder = None
145
+ cond_emb_dim = None
146
+
147
+ # ----------- In-Convolution ------------
148
+ in_ch = in_ch*2 if use_self_conditioning else in_ch
149
+ self.inc = UnetBasicBlock(spatial_dims, in_ch, hid_chs[0], kernel_size=kernel_sizes[0], stride=strides[0],
150
+ act_name=act_name, norm_name=norm_name, **kwargs)
151
+
152
+ # ----------- Encoder ----------------
153
+ self.encoders = nn.ModuleList([
154
+ DownBlock(spatial_dims, hid_chs[i-1], hid_chs[i], time_emb_dim=self.time_embedder.emb_dim,
155
+ cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[
156
+ i], stride=strides[i], act_name=act_name,
157
+ norm_name=norm_name, **kwargs)
158
+ for i in range(1, len(strides))
159
+ ])
160
+
161
+ # ------------ Decoder ----------
162
+ self.decoders = nn.ModuleList([
163
+ UpBlock(spatial_dims, hid_chs[i], hid_chs[i+1], time_emb_dim=self.time_embedder.emb_dim,
164
+ cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[i +
165
+ 1], stride=strides[i+1], act_name=act_name,
166
+ norm_name=norm_name, upsample_kernel_size=upsample_kernel_sizes[i], **kwargs)
167
+ for i in range(len(strides)-1)
168
+ ])
169
+
170
+ # --------------- Out-Convolution ----------------
171
+ out_ch_hor = out_ch*2 if estimate_variance else out_ch
172
+ self.outc = UnetOutBlock(
173
+ spatial_dims, hid_chs[0], out_ch_hor, dropout=None)
174
+ if isinstance(deep_ver_supervision, bool):
175
+ deep_ver_supervision = len(
176
+ strides)-2 if deep_ver_supervision else 0
177
+ self.outc_ver = nn.ModuleList([
178
+ UnetOutBlock(spatial_dims, hid_chs[i], out_ch, dropout=None)
179
+ for i in range(1, deep_ver_supervision+1)
180
+ ])
181
+
182
+ def forward(self, x_t, t, cond=None, self_cond=None, **kwargs):
183
+ condition = cond
184
+ # x_t [B, C, (D), H, W]
185
+ # t [B,]
186
+
187
+ # -------- In-Convolution --------------
188
+ x = [None for _ in range(len(self.encoders)+1)]
189
+ x_t = torch.cat([x_t, self_cond],
190
+ dim=1) if self_cond is not None else x_t
191
+ x[0] = self.inc(x_t)
192
+
193
+ # -------- Time Embedding (Gloabl) -----------
194
+ time_emb = self.time_embedder(t) # [B, C]
195
+
196
+ # -------- Condition Embedding (Gloabl) -----------
197
+ if (condition is None) or (self.cond_embedder is None):
198
+ cond_emb = None
199
+ else:
200
+ cond_emb = self.cond_embedder(condition) # [B, C]
201
+
202
+ # --------- Encoder --------------
203
+ for i in range(len(self.encoders)):
204
+ x[i+1] = self.encoders[i](x[i], time_emb, cond_emb)
205
+
206
+ # -------- Decoder -----------
207
+ for i in range(len(self.decoders), 0, -1):
208
+ x[i-1] = self.decoders[i-1](x[i-1], x[i], time_emb, cond_emb)
209
+
210
+ # ---------Out-Convolution ------------
211
+ y_hor = self.outc(x[0])
212
+ y_ver = [outc_ver_i(x[i+1])
213
+ for i, outc_ver_i in enumerate(self.outc_ver)]
214
+
215
+ return y_hor # , y_ver
216
+
217
+ def forward_with_cond_scale(self, *args, cond_scale=0., **kwargs):
218
+ return self.forward(*args, **kwargs)
219
+
220
+
221
+ if __name__ == '__main__':
222
+ model = UNet(in_ch=3)
223
+ input = torch.randn((1, 3, 16, 128, 128))
224
+ time = torch.randn((1,))
225
+ out_hor, out_ver = model(input, time)
226
+ print(out_hor[0].shape)
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/ddpm/util.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from einops import repeat
17
+
18
+ # from ldm.util import instantiate_from_config
19
+
20
+
21
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22
+ if schedule == "linear":
23
+ betas = (
24
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25
+ )
26
+
27
+ elif schedule == "cosine":
28
+ timesteps = (
29
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30
+ )
31
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
32
+ alphas = torch.cos(alphas).pow(2)
33
+ alphas = alphas / alphas[0]
34
+ betas = 1 - alphas[1:] / alphas[:-1]
35
+ betas = np.clip(betas, a_min=0, a_max=0.999)
36
+
37
+ elif schedule == "sqrt_linear":
38
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39
+ elif schedule == "sqrt":
40
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41
+ else:
42
+ raise ValueError(f"schedule '{schedule}' unknown.")
43
+ return betas.numpy()
44
+
45
+
46
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47
+ if ddim_discr_method == 'uniform':
48
+ c = num_ddpm_timesteps // num_ddim_timesteps
49
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50
+ elif ddim_discr_method == 'quad':
51
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52
+ else:
53
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54
+
55
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
57
+ if c != 1:
58
+ steps_out = ddim_timesteps + 1
59
+ else:
60
+ steps_out = ddim_timesteps
61
+ if verbose:
62
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
63
+ return steps_out
64
+
65
+
66
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
67
+ # select alphas for computing the variance schedule
68
+
69
+ alphas = alphacums[ddim_timesteps]
70
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
71
+
72
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
73
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
74
+ if verbose:
75
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
76
+ print(f'For the chosen value of eta, which is {eta}, '
77
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
78
+ return sigmas, alphas, alphas_prev
79
+
80
+
81
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
82
+ """
83
+ Create a beta schedule that discretizes the given alpha_t_bar function,
84
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
85
+ :param num_diffusion_timesteps: the number of betas to produce.
86
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
87
+ produces the cumulative product of (1-beta) up to that
88
+ part of the diffusion process.
89
+ :param max_beta: the maximum beta to use; use values lower than 1 to
90
+ prevent singularities.
91
+ """
92
+ betas = []
93
+ for i in range(num_diffusion_timesteps):
94
+ t1 = i / num_diffusion_timesteps
95
+ t2 = (i + 1) / num_diffusion_timesteps
96
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
97
+ return np.array(betas)
98
+
99
+
100
+ def extract_into_tensor(a, t, x_shape):
101
+ b, *_ = t.shape
102
+ out = a.gather(-1, t)
103
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
104
+
105
+
106
+ def checkpoint(func, inputs, params, flag):
107
+ """
108
+ Evaluate a function without caching intermediate activations, allowing for
109
+ reduced memory at the expense of extra compute in the backward pass.
110
+ :param func: the function to evaluate.
111
+ :param inputs: the argument sequence to pass to `func`.
112
+ :param params: a sequence of parameters `func` depends on but does not
113
+ explicitly take as arguments.
114
+ :param flag: if False, disable gradient checkpointing.
115
+ """
116
+ if flag:
117
+ args = tuple(inputs) + tuple(params)
118
+ return CheckpointFunction.apply(func, len(inputs), *args)
119
+ else:
120
+ return func(*inputs)
121
+
122
+
123
+ class CheckpointFunction(torch.autograd.Function):
124
+ @staticmethod
125
+ def forward(ctx, run_function, length, *args):
126
+ ctx.run_function = run_function
127
+ ctx.input_tensors = list(args[:length])
128
+ ctx.input_params = list(args[length:])
129
+
130
+ with torch.no_grad():
131
+ output_tensors = ctx.run_function(*ctx.input_tensors)
132
+ return output_tensors
133
+
134
+ @staticmethod
135
+ def backward(ctx, *output_grads):
136
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
137
+ with torch.enable_grad():
138
+ # Fixes a bug where the first op in run_function modifies the
139
+ # Tensor storage in place, which is not allowed for detach()'d
140
+ # Tensors.
141
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
142
+ output_tensors = ctx.run_function(*shallow_copies)
143
+ input_grads = torch.autograd.grad(
144
+ output_tensors,
145
+ ctx.input_tensors + ctx.input_params,
146
+ output_grads,
147
+ allow_unused=True,
148
+ )
149
+ del ctx.input_tensors
150
+ del ctx.input_params
151
+ del output_tensors
152
+ return (None, None) + input_grads
153
+
154
+
155
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
156
+ """
157
+ Create sinusoidal timestep embeddings.
158
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
159
+ These may be fractional.
160
+ :param dim: the dimension of the output.
161
+ :param max_period: controls the minimum frequency of the embeddings.
162
+ :return: an [N x dim] Tensor of positional embeddings.
163
+ """
164
+ if not repeat_only:
165
+ half = dim // 2
166
+ freqs = torch.exp(
167
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
168
+ ).to(device=timesteps.device)
169
+ args = timesteps[:, None].float() * freqs[None]
170
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
171
+ if dim % 2:
172
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
173
+ else:
174
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
175
+ return embedding
176
+
177
+
178
+ def zero_module(module):
179
+ """
180
+ Zero out the parameters of a module and return it.
181
+ """
182
+ for p in module.parameters():
183
+ p.detach().zero_()
184
+ return module
185
+
186
+
187
+ def scale_module(module, scale):
188
+ """
189
+ Scale the parameters of a module and return it.
190
+ """
191
+ for p in module.parameters():
192
+ p.detach().mul_(scale)
193
+ return module
194
+
195
+
196
+ def mean_flat(tensor):
197
+ """
198
+ Take the mean over all non-batch dimensions.
199
+ """
200
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
201
+
202
+
203
+ def normalization(channels):
204
+ """
205
+ Make a standard normalization layer.
206
+ :param channels: number of input channels.
207
+ :return: an nn.Module for normalization.
208
+ """
209
+ return GroupNorm32(32, channels)
210
+
211
+
212
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
213
+ class SiLU(nn.Module):
214
+ def forward(self, x):
215
+ return x * torch.sigmoid(x)
216
+
217
+
218
+ class GroupNorm32(nn.GroupNorm):
219
+ def forward(self, x):
220
+ return super().forward(x.float()).type(x.dtype)
221
+
222
+ def conv_nd(dims, *args, **kwargs):
223
+ """
224
+ Create a 1D, 2D, or 3D convolution module.
225
+ """
226
+ if dims == 1:
227
+ return nn.Conv1d(*args, **kwargs)
228
+ elif dims == 2:
229
+ return nn.Conv2d(*args, **kwargs)
230
+ elif dims == 3:
231
+ return nn.Conv3d(*args, **kwargs)
232
+ raise ValueError(f"unsupported dimensions: {dims}")
233
+
234
+
235
+ def linear(*args, **kwargs):
236
+ """
237
+ Create a linear module.
238
+ """
239
+ return nn.Linear(*args, **kwargs)
240
+
241
+
242
+ def avg_pool_nd(dims, *args, **kwargs):
243
+ """
244
+ Create a 1D, 2D, or 3D average pooling module.
245
+ """
246
+ if dims == 1:
247
+ return nn.AvgPool1d(*args, **kwargs)
248
+ elif dims == 2:
249
+ return nn.AvgPool2d(*args, **kwargs)
250
+ elif dims == 3:
251
+ return nn.AvgPool3d(*args, **kwargs)
252
+ raise ValueError(f"unsupported dimensions: {dims}")
253
+
254
+
255
+ class HybridConditioner(nn.Module):
256
+
257
+ def __init__(self, c_concat_config, c_crossattn_config):
258
+ super().__init__()
259
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
260
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
261
+
262
+ def forward(self, c_concat, c_crossattn):
263
+ c_concat = self.concat_conditioner(c_concat)
264
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
265
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
266
+
267
+
268
+ def noise_like(shape, device, repeat=False):
269
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
270
+ noise = lambda: torch.randn(shape, device=device)
271
+ return repeat_noise() if repeat else noise()
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc ADDED
Binary file (4.86 kB). View file
 
Generation_Pipeline_filter_all2/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .vqgan import VQGAN
2
+ from .codebook import Codebook
3
+ from .lpips import LPIPS