Dhenenjay commited on
Commit
1ab92c9
·
verified ·
1 Parent(s): 8870824

Upload unet.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. unet.py +143 -98
unet.py CHANGED
@@ -1,12 +1,9 @@
1
- """E3Diff UNet Architecture - exact copy from original with fixed imports."""
2
-
3
  import math
4
  import torch
5
  from torch import nn
6
  import torch.nn.functional as F
7
  from inspect import isfunction
8
- from softpool import soft_pool2d, SoftPool2d
9
-
10
 
11
  def exists(x):
12
  return x is not None
@@ -17,7 +14,7 @@ def default(val, d):
17
  return val
18
  return d() if isfunction(d) else d
19
 
20
-
21
  class PositionalEncoding(nn.Module):
22
  def __init__(self, dim):
23
  super().__init__()
@@ -77,6 +74,9 @@ class Downsample(nn.Module):
77
  return self.conv(x)
78
 
79
 
 
 
 
80
  class Block(nn.Module):
81
  def __init__(self, dim, dim_out, groups=32, dropout=0, stride=1):
82
  super().__init__()
@@ -96,7 +96,7 @@ class ResnetBlock(nn.Module):
96
  super().__init__()
97
  self.noise_func = FeatureWiseAffine(
98
  noise_level_emb_dim, dim_out, use_affine_level)
99
- self.c_func = nn.Conv2d(dim_out, dim_out, 1)
100
 
101
  self.block1 = Block(dim, dim_out, groups=norm_groups)
102
  self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
@@ -104,17 +104,22 @@ class ResnetBlock(nn.Module):
104
  dim, dim_out, 1) if dim != dim_out else nn.Identity()
105
 
106
  def forward(self, x, time_emb, c):
 
107
  h = self.block1(x)
108
  h = self.noise_func(h, time_emb)
109
  h = self.block2(h)
 
110
  h = self.c_func(c) + h
111
  return h + self.res_conv(x)
112
 
 
113
 
114
  class SelfAttention(nn.Module):
115
  def __init__(self, in_channel, n_head=1, norm_groups=32):
116
  super().__init__()
 
117
  self.n_head = n_head
 
118
  self.norm = nn.GroupNorm(norm_groups, in_channel)
119
  self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False)
120
  self.out = nn.Conv2d(in_channel, in_channel, 1)
@@ -126,7 +131,7 @@ class SelfAttention(nn.Module):
126
 
127
  norm = self.norm(input)
128
  qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width)
129
- query, key, value = qkv.chunk(3, dim=2)
130
 
131
  attn = torch.einsum(
132
  "bnchw, bncyx -> bnhwyx", query, key
@@ -140,6 +145,10 @@ class SelfAttention(nn.Module):
140
 
141
  return out + input
142
 
 
 
 
 
143
 
144
  class ResnetBlocWithAttn(nn.Module):
145
  def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False, size=256):
@@ -151,76 +160,12 @@ class ResnetBlocWithAttn(nn.Module):
151
  self.attn = SelfAttention(dim_out, norm_groups=norm_groups)
152
 
153
  def forward(self, x, time_emb, c, t=0, save_flag=False, file_i=0):
154
- x = self.res_block(x, time_emb, c)
155
- if self.with_attn:
156
  x = self.attn(x, t=t, save_flag=save_flag, file_num=file_i)
157
  return x
158
 
159
 
160
- class ResBlock_normal(nn.Module):
161
- def __init__(self, dim, dim_out, dropout=0, norm_groups=32):
162
- super().__init__()
163
- self.block1 = Block(dim, dim_out, groups=norm_groups)
164
- self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
165
- self.res_conv = nn.Conv2d(
166
- dim, dim_out, 1) if dim != dim_out else nn.Identity()
167
-
168
- def forward(self, x):
169
- b, c, h, w = x.shape
170
- h = self.block1(x)
171
- h = self.block2(h)
172
- return h + self.res_conv(x)
173
-
174
-
175
- class CPEN(nn.Module):
176
- """Condition Pyramid Encoder Network - EXACT architecture from E3Diff."""
177
- def __init__(self, inchannel=1):
178
- super(CPEN, self).__init__()
179
- self.pool = SoftPool2d(kernel_size=(2, 2), stride=(2, 2))
180
-
181
- self.E1 = nn.Sequential(
182
- nn.Conv2d(inchannel, 64, kernel_size=3, padding=1),
183
- Swish()
184
- )
185
-
186
- self.E2 = nn.Sequential(
187
- ResBlock_normal(64, 128, dropout=0, norm_groups=16),
188
- ResBlock_normal(128, 128, dropout=0, norm_groups=16),
189
- )
190
-
191
- self.E3 = nn.Sequential(
192
- ResBlock_normal(128, 256, dropout=0, norm_groups=16),
193
- ResBlock_normal(256, 256, dropout=0, norm_groups=16),
194
- )
195
-
196
- self.E4 = nn.Sequential(
197
- ResBlock_normal(256, 512, dropout=0, norm_groups=16),
198
- ResBlock_normal(512, 512, dropout=0, norm_groups=16),
199
- )
200
-
201
- self.E5 = nn.Sequential(
202
- ResBlock_normal(512, 512, dropout=0, norm_groups=16),
203
- ResBlock_normal(512, 1024, dropout=0, norm_groups=16),
204
- )
205
-
206
- def forward(self, x):
207
- x1 = self.E1(x) # 256x256, 64ch
208
-
209
- x2 = self.pool(x1) # 128x128
210
- x2 = self.E2(x2) # 128x128, 128ch
211
-
212
- x3 = self.pool(x2) # 64x64
213
- x3 = self.E3(x3) # 64x64, 256ch
214
-
215
- x4 = self.pool(x3) # 32x32
216
- x4 = self.E4(x4) # 32x32, 512ch
217
-
218
- x5 = self.pool(x4) # 16x16
219
- x5 = self.E5(x5) # 16x16, 1024ch
220
-
221
- return x1, x2, x3, x4, x5
222
-
223
-
224
  class UNet(nn.Module):
225
  def __init__(
226
  self,
@@ -229,7 +174,7 @@ class UNet(nn.Module):
229
  inner_channel=32,
230
  norm_groups=32,
231
  channel_mults=(1, 2, 4, 8, 8),
232
- attn_res=(8,),
233
  res_blocks=3,
234
  dropout=0,
235
  with_noise_level_emb=True,
@@ -251,35 +196,37 @@ class UNet(nn.Module):
251
  noise_level_channel = None
252
  self.noise_level_mlp = None
253
 
 
 
 
254
  self.res_blocks = res_blocks
255
  num_mults = len(channel_mults)
256
  self.num_mults = num_mults
257
  pre_channel = inner_channel
258
  feat_channels = [pre_channel]
259
  now_res = image_size
260
-
261
- downs = [nn.Conv2d(in_channel, inner_channel, kernel_size=3, padding=1)]
262
  for ind in range(num_mults):
263
  is_last = (ind == num_mults - 1)
264
  use_attn = (now_res in attn_res)
265
  channel_mult = inner_channel * channel_mults[ind]
266
  for _ in range(0, res_blocks):
267
  downs.append(ResnetBlocWithAttn(
268
- pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel,
269
- norm_groups=norm_groups, dropout=dropout, with_attn=use_attn, size=now_res))
270
  feat_channels.append(channel_mult)
271
  pre_channel = channel_mult
272
  if not is_last:
273
  downs.append(Downsample(pre_channel))
274
  feat_channels.append(pre_channel)
275
- now_res = now_res // 2
276
  self.downs = nn.ModuleList(downs)
277
 
278
  self.mid = nn.ModuleList([
279
- ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
280
- norm_groups=norm_groups, dropout=dropout, with_attn=True, size=now_res),
281
- ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
282
- norm_groups=norm_groups, dropout=dropout, with_attn=False, size=now_res)
283
  ])
284
 
285
  ups = []
@@ -287,66 +234,164 @@ class UNet(nn.Module):
287
  is_last = (ind < 1)
288
  use_attn = (now_res in attn_res)
289
  channel_mult = inner_channel * channel_mults[ind]
290
- for _ in range(0, res_blocks + 1):
291
  ups.append(ResnetBlocWithAttn(
292
- pre_channel + feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel,
293
- norm_groups=norm_groups, dropout=dropout, with_attn=use_attn, size=now_res))
294
  pre_channel = channel_mult
295
  if not is_last:
296
  ups.append(Upsample(pre_channel))
297
- now_res = now_res * 2
 
298
  self.ups = nn.ModuleList(ups)
299
 
300
  self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups)
 
301
 
302
- self.condition = CPEN(inchannel=condition_ch)
303
  self.condition_ch = condition_ch
 
 
 
304
 
 
 
 
305
  def forward(self, x, time, img_s1=None, class_label=None, return_condition=False, t_ori=0):
306
- condition = x[:, :self.condition_ch, ...].clone()
 
307
  x = x[:, self.condition_ch:, ...]
308
 
 
309
  c1, c2, c3, c4, c5 = self.condition(condition)
310
  c_base = [c1, c2, c3, c4, c5]
311
 
 
 
 
 
312
  c = []
313
  for i in range(len(c_base)):
314
  for _ in range(self.res_blocks):
315
- c.append(c_base[i])
316
 
317
- t = self.noise_level_mlp(time) if exists(self.noise_level_mlp) else None
 
318
 
 
 
319
  feats = []
320
- i = 0
321
  for layer in self.downs:
322
  if isinstance(layer, ResnetBlocWithAttn):
 
323
  x = layer(x, t, c[i])
324
- i += 1
 
325
  else:
326
  x = layer(x)
 
327
  feats.append(x)
 
 
328
 
329
  for layer in self.mid:
330
  if isinstance(layer, ResnetBlocWithAttn):
331
  x = layer(x, t, c5)
 
332
  else:
333
  x = layer(x)
 
334
 
 
335
  c_base = [c5, c4, c3, c2, c1]
336
  c = []
337
  for i in range(len(c_base)):
338
- for _ in range(self.res_blocks + 1):
339
- c.append(c_base[i])
340
-
341
  i = 0
342
  for layer in self.ups:
343
  if isinstance(layer, ResnetBlocWithAttn):
 
344
  x = layer(torch.cat((x, feats.pop()), dim=1), t, c[i])
345
- i += 1
 
346
  else:
347
  x = layer(x)
348
-
349
  if not return_condition:
350
  return self.final_conv(x)
351
  else:
352
  return self.final_conv(x), [c1, c2, c3, c4, c5]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import math
2
  import torch
3
  from torch import nn
4
  import torch.nn.functional as F
5
  from inspect import isfunction
6
+ import numpy as np
 
7
 
8
  def exists(x):
9
  return x is not None
 
14
  return val
15
  return d() if isfunction(d) else d
16
 
17
+ # PositionalEncoding Source: https://github.com/lmnt-com/wavegrad/blob/master/src/wavegrad/model.py
18
  class PositionalEncoding(nn.Module):
19
  def __init__(self, dim):
20
  super().__init__()
 
74
  return self.conv(x)
75
 
76
 
77
+ # building block modules
78
+
79
+
80
  class Block(nn.Module):
81
  def __init__(self, dim, dim_out, groups=32, dropout=0, stride=1):
82
  super().__init__()
 
96
  super().__init__()
97
  self.noise_func = FeatureWiseAffine(
98
  noise_level_emb_dim, dim_out, use_affine_level)
99
+ self.c_func = nn.Conv2d(dim_out, dim_out, 1)
100
 
101
  self.block1 = Block(dim, dim_out, groups=norm_groups)
102
  self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
 
104
  dim, dim_out, 1) if dim != dim_out else nn.Identity()
105
 
106
  def forward(self, x, time_emb, c):
107
+ # b, c, h, w = x.shape
108
  h = self.block1(x)
109
  h = self.noise_func(h, time_emb)
110
  h = self.block2(h)
111
+
112
  h = self.c_func(c) + h
113
  return h + self.res_conv(x)
114
 
115
+
116
 
117
  class SelfAttention(nn.Module):
118
  def __init__(self, in_channel, n_head=1, norm_groups=32):
119
  super().__init__()
120
+
121
  self.n_head = n_head
122
+
123
  self.norm = nn.GroupNorm(norm_groups, in_channel)
124
  self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False)
125
  self.out = nn.Conv2d(in_channel, in_channel, 1)
 
131
 
132
  norm = self.norm(input)
133
  qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width)
134
+ query, key, value = qkv.chunk(3, dim=2) # bhdyx
135
 
136
  attn = torch.einsum(
137
  "bnchw, bncyx -> bnhwyx", query, key
 
145
 
146
  return out + input
147
 
148
+
149
+
150
+
151
+
152
 
153
  class ResnetBlocWithAttn(nn.Module):
154
  def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False, size=256):
 
160
  self.attn = SelfAttention(dim_out, norm_groups=norm_groups)
161
 
162
  def forward(self, x, time_emb, c, t=0, save_flag=False, file_i=0):
163
+ x = self.res_block(x, time_emb, c) # resblock(x + self.noise_func(noise_embed)) + con1_1(c)
164
+ if(self.with_attn):
165
  x = self.attn(x, t=t, save_flag=save_flag, file_num=file_i)
166
  return x
167
 
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  class UNet(nn.Module):
170
  def __init__(
171
  self,
 
174
  inner_channel=32,
175
  norm_groups=32,
176
  channel_mults=(1, 2, 4, 8, 8),
177
+ attn_res=(8),
178
  res_blocks=3,
179
  dropout=0,
180
  with_noise_level_emb=True,
 
196
  noise_level_channel = None
197
  self.noise_level_mlp = None
198
 
199
+
200
+
201
+
202
  self.res_blocks = res_blocks
203
  num_mults = len(channel_mults)
204
  self.num_mults = num_mults
205
  pre_channel = inner_channel
206
  feat_channels = [pre_channel]
207
  now_res = image_size
208
+ downs = [nn.Conv2d(in_channel, inner_channel,
209
+ kernel_size=3, padding=1)]
210
  for ind in range(num_mults):
211
  is_last = (ind == num_mults - 1)
212
  use_attn = (now_res in attn_res)
213
  channel_mult = inner_channel * channel_mults[ind]
214
  for _ in range(0, res_blocks):
215
  downs.append(ResnetBlocWithAttn(
216
+ pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn,size=now_res))
 
217
  feat_channels.append(channel_mult)
218
  pre_channel = channel_mult
219
  if not is_last:
220
  downs.append(Downsample(pre_channel))
221
  feat_channels.append(pre_channel)
222
+ now_res = now_res//2
223
  self.downs = nn.ModuleList(downs)
224
 
225
  self.mid = nn.ModuleList([
226
+ ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
227
+ dropout=dropout, with_attn=True,size=now_res),
228
+ ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
229
+ dropout=dropout, with_attn=False,size=now_res)
230
  ])
231
 
232
  ups = []
 
234
  is_last = (ind < 1)
235
  use_attn = (now_res in attn_res)
236
  channel_mult = inner_channel * channel_mults[ind]
237
+ for _ in range(0, res_blocks+1):
238
  ups.append(ResnetBlocWithAttn(
239
+ pre_channel+feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
240
+ dropout=dropout, with_attn=use_attn, size=now_res))
241
  pre_channel = channel_mult
242
  if not is_last:
243
  ups.append(Upsample(pre_channel))
244
+ now_res = now_res*2
245
+
246
  self.ups = nn.ModuleList(ups)
247
 
248
  self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups)
249
+
250
 
251
+ self.condition = CPEN(inchannel = condition_ch) # canny+sar
252
  self.condition_ch = condition_ch
253
+ # self.c_func2 = nn.Linear(128, 128) #128 256 512 1024
254
+ self.mi = 0
255
+
256
 
257
+
258
+
259
+
260
  def forward(self, x, time, img_s1=None, class_label=None, return_condition=False, t_ori=0):
261
+ # x torch.cat([x_in['SR'], x_noisy], dim=1)
262
+ condition = x[:, :self.condition_ch, ...].clone()
263
  x = x[:, self.condition_ch:, ...]
264
 
265
+
266
  c1, c2, c3, c4, c5 = self.condition(condition)
267
  c_base = [c1, c2, c3, c4, c5]
268
 
269
+
270
+
271
+
272
+
273
  c = []
274
  for i in range(len(c_base)):
275
  for _ in range(self.res_blocks):
276
+ c.append(c_base[i])
277
 
278
+ t = self.noise_level_mlp(time) if exists(
279
+ self.noise_level_mlp) else None
280
 
281
+
282
+
283
  feats = []
284
+ i=0
285
  for layer in self.downs:
286
  if isinstance(layer, ResnetBlocWithAttn):
287
+
288
  x = layer(x, t, c[i])
289
+ # print(x.shape)
290
+ i+=1
291
  else:
292
  x = layer(x)
293
+
294
  feats.append(x)
295
+
296
+
297
 
298
  for layer in self.mid:
299
  if isinstance(layer, ResnetBlocWithAttn):
300
  x = layer(x, t, c5)
301
+ # print(x.shape)
302
  else:
303
  x = layer(x)
304
+
305
 
306
+
307
  c_base = [c5, c4, c3, c2, c1]
308
  c = []
309
  for i in range(len(c_base)):
310
+ for _ in range(self.res_blocks+1):
311
+ c.append(c_base[i])
 
312
  i = 0
313
  for layer in self.ups:
314
  if isinstance(layer, ResnetBlocWithAttn):
315
+ # print(x.shape)
316
  x = layer(torch.cat((x, feats.pop()), dim=1), t, c[i])
317
+ # print(x.shape)
318
+ i+=1
319
  else:
320
  x = layer(x)
321
+
322
  if not return_condition:
323
  return self.final_conv(x)
324
  else:
325
  return self.final_conv(x), [c1, c2, c3, c4, c5]
326
+
327
+
328
+
329
+ class ResBlock_normal(nn.Module):
330
+ def __init__(self, dim, dim_out, dropout=0, norm_groups=32):
331
+ super().__init__()
332
+
333
+ self.block1 = Block(dim, dim_out, groups=norm_groups)
334
+ self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
335
+ self.res_conv = nn.Conv2d(
336
+ dim, dim_out, 1) if dim != dim_out else nn.Identity()
337
+
338
+ def forward(self, x):
339
+ b, c, h, w = x.shape
340
+ h = self.block1(x)
341
+ h = self.block2(h)
342
+ return h + self.res_conv(x)
343
+
344
+
345
+ from SoftPool import soft_pool2d, SoftPool2d
346
+ class CPEN(nn.Module):
347
+ def __init__(self, inchannel = 1):
348
+ super(CPEN, self).__init__()
349
+ self.pool = SoftPool2d(kernel_size=(2,2), stride=(2,2))
350
+ # self.scale=scale
351
+ # if scale == 2:
352
+
353
+ self.E1= nn.Sequential(nn.Conv2d(inchannel, 64, kernel_size=3, padding=1),
354
+ Swish())
355
+
356
+
357
+
358
+ self.E2=nn.Sequential(
359
+ ResBlock_normal(64, 128, dropout=0, norm_groups=16),
360
+ ResBlock_normal(128, 128, dropout=0, norm_groups=16),
361
+ )
362
+
363
+ self.E3=nn.Sequential(
364
+ ResBlock_normal(128, 256, dropout=0, norm_groups=16),
365
+ ResBlock_normal(256, 256, dropout=0, norm_groups=16),
366
+ )
367
+
368
+ self.E4=nn.Sequential(
369
+ ResBlock_normal(256, 512, dropout=0, norm_groups=16),
370
+ ResBlock_normal(512, 512, dropout=0, norm_groups=16),
371
+ )
372
+
373
+ self.E5=nn.Sequential(
374
+ ResBlock_normal(512, 512, dropout=0, norm_groups=16),
375
+ ResBlock_normal(512, 1024, dropout=0, norm_groups=16),
376
+ )
377
+
378
+
379
+
380
+ def forward(self, x):
381
+
382
+ x1 = self.E1(x)
383
+
384
+ x2 = self.pool(x1)
385
+ x2 = self.E2(x2)
386
+
387
+ x3 = self.pool(x2)
388
+ x3 = self.E3(x3)
389
+
390
+
391
+ x4 = self.pool(x3)
392
+ x4 = self.E4(x4)
393
+
394
+ x5 = self.pool(x4)
395
+ x5 = self.E5(x5)
396
+
397
+ return x1, x2, x3, x4, x5