ZYMPKU commited on
Commit
251f521
1 Parent(s): a34ca7e
.gitignore CHANGED
@@ -1,2 +1 @@
1
- **/__pycache__
2
- process.ipynb
 
1
+ **/__pycache__
 
app.py CHANGED
@@ -8,7 +8,7 @@ from omegaconf import OmegaConf
8
  from contextlib import nullcontext
9
  from pytorch_lightning import seed_everything
10
  from os.path import join as ospj
11
-
12
  from util import *
13
 
14
 
@@ -18,30 +18,17 @@ def predict(cfgs, model, sampler, batch):
18
 
19
  with context():
20
 
21
- batch, batch_uc_1, batch_uc_2 = prepare_batch(cfgs, batch)
22
-
23
- if cfgs.dual_conditioner:
24
- c, uc_1, uc_2 = model.conditioner.get_unconditional_conditioning(
25
- batch,
26
- batch_uc_1=batch_uc_1,
27
- batch_uc_2=batch_uc_2,
28
- force_uc_zero_embeddings=cfgs.force_uc_zero_embeddings,
29
- )
30
- else:
31
- c, uc_1 = model.conditioner.get_unconditional_conditioning(
32
- batch,
33
- batch_uc=batch_uc_1,
34
- force_uc_zero_embeddings=cfgs.force_uc_zero_embeddings,
35
- )
36
 
37
- if cfgs.dual_conditioner:
38
- x = sampler.get_init_noise(cfgs, model, cond=c, batch=batch, uc_1=uc_1, uc_2=uc_2)
39
- samples_z = sampler(model, x, cond=c, batch=batch, uc_1=uc_1, uc_2=uc_2, init_step=0,
40
- aae_enabled = cfgs.aae_enabled, detailed = cfgs.detailed)
41
- else:
42
- x = sampler.get_init_noise(cfgs, model, cond=c, batch=batch, uc=uc_1)
43
- samples_z = sampler(model, x, cond=c, batch=batch, uc=uc_1, init_step=0,
44
- aae_enabled = cfgs.aae_enabled, detailed = cfgs.detailed)
45
 
46
  samples_x = model.decode_first_stage(samples_z)
47
  samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
@@ -131,6 +118,7 @@ def demo_predict(input_blk, text, num_samples, steps, scale, seed, show_detail):
131
 
132
  if __name__ == "__main__":
133
 
 
134
  os.makedirs("./temp/attn_map", exist_ok=True)
135
  os.makedirs("./temp/seg_map", exist_ok=True)
136
 
@@ -151,7 +139,7 @@ if __name__ == "__main__":
151
  UDiffText: A Unified Framework for High-quality Text Synthesis in Arbitrary Images via Character-aware Diffusion Models
152
  </h1>
153
  <ul style="text-align: center; margin: 0.5rem;">
154
- <li style="display: inline-block; margin:auto;"><a href='https://arxiv.org/pdf/******'><img src='https://img.shields.io/badge/Arxiv-******-DF826C'></a></li>
155
  <li style="display: inline-block; margin:auto;"><a href='https://github.com/ZYM-PKU/UDiffText'><img src='https://img.shields.io/badge/Code-UDiffText-D0F288'></a></li>
156
  <li style="display: inline-block; margin:auto;"><a href='https://udifftext.github.io'><img src='https://img.shields.io/badge/Project-UDiffText-8ADAB2'></a></li>
157
  </ul>
@@ -177,7 +165,7 @@ if __name__ == "__main__":
177
  steps = gr.Slider(label="Steps", info ="denoising sampling steps", minimum=1, maximum=200, value=50, step=1)
178
  scale = gr.Slider(label="Guidance Scale", info="the scale of classifier-free guidance (CFG)", minimum=0.0, maximum=10.0, value=4.0, step=0.1)
179
  seed = gr.Slider(label="Seed", info="random seed for noise initialization", minimum=0, maximum=2147483647, step=1, randomize=True)
180
- show_detail = gr.Checkbox(label="Show Detail", info="show the additional visualization results", value=True)
181
 
182
  with gr.Column():
183
 
 
8
  from contextlib import nullcontext
9
  from pytorch_lightning import seed_everything
10
  from os.path import join as ospj
11
+
12
  from util import *
13
 
14
 
 
18
 
19
  with context():
20
 
21
+ batch, batch_uc_1 = prepare_batch(cfgs, batch)
22
+
23
+ c, uc_1 = model.conditioner.get_unconditional_conditioning(
24
+ batch,
25
+ batch_uc=batch_uc_1,
26
+ force_uc_zero_embeddings=cfgs.force_uc_zero_embeddings,
27
+ )
 
 
 
 
 
 
 
 
28
 
29
+ x = sampler.get_init_noise(cfgs, model, cond=c, batch=batch, uc=uc_1)
30
+ samples_z = sampler(model, x, cond=c, batch=batch, uc=uc_1, init_step=0,
31
+ aae_enabled = cfgs.aae_enabled, detailed = cfgs.detailed)
 
 
 
 
 
32
 
33
  samples_x = model.decode_first_stage(samples_z)
34
  samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
 
118
 
119
  if __name__ == "__main__":
120
 
121
+ os.makedirs("./temp", exist_ok=True)
122
  os.makedirs("./temp/attn_map", exist_ok=True)
123
  os.makedirs("./temp/seg_map", exist_ok=True)
124
 
 
139
  UDiffText: A Unified Framework for High-quality Text Synthesis in Arbitrary Images via Character-aware Diffusion Models
140
  </h1>
141
  <ul style="text-align: center; margin: 0.5rem;">
142
+ <li style="display: inline-block; margin:auto;"><a href='https://arxiv.org/abs/2312.04884'><img src='https://img.shields.io/badge/Arxiv-2312.04884-DF826C'></a></li>
143
  <li style="display: inline-block; margin:auto;"><a href='https://github.com/ZYM-PKU/UDiffText'><img src='https://img.shields.io/badge/Code-UDiffText-D0F288'></a></li>
144
  <li style="display: inline-block; margin:auto;"><a href='https://udifftext.github.io'><img src='https://img.shields.io/badge/Project-UDiffText-8ADAB2'></a></li>
145
  </ul>
 
165
  steps = gr.Slider(label="Steps", info ="denoising sampling steps", minimum=1, maximum=200, value=50, step=1)
166
  scale = gr.Slider(label="Guidance Scale", info="the scale of classifier-free guidance (CFG)", minimum=0.0, maximum=10.0, value=4.0, step=0.1)
167
  seed = gr.Slider(label="Seed", info="random seed for noise initialization", minimum=0, maximum=2147483647, step=1, randomize=True)
168
+ show_detail = gr.Checkbox(label="Show Detail", info="show the additional visualization results", value=False)
169
 
170
  with gr.Column():
171
 
checkpoints/{st-step=100000+la-step=100000-simp.ckpt → st-step=100000+la-step=100000-v2.ckpt} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:968397df8910f3324d94ce3df7e9d70f1bf2415a46d22edef1a510885ee0648e
3
- size 2558065830
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b87a307ed6e240208b415166e88c0f3e6467ec9330836d70c6d662f423bfbc15
3
+ size 4173692086
configs/demo.yaml CHANGED
@@ -1,7 +1,7 @@
1
  type: "demo"
2
 
3
  # path
4
- load_ckpt_path: "./checkpoints/st-step=100000+la-step=100000-simp.ckpt"
5
  model_cfg_path: "./configs/test/textdesign_sd_2.yaml"
6
 
7
  # param
@@ -16,8 +16,7 @@ scale: [4.0, 0.0] # content scale, style scale
16
  noise_iters: 10
17
  force_uc_zero_embeddings: ["ref", "label"]
18
  aae_enabled: False
19
- detailed: True
20
- dual_conditioner: False
21
 
22
  # runtime
23
  steps: 50
 
1
  type: "demo"
2
 
3
  # path
4
+ load_ckpt_path: "./checkpoints/st-step=100000+la-step=100000-v2.ckpt"
5
  model_cfg_path: "./configs/test/textdesign_sd_2.yaml"
6
 
7
  # param
 
16
  noise_iters: 10
17
  force_uc_zero_embeddings: ["ref", "label"]
18
  aae_enabled: False
19
+ detailed: False
 
20
 
21
  # runtime
22
  steps: 50
configs/test/textdesign_sd_2.yaml CHANGED
@@ -1,6 +1,8 @@
1
  model:
2
  target: sgm.models.diffusion.DiffusionEngine
3
  params:
 
 
4
  input_key: image
5
  scale_factor: 0.18215
6
  disable_first_stage_autocast: True
@@ -18,54 +20,45 @@ model:
18
  target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
19
 
20
  network_config:
21
- target: sgm.modules.diffusionmodules.openaimodel.UNetAddModel
22
  params:
23
- use_checkpoint: False
24
  in_channels: 9
25
  out_channels: 4
26
  ctrl_channels: 0
27
  model_channels: 320
28
  attention_resolutions: [4, 2, 1]
29
- attn_type: add_attn
30
- attn_layers:
31
- - output_blocks.6.1
32
  num_res_blocks: 2
33
  channel_mult: [1, 2, 4, 4]
34
  num_head_channels: 64
35
- use_spatial_transformer: True
36
  use_linear_in_transformer: True
37
  transformer_depth: 1
38
- context_dim: 0
39
- add_context_dim: 2048
40
- legacy: False
41
 
42
  conditioner_config:
43
  target: sgm.modules.GeneralConditioner
44
  params:
45
  emb_models:
46
- # crossattn cond
47
- # - is_trainable: False
48
- # input_key: txt
49
- # target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
50
- # params:
51
- # arch: ViT-H-14
52
- # version: ./checkpoints/encoders/OpenCLIP/ViT-H-14/open_clip_pytorch_model.bin
53
- # layer: penultimate
54
- # add crossattn cond
55
  - is_trainable: False
 
 
56
  input_key: label
57
  target: sgm.modules.encoders.modules.LabelEncoder
58
  params:
59
- is_add_embedder: True
60
  max_len: 12
61
  emb_dim: 2048
62
  n_heads: 8
63
  n_trans_layers: 12
64
- ckpt_path: ./checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt # ./checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt
65
  # concat cond
66
  - is_trainable: False
67
  input_key: mask
68
- target: sgm.modules.encoders.modules.IdentityEncoder
 
 
 
69
  - is_trainable: False
70
  input_key: masked
71
  target: sgm.modules.encoders.modules.LatentEncoder
@@ -95,6 +88,7 @@ model:
95
  first_stage_config:
96
  target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
97
  params:
 
98
  embed_dim: 4
99
  monitor: val/rec_loss
100
  ddconfig:
@@ -117,9 +111,9 @@ model:
117
  params:
118
  seq_len: 12
119
  kernel_size: 3
120
- gaussian_sigma: 0.5
121
  min_attn_size: 16
122
- lambda_local_loss: 0.02
123
  lambda_ocr_loss: 0.001
124
  ocr_enabled: False
125
 
 
1
  model:
2
  target: sgm.models.diffusion.DiffusionEngine
3
  params:
4
+ opt_keys:
5
+ - t_attn
6
  input_key: image
7
  scale_factor: 0.18215
8
  disable_first_stage_autocast: True
 
20
  target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
21
 
22
  network_config:
23
+ target: sgm.modules.diffusionmodules.openaimodel.UnifiedUNetModel
24
  params:
 
25
  in_channels: 9
26
  out_channels: 4
27
  ctrl_channels: 0
28
  model_channels: 320
29
  attention_resolutions: [4, 2, 1]
30
+ save_attn_type: [t_attn]
31
+ save_attn_layers: [output_blocks.6.1]
 
32
  num_res_blocks: 2
33
  channel_mult: [1, 2, 4, 4]
34
  num_head_channels: 64
 
35
  use_linear_in_transformer: True
36
  transformer_depth: 1
37
+ t_context_dim: 2048
 
 
38
 
39
  conditioner_config:
40
  target: sgm.modules.GeneralConditioner
41
  params:
42
  emb_models:
43
+ # textual crossattn cond
 
 
 
 
 
 
 
 
44
  - is_trainable: False
45
+ emb_key: t_crossattn
46
+ ucg_rate: 0.1
47
  input_key: label
48
  target: sgm.modules.encoders.modules.LabelEncoder
49
  params:
 
50
  max_len: 12
51
  emb_dim: 2048
52
  n_heads: 8
53
  n_trans_layers: 12
54
+ ckpt_path: ./checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt
55
  # concat cond
56
  - is_trainable: False
57
  input_key: mask
58
+ target: sgm.modules.encoders.modules.SpatialRescaler
59
+ params:
60
+ in_channels: 1
61
+ multiplier: 0.125
62
  - is_trainable: False
63
  input_key: masked
64
  target: sgm.modules.encoders.modules.LatentEncoder
 
88
  first_stage_config:
89
  target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
90
  params:
91
+ ckpt_path: ./checkpoints/AEs/AE_inpainting_2.safetensors
92
  embed_dim: 4
93
  monitor: val/rec_loss
94
  ddconfig:
 
111
  params:
112
  seq_len: 12
113
  kernel_size: 3
114
+ gaussian_sigma: 1.0
115
  min_attn_size: 16
116
+ lambda_local_loss: 0.01
117
  lambda_ocr_loss: 0.001
118
  ocr_enabled: False
119
 
sgm/models/diffusion.py CHANGED
@@ -5,6 +5,7 @@ import pytorch_lightning as pl
5
  import torch
6
  from omegaconf import ListConfig, OmegaConf
7
  from safetensors.torch import load_file as load_safetensors
 
8
 
9
  from ..modules import UNCONDITIONAL_CONFIG
10
  from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
 
5
  import torch
6
  from omegaconf import ListConfig, OmegaConf
7
  from safetensors.torch import load_file as load_safetensors
8
+ from torch.optim.lr_scheduler import LambdaLR
9
 
10
  from ..modules import UNCONDITIONAL_CONFIG
11
  from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
sgm/modules/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .encoders.modules import GeneralConditioner, DualConditioner
2
 
3
  UNCONDITIONAL_CONFIG = {
4
  "target": "sgm.modules.GeneralConditioner",
 
1
+ from .encoders.modules import GeneralConditioner
2
 
3
  UNCONDITIONAL_CONFIG = {
4
  "target": "sgm.modules.GeneralConditioner",
sgm/modules/attention.py CHANGED
@@ -5,53 +5,15 @@ from typing import Any, Optional
5
  import torch
6
  import torch.nn.functional as F
7
  from einops import rearrange, repeat
8
- from packaging import version
9
  from torch import nn, einsum
10
 
11
-
12
- if version.parse(torch.__version__) >= version.parse("2.0.0"):
13
- SDP_IS_AVAILABLE = True
14
- from torch.backends.cuda import SDPBackend, sdp_kernel
15
-
16
- BACKEND_MAP = {
17
- SDPBackend.MATH: {
18
- "enable_math": True,
19
- "enable_flash": False,
20
- "enable_mem_efficient": False,
21
- },
22
- SDPBackend.FLASH_ATTENTION: {
23
- "enable_math": False,
24
- "enable_flash": True,
25
- "enable_mem_efficient": False,
26
- },
27
- SDPBackend.EFFICIENT_ATTENTION: {
28
- "enable_math": False,
29
- "enable_flash": False,
30
- "enable_mem_efficient": True,
31
- },
32
- None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
33
- }
34
- else:
35
- from contextlib import nullcontext
36
-
37
- SDP_IS_AVAILABLE = False
38
- sdp_kernel = nullcontext
39
- BACKEND_MAP = {}
40
- print(
41
- f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
42
- f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
43
- )
44
-
45
  try:
46
  import xformers
47
  import xformers.ops
48
-
49
  XFORMERS_IS_AVAILABLE = True
50
  except:
51
  XFORMERS_IS_AVAILABLE = False
52
- print("no module 'xformers'. Processing without...")
53
-
54
- from .diffusionmodules.util import checkpoint
55
 
56
 
57
  def exists(val):
@@ -146,51 +108,6 @@ class LinearAttention(nn.Module):
146
  return self.to_out(out)
147
 
148
 
149
- class SpatialSelfAttention(nn.Module):
150
- def __init__(self, in_channels):
151
- super().__init__()
152
- self.in_channels = in_channels
153
-
154
- self.norm = Normalize(in_channels)
155
- self.q = torch.nn.Conv2d(
156
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
157
- )
158
- self.k = torch.nn.Conv2d(
159
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
160
- )
161
- self.v = torch.nn.Conv2d(
162
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
163
- )
164
- self.proj_out = torch.nn.Conv2d(
165
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
166
- )
167
-
168
- def forward(self, x):
169
- h_ = x
170
- h_ = self.norm(h_)
171
- q = self.q(h_)
172
- k = self.k(h_)
173
- v = self.v(h_)
174
-
175
- # compute attention
176
- b, c, h, w = q.shape
177
- q = rearrange(q, "b c h w -> b (h w) c")
178
- k = rearrange(k, "b c h w -> b c (h w)")
179
- w_ = torch.einsum("bij,bjk->bik", q, k)
180
-
181
- w_ = w_ * (int(c) ** (-0.5))
182
- w_ = torch.nn.functional.softmax(w_, dim=2)
183
-
184
- # attend to values
185
- v = rearrange(v, "b c h w -> b c (h w)")
186
- w_ = rearrange(w_, "b i j -> b j i")
187
- h_ = torch.einsum("bij,bjk->bik", v, w_)
188
- h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
189
- h_ = self.proj_out(h_)
190
-
191
- return x + h_
192
-
193
-
194
  class CrossAttention(nn.Module):
195
  def __init__(
196
  self,
@@ -198,8 +115,7 @@ class CrossAttention(nn.Module):
198
  context_dim=None,
199
  heads=8,
200
  dim_head=64,
201
- dropout=0.0,
202
- backend=None,
203
  ):
204
  super().__init__()
205
  inner_dim = dim_head * heads
@@ -212,60 +128,38 @@ class CrossAttention(nn.Module):
212
  self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
213
  self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
214
 
215
- self.to_out = zero_module(nn.Sequential(
216
- nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
217
- ))
218
- self.backend = backend
 
 
219
 
220
  self.attn_map_cache = None
221
 
222
  def forward(
223
  self,
224
  x,
225
- context=None,
226
- mask=None,
227
- additional_tokens=None,
228
- n_times_crossframe_attn_in_self=0,
229
  ):
230
  h = self.heads
231
 
232
- if additional_tokens is not None:
233
- # get the number of masked tokens at the beginning of the output sequence
234
- n_tokens_to_mask = additional_tokens.shape[1]
235
- # add additional token
236
- x = torch.cat([additional_tokens, x], dim=1)
237
-
238
  q = self.to_q(x)
239
  context = default(context, x)
240
  k = self.to_k(context)
241
  v = self.to_v(context)
242
 
243
- if n_times_crossframe_attn_in_self:
244
- # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
245
- assert x.shape[0] % n_times_crossframe_attn_in_self == 0
246
- n_cp = x.shape[0] // n_times_crossframe_attn_in_self
247
- k = repeat(
248
- k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
249
- )
250
- v = repeat(
251
- v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
252
- )
253
-
254
  q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
255
 
256
  ## old
257
-
258
  sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
259
  del q, k
260
 
261
- if exists(mask):
262
- mask = rearrange(mask, 'b ... -> b (...)')
263
- max_neg_value = -torch.finfo(sim.dtype).max
264
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
265
- sim.masked_fill_(~mask, max_neg_value)
266
-
267
  # attention, what we cannot get enough of
268
- sim = sim.softmax(dim=-1)
 
 
 
269
 
270
  # save attn_map
271
  if self.attn_map_cache is not None:
@@ -276,20 +170,7 @@ class CrossAttention(nn.Module):
276
 
277
  out = einsum('b i j, b j d -> b i d', sim, v)
278
  out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
279
-
280
- ## new
281
- # with sdp_kernel(**BACKEND_MAP[self.backend]):
282
- # # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
283
- # out = F.scaled_dot_product_attention(
284
- # q, k, v, attn_mask=mask
285
- # ) # scale is dim_head ** -0.5 per default
286
-
287
- # del q, k, v
288
- # out = rearrange(out, "b h n d -> b n (h d)", h=h)
289
-
290
- if additional_tokens is not None:
291
- # remove additional token
292
- out = out[:, n_tokens_to_mask:]
293
  return self.to_out(out)
294
 
295
 
@@ -382,10 +263,6 @@ class MemoryEfficientCrossAttention(nn.Module):
382
 
383
 
384
  class BasicTransformerBlock(nn.Module):
385
- ATTENTION_MODES = {
386
- "softmax": CrossAttention, # vanilla attention
387
- "softmax-xformers": MemoryEfficientCrossAttention, # ampere
388
- }
389
 
390
  def __init__(
391
  self,
@@ -393,169 +270,78 @@ class BasicTransformerBlock(nn.Module):
393
  n_heads,
394
  d_head,
395
  dropout=0.0,
396
- context_dim=None,
397
- add_context_dim=None,
398
- gated_ff=True,
399
- checkpoint=True,
400
- disable_self_attn=False,
401
- attn_mode="softmax",
402
- sdp_backend=None,
403
  ):
404
  super().__init__()
405
- assert attn_mode in self.ATTENTION_MODES
406
- if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
407
- print(
408
- f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
409
- f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
410
- )
411
- attn_mode = "softmax"
412
- elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
413
- print(
414
- "We do not support vanilla attention anymore, as it is too expensive. Sorry."
415
- )
416
- if not XFORMERS_IS_AVAILABLE:
417
- assert (
418
- False
419
- ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
420
- else:
421
- print("Falling back to xformers efficient attention.")
422
- attn_mode = "softmax-xformers"
423
- attn_cls = self.ATTENTION_MODES[attn_mode]
424
- if version.parse(torch.__version__) >= version.parse("2.0.0"):
425
- assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
426
- else:
427
- assert sdp_backend is None
428
- self.disable_self_attn = disable_self_attn
429
- self.attn1 = attn_cls(
430
  query_dim=dim,
431
  heads=n_heads,
432
  dim_head=d_head,
433
  dropout=dropout,
434
- context_dim=context_dim if self.disable_self_attn else None,
435
- backend=sdp_backend,
436
- ) # is a self-attention if not self.disable_self_attn
437
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
438
- if context_dim is not None and context_dim > 0:
439
- self.attn2 = attn_cls(
440
  query_dim=dim,
441
- context_dim=context_dim,
442
  heads=n_heads,
443
  dim_head=d_head,
444
- dropout=dropout,
445
- backend=sdp_backend,
446
- ) # is self-attn if context is none
447
- if add_context_dim is not None and add_context_dim > 0:
448
- self.add_attn = attn_cls(
 
 
449
  query_dim=dim,
450
- context_dim=add_context_dim,
451
  heads=n_heads,
452
  dim_head=d_head,
453
- dropout=dropout,
454
- backend=sdp_backend,
455
- ) # is self-attn if context is none
456
- self.add_norm = nn.LayerNorm(dim)
457
- self.norm1 = nn.LayerNorm(dim)
458
- self.norm2 = nn.LayerNorm(dim)
459
- self.norm3 = nn.LayerNorm(dim)
460
- self.checkpoint = checkpoint
461
-
462
- def forward(
463
- self, x, context=None, add_context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
464
- ):
465
- kwargs = {"x": x}
466
-
467
- if context is not None:
468
- kwargs.update({"context": context})
469
-
470
- if additional_tokens is not None:
471
- kwargs.update({"additional_tokens": additional_tokens})
472
-
473
- if n_times_crossframe_attn_in_self:
474
- kwargs.update(
475
- {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
476
  )
 
477
 
478
- return checkpoint(
479
- self._forward, (x, context, add_context), self.parameters(), self.checkpoint
480
- )
481
 
482
- def _forward(
483
- self, x, context=None, add_context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
484
- ):
485
  x = (
486
  self.attn1(
487
  self.norm1(x),
488
- context=context if self.disable_self_attn else None,
489
- additional_tokens=additional_tokens,
490
- n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
491
- if not self.disable_self_attn
492
- else 0,
493
  )
494
  + x
495
  )
496
- if hasattr(self, "attn2"):
497
  x = (
498
- self.attn2(
499
- self.norm2(x), context=context, additional_tokens=additional_tokens
 
500
  )
501
  + x
502
  )
503
- if hasattr(self, "add_attn"):
504
  x = (
505
- self.add_attn(
506
- self.add_norm(x), context=add_context, additional_tokens=additional_tokens
 
507
  )
508
  + x
509
  )
510
- x = self.ff(self.norm3(x)) + x
511
- return x
512
-
513
-
514
- class BasicTransformerSingleLayerBlock(nn.Module):
515
- ATTENTION_MODES = {
516
- "softmax": CrossAttention, # vanilla attention
517
- "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
518
- # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
519
- }
520
-
521
- def __init__(
522
- self,
523
- dim,
524
- n_heads,
525
- d_head,
526
- dropout=0.0,
527
- context_dim=None,
528
- gated_ff=True,
529
- checkpoint=True,
530
- attn_mode="softmax",
531
- ):
532
- super().__init__()
533
- assert attn_mode in self.ATTENTION_MODES
534
- attn_cls = self.ATTENTION_MODES[attn_mode]
535
- self.attn1 = attn_cls(
536
- query_dim=dim,
537
- heads=n_heads,
538
- dim_head=d_head,
539
- dropout=dropout,
540
- context_dim=context_dim,
541
- )
542
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
543
- self.norm1 = nn.LayerNorm(dim)
544
- self.norm2 = nn.LayerNorm(dim)
545
- self.checkpoint = checkpoint
546
 
547
- def forward(self, x, context=None):
548
- return checkpoint(
549
- self._forward, (x, context), self.parameters(), self.checkpoint
550
- )
551
 
552
- def _forward(self, x, context=None):
553
- x = self.attn1(self.norm1(x), context=context) + x
554
- x = self.ff(self.norm2(x)) + x
555
  return x
556
 
557
 
558
- class SpatialTransformer(nn.Module):
559
  """
560
  Transformer block for image-like data.
561
  First, project the input (aka embedding)
@@ -572,36 +358,12 @@ class SpatialTransformer(nn.Module):
572
  d_head,
573
  depth=1,
574
  dropout=0.0,
575
- context_dim=None,
576
- add_context_dim=None,
577
- disable_self_attn=False,
578
- use_linear=False,
579
- attn_type="softmax",
580
- use_checkpoint=True,
581
- # sdp_backend=SDPBackend.FLASH_ATTENTION
582
- sdp_backend=None,
583
  ):
584
  super().__init__()
585
- # print(
586
- # f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
587
- # )
588
- from omegaconf import ListConfig
589
-
590
- if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
591
- context_dim = [context_dim]
592
- if exists(context_dim) and isinstance(context_dim, list):
593
- if depth != len(context_dim):
594
- # print(
595
- # f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
596
- # f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
597
- # )
598
- # depth does not match context dims.
599
- assert all(
600
- map(lambda x: x == context_dim[0], context_dim)
601
- ), "need homogenous context_dim to match depth automatically"
602
- context_dim = depth * [context_dim[0]]
603
- elif context_dim is None:
604
- context_dim = [None] * depth
605
  self.in_channels = in_channels
606
  inner_dim = n_heads * d_head
607
  self.norm = Normalize(in_channels)
@@ -619,12 +381,8 @@ class SpatialTransformer(nn.Module):
619
  n_heads,
620
  d_head,
621
  dropout=dropout,
622
- context_dim=context_dim[d],
623
- add_context_dim=add_context_dim,
624
- disable_self_attn=disable_self_attn,
625
- attn_mode=attn_type,
626
- checkpoint=use_checkpoint,
627
- sdp_backend=sdp_backend,
628
  )
629
  for d in range(depth)
630
  ]
@@ -634,14 +392,11 @@ class SpatialTransformer(nn.Module):
634
  nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
635
  )
636
  else:
637
- # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
638
  self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
639
  self.use_linear = use_linear
640
 
641
- def forward(self, x, context=None, add_context=None):
642
- # note: if no context is given, cross-attention defaults to self-attention
643
- if not isinstance(context, list):
644
- context = [context]
645
  b, c, h, w = x.shape
646
  x_in = x
647
  x = self.norm(x)
@@ -651,326 +406,11 @@ class SpatialTransformer(nn.Module):
651
  if self.use_linear:
652
  x = self.proj_in(x)
653
  for i, block in enumerate(self.transformer_blocks):
654
- if i > 0 and len(context) == 1:
655
- i = 0 # use same context for each block
656
- x = block(x, context=context[i], add_context=add_context)
657
  if self.use_linear:
658
  x = self.proj_out(x)
659
  x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
660
  if not self.use_linear:
661
  x = self.proj_out(x)
662
- return x + x_in
663
-
664
-
665
- def benchmark_attn():
666
- # Lets define a helpful benchmarking function:
667
- # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
668
- device = "cuda" if torch.cuda.is_available() else "cpu"
669
- import torch.nn.functional as F
670
- import torch.utils.benchmark as benchmark
671
-
672
- def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
673
- t0 = benchmark.Timer(
674
- stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
675
- )
676
- return t0.blocked_autorange().mean * 1e6
677
-
678
- # Lets define the hyper-parameters of our input
679
- batch_size = 32
680
- max_sequence_len = 1024
681
- num_heads = 32
682
- embed_dimension = 32
683
-
684
- dtype = torch.float16
685
-
686
- query = torch.rand(
687
- batch_size,
688
- num_heads,
689
- max_sequence_len,
690
- embed_dimension,
691
- device=device,
692
- dtype=dtype,
693
- )
694
- key = torch.rand(
695
- batch_size,
696
- num_heads,
697
- max_sequence_len,
698
- embed_dimension,
699
- device=device,
700
- dtype=dtype,
701
- )
702
- value = torch.rand(
703
- batch_size,
704
- num_heads,
705
- max_sequence_len,
706
- embed_dimension,
707
- device=device,
708
- dtype=dtype,
709
- )
710
-
711
- print(f"q/k/v shape:", query.shape, key.shape, value.shape)
712
-
713
- # Lets explore the speed of each of the 3 implementations
714
- from torch.backends.cuda import SDPBackend, sdp_kernel
715
-
716
- # Helpful arguments mapper
717
- backend_map = {
718
- SDPBackend.MATH: {
719
- "enable_math": True,
720
- "enable_flash": False,
721
- "enable_mem_efficient": False,
722
- },
723
- SDPBackend.FLASH_ATTENTION: {
724
- "enable_math": False,
725
- "enable_flash": True,
726
- "enable_mem_efficient": False,
727
- },
728
- SDPBackend.EFFICIENT_ATTENTION: {
729
- "enable_math": False,
730
- "enable_flash": False,
731
- "enable_mem_efficient": True,
732
- },
733
- }
734
-
735
- from torch.profiler import ProfilerActivity, profile, record_function
736
-
737
- activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
738
-
739
- print(
740
- f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
741
- )
742
- with profile(
743
- activities=activities, record_shapes=False, profile_memory=True
744
- ) as prof:
745
- with record_function("Default detailed stats"):
746
- for _ in range(25):
747
- o = F.scaled_dot_product_attention(query, key, value)
748
- print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
749
-
750
- print(
751
- f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
752
- )
753
- with sdp_kernel(**backend_map[SDPBackend.MATH]):
754
- with profile(
755
- activities=activities, record_shapes=False, profile_memory=True
756
- ) as prof:
757
- with record_function("Math implmentation stats"):
758
- for _ in range(25):
759
- o = F.scaled_dot_product_attention(query, key, value)
760
- print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
761
-
762
- with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
763
- try:
764
- print(
765
- f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
766
- )
767
- except RuntimeError:
768
- print("FlashAttention is not supported. See warnings for reasons.")
769
- with profile(
770
- activities=activities, record_shapes=False, profile_memory=True
771
- ) as prof:
772
- with record_function("FlashAttention stats"):
773
- for _ in range(25):
774
- o = F.scaled_dot_product_attention(query, key, value)
775
- print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
776
-
777
- with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
778
- try:
779
- print(
780
- f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
781
- )
782
- except RuntimeError:
783
- print("EfficientAttention is not supported. See warnings for reasons.")
784
- with profile(
785
- activities=activities, record_shapes=False, profile_memory=True
786
- ) as prof:
787
- with record_function("EfficientAttention stats"):
788
- for _ in range(25):
789
- o = F.scaled_dot_product_attention(query, key, value)
790
- print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
791
-
792
-
793
- def run_model(model, x, context):
794
- return model(x, context)
795
-
796
-
797
- def benchmark_transformer_blocks():
798
- device = "cuda" if torch.cuda.is_available() else "cpu"
799
- import torch.utils.benchmark as benchmark
800
-
801
- def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
802
- t0 = benchmark.Timer(
803
- stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
804
- )
805
- return t0.blocked_autorange().mean * 1e6
806
-
807
- checkpoint = True
808
- compile = False
809
-
810
- batch_size = 32
811
- h, w = 64, 64
812
- context_len = 77
813
- embed_dimension = 1024
814
- context_dim = 1024
815
- d_head = 64
816
-
817
- transformer_depth = 4
818
-
819
- n_heads = embed_dimension // d_head
820
-
821
- dtype = torch.float16
822
-
823
- model_native = SpatialTransformer(
824
- embed_dimension,
825
- n_heads,
826
- d_head,
827
- context_dim=context_dim,
828
- use_linear=True,
829
- use_checkpoint=checkpoint,
830
- attn_type="softmax",
831
- depth=transformer_depth,
832
- sdp_backend=SDPBackend.FLASH_ATTENTION,
833
- ).to(device)
834
- model_efficient_attn = SpatialTransformer(
835
- embed_dimension,
836
- n_heads,
837
- d_head,
838
- context_dim=context_dim,
839
- use_linear=True,
840
- depth=transformer_depth,
841
- use_checkpoint=checkpoint,
842
- attn_type="softmax-xformers",
843
- ).to(device)
844
- if not checkpoint and compile:
845
- print("compiling models")
846
- model_native = torch.compile(model_native)
847
- model_efficient_attn = torch.compile(model_efficient_attn)
848
-
849
- x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
850
- c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
851
-
852
- from torch.profiler import ProfilerActivity, profile, record_function
853
-
854
- activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
855
-
856
- with torch.autocast("cuda"):
857
- print(
858
- f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
859
- )
860
- print(
861
- f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
862
- )
863
-
864
- print(75 * "+")
865
- print("NATIVE")
866
- print(75 * "+")
867
- torch.cuda.reset_peak_memory_stats()
868
- with profile(
869
- activities=activities, record_shapes=False, profile_memory=True
870
- ) as prof:
871
- with record_function("NativeAttention stats"):
872
- for _ in range(25):
873
- model_native(x, c)
874
- print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
875
- print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
876
-
877
- print(75 * "+")
878
- print("Xformers")
879
- print(75 * "+")
880
- torch.cuda.reset_peak_memory_stats()
881
- with profile(
882
- activities=activities, record_shapes=False, profile_memory=True
883
- ) as prof:
884
- with record_function("xformers stats"):
885
- for _ in range(25):
886
- model_efficient_attn(x, c)
887
- print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
888
- print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
889
-
890
-
891
- def test01():
892
- # conv1x1 vs linear
893
- from ..util import count_params
894
-
895
- conv = nn.Conv2d(3, 32, kernel_size=1).cuda()
896
- print(count_params(conv))
897
- linear = torch.nn.Linear(3, 32).cuda()
898
- print(count_params(linear))
899
-
900
- print(conv.weight.shape)
901
-
902
- # use same initialization
903
- linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
904
- linear.bias = torch.nn.Parameter(conv.bias)
905
-
906
- print(linear.weight.shape)
907
-
908
- x = torch.randn(11, 3, 64, 64).cuda()
909
-
910
- xr = rearrange(x, "b c h w -> b (h w) c").contiguous()
911
- print(xr.shape)
912
- out_linear = linear(xr)
913
- print(out_linear.mean(), out_linear.shape)
914
-
915
- out_conv = conv(x)
916
- print(out_conv.mean(), out_conv.shape)
917
- print("done with test01.\n")
918
-
919
-
920
- def test02():
921
- # try cosine flash attention
922
- import time
923
-
924
- torch.backends.cuda.matmul.allow_tf32 = True
925
- torch.backends.cudnn.allow_tf32 = True
926
- torch.backends.cudnn.benchmark = True
927
- print("testing cosine flash attention...")
928
- DIM = 1024
929
- SEQLEN = 4096
930
- BS = 16
931
-
932
- print(" softmax (vanilla) first...")
933
- model = BasicTransformerBlock(
934
- dim=DIM,
935
- n_heads=16,
936
- d_head=64,
937
- dropout=0.0,
938
- context_dim=None,
939
- attn_mode="softmax",
940
- ).cuda()
941
- try:
942
- x = torch.randn(BS, SEQLEN, DIM).cuda()
943
- tic = time.time()
944
- y = model(x)
945
- toc = time.time()
946
- print(y.shape, toc - tic)
947
- except RuntimeError as e:
948
- # likely oom
949
- print(str(e))
950
-
951
- print("\n now flash-cosine...")
952
- model = BasicTransformerBlock(
953
- dim=DIM,
954
- n_heads=16,
955
- d_head=64,
956
- dropout=0.0,
957
- context_dim=None,
958
- attn_mode="flash-cosine",
959
- ).cuda()
960
- x = torch.randn(BS, SEQLEN, DIM).cuda()
961
- tic = time.time()
962
- y = model(x)
963
- toc = time.time()
964
- print(y.shape, toc - tic)
965
- print("done with test02.\n")
966
-
967
-
968
- if __name__ == "__main__":
969
- # test01()
970
- # test02()
971
- # test03()
972
-
973
- # benchmark_attn()
974
- benchmark_transformer_blocks()
975
 
976
- print("done.")
 
5
  import torch
6
  import torch.nn.functional as F
7
  from einops import rearrange, repeat
 
8
  from torch import nn, einsum
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  try:
11
  import xformers
12
  import xformers.ops
 
13
  XFORMERS_IS_AVAILABLE = True
14
  except:
15
  XFORMERS_IS_AVAILABLE = False
16
+ print("No module 'xformers'.")
 
 
17
 
18
 
19
  def exists(val):
 
108
  return self.to_out(out)
109
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  class CrossAttention(nn.Module):
112
  def __init__(
113
  self,
 
115
  context_dim=None,
116
  heads=8,
117
  dim_head=64,
118
+ dropout=0.0
 
119
  ):
120
  super().__init__()
121
  inner_dim = dim_head * heads
 
128
  self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
129
  self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
130
 
131
+ self.to_out = zero_module(
132
+ nn.Sequential(
133
+ nn.Linear(inner_dim, query_dim),
134
+ nn.Dropout(dropout)
135
+ )
136
+ )
137
 
138
  self.attn_map_cache = None
139
 
140
  def forward(
141
  self,
142
  x,
143
+ context=None
 
 
 
144
  ):
145
  h = self.heads
146
 
 
 
 
 
 
 
147
  q = self.to_q(x)
148
  context = default(context, x)
149
  k = self.to_k(context)
150
  v = self.to_v(context)
151
 
 
 
 
 
 
 
 
 
 
 
 
152
  q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
153
 
154
  ## old
 
155
  sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
156
  del q, k
157
 
 
 
 
 
 
 
158
  # attention, what we cannot get enough of
159
+ if sim.shape[-1] > 1:
160
+ sim = sim.softmax(dim=-1) # softmax on token dim
161
+ else:
162
+ sim = sim.sigmoid() # sigmoid on pixel dim
163
 
164
  # save attn_map
165
  if self.attn_map_cache is not None:
 
170
 
171
  out = einsum('b i j, b j d -> b i d', sim, v)
172
  out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
173
+
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  return self.to_out(out)
175
 
176
 
 
263
 
264
 
265
  class BasicTransformerBlock(nn.Module):
 
 
 
 
266
 
267
  def __init__(
268
  self,
 
270
  n_heads,
271
  d_head,
272
  dropout=0.0,
273
+ t_context_dim=None,
274
+ v_context_dim=None,
275
+ gated_ff=True
 
 
 
 
276
  ):
277
  super().__init__()
278
+
279
+ # self-attention
280
+ self.attn1 = MemoryEfficientCrossAttention(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  query_dim=dim,
282
  heads=n_heads,
283
  dim_head=d_head,
284
  dropout=dropout,
285
+ context_dim=None
286
+ )
287
+
288
+ # textual cross-attention
289
+ if t_context_dim is not None and t_context_dim > 0:
290
+ self.t_attn = CrossAttention(
291
  query_dim=dim,
292
+ context_dim=t_context_dim,
293
  heads=n_heads,
294
  dim_head=d_head,
295
+ dropout=dropout
296
+ )
297
+ self.t_norm = nn.LayerNorm(dim)
298
+
299
+ # visual cross-attention
300
+ if v_context_dim is not None and v_context_dim > 0:
301
+ self.v_attn = CrossAttention(
302
  query_dim=dim,
303
+ context_dim=v_context_dim,
304
  heads=n_heads,
305
  dim_head=d_head,
306
+ dropout=dropout
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  )
308
+ self.v_norm = nn.LayerNorm(dim)
309
 
310
+ self.norm1 = nn.LayerNorm(dim)
311
+ self.norm3 = nn.LayerNorm(dim)
312
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
313
 
314
+ def forward(self, x, t_context=None, v_context=None):
 
 
315
  x = (
316
  self.attn1(
317
  self.norm1(x),
318
+ context=None
 
 
 
 
319
  )
320
  + x
321
  )
322
+ if hasattr(self, "t_attn"):
323
  x = (
324
+ self.t_attn(
325
+ self.t_norm(x),
326
+ context=t_context
327
  )
328
  + x
329
  )
330
+ if hasattr(self, "v_attn"):
331
  x = (
332
+ self.v_attn(
333
+ self.v_norm(x),
334
+ context=v_context
335
  )
336
  + x
337
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
+ x = self.ff(self.norm3(x)) + x
 
 
 
340
 
 
 
 
341
  return x
342
 
343
 
344
+ class SpatialTransformer(nn.Module):
345
  """
346
  Transformer block for image-like data.
347
  First, project the input (aka embedding)
 
358
  d_head,
359
  depth=1,
360
  dropout=0.0,
361
+ t_context_dim=None,
362
+ v_context_dim=None,
363
+ use_linear=False
 
 
 
 
 
364
  ):
365
  super().__init__()
366
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  self.in_channels = in_channels
368
  inner_dim = n_heads * d_head
369
  self.norm = Normalize(in_channels)
 
381
  n_heads,
382
  d_head,
383
  dropout=dropout,
384
+ t_context_dim=t_context_dim,
385
+ v_context_dim=v_context_dim
 
 
 
 
386
  )
387
  for d in range(depth)
388
  ]
 
392
  nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
393
  )
394
  else:
 
395
  self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
396
  self.use_linear = use_linear
397
 
398
+ def forward(self, x, t_context=None, v_context=None):
399
+
 
 
400
  b, c, h, w = x.shape
401
  x_in = x
402
  x = self.norm(x)
 
406
  if self.use_linear:
407
  x = self.proj_in(x)
408
  for i, block in enumerate(self.transformer_blocks):
409
+ x = block(x, t_context=t_context, v_context=v_context)
 
 
410
  if self.use_linear:
411
  x = self.proj_out(x)
412
  x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
413
  if not self.use_linear:
414
  x = self.proj_out(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
 
416
+ return x + x_in
sgm/modules/diffusionmodules/__init__.py CHANGED
@@ -2,6 +2,6 @@ from .denoiser import Denoiser
2
  from .discretizer import Discretization
3
  from .loss import StandardDiffusionLoss
4
  from .model import Model, Encoder, Decoder
5
- from .openaimodel import UNetModel
6
  from .sampling import BaseDiffusionSampler
7
  from .wrappers import OpenAIWrapper
 
2
  from .discretizer import Discretization
3
  from .loss import StandardDiffusionLoss
4
  from .model import Model, Encoder, Decoder
5
+ from .openaimodel import UnifiedUNetModel
6
  from .sampling import BaseDiffusionSampler
7
  from .wrappers import OpenAIWrapper
sgm/modules/diffusionmodules/guiders.py CHANGED
@@ -32,7 +32,7 @@ class VanillaCFG:
32
  c_out = dict()
33
 
34
  for k in c:
35
- if k in ["vector", "crossattn", "add_crossattn", "concat"]:
36
  c_out[k] = torch.cat((uc[k], c[k]), 0)
37
  else:
38
  assert c[k] == uc[k]
@@ -40,34 +40,6 @@ class VanillaCFG:
40
  return torch.cat([x] * 2), torch.cat([s] * 2), c_out
41
 
42
 
43
- class DualCFG:
44
-
45
- def __init__(self, scale):
46
- self.scale = scale
47
- self.dyn_thresh = instantiate_from_config(
48
- {
49
- "target": "sgm.modules.diffusionmodules.sampling_utils.DualThresholding"
50
- },
51
- )
52
-
53
- def __call__(self, x, sigma):
54
- x_u_1, x_u_2, x_c = x.chunk(3)
55
- x_pred = self.dyn_thresh(x_u_1, x_u_2, x_c, self.scale)
56
- return x_pred
57
-
58
- def prepare_inputs(self, x, s, c, uc_1, uc_2):
59
- c_out = dict()
60
-
61
- for k in c:
62
- if k in ["vector", "crossattn", "concat", "add_crossattn"]:
63
- c_out[k] = torch.cat((uc_1[k], uc_2[k], c[k]), 0)
64
- else:
65
- assert c[k] == uc_1[k]
66
- c_out[k] = c[k]
67
- return torch.cat([x] * 3), torch.cat([s] * 3), c_out
68
-
69
-
70
-
71
  class IdentityGuider:
72
  def __call__(self, x, sigma):
73
  return x
 
32
  c_out = dict()
33
 
34
  for k in c:
35
+ if k in ["vector", "t_crossattn", "v_crossattn", "concat"]:
36
  c_out[k] = torch.cat((uc[k], c[k]), 0)
37
  else:
38
  assert c[k] == uc[k]
 
40
  return torch.cat([x] * 2), torch.cat([s] * 2), c_out
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  class IdentityGuider:
44
  def __call__(self, x, sigma):
45
  return x
sgm/modules/diffusionmodules/loss.py CHANGED
@@ -4,7 +4,6 @@ import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  from omegaconf import ListConfig
7
- # from taming.modules.losses.lpips import LPIPS
8
  from torchvision.utils import save_image
9
  from ...util import append_dims, instantiate_from_config
10
 
@@ -19,16 +18,13 @@ class StandardDiffusionLoss(nn.Module):
19
  ):
20
  super().__init__()
21
 
22
- assert type in ["l2", "l1", "lpips"]
23
 
24
  self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
25
 
26
  self.type = type
27
  self.offset_noise_level = offset_noise_level
28
 
29
- # if type == "lpips":
30
- # self.lpips = LPIPS().eval()
31
-
32
  if not batch2model_keys:
33
  batch2model_keys = []
34
 
@@ -70,9 +66,6 @@ class StandardDiffusionLoss(nn.Module):
70
  return torch.mean(
71
  (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
72
  )
73
- elif self.type == "lpips":
74
- loss = self.lpips(model_output, target).reshape(-1)
75
- return loss
76
 
77
 
78
  class FullLoss(StandardDiffusionLoss):
@@ -85,7 +78,9 @@ class FullLoss(StandardDiffusionLoss):
85
  min_attn_size=16,
86
  lambda_local_loss=0.0,
87
  lambda_ocr_loss=0.0,
 
88
  ocr_enabled = False,
 
89
  predictor_config = None,
90
  *args, **kwarg
91
  ):
@@ -98,7 +93,9 @@ class FullLoss(StandardDiffusionLoss):
98
  self.min_attn_size = min_attn_size
99
  self.lambda_local_loss = lambda_local_loss
100
  self.lambda_ocr_loss = lambda_ocr_loss
 
101
 
 
102
  self.ocr_enabled = ocr_enabled
103
  if ocr_enabled:
104
  self.predictor = instantiate_from_config(predictor_config)
@@ -155,9 +152,15 @@ class FullLoss(StandardDiffusionLoss):
155
  ocr_loss = self.get_ocr_loss(model_output, batch["r_bbox"], batch["label"], first_stage_model, scaler)
156
  ocr_loss = ocr_loss.mean()
157
 
 
 
 
 
158
  loss = diff_loss + self.lambda_local_loss * local_loss
159
  if self.ocr_enabled:
160
  loss += self.lambda_ocr_loss * ocr_loss
 
 
161
 
162
  loss_dict = {
163
  "loss/diff_loss": diff_loss,
@@ -167,6 +170,8 @@ class FullLoss(StandardDiffusionLoss):
167
 
168
  if self.ocr_enabled:
169
  loss_dict["loss/ocr_loss"] = ocr_loss
 
 
170
 
171
  return loss, loss_dict
172
 
@@ -191,6 +196,9 @@ class FullLoss(StandardDiffusionLoss):
191
 
192
  for item in attn_map_cache:
193
 
 
 
 
194
  heads = item["heads"]
195
  size = item["size"]
196
  attn_map = item["attn_map"]
@@ -233,6 +241,9 @@ class FullLoss(StandardDiffusionLoss):
233
 
234
  for item in attn_map_cache:
235
 
 
 
 
236
  heads = item["heads"]
237
  size = item["size"]
238
  attn_map = item["attn_map"]
@@ -241,7 +252,7 @@ class FullLoss(StandardDiffusionLoss):
241
 
242
  seg_l = seg_mask.shape[1]
243
 
244
- bh, n, l = attn_map.shape # bh: batch size * heads / n : pixel length(h*w) / l: token length
245
  attn_map = attn_map.reshape((-1, heads, n, l)) # b, h, n, l
246
 
247
  assert seg_l <= l
@@ -272,4 +283,43 @@ class FullLoss(StandardDiffusionLoss):
272
 
273
  loss = loss / count
274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  return loss
 
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  from omegaconf import ListConfig
 
7
  from torchvision.utils import save_image
8
  from ...util import append_dims, instantiate_from_config
9
 
 
18
  ):
19
  super().__init__()
20
 
21
+ assert type in ["l2", "l1"]
22
 
23
  self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
24
 
25
  self.type = type
26
  self.offset_noise_level = offset_noise_level
27
 
 
 
 
28
  if not batch2model_keys:
29
  batch2model_keys = []
30
 
 
66
  return torch.mean(
67
  (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
68
  )
 
 
 
69
 
70
 
71
  class FullLoss(StandardDiffusionLoss):
 
78
  min_attn_size=16,
79
  lambda_local_loss=0.0,
80
  lambda_ocr_loss=0.0,
81
+ lambda_style_loss=0.0,
82
  ocr_enabled = False,
83
+ style_enabled = False,
84
  predictor_config = None,
85
  *args, **kwarg
86
  ):
 
93
  self.min_attn_size = min_attn_size
94
  self.lambda_local_loss = lambda_local_loss
95
  self.lambda_ocr_loss = lambda_ocr_loss
96
+ self.lambda_style_loss = lambda_style_loss
97
 
98
+ self.style_enabled = style_enabled
99
  self.ocr_enabled = ocr_enabled
100
  if ocr_enabled:
101
  self.predictor = instantiate_from_config(predictor_config)
 
152
  ocr_loss = self.get_ocr_loss(model_output, batch["r_bbox"], batch["label"], first_stage_model, scaler)
153
  ocr_loss = ocr_loss.mean()
154
 
155
+ if self.style_enabled:
156
+ style_loss = self.get_style_local_loss(network.diffusion_model.attn_map_cache, batch["mask"])
157
+ style_loss = style_loss.mean()
158
+
159
  loss = diff_loss + self.lambda_local_loss * local_loss
160
  if self.ocr_enabled:
161
  loss += self.lambda_ocr_loss * ocr_loss
162
+ if self.style_enabled:
163
+ loss += self.lambda_style_loss * style_loss
164
 
165
  loss_dict = {
166
  "loss/diff_loss": diff_loss,
 
170
 
171
  if self.ocr_enabled:
172
  loss_dict["loss/ocr_loss"] = ocr_loss
173
+ if self.style_enabled:
174
+ loss_dict["loss/style_loss"] = style_loss
175
 
176
  return loss, loss_dict
177
 
 
196
 
197
  for item in attn_map_cache:
198
 
199
+ name = item["name"]
200
+ if not name.endswith("t_attn"): continue
201
+
202
  heads = item["heads"]
203
  size = item["size"]
204
  attn_map = item["attn_map"]
 
241
 
242
  for item in attn_map_cache:
243
 
244
+ name = item["name"]
245
+ if not name.endswith("t_attn"): continue
246
+
247
  heads = item["heads"]
248
  size = item["size"]
249
  attn_map = item["attn_map"]
 
252
 
253
  seg_l = seg_mask.shape[1]
254
 
255
+ bh, n, l = attn_map.shape # bh: batch size * heads / n: pixel length(h*w) / l: token length
256
  attn_map = attn_map.reshape((-1, heads, n, l)) # b, h, n, l
257
 
258
  assert seg_l <= l
 
283
 
284
  loss = loss / count
285
 
286
+ return loss
287
+
288
+ def get_style_local_loss(self, attn_map_cache, mask):
289
+
290
+ loss = 0
291
+ count = 0
292
+
293
+ for item in attn_map_cache:
294
+
295
+ name = item["name"]
296
+ if not name.endswith("v_attn"): continue
297
+
298
+ heads = item["heads"]
299
+ size = item["size"]
300
+ attn_map = item["attn_map"]
301
+
302
+ if size < self.min_attn_size: continue
303
+
304
+ bh, n, l = attn_map.shape # bh: batch size * heads / n: pixel length(h*w) / l: token length
305
+ attn_map = attn_map.reshape((-1, heads, n, l)) # b, h, n, l
306
+ attn_map = attn_map.permute(0, 1, 3, 2) # b, h, l, n
307
+ attn_map = attn_map.mean(dim = 1) # b, l, n
308
+
309
+ mask_map = F.interpolate(mask, (size, size))
310
+ mask_map = mask_map.reshape((-1, l, n)) # b, l, n
311
+ n_mask_map = 1 - mask_map
312
+
313
+ p_loss = (mask_map * attn_map).sum(dim = -1) / (mask_map.sum(dim = -1) + 1e-5) # b, l
314
+ n_loss = (n_mask_map * attn_map).sum(dim = -1) / (n_mask_map.sum(dim = -1) + 1e-5) # b, l
315
+
316
+ p_loss = p_loss.mean(dim = -1)
317
+ n_loss = n_loss.mean(dim = -1)
318
+
319
+ f_loss = n_loss - p_loss # b,
320
+ loss += f_loss
321
+ count += 1
322
+
323
+ loss = loss / count
324
+
325
  return loss
sgm/modules/diffusionmodules/openaimodel.py CHANGED
@@ -1,6 +1,4 @@
1
- import math
2
  from abc import abstractmethod
3
- from functools import partial
4
  from typing import Iterable
5
 
6
  import numpy as np
@@ -12,7 +10,6 @@ from einops import rearrange
12
  from ...modules.attention import SpatialTransformer
13
  from ...modules.diffusionmodules.util import (
14
  avg_pool_nd,
15
- checkpoint,
16
  conv_nd,
17
  linear,
18
  normalization,
@@ -22,47 +19,14 @@ from ...modules.diffusionmodules.util import (
22
  from ...util import default, exists
23
 
24
 
25
- # dummy replace
26
- def convert_module_to_f16(x):
27
- pass
28
-
29
-
30
- def convert_module_to_f32(x):
31
- pass
32
-
33
-
34
- ## go
35
- class AttentionPool2d(nn.Module):
36
- """
37
- Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
38
- """
39
-
40
- def __init__(
41
- self,
42
- spacial_dim: int,
43
- embed_dim: int,
44
- num_heads_channels: int,
45
- output_dim: int = None,
46
- ):
47
  super().__init__()
48
- self.positional_embedding = nn.Parameter(
49
- th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
50
- )
51
- self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
52
- self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
53
- self.num_heads = embed_dim // num_heads_channels
54
- self.attention = QKVAttention(self.num_heads)
55
-
56
- def forward(self, x):
57
- b, c, *_spatial = x.shape
58
- x = x.reshape(b, c, -1) # NC(HW)
59
- x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
60
- x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
61
- x = self.qkv_proj(x)
62
- x = self.attention(x)
63
- x = self.c_proj(x)
64
- return x[:, :, 0]
65
 
 
 
 
66
 
67
  class TimestepBlock(nn.Module):
68
  """
@@ -86,19 +50,14 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
86
  self,
87
  x,
88
  emb,
89
- context=None,
90
- add_context=None,
91
- skip_time_mix=False,
92
- time_context=None,
93
- num_video_frames=None,
94
- time_context_cat=None,
95
- use_crossframe_attention_in_spatial_layers=False,
96
  ):
97
  for layer in self:
98
  if isinstance(layer, TimestepBlock):
99
  x = layer(x, emb)
100
  elif isinstance(layer, SpatialTransformer):
101
- x = layer(x, context, add_context)
102
  else:
103
  x = layer(x)
104
  return x
@@ -143,22 +102,6 @@ class Upsample(nn.Module):
143
  return x
144
 
145
 
146
- class TransposedUpsample(nn.Module):
147
- "Learned 2x upsampling without padding"
148
-
149
- def __init__(self, channels, out_channels=None, ks=5):
150
- super().__init__()
151
- self.channels = channels
152
- self.out_channels = out_channels or channels
153
-
154
- self.up = nn.ConvTranspose2d(
155
- self.channels, self.out_channels, kernel_size=ks, stride=2
156
- )
157
-
158
- def forward(self, x):
159
- return self.up(x)
160
-
161
-
162
  class Downsample(nn.Module):
163
  """
164
  A downsampling layer with an optional convolution.
@@ -206,17 +149,6 @@ class Downsample(nn.Module):
206
  class ResBlock(TimestepBlock):
207
  """
208
  A residual block that can optionally change the number of channels.
209
- :param channels: the number of input channels.
210
- :param emb_channels: the number of timestep embedding channels.
211
- :param dropout: the rate of dropout.
212
- :param out_channels: if specified, the number of out channels.
213
- :param use_conv: if True and out_channels is specified, use a spatial
214
- convolution instead of a smaller 1x1 convolution to change the
215
- channels in the skip connection.
216
- :param dims: determines if the signal is 1D, 2D, or 3D.
217
- :param use_checkpoint: if True, use gradient checkpointing on this module.
218
- :param up: if True, use this block for upsampling.
219
- :param down: if True, use this block for downsampling.
220
  """
221
 
222
  def __init__(
@@ -228,12 +160,11 @@ class ResBlock(TimestepBlock):
228
  use_conv=False,
229
  use_scale_shift_norm=False,
230
  dims=2,
231
- use_checkpoint=False,
232
  up=False,
233
  down=False,
234
  kernel_size=3,
235
  exchange_temb_dims=False,
236
- skip_t_emb=False,
237
  ):
238
  super().__init__()
239
  self.channels = channels
@@ -241,7 +172,6 @@ class ResBlock(TimestepBlock):
241
  self.dropout = dropout
242
  self.out_channels = out_channels or channels
243
  self.use_conv = use_conv
244
- self.use_checkpoint = use_checkpoint
245
  self.use_scale_shift_norm = use_scale_shift_norm
246
  self.exchange_temb_dims = exchange_temb_dims
247
 
@@ -310,17 +240,6 @@ class ResBlock(TimestepBlock):
310
  self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
311
 
312
  def forward(self, x, emb):
313
- """
314
- Apply the block to a Tensor, conditioned on a timestep embedding.
315
- :param x: an [N x C x ...] Tensor of features.
316
- :param emb: an [N x emb_channels] Tensor of timestep embeddings.
317
- :return: an [N x C x ...] Tensor of outputs.
318
- """
319
- return checkpoint(
320
- self._forward, (x, emb), self.parameters(), self.use_checkpoint
321
- )
322
-
323
- def _forward(self, x, emb):
324
  if self.updown:
325
  in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
326
  h = in_rest(x)
@@ -348,233 +267,42 @@ class ResBlock(TimestepBlock):
348
  h = self.out_layers(h)
349
  return self.skip_connection(x) + h
350
 
351
-
352
- class AttentionBlock(nn.Module):
353
- """
354
- An attention block that allows spatial positions to attend to each other.
355
- Originally ported from here, but adapted to the N-d case.
356
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
357
- """
358
-
359
- def __init__(
360
- self,
361
- channels,
362
- num_heads=1,
363
- num_head_channels=-1,
364
- use_checkpoint=False,
365
- use_new_attention_order=False,
366
- ):
367
- super().__init__()
368
- self.channels = channels
369
- if num_head_channels == -1:
370
- self.num_heads = num_heads
371
- else:
372
- assert (
373
- channels % num_head_channels == 0
374
- ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
375
- self.num_heads = channels // num_head_channels
376
- self.use_checkpoint = use_checkpoint
377
- self.norm = normalization(channels)
378
- self.qkv = conv_nd(1, channels, channels * 3, 1)
379
- if use_new_attention_order:
380
- # split qkv before split heads
381
- self.attention = QKVAttention(self.num_heads)
382
- else:
383
- # split heads before split qkv
384
- self.attention = QKVAttentionLegacy(self.num_heads)
385
-
386
- self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
387
-
388
- def forward(self, x, **kwargs):
389
- # TODO add crossframe attention and use mixed checkpoint
390
- return checkpoint(
391
- self._forward, (x,), self.parameters(), True
392
- ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
393
- # return pt_checkpoint(self._forward, x) # pytorch
394
-
395
- def _forward(self, x):
396
- b, c, *spatial = x.shape
397
- x = x.reshape(b, c, -1)
398
- qkv = self.qkv(self.norm(x))
399
- h = self.attention(qkv)
400
- h = self.proj_out(h)
401
- return (x + h).reshape(b, c, *spatial)
402
-
403
-
404
- def count_flops_attn(model, _x, y):
405
- """
406
- A counter for the `thop` package to count the operations in an
407
- attention operation.
408
- Meant to be used like:
409
- macs, params = thop.profile(
410
- model,
411
- inputs=(inputs, timestamps),
412
- custom_ops={QKVAttention: QKVAttention.count_flops},
413
- )
414
- """
415
- b, c, *spatial = y[0].shape
416
- num_spatial = int(np.prod(spatial))
417
- # We perform two matmuls with the same number of ops.
418
- # The first computes the weight matrix, the second computes
419
- # the combination of the value vectors.
420
- matmul_ops = 2 * b * (num_spatial**2) * c
421
- model.total_ops += th.DoubleTensor([matmul_ops])
422
-
423
-
424
- class QKVAttentionLegacy(nn.Module):
425
- """
426
- A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
427
- """
428
-
429
- def __init__(self, n_heads):
430
- super().__init__()
431
- self.n_heads = n_heads
432
-
433
- def forward(self, qkv):
434
- """
435
- Apply QKV attention.
436
- :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
437
- :return: an [N x (H * C) x T] tensor after attention.
438
- """
439
- bs, width, length = qkv.shape
440
- assert width % (3 * self.n_heads) == 0
441
- ch = width // (3 * self.n_heads)
442
- q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
443
- scale = 1 / math.sqrt(math.sqrt(ch))
444
- weight = th.einsum(
445
- "bct,bcs->bts", q * scale, k * scale
446
- ) # More stable with f16 than dividing afterwards
447
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
448
- a = th.einsum("bts,bcs->bct", weight, v)
449
- return a.reshape(bs, -1, length)
450
-
451
- @staticmethod
452
- def count_flops(model, _x, y):
453
- return count_flops_attn(model, _x, y)
454
-
455
-
456
- class QKVAttention(nn.Module):
457
- """
458
- A module which performs QKV attention and splits in a different order.
459
- """
460
-
461
- def __init__(self, n_heads):
462
- super().__init__()
463
- self.n_heads = n_heads
464
-
465
- def forward(self, qkv):
466
- """
467
- Apply QKV attention.
468
- :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
469
- :return: an [N x (H * C) x T] tensor after attention.
470
- """
471
- bs, width, length = qkv.shape
472
- assert width % (3 * self.n_heads) == 0
473
- ch = width // (3 * self.n_heads)
474
- q, k, v = qkv.chunk(3, dim=1)
475
- scale = 1 / math.sqrt(math.sqrt(ch))
476
- weight = th.einsum(
477
- "bct,bcs->bts",
478
- (q * scale).view(bs * self.n_heads, ch, length),
479
- (k * scale).view(bs * self.n_heads, ch, length),
480
- ) # More stable with f16 than dividing afterwards
481
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
482
- a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
483
- return a.reshape(bs, -1, length)
484
-
485
- @staticmethod
486
- def count_flops(model, _x, y):
487
- return count_flops_attn(model, _x, y)
488
-
489
-
490
- class Timestep(nn.Module):
491
- def __init__(self, dim):
492
- super().__init__()
493
- self.dim = dim
494
-
495
- def forward(self, t):
496
- return timestep_embedding(t, self.dim)
497
 
498
 
499
- class UNetModel(nn.Module):
500
- """
501
- The full UNet model with attention and timestep embedding.
502
- :param in_channels: channels in the input Tensor.
503
- :param model_channels: base channel count for the model.
504
- :param out_channels: channels in the output Tensor.
505
- :param num_res_blocks: number of residual blocks per downsample.
506
- :param attention_resolutions: a collection of downsample rates at which
507
- attention will take place. May be a set, list, or tuple.
508
- For example, if this contains 4, then at 4x downsampling, attention
509
- will be used.
510
- :param dropout: the dropout probability.
511
- :param channel_mult: channel multiplier for each level of the UNet.
512
- :param conv_resample: if True, use learned convolutions for upsampling and
513
- downsampling.
514
- :param dims: determines if the signal is 1D, 2D, or 3D.
515
- :param num_classes: if specified (as an int), then this model will be
516
- class-conditional with `num_classes` classes.
517
- :param use_checkpoint: use gradient checkpointing to reduce memory usage.
518
- :param num_heads: the number of attention heads in each attention layer.
519
- :param num_heads_channels: if specified, ignore num_heads and instead use
520
- a fixed channel width per attention head.
521
- :param num_heads_upsample: works with num_heads to set a different number
522
- of heads for upsampling. Deprecated.
523
- :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
524
- :param resblock_updown: use residual blocks for up/downsampling.
525
- :param use_new_attention_order: use a different attention pattern for potentially
526
- increased efficiency.
527
- """
528
 
529
  def __init__(
530
  self,
531
  in_channels,
 
532
  model_channels,
533
  out_channels,
534
  num_res_blocks,
535
  attention_resolutions,
536
  dropout=0,
537
  channel_mult=(1, 2, 4, 8),
 
 
538
  conv_resample=True,
539
  dims=2,
540
- num_classes=None,
541
- use_checkpoint=False,
542
- use_fp16=False,
543
  num_heads=-1,
544
  num_head_channels=-1,
545
  num_heads_upsample=-1,
546
  use_scale_shift_norm=False,
547
  resblock_updown=False,
548
- use_new_attention_order=False,
549
- use_spatial_transformer=False, # custom transformer support
550
- transformer_depth=1, # custom transformer support
551
- context_dim=None, # custom transformer support
552
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
553
- legacy=True,
554
- disable_self_attentions=None,
555
  num_attention_blocks=None,
556
- disable_middle_self_attn=False,
557
  use_linear_in_transformer=False,
558
- spatial_transformer_attn_type="softmax",
559
  adm_in_channels=None,
560
- use_fairscale_checkpoint=False,
561
- offload_to_cpu=False,
562
- transformer_depth_middle=None,
563
  ):
564
  super().__init__()
565
- from omegaconf.listconfig import ListConfig
566
-
567
- if use_spatial_transformer:
568
- assert (
569
- context_dim is not None
570
- ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
571
-
572
- if context_dim is not None:
573
- assert (
574
- use_spatial_transformer
575
- ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
576
- if type(context_dim) == ListConfig:
577
- context_dim = list(context_dim)
578
 
579
  if num_heads_upsample == -1:
580
  num_heads_upsample = num_heads
@@ -590,106 +318,39 @@ class UNetModel(nn.Module):
590
  ), "Either num_heads or num_head_channels has to be set"
591
 
592
  self.in_channels = in_channels
 
593
  self.model_channels = model_channels
594
  self.out_channels = out_channels
595
- if isinstance(transformer_depth, int):
596
- transformer_depth = len(channel_mult) * [transformer_depth]
597
- elif isinstance(transformer_depth, ListConfig):
598
- transformer_depth = list(transformer_depth)
599
- transformer_depth_middle = default(
600
- transformer_depth_middle, transformer_depth[-1]
601
- )
602
 
603
- if isinstance(num_res_blocks, int):
604
- self.num_res_blocks = len(channel_mult) * [num_res_blocks]
605
- else:
606
- if len(num_res_blocks) != len(channel_mult):
607
- raise ValueError(
608
- "provide num_res_blocks either as an int (globally constant) or "
609
- "as a list/tuple (per-level) with the same length as channel_mult"
610
- )
611
- self.num_res_blocks = num_res_blocks
612
- # self.num_res_blocks = num_res_blocks
613
- if disable_self_attentions is not None:
614
- # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
615
- assert len(disable_self_attentions) == len(channel_mult)
616
- if num_attention_blocks is not None:
617
- assert len(num_attention_blocks) == len(self.num_res_blocks)
618
- assert all(
619
- map(
620
- lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
621
- range(len(num_attention_blocks)),
622
- )
623
- )
624
- print(
625
- f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
626
- f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
627
- f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
628
- f"attention will still not be set."
629
- ) # todo: convert to warning
630
 
631
  self.attention_resolutions = attention_resolutions
632
  self.dropout = dropout
633
  self.channel_mult = channel_mult
634
  self.conv_resample = conv_resample
635
- self.num_classes = num_classes
636
- self.use_checkpoint = use_checkpoint
637
- if use_fp16:
638
- print("WARNING: use_fp16 was dropped and has no effect anymore.")
639
- # self.dtype = th.float16 if use_fp16 else th.float32
640
  self.num_heads = num_heads
641
  self.num_head_channels = num_head_channels
642
  self.num_heads_upsample = num_heads_upsample
643
- self.predict_codebook_ids = n_embed is not None
644
-
645
- assert use_fairscale_checkpoint != use_checkpoint or not (
646
- use_checkpoint or use_fairscale_checkpoint
647
- )
648
-
649
- self.use_fairscale_checkpoint = False
650
- checkpoint_wrapper_fn = (
651
- partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
652
- if self.use_fairscale_checkpoint
653
- else lambda x: x
654
- )
655
 
656
  time_embed_dim = model_channels * 4
657
- self.time_embed = checkpoint_wrapper_fn(
658
- nn.Sequential(
659
- linear(model_channels, time_embed_dim),
660
- nn.SiLU(),
661
- linear(time_embed_dim, time_embed_dim),
662
- )
663
  )
664
-
665
- if self.num_classes is not None:
666
- if isinstance(self.num_classes, int):
667
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
668
- elif self.num_classes == "continuous":
669
- print("setting up linear c_adm embedding layer")
670
- self.label_emb = nn.Linear(1, time_embed_dim)
671
- elif self.num_classes == "timestep":
672
- self.label_emb = checkpoint_wrapper_fn(
673
- nn.Sequential(
674
- Timestep(model_channels),
675
- nn.Sequential(
676
- linear(model_channels, time_embed_dim),
677
- nn.SiLU(),
678
- linear(time_embed_dim, time_embed_dim),
679
- ),
680
- )
681
- )
682
- elif self.num_classes == "sequential":
683
- assert adm_in_channels is not None
684
- self.label_emb = nn.Sequential(
685
- nn.Sequential(
686
- linear(adm_in_channels, time_embed_dim),
687
- nn.SiLU(),
688
- linear(time_embed_dim, time_embed_dim),
689
- )
690
  )
691
- else:
692
- raise ValueError()
693
 
694
  self.input_blocks = nn.ModuleList(
695
  [
@@ -698,6 +359,26 @@ class UNetModel(nn.Module):
698
  )
699
  ]
700
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
  self._feature_size = model_channels
702
  input_block_chans = [model_channels]
703
  ch = model_channels
@@ -705,16 +386,13 @@ class UNetModel(nn.Module):
705
  for level, mult in enumerate(channel_mult):
706
  for nr in range(self.num_res_blocks[level]):
707
  layers = [
708
- checkpoint_wrapper_fn(
709
- ResBlock(
710
- ch,
711
- time_embed_dim,
712
- dropout,
713
- out_channels=mult * model_channels,
714
- dims=dims,
715
- use_checkpoint=use_checkpoint,
716
- use_scale_shift_norm=use_scale_shift_norm,
717
- )
718
  )
719
  ]
720
  ch = mult * model_channels
@@ -724,45 +402,19 @@ class UNetModel(nn.Module):
724
  else:
725
  num_heads = ch // num_head_channels
726
  dim_head = num_head_channels
727
- if legacy:
728
- # num_heads = 1
729
- dim_head = (
730
- ch // num_heads
731
- if use_spatial_transformer
732
- else num_head_channels
733
- )
734
- if exists(disable_self_attentions):
735
- disabled_sa = disable_self_attentions[level]
736
- else:
737
- disabled_sa = False
738
-
739
  if (
740
  not exists(num_attention_blocks)
741
  or nr < num_attention_blocks[level]
742
  ):
743
  layers.append(
744
- checkpoint_wrapper_fn(
745
- AttentionBlock(
746
- ch,
747
- use_checkpoint=use_checkpoint,
748
- num_heads=num_heads,
749
- num_head_channels=dim_head,
750
- use_new_attention_order=use_new_attention_order,
751
- )
752
- )
753
- if not use_spatial_transformer
754
- else checkpoint_wrapper_fn(
755
- SpatialTransformer(
756
- ch,
757
- num_heads,
758
- dim_head,
759
- depth=transformer_depth[level],
760
- context_dim=context_dim,
761
- disable_self_attn=disabled_sa,
762
- use_linear=use_linear_in_transformer,
763
- attn_type=spatial_transformer_attn_type,
764
- use_checkpoint=use_checkpoint,
765
- )
766
  )
767
  )
768
  self.input_blocks.append(TimestepEmbedSequential(*layers))
@@ -772,17 +424,14 @@ class UNetModel(nn.Module):
772
  out_ch = ch
773
  self.input_blocks.append(
774
  TimestepEmbedSequential(
775
- checkpoint_wrapper_fn(
776
- ResBlock(
777
- ch,
778
- time_embed_dim,
779
- dropout,
780
- out_channels=out_ch,
781
- dims=dims,
782
- use_checkpoint=use_checkpoint,
783
- use_scale_shift_norm=use_scale_shift_norm,
784
- down=True,
785
- )
786
  )
787
  if resblock_updown
788
  else Downsample(
@@ -800,54 +449,33 @@ class UNetModel(nn.Module):
800
  else:
801
  num_heads = ch // num_head_channels
802
  dim_head = num_head_channels
803
- if legacy:
804
- # num_heads = 1
805
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
806
  self.middle_block = TimestepEmbedSequential(
807
- checkpoint_wrapper_fn(
808
- ResBlock(
809
- ch,
810
- time_embed_dim,
811
- dropout,
812
- dims=dims,
813
- use_checkpoint=use_checkpoint,
814
- use_scale_shift_norm=use_scale_shift_norm,
815
- )
816
- ),
817
- checkpoint_wrapper_fn(
818
- AttentionBlock(
819
- ch,
820
- use_checkpoint=use_checkpoint,
821
- num_heads=num_heads,
822
- num_head_channels=dim_head,
823
- use_new_attention_order=use_new_attention_order,
824
- )
825
- )
826
- if not use_spatial_transformer
827
- else checkpoint_wrapper_fn(
828
- SpatialTransformer( # always uses a self-attn
829
- ch,
830
- num_heads,
831
- dim_head,
832
- depth=transformer_depth_middle,
833
- context_dim=context_dim,
834
- disable_self_attn=disable_middle_self_attn,
835
- use_linear=use_linear_in_transformer,
836
- attn_type=spatial_transformer_attn_type,
837
- use_checkpoint=use_checkpoint,
838
- )
839
  ),
840
- checkpoint_wrapper_fn(
841
- ResBlock(
842
- ch,
843
- time_embed_dim,
844
- dropout,
845
- dims=dims,
846
- use_checkpoint=use_checkpoint,
847
- use_scale_shift_norm=use_scale_shift_norm,
848
- )
849
  ),
 
 
 
 
 
 
 
850
  )
 
851
  self._feature_size += ch
852
 
853
  self.output_blocks = nn.ModuleList([])
@@ -855,16 +483,13 @@ class UNetModel(nn.Module):
855
  for i in range(self.num_res_blocks[level] + 1):
856
  ich = input_block_chans.pop()
857
  layers = [
858
- checkpoint_wrapper_fn(
859
- ResBlock(
860
- ch + ich,
861
- time_embed_dim,
862
- dropout,
863
- out_channels=model_channels * mult,
864
- dims=dims,
865
- use_checkpoint=use_checkpoint,
866
- use_scale_shift_norm=use_scale_shift_norm,
867
- )
868
  )
869
  ]
870
  ch = model_channels * mult
@@ -874,61 +499,32 @@ class UNetModel(nn.Module):
874
  else:
875
  num_heads = ch // num_head_channels
876
  dim_head = num_head_channels
877
- if legacy:
878
- # num_heads = 1
879
- dim_head = (
880
- ch // num_heads
881
- if use_spatial_transformer
882
- else num_head_channels
883
- )
884
- if exists(disable_self_attentions):
885
- disabled_sa = disable_self_attentions[level]
886
- else:
887
- disabled_sa = False
888
-
889
  if (
890
  not exists(num_attention_blocks)
891
  or i < num_attention_blocks[level]
892
  ):
893
  layers.append(
894
- checkpoint_wrapper_fn(
895
- AttentionBlock(
896
- ch,
897
- use_checkpoint=use_checkpoint,
898
- num_heads=num_heads_upsample,
899
- num_head_channels=dim_head,
900
- use_new_attention_order=use_new_attention_order,
901
- )
902
- )
903
- if not use_spatial_transformer
904
- else checkpoint_wrapper_fn(
905
- SpatialTransformer(
906
- ch,
907
- num_heads,
908
- dim_head,
909
- depth=transformer_depth[level],
910
- context_dim=context_dim,
911
- disable_self_attn=disabled_sa,
912
- use_linear=use_linear_in_transformer,
913
- attn_type=spatial_transformer_attn_type,
914
- use_checkpoint=use_checkpoint,
915
- )
916
  )
917
  )
918
  if level and i == self.num_res_blocks[level]:
919
  out_ch = ch
920
  layers.append(
921
- checkpoint_wrapper_fn(
922
- ResBlock(
923
- ch,
924
- time_embed_dim,
925
- dropout,
926
- out_channels=out_ch,
927
- dims=dims,
928
- use_checkpoint=use_checkpoint,
929
- use_scale_shift_norm=use_scale_shift_norm,
930
- up=True,
931
- )
932
  )
933
  if resblock_updown
934
  else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
@@ -937,1133 +533,92 @@ class UNetModel(nn.Module):
937
  self.output_blocks.append(TimestepEmbedSequential(*layers))
938
  self._feature_size += ch
939
 
940
- self.out = checkpoint_wrapper_fn(
941
- nn.Sequential(
942
- normalization(ch),
943
- nn.SiLU(),
944
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
945
- )
946
  )
947
- if self.predict_codebook_ids:
948
- self.id_predictor = checkpoint_wrapper_fn(
949
- nn.Sequential(
950
- normalization(ch),
951
- conv_nd(dims, model_channels, n_embed, 1),
952
- # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
953
- )
954
- )
 
 
955
 
956
- def convert_to_fp16(self):
957
- """
958
- Convert the torso of the model to float16.
959
- """
960
- self.input_blocks.apply(convert_module_to_f16)
961
- self.middle_block.apply(convert_module_to_f16)
962
- self.output_blocks.apply(convert_module_to_f16)
963
 
964
- def convert_to_fp32(self):
965
- """
966
- Convert the torso of the model to float32.
967
- """
968
- self.input_blocks.apply(convert_module_to_f32)
969
- self.middle_block.apply(convert_module_to_f32)
970
- self.output_blocks.apply(convert_module_to_f32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
971
 
972
- def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
973
- """
974
- Apply the model to an input batch.
975
- :param x: an [N x C x ...] Tensor of inputs.
976
- :param timesteps: a 1-D batch of timesteps.
977
- :param context: conditioning plugged in via crossattn
978
- :param y: an [N] Tensor of labels, if class-conditional.
979
- :return: an [N x C x ...] Tensor of outputs.
980
- """
981
  assert (y is not None) == (
982
- self.num_classes is not None
983
  ), "must specify y if and only if the model is class-conditional"
 
 
 
984
  hs = []
985
  t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
986
  emb = self.time_embed(t_emb)
987
 
988
- if self.num_classes is not None:
989
  assert y.shape[0] == x.shape[0]
990
  emb = emb + self.label_emb(y)
991
 
992
- # h = x.type(self.dtype)
993
- h = x
994
- for i, module in enumerate(self.input_blocks):
995
- h = module(h, emb, context)
996
- hs.append(h)
997
- h = self.middle_block(h, emb, context)
998
- for i, module in enumerate(self.output_blocks):
999
- h = th.cat([h, hs.pop()], dim=1)
1000
- h = module(h, emb, context)
1001
- h = h.type(x.dtype)
1002
- if self.predict_codebook_ids:
1003
- assert False, "not supported anymore. what the f*** are you doing?"
1004
- else:
1005
- return self.out(h)
1006
-
1007
-
1008
-
1009
- class UNetModel(nn.Module):
1010
- """
1011
- The full UNet model with attention and timestep embedding.
1012
- :param in_channels: channels in the input Tensor.
1013
- :param model_channels: base channel count for the model.
1014
- :param out_channels: channels in the output Tensor.
1015
- :param num_res_blocks: number of residual blocks per downsample.
1016
- :param attention_resolutions: a collection of downsample rates at which
1017
- attention will take place. May be a set, list, or tuple.
1018
- For example, if this contains 4, then at 4x downsampling, attention
1019
- will be used.
1020
- :param dropout: the dropout probability.
1021
- :param channel_mult: channel multiplier for each level of the UNet.
1022
- :param conv_resample: if True, use learned convolutions for upsampling and
1023
- downsampling.
1024
- :param dims: determines if the signal is 1D, 2D, or 3D.
1025
- :param num_classes: if specified (as an int), then this model will be
1026
- class-conditional with `num_classes` classes.
1027
- :param use_checkpoint: use gradient checkpointing to reduce memory usage.
1028
- :param num_heads: the number of attention heads in each attention layer.
1029
- :param num_heads_channels: if specified, ignore num_heads and instead use
1030
- a fixed channel width per attention head.
1031
- :param num_heads_upsample: works with num_heads to set a different number
1032
- of heads for upsampling. Deprecated.
1033
- :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
1034
- :param resblock_updown: use residual blocks for up/downsampling.
1035
- :param use_new_attention_order: use a different attention pattern for potentially
1036
- increased efficiency.
1037
- """
1038
-
1039
- def __init__(
1040
- self,
1041
- in_channels,
1042
- model_channels,
1043
- out_channels,
1044
- num_res_blocks,
1045
- attention_resolutions,
1046
- dropout=0,
1047
- channel_mult=(1, 2, 4, 8),
1048
- conv_resample=True,
1049
- dims=2,
1050
- num_classes=None,
1051
- use_checkpoint=False,
1052
- use_fp16=False,
1053
- num_heads=-1,
1054
- num_head_channels=-1,
1055
- num_heads_upsample=-1,
1056
- use_scale_shift_norm=False,
1057
- resblock_updown=False,
1058
- use_new_attention_order=False,
1059
- use_spatial_transformer=False, # custom transformer support
1060
- transformer_depth=1, # custom transformer support
1061
- context_dim=None, # custom transformer support
1062
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
1063
- legacy=True,
1064
- disable_self_attentions=None,
1065
- num_attention_blocks=None,
1066
- disable_middle_self_attn=False,
1067
- use_linear_in_transformer=False,
1068
- spatial_transformer_attn_type="softmax",
1069
- adm_in_channels=None,
1070
- use_fairscale_checkpoint=False,
1071
- offload_to_cpu=False,
1072
- transformer_depth_middle=None,
1073
- ):
1074
- super().__init__()
1075
- from omegaconf.listconfig import ListConfig
1076
-
1077
- if use_spatial_transformer:
1078
- assert (
1079
- context_dim is not None
1080
- ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
1081
-
1082
- if context_dim is not None:
1083
- assert (
1084
- use_spatial_transformer
1085
- ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
1086
- if type(context_dim) == ListConfig:
1087
- context_dim = list(context_dim)
1088
-
1089
- if num_heads_upsample == -1:
1090
- num_heads_upsample = num_heads
1091
-
1092
- if num_heads == -1:
1093
- assert (
1094
- num_head_channels != -1
1095
- ), "Either num_heads or num_head_channels has to be set"
1096
-
1097
- if num_head_channels == -1:
1098
- assert (
1099
- num_heads != -1
1100
- ), "Either num_heads or num_head_channels has to be set"
1101
-
1102
- self.in_channels = in_channels
1103
- self.model_channels = model_channels
1104
- self.out_channels = out_channels
1105
- if isinstance(transformer_depth, int):
1106
- transformer_depth = len(channel_mult) * [transformer_depth]
1107
- elif isinstance(transformer_depth, ListConfig):
1108
- transformer_depth = list(transformer_depth)
1109
- transformer_depth_middle = default(
1110
- transformer_depth_middle, transformer_depth[-1]
1111
- )
1112
-
1113
- if isinstance(num_res_blocks, int):
1114
- self.num_res_blocks = len(channel_mult) * [num_res_blocks]
1115
- else:
1116
- if len(num_res_blocks) != len(channel_mult):
1117
- raise ValueError(
1118
- "provide num_res_blocks either as an int (globally constant) or "
1119
- "as a list/tuple (per-level) with the same length as channel_mult"
1120
- )
1121
- self.num_res_blocks = num_res_blocks
1122
- # self.num_res_blocks = num_res_blocks
1123
- if disable_self_attentions is not None:
1124
- # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
1125
- assert len(disable_self_attentions) == len(channel_mult)
1126
- if num_attention_blocks is not None:
1127
- assert len(num_attention_blocks) == len(self.num_res_blocks)
1128
- assert all(
1129
- map(
1130
- lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
1131
- range(len(num_attention_blocks)),
1132
- )
1133
- )
1134
- print(
1135
- f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
1136
- f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
1137
- f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
1138
- f"attention will still not be set."
1139
- ) # todo: convert to warning
1140
-
1141
- self.attention_resolutions = attention_resolutions
1142
- self.dropout = dropout
1143
- self.channel_mult = channel_mult
1144
- self.conv_resample = conv_resample
1145
- self.num_classes = num_classes
1146
- self.use_checkpoint = use_checkpoint
1147
- if use_fp16:
1148
- print("WARNING: use_fp16 was dropped and has no effect anymore.")
1149
- # self.dtype = th.float16 if use_fp16 else th.float32
1150
- self.num_heads = num_heads
1151
- self.num_head_channels = num_head_channels
1152
- self.num_heads_upsample = num_heads_upsample
1153
- self.predict_codebook_ids = n_embed is not None
1154
-
1155
- assert use_fairscale_checkpoint != use_checkpoint or not (
1156
- use_checkpoint or use_fairscale_checkpoint
1157
- )
1158
-
1159
- self.use_fairscale_checkpoint = False
1160
- checkpoint_wrapper_fn = (
1161
- partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
1162
- if self.use_fairscale_checkpoint
1163
- else lambda x: x
1164
- )
1165
-
1166
- time_embed_dim = model_channels * 4
1167
- self.time_embed = checkpoint_wrapper_fn(
1168
- nn.Sequential(
1169
- linear(model_channels, time_embed_dim),
1170
- nn.SiLU(),
1171
- linear(time_embed_dim, time_embed_dim),
1172
- )
1173
- )
1174
-
1175
- if self.num_classes is not None:
1176
- if isinstance(self.num_classes, int):
1177
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
1178
- elif self.num_classes == "continuous":
1179
- print("setting up linear c_adm embedding layer")
1180
- self.label_emb = nn.Linear(1, time_embed_dim)
1181
- elif self.num_classes == "timestep":
1182
- self.label_emb = checkpoint_wrapper_fn(
1183
- nn.Sequential(
1184
- Timestep(model_channels),
1185
- nn.Sequential(
1186
- linear(model_channels, time_embed_dim),
1187
- nn.SiLU(),
1188
- linear(time_embed_dim, time_embed_dim),
1189
- ),
1190
- )
1191
- )
1192
- elif self.num_classes == "sequential":
1193
- assert adm_in_channels is not None
1194
- self.label_emb = nn.Sequential(
1195
- nn.Sequential(
1196
- linear(adm_in_channels, time_embed_dim),
1197
- nn.SiLU(),
1198
- linear(time_embed_dim, time_embed_dim),
1199
- )
1200
- )
1201
- else:
1202
- raise ValueError()
1203
-
1204
- self.input_blocks = nn.ModuleList(
1205
- [
1206
- TimestepEmbedSequential(
1207
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
1208
- )
1209
- ]
1210
- )
1211
- self._feature_size = model_channels
1212
- input_block_chans = [model_channels]
1213
- ch = model_channels
1214
- ds = 1
1215
- for level, mult in enumerate(channel_mult):
1216
- for nr in range(self.num_res_blocks[level]):
1217
- layers = [
1218
- checkpoint_wrapper_fn(
1219
- ResBlock(
1220
- ch,
1221
- time_embed_dim,
1222
- dropout,
1223
- out_channels=mult * model_channels,
1224
- dims=dims,
1225
- use_checkpoint=use_checkpoint,
1226
- use_scale_shift_norm=use_scale_shift_norm,
1227
- )
1228
- )
1229
- ]
1230
- ch = mult * model_channels
1231
- if ds in attention_resolutions:
1232
- if num_head_channels == -1:
1233
- dim_head = ch // num_heads
1234
- else:
1235
- num_heads = ch // num_head_channels
1236
- dim_head = num_head_channels
1237
- if legacy:
1238
- # num_heads = 1
1239
- dim_head = (
1240
- ch // num_heads
1241
- if use_spatial_transformer
1242
- else num_head_channels
1243
- )
1244
- if exists(disable_self_attentions):
1245
- disabled_sa = disable_self_attentions[level]
1246
- else:
1247
- disabled_sa = False
1248
-
1249
- if (
1250
- not exists(num_attention_blocks)
1251
- or nr < num_attention_blocks[level]
1252
- ):
1253
- layers.append(
1254
- checkpoint_wrapper_fn(
1255
- AttentionBlock(
1256
- ch,
1257
- use_checkpoint=use_checkpoint,
1258
- num_heads=num_heads,
1259
- num_head_channels=dim_head,
1260
- use_new_attention_order=use_new_attention_order,
1261
- )
1262
- )
1263
- if not use_spatial_transformer
1264
- else checkpoint_wrapper_fn(
1265
- SpatialTransformer(
1266
- ch,
1267
- num_heads,
1268
- dim_head,
1269
- depth=transformer_depth[level],
1270
- context_dim=context_dim,
1271
- disable_self_attn=disabled_sa,
1272
- use_linear=use_linear_in_transformer,
1273
- attn_type=spatial_transformer_attn_type,
1274
- use_checkpoint=use_checkpoint,
1275
- )
1276
- )
1277
- )
1278
- self.input_blocks.append(TimestepEmbedSequential(*layers))
1279
- self._feature_size += ch
1280
- input_block_chans.append(ch)
1281
- if level != len(channel_mult) - 1:
1282
- out_ch = ch
1283
- self.input_blocks.append(
1284
- TimestepEmbedSequential(
1285
- checkpoint_wrapper_fn(
1286
- ResBlock(
1287
- ch,
1288
- time_embed_dim,
1289
- dropout,
1290
- out_channels=out_ch,
1291
- dims=dims,
1292
- use_checkpoint=use_checkpoint,
1293
- use_scale_shift_norm=use_scale_shift_norm,
1294
- down=True,
1295
- )
1296
- )
1297
- if resblock_updown
1298
- else Downsample(
1299
- ch, conv_resample, dims=dims, out_channels=out_ch
1300
- )
1301
- )
1302
- )
1303
- ch = out_ch
1304
- input_block_chans.append(ch)
1305
- ds *= 2
1306
- self._feature_size += ch
1307
-
1308
- if num_head_channels == -1:
1309
- dim_head = ch // num_heads
1310
- else:
1311
- num_heads = ch // num_head_channels
1312
- dim_head = num_head_channels
1313
- if legacy:
1314
- # num_heads = 1
1315
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
1316
- self.middle_block = TimestepEmbedSequential(
1317
- checkpoint_wrapper_fn(
1318
- ResBlock(
1319
- ch,
1320
- time_embed_dim,
1321
- dropout,
1322
- dims=dims,
1323
- use_checkpoint=use_checkpoint,
1324
- use_scale_shift_norm=use_scale_shift_norm,
1325
- )
1326
- ),
1327
- checkpoint_wrapper_fn(
1328
- AttentionBlock(
1329
- ch,
1330
- use_checkpoint=use_checkpoint,
1331
- num_heads=num_heads,
1332
- num_head_channels=dim_head,
1333
- use_new_attention_order=use_new_attention_order,
1334
- )
1335
- )
1336
- if not use_spatial_transformer
1337
- else checkpoint_wrapper_fn(
1338
- SpatialTransformer( # always uses a self-attn
1339
- ch,
1340
- num_heads,
1341
- dim_head,
1342
- depth=transformer_depth_middle,
1343
- context_dim=context_dim,
1344
- disable_self_attn=disable_middle_self_attn,
1345
- use_linear=use_linear_in_transformer,
1346
- attn_type=spatial_transformer_attn_type,
1347
- use_checkpoint=use_checkpoint,
1348
- )
1349
- ),
1350
- checkpoint_wrapper_fn(
1351
- ResBlock(
1352
- ch,
1353
- time_embed_dim,
1354
- dropout,
1355
- dims=dims,
1356
- use_checkpoint=use_checkpoint,
1357
- use_scale_shift_norm=use_scale_shift_norm,
1358
- )
1359
- ),
1360
- )
1361
- self._feature_size += ch
1362
-
1363
- self.output_blocks = nn.ModuleList([])
1364
- for level, mult in list(enumerate(channel_mult))[::-1]:
1365
- for i in range(self.num_res_blocks[level] + 1):
1366
- ich = input_block_chans.pop()
1367
- layers = [
1368
- checkpoint_wrapper_fn(
1369
- ResBlock(
1370
- ch + ich,
1371
- time_embed_dim,
1372
- dropout,
1373
- out_channels=model_channels * mult,
1374
- dims=dims,
1375
- use_checkpoint=use_checkpoint,
1376
- use_scale_shift_norm=use_scale_shift_norm,
1377
- )
1378
- )
1379
- ]
1380
- ch = model_channels * mult
1381
- if ds in attention_resolutions:
1382
- if num_head_channels == -1:
1383
- dim_head = ch // num_heads
1384
- else:
1385
- num_heads = ch // num_head_channels
1386
- dim_head = num_head_channels
1387
- if legacy:
1388
- # num_heads = 1
1389
- dim_head = (
1390
- ch // num_heads
1391
- if use_spatial_transformer
1392
- else num_head_channels
1393
- )
1394
- if exists(disable_self_attentions):
1395
- disabled_sa = disable_self_attentions[level]
1396
- else:
1397
- disabled_sa = False
1398
-
1399
- if (
1400
- not exists(num_attention_blocks)
1401
- or i < num_attention_blocks[level]
1402
- ):
1403
- layers.append(
1404
- checkpoint_wrapper_fn(
1405
- AttentionBlock(
1406
- ch,
1407
- use_checkpoint=use_checkpoint,
1408
- num_heads=num_heads_upsample,
1409
- num_head_channels=dim_head,
1410
- use_new_attention_order=use_new_attention_order,
1411
- )
1412
- )
1413
- if not use_spatial_transformer
1414
- else checkpoint_wrapper_fn(
1415
- SpatialTransformer(
1416
- ch,
1417
- num_heads,
1418
- dim_head,
1419
- depth=transformer_depth[level],
1420
- context_dim=context_dim,
1421
- disable_self_attn=disabled_sa,
1422
- use_linear=use_linear_in_transformer,
1423
- attn_type=spatial_transformer_attn_type,
1424
- use_checkpoint=use_checkpoint,
1425
- )
1426
- )
1427
- )
1428
- if level and i == self.num_res_blocks[level]:
1429
- out_ch = ch
1430
- layers.append(
1431
- checkpoint_wrapper_fn(
1432
- ResBlock(
1433
- ch,
1434
- time_embed_dim,
1435
- dropout,
1436
- out_channels=out_ch,
1437
- dims=dims,
1438
- use_checkpoint=use_checkpoint,
1439
- use_scale_shift_norm=use_scale_shift_norm,
1440
- up=True,
1441
- )
1442
- )
1443
- if resblock_updown
1444
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
1445
- )
1446
- ds //= 2
1447
- self.output_blocks.append(TimestepEmbedSequential(*layers))
1448
- self._feature_size += ch
1449
-
1450
- self.out = checkpoint_wrapper_fn(
1451
- nn.Sequential(
1452
- normalization(ch),
1453
- nn.SiLU(),
1454
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
1455
- )
1456
- )
1457
- if self.predict_codebook_ids:
1458
- self.id_predictor = checkpoint_wrapper_fn(
1459
- nn.Sequential(
1460
- normalization(ch),
1461
- conv_nd(dims, model_channels, n_embed, 1),
1462
- # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
1463
- )
1464
- )
1465
-
1466
- def convert_to_fp16(self):
1467
- """
1468
- Convert the torso of the model to float16.
1469
- """
1470
- self.input_blocks.apply(convert_module_to_f16)
1471
- self.middle_block.apply(convert_module_to_f16)
1472
- self.output_blocks.apply(convert_module_to_f16)
1473
-
1474
- def convert_to_fp32(self):
1475
- """
1476
- Convert the torso of the model to float32.
1477
- """
1478
- self.input_blocks.apply(convert_module_to_f32)
1479
- self.middle_block.apply(convert_module_to_f32)
1480
- self.output_blocks.apply(convert_module_to_f32)
1481
-
1482
- def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
1483
- """
1484
- Apply the model to an input batch.
1485
- :param x: an [N x C x ...] Tensor of inputs.
1486
- :param timesteps: a 1-D batch of timesteps.
1487
- :param context: conditioning plugged in via crossattn
1488
- :param y: an [N] Tensor of labels, if class-conditional.
1489
- :return: an [N x C x ...] Tensor of outputs.
1490
- """
1491
- assert (y is not None) == (
1492
- self.num_classes is not None
1493
- ), "must specify y if and only if the model is class-conditional"
1494
- hs = []
1495
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
1496
- emb = self.time_embed(t_emb)
1497
-
1498
- if self.num_classes is not None:
1499
- assert y.shape[0] == x.shape[0]
1500
- emb = emb + self.label_emb(y)
1501
-
1502
- # h = x.type(self.dtype)
1503
- h = x
1504
- for i, module in enumerate(self.input_blocks):
1505
- h = module(h, emb, context)
1506
- hs.append(h)
1507
- h = self.middle_block(h, emb, context)
1508
- for i, module in enumerate(self.output_blocks):
1509
- h = th.cat([h, hs.pop()], dim=1)
1510
- h = module(h, emb, context)
1511
- h = h.type(x.dtype)
1512
- if self.predict_codebook_ids:
1513
- assert False, "not supported anymore. what the f*** are you doing?"
1514
- else:
1515
- return self.out(h)
1516
-
1517
-
1518
- import seaborn as sns
1519
- import matplotlib.pyplot as plt
1520
-
1521
- class UNetAddModel(nn.Module):
1522
-
1523
- def __init__(
1524
- self,
1525
- in_channels,
1526
- ctrl_channels,
1527
- model_channels,
1528
- out_channels,
1529
- num_res_blocks,
1530
- attention_resolutions,
1531
- dropout=0,
1532
- channel_mult=(1, 2, 4, 8),
1533
- attn_type="attn2",
1534
- attn_layers=[],
1535
- conv_resample=True,
1536
- dims=2,
1537
- num_classes=None,
1538
- use_checkpoint=False,
1539
- use_fp16=False,
1540
- num_heads=-1,
1541
- num_head_channels=-1,
1542
- num_heads_upsample=-1,
1543
- use_scale_shift_norm=False,
1544
- resblock_updown=False,
1545
- use_new_attention_order=False,
1546
- use_spatial_transformer=False, # custom transformer support
1547
- transformer_depth=1, # custom transformer support
1548
- context_dim=None, # custom transformer support
1549
- add_context_dim=None,
1550
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
1551
- legacy=True,
1552
- disable_self_attentions=None,
1553
- num_attention_blocks=None,
1554
- disable_middle_self_attn=False,
1555
- use_linear_in_transformer=False,
1556
- spatial_transformer_attn_type="softmax",
1557
- adm_in_channels=None,
1558
- use_fairscale_checkpoint=False,
1559
- offload_to_cpu=False,
1560
- transformer_depth_middle=None,
1561
- ):
1562
- super().__init__()
1563
- from omegaconf.listconfig import ListConfig
1564
-
1565
- if use_spatial_transformer:
1566
- assert (
1567
- context_dim is not None
1568
- ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
1569
-
1570
- if context_dim is not None:
1571
- assert (
1572
- use_spatial_transformer
1573
- ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
1574
- if type(context_dim) == ListConfig:
1575
- context_dim = list(context_dim)
1576
-
1577
- if num_heads_upsample == -1:
1578
- num_heads_upsample = num_heads
1579
-
1580
- if num_heads == -1:
1581
- assert (
1582
- num_head_channels != -1
1583
- ), "Either num_heads or num_head_channels has to be set"
1584
-
1585
- if num_head_channels == -1:
1586
- assert (
1587
- num_heads != -1
1588
- ), "Either num_heads or num_head_channels has to be set"
1589
-
1590
- self.in_channels = in_channels
1591
- self.ctrl_channels = ctrl_channels
1592
- self.model_channels = model_channels
1593
- self.out_channels = out_channels
1594
- if isinstance(transformer_depth, int):
1595
- transformer_depth = len(channel_mult) * [transformer_depth]
1596
- elif isinstance(transformer_depth, ListConfig):
1597
- transformer_depth = list(transformer_depth)
1598
- transformer_depth_middle = default(
1599
- transformer_depth_middle, transformer_depth[-1]
1600
- )
1601
-
1602
- if isinstance(num_res_blocks, int):
1603
- self.num_res_blocks = len(channel_mult) * [num_res_blocks]
1604
- else:
1605
- if len(num_res_blocks) != len(channel_mult):
1606
- raise ValueError(
1607
- "provide num_res_blocks either as an int (globally constant) or "
1608
- "as a list/tuple (per-level) with the same length as channel_mult"
1609
- )
1610
- self.num_res_blocks = num_res_blocks
1611
- # self.num_res_blocks = num_res_blocks
1612
- if disable_self_attentions is not None:
1613
- # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
1614
- assert len(disable_self_attentions) == len(channel_mult)
1615
- if num_attention_blocks is not None:
1616
- assert len(num_attention_blocks) == len(self.num_res_blocks)
1617
- assert all(
1618
- map(
1619
- lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
1620
- range(len(num_attention_blocks)),
1621
- )
1622
- )
1623
- print(
1624
- f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
1625
- f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
1626
- f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
1627
- f"attention will still not be set."
1628
- ) # todo: convert to warning
1629
-
1630
- self.attention_resolutions = attention_resolutions
1631
- self.dropout = dropout
1632
- self.channel_mult = channel_mult
1633
- self.conv_resample = conv_resample
1634
- self.num_classes = num_classes
1635
- self.use_checkpoint = use_checkpoint
1636
- if use_fp16:
1637
- print("WARNING: use_fp16 was dropped and has no effect anymore.")
1638
- # self.dtype = th.float16 if use_fp16 else th.float32
1639
- self.num_heads = num_heads
1640
- self.num_head_channels = num_head_channels
1641
- self.num_heads_upsample = num_heads_upsample
1642
- self.predict_codebook_ids = n_embed is not None
1643
-
1644
- assert use_fairscale_checkpoint != use_checkpoint or not (
1645
- use_checkpoint or use_fairscale_checkpoint
1646
- )
1647
-
1648
- self.use_fairscale_checkpoint = False
1649
- checkpoint_wrapper_fn = (
1650
- partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
1651
- if self.use_fairscale_checkpoint
1652
- else lambda x: x
1653
- )
1654
-
1655
- time_embed_dim = model_channels * 4
1656
- self.time_embed = checkpoint_wrapper_fn(
1657
- nn.Sequential(
1658
- linear(model_channels, time_embed_dim),
1659
- nn.SiLU(),
1660
- linear(time_embed_dim, time_embed_dim),
1661
- )
1662
- )
1663
-
1664
- if self.num_classes is not None:
1665
- if isinstance(self.num_classes, int):
1666
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
1667
- elif self.num_classes == "continuous":
1668
- print("setting up linear c_adm embedding layer")
1669
- self.label_emb = nn.Linear(1, time_embed_dim)
1670
- elif self.num_classes == "timestep":
1671
- self.label_emb = checkpoint_wrapper_fn(
1672
- nn.Sequential(
1673
- Timestep(model_channels),
1674
- nn.Sequential(
1675
- linear(model_channels, time_embed_dim),
1676
- nn.SiLU(),
1677
- linear(time_embed_dim, time_embed_dim),
1678
- ),
1679
- )
1680
- )
1681
- elif self.num_classes == "sequential":
1682
- assert adm_in_channels is not None
1683
- self.label_emb = nn.Sequential(
1684
- nn.Sequential(
1685
- linear(adm_in_channels, time_embed_dim),
1686
- nn.SiLU(),
1687
- linear(time_embed_dim, time_embed_dim),
1688
- )
1689
- )
1690
- else:
1691
- raise ValueError()
1692
-
1693
- self.input_blocks = nn.ModuleList(
1694
- [
1695
- TimestepEmbedSequential(
1696
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
1697
- )
1698
- ]
1699
- )
1700
- if self.ctrl_channels > 0:
1701
- self.add_input_block = TimestepEmbedSequential(
1702
- conv_nd(dims, ctrl_channels, 16, 3, padding=1),
1703
- nn.SiLU(),
1704
- conv_nd(dims, 16, 16, 3, padding=1),
1705
- nn.SiLU(),
1706
- conv_nd(dims, 16, 32, 3, padding=1),
1707
- nn.SiLU(),
1708
- conv_nd(dims, 32, 32, 3, padding=1),
1709
- nn.SiLU(),
1710
- conv_nd(dims, 32, 96, 3, padding=1),
1711
- nn.SiLU(),
1712
- conv_nd(dims, 96, 96, 3, padding=1),
1713
- nn.SiLU(),
1714
- conv_nd(dims, 96, 256, 3, padding=1),
1715
- nn.SiLU(),
1716
- zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
1717
- )
1718
-
1719
- self._feature_size = model_channels
1720
- input_block_chans = [model_channels]
1721
- ch = model_channels
1722
- ds = 1
1723
- for level, mult in enumerate(channel_mult):
1724
- for nr in range(self.num_res_blocks[level]):
1725
- layers = [
1726
- checkpoint_wrapper_fn(
1727
- ResBlock(
1728
- ch,
1729
- time_embed_dim,
1730
- dropout,
1731
- out_channels=mult * model_channels,
1732
- dims=dims,
1733
- use_checkpoint=use_checkpoint,
1734
- use_scale_shift_norm=use_scale_shift_norm,
1735
- )
1736
- )
1737
- ]
1738
- ch = mult * model_channels
1739
- if ds in attention_resolutions:
1740
- if num_head_channels == -1:
1741
- dim_head = ch // num_heads
1742
- else:
1743
- num_heads = ch // num_head_channels
1744
- dim_head = num_head_channels
1745
- if legacy:
1746
- # num_heads = 1
1747
- dim_head = (
1748
- ch // num_heads
1749
- if use_spatial_transformer
1750
- else num_head_channels
1751
- )
1752
- if exists(disable_self_attentions):
1753
- disabled_sa = disable_self_attentions[level]
1754
- else:
1755
- disabled_sa = False
1756
-
1757
- if (
1758
- not exists(num_attention_blocks)
1759
- or nr < num_attention_blocks[level]
1760
- ):
1761
- layers.append(
1762
- checkpoint_wrapper_fn(
1763
- AttentionBlock(
1764
- ch,
1765
- use_checkpoint=use_checkpoint,
1766
- num_heads=num_heads,
1767
- num_head_channels=dim_head,
1768
- use_new_attention_order=use_new_attention_order,
1769
- )
1770
- )
1771
- if not use_spatial_transformer
1772
- else checkpoint_wrapper_fn(
1773
- SpatialTransformer(
1774
- ch,
1775
- num_heads,
1776
- dim_head,
1777
- depth=transformer_depth[level],
1778
- context_dim=context_dim,
1779
- add_context_dim=add_context_dim,
1780
- disable_self_attn=disabled_sa,
1781
- use_linear=use_linear_in_transformer,
1782
- attn_type=spatial_transformer_attn_type,
1783
- use_checkpoint=use_checkpoint,
1784
- )
1785
- )
1786
- )
1787
- self.input_blocks.append(TimestepEmbedSequential(*layers))
1788
- self._feature_size += ch
1789
- input_block_chans.append(ch)
1790
- if level != len(channel_mult) - 1:
1791
- out_ch = ch
1792
- self.input_blocks.append(
1793
- TimestepEmbedSequential(
1794
- checkpoint_wrapper_fn(
1795
- ResBlock(
1796
- ch,
1797
- time_embed_dim,
1798
- dropout,
1799
- out_channels=out_ch,
1800
- dims=dims,
1801
- use_checkpoint=use_checkpoint,
1802
- use_scale_shift_norm=use_scale_shift_norm,
1803
- down=True,
1804
- )
1805
- )
1806
- if resblock_updown
1807
- else Downsample(
1808
- ch, conv_resample, dims=dims, out_channels=out_ch
1809
- )
1810
- )
1811
- )
1812
- ch = out_ch
1813
- input_block_chans.append(ch)
1814
- ds *= 2
1815
- self._feature_size += ch
1816
-
1817
- if num_head_channels == -1:
1818
- dim_head = ch // num_heads
1819
- else:
1820
- num_heads = ch // num_head_channels
1821
- dim_head = num_head_channels
1822
- if legacy:
1823
- # num_heads = 1
1824
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
1825
- self.middle_block = TimestepEmbedSequential(
1826
- checkpoint_wrapper_fn(
1827
- ResBlock(
1828
- ch,
1829
- time_embed_dim,
1830
- dropout,
1831
- dims=dims,
1832
- use_checkpoint=use_checkpoint,
1833
- use_scale_shift_norm=use_scale_shift_norm,
1834
- )
1835
- ),
1836
- checkpoint_wrapper_fn(
1837
- AttentionBlock(
1838
- ch,
1839
- use_checkpoint=use_checkpoint,
1840
- num_heads=num_heads,
1841
- num_head_channels=dim_head,
1842
- use_new_attention_order=use_new_attention_order,
1843
- )
1844
- )
1845
- if not use_spatial_transformer
1846
- else checkpoint_wrapper_fn(
1847
- SpatialTransformer( # always uses a self-attn
1848
- ch,
1849
- num_heads,
1850
- dim_head,
1851
- depth=transformer_depth_middle,
1852
- context_dim=context_dim,
1853
- add_context_dim=add_context_dim,
1854
- disable_self_attn=disable_middle_self_attn,
1855
- use_linear=use_linear_in_transformer,
1856
- attn_type=spatial_transformer_attn_type,
1857
- use_checkpoint=use_checkpoint,
1858
- )
1859
- ),
1860
- checkpoint_wrapper_fn(
1861
- ResBlock(
1862
- ch,
1863
- time_embed_dim,
1864
- dropout,
1865
- dims=dims,
1866
- use_checkpoint=use_checkpoint,
1867
- use_scale_shift_norm=use_scale_shift_norm,
1868
- )
1869
- ),
1870
- )
1871
- self._feature_size += ch
1872
-
1873
- self.output_blocks = nn.ModuleList([])
1874
- for level, mult in list(enumerate(channel_mult))[::-1]:
1875
- for i in range(self.num_res_blocks[level] + 1):
1876
- ich = input_block_chans.pop()
1877
- layers = [
1878
- checkpoint_wrapper_fn(
1879
- ResBlock(
1880
- ch + ich,
1881
- time_embed_dim,
1882
- dropout,
1883
- out_channels=model_channels * mult,
1884
- dims=dims,
1885
- use_checkpoint=use_checkpoint,
1886
- use_scale_shift_norm=use_scale_shift_norm,
1887
- )
1888
- )
1889
- ]
1890
- ch = model_channels * mult
1891
- if ds in attention_resolutions:
1892
- if num_head_channels == -1:
1893
- dim_head = ch // num_heads
1894
- else:
1895
- num_heads = ch // num_head_channels
1896
- dim_head = num_head_channels
1897
- if legacy:
1898
- # num_heads = 1
1899
- dim_head = (
1900
- ch // num_heads
1901
- if use_spatial_transformer
1902
- else num_head_channels
1903
- )
1904
- if exists(disable_self_attentions):
1905
- disabled_sa = disable_self_attentions[level]
1906
- else:
1907
- disabled_sa = False
1908
-
1909
- if (
1910
- not exists(num_attention_blocks)
1911
- or i < num_attention_blocks[level]
1912
- ):
1913
- layers.append(
1914
- checkpoint_wrapper_fn(
1915
- AttentionBlock(
1916
- ch,
1917
- use_checkpoint=use_checkpoint,
1918
- num_heads=num_heads_upsample,
1919
- num_head_channels=dim_head,
1920
- use_new_attention_order=use_new_attention_order,
1921
- )
1922
- )
1923
- if not use_spatial_transformer
1924
- else checkpoint_wrapper_fn(
1925
- SpatialTransformer(
1926
- ch,
1927
- num_heads,
1928
- dim_head,
1929
- depth=transformer_depth[level],
1930
- context_dim=context_dim,
1931
- add_context_dim=add_context_dim,
1932
- disable_self_attn=disabled_sa,
1933
- use_linear=use_linear_in_transformer,
1934
- attn_type=spatial_transformer_attn_type,
1935
- use_checkpoint=use_checkpoint,
1936
- )
1937
- )
1938
- )
1939
- if level and i == self.num_res_blocks[level]:
1940
- out_ch = ch
1941
- layers.append(
1942
- checkpoint_wrapper_fn(
1943
- ResBlock(
1944
- ch,
1945
- time_embed_dim,
1946
- dropout,
1947
- out_channels=out_ch,
1948
- dims=dims,
1949
- use_checkpoint=use_checkpoint,
1950
- use_scale_shift_norm=use_scale_shift_norm,
1951
- up=True,
1952
- )
1953
- )
1954
- if resblock_updown
1955
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
1956
- )
1957
- ds //= 2
1958
- self.output_blocks.append(TimestepEmbedSequential(*layers))
1959
- self._feature_size += ch
1960
-
1961
- self.out = checkpoint_wrapper_fn(
1962
- nn.Sequential(
1963
- normalization(ch),
1964
- nn.SiLU(),
1965
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
1966
- )
1967
- )
1968
- if self.predict_codebook_ids:
1969
- self.id_predictor = checkpoint_wrapper_fn(
1970
- nn.Sequential(
1971
- normalization(ch),
1972
- conv_nd(dims, model_channels, n_embed, 1),
1973
- # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
1974
- )
1975
- )
1976
-
1977
- # cache attn map
1978
- self.attn_type = attn_type
1979
- self.attn_layers = attn_layers
1980
- self.attn_map_cache = []
1981
- for name, module in self.named_modules():
1982
- if name.endswith(self.attn_type):
1983
- item = {"name": name, "heads": module.heads, "size": None, "attn_map": None}
1984
- self.attn_map_cache.append(item)
1985
- module.attn_map_cache = item
1986
-
1987
- def clear_attn_map(self):
1988
-
1989
- for item in self.attn_map_cache:
1990
- if item["attn_map"] is not None:
1991
- del item["attn_map"]
1992
- item["attn_map"] = None
1993
-
1994
- def save_attn_map(self, save_name="temp", tokens=""):
1995
-
1996
- attn_maps = []
1997
- for item in self.attn_map_cache:
1998
- name = item["name"]
1999
- if any([name.startswith(block) for block in self.attn_layers]):
2000
- heads = item["heads"]
2001
- attn_maps.append(item["attn_map"].detach().cpu())
2002
-
2003
- attn_map = th.stack(attn_maps, dim=0)
2004
- attn_map = th.mean(attn_map, dim=0)
2005
-
2006
- # attn_map: bh * n * l
2007
- bh, n, l = attn_map.shape # bh: batch size * heads / n : pixel length(h*w) / l: token length
2008
- attn_map = attn_map.reshape((-1,heads,n,l)).mean(dim=1)
2009
- b = attn_map.shape[0]
2010
-
2011
- h = w = int(n**0.5)
2012
- attn_map = attn_map.permute(0,2,1).reshape((b,l,h,w)).numpy()
2013
-
2014
- attn_map_i = attn_map[-1]
2015
-
2016
- l = attn_map_i.shape[0]
2017
- fig = plt.figure(figsize=(12, 8), dpi=300)
2018
- for j in range(12):
2019
- if j >= l: break
2020
- ax = fig.add_subplot(3, 4, j+1)
2021
- sns.heatmap(attn_map_i[j], square=True, xticklabels=False, yticklabels=False)
2022
- if j < len(tokens):
2023
- ax.set_title(tokens[j])
2024
- fig.savefig(f"./temp/attn_map/attn_map_{save_name}.png")
2025
- plt.close()
2026
-
2027
- return attn_map_i
2028
-
2029
- def forward(self, x, timesteps=None, context=None, add_context=None, y=None, **kwargs):
2030
- """
2031
- Apply the model to an input batch.
2032
- :param x: an [N x C x ...] Tensor of inputs.
2033
- :param timesteps: a 1-D batch of timesteps.
2034
- :param context: conditioning plugged in via crossattn
2035
- :param y: an [N] Tensor of labels, if class-conditional.
2036
- :return: an [N x C x ...] Tensor of outputs.
2037
- """
2038
- assert (y is not None) == (
2039
- self.num_classes is not None
2040
- ), "must specify y if and only if the model is class-conditional"
2041
-
2042
- self.clear_attn_map()
2043
-
2044
- hs = []
2045
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
2046
- emb = self.time_embed(t_emb)
2047
-
2048
- if self.num_classes is not None:
2049
- assert y.shape[0] == x.shape[0]
2050
- emb = emb + self.label_emb(y)
2051
-
2052
- # h = x.type(self.dtype)
2053
  h = x
2054
  if self.ctrl_channels > 0:
2055
  in_h, add_h = th.split(h, [self.in_channels, self.ctrl_channels], dim=1)
2056
-
2057
  for i, module in enumerate(self.input_blocks):
2058
  if self.ctrl_channels > 0 and i == 0:
2059
- h = module(in_h, emb, context, add_context) + self.add_input_block(add_h, emb, context, add_context)
2060
  else:
2061
- h = module(h, emb, context, add_context)
2062
  hs.append(h)
2063
- h = self.middle_block(h, emb, context, add_context)
2064
  for i, module in enumerate(self.output_blocks):
2065
  h = th.cat([h, hs.pop()], dim=1)
2066
- h = module(h, emb, context, add_context)
2067
  h = h.type(x.dtype)
2068
 
2069
  return self.out(h)
 
 
1
  from abc import abstractmethod
 
2
  from typing import Iterable
3
 
4
  import numpy as np
 
10
  from ...modules.attention import SpatialTransformer
11
  from ...modules.diffusionmodules.util import (
12
  avg_pool_nd,
 
13
  conv_nd,
14
  linear,
15
  normalization,
 
19
  from ...util import default, exists
20
 
21
 
22
+ class Timestep(nn.Module):
23
+ def __init__(self, dim):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  super().__init__()
25
+ self.dim = dim
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ def forward(self, t):
28
+ return timestep_embedding(t, self.dim)
29
+
30
 
31
  class TimestepBlock(nn.Module):
32
  """
 
50
  self,
51
  x,
52
  emb,
53
+ t_context=None,
54
+ v_context=None
 
 
 
 
 
55
  ):
56
  for layer in self:
57
  if isinstance(layer, TimestepBlock):
58
  x = layer(x, emb)
59
  elif isinstance(layer, SpatialTransformer):
60
+ x = layer(x, t_context, v_context)
61
  else:
62
  x = layer(x)
63
  return x
 
102
  return x
103
 
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  class Downsample(nn.Module):
106
  """
107
  A downsampling layer with an optional convolution.
 
149
  class ResBlock(TimestepBlock):
150
  """
151
  A residual block that can optionally change the number of channels.
 
 
 
 
 
 
 
 
 
 
 
152
  """
153
 
154
  def __init__(
 
160
  use_conv=False,
161
  use_scale_shift_norm=False,
162
  dims=2,
 
163
  up=False,
164
  down=False,
165
  kernel_size=3,
166
  exchange_temb_dims=False,
167
+ skip_t_emb=False
168
  ):
169
  super().__init__()
170
  self.channels = channels
 
172
  self.dropout = dropout
173
  self.out_channels = out_channels or channels
174
  self.use_conv = use_conv
 
175
  self.use_scale_shift_norm = use_scale_shift_norm
176
  self.exchange_temb_dims = exchange_temb_dims
177
 
 
240
  self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
241
 
242
  def forward(self, x, emb):
 
 
 
 
 
 
 
 
 
 
 
243
  if self.updown:
244
  in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
245
  h = in_rest(x)
 
267
  h = self.out_layers(h)
268
  return self.skip_connection(x) + h
269
 
270
+
271
+ import seaborn as sns
272
+ import matplotlib.pyplot as plt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
 
275
+ class UnifiedUNetModel(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
  def __init__(
278
  self,
279
  in_channels,
280
+ ctrl_channels,
281
  model_channels,
282
  out_channels,
283
  num_res_blocks,
284
  attention_resolutions,
285
  dropout=0,
286
  channel_mult=(1, 2, 4, 8),
287
+ save_attn_type=None,
288
+ save_attn_layers=[],
289
  conv_resample=True,
290
  dims=2,
291
+ use_label=None,
 
 
292
  num_heads=-1,
293
  num_head_channels=-1,
294
  num_heads_upsample=-1,
295
  use_scale_shift_norm=False,
296
  resblock_updown=False,
297
+ transformer_depth=1,
298
+ t_context_dim=None,
299
+ v_context_dim=None,
 
 
 
 
300
  num_attention_blocks=None,
 
301
  use_linear_in_transformer=False,
 
302
  adm_in_channels=None,
303
+ transformer_depth_middle=None
 
 
304
  ):
305
  super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
  if num_heads_upsample == -1:
308
  num_heads_upsample = num_heads
 
318
  ), "Either num_heads or num_head_channels has to be set"
319
 
320
  self.in_channels = in_channels
321
+ self.ctrl_channels = ctrl_channels
322
  self.model_channels = model_channels
323
  self.out_channels = out_channels
 
 
 
 
 
 
 
324
 
325
+ transformer_depth = len(channel_mult) * [transformer_depth]
326
+ transformer_depth_middle = default(transformer_depth_middle, transformer_depth[-1])
327
+
328
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
  self.attention_resolutions = attention_resolutions
331
  self.dropout = dropout
332
  self.channel_mult = channel_mult
333
  self.conv_resample = conv_resample
334
+ self.use_label = use_label
 
 
 
 
335
  self.num_heads = num_heads
336
  self.num_head_channels = num_head_channels
337
  self.num_heads_upsample = num_heads_upsample
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
  time_embed_dim = model_channels * 4
340
+ self.time_embed = nn.Sequential(
341
+ linear(model_channels, time_embed_dim),
342
+ nn.SiLU(),
343
+ linear(time_embed_dim, time_embed_dim),
 
 
344
  )
345
+
346
+ if self.use_label is not None:
347
+ self.label_emb = nn.Sequential(
348
+ nn.Sequential(
349
+ linear(adm_in_channels, time_embed_dim),
350
+ nn.SiLU(),
351
+ linear(time_embed_dim, time_embed_dim),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  )
353
+ )
 
354
 
355
  self.input_blocks = nn.ModuleList(
356
  [
 
359
  )
360
  ]
361
  )
362
+
363
+ if self.ctrl_channels > 0:
364
+ self.ctrl_block = TimestepEmbedSequential(
365
+ conv_nd(dims, ctrl_channels, 16, 3, padding=1),
366
+ nn.SiLU(),
367
+ conv_nd(dims, 16, 16, 3, padding=1),
368
+ nn.SiLU(),
369
+ conv_nd(dims, 16, 32, 3, padding=1),
370
+ nn.SiLU(),
371
+ conv_nd(dims, 32, 32, 3, padding=1),
372
+ nn.SiLU(),
373
+ conv_nd(dims, 32, 96, 3, padding=1),
374
+ nn.SiLU(),
375
+ conv_nd(dims, 96, 96, 3, padding=1),
376
+ nn.SiLU(),
377
+ conv_nd(dims, 96, 256, 3, padding=1),
378
+ nn.SiLU(),
379
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
380
+ )
381
+
382
  self._feature_size = model_channels
383
  input_block_chans = [model_channels]
384
  ch = model_channels
 
386
  for level, mult in enumerate(channel_mult):
387
  for nr in range(self.num_res_blocks[level]):
388
  layers = [
389
+ ResBlock(
390
+ ch,
391
+ time_embed_dim,
392
+ dropout,
393
+ out_channels=mult * model_channels,
394
+ dims=dims,
395
+ use_scale_shift_norm=use_scale_shift_norm
 
 
 
396
  )
397
  ]
398
  ch = mult * model_channels
 
402
  else:
403
  num_heads = ch // num_head_channels
404
  dim_head = num_head_channels
 
 
 
 
 
 
 
 
 
 
 
 
405
  if (
406
  not exists(num_attention_blocks)
407
  or nr < num_attention_blocks[level]
408
  ):
409
  layers.append(
410
+ SpatialTransformer(
411
+ ch,
412
+ num_heads,
413
+ dim_head,
414
+ depth=transformer_depth[level],
415
+ t_context_dim=t_context_dim,
416
+ v_context_dim=v_context_dim,
417
+ use_linear=use_linear_in_transformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  )
419
  )
420
  self.input_blocks.append(TimestepEmbedSequential(*layers))
 
424
  out_ch = ch
425
  self.input_blocks.append(
426
  TimestepEmbedSequential(
427
+ ResBlock(
428
+ ch,
429
+ time_embed_dim,
430
+ dropout,
431
+ out_channels=out_ch,
432
+ dims=dims,
433
+ use_scale_shift_norm=use_scale_shift_norm,
434
+ down=True
 
 
 
435
  )
436
  if resblock_updown
437
  else Downsample(
 
449
  else:
450
  num_heads = ch // num_head_channels
451
  dim_head = num_head_channels
452
+
 
 
453
  self.middle_block = TimestepEmbedSequential(
454
+ ResBlock(
455
+ ch,
456
+ time_embed_dim,
457
+ dropout,
458
+ dims=dims,
459
+ use_scale_shift_norm=use_scale_shift_norm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  ),
461
+ SpatialTransformer( # always uses a self-attn
462
+ ch,
463
+ num_heads,
464
+ dim_head,
465
+ depth=transformer_depth_middle,
466
+ t_context_dim=t_context_dim,
467
+ v_context_dim=v_context_dim,
468
+ use_linear=use_linear_in_transformer
 
469
  ),
470
+ ResBlock(
471
+ ch,
472
+ time_embed_dim,
473
+ dropout,
474
+ dims=dims,
475
+ use_scale_shift_norm=use_scale_shift_norm
476
+ )
477
  )
478
+
479
  self._feature_size += ch
480
 
481
  self.output_blocks = nn.ModuleList([])
 
483
  for i in range(self.num_res_blocks[level] + 1):
484
  ich = input_block_chans.pop()
485
  layers = [
486
+ ResBlock(
487
+ ch + ich,
488
+ time_embed_dim,
489
+ dropout,
490
+ out_channels=model_channels * mult,
491
+ dims=dims,
492
+ use_scale_shift_norm=use_scale_shift_norm
 
 
 
493
  )
494
  ]
495
  ch = model_channels * mult
 
499
  else:
500
  num_heads = ch // num_head_channels
501
  dim_head = num_head_channels
 
 
 
 
 
 
 
 
 
 
 
 
502
  if (
503
  not exists(num_attention_blocks)
504
  or i < num_attention_blocks[level]
505
  ):
506
  layers.append(
507
+ SpatialTransformer(
508
+ ch,
509
+ num_heads,
510
+ dim_head,
511
+ depth=transformer_depth[level],
512
+ t_context_dim=t_context_dim,
513
+ v_context_dim=v_context_dim,
514
+ use_linear=use_linear_in_transformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
  )
516
  )
517
  if level and i == self.num_res_blocks[level]:
518
  out_ch = ch
519
  layers.append(
520
+ ResBlock(
521
+ ch,
522
+ time_embed_dim,
523
+ dropout,
524
+ out_channels=out_ch,
525
+ dims=dims,
526
+ use_scale_shift_norm=use_scale_shift_norm,
527
+ up=True
 
 
 
528
  )
529
  if resblock_updown
530
  else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
 
533
  self.output_blocks.append(TimestepEmbedSequential(*layers))
534
  self._feature_size += ch
535
 
536
+ self.out = nn.Sequential(
537
+ normalization(ch),
538
+ nn.SiLU(),
539
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1))
 
 
540
  )
541
+
542
+ # cache attn map
543
+ self.attn_type = save_attn_type
544
+ self.attn_layers = save_attn_layers
545
+ self.attn_map_cache = []
546
+ for name, module in self.named_modules():
547
+ if any([name.endswith(attn_type) for attn_type in self.attn_type]):
548
+ item = {"name": name, "heads": module.heads, "size": None, "attn_map": None}
549
+ self.attn_map_cache.append(item)
550
+ module.attn_map_cache = item
551
 
552
+ def clear_attn_map(self):
 
 
 
 
 
 
553
 
554
+ for item in self.attn_map_cache:
555
+ if item["attn_map"] is not None:
556
+ del item["attn_map"]
557
+ item["attn_map"] = None
558
+
559
+ def save_attn_map(self, attn_type="t_attn", save_name="temp", tokens=""):
560
+
561
+ attn_maps = []
562
+ for item in self.attn_map_cache:
563
+ name = item["name"]
564
+ if any([name.startswith(block) for block in self.attn_layers]) and name.endswith(attn_type):
565
+ heads = item["heads"]
566
+ attn_maps.append(item["attn_map"].detach().cpu())
567
+
568
+ attn_map = th.stack(attn_maps, dim=0)
569
+ attn_map = th.mean(attn_map, dim=0)
570
+
571
+ # attn_map: bh * n * l
572
+ bh, n, l = attn_map.shape # bh: batch size * heads / n : pixel length(h*w) / l: token length
573
+ attn_map = attn_map.reshape((-1,heads,n,l)).mean(dim=1)
574
+ b = attn_map.shape[0]
575
+
576
+ h = w = int(n**0.5)
577
+ attn_map = attn_map.permute(0,2,1).reshape((b,l,h,w)).numpy()
578
+ attn_map_i = attn_map[-1]
579
+
580
+ l = attn_map_i.shape[0]
581
+ fig = plt.figure(figsize=(12, 8), dpi=300)
582
+ for j in range(12):
583
+ if j >= l: break
584
+ ax = fig.add_subplot(3, 4, j+1)
585
+ sns.heatmap(attn_map_i[j], square=True, xticklabels=False, yticklabels=False)
586
+ if j < len(tokens):
587
+ ax.set_title(tokens[j])
588
+ fig.savefig(f"temp/attn_map/attn_map_{save_name}.png")
589
+ plt.close()
590
+
591
+ return attn_map_i
592
+
593
+ def forward(self, x, timesteps=None, t_context=None, v_context=None, y=None, **kwargs):
594
 
 
 
 
 
 
 
 
 
 
595
  assert (y is not None) == (
596
+ self.use_label is not None
597
  ), "must specify y if and only if the model is class-conditional"
598
+
599
+ self.clear_attn_map()
600
+
601
  hs = []
602
  t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
603
  emb = self.time_embed(t_emb)
604
 
605
+ if self.use_label is not None:
606
  assert y.shape[0] == x.shape[0]
607
  emb = emb + self.label_emb(y)
608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
  h = x
610
  if self.ctrl_channels > 0:
611
  in_h, add_h = th.split(h, [self.in_channels, self.ctrl_channels], dim=1)
 
612
  for i, module in enumerate(self.input_blocks):
613
  if self.ctrl_channels > 0 and i == 0:
614
+ h = module(in_h, emb, t_context, v_context) + self.ctrl_block(add_h, emb, t_context, v_context)
615
  else:
616
+ h = module(h, emb, t_context, v_context)
617
  hs.append(h)
618
+ h = self.middle_block(h, emb, t_context, v_context)
619
  for i, module in enumerate(self.output_blocks):
620
  h = th.cat([h, hs.pop()], dim=1)
621
+ h = module(h, emb, t_context, v_context)
622
  h = h.type(x.dtype)
623
 
624
  return self.out(h)
sgm/modules/diffusionmodules/sampling.py CHANGED
@@ -412,194 +412,12 @@ class EulerEDMSampler(EDMSampler):
412
  inter = inter.cpu().numpy().transpose(1, 2, 0) * 255
413
  inters.append(inter.astype(np.uint8))
414
 
415
- print(f"Local losses: {local_losses}")
416
 
417
  if len(inters) > 0:
418
  imageio.mimsave(f"./temp/inters/{name}.gif", inters, 'GIF', duration=0.02)
419
 
420
  return x
421
-
422
-
423
- class EulerEDMDualSampler(EulerEDMSampler):
424
-
425
- def prepare_sampling_loop(self, x, cond, uc_1=None, uc_2=None, num_steps=None):
426
- sigmas = self.discretization(
427
- self.num_steps if num_steps is None else num_steps, device=self.device
428
- )
429
- uc_1 = default(uc_1, cond)
430
- uc_2 = default(uc_2, cond)
431
-
432
- x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
433
- num_sigmas = len(sigmas)
434
-
435
- s_in = x.new_ones([x.shape[0]])
436
-
437
- return x, s_in, sigmas, num_sigmas, cond, uc_1, uc_2
438
-
439
- def denoise(self, x, model, sigma, cond, uc_1, uc_2):
440
- denoised = model.denoiser(model.model, *self.guider.prepare_inputs(x, sigma, cond, uc_1, uc_2))
441
- denoised = self.guider(denoised, sigma)
442
- return denoised
443
-
444
- def get_init_noise(self, cfgs, model, cond, batch, uc_1=None, uc_2=None):
445
-
446
- H, W = batch["target_size_as_tuple"][0]
447
- shape = (cfgs.batch_size, cfgs.channel, int(H) // cfgs.factor, int(W) // cfgs.factor)
448
-
449
- randn = torch.randn(shape).to(torch.device("cuda", index=cfgs.gpu))
450
- x = randn.clone()
451
-
452
- xs = []
453
- self.verbose = False
454
- for _ in range(cfgs.noise_iters):
455
-
456
- x, s_in, sigmas, num_sigmas, cond, uc_1, uc_2 = self.prepare_sampling_loop(
457
- x, cond, uc_1, uc_2, num_steps=2
458
- )
459
-
460
- superv = {
461
- "mask": batch["mask"] if "mask" in batch else None,
462
- "seg_mask": batch["seg_mask"] if "seg_mask" in batch else None
463
- }
464
-
465
- local_losses = []
466
-
467
- for i in self.get_sigma_gen(num_sigmas):
468
-
469
- gamma = (
470
- min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
471
- if self.s_tmin <= sigmas[i] <= self.s_tmax
472
- else 0.0
473
- )
474
-
475
- x, inter, local_loss = self.sampler_step(
476
- s_in * sigmas[i],
477
- s_in * sigmas[i + 1],
478
- model,
479
- x,
480
- cond,
481
- superv,
482
- uc_1,
483
- uc_2,
484
- gamma,
485
- save_loss=True
486
- )
487
-
488
- local_losses.append(local_loss.item())
489
-
490
- xs.append((randn, local_losses[-1]))
491
-
492
- randn = torch.randn(shape).to(torch.device("cuda", index=cfgs.gpu))
493
- x = randn.clone()
494
-
495
- self.verbose = True
496
-
497
- xs.sort(key = lambda x: x[-1])
498
-
499
- if len(xs) > 0:
500
- print(f"Init local loss: Best {xs[0][1]} Worst {xs[-1][1]}")
501
- x = xs[0][0]
502
-
503
- return x
504
-
505
- def sampler_step(self, sigma, next_sigma, model, x, cond, batch=None, uc_1=None, uc_2=None,
506
- gamma=0.0, alpha=0, iter_enabled=False, thres=None, update=False,
507
- name=None, save_loss=False, save_attn=False, save_inter=False):
508
-
509
- sigma_hat = sigma * (gamma + 1.0)
510
- if gamma > 0:
511
- eps = torch.randn_like(x) * self.s_noise
512
- x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
513
-
514
- if update:
515
- x = self.attend_and_excite(x, model, sigma_hat, cond, batch, alpha, iter_enabled, thres)
516
-
517
- denoised = self.denoise(x, model, sigma_hat, cond, uc_1, uc_2)
518
- denoised_decode = model.decode_first_stage(denoised) if save_inter else None
519
-
520
- if save_loss:
521
- local_loss = model.loss_fn.get_min_local_loss(model.model.diffusion_model.attn_map_cache, batch["mask"], batch["seg_mask"])
522
- local_loss = local_loss[-local_loss.shape[0]//3:]
523
- else:
524
- local_loss = torch.zeros(1)
525
- if save_attn:
526
- attn_map = model.model.diffusion_model.save_attn_map(save_name=name, save_single=True)
527
- self.save_segment_map(attn_map, tokens=batch["label"][0], save_name=name)
528
-
529
- d = to_d(x, sigma_hat, denoised)
530
- dt = append_dims(next_sigma - sigma_hat, x.ndim)
531
-
532
- euler_step = self.euler_step(x, d, dt)
533
-
534
- return euler_step, denoised_decode, local_loss
535
-
536
- def __call__(self, model, x, cond, batch=None, uc_1=None, uc_2=None, num_steps=None, init_step=0,
537
- name=None, aae_enabled=False, detailed=False):
538
-
539
- x, s_in, sigmas, num_sigmas, cond, uc_1, uc_2 = self.prepare_sampling_loop(
540
- x, cond, uc_1, uc_2, num_steps
541
- )
542
-
543
- name = batch["name"][0]
544
- inters = []
545
- local_losses = []
546
- scales = np.linspace(start=1.0, stop=0, num=num_sigmas)
547
- iter_lst = np.linspace(start=5, stop=25, num=6, dtype=np.int32)
548
- thres_lst = np.linspace(start=-0.5, stop=-0.8, num=6)
549
-
550
- for i in self.get_sigma_gen(num_sigmas, init_step=init_step):
551
-
552
- gamma = (
553
- min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
554
- if self.s_tmin <= sigmas[i] <= self.s_tmax
555
- else 0.0
556
- )
557
-
558
- alpha = 20 * np.sqrt(scales[i])
559
- update = aae_enabled
560
- save_loss = aae_enabled
561
- save_attn = detailed and (i == (num_sigmas-1)//2)
562
- save_inter = aae_enabled
563
-
564
- if i in iter_lst:
565
- iter_enabled = True
566
- thres = thres_lst[list(iter_lst).index(i)]
567
- else:
568
- iter_enabled = False
569
- thres = 0.0
570
-
571
- x, inter, local_loss = self.sampler_step(
572
- s_in * sigmas[i],
573
- s_in * sigmas[i + 1],
574
- model,
575
- x,
576
- cond,
577
- batch,
578
- uc_1,
579
- uc_2,
580
- gamma,
581
- alpha=alpha,
582
- iter_enabled=iter_enabled,
583
- thres=thres,
584
- update=update,
585
- name=name,
586
- save_loss=save_loss,
587
- save_attn=save_attn,
588
- save_inter=save_inter
589
- )
590
-
591
- local_losses.append(local_loss.item())
592
- if inter is not None:
593
- inter = torch.clamp((inter + 1.0) / 2.0, min=0.0, max=1.0)[0]
594
- inter = inter.cpu().numpy().transpose(1, 2, 0) * 255
595
- inters.append(inter.astype(np.uint8))
596
-
597
- print(f"Local losses: {local_losses}")
598
-
599
- if len(inters) > 0:
600
- imageio.mimsave(f"./temp/inters/{name}.gif", inters, 'GIF', duration=0.1)
601
-
602
- return x
603
 
604
 
605
  class HeunEDMSampler(EDMSampler):
 
412
  inter = inter.cpu().numpy().transpose(1, 2, 0) * 255
413
  inters.append(inter.astype(np.uint8))
414
 
415
+ # print(f"Local losses: {local_losses}")
416
 
417
  if len(inters) > 0:
418
  imageio.mimsave(f"./temp/inters/{name}.gif", inters, 'GIF', duration=0.02)
419
 
420
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
 
422
 
423
  class HeunEDMSampler(EDMSampler):
sgm/modules/diffusionmodules/sampling_utils.py CHANGED
@@ -7,10 +7,7 @@ from ...util import append_dims
7
  class NoDynamicThresholding:
8
  def __call__(self, uncond, cond, scale):
9
  return uncond + scale * (cond - uncond)
10
-
11
- class DualThresholding: # Dual condition CFG (from instructPix2Pix)
12
- def __call__(self, uncond_1, uncond_2, cond, scale):
13
- return uncond_1 + scale[0] * (uncond_2 - uncond_1) + scale[1] * (cond - uncond_2)
14
 
15
  def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
16
  if order - 1 > i:
 
7
  class NoDynamicThresholding:
8
  def __call__(self, uncond, cond, scale):
9
  return uncond + scale * (cond - uncond)
10
+
 
 
 
11
 
12
  def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
13
  if order - 1 > i:
sgm/modules/diffusionmodules/wrappers.py CHANGED
@@ -28,8 +28,8 @@ class OpenAIWrapper(IdentityWrapper):
28
  return self.diffusion_model(
29
  x,
30
  timesteps=t,
31
- context=c.get("crossattn", None),
32
- add_context=c.get("add_crossattn", None),
33
  y=c.get("vector", None),
34
  **kwargs
35
  )
 
28
  return self.diffusion_model(
29
  x,
30
  timesteps=t,
31
+ t_context=c.get("t_crossattn", None),
32
+ v_context=c.get("v_crossattn", None),
33
  y=c.get("vector", None),
34
  **kwargs
35
  )
sgm/modules/encoders/modules.py CHANGED
@@ -14,6 +14,7 @@ from transformers import (
14
  ByT5Tokenizer,
15
  CLIPTextModel,
16
  CLIPTokenizer,
 
17
  T5EncoderModel,
18
  T5Tokenizer,
19
  )
@@ -38,18 +39,19 @@ import pytorch_lightning as pl
38
  from torchvision import transforms
39
  from timm.models.vision_transformer import VisionTransformer
40
  from safetensors.torch import load_file as load_safetensors
 
41
 
42
  # disable warning
43
  from transformers import logging
44
  logging.set_verbosity_error()
45
 
46
  class AbstractEmbModel(nn.Module):
47
- def __init__(self, is_add_embedder=False):
48
  super().__init__()
49
  self._is_trainable = None
50
  self._ucg_rate = None
51
  self._input_key = None
52
- self.is_add_embedder = is_add_embedder
53
 
54
  @property
55
  def is_trainable(self) -> bool:
@@ -63,6 +65,10 @@ class AbstractEmbModel(nn.Module):
63
  def input_key(self) -> str:
64
  return self._input_key
65
 
 
 
 
 
66
  @is_trainable.setter
67
  def is_trainable(self, value: bool):
68
  self._is_trainable = value
@@ -75,6 +81,10 @@ class AbstractEmbModel(nn.Module):
75
  def input_key(self, value: str):
76
  self._input_key = value
77
 
 
 
 
 
78
  @is_trainable.deleter
79
  def is_trainable(self):
80
  del self._is_trainable
@@ -87,8 +97,13 @@ class AbstractEmbModel(nn.Module):
87
  def input_key(self):
88
  del self._input_key
89
 
 
 
 
 
90
 
91
  class GeneralConditioner(nn.Module):
 
92
  OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
93
  KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
94
 
@@ -109,7 +124,8 @@ class GeneralConditioner(nn.Module):
109
  f"Initialized embedder #{n}: {embedder.__class__.__name__} "
110
  f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
111
  )
112
-
 
113
  if "input_key" in embconfig:
114
  embedder.input_key = embconfig["input_key"]
115
  elif "input_keys" in embconfig:
@@ -156,13 +172,10 @@ class GeneralConditioner(nn.Module):
156
  if not isinstance(emb_out, (list, tuple)):
157
  emb_out = [emb_out]
158
  for emb in emb_out:
159
- if embedder.is_add_embedder:
160
- out_key = "add_crossattn"
161
  else:
162
  out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
163
- if embedder.input_key == "mask":
164
- H, W = batch["image"].shape[-2:]
165
- emb = nn.functional.interpolate(emb, (H//8, W//8))
166
  if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
167
  emb = (
168
  expand_dims_like(
@@ -204,28 +217,6 @@ class GeneralConditioner(nn.Module):
204
  return c, uc
205
 
206
 
207
- class DualConditioner(GeneralConditioner):
208
-
209
- def get_unconditional_conditioning(
210
- self, batch_c, batch_uc_1=None, batch_uc_2=None, force_uc_zero_embeddings=None
211
- ):
212
- if force_uc_zero_embeddings is None:
213
- force_uc_zero_embeddings = []
214
- ucg_rates = list()
215
- for embedder in self.embedders:
216
- ucg_rates.append(embedder.ucg_rate)
217
- embedder.ucg_rate = 0.0
218
-
219
- c = self(batch_c)
220
- uc_1 = self(batch_uc_1, force_uc_zero_embeddings) if batch_uc_1 is not None else None
221
- uc_2 = self(batch_uc_2, force_uc_zero_embeddings[:1]) if batch_uc_2 is not None else None
222
-
223
- for embedder, rate in zip(self.embedders, ucg_rates):
224
- embedder.ucg_rate = rate
225
-
226
- return c, uc_1, uc_2
227
-
228
-
229
  class InceptionV3(nn.Module):
230
  """Wrapper around the https://github.com/mseitzer/pytorch-fid inception
231
  port with an additional squeeze at the end"""
@@ -409,7 +400,6 @@ class FrozenCLIPEmbedder(AbstractEmbModel):
409
 
410
  def freeze(self):
411
  self.transformer = self.transformer.eval()
412
-
413
  for param in self.parameters():
414
  param.requires_grad = False
415
 
@@ -694,24 +684,24 @@ class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
694
  if self.output_tokens:
695
  z, tokens = z[0], z[1]
696
  z = z.to(image.dtype)
697
- if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
698
- z = (
699
- torch.bernoulli(
700
- (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
701
- )[:, None]
702
- * z
703
- )
704
- if tokens is not None:
705
- tokens = (
706
- expand_dims_like(
707
- torch.bernoulli(
708
- (1.0 - self.ucg_rate)
709
- * torch.ones(tokens.shape[0], device=tokens.device)
710
- ),
711
- tokens,
712
- )
713
- * tokens
714
- )
715
  if self.unsqueeze_dim:
716
  z = z[:, None, :]
717
  if self.output_tokens:
@@ -807,7 +797,7 @@ class FrozenCLIPT5Encoder(AbstractEmbModel):
807
  return [clip_z, t5_z]
808
 
809
 
810
- class SpatialRescaler(nn.Module):
811
  def __init__(
812
  self,
813
  n_stages=1,
@@ -846,6 +836,9 @@ class SpatialRescaler(nn.Module):
846
  padding=kernel_size // 2,
847
  )
848
  self.wrap_video = wrap_video
 
 
 
849
 
850
  def forward(self, x):
851
  if self.wrap_video and x.ndim == 5:
 
14
  ByT5Tokenizer,
15
  CLIPTextModel,
16
  CLIPTokenizer,
17
+ CLIPVisionModel,
18
  T5EncoderModel,
19
  T5Tokenizer,
20
  )
 
39
  from torchvision import transforms
40
  from timm.models.vision_transformer import VisionTransformer
41
  from safetensors.torch import load_file as load_safetensors
42
+ from torchvision.utils import save_image
43
 
44
  # disable warning
45
  from transformers import logging
46
  logging.set_verbosity_error()
47
 
48
  class AbstractEmbModel(nn.Module):
49
+ def __init__(self):
50
  super().__init__()
51
  self._is_trainable = None
52
  self._ucg_rate = None
53
  self._input_key = None
54
+ self._emb_key = None
55
 
56
  @property
57
  def is_trainable(self) -> bool:
 
65
  def input_key(self) -> str:
66
  return self._input_key
67
 
68
+ @property
69
+ def emb_key(self) -> str:
70
+ return self._emb_key
71
+
72
  @is_trainable.setter
73
  def is_trainable(self, value: bool):
74
  self._is_trainable = value
 
81
  def input_key(self, value: str):
82
  self._input_key = value
83
 
84
+ @emb_key.setter
85
+ def emb_key(self, value: str):
86
+ self._emb_key = value
87
+
88
  @is_trainable.deleter
89
  def is_trainable(self):
90
  del self._is_trainable
 
97
  def input_key(self):
98
  del self._input_key
99
 
100
+ @emb_key.deleter
101
+ def emb_key(self):
102
+ del self._emb_key
103
+
104
 
105
  class GeneralConditioner(nn.Module):
106
+
107
  OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
108
  KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
109
 
 
124
  f"Initialized embedder #{n}: {embedder.__class__.__name__} "
125
  f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
126
  )
127
+ if "emb_key" in embconfig:
128
+ embedder.emb_key = embconfig["emb_key"]
129
  if "input_key" in embconfig:
130
  embedder.input_key = embconfig["input_key"]
131
  elif "input_keys" in embconfig:
 
172
  if not isinstance(emb_out, (list, tuple)):
173
  emb_out = [emb_out]
174
  for emb in emb_out:
175
+ if embedder.emb_key is not None:
176
+ out_key = embedder.emb_key
177
  else:
178
  out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
 
 
 
179
  if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
180
  emb = (
181
  expand_dims_like(
 
217
  return c, uc
218
 
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  class InceptionV3(nn.Module):
221
  """Wrapper around the https://github.com/mseitzer/pytorch-fid inception
222
  port with an additional squeeze at the end"""
 
400
 
401
  def freeze(self):
402
  self.transformer = self.transformer.eval()
 
403
  for param in self.parameters():
404
  param.requires_grad = False
405
 
 
684
  if self.output_tokens:
685
  z, tokens = z[0], z[1]
686
  z = z.to(image.dtype)
687
+ # if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
688
+ # z = (
689
+ # torch.bernoulli(
690
+ # (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
691
+ # )[:, None]
692
+ # * z
693
+ # )
694
+ # if tokens is not None:
695
+ # tokens = (
696
+ # expand_dims_like(
697
+ # torch.bernoulli(
698
+ # (1.0 - self.ucg_rate)
699
+ # * torch.ones(tokens.shape[0], device=tokens.device)
700
+ # ),
701
+ # tokens,
702
+ # )
703
+ # * tokens
704
+ # )
705
  if self.unsqueeze_dim:
706
  z = z[:, None, :]
707
  if self.output_tokens:
 
797
  return [clip_z, t5_z]
798
 
799
 
800
+ class SpatialRescaler(AbstractEmbModel):
801
  def __init__(
802
  self,
803
  n_stages=1,
 
836
  padding=kernel_size // 2,
837
  )
838
  self.wrap_video = wrap_video
839
+
840
+ def freeze(self):
841
+ pass
842
 
843
  def forward(self, x):
844
  if self.wrap_video and x.ndim == 5:
temp/attn_map/attn_map_3.png ADDED
temp/attn_map/attn_map_4.png ADDED
temp/attn_map/attn_map_5.png ADDED
temp/seg_map/seg_3.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff197cf810e4ba2d26b76265d48530ff03c7b753e1ae6b0b7dfc8d010801df26
3
+ size 20608
temp/seg_map/seg_4.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc96f8f8a39aa63faa8ece0d8f758520a41d59b881926a9ddcacb6f5d46099dd
3
+ size 20608
temp/seg_map/seg_5.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16f008e62ab6b2b5b1ca1f58390808b8c9096edb6ddd85570f17232c441114f2
3
+ size 24704
util.py CHANGED
@@ -3,34 +3,6 @@ from omegaconf import OmegaConf
3
  from sgm.util import instantiate_from_config
4
  from sgm.modules.diffusionmodules.sampling import *
5
 
6
- SD_XL_BASE_RATIOS = {
7
- "0.5": (704, 1408),
8
- "0.52": (704, 1344),
9
- "0.57": (768, 1344),
10
- "0.6": (768, 1280),
11
- "0.68": (832, 1216),
12
- "0.72": (832, 1152),
13
- "0.78": (896, 1152),
14
- "0.82": (896, 1088),
15
- "0.88": (960, 1088),
16
- "0.94": (960, 1024),
17
- "1.0": (1024, 1024),
18
- "1.07": (1024, 960),
19
- "1.13": (1088, 960),
20
- "1.21": (1088, 896),
21
- "1.29": (1152, 896),
22
- "1.38": (1152, 832),
23
- "1.46": (1216, 832),
24
- "1.67": (1280, 768),
25
- "1.75": (1344, 768),
26
- "1.91": (1344, 704),
27
- "2.0": (1408, 704),
28
- "2.09": (1472, 704),
29
- "2.4": (1536, 640),
30
- "2.5": (1600, 640),
31
- "2.89": (1664, 576),
32
- "3.0": (1728, 576),
33
- }
34
 
35
  def init_model(cfgs):
36
 
@@ -43,8 +15,7 @@ def init_model(cfgs):
43
  if cfgs.type == "train":
44
  model.train()
45
  else:
46
- if cfgs.use_gpu:
47
- model.to(torch.device("cuda", index=cfgs.gpu))
48
  model.eval()
49
  model.freeze()
50
 
@@ -56,40 +27,22 @@ def init_sampling(cfgs):
56
  "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
57
  }
58
 
59
- if cfgs.dual_conditioner:
60
- guider_config = {
61
- "target": "sgm.modules.diffusionmodules.guiders.DualCFG",
62
- "params": {"scale": cfgs.scale},
63
- }
64
-
65
- sampler = EulerEDMDualSampler(
66
- num_steps=cfgs.steps,
67
- discretization_config=discretization_config,
68
- guider_config=guider_config,
69
- s_churn=0.0,
70
- s_tmin=0.0,
71
- s_tmax=999.0,
72
- s_noise=1.0,
73
- verbose=True,
74
- device=torch.device("cuda", index=cfgs.gpu)
75
- )
76
- else:
77
- guider_config = {
78
- "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
79
- "params": {"scale": cfgs.scale[0]},
80
- }
81
-
82
- sampler = EulerEDMSampler(
83
- num_steps=cfgs.steps,
84
- discretization_config=discretization_config,
85
- guider_config=guider_config,
86
- s_churn=0.0,
87
- s_tmin=0.0,
88
- s_tmax=999.0,
89
- s_noise=1.0,
90
- verbose=True,
91
- device=torch.device("cuda", index=cfgs.gpu)
92
- )
93
 
94
  return sampler
95
 
@@ -109,29 +62,17 @@ def deep_copy(batch):
109
  def prepare_batch(cfgs, batch):
110
 
111
  for key in batch:
112
- if isinstance(batch[key], torch.Tensor) and cfgs.use_gpu:
113
  batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu))
114
 
115
- if not cfgs.dual_conditioner:
116
- batch_uc = deep_copy(batch)
117
 
118
- if "ntxt" in batch:
119
- batch_uc["txt"] = batch["ntxt"]
120
- else:
121
- batch_uc["txt"] = ["" for _ in range(len(batch["txt"]))]
122
-
123
- if "label" in batch:
124
- batch_uc["label"] = ["" for _ in range(len(batch["label"]))]
125
-
126
- return batch, batch_uc, None
127
-
128
  else:
129
- batch_uc_1 = deep_copy(batch)
130
- batch_uc_2 = deep_copy(batch)
131
-
132
- batch_uc_1["ref"] = torch.zeros_like(batch["ref"])
133
- batch_uc_2["ref"] = torch.zeros_like(batch["ref"])
134
 
135
- batch_uc_1["label"] = ["" for _ in range(len(batch["label"]))]
 
136
 
137
- return batch, batch_uc_1, batch_uc_2
 
3
  from sgm.util import instantiate_from_config
4
  from sgm.modules.diffusionmodules.sampling import *
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def init_model(cfgs):
8
 
 
15
  if cfgs.type == "train":
16
  model.train()
17
  else:
18
+ model.to(torch.device("cuda", index=cfgs.gpu))
 
19
  model.eval()
20
  model.freeze()
21
 
 
27
  "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
28
  }
29
 
30
+ guider_config = {
31
+ "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
32
+ "params": {"scale": cfgs.scale[0]},
33
+ }
34
+
35
+ sampler = EulerEDMSampler(
36
+ num_steps=cfgs.steps,
37
+ discretization_config=discretization_config,
38
+ guider_config=guider_config,
39
+ s_churn=0.0,
40
+ s_tmin=0.0,
41
+ s_tmax=999.0,
42
+ s_noise=1.0,
43
+ verbose=True,
44
+ device=torch.device("cuda", index=cfgs.gpu)
45
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  return sampler
48
 
 
62
  def prepare_batch(cfgs, batch):
63
 
64
  for key in batch:
65
+ if isinstance(batch[key], torch.Tensor):
66
  batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu))
67
 
68
+ batch_uc = deep_copy(batch)
 
69
 
70
+ if "ntxt" in batch:
71
+ batch_uc["txt"] = batch["ntxt"]
 
 
 
 
 
 
 
 
72
  else:
73
+ batch_uc["txt"] = ["" for _ in range(len(batch["txt"]))]
 
 
 
 
74
 
75
+ if "label" in batch:
76
+ batch_uc["label"] = ["" for _ in range(len(batch["label"]))]
77
 
78
+ return batch, batch_uc