ameerazam08 commited on
Commit
2cf789d
1 Parent(s): 7193dd9

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .gitignore +1 -0
  3. README.md +14 -0
  4. __pycache__/util.cpython-310.pyc +0 -0
  5. __pycache__/util.cpython-311.pyc +0 -0
  6. app.py +245 -0
  7. checkpoints/AEs/AE_inpainting_2.safetensors +3 -0
  8. checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt +3 -0
  9. checkpoints/st-step=100000+la-step=100000-v1.ckpt +3 -0
  10. configs/demo.yaml +29 -0
  11. configs/test/textdesign_sd_2.yaml +137 -0
  12. demo/examples/DIRTY_0_0.png +0 -0
  13. demo/examples/ENGINE_0_0.png +0 -0
  14. demo/examples/FAVOURITE_0_0.jpeg +0 -0
  15. demo/examples/FRONTIER_0_0.png +0 -0
  16. demo/examples/Peaceful_0_0.jpeg +0 -0
  17. demo/examples/Scamps_0_0.png +0 -0
  18. demo/examples/TREE_0_0.png +0 -0
  19. demo/examples/better_0_0.jpg +0 -0
  20. demo/examples/tested_0_0.png +0 -0
  21. demo/teaser.png +3 -0
  22. requirements.txt +28 -0
  23. sgm/__init__.py +2 -0
  24. sgm/__pycache__/__init__.cpython-310.pyc +0 -0
  25. sgm/__pycache__/__init__.cpython-311.pyc +0 -0
  26. sgm/__pycache__/lr_scheduler.cpython-311.pyc +0 -0
  27. sgm/__pycache__/util.cpython-310.pyc +0 -0
  28. sgm/__pycache__/util.cpython-311.pyc +0 -0
  29. sgm/lr_scheduler.py +135 -0
  30. sgm/models/__init__.py +2 -0
  31. sgm/models/__pycache__/__init__.cpython-310.pyc +0 -0
  32. sgm/models/__pycache__/__init__.cpython-311.pyc +0 -0
  33. sgm/models/__pycache__/autoencoder.cpython-310.pyc +0 -0
  34. sgm/models/__pycache__/autoencoder.cpython-311.pyc +0 -0
  35. sgm/models/__pycache__/diffusion.cpython-310.pyc +0 -0
  36. sgm/models/__pycache__/diffusion.cpython-311.pyc +0 -0
  37. sgm/models/autoencoder.py +335 -0
  38. sgm/models/diffusion.py +328 -0
  39. sgm/modules/__init__.py +6 -0
  40. sgm/modules/__pycache__/__init__.cpython-310.pyc +0 -0
  41. sgm/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  42. sgm/modules/__pycache__/attention.cpython-310.pyc +0 -0
  43. sgm/modules/__pycache__/attention.cpython-311.pyc +0 -0
  44. sgm/modules/__pycache__/ema.cpython-310.pyc +0 -0
  45. sgm/modules/__pycache__/ema.cpython-311.pyc +0 -0
  46. sgm/modules/attention.py +976 -0
  47. sgm/modules/autoencoding/__init__.py +0 -0
  48. sgm/modules/autoencoding/__pycache__/__init__.cpython-310.pyc +0 -0
  49. sgm/modules/autoencoding/__pycache__/__init__.cpython-311.pyc +0 -0
  50. sgm/modules/autoencoding/losses/__init__.py +246 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ /demo/**/* filter=lfs diff=lfs merge=lfs -text
37
+ checkpoints/**/* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ **/__pycache__
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: UDiffText
3
+ emoji: 😋
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.41.0
8
+ python_version: 3.11.4
9
+ app_file: app.py
10
+ pinned: true
11
+ license: apache-2.0
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
__pycache__/util.cpython-310.pyc ADDED
Binary file (1.77 kB). View file
 
__pycache__/util.cpython-311.pyc ADDED
Binary file (3.01 kB). View file
 
app.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -- coding: utf-8 --**
2
+ import cv2
3
+ import torch
4
+ import os, glob
5
+ import numpy as np
6
+ import gradio as gr
7
+ from PIL import Image
8
+ from omegaconf import OmegaConf
9
+ from contextlib import nullcontext
10
+ from pytorch_lightning import seed_everything
11
+ from os.path import join as ospj
12
+ from random import randint
13
+ from torchvision.utils import save_image
14
+ from torchvision.transforms import Resize
15
+
16
+ from util import *
17
+
18
+
19
+ def process(image, mask):
20
+
21
+ img_h, img_w = image.shape[:2]
22
+
23
+ mask = mask[...,:1]//255
24
+ contours, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
25
+ if len(contours) != 1: raise gr.Error("One masked area only!")
26
+
27
+ m_x, m_y, m_w, m_h = cv2.boundingRect(contours[0])
28
+ c_x, c_y = m_x + m_w//2, m_y + m_h//2
29
+
30
+ if img_w > img_h:
31
+ if m_w > img_h: raise gr.Error("Illegal mask area!")
32
+ if c_x < img_w - c_x:
33
+ c_l = max(0, c_x - img_h//2)
34
+ c_r = c_l + img_h
35
+ else:
36
+ c_r = min(img_w, c_x + img_h//2)
37
+ c_l = c_r - img_h
38
+ image = image[:,c_l:c_r,:]
39
+ mask = mask[:,c_l:c_r,:]
40
+ else:
41
+ if m_h > img_w: raise gr.Error("Illegal mask area!")
42
+ if c_y < img_h - c_y:
43
+ c_t = max(0, c_y - img_w//2)
44
+ c_b = c_t + img_w
45
+ else:
46
+ c_b = min(img_h, c_y + img_w//2)
47
+ c_t = c_b - img_w
48
+ image = image[c_t:c_b,:,:]
49
+ mask = mask[c_t:c_b,:,:]
50
+
51
+ image = torch.from_numpy(image.transpose(2,0,1)).to(dtype=torch.float32) / 127.5 - 1.0
52
+ mask = torch.from_numpy(mask.transpose(2,0,1)).to(dtype=torch.float32)
53
+
54
+ image = resize(image[None])[0]
55
+ mask = resize(mask[None])[0]
56
+ masked = image * (1 - mask)
57
+
58
+ return image, mask, masked
59
+
60
+
61
+
62
+ def predict(cfgs, model, sampler, batch):
63
+
64
+ context = nullcontext if cfgs.aae_enabled else torch.no_grad
65
+
66
+ with context():
67
+
68
+ batch, batch_uc_1 = prepare_batch(cfgs, batch)
69
+
70
+ c, uc_1 = model.conditioner.get_unconditional_conditioning(
71
+ batch,
72
+ batch_uc=batch_uc_1,
73
+ force_uc_zero_embeddings=cfgs.force_uc_zero_embeddings,
74
+ )
75
+
76
+ x = sampler.get_init_noise(cfgs, model, cond=c, batch=batch, uc=uc_1)
77
+ samples_z = sampler(model, x, cond=c, batch=batch, uc=uc_1, init_step=0,
78
+ aae_enabled = cfgs.aae_enabled, detailed = cfgs.detailed)
79
+
80
+ samples_x = model.decode_first_stage(samples_z)
81
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
82
+
83
+ return samples, samples_z
84
+
85
+
86
+ def demo_predict(input_blk, text, num_samples, steps, scale, seed, show_detail):
87
+
88
+ global cfgs, global_index
89
+
90
+ if len(text) < cfgs.txt_len[0] or len(text) > cfgs.txt_len[1]:
91
+ raise gr.Error("Illegal text length!")
92
+
93
+ global_index += 1
94
+
95
+ if num_samples > 1: cfgs.noise_iters = 0
96
+
97
+ cfgs.batch_size = num_samples
98
+ cfgs.steps = steps
99
+ cfgs.scale[0] = scale
100
+ cfgs.detailed = show_detail
101
+ seed_everything(seed)
102
+
103
+ sampler.num_steps = steps
104
+ sampler.guider.scale_value = scale
105
+
106
+ image = input_blk["image"]
107
+ mask = input_blk["mask"]
108
+
109
+ image, mask, masked = process(image, mask)
110
+
111
+ seg_mask = torch.cat((torch.ones(len(text)), torch.zeros(cfgs.seq_len-len(text))))
112
+
113
+ # additional cond
114
+ txt = f"\"{text}\""
115
+ original_size_as_tuple = torch.tensor((cfgs.H, cfgs.W))
116
+ crop_coords_top_left = torch.tensor((0, 0))
117
+ target_size_as_tuple = torch.tensor((cfgs.H, cfgs.W))
118
+
119
+ image = torch.tile(image[None], (num_samples, 1, 1, 1))
120
+ mask = torch.tile(mask[None], (num_samples, 1, 1, 1))
121
+ masked = torch.tile(masked[None], (num_samples, 1, 1, 1))
122
+ seg_mask = torch.tile(seg_mask[None], (num_samples, 1))
123
+ original_size_as_tuple = torch.tile(original_size_as_tuple[None], (num_samples, 1))
124
+ crop_coords_top_left = torch.tile(crop_coords_top_left[None], (num_samples, 1))
125
+ target_size_as_tuple = torch.tile(target_size_as_tuple[None], (num_samples, 1))
126
+
127
+ text = [text for i in range(num_samples)]
128
+ txt = [txt for i in range(num_samples)]
129
+ name = [str(global_index) for i in range(num_samples)]
130
+
131
+ batch = {
132
+ "image": image,
133
+ "mask": mask,
134
+ "masked": masked,
135
+ "seg_mask": seg_mask,
136
+ "label": text,
137
+ "txt": txt,
138
+ "original_size_as_tuple": original_size_as_tuple,
139
+ "crop_coords_top_left": crop_coords_top_left,
140
+ "target_size_as_tuple": target_size_as_tuple,
141
+ "name": name
142
+ }
143
+
144
+ samples, samples_z = predict(cfgs, model, sampler, batch)
145
+ samples = samples.cpu().numpy().transpose(0, 2, 3, 1) * 255
146
+ results = [Image.fromarray(sample.astype(np.uint8)) for sample in samples]
147
+
148
+ if cfgs.detailed:
149
+ sections = []
150
+ attn_map = Image.open(f"./temp/attn_map/attn_map_{global_index}.png")
151
+ seg_maps = np.load(f"./temp/seg_map/seg_{global_index}.npy")
152
+ for i, seg_map in enumerate(seg_maps):
153
+ seg_map = cv2.resize(seg_map, (cfgs.W, cfgs.H))
154
+ sections.append((seg_map, text[0][i]))
155
+ seg = (results[0], sections)
156
+ else:
157
+ attn_map = None
158
+ seg = None
159
+
160
+ return results, attn_map, seg
161
+
162
+
163
+ if __name__ == "__main__":
164
+
165
+ os.makedirs("./temp", exist_ok=True)
166
+ os.makedirs("./temp/attn_map", exist_ok=True)
167
+ os.makedirs("./temp/seg_map", exist_ok=True)
168
+
169
+ cfgs = OmegaConf.load("./configs/demo.yaml")
170
+
171
+ model = init_model(cfgs)
172
+ sampler = init_sampling(cfgs)
173
+ global_index = 0
174
+ resize = Resize((cfgs.H, cfgs.W))
175
+
176
+ block = gr.Blocks().queue()
177
+ with block:
178
+
179
+ with gr.Row():
180
+
181
+ gr.HTML(
182
+ """
183
+ <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
184
+ <h1 style="font-weight: 600; font-size: 2rem; margin: 0.5rem;">
185
+ UDiffText: A Unified Framework for High-quality Text Synthesis in Arbitrary Images via Character-aware Diffusion Models
186
+ </h1>
187
+ <ul style="text-align: center; margin: 0.5rem;">
188
+ <li style="display: inline-block; margin:auto;"><a href='https://arxiv.org/abs/2312.04884'><img src='https://img.shields.io/badge/Arxiv-2312.04884-DF826C'></a></li>
189
+ <li style="display: inline-block; margin:auto;"><a href='https://github.com/ZYM-PKU/UDiffText'><img src='https://img.shields.io/badge/Code-UDiffText-D0F288'></a></li>
190
+ <li style="display: inline-block; margin:auto;"><a href='https://udifftext.github.io'><img src='https://img.shields.io/badge/Project-UDiffText-8ADAB2'></a></li>
191
+ </ul>
192
+ <h2 style="text-align: left; font-weight: 450; font-size: 1rem; margin: 0.5rem;">
193
+ Our proposed UDiffText is capable of synthesizing accurate and harmonious text in either synthetic or real-word images, thus can be applied to tasks like scene text editing (a), arbitrary text generation (b) and accurate T2I generation (c)
194
+ </h2>
195
+ <div align=center><img src="file/demo/teaser.png" alt="UDiffText" width="80%"></div>
196
+ </div>
197
+ """
198
+ )
199
+
200
+ with gr.Row():
201
+
202
+ with gr.Column():
203
+
204
+ input_blk = gr.Image(source='upload', tool='sketch', type="numpy", label="Input", height=512)
205
+ gr.Markdown("Notice: please draw horizontally to indicate only **one** masked area.")
206
+ text = gr.Textbox(label="Text to render: (1~12 characters)", info="the text you want to render at the masked region")
207
+ run_button = gr.Button(variant="primary")
208
+
209
+ with gr.Accordion("Advanced options", open=False):
210
+
211
+ num_samples = gr.Slider(label="Images", info="number of generated images, locked as 1", minimum=1, maximum=1, value=1, step=1)
212
+ steps = gr.Slider(label="Steps", info ="denoising sampling steps", minimum=1, maximum=200, value=50, step=1)
213
+ scale = gr.Slider(label="Guidance Scale", info="the scale of classifier-free guidance (CFG)", minimum=0.0, maximum=10.0, value=5.0, step=0.1)
214
+ seed = gr.Slider(label="Seed", info="random seed for noise initialization", minimum=0, maximum=2147483647, step=1, randomize=True)
215
+ show_detail = gr.Checkbox(label="Show Detail", info="show the additional visualization results", value=False)
216
+
217
+ with gr.Column():
218
+
219
+ gallery = gr.Gallery(label="Output", height=512, preview=True)
220
+
221
+ with gr.Accordion("Visualization results", open=True):
222
+
223
+ with gr.Tab(label="Attention Maps"):
224
+ gr.Markdown("### Attention maps for each character (extracted from middle blocks at intermediate sampling step):")
225
+ attn_map = gr.Image(show_label=False, show_download_button=False)
226
+ with gr.Tab(label="Segmentation Maps"):
227
+ gr.Markdown("### Character-level segmentation maps (using upscaled attention maps):")
228
+ seg_map = gr.AnnotatedImage(height=384, show_label=False)
229
+
230
+ # examples
231
+ examples = []
232
+ example_paths = sorted(glob.glob(ospj("./demo/examples", "*")))
233
+ for example_path in example_paths:
234
+ label = example_path.split(os.sep)[-1].split(".")[0].split("_")[0]
235
+ examples.append([example_path, label])
236
+
237
+ gr.Markdown("## Examples:")
238
+ gr.Examples(
239
+ examples=examples,
240
+ inputs=[input_blk, text]
241
+ )
242
+
243
+ run_button.click(fn=demo_predict, inputs=[input_blk, text, num_samples, steps, scale, seed, show_detail], outputs=[gallery, attn_map, seg_map])
244
+
245
+ block.launch()
checkpoints/AEs/AE_inpainting_2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:547baac83984f8bf8b433882236b87e77eb4d2f5c71e3d7a04b8dec2fe02b81f
3
+ size 334640988
checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4076c90467a907dcb8cde15776bfda4473010fe845739490341db74e82cd2267
3
+ size 4059026213
checkpoints/st-step=100000+la-step=100000-v1.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:edea71eb83b6be72c33ef787a7122a810a7b9257bf97a276ef322707d5769878
3
+ size 6148465904
configs/demo.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ type: "demo"
2
+
3
+ # path
4
+ load_ckpt_path: "./checkpoints/st-step=100000+la-step=100000-v1.ckpt"
5
+ model_cfg_path: "./configs/test/textdesign_sd_2.yaml"
6
+
7
+ # param
8
+ H: 512
9
+ W: 512
10
+ txt_len: [1, 12]
11
+ seq_len: 12
12
+ batch_size: 1
13
+
14
+ channel: 4 # AE latent channel
15
+ factor: 8 # AE downsample factor
16
+ scale: [5.0, 0.0] # content scale, style scale
17
+ noise_iters: 0
18
+ force_uc_zero_embeddings: ["label"]
19
+ aae_enabled: False
20
+ detailed: False
21
+
22
+ # runtime
23
+ steps: 50
24
+ init_step: 0
25
+ num_workers: 0
26
+ use_gpu: True
27
+ gpu: 0
28
+ max_iter: 100
29
+
configs/test/textdesign_sd_2.yaml ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ input_key: image
5
+ scale_factor: 0.18215
6
+ disable_first_stage_autocast: True
7
+
8
+ denoiser_config:
9
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
10
+ params:
11
+ num_idx: 1000
12
+
13
+ weighting_config:
14
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
15
+ scaling_config:
16
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
17
+ discretization_config:
18
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
19
+
20
+ network_config:
21
+ target: sgm.modules.diffusionmodules.openaimodel.UNetAddModel
22
+ params:
23
+ use_checkpoint: False
24
+ in_channels: 9
25
+ out_channels: 4
26
+ ctrl_channels: 0
27
+ model_channels: 320
28
+ attention_resolutions: [4, 2, 1]
29
+ attn_type: add_attn
30
+ attn_layers:
31
+ - output_blocks.6.1
32
+ num_res_blocks: 2
33
+ channel_mult: [1, 2, 4, 4]
34
+ num_head_channels: 64
35
+ use_spatial_transformer: True
36
+ use_linear_in_transformer: True
37
+ transformer_depth: 1
38
+ context_dim: 0
39
+ add_context_dim: 2048
40
+ legacy: False
41
+
42
+ conditioner_config:
43
+ target: sgm.modules.GeneralConditioner
44
+ params:
45
+ emb_models:
46
+ # crossattn cond
47
+ # - is_trainable: False
48
+ # input_key: txt
49
+ # target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
50
+ # params:
51
+ # arch: ViT-H-14
52
+ # version: ./checkpoints/encoders/OpenCLIP/ViT-H-14/open_clip_pytorch_model.bin
53
+ # layer: penultimate
54
+ # add crossattn cond
55
+ - is_trainable: False
56
+ input_key: label
57
+ target: sgm.modules.encoders.modules.LabelEncoder
58
+ params:
59
+ is_add_embedder: True
60
+ max_len: 12
61
+ emb_dim: 2048
62
+ n_heads: 8
63
+ n_trans_layers: 12
64
+ ckpt_path: ./checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt # ./checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt
65
+ # concat cond
66
+ - is_trainable: False
67
+ input_key: mask
68
+ target: sgm.modules.encoders.modules.IdentityEncoder
69
+ - is_trainable: False
70
+ input_key: masked
71
+ target: sgm.modules.encoders.modules.LatentEncoder
72
+ params:
73
+ scale_factor: 0.18215
74
+ config:
75
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
76
+ params:
77
+ ckpt_path: ./checkpoints/AEs/AE_inpainting_2.safetensors
78
+ embed_dim: 4
79
+ monitor: val/rec_loss
80
+ ddconfig:
81
+ attn_type: vanilla-xformers
82
+ double_z: true
83
+ z_channels: 4
84
+ resolution: 256
85
+ in_channels: 3
86
+ out_ch: 3
87
+ ch: 128
88
+ ch_mult: [1, 2, 4, 4]
89
+ num_res_blocks: 2
90
+ attn_resolutions: []
91
+ dropout: 0.0
92
+ lossconfig:
93
+ target: torch.nn.Identity
94
+
95
+ first_stage_config:
96
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
97
+ params:
98
+ embed_dim: 4
99
+ monitor: val/rec_loss
100
+ ddconfig:
101
+ attn_type: vanilla-xformers
102
+ double_z: true
103
+ z_channels: 4
104
+ resolution: 256
105
+ in_channels: 3
106
+ out_ch: 3
107
+ ch: 128
108
+ ch_mult: [1, 2, 4, 4]
109
+ num_res_blocks: 2
110
+ attn_resolutions: []
111
+ dropout: 0.0
112
+ lossconfig:
113
+ target: torch.nn.Identity
114
+
115
+ loss_fn_config:
116
+ target: sgm.modules.diffusionmodules.loss.FullLoss # StandardDiffusionLoss
117
+ params:
118
+ seq_len: 12
119
+ kernel_size: 3
120
+ gaussian_sigma: 0.5
121
+ min_attn_size: 16
122
+ lambda_local_loss: 0.02
123
+ lambda_ocr_loss: 0.001
124
+ ocr_enabled: False
125
+
126
+ predictor_config:
127
+ target: sgm.modules.predictors.model.ParseqPredictor
128
+ params:
129
+ ckpt_path: "./checkpoints/predictors/parseq-bb5792a6.pt"
130
+
131
+ sigma_sampler_config:
132
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
133
+ params:
134
+ num_idx: 1000
135
+
136
+ discretization_config:
137
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
demo/examples/DIRTY_0_0.png ADDED
demo/examples/ENGINE_0_0.png ADDED
demo/examples/FAVOURITE_0_0.jpeg ADDED
demo/examples/FRONTIER_0_0.png ADDED
demo/examples/Peaceful_0_0.jpeg ADDED
demo/examples/Scamps_0_0.png ADDED
demo/examples/TREE_0_0.png ADDED
demo/examples/better_0_0.jpg ADDED
demo/examples/tested_0_0.png ADDED
demo/teaser.png ADDED

Git LFS Details

  • SHA256: dcd166cc9691c99a7ee93a028ab485472171ee348a5f4dbaf82f6bf1fb27c66d
  • Pointer size: 132 Bytes
  • Size of remote file: 2.62 MB
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ colorlover==0.3.0
2
+ einops==0.6.1
3
+ gradio==3.41.0
4
+ imageio==2.31.2
5
+ img2dataset==1.42.0
6
+ kornia==0.6.9
7
+ lpips==0.1.4
8
+ matplotlib==3.7.2
9
+ numpy==1.25.1
10
+ omegaconf==2.3.0
11
+ open-clip-torch==2.20.0
12
+ opencv-python==4.6.0.66
13
+ Pillow==9.5.0
14
+ pytorch-fid==0.3.0
15
+ pytorch-lightning==2.0.1
16
+ safetensors==0.3.1
17
+ scikit-learn==1.3.0
18
+ scipy==1.11.1
19
+ seaborn==0.12.2
20
+ tensorboard==2.14.0
21
+ timm==0.9.2
22
+ tokenizers==0.13.3
23
+ torch==2.1.0
24
+ torchvision==0.16.0
25
+ tqdm==4.65.0
26
+ transformers==4.30.2
27
+ xformers==0.0.22.post7
28
+
sgm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .models import AutoencodingEngine, DiffusionEngine
2
+ from .util import instantiate_from_config
sgm/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (267 Bytes). View file
 
sgm/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (314 Bytes). View file
 
sgm/__pycache__/lr_scheduler.cpython-311.pyc ADDED
Binary file (6.56 kB). View file
 
sgm/__pycache__/util.cpython-310.pyc ADDED
Binary file (8.07 kB). View file
 
sgm/__pycache__/util.cpython-311.pyc ADDED
Binary file (13.5 kB). View file
 
sgm/lr_scheduler.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+
9
+ def __init__(
10
+ self,
11
+ warm_up_steps,
12
+ lr_min,
13
+ lr_max,
14
+ lr_start,
15
+ max_decay_steps,
16
+ verbosity_interval=0,
17
+ ):
18
+ self.lr_warm_up_steps = warm_up_steps
19
+ self.lr_start = lr_start
20
+ self.lr_min = lr_min
21
+ self.lr_max = lr_max
22
+ self.lr_max_decay_steps = max_decay_steps
23
+ self.last_lr = 0.0
24
+ self.verbosity_interval = verbosity_interval
25
+
26
+ def schedule(self, n, **kwargs):
27
+ if self.verbosity_interval > 0:
28
+ if n % self.verbosity_interval == 0:
29
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
30
+ if n < self.lr_warm_up_steps:
31
+ lr = (
32
+ self.lr_max - self.lr_start
33
+ ) / self.lr_warm_up_steps * n + self.lr_start
34
+ self.last_lr = lr
35
+ return lr
36
+ else:
37
+ t = (n - self.lr_warm_up_steps) / (
38
+ self.lr_max_decay_steps - self.lr_warm_up_steps
39
+ )
40
+ t = min(t, 1.0)
41
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
42
+ 1 + np.cos(t * np.pi)
43
+ )
44
+ self.last_lr = lr
45
+ return lr
46
+
47
+ def __call__(self, n, **kwargs):
48
+ return self.schedule(n, **kwargs)
49
+
50
+
51
+ class LambdaWarmUpCosineScheduler2:
52
+ """
53
+ supports repeated iterations, configurable via lists
54
+ note: use with a base_lr of 1.0.
55
+ """
56
+
57
+ def __init__(
58
+ self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
59
+ ):
60
+ assert (
61
+ len(warm_up_steps)
62
+ == len(f_min)
63
+ == len(f_max)
64
+ == len(f_start)
65
+ == len(cycle_lengths)
66
+ )
67
+ self.lr_warm_up_steps = warm_up_steps
68
+ self.f_start = f_start
69
+ self.f_min = f_min
70
+ self.f_max = f_max
71
+ self.cycle_lengths = cycle_lengths
72
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
73
+ self.last_f = 0.0
74
+ self.verbosity_interval = verbosity_interval
75
+
76
+ def find_in_interval(self, n):
77
+ interval = 0
78
+ for cl in self.cum_cycles[1:]:
79
+ if n <= cl:
80
+ return interval
81
+ interval += 1
82
+
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0:
88
+ print(
89
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
90
+ f"current cycle {cycle}"
91
+ )
92
+ if n < self.lr_warm_up_steps[cycle]:
93
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
94
+ cycle
95
+ ] * n + self.f_start[cycle]
96
+ self.last_f = f
97
+ return f
98
+ else:
99
+ t = (n - self.lr_warm_up_steps[cycle]) / (
100
+ self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
101
+ )
102
+ t = min(t, 1.0)
103
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
104
+ 1 + np.cos(t * np.pi)
105
+ )
106
+ self.last_f = f
107
+ return f
108
+
109
+ def __call__(self, n, **kwargs):
110
+ return self.schedule(n, **kwargs)
111
+
112
+
113
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
114
+ def schedule(self, n, **kwargs):
115
+ cycle = self.find_in_interval(n)
116
+ n = n - self.cum_cycles[cycle]
117
+ if self.verbosity_interval > 0:
118
+ if n % self.verbosity_interval == 0:
119
+ print(
120
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
121
+ f"current cycle {cycle}"
122
+ )
123
+
124
+ if n < self.lr_warm_up_steps[cycle]:
125
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
126
+ cycle
127
+ ] * n + self.f_start[cycle]
128
+ self.last_f = f
129
+ return f
130
+ else:
131
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
132
+ self.cycle_lengths[cycle] - n
133
+ ) / (self.cycle_lengths[cycle])
134
+ self.last_f = f
135
+ return f
sgm/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .autoencoder import AutoencodingEngine
2
+ from .diffusion import DiffusionEngine
sgm/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (250 Bytes). View file
 
sgm/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (291 Bytes). View file
 
sgm/models/__pycache__/autoencoder.cpython-310.pyc ADDED
Binary file (11.5 kB). View file
 
sgm/models/__pycache__/autoencoder.cpython-311.pyc ADDED
Binary file (20.2 kB). View file
 
sgm/models/__pycache__/diffusion.cpython-310.pyc ADDED
Binary file (10.8 kB). View file
 
sgm/models/__pycache__/diffusion.cpython-311.pyc ADDED
Binary file (20.2 kB). View file
 
sgm/models/autoencoder.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from abc import abstractmethod
3
+ from contextlib import contextmanager
4
+ from typing import Any, Dict, Tuple, Union
5
+
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ from omegaconf import ListConfig
9
+ from packaging import version
10
+ from safetensors.torch import load_file as load_safetensors
11
+
12
+ from ..modules.diffusionmodules.model import Decoder, Encoder
13
+ from ..modules.distributions.distributions import DiagonalGaussianDistribution
14
+ from ..modules.ema import LitEma
15
+ from ..util import default, get_obj_from_str, instantiate_from_config
16
+
17
+
18
+ class AbstractAutoencoder(pl.LightningModule):
19
+ """
20
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
21
+ unCLIP models, etc. Hence, it is fairly general, and specific features
22
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ ema_decay: Union[None, float] = None,
28
+ monitor: Union[None, str] = None,
29
+ input_key: str = "jpg",
30
+ ckpt_path: Union[None, str] = None,
31
+ ignore_keys: Union[Tuple, list, ListConfig] = (),
32
+ ):
33
+ super().__init__()
34
+ self.input_key = input_key
35
+ self.use_ema = ema_decay is not None
36
+ if monitor is not None:
37
+ self.monitor = monitor
38
+
39
+ if self.use_ema:
40
+ self.model_ema = LitEma(self, decay=ema_decay)
41
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
42
+
43
+ if ckpt_path is not None:
44
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
45
+
46
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
47
+ self.automatic_optimization = False
48
+
49
+ def init_from_ckpt(
50
+ self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple()
51
+ ) -> None:
52
+ if path.endswith("ckpt"):
53
+ sd = torch.load(path, map_location="cpu")["state_dict"]
54
+ elif path.endswith("safetensors"):
55
+ sd = load_safetensors(path)
56
+ else:
57
+ raise NotImplementedError
58
+
59
+ keys = list(sd.keys())
60
+ for k in keys:
61
+ for ik in ignore_keys:
62
+ if re.match(ik, k):
63
+ print("Deleting key {} from state_dict.".format(k))
64
+ del sd[k]
65
+ missing, unexpected = self.load_state_dict(sd, strict=False)
66
+ print(
67
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
68
+ )
69
+ if len(missing) > 0:
70
+ print(f"Missing Keys: {missing}")
71
+ if len(unexpected) > 0:
72
+ print(f"Unexpected Keys: {unexpected}")
73
+
74
+ @abstractmethod
75
+ def get_input(self, batch) -> Any:
76
+ raise NotImplementedError()
77
+
78
+ def on_train_batch_end(self, *args, **kwargs):
79
+ # for EMA computation
80
+ if self.use_ema:
81
+ self.model_ema(self)
82
+
83
+ @contextmanager
84
+ def ema_scope(self, context=None):
85
+ if self.use_ema:
86
+ self.model_ema.store(self.parameters())
87
+ self.model_ema.copy_to(self)
88
+ if context is not None:
89
+ print(f"{context}: Switched to EMA weights")
90
+ try:
91
+ yield None
92
+ finally:
93
+ if self.use_ema:
94
+ self.model_ema.restore(self.parameters())
95
+ if context is not None:
96
+ print(f"{context}: Restored training weights")
97
+
98
+ @abstractmethod
99
+ def encode(self, *args, **kwargs) -> torch.Tensor:
100
+ raise NotImplementedError("encode()-method of abstract base class called")
101
+
102
+ @abstractmethod
103
+ def decode(self, *args, **kwargs) -> torch.Tensor:
104
+ raise NotImplementedError("decode()-method of abstract base class called")
105
+
106
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
107
+ print(f"loading >>> {cfg['target']} <<< optimizer from config")
108
+ return get_obj_from_str(cfg["target"])(
109
+ params, lr=lr, **cfg.get("params", dict())
110
+ )
111
+
112
+ def configure_optimizers(self) -> Any:
113
+ raise NotImplementedError()
114
+
115
+
116
+ class AutoencodingEngine(AbstractAutoencoder):
117
+ """
118
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
119
+ (we also restore them explicitly as special cases for legacy reasons).
120
+ Regularizations such as KL or VQ are moved to the regularizer class.
121
+ """
122
+
123
+ def __init__(
124
+ self,
125
+ *args,
126
+ encoder_config: Dict,
127
+ decoder_config: Dict,
128
+ loss_config: Dict,
129
+ regularizer_config: Dict,
130
+ optimizer_config: Union[Dict, None] = None,
131
+ lr_g_factor: float = 1.0,
132
+ **kwargs,
133
+ ):
134
+ super().__init__(*args, **kwargs)
135
+ # todo: add options to freeze encoder/decoder
136
+ self.encoder = instantiate_from_config(encoder_config)
137
+ self.decoder = instantiate_from_config(decoder_config)
138
+ self.loss = instantiate_from_config(loss_config)
139
+ self.regularization = instantiate_from_config(regularizer_config)
140
+ self.optimizer_config = default(
141
+ optimizer_config, {"target": "torch.optim.Adam"}
142
+ )
143
+ self.lr_g_factor = lr_g_factor
144
+
145
+ def get_input(self, batch: Dict) -> torch.Tensor:
146
+ # assuming unified data format, dataloader returns a dict.
147
+ # image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc)
148
+ return batch[self.input_key]
149
+
150
+ def get_autoencoder_params(self) -> list:
151
+ params = (
152
+ list(self.encoder.parameters())
153
+ + list(self.decoder.parameters())
154
+ + list(self.regularization.get_trainable_parameters())
155
+ + list(self.loss.get_trainable_autoencoder_parameters())
156
+ )
157
+ return params
158
+
159
+ def get_discriminator_params(self) -> list:
160
+ params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
161
+ return params
162
+
163
+ def get_last_layer(self):
164
+ return self.decoder.get_last_layer()
165
+
166
+ def encode(self, x: Any, return_reg_log: bool = False) -> Any:
167
+ z = self.encoder(x)
168
+ z, reg_log = self.regularization(z)
169
+ if return_reg_log:
170
+ return z, reg_log
171
+ return z
172
+
173
+ def decode(self, z: Any) -> torch.Tensor:
174
+ x = self.decoder(z)
175
+ return x
176
+
177
+ def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
178
+ z, reg_log = self.encode(x, return_reg_log=True)
179
+ dec = self.decode(z)
180
+ return z, dec, reg_log
181
+
182
+ def training_step(self, batch, batch_idx, optimizer_idx) -> Any:
183
+ x = self.get_input(batch)
184
+ z, xrec, regularization_log = self(x)
185
+
186
+ if optimizer_idx == 0:
187
+ # autoencode
188
+ aeloss, log_dict_ae = self.loss(
189
+ regularization_log,
190
+ x,
191
+ xrec,
192
+ optimizer_idx,
193
+ self.global_step,
194
+ last_layer=self.get_last_layer(),
195
+ split="train",
196
+ )
197
+
198
+ self.log_dict(
199
+ log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
200
+ )
201
+ return aeloss
202
+
203
+ if optimizer_idx == 1:
204
+ # discriminator
205
+ discloss, log_dict_disc = self.loss(
206
+ regularization_log,
207
+ x,
208
+ xrec,
209
+ optimizer_idx,
210
+ self.global_step,
211
+ last_layer=self.get_last_layer(),
212
+ split="train",
213
+ )
214
+ self.log_dict(
215
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
216
+ )
217
+ return discloss
218
+
219
+ def validation_step(self, batch, batch_idx) -> Dict:
220
+ log_dict = self._validation_step(batch, batch_idx)
221
+ with self.ema_scope():
222
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
223
+ log_dict.update(log_dict_ema)
224
+ return log_dict
225
+
226
+ def _validation_step(self, batch, batch_idx, postfix="") -> Dict:
227
+ x = self.get_input(batch)
228
+
229
+ z, xrec, regularization_log = self(x)
230
+ aeloss, log_dict_ae = self.loss(
231
+ regularization_log,
232
+ x,
233
+ xrec,
234
+ 0,
235
+ self.global_step,
236
+ last_layer=self.get_last_layer(),
237
+ split="val" + postfix,
238
+ )
239
+
240
+ discloss, log_dict_disc = self.loss(
241
+ regularization_log,
242
+ x,
243
+ xrec,
244
+ 1,
245
+ self.global_step,
246
+ last_layer=self.get_last_layer(),
247
+ split="val" + postfix,
248
+ )
249
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
250
+ log_dict_ae.update(log_dict_disc)
251
+ self.log_dict(log_dict_ae)
252
+ return log_dict_ae
253
+
254
+ def configure_optimizers(self) -> Any:
255
+ ae_params = self.get_autoencoder_params()
256
+ disc_params = self.get_discriminator_params()
257
+
258
+ opt_ae = self.instantiate_optimizer_from_config(
259
+ ae_params,
260
+ default(self.lr_g_factor, 1.0) * self.learning_rate,
261
+ self.optimizer_config,
262
+ )
263
+ opt_disc = self.instantiate_optimizer_from_config(
264
+ disc_params, self.learning_rate, self.optimizer_config
265
+ )
266
+
267
+ return [opt_ae, opt_disc], []
268
+
269
+ @torch.no_grad()
270
+ def log_images(self, batch: Dict, **kwargs) -> Dict:
271
+ log = dict()
272
+ x = self.get_input(batch)
273
+ _, xrec, _ = self(x)
274
+ log["inputs"] = x
275
+ log["reconstructions"] = xrec
276
+ with self.ema_scope():
277
+ _, xrec_ema, _ = self(x)
278
+ log["reconstructions_ema"] = xrec_ema
279
+ return log
280
+
281
+
282
+ class AutoencoderKL(AutoencodingEngine):
283
+ def __init__(self, embed_dim: int, **kwargs):
284
+ ddconfig = kwargs.pop("ddconfig")
285
+ ckpt_path = kwargs.pop("ckpt_path", None)
286
+ ignore_keys = kwargs.pop("ignore_keys", ())
287
+ super().__init__(
288
+ encoder_config={"target": "torch.nn.Identity"},
289
+ decoder_config={"target": "torch.nn.Identity"},
290
+ regularizer_config={"target": "torch.nn.Identity"},
291
+ loss_config=kwargs.pop("lossconfig"),
292
+ **kwargs,
293
+ )
294
+ assert ddconfig["double_z"]
295
+ self.encoder = Encoder(**ddconfig)
296
+ self.decoder = Decoder(**ddconfig)
297
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
298
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
299
+ self.embed_dim = embed_dim
300
+
301
+ if ckpt_path is not None:
302
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
303
+
304
+ def encode(self, x):
305
+ assert (
306
+ not self.training
307
+ ), f"{self.__class__.__name__} only supports inference currently"
308
+ h = self.encoder(x)
309
+ moments = self.quant_conv(h)
310
+ posterior = DiagonalGaussianDistribution(moments)
311
+ return posterior
312
+
313
+ def decode(self, z, **decoder_kwargs):
314
+ z = self.post_quant_conv(z)
315
+ dec = self.decoder(z, **decoder_kwargs)
316
+ return dec
317
+
318
+
319
+ class AutoencoderKLInferenceWrapper(AutoencoderKL):
320
+ def encode(self, x):
321
+ return super().encode(x).sample()
322
+
323
+
324
+ class IdentityFirstStage(AbstractAutoencoder):
325
+ def __init__(self, *args, **kwargs):
326
+ super().__init__(*args, **kwargs)
327
+
328
+ def get_input(self, x: Any) -> Any:
329
+ return x
330
+
331
+ def encode(self, x: Any, *args, **kwargs) -> Any:
332
+ return x
333
+
334
+ def decode(self, x: Any, *args, **kwargs) -> Any:
335
+ return x
sgm/models/diffusion.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ from typing import Any, Dict, List, Tuple, Union
3
+
4
+ import pytorch_lightning as pl
5
+ import torch
6
+ from omegaconf import ListConfig, OmegaConf
7
+ from safetensors.torch import load_file as load_safetensors
8
+ from torch.optim.lr_scheduler import LambdaLR
9
+
10
+ from ..modules import UNCONDITIONAL_CONFIG
11
+ from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
12
+ from ..modules.ema import LitEma
13
+ from ..util import (
14
+ default,
15
+ disabled_train,
16
+ get_obj_from_str,
17
+ instantiate_from_config,
18
+ log_txt_as_img,
19
+ )
20
+
21
+
22
+ class DiffusionEngine(pl.LightningModule):
23
+ def __init__(
24
+ self,
25
+ network_config,
26
+ denoiser_config,
27
+ first_stage_config,
28
+ conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
29
+ sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
30
+ optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
31
+ scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
32
+ loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
33
+ network_wrapper: Union[None, str] = None,
34
+ ckpt_path: Union[None, str] = None,
35
+ use_ema: bool = False,
36
+ ema_decay_rate: float = 0.9999,
37
+ scale_factor: float = 1.0,
38
+ disable_first_stage_autocast=False,
39
+ input_key: str = "jpg",
40
+ log_keys: Union[List, None] = None,
41
+ no_cond_log: bool = False,
42
+ compile_model: bool = False,
43
+ opt_keys: Union[List, None] = None
44
+ ):
45
+ super().__init__()
46
+ self.opt_keys = opt_keys
47
+ self.log_keys = log_keys
48
+ self.input_key = input_key
49
+ self.optimizer_config = default(
50
+ optimizer_config, {"target": "torch.optim.AdamW"}
51
+ )
52
+ model = instantiate_from_config(network_config)
53
+ self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
54
+ model, compile_model=compile_model
55
+ )
56
+
57
+ self.denoiser = instantiate_from_config(denoiser_config)
58
+ self.sampler = (
59
+ instantiate_from_config(sampler_config)
60
+ if sampler_config is not None
61
+ else None
62
+ )
63
+ self.conditioner = instantiate_from_config(
64
+ default(conditioner_config, UNCONDITIONAL_CONFIG)
65
+ )
66
+ self.scheduler_config = scheduler_config
67
+ self._init_first_stage(first_stage_config)
68
+
69
+ self.loss_fn = (
70
+ instantiate_from_config(loss_fn_config)
71
+ if loss_fn_config is not None
72
+ else None
73
+ )
74
+
75
+ self.use_ema = use_ema
76
+ if self.use_ema:
77
+ self.model_ema = LitEma(self.model, decay=ema_decay_rate)
78
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
79
+
80
+ self.scale_factor = scale_factor
81
+ self.disable_first_stage_autocast = disable_first_stage_autocast
82
+ self.no_cond_log = no_cond_log
83
+
84
+ if ckpt_path is not None:
85
+ self.init_from_ckpt(ckpt_path)
86
+
87
+ def init_from_ckpt(
88
+ self,
89
+ path: str,
90
+ ) -> None:
91
+ if path.endswith("ckpt"):
92
+ sd = torch.load(path, map_location="cpu")["state_dict"]
93
+ elif path.endswith("safetensors"):
94
+ sd = load_safetensors(path)
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ missing, unexpected = self.load_state_dict(sd, strict=False)
99
+ print(
100
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
101
+ )
102
+ if len(missing) > 0:
103
+ print(f"Missing Keys: {missing}")
104
+ if len(unexpected) > 0:
105
+ print(f"Unexpected Keys: {unexpected}")
106
+
107
+ def freeze(self):
108
+
109
+ for param in self.parameters():
110
+ param.requires_grad_(False)
111
+
112
+ def _init_first_stage(self, config):
113
+ model = instantiate_from_config(config).eval()
114
+ model.train = disabled_train
115
+ for param in model.parameters():
116
+ param.requires_grad = False
117
+ self.first_stage_model = model
118
+
119
+ def get_input(self, batch):
120
+ # assuming unified data format, dataloader returns a dict.
121
+ # image tensors should be scaled to -1 ... 1 and in bchw format
122
+ return batch[self.input_key]
123
+
124
+ @torch.no_grad()
125
+ def decode_first_stage(self, z):
126
+ z = 1.0 / self.scale_factor * z
127
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
128
+ out = self.first_stage_model.decode(z)
129
+ return out
130
+
131
+ @torch.no_grad()
132
+ def encode_first_stage(self, x):
133
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
134
+ z = self.first_stage_model.encode(x)
135
+ z = self.scale_factor * z
136
+ return z
137
+
138
+ def forward(self, x, batch):
139
+
140
+ loss, loss_dict = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch, self.first_stage_model, self.scale_factor)
141
+
142
+ return loss, loss_dict
143
+
144
+ def shared_step(self, batch: Dict) -> Any:
145
+ x = self.get_input(batch)
146
+ x = self.encode_first_stage(x)
147
+ batch["global_step"] = self.global_step
148
+ loss, loss_dict = self(x, batch)
149
+ return loss, loss_dict
150
+
151
+ def training_step(self, batch, batch_idx):
152
+ loss, loss_dict = self.shared_step(batch)
153
+
154
+ self.log_dict(
155
+ loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
156
+ )
157
+
158
+ self.log(
159
+ "global_step",
160
+ float(self.global_step),
161
+ prog_bar=True,
162
+ logger=True,
163
+ on_step=True,
164
+ on_epoch=False,
165
+ )
166
+
167
+ lr = self.optimizers().param_groups[0]["lr"]
168
+ self.log(
169
+ "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
170
+ )
171
+
172
+ return loss
173
+
174
+ def on_train_start(self, *args, **kwargs):
175
+ if self.sampler is None or self.loss_fn is None:
176
+ raise ValueError("Sampler and loss function need to be set for training.")
177
+
178
+ def on_train_batch_end(self, *args, **kwargs):
179
+ if self.use_ema:
180
+ self.model_ema(self.model)
181
+
182
+ @contextmanager
183
+ def ema_scope(self, context=None):
184
+ if self.use_ema:
185
+ self.model_ema.store(self.model.parameters())
186
+ self.model_ema.copy_to(self.model)
187
+ if context is not None:
188
+ print(f"{context}: Switched to EMA weights")
189
+ try:
190
+ yield None
191
+ finally:
192
+ if self.use_ema:
193
+ self.model_ema.restore(self.model.parameters())
194
+ if context is not None:
195
+ print(f"{context}: Restored training weights")
196
+
197
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
198
+ return get_obj_from_str(cfg["target"])(
199
+ params, lr=lr, **cfg.get("params", dict())
200
+ )
201
+
202
+ def configure_optimizers(self):
203
+ lr = self.learning_rate
204
+ params = []
205
+ print("Trainable parameter list: ")
206
+ print("-"*20)
207
+ for name, param in self.model.named_parameters():
208
+ if any([key in name for key in self.opt_keys]):
209
+ params.append(param)
210
+ print(name)
211
+ else:
212
+ param.requires_grad_(False)
213
+ for embedder in self.conditioner.embedders:
214
+ if embedder.is_trainable:
215
+ for name, param in embedder.named_parameters():
216
+ params.append(param)
217
+ print(name)
218
+ print("-"*20)
219
+ opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
220
+ scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda epoch: 0.95**epoch)
221
+
222
+ return [opt], scheduler
223
+
224
+ @torch.no_grad()
225
+ def sample(
226
+ self,
227
+ cond: Dict,
228
+ uc: Union[Dict, None] = None,
229
+ batch_size: int = 16,
230
+ shape: Union[None, Tuple, List] = None,
231
+ **kwargs,
232
+ ):
233
+ randn = torch.randn(batch_size, *shape).to(self.device)
234
+
235
+ denoiser = lambda input, sigma, c: self.denoiser(
236
+ self.model, input, sigma, c, **kwargs
237
+ )
238
+ samples = self.sampler(denoiser, randn, cond, uc=uc)
239
+ return samples
240
+
241
+ @torch.no_grad()
242
+ def log_conditionings(self, batch: Dict, n: int) -> Dict:
243
+ """
244
+ Defines heuristics to log different conditionings.
245
+ These can be lists of strings (text-to-image), tensors, ints, ...
246
+ """
247
+ image_h, image_w = batch[self.input_key].shape[2:]
248
+ log = dict()
249
+
250
+ for embedder in self.conditioner.embedders:
251
+ if (
252
+ (self.log_keys is None) or (embedder.input_key in self.log_keys)
253
+ ) and not self.no_cond_log:
254
+ x = batch[embedder.input_key][:n]
255
+ if isinstance(x, torch.Tensor):
256
+ if x.dim() == 1:
257
+ # class-conditional, convert integer to string
258
+ x = [str(x[i].item()) for i in range(x.shape[0])]
259
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
260
+ elif x.dim() == 2:
261
+ # size and crop cond and the like
262
+ x = [
263
+ "x".join([str(xx) for xx in x[i].tolist()])
264
+ for i in range(x.shape[0])
265
+ ]
266
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
267
+ else:
268
+ raise NotImplementedError()
269
+ elif isinstance(x, (List, ListConfig)):
270
+ if isinstance(x[0], str):
271
+ # strings
272
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
273
+ else:
274
+ raise NotImplementedError()
275
+ else:
276
+ raise NotImplementedError()
277
+ log[embedder.input_key] = xc
278
+ return log
279
+
280
+ @torch.no_grad()
281
+ def log_images(
282
+ self,
283
+ batch: Dict,
284
+ N: int = 8,
285
+ sample: bool = True,
286
+ ucg_keys: List[str] = None,
287
+ **kwargs,
288
+ ) -> Dict:
289
+ conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
290
+ if ucg_keys:
291
+ assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
292
+ "Each defined ucg key for sampling must be in the provided conditioner input keys,"
293
+ f"but we have {ucg_keys} vs. {conditioner_input_keys}"
294
+ )
295
+ else:
296
+ ucg_keys = conditioner_input_keys
297
+ log = dict()
298
+
299
+ x = self.get_input(batch)
300
+
301
+ c, uc = self.conditioner.get_unconditional_conditioning(
302
+ batch,
303
+ force_uc_zero_embeddings=ucg_keys
304
+ if len(self.conditioner.embedders) > 0
305
+ else [],
306
+ )
307
+
308
+ sampling_kwargs = {}
309
+
310
+ N = min(x.shape[0], N)
311
+ x = x.to(self.device)[:N]
312
+ log["inputs"] = x
313
+ z = self.encode_first_stage(x)
314
+ log["reconstructions"] = self.decode_first_stage(z)
315
+ log.update(self.log_conditionings(batch, N))
316
+
317
+ for k in c:
318
+ if isinstance(c[k], torch.Tensor):
319
+ c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
320
+
321
+ if sample:
322
+ with self.ema_scope("Plotting"):
323
+ samples = self.sample(
324
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
325
+ )
326
+ samples = self.decode_first_stage(samples)
327
+ log["samples"] = samples
328
+ return log
sgm/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .encoders.modules import GeneralConditioner, DualConditioner
2
+
3
+ UNCONDITIONAL_CONFIG = {
4
+ "target": "sgm.modules.GeneralConditioner",
5
+ "params": {"emb_models": []},
6
+ }
sgm/modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (337 Bytes). View file
 
sgm/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (388 Bytes). View file
 
sgm/modules/__pycache__/attention.cpython-310.pyc ADDED
Binary file (21.6 kB). View file
 
sgm/modules/__pycache__/attention.cpython-311.pyc ADDED
Binary file (45.1 kB). View file
 
sgm/modules/__pycache__/ema.cpython-310.pyc ADDED
Binary file (3.21 kB). View file
 
sgm/modules/__pycache__/ema.cpython-311.pyc ADDED
Binary file (5.82 kB). View file
 
sgm/modules/attention.py ADDED
@@ -0,0 +1,976 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from inspect import isfunction
3
+ from typing import Any, Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from einops import rearrange, repeat
8
+ from packaging import version
9
+ from torch import nn, einsum
10
+
11
+
12
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
13
+ SDP_IS_AVAILABLE = True
14
+ from torch.backends.cuda import SDPBackend, sdp_kernel
15
+
16
+ BACKEND_MAP = {
17
+ SDPBackend.MATH: {
18
+ "enable_math": True,
19
+ "enable_flash": False,
20
+ "enable_mem_efficient": False,
21
+ },
22
+ SDPBackend.FLASH_ATTENTION: {
23
+ "enable_math": False,
24
+ "enable_flash": True,
25
+ "enable_mem_efficient": False,
26
+ },
27
+ SDPBackend.EFFICIENT_ATTENTION: {
28
+ "enable_math": False,
29
+ "enable_flash": False,
30
+ "enable_mem_efficient": True,
31
+ },
32
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
33
+ }
34
+ else:
35
+ from contextlib import nullcontext
36
+
37
+ SDP_IS_AVAILABLE = False
38
+ sdp_kernel = nullcontext
39
+ BACKEND_MAP = {}
40
+ print(
41
+ f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
42
+ f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
43
+ )
44
+
45
+ try:
46
+ import xformers
47
+ import xformers.ops
48
+
49
+ XFORMERS_IS_AVAILABLE = True
50
+ except:
51
+ XFORMERS_IS_AVAILABLE = False
52
+ print("no module 'xformers'. Processing without...")
53
+
54
+ from .diffusionmodules.util import checkpoint
55
+
56
+
57
+ def exists(val):
58
+ return val is not None
59
+
60
+
61
+ def uniq(arr):
62
+ return {el: True for el in arr}.keys()
63
+
64
+
65
+ def default(val, d):
66
+ if exists(val):
67
+ return val
68
+ return d() if isfunction(d) else d
69
+
70
+
71
+ def max_neg_value(t):
72
+ return -torch.finfo(t.dtype).max
73
+
74
+
75
+ def init_(tensor):
76
+ dim = tensor.shape[-1]
77
+ std = 1 / math.sqrt(dim)
78
+ tensor.uniform_(-std, std)
79
+ return tensor
80
+
81
+ # feedforward
82
+ class GEGLU(nn.Module):
83
+ def __init__(self, dim_in, dim_out):
84
+ super().__init__()
85
+ self.proj = nn.Linear(dim_in, dim_out * 2)
86
+
87
+ def forward(self, x):
88
+ x, gate = self.proj(x).chunk(2, dim=-1)
89
+ return x * F.gelu(gate)
90
+
91
+
92
+ class FeedForward(nn.Module):
93
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
94
+ super().__init__()
95
+ inner_dim = int(dim * mult)
96
+ dim_out = default(dim_out, dim)
97
+ project_in = (
98
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
99
+ if not glu
100
+ else GEGLU(dim, inner_dim)
101
+ )
102
+
103
+ self.net = nn.Sequential(
104
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
105
+ )
106
+
107
+ def forward(self, x):
108
+ return self.net(x)
109
+
110
+
111
+ def zero_module(module):
112
+ """
113
+ Zero out the parameters of a module and return it.
114
+ """
115
+ for p in module.parameters():
116
+ p.detach().zero_()
117
+ return module
118
+
119
+
120
+ def Normalize(in_channels):
121
+ return torch.nn.GroupNorm(
122
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
123
+ )
124
+
125
+
126
+ class LinearAttention(nn.Module):
127
+ def __init__(self, dim, heads=4, dim_head=32):
128
+ super().__init__()
129
+ self.heads = heads
130
+ hidden_dim = dim_head * heads
131
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
132
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
133
+
134
+ def forward(self, x):
135
+ b, c, h, w = x.shape
136
+ qkv = self.to_qkv(x)
137
+ q, k, v = rearrange(
138
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
139
+ )
140
+ k = k.softmax(dim=-1)
141
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
142
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
143
+ out = rearrange(
144
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
145
+ )
146
+ return self.to_out(out)
147
+
148
+
149
+ class SpatialSelfAttention(nn.Module):
150
+ def __init__(self, in_channels):
151
+ super().__init__()
152
+ self.in_channels = in_channels
153
+
154
+ self.norm = Normalize(in_channels)
155
+ self.q = torch.nn.Conv2d(
156
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
157
+ )
158
+ self.k = torch.nn.Conv2d(
159
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
160
+ )
161
+ self.v = torch.nn.Conv2d(
162
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
163
+ )
164
+ self.proj_out = torch.nn.Conv2d(
165
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
166
+ )
167
+
168
+ def forward(self, x):
169
+ h_ = x
170
+ h_ = self.norm(h_)
171
+ q = self.q(h_)
172
+ k = self.k(h_)
173
+ v = self.v(h_)
174
+
175
+ # compute attention
176
+ b, c, h, w = q.shape
177
+ q = rearrange(q, "b c h w -> b (h w) c")
178
+ k = rearrange(k, "b c h w -> b c (h w)")
179
+ w_ = torch.einsum("bij,bjk->bik", q, k)
180
+
181
+ w_ = w_ * (int(c) ** (-0.5))
182
+ w_ = torch.nn.functional.softmax(w_, dim=2)
183
+
184
+ # attend to values
185
+ v = rearrange(v, "b c h w -> b c (h w)")
186
+ w_ = rearrange(w_, "b i j -> b j i")
187
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
188
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
189
+ h_ = self.proj_out(h_)
190
+
191
+ return x + h_
192
+
193
+
194
+ class CrossAttention(nn.Module):
195
+ def __init__(
196
+ self,
197
+ query_dim,
198
+ context_dim=None,
199
+ heads=8,
200
+ dim_head=64,
201
+ dropout=0.0,
202
+ backend=None,
203
+ ):
204
+ super().__init__()
205
+ inner_dim = dim_head * heads
206
+ context_dim = default(context_dim, query_dim)
207
+
208
+ self.scale = dim_head**-0.5
209
+ self.heads = heads
210
+
211
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
212
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
213
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
214
+
215
+ self.to_out = zero_module(nn.Sequential(
216
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
217
+ ))
218
+ self.backend = backend
219
+
220
+ self.attn_map_cache = None
221
+
222
+ def forward(
223
+ self,
224
+ x,
225
+ context=None,
226
+ mask=None,
227
+ additional_tokens=None,
228
+ n_times_crossframe_attn_in_self=0,
229
+ ):
230
+ h = self.heads
231
+
232
+ if additional_tokens is not None:
233
+ # get the number of masked tokens at the beginning of the output sequence
234
+ n_tokens_to_mask = additional_tokens.shape[1]
235
+ # add additional token
236
+ x = torch.cat([additional_tokens, x], dim=1)
237
+
238
+ q = self.to_q(x)
239
+ context = default(context, x)
240
+ k = self.to_k(context)
241
+ v = self.to_v(context)
242
+
243
+ if n_times_crossframe_attn_in_self:
244
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
245
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
246
+ n_cp = x.shape[0] // n_times_crossframe_attn_in_self
247
+ k = repeat(
248
+ k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
249
+ )
250
+ v = repeat(
251
+ v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
252
+ )
253
+
254
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
255
+
256
+ ## old
257
+
258
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
259
+ del q, k
260
+
261
+ if exists(mask):
262
+ mask = rearrange(mask, 'b ... -> b (...)')
263
+ max_neg_value = -torch.finfo(sim.dtype).max
264
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
265
+ sim.masked_fill_(~mask, max_neg_value)
266
+
267
+ # attention, what we cannot get enough of
268
+ sim = sim.softmax(dim=-1)
269
+
270
+ # save attn_map
271
+ if self.attn_map_cache is not None:
272
+ bh, n, l = sim.shape
273
+ size = int(n**0.5)
274
+ self.attn_map_cache["size"] = size
275
+ self.attn_map_cache["attn_map"] = sim
276
+
277
+ out = einsum('b i j, b j d -> b i d', sim, v)
278
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
279
+
280
+ ## new
281
+ # with sdp_kernel(**BACKEND_MAP[self.backend]):
282
+ # # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
283
+ # out = F.scaled_dot_product_attention(
284
+ # q, k, v, attn_mask=mask
285
+ # ) # scale is dim_head ** -0.5 per default
286
+
287
+ # del q, k, v
288
+ # out = rearrange(out, "b h n d -> b n (h d)", h=h)
289
+
290
+ if additional_tokens is not None:
291
+ # remove additional token
292
+ out = out[:, n_tokens_to_mask:]
293
+ return self.to_out(out)
294
+
295
+
296
+ class MemoryEfficientCrossAttention(nn.Module):
297
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
298
+ def __init__(
299
+ self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
300
+ ):
301
+ super().__init__()
302
+ # print(
303
+ # f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
304
+ # f"{heads} heads with a dimension of {dim_head}."
305
+ # )
306
+ inner_dim = dim_head * heads
307
+ context_dim = default(context_dim, query_dim)
308
+
309
+ self.heads = heads
310
+ self.dim_head = dim_head
311
+
312
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
313
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
314
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
315
+
316
+ self.to_out = nn.Sequential(
317
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
318
+ )
319
+ self.attention_op: Optional[Any] = None
320
+
321
+ def forward(
322
+ self,
323
+ x,
324
+ context=None,
325
+ mask=None,
326
+ additional_tokens=None,
327
+ n_times_crossframe_attn_in_self=0,
328
+ ):
329
+ if additional_tokens is not None:
330
+ # get the number of masked tokens at the beginning of the output sequence
331
+ n_tokens_to_mask = additional_tokens.shape[1]
332
+ # add additional token
333
+ x = torch.cat([additional_tokens, x], dim=1)
334
+ q = self.to_q(x)
335
+ context = default(context, x)
336
+ k = self.to_k(context)
337
+ v = self.to_v(context)
338
+
339
+ if n_times_crossframe_attn_in_self:
340
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
341
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
342
+ # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
343
+ k = repeat(
344
+ k[::n_times_crossframe_attn_in_self],
345
+ "b ... -> (b n) ...",
346
+ n=n_times_crossframe_attn_in_self,
347
+ )
348
+ v = repeat(
349
+ v[::n_times_crossframe_attn_in_self],
350
+ "b ... -> (b n) ...",
351
+ n=n_times_crossframe_attn_in_self,
352
+ )
353
+
354
+ b, _, _ = q.shape
355
+ q, k, v = map(
356
+ lambda t: t.unsqueeze(3)
357
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
358
+ .permute(0, 2, 1, 3)
359
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
360
+ .contiguous(),
361
+ (q, k, v),
362
+ )
363
+
364
+ # actually compute the attention, what we cannot get enough of
365
+ out = xformers.ops.memory_efficient_attention(
366
+ q, k, v, attn_bias=None, op=self.attention_op
367
+ )
368
+
369
+ # TODO: Use this directly in the attention operation, as a bias
370
+ if exists(mask):
371
+ raise NotImplementedError
372
+ out = (
373
+ out.unsqueeze(0)
374
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
375
+ .permute(0, 2, 1, 3)
376
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
377
+ )
378
+ if additional_tokens is not None:
379
+ # remove additional token
380
+ out = out[:, n_tokens_to_mask:]
381
+ return self.to_out(out)
382
+
383
+
384
+ class BasicTransformerBlock(nn.Module):
385
+ ATTENTION_MODES = {
386
+ "softmax": CrossAttention, # vanilla attention
387
+ "softmax-xformers": MemoryEfficientCrossAttention, # ampere
388
+ }
389
+
390
+ def __init__(
391
+ self,
392
+ dim,
393
+ n_heads,
394
+ d_head,
395
+ dropout=0.0,
396
+ context_dim=None,
397
+ add_context_dim=None,
398
+ gated_ff=True,
399
+ checkpoint=True,
400
+ disable_self_attn=False,
401
+ attn_mode="softmax",
402
+ sdp_backend=None,
403
+ ):
404
+ super().__init__()
405
+ assert attn_mode in self.ATTENTION_MODES
406
+ if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
407
+ print(
408
+ f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
409
+ f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
410
+ )
411
+ attn_mode = "softmax"
412
+ elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
413
+ print(
414
+ "We do not support vanilla attention anymore, as it is too expensive. Sorry."
415
+ )
416
+ if not XFORMERS_IS_AVAILABLE:
417
+ assert (
418
+ False
419
+ ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
420
+ else:
421
+ print("Falling back to xformers efficient attention.")
422
+ attn_mode = "softmax-xformers"
423
+ attn_cls = self.ATTENTION_MODES[attn_mode]
424
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
425
+ assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
426
+ else:
427
+ assert sdp_backend is None
428
+ self.disable_self_attn = disable_self_attn
429
+ self.attn1 = MemoryEfficientCrossAttention(
430
+ query_dim=dim,
431
+ heads=n_heads,
432
+ dim_head=d_head,
433
+ dropout=dropout,
434
+ context_dim=context_dim if self.disable_self_attn else None,
435
+ backend=sdp_backend,
436
+ ) # is a self-attention if not self.disable_self_attn
437
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
438
+ if context_dim is not None and context_dim > 0:
439
+ self.attn2 = attn_cls(
440
+ query_dim=dim,
441
+ context_dim=context_dim,
442
+ heads=n_heads,
443
+ dim_head=d_head,
444
+ dropout=dropout,
445
+ backend=sdp_backend,
446
+ ) # is self-attn if context is none
447
+ if add_context_dim is not None and add_context_dim > 0:
448
+ self.add_attn = attn_cls(
449
+ query_dim=dim,
450
+ context_dim=add_context_dim,
451
+ heads=n_heads,
452
+ dim_head=d_head,
453
+ dropout=dropout,
454
+ backend=sdp_backend,
455
+ ) # is self-attn if context is none
456
+ self.add_norm = nn.LayerNorm(dim)
457
+ self.norm1 = nn.LayerNorm(dim)
458
+ self.norm2 = nn.LayerNorm(dim)
459
+ self.norm3 = nn.LayerNorm(dim)
460
+ self.checkpoint = checkpoint
461
+
462
+ def forward(
463
+ self, x, context=None, add_context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
464
+ ):
465
+ kwargs = {"x": x}
466
+
467
+ if context is not None:
468
+ kwargs.update({"context": context})
469
+
470
+ if additional_tokens is not None:
471
+ kwargs.update({"additional_tokens": additional_tokens})
472
+
473
+ if n_times_crossframe_attn_in_self:
474
+ kwargs.update(
475
+ {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
476
+ )
477
+
478
+ return checkpoint(
479
+ self._forward, (x, context, add_context), self.parameters(), self.checkpoint
480
+ )
481
+
482
+ def _forward(
483
+ self, x, context=None, add_context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
484
+ ):
485
+ x = (
486
+ self.attn1(
487
+ self.norm1(x),
488
+ context=context if self.disable_self_attn else None,
489
+ additional_tokens=additional_tokens,
490
+ n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
491
+ if not self.disable_self_attn
492
+ else 0,
493
+ )
494
+ + x
495
+ )
496
+ if hasattr(self, "attn2"):
497
+ x = (
498
+ self.attn2(
499
+ self.norm2(x), context=context, additional_tokens=additional_tokens
500
+ )
501
+ + x
502
+ )
503
+ if hasattr(self, "add_attn"):
504
+ x = (
505
+ self.add_attn(
506
+ self.add_norm(x), context=add_context, additional_tokens=additional_tokens
507
+ )
508
+ + x
509
+ )
510
+ x = self.ff(self.norm3(x)) + x
511
+ return x
512
+
513
+
514
+ class BasicTransformerSingleLayerBlock(nn.Module):
515
+ ATTENTION_MODES = {
516
+ "softmax": CrossAttention, # vanilla attention
517
+ "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
518
+ # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
519
+ }
520
+
521
+ def __init__(
522
+ self,
523
+ dim,
524
+ n_heads,
525
+ d_head,
526
+ dropout=0.0,
527
+ context_dim=None,
528
+ gated_ff=True,
529
+ checkpoint=True,
530
+ attn_mode="softmax",
531
+ ):
532
+ super().__init__()
533
+ assert attn_mode in self.ATTENTION_MODES
534
+ attn_cls = self.ATTENTION_MODES[attn_mode]
535
+ self.attn1 = attn_cls(
536
+ query_dim=dim,
537
+ heads=n_heads,
538
+ dim_head=d_head,
539
+ dropout=dropout,
540
+ context_dim=context_dim,
541
+ )
542
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
543
+ self.norm1 = nn.LayerNorm(dim)
544
+ self.norm2 = nn.LayerNorm(dim)
545
+ self.checkpoint = checkpoint
546
+
547
+ def forward(self, x, context=None):
548
+ return checkpoint(
549
+ self._forward, (x, context), self.parameters(), self.checkpoint
550
+ )
551
+
552
+ def _forward(self, x, context=None):
553
+ x = self.attn1(self.norm1(x), context=context) + x
554
+ x = self.ff(self.norm2(x)) + x
555
+ return x
556
+
557
+
558
+ class SpatialTransformer(nn.Module):
559
+ """
560
+ Transformer block for image-like data.
561
+ First, project the input (aka embedding)
562
+ and reshape to b, t, d.
563
+ Then apply standard transformer action.
564
+ Finally, reshape to image
565
+ NEW: use_linear for more efficiency instead of the 1x1 convs
566
+ """
567
+
568
+ def __init__(
569
+ self,
570
+ in_channels,
571
+ n_heads,
572
+ d_head,
573
+ depth=1,
574
+ dropout=0.0,
575
+ context_dim=None,
576
+ add_context_dim=None,
577
+ disable_self_attn=False,
578
+ use_linear=False,
579
+ attn_type="softmax",
580
+ use_checkpoint=True,
581
+ # sdp_backend=SDPBackend.FLASH_ATTENTION
582
+ sdp_backend=None,
583
+ ):
584
+ super().__init__()
585
+ # print(
586
+ # f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
587
+ # )
588
+ from omegaconf import ListConfig
589
+
590
+ if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
591
+ context_dim = [context_dim]
592
+ if exists(context_dim) and isinstance(context_dim, list):
593
+ if depth != len(context_dim):
594
+ # print(
595
+ # f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
596
+ # f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
597
+ # )
598
+ # depth does not match context dims.
599
+ assert all(
600
+ map(lambda x: x == context_dim[0], context_dim)
601
+ ), "need homogenous context_dim to match depth automatically"
602
+ context_dim = depth * [context_dim[0]]
603
+ elif context_dim is None:
604
+ context_dim = [None] * depth
605
+ self.in_channels = in_channels
606
+ inner_dim = n_heads * d_head
607
+ self.norm = Normalize(in_channels)
608
+ if not use_linear:
609
+ self.proj_in = nn.Conv2d(
610
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
611
+ )
612
+ else:
613
+ self.proj_in = nn.Linear(in_channels, inner_dim)
614
+
615
+ self.transformer_blocks = nn.ModuleList(
616
+ [
617
+ BasicTransformerBlock(
618
+ inner_dim,
619
+ n_heads,
620
+ d_head,
621
+ dropout=dropout,
622
+ context_dim=context_dim[d],
623
+ add_context_dim=add_context_dim,
624
+ disable_self_attn=disable_self_attn,
625
+ attn_mode=attn_type,
626
+ checkpoint=use_checkpoint,
627
+ sdp_backend=sdp_backend,
628
+ )
629
+ for d in range(depth)
630
+ ]
631
+ )
632
+ if not use_linear:
633
+ self.proj_out = zero_module(
634
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
635
+ )
636
+ else:
637
+ # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
638
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
639
+ self.use_linear = use_linear
640
+
641
+ def forward(self, x, context=None, add_context=None):
642
+ # note: if no context is given, cross-attention defaults to self-attention
643
+ if not isinstance(context, list):
644
+ context = [context]
645
+ b, c, h, w = x.shape
646
+ x_in = x
647
+ x = self.norm(x)
648
+ if not self.use_linear:
649
+ x = self.proj_in(x)
650
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
651
+ if self.use_linear:
652
+ x = self.proj_in(x)
653
+ for i, block in enumerate(self.transformer_blocks):
654
+ if i > 0 and len(context) == 1:
655
+ i = 0 # use same context for each block
656
+ x = block(x, context=context[i], add_context=add_context)
657
+ if self.use_linear:
658
+ x = self.proj_out(x)
659
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
660
+ if not self.use_linear:
661
+ x = self.proj_out(x)
662
+ return x + x_in
663
+
664
+
665
+ def benchmark_attn():
666
+ # Lets define a helpful benchmarking function:
667
+ # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
668
+ device = "cuda" if torch.cuda.is_available() else "cpu"
669
+ import torch.nn.functional as F
670
+ import torch.utils.benchmark as benchmark
671
+
672
+ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
673
+ t0 = benchmark.Timer(
674
+ stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
675
+ )
676
+ return t0.blocked_autorange().mean * 1e6
677
+
678
+ # Lets define the hyper-parameters of our input
679
+ batch_size = 32
680
+ max_sequence_len = 1024
681
+ num_heads = 32
682
+ embed_dimension = 32
683
+
684
+ dtype = torch.float16
685
+
686
+ query = torch.rand(
687
+ batch_size,
688
+ num_heads,
689
+ max_sequence_len,
690
+ embed_dimension,
691
+ device=device,
692
+ dtype=dtype,
693
+ )
694
+ key = torch.rand(
695
+ batch_size,
696
+ num_heads,
697
+ max_sequence_len,
698
+ embed_dimension,
699
+ device=device,
700
+ dtype=dtype,
701
+ )
702
+ value = torch.rand(
703
+ batch_size,
704
+ num_heads,
705
+ max_sequence_len,
706
+ embed_dimension,
707
+ device=device,
708
+ dtype=dtype,
709
+ )
710
+
711
+ print(f"q/k/v shape:", query.shape, key.shape, value.shape)
712
+
713
+ # Lets explore the speed of each of the 3 implementations
714
+ from torch.backends.cuda import SDPBackend, sdp_kernel
715
+
716
+ # Helpful arguments mapper
717
+ backend_map = {
718
+ SDPBackend.MATH: {
719
+ "enable_math": True,
720
+ "enable_flash": False,
721
+ "enable_mem_efficient": False,
722
+ },
723
+ SDPBackend.FLASH_ATTENTION: {
724
+ "enable_math": False,
725
+ "enable_flash": True,
726
+ "enable_mem_efficient": False,
727
+ },
728
+ SDPBackend.EFFICIENT_ATTENTION: {
729
+ "enable_math": False,
730
+ "enable_flash": False,
731
+ "enable_mem_efficient": True,
732
+ },
733
+ }
734
+
735
+ from torch.profiler import ProfilerActivity, profile, record_function
736
+
737
+ activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
738
+
739
+ print(
740
+ f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
741
+ )
742
+ with profile(
743
+ activities=activities, record_shapes=False, profile_memory=True
744
+ ) as prof:
745
+ with record_function("Default detailed stats"):
746
+ for _ in range(25):
747
+ o = F.scaled_dot_product_attention(query, key, value)
748
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
749
+
750
+ print(
751
+ f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
752
+ )
753
+ with sdp_kernel(**backend_map[SDPBackend.MATH]):
754
+ with profile(
755
+ activities=activities, record_shapes=False, profile_memory=True
756
+ ) as prof:
757
+ with record_function("Math implmentation stats"):
758
+ for _ in range(25):
759
+ o = F.scaled_dot_product_attention(query, key, value)
760
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
761
+
762
+ with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
763
+ try:
764
+ print(
765
+ f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
766
+ )
767
+ except RuntimeError:
768
+ print("FlashAttention is not supported. See warnings for reasons.")
769
+ with profile(
770
+ activities=activities, record_shapes=False, profile_memory=True
771
+ ) as prof:
772
+ with record_function("FlashAttention stats"):
773
+ for _ in range(25):
774
+ o = F.scaled_dot_product_attention(query, key, value)
775
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
776
+
777
+ with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
778
+ try:
779
+ print(
780
+ f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
781
+ )
782
+ except RuntimeError:
783
+ print("EfficientAttention is not supported. See warnings for reasons.")
784
+ with profile(
785
+ activities=activities, record_shapes=False, profile_memory=True
786
+ ) as prof:
787
+ with record_function("EfficientAttention stats"):
788
+ for _ in range(25):
789
+ o = F.scaled_dot_product_attention(query, key, value)
790
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
791
+
792
+
793
+ def run_model(model, x, context):
794
+ return model(x, context)
795
+
796
+
797
+ def benchmark_transformer_blocks():
798
+ device = "cuda" if torch.cuda.is_available() else "cpu"
799
+ import torch.utils.benchmark as benchmark
800
+
801
+ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
802
+ t0 = benchmark.Timer(
803
+ stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
804
+ )
805
+ return t0.blocked_autorange().mean * 1e6
806
+
807
+ checkpoint = True
808
+ compile = False
809
+
810
+ batch_size = 32
811
+ h, w = 64, 64
812
+ context_len = 77
813
+ embed_dimension = 1024
814
+ context_dim = 1024
815
+ d_head = 64
816
+
817
+ transformer_depth = 4
818
+
819
+ n_heads = embed_dimension // d_head
820
+
821
+ dtype = torch.float16
822
+
823
+ model_native = SpatialTransformer(
824
+ embed_dimension,
825
+ n_heads,
826
+ d_head,
827
+ context_dim=context_dim,
828
+ use_linear=True,
829
+ use_checkpoint=checkpoint,
830
+ attn_type="softmax",
831
+ depth=transformer_depth,
832
+ sdp_backend=SDPBackend.FLASH_ATTENTION,
833
+ ).to(device)
834
+ model_efficient_attn = SpatialTransformer(
835
+ embed_dimension,
836
+ n_heads,
837
+ d_head,
838
+ context_dim=context_dim,
839
+ use_linear=True,
840
+ depth=transformer_depth,
841
+ use_checkpoint=checkpoint,
842
+ attn_type="softmax-xformers",
843
+ ).to(device)
844
+ if not checkpoint and compile:
845
+ print("compiling models")
846
+ model_native = torch.compile(model_native)
847
+ model_efficient_attn = torch.compile(model_efficient_attn)
848
+
849
+ x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
850
+ c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
851
+
852
+ from torch.profiler import ProfilerActivity, profile, record_function
853
+
854
+ activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
855
+
856
+ with torch.autocast("cuda"):
857
+ print(
858
+ f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
859
+ )
860
+ print(
861
+ f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
862
+ )
863
+
864
+ print(75 * "+")
865
+ print("NATIVE")
866
+ print(75 * "+")
867
+ torch.cuda.reset_peak_memory_stats()
868
+ with profile(
869
+ activities=activities, record_shapes=False, profile_memory=True
870
+ ) as prof:
871
+ with record_function("NativeAttention stats"):
872
+ for _ in range(25):
873
+ model_native(x, c)
874
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
875
+ print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
876
+
877
+ print(75 * "+")
878
+ print("Xformers")
879
+ print(75 * "+")
880
+ torch.cuda.reset_peak_memory_stats()
881
+ with profile(
882
+ activities=activities, record_shapes=False, profile_memory=True
883
+ ) as prof:
884
+ with record_function("xformers stats"):
885
+ for _ in range(25):
886
+ model_efficient_attn(x, c)
887
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
888
+ print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
889
+
890
+
891
+ def test01():
892
+ # conv1x1 vs linear
893
+ from ..util import count_params
894
+
895
+ conv = nn.Conv2d(3, 32, kernel_size=1).cuda()
896
+ print(count_params(conv))
897
+ linear = torch.nn.Linear(3, 32).cuda()
898
+ print(count_params(linear))
899
+
900
+ print(conv.weight.shape)
901
+
902
+ # use same initialization
903
+ linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
904
+ linear.bias = torch.nn.Parameter(conv.bias)
905
+
906
+ print(linear.weight.shape)
907
+
908
+ x = torch.randn(11, 3, 64, 64).cuda()
909
+
910
+ xr = rearrange(x, "b c h w -> b (h w) c").contiguous()
911
+ print(xr.shape)
912
+ out_linear = linear(xr)
913
+ print(out_linear.mean(), out_linear.shape)
914
+
915
+ out_conv = conv(x)
916
+ print(out_conv.mean(), out_conv.shape)
917
+ print("done with test01.\n")
918
+
919
+
920
+ def test02():
921
+ # try cosine flash attention
922
+ import time
923
+
924
+ torch.backends.cuda.matmul.allow_tf32 = True
925
+ torch.backends.cudnn.allow_tf32 = True
926
+ torch.backends.cudnn.benchmark = True
927
+ print("testing cosine flash attention...")
928
+ DIM = 1024
929
+ SEQLEN = 4096
930
+ BS = 16
931
+
932
+ print(" softmax (vanilla) first...")
933
+ model = BasicTransformerBlock(
934
+ dim=DIM,
935
+ n_heads=16,
936
+ d_head=64,
937
+ dropout=0.0,
938
+ context_dim=None,
939
+ attn_mode="softmax",
940
+ ).cuda()
941
+ try:
942
+ x = torch.randn(BS, SEQLEN, DIM).cuda()
943
+ tic = time.time()
944
+ y = model(x)
945
+ toc = time.time()
946
+ print(y.shape, toc - tic)
947
+ except RuntimeError as e:
948
+ # likely oom
949
+ print(str(e))
950
+
951
+ print("\n now flash-cosine...")
952
+ model = BasicTransformerBlock(
953
+ dim=DIM,
954
+ n_heads=16,
955
+ d_head=64,
956
+ dropout=0.0,
957
+ context_dim=None,
958
+ attn_mode="flash-cosine",
959
+ ).cuda()
960
+ x = torch.randn(BS, SEQLEN, DIM).cuda()
961
+ tic = time.time()
962
+ y = model(x)
963
+ toc = time.time()
964
+ print(y.shape, toc - tic)
965
+ print("done with test02.\n")
966
+
967
+
968
+ if __name__ == "__main__":
969
+ # test01()
970
+ # test02()
971
+ # test03()
972
+
973
+ # benchmark_attn()
974
+ benchmark_transformer_blocks()
975
+
976
+ print("done.")
sgm/modules/autoencoding/__init__.py ADDED
File without changes
sgm/modules/autoencoding/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (161 Bytes). View file
 
sgm/modules/autoencoding/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (172 Bytes). View file
 
sgm/modules/autoencoding/losses/__init__.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7
+ from taming.modules.losses.lpips import LPIPS
8
+ from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9
+
10
+ from ....util import default, instantiate_from_config
11
+
12
+
13
+ def adopt_weight(weight, global_step, threshold=0, value=0.0):
14
+ if global_step < threshold:
15
+ weight = value
16
+ return weight
17
+
18
+
19
+ class LatentLPIPS(nn.Module):
20
+ def __init__(
21
+ self,
22
+ decoder_config,
23
+ perceptual_weight=1.0,
24
+ latent_weight=1.0,
25
+ scale_input_to_tgt_size=False,
26
+ scale_tgt_to_input_size=False,
27
+ perceptual_weight_on_inputs=0.0,
28
+ ):
29
+ super().__init__()
30
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
31
+ self.scale_tgt_to_input_size = scale_tgt_to_input_size
32
+ self.init_decoder(decoder_config)
33
+ self.perceptual_loss = LPIPS().eval()
34
+ self.perceptual_weight = perceptual_weight
35
+ self.latent_weight = latent_weight
36
+ self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
37
+
38
+ def init_decoder(self, config):
39
+ self.decoder = instantiate_from_config(config)
40
+ if hasattr(self.decoder, "encoder"):
41
+ del self.decoder.encoder
42
+
43
+ def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
44
+ log = dict()
45
+ loss = (latent_inputs - latent_predictions) ** 2
46
+ log[f"{split}/latent_l2_loss"] = loss.mean().detach()
47
+ image_reconstructions = None
48
+ if self.perceptual_weight > 0.0:
49
+ image_reconstructions = self.decoder.decode(latent_predictions)
50
+ image_targets = self.decoder.decode(latent_inputs)
51
+ perceptual_loss = self.perceptual_loss(
52
+ image_targets.contiguous(), image_reconstructions.contiguous()
53
+ )
54
+ loss = (
55
+ self.latent_weight * loss.mean()
56
+ + self.perceptual_weight * perceptual_loss.mean()
57
+ )
58
+ log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
59
+
60
+ if self.perceptual_weight_on_inputs > 0.0:
61
+ image_reconstructions = default(
62
+ image_reconstructions, self.decoder.decode(latent_predictions)
63
+ )
64
+ if self.scale_input_to_tgt_size:
65
+ image_inputs = torch.nn.functional.interpolate(
66
+ image_inputs,
67
+ image_reconstructions.shape[2:],
68
+ mode="bicubic",
69
+ antialias=True,
70
+ )
71
+ elif self.scale_tgt_to_input_size:
72
+ image_reconstructions = torch.nn.functional.interpolate(
73
+ image_reconstructions,
74
+ image_inputs.shape[2:],
75
+ mode="bicubic",
76
+ antialias=True,
77
+ )
78
+
79
+ perceptual_loss2 = self.perceptual_loss(
80
+ image_inputs.contiguous(), image_reconstructions.contiguous()
81
+ )
82
+ loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
83
+ log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
84
+ return loss, log
85
+
86
+
87
+ class GeneralLPIPSWithDiscriminator(nn.Module):
88
+ def __init__(
89
+ self,
90
+ disc_start: int,
91
+ logvar_init: float = 0.0,
92
+ pixelloss_weight=1.0,
93
+ disc_num_layers: int = 3,
94
+ disc_in_channels: int = 3,
95
+ disc_factor: float = 1.0,
96
+ disc_weight: float = 1.0,
97
+ perceptual_weight: float = 1.0,
98
+ disc_loss: str = "hinge",
99
+ scale_input_to_tgt_size: bool = False,
100
+ dims: int = 2,
101
+ learn_logvar: bool = False,
102
+ regularization_weights: Union[None, dict] = None,
103
+ ):
104
+ super().__init__()
105
+ self.dims = dims
106
+ if self.dims > 2:
107
+ print(
108
+ f"running with dims={dims}. This means that for perceptual loss calculation, "
109
+ f"the LPIPS loss will be applied to each frame independently. "
110
+ )
111
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
112
+ assert disc_loss in ["hinge", "vanilla"]
113
+ self.pixel_weight = pixelloss_weight
114
+ self.perceptual_loss = LPIPS().eval()
115
+ self.perceptual_weight = perceptual_weight
116
+ # output log variance
117
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
118
+ self.learn_logvar = learn_logvar
119
+
120
+ self.discriminator = NLayerDiscriminator(
121
+ input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=False
122
+ ).apply(weights_init)
123
+ self.discriminator_iter_start = disc_start
124
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
125
+ self.disc_factor = disc_factor
126
+ self.discriminator_weight = disc_weight
127
+ self.regularization_weights = default(regularization_weights, {})
128
+
129
+ def get_trainable_parameters(self) -> Any:
130
+ return self.discriminator.parameters()
131
+
132
+ def get_trainable_autoencoder_parameters(self) -> Any:
133
+ if self.learn_logvar:
134
+ yield self.logvar
135
+ yield from ()
136
+
137
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
138
+ if last_layer is not None:
139
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
140
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
141
+ else:
142
+ nll_grads = torch.autograd.grad(
143
+ nll_loss, self.last_layer[0], retain_graph=True
144
+ )[0]
145
+ g_grads = torch.autograd.grad(
146
+ g_loss, self.last_layer[0], retain_graph=True
147
+ )[0]
148
+
149
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
150
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
151
+ d_weight = d_weight * self.discriminator_weight
152
+ return d_weight
153
+
154
+ def forward(
155
+ self,
156
+ regularization_log,
157
+ inputs,
158
+ reconstructions,
159
+ optimizer_idx,
160
+ global_step,
161
+ last_layer=None,
162
+ split="train",
163
+ weights=None,
164
+ ):
165
+ if self.scale_input_to_tgt_size:
166
+ inputs = torch.nn.functional.interpolate(
167
+ inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
168
+ )
169
+
170
+ if self.dims > 2:
171
+ inputs, reconstructions = map(
172
+ lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
173
+ (inputs, reconstructions),
174
+ )
175
+
176
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
177
+ if self.perceptual_weight > 0:
178
+ p_loss = self.perceptual_loss(
179
+ inputs.contiguous(), reconstructions.contiguous()
180
+ )
181
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
182
+
183
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
184
+ weighted_nll_loss = nll_loss
185
+ if weights is not None:
186
+ weighted_nll_loss = weights * nll_loss
187
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
188
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
189
+
190
+ # now the GAN part
191
+ if optimizer_idx == 0:
192
+ # generator update
193
+ logits_fake = self.discriminator(reconstructions.contiguous())
194
+ g_loss = -torch.mean(logits_fake)
195
+
196
+ if self.disc_factor > 0.0:
197
+ try:
198
+ d_weight = self.calculate_adaptive_weight(
199
+ nll_loss, g_loss, last_layer=last_layer
200
+ )
201
+ except RuntimeError:
202
+ assert not self.training
203
+ d_weight = torch.tensor(0.0)
204
+ else:
205
+ d_weight = torch.tensor(0.0)
206
+
207
+ disc_factor = adopt_weight(
208
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
209
+ )
210
+ loss = weighted_nll_loss + d_weight * disc_factor * g_loss
211
+ log = dict()
212
+ for k in regularization_log:
213
+ if k in self.regularization_weights:
214
+ loss = loss + self.regularization_weights[k] * regularization_log[k]
215
+ log[f"{split}/{k}"] = regularization_log[k].detach().mean()
216
+
217
+ log.update(
218
+ {
219
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
220
+ "{}/logvar".format(split): self.logvar.detach(),
221
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
222
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
223
+ "{}/d_weight".format(split): d_weight.detach(),
224
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
225
+ "{}/g_loss".format(split): g_loss.detach().mean(),
226
+ }
227
+ )
228
+
229
+ return loss, log
230
+
231
+ if optimizer_idx == 1:
232
+ # second pass for discriminator update
233
+ logits_real = self.discriminator(inputs.contiguous().detach())
234
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
235
+
236
+ disc_factor = adopt_weight(
237
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
238
+ )
239
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
240
+
241
+ log = {
242
+ "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
243
+ "{}/logits_real".format(split): logits_real.detach().mean(),
244
+ "{}/logits_fake".format(split): logits_fake.detach().mean(),
245
+ }
246
+ return d_loss, log