apolinario commited on
Commit
24c0900
1 Parent(s): 9fdf2c0

Reinstate the app

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +5 -10
  2. app.py +149 -0
  3. latent-diffusion/LICENSE +21 -0
  4. latent-diffusion/README.md +274 -0
  5. latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml +54 -0
  6. latent-diffusion/configs/autoencoder/autoencoder_kl_32x32x4.yaml +53 -0
  7. latent-diffusion/configs/autoencoder/autoencoder_kl_64x64x3.yaml +54 -0
  8. latent-diffusion/configs/autoencoder/autoencoder_kl_8x8x64.yaml +53 -0
  9. latent-diffusion/configs/latent-diffusion/celebahq-ldm-vq-4.yaml +86 -0
  10. latent-diffusion/configs/latent-diffusion/cin-ldm-vq-f8.yaml +98 -0
  11. latent-diffusion/configs/latent-diffusion/cin256-v2.yaml +68 -0
  12. latent-diffusion/configs/latent-diffusion/ffhq-ldm-vq-4.yaml +85 -0
  13. latent-diffusion/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml +85 -0
  14. latent-diffusion/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml +91 -0
  15. latent-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml +71 -0
  16. latent-diffusion/environment.yaml +27 -0
  17. latent-diffusion/ldm/__pycache__/util.cpython-39.pyc +0 -0
  18. latent-diffusion/ldm/data/__init__.py +0 -0
  19. latent-diffusion/ldm/data/base.py +23 -0
  20. latent-diffusion/ldm/data/imagenet.py +394 -0
  21. latent-diffusion/ldm/data/lsun.py +92 -0
  22. latent-diffusion/ldm/lr_scheduler.py +98 -0
  23. latent-diffusion/ldm/models/__pycache__/autoencoder.cpython-39.pyc +0 -0
  24. latent-diffusion/ldm/models/autoencoder.py +443 -0
  25. latent-diffusion/ldm/models/diffusion/__init__.py +0 -0
  26. latent-diffusion/ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc +0 -0
  27. latent-diffusion/ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc +0 -0
  28. latent-diffusion/ldm/models/diffusion/__pycache__/ddpm.cpython-39.pyc +0 -0
  29. latent-diffusion/ldm/models/diffusion/__pycache__/plms.cpython-39.pyc +0 -0
  30. latent-diffusion/ldm/models/diffusion/classifier.py +267 -0
  31. latent-diffusion/ldm/models/diffusion/ddim.py +203 -0
  32. latent-diffusion/ldm/models/diffusion/ddpm.py +1445 -0
  33. latent-diffusion/ldm/models/diffusion/plms.py +236 -0
  34. latent-diffusion/ldm/modules/__pycache__/attention.cpython-39.pyc +0 -0
  35. latent-diffusion/ldm/modules/__pycache__/ema.cpython-39.pyc +0 -0
  36. latent-diffusion/ldm/modules/__pycache__/x_transformer.cpython-39.pyc +0 -0
  37. latent-diffusion/ldm/modules/attention.py +261 -0
  38. latent-diffusion/ldm/modules/diffusionmodules/__init__.py +0 -0
  39. latent-diffusion/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc +0 -0
  40. latent-diffusion/ldm/modules/diffusionmodules/__pycache__/model.cpython-39.pyc +0 -0
  41. latent-diffusion/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-39.pyc +0 -0
  42. latent-diffusion/ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc +0 -0
  43. latent-diffusion/ldm/modules/diffusionmodules/model.py +835 -0
  44. latent-diffusion/ldm/modules/diffusionmodules/openaimodel.py +961 -0
  45. latent-diffusion/ldm/modules/diffusionmodules/util.py +267 -0
  46. latent-diffusion/ldm/modules/distributions/__init__.py +0 -0
  47. latent-diffusion/ldm/modules/distributions/__pycache__/__init__.cpython-39.pyc +0 -0
  48. latent-diffusion/ldm/modules/distributions/__pycache__/distributions.cpython-39.pyc +0 -0
  49. latent-diffusion/ldm/modules/distributions/distributions.py +92 -0
  50. latent-diffusion/ldm/modules/ema.py +76 -0
README.md CHANGED
@@ -3,14 +3,9 @@ title: Latent Diffusion
3
  emoji: 👁
4
  colorFrom: pink
5
  colorTo: indigo
6
- sdk: static
 
 
7
  pinned: false
8
- ---
9
- <div>
10
- <h2>Under maintenance due to concerns on the biases of the model</h2>
11
- <p class="lg:col-span-3">
12
- This space contained a way to quickly execute the open source text-to-image <a href="https://github.com/CompVis/latent-diffusion" target="_blank">Latent Diffusion</a> model by CompVis, trained on the <a href="https://laion.ai/laion-400-open-dataset/" target="_blank">LAION-400M</a> dataset.<br><br>
13
- We believe transparency and bringing and open and honest discussions about model biases is the best way to deal with such challegnges. However after some users reported seeing some gruesome NSFW images with certain prompts from this model, we decided to take down this Spaces for now.
14
- </p>
15
-
16
- </div>
3
  emoji: 👁
4
  colorFrom: pink
5
  colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 2.9.1
8
+ app_file: app.py
9
  pinned: false
10
+ license: mit
11
+ ---
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydoc import describe
2
+ import gradio as gr
3
+ import torch
4
+ from omegaconf import OmegaConf
5
+ import sys
6
+ sys.path.append(".")
7
+ sys.path.append('./taming-transformers')
8
+ sys.path.append('./latent-diffusion')
9
+ from taming.models import vqgan
10
+ from ldm.util import instantiate_from_config
11
+
12
+ torch.hub.download_url_to_file('http://batbot.ai/models/latent-diffusion/models/ldm/text2img-large/model.ckpt','txt2img-f8-large.ckpt')
13
+
14
+ #@title Import stuff
15
+ import argparse, os, sys, glob
16
+ import numpy as np
17
+ from PIL import Image
18
+ from einops import rearrange
19
+ from torchvision.utils import make_grid
20
+ import transformers
21
+ import gc
22
+ from ldm.util import instantiate_from_config
23
+ from ldm.models.diffusion.ddim import DDIMSampler
24
+ from ldm.models.diffusion.plms import PLMSSampler
25
+ from open_clip import tokenizer
26
+ import open_clip
27
+
28
+ def load_model_from_config(config, ckpt, verbose=False):
29
+ print(f"Loading model from {ckpt}")
30
+ pl_sd = torch.load(ckpt, map_location="cuda")
31
+ sd = pl_sd["state_dict"]
32
+ model = instantiate_from_config(config.model)
33
+ m, u = model.load_state_dict(sd, strict=False)
34
+ if len(m) > 0 and verbose:
35
+ print("missing keys:")
36
+ print(m)
37
+ if len(u) > 0 and verbose:
38
+ print("unexpected keys:")
39
+ print(u)
40
+
41
+ model = model.half().cuda()
42
+ model.eval()
43
+ return model
44
+
45
+ config = OmegaConf.load("latent-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml")
46
+ model = load_model_from_config(config, f"txt2img-f8-large.ckpt")
47
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
48
+ model = model.to(device)
49
+ #NSFW CLIP Filter
50
+ clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32')
51
+ text = tokenizer.tokenize(["NSFW", "adult content", "porn", "naked people","genitalia","penis","vagina"])
52
+ with torch.no_grad():
53
+ text_features = clip_model.encode_text(text)
54
+
55
+ def run(prompt, steps, width, height, images, scale):
56
+ opt = argparse.Namespace(
57
+ prompt = prompt,
58
+ outdir='latent-diffusion/outputs',
59
+ ddim_steps = int(steps),
60
+ ddim_eta = 0,
61
+ n_iter = 1,
62
+ W=int(width),
63
+ H=int(height),
64
+ n_samples=int(images),
65
+ scale=scale,
66
+ plms=True
67
+ )
68
+
69
+ if opt.plms:
70
+ opt.ddim_eta = 0
71
+ sampler = PLMSSampler(model)
72
+ else:
73
+ sampler = DDIMSampler(model)
74
+
75
+ os.makedirs(opt.outdir, exist_ok=True)
76
+ outpath = opt.outdir
77
+
78
+ prompt = opt.prompt
79
+
80
+
81
+ sample_path = os.path.join(outpath, "samples")
82
+ os.makedirs(sample_path, exist_ok=True)
83
+ base_count = len(os.listdir(sample_path))
84
+
85
+ all_samples=list()
86
+ all_samples_images=list()
87
+ with torch.no_grad():
88
+ with torch.cuda.amp.autocast():
89
+ with model.ema_scope():
90
+ uc = None
91
+ if opt.scale > 0:
92
+ uc = model.get_learned_conditioning(opt.n_samples * [""])
93
+ for n in range(opt.n_iter):
94
+ c = model.get_learned_conditioning(opt.n_samples * [prompt])
95
+ shape = [4, opt.H//8, opt.W//8]
96
+ samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
97
+ conditioning=c,
98
+ batch_size=opt.n_samples,
99
+ shape=shape,
100
+ verbose=False,
101
+ unconditional_guidance_scale=opt.scale,
102
+ unconditional_conditioning=uc,
103
+ eta=opt.ddim_eta)
104
+
105
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
106
+ x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0)
107
+
108
+ for x_sample in x_samples_ddim:
109
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
110
+ image_vector = Image.fromarray(x_sample.astype(np.uint8))
111
+ image = preprocess(image_vector).unsqueeze(0)
112
+ image_features = clip_model.encode_image(image)
113
+ sims = image_features @ text_features.T
114
+ if(sims.max()<18):
115
+ all_samples_images.append(image_vector)
116
+ else:
117
+ return(None,None,"Sorry, NSFW content was detected on your outputs. Try again with different prompts. If you feel your prompt was not supposed to give NSFW outputs, this may be due to a bias in the model. Read more about biases in the Biases Acknowledgment section below.")
118
+ #Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:04}.png"))
119
+ base_count += 1
120
+ all_samples.append(x_samples_ddim)
121
+
122
+
123
+ # additionally, save as grid
124
+ grid = torch.stack(all_samples, 0)
125
+ grid = rearrange(grid, 'n b c h w -> (n b) c h w')
126
+ grid = make_grid(grid, nrow=2)
127
+ # to image
128
+ grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
129
+
130
+ Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(" ", "-")}.png'))
131
+ return(Image.fromarray(grid.astype(np.uint8)),all_samples_images,None)
132
+
133
+ image = gr.outputs.Image(type="pil", label="Your result")
134
+ css = ".output-image{height: 528px !important} .output-carousel .output-image{height:272px !important} a{text-decoration: underline}"
135
+ iface = gr.Interface(fn=run, inputs=[
136
+ gr.inputs.Textbox(label="Prompt - try adding increments to your prompt such as 'oil on canvas', 'a painting', 'a book cover'",default="chalk pastel drawing of a dog wearing a funny hat"),
137
+ gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=45,maximum=50,minimum=1,step=1),
138
+ gr.inputs.Radio(label="Width", choices=[32,64,128,256],default=256),
139
+ gr.inputs.Radio(label="Height", choices=[32,64,128,256],default=256),
140
+ gr.inputs.Slider(label="Images - How many images you wish to generate", default=2, step=1, minimum=1, maximum=4),
141
+ gr.inputs.Slider(label="Diversity scale - How different from one another you wish the images to be",default=5.0, minimum=1.0, maximum=15.0),
142
+ #gr.inputs.Slider(label="ETA - between 0 and 1. Lower values can provide better quality, higher values can be more diverse",default=0.0,minimum=0.0, maximum=1.0,step=0.1),
143
+ ],
144
+ outputs=[image,gr.outputs.Carousel(label="Individual images",components=["image"]),gr.outputs.Textbox(label="Error")],
145
+ css=css,
146
+ title="Generate images from text with Latent Diffusion LAION-400M",
147
+ description="<div>By typing a prompt and pressing submit you can generate images based on this prompt. <a href='https://github.com/CompVis/latent-diffusion' target='_blank'>Latent Diffusion</a> is a text-to-image model created by <a href='https://github.com/CompVis' target='_blank'>CompVis</a>, trained on the <a href='https://laion.ai/laion-400-open-dataset/'>LAION-400M dataset.</a><br>This UI to the model was assembled by <a style='color: rgb(245, 158, 11);font-weight:bold' href='https://twitter.com/multimodalart' target='_blank'>@multimodalart</a></div>",
148
+ article="<h4 style='font-size: 110%;margin-top:.5em'>Biases acknowledgment</h4><div>Despite how impressive being able to turn text into image is, beware to the fact that this model may output content that reinforces or exarcbates societal biases. According to the <a href='https://arxiv.org/abs/2112.10752' target='_blank'>Latent Diffusion paper</a>:<i> \"Deep learning modules tend to reproduce or exacerbate biases that are already present in the data\"</i>. The model was trained on an unfiltered version the LAION-400M dataset, which scrapped non-curated image-text-pairs from the internet (the exception being the the removal of illegal content) and is meant to be used for research purposes, such as this one. <a href='https://laion.ai/laion-400-open-dataset/' target='_blank'>You can read more on LAION's website</a></div><h4 style='font-size: 110%;margin-top:1em'>Who owns the images produced by this demo?</h4><div>Definetly not me! Probably you do. I say probably because the Copyright discussion about AI generated art is ongoing. So <a href='https://www.theverge.com/2022/2/21/22944335/us-copyright-office-reject-ai-generated-art-recent-entrance-to-paradise' target='_blank'>it may be the case that everything produced here falls automatically into the public domain</a>. But in any case it is either yours or is in the public domain.</div>")
149
+ iface.launch(enable_queue=True)
latent-diffusion/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
latent-diffusion/README.md ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Latent Diffusion Models
2
+ [arXiv](https://arxiv.org/abs/2112.10752) | [BibTeX](#bibtex)
3
+
4
+ <p align="center">
5
+ <img src=assets/results.gif />
6
+ </p>
7
+
8
+
9
+
10
+ [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)<br/>
11
+ [Robin Rombach](https://github.com/rromb)\*,
12
+ [Andreas Blattmann](https://github.com/ablattmann)\*,
13
+ [Dominik Lorenz](https://github.com/qp-qp)\,
14
+ [Patrick Esser](https://github.com/pesser),
15
+ [Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)<br/>
16
+ \* equal contribution
17
+
18
+ <p align="center">
19
+ <img src=assets/modelfigure.png />
20
+ </p>
21
+
22
+ ## News
23
+ ### April 2022
24
+ - More pre-trained LDMs are available:
25
+ - A 1.45B [model](#text-to-image) trained on the [LAION-400M](https://arxiv.org/abs/2111.02114) database.
26
+ - A class-conditional model on ImageNet, achieving a FID of 3.6 when using [classifier-free guidance](https://openreview.net/pdf?id=qw8AKxfYbI) Available via a [colab notebook](https://colab.research.google.com/github/CompVis/latent-diffusion/blob/main/scripts/latent_imagenet_diffusion.ipynb) [![][colab]][colab-cin].
27
+
28
+ ## Requirements
29
+ A suitable [conda](https://conda.io/) environment named `ldm` can be created
30
+ and activated with:
31
+
32
+ ```
33
+ conda env create -f environment.yaml
34
+ conda activate ldm
35
+ ```
36
+
37
+ # Pretrained Models
38
+ A general list of all available checkpoints is available in via our [model zoo](#model-zoo).
39
+ If you use any of these models in your work, we are always happy to receive a [citation](#bibtex).
40
+
41
+ ## Text-to-Image
42
+ ![text2img-figure](assets/txt2img-preview.png)
43
+
44
+
45
+ Download the pre-trained weights (5.7GB)
46
+ ```
47
+ mkdir -p models/ldm/text2img-large/
48
+ wget -O models/ldm/text2img-large/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/txt2img-f8-large/model.ckpt
49
+ ```
50
+ and sample with
51
+ ```
52
+ python scripts/txt2img.py --prompt "a virus monster is playing guitar, oil on canvas" --ddim_eta 0.0 --n_samples 4 --n_iter 4 --scale 5.0 --ddim_steps 50
53
+ ```
54
+ This will save each sample individually as well as a grid of size `n_iter` x `n_samples` at the specified output location (default: `outputs/txt2img-samples`).
55
+ Quality, sampling speed and diversity are best controlled via the `scale`, `ddim_steps` and `ddim_eta` arguments.
56
+ As a rule of thumb, higher values of `scale` produce better samples at the cost of a reduced output diversity.
57
+ Furthermore, increasing `ddim_steps` generally also gives higher quality samples, but returns are diminishing for values > 250.
58
+ Fast sampling (i.e. low values of `ddim_steps`) while retaining good quality can be achieved by using `--ddim_eta 0.0`.
59
+ Faster sampling (i.e. even lower values of `ddim_steps`) while retaining good quality can be achieved by using `--ddim_eta 0.0` and `--plms` (added by Katherine Crowson, see [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778)).
60
+
61
+ #### Beyond 256²
62
+
63
+ For certain inputs, simply running the model in a convolutional fashion on larger features than it was trained on
64
+ can sometimes result in interesting results. To try it out, tune the `H` and `W` arguments (which will be integer-divided
65
+ by 8 in order to calculate the corresponding latent size), e.g. run
66
+
67
+ ```
68
+ python scripts/txt2img.py --prompt "a sunset behind a mountain range, vector image" --ddim_eta 1.0 --n_samples 1 --n_iter 1 --H 384 --W 1024 --scale 5.0
69
+ ```
70
+ to create a sample of size 384x1024. Note, however, that controllability is reduced compared to the 256x256 setting.
71
+
72
+ The example below was generated using the above command.
73
+ ![text2img-figure-conv](assets/txt2img-convsample.png)
74
+
75
+
76
+
77
+ ## Inpainting
78
+ ![inpainting](assets/inpainting.png)
79
+
80
+ Download the pre-trained weights
81
+ ```
82
+ wget -O models/ldm/inpainting_big/last.ckpt https://heibox.uni-heidelberg.de/f/4d9ac7ea40c64582b7c9/?dl=1
83
+ ```
84
+
85
+ and sample with
86
+ ```
87
+ python scripts/inpaint.py --indir data/inpainting_examples/ --outdir outputs/inpainting_results
88
+ ```
89
+ `indir` should contain images `*.png` and masks `<image_fname>_mask.png` like
90
+ the examples provided in `data/inpainting_examples`.
91
+
92
+ ## Class-Conditional ImageNet
93
+
94
+ Available via a [notebook](scripts/latent_imagenet_diffusion.ipynb) [![][colab]][colab-cin].
95
+ ![class-conditional](assets/birdhouse.png)
96
+
97
+ [colab]: <https://colab.research.google.com/assets/colab-badge.svg>
98
+ [colab-cin]: <https://colab.research.google.com/github/CompVis/latent-diffusion/blob/main/scripts/latent_imagenet_diffusion.ipynb>
99
+
100
+
101
+ ## Unconditional Models
102
+
103
+ We also provide a script for sampling from unconditional LDMs (e.g. LSUN, FFHQ, ...). Start it via
104
+
105
+ ```shell script
106
+ CUDA_VISIBLE_DEVICES=<GPU_ID> python scripts/sample_diffusion.py -r models/ldm/<model_spec>/model.ckpt -l <logdir> -n <\#samples> --batch_size <batch_size> -c <\#ddim steps> -e <\#eta>
107
+ ```
108
+
109
+ # Train your own LDMs
110
+
111
+ ## Data preparation
112
+
113
+ ### Faces
114
+ For downloading the CelebA-HQ and FFHQ datasets, proceed as described in the [taming-transformers](https://github.com/CompVis/taming-transformers#celeba-hq)
115
+ repository.
116
+
117
+ ### LSUN
118
+
119
+ The LSUN datasets can be conveniently downloaded via the script available [here](https://github.com/fyu/lsun).
120
+ We performed a custom split into training and validation images, and provide the corresponding filenames
121
+ at [https://ommer-lab.com/files/lsun.zip](https://ommer-lab.com/files/lsun.zip).
122
+ After downloading, extract them to `./data/lsun`. The beds/cats/churches subsets should
123
+ also be placed/symlinked at `./data/lsun/bedrooms`/`./data/lsun/cats`/`./data/lsun/churches`, respectively.
124
+
125
+ ### ImageNet
126
+ The code will try to download (through [Academic
127
+ Torrents](http://academictorrents.com/)) and prepare ImageNet the first time it
128
+ is used. However, since ImageNet is quite large, this requires a lot of disk
129
+ space and time. If you already have ImageNet on your disk, you can speed things
130
+ up by putting the data into
131
+ `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` (which defaults to
132
+ `~/.cache/autoencoders/data/ILSVRC2012_{split}/data/`), where `{split}` is one
133
+ of `train`/`validation`. It should have the following structure:
134
+
135
+ ```
136
+ ${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/
137
+ ├── n01440764
138
+ │ ├── n01440764_10026.JPEG
139
+ │ ├── n01440764_10027.JPEG
140
+ │ ├── ...
141
+ ├── n01443537
142
+ │ ├── n01443537_10007.JPEG
143
+ │ ├── n01443537_10014.JPEG
144
+ │ ├── ...
145
+ ├── ...
146
+ ```
147
+
148
+ If you haven't extracted the data, you can also place
149
+ `ILSVRC2012_img_train.tar`/`ILSVRC2012_img_val.tar` (or symlinks to them) into
150
+ `${XDG_CACHE}/autoencoders/data/ILSVRC2012_train/` /
151
+ `${XDG_CACHE}/autoencoders/data/ILSVRC2012_validation/`, which will then be
152
+ extracted into above structure without downloading it again. Note that this
153
+ will only happen if neither a folder
154
+ `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` nor a file
155
+ `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/.ready` exist. Remove them
156
+ if you want to force running the dataset preparation again.
157
+
158
+
159
+ ## Model Training
160
+
161
+ Logs and checkpoints for trained models are saved to `logs/<START_DATE_AND_TIME>_<config_spec>`.
162
+
163
+ ### Training autoencoder models
164
+
165
+ Configs for training a KL-regularized autoencoder on ImageNet are provided at `configs/autoencoder`.
166
+ Training can be started by running
167
+ ```
168
+ CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/autoencoder/<config_spec>.yaml -t --gpus 0,
169
+ ```
170
+ where `config_spec` is one of {`autoencoder_kl_8x8x64`(f=32, d=64), `autoencoder_kl_16x16x16`(f=16, d=16),
171
+ `autoencoder_kl_32x32x4`(f=8, d=4), `autoencoder_kl_64x64x3`(f=4, d=3)}.
172
+
173
+ For training VQ-regularized models, see the [taming-transformers](https://github.com/CompVis/taming-transformers)
174
+ repository.
175
+
176
+ ### Training LDMs
177
+
178
+ In ``configs/latent-diffusion/`` we provide configs for training LDMs on the LSUN-, CelebA-HQ, FFHQ and ImageNet datasets.
179
+ Training can be started by running
180
+
181
+ ```shell script
182
+ CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/latent-diffusion/<config_spec>.yaml -t --gpus 0,
183
+ ```
184
+
185
+ where ``<config_spec>`` is one of {`celebahq-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),`ffhq-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),
186
+ `lsun_bedrooms-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),
187
+ `lsun_churches-ldm-vq-4`(f=8, KL-reg. autoencoder, spatial size 32x32x4),`cin-ldm-vq-8`(f=8, VQ-reg. autoencoder, spatial size 32x32x4)}.
188
+
189
+ # Model Zoo
190
+
191
+ ## Pretrained Autoencoding Models
192
+ ![rec2](assets/reconstruction2.png)
193
+
194
+ All models were trained until convergence (no further substantial improvement in rFID).
195
+
196
+ | Model | rFID vs val | train steps |PSNR | PSIM | Link | Comments
197
+ |-------------------------|------------|----------------|----------------|---------------|-------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------|
198
+ | f=4, VQ (Z=8192, d=3) | 0.58 | 533066 | 27.43 +/- 4.26 | 0.53 +/- 0.21 | https://ommer-lab.com/files/latent-diffusion/vq-f4.zip | |
199
+ | f=4, VQ (Z=8192, d=3) | 1.06 | 658131 | 25.21 +/- 4.17 | 0.72 +/- 0.26 | https://heibox.uni-heidelberg.de/f/9c6681f64bb94338a069/?dl=1 | no attention |
200
+ | f=8, VQ (Z=16384, d=4) | 1.14 | 971043 | 23.07 +/- 3.99 | 1.17 +/- 0.36 | https://ommer-lab.com/files/latent-diffusion/vq-f8.zip | |
201
+ | f=8, VQ (Z=256, d=4) | 1.49 | 1608649 | 22.35 +/- 3.81 | 1.26 +/- 0.37 | https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip |
202
+ | f=16, VQ (Z=16384, d=8) | 5.15 | 1101166 | 20.83 +/- 3.61 | 1.73 +/- 0.43 | https://heibox.uni-heidelberg.de/f/0e42b04e2e904890a9b6/?dl=1 | |
203
+ | | | | | | | |
204
+ | f=4, KL | 0.27 | 176991 | 27.53 +/- 4.54 | 0.55 +/- 0.24 | https://ommer-lab.com/files/latent-diffusion/kl-f4.zip | |
205
+ | f=8, KL | 0.90 | 246803 | 24.19 +/- 4.19 | 1.02 +/- 0.35 | https://ommer-lab.com/files/latent-diffusion/kl-f8.zip | |
206
+ | f=16, KL (d=16) | 0.87 | 442998 | 24.08 +/- 4.22 | 1.07 +/- 0.36 | https://ommer-lab.com/files/latent-diffusion/kl-f16.zip | |
207
+ | f=32, KL (d=64) | 2.04 | 406763 | 22.27 +/- 3.93 | 1.41 +/- 0.40 | https://ommer-lab.com/files/latent-diffusion/kl-f32.zip | |
208
+
209
+ ### Get the models
210
+
211
+ Running the following script downloads und extracts all available pretrained autoencoding models.
212
+ ```shell script
213
+ bash scripts/download_first_stages.sh
214
+ ```
215
+
216
+ The first stage models can then be found in `models/first_stage_models/<model_spec>`
217
+
218
+
219
+
220
+ ## Pretrained LDMs
221
+ | Datset | Task | Model | FID | IS | Prec | Recall | Link | Comments
222
+ |---------------------------------|------|--------------|---------------|-----------------|------|------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------|
223
+ | CelebA-HQ | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=0)| 5.11 (5.11) | 3.29 | 0.72 | 0.49 | https://ommer-lab.com/files/latent-diffusion/celeba.zip | |
224
+ | FFHQ | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=1)| 4.98 (4.98) | 4.50 (4.50) | 0.73 | 0.50 | https://ommer-lab.com/files/latent-diffusion/ffhq.zip | |
225
+ | LSUN-Churches | Unconditional Image Synthesis | LDM-KL-8 (400 DDIM steps, eta=0)| 4.02 (4.02) | 2.72 | 0.64 | 0.52 | https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip | |
226
+ | LSUN-Bedrooms | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=1)| 2.95 (3.0) | 2.22 (2.23)| 0.66 | 0.48 | https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip | |
227
+ | ImageNet | Class-conditional Image Synthesis | LDM-VQ-8 (200 DDIM steps, eta=1) | 7.77(7.76)* /15.82** | 201.56(209.52)* /78.82** | 0.84* / 0.65** | 0.35* / 0.63** | https://ommer-lab.com/files/latent-diffusion/cin.zip | *: w/ guiding, classifier_scale 10 **: w/o guiding, scores in bracket calculated with script provided by [ADM](https://github.com/openai/guided-diffusion) |
228
+ | Conceptual Captions | Text-conditional Image Synthesis | LDM-VQ-f4 (100 DDIM steps, eta=0) | 16.79 | 13.89 | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/text2img.zip | finetuned from LAION |
229
+ | OpenImages | Super-resolution | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip | BSR image degradation |
230
+ | OpenImages | Layout-to-Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=0) | 32.02 | 15.92 | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip | |
231
+ | Landscapes | Semantic Image Synthesis | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip | |
232
+ | Landscapes | Semantic Image Synthesis | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip | finetuned on resolution 512x512 |
233
+
234
+
235
+ ### Get the models
236
+
237
+ The LDMs listed above can jointly be downloaded and extracted via
238
+
239
+ ```shell script
240
+ bash scripts/download_models.sh
241
+ ```
242
+
243
+ The models can then be found in `models/ldm/<model_spec>`.
244
+
245
+
246
+
247
+ ## Coming Soon...
248
+
249
+ * More inference scripts for conditional LDMs.
250
+ * In the meantime, you can play with our colab notebook https://colab.research.google.com/drive/1xqzUi2iXQXDqXBHQGP9Mqt2YrYW6cx-J?usp=sharing
251
+
252
+ ## Comments
253
+
254
+ - Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion)
255
+ and [https://github.com/lucidrains/denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch).
256
+ Thanks for open-sourcing!
257
+
258
+ - The implementation of the transformer encoder is from [x-transformers](https://github.com/lucidrains/x-transformers) by [lucidrains](https://github.com/lucidrains?tab=repositories).
259
+
260
+
261
+ ## BibTeX
262
+
263
+ ```
264
+ @misc{rombach2021highresolution,
265
+ title={High-Resolution Image Synthesis with Latent Diffusion Models},
266
+ author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},
267
+ year={2021},
268
+ eprint={2112.10752},
269
+ archivePrefix={arXiv},
270
+ primaryClass={cs.CV}
271
+ }
272
+ ```
273
+
274
+
latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-6
3
+ target: ldm.models.autoencoder.AutoencoderKL
4
+ params:
5
+ monitor: "val/rec_loss"
6
+ embed_dim: 16
7
+ lossconfig:
8
+ target: ldm.modules.losses.LPIPSWithDiscriminator
9
+ params:
10
+ disc_start: 50001
11
+ kl_weight: 0.000001
12
+ disc_weight: 0.5
13
+
14
+ ddconfig:
15
+ double_z: True
16
+ z_channels: 16
17
+ resolution: 256
18
+ in_channels: 3
19
+ out_ch: 3
20
+ ch: 128
21
+ ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
22
+ num_res_blocks: 2
23
+ attn_resolutions: [16]
24
+ dropout: 0.0
25
+
26
+
27
+ data:
28
+ target: main.DataModuleFromConfig
29
+ params:
30
+ batch_size: 12
31
+ wrap: True
32
+ train:
33
+ target: ldm.data.imagenet.ImageNetSRTrain
34
+ params:
35
+ size: 256
36
+ degradation: pil_nearest
37
+ validation:
38
+ target: ldm.data.imagenet.ImageNetSRValidation
39
+ params:
40
+ size: 256
41
+ degradation: pil_nearest
42
+
43
+ lightning:
44
+ callbacks:
45
+ image_logger:
46
+ target: main.ImageLogger
47
+ params:
48
+ batch_frequency: 1000
49
+ max_images: 8
50
+ increase_log_steps: True
51
+
52
+ trainer:
53
+ benchmark: True
54
+ accumulate_grad_batches: 2
latent-diffusion/configs/autoencoder/autoencoder_kl_32x32x4.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-6
3
+ target: ldm.models.autoencoder.AutoencoderKL
4
+ params:
5
+ monitor: "val/rec_loss"
6
+ embed_dim: 4
7
+ lossconfig:
8
+ target: ldm.modules.losses.LPIPSWithDiscriminator
9
+ params:
10
+ disc_start: 50001
11
+ kl_weight: 0.000001
12
+ disc_weight: 0.5
13
+
14
+ ddconfig:
15
+ double_z: True
16
+ z_channels: 4
17
+ resolution: 256
18
+ in_channels: 3
19
+ out_ch: 3
20
+ ch: 128
21
+ ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
22
+ num_res_blocks: 2
23
+ attn_resolutions: [ ]
24
+ dropout: 0.0
25
+
26
+ data:
27
+ target: main.DataModuleFromConfig
28
+ params:
29
+ batch_size: 12
30
+ wrap: True
31
+ train:
32
+ target: ldm.data.imagenet.ImageNetSRTrain
33
+ params:
34
+ size: 256
35
+ degradation: pil_nearest
36
+ validation:
37
+ target: ldm.data.imagenet.ImageNetSRValidation
38
+ params:
39
+ size: 256
40
+ degradation: pil_nearest
41
+
42
+ lightning:
43
+ callbacks:
44
+ image_logger:
45
+ target: main.ImageLogger
46
+ params:
47
+ batch_frequency: 1000
48
+ max_images: 8
49
+ increase_log_steps: True
50
+
51
+ trainer:
52
+ benchmark: True
53
+ accumulate_grad_batches: 2
latent-diffusion/configs/autoencoder/autoencoder_kl_64x64x3.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-6
3
+ target: ldm.models.autoencoder.AutoencoderKL
4
+ params:
5
+ monitor: "val/rec_loss"
6
+ embed_dim: 3
7
+ lossconfig:
8
+ target: ldm.modules.losses.LPIPSWithDiscriminator
9
+ params:
10
+ disc_start: 50001
11
+ kl_weight: 0.000001
12
+ disc_weight: 0.5
13
+
14
+ ddconfig:
15
+ double_z: True
16
+ z_channels: 3
17
+ resolution: 256
18
+ in_channels: 3
19
+ out_ch: 3
20
+ ch: 128
21
+ ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
22
+ num_res_blocks: 2
23
+ attn_resolutions: [ ]
24
+ dropout: 0.0
25
+
26
+
27
+ data:
28
+ target: main.DataModuleFromConfig
29
+ params:
30
+ batch_size: 12
31
+ wrap: True
32
+ train:
33
+ target: ldm.data.imagenet.ImageNetSRTrain
34
+ params:
35
+ size: 256
36
+ degradation: pil_nearest
37
+ validation:
38
+ target: ldm.data.imagenet.ImageNetSRValidation
39
+ params:
40
+ size: 256
41
+ degradation: pil_nearest
42
+
43
+ lightning:
44
+ callbacks:
45
+ image_logger:
46
+ target: main.ImageLogger
47
+ params:
48
+ batch_frequency: 1000
49
+ max_images: 8
50
+ increase_log_steps: True
51
+
52
+ trainer:
53
+ benchmark: True
54
+ accumulate_grad_batches: 2
latent-diffusion/configs/autoencoder/autoencoder_kl_8x8x64.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-6
3
+ target: ldm.models.autoencoder.AutoencoderKL
4
+ params:
5
+ monitor: "val/rec_loss"
6
+ embed_dim: 64
7
+ lossconfig:
8
+ target: ldm.modules.losses.LPIPSWithDiscriminator
9
+ params:
10
+ disc_start: 50001
11
+ kl_weight: 0.000001
12
+ disc_weight: 0.5
13
+
14
+ ddconfig:
15
+ double_z: True
16
+ z_channels: 64
17
+ resolution: 256
18
+ in_channels: 3
19
+ out_ch: 3
20
+ ch: 128
21
+ ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1
22
+ num_res_blocks: 2
23
+ attn_resolutions: [16,8]
24
+ dropout: 0.0
25
+
26
+ data:
27
+ target: main.DataModuleFromConfig
28
+ params:
29
+ batch_size: 12
30
+ wrap: True
31
+ train:
32
+ target: ldm.data.imagenet.ImageNetSRTrain
33
+ params:
34
+ size: 256
35
+ degradation: pil_nearest
36
+ validation:
37
+ target: ldm.data.imagenet.ImageNetSRValidation
38
+ params:
39
+ size: 256
40
+ degradation: pil_nearest
41
+
42
+ lightning:
43
+ callbacks:
44
+ image_logger:
45
+ target: main.ImageLogger
46
+ params:
47
+ batch_frequency: 1000
48
+ max_images: 8
49
+ increase_log_steps: True
50
+
51
+ trainer:
52
+ benchmark: True
53
+ accumulate_grad_batches: 2
latent-diffusion/configs/latent-diffusion/celebahq-ldm-vq-4.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 2.0e-06
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0195
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ image_size: 64
12
+ channels: 3
13
+ monitor: val/loss_simple_ema
14
+
15
+ unet_config:
16
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
17
+ params:
18
+ image_size: 64
19
+ in_channels: 3
20
+ out_channels: 3
21
+ model_channels: 224
22
+ attention_resolutions:
23
+ # note: this isn\t actually the resolution but
24
+ # the downsampling factor, i.e. this corresnponds to
25
+ # attention on spatial resolution 8,16,32, as the
26
+ # spatial reolution of the latents is 64 for f4
27
+ - 8
28
+ - 4
29
+ - 2
30
+ num_res_blocks: 2
31
+ channel_mult:
32
+ - 1
33
+ - 2
34
+ - 3
35
+ - 4
36
+ num_head_channels: 32
37
+ first_stage_config:
38
+ target: ldm.models.autoencoder.VQModelInterface
39
+ params:
40
+ embed_dim: 3
41
+ n_embed: 8192
42
+ ckpt_path: models/first_stage_models/vq-f4/model.ckpt
43
+ ddconfig:
44
+ double_z: false
45
+ z_channels: 3
46
+ resolution: 256
47
+ in_channels: 3
48
+ out_ch: 3
49
+ ch: 128
50
+ ch_mult:
51
+ - 1
52
+ - 2
53
+ - 4
54
+ num_res_blocks: 2
55
+ attn_resolutions: []
56
+ dropout: 0.0
57
+ lossconfig:
58
+ target: torch.nn.Identity
59
+ cond_stage_config: __is_unconditional__
60
+ data:
61
+ target: main.DataModuleFromConfig
62
+ params:
63
+ batch_size: 48
64
+ num_workers: 5
65
+ wrap: false
66
+ train:
67
+ target: taming.data.faceshq.CelebAHQTrain
68
+ params:
69
+ size: 256
70
+ validation:
71
+ target: taming.data.faceshq.CelebAHQValidation
72
+ params:
73
+ size: 256
74
+
75
+
76
+ lightning:
77
+ callbacks:
78
+ image_logger:
79
+ target: main.ImageLogger
80
+ params:
81
+ batch_frequency: 5000
82
+ max_images: 8
83
+ increase_log_steps: False
84
+
85
+ trainer:
86
+ benchmark: True
latent-diffusion/configs/latent-diffusion/cin-ldm-vq-f8.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-06
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0195
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ cond_stage_key: class_label
12
+ image_size: 32
13
+ channels: 4
14
+ cond_stage_trainable: true
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ unet_config:
18
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19
+ params:
20
+ image_size: 32
21
+ in_channels: 4
22
+ out_channels: 4
23
+ model_channels: 256
24
+ attention_resolutions:
25
+ #note: this isn\t actually the resolution but
26
+ # the downsampling factor, i.e. this corresnponds to
27
+ # attention on spatial resolution 8,16,32, as the
28
+ # spatial reolution of the latents is 32 for f8
29
+ - 4
30
+ - 2
31
+ - 1
32
+ num_res_blocks: 2
33
+ channel_mult:
34
+ - 1
35
+ - 2
36
+ - 4
37
+ num_head_channels: 32
38
+ use_spatial_transformer: true
39
+ transformer_depth: 1
40
+ context_dim: 512
41
+ first_stage_config:
42
+ target: ldm.models.autoencoder.VQModelInterface
43
+ params:
44
+ embed_dim: 4
45
+ n_embed: 16384
46
+ ckpt_path: configs/first_stage_models/vq-f8/model.yaml
47
+ ddconfig:
48
+ double_z: false
49
+ z_channels: 4
50
+ resolution: 256
51
+ in_channels: 3
52
+ out_ch: 3
53
+ ch: 128
54
+ ch_mult:
55
+ - 1
56
+ - 2
57
+ - 2
58
+ - 4
59
+ num_res_blocks: 2
60
+ attn_resolutions:
61
+ - 32
62
+ dropout: 0.0
63
+ lossconfig:
64
+ target: torch.nn.Identity
65
+ cond_stage_config:
66
+ target: ldm.modules.encoders.modules.ClassEmbedder
67
+ params:
68
+ embed_dim: 512
69
+ key: class_label
70
+ data:
71
+ target: main.DataModuleFromConfig
72
+ params:
73
+ batch_size: 64
74
+ num_workers: 12
75
+ wrap: false
76
+ train:
77
+ target: ldm.data.imagenet.ImageNetTrain
78
+ params:
79
+ config:
80
+ size: 256
81
+ validation:
82
+ target: ldm.data.imagenet.ImageNetValidation
83
+ params:
84
+ config:
85
+ size: 256
86
+
87
+
88
+ lightning:
89
+ callbacks:
90
+ image_logger:
91
+ target: main.ImageLogger
92
+ params:
93
+ batch_frequency: 5000
94
+ max_images: 8
95
+ increase_log_steps: False
96
+
97
+ trainer:
98
+ benchmark: True
latent-diffusion/configs/latent-diffusion/cin256-v2.yaml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 0.0001
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0195
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ cond_stage_key: class_label
12
+ image_size: 64
13
+ channels: 3
14
+ cond_stage_trainable: true
15
+ conditioning_key: crossattn
16
+ monitor: val/loss
17
+ use_ema: False
18
+
19
+ unet_config:
20
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21
+ params:
22
+ image_size: 64
23
+ in_channels: 3
24
+ out_channels: 3
25
+ model_channels: 192
26
+ attention_resolutions:
27
+ - 8
28
+ - 4
29
+ - 2
30
+ num_res_blocks: 2
31
+ channel_mult:
32
+ - 1
33
+ - 2
34
+ - 3
35
+ - 5
36
+ num_heads: 1
37
+ use_spatial_transformer: true
38
+ transformer_depth: 1
39
+ context_dim: 512
40
+
41
+ first_stage_config:
42
+ target: ldm.models.autoencoder.VQModelInterface
43
+ params:
44
+ embed_dim: 3
45
+ n_embed: 8192
46
+ ddconfig:
47
+ double_z: false
48
+ z_channels: 3
49
+ resolution: 256
50
+ in_channels: 3
51
+ out_ch: 3
52
+ ch: 128
53
+ ch_mult:
54
+ - 1
55
+ - 2
56
+ - 4
57
+ num_res_blocks: 2
58
+ attn_resolutions: []
59
+ dropout: 0.0
60
+ lossconfig:
61
+ target: torch.nn.Identity
62
+
63
+ cond_stage_config:
64
+ target: ldm.modules.encoders.modules.ClassEmbedder
65
+ params:
66
+ n_classes: 1001
67
+ embed_dim: 512
68
+ key: class_label
latent-diffusion/configs/latent-diffusion/ffhq-ldm-vq-4.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 2.0e-06
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0195
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ image_size: 64
12
+ channels: 3
13
+ monitor: val/loss_simple_ema
14
+ unet_config:
15
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16
+ params:
17
+ image_size: 64
18
+ in_channels: 3
19
+ out_channels: 3
20
+ model_channels: 224
21
+ attention_resolutions:
22
+ # note: this isn\t actually the resolution but
23
+ # the downsampling factor, i.e. this corresnponds to
24
+ # attention on spatial resolution 8,16,32, as the
25
+ # spatial reolution of the latents is 64 for f4
26
+ - 8
27
+ - 4
28
+ - 2
29
+ num_res_blocks: 2
30
+ channel_mult:
31
+ - 1
32
+ - 2
33
+ - 3
34
+ - 4
35
+ num_head_channels: 32
36
+ first_stage_config:
37
+ target: ldm.models.autoencoder.VQModelInterface
38
+ params:
39
+ embed_dim: 3
40
+ n_embed: 8192
41
+ ckpt_path: configs/first_stage_models/vq-f4/model.yaml
42
+ ddconfig:
43
+ double_z: false
44
+ z_channels: 3
45
+ resolution: 256
46
+ in_channels: 3
47
+ out_ch: 3
48
+ ch: 128
49
+ ch_mult:
50
+ - 1
51
+ - 2
52
+ - 4
53
+ num_res_blocks: 2
54
+ attn_resolutions: []
55
+ dropout: 0.0
56
+ lossconfig:
57
+ target: torch.nn.Identity
58
+ cond_stage_config: __is_unconditional__
59
+ data:
60
+ target: main.DataModuleFromConfig
61
+ params:
62
+ batch_size: 42
63
+ num_workers: 5
64
+ wrap: false
65
+ train:
66
+ target: taming.data.faceshq.FFHQTrain
67
+ params:
68
+ size: 256
69
+ validation:
70
+ target: taming.data.faceshq.FFHQValidation
71
+ params:
72
+ size: 256
73
+
74
+
75
+ lightning:
76
+ callbacks:
77
+ image_logger:
78
+ target: main.ImageLogger
79
+ params:
80
+ batch_frequency: 5000
81
+ max_images: 8
82
+ increase_log_steps: False
83
+
84
+ trainer:
85
+ benchmark: True
latent-diffusion/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 2.0e-06
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0195
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ image_size: 64
12
+ channels: 3
13
+ monitor: val/loss_simple_ema
14
+ unet_config:
15
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16
+ params:
17
+ image_size: 64
18
+ in_channels: 3
19
+ out_channels: 3
20
+ model_channels: 224
21
+ attention_resolutions:
22
+ # note: this isn\t actually the resolution but
23
+ # the downsampling factor, i.e. this corresnponds to
24
+ # attention on spatial resolution 8,16,32, as the
25
+ # spatial reolution of the latents is 64 for f4
26
+ - 8
27
+ - 4
28
+ - 2
29
+ num_res_blocks: 2
30
+ channel_mult:
31
+ - 1
32
+ - 2
33
+ - 3
34
+ - 4
35
+ num_head_channels: 32
36
+ first_stage_config:
37
+ target: ldm.models.autoencoder.VQModelInterface
38
+ params:
39
+ ckpt_path: configs/first_stage_models/vq-f4/model.yaml
40
+ embed_dim: 3
41
+ n_embed: 8192
42
+ ddconfig:
43
+ double_z: false
44
+ z_channels: 3
45
+ resolution: 256
46
+ in_channels: 3
47
+ out_ch: 3
48
+ ch: 128
49
+ ch_mult:
50
+ - 1
51
+ - 2
52
+ - 4
53
+ num_res_blocks: 2
54
+ attn_resolutions: []
55
+ dropout: 0.0
56
+ lossconfig:
57
+ target: torch.nn.Identity
58
+ cond_stage_config: __is_unconditional__
59
+ data:
60
+ target: main.DataModuleFromConfig
61
+ params:
62
+ batch_size: 48
63
+ num_workers: 5
64
+ wrap: false
65
+ train:
66
+ target: ldm.data.lsun.LSUNBedroomsTrain
67
+ params:
68
+ size: 256
69
+ validation:
70
+ target: ldm.data.lsun.LSUNBedroomsValidation
71
+ params:
72
+ size: 256
73
+
74
+
75
+ lightning:
76
+ callbacks:
77
+ image_logger:
78
+ target: main.ImageLogger
79
+ params:
80
+ batch_frequency: 5000
81
+ max_images: 8
82
+ increase_log_steps: False
83
+
84
+ trainer:
85
+ benchmark: True
latent-diffusion/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0155
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ loss_type: l1
11
+ first_stage_key: "image"
12
+ cond_stage_key: "image"
13
+ image_size: 32
14
+ channels: 4
15
+ cond_stage_trainable: False
16
+ concat_mode: False
17
+ scale_by_std: True
18
+ monitor: 'val/loss_simple_ema'
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [10000]
24
+ cycle_lengths: [10000000000000]
25
+ f_start: [1.e-6]
26
+ f_max: [1.]
27
+ f_min: [ 1.]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 192
36
+ attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2
39
+ num_heads: 8
40
+ use_scale_shift_norm: True
41
+ resblock_updown: True
42
+
43
+ first_stage_config:
44
+ target: ldm.models.autoencoder.AutoencoderKL
45
+ params:
46
+ embed_dim: 4
47
+ monitor: "val/rec_loss"
48
+ ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
49
+ ddconfig:
50
+ double_z: True
51
+ z_channels: 4
52
+ resolution: 256
53
+ in_channels: 3
54
+ out_ch: 3
55
+ ch: 128
56
+ ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
57
+ num_res_blocks: 2
58
+ attn_resolutions: [ ]
59
+ dropout: 0.0
60
+ lossconfig:
61
+ target: torch.nn.Identity
62
+
63
+ cond_stage_config: "__is_unconditional__"
64
+
65
+ data:
66
+ target: main.DataModuleFromConfig
67
+ params:
68
+ batch_size: 96
69
+ num_workers: 5
70
+ wrap: False
71
+ train:
72
+ target: ldm.data.lsun.LSUNChurchesTrain
73
+ params:
74
+ size: 256
75
+ validation:
76
+ target: ldm.data.lsun.LSUNChurchesValidation
77
+ params:
78
+ size: 256
79
+
80
+ lightning:
81
+ callbacks:
82
+ image_logger:
83
+ target: main.ImageLogger
84
+ params:
85
+ batch_frequency: 5000
86
+ max_images: 8
87
+ increase_log_steps: False
88
+
89
+
90
+ trainer:
91
+ benchmark: True
latent-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 5.0e-05
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.012
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ cond_stage_key: caption
12
+ image_size: 32
13
+ channels: 4
14
+ cond_stage_trainable: true
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ unet_config:
21
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22
+ params:
23
+ image_size: 32
24
+ in_channels: 4
25
+ out_channels: 4
26
+ model_channels: 320
27
+ attention_resolutions:
28
+ - 4
29
+ - 2
30
+ - 1
31
+ num_res_blocks: 2
32
+ channel_mult:
33
+ - 1
34
+ - 2
35
+ - 4
36
+ - 4
37
+ num_heads: 8
38
+ use_spatial_transformer: true
39
+ transformer_depth: 1
40
+ context_dim: 1280
41
+ use_checkpoint: true
42
+ legacy: False
43
+
44
+ first_stage_config:
45
+ target: ldm.models.autoencoder.AutoencoderKL
46
+ params:
47
+ embed_dim: 4
48
+ monitor: val/rec_loss
49
+ ddconfig:
50
+ double_z: true
51
+ z_channels: 4
52
+ resolution: 256
53
+ in_channels: 3
54
+ out_ch: 3
55
+ ch: 128
56
+ ch_mult:
57
+ - 1
58
+ - 2
59
+ - 4
60
+ - 4
61
+ num_res_blocks: 2
62
+ attn_resolutions: []
63
+ dropout: 0.0
64
+ lossconfig:
65
+ target: torch.nn.Identity
66
+
67
+ cond_stage_config:
68
+ target: ldm.modules.encoders.modules.BERTEmbedder
69
+ params:
70
+ n_embed: 1280
71
+ n_layer: 32
latent-diffusion/environment.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: ldm
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - python=3.8.5
7
+ - pip=20.3
8
+ - cudatoolkit=11.0
9
+ - pytorch=1.7.0
10
+ - torchvision=0.8.1
11
+ - numpy=1.19.2
12
+ - pip:
13
+ - albumentations==0.4.3
14
+ - opencv-python==4.1.2.30
15
+ - pudb==2019.2
16
+ - imageio==2.9.0
17
+ - imageio-ffmpeg==0.4.2
18
+ - pytorch-lightning==1.4.2
19
+ - omegaconf==2.1.1
20
+ - test-tube>=0.7.5
21
+ - streamlit>=0.73.1
22
+ - einops==0.3.0
23
+ - torch-fidelity==0.3.0
24
+ - transformers==4.3.1
25
+ - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
26
+ - -e git+https://github.com/openai/CLIP.git@main#egg=clip
27
+ - -e .
latent-diffusion/ldm/__pycache__/util.cpython-39.pyc ADDED
Binary file (3.23 kB). View file
latent-diffusion/ldm/data/__init__.py ADDED
File without changes
latent-diffusion/ldm/data/base.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3
+
4
+
5
+ class Txt2ImgIterableBaseDataset(IterableDataset):
6
+ '''
7
+ Define an interface to make the IterableDatasets for text2img data chainable
8
+ '''
9
+ def __init__(self, num_records=0, valid_ids=None, size=256):
10
+ super().__init__()
11
+ self.num_records = num_records
12
+ self.valid_ids = valid_ids
13
+ self.sample_ids = valid_ids
14
+ self.size = size
15
+
16
+ print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17
+
18
+ def __len__(self):
19
+ return self.num_records
20
+
21
+ @abstractmethod
22
+ def __iter__(self):
23
+ pass
latent-diffusion/ldm/data/imagenet.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, yaml, pickle, shutil, tarfile, glob
2
+ import cv2
3
+ import albumentations
4
+ import PIL
5
+ import numpy as np
6
+ import torchvision.transforms.functional as TF
7
+ from omegaconf import OmegaConf
8
+ from functools import partial
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+ from torch.utils.data import Dataset, Subset
12
+
13
+ import taming.data.utils as tdu
14
+ from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
15
+ from taming.data.imagenet import ImagePaths
16
+
17
+ from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
18
+
19
+
20
+ def synset2idx(path_to_yaml="data/index_synset.yaml"):
21
+ with open(path_to_yaml) as f:
22
+ di2s = yaml.load(f)
23
+ return dict((v,k) for k,v in di2s.items())
24
+
25
+
26
+ class ImageNetBase(Dataset):
27
+ def __init__(self, config=None):
28
+ self.config = config or OmegaConf.create()
29
+ if not type(self.config)==dict:
30
+ self.config = OmegaConf.to_container(self.config)
31
+ self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
32
+ self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
33
+ self._prepare()
34
+ self._prepare_synset_to_human()
35
+ self._prepare_idx_to_synset()
36
+ self._prepare_human_to_integer_label()
37
+ self._load()
38
+
39
+ def __len__(self):
40
+ return len(self.data)
41
+
42
+ def __getitem__(self, i):
43
+ return self.data[i]
44
+
45
+ def _prepare(self):
46
+ raise NotImplementedError()
47
+
48
+ def _filter_relpaths(self, relpaths):
49
+ ignore = set([
50
+ "n06596364_9591.JPEG",
51
+ ])
52
+ relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
53
+ if "sub_indices" in self.config:
54
+ indices = str_to_indices(self.config["sub_indices"])
55
+ synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
56
+ self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
57
+ files = []
58
+ for rpath in relpaths:
59
+ syn = rpath.split("/")[0]
60
+ if syn in synsets:
61
+ files.append(rpath)
62
+ return files
63
+ else:
64
+ return relpaths
65
+
66
+ def _prepare_synset_to_human(self):
67
+ SIZE = 2655750
68
+ URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
69
+ self.human_dict = os.path.join(self.root, "synset_human.txt")
70
+ if (not os.path.exists(self.human_dict) or
71
+ not os.path.getsize(self.human_dict)==SIZE):
72
+ download(URL, self.human_dict)
73
+
74
+ def _prepare_idx_to_synset(self):
75
+ URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
76
+ self.idx2syn = os.path.join(self.root, "index_synset.yaml")
77
+ if (not os.path.exists(self.idx2syn)):
78
+ download(URL, self.idx2syn)
79
+
80
+ def _prepare_human_to_integer_label(self):
81
+ URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
82
+ self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
83
+ if (not os.path.exists(self.human2integer)):
84
+ download(URL, self.human2integer)
85
+ with open(self.human2integer, "r") as f:
86
+ lines = f.read().splitlines()
87
+ assert len(lines) == 1000
88
+ self.human2integer_dict = dict()
89
+ for line in lines:
90
+ value, key = line.split(":")
91
+ self.human2integer_dict[key] = int(value)
92
+
93
+ def _load(self):
94
+ with open(self.txt_filelist, "r") as f:
95
+ self.relpaths = f.read().splitlines()
96
+ l1 = len(self.relpaths)
97
+ self.relpaths = self._filter_relpaths(self.relpaths)
98
+ print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
99
+
100
+ self.synsets = [p.split("/")[0] for p in self.relpaths]
101
+ self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
102
+
103
+ unique_synsets = np.unique(self.synsets)
104
+ class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
105
+ if not self.keep_orig_class_label:
106
+ self.class_labels = [class_dict[s] for s in self.synsets]
107
+ else:
108
+ self.class_labels = [self.synset2idx[s] for s in self.synsets]
109
+
110
+ with open(self.human_dict, "r") as f:
111
+ human_dict = f.read().splitlines()
112
+ human_dict = dict(line.split(maxsplit=1) for line in human_dict)
113
+
114
+ self.human_labels = [human_dict[s] for s in self.synsets]
115
+
116
+ labels = {
117
+ "relpath": np.array(self.relpaths),
118
+ "synsets": np.array(self.synsets),
119
+ "class_label": np.array(self.class_labels),
120
+ "human_label": np.array(self.human_labels),
121
+ }
122
+
123
+ if self.process_images:
124
+ self.size = retrieve(self.config, "size", default=256)
125
+ self.data = ImagePaths(self.abspaths,
126
+ labels=labels,
127
+ size=self.size,
128
+ random_crop=self.random_crop,
129
+ )
130
+ else:
131
+ self.data = self.abspaths
132
+
133
+
134
+ class ImageNetTrain(ImageNetBase):
135
+ NAME = "ILSVRC2012_train"
136
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
137
+ AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
138
+ FILES = [
139
+ "ILSVRC2012_img_train.tar",
140
+ ]
141
+ SIZES = [
142
+ 147897477120,
143
+ ]
144
+
145
+ def __init__(self, process_images=True, data_root=None, **kwargs):
146
+ self.process_images = process_images
147
+ self.data_root = data_root
148
+ super().__init__(**kwargs)
149
+
150
+ def _prepare(self):
151
+ if self.data_root:
152
+ self.root = os.path.join(self.data_root, self.NAME)
153
+ else:
154
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
155
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
156
+
157
+ self.datadir = os.path.join(self.root, "data")
158
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
159
+ self.expected_length = 1281167
160
+ self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
161
+ default=True)
162
+ if not tdu.is_prepared(self.root):
163
+ # prep
164
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
165
+
166
+ datadir = self.datadir
167
+ if not os.path.exists(datadir):
168
+ path = os.path.join(self.root, self.FILES[0])
169
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
170
+ import academictorrents as at
171
+ atpath = at.get(self.AT_HASH, datastore=self.root)
172
+ assert atpath == path
173
+
174
+ print("Extracting {} to {}".format(path, datadir))
175
+ os.makedirs(datadir, exist_ok=True)
176
+ with tarfile.open(path, "r:") as tar:
177
+ tar.extractall(path=datadir)
178
+
179
+ print("Extracting sub-tars.")
180
+ subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
181
+ for subpath in tqdm(subpaths):
182
+ subdir = subpath[:-len(".tar")]
183
+ os.makedirs(subdir, exist_ok=True)
184
+ with tarfile.open(subpath, "r:") as tar:
185
+ tar.extractall(path=subdir)
186
+
187
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
188
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
189
+ filelist = sorted(filelist)
190
+ filelist = "\n".join(filelist)+"\n"
191
+ with open(self.txt_filelist, "w") as f:
192
+ f.write(filelist)
193
+
194
+ tdu.mark_prepared(self.root)
195
+
196
+
197
+ class ImageNetValidation(ImageNetBase):
198
+ NAME = "ILSVRC2012_validation"
199
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
200
+ AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
201
+ VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
202
+ FILES = [
203
+ "ILSVRC2012_img_val.tar",
204
+ "validation_synset.txt",
205
+ ]
206
+ SIZES = [
207
+ 6744924160,
208
+ 1950000,
209
+ ]
210
+
211
+ def __init__(self, process_images=True, data_root=None, **kwargs):
212
+ self.data_root = data_root
213
+ self.process_images = process_images
214
+ super().__init__(**kwargs)
215
+
216
+ def _prepare(self):
217
+ if self.data_root:
218
+ self.root = os.path.join(self.data_root, self.NAME)
219
+ else:
220
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
221
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
222
+ self.datadir = os.path.join(self.root, "data")
223
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
224
+ self.expected_length = 50000
225
+ self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
226
+ default=False)
227
+ if not tdu.is_prepared(self.root):
228
+ # prep
229
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
230
+
231
+ datadir = self.datadir
232
+ if not os.path.exists(datadir):
233
+ path = os.path.join(self.root, self.FILES[0])
234
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
235
+ import academictorrents as at
236
+ atpath = at.get(self.AT_HASH, datastore=self.root)
237
+ assert atpath == path
238
+
239
+ print("Extracting {} to {}".format(path, datadir))
240
+ os.makedirs(datadir, exist_ok=True)
241
+ with tarfile.open(path, "r:") as tar:
242
+ tar.extractall(path=datadir)
243
+
244
+ vspath = os.path.join(self.root, self.FILES[1])
245
+ if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
246
+ download(self.VS_URL, vspath)
247
+
248
+ with open(vspath, "r") as f:
249
+ synset_dict = f.read().splitlines()
250
+ synset_dict = dict(line.split() for line in synset_dict)
251
+
252
+ print("Reorganizing into synset folders")
253
+ synsets = np.unique(list(synset_dict.values()))
254
+ for s in synsets:
255
+ os.makedirs(os.path.join(datadir, s), exist_ok=True)
256
+ for k, v in synset_dict.items():
257
+ src = os.path.join(datadir, k)
258
+ dst = os.path.join(datadir, v)
259
+ shutil.move(src, dst)
260
+
261
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
262
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
263
+ filelist = sorted(filelist)
264
+ filelist = "\n".join(filelist)+"\n"
265
+ with open(self.txt_filelist, "w") as f:
266
+ f.write(filelist)
267
+
268
+ tdu.mark_prepared(self.root)
269
+
270
+
271
+
272
+ class ImageNetSR(Dataset):
273
+ def __init__(self, size=None,
274
+ degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
275
+ random_crop=True):
276
+ """
277
+ Imagenet Superresolution Dataloader
278
+ Performs following ops in order:
279
+ 1. crops a crop of size s from image either as random or center crop
280
+ 2. resizes crop to size with cv2.area_interpolation
281
+ 3. degrades resized crop with degradation_fn
282
+
283
+ :param size: resizing to size after cropping
284
+ :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
285
+ :param downscale_f: Low Resolution Downsample factor
286
+ :param min_crop_f: determines crop size s,
287
+ where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
288
+ :param max_crop_f: ""
289
+ :param data_root:
290
+ :param random_crop:
291
+ """
292
+ self.base = self.get_base()
293
+ assert size
294
+ assert (size / downscale_f).is_integer()
295
+ self.size = size
296
+ self.LR_size = int(size / downscale_f)
297
+ self.min_crop_f = min_crop_f
298
+ self.max_crop_f = max_crop_f
299
+ assert(max_crop_f <= 1.)
300
+ self.center_crop = not random_crop
301
+
302
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
303
+
304
+ self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
305
+
306
+ if degradation == "bsrgan":
307
+ self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
308
+
309
+ elif degradation == "bsrgan_light":
310
+ self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
311
+
312
+ else:
313
+ interpolation_fn = {
314
+ "cv_nearest": cv2.INTER_NEAREST,
315
+ "cv_bilinear": cv2.INTER_LINEAR,
316
+ "cv_bicubic": cv2.INTER_CUBIC,
317
+ "cv_area": cv2.INTER_AREA,
318
+ "cv_lanczos": cv2.INTER_LANCZOS4,
319
+ "pil_nearest": PIL.Image.NEAREST,
320
+ "pil_bilinear": PIL.Image.BILINEAR,
321
+ "pil_bicubic": PIL.Image.BICUBIC,
322
+ "pil_box": PIL.Image.BOX,
323
+ "pil_hamming": PIL.Image.HAMMING,
324
+ "pil_lanczos": PIL.Image.LANCZOS,
325
+ }[degradation]
326
+
327
+ self.pil_interpolation = degradation.startswith("pil_")
328
+
329
+ if self.pil_interpolation:
330
+ self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
331
+
332
+ else:
333
+ self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
334
+ interpolation=interpolation_fn)
335
+
336
+ def __len__(self):
337
+ return len(self.base)
338
+
339
+ def __getitem__(self, i):
340
+ example = self.base[i]
341
+ image = Image.open(example["file_path_"])
342
+
343
+ if not image.mode == "RGB":
344
+ image = image.convert("RGB")
345
+
346
+ image = np.array(image).astype(np.uint8)
347
+
348
+ min_side_len = min(image.shape[:2])
349
+ crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
350
+ crop_side_len = int(crop_side_len)
351
+
352
+ if self.center_crop:
353
+ self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
354
+
355
+ else:
356
+ self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
357
+
358
+ image = self.cropper(image=image)["image"]
359
+ image = self.image_rescaler(image=image)["image"]
360
+
361
+ if self.pil_interpolation:
362
+ image_pil = PIL.Image.fromarray(image)
363
+ LR_image = self.degradation_process(image_pil)
364
+ LR_image = np.array(LR_image).astype(np.uint8)
365
+
366
+ else:
367
+ LR_image = self.degradation_process(image=image)["image"]
368
+
369
+ example["image"] = (image/127.5 - 1.0).astype(np.float32)
370
+ example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
371
+
372
+ return example
373
+
374
+
375
+ class ImageNetSRTrain(ImageNetSR):
376
+ def __init__(self, **kwargs):
377
+ super().__init__(**kwargs)
378
+
379
+ def get_base(self):
380
+ with open("data/imagenet_train_hr_indices.p", "rb") as f:
381
+ indices = pickle.load(f)
382
+ dset = ImageNetTrain(process_images=False,)
383
+ return Subset(dset, indices)
384
+
385
+
386
+ class ImageNetSRValidation(ImageNetSR):
387
+ def __init__(self, **kwargs):
388
+ super().__init__(**kwargs)
389
+
390
+ def get_base(self):
391
+ with open("data/imagenet_val_hr_indices.p", "rb") as f:
392
+ indices = pickle.load(f)
393
+ dset = ImageNetValidation(process_images=False,)
394
+ return Subset(dset, indices)
latent-diffusion/ldm/data/lsun.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import PIL
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset
6
+ from torchvision import transforms
7
+
8
+
9
+ class LSUNBase(Dataset):
10
+ def __init__(self,
11
+ txt_file,
12
+ data_root,
13
+ size=None,
14
+ interpolation="bicubic",
15
+ flip_p=0.5
16
+ ):
17
+ self.data_paths = txt_file
18
+ self.data_root = data_root
19
+ with open(self.data_paths, "r") as f:
20
+ self.image_paths = f.read().splitlines()
21
+ self._length = len(self.image_paths)
22
+ self.labels = {
23
+ "relative_file_path_": [l for l in self.image_paths],
24
+ "file_path_": [os.path.join(self.data_root, l)
25
+ for l in self.image_paths],
26
+ }
27
+
28
+ self.size = size
29
+ self.interpolation = {"linear": PIL.Image.LINEAR,
30
+ "bilinear": PIL.Image.BILINEAR,
31
+ "bicubic": PIL.Image.BICUBIC,
32
+ "lanczos": PIL.Image.LANCZOS,
33
+ }[interpolation]
34
+ self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35
+
36
+ def __len__(self):
37
+ return self._length
38
+
39
+ def __getitem__(self, i):
40
+ example = dict((k, self.labels[k][i]) for k in self.labels)
41
+ image = Image.open(example["file_path_"])
42
+ if not image.mode == "RGB":
43
+ image = image.convert("RGB")
44
+
45
+ # default to score-sde preprocessing
46
+ img = np.array(image).astype(np.uint8)
47
+ crop = min(img.shape[0], img.shape[1])
48
+ h, w, = img.shape[0], img.shape[1]
49
+ img = img[(h - crop) // 2:(h + crop) // 2,
50
+ (w - crop) // 2:(w + crop) // 2]
51
+
52
+ image = Image.fromarray(img)
53
+ if self.size is not None:
54
+ image = image.resize((self.size, self.size), resample=self.interpolation)
55
+
56
+ image = self.flip(image)
57
+ image = np.array(image).astype(np.uint8)
58
+ example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59
+ return example
60
+
61
+
62
+ class LSUNChurchesTrain(LSUNBase):
63
+ def __init__(self, **kwargs):
64
+ super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65
+
66
+
67
+ class LSUNChurchesValidation(LSUNBase):
68
+ def __init__(self, flip_p=0., **kwargs):
69
+ super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70
+ flip_p=flip_p, **kwargs)
71
+
72
+
73
+ class LSUNBedroomsTrain(LSUNBase):
74
+ def __init__(self, **kwargs):
75
+ super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76
+
77
+
78
+ class LSUNBedroomsValidation(LSUNBase):
79
+ def __init__(self, flip_p=0.0, **kwargs):
80
+ super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81
+ flip_p=flip_p, **kwargs)
82
+
83
+
84
+ class LSUNCatsTrain(LSUNBase):
85
+ def __init__(self, **kwargs):
86
+ super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87
+
88
+
89
+ class LSUNCatsValidation(LSUNBase):
90
+ def __init__(self, flip_p=0., **kwargs):
91
+ super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92
+ flip_p=flip_p, **kwargs)
latent-diffusion/ldm/lr_scheduler.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9
+ self.lr_warm_up_steps = warm_up_steps
10
+ self.lr_start = lr_start
11
+ self.lr_min = lr_min
12
+ self.lr_max = lr_max
13
+ self.lr_max_decay_steps = max_decay_steps
14
+ self.last_lr = 0.
15
+ self.verbosity_interval = verbosity_interval
16
+
17
+ def schedule(self, n, **kwargs):
18
+ if self.verbosity_interval > 0:
19
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20
+ if n < self.lr_warm_up_steps:
21
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22
+ self.last_lr = lr
23
+ return lr
24
+ else:
25
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26
+ t = min(t, 1.0)
27
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28
+ 1 + np.cos(t * np.pi))
29
+ self.last_lr = lr
30
+ return lr
31
+
32
+ def __call__(self, n, **kwargs):
33
+ return self.schedule(n,**kwargs)
34
+
35
+
36
+ class LambdaWarmUpCosineScheduler2:
37
+ """
38
+ supports repeated iterations, configurable via lists
39
+ note: use with a base_lr of 1.0.
40
+ """
41
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43
+ self.lr_warm_up_steps = warm_up_steps
44
+ self.f_start = f_start
45
+ self.f_min = f_min
46
+ self.f_max = f_max
47
+ self.cycle_lengths = cycle_lengths
48
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49
+ self.last_f = 0.
50
+ self.verbosity_interval = verbosity_interval
51
+
52
+ def find_in_interval(self, n):
53
+ interval = 0
54
+ for cl in self.cum_cycles[1:]:
55
+ if n <= cl:
56
+ return interval
57
+ interval += 1
58
+
59
+ def schedule(self, n, **kwargs):
60
+ cycle = self.find_in_interval(n)
61
+ n = n - self.cum_cycles[cycle]
62
+ if self.verbosity_interval > 0:
63
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64
+ f"current cycle {cycle}")
65
+ if n < self.lr_warm_up_steps[cycle]:
66
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67
+ self.last_f = f
68
+ return f
69
+ else:
70
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71
+ t = min(t, 1.0)
72
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73
+ 1 + np.cos(t * np.pi))
74
+ self.last_f = f
75
+ return f
76
+
77
+ def __call__(self, n, **kwargs):
78
+ return self.schedule(n, **kwargs)
79
+
80
+
81
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82
+
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88
+ f"current cycle {cycle}")
89
+
90
+ if n < self.lr_warm_up_steps[cycle]:
91
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92
+ self.last_f = f
93
+ return f
94
+ else:
95
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96
+ self.last_f = f
97
+ return f
98
+
latent-diffusion/ldm/models/__pycache__/autoencoder.cpython-39.pyc ADDED
Binary file (13.6 kB). View file
latent-diffusion/ldm/models/autoencoder.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ import torch.nn.functional as F
4
+ from contextlib import contextmanager
5
+
6
+ from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
7
+
8
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
9
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
10
+
11
+ from ldm.util import instantiate_from_config
12
+
13
+
14
+ class VQModel(pl.LightningModule):
15
+ def __init__(self,
16
+ ddconfig,
17
+ lossconfig,
18
+ n_embed,
19
+ embed_dim,
20
+ ckpt_path=None,
21
+ ignore_keys=[],
22
+ image_key="image",
23
+ colorize_nlabels=None,
24
+ monitor=None,
25
+ batch_resize_range=None,
26
+ scheduler_config=None,
27
+ lr_g_factor=1.0,
28
+ remap=None,
29
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
30
+ use_ema=False
31
+ ):
32
+ super().__init__()
33
+ self.embed_dim = embed_dim
34
+ self.n_embed = n_embed
35
+ self.image_key = image_key
36
+ self.encoder = Encoder(**ddconfig)
37
+ self.decoder = Decoder(**ddconfig)
38
+ self.loss = instantiate_from_config(lossconfig)
39
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
40
+ remap=remap,
41
+ sane_index_shape=sane_index_shape)
42
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
43
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
44
+ if colorize_nlabels is not None:
45
+ assert type(colorize_nlabels)==int
46
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
47
+ if monitor is not None:
48
+ self.monitor = monitor
49
+ self.batch_resize_range = batch_resize_range
50
+ if self.batch_resize_range is not None:
51
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
52
+
53
+ self.use_ema = use_ema
54
+ if self.use_ema:
55
+ self.model_ema = LitEma(self)
56
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
57
+
58
+ if ckpt_path is not None:
59
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
60
+ self.scheduler_config = scheduler_config
61
+ self.lr_g_factor = lr_g_factor
62
+
63
+ @contextmanager
64
+ def ema_scope(self, context=None):
65
+ if self.use_ema:
66
+ self.model_ema.store(self.parameters())
67
+ self.model_ema.copy_to(self)
68
+ if context is not None:
69
+ print(f"{context}: Switched to EMA weights")
70
+ try:
71
+ yield None
72
+ finally:
73
+ if self.use_ema:
74
+ self.model_ema.restore(self.parameters())
75
+ if context is not None:
76
+ print(f"{context}: Restored training weights")
77
+
78
+ def init_from_ckpt(self, path, ignore_keys=list()):
79
+ sd = torch.load(path, map_location="cpu")["state_dict"]
80
+ keys = list(sd.keys())
81
+ for k in keys:
82
+ for ik in ignore_keys:
83
+ if k.startswith(ik):
84
+ print("Deleting key {} from state_dict.".format(k))
85
+ del sd[k]
86
+ missing, unexpected = self.load_state_dict(sd, strict=False)
87
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
88
+ if len(missing) > 0:
89
+ print(f"Missing Keys: {missing}")
90
+ print(f"Unexpected Keys: {unexpected}")
91
+
92
+ def on_train_batch_end(self, *args, **kwargs):
93
+ if self.use_ema:
94
+ self.model_ema(self)
95
+
96
+ def encode(self, x):
97
+ h = self.encoder(x)
98
+ h = self.quant_conv(h)
99
+ quant, emb_loss, info = self.quantize(h)
100
+ return quant, emb_loss, info
101
+
102
+ def encode_to_prequant(self, x):
103
+ h = self.encoder(x)
104
+ h = self.quant_conv(h)
105
+ return h
106
+
107
+ def decode(self, quant):
108
+ quant = self.post_quant_conv(quant)
109
+ dec = self.decoder(quant)
110
+ return dec
111
+
112
+ def decode_code(self, code_b):
113
+ quant_b = self.quantize.embed_code(code_b)
114
+ dec = self.decode(quant_b)
115
+ return dec
116
+
117
+ def forward(self, input, return_pred_indices=False):
118
+ quant, diff, (_,_,ind) = self.encode(input)
119
+ dec = self.decode(quant)
120
+ if return_pred_indices:
121
+ return dec, diff, ind
122
+ return dec, diff
123
+
124
+ def get_input(self, batch, k):
125
+ x = batch[k]
126
+ if len(x.shape) == 3:
127
+ x = x[..., None]
128
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
129
+ if self.batch_resize_range is not None:
130
+ lower_size = self.batch_resize_range[0]
131
+ upper_size = self.batch_resize_range[1]
132
+ if self.global_step <= 4:
133
+ # do the first few batches with max size to avoid later oom
134
+ new_resize = upper_size
135
+ else:
136
+ new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
137
+ if new_resize != x.shape[2]:
138
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
139
+ x = x.detach()
140
+ return x
141
+
142
+ def training_step(self, batch, batch_idx, optimizer_idx):
143
+ # https://github.com/pytorch/pytorch/issues/37142
144
+ # try not to fool the heuristics
145
+ x = self.get_input(batch, self.image_key)
146
+ xrec, qloss, ind = self(x, return_pred_indices=True)
147
+
148
+ if optimizer_idx == 0:
149
+ # autoencode
150
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
151
+ last_layer=self.get_last_layer(), split="train",
152
+ predicted_indices=ind)
153
+
154
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
155
+ return aeloss
156
+
157
+ if optimizer_idx == 1:
158
+ # discriminator
159
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
160
+ last_layer=self.get_last_layer(), split="train")
161
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
162
+ return discloss
163
+
164
+ def validation_step(self, batch, batch_idx):
165
+ log_dict = self._validation_step(batch, batch_idx)
166
+ with self.ema_scope():
167
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
168
+ return log_dict
169
+
170
+ def _validation_step(self, batch, batch_idx, suffix=""):
171
+ x = self.get_input(batch, self.image_key)
172
+ xrec, qloss, ind = self(x, return_pred_indices=True)
173
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
174
+ self.global_step,
175
+ last_layer=self.get_last_layer(),
176
+ split="val"+suffix,
177
+ predicted_indices=ind
178
+ )
179
+
180
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
181
+ self.global_step,
182
+ last_layer=self.get_last_layer(),
183
+ split="val"+suffix,
184
+ predicted_indices=ind
185
+ )
186
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
187
+ self.log(f"val{suffix}/rec_loss", rec_loss,
188
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
189
+ self.log(f"val{suffix}/aeloss", aeloss,
190
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
191
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
192
+ del log_dict_ae[f"val{suffix}/rec_loss"]
193
+ self.log_dict(log_dict_ae)
194
+ self.log_dict(log_dict_disc)
195
+ return self.log_dict
196
+
197
+ def configure_optimizers(self):
198
+ lr_d = self.learning_rate
199
+ lr_g = self.lr_g_factor*self.learning_rate
200
+ print("lr_d", lr_d)
201
+ print("lr_g", lr_g)
202
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
203
+ list(self.decoder.parameters())+
204
+ list(self.quantize.parameters())+
205
+ list(self.quant_conv.parameters())+
206
+ list(self.post_quant_conv.parameters()),
207
+ lr=lr_g, betas=(0.5, 0.9))
208
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
209
+ lr=lr_d, betas=(0.5, 0.9))
210
+
211
+ if self.scheduler_config is not None:
212
+ scheduler = instantiate_from_config(self.scheduler_config)
213
+
214
+ print("Setting up LambdaLR scheduler...")
215
+ scheduler = [
216
+ {
217
+ 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
218
+ 'interval': 'step',
219
+ 'frequency': 1
220
+ },
221
+ {
222
+ 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
223
+ 'interval': 'step',
224
+ 'frequency': 1
225
+ },
226
+ ]
227
+ return [opt_ae, opt_disc], scheduler
228
+ return [opt_ae, opt_disc], []
229
+
230
+ def get_last_layer(self):
231
+ return self.decoder.conv_out.weight
232
+
233
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
234
+ log = dict()
235
+ x = self.get_input(batch, self.image_key)
236
+ x = x.to(self.device)
237
+ if only_inputs:
238
+ log["inputs"] = x
239
+ return log
240
+ xrec, _ = self(x)
241
+ if x.shape[1] > 3:
242
+ # colorize with random projection
243
+ assert xrec.shape[1] > 3
244
+ x = self.to_rgb(x)
245
+ xrec = self.to_rgb(xrec)
246
+ log["inputs"] = x
247
+ log["reconstructions"] = xrec
248
+ if plot_ema:
249
+ with self.ema_scope():
250
+ xrec_ema, _ = self(x)
251
+ if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
252
+ log["reconstructions_ema"] = xrec_ema
253
+ return log
254
+
255
+ def to_rgb(self, x):
256
+ assert self.image_key == "segmentation"
257
+ if not hasattr(self, "colorize"):
258
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
259
+ x = F.conv2d(x, weight=self.colorize)
260
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
261
+ return x
262
+
263
+
264
+ class VQModelInterface(VQModel):
265
+ def __init__(self, embed_dim, *args, **kwargs):
266
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
267
+ self.embed_dim = embed_dim
268
+
269
+ def encode(self, x):
270
+ h = self.encoder(x)
271
+ h = self.quant_conv(h)
272
+ return h
273
+
274
+ def decode(self, h, force_not_quantize=False):
275
+ # also go through quantization layer
276
+ if not force_not_quantize:
277
+ quant, emb_loss, info = self.quantize(h)
278
+ else:
279
+ quant = h
280
+ quant = self.post_quant_conv(quant)
281
+ dec = self.decoder(quant)
282
+ return dec
283
+
284
+
285
+ class AutoencoderKL(pl.LightningModule):
286
+ def __init__(self,
287
+ ddconfig,
288
+ lossconfig,
289
+ embed_dim,
290
+ ckpt_path=None,
291
+ ignore_keys=[],
292
+ image_key="image",
293
+ colorize_nlabels=None,
294
+ monitor=None,
295
+ ):
296
+ super().__init__()
297
+ self.image_key = image_key
298
+ self.encoder = Encoder(**ddconfig)
299
+ self.decoder = Decoder(**ddconfig)
300
+ self.loss = instantiate_from_config(lossconfig)
301
+ assert ddconfig["double_z"]
302
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
303
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
304
+ self.embed_dim = embed_dim
305
+ if colorize_nlabels is not None:
306
+ assert type(colorize_nlabels)==int
307
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
308
+ if monitor is not None:
309
+ self.monitor = monitor
310
+ if ckpt_path is not None:
311
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
312
+
313
+ def init_from_ckpt(self, path, ignore_keys=list()):
314
+ sd = torch.load(path, map_location="cpu")["state_dict"]
315
+ keys = list(sd.keys())
316
+ for k in keys:
317
+ for ik in ignore_keys:
318
+ if k.startswith(ik):
319
+ print("Deleting key {} from state_dict.".format(k))
320
+ del sd[k]
321
+ self.load_state_dict(sd, strict=False)
322
+ print(f"Restored from {path}")
323
+
324
+ def encode(self, x):
325
+ h = self.encoder(x)
326
+ moments = self.quant_conv(h)
327
+ posterior = DiagonalGaussianDistribution(moments)
328
+ return posterior
329
+
330
+ def decode(self, z):
331
+ z = self.post_quant_conv(z)
332
+ dec = self.decoder(z)
333
+ return dec
334
+
335
+ def forward(self, input, sample_posterior=True):
336
+ posterior = self.encode(input)
337
+ if sample_posterior:
338
+ z = posterior.sample()
339
+ else:
340
+ z = posterior.mode()
341
+ dec = self.decode(z)
342
+ return dec, posterior
343
+
344
+ def get_input(self, batch, k):
345
+ x = batch[k]
346
+ if len(x.shape) == 3:
347
+ x = x[..., None]
348
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
349
+ return x
350
+
351
+ def training_step(self, batch, batch_idx, optimizer_idx):
352
+ inputs = self.get_input(batch, self.image_key)
353
+ reconstructions, posterior = self(inputs)
354
+
355
+ if optimizer_idx == 0:
356
+ # train encoder+decoder+logvar
357
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
358
+ last_layer=self.get_last_layer(), split="train")
359
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
360
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
361
+ return aeloss
362
+
363
+ if optimizer_idx == 1:
364
+ # train the discriminator
365
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
366
+ last_layer=self.get_last_layer(), split="train")
367
+
368
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
369
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
370
+ return discloss
371
+
372
+ def validation_step(self, batch, batch_idx):
373
+ inputs = self.get_input(batch, self.image_key)
374
+ reconstructions, posterior = self(inputs)
375
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
376
+ last_layer=self.get_last_layer(), split="val")
377
+
378
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
379
+ last_layer=self.get_last_layer(), split="val")
380
+
381
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
382
+ self.log_dict(log_dict_ae)
383
+ self.log_dict(log_dict_disc)
384
+ return self.log_dict
385
+
386
+ def configure_optimizers(self):
387
+ lr = self.learning_rate
388
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
389
+ list(self.decoder.parameters())+
390
+ list(self.quant_conv.parameters())+
391
+ list(self.post_quant_conv.parameters()),
392
+ lr=lr, betas=(0.5, 0.9))
393
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
394
+ lr=lr, betas=(0.5, 0.9))
395
+ return [opt_ae, opt_disc], []
396
+
397
+ def get_last_layer(self):
398
+ return self.decoder.conv_out.weight
399
+
400
+ @torch.no_grad()
401
+ def log_images(self, batch, only_inputs=False, **kwargs):
402
+ log = dict()
403
+ x = self.get_input(batch, self.image_key)
404
+ x = x.to(self.device)
405
+ if not only_inputs:
406
+ xrec, posterior = self(x)
407
+ if x.shape[1] > 3:
408
+ # colorize with random projection
409
+ assert xrec.shape[1] > 3
410
+ x = self.to_rgb(x)
411
+ xrec = self.to_rgb(xrec)
412
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
413
+ log["reconstructions"] = xrec
414
+ log["inputs"] = x
415
+ return log
416
+
417
+ def to_rgb(self, x):
418
+ assert self.image_key == "segmentation"
419
+ if not hasattr(self, "colorize"):
420
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
421
+ x = F.conv2d(x, weight=self.colorize)
422
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
423
+ return x
424
+
425
+
426
+ class IdentityFirstStage(torch.nn.Module):
427
+ def __init__(self, *args, vq_interface=False, **kwargs):
428
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
429
+ super().__init__()
430
+
431
+ def encode(self, x, *args, **kwargs):
432
+ return x
433
+
434
+ def decode(self, x, *args, **kwargs):
435
+ return x
436
+
437
+ def quantize(self, x, *args, **kwargs):
438
+ if self.vq_interface:
439
+ return x, None, [None, None, None]
440
+ return x
441
+
442
+ def forward(self, x, *args, **kwargs):
443
+ return x
latent-diffusion/ldm/models/diffusion/__init__.py ADDED
File without changes
latent-diffusion/ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (193 Bytes). View file
latent-diffusion/ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc ADDED
Binary file (6.19 kB). View file
latent-diffusion/ldm/models/diffusion/__pycache__/ddpm.cpython-39.pyc ADDED
Binary file (44 kB). View file
latent-diffusion/ldm/models/diffusion/__pycache__/plms.cpython-39.pyc ADDED
Binary file (7.33 kB). View file
latent-diffusion/ldm/models/diffusion/classifier.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pytorch_lightning as pl
4
+ from omegaconf import OmegaConf
5
+ from torch.nn import functional as F
6
+ from torch.optim import AdamW
7
+ from torch.optim.lr_scheduler import LambdaLR
8
+ from copy import deepcopy
9
+ from einops import rearrange
10
+ from glob import glob
11
+ from natsort import natsorted
12
+
13
+ from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14
+ from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15
+
16
+ __models__ = {
17
+ 'class_label': EncoderUNetModel,
18
+ 'segmentation': UNetModel
19
+ }
20
+
21
+
22
+ def disabled_train(self, mode=True):
23
+ """Overwrite model.train with this function to make sure train/eval mode
24
+ does not change anymore."""
25
+ return self
26
+
27
+
28
+ class NoisyLatentImageClassifier(pl.LightningModule):
29
+
30
+ def __init__(self,
31
+ diffusion_path,
32
+ num_classes,
33
+ ckpt_path=None,
34
+ pool='attention',
35
+ label_key=None,
36
+ diffusion_ckpt_path=None,
37
+ scheduler_config=None,
38
+ weight_decay=1.e-2,
39
+ log_steps=10,
40
+ monitor='val/loss',
41
+ *args,
42
+ **kwargs):
43
+ super().__init__(*args, **kwargs)
44
+ self.num_classes = num_classes
45
+ # get latest config of diffusion model
46
+ diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47
+ self.diffusion_config = OmegaConf.load(diffusion_config).model
48
+ self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49
+ self.load_diffusion()
50
+
51
+ self.monitor = monitor
52
+ self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53
+ self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54
+ self.log_steps = log_steps
55
+
56
+ self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57
+ else self.diffusion_model.cond_stage_key
58
+
59
+ assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60
+
61
+ if self.label_key not in __models__:
62
+ raise NotImplementedError()
63
+
64
+ self.load_classifier(ckpt_path, pool)
65
+
66
+ self.scheduler_config = scheduler_config
67
+ self.use_scheduler = self.scheduler_config is not None
68
+ self.weight_decay = weight_decay
69
+
70
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71
+ sd = torch.load(path, map_location="cpu")
72
+ if "state_dict" in list(sd.keys()):
73
+ sd = sd["state_dict"]
74
+ keys = list(sd.keys())
75
+ for k in keys:
76
+ for ik in ignore_keys:
77
+ if k.startswith(ik):
78
+ print("Deleting key {} from state_dict.".format(k))
79
+ del sd[k]
80
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81
+ sd, strict=False)
82
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83
+ if len(missing) > 0:
84
+ print(f"Missing Keys: {missing}")
85
+ if len(unexpected) > 0:
86
+ print(f"Unexpected Keys: {unexpected}")
87
+
88
+ def load_diffusion(self):
89
+ model = instantiate_from_config(self.diffusion_config)
90
+ self.diffusion_model = model.eval()
91
+ self.diffusion_model.train = disabled_train
92
+ for param in self.diffusion_model.parameters():
93
+ param.requires_grad = False
94
+
95
+ def load_classifier(self, ckpt_path, pool):
96
+ model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97
+ model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98
+ model_config.out_channels = self.num_classes
99
+ if self.label_key == 'class_label':
100
+ model_config.pool = pool
101
+
102
+ self.model = __models__[self.label_key](**model_config)
103
+ if ckpt_path is not None:
104
+ print('#####################################################################')
105
+ print(f'load from ckpt "{ckpt_path}"')
106
+ print('#####################################################################')
107
+ self.init_from_ckpt(ckpt_path)
108
+
109
+ @torch.no_grad()
110
+ def get_x_noisy(self, x, t, noise=None):
111
+ noise = default(noise, lambda: torch.randn_like(x))
112
+ continuous_sqrt_alpha_cumprod = None
113
+ if self.diffusion_model.use_continuous_noise:
114
+ continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115
+ # todo: make sure t+1 is correct here
116
+
117
+ return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118
+ continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119
+
120
+ def forward(self, x_noisy, t, *args, **kwargs):
121
+ return self.model(x_noisy, t)
122
+
123
+ @torch.no_grad()
124
+ def get_input(self, batch, k):
125
+ x = batch[k]
126
+ if len(x.shape) == 3:
127
+ x = x[..., None]
128
+ x = rearrange(x, 'b h w c -> b c h w')
129
+ x = x.to(memory_format=torch.contiguous_format).float()
130
+ return x
131
+
132
+ @torch.no_grad()
133
+ def get_conditioning(self, batch, k=None):
134
+ if k is None:
135
+ k = self.label_key
136
+ assert k is not None, 'Needs to provide label key'
137
+
138
+ targets = batch[k].to(self.device)
139
+
140
+ if self.label_key == 'segmentation':
141
+ targets = rearrange(targets, 'b h w c -> b c h w')
142
+ for down in range(self.numd):
143
+ h, w = targets.shape[-2:]
144
+ targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145
+
146
+ # targets = rearrange(targets,'b c h w -> b h w c')
147
+
148
+ return targets
149
+
150
+ def compute_top_k(self, logits, labels, k, reduction="mean"):
151
+ _, top_ks = torch.topk(logits, k, dim=1)
152
+ if reduction == "mean":
153
+ return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154
+ elif reduction == "none":
155
+ return (top_ks == labels[:, None]).float().sum(dim=-1)
156
+
157
+ def on_train_epoch_start(self):
158
+ # save some memory
159
+ self.diffusion_model.model.to('cpu')
160
+
161
+ @torch.no_grad()
162
+ def write_logs(self, loss, logits, targets):
163
+ log_prefix = 'train' if self.training else 'val'
164
+ log = {}
165
+ log[f"{log_prefix}/loss"] = loss.mean()
166
+ log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167
+ logits, targets, k=1, reduction="mean"
168
+ )
169
+ log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170
+ logits, targets, k=5, reduction="mean"
171
+ )
172
+
173
+ self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174
+ self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175
+ self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176
+ lr = self.optimizers().param_groups[0]['lr']
177
+ self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178
+
179
+ def shared_step(self, batch, t=None):
180
+ x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181
+ targets = self.get_conditioning(batch)
182
+ if targets.dim() == 4:
183
+ targets = targets.argmax(dim=1)
184
+ if t is None:
185
+ t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186
+ else:
187
+ t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188
+ x_noisy = self.get_x_noisy(x, t)
189
+ logits = self(x_noisy, t)
190
+
191
+ loss = F.cross_entropy(logits, targets, reduction='none')
192
+
193
+ self.write_logs(loss.detach(), logits.detach(), targets.detach())
194
+
195
+ loss = loss.mean()
196
+ return loss, logits, x_noisy, targets
197
+
198
+ def training_step(self, batch, batch_idx):
199
+ loss, *_ = self.shared_step(batch)
200
+ return loss
201
+
202
+ def reset_noise_accs(self):
203
+ self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204
+ range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205
+
206
+ def on_validation_start(self):
207
+ self.reset_noise_accs()
208
+
209
+ @torch.no_grad()
210
+ def validation_step(self, batch, batch_idx):
211
+ loss, *_ = self.shared_step(batch)
212
+
213
+ for t in self.noisy_acc:
214
+ _, logits, _, targets = self.shared_step(batch, t)
215
+ self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216
+ self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217
+
218
+ return loss
219
+
220
+ def configure_optimizers(self):
221
+ optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222
+
223
+ if self.use_scheduler:
224
+ scheduler = instantiate_from_config(self.scheduler_config)
225
+
226
+ print("Setting up LambdaLR scheduler...")
227
+ scheduler = [
228
+ {
229
+ 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230
+ 'interval': 'step',
231
+ 'frequency': 1
232
+ }]
233
+ return [optimizer], scheduler
234
+
235
+ return optimizer
236
+
237
+ @torch.no_grad()
238
+ def log_images(self, batch, N=8, *args, **kwargs):
239
+ log = dict()
240
+ x = self.get_input(batch, self.diffusion_model.first_stage_key)
241
+ log['inputs'] = x
242
+
243
+ y = self.get_conditioning(batch)
244
+
245
+ if self.label_key == 'class_label':
246
+ y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247
+ log['labels'] = y
248
+
249
+ if ismap(y):
250
+ log['labels'] = self.diffusion_model.to_rgb(y)
251
+
252
+ for step in range(self.log_steps):
253
+ current_time = step * self.log_time_interval
254
+
255
+ _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256
+
257
+ log[f'inputs@t{current_time}'] = x_noisy
258
+
259
+ pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260
+ pred = rearrange(pred, 'b h w c -> b c h w')
261
+
262
+ log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263
+
264
+ for key in log:
265
+ log[key] = log[key][:N]
266
+
267
+ return log
latent-diffusion/ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
+
10
+
11
+ class DDIMSampler(object):
12
+ def __init__(self, model, schedule="linear", **kwargs):
13
+ super().__init__()
14
+ self.model = model
15
+ self.ddpm_num_timesteps = model.num_timesteps
16
+ self.schedule = schedule
17
+
18
+ def register_buffer(self, name, attr):
19
+ if type(attr) == torch.Tensor:
20
+ if attr.device != torch.device("cuda"):
21
+ attr = attr.to(torch.device("cuda"))
22
+ setattr(self, name, attr)
23
+
24
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
26
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
27
+ alphas_cumprod = self.model.alphas_cumprod
28
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
29
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
30
+
31
+ self.register_buffer('betas', to_torch(self.model.betas))
32
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
33
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
34
+
35
+ # calculations for diffusion q(x_t | x_{t-1}) and others
36
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
37
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
38
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
40
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
41
+
42
+ # ddim sampling parameters
43
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
44
+ ddim_timesteps=self.ddim_timesteps,
45
+ eta=ddim_eta,verbose=verbose)
46
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
47
+ self.register_buffer('ddim_alphas', ddim_alphas)
48
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
49
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
50
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
51
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
52
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
53
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
54
+
55
+ @torch.no_grad()
56
+ def sample(self,
57
+ S,
58
+ batch_size,
59
+ shape,
60
+ conditioning=None,
61
+ callback=None,
62
+ normals_sequence=None,
63
+ img_callback=None,
64
+ quantize_x0=False,
65
+ eta=0.,
66
+ mask=None,
67
+ x0=None,
68
+ temperature=1.,
69
+ noise_dropout=0.,
70
+ score_corrector=None,
71
+ corrector_kwargs=None,
72
+ verbose=True,
73
+ x_T=None,
74
+ log_every_t=100,
75
+ unconditional_guidance_scale=1.,
76
+ unconditional_conditioning=None,
77
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
78
+ **kwargs
79
+ ):
80
+ if conditioning is not None:
81
+ if isinstance(conditioning, dict):
82
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
83
+ if cbs != batch_size:
84
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
85
+ else:
86
+ if conditioning.shape[0] != batch_size:
87
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
88
+
89
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
90
+ # sampling
91
+ C, H, W = shape
92
+ size = (batch_size, C, H, W)
93
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
94
+
95
+ samples, intermediates = self.ddim_sampling(conditioning, size,
96
+ callback=callback,
97
+ img_callback=img_callback,
98
+ quantize_denoised=quantize_x0,
99
+ mask=mask, x0=x0,
100
+ ddim_use_original_steps=False,
101
+ noise_dropout=noise_dropout,
102
+ temperature=temperature,
103
+ score_corrector=score_corrector,
104
+ corrector_kwargs=corrector_kwargs,
105
+ x_T=x_T,
106
+ log_every_t=log_every_t,
107
+ unconditional_guidance_scale=unconditional_guidance_scale,
108
+ unconditional_conditioning=unconditional_conditioning,
109
+ )
110
+ return samples, intermediates
111
+
112
+ @torch.no_grad()
113
+ def ddim_sampling(self, cond, shape,
114
+ x_T=None, ddim_use_original_steps=False,
115
+ callback=None, timesteps=None, quantize_denoised=False,
116
+ mask=None, x0=None, img_callback=None, log_every_t=100,
117
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
118
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
119
+ device = self.model.betas.device
120
+ b = shape[0]
121
+ if x_T is None:
122
+ img = torch.randn(shape, device=device)
123
+ else:
124
+ img = x_T
125
+
126
+ if timesteps is None:
127
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
128
+ elif timesteps is not None and not ddim_use_original_steps:
129
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
130
+ timesteps = self.ddim_timesteps[:subset_end]
131
+
132
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
133
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
134
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
135
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
136
+
137
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
138
+
139
+ for i, step in enumerate(iterator):
140
+ index = total_steps - i - 1
141
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
142
+
143
+ if mask is not None:
144
+ assert x0 is not None
145
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
146
+ img = img_orig * mask + (1. - mask) * img
147
+
148
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
149
+ quantize_denoised=quantize_denoised, temperature=temperature,
150
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
151
+ corrector_kwargs=corrector_kwargs,
152
+ unconditional_guidance_scale=unconditional_guidance_scale,
153
+ unconditional_conditioning=unconditional_conditioning)
154
+ img, pred_x0 = outs
155
+ if callback: callback(i)
156
+ if img_callback: img_callback(pred_x0, i)
157
+
158
+ if index % log_every_t == 0 or index == total_steps - 1:
159
+ intermediates['x_inter'].append(img)
160
+ intermediates['pred_x0'].append(pred_x0)
161
+
162
+ return img, intermediates
163
+
164
+ @torch.no_grad()
165
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
166
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
167
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
168
+ b, *_, device = *x.shape, x.device
169
+
170
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
171
+ e_t = self.model.apply_model(x, t, c)
172
+ else:
173
+ x_in = torch.cat([x] * 2)
174
+ t_in = torch.cat([t] * 2)
175
+ c_in = torch.cat([unconditional_conditioning, c])
176
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
177
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
178
+
179
+ if score_corrector is not None:
180
+ assert self.model.parameterization == "eps"
181
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
182
+
183
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
184
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
185
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
186
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
187
+ # select parameters corresponding to the currently considered timestep
188
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
189
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
190
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
191
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
192
+
193
+ # current prediction for x_0
194
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
195
+ if quantize_denoised:
196
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
197
+ # direction pointing to x_t
198
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
199
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
200
+ if noise_dropout > 0.:
201
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
202
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
203
+ return x_prev, pred_x0
latent-diffusion/ldm/models/diffusion/ddpm.py ADDED
@@ -0,0 +1,1445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ wild mixture of
3
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
+ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
+ https://github.com/CompVis/taming-transformers
6
+ -- merci
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ import pytorch_lightning as pl
13
+ from torch.optim.lr_scheduler import LambdaLR
14
+ from einops import rearrange, repeat
15
+ from contextlib import contextmanager
16
+ from functools import partial
17
+ from tqdm import tqdm
18
+ from torchvision.utils import make_grid
19
+ from pytorch_lightning.utilities.distributed import rank_zero_only
20
+
21
+ from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
22
+ from ldm.modules.ema import LitEma
23
+ from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
24
+ from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
25
+ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
26
+ from ldm.models.diffusion.ddim import DDIMSampler
27
+
28
+
29
+ __conditioning_keys__ = {'concat': 'c_concat',
30
+ 'crossattn': 'c_crossattn',
31
+ 'adm': 'y'}
32
+
33
+
34
+ def disabled_train(self, mode=True):
35
+ """Overwrite model.train with this function to make sure train/eval mode
36
+ does not change anymore."""
37
+ return self
38
+
39
+
40
+ def uniform_on_device(r1, r2, shape, device):
41
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
42
+
43
+
44
+ class DDPM(pl.LightningModule):
45
+ # classic DDPM with Gaussian diffusion, in image space
46
+ def __init__(self,
47
+ unet_config,
48
+ timesteps=1000,
49
+ beta_schedule="linear",
50
+ loss_type="l2",
51
+ ckpt_path=None,
52
+ ignore_keys=[],
53
+ load_only_unet=False,
54
+ monitor="val/loss",
55
+ use_ema=True,
56
+ first_stage_key="image",
57
+ image_size=256,
58
+ channels=3,
59
+ log_every_t=100,
60
+ clip_denoised=True,
61
+ linear_start=1e-4,
62
+ linear_end=2e-2,
63
+ cosine_s=8e-3,
64
+ given_betas=None,
65
+ original_elbo_weight=0.,
66
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
67
+ l_simple_weight=1.,
68
+ conditioning_key=None,
69
+ parameterization="eps", # all assuming fixed variance schedules
70
+ scheduler_config=None,
71
+ use_positional_encodings=False,
72
+ learn_logvar=False,
73
+ logvar_init=0.,
74
+ ):
75
+ super().__init__()
76
+ assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
77
+ self.parameterization = parameterization
78
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
79
+ self.cond_stage_model = None
80
+ self.clip_denoised = clip_denoised
81
+ self.log_every_t = log_every_t
82
+ self.first_stage_key = first_stage_key
83
+ self.image_size = image_size # try conv?
84
+ self.channels = channels
85
+ self.use_positional_encodings = use_positional_encodings
86
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
87
+ count_params(self.model, verbose=True)
88
+ self.use_ema = use_ema
89
+ if self.use_ema:
90
+ self.model_ema = LitEma(self.model)
91
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
92
+
93
+ self.use_scheduler = scheduler_config is not None
94
+ if self.use_scheduler:
95
+ self.scheduler_config = scheduler_config
96
+
97
+ self.v_posterior = v_posterior
98
+ self.original_elbo_weight = original_elbo_weight
99
+ self.l_simple_weight = l_simple_weight
100
+
101
+ if monitor is not None:
102
+ self.monitor = monitor
103
+ if ckpt_path is not None:
104
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
105
+
106
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
107
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
108
+
109
+ self.loss_type = loss_type
110
+
111
+ self.learn_logvar = learn_logvar
112
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
113
+ if self.learn_logvar:
114
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
115
+
116
+
117
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
118
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
119
+ if exists(given_betas):
120
+ betas = given_betas
121
+ else:
122
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
123
+ cosine_s=cosine_s)
124
+ alphas = 1. - betas
125
+ alphas_cumprod = np.cumprod(alphas, axis=0)
126
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
127
+
128
+ timesteps, = betas.shape
129
+ self.num_timesteps = int(timesteps)
130
+ self.linear_start = linear_start
131
+ self.linear_end = linear_end
132
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
133
+
134
+ to_torch = partial(torch.tensor, dtype=torch.float32)
135
+
136
+ self.register_buffer('betas', to_torch(betas))
137
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
138
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
139
+
140
+ # calculations for diffusion q(x_t | x_{t-1}) and others
141
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
142
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
143
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
144
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
145
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
146
+
147
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
148
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
149
+ 1. - alphas_cumprod) + self.v_posterior * betas
150
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
151
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
152
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
153
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
154
+ self.register_buffer('posterior_mean_coef1', to_torch(
155
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
156
+ self.register_buffer('posterior_mean_coef2', to_torch(
157
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
158
+
159
+ if self.parameterization == "eps":
160
+ lvlb_weights = self.betas ** 2 / (
161
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
162
+ elif self.parameterization == "x0":
163
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
164
+ else:
165
+ raise NotImplementedError("mu not supported")
166
+ # TODO how to choose this term
167
+ lvlb_weights[0] = lvlb_weights[1]
168
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
169
+ assert not torch.isnan(self.lvlb_weights).all()
170
+
171
+ @contextmanager
172
+ def ema_scope(self, context=None):
173
+ if self.use_ema:
174
+ self.model_ema.store(self.model.parameters())
175
+ self.model_ema.copy_to(self.model)
176
+ if context is not None:
177
+ print(f"{context}: Switched to EMA weights")
178
+ try:
179
+ yield None
180
+ finally:
181
+ if self.use_ema:
182
+ self.model_ema.restore(self.model.parameters())
183
+ if context is not None:
184
+ print(f"{context}: Restored training weights")
185
+
186
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
187
+ sd = torch.load(path, map_location="cpu")
188
+ if "state_dict" in list(sd.keys()):
189
+ sd = sd["state_dict"]
190
+ keys = list(sd.keys())
191
+ for k in keys:
192
+ for ik in ignore_keys:
193
+ if k.startswith(ik):
194
+ print("Deleting key {} from state_dict.".format(k))
195
+ del sd[k]
196
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
197
+ sd, strict=False)
198
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
199
+ if len(missing) > 0:
200
+ print(f"Missing Keys: {missing}")
201
+ if len(unexpected) > 0:
202
+ print(f"Unexpected Keys: {unexpected}")
203
+
204
+ def q_mean_variance(self, x_start, t):
205
+ """
206
+ Get the distribution q(x_t | x_0).
207
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
208
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
209
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
210
+ """
211
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
212
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
213
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
214
+ return mean, variance, log_variance
215
+
216
+ def predict_start_from_noise(self, x_t, t, noise):
217
+ return (
218
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
219
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
220
+ )
221
+
222
+ def q_posterior(self, x_start, x_t, t):
223
+ posterior_mean = (
224
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
225
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
226
+ )
227
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
228
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
229
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
230
+
231
+ def p_mean_variance(self, x, t, clip_denoised: bool):
232
+ model_out = self.model(x, t)
233
+ if self.parameterization == "eps":
234
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
235
+ elif self.parameterization == "x0":
236
+ x_recon = model_out
237
+ if clip_denoised:
238
+ x_recon.clamp_(-1., 1.)
239
+
240
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
241
+ return model_mean, posterior_variance, posterior_log_variance
242
+
243
+ @torch.no_grad()
244
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
245
+ b, *_, device = *x.shape, x.device
246
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
247
+ noise = noise_like(x.shape, device, repeat_noise)
248
+ # no noise when t == 0
249
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
250
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
251
+
252
+ @torch.no_grad()
253
+ def p_sample_loop(self, shape, return_intermediates=False):
254
+ device = self.betas.device
255
+ b = shape[0]
256
+ img = torch.randn(shape, device=device)
257
+ intermediates = [img]
258
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
259
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
260
+ clip_denoised=self.clip_denoised)
261
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
262
+ intermediates.append(img)
263
+ if return_intermediates:
264
+ return img, intermediates
265
+ return img
266
+
267
+ @torch.no_grad()
268
+ def sample(self, batch_size=16, return_intermediates=False):
269
+ image_size = self.image_size
270
+ channels = self.channels
271
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
272
+ return_intermediates=return_intermediates)
273
+
274
+ def q_sample(self, x_start, t, noise=None):
275
+ noise = default(noise, lambda: torch.randn_like(x_start))
276
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
277
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
278
+
279
+ def get_loss(self, pred, target, mean=True):
280
+ if self.loss_type == 'l1':
281
+ loss = (target - pred).abs()
282
+ if mean:
283
+ loss = loss.mean()
284
+ elif self.loss_type == 'l2':
285
+ if mean:
286
+ loss = torch.nn.functional.mse_loss(target, pred)
287
+ else:
288
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
289
+ else:
290
+ raise NotImplementedError("unknown loss type '{loss_type}'")
291
+
292
+ return loss
293
+
294
+ def p_losses(self, x_start, t, noise=None):
295
+ noise = default(noise, lambda: torch.randn_like(x_start))
296
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
297
+ model_out = self.model(x_noisy, t)
298
+
299
+ loss_dict = {}
300
+ if self.parameterization == "eps":
301
+ target = noise
302
+ elif self.parameterization == "x0":
303
+ target = x_start
304
+ else:
305
+ raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
306
+
307
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
308
+
309
+ log_prefix = 'train' if self.training else 'val'
310
+
311
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
312
+ loss_simple = loss.mean() * self.l_simple_weight
313
+
314
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
315
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
316
+
317
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
318
+
319
+ loss_dict.update({f'{log_prefix}/loss': loss})
320
+
321
+ return loss, loss_dict
322
+
323
+ def forward(self, x, *args, **kwargs):
324
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
325
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
326
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
327
+ return self.p_losses(x, t, *args, **kwargs)
328
+
329
+ def get_input(self, batch, k):
330
+ x = batch[k]
331
+ if len(x.shape) == 3:
332
+ x = x[..., None]
333
+ x = rearrange(x, 'b h w c -> b c h w')
334
+ x = x.to(memory_format=torch.contiguous_format).float()
335
+ return x
336
+
337
+ def shared_step(self, batch):
338
+ x = self.get_input(batch, self.first_stage_key)
339
+ loss, loss_dict = self(x)
340
+ return loss, loss_dict
341
+
342
+ def training_step(self, batch, batch_idx):
343
+ loss, loss_dict = self.shared_step(batch)
344
+
345
+ self.log_dict(loss_dict, prog_bar=True,
346
+ logger=True, on_step=True, on_epoch=True)
347
+
348
+ self.log("global_step", self.global_step,
349
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
350
+
351
+ if self.use_scheduler:
352
+ lr = self.optimizers().param_groups[0]['lr']
353
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
354
+
355
+ return loss
356
+
357
+ @torch.no_grad()
358
+ def validation_step(self, batch, batch_idx):
359
+ _, loss_dict_no_ema = self.shared_step(batch)
360
+ with self.ema_scope():
361
+ _, loss_dict_ema = self.shared_step(batch)
362
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
363
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
364
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
365
+
366
+ def on_train_batch_end(self, *args, **kwargs):
367
+ if self.use_ema:
368
+ self.model_ema(self.model)
369
+
370
+ def _get_rows_from_list(self, samples):
371
+ n_imgs_per_row = len(samples)
372
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
373
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
374
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
375
+ return denoise_grid
376
+
377
+ @torch.no_grad()
378
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
379
+ log = dict()
380
+ x = self.get_input(batch, self.first_stage_key)
381
+ N = min(x.shape[0], N)
382
+ n_row = min(x.shape[0], n_row)
383
+ x = x.to(self.device)[:N]
384
+ log["inputs"] = x
385
+
386
+ # get diffusion row
387
+ diffusion_row = list()
388
+ x_start = x[:n_row]
389
+
390
+ for t in range(self.num_timesteps):
391
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
392
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
393
+ t = t.to(self.device).long()
394
+ noise = torch.randn_like(x_start)
395
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
396
+ diffusion_row.append(x_noisy)
397
+
398
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
399
+
400
+ if sample:
401
+ # get denoise row
402
+ with self.ema_scope("Plotting"):
403
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
404
+
405
+ log["samples"] = samples
406
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
407
+
408
+ if return_keys:
409
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
410
+ return log
411
+ else:
412
+ return {key: log[key] for key in return_keys}
413
+ return log
414
+
415
+ def configure_optimizers(self):
416
+ lr = self.learning_rate
417
+ params = list(self.model.parameters())
418
+ if self.learn_logvar:
419
+ params = params + [self.logvar]
420
+ opt = torch.optim.AdamW(params, lr=lr)
421
+ return opt
422
+
423
+
424
+ class LatentDiffusion(DDPM):
425
+ """main class"""
426
+ def __init__(self,
427
+ first_stage_config,
428
+ cond_stage_config,
429
+ num_timesteps_cond=None,
430
+ cond_stage_key="image",
431
+ cond_stage_trainable=False,
432
+ concat_mode=True,
433
+ cond_stage_forward=None,
434
+ conditioning_key=None,
435
+ scale_factor=1.0,
436
+ scale_by_std=False,
437
+ *args, **kwargs):
438
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
439
+ self.scale_by_std = scale_by_std
440
+ assert self.num_timesteps_cond <= kwargs['timesteps']
441
+ # for backwards compatibility after implementation of DiffusionWrapper
442
+ if conditioning_key is None:
443
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
444
+ if cond_stage_config == '__is_unconditional__':
445
+ conditioning_key = None
446
+ ckpt_path = kwargs.pop("ckpt_path", None)
447
+ ignore_keys = kwargs.pop("ignore_keys", [])
448
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
449
+ self.concat_mode = concat_mode
450
+ self.cond_stage_trainable = cond_stage_trainable
451
+ self.cond_stage_key = cond_stage_key
452
+ try:
453
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
454
+ except:
455
+ self.num_downs = 0
456
+ if not scale_by_std:
457
+ self.scale_factor = scale_factor
458
+ else:
459
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
460
+ self.instantiate_first_stage(first_stage_config)
461
+ self.instantiate_cond_stage(cond_stage_config)
462
+ self.cond_stage_forward = cond_stage_forward
463
+ self.clip_denoised = False
464
+ self.bbox_tokenizer = None
465
+
466
+ self.restarted_from_ckpt = False
467
+ if ckpt_path is not None:
468
+ self.init_from_ckpt(ckpt_path, ignore_keys)
469
+ self.restarted_from_ckpt = True
470
+
471
+ def make_cond_schedule(self, ):
472
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
473
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
474
+ self.cond_ids[:self.num_timesteps_cond] = ids
475
+
476
+ @rank_zero_only
477
+ @torch.no_grad()
478
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
479
+ # only for very first batch
480
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
481
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
482
+ # set rescale weight to 1./std of encodings
483
+ print("### USING STD-RESCALING ###")
484
+ x = super().get_input(batch, self.first_stage_key)
485
+ x = x.to(self.device)
486
+ encoder_posterior = self.encode_first_stage(x)
487
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
488
+ del self.scale_factor
489
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
490
+ print(f"setting self.scale_factor to {self.scale_factor}")
491
+ print("### USING STD-RESCALING ###")
492
+
493
+ def register_schedule(self,
494
+ given_betas=None, beta_schedule="linear", timesteps=1000,
495
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
496
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
497
+
498
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
499
+ if self.shorten_cond_schedule:
500
+ self.make_cond_schedule()
501
+
502
+ def instantiate_first_stage(self, config):
503
+ model = instantiate_from_config(config)
504
+ self.first_stage_model = model.eval()
505
+ self.first_stage_model.train = disabled_train
506
+ for param in self.first_stage_model.parameters():
507
+ param.requires_grad = False
508
+
509
+ def instantiate_cond_stage(self, config):
510
+ if not self.cond_stage_trainable:
511
+ if config == "__is_first_stage__":
512
+ print("Using first stage also as cond stage.")
513
+ self.cond_stage_model = self.first_stage_model
514
+ elif config == "__is_unconditional__":
515
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
516
+ self.cond_stage_model = None
517
+ # self.be_unconditional = True
518
+ else:
519
+ model = instantiate_from_config(config)
520
+ self.cond_stage_model = model.eval()
521
+ self.cond_stage_model.train = disabled_train
522
+ for param in self.cond_stage_model.parameters():
523
+ param.requires_grad = False
524
+ else:
525
+ assert config != '__is_first_stage__'
526
+ assert config != '__is_unconditional__'
527
+ model = instantiate_from_config(config)
528
+ self.cond_stage_model = model
529
+
530
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
531
+ denoise_row = []
532
+ for zd in tqdm(samples, desc=desc):
533
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
534
+ force_not_quantize=force_no_decoder_quantization))
535
+ n_imgs_per_row = len(denoise_row)
536
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
537
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
538
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
539
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
540
+ return denoise_grid
541
+
542
+ def get_first_stage_encoding(self, encoder_posterior):
543
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
544
+ z = encoder_posterior.sample()
545
+ elif isinstance(encoder_posterior, torch.Tensor):
546
+ z = encoder_posterior
547
+ else:
548
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
549
+ return self.scale_factor * z
550
+
551
+ def get_learned_conditioning(self, c):
552
+ if self.cond_stage_forward is None:
553
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
554
+ c = self.cond_stage_model.encode(c)
555
+ if isinstance(c, DiagonalGaussianDistribution):
556
+ c = c.mode()
557
+ else:
558
+ c = self.cond_stage_model(c)
559
+ else:
560
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
561
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
562
+ return c
563
+
564
+ def meshgrid(self, h, w):
565
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
566
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
567
+
568
+ arr = torch.cat([y, x], dim=-1)
569
+ return arr
570
+
571
+ def delta_border(self, h, w):
572
+ """
573
+ :param h: height
574
+ :param w: width
575
+ :return: normalized distance to image border,
576
+ wtith min distance = 0 at border and max dist = 0.5 at image center
577
+ """
578
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
579
+ arr = self.meshgrid(h, w) / lower_right_corner
580
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
581
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
582
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
583
+ return edge_dist
584
+
585
+ def get_weighting(self, h, w, Ly, Lx, device):
586
+ weighting = self.delta_border(h, w)
587
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
588
+ self.split_input_params["clip_max_weight"], )
589
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
590
+
591
+ if self.split_input_params["tie_braker"]:
592
+ L_weighting = self.delta_border(Ly, Lx)
593
+ L_weighting = torch.clip(L_weighting,
594
+ self.split_input_params["clip_min_tie_weight"],
595
+ self.split_input_params["clip_max_tie_weight"])
596
+
597
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
598
+ weighting = weighting * L_weighting
599
+ return weighting
600
+
601
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
602
+ """
603
+ :param x: img of size (bs, c, h, w)
604
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
605
+ """
606
+ bs, nc, h, w = x.shape
607
+
608
+ # number of crops in image
609
+ Ly = (h - kernel_size[0]) // stride[0] + 1
610
+ Lx = (w - kernel_size[1]) // stride[1] + 1
611
+
612
+ if uf == 1 and df == 1:
613
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
614
+ unfold = torch.nn.Unfold(**fold_params)
615
+
616
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
617
+
618
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
619
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
620
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
621
+
622
+ elif uf > 1 and df == 1:
623
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
624
+ unfold = torch.nn.Unfold(**fold_params)
625
+
626
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
627
+ dilation=1, padding=0,
628
+ stride=(stride[0] * uf, stride[1] * uf))
629
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
630
+
631
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
632
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
633
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
634
+
635
+ elif df > 1 and uf == 1:
636
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
637
+ unfold = torch.nn.Unfold(**fold_params)
638
+
639
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
640
+ dilation=1, padding=0,
641
+ stride=(stride[0] // df, stride[1] // df))
642
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
643
+
644
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
645
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
646
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
647
+
648
+ else:
649
+ raise NotImplementedError
650
+
651
+ return fold, unfold, normalization, weighting
652
+
653
+ @torch.no_grad()
654
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
655
+ cond_key=None, return_original_cond=False, bs=None):
656
+ x = super().get_input(batch, k)
657
+ if bs is not None:
658
+ x = x[:bs]
659
+ x = x.to(self.device)
660
+ encoder_posterior = self.encode_first_stage(x)
661
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
662
+
663
+ if self.model.conditioning_key is not None:
664
+ if cond_key is None:
665
+ cond_key = self.cond_stage_key
666
+ if cond_key != self.first_stage_key:
667
+ if cond_key in ['caption', 'coordinates_bbox']:
668
+ xc = batch[cond_key]
669
+ elif cond_key == 'class_label':
670
+ xc = batch
671
+ else:
672
+ xc = super().get_input(batch, cond_key).to(self.device)
673
+ else:
674
+ xc = x
675
+ if not self.cond_stage_trainable or force_c_encode:
676
+ if isinstance(xc, dict) or isinstance(xc, list):
677
+ # import pudb; pudb.set_trace()
678
+ c = self.get_learned_conditioning(xc)
679
+ else:
680
+ c = self.get_learned_conditioning(xc.to(self.device))
681
+ else:
682
+ c = xc
683
+ if bs is not None:
684
+ c = c[:bs]
685
+
686
+ if self.use_positional_encodings:
687
+ pos_x, pos_y = self.compute_latent_shifts(batch)
688
+ ckey = __conditioning_keys__[self.model.conditioning_key]
689
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
690
+
691
+ else:
692
+ c = None
693
+ xc = None
694
+ if self.use_positional_encodings:
695
+ pos_x, pos_y = self.compute_latent_shifts(batch)
696
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
697
+ out = [z, c]
698
+ if return_first_stage_outputs:
699
+ xrec = self.decode_first_stage(z)
700
+ out.extend([x, xrec])
701
+ if return_original_cond:
702
+ out.append(xc)
703
+ return out
704
+
705
+ @torch.no_grad()
706
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
707
+ if predict_cids:
708
+ if z.dim() == 4:
709
+ z = torch.argmax(z.exp(), dim=1).long()
710
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
711
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
712
+
713
+ z = 1. / self.scale_factor * z
714
+
715
+ if hasattr(self, "split_input_params"):
716
+ if self.split_input_params["patch_distributed_vq"]:
717
+ ks = self.split_input_params["ks"] # eg. (128, 128)
718
+ stride = self.split_input_params["stride"] # eg. (64, 64)
719
+ uf = self.split_input_params["vqf"]
720
+ bs, nc, h, w = z.shape
721
+ if ks[0] > h or ks[1] > w:
722
+ ks = (min(ks[0], h), min(ks[1], w))
723
+ print("reducing Kernel")
724
+
725
+ if stride[0] > h or stride[1] > w:
726
+ stride = (min(stride[0], h), min(stride[1], w))
727
+ print("reducing stride")
728
+
729
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
730
+
731
+ z = unfold(z) # (bn, nc * prod(**ks), L)
732
+ # 1. Reshape to img shape
733
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
734
+
735
+ # 2. apply model loop over last dim
736
+ if isinstance(self.first_stage_model, VQModelInterface):
737
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
738
+ force_not_quantize=predict_cids or force_not_quantize)
739
+ for i in range(z.shape[-1])]
740
+ else:
741
+
742
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
743
+ for i in range(z.shape[-1])]
744
+
745
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
746
+ o = o * weighting
747
+ # Reverse 1. reshape to img shape
748
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
749
+ # stitch crops together
750
+ decoded = fold(o)
751
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
752
+ return decoded
753
+ else:
754
+ if isinstance(self.first_stage_model, VQModelInterface):
755
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
756
+ else:
757
+ return self.first_stage_model.decode(z)
758
+
759
+ else:
760
+ if isinstance(self.first_stage_model, VQModelInterface):
761
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
762
+ else:
763
+ return self.first_stage_model.decode(z)
764
+
765
+ # same as above but without decorator
766
+ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
767
+ if predict_cids:
768
+ if z.dim() == 4:
769
+ z = torch.argmax(z.exp(), dim=1).long()
770
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
771
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
772
+
773
+ z = 1. / self.scale_factor * z
774
+
775
+ if hasattr(self, "split_input_params"):
776
+ if self.split_input_params["patch_distributed_vq"]:
777
+ ks = self.split_input_params["ks"] # eg. (128, 128)
778
+ stride = self.split_input_params["stride"] # eg. (64, 64)
779
+ uf = self.split_input_params["vqf"]
780
+ bs, nc, h, w = z.shape
781
+ if ks[0] > h or ks[1] > w:
782
+ ks = (min(ks[0], h), min(ks[1], w))
783
+ print("reducing Kernel")
784
+
785
+ if stride[0] > h or stride[1] > w:
786
+ stride = (min(stride[0], h), min(stride[1], w))
787
+ print("reducing stride")
788
+
789
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
790
+
791
+ z = unfold(z) # (bn, nc * prod(**ks), L)
792
+ # 1. Reshape to img shape
793
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
794
+
795
+ # 2. apply model loop over last dim
796
+ if isinstance(self.first_stage_model, VQModelInterface):
797
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
798
+ force_not_quantize=predict_cids or force_not_quantize)
799
+ for i in range(z.shape[-1])]
800
+ else:
801
+
802
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
803
+ for i in range(z.shape[-1])]
804
+
805
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
806
+ o = o * weighting
807
+ # Reverse 1. reshape to img shape
808
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
809
+ # stitch crops together
810
+ decoded = fold(o)
811
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
812
+ return decoded
813
+ else:
814
+ if isinstance(self.first_stage_model, VQModelInterface):
815
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
816
+ else:
817
+ return self.first_stage_model.decode(z)
818
+
819
+ else:
820
+ if isinstance(self.first_stage_model, VQModelInterface):
821
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
822
+ else:
823
+ return self.first_stage_model.decode(z)
824
+
825
+ @torch.no_grad()
826
+ def encode_first_stage(self, x):
827
+ if hasattr(self, "split_input_params"):
828
+ if self.split_input_params["patch_distributed_vq"]:
829
+ ks = self.split_input_params["ks"] # eg. (128, 128)
830
+ stride = self.split_input_params["stride"] # eg. (64, 64)
831
+ df = self.split_input_params["vqf"]
832
+ self.split_input_params['original_image_size'] = x.shape[-2:]
833
+ bs, nc, h, w = x.shape
834
+ if ks[0] > h or ks[1] > w:
835
+ ks = (min(ks[0], h), min(ks[1], w))
836
+ print("reducing Kernel")
837
+
838
+ if stride[0] > h or stride[1] > w:
839
+ stride = (min(stride[0], h), min(stride[1], w))
840
+ print("reducing stride")
841
+
842
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
843
+ z = unfold(x) # (bn, nc * prod(**ks), L)
844
+ # Reshape to img shape
845
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
846
+
847
+ output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
848
+ for i in range(z.shape[-1])]
849
+
850
+ o = torch.stack(output_list, axis=-1)
851
+ o = o * weighting
852
+
853
+ # Reverse reshape to img shape
854
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
855
+ # stitch crops together
856
+ decoded = fold(o)
857
+ decoded = decoded / normalization
858
+ return decoded
859
+
860
+ else:
861
+ return self.first_stage_model.encode(x)
862
+ else:
863
+ return self.first_stage_model.encode(x)
864
+
865
+ def shared_step(self, batch, **kwargs):
866
+ x, c = self.get_input(batch, self.first_stage_key)
867
+ loss = self(x, c)
868
+ return loss
869
+
870
+ def forward(self, x, c, *args, **kwargs):
871
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
872
+ if self.model.conditioning_key is not None:
873
+ assert c is not None
874
+ if self.cond_stage_trainable:
875
+ c = self.get_learned_conditioning(c)
876
+ if self.shorten_cond_schedule: # TODO: drop this option
877
+ tc = self.cond_ids[t].to(self.device)
878
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
879
+ return self.p_losses(x, c, t, *args, **kwargs)
880
+
881
+ def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
882
+ def rescale_bbox(bbox):
883
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
884
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
885
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
886
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
887
+ return x0, y0, w, h
888
+
889
+ return [rescale_bbox(b) for b in bboxes]
890
+
891
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
892
+
893
+ if isinstance(cond, dict):
894
+ # hybrid case, cond is exptected to be a dict
895
+ pass
896
+ else:
897
+ if not isinstance(cond, list):
898
+ cond = [cond]
899
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
900
+ cond = {key: cond}
901
+
902
+ if hasattr(self, "split_input_params"):
903
+ assert len(cond) == 1 # todo can only deal with one conditioning atm
904
+ assert not return_ids
905
+ ks = self.split_input_params["ks"] # eg. (128, 128)
906
+ stride = self.split_input_params["stride"] # eg. (64, 64)
907
+
908
+ h, w = x_noisy.shape[-2:]
909
+
910
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
911
+
912
+ z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
913
+ # Reshape to img shape
914
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
915
+ z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
916
+
917
+ if self.cond_stage_key in ["image", "LR_image", "segmentation",
918
+ 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
919
+ c_key = next(iter(cond.keys())) # get key
920
+ c = next(iter(cond.values())) # get value
921
+ assert (len(c) == 1) # todo extend to list with more than one elem
922
+ c = c[0] # get element
923
+
924
+ c = unfold(c)
925
+ c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
926
+
927
+ cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
928
+
929
+ elif self.cond_stage_key == 'coordinates_bbox':
930
+ assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
931
+
932
+ # assuming padding of unfold is always 0 and its dilation is always 1
933
+ n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
934
+ full_img_h, full_img_w = self.split_input_params['original_image_size']
935
+ # as we are operating on latents, we need the factor from the original image size to the
936
+ # spatial latent size to properly rescale the crops for regenerating the bbox annotations
937
+ num_downs = self.first_stage_model.encoder.num_resolutions - 1
938
+ rescale_latent = 2 ** (num_downs)
939
+
940
+ # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
941
+ # need to rescale the tl patch coordinates to be in between (0,1)
942
+ tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
943
+ rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
944
+ for patch_nr in range(z.shape[-1])]
945
+
946
+ # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
947
+ patch_limits = [(x_tl, y_tl,
948
+ rescale_latent * ks[0] / full_img_w,
949
+ rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
950
+ # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
951
+
952
+ # tokenize crop coordinates for the bounding boxes of the respective patches
953
+ patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
954
+ for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
955
+ print(patch_limits_tknzd[0].shape)
956
+ # cut tknzd crop position from conditioning
957
+ assert isinstance(cond, dict), 'cond must be dict to be fed into model'
958
+ cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
959
+ print(cut_cond.shape)
960
+
961
+ adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
962
+ adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
963
+ print(adapted_cond.shape)
964
+ adapted_cond = self.get_learned_conditioning(adapted_cond)
965
+ print(adapted_cond.shape)
966
+ adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
967
+ print(adapted_cond.shape)
968
+
969
+ cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
970
+
971
+ else:
972
+ cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
973
+
974
+ # apply model by loop over crops
975
+ output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
976
+ assert not isinstance(output_list[0],
977
+ tuple) # todo cant deal with multiple model outputs check this never happens
978
+
979
+ o = torch.stack(output_list, axis=-1)
980
+ o = o * weighting
981
+ # Reverse reshape to img shape
982
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
983
+ # stitch crops together
984
+ x_recon = fold(o) / normalization
985
+
986
+ else:
987
+ x_recon = self.model(x_noisy, t, **cond)
988
+
989
+ if isinstance(x_recon, tuple) and not return_ids:
990
+ return x_recon[0]
991
+ else:
992
+ return x_recon
993
+
994
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
995
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
996
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
997
+
998
+ def _prior_bpd(self, x_start):
999
+ """
1000
+ Get the prior KL term for the variational lower-bound, measured in
1001
+ bits-per-dim.
1002
+ This term can't be optimized, as it only depends on the encoder.
1003
+ :param x_start: the [N x C x ...] tensor of inputs.
1004
+ :return: a batch of [N] KL values (in bits), one per batch element.
1005
+ """
1006
+ batch_size = x_start.shape[0]
1007
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1008
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1009
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
1010
+ return mean_flat(kl_prior) / np.log(2.0)
1011
+
1012
+ def p_losses(self, x_start, cond, t, noise=None):
1013
+ noise = default(noise, lambda: torch.randn_like(x_start))
1014
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
1015
+ model_output = self.apply_model(x_noisy, t, cond)
1016
+
1017
+ loss_dict = {}
1018
+ prefix = 'train' if self.training else 'val'
1019
+
1020
+ if self.parameterization == "x0":
1021
+ target = x_start
1022
+ elif self.parameterization == "eps":
1023
+ target = noise
1024
+ else:
1025
+ raise NotImplementedError()
1026
+
1027
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
1028
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
1029
+
1030
+ logvar_t = self.logvar[t].to(self.device)
1031
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
1032
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
1033
+ if self.learn_logvar:
1034
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
1035
+ loss_dict.update({'logvar': self.logvar.data.mean()})
1036
+
1037
+ loss = self.l_simple_weight * loss.mean()
1038
+
1039
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
1040
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
1041
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
1042
+ loss += (self.original_elbo_weight * loss_vlb)
1043
+ loss_dict.update({f'{prefix}/loss': loss})
1044
+
1045
+ return loss, loss_dict
1046
+
1047
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
1048
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
1049
+ t_in = t
1050
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
1051
+
1052
+ if score_corrector is not None:
1053
+ assert self.parameterization == "eps"
1054
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
1055
+
1056
+ if return_codebook_ids:
1057
+ model_out, logits = model_out
1058
+
1059
+ if self.parameterization == "eps":
1060
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
1061
+ elif self.parameterization == "x0":
1062
+ x_recon = model_out
1063
+ else:
1064
+ raise NotImplementedError()
1065
+
1066
+ if clip_denoised:
1067
+ x_recon.clamp_(-1., 1.)
1068
+ if quantize_denoised:
1069
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
1070
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
1071
+ if return_codebook_ids:
1072
+ return model_mean, posterior_variance, posterior_log_variance, logits
1073
+ elif return_x0:
1074
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
1075
+ else:
1076
+ return model_mean, posterior_variance, posterior_log_variance
1077
+
1078
+ @torch.no_grad()
1079
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
1080
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
1081
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
1082
+ b, *_, device = *x.shape, x.device
1083
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
1084
+ return_codebook_ids=return_codebook_ids,
1085
+ quantize_denoised=quantize_denoised,
1086
+ return_x0=return_x0,
1087
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1088
+ if return_codebook_ids:
1089
+ raise DeprecationWarning("Support dropped.")
1090
+ model_mean, _, model_log_variance, logits = outputs
1091
+ elif return_x0:
1092
+ model_mean, _, model_log_variance, x0 = outputs
1093
+ else:
1094
+ model_mean, _, model_log_variance = outputs
1095
+
1096
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
1097
+ if noise_dropout > 0.:
1098
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1099
+ # no noise when t == 0
1100
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
1101
+
1102
+ if return_codebook_ids:
1103
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
1104
+ if return_x0:
1105
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
1106
+ else:
1107
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
1108
+
1109
+ @torch.no_grad()
1110
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
1111
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
1112
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
1113
+ log_every_t=None):
1114
+ if not log_every_t:
1115
+ log_every_t = self.log_every_t
1116
+ timesteps = self.num_timesteps
1117
+ if batch_size is not None:
1118
+ b = batch_size if batch_size is not None else shape[0]
1119
+ shape = [batch_size] + list(shape)
1120
+ else:
1121
+ b = batch_size = shape[0]
1122
+ if x_T is None:
1123
+ img = torch.randn(shape, device=self.device)
1124
+ else:
1125
+ img = x_T
1126
+ intermediates = []
1127
+ if cond is not None:
1128
+ if isinstance(cond, dict):
1129
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1130
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1131
+ else:
1132
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1133
+
1134
+ if start_T is not None:
1135
+ timesteps = min(timesteps, start_T)
1136
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1137
+ total=timesteps) if verbose else reversed(
1138
+ range(0, timesteps))
1139
+ if type(temperature) == float:
1140
+ temperature = [temperature] * timesteps
1141
+
1142
+ for i in iterator:
1143
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1144
+ if self.shorten_cond_schedule:
1145
+ assert self.model.conditioning_key != 'hybrid'
1146
+ tc = self.cond_ids[ts].to(cond.device)
1147
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1148
+
1149
+ img, x0_partial = self.p_sample(img, cond, ts,
1150
+ clip_denoised=self.clip_denoised,
1151
+ quantize_denoised=quantize_denoised, return_x0=True,
1152
+ temperature=temperature[i], noise_dropout=noise_dropout,
1153
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1154
+ if mask is not None:
1155
+ assert x0 is not None
1156
+ img_orig = self.q_sample(x0, ts)
1157
+ img = img_orig * mask + (1. - mask) * img
1158
+
1159
+ if i % log_every_t == 0 or i == timesteps - 1:
1160
+ intermediates.append(x0_partial)
1161
+ if callback: callback(i)
1162
+ if img_callback: img_callback(img, i)
1163
+ return img, intermediates
1164
+
1165
+ @torch.no_grad()
1166
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
1167
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1168
+ mask=None, x0=None, img_callback=None, start_T=None,
1169
+ log_every_t=None):
1170
+
1171
+ if not log_every_t:
1172
+ log_every_t = self.log_every_t
1173
+ device = self.betas.device
1174
+ b = shape[0]
1175
+ if x_T is None:
1176
+ img = torch.randn(shape, device=device)
1177
+ else:
1178
+ img = x_T
1179
+
1180
+ intermediates = [img]
1181
+ if timesteps is None:
1182
+ timesteps = self.num_timesteps
1183
+
1184
+ if start_T is not None:
1185
+ timesteps = min(timesteps, start_T)
1186
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1187
+ range(0, timesteps))
1188
+
1189
+ if mask is not None:
1190
+ assert x0 is not None
1191
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1192
+
1193
+ for i in iterator:
1194
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
1195
+ if self.shorten_cond_schedule:
1196
+ assert self.model.conditioning_key != 'hybrid'
1197
+ tc = self.cond_ids[ts].to(cond.device)
1198
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1199
+
1200
+ img = self.p_sample(img, cond, ts,
1201
+ clip_denoised=self.clip_denoised,
1202
+ quantize_denoised=quantize_denoised)
1203
+ if mask is not None:
1204
+ img_orig = self.q_sample(x0, ts)
1205
+ img = img_orig * mask + (1. - mask) * img
1206
+
1207
+ if i % log_every_t == 0 or i == timesteps - 1:
1208
+ intermediates.append(img)
1209
+ if callback: callback(i)
1210
+ if img_callback: img_callback(img, i)
1211
+
1212
+ if return_intermediates:
1213
+ return img, intermediates
1214
+ return img
1215
+
1216
+ @torch.no_grad()
1217
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1218
+ verbose=True, timesteps=None, quantize_denoised=False,
1219
+ mask=None, x0=None, shape=None,**kwargs):
1220
+ if shape is None:
1221
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
1222
+ if cond is not None:
1223
+ if isinstance(cond, dict):
1224
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1225
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1226
+ else:
1227
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1228
+ return self.p_sample_loop(cond,
1229
+ shape,
1230
+ return_intermediates=return_intermediates, x_T=x_T,
1231
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1232
+ mask=mask, x0=x0)
1233
+
1234
+ @torch.no_grad()
1235
+ def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
1236
+
1237
+ if ddim:
1238
+ ddim_sampler = DDIMSampler(self)
1239
+ shape = (self.channels, self.image_size, self.image_size)
1240
+ samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
1241
+ shape,cond,verbose=False,**kwargs)
1242
+
1243
+ else:
1244
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1245
+ return_intermediates=True,**kwargs)
1246
+
1247
+ return samples, intermediates
1248
+
1249
+
1250
+ @torch.no_grad()
1251
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1252
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1253
+ plot_diffusion_rows=True, **kwargs):
1254
+
1255
+ use_ddim = ddim_steps is not None
1256
+
1257
+ log = dict()
1258
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
1259
+ return_first_stage_outputs=True,
1260
+ force_c_encode=True,
1261
+ return_original_cond=True,
1262
+ bs=N)
1263
+ N = min(x.shape[0], N)
1264
+ n_row = min(x.shape[0], n_row)
1265
+ log["inputs"] = x
1266
+ log["reconstruction"] = xrec
1267
+ if self.model.conditioning_key is not None:
1268
+ if hasattr(self.cond_stage_model, "decode"):
1269
+ xc = self.cond_stage_model.decode(c)
1270
+ log["conditioning"] = xc
1271
+ elif self.cond_stage_key in ["caption"]:
1272
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
1273
+ log["conditioning"] = xc
1274
+ elif self.cond_stage_key == 'class_label':
1275
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
1276
+ log['conditioning'] = xc
1277
+ elif isimage(xc):
1278
+ log["conditioning"] = xc
1279
+ if ismap(xc):
1280
+ log["original_conditioning"] = self.to_rgb(xc)
1281
+
1282
+ if plot_diffusion_rows:
1283
+ # get diffusion row
1284
+ diffusion_row = list()
1285
+ z_start = z[:n_row]
1286
+ for t in range(self.num_timesteps):
1287
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1288
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1289
+ t = t.to(self.device).long()
1290
+ noise = torch.randn_like(z_start)
1291
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1292
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1293
+
1294
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1295
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1296
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1297
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1298
+ log["diffusion_row"] = diffusion_grid
1299
+
1300
+ if sample:
1301
+ # get denoise row
1302
+ with self.ema_scope("Plotting"):
1303
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1304
+ ddim_steps=ddim_steps,eta=ddim_eta)
1305
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1306
+ x_samples = self.decode_first_stage(samples)
1307
+ log["samples"] = x_samples
1308
+ if plot_denoise_rows:
1309
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1310
+ log["denoise_row"] = denoise_grid
1311
+
1312
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1313
+ self.first_stage_model, IdentityFirstStage):
1314
+ # also display when quantizing x0 while sampling
1315
+ with self.ema_scope("Plotting Quantized Denoised"):
1316
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1317
+ ddim_steps=ddim_steps,eta=ddim_eta,
1318
+ quantize_denoised=True)
1319
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1320
+ # quantize_denoised=True)
1321
+ x_samples = self.decode_first_stage(samples.to(self.device))
1322
+ log["samples_x0_quantized"] = x_samples
1323
+
1324
+ if inpaint:
1325
+ # make a simple center square
1326
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
1327
+ mask = torch.ones(N, h, w).to(self.device)
1328
+ # zeros will be filled in
1329
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1330
+ mask = mask[:, None, ...]
1331
+ with self.ema_scope("Plotting Inpaint"):
1332
+
1333
+ samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
1334
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1335
+ x_samples = self.decode_first_stage(samples.to(self.device))
1336
+ log["samples_inpainting"] = x_samples
1337
+ log["mask"] = mask
1338
+
1339
+ # outpaint
1340
+ with self.ema_scope("Plotting Outpaint"):
1341
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
1342
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1343
+ x_samples = self.decode_first_stage(samples.to(self.device))
1344
+ log["samples_outpainting"] = x_samples
1345
+
1346
+ if plot_progressive_rows:
1347
+ with self.ema_scope("Plotting Progressives"):
1348
+ img, progressives = self.progressive_denoising(c,
1349
+ shape=(self.channels, self.image_size, self.image_size),
1350
+ batch_size=N)
1351
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1352
+ log["progressive_row"] = prog_row
1353
+
1354
+ if return_keys:
1355
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1356
+ return log
1357
+ else:
1358
+ return {key: log[key] for key in return_keys}
1359
+ return log
1360
+
1361
+ def configure_optimizers(self):
1362
+ lr = self.learning_rate
1363
+ params = list(self.model.parameters())
1364
+ if self.cond_stage_trainable:
1365
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1366
+ params = params + list(self.cond_stage_model.parameters())
1367
+ if self.learn_logvar:
1368
+ print('Diffusion model optimizing logvar')
1369
+ params.append(self.logvar)
1370
+ opt = torch.optim.AdamW(params, lr=lr)
1371
+ if self.use_scheduler:
1372
+ assert 'target' in self.scheduler_config
1373
+ scheduler = instantiate_from_config(self.scheduler_config)
1374
+
1375
+ print("Setting up LambdaLR scheduler...")
1376
+ scheduler = [
1377
+ {
1378
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1379
+ 'interval': 'step',
1380
+ 'frequency': 1
1381
+ }]
1382
+ return [opt], scheduler
1383
+ return opt
1384
+
1385
+ @torch.no_grad()
1386
+ def to_rgb(self, x):
1387
+ x = x.float()
1388
+ if not hasattr(self, "colorize"):
1389
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1390
+ x = nn.functional.conv2d(x, weight=self.colorize)
1391
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1392
+ return x
1393
+
1394
+
1395
+ class DiffusionWrapper(pl.LightningModule):
1396
+ def __init__(self, diff_model_config, conditioning_key):
1397
+ super().__init__()
1398
+ self.diffusion_model = instantiate_from_config(diff_model_config)
1399
+ self.conditioning_key = conditioning_key
1400
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
1401
+
1402
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
1403
+ if self.conditioning_key is None:
1404
+ out = self.diffusion_model(x, t)
1405
+ elif self.conditioning_key == 'concat':
1406
+ xc = torch.cat([x] + c_concat, dim=1)
1407
+ out = self.diffusion_model(xc, t)
1408
+ elif self.conditioning_key == 'crossattn':
1409
+ cc = torch.cat(c_crossattn, 1)
1410
+ out = self.diffusion_model(x, t, context=cc)
1411
+ elif self.conditioning_key == 'hybrid':
1412
+ xc = torch.cat([x] + c_concat, dim=1)
1413
+ cc = torch.cat(c_crossattn, 1)
1414
+ out = self.diffusion_model(xc, t, context=cc)
1415
+ elif self.conditioning_key == 'adm':
1416
+ cc = c_crossattn[0]
1417
+ out = self.diffusion_model(x, t, y=cc)
1418
+ else:
1419
+ raise NotImplementedError()
1420
+
1421
+ return out
1422
+
1423
+
1424
+ class Layout2ImgDiffusion(LatentDiffusion):
1425
+ # TODO: move all layout-specific hacks to this class
1426
+ def __init__(self, cond_stage_key, *args, **kwargs):
1427
+ assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
1428
+ super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
1429
+
1430
+ def log_images(self, batch, N=8, *args, **kwargs):
1431
+ logs = super().log_images(batch=batch, N=N, *args, **kwargs)
1432
+
1433
+ key = 'train' if self.training else 'validation'
1434
+ dset = self.trainer.datamodule.datasets[key]
1435
+ mapper = dset.conditional_builders[self.cond_stage_key]
1436
+
1437
+ bbox_imgs = []
1438
+ map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
1439
+ for tknzd_bbox in batch[self.cond_stage_key][:N]:
1440
+ bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
1441
+ bbox_imgs.append(bboximg)
1442
+
1443
+ cond_img = torch.stack(bbox_imgs, dim=0)
1444
+ logs['bbox_image'] = cond_img
1445
+ return logs
latent-diffusion/ldm/models/diffusion/plms.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
+
10
+
11
+ class PLMSSampler(object):
12
+ def __init__(self, model, schedule="linear", **kwargs):
13
+ super().__init__()
14
+ self.model = model
15
+ self.ddpm_num_timesteps = model.num_timesteps
16
+ self.schedule = schedule
17
+
18
+ def register_buffer(self, name, attr):
19
+ if type(attr) == torch.Tensor:
20
+ if attr.device != torch.device("cuda"):
21
+ attr = attr.to(torch.device("cuda"))
22
+ setattr(self, name, attr)
23
+
24
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25
+ if ddim_eta != 0:
26
+ raise ValueError('ddim_eta must be 0 for PLMS')
27
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
28
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
29
+ alphas_cumprod = self.model.alphas_cumprod
30
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
31
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
32
+
33
+ self.register_buffer('betas', to_torch(self.model.betas))
34
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
36
+
37
+ # calculations for diffusion q(x_t | x_{t-1}) and others
38
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
40
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
41
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
42
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
43
+
44
+ # ddim sampling parameters
45
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
46
+ ddim_timesteps=self.ddim_timesteps,
47
+ eta=ddim_eta,verbose=verbose)
48
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
49
+ self.register_buffer('ddim_alphas', ddim_alphas)
50
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
51
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
52
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
53
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
54
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
55
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
56
+
57
+ @torch.no_grad()
58
+ def sample(self,
59
+ S,
60
+ batch_size,
61
+ shape,
62
+ conditioning=None,
63
+ callback=None,
64
+ normals_sequence=None,
65
+ img_callback=None,
66
+ quantize_x0=False,
67
+ eta=0.,
68
+ mask=None,
69
+ x0=None,
70
+ temperature=1.,
71
+ noise_dropout=0.,
72
+ score_corrector=None,
73
+ corrector_kwargs=None,
74
+ verbose=True,
75
+ x_T=None,
76
+ log_every_t=100,
77
+ unconditional_guidance_scale=1.,
78
+ unconditional_conditioning=None,
79
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
80
+ **kwargs
81
+ ):
82
+ if conditioning is not None:
83
+ if isinstance(conditioning, dict):
84
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
85
+ if cbs != batch_size:
86
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87
+ else:
88
+ if conditioning.shape[0] != batch_size:
89
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
90
+
91
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
92
+ # sampling
93
+ C, H, W = shape
94
+ size = (batch_size, C, H, W)
95
+ print(f'Data shape for PLMS sampling is {size}')
96
+
97
+ samples, intermediates = self.plms_sampling(conditioning, size,
98
+ callback=callback,
99
+ img_callback=img_callback,
100
+ quantize_denoised=quantize_x0,
101
+ mask=mask, x0=x0,
102
+ ddim_use_original_steps=False,
103
+ noise_dropout=noise_dropout,
104
+ temperature=temperature,
105
+ score_corrector=score_corrector,
106
+ corrector_kwargs=corrector_kwargs,
107
+ x_T=x_T,
108
+ log_every_t=log_every_t,
109
+ unconditional_guidance_scale=unconditional_guidance_scale,
110
+ unconditional_conditioning=unconditional_conditioning,
111
+ )
112
+ return samples, intermediates
113
+
114
+ @torch.no_grad()
115
+ def plms_sampling(self, cond, shape,
116
+ x_T=None, ddim_use_original_steps=False,
117
+ callback=None, timesteps=None, quantize_denoised=False,
118
+ mask=None, x0=None, img_callback=None, log_every_t=100,
119
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
120
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
121
+ device = self.model.betas.device
122
+ b = shape[0]
123
+ if x_T is None:
124
+ img = torch.randn(shape, device=device)
125
+ else:
126
+ img = x_T
127
+
128
+ if timesteps is None:
129
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
130
+ elif timesteps is not None and not ddim_use_original_steps:
131
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
132
+ timesteps = self.ddim_timesteps[:subset_end]
133
+
134
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
135
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
136
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
137
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
138
+
139
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
140
+ old_eps = []
141
+
142
+ for i, step in enumerate(iterator):
143
+ index = total_steps - i - 1
144
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
145
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
146
+
147
+ if mask is not None:
148
+ assert x0 is not None
149
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
150
+ img = img_orig * mask + (1. - mask) * img
151
+
152
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
153
+ quantize_denoised=quantize_denoised, temperature=temperature,
154
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
155
+ corrector_kwargs=corrector_kwargs,
156
+ unconditional_guidance_scale=unconditional_guidance_scale,
157
+ unconditional_conditioning=unconditional_conditioning,
158
+ old_eps=old_eps, t_next=ts_next)
159
+ img, pred_x0, e_t = outs
160
+ old_eps.append(e_t)
161
+ if len(old_eps) >= 4:
162
+ old_eps.pop(0)
163
+ if callback: callback(i)
164
+ if img_callback: img_callback(pred_x0, i)
165
+
166
+ if index % log_every_t == 0 or index == total_steps - 1:
167
+ intermediates['x_inter'].append(img)
168
+ intermediates['pred_x0'].append(pred_x0)
169
+
170
+ return img, intermediates
171
+
172
+ @torch.no_grad()
173
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
174
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
175
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
176
+ b, *_, device = *x.shape, x.device
177
+
178
+ def get_model_output(x, t):
179
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
180
+ e_t = self.model.apply_model(x, t, c)
181
+ else:
182
+ x_in = torch.cat([x] * 2)
183
+ t_in = torch.cat([t] * 2)
184
+ c_in = torch.cat([unconditional_conditioning, c])
185
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
186
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
187
+
188
+ if score_corrector is not None:
189
+ assert self.model.parameterization == "eps"
190
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
191
+
192
+ return e_t
193
+
194
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
195
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
196
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
197
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
198
+
199
+ def get_x_prev_and_pred_x0(e_t, index):
200
+ # select parameters corresponding to the currently considered timestep
201
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
202
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
203
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
204
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
205
+
206
+ # current prediction for x_0
207
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
208
+ if quantize_denoised:
209
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
210
+ # direction pointing to x_t
211
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
212
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
213
+ if noise_dropout > 0.:
214
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
215
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
216
+ return x_prev, pred_x0
217
+
218
+ e_t = get_model_output(x, t)
219
+ if len(old_eps) == 0:
220
+ # Pseudo Improved Euler (2nd order)
221
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
222
+ e_t_next = get_model_output(x_prev, t_next)
223
+ e_t_prime = (e_t + e_t_next) / 2
224
+ elif len(old_eps) == 1:
225
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
226
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
227
+ elif len(old_eps) == 2:
228
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
229
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
230
+ elif len(old_eps) >= 3:
231
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
232
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
233
+
234
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
235
+
236
+ return x_prev, pred_x0, e_t
latent-diffusion/ldm/modules/__pycache__/attention.cpython-39.pyc ADDED
Binary file (8.8 kB). View file
latent-diffusion/ldm/modules/__pycache__/ema.cpython-39.pyc ADDED
Binary file (3.01 kB). View file
latent-diffusion/ldm/modules/__pycache__/x_transformer.cpython-39.pyc ADDED
Binary file (18.3 kB). View file
latent-diffusion/ldm/modules/attention.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+
8
+ from ldm.modules.diffusionmodules.util import checkpoint
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def uniq(arr):
16
+ return{el: True for el in arr}.keys()
17
+
18
+
19
+ def default(val, d):
20
+ if exists(val):
21
+ return val
22
+ return d() if isfunction(d) else d
23
+
24
+
25
+ def max_neg_value(t):
26
+ return -torch.finfo(t.dtype).max
27
+
28
+
29
+ def init_(tensor):
30
+ dim = tensor.shape[-1]
31
+ std = 1 / math.sqrt(dim)
32
+ tensor.uniform_(-std, std)
33
+ return tensor
34
+
35
+
36
+ # feedforward
37
+ class GEGLU(nn.Module):
38
+ def __init__(self, dim_in, dim_out):
39
+ super().__init__()
40
+ self.proj = nn.Linear(dim_in, dim_out * 2)
41
+
42
+ def forward(self, x):
43
+ x, gate = self.proj(x).chunk(2, dim=-1)
44
+ return x * F.gelu(gate)
45
+
46
+
47
+ class FeedForward(nn.Module):
48
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49
+ super().__init__()
50
+ inner_dim = int(dim * mult)
51
+ dim_out = default(dim_out, dim)
52
+ project_in = nn.Sequential(
53
+ nn.Linear(dim, inner_dim),
54
+ nn.GELU()
55
+ ) if not glu else GEGLU(dim, inner_dim)
56
+
57
+ self.net = nn.Sequential(
58
+ project_in,
59
+ nn.Dropout(dropout),
60
+ nn.Linear(inner_dim, dim_out)
61
+ )
62
+
63
+ def forward(self, x):
64
+ return self.net(x)
65
+
66
+
67
+ def zero_module(module):
68
+ """
69
+ Zero out the parameters of a module and return it.
70
+ """
71
+ for p in module.parameters():
72
+ p.detach().zero_()
73
+ return module
74
+
75
+
76
+ def Normalize(in_channels):
77
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78
+
79
+
80
+ class LinearAttention(nn.Module):
81
+ def __init__(self, dim, heads=4, dim_head=32):
82
+ super().__init__()
83
+ self.heads = heads
84
+ hidden_dim = dim_head * heads
85
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87
+
88
+ def forward(self, x):
89
+ b, c, h, w = x.shape
90
+ qkv = self.to_qkv(x)
91
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92
+ k = k.softmax(dim=-1)
93
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
94
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
95
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96
+ return self.to_out(out)
97
+
98
+
99
+ class SpatialSelfAttention(nn.Module):
100
+ def __init__(self, in_channels):
101
+ super().__init__()
102
+ self.in_channels = in_channels
103
+
104
+ self.norm = Normalize(in_channels)
105
+ self.q = torch.nn.Conv2d(in_channels,
106
+ in_channels,
107
+ kernel_size=1,
108
+ stride=1,
109
+ padding=0)
110
+ self.k = torch.nn.Conv2d(in_channels,
111
+ in_channels,
112
+ kernel_size=1,
113
+ stride=1,
114
+ padding=0)
115
+ self.v = torch.nn.Conv2d(in_channels,
116
+ in_channels,
117
+ kernel_size=1,
118
+ stride=1,
119
+ padding=0)
120
+ self.proj_out = torch.nn.Conv2d(in_channels,
121
+ in_channels,
122
+ kernel_size=1,
123
+ stride=1,
124
+ padding=0)
125
+
126
+ def forward(self, x):
127
+ h_ = x
128
+ h_ = self.norm(h_)
129
+ q = self.q(h_)
130
+ k = self.k(h_)
131
+ v = self.v(h_)
132
+
133
+ # compute attention
134
+ b,c,h,w = q.shape
135
+ q = rearrange(q, 'b c h w -> b (h w) c')
136
+ k = rearrange(k, 'b c h w -> b c (h w)')
137
+ w_ = torch.einsum('bij,bjk->bik', q, k)
138
+
139
+ w_ = w_ * (int(c)**(-0.5))
140
+ w_ = torch.nn.functional.softmax(w_, dim=2)
141
+
142
+ # attend to values
143
+ v = rearrange(v, 'b c h w -> b c (h w)')
144
+ w_ = rearrange(w_, 'b i j -> b j i')
145
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
146
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
147
+ h_ = self.proj_out(h_)
148
+
149
+ return x+h_
150
+
151
+
152
+ class CrossAttention(nn.Module):
153
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
154
+ super().__init__()
155
+ inner_dim = dim_head * heads
156
+ context_dim = default(context_dim, query_dim)
157
+
158
+ self.scale = dim_head ** -0.5
159
+ self.heads = heads
160
+
161
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
162
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
163
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
164
+
165
+ self.to_out = nn.Sequential(
166
+ nn.Linear(inner_dim, query_dim),
167
+ nn.Dropout(dropout)
168
+ )
169
+
170
+ def forward(self, x, context=None, mask=None):
171
+ h = self.heads
172
+
173
+ q = self.to_q(x)
174
+ context = default(context, x)
175
+ k = self.to_k(context)
176
+ v = self.to_v(context)
177
+
178
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
179
+
180
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
181
+
182
+ if exists(mask):
183
+ mask = rearrange(mask, 'b ... -> b (...)')
184
+ max_neg_value = -torch.finfo(sim.dtype).max
185
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
186
+ sim.masked_fill_(~mask, max_neg_value)
187
+
188
+ # attention, what we cannot get enough of
189
+ attn = sim.softmax(dim=-1)
190
+
191
+ out = einsum('b i j, b j d -> b i d', attn, v)
192
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
193
+ return self.to_out(out)
194
+
195
+
196
+ class BasicTransformerBlock(nn.Module):
197
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
198
+ super().__init__()
199
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
200
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
201
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
202
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
203
+ self.norm1 = nn.LayerNorm(dim)
204
+ self.norm2 = nn.LayerNorm(dim)
205
+ self.norm3 = nn.LayerNorm(dim)
206
+ self.checkpoint = checkpoint
207
+
208
+ def forward(self, x, context=None):
209
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
210
+
211
+ def _forward(self, x, context=None):
212
+ x = self.attn1(self.norm1(x)) + x
213
+ x = self.attn2(self.norm2(x), context=context) + x
214
+ x = self.ff(self.norm3(x)) + x
215
+ return x
216
+
217
+
218
+ class SpatialTransformer(nn.Module):
219
+ """
220
+ Transformer block for image-like data.
221
+ First, project the input (aka embedding)
222
+ and reshape to b, t, d.
223
+ Then apply standard transformer action.
224
+ Finally, reshape to image
225
+ """
226
+ def __init__(self, in_channels, n_heads, d_head,
227
+ depth=1, dropout=0., context_dim=None):
228
+ super().__init__()
229
+ self.in_channels = in_channels
230
+ inner_dim = n_heads * d_head
231
+ self.norm = Normalize(in_channels)
232
+
233
+ self.proj_in = nn.Conv2d(in_channels,
234
+ inner_dim,
235
+ kernel_size=1,
236
+ stride=1,
237
+ padding=0)
238
+
239
+ self.transformer_blocks = nn.ModuleList(
240
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
241
+ for d in range(depth)]
242
+ )
243
+
244
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
245
+ in_channels,
246
+ kernel_size=1,
247
+ stride=1,
248
+ padding=0))
249
+
250
+ def forward(self, x, context=None):
251
+ # note: if no context is given, cross-attention defaults to self-attention
252
+ b, c, h, w = x.shape
253
+ x_in = x
254
+ x = self.norm(x)
255
+ x = self.proj_in(x)
256
+ x = rearrange(x, 'b c h w -> b (h w) c')
257
+ for block in self.transformer_blocks:
258
+ x = block(x, context=context)
259
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
260
+ x = self.proj_out(x)
261
+ return x + x_in
latent-diffusion/ldm/modules/diffusionmodules/__init__.py ADDED
File without changes
latent-diffusion/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (201 Bytes). View file
latent-diffusion/ldm/modules/diffusionmodules/__pycache__/model.cpython-39.pyc ADDED
Binary file (20.6 kB). View file
latent-diffusion/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-39.pyc ADDED
Binary file (22.7 kB). View file
latent-diffusion/ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc ADDED
Binary file (9.55 kB). View file
latent-diffusion/ldm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+
8
+ from ldm.util import instantiate_from_config
9
+ from ldm.modules.attention import LinearAttention
10
+
11
+
12
+ def get_timestep_embedding(timesteps, embedding_dim):
13
+ """
14
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
15
+ From Fairseq.
16
+ Build sinusoidal embeddings.
17
+ This matches the implementation in tensor2tensor, but differs slightly
18
+ from the description in Section 3.5 of "Attention Is All You Need".
19
+ """
20
+ assert len(timesteps.shape) == 1
21
+
22
+ half_dim = embedding_dim // 2
23
+ emb = math.log(10000) / (half_dim - 1)
24
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
25
+ emb = emb.to(device=timesteps.device)
26
+ emb = timesteps.float()[:, None] * emb[None, :]
27
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
28
+ if embedding_dim % 2 == 1: # zero pad
29
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
30
+ return emb
31
+
32
+
33
+ def nonlinearity(x):
34
+ # swish
35
+ return x*torch.sigmoid(x)
36
+
37
+
38
+ def Normalize(in_channels, num_groups=32):
39
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
40
+
41
+
42
+ class Upsample(nn.Module):
43
+ def __init__(self, in_channels, with_conv):
44
+ super().__init__()
45
+ self.with_conv = with_conv
46
+ if self.with_conv:
47
+ self.conv = torch.nn.Conv2d(in_channels,
48
+ in_channels,
49
+ kernel_size=3,
50
+ stride=1,
51
+ padding=1)
52
+
53
+ def forward(self, x):
54
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
55
+ if self.with_conv:
56
+ x = self.conv(x)
57
+ return x
58
+
59
+
60
+ class Downsample(nn.Module):
61
+ def __init__(self, in_channels, with_conv):
62
+ super().__init__()
63
+ self.with_conv = with_conv
64
+ if self.with_conv:
65
+ # no asymmetric padding in torch conv, must do it ourselves
66
+ self.conv = torch.nn.Conv2d(in_channels,
67
+ in_channels,
68
+ kernel_size=3,
69
+ stride=2,
70
+ padding=0)
71
+
72
+ def forward(self, x):
73
+ if self.with_conv:
74
+ pad = (0,1,0,1)
75
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
76
+ x = self.conv(x)
77
+ else:
78
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
79
+ return x
80
+
81
+
82
+ class ResnetBlock(nn.Module):
83
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
84
+ dropout, temb_channels=512):
85
+ super().__init__()
86
+ self.in_channels = in_channels
87
+ out_channels = in_channels if out_channels is None else out_channels
88
+ self.out_channels = out_channels
89
+ self.use_conv_shortcut = conv_shortcut
90
+
91
+ self.norm1 = Normalize(in_channels)
92
+ self.conv1 = torch.nn.Conv2d(in_channels,
93
+ out_channels,
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=1)
97
+ if temb_channels > 0:
98
+ self.temb_proj = torch.nn.Linear(temb_channels,
99
+ out_channels)
100
+ self.norm2 = Normalize(out_channels)
101
+ self.dropout = torch.nn.Dropout(dropout)
102
+ self.conv2 = torch.nn.Conv2d(out_channels,
103
+ out_channels,
104
+ kernel_size=3,
105
+ stride=1,
106
+ padding=1)
107
+ if self.in_channels != self.out_channels:
108
+ if self.use_conv_shortcut:
109
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
110
+ out_channels,
111
+ kernel_size=3,
112
+ stride=1,
113
+ padding=1)
114
+ else:
115
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
116
+ out_channels,
117
+ kernel_size=1,
118
+ stride=1,
119
+ padding=0)
120
+
121
+ def forward(self, x, temb):
122
+ h = x
123
+ h = self.norm1(h)
124
+ h = nonlinearity(h)
125
+ h = self.conv1(h)
126
+
127
+ if temb is not None:
128
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
129
+
130
+ h = self.norm2(h)
131
+ h = nonlinearity(h)
132
+ h = self.dropout(h)
133
+ h = self.conv2(h)
134
+
135
+ if self.in_channels != self.out_channels:
136
+ if self.use_conv_shortcut:
137
+ x = self.conv_shortcut(x)
138
+ else:
139
+ x = self.nin_shortcut(x)
140
+
141
+ return x+h
142
+
143
+
144
+ class LinAttnBlock(LinearAttention):
145
+ """to match AttnBlock usage"""
146
+ def __init__(self, in_channels):
147
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
148
+
149
+
150
+ class AttnBlock(nn.Module):
151
+ def __init__(self, in_channels):
152
+ super().__init__()
153
+ self.in_channels = in_channels
154
+
155
+ self.norm = Normalize(in_channels)
156
+ self.q = torch.nn.Conv2d(in_channels,
157
+ in_channels,
158
+ kernel_size=1,
159
+ stride=1,
160
+ padding=0)
161
+ self.k = torch.nn.Conv2d(in_channels,
162
+ in_channels,
163
+ kernel_size=1,
164
+ stride=1,
165
+ padding=0)
166
+ self.v = torch.nn.Conv2d(in_channels,
167
+ in_channels,
168
+ kernel_size=1,
169
+ stride=1,
170
+ padding=0)
171
+ self.proj_out = torch.nn.Conv2d(in_channels,
172
+ in_channels,
173
+ kernel_size=1,
174
+ stride=1,
175
+ padding=0)
176
+
177
+
178
+ def forward(self, x):
179
+ h_ = x
180
+ h_ = self.norm(h_)
181
+ q = self.q(h_)
182
+ k = self.k(h_)
183
+ v = self.v(h_)
184
+
185
+ # compute attention
186
+ b,c,h,w = q.shape
187
+ q = q.reshape(b,c,h*w)
188
+ q = q.permute(0,2,1) # b,hw,c
189
+ k = k.reshape(b,c,h*w) # b,c,hw
190
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
191
+ w_ = w_ * (int(c)**(-0.5))
192
+ w_ = torch.nn.functional.softmax(w_, dim=2)
193
+
194
+ # attend to values
195
+ v = v.reshape(b,c,h*w)
196
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
197
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
198
+ h_ = h_.reshape(b,c,h,w)
199
+
200
+ h_ = self.proj_out(h_)
201
+
202
+ return x+h_
203
+
204
+
205
+ def make_attn(in_channels, attn_type="vanilla"):
206
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
207
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
208
+ if attn_type == "vanilla":
209
+ return AttnBlock(in_channels)
210
+ elif attn_type == "none":
211
+ return nn.Identity(in_channels)
212
+ else:
213
+ return LinAttnBlock(in_channels)
214
+
215
+
216
+ class Model(nn.Module):
217
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
218
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
219
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
220
+ super().__init__()
221
+ if use_linear_attn: attn_type = "linear"
222
+ self.ch = ch
223
+ self.temb_ch = self.ch*4
224
+ self.num_resolutions = len(ch_mult)
225
+ self.num_res_blocks = num_res_blocks
226
+ self.resolution = resolution
227
+ self.in_channels = in_channels
228
+
229
+ self.use_timestep = use_timestep
230
+ if self.use_timestep:
231
+ # timestep embedding
232
+ self.temb = nn.Module()
233
+ self.temb.dense = nn.ModuleList([
234
+ torch.nn.Linear(self.ch,
235
+ self.temb_ch),
236
+ torch.nn.Linear(self.temb_ch,
237
+ self.temb_ch),
238
+ ])
239
+
240
+ # downsampling
241
+ self.conv_in = torch.nn.Conv2d(in_channels,
242
+ self.ch,
243
+ kernel_size=3,
244
+ stride=1,
245
+ padding=1)
246
+
247
+ curr_res = resolution
248
+ in_ch_mult = (1,)+tuple(ch_mult)
249
+ self.down = nn.ModuleList()
250
+ for i_level in range(self.num_resolutions):
251
+ block = nn.ModuleList()
252
+ attn = nn.ModuleList()
253
+ block_in = ch*in_ch_mult[i_level]
254
+ block_out = ch*ch_mult[i_level]
255
+ for i_block in range(self.num_res_blocks):
256
+ block.append(ResnetBlock(in_channels=block_in,
257
+ out_channels=block_out,
258
+ temb_channels=self.temb_ch,
259
+ dropout=dropout))
260
+ block_in = block_out
261
+ if curr_res in attn_resolutions:
262
+ attn.append(make_attn(block_in, attn_type=attn_type))
263
+ down = nn.Module()
264
+ down.block = block
265
+ down.attn = attn
266
+ if i_level != self.num_resolutions-1:
267
+ down.downsample = Downsample(block_in, resamp_with_conv)
268
+ curr_res = curr_res // 2
269
+ self.down.append(down)
270
+
271
+ # middle
272
+ self.mid = nn.Module()
273
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
274
+ out_channels=block_in,
275
+ temb_channels=self.temb_ch,
276
+ dropout=dropout)
277
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
278
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
279
+ out_channels=block_in,
280
+ temb_channels=self.temb_ch,
281
+ dropout=dropout)
282
+
283
+ # upsampling
284
+ self.up = nn.ModuleList()
285
+ for i_level in reversed(range(self.num_resolutions)):
286
+ block = nn.ModuleList()
287
+ attn = nn.ModuleList()
288
+ block_out = ch*ch_mult[i_level]
289
+ skip_in = ch*ch_mult[i_level]
290
+ for i_block in range(self.num_res_blocks+1):
291
+ if i_block == self.num_res_blocks:
292
+ skip_in = ch*in_ch_mult[i_level]
293
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
294
+ out_channels=block_out,
295
+ temb_channels=self.temb_ch,
296
+ dropout=dropout))
297
+ block_in = block_out
298
+ if curr_res in attn_resolutions:
299
+ attn.append(make_attn(block_in, attn_type=attn_type))
300
+ up = nn.Module()
301
+ up.block = block
302
+ up.attn = attn
303
+ if i_level != 0:
304
+ up.upsample = Upsample(block_in, resamp_with_conv)
305
+ curr_res = curr_res * 2
306
+ self.up.insert(0, up) # prepend to get consistent order
307
+
308
+ # end
309
+ self.norm_out = Normalize(block_in)
310
+ self.conv_out = torch.nn.Conv2d(block_in,
311
+ out_ch,
312
+ kernel_size=3,
313
+ stride=1,
314
+ padding=1)
315
+
316
+ def forward(self, x, t=None, context=None):
317
+ #assert x.shape[2] == x.shape[3] == self.resolution
318
+ if context is not None:
319
+ # assume aligned context, cat along channel axis
320
+ x = torch.cat((x, context), dim=1)
321
+ if self.use_timestep:
322
+ # timestep embedding
323
+ assert t is not None
324
+ temb = get_timestep_embedding(t, self.ch)
325
+ temb = self.temb.dense[0](temb)
326
+ temb = nonlinearity(temb)
327
+ temb = self.temb.dense[1](temb)
328
+ else:
329
+ temb = None
330
+
331
+ # downsampling
332
+ hs = [self.conv_in(x)]
333
+ for i_level in range(self.num_resolutions):
334
+ for i_block in range(self.num_res_blocks):
335
+ h = self.down[i_level].block[i_block](hs[-1], temb)
336
+ if len(self.down[i_level].attn) > 0:
337
+ h = self.down[i_level].attn[i_block](h)
338
+ hs.append(h)
339
+ if i_level != self.num_resolutions-1:
340
+ hs.append(self.down[i_level].downsample(hs[-1]))
341
+
342
+ # middle
343
+ h = hs[-1]
344
+ h = self.mid.block_1(h, temb)
345
+ h = self.mid.attn_1(h)
346
+ h = self.mid.block_2(h, temb)
347
+
348
+ # upsampling
349
+ for i_level in reversed(range(self.num_resolutions)):
350
+ for i_block in range(self.num_res_blocks+1):
351
+ h = self.up[i_level].block[i_block](
352
+ torch.cat([h, hs.pop()], dim=1), temb)
353
+ if len(self.up[i_level].attn) > 0:
354
+ h = self.up[i_level].attn[i_block](h)
355
+ if i_level != 0:
356
+ h = self.up[i_level].upsample(h)
357
+
358
+ # end
359
+ h = self.norm_out(h)
360
+ h = nonlinearity(h)
361
+ h = self.conv_out(h)
362
+ return h
363
+
364
+ def get_last_layer(self):
365
+ return self.conv_out.weight
366
+
367
+
368
+ class Encoder(nn.Module):
369
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
370
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
371
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
372
+ **ignore_kwargs):
373
+ super().__init__()
374
+ if use_linear_attn: attn_type = "linear"
375
+ self.ch = ch
376
+ self.temb_ch = 0
377
+ self.num_resolutions = len(ch_mult)
378
+ self.num_res_blocks = num_res_blocks
379
+ self.resolution = resolution
380
+ self.in_channels = in_channels
381
+
382
+ # downsampling
383
+ self.conv_in = torch.nn.Conv2d(in_channels,
384
+ self.ch,
385
+ kernel_size=3,
386
+ stride=1,
387
+ padding=1)
388
+
389
+ curr_res = resolution
390
+ in_ch_mult = (1,)+tuple(ch_mult)
391
+ self.in_ch_mult = in_ch_mult
392
+ self.down = nn.ModuleList()
393
+ for i_level in range(self.num_resolutions):
394
+ block = nn.ModuleList()
395
+ attn = nn.ModuleList()
396
+ block_in = ch*in_ch_mult[i_level]
397
+ block_out = ch*ch_mult[i_level]
398
+ for i_block in range(self.num_res_blocks):
399
+ block.append(ResnetBlock(in_channels=block_in,
400
+ out_channels=block_out,
401
+ temb_channels=self.temb_ch,
402
+ dropout=dropout))
403
+ block_in = block_out
404
+ if curr_res in attn_resolutions:
405
+ attn.append(make_attn(block_in, attn_type=attn_type))
406
+ down = nn.Module()
407
+ down.block = block
408
+ down.attn = attn
409
+ if i_level != self.num_resolutions-1:
410
+ down.downsample = Downsample(block_in, resamp_with_conv)
411
+ curr_res = curr_res // 2
412
+ self.down.append(down)
413
+
414
+ # middle
415
+ self.mid = nn.Module()
416
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
417
+ out_channels=block_in,
418
+ temb_channels=self.temb_ch,
419
+ dropout=dropout)
420
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
421
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
422
+ out_channels=block_in,
423
+ temb_channels=self.temb_ch,
424
+ dropout=dropout)
425
+
426
+ # end
427
+ self.norm_out = Normalize(block_in)
428
+ self.conv_out = torch.nn.Conv2d(block_in,
429
+ 2*z_channels if double_z else z_channels,
430
+ kernel_size=3,
431
+ stride=1,
432
+ padding=1)
433
+
434
+ def forward(self, x):
435
+ # timestep embedding
436
+ temb = None
437
+
438
+ # downsampling
439
+ hs = [self.conv_in(x)]
440
+ for i_level in range(self.num_resolutions):
441
+ for i_block in range(self.num_res_blocks):
442
+ h = self.down[i_level].block[i_block](hs[-1], temb)
443
+ if len(self.down[i_level].attn) > 0:
444
+ h = self.down[i_level].attn[i_block](h)
445
+ hs.append(h)
446
+ if i_level != self.num_resolutions-1:
447
+ hs.append(self.down[i_level].downsample(hs[-1]))
448
+
449
+ # middle
450
+ h = hs[-1]
451
+ h = self.mid.block_1(h, temb)
452
+ h = self.mid.attn_1(h)
453
+ h = self.mid.block_2(h, temb)
454
+
455
+ # end
456
+ h = self.norm_out(h)
457
+ h = nonlinearity(h)
458
+ h = self.conv_out(h)
459
+ return h
460
+
461
+
462
+ class Decoder(nn.Module):
463
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
464
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
465
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
466
+ attn_type="vanilla", **ignorekwargs):
467
+ super().__init__()
468
+ if use_linear_attn: attn_type = "linear"
469
+ self.ch = ch
470
+ self.temb_ch = 0
471
+ self.num_resolutions = len(ch_mult)
472
+ self.num_res_blocks = num_res_blocks
473
+ self.resolution = resolution
474
+ self.in_channels = in_channels
475
+ self.give_pre_end = give_pre_end
476
+ self.tanh_out = tanh_out
477
+
478
+ # compute in_ch_mult, block_in and curr_res at lowest res
479
+ in_ch_mult = (1,)+tuple(ch_mult)
480
+ block_in = ch*ch_mult[self.num_resolutions-1]
481
+ curr_res = resolution // 2**(self.num_resolutions-1)
482
+ self.z_shape = (1,z_channels,curr_res,curr_res)
483
+ print("Working with z of shape {} = {} dimensions.".format(
484
+ self.z_shape, np.prod(self.z_shape)))
485
+
486
+ # z to block_in
487
+ self.conv_in = torch.nn.Conv2d(z_channels,
488
+ block_in,
489
+ kernel_size=3,
490
+ stride=1,
491
+ padding=1)
492
+
493
+ # middle
494
+ self.mid = nn.Module()
495
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
496
+ out_channels=block_in,
497
+ temb_channels=self.temb_ch,
498
+ dropout=dropout)
499
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
500
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
501
+ out_channels=block_in,
502
+ temb_channels=self.temb_ch,
503
+ dropout=dropout)
504
+
505
+ # upsampling
506
+ self.up = nn.ModuleList()
507
+ for i_level in reversed(range(self.num_resolutions)):
508
+ block = nn.ModuleList()
509
+ attn = nn.ModuleList()
510
+ block_out = ch*ch_mult[i_level]
511
+ for i_block in range(self.num_res_blocks+1):
512
+ block.append(ResnetBlock(in_channels=block_in,
513
+ out_channels=block_out,
514
+ temb_channels=self.temb_ch,
515
+ dropout=dropout))
516
+ block_in = block_out
517
+ if curr_res in attn_resolutions:
518
+ attn.append(make_attn(block_in, attn_type=attn_type))
519
+ up = nn.Module()
520
+ up.block = block
521
+ up.attn = attn
522
+ if i_level != 0:
523
+ up.upsample = Upsample(block_in, resamp_with_conv)
524
+ curr_res = curr_res * 2
525
+ self.up.insert(0, up) # prepend to get consistent order
526
+
527
+ # end
528
+ self.norm_out = Normalize(block_in)
529
+ self.conv_out = torch.nn.Conv2d(block_in,
530
+ out_ch,
531
+ kernel_size=3,
532
+ stride=1,
533
+ padding=1)
534
+
535
+ def forward(self, z):
536
+ #assert z.shape[1:] == self.z_shape[1:]
537
+ self.last_z_shape = z.shape
538
+
539
+ # timestep embedding
540
+ temb = None
541
+
542
+ # z to block_in
543
+ h = self.conv_in(z)
544
+
545
+ # middle
546
+ h = self.mid.block_1(h, temb)
547
+ h = self.mid.attn_1(h)
548
+ h = self.mid.block_2(h, temb)
549
+
550
+ # upsampling
551
+ for i_level in reversed(range(self.num_resolutions)):
552
+ for i_block in range(self.num_res_blocks+1):
553
+ h = self.up[i_level].block[i_block](h, temb)
554
+ if len(self.up[i_level].attn) > 0:
555
+ h = self.up[i_level].attn[i_block](h)
556
+ if i_level != 0:
557
+ h = self.up[i_level].upsample(h)
558
+
559
+ # end
560
+ if self.give_pre_end:
561
+ return h
562
+
563
+ h = self.norm_out(h)
564
+ h = nonlinearity(h)
565
+ h = self.conv_out(h)
566
+ if self.tanh_out:
567
+ h = torch.tanh(h)
568
+ return h
569
+
570
+
571
+ class SimpleDecoder(nn.Module):
572
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
573
+ super().__init__()
574
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
575
+ ResnetBlock(in_channels=in_channels,
576
+ out_channels=2 * in_channels,
577
+ temb_channels=0, dropout=0.0),
578
+ ResnetBlock(in_channels=2 * in_channels,
579
+ out_channels=4 * in_channels,
580
+ temb_channels=0, dropout=0.0),
581
+ ResnetBlock(in_channels=4 * in_channels,
582
+ out_channels=2 * in_channels,
583
+ temb_channels=0, dropout=0.0),
584
+ nn.Conv2d(2*in_channels, in_channels, 1),
585
+ Upsample(in_channels, with_conv=True)])
586
+ # end
587
+ self.norm_out = Normalize(in_channels)
588
+ self.conv_out = torch.nn.Conv2d(in_channels,
589
+ out_channels,
590
+ kernel_size=3,
591
+ stride=1,
592
+ padding=1)
593
+
594
+ def forward(self, x):
595
+ for i, layer in enumerate(self.model):
596
+ if i in [1,2,3]:
597
+ x = layer(x, None)
598
+ else:
599
+ x = layer(x)
600
+
601
+ h = self.norm_out(x)
602
+ h = nonlinearity(h)
603
+ x = self.conv_out(h)
604
+ return x
605
+
606
+
607
+ class UpsampleDecoder(nn.Module):
608
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
609
+ ch_mult=(2,2), dropout=0.0):
610
+ super().__init__()
611
+ # upsampling
612
+ self.temb_ch = 0
613
+ self.num_resolutions = len(ch_mult)
614
+ self.num_res_blocks = num_res_blocks
615
+ block_in = in_channels
616
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
617
+ self.res_blocks = nn.ModuleList()
618
+ self.upsample_blocks = nn.ModuleList()
619
+ for i_level in range(self.num_resolutions):
620
+ res_block = []
621
+ block_out = ch * ch_mult[i_level]
622
+ for i_block in range(self.num_res_blocks + 1):
623
+ res_block.append(ResnetBlock(in_channels=block_in,
624
+ out_channels=block_out,
625
+ temb_channels=self.temb_ch,
626
+ dropout=dropout))
627
+ block_in = block_out
628
+ self.res_blocks.append(nn.ModuleList(res_block))
629
+ if i_level != self.num_resolutions - 1:
630
+ self.upsample_blocks.append(Upsample(block_in, True))
631
+ curr_res = curr_res * 2
632
+
633
+ # end
634
+ self.norm_out = Normalize(block_in)
635
+ self.conv_out = torch.nn.Conv2d(block_in,
636
+ out_channels,
637
+ kernel_size=3,
638
+ stride=1,
639
+ padding=1)
640
+
641
+ def forward(self, x):
642
+ # upsampling
643
+ h = x
644
+ for k, i_level in enumerate(range(self.num_resolutions)):
645
+ for i_block in range(self.num_res_blocks + 1):
646
+ h = self.res_blocks[i_level][i_block](h, None)
647
+ if i_level != self.num_resolutions - 1:
648
+ h = self.upsample_blocks[k](h)
649
+ h = self.norm_out(h)
650
+ h = nonlinearity(h)
651
+ h = self.conv_out(h)
652
+ return h
653
+
654
+
655
+ class LatentRescaler(nn.Module):
656
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
657
+ super().__init__()
658
+ # residual block, interpolate, residual block
659
+ self.factor = factor
660
+ self.conv_in = nn.Conv2d(in_channels,
661
+ mid_channels,
662
+ kernel_size=3,
663
+ stride=1,
664
+ padding=1)
665
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
666
+ out_channels=mid_channels,
667
+ temb_channels=0,
668
+ dropout=0.0) for _ in range(depth)])
669
+ self.attn = AttnBlock(mid_channels)
670
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
671
+ out_channels=mid_channels,
672
+ temb_channels=0,
673
+ dropout=0.0) for _ in range(depth)])
674
+
675
+ self.conv_out = nn.Conv2d(mid_channels,
676
+ out_channels,
677
+ kernel_size=1,
678
+ )
679
+
680
+ def forward(self, x):
681
+ x = self.conv_in(x)
682
+ for block in self.res_block1:
683
+ x = block(x, None)
684
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
685
+ x = self.attn(x)
686
+ for block in self.res_block2:
687
+ x = block(x, None)
688
+ x = self.conv_out(x)
689
+ return x
690
+
691
+
692
+ class MergedRescaleEncoder(nn.Module):
693
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
694
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
695
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
696
+ super().__init__()
697
+ intermediate_chn = ch * ch_mult[-1]
698
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
699
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
700
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
701
+ out_ch=None)
702
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
703
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
704
+
705
+ def forward(self, x):
706
+ x = self.encoder(x)
707
+ x = self.rescaler(x)
708
+ return x
709
+
710
+
711
+ class MergedRescaleDecoder(nn.Module):
712
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
713
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
714
+ super().__init__()
715
+ tmp_chn = z_channels*ch_mult[-1]
716
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
717
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
718
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
719
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
720
+ out_channels=tmp_chn, depth=rescale_module_depth)
721
+
722
+ def forward(self, x):
723
+ x = self.rescaler(x)
724
+ x = self.decoder(x)
725
+ return x
726
+
727
+
728
+ class Upsampler(nn.Module):
729
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
730
+ super().__init__()
731
+ assert out_size >= in_size
732
+ num_blocks = int(np.log2(out_size//in_size))+1
733
+ factor_up = 1.+ (out_size % in_size)
734
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
735
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
736
+ out_channels=in_channels)
737
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
738
+ attn_resolutions=[], in_channels=None, ch=in_channels,
739
+ ch_mult=[ch_mult for _ in range(num_blocks)])
740
+
741
+ def forward(self, x):
742
+ x = self.rescaler(x)
743
+ x = self.decoder(x)
744
+ return x
745
+
746
+
747
+ class Resize(nn.Module):
748
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
749
+ super().__init__()
750
+ self.with_conv = learned
751
+ self.mode = mode
752
+ if self.with_conv:
753
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
754
+ raise NotImplementedError()
755
+ assert in_channels is not None
756
+ # no asymmetric padding in torch conv, must do it ourselves
757
+ self.conv = torch.nn.Conv2d(in_channels,
758
+ in_channels,
759
+ kernel_size=4,
760
+ stride=2,
761
+ padding=1)
762
+
763
+ def forward(self, x, scale_factor=1.0):
764
+ if scale_factor==1.0:
765
+ return x
766
+ else:
767
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
768
+ return x
769
+
770
+ class FirstStagePostProcessor(nn.Module):
771
+
772
+ def __init__(self, ch_mult:list, in_channels,
773
+ pretrained_model:nn.Module=None,
774
+ reshape=False,
775
+ n_channels=None,
776
+ dropout=0.,
777
+ pretrained_config=None):
778
+ super().__init__()
779
+ if pretrained_config is None:
780
+ assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
781
+ self.pretrained_model = pretrained_model
782
+ else:
783
+ assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
784
+ self.instantiate_pretrained(pretrained_config)
785
+
786
+ self.do_reshape = reshape
787
+
788
+ if n_channels is None:
789
+ n_channels = self.pretrained_model.encoder.ch
790
+
791
+ self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
792
+ self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
793
+ stride=1,padding=1)
794
+
795
+ blocks = []
796
+ downs = []
797
+ ch_in = n_channels
798
+ for m in ch_mult:
799
+ blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
800
+ ch_in = m * n_channels
801
+ downs.append(Downsample(ch_in, with_conv=False))
802
+
803
+ self.model = nn.ModuleList(blocks)
804
+ self.downsampler = nn.ModuleList(downs)
805
+
806
+
807
+ def instantiate_pretrained(self, config):
808
+ model = instantiate_from_config(config)
809
+ self.pretrained_model = model.eval()
810
+ # self.pretrained_model.train = False
811
+ for param in self.pretrained_model.parameters():
812
+ param.requires_grad = False
813
+
814
+
815
+ @torch.no_grad()
816
+ def encode_with_pretrained(self,x):
817
+ c = self.pretrained_model.encode(x)
818
+ if isinstance(c, DiagonalGaussianDistribution):
819
+ c = c.mode()
820
+ return c
821
+
822
+ def forward(self,x):
823
+ z_fs = self.encode_with_pretrained(x)
824
+ z = self.proj_norm(z_fs)
825
+ z = self.proj(z)
826
+ z = nonlinearity(z)
827
+
828
+ for submodel, downmodel in zip(self.model,self.downsampler):
829
+ z = submodel(z,temb=None)
830
+ z = downmodel(z)
831
+
832
+ if self.do_reshape:
833
+ z = rearrange(z,'b c h w -> b (h w) c')
834
+ return z
835
+
latent-diffusion/ldm/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,961 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+ import math
4
+ from typing import Iterable
5
+
6
+ import numpy as np
7
+ import torch as th
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from ldm.modules.diffusionmodules.util import (
12
+ checkpoint,
13
+ conv_nd,
14
+ linear,
15
+ avg_pool_nd,
16
+ zero_module,
17
+ normalization,
18
+ timestep_embedding,
19
+ )
20
+ from ldm.modules.attention import SpatialTransformer
21
+
22
+
23
+ # dummy replace
24
+ def convert_module_to_f16(x):
25
+ pass
26
+
27
+ def convert_module_to_f32(x):
28
+ pass
29
+
30
+
31
+ ## go
32
+ class AttentionPool2d(nn.Module):
33
+ """
34
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ spacial_dim: int,
40
+ embed_dim: int,
41
+ num_heads_channels: int,
42
+ output_dim: int = None,
43
+ ):
44
+ super().__init__()
45
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
46
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
47
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
48
+ self.num_heads = embed_dim // num_heads_channels
49
+ self.attention = QKVAttention(self.num_heads)
50
+
51
+ def forward(self, x):
52
+ b, c, *_spatial = x.shape
53
+ x = x.reshape(b, c, -1) # NC(HW)
54
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
55
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
56
+ x = self.qkv_proj(x)
57
+ x = self.attention(x)
58
+ x = self.c_proj(x)
59
+ return x[:, :, 0]
60
+
61
+
62
+ class TimestepBlock(nn.Module):
63
+ """
64
+ Any module where forward() takes timestep embeddings as a second argument.
65
+ """
66
+
67
+ @abstractmethod
68
+ def forward(self, x, emb):
69
+ """
70
+ Apply the module to `x` given `emb` timestep embeddings.
71
+ """
72
+
73
+
74
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
75
+ """
76
+ A sequential module that passes timestep embeddings to the children that
77
+ support it as an extra input.
78
+ """
79
+
80
+ def forward(self, x, emb, context=None):
81
+ for layer in self:
82
+ if isinstance(layer, TimestepBlock):
83
+ x = layer(x, emb)
84
+ elif isinstance(layer, SpatialTransformer):
85
+ x = layer(x, context)
86
+ else:
87
+ x = layer(x)
88
+ return x
89
+
90
+
91
+ class Upsample(nn.Module):
92
+ """
93
+ An upsampling layer with an optional convolution.
94
+ :param channels: channels in the inputs and outputs.
95
+ :param use_conv: a bool determining if a convolution is applied.
96
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
97
+ upsampling occurs in the inner-two dimensions.
98
+ """
99
+
100
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
101
+ super().__init__()
102
+ self.channels = channels
103
+ self.out_channels = out_channels or channels
104
+ self.use_conv = use_conv
105
+ self.dims = dims
106
+ if use_conv:
107
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
108
+
109
+ def forward(self, x):
110
+ assert x.shape[1] == self.channels
111
+ if self.dims == 3:
112
+ x = F.interpolate(
113
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
114
+ )
115
+ else:
116
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
117
+ if self.use_conv:
118
+ x = self.conv(x)
119
+ return x
120
+
121
+ class TransposedUpsample(nn.Module):
122
+ 'Learned 2x upsampling without padding'
123
+ def __init__(self, channels, out_channels=None, ks=5):
124
+ super().__init__()
125
+ self.channels = channels
126
+ self.out_channels = out_channels or channels
127
+
128
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
129
+
130
+ def forward(self,x):
131
+ return self.up(x)
132
+
133
+
134
+ class Downsample(nn.Module):
135
+ """
136
+ A downsampling layer with an optional convolution.
137
+ :param channels: channels in the inputs and outputs.
138
+ :param use_conv: a bool determining if a convolution is applied.
139
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
140
+ downsampling occurs in the inner-two dimensions.
141
+ """
142
+
143
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
144
+ super().__init__()
145
+ self.channels = channels
146
+ self.out_channels = out_channels or channels
147
+ self.use_conv = use_conv
148
+ self.dims = dims
149
+ stride = 2 if dims != 3 else (1, 2, 2)
150
+ if use_conv:
151
+ self.op = conv_nd(
152
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
153
+ )
154
+ else:
155
+ assert self.channels == self.out_channels
156
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
157
+
158
+ def forward(self, x):
159
+ assert x.shape[1] == self.channels
160
+ return self.op(x)
161
+
162
+
163
+ class ResBlock(TimestepBlock):
164
+ """
165
+ A residual block that can optionally change the number of channels.
166
+ :param channels: the number of input channels.
167
+ :param emb_channels: the number of timestep embedding channels.
168
+ :param dropout: the rate of dropout.
169
+ :param out_channels: if specified, the number of out channels.
170
+ :param use_conv: if True and out_channels is specified, use a spatial
171
+ convolution instead of a smaller 1x1 convolution to change the
172
+ channels in the skip connection.
173
+ :param dims: determines if the signal is 1D, 2D, or 3D.
174
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
175
+ :param up: if True, use this block for upsampling.
176
+ :param down: if True, use this block for downsampling.
177
+ """
178
+
179
+ def __init__(
180
+ self,
181
+ channels,
182
+ emb_channels,
183
+ dropout,
184
+ out_channels=None,
185
+ use_conv=False,
186
+ use_scale_shift_norm=False,
187
+ dims=2,
188
+ use_checkpoint=False,
189
+ up=False,
190
+ down=False,
191
+ ):
192
+ super().__init__()
193
+ self.channels = channels
194
+ self.emb_channels = emb_channels
195
+ self.dropout = dropout
196
+ self.out_channels = out_channels or channels
197
+ self.use_conv = use_conv
198
+ self.use_checkpoint = use_checkpoint
199
+ self.use_scale_shift_norm = use_scale_shift_norm
200
+
201
+ self.in_layers = nn.Sequential(
202
+ normalization(channels),
203
+ nn.SiLU(),
204
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
205
+ )
206
+
207
+ self.updown = up or down
208
+
209
+ if up:
210
+ self.h_upd = Upsample(channels, False, dims)
211
+ self.x_upd = Upsample(channels, False, dims)
212
+ elif down:
213
+ self.h_upd = Downsample(channels, False, dims)
214
+ self.x_upd = Downsample(channels, False, dims)
215
+ else:
216
+ self.h_upd = self.x_upd = nn.Identity()
217
+
218
+ self.emb_layers = nn.Sequential(
219
+ nn.SiLU(),
220
+ linear(
221
+ emb_channels,
222
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
223
+ ),
224
+ )
225
+ self.out_layers = nn.Sequential(
226
+ normalization(self.out_channels),
227
+ nn.SiLU(),
228
+ nn.Dropout(p=dropout),
229
+ zero_module(
230
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
231
+ ),
232
+ )
233
+
234
+ if self.out_channels == channels:
235
+ self.skip_connection = nn.Identity()
236
+ elif use_conv:
237
+ self.skip_connection = conv_nd(
238
+ dims, channels, self.out_channels, 3, padding=1
239
+ )
240
+ else:
241
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
242
+
243
+ def forward(self, x, emb):
244
+ """
245
+ Apply the block to a Tensor, conditioned on a timestep embedding.
246
+ :param x: an [N x C x ...] Tensor of features.
247
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
248
+ :return: an [N x C x ...] Tensor of outputs.
249
+ """
250
+ return checkpoint(
251
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
252
+ )
253
+
254
+
255
+ def _forward(self, x, emb):
256
+ if self.updown:
257
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
258
+ h = in_rest(x)
259
+ h = self.h_upd(h)
260
+ x = self.x_upd(x)
261
+ h = in_conv(h)
262
+ else:
263
+ h = self.in_layers(x)
264
+ emb_out = self.emb_layers(emb).type(h.dtype)
265
+ while len(emb_out.shape) < len(h.shape):
266
+ emb_out = emb_out[..., None]
267
+ if self.use_scale_shift_norm:
268
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
269
+ scale, shift = th.chunk(emb_out, 2, dim=1)
270
+ h = out_norm(h) * (1 + scale) + shift
271
+ h = out_rest(h)
272
+ else:
273
+ h = h + emb_out
274
+ h = self.out_layers(h)
275
+ return self.skip_connection(x) + h
276
+
277
+
278
+ class AttentionBlock(nn.Module):
279
+ """
280
+ An attention block that allows spatial positions to attend to each other.
281
+ Originally ported from here, but adapted to the N-d case.
282
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
283
+ """
284
+
285
+ def __init__(
286
+ self,
287
+ channels,
288
+ num_heads=1,
289
+ num_head_channels=-1,
290
+ use_checkpoint=False,
291
+ use_new_attention_order=False,
292
+ ):
293
+ super().__init__()
294
+ self.channels = channels
295
+ if num_head_channels == -1:
296
+ self.num_heads = num_heads
297
+ else:
298
+ assert (
299
+ channels % num_head_channels == 0
300
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
301
+ self.num_heads = channels // num_head_channels
302
+ self.use_checkpoint = use_checkpoint
303
+ self.norm = normalization(channels)
304
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
305
+ if use_new_attention_order:
306
+ # split qkv before split heads
307
+ self.attention = QKVAttention(self.num_heads)
308
+ else:
309
+ # split heads before split qkv
310
+ self.attention = QKVAttentionLegacy(self.num_heads)
311
+
312
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
313
+
314
+ def forward(self, x):
315
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
316
+ #return pt_checkpoint(self._forward, x) # pytorch
317
+
318
+ def _forward(self, x):
319
+ b, c, *spatial = x.shape
320
+ x = x.reshape(b, c, -1)
321
+ qkv = self.qkv(self.norm(x))
322
+ h = self.attention(qkv)
323
+ h = self.proj_out(h)
324
+ return (x + h).reshape(b, c, *spatial)
325
+
326
+
327
+ def count_flops_attn(model, _x, y):
328
+ """
329
+ A counter for the `thop` package to count the operations in an
330
+ attention operation.
331
+ Meant to be used like:
332
+ macs, params = thop.profile(
333
+ model,
334
+ inputs=(inputs, timestamps),
335
+ custom_ops={QKVAttention: QKVAttention.count_flops},
336
+ )
337
+ """
338
+ b, c, *spatial = y[0].shape
339
+ num_spatial = int(np.prod(spatial))
340
+ # We perform two matmuls with the same number of ops.
341
+ # The first computes the weight matrix, the second computes
342
+ # the combination of the value vectors.
343
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
344
+ model.total_ops += th.DoubleTensor([matmul_ops])
345
+
346
+
347
+ class QKVAttentionLegacy(nn.Module):
348
+ """
349
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
350
+ """
351
+
352
+ def __init__(self, n_heads):
353
+ super().__init__()
354
+ self.n_heads = n_heads
355
+
356
+ def forward(self, qkv):
357
+ """
358
+ Apply QKV attention.
359
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
360
+ :return: an [N x (H * C) x T] tensor after attention.
361
+ """
362
+ bs, width, length = qkv.shape
363
+ assert width % (3 * self.n_heads) == 0
364
+ ch = width // (3 * self.n_heads)
365
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
366
+ scale = 1 / math.sqrt(math.sqrt(ch))
367
+ weight = th.einsum(
368
+ "bct,bcs->bts", q * scale, k * scale
369
+ ) # More stable with f16 than dividing afterwards
370
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
371
+ a = th.einsum("bts,bcs->bct", weight, v)
372
+ return a.reshape(bs, -1, length)
373
+
374
+ @staticmethod
375
+ def count_flops(model, _x, y):
376
+ return count_flops_attn(model, _x, y)
377
+
378
+
379
+ class QKVAttention(nn.Module):
380
+ """
381
+ A module which performs QKV attention and splits in a different order.
382
+ """
383
+
384
+ def __init__(self, n_heads):
385
+ super().__init__()
386
+ self.n_heads = n_heads
387
+
388
+ def forward(self, qkv):
389
+ """
390
+ Apply QKV attention.
391
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
392
+ :return: an [N x (H * C) x T] tensor after attention.
393
+ """
394
+ bs, width, length = qkv.shape
395
+ assert width % (3 * self.n_heads) == 0
396
+ ch = width // (3 * self.n_heads)
397
+ q, k, v = qkv.chunk(3, dim=1)
398
+ scale = 1 / math.sqrt(math.sqrt(ch))
399
+ weight = th.einsum(
400
+ "bct,bcs->bts",
401
+ (q * scale).view(bs * self.n_heads, ch, length),
402
+ (k * scale).view(bs * self.n_heads, ch, length),
403
+ ) # More stable with f16 than dividing afterwards
404
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
405
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
406
+ return a.reshape(bs, -1, length)
407
+
408
+ @staticmethod
409
+ def count_flops(model, _x, y):
410
+ return count_flops_attn(model, _x, y)
411
+
412
+
413
+ class UNetModel(nn.Module):
414
+ """
415
+ The full UNet model with attention and timestep embedding.
416
+ :param in_channels: channels in the input Tensor.
417
+ :param model_channels: base channel count for the model.
418
+ :param out_channels: channels in the output Tensor.
419
+ :param num_res_blocks: number of residual blocks per downsample.
420
+ :param attention_resolutions: a collection of downsample rates at which
421
+ attention will take place. May be a set, list, or tuple.
422
+ For example, if this contains 4, then at 4x downsampling, attention
423
+ will be used.
424
+ :param dropout: the dropout probability.
425
+ :param channel_mult: channel multiplier for each level of the UNet.
426
+ :param conv_resample: if True, use learned convolutions for upsampling and
427
+ downsampling.
428
+ :param dims: determines if the signal is 1D, 2D, or 3D.
429
+ :param num_classes: if specified (as an int), then this model will be
430
+ class-conditional with `num_classes` classes.
431
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
432
+ :param num_heads: the number of attention heads in each attention layer.
433
+ :param num_heads_channels: if specified, ignore num_heads and instead use
434
+ a fixed channel width per attention head.
435
+ :param num_heads_upsample: works with num_heads to set a different number
436
+ of heads for upsampling. Deprecated.
437
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
438
+ :param resblock_updown: use residual blocks for up/downsampling.
439
+ :param use_new_attention_order: use a different attention pattern for potentially
440
+ increased efficiency.
441
+ """
442
+
443
+ def __init__(
444
+ self,
445
+ image_size,
446
+ in_channels,
447
+ model_channels,
448
+ out_channels,
449
+ num_res_blocks,
450
+ attention_resolutions,
451
+ dropout=0,
452
+ channel_mult=(1, 2, 4, 8),
453
+ conv_resample=True,
454
+ dims=2,
455
+ num_classes=None,
456
+ use_checkpoint=False,
457
+ use_fp16=False,
458
+ num_heads=-1,
459
+ num_head_channels=-1,
460
+ num_heads_upsample=-1,
461
+ use_scale_shift_norm=False,
462
+ resblock_updown=False,
463
+ use_new_attention_order=False,
464
+ use_spatial_transformer=False, # custom transformer support
465
+ transformer_depth=1, # custom transformer support
466
+ context_dim=None, # custom transformer support
467
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
468
+ legacy=True,
469
+ ):
470
+ super().__init__()
471
+ if use_spatial_transformer:
472
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
473
+
474
+ if context_dim is not None:
475
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
476
+ from omegaconf.listconfig import ListConfig
477
+ if type(context_dim) == ListConfig:
478
+ context_dim = list(context_dim)
479
+
480
+ if num_heads_upsample == -1:
481
+ num_heads_upsample = num_heads
482
+
483
+ if num_heads == -1:
484
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
485
+
486
+ if num_head_channels == -1:
487
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
488
+
489
+ self.image_size = image_size
490
+ self.in_channels = in_channels
491
+ self.model_channels = model_channels
492
+ self.out_channels = out_channels
493
+ self.num_res_blocks = num_res_blocks
494
+ self.attention_resolutions = attention_resolutions
495
+ self.dropout = dropout
496
+ self.channel_mult = channel_mult
497
+ self.conv_resample = conv_resample
498
+ self.num_classes = num_classes
499
+ self.use_checkpoint = use_checkpoint
500
+ self.dtype = th.float16 if use_fp16 else th.float32
501
+ self.num_heads = num_heads
502
+ self.num_head_channels = num_head_channels
503
+ self.num_heads_upsample = num_heads_upsample
504
+ self.predict_codebook_ids = n_embed is not None
505
+
506
+ time_embed_dim = model_channels * 4
507
+ self.time_embed = nn.Sequential(
508
+ linear(model_channels, time_embed_dim),
509
+ nn.SiLU(),
510
+ linear(time_embed_dim, time_embed_dim),
511
+ )
512
+
513
+ if self.num_classes is not None:
514
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
515
+
516
+ self.input_blocks = nn.ModuleList(
517
+ [
518
+ TimestepEmbedSequential(
519
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
520
+ )
521
+ ]
522
+ )
523
+ self._feature_size = model_channels
524
+ input_block_chans = [model_channels]
525
+ ch = model_channels
526
+ ds = 1
527
+ for level, mult in enumerate(channel_mult):
528
+ for _ in range(num_res_blocks):
529
+ layers = [
530
+ ResBlock(
531
+ ch,
532
+ time_embed_dim,
533
+ dropout,
534
+ out_channels=mult * model_channels,
535
+ dims=dims,
536
+ use_checkpoint=use_checkpoint,
537
+ use_scale_shift_norm=use_scale_shift_norm,
538
+ )
539
+ ]
540
+ ch = mult * model_channels
541
+ if ds in attention_resolutions:
542
+ if num_head_channels == -1:
543
+ dim_head = ch // num_heads
544
+ else:
545
+ num_heads = ch // num_head_channels
546
+ dim_head = num_head_channels
547
+ if legacy:
548
+ num_heads = 1
549
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
550
+ layers.append(
551
+ AttentionBlock(
552
+ ch,
553
+ use_checkpoint=use_checkpoint,
554
+ num_heads=num_heads,
555
+ num_head_channels=dim_head,
556
+ use_new_attention_order=use_new_attention_order,
557
+ ) if not use_spatial_transformer else SpatialTransformer(
558
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
559
+ )
560
+ )
561
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
562
+ self._feature_size += ch
563
+ input_block_chans.append(ch)
564
+ if level != len(channel_mult) - 1:
565
+ out_ch = ch
566
+ self.input_blocks.append(
567
+ TimestepEmbedSequential(
568
+ ResBlock(
569
+ ch,
570
+ time_embed_dim,
571
+ dropout,
572
+ out_channels=out_ch,
573
+ dims=dims,
574
+ use_checkpoint=use_checkpoint,
575
+ use_scale_shift_norm=use_scale_shift_norm,
576
+ down=True,
577
+ )
578
+ if resblock_updown
579
+ else Downsample(
580
+ ch, conv_resample, dims=dims, out_channels=out_ch
581
+ )
582
+ )
583
+ )
584
+ ch = out_ch
585
+ input_block_chans.append(ch)
586
+ ds *= 2
587
+ self._feature_size += ch
588
+
589
+ if num_head_channels == -1:
590
+ dim_head = ch // num_heads
591
+ else:
592
+ num_heads = ch // num_head_channels
593
+ dim_head = num_head_channels
594
+ if legacy:
595
+ num_heads = 1
596
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
597
+ self.middle_block = TimestepEmbedSequential(
598
+ ResBlock(
599
+ ch,
600
+ time_embed_dim,
601
+ dropout,
602
+ dims=dims,
603
+ use_checkpoint=use_checkpoint,
604
+ use_scale_shift_norm=use_scale_shift_norm,
605
+ ),
606
+ AttentionBlock(
607
+ ch,
608
+ use_checkpoint=use_checkpoint,
609
+ num_heads=num_heads,
610
+ num_head_channels=dim_head,
611
+ use_new_attention_order=use_new_attention_order,
612
+ ) if not use_spatial_transformer else SpatialTransformer(
613
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
614
+ ),
615
+ ResBlock(
616
+ ch,
617
+ time_embed_dim,
618
+ dropout,
619
+ dims=dims,
620
+ use_checkpoint=use_checkpoint,
621
+ use_scale_shift_norm=use_scale_shift_norm,
622
+ ),
623
+ )
624
+ self._feature_size += ch
625
+
626
+ self.output_blocks = nn.ModuleList([])
627
+ for level, mult in list(enumerate(channel_mult))[::-1]:
628
+ for i in range(num_res_blocks + 1):
629
+ ich = input_block_chans.pop()
630
+ layers = [
631
+ ResBlock(
632
+ ch + ich,
633
+ time_embed_dim,
634
+ dropout,
635
+ out_channels=model_channels * mult,
636
+ dims=dims,
637
+ use_checkpoint=use_checkpoint,
638
+ use_scale_shift_norm=use_scale_shift_norm,
639
+ )
640
+ ]
641
+ ch = model_channels * mult
642
+ if ds in attention_resolutions:
643
+ if num_head_channels == -1:
644
+ dim_head = ch // num_heads
645
+ else:
646
+ num_heads = ch // num_head_channels
647
+ dim_head = num_head_channels
648
+ if legacy:
649
+ num_heads = 1
650
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
651
+ layers.append(
652
+ AttentionBlock(
653
+ ch,
654
+ use_checkpoint=use_checkpoint,
655
+ num_heads=num_heads_upsample,
656
+ num_head_channels=dim_head,
657
+ use_new_attention_order=use_new_attention_order,
658
+ ) if not use_spatial_transformer else SpatialTransformer(
659
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
660
+ )
661
+ )
662
+ if level and i == num_res_blocks:
663
+ out_ch = ch
664
+ layers.append(
665
+ ResBlock(
666
+ ch,
667
+ time_embed_dim,
668
+ dropout,
669
+ out_channels=out_ch,
670
+ dims=dims,
671
+ use_checkpoint=use_checkpoint,
672
+ use_scale_shift_norm=use_scale_shift_norm,
673
+ up=True,
674
+ )
675
+ if resblock_updown
676
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
677
+ )
678
+ ds //= 2
679
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
680
+ self._feature_size += ch
681
+
682
+ self.out = nn.Sequential(
683
+ normalization(ch),
684
+ nn.SiLU(),
685
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
686
+ )
687
+ if self.predict_codebook_ids:
688
+ self.id_predictor = nn.Sequential(
689
+ normalization(ch),
690
+ conv_nd(dims, model_channels, n_embed, 1),
691
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
692
+ )
693
+
694
+ def convert_to_fp16(self):
695
+ """
696
+ Convert the torso of the model to float16.
697
+ """
698
+ self.input_blocks.apply(convert_module_to_f16)
699
+ self.middle_block.apply(convert_module_to_f16)
700
+ self.output_blocks.apply(convert_module_to_f16)
701
+
702
+ def convert_to_fp32(self):
703
+ """
704
+ Convert the torso of the model to float32.
705
+ """
706
+ self.input_blocks.apply(convert_module_to_f32)
707
+ self.middle_block.apply(convert_module_to_f32)
708
+ self.output_blocks.apply(convert_module_to_f32)
709
+
710
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
711
+ """
712
+ Apply the model to an input batch.
713
+ :param x: an [N x C x ...] Tensor of inputs.
714
+ :param timesteps: a 1-D batch of timesteps.
715
+ :param context: conditioning plugged in via crossattn
716
+ :param y: an [N] Tensor of labels, if class-conditional.
717
+ :return: an [N x C x ...] Tensor of outputs.
718
+ """
719
+ assert (y is not None) == (
720
+ self.num_classes is not None
721
+ ), "must specify y if and only if the model is class-conditional"
722
+ hs = []
723
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
724
+ emb = self.time_embed(t_emb)
725
+
726
+ if self.num_classes is not None:
727
+ assert y.shape == (x.shape[0],)
728
+ emb = emb + self.label_emb(y)
729
+
730
+ h = x.type(self.dtype)
731
+ for module in self.input_blocks:
732
+ h = module(h, emb, context)
733
+ hs.append(h)
734
+ h = self.middle_block(h, emb, context)
735
+ for module in self.output_blocks:
736
+ h = th.cat([h, hs.pop()], dim=1)
737
+ h = module(h, emb, context)
738
+ h = h.type(x.dtype)
739
+ if self.predict_codebook_ids:
740
+ return self.id_predictor(h)
741
+ else:
742
+ return self.out(h)
743
+
744
+
745
+ class EncoderUNetModel(nn.Module):
746
+ """
747
+ The half UNet model with attention and timestep embedding.
748
+ For usage, see UNet.
749
+ """
750
+
751
+ def __init__(
752
+ self,
753
+ image_size,
754
+ in_channels,
755
+ model_channels,
756
+ out_channels,
757
+ num_res_blocks,
758
+ attention_resolutions,
759
+ dropout=0,
760
+ channel_mult=(1, 2, 4, 8),
761
+ conv_resample=True,
762
+ dims=2,
763
+ use_checkpoint=False,
764
+ use_fp16=False,
765
+ num_heads=1,
766
+ num_head_channels=-1,
767
+ num_heads_upsample=-1,
768
+ use_scale_shift_norm=False,
769
+ resblock_updown=False,
770
+ use_new_attention_order=False,
771
+ pool="adaptive",
772
+ *args,
773
+ **kwargs
774
+ ):
775
+ super().__init__()
776
+
777
+ if num_heads_upsample == -1:
778
+ num_heads_upsample = num_heads
779
+
780
+ self.in_channels = in_channels
781
+ self.model_channels = model_channels
782
+ self.out_channels = out_channels
783
+ self.num_res_blocks = num_res_blocks
784
+ self.attention_resolutions = attention_resolutions
785
+ self.dropout = dropout
786
+ self.channel_mult = channel_mult
787
+ self.conv_resample = conv_resample
788
+ self.use_checkpoint = use_checkpoint
789
+ self.dtype = th.float16 if use_fp16 else th.float32
790
+ self.num_heads = num_heads
791
+ self.num_head_channels = num_head_channels
792
+ self.num_heads_upsample = num_heads_upsample
793
+
794
+ time_embed_dim = model_channels * 4
795
+ self.time_embed = nn.Sequential(
796
+ linear(model_channels, time_embed_dim),
797
+ nn.SiLU(),
798
+ linear(time_embed_dim, time_embed_dim),
799
+ )
800
+
801
+ self.input_blocks = nn.ModuleList(
802
+ [
803
+ TimestepEmbedSequential(
804
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
805
+ )
806
+ ]
807
+ )
808
+ self._feature_size = model_channels
809
+ input_block_chans = [model_channels]
810
+ ch = model_channels
811
+ ds = 1
812
+ for level, mult in enumerate(channel_mult):
813
+ for _ in range(num_res_blocks):
814
+ layers = [
815
+ ResBlock(
816
+ ch,
817
+ time_embed_dim,
818
+ dropout,
819
+ out_channels=mult * model_channels,
820
+ dims=dims,
821
+ use_checkpoint=use_checkpoint,
822
+ use_scale_shift_norm=use_scale_shift_norm,
823
+ )
824
+ ]
825
+ ch = mult * model_channels
826
+ if ds in attention_resolutions:
827
+ layers.append(
828
+ AttentionBlock(
829
+ ch,
830
+ use_checkpoint=use_checkpoint,
831
+ num_heads=num_heads,
832
+ num_head_channels=num_head_channels,
833
+ use_new_attention_order=use_new_attention_order,
834
+ )
835
+ )
836
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
837
+ self._feature_size += ch
838
+ input_block_chans.append(ch)
839
+ if level != len(channel_mult) - 1:
840
+ out_ch = ch
841
+ self.input_blocks.append(
842
+ TimestepEmbedSequential(
843
+ ResBlock(
844
+ ch,
845
+ time_embed_dim,
846
+ dropout,
847
+ out_channels=out_ch,
848
+ dims=dims,
849
+ use_checkpoint=use_checkpoint,
850
+ use_scale_shift_norm=use_scale_shift_norm,
851
+ down=True,
852
+ )
853
+ if resblock_updown
854
+ else Downsample(
855
+ ch, conv_resample, dims=dims, out_channels=out_ch
856
+ )
857
+ )
858
+ )
859
+ ch = out_ch
860
+ input_block_chans.append(ch)
861
+ ds *= 2
862
+ self._feature_size += ch
863
+
864
+ self.middle_block = TimestepEmbedSequential(
865
+ ResBlock(
866
+ ch,
867
+ time_embed_dim,
868
+ dropout,
869
+ dims=dims,
870
+ use_checkpoint=use_checkpoint,
871
+ use_scale_shift_norm=use_scale_shift_norm,
872
+ ),
873
+ AttentionBlock(
874
+ ch,
875
+ use_checkpoint=use_checkpoint,
876
+ num_heads=num_heads,
877
+ num_head_channels=num_head_channels,
878
+ use_new_attention_order=use_new_attention_order,
879
+ ),
880
+ ResBlock(
881
+ ch,
882
+ time_embed_dim,
883
+ dropout,
884
+ dims=dims,
885
+ use_checkpoint=use_checkpoint,
886
+ use_scale_shift_norm=use_scale_shift_norm,
887
+ ),
888
+ )
889
+ self._feature_size += ch
890
+ self.pool = pool
891
+ if pool == "adaptive":
892
+ self.out = nn.Sequential(
893
+ normalization(ch),
894
+ nn.SiLU(),
895
+ nn.AdaptiveAvgPool2d((1, 1)),
896
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
897
+ nn.Flatten(),
898
+ )
899
+ elif pool == "attention":
900
+ assert num_head_channels != -1
901
+ self.out = nn.Sequential(
902
+ normalization(ch),
903
+ nn.SiLU(),
904
+ AttentionPool2d(
905
+ (image_size // ds), ch, num_head_channels, out_channels
906
+ ),
907
+ )
908
+ elif pool == "spatial":
909
+ self.out = nn.Sequential(
910
+ nn.Linear(self._feature_size, 2048),
911
+ nn.ReLU(),
912
+ nn.Linear(2048, self.out_channels),
913
+ )
914
+ elif pool == "spatial_v2":
915
+ self.out = nn.Sequential(
916
+ nn.Linear(self._feature_size, 2048),
917
+ normalization(2048),
918
+ nn.SiLU(),
919
+ nn.Linear(2048, self.out_channels),
920
+ )
921
+ else:
922
+ raise NotImplementedError(f"Unexpected {pool} pooling")
923
+
924
+ def convert_to_fp16(self):
925
+ """
926
+ Convert the torso of the model to float16.
927
+ """
928
+ self.input_blocks.apply(convert_module_to_f16)
929
+ self.middle_block.apply(convert_module_to_f16)
930
+
931
+ def convert_to_fp32(self):
932
+ """
933
+ Convert the torso of the model to float32.
934
+ """
935
+ self.input_blocks.apply(convert_module_to_f32)
936
+ self.middle_block.apply(convert_module_to_f32)
937
+
938
+ def forward(self, x, timesteps):
939
+ """
940
+ Apply the model to an input batch.
941
+ :param x: an [N x C x ...] Tensor of inputs.
942
+ :param timesteps: a 1-D batch of timesteps.
943
+ :return: an [N x K] Tensor of outputs.
944
+ """
945
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
946
+
947
+ results = []
948
+ h = x.type(self.dtype)
949
+ for module in self.input_blocks:
950
+ h = module(h, emb)
951
+ if self.pool.startswith("spatial"):
952
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
953
+ h = self.middle_block(h, emb)
954
+ if self.pool.startswith("spatial"):
955
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
956
+ h = th.cat(results, axis=-1)
957
+ return self.out(h)
958
+ else:
959
+ h = h.type(x.dtype)
960
+ return self.out(h)
961
+
latent-diffusion/ldm/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from einops import repeat
17
+
18
+ from ldm.util import instantiate_from_config
19
+
20
+
21
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22
+ if schedule == "linear":
23
+ betas = (
24
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25
+ )
26
+
27
+ elif schedule == "cosine":
28
+ timesteps = (
29
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30
+ )
31
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
32
+ alphas = torch.cos(alphas).pow(2)
33
+ alphas = alphas / alphas[0]
34
+ betas = 1 - alphas[1:] / alphas[:-1]
35
+ betas = np.clip(betas, a_min=0, a_max=0.999)
36
+
37
+ elif schedule == "sqrt_linear":
38
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39
+ elif schedule == "sqrt":
40
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41
+ else:
42
+ raise ValueError(f"schedule '{schedule}' unknown.")
43
+ return betas.numpy()
44
+
45
+
46
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47
+ if ddim_discr_method == 'uniform':
48
+ c = num_ddpm_timesteps // num_ddim_timesteps
49
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50
+ elif ddim_discr_method == 'quad':
51
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52
+ else:
53
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54
+
55
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
57
+ steps_out = ddim_timesteps + 1
58
+ if verbose:
59
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
60
+ return steps_out
61
+
62
+
63
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64
+ # select alphas for computing the variance schedule
65
+ alphas = alphacums[ddim_timesteps]
66
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67
+
68
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
69
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70
+ if verbose:
71
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72
+ print(f'For the chosen value of eta, which is {eta}, '
73
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74
+ return sigmas, alphas, alphas_prev
75
+
76
+
77
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78
+ """
79
+ Create a beta schedule that discretizes the given alpha_t_bar function,
80
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
81
+ :param num_diffusion_timesteps: the number of betas to produce.
82
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83
+ produces the cumulative product of (1-beta) up to that
84
+ part of the diffusion process.
85
+ :param max_beta: the maximum beta to use; use values lower than 1 to
86
+ prevent singularities.
87
+ """
88
+ betas = []
89
+ for i in range(num_diffusion_timesteps):
90
+ t1 = i / num_diffusion_timesteps
91
+ t2 = (i + 1) / num_diffusion_timesteps
92
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93
+ return np.array(betas)
94
+
95
+
96
+ def extract_into_tensor(a, t, x_shape):
97
+ b, *_ = t.shape
98
+ out = a.gather(-1, t)
99
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100
+
101
+
102
+ def checkpoint(func, inputs, params, flag):
103
+ """
104
+ Evaluate a function without caching intermediate activations, allowing for
105
+ reduced memory at the expense of extra compute in the backward pass.
106
+ :param func: the function to evaluate.
107
+ :param inputs: the argument sequence to pass to `func`.
108
+ :param params: a sequence of parameters `func` depends on but does not
109
+ explicitly take as arguments.
110
+ :param flag: if False, disable gradient checkpointing.
111
+ """
112
+ if flag:
113
+ args = tuple(inputs) + tuple(params)
114
+ return CheckpointFunction.apply(func, len(inputs), *args)
115
+ else:
116
+ return func(*inputs)
117
+
118
+
119
+ class CheckpointFunction(torch.autograd.Function):
120
+ @staticmethod
121
+ def forward(ctx, run_function, length, *args):
122
+ ctx.run_function = run_function
123
+ ctx.input_tensors = list(args[:length])
124
+ ctx.input_params = list(args[length:])
125
+
126
+ with torch.no_grad():
127
+ output_tensors = ctx.run_function(*ctx.input_tensors)
128
+ return output_tensors
129
+
130
+ @staticmethod
131
+ def backward(ctx, *output_grads):
132
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
133
+ with torch.enable_grad():
134
+ # Fixes a bug where the first op in run_function modifies the
135
+ # Tensor storage in place, which is not allowed for detach()'d
136
+ # Tensors.
137
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
138
+ output_tensors = ctx.run_function(*shallow_copies)
139
+ input_grads = torch.autograd.grad(
140
+ output_tensors,
141
+ ctx.input_tensors + ctx.input_params,
142
+ output_grads,
143
+ allow_unused=True,
144
+ )
145
+ del ctx.input_tensors
146
+ del ctx.input_params
147
+ del output_tensors
148
+ return (None, None) + input_grads
149
+
150
+
151
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
152
+ """
153
+ Create sinusoidal timestep embeddings.
154
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
155
+ These may be fractional.
156
+ :param dim: the dimension of the output.
157
+ :param max_period: controls the minimum frequency of the embeddings.
158
+ :return: an [N x dim] Tensor of positional embeddings.
159
+ """
160
+ if not repeat_only:
161
+ half = dim // 2
162
+ freqs = torch.exp(
163
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
164
+ ).to(device=timesteps.device)
165
+ args = timesteps[:, None].float() * freqs[None]
166
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
167
+ if dim % 2:
168
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
169
+ else:
170
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
171
+ return embedding
172
+
173
+
174
+ def zero_module(module):
175
+ """
176
+ Zero out the parameters of a module and return it.
177
+ """
178
+ for p in module.parameters():
179
+ p.detach().zero_()
180
+ return module
181
+
182
+
183
+ def scale_module(module, scale):
184
+ """
185
+ Scale the parameters of a module and return it.
186
+ """
187
+ for p in module.parameters():
188
+ p.detach().mul_(scale)
189
+ return module
190
+
191
+
192
+ def mean_flat(tensor):
193
+ """
194
+ Take the mean over all non-batch dimensions.
195
+ """
196
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
197
+
198
+
199
+ def normalization(channels):
200
+ """
201
+ Make a standard normalization layer.
202
+ :param channels: number of input channels.
203
+ :return: an nn.Module for normalization.
204
+ """
205
+ return GroupNorm32(32, channels)
206
+
207
+
208
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
209
+ class SiLU(nn.Module):
210
+ def forward(self, x):
211
+ return x * torch.sigmoid(x)
212
+
213
+
214
+ class GroupNorm32(nn.GroupNorm):
215
+ def forward(self, x):
216
+ return super().forward(x.float()).type(x.dtype)
217
+
218
+ def conv_nd(dims, *args, **kwargs):
219
+ """
220
+ Create a 1D, 2D, or 3D convolution module.
221
+ """
222
+ if dims == 1:
223
+ return nn.Conv1d(*args, **kwargs)
224
+ elif dims == 2:
225
+ return nn.Conv2d(*args, **kwargs)
226
+ elif dims == 3:
227
+ return nn.Conv3d(*args, **kwargs)
228
+ raise ValueError(f"unsupported dimensions: {dims}")
229
+
230
+
231
+ def linear(*args, **kwargs):
232
+ """
233
+ Create a linear module.
234
+ """
235
+ return nn.Linear(*args, **kwargs)
236
+
237
+
238
+ def avg_pool_nd(dims, *args, **kwargs):
239
+ """
240
+ Create a 1D, 2D, or 3D average pooling module.
241
+ """
242
+ if dims == 1:
243
+ return nn.AvgPool1d(*args, **kwargs)
244
+ elif dims == 2:
245
+ return nn.AvgPool2d(*args, **kwargs)
246
+ elif dims == 3:
247
+ return nn.AvgPool3d(*args, **kwargs)
248
+ raise ValueError(f"unsupported dimensions: {dims}")
249
+
250
+
251
+ class HybridConditioner(nn.Module):
252
+
253
+ def __init__(self, c_concat_config, c_crossattn_config):
254
+ super().__init__()
255
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
256
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
257
+
258
+ def forward(self, c_concat, c_crossattn):
259
+ c_concat = self.concat_conditioner(c_concat)
260
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
261
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
262
+
263
+
264
+ def noise_like(shape, device, repeat=False):
265
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
266
+ noise = lambda: torch.randn(shape, device=device)
267
+ return repeat_noise() if repeat else noise()
latent-diffusion/ldm/modules/distributions/__init__.py ADDED
File without changes
latent-diffusion/ldm/modules/distributions/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (198 Bytes). View file
latent-diffusion/ldm/modules/distributions/__pycache__/distributions.cpython-39.pyc ADDED
Binary file (3.81 kB). View file
latent-diffusion/ldm/modules/distributions/distributions.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34
+
35
+ def sample(self):
36
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37
+ return x
38
+
39
+ def kl(self, other=None):
40
+ if self.deterministic:
41
+ return torch.Tensor([0.])
42
+ else:
43
+ if other is None:
44
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
45
+ + self.var - 1.0 - self.logvar,
46
+ dim=[1, 2, 3])
47
+ else:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean - other.mean, 2) / other.var
50
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
51
+ dim=[1, 2, 3])
52
+
53
+ def nll(self, sample, dims=[1,2,3]):
54
+ if self.deterministic:
55
+ return torch.Tensor([0.])
56
+ logtwopi = np.log(2.0 * np.pi)
57
+ return 0.5 * torch.sum(
58
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59
+ dim=dims)
60
+
61
+ def mode(self):
62
+ return self.mean
63
+
64
+
65
+ def normal_kl(mean1, logvar1, mean2, logvar2):
66
+ """
67
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68
+ Compute the KL divergence between two gaussians.
69
+ Shapes are automatically broadcasted, so batches can be compared to
70
+ scalars, among other use cases.
71
+ """
72
+ tensor = None
73
+ for obj in (mean1, logvar1, mean2, logvar2):
74
+ if isinstance(obj, torch.Tensor):
75
+ tensor = obj
76
+ break
77
+ assert tensor is not None, "at least one argument must be a Tensor"
78
+
79
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
80
+ # Tensors, but it does not work for torch.exp().
81
+ logvar1, logvar2 = [
82
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83
+ for x in (logvar1, logvar2)
84
+ ]
85
+
86
+ return 0.5 * (
87
+ -1.0
88
+ + logvar2
89
+ - logvar1
90
+ + torch.exp(logvar1 - logvar2)
91
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92
+ )
latent-diffusion/ldm/modules/ema.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError('Decay must be between 0 and 1')
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14
+ else torch.tensor(-1,dtype=torch.int))
15
+
16
+ for name, p in model.named_parameters():
17
+ if p.requires_grad:
18
+ #remove as '.'-character is not allowed in buffers
19
+ s_name = name.replace('.','')
20
+ self.m_name2s_name.update({name:s_name})
21
+ self.register_buffer(s_name,p.clone().detach().data)
22
+
23
+ self.collected_params = []
24
+
25
+ def forward(self,model):
26
+ decay = self.decay
27
+
28
+ if self.num_updates >= 0:
29
+ self.num_updates += 1
30
+ decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31
+
32
+ one_minus_decay = 1.0 - decay
33
+
34
+ with torch.no_grad():
35
+ m_param = dict(model.named_parameters())
36
+ shadow_params = dict(self.named_buffers())
37
+
38
+ for key in m_param:
39
+ if m_param[key].requires_grad:
40
+ sname = self.m_name2s_name[key]
41
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43
+ else:
44
+ assert not key in self.m_name2s_name
45
+
46
+ def copy_to(self, model):
47
+ m_param = dict(model.named_parameters())
48
+ shadow_params = dict(self.named_buffers())
49
+ for key in m_param:
50
+ if m_param[key].requires_grad:
51
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52
+ else:
53
+ assert not key in self.m_name2s_name
54
+
55
+ def store(self, parameters):
56
+ """
57
+ Save the current parameters for restoring later.
58
+ Args:
59
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60
+ temporarily stored.
61
+ """
62
+ self.collected_params = [param.clone() for param in parameters]
63
+
64
+ def restore(self, parameters):
65
+ """
66
+ Restore the parameters stored with the `store` method.
67
+ Useful to validate the model with EMA parameters without affecting the
68
+ original optimization process. Store the parameters before the
69
+ `copy_to` method. After validation (or model saving), use this to
70
+ restore the former parameters.
71
+ Args:
72
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73
+ updated with the stored parameters.
74
+ """
75
+ for c_param, param in zip(self.collected_params, parameters):
76
+ param.data.copy_(c_param.data)