abyildirim commited on
Commit
2d42726
1 Parent(s): 2d134d2

gradio files are synced with the github repo

Browse files
.gitattributes DELETED
@@ -1,34 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tflite filter=lfs diff=lfs merge=lfs -text
29
- *.tgz filter=lfs diff=lfs merge=lfs -text
30
- *.wasm filter=lfs diff=lfs merge=lfs -text
31
- *.xz filter=lfs diff=lfs merge=lfs -text
32
- *.zip filter=lfs diff=lfs merge=lfs -text
33
- *.zst filter=lfs diff=lfs merge=lfs -text
34
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: Inst Inpaint
3
- emoji: 📊
4
  colorFrom: purple
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.24.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
1
  ---
2
  title: Inst Inpaint
3
+ emoji: 🖌️
4
  colorFrom: purple
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 3.39.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py CHANGED
@@ -1,35 +1,53 @@
 
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
  from PIL import Image
5
-
6
  import constants
7
  import utils
 
 
 
 
 
8
 
9
- PREDICTOR = None
10
-
11
 
12
- def inference(image: np.ndarray, text: str, center_crop: bool):
13
- num_steps = 10
14
- if not text.lower().startswith("remove the"):
15
  raise gr.Error("Instruction should start with 'Remove the' !")
16
-
17
  image = Image.fromarray(image)
18
  cropped_image, image = utils.preprocess_image(image, center_crop=center_crop)
19
-
20
- utils.seed_everything()
21
- prediction = PREDICTOR.predict(image, text, num_steps)
22
-
23
- print("Num steps:", num_steps)
24
-
25
- return cropped_image, prediction
26
-
27
 
28
  if __name__ == "__main__":
29
- utils.setup_environment()
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- if not PREDICTOR:
32
- PREDICTOR = utils.get_predictor()
 
 
 
 
 
 
 
 
 
33
 
34
  sample_image, sample_instruction, sample_step = constants.EXAMPLES[3]
35
 
 
1
+ import argparse
2
  import gradio as gr
3
  import numpy as np
4
  import torch
5
  from PIL import Image
 
6
  import constants
7
  import utils
8
+ from ldm.util import instantiate_from_config
9
+ from omegaconf import OmegaConf
10
+ from zipfile import ZipFile
11
+ import gdown
12
+ import os
13
 
14
+ MODEL = None
 
15
 
16
+ def inference(image: np.ndarray, instruction: str, center_crop: bool):
17
+ if not instruction.lower().startswith("remove the"):
 
18
  raise gr.Error("Instruction should start with 'Remove the' !")
 
19
  image = Image.fromarray(image)
20
  cropped_image, image = utils.preprocess_image(image, center_crop=center_crop)
21
+ output_image = MODEL.inpaint(image, instruction, num_steps=10, device="cpu", return_pil=True, seed=0)
22
+ return cropped_image, output_image
 
 
 
 
 
 
23
 
24
  if __name__ == "__main__":
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument(
27
+ "--config",
28
+ type=str,
29
+ default="configs/latent-diffusion/gqa-inpaint-ldm-vq-f8-256x256.yaml",
30
+ help="Path of the model config file",
31
+ )
32
+ parser.add_argument(
33
+ "--checkpoint",
34
+ type=str,
35
+ default="models/gqa_inpaint/ldm/model.ckpt",
36
+ help="Path of the model checkpoint file",
37
+ )
38
+ args = parser.parse_args()
39
 
40
+ gdown.download(id="1tp0aHAS-ccrIfNz7XrGTSdNIPNZjOVSp", output="models/")
41
+ with ZipFile("models/gqa_inpaint.zip", 'r') as zObject:
42
+ zObject.extractall(path="models/")
43
+ os.remove("models/gqa_inpaint.zip")
44
+
45
+ parsed_config = OmegaConf.load(args.config)
46
+ MODEL = instantiate_from_config(parsed_config["model"])
47
+ model_state_dict = torch.load(args.checkpoint, map_location="cpu")["state_dict"]
48
+ MODEL.load_state_dict(model_state_dict)
49
+ MODEL.eval()
50
+ MODEL.to("cpu")
51
 
52
  sample_image, sample_instruction, sample_step = constants.EXAMPLES[3]
53
 
configs/latent-diffusion/gqa-inpaint-ldm-vq-f8-256x256.yaml ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 2.0e-06
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0195
7
+ num_timesteps_cond: 1
8
+ log_every_t: 100
9
+ timesteps: 1000
10
+ first_stage_key: "target_image"
11
+ cond_stage_key: "source_image"
12
+ cond_stage_trainable: False
13
+ cond_stage_instruction_key: "text"
14
+ cond_stage_instruction_embedder_trainable: True
15
+ conditioning_key: "hybrid"
16
+ image_size: 32
17
+ channels: 4
18
+ monitor: val/loss_simple_ema
19
+
20
+ unet_config:
21
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22
+ params:
23
+ image_size: 32
24
+ in_channels: 8 # 4 (noisy image features) + 4 (source image features)
25
+ out_channels: 4
26
+ model_channels: 128
27
+ attention_resolutions: [8,4,2]
28
+ num_res_blocks: 2
29
+ channel_mult: [1,2,3,4]
30
+ num_heads: 8
31
+ resblock_updown: True
32
+
33
+ ###### Instruction embedding cross attention ######
34
+ use_spatial_transformer: true
35
+ transformer_depth: 1
36
+ context_dim: 512
37
+ ###################################################
38
+
39
+ first_stage_config:
40
+ target: ldm.models.autoencoder.VQModel
41
+ params:
42
+ ckpt_path: models/gqa_inpaint/first_stage/vq-f8-cb16384-openimages.ckpt
43
+ monitor: "val/rec_loss"
44
+ embed_dim: 4
45
+ n_embed: 16384
46
+ lossconfig:
47
+ target: torch.nn.Identity
48
+ ddconfig:
49
+ double_z: false
50
+ z_channels: 4
51
+ resolution: 256
52
+ in_channels: 3
53
+ out_ch: 3
54
+ ch: 128
55
+ ch_mult: [1,2,2,4]
56
+ num_res_blocks: 2
57
+ attn_resolutions: [32]
58
+ dropout: 0.0
59
+
60
+ cond_stage_config: __is_first_stage__
61
+
62
+ cond_stage_instruction_embedder_config:
63
+ target: ldm.modules.encoders.modules.BERTEmbedder
64
+ params:
65
+ n_embed: 512
66
+ n_layer: 16
67
+
68
+ data:
69
+ target: main.DataModuleFromConfig
70
+ params:
71
+ batch_size: 8
72
+ num_workers: 4
73
+
74
+ train:
75
+ target: dataset.gqa_inpaint.GQAInpaintTrain
76
+ params:
77
+ images_root: "data/gqa-inpaint/images"
78
+ images_inpainted_root: "data/gqa-inpaint/images_inpainted"
79
+ masks_root: "data/gqa-inpaint/masks"
80
+ scene_json_path: "data/gqa-inpaint/train_scenes.json"
81
+ max_relations: 1
82
+ simplify_augment: True
83
+ instruction_type: "remove"
84
+ size: 256
85
+ irrelevant_text_prob: 0.2
86
+
87
+ validation:
88
+ target: dataset.gqa_inpaint.GQAInpaintTest
89
+ params:
90
+ images_root: "data/gqa-inpaint/images"
91
+ images_inpainted_root: "data/gqa-inpaint/images_inpainted"
92
+ masks_root: "data/gqa-inpaint/masks"
93
+ scene_json_path: "data/gqa-inpaint/test_scenes.json"
94
+ max_relations: 1
95
+ simplify_augment: True
96
+ instruction_type: "remove"
97
+ size: 256
98
+
99
+ test:
100
+ target: dataset.gqa_inpaint.GQAInpaintTest
101
+ params:
102
+ images_root: "data/gqa-inpaint/images"
103
+ images_inpainted_root: "data/gqa-inpaint/images_inpainted"
104
+ masks_root: "data/gqa-inpaint/masks"
105
+ scene_json_path: "data/gqa-inpaint/test_scenes.json"
106
+ test_instructions_path: "data/gqa-inpaint/test_instructions.json"
107
+ max_relations: 1
108
+ simplify_augment: True
109
+ instruction_type: "remove"
110
+ size: 256
111
+
112
+ lightning:
113
+ callbacks:
114
+ image_logger:
115
+ target: main.ImageLogger
116
+ params:
117
+ batch_frequency: 5000
118
+ max_images: 8
119
+ increase_log_steps: True
120
+
121
+ trainer:
122
+ benchmark: True
constants.py CHANGED
@@ -14,7 +14,7 @@ DESCRIPTION = """
14
  EXAMPLES = [
15
  ["examples/kite-boy.png", "Remove the colorful kite", True],
16
  ["examples/cat-car.jpg", "Remove the car", True],
17
- ["examples/bus-tree.jpg", "Remove the bus", True],
18
  ["examples/cups.webp", "Remove the cup at the left", True],
19
  ["examples/woman-fantasy.jpg", "Remove the woman", True],
20
  ["examples/clock.png", "Remove the round clock at the center", True],
 
14
  EXAMPLES = [
15
  ["examples/kite-boy.png", "Remove the colorful kite", True],
16
  ["examples/cat-car.jpg", "Remove the car", True],
17
+ ["examples/bus-tree.jpg", "Remove the red bus", True],
18
  ["examples/cups.webp", "Remove the cup at the left", True],
19
  ["examples/woman-fantasy.jpg", "Remove the woman", True],
20
  ["examples/clock.png", "Remove the round clock at the center", True],
ldm/__init__.py ADDED
File without changes
ldm/lr_scheduler.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9
+ self.lr_warm_up_steps = warm_up_steps
10
+ self.lr_start = lr_start
11
+ self.lr_min = lr_min
12
+ self.lr_max = lr_max
13
+ self.lr_max_decay_steps = max_decay_steps
14
+ self.last_lr = 0.
15
+ self.verbosity_interval = verbosity_interval
16
+
17
+ def schedule(self, n, **kwargs):
18
+ if self.verbosity_interval > 0:
19
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20
+ if n < self.lr_warm_up_steps:
21
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22
+ self.last_lr = lr
23
+ return lr
24
+ else:
25
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26
+ t = min(t, 1.0)
27
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28
+ 1 + np.cos(t * np.pi))
29
+ self.last_lr = lr
30
+ return lr
31
+
32
+ def __call__(self, n, **kwargs):
33
+ return self.schedule(n,**kwargs)
34
+
35
+
36
+ class LambdaWarmUpCosineScheduler2:
37
+ """
38
+ supports repeated iterations, configurable via lists
39
+ note: use with a base_lr of 1.0.
40
+ """
41
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43
+ self.lr_warm_up_steps = warm_up_steps
44
+ self.f_start = f_start
45
+ self.f_min = f_min
46
+ self.f_max = f_max
47
+ self.cycle_lengths = cycle_lengths
48
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49
+ self.last_f = 0.
50
+ self.verbosity_interval = verbosity_interval
51
+
52
+ def find_in_interval(self, n):
53
+ interval = 0
54
+ for cl in self.cum_cycles[1:]:
55
+ if n <= cl:
56
+ return interval
57
+ interval += 1
58
+
59
+ def schedule(self, n, **kwargs):
60
+ cycle = self.find_in_interval(n)
61
+ n = n - self.cum_cycles[cycle]
62
+ if self.verbosity_interval > 0:
63
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64
+ f"current cycle {cycle}")
65
+ if n < self.lr_warm_up_steps[cycle]:
66
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67
+ self.last_f = f
68
+ return f
69
+ else:
70
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71
+ t = min(t, 1.0)
72
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73
+ 1 + np.cos(t * np.pi))
74
+ self.last_f = f
75
+ return f
76
+
77
+ def __call__(self, n, **kwargs):
78
+ return self.schedule(n, **kwargs)
79
+
80
+
81
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82
+
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88
+ f"current cycle {cycle}")
89
+
90
+ if n < self.lr_warm_up_steps[cycle]:
91
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92
+ self.last_f = f
93
+ return f
94
+ else:
95
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96
+ self.last_f = f
97
+ return f
98
+
ldm/models/autoencoder.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ import torch.nn.functional as F
4
+ from contextlib import contextmanager
5
+ from packaging import version
6
+ import numpy as np
7
+
8
+ from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
9
+
10
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
11
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
12
+
13
+ from ldm.util import instantiate_from_config
14
+
15
+
16
+ class VQModel(pl.LightningModule):
17
+ def __init__(self,
18
+ ddconfig,
19
+ lossconfig,
20
+ n_embed,
21
+ embed_dim,
22
+ ckpt_path=None,
23
+ ignore_keys=[],
24
+ image_key="image",
25
+ colorize_nlabels=None,
26
+ monitor=None,
27
+ batch_resize_range=None,
28
+ scheduler_config=None,
29
+ lr_g_factor=1.0,
30
+ remap=None,
31
+ sane_index_shape=False, # Telling vector quantizer to return indices
32
+ use_ema=False
33
+ ):
34
+ super().__init__()
35
+ self.embed_dim = embed_dim
36
+ self.n_embed = n_embed
37
+ self.image_key = image_key
38
+ self.encoder = Encoder(**ddconfig)
39
+ self.decoder = Decoder(**ddconfig)
40
+ self.loss = instantiate_from_config(lossconfig)
41
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
42
+ remap=remap,
43
+ sane_index_shape=sane_index_shape)
44
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
45
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
46
+ if colorize_nlabels is not None:
47
+ assert type(colorize_nlabels)==int
48
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
49
+ if monitor is not None:
50
+ self.monitor = monitor
51
+ self.batch_resize_range = batch_resize_range
52
+ if self.batch_resize_range is not None:
53
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
54
+
55
+ self.use_ema = use_ema
56
+ if self.use_ema:
57
+ self.model_ema = LitEma(self)
58
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
59
+
60
+ if ckpt_path is not None:
61
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
62
+ self.scheduler_config = scheduler_config
63
+ self.lr_g_factor = lr_g_factor
64
+
65
+ @contextmanager
66
+ def ema_scope(self, context=None):
67
+ if self.use_ema:
68
+ self.model_ema.store(self.parameters())
69
+ self.model_ema.copy_to(self)
70
+ if context is not None:
71
+ print(f"{context}: Switched to EMA weights")
72
+ try:
73
+ yield None
74
+ finally:
75
+ if self.use_ema:
76
+ self.model_ema.restore(self.parameters())
77
+ if context is not None:
78
+ print(f"{context}: Restored training weights")
79
+
80
+ def init_from_ckpt(self, path, ignore_keys=list()):
81
+ sd = torch.load(path, map_location="cpu")["state_dict"]
82
+ keys = list(sd.keys())
83
+ for k in keys:
84
+ for ik in ignore_keys:
85
+ if k.startswith(ik):
86
+ print("Deleting key {} from state_dict.".format(k))
87
+ del sd[k]
88
+ missing, unexpected = self.load_state_dict(sd, strict=False)
89
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
90
+ if len(missing) > 0:
91
+ print(f"Missing Keys: {missing}")
92
+ print(f"Unexpected Keys: {unexpected}")
93
+
94
+ def on_train_batch_end(self, *args, **kwargs):
95
+ if self.use_ema:
96
+ self.model_ema(self)
97
+
98
+ def encode(self, x, return_all=False):
99
+ h = self.encoder(x)
100
+ h = self.quant_conv(h)
101
+ quant, emb_loss, info = self.quantize(h)
102
+ if return_all:
103
+ return quant, emb_loss, info
104
+ return quant
105
+
106
+ def encode_to_prequant(self, x):
107
+ h = self.encoder(x)
108
+ h = self.quant_conv(h)
109
+ return h
110
+
111
+ def decode(self, quant):
112
+ quant = self.post_quant_conv(quant)
113
+ dec = self.decoder(quant)
114
+ return dec
115
+
116
+ def decode_code(self, code_b):
117
+ quant_b = self.quantize.embed_code(code_b)
118
+ dec = self.decode(quant_b)
119
+ return dec
120
+
121
+ def forward(self, input, return_pred_indices=False):
122
+ quant, diff, (_,_,ind) = self.encode(input)
123
+ dec = self.decode(quant)
124
+ if return_pred_indices:
125
+ return dec, diff, ind
126
+ return dec, diff
127
+
128
+ def get_input(self, batch, k):
129
+ x = batch[k]
130
+ if len(x.shape) == 3:
131
+ x = x[..., None]
132
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
133
+ if self.batch_resize_range is not None:
134
+ lower_size = self.batch_resize_range[0]
135
+ upper_size = self.batch_resize_range[1]
136
+ if self.global_step <= 4:
137
+ new_resize = upper_size
138
+ else:
139
+ new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
140
+ if new_resize != x.shape[2]:
141
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
142
+ x = x.detach()
143
+ return x
144
+
145
+ def training_step(self, batch, batch_idx, optimizer_idx):
146
+ # https://github.com/pytorch/pytorch/issues/37142
147
+ # Try not to fool the heuristics
148
+ x = self.get_input(batch, self.image_key)
149
+ xrec, qloss, ind = self(x, return_pred_indices=True)
150
+
151
+ if optimizer_idx == 0:
152
+ # autoencode
153
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
154
+ last_layer=self.get_last_layer(), split="train",
155
+ predicted_indices=ind)
156
+
157
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
158
+ return aeloss
159
+
160
+ if optimizer_idx == 1:
161
+ # Discriminator
162
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
163
+ last_layer=self.get_last_layer(), split="train")
164
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
165
+ return discloss
166
+
167
+ def validation_step(self, batch, batch_idx):
168
+ log_dict = self._validation_step(batch, batch_idx)
169
+ with self.ema_scope():
170
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
171
+ return log_dict
172
+
173
+ def _validation_step(self, batch, batch_idx, suffix=""):
174
+ x = self.get_input(batch, self.image_key)
175
+ xrec, qloss, ind = self(x, return_pred_indices=True)
176
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
177
+ self.global_step,
178
+ last_layer=self.get_last_layer(),
179
+ split="val"+suffix,
180
+ predicted_indices=ind
181
+ )
182
+
183
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
184
+ self.global_step,
185
+ last_layer=self.get_last_layer(),
186
+ split="val"+suffix,
187
+ predicted_indices=ind
188
+ )
189
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
190
+ self.log(f"val{suffix}/rec_loss", rec_loss,
191
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
192
+ self.log(f"val{suffix}/aeloss", aeloss,
193
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
194
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
195
+ del log_dict_ae[f"val{suffix}/rec_loss"]
196
+ self.log_dict(log_dict_ae)
197
+ self.log_dict(log_dict_disc)
198
+ return self.log_dict
199
+
200
+ def configure_optimizers(self):
201
+ lr_d = self.learning_rate
202
+ lr_g = self.lr_g_factor*self.learning_rate
203
+ print("lr_d", lr_d)
204
+ print("lr_g", lr_g)
205
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
206
+ list(self.decoder.parameters())+
207
+ list(self.quantize.parameters())+
208
+ list(self.quant_conv.parameters())+
209
+ list(self.post_quant_conv.parameters()),
210
+ lr=lr_g, betas=(0.5, 0.9))
211
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
212
+ lr=lr_d, betas=(0.5, 0.9))
213
+
214
+ if self.scheduler_config is not None:
215
+ scheduler = instantiate_from_config(self.scheduler_config)
216
+
217
+ print("Setting up LambdaLR scheduler...")
218
+ scheduler = [
219
+ {
220
+ 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
221
+ 'interval': 'step',
222
+ 'frequency': 1
223
+ },
224
+ {
225
+ 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
226
+ 'interval': 'step',
227
+ 'frequency': 1
228
+ },
229
+ ]
230
+ return [opt_ae, opt_disc], scheduler
231
+ return [opt_ae, opt_disc], []
232
+
233
+ def get_last_layer(self):
234
+ return self.decoder.conv_out.weight
235
+
236
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
237
+ log = dict()
238
+ x = self.get_input(batch, self.image_key)
239
+ x = x.to(self.device)
240
+ if only_inputs:
241
+ log["inputs"] = x
242
+ return log
243
+ xrec, _ = self(x)
244
+ if x.shape[1] > 3:
245
+ # Colorize with random projection
246
+ assert xrec.shape[1] > 3
247
+ x = self.to_rgb(x)
248
+ xrec = self.to_rgb(xrec)
249
+ log["inputs"] = x
250
+ log["reconstructions"] = xrec
251
+ if plot_ema:
252
+ with self.ema_scope():
253
+ xrec_ema, _ = self(x)
254
+ if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
255
+ log["reconstructions_ema"] = xrec_ema
256
+ return log
257
+
258
+ def to_rgb(self, x):
259
+ assert self.image_key == "segmentation"
260
+ if not hasattr(self, "colorize"):
261
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
262
+ x = F.conv2d(x, weight=self.colorize)
263
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
264
+ return x
265
+
266
+
267
+ class VQModelInterface(VQModel):
268
+ def __init__(self, embed_dim, *args, **kwargs):
269
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
270
+ self.embed_dim = embed_dim
271
+
272
+ def encode(self, x):
273
+ h = self.encoder(x)
274
+ h = self.quant_conv(h)
275
+ return h
276
+
277
+ def decode(self, h, force_not_quantize=False):
278
+ # Also go through quantization layer
279
+ if not force_not_quantize:
280
+ quant, emb_loss, info = self.quantize(h)
281
+ else:
282
+ quant = h
283
+ quant = self.post_quant_conv(quant)
284
+ dec = self.decoder(quant)
285
+ return dec
286
+
287
+
288
+ class AutoencoderKL(pl.LightningModule):
289
+ def __init__(self,
290
+ ddconfig,
291
+ lossconfig,
292
+ embed_dim,
293
+ ckpt_path=None,
294
+ ignore_keys=[],
295
+ image_key="image",
296
+ colorize_nlabels=None,
297
+ monitor=None,
298
+ ):
299
+ super().__init__()
300
+ self.image_key = image_key
301
+ self.encoder = Encoder(**ddconfig)
302
+ self.decoder = Decoder(**ddconfig)
303
+ self.loss = instantiate_from_config(lossconfig)
304
+ assert ddconfig["double_z"]
305
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
306
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
307
+ self.embed_dim = embed_dim
308
+ if colorize_nlabels is not None:
309
+ assert type(colorize_nlabels)==int
310
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
311
+ if monitor is not None:
312
+ self.monitor = monitor
313
+ if ckpt_path is not None:
314
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
315
+
316
+ def init_from_ckpt(self, path, ignore_keys=list()):
317
+ sd = torch.load(path, map_location="cpu")["state_dict"]
318
+ keys = list(sd.keys())
319
+ for k in keys:
320
+ for ik in ignore_keys:
321
+ if k.startswith(ik):
322
+ print("Deleting key {} from state_dict.".format(k))
323
+ del sd[k]
324
+ self.load_state_dict(sd, strict=False)
325
+ print(f"Restored from {path}")
326
+
327
+ def encode(self, x):
328
+ h = self.encoder(x)
329
+ moments = self.quant_conv(h)
330
+ posterior = DiagonalGaussianDistribution(moments)
331
+ return posterior
332
+
333
+ def decode(self, z):
334
+ z = self.post_quant_conv(z)
335
+ dec = self.decoder(z)
336
+ return dec
337
+
338
+ def forward(self, input, sample_posterior=True):
339
+ posterior = self.encode(input)
340
+ if sample_posterior:
341
+ z = posterior.sample()
342
+ else:
343
+ z = posterior.mode()
344
+ dec = self.decode(z)
345
+ return dec, posterior
346
+
347
+ def get_input(self, batch, k):
348
+ x = batch[k]
349
+ if len(x.shape) == 3:
350
+ x = x[..., None]
351
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
352
+ return x
353
+
354
+ def training_step(self, batch, batch_idx, optimizer_idx):
355
+ inputs = self.get_input(batch, self.image_key)
356
+ reconstructions, posterior = self(inputs)
357
+
358
+ if optimizer_idx == 0:
359
+ # Training encoder + decoder + logvar
360
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
361
+ last_layer=self.get_last_layer(), split="train")
362
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
363
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
364
+ return aeloss
365
+
366
+ if optimizer_idx == 1:
367
+ # Training the discriminator
368
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
369
+ last_layer=self.get_last_layer(), split="train")
370
+
371
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
372
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
373
+ return discloss
374
+
375
+ def validation_step(self, batch, batch_idx):
376
+ inputs = self.get_input(batch, self.image_key)
377
+ reconstructions, posterior = self(inputs)
378
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
379
+ last_layer=self.get_last_layer(), split="val")
380
+
381
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
382
+ last_layer=self.get_last_layer(), split="val")
383
+
384
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
385
+ self.log_dict(log_dict_ae)
386
+ self.log_dict(log_dict_disc)
387
+ return self.log_dict
388
+
389
+ def configure_optimizers(self):
390
+ lr = self.learning_rate
391
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
392
+ list(self.decoder.parameters())+
393
+ list(self.quant_conv.parameters())+
394
+ list(self.post_quant_conv.parameters()),
395
+ lr=lr, betas=(0.5, 0.9))
396
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
397
+ lr=lr, betas=(0.5, 0.9))
398
+ return [opt_ae, opt_disc], []
399
+
400
+ def get_last_layer(self):
401
+ return self.decoder.conv_out.weight
402
+
403
+ @torch.no_grad()
404
+ def log_images(self, batch, only_inputs=False, **kwargs):
405
+ log = dict()
406
+ x = self.get_input(batch, self.image_key)
407
+ x = x.to(self.device)
408
+ if not only_inputs:
409
+ xrec, posterior = self(x)
410
+ if x.shape[1] > 3:
411
+ # Colorize with random projection
412
+ assert xrec.shape[1] > 3
413
+ x = self.to_rgb(x)
414
+ xrec = self.to_rgb(xrec)
415
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
416
+ log["reconstructions"] = xrec
417
+ log["inputs"] = x
418
+ return log
419
+
420
+ def to_rgb(self, x):
421
+ assert self.image_key == "segmentation"
422
+ if not hasattr(self, "colorize"):
423
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
424
+ x = F.conv2d(x, weight=self.colorize)
425
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
426
+ return x
ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+
5
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
6
+
7
+
8
+ class DDIMSampler(object):
9
+ def __init__(self, model, schedule="linear", device="cuda", **kwargs):
10
+ super().__init__()
11
+ self.model = model
12
+ self.ddpm_num_timesteps = model.num_timesteps
13
+ self.schedule = schedule
14
+ self.device = device
15
+
16
+ def register_buffer(self, name, attr):
17
+ if type(attr) == torch.Tensor:
18
+ if self.device == "cuda" and attr.device != torch.device("cuda"):
19
+ attr = attr.to(torch.device("cuda"))
20
+ setattr(self, name, attr)
21
+
22
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
23
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
24
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
25
+ alphas_cumprod = self.model.alphas_cumprod
26
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
27
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
28
+
29
+ self.register_buffer('betas', to_torch(self.model.betas))
30
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
31
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
32
+
33
+ # Calculations for diffusion q(x_t | x_{t-1}) and others
34
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
35
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
36
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
37
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
38
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
39
+
40
+ # DDIM sampling parameters
41
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
42
+ ddim_timesteps=self.ddim_timesteps,
43
+ eta=ddim_eta,verbose=verbose)
44
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
45
+ self.register_buffer('ddim_alphas', ddim_alphas)
46
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
47
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
48
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
49
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
50
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
51
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
52
+
53
+ @torch.no_grad()
54
+ def sample(self,
55
+ S,
56
+ batch_size,
57
+ shape,
58
+ conditioning=None,
59
+ callback=None,
60
+ img_callback=None,
61
+ quantize_x0=False,
62
+ eta=0.,
63
+ mask=None,
64
+ x0=None,
65
+ temperature=1.,
66
+ noise_dropout=0.,
67
+ score_corrector=None,
68
+ corrector_kwargs=None,
69
+ verbose=True,
70
+ x_T=None,
71
+ log_every_t=100,
72
+ unconditional_guidance_scale=1.,
73
+ unconditional_conditioning=None,
74
+ keep_attn_maps=False,
75
+ **kwargs
76
+ ):
77
+ self.model.keep_attn_map_dict(keep_attn_maps)
78
+ if conditioning is not None:
79
+ if isinstance(conditioning, dict):
80
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
81
+ if cbs != batch_size:
82
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
83
+ else:
84
+ if conditioning.shape[0] != batch_size:
85
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
86
+
87
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
88
+ # Sampling
89
+ C, H, W = shape
90
+ size = (batch_size, C, H, W)
91
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
92
+
93
+ samples, intermediates = self.ddim_sampling(conditioning, size,
94
+ callback=callback,
95
+ img_callback=img_callback,
96
+ quantize_denoised=quantize_x0,
97
+ mask=mask, x0=x0,
98
+ ddim_use_original_steps=False,
99
+ noise_dropout=noise_dropout,
100
+ temperature=temperature,
101
+ score_corrector=score_corrector,
102
+ corrector_kwargs=corrector_kwargs,
103
+ x_T=x_T,
104
+ log_every_t=log_every_t,
105
+ unconditional_guidance_scale=unconditional_guidance_scale,
106
+ unconditional_conditioning=unconditional_conditioning,
107
+ )
108
+ return samples, intermediates
109
+
110
+ @torch.no_grad()
111
+ def ddim_sampling(self, cond, shape,
112
+ x_T=None, ddim_use_original_steps=False,
113
+ callback=None, timesteps=None, quantize_denoised=False,
114
+ mask=None, x0=None, img_callback=None, log_every_t=100,
115
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
116
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
117
+ device = self.model.betas.device
118
+ b = shape[0]
119
+ if x_T is None:
120
+ img = torch.randn(shape, device=device)
121
+ else:
122
+ img = x_T
123
+
124
+ if timesteps is None:
125
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
126
+ elif timesteps is not None and not ddim_use_original_steps:
127
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
128
+ timesteps = self.ddim_timesteps[:subset_end]
129
+
130
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
131
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
132
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
133
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
134
+
135
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
136
+
137
+ for i, step in enumerate(iterator):
138
+ index = total_steps - i - 1
139
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
140
+
141
+ if mask is not None:
142
+ assert x0 is not None
143
+ img_orig = self.model.q_sample(x0, ts)
144
+ img = img_orig * mask + (1. - mask) * img
145
+
146
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
147
+ quantize_denoised=quantize_denoised, temperature=temperature,
148
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
149
+ corrector_kwargs=corrector_kwargs,
150
+ unconditional_guidance_scale=unconditional_guidance_scale,
151
+ unconditional_conditioning=unconditional_conditioning)
152
+ img, pred_x0 = outs
153
+ if callback: callback(i)
154
+ if img_callback: img_callback(pred_x0, i)
155
+
156
+ if index % log_every_t == 0 or index == total_steps - 1:
157
+ intermediates['x_inter'].append(img)
158
+ intermediates['pred_x0'].append(pred_x0)
159
+
160
+ return img, intermediates
161
+
162
+ @torch.no_grad()
163
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
164
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
165
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
166
+ b, *_, device = *x.shape, x.device
167
+
168
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
169
+ e_t = self.model.apply_model(x, t, c, index=index)
170
+ else:
171
+ x_in = torch.cat([x] * 2)
172
+ t_in = torch.cat([t] * 2)
173
+ c_in = torch.cat([unconditional_conditioning, c])
174
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, index=index).chunk(2)
175
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
176
+
177
+ if score_corrector is not None:
178
+ assert self.model.parameterization == "eps"
179
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
180
+
181
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
182
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
183
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
184
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
185
+
186
+ # Selecting parameters corresponding to the currently considered timestep
187
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
188
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
189
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
190
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
191
+
192
+ # Current prediction for x_0
193
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
194
+ if quantize_denoised:
195
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
196
+
197
+ # Direction pointing to x_t
198
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
199
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
200
+ if noise_dropout > 0.:
201
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
202
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
203
+
204
+ return x_prev, pred_x0
ldm/models/diffusion/ddpm.py ADDED
@@ -0,0 +1,1062 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##################################################################################################
2
+ # Adapted from: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/ddpm.py
3
+ ##################################################################################################
4
+ # Utilized resources:
5
+ # - https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
6
+ # - https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
7
+ # - https://github.com/CompVis/taming-transformers
8
+ ##################################################################################################
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import numpy as np
13
+ import pytorch_lightning as pl
14
+ from torch.optim.lr_scheduler import LambdaLR
15
+ from einops import rearrange, repeat
16
+ from contextlib import contextmanager
17
+ from functools import partial
18
+ from tqdm import tqdm
19
+ from torchvision.utils import make_grid
20
+ from pytorch_lightning.utilities.distributed import rank_zero_only
21
+ from ldm.util import log_txt_as_img, exists, default, isimage, mean_flat, count_params, instantiate_from_config
22
+ from ldm.modules.ema import LitEma
23
+ from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
24
+ from ldm.models.autoencoder import VQModelInterface
25
+ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
26
+ from ldm.models.diffusion.ddim import DDIMSampler
27
+ from PIL import Image
28
+ from ldm.util import seed_everything
29
+
30
+ def disabled_train(self, mode=True):
31
+ """Overwrite model.train with this function to make sure train/eval mode
32
+ does not change anymore."""
33
+ return self
34
+
35
+ class DDPM(pl.LightningModule):
36
+ # DDPM with Gaussian diffusion in image space.
37
+ def __init__(self,
38
+ unet_config,
39
+ timesteps=1000,
40
+ beta_schedule="linear",
41
+ loss_type="l2",
42
+ ckpt_path=None,
43
+ ignore_keys=[],
44
+ load_only_unet=False,
45
+ monitor="val/loss",
46
+ use_ema=True,
47
+ first_stage_key="image",
48
+ image_size=256,
49
+ channels=3,
50
+ log_every_t=100,
51
+ clip_denoised=True,
52
+ linear_start=1e-4,
53
+ linear_end=2e-2,
54
+ cosine_s=8e-3,
55
+ given_betas=None,
56
+ original_elbo_weight=0.,
57
+ v_posterior=0., # Weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
58
+ l_simple_weight=1.,
59
+ conditioning_key=None,
60
+ parameterization="eps", # All assuming fixed variance schedules
61
+ scheduler_config=None,
62
+ learn_logvar=False,
63
+ logvar_init=0.
64
+ ):
65
+ super().__init__()
66
+ assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
67
+ self.parameterization = parameterization
68
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
69
+ self.cond_stage_model = None
70
+ self.clip_denoised = clip_denoised
71
+ self.log_every_t = log_every_t
72
+ self.first_stage_key = first_stage_key
73
+ self.image_size = image_size
74
+ self.channels = channels
75
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
76
+ count_params(self.model, verbose=True)
77
+ self.use_ema = use_ema
78
+ if self.use_ema:
79
+ self.model_ema = LitEma(self.model)
80
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
81
+
82
+ self.use_scheduler = scheduler_config is not None
83
+ if self.use_scheduler:
84
+ self.scheduler_config = scheduler_config
85
+
86
+ self.v_posterior = v_posterior
87
+ self.original_elbo_weight = original_elbo_weight
88
+ self.l_simple_weight = l_simple_weight
89
+
90
+ if monitor is not None:
91
+ self.monitor = monitor
92
+ if ckpt_path is not None:
93
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
94
+
95
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
96
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
97
+
98
+ self.loss_type = loss_type
99
+
100
+ self.learn_logvar = learn_logvar
101
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
102
+ if self.learn_logvar:
103
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
104
+
105
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
106
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
107
+ if exists(given_betas):
108
+ betas = given_betas
109
+ else:
110
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
111
+ cosine_s=cosine_s)
112
+ alphas = 1. - betas
113
+ alphas_cumprod = np.cumprod(alphas, axis=0)
114
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
115
+
116
+ timesteps, = betas.shape
117
+ self.num_timesteps = int(timesteps)
118
+ self.linear_start = linear_start
119
+ self.linear_end = linear_end
120
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
121
+
122
+ to_torch = partial(torch.tensor, dtype=torch.float32)
123
+
124
+ self.register_buffer('betas', to_torch(betas))
125
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
126
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
127
+
128
+ # Calculations for diffusion q(x_t | x_{t-1}) and others
129
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
130
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
131
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
132
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
133
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
134
+
135
+ # Calculations for posterior q(x_{t-1} | x_t, x_0)
136
+ # Equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
137
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + self.v_posterior * betas
138
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
139
+ # Log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain.
140
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
141
+ self.register_buffer('posterior_mean_coef1', to_torch(
142
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
143
+ self.register_buffer('posterior_mean_coef2', to_torch(
144
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
145
+
146
+ if self.parameterization == "eps":
147
+ lvlb_weights = self.betas ** 2 / (
148
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
149
+ elif self.parameterization == "x0":
150
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
151
+ else:
152
+ raise NotImplementedError("mu not supported")
153
+ lvlb_weights[0] = lvlb_weights[1]
154
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
155
+ assert not torch.isnan(self.lvlb_weights).all()
156
+
157
+ @contextmanager
158
+ def ema_scope(self, context=None):
159
+ if self.use_ema:
160
+ self.model_ema.store(self.model.parameters())
161
+ self.model_ema.copy_to(self.model)
162
+ if context is not None:
163
+ print(f"{context}: Switched to EMA weights")
164
+ try:
165
+ yield None
166
+ finally:
167
+ if self.use_ema:
168
+ self.model_ema.restore(self.model.parameters())
169
+ if context is not None:
170
+ print(f"{context}: Restored training weights")
171
+
172
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
173
+ sd = torch.load(path, map_location="cpu")
174
+ if "state_dict" in list(sd.keys()):
175
+ sd = sd["state_dict"]
176
+ keys = list(sd.keys())
177
+ for k in keys:
178
+ for ik in ignore_keys:
179
+ if k.startswith(ik):
180
+ print("Deleting key {} from state_dict.".format(k))
181
+ del sd[k]
182
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
183
+ sd, strict=False)
184
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
185
+ if len(missing) > 0:
186
+ print(f"Missing Keys: {missing}")
187
+ if len(unexpected) > 0:
188
+ print(f"Unexpected Keys: {unexpected}")
189
+
190
+ def q_mean_variance(self, x_start, t):
191
+ """
192
+ Get the distribution q(x_t | x_0).
193
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
194
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
195
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
196
+ """
197
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
198
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
199
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
200
+ return mean, variance, log_variance
201
+
202
+ def predict_start_from_noise(self, x_t, t, noise):
203
+ return (
204
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
205
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
206
+ )
207
+
208
+ def q_posterior(self, x_start, x_t, t):
209
+ posterior_mean = (
210
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
211
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
212
+ )
213
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
214
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
215
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
216
+
217
+ def p_mean_variance(self, x, t, clip_denoised: bool):
218
+ model_out = self.model(x, t)
219
+ if self.parameterization == "eps":
220
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
221
+ elif self.parameterization == "x0":
222
+ x_recon = model_out
223
+ if clip_denoised:
224
+ x_recon.clamp_(-1., 1.)
225
+
226
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
227
+ return model_mean, posterior_variance, posterior_log_variance
228
+
229
+ @torch.no_grad()
230
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
231
+ b, *_, device = *x.shape, x.device
232
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
233
+ noise = noise_like(x.shape, device, repeat_noise)
234
+ # No noise when t == 0
235
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
236
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
237
+
238
+ @torch.no_grad()
239
+ def p_sample_loop(self, shape, return_intermediates=False):
240
+ device = self.betas.device
241
+ b = shape[0]
242
+ img = torch.randn(shape, device=device)
243
+ intermediates = [img]
244
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
245
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
246
+ clip_denoised=self.clip_denoised)
247
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
248
+ intermediates.append(img)
249
+ if return_intermediates:
250
+ return img, intermediates
251
+ return img
252
+
253
+ @torch.no_grad()
254
+ def sample(self, batch_size=16, return_intermediates=False):
255
+ image_size = self.image_size
256
+ channels = self.channels
257
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
258
+ return_intermediates=return_intermediates)
259
+
260
+ def q_sample(self, x_start, t, noise=None):
261
+ noise = default(noise, lambda: torch.randn_like(x_start))
262
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
263
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
264
+
265
+ def get_loss(self, pred, target, mean=True):
266
+ if self.loss_type == 'l1':
267
+ loss = (target - pred).abs()
268
+ if mean:
269
+ loss = loss.mean()
270
+ elif self.loss_type == 'l2':
271
+ if mean:
272
+ loss = torch.nn.functional.mse_loss(target, pred)
273
+ else:
274
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
275
+ else:
276
+ raise NotImplementedError("unknown loss type '{loss_type}'")
277
+
278
+ return loss
279
+
280
+ def p_losses(self, x_start, t, noise=None):
281
+ noise = default(noise, lambda: torch.randn_like(x_start))
282
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
283
+ model_out = self.model(x_noisy, t)
284
+
285
+ loss_dict = {}
286
+ if self.parameterization == "eps":
287
+ target = noise
288
+ elif self.parameterization == "x0":
289
+ target = x_start
290
+ else:
291
+ raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
292
+
293
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
294
+
295
+ log_prefix = 'train' if self.training else 'val'
296
+
297
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
298
+ loss_simple = loss.mean() * self.l_simple_weight
299
+
300
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
301
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
302
+
303
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
304
+
305
+ loss_dict.update({f'{log_prefix}/loss': loss})
306
+
307
+ return loss, loss_dict
308
+
309
+ def forward(self, x, *args, **kwargs):
310
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
311
+ return self.p_losses(x, t, *args, **kwargs)
312
+
313
+ def get_input(self, batch, k):
314
+ x = batch[k]
315
+ if isinstance(x, list):
316
+ return x
317
+ x = x.to(memory_format=torch.contiguous_format).float()
318
+ return x
319
+
320
+ def shared_step(self, batch):
321
+ x = self.get_input(batch, self.first_stage_key)
322
+ loss, loss_dict = self(x)
323
+ return loss, loss_dict
324
+
325
+ def training_step(self, batch, batch_idx):
326
+ loss, loss_dict = self.shared_step(batch)
327
+
328
+ self.log_dict(loss_dict, prog_bar=True,
329
+ logger=True, on_step=True, on_epoch=True)
330
+
331
+ self.log("global_step", self.global_step,
332
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
333
+
334
+ if self.use_scheduler:
335
+ lr = self.optimizers().param_groups[0]['lr']
336
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
337
+
338
+ return loss
339
+
340
+ @torch.no_grad()
341
+ def validation_step(self, batch, batch_idx):
342
+ _, loss_dict_no_ema = self.shared_step(batch)
343
+ with self.ema_scope():
344
+ _, loss_dict_ema = self.shared_step(batch)
345
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
346
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
347
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
348
+
349
+ def on_train_batch_end(self, *args, **kwargs):
350
+ if self.use_ema:
351
+ self.model_ema(self.model)
352
+
353
+ def _get_rows_from_list(self, samples):
354
+ n_imgs_per_row = len(samples)
355
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
356
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
357
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
358
+ return denoise_grid
359
+
360
+ @torch.no_grad()
361
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
362
+ log = dict()
363
+ x = self.get_input(batch, self.first_stage_key)
364
+ N = min(x.shape[0], N)
365
+ n_row = min(x.shape[0], n_row)
366
+ x = x.to(self.device)[:N]
367
+ log["inputs"] = x
368
+
369
+ # Getting diffusion row
370
+ diffusion_row = list()
371
+ x_start = x[:n_row]
372
+
373
+ for t in range(self.num_timesteps):
374
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
375
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
376
+ t = t.to(self.device).long()
377
+ noise = torch.randn_like(x_start)
378
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
379
+ diffusion_row.append(x_noisy)
380
+
381
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
382
+
383
+ if sample:
384
+ # Getting denoise row
385
+ with self.ema_scope("Plotting"):
386
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
387
+
388
+ log["samples"] = samples
389
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
390
+
391
+ if return_keys:
392
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
393
+ return log
394
+ else:
395
+ return {key: log[key] for key in return_keys}
396
+ return log
397
+
398
+ def configure_optimizers(self):
399
+ lr = self.learning_rate
400
+ params = list(self.model.parameters())
401
+ if self.learn_logvar:
402
+ params = params + [self.logvar]
403
+ opt = torch.optim.AdamW(params, lr=lr)
404
+ return opt
405
+
406
+
407
+ class LatentDiffusion(DDPM):
408
+ def __init__(self,
409
+ first_stage_config,
410
+ cond_stage_config,
411
+ cond_stage_instruction_embedder_config=None,
412
+ num_timesteps_cond=None,
413
+ cond_stage_key="image",
414
+ cond_stage_instruction_key=None,
415
+ cond_stage_trainable=False,
416
+ cond_stage_instruction_embedder_trainable=False,
417
+ concat_mode=True,
418
+ cond_stage_forward=None,
419
+ conditioning_key=None,
420
+ scale_factor=1.0,
421
+ scale_by_std=False,
422
+ *args, **kwargs):
423
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
424
+ self.scale_by_std = scale_by_std
425
+ assert self.num_timesteps_cond <= kwargs['timesteps']
426
+ # For backwards compatibility after implementation of DiffusionWrapper
427
+ if conditioning_key is None:
428
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
429
+ if cond_stage_config == '__is_unconditional__':
430
+ conditioning_key = None
431
+ ckpt_path = kwargs.pop("ckpt_path", None)
432
+ ignore_keys = kwargs.pop("ignore_keys", [])
433
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
434
+ self.concat_mode = concat_mode
435
+ self.cond_stage_trainable = cond_stage_trainable
436
+ self.cond_stage_key = cond_stage_key
437
+ self.cond_stage_instruction_key = cond_stage_instruction_key
438
+ self.cond_stage_instruction_embedder_config = cond_stage_instruction_embedder_config
439
+ self.cond_stage_instruction_embedder_trainable = cond_stage_instruction_embedder_trainable
440
+ try:
441
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
442
+ except:
443
+ self.num_downs = 0
444
+ if not scale_by_std:
445
+ self.scale_factor = scale_factor
446
+ else:
447
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
448
+ self.instantiate_first_stage(first_stage_config)
449
+ self.instantiate_cond_stage(cond_stage_config)
450
+ self.instantiate_cond_stage_instruction_embedder(cond_stage_instruction_embedder_config)
451
+ self.cond_stage_forward = cond_stage_forward
452
+ self.clip_denoised = False
453
+ self.bbox_tokenizer = None
454
+
455
+ self.restarted_from_ckpt = False
456
+ if ckpt_path is not None:
457
+ self.init_from_ckpt(ckpt_path, ignore_keys)
458
+ self.restarted_from_ckpt = True
459
+
460
+ def keep_attn_map_dict(self, keep_attn_maps):
461
+ self.model.keep_attn_map_dict(keep_attn_maps)
462
+
463
+ def get_attn_map_dict(self):
464
+ return self.model.attn_dict
465
+
466
+ def make_cond_schedule(self, ):
467
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
468
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
469
+ self.cond_ids[:self.num_timesteps_cond] = ids
470
+
471
+ @rank_zero_only
472
+ @torch.no_grad()
473
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
474
+ # Only for the very first batch
475
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
476
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
477
+ # Set rescale weight to 1./std of encodings
478
+ print("### USING STD-RESCALING ###")
479
+ x = super().get_input(batch, self.first_stage_key)
480
+ x = x.to(self.device)
481
+ encoder_posterior = self.encode_first_stage(x)
482
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
483
+ del self.scale_factor
484
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
485
+ print(f"setting self.scale_factor to {self.scale_factor}")
486
+ print("### USING STD-RESCALING ###")
487
+
488
+ def register_schedule(self,
489
+ given_betas=None, beta_schedule="linear", timesteps=1000,
490
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
491
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
492
+
493
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
494
+ if self.shorten_cond_schedule:
495
+ self.make_cond_schedule()
496
+
497
+ def instantiate_first_stage(self, config):
498
+ model = instantiate_from_config(config)
499
+ self.first_stage_model = model.eval()
500
+ self.first_stage_model.train = disabled_train
501
+ for param in self.first_stage_model.parameters():
502
+ param.requires_grad = False
503
+
504
+ def instantiate_cond_stage(self, config):
505
+ if not self.cond_stage_trainable:
506
+ if config == "__is_first_stage__":
507
+ print("Using first stage also as cond stage.")
508
+ self.cond_stage_model = self.first_stage_model
509
+ elif config == "__is_unconditional__":
510
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
511
+ self.cond_stage_model = None
512
+ else:
513
+ model = instantiate_from_config(config)
514
+ self.cond_stage_model = model.eval()
515
+ self.cond_stage_model.train = disabled_train
516
+ for param in self.cond_stage_model.parameters():
517
+ param.requires_grad = False
518
+ else:
519
+ assert config != '__is_first_stage__'
520
+ assert config != '__is_unconditional__'
521
+ model = instantiate_from_config(config)
522
+ self.cond_stage_model = model
523
+
524
+ def instantiate_cond_stage_instruction_embedder(self, config):
525
+ if self.cond_stage_instruction_embedder_config is not None:
526
+ assert self.cond_stage_instruction_key is not None
527
+ self.cond_stage_instruction_embedder = instantiate_from_config(config)
528
+ if not self.cond_stage_instruction_embedder_trainable:
529
+ self.cond_stage_instruction_embedder = self.cond_stage_instruction_embedder.eval()
530
+ self.cond_stage_instruction_embedder.train = disabled_train
531
+ for param in self.cond_stage_instruction_embedder.parameters():
532
+ param.requires_grad = False
533
+
534
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
535
+ denoise_row = []
536
+ for zd in tqdm(samples, desc=desc):
537
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
538
+ force_not_quantize=force_no_decoder_quantization))
539
+ n_imgs_per_row = len(denoise_row)
540
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
541
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
542
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
543
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
544
+ return denoise_grid
545
+
546
+ def get_first_stage_encoding(self, encoder_posterior):
547
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
548
+ z = encoder_posterior.sample()
549
+ elif isinstance(encoder_posterior, torch.Tensor):
550
+ z = encoder_posterior
551
+ else:
552
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
553
+ return self.scale_factor * z
554
+
555
+ def get_learned_conditioning(self, c):
556
+ if self.cond_stage_forward is None:
557
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
558
+ c = self.cond_stage_model.encode(c)
559
+ if isinstance(c, DiagonalGaussianDistribution):
560
+ c = c.mode()
561
+ else:
562
+ c = self.cond_stage_model(c)
563
+ else:
564
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
565
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
566
+ return c
567
+
568
+ @torch.no_grad()
569
+ def get_main_input(self, batch, k, return_first_stage_outputs, force_c_encode,
570
+ cond_key, return_original_cond, bs):
571
+ x = super().get_input(batch, k)
572
+ check_condition_modification = False
573
+ if bs is not None:
574
+ x = x[:bs]
575
+ x = x.to(self.device)
576
+ encoder_posterior = self.encode_first_stage(x)
577
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
578
+
579
+ if self.model.conditioning_key is not None:
580
+ check_condition_modification = True
581
+ if cond_key is None:
582
+ cond_key = self.cond_stage_key
583
+ if cond_key != self.first_stage_key:
584
+ xc = super().get_input(batch, cond_key).to(self.device)
585
+ else:
586
+ xc = x
587
+ if not self.cond_stage_trainable or force_c_encode:
588
+ if isinstance(xc, dict) or isinstance(xc, list):
589
+ c = self.get_learned_conditioning(xc)
590
+ else:
591
+ c = self.get_learned_conditioning(xc.to(self.device))
592
+ else:
593
+ c = xc
594
+ if bs is not None:
595
+ c = c[:bs]
596
+ else:
597
+ c = None
598
+ xc = None
599
+ out = [z, c]
600
+ if return_first_stage_outputs:
601
+ xrec = self.decode_first_stage(z)
602
+ out.extend([x, xrec])
603
+ if return_original_cond:
604
+ out.append(xc)
605
+ return out, check_condition_modification
606
+
607
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
608
+ cond_key=None, return_original_cond=False, bs=None):
609
+
610
+ out, check_condition_modification = self.get_main_input(batch, k, return_first_stage_outputs, force_c_encode,
611
+ cond_key, return_original_cond, bs)
612
+ c = out[1]
613
+ # Implemented for inpainting model
614
+ if check_condition_modification:
615
+ if self.cond_stage_instruction_key and self.model.conditioning_key == "concat":
616
+ instructions = super().get_input(batch, self.cond_stage_instruction_key)
617
+ c = self.cond_stage_instruction_embedder(c, instructions)
618
+
619
+ if self.cond_stage_instruction_key and self.model.conditioning_key == "hybrid":
620
+ instructions = super().get_input(batch, self.cond_stage_instruction_key)
621
+ # Condition image feature is sent as None to the instruction embedder (instruction embedding is not concatenated)
622
+ instruction_embedding = self.cond_stage_instruction_embedder(None, instructions)
623
+ c = {'c_concat': c, 'c_crossattn': instruction_embedding}
624
+
625
+ out[1] = c
626
+ return out
627
+
628
+
629
+ @torch.no_grad()
630
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
631
+ if predict_cids:
632
+ if z.dim() == 4:
633
+ z = torch.argmax(z.exp(), dim=1).long()
634
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
635
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
636
+
637
+ z = 1. / self.scale_factor * z
638
+
639
+ if isinstance(self.first_stage_model, VQModelInterface):
640
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
641
+ else:
642
+ return self.first_stage_model.decode(z)
643
+
644
+ @torch.no_grad()
645
+ def encode_first_stage(self, x):
646
+ return self.first_stage_model.encode(x)
647
+
648
+ def shared_step(self, batch, **kwargs):
649
+ x, c = self.get_input(batch, self.first_stage_key)
650
+ loss = self(x, c)
651
+ return loss
652
+
653
+ def forward(self, x, c, *args, **kwargs):
654
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
655
+ if self.model.conditioning_key is not None:
656
+ assert c is not None
657
+ if self.cond_stage_trainable:
658
+ c = self.get_learned_conditioning(c)
659
+ if self.shorten_cond_schedule:
660
+ tc = self.cond_ids[t].to(self.device)
661
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
662
+ return self.p_losses(x, c, t, *args, **kwargs)
663
+
664
+ def apply_model(self, x_noisy, t, cond, index=None):
665
+ # self.model.conditioning_key is not hybrid
666
+ if not isinstance(cond, dict):
667
+ if not isinstance(cond, list):
668
+ cond = [cond]
669
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
670
+ cond = {key: cond}
671
+
672
+ x_recon = self.model(x_noisy, t, **cond, index=index)
673
+
674
+ if isinstance(x_recon, tuple):
675
+ return x_recon[0]
676
+ else:
677
+ return x_recon
678
+
679
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
680
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
681
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
682
+
683
+ def _prior_bpd(self, x_start):
684
+ """
685
+ Get the prior KL term for the variational lower-bound, measured in
686
+ bits-per-dim.
687
+ This term can't be optimized, as it only depends on the encoder.
688
+ :param x_start: the [N x C x ...] tensor of inputs.
689
+ :return: a batch of [N] KL values (in bits), one per batch element.
690
+ """
691
+ batch_size = x_start.shape[0]
692
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
693
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
694
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
695
+ return mean_flat(kl_prior) / np.log(2.0)
696
+
697
+ def p_losses(self, x_start, cond, t, noise=None):
698
+ noise = default(noise, lambda: torch.randn_like(x_start))
699
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
700
+ model_output = self.apply_model(x_noisy, t, cond)
701
+
702
+ loss_dict = {}
703
+ prefix = 'train' if self.training else 'val'
704
+
705
+ if self.parameterization == "x0":
706
+ target = x_start
707
+ elif self.parameterization == "eps":
708
+ target = noise
709
+ else:
710
+ raise NotImplementedError()
711
+
712
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
713
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
714
+
715
+ logvar_t = self.logvar[t].to(self.device)
716
+
717
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
718
+ if self.learn_logvar:
719
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
720
+ loss_dict.update({'logvar': self.logvar.data.mean()})
721
+
722
+ loss = self.l_simple_weight * loss.mean()
723
+
724
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
725
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
726
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
727
+ loss += (self.original_elbo_weight * loss_vlb)
728
+ loss_dict.update({f'{prefix}/loss': loss})
729
+
730
+ return loss, loss_dict
731
+
732
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, quantize_denoised=False,
733
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
734
+ t_in = t
735
+ model_out = self.apply_model(x, t_in, c)
736
+
737
+ if score_corrector is not None:
738
+ assert self.parameterization == "eps"
739
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
740
+
741
+
742
+ if self.parameterization == "eps":
743
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
744
+ elif self.parameterization == "x0":
745
+ x_recon = model_out
746
+ else:
747
+ raise NotImplementedError()
748
+
749
+ if clip_denoised:
750
+ x_recon.clamp_(-1., 1.)
751
+ if quantize_denoised:
752
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
753
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
754
+
755
+ if return_x0:
756
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
757
+ else:
758
+ return model_mean, posterior_variance, posterior_log_variance
759
+
760
+ @torch.no_grad()
761
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, quantize_denoised=False, return_x0=False,
762
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
763
+ b, *_, device = *x.shape, x.device
764
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
765
+ quantize_denoised=quantize_denoised,
766
+ return_x0=return_x0,
767
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
768
+
769
+ if return_x0:
770
+ model_mean, _, model_log_variance, x0 = outputs
771
+ else:
772
+ model_mean, _, model_log_variance = outputs
773
+
774
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
775
+ if noise_dropout > 0.:
776
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
777
+ # No noise when t == 0
778
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
779
+
780
+ if return_x0:
781
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
782
+ else:
783
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
784
+
785
+ @torch.no_grad()
786
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
787
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
788
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
789
+ log_every_t=None):
790
+ if not log_every_t:
791
+ log_every_t = self.log_every_t
792
+ timesteps = self.num_timesteps
793
+ if batch_size is not None:
794
+ b = batch_size if batch_size is not None else shape[0]
795
+ shape = [batch_size] + list(shape)
796
+ else:
797
+ b = batch_size = shape[0]
798
+ if x_T is None:
799
+ img = torch.randn(shape, device=self.device)
800
+ else:
801
+ img = x_T
802
+ intermediates = []
803
+ if cond is not None:
804
+ if isinstance(cond, dict):
805
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
806
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
807
+ else:
808
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
809
+
810
+ if start_T is not None:
811
+ timesteps = min(timesteps, start_T)
812
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
813
+ total=timesteps) if verbose else reversed(
814
+ range(0, timesteps))
815
+ if type(temperature) == float:
816
+ temperature = [temperature] * timesteps
817
+
818
+ for i in iterator:
819
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
820
+ if self.shorten_cond_schedule:
821
+ assert self.model.conditioning_key != 'hybrid'
822
+ tc = self.cond_ids[ts].to(cond.device)
823
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
824
+
825
+ img, x0_partial = self.p_sample(img, cond, ts,
826
+ clip_denoised=self.clip_denoised,
827
+ quantize_denoised=quantize_denoised, return_x0=True,
828
+ temperature=temperature[i], noise_dropout=noise_dropout,
829
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
830
+ if mask is not None:
831
+ assert x0 is not None
832
+ img_orig = self.q_sample(x0, ts)
833
+ img = img_orig * mask + (1. - mask) * img
834
+
835
+ if i % log_every_t == 0 or i == timesteps - 1:
836
+ intermediates.append(x0_partial)
837
+ if callback: callback(i)
838
+ if img_callback: img_callback(img, i)
839
+ return img, intermediates
840
+
841
+ @torch.no_grad()
842
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
843
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
844
+ mask=None, x0=None, img_callback=None, start_T=None,
845
+ log_every_t=None):
846
+
847
+ if not log_every_t:
848
+ log_every_t = self.log_every_t
849
+ device = self.betas.device
850
+ b = shape[0]
851
+ if x_T is None:
852
+ img = torch.randn(shape, device=device)
853
+ else:
854
+ img = x_T
855
+
856
+ intermediates = [img]
857
+ if timesteps is None:
858
+ timesteps = self.num_timesteps
859
+
860
+ if start_T is not None:
861
+ timesteps = min(timesteps, start_T)
862
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
863
+ range(0, timesteps))
864
+
865
+ if mask is not None:
866
+ assert x0 is not None
867
+ assert x0.shape[2:3] == mask.shape[2:3] # Spatial size has to match
868
+
869
+ for i in iterator:
870
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
871
+ if self.shorten_cond_schedule:
872
+ assert self.model.conditioning_key != 'hybrid'
873
+ tc = self.cond_ids[ts].to(cond.device)
874
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
875
+
876
+ img = self.p_sample(img, cond, ts,
877
+ clip_denoised=self.clip_denoised,
878
+ quantize_denoised=quantize_denoised)
879
+ if mask is not None:
880
+ img_orig = self.q_sample(x0, ts)
881
+ img = img_orig * mask + (1. - mask) * img
882
+
883
+ if i % log_every_t == 0 or i == timesteps - 1:
884
+ intermediates.append(img)
885
+ if callback: callback(i)
886
+ if img_callback: img_callback(img, i)
887
+
888
+ if return_intermediates:
889
+ return img, intermediates
890
+ return img
891
+
892
+ @torch.no_grad()
893
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
894
+ verbose=True, timesteps=None, quantize_denoised=False,
895
+ mask=None, x0=None, shape=None,**kwargs):
896
+ if shape is None:
897
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
898
+ if cond is not None:
899
+ if isinstance(cond, dict):
900
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
901
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
902
+ else:
903
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
904
+ return self.p_sample_loop(cond,
905
+ shape,
906
+ return_intermediates=return_intermediates, x_T=x_T,
907
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
908
+ mask=mask, x0=x0)
909
+
910
+ @torch.no_grad()
911
+ def log_images(self, batch, N=8, n_row=4, plot_progressive_rows=True, instruction_img_size=256, **kwargs):
912
+
913
+ log = dict()
914
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
915
+ return_first_stage_outputs=True,
916
+ force_c_encode=True,
917
+ return_original_cond=True,
918
+ bs=N)
919
+ N = min(x.shape[0], N)
920
+ n_row = min(x.shape[0], n_row)
921
+ log["inputs"] = x
922
+ log["reconstruction"] = xrec
923
+ if self.model.conditioning_key is not None:
924
+ if hasattr(self.cond_stage_model, "decode"):
925
+ if self.cond_stage_instruction_key and self.model.conditioning_key == "concat":
926
+ c_cond = c[:,:-self.cond_stage_instruction_embedder.out_size,:,:]
927
+ else:
928
+ c_cond = c
929
+ if isinstance(c_cond, dict):
930
+ c_cond = c_cond["c_concat"]
931
+ xc = self.cond_stage_model.decode(c_cond)
932
+ log["conditioning"] = xc
933
+ elif isimage(xc):
934
+ log["conditioning"] = xc
935
+
936
+ if self.cond_stage_instruction_key is not None:
937
+ instructions = super().get_input(batch, self.cond_stage_instruction_key)
938
+ instructions_img = log_txt_as_img((instruction_img_size, instruction_img_size), instructions)
939
+ log['instructions'] = instructions_img
940
+
941
+ if plot_progressive_rows:
942
+ with self.ema_scope("Plotting Progressives"):
943
+ img, progressives = self.progressive_denoising(c,
944
+ shape=(self.channels, self.image_size, self.image_size),
945
+ batch_size=N)
946
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
947
+ log["progressive_row"] = prog_row
948
+
949
+ return log
950
+
951
+ @torch.no_grad()
952
+ def inpaint(self, image, instruction, num_steps=50, device="cuda", return_pil=True, seed=0):
953
+ assert len(image.shape) == 4 and image.shape[0] == 1, "Input image should be a tensor object with batch size 1"
954
+ assert isinstance(instruction, str), "Input instruction type should be String"
955
+ assert self.model.conditioning_key == "hybrid", "Inpaint function is only available for hybrid conditioning"
956
+
957
+ image = image.to(device)
958
+ sampler = DDIMSampler(self, device=device)
959
+
960
+ seed_everything(seed)
961
+ with torch.no_grad():
962
+ with self.ema_scope():
963
+ c = self.get_first_stage_encoding(self.cond_stage_model.encode(image))
964
+ shape = c.shape[1:]
965
+ instruction_embedding = self.cond_stage_instruction_embedder(None, [instruction])
966
+ c = {'c_concat': c, 'c_crossattn': instruction_embedding}
967
+ batch_size=c["c_concat"].shape[0]
968
+ output_latent, _ = sampler.sample(S=num_steps,
969
+ conditioning=c,
970
+ batch_size=batch_size,
971
+ shape=shape,
972
+ verbose=False)
973
+ output_image_tensor = self.decode_first_stage(output_latent)[0]
974
+ output_image_tensor = torch.clip(output_image_tensor, -1, 1)
975
+ output_image_np = ((output_image_tensor + 1) * 127.5).cpu().numpy()
976
+ output_image = Image.fromarray(output_image_np.transpose(1,2,0).astype(np.uint8))
977
+
978
+ if return_pil:
979
+ return output_image
980
+ return output_image_tensor
981
+
982
+ def configure_optimizers(self):
983
+ lr = self.learning_rate
984
+ params = list(self.model.parameters())
985
+ if self.cond_stage_trainable:
986
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
987
+ params = params + list(self.cond_stage_model.parameters())
988
+ if self.cond_stage_instruction_embedder_trainable:
989
+ print(f"{self.__class__.__name__}: Also optimizing conditionaer (instruction embedder) params!")
990
+ params = params + list(self.cond_stage_instruction_embedder.parameters())
991
+ if self.learn_logvar:
992
+ print('Diffusion model optimizing logvar')
993
+ params.append(self.logvar)
994
+ opt = torch.optim.AdamW(params, lr=lr)
995
+ if self.use_scheduler:
996
+ assert 'target' in self.scheduler_config
997
+ scheduler = instantiate_from_config(self.scheduler_config)
998
+
999
+ print("Setting up LambdaLR scheduler...")
1000
+ scheduler = [
1001
+ {
1002
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1003
+ 'interval': 'step',
1004
+ 'frequency': 1
1005
+ }]
1006
+ return [opt], scheduler
1007
+ return opt
1008
+
1009
+
1010
+ class DiffusionWrapper(pl.LightningModule):
1011
+ def __init__(self, diff_model_config, conditioning_key):
1012
+ super().__init__()
1013
+ self.diffusion_model = instantiate_from_config(diff_model_config)
1014
+ self.conditioning_key = conditioning_key
1015
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
1016
+ self.attn_dict = None
1017
+ self.keep_attn_maps = False
1018
+
1019
+ def keep_attn_map_dict(self, keep_attn_maps):
1020
+ self.keep_attn_maps = keep_attn_maps
1021
+ if keep_attn_maps:
1022
+ if self.attn_dict is None:
1023
+ self.attn_dict = {}
1024
+ else:
1025
+ self.attn_dict.clear()
1026
+ else:
1027
+ self.attn_dict = None
1028
+
1029
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_mask: list = None, index=None):
1030
+ if self.keep_attn_maps:
1031
+ assert index is not None
1032
+ if index not in self.attn_dict:
1033
+ self.attn_dict[index] = {}
1034
+ else:
1035
+ raise Exception("Attention maps of the current time index has already been assigned.")
1036
+ if self.conditioning_key is None:
1037
+ out = self.diffusion_model(x, t)
1038
+ elif self.conditioning_key == 'concat':
1039
+ if not isinstance(c_concat, list):
1040
+ c_concat = [c_concat]
1041
+ xc = torch.cat([x] + c_concat, dim=1)
1042
+ out = self.diffusion_model(xc, t)
1043
+ elif self.conditioning_key == 'crossattn':
1044
+ cc = torch.cat(c_crossattn, 1)
1045
+ if self.keep_attn_maps:
1046
+ out = self.diffusion_model(x, t, context=cc, attn_dict=self.attn_dict[index])
1047
+ else:
1048
+ out = self.diffusion_model(x, t, context=cc)
1049
+ elif self.conditioning_key == 'hybrid':
1050
+ if not isinstance(c_concat, list):
1051
+ c_concat = [c_concat]
1052
+ if not isinstance(c_crossattn, list):
1053
+ c_crossattn = [c_crossattn]
1054
+ xc = torch.cat([x] + c_concat, dim=1)
1055
+ cc = torch.cat(c_crossattn, 1)
1056
+ if self.keep_attn_maps:
1057
+ out = self.diffusion_model(xc, t, context=cc, attn_dict=self.attn_dict[index])
1058
+ else:
1059
+ out = self.diffusion_model(xc, t, context=cc)
1060
+ else:
1061
+ raise NotImplementedError()
1062
+ return out
ldm/modules/attention.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+
8
+ from ldm.modules.diffusionmodules.util import checkpoint
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def uniq(arr):
16
+ return{el: True for el in arr}.keys()
17
+
18
+
19
+ def default(val, d):
20
+ if exists(val):
21
+ return val
22
+ return d() if isfunction(d) else d
23
+
24
+
25
+ def max_neg_value(t):
26
+ return -torch.finfo(t.dtype).max
27
+
28
+
29
+ def init_(tensor):
30
+ dim = tensor.shape[-1]
31
+ std = 1 / math.sqrt(dim)
32
+ tensor.uniform_(-std, std)
33
+ return tensor
34
+
35
+
36
+ # feedforward
37
+ class GEGLU(nn.Module):
38
+ def __init__(self, dim_in, dim_out):
39
+ super().__init__()
40
+ self.proj = nn.Linear(dim_in, dim_out * 2)
41
+
42
+ def forward(self, x):
43
+ x, gate = self.proj(x).chunk(2, dim=-1)
44
+ return x * F.gelu(gate)
45
+
46
+
47
+ class FeedForward(nn.Module):
48
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49
+ super().__init__()
50
+ inner_dim = int(dim * mult)
51
+ dim_out = default(dim_out, dim)
52
+ project_in = nn.Sequential(
53
+ nn.Linear(dim, inner_dim),
54
+ nn.GELU()
55
+ ) if not glu else GEGLU(dim, inner_dim)
56
+
57
+ self.net = nn.Sequential(
58
+ project_in,
59
+ nn.Dropout(dropout),
60
+ nn.Linear(inner_dim, dim_out)
61
+ )
62
+
63
+ def forward(self, x):
64
+ return self.net(x)
65
+
66
+
67
+ def zero_module(module):
68
+ """
69
+ Zero out the parameters of a module and return it.
70
+ """
71
+ for p in module.parameters():
72
+ p.detach().zero_()
73
+ return module
74
+
75
+
76
+ def Normalize(in_channels):
77
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78
+
79
+
80
+ class LinearAttention(nn.Module):
81
+ def __init__(self, dim, heads=4, dim_head=32):
82
+ super().__init__()
83
+ self.heads = heads
84
+ hidden_dim = dim_head * heads
85
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87
+
88
+ def forward(self, x):
89
+ b, c, h, w = x.shape
90
+ qkv = self.to_qkv(x)
91
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92
+ k = k.softmax(dim=-1)
93
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
94
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
95
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96
+ return self.to_out(out)
97
+
98
+
99
+ class SpatialSelfAttention(nn.Module):
100
+ def __init__(self, in_channels):
101
+ super().__init__()
102
+ self.in_channels = in_channels
103
+
104
+ self.norm = Normalize(in_channels)
105
+ self.q = torch.nn.Conv2d(in_channels,
106
+ in_channels,
107
+ kernel_size=1,
108
+ stride=1,
109
+ padding=0)
110
+ self.k = torch.nn.Conv2d(in_channels,
111
+ in_channels,
112
+ kernel_size=1,
113
+ stride=1,
114
+ padding=0)
115
+ self.v = torch.nn.Conv2d(in_channels,
116
+ in_channels,
117
+ kernel_size=1,
118
+ stride=1,
119
+ padding=0)
120
+ self.proj_out = torch.nn.Conv2d(in_channels,
121
+ in_channels,
122
+ kernel_size=1,
123
+ stride=1,
124
+ padding=0)
125
+
126
+ def forward(self, x):
127
+ h_ = x
128
+ h_ = self.norm(h_)
129
+ q = self.q(h_)
130
+ k = self.k(h_)
131
+ v = self.v(h_)
132
+
133
+ # compute attention
134
+ b,c,h,w = q.shape
135
+ q = rearrange(q, 'b c h w -> b (h w) c')
136
+ k = rearrange(k, 'b c h w -> b c (h w)')
137
+ w_ = torch.einsum('bij,bjk->bik', q, k)
138
+
139
+ w_ = w_ * (int(c)**(-0.5))
140
+ w_ = torch.nn.functional.softmax(w_, dim=2)
141
+
142
+ # attend to values
143
+ v = rearrange(v, 'b c h w -> b c (h w)')
144
+ w_ = rearrange(w_, 'b i j -> b j i')
145
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
146
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
147
+ h_ = self.proj_out(h_)
148
+
149
+ return x+h_
150
+
151
+
152
+ class CrossAttention(nn.Module):
153
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
154
+ super().__init__()
155
+ inner_dim = dim_head * heads
156
+ context_dim = default(context_dim, query_dim)
157
+
158
+ self.scale = dim_head ** -0.5
159
+ self.heads = heads
160
+
161
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
162
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
163
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
164
+
165
+ self.to_out = nn.Sequential(
166
+ nn.Linear(inner_dim, query_dim),
167
+ nn.Dropout(dropout)
168
+ )
169
+
170
+ def __add_attention_to_dict(self, attn_dict, layer_type, attn):
171
+ layer_key = "{}_layer".format(layer_type)
172
+ if layer_key not in attn_dict:
173
+ attn_dict[layer_key] = []
174
+ attn_dict[layer_key].append(attn.cpu())
175
+
176
+ def forward(self, x, context=None, mask=None, attn_dict=None, layer_type=None):
177
+ # Dimensions of the Simge dataset are written in comments
178
+ h = self.heads
179
+ q = self.to_q(x)
180
+ context = default(context, x)
181
+
182
+ # NOTE: Rest of the dimensions are reported for the cross-attn case
183
+
184
+ k = self.to_k(context)
185
+ v = self.to_v(context)
186
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
187
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
188
+
189
+ if exists(mask):
190
+ mask = rearrange(mask, 'b ... -> b (...)')
191
+ max_neg_value = -torch.finfo(sim.dtype).max
192
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
193
+ sim.masked_fill_(~mask, max_neg_value)
194
+
195
+ # attention, what we cannot get enough of
196
+ attn = sim.softmax(dim=-1)
197
+
198
+ if context is not None and attn_dict is not None and layer_type is not None:
199
+ self.__add_attention_to_dict(attn_dict, layer_type, attn)
200
+
201
+ out = einsum('b i j, b j d -> b i d', attn, v)
202
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
203
+ return self.to_out(out)
204
+
205
+
206
+ class BasicTransformerBlock(nn.Module):
207
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
208
+ super().__init__()
209
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
210
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
211
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
212
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
213
+ self.norm1 = nn.LayerNorm(dim)
214
+ self.norm2 = nn.LayerNorm(dim)
215
+ self.norm3 = nn.LayerNorm(dim)
216
+ self.checkpoint = checkpoint
217
+
218
+ def forward(self, x, context=None, attn_dict=None, layer_type=None):
219
+ return checkpoint(self._forward, (x, context, attn_dict, layer_type), self.parameters(), self.checkpoint)
220
+
221
+ def _forward(self, x, context=None, attn_dict=None, layer_type=None):
222
+ x = self.attn1(self.norm1(x)) + x
223
+ x = self.attn2(self.norm2(x), context=context, attn_dict=attn_dict, layer_type=layer_type) + x
224
+ x = self.ff(self.norm3(x)) + x
225
+ return x
226
+
227
+
228
+ class SpatialTransformer(nn.Module):
229
+ """
230
+ Transformer block for image-like data.
231
+ First, project the input (aka embedding)
232
+ and reshape to b, t, d.
233
+ Then apply standard transformer action.
234
+ Finally, reshape to image
235
+ """
236
+ def __init__(self, in_channels, n_heads, d_head,
237
+ depth=1, dropout=0., context_dim=None):
238
+ super().__init__()
239
+ self.in_channels = in_channels
240
+ inner_dim = n_heads * d_head
241
+ self.norm = Normalize(in_channels)
242
+
243
+ self.proj_in = nn.Conv2d(in_channels,
244
+ inner_dim,
245
+ kernel_size=1,
246
+ stride=1,
247
+ padding=0)
248
+
249
+ self.transformer_blocks = nn.ModuleList(
250
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
251
+ for d in range(depth)]
252
+ )
253
+
254
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
255
+ in_channels,
256
+ kernel_size=1,
257
+ stride=1,
258
+ padding=0))
259
+
260
+ def forward(self, x, context=None, attn_dict=None, layer_type=None):
261
+ # note: if no context is given, cross-attention defaults to self-attention
262
+ b, c, h, w = x.shape
263
+ x_in = x
264
+ x = self.norm(x)
265
+ x = self.proj_in(x)
266
+ x = rearrange(x, 'b c h w -> b (h w) c')
267
+ for block in self.transformer_blocks:
268
+ x = block(x, context=context, attn_dict=attn_dict, layer_type=layer_type)
269
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
270
+ x = self.proj_out(x)
271
+ return x + x_in
ldm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,836 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+
8
+ from ldm.util import instantiate_from_config
9
+ from ldm.modules.attention import LinearAttention
10
+
11
+
12
+ def get_timestep_embedding(timesteps, embedding_dim):
13
+ """
14
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
15
+ From Fairseq.
16
+ Build sinusoidal embeddings.
17
+ This matches the implementation in tensor2tensor, but differs slightly
18
+ from the description in Section 3.5 of "Attention Is All You Need".
19
+ """
20
+ assert len(timesteps.shape) == 1
21
+
22
+ half_dim = embedding_dim // 2
23
+ emb = math.log(10000) / (half_dim - 1)
24
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
25
+ emb = emb.to(device=timesteps.device)
26
+ emb = timesteps.float()[:, None] * emb[None, :]
27
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
28
+ if embedding_dim % 2 == 1: # zero pad
29
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
30
+ return emb
31
+
32
+
33
+ def nonlinearity(x):
34
+ # swish
35
+ return x*torch.sigmoid(x)
36
+
37
+
38
+ def Normalize(in_channels, num_groups=32):
39
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
40
+
41
+
42
+ class Upsample(nn.Module):
43
+ def __init__(self, in_channels, with_conv):
44
+ super().__init__()
45
+ self.with_conv = with_conv
46
+ if self.with_conv:
47
+ self.conv = torch.nn.Conv2d(in_channels,
48
+ in_channels,
49
+ kernel_size=3,
50
+ stride=1,
51
+ padding=1)
52
+
53
+ def forward(self, x):
54
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
55
+ if self.with_conv:
56
+ x = self.conv(x)
57
+ return x
58
+
59
+
60
+ class Downsample(nn.Module):
61
+ def __init__(self, in_channels, with_conv):
62
+ super().__init__()
63
+ self.with_conv = with_conv
64
+ if self.with_conv:
65
+ # no asymmetric padding in torch conv, must do it ourselves
66
+ self.conv = torch.nn.Conv2d(in_channels,
67
+ in_channels,
68
+ kernel_size=3,
69
+ stride=2,
70
+ padding=0)
71
+
72
+ def forward(self, x):
73
+ if self.with_conv:
74
+ pad = (0,1,0,1)
75
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
76
+ x = self.conv(x)
77
+ else:
78
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
79
+ return x
80
+
81
+
82
+ class ResnetBlock(nn.Module):
83
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
84
+ dropout, temb_channels=512):
85
+ super().__init__()
86
+ self.in_channels = in_channels
87
+ out_channels = in_channels if out_channels is None else out_channels
88
+ self.out_channels = out_channels
89
+ self.use_conv_shortcut = conv_shortcut
90
+
91
+ self.norm1 = Normalize(in_channels)
92
+ self.conv1 = torch.nn.Conv2d(in_channels,
93
+ out_channels,
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=1)
97
+ if temb_channels > 0:
98
+ self.temb_proj = torch.nn.Linear(temb_channels,
99
+ out_channels)
100
+ self.norm2 = Normalize(out_channels)
101
+ self.dropout = torch.nn.Dropout(dropout)
102
+ self.conv2 = torch.nn.Conv2d(out_channels,
103
+ out_channels,
104
+ kernel_size=3,
105
+ stride=1,
106
+ padding=1)
107
+ if self.in_channels != self.out_channels:
108
+ if self.use_conv_shortcut:
109
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
110
+ out_channels,
111
+ kernel_size=3,
112
+ stride=1,
113
+ padding=1)
114
+ else:
115
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
116
+ out_channels,
117
+ kernel_size=1,
118
+ stride=1,
119
+ padding=0)
120
+
121
+ def forward(self, x, temb):
122
+ h = x
123
+ h = self.norm1(h)
124
+ h = nonlinearity(h)
125
+ h = self.conv1(h)
126
+
127
+ if temb is not None:
128
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
129
+
130
+ h = self.norm2(h)
131
+ h = nonlinearity(h)
132
+ h = self.dropout(h)
133
+ h = self.conv2(h)
134
+
135
+ if self.in_channels != self.out_channels:
136
+ if self.use_conv_shortcut:
137
+ x = self.conv_shortcut(x)
138
+ else:
139
+ x = self.nin_shortcut(x)
140
+
141
+ return x+h
142
+
143
+
144
+ class LinAttnBlock(LinearAttention):
145
+ """to match AttnBlock usage"""
146
+ def __init__(self, in_channels):
147
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
148
+
149
+
150
+ class AttnBlock(nn.Module):
151
+ def __init__(self, in_channels):
152
+ super().__init__()
153
+ self.in_channels = in_channels
154
+
155
+ self.norm = Normalize(in_channels)
156
+ self.q = torch.nn.Conv2d(in_channels,
157
+ in_channels,
158
+ kernel_size=1,
159
+ stride=1,
160
+ padding=0)
161
+ self.k = torch.nn.Conv2d(in_channels,
162
+ in_channels,
163
+ kernel_size=1,
164
+ stride=1,
165
+ padding=0)
166
+ self.v = torch.nn.Conv2d(in_channels,
167
+ in_channels,
168
+ kernel_size=1,
169
+ stride=1,
170
+ padding=0)
171
+ self.proj_out = torch.nn.Conv2d(in_channels,
172
+ in_channels,
173
+ kernel_size=1,
174
+ stride=1,
175
+ padding=0)
176
+
177
+
178
+ def forward(self, x):
179
+ h_ = x
180
+ h_ = self.norm(h_)
181
+ q = self.q(h_)
182
+ k = self.k(h_)
183
+ v = self.v(h_)
184
+
185
+ # compute attention
186
+ b,c,h,w = q.shape
187
+ q = q.reshape(b,c,h*w)
188
+ q = q.permute(0,2,1) # b,hw,c
189
+ k = k.reshape(b,c,h*w) # b,c,hw
190
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
191
+ w_ = w_ * (int(c)**(-0.5))
192
+ w_ = torch.nn.functional.softmax(w_, dim=2)
193
+
194
+ # attend to values
195
+ v = v.reshape(b,c,h*w)
196
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
197
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
198
+ h_ = h_.reshape(b,c,h,w)
199
+
200
+ h_ = self.proj_out(h_)
201
+
202
+ return x+h_
203
+
204
+
205
+ def make_attn(in_channels, attn_type="vanilla"):
206
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
207
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
208
+ if attn_type == "vanilla":
209
+ return AttnBlock(in_channels)
210
+ elif attn_type == "none":
211
+ return nn.Identity(in_channels)
212
+ else:
213
+ return LinAttnBlock(in_channels)
214
+
215
+
216
+ class Model(nn.Module):
217
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
218
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
219
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
220
+ super().__init__()
221
+ if use_linear_attn: attn_type = "linear"
222
+ self.ch = ch
223
+ self.temb_ch = self.ch*4
224
+ self.num_resolutions = len(ch_mult)
225
+ self.num_res_blocks = num_res_blocks
226
+ self.resolution = resolution
227
+ self.in_channels = in_channels
228
+
229
+ self.use_timestep = use_timestep
230
+ if self.use_timestep:
231
+ # timestep embedding
232
+ self.temb = nn.Module()
233
+ self.temb.dense = nn.ModuleList([
234
+ torch.nn.Linear(self.ch,
235
+ self.temb_ch),
236
+ torch.nn.Linear(self.temb_ch,
237
+ self.temb_ch),
238
+ ])
239
+
240
+ # downsampling
241
+ self.conv_in = torch.nn.Conv2d(in_channels,
242
+ self.ch,
243
+ kernel_size=3,
244
+ stride=1,
245
+ padding=1)
246
+
247
+ curr_res = resolution
248
+ in_ch_mult = (1,)+tuple(ch_mult)
249
+ self.down = nn.ModuleList()
250
+ for i_level in range(self.num_resolutions):
251
+ block = nn.ModuleList()
252
+ attn = nn.ModuleList()
253
+ block_in = ch*in_ch_mult[i_level]
254
+ block_out = ch*ch_mult[i_level]
255
+ for i_block in range(self.num_res_blocks):
256
+ block.append(ResnetBlock(in_channels=block_in,
257
+ out_channels=block_out,
258
+ temb_channels=self.temb_ch,
259
+ dropout=dropout))
260
+ block_in = block_out
261
+ if curr_res in attn_resolutions:
262
+ attn.append(make_attn(block_in, attn_type=attn_type))
263
+ down = nn.Module()
264
+ down.block = block
265
+ down.attn = attn
266
+ if i_level != self.num_resolutions-1:
267
+ down.downsample = Downsample(block_in, resamp_with_conv)
268
+ curr_res = curr_res // 2
269
+ self.down.append(down)
270
+
271
+ # middle
272
+ self.mid = nn.Module()
273
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
274
+ out_channels=block_in,
275
+ temb_channels=self.temb_ch,
276
+ dropout=dropout)
277
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
278
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
279
+ out_channels=block_in,
280
+ temb_channels=self.temb_ch,
281
+ dropout=dropout)
282
+
283
+ # upsampling
284
+ self.up = nn.ModuleList()
285
+ for i_level in reversed(range(self.num_resolutions)):
286
+ block = nn.ModuleList()
287
+ attn = nn.ModuleList()
288
+ block_out = ch*ch_mult[i_level]
289
+ skip_in = ch*ch_mult[i_level]
290
+ for i_block in range(self.num_res_blocks+1):
291
+ if i_block == self.num_res_blocks:
292
+ skip_in = ch*in_ch_mult[i_level]
293
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
294
+ out_channels=block_out,
295
+ temb_channels=self.temb_ch,
296
+ dropout=dropout))
297
+ block_in = block_out
298
+ if curr_res in attn_resolutions:
299
+ attn.append(make_attn(block_in, attn_type=attn_type))
300
+ up = nn.Module()
301
+ up.block = block
302
+ up.attn = attn
303
+ if i_level != 0:
304
+ up.upsample = Upsample(block_in, resamp_with_conv)
305
+ curr_res = curr_res * 2
306
+ self.up.insert(0, up) # prepend to get consistent order
307
+
308
+ # end
309
+ self.norm_out = Normalize(block_in)
310
+ self.conv_out = torch.nn.Conv2d(block_in,
311
+ out_ch,
312
+ kernel_size=3,
313
+ stride=1,
314
+ padding=1)
315
+
316
+ def forward(self, x, t=None, context=None):
317
+ #assert x.shape[2] == x.shape[3] == self.resolution
318
+ if context is not None:
319
+ # assume aligned context, cat along channel axis
320
+ x = torch.cat((x, context), dim=1)
321
+ if self.use_timestep:
322
+ # timestep embedding
323
+ assert t is not None
324
+ temb = get_timestep_embedding(t, self.ch)
325
+ temb = self.temb.dense[0](temb)
326
+ temb = nonlinearity(temb)
327
+ temb = self.temb.dense[1](temb)
328
+ else:
329
+ temb = None
330
+
331
+ # downsampling
332
+ hs = [self.conv_in(x)]
333
+ for i_level in range(self.num_resolutions):
334
+ for i_block in range(self.num_res_blocks):
335
+ h = self.down[i_level].block[i_block](hs[-1], temb)
336
+ if len(self.down[i_level].attn) > 0:
337
+ h = self.down[i_level].attn[i_block](h)
338
+ hs.append(h)
339
+ if i_level != self.num_resolutions-1:
340
+ hs.append(self.down[i_level].downsample(hs[-1]))
341
+
342
+ # middle
343
+ h = hs[-1]
344
+ h = self.mid.block_1(h, temb)
345
+ h = self.mid.attn_1(h)
346
+ h = self.mid.block_2(h, temb)
347
+
348
+ # upsampling
349
+ for i_level in reversed(range(self.num_resolutions)):
350
+ for i_block in range(self.num_res_blocks+1):
351
+ h = self.up[i_level].block[i_block](
352
+ torch.cat([h, hs.pop()], dim=1), temb)
353
+ if len(self.up[i_level].attn) > 0:
354
+ h = self.up[i_level].attn[i_block](h)
355
+ if i_level != 0:
356
+ h = self.up[i_level].upsample(h)
357
+
358
+ # end
359
+ h = self.norm_out(h)
360
+ h = nonlinearity(h)
361
+ h = self.conv_out(h)
362
+ return h
363
+
364
+ def get_last_layer(self):
365
+ return self.conv_out.weight
366
+
367
+
368
+ class Encoder(nn.Module):
369
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
370
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
371
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
372
+ **ignore_kwargs):
373
+ super().__init__()
374
+ if use_linear_attn: attn_type = "linear"
375
+ self.ch = ch
376
+ self.temb_ch = 0
377
+ self.num_resolutions = len(ch_mult)
378
+ self.num_res_blocks = num_res_blocks
379
+ self.resolution = resolution
380
+ self.in_channels = in_channels
381
+
382
+ # downsampling
383
+ self.conv_in = torch.nn.Conv2d(in_channels,
384
+ self.ch,
385
+ kernel_size=3,
386
+ stride=1,
387
+ padding=1)
388
+
389
+ curr_res = resolution
390
+ in_ch_mult = (1,)+tuple(ch_mult)
391
+ self.in_ch_mult = in_ch_mult
392
+ self.down = nn.ModuleList()
393
+ for i_level in range(self.num_resolutions):
394
+ block = nn.ModuleList()
395
+ attn = nn.ModuleList()
396
+ block_in = ch*in_ch_mult[i_level]
397
+ block_out = ch*ch_mult[i_level]
398
+ for i_block in range(self.num_res_blocks):
399
+ block.append(ResnetBlock(in_channels=block_in,
400
+ out_channels=block_out,
401
+ temb_channels=self.temb_ch,
402
+ dropout=dropout))
403
+ block_in = block_out
404
+ if curr_res in attn_resolutions:
405
+ attn.append(make_attn(block_in, attn_type=attn_type))
406
+ down = nn.Module()
407
+ down.block = block
408
+ down.attn = attn
409
+ if i_level != self.num_resolutions-1:
410
+ down.downsample = Downsample(block_in, resamp_with_conv)
411
+ curr_res = curr_res // 2
412
+ self.down.append(down)
413
+
414
+ # middle
415
+ self.mid = nn.Module()
416
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
417
+ out_channels=block_in,
418
+ temb_channels=self.temb_ch,
419
+ dropout=dropout)
420
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
421
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
422
+ out_channels=block_in,
423
+ temb_channels=self.temb_ch,
424
+ dropout=dropout)
425
+
426
+ # end
427
+ self.norm_out = Normalize(block_in)
428
+ self.conv_out = torch.nn.Conv2d(block_in,
429
+ 2*z_channels if double_z else z_channels,
430
+ kernel_size=3,
431
+ stride=1,
432
+ padding=1)
433
+
434
+ def forward(self, x):
435
+ # timestep embedding
436
+ temb = None
437
+
438
+ # downsampling
439
+ hs = [self.conv_in(x)]
440
+ for i_level in range(self.num_resolutions):
441
+ for i_block in range(self.num_res_blocks):
442
+ h = self.down[i_level].block[i_block](hs[-1], temb)
443
+ if len(self.down[i_level].attn) > 0:
444
+ h = self.down[i_level].attn[i_block](h)
445
+ hs.append(h)
446
+ if i_level != self.num_resolutions-1:
447
+ hs.append(self.down[i_level].downsample(hs[-1]))
448
+
449
+ # middle
450
+ h = hs[-1]
451
+ h = self.mid.block_1(h, temb)
452
+ h = self.mid.attn_1(h)
453
+ h = self.mid.block_2(h, temb)
454
+
455
+ # end
456
+ h = self.norm_out(h)
457
+ h = nonlinearity(h)
458
+ h = self.conv_out(h)
459
+ return h
460
+
461
+
462
+ class Decoder(nn.Module):
463
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
464
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
465
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
466
+ attn_type="vanilla", **ignorekwargs):
467
+ super().__init__()
468
+ if use_linear_attn: attn_type = "linear"
469
+ self.ch = ch
470
+ self.temb_ch = 0
471
+ self.num_resolutions = len(ch_mult)
472
+ self.num_res_blocks = num_res_blocks
473
+ self.resolution = resolution
474
+ self.in_channels = in_channels
475
+ self.give_pre_end = give_pre_end
476
+ self.tanh_out = tanh_out
477
+
478
+ # compute in_ch_mult, block_in and curr_res at lowest res
479
+ in_ch_mult = (1,)+tuple(ch_mult)
480
+ block_in = ch*ch_mult[self.num_resolutions-1]
481
+ curr_res = resolution // 2**(self.num_resolutions-1)
482
+ self.z_shape = (1,z_channels,curr_res,curr_res)
483
+ print("Working with z of shape {} = {} dimensions.".format(
484
+ self.z_shape, np.prod(self.z_shape)))
485
+
486
+ # z to block_in
487
+ self.conv_in = torch.nn.Conv2d(z_channels,
488
+ block_in,
489
+ kernel_size=3,
490
+ stride=1,
491
+ padding=1)
492
+
493
+ # middle
494
+ self.mid = nn.Module()
495
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
496
+ out_channels=block_in,
497
+ temb_channels=self.temb_ch,
498
+ dropout=dropout)
499
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
500
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
501
+ out_channels=block_in,
502
+ temb_channels=self.temb_ch,
503
+ dropout=dropout)
504
+
505
+ # upsampling
506
+ self.up = nn.ModuleList()
507
+ for i_level in reversed(range(self.num_resolutions)):
508
+ block = nn.ModuleList()
509
+ attn = nn.ModuleList()
510
+ block_out = ch*ch_mult[i_level]
511
+ for i_block in range(self.num_res_blocks+1):
512
+ block.append(ResnetBlock(in_channels=block_in,
513
+ out_channels=block_out,
514
+ temb_channels=self.temb_ch,
515
+ dropout=dropout))
516
+ block_in = block_out
517
+ if curr_res in attn_resolutions:
518
+ attn.append(make_attn(block_in, attn_type=attn_type))
519
+ up = nn.Module()
520
+ up.block = block
521
+ up.attn = attn
522
+ if i_level != 0:
523
+ up.upsample = Upsample(block_in, resamp_with_conv)
524
+ curr_res = curr_res * 2
525
+ self.up.insert(0, up) # prepend to get consistent order
526
+
527
+ # end
528
+ self.norm_out = Normalize(block_in)
529
+ self.conv_out = torch.nn.Conv2d(block_in,
530
+ out_ch,
531
+ kernel_size=3,
532
+ stride=1,
533
+ padding=1)
534
+
535
+ def forward(self, z):
536
+
537
+ #assert z.shape[1:] == self.z_shape[1:]
538
+ self.last_z_shape = z.shape
539
+
540
+ # timestep embedding
541
+ temb = None
542
+
543
+ # z to block_in
544
+ h = self.conv_in(z)
545
+
546
+ # middle
547
+ h = self.mid.block_1(h, temb)
548
+ h = self.mid.attn_1(h)
549
+ h = self.mid.block_2(h, temb)
550
+
551
+ # upsampling
552
+ for i_level in reversed(range(self.num_resolutions)):
553
+ for i_block in range(self.num_res_blocks+1):
554
+ h = self.up[i_level].block[i_block](h, temb)
555
+ if len(self.up[i_level].attn) > 0:
556
+ h = self.up[i_level].attn[i_block](h)
557
+ if i_level != 0:
558
+ h = self.up[i_level].upsample(h)
559
+
560
+ # end
561
+ if self.give_pre_end:
562
+ return h
563
+
564
+ h = self.norm_out(h)
565
+ h = nonlinearity(h)
566
+ h = self.conv_out(h)
567
+ if self.tanh_out:
568
+ h = torch.tanh(h)
569
+ return h
570
+
571
+
572
+ class SimpleDecoder(nn.Module):
573
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
574
+ super().__init__()
575
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
576
+ ResnetBlock(in_channels=in_channels,
577
+ out_channels=2 * in_channels,
578
+ temb_channels=0, dropout=0.0),
579
+ ResnetBlock(in_channels=2 * in_channels,
580
+ out_channels=4 * in_channels,
581
+ temb_channels=0, dropout=0.0),
582
+ ResnetBlock(in_channels=4 * in_channels,
583
+ out_channels=2 * in_channels,
584
+ temb_channels=0, dropout=0.0),
585
+ nn.Conv2d(2*in_channels, in_channels, 1),
586
+ Upsample(in_channels, with_conv=True)])
587
+ # end
588
+ self.norm_out = Normalize(in_channels)
589
+ self.conv_out = torch.nn.Conv2d(in_channels,
590
+ out_channels,
591
+ kernel_size=3,
592
+ stride=1,
593
+ padding=1)
594
+
595
+ def forward(self, x):
596
+ for i, layer in enumerate(self.model):
597
+ if i in [1,2,3]:
598
+ x = layer(x, None)
599
+ else:
600
+ x = layer(x)
601
+
602
+ h = self.norm_out(x)
603
+ h = nonlinearity(h)
604
+ x = self.conv_out(h)
605
+ return x
606
+
607
+
608
+ class UpsampleDecoder(nn.Module):
609
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
610
+ ch_mult=(2,2), dropout=0.0):
611
+ super().__init__()
612
+ # upsampling
613
+ self.temb_ch = 0
614
+ self.num_resolutions = len(ch_mult)
615
+ self.num_res_blocks = num_res_blocks
616
+ block_in = in_channels
617
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
618
+ self.res_blocks = nn.ModuleList()
619
+ self.upsample_blocks = nn.ModuleList()
620
+ for i_level in range(self.num_resolutions):
621
+ res_block = []
622
+ block_out = ch * ch_mult[i_level]
623
+ for i_block in range(self.num_res_blocks + 1):
624
+ res_block.append(ResnetBlock(in_channels=block_in,
625
+ out_channels=block_out,
626
+ temb_channels=self.temb_ch,
627
+ dropout=dropout))
628
+ block_in = block_out
629
+ self.res_blocks.append(nn.ModuleList(res_block))
630
+ if i_level != self.num_resolutions - 1:
631
+ self.upsample_blocks.append(Upsample(block_in, True))
632
+ curr_res = curr_res * 2
633
+
634
+ # end
635
+ self.norm_out = Normalize(block_in)
636
+ self.conv_out = torch.nn.Conv2d(block_in,
637
+ out_channels,
638
+ kernel_size=3,
639
+ stride=1,
640
+ padding=1)
641
+
642
+ def forward(self, x):
643
+ # upsampling
644
+ h = x
645
+ for k, i_level in enumerate(range(self.num_resolutions)):
646
+ for i_block in range(self.num_res_blocks + 1):
647
+ h = self.res_blocks[i_level][i_block](h, None)
648
+ if i_level != self.num_resolutions - 1:
649
+ h = self.upsample_blocks[k](h)
650
+ h = self.norm_out(h)
651
+ h = nonlinearity(h)
652
+ h = self.conv_out(h)
653
+ return h
654
+
655
+
656
+ class LatentRescaler(nn.Module):
657
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
658
+ super().__init__()
659
+ # residual block, interpolate, residual block
660
+ self.factor = factor
661
+ self.conv_in = nn.Conv2d(in_channels,
662
+ mid_channels,
663
+ kernel_size=3,
664
+ stride=1,
665
+ padding=1)
666
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
667
+ out_channels=mid_channels,
668
+ temb_channels=0,
669
+ dropout=0.0) for _ in range(depth)])
670
+ self.attn = AttnBlock(mid_channels)
671
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
672
+ out_channels=mid_channels,
673
+ temb_channels=0,
674
+ dropout=0.0) for _ in range(depth)])
675
+
676
+ self.conv_out = nn.Conv2d(mid_channels,
677
+ out_channels,
678
+ kernel_size=1,
679
+ )
680
+
681
+ def forward(self, x):
682
+ x = self.conv_in(x)
683
+ for block in self.res_block1:
684
+ x = block(x, None)
685
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
686
+ x = self.attn(x)
687
+ for block in self.res_block2:
688
+ x = block(x, None)
689
+ x = self.conv_out(x)
690
+ return x
691
+
692
+
693
+ class MergedRescaleEncoder(nn.Module):
694
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
695
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
696
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
697
+ super().__init__()
698
+ intermediate_chn = ch * ch_mult[-1]
699
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
700
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
701
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
702
+ out_ch=None)
703
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
704
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
705
+
706
+ def forward(self, x):
707
+ x = self.encoder(x)
708
+ x = self.rescaler(x)
709
+ return x
710
+
711
+
712
+ class MergedRescaleDecoder(nn.Module):
713
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
714
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
715
+ super().__init__()
716
+ tmp_chn = z_channels*ch_mult[-1]
717
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
718
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
719
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
720
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
721
+ out_channels=tmp_chn, depth=rescale_module_depth)
722
+
723
+ def forward(self, x):
724
+ x = self.rescaler(x)
725
+ x = self.decoder(x)
726
+ return x
727
+
728
+
729
+ class Upsampler(nn.Module):
730
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
731
+ super().__init__()
732
+ assert out_size >= in_size
733
+ num_blocks = int(np.log2(out_size//in_size))+1
734
+ factor_up = 1.+ (out_size % in_size)
735
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
736
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
737
+ out_channels=in_channels)
738
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
739
+ attn_resolutions=[], in_channels=None, ch=in_channels,
740
+ ch_mult=[ch_mult for _ in range(num_blocks)])
741
+
742
+ def forward(self, x):
743
+ x = self.rescaler(x)
744
+ x = self.decoder(x)
745
+ return x
746
+
747
+
748
+ class Resize(nn.Module):
749
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
750
+ super().__init__()
751
+ self.with_conv = learned
752
+ self.mode = mode
753
+ if self.with_conv:
754
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
755
+ raise NotImplementedError()
756
+ assert in_channels is not None
757
+ # no asymmetric padding in torch conv, must do it ourselves
758
+ self.conv = torch.nn.Conv2d(in_channels,
759
+ in_channels,
760
+ kernel_size=4,
761
+ stride=2,
762
+ padding=1)
763
+
764
+ def forward(self, x, scale_factor=1.0):
765
+ if scale_factor==1.0:
766
+ return x
767
+ else:
768
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
769
+ return x
770
+
771
+ class FirstStagePostProcessor(nn.Module):
772
+
773
+ def __init__(self, ch_mult:list, in_channels,
774
+ pretrained_model:nn.Module=None,
775
+ reshape=False,
776
+ n_channels=None,
777
+ dropout=0.,
778
+ pretrained_config=None):
779
+ super().__init__()
780
+ if pretrained_config is None:
781
+ assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
782
+ self.pretrained_model = pretrained_model
783
+ else:
784
+ assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
785
+ self.instantiate_pretrained(pretrained_config)
786
+
787
+ self.do_reshape = reshape
788
+
789
+ if n_channels is None:
790
+ n_channels = self.pretrained_model.encoder.ch
791
+
792
+ self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
793
+ self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
794
+ stride=1,padding=1)
795
+
796
+ blocks = []
797
+ downs = []
798
+ ch_in = n_channels
799
+ for m in ch_mult:
800
+ blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
801
+ ch_in = m * n_channels
802
+ downs.append(Downsample(ch_in, with_conv=False))
803
+
804
+ self.model = nn.ModuleList(blocks)
805
+ self.downsampler = nn.ModuleList(downs)
806
+
807
+
808
+ def instantiate_pretrained(self, config):
809
+ model = instantiate_from_config(config)
810
+ self.pretrained_model = model.eval()
811
+ # self.pretrained_model.train = False
812
+ for param in self.pretrained_model.parameters():
813
+ param.requires_grad = False
814
+
815
+
816
+ @torch.no_grad()
817
+ def encode_with_pretrained(self,x):
818
+ c = self.pretrained_model.encode(x)
819
+ if isinstance(c, DiagonalGaussianDistribution):
820
+ c = c.mode()
821
+ return c
822
+
823
+ def forward(self,x):
824
+ z_fs = self.encode_with_pretrained(x)
825
+ z = self.proj_norm(z_fs)
826
+ z = self.proj(z)
827
+ z = nonlinearity(z)
828
+
829
+ for submodel, downmodel in zip(self.model,self.downsampler):
830
+ z = submodel(z,temb=None)
831
+ z = downmodel(z)
832
+
833
+ if self.do_reshape:
834
+ z = rearrange(z,'b c h w -> b (h w) c')
835
+ return z
836
+
ldm/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,973 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+ import math
4
+ from typing import Iterable
5
+
6
+ import numpy as np
7
+ import torch as th
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from ldm.modules.diffusionmodules.util import (
12
+ checkpoint,
13
+ conv_nd,
14
+ linear,
15
+ avg_pool_nd,
16
+ zero_module,
17
+ normalization,
18
+ timestep_embedding,
19
+ )
20
+ from ldm.modules.attention import SpatialTransformer
21
+
22
+
23
+ # dummy replace
24
+ def convert_module_to_f16(x):
25
+ pass
26
+
27
+ def convert_module_to_f32(x):
28
+ pass
29
+
30
+
31
+ ## go
32
+ class AttentionPool2d(nn.Module):
33
+ """
34
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ spacial_dim: int,
40
+ embed_dim: int,
41
+ num_heads_channels: int,
42
+ output_dim: int = None,
43
+ ):
44
+ super().__init__()
45
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
46
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
47
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
48
+ self.num_heads = embed_dim // num_heads_channels
49
+ self.attention = QKVAttention(self.num_heads)
50
+
51
+ def forward(self, x):
52
+ b, c, *_spatial = x.shape
53
+ x = x.reshape(b, c, -1) # NC(HW)
54
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
55
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
56
+ x = self.qkv_proj(x)
57
+ x = self.attention(x)
58
+ x = self.c_proj(x)
59
+ return x[:, :, 0]
60
+
61
+
62
+ class TimestepBlock(nn.Module):
63
+ """
64
+ Any module where forward() takes timestep embeddings as a second argument.
65
+ """
66
+
67
+ @abstractmethod
68
+ def forward(self, x, emb):
69
+ """
70
+ Apply the module to `x` given `emb` timestep embeddings.
71
+ """
72
+
73
+
74
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
75
+ """
76
+ A sequential module that passes timestep embeddings to the children that
77
+ support it as an extra input.
78
+ """
79
+
80
+ def forward(self, x, emb, context=None, attn_dict=None, layer_type=None):
81
+ for layer in self:
82
+ if isinstance(layer, TimestepBlock):
83
+ x = layer(x, emb)
84
+ elif isinstance(layer, SpatialTransformer):
85
+ x = layer(x, context, attn_dict=attn_dict, layer_type=layer_type)
86
+ else:
87
+ x = layer(x)
88
+ return x
89
+
90
+
91
+ class Upsample(nn.Module):
92
+ """
93
+ An upsampling layer with an optional convolution.
94
+ :param channels: channels in the inputs and outputs.
95
+ :param use_conv: a bool determining if a convolution is applied.
96
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
97
+ upsampling occurs in the inner-two dimensions.
98
+ """
99
+
100
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
101
+ super().__init__()
102
+ self.channels = channels
103
+ self.out_channels = out_channels or channels
104
+ self.use_conv = use_conv
105
+ self.dims = dims
106
+ if use_conv:
107
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
108
+
109
+ def forward(self, x):
110
+ assert x.shape[1] == self.channels
111
+ if self.dims == 3:
112
+ x = F.interpolate(
113
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
114
+ )
115
+ else:
116
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
117
+ if self.use_conv:
118
+ x = self.conv(x)
119
+ return x
120
+
121
+ class TransposedUpsample(nn.Module):
122
+ 'Learned 2x upsampling without padding'
123
+ def __init__(self, channels, out_channels=None, ks=5):
124
+ super().__init__()
125
+ self.channels = channels
126
+ self.out_channels = out_channels or channels
127
+
128
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
129
+
130
+ def forward(self,x):
131
+ return self.up(x)
132
+
133
+
134
+ class Downsample(nn.Module):
135
+ """
136
+ A downsampling layer with an optional convolution.
137
+ :param channels: channels in the inputs and outputs.
138
+ :param use_conv: a bool determining if a convolution is applied.
139
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
140
+ downsampling occurs in the inner-two dimensions.
141
+ """
142
+
143
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
144
+ super().__init__()
145
+ self.channels = channels
146
+ self.out_channels = out_channels or channels
147
+ self.use_conv = use_conv
148
+ self.dims = dims
149
+ stride = 2 if dims != 3 else (1, 2, 2)
150
+ if use_conv:
151
+ self.op = conv_nd(
152
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
153
+ )
154
+ else:
155
+ assert self.channels == self.out_channels
156
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
157
+
158
+ def forward(self, x):
159
+ assert x.shape[1] == self.channels
160
+ return self.op(x)
161
+
162
+
163
+ class ResBlock(TimestepBlock):
164
+ """
165
+ A residual block that can optionally change the number of channels.
166
+ :param channels: the number of input channels.
167
+ :param emb_channels: the number of timestep embedding channels.
168
+ :param dropout: the rate of dropout.
169
+ :param out_channels: if specified, the number of out channels.
170
+ :param use_conv: if True and out_channels is specified, use a spatial
171
+ convolution instead of a smaller 1x1 convolution to change the
172
+ channels in the skip connection.
173
+ :param dims: determines if the signal is 1D, 2D, or 3D.
174
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
175
+ :param up: if True, use this block for upsampling.
176
+ :param down: if True, use this block for downsampling.
177
+ """
178
+
179
+ def __init__(
180
+ self,
181
+ channels,
182
+ emb_channels,
183
+ dropout,
184
+ out_channels=None,
185
+ use_conv=False,
186
+ use_scale_shift_norm=False,
187
+ dims=2,
188
+ use_checkpoint=False,
189
+ up=False,
190
+ down=False,
191
+ ):
192
+ super().__init__()
193
+ self.channels = channels
194
+ self.emb_channels = emb_channels
195
+ self.dropout = dropout
196
+ self.out_channels = out_channels or channels
197
+ self.use_conv = use_conv
198
+ self.use_checkpoint = use_checkpoint
199
+ self.use_scale_shift_norm = use_scale_shift_norm
200
+
201
+ self.in_layers = nn.Sequential(
202
+ normalization(channels),
203
+ nn.SiLU(),
204
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
205
+ )
206
+
207
+ self.updown = up or down
208
+
209
+ if up:
210
+ self.h_upd = Upsample(channels, False, dims)
211
+ self.x_upd = Upsample(channels, False, dims)
212
+ elif down:
213
+ self.h_upd = Downsample(channels, False, dims)
214
+ self.x_upd = Downsample(channels, False, dims)
215
+ else:
216
+ self.h_upd = self.x_upd = nn.Identity()
217
+
218
+ self.emb_layers = nn.Sequential(
219
+ nn.SiLU(),
220
+ linear(
221
+ emb_channels,
222
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
223
+ ),
224
+ )
225
+ self.out_layers = nn.Sequential(
226
+ normalization(self.out_channels),
227
+ nn.SiLU(),
228
+ nn.Dropout(p=dropout),
229
+ zero_module(
230
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
231
+ ),
232
+ )
233
+
234
+ if self.out_channels == channels:
235
+ self.skip_connection = nn.Identity()
236
+ elif use_conv:
237
+ self.skip_connection = conv_nd(
238
+ dims, channels, self.out_channels, 3, padding=1
239
+ )
240
+ else:
241
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
242
+
243
+ def forward(self, x, emb):
244
+ """
245
+ Apply the block to a Tensor, conditioned on a timestep embedding.
246
+ :param x: an [N x C x ...] Tensor of features.
247
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
248
+ :return: an [N x C x ...] Tensor of outputs.
249
+ """
250
+ return checkpoint(
251
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
252
+ )
253
+
254
+
255
+ def _forward(self, x, emb):
256
+ if self.updown:
257
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
258
+ h = in_rest(x)
259
+ h = self.h_upd(h)
260
+ x = self.x_upd(x)
261
+ h = in_conv(h)
262
+ else:
263
+ h = self.in_layers(x)
264
+ emb_out = self.emb_layers(emb).type(h.dtype)
265
+ while len(emb_out.shape) < len(h.shape):
266
+ emb_out = emb_out[..., None]
267
+ if self.use_scale_shift_norm:
268
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
269
+ scale, shift = th.chunk(emb_out, 2, dim=1)
270
+ h = out_norm(h) * (1 + scale) + shift
271
+ h = out_rest(h)
272
+ else:
273
+ h = h + emb_out
274
+ h = self.out_layers(h)
275
+ return self.skip_connection(x) + h
276
+
277
+
278
+ class AttentionBlock(nn.Module):
279
+ """
280
+ An attention block that allows spatial positions to attend to each other.
281
+ Originally ported from here, but adapted to the N-d case.
282
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
283
+ """
284
+
285
+ def __init__(
286
+ self,
287
+ channels,
288
+ num_heads=1,
289
+ num_head_channels=-1,
290
+ use_checkpoint=False,
291
+ use_new_attention_order=False,
292
+ ):
293
+ super().__init__()
294
+ self.channels = channels
295
+ if num_head_channels == -1:
296
+ self.num_heads = num_heads
297
+ else:
298
+ assert (
299
+ channels % num_head_channels == 0
300
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
301
+ self.num_heads = channels // num_head_channels
302
+ self.use_checkpoint = use_checkpoint
303
+ self.norm = normalization(channels)
304
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
305
+ if use_new_attention_order:
306
+ # split qkv before split heads
307
+ self.attention = QKVAttention(self.num_heads)
308
+ else:
309
+ # split heads before split qkv
310
+ self.attention = QKVAttentionLegacy(self.num_heads)
311
+
312
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
313
+
314
+ def forward(self, x):
315
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
316
+ #return pt_checkpoint(self._forward, x) # pytorch
317
+
318
+ def _forward(self, x):
319
+ b, c, *spatial = x.shape
320
+ x = x.reshape(b, c, -1)
321
+ qkv = self.qkv(self.norm(x))
322
+ h = self.attention(qkv)
323
+ h = self.proj_out(h)
324
+ return (x + h).reshape(b, c, *spatial)
325
+
326
+
327
+ def count_flops_attn(model, _x, y):
328
+ """
329
+ A counter for the `thop` package to count the operations in an
330
+ attention operation.
331
+ Meant to be used like:
332
+ macs, params = thop.profile(
333
+ model,
334
+ inputs=(inputs, timestamps),
335
+ custom_ops={QKVAttention: QKVAttention.count_flops},
336
+ )
337
+ """
338
+ b, c, *spatial = y[0].shape
339
+ num_spatial = int(np.prod(spatial))
340
+ # We perform two matmuls with the same number of ops.
341
+ # The first computes the weight matrix, the second computes
342
+ # the combination of the value vectors.
343
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
344
+ model.total_ops += th.DoubleTensor([matmul_ops])
345
+
346
+
347
+ class QKVAttentionLegacy(nn.Module):
348
+ """
349
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
350
+ """
351
+
352
+ def __init__(self, n_heads):
353
+ super().__init__()
354
+ self.n_heads = n_heads
355
+
356
+ def forward(self, qkv):
357
+ """
358
+ Apply QKV attention.
359
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
360
+ :return: an [N x (H * C) x T] tensor after attention.
361
+ """
362
+ bs, width, length = qkv.shape
363
+ assert width % (3 * self.n_heads) == 0
364
+ ch = width // (3 * self.n_heads)
365
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
366
+ scale = 1 / math.sqrt(math.sqrt(ch))
367
+ weight = th.einsum(
368
+ "bct,bcs->bts", q * scale, k * scale
369
+ ) # More stable with f16 than dividing afterwards
370
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
371
+ a = th.einsum("bts,bcs->bct", weight, v)
372
+ return a.reshape(bs, -1, length)
373
+
374
+ @staticmethod
375
+ def count_flops(model, _x, y):
376
+ return count_flops_attn(model, _x, y)
377
+
378
+
379
+ class QKVAttention(nn.Module):
380
+ """
381
+ A module which performs QKV attention and splits in a different order.
382
+ """
383
+
384
+ def __init__(self, n_heads):
385
+ super().__init__()
386
+ self.n_heads = n_heads
387
+
388
+ def forward(self, qkv):
389
+ """
390
+ Apply QKV attention.
391
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
392
+ :return: an [N x (H * C) x T] tensor after attention.
393
+ """
394
+ bs, width, length = qkv.shape
395
+ assert width % (3 * self.n_heads) == 0
396
+ ch = width // (3 * self.n_heads)
397
+ q, k, v = qkv.chunk(3, dim=1)
398
+ scale = 1 / math.sqrt(math.sqrt(ch))
399
+ weight = th.einsum(
400
+ "bct,bcs->bts",
401
+ (q * scale).view(bs * self.n_heads, ch, length),
402
+ (k * scale).view(bs * self.n_heads, ch, length),
403
+ ) # More stable with f16 than dividing afterwards
404
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
405
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
406
+ return a.reshape(bs, -1, length)
407
+
408
+ @staticmethod
409
+ def count_flops(model, _x, y):
410
+ return count_flops_attn(model, _x, y)
411
+
412
+
413
+ class UNetModel(nn.Module):
414
+ """
415
+ The full UNet model with attention and timestep embedding.
416
+ :param in_channels: channels in the input Tensor.
417
+ :param model_channels: base channel count for the model.
418
+ :param out_channels: channels in the output Tensor.
419
+ :param num_res_blocks: number of residual blocks per downsample.
420
+ :param attention_resolutions: a collection of downsample rates at which
421
+ attention will take place. May be a set, list, or tuple.
422
+ For example, if this contains 4, then at 4x downsampling, attention
423
+ will be used.
424
+ :param dropout: the dropout probability.
425
+ :param channel_mult: channel multiplier for each level of the UNet.
426
+ :param conv_resample: if True, use learned convolutions for upsampling and
427
+ downsampling.
428
+ :param dims: determines if the signal is 1D, 2D, or 3D.
429
+ :param num_classes: if specified (as an int), then this model will be
430
+ class-conditional with `num_classes` classes.
431
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
432
+ :param num_heads: the number of attention heads in each attention layer.
433
+ :param num_heads_channels: if specified, ignore num_heads and instead use
434
+ a fixed channel width per attention head.
435
+ :param num_heads_upsample: works with num_heads to set a different number
436
+ of heads for upsampling. Deprecated.
437
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
438
+ :param resblock_updown: use residual blocks for up/downsampling.
439
+ :param use_new_attention_order: use a different attention pattern for potentially
440
+ increased efficiency.
441
+ """
442
+
443
+ def __init__(
444
+ self,
445
+ image_size,
446
+ in_channels,
447
+ model_channels,
448
+ out_channels,
449
+ num_res_blocks,
450
+ attention_resolutions,
451
+ dropout=0,
452
+ channel_mult=(1, 2, 4, 8),
453
+ conv_resample=True,
454
+ dims=2,
455
+ num_classes=None,
456
+ use_checkpoint=False,
457
+ use_fp16=False,
458
+ num_heads=-1,
459
+ num_head_channels=-1,
460
+ num_heads_upsample=-1,
461
+ use_scale_shift_norm=False,
462
+ resblock_updown=False,
463
+ use_new_attention_order=False,
464
+ use_spatial_transformer=False, # custom transformer support
465
+ transformer_depth=1, # custom transformer support
466
+ context_dim=None, # custom transformer support
467
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
468
+ legacy=True
469
+ ):
470
+ super().__init__()
471
+ if use_spatial_transformer:
472
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
473
+
474
+ if context_dim is not None:
475
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
476
+ from omegaconf.listconfig import ListConfig
477
+ if type(context_dim) == ListConfig:
478
+ context_dim = list(context_dim)
479
+
480
+ if num_heads_upsample == -1:
481
+ num_heads_upsample = num_heads
482
+
483
+ if num_heads == -1:
484
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
485
+
486
+ if num_head_channels == -1:
487
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
488
+
489
+ self.image_size = image_size
490
+ self.in_channels = in_channels
491
+ self.model_channels = model_channels
492
+ self.out_channels = out_channels
493
+ self.num_res_blocks = num_res_blocks
494
+ self.attention_resolutions = attention_resolutions
495
+ self.dropout = dropout
496
+ self.channel_mult = channel_mult
497
+ self.conv_resample = conv_resample
498
+ self.num_classes = num_classes
499
+ self.use_checkpoint = use_checkpoint
500
+ self.dtype = th.float16 if use_fp16 else th.float32
501
+ self.num_heads = num_heads
502
+ self.num_head_channels = num_head_channels
503
+ self.num_heads_upsample = num_heads_upsample
504
+ self.predict_codebook_ids = n_embed is not None
505
+
506
+ time_embed_dim = model_channels * 4
507
+ self.time_embed = nn.Sequential(
508
+ linear(model_channels, time_embed_dim),
509
+ nn.SiLU(),
510
+ linear(time_embed_dim, time_embed_dim),
511
+ )
512
+
513
+ if self.num_classes is not None:
514
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
515
+
516
+ self.input_blocks = nn.ModuleList(
517
+ [
518
+ TimestepEmbedSequential(
519
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
520
+ )
521
+ ]
522
+ )
523
+ self._feature_size = model_channels
524
+ input_block_chans = [model_channels]
525
+ ch = model_channels
526
+ ds = 1
527
+ for level, mult in enumerate(channel_mult):
528
+ for _ in range(num_res_blocks):
529
+ layers = [
530
+ ResBlock(
531
+ ch,
532
+ time_embed_dim,
533
+ dropout,
534
+ out_channels=mult * model_channels,
535
+ dims=dims,
536
+ use_checkpoint=use_checkpoint,
537
+ use_scale_shift_norm=use_scale_shift_norm,
538
+ )
539
+ ]
540
+ ch = mult * model_channels
541
+ if ds in attention_resolutions:
542
+ if num_head_channels == -1:
543
+ dim_head = ch // num_heads
544
+ else:
545
+ num_heads = ch // num_head_channels
546
+ dim_head = num_head_channels
547
+ if legacy:
548
+ #num_heads = 1
549
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
550
+ layers.append(
551
+ AttentionBlock(
552
+ ch,
553
+ use_checkpoint=use_checkpoint,
554
+ num_heads=num_heads,
555
+ num_head_channels=dim_head,
556
+ use_new_attention_order=use_new_attention_order,
557
+ ) if not use_spatial_transformer else SpatialTransformer(
558
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
559
+ )
560
+ )
561
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
562
+ self._feature_size += ch
563
+ input_block_chans.append(ch)
564
+ if level != len(channel_mult) - 1:
565
+ out_ch = ch
566
+ self.input_blocks.append(
567
+ TimestepEmbedSequential(
568
+ ResBlock(
569
+ ch,
570
+ time_embed_dim,
571
+ dropout,
572
+ out_channels=out_ch,
573
+ dims=dims,
574
+ use_checkpoint=use_checkpoint,
575
+ use_scale_shift_norm=use_scale_shift_norm,
576
+ down=True,
577
+ )
578
+ if resblock_updown
579
+ else Downsample(
580
+ ch, conv_resample, dims=dims, out_channels=out_ch
581
+ )
582
+ )
583
+ )
584
+ ch = out_ch
585
+ input_block_chans.append(ch)
586
+ ds *= 2
587
+ self._feature_size += ch
588
+
589
+ if num_head_channels == -1:
590
+ dim_head = ch // num_heads
591
+ else:
592
+ num_heads = ch // num_head_channels
593
+ dim_head = num_head_channels
594
+ if legacy:
595
+ #num_heads = 1
596
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
597
+ self.middle_block = TimestepEmbedSequential(
598
+ ResBlock(
599
+ ch,
600
+ time_embed_dim,
601
+ dropout,
602
+ dims=dims,
603
+ use_checkpoint=use_checkpoint,
604
+ use_scale_shift_norm=use_scale_shift_norm,
605
+ ),
606
+ AttentionBlock(
607
+ ch,
608
+ use_checkpoint=use_checkpoint,
609
+ num_heads=num_heads,
610
+ num_head_channels=dim_head,
611
+ use_new_attention_order=use_new_attention_order,
612
+ ) if not use_spatial_transformer else SpatialTransformer(
613
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
614
+ ),
615
+ ResBlock(
616
+ ch,
617
+ time_embed_dim,
618
+ dropout,
619
+ dims=dims,
620
+ use_checkpoint=use_checkpoint,
621
+ use_scale_shift_norm=use_scale_shift_norm,
622
+ ),
623
+ )
624
+ self._feature_size += ch
625
+
626
+ self.output_blocks = nn.ModuleList([])
627
+ for level, mult in list(enumerate(channel_mult))[::-1]:
628
+ for i in range(num_res_blocks + 1):
629
+ ich = input_block_chans.pop()
630
+ layers = [
631
+ ResBlock(
632
+ ch + ich,
633
+ time_embed_dim,
634
+ dropout,
635
+ out_channels=model_channels * mult,
636
+ dims=dims,
637
+ use_checkpoint=use_checkpoint,
638
+ use_scale_shift_norm=use_scale_shift_norm,
639
+ )
640
+ ]
641
+ ch = model_channels * mult
642
+ if ds in attention_resolutions:
643
+ if num_head_channels == -1:
644
+ dim_head = ch // num_heads
645
+ else:
646
+ num_heads = ch // num_head_channels
647
+ dim_head = num_head_channels
648
+ if legacy:
649
+ #num_heads = 1
650
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
651
+ layers.append(
652
+ AttentionBlock(
653
+ ch,
654
+ use_checkpoint=use_checkpoint,
655
+ num_heads=num_heads_upsample,
656
+ num_head_channels=dim_head,
657
+ use_new_attention_order=use_new_attention_order,
658
+ ) if not use_spatial_transformer else SpatialTransformer(
659
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
660
+ )
661
+ )
662
+ if level and i == num_res_blocks:
663
+ out_ch = ch
664
+ layers.append(
665
+ ResBlock(
666
+ ch,
667
+ time_embed_dim,
668
+ dropout,
669
+ out_channels=out_ch,
670
+ dims=dims,
671
+ use_checkpoint=use_checkpoint,
672
+ use_scale_shift_norm=use_scale_shift_norm,
673
+ up=True,
674
+ )
675
+ if resblock_updown
676
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
677
+ )
678
+ ds //= 2
679
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
680
+ self._feature_size += ch
681
+
682
+ self.out = nn.Sequential(
683
+ normalization(ch),
684
+ nn.SiLU(),
685
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
686
+ )
687
+ if self.predict_codebook_ids:
688
+ self.id_predictor = nn.Sequential(
689
+ normalization(ch),
690
+ conv_nd(dims, model_channels, n_embed, 1),
691
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
692
+ )
693
+
694
+ def convert_to_fp16(self):
695
+ """
696
+ Convert the torso of the model to float16.
697
+ """
698
+ self.input_blocks.apply(convert_module_to_f16)
699
+ self.middle_block.apply(convert_module_to_f16)
700
+ self.output_blocks.apply(convert_module_to_f16)
701
+
702
+ def convert_to_fp32(self):
703
+ """
704
+ Convert the torso of the model to float32.
705
+ """
706
+ self.input_blocks.apply(convert_module_to_f32)
707
+ self.middle_block.apply(convert_module_to_f32)
708
+ self.output_blocks.apply(convert_module_to_f32)
709
+
710
+ def forward(self, x, timesteps=None, context=None, y=None, attn_dict=None, **kwargs):
711
+ # attn_dict attribute is used to return the attention valus for visualization
712
+ """
713
+ Apply the model to an input batch.
714
+ :param x: an [N x C x ...] Tensor of inputs.
715
+ :param timesteps: a 1-D batch of timesteps.
716
+ :param context: conditioning plugged in via crossattn
717
+ :param y: an [N] Tensor of labels, if class-conditional.
718
+ :return: an [N x C x ...] Tensor of outputs.
719
+ """
720
+ keep_attns = True if attn_dict is not None else False
721
+
722
+ assert (y is not None) == (
723
+ self.num_classes is not None
724
+ ), "must specify y if and only if the model is class-conditional"
725
+ hs = []
726
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
727
+ emb = self.time_embed(t_emb)
728
+
729
+ if self.num_classes is not None:
730
+ assert y.shape == (x.shape[0],)
731
+ emb = emb + self.label_emb(y)
732
+
733
+ h = x.type(self.dtype)
734
+ for module in self.input_blocks:
735
+ if keep_attns:
736
+ h = module(h, emb, context, attn_dict=attn_dict, layer_type="down")
737
+ else:
738
+ h = module(h, emb, context)
739
+ hs.append(h)
740
+ if keep_attns:
741
+ h = self.middle_block(h, emb, context, attn_dict=attn_dict, layer_type="middle")
742
+ else:
743
+ h = self.middle_block(h, emb, context)
744
+ for module in self.output_blocks:
745
+ h = th.cat([h, hs.pop()], dim=1)
746
+ if keep_attns:
747
+ h = module(h, emb, context, attn_dict=attn_dict, layer_type="up")
748
+ else:
749
+ h = module(h, emb, context)
750
+ h = h.type(x.dtype)
751
+ if self.predict_codebook_ids:
752
+ return self.id_predictor(h)
753
+ else:
754
+ return self.out(h)
755
+
756
+
757
+ class EncoderUNetModel(nn.Module):
758
+ """
759
+ The half UNet model with attention and timestep embedding.
760
+ For usage, see UNet.
761
+ """
762
+
763
+ def __init__(
764
+ self,
765
+ image_size,
766
+ in_channels,
767
+ model_channels,
768
+ out_channels,
769
+ num_res_blocks,
770
+ attention_resolutions,
771
+ dropout=0,
772
+ channel_mult=(1, 2, 4, 8),
773
+ conv_resample=True,
774
+ dims=2,
775
+ use_checkpoint=False,
776
+ use_fp16=False,
777
+ num_heads=1,
778
+ num_head_channels=-1,
779
+ num_heads_upsample=-1,
780
+ use_scale_shift_norm=False,
781
+ resblock_updown=False,
782
+ use_new_attention_order=False,
783
+ pool="adaptive",
784
+ *args,
785
+ **kwargs
786
+ ):
787
+ super().__init__()
788
+
789
+ if num_heads_upsample == -1:
790
+ num_heads_upsample = num_heads
791
+
792
+ self.in_channels = in_channels
793
+ self.model_channels = model_channels
794
+ self.out_channels = out_channels
795
+ self.num_res_blocks = num_res_blocks
796
+ self.attention_resolutions = attention_resolutions
797
+ self.dropout = dropout
798
+ self.channel_mult = channel_mult
799
+ self.conv_resample = conv_resample
800
+ self.use_checkpoint = use_checkpoint
801
+ self.dtype = th.float16 if use_fp16 else th.float32
802
+ self.num_heads = num_heads
803
+ self.num_head_channels = num_head_channels
804
+ self.num_heads_upsample = num_heads_upsample
805
+
806
+ time_embed_dim = model_channels * 4
807
+ self.time_embed = nn.Sequential(
808
+ linear(model_channels, time_embed_dim),
809
+ nn.SiLU(),
810
+ linear(time_embed_dim, time_embed_dim),
811
+ )
812
+
813
+ self.input_blocks = nn.ModuleList(
814
+ [
815
+ TimestepEmbedSequential(
816
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
817
+ )
818
+ ]
819
+ )
820
+ self._feature_size = model_channels
821
+ input_block_chans = [model_channels]
822
+ ch = model_channels
823
+ ds = 1
824
+ for level, mult in enumerate(channel_mult):
825
+ for _ in range(num_res_blocks):
826
+ layers = [
827
+ ResBlock(
828
+ ch,
829
+ time_embed_dim,
830
+ dropout,
831
+ out_channels=mult * model_channels,
832
+ dims=dims,
833
+ use_checkpoint=use_checkpoint,
834
+ use_scale_shift_norm=use_scale_shift_norm,
835
+ )
836
+ ]
837
+ ch = mult * model_channels
838
+ if ds in attention_resolutions:
839
+ layers.append(
840
+ AttentionBlock(
841
+ ch,
842
+ use_checkpoint=use_checkpoint,
843
+ num_heads=num_heads,
844
+ num_head_channels=num_head_channels,
845
+ use_new_attention_order=use_new_attention_order,
846
+ )
847
+ )
848
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
849
+ self._feature_size += ch
850
+ input_block_chans.append(ch)
851
+ if level != len(channel_mult) - 1:
852
+ out_ch = ch
853
+ self.input_blocks.append(
854
+ TimestepEmbedSequential(
855
+ ResBlock(
856
+ ch,
857
+ time_embed_dim,
858
+ dropout,
859
+ out_channels=out_ch,
860
+ dims=dims,
861
+ use_checkpoint=use_checkpoint,
862
+ use_scale_shift_norm=use_scale_shift_norm,
863
+ down=True,
864
+ )
865
+ if resblock_updown
866
+ else Downsample(
867
+ ch, conv_resample, dims=dims, out_channels=out_ch
868
+ )
869
+ )
870
+ )
871
+ ch = out_ch
872
+ input_block_chans.append(ch)
873
+ ds *= 2
874
+ self._feature_size += ch
875
+
876
+ self.middle_block = TimestepEmbedSequential(
877
+ ResBlock(
878
+ ch,
879
+ time_embed_dim,
880
+ dropout,
881
+ dims=dims,
882
+ use_checkpoint=use_checkpoint,
883
+ use_scale_shift_norm=use_scale_shift_norm,
884
+ ),
885
+ AttentionBlock(
886
+ ch,
887
+ use_checkpoint=use_checkpoint,
888
+ num_heads=num_heads,
889
+ num_head_channels=num_head_channels,
890
+ use_new_attention_order=use_new_attention_order,
891
+ ),
892
+ ResBlock(
893
+ ch,
894
+ time_embed_dim,
895
+ dropout,
896
+ dims=dims,
897
+ use_checkpoint=use_checkpoint,
898
+ use_scale_shift_norm=use_scale_shift_norm,
899
+ ),
900
+ )
901
+ self._feature_size += ch
902
+ self.pool = pool
903
+ if pool == "adaptive":
904
+ self.out = nn.Sequential(
905
+ normalization(ch),
906
+ nn.SiLU(),
907
+ nn.AdaptiveAvgPool2d((1, 1)),
908
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
909
+ nn.Flatten(),
910
+ )
911
+ elif pool == "attention":
912
+ assert num_head_channels != -1
913
+ self.out = nn.Sequential(
914
+ normalization(ch),
915
+ nn.SiLU(),
916
+ AttentionPool2d(
917
+ (image_size // ds), ch, num_head_channels, out_channels
918
+ ),
919
+ )
920
+ elif pool == "spatial":
921
+ self.out = nn.Sequential(
922
+ nn.Linear(self._feature_size, 2048),
923
+ nn.ReLU(),
924
+ nn.Linear(2048, self.out_channels),
925
+ )
926
+ elif pool == "spatial_v2":
927
+ self.out = nn.Sequential(
928
+ nn.Linear(self._feature_size, 2048),
929
+ normalization(2048),
930
+ nn.SiLU(),
931
+ nn.Linear(2048, self.out_channels),
932
+ )
933
+ else:
934
+ raise NotImplementedError(f"Unexpected {pool} pooling")
935
+
936
+ def convert_to_fp16(self):
937
+ """
938
+ Convert the torso of the model to float16.
939
+ """
940
+ self.input_blocks.apply(convert_module_to_f16)
941
+ self.middle_block.apply(convert_module_to_f16)
942
+
943
+ def convert_to_fp32(self):
944
+ """
945
+ Convert the torso of the model to float32.
946
+ """
947
+ self.input_blocks.apply(convert_module_to_f32)
948
+ self.middle_block.apply(convert_module_to_f32)
949
+
950
+ def forward(self, x, timesteps):
951
+ """
952
+ Apply the model to an input batch.
953
+ :param x: an [N x C x ...] Tensor of inputs.
954
+ :param timesteps: a 1-D batch of timesteps.
955
+ :return: an [N x K] Tensor of outputs.
956
+ """
957
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
958
+
959
+ results = []
960
+ h = x.type(self.dtype)
961
+ for module in self.input_blocks:
962
+ h = module(h, emb)
963
+ if self.pool.startswith("spatial"):
964
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
965
+ h = self.middle_block(h, emb)
966
+ if self.pool.startswith("spatial"):
967
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
968
+ h = th.cat(results, axis=-1)
969
+ return self.out(h)
970
+ else:
971
+ h = h.type(x.dtype)
972
+ return self.out(h)
973
+
ldm/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from einops import repeat
17
+
18
+ from ldm.util import instantiate_from_config
19
+
20
+
21
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22
+ if schedule == "linear":
23
+ betas = (
24
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25
+ )
26
+
27
+ elif schedule == "cosine":
28
+ timesteps = (
29
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30
+ )
31
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
32
+ alphas = torch.cos(alphas).pow(2)
33
+ alphas = alphas / alphas[0]
34
+ betas = 1 - alphas[1:] / alphas[:-1]
35
+ betas = np.clip(betas, a_min=0, a_max=0.999)
36
+
37
+ elif schedule == "sqrt_linear":
38
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39
+ elif schedule == "sqrt":
40
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41
+ else:
42
+ raise ValueError(f"schedule '{schedule}' unknown.")
43
+ return betas.numpy()
44
+
45
+
46
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47
+ if ddim_discr_method == 'uniform':
48
+ c = num_ddpm_timesteps // num_ddim_timesteps
49
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50
+ elif ddim_discr_method == 'quad':
51
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52
+ else:
53
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54
+
55
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
57
+ steps_out = ddim_timesteps + 1
58
+ if verbose:
59
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
60
+ return steps_out
61
+
62
+
63
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64
+ # select alphas for computing the variance schedule
65
+ alphas = alphacums[ddim_timesteps]
66
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67
+
68
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
69
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70
+ if verbose:
71
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72
+ print(f'For the chosen value of eta, which is {eta}, '
73
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74
+ return sigmas, alphas, alphas_prev
75
+
76
+
77
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78
+ """
79
+ Create a beta schedule that discretizes the given alpha_t_bar function,
80
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
81
+ :param num_diffusion_timesteps: the number of betas to produce.
82
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83
+ produces the cumulative product of (1-beta) up to that
84
+ part of the diffusion process.
85
+ :param max_beta: the maximum beta to use; use values lower than 1 to
86
+ prevent singularities.
87
+ """
88
+ betas = []
89
+ for i in range(num_diffusion_timesteps):
90
+ t1 = i / num_diffusion_timesteps
91
+ t2 = (i + 1) / num_diffusion_timesteps
92
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93
+ return np.array(betas)
94
+
95
+
96
+ def extract_into_tensor(a, t, x_shape):
97
+ b, *_ = t.shape
98
+ out = a.gather(-1, t)
99
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100
+
101
+
102
+ def checkpoint(func, inputs, params, flag):
103
+ """
104
+ Evaluate a function without caching intermediate activations, allowing for
105
+ reduced memory at the expense of extra compute in the backward pass.
106
+ :param func: the function to evaluate.
107
+ :param inputs: the argument sequence to pass to `func`.
108
+ :param params: a sequence of parameters `func` depends on but does not
109
+ explicitly take as arguments.
110
+ :param flag: if False, disable gradient checkpointing.
111
+ """
112
+ if flag:
113
+
114
+ # For storing the attention maps, a dict object is sent to the forward function.
115
+ # Not to raise an error in below detach operation (backward function), dict and None objects are discarded.
116
+ inputs = [x for x in inputs if x is not isinstance(x, dict) and x is not None]
117
+
118
+ args = tuple(inputs) + tuple(params)
119
+ return CheckpointFunction.apply(func, len(inputs), *args)
120
+ else:
121
+ return func(*inputs)
122
+
123
+ class CheckpointFunction(torch.autograd.Function):
124
+ @staticmethod
125
+ def forward(ctx, run_function, length, *args):
126
+ ctx.run_function = run_function
127
+ ctx.input_tensors = list(args[:length])
128
+ ctx.input_params = list(args[length:])
129
+ with torch.no_grad():
130
+ output_tensors = ctx.run_function(*ctx.input_tensors)
131
+ return output_tensors
132
+
133
+ @staticmethod
134
+ def backward(ctx, *output_grads):
135
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
136
+ with torch.enable_grad():
137
+ # Fixes a bug where the first op in run_function modifies the
138
+ # Tensor storage in place, which is not allowed for detach()'d
139
+ # Tensors.
140
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
141
+ output_tensors = ctx.run_function(*shallow_copies)
142
+ input_grads = torch.autograd.grad(
143
+ output_tensors,
144
+ ctx.input_tensors + ctx.input_params,
145
+ output_grads,
146
+ allow_unused=True,
147
+ )
148
+ del ctx.input_tensors
149
+ del ctx.input_params
150
+ del output_tensors
151
+ return (None, None) + input_grads
152
+
153
+
154
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
155
+ """
156
+ Create sinusoidal timestep embeddings.
157
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
158
+ These may be fractional.
159
+ :param dim: the dimension of the output.
160
+ :param max_period: controls the minimum frequency of the embeddings.
161
+ :return: an [N x dim] Tensor of positional embeddings.
162
+ """
163
+ if not repeat_only:
164
+ half = dim // 2
165
+ freqs = torch.exp(
166
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
167
+ ).to(device=timesteps.device)
168
+ args = timesteps[:, None].float() * freqs[None]
169
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
170
+ if dim % 2:
171
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
172
+ else:
173
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
174
+ return embedding
175
+
176
+
177
+ def zero_module(module):
178
+ """
179
+ Zero out the parameters of a module and return it.
180
+ """
181
+ for p in module.parameters():
182
+ p.detach().zero_()
183
+ return module
184
+
185
+
186
+ def scale_module(module, scale):
187
+ """
188
+ Scale the parameters of a module and return it.
189
+ """
190
+ for p in module.parameters():
191
+ p.detach().mul_(scale)
192
+ return module
193
+
194
+
195
+ def mean_flat(tensor):
196
+ """
197
+ Take the mean over all non-batch dimensions.
198
+ """
199
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
200
+
201
+
202
+ def normalization(channels):
203
+ """
204
+ Make a standard normalization layer.
205
+ :param channels: number of input channels.
206
+ :return: an nn.Module for normalization.
207
+ """
208
+ return GroupNorm32(32, channels)
209
+
210
+
211
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
212
+ class SiLU(nn.Module):
213
+ def forward(self, x):
214
+ return x * torch.sigmoid(x)
215
+
216
+
217
+ class GroupNorm32(nn.GroupNorm):
218
+ def forward(self, x):
219
+ return super().forward(x.float()).type(x.dtype)
220
+
221
+ def conv_nd(dims, *args, **kwargs):
222
+ """
223
+ Create a 1D, 2D, or 3D convolution module.
224
+ """
225
+ if dims == 1:
226
+ return nn.Conv1d(*args, **kwargs)
227
+ elif dims == 2:
228
+ return nn.Conv2d(*args, **kwargs)
229
+ elif dims == 3:
230
+ return nn.Conv3d(*args, **kwargs)
231
+ raise ValueError(f"unsupported dimensions: {dims}")
232
+
233
+
234
+ def linear(*args, **kwargs):
235
+ """
236
+ Create a linear module.
237
+ """
238
+ return nn.Linear(*args, **kwargs)
239
+
240
+
241
+ def avg_pool_nd(dims, *args, **kwargs):
242
+ """
243
+ Create a 1D, 2D, or 3D average pooling module.
244
+ """
245
+ if dims == 1:
246
+ return nn.AvgPool1d(*args, **kwargs)
247
+ elif dims == 2:
248
+ return nn.AvgPool2d(*args, **kwargs)
249
+ elif dims == 3:
250
+ return nn.AvgPool3d(*args, **kwargs)
251
+ raise ValueError(f"unsupported dimensions: {dims}")
252
+
253
+
254
+ class HybridConditioner(nn.Module):
255
+
256
+ def __init__(self, c_concat_config, c_crossattn_config):
257
+ super().__init__()
258
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
259
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
260
+
261
+ def forward(self, c_concat, c_crossattn):
262
+ c_concat = self.concat_conditioner(c_concat)
263
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
264
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
265
+
266
+
267
+ def noise_like(shape, device, repeat=False):
268
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
269
+ noise = lambda: torch.randn(shape, device=device)
270
+ return repeat_noise() if repeat else noise()
ldm/modules/distributions/distributions.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34
+
35
+ def sample(self):
36
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37
+ return x
38
+
39
+ def kl(self, other=None):
40
+ if self.deterministic:
41
+ return torch.Tensor([0.])
42
+ else:
43
+ if other is None:
44
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
45
+ + self.var - 1.0 - self.logvar,
46
+ dim=[1, 2, 3])
47
+ else:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean - other.mean, 2) / other.var
50
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
51
+ dim=[1, 2, 3])
52
+
53
+ def nll(self, sample, dims=[1,2,3]):
54
+ if self.deterministic:
55
+ return torch.Tensor([0.])
56
+ logtwopi = np.log(2.0 * np.pi)
57
+ return 0.5 * torch.sum(
58
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59
+ dim=dims)
60
+
61
+ def mode(self):
62
+ return self.mean
63
+
64
+
65
+ def normal_kl(mean1, logvar1, mean2, logvar2):
66
+ """
67
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68
+ Compute the KL divergence between two gaussians.
69
+ Shapes are automatically broadcasted, so batches can be compared to
70
+ scalars, among other use cases.
71
+ """
72
+ tensor = None
73
+ for obj in (mean1, logvar1, mean2, logvar2):
74
+ if isinstance(obj, torch.Tensor):
75
+ tensor = obj
76
+ break
77
+ assert tensor is not None, "at least one argument must be a Tensor"
78
+
79
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
80
+ # Tensors, but it does not work for torch.exp().
81
+ logvar1, logvar2 = [
82
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83
+ for x in (logvar1, logvar2)
84
+ ]
85
+
86
+ return 0.5 * (
87
+ -1.0
88
+ + logvar2
89
+ - logvar1
90
+ + torch.exp(logvar1 - logvar2)
91
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92
+ )
ldm/modules/ema.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError('Decay must be between 0 and 1')
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14
+ else torch.tensor(-1,dtype=torch.int))
15
+
16
+ for name, p in model.named_parameters():
17
+ if p.requires_grad:
18
+ #remove as '.'-character is not allowed in buffers
19
+ s_name = name.replace('.','')
20
+ self.m_name2s_name.update({name:s_name})
21
+ self.register_buffer(s_name,p.clone().detach().data)
22
+
23
+ self.collected_params = []
24
+
25
+ def forward(self,model):
26
+ decay = self.decay
27
+
28
+ if self.num_updates >= 0:
29
+ self.num_updates += 1
30
+ decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31
+
32
+ one_minus_decay = 1.0 - decay
33
+
34
+ with torch.no_grad():
35
+ m_param = dict(model.named_parameters())
36
+ shadow_params = dict(self.named_buffers())
37
+
38
+ for key in m_param:
39
+ if m_param[key].requires_grad:
40
+ sname = self.m_name2s_name[key]
41
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43
+ else:
44
+ assert not key in self.m_name2s_name
45
+
46
+ def copy_to(self, model):
47
+ m_param = dict(model.named_parameters())
48
+ shadow_params = dict(self.named_buffers())
49
+ for key in m_param:
50
+ if m_param[key].requires_grad:
51
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52
+ else:
53
+ assert not key in self.m_name2s_name
54
+
55
+ def store(self, parameters):
56
+ """
57
+ Save the current parameters for restoring later.
58
+ Args:
59
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60
+ temporarily stored.
61
+ """
62
+ self.collected_params = [param.clone() for param in parameters]
63
+
64
+ def restore(self, parameters):
65
+ """
66
+ Restore the parameters stored with the `store` method.
67
+ Useful to validate the model with EMA parameters without affecting the
68
+ original optimization process. Store the parameters before the
69
+ `copy_to` method. After validation (or model saving), use this to
70
+ restore the former parameters.
71
+ Args:
72
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73
+ updated with the stored parameters.
74
+ """
75
+ for c_param, param in zip(self.collected_params, parameters):
76
+ param.data.copy_(c_param.data)
ldm/modules/encoders/modules.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from ldm.modules.x_transformer import Encoder, TransformerWrapper
4
+
5
+
6
+ class BERTTokenizer(nn.Module):
7
+ def __init__(self, vq_interface=True, max_length=77):
8
+ super().__init__()
9
+ from transformers import BertTokenizerFast
10
+ self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
11
+ self.vq_interface = vq_interface
12
+ self.max_length = max_length
13
+
14
+ def forward(self, text, return_batch_encoding=False):
15
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
16
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
17
+ tokens = batch_encoding["input_ids"]
18
+ if return_batch_encoding:
19
+ return tokens, batch_encoding
20
+ return tokens
21
+
22
+ @torch.no_grad()
23
+ def encode(self, text):
24
+ tokens = self(text)
25
+ if not self.vq_interface:
26
+ return tokens
27
+ return None, None, [None, None, tokens]
28
+
29
+ def decode(self, text):
30
+ return text
31
+
32
+ class BERTEmbedder(nn.Module):
33
+ """Uses the BERT tokenizer model and adds some transformer encoder layers"""
34
+ def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, use_tokenizer=True, embedding_dropout=0.0):
35
+ super().__init__()
36
+ self.use_tknz_fn = use_tokenizer
37
+ if self.use_tknz_fn:
38
+ self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
39
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
40
+ attn_layers=Encoder(dim=n_embed, depth=n_layer),
41
+ emb_dropout=embedding_dropout)
42
+
43
+ def forward(self, cond, text):
44
+ assert cond is None # Not supported for now (LDM conditioning key == "concat")
45
+ if self.use_tknz_fn:
46
+ tokens = self.tknz_fn(text)
47
+ if next(self.transformer.parameters()).is_cuda:
48
+ tokens = tokens.cuda()
49
+ else:
50
+ tokens = text
51
+ z = self.transformer(tokens, return_embeddings=True) # Size: [batch_size, max_seq_len, n_embed]
52
+ return z
ldm/modules/losses/contperceptual.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5
+
6
+
7
+ class LPIPSWithDiscriminator(nn.Module):
8
+ def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11
+ disc_loss="hinge"):
12
+
13
+ super().__init__()
14
+ assert disc_loss in ["hinge", "vanilla"]
15
+ self.kl_weight = kl_weight
16
+ self.pixel_weight = pixelloss_weight
17
+ self.perceptual_loss = LPIPS().eval()
18
+ self.perceptual_weight = perceptual_weight
19
+ # output log variance
20
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21
+
22
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23
+ n_layers=disc_num_layers,
24
+ use_actnorm=use_actnorm
25
+ ).apply(weights_init)
26
+ self.discriminator_iter_start = disc_start
27
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28
+ self.disc_factor = disc_factor
29
+ self.discriminator_weight = disc_weight
30
+ self.disc_conditional = disc_conditional
31
+
32
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33
+ if last_layer is not None:
34
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36
+ else:
37
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39
+
40
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42
+ d_weight = d_weight * self.discriminator_weight
43
+ return d_weight
44
+
45
+ def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46
+ global_step, last_layer=None, cond=None, split="train",
47
+ weights=None):
48
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
49
+ if self.perceptual_weight > 0:
50
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
51
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
52
+
53
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
54
+ weighted_nll_loss = nll_loss
55
+ if weights is not None:
56
+ weighted_nll_loss = weights*nll_loss
57
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
58
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
59
+ kl_loss = posteriors.kl()
60
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
61
+
62
+ # now the GAN part
63
+ if optimizer_idx == 0:
64
+ # generator update
65
+ if cond is None:
66
+ assert not self.disc_conditional
67
+ logits_fake = self.discriminator(reconstructions.contiguous())
68
+ else:
69
+ assert self.disc_conditional
70
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
71
+ g_loss = -torch.mean(logits_fake)
72
+
73
+ if self.disc_factor > 0.0:
74
+ try:
75
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
76
+ except RuntimeError:
77
+ assert not self.training
78
+ d_weight = torch.tensor(0.0)
79
+ else:
80
+ d_weight = torch.tensor(0.0)
81
+
82
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
83
+ loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
84
+
85
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
86
+ "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
87
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
88
+ "{}/d_weight".format(split): d_weight.detach(),
89
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
90
+ "{}/g_loss".format(split): g_loss.detach().mean(),
91
+ }
92
+ return loss, log
93
+
94
+ if optimizer_idx == 1:
95
+ # second pass for discriminator update
96
+ if cond is None:
97
+ logits_real = self.discriminator(inputs.contiguous().detach())
98
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
99
+ else:
100
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
101
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
102
+
103
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
104
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
105
+
106
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
107
+ "{}/logits_real".format(split): logits_real.detach().mean(),
108
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
109
+ }
110
+ return d_loss, log
111
+
ldm/modules/losses/vqperceptual.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from einops import repeat
5
+
6
+ from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7
+ from taming.modules.losses.lpips import LPIPS
8
+ from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9
+
10
+
11
+ def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12
+ assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13
+ loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
14
+ loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
15
+ loss_real = (weights * loss_real).sum() / weights.sum()
16
+ loss_fake = (weights * loss_fake).sum() / weights.sum()
17
+ d_loss = 0.5 * (loss_real + loss_fake)
18
+ return d_loss
19
+
20
+ def adopt_weight(weight, global_step, threshold=0, value=0.):
21
+ if global_step < threshold:
22
+ weight = value
23
+ return weight
24
+
25
+
26
+ def measure_perplexity(predicted_indices, n_embed):
27
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
28
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
29
+ encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
30
+ avg_probs = encodings.mean(0)
31
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
32
+ cluster_use = torch.sum(avg_probs > 0)
33
+ return perplexity, cluster_use
34
+
35
+ def l1(x, y):
36
+ return torch.abs(x-y)
37
+
38
+
39
+ def l2(x, y):
40
+ return torch.pow((x-y), 2)
41
+
42
+
43
+ class VQLPIPSWithDiscriminator(nn.Module):
44
+ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
45
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
46
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
47
+ disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
48
+ pixel_loss="l1"):
49
+ super().__init__()
50
+ assert disc_loss in ["hinge", "vanilla"]
51
+ assert perceptual_loss in ["lpips", "clips", "dists"]
52
+ assert pixel_loss in ["l1", "l2"]
53
+ self.codebook_weight = codebook_weight
54
+ self.pixel_weight = pixelloss_weight
55
+ if perceptual_loss == "lpips":
56
+ print(f"{self.__class__.__name__}: Running with LPIPS.")
57
+ self.perceptual_loss = LPIPS().eval()
58
+ else:
59
+ raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60
+ self.perceptual_weight = perceptual_weight
61
+
62
+ if pixel_loss == "l1":
63
+ self.pixel_loss = l1
64
+ else:
65
+ self.pixel_loss = l2
66
+
67
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
68
+ n_layers=disc_num_layers,
69
+ use_actnorm=use_actnorm,
70
+ ndf=disc_ndf
71
+ ).apply(weights_init)
72
+ self.discriminator_iter_start = disc_start
73
+ if disc_loss == "hinge":
74
+ self.disc_loss = hinge_d_loss
75
+ elif disc_loss == "vanilla":
76
+ self.disc_loss = vanilla_d_loss
77
+ else:
78
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
79
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
80
+ self.disc_factor = disc_factor
81
+ self.discriminator_weight = disc_weight
82
+ self.disc_conditional = disc_conditional
83
+ self.n_classes = n_classes
84
+
85
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
86
+ if last_layer is not None:
87
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
88
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
89
+ else:
90
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
91
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
92
+
93
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
94
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
95
+ d_weight = d_weight * self.discriminator_weight
96
+ return d_weight
97
+
98
+ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
99
+ global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
100
+ if codebook_loss is None:
101
+ codebook_loss = torch.tensor([0.]).to(inputs.device)
102
+ #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
103
+ rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
104
+ if self.perceptual_weight > 0:
105
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
106
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
107
+ else:
108
+ p_loss = torch.tensor([0.0])
109
+
110
+ nll_loss = rec_loss
111
+ #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
112
+ nll_loss = torch.mean(nll_loss)
113
+
114
+ # now the GAN part
115
+ if optimizer_idx == 0:
116
+ # generator update
117
+ if cond is None:
118
+ assert not self.disc_conditional
119
+ logits_fake = self.discriminator(reconstructions.contiguous())
120
+ else:
121
+ assert self.disc_conditional
122
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
123
+ g_loss = -torch.mean(logits_fake)
124
+
125
+ try:
126
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
127
+ except RuntimeError:
128
+ assert not self.training
129
+ d_weight = torch.tensor(0.0)
130
+
131
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
132
+ loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
133
+
134
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
135
+ "{}/quant_loss".format(split): codebook_loss.detach().mean(),
136
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
137
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
138
+ "{}/p_loss".format(split): p_loss.detach().mean(),
139
+ "{}/d_weight".format(split): d_weight.detach(),
140
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
141
+ "{}/g_loss".format(split): g_loss.detach().mean(),
142
+ }
143
+ # if predicted_indices is not None:
144
+ # assert self.n_classes is not None
145
+ # with torch.no_grad():
146
+ # perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
147
+ # log[f"{split}/perplexity"] = perplexity
148
+ # log[f"{split}/cluster_usage"] = cluster_usage
149
+ return loss, log
150
+
151
+ if optimizer_idx == 1:
152
+ # second pass for discriminator update
153
+ if cond is None:
154
+ logits_real = self.discriminator(inputs.contiguous().detach())
155
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
156
+ else:
157
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
158
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
159
+
160
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
161
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
162
+
163
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
164
+ "{}/logits_real".format(split): logits_real.detach().mean(),
165
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
166
+ }
167
+ return d_loss, log
ldm/modules/x_transformer.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
2
+ import torch
3
+ from torch import nn, einsum
4
+ import torch.nn.functional as F
5
+ from functools import partial
6
+ from inspect import isfunction
7
+ from collections import namedtuple
8
+ from einops import rearrange, repeat, reduce
9
+
10
+ # constants
11
+
12
+ DEFAULT_DIM_HEAD = 64
13
+
14
+ Intermediates = namedtuple('Intermediates', [
15
+ 'pre_softmax_attn',
16
+ 'post_softmax_attn'
17
+ ])
18
+
19
+ LayerIntermediates = namedtuple('Intermediates', [
20
+ 'hiddens',
21
+ 'attn_intermediates'
22
+ ])
23
+
24
+
25
+ class AbsolutePositionalEmbedding(nn.Module):
26
+ def __init__(self, dim, max_seq_len):
27
+ super().__init__()
28
+ self.emb = nn.Embedding(max_seq_len, dim)
29
+ self.init_()
30
+
31
+ def init_(self):
32
+ nn.init.normal_(self.emb.weight, std=0.02)
33
+
34
+ def forward(self, x):
35
+ n = torch.arange(x.shape[1], device=x.device)
36
+ return self.emb(n)[None, :, :]
37
+
38
+
39
+ class FixedPositionalEmbedding(nn.Module):
40
+ def __init__(self, dim):
41
+ super().__init__()
42
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
43
+ self.register_buffer('inv_freq', inv_freq)
44
+
45
+ def forward(self, x, seq_dim=1, offset=0):
46
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
47
+ sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
48
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
49
+ return emb[None, :, :]
50
+
51
+
52
+ # helpers
53
+
54
+ def exists(val):
55
+ return val is not None
56
+
57
+
58
+ def default(val, d):
59
+ if exists(val):
60
+ return val
61
+ return d() if isfunction(d) else d
62
+
63
+
64
+ def always(val):
65
+ def inner(*args, **kwargs):
66
+ return val
67
+ return inner
68
+
69
+
70
+ def not_equals(val):
71
+ def inner(x):
72
+ return x != val
73
+ return inner
74
+
75
+
76
+ def equals(val):
77
+ def inner(x):
78
+ return x == val
79
+ return inner
80
+
81
+
82
+ def max_neg_value(tensor):
83
+ return -torch.finfo(tensor.dtype).max
84
+
85
+
86
+ # keyword argument helpers
87
+
88
+ def pick_and_pop(keys, d):
89
+ values = list(map(lambda key: d.pop(key), keys))
90
+ return dict(zip(keys, values))
91
+
92
+
93
+ def group_dict_by_key(cond, d):
94
+ return_val = [dict(), dict()]
95
+ for key in d.keys():
96
+ match = bool(cond(key))
97
+ ind = int(not match)
98
+ return_val[ind][key] = d[key]
99
+ return (*return_val,)
100
+
101
+
102
+ def string_begins_with(prefix, str):
103
+ return str.startswith(prefix)
104
+
105
+
106
+ def group_by_key_prefix(prefix, d):
107
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
108
+
109
+
110
+ def groupby_prefix_and_trim(prefix, d):
111
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
112
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
113
+ return kwargs_without_prefix, kwargs
114
+
115
+
116
+ # classes
117
+ class Scale(nn.Module):
118
+ def __init__(self, value, fn):
119
+ super().__init__()
120
+ self.value = value
121
+ self.fn = fn
122
+
123
+ def forward(self, x, **kwargs):
124
+ x, *rest = self.fn(x, **kwargs)
125
+ return (x * self.value, *rest)
126
+
127
+
128
+ class Rezero(nn.Module):
129
+ def __init__(self, fn):
130
+ super().__init__()
131
+ self.fn = fn
132
+ self.g = nn.Parameter(torch.zeros(1))
133
+
134
+ def forward(self, x, **kwargs):
135
+ x, *rest = self.fn(x, **kwargs)
136
+ return (x * self.g, *rest)
137
+
138
+
139
+ class ScaleNorm(nn.Module):
140
+ def __init__(self, dim, eps=1e-5):
141
+ super().__init__()
142
+ self.scale = dim ** -0.5
143
+ self.eps = eps
144
+ self.g = nn.Parameter(torch.ones(1))
145
+
146
+ def forward(self, x):
147
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
148
+ return x / norm.clamp(min=self.eps) * self.g
149
+
150
+
151
+ class RMSNorm(nn.Module):
152
+ def __init__(self, dim, eps=1e-8):
153
+ super().__init__()
154
+ self.scale = dim ** -0.5
155
+ self.eps = eps
156
+ self.g = nn.Parameter(torch.ones(dim))
157
+
158
+ def forward(self, x):
159
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
160
+ return x / norm.clamp(min=self.eps) * self.g
161
+
162
+
163
+ class Residual(nn.Module):
164
+ def forward(self, x, residual):
165
+ return x + residual
166
+
167
+
168
+ class GRUGating(nn.Module):
169
+ def __init__(self, dim):
170
+ super().__init__()
171
+ self.gru = nn.GRUCell(dim, dim)
172
+
173
+ def forward(self, x, residual):
174
+ gated_output = self.gru(
175
+ rearrange(x, 'b n d -> (b n) d'),
176
+ rearrange(residual, 'b n d -> (b n) d')
177
+ )
178
+
179
+ return gated_output.reshape_as(x)
180
+
181
+
182
+ # feedforward
183
+
184
+ class GEGLU(nn.Module):
185
+ def __init__(self, dim_in, dim_out):
186
+ super().__init__()
187
+ self.proj = nn.Linear(dim_in, dim_out * 2)
188
+
189
+ def forward(self, x):
190
+ x, gate = self.proj(x).chunk(2, dim=-1)
191
+ return x * F.gelu(gate)
192
+
193
+
194
+ class FeedForward(nn.Module):
195
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
196
+ super().__init__()
197
+ inner_dim = int(dim * mult)
198
+ dim_out = default(dim_out, dim)
199
+ project_in = nn.Sequential(
200
+ nn.Linear(dim, inner_dim),
201
+ nn.GELU()
202
+ ) if not glu else GEGLU(dim, inner_dim)
203
+
204
+ self.net = nn.Sequential(
205
+ project_in,
206
+ nn.Dropout(dropout),
207
+ nn.Linear(inner_dim, dim_out)
208
+ )
209
+
210
+ def forward(self, x):
211
+ return self.net(x)
212
+
213
+
214
+ # attention.
215
+ class Attention(nn.Module):
216
+ def __init__(
217
+ self,
218
+ dim,
219
+ dim_head=DEFAULT_DIM_HEAD,
220
+ heads=8,
221
+ causal=False,
222
+ mask=None,
223
+ talking_heads=False,
224
+ sparse_topk=None,
225
+ use_entmax15=False,
226
+ num_mem_kv=0,
227
+ dropout=0.,
228
+ on_attn=False
229
+ ):
230
+ super().__init__()
231
+ if use_entmax15:
232
+ raise NotImplementedError("Check out entmax activation instead of softmax activation!")
233
+ self.scale = dim_head ** -0.5
234
+ self.heads = heads
235
+ self.causal = causal
236
+ self.mask = mask
237
+
238
+ inner_dim = dim_head * heads
239
+
240
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
241
+ self.to_k = nn.Linear(dim, inner_dim, bias=False)
242
+ self.to_v = nn.Linear(dim, inner_dim, bias=False)
243
+ self.dropout = nn.Dropout(dropout)
244
+
245
+ # talking heads
246
+ self.talking_heads = talking_heads
247
+ if talking_heads:
248
+ self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
249
+ self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
250
+
251
+ # explicit topk sparse attention
252
+ self.sparse_topk = sparse_topk
253
+
254
+ # entmax
255
+ #self.attn_fn = entmax15 if use_entmax15 else F.softmax
256
+ self.attn_fn = F.softmax
257
+
258
+ # add memory key / values
259
+ self.num_mem_kv = num_mem_kv
260
+ if num_mem_kv > 0:
261
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
262
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
263
+
264
+ # attention on attention
265
+ self.attn_on_attn = on_attn
266
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
267
+
268
+ def forward(
269
+ self,
270
+ x,
271
+ context=None,
272
+ mask=None,
273
+ context_mask=None,
274
+ rel_pos=None,
275
+ sinusoidal_emb=None,
276
+ prev_attn=None,
277
+ mem=None
278
+ ):
279
+ b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
280
+ kv_input = default(context, x)
281
+
282
+ q_input = x
283
+ k_input = kv_input
284
+ v_input = kv_input
285
+
286
+ if exists(mem):
287
+ k_input = torch.cat((mem, k_input), dim=-2)
288
+ v_input = torch.cat((mem, v_input), dim=-2)
289
+
290
+ if exists(sinusoidal_emb):
291
+ # in shortformer, the query would start at a position offset depending on the past cached memory
292
+ offset = k_input.shape[-2] - q_input.shape[-2]
293
+ q_input = q_input + sinusoidal_emb(q_input, offset=offset)
294
+ k_input = k_input + sinusoidal_emb(k_input)
295
+
296
+ q = self.to_q(q_input)
297
+ k = self.to_k(k_input)
298
+ v = self.to_v(v_input)
299
+
300
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
301
+
302
+ input_mask = None
303
+ if any(map(exists, (mask, context_mask))):
304
+ q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
305
+ k_mask = q_mask if not exists(context) else context_mask
306
+ k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
307
+ q_mask = rearrange(q_mask, 'b i -> b () i ()')
308
+ k_mask = rearrange(k_mask, 'b j -> b () () j')
309
+ input_mask = q_mask * k_mask
310
+
311
+ if self.num_mem_kv > 0:
312
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
313
+ k = torch.cat((mem_k, k), dim=-2)
314
+ v = torch.cat((mem_v, v), dim=-2)
315
+ if exists(input_mask):
316
+ input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
317
+
318
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
319
+ mask_value = max_neg_value(dots)
320
+
321
+ if exists(prev_attn):
322
+ dots = dots + prev_attn
323
+
324
+ pre_softmax_attn = dots
325
+
326
+ if talking_heads:
327
+ dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
328
+
329
+ if exists(rel_pos):
330
+ dots = rel_pos(dots)
331
+
332
+ if exists(input_mask):
333
+ dots.masked_fill_(~input_mask, mask_value)
334
+ del input_mask
335
+
336
+ if self.causal:
337
+ i, j = dots.shape[-2:]
338
+ r = torch.arange(i, device=device)
339
+ mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
340
+ mask = F.pad(mask, (j - i, 0), value=False)
341
+ dots.masked_fill_(mask, mask_value)
342
+ del mask
343
+
344
+ if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
345
+ top, _ = dots.topk(self.sparse_topk, dim=-1)
346
+ vk = top[..., -1].unsqueeze(-1).expand_as(dots)
347
+ mask = dots < vk
348
+ dots.masked_fill_(mask, mask_value)
349
+ del mask
350
+
351
+ attn = self.attn_fn(dots, dim=-1)
352
+ post_softmax_attn = attn
353
+
354
+ attn = self.dropout(attn)
355
+
356
+ if talking_heads:
357
+ attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
358
+
359
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
360
+ out = rearrange(out, 'b h n d -> b n (h d)')
361
+
362
+ intermediates = Intermediates(
363
+ pre_softmax_attn=pre_softmax_attn,
364
+ post_softmax_attn=post_softmax_attn
365
+ )
366
+
367
+ return self.to_out(out), intermediates
368
+
369
+
370
+ class AttentionLayers(nn.Module):
371
+ def __init__(
372
+ self,
373
+ dim,
374
+ depth,
375
+ heads=8,
376
+ causal=False,
377
+ cross_attend=False,
378
+ only_cross=False,
379
+ use_scalenorm=False,
380
+ use_rmsnorm=False,
381
+ use_rezero=False,
382
+ rel_pos_num_buckets=32,
383
+ rel_pos_max_distance=128,
384
+ position_infused_attn=False,
385
+ custom_layers=None,
386
+ sandwich_coef=None,
387
+ par_ratio=None,
388
+ residual_attn=False,
389
+ cross_residual_attn=False,
390
+ macaron=False,
391
+ pre_norm=True,
392
+ gate_residual=False,
393
+ **kwargs
394
+ ):
395
+ super().__init__()
396
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
397
+ attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
398
+
399
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
400
+
401
+ self.dim = dim
402
+ self.depth = depth
403
+ self.layers = nn.ModuleList([])
404
+
405
+ self.has_pos_emb = position_infused_attn
406
+ self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
407
+ self.rotary_pos_emb = always(None)
408
+
409
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
410
+ self.rel_pos = None
411
+
412
+ self.pre_norm = pre_norm
413
+
414
+ self.residual_attn = residual_attn
415
+ self.cross_residual_attn = cross_residual_attn
416
+
417
+ norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
418
+ norm_class = RMSNorm if use_rmsnorm else norm_class
419
+ norm_fn = partial(norm_class, dim)
420
+
421
+ norm_fn = nn.Identity if use_rezero else norm_fn
422
+ branch_fn = Rezero if use_rezero else None
423
+
424
+ if cross_attend and not only_cross:
425
+ default_block = ('a', 'c', 'f')
426
+ elif cross_attend and only_cross:
427
+ default_block = ('c', 'f')
428
+ else:
429
+ default_block = ('a', 'f')
430
+
431
+ if macaron:
432
+ default_block = ('f',) + default_block
433
+
434
+ if exists(custom_layers):
435
+ layer_types = custom_layers
436
+ elif exists(par_ratio):
437
+ par_depth = depth * len(default_block)
438
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
439
+ default_block = tuple(filter(not_equals('f'), default_block))
440
+ par_attn = par_depth // par_ratio
441
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
442
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
443
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
444
+ par_block = default_block + ('f',) * (par_width - len(default_block))
445
+ par_head = par_block * par_attn
446
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
447
+ elif exists(sandwich_coef):
448
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
449
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
450
+ else:
451
+ layer_types = default_block * depth
452
+
453
+ self.layer_types = layer_types
454
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
455
+
456
+ for layer_type in self.layer_types:
457
+ if layer_type == 'a':
458
+ layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
459
+ elif layer_type == 'c':
460
+ layer = Attention(dim, heads=heads, **attn_kwargs)
461
+ elif layer_type == 'f':
462
+ layer = FeedForward(dim, **ff_kwargs)
463
+ layer = layer if not macaron else Scale(0.5, layer)
464
+ else:
465
+ raise Exception(f'invalid layer type {layer_type}')
466
+
467
+ if isinstance(layer, Attention) and exists(branch_fn):
468
+ layer = branch_fn(layer)
469
+
470
+ if gate_residual:
471
+ residual_fn = GRUGating(dim)
472
+ else:
473
+ residual_fn = Residual()
474
+
475
+ self.layers.append(nn.ModuleList([
476
+ norm_fn(),
477
+ layer,
478
+ residual_fn
479
+ ]))
480
+
481
+ def forward(
482
+ self,
483
+ x,
484
+ context=None,
485
+ mask=None,
486
+ context_mask=None,
487
+ mems=None,
488
+ return_hiddens=False
489
+ ):
490
+ hiddens = []
491
+ intermediates = []
492
+ prev_attn = None
493
+ prev_cross_attn = None
494
+
495
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
496
+
497
+ for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
498
+ is_last = ind == (len(self.layers) - 1)
499
+
500
+ if layer_type == 'a':
501
+ hiddens.append(x)
502
+ layer_mem = mems.pop(0)
503
+
504
+ residual = x
505
+
506
+ if self.pre_norm:
507
+ x = norm(x)
508
+
509
+ if layer_type == 'a':
510
+ out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
511
+ prev_attn=prev_attn, mem=layer_mem)
512
+ elif layer_type == 'c':
513
+ out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
514
+ elif layer_type == 'f':
515
+ out = block(x)
516
+
517
+ x = residual_fn(out, residual)
518
+
519
+ if layer_type in ('a', 'c'):
520
+ intermediates.append(inter)
521
+
522
+ if layer_type == 'a' and self.residual_attn:
523
+ prev_attn = inter.pre_softmax_attn
524
+ elif layer_type == 'c' and self.cross_residual_attn:
525
+ prev_cross_attn = inter.pre_softmax_attn
526
+
527
+ if not self.pre_norm and not is_last:
528
+ x = norm(x)
529
+
530
+ if return_hiddens:
531
+ intermediates = LayerIntermediates(
532
+ hiddens=hiddens,
533
+ attn_intermediates=intermediates
534
+ )
535
+
536
+ return x, intermediates
537
+
538
+ return x
539
+
540
+
541
+ class Encoder(AttentionLayers):
542
+ def __init__(self, **kwargs):
543
+ assert 'causal' not in kwargs, 'cannot set causality on encoder'
544
+ super().__init__(causal=False, **kwargs)
545
+
546
+
547
+
548
+ class TransformerWrapper(nn.Module):
549
+ def __init__(
550
+ self,
551
+ *,
552
+ num_tokens,
553
+ max_seq_len,
554
+ attn_layers,
555
+ emb_dim=None,
556
+ max_mem_len=0.,
557
+ emb_dropout=0.,
558
+ num_memory_tokens=None,
559
+ tie_embedding=False,
560
+ use_pos_emb=True
561
+ ):
562
+ super().__init__()
563
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
564
+
565
+ dim = attn_layers.dim
566
+ emb_dim = default(emb_dim, dim)
567
+
568
+ self.max_seq_len = max_seq_len
569
+ self.max_mem_len = max_mem_len
570
+ self.num_tokens = num_tokens
571
+
572
+ self.token_emb = nn.Embedding(num_tokens, emb_dim)
573
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
574
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
575
+ self.emb_dropout = nn.Dropout(emb_dropout)
576
+
577
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
578
+ self.attn_layers = attn_layers
579
+ self.norm = nn.LayerNorm(dim)
580
+
581
+ self.init_()
582
+
583
+ self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
584
+
585
+ # memory tokens (like [cls]) from Memory Transformers paper
586
+ num_memory_tokens = default(num_memory_tokens, 0)
587
+ self.num_memory_tokens = num_memory_tokens
588
+ if num_memory_tokens > 0:
589
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
590
+
591
+ # let funnel encoder know number of memory tokens, if specified
592
+ if hasattr(attn_layers, 'num_memory_tokens'):
593
+ attn_layers.num_memory_tokens = num_memory_tokens
594
+
595
+ def init_(self):
596
+ nn.init.normal_(self.token_emb.weight, std=0.02)
597
+
598
+ def forward(
599
+ self,
600
+ x,
601
+ return_embeddings=False,
602
+ mask=None,
603
+ return_mems=False,
604
+ return_attn=False,
605
+ mems=None,
606
+ **kwargs
607
+ ):
608
+ b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
609
+ x = self.token_emb(x)
610
+ x += self.pos_emb(x)
611
+ x = self.emb_dropout(x)
612
+
613
+ x = self.project_emb(x)
614
+
615
+ if num_mem > 0:
616
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
617
+ x = torch.cat((mem, x), dim=1)
618
+
619
+ # auto-handle masking after appending memory tokens
620
+ if exists(mask):
621
+ mask = F.pad(mask, (num_mem, 0), value=True)
622
+
623
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
624
+ x = self.norm(x)
625
+
626
+ mem, x = x[:, :num_mem], x[:, num_mem:]
627
+
628
+ out = self.to_logits(x) if not return_embeddings else x
629
+
630
+ if return_mems:
631
+ hiddens = intermediates.hiddens
632
+ new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
633
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
634
+ return out, new_mems
635
+
636
+ if return_attn:
637
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
638
+ return out, attn_maps
639
+
640
+ return out
641
+
ldm/util.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ import torch
4
+ import numpy as np
5
+ from collections import abc
6
+ import os
7
+ import random
8
+
9
+ import multiprocessing as mp
10
+ from threading import Thread
11
+ from queue import Queue
12
+
13
+ from inspect import isfunction
14
+ from PIL import Image, ImageDraw, ImageFont
15
+
16
+
17
+ def log_txt_as_img(wh, xc, size=10):
18
+ # wh a tuple of (width, height)
19
+ # xc a list of captions to plot
20
+ b = len(xc)
21
+ txts = list()
22
+ for bi in range(b):
23
+ txt = Image.new("RGB", wh, color="white")
24
+ draw = ImageDraw.Draw(txt)
25
+ font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
26
+ nc = int(40 * (wh[0] / 256))
27
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
28
+
29
+ try:
30
+ draw.text((0, 0), lines, fill="black", font=font)
31
+ except UnicodeEncodeError:
32
+ print("Cant encode string for logging. Skipping.")
33
+
34
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
35
+ txts.append(txt)
36
+ txts = np.stack(txts)
37
+ txts = torch.tensor(txts)
38
+ return txts
39
+
40
+
41
+ def ismap(x):
42
+ if not isinstance(x, torch.Tensor):
43
+ return False
44
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
45
+
46
+
47
+ def isimage(x):
48
+ if not isinstance(x, torch.Tensor):
49
+ return False
50
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
51
+
52
+
53
+ def exists(x):
54
+ return x is not None
55
+
56
+
57
+ def default(val, d):
58
+ if exists(val):
59
+ return val
60
+ return d() if isfunction(d) else d
61
+
62
+
63
+ def mean_flat(tensor):
64
+ """
65
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
66
+ Take the mean over all non-batch dimensions.
67
+ """
68
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
69
+
70
+
71
+ def count_params(model, verbose=False):
72
+ total_params = sum(p.numel() for p in model.parameters())
73
+ if verbose:
74
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
75
+ return total_params
76
+
77
+
78
+ def instantiate_from_config(config):
79
+ if not "target" in config:
80
+ if config == '__is_first_stage__':
81
+ return None
82
+ elif config == "__is_unconditional__":
83
+ return None
84
+ raise KeyError("Expected key `target` to instantiate.")
85
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
86
+
87
+
88
+ def get_obj_from_str(string, reload=False):
89
+ module, cls = string.rsplit(".", 1)
90
+ if reload:
91
+ module_imp = importlib.import_module(module)
92
+ importlib.reload(module_imp)
93
+ return getattr(importlib.import_module(module, package=None), cls)
94
+
95
+
96
+ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
97
+ # create dummy dataset instance
98
+
99
+ # run prefetching
100
+ if idx_to_fn:
101
+ res = func(data, worker_id=idx)
102
+ else:
103
+ res = func(data)
104
+ Q.put([idx, res])
105
+ Q.put("Done")
106
+
107
+
108
+ def parallel_data_prefetch(
109
+ func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
110
+ ):
111
+ # if target_data_type not in ["ndarray", "list"]:
112
+ # raise ValueError(
113
+ # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
114
+ # )
115
+ if isinstance(data, np.ndarray) and target_data_type == "list":
116
+ raise ValueError("list expected but function got ndarray.")
117
+ elif isinstance(data, abc.Iterable):
118
+ if isinstance(data, dict):
119
+ print(
120
+ f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
121
+ )
122
+ data = list(data.values())
123
+ if target_data_type == "ndarray":
124
+ data = np.asarray(data)
125
+ else:
126
+ data = list(data)
127
+ else:
128
+ raise TypeError(
129
+ f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
130
+ )
131
+
132
+ if cpu_intensive:
133
+ Q = mp.Queue(1000)
134
+ proc = mp.Process
135
+ else:
136
+ Q = Queue(1000)
137
+ proc = Thread
138
+ # spawn processes
139
+ if target_data_type == "ndarray":
140
+ arguments = [
141
+ [func, Q, part, i, use_worker_id]
142
+ for i, part in enumerate(np.array_split(data, n_proc))
143
+ ]
144
+ else:
145
+ step = (
146
+ int(len(data) / n_proc + 1)
147
+ if len(data) % n_proc != 0
148
+ else int(len(data) / n_proc)
149
+ )
150
+ arguments = [
151
+ [func, Q, part, i, use_worker_id]
152
+ for i, part in enumerate(
153
+ [data[i: i + step] for i in range(0, len(data), step)]
154
+ )
155
+ ]
156
+ processes = []
157
+ for i in range(n_proc):
158
+ p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
159
+ processes += [p]
160
+
161
+ # start processes
162
+ print(f"Start prefetching...")
163
+ import time
164
+
165
+ start = time.time()
166
+ gather_res = [[] for _ in range(n_proc)]
167
+ try:
168
+ for p in processes:
169
+ p.start()
170
+
171
+ k = 0
172
+ while k < n_proc:
173
+ # get result
174
+ res = Q.get()
175
+ if res == "Done":
176
+ k += 1
177
+ else:
178
+ gather_res[res[0]] = res[1]
179
+
180
+ except Exception as e:
181
+ print("Exception: ", e)
182
+ for p in processes:
183
+ p.terminate()
184
+
185
+ raise e
186
+ finally:
187
+ for p in processes:
188
+ p.join()
189
+ print(f"Prefetching complete. [{time.time() - start} sec.]")
190
+
191
+ if target_data_type == 'ndarray':
192
+ if not isinstance(gather_res[0], np.ndarray):
193
+ return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
194
+
195
+ # order outputs
196
+ return np.concatenate(gather_res, axis=0)
197
+ elif target_data_type == 'list':
198
+ out = []
199
+ for r in gather_res:
200
+ out.extend(r)
201
+ return out
202
+ else:
203
+ return gather_res
204
+
205
+ def seed_everything(seed: int=0):
206
+ random.seed(seed)
207
+ os.environ["PYTHONHASHSEED"] = str(seed)
208
+ np.random.seed(seed)
209
+ torch.manual_seed(seed)
210
+ torch.cuda.manual_seed(seed)
211
+ torch.backends.cudnn.deterministic = True
requirements.txt CHANGED
@@ -9,5 +9,6 @@ kornia==0.6.11
9
  transformers==4.27.4
10
  dill==0.3.6
11
  gradio==3.26.0
 
12
  gdown==4.7.1
13
- torchmetrics==0.11.4
 
9
  transformers==4.27.4
10
  dill==0.3.6
11
  gradio==3.26.0
12
+ torchmetrics==0.11.4
13
  gdown==4.7.1
14
+ omegaconf==2.3.0
utils.py CHANGED
@@ -1,25 +1,13 @@
1
- import logging
2
- import os
3
- import random
4
- import tarfile
5
  from typing import Tuple
6
-
7
- import dill
8
- import gdown
9
- import numpy as np
10
- import torch
11
  from PIL import Image
12
  from torchvision.transforms import ToTensor
13
 
14
- logger = logging.getLogger(__file__)
15
-
16
  to_tensor = ToTensor()
17
 
18
-
19
  def preprocess_image(
20
  image: Image, resize_shape: Tuple[int, int] = (256, 256), center_crop=True
21
  ):
22
- processed_image = image
23
 
24
  if center_crop:
25
  width, height = image.size
@@ -30,49 +18,11 @@ def preprocess_image(
30
  right = (width + crop_size) // 2
31
  bottom = (height + crop_size) // 2
32
 
33
- processed_image = image.crop((left, top, right, bottom))
34
-
35
- processed_image = processed_image.resize(resize_shape)
36
-
37
- image = to_tensor(processed_image)
38
- image = image.unsqueeze(0) * 2 - 1
39
-
40
- return processed_image, image
41
-
42
-
43
- def download_artifacts(output_path: str):
44
- logger.error("Downloading the model artifacts...")
45
- if not os.path.exists(output_path):
46
- gdown.download(id=os.environ["GDRIVE_ID"], output=output_path, quiet=True)
47
-
48
-
49
- def extract_artifacts(path: str):
50
- logger.error("Extracting the model artifacts...")
51
- if not os.path.exists("model.pkl"):
52
- with tarfile.open(path) as tar:
53
- tar.extractall()
54
-
55
-
56
- def setup_environment():
57
- os.environ["PYTHONPATH"] = os.getcwd()
58
-
59
- artifacts_path = "artifacts.tar.gz"
60
-
61
- download_artifacts(output_path=artifacts_path)
62
-
63
- extract_artifacts(path=artifacts_path)
64
-
65
-
66
- def get_predictor():
67
- logger.error("Loading the predictor...")
68
- with open("model.pkl", "rb") as fp:
69
- return dill.load(fp)
70
 
 
 
 
 
71
 
72
- def seed_everything(seed: int = 0):
73
- random.seed(seed)
74
- os.environ["PYTHONHASHSEED"] = str(seed)
75
- np.random.seed(seed)
76
- torch.manual_seed(seed)
77
- torch.cuda.manual_seed(seed)
78
- torch.backends.cudnn.deterministic = True
 
 
 
 
 
1
  from typing import Tuple
 
 
 
 
 
2
  from PIL import Image
3
  from torchvision.transforms import ToTensor
4
 
 
 
5
  to_tensor = ToTensor()
6
 
 
7
  def preprocess_image(
8
  image: Image, resize_shape: Tuple[int, int] = (256, 256), center_crop=True
9
  ):
10
+ pil_image = image
11
 
12
  if center_crop:
13
  width, height = image.size
 
18
  right = (width + crop_size) // 2
19
  bottom = (height + crop_size) // 2
20
 
21
+ pil_image = image.crop((left, top, right, bottom))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ pil_image = pil_image.resize(resize_shape)
24
+
25
+ tensor_image = to_tensor(pil_image)
26
+ tensor_image = tensor_image.unsqueeze(0) * 2 - 1
27
 
28
+ return pil_image, tensor_image