Spaces:
Runtime error
Runtime error
常舒宁
commited on
Commit
•
1dc89cf
1
Parent(s):
6d96b0b
add files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- 00058.jpeg +0 -0
- app.py +15 -0
- configs/autoencoder/autoencoder_kl_16x16x16.yaml +54 -0
- configs/autoencoder/autoencoder_kl_32x32x4.yaml +53 -0
- configs/autoencoder/autoencoder_kl_64x64x3.yaml +54 -0
- configs/autoencoder/autoencoder_kl_8x8x64.yaml +53 -0
- configs/latent-diffusion/celebahq-ldm-vq-4.yaml +86 -0
- configs/latent-diffusion/cin-ldm-vq-f8.yaml +98 -0
- configs/latent-diffusion/cin256-v2.yaml +68 -0
- configs/latent-diffusion/ffhq-ldm-vq-4.yaml +85 -0
- configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml +85 -0
- configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml +91 -0
- configs/latent-diffusion/txt2img-1p4B-eval.yaml +71 -0
- configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml +77 -0
- configs/latent-diffusion/txt2img-1p4B-finetune.yaml +119 -0
- configs/latent-diffusion/txt2img-1p4B-finetune_style.yaml +117 -0
- configs/stable-diffusion/v1-finetune.yaml +110 -0
- configs/stable-diffusion/v1-finetune_unfrozen.yaml +120 -0
- configs/stable-diffusion/v1-inference.yaml +70 -0
- environment.yaml +31 -0
- evaluation/__pycache__/clip_eval.cpython-36.pyc +0 -0
- evaluation/__pycache__/clip_eval.cpython-38.pyc +0 -0
- evaluation/clip_eval.py +113 -0
- ldm/__pycache__/util.cpython-36.pyc +0 -0
- ldm/__pycache__/util.cpython-38.pyc +0 -0
- ldm/data/__init__.py +0 -0
- ldm/data/__pycache__/__init__.cpython-36.pyc +0 -0
- ldm/data/__pycache__/__init__.cpython-38.pyc +0 -0
- ldm/data/__pycache__/base.cpython-36.pyc +0 -0
- ldm/data/__pycache__/base.cpython-38.pyc +0 -0
- ldm/data/__pycache__/personalized.cpython-36.pyc +0 -0
- ldm/data/__pycache__/personalized.cpython-38.pyc +0 -0
- ldm/data/__pycache__/personalized_compose.cpython-38.pyc +0 -0
- ldm/data/__pycache__/personalized_detailed_text.cpython-36.pyc +0 -0
- ldm/data/__pycache__/personalized_style.cpython-36.pyc +0 -0
- ldm/data/__pycache__/personalized_style.cpython-38.pyc +0 -0
- ldm/data/base.py +23 -0
- ldm/data/imagenet.py +394 -0
- ldm/data/lsun.py +92 -0
- ldm/data/personalized.py +220 -0
- ldm/data/personalized_style.py +129 -0
- ldm/lr_scheduler.py +98 -0
- ldm/models/__pycache__/autoencoder.cpython-36.pyc +0 -0
- ldm/models/__pycache__/autoencoder.cpython-38.pyc +0 -0
- ldm/models/autoencoder.py +443 -0
- ldm/models/diffusion/__init__.py +0 -0
- ldm/models/diffusion/__pycache__/__init__.cpython-36.pyc +0 -0
- ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc +0 -0
- ldm/models/diffusion/__pycache__/ddim.cpython-36.pyc +0 -0
.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
00058.jpeg
ADDED
![]() |
app.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
os.system('pip install -e src/taming-transformers/.')
|
4 |
+
os.system('pip install -e src/clip/.')
|
5 |
+
os.system('pip install -e .')
|
6 |
+
from scripts.stable_txt2img import main
|
7 |
+
# st.title('AI Gen for SG')
|
8 |
+
st.markdown("<h1 style='text-align: center; color: orange;'>AI Gen for SG</h1>", unsafe_allow_html=True)
|
9 |
+
st.markdown("<h3 style='text-align: center; color: black;'>ShowLab</h3>", unsafe_allow_html=True)
|
10 |
+
st.write('Contributors: Mike Zheng Shou, Shuning Chang, Yufei Shi, Zihan Fan, Xiangdong Zhou')
|
11 |
+
text = st.text_input('Enter your prompt', value='', key=None)
|
12 |
+
img = main(text)
|
13 |
+
# st.write(text)
|
14 |
+
if text:
|
15 |
+
st.image(img, caption=text)
|
configs/autoencoder/autoencoder_kl_16x16x16.yaml
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 4.5e-6
|
3 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
4 |
+
params:
|
5 |
+
monitor: "val/rec_loss"
|
6 |
+
embed_dim: 16
|
7 |
+
lossconfig:
|
8 |
+
target: ldm.modules.losses.LPIPSWithDiscriminator
|
9 |
+
params:
|
10 |
+
disc_start: 50001
|
11 |
+
kl_weight: 0.000001
|
12 |
+
disc_weight: 0.5
|
13 |
+
|
14 |
+
ddconfig:
|
15 |
+
double_z: True
|
16 |
+
z_channels: 16
|
17 |
+
resolution: 256
|
18 |
+
in_channels: 3
|
19 |
+
out_ch: 3
|
20 |
+
ch: 128
|
21 |
+
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
|
22 |
+
num_res_blocks: 2
|
23 |
+
attn_resolutions: [16]
|
24 |
+
dropout: 0.0
|
25 |
+
|
26 |
+
|
27 |
+
data:
|
28 |
+
target: main.DataModuleFromConfig
|
29 |
+
params:
|
30 |
+
batch_size: 12
|
31 |
+
wrap: True
|
32 |
+
train:
|
33 |
+
target: ldm.data.imagenet.ImageNetSRTrain
|
34 |
+
params:
|
35 |
+
size: 256
|
36 |
+
degradation: pil_nearest
|
37 |
+
validation:
|
38 |
+
target: ldm.data.imagenet.ImageNetSRValidation
|
39 |
+
params:
|
40 |
+
size: 256
|
41 |
+
degradation: pil_nearest
|
42 |
+
|
43 |
+
lightning:
|
44 |
+
callbacks:
|
45 |
+
image_logger:
|
46 |
+
target: main.ImageLogger
|
47 |
+
params:
|
48 |
+
batch_frequency: 1000
|
49 |
+
max_images: 8
|
50 |
+
increase_log_steps: True
|
51 |
+
|
52 |
+
trainer:
|
53 |
+
benchmark: True
|
54 |
+
accumulate_grad_batches: 2
|
configs/autoencoder/autoencoder_kl_32x32x4.yaml
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 4.5e-6
|
3 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
4 |
+
params:
|
5 |
+
monitor: "val/rec_loss"
|
6 |
+
embed_dim: 4
|
7 |
+
lossconfig:
|
8 |
+
target: ldm.modules.losses.LPIPSWithDiscriminator
|
9 |
+
params:
|
10 |
+
disc_start: 50001
|
11 |
+
kl_weight: 0.000001
|
12 |
+
disc_weight: 0.5
|
13 |
+
|
14 |
+
ddconfig:
|
15 |
+
double_z: True
|
16 |
+
z_channels: 4
|
17 |
+
resolution: 256
|
18 |
+
in_channels: 3
|
19 |
+
out_ch: 3
|
20 |
+
ch: 128
|
21 |
+
ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
|
22 |
+
num_res_blocks: 2
|
23 |
+
attn_resolutions: [ ]
|
24 |
+
dropout: 0.0
|
25 |
+
|
26 |
+
data:
|
27 |
+
target: main.DataModuleFromConfig
|
28 |
+
params:
|
29 |
+
batch_size: 12
|
30 |
+
wrap: True
|
31 |
+
train:
|
32 |
+
target: ldm.data.imagenet.ImageNetSRTrain
|
33 |
+
params:
|
34 |
+
size: 256
|
35 |
+
degradation: pil_nearest
|
36 |
+
validation:
|
37 |
+
target: ldm.data.imagenet.ImageNetSRValidation
|
38 |
+
params:
|
39 |
+
size: 256
|
40 |
+
degradation: pil_nearest
|
41 |
+
|
42 |
+
lightning:
|
43 |
+
callbacks:
|
44 |
+
image_logger:
|
45 |
+
target: main.ImageLogger
|
46 |
+
params:
|
47 |
+
batch_frequency: 1000
|
48 |
+
max_images: 8
|
49 |
+
increase_log_steps: True
|
50 |
+
|
51 |
+
trainer:
|
52 |
+
benchmark: True
|
53 |
+
accumulate_grad_batches: 2
|
configs/autoencoder/autoencoder_kl_64x64x3.yaml
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 4.5e-6
|
3 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
4 |
+
params:
|
5 |
+
monitor: "val/rec_loss"
|
6 |
+
embed_dim: 3
|
7 |
+
lossconfig:
|
8 |
+
target: ldm.modules.losses.LPIPSWithDiscriminator
|
9 |
+
params:
|
10 |
+
disc_start: 50001
|
11 |
+
kl_weight: 0.000001
|
12 |
+
disc_weight: 0.5
|
13 |
+
|
14 |
+
ddconfig:
|
15 |
+
double_z: True
|
16 |
+
z_channels: 3
|
17 |
+
resolution: 256
|
18 |
+
in_channels: 3
|
19 |
+
out_ch: 3
|
20 |
+
ch: 128
|
21 |
+
ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
|
22 |
+
num_res_blocks: 2
|
23 |
+
attn_resolutions: [ ]
|
24 |
+
dropout: 0.0
|
25 |
+
|
26 |
+
|
27 |
+
data:
|
28 |
+
target: main.DataModuleFromConfig
|
29 |
+
params:
|
30 |
+
batch_size: 12
|
31 |
+
wrap: True
|
32 |
+
train:
|
33 |
+
target: ldm.data.imagenet.ImageNetSRTrain
|
34 |
+
params:
|
35 |
+
size: 256
|
36 |
+
degradation: pil_nearest
|
37 |
+
validation:
|
38 |
+
target: ldm.data.imagenet.ImageNetSRValidation
|
39 |
+
params:
|
40 |
+
size: 256
|
41 |
+
degradation: pil_nearest
|
42 |
+
|
43 |
+
lightning:
|
44 |
+
callbacks:
|
45 |
+
image_logger:
|
46 |
+
target: main.ImageLogger
|
47 |
+
params:
|
48 |
+
batch_frequency: 1000
|
49 |
+
max_images: 8
|
50 |
+
increase_log_steps: True
|
51 |
+
|
52 |
+
trainer:
|
53 |
+
benchmark: True
|
54 |
+
accumulate_grad_batches: 2
|
configs/autoencoder/autoencoder_kl_8x8x64.yaml
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 4.5e-6
|
3 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
4 |
+
params:
|
5 |
+
monitor: "val/rec_loss"
|
6 |
+
embed_dim: 64
|
7 |
+
lossconfig:
|
8 |
+
target: ldm.modules.losses.LPIPSWithDiscriminator
|
9 |
+
params:
|
10 |
+
disc_start: 50001
|
11 |
+
kl_weight: 0.000001
|
12 |
+
disc_weight: 0.5
|
13 |
+
|
14 |
+
ddconfig:
|
15 |
+
double_z: True
|
16 |
+
z_channels: 64
|
17 |
+
resolution: 256
|
18 |
+
in_channels: 3
|
19 |
+
out_ch: 3
|
20 |
+
ch: 128
|
21 |
+
ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1
|
22 |
+
num_res_blocks: 2
|
23 |
+
attn_resolutions: [16,8]
|
24 |
+
dropout: 0.0
|
25 |
+
|
26 |
+
data:
|
27 |
+
target: main.DataModuleFromConfig
|
28 |
+
params:
|
29 |
+
batch_size: 12
|
30 |
+
wrap: True
|
31 |
+
train:
|
32 |
+
target: ldm.data.imagenet.ImageNetSRTrain
|
33 |
+
params:
|
34 |
+
size: 256
|
35 |
+
degradation: pil_nearest
|
36 |
+
validation:
|
37 |
+
target: ldm.data.imagenet.ImageNetSRValidation
|
38 |
+
params:
|
39 |
+
size: 256
|
40 |
+
degradation: pil_nearest
|
41 |
+
|
42 |
+
lightning:
|
43 |
+
callbacks:
|
44 |
+
image_logger:
|
45 |
+
target: main.ImageLogger
|
46 |
+
params:
|
47 |
+
batch_frequency: 1000
|
48 |
+
max_images: 8
|
49 |
+
increase_log_steps: True
|
50 |
+
|
51 |
+
trainer:
|
52 |
+
benchmark: True
|
53 |
+
accumulate_grad_batches: 2
|
configs/latent-diffusion/celebahq-ldm-vq-4.yaml
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 2.0e-06
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.0015
|
6 |
+
linear_end: 0.0195
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: image
|
11 |
+
image_size: 64
|
12 |
+
channels: 3
|
13 |
+
monitor: val/loss_simple_ema
|
14 |
+
|
15 |
+
unet_config:
|
16 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
17 |
+
params:
|
18 |
+
image_size: 64
|
19 |
+
in_channels: 3
|
20 |
+
out_channels: 3
|
21 |
+
model_channels: 224
|
22 |
+
attention_resolutions:
|
23 |
+
# note: this isn\t actually the resolution but
|
24 |
+
# the downsampling factor, i.e. this corresnponds to
|
25 |
+
# attention on spatial resolution 8,16,32, as the
|
26 |
+
# spatial reolution of the latents is 64 for f4
|
27 |
+
- 8
|
28 |
+
- 4
|
29 |
+
- 2
|
30 |
+
num_res_blocks: 2
|
31 |
+
channel_mult:
|
32 |
+
- 1
|
33 |
+
- 2
|
34 |
+
- 3
|
35 |
+
- 4
|
36 |
+
num_head_channels: 32
|
37 |
+
first_stage_config:
|
38 |
+
target: ldm.models.autoencoder.VQModelInterface
|
39 |
+
params:
|
40 |
+
embed_dim: 3
|
41 |
+
n_embed: 8192
|
42 |
+
ckpt_path: models/first_stage_models/vq-f4/model.ckpt
|
43 |
+
ddconfig:
|
44 |
+
double_z: false
|
45 |
+
z_channels: 3
|
46 |
+
resolution: 256
|
47 |
+
in_channels: 3
|
48 |
+
out_ch: 3
|
49 |
+
ch: 128
|
50 |
+
ch_mult:
|
51 |
+
- 1
|
52 |
+
- 2
|
53 |
+
- 4
|
54 |
+
num_res_blocks: 2
|
55 |
+
attn_resolutions: []
|
56 |
+
dropout: 0.0
|
57 |
+
lossconfig:
|
58 |
+
target: torch.nn.Identity
|
59 |
+
cond_stage_config: __is_unconditional__
|
60 |
+
data:
|
61 |
+
target: main.DataModuleFromConfig
|
62 |
+
params:
|
63 |
+
batch_size: 48
|
64 |
+
num_workers: 5
|
65 |
+
wrap: false
|
66 |
+
train:
|
67 |
+
target: taming.data.faceshq.CelebAHQTrain
|
68 |
+
params:
|
69 |
+
size: 256
|
70 |
+
validation:
|
71 |
+
target: taming.data.faceshq.CelebAHQValidation
|
72 |
+
params:
|
73 |
+
size: 256
|
74 |
+
|
75 |
+
|
76 |
+
lightning:
|
77 |
+
callbacks:
|
78 |
+
image_logger:
|
79 |
+
target: main.ImageLogger
|
80 |
+
params:
|
81 |
+
batch_frequency: 5000
|
82 |
+
max_images: 8
|
83 |
+
increase_log_steps: False
|
84 |
+
|
85 |
+
trainer:
|
86 |
+
benchmark: True
|
configs/latent-diffusion/cin-ldm-vq-f8.yaml
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-06
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.0015
|
6 |
+
linear_end: 0.0195
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: image
|
11 |
+
cond_stage_key: class_label
|
12 |
+
image_size: 32
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: true
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
unet_config:
|
18 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
19 |
+
params:
|
20 |
+
image_size: 32
|
21 |
+
in_channels: 4
|
22 |
+
out_channels: 4
|
23 |
+
model_channels: 256
|
24 |
+
attention_resolutions:
|
25 |
+
#note: this isn\t actually the resolution but
|
26 |
+
# the downsampling factor, i.e. this corresnponds to
|
27 |
+
# attention on spatial resolution 8,16,32, as the
|
28 |
+
# spatial reolution of the latents is 32 for f8
|
29 |
+
- 4
|
30 |
+
- 2
|
31 |
+
- 1
|
32 |
+
num_res_blocks: 2
|
33 |
+
channel_mult:
|
34 |
+
- 1
|
35 |
+
- 2
|
36 |
+
- 4
|
37 |
+
num_head_channels: 32
|
38 |
+
use_spatial_transformer: true
|
39 |
+
transformer_depth: 1
|
40 |
+
context_dim: 512
|
41 |
+
first_stage_config:
|
42 |
+
target: ldm.models.autoencoder.VQModelInterface
|
43 |
+
params:
|
44 |
+
embed_dim: 4
|
45 |
+
n_embed: 16384
|
46 |
+
ckpt_path: configs/first_stage_models/vq-f8/model.yaml
|
47 |
+
ddconfig:
|
48 |
+
double_z: false
|
49 |
+
z_channels: 4
|
50 |
+
resolution: 256
|
51 |
+
in_channels: 3
|
52 |
+
out_ch: 3
|
53 |
+
ch: 128
|
54 |
+
ch_mult:
|
55 |
+
- 1
|
56 |
+
- 2
|
57 |
+
- 2
|
58 |
+
- 4
|
59 |
+
num_res_blocks: 2
|
60 |
+
attn_resolutions:
|
61 |
+
- 32
|
62 |
+
dropout: 0.0
|
63 |
+
lossconfig:
|
64 |
+
target: torch.nn.Identity
|
65 |
+
cond_stage_config:
|
66 |
+
target: ldm.modules.encoders.modules.ClassEmbedder
|
67 |
+
params:
|
68 |
+
embed_dim: 512
|
69 |
+
key: class_label
|
70 |
+
data:
|
71 |
+
target: main.DataModuleFromConfig
|
72 |
+
params:
|
73 |
+
batch_size: 64
|
74 |
+
num_workers: 12
|
75 |
+
wrap: false
|
76 |
+
train:
|
77 |
+
target: ldm.data.imagenet.ImageNetTrain
|
78 |
+
params:
|
79 |
+
config:
|
80 |
+
size: 256
|
81 |
+
validation:
|
82 |
+
target: ldm.data.imagenet.ImageNetValidation
|
83 |
+
params:
|
84 |
+
config:
|
85 |
+
size: 256
|
86 |
+
|
87 |
+
|
88 |
+
lightning:
|
89 |
+
callbacks:
|
90 |
+
image_logger:
|
91 |
+
target: main.ImageLogger
|
92 |
+
params:
|
93 |
+
batch_frequency: 5000
|
94 |
+
max_images: 8
|
95 |
+
increase_log_steps: False
|
96 |
+
|
97 |
+
trainer:
|
98 |
+
benchmark: True
|
configs/latent-diffusion/cin256-v2.yaml
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 0.0001
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.0015
|
6 |
+
linear_end: 0.0195
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: image
|
11 |
+
cond_stage_key: class_label
|
12 |
+
image_size: 64
|
13 |
+
channels: 3
|
14 |
+
cond_stage_trainable: true
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss
|
17 |
+
use_ema: False
|
18 |
+
|
19 |
+
unet_config:
|
20 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
21 |
+
params:
|
22 |
+
image_size: 64
|
23 |
+
in_channels: 3
|
24 |
+
out_channels: 3
|
25 |
+
model_channels: 192
|
26 |
+
attention_resolutions:
|
27 |
+
- 8
|
28 |
+
- 4
|
29 |
+
- 2
|
30 |
+
num_res_blocks: 2
|
31 |
+
channel_mult:
|
32 |
+
- 1
|
33 |
+
- 2
|
34 |
+
- 3
|
35 |
+
- 5
|
36 |
+
num_heads: 1
|
37 |
+
use_spatial_transformer: true
|
38 |
+
transformer_depth: 1
|
39 |
+
context_dim: 512
|
40 |
+
|
41 |
+
first_stage_config:
|
42 |
+
target: ldm.models.autoencoder.VQModelInterface
|
43 |
+
params:
|
44 |
+
embed_dim: 3
|
45 |
+
n_embed: 8192
|
46 |
+
ddconfig:
|
47 |
+
double_z: false
|
48 |
+
z_channels: 3
|
49 |
+
resolution: 256
|
50 |
+
in_channels: 3
|
51 |
+
out_ch: 3
|
52 |
+
ch: 128
|
53 |
+
ch_mult:
|
54 |
+
- 1
|
55 |
+
- 2
|
56 |
+
- 4
|
57 |
+
num_res_blocks: 2
|
58 |
+
attn_resolutions: []
|
59 |
+
dropout: 0.0
|
60 |
+
lossconfig:
|
61 |
+
target: torch.nn.Identity
|
62 |
+
|
63 |
+
cond_stage_config:
|
64 |
+
target: ldm.modules.encoders.modules.ClassEmbedder
|
65 |
+
params:
|
66 |
+
n_classes: 1001
|
67 |
+
embed_dim: 512
|
68 |
+
key: class_label
|
configs/latent-diffusion/ffhq-ldm-vq-4.yaml
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 2.0e-06
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.0015
|
6 |
+
linear_end: 0.0195
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: image
|
11 |
+
image_size: 64
|
12 |
+
channels: 3
|
13 |
+
monitor: val/loss_simple_ema
|
14 |
+
unet_config:
|
15 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
16 |
+
params:
|
17 |
+
image_size: 64
|
18 |
+
in_channels: 3
|
19 |
+
out_channels: 3
|
20 |
+
model_channels: 224
|
21 |
+
attention_resolutions:
|
22 |
+
# note: this isn\t actually the resolution but
|
23 |
+
# the downsampling factor, i.e. this corresnponds to
|
24 |
+
# attention on spatial resolution 8,16,32, as the
|
25 |
+
# spatial reolution of the latents is 64 for f4
|
26 |
+
- 8
|
27 |
+
- 4
|
28 |
+
- 2
|
29 |
+
num_res_blocks: 2
|
30 |
+
channel_mult:
|
31 |
+
- 1
|
32 |
+
- 2
|
33 |
+
- 3
|
34 |
+
- 4
|
35 |
+
num_head_channels: 32
|
36 |
+
first_stage_config:
|
37 |
+
target: ldm.models.autoencoder.VQModelInterface
|
38 |
+
params:
|
39 |
+
embed_dim: 3
|
40 |
+
n_embed: 8192
|
41 |
+
ckpt_path: configs/first_stage_models/vq-f4/model.yaml
|
42 |
+
ddconfig:
|
43 |
+
double_z: false
|
44 |
+
z_channels: 3
|
45 |
+
resolution: 256
|
46 |
+
in_channels: 3
|
47 |
+
out_ch: 3
|
48 |
+
ch: 128
|
49 |
+
ch_mult:
|
50 |
+
- 1
|
51 |
+
- 2
|
52 |
+
- 4
|
53 |
+
num_res_blocks: 2
|
54 |
+
attn_resolutions: []
|
55 |
+
dropout: 0.0
|
56 |
+
lossconfig:
|
57 |
+
target: torch.nn.Identity
|
58 |
+
cond_stage_config: __is_unconditional__
|
59 |
+
data:
|
60 |
+
target: main.DataModuleFromConfig
|
61 |
+
params:
|
62 |
+
batch_size: 42
|
63 |
+
num_workers: 5
|
64 |
+
wrap: false
|
65 |
+
train:
|
66 |
+
target: taming.data.faceshq.FFHQTrain
|
67 |
+
params:
|
68 |
+
size: 256
|
69 |
+
validation:
|
70 |
+
target: taming.data.faceshq.FFHQValidation
|
71 |
+
params:
|
72 |
+
size: 256
|
73 |
+
|
74 |
+
|
75 |
+
lightning:
|
76 |
+
callbacks:
|
77 |
+
image_logger:
|
78 |
+
target: main.ImageLogger
|
79 |
+
params:
|
80 |
+
batch_frequency: 5000
|
81 |
+
max_images: 8
|
82 |
+
increase_log_steps: False
|
83 |
+
|
84 |
+
trainer:
|
85 |
+
benchmark: True
|
configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 2.0e-06
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.0015
|
6 |
+
linear_end: 0.0195
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: image
|
11 |
+
image_size: 64
|
12 |
+
channels: 3
|
13 |
+
monitor: val/loss_simple_ema
|
14 |
+
unet_config:
|
15 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
16 |
+
params:
|
17 |
+
image_size: 64
|
18 |
+
in_channels: 3
|
19 |
+
out_channels: 3
|
20 |
+
model_channels: 224
|
21 |
+
attention_resolutions:
|
22 |
+
# note: this isn\t actually the resolution but
|
23 |
+
# the downsampling factor, i.e. this corresnponds to
|
24 |
+
# attention on spatial resolution 8,16,32, as the
|
25 |
+
# spatial reolution of the latents is 64 for f4
|
26 |
+
- 8
|
27 |
+
- 4
|
28 |
+
- 2
|
29 |
+
num_res_blocks: 2
|
30 |
+
channel_mult:
|
31 |
+
- 1
|
32 |
+
- 2
|
33 |
+
- 3
|
34 |
+
- 4
|
35 |
+
num_head_channels: 32
|
36 |
+
first_stage_config:
|
37 |
+
target: ldm.models.autoencoder.VQModelInterface
|
38 |
+
params:
|
39 |
+
ckpt_path: configs/first_stage_models/vq-f4/model.yaml
|
40 |
+
embed_dim: 3
|
41 |
+
n_embed: 8192
|
42 |
+
ddconfig:
|
43 |
+
double_z: false
|
44 |
+
z_channels: 3
|
45 |
+
resolution: 256
|
46 |
+
in_channels: 3
|
47 |
+
out_ch: 3
|
48 |
+
ch: 128
|
49 |
+
ch_mult:
|
50 |
+
- 1
|
51 |
+
- 2
|
52 |
+
- 4
|
53 |
+
num_res_blocks: 2
|
54 |
+
attn_resolutions: []
|
55 |
+
dropout: 0.0
|
56 |
+
lossconfig:
|
57 |
+
target: torch.nn.Identity
|
58 |
+
cond_stage_config: __is_unconditional__
|
59 |
+
data:
|
60 |
+
target: main.DataModuleFromConfig
|
61 |
+
params:
|
62 |
+
batch_size: 48
|
63 |
+
num_workers: 5
|
64 |
+
wrap: false
|
65 |
+
train:
|
66 |
+
target: ldm.data.lsun.LSUNBedroomsTrain
|
67 |
+
params:
|
68 |
+
size: 256
|
69 |
+
validation:
|
70 |
+
target: ldm.data.lsun.LSUNBedroomsValidation
|
71 |
+
params:
|
72 |
+
size: 256
|
73 |
+
|
74 |
+
|
75 |
+
lightning:
|
76 |
+
callbacks:
|
77 |
+
image_logger:
|
78 |
+
target: main.ImageLogger
|
79 |
+
params:
|
80 |
+
batch_frequency: 5000
|
81 |
+
max_images: 8
|
82 |
+
increase_log_steps: False
|
83 |
+
|
84 |
+
trainer:
|
85 |
+
benchmark: True
|
configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.0015
|
6 |
+
linear_end: 0.0155
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
loss_type: l1
|
11 |
+
first_stage_key: "image"
|
12 |
+
cond_stage_key: "image"
|
13 |
+
image_size: 32
|
14 |
+
channels: 4
|
15 |
+
cond_stage_trainable: False
|
16 |
+
concat_mode: False
|
17 |
+
scale_by_std: True
|
18 |
+
monitor: 'val/loss_simple_ema'
|
19 |
+
|
20 |
+
scheduler_config: # 10000 warmup steps
|
21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
22 |
+
params:
|
23 |
+
warm_up_steps: [10000]
|
24 |
+
cycle_lengths: [10000000000000]
|
25 |
+
f_start: [1.e-6]
|
26 |
+
f_max: [1.]
|
27 |
+
f_min: [ 1.]
|
28 |
+
|
29 |
+
unet_config:
|
30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
31 |
+
params:
|
32 |
+
image_size: 32
|
33 |
+
in_channels: 4
|
34 |
+
out_channels: 4
|
35 |
+
model_channels: 192
|
36 |
+
attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4
|
37 |
+
num_res_blocks: 2
|
38 |
+
channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2
|
39 |
+
num_heads: 8
|
40 |
+
use_scale_shift_norm: True
|
41 |
+
resblock_updown: True
|
42 |
+
|
43 |
+
first_stage_config:
|
44 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
45 |
+
params:
|
46 |
+
embed_dim: 4
|
47 |
+
monitor: "val/rec_loss"
|
48 |
+
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
49 |
+
ddconfig:
|
50 |
+
double_z: True
|
51 |
+
z_channels: 4
|
52 |
+
resolution: 256
|
53 |
+
in_channels: 3
|
54 |
+
out_ch: 3
|
55 |
+
ch: 128
|
56 |
+
ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
|
57 |
+
num_res_blocks: 2
|
58 |
+
attn_resolutions: [ ]
|
59 |
+
dropout: 0.0
|
60 |
+
lossconfig:
|
61 |
+
target: torch.nn.Identity
|
62 |
+
|
63 |
+
cond_stage_config: "__is_unconditional__"
|
64 |
+
|
65 |
+
data:
|
66 |
+
target: main.DataModuleFromConfig
|
67 |
+
params:
|
68 |
+
batch_size: 96
|
69 |
+
num_workers: 5
|
70 |
+
wrap: False
|
71 |
+
train:
|
72 |
+
target: ldm.data.lsun.LSUNChurchesTrain
|
73 |
+
params:
|
74 |
+
size: 256
|
75 |
+
validation:
|
76 |
+
target: ldm.data.lsun.LSUNChurchesValidation
|
77 |
+
params:
|
78 |
+
size: 256
|
79 |
+
|
80 |
+
lightning:
|
81 |
+
callbacks:
|
82 |
+
image_logger:
|
83 |
+
target: main.ImageLogger
|
84 |
+
params:
|
85 |
+
batch_frequency: 5000
|
86 |
+
max_images: 8
|
87 |
+
increase_log_steps: False
|
88 |
+
|
89 |
+
|
90 |
+
trainer:
|
91 |
+
benchmark: True
|
configs/latent-diffusion/txt2img-1p4B-eval.yaml
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 5.0e-05
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.012
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: image
|
11 |
+
cond_stage_key: caption
|
12 |
+
image_size: 32
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: true
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
|
20 |
+
unet_config:
|
21 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
22 |
+
params:
|
23 |
+
image_size: 32
|
24 |
+
in_channels: 4
|
25 |
+
out_channels: 4
|
26 |
+
model_channels: 320
|
27 |
+
attention_resolutions:
|
28 |
+
- 4
|
29 |
+
- 2
|
30 |
+
- 1
|
31 |
+
num_res_blocks: 2
|
32 |
+
channel_mult:
|
33 |
+
- 1
|
34 |
+
- 2
|
35 |
+
- 4
|
36 |
+
- 4
|
37 |
+
num_heads: 8
|
38 |
+
use_spatial_transformer: true
|
39 |
+
transformer_depth: 1
|
40 |
+
context_dim: 1280
|
41 |
+
use_checkpoint: true
|
42 |
+
legacy: False
|
43 |
+
|
44 |
+
first_stage_config:
|
45 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
46 |
+
params:
|
47 |
+
embed_dim: 4
|
48 |
+
monitor: val/rec_loss
|
49 |
+
ddconfig:
|
50 |
+
double_z: true
|
51 |
+
z_channels: 4
|
52 |
+
resolution: 256
|
53 |
+
in_channels: 3
|
54 |
+
out_ch: 3
|
55 |
+
ch: 128
|
56 |
+
ch_mult:
|
57 |
+
- 1
|
58 |
+
- 2
|
59 |
+
- 4
|
60 |
+
- 4
|
61 |
+
num_res_blocks: 2
|
62 |
+
attn_resolutions: []
|
63 |
+
dropout: 0.0
|
64 |
+
lossconfig:
|
65 |
+
target: torch.nn.Identity
|
66 |
+
|
67 |
+
cond_stage_config:
|
68 |
+
target: ldm.modules.encoders.modules.BERTEmbedder
|
69 |
+
params:
|
70 |
+
n_embed: 1280
|
71 |
+
n_layer: 32
|
configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 5.0e-05
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.012
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: image
|
11 |
+
cond_stage_key: caption
|
12 |
+
image_size: 32
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: true
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
|
20 |
+
personalization_config:
|
21 |
+
target: ldm.modules.embedding_manager.EmbeddingManager
|
22 |
+
params:
|
23 |
+
placeholder_strings: ["*"]
|
24 |
+
initializer_words: []
|
25 |
+
|
26 |
+
unet_config:
|
27 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
28 |
+
params:
|
29 |
+
image_size: 32
|
30 |
+
in_channels: 4
|
31 |
+
out_channels: 4
|
32 |
+
model_channels: 320
|
33 |
+
attention_resolutions:
|
34 |
+
- 4
|
35 |
+
- 2
|
36 |
+
- 1
|
37 |
+
num_res_blocks: 2
|
38 |
+
channel_mult:
|
39 |
+
- 1
|
40 |
+
- 2
|
41 |
+
- 4
|
42 |
+
- 4
|
43 |
+
num_heads: 8
|
44 |
+
use_spatial_transformer: true
|
45 |
+
transformer_depth: 1
|
46 |
+
context_dim: 1280
|
47 |
+
use_checkpoint: true
|
48 |
+
legacy: False
|
49 |
+
|
50 |
+
first_stage_config:
|
51 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
52 |
+
params:
|
53 |
+
embed_dim: 4
|
54 |
+
monitor: val/rec_loss
|
55 |
+
ddconfig:
|
56 |
+
double_z: true
|
57 |
+
z_channels: 4
|
58 |
+
resolution: 256
|
59 |
+
in_channels: 3
|
60 |
+
out_ch: 3
|
61 |
+
ch: 128
|
62 |
+
ch_mult:
|
63 |
+
- 1
|
64 |
+
- 2
|
65 |
+
- 4
|
66 |
+
- 4
|
67 |
+
num_res_blocks: 2
|
68 |
+
attn_resolutions: []
|
69 |
+
dropout: 0.0
|
70 |
+
lossconfig:
|
71 |
+
target: torch.nn.Identity
|
72 |
+
|
73 |
+
cond_stage_config:
|
74 |
+
target: ldm.modules.encoders.modules.BERTEmbedder
|
75 |
+
params:
|
76 |
+
n_embed: 1280
|
77 |
+
n_layer: 32
|
configs/latent-diffusion/txt2img-1p4B-finetune.yaml
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 5.0e-3
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.012
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: image
|
11 |
+
cond_stage_key: caption
|
12 |
+
image_size: 32
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: true
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
embedding_reg_weight: 0.0
|
20 |
+
|
21 |
+
personalization_config:
|
22 |
+
target: ldm.modules.embedding_manager.EmbeddingManager
|
23 |
+
params:
|
24 |
+
placeholder_strings: ["*"]
|
25 |
+
initializer_words: ["sculpture"]
|
26 |
+
per_image_tokens: false
|
27 |
+
num_vectors_per_token: 1
|
28 |
+
progressive_words: False
|
29 |
+
|
30 |
+
unet_config:
|
31 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
32 |
+
params:
|
33 |
+
image_size: 32
|
34 |
+
in_channels: 4
|
35 |
+
out_channels: 4
|
36 |
+
model_channels: 320
|
37 |
+
attention_resolutions:
|
38 |
+
- 4
|
39 |
+
- 2
|
40 |
+
- 1
|
41 |
+
num_res_blocks: 2
|
42 |
+
channel_mult:
|
43 |
+
- 1
|
44 |
+
- 2
|
45 |
+
- 4
|
46 |
+
- 4
|
47 |
+
num_heads: 8
|
48 |
+
use_spatial_transformer: true
|
49 |
+
transformer_depth: 1
|
50 |
+
context_dim: 1280
|
51 |
+
use_checkpoint: true
|
52 |
+
legacy: False
|
53 |
+
|
54 |
+
first_stage_config:
|
55 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
56 |
+
params:
|
57 |
+
embed_dim: 4
|
58 |
+
monitor: val/rec_loss
|
59 |
+
ddconfig:
|
60 |
+
double_z: true
|
61 |
+
z_channels: 4
|
62 |
+
resolution: 256
|
63 |
+
in_channels: 3
|
64 |
+
out_ch: 3
|
65 |
+
ch: 128
|
66 |
+
ch_mult:
|
67 |
+
- 1
|
68 |
+
- 2
|
69 |
+
- 4
|
70 |
+
- 4
|
71 |
+
num_res_blocks: 2
|
72 |
+
attn_resolutions: []
|
73 |
+
dropout: 0.0
|
74 |
+
lossconfig:
|
75 |
+
target: torch.nn.Identity
|
76 |
+
|
77 |
+
cond_stage_config:
|
78 |
+
target: ldm.modules.encoders.modules.BERTEmbedder
|
79 |
+
params:
|
80 |
+
n_embed: 1280
|
81 |
+
n_layer: 32
|
82 |
+
|
83 |
+
|
84 |
+
data:
|
85 |
+
target: main.DataModuleFromConfig
|
86 |
+
params:
|
87 |
+
batch_size: 4
|
88 |
+
num_workers: 2
|
89 |
+
wrap: false
|
90 |
+
train:
|
91 |
+
target: ldm.data.personalized.PersonalizedBase
|
92 |
+
params:
|
93 |
+
size: 256
|
94 |
+
set: train
|
95 |
+
per_image_tokens: false
|
96 |
+
repeats: 100
|
97 |
+
validation:
|
98 |
+
target: ldm.data.personalized.PersonalizedBase
|
99 |
+
params:
|
100 |
+
size: 256
|
101 |
+
set: val
|
102 |
+
per_image_tokens: false
|
103 |
+
repeats: 10
|
104 |
+
|
105 |
+
lightning:
|
106 |
+
modelcheckpoint:
|
107 |
+
params:
|
108 |
+
every_n_train_steps: 500
|
109 |
+
callbacks:
|
110 |
+
image_logger:
|
111 |
+
target: main.ImageLogger
|
112 |
+
params:
|
113 |
+
batch_frequency: 500
|
114 |
+
max_images: 8
|
115 |
+
increase_log_steps: False
|
116 |
+
|
117 |
+
trainer:
|
118 |
+
benchmark: True
|
119 |
+
max_steps: 6100
|
configs/latent-diffusion/txt2img-1p4B-finetune_style.yaml
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 5.0e-3
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.012
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: image
|
11 |
+
cond_stage_key: caption
|
12 |
+
image_size: 32
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: true
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
embedding_reg_weight: 0.0
|
20 |
+
|
21 |
+
personalization_config:
|
22 |
+
target: ldm.modules.embedding_manager.EmbeddingManager
|
23 |
+
params:
|
24 |
+
placeholder_strings: ["*"]
|
25 |
+
initializer_words: ["painting"]
|
26 |
+
per_image_tokens: false
|
27 |
+
num_vectors_per_token: 1
|
28 |
+
|
29 |
+
unet_config:
|
30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
31 |
+
params:
|
32 |
+
image_size: 32
|
33 |
+
in_channels: 4
|
34 |
+
out_channels: 4
|
35 |
+
model_channels: 320
|
36 |
+
attention_resolutions:
|
37 |
+
- 4
|
38 |
+
- 2
|
39 |
+
- 1
|
40 |
+
num_res_blocks: 2
|
41 |
+
channel_mult:
|
42 |
+
- 1
|
43 |
+
- 2
|
44 |
+
- 4
|
45 |
+
- 4
|
46 |
+
num_heads: 8
|
47 |
+
use_spatial_transformer: true
|
48 |
+
transformer_depth: 1
|
49 |
+
context_dim: 1280
|
50 |
+
use_checkpoint: true
|
51 |
+
legacy: False
|
52 |
+
|
53 |
+
first_stage_config:
|
54 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
55 |
+
params:
|
56 |
+
embed_dim: 4
|
57 |
+
monitor: val/rec_loss
|
58 |
+
ddconfig:
|
59 |
+
double_z: true
|
60 |
+
z_channels: 4
|
61 |
+
resolution: 256
|
62 |
+
in_channels: 3
|
63 |
+
out_ch: 3
|
64 |
+
ch: 128
|
65 |
+
ch_mult:
|
66 |
+
- 1
|
67 |
+
- 2
|
68 |
+
- 4
|
69 |
+
- 4
|
70 |
+
num_res_blocks: 2
|
71 |
+
attn_resolutions: []
|
72 |
+
dropout: 0.0
|
73 |
+
lossconfig:
|
74 |
+
target: torch.nn.Identity
|
75 |
+
|
76 |
+
cond_stage_config:
|
77 |
+
target: ldm.modules.encoders.modules.BERTEmbedder
|
78 |
+
params:
|
79 |
+
n_embed: 1280
|
80 |
+
n_layer: 32
|
81 |
+
|
82 |
+
|
83 |
+
data:
|
84 |
+
target: main.DataModuleFromConfig
|
85 |
+
params:
|
86 |
+
batch_size: 4
|
87 |
+
num_workers: 4
|
88 |
+
wrap: false
|
89 |
+
train:
|
90 |
+
target: ldm.data.personalized_style.PersonalizedBase
|
91 |
+
params:
|
92 |
+
size: 256
|
93 |
+
set: train
|
94 |
+
per_image_tokens: false
|
95 |
+
repeats: 100
|
96 |
+
validation:
|
97 |
+
target: ldm.data.personalized_style.PersonalizedBase
|
98 |
+
params:
|
99 |
+
size: 256
|
100 |
+
set: val
|
101 |
+
per_image_tokens: false
|
102 |
+
repeats: 10
|
103 |
+
|
104 |
+
lightning:
|
105 |
+
modelcheckpoint:
|
106 |
+
params:
|
107 |
+
every_n_train_steps: 500
|
108 |
+
callbacks:
|
109 |
+
image_logger:
|
110 |
+
target: main.ImageLogger
|
111 |
+
params:
|
112 |
+
batch_frequency: 500
|
113 |
+
max_images: 8
|
114 |
+
increase_log_steps: False
|
115 |
+
|
116 |
+
trainer:
|
117 |
+
benchmark: True
|
configs/stable-diffusion/v1-finetune.yaml
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 5.0e-03
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: image
|
11 |
+
cond_stage_key: caption
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: true # Note: different from the one we trained before
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
embedding_reg_weight: 0.0
|
20 |
+
unfreeze_model: False
|
21 |
+
model_lr: 0.0
|
22 |
+
|
23 |
+
personalization_config:
|
24 |
+
target: ldm.modules.embedding_manager.EmbeddingManager
|
25 |
+
params:
|
26 |
+
placeholder_strings: ["*"]
|
27 |
+
initializer_words: ["sculpture"]
|
28 |
+
per_image_tokens: false
|
29 |
+
num_vectors_per_token: 1
|
30 |
+
progressive_words: False
|
31 |
+
|
32 |
+
unet_config:
|
33 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
34 |
+
params:
|
35 |
+
image_size: 32 # unused
|
36 |
+
in_channels: 4
|
37 |
+
out_channels: 4
|
38 |
+
model_channels: 320
|
39 |
+
attention_resolutions: [ 4, 2, 1 ]
|
40 |
+
num_res_blocks: 2
|
41 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
42 |
+
num_heads: 8
|
43 |
+
use_spatial_transformer: True
|
44 |
+
transformer_depth: 1
|
45 |
+
context_dim: 768
|
46 |
+
use_checkpoint: True
|
47 |
+
legacy: False
|
48 |
+
|
49 |
+
first_stage_config:
|
50 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
51 |
+
params:
|
52 |
+
embed_dim: 4
|
53 |
+
monitor: val/rec_loss
|
54 |
+
ddconfig:
|
55 |
+
double_z: true
|
56 |
+
z_channels: 4
|
57 |
+
resolution: 512
|
58 |
+
in_channels: 3
|
59 |
+
out_ch: 3
|
60 |
+
ch: 128
|
61 |
+
ch_mult:
|
62 |
+
- 1
|
63 |
+
- 2
|
64 |
+
- 4
|
65 |
+
- 4
|
66 |
+
num_res_blocks: 2
|
67 |
+
attn_resolutions: []
|
68 |
+
dropout: 0.0
|
69 |
+
lossconfig:
|
70 |
+
target: torch.nn.Identity
|
71 |
+
|
72 |
+
cond_stage_config:
|
73 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
74 |
+
|
75 |
+
data:
|
76 |
+
target: main.DataModuleFromConfig
|
77 |
+
params:
|
78 |
+
batch_size: 2
|
79 |
+
num_workers: 2
|
80 |
+
wrap: false
|
81 |
+
train:
|
82 |
+
target: ldm.data.personalized.PersonalizedBase
|
83 |
+
params:
|
84 |
+
size: 512
|
85 |
+
set: train
|
86 |
+
per_image_tokens: false
|
87 |
+
repeats: 100
|
88 |
+
validation:
|
89 |
+
target: ldm.data.personalized.PersonalizedBase
|
90 |
+
params:
|
91 |
+
size: 512
|
92 |
+
set: val
|
93 |
+
per_image_tokens: false
|
94 |
+
repeats: 10
|
95 |
+
|
96 |
+
lightning:
|
97 |
+
modelcheckpoint:
|
98 |
+
params:
|
99 |
+
every_n_train_steps: 500
|
100 |
+
callbacks:
|
101 |
+
image_logger:
|
102 |
+
target: main.ImageLogger
|
103 |
+
params:
|
104 |
+
batch_frequency: 500
|
105 |
+
max_images: 8
|
106 |
+
increase_log_steps: False
|
107 |
+
|
108 |
+
trainer:
|
109 |
+
benchmark: True
|
110 |
+
max_steps: 6100
|
configs/stable-diffusion/v1-finetune_unfrozen.yaml
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-06
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
reg_weight: 1.0
|
6 |
+
linear_start: 0.00085
|
7 |
+
linear_end: 0.0120
|
8 |
+
num_timesteps_cond: 1
|
9 |
+
log_every_t: 200
|
10 |
+
timesteps: 1000
|
11 |
+
first_stage_key: image
|
12 |
+
cond_stage_key: caption
|
13 |
+
image_size: 64
|
14 |
+
channels: 4
|
15 |
+
cond_stage_trainable: true # Note: different from the one we trained before
|
16 |
+
conditioning_key: crossattn
|
17 |
+
monitor: val/loss_simple_ema
|
18 |
+
scale_factor: 0.18215
|
19 |
+
use_ema: False
|
20 |
+
embedding_reg_weight: 0.0
|
21 |
+
unfreeze_model: True
|
22 |
+
model_lr: 1.0e-6
|
23 |
+
|
24 |
+
personalization_config:
|
25 |
+
target: ldm.modules.embedding_manager.EmbeddingManager
|
26 |
+
params:
|
27 |
+
placeholder_strings: ["*"]
|
28 |
+
initializer_words: ["sculpture"]
|
29 |
+
per_image_tokens: false
|
30 |
+
num_vectors_per_token: 1
|
31 |
+
progressive_words: False
|
32 |
+
|
33 |
+
unet_config:
|
34 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
35 |
+
params:
|
36 |
+
image_size: 32 # unused
|
37 |
+
in_channels: 4
|
38 |
+
out_channels: 4
|
39 |
+
model_channels: 320
|
40 |
+
attention_resolutions: [ 4, 2, 1 ]
|
41 |
+
num_res_blocks: 2
|
42 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
43 |
+
num_heads: 8
|
44 |
+
use_spatial_transformer: True
|
45 |
+
transformer_depth: 1
|
46 |
+
context_dim: 768
|
47 |
+
use_checkpoint: True
|
48 |
+
legacy: False
|
49 |
+
|
50 |
+
first_stage_config:
|
51 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
52 |
+
params:
|
53 |
+
embed_dim: 4
|
54 |
+
monitor: val/rec_loss
|
55 |
+
ddconfig:
|
56 |
+
double_z: true
|
57 |
+
z_channels: 4
|
58 |
+
resolution: 512
|
59 |
+
in_channels: 3
|
60 |
+
out_ch: 3
|
61 |
+
ch: 128
|
62 |
+
ch_mult:
|
63 |
+
- 1
|
64 |
+
- 2
|
65 |
+
- 4
|
66 |
+
- 4
|
67 |
+
num_res_blocks: 2
|
68 |
+
attn_resolutions: []
|
69 |
+
dropout: 0.0
|
70 |
+
lossconfig:
|
71 |
+
target: torch.nn.Identity
|
72 |
+
|
73 |
+
cond_stage_config:
|
74 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
75 |
+
|
76 |
+
data:
|
77 |
+
target: main.DataModuleFromConfig
|
78 |
+
params:
|
79 |
+
batch_size: 1
|
80 |
+
num_workers: 2
|
81 |
+
wrap: false
|
82 |
+
train:
|
83 |
+
target: ldm.data.personalized.PersonalizedBase
|
84 |
+
params:
|
85 |
+
size: 512
|
86 |
+
set: train
|
87 |
+
per_image_tokens: false
|
88 |
+
repeats: 100
|
89 |
+
reg:
|
90 |
+
target: ldm.data.personalized.PersonalizedBase
|
91 |
+
params:
|
92 |
+
size: 512
|
93 |
+
set: train
|
94 |
+
reg: true
|
95 |
+
per_image_tokens: false
|
96 |
+
repeats: 10
|
97 |
+
|
98 |
+
validation:
|
99 |
+
target: ldm.data.personalized.PersonalizedBase
|
100 |
+
params:
|
101 |
+
size: 512
|
102 |
+
set: val
|
103 |
+
per_image_tokens: false
|
104 |
+
repeats: 10
|
105 |
+
|
106 |
+
lightning:
|
107 |
+
modelcheckpoint:
|
108 |
+
params:
|
109 |
+
every_n_train_steps: 500
|
110 |
+
callbacks:
|
111 |
+
image_logger:
|
112 |
+
target: main.ImageLogger
|
113 |
+
params:
|
114 |
+
batch_frequency: 500
|
115 |
+
max_images: 8
|
116 |
+
increase_log_steps: False
|
117 |
+
|
118 |
+
trainer:
|
119 |
+
benchmark: True
|
120 |
+
max_steps: 800
|
configs/stable-diffusion/v1-inference.yaml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "jpg"
|
11 |
+
cond_stage_key: "txt"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
|
20 |
+
personalization_config:
|
21 |
+
target: ldm.modules.embedding_manager.EmbeddingManager
|
22 |
+
params:
|
23 |
+
placeholder_strings: ["*"]
|
24 |
+
initializer_words: ["sculpture"]
|
25 |
+
per_image_tokens: false
|
26 |
+
num_vectors_per_token: 1
|
27 |
+
progressive_words: False
|
28 |
+
|
29 |
+
unet_config:
|
30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
31 |
+
params:
|
32 |
+
image_size: 32 # unused
|
33 |
+
in_channels: 4
|
34 |
+
out_channels: 4
|
35 |
+
model_channels: 320
|
36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
37 |
+
num_res_blocks: 2
|
38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
39 |
+
num_heads: 8
|
40 |
+
use_spatial_transformer: True
|
41 |
+
transformer_depth: 1
|
42 |
+
context_dim: 768
|
43 |
+
use_checkpoint: True
|
44 |
+
legacy: False
|
45 |
+
|
46 |
+
first_stage_config:
|
47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
48 |
+
params:
|
49 |
+
embed_dim: 4
|
50 |
+
monitor: val/rec_loss
|
51 |
+
ddconfig:
|
52 |
+
double_z: true
|
53 |
+
z_channels: 4
|
54 |
+
resolution: 256
|
55 |
+
in_channels: 3
|
56 |
+
out_ch: 3
|
57 |
+
ch: 128
|
58 |
+
ch_mult:
|
59 |
+
- 1
|
60 |
+
- 2
|
61 |
+
- 4
|
62 |
+
- 4
|
63 |
+
num_res_blocks: 2
|
64 |
+
attn_resolutions: []
|
65 |
+
dropout: 0.0
|
66 |
+
lossconfig:
|
67 |
+
target: torch.nn.Identity
|
68 |
+
|
69 |
+
cond_stage_config:
|
70 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
environment.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: ldm
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- python=3.8.10
|
7 |
+
- pip=20.3
|
8 |
+
- cudatoolkit=11.3
|
9 |
+
- pytorch=1.10.2
|
10 |
+
- torchvision=0.11.3
|
11 |
+
- numpy=1.22.3
|
12 |
+
- pip:
|
13 |
+
- albumentations==1.1.0
|
14 |
+
- opencv-python==4.2.0.34
|
15 |
+
- pudb==2019.2
|
16 |
+
- imageio==2.14.1
|
17 |
+
- imageio-ffmpeg==0.4.7
|
18 |
+
- pytorch-lightning==1.5.9
|
19 |
+
- omegaconf==2.1.1
|
20 |
+
- test-tube>=0.7.5
|
21 |
+
- streamlit>=0.73.1
|
22 |
+
- setuptools==59.5.0
|
23 |
+
- pillow==9.0.1
|
24 |
+
- einops==0.4.1
|
25 |
+
- torch-fidelity==0.3.0
|
26 |
+
- transformers==4.18.0
|
27 |
+
- torchmetrics==0.6.0
|
28 |
+
- kornia==0.6
|
29 |
+
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
30 |
+
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
31 |
+
- -e .
|
evaluation/__pycache__/clip_eval.cpython-36.pyc
ADDED
Binary file (3.94 kB). View file
|
|
evaluation/__pycache__/clip_eval.cpython-38.pyc
ADDED
Binary file (3.94 kB). View file
|
|
evaluation/clip_eval.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import clip
|
2 |
+
import torch
|
3 |
+
from torchvision import transforms
|
4 |
+
|
5 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
6 |
+
|
7 |
+
class CLIPEvaluator(object):
|
8 |
+
def __init__(self, device, clip_model='ViT-B/32') -> None:
|
9 |
+
self.device = device
|
10 |
+
self.model, clip_preprocess = clip.load(clip_model, device=self.device)
|
11 |
+
|
12 |
+
self.clip_preprocess = clip_preprocess
|
13 |
+
|
14 |
+
self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (generator output) to [0, 1].
|
15 |
+
clip_preprocess.transforms[:2] + # to match CLIP input scale assumptions
|
16 |
+
clip_preprocess.transforms[4:]) # + skip convert PIL to tensor
|
17 |
+
|
18 |
+
def tokenize(self, strings: list):
|
19 |
+
return clip.tokenize(strings).to(self.device)
|
20 |
+
|
21 |
+
@torch.no_grad()
|
22 |
+
def encode_text(self, tokens: list) -> torch.Tensor:
|
23 |
+
return self.model.encode_text(tokens)
|
24 |
+
|
25 |
+
@torch.no_grad()
|
26 |
+
def encode_images(self, images: torch.Tensor) -> torch.Tensor:
|
27 |
+
images = self.preprocess(images).to(self.device)
|
28 |
+
return self.model.encode_image(images)
|
29 |
+
|
30 |
+
def get_text_features(self, text: str, norm: bool = True) -> torch.Tensor:
|
31 |
+
|
32 |
+
tokens = clip.tokenize(text).to(self.device)
|
33 |
+
|
34 |
+
text_features = self.encode_text(tokens).detach()
|
35 |
+
|
36 |
+
if norm:
|
37 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
38 |
+
|
39 |
+
return text_features
|
40 |
+
|
41 |
+
def get_image_features(self, img: torch.Tensor, norm: bool = True) -> torch.Tensor:
|
42 |
+
image_features = self.encode_images(img)
|
43 |
+
|
44 |
+
if norm:
|
45 |
+
image_features /= image_features.clone().norm(dim=-1, keepdim=True)
|
46 |
+
|
47 |
+
return image_features
|
48 |
+
|
49 |
+
def img_to_img_similarity(self, src_images, generated_images):
|
50 |
+
src_img_features = self.get_image_features(src_images)
|
51 |
+
gen_img_features = self.get_image_features(generated_images)
|
52 |
+
|
53 |
+
return (src_img_features @ gen_img_features.T).mean()
|
54 |
+
|
55 |
+
def txt_to_img_similarity(self, text, generated_images):
|
56 |
+
text_features = self.get_text_features(text)
|
57 |
+
gen_img_features = self.get_image_features(generated_images)
|
58 |
+
|
59 |
+
return (text_features @ gen_img_features.T).mean()
|
60 |
+
|
61 |
+
|
62 |
+
class LDMCLIPEvaluator(CLIPEvaluator):
|
63 |
+
def __init__(self, device, clip_model='ViT-B/32') -> None:
|
64 |
+
super().__init__(device, clip_model)
|
65 |
+
|
66 |
+
def evaluate(self, ldm_model, src_images, target_text, n_samples=64, n_steps=50):
|
67 |
+
|
68 |
+
sampler = DDIMSampler(ldm_model)
|
69 |
+
|
70 |
+
samples_per_batch = 8
|
71 |
+
n_batches = n_samples // samples_per_batch
|
72 |
+
|
73 |
+
# generate samples
|
74 |
+
all_samples=list()
|
75 |
+
with torch.no_grad():
|
76 |
+
with ldm_model.ema_scope():
|
77 |
+
uc = ldm_model.get_learned_conditioning(samples_per_batch * [""])
|
78 |
+
|
79 |
+
for batch in range(n_batches):
|
80 |
+
c = ldm_model.get_learned_conditioning(samples_per_batch * [target_text])
|
81 |
+
shape = [4, 256//8, 256//8]
|
82 |
+
samples_ddim, _ = sampler.sample(S=n_steps,
|
83 |
+
conditioning=c,
|
84 |
+
batch_size=samples_per_batch,
|
85 |
+
shape=shape,
|
86 |
+
verbose=False,
|
87 |
+
unconditional_guidance_scale=5.0,
|
88 |
+
unconditional_conditioning=uc,
|
89 |
+
eta=0.0)
|
90 |
+
|
91 |
+
x_samples_ddim = ldm_model.decode_first_stage(samples_ddim)
|
92 |
+
x_samples_ddim = torch.clamp(x_samples_ddim, min=-1.0, max=1.0)
|
93 |
+
|
94 |
+
all_samples.append(x_samples_ddim)
|
95 |
+
|
96 |
+
all_samples = torch.cat(all_samples, axis=0)
|
97 |
+
|
98 |
+
sim_samples_to_img = self.img_to_img_similarity(src_images, all_samples)
|
99 |
+
sim_samples_to_text = self.txt_to_img_similarity(target_text.replace("*", ""), all_samples)
|
100 |
+
|
101 |
+
return sim_samples_to_img, sim_samples_to_text
|
102 |
+
|
103 |
+
|
104 |
+
class ImageDirEvaluator(CLIPEvaluator):
|
105 |
+
def __init__(self, device, clip_model='ViT-B/32') -> None:
|
106 |
+
super().__init__(device, clip_model)
|
107 |
+
|
108 |
+
def evaluate(self, gen_samples, src_images, target_text):
|
109 |
+
|
110 |
+
sim_samples_to_img = self.img_to_img_similarity(src_images, gen_samples)
|
111 |
+
sim_samples_to_text = self.txt_to_img_similarity(target_text.replace("*", ""), gen_samples)
|
112 |
+
|
113 |
+
return sim_samples_to_img, sim_samples_to_text
|
ldm/__pycache__/util.cpython-36.pyc
ADDED
Binary file (3.72 kB). View file
|
|
ldm/__pycache__/util.cpython-38.pyc
ADDED
Binary file (6.12 kB). View file
|
|
ldm/data/__init__.py
ADDED
File without changes
|
ldm/data/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (177 Bytes). View file
|
|
ldm/data/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (181 Bytes). View file
|
|
ldm/data/__pycache__/base.cpython-36.pyc
ADDED
Binary file (1.26 kB). View file
|
|
ldm/data/__pycache__/base.cpython-38.pyc
ADDED
Binary file (1.29 kB). View file
|
|
ldm/data/__pycache__/personalized.cpython-36.pyc
ADDED
Binary file (4.92 kB). View file
|
|
ldm/data/__pycache__/personalized.cpython-38.pyc
ADDED
Binary file (5.82 kB). View file
|
|
ldm/data/__pycache__/personalized_compose.cpython-38.pyc
ADDED
Binary file (6.06 kB). View file
|
|
ldm/data/__pycache__/personalized_detailed_text.cpython-36.pyc
ADDED
Binary file (5.52 kB). View file
|
|
ldm/data/__pycache__/personalized_style.cpython-36.pyc
ADDED
Binary file (4.69 kB). View file
|
|
ldm/data/__pycache__/personalized_style.cpython-38.pyc
ADDED
Binary file (4.78 kB). View file
|
|
ldm/data/base.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
|
3 |
+
|
4 |
+
|
5 |
+
class Txt2ImgIterableBaseDataset(IterableDataset):
|
6 |
+
'''
|
7 |
+
Define an interface to make the IterableDatasets for text2img data chainable
|
8 |
+
'''
|
9 |
+
def __init__(self, num_records=0, valid_ids=None, size=256):
|
10 |
+
super().__init__()
|
11 |
+
self.num_records = num_records
|
12 |
+
self.valid_ids = valid_ids
|
13 |
+
self.sample_ids = valid_ids
|
14 |
+
self.size = size
|
15 |
+
|
16 |
+
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
|
17 |
+
|
18 |
+
def __len__(self):
|
19 |
+
return self.num_records
|
20 |
+
|
21 |
+
@abstractmethod
|
22 |
+
def __iter__(self):
|
23 |
+
pass
|
ldm/data/imagenet.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, yaml, pickle, shutil, tarfile, glob
|
2 |
+
import cv2
|
3 |
+
import albumentations
|
4 |
+
import PIL
|
5 |
+
import numpy as np
|
6 |
+
import torchvision.transforms.functional as TF
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from functools import partial
|
9 |
+
from PIL import Image
|
10 |
+
from tqdm import tqdm
|
11 |
+
from torch.utils.data import Dataset, Subset
|
12 |
+
|
13 |
+
import taming.data.utils as tdu
|
14 |
+
from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
|
15 |
+
from taming.data.imagenet import ImagePaths
|
16 |
+
|
17 |
+
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
|
18 |
+
|
19 |
+
|
20 |
+
def synset2idx(path_to_yaml="data/index_synset.yaml"):
|
21 |
+
with open(path_to_yaml) as f:
|
22 |
+
di2s = yaml.load(f)
|
23 |
+
return dict((v,k) for k,v in di2s.items())
|
24 |
+
|
25 |
+
|
26 |
+
class ImageNetBase(Dataset):
|
27 |
+
def __init__(self, config=None):
|
28 |
+
self.config = config or OmegaConf.create()
|
29 |
+
if not type(self.config)==dict:
|
30 |
+
self.config = OmegaConf.to_container(self.config)
|
31 |
+
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
|
32 |
+
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
|
33 |
+
self._prepare()
|
34 |
+
self._prepare_synset_to_human()
|
35 |
+
self._prepare_idx_to_synset()
|
36 |
+
self._prepare_human_to_integer_label()
|
37 |
+
self._load()
|
38 |
+
|
39 |
+
def __len__(self):
|
40 |
+
return len(self.data)
|
41 |
+
|
42 |
+
def __getitem__(self, i):
|
43 |
+
return self.data[i]
|
44 |
+
|
45 |
+
def _prepare(self):
|
46 |
+
raise NotImplementedError()
|
47 |
+
|
48 |
+
def _filter_relpaths(self, relpaths):
|
49 |
+
ignore = set([
|
50 |
+
"n06596364_9591.JPEG",
|
51 |
+
])
|
52 |
+
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
|
53 |
+
if "sub_indices" in self.config:
|
54 |
+
indices = str_to_indices(self.config["sub_indices"])
|
55 |
+
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
|
56 |
+
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
|
57 |
+
files = []
|
58 |
+
for rpath in relpaths:
|
59 |
+
syn = rpath.split("/")[0]
|
60 |
+
if syn in synsets:
|
61 |
+
files.append(rpath)
|
62 |
+
return files
|
63 |
+
else:
|
64 |
+
return relpaths
|
65 |
+
|
66 |
+
def _prepare_synset_to_human(self):
|
67 |
+
SIZE = 2655750
|
68 |
+
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
|
69 |
+
self.human_dict = os.path.join(self.root, "synset_human.txt")
|
70 |
+
if (not os.path.exists(self.human_dict) or
|
71 |
+
not os.path.getsize(self.human_dict)==SIZE):
|
72 |
+
download(URL, self.human_dict)
|
73 |
+
|
74 |
+
def _prepare_idx_to_synset(self):
|
75 |
+
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
|
76 |
+
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
|
77 |
+
if (not os.path.exists(self.idx2syn)):
|
78 |
+
download(URL, self.idx2syn)
|
79 |
+
|
80 |
+
def _prepare_human_to_integer_label(self):
|
81 |
+
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
|
82 |
+
self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
|
83 |
+
if (not os.path.exists(self.human2integer)):
|
84 |
+
download(URL, self.human2integer)
|
85 |
+
with open(self.human2integer, "r") as f:
|
86 |
+
lines = f.read().splitlines()
|
87 |
+
assert len(lines) == 1000
|
88 |
+
self.human2integer_dict = dict()
|
89 |
+
for line in lines:
|
90 |
+
value, key = line.split(":")
|
91 |
+
self.human2integer_dict[key] = int(value)
|
92 |
+
|
93 |
+
def _load(self):
|
94 |
+
with open(self.txt_filelist, "r") as f:
|
95 |
+
self.relpaths = f.read().splitlines()
|
96 |
+
l1 = len(self.relpaths)
|
97 |
+
self.relpaths = self._filter_relpaths(self.relpaths)
|
98 |
+
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
|
99 |
+
|
100 |
+
self.synsets = [p.split("/")[0] for p in self.relpaths]
|
101 |
+
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
|
102 |
+
|
103 |
+
unique_synsets = np.unique(self.synsets)
|
104 |
+
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
|
105 |
+
if not self.keep_orig_class_label:
|
106 |
+
self.class_labels = [class_dict[s] for s in self.synsets]
|
107 |
+
else:
|
108 |
+
self.class_labels = [self.synset2idx[s] for s in self.synsets]
|
109 |
+
|
110 |
+
with open(self.human_dict, "r") as f:
|
111 |
+
human_dict = f.read().splitlines()
|
112 |
+
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
|
113 |
+
|
114 |
+
self.human_labels = [human_dict[s] for s in self.synsets]
|
115 |
+
|
116 |
+
labels = {
|
117 |
+
"relpath": np.array(self.relpaths),
|
118 |
+
"synsets": np.array(self.synsets),
|
119 |
+
"class_label": np.array(self.class_labels),
|
120 |
+
"human_label": np.array(self.human_labels),
|
121 |
+
}
|
122 |
+
|
123 |
+
if self.process_images:
|
124 |
+
self.size = retrieve(self.config, "size", default=256)
|
125 |
+
self.data = ImagePaths(self.abspaths,
|
126 |
+
labels=labels,
|
127 |
+
size=self.size,
|
128 |
+
random_crop=self.random_crop,
|
129 |
+
)
|
130 |
+
else:
|
131 |
+
self.data = self.abspaths
|
132 |
+
|
133 |
+
|
134 |
+
class ImageNetTrain(ImageNetBase):
|
135 |
+
NAME = "ILSVRC2012_train"
|
136 |
+
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
137 |
+
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
|
138 |
+
FILES = [
|
139 |
+
"ILSVRC2012_img_train.tar",
|
140 |
+
]
|
141 |
+
SIZES = [
|
142 |
+
147897477120,
|
143 |
+
]
|
144 |
+
|
145 |
+
def __init__(self, process_images=True, data_root=None, **kwargs):
|
146 |
+
self.process_images = process_images
|
147 |
+
self.data_root = data_root
|
148 |
+
super().__init__(**kwargs)
|
149 |
+
|
150 |
+
def _prepare(self):
|
151 |
+
if self.data_root:
|
152 |
+
self.root = os.path.join(self.data_root, self.NAME)
|
153 |
+
else:
|
154 |
+
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
155 |
+
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
156 |
+
|
157 |
+
self.datadir = os.path.join(self.root, "data")
|
158 |
+
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
159 |
+
self.expected_length = 1281167
|
160 |
+
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
|
161 |
+
default=True)
|
162 |
+
if not tdu.is_prepared(self.root):
|
163 |
+
# prep
|
164 |
+
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
165 |
+
|
166 |
+
datadir = self.datadir
|
167 |
+
if not os.path.exists(datadir):
|
168 |
+
path = os.path.join(self.root, self.FILES[0])
|
169 |
+
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
170 |
+
import academictorrents as at
|
171 |
+
atpath = at.get(self.AT_HASH, datastore=self.root)
|
172 |
+
assert atpath == path
|
173 |
+
|
174 |
+
print("Extracting {} to {}".format(path, datadir))
|
175 |
+
os.makedirs(datadir, exist_ok=True)
|
176 |
+
with tarfile.open(path, "r:") as tar:
|
177 |
+
tar.extractall(path=datadir)
|
178 |
+
|
179 |
+
print("Extracting sub-tars.")
|
180 |
+
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
|
181 |
+
for subpath in tqdm(subpaths):
|
182 |
+
subdir = subpath[:-len(".tar")]
|
183 |
+
os.makedirs(subdir, exist_ok=True)
|
184 |
+
with tarfile.open(subpath, "r:") as tar:
|
185 |
+
tar.extractall(path=subdir)
|
186 |
+
|
187 |
+
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
188 |
+
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
189 |
+
filelist = sorted(filelist)
|
190 |
+
filelist = "\n".join(filelist)+"\n"
|
191 |
+
with open(self.txt_filelist, "w") as f:
|
192 |
+
f.write(filelist)
|
193 |
+
|
194 |
+
tdu.mark_prepared(self.root)
|
195 |
+
|
196 |
+
|
197 |
+
class ImageNetValidation(ImageNetBase):
|
198 |
+
NAME = "ILSVRC2012_validation"
|
199 |
+
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
200 |
+
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
|
201 |
+
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
|
202 |
+
FILES = [
|
203 |
+
"ILSVRC2012_img_val.tar",
|
204 |
+
"validation_synset.txt",
|
205 |
+
]
|
206 |
+
SIZES = [
|
207 |
+
6744924160,
|
208 |
+
1950000,
|
209 |
+
]
|
210 |
+
|
211 |
+
def __init__(self, process_images=True, data_root=None, **kwargs):
|
212 |
+
self.data_root = data_root
|
213 |
+
self.process_images = process_images
|
214 |
+
super().__init__(**kwargs)
|
215 |
+
|
216 |
+
def _prepare(self):
|
217 |
+
if self.data_root:
|
218 |
+
self.root = os.path.join(self.data_root, self.NAME)
|
219 |
+
else:
|
220 |
+
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
221 |
+
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
222 |
+
self.datadir = os.path.join(self.root, "data")
|
223 |
+
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
224 |
+
self.expected_length = 50000
|
225 |
+
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
|
226 |
+
default=False)
|
227 |
+
if not tdu.is_prepared(self.root):
|
228 |
+
# prep
|
229 |
+
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
230 |
+
|
231 |
+
datadir = self.datadir
|
232 |
+
if not os.path.exists(datadir):
|
233 |
+
path = os.path.join(self.root, self.FILES[0])
|
234 |
+
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
235 |
+
import academictorrents as at
|
236 |
+
atpath = at.get(self.AT_HASH, datastore=self.root)
|
237 |
+
assert atpath == path
|
238 |
+
|
239 |
+
print("Extracting {} to {}".format(path, datadir))
|
240 |
+
os.makedirs(datadir, exist_ok=True)
|
241 |
+
with tarfile.open(path, "r:") as tar:
|
242 |
+
tar.extractall(path=datadir)
|
243 |
+
|
244 |
+
vspath = os.path.join(self.root, self.FILES[1])
|
245 |
+
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
|
246 |
+
download(self.VS_URL, vspath)
|
247 |
+
|
248 |
+
with open(vspath, "r") as f:
|
249 |
+
synset_dict = f.read().splitlines()
|
250 |
+
synset_dict = dict(line.split() for line in synset_dict)
|
251 |
+
|
252 |
+
print("Reorganizing into synset folders")
|
253 |
+
synsets = np.unique(list(synset_dict.values()))
|
254 |
+
for s in synsets:
|
255 |
+
os.makedirs(os.path.join(datadir, s), exist_ok=True)
|
256 |
+
for k, v in synset_dict.items():
|
257 |
+
src = os.path.join(datadir, k)
|
258 |
+
dst = os.path.join(datadir, v)
|
259 |
+
shutil.move(src, dst)
|
260 |
+
|
261 |
+
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
262 |
+
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
263 |
+
filelist = sorted(filelist)
|
264 |
+
filelist = "\n".join(filelist)+"\n"
|
265 |
+
with open(self.txt_filelist, "w") as f:
|
266 |
+
f.write(filelist)
|
267 |
+
|
268 |
+
tdu.mark_prepared(self.root)
|
269 |
+
|
270 |
+
|
271 |
+
|
272 |
+
class ImageNetSR(Dataset):
|
273 |
+
def __init__(self, size=None,
|
274 |
+
degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
|
275 |
+
random_crop=True):
|
276 |
+
"""
|
277 |
+
Imagenet Superresolution Dataloader
|
278 |
+
Performs following ops in order:
|
279 |
+
1. crops a crop of size s from image either as random or center crop
|
280 |
+
2. resizes crop to size with cv2.area_interpolation
|
281 |
+
3. degrades resized crop with degradation_fn
|
282 |
+
|
283 |
+
:param size: resizing to size after cropping
|
284 |
+
:param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
|
285 |
+
:param downscale_f: Low Resolution Downsample factor
|
286 |
+
:param min_crop_f: determines crop size s,
|
287 |
+
where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
|
288 |
+
:param max_crop_f: ""
|
289 |
+
:param data_root:
|
290 |
+
:param random_crop:
|
291 |
+
"""
|
292 |
+
self.base = self.get_base()
|
293 |
+
assert size
|
294 |
+
assert (size / downscale_f).is_integer()
|
295 |
+
self.size = size
|
296 |
+
self.LR_size = int(size / downscale_f)
|
297 |
+
self.min_crop_f = min_crop_f
|
298 |
+
self.max_crop_f = max_crop_f
|
299 |
+
assert(max_crop_f <= 1.)
|
300 |
+
self.center_crop = not random_crop
|
301 |
+
|
302 |
+
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
|
303 |
+
|
304 |
+
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
|
305 |
+
|
306 |
+
if degradation == "bsrgan":
|
307 |
+
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
|
308 |
+
|
309 |
+
elif degradation == "bsrgan_light":
|
310 |
+
self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
|
311 |
+
|
312 |
+
else:
|
313 |
+
interpolation_fn = {
|
314 |
+
"cv_nearest": cv2.INTER_NEAREST,
|
315 |
+
"cv_bilinear": cv2.INTER_LINEAR,
|
316 |
+
"cv_bicubic": cv2.INTER_CUBIC,
|
317 |
+
"cv_area": cv2.INTER_AREA,
|
318 |
+
"cv_lanczos": cv2.INTER_LANCZOS4,
|
319 |
+
"pil_nearest": PIL.Image.NEAREST,
|
320 |
+
"pil_bilinear": PIL.Image.BILINEAR,
|
321 |
+
"pil_bicubic": PIL.Image.BICUBIC,
|
322 |
+
"pil_box": PIL.Image.BOX,
|
323 |
+
"pil_hamming": PIL.Image.HAMMING,
|
324 |
+
"pil_lanczos": PIL.Image.LANCZOS,
|
325 |
+
}[degradation]
|
326 |
+
|
327 |
+
self.pil_interpolation = degradation.startswith("pil_")
|
328 |
+
|
329 |
+
if self.pil_interpolation:
|
330 |
+
self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
|
331 |
+
|
332 |
+
else:
|
333 |
+
self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
|
334 |
+
interpolation=interpolation_fn)
|
335 |
+
|
336 |
+
def __len__(self):
|
337 |
+
return len(self.base)
|
338 |
+
|
339 |
+
def __getitem__(self, i):
|
340 |
+
example = self.base[i]
|
341 |
+
image = Image.open(example["file_path_"])
|
342 |
+
|
343 |
+
if not image.mode == "RGB":
|
344 |
+
image = image.convert("RGB")
|
345 |
+
|
346 |
+
image = np.array(image).astype(np.uint8)
|
347 |
+
|
348 |
+
min_side_len = min(image.shape[:2])
|
349 |
+
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
|
350 |
+
crop_side_len = int(crop_side_len)
|
351 |
+
|
352 |
+
if self.center_crop:
|
353 |
+
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
|
354 |
+
|
355 |
+
else:
|
356 |
+
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
|
357 |
+
|
358 |
+
image = self.cropper(image=image)["image"]
|
359 |
+
image = self.image_rescaler(image=image)["image"]
|
360 |
+
|
361 |
+
if self.pil_interpolation:
|
362 |
+
image_pil = PIL.Image.fromarray(image)
|
363 |
+
LR_image = self.degradation_process(image_pil)
|
364 |
+
LR_image = np.array(LR_image).astype(np.uint8)
|
365 |
+
|
366 |
+
else:
|
367 |
+
LR_image = self.degradation_process(image=image)["image"]
|
368 |
+
|
369 |
+
example["image"] = (image/127.5 - 1.0).astype(np.float32)
|
370 |
+
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
|
371 |
+
|
372 |
+
return example
|
373 |
+
|
374 |
+
|
375 |
+
class ImageNetSRTrain(ImageNetSR):
|
376 |
+
def __init__(self, **kwargs):
|
377 |
+
super().__init__(**kwargs)
|
378 |
+
|
379 |
+
def get_base(self):
|
380 |
+
with open("data/imagenet_train_hr_indices.p", "rb") as f:
|
381 |
+
indices = pickle.load(f)
|
382 |
+
dset = ImageNetTrain(process_images=False,)
|
383 |
+
return Subset(dset, indices)
|
384 |
+
|
385 |
+
|
386 |
+
class ImageNetSRValidation(ImageNetSR):
|
387 |
+
def __init__(self, **kwargs):
|
388 |
+
super().__init__(**kwargs)
|
389 |
+
|
390 |
+
def get_base(self):
|
391 |
+
with open("data/imagenet_val_hr_indices.p", "rb") as f:
|
392 |
+
indices = pickle.load(f)
|
393 |
+
dset = ImageNetValidation(process_images=False,)
|
394 |
+
return Subset(dset, indices)
|
ldm/data/lsun.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import PIL
|
4 |
+
from PIL import Image
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torchvision import transforms
|
7 |
+
|
8 |
+
|
9 |
+
class LSUNBase(Dataset):
|
10 |
+
def __init__(self,
|
11 |
+
txt_file,
|
12 |
+
data_root,
|
13 |
+
size=None,
|
14 |
+
interpolation="bicubic",
|
15 |
+
flip_p=0.5
|
16 |
+
):
|
17 |
+
self.data_paths = txt_file
|
18 |
+
self.data_root = data_root
|
19 |
+
with open(self.data_paths, "r") as f:
|
20 |
+
self.image_paths = f.read().splitlines()
|
21 |
+
self._length = len(self.image_paths)
|
22 |
+
self.labels = {
|
23 |
+
"relative_file_path_": [l for l in self.image_paths],
|
24 |
+
"file_path_": [os.path.join(self.data_root, l)
|
25 |
+
for l in self.image_paths],
|
26 |
+
}
|
27 |
+
|
28 |
+
self.size = size
|
29 |
+
self.interpolation = {"linear": PIL.Image.LINEAR,
|
30 |
+
"bilinear": PIL.Image.BILINEAR,
|
31 |
+
"bicubic": PIL.Image.BICUBIC,
|
32 |
+
"lanczos": PIL.Image.LANCZOS,
|
33 |
+
}[interpolation]
|
34 |
+
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return self._length
|
38 |
+
|
39 |
+
def __getitem__(self, i):
|
40 |
+
example = dict((k, self.labels[k][i]) for k in self.labels)
|
41 |
+
image = Image.open(example["file_path_"])
|
42 |
+
if not image.mode == "RGB":
|
43 |
+
image = image.convert("RGB")
|
44 |
+
|
45 |
+
# default to score-sde preprocessing
|
46 |
+
img = np.array(image).astype(np.uint8)
|
47 |
+
crop = min(img.shape[0], img.shape[1])
|
48 |
+
h, w, = img.shape[0], img.shape[1]
|
49 |
+
img = img[(h - crop) // 2:(h + crop) // 2,
|
50 |
+
(w - crop) // 2:(w + crop) // 2]
|
51 |
+
|
52 |
+
image = Image.fromarray(img)
|
53 |
+
if self.size is not None:
|
54 |
+
image = image.resize((self.size, self.size), resample=self.interpolation)
|
55 |
+
|
56 |
+
image = self.flip(image)
|
57 |
+
image = np.array(image).astype(np.uint8)
|
58 |
+
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
59 |
+
return example
|
60 |
+
|
61 |
+
|
62 |
+
class LSUNChurchesTrain(LSUNBase):
|
63 |
+
def __init__(self, **kwargs):
|
64 |
+
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
|
65 |
+
|
66 |
+
|
67 |
+
class LSUNChurchesValidation(LSUNBase):
|
68 |
+
def __init__(self, flip_p=0., **kwargs):
|
69 |
+
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
|
70 |
+
flip_p=flip_p, **kwargs)
|
71 |
+
|
72 |
+
|
73 |
+
class LSUNBedroomsTrain(LSUNBase):
|
74 |
+
def __init__(self, **kwargs):
|
75 |
+
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
|
76 |
+
|
77 |
+
|
78 |
+
class LSUNBedroomsValidation(LSUNBase):
|
79 |
+
def __init__(self, flip_p=0.0, **kwargs):
|
80 |
+
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
|
81 |
+
flip_p=flip_p, **kwargs)
|
82 |
+
|
83 |
+
|
84 |
+
class LSUNCatsTrain(LSUNBase):
|
85 |
+
def __init__(self, **kwargs):
|
86 |
+
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
|
87 |
+
|
88 |
+
|
89 |
+
class LSUNCatsValidation(LSUNBase):
|
90 |
+
def __init__(self, flip_p=0., **kwargs):
|
91 |
+
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
|
92 |
+
flip_p=flip_p, **kwargs)
|
ldm/data/personalized.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import PIL
|
4 |
+
from PIL import Image
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torchvision import transforms
|
7 |
+
|
8 |
+
import random
|
9 |
+
|
10 |
+
training_templates_smallest = [
|
11 |
+
'photo of a sks {}',
|
12 |
+
]
|
13 |
+
|
14 |
+
reg_templates_smallest = [
|
15 |
+
'photo of a {}',
|
16 |
+
]
|
17 |
+
|
18 |
+
imagenet_templates_small = [
|
19 |
+
'a photo of a {}',
|
20 |
+
'a rendering of a {}',
|
21 |
+
'a cropped photo of the {}',
|
22 |
+
'the photo of a {}',
|
23 |
+
'a photo of a clean {}',
|
24 |
+
'a photo of a dirty {}',
|
25 |
+
'a dark photo of the {}',
|
26 |
+
'a photo of my {}',
|
27 |
+
'a photo of the cool {}',
|
28 |
+
'a close-up photo of a {}',
|
29 |
+
'a bright photo of the {}',
|
30 |
+
'a cropped photo of a {}',
|
31 |
+
'a photo of the {}',
|
32 |
+
'a good photo of the {}',
|
33 |
+
'a photo of one {}',
|
34 |
+
'a close-up photo of the {}',
|
35 |
+
'a rendition of the {}',
|
36 |
+
'a photo of the clean {}',
|
37 |
+
'a rendition of a {}',
|
38 |
+
'a photo of a nice {}',
|
39 |
+
'a good photo of a {}',
|
40 |
+
'a photo of the nice {}',
|
41 |
+
'a photo of the small {}',
|
42 |
+
'a photo of the weird {}',
|
43 |
+
'a photo of the large {}',
|
44 |
+
'a photo of a cool {}',
|
45 |
+
'a photo of a small {}',
|
46 |
+
'an illustration of a {}',
|
47 |
+
'a rendering of a {}',
|
48 |
+
'a cropped photo of the {}',
|
49 |
+
'the photo of a {}',
|
50 |
+
'an illustration of a clean {}',
|
51 |
+
'an illustration of a dirty {}',
|
52 |
+
'a dark photo of the {}',
|
53 |
+
'an illustration of my {}',
|
54 |
+
'an illustration of the cool {}',
|
55 |
+
'a close-up photo of a {}',
|
56 |
+
'a bright photo of the {}',
|
57 |
+
'a cropped photo of a {}',
|
58 |
+
'an illustration of the {}',
|
59 |
+
'a good photo of the {}',
|
60 |
+
'an illustration of one {}',
|
61 |
+
'a close-up photo of the {}',
|
62 |
+
'a rendition of the {}',
|
63 |
+
'an illustration of the clean {}',
|
64 |
+
'a rendition of a {}',
|
65 |
+
'an illustration of a nice {}',
|
66 |
+
'a good photo of a {}',
|
67 |
+
'an illustration of the nice {}',
|
68 |
+
'an illustration of the small {}',
|
69 |
+
'an illustration of the weird {}',
|
70 |
+
'an illustration of the large {}',
|
71 |
+
'an illustration of a cool {}',
|
72 |
+
'an illustration of a small {}',
|
73 |
+
'a depiction of a {}',
|
74 |
+
'a rendering of a {}',
|
75 |
+
'a cropped photo of the {}',
|
76 |
+
'the photo of a {}',
|
77 |
+
'a depiction of a clean {}',
|
78 |
+
'a depiction of a dirty {}',
|
79 |
+
'a dark photo of the {}',
|
80 |
+
'a depiction of my {}',
|
81 |
+
'a depiction of the cool {}',
|
82 |
+
'a close-up photo of a {}',
|
83 |
+
'a bright photo of the {}',
|
84 |
+
'a cropped photo of a {}',
|
85 |
+
'a depiction of the {}',
|
86 |
+
'a good photo of the {}',
|
87 |
+
'a depiction of one {}',
|
88 |
+
'a close-up photo of the {}',
|
89 |
+
'a rendition of the {}',
|
90 |
+
'a depiction of the clean {}',
|
91 |
+
'a rendition of a {}',
|
92 |
+
'a depiction of a nice {}',
|
93 |
+
'a good photo of a {}',
|
94 |
+
'a depiction of the nice {}',
|
95 |
+
'a depiction of the small {}',
|
96 |
+
'a depiction of the weird {}',
|
97 |
+
'a depiction of the large {}',
|
98 |
+
'a depiction of a cool {}',
|
99 |
+
'a depiction of a small {}',
|
100 |
+
]
|
101 |
+
|
102 |
+
imagenet_dual_templates_small = [
|
103 |
+
'a photo of a {} with {}',
|
104 |
+
'a rendering of a {} with {}',
|
105 |
+
'a cropped photo of the {} with {}',
|
106 |
+
'the photo of a {} with {}',
|
107 |
+
'a photo of a clean {} with {}',
|
108 |
+
'a photo of a dirty {} with {}',
|
109 |
+
'a dark photo of the {} with {}',
|
110 |
+
'a photo of my {} with {}',
|
111 |
+
'a photo of the cool {} with {}',
|
112 |
+
'a close-up photo of a {} with {}',
|
113 |
+
'a bright photo of the {} with {}',
|
114 |
+
'a cropped photo of a {} with {}',
|
115 |
+
'a photo of the {} with {}',
|
116 |
+
'a good photo of the {} with {}',
|
117 |
+
'a photo of one {} with {}',
|
118 |
+
'a close-up photo of the {} with {}',
|
119 |
+
'a rendition of the {} with {}',
|
120 |
+
'a photo of the clean {} with {}',
|
121 |
+
'a rendition of a {} with {}',
|
122 |
+
'a photo of a nice {} with {}',
|
123 |
+
'a good photo of a {} with {}',
|
124 |
+
'a photo of the nice {} with {}',
|
125 |
+
'a photo of the small {} with {}',
|
126 |
+
'a photo of the weird {} with {}',
|
127 |
+
'a photo of the large {} with {}',
|
128 |
+
'a photo of a cool {} with {}',
|
129 |
+
'a photo of a small {} with {}',
|
130 |
+
]
|
131 |
+
|
132 |
+
per_img_token_list = [
|
133 |
+
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
|
134 |
+
]
|
135 |
+
|
136 |
+
class PersonalizedBase(Dataset):
|
137 |
+
def __init__(self,
|
138 |
+
data_root,
|
139 |
+
size=None,
|
140 |
+
repeats=100,
|
141 |
+
interpolation="bicubic",
|
142 |
+
flip_p=0.5,
|
143 |
+
set="train",
|
144 |
+
placeholder_token="dog",
|
145 |
+
per_image_tokens=False,
|
146 |
+
center_crop=False,
|
147 |
+
mixing_prob=0.25,
|
148 |
+
coarse_class_text=None,
|
149 |
+
reg = False
|
150 |
+
):
|
151 |
+
|
152 |
+
self.data_root = data_root
|
153 |
+
|
154 |
+
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
|
155 |
+
|
156 |
+
# self._length = len(self.image_paths)
|
157 |
+
self.num_images = len(self.image_paths)
|
158 |
+
self._length = self.num_images
|
159 |
+
|
160 |
+
self.placeholder_token = placeholder_token
|
161 |
+
|
162 |
+
self.per_image_tokens = per_image_tokens
|
163 |
+
self.center_crop = center_crop
|
164 |
+
self.mixing_prob = mixing_prob
|
165 |
+
|
166 |
+
self.coarse_class_text = coarse_class_text
|
167 |
+
|
168 |
+
if per_image_tokens:
|
169 |
+
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
|
170 |
+
|
171 |
+
if set == "train":
|
172 |
+
self._length = self.num_images * repeats
|
173 |
+
|
174 |
+
self.size = size
|
175 |
+
self.interpolation = {"linear": PIL.Image.LINEAR,
|
176 |
+
"bilinear": PIL.Image.BILINEAR,
|
177 |
+
"bicubic": PIL.Image.BICUBIC,
|
178 |
+
"lanczos": PIL.Image.LANCZOS,
|
179 |
+
}[interpolation]
|
180 |
+
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
181 |
+
self.reg = reg
|
182 |
+
|
183 |
+
def __len__(self):
|
184 |
+
return self._length
|
185 |
+
|
186 |
+
def __getitem__(self, i):
|
187 |
+
example = {}
|
188 |
+
image = Image.open(self.image_paths[i % self.num_images])
|
189 |
+
|
190 |
+
if not image.mode == "RGB":
|
191 |
+
image = image.convert("RGB")
|
192 |
+
|
193 |
+
placeholder_string = self.placeholder_token
|
194 |
+
if self.coarse_class_text:
|
195 |
+
placeholder_string = f"{self.coarse_class_text} {placeholder_string}"
|
196 |
+
|
197 |
+
if not self.reg:
|
198 |
+
text = random.choice(training_templates_smallest).format(placeholder_string)
|
199 |
+
else:
|
200 |
+
text = random.choice(reg_templates_smallest).format(placeholder_string)
|
201 |
+
|
202 |
+
example["caption"] = text
|
203 |
+
|
204 |
+
# default to score-sde preprocessing
|
205 |
+
img = np.array(image).astype(np.uint8)
|
206 |
+
|
207 |
+
if self.center_crop:
|
208 |
+
crop = min(img.shape[0], img.shape[1])
|
209 |
+
h, w, = img.shape[0], img.shape[1]
|
210 |
+
img = img[(h - crop) // 2:(h + crop) // 2,
|
211 |
+
(w - crop) // 2:(w + crop) // 2]
|
212 |
+
|
213 |
+
image = Image.fromarray(img)
|
214 |
+
if self.size is not None:
|
215 |
+
image = image.resize((self.size, self.size), resample=self.interpolation)
|
216 |
+
|
217 |
+
image = self.flip(image)
|
218 |
+
image = np.array(image).astype(np.uint8)
|
219 |
+
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
220 |
+
return example
|
ldm/data/personalized_style.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import PIL
|
4 |
+
from PIL import Image
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torchvision import transforms
|
7 |
+
|
8 |
+
import random
|
9 |
+
|
10 |
+
imagenet_templates_small = [
|
11 |
+
'a painting in the style of {}',
|
12 |
+
'a rendering in the style of {}',
|
13 |
+
'a cropped painting in the style of {}',
|
14 |
+
'the painting in the style of {}',
|
15 |
+
'a clean painting in the style of {}',
|
16 |
+
'a dirty painting in the style of {}',
|
17 |
+
'a dark painting in the style of {}',
|
18 |
+
'a picture in the style of {}',
|
19 |
+
'a cool painting in the style of {}',
|
20 |
+
'a close-up painting in the style of {}',
|
21 |
+
'a bright painting in the style of {}',
|
22 |
+
'a cropped painting in the style of {}',
|
23 |
+
'a good painting in the style of {}',
|
24 |
+
'a close-up painting in the style of {}',
|
25 |
+
'a rendition in the style of {}',
|
26 |
+
'a nice painting in the style of {}',
|
27 |
+
'a small painting in the style of {}',
|
28 |
+
'a weird painting in the style of {}',
|
29 |
+
'a large painting in the style of {}',
|
30 |
+
]
|
31 |
+
|
32 |
+
imagenet_dual_templates_small = [
|
33 |
+
'a painting in the style of {} with {}',
|
34 |
+
'a rendering in the style of {} with {}',
|
35 |
+
'a cropped painting in the style of {} with {}',
|
36 |
+
'the painting in the style of {} with {}',
|
37 |
+
'a clean painting in the style of {} with {}',
|
38 |
+
'a dirty painting in the style of {} with {}',
|
39 |
+
'a dark painting in the style of {} with {}',
|
40 |
+
'a cool painting in the style of {} with {}',
|
41 |
+
'a close-up painting in the style of {} with {}',
|
42 |
+
'a bright painting in the style of {} with {}',
|
43 |
+
'a cropped painting in the style of {} with {}',
|
44 |
+
'a good painting in the style of {} with {}',
|
45 |
+
'a painting of one {} in the style of {}',
|
46 |
+
'a nice painting in the style of {} with {}',
|
47 |
+
'a small painting in the style of {} with {}',
|
48 |
+
'a weird painting in the style of {} with {}',
|
49 |
+
'a large painting in the style of {} with {}',
|
50 |
+
]
|
51 |
+
|
52 |
+
per_img_token_list = [
|
53 |
+
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
|
54 |
+
]
|
55 |
+
|
56 |
+
class PersonalizedBase(Dataset):
|
57 |
+
def __init__(self,
|
58 |
+
data_root,
|
59 |
+
size=None,
|
60 |
+
repeats=100,
|
61 |
+
interpolation="bicubic",
|
62 |
+
flip_p=0.5,
|
63 |
+
set="train",
|
64 |
+
placeholder_token="*",
|
65 |
+
per_image_tokens=False,
|
66 |
+
center_crop=False,
|
67 |
+
):
|
68 |
+
|
69 |
+
self.data_root = data_root
|
70 |
+
|
71 |
+
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
|
72 |
+
|
73 |
+
# self._length = len(self.image_paths)
|
74 |
+
self.num_images = len(self.image_paths)
|
75 |
+
self._length = self.num_images
|
76 |
+
|
77 |
+
self.placeholder_token = placeholder_token
|
78 |
+
|
79 |
+
self.per_image_tokens = per_image_tokens
|
80 |
+
self.center_crop = center_crop
|
81 |
+
|
82 |
+
if per_image_tokens:
|
83 |
+
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
|
84 |
+
|
85 |
+
if set == "train":
|
86 |
+
self._length = self.num_images * repeats
|
87 |
+
|
88 |
+
self.size = size
|
89 |
+
self.interpolation = {"linear": PIL.Image.LINEAR,
|
90 |
+
"bilinear": PIL.Image.BILINEAR,
|
91 |
+
"bicubic": PIL.Image.BICUBIC,
|
92 |
+
"lanczos": PIL.Image.LANCZOS,
|
93 |
+
}[interpolation]
|
94 |
+
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
95 |
+
|
96 |
+
def __len__(self):
|
97 |
+
return self._length
|
98 |
+
|
99 |
+
def __getitem__(self, i):
|
100 |
+
example = {}
|
101 |
+
image = Image.open(self.image_paths[i % self.num_images])
|
102 |
+
|
103 |
+
if not image.mode == "RGB":
|
104 |
+
image = image.convert("RGB")
|
105 |
+
|
106 |
+
if self.per_image_tokens and np.random.uniform() < 0.25:
|
107 |
+
text = random.choice(imagenet_dual_templates_small).format(self.placeholder_token, per_img_token_list[i % self.num_images])
|
108 |
+
else:
|
109 |
+
text = random.choice(imagenet_templates_small).format(self.placeholder_token)
|
110 |
+
|
111 |
+
example["caption"] = text
|
112 |
+
|
113 |
+
# default to score-sde preprocessing
|
114 |
+
img = np.array(image).astype(np.uint8)
|
115 |
+
|
116 |
+
if self.center_crop:
|
117 |
+
crop = min(img.shape[0], img.shape[1])
|
118 |
+
h, w, = img.shape[0], img.shape[1]
|
119 |
+
img = img[(h - crop) // 2:(h + crop) // 2,
|
120 |
+
(w - crop) // 2:(w + crop) // 2]
|
121 |
+
|
122 |
+
image = Image.fromarray(img)
|
123 |
+
if self.size is not None:
|
124 |
+
image = image.resize((self.size, self.size), resample=self.interpolation)
|
125 |
+
|
126 |
+
image = self.flip(image)
|
127 |
+
image = np.array(image).astype(np.uint8)
|
128 |
+
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
129 |
+
return example
|
ldm/lr_scheduler.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class LambdaWarmUpCosineScheduler:
|
5 |
+
"""
|
6 |
+
note: use with a base_lr of 1.0
|
7 |
+
"""
|
8 |
+
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
9 |
+
self.lr_warm_up_steps = warm_up_steps
|
10 |
+
self.lr_start = lr_start
|
11 |
+
self.lr_min = lr_min
|
12 |
+
self.lr_max = lr_max
|
13 |
+
self.lr_max_decay_steps = max_decay_steps
|
14 |
+
self.last_lr = 0.
|
15 |
+
self.verbosity_interval = verbosity_interval
|
16 |
+
|
17 |
+
def schedule(self, n, **kwargs):
|
18 |
+
if self.verbosity_interval > 0:
|
19 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
20 |
+
if n < self.lr_warm_up_steps:
|
21 |
+
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
22 |
+
self.last_lr = lr
|
23 |
+
return lr
|
24 |
+
else:
|
25 |
+
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
26 |
+
t = min(t, 1.0)
|
27 |
+
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
28 |
+
1 + np.cos(t * np.pi))
|
29 |
+
self.last_lr = lr
|
30 |
+
return lr
|
31 |
+
|
32 |
+
def __call__(self, n, **kwargs):
|
33 |
+
return self.schedule(n,**kwargs)
|
34 |
+
|
35 |
+
|
36 |
+
class LambdaWarmUpCosineScheduler2:
|
37 |
+
"""
|
38 |
+
supports repeated iterations, configurable via lists
|
39 |
+
note: use with a base_lr of 1.0.
|
40 |
+
"""
|
41 |
+
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
|
42 |
+
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
|
43 |
+
self.lr_warm_up_steps = warm_up_steps
|
44 |
+
self.f_start = f_start
|
45 |
+
self.f_min = f_min
|
46 |
+
self.f_max = f_max
|
47 |
+
self.cycle_lengths = cycle_lengths
|
48 |
+
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
49 |
+
self.last_f = 0.
|
50 |
+
self.verbosity_interval = verbosity_interval
|
51 |
+
|
52 |
+
def find_in_interval(self, n):
|
53 |
+
interval = 0
|
54 |
+
for cl in self.cum_cycles[1:]:
|
55 |
+
if n <= cl:
|
56 |
+
return interval
|
57 |
+
interval += 1
|
58 |
+
|
59 |
+
def schedule(self, n, **kwargs):
|
60 |
+
cycle = self.find_in_interval(n)
|
61 |
+
n = n - self.cum_cycles[cycle]
|
62 |
+
if self.verbosity_interval > 0:
|
63 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
64 |
+
f"current cycle {cycle}")
|
65 |
+
if n < self.lr_warm_up_steps[cycle]:
|
66 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
67 |
+
self.last_f = f
|
68 |
+
return f
|
69 |
+
else:
|
70 |
+
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
71 |
+
t = min(t, 1.0)
|
72 |
+
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
73 |
+
1 + np.cos(t * np.pi))
|
74 |
+
self.last_f = f
|
75 |
+
return f
|
76 |
+
|
77 |
+
def __call__(self, n, **kwargs):
|
78 |
+
return self.schedule(n, **kwargs)
|
79 |
+
|
80 |
+
|
81 |
+
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
82 |
+
|
83 |
+
def schedule(self, n, **kwargs):
|
84 |
+
cycle = self.find_in_interval(n)
|
85 |
+
n = n - self.cum_cycles[cycle]
|
86 |
+
if self.verbosity_interval > 0:
|
87 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
88 |
+
f"current cycle {cycle}")
|
89 |
+
|
90 |
+
if n < self.lr_warm_up_steps[cycle]:
|
91 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
92 |
+
self.last_f = f
|
93 |
+
return f
|
94 |
+
else:
|
95 |
+
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
|
96 |
+
self.last_f = f
|
97 |
+
return f
|
98 |
+
|
ldm/models/__pycache__/autoencoder.cpython-36.pyc
ADDED
Binary file (13.7 kB). View file
|
|
ldm/models/__pycache__/autoencoder.cpython-38.pyc
ADDED
Binary file (13.6 kB). View file
|
|
ldm/models/autoencoder.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import pytorch_lightning as pl
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from contextlib import contextmanager
|
5 |
+
|
6 |
+
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
7 |
+
|
8 |
+
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
9 |
+
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
10 |
+
|
11 |
+
from ldm.util import instantiate_from_config
|
12 |
+
|
13 |
+
|
14 |
+
class VQModel(pl.LightningModule):
|
15 |
+
def __init__(self,
|
16 |
+
ddconfig,
|
17 |
+
lossconfig,
|
18 |
+
n_embed,
|
19 |
+
embed_dim,
|
20 |
+
ckpt_path=None,
|
21 |
+
ignore_keys=[],
|
22 |
+
image_key="image",
|
23 |
+
colorize_nlabels=None,
|
24 |
+
monitor=None,
|
25 |
+
batch_resize_range=None,
|
26 |
+
scheduler_config=None,
|
27 |
+
lr_g_factor=1.0,
|
28 |
+
remap=None,
|
29 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
30 |
+
use_ema=False
|
31 |
+
):
|
32 |
+
super().__init__()
|
33 |
+
self.embed_dim = embed_dim
|
34 |
+
self.n_embed = n_embed
|
35 |
+
self.image_key = image_key
|
36 |
+
self.encoder = Encoder(**ddconfig)
|
37 |
+
self.decoder = Decoder(**ddconfig)
|
38 |
+
self.loss = instantiate_from_config(lossconfig)
|
39 |
+
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
40 |
+
remap=remap,
|
41 |
+
sane_index_shape=sane_index_shape)
|
42 |
+
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
43 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
44 |
+
if colorize_nlabels is not None:
|
45 |
+
assert type(colorize_nlabels)==int
|
46 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
47 |
+
if monitor is not None:
|
48 |
+
self.monitor = monitor
|
49 |
+
self.batch_resize_range = batch_resize_range
|
50 |
+
if self.batch_resize_range is not None:
|
51 |
+
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
52 |
+
|
53 |
+
self.use_ema = use_ema
|
54 |
+
if self.use_ema:
|
55 |
+
self.model_ema = LitEma(self)
|
56 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
57 |
+
|
58 |
+
if ckpt_path is not None:
|
59 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
60 |
+
self.scheduler_config = scheduler_config
|
61 |
+
self.lr_g_factor = lr_g_factor
|
62 |
+
|
63 |
+
@contextmanager
|
64 |
+
def ema_scope(self, context=None):
|
65 |
+
if self.use_ema:
|
66 |
+
self.model_ema.store(self.parameters())
|
67 |
+
self.model_ema.copy_to(self)
|
68 |
+
if context is not None:
|
69 |
+
print(f"{context}: Switched to EMA weights")
|
70 |
+
try:
|
71 |
+
yield None
|
72 |
+
finally:
|
73 |
+
if self.use_ema:
|
74 |
+
self.model_ema.restore(self.parameters())
|
75 |
+
if context is not None:
|
76 |
+
print(f"{context}: Restored training weights")
|
77 |
+
|
78 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
79 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
80 |
+
keys = list(sd.keys())
|
81 |
+
for k in keys:
|
82 |
+
for ik in ignore_keys:
|
83 |
+
if k.startswith(ik):
|
84 |
+
print("Deleting key {} from state_dict.".format(k))
|
85 |
+
del sd[k]
|
86 |
+
missing, unexpected = self.load_state_dict(sd, strict=False)
|
87 |
+
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
88 |
+
if len(missing) > 0:
|
89 |
+
print(f"Missing Keys: {missing}")
|
90 |
+
print(f"Unexpected Keys: {unexpected}")
|
91 |
+
|
92 |
+
def on_train_batch_end(self, *args, **kwargs):
|
93 |
+
if self.use_ema:
|
94 |
+
self.model_ema(self)
|
95 |
+
|
96 |
+
def encode(self, x):
|
97 |
+
h = self.encoder(x)
|
98 |
+
h = self.quant_conv(h)
|
99 |
+
quant, emb_loss, info = self.quantize(h)
|
100 |
+
return quant, emb_loss, info
|
101 |
+
|
102 |
+
def encode_to_prequant(self, x):
|
103 |
+
h = self.encoder(x)
|
104 |
+
h = self.quant_conv(h)
|
105 |
+
return h
|
106 |
+
|
107 |
+
def decode(self, quant):
|
108 |
+
quant = self.post_quant_conv(quant)
|
109 |
+
dec = self.decoder(quant)
|
110 |
+
return dec
|
111 |
+
|
112 |
+
def decode_code(self, code_b):
|
113 |
+
quant_b = self.quantize.embed_code(code_b)
|
114 |
+
dec = self.decode(quant_b)
|
115 |
+
return dec
|
116 |
+
|
117 |
+
def forward(self, input, return_pred_indices=False):
|
118 |
+
quant, diff, (_,_,ind) = self.encode(input)
|
119 |
+
dec = self.decode(quant)
|
120 |
+
if return_pred_indices:
|
121 |
+
return dec, diff, ind
|
122 |
+
return dec, diff
|
123 |
+
|
124 |
+
def get_input(self, batch, k):
|
125 |
+
x = batch[k]
|
126 |
+
if len(x.shape) == 3:
|
127 |
+
x = x[..., None]
|
128 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
129 |
+
if self.batch_resize_range is not None:
|
130 |
+
lower_size = self.batch_resize_range[0]
|
131 |
+
upper_size = self.batch_resize_range[1]
|
132 |
+
if self.global_step <= 4:
|
133 |
+
# do the first few batches with max size to avoid later oom
|
134 |
+
new_resize = upper_size
|
135 |
+
else:
|
136 |
+
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
137 |
+
if new_resize != x.shape[2]:
|
138 |
+
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
139 |
+
x = x.detach()
|
140 |
+
return x
|
141 |
+
|
142 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
143 |
+
# https://github.com/pytorch/pytorch/issues/37142
|
144 |
+
# try not to fool the heuristics
|
145 |
+
x = self.get_input(batch, self.image_key)
|
146 |
+
xrec, qloss, ind = self(x, return_pred_indices=True)
|
147 |
+
|
148 |
+
if optimizer_idx == 0:
|
149 |
+
# autoencode
|
150 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
151 |
+
last_layer=self.get_last_layer(), split="train",
|
152 |
+
predicted_indices=ind)
|
153 |
+
|
154 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
155 |
+
return aeloss
|
156 |
+
|
157 |
+
if optimizer_idx == 1:
|
158 |
+
# discriminator
|
159 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
160 |
+
last_layer=self.get_last_layer(), split="train")
|
161 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
162 |
+
return discloss
|
163 |
+
|
164 |
+
def validation_step(self, batch, batch_idx):
|
165 |
+
log_dict = self._validation_step(batch, batch_idx)
|
166 |
+
with self.ema_scope():
|
167 |
+
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
168 |
+
return log_dict
|
169 |
+
|
170 |
+
def _validation_step(self, batch, batch_idx, suffix=""):
|
171 |
+
x = self.get_input(batch, self.image_key)
|
172 |
+
xrec, qloss, ind = self(x, return_pred_indices=True)
|
173 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
174 |
+
self.global_step,
|
175 |
+
last_layer=self.get_last_layer(),
|
176 |
+
split="val"+suffix,
|
177 |
+
predicted_indices=ind
|
178 |
+
)
|
179 |
+
|
180 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
181 |
+
self.global_step,
|
182 |
+
last_layer=self.get_last_layer(),
|
183 |
+
split="val"+suffix,
|
184 |
+
predicted_indices=ind
|
185 |
+
)
|
186 |
+
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
187 |
+
self.log(f"val{suffix}/rec_loss", rec_loss,
|
188 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
189 |
+
self.log(f"val{suffix}/aeloss", aeloss,
|
190 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
191 |
+
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
192 |
+
del log_dict_ae[f"val{suffix}/rec_loss"]
|
193 |
+
self.log_dict(log_dict_ae)
|
194 |
+
self.log_dict(log_dict_disc)
|
195 |
+
return self.log_dict
|
196 |
+
|
197 |
+
def configure_optimizers(self):
|
198 |
+
lr_d = self.learning_rate
|
199 |
+
lr_g = self.lr_g_factor*self.learning_rate
|
200 |
+
print("lr_d", lr_d)
|
201 |
+
print("lr_g", lr_g)
|
202 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
203 |
+
list(self.decoder.parameters())+
|
204 |
+
list(self.quantize.parameters())+
|
205 |
+
list(self.quant_conv.parameters())+
|
206 |
+
list(self.post_quant_conv.parameters()),
|
207 |
+
lr=lr_g, betas=(0.5, 0.9))
|
208 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
209 |
+
lr=lr_d, betas=(0.5, 0.9))
|
210 |
+
|
211 |
+
if self.scheduler_config is not None:
|
212 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
213 |
+
|
214 |
+
print("Setting up LambdaLR scheduler...")
|
215 |
+
scheduler = [
|
216 |
+
{
|
217 |
+
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
218 |
+
'interval': 'step',
|
219 |
+
'frequency': 1
|
220 |
+
},
|
221 |
+
{
|
222 |
+
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
223 |
+
'interval': 'step',
|
224 |
+
'frequency': 1
|
225 |
+
},
|
226 |
+
]
|
227 |
+
return [opt_ae, opt_disc], scheduler
|
228 |
+
return [opt_ae, opt_disc], []
|
229 |
+
|
230 |
+
def get_last_layer(self):
|
231 |
+
return self.decoder.conv_out.weight
|
232 |
+
|
233 |
+
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
234 |
+
log = dict()
|
235 |
+
x = self.get_input(batch, self.image_key)
|
236 |
+
x = x.to(self.device)
|
237 |
+
if only_inputs:
|
238 |
+
log["inputs"] = x
|
239 |
+
return log
|
240 |
+
xrec, _ = self(x)
|
241 |
+
if x.shape[1] > 3:
|
242 |
+
# colorize with random projection
|
243 |
+
assert xrec.shape[1] > 3
|
244 |
+
x = self.to_rgb(x)
|
245 |
+
xrec = self.to_rgb(xrec)
|
246 |
+
log["inputs"] = x
|
247 |
+
log["reconstructions"] = xrec
|
248 |
+
if plot_ema:
|
249 |
+
with self.ema_scope():
|
250 |
+
xrec_ema, _ = self(x)
|
251 |
+
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
252 |
+
log["reconstructions_ema"] = xrec_ema
|
253 |
+
return log
|
254 |
+
|
255 |
+
def to_rgb(self, x):
|
256 |
+
assert self.image_key == "segmentation"
|
257 |
+
if not hasattr(self, "colorize"):
|
258 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
259 |
+
x = F.conv2d(x, weight=self.colorize)
|
260 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
261 |
+
return x
|
262 |
+
|
263 |
+
|
264 |
+
class VQModelInterface(VQModel):
|
265 |
+
def __init__(self, embed_dim, *args, **kwargs):
|
266 |
+
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
267 |
+
self.embed_dim = embed_dim
|
268 |
+
|
269 |
+
def encode(self, x):
|
270 |
+
h = self.encoder(x)
|
271 |
+
h = self.quant_conv(h)
|
272 |
+
return h
|
273 |
+
|
274 |
+
def decode(self, h, force_not_quantize=False):
|
275 |
+
# also go through quantization layer
|
276 |
+
if not force_not_quantize:
|
277 |
+
quant, emb_loss, info = self.quantize(h)
|
278 |
+
else:
|
279 |
+
quant = h
|
280 |
+
quant = self.post_quant_conv(quant)
|
281 |
+
dec = self.decoder(quant)
|
282 |
+
return dec
|
283 |
+
|
284 |
+
|
285 |
+
class AutoencoderKL(pl.LightningModule):
|
286 |
+
def __init__(self,
|
287 |
+
ddconfig,
|
288 |
+
lossconfig,
|
289 |
+
embed_dim,
|
290 |
+
ckpt_path=None,
|
291 |
+
ignore_keys=[],
|
292 |
+
image_key="image",
|
293 |
+
colorize_nlabels=None,
|
294 |
+
monitor=None,
|
295 |
+
):
|
296 |
+
super().__init__()
|
297 |
+
self.image_key = image_key
|
298 |
+
self.encoder = Encoder(**ddconfig)
|
299 |
+
self.decoder = Decoder(**ddconfig)
|
300 |
+
self.loss = instantiate_from_config(lossconfig)
|
301 |
+
assert ddconfig["double_z"]
|
302 |
+
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
303 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
304 |
+
self.embed_dim = embed_dim
|
305 |
+
if colorize_nlabels is not None:
|
306 |
+
assert type(colorize_nlabels)==int
|
307 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
308 |
+
if monitor is not None:
|
309 |
+
self.monitor = monitor
|
310 |
+
if ckpt_path is not None:
|
311 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
312 |
+
|
313 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
314 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
315 |
+
keys = list(sd.keys())
|
316 |
+
for k in keys:
|
317 |
+
for ik in ignore_keys:
|
318 |
+
if k.startswith(ik):
|
319 |
+
print("Deleting key {} from state_dict.".format(k))
|
320 |
+
del sd[k]
|
321 |
+
self.load_state_dict(sd, strict=False)
|
322 |
+
print(f"Restored from {path}")
|
323 |
+
|
324 |
+
def encode(self, x):
|
325 |
+
h = self.encoder(x)
|
326 |
+
moments = self.quant_conv(h)
|
327 |
+
posterior = DiagonalGaussianDistribution(moments)
|
328 |
+
return posterior
|
329 |
+
|
330 |
+
def decode(self, z):
|
331 |
+
z = self.post_quant_conv(z)
|
332 |
+
dec = self.decoder(z)
|
333 |
+
return dec
|
334 |
+
|
335 |
+
def forward(self, input, sample_posterior=True):
|
336 |
+
posterior = self.encode(input)
|
337 |
+
if sample_posterior:
|
338 |
+
z = posterior.sample()
|
339 |
+
else:
|
340 |
+
z = posterior.mode()
|
341 |
+
dec = self.decode(z)
|
342 |
+
return dec, posterior
|
343 |
+
|
344 |
+
def get_input(self, batch, k):
|
345 |
+
x = batch[k]
|
346 |
+
if len(x.shape) == 3:
|
347 |
+
x = x[..., None]
|
348 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
349 |
+
return x
|
350 |
+
|
351 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
352 |
+
inputs = self.get_input(batch, self.image_key)
|
353 |
+
reconstructions, posterior = self(inputs)
|
354 |
+
|
355 |
+
if optimizer_idx == 0:
|
356 |
+
# train encoder+decoder+logvar
|
357 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
358 |
+
last_layer=self.get_last_layer(), split="train")
|
359 |
+
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
360 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
361 |
+
return aeloss
|
362 |
+
|
363 |
+
if optimizer_idx == 1:
|
364 |
+
# train the discriminator
|
365 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
366 |
+
last_layer=self.get_last_layer(), split="train")
|
367 |
+
|
368 |
+
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
369 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
370 |
+
return discloss
|
371 |
+
|
372 |
+
def validation_step(self, batch, batch_idx):
|
373 |
+
inputs = self.get_input(batch, self.image_key)
|
374 |
+
reconstructions, posterior = self(inputs)
|
375 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
376 |
+
last_layer=self.get_last_layer(), split="val")
|
377 |
+
|
378 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
379 |
+
last_layer=self.get_last_layer(), split="val")
|
380 |
+
|
381 |
+
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
382 |
+
self.log_dict(log_dict_ae)
|
383 |
+
self.log_dict(log_dict_disc)
|
384 |
+
return self.log_dict
|
385 |
+
|
386 |
+
def configure_optimizers(self):
|
387 |
+
lr = self.learning_rate
|
388 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
389 |
+
list(self.decoder.parameters())+
|
390 |
+
list(self.quant_conv.parameters())+
|
391 |
+
list(self.post_quant_conv.parameters()),
|
392 |
+
lr=lr, betas=(0.5, 0.9))
|
393 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
394 |
+
lr=lr, betas=(0.5, 0.9))
|
395 |
+
return [opt_ae, opt_disc], []
|
396 |
+
|
397 |
+
def get_last_layer(self):
|
398 |
+
return self.decoder.conv_out.weight
|
399 |
+
|
400 |
+
@torch.no_grad()
|
401 |
+
def log_images(self, batch, only_inputs=False, **kwargs):
|
402 |
+
log = dict()
|
403 |
+
x = self.get_input(batch, self.image_key)
|
404 |
+
x = x.to(self.device)
|
405 |
+
if not only_inputs:
|
406 |
+
xrec, posterior = self(x)
|
407 |
+
if x.shape[1] > 3:
|
408 |
+
# colorize with random projection
|
409 |
+
assert xrec.shape[1] > 3
|
410 |
+
x = self.to_rgb(x)
|
411 |
+
xrec = self.to_rgb(xrec)
|
412 |
+
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
413 |
+
log["reconstructions"] = xrec
|
414 |
+
log["inputs"] = x
|
415 |
+
return log
|
416 |
+
|
417 |
+
def to_rgb(self, x):
|
418 |
+
assert self.image_key == "segmentation"
|
419 |
+
if not hasattr(self, "colorize"):
|
420 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
421 |
+
x = F.conv2d(x, weight=self.colorize)
|
422 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
423 |
+
return x
|
424 |
+
|
425 |
+
|
426 |
+
class IdentityFirstStage(torch.nn.Module):
|
427 |
+
def __init__(self, *args, vq_interface=False, **kwargs):
|
428 |
+
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
429 |
+
super().__init__()
|
430 |
+
|
431 |
+
def encode(self, x, *args, **kwargs):
|
432 |
+
return x
|
433 |
+
|
434 |
+
def decode(self, x, *args, **kwargs):
|
435 |
+
return x
|
436 |
+
|
437 |
+
def quantize(self, x, *args, **kwargs):
|
438 |
+
if self.vq_interface:
|
439 |
+
return x, None, [None, None, None]
|
440 |
+
return x
|
441 |
+
|
442 |
+
def forward(self, x, *args, **kwargs):
|
443 |
+
return x
|
ldm/models/diffusion/__init__.py
ADDED
File without changes
|
ldm/models/diffusion/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (189 Bytes). View file
|
|
ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (193 Bytes). View file
|
|
ldm/models/diffusion/__pycache__/ddim.cpython-36.pyc
ADDED
Binary file (6.19 kB). View file
|
|