diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c56393c38e3c1ea80308920229f4fdafb015eab7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +**/__pycache__ +process.ipynb \ No newline at end of file diff --git a/README.md b/README.md index 4ddd6c820473ed1bf20813d45a9f07cef6ca55d6..f558d534cadaacf1bcf254062a5826d719766b7b 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ --- title: UDiffText -emoji: 🐢 +emoji: 😋 colorFrom: purple colorTo: blue sdk: gradio diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..f4bde72cd39038824baa3b9137634c6ae5c1604b --- /dev/null +++ b/app.py @@ -0,0 +1,207 @@ +import cv2 +import torch +import os, glob +import numpy as np +import gradio as gr +from PIL import Image +from omegaconf import OmegaConf +from contextlib import nullcontext +from pytorch_lightning import seed_everything +from os.path import join as ospj + +from util import * + + +def predict(cfgs, model, sampler, batch): + + context = nullcontext if cfgs.aae_enabled else torch.no_grad + + with context(): + + batch, batch_uc_1, batch_uc_2 = prepare_batch(cfgs, batch) + + if cfgs.dual_conditioner: + c, uc_1, uc_2 = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc_1=batch_uc_1, + batch_uc_2=batch_uc_2, + force_uc_zero_embeddings=cfgs.force_uc_zero_embeddings, + ) + else: + c, uc_1 = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc_1, + force_uc_zero_embeddings=cfgs.force_uc_zero_embeddings, + ) + + if cfgs.dual_conditioner: + x = sampler.get_init_noise(cfgs, model, cond=c, batch=batch, uc_1=uc_1, uc_2=uc_2) + samples_z = sampler(model, x, cond=c, batch=batch, uc_1=uc_1, uc_2=uc_2, init_step=0, + aae_enabled = cfgs.aae_enabled, detailed = cfgs.detailed) + else: + x = sampler.get_init_noise(cfgs, model, cond=c, batch=batch, uc=uc_1) + samples_z = sampler(model, x, cond=c, batch=batch, uc=uc_1, init_step=0, + aae_enabled = cfgs.aae_enabled, detailed = cfgs.detailed) + + samples_x = model.decode_first_stage(samples_z) + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + + return samples, samples_z + + +def demo_predict(input_blk, text, num_samples, steps, scale, seed, show_detail): + + global cfgs, global_index + + global_index += 1 + + if num_samples > 1: cfgs.noise_iters = 0 + + cfgs.batch_size = num_samples + cfgs.steps = steps + cfgs.scale[0] = scale + cfgs.detailed = show_detail + seed_everything(seed) + + sampler = init_sampling(cfgs) + + image = input_blk["image"] + mask = input_blk["mask"] + image = cv2.resize(image, (cfgs.W, cfgs.H)) + mask = cv2.resize(mask, (cfgs.W, cfgs.H)) + + mask = (mask == 0).astype(np.int32) + + image = torch.from_numpy(image.transpose(2,0,1)).to(dtype=torch.float32) / 127.5 - 1.0 + mask = torch.from_numpy(mask.transpose(2,0,1)).to(dtype=torch.float32).mean(dim=0, keepdim=True) + masked = image * mask + mask = 1 - mask + + seg_mask = torch.cat((torch.ones(len(text)), torch.zeros(cfgs.seq_len-len(text)))) + + # additional cond + txt = f"\"{text}\"" + original_size_as_tuple = torch.tensor((cfgs.H, cfgs.W)) + crop_coords_top_left = torch.tensor((0, 0)) + target_size_as_tuple = torch.tensor((cfgs.H, cfgs.W)) + + image = torch.tile(image[None], (num_samples, 1, 1, 1)) + mask = torch.tile(mask[None], (num_samples, 1, 1, 1)) + masked = torch.tile(masked[None], (num_samples, 1, 1, 1)) + seg_mask = torch.tile(seg_mask[None], (num_samples, 1)) + original_size_as_tuple = torch.tile(original_size_as_tuple[None], (num_samples, 1)) + crop_coords_top_left = torch.tile(crop_coords_top_left[None], (num_samples, 1)) + target_size_as_tuple = torch.tile(target_size_as_tuple[None], (num_samples, 1)) + + text = [text for i in range(num_samples)] + txt = [txt for i in range(num_samples)] + name = [str(global_index) for i in range(num_samples)] + + batch = { + "image": image, + "mask": mask, + "masked": masked, + "seg_mask": seg_mask, + "label": text, + "txt": txt, + "original_size_as_tuple": original_size_as_tuple, + "crop_coords_top_left": crop_coords_top_left, + "target_size_as_tuple": target_size_as_tuple, + "name": name + } + + samples, samples_z = predict(cfgs, model, sampler, batch) + samples = samples.cpu().numpy().transpose(0, 2, 3, 1) * 255 + results = [Image.fromarray(sample.astype(np.uint8)) for sample in samples] + + if cfgs.detailed: + sections = [] + attn_map = Image.open(f"./temp/attn_map/attn_map_{global_index}.png") + seg_maps = np.load(f"./temp/seg_map/seg_{global_index}.npy") + for i, seg_map in enumerate(seg_maps): + seg_map = cv2.resize(seg_map, (cfgs.W, cfgs.H)) + sections.append((seg_map, text[0][i])) + seg = (results[0], sections) + else: + attn_map = None + seg = None + + return results, attn_map, seg + + +if __name__ == "__main__": + + cfgs = OmegaConf.load("./configs/demo.yaml") + + model = init_model(cfgs) + global_index = 0 + + block = gr.Blocks().queue() + with block: + + with gr.Row(): + + gr.HTML( + """ +
+

+ UDiffText: A Unified Framework for High-quality Text Synthesis in Arbitrary Images via Character-aware Diffusion Models +

+

+ [arXiv] + [Code] + [ProjectPage] +

+

+ Our proposed UDiffText is capable of synthesizing accurate and harmonious text in either synthetic or real-word images, thus can be applied to tasks like scene text editing (a), arbitrary text generation (b) and accurate T2I generation (c) +

+
UDiffText
+
+ """ + ) + + with gr.Row(): + + with gr.Column(): + + input_blk = gr.Image(source='upload', tool='sketch', type="numpy", label="Input", height=512) + text = gr.Textbox(label="Text to render:", info="the text you want to render at the masked region") + run_button = gr.Button(variant="primary") + + with gr.Accordion("Advanced options", open=False): + + num_samples = gr.Slider(label="Images", info="number of generated images, locked as 1", minimum=1, maximum=1, value=1, step=1) + steps = gr.Slider(label="Steps", info ="denoising sampling steps", minimum=1, maximum=200, value=50, step=1) + scale = gr.Slider(label="Guidance Scale", info="the scale of classifier-free guidance (CFG)", minimum=0.0, maximum=10.0, value=4.0, step=0.1) + seed = gr.Slider(label="Seed", info="random seed for noise initialization", minimum=0, maximum=2147483647, step=1, randomize=True) + show_detail = gr.Checkbox(label="Show Detail", info="show the additional visualization results", value=True) + + with gr.Column(): + + gallery = gr.Gallery(label="Output", height=512, preview=True) + + with gr.Accordion("Visualization results", open=True): + + with gr.Tab(label="Attention Maps"): + gr.Markdown("### Attention maps for each character (extracted from middle blocks at intermediate sampling step):") + attn_map = gr.Image(show_label=False, show_download_button=False) + with gr.Tab(label="Segmentation Maps"): + gr.Markdown("### Character-level segmentation maps (using upscaled attention maps):") + seg_map = gr.AnnotatedImage(height=384, show_label=False, show_download_button=False) + + # examples + examples = [] + example_paths = sorted(glob.glob(ospj("./demo/examples", "*"))) + for example_path in example_paths: + label = example_path.split(os.sep)[-1].split(".")[0].split("_")[0] + examples.append([example_path, label]) + + gr.Markdown("## Examples:") + gr.Examples( + examples=examples, + inputs=[input_blk, text] + ) + + run_button.click(fn=demo_predict, inputs=[input_blk, text, num_samples, steps, scale, seed, show_detail], outputs=[gallery, attn_map, seg_map]) + + block.launch() \ No newline at end of file diff --git a/checkpoints/AEs/AE_inpainting_2.safetensors b/checkpoints/AEs/AE_inpainting_2.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..8cfdd04ec7695b20987d5323959d7d09b958f626 --- /dev/null +++ b/checkpoints/AEs/AE_inpainting_2.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:547baac83984f8bf8b433882236b87e77eb4d2f5c71e3d7a04b8dec2fe02b81f +size 334640988 diff --git a/checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt b/checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..ce122abfa90fe94f50cf2db8b7a9aafc3629bb2f --- /dev/null +++ b/checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4076c90467a907dcb8cde15776bfda4473010fe845739490341db74e82cd2267 +size 4059026213 diff --git a/checkpoints/st-step=100000+la-step=100000-simp.ckpt b/checkpoints/st-step=100000+la-step=100000-simp.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..d504886ecbfe9f294cdac9e448b8ca4910e401f7 --- /dev/null +++ b/checkpoints/st-step=100000+la-step=100000-simp.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:968397df8910f3324d94ce3df7e9d70f1bf2415a46d22edef1a510885ee0648e +size 2558065830 diff --git a/configs/demo.yaml b/configs/demo.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ba306ebf85353177abff625b16442a6a10d9c6ec --- /dev/null +++ b/configs/demo.yaml @@ -0,0 +1,29 @@ +type: "demo" + +# path +load_ckpt_path: "./checkpoints/st-step=100000+la-step=100000-simp.ckpt" +model_cfg_path: "./configs/test/textdesign_sd_2.yaml" + +# param +H: 512 +W: 512 +seq_len: 12 +batch_size: 1 + +channel: 4 # AE latent channel +factor: 8 # AE downsample factor +scale: [4.0, 0.0] # content scale, style scale +noise_iters: 10 +force_uc_zero_embeddings: ["ref", "label"] +aae_enabled: False +detailed: True +dual_conditioner: False + + +# runtime +steps: 50 +init_step: 0 +num_workers: 0 +gpu: 0 +max_iter: 100 + diff --git a/configs/test/textdesign_sd_2.yaml b/configs/test/textdesign_sd_2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..243f723f9182825cb382e23d2dbc0ea4221e5a54 --- /dev/null +++ b/configs/test/textdesign_sd_2.yaml @@ -0,0 +1,137 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + input_key: image + scale_factor: 0.18215 + disable_first_stage_autocast: True + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetAddModel + params: + use_checkpoint: False + in_channels: 9 + out_channels: 4 + ctrl_channels: 0 + model_channels: 320 + attention_resolutions: [4, 2, 1] + attn_type: add_attn + attn_layers: + - output_blocks.6.1 + num_res_blocks: 2 + channel_mult: [1, 2, 4, 4] + num_head_channels: 64 + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 0 + add_context_dim: 2048 + legacy: False + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + # crossattn cond + # - is_trainable: False + # input_key: txt + # target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder + # params: + # arch: ViT-H-14 + # version: ./checkpoints/encoders/OpenCLIP/ViT-H-14/open_clip_pytorch_model.bin + # layer: penultimate + # add crossattn cond + - is_trainable: False + input_key: label + target: sgm.modules.encoders.modules.LabelEncoder + params: + is_add_embedder: True + max_len: 12 + emb_dim: 2048 + n_heads: 8 + n_trans_layers: 12 + ckpt_path: ./checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt # ./checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt + # concat cond + - is_trainable: False + input_key: mask + target: sgm.modules.encoders.modules.IdentityEncoder + - is_trainable: False + input_key: masked + target: sgm.modules.encoders.modules.LatentEncoder + params: + scale_factor: 0.18215 + config: + target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper + params: + ckpt_path: ./checkpoints/AEs/AE_inpainting_2.safetensors + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + first_stage_config: + target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.FullLoss # StandardDiffusionLoss + params: + seq_len: 12 + kernel_size: 3 + gaussian_sigma: 0.5 + min_attn_size: 16 + lambda_local_loss: 0.02 + lambda_ocr_loss: 0.001 + ocr_enabled: False + + predictor_config: + target: sgm.modules.predictors.model.ParseqPredictor + params: + ckpt_path: "./checkpoints/predictors/parseq-bb5792a6.pt" + + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling + params: + num_idx: 1000 + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization \ No newline at end of file diff --git a/demo/examples/CEFUL_1_0.jpeg b/demo/examples/CEFUL_1_0.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..7db55e1b9e4c00a166a2628156c4975a0bb7e5f9 --- /dev/null +++ b/demo/examples/CEFUL_1_0.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d90a580083194c2130da6fd0176df3fde40b312f13f00b34b7ac6641e4ff1597 +size 113124 diff --git a/demo/examples/CLOTHES_0_0.png b/demo/examples/CLOTHES_0_0.png new file mode 100644 index 0000000000000000000000000000000000000000..b2c2118c2ad3738a44569a5d03eb783df31ac1e2 --- /dev/null +++ b/demo/examples/CLOTHES_0_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a7374b07e520fe86c4b0b587082f125dc826542caf5d9d1c08107bb1cfe0154 +size 330559 diff --git a/demo/examples/COMPLICATED_0_1.jpeg b/demo/examples/COMPLICATED_0_1.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..fa49a64a00e234f20ad7f1f40b5ea45db3deeb33 --- /dev/null +++ b/demo/examples/COMPLICATED_0_1.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98ba496f8289dda423bf5d9d60493e599df61eb5d6de75f8b966786909c3a5ab +size 207750 diff --git a/demo/examples/DELIGHT_0_1.jpeg b/demo/examples/DELIGHT_0_1.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..e59c6ffa31312d81a8983b389996063ee625987d --- /dev/null +++ b/demo/examples/DELIGHT_0_1.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:899b7388f742c28317c85944f46a181c70c8d89ce22229838fa9d2afbdae495a +size 343094 diff --git a/demo/examples/ECHOES_0_0.jpeg b/demo/examples/ECHOES_0_0.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..a782573f655bc4d0980b96d8a7a555604b3a1d52 --- /dev/null +++ b/demo/examples/ECHOES_0_0.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0ade75487cac60d88684e41b8ddfcd492f386f22eeb8bb67f1d99a2514803477 +size 285925 diff --git a/demo/examples/ENGINE_0_0.png b/demo/examples/ENGINE_0_0.png new file mode 100644 index 0000000000000000000000000000000000000000..bf7b9e9292385d7c5e6a895a83b92b2c198a2831 --- /dev/null +++ b/demo/examples/ENGINE_0_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd1fd33cded3a9c8245a38cd82e0603e2f583dbe7b415dd13ed20cdec08e94b0 +size 577987 diff --git a/demo/examples/FASCINATING_0_1.jpeg b/demo/examples/FASCINATING_0_1.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..5968d79016425cf49566d58dcf320884321e2f98 --- /dev/null +++ b/demo/examples/FASCINATING_0_1.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af5ea76ba8c5827f9ec83bc2b6a096511bc619975db5fc8f6e742ba1bb687570 +size 311074 diff --git a/demo/examples/FAVOURITE_0_0.jpeg b/demo/examples/FAVOURITE_0_0.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..cdc28a7becd458a8adde2687d281953ec31e9f43 --- /dev/null +++ b/demo/examples/FAVOURITE_0_0.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:38747d02015147fa4f1eb3ebca5a3757d908957cf8caf7de8b33d5a1750d6ada +size 130413 diff --git a/demo/examples/FINNAL_0_1.jpeg b/demo/examples/FINNAL_0_1.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..0320eff1898bff3235123235fba94841aa7931b6 --- /dev/null +++ b/demo/examples/FINNAL_0_1.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6dc56ca1ba9a1fc5e6899a9629a06d843ca10a96f5eca095c0cf1af9e38191a +size 175671 diff --git a/demo/examples/FRONTIER_0_0.png b/demo/examples/FRONTIER_0_0.png new file mode 100644 index 0000000000000000000000000000000000000000..b517c4db4c9d2b6313ae603345946edf382694b3 --- /dev/null +++ b/demo/examples/FRONTIER_0_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c0231c43a100dcf5f95a3f79c9fcbe77b345e70f87273ec395f7ff857716483c +size 437454 diff --git a/demo/examples/Innovate_0_0.jpeg b/demo/examples/Innovate_0_0.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..201be26bfce3b0d694b674423c632d6fe1425a4e --- /dev/null +++ b/demo/examples/Innovate_0_0.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a74d91e607bceafe0ea45858dec08949eb93597cbed45a7bf194cf476d118b03 +size 185887 diff --git a/demo/examples/PRESERVE_0_0.jpeg b/demo/examples/PRESERVE_0_0.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..337b5a53f709dde0f5f16065cc1ee9b31a22d140 --- /dev/null +++ b/demo/examples/PRESERVE_0_0.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e5608fe1a2ccd04f18ba8a24172bd484bf2a628e21756fb5121283cad9618f60 +size 295856 diff --git a/demo/examples/Peaceful_0_0.jpeg b/demo/examples/Peaceful_0_0.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..d9267a92f353ced685789d570239ecf497406b25 --- /dev/null +++ b/demo/examples/Peaceful_0_0.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:40e3adca8425b26c41f64ff62a29f299d172103989fbbe82c77f41875af9c86d +size 93331 diff --git a/demo/examples/Scamps_0_0.png b/demo/examples/Scamps_0_0.png new file mode 100644 index 0000000000000000000000000000000000000000..5ffa035aec7647e1722eb328add7442d7cb86282 --- /dev/null +++ b/demo/examples/Scamps_0_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3fa97107ac42733873b451efa06b7ba2fefbfb182f905084e0fd2f511ec8a251 +size 267489 diff --git a/demo/examples/TREE_0_0.png b/demo/examples/TREE_0_0.png new file mode 100644 index 0000000000000000000000000000000000000000..ede24288e7e0e5438b7b2d395c3ad8cc2924749c --- /dev/null +++ b/demo/examples/TREE_0_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:76e3f78050bd19efd0247befb40a5fc56a6d3067324606a9800cfdd91a6c142d +size 384257 diff --git a/demo/examples/better_0_0.jpg b/demo/examples/better_0_0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a929f89c9dbbc0b27a61753f8db9ae46c1daf048 --- /dev/null +++ b/demo/examples/better_0_0.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6473b82056e41fa74594e89c07c92640c375bdf568e3e9a5f296c9ec8c749145 +size 200512 diff --git a/demo/examples/tested_0_0.png b/demo/examples/tested_0_0.png new file mode 100644 index 0000000000000000000000000000000000000000..bb823fc3b22a1350790e73097e33bf5198fdcf84 --- /dev/null +++ b/demo/examples/tested_0_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2a3e38e5f1c63b1db4ce6d961aebf8f793ba19554f390b263340269d21b0d84a +size 305045 diff --git a/demo/teaser.png b/demo/teaser.png new file mode 100644 index 0000000000000000000000000000000000000000..a36a109cfd785b093187cf75c453aacc7cbb04ab --- /dev/null +++ b/demo/teaser.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dcd166cc9691c99a7ee93a028ab485472171ee348a5f4dbaf82f6bf1fb27c66d +size 2623749 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..0a0de5086f34d85701bc3b6e517f59103b91646e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,24 @@ +colorlover==0.3.0 +gradio==3.41.0 +imageio==2.31.2 +img2dataset==1.42.0 +lpips==0.1.4 +matplotlib==3.7.2 +numpy==1.25.1 +omegaconf==2.3.0 +open-clip-torch==2.20.0 +opencv-python==4.6.0.66 +Pillow==9.5.0 +pytorch-fid==0.3.0 +pytorch-lightning==2.0.1 +safetensors==0.3.1 +scikit-learn==1.3.0 +scipy==1.11.1 +seaborn==0.12.2 +tensorboard==2.14.0 +tokenizers==0.13.3 +torch==2.1.0 +torchvision==0.16.0 +tqdm==4.65.0 +transformers==4.30.2 + diff --git a/sgm/__init__.py b/sgm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e273bdaa90e0ff822a6098dd531046b90a12ce3e --- /dev/null +++ b/sgm/__init__.py @@ -0,0 +1,2 @@ +from .models import AutoencodingEngine, DiffusionEngine +from .util import instantiate_from_config diff --git a/sgm/lr_scheduler.py b/sgm/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..b2f4d384c1fcaff0df13e0564450d3fa972ace42 --- /dev/null +++ b/sgm/lr_scheduler.py @@ -0,0 +1,135 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + + def __init__( + self, + warm_up_steps, + lr_min, + lr_max, + lr_start, + max_decay_steps, + verbosity_interval=0, + ): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0.0 + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = ( + self.lr_max - self.lr_start + ) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / ( + self.lr_max_decay_steps - self.lr_warm_up_steps + ) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi) + ) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + + def __init__( + self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0 + ): + assert ( + len(warm_up_steps) + == len(f_min) + == len(f_max) + == len(f_start) + == len(cycle_lengths) + ) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0.0 + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print( + f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}" + ) + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ + cycle + ] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / ( + self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] + ) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi) + ) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print( + f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}" + ) + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ + cycle + ] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * ( + self.cycle_lengths[cycle] - n + ) / (self.cycle_lengths[cycle]) + self.last_f = f + return f diff --git a/sgm/models/__init__.py b/sgm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c410b3747afc208e4204c8f140170e0a7808eace --- /dev/null +++ b/sgm/models/__init__.py @@ -0,0 +1,2 @@ +from .autoencoder import AutoencodingEngine +from .diffusion import DiffusionEngine diff --git a/sgm/models/autoencoder.py b/sgm/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..83c6863df153c147b343978ef37c201691abbe23 --- /dev/null +++ b/sgm/models/autoencoder.py @@ -0,0 +1,335 @@ +import re +from abc import abstractmethod +from contextlib import contextmanager +from typing import Any, Dict, Tuple, Union + +import pytorch_lightning as pl +import torch +from omegaconf import ListConfig +from packaging import version +from safetensors.torch import load_file as load_safetensors + +from ..modules.diffusionmodules.model import Decoder, Encoder +from ..modules.distributions.distributions import DiagonalGaussianDistribution +from ..modules.ema import LitEma +from ..util import default, get_obj_from_str, instantiate_from_config + + +class AbstractAutoencoder(pl.LightningModule): + """ + This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators, + unCLIP models, etc. Hence, it is fairly general, and specific features + (e.g. discriminator training, encoding, decoding) must be implemented in subclasses. + """ + + def __init__( + self, + ema_decay: Union[None, float] = None, + monitor: Union[None, str] = None, + input_key: str = "jpg", + ckpt_path: Union[None, str] = None, + ignore_keys: Union[Tuple, list, ListConfig] = (), + ): + super().__init__() + self.input_key = input_key + self.use_ema = ema_decay is not None + if monitor is not None: + self.monitor = monitor + + if self.use_ema: + self.model_ema = LitEma(self, decay=ema_decay) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + if version.parse(torch.__version__) >= version.parse("2.0.0"): + self.automatic_optimization = False + + def init_from_ckpt( + self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple() + ) -> None: + if path.endswith("ckpt"): + sd = torch.load(path, map_location="cpu")["state_dict"] + elif path.endswith("safetensors"): + sd = load_safetensors(path) + else: + raise NotImplementedError + + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if re.match(ik, k): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print( + f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" + ) + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + @abstractmethod + def get_input(self, batch) -> Any: + raise NotImplementedError() + + def on_train_batch_end(self, *args, **kwargs): + # for EMA computation + if self.use_ema: + self.model_ema(self) + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + @abstractmethod + def encode(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("encode()-method of abstract base class called") + + @abstractmethod + def decode(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("decode()-method of abstract base class called") + + def instantiate_optimizer_from_config(self, params, lr, cfg): + print(f"loading >>> {cfg['target']} <<< optimizer from config") + return get_obj_from_str(cfg["target"])( + params, lr=lr, **cfg.get("params", dict()) + ) + + def configure_optimizers(self) -> Any: + raise NotImplementedError() + + +class AutoencodingEngine(AbstractAutoencoder): + """ + Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL + (we also restore them explicitly as special cases for legacy reasons). + Regularizations such as KL or VQ are moved to the regularizer class. + """ + + def __init__( + self, + *args, + encoder_config: Dict, + decoder_config: Dict, + loss_config: Dict, + regularizer_config: Dict, + optimizer_config: Union[Dict, None] = None, + lr_g_factor: float = 1.0, + **kwargs, + ): + super().__init__(*args, **kwargs) + # todo: add options to freeze encoder/decoder + self.encoder = instantiate_from_config(encoder_config) + self.decoder = instantiate_from_config(decoder_config) + self.loss = instantiate_from_config(loss_config) + self.regularization = instantiate_from_config(regularizer_config) + self.optimizer_config = default( + optimizer_config, {"target": "torch.optim.Adam"} + ) + self.lr_g_factor = lr_g_factor + + def get_input(self, batch: Dict) -> torch.Tensor: + # assuming unified data format, dataloader returns a dict. + # image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc) + return batch[self.input_key] + + def get_autoencoder_params(self) -> list: + params = ( + list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.regularization.get_trainable_parameters()) + + list(self.loss.get_trainable_autoencoder_parameters()) + ) + return params + + def get_discriminator_params(self) -> list: + params = list(self.loss.get_trainable_parameters()) # e.g., discriminator + return params + + def get_last_layer(self): + return self.decoder.get_last_layer() + + def encode(self, x: Any, return_reg_log: bool = False) -> Any: + z = self.encoder(x) + z, reg_log = self.regularization(z) + if return_reg_log: + return z, reg_log + return z + + def decode(self, z: Any) -> torch.Tensor: + x = self.decoder(z) + return x + + def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + z, reg_log = self.encode(x, return_reg_log=True) + dec = self.decode(z) + return z, dec, reg_log + + def training_step(self, batch, batch_idx, optimizer_idx) -> Any: + x = self.get_input(batch) + z, xrec, regularization_log = self(x) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss( + regularization_log, + x, + xrec, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) + + self.log_dict( + log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True + ) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss( + regularization_log, + x, + xrec, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) + self.log_dict( + log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True + ) + return discloss + + def validation_step(self, batch, batch_idx) -> Dict: + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") + log_dict.update(log_dict_ema) + return log_dict + + def _validation_step(self, batch, batch_idx, postfix="") -> Dict: + x = self.get_input(batch) + + z, xrec, regularization_log = self(x) + aeloss, log_dict_ae = self.loss( + regularization_log, + x, + xrec, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + postfix, + ) + + discloss, log_dict_disc = self.loss( + regularization_log, + x, + xrec, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + postfix, + ) + self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) + log_dict_ae.update(log_dict_disc) + self.log_dict(log_dict_ae) + return log_dict_ae + + def configure_optimizers(self) -> Any: + ae_params = self.get_autoencoder_params() + disc_params = self.get_discriminator_params() + + opt_ae = self.instantiate_optimizer_from_config( + ae_params, + default(self.lr_g_factor, 1.0) * self.learning_rate, + self.optimizer_config, + ) + opt_disc = self.instantiate_optimizer_from_config( + disc_params, self.learning_rate, self.optimizer_config + ) + + return [opt_ae, opt_disc], [] + + @torch.no_grad() + def log_images(self, batch: Dict, **kwargs) -> Dict: + log = dict() + x = self.get_input(batch) + _, xrec, _ = self(x) + log["inputs"] = x + log["reconstructions"] = xrec + with self.ema_scope(): + _, xrec_ema, _ = self(x) + log["reconstructions_ema"] = xrec_ema + return log + + +class AutoencoderKL(AutoencodingEngine): + def __init__(self, embed_dim: int, **kwargs): + ddconfig = kwargs.pop("ddconfig") + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", ()) + super().__init__( + encoder_config={"target": "torch.nn.Identity"}, + decoder_config={"target": "torch.nn.Identity"}, + regularizer_config={"target": "torch.nn.Identity"}, + loss_config=kwargs.pop("lossconfig"), + **kwargs, + ) + assert ddconfig["double_z"] + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def encode(self, x): + assert ( + not self.training + ), f"{self.__class__.__name__} only supports inference currently" + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, **decoder_kwargs): + z = self.post_quant_conv(z) + dec = self.decoder(z, **decoder_kwargs) + return dec + + +class AutoencoderKLInferenceWrapper(AutoencoderKL): + def encode(self, x): + return super().encode(x).sample() + + +class IdentityFirstStage(AbstractAutoencoder): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def get_input(self, x: Any) -> Any: + return x + + def encode(self, x: Any, *args, **kwargs) -> Any: + return x + + def decode(self, x: Any, *args, **kwargs) -> Any: + return x diff --git a/sgm/models/diffusion.py b/sgm/models/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..1c72ed364de456b081d85939c94bae3e53ec7128 --- /dev/null +++ b/sgm/models/diffusion.py @@ -0,0 +1,328 @@ +from contextlib import contextmanager +from typing import Any, Dict, List, Tuple, Union + +import pytorch_lightning as pl +import torch +from omegaconf import ListConfig, OmegaConf +from safetensors.torch import load_file as load_safetensors +from torch.optim.lr_scheduler import LambdaLR + +from ..modules import UNCONDITIONAL_CONFIG +from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER +from ..modules.ema import LitEma +from ..util import ( + default, + disabled_train, + get_obj_from_str, + instantiate_from_config, + log_txt_as_img, +) + + +class DiffusionEngine(pl.LightningModule): + def __init__( + self, + network_config, + denoiser_config, + first_stage_config, + conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None, + sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None, + optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None, + scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None, + loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None, + network_wrapper: Union[None, str] = None, + ckpt_path: Union[None, str] = None, + use_ema: bool = False, + ema_decay_rate: float = 0.9999, + scale_factor: float = 1.0, + disable_first_stage_autocast=False, + input_key: str = "jpg", + log_keys: Union[List, None] = None, + no_cond_log: bool = False, + compile_model: bool = False, + opt_keys: Union[List, None] = None + ): + super().__init__() + self.opt_keys = opt_keys + self.log_keys = log_keys + self.input_key = input_key + self.optimizer_config = default( + optimizer_config, {"target": "torch.optim.AdamW"} + ) + model = instantiate_from_config(network_config) + self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))( + model, compile_model=compile_model + ) + + self.denoiser = instantiate_from_config(denoiser_config) + self.sampler = ( + instantiate_from_config(sampler_config) + if sampler_config is not None + else None + ) + self.conditioner = instantiate_from_config( + default(conditioner_config, UNCONDITIONAL_CONFIG) + ) + self.scheduler_config = scheduler_config + self._init_first_stage(first_stage_config) + + self.loss_fn = ( + instantiate_from_config(loss_fn_config) + if loss_fn_config is not None + else None + ) + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model, decay=ema_decay_rate) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.scale_factor = scale_factor + self.disable_first_stage_autocast = disable_first_stage_autocast + self.no_cond_log = no_cond_log + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path) + + def init_from_ckpt( + self, + path: str, + ) -> None: + if path.endswith("ckpt"): + sd = torch.load(path, map_location="cpu")["state_dict"] + elif path.endswith("safetensors"): + sd = load_safetensors(path) + else: + raise NotImplementedError + + missing, unexpected = self.load_state_dict(sd, strict=False) + print( + f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" + ) + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def freeze(self): + + for param in self.parameters(): + param.requires_grad_(False) + + def _init_first_stage(self, config): + model = instantiate_from_config(config).eval() + model.train = disabled_train + for param in model.parameters(): + param.requires_grad = False + self.first_stage_model = model + + def get_input(self, batch): + # assuming unified data format, dataloader returns a dict. + # image tensors should be scaled to -1 ... 1 and in bchw format + return batch[self.input_key] + + @torch.no_grad() + def decode_first_stage(self, z): + z = 1.0 / self.scale_factor * z + with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): + out = self.first_stage_model.decode(z) + return out + + @torch.no_grad() + def encode_first_stage(self, x): + with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): + z = self.first_stage_model.encode(x) + z = self.scale_factor * z + return z + + def forward(self, x, batch): + + loss, loss_dict = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch, self.first_stage_model, self.scale_factor) + + return loss, loss_dict + + def shared_step(self, batch: Dict) -> Any: + x = self.get_input(batch) + x = self.encode_first_stage(x) + batch["global_step"] = self.global_step + loss, loss_dict = self(x, batch) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + self.log_dict( + loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False + ) + + self.log( + "global_step", + float(self.global_step), + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False, + ) + + lr = self.optimizers().param_groups[0]["lr"] + self.log( + "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False + ) + + return loss + + def on_train_start(self, *args, **kwargs): + if self.sampler is None or self.loss_fn is None: + raise ValueError("Sampler and loss function need to be set for training.") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def instantiate_optimizer_from_config(self, params, lr, cfg): + return get_obj_from_str(cfg["target"])( + params, lr=lr, **cfg.get("params", dict()) + ) + + def configure_optimizers(self): + lr = self.learning_rate + params = [] + print("Trainable parameter list: ") + print("-"*20) + for name, param in self.model.named_parameters(): + if any([key in name for key in self.opt_keys]): + params.append(param) + print(name) + else: + param.requires_grad_(False) + for embedder in self.conditioner.embedders: + if embedder.is_trainable: + for name, param in embedder.named_parameters(): + params.append(param) + print(name) + print("-"*20) + opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config) + scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda epoch: 0.95**epoch) + + return [opt], scheduler + + @torch.no_grad() + def sample( + self, + cond: Dict, + uc: Union[Dict, None] = None, + batch_size: int = 16, + shape: Union[None, Tuple, List] = None, + **kwargs, + ): + randn = torch.randn(batch_size, *shape).to(self.device) + + denoiser = lambda input, sigma, c: self.denoiser( + self.model, input, sigma, c, **kwargs + ) + samples = self.sampler(denoiser, randn, cond, uc=uc) + return samples + + @torch.no_grad() + def log_conditionings(self, batch: Dict, n: int) -> Dict: + """ + Defines heuristics to log different conditionings. + These can be lists of strings (text-to-image), tensors, ints, ... + """ + image_h, image_w = batch[self.input_key].shape[2:] + log = dict() + + for embedder in self.conditioner.embedders: + if ( + (self.log_keys is None) or (embedder.input_key in self.log_keys) + ) and not self.no_cond_log: + x = batch[embedder.input_key][:n] + if isinstance(x, torch.Tensor): + if x.dim() == 1: + # class-conditional, convert integer to string + x = [str(x[i].item()) for i in range(x.shape[0])] + xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4) + elif x.dim() == 2: + # size and crop cond and the like + x = [ + "x".join([str(xx) for xx in x[i].tolist()]) + for i in range(x.shape[0]) + ] + xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) + else: + raise NotImplementedError() + elif isinstance(x, (List, ListConfig)): + if isinstance(x[0], str): + # strings + xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) + else: + raise NotImplementedError() + else: + raise NotImplementedError() + log[embedder.input_key] = xc + return log + + @torch.no_grad() + def log_images( + self, + batch: Dict, + N: int = 8, + sample: bool = True, + ucg_keys: List[str] = None, + **kwargs, + ) -> Dict: + conditioner_input_keys = [e.input_key for e in self.conditioner.embedders] + if ucg_keys: + assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( + "Each defined ucg key for sampling must be in the provided conditioner input keys," + f"but we have {ucg_keys} vs. {conditioner_input_keys}" + ) + else: + ucg_keys = conditioner_input_keys + log = dict() + + x = self.get_input(batch) + + c, uc = self.conditioner.get_unconditional_conditioning( + batch, + force_uc_zero_embeddings=ucg_keys + if len(self.conditioner.embedders) > 0 + else [], + ) + + sampling_kwargs = {} + + N = min(x.shape[0], N) + x = x.to(self.device)[:N] + log["inputs"] = x + z = self.encode_first_stage(x) + log["reconstructions"] = self.decode_first_stage(z) + log.update(self.log_conditionings(batch, N)) + + for k in c: + if isinstance(c[k], torch.Tensor): + c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc)) + + if sample: + with self.ema_scope("Plotting"): + samples = self.sample( + c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs + ) + samples = self.decode_first_stage(samples) + log["samples"] = samples + return log diff --git a/sgm/modules/__init__.py b/sgm/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a2a7e034952f8feb22a035dbfd4872025a9727b9 --- /dev/null +++ b/sgm/modules/__init__.py @@ -0,0 +1,6 @@ +from .encoders.modules import GeneralConditioner, DualConditioner + +UNCONDITIONAL_CONFIG = { + "target": "sgm.modules.GeneralConditioner", + "params": {"emb_models": []}, +} diff --git a/sgm/modules/attention.py b/sgm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..791d5f3ca5016b224977831b28a6733d06b1ff88 --- /dev/null +++ b/sgm/modules/attention.py @@ -0,0 +1,976 @@ +import math +from inspect import isfunction +from typing import Any, Optional + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from packaging import version +from torch import nn, einsum + + +if version.parse(torch.__version__) >= version.parse("2.0.0"): + SDP_IS_AVAILABLE = True + from torch.backends.cuda import SDPBackend, sdp_kernel + + BACKEND_MAP = { + SDPBackend.MATH: { + "enable_math": True, + "enable_flash": False, + "enable_mem_efficient": False, + }, + SDPBackend.FLASH_ATTENTION: { + "enable_math": False, + "enable_flash": True, + "enable_mem_efficient": False, + }, + SDPBackend.EFFICIENT_ATTENTION: { + "enable_math": False, + "enable_flash": False, + "enable_mem_efficient": True, + }, + None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True}, + } +else: + from contextlib import nullcontext + + SDP_IS_AVAILABLE = False + sdp_kernel = nullcontext + BACKEND_MAP = {} + print( + f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, " + f"you are using PyTorch {torch.__version__}. You might want to consider upgrading." + ) + +try: + import xformers + import xformers.ops + + XFORMERS_IS_AVAILABLE = True +except: + XFORMERS_IS_AVAILABLE = False + print("no module 'xformers'. Processing without...") + +from .diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +class CrossAttention(nn.Module): + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + backend=None, + ): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = zero_module(nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + )) + self.backend = backend + + self.attn_map_cache = None + + def forward( + self, + x, + context=None, + mask=None, + additional_tokens=None, + n_times_crossframe_attn_in_self=0, + ): + h = self.heads + + if additional_tokens is not None: + # get the number of masked tokens at the beginning of the output sequence + n_tokens_to_mask = additional_tokens.shape[1] + # add additional token + x = torch.cat([additional_tokens, x], dim=1) + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + if n_times_crossframe_attn_in_self: + # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 + assert x.shape[0] % n_times_crossframe_attn_in_self == 0 + n_cp = x.shape[0] // n_times_crossframe_attn_in_self + k = repeat( + k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp + ) + v = repeat( + v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp + ) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + + ## old + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + # save attn_map + if self.attn_map_cache is not None: + bh, n, l = sim.shape + size = int(n**0.5) + self.attn_map_cache["size"] = size + self.attn_map_cache["attn_map"] = sim + + out = einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + + ## new + # with sdp_kernel(**BACKEND_MAP[self.backend]): + # # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape) + # out = F.scaled_dot_product_attention( + # q, k, v, attn_mask=mask + # ) # scale is dim_head ** -0.5 per default + + # del q, k, v + # out = rearrange(out, "b h n d -> b n (h d)", h=h) + + if additional_tokens is not None: + # remove additional token + out = out[:, n_tokens_to_mask:] + return self.to_out(out) + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__( + self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs + ): + super().__init__() + # print( + # f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + # f"{heads} heads with a dimension of {dim_head}." + # ) + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + self.attention_op: Optional[Any] = None + + def forward( + self, + x, + context=None, + mask=None, + additional_tokens=None, + n_times_crossframe_attn_in_self=0, + ): + if additional_tokens is not None: + # get the number of masked tokens at the beginning of the output sequence + n_tokens_to_mask = additional_tokens.shape[1] + # add additional token + x = torch.cat([additional_tokens, x], dim=1) + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + if n_times_crossframe_attn_in_self: + # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 + assert x.shape[0] % n_times_crossframe_attn_in_self == 0 + # n_cp = x.shape[0]//n_times_crossframe_attn_in_self + k = repeat( + k[::n_times_crossframe_attn_in_self], + "b ... -> (b n) ...", + n=n_times_crossframe_attn_in_self, + ) + v = repeat( + v[::n_times_crossframe_attn_in_self], + "b ... -> (b n) ...", + n=n_times_crossframe_attn_in_self, + ) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention( + q, k, v, attn_bias=None, op=self.attention_op + ) + + # TODO: Use this directly in the attention operation, as a bias + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + if additional_tokens is not None: + # remove additional token + out = out[:, n_tokens_to_mask:] + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, # vanilla attention + "softmax-xformers": MemoryEfficientCrossAttention, # ampere + } + + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + add_context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False, + attn_mode="softmax", + sdp_backend=None, + ): + super().__init__() + assert attn_mode in self.ATTENTION_MODES + if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE: + print( + f"Attention mode '{attn_mode}' is not available. Falling back to native attention. " + f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}" + ) + attn_mode = "softmax" + elif attn_mode == "softmax" and not SDP_IS_AVAILABLE: + print( + "We do not support vanilla attention anymore, as it is too expensive. Sorry." + ) + if not XFORMERS_IS_AVAILABLE: + assert ( + False + ), "Please install xformers via e.g. 'pip install xformers==0.0.16'" + else: + print("Falling back to xformers efficient attention.") + attn_mode = "softmax-xformers" + attn_cls = self.ATTENTION_MODES[attn_mode] + if version.parse(torch.__version__) >= version.parse("2.0.0"): + assert sdp_backend is None or isinstance(sdp_backend, SDPBackend) + else: + assert sdp_backend is None + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None, + backend=sdp_backend, + ) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + if context_dim is not None and context_dim > 0: + self.attn2 = attn_cls( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + backend=sdp_backend, + ) # is self-attn if context is none + if add_context_dim is not None and add_context_dim > 0: + self.add_attn = attn_cls( + query_dim=dim, + context_dim=add_context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + backend=sdp_backend, + ) # is self-attn if context is none + self.add_norm = nn.LayerNorm(dim) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward( + self, x, context=None, add_context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 + ): + kwargs = {"x": x} + + if context is not None: + kwargs.update({"context": context}) + + if additional_tokens is not None: + kwargs.update({"additional_tokens": additional_tokens}) + + if n_times_crossframe_attn_in_self: + kwargs.update( + {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self} + ) + + return checkpoint( + self._forward, (x, context, add_context), self.parameters(), self.checkpoint + ) + + def _forward( + self, x, context=None, add_context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 + ): + x = ( + self.attn1( + self.norm1(x), + context=context if self.disable_self_attn else None, + additional_tokens=additional_tokens, + n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self + if not self.disable_self_attn + else 0, + ) + + x + ) + if hasattr(self, "attn2"): + x = ( + self.attn2( + self.norm2(x), context=context, additional_tokens=additional_tokens + ) + + x + ) + if hasattr(self, "add_attn"): + x = ( + self.add_attn( + self.add_norm(x), context=add_context, additional_tokens=additional_tokens + ) + + x + ) + x = self.ff(self.norm3(x)) + x + return x + + +class BasicTransformerSingleLayerBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, # vanilla attention + "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version + # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128]) + } + + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + attn_mode="softmax", + ): + super().__init__() + assert attn_mode in self.ATTENTION_MODES + attn_cls = self.ATTENTION_MODES[attn_mode] + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim, + ) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint( + self._forward, (x, context), self.parameters(), self.checkpoint + ) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context) + x + x = self.ff(self.norm2(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + add_context_dim=None, + disable_self_attn=False, + use_linear=False, + attn_type="softmax", + use_checkpoint=True, + # sdp_backend=SDPBackend.FLASH_ATTENTION + sdp_backend=None, + ): + super().__init__() + # print( + # f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads" + # ) + from omegaconf import ListConfig + + if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)): + context_dim = [context_dim] + if exists(context_dim) and isinstance(context_dim, list): + if depth != len(context_dim): + # print( + # f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, " + # f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now." + # ) + # depth does not match context dims. + assert all( + map(lambda x: x == context_dim[0], context_dim) + ), "need homogenous context_dim to match depth automatically" + context_dim = depth * [context_dim[0]] + elif context_dim is None: + context_dim = [None] * depth + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + add_context_dim=add_context_dim, + disable_self_attn=disable_self_attn, + attn_mode=attn_type, + checkpoint=use_checkpoint, + sdp_backend=sdp_backend, + ) + for d in range(depth) + ] + ) + if not use_linear: + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + else: + # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) + self.use_linear = use_linear + + def forward(self, x, context=None, add_context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + if i > 0 and len(context) == 1: + i = 0 # use same context for each block + x = block(x, context=context[i], add_context=add_context) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +def benchmark_attn(): + # Lets define a helpful benchmarking function: + # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html + device = "cuda" if torch.cuda.is_available() else "cpu" + import torch.nn.functional as F + import torch.utils.benchmark as benchmark + + def benchmark_torch_function_in_microseconds(f, *args, **kwargs): + t0 = benchmark.Timer( + stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} + ) + return t0.blocked_autorange().mean * 1e6 + + # Lets define the hyper-parameters of our input + batch_size = 32 + max_sequence_len = 1024 + num_heads = 32 + embed_dimension = 32 + + dtype = torch.float16 + + query = torch.rand( + batch_size, + num_heads, + max_sequence_len, + embed_dimension, + device=device, + dtype=dtype, + ) + key = torch.rand( + batch_size, + num_heads, + max_sequence_len, + embed_dimension, + device=device, + dtype=dtype, + ) + value = torch.rand( + batch_size, + num_heads, + max_sequence_len, + embed_dimension, + device=device, + dtype=dtype, + ) + + print(f"q/k/v shape:", query.shape, key.shape, value.shape) + + # Lets explore the speed of each of the 3 implementations + from torch.backends.cuda import SDPBackend, sdp_kernel + + # Helpful arguments mapper + backend_map = { + SDPBackend.MATH: { + "enable_math": True, + "enable_flash": False, + "enable_mem_efficient": False, + }, + SDPBackend.FLASH_ATTENTION: { + "enable_math": False, + "enable_flash": True, + "enable_mem_efficient": False, + }, + SDPBackend.EFFICIENT_ATTENTION: { + "enable_math": False, + "enable_flash": False, + "enable_mem_efficient": True, + }, + } + + from torch.profiler import ProfilerActivity, profile, record_function + + activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] + + print( + f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" + ) + with profile( + activities=activities, record_shapes=False, profile_memory=True + ) as prof: + with record_function("Default detailed stats"): + for _ in range(25): + o = F.scaled_dot_product_attention(query, key, value) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + print( + f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" + ) + with sdp_kernel(**backend_map[SDPBackend.MATH]): + with profile( + activities=activities, record_shapes=False, profile_memory=True + ) as prof: + with record_function("Math implmentation stats"): + for _ in range(25): + o = F.scaled_dot_product_attention(query, key, value) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): + try: + print( + f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" + ) + except RuntimeError: + print("FlashAttention is not supported. See warnings for reasons.") + with profile( + activities=activities, record_shapes=False, profile_memory=True + ) as prof: + with record_function("FlashAttention stats"): + for _ in range(25): + o = F.scaled_dot_product_attention(query, key, value) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): + try: + print( + f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" + ) + except RuntimeError: + print("EfficientAttention is not supported. See warnings for reasons.") + with profile( + activities=activities, record_shapes=False, profile_memory=True + ) as prof: + with record_function("EfficientAttention stats"): + for _ in range(25): + o = F.scaled_dot_product_attention(query, key, value) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + +def run_model(model, x, context): + return model(x, context) + + +def benchmark_transformer_blocks(): + device = "cuda" if torch.cuda.is_available() else "cpu" + import torch.utils.benchmark as benchmark + + def benchmark_torch_function_in_microseconds(f, *args, **kwargs): + t0 = benchmark.Timer( + stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} + ) + return t0.blocked_autorange().mean * 1e6 + + checkpoint = True + compile = False + + batch_size = 32 + h, w = 64, 64 + context_len = 77 + embed_dimension = 1024 + context_dim = 1024 + d_head = 64 + + transformer_depth = 4 + + n_heads = embed_dimension // d_head + + dtype = torch.float16 + + model_native = SpatialTransformer( + embed_dimension, + n_heads, + d_head, + context_dim=context_dim, + use_linear=True, + use_checkpoint=checkpoint, + attn_type="softmax", + depth=transformer_depth, + sdp_backend=SDPBackend.FLASH_ATTENTION, + ).to(device) + model_efficient_attn = SpatialTransformer( + embed_dimension, + n_heads, + d_head, + context_dim=context_dim, + use_linear=True, + depth=transformer_depth, + use_checkpoint=checkpoint, + attn_type="softmax-xformers", + ).to(device) + if not checkpoint and compile: + print("compiling models") + model_native = torch.compile(model_native) + model_efficient_attn = torch.compile(model_efficient_attn) + + x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype) + c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype) + + from torch.profiler import ProfilerActivity, profile, record_function + + activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] + + with torch.autocast("cuda"): + print( + f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds" + ) + print( + f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds" + ) + + print(75 * "+") + print("NATIVE") + print(75 * "+") + torch.cuda.reset_peak_memory_stats() + with profile( + activities=activities, record_shapes=False, profile_memory=True + ) as prof: + with record_function("NativeAttention stats"): + for _ in range(25): + model_native(x, c) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block") + + print(75 * "+") + print("Xformers") + print(75 * "+") + torch.cuda.reset_peak_memory_stats() + with profile( + activities=activities, record_shapes=False, profile_memory=True + ) as prof: + with record_function("xformers stats"): + for _ in range(25): + model_efficient_attn(x, c) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block") + + +def test01(): + # conv1x1 vs linear + from ..util import count_params + + conv = nn.Conv2d(3, 32, kernel_size=1).cuda() + print(count_params(conv)) + linear = torch.nn.Linear(3, 32).cuda() + print(count_params(linear)) + + print(conv.weight.shape) + + # use same initialization + linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1)) + linear.bias = torch.nn.Parameter(conv.bias) + + print(linear.weight.shape) + + x = torch.randn(11, 3, 64, 64).cuda() + + xr = rearrange(x, "b c h w -> b (h w) c").contiguous() + print(xr.shape) + out_linear = linear(xr) + print(out_linear.mean(), out_linear.shape) + + out_conv = conv(x) + print(out_conv.mean(), out_conv.shape) + print("done with test01.\n") + + +def test02(): + # try cosine flash attention + import time + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.backends.cudnn.benchmark = True + print("testing cosine flash attention...") + DIM = 1024 + SEQLEN = 4096 + BS = 16 + + print(" softmax (vanilla) first...") + model = BasicTransformerBlock( + dim=DIM, + n_heads=16, + d_head=64, + dropout=0.0, + context_dim=None, + attn_mode="softmax", + ).cuda() + try: + x = torch.randn(BS, SEQLEN, DIM).cuda() + tic = time.time() + y = model(x) + toc = time.time() + print(y.shape, toc - tic) + except RuntimeError as e: + # likely oom + print(str(e)) + + print("\n now flash-cosine...") + model = BasicTransformerBlock( + dim=DIM, + n_heads=16, + d_head=64, + dropout=0.0, + context_dim=None, + attn_mode="flash-cosine", + ).cuda() + x = torch.randn(BS, SEQLEN, DIM).cuda() + tic = time.time() + y = model(x) + toc = time.time() + print(y.shape, toc - tic) + print("done with test02.\n") + + +if __name__ == "__main__": + # test01() + # test02() + # test03() + + # benchmark_attn() + benchmark_transformer_blocks() + + print("done.") diff --git a/sgm/modules/autoencoding/__init__.py b/sgm/modules/autoencoding/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sgm/modules/autoencoding/losses/__init__.py b/sgm/modules/autoencoding/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6a3b54f7284ae1be6a23b425f6c296efc1881a5c --- /dev/null +++ b/sgm/modules/autoencoding/losses/__init__.py @@ -0,0 +1,246 @@ +from typing import Any, Union + +import torch +import torch.nn as nn +from einops import rearrange +from taming.modules.discriminator.model import NLayerDiscriminator, weights_init +from taming.modules.losses.lpips import LPIPS +from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss + +from ....util import default, instantiate_from_config + + +def adopt_weight(weight, global_step, threshold=0, value=0.0): + if global_step < threshold: + weight = value + return weight + + +class LatentLPIPS(nn.Module): + def __init__( + self, + decoder_config, + perceptual_weight=1.0, + latent_weight=1.0, + scale_input_to_tgt_size=False, + scale_tgt_to_input_size=False, + perceptual_weight_on_inputs=0.0, + ): + super().__init__() + self.scale_input_to_tgt_size = scale_input_to_tgt_size + self.scale_tgt_to_input_size = scale_tgt_to_input_size + self.init_decoder(decoder_config) + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + self.latent_weight = latent_weight + self.perceptual_weight_on_inputs = perceptual_weight_on_inputs + + def init_decoder(self, config): + self.decoder = instantiate_from_config(config) + if hasattr(self.decoder, "encoder"): + del self.decoder.encoder + + def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"): + log = dict() + loss = (latent_inputs - latent_predictions) ** 2 + log[f"{split}/latent_l2_loss"] = loss.mean().detach() + image_reconstructions = None + if self.perceptual_weight > 0.0: + image_reconstructions = self.decoder.decode(latent_predictions) + image_targets = self.decoder.decode(latent_inputs) + perceptual_loss = self.perceptual_loss( + image_targets.contiguous(), image_reconstructions.contiguous() + ) + loss = ( + self.latent_weight * loss.mean() + + self.perceptual_weight * perceptual_loss.mean() + ) + log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach() + + if self.perceptual_weight_on_inputs > 0.0: + image_reconstructions = default( + image_reconstructions, self.decoder.decode(latent_predictions) + ) + if self.scale_input_to_tgt_size: + image_inputs = torch.nn.functional.interpolate( + image_inputs, + image_reconstructions.shape[2:], + mode="bicubic", + antialias=True, + ) + elif self.scale_tgt_to_input_size: + image_reconstructions = torch.nn.functional.interpolate( + image_reconstructions, + image_inputs.shape[2:], + mode="bicubic", + antialias=True, + ) + + perceptual_loss2 = self.perceptual_loss( + image_inputs.contiguous(), image_reconstructions.contiguous() + ) + loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean() + log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach() + return loss, log + + +class GeneralLPIPSWithDiscriminator(nn.Module): + def __init__( + self, + disc_start: int, + logvar_init: float = 0.0, + pixelloss_weight=1.0, + disc_num_layers: int = 3, + disc_in_channels: int = 3, + disc_factor: float = 1.0, + disc_weight: float = 1.0, + perceptual_weight: float = 1.0, + disc_loss: str = "hinge", + scale_input_to_tgt_size: bool = False, + dims: int = 2, + learn_logvar: bool = False, + regularization_weights: Union[None, dict] = None, + ): + super().__init__() + self.dims = dims + if self.dims > 2: + print( + f"running with dims={dims}. This means that for perceptual loss calculation, " + f"the LPIPS loss will be applied to each frame independently. " + ) + self.scale_input_to_tgt_size = scale_input_to_tgt_size + assert disc_loss in ["hinge", "vanilla"] + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + # output log variance + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + self.learn_logvar = learn_logvar + + self.discriminator = NLayerDiscriminator( + input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=False + ).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.regularization_weights = default(regularization_weights, {}) + + def get_trainable_parameters(self) -> Any: + return self.discriminator.parameters() + + def get_trainable_autoencoder_parameters(self) -> Any: + if self.learn_logvar: + yield self.logvar + yield from () + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad( + nll_loss, self.last_layer[0], retain_graph=True + )[0] + g_grads = torch.autograd.grad( + g_loss, self.last_layer[0], retain_graph=True + )[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward( + self, + regularization_log, + inputs, + reconstructions, + optimizer_idx, + global_step, + last_layer=None, + split="train", + weights=None, + ): + if self.scale_input_to_tgt_size: + inputs = torch.nn.functional.interpolate( + inputs, reconstructions.shape[2:], mode="bicubic", antialias=True + ) + + if self.dims > 2: + inputs, reconstructions = map( + lambda x: rearrange(x, "b c t h w -> (b t) c h w"), + (inputs, reconstructions), + ) + + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss( + inputs.contiguous(), reconstructions.contiguous() + ) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights * nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + + # now the GAN part + if optimizer_idx == 0: + # generator update + logits_fake = self.discriminator(reconstructions.contiguous()) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight( + nll_loss, g_loss, last_layer=last_layer + ) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight( + self.disc_factor, global_step, threshold=self.discriminator_iter_start + ) + loss = weighted_nll_loss + d_weight * disc_factor * g_loss + log = dict() + for k in regularization_log: + if k in self.regularization_weights: + loss = loss + self.regularization_weights[k] * regularization_log[k] + log[f"{split}/{k}"] = regularization_log[k].detach().mean() + + log.update( + { + "{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/logvar".format(split): self.logvar.detach(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + ) + + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + + disc_factor = adopt_weight( + self.disc_factor, global_step, threshold=self.discriminator_iter_start + ) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = { + "{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean(), + } + return d_loss, log diff --git a/sgm/modules/autoencoding/regularizers/__init__.py b/sgm/modules/autoencoding/regularizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8de3212d3be4f58e621e8caa6e31dd8dc32b6929 --- /dev/null +++ b/sgm/modules/autoencoding/regularizers/__init__.py @@ -0,0 +1,53 @@ +from abc import abstractmethod +from typing import Any, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ....modules.distributions.distributions import DiagonalGaussianDistribution + + +class AbstractRegularizer(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + raise NotImplementedError() + + @abstractmethod + def get_trainable_parameters(self) -> Any: + raise NotImplementedError() + + +class DiagonalGaussianRegularizer(AbstractRegularizer): + def __init__(self, sample: bool = True): + super().__init__() + self.sample = sample + + def get_trainable_parameters(self) -> Any: + yield from () + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + log = dict() + posterior = DiagonalGaussianDistribution(z) + if self.sample: + z = posterior.sample() + else: + z = posterior.mode() + kl_loss = posterior.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + log["kl_loss"] = kl_loss + return z, log + + +def measure_perplexity(predicted_indices, num_centroids): + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = ( + F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) + ) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use diff --git a/sgm/modules/diffusionmodules/__init__.py b/sgm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ce7968af9224aff42a20023b4e14ca059939e034 --- /dev/null +++ b/sgm/modules/diffusionmodules/__init__.py @@ -0,0 +1,7 @@ +from .denoiser import Denoiser +from .discretizer import Discretization +from .loss import StandardDiffusionLoss +from .model import Model, Encoder, Decoder +from .openaimodel import UNetModel +from .sampling import BaseDiffusionSampler +from .wrappers import OpenAIWrapper diff --git a/sgm/modules/diffusionmodules/denoiser.py b/sgm/modules/diffusionmodules/denoiser.py new file mode 100644 index 0000000000000000000000000000000000000000..4651e7de5c4a90e0656843821d5d32a5ce0ddd0e --- /dev/null +++ b/sgm/modules/diffusionmodules/denoiser.py @@ -0,0 +1,63 @@ +import torch.nn as nn + +from ...util import append_dims, instantiate_from_config + + +class Denoiser(nn.Module): + def __init__(self, weighting_config, scaling_config): + super().__init__() + + self.weighting = instantiate_from_config(weighting_config) + self.scaling = instantiate_from_config(scaling_config) + + def possibly_quantize_sigma(self, sigma): + return sigma + + def possibly_quantize_c_noise(self, c_noise): + return c_noise + + def w(self, sigma): + return self.weighting(sigma) + + def __call__(self, network, input, sigma, cond): + sigma = self.possibly_quantize_sigma(sigma) + sigma_shape = sigma.shape + sigma = append_dims(sigma, input.ndim) + c_skip, c_out, c_in, c_noise = self.scaling(sigma) + c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) + return network(input * c_in, c_noise, cond) * c_out + input * c_skip + + +class DiscreteDenoiser(Denoiser): + def __init__( + self, + weighting_config, + scaling_config, + num_idx, + discretization_config, + do_append_zero=False, + quantize_c_noise=True, + flip=True, + ): + super().__init__(weighting_config, scaling_config) + sigmas = instantiate_from_config(discretization_config)( + num_idx, do_append_zero=do_append_zero, flip=flip + ) + self.register_buffer("sigmas", sigmas) + self.quantize_c_noise = quantize_c_noise + + def sigma_to_idx(self, sigma): + dists = sigma - self.sigmas[:, None] + return dists.abs().argmin(dim=0).view(sigma.shape) + + def idx_to_sigma(self, idx): + return self.sigmas[idx] + + def possibly_quantize_sigma(self, sigma): + return self.idx_to_sigma(self.sigma_to_idx(sigma)) + + def possibly_quantize_c_noise(self, c_noise): + if self.quantize_c_noise: + return self.sigma_to_idx(c_noise) + else: + return c_noise diff --git a/sgm/modules/diffusionmodules/denoiser_scaling.py b/sgm/modules/diffusionmodules/denoiser_scaling.py new file mode 100644 index 0000000000000000000000000000000000000000..f8a2ac6732ea78f1030b21bebd14063d52ac2a82 --- /dev/null +++ b/sgm/modules/diffusionmodules/denoiser_scaling.py @@ -0,0 +1,31 @@ +import torch + + +class EDMScaling: + def __init__(self, sigma_data=0.5): + self.sigma_data = sigma_data + + def __call__(self, sigma): + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 + c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 + c_noise = 0.25 * sigma.log() + return c_skip, c_out, c_in, c_noise + + +class EpsScaling: + def __call__(self, sigma): + c_skip = torch.ones_like(sigma, device=sigma.device) + c_out = -sigma + c_in = 1 / (sigma**2 + 1.0) ** 0.5 + c_noise = sigma.clone() + return c_skip, c_out, c_in, c_noise + + +class VScaling: + def __call__(self, sigma): + c_skip = 1.0 / (sigma**2 + 1.0) + c_out = -sigma / (sigma**2 + 1.0) ** 0.5 + c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 + c_noise = sigma.clone() + return c_skip, c_out, c_in, c_noise diff --git a/sgm/modules/diffusionmodules/denoiser_weighting.py b/sgm/modules/diffusionmodules/denoiser_weighting.py new file mode 100644 index 0000000000000000000000000000000000000000..b8b03ca58f17ea3d7374f4bbb7bf1d2994755e00 --- /dev/null +++ b/sgm/modules/diffusionmodules/denoiser_weighting.py @@ -0,0 +1,24 @@ +import torch + + +class UnitWeighting: + def __call__(self, sigma): + return torch.ones_like(sigma, device=sigma.device) + + +class EDMWeighting: + def __init__(self, sigma_data=0.5): + self.sigma_data = sigma_data + + def __call__(self, sigma): + return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + + +class VWeighting(EDMWeighting): + def __init__(self): + super().__init__(sigma_data=1.0) + + +class EpsWeighting: + def __call__(self, sigma): + return sigma**-2.0 diff --git a/sgm/modules/diffusionmodules/discretizer.py b/sgm/modules/diffusionmodules/discretizer.py new file mode 100644 index 0000000000000000000000000000000000000000..397b8f386615b50bf83742eb0a8c02a95a6ffefc --- /dev/null +++ b/sgm/modules/diffusionmodules/discretizer.py @@ -0,0 +1,68 @@ +import torch +import numpy as np +from functools import partial +from abc import abstractmethod + +from ...util import append_zero +from ...modules.diffusionmodules.util import make_beta_schedule + + +def generate_roughly_equally_spaced_steps( + num_substeps: int, max_step: int +) -> np.ndarray: + return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] + + +class Discretization: + def __call__(self, n, do_append_zero=True, device="cpu", flip=False): + sigmas = self.get_sigmas(n, device=device) + sigmas = append_zero(sigmas) if do_append_zero else sigmas + return sigmas if not flip else torch.flip(sigmas, (0,)) + + @abstractmethod + def get_sigmas(self, n, device): + pass + + +class EDMDiscretization(Discretization): + def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0): + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.rho = rho + + def get_sigmas(self, n, device="cpu"): + ramp = torch.linspace(0, 1, n, device=device) + min_inv_rho = self.sigma_min ** (1 / self.rho) + max_inv_rho = self.sigma_max ** (1 / self.rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho + return sigmas + + +class LegacyDDPMDiscretization(Discretization): + def __init__( + self, + linear_start=0.00085, + linear_end=0.0120, + num_timesteps=1000, + ): + super().__init__() + self.num_timesteps = num_timesteps + betas = make_beta_schedule( + "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end + ) + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.to_torch = partial(torch.tensor, dtype=torch.float32) + + def get_sigmas(self, n, device="cpu"): + if n < self.num_timesteps: + timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) + alphas_cumprod = self.alphas_cumprod[timesteps] + elif n == self.num_timesteps: + alphas_cumprod = self.alphas_cumprod + else: + raise ValueError + + to_torch = partial(torch.tensor, dtype=torch.float32, device=device) + sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 + return torch.flip(sigmas, (0,)) diff --git a/sgm/modules/diffusionmodules/guiders.py b/sgm/modules/diffusionmodules/guiders.py new file mode 100644 index 0000000000000000000000000000000000000000..078de2920a8b20b911c7a42e2ffad67c4425e20b --- /dev/null +++ b/sgm/modules/diffusionmodules/guiders.py @@ -0,0 +1,81 @@ +from functools import partial + +import torch + +from ...util import default, instantiate_from_config + + +class VanillaCFG: + """ + implements parallelized CFG + """ + + def __init__(self, scale, dyn_thresh_config=None): + scale_schedule = lambda scale, sigma: scale # independent of step + self.scale_schedule = partial(scale_schedule, scale) + self.dyn_thresh = instantiate_from_config( + default( + dyn_thresh_config, + { + "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" + }, + ) + ) + + def __call__(self, x, sigma): + x_u, x_c = x.chunk(2) + scale_value = self.scale_schedule(sigma) + x_pred = self.dyn_thresh(x_u, x_c, scale_value) + return x_pred + + def prepare_inputs(self, x, s, c, uc): + c_out = dict() + + for k in c: + if k in ["vector", "crossattn", "add_crossattn", "concat"]: + c_out[k] = torch.cat((uc[k], c[k]), 0) + else: + assert c[k] == uc[k] + c_out[k] = c[k] + return torch.cat([x] * 2), torch.cat([s] * 2), c_out + + +class DualCFG: + + def __init__(self, scale): + self.scale = scale + self.dyn_thresh = instantiate_from_config( + { + "target": "sgm.modules.diffusionmodules.sampling_utils.DualThresholding" + }, + ) + + def __call__(self, x, sigma): + x_u_1, x_u_2, x_c = x.chunk(3) + x_pred = self.dyn_thresh(x_u_1, x_u_2, x_c, self.scale) + return x_pred + + def prepare_inputs(self, x, s, c, uc_1, uc_2): + c_out = dict() + + for k in c: + if k in ["vector", "crossattn", "concat", "add_crossattn"]: + c_out[k] = torch.cat((uc_1[k], uc_2[k], c[k]), 0) + else: + assert c[k] == uc_1[k] + c_out[k] = c[k] + return torch.cat([x] * 3), torch.cat([s] * 3), c_out + + + +class IdentityGuider: + def __call__(self, x, sigma): + return x + + def prepare_inputs(self, x, s, c, uc): + c_out = dict() + + for k in c: + c_out[k] = c[k] + + return x, s, c_out diff --git a/sgm/modules/diffusionmodules/loss.py b/sgm/modules/diffusionmodules/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..8a27d1dd96703c30617ec01d2da2860615aa012a --- /dev/null +++ b/sgm/modules/diffusionmodules/loss.py @@ -0,0 +1,275 @@ +from typing import List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from omegaconf import ListConfig +from taming.modules.losses.lpips import LPIPS +from torchvision.utils import save_image +from ...util import append_dims, instantiate_from_config + + +class StandardDiffusionLoss(nn.Module): + def __init__( + self, + sigma_sampler_config, + type="l2", + offset_noise_level=0.0, + batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None, + ): + super().__init__() + + assert type in ["l2", "l1", "lpips"] + + self.sigma_sampler = instantiate_from_config(sigma_sampler_config) + + self.type = type + self.offset_noise_level = offset_noise_level + + if type == "lpips": + self.lpips = LPIPS().eval() + + if not batch2model_keys: + batch2model_keys = [] + + if isinstance(batch2model_keys, str): + batch2model_keys = [batch2model_keys] + + self.batch2model_keys = set(batch2model_keys) + + def __call__(self, network, denoiser, conditioner, input, batch, *args, **kwarg): + cond = conditioner(batch) + additional_model_inputs = { + key: batch[key] for key in self.batch2model_keys.intersection(batch) + } + + sigmas = self.sigma_sampler(input.shape[0]).to(input.device) + noise = torch.randn_like(input) + if self.offset_noise_level > 0.0: + noise = noise + self.offset_noise_level * append_dims( + torch.randn(input.shape[0], device=input.device), input.ndim + ) + noised_input = input + noise * append_dims(sigmas, input.ndim) + model_output = denoiser( + network, noised_input, sigmas, cond, **additional_model_inputs + ) + w = append_dims(denoiser.w(sigmas), input.ndim) + + loss = self.get_diff_loss(model_output, input, w) + loss = loss.mean() + loss_dict = {"loss": loss} + + return loss, loss_dict + + def get_diff_loss(self, model_output, target, w): + if self.type == "l2": + return torch.mean( + (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1 + ) + elif self.type == "l1": + return torch.mean( + (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1 + ) + elif self.type == "lpips": + loss = self.lpips(model_output, target).reshape(-1) + return loss + + +class FullLoss(StandardDiffusionLoss): + + def __init__( + self, + seq_len=12, + kernel_size=3, + gaussian_sigma=0.5, + min_attn_size=16, + lambda_local_loss=0.0, + lambda_ocr_loss=0.0, + ocr_enabled = False, + predictor_config = None, + *args, **kwarg + ): + super().__init__(*args, **kwarg) + + self.gaussian_kernel_size = kernel_size + gaussian_kernel = self.get_gaussian_kernel(kernel_size=self.gaussian_kernel_size, sigma=gaussian_sigma, out_channels=seq_len) + self.register_buffer("g_kernel", gaussian_kernel.requires_grad_(False)) + + self.min_attn_size = min_attn_size + self.lambda_local_loss = lambda_local_loss + self.lambda_ocr_loss = lambda_ocr_loss + + self.ocr_enabled = ocr_enabled + if ocr_enabled: + self.predictor = instantiate_from_config(predictor_config) + + def get_gaussian_kernel(self, kernel_size=3, sigma=1, out_channels=3): + # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) + x_coord = torch.arange(kernel_size) + x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size) + y_grid = x_grid.t() + xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() + + mean = (kernel_size - 1)/2. + variance = sigma**2. + + # Calculate the 2-dimensional gaussian kernel which is + # the product of two gaussian distributions for two different + # variables (in this case called x and y) + gaussian_kernel = (1./(2.*torch.pi*variance)) *\ + torch.exp( + -torch.sum((xy_grid - mean)**2., dim=-1) /\ + (2*variance) + ) + + # Make sure sum of values in gaussian kernel equals 1. + gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) + + # Reshape to 2d depthwise convolutional weight + gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) + gaussian_kernel = gaussian_kernel.tile(out_channels, 1, 1, 1) + + return gaussian_kernel + + def __call__(self, network, denoiser, conditioner, input, batch, first_stage_model, scaler): + + cond = conditioner(batch) + + sigmas = self.sigma_sampler(input.shape[0]).to(input.device) + noise = torch.randn_like(input) + if self.offset_noise_level > 0.0: + noise = noise + self.offset_noise_level * append_dims( + torch.randn(input.shape[0], device=input.device), input.ndim + ) + + noised_input = input + noise * append_dims(sigmas, input.ndim) + model_output = denoiser(network, noised_input, sigmas, cond) + w = append_dims(denoiser.w(sigmas), input.ndim) + + diff_loss = self.get_diff_loss(model_output, input, w) + local_loss = self.get_local_loss(network.diffusion_model.attn_map_cache, batch["seg"], batch["seg_mask"]) + diff_loss = diff_loss.mean() + local_loss = local_loss.mean() + + if self.ocr_enabled: + ocr_loss = self.get_ocr_loss(model_output, batch["r_bbox"], batch["label"], first_stage_model, scaler) + ocr_loss = ocr_loss.mean() + + loss = diff_loss + self.lambda_local_loss * local_loss + if self.ocr_enabled: + loss += self.lambda_ocr_loss * ocr_loss + + loss_dict = { + "loss/diff_loss": diff_loss, + "loss/local_loss": local_loss, + "loss/full_loss": loss + } + + if self.ocr_enabled: + loss_dict["loss/ocr_loss"] = ocr_loss + + return loss, loss_dict + + def get_ocr_loss(self, model_output, r_bbox, label, first_stage_model, scaler): + + model_output = 1 / scaler * model_output + model_output_decoded = first_stage_model.decode(model_output) + model_output_crops = [] + + for i, bbox in enumerate(r_bbox): + m_top, m_bottom, m_left, m_right = bbox + model_output_crops.append(model_output_decoded[i, :, m_top:m_bottom, m_left:m_right]) + + loss = self.predictor.calc_loss(model_output_crops, label) + + return loss + + def get_min_local_loss(self, attn_map_cache, mask, seg_mask): + + loss = 0 + count = 0 + + for item in attn_map_cache: + + heads = item["heads"] + size = item["size"] + attn_map = item["attn_map"] + + if size < self.min_attn_size: continue + + seg_l = seg_mask.shape[1] + + bh, n, l = attn_map.shape # bh: batch size * heads / n : pixel length(h*w) / l: token length + attn_map = attn_map.reshape((-1, heads, n, l)) # b, h, n, l + + assert seg_l <= l + attn_map = attn_map[..., :seg_l] + attn_map = attn_map.permute(0, 1, 3, 2) # b, h, l, n + attn_map = attn_map.mean(dim = 1) # b, l, n + + attn_map = attn_map.reshape((-1, seg_l, size, size)) # b, l, s, s + attn_map = F.conv2d(attn_map, self.g_kernel, padding = self.gaussian_kernel_size//2, groups=seg_l) # gaussian blur on each channel + attn_map = attn_map.reshape((-1, seg_l, n)) # b, l, n + + mask_map = F.interpolate(mask, (size, size)) + mask_map = mask_map.tile((1, seg_l, 1, 1)) + mask_map = mask_map.reshape((-1, seg_l, n)) # b, l, n + + p_loss = (mask_map * attn_map).max(dim = -1)[0] # b, l + p_loss = p_loss + (1 - seg_mask) # b, l + p_loss = p_loss.min(dim = -1)[0] # b, + + loss += -p_loss + count += 1 + + loss = loss / count + + return loss + + def get_local_loss(self, attn_map_cache, seg, seg_mask): + + loss = 0 + count = 0 + + for item in attn_map_cache: + + heads = item["heads"] + size = item["size"] + attn_map = item["attn_map"] + + if size < self.min_attn_size: continue + + seg_l = seg_mask.shape[1] + + bh, n, l = attn_map.shape # bh: batch size * heads / n : pixel length(h*w) / l: token length + attn_map = attn_map.reshape((-1, heads, n, l)) # b, h, n, l + + assert seg_l <= l + attn_map = attn_map[..., :seg_l] + attn_map = attn_map.permute(0, 1, 3, 2) # b, h, l, n + attn_map = attn_map.mean(dim = 1) # b, l, n + + attn_map = attn_map.reshape((-1, seg_l, size, size)) # b, l, s, s + attn_map = F.conv2d(attn_map, self.g_kernel, padding = self.gaussian_kernel_size//2, groups=seg_l) # gaussian blur on each channel + attn_map = attn_map.reshape((-1, seg_l, n)) # b, l, n + + seg_map = F.interpolate(seg, (size, size)) + seg_map = seg_map.reshape((-1, seg_l, n)) # b, l, n + n_seg_map = 1 - seg_map + + p_loss = (seg_map * attn_map).max(dim = -1)[0] # b, l + n_loss = (n_seg_map * attn_map).max(dim = -1)[0] # b, l + + p_loss = p_loss * seg_mask # b, l + n_loss = n_loss * seg_mask # b, l + + p_loss = p_loss.sum(dim = -1) / seg_mask.sum(dim = -1) # b, + n_loss = n_loss.sum(dim = -1) / seg_mask.sum(dim = -1) # b, + + f_loss = n_loss - p_loss # b, + loss += f_loss + count += 1 + + loss = loss / count + + return loss \ No newline at end of file diff --git a/sgm/modules/diffusionmodules/model.py b/sgm/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..26efd07840def93d44513a5dabd79edbb7cee662 --- /dev/null +++ b/sgm/modules/diffusionmodules/model.py @@ -0,0 +1,743 @@ +# pytorch_diffusion + derived encoder decoder +import math +from typing import Any, Callable, Optional + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from packaging import version + +try: + import xformers + import xformers.ops + + XFORMERS_IS_AVAILABLE = True +except: + XFORMERS_IS_AVAILABLE = False + print("no module 'xformers'. Processing without...") + +from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def attention(self, h_: torch.Tensor) -> torch.Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q, k, v = map( + lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v) + ) + h_ = torch.nn.functional.scaled_dot_product_attention( + q, k, v + ) # scale is dim ** -0.5 per default + # compute attention + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x, **kwargs): + h_ = x + h_ = self.attention(h_) + h_ = self.proj_out(h_) + return x + h_ + + +class MemoryEfficientAttnBlock(nn.Module): + """ + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation + """ + + # + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.attention_op: Optional[Any] = None + + def attention(self, h_: torch.Tensor) -> torch.Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), + (q, k, v), + ) + out = xformers.ops.memory_efficient_attention( + q, k, v, attn_bias=None, op=self.attention_op + ) + + out = ( + out.unsqueeze(0) + .reshape(B, 1, out.shape[1], C) + .permute(0, 2, 1, 3) + .reshape(B, out.shape[1], C) + ) + return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) + + def forward(self, x, **kwargs): + h_ = x + h_ = self.attention(h_) + h_ = self.proj_out(h_) + return x + h_ + + +class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): + def forward(self, x, context=None, mask=None, **unused_kwargs): + b, c, h, w = x.shape + x = rearrange(x, "b c h w -> b (h w) c") + out = super().forward(x, context=context, mask=mask) + out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c) + return x + out + + +def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): + assert attn_type in [ + "vanilla", + "vanilla-xformers", + "memory-efficient-cross-attn", + "linear", + "none", + ], f"attn_type {attn_type} unknown" + if ( + version.parse(torch.__version__) < version.parse("2.0.0") + and attn_type != "none" + ): + assert XFORMERS_IS_AVAILABLE, ( + f"We do not support vanilla attention in {torch.__version__} anymore, " + f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'" + ) + attn_type = "vanilla-xformers" + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + assert attn_kwargs is None + return AttnBlock(in_channels) + elif attn_type == "vanilla-xformers": + print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") + return MemoryEfficientAttnBlock(in_channels) + elif type == "memory-efficient-cross-attn": + attn_kwargs["query_dim"] = in_channels + return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb + ) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) + + make_attn_cls = self._make_attn() + make_resblock_cls = self._make_resblock() + make_conv_cls = self._make_conv() + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = make_resblock_cls( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type) + self.mid.block_2 = make_resblock_cls( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + make_resblock_cls( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn_cls(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = make_conv_cls( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def _make_attn(self) -> Callable: + return make_attn + + def _make_resblock(self) -> Callable: + return ResnetBlock + + def _make_conv(self) -> Callable: + return torch.nn.Conv2d + + def get_last_layer(self, **kwargs): + return self.conv_out.weight + + def forward(self, z, **kwargs): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb, **kwargs) + h = self.mid.attn_1(h, **kwargs) + h = self.mid.block_2(h, temb, **kwargs) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb, **kwargs) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h, **kwargs) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h, **kwargs) + if self.tanh_out: + h = torch.tanh(h) + return h diff --git a/sgm/modules/diffusionmodules/openaimodel.py b/sgm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..2768d3232b6cb361b631661d428f272b9cd06f79 --- /dev/null +++ b/sgm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,2070 @@ +import os +import math +from abc import abstractmethod +from functools import partial +from typing import Iterable + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from ...modules.attention import SpatialTransformer +from ...modules.diffusionmodules.util import ( + avg_pool_nd, + checkpoint, + conv_nd, + linear, + normalization, + timestep_embedding, + zero_module, +) +from ...util import default, exists + + +# dummy replace +def convert_module_to_f16(x): + pass + + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward( + self, + x, + emb, + context=None, + add_context=None, + skip_time_mix=False, + time_context=None, + num_video_frames=None, + time_context_cat=None, + use_crossframe_attention_in_spatial_layers=False, + ): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context, add_context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__( + self, channels, use_conv, dims=2, out_channels=None, padding=1, third_up=False + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + self.third_up = third_up + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + t_factor = 1 if not self.third_up else 2 + x = F.interpolate( + x, + (t_factor * x.shape[2], x.shape[3] * 2, x.shape[4] * 2), + mode="nearest", + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + "Learned 2x upsampling without padding" + + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d( + self.channels, self.out_channels, kernel_size=ks, stride=2 + ) + + def forward(self, x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__( + self, channels, use_conv, dims=2, out_channels=None, padding=1, third_down=False + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2)) + if use_conv: + # print(f"Building a Downsample layer with {dims} dims.") + # print( + # f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, " + # f"kernel-size: 3, stride: {stride}, padding: {padding}" + # ) + if dims == 3: + pass + # print(f" --> Downsampling third axis (time): {third_down}") + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + kernel_size=3, + exchange_temb_dims=False, + skip_t_emb=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + self.exchange_temb_dims = exchange_temb_dims + + if isinstance(kernel_size, Iterable): + padding = [k // 2 for k in kernel_size] + else: + padding = kernel_size // 2 + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.skip_t_emb = skip_t_emb + self.emb_out_channels = ( + 2 * self.out_channels if use_scale_shift_norm else self.out_channels + ) + if self.skip_t_emb: + print(f"Skipping timestep embedding in {self.__class__.__name__}") + assert not self.use_scale_shift_norm + self.emb_layers = None + self.exchange_temb_dims = False + else: + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + self.emb_out_channels, + ), + ) + + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd( + dims, + self.out_channels, + self.out_channels, + kernel_size, + padding=padding, + ) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, kernel_size, padding=padding + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + + if self.skip_t_emb: + emb_out = th.zeros_like(h) + else: + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + if self.exchange_temb_dims: + emb_out = rearrange(emb_out, "b t c ... -> b c t ...") + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x, **kwargs): + # TODO add crossframe attention and use mixed checkpoint + return checkpoint( + self._forward, (x,), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class Timestep(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, t): + return timestep_embedding(t, self.dim) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + spatial_transformer_attn_type="softmax", + adm_in_channels=None, + use_fairscale_checkpoint=False, + offload_to_cpu=False, + transformer_depth_middle=None, + ): + super().__init__() + from omegaconf.listconfig import ListConfig + + if use_spatial_transformer: + assert ( + context_dim is not None + ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." + + if context_dim is not None: + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(transformer_depth, int): + transformer_depth = len(channel_mult) * [transformer_depth] + elif isinstance(transformer_depth, ListConfig): + transformer_depth = list(transformer_depth) + transformer_depth_middle = default( + transformer_depth_middle, transformer_depth[-1] + ) + + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError( + "provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult" + ) + self.num_res_blocks = num_res_blocks + # self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all( + map( + lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], + range(len(num_attention_blocks)), + ) + ) + print( + f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set." + ) # todo: convert to warning + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + if use_fp16: + print("WARNING: use_fp16 was dropped and has no effect anymore.") + # self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + assert use_fairscale_checkpoint != use_checkpoint or not ( + use_checkpoint or use_fairscale_checkpoint + ) + + self.use_fairscale_checkpoint = False + checkpoint_wrapper_fn = ( + partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu) + if self.use_fairscale_checkpoint + else lambda x: x + ) + + time_embed_dim = model_channels * 4 + self.time_embed = checkpoint_wrapper_fn( + nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "timestep": + self.label_emb = checkpoint_wrapper_fn( + nn.Sequential( + Timestep(model_channels), + nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ), + ) + ) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if ( + not exists(num_attention_blocks) + or nr < num_attention_blocks[level] + ): + layers.append( + checkpoint_wrapper_fn( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + ) + if not use_spatial_transformer + else checkpoint_wrapper_fn( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ), + checkpoint_wrapper_fn( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + ) + if not use_spatial_transformer + else checkpoint_wrapper_fn( + SpatialTransformer( # always uses a self-attn + ch, + num_heads, + dim_head, + depth=transformer_depth_middle, + context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + ), + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + checkpoint_wrapper_fn( + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if ( + not exists(num_attention_blocks) + or i < num_attention_blocks[level] + ): + layers.append( + checkpoint_wrapper_fn( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + ) + if not use_spatial_transformer + else checkpoint_wrapper_fn( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = checkpoint_wrapper_fn( + nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + ) + if self.predict_codebook_ids: + self.id_predictor = checkpoint_wrapper_fn( + nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + # h = x.type(self.dtype) + h = x + for i, module in enumerate(self.input_blocks): + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for i, module in enumerate(self.output_blocks): + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + assert False, "not supported anymore. what the f*** are you doing?" + else: + return self.out(h) + + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + spatial_transformer_attn_type="softmax", + adm_in_channels=None, + use_fairscale_checkpoint=False, + offload_to_cpu=False, + transformer_depth_middle=None, + ): + super().__init__() + from omegaconf.listconfig import ListConfig + + if use_spatial_transformer: + assert ( + context_dim is not None + ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." + + if context_dim is not None: + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(transformer_depth, int): + transformer_depth = len(channel_mult) * [transformer_depth] + elif isinstance(transformer_depth, ListConfig): + transformer_depth = list(transformer_depth) + transformer_depth_middle = default( + transformer_depth_middle, transformer_depth[-1] + ) + + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError( + "provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult" + ) + self.num_res_blocks = num_res_blocks + # self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all( + map( + lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], + range(len(num_attention_blocks)), + ) + ) + print( + f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set." + ) # todo: convert to warning + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + if use_fp16: + print("WARNING: use_fp16 was dropped and has no effect anymore.") + # self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + assert use_fairscale_checkpoint != use_checkpoint or not ( + use_checkpoint or use_fairscale_checkpoint + ) + + self.use_fairscale_checkpoint = False + checkpoint_wrapper_fn = ( + partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu) + if self.use_fairscale_checkpoint + else lambda x: x + ) + + time_embed_dim = model_channels * 4 + self.time_embed = checkpoint_wrapper_fn( + nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "timestep": + self.label_emb = checkpoint_wrapper_fn( + nn.Sequential( + Timestep(model_channels), + nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ), + ) + ) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if ( + not exists(num_attention_blocks) + or nr < num_attention_blocks[level] + ): + layers.append( + checkpoint_wrapper_fn( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + ) + if not use_spatial_transformer + else checkpoint_wrapper_fn( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ), + checkpoint_wrapper_fn( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + ) + if not use_spatial_transformer + else checkpoint_wrapper_fn( + SpatialTransformer( # always uses a self-attn + ch, + num_heads, + dim_head, + depth=transformer_depth_middle, + context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + ), + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + checkpoint_wrapper_fn( + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if ( + not exists(num_attention_blocks) + or i < num_attention_blocks[level] + ): + layers.append( + checkpoint_wrapper_fn( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + ) + if not use_spatial_transformer + else checkpoint_wrapper_fn( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = checkpoint_wrapper_fn( + nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + ) + if self.predict_codebook_ids: + self.id_predictor = checkpoint_wrapper_fn( + nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + # h = x.type(self.dtype) + h = x + for i, module in enumerate(self.input_blocks): + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for i, module in enumerate(self.output_blocks): + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + assert False, "not supported anymore. what the f*** are you doing?" + else: + return self.out(h) + + +import seaborn as sns +import matplotlib.pyplot as plt + +class UNetAddModel(nn.Module): + + def __init__( + self, + in_channels, + ctrl_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + attn_type="attn2", + attn_layers=[], + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + add_context_dim=None, + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + spatial_transformer_attn_type="softmax", + adm_in_channels=None, + use_fairscale_checkpoint=False, + offload_to_cpu=False, + transformer_depth_middle=None, + ): + super().__init__() + from omegaconf.listconfig import ListConfig + + if use_spatial_transformer: + assert ( + context_dim is not None + ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." + + if context_dim is not None: + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.in_channels = in_channels + self.ctrl_channels = ctrl_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(transformer_depth, int): + transformer_depth = len(channel_mult) * [transformer_depth] + elif isinstance(transformer_depth, ListConfig): + transformer_depth = list(transformer_depth) + transformer_depth_middle = default( + transformer_depth_middle, transformer_depth[-1] + ) + + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError( + "provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult" + ) + self.num_res_blocks = num_res_blocks + # self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all( + map( + lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], + range(len(num_attention_blocks)), + ) + ) + print( + f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set." + ) # todo: convert to warning + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + if use_fp16: + print("WARNING: use_fp16 was dropped and has no effect anymore.") + # self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + assert use_fairscale_checkpoint != use_checkpoint or not ( + use_checkpoint or use_fairscale_checkpoint + ) + + self.use_fairscale_checkpoint = False + checkpoint_wrapper_fn = ( + partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu) + if self.use_fairscale_checkpoint + else lambda x: x + ) + + time_embed_dim = model_channels * 4 + self.time_embed = checkpoint_wrapper_fn( + nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "timestep": + self.label_emb = checkpoint_wrapper_fn( + nn.Sequential( + Timestep(model_channels), + nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ), + ) + ) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + if self.ctrl_channels > 0: + self.add_input_block = TimestepEmbedSequential( + conv_nd(dims, ctrl_channels, 16, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 16, 16, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 16, 32, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 32, 32, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 32, 96, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 96, 96, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 96, 256, 3, padding=1), + nn.SiLU(), + zero_module(conv_nd(dims, 256, model_channels, 3, padding=1)) + ) + + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if ( + not exists(num_attention_blocks) + or nr < num_attention_blocks[level] + ): + layers.append( + checkpoint_wrapper_fn( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + ) + if not use_spatial_transformer + else checkpoint_wrapper_fn( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + add_context_dim=add_context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ), + checkpoint_wrapper_fn( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + ) + if not use_spatial_transformer + else checkpoint_wrapper_fn( + SpatialTransformer( # always uses a self-attn + ch, + num_heads, + dim_head, + depth=transformer_depth_middle, + context_dim=context_dim, + add_context_dim=add_context_dim, + disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + ), + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + checkpoint_wrapper_fn( + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if ( + not exists(num_attention_blocks) + or i < num_attention_blocks[level] + ): + layers.append( + checkpoint_wrapper_fn( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + ) + if not use_spatial_transformer + else checkpoint_wrapper_fn( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + add_context_dim=add_context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = checkpoint_wrapper_fn( + nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + ) + if self.predict_codebook_ids: + self.id_predictor = checkpoint_wrapper_fn( + nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + ) + + # cache attn map + self.attn_type = attn_type + self.attn_layers = attn_layers + self.attn_map_cache = [] + for name, module in self.named_modules(): + if name.endswith(self.attn_type): + item = {"name": name, "heads": module.heads, "size": None, "attn_map": None} + self.attn_map_cache.append(item) + module.attn_map_cache = item + + def clear_attn_map(self): + + for item in self.attn_map_cache: + if item["attn_map"] is not None: + del item["attn_map"] + item["attn_map"] = None + + def save_attn_map(self, save_name="temp", tokens=""): + + attn_maps = [] + for item in self.attn_map_cache: + name = item["name"] + if any([name.startswith(block) for block in self.attn_layers]): + heads = item["heads"] + attn_maps.append(item["attn_map"].detach().cpu()) + + attn_map = th.stack(attn_maps, dim=0) + attn_map = th.mean(attn_map, dim=0) + + # attn_map: bh * n * l + bh, n, l = attn_map.shape # bh: batch size * heads / n : pixel length(h*w) / l: token length + attn_map = attn_map.reshape((-1,heads,n,l)).mean(dim=1) + b = attn_map.shape[0] + + h = w = int(n**0.5) + attn_map = attn_map.permute(0,2,1).reshape((b,l,h,w)).numpy() + + attn_map_i = attn_map[-1] + + l = attn_map_i.shape[0] + fig = plt.figure(figsize=(12, 8), dpi=300) + for j in range(12): + if j >= l: break + ax = fig.add_subplot(3, 4, j+1) + sns.heatmap(attn_map_i[j], square=True, xticklabels=False, yticklabels=False) + if j < len(tokens): + ax.set_title(tokens[j]) + fig.savefig(f"temp/attn_map/attn_map_{save_name}.png") + plt.close() + + return attn_map_i + + def forward(self, x, timesteps=None, context=None, add_context=None, y=None, **kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + + self.clear_attn_map() + + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + # h = x.type(self.dtype) + h = x + if self.ctrl_channels > 0: + in_h, add_h = th.split(h, [self.in_channels, self.ctrl_channels], dim=1) + + for i, module in enumerate(self.input_blocks): + if self.ctrl_channels > 0 and i == 0: + h = module(in_h, emb, context, add_context) + self.add_input_block(add_h, emb, context, add_context) + else: + h = module(h, emb, context, add_context) + hs.append(h) + h = self.middle_block(h, emb, context, add_context) + for i, module in enumerate(self.output_blocks): + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context, add_context) + h = h.type(x.dtype) + + return self.out(h) \ No newline at end of file diff --git a/sgm/modules/diffusionmodules/sampling.py b/sgm/modules/diffusionmodules/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..d66eb9e9722f3e7b8ea98a9a757c3fb09f33ac90 --- /dev/null +++ b/sgm/modules/diffusionmodules/sampling.py @@ -0,0 +1,784 @@ +""" + Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py +""" + + +from typing import Dict, Union + +import imageio +import torch +import json +import numpy as np +import torch.nn.functional as F +from omegaconf import ListConfig, OmegaConf +from tqdm import tqdm + +from ...modules.diffusionmodules.sampling_utils import ( + get_ancestral_step, + linear_multistep_coeff, + to_d, + to_neg_log_sigma, + to_sigma, +) +from ...util import append_dims, default, instantiate_from_config +from torchvision.utils import save_image + +DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} + + +class BaseDiffusionSampler: + def __init__( + self, + discretization_config: Union[Dict, ListConfig, OmegaConf], + num_steps: Union[int, None] = None, + guider_config: Union[Dict, ListConfig, OmegaConf, None] = None, + verbose: bool = False, + device: str = "cuda", + ): + self.num_steps = num_steps + self.discretization = instantiate_from_config(discretization_config) + self.guider = instantiate_from_config( + default( + guider_config, + DEFAULT_GUIDER, + ) + ) + self.verbose = verbose + self.device = device + + def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): + sigmas = self.discretization( + self.num_steps if num_steps is None else num_steps, device=self.device + ) + uc = default(uc, cond) + + x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) + num_sigmas = len(sigmas) + + s_in = x.new_ones([x.shape[0]]) + + return x, s_in, sigmas, num_sigmas, cond, uc + + def denoise(self, x, model, sigma, cond, uc): + denoised = model.denoiser(model.model, *self.guider.prepare_inputs(x, sigma, cond, uc)) + denoised = self.guider(denoised, sigma) + return denoised + + def get_sigma_gen(self, num_sigmas, init_step=0): + sigma_generator = range(init_step, num_sigmas - 1) + if self.verbose: + print("#" * 30, " Sampling setting ", "#" * 30) + print(f"Sampler: {self.__class__.__name__}") + print(f"Discretization: {self.discretization.__class__.__name__}") + print(f"Guider: {self.guider.__class__.__name__}") + sigma_generator = tqdm( + sigma_generator, + total=num_sigmas-1-init_step, + desc=f"Sampling with {self.__class__.__name__} for {num_sigmas-1-init_step} steps", + ) + return sigma_generator + + +class SingleStepDiffusionSampler(BaseDiffusionSampler): + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs): + raise NotImplementedError + + def euler_step(self, x, d, dt): + return x + dt * d + + +class EDMSampler(SingleStepDiffusionSampler): + def __init__( + self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs + ): + super().__init__(*args, **kwargs) + + self.s_churn = s_churn + self.s_tmin = s_tmin + self.s_tmax = s_tmax + self.s_noise = s_noise + + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0): + sigma_hat = sigma * (gamma + 1.0) + if gamma > 0: + eps = torch.randn_like(x) * self.s_noise + x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 + + denoised = self.denoise(x, denoiser, sigma_hat, cond, uc) + d = to_d(x, sigma_hat, denoised) + dt = append_dims(next_sigma - sigma_hat, x.ndim) + + euler_step = self.euler_step(x, d, dt) + x = self.possible_correction_step( + euler_step, x, d, dt, next_sigma, denoiser, cond, uc + ) + return x + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + for i in self.get_sigma_gen(num_sigmas): + gamma = ( + min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) + if self.s_tmin <= sigmas[i] <= self.s_tmax + else 0.0 + ) + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + gamma, + ) + + return x + + +class AncestralSampler(SingleStepDiffusionSampler): + def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.eta = eta + self.s_noise = s_noise + self.noise_sampler = lambda x: torch.randn_like(x) + + def ancestral_euler_step(self, x, denoised, sigma, sigma_down): + d = to_d(x, sigma, denoised) + dt = append_dims(sigma_down - sigma, x.ndim) + + return self.euler_step(x, d, dt) + + def ancestral_step(self, x, sigma, next_sigma, sigma_up): + x = torch.where( + append_dims(next_sigma, x.ndim) > 0.0, + x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim), + x, + ) + return x + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + for i in self.get_sigma_gen(num_sigmas): + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + ) + + return x + + +class LinearMultistepSampler(BaseDiffusionSampler): + def __init__( + self, + order=4, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.order = order + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + ds = [] + sigmas_cpu = sigmas.detach().cpu().numpy() + for i in self.get_sigma_gen(num_sigmas): + sigma = s_in * sigmas[i] + denoised = denoiser( + *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs + ) + denoised = self.guider(denoised, sigma) + d = to_d(x, sigma, denoised) + ds.append(d) + if len(ds) > self.order: + ds.pop(0) + cur_order = min(i + 1, self.order) + coeffs = [ + linear_multistep_coeff(cur_order, sigmas_cpu, i, j) + for j in range(cur_order) + ] + x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) + + return x + + +class EulerEDMSampler(EDMSampler): + + def possible_correction_step( + self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc + ): + return euler_step + + def get_c_noise(self, x, model, sigma): + sigma = model.denoiser.possibly_quantize_sigma(sigma) + sigma_shape = sigma.shape + sigma = append_dims(sigma, x.ndim) + c_skip, c_out, c_in, c_noise = model.denoiser.scaling(sigma) + c_noise = model.denoiser.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) + return c_noise + + def attend_and_excite(self, x, model, sigma, cond, batch, alpha, iter_enabled, thres, max_iter=20): + + # calc timestep + c_noise = self.get_c_noise(x, model, sigma) + + x = x.clone().detach().requires_grad_(True) # https://github.com/yuval-alaluf/Attend-and-Excite/blob/main/pipeline_attend_and_excite.py#L288 + + iters = 0 + while True: + + model_output = model.model(x, c_noise, cond) + local_loss = model.loss_fn.get_min_local_loss(model.model.diffusion_model.attn_map_cache, batch["mask"], batch["seg_mask"]) + grad = torch.autograd.grad(local_loss.requires_grad_(True), [x], retain_graph=True)[0] + x = x - alpha * grad + iters += 1 + + if not iter_enabled or local_loss <= thres or iters > max_iter: + break + + return x + + def create_pascal_label_colormap(self): + """ + PASCAL VOC 分割数据集的类别标签颜色映射label colormap + + 返回: + 可视化分割结果的颜色映射Colormap + """ + colormap = np.zeros((256, 3), dtype=int) + ind = np.arange(256, dtype=int) + + for shift in reversed(range(8)): + for channel in range(3): + colormap[:, channel] |= ((ind >> channel) & 1) << shift + ind >>= 3 + + return colormap + + def save_segment_map(self, image, attn_maps, tokens=None, save_name=None): + + colormap = self.create_pascal_label_colormap() + H, W = image.shape[-2:] + + image_ = image*0.3 + sections = [] + for i in range(len(tokens)): + attn_map = attn_maps[i] + attn_map_t = np.tile(attn_map[None], (1,3,1,1)) # b, 3, h, w + attn_map_t = torch.from_numpy(attn_map_t) + attn_map_t = F.interpolate(attn_map_t, (W, H)) + + color = torch.from_numpy(colormap[i+1][None,:,None,None] / 255.0) + colored_attn_map = attn_map_t * color + colored_attn_map = colored_attn_map.to(device=image_.device) + + image_ += colored_attn_map*0.7 + sections.append(attn_map) + + section = np.stack(sections) + np.save(f"temp/seg_map/seg_{save_name}.npy", section) + + save_image(image_, f"temp/seg_map/seg_{save_name}.png", normalize=True) + + def get_init_noise(self, cfgs, model, cond, batch, uc=None): + + H, W = batch["target_size_as_tuple"][0] + shape = (cfgs.batch_size, cfgs.channel, int(H) // cfgs.factor, int(W) // cfgs.factor) + + randn = torch.randn(shape).to(torch.device("cuda", index=cfgs.gpu)) + x = randn.clone() + + xs = [] + self.verbose = False + for _ in range(cfgs.noise_iters): + + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps=2 + ) + + superv = { + "mask": batch["mask"] if "mask" in batch else None, + "seg_mask": batch["seg_mask"] if "seg_mask" in batch else None + } + + local_losses = [] + + for i in self.get_sigma_gen(num_sigmas): + + gamma = ( + min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) + if self.s_tmin <= sigmas[i] <= self.s_tmax + else 0.0 + ) + + x, inter, local_loss = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + model, + x, + cond, + superv, + uc, + gamma, + save_loss=True + ) + + local_losses.append(local_loss.item()) + + xs.append((randn, local_losses[-1])) + + randn = torch.randn(shape).to(torch.device("cuda", index=cfgs.gpu)) + x = randn.clone() + + self.verbose = True + + xs.sort(key = lambda x: x[-1]) + + if len(xs) > 0: + print(f"Init local loss: Best {xs[0][1]} Worst {xs[-1][1]}") + x = xs[0][0] + + return x + + def sampler_step(self, sigma, next_sigma, model, x, cond, batch=None, uc=None, + gamma=0.0, alpha=0, iter_enabled=False, thres=None, update=False, + name=None, save_loss=False, save_attn=False, save_inter=False): + + sigma_hat = sigma * (gamma + 1.0) + if gamma > 0: + eps = torch.randn_like(x) * self.s_noise + x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 + + if update: + x = self.attend_and_excite(x, model, sigma_hat, cond, batch, alpha, iter_enabled, thres) + + denoised = self.denoise(x, model, sigma_hat, cond, uc) + denoised_decode = model.decode_first_stage(denoised) if save_inter else None + + if save_loss: + local_loss = model.loss_fn.get_min_local_loss(model.model.diffusion_model.attn_map_cache, batch["mask"], batch["seg_mask"]) + local_loss = local_loss[local_loss.shape[0]//2:] + else: + local_loss = torch.zeros(1) + if save_attn: + attn_map = model.model.diffusion_model.save_attn_map(save_name=name, tokens=batch["label"][0]) + denoised_decode = model.decode_first_stage(denoised) if denoised_decode is None else denoised_decode + self.save_segment_map(denoised_decode, attn_map, tokens=batch["label"][0], save_name=name) + + d = to_d(x, sigma_hat, denoised) + dt = append_dims(next_sigma - sigma_hat, x.ndim) + + euler_step = self.euler_step(x, d, dt) + + return euler_step, denoised_decode, local_loss + + def __call__(self, model, x, cond, batch=None, uc=None, num_steps=None, init_step=0, + name=None, aae_enabled=False, detailed=False): + + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + name = batch["name"][0] + inters = [] + local_losses = [] + scales = np.linspace(start=1.0, stop=0, num=num_sigmas) + iter_lst = np.linspace(start=5, stop=25, num=6, dtype=np.int32) + thres_lst = np.linspace(start=-0.5, stop=-0.8, num=6) + + for i in self.get_sigma_gen(num_sigmas, init_step=init_step): + + gamma = ( + min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) + if self.s_tmin <= sigmas[i] <= self.s_tmax + else 0.0 + ) + + alpha = 20 * np.sqrt(scales[i]) + update = aae_enabled + save_loss = detailed + save_attn = detailed and (i == (num_sigmas-1)//2) + save_inter = detailed + + if i in iter_lst: + iter_enabled = True + thres = thres_lst[list(iter_lst).index(i)] + else: + iter_enabled = False + thres = 0.0 + + x, inter, local_loss = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + model, + x, + cond, + batch, + uc, + gamma, + alpha=alpha, + iter_enabled=iter_enabled, + thres=thres, + update=update, + name=name, + save_loss=save_loss, + save_attn=save_attn, + save_inter=save_inter + ) + + local_losses.append(local_loss.item()) + if inter is not None: + inter = torch.clamp((inter + 1.0) / 2.0, min=0.0, max=1.0)[0] + inter = inter.cpu().numpy().transpose(1, 2, 0) * 255 + inters.append(inter.astype(np.uint8)) + + print(f"Local losses: {local_losses}") + + if len(inters) > 0: + imageio.mimsave(f"./temp/inters/{name}.gif", inters, 'GIF', duration=0.02) + + return x + + +class EulerEDMDualSampler(EulerEDMSampler): + + def prepare_sampling_loop(self, x, cond, uc_1=None, uc_2=None, num_steps=None): + sigmas = self.discretization( + self.num_steps if num_steps is None else num_steps, device=self.device + ) + uc_1 = default(uc_1, cond) + uc_2 = default(uc_2, cond) + + x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) + num_sigmas = len(sigmas) + + s_in = x.new_ones([x.shape[0]]) + + return x, s_in, sigmas, num_sigmas, cond, uc_1, uc_2 + + def denoise(self, x, model, sigma, cond, uc_1, uc_2): + denoised = model.denoiser(model.model, *self.guider.prepare_inputs(x, sigma, cond, uc_1, uc_2)) + denoised = self.guider(denoised, sigma) + return denoised + + def get_init_noise(self, cfgs, model, cond, batch, uc_1=None, uc_2=None): + + H, W = batch["target_size_as_tuple"][0] + shape = (cfgs.batch_size, cfgs.channel, int(H) // cfgs.factor, int(W) // cfgs.factor) + + randn = torch.randn(shape).to(torch.device("cuda", index=cfgs.gpu)) + x = randn.clone() + + xs = [] + self.verbose = False + for _ in range(cfgs.noise_iters): + + x, s_in, sigmas, num_sigmas, cond, uc_1, uc_2 = self.prepare_sampling_loop( + x, cond, uc_1, uc_2, num_steps=2 + ) + + superv = { + "mask": batch["mask"] if "mask" in batch else None, + "seg_mask": batch["seg_mask"] if "seg_mask" in batch else None + } + + local_losses = [] + + for i in self.get_sigma_gen(num_sigmas): + + gamma = ( + min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) + if self.s_tmin <= sigmas[i] <= self.s_tmax + else 0.0 + ) + + x, inter, local_loss = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + model, + x, + cond, + superv, + uc_1, + uc_2, + gamma, + save_loss=True + ) + + local_losses.append(local_loss.item()) + + xs.append((randn, local_losses[-1])) + + randn = torch.randn(shape).to(torch.device("cuda", index=cfgs.gpu)) + x = randn.clone() + + self.verbose = True + + xs.sort(key = lambda x: x[-1]) + + if len(xs) > 0: + print(f"Init local loss: Best {xs[0][1]} Worst {xs[-1][1]}") + x = xs[0][0] + + return x + + def sampler_step(self, sigma, next_sigma, model, x, cond, batch=None, uc_1=None, uc_2=None, + gamma=0.0, alpha=0, iter_enabled=False, thres=None, update=False, + name=None, save_loss=False, save_attn=False, save_inter=False): + + sigma_hat = sigma * (gamma + 1.0) + if gamma > 0: + eps = torch.randn_like(x) * self.s_noise + x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 + + if update: + x = self.attend_and_excite(x, model, sigma_hat, cond, batch, alpha, iter_enabled, thres) + + denoised = self.denoise(x, model, sigma_hat, cond, uc_1, uc_2) + denoised_decode = model.decode_first_stage(denoised) if save_inter else None + + if save_loss: + local_loss = model.loss_fn.get_min_local_loss(model.model.diffusion_model.attn_map_cache, batch["mask"], batch["seg_mask"]) + local_loss = local_loss[-local_loss.shape[0]//3:] + else: + local_loss = torch.zeros(1) + if save_attn: + attn_map = model.model.diffusion_model.save_attn_map(save_name=name, save_single=True) + denoised_decode = model.decode_first_stage(denoised) if denoised_decode is None else denoised_decode + self.save_segment_map(denoised_decode, attn_map, tokens=batch["label"][0], save_name=name) + + d = to_d(x, sigma_hat, denoised) + dt = append_dims(next_sigma - sigma_hat, x.ndim) + + euler_step = self.euler_step(x, d, dt) + + return euler_step, denoised_decode, local_loss + + def __call__(self, model, x, cond, batch=None, uc_1=None, uc_2=None, num_steps=None, init_step=0, + name=None, aae_enabled=False, detailed=False): + + x, s_in, sigmas, num_sigmas, cond, uc_1, uc_2 = self.prepare_sampling_loop( + x, cond, uc_1, uc_2, num_steps + ) + + name = batch["name"][0] + inters = [] + local_losses = [] + scales = np.linspace(start=1.0, stop=0, num=num_sigmas) + iter_lst = np.linspace(start=5, stop=25, num=6, dtype=np.int32) + thres_lst = np.linspace(start=-0.5, stop=-0.8, num=6) + + for i in self.get_sigma_gen(num_sigmas, init_step=init_step): + + gamma = ( + min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) + if self.s_tmin <= sigmas[i] <= self.s_tmax + else 0.0 + ) + + alpha = 20 * np.sqrt(scales[i]) + update = aae_enabled + save_loss = aae_enabled + save_attn = detailed and (i == (num_sigmas-1)//2) + save_inter = aae_enabled + + if i in iter_lst: + iter_enabled = True + thres = thres_lst[list(iter_lst).index(i)] + else: + iter_enabled = False + thres = 0.0 + + x, inter, local_loss = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + model, + x, + cond, + batch, + uc_1, + uc_2, + gamma, + alpha=alpha, + iter_enabled=iter_enabled, + thres=thres, + update=update, + name=name, + save_loss=save_loss, + save_attn=save_attn, + save_inter=save_inter + ) + + local_losses.append(local_loss.item()) + if inter is not None: + inter = torch.clamp((inter + 1.0) / 2.0, min=0.0, max=1.0)[0] + inter = inter.cpu().numpy().transpose(1, 2, 0) * 255 + inters.append(inter.astype(np.uint8)) + + print(f"Local losses: {local_losses}") + + if len(inters) > 0: + imageio.mimsave(f"./temp/inters/{name}.gif", inters, 'GIF', duration=0.1) + + return x + + +class HeunEDMSampler(EDMSampler): + def possible_correction_step( + self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc + ): + if torch.sum(next_sigma) < 1e-14: + # Save a network evaluation if all noise levels are 0 + return euler_step + else: + denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc) + d_new = to_d(euler_step, next_sigma, denoised) + d_prime = (d + d_new) / 2.0 + + # apply correction if noise level is not 0 + x = torch.where( + append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step + ) + return x + + +class EulerAncestralSampler(AncestralSampler): + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc): + sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) + denoised = self.denoise(x, denoiser, sigma, cond, uc) + x = self.ancestral_euler_step(x, denoised, sigma, sigma_down) + x = self.ancestral_step(x, sigma, next_sigma, sigma_up) + + return x + + +class DPMPP2SAncestralSampler(AncestralSampler): + def get_variables(self, sigma, sigma_down): + t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)] + h = t_next - t + s = t + 0.5 * h + return h, s, t, t_next + + def get_mult(self, h, s, t, t_next): + mult1 = to_sigma(s) / to_sigma(t) + mult2 = (-0.5 * h).expm1() + mult3 = to_sigma(t_next) / to_sigma(t) + mult4 = (-h).expm1() + + return mult1, mult2, mult3, mult4 + + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs): + sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) + denoised = self.denoise(x, denoiser, sigma, cond, uc) + x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down) + + if torch.sum(sigma_down) < 1e-14: + # Save a network evaluation if all noise levels are 0 + x = x_euler + else: + h, s, t, t_next = self.get_variables(sigma, sigma_down) + mult = [ + append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next) + ] + + x2 = mult[0] * x - mult[1] * denoised + denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc) + x_dpmpp2s = mult[2] * x - mult[3] * denoised2 + + # apply correction if noise level is not 0 + x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler) + + x = self.ancestral_step(x, sigma, next_sigma, sigma_up) + return x + + +class DPMPP2MSampler(BaseDiffusionSampler): + def get_variables(self, sigma, next_sigma, previous_sigma=None): + t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] + h = t_next - t + + if previous_sigma is not None: + h_last = t - to_neg_log_sigma(previous_sigma) + r = h_last / h + return h, r, t, t_next + else: + return h, None, t, t_next + + def get_mult(self, h, r, t, t_next, previous_sigma): + mult1 = to_sigma(t_next) / to_sigma(t) + mult2 = (-h).expm1() + + if previous_sigma is not None: + mult3 = 1 + 1 / (2 * r) + mult4 = 1 / (2 * r) + return mult1, mult2, mult3, mult4 + else: + return mult1, mult2 + + def sampler_step( + self, + old_denoised, + previous_sigma, + sigma, + next_sigma, + denoiser, + x, + cond, + uc=None, + ): + denoised = self.denoise(x, denoiser, sigma, cond, uc) + + h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) + mult = [ + append_dims(mult, x.ndim) + for mult in self.get_mult(h, r, t, t_next, previous_sigma) + ] + + x_standard = mult[0] * x - mult[1] * denoised + if old_denoised is None or torch.sum(next_sigma) < 1e-14: + # Save a network evaluation if all noise levels are 0 or on the first step + return x_standard, denoised + else: + denoised_d = mult[2] * denoised - mult[3] * old_denoised + x_advanced = mult[0] * x - mult[1] * denoised_d + + # apply correction if noise level is not 0 and not first step + x = torch.where( + append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard + ) + + return x, denoised + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, init_step=0, **kwargs): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + old_denoised = None + for i in self.get_sigma_gen(num_sigmas, init_step=init_step): + x, old_denoised = self.sampler_step( + old_denoised, + None if i == 0 else s_in * sigmas[i - 1], + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc=uc, + ) + + return x diff --git a/sgm/modules/diffusionmodules/sampling_utils.py b/sgm/modules/diffusionmodules/sampling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..839fce1f246f8867b38a101fc2075905bef9ef8b --- /dev/null +++ b/sgm/modules/diffusionmodules/sampling_utils.py @@ -0,0 +1,51 @@ +import torch +from scipy import integrate + +from ...util import append_dims + + +class NoDynamicThresholding: + def __call__(self, uncond, cond, scale): + return uncond + scale * (cond - uncond) + +class DualThresholding: # Dual condition CFG (from instructPix2Pix) + def __call__(self, uncond_1, uncond_2, cond, scale): + return uncond_1 + scale[0] * (uncond_2 - uncond_1) + scale[1] * (cond - uncond_2) + +def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): + if order - 1 > i: + raise ValueError(f"Order {order} too high for step {i}") + + def fn(tau): + prod = 1.0 + for k in range(order): + if j == k: + continue + prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) + return prod + + return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] + + +def get_ancestral_step(sigma_from, sigma_to, eta=1.0): + if not eta: + return sigma_to, 0.0 + sigma_up = torch.minimum( + sigma_to, + eta + * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, + ) + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + return sigma_down, sigma_up + + +def to_d(x, sigma, denoised): + return (x - denoised) / append_dims(sigma, x.ndim) + + +def to_neg_log_sigma(sigma): + return sigma.log().neg() + + +def to_sigma(neg_log_sigma): + return neg_log_sigma.neg().exp() diff --git a/sgm/modules/diffusionmodules/sigma_sampling.py b/sgm/modules/diffusionmodules/sigma_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..d54724c6ef6a7b8067784a4192b0fe2f41123063 --- /dev/null +++ b/sgm/modules/diffusionmodules/sigma_sampling.py @@ -0,0 +1,31 @@ +import torch + +from ...util import default, instantiate_from_config + + +class EDMSampling: + def __init__(self, p_mean=-1.2, p_std=1.2): + self.p_mean = p_mean + self.p_std = p_std + + def __call__(self, n_samples, rand=None): + log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,))) + return log_sigma.exp() + + +class DiscreteSampling: + def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True): + self.num_idx = num_idx + self.sigmas = instantiate_from_config(discretization_config)( + num_idx, do_append_zero=do_append_zero, flip=flip + ) + + def idx_to_sigma(self, idx): + return self.sigmas[idx] + + def __call__(self, n_samples, rand=None): + idx = default( + rand, + torch.randint(0, self.num_idx, (n_samples,)), + ) + return self.idx_to_sigma(idx) diff --git a/sgm/modules/diffusionmodules/util.py b/sgm/modules/diffusionmodules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..069ff131fb9789949203551e0efd8313a3d0cc08 --- /dev/null +++ b/sgm/modules/diffusionmodules/util.py @@ -0,0 +1,308 @@ +""" +adopted from +https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +and +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +and +https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py + +thanks! +""" + +import math + +import torch +import torch.nn as nn +from einops import repeat + + +def make_beta_schedule( + schedule, + n_timestep, + linear_start=1e-4, + linear_end=2e-2, +): + if schedule == "linear": + betas = ( + torch.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 + ) + ** 2 + ) + return betas.numpy() + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def mixed_checkpoint(func, inputs: dict, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function + borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that + it also works with non-tensor inputs + :param func: the function to evaluate. + :param inputs: the argument dictionary to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)] + tensor_inputs = [ + inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor) + ] + non_tensor_keys = [ + key for key in inputs if not isinstance(inputs[key], torch.Tensor) + ] + non_tensor_inputs = [ + inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor) + ] + args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params) + return MixedCheckpointFunction.apply( + func, + len(tensor_inputs), + len(non_tensor_inputs), + tensor_keys, + non_tensor_keys, + *args, + ) + else: + return func(**inputs) + + +class MixedCheckpointFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + run_function, + length_tensors, + length_non_tensors, + tensor_keys, + non_tensor_keys, + *args, + ): + ctx.end_tensors = length_tensors + ctx.end_non_tensors = length_tensors + length_non_tensors + ctx.gpu_autocast_kwargs = { + "enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled(), + } + assert ( + len(tensor_keys) == length_tensors + and len(non_tensor_keys) == length_non_tensors + ) + + ctx.input_tensors = { + key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors])) + } + ctx.input_non_tensors = { + key: val + for (key, val) in zip( + non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors]) + ) + } + ctx.run_function = run_function + ctx.input_params = list(args[ctx.end_non_tensors :]) + + with torch.no_grad(): + output_tensors = ctx.run_function( + **ctx.input_tensors, **ctx.input_non_tensors + ) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)} + ctx.input_tensors = { + key: ctx.input_tensors[key].detach().requires_grad_(True) + for key in ctx.input_tensors + } + + with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = { + key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) + for key in ctx.input_tensors + } + # shallow_copies.update(additional_args) + output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors) + input_grads = torch.autograd.grad( + output_tensors, + list(ctx.input_tensors.values()) + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return ( + (None, None, None, None, None) + + input_grads[: ctx.end_tensors] + + (None,) * (ctx.end_non_tensors - ctx.end_tensors) + + input_grads[ctx.end_tensors :] + ) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + ctx.gpu_autocast_kwargs = { + "enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled(), + } + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") diff --git a/sgm/modules/diffusionmodules/wrappers.py b/sgm/modules/diffusionmodules/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..3bb494470026014f6c091828361b24d494dd37dd --- /dev/null +++ b/sgm/modules/diffusionmodules/wrappers.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn +from packaging import version + +OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" + + +class IdentityWrapper(nn.Module): + def __init__(self, diffusion_model, compile_model: bool = False): + super().__init__() + compile = ( + torch.compile + if (version.parse(torch.__version__) >= version.parse("2.0.0")) + and compile_model + else lambda x: x + ) + self.diffusion_model = compile(diffusion_model) + + def forward(self, *args, **kwargs): + return self.diffusion_model(*args, **kwargs) + + +class OpenAIWrapper(IdentityWrapper): + def forward( + self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs + ) -> torch.Tensor: + x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) + return self.diffusion_model( + x, + timesteps=t, + context=c.get("crossattn", None), + add_context=c.get("add_crossattn", None), + y=c.get("vector", None), + **kwargs + ) diff --git a/sgm/modules/distributions/__init__.py b/sgm/modules/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sgm/modules/distributions/distributions.py b/sgm/modules/distributions/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..0b61f03077358ce4737c85842d9871f70dabb656 --- /dev/null +++ b/sgm/modules/distributions/distributions.py @@ -0,0 +1,102 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to( + device=self.parameters.device + ) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to( + device=self.parameters.device + ) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/sgm/modules/ema.py b/sgm/modules/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..97b5ae2b230f89b4dba57e44c4f851478ad86f68 --- /dev/null +++ b/sgm/modules/ema.py @@ -0,0 +1,86 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.m_name2s_name = {} + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", + torch.tensor(0, dtype=torch.int) + if use_num_upates + else torch.tensor(-1, dtype=torch.int), + ) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def reset_num_updates(self): + del self.num_updates + self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_( + one_minus_decay * (shadow_params[sname] - m_param[key]) + ) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/sgm/modules/encoders/__init__.py b/sgm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sgm/modules/encoders/modules.py b/sgm/modules/encoders/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..be44972af5374c8fb460eb318a78c3b6a3a0e7f3 --- /dev/null +++ b/sgm/modules/encoders/modules.py @@ -0,0 +1,1253 @@ +from contextlib import nullcontext +from functools import partial +from typing import Dict, List, Optional, Tuple, Union + +import kornia +import numpy as np +import open_clip +import torch +import torch.nn as nn +from einops import rearrange, repeat +from omegaconf import ListConfig +from torch.utils.checkpoint import checkpoint +from transformers import ( + ByT5Tokenizer, + CLIPTextModel, + CLIPTokenizer, + T5EncoderModel, + T5Tokenizer, +) + +from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer +from ...modules.diffusionmodules.model import Encoder +from ...modules.diffusionmodules.openaimodel import Timestep +from ...modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule +from ...modules.distributions.distributions import DiagonalGaussianDistribution +from ...util import ( + autocast, + count_params, + default, + disabled_train, + expand_dims_like, + instantiate_from_config, +) + +import math +import string +import pytorch_lightning as pl +from torchvision import transforms +from timm.models.vision_transformer import VisionTransformer +from safetensors.torch import load_file as load_safetensors + +# disable warning +from transformers import logging +logging.set_verbosity_error() + +class AbstractEmbModel(nn.Module): + def __init__(self, is_add_embedder=False): + super().__init__() + self._is_trainable = None + self._ucg_rate = None + self._input_key = None + self.is_add_embedder = is_add_embedder + + @property + def is_trainable(self) -> bool: + return self._is_trainable + + @property + def ucg_rate(self) -> Union[float, torch.Tensor]: + return self._ucg_rate + + @property + def input_key(self) -> str: + return self._input_key + + @is_trainable.setter + def is_trainable(self, value: bool): + self._is_trainable = value + + @ucg_rate.setter + def ucg_rate(self, value: Union[float, torch.Tensor]): + self._ucg_rate = value + + @input_key.setter + def input_key(self, value: str): + self._input_key = value + + @is_trainable.deleter + def is_trainable(self): + del self._is_trainable + + @ucg_rate.deleter + def ucg_rate(self): + del self._ucg_rate + + @input_key.deleter + def input_key(self): + del self._input_key + + +class GeneralConditioner(nn.Module): + OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"} + KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1} + + def __init__(self, emb_models: Union[List, ListConfig]): + super().__init__() + embedders = [] + for n, embconfig in enumerate(emb_models): + embedder = instantiate_from_config(embconfig) + assert isinstance( + embedder, AbstractEmbModel + ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" + embedder.is_trainable = embconfig.get("is_trainable", False) + embedder.ucg_rate = embconfig.get("ucg_rate", 0.0) + if not embedder.is_trainable: + embedder.train = disabled_train + embedder.freeze() + print( + f"Initialized embedder #{n}: {embedder.__class__.__name__} " + f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}" + ) + + if "input_key" in embconfig: + embedder.input_key = embconfig["input_key"] + elif "input_keys" in embconfig: + embedder.input_keys = embconfig["input_keys"] + else: + raise KeyError( + f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}" + ) + + embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None) + if embedder.legacy_ucg_val is not None: + embedder.ucg_prng = np.random.RandomState() + + embedders.append(embedder) + self.embedders = nn.ModuleList(embedders) + + def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict: + assert embedder.legacy_ucg_val is not None + p = embedder.ucg_rate + val = embedder.legacy_ucg_val + for i in range(len(batch[embedder.input_key])): + if embedder.ucg_prng.choice(2, p=[1 - p, p]): + batch[embedder.input_key][i] = val + return batch + + def forward( + self, batch: Dict, force_zero_embeddings: Optional[List] = None + ) -> Dict: + output = dict() + if force_zero_embeddings is None: + force_zero_embeddings = [] + for embedder in self.embedders: + embedding_context = nullcontext if embedder.is_trainable else torch.no_grad + with embedding_context(): + if hasattr(embedder, "input_key") and (embedder.input_key is not None): + if embedder.legacy_ucg_val is not None: + batch = self.possibly_get_ucg_val(embedder, batch) + emb_out = embedder(batch[embedder.input_key]) + elif hasattr(embedder, "input_keys"): + emb_out = embedder(*[batch[k] for k in embedder.input_keys]) + assert isinstance( + emb_out, (torch.Tensor, list, tuple) + ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}" + if not isinstance(emb_out, (list, tuple)): + emb_out = [emb_out] + for emb in emb_out: + if embedder.is_add_embedder: + out_key = "add_crossattn" + else: + out_key = self.OUTPUT_DIM2KEYS[emb.dim()] + if embedder.input_key == "mask": + H, W = batch["image"].shape[-2:] + emb = nn.functional.interpolate(emb, (H//8, W//8)) + if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: + emb = ( + expand_dims_like( + torch.bernoulli( + (1.0 - embedder.ucg_rate) + * torch.ones(emb.shape[0], device=emb.device) + ), + emb, + ) + * emb + ) + if ( + hasattr(embedder, "input_key") + and embedder.input_key in force_zero_embeddings + ): + emb = torch.zeros_like(emb) + if out_key in output: + output[out_key] = torch.cat( + (output[out_key], emb), self.KEY2CATDIM[out_key] + ) + else: + output[out_key] = emb + return output + + def get_unconditional_conditioning( + self, batch_c, batch_uc=None, force_uc_zero_embeddings=None + ): + if force_uc_zero_embeddings is None: + force_uc_zero_embeddings = [] + ucg_rates = list() + for embedder in self.embedders: + ucg_rates.append(embedder.ucg_rate) + embedder.ucg_rate = 0.0 + c = self(batch_c) + uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings) + + for embedder, rate in zip(self.embedders, ucg_rates): + embedder.ucg_rate = rate + return c, uc + + +class DualConditioner(GeneralConditioner): + + def get_unconditional_conditioning( + self, batch_c, batch_uc_1=None, batch_uc_2=None, force_uc_zero_embeddings=None + ): + if force_uc_zero_embeddings is None: + force_uc_zero_embeddings = [] + ucg_rates = list() + for embedder in self.embedders: + ucg_rates.append(embedder.ucg_rate) + embedder.ucg_rate = 0.0 + + c = self(batch_c) + uc_1 = self(batch_uc_1, force_uc_zero_embeddings) if batch_uc_1 is not None else None + uc_2 = self(batch_uc_2, force_uc_zero_embeddings[:1]) if batch_uc_2 is not None else None + + for embedder, rate in zip(self.embedders, ucg_rates): + embedder.ucg_rate = rate + + return c, uc_1, uc_2 + + +class InceptionV3(nn.Module): + """Wrapper around the https://github.com/mseitzer/pytorch-fid inception + port with an additional squeeze at the end""" + + def __init__(self, normalize_input=False, **kwargs): + super().__init__() + from pytorch_fid import inception + + kwargs["resize_input"] = True + self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs) + + def forward(self, inp): + # inp = kornia.geometry.resize(inp, (299, 299), + # interpolation='bicubic', + # align_corners=False, + # antialias=True) + # inp = inp.clamp(min=-1, max=1) + + outp = self.model(inp) + + if len(outp) == 1: + return outp[0].squeeze() + + return outp + + +class IdentityEncoder(AbstractEmbModel): + def encode(self, x): + return x + def freeze(self): + return + def forward(self, x): + return x + + +class ClassEmbedder(AbstractEmbModel): + def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False): + super().__init__() + self.embedding = nn.Embedding(n_classes, embed_dim) + self.n_classes = n_classes + self.add_sequence_dim = add_sequence_dim + + def forward(self, c): + c = self.embedding(c) + if self.add_sequence_dim: + c = c[:, None, :] + return c + + def get_unconditional_conditioning(self, bs, device="cuda"): + uc_class = ( + self.n_classes - 1 + ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) + uc = torch.ones((bs,), device=device) * uc_class + uc = {self.key: uc.long()} + return uc + + +class ClassEmbedderForMultiCond(ClassEmbedder): + def forward(self, batch, key=None, disable_dropout=False): + out = batch + key = default(key, self.key) + islist = isinstance(batch[key], list) + if islist: + batch[key] = batch[key][0] + c_out = super().forward(batch, key, disable_dropout) + out[key] = [c_out] if islist else c_out + return out + + +class FrozenT5Embedder(AbstractEmbModel): + """Uses the T5 transformer encoder for text""" + + def __init__( + self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True + ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + + for param in self.parameters(): + param.requires_grad = False + + # @autocast + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + with torch.autocast("cuda", enabled=False): + outputs = self.transformer(input_ids=tokens) + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenByT5Embedder(AbstractEmbModel): + """ + Uses the ByT5 transformer encoder for text. Is character-aware. + """ + + def __init__( + self, version="google/byt5-base", device="cuda", max_length=77, freeze=True, *args, **kwargs + ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__(*args, **kwargs) + self.tokenizer = ByT5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(next(self.parameters()).device) + with torch.autocast("cuda", enabled=False): + outputs = self.transformer(input_ids=tokens) + z = outputs.last_hidden_state # l, 1536 + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPEmbedder(AbstractEmbModel): + """Uses the CLIP transformer encoder for text (from huggingface)""" + + LAYERS = ["last", "pooled", "hidden"] + + def __init__( + self, + version="openai/clip-vit-large-patch14", + device="cuda", + max_length=77, + freeze=True, + layer="last", + layer_idx=None, + always_return_pooled=False, + ): # clip-vit-base-patch32 + super().__init__() + assert layer in self.LAYERS + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + self.layer_idx = layer_idx + self.return_pooled = always_return_pooled + if layer == "hidden": + assert layer_idx is not None + assert 0 <= abs(layer_idx) <= 12 + + def freeze(self): + self.transformer = self.transformer.eval() + + for param in self.parameters(): + param.requires_grad = False + + @autocast + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + device = next(self.transformer.parameters()).device + tokens = batch_encoding["input_ids"].to(device) + outputs = self.transformer( + input_ids=tokens, output_hidden_states=self.layer == "hidden" + ) + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + if self.return_pooled: + return z, outputs.pooler_output + return z + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPEmbedder2(AbstractEmbModel): + """ + Uses the OpenCLIP transformer encoder for text + """ + + LAYERS = ["pooled", "last", "penultimate"] + + def __init__( + self, + arch="ViT-H-14", + version="laion2b_s32b_b79k", + device="cuda", + max_length=77, + freeze=True, + layer="last", + always_return_pooled=False, + legacy=True, + ): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms( + arch, + device=torch.device("cpu"), + pretrained=version, + ) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + self.return_pooled = always_return_pooled + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + self.legacy = legacy + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + @autocast + def forward(self, text): + device = next(self.model.parameters()).device + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(device)) + if not self.return_pooled and self.legacy: + return z + if self.return_pooled: + assert not self.legacy + return z[self.layer], z["pooled"] + return z[self.layer] + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + if self.legacy: + x = x[self.layer] + x = self.model.ln_final(x) + return x + else: + # x is a dict and will stay a dict + o = x["last"] + o = self.model.ln_final(o) + pooled = self.pool(o, text) + x["pooled"] = pooled + return x + + def pool(self, x, text): + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = ( + x[torch.arange(x.shape[0]), text.argmax(dim=-1)] + @ self.model.text_projection + ) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + outputs = {} + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - 1: + outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD + if ( + self.model.transformer.grad_checkpointing + and not torch.jit.is_scripting() + ): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + outputs["last"] = x.permute(1, 0, 2) # LND -> NLD + return outputs + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPEmbedder(AbstractEmbModel): + LAYERS = [ + # "pooled", + "last", + "penultimate", + ] + + def __init__( + self, + arch="ViT-H-14", + version="laion2b_s32b_b79k", + device="cuda", + max_length=77, + freeze=True, + layer="last", + ): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms( + arch, device=torch.device("cpu"), pretrained=version + ) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + device = next(self.model.parameters()).device + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if ( + self.model.transformer.grad_checkpointing + and not torch.jit.is_scripting() + ): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPImageEmbedder(AbstractEmbModel): + """ + Uses the OpenCLIP vision transformer encoder for images + """ + + def __init__( + self, + arch="ViT-H-14", + version="laion2b_s32b_b79k", + device="cuda", + max_length=77, + freeze=True, + antialias=True, + ucg_rate=0.0, + unsqueeze_dim=False, + repeat_to_max_len=False, + num_image_crops=0, + output_tokens=False, + ): + super().__init__() + model, _, _ = open_clip.create_model_and_transforms( + arch, + device=torch.device("cpu"), + pretrained=version, + ) + del model.transformer + self.model = model + self.max_crops = num_image_crops + self.pad_to_max_len = self.max_crops > 0 + self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + + self.antialias = antialias + + self.register_buffer( + "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False + ) + self.register_buffer( + "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False + ) + self.ucg_rate = ucg_rate + self.unsqueeze_dim = unsqueeze_dim + self.stored_batch = None + self.model.visual.output_tokens = output_tokens + self.output_tokens = output_tokens + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize( + x, + (224, 224), + interpolation="bicubic", + align_corners=True, + antialias=self.antialias, + ) + x = (x + 1.0) / 2.0 + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + @autocast + def forward(self, image, no_dropout=False): + z = self.encode_with_vision_transformer(image) + tokens = None + if self.output_tokens: + z, tokens = z[0], z[1] + z = z.to(image.dtype) + if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0): + z = ( + torch.bernoulli( + (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device) + )[:, None] + * z + ) + if tokens is not None: + tokens = ( + expand_dims_like( + torch.bernoulli( + (1.0 - self.ucg_rate) + * torch.ones(tokens.shape[0], device=tokens.device) + ), + tokens, + ) + * tokens + ) + if self.unsqueeze_dim: + z = z[:, None, :] + if self.output_tokens: + assert not self.repeat_to_max_len + assert not self.pad_to_max_len + return tokens, z + if self.repeat_to_max_len: + if z.dim() == 2: + z_ = z[:, None, :] + else: + z_ = z + return repeat(z_, "b 1 d -> b n d", n=self.max_length), z + elif self.pad_to_max_len: + assert z.dim() == 3 + z_pad = torch.cat( + ( + z, + torch.zeros( + z.shape[0], + self.max_length - z.shape[1], + z.shape[2], + device=z.device, + ), + ), + 1, + ) + return z_pad, z_pad[:, 0, ...] + return z + + def encode_with_vision_transformer(self, img): + # if self.max_crops > 0: + # img = self.preprocess_by_cropping(img) + if img.dim() == 5: + assert self.max_crops == img.shape[1] + img = rearrange(img, "b n c h w -> (b n) c h w") + img = self.preprocess(img) + if not self.output_tokens: + assert not self.model.visual.output_tokens + x = self.model.visual(img) + tokens = None + else: + assert self.model.visual.output_tokens + x, tokens = self.model.visual(img) + if self.max_crops > 0: + x = rearrange(x, "(b n) d -> b n d", n=self.max_crops) + # drop out between 0 and all along the sequence axis + x = ( + torch.bernoulli( + (1.0 - self.ucg_rate) + * torch.ones(x.shape[0], x.shape[1], 1, device=x.device) + ) + * x + ) + if tokens is not None: + tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops) + print( + f"You are running very experimental token-concat in {self.__class__.__name__}. " + f"Check what you are doing, and then remove this message." + ) + if self.output_tokens: + return x, tokens + return x + + def encode(self, text): + return self(text) + + +class FrozenCLIPT5Encoder(AbstractEmbModel): + def __init__( + self, + clip_version="openai/clip-vit-large-patch14", + t5_version="google/t5-v1_1-xl", + device="cuda", + clip_max_length=77, + t5_max_length=77, + ): + super().__init__() + self.clip_encoder = FrozenCLIPEmbedder( + clip_version, device, max_length=clip_max_length + ) + self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) + print( + f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params." + ) + + def encode(self, text): + return self(text) + + def forward(self, text): + clip_z = self.clip_encoder.encode(text) + t5_z = self.t5_encoder.encode(text) + return [clip_z, t5_z] + + +class SpatialRescaler(nn.Module): + def __init__( + self, + n_stages=1, + method="bilinear", + multiplier=0.5, + in_channels=3, + out_channels=None, + bias=False, + wrap_video=False, + kernel_size=1, + remap_output=False, + ): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in [ + "nearest", + "linear", + "bilinear", + "trilinear", + "bicubic", + "area", + ] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.remap_output = out_channels is not None or remap_output + if self.remap_output: + print( + f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing." + ) + self.channel_mapper = nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + bias=bias, + padding=kernel_size // 2, + ) + self.wrap_video = wrap_video + + def forward(self, x): + if self.wrap_video and x.ndim == 5: + B, C, T, H, W = x.shape + x = rearrange(x, "b c t h w -> b t c h w") + x = rearrange(x, "b t c h w -> (b t) c h w") + + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + if self.wrap_video: + x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C) + x = rearrange(x, "b t c h w -> b c t h w") + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) + + +class LowScaleEncoder(nn.Module): + def __init__( + self, + model_config, + linear_start, + linear_end, + timesteps=1000, + max_noise_level=250, + output_size=64, + scale_factor=1.0, + ): + super().__init__() + self.max_noise_level = max_noise_level + self.model = instantiate_from_config(model_config) + self.augmentation_schedule = self.register_schedule( + timesteps=timesteps, linear_start=linear_start, linear_end=linear_end + ) + self.out_size = output_size + self.scale_factor = scale_factor + + def register_schedule( + self, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), "alphas have to be defined for each timestep" + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) + ) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def forward(self, x): + z = self.model.encode(x) + if isinstance(z, DiagonalGaussianDistribution): + z = z.sample() + z = z * self.scale_factor + noise_level = torch.randint( + 0, self.max_noise_level, (x.shape[0],), device=x.device + ).long() + z = self.q_sample(z, noise_level) + if self.out_size is not None: + z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") + # z = z.repeat_interleave(2, -2).repeat_interleave(2, -1) + return z, noise_level + + def decode(self, z): + z = z / self.scale_factor + return self.model.decode(z) + + +class ConcatTimestepEmbedderND(AbstractEmbModel): + """embeds each dimension independently and concatenates them""" + + def __init__(self, outdim): + super().__init__() + self.timestep = Timestep(outdim) + self.outdim = outdim + + def freeze(self): + self.eval() + + def forward(self, x): + if x.ndim == 1: + x = x[:, None] + assert len(x.shape) == 2 + b, dims = x.shape[0], x.shape[1] + x = rearrange(x, "b d -> (b d)") + emb = self.timestep(x) + emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) + return emb + + +class GaussianEncoder(Encoder, AbstractEmbModel): + def __init__( + self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs + ): + super().__init__(*args, **kwargs) + self.posterior = DiagonalGaussianRegularizer() + self.weight = weight + self.flatten_output = flatten_output + + def forward(self, x) -> Tuple[Dict, torch.Tensor]: + z = super().forward(x) + z, log = self.posterior(z) + log["loss"] = log["kl_loss"] + log["weight"] = self.weight + if self.flatten_output: + z = rearrange(z, "b c h w -> b (h w ) c") + return log, z + + +class LatentEncoder(AbstractEmbModel): + + def __init__(self, scale_factor, config, *args, **kwargs): + super().__init__(*args, **kwargs) + self.scale_factor = scale_factor + self.model = instantiate_from_config(config).eval() + self.model.train = disabled_train + + def freeze(self): + for param in self.model.parameters(): + param.requires_grad = False + + def forward(self, x): + z = self.model.encode(x) + z = self.scale_factor * z + return z + + +class ViTSTREncoder(VisionTransformer): + ''' + ViTSTREncoder is basically a ViT that uses ViTSTR weights + ''' + def __init__(self, size=224, ckpt_path=None, freeze=True, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.grayscale = transforms.Grayscale() + self.resize = transforms.Resize((size, size), transforms.InterpolationMode.BICUBIC, antialias=True) + + self.character = string.printable[:-6] + self.reset_classifier(num_classes=len(self.character)+2) + + if ckpt_path is not None: + self.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False) + + if freeze: + self.freeze() + + def reset_classifier(self, num_classes): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def freeze(self): + for param in self.parameters(): + param.requires_grad_(False) + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x + + def forward(self, x): + + x = self.forward_features(x) + + return x + + def encode(self, x): + return self(x) + + +class PositionalEncoding(nn.Module): + + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(PositionalEncoding, self).__init__() + + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + torch.tile(self.pe[None, ...].to(x.device), (x.shape[0], 1, 1)) + return self.dropout(x) + + +class LabelEncoder(AbstractEmbModel, pl.LightningModule): + + def __init__(self, max_len, emb_dim, n_heads=8, n_trans_layers=12, ckpt_path=None, trainable=False, + lr=1e-4, lambda_cls=0.1, lambda_pos=0.1, clip_dim=1024, visual_len=197, visual_dim=768, visual_config=None, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.max_len = max_len + self.emd_dim = emb_dim + self.n_heads = n_heads + self.n_trans_layers = n_trans_layers + self.character = string.printable[:-6] + self.num_cls = len(self.character) + 1 + + self.label_embedding = nn.Embedding(self.num_cls, self.emd_dim) + self.pos_embedding = PositionalEncoding(d_model=self.emd_dim, max_len=self.max_len) + transformer_block = nn.TransformerEncoderLayer(d_model=self.emd_dim, nhead=self.n_heads, batch_first=True) + self.encoder = nn.TransformerEncoder(transformer_block, num_layers=self.n_trans_layers) + + if ckpt_path is not None: + self.load_state_dict(torch.load(ckpt_path, map_location="cpu")["state_dict"], strict=False) + + if trainable: + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.visual_encoder = instantiate_from_config(visual_config) + + self.learning_rate = lr + self.clip_dim = clip_dim + self.visual_len = visual_len + self.visual_dim = visual_dim + self.lambda_cls = lambda_cls + self.lambda_pos = lambda_pos + + self.cls_head = nn.Sequential(*[ + nn.InstanceNorm1d(self.max_len), + nn.Linear(self.emd_dim, self.emd_dim), + nn.GELU(), + nn.Linear(self.emd_dim, self.num_cls) + ]) + + self.pos_head = nn.Sequential(*[ + nn.InstanceNorm1d(self.max_len), + nn.Linear(self.emd_dim, self.max_len, bias=False) + ]) + + self.text_head = nn.Sequential(*[ + nn.InstanceNorm1d(self.max_len), + nn.Linear(self.emd_dim, self.clip_dim, bias=False), + nn.Conv1d(in_channels=self.max_len, out_channels=1, kernel_size=1) + ]) + + self.visual_head = nn.Sequential(*[ + nn.InstanceNorm1d(self.visual_len), + nn.Linear(self.visual_dim, self.clip_dim, bias=False), + nn.Conv1d(in_channels=self.visual_len, out_channels=1, kernel_size=1) + ]) + + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + + def get_index(self, labels): + + indexes = [] + for label in labels: + assert len(label) <= self.max_len + index = [self.character.find(c)+1 for c in label] + index = index + [0] * (self.max_len - len(index)) + indexes.append(index) + + return torch.tensor(indexes, device=next(self.parameters()).device) + + def get_embeddings(self, x): + + emb = self.label_embedding(x) + emb = self.pos_embedding(emb) + out = self.encoder(emb) + + return out + + def forward(self, labels): + + idx = self.get_index(labels) + out = self.get_embeddings(idx) + + return out + + def get_loss(self, text_out, visual_out, clip_target, cls_out, pos_out, cls_target, pos_target): + + text_out = text_out / text_out.norm(dim=1, keepdim=True) # b, 1024 + visual_out = visual_out / visual_out.norm(dim=1, keepdim=True) # b, 1024 + + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * visual_out @ text_out.T # b, b + logits_per_text = logits_per_image.T # b, b + + clip_loss_image = nn.functional.cross_entropy(logits_per_image, clip_target) + clip_loss_text = nn.functional.cross_entropy(logits_per_text, clip_target) + clip_loss = (clip_loss_image + clip_loss_text) / 2 + + cls_loss = nn.functional.cross_entropy(cls_out.permute(0,2,1), cls_target) + pos_loss = nn.functional.cross_entropy(pos_out.permute(0,2,1), pos_target) + + return clip_loss, cls_loss, pos_loss, logits_per_text + + def training_step(self, batch, batch_idx): + + text = batch["text"] + image = batch["image"] + + idx = self.get_index(text) + text_emb = self.get_embeddings(idx) # b, l, d + visual_emb = self.visual_encoder(image) # b, n, d + + cls_out = self.cls_head(text_emb) # b, l, c + pos_out = self.pos_head(text_emb) # b, l, p + text_out = self.text_head(text_emb).squeeze(1) # b, 1024 + visual_out = self.visual_head(visual_emb).squeeze(1) # b, 1024 + + cls_target = idx # b, c + pos_target = torch.arange(start=0, end=self.max_len, step=1) + pos_target = pos_target[None].tile((idx.shape[0], 1)).to(cls_target) # b, c + clip_target = torch.arange(0, idx.shape[0], 1).to(cls_target) # b, + + clip_loss, cls_loss, pos_loss, logits_per_text = self.get_loss(text_out, visual_out, clip_target, cls_out, pos_out, cls_target, pos_target) + loss = clip_loss + self.lambda_cls * cls_loss + self.lambda_pos * pos_loss + + loss_dict = {} + loss_dict["loss/clip_loss"] = clip_loss + loss_dict["loss/cls_loss"] = cls_loss + loss_dict["loss/pos_loss"] = pos_loss + loss_dict["loss/full_loss"] = loss + + clip_idx = torch.max(logits_per_text, dim=-1).indices # b, + clip_acc = (clip_idx == clip_target).to(dtype=torch.float32).mean() + + cls_idx = torch.max(cls_out, dim=-1).indices # b, l + cls_acc = (cls_idx == cls_target).to(dtype=torch.float32).mean() + + pos_idx = torch.max(pos_out, dim=-1).indices # b, l + pos_acc = (pos_idx == pos_target).to(dtype=torch.float32).mean() + + loss_dict["acc/clip_acc"] = clip_acc + loss_dict["acc/cls_acc"] = cls_acc + loss_dict["acc/pos_acc"] = pos_acc + + self.log_dict(loss_dict, prog_bar=True, batch_size=len(text), + logger=True, on_step=True, on_epoch=True, sync_dist=True) + + return loss + + def configure_optimizers(self): + + lr = self.learning_rate + opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=lr) + + return opt + + diff --git a/sgm/modules/predictors/model.py b/sgm/modules/predictors/model.py new file mode 100644 index 0000000000000000000000000000000000000000..cc264b7fe70af379135c814dc393837534a80b3b --- /dev/null +++ b/sgm/modules/predictors/model.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn +from torchvision import transforms +from torchvision.utils import save_image + + +class ParseqPredictor(nn.Module): + + def __init__(self, ckpt_path=None, freeze=True, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.parseq = torch.hub.load('./src/parseq', 'parseq', source='local').eval() + self.parseq.load_state_dict(torch.load(ckpt_path, map_location="cpu")) + self.parseq_transform = transforms.Compose([ + transforms.Resize(self.parseq.hparams.img_size, transforms.InterpolationMode.BICUBIC, antialias=True), + transforms.Normalize(0.5, 0.5) + ]) + + if freeze: + self.freeze() + + def freeze(self): + for param in self.parseq.parameters(): + param.requires_grad_(False) + + def forward(self, x): + + x = torch.cat([self.parseq_transform(t[None]) for t in x]) + logits = self.parseq(x.to(next(self.parameters()).device)) + + return logits + + def img2txt(self, x): + + pred = self(x) + label, confidence = self.parseq.tokenizer.decode(pred) + return label + + + def calc_loss(self, x, label): + + preds = self(x) # (B, l, C) l=26, C=95 + gt_ids = self.parseq.tokenizer.encode(label).to(preds.device) # (B, l_trun) + + losses = [] + for pred, gt_id in zip(preds, gt_ids): + + eos_id = (gt_id == 0).nonzero().item() + gt_id = gt_id[1: eos_id] + pred = pred[:eos_id-1, :] + + ce_loss = nn.functional.cross_entropy(pred.permute(1, 0)[None], gt_id[None]) + ce_loss = torch.clamp(ce_loss, max = 1.0) + losses.append(ce_loss[None]) + + loss = torch.cat(losses) + + return loss \ No newline at end of file diff --git a/sgm/util.py b/sgm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..06f48a882de7b5b37d48fffe7b44d2bf883bd85d --- /dev/null +++ b/sgm/util.py @@ -0,0 +1,231 @@ +import functools +import importlib +import os +from functools import partial +from inspect import isfunction + +import fsspec +import numpy as np +import torch +from PIL import Image, ImageDraw, ImageFont +from safetensors.torch import load_file as load_safetensors + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def get_string_from_tuple(s): + try: + # Check if the string starts and ends with parentheses + if s[0] == "(" and s[-1] == ")": + # Convert the string to a tuple + t = eval(s) + # Check if the type of t is tuple + if type(t) == tuple: + return t[0] + else: + pass + except: + pass + return s + + +def is_power_of_two(n): + """ + chat.openai.com/chat + Return True if n is a power of 2, otherwise return False. + + The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False. + The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False. + If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise. + Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False. + + """ + if n <= 0: + return False + return (n & (n - 1)) == 0 + + +def autocast(f, enabled=True): + def do_autocast(*args, **kwargs): + with torch.cuda.amp.autocast( + enabled=enabled, + dtype=torch.get_autocast_gpu_dtype(), + cache_enabled=torch.is_autocast_cache_enabled(), + ): + return f(*args, **kwargs) + + return do_autocast + + +def load_partial_from_config(config): + return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) + nc = int(40 * (wh[0] / 256)) + if isinstance(xc[bi], list): + text_seq = xc[bi][0] + else: + text_seq = xc[bi] + lines = "\n".join( + text_seq[start : start + nc] for start in range(0, len(text_seq), nc) + ) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def partialclass(cls, *args, **kwargs): + class NewCls(cls): + __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) + + return NewCls + + +def make_path_absolute(path): + fs, p = fsspec.core.url_to_fs(path) + if fs.protocol == "file": + return os.path.abspath(p) + return path + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def isheatmap(x): + if not isinstance(x, torch.Tensor): + return False + + return x.ndim == 2 + + +def isneighbors(x): + if not isinstance(x, torch.Tensor): + return False + return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) + + +def exists(x): + return x is not None + + +def expand_dims_like(x, y): + while x.dim() != y.dim(): + x = x.unsqueeze(-1) + return x + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False, invalidate_cache=True): + module, cls = string.rsplit(".", 1) + if invalidate_cache: + importlib.invalidate_caches() + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def append_zero(x): + return torch.cat([x, x.new_zeros([1])]) + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError( + f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" + ) + return x[(...,) + (None,) * dims_to_append] + + +def load_model_from_config(config, ckpt, verbose=True, freeze=True): + print(f"Loading model from {ckpt}") + if ckpt.endswith("ckpt"): + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + elif ckpt.endswith("safetensors"): + sd = load_safetensors(ckpt) + else: + raise NotImplementedError + + model = instantiate_from_config(config.model) + sd = pl_sd["state_dict"] + + m, u = model.load_state_dict(sd, strict=False) + + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + if freeze: + for param in model.parameters(): + param.requires_grad = False + + model.eval() + return model diff --git a/util.py b/util.py new file mode 100644 index 0000000000000000000000000000000000000000..8f94b3fdff00957405b9ff6a94036b344b93548c --- /dev/null +++ b/util.py @@ -0,0 +1,136 @@ +import torch +from omegaconf import OmegaConf +from sgm.util import instantiate_from_config +from sgm.modules.diffusionmodules.sampling import * + +SD_XL_BASE_RATIOS = { + "0.5": (704, 1408), + "0.52": (704, 1344), + "0.57": (768, 1344), + "0.6": (768, 1280), + "0.68": (832, 1216), + "0.72": (832, 1152), + "0.78": (896, 1152), + "0.82": (896, 1088), + "0.88": (960, 1088), + "0.94": (960, 1024), + "1.0": (1024, 1024), + "1.07": (1024, 960), + "1.13": (1088, 960), + "1.21": (1088, 896), + "1.29": (1152, 896), + "1.38": (1152, 832), + "1.46": (1216, 832), + "1.67": (1280, 768), + "1.75": (1344, 768), + "1.91": (1344, 704), + "2.0": (1408, 704), + "2.09": (1472, 704), + "2.4": (1536, 640), + "2.5": (1600, 640), + "2.89": (1664, 576), + "3.0": (1728, 576), +} + +def init_model(cfg): + + model_cfg = OmegaConf.load(cfg.model_cfg_path) + ckpt = cfg.load_ckpt_path + + model = instantiate_from_config(model_cfg.model) + model.init_from_ckpt(ckpt) + + if cfg.type == "train": + model.train() + else: + model.to(torch.device("cuda", index=cfg.gpu)) + model.eval() + model.freeze() + + return model + +def init_sampling(cfgs): + + discretization_config = { + "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", + } + + if cfgs.dual_conditioner: + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.DualCFG", + "params": {"scale": cfgs.scale}, + } + + sampler = EulerEDMDualSampler( + num_steps=cfgs.steps, + discretization_config=discretization_config, + guider_config=guider_config, + s_churn=0.0, + s_tmin=0.0, + s_tmax=999.0, + s_noise=1.0, + verbose=True, + device=torch.device("cuda", index=cfgs.gpu) + ) + else: + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", + "params": {"scale": cfgs.scale[0]}, + } + + sampler = EulerEDMSampler( + num_steps=cfgs.steps, + discretization_config=discretization_config, + guider_config=guider_config, + s_churn=0.0, + s_tmin=0.0, + s_tmax=999.0, + s_noise=1.0, + verbose=True, + device=torch.device("cuda", index=cfgs.gpu) + ) + + return sampler + +def deep_copy(batch): + + c_batch = {} + for key in batch: + if isinstance(batch[key], torch.Tensor): + c_batch[key] = torch.clone(batch[key]) + elif isinstance(batch[key], (tuple, list)): + c_batch[key] = batch[key].copy() + else: + c_batch[key] = batch[key] + + return c_batch + +def prepare_batch(cfgs, batch): + + for key in batch: + if isinstance(batch[key], torch.Tensor): + batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu)) + + if not cfgs.dual_conditioner: + batch_uc = deep_copy(batch) + + if "ntxt" in batch: + batch_uc["txt"] = batch["ntxt"] + else: + batch_uc["txt"] = ["" for _ in range(len(batch["txt"]))] + + if "label" in batch: + batch_uc["label"] = ["" for _ in range(len(batch["label"]))] + + return batch, batch_uc, None + + else: + batch_uc_1 = deep_copy(batch) + batch_uc_2 = deep_copy(batch) + + batch_uc_1["ref"] = torch.zeros_like(batch["ref"]) + batch_uc_2["ref"] = torch.zeros_like(batch["ref"]) + + batch_uc_1["label"] = ["" for _ in range(len(batch["label"]))] + + return batch, batch_uc_1, batch_uc_2 \ No newline at end of file