diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..56b27e14e56c93b06ad152c8da171f2b3abf2d5c 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+/demo/**/* filter=lfs diff=lfs merge=lfs -text
+checkpoints/**/* filter=lfs diff=lfs merge=lfs -text
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..d5e4f15f8c02ac458b0019ffbb99d969fdc719db
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+**/__pycache__
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7b9df73b60743ae75687c8e2bae9293ce565dee3
--- /dev/null
+++ b/README.md
@@ -0,0 +1,14 @@
+---
+title: UDiffText
+emoji: 😋
+colorFrom: purple
+colorTo: blue
+sdk: gradio
+sdk_version: 3.41.0
+python_version: 3.11.4
+app_file: app.py
+pinned: true
+license: apache-2.0
+---
+
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/__pycache__/util.cpython-310.pyc b/__pycache__/util.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..85a9ed4dd239a1206898c9a03f621daaec8a3f16
Binary files /dev/null and b/__pycache__/util.cpython-310.pyc differ
diff --git a/__pycache__/util.cpython-311.pyc b/__pycache__/util.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a048ccb8f904da91f23ee4567045739650faca0d
Binary files /dev/null and b/__pycache__/util.cpython-311.pyc differ
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4277dc7a8ba4d51ad2f1769841b99151c5f666c
--- /dev/null
+++ b/app.py
@@ -0,0 +1,245 @@
+# -- coding: utf-8 --**
+import cv2
+import torch
+import os, glob
+import numpy as np
+import gradio as gr
+from PIL import Image
+from omegaconf import OmegaConf
+from contextlib import nullcontext
+from pytorch_lightning import seed_everything
+from os.path import join as ospj
+from random import randint
+from torchvision.utils import save_image
+from torchvision.transforms import Resize
+
+from util import *
+
+
+def process(image, mask):
+
+ img_h, img_w = image.shape[:2]
+
+ mask = mask[...,:1]//255
+ contours, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
+ if len(contours) != 1: raise gr.Error("One masked area only!")
+
+ m_x, m_y, m_w, m_h = cv2.boundingRect(contours[0])
+ c_x, c_y = m_x + m_w//2, m_y + m_h//2
+
+ if img_w > img_h:
+ if m_w > img_h: raise gr.Error("Illegal mask area!")
+ if c_x < img_w - c_x:
+ c_l = max(0, c_x - img_h//2)
+ c_r = c_l + img_h
+ else:
+ c_r = min(img_w, c_x + img_h//2)
+ c_l = c_r - img_h
+ image = image[:,c_l:c_r,:]
+ mask = mask[:,c_l:c_r,:]
+ else:
+ if m_h > img_w: raise gr.Error("Illegal mask area!")
+ if c_y < img_h - c_y:
+ c_t = max(0, c_y - img_w//2)
+ c_b = c_t + img_w
+ else:
+ c_b = min(img_h, c_y + img_w//2)
+ c_t = c_b - img_w
+ image = image[c_t:c_b,:,:]
+ mask = mask[c_t:c_b,:,:]
+
+ image = torch.from_numpy(image.transpose(2,0,1)).to(dtype=torch.float32) / 127.5 - 1.0
+ mask = torch.from_numpy(mask.transpose(2,0,1)).to(dtype=torch.float32)
+
+ image = resize(image[None])[0]
+ mask = resize(mask[None])[0]
+ masked = image * (1 - mask)
+
+ return image, mask, masked
+
+
+
+def predict(cfgs, model, sampler, batch):
+
+ context = nullcontext if cfgs.aae_enabled else torch.no_grad
+
+ with context():
+
+ batch, batch_uc_1 = prepare_batch(cfgs, batch)
+
+ c, uc_1 = model.conditioner.get_unconditional_conditioning(
+ batch,
+ batch_uc=batch_uc_1,
+ force_uc_zero_embeddings=cfgs.force_uc_zero_embeddings,
+ )
+
+ x = sampler.get_init_noise(cfgs, model, cond=c, batch=batch, uc=uc_1)
+ samples_z = sampler(model, x, cond=c, batch=batch, uc=uc_1, init_step=0,
+ aae_enabled = cfgs.aae_enabled, detailed = cfgs.detailed)
+
+ samples_x = model.decode_first_stage(samples_z)
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
+
+ return samples, samples_z
+
+
+def demo_predict(input_blk, text, num_samples, steps, scale, seed, show_detail):
+
+ global cfgs, global_index
+
+ if len(text) < cfgs.txt_len[0] or len(text) > cfgs.txt_len[1]:
+ raise gr.Error("Illegal text length!")
+
+ global_index += 1
+
+ if num_samples > 1: cfgs.noise_iters = 0
+
+ cfgs.batch_size = num_samples
+ cfgs.steps = steps
+ cfgs.scale[0] = scale
+ cfgs.detailed = show_detail
+ seed_everything(seed)
+
+ sampler.num_steps = steps
+ sampler.guider.scale_value = scale
+
+ image = input_blk["image"]
+ mask = input_blk["mask"]
+
+ image, mask, masked = process(image, mask)
+
+ seg_mask = torch.cat((torch.ones(len(text)), torch.zeros(cfgs.seq_len-len(text))))
+
+ # additional cond
+ txt = f"\"{text}\""
+ original_size_as_tuple = torch.tensor((cfgs.H, cfgs.W))
+ crop_coords_top_left = torch.tensor((0, 0))
+ target_size_as_tuple = torch.tensor((cfgs.H, cfgs.W))
+
+ image = torch.tile(image[None], (num_samples, 1, 1, 1))
+ mask = torch.tile(mask[None], (num_samples, 1, 1, 1))
+ masked = torch.tile(masked[None], (num_samples, 1, 1, 1))
+ seg_mask = torch.tile(seg_mask[None], (num_samples, 1))
+ original_size_as_tuple = torch.tile(original_size_as_tuple[None], (num_samples, 1))
+ crop_coords_top_left = torch.tile(crop_coords_top_left[None], (num_samples, 1))
+ target_size_as_tuple = torch.tile(target_size_as_tuple[None], (num_samples, 1))
+
+ text = [text for i in range(num_samples)]
+ txt = [txt for i in range(num_samples)]
+ name = [str(global_index) for i in range(num_samples)]
+
+ batch = {
+ "image": image,
+ "mask": mask,
+ "masked": masked,
+ "seg_mask": seg_mask,
+ "label": text,
+ "txt": txt,
+ "original_size_as_tuple": original_size_as_tuple,
+ "crop_coords_top_left": crop_coords_top_left,
+ "target_size_as_tuple": target_size_as_tuple,
+ "name": name
+ }
+
+ samples, samples_z = predict(cfgs, model, sampler, batch)
+ samples = samples.cpu().numpy().transpose(0, 2, 3, 1) * 255
+ results = [Image.fromarray(sample.astype(np.uint8)) for sample in samples]
+
+ if cfgs.detailed:
+ sections = []
+ attn_map = Image.open(f"./temp/attn_map/attn_map_{global_index}.png")
+ seg_maps = np.load(f"./temp/seg_map/seg_{global_index}.npy")
+ for i, seg_map in enumerate(seg_maps):
+ seg_map = cv2.resize(seg_map, (cfgs.W, cfgs.H))
+ sections.append((seg_map, text[0][i]))
+ seg = (results[0], sections)
+ else:
+ attn_map = None
+ seg = None
+
+ return results, attn_map, seg
+
+
+if __name__ == "__main__":
+
+ os.makedirs("./temp", exist_ok=True)
+ os.makedirs("./temp/attn_map", exist_ok=True)
+ os.makedirs("./temp/seg_map", exist_ok=True)
+
+ cfgs = OmegaConf.load("./configs/demo.yaml")
+
+ model = init_model(cfgs)
+ sampler = init_sampling(cfgs)
+ global_index = 0
+ resize = Resize((cfgs.H, cfgs.W))
+
+ block = gr.Blocks().queue()
+ with block:
+
+ with gr.Row():
+
+ gr.HTML(
+ """
+
+
+ UDiffText: A Unified Framework for High-quality Text Synthesis in Arbitrary Images via Character-aware Diffusion Models
+
+
+
+ Our proposed UDiffText is capable of synthesizing accurate and harmonious text in either synthetic or real-word images, thus can be applied to tasks like scene text editing (a), arbitrary text generation (b) and accurate T2I generation (c)
+
+
+
+ """
+ )
+
+ with gr.Row():
+
+ with gr.Column():
+
+ input_blk = gr.Image(source='upload', tool='sketch', type="numpy", label="Input", height=512)
+ gr.Markdown("Notice: please draw horizontally to indicate only **one** masked area.")
+ text = gr.Textbox(label="Text to render: (1~12 characters)", info="the text you want to render at the masked region")
+ run_button = gr.Button(variant="primary")
+
+ with gr.Accordion("Advanced options", open=False):
+
+ num_samples = gr.Slider(label="Images", info="number of generated images, locked as 1", minimum=1, maximum=1, value=1, step=1)
+ steps = gr.Slider(label="Steps", info ="denoising sampling steps", minimum=1, maximum=200, value=50, step=1)
+ scale = gr.Slider(label="Guidance Scale", info="the scale of classifier-free guidance (CFG)", minimum=0.0, maximum=10.0, value=5.0, step=0.1)
+ seed = gr.Slider(label="Seed", info="random seed for noise initialization", minimum=0, maximum=2147483647, step=1, randomize=True)
+ show_detail = gr.Checkbox(label="Show Detail", info="show the additional visualization results", value=False)
+
+ with gr.Column():
+
+ gallery = gr.Gallery(label="Output", height=512, preview=True)
+
+ with gr.Accordion("Visualization results", open=True):
+
+ with gr.Tab(label="Attention Maps"):
+ gr.Markdown("### Attention maps for each character (extracted from middle blocks at intermediate sampling step):")
+ attn_map = gr.Image(show_label=False, show_download_button=False)
+ with gr.Tab(label="Segmentation Maps"):
+ gr.Markdown("### Character-level segmentation maps (using upscaled attention maps):")
+ seg_map = gr.AnnotatedImage(height=384, show_label=False)
+
+ # examples
+ examples = []
+ example_paths = sorted(glob.glob(ospj("./demo/examples", "*")))
+ for example_path in example_paths:
+ label = example_path.split(os.sep)[-1].split(".")[0].split("_")[0]
+ examples.append([example_path, label])
+
+ gr.Markdown("## Examples:")
+ gr.Examples(
+ examples=examples,
+ inputs=[input_blk, text]
+ )
+
+ run_button.click(fn=demo_predict, inputs=[input_blk, text, num_samples, steps, scale, seed, show_detail], outputs=[gallery, attn_map, seg_map])
+
+ block.launch()
\ No newline at end of file
diff --git a/checkpoints/AEs/AE_inpainting_2.safetensors b/checkpoints/AEs/AE_inpainting_2.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..8cfdd04ec7695b20987d5323959d7d09b958f626
--- /dev/null
+++ b/checkpoints/AEs/AE_inpainting_2.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:547baac83984f8bf8b433882236b87e77eb4d2f5c71e3d7a04b8dec2fe02b81f
+size 334640988
diff --git a/checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt b/checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt
new file mode 100644
index 0000000000000000000000000000000000000000..ce122abfa90fe94f50cf2db8b7a9aafc3629bb2f
--- /dev/null
+++ b/checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4076c90467a907dcb8cde15776bfda4473010fe845739490341db74e82cd2267
+size 4059026213
diff --git a/checkpoints/st-step=100000+la-step=100000-v1.ckpt b/checkpoints/st-step=100000+la-step=100000-v1.ckpt
new file mode 100644
index 0000000000000000000000000000000000000000..6c5ff277cc700b911005c9c8bd0e73afa21d79eb
--- /dev/null
+++ b/checkpoints/st-step=100000+la-step=100000-v1.ckpt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:edea71eb83b6be72c33ef787a7122a810a7b9257bf97a276ef322707d5769878
+size 6148465904
diff --git a/configs/demo.yaml b/configs/demo.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..26f5b4bd1d495f5615d895d287b304a4bb39ea23
--- /dev/null
+++ b/configs/demo.yaml
@@ -0,0 +1,29 @@
+type: "demo"
+
+# path
+load_ckpt_path: "./checkpoints/st-step=100000+la-step=100000-v1.ckpt"
+model_cfg_path: "./configs/test/textdesign_sd_2.yaml"
+
+# param
+H: 512
+W: 512
+txt_len: [1, 12]
+seq_len: 12
+batch_size: 1
+
+channel: 4 # AE latent channel
+factor: 8 # AE downsample factor
+scale: [5.0, 0.0] # content scale, style scale
+noise_iters: 0
+force_uc_zero_embeddings: ["label"]
+aae_enabled: False
+detailed: False
+
+# runtime
+steps: 50
+init_step: 0
+num_workers: 0
+use_gpu: True
+gpu: 0
+max_iter: 100
+
diff --git a/configs/test/textdesign_sd_2.yaml b/configs/test/textdesign_sd_2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..243f723f9182825cb382e23d2dbc0ea4221e5a54
--- /dev/null
+++ b/configs/test/textdesign_sd_2.yaml
@@ -0,0 +1,137 @@
+model:
+ target: sgm.models.diffusion.DiffusionEngine
+ params:
+ input_key: image
+ scale_factor: 0.18215
+ disable_first_stage_autocast: True
+
+ denoiser_config:
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
+ params:
+ num_idx: 1000
+
+ weighting_config:
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
+ scaling_config:
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
+
+ network_config:
+ target: sgm.modules.diffusionmodules.openaimodel.UNetAddModel
+ params:
+ use_checkpoint: False
+ in_channels: 9
+ out_channels: 4
+ ctrl_channels: 0
+ model_channels: 320
+ attention_resolutions: [4, 2, 1]
+ attn_type: add_attn
+ attn_layers:
+ - output_blocks.6.1
+ num_res_blocks: 2
+ channel_mult: [1, 2, 4, 4]
+ num_head_channels: 64
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: 1
+ context_dim: 0
+ add_context_dim: 2048
+ legacy: False
+
+ conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+ # crossattn cond
+ # - is_trainable: False
+ # input_key: txt
+ # target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+ # params:
+ # arch: ViT-H-14
+ # version: ./checkpoints/encoders/OpenCLIP/ViT-H-14/open_clip_pytorch_model.bin
+ # layer: penultimate
+ # add crossattn cond
+ - is_trainable: False
+ input_key: label
+ target: sgm.modules.encoders.modules.LabelEncoder
+ params:
+ is_add_embedder: True
+ max_len: 12
+ emb_dim: 2048
+ n_heads: 8
+ n_trans_layers: 12
+ ckpt_path: ./checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt # ./checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt
+ # concat cond
+ - is_trainable: False
+ input_key: mask
+ target: sgm.modules.encoders.modules.IdentityEncoder
+ - is_trainable: False
+ input_key: masked
+ target: sgm.modules.encoders.modules.LatentEncoder
+ params:
+ scale_factor: 0.18215
+ config:
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
+ params:
+ ckpt_path: ./checkpoints/AEs/AE_inpainting_2.safetensors
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ attn_type: vanilla-xformers
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1, 2, 4, 4]
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ first_stage_config:
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ attn_type: vanilla-xformers
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1, 2, 4, 4]
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ loss_fn_config:
+ target: sgm.modules.diffusionmodules.loss.FullLoss # StandardDiffusionLoss
+ params:
+ seq_len: 12
+ kernel_size: 3
+ gaussian_sigma: 0.5
+ min_attn_size: 16
+ lambda_local_loss: 0.02
+ lambda_ocr_loss: 0.001
+ ocr_enabled: False
+
+ predictor_config:
+ target: sgm.modules.predictors.model.ParseqPredictor
+ params:
+ ckpt_path: "./checkpoints/predictors/parseq-bb5792a6.pt"
+
+ sigma_sampler_config:
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
+ params:
+ num_idx: 1000
+
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
\ No newline at end of file
diff --git a/demo/examples/DIRTY_0_0.png b/demo/examples/DIRTY_0_0.png
new file mode 100644
index 0000000000000000000000000000000000000000..6d2423f60773e1a0140c573f3b816fac225138b4
Binary files /dev/null and b/demo/examples/DIRTY_0_0.png differ
diff --git a/demo/examples/ENGINE_0_0.png b/demo/examples/ENGINE_0_0.png
new file mode 100644
index 0000000000000000000000000000000000000000..592ca22a7277853103e24ffcccab33f5ad33859a
Binary files /dev/null and b/demo/examples/ENGINE_0_0.png differ
diff --git a/demo/examples/FAVOURITE_0_0.jpeg b/demo/examples/FAVOURITE_0_0.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..542a2a945f5d4c68c3c4249018577296479e931b
Binary files /dev/null and b/demo/examples/FAVOURITE_0_0.jpeg differ
diff --git a/demo/examples/FRONTIER_0_0.png b/demo/examples/FRONTIER_0_0.png
new file mode 100644
index 0000000000000000000000000000000000000000..0eab04679bd107680e57afc378a312bbcf605e9c
Binary files /dev/null and b/demo/examples/FRONTIER_0_0.png differ
diff --git a/demo/examples/Peaceful_0_0.jpeg b/demo/examples/Peaceful_0_0.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..49daf09e7a91834ed9b224bb61c5b03f3fa6ce80
Binary files /dev/null and b/demo/examples/Peaceful_0_0.jpeg differ
diff --git a/demo/examples/Scamps_0_0.png b/demo/examples/Scamps_0_0.png
new file mode 100644
index 0000000000000000000000000000000000000000..0513a2897bf32f7d5994dd1484619f4b8b474e90
Binary files /dev/null and b/demo/examples/Scamps_0_0.png differ
diff --git a/demo/examples/TREE_0_0.png b/demo/examples/TREE_0_0.png
new file mode 100644
index 0000000000000000000000000000000000000000..1f5985d0f5f6f81b36656c7757ca81b2141fb778
Binary files /dev/null and b/demo/examples/TREE_0_0.png differ
diff --git a/demo/examples/better_0_0.jpg b/demo/examples/better_0_0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..64b3ad7ef917d473fe3e6e5c2c7af99dc1e72a85
Binary files /dev/null and b/demo/examples/better_0_0.jpg differ
diff --git a/demo/examples/tested_0_0.png b/demo/examples/tested_0_0.png
new file mode 100644
index 0000000000000000000000000000000000000000..28311ea559b559528533e573c6380b68429ae51c
Binary files /dev/null and b/demo/examples/tested_0_0.png differ
diff --git a/demo/teaser.png b/demo/teaser.png
new file mode 100644
index 0000000000000000000000000000000000000000..a36a109cfd785b093187cf75c453aacc7cbb04ab
--- /dev/null
+++ b/demo/teaser.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dcd166cc9691c99a7ee93a028ab485472171ee348a5f4dbaf82f6bf1fb27c66d
+size 2623749
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..669b84544de99f30dd40a5b20c7309f095d2c4e2
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,28 @@
+colorlover==0.3.0
+einops==0.6.1
+gradio==3.41.0
+imageio==2.31.2
+img2dataset==1.42.0
+kornia==0.6.9
+lpips==0.1.4
+matplotlib==3.7.2
+numpy==1.25.1
+omegaconf==2.3.0
+open-clip-torch==2.20.0
+opencv-python==4.6.0.66
+Pillow==9.5.0
+pytorch-fid==0.3.0
+pytorch-lightning==2.0.1
+safetensors==0.3.1
+scikit-learn==1.3.0
+scipy==1.11.1
+seaborn==0.12.2
+tensorboard==2.14.0
+timm==0.9.2
+tokenizers==0.13.3
+torch==2.1.0
+torchvision==0.16.0
+tqdm==4.65.0
+transformers==4.30.2
+xformers==0.0.22.post7
+
diff --git a/sgm/__init__.py b/sgm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e273bdaa90e0ff822a6098dd531046b90a12ce3e
--- /dev/null
+++ b/sgm/__init__.py
@@ -0,0 +1,2 @@
+from .models import AutoencodingEngine, DiffusionEngine
+from .util import instantiate_from_config
diff --git a/sgm/__pycache__/__init__.cpython-310.pyc b/sgm/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8a4eb39b783460d0875ab7de452a57ba9d0add8f
Binary files /dev/null and b/sgm/__pycache__/__init__.cpython-310.pyc differ
diff --git a/sgm/__pycache__/__init__.cpython-311.pyc b/sgm/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..43117e01f8831d30bc313c6f118866560228e70d
Binary files /dev/null and b/sgm/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sgm/__pycache__/lr_scheduler.cpython-311.pyc b/sgm/__pycache__/lr_scheduler.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e9c54ee8b655266ee6ae4d38dd3e589b7c462b6b
Binary files /dev/null and b/sgm/__pycache__/lr_scheduler.cpython-311.pyc differ
diff --git a/sgm/__pycache__/util.cpython-310.pyc b/sgm/__pycache__/util.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f97a94f63e65785a86d02cb5a7dcd88d6aaaafdd
Binary files /dev/null and b/sgm/__pycache__/util.cpython-310.pyc differ
diff --git a/sgm/__pycache__/util.cpython-311.pyc b/sgm/__pycache__/util.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..574f8371efe3e65c76cc5269d575794f5a372576
Binary files /dev/null and b/sgm/__pycache__/util.cpython-311.pyc differ
diff --git a/sgm/lr_scheduler.py b/sgm/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2f4d384c1fcaff0df13e0564450d3fa972ace42
--- /dev/null
+++ b/sgm/lr_scheduler.py
@@ -0,0 +1,135 @@
+import numpy as np
+
+
+class LambdaWarmUpCosineScheduler:
+ """
+ note: use with a base_lr of 1.0
+ """
+
+ def __init__(
+ self,
+ warm_up_steps,
+ lr_min,
+ lr_max,
+ lr_start,
+ max_decay_steps,
+ verbosity_interval=0,
+ ):
+ self.lr_warm_up_steps = warm_up_steps
+ self.lr_start = lr_start
+ self.lr_min = lr_min
+ self.lr_max = lr_max
+ self.lr_max_decay_steps = max_decay_steps
+ self.last_lr = 0.0
+ self.verbosity_interval = verbosity_interval
+
+ def schedule(self, n, **kwargs):
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0:
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
+ if n < self.lr_warm_up_steps:
+ lr = (
+ self.lr_max - self.lr_start
+ ) / self.lr_warm_up_steps * n + self.lr_start
+ self.last_lr = lr
+ return lr
+ else:
+ t = (n - self.lr_warm_up_steps) / (
+ self.lr_max_decay_steps - self.lr_warm_up_steps
+ )
+ t = min(t, 1.0)
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
+ 1 + np.cos(t * np.pi)
+ )
+ self.last_lr = lr
+ return lr
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n, **kwargs)
+
+
+class LambdaWarmUpCosineScheduler2:
+ """
+ supports repeated iterations, configurable via lists
+ note: use with a base_lr of 1.0.
+ """
+
+ def __init__(
+ self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
+ ):
+ assert (
+ len(warm_up_steps)
+ == len(f_min)
+ == len(f_max)
+ == len(f_start)
+ == len(cycle_lengths)
+ )
+ self.lr_warm_up_steps = warm_up_steps
+ self.f_start = f_start
+ self.f_min = f_min
+ self.f_max = f_max
+ self.cycle_lengths = cycle_lengths
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
+ self.last_f = 0.0
+ self.verbosity_interval = verbosity_interval
+
+ def find_in_interval(self, n):
+ interval = 0
+ for cl in self.cum_cycles[1:]:
+ if n <= cl:
+ return interval
+ interval += 1
+
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0:
+ print(
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}"
+ )
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
+ cycle
+ ] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ t = (n - self.lr_warm_up_steps[cycle]) / (
+ self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
+ )
+ t = min(t, 1.0)
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
+ 1 + np.cos(t * np.pi)
+ )
+ self.last_f = f
+ return f
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n, **kwargs)
+
+
+class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0:
+ print(
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}"
+ )
+
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
+ cycle
+ ] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
+ self.cycle_lengths[cycle] - n
+ ) / (self.cycle_lengths[cycle])
+ self.last_f = f
+ return f
diff --git a/sgm/models/__init__.py b/sgm/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c410b3747afc208e4204c8f140170e0a7808eace
--- /dev/null
+++ b/sgm/models/__init__.py
@@ -0,0 +1,2 @@
+from .autoencoder import AutoencodingEngine
+from .diffusion import DiffusionEngine
diff --git a/sgm/models/__pycache__/__init__.cpython-310.pyc b/sgm/models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6e32a1135ed5ff1adf7e8a66e63d5e2d06b28559
Binary files /dev/null and b/sgm/models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/sgm/models/__pycache__/__init__.cpython-311.pyc b/sgm/models/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b800cd08b4d280ca659ccc4a6cb60969bbbab0f5
Binary files /dev/null and b/sgm/models/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sgm/models/__pycache__/autoencoder.cpython-310.pyc b/sgm/models/__pycache__/autoencoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b6ac83302294858227ca4a09a792683bf4d4123
Binary files /dev/null and b/sgm/models/__pycache__/autoencoder.cpython-310.pyc differ
diff --git a/sgm/models/__pycache__/autoencoder.cpython-311.pyc b/sgm/models/__pycache__/autoencoder.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6d8aa3fc40e71a77a8241846d44b73351c5b0379
Binary files /dev/null and b/sgm/models/__pycache__/autoencoder.cpython-311.pyc differ
diff --git a/sgm/models/__pycache__/diffusion.cpython-310.pyc b/sgm/models/__pycache__/diffusion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19575e3def7041d87e46864ed48dc30ee4ce8f5d
Binary files /dev/null and b/sgm/models/__pycache__/diffusion.cpython-310.pyc differ
diff --git a/sgm/models/__pycache__/diffusion.cpython-311.pyc b/sgm/models/__pycache__/diffusion.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bf0b94797bf0aab0ca87df38acf3cef3af832148
Binary files /dev/null and b/sgm/models/__pycache__/diffusion.cpython-311.pyc differ
diff --git a/sgm/models/autoencoder.py b/sgm/models/autoencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..83c6863df153c147b343978ef37c201691abbe23
--- /dev/null
+++ b/sgm/models/autoencoder.py
@@ -0,0 +1,335 @@
+import re
+from abc import abstractmethod
+from contextlib import contextmanager
+from typing import Any, Dict, Tuple, Union
+
+import pytorch_lightning as pl
+import torch
+from omegaconf import ListConfig
+from packaging import version
+from safetensors.torch import load_file as load_safetensors
+
+from ..modules.diffusionmodules.model import Decoder, Encoder
+from ..modules.distributions.distributions import DiagonalGaussianDistribution
+from ..modules.ema import LitEma
+from ..util import default, get_obj_from_str, instantiate_from_config
+
+
+class AbstractAutoencoder(pl.LightningModule):
+ """
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
+ unCLIP models, etc. Hence, it is fairly general, and specific features
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
+ """
+
+ def __init__(
+ self,
+ ema_decay: Union[None, float] = None,
+ monitor: Union[None, str] = None,
+ input_key: str = "jpg",
+ ckpt_path: Union[None, str] = None,
+ ignore_keys: Union[Tuple, list, ListConfig] = (),
+ ):
+ super().__init__()
+ self.input_key = input_key
+ self.use_ema = ema_decay is not None
+ if monitor is not None:
+ self.monitor = monitor
+
+ if self.use_ema:
+ self.model_ema = LitEma(self, decay=ema_decay)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
+ self.automatic_optimization = False
+
+ def init_from_ckpt(
+ self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple()
+ ) -> None:
+ if path.endswith("ckpt"):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ elif path.endswith("safetensors"):
+ sd = load_safetensors(path)
+ else:
+ raise NotImplementedError
+
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if re.match(ik, k):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing, unexpected = self.load_state_dict(sd, strict=False)
+ print(
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
+ )
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ @abstractmethod
+ def get_input(self, batch) -> Any:
+ raise NotImplementedError()
+
+ def on_train_batch_end(self, *args, **kwargs):
+ # for EMA computation
+ if self.use_ema:
+ self.model_ema(self)
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.parameters())
+ self.model_ema.copy_to(self)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ @abstractmethod
+ def encode(self, *args, **kwargs) -> torch.Tensor:
+ raise NotImplementedError("encode()-method of abstract base class called")
+
+ @abstractmethod
+ def decode(self, *args, **kwargs) -> torch.Tensor:
+ raise NotImplementedError("decode()-method of abstract base class called")
+
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
+ print(f"loading >>> {cfg['target']} <<< optimizer from config")
+ return get_obj_from_str(cfg["target"])(
+ params, lr=lr, **cfg.get("params", dict())
+ )
+
+ def configure_optimizers(self) -> Any:
+ raise NotImplementedError()
+
+
+class AutoencodingEngine(AbstractAutoencoder):
+ """
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
+ (we also restore them explicitly as special cases for legacy reasons).
+ Regularizations such as KL or VQ are moved to the regularizer class.
+ """
+
+ def __init__(
+ self,
+ *args,
+ encoder_config: Dict,
+ decoder_config: Dict,
+ loss_config: Dict,
+ regularizer_config: Dict,
+ optimizer_config: Union[Dict, None] = None,
+ lr_g_factor: float = 1.0,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ # todo: add options to freeze encoder/decoder
+ self.encoder = instantiate_from_config(encoder_config)
+ self.decoder = instantiate_from_config(decoder_config)
+ self.loss = instantiate_from_config(loss_config)
+ self.regularization = instantiate_from_config(regularizer_config)
+ self.optimizer_config = default(
+ optimizer_config, {"target": "torch.optim.Adam"}
+ )
+ self.lr_g_factor = lr_g_factor
+
+ def get_input(self, batch: Dict) -> torch.Tensor:
+ # assuming unified data format, dataloader returns a dict.
+ # image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc)
+ return batch[self.input_key]
+
+ def get_autoencoder_params(self) -> list:
+ params = (
+ list(self.encoder.parameters())
+ + list(self.decoder.parameters())
+ + list(self.regularization.get_trainable_parameters())
+ + list(self.loss.get_trainable_autoencoder_parameters())
+ )
+ return params
+
+ def get_discriminator_params(self) -> list:
+ params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
+ return params
+
+ def get_last_layer(self):
+ return self.decoder.get_last_layer()
+
+ def encode(self, x: Any, return_reg_log: bool = False) -> Any:
+ z = self.encoder(x)
+ z, reg_log = self.regularization(z)
+ if return_reg_log:
+ return z, reg_log
+ return z
+
+ def decode(self, z: Any) -> torch.Tensor:
+ x = self.decoder(z)
+ return x
+
+ def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ z, reg_log = self.encode(x, return_reg_log=True)
+ dec = self.decode(z)
+ return z, dec, reg_log
+
+ def training_step(self, batch, batch_idx, optimizer_idx) -> Any:
+ x = self.get_input(batch)
+ z, xrec, regularization_log = self(x)
+
+ if optimizer_idx == 0:
+ # autoencode
+ aeloss, log_dict_ae = self.loss(
+ regularization_log,
+ x,
+ xrec,
+ optimizer_idx,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="train",
+ )
+
+ self.log_dict(
+ log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
+ )
+ return aeloss
+
+ if optimizer_idx == 1:
+ # discriminator
+ discloss, log_dict_disc = self.loss(
+ regularization_log,
+ x,
+ xrec,
+ optimizer_idx,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="train",
+ )
+ self.log_dict(
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
+ )
+ return discloss
+
+ def validation_step(self, batch, batch_idx) -> Dict:
+ log_dict = self._validation_step(batch, batch_idx)
+ with self.ema_scope():
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
+ log_dict.update(log_dict_ema)
+ return log_dict
+
+ def _validation_step(self, batch, batch_idx, postfix="") -> Dict:
+ x = self.get_input(batch)
+
+ z, xrec, regularization_log = self(x)
+ aeloss, log_dict_ae = self.loss(
+ regularization_log,
+ x,
+ xrec,
+ 0,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="val" + postfix,
+ )
+
+ discloss, log_dict_disc = self.loss(
+ regularization_log,
+ x,
+ xrec,
+ 1,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="val" + postfix,
+ )
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
+ log_dict_ae.update(log_dict_disc)
+ self.log_dict(log_dict_ae)
+ return log_dict_ae
+
+ def configure_optimizers(self) -> Any:
+ ae_params = self.get_autoencoder_params()
+ disc_params = self.get_discriminator_params()
+
+ opt_ae = self.instantiate_optimizer_from_config(
+ ae_params,
+ default(self.lr_g_factor, 1.0) * self.learning_rate,
+ self.optimizer_config,
+ )
+ opt_disc = self.instantiate_optimizer_from_config(
+ disc_params, self.learning_rate, self.optimizer_config
+ )
+
+ return [opt_ae, opt_disc], []
+
+ @torch.no_grad()
+ def log_images(self, batch: Dict, **kwargs) -> Dict:
+ log = dict()
+ x = self.get_input(batch)
+ _, xrec, _ = self(x)
+ log["inputs"] = x
+ log["reconstructions"] = xrec
+ with self.ema_scope():
+ _, xrec_ema, _ = self(x)
+ log["reconstructions_ema"] = xrec_ema
+ return log
+
+
+class AutoencoderKL(AutoencodingEngine):
+ def __init__(self, embed_dim: int, **kwargs):
+ ddconfig = kwargs.pop("ddconfig")
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", ())
+ super().__init__(
+ encoder_config={"target": "torch.nn.Identity"},
+ decoder_config={"target": "torch.nn.Identity"},
+ regularizer_config={"target": "torch.nn.Identity"},
+ loss_config=kwargs.pop("lossconfig"),
+ **kwargs,
+ )
+ assert ddconfig["double_z"]
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ def encode(self, x):
+ assert (
+ not self.training
+ ), f"{self.__class__.__name__} only supports inference currently"
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z, **decoder_kwargs):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z, **decoder_kwargs)
+ return dec
+
+
+class AutoencoderKLInferenceWrapper(AutoencoderKL):
+ def encode(self, x):
+ return super().encode(x).sample()
+
+
+class IdentityFirstStage(AbstractAutoencoder):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def get_input(self, x: Any) -> Any:
+ return x
+
+ def encode(self, x: Any, *args, **kwargs) -> Any:
+ return x
+
+ def decode(self, x: Any, *args, **kwargs) -> Any:
+ return x
diff --git a/sgm/models/diffusion.py b/sgm/models/diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c72ed364de456b081d85939c94bae3e53ec7128
--- /dev/null
+++ b/sgm/models/diffusion.py
@@ -0,0 +1,328 @@
+from contextlib import contextmanager
+from typing import Any, Dict, List, Tuple, Union
+
+import pytorch_lightning as pl
+import torch
+from omegaconf import ListConfig, OmegaConf
+from safetensors.torch import load_file as load_safetensors
+from torch.optim.lr_scheduler import LambdaLR
+
+from ..modules import UNCONDITIONAL_CONFIG
+from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
+from ..modules.ema import LitEma
+from ..util import (
+ default,
+ disabled_train,
+ get_obj_from_str,
+ instantiate_from_config,
+ log_txt_as_img,
+)
+
+
+class DiffusionEngine(pl.LightningModule):
+ def __init__(
+ self,
+ network_config,
+ denoiser_config,
+ first_stage_config,
+ conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ network_wrapper: Union[None, str] = None,
+ ckpt_path: Union[None, str] = None,
+ use_ema: bool = False,
+ ema_decay_rate: float = 0.9999,
+ scale_factor: float = 1.0,
+ disable_first_stage_autocast=False,
+ input_key: str = "jpg",
+ log_keys: Union[List, None] = None,
+ no_cond_log: bool = False,
+ compile_model: bool = False,
+ opt_keys: Union[List, None] = None
+ ):
+ super().__init__()
+ self.opt_keys = opt_keys
+ self.log_keys = log_keys
+ self.input_key = input_key
+ self.optimizer_config = default(
+ optimizer_config, {"target": "torch.optim.AdamW"}
+ )
+ model = instantiate_from_config(network_config)
+ self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
+ model, compile_model=compile_model
+ )
+
+ self.denoiser = instantiate_from_config(denoiser_config)
+ self.sampler = (
+ instantiate_from_config(sampler_config)
+ if sampler_config is not None
+ else None
+ )
+ self.conditioner = instantiate_from_config(
+ default(conditioner_config, UNCONDITIONAL_CONFIG)
+ )
+ self.scheduler_config = scheduler_config
+ self._init_first_stage(first_stage_config)
+
+ self.loss_fn = (
+ instantiate_from_config(loss_fn_config)
+ if loss_fn_config is not None
+ else None
+ )
+
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self.model, decay=ema_decay_rate)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ self.scale_factor = scale_factor
+ self.disable_first_stage_autocast = disable_first_stage_autocast
+ self.no_cond_log = no_cond_log
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path)
+
+ def init_from_ckpt(
+ self,
+ path: str,
+ ) -> None:
+ if path.endswith("ckpt"):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ elif path.endswith("safetensors"):
+ sd = load_safetensors(path)
+ else:
+ raise NotImplementedError
+
+ missing, unexpected = self.load_state_dict(sd, strict=False)
+ print(
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
+ )
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ def freeze(self):
+
+ for param in self.parameters():
+ param.requires_grad_(False)
+
+ def _init_first_stage(self, config):
+ model = instantiate_from_config(config).eval()
+ model.train = disabled_train
+ for param in model.parameters():
+ param.requires_grad = False
+ self.first_stage_model = model
+
+ def get_input(self, batch):
+ # assuming unified data format, dataloader returns a dict.
+ # image tensors should be scaled to -1 ... 1 and in bchw format
+ return batch[self.input_key]
+
+ @torch.no_grad()
+ def decode_first_stage(self, z):
+ z = 1.0 / self.scale_factor * z
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
+ out = self.first_stage_model.decode(z)
+ return out
+
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
+ z = self.first_stage_model.encode(x)
+ z = self.scale_factor * z
+ return z
+
+ def forward(self, x, batch):
+
+ loss, loss_dict = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch, self.first_stage_model, self.scale_factor)
+
+ return loss, loss_dict
+
+ def shared_step(self, batch: Dict) -> Any:
+ x = self.get_input(batch)
+ x = self.encode_first_stage(x)
+ batch["global_step"] = self.global_step
+ loss, loss_dict = self(x, batch)
+ return loss, loss_dict
+
+ def training_step(self, batch, batch_idx):
+ loss, loss_dict = self.shared_step(batch)
+
+ self.log_dict(
+ loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
+ )
+
+ self.log(
+ "global_step",
+ float(self.global_step),
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=False,
+ )
+
+ lr = self.optimizers().param_groups[0]["lr"]
+ self.log(
+ "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
+ )
+
+ return loss
+
+ def on_train_start(self, *args, **kwargs):
+ if self.sampler is None or self.loss_fn is None:
+ raise ValueError("Sampler and loss function need to be set for training.")
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self.model)
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
+ return get_obj_from_str(cfg["target"])(
+ params, lr=lr, **cfg.get("params", dict())
+ )
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = []
+ print("Trainable parameter list: ")
+ print("-"*20)
+ for name, param in self.model.named_parameters():
+ if any([key in name for key in self.opt_keys]):
+ params.append(param)
+ print(name)
+ else:
+ param.requires_grad_(False)
+ for embedder in self.conditioner.embedders:
+ if embedder.is_trainable:
+ for name, param in embedder.named_parameters():
+ params.append(param)
+ print(name)
+ print("-"*20)
+ opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
+ scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda epoch: 0.95**epoch)
+
+ return [opt], scheduler
+
+ @torch.no_grad()
+ def sample(
+ self,
+ cond: Dict,
+ uc: Union[Dict, None] = None,
+ batch_size: int = 16,
+ shape: Union[None, Tuple, List] = None,
+ **kwargs,
+ ):
+ randn = torch.randn(batch_size, *shape).to(self.device)
+
+ denoiser = lambda input, sigma, c: self.denoiser(
+ self.model, input, sigma, c, **kwargs
+ )
+ samples = self.sampler(denoiser, randn, cond, uc=uc)
+ return samples
+
+ @torch.no_grad()
+ def log_conditionings(self, batch: Dict, n: int) -> Dict:
+ """
+ Defines heuristics to log different conditionings.
+ These can be lists of strings (text-to-image), tensors, ints, ...
+ """
+ image_h, image_w = batch[self.input_key].shape[2:]
+ log = dict()
+
+ for embedder in self.conditioner.embedders:
+ if (
+ (self.log_keys is None) or (embedder.input_key in self.log_keys)
+ ) and not self.no_cond_log:
+ x = batch[embedder.input_key][:n]
+ if isinstance(x, torch.Tensor):
+ if x.dim() == 1:
+ # class-conditional, convert integer to string
+ x = [str(x[i].item()) for i in range(x.shape[0])]
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
+ elif x.dim() == 2:
+ # size and crop cond and the like
+ x = [
+ "x".join([str(xx) for xx in x[i].tolist()])
+ for i in range(x.shape[0])
+ ]
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
+ else:
+ raise NotImplementedError()
+ elif isinstance(x, (List, ListConfig)):
+ if isinstance(x[0], str):
+ # strings
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
+ else:
+ raise NotImplementedError()
+ else:
+ raise NotImplementedError()
+ log[embedder.input_key] = xc
+ return log
+
+ @torch.no_grad()
+ def log_images(
+ self,
+ batch: Dict,
+ N: int = 8,
+ sample: bool = True,
+ ucg_keys: List[str] = None,
+ **kwargs,
+ ) -> Dict:
+ conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
+ if ucg_keys:
+ assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
+ "Each defined ucg key for sampling must be in the provided conditioner input keys,"
+ f"but we have {ucg_keys} vs. {conditioner_input_keys}"
+ )
+ else:
+ ucg_keys = conditioner_input_keys
+ log = dict()
+
+ x = self.get_input(batch)
+
+ c, uc = self.conditioner.get_unconditional_conditioning(
+ batch,
+ force_uc_zero_embeddings=ucg_keys
+ if len(self.conditioner.embedders) > 0
+ else [],
+ )
+
+ sampling_kwargs = {}
+
+ N = min(x.shape[0], N)
+ x = x.to(self.device)[:N]
+ log["inputs"] = x
+ z = self.encode_first_stage(x)
+ log["reconstructions"] = self.decode_first_stage(z)
+ log.update(self.log_conditionings(batch, N))
+
+ for k in c:
+ if isinstance(c[k], torch.Tensor):
+ c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
+
+ if sample:
+ with self.ema_scope("Plotting"):
+ samples = self.sample(
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
+ )
+ samples = self.decode_first_stage(samples)
+ log["samples"] = samples
+ return log
diff --git a/sgm/modules/__init__.py b/sgm/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2a7e034952f8feb22a035dbfd4872025a9727b9
--- /dev/null
+++ b/sgm/modules/__init__.py
@@ -0,0 +1,6 @@
+from .encoders.modules import GeneralConditioner, DualConditioner
+
+UNCONDITIONAL_CONFIG = {
+ "target": "sgm.modules.GeneralConditioner",
+ "params": {"emb_models": []},
+}
diff --git a/sgm/modules/__pycache__/__init__.cpython-310.pyc b/sgm/modules/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b9a0f3535f61d9b2a3bb255f360ba1ad6e785410
Binary files /dev/null and b/sgm/modules/__pycache__/__init__.cpython-310.pyc differ
diff --git a/sgm/modules/__pycache__/__init__.cpython-311.pyc b/sgm/modules/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6406885e0d95309f1c2c8d76309da1dbdd047255
Binary files /dev/null and b/sgm/modules/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sgm/modules/__pycache__/attention.cpython-310.pyc b/sgm/modules/__pycache__/attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..15c5ddfaed6c10a092cd294144afba5851cd7f08
Binary files /dev/null and b/sgm/modules/__pycache__/attention.cpython-310.pyc differ
diff --git a/sgm/modules/__pycache__/attention.cpython-311.pyc b/sgm/modules/__pycache__/attention.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3864f4e4d7bed0f20beb2b02a14a3676b282ff9e
Binary files /dev/null and b/sgm/modules/__pycache__/attention.cpython-311.pyc differ
diff --git a/sgm/modules/__pycache__/ema.cpython-310.pyc b/sgm/modules/__pycache__/ema.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f5719c32f3bd3ac007c9456238351e07bb1ae665
Binary files /dev/null and b/sgm/modules/__pycache__/ema.cpython-310.pyc differ
diff --git a/sgm/modules/__pycache__/ema.cpython-311.pyc b/sgm/modules/__pycache__/ema.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d139d6b613906bd6db63639b0e3316a31b1b54e
Binary files /dev/null and b/sgm/modules/__pycache__/ema.cpython-311.pyc differ
diff --git a/sgm/modules/attention.py b/sgm/modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b006003ba41e1327ea8df469f7441139a3d65a2
--- /dev/null
+++ b/sgm/modules/attention.py
@@ -0,0 +1,976 @@
+import math
+from inspect import isfunction
+from typing import Any, Optional
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from packaging import version
+from torch import nn, einsum
+
+
+if version.parse(torch.__version__) >= version.parse("2.0.0"):
+ SDP_IS_AVAILABLE = True
+ from torch.backends.cuda import SDPBackend, sdp_kernel
+
+ BACKEND_MAP = {
+ SDPBackend.MATH: {
+ "enable_math": True,
+ "enable_flash": False,
+ "enable_mem_efficient": False,
+ },
+ SDPBackend.FLASH_ATTENTION: {
+ "enable_math": False,
+ "enable_flash": True,
+ "enable_mem_efficient": False,
+ },
+ SDPBackend.EFFICIENT_ATTENTION: {
+ "enable_math": False,
+ "enable_flash": False,
+ "enable_mem_efficient": True,
+ },
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
+ }
+else:
+ from contextlib import nullcontext
+
+ SDP_IS_AVAILABLE = False
+ sdp_kernel = nullcontext
+ BACKEND_MAP = {}
+ print(
+ f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
+ f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
+ )
+
+try:
+ import xformers
+ import xformers.ops
+
+ XFORMERS_IS_AVAILABLE = True
+except:
+ XFORMERS_IS_AVAILABLE = False
+ print("no module 'xformers'. Processing without...")
+
+from .diffusionmodules.util import checkpoint
+
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return {el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = (
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
+ if not glu
+ else GEGLU(dim, inner_dim)
+ )
+
+ self.net = nn.Sequential(
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
+ )
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
+ )
+ k = k.softmax(dim=-1)
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
+ out = rearrange(
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
+ )
+ return self.to_out(out)
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = rearrange(q, "b c h w -> b (h w) c")
+ k = rearrange(k, "b c h w -> b c (h w)")
+ w_ = torch.einsum("bij,bjk->bik", q, k)
+
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, "b c h w -> b c (h w)")
+ w_ = rearrange(w_, "b i j -> b j i")
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self,
+ query_dim,
+ context_dim=None,
+ heads=8,
+ dim_head=64,
+ dropout=0.0,
+ backend=None,
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = zero_module(nn.Sequential(
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
+ ))
+ self.backend = backend
+
+ self.attn_map_cache = None
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ additional_tokens=None,
+ n_times_crossframe_attn_in_self=0,
+ ):
+ h = self.heads
+
+ if additional_tokens is not None:
+ # get the number of masked tokens at the beginning of the output sequence
+ n_tokens_to_mask = additional_tokens.shape[1]
+ # add additional token
+ x = torch.cat([additional_tokens, x], dim=1)
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ if n_times_crossframe_attn_in_self:
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
+ n_cp = x.shape[0] // n_times_crossframe_attn_in_self
+ k = repeat(
+ k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
+ )
+ v = repeat(
+ v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
+ )
+
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
+
+ ## old
+
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+ del q, k
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ sim = sim.softmax(dim=-1)
+
+ # save attn_map
+ if self.attn_map_cache is not None:
+ bh, n, l = sim.shape
+ size = int(n**0.5)
+ self.attn_map_cache["size"] = size
+ self.attn_map_cache["attn_map"] = sim
+
+ out = einsum('b i j, b j d -> b i d', sim, v)
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
+
+ ## new
+ # with sdp_kernel(**BACKEND_MAP[self.backend]):
+ # # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
+ # out = F.scaled_dot_product_attention(
+ # q, k, v, attn_mask=mask
+ # ) # scale is dim_head ** -0.5 per default
+
+ # del q, k, v
+ # out = rearrange(out, "b h n d -> b n (h d)", h=h)
+
+ if additional_tokens is not None:
+ # remove additional token
+ out = out[:, n_tokens_to_mask:]
+ return self.to_out(out)
+
+
+class MemoryEfficientCrossAttention(nn.Module):
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ def __init__(
+ self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
+ ):
+ super().__init__()
+ # print(
+ # f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
+ # f"{heads} heads with a dimension of {dim_head}."
+ # )
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.heads = heads
+ self.dim_head = dim_head
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
+ )
+ self.attention_op: Optional[Any] = None
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ additional_tokens=None,
+ n_times_crossframe_attn_in_self=0,
+ ):
+ if additional_tokens is not None:
+ # get the number of masked tokens at the beginning of the output sequence
+ n_tokens_to_mask = additional_tokens.shape[1]
+ # add additional token
+ x = torch.cat([additional_tokens, x], dim=1)
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ if n_times_crossframe_attn_in_self:
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
+ # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
+ k = repeat(
+ k[::n_times_crossframe_attn_in_self],
+ "b ... -> (b n) ...",
+ n=n_times_crossframe_attn_in_self,
+ )
+ v = repeat(
+ v[::n_times_crossframe_attn_in_self],
+ "b ... -> (b n) ...",
+ n=n_times_crossframe_attn_in_self,
+ )
+
+ b, _, _ = q.shape
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+
+ # actually compute the attention, what we cannot get enough of
+ out = xformers.ops.memory_efficient_attention(
+ q, k, v, attn_bias=None, op=self.attention_op
+ )
+
+ # TODO: Use this directly in the attention operation, as a bias
+ if exists(mask):
+ raise NotImplementedError
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+ if additional_tokens is not None:
+ # remove additional token
+ out = out[:, n_tokens_to_mask:]
+ return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention, # ampere
+ }
+
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ add_context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ disable_self_attn=False,
+ attn_mode="softmax",
+ sdp_backend=None,
+ ):
+ super().__init__()
+ assert attn_mode in self.ATTENTION_MODES
+ if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
+ print(
+ f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
+ f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
+ )
+ attn_mode = "softmax"
+ elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
+ print(
+ "We do not support vanilla attention anymore, as it is too expensive. Sorry."
+ )
+ if not XFORMERS_IS_AVAILABLE:
+ assert (
+ False
+ ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
+ else:
+ print("Falling back to xformers efficient attention.")
+ attn_mode = "softmax-xformers"
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
+ assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
+ else:
+ assert sdp_backend is None
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = MemoryEfficientCrossAttention(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None,
+ backend=sdp_backend,
+ ) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ if context_dim is not None and context_dim > 0:
+ self.attn2 = attn_cls(
+ query_dim=dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ backend=sdp_backend,
+ ) # is self-attn if context is none
+ if add_context_dim is not None and add_context_dim > 0:
+ self.add_attn = attn_cls(
+ query_dim=dim,
+ context_dim=add_context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ backend=sdp_backend,
+ ) # is self-attn if context is none
+ self.add_norm = nn.LayerNorm(dim)
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(
+ self, x, context=None, add_context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
+ ):
+ kwargs = {"x": x}
+
+ if context is not None:
+ kwargs.update({"context": context})
+
+ if additional_tokens is not None:
+ kwargs.update({"additional_tokens": additional_tokens})
+
+ if n_times_crossframe_attn_in_self:
+ kwargs.update(
+ {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
+ )
+
+ return checkpoint(
+ self._forward, (x, context, add_context), self.parameters(), self.checkpoint
+ )
+
+ def _forward(
+ self, x, context=None, add_context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
+ ):
+ x = (
+ self.attn1(
+ self.norm1(x),
+ context=context if self.disable_self_attn else None,
+ additional_tokens=additional_tokens,
+ n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
+ if not self.disable_self_attn
+ else 0,
+ )
+ + x
+ )
+ if hasattr(self, "attn2"):
+ x = (
+ self.attn2(
+ self.norm2(x), context=context, additional_tokens=additional_tokens
+ )
+ + x
+ )
+ if hasattr(self, "add_attn"):
+ x = (
+ self.add_attn(
+ self.add_norm(x), context=add_context, additional_tokens=additional_tokens
+ )
+ + x
+ )
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class BasicTransformerSingleLayerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
+ # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
+ }
+
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ attn_mode="softmax",
+ ):
+ super().__init__()
+ assert attn_mode in self.ATTENTION_MODES
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ self.attn1 = attn_cls(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ context_dim=context_dim,
+ )
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ return checkpoint(
+ self._forward, (x, context), self.parameters(), self.checkpoint
+ )
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x), context=context) + x
+ x = self.ff(self.norm2(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ NEW: use_linear for more efficiency instead of the 1x1 convs
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.0,
+ context_dim=None,
+ add_context_dim=None,
+ disable_self_attn=False,
+ use_linear=False,
+ attn_type="softmax",
+ use_checkpoint=True,
+ # sdp_backend=SDPBackend.FLASH_ATTENTION
+ sdp_backend=None,
+ ):
+ super().__init__()
+ # print(
+ # f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
+ # )
+ from omegaconf import ListConfig
+
+ if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
+ context_dim = [context_dim]
+ if exists(context_dim) and isinstance(context_dim, list):
+ if depth != len(context_dim):
+ # print(
+ # f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
+ # f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
+ # )
+ # depth does not match context dims.
+ assert all(
+ map(lambda x: x == context_dim[0], context_dim)
+ ), "need homogenous context_dim to match depth automatically"
+ context_dim = depth * [context_dim[0]]
+ elif context_dim is None:
+ context_dim = [None] * depth
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+ if not use_linear:
+ self.proj_in = nn.Conv2d(
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
+ )
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ n_heads,
+ d_head,
+ dropout=dropout,
+ context_dim=context_dim[d],
+ add_context_dim=add_context_dim,
+ disable_self_attn=disable_self_attn,
+ attn_mode=attn_type,
+ checkpoint=use_checkpoint,
+ sdp_backend=sdp_backend,
+ )
+ for d in range(depth)
+ ]
+ )
+ if not use_linear:
+ self.proj_out = zero_module(
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ )
+ else:
+ # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None, add_context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ if not isinstance(context, list):
+ context = [context]
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ if i > 0 and len(context) == 1:
+ i = 0 # use same context for each block
+ x = block(x, context=context[i], add_context=add_context)
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
+
+
+def benchmark_attn():
+ # Lets define a helpful benchmarking function:
+ # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ import torch.nn.functional as F
+ import torch.utils.benchmark as benchmark
+
+ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
+ t0 = benchmark.Timer(
+ stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
+ )
+ return t0.blocked_autorange().mean * 1e6
+
+ # Lets define the hyper-parameters of our input
+ batch_size = 32
+ max_sequence_len = 1024
+ num_heads = 32
+ embed_dimension = 32
+
+ dtype = torch.float16
+
+ query = torch.rand(
+ batch_size,
+ num_heads,
+ max_sequence_len,
+ embed_dimension,
+ device=device,
+ dtype=dtype,
+ )
+ key = torch.rand(
+ batch_size,
+ num_heads,
+ max_sequence_len,
+ embed_dimension,
+ device=device,
+ dtype=dtype,
+ )
+ value = torch.rand(
+ batch_size,
+ num_heads,
+ max_sequence_len,
+ embed_dimension,
+ device=device,
+ dtype=dtype,
+ )
+
+ print(f"q/k/v shape:", query.shape, key.shape, value.shape)
+
+ # Lets explore the speed of each of the 3 implementations
+ from torch.backends.cuda import SDPBackend, sdp_kernel
+
+ # Helpful arguments mapper
+ backend_map = {
+ SDPBackend.MATH: {
+ "enable_math": True,
+ "enable_flash": False,
+ "enable_mem_efficient": False,
+ },
+ SDPBackend.FLASH_ATTENTION: {
+ "enable_math": False,
+ "enable_flash": True,
+ "enable_mem_efficient": False,
+ },
+ SDPBackend.EFFICIENT_ATTENTION: {
+ "enable_math": False,
+ "enable_flash": False,
+ "enable_mem_efficient": True,
+ },
+ }
+
+ from torch.profiler import ProfilerActivity, profile, record_function
+
+ activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
+
+ print(
+ f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
+ )
+ with profile(
+ activities=activities, record_shapes=False, profile_memory=True
+ ) as prof:
+ with record_function("Default detailed stats"):
+ for _ in range(25):
+ o = F.scaled_dot_product_attention(query, key, value)
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
+
+ print(
+ f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
+ )
+ with sdp_kernel(**backend_map[SDPBackend.MATH]):
+ with profile(
+ activities=activities, record_shapes=False, profile_memory=True
+ ) as prof:
+ with record_function("Math implmentation stats"):
+ for _ in range(25):
+ o = F.scaled_dot_product_attention(query, key, value)
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
+
+ with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
+ try:
+ print(
+ f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
+ )
+ except RuntimeError:
+ print("FlashAttention is not supported. See warnings for reasons.")
+ with profile(
+ activities=activities, record_shapes=False, profile_memory=True
+ ) as prof:
+ with record_function("FlashAttention stats"):
+ for _ in range(25):
+ o = F.scaled_dot_product_attention(query, key, value)
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
+
+ with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
+ try:
+ print(
+ f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
+ )
+ except RuntimeError:
+ print("EfficientAttention is not supported. See warnings for reasons.")
+ with profile(
+ activities=activities, record_shapes=False, profile_memory=True
+ ) as prof:
+ with record_function("EfficientAttention stats"):
+ for _ in range(25):
+ o = F.scaled_dot_product_attention(query, key, value)
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
+
+
+def run_model(model, x, context):
+ return model(x, context)
+
+
+def benchmark_transformer_blocks():
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ import torch.utils.benchmark as benchmark
+
+ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
+ t0 = benchmark.Timer(
+ stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
+ )
+ return t0.blocked_autorange().mean * 1e6
+
+ checkpoint = True
+ compile = False
+
+ batch_size = 32
+ h, w = 64, 64
+ context_len = 77
+ embed_dimension = 1024
+ context_dim = 1024
+ d_head = 64
+
+ transformer_depth = 4
+
+ n_heads = embed_dimension // d_head
+
+ dtype = torch.float16
+
+ model_native = SpatialTransformer(
+ embed_dimension,
+ n_heads,
+ d_head,
+ context_dim=context_dim,
+ use_linear=True,
+ use_checkpoint=checkpoint,
+ attn_type="softmax",
+ depth=transformer_depth,
+ sdp_backend=SDPBackend.FLASH_ATTENTION,
+ ).to(device)
+ model_efficient_attn = SpatialTransformer(
+ embed_dimension,
+ n_heads,
+ d_head,
+ context_dim=context_dim,
+ use_linear=True,
+ depth=transformer_depth,
+ use_checkpoint=checkpoint,
+ attn_type="softmax-xformers",
+ ).to(device)
+ if not checkpoint and compile:
+ print("compiling models")
+ model_native = torch.compile(model_native)
+ model_efficient_attn = torch.compile(model_efficient_attn)
+
+ x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
+ c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
+
+ from torch.profiler import ProfilerActivity, profile, record_function
+
+ activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
+
+ with torch.autocast("cuda"):
+ print(
+ f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
+ )
+ print(
+ f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
+ )
+
+ print(75 * "+")
+ print("NATIVE")
+ print(75 * "+")
+ torch.cuda.reset_peak_memory_stats()
+ with profile(
+ activities=activities, record_shapes=False, profile_memory=True
+ ) as prof:
+ with record_function("NativeAttention stats"):
+ for _ in range(25):
+ model_native(x, c)
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
+ print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
+
+ print(75 * "+")
+ print("Xformers")
+ print(75 * "+")
+ torch.cuda.reset_peak_memory_stats()
+ with profile(
+ activities=activities, record_shapes=False, profile_memory=True
+ ) as prof:
+ with record_function("xformers stats"):
+ for _ in range(25):
+ model_efficient_attn(x, c)
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
+ print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
+
+
+def test01():
+ # conv1x1 vs linear
+ from ..util import count_params
+
+ conv = nn.Conv2d(3, 32, kernel_size=1).cuda()
+ print(count_params(conv))
+ linear = torch.nn.Linear(3, 32).cuda()
+ print(count_params(linear))
+
+ print(conv.weight.shape)
+
+ # use same initialization
+ linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
+ linear.bias = torch.nn.Parameter(conv.bias)
+
+ print(linear.weight.shape)
+
+ x = torch.randn(11, 3, 64, 64).cuda()
+
+ xr = rearrange(x, "b c h w -> b (h w) c").contiguous()
+ print(xr.shape)
+ out_linear = linear(xr)
+ print(out_linear.mean(), out_linear.shape)
+
+ out_conv = conv(x)
+ print(out_conv.mean(), out_conv.shape)
+ print("done with test01.\n")
+
+
+def test02():
+ # try cosine flash attention
+ import time
+
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+ torch.backends.cudnn.benchmark = True
+ print("testing cosine flash attention...")
+ DIM = 1024
+ SEQLEN = 4096
+ BS = 16
+
+ print(" softmax (vanilla) first...")
+ model = BasicTransformerBlock(
+ dim=DIM,
+ n_heads=16,
+ d_head=64,
+ dropout=0.0,
+ context_dim=None,
+ attn_mode="softmax",
+ ).cuda()
+ try:
+ x = torch.randn(BS, SEQLEN, DIM).cuda()
+ tic = time.time()
+ y = model(x)
+ toc = time.time()
+ print(y.shape, toc - tic)
+ except RuntimeError as e:
+ # likely oom
+ print(str(e))
+
+ print("\n now flash-cosine...")
+ model = BasicTransformerBlock(
+ dim=DIM,
+ n_heads=16,
+ d_head=64,
+ dropout=0.0,
+ context_dim=None,
+ attn_mode="flash-cosine",
+ ).cuda()
+ x = torch.randn(BS, SEQLEN, DIM).cuda()
+ tic = time.time()
+ y = model(x)
+ toc = time.time()
+ print(y.shape, toc - tic)
+ print("done with test02.\n")
+
+
+if __name__ == "__main__":
+ # test01()
+ # test02()
+ # test03()
+
+ # benchmark_attn()
+ benchmark_transformer_blocks()
+
+ print("done.")
diff --git a/sgm/modules/autoencoding/__init__.py b/sgm/modules/autoencoding/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sgm/modules/autoencoding/__pycache__/__init__.cpython-310.pyc b/sgm/modules/autoencoding/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cd2db470e0c4888df74d1714c70e68f3abd521c4
Binary files /dev/null and b/sgm/modules/autoencoding/__pycache__/__init__.cpython-310.pyc differ
diff --git a/sgm/modules/autoencoding/__pycache__/__init__.cpython-311.pyc b/sgm/modules/autoencoding/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c88109b1d9d242049d8c1dc21ba0ca8f9e1cc8b1
Binary files /dev/null and b/sgm/modules/autoencoding/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sgm/modules/autoencoding/losses/__init__.py b/sgm/modules/autoencoding/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a3b54f7284ae1be6a23b425f6c296efc1881a5c
--- /dev/null
+++ b/sgm/modules/autoencoding/losses/__init__.py
@@ -0,0 +1,246 @@
+from typing import Any, Union
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
+from taming.modules.losses.lpips import LPIPS
+from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
+
+from ....util import default, instantiate_from_config
+
+
+def adopt_weight(weight, global_step, threshold=0, value=0.0):
+ if global_step < threshold:
+ weight = value
+ return weight
+
+
+class LatentLPIPS(nn.Module):
+ def __init__(
+ self,
+ decoder_config,
+ perceptual_weight=1.0,
+ latent_weight=1.0,
+ scale_input_to_tgt_size=False,
+ scale_tgt_to_input_size=False,
+ perceptual_weight_on_inputs=0.0,
+ ):
+ super().__init__()
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
+ self.scale_tgt_to_input_size = scale_tgt_to_input_size
+ self.init_decoder(decoder_config)
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+ self.latent_weight = latent_weight
+ self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
+
+ def init_decoder(self, config):
+ self.decoder = instantiate_from_config(config)
+ if hasattr(self.decoder, "encoder"):
+ del self.decoder.encoder
+
+ def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
+ log = dict()
+ loss = (latent_inputs - latent_predictions) ** 2
+ log[f"{split}/latent_l2_loss"] = loss.mean().detach()
+ image_reconstructions = None
+ if self.perceptual_weight > 0.0:
+ image_reconstructions = self.decoder.decode(latent_predictions)
+ image_targets = self.decoder.decode(latent_inputs)
+ perceptual_loss = self.perceptual_loss(
+ image_targets.contiguous(), image_reconstructions.contiguous()
+ )
+ loss = (
+ self.latent_weight * loss.mean()
+ + self.perceptual_weight * perceptual_loss.mean()
+ )
+ log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
+
+ if self.perceptual_weight_on_inputs > 0.0:
+ image_reconstructions = default(
+ image_reconstructions, self.decoder.decode(latent_predictions)
+ )
+ if self.scale_input_to_tgt_size:
+ image_inputs = torch.nn.functional.interpolate(
+ image_inputs,
+ image_reconstructions.shape[2:],
+ mode="bicubic",
+ antialias=True,
+ )
+ elif self.scale_tgt_to_input_size:
+ image_reconstructions = torch.nn.functional.interpolate(
+ image_reconstructions,
+ image_inputs.shape[2:],
+ mode="bicubic",
+ antialias=True,
+ )
+
+ perceptual_loss2 = self.perceptual_loss(
+ image_inputs.contiguous(), image_reconstructions.contiguous()
+ )
+ loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
+ log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
+ return loss, log
+
+
+class GeneralLPIPSWithDiscriminator(nn.Module):
+ def __init__(
+ self,
+ disc_start: int,
+ logvar_init: float = 0.0,
+ pixelloss_weight=1.0,
+ disc_num_layers: int = 3,
+ disc_in_channels: int = 3,
+ disc_factor: float = 1.0,
+ disc_weight: float = 1.0,
+ perceptual_weight: float = 1.0,
+ disc_loss: str = "hinge",
+ scale_input_to_tgt_size: bool = False,
+ dims: int = 2,
+ learn_logvar: bool = False,
+ regularization_weights: Union[None, dict] = None,
+ ):
+ super().__init__()
+ self.dims = dims
+ if self.dims > 2:
+ print(
+ f"running with dims={dims}. This means that for perceptual loss calculation, "
+ f"the LPIPS loss will be applied to each frame independently. "
+ )
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
+ assert disc_loss in ["hinge", "vanilla"]
+ self.pixel_weight = pixelloss_weight
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+ # output log variance
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
+ self.learn_logvar = learn_logvar
+
+ self.discriminator = NLayerDiscriminator(
+ input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=False
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.regularization_weights = default(regularization_weights, {})
+
+ def get_trainable_parameters(self) -> Any:
+ return self.discriminator.parameters()
+
+ def get_trainable_autoencoder_parameters(self) -> Any:
+ if self.learn_logvar:
+ yield self.logvar
+ yield from ()
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(
+ nll_loss, self.last_layer[0], retain_graph=True
+ )[0]
+ g_grads = torch.autograd.grad(
+ g_loss, self.last_layer[0], retain_graph=True
+ )[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(
+ self,
+ regularization_log,
+ inputs,
+ reconstructions,
+ optimizer_idx,
+ global_step,
+ last_layer=None,
+ split="train",
+ weights=None,
+ ):
+ if self.scale_input_to_tgt_size:
+ inputs = torch.nn.functional.interpolate(
+ inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
+ )
+
+ if self.dims > 2:
+ inputs, reconstructions = map(
+ lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
+ (inputs, reconstructions),
+ )
+
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(
+ inputs.contiguous(), reconstructions.contiguous()
+ )
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
+ weighted_nll_loss = nll_loss
+ if weights is not None:
+ weighted_nll_loss = weights * nll_loss
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ g_loss = -torch.mean(logits_fake)
+
+ if self.disc_factor > 0.0:
+ try:
+ d_weight = self.calculate_adaptive_weight(
+ nll_loss, g_loss, last_layer=last_layer
+ )
+ except RuntimeError:
+ assert not self.training
+ d_weight = torch.tensor(0.0)
+ else:
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
+ )
+ loss = weighted_nll_loss + d_weight * disc_factor * g_loss
+ log = dict()
+ for k in regularization_log:
+ if k in self.regularization_weights:
+ loss = loss + self.regularization_weights[k] * regularization_log[k]
+ log[f"{split}/{k}"] = regularization_log[k].detach().mean()
+
+ log.update(
+ {
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/logvar".format(split): self.logvar.detach(),
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ )
+
+ return loss, log
+
+ if optimizer_idx == 1:
+ # second pass for discriminator update
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+
+ disc_factor = adopt_weight(
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
+ )
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ log = {
+ "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+ "{}/logits_real".format(split): logits_real.detach().mean(),
+ "{}/logits_fake".format(split): logits_fake.detach().mean(),
+ }
+ return d_loss, log
diff --git a/sgm/modules/autoencoding/regularizers/__init__.py b/sgm/modules/autoencoding/regularizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8de3212d3be4f58e621e8caa6e31dd8dc32b6929
--- /dev/null
+++ b/sgm/modules/autoencoding/regularizers/__init__.py
@@ -0,0 +1,53 @@
+from abc import abstractmethod
+from typing import Any, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ....modules.distributions.distributions import DiagonalGaussianDistribution
+
+
+class AbstractRegularizer(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
+ raise NotImplementedError()
+
+ @abstractmethod
+ def get_trainable_parameters(self) -> Any:
+ raise NotImplementedError()
+
+
+class DiagonalGaussianRegularizer(AbstractRegularizer):
+ def __init__(self, sample: bool = True):
+ super().__init__()
+ self.sample = sample
+
+ def get_trainable_parameters(self) -> Any:
+ yield from ()
+
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
+ log = dict()
+ posterior = DiagonalGaussianDistribution(z)
+ if self.sample:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ kl_loss = posterior.kl()
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
+ log["kl_loss"] = kl_loss
+ return z, log
+
+
+def measure_perplexity(predicted_indices, num_centroids):
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
+ encodings = (
+ F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
+ )
+ avg_probs = encodings.mean(0)
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
+ cluster_use = torch.sum(avg_probs > 0)
+ return perplexity, cluster_use
diff --git a/sgm/modules/autoencoding/regularizers/__pycache__/__init__.cpython-310.pyc b/sgm/modules/autoencoding/regularizers/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..78c77ab4c2e7d85bbdbda1c20a705a30bed6b9b8
Binary files /dev/null and b/sgm/modules/autoencoding/regularizers/__pycache__/__init__.cpython-310.pyc differ
diff --git a/sgm/modules/autoencoding/regularizers/__pycache__/__init__.cpython-311.pyc b/sgm/modules/autoencoding/regularizers/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0d73b0f3b49223cddb2194676d9ae30c48843425
Binary files /dev/null and b/sgm/modules/autoencoding/regularizers/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__init__.py b/sgm/modules/diffusionmodules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce7968af9224aff42a20023b4e14ca059939e034
--- /dev/null
+++ b/sgm/modules/diffusionmodules/__init__.py
@@ -0,0 +1,7 @@
+from .denoiser import Denoiser
+from .discretizer import Discretization
+from .loss import StandardDiffusionLoss
+from .model import Model, Encoder, Decoder
+from .openaimodel import UNetModel
+from .sampling import BaseDiffusionSampler
+from .wrappers import OpenAIWrapper
diff --git a/sgm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc b/sgm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eee33aa50fdfd415a2fb09cacf32faf1782850a1
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/__init__.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f5dc76ad0946dc85c0e49b8998f6b60c2f608c93
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/denoiser.cpython-310.pyc b/sgm/modules/diffusionmodules/__pycache__/denoiser.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d2a2a04e67ef637478cde2fc5ca873e8c7123319
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/denoiser.cpython-310.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/denoiser.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/denoiser.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ddfea30128caad253694b31fd9cf83b364cace50
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/denoiser.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/denoiser_scaling.cpython-310.pyc b/sgm/modules/diffusionmodules/__pycache__/denoiser_scaling.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7384a2e3460c42bf6d90d7ff7b9e98e3c96716a8
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/denoiser_scaling.cpython-310.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/denoiser_scaling.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/denoiser_scaling.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f71354da122fff53ded2a2c188ee64d02449f738
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/denoiser_scaling.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/denoiser_weighting.cpython-310.pyc b/sgm/modules/diffusionmodules/__pycache__/denoiser_weighting.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1f33259f8d0dd7e71c06bbb0518e0534fa2bd5dc
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/denoiser_weighting.cpython-310.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/denoiser_weighting.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/denoiser_weighting.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e93cbd4a789e903d9b12d7f0b60ac96e5008229
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/denoiser_weighting.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/discretizer.cpython-310.pyc b/sgm/modules/diffusionmodules/__pycache__/discretizer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d4ab34df095729f447451b82a4d2728e85def50a
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/discretizer.cpython-310.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/discretizer.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/discretizer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3565ab3a49414fb9e884f996207993fb38b5783a
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/discretizer.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/guiders.cpython-310.pyc b/sgm/modules/diffusionmodules/__pycache__/guiders.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..428cb689cc57290f4bf24f8b323f2f8efdb769d3
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/guiders.cpython-310.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/guiders.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/guiders.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..78a1a6d7176176f4d74f95a934086d6e32f660ce
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/guiders.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/loss.cpython-310.pyc b/sgm/modules/diffusionmodules/__pycache__/loss.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..43bf8e2ccfd6e5139b8f590861b0a938389d74f8
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/loss.cpython-310.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/loss.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/loss.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96f13b5cbb2fa6cd5a90da4ba3429f94327390ec
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/loss.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc b/sgm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b66c3a3e1f168d706f572b84b54d4156068b888
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/model.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cb03d345ce95b6d9208d402f08815d5f34620226
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/model.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc b/sgm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4648b23cb30e58694f8f3de99c7f5bb9381054c9
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/openaimodel.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/openaimodel.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..546e1ee7599bde9e23e25377b4c26eab1b4cad09
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/openaimodel.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/sampling.cpython-310.pyc b/sgm/modules/diffusionmodules/__pycache__/sampling.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..68af44c9466aa662540911ec8b1354dc5ea428d5
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/sampling.cpython-310.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/sampling.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/sampling.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e919213b0c5dd4127dc0cfdca708ab0afb6b621c
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/sampling.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/sampling_utils.cpython-310.pyc b/sgm/modules/diffusionmodules/__pycache__/sampling_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c83b8719cfeba58fb686f9375da20f3c71851088
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/sampling_utils.cpython-310.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/sampling_utils.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/sampling_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0d4c084c5cb4d4bcb6990ed370f2f6aac00bdb99
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/sampling_utils.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/sigma_sampling.cpython-310.pyc b/sgm/modules/diffusionmodules/__pycache__/sigma_sampling.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3dd0413bf8c9cd5c3a474cfbd5d8404c9d957b69
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/sigma_sampling.cpython-310.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/sigma_sampling.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/sigma_sampling.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b0d6329ad9a7cf480a0492cfe8554b1ae325c9bb
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/sigma_sampling.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc b/sgm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bbdb4713eb1c5864f37e635f974b3ceeaf5f191b
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/util.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/util.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9171c7711dc732f8d6ab51c8938511f4d752b61d
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/util.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/wrappers.cpython-310.pyc b/sgm/modules/diffusionmodules/__pycache__/wrappers.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..949dfe2ed217069a315b7e728c441ffafd85f363
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/wrappers.cpython-310.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/wrappers.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/wrappers.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f17129ea6d66744ecfd47fe95709ceabf182352d
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/wrappers.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/denoiser.py b/sgm/modules/diffusionmodules/denoiser.py
new file mode 100644
index 0000000000000000000000000000000000000000..4651e7de5c4a90e0656843821d5d32a5ce0ddd0e
--- /dev/null
+++ b/sgm/modules/diffusionmodules/denoiser.py
@@ -0,0 +1,63 @@
+import torch.nn as nn
+
+from ...util import append_dims, instantiate_from_config
+
+
+class Denoiser(nn.Module):
+ def __init__(self, weighting_config, scaling_config):
+ super().__init__()
+
+ self.weighting = instantiate_from_config(weighting_config)
+ self.scaling = instantiate_from_config(scaling_config)
+
+ def possibly_quantize_sigma(self, sigma):
+ return sigma
+
+ def possibly_quantize_c_noise(self, c_noise):
+ return c_noise
+
+ def w(self, sigma):
+ return self.weighting(sigma)
+
+ def __call__(self, network, input, sigma, cond):
+ sigma = self.possibly_quantize_sigma(sigma)
+ sigma_shape = sigma.shape
+ sigma = append_dims(sigma, input.ndim)
+ c_skip, c_out, c_in, c_noise = self.scaling(sigma)
+ c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
+ return network(input * c_in, c_noise, cond) * c_out + input * c_skip
+
+
+class DiscreteDenoiser(Denoiser):
+ def __init__(
+ self,
+ weighting_config,
+ scaling_config,
+ num_idx,
+ discretization_config,
+ do_append_zero=False,
+ quantize_c_noise=True,
+ flip=True,
+ ):
+ super().__init__(weighting_config, scaling_config)
+ sigmas = instantiate_from_config(discretization_config)(
+ num_idx, do_append_zero=do_append_zero, flip=flip
+ )
+ self.register_buffer("sigmas", sigmas)
+ self.quantize_c_noise = quantize_c_noise
+
+ def sigma_to_idx(self, sigma):
+ dists = sigma - self.sigmas[:, None]
+ return dists.abs().argmin(dim=0).view(sigma.shape)
+
+ def idx_to_sigma(self, idx):
+ return self.sigmas[idx]
+
+ def possibly_quantize_sigma(self, sigma):
+ return self.idx_to_sigma(self.sigma_to_idx(sigma))
+
+ def possibly_quantize_c_noise(self, c_noise):
+ if self.quantize_c_noise:
+ return self.sigma_to_idx(c_noise)
+ else:
+ return c_noise
diff --git a/sgm/modules/diffusionmodules/denoiser_scaling.py b/sgm/modules/diffusionmodules/denoiser_scaling.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8a2ac6732ea78f1030b21bebd14063d52ac2a82
--- /dev/null
+++ b/sgm/modules/diffusionmodules/denoiser_scaling.py
@@ -0,0 +1,31 @@
+import torch
+
+
+class EDMScaling:
+ def __init__(self, sigma_data=0.5):
+ self.sigma_data = sigma_data
+
+ def __call__(self, sigma):
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
+ c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
+ c_noise = 0.25 * sigma.log()
+ return c_skip, c_out, c_in, c_noise
+
+
+class EpsScaling:
+ def __call__(self, sigma):
+ c_skip = torch.ones_like(sigma, device=sigma.device)
+ c_out = -sigma
+ c_in = 1 / (sigma**2 + 1.0) ** 0.5
+ c_noise = sigma.clone()
+ return c_skip, c_out, c_in, c_noise
+
+
+class VScaling:
+ def __call__(self, sigma):
+ c_skip = 1.0 / (sigma**2 + 1.0)
+ c_out = -sigma / (sigma**2 + 1.0) ** 0.5
+ c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
+ c_noise = sigma.clone()
+ return c_skip, c_out, c_in, c_noise
diff --git a/sgm/modules/diffusionmodules/denoiser_weighting.py b/sgm/modules/diffusionmodules/denoiser_weighting.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8b03ca58f17ea3d7374f4bbb7bf1d2994755e00
--- /dev/null
+++ b/sgm/modules/diffusionmodules/denoiser_weighting.py
@@ -0,0 +1,24 @@
+import torch
+
+
+class UnitWeighting:
+ def __call__(self, sigma):
+ return torch.ones_like(sigma, device=sigma.device)
+
+
+class EDMWeighting:
+ def __init__(self, sigma_data=0.5):
+ self.sigma_data = sigma_data
+
+ def __call__(self, sigma):
+ return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
+
+
+class VWeighting(EDMWeighting):
+ def __init__(self):
+ super().__init__(sigma_data=1.0)
+
+
+class EpsWeighting:
+ def __call__(self, sigma):
+ return sigma**-2.0
diff --git a/sgm/modules/diffusionmodules/discretizer.py b/sgm/modules/diffusionmodules/discretizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..397b8f386615b50bf83742eb0a8c02a95a6ffefc
--- /dev/null
+++ b/sgm/modules/diffusionmodules/discretizer.py
@@ -0,0 +1,68 @@
+import torch
+import numpy as np
+from functools import partial
+from abc import abstractmethod
+
+from ...util import append_zero
+from ...modules.diffusionmodules.util import make_beta_schedule
+
+
+def generate_roughly_equally_spaced_steps(
+ num_substeps: int, max_step: int
+) -> np.ndarray:
+ return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
+
+
+class Discretization:
+ def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
+ sigmas = self.get_sigmas(n, device=device)
+ sigmas = append_zero(sigmas) if do_append_zero else sigmas
+ return sigmas if not flip else torch.flip(sigmas, (0,))
+
+ @abstractmethod
+ def get_sigmas(self, n, device):
+ pass
+
+
+class EDMDiscretization(Discretization):
+ def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0):
+ self.sigma_min = sigma_min
+ self.sigma_max = sigma_max
+ self.rho = rho
+
+ def get_sigmas(self, n, device="cpu"):
+ ramp = torch.linspace(0, 1, n, device=device)
+ min_inv_rho = self.sigma_min ** (1 / self.rho)
+ max_inv_rho = self.sigma_max ** (1 / self.rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
+ return sigmas
+
+
+class LegacyDDPMDiscretization(Discretization):
+ def __init__(
+ self,
+ linear_start=0.00085,
+ linear_end=0.0120,
+ num_timesteps=1000,
+ ):
+ super().__init__()
+ self.num_timesteps = num_timesteps
+ betas = make_beta_schedule(
+ "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
+ )
+ alphas = 1.0 - betas
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
+ self.to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ def get_sigmas(self, n, device="cpu"):
+ if n < self.num_timesteps:
+ timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
+ alphas_cumprod = self.alphas_cumprod[timesteps]
+ elif n == self.num_timesteps:
+ alphas_cumprod = self.alphas_cumprod
+ else:
+ raise ValueError
+
+ to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
+ sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
+ return torch.flip(sigmas, (0,))
diff --git a/sgm/modules/diffusionmodules/guiders.py b/sgm/modules/diffusionmodules/guiders.py
new file mode 100644
index 0000000000000000000000000000000000000000..078de2920a8b20b911c7a42e2ffad67c4425e20b
--- /dev/null
+++ b/sgm/modules/diffusionmodules/guiders.py
@@ -0,0 +1,81 @@
+from functools import partial
+
+import torch
+
+from ...util import default, instantiate_from_config
+
+
+class VanillaCFG:
+ """
+ implements parallelized CFG
+ """
+
+ def __init__(self, scale, dyn_thresh_config=None):
+ scale_schedule = lambda scale, sigma: scale # independent of step
+ self.scale_schedule = partial(scale_schedule, scale)
+ self.dyn_thresh = instantiate_from_config(
+ default(
+ dyn_thresh_config,
+ {
+ "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
+ },
+ )
+ )
+
+ def __call__(self, x, sigma):
+ x_u, x_c = x.chunk(2)
+ scale_value = self.scale_schedule(sigma)
+ x_pred = self.dyn_thresh(x_u, x_c, scale_value)
+ return x_pred
+
+ def prepare_inputs(self, x, s, c, uc):
+ c_out = dict()
+
+ for k in c:
+ if k in ["vector", "crossattn", "add_crossattn", "concat"]:
+ c_out[k] = torch.cat((uc[k], c[k]), 0)
+ else:
+ assert c[k] == uc[k]
+ c_out[k] = c[k]
+ return torch.cat([x] * 2), torch.cat([s] * 2), c_out
+
+
+class DualCFG:
+
+ def __init__(self, scale):
+ self.scale = scale
+ self.dyn_thresh = instantiate_from_config(
+ {
+ "target": "sgm.modules.diffusionmodules.sampling_utils.DualThresholding"
+ },
+ )
+
+ def __call__(self, x, sigma):
+ x_u_1, x_u_2, x_c = x.chunk(3)
+ x_pred = self.dyn_thresh(x_u_1, x_u_2, x_c, self.scale)
+ return x_pred
+
+ def prepare_inputs(self, x, s, c, uc_1, uc_2):
+ c_out = dict()
+
+ for k in c:
+ if k in ["vector", "crossattn", "concat", "add_crossattn"]:
+ c_out[k] = torch.cat((uc_1[k], uc_2[k], c[k]), 0)
+ else:
+ assert c[k] == uc_1[k]
+ c_out[k] = c[k]
+ return torch.cat([x] * 3), torch.cat([s] * 3), c_out
+
+
+
+class IdentityGuider:
+ def __call__(self, x, sigma):
+ return x
+
+ def prepare_inputs(self, x, s, c, uc):
+ c_out = dict()
+
+ for k in c:
+ c_out[k] = c[k]
+
+ return x, s, c_out
diff --git a/sgm/modules/diffusionmodules/loss.py b/sgm/modules/diffusionmodules/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4ba0b417cb53abb00e7e067c4b5b8332ec152f5
--- /dev/null
+++ b/sgm/modules/diffusionmodules/loss.py
@@ -0,0 +1,268 @@
+from typing import List, Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from omegaconf import ListConfig
+from torchvision.utils import save_image
+from ...util import append_dims, instantiate_from_config
+
+
+class StandardDiffusionLoss(nn.Module):
+ def __init__(
+ self,
+ sigma_sampler_config,
+ type="l2",
+ offset_noise_level=0.0,
+ batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
+ ):
+ super().__init__()
+
+ assert type in ["l2", "l1"]
+
+ self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
+
+ self.type = type
+ self.offset_noise_level = offset_noise_level
+
+ if not batch2model_keys:
+ batch2model_keys = []
+
+ if isinstance(batch2model_keys, str):
+ batch2model_keys = [batch2model_keys]
+
+ self.batch2model_keys = set(batch2model_keys)
+
+ def __call__(self, network, denoiser, conditioner, input, batch, *args, **kwarg):
+ cond = conditioner(batch)
+ additional_model_inputs = {
+ key: batch[key] for key in self.batch2model_keys.intersection(batch)
+ }
+
+ sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
+ noise = torch.randn_like(input)
+ if self.offset_noise_level > 0.0:
+ noise = noise + self.offset_noise_level * append_dims(
+ torch.randn(input.shape[0], device=input.device), input.ndim
+ )
+ noised_input = input + noise * append_dims(sigmas, input.ndim)
+ model_output = denoiser(
+ network, noised_input, sigmas, cond, **additional_model_inputs
+ )
+ w = append_dims(denoiser.w(sigmas), input.ndim)
+
+ loss = self.get_diff_loss(model_output, input, w)
+ loss = loss.mean()
+ loss_dict = {"loss": loss}
+
+ return loss, loss_dict
+
+ def get_diff_loss(self, model_output, target, w):
+ if self.type == "l2":
+ return torch.mean(
+ (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
+ )
+ elif self.type == "l1":
+ return torch.mean(
+ (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
+ )
+
+
+class FullLoss(StandardDiffusionLoss):
+
+ def __init__(
+ self,
+ seq_len=12,
+ kernel_size=3,
+ gaussian_sigma=0.5,
+ min_attn_size=16,
+ lambda_local_loss=0.0,
+ lambda_ocr_loss=0.0,
+ ocr_enabled = False,
+ predictor_config = None,
+ *args, **kwarg
+ ):
+ super().__init__(*args, **kwarg)
+
+ self.gaussian_kernel_size = kernel_size
+ gaussian_kernel = self.get_gaussian_kernel(kernel_size=self.gaussian_kernel_size, sigma=gaussian_sigma, out_channels=seq_len)
+ self.register_buffer("g_kernel", gaussian_kernel.requires_grad_(False))
+
+ self.min_attn_size = min_attn_size
+ self.lambda_local_loss = lambda_local_loss
+ self.lambda_ocr_loss = lambda_ocr_loss
+
+ self.ocr_enabled = ocr_enabled
+ if ocr_enabled:
+ self.predictor = instantiate_from_config(predictor_config)
+
+ def get_gaussian_kernel(self, kernel_size=3, sigma=1, out_channels=3):
+ # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
+ x_coord = torch.arange(kernel_size)
+ x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
+ y_grid = x_grid.t()
+ xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
+
+ mean = (kernel_size - 1)/2.
+ variance = sigma**2.
+
+ # Calculate the 2-dimensional gaussian kernel which is
+ # the product of two gaussian distributions for two different
+ # variables (in this case called x and y)
+ gaussian_kernel = (1./(2.*torch.pi*variance)) *\
+ torch.exp(
+ -torch.sum((xy_grid - mean)**2., dim=-1) /\
+ (2*variance)
+ )
+
+ # Make sure sum of values in gaussian kernel equals 1.
+ gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
+
+ # Reshape to 2d depthwise convolutional weight
+ gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
+ gaussian_kernel = gaussian_kernel.tile(out_channels, 1, 1, 1)
+
+ return gaussian_kernel
+
+ def __call__(self, network, denoiser, conditioner, input, batch, first_stage_model, scaler):
+
+ cond = conditioner(batch)
+
+ sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
+ noise = torch.randn_like(input)
+ if self.offset_noise_level > 0.0:
+ noise = noise + self.offset_noise_level * append_dims(
+ torch.randn(input.shape[0], device=input.device), input.ndim
+ )
+
+ noised_input = input + noise * append_dims(sigmas, input.ndim)
+ model_output = denoiser(network, noised_input, sigmas, cond)
+ w = append_dims(denoiser.w(sigmas), input.ndim)
+
+ diff_loss = self.get_diff_loss(model_output, input, w)
+ local_loss = self.get_local_loss(network.diffusion_model.attn_map_cache, batch["seg"], batch["seg_mask"])
+ diff_loss = diff_loss.mean()
+ local_loss = local_loss.mean()
+
+ if self.ocr_enabled:
+ ocr_loss = self.get_ocr_loss(model_output, batch["r_bbox"], batch["label"], first_stage_model, scaler)
+ ocr_loss = ocr_loss.mean()
+
+ loss = diff_loss + self.lambda_local_loss * local_loss
+ if self.ocr_enabled:
+ loss += self.lambda_ocr_loss * ocr_loss
+
+ loss_dict = {
+ "loss/diff_loss": diff_loss,
+ "loss/local_loss": local_loss,
+ "loss/full_loss": loss
+ }
+
+ if self.ocr_enabled:
+ loss_dict["loss/ocr_loss"] = ocr_loss
+
+ return loss, loss_dict
+
+ def get_ocr_loss(self, model_output, r_bbox, label, first_stage_model, scaler):
+
+ model_output = 1 / scaler * model_output
+ model_output_decoded = first_stage_model.decode(model_output)
+ model_output_crops = []
+
+ for i, bbox in enumerate(r_bbox):
+ m_top, m_bottom, m_left, m_right = bbox
+ model_output_crops.append(model_output_decoded[i, :, m_top:m_bottom, m_left:m_right])
+
+ loss = self.predictor.calc_loss(model_output_crops, label)
+
+ return loss
+
+ def get_min_local_loss(self, attn_map_cache, mask, seg_mask):
+
+ loss = 0
+ count = 0
+
+ for item in attn_map_cache:
+
+ heads = item["heads"]
+ size = item["size"]
+ attn_map = item["attn_map"]
+
+ if size < self.min_attn_size: continue
+
+ seg_l = seg_mask.shape[1]
+
+ bh, n, l = attn_map.shape # bh: batch size * heads / n : pixel length(h*w) / l: token length
+ attn_map = attn_map.reshape((-1, heads, n, l)) # b, h, n, l
+
+ assert seg_l <= l
+ attn_map = attn_map[..., :seg_l]
+ attn_map = attn_map.permute(0, 1, 3, 2) # b, h, l, n
+ attn_map = attn_map.mean(dim = 1) # b, l, n
+
+ attn_map = attn_map.reshape((-1, seg_l, size, size)) # b, l, s, s
+ attn_map = F.conv2d(attn_map, self.g_kernel, padding = self.gaussian_kernel_size//2, groups=seg_l) # gaussian blur on each channel
+ attn_map = attn_map.reshape((-1, seg_l, n)) # b, l, n
+
+ mask_map = F.interpolate(mask, (size, size))
+ mask_map = mask_map.tile((1, seg_l, 1, 1))
+ mask_map = mask_map.reshape((-1, seg_l, n)) # b, l, n
+
+ p_loss = (mask_map * attn_map).max(dim = -1)[0] # b, l
+ p_loss = p_loss + (1 - seg_mask) # b, l
+ p_loss = p_loss.min(dim = -1)[0] # b,
+
+ loss += -p_loss
+ count += 1
+
+ loss = loss / count
+
+ return loss
+
+ def get_local_loss(self, attn_map_cache, seg, seg_mask):
+
+ loss = 0
+ count = 0
+
+ for item in attn_map_cache:
+
+ heads = item["heads"]
+ size = item["size"]
+ attn_map = item["attn_map"]
+
+ if size < self.min_attn_size: continue
+
+ seg_l = seg_mask.shape[1]
+
+ bh, n, l = attn_map.shape # bh: batch size * heads / n : pixel length(h*w) / l: token length
+ attn_map = attn_map.reshape((-1, heads, n, l)) # b, h, n, l
+
+ assert seg_l <= l
+ attn_map = attn_map[..., :seg_l]
+ attn_map = attn_map.permute(0, 1, 3, 2) # b, h, l, n
+ attn_map = attn_map.mean(dim = 1) # b, l, n
+
+ attn_map = attn_map.reshape((-1, seg_l, size, size)) # b, l, s, s
+ attn_map = F.conv2d(attn_map, self.g_kernel, padding = self.gaussian_kernel_size//2, groups=seg_l) # gaussian blur on each channel
+ attn_map = attn_map.reshape((-1, seg_l, n)) # b, l, n
+
+ seg_map = F.interpolate(seg, (size, size))
+ seg_map = seg_map.reshape((-1, seg_l, n)) # b, l, n
+ n_seg_map = 1 - seg_map
+
+ p_loss = (seg_map * attn_map).max(dim = -1)[0] # b, l
+ n_loss = (n_seg_map * attn_map).max(dim = -1)[0] # b, l
+
+ p_loss = p_loss * seg_mask # b, l
+ n_loss = n_loss * seg_mask # b, l
+
+ p_loss = p_loss.sum(dim = -1) / seg_mask.sum(dim = -1) # b,
+ n_loss = n_loss.sum(dim = -1) / seg_mask.sum(dim = -1) # b,
+
+ f_loss = n_loss - p_loss # b,
+ loss += f_loss
+ count += 1
+
+ loss = loss / count
+
+ return loss
\ No newline at end of file
diff --git a/sgm/modules/diffusionmodules/model.py b/sgm/modules/diffusionmodules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..26efd07840def93d44513a5dabd79edbb7cee662
--- /dev/null
+++ b/sgm/modules/diffusionmodules/model.py
@@ -0,0 +1,743 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+from typing import Any, Callable, Optional
+
+import numpy as np
+import torch
+import torch.nn as nn
+from einops import rearrange
+from packaging import version
+
+try:
+ import xformers
+ import xformers.ops
+
+ XFORMERS_IS_AVAILABLE = True
+except:
+ XFORMERS_IS_AVAILABLE = False
+ print("no module 'xformers'. Processing without...")
+
+from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
+ )
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
+ )
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout,
+ temb_channels=512,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class LinAttnBlock(LinearAttention):
+ """to match AttnBlock usage"""
+
+ def __init__(self, in_channels):
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ b, c, h, w = q.shape
+ q, k, v = map(
+ lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)
+ )
+ h_ = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v
+ ) # scale is dim ** -0.5 per default
+ # compute attention
+
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
+
+ def forward(self, x, **kwargs):
+ h_ = x
+ h_ = self.attention(h_)
+ h_ = self.proj_out(h_)
+ return x + h_
+
+
+class MemoryEfficientAttnBlock(nn.Module):
+ """
+ Uses xformers efficient implementation,
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ Note: this is a single-head self-attention operation
+ """
+
+ #
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.attention_op: Optional[Any] = None
+
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ B, C, H, W = q.shape
+ q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
+
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(B, t.shape[1], 1, C)
+ .permute(0, 2, 1, 3)
+ .reshape(B * 1, t.shape[1], C)
+ .contiguous(),
+ (q, k, v),
+ )
+ out = xformers.ops.memory_efficient_attention(
+ q, k, v, attn_bias=None, op=self.attention_op
+ )
+
+ out = (
+ out.unsqueeze(0)
+ .reshape(B, 1, out.shape[1], C)
+ .permute(0, 2, 1, 3)
+ .reshape(B, out.shape[1], C)
+ )
+ return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
+
+ def forward(self, x, **kwargs):
+ h_ = x
+ h_ = self.attention(h_)
+ h_ = self.proj_out(h_)
+ return x + h_
+
+
+class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
+ def forward(self, x, context=None, mask=None, **unused_kwargs):
+ b, c, h, w = x.shape
+ x = rearrange(x, "b c h w -> b (h w) c")
+ out = super().forward(x, context=context, mask=mask)
+ out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
+ return x + out
+
+
+def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
+ assert attn_type in [
+ "vanilla",
+ "vanilla-xformers",
+ "memory-efficient-cross-attn",
+ "linear",
+ "none",
+ ], f"attn_type {attn_type} unknown"
+ if (
+ version.parse(torch.__version__) < version.parse("2.0.0")
+ and attn_type != "none"
+ ):
+ assert XFORMERS_IS_AVAILABLE, (
+ f"We do not support vanilla attention in {torch.__version__} anymore, "
+ f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
+ )
+ attn_type = "vanilla-xformers"
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ assert attn_kwargs is None
+ return AttnBlock(in_channels)
+ elif attn_type == "vanilla-xformers":
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
+ return MemoryEfficientAttnBlock(in_channels)
+ elif type == "memory-efficient-cross-attn":
+ attn_kwargs["query_dim"] = in_channels
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ return LinAttnBlock(in_channels)
+
+
+class Model(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ use_timestep=True,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch * 4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList(
+ [
+ torch.nn.Linear(self.ch, self.temb_ch),
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
+ ]
+ )
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
+ )
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ skip_in = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch * in_ch_mult[i_level]
+ block.append(
+ ResnetBlock(
+ in_channels=block_in + skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x, t=None, context=None):
+ # assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb
+ )
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ double_z=True,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ **ignore_kwargs,
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
+ )
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in,
+ 2 * z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ give_pre_end=False,
+ tanh_out=False,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ **ignorekwargs,
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+ print(
+ "Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)
+ )
+ )
+
+ make_attn_cls = self._make_attn()
+ make_resblock_cls = self._make_resblock()
+ make_conv_cls = self._make_conv()
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
+ )
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = make_resblock_cls(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
+ self.mid.block_2 = make_resblock_cls(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ make_resblock_cls(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn_cls(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = make_conv_cls(
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
+ )
+
+ def _make_attn(self) -> Callable:
+ return make_attn
+
+ def _make_resblock(self) -> Callable:
+ return ResnetBlock
+
+ def _make_conv(self) -> Callable:
+ return torch.nn.Conv2d
+
+ def get_last_layer(self, **kwargs):
+ return self.conv_out.weight
+
+ def forward(self, z, **kwargs):
+ # assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb, **kwargs)
+ h = self.mid.attn_1(h, **kwargs)
+ h = self.mid.block_2(h, temb, **kwargs)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb, **kwargs)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h, **kwargs)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h, **kwargs)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
diff --git a/sgm/modules/diffusionmodules/openaimodel.py b/sgm/modules/diffusionmodules/openaimodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..2768d3232b6cb361b631661d428f272b9cd06f79
--- /dev/null
+++ b/sgm/modules/diffusionmodules/openaimodel.py
@@ -0,0 +1,2070 @@
+import os
+import math
+from abc import abstractmethod
+from functools import partial
+from typing import Iterable
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+
+from ...modules.attention import SpatialTransformer
+from ...modules.diffusionmodules.util import (
+ avg_pool_nd,
+ checkpoint,
+ conv_nd,
+ linear,
+ normalization,
+ timestep_embedding,
+ zero_module,
+)
+from ...util import default, exists
+
+
+# dummy replace
+def convert_module_to_f16(x):
+ pass
+
+
+def convert_module_to_f32(x):
+ pass
+
+
+## go
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
+ )
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(
+ self,
+ x,
+ emb,
+ context=None,
+ add_context=None,
+ skip_time_mix=False,
+ time_context=None,
+ num_video_frames=None,
+ time_context_cat=None,
+ use_crossframe_attention_in_spatial_layers=False,
+ ):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x, context, add_context)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(
+ self, channels, use_conv, dims=2, out_channels=None, padding=1, third_up=False
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ self.third_up = third_up
+ if use_conv:
+ self.conv = conv_nd(
+ dims, self.channels, self.out_channels, 3, padding=padding
+ )
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ t_factor = 1 if not self.third_up else 2
+ x = F.interpolate(
+ x,
+ (t_factor * x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
+ mode="nearest",
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class TransposedUpsample(nn.Module):
+ "Learned 2x upsampling without padding"
+
+ def __init__(self, channels, out_channels=None, ks=5):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.up = nn.ConvTranspose2d(
+ self.channels, self.out_channels, kernel_size=ks, stride=2
+ )
+
+ def forward(self, x):
+ return self.up(x)
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(
+ self, channels, use_conv, dims=2, out_channels=None, padding=1, third_down=False
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
+ if use_conv:
+ # print(f"Building a Downsample layer with {dims} dims.")
+ # print(
+ # f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
+ # f"kernel-size: 3, stride: {stride}, padding: {padding}"
+ # )
+ if dims == 3:
+ pass
+ # print(f" --> Downsampling third axis (time): {third_down}")
+ self.op = conv_nd(
+ dims,
+ self.channels,
+ self.out_channels,
+ 3,
+ stride=stride,
+ padding=padding,
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ kernel_size=3,
+ exchange_temb_dims=False,
+ skip_t_emb=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+ self.exchange_temb_dims = exchange_temb_dims
+
+ if isinstance(kernel_size, Iterable):
+ padding = [k // 2 for k in kernel_size]
+ else:
+ padding = kernel_size // 2
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.skip_t_emb = skip_t_emb
+ self.emb_out_channels = (
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels
+ )
+ if self.skip_t_emb:
+ print(f"Skipping timestep embedding in {self.__class__.__name__}")
+ assert not self.use_scale_shift_norm
+ self.emb_layers = None
+ self.exchange_temb_dims = False
+ else:
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ self.emb_out_channels,
+ ),
+ )
+
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(
+ dims,
+ self.out_channels,
+ self.out_channels,
+ kernel_size,
+ padding=padding,
+ )
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, kernel_size, padding=padding
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ )
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+
+ if self.skip_t_emb:
+ emb_out = th.zeros_like(h)
+ else:
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ if self.exchange_temb_dims:
+ emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x, **kwargs):
+ # TODO add crossframe attention and use mixed checkpoint
+ return checkpoint(
+ self._forward, (x,), self.parameters(), True
+ ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ # return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial**2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class Timestep(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, t):
+ return timestep_embedding(t, self.dim)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None,
+ disable_middle_self_attn=False,
+ use_linear_in_transformer=False,
+ spatial_transformer_attn_type="softmax",
+ adm_in_channels=None,
+ use_fairscale_checkpoint=False,
+ offload_to_cpu=False,
+ transformer_depth_middle=None,
+ ):
+ super().__init__()
+ from omegaconf.listconfig import ListConfig
+
+ if use_spatial_transformer:
+ assert (
+ context_dim is not None
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
+
+ if context_dim is not None:
+ assert (
+ use_spatial_transformer
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert (
+ num_head_channels != -1
+ ), "Either num_heads or num_head_channels has to be set"
+
+ if num_head_channels == -1:
+ assert (
+ num_heads != -1
+ ), "Either num_heads or num_head_channels has to be set"
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(transformer_depth, int):
+ transformer_depth = len(channel_mult) * [transformer_depth]
+ elif isinstance(transformer_depth, ListConfig):
+ transformer_depth = list(transformer_depth)
+ transformer_depth_middle = default(
+ transformer_depth_middle, transformer_depth[-1]
+ )
+
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError(
+ "provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult"
+ )
+ self.num_res_blocks = num_res_blocks
+ # self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(
+ map(
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
+ range(len(num_attention_blocks)),
+ )
+ )
+ print(
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set."
+ ) # todo: convert to warning
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ if use_fp16:
+ print("WARNING: use_fp16 was dropped and has no effect anymore.")
+ # self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ assert use_fairscale_checkpoint != use_checkpoint or not (
+ use_checkpoint or use_fairscale_checkpoint
+ )
+
+ self.use_fairscale_checkpoint = False
+ checkpoint_wrapper_fn = (
+ partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
+ if self.use_fairscale_checkpoint
+ else lambda x: x
+ )
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = checkpoint_wrapper_fn(
+ nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ elif self.num_classes == "timestep":
+ self.label_emb = checkpoint_wrapper_fn(
+ nn.Sequential(
+ Timestep(model_channels),
+ nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ ),
+ )
+ )
+ elif self.num_classes == "sequential":
+ assert adm_in_channels is not None
+ self.label_emb = nn.Sequential(
+ nn.Sequential(
+ linear(adm_in_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ )
+ else:
+ raise ValueError()
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = (
+ ch // num_heads
+ if use_spatial_transformer
+ else num_head_channels
+ )
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if (
+ not exists(num_attention_blocks)
+ or nr < num_attention_blocks[level]
+ ):
+ layers.append(
+ checkpoint_wrapper_fn(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ if not use_spatial_transformer
+ else checkpoint_wrapper_fn(
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth[level],
+ context_dim=context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ),
+ checkpoint_wrapper_fn(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ if not use_spatial_transformer
+ else checkpoint_wrapper_fn(
+ SpatialTransformer( # always uses a self-attn
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth_middle,
+ context_dim=context_dim,
+ disable_self_attn=disable_middle_self_attn,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ ),
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = (
+ ch // num_heads
+ if use_spatial_transformer
+ else num_head_channels
+ )
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if (
+ not exists(num_attention_blocks)
+ or i < num_attention_blocks[level]
+ ):
+ layers.append(
+ checkpoint_wrapper_fn(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ if not use_spatial_transformer
+ else checkpoint_wrapper_fn(
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth[level],
+ context_dim=context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ )
+ )
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = checkpoint_wrapper_fn(
+ nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = checkpoint_wrapper_fn(
+ nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+ )
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y)
+
+ # h = x.type(self.dtype)
+ h = x
+ for i, module in enumerate(self.input_blocks):
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for i, module in enumerate(self.output_blocks):
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ assert False, "not supported anymore. what the f*** are you doing?"
+ else:
+ return self.out(h)
+
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None,
+ disable_middle_self_attn=False,
+ use_linear_in_transformer=False,
+ spatial_transformer_attn_type="softmax",
+ adm_in_channels=None,
+ use_fairscale_checkpoint=False,
+ offload_to_cpu=False,
+ transformer_depth_middle=None,
+ ):
+ super().__init__()
+ from omegaconf.listconfig import ListConfig
+
+ if use_spatial_transformer:
+ assert (
+ context_dim is not None
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
+
+ if context_dim is not None:
+ assert (
+ use_spatial_transformer
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert (
+ num_head_channels != -1
+ ), "Either num_heads or num_head_channels has to be set"
+
+ if num_head_channels == -1:
+ assert (
+ num_heads != -1
+ ), "Either num_heads or num_head_channels has to be set"
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(transformer_depth, int):
+ transformer_depth = len(channel_mult) * [transformer_depth]
+ elif isinstance(transformer_depth, ListConfig):
+ transformer_depth = list(transformer_depth)
+ transformer_depth_middle = default(
+ transformer_depth_middle, transformer_depth[-1]
+ )
+
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError(
+ "provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult"
+ )
+ self.num_res_blocks = num_res_blocks
+ # self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(
+ map(
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
+ range(len(num_attention_blocks)),
+ )
+ )
+ print(
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set."
+ ) # todo: convert to warning
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ if use_fp16:
+ print("WARNING: use_fp16 was dropped and has no effect anymore.")
+ # self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ assert use_fairscale_checkpoint != use_checkpoint or not (
+ use_checkpoint or use_fairscale_checkpoint
+ )
+
+ self.use_fairscale_checkpoint = False
+ checkpoint_wrapper_fn = (
+ partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
+ if self.use_fairscale_checkpoint
+ else lambda x: x
+ )
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = checkpoint_wrapper_fn(
+ nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ elif self.num_classes == "timestep":
+ self.label_emb = checkpoint_wrapper_fn(
+ nn.Sequential(
+ Timestep(model_channels),
+ nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ ),
+ )
+ )
+ elif self.num_classes == "sequential":
+ assert adm_in_channels is not None
+ self.label_emb = nn.Sequential(
+ nn.Sequential(
+ linear(adm_in_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ )
+ else:
+ raise ValueError()
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = (
+ ch // num_heads
+ if use_spatial_transformer
+ else num_head_channels
+ )
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if (
+ not exists(num_attention_blocks)
+ or nr < num_attention_blocks[level]
+ ):
+ layers.append(
+ checkpoint_wrapper_fn(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ if not use_spatial_transformer
+ else checkpoint_wrapper_fn(
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth[level],
+ context_dim=context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ),
+ checkpoint_wrapper_fn(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ if not use_spatial_transformer
+ else checkpoint_wrapper_fn(
+ SpatialTransformer( # always uses a self-attn
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth_middle,
+ context_dim=context_dim,
+ disable_self_attn=disable_middle_self_attn,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ ),
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = (
+ ch // num_heads
+ if use_spatial_transformer
+ else num_head_channels
+ )
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if (
+ not exists(num_attention_blocks)
+ or i < num_attention_blocks[level]
+ ):
+ layers.append(
+ checkpoint_wrapper_fn(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ if not use_spatial_transformer
+ else checkpoint_wrapper_fn(
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth[level],
+ context_dim=context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ )
+ )
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = checkpoint_wrapper_fn(
+ nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = checkpoint_wrapper_fn(
+ nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+ )
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y)
+
+ # h = x.type(self.dtype)
+ h = x
+ for i, module in enumerate(self.input_blocks):
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for i, module in enumerate(self.output_blocks):
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ assert False, "not supported anymore. what the f*** are you doing?"
+ else:
+ return self.out(h)
+
+
+import seaborn as sns
+import matplotlib.pyplot as plt
+
+class UNetAddModel(nn.Module):
+
+ def __init__(
+ self,
+ in_channels,
+ ctrl_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ attn_type="attn2",
+ attn_layers=[],
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ add_context_dim=None,
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None,
+ disable_middle_self_attn=False,
+ use_linear_in_transformer=False,
+ spatial_transformer_attn_type="softmax",
+ adm_in_channels=None,
+ use_fairscale_checkpoint=False,
+ offload_to_cpu=False,
+ transformer_depth_middle=None,
+ ):
+ super().__init__()
+ from omegaconf.listconfig import ListConfig
+
+ if use_spatial_transformer:
+ assert (
+ context_dim is not None
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
+
+ if context_dim is not None:
+ assert (
+ use_spatial_transformer
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert (
+ num_head_channels != -1
+ ), "Either num_heads or num_head_channels has to be set"
+
+ if num_head_channels == -1:
+ assert (
+ num_heads != -1
+ ), "Either num_heads or num_head_channels has to be set"
+
+ self.in_channels = in_channels
+ self.ctrl_channels = ctrl_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(transformer_depth, int):
+ transformer_depth = len(channel_mult) * [transformer_depth]
+ elif isinstance(transformer_depth, ListConfig):
+ transformer_depth = list(transformer_depth)
+ transformer_depth_middle = default(
+ transformer_depth_middle, transformer_depth[-1]
+ )
+
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError(
+ "provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult"
+ )
+ self.num_res_blocks = num_res_blocks
+ # self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(
+ map(
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
+ range(len(num_attention_blocks)),
+ )
+ )
+ print(
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set."
+ ) # todo: convert to warning
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ if use_fp16:
+ print("WARNING: use_fp16 was dropped and has no effect anymore.")
+ # self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ assert use_fairscale_checkpoint != use_checkpoint or not (
+ use_checkpoint or use_fairscale_checkpoint
+ )
+
+ self.use_fairscale_checkpoint = False
+ checkpoint_wrapper_fn = (
+ partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
+ if self.use_fairscale_checkpoint
+ else lambda x: x
+ )
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = checkpoint_wrapper_fn(
+ nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ elif self.num_classes == "timestep":
+ self.label_emb = checkpoint_wrapper_fn(
+ nn.Sequential(
+ Timestep(model_channels),
+ nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ ),
+ )
+ )
+ elif self.num_classes == "sequential":
+ assert adm_in_channels is not None
+ self.label_emb = nn.Sequential(
+ nn.Sequential(
+ linear(adm_in_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ )
+ else:
+ raise ValueError()
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ if self.ctrl_channels > 0:
+ self.add_input_block = TimestepEmbedSequential(
+ conv_nd(dims, ctrl_channels, 16, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 16, 16, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 16, 32, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 32, 32, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 32, 96, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 96, 96, 3, padding=1),
+ nn.SiLU(),
+ conv_nd(dims, 96, 256, 3, padding=1),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
+ )
+
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = (
+ ch // num_heads
+ if use_spatial_transformer
+ else num_head_channels
+ )
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if (
+ not exists(num_attention_blocks)
+ or nr < num_attention_blocks[level]
+ ):
+ layers.append(
+ checkpoint_wrapper_fn(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ if not use_spatial_transformer
+ else checkpoint_wrapper_fn(
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth[level],
+ context_dim=context_dim,
+ add_context_dim=add_context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ),
+ checkpoint_wrapper_fn(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ if not use_spatial_transformer
+ else checkpoint_wrapper_fn(
+ SpatialTransformer( # always uses a self-attn
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth_middle,
+ context_dim=context_dim,
+ add_context_dim=add_context_dim,
+ disable_self_attn=disable_middle_self_attn,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ ),
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = (
+ ch // num_heads
+ if use_spatial_transformer
+ else num_head_channels
+ )
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if (
+ not exists(num_attention_blocks)
+ or i < num_attention_blocks[level]
+ ):
+ layers.append(
+ checkpoint_wrapper_fn(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ if not use_spatial_transformer
+ else checkpoint_wrapper_fn(
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth[level],
+ context_dim=context_dim,
+ add_context_dim=add_context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ )
+ )
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ checkpoint_wrapper_fn(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = checkpoint_wrapper_fn(
+ nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = checkpoint_wrapper_fn(
+ nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+ )
+
+ # cache attn map
+ self.attn_type = attn_type
+ self.attn_layers = attn_layers
+ self.attn_map_cache = []
+ for name, module in self.named_modules():
+ if name.endswith(self.attn_type):
+ item = {"name": name, "heads": module.heads, "size": None, "attn_map": None}
+ self.attn_map_cache.append(item)
+ module.attn_map_cache = item
+
+ def clear_attn_map(self):
+
+ for item in self.attn_map_cache:
+ if item["attn_map"] is not None:
+ del item["attn_map"]
+ item["attn_map"] = None
+
+ def save_attn_map(self, save_name="temp", tokens=""):
+
+ attn_maps = []
+ for item in self.attn_map_cache:
+ name = item["name"]
+ if any([name.startswith(block) for block in self.attn_layers]):
+ heads = item["heads"]
+ attn_maps.append(item["attn_map"].detach().cpu())
+
+ attn_map = th.stack(attn_maps, dim=0)
+ attn_map = th.mean(attn_map, dim=0)
+
+ # attn_map: bh * n * l
+ bh, n, l = attn_map.shape # bh: batch size * heads / n : pixel length(h*w) / l: token length
+ attn_map = attn_map.reshape((-1,heads,n,l)).mean(dim=1)
+ b = attn_map.shape[0]
+
+ h = w = int(n**0.5)
+ attn_map = attn_map.permute(0,2,1).reshape((b,l,h,w)).numpy()
+
+ attn_map_i = attn_map[-1]
+
+ l = attn_map_i.shape[0]
+ fig = plt.figure(figsize=(12, 8), dpi=300)
+ for j in range(12):
+ if j >= l: break
+ ax = fig.add_subplot(3, 4, j+1)
+ sns.heatmap(attn_map_i[j], square=True, xticklabels=False, yticklabels=False)
+ if j < len(tokens):
+ ax.set_title(tokens[j])
+ fig.savefig(f"temp/attn_map/attn_map_{save_name}.png")
+ plt.close()
+
+ return attn_map_i
+
+ def forward(self, x, timesteps=None, context=None, add_context=None, y=None, **kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+
+ self.clear_attn_map()
+
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y)
+
+ # h = x.type(self.dtype)
+ h = x
+ if self.ctrl_channels > 0:
+ in_h, add_h = th.split(h, [self.in_channels, self.ctrl_channels], dim=1)
+
+ for i, module in enumerate(self.input_blocks):
+ if self.ctrl_channels > 0 and i == 0:
+ h = module(in_h, emb, context, add_context) + self.add_input_block(add_h, emb, context, add_context)
+ else:
+ h = module(h, emb, context, add_context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context, add_context)
+ for i, module in enumerate(self.output_blocks):
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context, add_context)
+ h = h.type(x.dtype)
+
+ return self.out(h)
\ No newline at end of file
diff --git a/sgm/modules/diffusionmodules/sampling.py b/sgm/modules/diffusionmodules/sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..506eb60ba7a81e734eb5ed2abafc613fa0633e6a
--- /dev/null
+++ b/sgm/modules/diffusionmodules/sampling.py
@@ -0,0 +1,784 @@
+"""
+ Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
+"""
+
+
+from typing import Dict, Union
+
+import imageio
+import torch
+import json
+import numpy as np
+import torch.nn.functional as F
+from omegaconf import ListConfig, OmegaConf
+from tqdm import tqdm
+
+from ...modules.diffusionmodules.sampling_utils import (
+ get_ancestral_step,
+ linear_multistep_coeff,
+ to_d,
+ to_neg_log_sigma,
+ to_sigma,
+)
+from ...util import append_dims, default, instantiate_from_config
+from torchvision.utils import save_image
+
+DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
+
+
+class BaseDiffusionSampler:
+ def __init__(
+ self,
+ discretization_config: Union[Dict, ListConfig, OmegaConf],
+ num_steps: Union[int, None] = None,
+ guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,
+ verbose: bool = False,
+ device: str = "cuda",
+ ):
+ self.num_steps = num_steps
+ self.discretization = instantiate_from_config(discretization_config)
+ self.guider = instantiate_from_config(
+ default(
+ guider_config,
+ DEFAULT_GUIDER,
+ )
+ )
+ self.verbose = verbose
+ self.device = device
+
+ def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
+ sigmas = self.discretization(
+ self.num_steps if num_steps is None else num_steps, device=self.device
+ )
+ uc = default(uc, cond)
+
+ x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
+ num_sigmas = len(sigmas)
+
+ s_in = x.new_ones([x.shape[0]])
+
+ return x, s_in, sigmas, num_sigmas, cond, uc
+
+ def denoise(self, x, model, sigma, cond, uc):
+ denoised = model.denoiser(model.model, *self.guider.prepare_inputs(x, sigma, cond, uc))
+ denoised = self.guider(denoised, sigma)
+ return denoised
+
+ def get_sigma_gen(self, num_sigmas, init_step=0):
+ sigma_generator = range(init_step, num_sigmas - 1)
+ if self.verbose:
+ print("#" * 30, " Sampling setting ", "#" * 30)
+ print(f"Sampler: {self.__class__.__name__}")
+ print(f"Discretization: {self.discretization.__class__.__name__}")
+ print(f"Guider: {self.guider.__class__.__name__}")
+ sigma_generator = tqdm(
+ sigma_generator,
+ total=num_sigmas-1-init_step,
+ desc=f"Sampling with {self.__class__.__name__} for {num_sigmas-1-init_step} steps",
+ )
+ return sigma_generator
+
+
+class SingleStepDiffusionSampler(BaseDiffusionSampler):
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):
+ raise NotImplementedError
+
+ def euler_step(self, x, d, dt):
+ return x + dt * d
+
+
+class EDMSampler(SingleStepDiffusionSampler):
+ def __init__(
+ self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+
+ self.s_churn = s_churn
+ self.s_tmin = s_tmin
+ self.s_tmax = s_tmax
+ self.s_noise = s_noise
+
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
+ sigma_hat = sigma * (gamma + 1.0)
+ if gamma > 0:
+ eps = torch.randn_like(x) * self.s_noise
+ x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
+
+ denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
+ d = to_d(x, sigma_hat, denoised)
+ dt = append_dims(next_sigma - sigma_hat, x.ndim)
+
+ euler_step = self.euler_step(x, d, dt)
+ x = self.possible_correction_step(
+ euler_step, x, d, dt, next_sigma, denoiser, cond, uc
+ )
+ return x
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
+
+ for i in self.get_sigma_gen(num_sigmas):
+ gamma = (
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
+ else 0.0
+ )
+ x = self.sampler_step(
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc,
+ gamma,
+ )
+
+ return x
+
+
+class AncestralSampler(SingleStepDiffusionSampler):
+ def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.eta = eta
+ self.s_noise = s_noise
+ self.noise_sampler = lambda x: torch.randn_like(x)
+
+ def ancestral_euler_step(self, x, denoised, sigma, sigma_down):
+ d = to_d(x, sigma, denoised)
+ dt = append_dims(sigma_down - sigma, x.ndim)
+
+ return self.euler_step(x, d, dt)
+
+ def ancestral_step(self, x, sigma, next_sigma, sigma_up):
+ x = torch.where(
+ append_dims(next_sigma, x.ndim) > 0.0,
+ x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim),
+ x,
+ )
+ return x
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
+
+ for i in self.get_sigma_gen(num_sigmas):
+ x = self.sampler_step(
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc,
+ )
+
+ return x
+
+
+class LinearMultistepSampler(BaseDiffusionSampler):
+ def __init__(
+ self,
+ order=4,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+
+ self.order = order
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
+
+ ds = []
+ sigmas_cpu = sigmas.detach().cpu().numpy()
+ for i in self.get_sigma_gen(num_sigmas):
+ sigma = s_in * sigmas[i]
+ denoised = denoiser(
+ *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs
+ )
+ denoised = self.guider(denoised, sigma)
+ d = to_d(x, sigma, denoised)
+ ds.append(d)
+ if len(ds) > self.order:
+ ds.pop(0)
+ cur_order = min(i + 1, self.order)
+ coeffs = [
+ linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
+ for j in range(cur_order)
+ ]
+ x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
+
+ return x
+
+
+class EulerEDMSampler(EDMSampler):
+
+ def possible_correction_step(
+ self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
+ ):
+ return euler_step
+
+ def get_c_noise(self, x, model, sigma):
+ sigma = model.denoiser.possibly_quantize_sigma(sigma)
+ sigma_shape = sigma.shape
+ sigma = append_dims(sigma, x.ndim)
+ c_skip, c_out, c_in, c_noise = model.denoiser.scaling(sigma)
+ c_noise = model.denoiser.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
+ return c_noise
+
+ def attend_and_excite(self, x, model, sigma, cond, batch, alpha, iter_enabled, thres, max_iter=20):
+
+ # calc timestep
+ c_noise = self.get_c_noise(x, model, sigma)
+
+ x = x.clone().detach().requires_grad_(True) # https://github.com/yuval-alaluf/Attend-and-Excite/blob/main/pipeline_attend_and_excite.py#L288
+
+ iters = 0
+ while True:
+
+ model_output = model.model(x, c_noise, cond)
+ local_loss = model.loss_fn.get_min_local_loss(model.model.diffusion_model.attn_map_cache, batch["mask"], batch["seg_mask"])
+ grad = torch.autograd.grad(local_loss.requires_grad_(True), [x], retain_graph=True)[0]
+ x = x - alpha * grad
+ iters += 1
+
+ if not iter_enabled or local_loss <= thres or iters > max_iter:
+ break
+
+ return x
+
+ def create_pascal_label_colormap(self):
+ """
+ PASCAL VOC 分割数据集的类别标签颜色映射label colormap
+
+ 返回:
+ 可视化分割结果的颜色映射Colormap
+ """
+ colormap = np.zeros((256, 3), dtype=int)
+ ind = np.arange(256, dtype=int)
+
+ for shift in reversed(range(8)):
+ for channel in range(3):
+ colormap[:, channel] |= ((ind >> channel) & 1) << shift
+ ind >>= 3
+
+ return colormap
+
+ def save_segment_map(self, image, attn_maps, tokens=None, save_name=None):
+
+ colormap = self.create_pascal_label_colormap()
+ H, W = image.shape[-2:]
+
+ image_ = image*0.3
+ sections = []
+ for i in range(len(tokens)):
+ attn_map = attn_maps[i]
+ attn_map_t = np.tile(attn_map[None], (1,3,1,1)) # b, 3, h, w
+ attn_map_t = torch.from_numpy(attn_map_t)
+ attn_map_t = F.interpolate(attn_map_t, (W, H))
+
+ color = torch.from_numpy(colormap[i+1][None,:,None,None] / 255.0)
+ colored_attn_map = attn_map_t * color
+ colored_attn_map = colored_attn_map.to(device=image_.device)
+
+ image_ += colored_attn_map*0.7
+ sections.append(attn_map)
+
+ section = np.stack(sections)
+ np.save(f"temp/seg_map/seg_{save_name}.npy", section)
+
+ save_image(image_, f"temp/seg_map/seg_{save_name}.png", normalize=True)
+
+ def get_init_noise(self, cfgs, model, cond, batch, uc=None):
+
+ H, W = batch["target_size_as_tuple"][0]
+ shape = (cfgs.batch_size, cfgs.channel, int(H) // cfgs.factor, int(W) // cfgs.factor)
+
+ randn = torch.randn(shape).to(torch.device("cuda", index=cfgs.gpu))
+ x = randn.clone()
+
+ xs = []
+ self.verbose = False
+ for _ in range(cfgs.noise_iters):
+
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
+ x, cond, uc, num_steps=2
+ )
+
+ superv = {
+ "mask": batch["mask"] if "mask" in batch else None,
+ "seg_mask": batch["seg_mask"] if "seg_mask" in batch else None
+ }
+
+ local_losses = []
+
+ for i in self.get_sigma_gen(num_sigmas):
+
+ gamma = (
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
+ else 0.0
+ )
+
+ x, inter, local_loss = self.sampler_step(
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ model,
+ x,
+ cond,
+ superv,
+ uc,
+ gamma,
+ save_loss=True
+ )
+
+ local_losses.append(local_loss.item())
+
+ xs.append((randn, local_losses[-1]))
+
+ randn = torch.randn(shape).to(torch.device("cuda", index=cfgs.gpu))
+ x = randn.clone()
+
+ self.verbose = True
+
+ xs.sort(key = lambda x: x[-1])
+
+ if len(xs) > 0:
+ print(f"Init local loss: Best {xs[0][1]} Worst {xs[-1][1]}")
+ x = xs[0][0]
+
+ return x
+
+ def sampler_step(self, sigma, next_sigma, model, x, cond, batch=None, uc=None,
+ gamma=0.0, alpha=0, iter_enabled=False, thres=None, update=False,
+ name=None, save_loss=False, save_attn=False, save_inter=False):
+
+ sigma_hat = sigma * (gamma + 1.0)
+ if gamma > 0:
+ eps = torch.randn_like(x) * self.s_noise
+ x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
+
+ if update:
+ x = self.attend_and_excite(x, model, sigma_hat, cond, batch, alpha, iter_enabled, thres)
+
+ denoised = self.denoise(x, model, sigma_hat, cond, uc)
+ denoised_decode = model.decode_first_stage(denoised) if save_inter else None
+
+ if save_loss:
+ local_loss = model.loss_fn.get_min_local_loss(model.model.diffusion_model.attn_map_cache, batch["mask"], batch["seg_mask"])
+ local_loss = local_loss[local_loss.shape[0]//2:]
+ else:
+ local_loss = torch.zeros(1)
+ if save_attn:
+ attn_map = model.model.diffusion_model.save_attn_map(save_name=name, tokens=batch["label"][0])
+ denoised_decode = model.decode_first_stage(denoised) if denoised_decode is None else denoised_decode
+ self.save_segment_map(denoised_decode, attn_map, tokens=batch["label"][0], save_name=name)
+
+ d = to_d(x, sigma_hat, denoised)
+ dt = append_dims(next_sigma - sigma_hat, x.ndim)
+
+ euler_step = self.euler_step(x, d, dt)
+
+ return euler_step, denoised_decode, local_loss
+
+ def __call__(self, model, x, cond, batch=None, uc=None, num_steps=None, init_step=0,
+ name=None, aae_enabled=False, detailed=False):
+
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
+
+ name = batch["name"][0]
+ inters = []
+ local_losses = []
+ scales = np.linspace(start=1.0, stop=0, num=num_sigmas)
+ iter_lst = np.linspace(start=5, stop=25, num=6, dtype=np.int32)
+ thres_lst = np.linspace(start=-0.5, stop=-0.8, num=6)
+
+ for i in self.get_sigma_gen(num_sigmas, init_step=init_step):
+
+ gamma = (
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
+ else 0.0
+ )
+
+ alpha = 20 * np.sqrt(scales[i])
+ update = aae_enabled
+ save_loss = detailed
+ save_attn = detailed and (i == (num_sigmas-1)//2)
+ save_inter = aae_enabled
+
+ if i in iter_lst:
+ iter_enabled = True
+ thres = thres_lst[list(iter_lst).index(i)]
+ else:
+ iter_enabled = False
+ thres = 0.0
+
+ x, inter, local_loss = self.sampler_step(
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ model,
+ x,
+ cond,
+ batch,
+ uc,
+ gamma,
+ alpha=alpha,
+ iter_enabled=iter_enabled,
+ thres=thres,
+ update=update,
+ name=name,
+ save_loss=save_loss,
+ save_attn=save_attn,
+ save_inter=save_inter
+ )
+
+ local_losses.append(local_loss.item())
+ if inter is not None:
+ inter = torch.clamp((inter + 1.0) / 2.0, min=0.0, max=1.0)[0]
+ inter = inter.cpu().numpy().transpose(1, 2, 0) * 255
+ inters.append(inter.astype(np.uint8))
+
+ print(f"Local losses: {local_losses}")
+
+ if len(inters) > 0:
+ imageio.mimsave(f"./temp/inters/{name}.gif", inters, 'GIF', duration=0.02)
+
+ return x
+
+
+class EulerEDMDualSampler(EulerEDMSampler):
+
+ def prepare_sampling_loop(self, x, cond, uc_1=None, uc_2=None, num_steps=None):
+ sigmas = self.discretization(
+ self.num_steps if num_steps is None else num_steps, device=self.device
+ )
+ uc_1 = default(uc_1, cond)
+ uc_2 = default(uc_2, cond)
+
+ x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
+ num_sigmas = len(sigmas)
+
+ s_in = x.new_ones([x.shape[0]])
+
+ return x, s_in, sigmas, num_sigmas, cond, uc_1, uc_2
+
+ def denoise(self, x, model, sigma, cond, uc_1, uc_2):
+ denoised = model.denoiser(model.model, *self.guider.prepare_inputs(x, sigma, cond, uc_1, uc_2))
+ denoised = self.guider(denoised, sigma)
+ return denoised
+
+ def get_init_noise(self, cfgs, model, cond, batch, uc_1=None, uc_2=None):
+
+ H, W = batch["target_size_as_tuple"][0]
+ shape = (cfgs.batch_size, cfgs.channel, int(H) // cfgs.factor, int(W) // cfgs.factor)
+
+ randn = torch.randn(shape).to(torch.device("cuda", index=cfgs.gpu))
+ x = randn.clone()
+
+ xs = []
+ self.verbose = False
+ for _ in range(cfgs.noise_iters):
+
+ x, s_in, sigmas, num_sigmas, cond, uc_1, uc_2 = self.prepare_sampling_loop(
+ x, cond, uc_1, uc_2, num_steps=2
+ )
+
+ superv = {
+ "mask": batch["mask"] if "mask" in batch else None,
+ "seg_mask": batch["seg_mask"] if "seg_mask" in batch else None
+ }
+
+ local_losses = []
+
+ for i in self.get_sigma_gen(num_sigmas):
+
+ gamma = (
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
+ else 0.0
+ )
+
+ x, inter, local_loss = self.sampler_step(
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ model,
+ x,
+ cond,
+ superv,
+ uc_1,
+ uc_2,
+ gamma,
+ save_loss=True
+ )
+
+ local_losses.append(local_loss.item())
+
+ xs.append((randn, local_losses[-1]))
+
+ randn = torch.randn(shape).to(torch.device("cuda", index=cfgs.gpu))
+ x = randn.clone()
+
+ self.verbose = True
+
+ xs.sort(key = lambda x: x[-1])
+
+ if len(xs) > 0:
+ print(f"Init local loss: Best {xs[0][1]} Worst {xs[-1][1]}")
+ x = xs[0][0]
+
+ return x
+
+ def sampler_step(self, sigma, next_sigma, model, x, cond, batch=None, uc_1=None, uc_2=None,
+ gamma=0.0, alpha=0, iter_enabled=False, thres=None, update=False,
+ name=None, save_loss=False, save_attn=False, save_inter=False):
+
+ sigma_hat = sigma * (gamma + 1.0)
+ if gamma > 0:
+ eps = torch.randn_like(x) * self.s_noise
+ x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
+
+ if update:
+ x = self.attend_and_excite(x, model, sigma_hat, cond, batch, alpha, iter_enabled, thres)
+
+ denoised = self.denoise(x, model, sigma_hat, cond, uc_1, uc_2)
+ denoised_decode = model.decode_first_stage(denoised) if save_inter else None
+
+ if save_loss:
+ local_loss = model.loss_fn.get_min_local_loss(model.model.diffusion_model.attn_map_cache, batch["mask"], batch["seg_mask"])
+ local_loss = local_loss[-local_loss.shape[0]//3:]
+ else:
+ local_loss = torch.zeros(1)
+ if save_attn:
+ attn_map = model.model.diffusion_model.save_attn_map(save_name=name, save_single=True)
+ denoised_decode = model.decode_first_stage(denoised) if denoised_decode is None else denoised_decode
+ self.save_segment_map(denoised_decode, attn_map, tokens=batch["label"][0], save_name=name)
+
+ d = to_d(x, sigma_hat, denoised)
+ dt = append_dims(next_sigma - sigma_hat, x.ndim)
+
+ euler_step = self.euler_step(x, d, dt)
+
+ return euler_step, denoised_decode, local_loss
+
+ def __call__(self, model, x, cond, batch=None, uc_1=None, uc_2=None, num_steps=None, init_step=0,
+ name=None, aae_enabled=False, detailed=False):
+
+ x, s_in, sigmas, num_sigmas, cond, uc_1, uc_2 = self.prepare_sampling_loop(
+ x, cond, uc_1, uc_2, num_steps
+ )
+
+ name = batch["name"][0]
+ inters = []
+ local_losses = []
+ scales = np.linspace(start=1.0, stop=0, num=num_sigmas)
+ iter_lst = np.linspace(start=5, stop=25, num=6, dtype=np.int32)
+ thres_lst = np.linspace(start=-0.5, stop=-0.8, num=6)
+
+ for i in self.get_sigma_gen(num_sigmas, init_step=init_step):
+
+ gamma = (
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
+ else 0.0
+ )
+
+ alpha = 20 * np.sqrt(scales[i])
+ update = aae_enabled
+ save_loss = aae_enabled
+ save_attn = detailed and (i == (num_sigmas-1)//2)
+ save_inter = aae_enabled
+
+ if i in iter_lst:
+ iter_enabled = True
+ thres = thres_lst[list(iter_lst).index(i)]
+ else:
+ iter_enabled = False
+ thres = 0.0
+
+ x, inter, local_loss = self.sampler_step(
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ model,
+ x,
+ cond,
+ batch,
+ uc_1,
+ uc_2,
+ gamma,
+ alpha=alpha,
+ iter_enabled=iter_enabled,
+ thres=thres,
+ update=update,
+ name=name,
+ save_loss=save_loss,
+ save_attn=save_attn,
+ save_inter=save_inter
+ )
+
+ local_losses.append(local_loss.item())
+ if inter is not None:
+ inter = torch.clamp((inter + 1.0) / 2.0, min=0.0, max=1.0)[0]
+ inter = inter.cpu().numpy().transpose(1, 2, 0) * 255
+ inters.append(inter.astype(np.uint8))
+
+ print(f"Local losses: {local_losses}")
+
+ if len(inters) > 0:
+ imageio.mimsave(f"./temp/inters/{name}.gif", inters, 'GIF', duration=0.1)
+
+ return x
+
+
+class HeunEDMSampler(EDMSampler):
+ def possible_correction_step(
+ self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
+ ):
+ if torch.sum(next_sigma) < 1e-14:
+ # Save a network evaluation if all noise levels are 0
+ return euler_step
+ else:
+ denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc)
+ d_new = to_d(euler_step, next_sigma, denoised)
+ d_prime = (d + d_new) / 2.0
+
+ # apply correction if noise level is not 0
+ x = torch.where(
+ append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
+ )
+ return x
+
+
+class EulerAncestralSampler(AncestralSampler):
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc):
+ sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
+ x = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
+ x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
+
+ return x
+
+
+class DPMPP2SAncestralSampler(AncestralSampler):
+ def get_variables(self, sigma, sigma_down):
+ t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)]
+ h = t_next - t
+ s = t + 0.5 * h
+ return h, s, t, t_next
+
+ def get_mult(self, h, s, t, t_next):
+ mult1 = to_sigma(s) / to_sigma(t)
+ mult2 = (-0.5 * h).expm1()
+ mult3 = to_sigma(t_next) / to_sigma(t)
+ mult4 = (-h).expm1()
+
+ return mult1, mult2, mult3, mult4
+
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs):
+ sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
+ x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
+
+ if torch.sum(sigma_down) < 1e-14:
+ # Save a network evaluation if all noise levels are 0
+ x = x_euler
+ else:
+ h, s, t, t_next = self.get_variables(sigma, sigma_down)
+ mult = [
+ append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)
+ ]
+
+ x2 = mult[0] * x - mult[1] * denoised
+ denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
+ x_dpmpp2s = mult[2] * x - mult[3] * denoised2
+
+ # apply correction if noise level is not 0
+ x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler)
+
+ x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
+ return x
+
+
+class DPMPP2MSampler(BaseDiffusionSampler):
+ def get_variables(self, sigma, next_sigma, previous_sigma=None):
+ t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
+ h = t_next - t
+
+ if previous_sigma is not None:
+ h_last = t - to_neg_log_sigma(previous_sigma)
+ r = h_last / h
+ return h, r, t, t_next
+ else:
+ return h, None, t, t_next
+
+ def get_mult(self, h, r, t, t_next, previous_sigma):
+ mult1 = to_sigma(t_next) / to_sigma(t)
+ mult2 = (-h).expm1()
+
+ if previous_sigma is not None:
+ mult3 = 1 + 1 / (2 * r)
+ mult4 = 1 / (2 * r)
+ return mult1, mult2, mult3, mult4
+ else:
+ return mult1, mult2
+
+ def sampler_step(
+ self,
+ old_denoised,
+ previous_sigma,
+ sigma,
+ next_sigma,
+ denoiser,
+ x,
+ cond,
+ uc=None,
+ ):
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
+
+ h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
+ mult = [
+ append_dims(mult, x.ndim)
+ for mult in self.get_mult(h, r, t, t_next, previous_sigma)
+ ]
+
+ x_standard = mult[0] * x - mult[1] * denoised
+ if old_denoised is None or torch.sum(next_sigma) < 1e-14:
+ # Save a network evaluation if all noise levels are 0 or on the first step
+ return x_standard, denoised
+ else:
+ denoised_d = mult[2] * denoised - mult[3] * old_denoised
+ x_advanced = mult[0] * x - mult[1] * denoised_d
+
+ # apply correction if noise level is not 0 and not first step
+ x = torch.where(
+ append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
+ )
+
+ return x, denoised
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, init_step=0, **kwargs):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
+
+ old_denoised = None
+ for i in self.get_sigma_gen(num_sigmas, init_step=init_step):
+ x, old_denoised = self.sampler_step(
+ old_denoised,
+ None if i == 0 else s_in * sigmas[i - 1],
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc=uc,
+ )
+
+ return x
diff --git a/sgm/modules/diffusionmodules/sampling_utils.py b/sgm/modules/diffusionmodules/sampling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..839fce1f246f8867b38a101fc2075905bef9ef8b
--- /dev/null
+++ b/sgm/modules/diffusionmodules/sampling_utils.py
@@ -0,0 +1,51 @@
+import torch
+from scipy import integrate
+
+from ...util import append_dims
+
+
+class NoDynamicThresholding:
+ def __call__(self, uncond, cond, scale):
+ return uncond + scale * (cond - uncond)
+
+class DualThresholding: # Dual condition CFG (from instructPix2Pix)
+ def __call__(self, uncond_1, uncond_2, cond, scale):
+ return uncond_1 + scale[0] * (uncond_2 - uncond_1) + scale[1] * (cond - uncond_2)
+
+def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
+ if order - 1 > i:
+ raise ValueError(f"Order {order} too high for step {i}")
+
+ def fn(tau):
+ prod = 1.0
+ for k in range(order):
+ if j == k:
+ continue
+ prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
+ return prod
+
+ return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]
+
+
+def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
+ if not eta:
+ return sigma_to, 0.0
+ sigma_up = torch.minimum(
+ sigma_to,
+ eta
+ * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
+ )
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
+ return sigma_down, sigma_up
+
+
+def to_d(x, sigma, denoised):
+ return (x - denoised) / append_dims(sigma, x.ndim)
+
+
+def to_neg_log_sigma(sigma):
+ return sigma.log().neg()
+
+
+def to_sigma(neg_log_sigma):
+ return neg_log_sigma.neg().exp()
diff --git a/sgm/modules/diffusionmodules/sigma_sampling.py b/sgm/modules/diffusionmodules/sigma_sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..d54724c6ef6a7b8067784a4192b0fe2f41123063
--- /dev/null
+++ b/sgm/modules/diffusionmodules/sigma_sampling.py
@@ -0,0 +1,31 @@
+import torch
+
+from ...util import default, instantiate_from_config
+
+
+class EDMSampling:
+ def __init__(self, p_mean=-1.2, p_std=1.2):
+ self.p_mean = p_mean
+ self.p_std = p_std
+
+ def __call__(self, n_samples, rand=None):
+ log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
+ return log_sigma.exp()
+
+
+class DiscreteSampling:
+ def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True):
+ self.num_idx = num_idx
+ self.sigmas = instantiate_from_config(discretization_config)(
+ num_idx, do_append_zero=do_append_zero, flip=flip
+ )
+
+ def idx_to_sigma(self, idx):
+ return self.sigmas[idx]
+
+ def __call__(self, n_samples, rand=None):
+ idx = default(
+ rand,
+ torch.randint(0, self.num_idx, (n_samples,)),
+ )
+ return self.idx_to_sigma(idx)
diff --git a/sgm/modules/diffusionmodules/util.py b/sgm/modules/diffusionmodules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..069ff131fb9789949203551e0efd8313a3d0cc08
--- /dev/null
+++ b/sgm/modules/diffusionmodules/util.py
@@ -0,0 +1,308 @@
+"""
+adopted from
+https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+and
+https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+and
+https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+
+thanks!
+"""
+
+import math
+
+import torch
+import torch.nn as nn
+from einops import repeat
+
+
+def make_beta_schedule(
+ schedule,
+ n_timestep,
+ linear_start=1e-4,
+ linear_end=2e-2,
+):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(
+ linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
+ )
+ ** 2
+ )
+ return betas.numpy()
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def mixed_checkpoint(func, inputs: dict, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function
+ borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that
+ it also works with non-tensor inputs
+ :param func: the function to evaluate.
+ :param inputs: the argument dictionary to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]
+ tensor_inputs = [
+ inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)
+ ]
+ non_tensor_keys = [
+ key for key in inputs if not isinstance(inputs[key], torch.Tensor)
+ ]
+ non_tensor_inputs = [
+ inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)
+ ]
+ args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)
+ return MixedCheckpointFunction.apply(
+ func,
+ len(tensor_inputs),
+ len(non_tensor_inputs),
+ tensor_keys,
+ non_tensor_keys,
+ *args,
+ )
+ else:
+ return func(**inputs)
+
+
+class MixedCheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ run_function,
+ length_tensors,
+ length_non_tensors,
+ tensor_keys,
+ non_tensor_keys,
+ *args,
+ ):
+ ctx.end_tensors = length_tensors
+ ctx.end_non_tensors = length_tensors + length_non_tensors
+ ctx.gpu_autocast_kwargs = {
+ "enabled": torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled(),
+ }
+ assert (
+ len(tensor_keys) == length_tensors
+ and len(non_tensor_keys) == length_non_tensors
+ )
+
+ ctx.input_tensors = {
+ key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))
+ }
+ ctx.input_non_tensors = {
+ key: val
+ for (key, val) in zip(
+ non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors])
+ )
+ }
+ ctx.run_function = run_function
+ ctx.input_params = list(args[ctx.end_non_tensors :])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(
+ **ctx.input_tensors, **ctx.input_non_tensors
+ )
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}
+ ctx.input_tensors = {
+ key: ctx.input_tensors[key].detach().requires_grad_(True)
+ for key in ctx.input_tensors
+ }
+
+ with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = {
+ key: ctx.input_tensors[key].view_as(ctx.input_tensors[key])
+ for key in ctx.input_tensors
+ }
+ # shallow_copies.update(additional_args)
+ output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ list(ctx.input_tensors.values()) + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (
+ (None, None, None, None, None)
+ + input_grads[: ctx.end_tensors]
+ + (None,) * (ctx.end_non_tensors - ctx.end_tensors)
+ + input_grads[ctx.end_tensors :]
+ )
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+ ctx.gpu_autocast_kwargs = {
+ "enabled": torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled(),
+ }
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period)
+ * torch.arange(start=0, end=half, dtype=torch.float32)
+ / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat(
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
+ )
+ else:
+ embedding = repeat(timesteps, "b -> b d", d=dim)
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
diff --git a/sgm/modules/diffusionmodules/wrappers.py b/sgm/modules/diffusionmodules/wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bb494470026014f6c091828361b24d494dd37dd
--- /dev/null
+++ b/sgm/modules/diffusionmodules/wrappers.py
@@ -0,0 +1,35 @@
+import torch
+import torch.nn as nn
+from packaging import version
+
+OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
+
+
+class IdentityWrapper(nn.Module):
+ def __init__(self, diffusion_model, compile_model: bool = False):
+ super().__init__()
+ compile = (
+ torch.compile
+ if (version.parse(torch.__version__) >= version.parse("2.0.0"))
+ and compile_model
+ else lambda x: x
+ )
+ self.diffusion_model = compile(diffusion_model)
+
+ def forward(self, *args, **kwargs):
+ return self.diffusion_model(*args, **kwargs)
+
+
+class OpenAIWrapper(IdentityWrapper):
+ def forward(
+ self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
+ ) -> torch.Tensor:
+ x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
+ return self.diffusion_model(
+ x,
+ timesteps=t,
+ context=c.get("crossattn", None),
+ add_context=c.get("add_crossattn", None),
+ y=c.get("vector", None),
+ **kwargs
+ )
diff --git a/sgm/modules/distributions/__init__.py b/sgm/modules/distributions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sgm/modules/distributions/__pycache__/__init__.cpython-310.pyc b/sgm/modules/distributions/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f40850d96a355a3bf5435b1bc1b660f5469bdc44
Binary files /dev/null and b/sgm/modules/distributions/__pycache__/__init__.cpython-310.pyc differ
diff --git a/sgm/modules/distributions/__pycache__/__init__.cpython-311.pyc b/sgm/modules/distributions/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..453d311410d5bc23bd49f25882fbedd02aad1e36
Binary files /dev/null and b/sgm/modules/distributions/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sgm/modules/distributions/__pycache__/distributions.cpython-310.pyc b/sgm/modules/distributions/__pycache__/distributions.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cdd90b54e8f50ab4961fa1544e3cf174f41732f3
Binary files /dev/null and b/sgm/modules/distributions/__pycache__/distributions.cpython-310.pyc differ
diff --git a/sgm/modules/distributions/__pycache__/distributions.cpython-311.pyc b/sgm/modules/distributions/__pycache__/distributions.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ebe38223465ba41db4acc3ef75bdcff3245b6380
Binary files /dev/null and b/sgm/modules/distributions/__pycache__/distributions.cpython-311.pyc differ
diff --git a/sgm/modules/distributions/distributions.py b/sgm/modules/distributions/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b61f03077358ce4737c85842d9871f70dabb656
--- /dev/null
+++ b/sgm/modules/distributions/distributions.py
@@ -0,0 +1,102 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(
+ device=self.parameters.device
+ )
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(
+ device=self.parameters.device
+ )
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3],
+ )
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2, 3],
+ )
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims,
+ )
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/sgm/modules/ema.py b/sgm/modules/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..97b5ae2b230f89b4dba57e44c4f851478ad86f68
--- /dev/null
+++ b/sgm/modules/ema.py
@@ -0,0 +1,86 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError("Decay must be between 0 and 1")
+
+ self.m_name2s_name = {}
+ self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer(
+ "num_updates",
+ torch.tensor(0, dtype=torch.int)
+ if use_num_upates
+ else torch.tensor(-1, dtype=torch.int),
+ )
+
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ # remove as '.'-character is not allowed in buffers
+ s_name = name.replace(".", "")
+ self.m_name2s_name.update({name: s_name})
+ self.register_buffer(s_name, p.clone().detach().data)
+
+ self.collected_params = []
+
+ def reset_num_updates(self):
+ del self.num_updates
+ self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
+
+ def forward(self, model):
+ decay = self.decay
+
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
+
+ one_minus_decay = 1.0 - decay
+
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+ shadow_params[sname].sub_(
+ one_minus_decay * (shadow_params[sname] - m_param[key])
+ )
+ else:
+ assert not key in self.m_name2s_name
+
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+ else:
+ assert not key in self.m_name2s_name
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
diff --git a/sgm/modules/encoders/__init__.py b/sgm/modules/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sgm/modules/encoders/__pycache__/__init__.cpython-310.pyc b/sgm/modules/encoders/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ff39c633653cf3c06772c82124849cce4677bd34
Binary files /dev/null and b/sgm/modules/encoders/__pycache__/__init__.cpython-310.pyc differ
diff --git a/sgm/modules/encoders/__pycache__/__init__.cpython-311.pyc b/sgm/modules/encoders/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ce7a1a39c17c52cffbf26fdd3263b5d51d6b6c0
Binary files /dev/null and b/sgm/modules/encoders/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sgm/modules/encoders/__pycache__/modules.cpython-310.pyc b/sgm/modules/encoders/__pycache__/modules.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d5d3f41ce23667009337aba0d8e4cea5f941ccb4
Binary files /dev/null and b/sgm/modules/encoders/__pycache__/modules.cpython-310.pyc differ
diff --git a/sgm/modules/encoders/__pycache__/modules.cpython-311.pyc b/sgm/modules/encoders/__pycache__/modules.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..342442149d536422276d5bf1dba9deffafd40175
Binary files /dev/null and b/sgm/modules/encoders/__pycache__/modules.cpython-311.pyc differ
diff --git a/sgm/modules/encoders/modules.py b/sgm/modules/encoders/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..be44972af5374c8fb460eb318a78c3b6a3a0e7f3
--- /dev/null
+++ b/sgm/modules/encoders/modules.py
@@ -0,0 +1,1253 @@
+from contextlib import nullcontext
+from functools import partial
+from typing import Dict, List, Optional, Tuple, Union
+
+import kornia
+import numpy as np
+import open_clip
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+from omegaconf import ListConfig
+from torch.utils.checkpoint import checkpoint
+from transformers import (
+ ByT5Tokenizer,
+ CLIPTextModel,
+ CLIPTokenizer,
+ T5EncoderModel,
+ T5Tokenizer,
+)
+
+from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer
+from ...modules.diffusionmodules.model import Encoder
+from ...modules.diffusionmodules.openaimodel import Timestep
+from ...modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
+from ...modules.distributions.distributions import DiagonalGaussianDistribution
+from ...util import (
+ autocast,
+ count_params,
+ default,
+ disabled_train,
+ expand_dims_like,
+ instantiate_from_config,
+)
+
+import math
+import string
+import pytorch_lightning as pl
+from torchvision import transforms
+from timm.models.vision_transformer import VisionTransformer
+from safetensors.torch import load_file as load_safetensors
+
+# disable warning
+from transformers import logging
+logging.set_verbosity_error()
+
+class AbstractEmbModel(nn.Module):
+ def __init__(self, is_add_embedder=False):
+ super().__init__()
+ self._is_trainable = None
+ self._ucg_rate = None
+ self._input_key = None
+ self.is_add_embedder = is_add_embedder
+
+ @property
+ def is_trainable(self) -> bool:
+ return self._is_trainable
+
+ @property
+ def ucg_rate(self) -> Union[float, torch.Tensor]:
+ return self._ucg_rate
+
+ @property
+ def input_key(self) -> str:
+ return self._input_key
+
+ @is_trainable.setter
+ def is_trainable(self, value: bool):
+ self._is_trainable = value
+
+ @ucg_rate.setter
+ def ucg_rate(self, value: Union[float, torch.Tensor]):
+ self._ucg_rate = value
+
+ @input_key.setter
+ def input_key(self, value: str):
+ self._input_key = value
+
+ @is_trainable.deleter
+ def is_trainable(self):
+ del self._is_trainable
+
+ @ucg_rate.deleter
+ def ucg_rate(self):
+ del self._ucg_rate
+
+ @input_key.deleter
+ def input_key(self):
+ del self._input_key
+
+
+class GeneralConditioner(nn.Module):
+ OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
+ KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
+
+ def __init__(self, emb_models: Union[List, ListConfig]):
+ super().__init__()
+ embedders = []
+ for n, embconfig in enumerate(emb_models):
+ embedder = instantiate_from_config(embconfig)
+ assert isinstance(
+ embedder, AbstractEmbModel
+ ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
+ embedder.is_trainable = embconfig.get("is_trainable", False)
+ embedder.ucg_rate = embconfig.get("ucg_rate", 0.0)
+ if not embedder.is_trainable:
+ embedder.train = disabled_train
+ embedder.freeze()
+ print(
+ f"Initialized embedder #{n}: {embedder.__class__.__name__} "
+ f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
+ )
+
+ if "input_key" in embconfig:
+ embedder.input_key = embconfig["input_key"]
+ elif "input_keys" in embconfig:
+ embedder.input_keys = embconfig["input_keys"]
+ else:
+ raise KeyError(
+ f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}"
+ )
+
+ embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
+ if embedder.legacy_ucg_val is not None:
+ embedder.ucg_prng = np.random.RandomState()
+
+ embedders.append(embedder)
+ self.embedders = nn.ModuleList(embedders)
+
+ def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict:
+ assert embedder.legacy_ucg_val is not None
+ p = embedder.ucg_rate
+ val = embedder.legacy_ucg_val
+ for i in range(len(batch[embedder.input_key])):
+ if embedder.ucg_prng.choice(2, p=[1 - p, p]):
+ batch[embedder.input_key][i] = val
+ return batch
+
+ def forward(
+ self, batch: Dict, force_zero_embeddings: Optional[List] = None
+ ) -> Dict:
+ output = dict()
+ if force_zero_embeddings is None:
+ force_zero_embeddings = []
+ for embedder in self.embedders:
+ embedding_context = nullcontext if embedder.is_trainable else torch.no_grad
+ with embedding_context():
+ if hasattr(embedder, "input_key") and (embedder.input_key is not None):
+ if embedder.legacy_ucg_val is not None:
+ batch = self.possibly_get_ucg_val(embedder, batch)
+ emb_out = embedder(batch[embedder.input_key])
+ elif hasattr(embedder, "input_keys"):
+ emb_out = embedder(*[batch[k] for k in embedder.input_keys])
+ assert isinstance(
+ emb_out, (torch.Tensor, list, tuple)
+ ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}"
+ if not isinstance(emb_out, (list, tuple)):
+ emb_out = [emb_out]
+ for emb in emb_out:
+ if embedder.is_add_embedder:
+ out_key = "add_crossattn"
+ else:
+ out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
+ if embedder.input_key == "mask":
+ H, W = batch["image"].shape[-2:]
+ emb = nn.functional.interpolate(emb, (H//8, W//8))
+ if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
+ emb = (
+ expand_dims_like(
+ torch.bernoulli(
+ (1.0 - embedder.ucg_rate)
+ * torch.ones(emb.shape[0], device=emb.device)
+ ),
+ emb,
+ )
+ * emb
+ )
+ if (
+ hasattr(embedder, "input_key")
+ and embedder.input_key in force_zero_embeddings
+ ):
+ emb = torch.zeros_like(emb)
+ if out_key in output:
+ output[out_key] = torch.cat(
+ (output[out_key], emb), self.KEY2CATDIM[out_key]
+ )
+ else:
+ output[out_key] = emb
+ return output
+
+ def get_unconditional_conditioning(
+ self, batch_c, batch_uc=None, force_uc_zero_embeddings=None
+ ):
+ if force_uc_zero_embeddings is None:
+ force_uc_zero_embeddings = []
+ ucg_rates = list()
+ for embedder in self.embedders:
+ ucg_rates.append(embedder.ucg_rate)
+ embedder.ucg_rate = 0.0
+ c = self(batch_c)
+ uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings)
+
+ for embedder, rate in zip(self.embedders, ucg_rates):
+ embedder.ucg_rate = rate
+ return c, uc
+
+
+class DualConditioner(GeneralConditioner):
+
+ def get_unconditional_conditioning(
+ self, batch_c, batch_uc_1=None, batch_uc_2=None, force_uc_zero_embeddings=None
+ ):
+ if force_uc_zero_embeddings is None:
+ force_uc_zero_embeddings = []
+ ucg_rates = list()
+ for embedder in self.embedders:
+ ucg_rates.append(embedder.ucg_rate)
+ embedder.ucg_rate = 0.0
+
+ c = self(batch_c)
+ uc_1 = self(batch_uc_1, force_uc_zero_embeddings) if batch_uc_1 is not None else None
+ uc_2 = self(batch_uc_2, force_uc_zero_embeddings[:1]) if batch_uc_2 is not None else None
+
+ for embedder, rate in zip(self.embedders, ucg_rates):
+ embedder.ucg_rate = rate
+
+ return c, uc_1, uc_2
+
+
+class InceptionV3(nn.Module):
+ """Wrapper around the https://github.com/mseitzer/pytorch-fid inception
+ port with an additional squeeze at the end"""
+
+ def __init__(self, normalize_input=False, **kwargs):
+ super().__init__()
+ from pytorch_fid import inception
+
+ kwargs["resize_input"] = True
+ self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs)
+
+ def forward(self, inp):
+ # inp = kornia.geometry.resize(inp, (299, 299),
+ # interpolation='bicubic',
+ # align_corners=False,
+ # antialias=True)
+ # inp = inp.clamp(min=-1, max=1)
+
+ outp = self.model(inp)
+
+ if len(outp) == 1:
+ return outp[0].squeeze()
+
+ return outp
+
+
+class IdentityEncoder(AbstractEmbModel):
+ def encode(self, x):
+ return x
+ def freeze(self):
+ return
+ def forward(self, x):
+ return x
+
+
+class ClassEmbedder(AbstractEmbModel):
+ def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False):
+ super().__init__()
+ self.embedding = nn.Embedding(n_classes, embed_dim)
+ self.n_classes = n_classes
+ self.add_sequence_dim = add_sequence_dim
+
+ def forward(self, c):
+ c = self.embedding(c)
+ if self.add_sequence_dim:
+ c = c[:, None, :]
+ return c
+
+ def get_unconditional_conditioning(self, bs, device="cuda"):
+ uc_class = (
+ self.n_classes - 1
+ ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
+ uc = torch.ones((bs,), device=device) * uc_class
+ uc = {self.key: uc.long()}
+ return uc
+
+
+class ClassEmbedderForMultiCond(ClassEmbedder):
+ def forward(self, batch, key=None, disable_dropout=False):
+ out = batch
+ key = default(key, self.key)
+ islist = isinstance(batch[key], list)
+ if islist:
+ batch[key] = batch[key][0]
+ c_out = super().forward(batch, key, disable_dropout)
+ out[key] = [c_out] if islist else c_out
+ return out
+
+
+class FrozenT5Embedder(AbstractEmbModel):
+ """Uses the T5 transformer encoder for text"""
+
+ def __init__(
+ self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True
+ ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
+ super().__init__()
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
+ self.transformer = T5EncoderModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+
+ for param in self.parameters():
+ param.requires_grad = False
+
+ # @autocast
+ def forward(self, text):
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ tokens = batch_encoding["input_ids"].to(self.device)
+ with torch.autocast("cuda", enabled=False):
+ outputs = self.transformer(input_ids=tokens)
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenByT5Embedder(AbstractEmbModel):
+ """
+ Uses the ByT5 transformer encoder for text. Is character-aware.
+ """
+
+ def __init__(
+ self, version="google/byt5-base", device="cuda", max_length=77, freeze=True, *args, **kwargs
+ ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
+ super().__init__(*args, **kwargs)
+ self.tokenizer = ByT5Tokenizer.from_pretrained(version)
+ self.transformer = T5EncoderModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ tokens = batch_encoding["input_ids"].to(next(self.parameters()).device)
+ with torch.autocast("cuda", enabled=False):
+ outputs = self.transformer(input_ids=tokens)
+ z = outputs.last_hidden_state # l, 1536
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPEmbedder(AbstractEmbModel):
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
+
+ LAYERS = ["last", "pooled", "hidden"]
+
+ def __init__(
+ self,
+ version="openai/clip-vit-large-patch14",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ layer="last",
+ layer_idx=None,
+ always_return_pooled=False,
+ ): # clip-vit-base-patch32
+ super().__init__()
+ assert layer in self.LAYERS
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPTextModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ self.layer_idx = layer_idx
+ self.return_pooled = always_return_pooled
+ if layer == "hidden":
+ assert layer_idx is not None
+ assert 0 <= abs(layer_idx) <= 12
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+
+ for param in self.parameters():
+ param.requires_grad = False
+
+ @autocast
+ def forward(self, text):
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ device = next(self.transformer.parameters()).device
+ tokens = batch_encoding["input_ids"].to(device)
+ outputs = self.transformer(
+ input_ids=tokens, output_hidden_states=self.layer == "hidden"
+ )
+ if self.layer == "last":
+ z = outputs.last_hidden_state
+ elif self.layer == "pooled":
+ z = outputs.pooler_output[:, None, :]
+ else:
+ z = outputs.hidden_states[self.layer_idx]
+ if self.return_pooled:
+ return z, outputs.pooler_output
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPEmbedder2(AbstractEmbModel):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+
+ LAYERS = ["pooled", "last", "penultimate"]
+
+ def __init__(
+ self,
+ arch="ViT-H-14",
+ version="laion2b_s32b_b79k",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ layer="last",
+ always_return_pooled=False,
+ legacy=True,
+ ):
+ super().__init__()
+ assert layer in self.LAYERS
+ model, _, _ = open_clip.create_model_and_transforms(
+ arch,
+ device=torch.device("cpu"),
+ pretrained=version,
+ )
+ del model.visual
+ self.model = model
+
+ self.device = device
+ self.max_length = max_length
+ self.return_pooled = always_return_pooled
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "last":
+ self.layer_idx = 0
+ elif self.layer == "penultimate":
+ self.layer_idx = 1
+ else:
+ raise NotImplementedError()
+ self.legacy = legacy
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ @autocast
+ def forward(self, text):
+ device = next(self.model.parameters()).device
+ tokens = open_clip.tokenize(text)
+ z = self.encode_with_transformer(tokens.to(device))
+ if not self.return_pooled and self.legacy:
+ return z
+ if self.return_pooled:
+ assert not self.legacy
+ return z[self.layer], z["pooled"]
+ return z[self.layer]
+
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ if self.legacy:
+ x = x[self.layer]
+ x = self.model.ln_final(x)
+ return x
+ else:
+ # x is a dict and will stay a dict
+ o = x["last"]
+ o = self.model.ln_final(o)
+ pooled = self.pool(o, text)
+ x["pooled"] = pooled
+ return x
+
+ def pool(self, x, text):
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = (
+ x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
+ @ self.model.text_projection
+ )
+ return x
+
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
+ outputs = {}
+ for i, r in enumerate(self.model.transformer.resblocks):
+ if i == len(self.model.transformer.resblocks) - 1:
+ outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD
+ if (
+ self.model.transformer.grad_checkpointing
+ and not torch.jit.is_scripting()
+ ):
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ outputs["last"] = x.permute(1, 0, 2) # LND -> NLD
+ return outputs
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPEmbedder(AbstractEmbModel):
+ LAYERS = [
+ # "pooled",
+ "last",
+ "penultimate",
+ ]
+
+ def __init__(
+ self,
+ arch="ViT-H-14",
+ version="laion2b_s32b_b79k",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ layer="last",
+ ):
+ super().__init__()
+ assert layer in self.LAYERS
+ model, _, _ = open_clip.create_model_and_transforms(
+ arch, device=torch.device("cpu"), pretrained=version
+ )
+ del model.visual
+ self.model = model
+
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "last":
+ self.layer_idx = 0
+ elif self.layer == "penultimate":
+ self.layer_idx = 1
+ else:
+ raise NotImplementedError()
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ device = next(self.model.parameters()).device
+ tokens = open_clip.tokenize(text)
+ z = self.encode_with_transformer(tokens.to(device))
+ return z
+
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.model.ln_final(x)
+ return x
+
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
+ for i, r in enumerate(self.model.transformer.resblocks):
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
+ break
+ if (
+ self.model.transformer.grad_checkpointing
+ and not torch.jit.is_scripting()
+ ):
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
+ """
+ Uses the OpenCLIP vision transformer encoder for images
+ """
+
+ def __init__(
+ self,
+ arch="ViT-H-14",
+ version="laion2b_s32b_b79k",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ antialias=True,
+ ucg_rate=0.0,
+ unsqueeze_dim=False,
+ repeat_to_max_len=False,
+ num_image_crops=0,
+ output_tokens=False,
+ ):
+ super().__init__()
+ model, _, _ = open_clip.create_model_and_transforms(
+ arch,
+ device=torch.device("cpu"),
+ pretrained=version,
+ )
+ del model.transformer
+ self.model = model
+ self.max_crops = num_image_crops
+ self.pad_to_max_len = self.max_crops > 0
+ self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+
+ self.antialias = antialias
+
+ self.register_buffer(
+ "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
+ )
+ self.register_buffer(
+ "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
+ )
+ self.ucg_rate = ucg_rate
+ self.unsqueeze_dim = unsqueeze_dim
+ self.stored_batch = None
+ self.model.visual.output_tokens = output_tokens
+ self.output_tokens = output_tokens
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ x = kornia.geometry.resize(
+ x,
+ (224, 224),
+ interpolation="bicubic",
+ align_corners=True,
+ antialias=self.antialias,
+ )
+ x = (x + 1.0) / 2.0
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ @autocast
+ def forward(self, image, no_dropout=False):
+ z = self.encode_with_vision_transformer(image)
+ tokens = None
+ if self.output_tokens:
+ z, tokens = z[0], z[1]
+ z = z.to(image.dtype)
+ if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
+ z = (
+ torch.bernoulli(
+ (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
+ )[:, None]
+ * z
+ )
+ if tokens is not None:
+ tokens = (
+ expand_dims_like(
+ torch.bernoulli(
+ (1.0 - self.ucg_rate)
+ * torch.ones(tokens.shape[0], device=tokens.device)
+ ),
+ tokens,
+ )
+ * tokens
+ )
+ if self.unsqueeze_dim:
+ z = z[:, None, :]
+ if self.output_tokens:
+ assert not self.repeat_to_max_len
+ assert not self.pad_to_max_len
+ return tokens, z
+ if self.repeat_to_max_len:
+ if z.dim() == 2:
+ z_ = z[:, None, :]
+ else:
+ z_ = z
+ return repeat(z_, "b 1 d -> b n d", n=self.max_length), z
+ elif self.pad_to_max_len:
+ assert z.dim() == 3
+ z_pad = torch.cat(
+ (
+ z,
+ torch.zeros(
+ z.shape[0],
+ self.max_length - z.shape[1],
+ z.shape[2],
+ device=z.device,
+ ),
+ ),
+ 1,
+ )
+ return z_pad, z_pad[:, 0, ...]
+ return z
+
+ def encode_with_vision_transformer(self, img):
+ # if self.max_crops > 0:
+ # img = self.preprocess_by_cropping(img)
+ if img.dim() == 5:
+ assert self.max_crops == img.shape[1]
+ img = rearrange(img, "b n c h w -> (b n) c h w")
+ img = self.preprocess(img)
+ if not self.output_tokens:
+ assert not self.model.visual.output_tokens
+ x = self.model.visual(img)
+ tokens = None
+ else:
+ assert self.model.visual.output_tokens
+ x, tokens = self.model.visual(img)
+ if self.max_crops > 0:
+ x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
+ # drop out between 0 and all along the sequence axis
+ x = (
+ torch.bernoulli(
+ (1.0 - self.ucg_rate)
+ * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
+ )
+ * x
+ )
+ if tokens is not None:
+ tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
+ print(
+ f"You are running very experimental token-concat in {self.__class__.__name__}. "
+ f"Check what you are doing, and then remove this message."
+ )
+ if self.output_tokens:
+ return x, tokens
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPT5Encoder(AbstractEmbModel):
+ def __init__(
+ self,
+ clip_version="openai/clip-vit-large-patch14",
+ t5_version="google/t5-v1_1-xl",
+ device="cuda",
+ clip_max_length=77,
+ t5_max_length=77,
+ ):
+ super().__init__()
+ self.clip_encoder = FrozenCLIPEmbedder(
+ clip_version, device, max_length=clip_max_length
+ )
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
+ print(
+ f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
+ )
+
+ def encode(self, text):
+ return self(text)
+
+ def forward(self, text):
+ clip_z = self.clip_encoder.encode(text)
+ t5_z = self.t5_encoder.encode(text)
+ return [clip_z, t5_z]
+
+
+class SpatialRescaler(nn.Module):
+ def __init__(
+ self,
+ n_stages=1,
+ method="bilinear",
+ multiplier=0.5,
+ in_channels=3,
+ out_channels=None,
+ bias=False,
+ wrap_video=False,
+ kernel_size=1,
+ remap_output=False,
+ ):
+ super().__init__()
+ self.n_stages = n_stages
+ assert self.n_stages >= 0
+ assert method in [
+ "nearest",
+ "linear",
+ "bilinear",
+ "trilinear",
+ "bicubic",
+ "area",
+ ]
+ self.multiplier = multiplier
+ self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
+ self.remap_output = out_channels is not None or remap_output
+ if self.remap_output:
+ print(
+ f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
+ )
+ self.channel_mapper = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ bias=bias,
+ padding=kernel_size // 2,
+ )
+ self.wrap_video = wrap_video
+
+ def forward(self, x):
+ if self.wrap_video and x.ndim == 5:
+ B, C, T, H, W = x.shape
+ x = rearrange(x, "b c t h w -> b t c h w")
+ x = rearrange(x, "b t c h w -> (b t) c h w")
+
+ for stage in range(self.n_stages):
+ x = self.interpolator(x, scale_factor=self.multiplier)
+
+ if self.wrap_video:
+ x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C)
+ x = rearrange(x, "b t c h w -> b c t h w")
+ if self.remap_output:
+ x = self.channel_mapper(x)
+ return x
+
+ def encode(self, x):
+ return self(x)
+
+
+class LowScaleEncoder(nn.Module):
+ def __init__(
+ self,
+ model_config,
+ linear_start,
+ linear_end,
+ timesteps=1000,
+ max_noise_level=250,
+ output_size=64,
+ scale_factor=1.0,
+ ):
+ super().__init__()
+ self.max_noise_level = max_noise_level
+ self.model = instantiate_from_config(model_config)
+ self.augmentation_schedule = self.register_schedule(
+ timesteps=timesteps, linear_start=linear_start, linear_end=linear_end
+ )
+ self.out_size = output_size
+ self.scale_factor = scale_factor
+
+ def register_schedule(
+ self,
+ beta_schedule="linear",
+ timesteps=1000,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ ):
+ betas = make_beta_schedule(
+ beta_schedule,
+ timesteps,
+ linear_start=linear_start,
+ linear_end=linear_end,
+ cosine_s=cosine_s,
+ )
+ alphas = 1.0 - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
+
+ (timesteps,) = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert (
+ alphas_cumprod.shape[0] == self.num_timesteps
+ ), "alphas have to be defined for each timestep"
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer("betas", to_torch(betas))
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer(
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
+ )
+ self.register_buffer(
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
+ )
+ self.register_buffer(
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
+ )
+ self.register_buffer(
+ "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
+ )
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+ * noise
+ )
+
+ def forward(self, x):
+ z = self.model.encode(x)
+ if isinstance(z, DiagonalGaussianDistribution):
+ z = z.sample()
+ z = z * self.scale_factor
+ noise_level = torch.randint(
+ 0, self.max_noise_level, (x.shape[0],), device=x.device
+ ).long()
+ z = self.q_sample(z, noise_level)
+ if self.out_size is not None:
+ z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest")
+ # z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
+ return z, noise_level
+
+ def decode(self, z):
+ z = z / self.scale_factor
+ return self.model.decode(z)
+
+
+class ConcatTimestepEmbedderND(AbstractEmbModel):
+ """embeds each dimension independently and concatenates them"""
+
+ def __init__(self, outdim):
+ super().__init__()
+ self.timestep = Timestep(outdim)
+ self.outdim = outdim
+
+ def freeze(self):
+ self.eval()
+
+ def forward(self, x):
+ if x.ndim == 1:
+ x = x[:, None]
+ assert len(x.shape) == 2
+ b, dims = x.shape[0], x.shape[1]
+ x = rearrange(x, "b d -> (b d)")
+ emb = self.timestep(x)
+ emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
+ return emb
+
+
+class GaussianEncoder(Encoder, AbstractEmbModel):
+ def __init__(
+ self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.posterior = DiagonalGaussianRegularizer()
+ self.weight = weight
+ self.flatten_output = flatten_output
+
+ def forward(self, x) -> Tuple[Dict, torch.Tensor]:
+ z = super().forward(x)
+ z, log = self.posterior(z)
+ log["loss"] = log["kl_loss"]
+ log["weight"] = self.weight
+ if self.flatten_output:
+ z = rearrange(z, "b c h w -> b (h w ) c")
+ return log, z
+
+
+class LatentEncoder(AbstractEmbModel):
+
+ def __init__(self, scale_factor, config, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.scale_factor = scale_factor
+ self.model = instantiate_from_config(config).eval()
+ self.model.train = disabled_train
+
+ def freeze(self):
+ for param in self.model.parameters():
+ param.requires_grad = False
+
+ def forward(self, x):
+ z = self.model.encode(x)
+ z = self.scale_factor * z
+ return z
+
+
+class ViTSTREncoder(VisionTransformer):
+ '''
+ ViTSTREncoder is basically a ViT that uses ViTSTR weights
+ '''
+ def __init__(self, size=224, ckpt_path=None, freeze=True, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.grayscale = transforms.Grayscale()
+ self.resize = transforms.Resize((size, size), transforms.InterpolationMode.BICUBIC, antialias=True)
+
+ self.character = string.printable[:-6]
+ self.reset_classifier(num_classes=len(self.character)+2)
+
+ if ckpt_path is not None:
+ self.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
+
+ if freeze:
+ self.freeze()
+
+ def reset_classifier(self, num_classes):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def freeze(self):
+ for param in self.parameters():
+ param.requires_grad_(False)
+
+ def forward_features(self, x):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)
+ return x
+
+ def forward(self, x):
+
+ x = self.forward_features(x)
+
+ return x
+
+ def encode(self, x):
+ return self(x)
+
+
+class PositionalEncoding(nn.Module):
+
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
+ super(PositionalEncoding, self).__init__()
+
+ self.dropout = nn.Dropout(p=dropout)
+
+ pe = torch.zeros(max_len, d_model)
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ x = x + torch.tile(self.pe[None, ...].to(x.device), (x.shape[0], 1, 1))
+ return self.dropout(x)
+
+
+class LabelEncoder(AbstractEmbModel, pl.LightningModule):
+
+ def __init__(self, max_len, emb_dim, n_heads=8, n_trans_layers=12, ckpt_path=None, trainable=False,
+ lr=1e-4, lambda_cls=0.1, lambda_pos=0.1, clip_dim=1024, visual_len=197, visual_dim=768, visual_config=None, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.max_len = max_len
+ self.emd_dim = emb_dim
+ self.n_heads = n_heads
+ self.n_trans_layers = n_trans_layers
+ self.character = string.printable[:-6]
+ self.num_cls = len(self.character) + 1
+
+ self.label_embedding = nn.Embedding(self.num_cls, self.emd_dim)
+ self.pos_embedding = PositionalEncoding(d_model=self.emd_dim, max_len=self.max_len)
+ transformer_block = nn.TransformerEncoderLayer(d_model=self.emd_dim, nhead=self.n_heads, batch_first=True)
+ self.encoder = nn.TransformerEncoder(transformer_block, num_layers=self.n_trans_layers)
+
+ if ckpt_path is not None:
+ self.load_state_dict(torch.load(ckpt_path, map_location="cpu")["state_dict"], strict=False)
+
+ if trainable:
+
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+ self.visual_encoder = instantiate_from_config(visual_config)
+
+ self.learning_rate = lr
+ self.clip_dim = clip_dim
+ self.visual_len = visual_len
+ self.visual_dim = visual_dim
+ self.lambda_cls = lambda_cls
+ self.lambda_pos = lambda_pos
+
+ self.cls_head = nn.Sequential(*[
+ nn.InstanceNorm1d(self.max_len),
+ nn.Linear(self.emd_dim, self.emd_dim),
+ nn.GELU(),
+ nn.Linear(self.emd_dim, self.num_cls)
+ ])
+
+ self.pos_head = nn.Sequential(*[
+ nn.InstanceNorm1d(self.max_len),
+ nn.Linear(self.emd_dim, self.max_len, bias=False)
+ ])
+
+ self.text_head = nn.Sequential(*[
+ nn.InstanceNorm1d(self.max_len),
+ nn.Linear(self.emd_dim, self.clip_dim, bias=False),
+ nn.Conv1d(in_channels=self.max_len, out_channels=1, kernel_size=1)
+ ])
+
+ self.visual_head = nn.Sequential(*[
+ nn.InstanceNorm1d(self.visual_len),
+ nn.Linear(self.visual_dim, self.clip_dim, bias=False),
+ nn.Conv1d(in_channels=self.visual_len, out_channels=1, kernel_size=1)
+ ])
+
+ def freeze(self):
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def get_index(self, labels):
+
+ indexes = []
+ for label in labels:
+ assert len(label) <= self.max_len
+ index = [self.character.find(c)+1 for c in label]
+ index = index + [0] * (self.max_len - len(index))
+ indexes.append(index)
+
+ return torch.tensor(indexes, device=next(self.parameters()).device)
+
+ def get_embeddings(self, x):
+
+ emb = self.label_embedding(x)
+ emb = self.pos_embedding(emb)
+ out = self.encoder(emb)
+
+ return out
+
+ def forward(self, labels):
+
+ idx = self.get_index(labels)
+ out = self.get_embeddings(idx)
+
+ return out
+
+ def get_loss(self, text_out, visual_out, clip_target, cls_out, pos_out, cls_target, pos_target):
+
+ text_out = text_out / text_out.norm(dim=1, keepdim=True) # b, 1024
+ visual_out = visual_out / visual_out.norm(dim=1, keepdim=True) # b, 1024
+
+ logit_scale = self.logit_scale.exp()
+ logits_per_image = logit_scale * visual_out @ text_out.T # b, b
+ logits_per_text = logits_per_image.T # b, b
+
+ clip_loss_image = nn.functional.cross_entropy(logits_per_image, clip_target)
+ clip_loss_text = nn.functional.cross_entropy(logits_per_text, clip_target)
+ clip_loss = (clip_loss_image + clip_loss_text) / 2
+
+ cls_loss = nn.functional.cross_entropy(cls_out.permute(0,2,1), cls_target)
+ pos_loss = nn.functional.cross_entropy(pos_out.permute(0,2,1), pos_target)
+
+ return clip_loss, cls_loss, pos_loss, logits_per_text
+
+ def training_step(self, batch, batch_idx):
+
+ text = batch["text"]
+ image = batch["image"]
+
+ idx = self.get_index(text)
+ text_emb = self.get_embeddings(idx) # b, l, d
+ visual_emb = self.visual_encoder(image) # b, n, d
+
+ cls_out = self.cls_head(text_emb) # b, l, c
+ pos_out = self.pos_head(text_emb) # b, l, p
+ text_out = self.text_head(text_emb).squeeze(1) # b, 1024
+ visual_out = self.visual_head(visual_emb).squeeze(1) # b, 1024
+
+ cls_target = idx # b, c
+ pos_target = torch.arange(start=0, end=self.max_len, step=1)
+ pos_target = pos_target[None].tile((idx.shape[0], 1)).to(cls_target) # b, c
+ clip_target = torch.arange(0, idx.shape[0], 1).to(cls_target) # b,
+
+ clip_loss, cls_loss, pos_loss, logits_per_text = self.get_loss(text_out, visual_out, clip_target, cls_out, pos_out, cls_target, pos_target)
+ loss = clip_loss + self.lambda_cls * cls_loss + self.lambda_pos * pos_loss
+
+ loss_dict = {}
+ loss_dict["loss/clip_loss"] = clip_loss
+ loss_dict["loss/cls_loss"] = cls_loss
+ loss_dict["loss/pos_loss"] = pos_loss
+ loss_dict["loss/full_loss"] = loss
+
+ clip_idx = torch.max(logits_per_text, dim=-1).indices # b,
+ clip_acc = (clip_idx == clip_target).to(dtype=torch.float32).mean()
+
+ cls_idx = torch.max(cls_out, dim=-1).indices # b, l
+ cls_acc = (cls_idx == cls_target).to(dtype=torch.float32).mean()
+
+ pos_idx = torch.max(pos_out, dim=-1).indices # b, l
+ pos_acc = (pos_idx == pos_target).to(dtype=torch.float32).mean()
+
+ loss_dict["acc/clip_acc"] = clip_acc
+ loss_dict["acc/cls_acc"] = cls_acc
+ loss_dict["acc/pos_acc"] = pos_acc
+
+ self.log_dict(loss_dict, prog_bar=True, batch_size=len(text),
+ logger=True, on_step=True, on_epoch=True, sync_dist=True)
+
+ return loss
+
+ def configure_optimizers(self):
+
+ lr = self.learning_rate
+ opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=lr)
+
+ return opt
+
+
diff --git a/sgm/modules/predictors/__pycache__/model.cpython-311.pyc b/sgm/modules/predictors/__pycache__/model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..930a3e7f512cca49892fdd2fa66d6d56443682b3
Binary files /dev/null and b/sgm/modules/predictors/__pycache__/model.cpython-311.pyc differ
diff --git a/sgm/modules/predictors/__pycache__/model.cpython-38.pyc b/sgm/modules/predictors/__pycache__/model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e8922a6c80564798d867e4ad5fb5fd69ffd9dcbe
Binary files /dev/null and b/sgm/modules/predictors/__pycache__/model.cpython-38.pyc differ
diff --git a/sgm/modules/predictors/model.py b/sgm/modules/predictors/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc264b7fe70af379135c814dc393837534a80b3b
--- /dev/null
+++ b/sgm/modules/predictors/model.py
@@ -0,0 +1,58 @@
+import torch
+import torch.nn as nn
+from torchvision import transforms
+from torchvision.utils import save_image
+
+
+class ParseqPredictor(nn.Module):
+
+ def __init__(self, ckpt_path=None, freeze=True, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.parseq = torch.hub.load('./src/parseq', 'parseq', source='local').eval()
+ self.parseq.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
+ self.parseq_transform = transforms.Compose([
+ transforms.Resize(self.parseq.hparams.img_size, transforms.InterpolationMode.BICUBIC, antialias=True),
+ transforms.Normalize(0.5, 0.5)
+ ])
+
+ if freeze:
+ self.freeze()
+
+ def freeze(self):
+ for param in self.parseq.parameters():
+ param.requires_grad_(False)
+
+ def forward(self, x):
+
+ x = torch.cat([self.parseq_transform(t[None]) for t in x])
+ logits = self.parseq(x.to(next(self.parameters()).device))
+
+ return logits
+
+ def img2txt(self, x):
+
+ pred = self(x)
+ label, confidence = self.parseq.tokenizer.decode(pred)
+ return label
+
+
+ def calc_loss(self, x, label):
+
+ preds = self(x) # (B, l, C) l=26, C=95
+ gt_ids = self.parseq.tokenizer.encode(label).to(preds.device) # (B, l_trun)
+
+ losses = []
+ for pred, gt_id in zip(preds, gt_ids):
+
+ eos_id = (gt_id == 0).nonzero().item()
+ gt_id = gt_id[1: eos_id]
+ pred = pred[:eos_id-1, :]
+
+ ce_loss = nn.functional.cross_entropy(pred.permute(1, 0)[None], gt_id[None])
+ ce_loss = torch.clamp(ce_loss, max = 1.0)
+ losses.append(ce_loss[None])
+
+ loss = torch.cat(losses)
+
+ return loss
\ No newline at end of file
diff --git a/sgm/util.py b/sgm/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..06f48a882de7b5b37d48fffe7b44d2bf883bd85d
--- /dev/null
+++ b/sgm/util.py
@@ -0,0 +1,231 @@
+import functools
+import importlib
+import os
+from functools import partial
+from inspect import isfunction
+
+import fsspec
+import numpy as np
+import torch
+from PIL import Image, ImageDraw, ImageFont
+from safetensors.torch import load_file as load_safetensors
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def get_string_from_tuple(s):
+ try:
+ # Check if the string starts and ends with parentheses
+ if s[0] == "(" and s[-1] == ")":
+ # Convert the string to a tuple
+ t = eval(s)
+ # Check if the type of t is tuple
+ if type(t) == tuple:
+ return t[0]
+ else:
+ pass
+ except:
+ pass
+ return s
+
+
+def is_power_of_two(n):
+ """
+ chat.openai.com/chat
+ Return True if n is a power of 2, otherwise return False.
+
+ The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.
+ The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.
+ If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.
+ Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.
+
+ """
+ if n <= 0:
+ return False
+ return (n & (n - 1)) == 0
+
+
+def autocast(f, enabled=True):
+ def do_autocast(*args, **kwargs):
+ with torch.cuda.amp.autocast(
+ enabled=enabled,
+ dtype=torch.get_autocast_gpu_dtype(),
+ cache_enabled=torch.is_autocast_cache_enabled(),
+ ):
+ return f(*args, **kwargs)
+
+ return do_autocast
+
+
+def load_partial_from_config(config):
+ return partial(get_obj_from_str(config["target"]), **config.get("params", dict()))
+
+
+def log_txt_as_img(wh, xc, size=10):
+ # wh a tuple of (width, height)
+ # xc a list of captions to plot
+ b = len(xc)
+ txts = list()
+ for bi in range(b):
+ txt = Image.new("RGB", wh, color="white")
+ draw = ImageDraw.Draw(txt)
+ font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
+ nc = int(40 * (wh[0] / 256))
+ if isinstance(xc[bi], list):
+ text_seq = xc[bi][0]
+ else:
+ text_seq = xc[bi]
+ lines = "\n".join(
+ text_seq[start : start + nc] for start in range(0, len(text_seq), nc)
+ )
+
+ try:
+ draw.text((0, 0), lines, fill="black", font=font)
+ except UnicodeEncodeError:
+ print("Cant encode string for logging. Skipping.")
+
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+ txts.append(txt)
+ txts = np.stack(txts)
+ txts = torch.tensor(txts)
+ return txts
+
+
+def partialclass(cls, *args, **kwargs):
+ class NewCls(cls):
+ __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
+
+ return NewCls
+
+
+def make_path_absolute(path):
+ fs, p = fsspec.core.url_to_fs(path)
+ if fs.protocol == "file":
+ return os.path.abspath(p)
+ return path
+
+
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def isheatmap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+
+ return x.ndim == 2
+
+
+def isneighbors(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)
+
+
+def exists(x):
+ return x is not None
+
+
+def expand_dims_like(x, y):
+ while x.dim() != y.dim():
+ x = x.unsqueeze(-1)
+ return x
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+ """
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
+ return total_params
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == "__is_first_stage__":
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False, invalidate_cache=True):
+ module, cls = string.rsplit(".", 1)
+ if invalidate_cache:
+ importlib.invalidate_caches()
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+def append_zero(x):
+ return torch.cat([x, x.new_zeros([1])])
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(
+ f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
+ )
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def load_model_from_config(config, ckpt, verbose=True, freeze=True):
+ print(f"Loading model from {ckpt}")
+ if ckpt.endswith("ckpt"):
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ elif ckpt.endswith("safetensors"):
+ sd = load_safetensors(ckpt)
+ else:
+ raise NotImplementedError
+
+ model = instantiate_from_config(config.model)
+ sd = pl_sd["state_dict"]
+
+ m, u = model.load_state_dict(sd, strict=False)
+
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ if freeze:
+ for param in model.parameters():
+ param.requires_grad = False
+
+ model.eval()
+ return model
diff --git a/util.py b/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb4a070ead04a775c484a6ea34123a86471d8a6a
--- /dev/null
+++ b/util.py
@@ -0,0 +1,70 @@
+import torch
+from omegaconf import OmegaConf
+from sgm.util import instantiate_from_config
+from sgm.modules.diffusionmodules.sampling import *
+
+
+def init_model(cfgs):
+
+ model_cfg = OmegaConf.load(cfgs.model_cfg_path)
+ ckpt = cfgs.load_ckpt_path
+
+ model = instantiate_from_config(model_cfg.model)
+ model.init_from_ckpt(ckpt)
+
+ if cfgs.type == "train":
+ model.train()
+ else:
+ model.to(torch.device("cuda", index=cfgs.gpu))
+ model.eval()
+ model.freeze()
+
+ return model
+
+def init_sampling(cfgs):
+
+ discretization_config = {
+ "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
+ }
+
+ guider_config = {
+ "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
+ "params": {"scale": cfgs.scale[0]},
+ }
+
+ sampler = EulerEDMSampler(
+ num_steps=cfgs.steps,
+ discretization_config=discretization_config,
+ guider_config=guider_config,
+ s_churn=0.0,
+ s_tmin=0.0,
+ s_tmax=999.0,
+ s_noise=1.0,
+ verbose=True,
+ device=torch.device("cuda", index=cfgs.gpu)
+ )
+
+ return sampler
+
+def deep_copy(batch):
+
+ c_batch = {}
+ for key in batch:
+ if isinstance(batch[key], torch.Tensor):
+ c_batch[key] = torch.clone(batch[key])
+ elif isinstance(batch[key], (tuple, list)):
+ c_batch[key] = batch[key].copy()
+ else:
+ c_batch[key] = batch[key]
+
+ return c_batch
+
+def prepare_batch(cfgs, batch):
+
+ for key in batch:
+ if isinstance(batch[key], torch.Tensor):
+ batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu))
+
+ batch_uc = batch
+
+ return batch, batch_uc
\ No newline at end of file