Fabrice-TIERCELIN commited on
Commit
c284f61
1 Parent(s): 235544b

Upload SUPIR_v0.py

Browse files
Files changed (1) hide show
  1. SUPIR/modules/SUPIR_v0.py +718 -0
SUPIR/modules/SUPIR_v0.py ADDED
@@ -0,0 +1,718 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from einops._torch_specific import allow_ops_in_compiled_graph
2
+ # allow_ops_in_compiled_graph()
3
+ import einops
4
+ import torch
5
+ import torch as th
6
+ import torch.nn as nn
7
+ from einops import rearrange, repeat
8
+
9
+ from sgm.modules.diffusionmodules.util import (
10
+ avg_pool_nd,
11
+ checkpoint,
12
+ conv_nd,
13
+ linear,
14
+ normalization,
15
+ timestep_embedding,
16
+ zero_module,
17
+ )
18
+
19
+ from sgm.modules.diffusionmodules.openaimodel import Downsample, Upsample, UNetModel, Timestep, \
20
+ TimestepEmbedSequential, ResBlock, AttentionBlock, TimestepBlock
21
+ from sgm.modules.attention import SpatialTransformer, MemoryEfficientCrossAttention, CrossAttention
22
+ from sgm.util import default, log_txt_as_img, exists, instantiate_from_config
23
+ import re
24
+ import torch
25
+ from functools import partial
26
+
27
+
28
+ try:
29
+ import xformers
30
+ import xformers.ops
31
+ XFORMERS_IS_AVAILBLE = True
32
+ except:
33
+ XFORMERS_IS_AVAILBLE = False
34
+
35
+
36
+ # dummy replace
37
+ def convert_module_to_f16(x):
38
+ pass
39
+
40
+
41
+ def convert_module_to_f32(x):
42
+ pass
43
+
44
+
45
+ class ZeroConv(nn.Module):
46
+ def __init__(self, label_nc, norm_nc, mask=False):
47
+ super().__init__()
48
+ self.zero_conv = zero_module(conv_nd(2, label_nc, norm_nc, 1, 1, 0))
49
+ self.mask = mask
50
+
51
+ def forward(self, c, h, h_ori=None):
52
+ # with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
53
+ if not self.mask:
54
+ h = h + self.zero_conv(c)
55
+ else:
56
+ h = h + self.zero_conv(c) * torch.zeros_like(h)
57
+ if h_ori is not None:
58
+ h = th.cat([h_ori, h], dim=1)
59
+ return h
60
+
61
+
62
+ class ZeroSFT(nn.Module):
63
+ def __init__(self, label_nc, norm_nc, concat_channels=0, norm=True, mask=False):
64
+ super().__init__()
65
+
66
+ # param_free_norm_type = str(parsed.group(1))
67
+ ks = 3
68
+ pw = ks // 2
69
+
70
+ self.norm = norm
71
+ if self.norm:
72
+ self.param_free_norm = normalization(norm_nc + concat_channels)
73
+ else:
74
+ self.param_free_norm = nn.Identity()
75
+
76
+ nhidden = 128
77
+
78
+ self.mlp_shared = nn.Sequential(
79
+ nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
80
+ nn.SiLU()
81
+ )
82
+ self.zero_mul = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw))
83
+ self.zero_add = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw))
84
+ # self.zero_mul = nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw)
85
+ # self.zero_add = nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw)
86
+
87
+ self.zero_conv = zero_module(conv_nd(2, label_nc, norm_nc, 1, 1, 0))
88
+ self.pre_concat = bool(concat_channels != 0)
89
+ self.mask = mask
90
+
91
+ def forward(self, c, h, h_ori=None, control_scale=1):
92
+ assert self.mask is False
93
+ if h_ori is not None and self.pre_concat:
94
+ h_raw = th.cat([h_ori, h], dim=1)
95
+ else:
96
+ h_raw = h
97
+
98
+ if self.mask:
99
+ h = h + self.zero_conv(c) * torch.zeros_like(h)
100
+ else:
101
+ h = h + self.zero_conv(c)
102
+ if h_ori is not None and self.pre_concat:
103
+ h = th.cat([h_ori, h], dim=1)
104
+ actv = self.mlp_shared(c)
105
+ gamma = self.zero_mul(actv)
106
+ beta = self.zero_add(actv)
107
+ if self.mask:
108
+ gamma = gamma * torch.zeros_like(gamma)
109
+ beta = beta * torch.zeros_like(beta)
110
+ h = self.param_free_norm(h) * (gamma + 1) + beta
111
+ if h_ori is not None and not self.pre_concat:
112
+ h = th.cat([h_ori, h], dim=1)
113
+ return h * control_scale + h_raw * (1 - control_scale)
114
+
115
+
116
+ class ZeroCrossAttn(nn.Module):
117
+ ATTENTION_MODES = {
118
+ "softmax": CrossAttention, # vanilla attention
119
+ "softmax-xformers": MemoryEfficientCrossAttention
120
+ }
121
+
122
+ def __init__(self, context_dim, query_dim, zero_out=True, mask=False):
123
+ super().__init__()
124
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
125
+ assert attn_mode in self.ATTENTION_MODES
126
+ attn_cls = self.ATTENTION_MODES[attn_mode]
127
+ self.attn = attn_cls(query_dim=query_dim, context_dim=context_dim, heads=query_dim//64, dim_head=64)
128
+ self.norm1 = normalization(query_dim)
129
+ self.norm2 = normalization(context_dim)
130
+
131
+ self.mask = mask
132
+
133
+ # if zero_out:
134
+ # # for p in self.attn.to_out.parameters():
135
+ # # p.detach().zero_()
136
+ # self.attn.to_out = zero_module(self.attn.to_out)
137
+
138
+ def forward(self, context, x, control_scale=1):
139
+ assert self.mask is False
140
+ x_in = x
141
+ x = self.norm1(x)
142
+ context = self.norm2(context)
143
+ b, c, h, w = x.shape
144
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
145
+ context = rearrange(context, 'b c h w -> b (h w) c').contiguous()
146
+ x = self.attn(x, context)
147
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
148
+ if self.mask:
149
+ x = x * torch.zeros_like(x)
150
+ x = x_in + x * control_scale
151
+
152
+ return x
153
+
154
+
155
+ class GLVControl(nn.Module):
156
+ def __init__(
157
+ self,
158
+ in_channels,
159
+ model_channels,
160
+ out_channels,
161
+ num_res_blocks,
162
+ attention_resolutions,
163
+ dropout=0,
164
+ channel_mult=(1, 2, 4, 8),
165
+ conv_resample=True,
166
+ dims=2,
167
+ num_classes=None,
168
+ use_checkpoint=False,
169
+ use_fp16=False,
170
+ num_heads=-1,
171
+ num_head_channels=-1,
172
+ num_heads_upsample=-1,
173
+ use_scale_shift_norm=False,
174
+ resblock_updown=False,
175
+ use_new_attention_order=False,
176
+ use_spatial_transformer=False, # custom transformer support
177
+ transformer_depth=1, # custom transformer support
178
+ context_dim=None, # custom transformer support
179
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
180
+ legacy=True,
181
+ disable_self_attentions=None,
182
+ num_attention_blocks=None,
183
+ disable_middle_self_attn=False,
184
+ use_linear_in_transformer=False,
185
+ spatial_transformer_attn_type="softmax",
186
+ adm_in_channels=None,
187
+ use_fairscale_checkpoint=False,
188
+ offload_to_cpu=False,
189
+ transformer_depth_middle=None,
190
+ input_upscale=1,
191
+ ):
192
+ super().__init__()
193
+ from omegaconf.listconfig import ListConfig
194
+
195
+ if use_spatial_transformer:
196
+ assert (
197
+ context_dim is not None
198
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
199
+
200
+ if context_dim is not None:
201
+ assert (
202
+ use_spatial_transformer
203
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
204
+ if type(context_dim) == ListConfig:
205
+ context_dim = list(context_dim)
206
+
207
+ if num_heads_upsample == -1:
208
+ num_heads_upsample = num_heads
209
+
210
+ if num_heads == -1:
211
+ assert (
212
+ num_head_channels != -1
213
+ ), "Either num_heads or num_head_channels has to be set"
214
+
215
+ if num_head_channels == -1:
216
+ assert (
217
+ num_heads != -1
218
+ ), "Either num_heads or num_head_channels has to be set"
219
+
220
+ self.in_channels = in_channels
221
+ self.model_channels = model_channels
222
+ self.out_channels = out_channels
223
+ if isinstance(transformer_depth, int):
224
+ transformer_depth = len(channel_mult) * [transformer_depth]
225
+ elif isinstance(transformer_depth, ListConfig):
226
+ transformer_depth = list(transformer_depth)
227
+ transformer_depth_middle = default(
228
+ transformer_depth_middle, transformer_depth[-1]
229
+ )
230
+
231
+ if isinstance(num_res_blocks, int):
232
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
233
+ else:
234
+ if len(num_res_blocks) != len(channel_mult):
235
+ raise ValueError(
236
+ "provide num_res_blocks either as an int (globally constant) or "
237
+ "as a list/tuple (per-level) with the same length as channel_mult"
238
+ )
239
+ self.num_res_blocks = num_res_blocks
240
+ # self.num_res_blocks = num_res_blocks
241
+ if disable_self_attentions is not None:
242
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
243
+ assert len(disable_self_attentions) == len(channel_mult)
244
+ if num_attention_blocks is not None:
245
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
246
+ assert all(
247
+ map(
248
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
249
+ range(len(num_attention_blocks)),
250
+ )
251
+ )
252
+ print(
253
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
254
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
255
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
256
+ f"attention will still not be set."
257
+ ) # todo: convert to warning
258
+
259
+ self.attention_resolutions = attention_resolutions
260
+ self.dropout = dropout
261
+ self.channel_mult = channel_mult
262
+ self.conv_resample = conv_resample
263
+ self.num_classes = num_classes
264
+ self.use_checkpoint = use_checkpoint
265
+ if use_fp16:
266
+ print("WARNING: use_fp16 was dropped and has no effect anymore.")
267
+ # self.dtype = th.float16 if use_fp16 else th.float32
268
+ self.num_heads = num_heads
269
+ self.num_head_channels = num_head_channels
270
+ self.num_heads_upsample = num_heads_upsample
271
+ self.predict_codebook_ids = n_embed is not None
272
+
273
+ assert use_fairscale_checkpoint != use_checkpoint or not (
274
+ use_checkpoint or use_fairscale_checkpoint
275
+ )
276
+
277
+ self.use_fairscale_checkpoint = False
278
+ checkpoint_wrapper_fn = (
279
+ partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
280
+ if self.use_fairscale_checkpoint
281
+ else lambda x: x
282
+ )
283
+
284
+ time_embed_dim = model_channels * 4
285
+ self.time_embed = checkpoint_wrapper_fn(
286
+ nn.Sequential(
287
+ linear(model_channels, time_embed_dim),
288
+ nn.SiLU(),
289
+ linear(time_embed_dim, time_embed_dim),
290
+ )
291
+ )
292
+
293
+ if self.num_classes is not None:
294
+ if isinstance(self.num_classes, int):
295
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
296
+ elif self.num_classes == "continuous":
297
+ print("setting up linear c_adm embedding layer")
298
+ self.label_emb = nn.Linear(1, time_embed_dim)
299
+ elif self.num_classes == "timestep":
300
+ self.label_emb = checkpoint_wrapper_fn(
301
+ nn.Sequential(
302
+ Timestep(model_channels),
303
+ nn.Sequential(
304
+ linear(model_channels, time_embed_dim),
305
+ nn.SiLU(),
306
+ linear(time_embed_dim, time_embed_dim),
307
+ ),
308
+ )
309
+ )
310
+ elif self.num_classes == "sequential":
311
+ assert adm_in_channels is not None
312
+ self.label_emb = nn.Sequential(
313
+ nn.Sequential(
314
+ linear(adm_in_channels, time_embed_dim),
315
+ nn.SiLU(),
316
+ linear(time_embed_dim, time_embed_dim),
317
+ )
318
+ )
319
+ else:
320
+ raise ValueError()
321
+
322
+ self.input_blocks = nn.ModuleList(
323
+ [
324
+ TimestepEmbedSequential(
325
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
326
+ )
327
+ ]
328
+ )
329
+ self._feature_size = model_channels
330
+ input_block_chans = [model_channels]
331
+ ch = model_channels
332
+ ds = 1
333
+ for level, mult in enumerate(channel_mult):
334
+ for nr in range(self.num_res_blocks[level]):
335
+ layers = [
336
+ checkpoint_wrapper_fn(
337
+ ResBlock(
338
+ ch,
339
+ time_embed_dim,
340
+ dropout,
341
+ out_channels=mult * model_channels,
342
+ dims=dims,
343
+ use_checkpoint=use_checkpoint,
344
+ use_scale_shift_norm=use_scale_shift_norm,
345
+ )
346
+ )
347
+ ]
348
+ ch = mult * model_channels
349
+ if ds in attention_resolutions:
350
+ if num_head_channels == -1:
351
+ dim_head = ch // num_heads
352
+ else:
353
+ num_heads = ch // num_head_channels
354
+ dim_head = num_head_channels
355
+ if legacy:
356
+ # num_heads = 1
357
+ dim_head = (
358
+ ch // num_heads
359
+ if use_spatial_transformer
360
+ else num_head_channels
361
+ )
362
+ if exists(disable_self_attentions):
363
+ disabled_sa = disable_self_attentions[level]
364
+ else:
365
+ disabled_sa = False
366
+
367
+ if (
368
+ not exists(num_attention_blocks)
369
+ or nr < num_attention_blocks[level]
370
+ ):
371
+ layers.append(
372
+ checkpoint_wrapper_fn(
373
+ AttentionBlock(
374
+ ch,
375
+ use_checkpoint=use_checkpoint,
376
+ num_heads=num_heads,
377
+ num_head_channels=dim_head,
378
+ use_new_attention_order=use_new_attention_order,
379
+ )
380
+ )
381
+ if not use_spatial_transformer
382
+ else checkpoint_wrapper_fn(
383
+ SpatialTransformer(
384
+ ch,
385
+ num_heads,
386
+ dim_head,
387
+ depth=transformer_depth[level],
388
+ context_dim=context_dim,
389
+ disable_self_attn=disabled_sa,
390
+ use_linear=use_linear_in_transformer,
391
+ attn_type=spatial_transformer_attn_type,
392
+ use_checkpoint=use_checkpoint,
393
+ )
394
+ )
395
+ )
396
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
397
+ self._feature_size += ch
398
+ input_block_chans.append(ch)
399
+ if level != len(channel_mult) - 1:
400
+ out_ch = ch
401
+ self.input_blocks.append(
402
+ TimestepEmbedSequential(
403
+ checkpoint_wrapper_fn(
404
+ ResBlock(
405
+ ch,
406
+ time_embed_dim,
407
+ dropout,
408
+ out_channels=out_ch,
409
+ dims=dims,
410
+ use_checkpoint=use_checkpoint,
411
+ use_scale_shift_norm=use_scale_shift_norm,
412
+ down=True,
413
+ )
414
+ )
415
+ if resblock_updown
416
+ else Downsample(
417
+ ch, conv_resample, dims=dims, out_channels=out_ch
418
+ )
419
+ )
420
+ )
421
+ ch = out_ch
422
+ input_block_chans.append(ch)
423
+ ds *= 2
424
+ self._feature_size += ch
425
+
426
+ if num_head_channels == -1:
427
+ dim_head = ch // num_heads
428
+ else:
429
+ num_heads = ch // num_head_channels
430
+ dim_head = num_head_channels
431
+ if legacy:
432
+ # num_heads = 1
433
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
434
+ self.middle_block = TimestepEmbedSequential(
435
+ checkpoint_wrapper_fn(
436
+ ResBlock(
437
+ ch,
438
+ time_embed_dim,
439
+ dropout,
440
+ dims=dims,
441
+ use_checkpoint=use_checkpoint,
442
+ use_scale_shift_norm=use_scale_shift_norm,
443
+ )
444
+ ),
445
+ checkpoint_wrapper_fn(
446
+ AttentionBlock(
447
+ ch,
448
+ use_checkpoint=use_checkpoint,
449
+ num_heads=num_heads,
450
+ num_head_channels=dim_head,
451
+ use_new_attention_order=use_new_attention_order,
452
+ )
453
+ )
454
+ if not use_spatial_transformer
455
+ else checkpoint_wrapper_fn(
456
+ SpatialTransformer( # always uses a self-attn
457
+ ch,
458
+ num_heads,
459
+ dim_head,
460
+ depth=transformer_depth_middle,
461
+ context_dim=context_dim,
462
+ disable_self_attn=disable_middle_self_attn,
463
+ use_linear=use_linear_in_transformer,
464
+ attn_type=spatial_transformer_attn_type,
465
+ use_checkpoint=use_checkpoint,
466
+ )
467
+ ),
468
+ checkpoint_wrapper_fn(
469
+ ResBlock(
470
+ ch,
471
+ time_embed_dim,
472
+ dropout,
473
+ dims=dims,
474
+ use_checkpoint=use_checkpoint,
475
+ use_scale_shift_norm=use_scale_shift_norm,
476
+ )
477
+ ),
478
+ )
479
+
480
+ self.input_upscale = input_upscale
481
+ self.input_hint_block = TimestepEmbedSequential(
482
+ zero_module(conv_nd(dims, in_channels, model_channels, 3, padding=1))
483
+ )
484
+
485
+ def convert_to_fp16(self):
486
+ """
487
+ Convert the torso of the model to float16.
488
+ """
489
+ self.input_blocks.apply(convert_module_to_f16)
490
+ self.middle_block.apply(convert_module_to_f16)
491
+
492
+ def convert_to_fp32(self):
493
+ """
494
+ Convert the torso of the model to float32.
495
+ """
496
+ self.input_blocks.apply(convert_module_to_f32)
497
+ self.middle_block.apply(convert_module_to_f32)
498
+
499
+ def forward(self, x, timesteps, xt, context=None, y=None, **kwargs):
500
+ # with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
501
+ # x = x.to(torch.float32)
502
+ # timesteps = timesteps.to(torch.float32)
503
+ # xt = xt.to(torch.float32)
504
+ # context = context.to(torch.float32)
505
+ # y = y.to(torch.float32)
506
+ # print(x.dtype)
507
+ xt, context, y = xt.to(x.dtype), context.to(x.dtype), y.to(x.dtype)
508
+
509
+ if self.input_upscale != 1:
510
+ x = nn.functional.interpolate(x, scale_factor=self.input_upscale, mode='bilinear', antialias=True)
511
+ assert (y is not None) == (
512
+ self.num_classes is not None
513
+ ), "must specify y if and only if the model is class-conditional"
514
+ hs = []
515
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
516
+ # import pdb
517
+ # pdb.set_trace()
518
+ emb = self.time_embed(t_emb)
519
+
520
+ if self.num_classes is not None:
521
+ assert y.shape[0] == xt.shape[0]
522
+ emb = emb + self.label_emb(y)
523
+
524
+ guided_hint = self.input_hint_block(x, emb, context)
525
+
526
+ # h = x.type(self.dtype)
527
+ h = xt
528
+ for module in self.input_blocks:
529
+ if guided_hint is not None:
530
+ h = module(h, emb, context)
531
+ h += guided_hint
532
+ guided_hint = None
533
+ else:
534
+ h = module(h, emb, context)
535
+ hs.append(h)
536
+ # print(module)
537
+ # print(h.shape)
538
+ h = self.middle_block(h, emb, context)
539
+ hs.append(h)
540
+ return hs
541
+
542
+
543
+ class LightGLVUNet(UNetModel):
544
+ def __init__(self, mode='', project_type='ZeroSFT', project_channel_scale=1,
545
+ *args, **kwargs):
546
+ super().__init__(*args, **kwargs)
547
+ if mode == 'XL-base':
548
+ cond_output_channels = [320] * 4 + [640] * 3 + [1280] * 3
549
+ project_channels = [160] * 4 + [320] * 3 + [640] * 3
550
+ concat_channels = [320] * 2 + [640] * 3 + [1280] * 4 + [0]
551
+ cross_attn_insert_idx = [6, 3]
552
+ self.progressive_mask_nums = [0, 3, 7, 11]
553
+ elif mode == 'XL-refine':
554
+ cond_output_channels = [384] * 4 + [768] * 3 + [1536] * 6
555
+ project_channels = [192] * 4 + [384] * 3 + [768] * 6
556
+ concat_channels = [384] * 2 + [768] * 3 + [1536] * 7 + [0]
557
+ cross_attn_insert_idx = [9, 6, 3]
558
+ self.progressive_mask_nums = [0, 3, 6, 10, 14]
559
+ else:
560
+ raise NotImplementedError
561
+
562
+ project_channels = [int(c * project_channel_scale) for c in project_channels]
563
+
564
+ self.project_modules = nn.ModuleList()
565
+ for i in range(len(cond_output_channels)):
566
+ # if i == len(cond_output_channels) - 1:
567
+ # _project_type = 'ZeroCrossAttn'
568
+ # else:
569
+ # _project_type = project_type
570
+ _project_type = project_type
571
+ if _project_type == 'ZeroSFT':
572
+ self.project_modules.append(ZeroSFT(project_channels[i], cond_output_channels[i],
573
+ concat_channels=concat_channels[i]))
574
+ elif _project_type == 'ZeroCrossAttn':
575
+ self.project_modules.append(ZeroCrossAttn(cond_output_channels[i], project_channels[i]))
576
+ else:
577
+ raise NotImplementedError
578
+
579
+ for i in cross_attn_insert_idx:
580
+ self.project_modules.insert(i, ZeroCrossAttn(cond_output_channels[i], concat_channels[i]))
581
+ # print(self.project_modules[i])
582
+
583
+ def step_progressive_mask(self):
584
+ if len(self.progressive_mask_nums) > 0:
585
+ mask_num = self.progressive_mask_nums.pop()
586
+ for i in range(len(self.project_modules)):
587
+ if i < mask_num:
588
+ self.project_modules[i].mask = True
589
+ else:
590
+ self.project_modules[i].mask = False
591
+ return
592
+ # print(f'step_progressive_mask, current masked layers: {mask_num}')
593
+ else:
594
+ return
595
+ # print('step_progressive_mask, no more masked layers')
596
+ # for i in range(len(self.project_modules)):
597
+ # print(self.project_modules[i].mask)
598
+
599
+
600
+ def forward(self, x, timesteps=None, context=None, y=None, control=None, control_scale=1, **kwargs):
601
+ """
602
+ Apply the model to an input batch.
603
+ :param x: an [N x C x ...] Tensor of inputs.
604
+ :param timesteps: a 1-D batch of timesteps.
605
+ :param context: conditioning plugged in via crossattn
606
+ :param y: an [N] Tensor of labels, if class-conditional.
607
+ :return: an [N x C x ...] Tensor of outputs.
608
+ """
609
+ assert (y is not None) == (
610
+ self.num_classes is not None
611
+ ), "must specify y if and only if the model is class-conditional"
612
+ hs = []
613
+
614
+ _dtype = control[0].dtype
615
+ x, context, y = x.to(_dtype), context.to(_dtype), y.to(_dtype)
616
+
617
+ with torch.no_grad():
618
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
619
+ emb = self.time_embed(t_emb)
620
+
621
+ if self.num_classes is not None:
622
+ assert y.shape[0] == x.shape[0]
623
+ emb = emb + self.label_emb(y)
624
+
625
+ # h = x.type(self.dtype)
626
+ h = x
627
+ for module in self.input_blocks:
628
+ h = module(h, emb, context)
629
+ hs.append(h)
630
+
631
+ adapter_idx = len(self.project_modules) - 1
632
+ control_idx = len(control) - 1
633
+ h = self.middle_block(h, emb, context)
634
+ h = self.project_modules[adapter_idx](control[control_idx], h, control_scale=control_scale)
635
+ adapter_idx -= 1
636
+ control_idx -= 1
637
+
638
+ for i, module in enumerate(self.output_blocks):
639
+ _h = hs.pop()
640
+ h = self.project_modules[adapter_idx](control[control_idx], _h, h, control_scale=control_scale)
641
+ adapter_idx -= 1
642
+ # h = th.cat([h, _h], dim=1)
643
+ if len(module) == 3:
644
+ assert isinstance(module[2], Upsample)
645
+ for layer in module[:2]:
646
+ if isinstance(layer, TimestepBlock):
647
+ h = layer(h, emb)
648
+ elif isinstance(layer, SpatialTransformer):
649
+ h = layer(h, context)
650
+ else:
651
+ h = layer(h)
652
+ # print('cross_attn_here')
653
+ h = self.project_modules[adapter_idx](control[control_idx], h, control_scale=control_scale)
654
+ adapter_idx -= 1
655
+ h = module[2](h)
656
+ else:
657
+ h = module(h, emb, context)
658
+ control_idx -= 1
659
+ # print(module)
660
+ # print(h.shape)
661
+
662
+ h = h.type(x.dtype)
663
+ if self.predict_codebook_ids:
664
+ assert False, "not supported anymore. what the f*** are you doing?"
665
+ else:
666
+ return self.out(h)
667
+
668
+ if __name__ == '__main__':
669
+ from omegaconf import OmegaConf
670
+
671
+ # refiner
672
+ # opt = OmegaConf.load('../../options/train/debug_p2_xl.yaml')
673
+ #
674
+ # model = instantiate_from_config(opt.model.params.control_stage_config)
675
+ # hint = model(torch.randn([1, 4, 64, 64]), torch.randn([1]), torch.randn([1, 4, 64, 64]))
676
+ # hint = [h.cuda() for h in hint]
677
+ # print(sum(map(lambda hint: hint.numel(), model.parameters())))
678
+ #
679
+ # unet = instantiate_from_config(opt.model.params.network_config)
680
+ # unet = unet.cuda()
681
+ #
682
+ # _output = unet(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 77, 1280]).cuda(),
683
+ # torch.randn([1, 2560]).cuda(), hint)
684
+ # print(sum(map(lambda _output: _output.numel(), unet.parameters())))
685
+
686
+ # base
687
+ with torch.no_grad():
688
+ opt = OmegaConf.load('../../options/dev/SUPIR_tmp.yaml')
689
+
690
+ model = instantiate_from_config(opt.model.params.control_stage_config)
691
+ model = model.cuda()
692
+
693
+ hint = model(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1, 77, 2048]).cuda(),
694
+ torch.randn([1, 2816]).cuda())
695
+
696
+ for h in hint:
697
+ print(h.shape)
698
+ #
699
+ unet = instantiate_from_config(opt.model.params.network_config)
700
+ unet = unet.cuda()
701
+ _output = unet(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 77, 2048]).cuda(),
702
+ torch.randn([1, 2816]).cuda(), hint)
703
+
704
+
705
+ # model = instantiate_from_config(opt.model.params.control_stage_config)
706
+ # model = model.cuda()
707
+ # # hint = model(torch.randn([1, 4, 64, 64]), torch.randn([1]), torch.randn([1, 4, 64, 64]))
708
+ # hint = model(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1, 77, 1280]).cuda(),
709
+ # torch.randn([1, 2560]).cuda())
710
+ # # hint = [h.cuda() for h in hint]
711
+ #
712
+ # for h in hint:
713
+ # print(h.shape)
714
+ #
715
+ # unet = instantiate_from_config(opt.model.params.network_config)
716
+ # unet = unet.cuda()
717
+ # _output = unet(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 77, 1280]).cuda(),
718
+ # torch.randn([1, 2560]).cuda(), hint)