ablattmann commited on
Commit
ec3a273
1 Parent(s): ebcf159

add configs for training unconditional/class-conditional ldms

Browse files
README.md CHANGED
@@ -55,18 +55,7 @@ bash scripts/download_first_stages.sh
55
  ```
56
 
57
  The first stage models can then be found in `models/first_stage_models/<model_spec>`
58
- ### Training autoencoder models
59
 
60
- Configs for training a KL-regularized autoencoder on ImageNet are provided at `configs/autoencoder`.
61
- Training can be started by running
62
- ```
63
- CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/autoencoder/<config_spec> -t --gpus 0,
64
- ```
65
- where `config_spec` is one of {`autoencoder_kl_8x8x64.yaml`(f=32, d=64), `autoencoder_kl_16x16x16.yaml`(f=16, d=16),
66
- `autoencoder_kl_32x32x4`(f=8, d=4), `autoencoder_kl_64x64x3`(f=4, d=3)}.
67
-
68
- For training VQ-regularized models, see the [taming-transformers](https://github.com/CompVis/taming-transformers)
69
- repository.
70
 
71
 
72
  ## Pretrained LDMs
@@ -78,9 +67,10 @@ repository.
78
  | LSUN-Bedrooms | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=1)| 2.95 (3.0) | 2.22 (2.23)| 0.66 | 0.48 | https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip | |
79
  | ImageNet | Class-conditional Image Synthesis | LDM-VQ-8 (200 DDIM steps, eta=1) | 7.77(7.76)* /15.82** | 201.56(209.52)* /78.82** | 0.84* / 0.65** | 0.35* / 0.63** | https://ommer-lab.com/files/latent-diffusion/cin.zip | *: w/ guiding, classifier_scale 10 **: w/o guiding, scores in bracket calculated with script provided by [ADM](https://github.com/openai/guided-diffusion) |
80
  | Conceptual Captions | Text-conditional Image Synthesis | LDM-VQ-f4 (100 DDIM steps, eta=0) | 16.79 | 13.89 | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/text2img.zip | finetuned from LAION |
81
- | OpenImages | Super-resolution | N/A | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip | BSR image degradation |
82
  | OpenImages | Layout-to-Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=0) | 32.02 | 15.92 | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip | |
83
- | Landscapes (finetuned 512) | Semantic Image Synthesis | LDM-VQ-4 (100 DDIM steps, eta=1) | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip | |
 
84
 
85
 
86
  ### Get the models
@@ -116,10 +106,90 @@ python scripts/inpaint.py --indir data/inpainting_examples/ --outdir outputs/inp
116
  `indir` should contain images `*.png` and masks `<image_fname>_mask.png` like
117
  the examples provided in `data/inpainting_examples`.
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  ## Coming Soon...
120
 
121
- * Code for training LDMs and the corresponding compression models.
122
- * Inference scripts for conditional LDMs for various conditioning modalities.
123
  * In the meantime, you can play with our colab notebook https://colab.research.google.com/drive/1xqzUi2iXQXDqXBHQGP9Mqt2YrYW6cx-J?usp=sharing
124
  * We will also release some further pretrained models.
125
 
 
55
  ```
56
 
57
  The first stage models can then be found in `models/first_stage_models/<model_spec>`
 
58
 
 
 
 
 
 
 
 
 
 
 
59
 
60
 
61
  ## Pretrained LDMs
 
67
  | LSUN-Bedrooms | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=1)| 2.95 (3.0) | 2.22 (2.23)| 0.66 | 0.48 | https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip | |
68
  | ImageNet | Class-conditional Image Synthesis | LDM-VQ-8 (200 DDIM steps, eta=1) | 7.77(7.76)* /15.82** | 201.56(209.52)* /78.82** | 0.84* / 0.65** | 0.35* / 0.63** | https://ommer-lab.com/files/latent-diffusion/cin.zip | *: w/ guiding, classifier_scale 10 **: w/o guiding, scores in bracket calculated with script provided by [ADM](https://github.com/openai/guided-diffusion) |
69
  | Conceptual Captions | Text-conditional Image Synthesis | LDM-VQ-f4 (100 DDIM steps, eta=0) | 16.79 | 13.89 | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/text2img.zip | finetuned from LAION |
70
+ | OpenImages | Super-resolution | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip | BSR image degradation |
71
  | OpenImages | Layout-to-Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=0) | 32.02 | 15.92 | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip | |
72
+ | Landscapes | Semantic Image Synthesis | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip | |
73
+ | Landscapes | Semantic Image Synthesis | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip | finetuned on resolution 512x512 |
74
 
75
 
76
  ### Get the models
 
106
  `indir` should contain images `*.png` and masks `<image_fname>_mask.png` like
107
  the examples provided in `data/inpainting_examples`.
108
 
109
+
110
+ # Train your own LDMs
111
+
112
+ ## Data preparation
113
+
114
+ ### Faces
115
+ For downloading the CelebA-HQ and FFHQ datasets, proceed as described in the [taming-transformers](https://github.com/CompVis/taming-transformers#celeba-hq)
116
+ repository.
117
+
118
+ ### LSUN
119
+
120
+ The LSUN datasets can be conveniently downloaded via the script available [here](https://github.com/fyu/lsun).
121
+ We performed a custom split into training and validation images, and provide the corresponding filenames
122
+ at [https://ommer-lab.com/files/lsun.zip](https://ommer-lab.com/files/lsun.zip).
123
+ After downloading, extract them to `./data/lsun`. The beds/cats/churches subsets should
124
+ also be placed/symlinked at `./data/lsun/bedrooms`/`./data/lsun/cats`/`./data/lsun/churches`, respectively.
125
+
126
+ ### ImageNet
127
+ The code will try to download (through [Academic
128
+ Torrents](http://academictorrents.com/)) and prepare ImageNet the first time it
129
+ is used. However, since ImageNet is quite large, this requires a lot of disk
130
+ space and time. If you already have ImageNet on your disk, you can speed things
131
+ up by putting the data into
132
+ `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` (which defaults to
133
+ `~/.cache/autoencoders/data/ILSVRC2012_{split}/data/`), where `{split}` is one
134
+ of `train`/`validation`. It should have the following structure:
135
+
136
+ ```
137
+ ${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/
138
+ ├── n01440764
139
+ │ ├── n01440764_10026.JPEG
140
+ │ ├── n01440764_10027.JPEG
141
+ │ ├── ...
142
+ ├── n01443537
143
+ │ ├── n01443537_10007.JPEG
144
+ │ ├── n01443537_10014.JPEG
145
+ │ ├── ...
146
+ ├── ...
147
+ ```
148
+
149
+ If you haven't extracted the data, you can also place
150
+ `ILSVRC2012_img_train.tar`/`ILSVRC2012_img_val.tar` (or symlinks to them) into
151
+ `${XDG_CACHE}/autoencoders/data/ILSVRC2012_train/` /
152
+ `${XDG_CACHE}/autoencoders/data/ILSVRC2012_validation/`, which will then be
153
+ extracted into above structure without downloading it again. Note that this
154
+ will only happen if neither a folder
155
+ `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` nor a file
156
+ `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/.ready` exist. Remove them
157
+ if you want to force running the dataset preparation again.
158
+
159
+
160
+ ## Model Training
161
+
162
+ Logs and checkpoints for trained models are saved to `logs/<START_DATE_AND_TIME>_<config_spec>`.
163
+
164
+ ### Training autoencoder models
165
+
166
+ Configs for training a KL-regularized autoencoder on ImageNet are provided at `configs/autoencoder`.
167
+ Training can be started by running
168
+ ```
169
+ CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/autoencoder/<config_spec>.yaml -t --gpus 0,
170
+ ```
171
+ where `config_spec` is one of {`autoencoder_kl_8x8x64`(f=32, d=64), `autoencoder_kl_16x16x16`(f=16, d=16),
172
+ `autoencoder_kl_32x32x4`(f=8, d=4), `autoencoder_kl_64x64x3`(f=4, d=3)}.
173
+
174
+ For training VQ-regularized models, see the [taming-transformers](https://github.com/CompVis/taming-transformers)
175
+ repository.
176
+
177
+ ### Training LDMs
178
+
179
+ In ``configs/latent-diffusion/`` we provide configs for training LDMs on the LSUN-, CelebA-HQ, FFHQ and ImageNet datasets.
180
+ Training can be started by running
181
+
182
+ ```shell script
183
+ CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/latent-diffusion/<config_spec>.yaml -t --gpus 0,
184
+ ```
185
+
186
+ where ``<config_spec>`` is one of {`celebahq-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),`ffhq-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),
187
+ `lsun_bedrooms-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),
188
+ `lsun_churches-ldm-vq-4`(f=8, KL-reg. autoencoder, spatial size 32x32x4),`cin-ldm-vq-8`(f=8, VQ-reg. autoencoder, spatial size 32x32x4)}.
189
+
190
  ## Coming Soon...
191
 
192
+ * More inference scripts for conditional LDMs.
 
193
  * In the meantime, you can play with our colab notebook https://colab.research.google.com/drive/1xqzUi2iXQXDqXBHQGP9Mqt2YrYW6cx-J?usp=sharing
194
  * We will also release some further pretrained models.
195
 
configs/latent-diffusion/celebahq-ldm-vq-4.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 2.0e-06
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0195
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ image_size: 64
12
+ channels: 3
13
+ monitor: val/loss_simple_ema
14
+
15
+ unet_config:
16
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
17
+ params:
18
+ image_size: 64
19
+ in_channels: 3
20
+ out_channels: 3
21
+ model_channels: 224
22
+ attention_resolutions:
23
+ # note: this isn\t actually the resolution but
24
+ # the downsampling factor, i.e. this corresnponds to
25
+ # attention on spatial resolution 8,16,32, as the
26
+ # spatial reolution of the latents is 64 for f4
27
+ - 8
28
+ - 4
29
+ - 2
30
+ num_res_blocks: 2
31
+ channel_mult:
32
+ - 1
33
+ - 2
34
+ - 3
35
+ - 4
36
+ num_head_channels: 32
37
+ first_stage_config:
38
+ target: ldm.models.autoencoder.VQModelInterface
39
+ params:
40
+ embed_dim: 3
41
+ n_embed: 8192
42
+ ckpt_path: models/first_stage_models/vq-f4/model.ckpt
43
+ ddconfig:
44
+ double_z: false
45
+ z_channels: 3
46
+ resolution: 256
47
+ in_channels: 3
48
+ out_ch: 3
49
+ ch: 128
50
+ ch_mult:
51
+ - 1
52
+ - 2
53
+ - 4
54
+ num_res_blocks: 2
55
+ attn_resolutions: []
56
+ dropout: 0.0
57
+ lossconfig:
58
+ target: torch.nn.Identity
59
+ cond_stage_config: __is_unconditional__
60
+ data:
61
+ target: main.DataModuleFromConfig
62
+ params:
63
+ batch_size: 48
64
+ num_workers: 5
65
+ wrap: false
66
+ train:
67
+ target: taming.data.faceshq.CelebAHQTrain
68
+ params:
69
+ size: 256
70
+ validation:
71
+ target: taming.data.faceshq.CelebAHQValidation
72
+ params:
73
+ size: 256
74
+
75
+
76
+ lightning:
77
+ callbacks:
78
+ image_logger:
79
+ target: main.ImageLogger
80
+ params:
81
+ batch_frequency: 5000
82
+ max_images: 8
83
+ increase_log_steps: False
84
+
85
+ trainer:
86
+ benchmark: True
configs/latent-diffusion/cin-ldm-vq-f8.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-06
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0195
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ cond_stage_key: class_label
12
+ image_size: 32
13
+ channels: 4
14
+ cond_stage_trainable: true
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ unet_config:
18
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19
+ params:
20
+ image_size: 32
21
+ in_channels: 4
22
+ out_channels: 4
23
+ model_channels: 256
24
+ attention_resolutions:
25
+ #note: this isn\t actually the resolution but
26
+ # the downsampling factor, i.e. this corresnponds to
27
+ # attention on spatial resolution 8,16,32, as the
28
+ # spatial reolution of the latents is 32 for f8
29
+ - 4
30
+ - 2
31
+ - 1
32
+ num_res_blocks: 2
33
+ channel_mult:
34
+ - 1
35
+ - 2
36
+ - 4
37
+ num_head_channels: 32
38
+ use_spatial_transformer: true
39
+ transformer_depth: 1
40
+ context_dim: 512
41
+ first_stage_config:
42
+ target: ldm.models.autoencoder.VQModelInterface
43
+ params:
44
+ embed_dim: 4
45
+ n_embed: 16384
46
+ ckpt_path: configs/first_stage_models/vq-f8/model.yaml
47
+ ddconfig:
48
+ double_z: false
49
+ z_channels: 4
50
+ resolution: 256
51
+ in_channels: 3
52
+ out_ch: 3
53
+ ch: 128
54
+ ch_mult:
55
+ - 1
56
+ - 2
57
+ - 2
58
+ - 4
59
+ num_res_blocks: 2
60
+ attn_resolutions:
61
+ - 32
62
+ dropout: 0.0
63
+ lossconfig:
64
+ target: torch.nn.Identity
65
+ cond_stage_config:
66
+ target: ldm.modules.encoders.modules.ClassEmbedder
67
+ params:
68
+ embed_dim: 512
69
+ key: class_label
70
+ data:
71
+ target: main.DataModuleFromConfig
72
+ params:
73
+ batch_size: 64
74
+ num_workers: 12
75
+ wrap: false
76
+ train:
77
+ target: ldm.data.imagenet.ImageNetTrain
78
+ params:
79
+ config:
80
+ size: 256
81
+ validation:
82
+ target: ldm.data.imagenet.ImageNetValidation
83
+ params:
84
+ config:
85
+ size: 256
86
+
87
+
88
+ lightning:
89
+ callbacks:
90
+ image_logger:
91
+ target: main.ImageLogger
92
+ params:
93
+ batch_frequency: 5000
94
+ max_images: 8
95
+ increase_log_steps: False
96
+
97
+ trainer:
98
+ benchmark: True
configs/latent-diffusion/ffhq-ldm-vq-4.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 2.0e-06
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0195
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ image_size: 64
12
+ channels: 3
13
+ monitor: val/loss_simple_ema
14
+ unet_config:
15
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16
+ params:
17
+ image_size: 64
18
+ in_channels: 3
19
+ out_channels: 3
20
+ model_channels: 224
21
+ attention_resolutions:
22
+ # note: this isn\t actually the resolution but
23
+ # the downsampling factor, i.e. this corresnponds to
24
+ # attention on spatial resolution 8,16,32, as the
25
+ # spatial reolution of the latents is 64 for f4
26
+ - 8
27
+ - 4
28
+ - 2
29
+ num_res_blocks: 2
30
+ channel_mult:
31
+ - 1
32
+ - 2
33
+ - 3
34
+ - 4
35
+ num_head_channels: 32
36
+ first_stage_config:
37
+ target: ldm.models.autoencoder.VQModelInterface
38
+ params:
39
+ embed_dim: 3
40
+ n_embed: 8192
41
+ ckpt_path: configs/first_stage_models/vq-f4/model.yaml
42
+ ddconfig:
43
+ double_z: false
44
+ z_channels: 3
45
+ resolution: 256
46
+ in_channels: 3
47
+ out_ch: 3
48
+ ch: 128
49
+ ch_mult:
50
+ - 1
51
+ - 2
52
+ - 4
53
+ num_res_blocks: 2
54
+ attn_resolutions: []
55
+ dropout: 0.0
56
+ lossconfig:
57
+ target: torch.nn.Identity
58
+ cond_stage_config: __is_unconditional__
59
+ data:
60
+ target: main.DataModuleFromConfig
61
+ params:
62
+ batch_size: 42
63
+ num_workers: 5
64
+ wrap: false
65
+ train:
66
+ target: taming.data.faceshq.FFHQTrain
67
+ params:
68
+ size: 256
69
+ validation:
70
+ target: taming.data.faceshq.FFHQValidation
71
+ params:
72
+ size: 256
73
+
74
+
75
+ lightning:
76
+ callbacks:
77
+ image_logger:
78
+ target: main.ImageLogger
79
+ params:
80
+ batch_frequency: 5000
81
+ max_images: 8
82
+ increase_log_steps: False
83
+
84
+ trainer:
85
+ benchmark: True
configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 2.0e-06
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0195
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ image_size: 64
12
+ channels: 3
13
+ monitor: val/loss_simple_ema
14
+ unet_config:
15
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16
+ params:
17
+ image_size: 64
18
+ in_channels: 3
19
+ out_channels: 3
20
+ model_channels: 224
21
+ attention_resolutions:
22
+ # note: this isn\t actually the resolution but
23
+ # the downsampling factor, i.e. this corresnponds to
24
+ # attention on spatial resolution 8,16,32, as the
25
+ # spatial reolution of the latents is 64 for f4
26
+ - 8
27
+ - 4
28
+ - 2
29
+ num_res_blocks: 2
30
+ channel_mult:
31
+ - 1
32
+ - 2
33
+ - 3
34
+ - 4
35
+ num_head_channels: 32
36
+ first_stage_config:
37
+ target: ldm.models.autoencoder.VQModelInterface
38
+ params:
39
+ ckpt_path: configs/first_stage_models/vq-f4/model.yaml
40
+ embed_dim: 3
41
+ n_embed: 8192
42
+ ddconfig:
43
+ double_z: false
44
+ z_channels: 3
45
+ resolution: 256
46
+ in_channels: 3
47
+ out_ch: 3
48
+ ch: 128
49
+ ch_mult:
50
+ - 1
51
+ - 2
52
+ - 4
53
+ num_res_blocks: 2
54
+ attn_resolutions: []
55
+ dropout: 0.0
56
+ lossconfig:
57
+ target: torch.nn.Identity
58
+ cond_stage_config: __is_unconditional__
59
+ data:
60
+ target: main.DataModuleFromConfig
61
+ params:
62
+ batch_size: 48
63
+ num_workers: 5
64
+ wrap: false
65
+ train:
66
+ target: ldm.data.lsun.LSUNBedroomsTrain
67
+ params:
68
+ size: 256
69
+ validation:
70
+ target: ldm.data.lsun.LSUNBedroomsValidation
71
+ params:
72
+ size: 256
73
+
74
+
75
+ lightning:
76
+ callbacks:
77
+ image_logger:
78
+ target: main.ImageLogger
79
+ params:
80
+ batch_frequency: 5000
81
+ max_images: 8
82
+ increase_log_steps: False
83
+
84
+ trainer:
85
+ benchmark: True
configs/latent-diffusion/{lsun_churches_f8-autoencoder-ldm.yaml → lsun_churches-ldm-kl-8.yaml} RENAMED
@@ -45,7 +45,7 @@ model:
45
  params:
46
  embed_dim: 4
47
  monitor: "val/rec_loss"
48
- ckpt_path: "/export/compvis-nfs/user/ablattma/logs/braket/2021-11-26T11-25-56_lsun_churches-convae-f8-ft_from_oi/checkpoints/step=000180071-fidfrechet_inception_distance=2.335.ckpt"
49
  ddconfig:
50
  double_z: True
51
  z_channels: 4
@@ -65,7 +65,7 @@ model:
65
  data:
66
  target: main.DataModuleFromConfig
67
  params:
68
- batch_size: 24 # TODO: was 96 in our experiments
69
  num_workers: 5
70
  wrap: False
71
  train:
@@ -82,14 +82,10 @@ lightning:
82
  image_logger:
83
  target: main.ImageLogger
84
  params:
85
- batch_frequency: 1000 # TODO 5000
86
  max_images: 8
87
  increase_log_steps: False
88
 
89
- metrics_over_trainsteps_checkpoint:
90
- target: pytorch_lightning.callbacks.ModelCheckpoint
91
- params:
92
- every_n_train_steps: 20000
93
 
94
  trainer:
95
  benchmark: True
 
45
  params:
46
  embed_dim: 4
47
  monitor: "val/rec_loss"
48
+ ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
49
  ddconfig:
50
  double_z: True
51
  z_channels: 4
 
65
  data:
66
  target: main.DataModuleFromConfig
67
  params:
68
+ batch_size: 96
69
  num_workers: 5
70
  wrap: False
71
  train:
 
82
  image_logger:
83
  target: main.ImageLogger
84
  params:
85
+ batch_frequency: 5000
86
  max_images: 8
87
  increase_log_steps: False
88
 
 
 
 
 
89
 
90
  trainer:
91
  benchmark: True
ldm/models/diffusion/ddim.py CHANGED
@@ -5,8 +5,7 @@ import numpy as np
5
  from tqdm import tqdm
6
  from functools import partial
7
 
8
- from ldm.models.diffusion.ddpm import noise_like
9
- from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps
10
 
11
 
12
  class DDIMSampler(object):
@@ -27,8 +26,7 @@ class DDIMSampler(object):
27
  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
28
  alphas_cumprod = self.model.alphas_cumprod
29
  assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
30
-
31
- to_torch = partial(torch.tensor, dtype=torch.float32, device=self.model.device)
32
 
33
  self.register_buffer('betas', to_torch(self.model.betas))
34
  self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
@@ -73,7 +71,8 @@ class DDIMSampler(object):
73
  corrector_kwargs=None,
74
  verbose=True,
75
  x_T=None,
76
- log_every_t=100
 
77
  ):
78
  if conditioning is not None:
79
  if isinstance(conditioning, dict):
 
5
  from tqdm import tqdm
6
  from functools import partial
7
 
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
 
9
 
10
 
11
  class DDIMSampler(object):
 
26
  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
27
  alphas_cumprod = self.model.alphas_cumprod
28
  assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
29
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
 
30
 
31
  self.register_buffer('betas', to_torch(self.model.betas))
32
  self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
 
71
  corrector_kwargs=None,
72
  verbose=True,
73
  x_T=None,
74
+ log_every_t=100,
75
+ **kwargs
76
  ):
77
  if conditioning is not None:
78
  if isinstance(conditioning, dict):
ldm/models/diffusion/ddpm.py CHANGED
@@ -16,14 +16,14 @@ from contextlib import contextmanager
16
  from functools import partial
17
  from tqdm import tqdm
18
  from torchvision.utils import make_grid
19
- from PIL import Image
20
  from pytorch_lightning.utilities.distributed import rank_zero_only
21
 
22
  from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
23
  from ldm.modules.ema import LitEma
24
  from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
25
  from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
26
- from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor
 
27
 
28
 
29
  __conditioning_keys__ = {'concat': 'c_concat',
@@ -37,12 +37,6 @@ def disabled_train(self, mode=True):
37
  return self
38
 
39
 
40
- def noise_like(shape, device, repeat=False):
41
- repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
42
- noise = lambda: torch.randn(shape, device=device)
43
- return repeat_noise() if repeat else noise()
44
-
45
-
46
  def uniform_on_device(r1, r2, shape, device):
47
  return (r1 - r2) * torch.rand(*shape, device=device) + r2
48
 
@@ -119,6 +113,7 @@ class DDPM(pl.LightningModule):
119
  if self.learn_logvar:
120
  self.logvar = nn.Parameter(self.logvar, requires_grad=True)
121
 
 
122
  def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
123
  linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
124
  if exists(given_betas):
@@ -1188,7 +1183,6 @@ class LatentDiffusion(DDPM):
1188
 
1189
  if start_T is not None:
1190
  timesteps = min(timesteps, start_T)
1191
- print(timesteps, start_T)
1192
  iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1193
  range(0, timesteps))
1194
 
@@ -1222,7 +1216,7 @@ class LatentDiffusion(DDPM):
1222
  @torch.no_grad()
1223
  def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1224
  verbose=True, timesteps=None, quantize_denoised=False,
1225
- mask=None, x0=None, shape=None):
1226
  if shape is None:
1227
  shape = (batch_size, self.channels, self.image_size, self.image_size)
1228
  if cond is not None:
@@ -1238,10 +1232,28 @@ class LatentDiffusion(DDPM):
1238
  mask=mask, x0=x0)
1239
 
1240
  @torch.no_grad()
1241
- def log_images(self, batch, N=8, n_row=4, sample=True, sample_ddim=False, return_keys=None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1242
  quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1243
  plot_diffusion_rows=True, **kwargs):
1244
- # TODO: maybe add option for ddim sampling via DDIMSampler class
 
 
1245
  log = dict()
1246
  z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
1247
  return_first_stage_outputs=True,
@@ -1288,7 +1300,9 @@ class LatentDiffusion(DDPM):
1288
  if sample:
1289
  # get denoise row
1290
  with self.ema_scope("Plotting"):
1291
- samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
 
 
1292
  x_samples = self.decode_first_stage(samples)
1293
  log["samples"] = x_samples
1294
  if plot_denoise_rows:
@@ -1299,8 +1313,11 @@ class LatentDiffusion(DDPM):
1299
  self.first_stage_model, IdentityFirstStage):
1300
  # also display when quantizing x0 while sampling
1301
  with self.ema_scope("Plotting Quantized Denoised"):
1302
- samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1303
- quantize_denoised=True)
 
 
 
1304
  x_samples = self.decode_first_stage(samples.to(self.device))
1305
  log["samples_x0_quantized"] = x_samples
1306
 
@@ -1312,19 +1329,17 @@ class LatentDiffusion(DDPM):
1312
  mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1313
  mask = mask[:, None, ...]
1314
  with self.ema_scope("Plotting Inpaint"):
1315
- samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1316
- quantize_denoised=False, x0=z[:N], mask=mask)
 
1317
  x_samples = self.decode_first_stage(samples.to(self.device))
1318
  log["samples_inpainting"] = x_samples
1319
  log["mask"] = mask
1320
- if plot_denoise_rows:
1321
- denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1322
- log["denoise_row_inpainting"] = denoise_grid
1323
 
1324
  # outpaint
1325
  with self.ema_scope("Plotting Outpaint"):
1326
- samples = self.sample(cond=c, batch_size=N, return_intermediates=False,
1327
- quantize_denoised=False, x0=z[:N], mask=1. - mask)
1328
  x_samples = self.decode_first_stage(samples.to(self.device))
1329
  log["samples_outpainting"] = x_samples
1330
 
 
16
  from functools import partial
17
  from tqdm import tqdm
18
  from torchvision.utils import make_grid
 
19
  from pytorch_lightning.utilities.distributed import rank_zero_only
20
 
21
  from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
22
  from ldm.modules.ema import LitEma
23
  from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
24
  from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
25
+ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
26
+ from ldm.models.diffusion.ddim import DDIMSampler
27
 
28
 
29
  __conditioning_keys__ = {'concat': 'c_concat',
 
37
  return self
38
 
39
 
 
 
 
 
 
 
40
  def uniform_on_device(r1, r2, shape, device):
41
  return (r1 - r2) * torch.rand(*shape, device=device) + r2
42
 
 
113
  if self.learn_logvar:
114
  self.logvar = nn.Parameter(self.logvar, requires_grad=True)
115
 
116
+
117
  def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
118
  linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
119
  if exists(given_betas):
 
1183
 
1184
  if start_T is not None:
1185
  timesteps = min(timesteps, start_T)
 
1186
  iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1187
  range(0, timesteps))
1188
 
 
1216
  @torch.no_grad()
1217
  def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1218
  verbose=True, timesteps=None, quantize_denoised=False,
1219
+ mask=None, x0=None, shape=None,**kwargs):
1220
  if shape is None:
1221
  shape = (batch_size, self.channels, self.image_size, self.image_size)
1222
  if cond is not None:
 
1232
  mask=mask, x0=x0)
1233
 
1234
  @torch.no_grad()
1235
+ def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
1236
+
1237
+ if ddim:
1238
+ ddim_sampler = DDIMSampler(self)
1239
+ shape = (self.channels, self.image_size, self.image_size)
1240
+ samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
1241
+ shape,cond,verbose=False,**kwargs)
1242
+
1243
+ else:
1244
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1245
+ return_intermediates=True,**kwargs)
1246
+
1247
+ return samples, intermediates
1248
+
1249
+
1250
+ @torch.no_grad()
1251
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1252
  quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1253
  plot_diffusion_rows=True, **kwargs):
1254
+
1255
+ use_ddim = ddim_steps is not None
1256
+
1257
  log = dict()
1258
  z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
1259
  return_first_stage_outputs=True,
 
1300
  if sample:
1301
  # get denoise row
1302
  with self.ema_scope("Plotting"):
1303
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1304
+ ddim_steps=ddim_steps,eta=ddim_eta)
1305
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1306
  x_samples = self.decode_first_stage(samples)
1307
  log["samples"] = x_samples
1308
  if plot_denoise_rows:
 
1313
  self.first_stage_model, IdentityFirstStage):
1314
  # also display when quantizing x0 while sampling
1315
  with self.ema_scope("Plotting Quantized Denoised"):
1316
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1317
+ ddim_steps=ddim_steps,eta=ddim_eta,
1318
+ quantize_denoised=True)
1319
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1320
+ # quantize_denoised=True)
1321
  x_samples = self.decode_first_stage(samples.to(self.device))
1322
  log["samples_x0_quantized"] = x_samples
1323
 
 
1329
  mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1330
  mask = mask[:, None, ...]
1331
  with self.ema_scope("Plotting Inpaint"):
1332
+
1333
+ samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
1334
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1335
  x_samples = self.decode_first_stage(samples.to(self.device))
1336
  log["samples_inpainting"] = x_samples
1337
  log["mask"] = mask
 
 
 
1338
 
1339
  # outpaint
1340
  with self.ema_scope("Plotting Outpaint"):
1341
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
1342
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1343
  x_samples = self.decode_first_stage(samples.to(self.device))
1344
  log["samples_outpainting"] = x_samples
1345
 
ldm/modules/diffusionmodules/util.py CHANGED
@@ -259,3 +259,9 @@ class HybridConditioner(nn.Module):
259
  c_concat = self.concat_conditioner(c_concat)
260
  c_crossattn = self.crossattn_conditioner(c_crossattn)
261
  return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
 
 
 
 
 
 
 
259
  c_concat = self.concat_conditioner(c_concat)
260
  c_crossattn = self.crossattn_conditioner(c_crossattn)
261
  return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
262
+
263
+
264
+ def noise_like(shape, device, repeat=False):
265
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
266
+ noise = lambda: torch.randn(shape, device=device)
267
+ return repeat_noise() if repeat else noise()
main.py CHANGED
@@ -676,7 +676,10 @@ if __name__ == "__main__":
676
  ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
677
  else:
678
  ngpu = 1
679
- accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches or 1
 
 
 
680
  print(f"accumulate_grad_batches = {accumulate_grad_batches}")
681
  lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
682
  if opt.scale_lr:
 
676
  ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
677
  else:
678
  ngpu = 1
679
+ if 'accumulate_grad_batches' in lightning_config.trainer:
680
+ accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
681
+ else:
682
+ accumulate_grad_batches = 1
683
  print(f"accumulate_grad_batches = {accumulate_grad_batches}")
684
  lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
685
  if opt.scale_lr:
models/ldm/semantic_synthesis256/config.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-06
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0205
7
+ log_every_t: 100
8
+ timesteps: 1000
9
+ loss_type: l1
10
+ first_stage_key: image
11
+ cond_stage_key: segmentation
12
+ image_size: 64
13
+ channels: 3
14
+ concat_mode: true
15
+ cond_stage_trainable: true
16
+ unet_config:
17
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18
+ params:
19
+ image_size: 64
20
+ in_channels: 6
21
+ out_channels: 3
22
+ model_channels: 128
23
+ attention_resolutions:
24
+ - 32
25
+ - 16
26
+ - 8
27
+ num_res_blocks: 2
28
+ channel_mult:
29
+ - 1
30
+ - 4
31
+ - 8
32
+ num_heads: 8
33
+ first_stage_config:
34
+ target: ldm.models.autoencoder.VQModelInterface
35
+ params:
36
+ embed_dim: 3
37
+ n_embed: 8192
38
+ ddconfig:
39
+ double_z: false
40
+ z_channels: 3
41
+ resolution: 256
42
+ in_channels: 3
43
+ out_ch: 3
44
+ ch: 128
45
+ ch_mult:
46
+ - 1
47
+ - 2
48
+ - 4
49
+ num_res_blocks: 2
50
+ attn_resolutions: []
51
+ dropout: 0.0
52
+ lossconfig:
53
+ target: torch.nn.Identity
54
+ cond_stage_config:
55
+ target: ldm.modules.encoders.modules.SpatialRescaler
56
+ params:
57
+ n_stages: 2
58
+ in_channels: 182
59
+ out_channels: 3
scripts/download_first_stages.sh CHANGED
@@ -4,10 +4,10 @@ wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/la
4
  wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip
5
  wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip
6
  wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip
7
- wget -O models/first_stage_models/vq-f4-noattn/model.zip https://heibox.uni-heidelberg.de/f/9c6681f64bb94338a069/?dl=1
8
  wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
9
  wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip
10
- wget -O models/first_stage_models/vq-f16/model.zip https://heibox.uni-heidelberg.de/f/0e42b04e2e904890a9b6/?dl=1
11
 
12
 
13
 
 
4
  wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip
5
  wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip
6
  wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip
7
+ wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip
8
  wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
9
  wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip
10
+ wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip
11
 
12
 
13
 
scripts/download_models.sh CHANGED
@@ -6,9 +6,10 @@ wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/la
6
  wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip
7
  wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip
8
  wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip
 
9
  wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip
10
  wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip
11
- wget -O models/ldm/inpainting_big/last.ckpt https://heibox.uni-heidelberg.de/f/4d9ac7ea40c64582b7c9/?dl=1
12
 
13
 
14
 
@@ -33,10 +34,16 @@ unzip -o model.zip
33
  cd ../semantic_synthesis512
34
  unzip -o model.zip
35
 
 
 
 
36
  cd ../bsr_sr
37
  unzip -o model.zip
38
 
39
  cd ../layout2img-openimages256
40
  unzip -o model.zip
41
 
 
 
 
42
  cd ../..
 
6
  wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip
7
  wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip
8
  wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip
9
+ wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip
10
  wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip
11
  wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip
12
+ wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip
13
 
14
 
15
 
 
34
  cd ../semantic_synthesis512
35
  unzip -o model.zip
36
 
37
+ cd ../semantic_synthesis256
38
+ unzip -o model.zip
39
+
40
  cd ../bsr_sr
41
  unzip -o model.zip
42
 
43
  cd ../layout2img-openimages256
44
  unzip -o model.zip
45
 
46
+ cd ../inpainting_big
47
+ unzip -o model.zip
48
+
49
  cd ../..