ZYMPKU commited on
Commit
6497501
1 Parent(s): e79d136
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +2 -0
  2. README.md +1 -1
  3. app.py +207 -0
  4. checkpoints/AEs/AE_inpainting_2.safetensors +3 -0
  5. checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt +3 -0
  6. checkpoints/st-step=100000+la-step=100000-simp.ckpt +3 -0
  7. configs/demo.yaml +29 -0
  8. configs/test/textdesign_sd_2.yaml +137 -0
  9. demo/examples/CEFUL_1_0.jpeg +3 -0
  10. demo/examples/CLOTHES_0_0.png +3 -0
  11. demo/examples/COMPLICATED_0_1.jpeg +3 -0
  12. demo/examples/DELIGHT_0_1.jpeg +3 -0
  13. demo/examples/ECHOES_0_0.jpeg +3 -0
  14. demo/examples/ENGINE_0_0.png +3 -0
  15. demo/examples/FASCINATING_0_1.jpeg +3 -0
  16. demo/examples/FAVOURITE_0_0.jpeg +3 -0
  17. demo/examples/FINNAL_0_1.jpeg +3 -0
  18. demo/examples/FRONTIER_0_0.png +3 -0
  19. demo/examples/Innovate_0_0.jpeg +3 -0
  20. demo/examples/PRESERVE_0_0.jpeg +3 -0
  21. demo/examples/Peaceful_0_0.jpeg +3 -0
  22. demo/examples/Scamps_0_0.png +3 -0
  23. demo/examples/TREE_0_0.png +3 -0
  24. demo/examples/better_0_0.jpg +3 -0
  25. demo/examples/tested_0_0.png +3 -0
  26. demo/teaser.png +3 -0
  27. requirements.txt +24 -0
  28. sgm/__init__.py +2 -0
  29. sgm/lr_scheduler.py +135 -0
  30. sgm/models/__init__.py +2 -0
  31. sgm/models/autoencoder.py +335 -0
  32. sgm/models/diffusion.py +328 -0
  33. sgm/modules/__init__.py +6 -0
  34. sgm/modules/attention.py +976 -0
  35. sgm/modules/autoencoding/__init__.py +0 -0
  36. sgm/modules/autoencoding/losses/__init__.py +246 -0
  37. sgm/modules/autoencoding/regularizers/__init__.py +53 -0
  38. sgm/modules/diffusionmodules/__init__.py +7 -0
  39. sgm/modules/diffusionmodules/denoiser.py +63 -0
  40. sgm/modules/diffusionmodules/denoiser_scaling.py +31 -0
  41. sgm/modules/diffusionmodules/denoiser_weighting.py +24 -0
  42. sgm/modules/diffusionmodules/discretizer.py +68 -0
  43. sgm/modules/diffusionmodules/guiders.py +81 -0
  44. sgm/modules/diffusionmodules/loss.py +275 -0
  45. sgm/modules/diffusionmodules/model.py +743 -0
  46. sgm/modules/diffusionmodules/openaimodel.py +2070 -0
  47. sgm/modules/diffusionmodules/sampling.py +784 -0
  48. sgm/modules/diffusionmodules/sampling_utils.py +51 -0
  49. sgm/modules/diffusionmodules/sigma_sampling.py +31 -0
  50. sgm/modules/diffusionmodules/util.py +308 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ **/__pycache__
2
+ process.ipynb
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: UDiffText
3
- emoji: 🐢
4
  colorFrom: purple
5
  colorTo: blue
6
  sdk: gradio
 
1
  ---
2
  title: UDiffText
3
+ emoji: 😋
4
  colorFrom: purple
5
  colorTo: blue
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import os, glob
4
+ import numpy as np
5
+ import gradio as gr
6
+ from PIL import Image
7
+ from omegaconf import OmegaConf
8
+ from contextlib import nullcontext
9
+ from pytorch_lightning import seed_everything
10
+ from os.path import join as ospj
11
+
12
+ from util import *
13
+
14
+
15
+ def predict(cfgs, model, sampler, batch):
16
+
17
+ context = nullcontext if cfgs.aae_enabled else torch.no_grad
18
+
19
+ with context():
20
+
21
+ batch, batch_uc_1, batch_uc_2 = prepare_batch(cfgs, batch)
22
+
23
+ if cfgs.dual_conditioner:
24
+ c, uc_1, uc_2 = model.conditioner.get_unconditional_conditioning(
25
+ batch,
26
+ batch_uc_1=batch_uc_1,
27
+ batch_uc_2=batch_uc_2,
28
+ force_uc_zero_embeddings=cfgs.force_uc_zero_embeddings,
29
+ )
30
+ else:
31
+ c, uc_1 = model.conditioner.get_unconditional_conditioning(
32
+ batch,
33
+ batch_uc=batch_uc_1,
34
+ force_uc_zero_embeddings=cfgs.force_uc_zero_embeddings,
35
+ )
36
+
37
+ if cfgs.dual_conditioner:
38
+ x = sampler.get_init_noise(cfgs, model, cond=c, batch=batch, uc_1=uc_1, uc_2=uc_2)
39
+ samples_z = sampler(model, x, cond=c, batch=batch, uc_1=uc_1, uc_2=uc_2, init_step=0,
40
+ aae_enabled = cfgs.aae_enabled, detailed = cfgs.detailed)
41
+ else:
42
+ x = sampler.get_init_noise(cfgs, model, cond=c, batch=batch, uc=uc_1)
43
+ samples_z = sampler(model, x, cond=c, batch=batch, uc=uc_1, init_step=0,
44
+ aae_enabled = cfgs.aae_enabled, detailed = cfgs.detailed)
45
+
46
+ samples_x = model.decode_first_stage(samples_z)
47
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
48
+
49
+ return samples, samples_z
50
+
51
+
52
+ def demo_predict(input_blk, text, num_samples, steps, scale, seed, show_detail):
53
+
54
+ global cfgs, global_index
55
+
56
+ global_index += 1
57
+
58
+ if num_samples > 1: cfgs.noise_iters = 0
59
+
60
+ cfgs.batch_size = num_samples
61
+ cfgs.steps = steps
62
+ cfgs.scale[0] = scale
63
+ cfgs.detailed = show_detail
64
+ seed_everything(seed)
65
+
66
+ sampler = init_sampling(cfgs)
67
+
68
+ image = input_blk["image"]
69
+ mask = input_blk["mask"]
70
+ image = cv2.resize(image, (cfgs.W, cfgs.H))
71
+ mask = cv2.resize(mask, (cfgs.W, cfgs.H))
72
+
73
+ mask = (mask == 0).astype(np.int32)
74
+
75
+ image = torch.from_numpy(image.transpose(2,0,1)).to(dtype=torch.float32) / 127.5 - 1.0
76
+ mask = torch.from_numpy(mask.transpose(2,0,1)).to(dtype=torch.float32).mean(dim=0, keepdim=True)
77
+ masked = image * mask
78
+ mask = 1 - mask
79
+
80
+ seg_mask = torch.cat((torch.ones(len(text)), torch.zeros(cfgs.seq_len-len(text))))
81
+
82
+ # additional cond
83
+ txt = f"\"{text}\""
84
+ original_size_as_tuple = torch.tensor((cfgs.H, cfgs.W))
85
+ crop_coords_top_left = torch.tensor((0, 0))
86
+ target_size_as_tuple = torch.tensor((cfgs.H, cfgs.W))
87
+
88
+ image = torch.tile(image[None], (num_samples, 1, 1, 1))
89
+ mask = torch.tile(mask[None], (num_samples, 1, 1, 1))
90
+ masked = torch.tile(masked[None], (num_samples, 1, 1, 1))
91
+ seg_mask = torch.tile(seg_mask[None], (num_samples, 1))
92
+ original_size_as_tuple = torch.tile(original_size_as_tuple[None], (num_samples, 1))
93
+ crop_coords_top_left = torch.tile(crop_coords_top_left[None], (num_samples, 1))
94
+ target_size_as_tuple = torch.tile(target_size_as_tuple[None], (num_samples, 1))
95
+
96
+ text = [text for i in range(num_samples)]
97
+ txt = [txt for i in range(num_samples)]
98
+ name = [str(global_index) for i in range(num_samples)]
99
+
100
+ batch = {
101
+ "image": image,
102
+ "mask": mask,
103
+ "masked": masked,
104
+ "seg_mask": seg_mask,
105
+ "label": text,
106
+ "txt": txt,
107
+ "original_size_as_tuple": original_size_as_tuple,
108
+ "crop_coords_top_left": crop_coords_top_left,
109
+ "target_size_as_tuple": target_size_as_tuple,
110
+ "name": name
111
+ }
112
+
113
+ samples, samples_z = predict(cfgs, model, sampler, batch)
114
+ samples = samples.cpu().numpy().transpose(0, 2, 3, 1) * 255
115
+ results = [Image.fromarray(sample.astype(np.uint8)) for sample in samples]
116
+
117
+ if cfgs.detailed:
118
+ sections = []
119
+ attn_map = Image.open(f"./temp/attn_map/attn_map_{global_index}.png")
120
+ seg_maps = np.load(f"./temp/seg_map/seg_{global_index}.npy")
121
+ for i, seg_map in enumerate(seg_maps):
122
+ seg_map = cv2.resize(seg_map, (cfgs.W, cfgs.H))
123
+ sections.append((seg_map, text[0][i]))
124
+ seg = (results[0], sections)
125
+ else:
126
+ attn_map = None
127
+ seg = None
128
+
129
+ return results, attn_map, seg
130
+
131
+
132
+ if __name__ == "__main__":
133
+
134
+ cfgs = OmegaConf.load("./configs/demo.yaml")
135
+
136
+ model = init_model(cfgs)
137
+ global_index = 0
138
+
139
+ block = gr.Blocks().queue()
140
+ with block:
141
+
142
+ with gr.Row():
143
+
144
+ gr.HTML(
145
+ """
146
+ <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
147
+ <h1 style="font-weight: 600; font-size: 2rem; margin: 0rem">
148
+ UDiffText: A Unified Framework for High-quality Text Synthesis in Arbitrary Images via Character-aware Diffusion Models
149
+ </h1>
150
+ <h3 style="font-weight: 450; font-size: 1rem; margin: 0rem">
151
+ [<a href="" style="color:blue;">arXiv</a>]
152
+ [<a href="" style="color:blue;">Code</a>]
153
+ [<a href="" style="color:blue;">ProjectPage</a>]
154
+ </h3>
155
+ <h2 style="text-align: left; font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
156
+ Our proposed UDiffText is capable of synthesizing accurate and harmonious text in either synthetic or real-word images, thus can be applied to tasks like scene text editing (a), arbitrary text generation (b) and accurate T2I generation (c)
157
+ </h2>
158
+ <div align=center><img src="file/demo/teaser.png" alt="UDiffText" width="80%"></div>
159
+ </div>
160
+ """
161
+ )
162
+
163
+ with gr.Row():
164
+
165
+ with gr.Column():
166
+
167
+ input_blk = gr.Image(source='upload', tool='sketch', type="numpy", label="Input", height=512)
168
+ text = gr.Textbox(label="Text to render:", info="the text you want to render at the masked region")
169
+ run_button = gr.Button(variant="primary")
170
+
171
+ with gr.Accordion("Advanced options", open=False):
172
+
173
+ num_samples = gr.Slider(label="Images", info="number of generated images, locked as 1", minimum=1, maximum=1, value=1, step=1)
174
+ steps = gr.Slider(label="Steps", info ="denoising sampling steps", minimum=1, maximum=200, value=50, step=1)
175
+ scale = gr.Slider(label="Guidance Scale", info="the scale of classifier-free guidance (CFG)", minimum=0.0, maximum=10.0, value=4.0, step=0.1)
176
+ seed = gr.Slider(label="Seed", info="random seed for noise initialization", minimum=0, maximum=2147483647, step=1, randomize=True)
177
+ show_detail = gr.Checkbox(label="Show Detail", info="show the additional visualization results", value=True)
178
+
179
+ with gr.Column():
180
+
181
+ gallery = gr.Gallery(label="Output", height=512, preview=True)
182
+
183
+ with gr.Accordion("Visualization results", open=True):
184
+
185
+ with gr.Tab(label="Attention Maps"):
186
+ gr.Markdown("### Attention maps for each character (extracted from middle blocks at intermediate sampling step):")
187
+ attn_map = gr.Image(show_label=False, show_download_button=False)
188
+ with gr.Tab(label="Segmentation Maps"):
189
+ gr.Markdown("### Character-level segmentation maps (using upscaled attention maps):")
190
+ seg_map = gr.AnnotatedImage(height=384, show_label=False, show_download_button=False)
191
+
192
+ # examples
193
+ examples = []
194
+ example_paths = sorted(glob.glob(ospj("./demo/examples", "*")))
195
+ for example_path in example_paths:
196
+ label = example_path.split(os.sep)[-1].split(".")[0].split("_")[0]
197
+ examples.append([example_path, label])
198
+
199
+ gr.Markdown("## Examples:")
200
+ gr.Examples(
201
+ examples=examples,
202
+ inputs=[input_blk, text]
203
+ )
204
+
205
+ run_button.click(fn=demo_predict, inputs=[input_blk, text, num_samples, steps, scale, seed, show_detail], outputs=[gallery, attn_map, seg_map])
206
+
207
+ block.launch()
checkpoints/AEs/AE_inpainting_2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:547baac83984f8bf8b433882236b87e77eb4d2f5c71e3d7a04b8dec2fe02b81f
3
+ size 334640988
checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4076c90467a907dcb8cde15776bfda4473010fe845739490341db74e82cd2267
3
+ size 4059026213
checkpoints/st-step=100000+la-step=100000-simp.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:968397df8910f3324d94ce3df7e9d70f1bf2415a46d22edef1a510885ee0648e
3
+ size 2558065830
configs/demo.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ type: "demo"
2
+
3
+ # path
4
+ load_ckpt_path: "./checkpoints/st-step=100000+la-step=100000-simp.ckpt"
5
+ model_cfg_path: "./configs/test/textdesign_sd_2.yaml"
6
+
7
+ # param
8
+ H: 512
9
+ W: 512
10
+ seq_len: 12
11
+ batch_size: 1
12
+
13
+ channel: 4 # AE latent channel
14
+ factor: 8 # AE downsample factor
15
+ scale: [4.0, 0.0] # content scale, style scale
16
+ noise_iters: 10
17
+ force_uc_zero_embeddings: ["ref", "label"]
18
+ aae_enabled: False
19
+ detailed: True
20
+ dual_conditioner: False
21
+
22
+
23
+ # runtime
24
+ steps: 50
25
+ init_step: 0
26
+ num_workers: 0
27
+ gpu: 0
28
+ max_iter: 100
29
+
configs/test/textdesign_sd_2.yaml ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ input_key: image
5
+ scale_factor: 0.18215
6
+ disable_first_stage_autocast: True
7
+
8
+ denoiser_config:
9
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
10
+ params:
11
+ num_idx: 1000
12
+
13
+ weighting_config:
14
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
15
+ scaling_config:
16
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
17
+ discretization_config:
18
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
19
+
20
+ network_config:
21
+ target: sgm.modules.diffusionmodules.openaimodel.UNetAddModel
22
+ params:
23
+ use_checkpoint: False
24
+ in_channels: 9
25
+ out_channels: 4
26
+ ctrl_channels: 0
27
+ model_channels: 320
28
+ attention_resolutions: [4, 2, 1]
29
+ attn_type: add_attn
30
+ attn_layers:
31
+ - output_blocks.6.1
32
+ num_res_blocks: 2
33
+ channel_mult: [1, 2, 4, 4]
34
+ num_head_channels: 64
35
+ use_spatial_transformer: True
36
+ use_linear_in_transformer: True
37
+ transformer_depth: 1
38
+ context_dim: 0
39
+ add_context_dim: 2048
40
+ legacy: False
41
+
42
+ conditioner_config:
43
+ target: sgm.modules.GeneralConditioner
44
+ params:
45
+ emb_models:
46
+ # crossattn cond
47
+ # - is_trainable: False
48
+ # input_key: txt
49
+ # target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
50
+ # params:
51
+ # arch: ViT-H-14
52
+ # version: ./checkpoints/encoders/OpenCLIP/ViT-H-14/open_clip_pytorch_model.bin
53
+ # layer: penultimate
54
+ # add crossattn cond
55
+ - is_trainable: False
56
+ input_key: label
57
+ target: sgm.modules.encoders.modules.LabelEncoder
58
+ params:
59
+ is_add_embedder: True
60
+ max_len: 12
61
+ emb_dim: 2048
62
+ n_heads: 8
63
+ n_trans_layers: 12
64
+ ckpt_path: ./checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt # ./checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt
65
+ # concat cond
66
+ - is_trainable: False
67
+ input_key: mask
68
+ target: sgm.modules.encoders.modules.IdentityEncoder
69
+ - is_trainable: False
70
+ input_key: masked
71
+ target: sgm.modules.encoders.modules.LatentEncoder
72
+ params:
73
+ scale_factor: 0.18215
74
+ config:
75
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
76
+ params:
77
+ ckpt_path: ./checkpoints/AEs/AE_inpainting_2.safetensors
78
+ embed_dim: 4
79
+ monitor: val/rec_loss
80
+ ddconfig:
81
+ attn_type: vanilla-xformers
82
+ double_z: true
83
+ z_channels: 4
84
+ resolution: 256
85
+ in_channels: 3
86
+ out_ch: 3
87
+ ch: 128
88
+ ch_mult: [1, 2, 4, 4]
89
+ num_res_blocks: 2
90
+ attn_resolutions: []
91
+ dropout: 0.0
92
+ lossconfig:
93
+ target: torch.nn.Identity
94
+
95
+ first_stage_config:
96
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
97
+ params:
98
+ embed_dim: 4
99
+ monitor: val/rec_loss
100
+ ddconfig:
101
+ attn_type: vanilla-xformers
102
+ double_z: true
103
+ z_channels: 4
104
+ resolution: 256
105
+ in_channels: 3
106
+ out_ch: 3
107
+ ch: 128
108
+ ch_mult: [1, 2, 4, 4]
109
+ num_res_blocks: 2
110
+ attn_resolutions: []
111
+ dropout: 0.0
112
+ lossconfig:
113
+ target: torch.nn.Identity
114
+
115
+ loss_fn_config:
116
+ target: sgm.modules.diffusionmodules.loss.FullLoss # StandardDiffusionLoss
117
+ params:
118
+ seq_len: 12
119
+ kernel_size: 3
120
+ gaussian_sigma: 0.5
121
+ min_attn_size: 16
122
+ lambda_local_loss: 0.02
123
+ lambda_ocr_loss: 0.001
124
+ ocr_enabled: False
125
+
126
+ predictor_config:
127
+ target: sgm.modules.predictors.model.ParseqPredictor
128
+ params:
129
+ ckpt_path: "./checkpoints/predictors/parseq-bb5792a6.pt"
130
+
131
+ sigma_sampler_config:
132
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
133
+ params:
134
+ num_idx: 1000
135
+
136
+ discretization_config:
137
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
demo/examples/CEFUL_1_0.jpeg ADDED

Git LFS Details

  • SHA256: d90a580083194c2130da6fd0176df3fde40b312f13f00b34b7ac6641e4ff1597
  • Pointer size: 131 Bytes
  • Size of remote file: 113 kB
demo/examples/CLOTHES_0_0.png ADDED

Git LFS Details

  • SHA256: 7a7374b07e520fe86c4b0b587082f125dc826542caf5d9d1c08107bb1cfe0154
  • Pointer size: 131 Bytes
  • Size of remote file: 331 kB
demo/examples/COMPLICATED_0_1.jpeg ADDED

Git LFS Details

  • SHA256: 98ba496f8289dda423bf5d9d60493e599df61eb5d6de75f8b966786909c3a5ab
  • Pointer size: 131 Bytes
  • Size of remote file: 208 kB
demo/examples/DELIGHT_0_1.jpeg ADDED

Git LFS Details

  • SHA256: 899b7388f742c28317c85944f46a181c70c8d89ce22229838fa9d2afbdae495a
  • Pointer size: 131 Bytes
  • Size of remote file: 343 kB
demo/examples/ECHOES_0_0.jpeg ADDED

Git LFS Details

  • SHA256: 0ade75487cac60d88684e41b8ddfcd492f386f22eeb8bb67f1d99a2514803477
  • Pointer size: 131 Bytes
  • Size of remote file: 286 kB
demo/examples/ENGINE_0_0.png ADDED

Git LFS Details

  • SHA256: fd1fd33cded3a9c8245a38cd82e0603e2f583dbe7b415dd13ed20cdec08e94b0
  • Pointer size: 131 Bytes
  • Size of remote file: 578 kB
demo/examples/FASCINATING_0_1.jpeg ADDED

Git LFS Details

  • SHA256: af5ea76ba8c5827f9ec83bc2b6a096511bc619975db5fc8f6e742ba1bb687570
  • Pointer size: 131 Bytes
  • Size of remote file: 311 kB
demo/examples/FAVOURITE_0_0.jpeg ADDED

Git LFS Details

  • SHA256: 38747d02015147fa4f1eb3ebca5a3757d908957cf8caf7de8b33d5a1750d6ada
  • Pointer size: 131 Bytes
  • Size of remote file: 130 kB
demo/examples/FINNAL_0_1.jpeg ADDED

Git LFS Details

  • SHA256: e6dc56ca1ba9a1fc5e6899a9629a06d843ca10a96f5eca095c0cf1af9e38191a
  • Pointer size: 131 Bytes
  • Size of remote file: 176 kB
demo/examples/FRONTIER_0_0.png ADDED

Git LFS Details

  • SHA256: c0231c43a100dcf5f95a3f79c9fcbe77b345e70f87273ec395f7ff857716483c
  • Pointer size: 131 Bytes
  • Size of remote file: 437 kB
demo/examples/Innovate_0_0.jpeg ADDED

Git LFS Details

  • SHA256: a74d91e607bceafe0ea45858dec08949eb93597cbed45a7bf194cf476d118b03
  • Pointer size: 131 Bytes
  • Size of remote file: 186 kB
demo/examples/PRESERVE_0_0.jpeg ADDED

Git LFS Details

  • SHA256: e5608fe1a2ccd04f18ba8a24172bd484bf2a628e21756fb5121283cad9618f60
  • Pointer size: 131 Bytes
  • Size of remote file: 296 kB
demo/examples/Peaceful_0_0.jpeg ADDED

Git LFS Details

  • SHA256: 40e3adca8425b26c41f64ff62a29f299d172103989fbbe82c77f41875af9c86d
  • Pointer size: 130 Bytes
  • Size of remote file: 93.3 kB
demo/examples/Scamps_0_0.png ADDED

Git LFS Details

  • SHA256: 3fa97107ac42733873b451efa06b7ba2fefbfb182f905084e0fd2f511ec8a251
  • Pointer size: 131 Bytes
  • Size of remote file: 267 kB
demo/examples/TREE_0_0.png ADDED

Git LFS Details

  • SHA256: 76e3f78050bd19efd0247befb40a5fc56a6d3067324606a9800cfdd91a6c142d
  • Pointer size: 131 Bytes
  • Size of remote file: 384 kB
demo/examples/better_0_0.jpg ADDED

Git LFS Details

  • SHA256: 6473b82056e41fa74594e89c07c92640c375bdf568e3e9a5f296c9ec8c749145
  • Pointer size: 131 Bytes
  • Size of remote file: 201 kB
demo/examples/tested_0_0.png ADDED

Git LFS Details

  • SHA256: 2a3e38e5f1c63b1db4ce6d961aebf8f793ba19554f390b263340269d21b0d84a
  • Pointer size: 131 Bytes
  • Size of remote file: 305 kB
demo/teaser.png ADDED

Git LFS Details

  • SHA256: dcd166cc9691c99a7ee93a028ab485472171ee348a5f4dbaf82f6bf1fb27c66d
  • Pointer size: 132 Bytes
  • Size of remote file: 2.62 MB
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ colorlover==0.3.0
2
+ gradio==3.41.0
3
+ imageio==2.31.2
4
+ img2dataset==1.42.0
5
+ lpips==0.1.4
6
+ matplotlib==3.7.2
7
+ numpy==1.25.1
8
+ omegaconf==2.3.0
9
+ open-clip-torch==2.20.0
10
+ opencv-python==4.6.0.66
11
+ Pillow==9.5.0
12
+ pytorch-fid==0.3.0
13
+ pytorch-lightning==2.0.1
14
+ safetensors==0.3.1
15
+ scikit-learn==1.3.0
16
+ scipy==1.11.1
17
+ seaborn==0.12.2
18
+ tensorboard==2.14.0
19
+ tokenizers==0.13.3
20
+ torch==2.1.0
21
+ torchvision==0.16.0
22
+ tqdm==4.65.0
23
+ transformers==4.30.2
24
+
sgm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .models import AutoencodingEngine, DiffusionEngine
2
+ from .util import instantiate_from_config
sgm/lr_scheduler.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+
9
+ def __init__(
10
+ self,
11
+ warm_up_steps,
12
+ lr_min,
13
+ lr_max,
14
+ lr_start,
15
+ max_decay_steps,
16
+ verbosity_interval=0,
17
+ ):
18
+ self.lr_warm_up_steps = warm_up_steps
19
+ self.lr_start = lr_start
20
+ self.lr_min = lr_min
21
+ self.lr_max = lr_max
22
+ self.lr_max_decay_steps = max_decay_steps
23
+ self.last_lr = 0.0
24
+ self.verbosity_interval = verbosity_interval
25
+
26
+ def schedule(self, n, **kwargs):
27
+ if self.verbosity_interval > 0:
28
+ if n % self.verbosity_interval == 0:
29
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
30
+ if n < self.lr_warm_up_steps:
31
+ lr = (
32
+ self.lr_max - self.lr_start
33
+ ) / self.lr_warm_up_steps * n + self.lr_start
34
+ self.last_lr = lr
35
+ return lr
36
+ else:
37
+ t = (n - self.lr_warm_up_steps) / (
38
+ self.lr_max_decay_steps - self.lr_warm_up_steps
39
+ )
40
+ t = min(t, 1.0)
41
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
42
+ 1 + np.cos(t * np.pi)
43
+ )
44
+ self.last_lr = lr
45
+ return lr
46
+
47
+ def __call__(self, n, **kwargs):
48
+ return self.schedule(n, **kwargs)
49
+
50
+
51
+ class LambdaWarmUpCosineScheduler2:
52
+ """
53
+ supports repeated iterations, configurable via lists
54
+ note: use with a base_lr of 1.0.
55
+ """
56
+
57
+ def __init__(
58
+ self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
59
+ ):
60
+ assert (
61
+ len(warm_up_steps)
62
+ == len(f_min)
63
+ == len(f_max)
64
+ == len(f_start)
65
+ == len(cycle_lengths)
66
+ )
67
+ self.lr_warm_up_steps = warm_up_steps
68
+ self.f_start = f_start
69
+ self.f_min = f_min
70
+ self.f_max = f_max
71
+ self.cycle_lengths = cycle_lengths
72
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
73
+ self.last_f = 0.0
74
+ self.verbosity_interval = verbosity_interval
75
+
76
+ def find_in_interval(self, n):
77
+ interval = 0
78
+ for cl in self.cum_cycles[1:]:
79
+ if n <= cl:
80
+ return interval
81
+ interval += 1
82
+
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0:
88
+ print(
89
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
90
+ f"current cycle {cycle}"
91
+ )
92
+ if n < self.lr_warm_up_steps[cycle]:
93
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
94
+ cycle
95
+ ] * n + self.f_start[cycle]
96
+ self.last_f = f
97
+ return f
98
+ else:
99
+ t = (n - self.lr_warm_up_steps[cycle]) / (
100
+ self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
101
+ )
102
+ t = min(t, 1.0)
103
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
104
+ 1 + np.cos(t * np.pi)
105
+ )
106
+ self.last_f = f
107
+ return f
108
+
109
+ def __call__(self, n, **kwargs):
110
+ return self.schedule(n, **kwargs)
111
+
112
+
113
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
114
+ def schedule(self, n, **kwargs):
115
+ cycle = self.find_in_interval(n)
116
+ n = n - self.cum_cycles[cycle]
117
+ if self.verbosity_interval > 0:
118
+ if n % self.verbosity_interval == 0:
119
+ print(
120
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
121
+ f"current cycle {cycle}"
122
+ )
123
+
124
+ if n < self.lr_warm_up_steps[cycle]:
125
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
126
+ cycle
127
+ ] * n + self.f_start[cycle]
128
+ self.last_f = f
129
+ return f
130
+ else:
131
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
132
+ self.cycle_lengths[cycle] - n
133
+ ) / (self.cycle_lengths[cycle])
134
+ self.last_f = f
135
+ return f
sgm/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .autoencoder import AutoencodingEngine
2
+ from .diffusion import DiffusionEngine
sgm/models/autoencoder.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from abc import abstractmethod
3
+ from contextlib import contextmanager
4
+ from typing import Any, Dict, Tuple, Union
5
+
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ from omegaconf import ListConfig
9
+ from packaging import version
10
+ from safetensors.torch import load_file as load_safetensors
11
+
12
+ from ..modules.diffusionmodules.model import Decoder, Encoder
13
+ from ..modules.distributions.distributions import DiagonalGaussianDistribution
14
+ from ..modules.ema import LitEma
15
+ from ..util import default, get_obj_from_str, instantiate_from_config
16
+
17
+
18
+ class AbstractAutoencoder(pl.LightningModule):
19
+ """
20
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
21
+ unCLIP models, etc. Hence, it is fairly general, and specific features
22
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ ema_decay: Union[None, float] = None,
28
+ monitor: Union[None, str] = None,
29
+ input_key: str = "jpg",
30
+ ckpt_path: Union[None, str] = None,
31
+ ignore_keys: Union[Tuple, list, ListConfig] = (),
32
+ ):
33
+ super().__init__()
34
+ self.input_key = input_key
35
+ self.use_ema = ema_decay is not None
36
+ if monitor is not None:
37
+ self.monitor = monitor
38
+
39
+ if self.use_ema:
40
+ self.model_ema = LitEma(self, decay=ema_decay)
41
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
42
+
43
+ if ckpt_path is not None:
44
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
45
+
46
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
47
+ self.automatic_optimization = False
48
+
49
+ def init_from_ckpt(
50
+ self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple()
51
+ ) -> None:
52
+ if path.endswith("ckpt"):
53
+ sd = torch.load(path, map_location="cpu")["state_dict"]
54
+ elif path.endswith("safetensors"):
55
+ sd = load_safetensors(path)
56
+ else:
57
+ raise NotImplementedError
58
+
59
+ keys = list(sd.keys())
60
+ for k in keys:
61
+ for ik in ignore_keys:
62
+ if re.match(ik, k):
63
+ print("Deleting key {} from state_dict.".format(k))
64
+ del sd[k]
65
+ missing, unexpected = self.load_state_dict(sd, strict=False)
66
+ print(
67
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
68
+ )
69
+ if len(missing) > 0:
70
+ print(f"Missing Keys: {missing}")
71
+ if len(unexpected) > 0:
72
+ print(f"Unexpected Keys: {unexpected}")
73
+
74
+ @abstractmethod
75
+ def get_input(self, batch) -> Any:
76
+ raise NotImplementedError()
77
+
78
+ def on_train_batch_end(self, *args, **kwargs):
79
+ # for EMA computation
80
+ if self.use_ema:
81
+ self.model_ema(self)
82
+
83
+ @contextmanager
84
+ def ema_scope(self, context=None):
85
+ if self.use_ema:
86
+ self.model_ema.store(self.parameters())
87
+ self.model_ema.copy_to(self)
88
+ if context is not None:
89
+ print(f"{context}: Switched to EMA weights")
90
+ try:
91
+ yield None
92
+ finally:
93
+ if self.use_ema:
94
+ self.model_ema.restore(self.parameters())
95
+ if context is not None:
96
+ print(f"{context}: Restored training weights")
97
+
98
+ @abstractmethod
99
+ def encode(self, *args, **kwargs) -> torch.Tensor:
100
+ raise NotImplementedError("encode()-method of abstract base class called")
101
+
102
+ @abstractmethod
103
+ def decode(self, *args, **kwargs) -> torch.Tensor:
104
+ raise NotImplementedError("decode()-method of abstract base class called")
105
+
106
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
107
+ print(f"loading >>> {cfg['target']} <<< optimizer from config")
108
+ return get_obj_from_str(cfg["target"])(
109
+ params, lr=lr, **cfg.get("params", dict())
110
+ )
111
+
112
+ def configure_optimizers(self) -> Any:
113
+ raise NotImplementedError()
114
+
115
+
116
+ class AutoencodingEngine(AbstractAutoencoder):
117
+ """
118
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
119
+ (we also restore them explicitly as special cases for legacy reasons).
120
+ Regularizations such as KL or VQ are moved to the regularizer class.
121
+ """
122
+
123
+ def __init__(
124
+ self,
125
+ *args,
126
+ encoder_config: Dict,
127
+ decoder_config: Dict,
128
+ loss_config: Dict,
129
+ regularizer_config: Dict,
130
+ optimizer_config: Union[Dict, None] = None,
131
+ lr_g_factor: float = 1.0,
132
+ **kwargs,
133
+ ):
134
+ super().__init__(*args, **kwargs)
135
+ # todo: add options to freeze encoder/decoder
136
+ self.encoder = instantiate_from_config(encoder_config)
137
+ self.decoder = instantiate_from_config(decoder_config)
138
+ self.loss = instantiate_from_config(loss_config)
139
+ self.regularization = instantiate_from_config(regularizer_config)
140
+ self.optimizer_config = default(
141
+ optimizer_config, {"target": "torch.optim.Adam"}
142
+ )
143
+ self.lr_g_factor = lr_g_factor
144
+
145
+ def get_input(self, batch: Dict) -> torch.Tensor:
146
+ # assuming unified data format, dataloader returns a dict.
147
+ # image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc)
148
+ return batch[self.input_key]
149
+
150
+ def get_autoencoder_params(self) -> list:
151
+ params = (
152
+ list(self.encoder.parameters())
153
+ + list(self.decoder.parameters())
154
+ + list(self.regularization.get_trainable_parameters())
155
+ + list(self.loss.get_trainable_autoencoder_parameters())
156
+ )
157
+ return params
158
+
159
+ def get_discriminator_params(self) -> list:
160
+ params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
161
+ return params
162
+
163
+ def get_last_layer(self):
164
+ return self.decoder.get_last_layer()
165
+
166
+ def encode(self, x: Any, return_reg_log: bool = False) -> Any:
167
+ z = self.encoder(x)
168
+ z, reg_log = self.regularization(z)
169
+ if return_reg_log:
170
+ return z, reg_log
171
+ return z
172
+
173
+ def decode(self, z: Any) -> torch.Tensor:
174
+ x = self.decoder(z)
175
+ return x
176
+
177
+ def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
178
+ z, reg_log = self.encode(x, return_reg_log=True)
179
+ dec = self.decode(z)
180
+ return z, dec, reg_log
181
+
182
+ def training_step(self, batch, batch_idx, optimizer_idx) -> Any:
183
+ x = self.get_input(batch)
184
+ z, xrec, regularization_log = self(x)
185
+
186
+ if optimizer_idx == 0:
187
+ # autoencode
188
+ aeloss, log_dict_ae = self.loss(
189
+ regularization_log,
190
+ x,
191
+ xrec,
192
+ optimizer_idx,
193
+ self.global_step,
194
+ last_layer=self.get_last_layer(),
195
+ split="train",
196
+ )
197
+
198
+ self.log_dict(
199
+ log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
200
+ )
201
+ return aeloss
202
+
203
+ if optimizer_idx == 1:
204
+ # discriminator
205
+ discloss, log_dict_disc = self.loss(
206
+ regularization_log,
207
+ x,
208
+ xrec,
209
+ optimizer_idx,
210
+ self.global_step,
211
+ last_layer=self.get_last_layer(),
212
+ split="train",
213
+ )
214
+ self.log_dict(
215
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
216
+ )
217
+ return discloss
218
+
219
+ def validation_step(self, batch, batch_idx) -> Dict:
220
+ log_dict = self._validation_step(batch, batch_idx)
221
+ with self.ema_scope():
222
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
223
+ log_dict.update(log_dict_ema)
224
+ return log_dict
225
+
226
+ def _validation_step(self, batch, batch_idx, postfix="") -> Dict:
227
+ x = self.get_input(batch)
228
+
229
+ z, xrec, regularization_log = self(x)
230
+ aeloss, log_dict_ae = self.loss(
231
+ regularization_log,
232
+ x,
233
+ xrec,
234
+ 0,
235
+ self.global_step,
236
+ last_layer=self.get_last_layer(),
237
+ split="val" + postfix,
238
+ )
239
+
240
+ discloss, log_dict_disc = self.loss(
241
+ regularization_log,
242
+ x,
243
+ xrec,
244
+ 1,
245
+ self.global_step,
246
+ last_layer=self.get_last_layer(),
247
+ split="val" + postfix,
248
+ )
249
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
250
+ log_dict_ae.update(log_dict_disc)
251
+ self.log_dict(log_dict_ae)
252
+ return log_dict_ae
253
+
254
+ def configure_optimizers(self) -> Any:
255
+ ae_params = self.get_autoencoder_params()
256
+ disc_params = self.get_discriminator_params()
257
+
258
+ opt_ae = self.instantiate_optimizer_from_config(
259
+ ae_params,
260
+ default(self.lr_g_factor, 1.0) * self.learning_rate,
261
+ self.optimizer_config,
262
+ )
263
+ opt_disc = self.instantiate_optimizer_from_config(
264
+ disc_params, self.learning_rate, self.optimizer_config
265
+ )
266
+
267
+ return [opt_ae, opt_disc], []
268
+
269
+ @torch.no_grad()
270
+ def log_images(self, batch: Dict, **kwargs) -> Dict:
271
+ log = dict()
272
+ x = self.get_input(batch)
273
+ _, xrec, _ = self(x)
274
+ log["inputs"] = x
275
+ log["reconstructions"] = xrec
276
+ with self.ema_scope():
277
+ _, xrec_ema, _ = self(x)
278
+ log["reconstructions_ema"] = xrec_ema
279
+ return log
280
+
281
+
282
+ class AutoencoderKL(AutoencodingEngine):
283
+ def __init__(self, embed_dim: int, **kwargs):
284
+ ddconfig = kwargs.pop("ddconfig")
285
+ ckpt_path = kwargs.pop("ckpt_path", None)
286
+ ignore_keys = kwargs.pop("ignore_keys", ())
287
+ super().__init__(
288
+ encoder_config={"target": "torch.nn.Identity"},
289
+ decoder_config={"target": "torch.nn.Identity"},
290
+ regularizer_config={"target": "torch.nn.Identity"},
291
+ loss_config=kwargs.pop("lossconfig"),
292
+ **kwargs,
293
+ )
294
+ assert ddconfig["double_z"]
295
+ self.encoder = Encoder(**ddconfig)
296
+ self.decoder = Decoder(**ddconfig)
297
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
298
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
299
+ self.embed_dim = embed_dim
300
+
301
+ if ckpt_path is not None:
302
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
303
+
304
+ def encode(self, x):
305
+ assert (
306
+ not self.training
307
+ ), f"{self.__class__.__name__} only supports inference currently"
308
+ h = self.encoder(x)
309
+ moments = self.quant_conv(h)
310
+ posterior = DiagonalGaussianDistribution(moments)
311
+ return posterior
312
+
313
+ def decode(self, z, **decoder_kwargs):
314
+ z = self.post_quant_conv(z)
315
+ dec = self.decoder(z, **decoder_kwargs)
316
+ return dec
317
+
318
+
319
+ class AutoencoderKLInferenceWrapper(AutoencoderKL):
320
+ def encode(self, x):
321
+ return super().encode(x).sample()
322
+
323
+
324
+ class IdentityFirstStage(AbstractAutoencoder):
325
+ def __init__(self, *args, **kwargs):
326
+ super().__init__(*args, **kwargs)
327
+
328
+ def get_input(self, x: Any) -> Any:
329
+ return x
330
+
331
+ def encode(self, x: Any, *args, **kwargs) -> Any:
332
+ return x
333
+
334
+ def decode(self, x: Any, *args, **kwargs) -> Any:
335
+ return x
sgm/models/diffusion.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ from typing import Any, Dict, List, Tuple, Union
3
+
4
+ import pytorch_lightning as pl
5
+ import torch
6
+ from omegaconf import ListConfig, OmegaConf
7
+ from safetensors.torch import load_file as load_safetensors
8
+ from torch.optim.lr_scheduler import LambdaLR
9
+
10
+ from ..modules import UNCONDITIONAL_CONFIG
11
+ from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
12
+ from ..modules.ema import LitEma
13
+ from ..util import (
14
+ default,
15
+ disabled_train,
16
+ get_obj_from_str,
17
+ instantiate_from_config,
18
+ log_txt_as_img,
19
+ )
20
+
21
+
22
+ class DiffusionEngine(pl.LightningModule):
23
+ def __init__(
24
+ self,
25
+ network_config,
26
+ denoiser_config,
27
+ first_stage_config,
28
+ conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
29
+ sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
30
+ optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
31
+ scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
32
+ loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
33
+ network_wrapper: Union[None, str] = None,
34
+ ckpt_path: Union[None, str] = None,
35
+ use_ema: bool = False,
36
+ ema_decay_rate: float = 0.9999,
37
+ scale_factor: float = 1.0,
38
+ disable_first_stage_autocast=False,
39
+ input_key: str = "jpg",
40
+ log_keys: Union[List, None] = None,
41
+ no_cond_log: bool = False,
42
+ compile_model: bool = False,
43
+ opt_keys: Union[List, None] = None
44
+ ):
45
+ super().__init__()
46
+ self.opt_keys = opt_keys
47
+ self.log_keys = log_keys
48
+ self.input_key = input_key
49
+ self.optimizer_config = default(
50
+ optimizer_config, {"target": "torch.optim.AdamW"}
51
+ )
52
+ model = instantiate_from_config(network_config)
53
+ self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
54
+ model, compile_model=compile_model
55
+ )
56
+
57
+ self.denoiser = instantiate_from_config(denoiser_config)
58
+ self.sampler = (
59
+ instantiate_from_config(sampler_config)
60
+ if sampler_config is not None
61
+ else None
62
+ )
63
+ self.conditioner = instantiate_from_config(
64
+ default(conditioner_config, UNCONDITIONAL_CONFIG)
65
+ )
66
+ self.scheduler_config = scheduler_config
67
+ self._init_first_stage(first_stage_config)
68
+
69
+ self.loss_fn = (
70
+ instantiate_from_config(loss_fn_config)
71
+ if loss_fn_config is not None
72
+ else None
73
+ )
74
+
75
+ self.use_ema = use_ema
76
+ if self.use_ema:
77
+ self.model_ema = LitEma(self.model, decay=ema_decay_rate)
78
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
79
+
80
+ self.scale_factor = scale_factor
81
+ self.disable_first_stage_autocast = disable_first_stage_autocast
82
+ self.no_cond_log = no_cond_log
83
+
84
+ if ckpt_path is not None:
85
+ self.init_from_ckpt(ckpt_path)
86
+
87
+ def init_from_ckpt(
88
+ self,
89
+ path: str,
90
+ ) -> None:
91
+ if path.endswith("ckpt"):
92
+ sd = torch.load(path, map_location="cpu")["state_dict"]
93
+ elif path.endswith("safetensors"):
94
+ sd = load_safetensors(path)
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ missing, unexpected = self.load_state_dict(sd, strict=False)
99
+ print(
100
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
101
+ )
102
+ if len(missing) > 0:
103
+ print(f"Missing Keys: {missing}")
104
+ if len(unexpected) > 0:
105
+ print(f"Unexpected Keys: {unexpected}")
106
+
107
+ def freeze(self):
108
+
109
+ for param in self.parameters():
110
+ param.requires_grad_(False)
111
+
112
+ def _init_first_stage(self, config):
113
+ model = instantiate_from_config(config).eval()
114
+ model.train = disabled_train
115
+ for param in model.parameters():
116
+ param.requires_grad = False
117
+ self.first_stage_model = model
118
+
119
+ def get_input(self, batch):
120
+ # assuming unified data format, dataloader returns a dict.
121
+ # image tensors should be scaled to -1 ... 1 and in bchw format
122
+ return batch[self.input_key]
123
+
124
+ @torch.no_grad()
125
+ def decode_first_stage(self, z):
126
+ z = 1.0 / self.scale_factor * z
127
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
128
+ out = self.first_stage_model.decode(z)
129
+ return out
130
+
131
+ @torch.no_grad()
132
+ def encode_first_stage(self, x):
133
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
134
+ z = self.first_stage_model.encode(x)
135
+ z = self.scale_factor * z
136
+ return z
137
+
138
+ def forward(self, x, batch):
139
+
140
+ loss, loss_dict = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch, self.first_stage_model, self.scale_factor)
141
+
142
+ return loss, loss_dict
143
+
144
+ def shared_step(self, batch: Dict) -> Any:
145
+ x = self.get_input(batch)
146
+ x = self.encode_first_stage(x)
147
+ batch["global_step"] = self.global_step
148
+ loss, loss_dict = self(x, batch)
149
+ return loss, loss_dict
150
+
151
+ def training_step(self, batch, batch_idx):
152
+ loss, loss_dict = self.shared_step(batch)
153
+
154
+ self.log_dict(
155
+ loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
156
+ )
157
+
158
+ self.log(
159
+ "global_step",
160
+ float(self.global_step),
161
+ prog_bar=True,
162
+ logger=True,
163
+ on_step=True,
164
+ on_epoch=False,
165
+ )
166
+
167
+ lr = self.optimizers().param_groups[0]["lr"]
168
+ self.log(
169
+ "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
170
+ )
171
+
172
+ return loss
173
+
174
+ def on_train_start(self, *args, **kwargs):
175
+ if self.sampler is None or self.loss_fn is None:
176
+ raise ValueError("Sampler and loss function need to be set for training.")
177
+
178
+ def on_train_batch_end(self, *args, **kwargs):
179
+ if self.use_ema:
180
+ self.model_ema(self.model)
181
+
182
+ @contextmanager
183
+ def ema_scope(self, context=None):
184
+ if self.use_ema:
185
+ self.model_ema.store(self.model.parameters())
186
+ self.model_ema.copy_to(self.model)
187
+ if context is not None:
188
+ print(f"{context}: Switched to EMA weights")
189
+ try:
190
+ yield None
191
+ finally:
192
+ if self.use_ema:
193
+ self.model_ema.restore(self.model.parameters())
194
+ if context is not None:
195
+ print(f"{context}: Restored training weights")
196
+
197
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
198
+ return get_obj_from_str(cfg["target"])(
199
+ params, lr=lr, **cfg.get("params", dict())
200
+ )
201
+
202
+ def configure_optimizers(self):
203
+ lr = self.learning_rate
204
+ params = []
205
+ print("Trainable parameter list: ")
206
+ print("-"*20)
207
+ for name, param in self.model.named_parameters():
208
+ if any([key in name for key in self.opt_keys]):
209
+ params.append(param)
210
+ print(name)
211
+ else:
212
+ param.requires_grad_(False)
213
+ for embedder in self.conditioner.embedders:
214
+ if embedder.is_trainable:
215
+ for name, param in embedder.named_parameters():
216
+ params.append(param)
217
+ print(name)
218
+ print("-"*20)
219
+ opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
220
+ scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda epoch: 0.95**epoch)
221
+
222
+ return [opt], scheduler
223
+
224
+ @torch.no_grad()
225
+ def sample(
226
+ self,
227
+ cond: Dict,
228
+ uc: Union[Dict, None] = None,
229
+ batch_size: int = 16,
230
+ shape: Union[None, Tuple, List] = None,
231
+ **kwargs,
232
+ ):
233
+ randn = torch.randn(batch_size, *shape).to(self.device)
234
+
235
+ denoiser = lambda input, sigma, c: self.denoiser(
236
+ self.model, input, sigma, c, **kwargs
237
+ )
238
+ samples = self.sampler(denoiser, randn, cond, uc=uc)
239
+ return samples
240
+
241
+ @torch.no_grad()
242
+ def log_conditionings(self, batch: Dict, n: int) -> Dict:
243
+ """
244
+ Defines heuristics to log different conditionings.
245
+ These can be lists of strings (text-to-image), tensors, ints, ...
246
+ """
247
+ image_h, image_w = batch[self.input_key].shape[2:]
248
+ log = dict()
249
+
250
+ for embedder in self.conditioner.embedders:
251
+ if (
252
+ (self.log_keys is None) or (embedder.input_key in self.log_keys)
253
+ ) and not self.no_cond_log:
254
+ x = batch[embedder.input_key][:n]
255
+ if isinstance(x, torch.Tensor):
256
+ if x.dim() == 1:
257
+ # class-conditional, convert integer to string
258
+ x = [str(x[i].item()) for i in range(x.shape[0])]
259
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
260
+ elif x.dim() == 2:
261
+ # size and crop cond and the like
262
+ x = [
263
+ "x".join([str(xx) for xx in x[i].tolist()])
264
+ for i in range(x.shape[0])
265
+ ]
266
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
267
+ else:
268
+ raise NotImplementedError()
269
+ elif isinstance(x, (List, ListConfig)):
270
+ if isinstance(x[0], str):
271
+ # strings
272
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
273
+ else:
274
+ raise NotImplementedError()
275
+ else:
276
+ raise NotImplementedError()
277
+ log[embedder.input_key] = xc
278
+ return log
279
+
280
+ @torch.no_grad()
281
+ def log_images(
282
+ self,
283
+ batch: Dict,
284
+ N: int = 8,
285
+ sample: bool = True,
286
+ ucg_keys: List[str] = None,
287
+ **kwargs,
288
+ ) -> Dict:
289
+ conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
290
+ if ucg_keys:
291
+ assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
292
+ "Each defined ucg key for sampling must be in the provided conditioner input keys,"
293
+ f"but we have {ucg_keys} vs. {conditioner_input_keys}"
294
+ )
295
+ else:
296
+ ucg_keys = conditioner_input_keys
297
+ log = dict()
298
+
299
+ x = self.get_input(batch)
300
+
301
+ c, uc = self.conditioner.get_unconditional_conditioning(
302
+ batch,
303
+ force_uc_zero_embeddings=ucg_keys
304
+ if len(self.conditioner.embedders) > 0
305
+ else [],
306
+ )
307
+
308
+ sampling_kwargs = {}
309
+
310
+ N = min(x.shape[0], N)
311
+ x = x.to(self.device)[:N]
312
+ log["inputs"] = x
313
+ z = self.encode_first_stage(x)
314
+ log["reconstructions"] = self.decode_first_stage(z)
315
+ log.update(self.log_conditionings(batch, N))
316
+
317
+ for k in c:
318
+ if isinstance(c[k], torch.Tensor):
319
+ c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
320
+
321
+ if sample:
322
+ with self.ema_scope("Plotting"):
323
+ samples = self.sample(
324
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
325
+ )
326
+ samples = self.decode_first_stage(samples)
327
+ log["samples"] = samples
328
+ return log
sgm/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .encoders.modules import GeneralConditioner, DualConditioner
2
+
3
+ UNCONDITIONAL_CONFIG = {
4
+ "target": "sgm.modules.GeneralConditioner",
5
+ "params": {"emb_models": []},
6
+ }
sgm/modules/attention.py ADDED
@@ -0,0 +1,976 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from inspect import isfunction
3
+ from typing import Any, Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from einops import rearrange, repeat
8
+ from packaging import version
9
+ from torch import nn, einsum
10
+
11
+
12
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
13
+ SDP_IS_AVAILABLE = True
14
+ from torch.backends.cuda import SDPBackend, sdp_kernel
15
+
16
+ BACKEND_MAP = {
17
+ SDPBackend.MATH: {
18
+ "enable_math": True,
19
+ "enable_flash": False,
20
+ "enable_mem_efficient": False,
21
+ },
22
+ SDPBackend.FLASH_ATTENTION: {
23
+ "enable_math": False,
24
+ "enable_flash": True,
25
+ "enable_mem_efficient": False,
26
+ },
27
+ SDPBackend.EFFICIENT_ATTENTION: {
28
+ "enable_math": False,
29
+ "enable_flash": False,
30
+ "enable_mem_efficient": True,
31
+ },
32
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
33
+ }
34
+ else:
35
+ from contextlib import nullcontext
36
+
37
+ SDP_IS_AVAILABLE = False
38
+ sdp_kernel = nullcontext
39
+ BACKEND_MAP = {}
40
+ print(
41
+ f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
42
+ f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
43
+ )
44
+
45
+ try:
46
+ import xformers
47
+ import xformers.ops
48
+
49
+ XFORMERS_IS_AVAILABLE = True
50
+ except:
51
+ XFORMERS_IS_AVAILABLE = False
52
+ print("no module 'xformers'. Processing without...")
53
+
54
+ from .diffusionmodules.util import checkpoint
55
+
56
+
57
+ def exists(val):
58
+ return val is not None
59
+
60
+
61
+ def uniq(arr):
62
+ return {el: True for el in arr}.keys()
63
+
64
+
65
+ def default(val, d):
66
+ if exists(val):
67
+ return val
68
+ return d() if isfunction(d) else d
69
+
70
+
71
+ def max_neg_value(t):
72
+ return -torch.finfo(t.dtype).max
73
+
74
+
75
+ def init_(tensor):
76
+ dim = tensor.shape[-1]
77
+ std = 1 / math.sqrt(dim)
78
+ tensor.uniform_(-std, std)
79
+ return tensor
80
+
81
+ # feedforward
82
+ class GEGLU(nn.Module):
83
+ def __init__(self, dim_in, dim_out):
84
+ super().__init__()
85
+ self.proj = nn.Linear(dim_in, dim_out * 2)
86
+
87
+ def forward(self, x):
88
+ x, gate = self.proj(x).chunk(2, dim=-1)
89
+ return x * F.gelu(gate)
90
+
91
+
92
+ class FeedForward(nn.Module):
93
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
94
+ super().__init__()
95
+ inner_dim = int(dim * mult)
96
+ dim_out = default(dim_out, dim)
97
+ project_in = (
98
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
99
+ if not glu
100
+ else GEGLU(dim, inner_dim)
101
+ )
102
+
103
+ self.net = nn.Sequential(
104
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
105
+ )
106
+
107
+ def forward(self, x):
108
+ return self.net(x)
109
+
110
+
111
+ def zero_module(module):
112
+ """
113
+ Zero out the parameters of a module and return it.
114
+ """
115
+ for p in module.parameters():
116
+ p.detach().zero_()
117
+ return module
118
+
119
+
120
+ def Normalize(in_channels):
121
+ return torch.nn.GroupNorm(
122
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
123
+ )
124
+
125
+
126
+ class LinearAttention(nn.Module):
127
+ def __init__(self, dim, heads=4, dim_head=32):
128
+ super().__init__()
129
+ self.heads = heads
130
+ hidden_dim = dim_head * heads
131
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
132
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
133
+
134
+ def forward(self, x):
135
+ b, c, h, w = x.shape
136
+ qkv = self.to_qkv(x)
137
+ q, k, v = rearrange(
138
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
139
+ )
140
+ k = k.softmax(dim=-1)
141
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
142
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
143
+ out = rearrange(
144
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
145
+ )
146
+ return self.to_out(out)
147
+
148
+
149
+ class SpatialSelfAttention(nn.Module):
150
+ def __init__(self, in_channels):
151
+ super().__init__()
152
+ self.in_channels = in_channels
153
+
154
+ self.norm = Normalize(in_channels)
155
+ self.q = torch.nn.Conv2d(
156
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
157
+ )
158
+ self.k = torch.nn.Conv2d(
159
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
160
+ )
161
+ self.v = torch.nn.Conv2d(
162
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
163
+ )
164
+ self.proj_out = torch.nn.Conv2d(
165
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
166
+ )
167
+
168
+ def forward(self, x):
169
+ h_ = x
170
+ h_ = self.norm(h_)
171
+ q = self.q(h_)
172
+ k = self.k(h_)
173
+ v = self.v(h_)
174
+
175
+ # compute attention
176
+ b, c, h, w = q.shape
177
+ q = rearrange(q, "b c h w -> b (h w) c")
178
+ k = rearrange(k, "b c h w -> b c (h w)")
179
+ w_ = torch.einsum("bij,bjk->bik", q, k)
180
+
181
+ w_ = w_ * (int(c) ** (-0.5))
182
+ w_ = torch.nn.functional.softmax(w_, dim=2)
183
+
184
+ # attend to values
185
+ v = rearrange(v, "b c h w -> b c (h w)")
186
+ w_ = rearrange(w_, "b i j -> b j i")
187
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
188
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
189
+ h_ = self.proj_out(h_)
190
+
191
+ return x + h_
192
+
193
+
194
+ class CrossAttention(nn.Module):
195
+ def __init__(
196
+ self,
197
+ query_dim,
198
+ context_dim=None,
199
+ heads=8,
200
+ dim_head=64,
201
+ dropout=0.0,
202
+ backend=None,
203
+ ):
204
+ super().__init__()
205
+ inner_dim = dim_head * heads
206
+ context_dim = default(context_dim, query_dim)
207
+
208
+ self.scale = dim_head**-0.5
209
+ self.heads = heads
210
+
211
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
212
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
213
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
214
+
215
+ self.to_out = zero_module(nn.Sequential(
216
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
217
+ ))
218
+ self.backend = backend
219
+
220
+ self.attn_map_cache = None
221
+
222
+ def forward(
223
+ self,
224
+ x,
225
+ context=None,
226
+ mask=None,
227
+ additional_tokens=None,
228
+ n_times_crossframe_attn_in_self=0,
229
+ ):
230
+ h = self.heads
231
+
232
+ if additional_tokens is not None:
233
+ # get the number of masked tokens at the beginning of the output sequence
234
+ n_tokens_to_mask = additional_tokens.shape[1]
235
+ # add additional token
236
+ x = torch.cat([additional_tokens, x], dim=1)
237
+
238
+ q = self.to_q(x)
239
+ context = default(context, x)
240
+ k = self.to_k(context)
241
+ v = self.to_v(context)
242
+
243
+ if n_times_crossframe_attn_in_self:
244
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
245
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
246
+ n_cp = x.shape[0] // n_times_crossframe_attn_in_self
247
+ k = repeat(
248
+ k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
249
+ )
250
+ v = repeat(
251
+ v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
252
+ )
253
+
254
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
255
+
256
+ ## old
257
+
258
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
259
+ del q, k
260
+
261
+ if exists(mask):
262
+ mask = rearrange(mask, 'b ... -> b (...)')
263
+ max_neg_value = -torch.finfo(sim.dtype).max
264
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
265
+ sim.masked_fill_(~mask, max_neg_value)
266
+
267
+ # attention, what we cannot get enough of
268
+ sim = sim.softmax(dim=-1)
269
+
270
+ # save attn_map
271
+ if self.attn_map_cache is not None:
272
+ bh, n, l = sim.shape
273
+ size = int(n**0.5)
274
+ self.attn_map_cache["size"] = size
275
+ self.attn_map_cache["attn_map"] = sim
276
+
277
+ out = einsum('b i j, b j d -> b i d', sim, v)
278
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
279
+
280
+ ## new
281
+ # with sdp_kernel(**BACKEND_MAP[self.backend]):
282
+ # # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
283
+ # out = F.scaled_dot_product_attention(
284
+ # q, k, v, attn_mask=mask
285
+ # ) # scale is dim_head ** -0.5 per default
286
+
287
+ # del q, k, v
288
+ # out = rearrange(out, "b h n d -> b n (h d)", h=h)
289
+
290
+ if additional_tokens is not None:
291
+ # remove additional token
292
+ out = out[:, n_tokens_to_mask:]
293
+ return self.to_out(out)
294
+
295
+
296
+ class MemoryEfficientCrossAttention(nn.Module):
297
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
298
+ def __init__(
299
+ self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
300
+ ):
301
+ super().__init__()
302
+ # print(
303
+ # f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
304
+ # f"{heads} heads with a dimension of {dim_head}."
305
+ # )
306
+ inner_dim = dim_head * heads
307
+ context_dim = default(context_dim, query_dim)
308
+
309
+ self.heads = heads
310
+ self.dim_head = dim_head
311
+
312
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
313
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
314
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
315
+
316
+ self.to_out = nn.Sequential(
317
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
318
+ )
319
+ self.attention_op: Optional[Any] = None
320
+
321
+ def forward(
322
+ self,
323
+ x,
324
+ context=None,
325
+ mask=None,
326
+ additional_tokens=None,
327
+ n_times_crossframe_attn_in_self=0,
328
+ ):
329
+ if additional_tokens is not None:
330
+ # get the number of masked tokens at the beginning of the output sequence
331
+ n_tokens_to_mask = additional_tokens.shape[1]
332
+ # add additional token
333
+ x = torch.cat([additional_tokens, x], dim=1)
334
+ q = self.to_q(x)
335
+ context = default(context, x)
336
+ k = self.to_k(context)
337
+ v = self.to_v(context)
338
+
339
+ if n_times_crossframe_attn_in_self:
340
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
341
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
342
+ # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
343
+ k = repeat(
344
+ k[::n_times_crossframe_attn_in_self],
345
+ "b ... -> (b n) ...",
346
+ n=n_times_crossframe_attn_in_self,
347
+ )
348
+ v = repeat(
349
+ v[::n_times_crossframe_attn_in_self],
350
+ "b ... -> (b n) ...",
351
+ n=n_times_crossframe_attn_in_self,
352
+ )
353
+
354
+ b, _, _ = q.shape
355
+ q, k, v = map(
356
+ lambda t: t.unsqueeze(3)
357
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
358
+ .permute(0, 2, 1, 3)
359
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
360
+ .contiguous(),
361
+ (q, k, v),
362
+ )
363
+
364
+ # actually compute the attention, what we cannot get enough of
365
+ out = xformers.ops.memory_efficient_attention(
366
+ q, k, v, attn_bias=None, op=self.attention_op
367
+ )
368
+
369
+ # TODO: Use this directly in the attention operation, as a bias
370
+ if exists(mask):
371
+ raise NotImplementedError
372
+ out = (
373
+ out.unsqueeze(0)
374
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
375
+ .permute(0, 2, 1, 3)
376
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
377
+ )
378
+ if additional_tokens is not None:
379
+ # remove additional token
380
+ out = out[:, n_tokens_to_mask:]
381
+ return self.to_out(out)
382
+
383
+
384
+ class BasicTransformerBlock(nn.Module):
385
+ ATTENTION_MODES = {
386
+ "softmax": CrossAttention, # vanilla attention
387
+ "softmax-xformers": MemoryEfficientCrossAttention, # ampere
388
+ }
389
+
390
+ def __init__(
391
+ self,
392
+ dim,
393
+ n_heads,
394
+ d_head,
395
+ dropout=0.0,
396
+ context_dim=None,
397
+ add_context_dim=None,
398
+ gated_ff=True,
399
+ checkpoint=True,
400
+ disable_self_attn=False,
401
+ attn_mode="softmax",
402
+ sdp_backend=None,
403
+ ):
404
+ super().__init__()
405
+ assert attn_mode in self.ATTENTION_MODES
406
+ if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
407
+ print(
408
+ f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
409
+ f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
410
+ )
411
+ attn_mode = "softmax"
412
+ elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
413
+ print(
414
+ "We do not support vanilla attention anymore, as it is too expensive. Sorry."
415
+ )
416
+ if not XFORMERS_IS_AVAILABLE:
417
+ assert (
418
+ False
419
+ ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
420
+ else:
421
+ print("Falling back to xformers efficient attention.")
422
+ attn_mode = "softmax-xformers"
423
+ attn_cls = self.ATTENTION_MODES[attn_mode]
424
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
425
+ assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
426
+ else:
427
+ assert sdp_backend is None
428
+ self.disable_self_attn = disable_self_attn
429
+ self.attn1 = attn_cls(
430
+ query_dim=dim,
431
+ heads=n_heads,
432
+ dim_head=d_head,
433
+ dropout=dropout,
434
+ context_dim=context_dim if self.disable_self_attn else None,
435
+ backend=sdp_backend,
436
+ ) # is a self-attention if not self.disable_self_attn
437
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
438
+ if context_dim is not None and context_dim > 0:
439
+ self.attn2 = attn_cls(
440
+ query_dim=dim,
441
+ context_dim=context_dim,
442
+ heads=n_heads,
443
+ dim_head=d_head,
444
+ dropout=dropout,
445
+ backend=sdp_backend,
446
+ ) # is self-attn if context is none
447
+ if add_context_dim is not None and add_context_dim > 0:
448
+ self.add_attn = attn_cls(
449
+ query_dim=dim,
450
+ context_dim=add_context_dim,
451
+ heads=n_heads,
452
+ dim_head=d_head,
453
+ dropout=dropout,
454
+ backend=sdp_backend,
455
+ ) # is self-attn if context is none
456
+ self.add_norm = nn.LayerNorm(dim)
457
+ self.norm1 = nn.LayerNorm(dim)
458
+ self.norm2 = nn.LayerNorm(dim)
459
+ self.norm3 = nn.LayerNorm(dim)
460
+ self.checkpoint = checkpoint
461
+
462
+ def forward(
463
+ self, x, context=None, add_context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
464
+ ):
465
+ kwargs = {"x": x}
466
+
467
+ if context is not None:
468
+ kwargs.update({"context": context})
469
+
470
+ if additional_tokens is not None:
471
+ kwargs.update({"additional_tokens": additional_tokens})
472
+
473
+ if n_times_crossframe_attn_in_self:
474
+ kwargs.update(
475
+ {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
476
+ )
477
+
478
+ return checkpoint(
479
+ self._forward, (x, context, add_context), self.parameters(), self.checkpoint
480
+ )
481
+
482
+ def _forward(
483
+ self, x, context=None, add_context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
484
+ ):
485
+ x = (
486
+ self.attn1(
487
+ self.norm1(x),
488
+ context=context if self.disable_self_attn else None,
489
+ additional_tokens=additional_tokens,
490
+ n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
491
+ if not self.disable_self_attn
492
+ else 0,
493
+ )
494
+ + x
495
+ )
496
+ if hasattr(self, "attn2"):
497
+ x = (
498
+ self.attn2(
499
+ self.norm2(x), context=context, additional_tokens=additional_tokens
500
+ )
501
+ + x
502
+ )
503
+ if hasattr(self, "add_attn"):
504
+ x = (
505
+ self.add_attn(
506
+ self.add_norm(x), context=add_context, additional_tokens=additional_tokens
507
+ )
508
+ + x
509
+ )
510
+ x = self.ff(self.norm3(x)) + x
511
+ return x
512
+
513
+
514
+ class BasicTransformerSingleLayerBlock(nn.Module):
515
+ ATTENTION_MODES = {
516
+ "softmax": CrossAttention, # vanilla attention
517
+ "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
518
+ # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
519
+ }
520
+
521
+ def __init__(
522
+ self,
523
+ dim,
524
+ n_heads,
525
+ d_head,
526
+ dropout=0.0,
527
+ context_dim=None,
528
+ gated_ff=True,
529
+ checkpoint=True,
530
+ attn_mode="softmax",
531
+ ):
532
+ super().__init__()
533
+ assert attn_mode in self.ATTENTION_MODES
534
+ attn_cls = self.ATTENTION_MODES[attn_mode]
535
+ self.attn1 = attn_cls(
536
+ query_dim=dim,
537
+ heads=n_heads,
538
+ dim_head=d_head,
539
+ dropout=dropout,
540
+ context_dim=context_dim,
541
+ )
542
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
543
+ self.norm1 = nn.LayerNorm(dim)
544
+ self.norm2 = nn.LayerNorm(dim)
545
+ self.checkpoint = checkpoint
546
+
547
+ def forward(self, x, context=None):
548
+ return checkpoint(
549
+ self._forward, (x, context), self.parameters(), self.checkpoint
550
+ )
551
+
552
+ def _forward(self, x, context=None):
553
+ x = self.attn1(self.norm1(x), context=context) + x
554
+ x = self.ff(self.norm2(x)) + x
555
+ return x
556
+
557
+
558
+ class SpatialTransformer(nn.Module):
559
+ """
560
+ Transformer block for image-like data.
561
+ First, project the input (aka embedding)
562
+ and reshape to b, t, d.
563
+ Then apply standard transformer action.
564
+ Finally, reshape to image
565
+ NEW: use_linear for more efficiency instead of the 1x1 convs
566
+ """
567
+
568
+ def __init__(
569
+ self,
570
+ in_channels,
571
+ n_heads,
572
+ d_head,
573
+ depth=1,
574
+ dropout=0.0,
575
+ context_dim=None,
576
+ add_context_dim=None,
577
+ disable_self_attn=False,
578
+ use_linear=False,
579
+ attn_type="softmax",
580
+ use_checkpoint=True,
581
+ # sdp_backend=SDPBackend.FLASH_ATTENTION
582
+ sdp_backend=None,
583
+ ):
584
+ super().__init__()
585
+ # print(
586
+ # f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
587
+ # )
588
+ from omegaconf import ListConfig
589
+
590
+ if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
591
+ context_dim = [context_dim]
592
+ if exists(context_dim) and isinstance(context_dim, list):
593
+ if depth != len(context_dim):
594
+ # print(
595
+ # f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
596
+ # f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
597
+ # )
598
+ # depth does not match context dims.
599
+ assert all(
600
+ map(lambda x: x == context_dim[0], context_dim)
601
+ ), "need homogenous context_dim to match depth automatically"
602
+ context_dim = depth * [context_dim[0]]
603
+ elif context_dim is None:
604
+ context_dim = [None] * depth
605
+ self.in_channels = in_channels
606
+ inner_dim = n_heads * d_head
607
+ self.norm = Normalize(in_channels)
608
+ if not use_linear:
609
+ self.proj_in = nn.Conv2d(
610
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
611
+ )
612
+ else:
613
+ self.proj_in = nn.Linear(in_channels, inner_dim)
614
+
615
+ self.transformer_blocks = nn.ModuleList(
616
+ [
617
+ BasicTransformerBlock(
618
+ inner_dim,
619
+ n_heads,
620
+ d_head,
621
+ dropout=dropout,
622
+ context_dim=context_dim[d],
623
+ add_context_dim=add_context_dim,
624
+ disable_self_attn=disable_self_attn,
625
+ attn_mode=attn_type,
626
+ checkpoint=use_checkpoint,
627
+ sdp_backend=sdp_backend,
628
+ )
629
+ for d in range(depth)
630
+ ]
631
+ )
632
+ if not use_linear:
633
+ self.proj_out = zero_module(
634
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
635
+ )
636
+ else:
637
+ # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
638
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
639
+ self.use_linear = use_linear
640
+
641
+ def forward(self, x, context=None, add_context=None):
642
+ # note: if no context is given, cross-attention defaults to self-attention
643
+ if not isinstance(context, list):
644
+ context = [context]
645
+ b, c, h, w = x.shape
646
+ x_in = x
647
+ x = self.norm(x)
648
+ if not self.use_linear:
649
+ x = self.proj_in(x)
650
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
651
+ if self.use_linear:
652
+ x = self.proj_in(x)
653
+ for i, block in enumerate(self.transformer_blocks):
654
+ if i > 0 and len(context) == 1:
655
+ i = 0 # use same context for each block
656
+ x = block(x, context=context[i], add_context=add_context)
657
+ if self.use_linear:
658
+ x = self.proj_out(x)
659
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
660
+ if not self.use_linear:
661
+ x = self.proj_out(x)
662
+ return x + x_in
663
+
664
+
665
+ def benchmark_attn():
666
+ # Lets define a helpful benchmarking function:
667
+ # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
668
+ device = "cuda" if torch.cuda.is_available() else "cpu"
669
+ import torch.nn.functional as F
670
+ import torch.utils.benchmark as benchmark
671
+
672
+ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
673
+ t0 = benchmark.Timer(
674
+ stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
675
+ )
676
+ return t0.blocked_autorange().mean * 1e6
677
+
678
+ # Lets define the hyper-parameters of our input
679
+ batch_size = 32
680
+ max_sequence_len = 1024
681
+ num_heads = 32
682
+ embed_dimension = 32
683
+
684
+ dtype = torch.float16
685
+
686
+ query = torch.rand(
687
+ batch_size,
688
+ num_heads,
689
+ max_sequence_len,
690
+ embed_dimension,
691
+ device=device,
692
+ dtype=dtype,
693
+ )
694
+ key = torch.rand(
695
+ batch_size,
696
+ num_heads,
697
+ max_sequence_len,
698
+ embed_dimension,
699
+ device=device,
700
+ dtype=dtype,
701
+ )
702
+ value = torch.rand(
703
+ batch_size,
704
+ num_heads,
705
+ max_sequence_len,
706
+ embed_dimension,
707
+ device=device,
708
+ dtype=dtype,
709
+ )
710
+
711
+ print(f"q/k/v shape:", query.shape, key.shape, value.shape)
712
+
713
+ # Lets explore the speed of each of the 3 implementations
714
+ from torch.backends.cuda import SDPBackend, sdp_kernel
715
+
716
+ # Helpful arguments mapper
717
+ backend_map = {
718
+ SDPBackend.MATH: {
719
+ "enable_math": True,
720
+ "enable_flash": False,
721
+ "enable_mem_efficient": False,
722
+ },
723
+ SDPBackend.FLASH_ATTENTION: {
724
+ "enable_math": False,
725
+ "enable_flash": True,
726
+ "enable_mem_efficient": False,
727
+ },
728
+ SDPBackend.EFFICIENT_ATTENTION: {
729
+ "enable_math": False,
730
+ "enable_flash": False,
731
+ "enable_mem_efficient": True,
732
+ },
733
+ }
734
+
735
+ from torch.profiler import ProfilerActivity, profile, record_function
736
+
737
+ activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
738
+
739
+ print(
740
+ f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
741
+ )
742
+ with profile(
743
+ activities=activities, record_shapes=False, profile_memory=True
744
+ ) as prof:
745
+ with record_function("Default detailed stats"):
746
+ for _ in range(25):
747
+ o = F.scaled_dot_product_attention(query, key, value)
748
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
749
+
750
+ print(
751
+ f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
752
+ )
753
+ with sdp_kernel(**backend_map[SDPBackend.MATH]):
754
+ with profile(
755
+ activities=activities, record_shapes=False, profile_memory=True
756
+ ) as prof:
757
+ with record_function("Math implmentation stats"):
758
+ for _ in range(25):
759
+ o = F.scaled_dot_product_attention(query, key, value)
760
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
761
+
762
+ with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
763
+ try:
764
+ print(
765
+ f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
766
+ )
767
+ except RuntimeError:
768
+ print("FlashAttention is not supported. See warnings for reasons.")
769
+ with profile(
770
+ activities=activities, record_shapes=False, profile_memory=True
771
+ ) as prof:
772
+ with record_function("FlashAttention stats"):
773
+ for _ in range(25):
774
+ o = F.scaled_dot_product_attention(query, key, value)
775
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
776
+
777
+ with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
778
+ try:
779
+ print(
780
+ f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
781
+ )
782
+ except RuntimeError:
783
+ print("EfficientAttention is not supported. See warnings for reasons.")
784
+ with profile(
785
+ activities=activities, record_shapes=False, profile_memory=True
786
+ ) as prof:
787
+ with record_function("EfficientAttention stats"):
788
+ for _ in range(25):
789
+ o = F.scaled_dot_product_attention(query, key, value)
790
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
791
+
792
+
793
+ def run_model(model, x, context):
794
+ return model(x, context)
795
+
796
+
797
+ def benchmark_transformer_blocks():
798
+ device = "cuda" if torch.cuda.is_available() else "cpu"
799
+ import torch.utils.benchmark as benchmark
800
+
801
+ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
802
+ t0 = benchmark.Timer(
803
+ stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
804
+ )
805
+ return t0.blocked_autorange().mean * 1e6
806
+
807
+ checkpoint = True
808
+ compile = False
809
+
810
+ batch_size = 32
811
+ h, w = 64, 64
812
+ context_len = 77
813
+ embed_dimension = 1024
814
+ context_dim = 1024
815
+ d_head = 64
816
+
817
+ transformer_depth = 4
818
+
819
+ n_heads = embed_dimension // d_head
820
+
821
+ dtype = torch.float16
822
+
823
+ model_native = SpatialTransformer(
824
+ embed_dimension,
825
+ n_heads,
826
+ d_head,
827
+ context_dim=context_dim,
828
+ use_linear=True,
829
+ use_checkpoint=checkpoint,
830
+ attn_type="softmax",
831
+ depth=transformer_depth,
832
+ sdp_backend=SDPBackend.FLASH_ATTENTION,
833
+ ).to(device)
834
+ model_efficient_attn = SpatialTransformer(
835
+ embed_dimension,
836
+ n_heads,
837
+ d_head,
838
+ context_dim=context_dim,
839
+ use_linear=True,
840
+ depth=transformer_depth,
841
+ use_checkpoint=checkpoint,
842
+ attn_type="softmax-xformers",
843
+ ).to(device)
844
+ if not checkpoint and compile:
845
+ print("compiling models")
846
+ model_native = torch.compile(model_native)
847
+ model_efficient_attn = torch.compile(model_efficient_attn)
848
+
849
+ x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
850
+ c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
851
+
852
+ from torch.profiler import ProfilerActivity, profile, record_function
853
+
854
+ activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
855
+
856
+ with torch.autocast("cuda"):
857
+ print(
858
+ f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
859
+ )
860
+ print(
861
+ f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
862
+ )
863
+
864
+ print(75 * "+")
865
+ print("NATIVE")
866
+ print(75 * "+")
867
+ torch.cuda.reset_peak_memory_stats()
868
+ with profile(
869
+ activities=activities, record_shapes=False, profile_memory=True
870
+ ) as prof:
871
+ with record_function("NativeAttention stats"):
872
+ for _ in range(25):
873
+ model_native(x, c)
874
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
875
+ print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
876
+
877
+ print(75 * "+")
878
+ print("Xformers")
879
+ print(75 * "+")
880
+ torch.cuda.reset_peak_memory_stats()
881
+ with profile(
882
+ activities=activities, record_shapes=False, profile_memory=True
883
+ ) as prof:
884
+ with record_function("xformers stats"):
885
+ for _ in range(25):
886
+ model_efficient_attn(x, c)
887
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
888
+ print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
889
+
890
+
891
+ def test01():
892
+ # conv1x1 vs linear
893
+ from ..util import count_params
894
+
895
+ conv = nn.Conv2d(3, 32, kernel_size=1).cuda()
896
+ print(count_params(conv))
897
+ linear = torch.nn.Linear(3, 32).cuda()
898
+ print(count_params(linear))
899
+
900
+ print(conv.weight.shape)
901
+
902
+ # use same initialization
903
+ linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
904
+ linear.bias = torch.nn.Parameter(conv.bias)
905
+
906
+ print(linear.weight.shape)
907
+
908
+ x = torch.randn(11, 3, 64, 64).cuda()
909
+
910
+ xr = rearrange(x, "b c h w -> b (h w) c").contiguous()
911
+ print(xr.shape)
912
+ out_linear = linear(xr)
913
+ print(out_linear.mean(), out_linear.shape)
914
+
915
+ out_conv = conv(x)
916
+ print(out_conv.mean(), out_conv.shape)
917
+ print("done with test01.\n")
918
+
919
+
920
+ def test02():
921
+ # try cosine flash attention
922
+ import time
923
+
924
+ torch.backends.cuda.matmul.allow_tf32 = True
925
+ torch.backends.cudnn.allow_tf32 = True
926
+ torch.backends.cudnn.benchmark = True
927
+ print("testing cosine flash attention...")
928
+ DIM = 1024
929
+ SEQLEN = 4096
930
+ BS = 16
931
+
932
+ print(" softmax (vanilla) first...")
933
+ model = BasicTransformerBlock(
934
+ dim=DIM,
935
+ n_heads=16,
936
+ d_head=64,
937
+ dropout=0.0,
938
+ context_dim=None,
939
+ attn_mode="softmax",
940
+ ).cuda()
941
+ try:
942
+ x = torch.randn(BS, SEQLEN, DIM).cuda()
943
+ tic = time.time()
944
+ y = model(x)
945
+ toc = time.time()
946
+ print(y.shape, toc - tic)
947
+ except RuntimeError as e:
948
+ # likely oom
949
+ print(str(e))
950
+
951
+ print("\n now flash-cosine...")
952
+ model = BasicTransformerBlock(
953
+ dim=DIM,
954
+ n_heads=16,
955
+ d_head=64,
956
+ dropout=0.0,
957
+ context_dim=None,
958
+ attn_mode="flash-cosine",
959
+ ).cuda()
960
+ x = torch.randn(BS, SEQLEN, DIM).cuda()
961
+ tic = time.time()
962
+ y = model(x)
963
+ toc = time.time()
964
+ print(y.shape, toc - tic)
965
+ print("done with test02.\n")
966
+
967
+
968
+ if __name__ == "__main__":
969
+ # test01()
970
+ # test02()
971
+ # test03()
972
+
973
+ # benchmark_attn()
974
+ benchmark_transformer_blocks()
975
+
976
+ print("done.")
sgm/modules/autoencoding/__init__.py ADDED
File without changes
sgm/modules/autoencoding/losses/__init__.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7
+ from taming.modules.losses.lpips import LPIPS
8
+ from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9
+
10
+ from ....util import default, instantiate_from_config
11
+
12
+
13
+ def adopt_weight(weight, global_step, threshold=0, value=0.0):
14
+ if global_step < threshold:
15
+ weight = value
16
+ return weight
17
+
18
+
19
+ class LatentLPIPS(nn.Module):
20
+ def __init__(
21
+ self,
22
+ decoder_config,
23
+ perceptual_weight=1.0,
24
+ latent_weight=1.0,
25
+ scale_input_to_tgt_size=False,
26
+ scale_tgt_to_input_size=False,
27
+ perceptual_weight_on_inputs=0.0,
28
+ ):
29
+ super().__init__()
30
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
31
+ self.scale_tgt_to_input_size = scale_tgt_to_input_size
32
+ self.init_decoder(decoder_config)
33
+ self.perceptual_loss = LPIPS().eval()
34
+ self.perceptual_weight = perceptual_weight
35
+ self.latent_weight = latent_weight
36
+ self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
37
+
38
+ def init_decoder(self, config):
39
+ self.decoder = instantiate_from_config(config)
40
+ if hasattr(self.decoder, "encoder"):
41
+ del self.decoder.encoder
42
+
43
+ def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
44
+ log = dict()
45
+ loss = (latent_inputs - latent_predictions) ** 2
46
+ log[f"{split}/latent_l2_loss"] = loss.mean().detach()
47
+ image_reconstructions = None
48
+ if self.perceptual_weight > 0.0:
49
+ image_reconstructions = self.decoder.decode(latent_predictions)
50
+ image_targets = self.decoder.decode(latent_inputs)
51
+ perceptual_loss = self.perceptual_loss(
52
+ image_targets.contiguous(), image_reconstructions.contiguous()
53
+ )
54
+ loss = (
55
+ self.latent_weight * loss.mean()
56
+ + self.perceptual_weight * perceptual_loss.mean()
57
+ )
58
+ log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
59
+
60
+ if self.perceptual_weight_on_inputs > 0.0:
61
+ image_reconstructions = default(
62
+ image_reconstructions, self.decoder.decode(latent_predictions)
63
+ )
64
+ if self.scale_input_to_tgt_size:
65
+ image_inputs = torch.nn.functional.interpolate(
66
+ image_inputs,
67
+ image_reconstructions.shape[2:],
68
+ mode="bicubic",
69
+ antialias=True,
70
+ )
71
+ elif self.scale_tgt_to_input_size:
72
+ image_reconstructions = torch.nn.functional.interpolate(
73
+ image_reconstructions,
74
+ image_inputs.shape[2:],
75
+ mode="bicubic",
76
+ antialias=True,
77
+ )
78
+
79
+ perceptual_loss2 = self.perceptual_loss(
80
+ image_inputs.contiguous(), image_reconstructions.contiguous()
81
+ )
82
+ loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
83
+ log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
84
+ return loss, log
85
+
86
+
87
+ class GeneralLPIPSWithDiscriminator(nn.Module):
88
+ def __init__(
89
+ self,
90
+ disc_start: int,
91
+ logvar_init: float = 0.0,
92
+ pixelloss_weight=1.0,
93
+ disc_num_layers: int = 3,
94
+ disc_in_channels: int = 3,
95
+ disc_factor: float = 1.0,
96
+ disc_weight: float = 1.0,
97
+ perceptual_weight: float = 1.0,
98
+ disc_loss: str = "hinge",
99
+ scale_input_to_tgt_size: bool = False,
100
+ dims: int = 2,
101
+ learn_logvar: bool = False,
102
+ regularization_weights: Union[None, dict] = None,
103
+ ):
104
+ super().__init__()
105
+ self.dims = dims
106
+ if self.dims > 2:
107
+ print(
108
+ f"running with dims={dims}. This means that for perceptual loss calculation, "
109
+ f"the LPIPS loss will be applied to each frame independently. "
110
+ )
111
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
112
+ assert disc_loss in ["hinge", "vanilla"]
113
+ self.pixel_weight = pixelloss_weight
114
+ self.perceptual_loss = LPIPS().eval()
115
+ self.perceptual_weight = perceptual_weight
116
+ # output log variance
117
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
118
+ self.learn_logvar = learn_logvar
119
+
120
+ self.discriminator = NLayerDiscriminator(
121
+ input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=False
122
+ ).apply(weights_init)
123
+ self.discriminator_iter_start = disc_start
124
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
125
+ self.disc_factor = disc_factor
126
+ self.discriminator_weight = disc_weight
127
+ self.regularization_weights = default(regularization_weights, {})
128
+
129
+ def get_trainable_parameters(self) -> Any:
130
+ return self.discriminator.parameters()
131
+
132
+ def get_trainable_autoencoder_parameters(self) -> Any:
133
+ if self.learn_logvar:
134
+ yield self.logvar
135
+ yield from ()
136
+
137
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
138
+ if last_layer is not None:
139
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
140
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
141
+ else:
142
+ nll_grads = torch.autograd.grad(
143
+ nll_loss, self.last_layer[0], retain_graph=True
144
+ )[0]
145
+ g_grads = torch.autograd.grad(
146
+ g_loss, self.last_layer[0], retain_graph=True
147
+ )[0]
148
+
149
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
150
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
151
+ d_weight = d_weight * self.discriminator_weight
152
+ return d_weight
153
+
154
+ def forward(
155
+ self,
156
+ regularization_log,
157
+ inputs,
158
+ reconstructions,
159
+ optimizer_idx,
160
+ global_step,
161
+ last_layer=None,
162
+ split="train",
163
+ weights=None,
164
+ ):
165
+ if self.scale_input_to_tgt_size:
166
+ inputs = torch.nn.functional.interpolate(
167
+ inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
168
+ )
169
+
170
+ if self.dims > 2:
171
+ inputs, reconstructions = map(
172
+ lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
173
+ (inputs, reconstructions),
174
+ )
175
+
176
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
177
+ if self.perceptual_weight > 0:
178
+ p_loss = self.perceptual_loss(
179
+ inputs.contiguous(), reconstructions.contiguous()
180
+ )
181
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
182
+
183
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
184
+ weighted_nll_loss = nll_loss
185
+ if weights is not None:
186
+ weighted_nll_loss = weights * nll_loss
187
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
188
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
189
+
190
+ # now the GAN part
191
+ if optimizer_idx == 0:
192
+ # generator update
193
+ logits_fake = self.discriminator(reconstructions.contiguous())
194
+ g_loss = -torch.mean(logits_fake)
195
+
196
+ if self.disc_factor > 0.0:
197
+ try:
198
+ d_weight = self.calculate_adaptive_weight(
199
+ nll_loss, g_loss, last_layer=last_layer
200
+ )
201
+ except RuntimeError:
202
+ assert not self.training
203
+ d_weight = torch.tensor(0.0)
204
+ else:
205
+ d_weight = torch.tensor(0.0)
206
+
207
+ disc_factor = adopt_weight(
208
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
209
+ )
210
+ loss = weighted_nll_loss + d_weight * disc_factor * g_loss
211
+ log = dict()
212
+ for k in regularization_log:
213
+ if k in self.regularization_weights:
214
+ loss = loss + self.regularization_weights[k] * regularization_log[k]
215
+ log[f"{split}/{k}"] = regularization_log[k].detach().mean()
216
+
217
+ log.update(
218
+ {
219
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
220
+ "{}/logvar".format(split): self.logvar.detach(),
221
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
222
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
223
+ "{}/d_weight".format(split): d_weight.detach(),
224
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
225
+ "{}/g_loss".format(split): g_loss.detach().mean(),
226
+ }
227
+ )
228
+
229
+ return loss, log
230
+
231
+ if optimizer_idx == 1:
232
+ # second pass for discriminator update
233
+ logits_real = self.discriminator(inputs.contiguous().detach())
234
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
235
+
236
+ disc_factor = adopt_weight(
237
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
238
+ )
239
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
240
+
241
+ log = {
242
+ "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
243
+ "{}/logits_real".format(split): logits_real.detach().mean(),
244
+ "{}/logits_fake".format(split): logits_fake.detach().mean(),
245
+ }
246
+ return d_loss, log
sgm/modules/autoencoding/regularizers/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import Any, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from ....modules.distributions.distributions import DiagonalGaussianDistribution
9
+
10
+
11
+ class AbstractRegularizer(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
16
+ raise NotImplementedError()
17
+
18
+ @abstractmethod
19
+ def get_trainable_parameters(self) -> Any:
20
+ raise NotImplementedError()
21
+
22
+
23
+ class DiagonalGaussianRegularizer(AbstractRegularizer):
24
+ def __init__(self, sample: bool = True):
25
+ super().__init__()
26
+ self.sample = sample
27
+
28
+ def get_trainable_parameters(self) -> Any:
29
+ yield from ()
30
+
31
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
32
+ log = dict()
33
+ posterior = DiagonalGaussianDistribution(z)
34
+ if self.sample:
35
+ z = posterior.sample()
36
+ else:
37
+ z = posterior.mode()
38
+ kl_loss = posterior.kl()
39
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
40
+ log["kl_loss"] = kl_loss
41
+ return z, log
42
+
43
+
44
+ def measure_perplexity(predicted_indices, num_centroids):
45
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
46
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
47
+ encodings = (
48
+ F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
49
+ )
50
+ avg_probs = encodings.mean(0)
51
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
52
+ cluster_use = torch.sum(avg_probs > 0)
53
+ return perplexity, cluster_use
sgm/modules/diffusionmodules/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .denoiser import Denoiser
2
+ from .discretizer import Discretization
3
+ from .loss import StandardDiffusionLoss
4
+ from .model import Model, Encoder, Decoder
5
+ from .openaimodel import UNetModel
6
+ from .sampling import BaseDiffusionSampler
7
+ from .wrappers import OpenAIWrapper
sgm/modules/diffusionmodules/denoiser.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from ...util import append_dims, instantiate_from_config
4
+
5
+
6
+ class Denoiser(nn.Module):
7
+ def __init__(self, weighting_config, scaling_config):
8
+ super().__init__()
9
+
10
+ self.weighting = instantiate_from_config(weighting_config)
11
+ self.scaling = instantiate_from_config(scaling_config)
12
+
13
+ def possibly_quantize_sigma(self, sigma):
14
+ return sigma
15
+
16
+ def possibly_quantize_c_noise(self, c_noise):
17
+ return c_noise
18
+
19
+ def w(self, sigma):
20
+ return self.weighting(sigma)
21
+
22
+ def __call__(self, network, input, sigma, cond):
23
+ sigma = self.possibly_quantize_sigma(sigma)
24
+ sigma_shape = sigma.shape
25
+ sigma = append_dims(sigma, input.ndim)
26
+ c_skip, c_out, c_in, c_noise = self.scaling(sigma)
27
+ c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
28
+ return network(input * c_in, c_noise, cond) * c_out + input * c_skip
29
+
30
+
31
+ class DiscreteDenoiser(Denoiser):
32
+ def __init__(
33
+ self,
34
+ weighting_config,
35
+ scaling_config,
36
+ num_idx,
37
+ discretization_config,
38
+ do_append_zero=False,
39
+ quantize_c_noise=True,
40
+ flip=True,
41
+ ):
42
+ super().__init__(weighting_config, scaling_config)
43
+ sigmas = instantiate_from_config(discretization_config)(
44
+ num_idx, do_append_zero=do_append_zero, flip=flip
45
+ )
46
+ self.register_buffer("sigmas", sigmas)
47
+ self.quantize_c_noise = quantize_c_noise
48
+
49
+ def sigma_to_idx(self, sigma):
50
+ dists = sigma - self.sigmas[:, None]
51
+ return dists.abs().argmin(dim=0).view(sigma.shape)
52
+
53
+ def idx_to_sigma(self, idx):
54
+ return self.sigmas[idx]
55
+
56
+ def possibly_quantize_sigma(self, sigma):
57
+ return self.idx_to_sigma(self.sigma_to_idx(sigma))
58
+
59
+ def possibly_quantize_c_noise(self, c_noise):
60
+ if self.quantize_c_noise:
61
+ return self.sigma_to_idx(c_noise)
62
+ else:
63
+ return c_noise
sgm/modules/diffusionmodules/denoiser_scaling.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class EDMScaling:
5
+ def __init__(self, sigma_data=0.5):
6
+ self.sigma_data = sigma_data
7
+
8
+ def __call__(self, sigma):
9
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
10
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
11
+ c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
12
+ c_noise = 0.25 * sigma.log()
13
+ return c_skip, c_out, c_in, c_noise
14
+
15
+
16
+ class EpsScaling:
17
+ def __call__(self, sigma):
18
+ c_skip = torch.ones_like(sigma, device=sigma.device)
19
+ c_out = -sigma
20
+ c_in = 1 / (sigma**2 + 1.0) ** 0.5
21
+ c_noise = sigma.clone()
22
+ return c_skip, c_out, c_in, c_noise
23
+
24
+
25
+ class VScaling:
26
+ def __call__(self, sigma):
27
+ c_skip = 1.0 / (sigma**2 + 1.0)
28
+ c_out = -sigma / (sigma**2 + 1.0) ** 0.5
29
+ c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
30
+ c_noise = sigma.clone()
31
+ return c_skip, c_out, c_in, c_noise
sgm/modules/diffusionmodules/denoiser_weighting.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class UnitWeighting:
5
+ def __call__(self, sigma):
6
+ return torch.ones_like(sigma, device=sigma.device)
7
+
8
+
9
+ class EDMWeighting:
10
+ def __init__(self, sigma_data=0.5):
11
+ self.sigma_data = sigma_data
12
+
13
+ def __call__(self, sigma):
14
+ return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
15
+
16
+
17
+ class VWeighting(EDMWeighting):
18
+ def __init__(self):
19
+ super().__init__(sigma_data=1.0)
20
+
21
+
22
+ class EpsWeighting:
23
+ def __call__(self, sigma):
24
+ return sigma**-2.0
sgm/modules/diffusionmodules/discretizer.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from functools import partial
4
+ from abc import abstractmethod
5
+
6
+ from ...util import append_zero
7
+ from ...modules.diffusionmodules.util import make_beta_schedule
8
+
9
+
10
+ def generate_roughly_equally_spaced_steps(
11
+ num_substeps: int, max_step: int
12
+ ) -> np.ndarray:
13
+ return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
14
+
15
+
16
+ class Discretization:
17
+ def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
18
+ sigmas = self.get_sigmas(n, device=device)
19
+ sigmas = append_zero(sigmas) if do_append_zero else sigmas
20
+ return sigmas if not flip else torch.flip(sigmas, (0,))
21
+
22
+ @abstractmethod
23
+ def get_sigmas(self, n, device):
24
+ pass
25
+
26
+
27
+ class EDMDiscretization(Discretization):
28
+ def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0):
29
+ self.sigma_min = sigma_min
30
+ self.sigma_max = sigma_max
31
+ self.rho = rho
32
+
33
+ def get_sigmas(self, n, device="cpu"):
34
+ ramp = torch.linspace(0, 1, n, device=device)
35
+ min_inv_rho = self.sigma_min ** (1 / self.rho)
36
+ max_inv_rho = self.sigma_max ** (1 / self.rho)
37
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
38
+ return sigmas
39
+
40
+
41
+ class LegacyDDPMDiscretization(Discretization):
42
+ def __init__(
43
+ self,
44
+ linear_start=0.00085,
45
+ linear_end=0.0120,
46
+ num_timesteps=1000,
47
+ ):
48
+ super().__init__()
49
+ self.num_timesteps = num_timesteps
50
+ betas = make_beta_schedule(
51
+ "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
52
+ )
53
+ alphas = 1.0 - betas
54
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
55
+ self.to_torch = partial(torch.tensor, dtype=torch.float32)
56
+
57
+ def get_sigmas(self, n, device="cpu"):
58
+ if n < self.num_timesteps:
59
+ timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
60
+ alphas_cumprod = self.alphas_cumprod[timesteps]
61
+ elif n == self.num_timesteps:
62
+ alphas_cumprod = self.alphas_cumprod
63
+ else:
64
+ raise ValueError
65
+
66
+ to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
67
+ sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
68
+ return torch.flip(sigmas, (0,))
sgm/modules/diffusionmodules/guiders.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+
5
+ from ...util import default, instantiate_from_config
6
+
7
+
8
+ class VanillaCFG:
9
+ """
10
+ implements parallelized CFG
11
+ """
12
+
13
+ def __init__(self, scale, dyn_thresh_config=None):
14
+ scale_schedule = lambda scale, sigma: scale # independent of step
15
+ self.scale_schedule = partial(scale_schedule, scale)
16
+ self.dyn_thresh = instantiate_from_config(
17
+ default(
18
+ dyn_thresh_config,
19
+ {
20
+ "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
21
+ },
22
+ )
23
+ )
24
+
25
+ def __call__(self, x, sigma):
26
+ x_u, x_c = x.chunk(2)
27
+ scale_value = self.scale_schedule(sigma)
28
+ x_pred = self.dyn_thresh(x_u, x_c, scale_value)
29
+ return x_pred
30
+
31
+ def prepare_inputs(self, x, s, c, uc):
32
+ c_out = dict()
33
+
34
+ for k in c:
35
+ if k in ["vector", "crossattn", "add_crossattn", "concat"]:
36
+ c_out[k] = torch.cat((uc[k], c[k]), 0)
37
+ else:
38
+ assert c[k] == uc[k]
39
+ c_out[k] = c[k]
40
+ return torch.cat([x] * 2), torch.cat([s] * 2), c_out
41
+
42
+
43
+ class DualCFG:
44
+
45
+ def __init__(self, scale):
46
+ self.scale = scale
47
+ self.dyn_thresh = instantiate_from_config(
48
+ {
49
+ "target": "sgm.modules.diffusionmodules.sampling_utils.DualThresholding"
50
+ },
51
+ )
52
+
53
+ def __call__(self, x, sigma):
54
+ x_u_1, x_u_2, x_c = x.chunk(3)
55
+ x_pred = self.dyn_thresh(x_u_1, x_u_2, x_c, self.scale)
56
+ return x_pred
57
+
58
+ def prepare_inputs(self, x, s, c, uc_1, uc_2):
59
+ c_out = dict()
60
+
61
+ for k in c:
62
+ if k in ["vector", "crossattn", "concat", "add_crossattn"]:
63
+ c_out[k] = torch.cat((uc_1[k], uc_2[k], c[k]), 0)
64
+ else:
65
+ assert c[k] == uc_1[k]
66
+ c_out[k] = c[k]
67
+ return torch.cat([x] * 3), torch.cat([s] * 3), c_out
68
+
69
+
70
+
71
+ class IdentityGuider:
72
+ def __call__(self, x, sigma):
73
+ return x
74
+
75
+ def prepare_inputs(self, x, s, c, uc):
76
+ c_out = dict()
77
+
78
+ for k in c:
79
+ c_out[k] = c[k]
80
+
81
+ return x, s, c_out
sgm/modules/diffusionmodules/loss.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from omegaconf import ListConfig
7
+ from taming.modules.losses.lpips import LPIPS
8
+ from torchvision.utils import save_image
9
+ from ...util import append_dims, instantiate_from_config
10
+
11
+
12
+ class StandardDiffusionLoss(nn.Module):
13
+ def __init__(
14
+ self,
15
+ sigma_sampler_config,
16
+ type="l2",
17
+ offset_noise_level=0.0,
18
+ batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
19
+ ):
20
+ super().__init__()
21
+
22
+ assert type in ["l2", "l1", "lpips"]
23
+
24
+ self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
25
+
26
+ self.type = type
27
+ self.offset_noise_level = offset_noise_level
28
+
29
+ if type == "lpips":
30
+ self.lpips = LPIPS().eval()
31
+
32
+ if not batch2model_keys:
33
+ batch2model_keys = []
34
+
35
+ if isinstance(batch2model_keys, str):
36
+ batch2model_keys = [batch2model_keys]
37
+
38
+ self.batch2model_keys = set(batch2model_keys)
39
+
40
+ def __call__(self, network, denoiser, conditioner, input, batch, *args, **kwarg):
41
+ cond = conditioner(batch)
42
+ additional_model_inputs = {
43
+ key: batch[key] for key in self.batch2model_keys.intersection(batch)
44
+ }
45
+
46
+ sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
47
+ noise = torch.randn_like(input)
48
+ if self.offset_noise_level > 0.0:
49
+ noise = noise + self.offset_noise_level * append_dims(
50
+ torch.randn(input.shape[0], device=input.device), input.ndim
51
+ )
52
+ noised_input = input + noise * append_dims(sigmas, input.ndim)
53
+ model_output = denoiser(
54
+ network, noised_input, sigmas, cond, **additional_model_inputs
55
+ )
56
+ w = append_dims(denoiser.w(sigmas), input.ndim)
57
+
58
+ loss = self.get_diff_loss(model_output, input, w)
59
+ loss = loss.mean()
60
+ loss_dict = {"loss": loss}
61
+
62
+ return loss, loss_dict
63
+
64
+ def get_diff_loss(self, model_output, target, w):
65
+ if self.type == "l2":
66
+ return torch.mean(
67
+ (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
68
+ )
69
+ elif self.type == "l1":
70
+ return torch.mean(
71
+ (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
72
+ )
73
+ elif self.type == "lpips":
74
+ loss = self.lpips(model_output, target).reshape(-1)
75
+ return loss
76
+
77
+
78
+ class FullLoss(StandardDiffusionLoss):
79
+
80
+ def __init__(
81
+ self,
82
+ seq_len=12,
83
+ kernel_size=3,
84
+ gaussian_sigma=0.5,
85
+ min_attn_size=16,
86
+ lambda_local_loss=0.0,
87
+ lambda_ocr_loss=0.0,
88
+ ocr_enabled = False,
89
+ predictor_config = None,
90
+ *args, **kwarg
91
+ ):
92
+ super().__init__(*args, **kwarg)
93
+
94
+ self.gaussian_kernel_size = kernel_size
95
+ gaussian_kernel = self.get_gaussian_kernel(kernel_size=self.gaussian_kernel_size, sigma=gaussian_sigma, out_channels=seq_len)
96
+ self.register_buffer("g_kernel", gaussian_kernel.requires_grad_(False))
97
+
98
+ self.min_attn_size = min_attn_size
99
+ self.lambda_local_loss = lambda_local_loss
100
+ self.lambda_ocr_loss = lambda_ocr_loss
101
+
102
+ self.ocr_enabled = ocr_enabled
103
+ if ocr_enabled:
104
+ self.predictor = instantiate_from_config(predictor_config)
105
+
106
+ def get_gaussian_kernel(self, kernel_size=3, sigma=1, out_channels=3):
107
+ # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
108
+ x_coord = torch.arange(kernel_size)
109
+ x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
110
+ y_grid = x_grid.t()
111
+ xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
112
+
113
+ mean = (kernel_size - 1)/2.
114
+ variance = sigma**2.
115
+
116
+ # Calculate the 2-dimensional gaussian kernel which is
117
+ # the product of two gaussian distributions for two different
118
+ # variables (in this case called x and y)
119
+ gaussian_kernel = (1./(2.*torch.pi*variance)) *\
120
+ torch.exp(
121
+ -torch.sum((xy_grid - mean)**2., dim=-1) /\
122
+ (2*variance)
123
+ )
124
+
125
+ # Make sure sum of values in gaussian kernel equals 1.
126
+ gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
127
+
128
+ # Reshape to 2d depthwise convolutional weight
129
+ gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
130
+ gaussian_kernel = gaussian_kernel.tile(out_channels, 1, 1, 1)
131
+
132
+ return gaussian_kernel
133
+
134
+ def __call__(self, network, denoiser, conditioner, input, batch, first_stage_model, scaler):
135
+
136
+ cond = conditioner(batch)
137
+
138
+ sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
139
+ noise = torch.randn_like(input)
140
+ if self.offset_noise_level > 0.0:
141
+ noise = noise + self.offset_noise_level * append_dims(
142
+ torch.randn(input.shape[0], device=input.device), input.ndim
143
+ )
144
+
145
+ noised_input = input + noise * append_dims(sigmas, input.ndim)
146
+ model_output = denoiser(network, noised_input, sigmas, cond)
147
+ w = append_dims(denoiser.w(sigmas), input.ndim)
148
+
149
+ diff_loss = self.get_diff_loss(model_output, input, w)
150
+ local_loss = self.get_local_loss(network.diffusion_model.attn_map_cache, batch["seg"], batch["seg_mask"])
151
+ diff_loss = diff_loss.mean()
152
+ local_loss = local_loss.mean()
153
+
154
+ if self.ocr_enabled:
155
+ ocr_loss = self.get_ocr_loss(model_output, batch["r_bbox"], batch["label"], first_stage_model, scaler)
156
+ ocr_loss = ocr_loss.mean()
157
+
158
+ loss = diff_loss + self.lambda_local_loss * local_loss
159
+ if self.ocr_enabled:
160
+ loss += self.lambda_ocr_loss * ocr_loss
161
+
162
+ loss_dict = {
163
+ "loss/diff_loss": diff_loss,
164
+ "loss/local_loss": local_loss,
165
+ "loss/full_loss": loss
166
+ }
167
+
168
+ if self.ocr_enabled:
169
+ loss_dict["loss/ocr_loss"] = ocr_loss
170
+
171
+ return loss, loss_dict
172
+
173
+ def get_ocr_loss(self, model_output, r_bbox, label, first_stage_model, scaler):
174
+
175
+ model_output = 1 / scaler * model_output
176
+ model_output_decoded = first_stage_model.decode(model_output)
177
+ model_output_crops = []
178
+
179
+ for i, bbox in enumerate(r_bbox):
180
+ m_top, m_bottom, m_left, m_right = bbox
181
+ model_output_crops.append(model_output_decoded[i, :, m_top:m_bottom, m_left:m_right])
182
+
183
+ loss = self.predictor.calc_loss(model_output_crops, label)
184
+
185
+ return loss
186
+
187
+ def get_min_local_loss(self, attn_map_cache, mask, seg_mask):
188
+
189
+ loss = 0
190
+ count = 0
191
+
192
+ for item in attn_map_cache:
193
+
194
+ heads = item["heads"]
195
+ size = item["size"]
196
+ attn_map = item["attn_map"]
197
+
198
+ if size < self.min_attn_size: continue
199
+
200
+ seg_l = seg_mask.shape[1]
201
+
202
+ bh, n, l = attn_map.shape # bh: batch size * heads / n : pixel length(h*w) / l: token length
203
+ attn_map = attn_map.reshape((-1, heads, n, l)) # b, h, n, l
204
+
205
+ assert seg_l <= l
206
+ attn_map = attn_map[..., :seg_l]
207
+ attn_map = attn_map.permute(0, 1, 3, 2) # b, h, l, n
208
+ attn_map = attn_map.mean(dim = 1) # b, l, n
209
+
210
+ attn_map = attn_map.reshape((-1, seg_l, size, size)) # b, l, s, s
211
+ attn_map = F.conv2d(attn_map, self.g_kernel, padding = self.gaussian_kernel_size//2, groups=seg_l) # gaussian blur on each channel
212
+ attn_map = attn_map.reshape((-1, seg_l, n)) # b, l, n
213
+
214
+ mask_map = F.interpolate(mask, (size, size))
215
+ mask_map = mask_map.tile((1, seg_l, 1, 1))
216
+ mask_map = mask_map.reshape((-1, seg_l, n)) # b, l, n
217
+
218
+ p_loss = (mask_map * attn_map).max(dim = -1)[0] # b, l
219
+ p_loss = p_loss + (1 - seg_mask) # b, l
220
+ p_loss = p_loss.min(dim = -1)[0] # b,
221
+
222
+ loss += -p_loss
223
+ count += 1
224
+
225
+ loss = loss / count
226
+
227
+ return loss
228
+
229
+ def get_local_loss(self, attn_map_cache, seg, seg_mask):
230
+
231
+ loss = 0
232
+ count = 0
233
+
234
+ for item in attn_map_cache:
235
+
236
+ heads = item["heads"]
237
+ size = item["size"]
238
+ attn_map = item["attn_map"]
239
+
240
+ if size < self.min_attn_size: continue
241
+
242
+ seg_l = seg_mask.shape[1]
243
+
244
+ bh, n, l = attn_map.shape # bh: batch size * heads / n : pixel length(h*w) / l: token length
245
+ attn_map = attn_map.reshape((-1, heads, n, l)) # b, h, n, l
246
+
247
+ assert seg_l <= l
248
+ attn_map = attn_map[..., :seg_l]
249
+ attn_map = attn_map.permute(0, 1, 3, 2) # b, h, l, n
250
+ attn_map = attn_map.mean(dim = 1) # b, l, n
251
+
252
+ attn_map = attn_map.reshape((-1, seg_l, size, size)) # b, l, s, s
253
+ attn_map = F.conv2d(attn_map, self.g_kernel, padding = self.gaussian_kernel_size//2, groups=seg_l) # gaussian blur on each channel
254
+ attn_map = attn_map.reshape((-1, seg_l, n)) # b, l, n
255
+
256
+ seg_map = F.interpolate(seg, (size, size))
257
+ seg_map = seg_map.reshape((-1, seg_l, n)) # b, l, n
258
+ n_seg_map = 1 - seg_map
259
+
260
+ p_loss = (seg_map * attn_map).max(dim = -1)[0] # b, l
261
+ n_loss = (n_seg_map * attn_map).max(dim = -1)[0] # b, l
262
+
263
+ p_loss = p_loss * seg_mask # b, l
264
+ n_loss = n_loss * seg_mask # b, l
265
+
266
+ p_loss = p_loss.sum(dim = -1) / seg_mask.sum(dim = -1) # b,
267
+ n_loss = n_loss.sum(dim = -1) / seg_mask.sum(dim = -1) # b,
268
+
269
+ f_loss = n_loss - p_loss # b,
270
+ loss += f_loss
271
+ count += 1
272
+
273
+ loss = loss / count
274
+
275
+ return loss
sgm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ from typing import Any, Callable, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from packaging import version
10
+
11
+ try:
12
+ import xformers
13
+ import xformers.ops
14
+
15
+ XFORMERS_IS_AVAILABLE = True
16
+ except:
17
+ XFORMERS_IS_AVAILABLE = False
18
+ print("no module 'xformers'. Processing without...")
19
+
20
+ from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
21
+
22
+
23
+ def get_timestep_embedding(timesteps, embedding_dim):
24
+ """
25
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
26
+ From Fairseq.
27
+ Build sinusoidal embeddings.
28
+ This matches the implementation in tensor2tensor, but differs slightly
29
+ from the description in Section 3.5 of "Attention Is All You Need".
30
+ """
31
+ assert len(timesteps.shape) == 1
32
+
33
+ half_dim = embedding_dim // 2
34
+ emb = math.log(10000) / (half_dim - 1)
35
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
36
+ emb = emb.to(device=timesteps.device)
37
+ emb = timesteps.float()[:, None] * emb[None, :]
38
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
39
+ if embedding_dim % 2 == 1: # zero pad
40
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
41
+ return emb
42
+
43
+
44
+ def nonlinearity(x):
45
+ # swish
46
+ return x * torch.sigmoid(x)
47
+
48
+
49
+ def Normalize(in_channels, num_groups=32):
50
+ return torch.nn.GroupNorm(
51
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
52
+ )
53
+
54
+
55
+ class Upsample(nn.Module):
56
+ def __init__(self, in_channels, with_conv):
57
+ super().__init__()
58
+ self.with_conv = with_conv
59
+ if self.with_conv:
60
+ self.conv = torch.nn.Conv2d(
61
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
62
+ )
63
+
64
+ def forward(self, x):
65
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
66
+ if self.with_conv:
67
+ x = self.conv(x)
68
+ return x
69
+
70
+
71
+ class Downsample(nn.Module):
72
+ def __init__(self, in_channels, with_conv):
73
+ super().__init__()
74
+ self.with_conv = with_conv
75
+ if self.with_conv:
76
+ # no asymmetric padding in torch conv, must do it ourselves
77
+ self.conv = torch.nn.Conv2d(
78
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
79
+ )
80
+
81
+ def forward(self, x):
82
+ if self.with_conv:
83
+ pad = (0, 1, 0, 1)
84
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
85
+ x = self.conv(x)
86
+ else:
87
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
88
+ return x
89
+
90
+
91
+ class ResnetBlock(nn.Module):
92
+ def __init__(
93
+ self,
94
+ *,
95
+ in_channels,
96
+ out_channels=None,
97
+ conv_shortcut=False,
98
+ dropout,
99
+ temb_channels=512,
100
+ ):
101
+ super().__init__()
102
+ self.in_channels = in_channels
103
+ out_channels = in_channels if out_channels is None else out_channels
104
+ self.out_channels = out_channels
105
+ self.use_conv_shortcut = conv_shortcut
106
+
107
+ self.norm1 = Normalize(in_channels)
108
+ self.conv1 = torch.nn.Conv2d(
109
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
110
+ )
111
+ if temb_channels > 0:
112
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
113
+ self.norm2 = Normalize(out_channels)
114
+ self.dropout = torch.nn.Dropout(dropout)
115
+ self.conv2 = torch.nn.Conv2d(
116
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
117
+ )
118
+ if self.in_channels != self.out_channels:
119
+ if self.use_conv_shortcut:
120
+ self.conv_shortcut = torch.nn.Conv2d(
121
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
122
+ )
123
+ else:
124
+ self.nin_shortcut = torch.nn.Conv2d(
125
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
126
+ )
127
+
128
+ def forward(self, x, temb):
129
+ h = x
130
+ h = self.norm1(h)
131
+ h = nonlinearity(h)
132
+ h = self.conv1(h)
133
+
134
+ if temb is not None:
135
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
136
+
137
+ h = self.norm2(h)
138
+ h = nonlinearity(h)
139
+ h = self.dropout(h)
140
+ h = self.conv2(h)
141
+
142
+ if self.in_channels != self.out_channels:
143
+ if self.use_conv_shortcut:
144
+ x = self.conv_shortcut(x)
145
+ else:
146
+ x = self.nin_shortcut(x)
147
+
148
+ return x + h
149
+
150
+
151
+ class LinAttnBlock(LinearAttention):
152
+ """to match AttnBlock usage"""
153
+
154
+ def __init__(self, in_channels):
155
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
156
+
157
+
158
+ class AttnBlock(nn.Module):
159
+ def __init__(self, in_channels):
160
+ super().__init__()
161
+ self.in_channels = in_channels
162
+
163
+ self.norm = Normalize(in_channels)
164
+ self.q = torch.nn.Conv2d(
165
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
166
+ )
167
+ self.k = torch.nn.Conv2d(
168
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
169
+ )
170
+ self.v = torch.nn.Conv2d(
171
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
172
+ )
173
+ self.proj_out = torch.nn.Conv2d(
174
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
175
+ )
176
+
177
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
178
+ h_ = self.norm(h_)
179
+ q = self.q(h_)
180
+ k = self.k(h_)
181
+ v = self.v(h_)
182
+
183
+ b, c, h, w = q.shape
184
+ q, k, v = map(
185
+ lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)
186
+ )
187
+ h_ = torch.nn.functional.scaled_dot_product_attention(
188
+ q, k, v
189
+ ) # scale is dim ** -0.5 per default
190
+ # compute attention
191
+
192
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
193
+
194
+ def forward(self, x, **kwargs):
195
+ h_ = x
196
+ h_ = self.attention(h_)
197
+ h_ = self.proj_out(h_)
198
+ return x + h_
199
+
200
+
201
+ class MemoryEfficientAttnBlock(nn.Module):
202
+ """
203
+ Uses xformers efficient implementation,
204
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
205
+ Note: this is a single-head self-attention operation
206
+ """
207
+
208
+ #
209
+ def __init__(self, in_channels):
210
+ super().__init__()
211
+ self.in_channels = in_channels
212
+
213
+ self.norm = Normalize(in_channels)
214
+ self.q = torch.nn.Conv2d(
215
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
216
+ )
217
+ self.k = torch.nn.Conv2d(
218
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
219
+ )
220
+ self.v = torch.nn.Conv2d(
221
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
222
+ )
223
+ self.proj_out = torch.nn.Conv2d(
224
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
225
+ )
226
+ self.attention_op: Optional[Any] = None
227
+
228
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
229
+ h_ = self.norm(h_)
230
+ q = self.q(h_)
231
+ k = self.k(h_)
232
+ v = self.v(h_)
233
+
234
+ # compute attention
235
+ B, C, H, W = q.shape
236
+ q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
237
+
238
+ q, k, v = map(
239
+ lambda t: t.unsqueeze(3)
240
+ .reshape(B, t.shape[1], 1, C)
241
+ .permute(0, 2, 1, 3)
242
+ .reshape(B * 1, t.shape[1], C)
243
+ .contiguous(),
244
+ (q, k, v),
245
+ )
246
+ out = xformers.ops.memory_efficient_attention(
247
+ q, k, v, attn_bias=None, op=self.attention_op
248
+ )
249
+
250
+ out = (
251
+ out.unsqueeze(0)
252
+ .reshape(B, 1, out.shape[1], C)
253
+ .permute(0, 2, 1, 3)
254
+ .reshape(B, out.shape[1], C)
255
+ )
256
+ return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
257
+
258
+ def forward(self, x, **kwargs):
259
+ h_ = x
260
+ h_ = self.attention(h_)
261
+ h_ = self.proj_out(h_)
262
+ return x + h_
263
+
264
+
265
+ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
266
+ def forward(self, x, context=None, mask=None, **unused_kwargs):
267
+ b, c, h, w = x.shape
268
+ x = rearrange(x, "b c h w -> b (h w) c")
269
+ out = super().forward(x, context=context, mask=mask)
270
+ out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
271
+ return x + out
272
+
273
+
274
+ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
275
+ assert attn_type in [
276
+ "vanilla",
277
+ "vanilla-xformers",
278
+ "memory-efficient-cross-attn",
279
+ "linear",
280
+ "none",
281
+ ], f"attn_type {attn_type} unknown"
282
+ if (
283
+ version.parse(torch.__version__) < version.parse("2.0.0")
284
+ and attn_type != "none"
285
+ ):
286
+ assert XFORMERS_IS_AVAILABLE, (
287
+ f"We do not support vanilla attention in {torch.__version__} anymore, "
288
+ f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
289
+ )
290
+ attn_type = "vanilla-xformers"
291
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
292
+ if attn_type == "vanilla":
293
+ assert attn_kwargs is None
294
+ return AttnBlock(in_channels)
295
+ elif attn_type == "vanilla-xformers":
296
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
297
+ return MemoryEfficientAttnBlock(in_channels)
298
+ elif type == "memory-efficient-cross-attn":
299
+ attn_kwargs["query_dim"] = in_channels
300
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
301
+ elif attn_type == "none":
302
+ return nn.Identity(in_channels)
303
+ else:
304
+ return LinAttnBlock(in_channels)
305
+
306
+
307
+ class Model(nn.Module):
308
+ def __init__(
309
+ self,
310
+ *,
311
+ ch,
312
+ out_ch,
313
+ ch_mult=(1, 2, 4, 8),
314
+ num_res_blocks,
315
+ attn_resolutions,
316
+ dropout=0.0,
317
+ resamp_with_conv=True,
318
+ in_channels,
319
+ resolution,
320
+ use_timestep=True,
321
+ use_linear_attn=False,
322
+ attn_type="vanilla",
323
+ ):
324
+ super().__init__()
325
+ if use_linear_attn:
326
+ attn_type = "linear"
327
+ self.ch = ch
328
+ self.temb_ch = self.ch * 4
329
+ self.num_resolutions = len(ch_mult)
330
+ self.num_res_blocks = num_res_blocks
331
+ self.resolution = resolution
332
+ self.in_channels = in_channels
333
+
334
+ self.use_timestep = use_timestep
335
+ if self.use_timestep:
336
+ # timestep embedding
337
+ self.temb = nn.Module()
338
+ self.temb.dense = nn.ModuleList(
339
+ [
340
+ torch.nn.Linear(self.ch, self.temb_ch),
341
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
342
+ ]
343
+ )
344
+
345
+ # downsampling
346
+ self.conv_in = torch.nn.Conv2d(
347
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
348
+ )
349
+
350
+ curr_res = resolution
351
+ in_ch_mult = (1,) + tuple(ch_mult)
352
+ self.down = nn.ModuleList()
353
+ for i_level in range(self.num_resolutions):
354
+ block = nn.ModuleList()
355
+ attn = nn.ModuleList()
356
+ block_in = ch * in_ch_mult[i_level]
357
+ block_out = ch * ch_mult[i_level]
358
+ for i_block in range(self.num_res_blocks):
359
+ block.append(
360
+ ResnetBlock(
361
+ in_channels=block_in,
362
+ out_channels=block_out,
363
+ temb_channels=self.temb_ch,
364
+ dropout=dropout,
365
+ )
366
+ )
367
+ block_in = block_out
368
+ if curr_res in attn_resolutions:
369
+ attn.append(make_attn(block_in, attn_type=attn_type))
370
+ down = nn.Module()
371
+ down.block = block
372
+ down.attn = attn
373
+ if i_level != self.num_resolutions - 1:
374
+ down.downsample = Downsample(block_in, resamp_with_conv)
375
+ curr_res = curr_res // 2
376
+ self.down.append(down)
377
+
378
+ # middle
379
+ self.mid = nn.Module()
380
+ self.mid.block_1 = ResnetBlock(
381
+ in_channels=block_in,
382
+ out_channels=block_in,
383
+ temb_channels=self.temb_ch,
384
+ dropout=dropout,
385
+ )
386
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
387
+ self.mid.block_2 = ResnetBlock(
388
+ in_channels=block_in,
389
+ out_channels=block_in,
390
+ temb_channels=self.temb_ch,
391
+ dropout=dropout,
392
+ )
393
+
394
+ # upsampling
395
+ self.up = nn.ModuleList()
396
+ for i_level in reversed(range(self.num_resolutions)):
397
+ block = nn.ModuleList()
398
+ attn = nn.ModuleList()
399
+ block_out = ch * ch_mult[i_level]
400
+ skip_in = ch * ch_mult[i_level]
401
+ for i_block in range(self.num_res_blocks + 1):
402
+ if i_block == self.num_res_blocks:
403
+ skip_in = ch * in_ch_mult[i_level]
404
+ block.append(
405
+ ResnetBlock(
406
+ in_channels=block_in + skip_in,
407
+ out_channels=block_out,
408
+ temb_channels=self.temb_ch,
409
+ dropout=dropout,
410
+ )
411
+ )
412
+ block_in = block_out
413
+ if curr_res in attn_resolutions:
414
+ attn.append(make_attn(block_in, attn_type=attn_type))
415
+ up = nn.Module()
416
+ up.block = block
417
+ up.attn = attn
418
+ if i_level != 0:
419
+ up.upsample = Upsample(block_in, resamp_with_conv)
420
+ curr_res = curr_res * 2
421
+ self.up.insert(0, up) # prepend to get consistent order
422
+
423
+ # end
424
+ self.norm_out = Normalize(block_in)
425
+ self.conv_out = torch.nn.Conv2d(
426
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
427
+ )
428
+
429
+ def forward(self, x, t=None, context=None):
430
+ # assert x.shape[2] == x.shape[3] == self.resolution
431
+ if context is not None:
432
+ # assume aligned context, cat along channel axis
433
+ x = torch.cat((x, context), dim=1)
434
+ if self.use_timestep:
435
+ # timestep embedding
436
+ assert t is not None
437
+ temb = get_timestep_embedding(t, self.ch)
438
+ temb = self.temb.dense[0](temb)
439
+ temb = nonlinearity(temb)
440
+ temb = self.temb.dense[1](temb)
441
+ else:
442
+ temb = None
443
+
444
+ # downsampling
445
+ hs = [self.conv_in(x)]
446
+ for i_level in range(self.num_resolutions):
447
+ for i_block in range(self.num_res_blocks):
448
+ h = self.down[i_level].block[i_block](hs[-1], temb)
449
+ if len(self.down[i_level].attn) > 0:
450
+ h = self.down[i_level].attn[i_block](h)
451
+ hs.append(h)
452
+ if i_level != self.num_resolutions - 1:
453
+ hs.append(self.down[i_level].downsample(hs[-1]))
454
+
455
+ # middle
456
+ h = hs[-1]
457
+ h = self.mid.block_1(h, temb)
458
+ h = self.mid.attn_1(h)
459
+ h = self.mid.block_2(h, temb)
460
+
461
+ # upsampling
462
+ for i_level in reversed(range(self.num_resolutions)):
463
+ for i_block in range(self.num_res_blocks + 1):
464
+ h = self.up[i_level].block[i_block](
465
+ torch.cat([h, hs.pop()], dim=1), temb
466
+ )
467
+ if len(self.up[i_level].attn) > 0:
468
+ h = self.up[i_level].attn[i_block](h)
469
+ if i_level != 0:
470
+ h = self.up[i_level].upsample(h)
471
+
472
+ # end
473
+ h = self.norm_out(h)
474
+ h = nonlinearity(h)
475
+ h = self.conv_out(h)
476
+ return h
477
+
478
+ def get_last_layer(self):
479
+ return self.conv_out.weight
480
+
481
+
482
+ class Encoder(nn.Module):
483
+ def __init__(
484
+ self,
485
+ *,
486
+ ch,
487
+ out_ch,
488
+ ch_mult=(1, 2, 4, 8),
489
+ num_res_blocks,
490
+ attn_resolutions,
491
+ dropout=0.0,
492
+ resamp_with_conv=True,
493
+ in_channels,
494
+ resolution,
495
+ z_channels,
496
+ double_z=True,
497
+ use_linear_attn=False,
498
+ attn_type="vanilla",
499
+ **ignore_kwargs,
500
+ ):
501
+ super().__init__()
502
+ if use_linear_attn:
503
+ attn_type = "linear"
504
+ self.ch = ch
505
+ self.temb_ch = 0
506
+ self.num_resolutions = len(ch_mult)
507
+ self.num_res_blocks = num_res_blocks
508
+ self.resolution = resolution
509
+ self.in_channels = in_channels
510
+
511
+ # downsampling
512
+ self.conv_in = torch.nn.Conv2d(
513
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
514
+ )
515
+
516
+ curr_res = resolution
517
+ in_ch_mult = (1,) + tuple(ch_mult)
518
+ self.in_ch_mult = in_ch_mult
519
+ self.down = nn.ModuleList()
520
+ for i_level in range(self.num_resolutions):
521
+ block = nn.ModuleList()
522
+ attn = nn.ModuleList()
523
+ block_in = ch * in_ch_mult[i_level]
524
+ block_out = ch * ch_mult[i_level]
525
+ for i_block in range(self.num_res_blocks):
526
+ block.append(
527
+ ResnetBlock(
528
+ in_channels=block_in,
529
+ out_channels=block_out,
530
+ temb_channels=self.temb_ch,
531
+ dropout=dropout,
532
+ )
533
+ )
534
+ block_in = block_out
535
+ if curr_res in attn_resolutions:
536
+ attn.append(make_attn(block_in, attn_type=attn_type))
537
+ down = nn.Module()
538
+ down.block = block
539
+ down.attn = attn
540
+ if i_level != self.num_resolutions - 1:
541
+ down.downsample = Downsample(block_in, resamp_with_conv)
542
+ curr_res = curr_res // 2
543
+ self.down.append(down)
544
+
545
+ # middle
546
+ self.mid = nn.Module()
547
+ self.mid.block_1 = ResnetBlock(
548
+ in_channels=block_in,
549
+ out_channels=block_in,
550
+ temb_channels=self.temb_ch,
551
+ dropout=dropout,
552
+ )
553
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
554
+ self.mid.block_2 = ResnetBlock(
555
+ in_channels=block_in,
556
+ out_channels=block_in,
557
+ temb_channels=self.temb_ch,
558
+ dropout=dropout,
559
+ )
560
+
561
+ # end
562
+ self.norm_out = Normalize(block_in)
563
+ self.conv_out = torch.nn.Conv2d(
564
+ block_in,
565
+ 2 * z_channels if double_z else z_channels,
566
+ kernel_size=3,
567
+ stride=1,
568
+ padding=1,
569
+ )
570
+
571
+ def forward(self, x):
572
+ # timestep embedding
573
+ temb = None
574
+
575
+ # downsampling
576
+ hs = [self.conv_in(x)]
577
+ for i_level in range(self.num_resolutions):
578
+ for i_block in range(self.num_res_blocks):
579
+ h = self.down[i_level].block[i_block](hs[-1], temb)
580
+ if len(self.down[i_level].attn) > 0:
581
+ h = self.down[i_level].attn[i_block](h)
582
+ hs.append(h)
583
+ if i_level != self.num_resolutions - 1:
584
+ hs.append(self.down[i_level].downsample(hs[-1]))
585
+
586
+ # middle
587
+ h = hs[-1]
588
+ h = self.mid.block_1(h, temb)
589
+ h = self.mid.attn_1(h)
590
+ h = self.mid.block_2(h, temb)
591
+
592
+ # end
593
+ h = self.norm_out(h)
594
+ h = nonlinearity(h)
595
+ h = self.conv_out(h)
596
+ return h
597
+
598
+
599
+ class Decoder(nn.Module):
600
+ def __init__(
601
+ self,
602
+ *,
603
+ ch,
604
+ out_ch,
605
+ ch_mult=(1, 2, 4, 8),
606
+ num_res_blocks,
607
+ attn_resolutions,
608
+ dropout=0.0,
609
+ resamp_with_conv=True,
610
+ in_channels,
611
+ resolution,
612
+ z_channels,
613
+ give_pre_end=False,
614
+ tanh_out=False,
615
+ use_linear_attn=False,
616
+ attn_type="vanilla",
617
+ **ignorekwargs,
618
+ ):
619
+ super().__init__()
620
+ if use_linear_attn:
621
+ attn_type = "linear"
622
+ self.ch = ch
623
+ self.temb_ch = 0
624
+ self.num_resolutions = len(ch_mult)
625
+ self.num_res_blocks = num_res_blocks
626
+ self.resolution = resolution
627
+ self.in_channels = in_channels
628
+ self.give_pre_end = give_pre_end
629
+ self.tanh_out = tanh_out
630
+
631
+ # compute in_ch_mult, block_in and curr_res at lowest res
632
+ in_ch_mult = (1,) + tuple(ch_mult)
633
+ block_in = ch * ch_mult[self.num_resolutions - 1]
634
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
635
+ self.z_shape = (1, z_channels, curr_res, curr_res)
636
+ print(
637
+ "Working with z of shape {} = {} dimensions.".format(
638
+ self.z_shape, np.prod(self.z_shape)
639
+ )
640
+ )
641
+
642
+ make_attn_cls = self._make_attn()
643
+ make_resblock_cls = self._make_resblock()
644
+ make_conv_cls = self._make_conv()
645
+ # z to block_in
646
+ self.conv_in = torch.nn.Conv2d(
647
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
648
+ )
649
+
650
+ # middle
651
+ self.mid = nn.Module()
652
+ self.mid.block_1 = make_resblock_cls(
653
+ in_channels=block_in,
654
+ out_channels=block_in,
655
+ temb_channels=self.temb_ch,
656
+ dropout=dropout,
657
+ )
658
+ self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
659
+ self.mid.block_2 = make_resblock_cls(
660
+ in_channels=block_in,
661
+ out_channels=block_in,
662
+ temb_channels=self.temb_ch,
663
+ dropout=dropout,
664
+ )
665
+
666
+ # upsampling
667
+ self.up = nn.ModuleList()
668
+ for i_level in reversed(range(self.num_resolutions)):
669
+ block = nn.ModuleList()
670
+ attn = nn.ModuleList()
671
+ block_out = ch * ch_mult[i_level]
672
+ for i_block in range(self.num_res_blocks + 1):
673
+ block.append(
674
+ make_resblock_cls(
675
+ in_channels=block_in,
676
+ out_channels=block_out,
677
+ temb_channels=self.temb_ch,
678
+ dropout=dropout,
679
+ )
680
+ )
681
+ block_in = block_out
682
+ if curr_res in attn_resolutions:
683
+ attn.append(make_attn_cls(block_in, attn_type=attn_type))
684
+ up = nn.Module()
685
+ up.block = block
686
+ up.attn = attn
687
+ if i_level != 0:
688
+ up.upsample = Upsample(block_in, resamp_with_conv)
689
+ curr_res = curr_res * 2
690
+ self.up.insert(0, up) # prepend to get consistent order
691
+
692
+ # end
693
+ self.norm_out = Normalize(block_in)
694
+ self.conv_out = make_conv_cls(
695
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
696
+ )
697
+
698
+ def _make_attn(self) -> Callable:
699
+ return make_attn
700
+
701
+ def _make_resblock(self) -> Callable:
702
+ return ResnetBlock
703
+
704
+ def _make_conv(self) -> Callable:
705
+ return torch.nn.Conv2d
706
+
707
+ def get_last_layer(self, **kwargs):
708
+ return self.conv_out.weight
709
+
710
+ def forward(self, z, **kwargs):
711
+ # assert z.shape[1:] == self.z_shape[1:]
712
+ self.last_z_shape = z.shape
713
+
714
+ # timestep embedding
715
+ temb = None
716
+
717
+ # z to block_in
718
+ h = self.conv_in(z)
719
+
720
+ # middle
721
+ h = self.mid.block_1(h, temb, **kwargs)
722
+ h = self.mid.attn_1(h, **kwargs)
723
+ h = self.mid.block_2(h, temb, **kwargs)
724
+
725
+ # upsampling
726
+ for i_level in reversed(range(self.num_resolutions)):
727
+ for i_block in range(self.num_res_blocks + 1):
728
+ h = self.up[i_level].block[i_block](h, temb, **kwargs)
729
+ if len(self.up[i_level].attn) > 0:
730
+ h = self.up[i_level].attn[i_block](h, **kwargs)
731
+ if i_level != 0:
732
+ h = self.up[i_level].upsample(h)
733
+
734
+ # end
735
+ if self.give_pre_end:
736
+ return h
737
+
738
+ h = self.norm_out(h)
739
+ h = nonlinearity(h)
740
+ h = self.conv_out(h, **kwargs)
741
+ if self.tanh_out:
742
+ h = torch.tanh(h)
743
+ return h
sgm/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,2070 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ from abc import abstractmethod
4
+ from functools import partial
5
+ from typing import Iterable
6
+
7
+ import numpy as np
8
+ import torch as th
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange
12
+
13
+ from ...modules.attention import SpatialTransformer
14
+ from ...modules.diffusionmodules.util import (
15
+ avg_pool_nd,
16
+ checkpoint,
17
+ conv_nd,
18
+ linear,
19
+ normalization,
20
+ timestep_embedding,
21
+ zero_module,
22
+ )
23
+ from ...util import default, exists
24
+
25
+
26
+ # dummy replace
27
+ def convert_module_to_f16(x):
28
+ pass
29
+
30
+
31
+ def convert_module_to_f32(x):
32
+ pass
33
+
34
+
35
+ ## go
36
+ class AttentionPool2d(nn.Module):
37
+ """
38
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ spacial_dim: int,
44
+ embed_dim: int,
45
+ num_heads_channels: int,
46
+ output_dim: int = None,
47
+ ):
48
+ super().__init__()
49
+ self.positional_embedding = nn.Parameter(
50
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
51
+ )
52
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
53
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
54
+ self.num_heads = embed_dim // num_heads_channels
55
+ self.attention = QKVAttention(self.num_heads)
56
+
57
+ def forward(self, x):
58
+ b, c, *_spatial = x.shape
59
+ x = x.reshape(b, c, -1) # NC(HW)
60
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
61
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
62
+ x = self.qkv_proj(x)
63
+ x = self.attention(x)
64
+ x = self.c_proj(x)
65
+ return x[:, :, 0]
66
+
67
+
68
+ class TimestepBlock(nn.Module):
69
+ """
70
+ Any module where forward() takes timestep embeddings as a second argument.
71
+ """
72
+
73
+ @abstractmethod
74
+ def forward(self, x, emb):
75
+ """
76
+ Apply the module to `x` given `emb` timestep embeddings.
77
+ """
78
+
79
+
80
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
81
+ """
82
+ A sequential module that passes timestep embeddings to the children that
83
+ support it as an extra input.
84
+ """
85
+
86
+ def forward(
87
+ self,
88
+ x,
89
+ emb,
90
+ context=None,
91
+ add_context=None,
92
+ skip_time_mix=False,
93
+ time_context=None,
94
+ num_video_frames=None,
95
+ time_context_cat=None,
96
+ use_crossframe_attention_in_spatial_layers=False,
97
+ ):
98
+ for layer in self:
99
+ if isinstance(layer, TimestepBlock):
100
+ x = layer(x, emb)
101
+ elif isinstance(layer, SpatialTransformer):
102
+ x = layer(x, context, add_context)
103
+ else:
104
+ x = layer(x)
105
+ return x
106
+
107
+
108
+ class Upsample(nn.Module):
109
+ """
110
+ An upsampling layer with an optional convolution.
111
+ :param channels: channels in the inputs and outputs.
112
+ :param use_conv: a bool determining if a convolution is applied.
113
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
114
+ upsampling occurs in the inner-two dimensions.
115
+ """
116
+
117
+ def __init__(
118
+ self, channels, use_conv, dims=2, out_channels=None, padding=1, third_up=False
119
+ ):
120
+ super().__init__()
121
+ self.channels = channels
122
+ self.out_channels = out_channels or channels
123
+ self.use_conv = use_conv
124
+ self.dims = dims
125
+ self.third_up = third_up
126
+ if use_conv:
127
+ self.conv = conv_nd(
128
+ dims, self.channels, self.out_channels, 3, padding=padding
129
+ )
130
+
131
+ def forward(self, x):
132
+ assert x.shape[1] == self.channels
133
+ if self.dims == 3:
134
+ t_factor = 1 if not self.third_up else 2
135
+ x = F.interpolate(
136
+ x,
137
+ (t_factor * x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
138
+ mode="nearest",
139
+ )
140
+ else:
141
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
142
+ if self.use_conv:
143
+ x = self.conv(x)
144
+ return x
145
+
146
+
147
+ class TransposedUpsample(nn.Module):
148
+ "Learned 2x upsampling without padding"
149
+
150
+ def __init__(self, channels, out_channels=None, ks=5):
151
+ super().__init__()
152
+ self.channels = channels
153
+ self.out_channels = out_channels or channels
154
+
155
+ self.up = nn.ConvTranspose2d(
156
+ self.channels, self.out_channels, kernel_size=ks, stride=2
157
+ )
158
+
159
+ def forward(self, x):
160
+ return self.up(x)
161
+
162
+
163
+ class Downsample(nn.Module):
164
+ """
165
+ A downsampling layer with an optional convolution.
166
+ :param channels: channels in the inputs and outputs.
167
+ :param use_conv: a bool determining if a convolution is applied.
168
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
169
+ downsampling occurs in the inner-two dimensions.
170
+ """
171
+
172
+ def __init__(
173
+ self, channels, use_conv, dims=2, out_channels=None, padding=1, third_down=False
174
+ ):
175
+ super().__init__()
176
+ self.channels = channels
177
+ self.out_channels = out_channels or channels
178
+ self.use_conv = use_conv
179
+ self.dims = dims
180
+ stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
181
+ if use_conv:
182
+ # print(f"Building a Downsample layer with {dims} dims.")
183
+ # print(
184
+ # f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
185
+ # f"kernel-size: 3, stride: {stride}, padding: {padding}"
186
+ # )
187
+ if dims == 3:
188
+ pass
189
+ # print(f" --> Downsampling third axis (time): {third_down}")
190
+ self.op = conv_nd(
191
+ dims,
192
+ self.channels,
193
+ self.out_channels,
194
+ 3,
195
+ stride=stride,
196
+ padding=padding,
197
+ )
198
+ else:
199
+ assert self.channels == self.out_channels
200
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
201
+
202
+ def forward(self, x):
203
+ assert x.shape[1] == self.channels
204
+ return self.op(x)
205
+
206
+
207
+ class ResBlock(TimestepBlock):
208
+ """
209
+ A residual block that can optionally change the number of channels.
210
+ :param channels: the number of input channels.
211
+ :param emb_channels: the number of timestep embedding channels.
212
+ :param dropout: the rate of dropout.
213
+ :param out_channels: if specified, the number of out channels.
214
+ :param use_conv: if True and out_channels is specified, use a spatial
215
+ convolution instead of a smaller 1x1 convolution to change the
216
+ channels in the skip connection.
217
+ :param dims: determines if the signal is 1D, 2D, or 3D.
218
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
219
+ :param up: if True, use this block for upsampling.
220
+ :param down: if True, use this block for downsampling.
221
+ """
222
+
223
+ def __init__(
224
+ self,
225
+ channels,
226
+ emb_channels,
227
+ dropout,
228
+ out_channels=None,
229
+ use_conv=False,
230
+ use_scale_shift_norm=False,
231
+ dims=2,
232
+ use_checkpoint=False,
233
+ up=False,
234
+ down=False,
235
+ kernel_size=3,
236
+ exchange_temb_dims=False,
237
+ skip_t_emb=False,
238
+ ):
239
+ super().__init__()
240
+ self.channels = channels
241
+ self.emb_channels = emb_channels
242
+ self.dropout = dropout
243
+ self.out_channels = out_channels or channels
244
+ self.use_conv = use_conv
245
+ self.use_checkpoint = use_checkpoint
246
+ self.use_scale_shift_norm = use_scale_shift_norm
247
+ self.exchange_temb_dims = exchange_temb_dims
248
+
249
+ if isinstance(kernel_size, Iterable):
250
+ padding = [k // 2 for k in kernel_size]
251
+ else:
252
+ padding = kernel_size // 2
253
+
254
+ self.in_layers = nn.Sequential(
255
+ normalization(channels),
256
+ nn.SiLU(),
257
+ conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
258
+ )
259
+
260
+ self.updown = up or down
261
+
262
+ if up:
263
+ self.h_upd = Upsample(channels, False, dims)
264
+ self.x_upd = Upsample(channels, False, dims)
265
+ elif down:
266
+ self.h_upd = Downsample(channels, False, dims)
267
+ self.x_upd = Downsample(channels, False, dims)
268
+ else:
269
+ self.h_upd = self.x_upd = nn.Identity()
270
+
271
+ self.skip_t_emb = skip_t_emb
272
+ self.emb_out_channels = (
273
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels
274
+ )
275
+ if self.skip_t_emb:
276
+ print(f"Skipping timestep embedding in {self.__class__.__name__}")
277
+ assert not self.use_scale_shift_norm
278
+ self.emb_layers = None
279
+ self.exchange_temb_dims = False
280
+ else:
281
+ self.emb_layers = nn.Sequential(
282
+ nn.SiLU(),
283
+ linear(
284
+ emb_channels,
285
+ self.emb_out_channels,
286
+ ),
287
+ )
288
+
289
+ self.out_layers = nn.Sequential(
290
+ normalization(self.out_channels),
291
+ nn.SiLU(),
292
+ nn.Dropout(p=dropout),
293
+ zero_module(
294
+ conv_nd(
295
+ dims,
296
+ self.out_channels,
297
+ self.out_channels,
298
+ kernel_size,
299
+ padding=padding,
300
+ )
301
+ ),
302
+ )
303
+
304
+ if self.out_channels == channels:
305
+ self.skip_connection = nn.Identity()
306
+ elif use_conv:
307
+ self.skip_connection = conv_nd(
308
+ dims, channels, self.out_channels, kernel_size, padding=padding
309
+ )
310
+ else:
311
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
312
+
313
+ def forward(self, x, emb):
314
+ """
315
+ Apply the block to a Tensor, conditioned on a timestep embedding.
316
+ :param x: an [N x C x ...] Tensor of features.
317
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
318
+ :return: an [N x C x ...] Tensor of outputs.
319
+ """
320
+ return checkpoint(
321
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
322
+ )
323
+
324
+ def _forward(self, x, emb):
325
+ if self.updown:
326
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
327
+ h = in_rest(x)
328
+ h = self.h_upd(h)
329
+ x = self.x_upd(x)
330
+ h = in_conv(h)
331
+ else:
332
+ h = self.in_layers(x)
333
+
334
+ if self.skip_t_emb:
335
+ emb_out = th.zeros_like(h)
336
+ else:
337
+ emb_out = self.emb_layers(emb).type(h.dtype)
338
+ while len(emb_out.shape) < len(h.shape):
339
+ emb_out = emb_out[..., None]
340
+ if self.use_scale_shift_norm:
341
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
342
+ scale, shift = th.chunk(emb_out, 2, dim=1)
343
+ h = out_norm(h) * (1 + scale) + shift
344
+ h = out_rest(h)
345
+ else:
346
+ if self.exchange_temb_dims:
347
+ emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
348
+ h = h + emb_out
349
+ h = self.out_layers(h)
350
+ return self.skip_connection(x) + h
351
+
352
+
353
+ class AttentionBlock(nn.Module):
354
+ """
355
+ An attention block that allows spatial positions to attend to each other.
356
+ Originally ported from here, but adapted to the N-d case.
357
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
358
+ """
359
+
360
+ def __init__(
361
+ self,
362
+ channels,
363
+ num_heads=1,
364
+ num_head_channels=-1,
365
+ use_checkpoint=False,
366
+ use_new_attention_order=False,
367
+ ):
368
+ super().__init__()
369
+ self.channels = channels
370
+ if num_head_channels == -1:
371
+ self.num_heads = num_heads
372
+ else:
373
+ assert (
374
+ channels % num_head_channels == 0
375
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
376
+ self.num_heads = channels // num_head_channels
377
+ self.use_checkpoint = use_checkpoint
378
+ self.norm = normalization(channels)
379
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
380
+ if use_new_attention_order:
381
+ # split qkv before split heads
382
+ self.attention = QKVAttention(self.num_heads)
383
+ else:
384
+ # split heads before split qkv
385
+ self.attention = QKVAttentionLegacy(self.num_heads)
386
+
387
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
388
+
389
+ def forward(self, x, **kwargs):
390
+ # TODO add crossframe attention and use mixed checkpoint
391
+ return checkpoint(
392
+ self._forward, (x,), self.parameters(), True
393
+ ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
394
+ # return pt_checkpoint(self._forward, x) # pytorch
395
+
396
+ def _forward(self, x):
397
+ b, c, *spatial = x.shape
398
+ x = x.reshape(b, c, -1)
399
+ qkv = self.qkv(self.norm(x))
400
+ h = self.attention(qkv)
401
+ h = self.proj_out(h)
402
+ return (x + h).reshape(b, c, *spatial)
403
+
404
+
405
+ def count_flops_attn(model, _x, y):
406
+ """
407
+ A counter for the `thop` package to count the operations in an
408
+ attention operation.
409
+ Meant to be used like:
410
+ macs, params = thop.profile(
411
+ model,
412
+ inputs=(inputs, timestamps),
413
+ custom_ops={QKVAttention: QKVAttention.count_flops},
414
+ )
415
+ """
416
+ b, c, *spatial = y[0].shape
417
+ num_spatial = int(np.prod(spatial))
418
+ # We perform two matmuls with the same number of ops.
419
+ # The first computes the weight matrix, the second computes
420
+ # the combination of the value vectors.
421
+ matmul_ops = 2 * b * (num_spatial**2) * c
422
+ model.total_ops += th.DoubleTensor([matmul_ops])
423
+
424
+
425
+ class QKVAttentionLegacy(nn.Module):
426
+ """
427
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
428
+ """
429
+
430
+ def __init__(self, n_heads):
431
+ super().__init__()
432
+ self.n_heads = n_heads
433
+
434
+ def forward(self, qkv):
435
+ """
436
+ Apply QKV attention.
437
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
438
+ :return: an [N x (H * C) x T] tensor after attention.
439
+ """
440
+ bs, width, length = qkv.shape
441
+ assert width % (3 * self.n_heads) == 0
442
+ ch = width // (3 * self.n_heads)
443
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
444
+ scale = 1 / math.sqrt(math.sqrt(ch))
445
+ weight = th.einsum(
446
+ "bct,bcs->bts", q * scale, k * scale
447
+ ) # More stable with f16 than dividing afterwards
448
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
449
+ a = th.einsum("bts,bcs->bct", weight, v)
450
+ return a.reshape(bs, -1, length)
451
+
452
+ @staticmethod
453
+ def count_flops(model, _x, y):
454
+ return count_flops_attn(model, _x, y)
455
+
456
+
457
+ class QKVAttention(nn.Module):
458
+ """
459
+ A module which performs QKV attention and splits in a different order.
460
+ """
461
+
462
+ def __init__(self, n_heads):
463
+ super().__init__()
464
+ self.n_heads = n_heads
465
+
466
+ def forward(self, qkv):
467
+ """
468
+ Apply QKV attention.
469
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
470
+ :return: an [N x (H * C) x T] tensor after attention.
471
+ """
472
+ bs, width, length = qkv.shape
473
+ assert width % (3 * self.n_heads) == 0
474
+ ch = width // (3 * self.n_heads)
475
+ q, k, v = qkv.chunk(3, dim=1)
476
+ scale = 1 / math.sqrt(math.sqrt(ch))
477
+ weight = th.einsum(
478
+ "bct,bcs->bts",
479
+ (q * scale).view(bs * self.n_heads, ch, length),
480
+ (k * scale).view(bs * self.n_heads, ch, length),
481
+ ) # More stable with f16 than dividing afterwards
482
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
483
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
484
+ return a.reshape(bs, -1, length)
485
+
486
+ @staticmethod
487
+ def count_flops(model, _x, y):
488
+ return count_flops_attn(model, _x, y)
489
+
490
+
491
+ class Timestep(nn.Module):
492
+ def __init__(self, dim):
493
+ super().__init__()
494
+ self.dim = dim
495
+
496
+ def forward(self, t):
497
+ return timestep_embedding(t, self.dim)
498
+
499
+
500
+ class UNetModel(nn.Module):
501
+ """
502
+ The full UNet model with attention and timestep embedding.
503
+ :param in_channels: channels in the input Tensor.
504
+ :param model_channels: base channel count for the model.
505
+ :param out_channels: channels in the output Tensor.
506
+ :param num_res_blocks: number of residual blocks per downsample.
507
+ :param attention_resolutions: a collection of downsample rates at which
508
+ attention will take place. May be a set, list, or tuple.
509
+ For example, if this contains 4, then at 4x downsampling, attention
510
+ will be used.
511
+ :param dropout: the dropout probability.
512
+ :param channel_mult: channel multiplier for each level of the UNet.
513
+ :param conv_resample: if True, use learned convolutions for upsampling and
514
+ downsampling.
515
+ :param dims: determines if the signal is 1D, 2D, or 3D.
516
+ :param num_classes: if specified (as an int), then this model will be
517
+ class-conditional with `num_classes` classes.
518
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
519
+ :param num_heads: the number of attention heads in each attention layer.
520
+ :param num_heads_channels: if specified, ignore num_heads and instead use
521
+ a fixed channel width per attention head.
522
+ :param num_heads_upsample: works with num_heads to set a different number
523
+ of heads for upsampling. Deprecated.
524
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
525
+ :param resblock_updown: use residual blocks for up/downsampling.
526
+ :param use_new_attention_order: use a different attention pattern for potentially
527
+ increased efficiency.
528
+ """
529
+
530
+ def __init__(
531
+ self,
532
+ in_channels,
533
+ model_channels,
534
+ out_channels,
535
+ num_res_blocks,
536
+ attention_resolutions,
537
+ dropout=0,
538
+ channel_mult=(1, 2, 4, 8),
539
+ conv_resample=True,
540
+ dims=2,
541
+ num_classes=None,
542
+ use_checkpoint=False,
543
+ use_fp16=False,
544
+ num_heads=-1,
545
+ num_head_channels=-1,
546
+ num_heads_upsample=-1,
547
+ use_scale_shift_norm=False,
548
+ resblock_updown=False,
549
+ use_new_attention_order=False,
550
+ use_spatial_transformer=False, # custom transformer support
551
+ transformer_depth=1, # custom transformer support
552
+ context_dim=None, # custom transformer support
553
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
554
+ legacy=True,
555
+ disable_self_attentions=None,
556
+ num_attention_blocks=None,
557
+ disable_middle_self_attn=False,
558
+ use_linear_in_transformer=False,
559
+ spatial_transformer_attn_type="softmax",
560
+ adm_in_channels=None,
561
+ use_fairscale_checkpoint=False,
562
+ offload_to_cpu=False,
563
+ transformer_depth_middle=None,
564
+ ):
565
+ super().__init__()
566
+ from omegaconf.listconfig import ListConfig
567
+
568
+ if use_spatial_transformer:
569
+ assert (
570
+ context_dim is not None
571
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
572
+
573
+ if context_dim is not None:
574
+ assert (
575
+ use_spatial_transformer
576
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
577
+ if type(context_dim) == ListConfig:
578
+ context_dim = list(context_dim)
579
+
580
+ if num_heads_upsample == -1:
581
+ num_heads_upsample = num_heads
582
+
583
+ if num_heads == -1:
584
+ assert (
585
+ num_head_channels != -1
586
+ ), "Either num_heads or num_head_channels has to be set"
587
+
588
+ if num_head_channels == -1:
589
+ assert (
590
+ num_heads != -1
591
+ ), "Either num_heads or num_head_channels has to be set"
592
+
593
+ self.in_channels = in_channels
594
+ self.model_channels = model_channels
595
+ self.out_channels = out_channels
596
+ if isinstance(transformer_depth, int):
597
+ transformer_depth = len(channel_mult) * [transformer_depth]
598
+ elif isinstance(transformer_depth, ListConfig):
599
+ transformer_depth = list(transformer_depth)
600
+ transformer_depth_middle = default(
601
+ transformer_depth_middle, transformer_depth[-1]
602
+ )
603
+
604
+ if isinstance(num_res_blocks, int):
605
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
606
+ else:
607
+ if len(num_res_blocks) != len(channel_mult):
608
+ raise ValueError(
609
+ "provide num_res_blocks either as an int (globally constant) or "
610
+ "as a list/tuple (per-level) with the same length as channel_mult"
611
+ )
612
+ self.num_res_blocks = num_res_blocks
613
+ # self.num_res_blocks = num_res_blocks
614
+ if disable_self_attentions is not None:
615
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
616
+ assert len(disable_self_attentions) == len(channel_mult)
617
+ if num_attention_blocks is not None:
618
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
619
+ assert all(
620
+ map(
621
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
622
+ range(len(num_attention_blocks)),
623
+ )
624
+ )
625
+ print(
626
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
627
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
628
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
629
+ f"attention will still not be set."
630
+ ) # todo: convert to warning
631
+
632
+ self.attention_resolutions = attention_resolutions
633
+ self.dropout = dropout
634
+ self.channel_mult = channel_mult
635
+ self.conv_resample = conv_resample
636
+ self.num_classes = num_classes
637
+ self.use_checkpoint = use_checkpoint
638
+ if use_fp16:
639
+ print("WARNING: use_fp16 was dropped and has no effect anymore.")
640
+ # self.dtype = th.float16 if use_fp16 else th.float32
641
+ self.num_heads = num_heads
642
+ self.num_head_channels = num_head_channels
643
+ self.num_heads_upsample = num_heads_upsample
644
+ self.predict_codebook_ids = n_embed is not None
645
+
646
+ assert use_fairscale_checkpoint != use_checkpoint or not (
647
+ use_checkpoint or use_fairscale_checkpoint
648
+ )
649
+
650
+ self.use_fairscale_checkpoint = False
651
+ checkpoint_wrapper_fn = (
652
+ partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
653
+ if self.use_fairscale_checkpoint
654
+ else lambda x: x
655
+ )
656
+
657
+ time_embed_dim = model_channels * 4
658
+ self.time_embed = checkpoint_wrapper_fn(
659
+ nn.Sequential(
660
+ linear(model_channels, time_embed_dim),
661
+ nn.SiLU(),
662
+ linear(time_embed_dim, time_embed_dim),
663
+ )
664
+ )
665
+
666
+ if self.num_classes is not None:
667
+ if isinstance(self.num_classes, int):
668
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
669
+ elif self.num_classes == "continuous":
670
+ print("setting up linear c_adm embedding layer")
671
+ self.label_emb = nn.Linear(1, time_embed_dim)
672
+ elif self.num_classes == "timestep":
673
+ self.label_emb = checkpoint_wrapper_fn(
674
+ nn.Sequential(
675
+ Timestep(model_channels),
676
+ nn.Sequential(
677
+ linear(model_channels, time_embed_dim),
678
+ nn.SiLU(),
679
+ linear(time_embed_dim, time_embed_dim),
680
+ ),
681
+ )
682
+ )
683
+ elif self.num_classes == "sequential":
684
+ assert adm_in_channels is not None
685
+ self.label_emb = nn.Sequential(
686
+ nn.Sequential(
687
+ linear(adm_in_channels, time_embed_dim),
688
+ nn.SiLU(),
689
+ linear(time_embed_dim, time_embed_dim),
690
+ )
691
+ )
692
+ else:
693
+ raise ValueError()
694
+
695
+ self.input_blocks = nn.ModuleList(
696
+ [
697
+ TimestepEmbedSequential(
698
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
699
+ )
700
+ ]
701
+ )
702
+ self._feature_size = model_channels
703
+ input_block_chans = [model_channels]
704
+ ch = model_channels
705
+ ds = 1
706
+ for level, mult in enumerate(channel_mult):
707
+ for nr in range(self.num_res_blocks[level]):
708
+ layers = [
709
+ checkpoint_wrapper_fn(
710
+ ResBlock(
711
+ ch,
712
+ time_embed_dim,
713
+ dropout,
714
+ out_channels=mult * model_channels,
715
+ dims=dims,
716
+ use_checkpoint=use_checkpoint,
717
+ use_scale_shift_norm=use_scale_shift_norm,
718
+ )
719
+ )
720
+ ]
721
+ ch = mult * model_channels
722
+ if ds in attention_resolutions:
723
+ if num_head_channels == -1:
724
+ dim_head = ch // num_heads
725
+ else:
726
+ num_heads = ch // num_head_channels
727
+ dim_head = num_head_channels
728
+ if legacy:
729
+ # num_heads = 1
730
+ dim_head = (
731
+ ch // num_heads
732
+ if use_spatial_transformer
733
+ else num_head_channels
734
+ )
735
+ if exists(disable_self_attentions):
736
+ disabled_sa = disable_self_attentions[level]
737
+ else:
738
+ disabled_sa = False
739
+
740
+ if (
741
+ not exists(num_attention_blocks)
742
+ or nr < num_attention_blocks[level]
743
+ ):
744
+ layers.append(
745
+ checkpoint_wrapper_fn(
746
+ AttentionBlock(
747
+ ch,
748
+ use_checkpoint=use_checkpoint,
749
+ num_heads=num_heads,
750
+ num_head_channels=dim_head,
751
+ use_new_attention_order=use_new_attention_order,
752
+ )
753
+ )
754
+ if not use_spatial_transformer
755
+ else checkpoint_wrapper_fn(
756
+ SpatialTransformer(
757
+ ch,
758
+ num_heads,
759
+ dim_head,
760
+ depth=transformer_depth[level],
761
+ context_dim=context_dim,
762
+ disable_self_attn=disabled_sa,
763
+ use_linear=use_linear_in_transformer,
764
+ attn_type=spatial_transformer_attn_type,
765
+ use_checkpoint=use_checkpoint,
766
+ )
767
+ )
768
+ )
769
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
770
+ self._feature_size += ch
771
+ input_block_chans.append(ch)
772
+ if level != len(channel_mult) - 1:
773
+ out_ch = ch
774
+ self.input_blocks.append(
775
+ TimestepEmbedSequential(
776
+ checkpoint_wrapper_fn(
777
+ ResBlock(
778
+ ch,
779
+ time_embed_dim,
780
+ dropout,
781
+ out_channels=out_ch,
782
+ dims=dims,
783
+ use_checkpoint=use_checkpoint,
784
+ use_scale_shift_norm=use_scale_shift_norm,
785
+ down=True,
786
+ )
787
+ )
788
+ if resblock_updown
789
+ else Downsample(
790
+ ch, conv_resample, dims=dims, out_channels=out_ch
791
+ )
792
+ )
793
+ )
794
+ ch = out_ch
795
+ input_block_chans.append(ch)
796
+ ds *= 2
797
+ self._feature_size += ch
798
+
799
+ if num_head_channels == -1:
800
+ dim_head = ch // num_heads
801
+ else:
802
+ num_heads = ch // num_head_channels
803
+ dim_head = num_head_channels
804
+ if legacy:
805
+ # num_heads = 1
806
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
807
+ self.middle_block = TimestepEmbedSequential(
808
+ checkpoint_wrapper_fn(
809
+ ResBlock(
810
+ ch,
811
+ time_embed_dim,
812
+ dropout,
813
+ dims=dims,
814
+ use_checkpoint=use_checkpoint,
815
+ use_scale_shift_norm=use_scale_shift_norm,
816
+ )
817
+ ),
818
+ checkpoint_wrapper_fn(
819
+ AttentionBlock(
820
+ ch,
821
+ use_checkpoint=use_checkpoint,
822
+ num_heads=num_heads,
823
+ num_head_channels=dim_head,
824
+ use_new_attention_order=use_new_attention_order,
825
+ )
826
+ )
827
+ if not use_spatial_transformer
828
+ else checkpoint_wrapper_fn(
829
+ SpatialTransformer( # always uses a self-attn
830
+ ch,
831
+ num_heads,
832
+ dim_head,
833
+ depth=transformer_depth_middle,
834
+ context_dim=context_dim,
835
+ disable_self_attn=disable_middle_self_attn,
836
+ use_linear=use_linear_in_transformer,
837
+ attn_type=spatial_transformer_attn_type,
838
+ use_checkpoint=use_checkpoint,
839
+ )
840
+ ),
841
+ checkpoint_wrapper_fn(
842
+ ResBlock(
843
+ ch,
844
+ time_embed_dim,
845
+ dropout,
846
+ dims=dims,
847
+ use_checkpoint=use_checkpoint,
848
+ use_scale_shift_norm=use_scale_shift_norm,
849
+ )
850
+ ),
851
+ )
852
+ self._feature_size += ch
853
+
854
+ self.output_blocks = nn.ModuleList([])
855
+ for level, mult in list(enumerate(channel_mult))[::-1]:
856
+ for i in range(self.num_res_blocks[level] + 1):
857
+ ich = input_block_chans.pop()
858
+ layers = [
859
+ checkpoint_wrapper_fn(
860
+ ResBlock(
861
+ ch + ich,
862
+ time_embed_dim,
863
+ dropout,
864
+ out_channels=model_channels * mult,
865
+ dims=dims,
866
+ use_checkpoint=use_checkpoint,
867
+ use_scale_shift_norm=use_scale_shift_norm,
868
+ )
869
+ )
870
+ ]
871
+ ch = model_channels * mult
872
+ if ds in attention_resolutions:
873
+ if num_head_channels == -1:
874
+ dim_head = ch // num_heads
875
+ else:
876
+ num_heads = ch // num_head_channels
877
+ dim_head = num_head_channels
878
+ if legacy:
879
+ # num_heads = 1
880
+ dim_head = (
881
+ ch // num_heads
882
+ if use_spatial_transformer
883
+ else num_head_channels
884
+ )
885
+ if exists(disable_self_attentions):
886
+ disabled_sa = disable_self_attentions[level]
887
+ else:
888
+ disabled_sa = False
889
+
890
+ if (
891
+ not exists(num_attention_blocks)
892
+ or i < num_attention_blocks[level]
893
+ ):
894
+ layers.append(
895
+ checkpoint_wrapper_fn(
896
+ AttentionBlock(
897
+ ch,
898
+ use_checkpoint=use_checkpoint,
899
+ num_heads=num_heads_upsample,
900
+ num_head_channels=dim_head,
901
+ use_new_attention_order=use_new_attention_order,
902
+ )
903
+ )
904
+ if not use_spatial_transformer
905
+ else checkpoint_wrapper_fn(
906
+ SpatialTransformer(
907
+ ch,
908
+ num_heads,
909
+ dim_head,
910
+ depth=transformer_depth[level],
911
+ context_dim=context_dim,
912
+ disable_self_attn=disabled_sa,
913
+ use_linear=use_linear_in_transformer,
914
+ attn_type=spatial_transformer_attn_type,
915
+ use_checkpoint=use_checkpoint,
916
+ )
917
+ )
918
+ )
919
+ if level and i == self.num_res_blocks[level]:
920
+ out_ch = ch
921
+ layers.append(
922
+ checkpoint_wrapper_fn(
923
+ ResBlock(
924
+ ch,
925
+ time_embed_dim,
926
+ dropout,
927
+ out_channels=out_ch,
928
+ dims=dims,
929
+ use_checkpoint=use_checkpoint,
930
+ use_scale_shift_norm=use_scale_shift_norm,
931
+ up=True,
932
+ )
933
+ )
934
+ if resblock_updown
935
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
936
+ )
937
+ ds //= 2
938
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
939
+ self._feature_size += ch
940
+
941
+ self.out = checkpoint_wrapper_fn(
942
+ nn.Sequential(
943
+ normalization(ch),
944
+ nn.SiLU(),
945
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
946
+ )
947
+ )
948
+ if self.predict_codebook_ids:
949
+ self.id_predictor = checkpoint_wrapper_fn(
950
+ nn.Sequential(
951
+ normalization(ch),
952
+ conv_nd(dims, model_channels, n_embed, 1),
953
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
954
+ )
955
+ )
956
+
957
+ def convert_to_fp16(self):
958
+ """
959
+ Convert the torso of the model to float16.
960
+ """
961
+ self.input_blocks.apply(convert_module_to_f16)
962
+ self.middle_block.apply(convert_module_to_f16)
963
+ self.output_blocks.apply(convert_module_to_f16)
964
+
965
+ def convert_to_fp32(self):
966
+ """
967
+ Convert the torso of the model to float32.
968
+ """
969
+ self.input_blocks.apply(convert_module_to_f32)
970
+ self.middle_block.apply(convert_module_to_f32)
971
+ self.output_blocks.apply(convert_module_to_f32)
972
+
973
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
974
+ """
975
+ Apply the model to an input batch.
976
+ :param x: an [N x C x ...] Tensor of inputs.
977
+ :param timesteps: a 1-D batch of timesteps.
978
+ :param context: conditioning plugged in via crossattn
979
+ :param y: an [N] Tensor of labels, if class-conditional.
980
+ :return: an [N x C x ...] Tensor of outputs.
981
+ """
982
+ assert (y is not None) == (
983
+ self.num_classes is not None
984
+ ), "must specify y if and only if the model is class-conditional"
985
+ hs = []
986
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
987
+ emb = self.time_embed(t_emb)
988
+
989
+ if self.num_classes is not None:
990
+ assert y.shape[0] == x.shape[0]
991
+ emb = emb + self.label_emb(y)
992
+
993
+ # h = x.type(self.dtype)
994
+ h = x
995
+ for i, module in enumerate(self.input_blocks):
996
+ h = module(h, emb, context)
997
+ hs.append(h)
998
+ h = self.middle_block(h, emb, context)
999
+ for i, module in enumerate(self.output_blocks):
1000
+ h = th.cat([h, hs.pop()], dim=1)
1001
+ h = module(h, emb, context)
1002
+ h = h.type(x.dtype)
1003
+ if self.predict_codebook_ids:
1004
+ assert False, "not supported anymore. what the f*** are you doing?"
1005
+ else:
1006
+ return self.out(h)
1007
+
1008
+
1009
+
1010
+ class UNetModel(nn.Module):
1011
+ """
1012
+ The full UNet model with attention and timestep embedding.
1013
+ :param in_channels: channels in the input Tensor.
1014
+ :param model_channels: base channel count for the model.
1015
+ :param out_channels: channels in the output Tensor.
1016
+ :param num_res_blocks: number of residual blocks per downsample.
1017
+ :param attention_resolutions: a collection of downsample rates at which
1018
+ attention will take place. May be a set, list, or tuple.
1019
+ For example, if this contains 4, then at 4x downsampling, attention
1020
+ will be used.
1021
+ :param dropout: the dropout probability.
1022
+ :param channel_mult: channel multiplier for each level of the UNet.
1023
+ :param conv_resample: if True, use learned convolutions for upsampling and
1024
+ downsampling.
1025
+ :param dims: determines if the signal is 1D, 2D, or 3D.
1026
+ :param num_classes: if specified (as an int), then this model will be
1027
+ class-conditional with `num_classes` classes.
1028
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
1029
+ :param num_heads: the number of attention heads in each attention layer.
1030
+ :param num_heads_channels: if specified, ignore num_heads and instead use
1031
+ a fixed channel width per attention head.
1032
+ :param num_heads_upsample: works with num_heads to set a different number
1033
+ of heads for upsampling. Deprecated.
1034
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
1035
+ :param resblock_updown: use residual blocks for up/downsampling.
1036
+ :param use_new_attention_order: use a different attention pattern for potentially
1037
+ increased efficiency.
1038
+ """
1039
+
1040
+ def __init__(
1041
+ self,
1042
+ in_channels,
1043
+ model_channels,
1044
+ out_channels,
1045
+ num_res_blocks,
1046
+ attention_resolutions,
1047
+ dropout=0,
1048
+ channel_mult=(1, 2, 4, 8),
1049
+ conv_resample=True,
1050
+ dims=2,
1051
+ num_classes=None,
1052
+ use_checkpoint=False,
1053
+ use_fp16=False,
1054
+ num_heads=-1,
1055
+ num_head_channels=-1,
1056
+ num_heads_upsample=-1,
1057
+ use_scale_shift_norm=False,
1058
+ resblock_updown=False,
1059
+ use_new_attention_order=False,
1060
+ use_spatial_transformer=False, # custom transformer support
1061
+ transformer_depth=1, # custom transformer support
1062
+ context_dim=None, # custom transformer support
1063
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
1064
+ legacy=True,
1065
+ disable_self_attentions=None,
1066
+ num_attention_blocks=None,
1067
+ disable_middle_self_attn=False,
1068
+ use_linear_in_transformer=False,
1069
+ spatial_transformer_attn_type="softmax",
1070
+ adm_in_channels=None,
1071
+ use_fairscale_checkpoint=False,
1072
+ offload_to_cpu=False,
1073
+ transformer_depth_middle=None,
1074
+ ):
1075
+ super().__init__()
1076
+ from omegaconf.listconfig import ListConfig
1077
+
1078
+ if use_spatial_transformer:
1079
+ assert (
1080
+ context_dim is not None
1081
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
1082
+
1083
+ if context_dim is not None:
1084
+ assert (
1085
+ use_spatial_transformer
1086
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
1087
+ if type(context_dim) == ListConfig:
1088
+ context_dim = list(context_dim)
1089
+
1090
+ if num_heads_upsample == -1:
1091
+ num_heads_upsample = num_heads
1092
+
1093
+ if num_heads == -1:
1094
+ assert (
1095
+ num_head_channels != -1
1096
+ ), "Either num_heads or num_head_channels has to be set"
1097
+
1098
+ if num_head_channels == -1:
1099
+ assert (
1100
+ num_heads != -1
1101
+ ), "Either num_heads or num_head_channels has to be set"
1102
+
1103
+ self.in_channels = in_channels
1104
+ self.model_channels = model_channels
1105
+ self.out_channels = out_channels
1106
+ if isinstance(transformer_depth, int):
1107
+ transformer_depth = len(channel_mult) * [transformer_depth]
1108
+ elif isinstance(transformer_depth, ListConfig):
1109
+ transformer_depth = list(transformer_depth)
1110
+ transformer_depth_middle = default(
1111
+ transformer_depth_middle, transformer_depth[-1]
1112
+ )
1113
+
1114
+ if isinstance(num_res_blocks, int):
1115
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
1116
+ else:
1117
+ if len(num_res_blocks) != len(channel_mult):
1118
+ raise ValueError(
1119
+ "provide num_res_blocks either as an int (globally constant) or "
1120
+ "as a list/tuple (per-level) with the same length as channel_mult"
1121
+ )
1122
+ self.num_res_blocks = num_res_blocks
1123
+ # self.num_res_blocks = num_res_blocks
1124
+ if disable_self_attentions is not None:
1125
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
1126
+ assert len(disable_self_attentions) == len(channel_mult)
1127
+ if num_attention_blocks is not None:
1128
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
1129
+ assert all(
1130
+ map(
1131
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
1132
+ range(len(num_attention_blocks)),
1133
+ )
1134
+ )
1135
+ print(
1136
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
1137
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
1138
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
1139
+ f"attention will still not be set."
1140
+ ) # todo: convert to warning
1141
+
1142
+ self.attention_resolutions = attention_resolutions
1143
+ self.dropout = dropout
1144
+ self.channel_mult = channel_mult
1145
+ self.conv_resample = conv_resample
1146
+ self.num_classes = num_classes
1147
+ self.use_checkpoint = use_checkpoint
1148
+ if use_fp16:
1149
+ print("WARNING: use_fp16 was dropped and has no effect anymore.")
1150
+ # self.dtype = th.float16 if use_fp16 else th.float32
1151
+ self.num_heads = num_heads
1152
+ self.num_head_channels = num_head_channels
1153
+ self.num_heads_upsample = num_heads_upsample
1154
+ self.predict_codebook_ids = n_embed is not None
1155
+
1156
+ assert use_fairscale_checkpoint != use_checkpoint or not (
1157
+ use_checkpoint or use_fairscale_checkpoint
1158
+ )
1159
+
1160
+ self.use_fairscale_checkpoint = False
1161
+ checkpoint_wrapper_fn = (
1162
+ partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
1163
+ if self.use_fairscale_checkpoint
1164
+ else lambda x: x
1165
+ )
1166
+
1167
+ time_embed_dim = model_channels * 4
1168
+ self.time_embed = checkpoint_wrapper_fn(
1169
+ nn.Sequential(
1170
+ linear(model_channels, time_embed_dim),
1171
+ nn.SiLU(),
1172
+ linear(time_embed_dim, time_embed_dim),
1173
+ )
1174
+ )
1175
+
1176
+ if self.num_classes is not None:
1177
+ if isinstance(self.num_classes, int):
1178
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
1179
+ elif self.num_classes == "continuous":
1180
+ print("setting up linear c_adm embedding layer")
1181
+ self.label_emb = nn.Linear(1, time_embed_dim)
1182
+ elif self.num_classes == "timestep":
1183
+ self.label_emb = checkpoint_wrapper_fn(
1184
+ nn.Sequential(
1185
+ Timestep(model_channels),
1186
+ nn.Sequential(
1187
+ linear(model_channels, time_embed_dim),
1188
+ nn.SiLU(),
1189
+ linear(time_embed_dim, time_embed_dim),
1190
+ ),
1191
+ )
1192
+ )
1193
+ elif self.num_classes == "sequential":
1194
+ assert adm_in_channels is not None
1195
+ self.label_emb = nn.Sequential(
1196
+ nn.Sequential(
1197
+ linear(adm_in_channels, time_embed_dim),
1198
+ nn.SiLU(),
1199
+ linear(time_embed_dim, time_embed_dim),
1200
+ )
1201
+ )
1202
+ else:
1203
+ raise ValueError()
1204
+
1205
+ self.input_blocks = nn.ModuleList(
1206
+ [
1207
+ TimestepEmbedSequential(
1208
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
1209
+ )
1210
+ ]
1211
+ )
1212
+ self._feature_size = model_channels
1213
+ input_block_chans = [model_channels]
1214
+ ch = model_channels
1215
+ ds = 1
1216
+ for level, mult in enumerate(channel_mult):
1217
+ for nr in range(self.num_res_blocks[level]):
1218
+ layers = [
1219
+ checkpoint_wrapper_fn(
1220
+ ResBlock(
1221
+ ch,
1222
+ time_embed_dim,
1223
+ dropout,
1224
+ out_channels=mult * model_channels,
1225
+ dims=dims,
1226
+ use_checkpoint=use_checkpoint,
1227
+ use_scale_shift_norm=use_scale_shift_norm,
1228
+ )
1229
+ )
1230
+ ]
1231
+ ch = mult * model_channels
1232
+ if ds in attention_resolutions:
1233
+ if num_head_channels == -1:
1234
+ dim_head = ch // num_heads
1235
+ else:
1236
+ num_heads = ch // num_head_channels
1237
+ dim_head = num_head_channels
1238
+ if legacy:
1239
+ # num_heads = 1
1240
+ dim_head = (
1241
+ ch // num_heads
1242
+ if use_spatial_transformer
1243
+ else num_head_channels
1244
+ )
1245
+ if exists(disable_self_attentions):
1246
+ disabled_sa = disable_self_attentions[level]
1247
+ else:
1248
+ disabled_sa = False
1249
+
1250
+ if (
1251
+ not exists(num_attention_blocks)
1252
+ or nr < num_attention_blocks[level]
1253
+ ):
1254
+ layers.append(
1255
+ checkpoint_wrapper_fn(
1256
+ AttentionBlock(
1257
+ ch,
1258
+ use_checkpoint=use_checkpoint,
1259
+ num_heads=num_heads,
1260
+ num_head_channels=dim_head,
1261
+ use_new_attention_order=use_new_attention_order,
1262
+ )
1263
+ )
1264
+ if not use_spatial_transformer
1265
+ else checkpoint_wrapper_fn(
1266
+ SpatialTransformer(
1267
+ ch,
1268
+ num_heads,
1269
+ dim_head,
1270
+ depth=transformer_depth[level],
1271
+ context_dim=context_dim,
1272
+ disable_self_attn=disabled_sa,
1273
+ use_linear=use_linear_in_transformer,
1274
+ attn_type=spatial_transformer_attn_type,
1275
+ use_checkpoint=use_checkpoint,
1276
+ )
1277
+ )
1278
+ )
1279
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
1280
+ self._feature_size += ch
1281
+ input_block_chans.append(ch)
1282
+ if level != len(channel_mult) - 1:
1283
+ out_ch = ch
1284
+ self.input_blocks.append(
1285
+ TimestepEmbedSequential(
1286
+ checkpoint_wrapper_fn(
1287
+ ResBlock(
1288
+ ch,
1289
+ time_embed_dim,
1290
+ dropout,
1291
+ out_channels=out_ch,
1292
+ dims=dims,
1293
+ use_checkpoint=use_checkpoint,
1294
+ use_scale_shift_norm=use_scale_shift_norm,
1295
+ down=True,
1296
+ )
1297
+ )
1298
+ if resblock_updown
1299
+ else Downsample(
1300
+ ch, conv_resample, dims=dims, out_channels=out_ch
1301
+ )
1302
+ )
1303
+ )
1304
+ ch = out_ch
1305
+ input_block_chans.append(ch)
1306
+ ds *= 2
1307
+ self._feature_size += ch
1308
+
1309
+ if num_head_channels == -1:
1310
+ dim_head = ch // num_heads
1311
+ else:
1312
+ num_heads = ch // num_head_channels
1313
+ dim_head = num_head_channels
1314
+ if legacy:
1315
+ # num_heads = 1
1316
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
1317
+ self.middle_block = TimestepEmbedSequential(
1318
+ checkpoint_wrapper_fn(
1319
+ ResBlock(
1320
+ ch,
1321
+ time_embed_dim,
1322
+ dropout,
1323
+ dims=dims,
1324
+ use_checkpoint=use_checkpoint,
1325
+ use_scale_shift_norm=use_scale_shift_norm,
1326
+ )
1327
+ ),
1328
+ checkpoint_wrapper_fn(
1329
+ AttentionBlock(
1330
+ ch,
1331
+ use_checkpoint=use_checkpoint,
1332
+ num_heads=num_heads,
1333
+ num_head_channels=dim_head,
1334
+ use_new_attention_order=use_new_attention_order,
1335
+ )
1336
+ )
1337
+ if not use_spatial_transformer
1338
+ else checkpoint_wrapper_fn(
1339
+ SpatialTransformer( # always uses a self-attn
1340
+ ch,
1341
+ num_heads,
1342
+ dim_head,
1343
+ depth=transformer_depth_middle,
1344
+ context_dim=context_dim,
1345
+ disable_self_attn=disable_middle_self_attn,
1346
+ use_linear=use_linear_in_transformer,
1347
+ attn_type=spatial_transformer_attn_type,
1348
+ use_checkpoint=use_checkpoint,
1349
+ )
1350
+ ),
1351
+ checkpoint_wrapper_fn(
1352
+ ResBlock(
1353
+ ch,
1354
+ time_embed_dim,
1355
+ dropout,
1356
+ dims=dims,
1357
+ use_checkpoint=use_checkpoint,
1358
+ use_scale_shift_norm=use_scale_shift_norm,
1359
+ )
1360
+ ),
1361
+ )
1362
+ self._feature_size += ch
1363
+
1364
+ self.output_blocks = nn.ModuleList([])
1365
+ for level, mult in list(enumerate(channel_mult))[::-1]:
1366
+ for i in range(self.num_res_blocks[level] + 1):
1367
+ ich = input_block_chans.pop()
1368
+ layers = [
1369
+ checkpoint_wrapper_fn(
1370
+ ResBlock(
1371
+ ch + ich,
1372
+ time_embed_dim,
1373
+ dropout,
1374
+ out_channels=model_channels * mult,
1375
+ dims=dims,
1376
+ use_checkpoint=use_checkpoint,
1377
+ use_scale_shift_norm=use_scale_shift_norm,
1378
+ )
1379
+ )
1380
+ ]
1381
+ ch = model_channels * mult
1382
+ if ds in attention_resolutions:
1383
+ if num_head_channels == -1:
1384
+ dim_head = ch // num_heads
1385
+ else:
1386
+ num_heads = ch // num_head_channels
1387
+ dim_head = num_head_channels
1388
+ if legacy:
1389
+ # num_heads = 1
1390
+ dim_head = (
1391
+ ch // num_heads
1392
+ if use_spatial_transformer
1393
+ else num_head_channels
1394
+ )
1395
+ if exists(disable_self_attentions):
1396
+ disabled_sa = disable_self_attentions[level]
1397
+ else:
1398
+ disabled_sa = False
1399
+
1400
+ if (
1401
+ not exists(num_attention_blocks)
1402
+ or i < num_attention_blocks[level]
1403
+ ):
1404
+ layers.append(
1405
+ checkpoint_wrapper_fn(
1406
+ AttentionBlock(
1407
+ ch,
1408
+ use_checkpoint=use_checkpoint,
1409
+ num_heads=num_heads_upsample,
1410
+ num_head_channels=dim_head,
1411
+ use_new_attention_order=use_new_attention_order,
1412
+ )
1413
+ )
1414
+ if not use_spatial_transformer
1415
+ else checkpoint_wrapper_fn(
1416
+ SpatialTransformer(
1417
+ ch,
1418
+ num_heads,
1419
+ dim_head,
1420
+ depth=transformer_depth[level],
1421
+ context_dim=context_dim,
1422
+ disable_self_attn=disabled_sa,
1423
+ use_linear=use_linear_in_transformer,
1424
+ attn_type=spatial_transformer_attn_type,
1425
+ use_checkpoint=use_checkpoint,
1426
+ )
1427
+ )
1428
+ )
1429
+ if level and i == self.num_res_blocks[level]:
1430
+ out_ch = ch
1431
+ layers.append(
1432
+ checkpoint_wrapper_fn(
1433
+ ResBlock(
1434
+ ch,
1435
+ time_embed_dim,
1436
+ dropout,
1437
+ out_channels=out_ch,
1438
+ dims=dims,
1439
+ use_checkpoint=use_checkpoint,
1440
+ use_scale_shift_norm=use_scale_shift_norm,
1441
+ up=True,
1442
+ )
1443
+ )
1444
+ if resblock_updown
1445
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
1446
+ )
1447
+ ds //= 2
1448
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
1449
+ self._feature_size += ch
1450
+
1451
+ self.out = checkpoint_wrapper_fn(
1452
+ nn.Sequential(
1453
+ normalization(ch),
1454
+ nn.SiLU(),
1455
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
1456
+ )
1457
+ )
1458
+ if self.predict_codebook_ids:
1459
+ self.id_predictor = checkpoint_wrapper_fn(
1460
+ nn.Sequential(
1461
+ normalization(ch),
1462
+ conv_nd(dims, model_channels, n_embed, 1),
1463
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
1464
+ )
1465
+ )
1466
+
1467
+ def convert_to_fp16(self):
1468
+ """
1469
+ Convert the torso of the model to float16.
1470
+ """
1471
+ self.input_blocks.apply(convert_module_to_f16)
1472
+ self.middle_block.apply(convert_module_to_f16)
1473
+ self.output_blocks.apply(convert_module_to_f16)
1474
+
1475
+ def convert_to_fp32(self):
1476
+ """
1477
+ Convert the torso of the model to float32.
1478
+ """
1479
+ self.input_blocks.apply(convert_module_to_f32)
1480
+ self.middle_block.apply(convert_module_to_f32)
1481
+ self.output_blocks.apply(convert_module_to_f32)
1482
+
1483
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
1484
+ """
1485
+ Apply the model to an input batch.
1486
+ :param x: an [N x C x ...] Tensor of inputs.
1487
+ :param timesteps: a 1-D batch of timesteps.
1488
+ :param context: conditioning plugged in via crossattn
1489
+ :param y: an [N] Tensor of labels, if class-conditional.
1490
+ :return: an [N x C x ...] Tensor of outputs.
1491
+ """
1492
+ assert (y is not None) == (
1493
+ self.num_classes is not None
1494
+ ), "must specify y if and only if the model is class-conditional"
1495
+ hs = []
1496
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
1497
+ emb = self.time_embed(t_emb)
1498
+
1499
+ if self.num_classes is not None:
1500
+ assert y.shape[0] == x.shape[0]
1501
+ emb = emb + self.label_emb(y)
1502
+
1503
+ # h = x.type(self.dtype)
1504
+ h = x
1505
+ for i, module in enumerate(self.input_blocks):
1506
+ h = module(h, emb, context)
1507
+ hs.append(h)
1508
+ h = self.middle_block(h, emb, context)
1509
+ for i, module in enumerate(self.output_blocks):
1510
+ h = th.cat([h, hs.pop()], dim=1)
1511
+ h = module(h, emb, context)
1512
+ h = h.type(x.dtype)
1513
+ if self.predict_codebook_ids:
1514
+ assert False, "not supported anymore. what the f*** are you doing?"
1515
+ else:
1516
+ return self.out(h)
1517
+
1518
+
1519
+ import seaborn as sns
1520
+ import matplotlib.pyplot as plt
1521
+
1522
+ class UNetAddModel(nn.Module):
1523
+
1524
+ def __init__(
1525
+ self,
1526
+ in_channels,
1527
+ ctrl_channels,
1528
+ model_channels,
1529
+ out_channels,
1530
+ num_res_blocks,
1531
+ attention_resolutions,
1532
+ dropout=0,
1533
+ channel_mult=(1, 2, 4, 8),
1534
+ attn_type="attn2",
1535
+ attn_layers=[],
1536
+ conv_resample=True,
1537
+ dims=2,
1538
+ num_classes=None,
1539
+ use_checkpoint=False,
1540
+ use_fp16=False,
1541
+ num_heads=-1,
1542
+ num_head_channels=-1,
1543
+ num_heads_upsample=-1,
1544
+ use_scale_shift_norm=False,
1545
+ resblock_updown=False,
1546
+ use_new_attention_order=False,
1547
+ use_spatial_transformer=False, # custom transformer support
1548
+ transformer_depth=1, # custom transformer support
1549
+ context_dim=None, # custom transformer support
1550
+ add_context_dim=None,
1551
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
1552
+ legacy=True,
1553
+ disable_self_attentions=None,
1554
+ num_attention_blocks=None,
1555
+ disable_middle_self_attn=False,
1556
+ use_linear_in_transformer=False,
1557
+ spatial_transformer_attn_type="softmax",
1558
+ adm_in_channels=None,
1559
+ use_fairscale_checkpoint=False,
1560
+ offload_to_cpu=False,
1561
+ transformer_depth_middle=None,
1562
+ ):
1563
+ super().__init__()
1564
+ from omegaconf.listconfig import ListConfig
1565
+
1566
+ if use_spatial_transformer:
1567
+ assert (
1568
+ context_dim is not None
1569
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
1570
+
1571
+ if context_dim is not None:
1572
+ assert (
1573
+ use_spatial_transformer
1574
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
1575
+ if type(context_dim) == ListConfig:
1576
+ context_dim = list(context_dim)
1577
+
1578
+ if num_heads_upsample == -1:
1579
+ num_heads_upsample = num_heads
1580
+
1581
+ if num_heads == -1:
1582
+ assert (
1583
+ num_head_channels != -1
1584
+ ), "Either num_heads or num_head_channels has to be set"
1585
+
1586
+ if num_head_channels == -1:
1587
+ assert (
1588
+ num_heads != -1
1589
+ ), "Either num_heads or num_head_channels has to be set"
1590
+
1591
+ self.in_channels = in_channels
1592
+ self.ctrl_channels = ctrl_channels
1593
+ self.model_channels = model_channels
1594
+ self.out_channels = out_channels
1595
+ if isinstance(transformer_depth, int):
1596
+ transformer_depth = len(channel_mult) * [transformer_depth]
1597
+ elif isinstance(transformer_depth, ListConfig):
1598
+ transformer_depth = list(transformer_depth)
1599
+ transformer_depth_middle = default(
1600
+ transformer_depth_middle, transformer_depth[-1]
1601
+ )
1602
+
1603
+ if isinstance(num_res_blocks, int):
1604
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
1605
+ else:
1606
+ if len(num_res_blocks) != len(channel_mult):
1607
+ raise ValueError(
1608
+ "provide num_res_blocks either as an int (globally constant) or "
1609
+ "as a list/tuple (per-level) with the same length as channel_mult"
1610
+ )
1611
+ self.num_res_blocks = num_res_blocks
1612
+ # self.num_res_blocks = num_res_blocks
1613
+ if disable_self_attentions is not None:
1614
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
1615
+ assert len(disable_self_attentions) == len(channel_mult)
1616
+ if num_attention_blocks is not None:
1617
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
1618
+ assert all(
1619
+ map(
1620
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
1621
+ range(len(num_attention_blocks)),
1622
+ )
1623
+ )
1624
+ print(
1625
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
1626
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
1627
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
1628
+ f"attention will still not be set."
1629
+ ) # todo: convert to warning
1630
+
1631
+ self.attention_resolutions = attention_resolutions
1632
+ self.dropout = dropout
1633
+ self.channel_mult = channel_mult
1634
+ self.conv_resample = conv_resample
1635
+ self.num_classes = num_classes
1636
+ self.use_checkpoint = use_checkpoint
1637
+ if use_fp16:
1638
+ print("WARNING: use_fp16 was dropped and has no effect anymore.")
1639
+ # self.dtype = th.float16 if use_fp16 else th.float32
1640
+ self.num_heads = num_heads
1641
+ self.num_head_channels = num_head_channels
1642
+ self.num_heads_upsample = num_heads_upsample
1643
+ self.predict_codebook_ids = n_embed is not None
1644
+
1645
+ assert use_fairscale_checkpoint != use_checkpoint or not (
1646
+ use_checkpoint or use_fairscale_checkpoint
1647
+ )
1648
+
1649
+ self.use_fairscale_checkpoint = False
1650
+ checkpoint_wrapper_fn = (
1651
+ partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
1652
+ if self.use_fairscale_checkpoint
1653
+ else lambda x: x
1654
+ )
1655
+
1656
+ time_embed_dim = model_channels * 4
1657
+ self.time_embed = checkpoint_wrapper_fn(
1658
+ nn.Sequential(
1659
+ linear(model_channels, time_embed_dim),
1660
+ nn.SiLU(),
1661
+ linear(time_embed_dim, time_embed_dim),
1662
+ )
1663
+ )
1664
+
1665
+ if self.num_classes is not None:
1666
+ if isinstance(self.num_classes, int):
1667
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
1668
+ elif self.num_classes == "continuous":
1669
+ print("setting up linear c_adm embedding layer")
1670
+ self.label_emb = nn.Linear(1, time_embed_dim)
1671
+ elif self.num_classes == "timestep":
1672
+ self.label_emb = checkpoint_wrapper_fn(
1673
+ nn.Sequential(
1674
+ Timestep(model_channels),
1675
+ nn.Sequential(
1676
+ linear(model_channels, time_embed_dim),
1677
+ nn.SiLU(),
1678
+ linear(time_embed_dim, time_embed_dim),
1679
+ ),
1680
+ )
1681
+ )
1682
+ elif self.num_classes == "sequential":
1683
+ assert adm_in_channels is not None
1684
+ self.label_emb = nn.Sequential(
1685
+ nn.Sequential(
1686
+ linear(adm_in_channels, time_embed_dim),
1687
+ nn.SiLU(),
1688
+ linear(time_embed_dim, time_embed_dim),
1689
+ )
1690
+ )
1691
+ else:
1692
+ raise ValueError()
1693
+
1694
+ self.input_blocks = nn.ModuleList(
1695
+ [
1696
+ TimestepEmbedSequential(
1697
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
1698
+ )
1699
+ ]
1700
+ )
1701
+ if self.ctrl_channels > 0:
1702
+ self.add_input_block = TimestepEmbedSequential(
1703
+ conv_nd(dims, ctrl_channels, 16, 3, padding=1),
1704
+ nn.SiLU(),
1705
+ conv_nd(dims, 16, 16, 3, padding=1),
1706
+ nn.SiLU(),
1707
+ conv_nd(dims, 16, 32, 3, padding=1),
1708
+ nn.SiLU(),
1709
+ conv_nd(dims, 32, 32, 3, padding=1),
1710
+ nn.SiLU(),
1711
+ conv_nd(dims, 32, 96, 3, padding=1),
1712
+ nn.SiLU(),
1713
+ conv_nd(dims, 96, 96, 3, padding=1),
1714
+ nn.SiLU(),
1715
+ conv_nd(dims, 96, 256, 3, padding=1),
1716
+ nn.SiLU(),
1717
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
1718
+ )
1719
+
1720
+ self._feature_size = model_channels
1721
+ input_block_chans = [model_channels]
1722
+ ch = model_channels
1723
+ ds = 1
1724
+ for level, mult in enumerate(channel_mult):
1725
+ for nr in range(self.num_res_blocks[level]):
1726
+ layers = [
1727
+ checkpoint_wrapper_fn(
1728
+ ResBlock(
1729
+ ch,
1730
+ time_embed_dim,
1731
+ dropout,
1732
+ out_channels=mult * model_channels,
1733
+ dims=dims,
1734
+ use_checkpoint=use_checkpoint,
1735
+ use_scale_shift_norm=use_scale_shift_norm,
1736
+ )
1737
+ )
1738
+ ]
1739
+ ch = mult * model_channels
1740
+ if ds in attention_resolutions:
1741
+ if num_head_channels == -1:
1742
+ dim_head = ch // num_heads
1743
+ else:
1744
+ num_heads = ch // num_head_channels
1745
+ dim_head = num_head_channels
1746
+ if legacy:
1747
+ # num_heads = 1
1748
+ dim_head = (
1749
+ ch // num_heads
1750
+ if use_spatial_transformer
1751
+ else num_head_channels
1752
+ )
1753
+ if exists(disable_self_attentions):
1754
+ disabled_sa = disable_self_attentions[level]
1755
+ else:
1756
+ disabled_sa = False
1757
+
1758
+ if (
1759
+ not exists(num_attention_blocks)
1760
+ or nr < num_attention_blocks[level]
1761
+ ):
1762
+ layers.append(
1763
+ checkpoint_wrapper_fn(
1764
+ AttentionBlock(
1765
+ ch,
1766
+ use_checkpoint=use_checkpoint,
1767
+ num_heads=num_heads,
1768
+ num_head_channels=dim_head,
1769
+ use_new_attention_order=use_new_attention_order,
1770
+ )
1771
+ )
1772
+ if not use_spatial_transformer
1773
+ else checkpoint_wrapper_fn(
1774
+ SpatialTransformer(
1775
+ ch,
1776
+ num_heads,
1777
+ dim_head,
1778
+ depth=transformer_depth[level],
1779
+ context_dim=context_dim,
1780
+ add_context_dim=add_context_dim,
1781
+ disable_self_attn=disabled_sa,
1782
+ use_linear=use_linear_in_transformer,
1783
+ attn_type=spatial_transformer_attn_type,
1784
+ use_checkpoint=use_checkpoint,
1785
+ )
1786
+ )
1787
+ )
1788
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
1789
+ self._feature_size += ch
1790
+ input_block_chans.append(ch)
1791
+ if level != len(channel_mult) - 1:
1792
+ out_ch = ch
1793
+ self.input_blocks.append(
1794
+ TimestepEmbedSequential(
1795
+ checkpoint_wrapper_fn(
1796
+ ResBlock(
1797
+ ch,
1798
+ time_embed_dim,
1799
+ dropout,
1800
+ out_channels=out_ch,
1801
+ dims=dims,
1802
+ use_checkpoint=use_checkpoint,
1803
+ use_scale_shift_norm=use_scale_shift_norm,
1804
+ down=True,
1805
+ )
1806
+ )
1807
+ if resblock_updown
1808
+ else Downsample(
1809
+ ch, conv_resample, dims=dims, out_channels=out_ch
1810
+ )
1811
+ )
1812
+ )
1813
+ ch = out_ch
1814
+ input_block_chans.append(ch)
1815
+ ds *= 2
1816
+ self._feature_size += ch
1817
+
1818
+ if num_head_channels == -1:
1819
+ dim_head = ch // num_heads
1820
+ else:
1821
+ num_heads = ch // num_head_channels
1822
+ dim_head = num_head_channels
1823
+ if legacy:
1824
+ # num_heads = 1
1825
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
1826
+ self.middle_block = TimestepEmbedSequential(
1827
+ checkpoint_wrapper_fn(
1828
+ ResBlock(
1829
+ ch,
1830
+ time_embed_dim,
1831
+ dropout,
1832
+ dims=dims,
1833
+ use_checkpoint=use_checkpoint,
1834
+ use_scale_shift_norm=use_scale_shift_norm,
1835
+ )
1836
+ ),
1837
+ checkpoint_wrapper_fn(
1838
+ AttentionBlock(
1839
+ ch,
1840
+ use_checkpoint=use_checkpoint,
1841
+ num_heads=num_heads,
1842
+ num_head_channels=dim_head,
1843
+ use_new_attention_order=use_new_attention_order,
1844
+ )
1845
+ )
1846
+ if not use_spatial_transformer
1847
+ else checkpoint_wrapper_fn(
1848
+ SpatialTransformer( # always uses a self-attn
1849
+ ch,
1850
+ num_heads,
1851
+ dim_head,
1852
+ depth=transformer_depth_middle,
1853
+ context_dim=context_dim,
1854
+ add_context_dim=add_context_dim,
1855
+ disable_self_attn=disable_middle_self_attn,
1856
+ use_linear=use_linear_in_transformer,
1857
+ attn_type=spatial_transformer_attn_type,
1858
+ use_checkpoint=use_checkpoint,
1859
+ )
1860
+ ),
1861
+ checkpoint_wrapper_fn(
1862
+ ResBlock(
1863
+ ch,
1864
+ time_embed_dim,
1865
+ dropout,
1866
+ dims=dims,
1867
+ use_checkpoint=use_checkpoint,
1868
+ use_scale_shift_norm=use_scale_shift_norm,
1869
+ )
1870
+ ),
1871
+ )
1872
+ self._feature_size += ch
1873
+
1874
+ self.output_blocks = nn.ModuleList([])
1875
+ for level, mult in list(enumerate(channel_mult))[::-1]:
1876
+ for i in range(self.num_res_blocks[level] + 1):
1877
+ ich = input_block_chans.pop()
1878
+ layers = [
1879
+ checkpoint_wrapper_fn(
1880
+ ResBlock(
1881
+ ch + ich,
1882
+ time_embed_dim,
1883
+ dropout,
1884
+ out_channels=model_channels * mult,
1885
+ dims=dims,
1886
+ use_checkpoint=use_checkpoint,
1887
+ use_scale_shift_norm=use_scale_shift_norm,
1888
+ )
1889
+ )
1890
+ ]
1891
+ ch = model_channels * mult
1892
+ if ds in attention_resolutions:
1893
+ if num_head_channels == -1:
1894
+ dim_head = ch // num_heads
1895
+ else:
1896
+ num_heads = ch // num_head_channels
1897
+ dim_head = num_head_channels
1898
+ if legacy:
1899
+ # num_heads = 1
1900
+ dim_head = (
1901
+ ch // num_heads
1902
+ if use_spatial_transformer
1903
+ else num_head_channels
1904
+ )
1905
+ if exists(disable_self_attentions):
1906
+ disabled_sa = disable_self_attentions[level]
1907
+ else:
1908
+ disabled_sa = False
1909
+
1910
+ if (
1911
+ not exists(num_attention_blocks)
1912
+ or i < num_attention_blocks[level]
1913
+ ):
1914
+ layers.append(
1915
+ checkpoint_wrapper_fn(
1916
+ AttentionBlock(
1917
+ ch,
1918
+ use_checkpoint=use_checkpoint,
1919
+ num_heads=num_heads_upsample,
1920
+ num_head_channels=dim_head,
1921
+ use_new_attention_order=use_new_attention_order,
1922
+ )
1923
+ )
1924
+ if not use_spatial_transformer
1925
+ else checkpoint_wrapper_fn(
1926
+ SpatialTransformer(
1927
+ ch,
1928
+ num_heads,
1929
+ dim_head,
1930
+ depth=transformer_depth[level],
1931
+ context_dim=context_dim,
1932
+ add_context_dim=add_context_dim,
1933
+ disable_self_attn=disabled_sa,
1934
+ use_linear=use_linear_in_transformer,
1935
+ attn_type=spatial_transformer_attn_type,
1936
+ use_checkpoint=use_checkpoint,
1937
+ )
1938
+ )
1939
+ )
1940
+ if level and i == self.num_res_blocks[level]:
1941
+ out_ch = ch
1942
+ layers.append(
1943
+ checkpoint_wrapper_fn(
1944
+ ResBlock(
1945
+ ch,
1946
+ time_embed_dim,
1947
+ dropout,
1948
+ out_channels=out_ch,
1949
+ dims=dims,
1950
+ use_checkpoint=use_checkpoint,
1951
+ use_scale_shift_norm=use_scale_shift_norm,
1952
+ up=True,
1953
+ )
1954
+ )
1955
+ if resblock_updown
1956
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
1957
+ )
1958
+ ds //= 2
1959
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
1960
+ self._feature_size += ch
1961
+
1962
+ self.out = checkpoint_wrapper_fn(
1963
+ nn.Sequential(
1964
+ normalization(ch),
1965
+ nn.SiLU(),
1966
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
1967
+ )
1968
+ )
1969
+ if self.predict_codebook_ids:
1970
+ self.id_predictor = checkpoint_wrapper_fn(
1971
+ nn.Sequential(
1972
+ normalization(ch),
1973
+ conv_nd(dims, model_channels, n_embed, 1),
1974
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
1975
+ )
1976
+ )
1977
+
1978
+ # cache attn map
1979
+ self.attn_type = attn_type
1980
+ self.attn_layers = attn_layers
1981
+ self.attn_map_cache = []
1982
+ for name, module in self.named_modules():
1983
+ if name.endswith(self.attn_type):
1984
+ item = {"name": name, "heads": module.heads, "size": None, "attn_map": None}
1985
+ self.attn_map_cache.append(item)
1986
+ module.attn_map_cache = item
1987
+
1988
+ def clear_attn_map(self):
1989
+
1990
+ for item in self.attn_map_cache:
1991
+ if item["attn_map"] is not None:
1992
+ del item["attn_map"]
1993
+ item["attn_map"] = None
1994
+
1995
+ def save_attn_map(self, save_name="temp", tokens=""):
1996
+
1997
+ attn_maps = []
1998
+ for item in self.attn_map_cache:
1999
+ name = item["name"]
2000
+ if any([name.startswith(block) for block in self.attn_layers]):
2001
+ heads = item["heads"]
2002
+ attn_maps.append(item["attn_map"].detach().cpu())
2003
+
2004
+ attn_map = th.stack(attn_maps, dim=0)
2005
+ attn_map = th.mean(attn_map, dim=0)
2006
+
2007
+ # attn_map: bh * n * l
2008
+ bh, n, l = attn_map.shape # bh: batch size * heads / n : pixel length(h*w) / l: token length
2009
+ attn_map = attn_map.reshape((-1,heads,n,l)).mean(dim=1)
2010
+ b = attn_map.shape[0]
2011
+
2012
+ h = w = int(n**0.5)
2013
+ attn_map = attn_map.permute(0,2,1).reshape((b,l,h,w)).numpy()
2014
+
2015
+ attn_map_i = attn_map[-1]
2016
+
2017
+ l = attn_map_i.shape[0]
2018
+ fig = plt.figure(figsize=(12, 8), dpi=300)
2019
+ for j in range(12):
2020
+ if j >= l: break
2021
+ ax = fig.add_subplot(3, 4, j+1)
2022
+ sns.heatmap(attn_map_i[j], square=True, xticklabels=False, yticklabels=False)
2023
+ if j < len(tokens):
2024
+ ax.set_title(tokens[j])
2025
+ fig.savefig(f"temp/attn_map/attn_map_{save_name}.png")
2026
+ plt.close()
2027
+
2028
+ return attn_map_i
2029
+
2030
+ def forward(self, x, timesteps=None, context=None, add_context=None, y=None, **kwargs):
2031
+ """
2032
+ Apply the model to an input batch.
2033
+ :param x: an [N x C x ...] Tensor of inputs.
2034
+ :param timesteps: a 1-D batch of timesteps.
2035
+ :param context: conditioning plugged in via crossattn
2036
+ :param y: an [N] Tensor of labels, if class-conditional.
2037
+ :return: an [N x C x ...] Tensor of outputs.
2038
+ """
2039
+ assert (y is not None) == (
2040
+ self.num_classes is not None
2041
+ ), "must specify y if and only if the model is class-conditional"
2042
+
2043
+ self.clear_attn_map()
2044
+
2045
+ hs = []
2046
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
2047
+ emb = self.time_embed(t_emb)
2048
+
2049
+ if self.num_classes is not None:
2050
+ assert y.shape[0] == x.shape[0]
2051
+ emb = emb + self.label_emb(y)
2052
+
2053
+ # h = x.type(self.dtype)
2054
+ h = x
2055
+ if self.ctrl_channels > 0:
2056
+ in_h, add_h = th.split(h, [self.in_channels, self.ctrl_channels], dim=1)
2057
+
2058
+ for i, module in enumerate(self.input_blocks):
2059
+ if self.ctrl_channels > 0 and i == 0:
2060
+ h = module(in_h, emb, context, add_context) + self.add_input_block(add_h, emb, context, add_context)
2061
+ else:
2062
+ h = module(h, emb, context, add_context)
2063
+ hs.append(h)
2064
+ h = self.middle_block(h, emb, context, add_context)
2065
+ for i, module in enumerate(self.output_blocks):
2066
+ h = th.cat([h, hs.pop()], dim=1)
2067
+ h = module(h, emb, context, add_context)
2068
+ h = h.type(x.dtype)
2069
+
2070
+ return self.out(h)
sgm/modules/diffusionmodules/sampling.py ADDED
@@ -0,0 +1,784 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
3
+ """
4
+
5
+
6
+ from typing import Dict, Union
7
+
8
+ import imageio
9
+ import torch
10
+ import json
11
+ import numpy as np
12
+ import torch.nn.functional as F
13
+ from omegaconf import ListConfig, OmegaConf
14
+ from tqdm import tqdm
15
+
16
+ from ...modules.diffusionmodules.sampling_utils import (
17
+ get_ancestral_step,
18
+ linear_multistep_coeff,
19
+ to_d,
20
+ to_neg_log_sigma,
21
+ to_sigma,
22
+ )
23
+ from ...util import append_dims, default, instantiate_from_config
24
+ from torchvision.utils import save_image
25
+
26
+ DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
27
+
28
+
29
+ class BaseDiffusionSampler:
30
+ def __init__(
31
+ self,
32
+ discretization_config: Union[Dict, ListConfig, OmegaConf],
33
+ num_steps: Union[int, None] = None,
34
+ guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,
35
+ verbose: bool = False,
36
+ device: str = "cuda",
37
+ ):
38
+ self.num_steps = num_steps
39
+ self.discretization = instantiate_from_config(discretization_config)
40
+ self.guider = instantiate_from_config(
41
+ default(
42
+ guider_config,
43
+ DEFAULT_GUIDER,
44
+ )
45
+ )
46
+ self.verbose = verbose
47
+ self.device = device
48
+
49
+ def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
50
+ sigmas = self.discretization(
51
+ self.num_steps if num_steps is None else num_steps, device=self.device
52
+ )
53
+ uc = default(uc, cond)
54
+
55
+ x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
56
+ num_sigmas = len(sigmas)
57
+
58
+ s_in = x.new_ones([x.shape[0]])
59
+
60
+ return x, s_in, sigmas, num_sigmas, cond, uc
61
+
62
+ def denoise(self, x, model, sigma, cond, uc):
63
+ denoised = model.denoiser(model.model, *self.guider.prepare_inputs(x, sigma, cond, uc))
64
+ denoised = self.guider(denoised, sigma)
65
+ return denoised
66
+
67
+ def get_sigma_gen(self, num_sigmas, init_step=0):
68
+ sigma_generator = range(init_step, num_sigmas - 1)
69
+ if self.verbose:
70
+ print("#" * 30, " Sampling setting ", "#" * 30)
71
+ print(f"Sampler: {self.__class__.__name__}")
72
+ print(f"Discretization: {self.discretization.__class__.__name__}")
73
+ print(f"Guider: {self.guider.__class__.__name__}")
74
+ sigma_generator = tqdm(
75
+ sigma_generator,
76
+ total=num_sigmas-1-init_step,
77
+ desc=f"Sampling with {self.__class__.__name__} for {num_sigmas-1-init_step} steps",
78
+ )
79
+ return sigma_generator
80
+
81
+
82
+ class SingleStepDiffusionSampler(BaseDiffusionSampler):
83
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):
84
+ raise NotImplementedError
85
+
86
+ def euler_step(self, x, d, dt):
87
+ return x + dt * d
88
+
89
+
90
+ class EDMSampler(SingleStepDiffusionSampler):
91
+ def __init__(
92
+ self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
93
+ ):
94
+ super().__init__(*args, **kwargs)
95
+
96
+ self.s_churn = s_churn
97
+ self.s_tmin = s_tmin
98
+ self.s_tmax = s_tmax
99
+ self.s_noise = s_noise
100
+
101
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
102
+ sigma_hat = sigma * (gamma + 1.0)
103
+ if gamma > 0:
104
+ eps = torch.randn_like(x) * self.s_noise
105
+ x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
106
+
107
+ denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
108
+ d = to_d(x, sigma_hat, denoised)
109
+ dt = append_dims(next_sigma - sigma_hat, x.ndim)
110
+
111
+ euler_step = self.euler_step(x, d, dt)
112
+ x = self.possible_correction_step(
113
+ euler_step, x, d, dt, next_sigma, denoiser, cond, uc
114
+ )
115
+ return x
116
+
117
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
118
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
119
+ x, cond, uc, num_steps
120
+ )
121
+
122
+ for i in self.get_sigma_gen(num_sigmas):
123
+ gamma = (
124
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
125
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
126
+ else 0.0
127
+ )
128
+ x = self.sampler_step(
129
+ s_in * sigmas[i],
130
+ s_in * sigmas[i + 1],
131
+ denoiser,
132
+ x,
133
+ cond,
134
+ uc,
135
+ gamma,
136
+ )
137
+
138
+ return x
139
+
140
+
141
+ class AncestralSampler(SingleStepDiffusionSampler):
142
+ def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):
143
+ super().__init__(*args, **kwargs)
144
+
145
+ self.eta = eta
146
+ self.s_noise = s_noise
147
+ self.noise_sampler = lambda x: torch.randn_like(x)
148
+
149
+ def ancestral_euler_step(self, x, denoised, sigma, sigma_down):
150
+ d = to_d(x, sigma, denoised)
151
+ dt = append_dims(sigma_down - sigma, x.ndim)
152
+
153
+ return self.euler_step(x, d, dt)
154
+
155
+ def ancestral_step(self, x, sigma, next_sigma, sigma_up):
156
+ x = torch.where(
157
+ append_dims(next_sigma, x.ndim) > 0.0,
158
+ x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim),
159
+ x,
160
+ )
161
+ return x
162
+
163
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
164
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
165
+ x, cond, uc, num_steps
166
+ )
167
+
168
+ for i in self.get_sigma_gen(num_sigmas):
169
+ x = self.sampler_step(
170
+ s_in * sigmas[i],
171
+ s_in * sigmas[i + 1],
172
+ denoiser,
173
+ x,
174
+ cond,
175
+ uc,
176
+ )
177
+
178
+ return x
179
+
180
+
181
+ class LinearMultistepSampler(BaseDiffusionSampler):
182
+ def __init__(
183
+ self,
184
+ order=4,
185
+ *args,
186
+ **kwargs,
187
+ ):
188
+ super().__init__(*args, **kwargs)
189
+
190
+ self.order = order
191
+
192
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
193
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
194
+ x, cond, uc, num_steps
195
+ )
196
+
197
+ ds = []
198
+ sigmas_cpu = sigmas.detach().cpu().numpy()
199
+ for i in self.get_sigma_gen(num_sigmas):
200
+ sigma = s_in * sigmas[i]
201
+ denoised = denoiser(
202
+ *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs
203
+ )
204
+ denoised = self.guider(denoised, sigma)
205
+ d = to_d(x, sigma, denoised)
206
+ ds.append(d)
207
+ if len(ds) > self.order:
208
+ ds.pop(0)
209
+ cur_order = min(i + 1, self.order)
210
+ coeffs = [
211
+ linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
212
+ for j in range(cur_order)
213
+ ]
214
+ x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
215
+
216
+ return x
217
+
218
+
219
+ class EulerEDMSampler(EDMSampler):
220
+
221
+ def possible_correction_step(
222
+ self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
223
+ ):
224
+ return euler_step
225
+
226
+ def get_c_noise(self, x, model, sigma):
227
+ sigma = model.denoiser.possibly_quantize_sigma(sigma)
228
+ sigma_shape = sigma.shape
229
+ sigma = append_dims(sigma, x.ndim)
230
+ c_skip, c_out, c_in, c_noise = model.denoiser.scaling(sigma)
231
+ c_noise = model.denoiser.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
232
+ return c_noise
233
+
234
+ def attend_and_excite(self, x, model, sigma, cond, batch, alpha, iter_enabled, thres, max_iter=20):
235
+
236
+ # calc timestep
237
+ c_noise = self.get_c_noise(x, model, sigma)
238
+
239
+ x = x.clone().detach().requires_grad_(True) # https://github.com/yuval-alaluf/Attend-and-Excite/blob/main/pipeline_attend_and_excite.py#L288
240
+
241
+ iters = 0
242
+ while True:
243
+
244
+ model_output = model.model(x, c_noise, cond)
245
+ local_loss = model.loss_fn.get_min_local_loss(model.model.diffusion_model.attn_map_cache, batch["mask"], batch["seg_mask"])
246
+ grad = torch.autograd.grad(local_loss.requires_grad_(True), [x], retain_graph=True)[0]
247
+ x = x - alpha * grad
248
+ iters += 1
249
+
250
+ if not iter_enabled or local_loss <= thres or iters > max_iter:
251
+ break
252
+
253
+ return x
254
+
255
+ def create_pascal_label_colormap(self):
256
+ """
257
+ PASCAL VOC 分割数据集的类别标签颜色映射label colormap
258
+
259
+ 返回:
260
+ 可视化分割结果的颜色映射Colormap
261
+ """
262
+ colormap = np.zeros((256, 3), dtype=int)
263
+ ind = np.arange(256, dtype=int)
264
+
265
+ for shift in reversed(range(8)):
266
+ for channel in range(3):
267
+ colormap[:, channel] |= ((ind >> channel) & 1) << shift
268
+ ind >>= 3
269
+
270
+ return colormap
271
+
272
+ def save_segment_map(self, image, attn_maps, tokens=None, save_name=None):
273
+
274
+ colormap = self.create_pascal_label_colormap()
275
+ H, W = image.shape[-2:]
276
+
277
+ image_ = image*0.3
278
+ sections = []
279
+ for i in range(len(tokens)):
280
+ attn_map = attn_maps[i]
281
+ attn_map_t = np.tile(attn_map[None], (1,3,1,1)) # b, 3, h, w
282
+ attn_map_t = torch.from_numpy(attn_map_t)
283
+ attn_map_t = F.interpolate(attn_map_t, (W, H))
284
+
285
+ color = torch.from_numpy(colormap[i+1][None,:,None,None] / 255.0)
286
+ colored_attn_map = attn_map_t * color
287
+ colored_attn_map = colored_attn_map.to(device=image_.device)
288
+
289
+ image_ += colored_attn_map*0.7
290
+ sections.append(attn_map)
291
+
292
+ section = np.stack(sections)
293
+ np.save(f"temp/seg_map/seg_{save_name}.npy", section)
294
+
295
+ save_image(image_, f"temp/seg_map/seg_{save_name}.png", normalize=True)
296
+
297
+ def get_init_noise(self, cfgs, model, cond, batch, uc=None):
298
+
299
+ H, W = batch["target_size_as_tuple"][0]
300
+ shape = (cfgs.batch_size, cfgs.channel, int(H) // cfgs.factor, int(W) // cfgs.factor)
301
+
302
+ randn = torch.randn(shape).to(torch.device("cuda", index=cfgs.gpu))
303
+ x = randn.clone()
304
+
305
+ xs = []
306
+ self.verbose = False
307
+ for _ in range(cfgs.noise_iters):
308
+
309
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
310
+ x, cond, uc, num_steps=2
311
+ )
312
+
313
+ superv = {
314
+ "mask": batch["mask"] if "mask" in batch else None,
315
+ "seg_mask": batch["seg_mask"] if "seg_mask" in batch else None
316
+ }
317
+
318
+ local_losses = []
319
+
320
+ for i in self.get_sigma_gen(num_sigmas):
321
+
322
+ gamma = (
323
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
324
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
325
+ else 0.0
326
+ )
327
+
328
+ x, inter, local_loss = self.sampler_step(
329
+ s_in * sigmas[i],
330
+ s_in * sigmas[i + 1],
331
+ model,
332
+ x,
333
+ cond,
334
+ superv,
335
+ uc,
336
+ gamma,
337
+ save_loss=True
338
+ )
339
+
340
+ local_losses.append(local_loss.item())
341
+
342
+ xs.append((randn, local_losses[-1]))
343
+
344
+ randn = torch.randn(shape).to(torch.device("cuda", index=cfgs.gpu))
345
+ x = randn.clone()
346
+
347
+ self.verbose = True
348
+
349
+ xs.sort(key = lambda x: x[-1])
350
+
351
+ if len(xs) > 0:
352
+ print(f"Init local loss: Best {xs[0][1]} Worst {xs[-1][1]}")
353
+ x = xs[0][0]
354
+
355
+ return x
356
+
357
+ def sampler_step(self, sigma, next_sigma, model, x, cond, batch=None, uc=None,
358
+ gamma=0.0, alpha=0, iter_enabled=False, thres=None, update=False,
359
+ name=None, save_loss=False, save_attn=False, save_inter=False):
360
+
361
+ sigma_hat = sigma * (gamma + 1.0)
362
+ if gamma > 0:
363
+ eps = torch.randn_like(x) * self.s_noise
364
+ x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
365
+
366
+ if update:
367
+ x = self.attend_and_excite(x, model, sigma_hat, cond, batch, alpha, iter_enabled, thres)
368
+
369
+ denoised = self.denoise(x, model, sigma_hat, cond, uc)
370
+ denoised_decode = model.decode_first_stage(denoised) if save_inter else None
371
+
372
+ if save_loss:
373
+ local_loss = model.loss_fn.get_min_local_loss(model.model.diffusion_model.attn_map_cache, batch["mask"], batch["seg_mask"])
374
+ local_loss = local_loss[local_loss.shape[0]//2:]
375
+ else:
376
+ local_loss = torch.zeros(1)
377
+ if save_attn:
378
+ attn_map = model.model.diffusion_model.save_attn_map(save_name=name, tokens=batch["label"][0])
379
+ denoised_decode = model.decode_first_stage(denoised) if denoised_decode is None else denoised_decode
380
+ self.save_segment_map(denoised_decode, attn_map, tokens=batch["label"][0], save_name=name)
381
+
382
+ d = to_d(x, sigma_hat, denoised)
383
+ dt = append_dims(next_sigma - sigma_hat, x.ndim)
384
+
385
+ euler_step = self.euler_step(x, d, dt)
386
+
387
+ return euler_step, denoised_decode, local_loss
388
+
389
+ def __call__(self, model, x, cond, batch=None, uc=None, num_steps=None, init_step=0,
390
+ name=None, aae_enabled=False, detailed=False):
391
+
392
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
393
+ x, cond, uc, num_steps
394
+ )
395
+
396
+ name = batch["name"][0]
397
+ inters = []
398
+ local_losses = []
399
+ scales = np.linspace(start=1.0, stop=0, num=num_sigmas)
400
+ iter_lst = np.linspace(start=5, stop=25, num=6, dtype=np.int32)
401
+ thres_lst = np.linspace(start=-0.5, stop=-0.8, num=6)
402
+
403
+ for i in self.get_sigma_gen(num_sigmas, init_step=init_step):
404
+
405
+ gamma = (
406
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
407
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
408
+ else 0.0
409
+ )
410
+
411
+ alpha = 20 * np.sqrt(scales[i])
412
+ update = aae_enabled
413
+ save_loss = detailed
414
+ save_attn = detailed and (i == (num_sigmas-1)//2)
415
+ save_inter = detailed
416
+
417
+ if i in iter_lst:
418
+ iter_enabled = True
419
+ thres = thres_lst[list(iter_lst).index(i)]
420
+ else:
421
+ iter_enabled = False
422
+ thres = 0.0
423
+
424
+ x, inter, local_loss = self.sampler_step(
425
+ s_in * sigmas[i],
426
+ s_in * sigmas[i + 1],
427
+ model,
428
+ x,
429
+ cond,
430
+ batch,
431
+ uc,
432
+ gamma,
433
+ alpha=alpha,
434
+ iter_enabled=iter_enabled,
435
+ thres=thres,
436
+ update=update,
437
+ name=name,
438
+ save_loss=save_loss,
439
+ save_attn=save_attn,
440
+ save_inter=save_inter
441
+ )
442
+
443
+ local_losses.append(local_loss.item())
444
+ if inter is not None:
445
+ inter = torch.clamp((inter + 1.0) / 2.0, min=0.0, max=1.0)[0]
446
+ inter = inter.cpu().numpy().transpose(1, 2, 0) * 255
447
+ inters.append(inter.astype(np.uint8))
448
+
449
+ print(f"Local losses: {local_losses}")
450
+
451
+ if len(inters) > 0:
452
+ imageio.mimsave(f"./temp/inters/{name}.gif", inters, 'GIF', duration=0.02)
453
+
454
+ return x
455
+
456
+
457
+ class EulerEDMDualSampler(EulerEDMSampler):
458
+
459
+ def prepare_sampling_loop(self, x, cond, uc_1=None, uc_2=None, num_steps=None):
460
+ sigmas = self.discretization(
461
+ self.num_steps if num_steps is None else num_steps, device=self.device
462
+ )
463
+ uc_1 = default(uc_1, cond)
464
+ uc_2 = default(uc_2, cond)
465
+
466
+ x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
467
+ num_sigmas = len(sigmas)
468
+
469
+ s_in = x.new_ones([x.shape[0]])
470
+
471
+ return x, s_in, sigmas, num_sigmas, cond, uc_1, uc_2
472
+
473
+ def denoise(self, x, model, sigma, cond, uc_1, uc_2):
474
+ denoised = model.denoiser(model.model, *self.guider.prepare_inputs(x, sigma, cond, uc_1, uc_2))
475
+ denoised = self.guider(denoised, sigma)
476
+ return denoised
477
+
478
+ def get_init_noise(self, cfgs, model, cond, batch, uc_1=None, uc_2=None):
479
+
480
+ H, W = batch["target_size_as_tuple"][0]
481
+ shape = (cfgs.batch_size, cfgs.channel, int(H) // cfgs.factor, int(W) // cfgs.factor)
482
+
483
+ randn = torch.randn(shape).to(torch.device("cuda", index=cfgs.gpu))
484
+ x = randn.clone()
485
+
486
+ xs = []
487
+ self.verbose = False
488
+ for _ in range(cfgs.noise_iters):
489
+
490
+ x, s_in, sigmas, num_sigmas, cond, uc_1, uc_2 = self.prepare_sampling_loop(
491
+ x, cond, uc_1, uc_2, num_steps=2
492
+ )
493
+
494
+ superv = {
495
+ "mask": batch["mask"] if "mask" in batch else None,
496
+ "seg_mask": batch["seg_mask"] if "seg_mask" in batch else None
497
+ }
498
+
499
+ local_losses = []
500
+
501
+ for i in self.get_sigma_gen(num_sigmas):
502
+
503
+ gamma = (
504
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
505
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
506
+ else 0.0
507
+ )
508
+
509
+ x, inter, local_loss = self.sampler_step(
510
+ s_in * sigmas[i],
511
+ s_in * sigmas[i + 1],
512
+ model,
513
+ x,
514
+ cond,
515
+ superv,
516
+ uc_1,
517
+ uc_2,
518
+ gamma,
519
+ save_loss=True
520
+ )
521
+
522
+ local_losses.append(local_loss.item())
523
+
524
+ xs.append((randn, local_losses[-1]))
525
+
526
+ randn = torch.randn(shape).to(torch.device("cuda", index=cfgs.gpu))
527
+ x = randn.clone()
528
+
529
+ self.verbose = True
530
+
531
+ xs.sort(key = lambda x: x[-1])
532
+
533
+ if len(xs) > 0:
534
+ print(f"Init local loss: Best {xs[0][1]} Worst {xs[-1][1]}")
535
+ x = xs[0][0]
536
+
537
+ return x
538
+
539
+ def sampler_step(self, sigma, next_sigma, model, x, cond, batch=None, uc_1=None, uc_2=None,
540
+ gamma=0.0, alpha=0, iter_enabled=False, thres=None, update=False,
541
+ name=None, save_loss=False, save_attn=False, save_inter=False):
542
+
543
+ sigma_hat = sigma * (gamma + 1.0)
544
+ if gamma > 0:
545
+ eps = torch.randn_like(x) * self.s_noise
546
+ x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
547
+
548
+ if update:
549
+ x = self.attend_and_excite(x, model, sigma_hat, cond, batch, alpha, iter_enabled, thres)
550
+
551
+ denoised = self.denoise(x, model, sigma_hat, cond, uc_1, uc_2)
552
+ denoised_decode = model.decode_first_stage(denoised) if save_inter else None
553
+
554
+ if save_loss:
555
+ local_loss = model.loss_fn.get_min_local_loss(model.model.diffusion_model.attn_map_cache, batch["mask"], batch["seg_mask"])
556
+ local_loss = local_loss[-local_loss.shape[0]//3:]
557
+ else:
558
+ local_loss = torch.zeros(1)
559
+ if save_attn:
560
+ attn_map = model.model.diffusion_model.save_attn_map(save_name=name, save_single=True)
561
+ denoised_decode = model.decode_first_stage(denoised) if denoised_decode is None else denoised_decode
562
+ self.save_segment_map(denoised_decode, attn_map, tokens=batch["label"][0], save_name=name)
563
+
564
+ d = to_d(x, sigma_hat, denoised)
565
+ dt = append_dims(next_sigma - sigma_hat, x.ndim)
566
+
567
+ euler_step = self.euler_step(x, d, dt)
568
+
569
+ return euler_step, denoised_decode, local_loss
570
+
571
+ def __call__(self, model, x, cond, batch=None, uc_1=None, uc_2=None, num_steps=None, init_step=0,
572
+ name=None, aae_enabled=False, detailed=False):
573
+
574
+ x, s_in, sigmas, num_sigmas, cond, uc_1, uc_2 = self.prepare_sampling_loop(
575
+ x, cond, uc_1, uc_2, num_steps
576
+ )
577
+
578
+ name = batch["name"][0]
579
+ inters = []
580
+ local_losses = []
581
+ scales = np.linspace(start=1.0, stop=0, num=num_sigmas)
582
+ iter_lst = np.linspace(start=5, stop=25, num=6, dtype=np.int32)
583
+ thres_lst = np.linspace(start=-0.5, stop=-0.8, num=6)
584
+
585
+ for i in self.get_sigma_gen(num_sigmas, init_step=init_step):
586
+
587
+ gamma = (
588
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
589
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
590
+ else 0.0
591
+ )
592
+
593
+ alpha = 20 * np.sqrt(scales[i])
594
+ update = aae_enabled
595
+ save_loss = aae_enabled
596
+ save_attn = detailed and (i == (num_sigmas-1)//2)
597
+ save_inter = aae_enabled
598
+
599
+ if i in iter_lst:
600
+ iter_enabled = True
601
+ thres = thres_lst[list(iter_lst).index(i)]
602
+ else:
603
+ iter_enabled = False
604
+ thres = 0.0
605
+
606
+ x, inter, local_loss = self.sampler_step(
607
+ s_in * sigmas[i],
608
+ s_in * sigmas[i + 1],
609
+ model,
610
+ x,
611
+ cond,
612
+ batch,
613
+ uc_1,
614
+ uc_2,
615
+ gamma,
616
+ alpha=alpha,
617
+ iter_enabled=iter_enabled,
618
+ thres=thres,
619
+ update=update,
620
+ name=name,
621
+ save_loss=save_loss,
622
+ save_attn=save_attn,
623
+ save_inter=save_inter
624
+ )
625
+
626
+ local_losses.append(local_loss.item())
627
+ if inter is not None:
628
+ inter = torch.clamp((inter + 1.0) / 2.0, min=0.0, max=1.0)[0]
629
+ inter = inter.cpu().numpy().transpose(1, 2, 0) * 255
630
+ inters.append(inter.astype(np.uint8))
631
+
632
+ print(f"Local losses: {local_losses}")
633
+
634
+ if len(inters) > 0:
635
+ imageio.mimsave(f"./temp/inters/{name}.gif", inters, 'GIF', duration=0.1)
636
+
637
+ return x
638
+
639
+
640
+ class HeunEDMSampler(EDMSampler):
641
+ def possible_correction_step(
642
+ self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
643
+ ):
644
+ if torch.sum(next_sigma) < 1e-14:
645
+ # Save a network evaluation if all noise levels are 0
646
+ return euler_step
647
+ else:
648
+ denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc)
649
+ d_new = to_d(euler_step, next_sigma, denoised)
650
+ d_prime = (d + d_new) / 2.0
651
+
652
+ # apply correction if noise level is not 0
653
+ x = torch.where(
654
+ append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
655
+ )
656
+ return x
657
+
658
+
659
+ class EulerAncestralSampler(AncestralSampler):
660
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc):
661
+ sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
662
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
663
+ x = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
664
+ x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
665
+
666
+ return x
667
+
668
+
669
+ class DPMPP2SAncestralSampler(AncestralSampler):
670
+ def get_variables(self, sigma, sigma_down):
671
+ t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)]
672
+ h = t_next - t
673
+ s = t + 0.5 * h
674
+ return h, s, t, t_next
675
+
676
+ def get_mult(self, h, s, t, t_next):
677
+ mult1 = to_sigma(s) / to_sigma(t)
678
+ mult2 = (-0.5 * h).expm1()
679
+ mult3 = to_sigma(t_next) / to_sigma(t)
680
+ mult4 = (-h).expm1()
681
+
682
+ return mult1, mult2, mult3, mult4
683
+
684
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs):
685
+ sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
686
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
687
+ x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
688
+
689
+ if torch.sum(sigma_down) < 1e-14:
690
+ # Save a network evaluation if all noise levels are 0
691
+ x = x_euler
692
+ else:
693
+ h, s, t, t_next = self.get_variables(sigma, sigma_down)
694
+ mult = [
695
+ append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)
696
+ ]
697
+
698
+ x2 = mult[0] * x - mult[1] * denoised
699
+ denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
700
+ x_dpmpp2s = mult[2] * x - mult[3] * denoised2
701
+
702
+ # apply correction if noise level is not 0
703
+ x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler)
704
+
705
+ x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
706
+ return x
707
+
708
+
709
+ class DPMPP2MSampler(BaseDiffusionSampler):
710
+ def get_variables(self, sigma, next_sigma, previous_sigma=None):
711
+ t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
712
+ h = t_next - t
713
+
714
+ if previous_sigma is not None:
715
+ h_last = t - to_neg_log_sigma(previous_sigma)
716
+ r = h_last / h
717
+ return h, r, t, t_next
718
+ else:
719
+ return h, None, t, t_next
720
+
721
+ def get_mult(self, h, r, t, t_next, previous_sigma):
722
+ mult1 = to_sigma(t_next) / to_sigma(t)
723
+ mult2 = (-h).expm1()
724
+
725
+ if previous_sigma is not None:
726
+ mult3 = 1 + 1 / (2 * r)
727
+ mult4 = 1 / (2 * r)
728
+ return mult1, mult2, mult3, mult4
729
+ else:
730
+ return mult1, mult2
731
+
732
+ def sampler_step(
733
+ self,
734
+ old_denoised,
735
+ previous_sigma,
736
+ sigma,
737
+ next_sigma,
738
+ denoiser,
739
+ x,
740
+ cond,
741
+ uc=None,
742
+ ):
743
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
744
+
745
+ h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
746
+ mult = [
747
+ append_dims(mult, x.ndim)
748
+ for mult in self.get_mult(h, r, t, t_next, previous_sigma)
749
+ ]
750
+
751
+ x_standard = mult[0] * x - mult[1] * denoised
752
+ if old_denoised is None or torch.sum(next_sigma) < 1e-14:
753
+ # Save a network evaluation if all noise levels are 0 or on the first step
754
+ return x_standard, denoised
755
+ else:
756
+ denoised_d = mult[2] * denoised - mult[3] * old_denoised
757
+ x_advanced = mult[0] * x - mult[1] * denoised_d
758
+
759
+ # apply correction if noise level is not 0 and not first step
760
+ x = torch.where(
761
+ append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
762
+ )
763
+
764
+ return x, denoised
765
+
766
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, init_step=0, **kwargs):
767
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
768
+ x, cond, uc, num_steps
769
+ )
770
+
771
+ old_denoised = None
772
+ for i in self.get_sigma_gen(num_sigmas, init_step=init_step):
773
+ x, old_denoised = self.sampler_step(
774
+ old_denoised,
775
+ None if i == 0 else s_in * sigmas[i - 1],
776
+ s_in * sigmas[i],
777
+ s_in * sigmas[i + 1],
778
+ denoiser,
779
+ x,
780
+ cond,
781
+ uc=uc,
782
+ )
783
+
784
+ return x
sgm/modules/diffusionmodules/sampling_utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from scipy import integrate
3
+
4
+ from ...util import append_dims
5
+
6
+
7
+ class NoDynamicThresholding:
8
+ def __call__(self, uncond, cond, scale):
9
+ return uncond + scale * (cond - uncond)
10
+
11
+ class DualThresholding: # Dual condition CFG (from instructPix2Pix)
12
+ def __call__(self, uncond_1, uncond_2, cond, scale):
13
+ return uncond_1 + scale[0] * (uncond_2 - uncond_1) + scale[1] * (cond - uncond_2)
14
+
15
+ def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
16
+ if order - 1 > i:
17
+ raise ValueError(f"Order {order} too high for step {i}")
18
+
19
+ def fn(tau):
20
+ prod = 1.0
21
+ for k in range(order):
22
+ if j == k:
23
+ continue
24
+ prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
25
+ return prod
26
+
27
+ return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]
28
+
29
+
30
+ def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
31
+ if not eta:
32
+ return sigma_to, 0.0
33
+ sigma_up = torch.minimum(
34
+ sigma_to,
35
+ eta
36
+ * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
37
+ )
38
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
39
+ return sigma_down, sigma_up
40
+
41
+
42
+ def to_d(x, sigma, denoised):
43
+ return (x - denoised) / append_dims(sigma, x.ndim)
44
+
45
+
46
+ def to_neg_log_sigma(sigma):
47
+ return sigma.log().neg()
48
+
49
+
50
+ def to_sigma(neg_log_sigma):
51
+ return neg_log_sigma.neg().exp()
sgm/modules/diffusionmodules/sigma_sampling.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ...util import default, instantiate_from_config
4
+
5
+
6
+ class EDMSampling:
7
+ def __init__(self, p_mean=-1.2, p_std=1.2):
8
+ self.p_mean = p_mean
9
+ self.p_std = p_std
10
+
11
+ def __call__(self, n_samples, rand=None):
12
+ log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
13
+ return log_sigma.exp()
14
+
15
+
16
+ class DiscreteSampling:
17
+ def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True):
18
+ self.num_idx = num_idx
19
+ self.sigmas = instantiate_from_config(discretization_config)(
20
+ num_idx, do_append_zero=do_append_zero, flip=flip
21
+ )
22
+
23
+ def idx_to_sigma(self, idx):
24
+ return self.sigmas[idx]
25
+
26
+ def __call__(self, n_samples, rand=None):
27
+ idx = default(
28
+ rand,
29
+ torch.randint(0, self.num_idx, (n_samples,)),
30
+ )
31
+ return self.idx_to_sigma(idx)
sgm/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ adopted from
3
+ https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
4
+ and
5
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
6
+ and
7
+ https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
8
+
9
+ thanks!
10
+ """
11
+
12
+ import math
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from einops import repeat
17
+
18
+
19
+ def make_beta_schedule(
20
+ schedule,
21
+ n_timestep,
22
+ linear_start=1e-4,
23
+ linear_end=2e-2,
24
+ ):
25
+ if schedule == "linear":
26
+ betas = (
27
+ torch.linspace(
28
+ linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
29
+ )
30
+ ** 2
31
+ )
32
+ return betas.numpy()
33
+
34
+
35
+ def extract_into_tensor(a, t, x_shape):
36
+ b, *_ = t.shape
37
+ out = a.gather(-1, t)
38
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
39
+
40
+
41
+ def mixed_checkpoint(func, inputs: dict, params, flag):
42
+ """
43
+ Evaluate a function without caching intermediate activations, allowing for
44
+ reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function
45
+ borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that
46
+ it also works with non-tensor inputs
47
+ :param func: the function to evaluate.
48
+ :param inputs: the argument dictionary to pass to `func`.
49
+ :param params: a sequence of parameters `func` depends on but does not
50
+ explicitly take as arguments.
51
+ :param flag: if False, disable gradient checkpointing.
52
+ """
53
+ if flag:
54
+ tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]
55
+ tensor_inputs = [
56
+ inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)
57
+ ]
58
+ non_tensor_keys = [
59
+ key for key in inputs if not isinstance(inputs[key], torch.Tensor)
60
+ ]
61
+ non_tensor_inputs = [
62
+ inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)
63
+ ]
64
+ args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)
65
+ return MixedCheckpointFunction.apply(
66
+ func,
67
+ len(tensor_inputs),
68
+ len(non_tensor_inputs),
69
+ tensor_keys,
70
+ non_tensor_keys,
71
+ *args,
72
+ )
73
+ else:
74
+ return func(**inputs)
75
+
76
+
77
+ class MixedCheckpointFunction(torch.autograd.Function):
78
+ @staticmethod
79
+ def forward(
80
+ ctx,
81
+ run_function,
82
+ length_tensors,
83
+ length_non_tensors,
84
+ tensor_keys,
85
+ non_tensor_keys,
86
+ *args,
87
+ ):
88
+ ctx.end_tensors = length_tensors
89
+ ctx.end_non_tensors = length_tensors + length_non_tensors
90
+ ctx.gpu_autocast_kwargs = {
91
+ "enabled": torch.is_autocast_enabled(),
92
+ "dtype": torch.get_autocast_gpu_dtype(),
93
+ "cache_enabled": torch.is_autocast_cache_enabled(),
94
+ }
95
+ assert (
96
+ len(tensor_keys) == length_tensors
97
+ and len(non_tensor_keys) == length_non_tensors
98
+ )
99
+
100
+ ctx.input_tensors = {
101
+ key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))
102
+ }
103
+ ctx.input_non_tensors = {
104
+ key: val
105
+ for (key, val) in zip(
106
+ non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors])
107
+ )
108
+ }
109
+ ctx.run_function = run_function
110
+ ctx.input_params = list(args[ctx.end_non_tensors :])
111
+
112
+ with torch.no_grad():
113
+ output_tensors = ctx.run_function(
114
+ **ctx.input_tensors, **ctx.input_non_tensors
115
+ )
116
+ return output_tensors
117
+
118
+ @staticmethod
119
+ def backward(ctx, *output_grads):
120
+ # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}
121
+ ctx.input_tensors = {
122
+ key: ctx.input_tensors[key].detach().requires_grad_(True)
123
+ for key in ctx.input_tensors
124
+ }
125
+
126
+ with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
127
+ # Fixes a bug where the first op in run_function modifies the
128
+ # Tensor storage in place, which is not allowed for detach()'d
129
+ # Tensors.
130
+ shallow_copies = {
131
+ key: ctx.input_tensors[key].view_as(ctx.input_tensors[key])
132
+ for key in ctx.input_tensors
133
+ }
134
+ # shallow_copies.update(additional_args)
135
+ output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)
136
+ input_grads = torch.autograd.grad(
137
+ output_tensors,
138
+ list(ctx.input_tensors.values()) + ctx.input_params,
139
+ output_grads,
140
+ allow_unused=True,
141
+ )
142
+ del ctx.input_tensors
143
+ del ctx.input_params
144
+ del output_tensors
145
+ return (
146
+ (None, None, None, None, None)
147
+ + input_grads[: ctx.end_tensors]
148
+ + (None,) * (ctx.end_non_tensors - ctx.end_tensors)
149
+ + input_grads[ctx.end_tensors :]
150
+ )
151
+
152
+
153
+ def checkpoint(func, inputs, params, flag):
154
+ """
155
+ Evaluate a function without caching intermediate activations, allowing for
156
+ reduced memory at the expense of extra compute in the backward pass.
157
+ :param func: the function to evaluate.
158
+ :param inputs: the argument sequence to pass to `func`.
159
+ :param params: a sequence of parameters `func` depends on but does not
160
+ explicitly take as arguments.
161
+ :param flag: if False, disable gradient checkpointing.
162
+ """
163
+ if flag:
164
+ args = tuple(inputs) + tuple(params)
165
+ return CheckpointFunction.apply(func, len(inputs), *args)
166
+ else:
167
+ return func(*inputs)
168
+
169
+
170
+ class CheckpointFunction(torch.autograd.Function):
171
+ @staticmethod
172
+ def forward(ctx, run_function, length, *args):
173
+ ctx.run_function = run_function
174
+ ctx.input_tensors = list(args[:length])
175
+ ctx.input_params = list(args[length:])
176
+ ctx.gpu_autocast_kwargs = {
177
+ "enabled": torch.is_autocast_enabled(),
178
+ "dtype": torch.get_autocast_gpu_dtype(),
179
+ "cache_enabled": torch.is_autocast_cache_enabled(),
180
+ }
181
+ with torch.no_grad():
182
+ output_tensors = ctx.run_function(*ctx.input_tensors)
183
+ return output_tensors
184
+
185
+ @staticmethod
186
+ def backward(ctx, *output_grads):
187
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
188
+ with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
189
+ # Fixes a bug where the first op in run_function modifies the
190
+ # Tensor storage in place, which is not allowed for detach()'d
191
+ # Tensors.
192
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
193
+ output_tensors = ctx.run_function(*shallow_copies)
194
+ input_grads = torch.autograd.grad(
195
+ output_tensors,
196
+ ctx.input_tensors + ctx.input_params,
197
+ output_grads,
198
+ allow_unused=True,
199
+ )
200
+ del ctx.input_tensors
201
+ del ctx.input_params
202
+ del output_tensors
203
+ return (None, None) + input_grads
204
+
205
+
206
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
207
+ """
208
+ Create sinusoidal timestep embeddings.
209
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
210
+ These may be fractional.
211
+ :param dim: the dimension of the output.
212
+ :param max_period: controls the minimum frequency of the embeddings.
213
+ :return: an [N x dim] Tensor of positional embeddings.
214
+ """
215
+ if not repeat_only:
216
+ half = dim // 2
217
+ freqs = torch.exp(
218
+ -math.log(max_period)
219
+ * torch.arange(start=0, end=half, dtype=torch.float32)
220
+ / half
221
+ ).to(device=timesteps.device)
222
+ args = timesteps[:, None].float() * freqs[None]
223
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
224
+ if dim % 2:
225
+ embedding = torch.cat(
226
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
227
+ )
228
+ else:
229
+ embedding = repeat(timesteps, "b -> b d", d=dim)
230
+ return embedding
231
+
232
+
233
+ def zero_module(module):
234
+ """
235
+ Zero out the parameters of a module and return it.
236
+ """
237
+ for p in module.parameters():
238
+ p.detach().zero_()
239
+ return module
240
+
241
+
242
+ def scale_module(module, scale):
243
+ """
244
+ Scale the parameters of a module and return it.
245
+ """
246
+ for p in module.parameters():
247
+ p.detach().mul_(scale)
248
+ return module
249
+
250
+
251
+ def mean_flat(tensor):
252
+ """
253
+ Take the mean over all non-batch dimensions.
254
+ """
255
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
256
+
257
+
258
+ def normalization(channels):
259
+ """
260
+ Make a standard normalization layer.
261
+ :param channels: number of input channels.
262
+ :return: an nn.Module for normalization.
263
+ """
264
+ return GroupNorm32(32, channels)
265
+
266
+
267
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
268
+ class SiLU(nn.Module):
269
+ def forward(self, x):
270
+ return x * torch.sigmoid(x)
271
+
272
+
273
+ class GroupNorm32(nn.GroupNorm):
274
+ def forward(self, x):
275
+ return super().forward(x.float()).type(x.dtype)
276
+
277
+
278
+ def conv_nd(dims, *args, **kwargs):
279
+ """
280
+ Create a 1D, 2D, or 3D convolution module.
281
+ """
282
+ if dims == 1:
283
+ return nn.Conv1d(*args, **kwargs)
284
+ elif dims == 2:
285
+ return nn.Conv2d(*args, **kwargs)
286
+ elif dims == 3:
287
+ return nn.Conv3d(*args, **kwargs)
288
+ raise ValueError(f"unsupported dimensions: {dims}")
289
+
290
+
291
+ def linear(*args, **kwargs):
292
+ """
293
+ Create a linear module.
294
+ """
295
+ return nn.Linear(*args, **kwargs)
296
+
297
+
298
+ def avg_pool_nd(dims, *args, **kwargs):
299
+ """
300
+ Create a 1D, 2D, or 3D average pooling module.
301
+ """
302
+ if dims == 1:
303
+ return nn.AvgPool1d(*args, **kwargs)
304
+ elif dims == 2:
305
+ return nn.AvgPool2d(*args, **kwargs)
306
+ elif dims == 3:
307
+ return nn.AvgPool3d(*args, **kwargs)
308
+ raise ValueError(f"unsupported dimensions: {dims}")