multimodalart HF staff commited on
Commit
7e93a0e
β€’
1 Parent(s): 3ddcbca

Upload 81 files

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. README.md +292 -12
  2. configs/.DS_Store +0 -0
  3. configs/inference/sd_2_1.yaml +60 -0
  4. configs/inference/sd_2_1_768.yaml +60 -0
  5. configs/inference/sd_xl_base.yaml +93 -0
  6. configs/inference/sd_xl_refiner.yaml +86 -0
  7. configs/inference/svd.yaml +131 -0
  8. configs/inference/svd_image_decoder.yaml +114 -0
  9. requirements/pt2.txt +39 -0
  10. scripts/.DS_Store +0 -0
  11. scripts/__init__.py +0 -0
  12. scripts/demo/__init__.py +0 -0
  13. scripts/demo/detect.py +156 -0
  14. scripts/demo/discretization.py +59 -0
  15. scripts/demo/sampling.py +364 -0
  16. scripts/demo/streamlit_helpers.py +928 -0
  17. scripts/demo/video_sampling.py +200 -0
  18. scripts/sampling/configs/svd.yaml +146 -0
  19. scripts/sampling/configs/svd_image_decoder.yaml +129 -0
  20. scripts/sampling/configs/svd_xt.yaml +146 -0
  21. scripts/sampling/configs/svd_xt_image_decoder.yaml +129 -0
  22. scripts/sampling/simple_video_sample.py +278 -0
  23. scripts/tests/attention.py +319 -0
  24. scripts/util/__init__.py +0 -0
  25. scripts/util/detection/__init__.py +0 -0
  26. scripts/util/detection/nsfw_and_watermark_dectection.py +110 -0
  27. scripts/util/detection/p_head_v1.npz +3 -0
  28. scripts/util/detection/w_head_v1.npz +3 -0
  29. sgm/__init__.py +4 -0
  30. sgm/data/__init__.py +1 -0
  31. sgm/data/cifar10.py +67 -0
  32. sgm/data/dataset.py +80 -0
  33. sgm/data/mnist.py +85 -0
  34. sgm/inference/api.py +386 -0
  35. sgm/inference/helpers.py +305 -0
  36. sgm/lr_scheduler.py +135 -0
  37. sgm/models/__init__.py +2 -0
  38. sgm/models/autoencoder.py +619 -0
  39. sgm/models/diffusion.py +346 -0
  40. sgm/modules/__init__.py +6 -0
  41. sgm/modules/attention.py +759 -0
  42. sgm/modules/autoencoding/__init__.py +0 -0
  43. sgm/modules/autoencoding/losses/__init__.py +7 -0
  44. sgm/modules/autoencoding/losses/discriminator_loss.py +306 -0
  45. sgm/modules/autoencoding/losses/lpips.py +73 -0
  46. sgm/modules/autoencoding/lpips/__init__.py +0 -0
  47. sgm/modules/autoencoding/lpips/loss/.gitignore +1 -0
  48. sgm/modules/autoencoding/lpips/loss/LICENSE +23 -0
  49. sgm/modules/autoencoding/lpips/loss/__init__.py +0 -0
  50. sgm/modules/autoencoding/lpips/loss/lpips.py +147 -0
README.md CHANGED
@@ -1,12 +1,292 @@
1
- ---
2
- title: Stable Video Diffusion
3
- emoji: 🐨
4
- colorFrom: blue
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 4.5.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generative Models by Stability AI
2
+
3
+ ![sample1](assets/000.jpg)
4
+
5
+ ## News
6
+
7
+ **November 21, 2023**
8
+
9
+ - We are releasing Stable Video Diffusion, an image-to-video model, for research purposes:
10
+ - [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid): This model was trained to generate 14
11
+ frames at resolution 576x1024 given a context frame of the same size.
12
+ We use the standard image encoder from SD 2.1, but replace the decoder with a temporally-aware `deflickering decoder`.
13
+ - [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt): Same architecture as `SVD` but finetuned
14
+ for 25 frame generation.
15
+ - We provide a streamlit demo `scripts/demo/video_sampling.py` and a standalone python script `scripts/sampling/simple_video_sample.py` for inference of both models.
16
+ - Alongside the model, we will release a technical report shortly. Stay tuned.
17
+
18
+ ![tile](assets/tile.gif)
19
+
20
+ **July 26, 2023**
21
+
22
+ - We are releasing two new open models with a
23
+ permissive [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0) (see [Inference](#inference) for file
24
+ hashes):
25
+ - [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0): An improved version
26
+ over `SDXL-base-0.9`.
27
+ - [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0): An improved version
28
+ over `SDXL-refiner-0.9`.
29
+
30
+ ![sample2](assets/001_with_eval.png)
31
+
32
+ **July 4, 2023**
33
+
34
+ - A technical report on SDXL is now available [here](https://arxiv.org/abs/2307.01952).
35
+
36
+ **June 22, 2023**
37
+
38
+ - We are releasing two new diffusion models for research purposes:
39
+ - `SDXL-base-0.9`: The base model was trained on a variety of aspect ratios on images with resolution 1024^2. The
40
+ base model uses [OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip)
41
+ and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main) for text encoding whereas the refiner model only uses
42
+ the OpenCLIP model.
43
+ - `SDXL-refiner-0.9`: The refiner has been trained to denoise small noise levels of high quality data and as such is
44
+ not expected to work as a text-to-image model; instead, it should only be used as an image-to-image model.
45
+
46
+ If you would like to access these models for your research, please apply using one of the following links:
47
+ [SDXL-0.9-Base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),
48
+ and [SDXL-0.9-Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
49
+ This means that you can apply for any of the two links - and if you are granted - you can access both.
50
+ Please log in to your Hugging Face Account with your organization email to request access.
51
+ **We plan to do a full release soon (July).**
52
+
53
+ ## The codebase
54
+
55
+ ### General Philosophy
56
+
57
+ Modularity is king. This repo implements a config-driven approach where we build and combine submodules by
58
+ calling `instantiate_from_config()` on objects defined in yaml configs. See `configs/` for many examples.
59
+
60
+ ### Changelog from the old `ldm` codebase
61
+
62
+ For training, we use [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), but it should be easy to use other
63
+ training wrappers around the base modules. The core diffusion model class (formerly `LatentDiffusion`,
64
+ now `DiffusionEngine`) has been cleaned up:
65
+
66
+ - No more extensive subclassing! We now handle all types of conditioning inputs (vectors, sequences and spatial
67
+ conditionings, and all combinations thereof) in a single class: `GeneralConditioner`,
68
+ see `sgm/modules/encoders/modules.py`.
69
+ - We separate guiders (such as classifier-free guidance, see `sgm/modules/diffusionmodules/guiders.py`) from the
70
+ samplers (`sgm/modules/diffusionmodules/sampling.py`), and the samplers are independent of the model.
71
+ - We adopt the ["denoiser framework"](https://arxiv.org/abs/2206.00364) for both training and inference (most notable
72
+ change is probably now the option to train continuous time models):
73
+ * Discrete times models (denoisers) are simply a special case of continuous time models (denoisers);
74
+ see `sgm/modules/diffusionmodules/denoiser.py`.
75
+ * The following features are now independent: weighting of the diffusion loss
76
+ function (`sgm/modules/diffusionmodules/denoiser_weighting.py`), preconditioning of the
77
+ network (`sgm/modules/diffusionmodules/denoiser_scaling.py`), and sampling of noise levels during
78
+ training (`sgm/modules/diffusionmodules/sigma_sampling.py`).
79
+ - Autoencoding models have also been cleaned up.
80
+
81
+ ## Installation:
82
+
83
+ <a name="installation"></a>
84
+
85
+ #### 1. Clone the repo
86
+
87
+ ```shell
88
+ git clone git@github.com:Stability-AI/generative-models.git
89
+ cd generative-models
90
+ ```
91
+
92
+ #### 2. Setting up the virtualenv
93
+
94
+ This is assuming you have navigated to the `generative-models` root after cloning it.
95
+
96
+ **NOTE:** This is tested under `python3.10`. For other python versions, you might encounter version conflicts.
97
+
98
+ **PyTorch 2.0**
99
+
100
+ ```shell
101
+ # install required packages from pypi
102
+ python3 -m venv .pt2
103
+ source .pt2/bin/activate
104
+ pip3 install -r requirements/pt2.txt
105
+ ```
106
+
107
+ #### 3. Install `sgm`
108
+
109
+ ```shell
110
+ pip3 install .
111
+ ```
112
+
113
+ #### 4. Install `sdata` for training
114
+
115
+ ```shell
116
+ pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
117
+ ```
118
+
119
+ ## Packaging
120
+
121
+ This repository uses PEP 517 compliant packaging using [Hatch](https://hatch.pypa.io/latest/).
122
+
123
+ To build a distributable wheel, install `hatch` and run `hatch build`
124
+ (specifying `-t wheel` will skip building a sdist, which is not necessary).
125
+
126
+ ```
127
+ pip install hatch
128
+ hatch build -t wheel
129
+ ```
130
+
131
+ You will find the built package in `dist/`. You can install the wheel with `pip install dist/*.whl`.
132
+
133
+ Note that the package does **not** currently specify dependencies; you will need to install the required packages,
134
+ depending on your use case and PyTorch version, manually.
135
+
136
+ ## Inference
137
+
138
+ We provide a [streamlit](https://streamlit.io/) demo for text-to-image and image-to-image sampling
139
+ in `scripts/demo/sampling.py`.
140
+ We provide file hashes for the complete file as well as for only the saved tensors in the file (
141
+ see [Model Spec](https://github.com/Stability-AI/ModelSpec) for a script to evaluate that).
142
+ The following models are currently supported:
143
+
144
+ - [SDXL-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
145
+ ```
146
+ File Hash (sha256): 31e35c80fc4829d14f90153f4c74cd59c90b779f6afe05a74cd6120b893f7e5b
147
+ Tensordata Hash (sha256): 0xd7a9105a900fd52748f20725fe52fe52b507fd36bee4fc107b1550a26e6ee1d7
148
+ ```
149
+ - [SDXL-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)
150
+ ```
151
+ File Hash (sha256): 7440042bbdc8a24813002c09b6b69b64dc90fded4472613437b7f55f9b7d9c5f
152
+ Tensordata Hash (sha256): 0x1a77d21bebc4b4de78c474a90cb74dc0d2217caf4061971dbfa75ad406b75d81
153
+ ```
154
+ - [SDXL-base-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9)
155
+ - [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9)
156
+ - [SD-2.1-512](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned.safetensors)
157
+ - [SD-2.1-768](https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors)
158
+
159
+ **Weights for SDXL**:
160
+
161
+ **SDXL-1.0:**
162
+ The weights of SDXL-1.0 are available (subject to
163
+ a [`CreativeML Open RAIL++-M` license](model_licenses/LICENSE-SDXL1.0)) here:
164
+
165
+ - base model: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/
166
+ - refiner model: https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/
167
+
168
+ **SDXL-0.9:**
169
+ The weights of SDXL-0.9 are available and subject to a [research license](model_licenses/LICENSE-SDXL0.9).
170
+ If you would like to access these models for your research, please apply using one of the following links:
171
+ [SDXL-base-0.9 model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),
172
+ and [SDXL-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9).
173
+ This means that you can apply for any of the two links - and if you are granted - you can access both.
174
+ Please log in to your Hugging Face Account with your organization email to request access.
175
+
176
+ After obtaining the weights, place them into `checkpoints/`.
177
+ Next, start the demo using
178
+
179
+ ```
180
+ streamlit run scripts/demo/sampling.py --server.port <your_port>
181
+ ```
182
+
183
+ ### Invisible Watermark Detection
184
+
185
+ Images generated with our code use the
186
+ [invisible-watermark](https://github.com/ShieldMnt/invisible-watermark/)
187
+ library to embed an invisible watermark into the model output. We also provide
188
+ a script to easily detect that watermark. Please note that this watermark is
189
+ not the same as in previous Stable Diffusion 1.x/2.x versions.
190
+
191
+ To run the script you need to either have a working installation as above or
192
+ try an _experimental_ import using only a minimal amount of packages:
193
+
194
+ ```bash
195
+ python -m venv .detect
196
+ source .detect/bin/activate
197
+
198
+ pip install "numpy>=1.17" "PyWavelets>=1.1.1" "opencv-python>=4.1.0.25"
199
+ pip install --no-deps invisible-watermark
200
+ ```
201
+
202
+ To run the script you need to have a working installation as above. The script
203
+ is then useable in the following ways (don't forget to activate your
204
+ virtual environment beforehand, e.g. `source .pt1/bin/activate`):
205
+
206
+ ```bash
207
+ # test a single file
208
+ python scripts/demo/detect.py <your filename here>
209
+ # test multiple files at once
210
+ python scripts/demo/detect.py <filename 1> <filename 2> ... <filename n>
211
+ # test all files in a specific folder
212
+ python scripts/demo/detect.py <your folder name here>/*
213
+ ```
214
+
215
+ ## Training:
216
+
217
+ We are providing example training configs in `configs/example_training`. To launch a training, run
218
+
219
+ ```
220
+ python main.py --base configs/<config1.yaml> configs/<config2.yaml>
221
+ ```
222
+
223
+ where configs are merged from left to right (later configs overwrite the same values).
224
+ This can be used to combine model, training and data configs. However, all of them can also be
225
+ defined in a single config. For example, to run a class-conditional pixel-based diffusion model training on MNIST,
226
+ run
227
+
228
+ ```bash
229
+ python main.py --base configs/example_training/toy/mnist_cond.yaml
230
+ ```
231
+
232
+ **NOTE 1:** Using the non-toy-dataset
233
+ configs `configs/example_training/imagenet-f8_cond.yaml`, `configs/example_training/txt2img-clipl.yaml`
234
+ and `configs/example_training/txt2img-clipl-legacy-ucg-training.yaml` for training will require edits depending on the
235
+ used dataset (which is expected to stored in tar-file in
236
+ the [webdataset-format](https://github.com/webdataset/webdataset)). To find the parts which have to be adapted, search
237
+ for comments containing `USER:` in the respective config.
238
+
239
+ **NOTE 2:** This repository supports both `pytorch1.13` and `pytorch2`for training generative models. However for
240
+ autoencoder training as e.g. in `configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml`,
241
+ only `pytorch1.13` is supported.
242
+
243
+ **NOTE 3:** Training latent generative models (as e.g. in `configs/example_training/imagenet-f8_cond.yaml`) requires
244
+ retrieving the checkpoint from [Hugging Face](https://huggingface.co/stabilityai/sdxl-vae/tree/main) and replacing
245
+ the `CKPT_PATH` placeholder in [this line](configs/example_training/imagenet-f8_cond.yaml#81). The same is to be done
246
+ for the provided text-to-image configs.
247
+
248
+ ### Building New Diffusion Models
249
+
250
+ #### Conditioner
251
+
252
+ The `GeneralConditioner` is configured through the `conditioner_config`. Its only attribute is `emb_models`, a list of
253
+ different embedders (all inherited from `AbstractEmbModel`) that are used to condition the generative model.
254
+ All embedders should define whether or not they are trainable (`is_trainable`, default `False`), a classifier-free
255
+ guidance dropout rate is used (`ucg_rate`, default `0`), and an input key (`input_key`), for example, `txt` for
256
+ text-conditioning or `cls` for class-conditioning.
257
+ When computing conditionings, the embedder will get `batch[input_key]` as input.
258
+ We currently support two to four dimensional conditionings and conditionings of different embedders are concatenated
259
+ appropriately.
260
+ Note that the order of the embedders in the `conditioner_config` is important.
261
+
262
+ #### Network
263
+
264
+ The neural network is set through the `network_config`. This used to be called `unet_config`, which is not general
265
+ enough as we plan to experiment with transformer-based diffusion backbones.
266
+
267
+ #### Loss
268
+
269
+ The loss is configured through `loss_config`. For standard diffusion model training, you will have to
270
+ set `sigma_sampler_config`.
271
+
272
+ #### Sampler config
273
+
274
+ As discussed above, the sampler is independent of the model. In the `sampler_config`, we set the type of numerical
275
+ solver, number of steps, type of discretization, as well as, for example, guidance wrappers for classifier-free
276
+ guidance.
277
+
278
+ ### Dataset Handling
279
+
280
+ For large scale training we recommend using the data pipelines from
281
+ our [data pipelines](https://github.com/Stability-AI/datapipelines) project. The project is contained in the requirement
282
+ and automatically included when following the steps from the [Installation section](#installation).
283
+ Small map-style datasets should be defined here in the repository (e.g., MNIST, CIFAR-10, ...), and return a dict of
284
+ data keys/values,
285
+ e.g.,
286
+
287
+ ```python
288
+ example = {"jpg": x, # this is a tensor -1...1 chw
289
+ "txt": "a beautiful image"}
290
+ ```
291
+
292
+ where we expect images in -1...1, channel-first format.
configs/.DS_Store ADDED
Binary file (6.15 kB). View file
 
configs/inference/sd_2_1.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.18215
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9
+ params:
10
+ num_idx: 1000
11
+
12
+ scaling_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
14
+ discretization_config:
15
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
16
+
17
+ network_config:
18
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
19
+ params:
20
+ use_checkpoint: True
21
+ in_channels: 4
22
+ out_channels: 4
23
+ model_channels: 320
24
+ attention_resolutions: [4, 2, 1]
25
+ num_res_blocks: 2
26
+ channel_mult: [1, 2, 4, 4]
27
+ num_head_channels: 64
28
+ use_linear_in_transformer: True
29
+ transformer_depth: 1
30
+ context_dim: 1024
31
+
32
+ conditioner_config:
33
+ target: sgm.modules.GeneralConditioner
34
+ params:
35
+ emb_models:
36
+ - is_trainable: False
37
+ input_key: txt
38
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
39
+ params:
40
+ freeze: true
41
+ layer: penultimate
42
+
43
+ first_stage_config:
44
+ target: sgm.models.autoencoder.AutoencoderKL
45
+ params:
46
+ embed_dim: 4
47
+ monitor: val/rec_loss
48
+ ddconfig:
49
+ double_z: true
50
+ z_channels: 4
51
+ resolution: 256
52
+ in_channels: 3
53
+ out_ch: 3
54
+ ch: 128
55
+ ch_mult: [1, 2, 4, 4]
56
+ num_res_blocks: 2
57
+ attn_resolutions: []
58
+ dropout: 0.0
59
+ lossconfig:
60
+ target: torch.nn.Identity
configs/inference/sd_2_1_768.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.18215
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9
+ params:
10
+ num_idx: 1000
11
+
12
+ scaling_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling
14
+ discretization_config:
15
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
16
+
17
+ network_config:
18
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
19
+ params:
20
+ use_checkpoint: True
21
+ in_channels: 4
22
+ out_channels: 4
23
+ model_channels: 320
24
+ attention_resolutions: [4, 2, 1]
25
+ num_res_blocks: 2
26
+ channel_mult: [1, 2, 4, 4]
27
+ num_head_channels: 64
28
+ use_linear_in_transformer: True
29
+ transformer_depth: 1
30
+ context_dim: 1024
31
+
32
+ conditioner_config:
33
+ target: sgm.modules.GeneralConditioner
34
+ params:
35
+ emb_models:
36
+ - is_trainable: False
37
+ input_key: txt
38
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
39
+ params:
40
+ freeze: true
41
+ layer: penultimate
42
+
43
+ first_stage_config:
44
+ target: sgm.models.autoencoder.AutoencoderKL
45
+ params:
46
+ embed_dim: 4
47
+ monitor: val/rec_loss
48
+ ddconfig:
49
+ double_z: true
50
+ z_channels: 4
51
+ resolution: 256
52
+ in_channels: 3
53
+ out_ch: 3
54
+ ch: 128
55
+ ch_mult: [1, 2, 4, 4]
56
+ num_res_blocks: 2
57
+ attn_resolutions: []
58
+ dropout: 0.0
59
+ lossconfig:
60
+ target: torch.nn.Identity
configs/inference/sd_xl_base.yaml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.13025
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9
+ params:
10
+ num_idx: 1000
11
+
12
+ scaling_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
14
+ discretization_config:
15
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
16
+
17
+ network_config:
18
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
19
+ params:
20
+ adm_in_channels: 2816
21
+ num_classes: sequential
22
+ use_checkpoint: True
23
+ in_channels: 4
24
+ out_channels: 4
25
+ model_channels: 320
26
+ attention_resolutions: [4, 2]
27
+ num_res_blocks: 2
28
+ channel_mult: [1, 2, 4]
29
+ num_head_channels: 64
30
+ use_linear_in_transformer: True
31
+ transformer_depth: [1, 2, 10]
32
+ context_dim: 2048
33
+ spatial_transformer_attn_type: softmax-xformers
34
+
35
+ conditioner_config:
36
+ target: sgm.modules.GeneralConditioner
37
+ params:
38
+ emb_models:
39
+ - is_trainable: False
40
+ input_key: txt
41
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
42
+ params:
43
+ layer: hidden
44
+ layer_idx: 11
45
+
46
+ - is_trainable: False
47
+ input_key: txt
48
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
49
+ params:
50
+ arch: ViT-bigG-14
51
+ version: laion2b_s39b_b160k
52
+ freeze: True
53
+ layer: penultimate
54
+ always_return_pooled: True
55
+ legacy: False
56
+
57
+ - is_trainable: False
58
+ input_key: original_size_as_tuple
59
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
60
+ params:
61
+ outdim: 256
62
+
63
+ - is_trainable: False
64
+ input_key: crop_coords_top_left
65
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
66
+ params:
67
+ outdim: 256
68
+
69
+ - is_trainable: False
70
+ input_key: target_size_as_tuple
71
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
72
+ params:
73
+ outdim: 256
74
+
75
+ first_stage_config:
76
+ target: sgm.models.autoencoder.AutoencoderKL
77
+ params:
78
+ embed_dim: 4
79
+ monitor: val/rec_loss
80
+ ddconfig:
81
+ attn_type: vanilla-xformers
82
+ double_z: true
83
+ z_channels: 4
84
+ resolution: 256
85
+ in_channels: 3
86
+ out_ch: 3
87
+ ch: 128
88
+ ch_mult: [1, 2, 4, 4]
89
+ num_res_blocks: 2
90
+ attn_resolutions: []
91
+ dropout: 0.0
92
+ lossconfig:
93
+ target: torch.nn.Identity
configs/inference/sd_xl_refiner.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.13025
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9
+ params:
10
+ num_idx: 1000
11
+
12
+ scaling_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
14
+ discretization_config:
15
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
16
+
17
+ network_config:
18
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
19
+ params:
20
+ adm_in_channels: 2560
21
+ num_classes: sequential
22
+ use_checkpoint: True
23
+ in_channels: 4
24
+ out_channels: 4
25
+ model_channels: 384
26
+ attention_resolutions: [4, 2]
27
+ num_res_blocks: 2
28
+ channel_mult: [1, 2, 4, 4]
29
+ num_head_channels: 64
30
+ use_linear_in_transformer: True
31
+ transformer_depth: 4
32
+ context_dim: [1280, 1280, 1280, 1280]
33
+ spatial_transformer_attn_type: softmax-xformers
34
+
35
+ conditioner_config:
36
+ target: sgm.modules.GeneralConditioner
37
+ params:
38
+ emb_models:
39
+ - is_trainable: False
40
+ input_key: txt
41
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
42
+ params:
43
+ arch: ViT-bigG-14
44
+ version: laion2b_s39b_b160k
45
+ legacy: False
46
+ freeze: True
47
+ layer: penultimate
48
+ always_return_pooled: True
49
+
50
+ - is_trainable: False
51
+ input_key: original_size_as_tuple
52
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
53
+ params:
54
+ outdim: 256
55
+
56
+ - is_trainable: False
57
+ input_key: crop_coords_top_left
58
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
59
+ params:
60
+ outdim: 256
61
+
62
+ - is_trainable: False
63
+ input_key: aesthetic_score
64
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
65
+ params:
66
+ outdim: 256
67
+
68
+ first_stage_config:
69
+ target: sgm.models.autoencoder.AutoencoderKL
70
+ params:
71
+ embed_dim: 4
72
+ monitor: val/rec_loss
73
+ ddconfig:
74
+ attn_type: vanilla-xformers
75
+ double_z: true
76
+ z_channels: 4
77
+ resolution: 256
78
+ in_channels: 3
79
+ out_ch: 3
80
+ ch: 128
81
+ ch_mult: [1, 2, 4, 4]
82
+ num_res_blocks: 2
83
+ attn_resolutions: []
84
+ dropout: 0.0
85
+ lossconfig:
86
+ target: torch.nn.Identity
configs/inference/svd.yaml ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.18215
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
9
+ params:
10
+ scaling_config:
11
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
12
+
13
+ network_config:
14
+ target: sgm.modules.diffusionmodules.video_model.VideoUNet
15
+ params:
16
+ adm_in_channels: 768
17
+ num_classes: sequential
18
+ use_checkpoint: True
19
+ in_channels: 8
20
+ out_channels: 4
21
+ model_channels: 320
22
+ attention_resolutions: [4, 2, 1]
23
+ num_res_blocks: 2
24
+ channel_mult: [1, 2, 4, 4]
25
+ num_head_channels: 64
26
+ use_linear_in_transformer: True
27
+ transformer_depth: 1
28
+ context_dim: 1024
29
+ spatial_transformer_attn_type: softmax-xformers
30
+ extra_ff_mix_layer: True
31
+ use_spatial_context: True
32
+ merge_strategy: learned_with_images
33
+ video_kernel_size: [3, 1, 1]
34
+
35
+ conditioner_config:
36
+ target: sgm.modules.GeneralConditioner
37
+ params:
38
+ emb_models:
39
+ - is_trainable: False
40
+ input_key: cond_frames_without_noise
41
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
42
+ params:
43
+ n_cond_frames: 1
44
+ n_copies: 1
45
+ open_clip_embedding_config:
46
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
47
+ params:
48
+ freeze: True
49
+
50
+ - input_key: fps_id
51
+ is_trainable: False
52
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
53
+ params:
54
+ outdim: 256
55
+
56
+ - input_key: motion_bucket_id
57
+ is_trainable: False
58
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
59
+ params:
60
+ outdim: 256
61
+
62
+ - input_key: cond_frames
63
+ is_trainable: False
64
+ target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
65
+ params:
66
+ disable_encoder_autocast: True
67
+ n_cond_frames: 1
68
+ n_copies: 1
69
+ is_ae: True
70
+ encoder_config:
71
+ target: sgm.models.autoencoder.AutoencoderKLModeOnly
72
+ params:
73
+ embed_dim: 4
74
+ monitor: val/rec_loss
75
+ ddconfig:
76
+ attn_type: vanilla-xformers
77
+ double_z: True
78
+ z_channels: 4
79
+ resolution: 256
80
+ in_channels: 3
81
+ out_ch: 3
82
+ ch: 128
83
+ ch_mult: [1, 2, 4, 4]
84
+ num_res_blocks: 2
85
+ attn_resolutions: []
86
+ dropout: 0.0
87
+ lossconfig:
88
+ target: torch.nn.Identity
89
+
90
+ - input_key: cond_aug
91
+ is_trainable: False
92
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
93
+ params:
94
+ outdim: 256
95
+
96
+ first_stage_config:
97
+ target: sgm.models.autoencoder.AutoencodingEngine
98
+ params:
99
+ loss_config:
100
+ target: torch.nn.Identity
101
+ regularizer_config:
102
+ target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
103
+ encoder_config:
104
+ target: sgm.modules.diffusionmodules.model.Encoder
105
+ params:
106
+ attn_type: vanilla
107
+ double_z: True
108
+ z_channels: 4
109
+ resolution: 256
110
+ in_channels: 3
111
+ out_ch: 3
112
+ ch: 128
113
+ ch_mult: [1, 2, 4, 4]
114
+ num_res_blocks: 2
115
+ attn_resolutions: []
116
+ dropout: 0.0
117
+ decoder_config:
118
+ target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
119
+ params:
120
+ attn_type: vanilla
121
+ double_z: True
122
+ z_channels: 4
123
+ resolution: 256
124
+ in_channels: 3
125
+ out_ch: 3
126
+ ch: 128
127
+ ch_mult: [1, 2, 4, 4]
128
+ num_res_blocks: 2
129
+ attn_resolutions: []
130
+ dropout: 0.0
131
+ video_kernel_size: [3, 1, 1]
configs/inference/svd_image_decoder.yaml ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.18215
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
9
+ params:
10
+ scaling_config:
11
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
12
+
13
+ network_config:
14
+ target: sgm.modules.diffusionmodules.video_model.VideoUNet
15
+ params:
16
+ adm_in_channels: 768
17
+ num_classes: sequential
18
+ use_checkpoint: True
19
+ in_channels: 8
20
+ out_channels: 4
21
+ model_channels: 320
22
+ attention_resolutions: [4, 2, 1]
23
+ num_res_blocks: 2
24
+ channel_mult: [1, 2, 4, 4]
25
+ num_head_channels: 64
26
+ use_linear_in_transformer: True
27
+ transformer_depth: 1
28
+ context_dim: 1024
29
+ spatial_transformer_attn_type: softmax-xformers
30
+ extra_ff_mix_layer: True
31
+ use_spatial_context: True
32
+ merge_strategy: learned_with_images
33
+ video_kernel_size: [3, 1, 1]
34
+
35
+ conditioner_config:
36
+ target: sgm.modules.GeneralConditioner
37
+ params:
38
+ emb_models:
39
+ - is_trainable: False
40
+ input_key: cond_frames_without_noise
41
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
42
+ params:
43
+ n_cond_frames: 1
44
+ n_copies: 1
45
+ open_clip_embedding_config:
46
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
47
+ params:
48
+ freeze: True
49
+
50
+ - input_key: fps_id
51
+ is_trainable: False
52
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
53
+ params:
54
+ outdim: 256
55
+
56
+ - input_key: motion_bucket_id
57
+ is_trainable: False
58
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
59
+ params:
60
+ outdim: 256
61
+
62
+ - input_key: cond_frames
63
+ is_trainable: False
64
+ target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
65
+ params:
66
+ disable_encoder_autocast: True
67
+ n_cond_frames: 1
68
+ n_copies: 1
69
+ is_ae: True
70
+ encoder_config:
71
+ target: sgm.models.autoencoder.AutoencoderKLModeOnly
72
+ params:
73
+ embed_dim: 4
74
+ monitor: val/rec_loss
75
+ ddconfig:
76
+ attn_type: vanilla-xformers
77
+ double_z: True
78
+ z_channels: 4
79
+ resolution: 256
80
+ in_channels: 3
81
+ out_ch: 3
82
+ ch: 128
83
+ ch_mult: [1, 2, 4, 4]
84
+ num_res_blocks: 2
85
+ attn_resolutions: []
86
+ dropout: 0.0
87
+ lossconfig:
88
+ target: torch.nn.Identity
89
+
90
+ - input_key: cond_aug
91
+ is_trainable: False
92
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
93
+ params:
94
+ outdim: 256
95
+
96
+ first_stage_config:
97
+ target: sgm.models.autoencoder.AutoencoderKL
98
+ params:
99
+ embed_dim: 4
100
+ monitor: val/rec_loss
101
+ ddconfig:
102
+ attn_type: vanilla-xformers
103
+ double_z: True
104
+ z_channels: 4
105
+ resolution: 256
106
+ in_channels: 3
107
+ out_ch: 3
108
+ ch: 128
109
+ ch_mult: [1, 2, 4, 4]
110
+ num_res_blocks: 2
111
+ attn_resolutions: []
112
+ dropout: 0.0
113
+ lossconfig:
114
+ target: torch.nn.Identity
requirements/pt2.txt ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ black==23.7.0
2
+ chardet==5.1.0
3
+ clip @ git+https://github.com/openai/CLIP.git
4
+ einops>=0.6.1
5
+ fairscale>=0.4.13
6
+ fire>=0.5.0
7
+ fsspec>=2023.6.0
8
+ invisible-watermark>=0.2.0
9
+ kornia==0.6.9
10
+ matplotlib>=3.7.2
11
+ natsort>=8.4.0
12
+ ninja>=1.11.1
13
+ numpy>=1.24.4
14
+ omegaconf>=2.3.0
15
+ open-clip-torch>=2.20.0
16
+ opencv-python==4.6.0.66
17
+ pandas>=2.0.3
18
+ pillow>=9.5.0
19
+ pudb>=2022.1.3
20
+ pytorch-lightning==2.0.1
21
+ pyyaml>=6.0.1
22
+ scipy>=1.10.1
23
+ streamlit>=0.73.1
24
+ tensorboardx==2.6
25
+ timm>=0.9.2
26
+ tokenizers==0.12.1
27
+ torch>=2.0.1
28
+ torchaudio>=2.0.2
29
+ torchdata==0.6.1
30
+ torchmetrics>=1.0.1
31
+ torchvision>=0.15.2
32
+ tqdm>=4.65.0
33
+ transformers==4.19.1
34
+ triton==2.0.0
35
+ urllib3<1.27,>=1.25.4
36
+ wandb>=0.15.6
37
+ webdataset>=0.2.33
38
+ wheel>=0.41.0
39
+ xformers>=0.0.20
scripts/.DS_Store ADDED
Binary file (6.15 kB). View file
 
scripts/__init__.py ADDED
File without changes
scripts/demo/__init__.py ADDED
File without changes
scripts/demo/detect.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+ try:
7
+ from imwatermark import WatermarkDecoder
8
+ except ImportError as e:
9
+ try:
10
+ # Assume some of the other dependencies such as torch are not fulfilled
11
+ # import file without loading unnecessary libraries.
12
+ import importlib.util
13
+ import sys
14
+
15
+ spec = importlib.util.find_spec("imwatermark.maxDct")
16
+ assert spec is not None
17
+ maxDct = importlib.util.module_from_spec(spec)
18
+ sys.modules["maxDct"] = maxDct
19
+ spec.loader.exec_module(maxDct)
20
+
21
+ class WatermarkDecoder(object):
22
+ """A minimal version of
23
+ https://github.com/ShieldMnt/invisible-watermark/blob/main/imwatermark/watermark.py
24
+ to only reconstruct bits using dwtDct"""
25
+
26
+ def __init__(self, wm_type="bytes", length=0):
27
+ assert wm_type == "bits", "Only bits defined in minimal import"
28
+ self._wmType = wm_type
29
+ self._wmLen = length
30
+
31
+ def reconstruct(self, bits):
32
+ if len(bits) != self._wmLen:
33
+ raise RuntimeError("bits are not matched with watermark length")
34
+
35
+ return bits
36
+
37
+ def decode(self, cv2Image, method="dwtDct", **configs):
38
+ (r, c, channels) = cv2Image.shape
39
+ if r * c < 256 * 256:
40
+ raise RuntimeError("image too small, should be larger than 256x256")
41
+
42
+ bits = []
43
+ assert method == "dwtDct"
44
+ embed = maxDct.EmbedMaxDct(watermarks=[], wmLen=self._wmLen, **configs)
45
+ bits = embed.decode(cv2Image)
46
+ return self.reconstruct(bits)
47
+
48
+ except:
49
+ raise e
50
+
51
+
52
+ # A fixed 48-bit message that was choosen at random
53
+ # WATERMARK_MESSAGE = 0xB3EC907BB19E
54
+ WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
55
+ # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
56
+ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
57
+ MATCH_VALUES = [
58
+ [27, "No watermark detected"],
59
+ [33, "Partial watermark match. Cannot determine with certainty."],
60
+ [
61
+ 35,
62
+ (
63
+ "Likely watermarked. In our test 0.02% of real images were "
64
+ 'falsely detected as "Likely watermarked"'
65
+ ),
66
+ ],
67
+ [
68
+ 49,
69
+ (
70
+ "Very likely watermarked. In our test no real images were "
71
+ 'falsely detected as "Very likely watermarked"'
72
+ ),
73
+ ],
74
+ ]
75
+
76
+
77
+ class GetWatermarkMatch:
78
+ def __init__(self, watermark):
79
+ self.watermark = watermark
80
+ self.num_bits = len(self.watermark)
81
+ self.decoder = WatermarkDecoder("bits", self.num_bits)
82
+
83
+ def __call__(self, x: np.ndarray) -> np.ndarray:
84
+ """
85
+ Detects the number of matching bits the predefined watermark with one
86
+ or multiple images. Images should be in cv2 format, e.g. h x w x c BGR.
87
+
88
+ Args:
89
+ x: ([B], h w, c) in range [0, 255]
90
+
91
+ Returns:
92
+ number of matched bits ([B],)
93
+ """
94
+ squeeze = len(x.shape) == 3
95
+ if squeeze:
96
+ x = x[None, ...]
97
+
98
+ bs = x.shape[0]
99
+ detected = np.empty((bs, self.num_bits), dtype=bool)
100
+ for k in range(bs):
101
+ detected[k] = self.decoder.decode(x[k], "dwtDct")
102
+ result = np.sum(detected == self.watermark, axis=-1)
103
+ if squeeze:
104
+ return result[0]
105
+ else:
106
+ return result
107
+
108
+
109
+ get_watermark_match = GetWatermarkMatch(WATERMARK_BITS)
110
+
111
+
112
+ if __name__ == "__main__":
113
+ parser = argparse.ArgumentParser()
114
+ parser.add_argument(
115
+ "filename",
116
+ nargs="+",
117
+ type=str,
118
+ help="Image files to check for watermarks",
119
+ )
120
+ opts = parser.parse_args()
121
+
122
+ print(
123
+ """
124
+ This script tries to detect watermarked images. Please be aware of
125
+ the following:
126
+ - As the watermark is supposed to be invisible, there is the risk that
127
+ watermarked images may not be detected.
128
+ - To maximize the chance of detection make sure that the image has the same
129
+ dimensions as when the watermark was applied (most likely 1024x1024
130
+ or 512x512).
131
+ - Specific image manipulation may drastically decrease the chance that
132
+ watermarks can be detected.
133
+ - There is also the chance that an image has the characteristics of the
134
+ watermark by chance.
135
+ - The watermark script is public, anybody may watermark any images, and
136
+ could therefore claim it to be generated.
137
+ - All numbers below are based on a test using 10,000 images without any
138
+ modifications after applying the watermark.
139
+ """
140
+ )
141
+
142
+ for fn in opts.filename:
143
+ image = cv2.imread(fn)
144
+ if image is None:
145
+ print(f"Couldn't read {fn}. Skipping")
146
+ continue
147
+
148
+ num_bits = get_watermark_match(image)
149
+ k = 0
150
+ while num_bits > MATCH_VALUES[k][0]:
151
+ k += 1
152
+ print(
153
+ f"{fn}: {MATCH_VALUES[k][1]}",
154
+ f"Bits that matched the watermark {num_bits} from {len(WATERMARK_BITS)}\n",
155
+ sep="\n\t",
156
+ )
scripts/demo/discretization.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from sgm.modules.diffusionmodules.discretizer import Discretization
4
+
5
+
6
+ class Img2ImgDiscretizationWrapper:
7
+ """
8
+ wraps a discretizer, and prunes the sigmas
9
+ params:
10
+ strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
11
+ """
12
+
13
+ def __init__(self, discretization: Discretization, strength: float = 1.0):
14
+ self.discretization = discretization
15
+ self.strength = strength
16
+ assert 0.0 <= self.strength <= 1.0
17
+
18
+ def __call__(self, *args, **kwargs):
19
+ # sigmas start large first, and decrease then
20
+ sigmas = self.discretization(*args, **kwargs)
21
+ print(f"sigmas after discretization, before pruning img2img: ", sigmas)
22
+ sigmas = torch.flip(sigmas, (0,))
23
+ sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
24
+ print("prune index:", max(int(self.strength * len(sigmas)), 1))
25
+ sigmas = torch.flip(sigmas, (0,))
26
+ print(f"sigmas after pruning: ", sigmas)
27
+ return sigmas
28
+
29
+
30
+ class Txt2NoisyDiscretizationWrapper:
31
+ """
32
+ wraps a discretizer, and prunes the sigmas
33
+ params:
34
+ strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
35
+ """
36
+
37
+ def __init__(
38
+ self, discretization: Discretization, strength: float = 0.0, original_steps=None
39
+ ):
40
+ self.discretization = discretization
41
+ self.strength = strength
42
+ self.original_steps = original_steps
43
+ assert 0.0 <= self.strength <= 1.0
44
+
45
+ def __call__(self, *args, **kwargs):
46
+ # sigmas start large first, and decrease then
47
+ sigmas = self.discretization(*args, **kwargs)
48
+ print(f"sigmas after discretization, before pruning img2img: ", sigmas)
49
+ sigmas = torch.flip(sigmas, (0,))
50
+ if self.original_steps is None:
51
+ steps = len(sigmas)
52
+ else:
53
+ steps = self.original_steps + 1
54
+ prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
55
+ sigmas = sigmas[prune_index:]
56
+ print("prune index:", prune_index)
57
+ sigmas = torch.flip(sigmas, (0,))
58
+ print(f"sigmas after pruning: ", sigmas)
59
+ return sigmas
scripts/demo/sampling.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytorch_lightning import seed_everything
2
+
3
+ from scripts.demo.streamlit_helpers import *
4
+
5
+ SAVE_PATH = "outputs/demo/txt2img/"
6
+
7
+ SD_XL_BASE_RATIOS = {
8
+ "0.5": (704, 1408),
9
+ "0.52": (704, 1344),
10
+ "0.57": (768, 1344),
11
+ "0.6": (768, 1280),
12
+ "0.68": (832, 1216),
13
+ "0.72": (832, 1152),
14
+ "0.78": (896, 1152),
15
+ "0.82": (896, 1088),
16
+ "0.88": (960, 1088),
17
+ "0.94": (960, 1024),
18
+ "1.0": (1024, 1024),
19
+ "1.07": (1024, 960),
20
+ "1.13": (1088, 960),
21
+ "1.21": (1088, 896),
22
+ "1.29": (1152, 896),
23
+ "1.38": (1152, 832),
24
+ "1.46": (1216, 832),
25
+ "1.67": (1280, 768),
26
+ "1.75": (1344, 768),
27
+ "1.91": (1344, 704),
28
+ "2.0": (1408, 704),
29
+ "2.09": (1472, 704),
30
+ "2.4": (1536, 640),
31
+ "2.5": (1600, 640),
32
+ "2.89": (1664, 576),
33
+ "3.0": (1728, 576),
34
+ }
35
+
36
+ VERSION2SPECS = {
37
+ "SDXL-base-1.0": {
38
+ "H": 1024,
39
+ "W": 1024,
40
+ "C": 4,
41
+ "f": 8,
42
+ "is_legacy": False,
43
+ "config": "configs/inference/sd_xl_base.yaml",
44
+ "ckpt": "checkpoints/sd_xl_base_1.0.safetensors",
45
+ },
46
+ "SDXL-base-0.9": {
47
+ "H": 1024,
48
+ "W": 1024,
49
+ "C": 4,
50
+ "f": 8,
51
+ "is_legacy": False,
52
+ "config": "configs/inference/sd_xl_base.yaml",
53
+ "ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
54
+ },
55
+ "SD-2.1": {
56
+ "H": 512,
57
+ "W": 512,
58
+ "C": 4,
59
+ "f": 8,
60
+ "is_legacy": True,
61
+ "config": "configs/inference/sd_2_1.yaml",
62
+ "ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
63
+ },
64
+ "SD-2.1-768": {
65
+ "H": 768,
66
+ "W": 768,
67
+ "C": 4,
68
+ "f": 8,
69
+ "is_legacy": True,
70
+ "config": "configs/inference/sd_2_1_768.yaml",
71
+ "ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
72
+ },
73
+ "SDXL-refiner-0.9": {
74
+ "H": 1024,
75
+ "W": 1024,
76
+ "C": 4,
77
+ "f": 8,
78
+ "is_legacy": True,
79
+ "config": "configs/inference/sd_xl_refiner.yaml",
80
+ "ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
81
+ },
82
+ "SDXL-refiner-1.0": {
83
+ "H": 1024,
84
+ "W": 1024,
85
+ "C": 4,
86
+ "f": 8,
87
+ "is_legacy": True,
88
+ "config": "configs/inference/sd_xl_refiner.yaml",
89
+ "ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors",
90
+ },
91
+ }
92
+
93
+
94
+ def load_img(display=True, key=None, device="cuda"):
95
+ image = get_interactive_image(key=key)
96
+ if image is None:
97
+ return None
98
+ if display:
99
+ st.image(image)
100
+ w, h = image.size
101
+ print(f"loaded input image of size ({w}, {h})")
102
+ width, height = map(
103
+ lambda x: x - x % 64, (w, h)
104
+ ) # resize to integer multiple of 64
105
+ image = image.resize((width, height))
106
+ image = np.array(image.convert("RGB"))
107
+ image = image[None].transpose(0, 3, 1, 2)
108
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
109
+ return image.to(device)
110
+
111
+
112
+ def run_txt2img(
113
+ state,
114
+ version,
115
+ version_dict,
116
+ is_legacy=False,
117
+ return_latents=False,
118
+ filter=None,
119
+ stage2strength=None,
120
+ ):
121
+ if version.startswith("SDXL-base"):
122
+ W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10)
123
+ else:
124
+ H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048)
125
+ W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048)
126
+ C = version_dict["C"]
127
+ F = version_dict["f"]
128
+
129
+ init_dict = {
130
+ "orig_width": W,
131
+ "orig_height": H,
132
+ "target_width": W,
133
+ "target_height": H,
134
+ }
135
+ value_dict = init_embedder_options(
136
+ get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
137
+ init_dict,
138
+ prompt=prompt,
139
+ negative_prompt=negative_prompt,
140
+ )
141
+ sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength)
142
+ num_samples = num_rows * num_cols
143
+
144
+ if st.button("Sample"):
145
+ st.write(f"**Model I:** {version}")
146
+ out = do_sample(
147
+ state["model"],
148
+ sampler,
149
+ value_dict,
150
+ num_samples,
151
+ H,
152
+ W,
153
+ C,
154
+ F,
155
+ force_uc_zero_embeddings=["txt"] if not is_legacy else [],
156
+ return_latents=return_latents,
157
+ filter=filter,
158
+ )
159
+ return out
160
+
161
+
162
+ def run_img2img(
163
+ state,
164
+ version_dict,
165
+ is_legacy=False,
166
+ return_latents=False,
167
+ filter=None,
168
+ stage2strength=None,
169
+ ):
170
+ img = load_img()
171
+ if img is None:
172
+ return None
173
+ H, W = img.shape[2], img.shape[3]
174
+
175
+ init_dict = {
176
+ "orig_width": W,
177
+ "orig_height": H,
178
+ "target_width": W,
179
+ "target_height": H,
180
+ }
181
+ value_dict = init_embedder_options(
182
+ get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
183
+ init_dict,
184
+ prompt=prompt,
185
+ negative_prompt=negative_prompt,
186
+ )
187
+ strength = st.number_input(
188
+ "**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
189
+ )
190
+ sampler, num_rows, num_cols = init_sampling(
191
+ img2img_strength=strength,
192
+ stage2strength=stage2strength,
193
+ )
194
+ num_samples = num_rows * num_cols
195
+
196
+ if st.button("Sample"):
197
+ out = do_img2img(
198
+ repeat(img, "1 ... -> n ...", n=num_samples),
199
+ state["model"],
200
+ sampler,
201
+ value_dict,
202
+ num_samples,
203
+ force_uc_zero_embeddings=["txt"] if not is_legacy else [],
204
+ return_latents=return_latents,
205
+ filter=filter,
206
+ )
207
+ return out
208
+
209
+
210
+ def apply_refiner(
211
+ input,
212
+ state,
213
+ sampler,
214
+ num_samples,
215
+ prompt,
216
+ negative_prompt,
217
+ filter=None,
218
+ finish_denoising=False,
219
+ ):
220
+ init_dict = {
221
+ "orig_width": input.shape[3] * 8,
222
+ "orig_height": input.shape[2] * 8,
223
+ "target_width": input.shape[3] * 8,
224
+ "target_height": input.shape[2] * 8,
225
+ }
226
+
227
+ value_dict = init_dict
228
+ value_dict["prompt"] = prompt
229
+ value_dict["negative_prompt"] = negative_prompt
230
+
231
+ value_dict["crop_coords_top"] = 0
232
+ value_dict["crop_coords_left"] = 0
233
+
234
+ value_dict["aesthetic_score"] = 6.0
235
+ value_dict["negative_aesthetic_score"] = 2.5
236
+
237
+ st.warning(f"refiner input shape: {input.shape}")
238
+ samples = do_img2img(
239
+ input,
240
+ state["model"],
241
+ sampler,
242
+ value_dict,
243
+ num_samples,
244
+ skip_encode=True,
245
+ filter=filter,
246
+ add_noise=not finish_denoising,
247
+ )
248
+
249
+ return samples
250
+
251
+
252
+ if __name__ == "__main__":
253
+ st.title("Stable Diffusion")
254
+ version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
255
+ version_dict = VERSION2SPECS[version]
256
+ if st.checkbox("Load Model"):
257
+ mode = st.radio("Mode", ("txt2img", "img2img"), 0)
258
+ else:
259
+ mode = "skip"
260
+ st.write("__________________________")
261
+
262
+ set_lowvram_mode(st.checkbox("Low vram mode", True))
263
+
264
+ if version.startswith("SDXL-base"):
265
+ add_pipeline = st.checkbox("Load SDXL-refiner?", False)
266
+ st.write("__________________________")
267
+ else:
268
+ add_pipeline = False
269
+
270
+ seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
271
+ seed_everything(seed)
272
+
273
+ save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
274
+
275
+ if mode != "skip":
276
+ state = init_st(version_dict, load_filter=True)
277
+ if state["msg"]:
278
+ st.info(state["msg"])
279
+ model = state["model"]
280
+
281
+ is_legacy = version_dict["is_legacy"]
282
+
283
+ prompt = st.text_input(
284
+ "prompt",
285
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
286
+ )
287
+ if is_legacy:
288
+ negative_prompt = st.text_input("negative prompt", "")
289
+ else:
290
+ negative_prompt = "" # which is unused
291
+
292
+ stage2strength = None
293
+ finish_denoising = False
294
+
295
+ if add_pipeline:
296
+ st.write("__________________________")
297
+ version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"])
298
+ st.warning(
299
+ f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
300
+ )
301
+ st.write("**Refiner Options:**")
302
+
303
+ version_dict2 = VERSION2SPECS[version2]
304
+ state2 = init_st(version_dict2, load_filter=False)
305
+ st.info(state2["msg"])
306
+
307
+ stage2strength = st.number_input(
308
+ "**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0
309
+ )
310
+
311
+ sampler2, *_ = init_sampling(
312
+ key=2,
313
+ img2img_strength=stage2strength,
314
+ specify_num_samples=False,
315
+ )
316
+ st.write("__________________________")
317
+ finish_denoising = st.checkbox("Finish denoising with refiner.", True)
318
+ if not finish_denoising:
319
+ stage2strength = None
320
+
321
+ if mode == "txt2img":
322
+ out = run_txt2img(
323
+ state,
324
+ version,
325
+ version_dict,
326
+ is_legacy=is_legacy,
327
+ return_latents=add_pipeline,
328
+ filter=state.get("filter"),
329
+ stage2strength=stage2strength,
330
+ )
331
+ elif mode == "img2img":
332
+ out = run_img2img(
333
+ state,
334
+ version_dict,
335
+ is_legacy=is_legacy,
336
+ return_latents=add_pipeline,
337
+ filter=state.get("filter"),
338
+ stage2strength=stage2strength,
339
+ )
340
+ elif mode == "skip":
341
+ out = None
342
+ else:
343
+ raise ValueError(f"unknown mode {mode}")
344
+ if isinstance(out, (tuple, list)):
345
+ samples, samples_z = out
346
+ else:
347
+ samples = out
348
+ samples_z = None
349
+
350
+ if add_pipeline and samples_z is not None:
351
+ st.write("**Running Refinement Stage**")
352
+ samples = apply_refiner(
353
+ samples_z,
354
+ state2,
355
+ sampler2,
356
+ samples_z.shape[0],
357
+ prompt=prompt,
358
+ negative_prompt=negative_prompt if is_legacy else "",
359
+ filter=state.get("filter"),
360
+ finish_denoising=finish_denoising,
361
+ )
362
+
363
+ if save_locally and samples is not None:
364
+ perform_save_locally(save_path, samples)
scripts/demo/streamlit_helpers.py ADDED
@@ -0,0 +1,928 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import os
4
+ from glob import glob
5
+ from typing import Dict, List, Optional, Tuple, Union
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import streamlit as st
10
+ import torch
11
+ import torch.nn as nn
12
+ import torchvision.transforms as TT
13
+ from einops import rearrange, repeat
14
+ from imwatermark import WatermarkEncoder
15
+ from omegaconf import ListConfig, OmegaConf
16
+ from PIL import Image
17
+ from safetensors.torch import load_file as load_safetensors
18
+ from torch import autocast
19
+ from torchvision import transforms
20
+ from torchvision.utils import make_grid, save_image
21
+
22
+ from scripts.demo.discretization import (Img2ImgDiscretizationWrapper,
23
+ Txt2NoisyDiscretizationWrapper)
24
+ from scripts.util.detection.nsfw_and_watermark_dectection import \
25
+ DeepFloydDataFiltering
26
+ from sgm.inference.helpers import embed_watermark
27
+ from sgm.modules.diffusionmodules.guiders import (LinearPredictionGuider,
28
+ VanillaCFG)
29
+ from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
30
+ DPMPP2SAncestralSampler,
31
+ EulerAncestralSampler,
32
+ EulerEDMSampler,
33
+ HeunEDMSampler,
34
+ LinearMultistepSampler)
35
+ from sgm.util import append_dims, default, instantiate_from_config
36
+
37
+
38
+ @st.cache_resource()
39
+ def init_st(version_dict, load_ckpt=True, load_filter=True):
40
+ state = dict()
41
+ if not "model" in state:
42
+ config = version_dict["config"]
43
+ ckpt = version_dict["ckpt"]
44
+
45
+ config = OmegaConf.load(config)
46
+ model, msg = load_model_from_config(config, ckpt if load_ckpt else None)
47
+
48
+ state["msg"] = msg
49
+ state["model"] = model
50
+ state["ckpt"] = ckpt if load_ckpt else None
51
+ state["config"] = config
52
+ if load_filter:
53
+ state["filter"] = DeepFloydDataFiltering(verbose=False)
54
+ return state
55
+
56
+
57
+ def load_model(model):
58
+ model.cuda()
59
+
60
+
61
+ lowvram_mode = False
62
+
63
+
64
+ def set_lowvram_mode(mode):
65
+ global lowvram_mode
66
+ lowvram_mode = mode
67
+
68
+
69
+ def initial_model_load(model):
70
+ global lowvram_mode
71
+ if lowvram_mode:
72
+ model.model.half()
73
+ else:
74
+ model.cuda()
75
+ return model
76
+
77
+
78
+ def unload_model(model):
79
+ global lowvram_mode
80
+ if lowvram_mode:
81
+ model.cpu()
82
+ torch.cuda.empty_cache()
83
+
84
+
85
+ def load_model_from_config(config, ckpt=None, verbose=True):
86
+ model = instantiate_from_config(config.model)
87
+
88
+ if ckpt is not None:
89
+ print(f"Loading model from {ckpt}")
90
+ if ckpt.endswith("ckpt"):
91
+ pl_sd = torch.load(ckpt, map_location="cpu")
92
+ if "global_step" in pl_sd:
93
+ global_step = pl_sd["global_step"]
94
+ st.info(f"loaded ckpt from global step {global_step}")
95
+ print(f"Global Step: {pl_sd['global_step']}")
96
+ sd = pl_sd["state_dict"]
97
+ elif ckpt.endswith("safetensors"):
98
+ sd = load_safetensors(ckpt)
99
+ else:
100
+ raise NotImplementedError
101
+
102
+ msg = None
103
+
104
+ m, u = model.load_state_dict(sd, strict=False)
105
+
106
+ if len(m) > 0 and verbose:
107
+ print("missing keys:")
108
+ print(m)
109
+ if len(u) > 0 and verbose:
110
+ print("unexpected keys:")
111
+ print(u)
112
+ else:
113
+ msg = None
114
+
115
+ model = initial_model_load(model)
116
+ model.eval()
117
+ return model, msg
118
+
119
+
120
+ def get_unique_embedder_keys_from_conditioner(conditioner):
121
+ return list(set([x.input_key for x in conditioner.embedders]))
122
+
123
+
124
+ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
125
+ # Hardcoded demo settings; might undergo some changes in the future
126
+
127
+ value_dict = {}
128
+ for key in keys:
129
+ if key == "txt":
130
+ if prompt is None:
131
+ prompt = "A professional photograph of an astronaut riding a pig"
132
+ if negative_prompt is None:
133
+ negative_prompt = ""
134
+
135
+ prompt = st.text_input("Prompt", prompt)
136
+ negative_prompt = st.text_input("Negative prompt", negative_prompt)
137
+
138
+ value_dict["prompt"] = prompt
139
+ value_dict["negative_prompt"] = negative_prompt
140
+
141
+ if key == "original_size_as_tuple":
142
+ orig_width = st.number_input(
143
+ "orig_width",
144
+ value=init_dict["orig_width"],
145
+ min_value=16,
146
+ )
147
+ orig_height = st.number_input(
148
+ "orig_height",
149
+ value=init_dict["orig_height"],
150
+ min_value=16,
151
+ )
152
+
153
+ value_dict["orig_width"] = orig_width
154
+ value_dict["orig_height"] = orig_height
155
+
156
+ if key == "crop_coords_top_left":
157
+ crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0)
158
+ crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0)
159
+
160
+ value_dict["crop_coords_top"] = crop_coord_top
161
+ value_dict["crop_coords_left"] = crop_coord_left
162
+
163
+ if key == "aesthetic_score":
164
+ value_dict["aesthetic_score"] = 6.0
165
+ value_dict["negative_aesthetic_score"] = 2.5
166
+
167
+ if key == "target_size_as_tuple":
168
+ value_dict["target_width"] = init_dict["target_width"]
169
+ value_dict["target_height"] = init_dict["target_height"]
170
+
171
+ if key in ["fps_id", "fps"]:
172
+ fps = st.number_input("fps", value=6, min_value=1)
173
+
174
+ value_dict["fps"] = fps
175
+ value_dict["fps_id"] = fps - 1
176
+
177
+ if key == "motion_bucket_id":
178
+ mb_id = st.number_input("motion bucket id", 0, 511, value=127)
179
+ value_dict["motion_bucket_id"] = mb_id
180
+
181
+ if key == "pool_image":
182
+ st.text("Image for pool conditioning")
183
+ image = load_img(
184
+ key="pool_image_input",
185
+ size=224,
186
+ center_crop=True,
187
+ )
188
+ if image is None:
189
+ st.info("Need an image here")
190
+ image = torch.zeros(1, 3, 224, 224)
191
+ value_dict["pool_image"] = image
192
+
193
+ return value_dict
194
+
195
+
196
+ def perform_save_locally(save_path, samples):
197
+ os.makedirs(os.path.join(save_path), exist_ok=True)
198
+ base_count = len(os.listdir(os.path.join(save_path)))
199
+ samples = embed_watermark(samples)
200
+ for sample in samples:
201
+ sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
202
+ Image.fromarray(sample.astype(np.uint8)).save(
203
+ os.path.join(save_path, f"{base_count:09}.png")
204
+ )
205
+ base_count += 1
206
+
207
+
208
+ def init_save_locally(_dir, init_value: bool = False):
209
+ save_locally = st.sidebar.checkbox("Save images locally", value=init_value)
210
+ if save_locally:
211
+ save_path = st.text_input("Save path", value=os.path.join(_dir, "samples"))
212
+ else:
213
+ save_path = None
214
+
215
+ return save_locally, save_path
216
+
217
+
218
+ def get_guider(options, key):
219
+ guider = st.sidebar.selectbox(
220
+ f"Discretization #{key}",
221
+ [
222
+ "VanillaCFG",
223
+ "IdentityGuider",
224
+ "LinearPredictionGuider",
225
+ ],
226
+ options.get("guider", 0),
227
+ )
228
+
229
+ additional_guider_kwargs = options.pop("additional_guider_kwargs", {})
230
+
231
+ if guider == "IdentityGuider":
232
+ guider_config = {
233
+ "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
234
+ }
235
+ elif guider == "VanillaCFG":
236
+ scale_schedule = st.sidebar.selectbox(
237
+ f"Scale schedule #{key}",
238
+ ["Identity", "Oscillating"],
239
+ )
240
+
241
+ if scale_schedule == "Identity":
242
+ scale = st.number_input(
243
+ f"cfg-scale #{key}",
244
+ value=options.get("cfg", 5.0),
245
+ min_value=0.0,
246
+ )
247
+
248
+ scale_schedule_config = {
249
+ "target": "sgm.modules.diffusionmodules.guiders.IdentitySchedule",
250
+ "params": {"scale": scale},
251
+ }
252
+
253
+ elif scale_schedule == "Oscillating":
254
+ small_scale = st.number_input(
255
+ f"small cfg-scale #{key}",
256
+ value=4.0,
257
+ min_value=0.0,
258
+ )
259
+
260
+ large_scale = st.number_input(
261
+ f"large cfg-scale #{key}",
262
+ value=16.0,
263
+ min_value=0.0,
264
+ )
265
+
266
+ sigma_cutoff = st.number_input(
267
+ f"sigma cutoff #{key}",
268
+ value=1.0,
269
+ min_value=0.0,
270
+ )
271
+
272
+ scale_schedule_config = {
273
+ "target": "sgm.modules.diffusionmodules.guiders.OscillatingSchedule",
274
+ "params": {
275
+ "small_scale": small_scale,
276
+ "large_scale": large_scale,
277
+ "sigma_cutoff": sigma_cutoff,
278
+ },
279
+ }
280
+ else:
281
+ raise NotImplementedError
282
+
283
+ guider_config = {
284
+ "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
285
+ "params": {
286
+ "scale_schedule_config": scale_schedule_config,
287
+ **additional_guider_kwargs,
288
+ },
289
+ }
290
+ elif guider == "LinearPredictionGuider":
291
+ max_scale = st.number_input(
292
+ f"max-cfg-scale #{key}",
293
+ value=options.get("cfg", 1.5),
294
+ min_value=1.0,
295
+ )
296
+ min_scale = st.number_input(
297
+ f"min guidance scale",
298
+ value=options.get("min_cfg", 1.0),
299
+ min_value=1.0,
300
+ max_value=10.0,
301
+ )
302
+
303
+ guider_config = {
304
+ "target": "sgm.modules.diffusionmodules.guiders.LinearPredictionGuider",
305
+ "params": {
306
+ "max_scale": max_scale,
307
+ "min_scale": min_scale,
308
+ "num_frames": options["num_frames"],
309
+ **additional_guider_kwargs,
310
+ },
311
+ }
312
+ else:
313
+ raise NotImplementedError
314
+ return guider_config
315
+
316
+
317
+ def init_sampling(
318
+ key=1,
319
+ img2img_strength: Optional[float] = None,
320
+ specify_num_samples: bool = True,
321
+ stage2strength: Optional[float] = None,
322
+ options: Optional[Dict[str, int]] = None,
323
+ ):
324
+ options = {} if options is None else options
325
+
326
+ num_rows, num_cols = 1, 1
327
+ if specify_num_samples:
328
+ num_cols = st.number_input(
329
+ f"num cols #{key}", value=num_cols, min_value=1, max_value=10
330
+ )
331
+
332
+ steps = st.sidebar.number_input(
333
+ f"steps #{key}", value=options.get("num_steps", 40), min_value=1, max_value=1000
334
+ )
335
+ sampler = st.sidebar.selectbox(
336
+ f"Sampler #{key}",
337
+ [
338
+ "EulerEDMSampler",
339
+ "HeunEDMSampler",
340
+ "EulerAncestralSampler",
341
+ "DPMPP2SAncestralSampler",
342
+ "DPMPP2MSampler",
343
+ "LinearMultistepSampler",
344
+ ],
345
+ options.get("sampler", 0),
346
+ )
347
+ discretization = st.sidebar.selectbox(
348
+ f"Discretization #{key}",
349
+ [
350
+ "LegacyDDPMDiscretization",
351
+ "EDMDiscretization",
352
+ ],
353
+ options.get("discretization", 0),
354
+ )
355
+
356
+ discretization_config = get_discretization(discretization, options=options, key=key)
357
+
358
+ guider_config = get_guider(options=options, key=key)
359
+
360
+ sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
361
+ if img2img_strength is not None:
362
+ st.warning(
363
+ f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
364
+ )
365
+ sampler.discretization = Img2ImgDiscretizationWrapper(
366
+ sampler.discretization, strength=img2img_strength
367
+ )
368
+ if stage2strength is not None:
369
+ sampler.discretization = Txt2NoisyDiscretizationWrapper(
370
+ sampler.discretization, strength=stage2strength, original_steps=steps
371
+ )
372
+ return sampler, num_rows, num_cols
373
+
374
+
375
+ def get_discretization(discretization, options, key=1):
376
+ if discretization == "LegacyDDPMDiscretization":
377
+ discretization_config = {
378
+ "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
379
+ }
380
+ elif discretization == "EDMDiscretization":
381
+ sigma_min = st.number_input(
382
+ f"sigma_min #{key}", value=options.get("sigma_min", 0.03)
383
+ ) # 0.0292
384
+ sigma_max = st.number_input(
385
+ f"sigma_max #{key}", value=options.get("sigma_max", 14.61)
386
+ ) # 14.6146
387
+ rho = st.number_input(f"rho #{key}", value=options.get("rho", 3.0))
388
+ discretization_config = {
389
+ "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
390
+ "params": {
391
+ "sigma_min": sigma_min,
392
+ "sigma_max": sigma_max,
393
+ "rho": rho,
394
+ },
395
+ }
396
+
397
+ return discretization_config
398
+
399
+
400
+ def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1):
401
+ if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler":
402
+ s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0)
403
+ s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0)
404
+ s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0)
405
+ s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0)
406
+
407
+ if sampler_name == "EulerEDMSampler":
408
+ sampler = EulerEDMSampler(
409
+ num_steps=steps,
410
+ discretization_config=discretization_config,
411
+ guider_config=guider_config,
412
+ s_churn=s_churn,
413
+ s_tmin=s_tmin,
414
+ s_tmax=s_tmax,
415
+ s_noise=s_noise,
416
+ verbose=True,
417
+ )
418
+ elif sampler_name == "HeunEDMSampler":
419
+ sampler = HeunEDMSampler(
420
+ num_steps=steps,
421
+ discretization_config=discretization_config,
422
+ guider_config=guider_config,
423
+ s_churn=s_churn,
424
+ s_tmin=s_tmin,
425
+ s_tmax=s_tmax,
426
+ s_noise=s_noise,
427
+ verbose=True,
428
+ )
429
+ elif (
430
+ sampler_name == "EulerAncestralSampler"
431
+ or sampler_name == "DPMPP2SAncestralSampler"
432
+ ):
433
+ s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0)
434
+ eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0)
435
+
436
+ if sampler_name == "EulerAncestralSampler":
437
+ sampler = EulerAncestralSampler(
438
+ num_steps=steps,
439
+ discretization_config=discretization_config,
440
+ guider_config=guider_config,
441
+ eta=eta,
442
+ s_noise=s_noise,
443
+ verbose=True,
444
+ )
445
+ elif sampler_name == "DPMPP2SAncestralSampler":
446
+ sampler = DPMPP2SAncestralSampler(
447
+ num_steps=steps,
448
+ discretization_config=discretization_config,
449
+ guider_config=guider_config,
450
+ eta=eta,
451
+ s_noise=s_noise,
452
+ verbose=True,
453
+ )
454
+ elif sampler_name == "DPMPP2MSampler":
455
+ sampler = DPMPP2MSampler(
456
+ num_steps=steps,
457
+ discretization_config=discretization_config,
458
+ guider_config=guider_config,
459
+ verbose=True,
460
+ )
461
+ elif sampler_name == "LinearMultistepSampler":
462
+ order = st.sidebar.number_input("order", value=4, min_value=1)
463
+ sampler = LinearMultistepSampler(
464
+ num_steps=steps,
465
+ discretization_config=discretization_config,
466
+ guider_config=guider_config,
467
+ order=order,
468
+ verbose=True,
469
+ )
470
+ else:
471
+ raise ValueError(f"unknown sampler {sampler_name}!")
472
+
473
+ return sampler
474
+
475
+
476
+ def get_interactive_image() -> Image.Image:
477
+ image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
478
+ if image is not None:
479
+ image = Image.open(image)
480
+ if not image.mode == "RGB":
481
+ image = image.convert("RGB")
482
+ return image
483
+
484
+
485
+ def load_img(
486
+ display: bool = True,
487
+ size: Union[None, int, Tuple[int, int]] = None,
488
+ center_crop: bool = False,
489
+ ):
490
+ image = get_interactive_image()
491
+ if image is None:
492
+ return None
493
+ if display:
494
+ st.image(image)
495
+ w, h = image.size
496
+ print(f"loaded input image of size ({w}, {h})")
497
+
498
+ transform = []
499
+ if size is not None:
500
+ transform.append(transforms.Resize(size))
501
+ if center_crop:
502
+ transform.append(transforms.CenterCrop(size))
503
+ transform.append(transforms.ToTensor())
504
+ transform.append(transforms.Lambda(lambda x: 2.0 * x - 1.0))
505
+
506
+ transform = transforms.Compose(transform)
507
+ img = transform(image)[None, ...]
508
+ st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}")
509
+ return img
510
+
511
+
512
+ def get_init_img(batch_size=1, key=None):
513
+ init_image = load_img(key=key).cuda()
514
+ init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
515
+ return init_image
516
+
517
+
518
+ def do_sample(
519
+ model,
520
+ sampler,
521
+ value_dict,
522
+ num_samples,
523
+ H,
524
+ W,
525
+ C,
526
+ F,
527
+ force_uc_zero_embeddings: Optional[List] = None,
528
+ force_cond_zero_embeddings: Optional[List] = None,
529
+ batch2model_input: List = None,
530
+ return_latents=False,
531
+ filter=None,
532
+ T=None,
533
+ additional_batch_uc_fields=None,
534
+ decoding_t=None,
535
+ ):
536
+ force_uc_zero_embeddings = default(force_uc_zero_embeddings, [])
537
+ batch2model_input = default(batch2model_input, [])
538
+ additional_batch_uc_fields = default(additional_batch_uc_fields, [])
539
+
540
+ st.text("Sampling")
541
+
542
+ outputs = st.empty()
543
+ precision_scope = autocast
544
+ with torch.no_grad():
545
+ with precision_scope("cuda"):
546
+ with model.ema_scope():
547
+ if T is not None:
548
+ num_samples = [num_samples, T]
549
+ else:
550
+ num_samples = [num_samples]
551
+
552
+ load_model(model.conditioner)
553
+ batch, batch_uc = get_batch(
554
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
555
+ value_dict,
556
+ num_samples,
557
+ T=T,
558
+ additional_batch_uc_fields=additional_batch_uc_fields,
559
+ )
560
+
561
+ c, uc = model.conditioner.get_unconditional_conditioning(
562
+ batch,
563
+ batch_uc=batch_uc,
564
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
565
+ force_cond_zero_embeddings=force_cond_zero_embeddings,
566
+ )
567
+ unload_model(model.conditioner)
568
+
569
+ for k in c:
570
+ if not k == "crossattn":
571
+ c[k], uc[k] = map(
572
+ lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
573
+ )
574
+ if k in ["crossattn", "concat"] and T is not None:
575
+ uc[k] = repeat(uc[k], "b ... -> b t ...", t=T)
576
+ uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=T)
577
+ c[k] = repeat(c[k], "b ... -> b t ...", t=T)
578
+ c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=T)
579
+
580
+ additional_model_inputs = {}
581
+ for k in batch2model_input:
582
+ if k == "image_only_indicator":
583
+ assert T is not None
584
+
585
+ if isinstance(
586
+ sampler.guider, (VanillaCFG, LinearPredictionGuider)
587
+ ):
588
+ additional_model_inputs[k] = torch.zeros(
589
+ num_samples[0] * 2, num_samples[1]
590
+ ).to("cuda")
591
+ else:
592
+ additional_model_inputs[k] = torch.zeros(num_samples).to(
593
+ "cuda"
594
+ )
595
+ else:
596
+ additional_model_inputs[k] = batch[k]
597
+
598
+ shape = (math.prod(num_samples), C, H // F, W // F)
599
+ randn = torch.randn(shape).to("cuda")
600
+
601
+ def denoiser(input, sigma, c):
602
+ return model.denoiser(
603
+ model.model, input, sigma, c, **additional_model_inputs
604
+ )
605
+
606
+ load_model(model.denoiser)
607
+ load_model(model.model)
608
+ samples_z = sampler(denoiser, randn, cond=c, uc=uc)
609
+ unload_model(model.model)
610
+ unload_model(model.denoiser)
611
+
612
+ load_model(model.first_stage_model)
613
+ model.en_and_decode_n_samples_a_time = (
614
+ decoding_t # Decode n frames at a time
615
+ )
616
+ samples_x = model.decode_first_stage(samples_z)
617
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
618
+ unload_model(model.first_stage_model)
619
+
620
+ if filter is not None:
621
+ samples = filter(samples)
622
+
623
+ if T is None:
624
+ grid = torch.stack([samples])
625
+ grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
626
+ outputs.image(grid.cpu().numpy())
627
+ else:
628
+ as_vids = rearrange(samples, "(b t) c h w -> b t c h w", t=T)
629
+ for i, vid in enumerate(as_vids):
630
+ grid = rearrange(make_grid(vid, nrow=4), "c h w -> h w c")
631
+ st.image(
632
+ grid.cpu().numpy(),
633
+ f"Sample #{i} as image",
634
+ )
635
+
636
+ if return_latents:
637
+ return samples, samples_z
638
+ return samples
639
+
640
+
641
+ def get_batch(
642
+ keys,
643
+ value_dict: dict,
644
+ N: Union[List, ListConfig],
645
+ device: str = "cuda",
646
+ T: int = None,
647
+ additional_batch_uc_fields: List[str] = [],
648
+ ):
649
+ # Hardcoded demo setups; might undergo some changes in the future
650
+
651
+ batch = {}
652
+ batch_uc = {}
653
+
654
+ for key in keys:
655
+ if key == "txt":
656
+ batch["txt"] = [value_dict["prompt"]] * math.prod(N)
657
+
658
+ batch_uc["txt"] = [value_dict["negative_prompt"]] * math.prod(N)
659
+
660
+ elif key == "original_size_as_tuple":
661
+ batch["original_size_as_tuple"] = (
662
+ torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
663
+ .to(device)
664
+ .repeat(math.prod(N), 1)
665
+ )
666
+ elif key == "crop_coords_top_left":
667
+ batch["crop_coords_top_left"] = (
668
+ torch.tensor(
669
+ [value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
670
+ )
671
+ .to(device)
672
+ .repeat(math.prod(N), 1)
673
+ )
674
+ elif key == "aesthetic_score":
675
+ batch["aesthetic_score"] = (
676
+ torch.tensor([value_dict["aesthetic_score"]])
677
+ .to(device)
678
+ .repeat(math.prod(N), 1)
679
+ )
680
+ batch_uc["aesthetic_score"] = (
681
+ torch.tensor([value_dict["negative_aesthetic_score"]])
682
+ .to(device)
683
+ .repeat(math.prod(N), 1)
684
+ )
685
+
686
+ elif key == "target_size_as_tuple":
687
+ batch["target_size_as_tuple"] = (
688
+ torch.tensor([value_dict["target_height"], value_dict["target_width"]])
689
+ .to(device)
690
+ .repeat(math.prod(N), 1)
691
+ )
692
+ elif key == "fps":
693
+ batch[key] = (
694
+ torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
695
+ )
696
+ elif key == "fps_id":
697
+ batch[key] = (
698
+ torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
699
+ )
700
+ elif key == "motion_bucket_id":
701
+ batch[key] = (
702
+ torch.tensor([value_dict["motion_bucket_id"]])
703
+ .to(device)
704
+ .repeat(math.prod(N))
705
+ )
706
+ elif key == "pool_image":
707
+ batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(
708
+ device, dtype=torch.half
709
+ )
710
+ elif key == "cond_aug":
711
+ batch[key] = repeat(
712
+ torch.tensor([value_dict["cond_aug"]]).to("cuda"),
713
+ "1 -> b",
714
+ b=math.prod(N),
715
+ )
716
+ elif key == "cond_frames":
717
+ batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
718
+ elif key == "cond_frames_without_noise":
719
+ batch[key] = repeat(
720
+ value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
721
+ )
722
+ else:
723
+ batch[key] = value_dict[key]
724
+
725
+ if T is not None:
726
+ batch["num_video_frames"] = T
727
+
728
+ for key in batch.keys():
729
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
730
+ batch_uc[key] = torch.clone(batch[key])
731
+ elif key in additional_batch_uc_fields and key not in batch_uc:
732
+ batch_uc[key] = copy.copy(batch[key])
733
+ return batch, batch_uc
734
+
735
+
736
+ @torch.no_grad()
737
+ def do_img2img(
738
+ img,
739
+ model,
740
+ sampler,
741
+ value_dict,
742
+ num_samples,
743
+ force_uc_zero_embeddings: Optional[List] = None,
744
+ force_cond_zero_embeddings: Optional[List] = None,
745
+ additional_kwargs={},
746
+ offset_noise_level: int = 0.0,
747
+ return_latents=False,
748
+ skip_encode=False,
749
+ filter=None,
750
+ add_noise=True,
751
+ ):
752
+ st.text("Sampling")
753
+
754
+ outputs = st.empty()
755
+ precision_scope = autocast
756
+ with torch.no_grad():
757
+ with precision_scope("cuda"):
758
+ with model.ema_scope():
759
+ load_model(model.conditioner)
760
+ batch, batch_uc = get_batch(
761
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
762
+ value_dict,
763
+ [num_samples],
764
+ )
765
+ c, uc = model.conditioner.get_unconditional_conditioning(
766
+ batch,
767
+ batch_uc=batch_uc,
768
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
769
+ force_cond_zero_embeddings=force_cond_zero_embeddings,
770
+ )
771
+ unload_model(model.conditioner)
772
+ for k in c:
773
+ c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc))
774
+
775
+ for k in additional_kwargs:
776
+ c[k] = uc[k] = additional_kwargs[k]
777
+ if skip_encode:
778
+ z = img
779
+ else:
780
+ load_model(model.first_stage_model)
781
+ z = model.encode_first_stage(img)
782
+ unload_model(model.first_stage_model)
783
+
784
+ noise = torch.randn_like(z)
785
+
786
+ sigmas = sampler.discretization(sampler.num_steps).cuda()
787
+ sigma = sigmas[0]
788
+
789
+ st.info(f"all sigmas: {sigmas}")
790
+ st.info(f"noising sigma: {sigma}")
791
+ if offset_noise_level > 0.0:
792
+ noise = noise + offset_noise_level * append_dims(
793
+ torch.randn(z.shape[0], device=z.device), z.ndim
794
+ )
795
+ if add_noise:
796
+ noised_z = z + noise * append_dims(sigma, z.ndim).cuda()
797
+ noised_z = noised_z / torch.sqrt(
798
+ 1.0 + sigmas[0] ** 2.0
799
+ ) # Note: hardcoded to DDPM-like scaling. need to generalize later.
800
+ else:
801
+ noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)
802
+
803
+ def denoiser(x, sigma, c):
804
+ return model.denoiser(model.model, x, sigma, c)
805
+
806
+ load_model(model.denoiser)
807
+ load_model(model.model)
808
+ samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
809
+ unload_model(model.model)
810
+ unload_model(model.denoiser)
811
+
812
+ load_model(model.first_stage_model)
813
+ samples_x = model.decode_first_stage(samples_z)
814
+ unload_model(model.first_stage_model)
815
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
816
+
817
+ if filter is not None:
818
+ samples = filter(samples)
819
+
820
+ grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
821
+ outputs.image(grid.cpu().numpy())
822
+ if return_latents:
823
+ return samples, samples_z
824
+ return samples
825
+
826
+
827
+ def get_resizing_factor(
828
+ desired_shape: Tuple[int, int], current_shape: Tuple[int, int]
829
+ ) -> float:
830
+ r_bound = desired_shape[1] / desired_shape[0]
831
+ aspect_r = current_shape[1] / current_shape[0]
832
+ if r_bound >= 1.0:
833
+ if aspect_r >= r_bound:
834
+ factor = min(desired_shape) / min(current_shape)
835
+ else:
836
+ if aspect_r < 1.0:
837
+ factor = max(desired_shape) / min(current_shape)
838
+ else:
839
+ factor = max(desired_shape) / max(current_shape)
840
+ else:
841
+ if aspect_r <= r_bound:
842
+ factor = min(desired_shape) / min(current_shape)
843
+ else:
844
+ if aspect_r > 1:
845
+ factor = max(desired_shape) / min(current_shape)
846
+ else:
847
+ factor = max(desired_shape) / max(current_shape)
848
+
849
+ return factor
850
+
851
+
852
+ def get_interactive_image(key=None) -> Image.Image:
853
+ image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
854
+ if image is not None:
855
+ image = Image.open(image)
856
+ if not image.mode == "RGB":
857
+ image = image.convert("RGB")
858
+ return image
859
+
860
+
861
+ def load_img_for_prediction(
862
+ W: int, H: int, display=True, key=None, device="cuda"
863
+ ) -> torch.Tensor:
864
+ image = get_interactive_image(key=key)
865
+ if image is None:
866
+ return None
867
+ if display:
868
+ st.image(image)
869
+ w, h = image.size
870
+
871
+ image = np.array(image).transpose(2, 0, 1)
872
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0
873
+ image = image.unsqueeze(0)
874
+
875
+ rfs = get_resizing_factor((H, W), (h, w))
876
+ resize_size = [int(np.ceil(rfs * s)) for s in (h, w)]
877
+ top = (resize_size[0] - H) // 2
878
+ left = (resize_size[1] - W) // 2
879
+
880
+ image = torch.nn.functional.interpolate(
881
+ image, resize_size, mode="area", antialias=False
882
+ )
883
+ image = TT.functional.crop(image, top=top, left=left, height=H, width=W)
884
+
885
+ if display:
886
+ numpy_img = np.transpose(image[0].numpy(), (1, 2, 0))
887
+ pil_image = Image.fromarray((numpy_img * 255).astype(np.uint8))
888
+ st.image(pil_image)
889
+ return image.to(device) * 2.0 - 1.0
890
+
891
+
892
+ def save_video_as_grid_and_mp4(
893
+ video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5
894
+ ):
895
+ os.makedirs(save_path, exist_ok=True)
896
+ base_count = len(glob(os.path.join(save_path, "*.mp4")))
897
+
898
+ video_batch = rearrange(video_batch, "(b t) c h w -> b t c h w", t=T)
899
+ video_batch = embed_watermark(video_batch)
900
+ for vid in video_batch:
901
+ save_image(vid, fp=os.path.join(save_path, f"{base_count:06d}.png"), nrow=4)
902
+
903
+ video_path = os.path.join(save_path, f"{base_count:06d}.mp4")
904
+
905
+ writer = cv2.VideoWriter(
906
+ video_path,
907
+ cv2.VideoWriter_fourcc(*"MP4V"),
908
+ fps,
909
+ (vid.shape[-1], vid.shape[-2]),
910
+ )
911
+
912
+ vid = (
913
+ (rearrange(vid, "t c h w -> t h w c") * 255).cpu().numpy().astype(np.uint8)
914
+ )
915
+ for frame in vid:
916
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
917
+ writer.write(frame)
918
+
919
+ writer.release()
920
+
921
+ video_path_h264 = video_path[:-4] + "_h264.mp4"
922
+ os.system(f"ffmpeg -i {video_path} -c:v libx264 {video_path_h264}")
923
+
924
+ with open(video_path_h264, "rb") as f:
925
+ video_bytes = f.read()
926
+ st.video(video_bytes)
927
+
928
+ base_count += 1
scripts/demo/video_sampling.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from pytorch_lightning import seed_everything
4
+
5
+ from scripts.demo.streamlit_helpers import *
6
+
7
+ SAVE_PATH = "outputs/demo/vid/"
8
+
9
+ VERSION2SPECS = {
10
+ "svd": {
11
+ "T": 14,
12
+ "H": 576,
13
+ "W": 1024,
14
+ "C": 4,
15
+ "f": 8,
16
+ "config": "configs/inference/svd.yaml",
17
+ "ckpt": "checkpoints/svd.safetensors",
18
+ "options": {
19
+ "discretization": 1,
20
+ "cfg": 2.5,
21
+ "sigma_min": 0.002,
22
+ "sigma_max": 700.0,
23
+ "rho": 7.0,
24
+ "guider": 2,
25
+ "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
26
+ "num_steps": 25,
27
+ },
28
+ },
29
+ "svd_image_decoder": {
30
+ "T": 14,
31
+ "H": 576,
32
+ "W": 1024,
33
+ "C": 4,
34
+ "f": 8,
35
+ "config": "configs/inference/svd_image_decoder.yaml",
36
+ "ckpt": "checkpoints/svd_image_decoder.safetensors",
37
+ "options": {
38
+ "discretization": 1,
39
+ "cfg": 2.5,
40
+ "sigma_min": 0.002,
41
+ "sigma_max": 700.0,
42
+ "rho": 7.0,
43
+ "guider": 2,
44
+ "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
45
+ "num_steps": 25,
46
+ },
47
+ },
48
+ "svd_xt": {
49
+ "T": 25,
50
+ "H": 576,
51
+ "W": 1024,
52
+ "C": 4,
53
+ "f": 8,
54
+ "config": "configs/inference/svd.yaml",
55
+ "ckpt": "checkpoints/svd_xt.safetensors",
56
+ "options": {
57
+ "discretization": 1,
58
+ "cfg": 3.0,
59
+ "min_cfg": 1.5,
60
+ "sigma_min": 0.002,
61
+ "sigma_max": 700.0,
62
+ "rho": 7.0,
63
+ "guider": 2,
64
+ "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
65
+ "num_steps": 30,
66
+ "decoding_t": 14,
67
+ },
68
+ },
69
+ "svd_xt_image_decoder": {
70
+ "T": 25,
71
+ "H": 576,
72
+ "W": 1024,
73
+ "C": 4,
74
+ "f": 8,
75
+ "config": "configs/inference/svd_image_decoder.yaml",
76
+ "ckpt": "checkpoints/svd_xt_image_decoder.safetensors",
77
+ "options": {
78
+ "discretization": 1,
79
+ "cfg": 3.0,
80
+ "min_cfg": 1.5,
81
+ "sigma_min": 0.002,
82
+ "sigma_max": 700.0,
83
+ "rho": 7.0,
84
+ "guider": 2,
85
+ "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
86
+ "num_steps": 30,
87
+ "decoding_t": 14,
88
+ },
89
+ },
90
+ }
91
+
92
+
93
+ if __name__ == "__main__":
94
+ st.title("Stable Video Diffusion")
95
+ version = st.selectbox(
96
+ "Model Version",
97
+ [k for k in VERSION2SPECS.keys()],
98
+ 0,
99
+ )
100
+ version_dict = VERSION2SPECS[version]
101
+ if st.checkbox("Load Model"):
102
+ mode = "img2vid"
103
+ else:
104
+ mode = "skip"
105
+
106
+ H = st.sidebar.number_input(
107
+ "H", value=version_dict["H"], min_value=64, max_value=2048
108
+ )
109
+ W = st.sidebar.number_input(
110
+ "W", value=version_dict["W"], min_value=64, max_value=2048
111
+ )
112
+ T = st.sidebar.number_input(
113
+ "T", value=version_dict["T"], min_value=0, max_value=128
114
+ )
115
+ C = version_dict["C"]
116
+ F = version_dict["f"]
117
+ options = version_dict["options"]
118
+
119
+ if mode != "skip":
120
+ state = init_st(version_dict, load_filter=True)
121
+ if state["msg"]:
122
+ st.info(state["msg"])
123
+ model = state["model"]
124
+
125
+ ukeys = set(
126
+ get_unique_embedder_keys_from_conditioner(state["model"].conditioner)
127
+ )
128
+
129
+ value_dict = init_embedder_options(
130
+ ukeys,
131
+ {},
132
+ )
133
+
134
+ value_dict["image_only_indicator"] = 0
135
+
136
+ if mode == "img2vid":
137
+ img = load_img_for_prediction(W, H)
138
+ cond_aug = st.number_input(
139
+ "Conditioning augmentation:", value=0.02, min_value=0.0
140
+ )
141
+ value_dict["cond_frames_without_noise"] = img
142
+ value_dict["cond_frames"] = img + cond_aug * torch.randn_like(img)
143
+ value_dict["cond_aug"] = cond_aug
144
+
145
+ seed = st.sidebar.number_input(
146
+ "seed", value=23, min_value=0, max_value=int(1e9)
147
+ )
148
+ seed_everything(seed)
149
+
150
+ save_locally, save_path = init_save_locally(
151
+ os.path.join(SAVE_PATH, version), init_value=True
152
+ )
153
+
154
+ options["num_frames"] = T
155
+
156
+ sampler, num_rows, num_cols = init_sampling(options=options)
157
+ num_samples = num_rows * num_cols
158
+
159
+ decoding_t = st.number_input(
160
+ "Decode t frames at a time (set small if you are low on VRAM)",
161
+ value=options.get("decoding_t", T),
162
+ min_value=1,
163
+ max_value=int(1e9),
164
+ )
165
+
166
+ if st.checkbox("Overwrite fps in mp4 generator", False):
167
+ saving_fps = st.number_input(
168
+ f"saving video at fps:", value=value_dict["fps"], min_value=1
169
+ )
170
+ else:
171
+ saving_fps = value_dict["fps"]
172
+
173
+ if st.button("Sample"):
174
+ out = do_sample(
175
+ model,
176
+ sampler,
177
+ value_dict,
178
+ num_samples,
179
+ H,
180
+ W,
181
+ C,
182
+ F,
183
+ T=T,
184
+ batch2model_input=["num_video_frames", "image_only_indicator"],
185
+ force_uc_zero_embeddings=options.get("force_uc_zero_embeddings", None),
186
+ force_cond_zero_embeddings=options.get(
187
+ "force_cond_zero_embeddings", None
188
+ ),
189
+ return_latents=False,
190
+ decoding_t=decoding_t,
191
+ )
192
+
193
+ if isinstance(out, (tuple, list)):
194
+ samples, samples_z = out
195
+ else:
196
+ samples = out
197
+ samples_z = None
198
+
199
+ if save_locally:
200
+ save_video_as_grid_and_mp4(samples, save_path, T, fps=saving_fps)
scripts/sampling/configs/svd.yaml ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.18215
5
+ disable_first_stage_autocast: True
6
+ ckpt_path: checkpoints/svd.safetensors
7
+
8
+ denoiser_config:
9
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
10
+ params:
11
+ scaling_config:
12
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
13
+
14
+ network_config:
15
+ target: sgm.modules.diffusionmodules.video_model.VideoUNet
16
+ params:
17
+ adm_in_channels: 768
18
+ num_classes: sequential
19
+ use_checkpoint: True
20
+ in_channels: 8
21
+ out_channels: 4
22
+ model_channels: 320
23
+ attention_resolutions: [4, 2, 1]
24
+ num_res_blocks: 2
25
+ channel_mult: [1, 2, 4, 4]
26
+ num_head_channels: 64
27
+ use_linear_in_transformer: True
28
+ transformer_depth: 1
29
+ context_dim: 1024
30
+ spatial_transformer_attn_type: softmax-xformers
31
+ extra_ff_mix_layer: True
32
+ use_spatial_context: True
33
+ merge_strategy: learned_with_images
34
+ video_kernel_size: [3, 1, 1]
35
+
36
+ conditioner_config:
37
+ target: sgm.modules.GeneralConditioner
38
+ params:
39
+ emb_models:
40
+ - is_trainable: False
41
+ input_key: cond_frames_without_noise
42
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
43
+ params:
44
+ n_cond_frames: 1
45
+ n_copies: 1
46
+ open_clip_embedding_config:
47
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
48
+ params:
49
+ freeze: True
50
+
51
+ - input_key: fps_id
52
+ is_trainable: False
53
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
54
+ params:
55
+ outdim: 256
56
+
57
+ - input_key: motion_bucket_id
58
+ is_trainable: False
59
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
60
+ params:
61
+ outdim: 256
62
+
63
+ - input_key: cond_frames
64
+ is_trainable: False
65
+ target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
66
+ params:
67
+ disable_encoder_autocast: True
68
+ n_cond_frames: 1
69
+ n_copies: 1
70
+ is_ae: True
71
+ encoder_config:
72
+ target: sgm.models.autoencoder.AutoencoderKLModeOnly
73
+ params:
74
+ embed_dim: 4
75
+ monitor: val/rec_loss
76
+ ddconfig:
77
+ attn_type: vanilla-xformers
78
+ double_z: True
79
+ z_channels: 4
80
+ resolution: 256
81
+ in_channels: 3
82
+ out_ch: 3
83
+ ch: 128
84
+ ch_mult: [1, 2, 4, 4]
85
+ num_res_blocks: 2
86
+ attn_resolutions: []
87
+ dropout: 0.0
88
+ lossconfig:
89
+ target: torch.nn.Identity
90
+
91
+ - input_key: cond_aug
92
+ is_trainable: False
93
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
94
+ params:
95
+ outdim: 256
96
+
97
+ first_stage_config:
98
+ target: sgm.models.autoencoder.AutoencodingEngine
99
+ params:
100
+ loss_config:
101
+ target: torch.nn.Identity
102
+ regularizer_config:
103
+ target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
104
+ encoder_config:
105
+ target: sgm.modules.diffusionmodules.model.Encoder
106
+ params:
107
+ attn_type: vanilla
108
+ double_z: True
109
+ z_channels: 4
110
+ resolution: 256
111
+ in_channels: 3
112
+ out_ch: 3
113
+ ch: 128
114
+ ch_mult: [1, 2, 4, 4]
115
+ num_res_blocks: 2
116
+ attn_resolutions: []
117
+ dropout: 0.0
118
+ decoder_config:
119
+ target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
120
+ params:
121
+ attn_type: vanilla
122
+ double_z: True
123
+ z_channels: 4
124
+ resolution: 256
125
+ in_channels: 3
126
+ out_ch: 3
127
+ ch: 128
128
+ ch_mult: [1, 2, 4, 4]
129
+ num_res_blocks: 2
130
+ attn_resolutions: []
131
+ dropout: 0.0
132
+ video_kernel_size: [3, 1, 1]
133
+
134
+ sampler_config:
135
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
136
+ params:
137
+ discretization_config:
138
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
139
+ params:
140
+ sigma_max: 700.0
141
+
142
+ guider_config:
143
+ target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
144
+ params:
145
+ max_scale: 2.5
146
+ min_scale: 1.0
scripts/sampling/configs/svd_image_decoder.yaml ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.18215
5
+ disable_first_stage_autocast: True
6
+ ckpt_path: checkpoints/svd_image_decoder.safetensors
7
+
8
+ denoiser_config:
9
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
10
+ params:
11
+ scaling_config:
12
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
13
+
14
+ network_config:
15
+ target: sgm.modules.diffusionmodules.video_model.VideoUNet
16
+ params:
17
+ adm_in_channels: 768
18
+ num_classes: sequential
19
+ use_checkpoint: True
20
+ in_channels: 8
21
+ out_channels: 4
22
+ model_channels: 320
23
+ attention_resolutions: [4, 2, 1]
24
+ num_res_blocks: 2
25
+ channel_mult: [1, 2, 4, 4]
26
+ num_head_channels: 64
27
+ use_linear_in_transformer: True
28
+ transformer_depth: 1
29
+ context_dim: 1024
30
+ spatial_transformer_attn_type: softmax-xformers
31
+ extra_ff_mix_layer: True
32
+ use_spatial_context: True
33
+ merge_strategy: learned_with_images
34
+ video_kernel_size: [3, 1, 1]
35
+
36
+ conditioner_config:
37
+ target: sgm.modules.GeneralConditioner
38
+ params:
39
+ emb_models:
40
+ - is_trainable: False
41
+ input_key: cond_frames_without_noise
42
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
43
+ params:
44
+ n_cond_frames: 1
45
+ n_copies: 1
46
+ open_clip_embedding_config:
47
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
48
+ params:
49
+ freeze: True
50
+
51
+ - input_key: fps_id
52
+ is_trainable: False
53
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
54
+ params:
55
+ outdim: 256
56
+
57
+ - input_key: motion_bucket_id
58
+ is_trainable: False
59
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
60
+ params:
61
+ outdim: 256
62
+
63
+ - input_key: cond_frames
64
+ is_trainable: False
65
+ target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
66
+ params:
67
+ disable_encoder_autocast: True
68
+ n_cond_frames: 1
69
+ n_copies: 1
70
+ is_ae: True
71
+ encoder_config:
72
+ target: sgm.models.autoencoder.AutoencoderKLModeOnly
73
+ params:
74
+ embed_dim: 4
75
+ monitor: val/rec_loss
76
+ ddconfig:
77
+ attn_type: vanilla-xformers
78
+ double_z: True
79
+ z_channels: 4
80
+ resolution: 256
81
+ in_channels: 3
82
+ out_ch: 3
83
+ ch: 128
84
+ ch_mult: [1, 2, 4, 4]
85
+ num_res_blocks: 2
86
+ attn_resolutions: []
87
+ dropout: 0.0
88
+ lossconfig:
89
+ target: torch.nn.Identity
90
+
91
+ - input_key: cond_aug
92
+ is_trainable: False
93
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
94
+ params:
95
+ outdim: 256
96
+
97
+ first_stage_config:
98
+ target: sgm.models.autoencoder.AutoencoderKL
99
+ params:
100
+ embed_dim: 4
101
+ monitor: val/rec_loss
102
+ ddconfig:
103
+ attn_type: vanilla-xformers
104
+ double_z: True
105
+ z_channels: 4
106
+ resolution: 256
107
+ in_channels: 3
108
+ out_ch: 3
109
+ ch: 128
110
+ ch_mult: [1, 2, 4, 4]
111
+ num_res_blocks: 2
112
+ attn_resolutions: []
113
+ dropout: 0.0
114
+ lossconfig:
115
+ target: torch.nn.Identity
116
+
117
+ sampler_config:
118
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
119
+ params:
120
+ discretization_config:
121
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
122
+ params:
123
+ sigma_max: 700.0
124
+
125
+ guider_config:
126
+ target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
127
+ params:
128
+ max_scale: 2.5
129
+ min_scale: 1.0
scripts/sampling/configs/svd_xt.yaml ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.18215
5
+ disable_first_stage_autocast: True
6
+ ckpt_path: checkpoints/svd_xt.safetensors
7
+
8
+ denoiser_config:
9
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
10
+ params:
11
+ scaling_config:
12
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
13
+
14
+ network_config:
15
+ target: sgm.modules.diffusionmodules.video_model.VideoUNet
16
+ params:
17
+ adm_in_channels: 768
18
+ num_classes: sequential
19
+ use_checkpoint: True
20
+ in_channels: 8
21
+ out_channels: 4
22
+ model_channels: 320
23
+ attention_resolutions: [4, 2, 1]
24
+ num_res_blocks: 2
25
+ channel_mult: [1, 2, 4, 4]
26
+ num_head_channels: 64
27
+ use_linear_in_transformer: True
28
+ transformer_depth: 1
29
+ context_dim: 1024
30
+ spatial_transformer_attn_type: softmax-xformers
31
+ extra_ff_mix_layer: True
32
+ use_spatial_context: True
33
+ merge_strategy: learned_with_images
34
+ video_kernel_size: [3, 1, 1]
35
+
36
+ conditioner_config:
37
+ target: sgm.modules.GeneralConditioner
38
+ params:
39
+ emb_models:
40
+ - is_trainable: False
41
+ input_key: cond_frames_without_noise
42
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
43
+ params:
44
+ n_cond_frames: 1
45
+ n_copies: 1
46
+ open_clip_embedding_config:
47
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
48
+ params:
49
+ freeze: True
50
+
51
+ - input_key: fps_id
52
+ is_trainable: False
53
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
54
+ params:
55
+ outdim: 256
56
+
57
+ - input_key: motion_bucket_id
58
+ is_trainable: False
59
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
60
+ params:
61
+ outdim: 256
62
+
63
+ - input_key: cond_frames
64
+ is_trainable: False
65
+ target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
66
+ params:
67
+ disable_encoder_autocast: True
68
+ n_cond_frames: 1
69
+ n_copies: 1
70
+ is_ae: True
71
+ encoder_config:
72
+ target: sgm.models.autoencoder.AutoencoderKLModeOnly
73
+ params:
74
+ embed_dim: 4
75
+ monitor: val/rec_loss
76
+ ddconfig:
77
+ attn_type: vanilla-xformers
78
+ double_z: True
79
+ z_channels: 4
80
+ resolution: 256
81
+ in_channels: 3
82
+ out_ch: 3
83
+ ch: 128
84
+ ch_mult: [1, 2, 4, 4]
85
+ num_res_blocks: 2
86
+ attn_resolutions: []
87
+ dropout: 0.0
88
+ lossconfig:
89
+ target: torch.nn.Identity
90
+
91
+ - input_key: cond_aug
92
+ is_trainable: False
93
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
94
+ params:
95
+ outdim: 256
96
+
97
+ first_stage_config:
98
+ target: sgm.models.autoencoder.AutoencodingEngine
99
+ params:
100
+ loss_config:
101
+ target: torch.nn.Identity
102
+ regularizer_config:
103
+ target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
104
+ encoder_config:
105
+ target: sgm.modules.diffusionmodules.model.Encoder
106
+ params:
107
+ attn_type: vanilla
108
+ double_z: True
109
+ z_channels: 4
110
+ resolution: 256
111
+ in_channels: 3
112
+ out_ch: 3
113
+ ch: 128
114
+ ch_mult: [1, 2, 4, 4]
115
+ num_res_blocks: 2
116
+ attn_resolutions: []
117
+ dropout: 0.0
118
+ decoder_config:
119
+ target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
120
+ params:
121
+ attn_type: vanilla
122
+ double_z: True
123
+ z_channels: 4
124
+ resolution: 256
125
+ in_channels: 3
126
+ out_ch: 3
127
+ ch: 128
128
+ ch_mult: [1, 2, 4, 4]
129
+ num_res_blocks: 2
130
+ attn_resolutions: []
131
+ dropout: 0.0
132
+ video_kernel_size: [3, 1, 1]
133
+
134
+ sampler_config:
135
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
136
+ params:
137
+ discretization_config:
138
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
139
+ params:
140
+ sigma_max: 700.0
141
+
142
+ guider_config:
143
+ target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
144
+ params:
145
+ max_scale: 3.0
146
+ min_scale: 1.5
scripts/sampling/configs/svd_xt_image_decoder.yaml ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.18215
5
+ disable_first_stage_autocast: True
6
+ ckpt_path: checkpoints/svd_xt_image_decoder.safetensors
7
+
8
+ denoiser_config:
9
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
10
+ params:
11
+ scaling_config:
12
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
13
+
14
+ network_config:
15
+ target: sgm.modules.diffusionmodules.video_model.VideoUNet
16
+ params:
17
+ adm_in_channels: 768
18
+ num_classes: sequential
19
+ use_checkpoint: True
20
+ in_channels: 8
21
+ out_channels: 4
22
+ model_channels: 320
23
+ attention_resolutions: [4, 2, 1]
24
+ num_res_blocks: 2
25
+ channel_mult: [1, 2, 4, 4]
26
+ num_head_channels: 64
27
+ use_linear_in_transformer: True
28
+ transformer_depth: 1
29
+ context_dim: 1024
30
+ spatial_transformer_attn_type: softmax-xformers
31
+ extra_ff_mix_layer: True
32
+ use_spatial_context: True
33
+ merge_strategy: learned_with_images
34
+ video_kernel_size: [3, 1, 1]
35
+
36
+ conditioner_config:
37
+ target: sgm.modules.GeneralConditioner
38
+ params:
39
+ emb_models:
40
+ - is_trainable: False
41
+ input_key: cond_frames_without_noise
42
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
43
+ params:
44
+ n_cond_frames: 1
45
+ n_copies: 1
46
+ open_clip_embedding_config:
47
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
48
+ params:
49
+ freeze: True
50
+
51
+ - input_key: fps_id
52
+ is_trainable: False
53
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
54
+ params:
55
+ outdim: 256
56
+
57
+ - input_key: motion_bucket_id
58
+ is_trainable: False
59
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
60
+ params:
61
+ outdim: 256
62
+
63
+ - input_key: cond_frames
64
+ is_trainable: False
65
+ target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
66
+ params:
67
+ disable_encoder_autocast: True
68
+ n_cond_frames: 1
69
+ n_copies: 1
70
+ is_ae: True
71
+ encoder_config:
72
+ target: sgm.models.autoencoder.AutoencoderKLModeOnly
73
+ params:
74
+ embed_dim: 4
75
+ monitor: val/rec_loss
76
+ ddconfig:
77
+ attn_type: vanilla-xformers
78
+ double_z: True
79
+ z_channels: 4
80
+ resolution: 256
81
+ in_channels: 3
82
+ out_ch: 3
83
+ ch: 128
84
+ ch_mult: [1, 2, 4, 4]
85
+ num_res_blocks: 2
86
+ attn_resolutions: []
87
+ dropout: 0.0
88
+ lossconfig:
89
+ target: torch.nn.Identity
90
+
91
+ - input_key: cond_aug
92
+ is_trainable: False
93
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
94
+ params:
95
+ outdim: 256
96
+
97
+ first_stage_config:
98
+ target: sgm.models.autoencoder.AutoencoderKL
99
+ params:
100
+ embed_dim: 4
101
+ monitor: val/rec_loss
102
+ ddconfig:
103
+ attn_type: vanilla-xformers
104
+ double_z: True
105
+ z_channels: 4
106
+ resolution: 256
107
+ in_channels: 3
108
+ out_ch: 3
109
+ ch: 128
110
+ ch_mult: [1, 2, 4, 4]
111
+ num_res_blocks: 2
112
+ attn_resolutions: []
113
+ dropout: 0.0
114
+ lossconfig:
115
+ target: torch.nn.Identity
116
+
117
+ sampler_config:
118
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
119
+ params:
120
+ discretization_config:
121
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
122
+ params:
123
+ sigma_max: 700.0
124
+
125
+ guider_config:
126
+ target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
127
+ params:
128
+ max_scale: 3.0
129
+ min_scale: 1.5
scripts/sampling/simple_video_sample.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from glob import glob
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+ from einops import rearrange, repeat
11
+ from fire import Fire
12
+ from omegaconf import OmegaConf
13
+ from PIL import Image
14
+ from torchvision.transforms import ToTensor
15
+
16
+ from scripts.util.detection.nsfw_and_watermark_dectection import \
17
+ DeepFloydDataFiltering
18
+ from sgm.inference.helpers import embed_watermark
19
+ from sgm.util import default, instantiate_from_config
20
+
21
+
22
+ def sample(
23
+ input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
24
+ num_frames: Optional[int] = None,
25
+ num_steps: Optional[int] = None,
26
+ version: str = "svd",
27
+ fps_id: int = 6,
28
+ motion_bucket_id: int = 127,
29
+ cond_aug: float = 0.02,
30
+ seed: int = 23,
31
+ decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
32
+ device: str = "cuda",
33
+ output_folder: Optional[str] = None,
34
+ ):
35
+ """
36
+ Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
37
+ image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
38
+ """
39
+
40
+ if version == "svd":
41
+ num_frames = default(num_frames, 14)
42
+ num_steps = default(num_steps, 25)
43
+ output_folder = default(output_folder, "outputs/simple_video_sample/svd/")
44
+ model_config = "scripts/sampling/configs/svd.yaml"
45
+ elif version == "svd_xt":
46
+ num_frames = default(num_frames, 25)
47
+ num_steps = default(num_steps, 30)
48
+ output_folder = default(output_folder, "outputs/simple_video_sample/svd_xt/")
49
+ model_config = "scripts/sampling/configs/svd_xt.yaml"
50
+ elif version == "svd_image_decoder":
51
+ num_frames = default(num_frames, 14)
52
+ num_steps = default(num_steps, 25)
53
+ output_folder = default(
54
+ output_folder, "outputs/simple_video_sample/svd_image_decoder/"
55
+ )
56
+ model_config = "scripts/sampling/configs/svd_image_decoder.yaml"
57
+ elif version == "svd_xt_image_decoder":
58
+ num_frames = default(num_frames, 25)
59
+ num_steps = default(num_steps, 30)
60
+ output_folder = default(
61
+ output_folder, "outputs/simple_video_sample/svd_xt_image_decoder/"
62
+ )
63
+ model_config = "scripts/sampling/configs/svd_xt_image_decoder.yaml"
64
+ else:
65
+ raise ValueError(f"Version {version} does not exist.")
66
+
67
+ model, filter = load_model(
68
+ model_config,
69
+ device,
70
+ num_frames,
71
+ num_steps,
72
+ )
73
+ torch.manual_seed(seed)
74
+
75
+ path = Path(input_path)
76
+ all_img_paths = []
77
+ if path.is_file():
78
+ if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
79
+ all_img_paths = [input_path]
80
+ else:
81
+ raise ValueError("Path is not valid image file.")
82
+ elif path.is_dir():
83
+ all_img_paths = sorted(
84
+ [
85
+ f
86
+ for f in path.iterdir()
87
+ if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
88
+ ]
89
+ )
90
+ if len(all_img_paths) == 0:
91
+ raise ValueError("Folder does not contain any images.")
92
+ else:
93
+ raise ValueError
94
+
95
+ for input_img_path in all_img_paths:
96
+ with Image.open(input_img_path) as image:
97
+ if image.mode == "RGBA":
98
+ image = image.convert("RGB")
99
+ w, h = image.size
100
+
101
+ if h % 64 != 0 or w % 64 != 0:
102
+ width, height = map(lambda x: x - x % 64, (w, h))
103
+ image = image.resize((width, height))
104
+ print(
105
+ f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
106
+ )
107
+
108
+ image = ToTensor()(image)
109
+ image = image * 2.0 - 1.0
110
+
111
+ image = image.unsqueeze(0).to(device)
112
+ H, W = image.shape[2:]
113
+ assert image.shape[1] == 3
114
+ F = 8
115
+ C = 4
116
+ shape = (num_frames, C, H // F, W // F)
117
+ if (H, W) != (576, 1024):
118
+ print(
119
+ "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
120
+ )
121
+ if motion_bucket_id > 255:
122
+ print(
123
+ "WARNING: High motion bucket! This may lead to suboptimal performance."
124
+ )
125
+
126
+ if fps_id < 5:
127
+ print("WARNING: Small fps value! This may lead to suboptimal performance.")
128
+
129
+ if fps_id > 30:
130
+ print("WARNING: Large fps value! This may lead to suboptimal performance.")
131
+
132
+ value_dict = {}
133
+ value_dict["motion_bucket_id"] = motion_bucket_id
134
+ value_dict["fps_id"] = fps_id
135
+ value_dict["cond_aug"] = cond_aug
136
+ value_dict["cond_frames_without_noise"] = image
137
+ value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
138
+ value_dict["cond_aug"] = cond_aug
139
+
140
+ with torch.no_grad():
141
+ with torch.autocast(device):
142
+ batch, batch_uc = get_batch(
143
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
144
+ value_dict,
145
+ [1, num_frames],
146
+ T=num_frames,
147
+ device=device,
148
+ )
149
+ c, uc = model.conditioner.get_unconditional_conditioning(
150
+ batch,
151
+ batch_uc=batch_uc,
152
+ force_uc_zero_embeddings=[
153
+ "cond_frames",
154
+ "cond_frames_without_noise",
155
+ ],
156
+ )
157
+
158
+ for k in ["crossattn", "concat"]:
159
+ uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
160
+ uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
161
+ c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
162
+ c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
163
+
164
+ randn = torch.randn(shape, device=device)
165
+
166
+ additional_model_inputs = {}
167
+ additional_model_inputs["image_only_indicator"] = torch.zeros(
168
+ 2, num_frames
169
+ ).to(device)
170
+ additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
171
+
172
+ def denoiser(input, sigma, c):
173
+ return model.denoiser(
174
+ model.model, input, sigma, c, **additional_model_inputs
175
+ )
176
+
177
+ samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
178
+ model.en_and_decode_n_samples_a_time = decoding_t
179
+ samples_x = model.decode_first_stage(samples_z)
180
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
181
+
182
+ os.makedirs(output_folder, exist_ok=True)
183
+ base_count = len(glob(os.path.join(output_folder, "*.mp4")))
184
+ video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
185
+ writer = cv2.VideoWriter(
186
+ video_path,
187
+ cv2.VideoWriter_fourcc(*"MP4V"),
188
+ fps_id + 1,
189
+ (samples.shape[-1], samples.shape[-2]),
190
+ )
191
+
192
+ samples = embed_watermark(samples)
193
+ samples = filter(samples)
194
+ vid = (
195
+ (rearrange(samples, "t c h w -> t h w c") * 255)
196
+ .cpu()
197
+ .numpy()
198
+ .astype(np.uint8)
199
+ )
200
+ for frame in vid:
201
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
202
+ writer.write(frame)
203
+ writer.release()
204
+
205
+
206
+ def get_unique_embedder_keys_from_conditioner(conditioner):
207
+ return list(set([x.input_key for x in conditioner.embedders]))
208
+
209
+
210
+ def get_batch(keys, value_dict, N, T, device):
211
+ batch = {}
212
+ batch_uc = {}
213
+
214
+ for key in keys:
215
+ if key == "fps_id":
216
+ batch[key] = (
217
+ torch.tensor([value_dict["fps_id"]])
218
+ .to(device)
219
+ .repeat(int(math.prod(N)))
220
+ )
221
+ elif key == "motion_bucket_id":
222
+ batch[key] = (
223
+ torch.tensor([value_dict["motion_bucket_id"]])
224
+ .to(device)
225
+ .repeat(int(math.prod(N)))
226
+ )
227
+ elif key == "cond_aug":
228
+ batch[key] = repeat(
229
+ torch.tensor([value_dict["cond_aug"]]).to(device),
230
+ "1 -> b",
231
+ b=math.prod(N),
232
+ )
233
+ elif key == "cond_frames":
234
+ batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
235
+ elif key == "cond_frames_without_noise":
236
+ batch[key] = repeat(
237
+ value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
238
+ )
239
+ else:
240
+ batch[key] = value_dict[key]
241
+
242
+ if T is not None:
243
+ batch["num_video_frames"] = T
244
+
245
+ for key in batch.keys():
246
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
247
+ batch_uc[key] = torch.clone(batch[key])
248
+ return batch, batch_uc
249
+
250
+
251
+ def load_model(
252
+ config: str,
253
+ device: str,
254
+ num_frames: int,
255
+ num_steps: int,
256
+ ):
257
+ config = OmegaConf.load(config)
258
+ if device == "cuda":
259
+ config.model.params.conditioner_config.params.emb_models[
260
+ 0
261
+ ].params.open_clip_embedding_config.params.init_device = device
262
+
263
+ config.model.params.sampler_config.params.num_steps = num_steps
264
+ config.model.params.sampler_config.params.guider_config.params.num_frames = (
265
+ num_frames
266
+ )
267
+ if device == "cuda":
268
+ with torch.device(device):
269
+ model = instantiate_from_config(config.model).to(device).eval()
270
+ else:
271
+ model = instantiate_from_config(config.model).to(device).eval()
272
+
273
+ filter = DeepFloydDataFiltering(verbose=False, device=device)
274
+ return model, filter
275
+
276
+
277
+ if __name__ == "__main__":
278
+ Fire(sample)
scripts/tests/attention.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torch.utils.benchmark as benchmark
5
+ from torch.backends.cuda import SDPBackend
6
+
7
+ from sgm.modules.attention import BasicTransformerBlock, SpatialTransformer
8
+
9
+
10
+ def benchmark_attn():
11
+ # Lets define a helpful benchmarking function:
12
+ # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
16
+ t0 = benchmark.Timer(
17
+ stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
18
+ )
19
+ return t0.blocked_autorange().mean * 1e6
20
+
21
+ # Lets define the hyper-parameters of our input
22
+ batch_size = 32
23
+ max_sequence_len = 1024
24
+ num_heads = 32
25
+ embed_dimension = 32
26
+
27
+ dtype = torch.float16
28
+
29
+ query = torch.rand(
30
+ batch_size,
31
+ num_heads,
32
+ max_sequence_len,
33
+ embed_dimension,
34
+ device=device,
35
+ dtype=dtype,
36
+ )
37
+ key = torch.rand(
38
+ batch_size,
39
+ num_heads,
40
+ max_sequence_len,
41
+ embed_dimension,
42
+ device=device,
43
+ dtype=dtype,
44
+ )
45
+ value = torch.rand(
46
+ batch_size,
47
+ num_heads,
48
+ max_sequence_len,
49
+ embed_dimension,
50
+ device=device,
51
+ dtype=dtype,
52
+ )
53
+
54
+ print(f"q/k/v shape:", query.shape, key.shape, value.shape)
55
+
56
+ # Lets explore the speed of each of the 3 implementations
57
+ from torch.backends.cuda import SDPBackend, sdp_kernel
58
+
59
+ # Helpful arguments mapper
60
+ backend_map = {
61
+ SDPBackend.MATH: {
62
+ "enable_math": True,
63
+ "enable_flash": False,
64
+ "enable_mem_efficient": False,
65
+ },
66
+ SDPBackend.FLASH_ATTENTION: {
67
+ "enable_math": False,
68
+ "enable_flash": True,
69
+ "enable_mem_efficient": False,
70
+ },
71
+ SDPBackend.EFFICIENT_ATTENTION: {
72
+ "enable_math": False,
73
+ "enable_flash": False,
74
+ "enable_mem_efficient": True,
75
+ },
76
+ }
77
+
78
+ from torch.profiler import ProfilerActivity, profile, record_function
79
+
80
+ activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
81
+
82
+ print(
83
+ f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
84
+ )
85
+ with profile(
86
+ activities=activities, record_shapes=False, profile_memory=True
87
+ ) as prof:
88
+ with record_function("Default detailed stats"):
89
+ for _ in range(25):
90
+ o = F.scaled_dot_product_attention(query, key, value)
91
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
92
+
93
+ print(
94
+ f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
95
+ )
96
+ with sdp_kernel(**backend_map[SDPBackend.MATH]):
97
+ with profile(
98
+ activities=activities, record_shapes=False, profile_memory=True
99
+ ) as prof:
100
+ with record_function("Math implmentation stats"):
101
+ for _ in range(25):
102
+ o = F.scaled_dot_product_attention(query, key, value)
103
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
104
+
105
+ with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
106
+ try:
107
+ print(
108
+ f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
109
+ )
110
+ except RuntimeError:
111
+ print("FlashAttention is not supported. See warnings for reasons.")
112
+ with profile(
113
+ activities=activities, record_shapes=False, profile_memory=True
114
+ ) as prof:
115
+ with record_function("FlashAttention stats"):
116
+ for _ in range(25):
117
+ o = F.scaled_dot_product_attention(query, key, value)
118
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
119
+
120
+ with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
121
+ try:
122
+ print(
123
+ f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
124
+ )
125
+ except RuntimeError:
126
+ print("EfficientAttention is not supported. See warnings for reasons.")
127
+ with profile(
128
+ activities=activities, record_shapes=False, profile_memory=True
129
+ ) as prof:
130
+ with record_function("EfficientAttention stats"):
131
+ for _ in range(25):
132
+ o = F.scaled_dot_product_attention(query, key, value)
133
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
134
+
135
+
136
+ def run_model(model, x, context):
137
+ return model(x, context)
138
+
139
+
140
+ def benchmark_transformer_blocks():
141
+ device = "cuda" if torch.cuda.is_available() else "cpu"
142
+ import torch.utils.benchmark as benchmark
143
+
144
+ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
145
+ t0 = benchmark.Timer(
146
+ stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
147
+ )
148
+ return t0.blocked_autorange().mean * 1e6
149
+
150
+ checkpoint = True
151
+ compile = False
152
+
153
+ batch_size = 32
154
+ h, w = 64, 64
155
+ context_len = 77
156
+ embed_dimension = 1024
157
+ context_dim = 1024
158
+ d_head = 64
159
+
160
+ transformer_depth = 4
161
+
162
+ n_heads = embed_dimension // d_head
163
+
164
+ dtype = torch.float16
165
+
166
+ model_native = SpatialTransformer(
167
+ embed_dimension,
168
+ n_heads,
169
+ d_head,
170
+ context_dim=context_dim,
171
+ use_linear=True,
172
+ use_checkpoint=checkpoint,
173
+ attn_type="softmax",
174
+ depth=transformer_depth,
175
+ sdp_backend=SDPBackend.FLASH_ATTENTION,
176
+ ).to(device)
177
+ model_efficient_attn = SpatialTransformer(
178
+ embed_dimension,
179
+ n_heads,
180
+ d_head,
181
+ context_dim=context_dim,
182
+ use_linear=True,
183
+ depth=transformer_depth,
184
+ use_checkpoint=checkpoint,
185
+ attn_type="softmax-xformers",
186
+ ).to(device)
187
+ if not checkpoint and compile:
188
+ print("compiling models")
189
+ model_native = torch.compile(model_native)
190
+ model_efficient_attn = torch.compile(model_efficient_attn)
191
+
192
+ x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
193
+ c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
194
+
195
+ from torch.profiler import ProfilerActivity, profile, record_function
196
+
197
+ activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
198
+
199
+ with torch.autocast("cuda"):
200
+ print(
201
+ f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
202
+ )
203
+ print(
204
+ f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
205
+ )
206
+
207
+ print(75 * "+")
208
+ print("NATIVE")
209
+ print(75 * "+")
210
+ torch.cuda.reset_peak_memory_stats()
211
+ with profile(
212
+ activities=activities, record_shapes=False, profile_memory=True
213
+ ) as prof:
214
+ with record_function("NativeAttention stats"):
215
+ for _ in range(25):
216
+ model_native(x, c)
217
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
218
+ print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
219
+
220
+ print(75 * "+")
221
+ print("Xformers")
222
+ print(75 * "+")
223
+ torch.cuda.reset_peak_memory_stats()
224
+ with profile(
225
+ activities=activities, record_shapes=False, profile_memory=True
226
+ ) as prof:
227
+ with record_function("xformers stats"):
228
+ for _ in range(25):
229
+ model_efficient_attn(x, c)
230
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
231
+ print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
232
+
233
+
234
+ def test01():
235
+ # conv1x1 vs linear
236
+ from sgm.util import count_params
237
+
238
+ conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda()
239
+ print(count_params(conv))
240
+ linear = torch.nn.Linear(3, 32).cuda()
241
+ print(count_params(linear))
242
+
243
+ print(conv.weight.shape)
244
+
245
+ # use same initialization
246
+ linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
247
+ linear.bias = torch.nn.Parameter(conv.bias)
248
+
249
+ print(linear.weight.shape)
250
+
251
+ x = torch.randn(11, 3, 64, 64).cuda()
252
+
253
+ xr = einops.rearrange(x, "b c h w -> b (h w) c").contiguous()
254
+ print(xr.shape)
255
+ out_linear = linear(xr)
256
+ print(out_linear.mean(), out_linear.shape)
257
+
258
+ out_conv = conv(x)
259
+ print(out_conv.mean(), out_conv.shape)
260
+ print("done with test01.\n")
261
+
262
+
263
+ def test02():
264
+ # try cosine flash attention
265
+ import time
266
+
267
+ torch.backends.cuda.matmul.allow_tf32 = True
268
+ torch.backends.cudnn.allow_tf32 = True
269
+ torch.backends.cudnn.benchmark = True
270
+ print("testing cosine flash attention...")
271
+ DIM = 1024
272
+ SEQLEN = 4096
273
+ BS = 16
274
+
275
+ print(" softmax (vanilla) first...")
276
+ model = BasicTransformerBlock(
277
+ dim=DIM,
278
+ n_heads=16,
279
+ d_head=64,
280
+ dropout=0.0,
281
+ context_dim=None,
282
+ attn_mode="softmax",
283
+ ).cuda()
284
+ try:
285
+ x = torch.randn(BS, SEQLEN, DIM).cuda()
286
+ tic = time.time()
287
+ y = model(x)
288
+ toc = time.time()
289
+ print(y.shape, toc - tic)
290
+ except RuntimeError as e:
291
+ # likely oom
292
+ print(str(e))
293
+
294
+ print("\n now flash-cosine...")
295
+ model = BasicTransformerBlock(
296
+ dim=DIM,
297
+ n_heads=16,
298
+ d_head=64,
299
+ dropout=0.0,
300
+ context_dim=None,
301
+ attn_mode="flash-cosine",
302
+ ).cuda()
303
+ x = torch.randn(BS, SEQLEN, DIM).cuda()
304
+ tic = time.time()
305
+ y = model(x)
306
+ toc = time.time()
307
+ print(y.shape, toc - tic)
308
+ print("done with test02.\n")
309
+
310
+
311
+ if __name__ == "__main__":
312
+ # test01()
313
+ # test02()
314
+ # test03()
315
+
316
+ # benchmark_attn()
317
+ benchmark_transformer_blocks()
318
+
319
+ print("done.")
scripts/util/__init__.py ADDED
File without changes
scripts/util/detection/__init__.py ADDED
File without changes
scripts/util/detection/nsfw_and_watermark_dectection.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import clip
4
+ import numpy as np
5
+ import torch
6
+ import torchvision.transforms as T
7
+ from PIL import Image
8
+
9
+ RESOURCES_ROOT = "scripts/util/detection/"
10
+
11
+
12
+ def predict_proba(X, weights, biases):
13
+ logits = X @ weights.T + biases
14
+ proba = np.where(
15
+ logits >= 0, 1 / (1 + np.exp(-logits)), np.exp(logits) / (1 + np.exp(logits))
16
+ )
17
+ return proba.T
18
+
19
+
20
+ def load_model_weights(path: str):
21
+ model_weights = np.load(path)
22
+ return model_weights["weights"], model_weights["biases"]
23
+
24
+
25
+ def clip_process_images(images: torch.Tensor) -> torch.Tensor:
26
+ min_size = min(images.shape[-2:])
27
+ return T.Compose(
28
+ [
29
+ T.CenterCrop(min_size), # TODO: this might affect the watermark, check this
30
+ T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True),
31
+ T.Normalize(
32
+ (0.48145466, 0.4578275, 0.40821073),
33
+ (0.26862954, 0.26130258, 0.27577711),
34
+ ),
35
+ ]
36
+ )(images)
37
+
38
+
39
+ class DeepFloydDataFiltering(object):
40
+ def __init__(
41
+ self, verbose: bool = False, device: torch.device = torch.device("cpu")
42
+ ):
43
+ super().__init__()
44
+ self.verbose = verbose
45
+ self._device = None
46
+ self.clip_model, _ = clip.load("ViT-L/14", device=device)
47
+ self.clip_model.eval()
48
+
49
+ self.cpu_w_weights, self.cpu_w_biases = load_model_weights(
50
+ os.path.join(RESOURCES_ROOT, "w_head_v1.npz")
51
+ )
52
+ self.cpu_p_weights, self.cpu_p_biases = load_model_weights(
53
+ os.path.join(RESOURCES_ROOT, "p_head_v1.npz")
54
+ )
55
+ self.w_threshold, self.p_threshold = 0.5, 0.5
56
+
57
+ @torch.inference_mode()
58
+ def __call__(self, images: torch.Tensor) -> torch.Tensor:
59
+ imgs = clip_process_images(images)
60
+ if self._device is None:
61
+ self._device = next(p for p in self.clip_model.parameters()).device
62
+ image_features = self.clip_model.encode_image(imgs.to(self._device))
63
+ image_features = image_features.detach().cpu().numpy().astype(np.float16)
64
+ p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases)
65
+ w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases)
66
+ print(f"p_pred = {p_pred}, w_pred = {w_pred}") if self.verbose else None
67
+ query = p_pred > self.p_threshold
68
+ if query.sum() > 0:
69
+ print(f"Hit for p_threshold: {p_pred}") if self.verbose else None
70
+ images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query])
71
+ query = w_pred > self.w_threshold
72
+ if query.sum() > 0:
73
+ print(f"Hit for w_threshold: {w_pred}") if self.verbose else None
74
+ images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query])
75
+ return images
76
+
77
+
78
+ def load_img(path: str) -> torch.Tensor:
79
+ image = Image.open(path)
80
+ if not image.mode == "RGB":
81
+ image = image.convert("RGB")
82
+ image_transforms = T.Compose(
83
+ [
84
+ T.ToTensor(),
85
+ ]
86
+ )
87
+ return image_transforms(image)[None, ...]
88
+
89
+
90
+ def test(root):
91
+ from einops import rearrange
92
+
93
+ filter = DeepFloydDataFiltering(verbose=True)
94
+ for p in os.listdir((root)):
95
+ print(f"running on {p}...")
96
+ img = load_img(os.path.join(root, p))
97
+ filtered_img = filter(img)
98
+ filtered_img = rearrange(
99
+ 255.0 * (filtered_img.numpy())[0], "c h w -> h w c"
100
+ ).astype(np.uint8)
101
+ Image.fromarray(filtered_img).save(
102
+ os.path.join(root, f"{os.path.splitext(p)[0]}-filtered.jpg")
103
+ )
104
+
105
+
106
+ if __name__ == "__main__":
107
+ import fire
108
+
109
+ fire.Fire(test)
110
+ print("done.")
scripts/util/detection/p_head_v1.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4653a64d5f85d8d4c5f6c5ec175f1c5c5e37db8f38d39b2ed8b5979da7fdc76
3
+ size 3588
scripts/util/detection/w_head_v1.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6af23687aa347073e692025f405ccc48c14aadc5dbe775b3312041006d496d1
3
+ size 3588
sgm/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .models import AutoencodingEngine, DiffusionEngine
2
+ from .util import get_configs_path, instantiate_from_config
3
+
4
+ __version__ = "0.1.0"
sgm/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .dataset import StableDataModuleFromConfig
sgm/data/cifar10.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torchvision
3
+ from torch.utils.data import DataLoader, Dataset
4
+ from torchvision import transforms
5
+
6
+
7
+ class CIFAR10DataDictWrapper(Dataset):
8
+ def __init__(self, dset):
9
+ super().__init__()
10
+ self.dset = dset
11
+
12
+ def __getitem__(self, i):
13
+ x, y = self.dset[i]
14
+ return {"jpg": x, "cls": y}
15
+
16
+ def __len__(self):
17
+ return len(self.dset)
18
+
19
+
20
+ class CIFAR10Loader(pl.LightningDataModule):
21
+ def __init__(self, batch_size, num_workers=0, shuffle=True):
22
+ super().__init__()
23
+
24
+ transform = transforms.Compose(
25
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
26
+ )
27
+
28
+ self.batch_size = batch_size
29
+ self.num_workers = num_workers
30
+ self.shuffle = shuffle
31
+ self.train_dataset = CIFAR10DataDictWrapper(
32
+ torchvision.datasets.CIFAR10(
33
+ root=".data/", train=True, download=True, transform=transform
34
+ )
35
+ )
36
+ self.test_dataset = CIFAR10DataDictWrapper(
37
+ torchvision.datasets.CIFAR10(
38
+ root=".data/", train=False, download=True, transform=transform
39
+ )
40
+ )
41
+
42
+ def prepare_data(self):
43
+ pass
44
+
45
+ def train_dataloader(self):
46
+ return DataLoader(
47
+ self.train_dataset,
48
+ batch_size=self.batch_size,
49
+ shuffle=self.shuffle,
50
+ num_workers=self.num_workers,
51
+ )
52
+
53
+ def test_dataloader(self):
54
+ return DataLoader(
55
+ self.test_dataset,
56
+ batch_size=self.batch_size,
57
+ shuffle=self.shuffle,
58
+ num_workers=self.num_workers,
59
+ )
60
+
61
+ def val_dataloader(self):
62
+ return DataLoader(
63
+ self.test_dataset,
64
+ batch_size=self.batch_size,
65
+ shuffle=self.shuffle,
66
+ num_workers=self.num_workers,
67
+ )
sgm/data/dataset.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torchdata.datapipes.iter
4
+ import webdataset as wds
5
+ from omegaconf import DictConfig
6
+ from pytorch_lightning import LightningDataModule
7
+
8
+ try:
9
+ from sdata import create_dataset, create_dummy_dataset, create_loader
10
+ except ImportError as e:
11
+ print("#" * 100)
12
+ print("Datasets not yet available")
13
+ print("to enable, we need to add stable-datasets as a submodule")
14
+ print("please use ``git submodule update --init --recursive``")
15
+ print("and do ``pip install -e stable-datasets/`` from the root of this repo")
16
+ print("#" * 100)
17
+ exit(1)
18
+
19
+
20
+ class StableDataModuleFromConfig(LightningDataModule):
21
+ def __init__(
22
+ self,
23
+ train: DictConfig,
24
+ validation: Optional[DictConfig] = None,
25
+ test: Optional[DictConfig] = None,
26
+ skip_val_loader: bool = False,
27
+ dummy: bool = False,
28
+ ):
29
+ super().__init__()
30
+ self.train_config = train
31
+ assert (
32
+ "datapipeline" in self.train_config and "loader" in self.train_config
33
+ ), "train config requires the fields `datapipeline` and `loader`"
34
+
35
+ self.val_config = validation
36
+ if not skip_val_loader:
37
+ if self.val_config is not None:
38
+ assert (
39
+ "datapipeline" in self.val_config and "loader" in self.val_config
40
+ ), "validation config requires the fields `datapipeline` and `loader`"
41
+ else:
42
+ print(
43
+ "Warning: No Validation datapipeline defined, using that one from training"
44
+ )
45
+ self.val_config = train
46
+
47
+ self.test_config = test
48
+ if self.test_config is not None:
49
+ assert (
50
+ "datapipeline" in self.test_config and "loader" in self.test_config
51
+ ), "test config requires the fields `datapipeline` and `loader`"
52
+
53
+ self.dummy = dummy
54
+ if self.dummy:
55
+ print("#" * 100)
56
+ print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
57
+ print("#" * 100)
58
+
59
+ def setup(self, stage: str) -> None:
60
+ print("Preparing datasets")
61
+ if self.dummy:
62
+ data_fn = create_dummy_dataset
63
+ else:
64
+ data_fn = create_dataset
65
+
66
+ self.train_datapipeline = data_fn(**self.train_config.datapipeline)
67
+ if self.val_config:
68
+ self.val_datapipeline = data_fn(**self.val_config.datapipeline)
69
+ if self.test_config:
70
+ self.test_datapipeline = data_fn(**self.test_config.datapipeline)
71
+
72
+ def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
73
+ loader = create_loader(self.train_datapipeline, **self.train_config.loader)
74
+ return loader
75
+
76
+ def val_dataloader(self) -> wds.DataPipeline:
77
+ return create_loader(self.val_datapipeline, **self.val_config.loader)
78
+
79
+ def test_dataloader(self) -> wds.DataPipeline:
80
+ return create_loader(self.test_datapipeline, **self.test_config.loader)
sgm/data/mnist.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torchvision
3
+ from torch.utils.data import DataLoader, Dataset
4
+ from torchvision import transforms
5
+
6
+
7
+ class MNISTDataDictWrapper(Dataset):
8
+ def __init__(self, dset):
9
+ super().__init__()
10
+ self.dset = dset
11
+
12
+ def __getitem__(self, i):
13
+ x, y = self.dset[i]
14
+ return {"jpg": x, "cls": y}
15
+
16
+ def __len__(self):
17
+ return len(self.dset)
18
+
19
+
20
+ class MNISTLoader(pl.LightningDataModule):
21
+ def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True):
22
+ super().__init__()
23
+
24
+ transform = transforms.Compose(
25
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
26
+ )
27
+
28
+ self.batch_size = batch_size
29
+ self.num_workers = num_workers
30
+ self.prefetch_factor = prefetch_factor if num_workers > 0 else 0
31
+ self.shuffle = shuffle
32
+ self.train_dataset = MNISTDataDictWrapper(
33
+ torchvision.datasets.MNIST(
34
+ root=".data/", train=True, download=True, transform=transform
35
+ )
36
+ )
37
+ self.test_dataset = MNISTDataDictWrapper(
38
+ torchvision.datasets.MNIST(
39
+ root=".data/", train=False, download=True, transform=transform
40
+ )
41
+ )
42
+
43
+ def prepare_data(self):
44
+ pass
45
+
46
+ def train_dataloader(self):
47
+ return DataLoader(
48
+ self.train_dataset,
49
+ batch_size=self.batch_size,
50
+ shuffle=self.shuffle,
51
+ num_workers=self.num_workers,
52
+ prefetch_factor=self.prefetch_factor,
53
+ )
54
+
55
+ def test_dataloader(self):
56
+ return DataLoader(
57
+ self.test_dataset,
58
+ batch_size=self.batch_size,
59
+ shuffle=self.shuffle,
60
+ num_workers=self.num_workers,
61
+ prefetch_factor=self.prefetch_factor,
62
+ )
63
+
64
+ def val_dataloader(self):
65
+ return DataLoader(
66
+ self.test_dataset,
67
+ batch_size=self.batch_size,
68
+ shuffle=self.shuffle,
69
+ num_workers=self.num_workers,
70
+ prefetch_factor=self.prefetch_factor,
71
+ )
72
+
73
+
74
+ if __name__ == "__main__":
75
+ dset = MNISTDataDictWrapper(
76
+ torchvision.datasets.MNIST(
77
+ root=".data/",
78
+ train=False,
79
+ download=True,
80
+ transform=transforms.Compose(
81
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
82
+ ),
83
+ )
84
+ )
85
+ ex = dset[0]
sgm/inference/api.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ from dataclasses import asdict, dataclass
3
+ from enum import Enum
4
+ from typing import Optional
5
+
6
+ from omegaconf import OmegaConf
7
+
8
+ from sgm.inference.helpers import Img2ImgDiscretizationWrapper, do_img2img, do_sample
9
+ from sgm.modules.diffusionmodules.sampling import (
10
+ DPMPP2MSampler,
11
+ DPMPP2SAncestralSampler,
12
+ EulerAncestralSampler,
13
+ EulerEDMSampler,
14
+ HeunEDMSampler,
15
+ LinearMultistepSampler,
16
+ )
17
+ from sgm.util import load_model_from_config
18
+
19
+
20
+ class ModelArchitecture(str, Enum):
21
+ SD_2_1 = "stable-diffusion-v2-1"
22
+ SD_2_1_768 = "stable-diffusion-v2-1-768"
23
+ SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
24
+ SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
25
+ SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
26
+ SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
27
+
28
+
29
+ class Sampler(str, Enum):
30
+ EULER_EDM = "EulerEDMSampler"
31
+ HEUN_EDM = "HeunEDMSampler"
32
+ EULER_ANCESTRAL = "EulerAncestralSampler"
33
+ DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
34
+ DPMPP2M = "DPMPP2MSampler"
35
+ LINEAR_MULTISTEP = "LinearMultistepSampler"
36
+
37
+
38
+ class Discretization(str, Enum):
39
+ LEGACY_DDPM = "LegacyDDPMDiscretization"
40
+ EDM = "EDMDiscretization"
41
+
42
+
43
+ class Guider(str, Enum):
44
+ VANILLA = "VanillaCFG"
45
+ IDENTITY = "IdentityGuider"
46
+
47
+
48
+ class Thresholder(str, Enum):
49
+ NONE = "None"
50
+
51
+
52
+ @dataclass
53
+ class SamplingParams:
54
+ width: int = 1024
55
+ height: int = 1024
56
+ steps: int = 50
57
+ sampler: Sampler = Sampler.DPMPP2M
58
+ discretization: Discretization = Discretization.LEGACY_DDPM
59
+ guider: Guider = Guider.VANILLA
60
+ thresholder: Thresholder = Thresholder.NONE
61
+ scale: float = 6.0
62
+ aesthetic_score: float = 5.0
63
+ negative_aesthetic_score: float = 5.0
64
+ img2img_strength: float = 1.0
65
+ orig_width: int = 1024
66
+ orig_height: int = 1024
67
+ crop_coords_top: int = 0
68
+ crop_coords_left: int = 0
69
+ sigma_min: float = 0.0292
70
+ sigma_max: float = 14.6146
71
+ rho: float = 3.0
72
+ s_churn: float = 0.0
73
+ s_tmin: float = 0.0
74
+ s_tmax: float = 999.0
75
+ s_noise: float = 1.0
76
+ eta: float = 1.0
77
+ order: int = 4
78
+
79
+
80
+ @dataclass
81
+ class SamplingSpec:
82
+ width: int
83
+ height: int
84
+ channels: int
85
+ factor: int
86
+ is_legacy: bool
87
+ config: str
88
+ ckpt: str
89
+ is_guided: bool
90
+
91
+
92
+ model_specs = {
93
+ ModelArchitecture.SD_2_1: SamplingSpec(
94
+ height=512,
95
+ width=512,
96
+ channels=4,
97
+ factor=8,
98
+ is_legacy=True,
99
+ config="sd_2_1.yaml",
100
+ ckpt="v2-1_512-ema-pruned.safetensors",
101
+ is_guided=True,
102
+ ),
103
+ ModelArchitecture.SD_2_1_768: SamplingSpec(
104
+ height=768,
105
+ width=768,
106
+ channels=4,
107
+ factor=8,
108
+ is_legacy=True,
109
+ config="sd_2_1_768.yaml",
110
+ ckpt="v2-1_768-ema-pruned.safetensors",
111
+ is_guided=True,
112
+ ),
113
+ ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
114
+ height=1024,
115
+ width=1024,
116
+ channels=4,
117
+ factor=8,
118
+ is_legacy=False,
119
+ config="sd_xl_base.yaml",
120
+ ckpt="sd_xl_base_0.9.safetensors",
121
+ is_guided=True,
122
+ ),
123
+ ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
124
+ height=1024,
125
+ width=1024,
126
+ channels=4,
127
+ factor=8,
128
+ is_legacy=True,
129
+ config="sd_xl_refiner.yaml",
130
+ ckpt="sd_xl_refiner_0.9.safetensors",
131
+ is_guided=True,
132
+ ),
133
+ ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
134
+ height=1024,
135
+ width=1024,
136
+ channels=4,
137
+ factor=8,
138
+ is_legacy=False,
139
+ config="sd_xl_base.yaml",
140
+ ckpt="sd_xl_base_1.0.safetensors",
141
+ is_guided=True,
142
+ ),
143
+ ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
144
+ height=1024,
145
+ width=1024,
146
+ channels=4,
147
+ factor=8,
148
+ is_legacy=True,
149
+ config="sd_xl_refiner.yaml",
150
+ ckpt="sd_xl_refiner_1.0.safetensors",
151
+ is_guided=True,
152
+ ),
153
+ }
154
+
155
+
156
+ class SamplingPipeline:
157
+ def __init__(
158
+ self,
159
+ model_id: ModelArchitecture,
160
+ model_path="checkpoints",
161
+ config_path="configs/inference",
162
+ device="cuda",
163
+ use_fp16=True,
164
+ ) -> None:
165
+ if model_id not in model_specs:
166
+ raise ValueError(f"Model {model_id} not supported")
167
+ self.model_id = model_id
168
+ self.specs = model_specs[self.model_id]
169
+ self.config = str(pathlib.Path(config_path, self.specs.config))
170
+ self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
171
+ self.device = device
172
+ self.model = self._load_model(device=device, use_fp16=use_fp16)
173
+
174
+ def _load_model(self, device="cuda", use_fp16=True):
175
+ config = OmegaConf.load(self.config)
176
+ model = load_model_from_config(config, self.ckpt)
177
+ if model is None:
178
+ raise ValueError(f"Model {self.model_id} could not be loaded")
179
+ model.to(device)
180
+ if use_fp16:
181
+ model.conditioner.half()
182
+ model.model.half()
183
+ return model
184
+
185
+ def text_to_image(
186
+ self,
187
+ params: SamplingParams,
188
+ prompt: str,
189
+ negative_prompt: str = "",
190
+ samples: int = 1,
191
+ return_latents: bool = False,
192
+ ):
193
+ sampler = get_sampler_config(params)
194
+ value_dict = asdict(params)
195
+ value_dict["prompt"] = prompt
196
+ value_dict["negative_prompt"] = negative_prompt
197
+ value_dict["target_width"] = params.width
198
+ value_dict["target_height"] = params.height
199
+ return do_sample(
200
+ self.model,
201
+ sampler,
202
+ value_dict,
203
+ samples,
204
+ params.height,
205
+ params.width,
206
+ self.specs.channels,
207
+ self.specs.factor,
208
+ force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
209
+ return_latents=return_latents,
210
+ filter=None,
211
+ )
212
+
213
+ def image_to_image(
214
+ self,
215
+ params: SamplingParams,
216
+ image,
217
+ prompt: str,
218
+ negative_prompt: str = "",
219
+ samples: int = 1,
220
+ return_latents: bool = False,
221
+ ):
222
+ sampler = get_sampler_config(params)
223
+
224
+ if params.img2img_strength < 1.0:
225
+ sampler.discretization = Img2ImgDiscretizationWrapper(
226
+ sampler.discretization,
227
+ strength=params.img2img_strength,
228
+ )
229
+ height, width = image.shape[2], image.shape[3]
230
+ value_dict = asdict(params)
231
+ value_dict["prompt"] = prompt
232
+ value_dict["negative_prompt"] = negative_prompt
233
+ value_dict["target_width"] = width
234
+ value_dict["target_height"] = height
235
+ return do_img2img(
236
+ image,
237
+ self.model,
238
+ sampler,
239
+ value_dict,
240
+ samples,
241
+ force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
242
+ return_latents=return_latents,
243
+ filter=None,
244
+ )
245
+
246
+ def refiner(
247
+ self,
248
+ params: SamplingParams,
249
+ image,
250
+ prompt: str,
251
+ negative_prompt: Optional[str] = None,
252
+ samples: int = 1,
253
+ return_latents: bool = False,
254
+ ):
255
+ sampler = get_sampler_config(params)
256
+ value_dict = {
257
+ "orig_width": image.shape[3] * 8,
258
+ "orig_height": image.shape[2] * 8,
259
+ "target_width": image.shape[3] * 8,
260
+ "target_height": image.shape[2] * 8,
261
+ "prompt": prompt,
262
+ "negative_prompt": negative_prompt,
263
+ "crop_coords_top": 0,
264
+ "crop_coords_left": 0,
265
+ "aesthetic_score": 6.0,
266
+ "negative_aesthetic_score": 2.5,
267
+ }
268
+
269
+ return do_img2img(
270
+ image,
271
+ self.model,
272
+ sampler,
273
+ value_dict,
274
+ samples,
275
+ skip_encode=True,
276
+ return_latents=return_latents,
277
+ filter=None,
278
+ )
279
+
280
+
281
+ def get_guider_config(params: SamplingParams):
282
+ if params.guider == Guider.IDENTITY:
283
+ guider_config = {
284
+ "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
285
+ }
286
+ elif params.guider == Guider.VANILLA:
287
+ scale = params.scale
288
+
289
+ thresholder = params.thresholder
290
+
291
+ if thresholder == Thresholder.NONE:
292
+ dyn_thresh_config = {
293
+ "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
294
+ }
295
+ else:
296
+ raise NotImplementedError
297
+
298
+ guider_config = {
299
+ "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
300
+ "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
301
+ }
302
+ else:
303
+ raise NotImplementedError
304
+ return guider_config
305
+
306
+
307
+ def get_discretization_config(params: SamplingParams):
308
+ if params.discretization == Discretization.LEGACY_DDPM:
309
+ discretization_config = {
310
+ "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
311
+ }
312
+ elif params.discretization == Discretization.EDM:
313
+ discretization_config = {
314
+ "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
315
+ "params": {
316
+ "sigma_min": params.sigma_min,
317
+ "sigma_max": params.sigma_max,
318
+ "rho": params.rho,
319
+ },
320
+ }
321
+ else:
322
+ raise ValueError(f"unknown discretization {params.discretization}")
323
+ return discretization_config
324
+
325
+
326
+ def get_sampler_config(params: SamplingParams):
327
+ discretization_config = get_discretization_config(params)
328
+ guider_config = get_guider_config(params)
329
+ sampler = None
330
+ if params.sampler == Sampler.EULER_EDM:
331
+ return EulerEDMSampler(
332
+ num_steps=params.steps,
333
+ discretization_config=discretization_config,
334
+ guider_config=guider_config,
335
+ s_churn=params.s_churn,
336
+ s_tmin=params.s_tmin,
337
+ s_tmax=params.s_tmax,
338
+ s_noise=params.s_noise,
339
+ verbose=True,
340
+ )
341
+ if params.sampler == Sampler.HEUN_EDM:
342
+ return HeunEDMSampler(
343
+ num_steps=params.steps,
344
+ discretization_config=discretization_config,
345
+ guider_config=guider_config,
346
+ s_churn=params.s_churn,
347
+ s_tmin=params.s_tmin,
348
+ s_tmax=params.s_tmax,
349
+ s_noise=params.s_noise,
350
+ verbose=True,
351
+ )
352
+ if params.sampler == Sampler.EULER_ANCESTRAL:
353
+ return EulerAncestralSampler(
354
+ num_steps=params.steps,
355
+ discretization_config=discretization_config,
356
+ guider_config=guider_config,
357
+ eta=params.eta,
358
+ s_noise=params.s_noise,
359
+ verbose=True,
360
+ )
361
+ if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
362
+ return DPMPP2SAncestralSampler(
363
+ num_steps=params.steps,
364
+ discretization_config=discretization_config,
365
+ guider_config=guider_config,
366
+ eta=params.eta,
367
+ s_noise=params.s_noise,
368
+ verbose=True,
369
+ )
370
+ if params.sampler == Sampler.DPMPP2M:
371
+ return DPMPP2MSampler(
372
+ num_steps=params.steps,
373
+ discretization_config=discretization_config,
374
+ guider_config=guider_config,
375
+ verbose=True,
376
+ )
377
+ if params.sampler == Sampler.LINEAR_MULTISTEP:
378
+ return LinearMultistepSampler(
379
+ num_steps=params.steps,
380
+ discretization_config=discretization_config,
381
+ guider_config=guider_config,
382
+ order=params.order,
383
+ verbose=True,
384
+ )
385
+
386
+ raise ValueError(f"unknown sampler {params.sampler}!")
sgm/inference/helpers.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from typing import List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from einops import rearrange
8
+ from imwatermark import WatermarkEncoder
9
+ from omegaconf import ListConfig
10
+ from PIL import Image
11
+ from torch import autocast
12
+
13
+ from sgm.util import append_dims
14
+
15
+
16
+ class WatermarkEmbedder:
17
+ def __init__(self, watermark):
18
+ self.watermark = watermark
19
+ self.num_bits = len(WATERMARK_BITS)
20
+ self.encoder = WatermarkEncoder()
21
+ self.encoder.set_watermark("bits", self.watermark)
22
+
23
+ def __call__(self, image: torch.Tensor) -> torch.Tensor:
24
+ """
25
+ Adds a predefined watermark to the input image
26
+
27
+ Args:
28
+ image: ([N,] B, RGB, H, W) in range [0, 1]
29
+
30
+ Returns:
31
+ same as input but watermarked
32
+ """
33
+ squeeze = len(image.shape) == 4
34
+ if squeeze:
35
+ image = image[None, ...]
36
+ n = image.shape[0]
37
+ image_np = rearrange(
38
+ (255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
39
+ ).numpy()[:, :, :, ::-1]
40
+ # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
41
+ # watermarking libary expects input as cv2 BGR format
42
+ for k in range(image_np.shape[0]):
43
+ image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
44
+ image = torch.from_numpy(
45
+ rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
46
+ ).to(image.device)
47
+ image = torch.clamp(image / 255, min=0.0, max=1.0)
48
+ if squeeze:
49
+ image = image[0]
50
+ return image
51
+
52
+
53
+ # A fixed 48-bit message that was choosen at random
54
+ # WATERMARK_MESSAGE = 0xB3EC907BB19E
55
+ WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
56
+ # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
57
+ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
58
+ embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
59
+
60
+
61
+ def get_unique_embedder_keys_from_conditioner(conditioner):
62
+ return list({x.input_key for x in conditioner.embedders})
63
+
64
+
65
+ def perform_save_locally(save_path, samples):
66
+ os.makedirs(os.path.join(save_path), exist_ok=True)
67
+ base_count = len(os.listdir(os.path.join(save_path)))
68
+ samples = embed_watermark(samples)
69
+ for sample in samples:
70
+ sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
71
+ Image.fromarray(sample.astype(np.uint8)).save(
72
+ os.path.join(save_path, f"{base_count:09}.png")
73
+ )
74
+ base_count += 1
75
+
76
+
77
+ class Img2ImgDiscretizationWrapper:
78
+ """
79
+ wraps a discretizer, and prunes the sigmas
80
+ params:
81
+ strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
82
+ """
83
+
84
+ def __init__(self, discretization, strength: float = 1.0):
85
+ self.discretization = discretization
86
+ self.strength = strength
87
+ assert 0.0 <= self.strength <= 1.0
88
+
89
+ def __call__(self, *args, **kwargs):
90
+ # sigmas start large first, and decrease then
91
+ sigmas = self.discretization(*args, **kwargs)
92
+ print(f"sigmas after discretization, before pruning img2img: ", sigmas)
93
+ sigmas = torch.flip(sigmas, (0,))
94
+ sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
95
+ print("prune index:", max(int(self.strength * len(sigmas)), 1))
96
+ sigmas = torch.flip(sigmas, (0,))
97
+ print(f"sigmas after pruning: ", sigmas)
98
+ return sigmas
99
+
100
+
101
+ def do_sample(
102
+ model,
103
+ sampler,
104
+ value_dict,
105
+ num_samples,
106
+ H,
107
+ W,
108
+ C,
109
+ F,
110
+ force_uc_zero_embeddings: Optional[List] = None,
111
+ batch2model_input: Optional[List] = None,
112
+ return_latents=False,
113
+ filter=None,
114
+ device="cuda",
115
+ ):
116
+ if force_uc_zero_embeddings is None:
117
+ force_uc_zero_embeddings = []
118
+ if batch2model_input is None:
119
+ batch2model_input = []
120
+
121
+ with torch.no_grad():
122
+ with autocast(device) as precision_scope:
123
+ with model.ema_scope():
124
+ num_samples = [num_samples]
125
+ batch, batch_uc = get_batch(
126
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
127
+ value_dict,
128
+ num_samples,
129
+ )
130
+ for key in batch:
131
+ if isinstance(batch[key], torch.Tensor):
132
+ print(key, batch[key].shape)
133
+ elif isinstance(batch[key], list):
134
+ print(key, [len(l) for l in batch[key]])
135
+ else:
136
+ print(key, batch[key])
137
+ c, uc = model.conditioner.get_unconditional_conditioning(
138
+ batch,
139
+ batch_uc=batch_uc,
140
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
141
+ )
142
+
143
+ for k in c:
144
+ if not k == "crossattn":
145
+ c[k], uc[k] = map(
146
+ lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
147
+ )
148
+
149
+ additional_model_inputs = {}
150
+ for k in batch2model_input:
151
+ additional_model_inputs[k] = batch[k]
152
+
153
+ shape = (math.prod(num_samples), C, H // F, W // F)
154
+ randn = torch.randn(shape).to(device)
155
+
156
+ def denoiser(input, sigma, c):
157
+ return model.denoiser(
158
+ model.model, input, sigma, c, **additional_model_inputs
159
+ )
160
+
161
+ samples_z = sampler(denoiser, randn, cond=c, uc=uc)
162
+ samples_x = model.decode_first_stage(samples_z)
163
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
164
+
165
+ if filter is not None:
166
+ samples = filter(samples)
167
+
168
+ if return_latents:
169
+ return samples, samples_z
170
+ return samples
171
+
172
+
173
+ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
174
+ # Hardcoded demo setups; might undergo some changes in the future
175
+
176
+ batch = {}
177
+ batch_uc = {}
178
+
179
+ for key in keys:
180
+ if key == "txt":
181
+ batch["txt"] = (
182
+ np.repeat([value_dict["prompt"]], repeats=math.prod(N))
183
+ .reshape(N)
184
+ .tolist()
185
+ )
186
+ batch_uc["txt"] = (
187
+ np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
188
+ .reshape(N)
189
+ .tolist()
190
+ )
191
+ elif key == "original_size_as_tuple":
192
+ batch["original_size_as_tuple"] = (
193
+ torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
194
+ .to(device)
195
+ .repeat(*N, 1)
196
+ )
197
+ elif key == "crop_coords_top_left":
198
+ batch["crop_coords_top_left"] = (
199
+ torch.tensor(
200
+ [value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
201
+ )
202
+ .to(device)
203
+ .repeat(*N, 1)
204
+ )
205
+ elif key == "aesthetic_score":
206
+ batch["aesthetic_score"] = (
207
+ torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
208
+ )
209
+ batch_uc["aesthetic_score"] = (
210
+ torch.tensor([value_dict["negative_aesthetic_score"]])
211
+ .to(device)
212
+ .repeat(*N, 1)
213
+ )
214
+
215
+ elif key == "target_size_as_tuple":
216
+ batch["target_size_as_tuple"] = (
217
+ torch.tensor([value_dict["target_height"], value_dict["target_width"]])
218
+ .to(device)
219
+ .repeat(*N, 1)
220
+ )
221
+ else:
222
+ batch[key] = value_dict[key]
223
+
224
+ for key in batch.keys():
225
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
226
+ batch_uc[key] = torch.clone(batch[key])
227
+ return batch, batch_uc
228
+
229
+
230
+ def get_input_image_tensor(image: Image.Image, device="cuda"):
231
+ w, h = image.size
232
+ print(f"loaded input image of size ({w}, {h})")
233
+ width, height = map(
234
+ lambda x: x - x % 64, (w, h)
235
+ ) # resize to integer multiple of 64
236
+ image = image.resize((width, height))
237
+ image_array = np.array(image.convert("RGB"))
238
+ image_array = image_array[None].transpose(0, 3, 1, 2)
239
+ image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0
240
+ return image_tensor.to(device)
241
+
242
+
243
+ def do_img2img(
244
+ img,
245
+ model,
246
+ sampler,
247
+ value_dict,
248
+ num_samples,
249
+ force_uc_zero_embeddings=[],
250
+ additional_kwargs={},
251
+ offset_noise_level: float = 0.0,
252
+ return_latents=False,
253
+ skip_encode=False,
254
+ filter=None,
255
+ device="cuda",
256
+ ):
257
+ with torch.no_grad():
258
+ with autocast(device) as precision_scope:
259
+ with model.ema_scope():
260
+ batch, batch_uc = get_batch(
261
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
262
+ value_dict,
263
+ [num_samples],
264
+ )
265
+ c, uc = model.conditioner.get_unconditional_conditioning(
266
+ batch,
267
+ batch_uc=batch_uc,
268
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
269
+ )
270
+
271
+ for k in c:
272
+ c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
273
+
274
+ for k in additional_kwargs:
275
+ c[k] = uc[k] = additional_kwargs[k]
276
+ if skip_encode:
277
+ z = img
278
+ else:
279
+ z = model.encode_first_stage(img)
280
+ noise = torch.randn_like(z)
281
+ sigmas = sampler.discretization(sampler.num_steps)
282
+ sigma = sigmas[0].to(z.device)
283
+
284
+ if offset_noise_level > 0.0:
285
+ noise = noise + offset_noise_level * append_dims(
286
+ torch.randn(z.shape[0], device=z.device), z.ndim
287
+ )
288
+ noised_z = z + noise * append_dims(sigma, z.ndim)
289
+ noised_z = noised_z / torch.sqrt(
290
+ 1.0 + sigmas[0] ** 2.0
291
+ ) # Note: hardcoded to DDPM-like scaling. need to generalize later.
292
+
293
+ def denoiser(x, sigma, c):
294
+ return model.denoiser(model.model, x, sigma, c)
295
+
296
+ samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
297
+ samples_x = model.decode_first_stage(samples_z)
298
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
299
+
300
+ if filter is not None:
301
+ samples = filter(samples)
302
+
303
+ if return_latents:
304
+ return samples, samples_z
305
+ return samples
sgm/lr_scheduler.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+
9
+ def __init__(
10
+ self,
11
+ warm_up_steps,
12
+ lr_min,
13
+ lr_max,
14
+ lr_start,
15
+ max_decay_steps,
16
+ verbosity_interval=0,
17
+ ):
18
+ self.lr_warm_up_steps = warm_up_steps
19
+ self.lr_start = lr_start
20
+ self.lr_min = lr_min
21
+ self.lr_max = lr_max
22
+ self.lr_max_decay_steps = max_decay_steps
23
+ self.last_lr = 0.0
24
+ self.verbosity_interval = verbosity_interval
25
+
26
+ def schedule(self, n, **kwargs):
27
+ if self.verbosity_interval > 0:
28
+ if n % self.verbosity_interval == 0:
29
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
30
+ if n < self.lr_warm_up_steps:
31
+ lr = (
32
+ self.lr_max - self.lr_start
33
+ ) / self.lr_warm_up_steps * n + self.lr_start
34
+ self.last_lr = lr
35
+ return lr
36
+ else:
37
+ t = (n - self.lr_warm_up_steps) / (
38
+ self.lr_max_decay_steps - self.lr_warm_up_steps
39
+ )
40
+ t = min(t, 1.0)
41
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
42
+ 1 + np.cos(t * np.pi)
43
+ )
44
+ self.last_lr = lr
45
+ return lr
46
+
47
+ def __call__(self, n, **kwargs):
48
+ return self.schedule(n, **kwargs)
49
+
50
+
51
+ class LambdaWarmUpCosineScheduler2:
52
+ """
53
+ supports repeated iterations, configurable via lists
54
+ note: use with a base_lr of 1.0.
55
+ """
56
+
57
+ def __init__(
58
+ self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
59
+ ):
60
+ assert (
61
+ len(warm_up_steps)
62
+ == len(f_min)
63
+ == len(f_max)
64
+ == len(f_start)
65
+ == len(cycle_lengths)
66
+ )
67
+ self.lr_warm_up_steps = warm_up_steps
68
+ self.f_start = f_start
69
+ self.f_min = f_min
70
+ self.f_max = f_max
71
+ self.cycle_lengths = cycle_lengths
72
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
73
+ self.last_f = 0.0
74
+ self.verbosity_interval = verbosity_interval
75
+
76
+ def find_in_interval(self, n):
77
+ interval = 0
78
+ for cl in self.cum_cycles[1:]:
79
+ if n <= cl:
80
+ return interval
81
+ interval += 1
82
+
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0:
88
+ print(
89
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
90
+ f"current cycle {cycle}"
91
+ )
92
+ if n < self.lr_warm_up_steps[cycle]:
93
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
94
+ cycle
95
+ ] * n + self.f_start[cycle]
96
+ self.last_f = f
97
+ return f
98
+ else:
99
+ t = (n - self.lr_warm_up_steps[cycle]) / (
100
+ self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
101
+ )
102
+ t = min(t, 1.0)
103
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
104
+ 1 + np.cos(t * np.pi)
105
+ )
106
+ self.last_f = f
107
+ return f
108
+
109
+ def __call__(self, n, **kwargs):
110
+ return self.schedule(n, **kwargs)
111
+
112
+
113
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
114
+ def schedule(self, n, **kwargs):
115
+ cycle = self.find_in_interval(n)
116
+ n = n - self.cum_cycles[cycle]
117
+ if self.verbosity_interval > 0:
118
+ if n % self.verbosity_interval == 0:
119
+ print(
120
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
121
+ f"current cycle {cycle}"
122
+ )
123
+
124
+ if n < self.lr_warm_up_steps[cycle]:
125
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
126
+ cycle
127
+ ] * n + self.f_start[cycle]
128
+ self.last_f = f
129
+ return f
130
+ else:
131
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
132
+ self.cycle_lengths[cycle] - n
133
+ ) / (self.cycle_lengths[cycle])
134
+ self.last_f = f
135
+ return f
sgm/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .autoencoder import AutoencodingEngine
2
+ from .diffusion import DiffusionEngine
sgm/models/autoencoder.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import re
4
+ from abc import abstractmethod
5
+ from contextlib import contextmanager
6
+ from typing import Any, Dict, List, Optional, Tuple, Union
7
+
8
+ import pytorch_lightning as pl
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+ from packaging import version
13
+
14
+ from ..modules.autoencoding.regularizers import AbstractRegularizer
15
+ from ..modules.ema import LitEma
16
+ from ..util import (
17
+ default,
18
+ get_nested_attribute,
19
+ get_obj_from_str,
20
+ instantiate_from_config,
21
+ )
22
+
23
+ logpy = logging.getLogger(__name__)
24
+
25
+
26
+ class AbstractAutoencoder(pl.LightningModule):
27
+ """
28
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
29
+ unCLIP models, etc. Hence, it is fairly general, and specific features
30
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ ema_decay: Union[None, float] = None,
36
+ monitor: Union[None, str] = None,
37
+ input_key: str = "jpg",
38
+ ):
39
+ super().__init__()
40
+
41
+ self.input_key = input_key
42
+ self.use_ema = ema_decay is not None
43
+ if monitor is not None:
44
+ self.monitor = monitor
45
+
46
+ if self.use_ema:
47
+ self.model_ema = LitEma(self, decay=ema_decay)
48
+ logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
49
+
50
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
51
+ self.automatic_optimization = False
52
+
53
+ def apply_ckpt(self, ckpt: Union[None, str, dict]):
54
+ if ckpt is None:
55
+ return
56
+ if isinstance(ckpt, str):
57
+ ckpt = {
58
+ "target": "sgm.modules.checkpoint.CheckpointEngine",
59
+ "params": {"ckpt_path": ckpt},
60
+ }
61
+ engine = instantiate_from_config(ckpt)
62
+ engine(self)
63
+
64
+ @abstractmethod
65
+ def get_input(self, batch) -> Any:
66
+ raise NotImplementedError()
67
+
68
+ def on_train_batch_end(self, *args, **kwargs):
69
+ # for EMA computation
70
+ if self.use_ema:
71
+ self.model_ema(self)
72
+
73
+ @contextmanager
74
+ def ema_scope(self, context=None):
75
+ if self.use_ema:
76
+ self.model_ema.store(self.parameters())
77
+ self.model_ema.copy_to(self)
78
+ if context is not None:
79
+ logpy.info(f"{context}: Switched to EMA weights")
80
+ try:
81
+ yield None
82
+ finally:
83
+ if self.use_ema:
84
+ self.model_ema.restore(self.parameters())
85
+ if context is not None:
86
+ logpy.info(f"{context}: Restored training weights")
87
+
88
+ @abstractmethod
89
+ def encode(self, *args, **kwargs) -> torch.Tensor:
90
+ raise NotImplementedError("encode()-method of abstract base class called")
91
+
92
+ @abstractmethod
93
+ def decode(self, *args, **kwargs) -> torch.Tensor:
94
+ raise NotImplementedError("decode()-method of abstract base class called")
95
+
96
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
97
+ logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
98
+ return get_obj_from_str(cfg["target"])(
99
+ params, lr=lr, **cfg.get("params", dict())
100
+ )
101
+
102
+ def configure_optimizers(self) -> Any:
103
+ raise NotImplementedError()
104
+
105
+
106
+ class AutoencodingEngine(AbstractAutoencoder):
107
+ """
108
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
109
+ (we also restore them explicitly as special cases for legacy reasons).
110
+ Regularizations such as KL or VQ are moved to the regularizer class.
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ *args,
116
+ encoder_config: Dict,
117
+ decoder_config: Dict,
118
+ loss_config: Dict,
119
+ regularizer_config: Dict,
120
+ optimizer_config: Union[Dict, None] = None,
121
+ lr_g_factor: float = 1.0,
122
+ trainable_ae_params: Optional[List[List[str]]] = None,
123
+ ae_optimizer_args: Optional[List[dict]] = None,
124
+ trainable_disc_params: Optional[List[List[str]]] = None,
125
+ disc_optimizer_args: Optional[List[dict]] = None,
126
+ disc_start_iter: int = 0,
127
+ diff_boost_factor: float = 3.0,
128
+ ckpt_engine: Union[None, str, dict] = None,
129
+ ckpt_path: Optional[str] = None,
130
+ additional_decode_keys: Optional[List[str]] = None,
131
+ **kwargs,
132
+ ):
133
+ super().__init__(*args, **kwargs)
134
+ self.automatic_optimization = False # pytorch lightning
135
+
136
+ self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
137
+ self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
138
+ self.loss: torch.nn.Module = instantiate_from_config(loss_config)
139
+ self.regularization: AbstractRegularizer = instantiate_from_config(
140
+ regularizer_config
141
+ )
142
+ self.optimizer_config = default(
143
+ optimizer_config, {"target": "torch.optim.Adam"}
144
+ )
145
+ self.diff_boost_factor = diff_boost_factor
146
+ self.disc_start_iter = disc_start_iter
147
+ self.lr_g_factor = lr_g_factor
148
+ self.trainable_ae_params = trainable_ae_params
149
+ if self.trainable_ae_params is not None:
150
+ self.ae_optimizer_args = default(
151
+ ae_optimizer_args,
152
+ [{} for _ in range(len(self.trainable_ae_params))],
153
+ )
154
+ assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
155
+ else:
156
+ self.ae_optimizer_args = [{}] # makes type consitent
157
+
158
+ self.trainable_disc_params = trainable_disc_params
159
+ if self.trainable_disc_params is not None:
160
+ self.disc_optimizer_args = default(
161
+ disc_optimizer_args,
162
+ [{} for _ in range(len(self.trainable_disc_params))],
163
+ )
164
+ assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
165
+ else:
166
+ self.disc_optimizer_args = [{}] # makes type consitent
167
+
168
+ if ckpt_path is not None:
169
+ assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
170
+ logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
171
+ self.apply_ckpt(default(ckpt_path, ckpt_engine))
172
+ self.additional_decode_keys = set(default(additional_decode_keys, []))
173
+
174
+ def get_input(self, batch: Dict) -> torch.Tensor:
175
+ # assuming unified data format, dataloader returns a dict.
176
+ # image tensors should be scaled to -1 ... 1 and in channels-first
177
+ # format (e.g., bchw instead if bhwc)
178
+ return batch[self.input_key]
179
+
180
+ def get_autoencoder_params(self) -> list:
181
+ params = []
182
+ if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
183
+ params += list(self.loss.get_trainable_autoencoder_parameters())
184
+ if hasattr(self.regularization, "get_trainable_parameters"):
185
+ params += list(self.regularization.get_trainable_parameters())
186
+ params = params + list(self.encoder.parameters())
187
+ params = params + list(self.decoder.parameters())
188
+ return params
189
+
190
+ def get_discriminator_params(self) -> list:
191
+ if hasattr(self.loss, "get_trainable_parameters"):
192
+ params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
193
+ else:
194
+ params = []
195
+ return params
196
+
197
+ def get_last_layer(self):
198
+ return self.decoder.get_last_layer()
199
+
200
+ def encode(
201
+ self,
202
+ x: torch.Tensor,
203
+ return_reg_log: bool = False,
204
+ unregularized: bool = False,
205
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
206
+ z = self.encoder(x)
207
+ if unregularized:
208
+ return z, dict()
209
+ z, reg_log = self.regularization(z)
210
+ if return_reg_log:
211
+ return z, reg_log
212
+ return z
213
+
214
+ def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
215
+ x = self.decoder(z, **kwargs)
216
+ return x
217
+
218
+ def forward(
219
+ self, x: torch.Tensor, **additional_decode_kwargs
220
+ ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
221
+ z, reg_log = self.encode(x, return_reg_log=True)
222
+ dec = self.decode(z, **additional_decode_kwargs)
223
+ return z, dec, reg_log
224
+
225
+ def inner_training_step(
226
+ self, batch: dict, batch_idx: int, optimizer_idx: int = 0
227
+ ) -> torch.Tensor:
228
+ x = self.get_input(batch)
229
+ additional_decode_kwargs = {
230
+ key: batch[key] for key in self.additional_decode_keys.intersection(batch)
231
+ }
232
+ z, xrec, regularization_log = self(x, **additional_decode_kwargs)
233
+ if hasattr(self.loss, "forward_keys"):
234
+ extra_info = {
235
+ "z": z,
236
+ "optimizer_idx": optimizer_idx,
237
+ "global_step": self.global_step,
238
+ "last_layer": self.get_last_layer(),
239
+ "split": "train",
240
+ "regularization_log": regularization_log,
241
+ "autoencoder": self,
242
+ }
243
+ extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
244
+ else:
245
+ extra_info = dict()
246
+
247
+ if optimizer_idx == 0:
248
+ # autoencode
249
+ out_loss = self.loss(x, xrec, **extra_info)
250
+ if isinstance(out_loss, tuple):
251
+ aeloss, log_dict_ae = out_loss
252
+ else:
253
+ # simple loss function
254
+ aeloss = out_loss
255
+ log_dict_ae = {"train/loss/rec": aeloss.detach()}
256
+
257
+ self.log_dict(
258
+ log_dict_ae,
259
+ prog_bar=False,
260
+ logger=True,
261
+ on_step=True,
262
+ on_epoch=True,
263
+ sync_dist=False,
264
+ )
265
+ self.log(
266
+ "loss",
267
+ aeloss.mean().detach(),
268
+ prog_bar=True,
269
+ logger=False,
270
+ on_epoch=False,
271
+ on_step=True,
272
+ )
273
+ return aeloss
274
+ elif optimizer_idx == 1:
275
+ # discriminator
276
+ discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
277
+ # -> discriminator always needs to return a tuple
278
+ self.log_dict(
279
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
280
+ )
281
+ return discloss
282
+ else:
283
+ raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
284
+
285
+ def training_step(self, batch: dict, batch_idx: int):
286
+ opts = self.optimizers()
287
+ if not isinstance(opts, list):
288
+ # Non-adversarial case
289
+ opts = [opts]
290
+ optimizer_idx = batch_idx % len(opts)
291
+ if self.global_step < self.disc_start_iter:
292
+ optimizer_idx = 0
293
+ opt = opts[optimizer_idx]
294
+ opt.zero_grad()
295
+ with opt.toggle_model():
296
+ loss = self.inner_training_step(
297
+ batch, batch_idx, optimizer_idx=optimizer_idx
298
+ )
299
+ self.manual_backward(loss)
300
+ opt.step()
301
+
302
+ def validation_step(self, batch: dict, batch_idx: int) -> Dict:
303
+ log_dict = self._validation_step(batch, batch_idx)
304
+ with self.ema_scope():
305
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
306
+ log_dict.update(log_dict_ema)
307
+ return log_dict
308
+
309
+ def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
310
+ x = self.get_input(batch)
311
+
312
+ z, xrec, regularization_log = self(x)
313
+ if hasattr(self.loss, "forward_keys"):
314
+ extra_info = {
315
+ "z": z,
316
+ "optimizer_idx": 0,
317
+ "global_step": self.global_step,
318
+ "last_layer": self.get_last_layer(),
319
+ "split": "val" + postfix,
320
+ "regularization_log": regularization_log,
321
+ "autoencoder": self,
322
+ }
323
+ extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
324
+ else:
325
+ extra_info = dict()
326
+ out_loss = self.loss(x, xrec, **extra_info)
327
+ if isinstance(out_loss, tuple):
328
+ aeloss, log_dict_ae = out_loss
329
+ else:
330
+ # simple loss function
331
+ aeloss = out_loss
332
+ log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
333
+ full_log_dict = log_dict_ae
334
+
335
+ if "optimizer_idx" in extra_info:
336
+ extra_info["optimizer_idx"] = 1
337
+ discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
338
+ full_log_dict.update(log_dict_disc)
339
+ self.log(
340
+ f"val{postfix}/loss/rec",
341
+ log_dict_ae[f"val{postfix}/loss/rec"],
342
+ sync_dist=True,
343
+ )
344
+ self.log_dict(full_log_dict, sync_dist=True)
345
+ return full_log_dict
346
+
347
+ def get_param_groups(
348
+ self, parameter_names: List[List[str]], optimizer_args: List[dict]
349
+ ) -> Tuple[List[Dict[str, Any]], int]:
350
+ groups = []
351
+ num_params = 0
352
+ for names, args in zip(parameter_names, optimizer_args):
353
+ params = []
354
+ for pattern_ in names:
355
+ pattern_params = []
356
+ pattern = re.compile(pattern_)
357
+ for p_name, param in self.named_parameters():
358
+ if re.match(pattern, p_name):
359
+ pattern_params.append(param)
360
+ num_params += param.numel()
361
+ if len(pattern_params) == 0:
362
+ logpy.warn(f"Did not find parameters for pattern {pattern_}")
363
+ params.extend(pattern_params)
364
+ groups.append({"params": params, **args})
365
+ return groups, num_params
366
+
367
+ def configure_optimizers(self) -> List[torch.optim.Optimizer]:
368
+ if self.trainable_ae_params is None:
369
+ ae_params = self.get_autoencoder_params()
370
+ else:
371
+ ae_params, num_ae_params = self.get_param_groups(
372
+ self.trainable_ae_params, self.ae_optimizer_args
373
+ )
374
+ logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
375
+ if self.trainable_disc_params is None:
376
+ disc_params = self.get_discriminator_params()
377
+ else:
378
+ disc_params, num_disc_params = self.get_param_groups(
379
+ self.trainable_disc_params, self.disc_optimizer_args
380
+ )
381
+ logpy.info(
382
+ f"Number of trainable discriminator parameters: {num_disc_params:,}"
383
+ )
384
+ opt_ae = self.instantiate_optimizer_from_config(
385
+ ae_params,
386
+ default(self.lr_g_factor, 1.0) * self.learning_rate,
387
+ self.optimizer_config,
388
+ )
389
+ opts = [opt_ae]
390
+ if len(disc_params) > 0:
391
+ opt_disc = self.instantiate_optimizer_from_config(
392
+ disc_params, self.learning_rate, self.optimizer_config
393
+ )
394
+ opts.append(opt_disc)
395
+
396
+ return opts
397
+
398
+ @torch.no_grad()
399
+ def log_images(
400
+ self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
401
+ ) -> dict:
402
+ log = dict()
403
+ additional_decode_kwargs = {}
404
+ x = self.get_input(batch)
405
+ additional_decode_kwargs.update(
406
+ {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
407
+ )
408
+
409
+ _, xrec, _ = self(x, **additional_decode_kwargs)
410
+ log["inputs"] = x
411
+ log["reconstructions"] = xrec
412
+ diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
413
+ diff.clamp_(0, 1.0)
414
+ log["diff"] = 2.0 * diff - 1.0
415
+ # diff_boost shows location of small errors, by boosting their
416
+ # brightness.
417
+ log["diff_boost"] = (
418
+ 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
419
+ )
420
+ if hasattr(self.loss, "log_images"):
421
+ log.update(self.loss.log_images(x, xrec))
422
+ with self.ema_scope():
423
+ _, xrec_ema, _ = self(x, **additional_decode_kwargs)
424
+ log["reconstructions_ema"] = xrec_ema
425
+ diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
426
+ diff_ema.clamp_(0, 1.0)
427
+ log["diff_ema"] = 2.0 * diff_ema - 1.0
428
+ log["diff_boost_ema"] = (
429
+ 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
430
+ )
431
+ if additional_log_kwargs:
432
+ additional_decode_kwargs.update(additional_log_kwargs)
433
+ _, xrec_add, _ = self(x, **additional_decode_kwargs)
434
+ log_str = "reconstructions-" + "-".join(
435
+ [f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
436
+ )
437
+ log[log_str] = xrec_add
438
+ return log
439
+
440
+
441
+ class AutoencodingEngineLegacy(AutoencodingEngine):
442
+ def __init__(self, embed_dim: int, **kwargs):
443
+ self.max_batch_size = kwargs.pop("max_batch_size", None)
444
+ ddconfig = kwargs.pop("ddconfig")
445
+ ckpt_path = kwargs.pop("ckpt_path", None)
446
+ ckpt_engine = kwargs.pop("ckpt_engine", None)
447
+ super().__init__(
448
+ encoder_config={
449
+ "target": "sgm.modules.diffusionmodules.model.Encoder",
450
+ "params": ddconfig,
451
+ },
452
+ decoder_config={
453
+ "target": "sgm.modules.diffusionmodules.model.Decoder",
454
+ "params": ddconfig,
455
+ },
456
+ **kwargs,
457
+ )
458
+ self.quant_conv = torch.nn.Conv2d(
459
+ (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
460
+ (1 + ddconfig["double_z"]) * embed_dim,
461
+ 1,
462
+ )
463
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
464
+ self.embed_dim = embed_dim
465
+
466
+ self.apply_ckpt(default(ckpt_path, ckpt_engine))
467
+
468
+ def get_autoencoder_params(self) -> list:
469
+ params = super().get_autoencoder_params()
470
+ return params
471
+
472
+ def encode(
473
+ self, x: torch.Tensor, return_reg_log: bool = False
474
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
475
+ if self.max_batch_size is None:
476
+ z = self.encoder(x)
477
+ z = self.quant_conv(z)
478
+ else:
479
+ N = x.shape[0]
480
+ bs = self.max_batch_size
481
+ n_batches = int(math.ceil(N / bs))
482
+ z = list()
483
+ for i_batch in range(n_batches):
484
+ z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
485
+ z_batch = self.quant_conv(z_batch)
486
+ z.append(z_batch)
487
+ z = torch.cat(z, 0)
488
+
489
+ z, reg_log = self.regularization(z)
490
+ if return_reg_log:
491
+ return z, reg_log
492
+ return z
493
+
494
+ def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
495
+ if self.max_batch_size is None:
496
+ dec = self.post_quant_conv(z)
497
+ dec = self.decoder(dec, **decoder_kwargs)
498
+ else:
499
+ N = z.shape[0]
500
+ bs = self.max_batch_size
501
+ n_batches = int(math.ceil(N / bs))
502
+ dec = list()
503
+ for i_batch in range(n_batches):
504
+ dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
505
+ dec_batch = self.decoder(dec_batch, **decoder_kwargs)
506
+ dec.append(dec_batch)
507
+ dec = torch.cat(dec, 0)
508
+
509
+ return dec
510
+
511
+
512
+ class AutoencoderKL(AutoencodingEngineLegacy):
513
+ def __init__(self, **kwargs):
514
+ if "lossconfig" in kwargs:
515
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
516
+ super().__init__(
517
+ regularizer_config={
518
+ "target": (
519
+ "sgm.modules.autoencoding.regularizers"
520
+ ".DiagonalGaussianRegularizer"
521
+ )
522
+ },
523
+ **kwargs,
524
+ )
525
+
526
+
527
+ class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
528
+ def __init__(
529
+ self,
530
+ embed_dim: int,
531
+ n_embed: int,
532
+ sane_index_shape: bool = False,
533
+ **kwargs,
534
+ ):
535
+ if "lossconfig" in kwargs:
536
+ logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.")
537
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
538
+ super().__init__(
539
+ regularizer_config={
540
+ "target": (
541
+ "sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer"
542
+ ),
543
+ "params": {
544
+ "n_e": n_embed,
545
+ "e_dim": embed_dim,
546
+ "sane_index_shape": sane_index_shape,
547
+ },
548
+ },
549
+ **kwargs,
550
+ )
551
+
552
+
553
+ class IdentityFirstStage(AbstractAutoencoder):
554
+ def __init__(self, *args, **kwargs):
555
+ super().__init__(*args, **kwargs)
556
+
557
+ def get_input(self, x: Any) -> Any:
558
+ return x
559
+
560
+ def encode(self, x: Any, *args, **kwargs) -> Any:
561
+ return x
562
+
563
+ def decode(self, x: Any, *args, **kwargs) -> Any:
564
+ return x
565
+
566
+
567
+ class AEIntegerWrapper(nn.Module):
568
+ def __init__(
569
+ self,
570
+ model: nn.Module,
571
+ shape: Union[None, Tuple[int, int], List[int]] = (16, 16),
572
+ regularization_key: str = "regularization",
573
+ encoder_kwargs: Optional[Dict[str, Any]] = None,
574
+ ):
575
+ super().__init__()
576
+ self.model = model
577
+ assert hasattr(model, "encode") and hasattr(
578
+ model, "decode"
579
+ ), "Need AE interface"
580
+ self.regularization = get_nested_attribute(model, regularization_key)
581
+ self.shape = shape
582
+ self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True})
583
+
584
+ def encode(self, x) -> torch.Tensor:
585
+ assert (
586
+ not self.training
587
+ ), f"{self.__class__.__name__} only supports inference currently"
588
+ _, log = self.model.encode(x, **self.encoder_kwargs)
589
+ assert isinstance(log, dict)
590
+ inds = log["min_encoding_indices"]
591
+ return rearrange(inds, "b ... -> b (...)")
592
+
593
+ def decode(
594
+ self, inds: torch.Tensor, shape: Union[None, tuple, list] = None
595
+ ) -> torch.Tensor:
596
+ # expect inds shape (b, s) with s = h*w
597
+ shape = default(shape, self.shape) # Optional[(h, w)]
598
+ if shape is not None:
599
+ assert len(shape) == 2, f"Unhandeled shape {shape}"
600
+ inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1])
601
+ h = self.regularization.get_codebook_entry(inds) # (b, h, w, c)
602
+ h = rearrange(h, "b h w c -> b c h w")
603
+ return self.model.decode(h)
604
+
605
+
606
+ class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
607
+ def __init__(self, **kwargs):
608
+ if "lossconfig" in kwargs:
609
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
610
+ super().__init__(
611
+ regularizer_config={
612
+ "target": (
613
+ "sgm.modules.autoencoding.regularizers"
614
+ ".DiagonalGaussianRegularizer"
615
+ ),
616
+ "params": {"sample": False},
617
+ },
618
+ **kwargs,
619
+ )
sgm/models/diffusion.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from contextlib import contextmanager
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import pytorch_lightning as pl
6
+ import torch
7
+ from omegaconf import ListConfig, OmegaConf
8
+ from safetensors.torch import load_file as load_safetensors
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+
11
+ from ..modules import UNCONDITIONAL_CONFIG
12
+ from ..modules.autoencoding.temporal_ae import VideoDecoder
13
+ from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
14
+ from ..modules.ema import LitEma
15
+ from ..util import (
16
+ default,
17
+ disabled_train,
18
+ get_obj_from_str,
19
+ instantiate_from_config,
20
+ log_txt_as_img,
21
+ )
22
+
23
+
24
+ class DiffusionEngine(pl.LightningModule):
25
+ def __init__(
26
+ self,
27
+ network_config,
28
+ denoiser_config,
29
+ first_stage_config,
30
+ conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
31
+ sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
32
+ optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
33
+ scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
34
+ loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
35
+ network_wrapper: Union[None, str] = None,
36
+ ckpt_path: Union[None, str] = None,
37
+ use_ema: bool = False,
38
+ ema_decay_rate: float = 0.9999,
39
+ scale_factor: float = 1.0,
40
+ disable_first_stage_autocast=False,
41
+ input_key: str = "jpg",
42
+ log_keys: Union[List, None] = None,
43
+ no_cond_log: bool = False,
44
+ compile_model: bool = False,
45
+ en_and_decode_n_samples_a_time: Optional[int] = None,
46
+ ):
47
+ super().__init__()
48
+ self.log_keys = log_keys
49
+ self.input_key = input_key
50
+ self.optimizer_config = default(
51
+ optimizer_config, {"target": "torch.optim.AdamW"}
52
+ )
53
+ model = instantiate_from_config(network_config)
54
+ self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
55
+ model, compile_model=compile_model
56
+ )
57
+
58
+ self.denoiser = instantiate_from_config(denoiser_config)
59
+ self.sampler = (
60
+ instantiate_from_config(sampler_config)
61
+ if sampler_config is not None
62
+ else None
63
+ )
64
+ self.conditioner = instantiate_from_config(
65
+ default(conditioner_config, UNCONDITIONAL_CONFIG)
66
+ )
67
+ self.scheduler_config = scheduler_config
68
+ self._init_first_stage(first_stage_config)
69
+
70
+ self.loss_fn = (
71
+ instantiate_from_config(loss_fn_config)
72
+ if loss_fn_config is not None
73
+ else None
74
+ )
75
+
76
+ self.use_ema = use_ema
77
+ if self.use_ema:
78
+ self.model_ema = LitEma(self.model, decay=ema_decay_rate)
79
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
80
+
81
+ self.scale_factor = scale_factor
82
+ self.disable_first_stage_autocast = disable_first_stage_autocast
83
+ self.no_cond_log = no_cond_log
84
+
85
+ if ckpt_path is not None:
86
+ self.init_from_ckpt(ckpt_path)
87
+
88
+ self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
89
+
90
+ def init_from_ckpt(
91
+ self,
92
+ path: str,
93
+ ) -> None:
94
+ if path.endswith("ckpt"):
95
+ sd = torch.load(path, map_location="cpu")["state_dict"]
96
+ elif path.endswith("safetensors"):
97
+ sd = load_safetensors(path)
98
+ else:
99
+ raise NotImplementedError
100
+
101
+ missing, unexpected = self.load_state_dict(sd, strict=False)
102
+ print(
103
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
104
+ )
105
+ if len(missing) > 0:
106
+ print(f"Missing Keys: {missing}")
107
+ if len(unexpected) > 0:
108
+ print(f"Unexpected Keys: {unexpected}")
109
+
110
+ def _init_first_stage(self, config):
111
+ model = instantiate_from_config(config).eval()
112
+ model.train = disabled_train
113
+ for param in model.parameters():
114
+ param.requires_grad = False
115
+ self.first_stage_model = model
116
+
117
+ def get_input(self, batch):
118
+ # assuming unified data format, dataloader returns a dict.
119
+ # image tensors should be scaled to -1 ... 1 and in bchw format
120
+ return batch[self.input_key]
121
+
122
+ @torch.no_grad()
123
+ def decode_first_stage(self, z):
124
+ z = 1.0 / self.scale_factor * z
125
+ n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
126
+
127
+ n_rounds = math.ceil(z.shape[0] / n_samples)
128
+ all_out = []
129
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
130
+ for n in range(n_rounds):
131
+ if isinstance(self.first_stage_model.decoder, VideoDecoder):
132
+ kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
133
+ else:
134
+ kwargs = {}
135
+ out = self.first_stage_model.decode(
136
+ z[n * n_samples : (n + 1) * n_samples], **kwargs
137
+ )
138
+ all_out.append(out)
139
+ out = torch.cat(all_out, dim=0)
140
+ return out
141
+
142
+ @torch.no_grad()
143
+ def encode_first_stage(self, x):
144
+ n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
145
+ n_rounds = math.ceil(x.shape[0] / n_samples)
146
+ all_out = []
147
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
148
+ for n in range(n_rounds):
149
+ out = self.first_stage_model.encode(
150
+ x[n * n_samples : (n + 1) * n_samples]
151
+ )
152
+ all_out.append(out)
153
+ z = torch.cat(all_out, dim=0)
154
+ z = self.scale_factor * z
155
+ return z
156
+
157
+ def forward(self, x, batch):
158
+ loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
159
+ loss_mean = loss.mean()
160
+ loss_dict = {"loss": loss_mean}
161
+ return loss_mean, loss_dict
162
+
163
+ def shared_step(self, batch: Dict) -> Any:
164
+ x = self.get_input(batch)
165
+ x = self.encode_first_stage(x)
166
+ batch["global_step"] = self.global_step
167
+ loss, loss_dict = self(x, batch)
168
+ return loss, loss_dict
169
+
170
+ def training_step(self, batch, batch_idx):
171
+ loss, loss_dict = self.shared_step(batch)
172
+
173
+ self.log_dict(
174
+ loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
175
+ )
176
+
177
+ self.log(
178
+ "global_step",
179
+ self.global_step,
180
+ prog_bar=True,
181
+ logger=True,
182
+ on_step=True,
183
+ on_epoch=False,
184
+ )
185
+
186
+ if self.scheduler_config is not None:
187
+ lr = self.optimizers().param_groups[0]["lr"]
188
+ self.log(
189
+ "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
190
+ )
191
+
192
+ return loss
193
+
194
+ def on_train_start(self, *args, **kwargs):
195
+ if self.sampler is None or self.loss_fn is None:
196
+ raise ValueError("Sampler and loss function need to be set for training.")
197
+
198
+ def on_train_batch_end(self, *args, **kwargs):
199
+ if self.use_ema:
200
+ self.model_ema(self.model)
201
+
202
+ @contextmanager
203
+ def ema_scope(self, context=None):
204
+ if self.use_ema:
205
+ self.model_ema.store(self.model.parameters())
206
+ self.model_ema.copy_to(self.model)
207
+ if context is not None:
208
+ print(f"{context}: Switched to EMA weights")
209
+ try:
210
+ yield None
211
+ finally:
212
+ if self.use_ema:
213
+ self.model_ema.restore(self.model.parameters())
214
+ if context is not None:
215
+ print(f"{context}: Restored training weights")
216
+
217
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
218
+ return get_obj_from_str(cfg["target"])(
219
+ params, lr=lr, **cfg.get("params", dict())
220
+ )
221
+
222
+ def configure_optimizers(self):
223
+ lr = self.learning_rate
224
+ params = list(self.model.parameters())
225
+ for embedder in self.conditioner.embedders:
226
+ if embedder.is_trainable:
227
+ params = params + list(embedder.parameters())
228
+ opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
229
+ if self.scheduler_config is not None:
230
+ scheduler = instantiate_from_config(self.scheduler_config)
231
+ print("Setting up LambdaLR scheduler...")
232
+ scheduler = [
233
+ {
234
+ "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
235
+ "interval": "step",
236
+ "frequency": 1,
237
+ }
238
+ ]
239
+ return [opt], scheduler
240
+ return opt
241
+
242
+ @torch.no_grad()
243
+ def sample(
244
+ self,
245
+ cond: Dict,
246
+ uc: Union[Dict, None] = None,
247
+ batch_size: int = 16,
248
+ shape: Union[None, Tuple, List] = None,
249
+ **kwargs,
250
+ ):
251
+ randn = torch.randn(batch_size, *shape).to(self.device)
252
+
253
+ denoiser = lambda input, sigma, c: self.denoiser(
254
+ self.model, input, sigma, c, **kwargs
255
+ )
256
+ samples = self.sampler(denoiser, randn, cond, uc=uc)
257
+ return samples
258
+
259
+ @torch.no_grad()
260
+ def log_conditionings(self, batch: Dict, n: int) -> Dict:
261
+ """
262
+ Defines heuristics to log different conditionings.
263
+ These can be lists of strings (text-to-image), tensors, ints, ...
264
+ """
265
+ image_h, image_w = batch[self.input_key].shape[2:]
266
+ log = dict()
267
+
268
+ for embedder in self.conditioner.embedders:
269
+ if (
270
+ (self.log_keys is None) or (embedder.input_key in self.log_keys)
271
+ ) and not self.no_cond_log:
272
+ x = batch[embedder.input_key][:n]
273
+ if isinstance(x, torch.Tensor):
274
+ if x.dim() == 1:
275
+ # class-conditional, convert integer to string
276
+ x = [str(x[i].item()) for i in range(x.shape[0])]
277
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
278
+ elif x.dim() == 2:
279
+ # size and crop cond and the like
280
+ x = [
281
+ "x".join([str(xx) for xx in x[i].tolist()])
282
+ for i in range(x.shape[0])
283
+ ]
284
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
285
+ else:
286
+ raise NotImplementedError()
287
+ elif isinstance(x, (List, ListConfig)):
288
+ if isinstance(x[0], str):
289
+ # strings
290
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
291
+ else:
292
+ raise NotImplementedError()
293
+ else:
294
+ raise NotImplementedError()
295
+ log[embedder.input_key] = xc
296
+ return log
297
+
298
+ @torch.no_grad()
299
+ def log_images(
300
+ self,
301
+ batch: Dict,
302
+ N: int = 8,
303
+ sample: bool = True,
304
+ ucg_keys: List[str] = None,
305
+ **kwargs,
306
+ ) -> Dict:
307
+ conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
308
+ if ucg_keys:
309
+ assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
310
+ "Each defined ucg key for sampling must be in the provided conditioner input keys,"
311
+ f"but we have {ucg_keys} vs. {conditioner_input_keys}"
312
+ )
313
+ else:
314
+ ucg_keys = conditioner_input_keys
315
+ log = dict()
316
+
317
+ x = self.get_input(batch)
318
+
319
+ c, uc = self.conditioner.get_unconditional_conditioning(
320
+ batch,
321
+ force_uc_zero_embeddings=ucg_keys
322
+ if len(self.conditioner.embedders) > 0
323
+ else [],
324
+ )
325
+
326
+ sampling_kwargs = {}
327
+
328
+ N = min(x.shape[0], N)
329
+ x = x.to(self.device)[:N]
330
+ log["inputs"] = x
331
+ z = self.encode_first_stage(x)
332
+ log["reconstructions"] = self.decode_first_stage(z)
333
+ log.update(self.log_conditionings(batch, N))
334
+
335
+ for k in c:
336
+ if isinstance(c[k], torch.Tensor):
337
+ c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
338
+
339
+ if sample:
340
+ with self.ema_scope("Plotting"):
341
+ samples = self.sample(
342
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
343
+ )
344
+ samples = self.decode_first_stage(samples)
345
+ log["samples"] = samples
346
+ return log
sgm/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .encoders.modules import GeneralConditioner
2
+
3
+ UNCONDITIONAL_CONFIG = {
4
+ "target": "sgm.modules.GeneralConditioner",
5
+ "params": {"emb_models": []},
6
+ }
sgm/modules/attention.py ADDED
@@ -0,0 +1,759 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from inspect import isfunction
4
+ from typing import Any, Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+ from packaging import version
10
+ from torch import nn
11
+ from torch.utils.checkpoint import checkpoint
12
+
13
+ logpy = logging.getLogger(__name__)
14
+
15
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
16
+ SDP_IS_AVAILABLE = True
17
+ from torch.backends.cuda import SDPBackend, sdp_kernel
18
+
19
+ BACKEND_MAP = {
20
+ SDPBackend.MATH: {
21
+ "enable_math": True,
22
+ "enable_flash": False,
23
+ "enable_mem_efficient": False,
24
+ },
25
+ SDPBackend.FLASH_ATTENTION: {
26
+ "enable_math": False,
27
+ "enable_flash": True,
28
+ "enable_mem_efficient": False,
29
+ },
30
+ SDPBackend.EFFICIENT_ATTENTION: {
31
+ "enable_math": False,
32
+ "enable_flash": False,
33
+ "enable_mem_efficient": True,
34
+ },
35
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
36
+ }
37
+ else:
38
+ from contextlib import nullcontext
39
+
40
+ SDP_IS_AVAILABLE = False
41
+ sdp_kernel = nullcontext
42
+ BACKEND_MAP = {}
43
+ logpy.warn(
44
+ f"No SDP backend available, likely because you are running in pytorch "
45
+ f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
46
+ f"You might want to consider upgrading."
47
+ )
48
+
49
+ try:
50
+ import xformers
51
+ import xformers.ops
52
+
53
+ XFORMERS_IS_AVAILABLE = True
54
+ except:
55
+ XFORMERS_IS_AVAILABLE = False
56
+ logpy.warn("no module 'xformers'. Processing without...")
57
+
58
+ # from .diffusionmodules.util import mixed_checkpoint as checkpoint
59
+
60
+
61
+ def exists(val):
62
+ return val is not None
63
+
64
+
65
+ def uniq(arr):
66
+ return {el: True for el in arr}.keys()
67
+
68
+
69
+ def default(val, d):
70
+ if exists(val):
71
+ return val
72
+ return d() if isfunction(d) else d
73
+
74
+
75
+ def max_neg_value(t):
76
+ return -torch.finfo(t.dtype).max
77
+
78
+
79
+ def init_(tensor):
80
+ dim = tensor.shape[-1]
81
+ std = 1 / math.sqrt(dim)
82
+ tensor.uniform_(-std, std)
83
+ return tensor
84
+
85
+
86
+ # feedforward
87
+ class GEGLU(nn.Module):
88
+ def __init__(self, dim_in, dim_out):
89
+ super().__init__()
90
+ self.proj = nn.Linear(dim_in, dim_out * 2)
91
+
92
+ def forward(self, x):
93
+ x, gate = self.proj(x).chunk(2, dim=-1)
94
+ return x * F.gelu(gate)
95
+
96
+
97
+ class FeedForward(nn.Module):
98
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
99
+ super().__init__()
100
+ inner_dim = int(dim * mult)
101
+ dim_out = default(dim_out, dim)
102
+ project_in = (
103
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
104
+ if not glu
105
+ else GEGLU(dim, inner_dim)
106
+ )
107
+
108
+ self.net = nn.Sequential(
109
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
110
+ )
111
+
112
+ def forward(self, x):
113
+ return self.net(x)
114
+
115
+
116
+ def zero_module(module):
117
+ """
118
+ Zero out the parameters of a module and return it.
119
+ """
120
+ for p in module.parameters():
121
+ p.detach().zero_()
122
+ return module
123
+
124
+
125
+ def Normalize(in_channels):
126
+ return torch.nn.GroupNorm(
127
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
128
+ )
129
+
130
+
131
+ class LinearAttention(nn.Module):
132
+ def __init__(self, dim, heads=4, dim_head=32):
133
+ super().__init__()
134
+ self.heads = heads
135
+ hidden_dim = dim_head * heads
136
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
137
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
138
+
139
+ def forward(self, x):
140
+ b, c, h, w = x.shape
141
+ qkv = self.to_qkv(x)
142
+ q, k, v = rearrange(
143
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
144
+ )
145
+ k = k.softmax(dim=-1)
146
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
147
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
148
+ out = rearrange(
149
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
150
+ )
151
+ return self.to_out(out)
152
+
153
+
154
+ class SelfAttention(nn.Module):
155
+ ATTENTION_MODES = ("xformers", "torch", "math")
156
+
157
+ def __init__(
158
+ self,
159
+ dim: int,
160
+ num_heads: int = 8,
161
+ qkv_bias: bool = False,
162
+ qk_scale: Optional[float] = None,
163
+ attn_drop: float = 0.0,
164
+ proj_drop: float = 0.0,
165
+ attn_mode: str = "xformers",
166
+ ):
167
+ super().__init__()
168
+ self.num_heads = num_heads
169
+ head_dim = dim // num_heads
170
+ self.scale = qk_scale or head_dim**-0.5
171
+
172
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
173
+ self.attn_drop = nn.Dropout(attn_drop)
174
+ self.proj = nn.Linear(dim, dim)
175
+ self.proj_drop = nn.Dropout(proj_drop)
176
+ assert attn_mode in self.ATTENTION_MODES
177
+ self.attn_mode = attn_mode
178
+
179
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
180
+ B, L, C = x.shape
181
+
182
+ qkv = self.qkv(x)
183
+ if self.attn_mode == "torch":
184
+ qkv = rearrange(
185
+ qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
186
+ ).float()
187
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
188
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
189
+ x = rearrange(x, "B H L D -> B L (H D)")
190
+ elif self.attn_mode == "xformers":
191
+ qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
192
+ q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
193
+ x = xformers.ops.memory_efficient_attention(q, k, v)
194
+ x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads)
195
+ elif self.attn_mode == "math":
196
+ qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
197
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
198
+ attn = (q @ k.transpose(-2, -1)) * self.scale
199
+ attn = attn.softmax(dim=-1)
200
+ attn = self.attn_drop(attn)
201
+ x = (attn @ v).transpose(1, 2).reshape(B, L, C)
202
+ else:
203
+ raise NotImplemented
204
+
205
+ x = self.proj(x)
206
+ x = self.proj_drop(x)
207
+ return x
208
+
209
+
210
+ class SpatialSelfAttention(nn.Module):
211
+ def __init__(self, in_channels):
212
+ super().__init__()
213
+ self.in_channels = in_channels
214
+
215
+ self.norm = Normalize(in_channels)
216
+ self.q = torch.nn.Conv2d(
217
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
218
+ )
219
+ self.k = torch.nn.Conv2d(
220
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
221
+ )
222
+ self.v = torch.nn.Conv2d(
223
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
224
+ )
225
+ self.proj_out = torch.nn.Conv2d(
226
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
227
+ )
228
+
229
+ def forward(self, x):
230
+ h_ = x
231
+ h_ = self.norm(h_)
232
+ q = self.q(h_)
233
+ k = self.k(h_)
234
+ v = self.v(h_)
235
+
236
+ # compute attention
237
+ b, c, h, w = q.shape
238
+ q = rearrange(q, "b c h w -> b (h w) c")
239
+ k = rearrange(k, "b c h w -> b c (h w)")
240
+ w_ = torch.einsum("bij,bjk->bik", q, k)
241
+
242
+ w_ = w_ * (int(c) ** (-0.5))
243
+ w_ = torch.nn.functional.softmax(w_, dim=2)
244
+
245
+ # attend to values
246
+ v = rearrange(v, "b c h w -> b c (h w)")
247
+ w_ = rearrange(w_, "b i j -> b j i")
248
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
249
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
250
+ h_ = self.proj_out(h_)
251
+
252
+ return x + h_
253
+
254
+
255
+ class CrossAttention(nn.Module):
256
+ def __init__(
257
+ self,
258
+ query_dim,
259
+ context_dim=None,
260
+ heads=8,
261
+ dim_head=64,
262
+ dropout=0.0,
263
+ backend=None,
264
+ ):
265
+ super().__init__()
266
+ inner_dim = dim_head * heads
267
+ context_dim = default(context_dim, query_dim)
268
+
269
+ self.scale = dim_head**-0.5
270
+ self.heads = heads
271
+
272
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
273
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
274
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
275
+
276
+ self.to_out = nn.Sequential(
277
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
278
+ )
279
+ self.backend = backend
280
+
281
+ def forward(
282
+ self,
283
+ x,
284
+ context=None,
285
+ mask=None,
286
+ additional_tokens=None,
287
+ n_times_crossframe_attn_in_self=0,
288
+ ):
289
+ h = self.heads
290
+
291
+ if additional_tokens is not None:
292
+ # get the number of masked tokens at the beginning of the output sequence
293
+ n_tokens_to_mask = additional_tokens.shape[1]
294
+ # add additional token
295
+ x = torch.cat([additional_tokens, x], dim=1)
296
+
297
+ q = self.to_q(x)
298
+ context = default(context, x)
299
+ k = self.to_k(context)
300
+ v = self.to_v(context)
301
+
302
+ if n_times_crossframe_attn_in_self:
303
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
304
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
305
+ n_cp = x.shape[0] // n_times_crossframe_attn_in_self
306
+ k = repeat(
307
+ k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
308
+ )
309
+ v = repeat(
310
+ v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
311
+ )
312
+
313
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
314
+
315
+ ## old
316
+ """
317
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
318
+ del q, k
319
+
320
+ if exists(mask):
321
+ mask = rearrange(mask, 'b ... -> b (...)')
322
+ max_neg_value = -torch.finfo(sim.dtype).max
323
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
324
+ sim.masked_fill_(~mask, max_neg_value)
325
+
326
+ # attention, what we cannot get enough of
327
+ sim = sim.softmax(dim=-1)
328
+
329
+ out = einsum('b i j, b j d -> b i d', sim, v)
330
+ """
331
+ ## new
332
+ with sdp_kernel(**BACKEND_MAP[self.backend]):
333
+ # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
334
+ out = F.scaled_dot_product_attention(
335
+ q, k, v, attn_mask=mask
336
+ ) # scale is dim_head ** -0.5 per default
337
+
338
+ del q, k, v
339
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
340
+
341
+ if additional_tokens is not None:
342
+ # remove additional token
343
+ out = out[:, n_tokens_to_mask:]
344
+ return self.to_out(out)
345
+
346
+
347
+ class MemoryEfficientCrossAttention(nn.Module):
348
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
349
+ def __init__(
350
+ self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
351
+ ):
352
+ super().__init__()
353
+ logpy.debug(
354
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
355
+ f"context_dim is {context_dim} and using {heads} heads with a "
356
+ f"dimension of {dim_head}."
357
+ )
358
+ inner_dim = dim_head * heads
359
+ context_dim = default(context_dim, query_dim)
360
+
361
+ self.heads = heads
362
+ self.dim_head = dim_head
363
+
364
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
365
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
366
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
367
+
368
+ self.to_out = nn.Sequential(
369
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
370
+ )
371
+ self.attention_op: Optional[Any] = None
372
+
373
+ def forward(
374
+ self,
375
+ x,
376
+ context=None,
377
+ mask=None,
378
+ additional_tokens=None,
379
+ n_times_crossframe_attn_in_self=0,
380
+ ):
381
+ if additional_tokens is not None:
382
+ # get the number of masked tokens at the beginning of the output sequence
383
+ n_tokens_to_mask = additional_tokens.shape[1]
384
+ # add additional token
385
+ x = torch.cat([additional_tokens, x], dim=1)
386
+ q = self.to_q(x)
387
+ context = default(context, x)
388
+ k = self.to_k(context)
389
+ v = self.to_v(context)
390
+
391
+ if n_times_crossframe_attn_in_self:
392
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
393
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
394
+ # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
395
+ k = repeat(
396
+ k[::n_times_crossframe_attn_in_self],
397
+ "b ... -> (b n) ...",
398
+ n=n_times_crossframe_attn_in_self,
399
+ )
400
+ v = repeat(
401
+ v[::n_times_crossframe_attn_in_self],
402
+ "b ... -> (b n) ...",
403
+ n=n_times_crossframe_attn_in_self,
404
+ )
405
+
406
+ b, _, _ = q.shape
407
+ q, k, v = map(
408
+ lambda t: t.unsqueeze(3)
409
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
410
+ .permute(0, 2, 1, 3)
411
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
412
+ .contiguous(),
413
+ (q, k, v),
414
+ )
415
+
416
+ # actually compute the attention, what we cannot get enough of
417
+ if version.parse(xformers.__version__) >= version.parse("0.0.21"):
418
+ # NOTE: workaround for
419
+ # https://github.com/facebookresearch/xformers/issues/845
420
+ max_bs = 32768
421
+ N = q.shape[0]
422
+ n_batches = math.ceil(N / max_bs)
423
+ out = list()
424
+ for i_batch in range(n_batches):
425
+ batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
426
+ out.append(
427
+ xformers.ops.memory_efficient_attention(
428
+ q[batch],
429
+ k[batch],
430
+ v[batch],
431
+ attn_bias=None,
432
+ op=self.attention_op,
433
+ )
434
+ )
435
+ out = torch.cat(out, 0)
436
+ else:
437
+ out = xformers.ops.memory_efficient_attention(
438
+ q, k, v, attn_bias=None, op=self.attention_op
439
+ )
440
+
441
+ # TODO: Use this directly in the attention operation, as a bias
442
+ if exists(mask):
443
+ raise NotImplementedError
444
+ out = (
445
+ out.unsqueeze(0)
446
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
447
+ .permute(0, 2, 1, 3)
448
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
449
+ )
450
+ if additional_tokens is not None:
451
+ # remove additional token
452
+ out = out[:, n_tokens_to_mask:]
453
+ return self.to_out(out)
454
+
455
+
456
+ class BasicTransformerBlock(nn.Module):
457
+ ATTENTION_MODES = {
458
+ "softmax": CrossAttention, # vanilla attention
459
+ "softmax-xformers": MemoryEfficientCrossAttention, # ampere
460
+ }
461
+
462
+ def __init__(
463
+ self,
464
+ dim,
465
+ n_heads,
466
+ d_head,
467
+ dropout=0.0,
468
+ context_dim=None,
469
+ gated_ff=True,
470
+ checkpoint=True,
471
+ disable_self_attn=False,
472
+ attn_mode="softmax",
473
+ sdp_backend=None,
474
+ ):
475
+ super().__init__()
476
+ assert attn_mode in self.ATTENTION_MODES
477
+ if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
478
+ logpy.warn(
479
+ f"Attention mode '{attn_mode}' is not available. Falling "
480
+ f"back to native attention. This is not a problem in "
481
+ f"Pytorch >= 2.0. FYI, you are running with PyTorch "
482
+ f"version {torch.__version__}."
483
+ )
484
+ attn_mode = "softmax"
485
+ elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
486
+ logpy.warn(
487
+ "We do not support vanilla attention anymore, as it is too "
488
+ "expensive. Sorry."
489
+ )
490
+ if not XFORMERS_IS_AVAILABLE:
491
+ assert (
492
+ False
493
+ ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
494
+ else:
495
+ logpy.info("Falling back to xformers efficient attention.")
496
+ attn_mode = "softmax-xformers"
497
+ attn_cls = self.ATTENTION_MODES[attn_mode]
498
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
499
+ assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
500
+ else:
501
+ assert sdp_backend is None
502
+ self.disable_self_attn = disable_self_attn
503
+ self.attn1 = attn_cls(
504
+ query_dim=dim,
505
+ heads=n_heads,
506
+ dim_head=d_head,
507
+ dropout=dropout,
508
+ context_dim=context_dim if self.disable_self_attn else None,
509
+ backend=sdp_backend,
510
+ ) # is a self-attention if not self.disable_self_attn
511
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
512
+ self.attn2 = attn_cls(
513
+ query_dim=dim,
514
+ context_dim=context_dim,
515
+ heads=n_heads,
516
+ dim_head=d_head,
517
+ dropout=dropout,
518
+ backend=sdp_backend,
519
+ ) # is self-attn if context is none
520
+ self.norm1 = nn.LayerNorm(dim)
521
+ self.norm2 = nn.LayerNorm(dim)
522
+ self.norm3 = nn.LayerNorm(dim)
523
+ self.checkpoint = checkpoint
524
+ if self.checkpoint:
525
+ logpy.debug(f"{self.__class__.__name__} is using checkpointing")
526
+
527
+ def forward(
528
+ self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
529
+ ):
530
+ kwargs = {"x": x}
531
+
532
+ if context is not None:
533
+ kwargs.update({"context": context})
534
+
535
+ if additional_tokens is not None:
536
+ kwargs.update({"additional_tokens": additional_tokens})
537
+
538
+ if n_times_crossframe_attn_in_self:
539
+ kwargs.update(
540
+ {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
541
+ )
542
+
543
+ # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
544
+ if self.checkpoint:
545
+ # inputs = {"x": x, "context": context}
546
+ return checkpoint(self._forward, x, context)
547
+ # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
548
+ else:
549
+ return self._forward(**kwargs)
550
+
551
+ def _forward(
552
+ self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
553
+ ):
554
+ x = (
555
+ self.attn1(
556
+ self.norm1(x),
557
+ context=context if self.disable_self_attn else None,
558
+ additional_tokens=additional_tokens,
559
+ n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
560
+ if not self.disable_self_attn
561
+ else 0,
562
+ )
563
+ + x
564
+ )
565
+ x = (
566
+ self.attn2(
567
+ self.norm2(x), context=context, additional_tokens=additional_tokens
568
+ )
569
+ + x
570
+ )
571
+ x = self.ff(self.norm3(x)) + x
572
+ return x
573
+
574
+
575
+ class BasicTransformerSingleLayerBlock(nn.Module):
576
+ ATTENTION_MODES = {
577
+ "softmax": CrossAttention, # vanilla attention
578
+ "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
579
+ # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
580
+ }
581
+
582
+ def __init__(
583
+ self,
584
+ dim,
585
+ n_heads,
586
+ d_head,
587
+ dropout=0.0,
588
+ context_dim=None,
589
+ gated_ff=True,
590
+ checkpoint=True,
591
+ attn_mode="softmax",
592
+ ):
593
+ super().__init__()
594
+ assert attn_mode in self.ATTENTION_MODES
595
+ attn_cls = self.ATTENTION_MODES[attn_mode]
596
+ self.attn1 = attn_cls(
597
+ query_dim=dim,
598
+ heads=n_heads,
599
+ dim_head=d_head,
600
+ dropout=dropout,
601
+ context_dim=context_dim,
602
+ )
603
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
604
+ self.norm1 = nn.LayerNorm(dim)
605
+ self.norm2 = nn.LayerNorm(dim)
606
+ self.checkpoint = checkpoint
607
+
608
+ def forward(self, x, context=None):
609
+ # inputs = {"x": x, "context": context}
610
+ # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
611
+ return checkpoint(self._forward, x, context)
612
+
613
+ def _forward(self, x, context=None):
614
+ x = self.attn1(self.norm1(x), context=context) + x
615
+ x = self.ff(self.norm2(x)) + x
616
+ return x
617
+
618
+
619
+ class SpatialTransformer(nn.Module):
620
+ """
621
+ Transformer block for image-like data.
622
+ First, project the input (aka embedding)
623
+ and reshape to b, t, d.
624
+ Then apply standard transformer action.
625
+ Finally, reshape to image
626
+ NEW: use_linear for more efficiency instead of the 1x1 convs
627
+ """
628
+
629
+ def __init__(
630
+ self,
631
+ in_channels,
632
+ n_heads,
633
+ d_head,
634
+ depth=1,
635
+ dropout=0.0,
636
+ context_dim=None,
637
+ disable_self_attn=False,
638
+ use_linear=False,
639
+ attn_type="softmax",
640
+ use_checkpoint=True,
641
+ # sdp_backend=SDPBackend.FLASH_ATTENTION
642
+ sdp_backend=None,
643
+ ):
644
+ super().__init__()
645
+ logpy.debug(
646
+ f"constructing {self.__class__.__name__} of depth {depth} w/ "
647
+ f"{in_channels} channels and {n_heads} heads."
648
+ )
649
+
650
+ if exists(context_dim) and not isinstance(context_dim, list):
651
+ context_dim = [context_dim]
652
+ if exists(context_dim) and isinstance(context_dim, list):
653
+ if depth != len(context_dim):
654
+ logpy.warn(
655
+ f"{self.__class__.__name__}: Found context dims "
656
+ f"{context_dim} of depth {len(context_dim)}, which does not "
657
+ f"match the specified 'depth' of {depth}. Setting context_dim "
658
+ f"to {depth * [context_dim[0]]} now."
659
+ )
660
+ # depth does not match context dims.
661
+ assert all(
662
+ map(lambda x: x == context_dim[0], context_dim)
663
+ ), "need homogenous context_dim to match depth automatically"
664
+ context_dim = depth * [context_dim[0]]
665
+ elif context_dim is None:
666
+ context_dim = [None] * depth
667
+ self.in_channels = in_channels
668
+ inner_dim = n_heads * d_head
669
+ self.norm = Normalize(in_channels)
670
+ if not use_linear:
671
+ self.proj_in = nn.Conv2d(
672
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
673
+ )
674
+ else:
675
+ self.proj_in = nn.Linear(in_channels, inner_dim)
676
+
677
+ self.transformer_blocks = nn.ModuleList(
678
+ [
679
+ BasicTransformerBlock(
680
+ inner_dim,
681
+ n_heads,
682
+ d_head,
683
+ dropout=dropout,
684
+ context_dim=context_dim[d],
685
+ disable_self_attn=disable_self_attn,
686
+ attn_mode=attn_type,
687
+ checkpoint=use_checkpoint,
688
+ sdp_backend=sdp_backend,
689
+ )
690
+ for d in range(depth)
691
+ ]
692
+ )
693
+ if not use_linear:
694
+ self.proj_out = zero_module(
695
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
696
+ )
697
+ else:
698
+ # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
699
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
700
+ self.use_linear = use_linear
701
+
702
+ def forward(self, x, context=None):
703
+ # note: if no context is given, cross-attention defaults to self-attention
704
+ if not isinstance(context, list):
705
+ context = [context]
706
+ b, c, h, w = x.shape
707
+ x_in = x
708
+ x = self.norm(x)
709
+ if not self.use_linear:
710
+ x = self.proj_in(x)
711
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
712
+ if self.use_linear:
713
+ x = self.proj_in(x)
714
+ for i, block in enumerate(self.transformer_blocks):
715
+ if i > 0 and len(context) == 1:
716
+ i = 0 # use same context for each block
717
+ x = block(x, context=context[i])
718
+ if self.use_linear:
719
+ x = self.proj_out(x)
720
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
721
+ if not self.use_linear:
722
+ x = self.proj_out(x)
723
+ return x + x_in
724
+
725
+
726
+ class SimpleTransformer(nn.Module):
727
+ def __init__(
728
+ self,
729
+ dim: int,
730
+ depth: int,
731
+ heads: int,
732
+ dim_head: int,
733
+ context_dim: Optional[int] = None,
734
+ dropout: float = 0.0,
735
+ checkpoint: bool = True,
736
+ ):
737
+ super().__init__()
738
+ self.layers = nn.ModuleList([])
739
+ for _ in range(depth):
740
+ self.layers.append(
741
+ BasicTransformerBlock(
742
+ dim,
743
+ heads,
744
+ dim_head,
745
+ dropout=dropout,
746
+ context_dim=context_dim,
747
+ attn_mode="softmax-xformers",
748
+ checkpoint=checkpoint,
749
+ )
750
+ )
751
+
752
+ def forward(
753
+ self,
754
+ x: torch.Tensor,
755
+ context: Optional[torch.Tensor] = None,
756
+ ) -> torch.Tensor:
757
+ for layer in self.layers:
758
+ x = layer(x, context)
759
+ return x
sgm/modules/autoencoding/__init__.py ADDED
File without changes
sgm/modules/autoencoding/losses/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __all__ = [
2
+ "GeneralLPIPSWithDiscriminator",
3
+ "LatentLPIPS",
4
+ ]
5
+
6
+ from .discriminator_loss import GeneralLPIPSWithDiscriminator
7
+ from .lpips import LatentLPIPS
sgm/modules/autoencoding/losses/discriminator_loss.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Iterator, List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+ from einops import rearrange
8
+ from matplotlib import colormaps
9
+ from matplotlib import pyplot as plt
10
+
11
+ from ....util import default, instantiate_from_config
12
+ from ..lpips.loss.lpips import LPIPS
13
+ from ..lpips.model.model import weights_init
14
+ from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
15
+
16
+
17
+ class GeneralLPIPSWithDiscriminator(nn.Module):
18
+ def __init__(
19
+ self,
20
+ disc_start: int,
21
+ logvar_init: float = 0.0,
22
+ disc_num_layers: int = 3,
23
+ disc_in_channels: int = 3,
24
+ disc_factor: float = 1.0,
25
+ disc_weight: float = 1.0,
26
+ perceptual_weight: float = 1.0,
27
+ disc_loss: str = "hinge",
28
+ scale_input_to_tgt_size: bool = False,
29
+ dims: int = 2,
30
+ learn_logvar: bool = False,
31
+ regularization_weights: Union[None, Dict[str, float]] = None,
32
+ additional_log_keys: Optional[List[str]] = None,
33
+ discriminator_config: Optional[Dict] = None,
34
+ ):
35
+ super().__init__()
36
+ self.dims = dims
37
+ if self.dims > 2:
38
+ print(
39
+ f"running with dims={dims}. This means that for perceptual loss "
40
+ f"calculation, the LPIPS loss will be applied to each frame "
41
+ f"independently."
42
+ )
43
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
44
+ assert disc_loss in ["hinge", "vanilla"]
45
+ self.perceptual_loss = LPIPS().eval()
46
+ self.perceptual_weight = perceptual_weight
47
+ # output log variance
48
+ self.logvar = nn.Parameter(
49
+ torch.full((), logvar_init), requires_grad=learn_logvar
50
+ )
51
+ self.learn_logvar = learn_logvar
52
+
53
+ discriminator_config = default(
54
+ discriminator_config,
55
+ {
56
+ "target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator",
57
+ "params": {
58
+ "input_nc": disc_in_channels,
59
+ "n_layers": disc_num_layers,
60
+ "use_actnorm": False,
61
+ },
62
+ },
63
+ )
64
+
65
+ self.discriminator = instantiate_from_config(discriminator_config).apply(
66
+ weights_init
67
+ )
68
+ self.discriminator_iter_start = disc_start
69
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
70
+ self.disc_factor = disc_factor
71
+ self.discriminator_weight = disc_weight
72
+ self.regularization_weights = default(regularization_weights, {})
73
+
74
+ self.forward_keys = [
75
+ "optimizer_idx",
76
+ "global_step",
77
+ "last_layer",
78
+ "split",
79
+ "regularization_log",
80
+ ]
81
+
82
+ self.additional_log_keys = set(default(additional_log_keys, []))
83
+ self.additional_log_keys.update(set(self.regularization_weights.keys()))
84
+
85
+ def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
86
+ return self.discriminator.parameters()
87
+
88
+ def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:
89
+ if self.learn_logvar:
90
+ yield self.logvar
91
+ yield from ()
92
+
93
+ @torch.no_grad()
94
+ def log_images(
95
+ self, inputs: torch.Tensor, reconstructions: torch.Tensor
96
+ ) -> Dict[str, torch.Tensor]:
97
+ # calc logits of real/fake
98
+ logits_real = self.discriminator(inputs.contiguous().detach())
99
+ if len(logits_real.shape) < 4:
100
+ # Non patch-discriminator
101
+ return dict()
102
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
103
+ # -> (b, 1, h, w)
104
+
105
+ # parameters for colormapping
106
+ high = max(logits_fake.abs().max(), logits_real.abs().max()).item()
107
+ cmap = colormaps["PiYG"] # diverging colormap
108
+
109
+ def to_colormap(logits: torch.Tensor) -> torch.Tensor:
110
+ """(b, 1, ...) -> (b, 3, ...)"""
111
+ logits = (logits + high) / (2 * high)
112
+ logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel
113
+ # -> (b, 1, ..., 3)
114
+ logits = torch.from_numpy(logits_np).to(logits.device)
115
+ return rearrange(logits, "b 1 ... c -> b c ...")
116
+
117
+ logits_real = torch.nn.functional.interpolate(
118
+ logits_real,
119
+ size=inputs.shape[-2:],
120
+ mode="nearest",
121
+ antialias=False,
122
+ )
123
+ logits_fake = torch.nn.functional.interpolate(
124
+ logits_fake,
125
+ size=reconstructions.shape[-2:],
126
+ mode="nearest",
127
+ antialias=False,
128
+ )
129
+
130
+ # alpha value of logits for overlay
131
+ alpha_real = torch.abs(logits_real) / high
132
+ alpha_fake = torch.abs(logits_fake) / high
133
+ # -> (b, 1, h, w) in range [0, 0.5]
134
+ # alpha value of lines don't really matter, since the values are the same
135
+ # for both images and logits anyway
136
+ grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4)
137
+ grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4)
138
+ grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1)
139
+ # -> (1, h, w)
140
+ # blend logits and images together
141
+
142
+ # prepare logits for plotting
143
+ logits_real = to_colormap(logits_real)
144
+ logits_fake = to_colormap(logits_fake)
145
+ # resize logits
146
+ # -> (b, 3, h, w)
147
+
148
+ # make some grids
149
+ # add all logits to one plot
150
+ logits_real = torchvision.utils.make_grid(logits_real, nrow=4)
151
+ logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4)
152
+ # I just love how torchvision calls the number of columns `nrow`
153
+ grid_logits = torch.cat((logits_real, logits_fake), dim=1)
154
+ # -> (3, h, w)
155
+
156
+ grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4)
157
+ grid_images_fake = torchvision.utils.make_grid(
158
+ 0.5 * reconstructions + 0.5, nrow=4
159
+ )
160
+ grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1)
161
+ # -> (3, h, w) in range [0, 1]
162
+
163
+ grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images
164
+
165
+ # Create labeled colorbar
166
+ dpi = 100
167
+ height = 128 / dpi
168
+ width = grid_logits.shape[2] / dpi
169
+ fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
170
+ img = ax.imshow(np.array([[-high, high]]), cmap=cmap)
171
+ plt.colorbar(
172
+ img,
173
+ cax=ax,
174
+ orientation="horizontal",
175
+ fraction=0.9,
176
+ aspect=width / height,
177
+ pad=0.0,
178
+ )
179
+ img.set_visible(False)
180
+ fig.tight_layout()
181
+ fig.canvas.draw()
182
+ # manually convert figure to numpy
183
+ cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
184
+ cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
185
+ cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0
186
+ cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device)
187
+
188
+ # Add colorbar to plot
189
+ annotated_grid = torch.cat((grid_logits, cbar), dim=1)
190
+ blended_grid = torch.cat((grid_blend, cbar), dim=1)
191
+ return {
192
+ "vis_logits": 2 * annotated_grid[None, ...] - 1,
193
+ "vis_logits_blended": 2 * blended_grid[None, ...] - 1,
194
+ }
195
+
196
+ def calculate_adaptive_weight(
197
+ self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor
198
+ ) -> torch.Tensor:
199
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
200
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
201
+
202
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
203
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
204
+ d_weight = d_weight * self.discriminator_weight
205
+ return d_weight
206
+
207
+ def forward(
208
+ self,
209
+ inputs: torch.Tensor,
210
+ reconstructions: torch.Tensor,
211
+ *, # added because I changed the order here
212
+ regularization_log: Dict[str, torch.Tensor],
213
+ optimizer_idx: int,
214
+ global_step: int,
215
+ last_layer: torch.Tensor,
216
+ split: str = "train",
217
+ weights: Union[None, float, torch.Tensor] = None,
218
+ ) -> Tuple[torch.Tensor, dict]:
219
+ if self.scale_input_to_tgt_size:
220
+ inputs = torch.nn.functional.interpolate(
221
+ inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
222
+ )
223
+
224
+ if self.dims > 2:
225
+ inputs, reconstructions = map(
226
+ lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
227
+ (inputs, reconstructions),
228
+ )
229
+
230
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
231
+ if self.perceptual_weight > 0:
232
+ p_loss = self.perceptual_loss(
233
+ inputs.contiguous(), reconstructions.contiguous()
234
+ )
235
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
236
+
237
+ nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
238
+
239
+ # now the GAN part
240
+ if optimizer_idx == 0:
241
+ # generator update
242
+ if global_step >= self.discriminator_iter_start or not self.training:
243
+ logits_fake = self.discriminator(reconstructions.contiguous())
244
+ g_loss = -torch.mean(logits_fake)
245
+ if self.training:
246
+ d_weight = self.calculate_adaptive_weight(
247
+ nll_loss, g_loss, last_layer=last_layer
248
+ )
249
+ else:
250
+ d_weight = torch.tensor(1.0)
251
+ else:
252
+ d_weight = torch.tensor(0.0)
253
+ g_loss = torch.tensor(0.0, requires_grad=True)
254
+
255
+ loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss
256
+ log = dict()
257
+ for k in regularization_log:
258
+ if k in self.regularization_weights:
259
+ loss = loss + self.regularization_weights[k] * regularization_log[k]
260
+ if k in self.additional_log_keys:
261
+ log[f"{split}/{k}"] = regularization_log[k].detach().float().mean()
262
+
263
+ log.update(
264
+ {
265
+ f"{split}/loss/total": loss.clone().detach().mean(),
266
+ f"{split}/loss/nll": nll_loss.detach().mean(),
267
+ f"{split}/loss/rec": rec_loss.detach().mean(),
268
+ f"{split}/loss/g": g_loss.detach().mean(),
269
+ f"{split}/scalars/logvar": self.logvar.detach(),
270
+ f"{split}/scalars/d_weight": d_weight.detach(),
271
+ }
272
+ )
273
+
274
+ return loss, log
275
+ elif optimizer_idx == 1:
276
+ # second pass for discriminator update
277
+ logits_real = self.discriminator(inputs.contiguous().detach())
278
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
279
+
280
+ if global_step >= self.discriminator_iter_start or not self.training:
281
+ d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake)
282
+ else:
283
+ d_loss = torch.tensor(0.0, requires_grad=True)
284
+
285
+ log = {
286
+ f"{split}/loss/disc": d_loss.clone().detach().mean(),
287
+ f"{split}/logits/real": logits_real.detach().mean(),
288
+ f"{split}/logits/fake": logits_fake.detach().mean(),
289
+ }
290
+ return d_loss, log
291
+ else:
292
+ raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}")
293
+
294
+ def get_nll_loss(
295
+ self,
296
+ rec_loss: torch.Tensor,
297
+ weights: Optional[Union[float, torch.Tensor]] = None,
298
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
299
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
300
+ weighted_nll_loss = nll_loss
301
+ if weights is not None:
302
+ weighted_nll_loss = weights * nll_loss
303
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
304
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
305
+
306
+ return nll_loss, weighted_nll_loss
sgm/modules/autoencoding/losses/lpips.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ....util import default, instantiate_from_config
5
+ from ..lpips.loss.lpips import LPIPS
6
+
7
+
8
+ class LatentLPIPS(nn.Module):
9
+ def __init__(
10
+ self,
11
+ decoder_config,
12
+ perceptual_weight=1.0,
13
+ latent_weight=1.0,
14
+ scale_input_to_tgt_size=False,
15
+ scale_tgt_to_input_size=False,
16
+ perceptual_weight_on_inputs=0.0,
17
+ ):
18
+ super().__init__()
19
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
20
+ self.scale_tgt_to_input_size = scale_tgt_to_input_size
21
+ self.init_decoder(decoder_config)
22
+ self.perceptual_loss = LPIPS().eval()
23
+ self.perceptual_weight = perceptual_weight
24
+ self.latent_weight = latent_weight
25
+ self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
26
+
27
+ def init_decoder(self, config):
28
+ self.decoder = instantiate_from_config(config)
29
+ if hasattr(self.decoder, "encoder"):
30
+ del self.decoder.encoder
31
+
32
+ def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
33
+ log = dict()
34
+ loss = (latent_inputs - latent_predictions) ** 2
35
+ log[f"{split}/latent_l2_loss"] = loss.mean().detach()
36
+ image_reconstructions = None
37
+ if self.perceptual_weight > 0.0:
38
+ image_reconstructions = self.decoder.decode(latent_predictions)
39
+ image_targets = self.decoder.decode(latent_inputs)
40
+ perceptual_loss = self.perceptual_loss(
41
+ image_targets.contiguous(), image_reconstructions.contiguous()
42
+ )
43
+ loss = (
44
+ self.latent_weight * loss.mean()
45
+ + self.perceptual_weight * perceptual_loss.mean()
46
+ )
47
+ log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
48
+
49
+ if self.perceptual_weight_on_inputs > 0.0:
50
+ image_reconstructions = default(
51
+ image_reconstructions, self.decoder.decode(latent_predictions)
52
+ )
53
+ if self.scale_input_to_tgt_size:
54
+ image_inputs = torch.nn.functional.interpolate(
55
+ image_inputs,
56
+ image_reconstructions.shape[2:],
57
+ mode="bicubic",
58
+ antialias=True,
59
+ )
60
+ elif self.scale_tgt_to_input_size:
61
+ image_reconstructions = torch.nn.functional.interpolate(
62
+ image_reconstructions,
63
+ image_inputs.shape[2:],
64
+ mode="bicubic",
65
+ antialias=True,
66
+ )
67
+
68
+ perceptual_loss2 = self.perceptual_loss(
69
+ image_inputs.contiguous(), image_reconstructions.contiguous()
70
+ )
71
+ loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
72
+ log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
73
+ return loss, log
sgm/modules/autoencoding/lpips/__init__.py ADDED
File without changes
sgm/modules/autoencoding/lpips/loss/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ vgg.pth
sgm/modules/autoencoding/lpips/loss/LICENSE ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
sgm/modules/autoencoding/lpips/loss/__init__.py ADDED
File without changes
sgm/modules/autoencoding/lpips/loss/lpips.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2
+
3
+ from collections import namedtuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import models
8
+
9
+ from ..util import get_ckpt_path
10
+
11
+
12
+ class LPIPS(nn.Module):
13
+ # Learned perceptual metric
14
+ def __init__(self, use_dropout=True):
15
+ super().__init__()
16
+ self.scaling_layer = ScalingLayer()
17
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
18
+ self.net = vgg16(pretrained=True, requires_grad=False)
19
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
20
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
21
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
22
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
23
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
24
+ self.load_from_pretrained()
25
+ for param in self.parameters():
26
+ param.requires_grad = False
27
+
28
+ def load_from_pretrained(self, name="vgg_lpips"):
29
+ ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
30
+ self.load_state_dict(
31
+ torch.load(ckpt, map_location=torch.device("cpu")), strict=False
32
+ )
33
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
34
+
35
+ @classmethod
36
+ def from_pretrained(cls, name="vgg_lpips"):
37
+ if name != "vgg_lpips":
38
+ raise NotImplementedError
39
+ model = cls()
40
+ ckpt = get_ckpt_path(name)
41
+ model.load_state_dict(
42
+ torch.load(ckpt, map_location=torch.device("cpu")), strict=False
43
+ )
44
+ return model
45
+
46
+ def forward(self, input, target):
47
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
48
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
49
+ feats0, feats1, diffs = {}, {}, {}
50
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
51
+ for kk in range(len(self.chns)):
52
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
53
+ outs1[kk]
54
+ )
55
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
56
+
57
+ res = [
58
+ spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
59
+ for kk in range(len(self.chns))
60
+ ]
61
+ val = res[0]
62
+ for l in range(1, len(self.chns)):
63
+ val += res[l]
64
+ return val
65
+
66
+
67
+ class ScalingLayer(nn.Module):
68
+ def __init__(self):
69
+ super(ScalingLayer, self).__init__()
70
+ self.register_buffer(
71
+ "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
72
+ )
73
+ self.register_buffer(
74
+ "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
75
+ )
76
+
77
+ def forward(self, inp):
78
+ return (inp - self.shift) / self.scale
79
+
80
+
81
+ class NetLinLayer(nn.Module):
82
+ """A single linear layer which does a 1x1 conv"""
83
+
84
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
85
+ super(NetLinLayer, self).__init__()
86
+ layers = (
87
+ [
88
+ nn.Dropout(),
89
+ ]
90
+ if (use_dropout)
91
+ else []
92
+ )
93
+ layers += [
94
+ nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
95
+ ]
96
+ self.model = nn.Sequential(*layers)
97
+
98
+
99
+ class vgg16(torch.nn.Module):
100
+ def __init__(self, requires_grad=False, pretrained=True):
101
+ super(vgg16, self).__init__()
102
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
103
+ self.slice1 = torch.nn.Sequential()
104
+ self.slice2 = torch.nn.Sequential()
105
+ self.slice3 = torch.nn.Sequential()
106
+ self.slice4 = torch.nn.Sequential()
107
+ self.slice5 = torch.nn.Sequential()
108
+ self.N_slices = 5
109
+ for x in range(4):
110
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
111
+ for x in range(4, 9):
112
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
113
+ for x in range(9, 16):
114
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
115
+ for x in range(16, 23):
116
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
117
+ for x in range(23, 30):
118
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
119
+ if not requires_grad:
120
+ for param in self.parameters():
121
+ param.requires_grad = False
122
+
123
+ def forward(self, X):
124
+ h = self.slice1(X)
125
+ h_relu1_2 = h
126
+ h = self.slice2(h)
127
+ h_relu2_2 = h
128
+ h = self.slice3(h)
129
+ h_relu3_3 = h
130
+ h = self.slice4(h)
131
+ h_relu4_3 = h
132
+ h = self.slice5(h)
133
+ h_relu5_3 = h
134
+ vgg_outputs = namedtuple(
135
+ "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
136
+ )
137
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
138
+ return out
139
+
140
+
141
+ def normalize_tensor(x, eps=1e-10):
142
+ norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
143
+ return x / (norm_factor + eps)
144
+
145
+
146
+ def spatial_average(x, keepdim=True):
147
+ return x.mean([2, 3], keepdim=keepdim)