常舒宁 commited on
Commit
1dc89cf
1 Parent(s): 6d96b0b
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. 00058.jpeg +0 -0
  3. app.py +15 -0
  4. configs/autoencoder/autoencoder_kl_16x16x16.yaml +54 -0
  5. configs/autoencoder/autoencoder_kl_32x32x4.yaml +53 -0
  6. configs/autoencoder/autoencoder_kl_64x64x3.yaml +54 -0
  7. configs/autoencoder/autoencoder_kl_8x8x64.yaml +53 -0
  8. configs/latent-diffusion/celebahq-ldm-vq-4.yaml +86 -0
  9. configs/latent-diffusion/cin-ldm-vq-f8.yaml +98 -0
  10. configs/latent-diffusion/cin256-v2.yaml +68 -0
  11. configs/latent-diffusion/ffhq-ldm-vq-4.yaml +85 -0
  12. configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml +85 -0
  13. configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml +91 -0
  14. configs/latent-diffusion/txt2img-1p4B-eval.yaml +71 -0
  15. configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml +77 -0
  16. configs/latent-diffusion/txt2img-1p4B-finetune.yaml +119 -0
  17. configs/latent-diffusion/txt2img-1p4B-finetune_style.yaml +117 -0
  18. configs/stable-diffusion/v1-finetune.yaml +110 -0
  19. configs/stable-diffusion/v1-finetune_unfrozen.yaml +120 -0
  20. configs/stable-diffusion/v1-inference.yaml +70 -0
  21. environment.yaml +31 -0
  22. evaluation/__pycache__/clip_eval.cpython-36.pyc +0 -0
  23. evaluation/__pycache__/clip_eval.cpython-38.pyc +0 -0
  24. evaluation/clip_eval.py +113 -0
  25. ldm/__pycache__/util.cpython-36.pyc +0 -0
  26. ldm/__pycache__/util.cpython-38.pyc +0 -0
  27. ldm/data/__init__.py +0 -0
  28. ldm/data/__pycache__/__init__.cpython-36.pyc +0 -0
  29. ldm/data/__pycache__/__init__.cpython-38.pyc +0 -0
  30. ldm/data/__pycache__/base.cpython-36.pyc +0 -0
  31. ldm/data/__pycache__/base.cpython-38.pyc +0 -0
  32. ldm/data/__pycache__/personalized.cpython-36.pyc +0 -0
  33. ldm/data/__pycache__/personalized.cpython-38.pyc +0 -0
  34. ldm/data/__pycache__/personalized_compose.cpython-38.pyc +0 -0
  35. ldm/data/__pycache__/personalized_detailed_text.cpython-36.pyc +0 -0
  36. ldm/data/__pycache__/personalized_style.cpython-36.pyc +0 -0
  37. ldm/data/__pycache__/personalized_style.cpython-38.pyc +0 -0
  38. ldm/data/base.py +23 -0
  39. ldm/data/imagenet.py +394 -0
  40. ldm/data/lsun.py +92 -0
  41. ldm/data/personalized.py +220 -0
  42. ldm/data/personalized_style.py +129 -0
  43. ldm/lr_scheduler.py +98 -0
  44. ldm/models/__pycache__/autoencoder.cpython-36.pyc +0 -0
  45. ldm/models/__pycache__/autoencoder.cpython-38.pyc +0 -0
  46. ldm/models/autoencoder.py +443 -0
  47. ldm/models/diffusion/__init__.py +0 -0
  48. ldm/models/diffusion/__pycache__/__init__.cpython-36.pyc +0 -0
  49. ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc +0 -0
  50. ldm/models/diffusion/__pycache__/ddim.cpython-36.pyc +0 -0
.DS_Store ADDED
Binary file (8.2 kB). View file
 
00058.jpeg ADDED
app.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ os.system('pip install -e src/taming-transformers/.')
4
+ os.system('pip install -e src/clip/.')
5
+ os.system('pip install -e .')
6
+ from scripts.stable_txt2img import main
7
+ # st.title('AI Gen for SG')
8
+ st.markdown("<h1 style='text-align: center; color: orange;'>AI Gen for SG</h1>", unsafe_allow_html=True)
9
+ st.markdown("<h3 style='text-align: center; color: black;'>ShowLab</h3>", unsafe_allow_html=True)
10
+ st.write('Contributors: Mike Zheng Shou, Shuning Chang, Yufei Shi, Zihan Fan, Xiangdong Zhou')
11
+ text = st.text_input('Enter your prompt', value='', key=None)
12
+ img = main(text)
13
+ # st.write(text)
14
+ if text:
15
+ st.image(img, caption=text)
configs/autoencoder/autoencoder_kl_16x16x16.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-6
3
+ target: ldm.models.autoencoder.AutoencoderKL
4
+ params:
5
+ monitor: "val/rec_loss"
6
+ embed_dim: 16
7
+ lossconfig:
8
+ target: ldm.modules.losses.LPIPSWithDiscriminator
9
+ params:
10
+ disc_start: 50001
11
+ kl_weight: 0.000001
12
+ disc_weight: 0.5
13
+
14
+ ddconfig:
15
+ double_z: True
16
+ z_channels: 16
17
+ resolution: 256
18
+ in_channels: 3
19
+ out_ch: 3
20
+ ch: 128
21
+ ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
22
+ num_res_blocks: 2
23
+ attn_resolutions: [16]
24
+ dropout: 0.0
25
+
26
+
27
+ data:
28
+ target: main.DataModuleFromConfig
29
+ params:
30
+ batch_size: 12
31
+ wrap: True
32
+ train:
33
+ target: ldm.data.imagenet.ImageNetSRTrain
34
+ params:
35
+ size: 256
36
+ degradation: pil_nearest
37
+ validation:
38
+ target: ldm.data.imagenet.ImageNetSRValidation
39
+ params:
40
+ size: 256
41
+ degradation: pil_nearest
42
+
43
+ lightning:
44
+ callbacks:
45
+ image_logger:
46
+ target: main.ImageLogger
47
+ params:
48
+ batch_frequency: 1000
49
+ max_images: 8
50
+ increase_log_steps: True
51
+
52
+ trainer:
53
+ benchmark: True
54
+ accumulate_grad_batches: 2
configs/autoencoder/autoencoder_kl_32x32x4.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-6
3
+ target: ldm.models.autoencoder.AutoencoderKL
4
+ params:
5
+ monitor: "val/rec_loss"
6
+ embed_dim: 4
7
+ lossconfig:
8
+ target: ldm.modules.losses.LPIPSWithDiscriminator
9
+ params:
10
+ disc_start: 50001
11
+ kl_weight: 0.000001
12
+ disc_weight: 0.5
13
+
14
+ ddconfig:
15
+ double_z: True
16
+ z_channels: 4
17
+ resolution: 256
18
+ in_channels: 3
19
+ out_ch: 3
20
+ ch: 128
21
+ ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
22
+ num_res_blocks: 2
23
+ attn_resolutions: [ ]
24
+ dropout: 0.0
25
+
26
+ data:
27
+ target: main.DataModuleFromConfig
28
+ params:
29
+ batch_size: 12
30
+ wrap: True
31
+ train:
32
+ target: ldm.data.imagenet.ImageNetSRTrain
33
+ params:
34
+ size: 256
35
+ degradation: pil_nearest
36
+ validation:
37
+ target: ldm.data.imagenet.ImageNetSRValidation
38
+ params:
39
+ size: 256
40
+ degradation: pil_nearest
41
+
42
+ lightning:
43
+ callbacks:
44
+ image_logger:
45
+ target: main.ImageLogger
46
+ params:
47
+ batch_frequency: 1000
48
+ max_images: 8
49
+ increase_log_steps: True
50
+
51
+ trainer:
52
+ benchmark: True
53
+ accumulate_grad_batches: 2
configs/autoencoder/autoencoder_kl_64x64x3.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-6
3
+ target: ldm.models.autoencoder.AutoencoderKL
4
+ params:
5
+ monitor: "val/rec_loss"
6
+ embed_dim: 3
7
+ lossconfig:
8
+ target: ldm.modules.losses.LPIPSWithDiscriminator
9
+ params:
10
+ disc_start: 50001
11
+ kl_weight: 0.000001
12
+ disc_weight: 0.5
13
+
14
+ ddconfig:
15
+ double_z: True
16
+ z_channels: 3
17
+ resolution: 256
18
+ in_channels: 3
19
+ out_ch: 3
20
+ ch: 128
21
+ ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
22
+ num_res_blocks: 2
23
+ attn_resolutions: [ ]
24
+ dropout: 0.0
25
+
26
+
27
+ data:
28
+ target: main.DataModuleFromConfig
29
+ params:
30
+ batch_size: 12
31
+ wrap: True
32
+ train:
33
+ target: ldm.data.imagenet.ImageNetSRTrain
34
+ params:
35
+ size: 256
36
+ degradation: pil_nearest
37
+ validation:
38
+ target: ldm.data.imagenet.ImageNetSRValidation
39
+ params:
40
+ size: 256
41
+ degradation: pil_nearest
42
+
43
+ lightning:
44
+ callbacks:
45
+ image_logger:
46
+ target: main.ImageLogger
47
+ params:
48
+ batch_frequency: 1000
49
+ max_images: 8
50
+ increase_log_steps: True
51
+
52
+ trainer:
53
+ benchmark: True
54
+ accumulate_grad_batches: 2
configs/autoencoder/autoencoder_kl_8x8x64.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-6
3
+ target: ldm.models.autoencoder.AutoencoderKL
4
+ params:
5
+ monitor: "val/rec_loss"
6
+ embed_dim: 64
7
+ lossconfig:
8
+ target: ldm.modules.losses.LPIPSWithDiscriminator
9
+ params:
10
+ disc_start: 50001
11
+ kl_weight: 0.000001
12
+ disc_weight: 0.5
13
+
14
+ ddconfig:
15
+ double_z: True
16
+ z_channels: 64
17
+ resolution: 256
18
+ in_channels: 3
19
+ out_ch: 3
20
+ ch: 128
21
+ ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1
22
+ num_res_blocks: 2
23
+ attn_resolutions: [16,8]
24
+ dropout: 0.0
25
+
26
+ data:
27
+ target: main.DataModuleFromConfig
28
+ params:
29
+ batch_size: 12
30
+ wrap: True
31
+ train:
32
+ target: ldm.data.imagenet.ImageNetSRTrain
33
+ params:
34
+ size: 256
35
+ degradation: pil_nearest
36
+ validation:
37
+ target: ldm.data.imagenet.ImageNetSRValidation
38
+ params:
39
+ size: 256
40
+ degradation: pil_nearest
41
+
42
+ lightning:
43
+ callbacks:
44
+ image_logger:
45
+ target: main.ImageLogger
46
+ params:
47
+ batch_frequency: 1000
48
+ max_images: 8
49
+ increase_log_steps: True
50
+
51
+ trainer:
52
+ benchmark: True
53
+ accumulate_grad_batches: 2
configs/latent-diffusion/celebahq-ldm-vq-4.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 2.0e-06
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0195
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ image_size: 64
12
+ channels: 3
13
+ monitor: val/loss_simple_ema
14
+
15
+ unet_config:
16
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
17
+ params:
18
+ image_size: 64
19
+ in_channels: 3
20
+ out_channels: 3
21
+ model_channels: 224
22
+ attention_resolutions:
23
+ # note: this isn\t actually the resolution but
24
+ # the downsampling factor, i.e. this corresnponds to
25
+ # attention on spatial resolution 8,16,32, as the
26
+ # spatial reolution of the latents is 64 for f4
27
+ - 8
28
+ - 4
29
+ - 2
30
+ num_res_blocks: 2
31
+ channel_mult:
32
+ - 1
33
+ - 2
34
+ - 3
35
+ - 4
36
+ num_head_channels: 32
37
+ first_stage_config:
38
+ target: ldm.models.autoencoder.VQModelInterface
39
+ params:
40
+ embed_dim: 3
41
+ n_embed: 8192
42
+ ckpt_path: models/first_stage_models/vq-f4/model.ckpt
43
+ ddconfig:
44
+ double_z: false
45
+ z_channels: 3
46
+ resolution: 256
47
+ in_channels: 3
48
+ out_ch: 3
49
+ ch: 128
50
+ ch_mult:
51
+ - 1
52
+ - 2
53
+ - 4
54
+ num_res_blocks: 2
55
+ attn_resolutions: []
56
+ dropout: 0.0
57
+ lossconfig:
58
+ target: torch.nn.Identity
59
+ cond_stage_config: __is_unconditional__
60
+ data:
61
+ target: main.DataModuleFromConfig
62
+ params:
63
+ batch_size: 48
64
+ num_workers: 5
65
+ wrap: false
66
+ train:
67
+ target: taming.data.faceshq.CelebAHQTrain
68
+ params:
69
+ size: 256
70
+ validation:
71
+ target: taming.data.faceshq.CelebAHQValidation
72
+ params:
73
+ size: 256
74
+
75
+
76
+ lightning:
77
+ callbacks:
78
+ image_logger:
79
+ target: main.ImageLogger
80
+ params:
81
+ batch_frequency: 5000
82
+ max_images: 8
83
+ increase_log_steps: False
84
+
85
+ trainer:
86
+ benchmark: True
configs/latent-diffusion/cin-ldm-vq-f8.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-06
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0195
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ cond_stage_key: class_label
12
+ image_size: 32
13
+ channels: 4
14
+ cond_stage_trainable: true
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ unet_config:
18
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19
+ params:
20
+ image_size: 32
21
+ in_channels: 4
22
+ out_channels: 4
23
+ model_channels: 256
24
+ attention_resolutions:
25
+ #note: this isn\t actually the resolution but
26
+ # the downsampling factor, i.e. this corresnponds to
27
+ # attention on spatial resolution 8,16,32, as the
28
+ # spatial reolution of the latents is 32 for f8
29
+ - 4
30
+ - 2
31
+ - 1
32
+ num_res_blocks: 2
33
+ channel_mult:
34
+ - 1
35
+ - 2
36
+ - 4
37
+ num_head_channels: 32
38
+ use_spatial_transformer: true
39
+ transformer_depth: 1
40
+ context_dim: 512
41
+ first_stage_config:
42
+ target: ldm.models.autoencoder.VQModelInterface
43
+ params:
44
+ embed_dim: 4
45
+ n_embed: 16384
46
+ ckpt_path: configs/first_stage_models/vq-f8/model.yaml
47
+ ddconfig:
48
+ double_z: false
49
+ z_channels: 4
50
+ resolution: 256
51
+ in_channels: 3
52
+ out_ch: 3
53
+ ch: 128
54
+ ch_mult:
55
+ - 1
56
+ - 2
57
+ - 2
58
+ - 4
59
+ num_res_blocks: 2
60
+ attn_resolutions:
61
+ - 32
62
+ dropout: 0.0
63
+ lossconfig:
64
+ target: torch.nn.Identity
65
+ cond_stage_config:
66
+ target: ldm.modules.encoders.modules.ClassEmbedder
67
+ params:
68
+ embed_dim: 512
69
+ key: class_label
70
+ data:
71
+ target: main.DataModuleFromConfig
72
+ params:
73
+ batch_size: 64
74
+ num_workers: 12
75
+ wrap: false
76
+ train:
77
+ target: ldm.data.imagenet.ImageNetTrain
78
+ params:
79
+ config:
80
+ size: 256
81
+ validation:
82
+ target: ldm.data.imagenet.ImageNetValidation
83
+ params:
84
+ config:
85
+ size: 256
86
+
87
+
88
+ lightning:
89
+ callbacks:
90
+ image_logger:
91
+ target: main.ImageLogger
92
+ params:
93
+ batch_frequency: 5000
94
+ max_images: 8
95
+ increase_log_steps: False
96
+
97
+ trainer:
98
+ benchmark: True
configs/latent-diffusion/cin256-v2.yaml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 0.0001
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0195
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ cond_stage_key: class_label
12
+ image_size: 64
13
+ channels: 3
14
+ cond_stage_trainable: true
15
+ conditioning_key: crossattn
16
+ monitor: val/loss
17
+ use_ema: False
18
+
19
+ unet_config:
20
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21
+ params:
22
+ image_size: 64
23
+ in_channels: 3
24
+ out_channels: 3
25
+ model_channels: 192
26
+ attention_resolutions:
27
+ - 8
28
+ - 4
29
+ - 2
30
+ num_res_blocks: 2
31
+ channel_mult:
32
+ - 1
33
+ - 2
34
+ - 3
35
+ - 5
36
+ num_heads: 1
37
+ use_spatial_transformer: true
38
+ transformer_depth: 1
39
+ context_dim: 512
40
+
41
+ first_stage_config:
42
+ target: ldm.models.autoencoder.VQModelInterface
43
+ params:
44
+ embed_dim: 3
45
+ n_embed: 8192
46
+ ddconfig:
47
+ double_z: false
48
+ z_channels: 3
49
+ resolution: 256
50
+ in_channels: 3
51
+ out_ch: 3
52
+ ch: 128
53
+ ch_mult:
54
+ - 1
55
+ - 2
56
+ - 4
57
+ num_res_blocks: 2
58
+ attn_resolutions: []
59
+ dropout: 0.0
60
+ lossconfig:
61
+ target: torch.nn.Identity
62
+
63
+ cond_stage_config:
64
+ target: ldm.modules.encoders.modules.ClassEmbedder
65
+ params:
66
+ n_classes: 1001
67
+ embed_dim: 512
68
+ key: class_label
configs/latent-diffusion/ffhq-ldm-vq-4.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 2.0e-06
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0195
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ image_size: 64
12
+ channels: 3
13
+ monitor: val/loss_simple_ema
14
+ unet_config:
15
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16
+ params:
17
+ image_size: 64
18
+ in_channels: 3
19
+ out_channels: 3
20
+ model_channels: 224
21
+ attention_resolutions:
22
+ # note: this isn\t actually the resolution but
23
+ # the downsampling factor, i.e. this corresnponds to
24
+ # attention on spatial resolution 8,16,32, as the
25
+ # spatial reolution of the latents is 64 for f4
26
+ - 8
27
+ - 4
28
+ - 2
29
+ num_res_blocks: 2
30
+ channel_mult:
31
+ - 1
32
+ - 2
33
+ - 3
34
+ - 4
35
+ num_head_channels: 32
36
+ first_stage_config:
37
+ target: ldm.models.autoencoder.VQModelInterface
38
+ params:
39
+ embed_dim: 3
40
+ n_embed: 8192
41
+ ckpt_path: configs/first_stage_models/vq-f4/model.yaml
42
+ ddconfig:
43
+ double_z: false
44
+ z_channels: 3
45
+ resolution: 256
46
+ in_channels: 3
47
+ out_ch: 3
48
+ ch: 128
49
+ ch_mult:
50
+ - 1
51
+ - 2
52
+ - 4
53
+ num_res_blocks: 2
54
+ attn_resolutions: []
55
+ dropout: 0.0
56
+ lossconfig:
57
+ target: torch.nn.Identity
58
+ cond_stage_config: __is_unconditional__
59
+ data:
60
+ target: main.DataModuleFromConfig
61
+ params:
62
+ batch_size: 42
63
+ num_workers: 5
64
+ wrap: false
65
+ train:
66
+ target: taming.data.faceshq.FFHQTrain
67
+ params:
68
+ size: 256
69
+ validation:
70
+ target: taming.data.faceshq.FFHQValidation
71
+ params:
72
+ size: 256
73
+
74
+
75
+ lightning:
76
+ callbacks:
77
+ image_logger:
78
+ target: main.ImageLogger
79
+ params:
80
+ batch_frequency: 5000
81
+ max_images: 8
82
+ increase_log_steps: False
83
+
84
+ trainer:
85
+ benchmark: True
configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 2.0e-06
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0195
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ image_size: 64
12
+ channels: 3
13
+ monitor: val/loss_simple_ema
14
+ unet_config:
15
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16
+ params:
17
+ image_size: 64
18
+ in_channels: 3
19
+ out_channels: 3
20
+ model_channels: 224
21
+ attention_resolutions:
22
+ # note: this isn\t actually the resolution but
23
+ # the downsampling factor, i.e. this corresnponds to
24
+ # attention on spatial resolution 8,16,32, as the
25
+ # spatial reolution of the latents is 64 for f4
26
+ - 8
27
+ - 4
28
+ - 2
29
+ num_res_blocks: 2
30
+ channel_mult:
31
+ - 1
32
+ - 2
33
+ - 3
34
+ - 4
35
+ num_head_channels: 32
36
+ first_stage_config:
37
+ target: ldm.models.autoencoder.VQModelInterface
38
+ params:
39
+ ckpt_path: configs/first_stage_models/vq-f4/model.yaml
40
+ embed_dim: 3
41
+ n_embed: 8192
42
+ ddconfig:
43
+ double_z: false
44
+ z_channels: 3
45
+ resolution: 256
46
+ in_channels: 3
47
+ out_ch: 3
48
+ ch: 128
49
+ ch_mult:
50
+ - 1
51
+ - 2
52
+ - 4
53
+ num_res_blocks: 2
54
+ attn_resolutions: []
55
+ dropout: 0.0
56
+ lossconfig:
57
+ target: torch.nn.Identity
58
+ cond_stage_config: __is_unconditional__
59
+ data:
60
+ target: main.DataModuleFromConfig
61
+ params:
62
+ batch_size: 48
63
+ num_workers: 5
64
+ wrap: false
65
+ train:
66
+ target: ldm.data.lsun.LSUNBedroomsTrain
67
+ params:
68
+ size: 256
69
+ validation:
70
+ target: ldm.data.lsun.LSUNBedroomsValidation
71
+ params:
72
+ size: 256
73
+
74
+
75
+ lightning:
76
+ callbacks:
77
+ image_logger:
78
+ target: main.ImageLogger
79
+ params:
80
+ batch_frequency: 5000
81
+ max_images: 8
82
+ increase_log_steps: False
83
+
84
+ trainer:
85
+ benchmark: True
configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0155
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ loss_type: l1
11
+ first_stage_key: "image"
12
+ cond_stage_key: "image"
13
+ image_size: 32
14
+ channels: 4
15
+ cond_stage_trainable: False
16
+ concat_mode: False
17
+ scale_by_std: True
18
+ monitor: 'val/loss_simple_ema'
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [10000]
24
+ cycle_lengths: [10000000000000]
25
+ f_start: [1.e-6]
26
+ f_max: [1.]
27
+ f_min: [ 1.]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 192
36
+ attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2
39
+ num_heads: 8
40
+ use_scale_shift_norm: True
41
+ resblock_updown: True
42
+
43
+ first_stage_config:
44
+ target: ldm.models.autoencoder.AutoencoderKL
45
+ params:
46
+ embed_dim: 4
47
+ monitor: "val/rec_loss"
48
+ ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
49
+ ddconfig:
50
+ double_z: True
51
+ z_channels: 4
52
+ resolution: 256
53
+ in_channels: 3
54
+ out_ch: 3
55
+ ch: 128
56
+ ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
57
+ num_res_blocks: 2
58
+ attn_resolutions: [ ]
59
+ dropout: 0.0
60
+ lossconfig:
61
+ target: torch.nn.Identity
62
+
63
+ cond_stage_config: "__is_unconditional__"
64
+
65
+ data:
66
+ target: main.DataModuleFromConfig
67
+ params:
68
+ batch_size: 96
69
+ num_workers: 5
70
+ wrap: False
71
+ train:
72
+ target: ldm.data.lsun.LSUNChurchesTrain
73
+ params:
74
+ size: 256
75
+ validation:
76
+ target: ldm.data.lsun.LSUNChurchesValidation
77
+ params:
78
+ size: 256
79
+
80
+ lightning:
81
+ callbacks:
82
+ image_logger:
83
+ target: main.ImageLogger
84
+ params:
85
+ batch_frequency: 5000
86
+ max_images: 8
87
+ increase_log_steps: False
88
+
89
+
90
+ trainer:
91
+ benchmark: True
configs/latent-diffusion/txt2img-1p4B-eval.yaml ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 5.0e-05
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.012
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ cond_stage_key: caption
12
+ image_size: 32
13
+ channels: 4
14
+ cond_stage_trainable: true
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ unet_config:
21
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22
+ params:
23
+ image_size: 32
24
+ in_channels: 4
25
+ out_channels: 4
26
+ model_channels: 320
27
+ attention_resolutions:
28
+ - 4
29
+ - 2
30
+ - 1
31
+ num_res_blocks: 2
32
+ channel_mult:
33
+ - 1
34
+ - 2
35
+ - 4
36
+ - 4
37
+ num_heads: 8
38
+ use_spatial_transformer: true
39
+ transformer_depth: 1
40
+ context_dim: 1280
41
+ use_checkpoint: true
42
+ legacy: False
43
+
44
+ first_stage_config:
45
+ target: ldm.models.autoencoder.AutoencoderKL
46
+ params:
47
+ embed_dim: 4
48
+ monitor: val/rec_loss
49
+ ddconfig:
50
+ double_z: true
51
+ z_channels: 4
52
+ resolution: 256
53
+ in_channels: 3
54
+ out_ch: 3
55
+ ch: 128
56
+ ch_mult:
57
+ - 1
58
+ - 2
59
+ - 4
60
+ - 4
61
+ num_res_blocks: 2
62
+ attn_resolutions: []
63
+ dropout: 0.0
64
+ lossconfig:
65
+ target: torch.nn.Identity
66
+
67
+ cond_stage_config:
68
+ target: ldm.modules.encoders.modules.BERTEmbedder
69
+ params:
70
+ n_embed: 1280
71
+ n_layer: 32
configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 5.0e-05
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.012
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ cond_stage_key: caption
12
+ image_size: 32
13
+ channels: 4
14
+ cond_stage_trainable: true
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ personalization_config:
21
+ target: ldm.modules.embedding_manager.EmbeddingManager
22
+ params:
23
+ placeholder_strings: ["*"]
24
+ initializer_words: []
25
+
26
+ unet_config:
27
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
28
+ params:
29
+ image_size: 32
30
+ in_channels: 4
31
+ out_channels: 4
32
+ model_channels: 320
33
+ attention_resolutions:
34
+ - 4
35
+ - 2
36
+ - 1
37
+ num_res_blocks: 2
38
+ channel_mult:
39
+ - 1
40
+ - 2
41
+ - 4
42
+ - 4
43
+ num_heads: 8
44
+ use_spatial_transformer: true
45
+ transformer_depth: 1
46
+ context_dim: 1280
47
+ use_checkpoint: true
48
+ legacy: False
49
+
50
+ first_stage_config:
51
+ target: ldm.models.autoencoder.AutoencoderKL
52
+ params:
53
+ embed_dim: 4
54
+ monitor: val/rec_loss
55
+ ddconfig:
56
+ double_z: true
57
+ z_channels: 4
58
+ resolution: 256
59
+ in_channels: 3
60
+ out_ch: 3
61
+ ch: 128
62
+ ch_mult:
63
+ - 1
64
+ - 2
65
+ - 4
66
+ - 4
67
+ num_res_blocks: 2
68
+ attn_resolutions: []
69
+ dropout: 0.0
70
+ lossconfig:
71
+ target: torch.nn.Identity
72
+
73
+ cond_stage_config:
74
+ target: ldm.modules.encoders.modules.BERTEmbedder
75
+ params:
76
+ n_embed: 1280
77
+ n_layer: 32
configs/latent-diffusion/txt2img-1p4B-finetune.yaml ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 5.0e-3
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.012
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ cond_stage_key: caption
12
+ image_size: 32
13
+ channels: 4
14
+ cond_stage_trainable: true
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+ embedding_reg_weight: 0.0
20
+
21
+ personalization_config:
22
+ target: ldm.modules.embedding_manager.EmbeddingManager
23
+ params:
24
+ placeholder_strings: ["*"]
25
+ initializer_words: ["sculpture"]
26
+ per_image_tokens: false
27
+ num_vectors_per_token: 1
28
+ progressive_words: False
29
+
30
+ unet_config:
31
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
+ params:
33
+ image_size: 32
34
+ in_channels: 4
35
+ out_channels: 4
36
+ model_channels: 320
37
+ attention_resolutions:
38
+ - 4
39
+ - 2
40
+ - 1
41
+ num_res_blocks: 2
42
+ channel_mult:
43
+ - 1
44
+ - 2
45
+ - 4
46
+ - 4
47
+ num_heads: 8
48
+ use_spatial_transformer: true
49
+ transformer_depth: 1
50
+ context_dim: 1280
51
+ use_checkpoint: true
52
+ legacy: False
53
+
54
+ first_stage_config:
55
+ target: ldm.models.autoencoder.AutoencoderKL
56
+ params:
57
+ embed_dim: 4
58
+ monitor: val/rec_loss
59
+ ddconfig:
60
+ double_z: true
61
+ z_channels: 4
62
+ resolution: 256
63
+ in_channels: 3
64
+ out_ch: 3
65
+ ch: 128
66
+ ch_mult:
67
+ - 1
68
+ - 2
69
+ - 4
70
+ - 4
71
+ num_res_blocks: 2
72
+ attn_resolutions: []
73
+ dropout: 0.0
74
+ lossconfig:
75
+ target: torch.nn.Identity
76
+
77
+ cond_stage_config:
78
+ target: ldm.modules.encoders.modules.BERTEmbedder
79
+ params:
80
+ n_embed: 1280
81
+ n_layer: 32
82
+
83
+
84
+ data:
85
+ target: main.DataModuleFromConfig
86
+ params:
87
+ batch_size: 4
88
+ num_workers: 2
89
+ wrap: false
90
+ train:
91
+ target: ldm.data.personalized.PersonalizedBase
92
+ params:
93
+ size: 256
94
+ set: train
95
+ per_image_tokens: false
96
+ repeats: 100
97
+ validation:
98
+ target: ldm.data.personalized.PersonalizedBase
99
+ params:
100
+ size: 256
101
+ set: val
102
+ per_image_tokens: false
103
+ repeats: 10
104
+
105
+ lightning:
106
+ modelcheckpoint:
107
+ params:
108
+ every_n_train_steps: 500
109
+ callbacks:
110
+ image_logger:
111
+ target: main.ImageLogger
112
+ params:
113
+ batch_frequency: 500
114
+ max_images: 8
115
+ increase_log_steps: False
116
+
117
+ trainer:
118
+ benchmark: True
119
+ max_steps: 6100
configs/latent-diffusion/txt2img-1p4B-finetune_style.yaml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 5.0e-3
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.012
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ cond_stage_key: caption
12
+ image_size: 32
13
+ channels: 4
14
+ cond_stage_trainable: true
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+ embedding_reg_weight: 0.0
20
+
21
+ personalization_config:
22
+ target: ldm.modules.embedding_manager.EmbeddingManager
23
+ params:
24
+ placeholder_strings: ["*"]
25
+ initializer_words: ["painting"]
26
+ per_image_tokens: false
27
+ num_vectors_per_token: 1
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions:
37
+ - 4
38
+ - 2
39
+ - 1
40
+ num_res_blocks: 2
41
+ channel_mult:
42
+ - 1
43
+ - 2
44
+ - 4
45
+ - 4
46
+ num_heads: 8
47
+ use_spatial_transformer: true
48
+ transformer_depth: 1
49
+ context_dim: 1280
50
+ use_checkpoint: true
51
+ legacy: False
52
+
53
+ first_stage_config:
54
+ target: ldm.models.autoencoder.AutoencoderKL
55
+ params:
56
+ embed_dim: 4
57
+ monitor: val/rec_loss
58
+ ddconfig:
59
+ double_z: true
60
+ z_channels: 4
61
+ resolution: 256
62
+ in_channels: 3
63
+ out_ch: 3
64
+ ch: 128
65
+ ch_mult:
66
+ - 1
67
+ - 2
68
+ - 4
69
+ - 4
70
+ num_res_blocks: 2
71
+ attn_resolutions: []
72
+ dropout: 0.0
73
+ lossconfig:
74
+ target: torch.nn.Identity
75
+
76
+ cond_stage_config:
77
+ target: ldm.modules.encoders.modules.BERTEmbedder
78
+ params:
79
+ n_embed: 1280
80
+ n_layer: 32
81
+
82
+
83
+ data:
84
+ target: main.DataModuleFromConfig
85
+ params:
86
+ batch_size: 4
87
+ num_workers: 4
88
+ wrap: false
89
+ train:
90
+ target: ldm.data.personalized_style.PersonalizedBase
91
+ params:
92
+ size: 256
93
+ set: train
94
+ per_image_tokens: false
95
+ repeats: 100
96
+ validation:
97
+ target: ldm.data.personalized_style.PersonalizedBase
98
+ params:
99
+ size: 256
100
+ set: val
101
+ per_image_tokens: false
102
+ repeats: 10
103
+
104
+ lightning:
105
+ modelcheckpoint:
106
+ params:
107
+ every_n_train_steps: 500
108
+ callbacks:
109
+ image_logger:
110
+ target: main.ImageLogger
111
+ params:
112
+ batch_frequency: 500
113
+ max_images: 8
114
+ increase_log_steps: False
115
+
116
+ trainer:
117
+ benchmark: True
configs/stable-diffusion/v1-finetune.yaml ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 5.0e-03
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ cond_stage_key: caption
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: true # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+ embedding_reg_weight: 0.0
20
+ unfreeze_model: False
21
+ model_lr: 0.0
22
+
23
+ personalization_config:
24
+ target: ldm.modules.embedding_manager.EmbeddingManager
25
+ params:
26
+ placeholder_strings: ["*"]
27
+ initializer_words: ["sculpture"]
28
+ per_image_tokens: false
29
+ num_vectors_per_token: 1
30
+ progressive_words: False
31
+
32
+ unet_config:
33
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
34
+ params:
35
+ image_size: 32 # unused
36
+ in_channels: 4
37
+ out_channels: 4
38
+ model_channels: 320
39
+ attention_resolutions: [ 4, 2, 1 ]
40
+ num_res_blocks: 2
41
+ channel_mult: [ 1, 2, 4, 4 ]
42
+ num_heads: 8
43
+ use_spatial_transformer: True
44
+ transformer_depth: 1
45
+ context_dim: 768
46
+ use_checkpoint: True
47
+ legacy: False
48
+
49
+ first_stage_config:
50
+ target: ldm.models.autoencoder.AutoencoderKL
51
+ params:
52
+ embed_dim: 4
53
+ monitor: val/rec_loss
54
+ ddconfig:
55
+ double_z: true
56
+ z_channels: 4
57
+ resolution: 512
58
+ in_channels: 3
59
+ out_ch: 3
60
+ ch: 128
61
+ ch_mult:
62
+ - 1
63
+ - 2
64
+ - 4
65
+ - 4
66
+ num_res_blocks: 2
67
+ attn_resolutions: []
68
+ dropout: 0.0
69
+ lossconfig:
70
+ target: torch.nn.Identity
71
+
72
+ cond_stage_config:
73
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
74
+
75
+ data:
76
+ target: main.DataModuleFromConfig
77
+ params:
78
+ batch_size: 2
79
+ num_workers: 2
80
+ wrap: false
81
+ train:
82
+ target: ldm.data.personalized.PersonalizedBase
83
+ params:
84
+ size: 512
85
+ set: train
86
+ per_image_tokens: false
87
+ repeats: 100
88
+ validation:
89
+ target: ldm.data.personalized.PersonalizedBase
90
+ params:
91
+ size: 512
92
+ set: val
93
+ per_image_tokens: false
94
+ repeats: 10
95
+
96
+ lightning:
97
+ modelcheckpoint:
98
+ params:
99
+ every_n_train_steps: 500
100
+ callbacks:
101
+ image_logger:
102
+ target: main.ImageLogger
103
+ params:
104
+ batch_frequency: 500
105
+ max_images: 8
106
+ increase_log_steps: False
107
+
108
+ trainer:
109
+ benchmark: True
110
+ max_steps: 6100
configs/stable-diffusion/v1-finetune_unfrozen.yaml ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-06
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ reg_weight: 1.0
6
+ linear_start: 0.00085
7
+ linear_end: 0.0120
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: image
12
+ cond_stage_key: caption
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: true # Note: different from the one we trained before
16
+ conditioning_key: crossattn
17
+ monitor: val/loss_simple_ema
18
+ scale_factor: 0.18215
19
+ use_ema: False
20
+ embedding_reg_weight: 0.0
21
+ unfreeze_model: True
22
+ model_lr: 1.0e-6
23
+
24
+ personalization_config:
25
+ target: ldm.modules.embedding_manager.EmbeddingManager
26
+ params:
27
+ placeholder_strings: ["*"]
28
+ initializer_words: ["sculpture"]
29
+ per_image_tokens: false
30
+ num_vectors_per_token: 1
31
+ progressive_words: False
32
+
33
+ unet_config:
34
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
35
+ params:
36
+ image_size: 32 # unused
37
+ in_channels: 4
38
+ out_channels: 4
39
+ model_channels: 320
40
+ attention_resolutions: [ 4, 2, 1 ]
41
+ num_res_blocks: 2
42
+ channel_mult: [ 1, 2, 4, 4 ]
43
+ num_heads: 8
44
+ use_spatial_transformer: True
45
+ transformer_depth: 1
46
+ context_dim: 768
47
+ use_checkpoint: True
48
+ legacy: False
49
+
50
+ first_stage_config:
51
+ target: ldm.models.autoencoder.AutoencoderKL
52
+ params:
53
+ embed_dim: 4
54
+ monitor: val/rec_loss
55
+ ddconfig:
56
+ double_z: true
57
+ z_channels: 4
58
+ resolution: 512
59
+ in_channels: 3
60
+ out_ch: 3
61
+ ch: 128
62
+ ch_mult:
63
+ - 1
64
+ - 2
65
+ - 4
66
+ - 4
67
+ num_res_blocks: 2
68
+ attn_resolutions: []
69
+ dropout: 0.0
70
+ lossconfig:
71
+ target: torch.nn.Identity
72
+
73
+ cond_stage_config:
74
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
75
+
76
+ data:
77
+ target: main.DataModuleFromConfig
78
+ params:
79
+ batch_size: 1
80
+ num_workers: 2
81
+ wrap: false
82
+ train:
83
+ target: ldm.data.personalized.PersonalizedBase
84
+ params:
85
+ size: 512
86
+ set: train
87
+ per_image_tokens: false
88
+ repeats: 100
89
+ reg:
90
+ target: ldm.data.personalized.PersonalizedBase
91
+ params:
92
+ size: 512
93
+ set: train
94
+ reg: true
95
+ per_image_tokens: false
96
+ repeats: 10
97
+
98
+ validation:
99
+ target: ldm.data.personalized.PersonalizedBase
100
+ params:
101
+ size: 512
102
+ set: val
103
+ per_image_tokens: false
104
+ repeats: 10
105
+
106
+ lightning:
107
+ modelcheckpoint:
108
+ params:
109
+ every_n_train_steps: 500
110
+ callbacks:
111
+ image_logger:
112
+ target: main.ImageLogger
113
+ params:
114
+ batch_frequency: 500
115
+ max_images: 8
116
+ increase_log_steps: False
117
+
118
+ trainer:
119
+ benchmark: True
120
+ max_steps: 800
configs/stable-diffusion/v1-inference.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ personalization_config:
21
+ target: ldm.modules.embedding_manager.EmbeddingManager
22
+ params:
23
+ placeholder_strings: ["*"]
24
+ initializer_words: ["sculpture"]
25
+ per_image_tokens: false
26
+ num_vectors_per_token: 1
27
+ progressive_words: False
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: True
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config:
70
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
environment.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: ldm
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - python=3.8.10
7
+ - pip=20.3
8
+ - cudatoolkit=11.3
9
+ - pytorch=1.10.2
10
+ - torchvision=0.11.3
11
+ - numpy=1.22.3
12
+ - pip:
13
+ - albumentations==1.1.0
14
+ - opencv-python==4.2.0.34
15
+ - pudb==2019.2
16
+ - imageio==2.14.1
17
+ - imageio-ffmpeg==0.4.7
18
+ - pytorch-lightning==1.5.9
19
+ - omegaconf==2.1.1
20
+ - test-tube>=0.7.5
21
+ - streamlit>=0.73.1
22
+ - setuptools==59.5.0
23
+ - pillow==9.0.1
24
+ - einops==0.4.1
25
+ - torch-fidelity==0.3.0
26
+ - transformers==4.18.0
27
+ - torchmetrics==0.6.0
28
+ - kornia==0.6
29
+ - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
30
+ - -e git+https://github.com/openai/CLIP.git@main#egg=clip
31
+ - -e .
evaluation/__pycache__/clip_eval.cpython-36.pyc ADDED
Binary file (3.94 kB). View file
 
evaluation/__pycache__/clip_eval.cpython-38.pyc ADDED
Binary file (3.94 kB). View file
 
evaluation/clip_eval.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import clip
2
+ import torch
3
+ from torchvision import transforms
4
+
5
+ from ldm.models.diffusion.ddim import DDIMSampler
6
+
7
+ class CLIPEvaluator(object):
8
+ def __init__(self, device, clip_model='ViT-B/32') -> None:
9
+ self.device = device
10
+ self.model, clip_preprocess = clip.load(clip_model, device=self.device)
11
+
12
+ self.clip_preprocess = clip_preprocess
13
+
14
+ self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (generator output) to [0, 1].
15
+ clip_preprocess.transforms[:2] + # to match CLIP input scale assumptions
16
+ clip_preprocess.transforms[4:]) # + skip convert PIL to tensor
17
+
18
+ def tokenize(self, strings: list):
19
+ return clip.tokenize(strings).to(self.device)
20
+
21
+ @torch.no_grad()
22
+ def encode_text(self, tokens: list) -> torch.Tensor:
23
+ return self.model.encode_text(tokens)
24
+
25
+ @torch.no_grad()
26
+ def encode_images(self, images: torch.Tensor) -> torch.Tensor:
27
+ images = self.preprocess(images).to(self.device)
28
+ return self.model.encode_image(images)
29
+
30
+ def get_text_features(self, text: str, norm: bool = True) -> torch.Tensor:
31
+
32
+ tokens = clip.tokenize(text).to(self.device)
33
+
34
+ text_features = self.encode_text(tokens).detach()
35
+
36
+ if norm:
37
+ text_features /= text_features.norm(dim=-1, keepdim=True)
38
+
39
+ return text_features
40
+
41
+ def get_image_features(self, img: torch.Tensor, norm: bool = True) -> torch.Tensor:
42
+ image_features = self.encode_images(img)
43
+
44
+ if norm:
45
+ image_features /= image_features.clone().norm(dim=-1, keepdim=True)
46
+
47
+ return image_features
48
+
49
+ def img_to_img_similarity(self, src_images, generated_images):
50
+ src_img_features = self.get_image_features(src_images)
51
+ gen_img_features = self.get_image_features(generated_images)
52
+
53
+ return (src_img_features @ gen_img_features.T).mean()
54
+
55
+ def txt_to_img_similarity(self, text, generated_images):
56
+ text_features = self.get_text_features(text)
57
+ gen_img_features = self.get_image_features(generated_images)
58
+
59
+ return (text_features @ gen_img_features.T).mean()
60
+
61
+
62
+ class LDMCLIPEvaluator(CLIPEvaluator):
63
+ def __init__(self, device, clip_model='ViT-B/32') -> None:
64
+ super().__init__(device, clip_model)
65
+
66
+ def evaluate(self, ldm_model, src_images, target_text, n_samples=64, n_steps=50):
67
+
68
+ sampler = DDIMSampler(ldm_model)
69
+
70
+ samples_per_batch = 8
71
+ n_batches = n_samples // samples_per_batch
72
+
73
+ # generate samples
74
+ all_samples=list()
75
+ with torch.no_grad():
76
+ with ldm_model.ema_scope():
77
+ uc = ldm_model.get_learned_conditioning(samples_per_batch * [""])
78
+
79
+ for batch in range(n_batches):
80
+ c = ldm_model.get_learned_conditioning(samples_per_batch * [target_text])
81
+ shape = [4, 256//8, 256//8]
82
+ samples_ddim, _ = sampler.sample(S=n_steps,
83
+ conditioning=c,
84
+ batch_size=samples_per_batch,
85
+ shape=shape,
86
+ verbose=False,
87
+ unconditional_guidance_scale=5.0,
88
+ unconditional_conditioning=uc,
89
+ eta=0.0)
90
+
91
+ x_samples_ddim = ldm_model.decode_first_stage(samples_ddim)
92
+ x_samples_ddim = torch.clamp(x_samples_ddim, min=-1.0, max=1.0)
93
+
94
+ all_samples.append(x_samples_ddim)
95
+
96
+ all_samples = torch.cat(all_samples, axis=0)
97
+
98
+ sim_samples_to_img = self.img_to_img_similarity(src_images, all_samples)
99
+ sim_samples_to_text = self.txt_to_img_similarity(target_text.replace("*", ""), all_samples)
100
+
101
+ return sim_samples_to_img, sim_samples_to_text
102
+
103
+
104
+ class ImageDirEvaluator(CLIPEvaluator):
105
+ def __init__(self, device, clip_model='ViT-B/32') -> None:
106
+ super().__init__(device, clip_model)
107
+
108
+ def evaluate(self, gen_samples, src_images, target_text):
109
+
110
+ sim_samples_to_img = self.img_to_img_similarity(src_images, gen_samples)
111
+ sim_samples_to_text = self.txt_to_img_similarity(target_text.replace("*", ""), gen_samples)
112
+
113
+ return sim_samples_to_img, sim_samples_to_text
ldm/__pycache__/util.cpython-36.pyc ADDED
Binary file (3.72 kB). View file
 
ldm/__pycache__/util.cpython-38.pyc ADDED
Binary file (6.12 kB). View file
 
ldm/data/__init__.py ADDED
File without changes
ldm/data/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (177 Bytes). View file
 
ldm/data/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (181 Bytes). View file
 
ldm/data/__pycache__/base.cpython-36.pyc ADDED
Binary file (1.26 kB). View file
 
ldm/data/__pycache__/base.cpython-38.pyc ADDED
Binary file (1.29 kB). View file
 
ldm/data/__pycache__/personalized.cpython-36.pyc ADDED
Binary file (4.92 kB). View file
 
ldm/data/__pycache__/personalized.cpython-38.pyc ADDED
Binary file (5.82 kB). View file
 
ldm/data/__pycache__/personalized_compose.cpython-38.pyc ADDED
Binary file (6.06 kB). View file
 
ldm/data/__pycache__/personalized_detailed_text.cpython-36.pyc ADDED
Binary file (5.52 kB). View file
 
ldm/data/__pycache__/personalized_style.cpython-36.pyc ADDED
Binary file (4.69 kB). View file
 
ldm/data/__pycache__/personalized_style.cpython-38.pyc ADDED
Binary file (4.78 kB). View file
 
ldm/data/base.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3
+
4
+
5
+ class Txt2ImgIterableBaseDataset(IterableDataset):
6
+ '''
7
+ Define an interface to make the IterableDatasets for text2img data chainable
8
+ '''
9
+ def __init__(self, num_records=0, valid_ids=None, size=256):
10
+ super().__init__()
11
+ self.num_records = num_records
12
+ self.valid_ids = valid_ids
13
+ self.sample_ids = valid_ids
14
+ self.size = size
15
+
16
+ print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17
+
18
+ def __len__(self):
19
+ return self.num_records
20
+
21
+ @abstractmethod
22
+ def __iter__(self):
23
+ pass
ldm/data/imagenet.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, yaml, pickle, shutil, tarfile, glob
2
+ import cv2
3
+ import albumentations
4
+ import PIL
5
+ import numpy as np
6
+ import torchvision.transforms.functional as TF
7
+ from omegaconf import OmegaConf
8
+ from functools import partial
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+ from torch.utils.data import Dataset, Subset
12
+
13
+ import taming.data.utils as tdu
14
+ from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
15
+ from taming.data.imagenet import ImagePaths
16
+
17
+ from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
18
+
19
+
20
+ def synset2idx(path_to_yaml="data/index_synset.yaml"):
21
+ with open(path_to_yaml) as f:
22
+ di2s = yaml.load(f)
23
+ return dict((v,k) for k,v in di2s.items())
24
+
25
+
26
+ class ImageNetBase(Dataset):
27
+ def __init__(self, config=None):
28
+ self.config = config or OmegaConf.create()
29
+ if not type(self.config)==dict:
30
+ self.config = OmegaConf.to_container(self.config)
31
+ self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
32
+ self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
33
+ self._prepare()
34
+ self._prepare_synset_to_human()
35
+ self._prepare_idx_to_synset()
36
+ self._prepare_human_to_integer_label()
37
+ self._load()
38
+
39
+ def __len__(self):
40
+ return len(self.data)
41
+
42
+ def __getitem__(self, i):
43
+ return self.data[i]
44
+
45
+ def _prepare(self):
46
+ raise NotImplementedError()
47
+
48
+ def _filter_relpaths(self, relpaths):
49
+ ignore = set([
50
+ "n06596364_9591.JPEG",
51
+ ])
52
+ relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
53
+ if "sub_indices" in self.config:
54
+ indices = str_to_indices(self.config["sub_indices"])
55
+ synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
56
+ self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
57
+ files = []
58
+ for rpath in relpaths:
59
+ syn = rpath.split("/")[0]
60
+ if syn in synsets:
61
+ files.append(rpath)
62
+ return files
63
+ else:
64
+ return relpaths
65
+
66
+ def _prepare_synset_to_human(self):
67
+ SIZE = 2655750
68
+ URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
69
+ self.human_dict = os.path.join(self.root, "synset_human.txt")
70
+ if (not os.path.exists(self.human_dict) or
71
+ not os.path.getsize(self.human_dict)==SIZE):
72
+ download(URL, self.human_dict)
73
+
74
+ def _prepare_idx_to_synset(self):
75
+ URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
76
+ self.idx2syn = os.path.join(self.root, "index_synset.yaml")
77
+ if (not os.path.exists(self.idx2syn)):
78
+ download(URL, self.idx2syn)
79
+
80
+ def _prepare_human_to_integer_label(self):
81
+ URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
82
+ self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
83
+ if (not os.path.exists(self.human2integer)):
84
+ download(URL, self.human2integer)
85
+ with open(self.human2integer, "r") as f:
86
+ lines = f.read().splitlines()
87
+ assert len(lines) == 1000
88
+ self.human2integer_dict = dict()
89
+ for line in lines:
90
+ value, key = line.split(":")
91
+ self.human2integer_dict[key] = int(value)
92
+
93
+ def _load(self):
94
+ with open(self.txt_filelist, "r") as f:
95
+ self.relpaths = f.read().splitlines()
96
+ l1 = len(self.relpaths)
97
+ self.relpaths = self._filter_relpaths(self.relpaths)
98
+ print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
99
+
100
+ self.synsets = [p.split("/")[0] for p in self.relpaths]
101
+ self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
102
+
103
+ unique_synsets = np.unique(self.synsets)
104
+ class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
105
+ if not self.keep_orig_class_label:
106
+ self.class_labels = [class_dict[s] for s in self.synsets]
107
+ else:
108
+ self.class_labels = [self.synset2idx[s] for s in self.synsets]
109
+
110
+ with open(self.human_dict, "r") as f:
111
+ human_dict = f.read().splitlines()
112
+ human_dict = dict(line.split(maxsplit=1) for line in human_dict)
113
+
114
+ self.human_labels = [human_dict[s] for s in self.synsets]
115
+
116
+ labels = {
117
+ "relpath": np.array(self.relpaths),
118
+ "synsets": np.array(self.synsets),
119
+ "class_label": np.array(self.class_labels),
120
+ "human_label": np.array(self.human_labels),
121
+ }
122
+
123
+ if self.process_images:
124
+ self.size = retrieve(self.config, "size", default=256)
125
+ self.data = ImagePaths(self.abspaths,
126
+ labels=labels,
127
+ size=self.size,
128
+ random_crop=self.random_crop,
129
+ )
130
+ else:
131
+ self.data = self.abspaths
132
+
133
+
134
+ class ImageNetTrain(ImageNetBase):
135
+ NAME = "ILSVRC2012_train"
136
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
137
+ AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
138
+ FILES = [
139
+ "ILSVRC2012_img_train.tar",
140
+ ]
141
+ SIZES = [
142
+ 147897477120,
143
+ ]
144
+
145
+ def __init__(self, process_images=True, data_root=None, **kwargs):
146
+ self.process_images = process_images
147
+ self.data_root = data_root
148
+ super().__init__(**kwargs)
149
+
150
+ def _prepare(self):
151
+ if self.data_root:
152
+ self.root = os.path.join(self.data_root, self.NAME)
153
+ else:
154
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
155
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
156
+
157
+ self.datadir = os.path.join(self.root, "data")
158
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
159
+ self.expected_length = 1281167
160
+ self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
161
+ default=True)
162
+ if not tdu.is_prepared(self.root):
163
+ # prep
164
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
165
+
166
+ datadir = self.datadir
167
+ if not os.path.exists(datadir):
168
+ path = os.path.join(self.root, self.FILES[0])
169
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
170
+ import academictorrents as at
171
+ atpath = at.get(self.AT_HASH, datastore=self.root)
172
+ assert atpath == path
173
+
174
+ print("Extracting {} to {}".format(path, datadir))
175
+ os.makedirs(datadir, exist_ok=True)
176
+ with tarfile.open(path, "r:") as tar:
177
+ tar.extractall(path=datadir)
178
+
179
+ print("Extracting sub-tars.")
180
+ subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
181
+ for subpath in tqdm(subpaths):
182
+ subdir = subpath[:-len(".tar")]
183
+ os.makedirs(subdir, exist_ok=True)
184
+ with tarfile.open(subpath, "r:") as tar:
185
+ tar.extractall(path=subdir)
186
+
187
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
188
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
189
+ filelist = sorted(filelist)
190
+ filelist = "\n".join(filelist)+"\n"
191
+ with open(self.txt_filelist, "w") as f:
192
+ f.write(filelist)
193
+
194
+ tdu.mark_prepared(self.root)
195
+
196
+
197
+ class ImageNetValidation(ImageNetBase):
198
+ NAME = "ILSVRC2012_validation"
199
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
200
+ AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
201
+ VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
202
+ FILES = [
203
+ "ILSVRC2012_img_val.tar",
204
+ "validation_synset.txt",
205
+ ]
206
+ SIZES = [
207
+ 6744924160,
208
+ 1950000,
209
+ ]
210
+
211
+ def __init__(self, process_images=True, data_root=None, **kwargs):
212
+ self.data_root = data_root
213
+ self.process_images = process_images
214
+ super().__init__(**kwargs)
215
+
216
+ def _prepare(self):
217
+ if self.data_root:
218
+ self.root = os.path.join(self.data_root, self.NAME)
219
+ else:
220
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
221
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
222
+ self.datadir = os.path.join(self.root, "data")
223
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
224
+ self.expected_length = 50000
225
+ self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
226
+ default=False)
227
+ if not tdu.is_prepared(self.root):
228
+ # prep
229
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
230
+
231
+ datadir = self.datadir
232
+ if not os.path.exists(datadir):
233
+ path = os.path.join(self.root, self.FILES[0])
234
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
235
+ import academictorrents as at
236
+ atpath = at.get(self.AT_HASH, datastore=self.root)
237
+ assert atpath == path
238
+
239
+ print("Extracting {} to {}".format(path, datadir))
240
+ os.makedirs(datadir, exist_ok=True)
241
+ with tarfile.open(path, "r:") as tar:
242
+ tar.extractall(path=datadir)
243
+
244
+ vspath = os.path.join(self.root, self.FILES[1])
245
+ if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
246
+ download(self.VS_URL, vspath)
247
+
248
+ with open(vspath, "r") as f:
249
+ synset_dict = f.read().splitlines()
250
+ synset_dict = dict(line.split() for line in synset_dict)
251
+
252
+ print("Reorganizing into synset folders")
253
+ synsets = np.unique(list(synset_dict.values()))
254
+ for s in synsets:
255
+ os.makedirs(os.path.join(datadir, s), exist_ok=True)
256
+ for k, v in synset_dict.items():
257
+ src = os.path.join(datadir, k)
258
+ dst = os.path.join(datadir, v)
259
+ shutil.move(src, dst)
260
+
261
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
262
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
263
+ filelist = sorted(filelist)
264
+ filelist = "\n".join(filelist)+"\n"
265
+ with open(self.txt_filelist, "w") as f:
266
+ f.write(filelist)
267
+
268
+ tdu.mark_prepared(self.root)
269
+
270
+
271
+
272
+ class ImageNetSR(Dataset):
273
+ def __init__(self, size=None,
274
+ degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
275
+ random_crop=True):
276
+ """
277
+ Imagenet Superresolution Dataloader
278
+ Performs following ops in order:
279
+ 1. crops a crop of size s from image either as random or center crop
280
+ 2. resizes crop to size with cv2.area_interpolation
281
+ 3. degrades resized crop with degradation_fn
282
+
283
+ :param size: resizing to size after cropping
284
+ :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
285
+ :param downscale_f: Low Resolution Downsample factor
286
+ :param min_crop_f: determines crop size s,
287
+ where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
288
+ :param max_crop_f: ""
289
+ :param data_root:
290
+ :param random_crop:
291
+ """
292
+ self.base = self.get_base()
293
+ assert size
294
+ assert (size / downscale_f).is_integer()
295
+ self.size = size
296
+ self.LR_size = int(size / downscale_f)
297
+ self.min_crop_f = min_crop_f
298
+ self.max_crop_f = max_crop_f
299
+ assert(max_crop_f <= 1.)
300
+ self.center_crop = not random_crop
301
+
302
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
303
+
304
+ self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
305
+
306
+ if degradation == "bsrgan":
307
+ self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
308
+
309
+ elif degradation == "bsrgan_light":
310
+ self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
311
+
312
+ else:
313
+ interpolation_fn = {
314
+ "cv_nearest": cv2.INTER_NEAREST,
315
+ "cv_bilinear": cv2.INTER_LINEAR,
316
+ "cv_bicubic": cv2.INTER_CUBIC,
317
+ "cv_area": cv2.INTER_AREA,
318
+ "cv_lanczos": cv2.INTER_LANCZOS4,
319
+ "pil_nearest": PIL.Image.NEAREST,
320
+ "pil_bilinear": PIL.Image.BILINEAR,
321
+ "pil_bicubic": PIL.Image.BICUBIC,
322
+ "pil_box": PIL.Image.BOX,
323
+ "pil_hamming": PIL.Image.HAMMING,
324
+ "pil_lanczos": PIL.Image.LANCZOS,
325
+ }[degradation]
326
+
327
+ self.pil_interpolation = degradation.startswith("pil_")
328
+
329
+ if self.pil_interpolation:
330
+ self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
331
+
332
+ else:
333
+ self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
334
+ interpolation=interpolation_fn)
335
+
336
+ def __len__(self):
337
+ return len(self.base)
338
+
339
+ def __getitem__(self, i):
340
+ example = self.base[i]
341
+ image = Image.open(example["file_path_"])
342
+
343
+ if not image.mode == "RGB":
344
+ image = image.convert("RGB")
345
+
346
+ image = np.array(image).astype(np.uint8)
347
+
348
+ min_side_len = min(image.shape[:2])
349
+ crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
350
+ crop_side_len = int(crop_side_len)
351
+
352
+ if self.center_crop:
353
+ self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
354
+
355
+ else:
356
+ self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
357
+
358
+ image = self.cropper(image=image)["image"]
359
+ image = self.image_rescaler(image=image)["image"]
360
+
361
+ if self.pil_interpolation:
362
+ image_pil = PIL.Image.fromarray(image)
363
+ LR_image = self.degradation_process(image_pil)
364
+ LR_image = np.array(LR_image).astype(np.uint8)
365
+
366
+ else:
367
+ LR_image = self.degradation_process(image=image)["image"]
368
+
369
+ example["image"] = (image/127.5 - 1.0).astype(np.float32)
370
+ example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
371
+
372
+ return example
373
+
374
+
375
+ class ImageNetSRTrain(ImageNetSR):
376
+ def __init__(self, **kwargs):
377
+ super().__init__(**kwargs)
378
+
379
+ def get_base(self):
380
+ with open("data/imagenet_train_hr_indices.p", "rb") as f:
381
+ indices = pickle.load(f)
382
+ dset = ImageNetTrain(process_images=False,)
383
+ return Subset(dset, indices)
384
+
385
+
386
+ class ImageNetSRValidation(ImageNetSR):
387
+ def __init__(self, **kwargs):
388
+ super().__init__(**kwargs)
389
+
390
+ def get_base(self):
391
+ with open("data/imagenet_val_hr_indices.p", "rb") as f:
392
+ indices = pickle.load(f)
393
+ dset = ImageNetValidation(process_images=False,)
394
+ return Subset(dset, indices)
ldm/data/lsun.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import PIL
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset
6
+ from torchvision import transforms
7
+
8
+
9
+ class LSUNBase(Dataset):
10
+ def __init__(self,
11
+ txt_file,
12
+ data_root,
13
+ size=None,
14
+ interpolation="bicubic",
15
+ flip_p=0.5
16
+ ):
17
+ self.data_paths = txt_file
18
+ self.data_root = data_root
19
+ with open(self.data_paths, "r") as f:
20
+ self.image_paths = f.read().splitlines()
21
+ self._length = len(self.image_paths)
22
+ self.labels = {
23
+ "relative_file_path_": [l for l in self.image_paths],
24
+ "file_path_": [os.path.join(self.data_root, l)
25
+ for l in self.image_paths],
26
+ }
27
+
28
+ self.size = size
29
+ self.interpolation = {"linear": PIL.Image.LINEAR,
30
+ "bilinear": PIL.Image.BILINEAR,
31
+ "bicubic": PIL.Image.BICUBIC,
32
+ "lanczos": PIL.Image.LANCZOS,
33
+ }[interpolation]
34
+ self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35
+
36
+ def __len__(self):
37
+ return self._length
38
+
39
+ def __getitem__(self, i):
40
+ example = dict((k, self.labels[k][i]) for k in self.labels)
41
+ image = Image.open(example["file_path_"])
42
+ if not image.mode == "RGB":
43
+ image = image.convert("RGB")
44
+
45
+ # default to score-sde preprocessing
46
+ img = np.array(image).astype(np.uint8)
47
+ crop = min(img.shape[0], img.shape[1])
48
+ h, w, = img.shape[0], img.shape[1]
49
+ img = img[(h - crop) // 2:(h + crop) // 2,
50
+ (w - crop) // 2:(w + crop) // 2]
51
+
52
+ image = Image.fromarray(img)
53
+ if self.size is not None:
54
+ image = image.resize((self.size, self.size), resample=self.interpolation)
55
+
56
+ image = self.flip(image)
57
+ image = np.array(image).astype(np.uint8)
58
+ example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59
+ return example
60
+
61
+
62
+ class LSUNChurchesTrain(LSUNBase):
63
+ def __init__(self, **kwargs):
64
+ super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65
+
66
+
67
+ class LSUNChurchesValidation(LSUNBase):
68
+ def __init__(self, flip_p=0., **kwargs):
69
+ super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70
+ flip_p=flip_p, **kwargs)
71
+
72
+
73
+ class LSUNBedroomsTrain(LSUNBase):
74
+ def __init__(self, **kwargs):
75
+ super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76
+
77
+
78
+ class LSUNBedroomsValidation(LSUNBase):
79
+ def __init__(self, flip_p=0.0, **kwargs):
80
+ super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81
+ flip_p=flip_p, **kwargs)
82
+
83
+
84
+ class LSUNCatsTrain(LSUNBase):
85
+ def __init__(self, **kwargs):
86
+ super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87
+
88
+
89
+ class LSUNCatsValidation(LSUNBase):
90
+ def __init__(self, flip_p=0., **kwargs):
91
+ super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92
+ flip_p=flip_p, **kwargs)
ldm/data/personalized.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import PIL
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset
6
+ from torchvision import transforms
7
+
8
+ import random
9
+
10
+ training_templates_smallest = [
11
+ 'photo of a sks {}',
12
+ ]
13
+
14
+ reg_templates_smallest = [
15
+ 'photo of a {}',
16
+ ]
17
+
18
+ imagenet_templates_small = [
19
+ 'a photo of a {}',
20
+ 'a rendering of a {}',
21
+ 'a cropped photo of the {}',
22
+ 'the photo of a {}',
23
+ 'a photo of a clean {}',
24
+ 'a photo of a dirty {}',
25
+ 'a dark photo of the {}',
26
+ 'a photo of my {}',
27
+ 'a photo of the cool {}',
28
+ 'a close-up photo of a {}',
29
+ 'a bright photo of the {}',
30
+ 'a cropped photo of a {}',
31
+ 'a photo of the {}',
32
+ 'a good photo of the {}',
33
+ 'a photo of one {}',
34
+ 'a close-up photo of the {}',
35
+ 'a rendition of the {}',
36
+ 'a photo of the clean {}',
37
+ 'a rendition of a {}',
38
+ 'a photo of a nice {}',
39
+ 'a good photo of a {}',
40
+ 'a photo of the nice {}',
41
+ 'a photo of the small {}',
42
+ 'a photo of the weird {}',
43
+ 'a photo of the large {}',
44
+ 'a photo of a cool {}',
45
+ 'a photo of a small {}',
46
+ 'an illustration of a {}',
47
+ 'a rendering of a {}',
48
+ 'a cropped photo of the {}',
49
+ 'the photo of a {}',
50
+ 'an illustration of a clean {}',
51
+ 'an illustration of a dirty {}',
52
+ 'a dark photo of the {}',
53
+ 'an illustration of my {}',
54
+ 'an illustration of the cool {}',
55
+ 'a close-up photo of a {}',
56
+ 'a bright photo of the {}',
57
+ 'a cropped photo of a {}',
58
+ 'an illustration of the {}',
59
+ 'a good photo of the {}',
60
+ 'an illustration of one {}',
61
+ 'a close-up photo of the {}',
62
+ 'a rendition of the {}',
63
+ 'an illustration of the clean {}',
64
+ 'a rendition of a {}',
65
+ 'an illustration of a nice {}',
66
+ 'a good photo of a {}',
67
+ 'an illustration of the nice {}',
68
+ 'an illustration of the small {}',
69
+ 'an illustration of the weird {}',
70
+ 'an illustration of the large {}',
71
+ 'an illustration of a cool {}',
72
+ 'an illustration of a small {}',
73
+ 'a depiction of a {}',
74
+ 'a rendering of a {}',
75
+ 'a cropped photo of the {}',
76
+ 'the photo of a {}',
77
+ 'a depiction of a clean {}',
78
+ 'a depiction of a dirty {}',
79
+ 'a dark photo of the {}',
80
+ 'a depiction of my {}',
81
+ 'a depiction of the cool {}',
82
+ 'a close-up photo of a {}',
83
+ 'a bright photo of the {}',
84
+ 'a cropped photo of a {}',
85
+ 'a depiction of the {}',
86
+ 'a good photo of the {}',
87
+ 'a depiction of one {}',
88
+ 'a close-up photo of the {}',
89
+ 'a rendition of the {}',
90
+ 'a depiction of the clean {}',
91
+ 'a rendition of a {}',
92
+ 'a depiction of a nice {}',
93
+ 'a good photo of a {}',
94
+ 'a depiction of the nice {}',
95
+ 'a depiction of the small {}',
96
+ 'a depiction of the weird {}',
97
+ 'a depiction of the large {}',
98
+ 'a depiction of a cool {}',
99
+ 'a depiction of a small {}',
100
+ ]
101
+
102
+ imagenet_dual_templates_small = [
103
+ 'a photo of a {} with {}',
104
+ 'a rendering of a {} with {}',
105
+ 'a cropped photo of the {} with {}',
106
+ 'the photo of a {} with {}',
107
+ 'a photo of a clean {} with {}',
108
+ 'a photo of a dirty {} with {}',
109
+ 'a dark photo of the {} with {}',
110
+ 'a photo of my {} with {}',
111
+ 'a photo of the cool {} with {}',
112
+ 'a close-up photo of a {} with {}',
113
+ 'a bright photo of the {} with {}',
114
+ 'a cropped photo of a {} with {}',
115
+ 'a photo of the {} with {}',
116
+ 'a good photo of the {} with {}',
117
+ 'a photo of one {} with {}',
118
+ 'a close-up photo of the {} with {}',
119
+ 'a rendition of the {} with {}',
120
+ 'a photo of the clean {} with {}',
121
+ 'a rendition of a {} with {}',
122
+ 'a photo of a nice {} with {}',
123
+ 'a good photo of a {} with {}',
124
+ 'a photo of the nice {} with {}',
125
+ 'a photo of the small {} with {}',
126
+ 'a photo of the weird {} with {}',
127
+ 'a photo of the large {} with {}',
128
+ 'a photo of a cool {} with {}',
129
+ 'a photo of a small {} with {}',
130
+ ]
131
+
132
+ per_img_token_list = [
133
+ 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
134
+ ]
135
+
136
+ class PersonalizedBase(Dataset):
137
+ def __init__(self,
138
+ data_root,
139
+ size=None,
140
+ repeats=100,
141
+ interpolation="bicubic",
142
+ flip_p=0.5,
143
+ set="train",
144
+ placeholder_token="dog",
145
+ per_image_tokens=False,
146
+ center_crop=False,
147
+ mixing_prob=0.25,
148
+ coarse_class_text=None,
149
+ reg = False
150
+ ):
151
+
152
+ self.data_root = data_root
153
+
154
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
155
+
156
+ # self._length = len(self.image_paths)
157
+ self.num_images = len(self.image_paths)
158
+ self._length = self.num_images
159
+
160
+ self.placeholder_token = placeholder_token
161
+
162
+ self.per_image_tokens = per_image_tokens
163
+ self.center_crop = center_crop
164
+ self.mixing_prob = mixing_prob
165
+
166
+ self.coarse_class_text = coarse_class_text
167
+
168
+ if per_image_tokens:
169
+ assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
170
+
171
+ if set == "train":
172
+ self._length = self.num_images * repeats
173
+
174
+ self.size = size
175
+ self.interpolation = {"linear": PIL.Image.LINEAR,
176
+ "bilinear": PIL.Image.BILINEAR,
177
+ "bicubic": PIL.Image.BICUBIC,
178
+ "lanczos": PIL.Image.LANCZOS,
179
+ }[interpolation]
180
+ self.flip = transforms.RandomHorizontalFlip(p=flip_p)
181
+ self.reg = reg
182
+
183
+ def __len__(self):
184
+ return self._length
185
+
186
+ def __getitem__(self, i):
187
+ example = {}
188
+ image = Image.open(self.image_paths[i % self.num_images])
189
+
190
+ if not image.mode == "RGB":
191
+ image = image.convert("RGB")
192
+
193
+ placeholder_string = self.placeholder_token
194
+ if self.coarse_class_text:
195
+ placeholder_string = f"{self.coarse_class_text} {placeholder_string}"
196
+
197
+ if not self.reg:
198
+ text = random.choice(training_templates_smallest).format(placeholder_string)
199
+ else:
200
+ text = random.choice(reg_templates_smallest).format(placeholder_string)
201
+
202
+ example["caption"] = text
203
+
204
+ # default to score-sde preprocessing
205
+ img = np.array(image).astype(np.uint8)
206
+
207
+ if self.center_crop:
208
+ crop = min(img.shape[0], img.shape[1])
209
+ h, w, = img.shape[0], img.shape[1]
210
+ img = img[(h - crop) // 2:(h + crop) // 2,
211
+ (w - crop) // 2:(w + crop) // 2]
212
+
213
+ image = Image.fromarray(img)
214
+ if self.size is not None:
215
+ image = image.resize((self.size, self.size), resample=self.interpolation)
216
+
217
+ image = self.flip(image)
218
+ image = np.array(image).astype(np.uint8)
219
+ example["image"] = (image / 127.5 - 1.0).astype(np.float32)
220
+ return example
ldm/data/personalized_style.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import PIL
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset
6
+ from torchvision import transforms
7
+
8
+ import random
9
+
10
+ imagenet_templates_small = [
11
+ 'a painting in the style of {}',
12
+ 'a rendering in the style of {}',
13
+ 'a cropped painting in the style of {}',
14
+ 'the painting in the style of {}',
15
+ 'a clean painting in the style of {}',
16
+ 'a dirty painting in the style of {}',
17
+ 'a dark painting in the style of {}',
18
+ 'a picture in the style of {}',
19
+ 'a cool painting in the style of {}',
20
+ 'a close-up painting in the style of {}',
21
+ 'a bright painting in the style of {}',
22
+ 'a cropped painting in the style of {}',
23
+ 'a good painting in the style of {}',
24
+ 'a close-up painting in the style of {}',
25
+ 'a rendition in the style of {}',
26
+ 'a nice painting in the style of {}',
27
+ 'a small painting in the style of {}',
28
+ 'a weird painting in the style of {}',
29
+ 'a large painting in the style of {}',
30
+ ]
31
+
32
+ imagenet_dual_templates_small = [
33
+ 'a painting in the style of {} with {}',
34
+ 'a rendering in the style of {} with {}',
35
+ 'a cropped painting in the style of {} with {}',
36
+ 'the painting in the style of {} with {}',
37
+ 'a clean painting in the style of {} with {}',
38
+ 'a dirty painting in the style of {} with {}',
39
+ 'a dark painting in the style of {} with {}',
40
+ 'a cool painting in the style of {} with {}',
41
+ 'a close-up painting in the style of {} with {}',
42
+ 'a bright painting in the style of {} with {}',
43
+ 'a cropped painting in the style of {} with {}',
44
+ 'a good painting in the style of {} with {}',
45
+ 'a painting of one {} in the style of {}',
46
+ 'a nice painting in the style of {} with {}',
47
+ 'a small painting in the style of {} with {}',
48
+ 'a weird painting in the style of {} with {}',
49
+ 'a large painting in the style of {} with {}',
50
+ ]
51
+
52
+ per_img_token_list = [
53
+ 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
54
+ ]
55
+
56
+ class PersonalizedBase(Dataset):
57
+ def __init__(self,
58
+ data_root,
59
+ size=None,
60
+ repeats=100,
61
+ interpolation="bicubic",
62
+ flip_p=0.5,
63
+ set="train",
64
+ placeholder_token="*",
65
+ per_image_tokens=False,
66
+ center_crop=False,
67
+ ):
68
+
69
+ self.data_root = data_root
70
+
71
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
72
+
73
+ # self._length = len(self.image_paths)
74
+ self.num_images = len(self.image_paths)
75
+ self._length = self.num_images
76
+
77
+ self.placeholder_token = placeholder_token
78
+
79
+ self.per_image_tokens = per_image_tokens
80
+ self.center_crop = center_crop
81
+
82
+ if per_image_tokens:
83
+ assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
84
+
85
+ if set == "train":
86
+ self._length = self.num_images * repeats
87
+
88
+ self.size = size
89
+ self.interpolation = {"linear": PIL.Image.LINEAR,
90
+ "bilinear": PIL.Image.BILINEAR,
91
+ "bicubic": PIL.Image.BICUBIC,
92
+ "lanczos": PIL.Image.LANCZOS,
93
+ }[interpolation]
94
+ self.flip = transforms.RandomHorizontalFlip(p=flip_p)
95
+
96
+ def __len__(self):
97
+ return self._length
98
+
99
+ def __getitem__(self, i):
100
+ example = {}
101
+ image = Image.open(self.image_paths[i % self.num_images])
102
+
103
+ if not image.mode == "RGB":
104
+ image = image.convert("RGB")
105
+
106
+ if self.per_image_tokens and np.random.uniform() < 0.25:
107
+ text = random.choice(imagenet_dual_templates_small).format(self.placeholder_token, per_img_token_list[i % self.num_images])
108
+ else:
109
+ text = random.choice(imagenet_templates_small).format(self.placeholder_token)
110
+
111
+ example["caption"] = text
112
+
113
+ # default to score-sde preprocessing
114
+ img = np.array(image).astype(np.uint8)
115
+
116
+ if self.center_crop:
117
+ crop = min(img.shape[0], img.shape[1])
118
+ h, w, = img.shape[0], img.shape[1]
119
+ img = img[(h - crop) // 2:(h + crop) // 2,
120
+ (w - crop) // 2:(w + crop) // 2]
121
+
122
+ image = Image.fromarray(img)
123
+ if self.size is not None:
124
+ image = image.resize((self.size, self.size), resample=self.interpolation)
125
+
126
+ image = self.flip(image)
127
+ image = np.array(image).astype(np.uint8)
128
+ example["image"] = (image / 127.5 - 1.0).astype(np.float32)
129
+ return example
ldm/lr_scheduler.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9
+ self.lr_warm_up_steps = warm_up_steps
10
+ self.lr_start = lr_start
11
+ self.lr_min = lr_min
12
+ self.lr_max = lr_max
13
+ self.lr_max_decay_steps = max_decay_steps
14
+ self.last_lr = 0.
15
+ self.verbosity_interval = verbosity_interval
16
+
17
+ def schedule(self, n, **kwargs):
18
+ if self.verbosity_interval > 0:
19
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20
+ if n < self.lr_warm_up_steps:
21
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22
+ self.last_lr = lr
23
+ return lr
24
+ else:
25
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26
+ t = min(t, 1.0)
27
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28
+ 1 + np.cos(t * np.pi))
29
+ self.last_lr = lr
30
+ return lr
31
+
32
+ def __call__(self, n, **kwargs):
33
+ return self.schedule(n,**kwargs)
34
+
35
+
36
+ class LambdaWarmUpCosineScheduler2:
37
+ """
38
+ supports repeated iterations, configurable via lists
39
+ note: use with a base_lr of 1.0.
40
+ """
41
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43
+ self.lr_warm_up_steps = warm_up_steps
44
+ self.f_start = f_start
45
+ self.f_min = f_min
46
+ self.f_max = f_max
47
+ self.cycle_lengths = cycle_lengths
48
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49
+ self.last_f = 0.
50
+ self.verbosity_interval = verbosity_interval
51
+
52
+ def find_in_interval(self, n):
53
+ interval = 0
54
+ for cl in self.cum_cycles[1:]:
55
+ if n <= cl:
56
+ return interval
57
+ interval += 1
58
+
59
+ def schedule(self, n, **kwargs):
60
+ cycle = self.find_in_interval(n)
61
+ n = n - self.cum_cycles[cycle]
62
+ if self.verbosity_interval > 0:
63
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64
+ f"current cycle {cycle}")
65
+ if n < self.lr_warm_up_steps[cycle]:
66
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67
+ self.last_f = f
68
+ return f
69
+ else:
70
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71
+ t = min(t, 1.0)
72
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73
+ 1 + np.cos(t * np.pi))
74
+ self.last_f = f
75
+ return f
76
+
77
+ def __call__(self, n, **kwargs):
78
+ return self.schedule(n, **kwargs)
79
+
80
+
81
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82
+
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88
+ f"current cycle {cycle}")
89
+
90
+ if n < self.lr_warm_up_steps[cycle]:
91
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92
+ self.last_f = f
93
+ return f
94
+ else:
95
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96
+ self.last_f = f
97
+ return f
98
+
ldm/models/__pycache__/autoencoder.cpython-36.pyc ADDED
Binary file (13.7 kB). View file
 
ldm/models/__pycache__/autoencoder.cpython-38.pyc ADDED
Binary file (13.6 kB). View file
 
ldm/models/autoencoder.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ import torch.nn.functional as F
4
+ from contextlib import contextmanager
5
+
6
+ from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
7
+
8
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
9
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
10
+
11
+ from ldm.util import instantiate_from_config
12
+
13
+
14
+ class VQModel(pl.LightningModule):
15
+ def __init__(self,
16
+ ddconfig,
17
+ lossconfig,
18
+ n_embed,
19
+ embed_dim,
20
+ ckpt_path=None,
21
+ ignore_keys=[],
22
+ image_key="image",
23
+ colorize_nlabels=None,
24
+ monitor=None,
25
+ batch_resize_range=None,
26
+ scheduler_config=None,
27
+ lr_g_factor=1.0,
28
+ remap=None,
29
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
30
+ use_ema=False
31
+ ):
32
+ super().__init__()
33
+ self.embed_dim = embed_dim
34
+ self.n_embed = n_embed
35
+ self.image_key = image_key
36
+ self.encoder = Encoder(**ddconfig)
37
+ self.decoder = Decoder(**ddconfig)
38
+ self.loss = instantiate_from_config(lossconfig)
39
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
40
+ remap=remap,
41
+ sane_index_shape=sane_index_shape)
42
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
43
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
44
+ if colorize_nlabels is not None:
45
+ assert type(colorize_nlabels)==int
46
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
47
+ if monitor is not None:
48
+ self.monitor = monitor
49
+ self.batch_resize_range = batch_resize_range
50
+ if self.batch_resize_range is not None:
51
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
52
+
53
+ self.use_ema = use_ema
54
+ if self.use_ema:
55
+ self.model_ema = LitEma(self)
56
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
57
+
58
+ if ckpt_path is not None:
59
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
60
+ self.scheduler_config = scheduler_config
61
+ self.lr_g_factor = lr_g_factor
62
+
63
+ @contextmanager
64
+ def ema_scope(self, context=None):
65
+ if self.use_ema:
66
+ self.model_ema.store(self.parameters())
67
+ self.model_ema.copy_to(self)
68
+ if context is not None:
69
+ print(f"{context}: Switched to EMA weights")
70
+ try:
71
+ yield None
72
+ finally:
73
+ if self.use_ema:
74
+ self.model_ema.restore(self.parameters())
75
+ if context is not None:
76
+ print(f"{context}: Restored training weights")
77
+
78
+ def init_from_ckpt(self, path, ignore_keys=list()):
79
+ sd = torch.load(path, map_location="cpu")["state_dict"]
80
+ keys = list(sd.keys())
81
+ for k in keys:
82
+ for ik in ignore_keys:
83
+ if k.startswith(ik):
84
+ print("Deleting key {} from state_dict.".format(k))
85
+ del sd[k]
86
+ missing, unexpected = self.load_state_dict(sd, strict=False)
87
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
88
+ if len(missing) > 0:
89
+ print(f"Missing Keys: {missing}")
90
+ print(f"Unexpected Keys: {unexpected}")
91
+
92
+ def on_train_batch_end(self, *args, **kwargs):
93
+ if self.use_ema:
94
+ self.model_ema(self)
95
+
96
+ def encode(self, x):
97
+ h = self.encoder(x)
98
+ h = self.quant_conv(h)
99
+ quant, emb_loss, info = self.quantize(h)
100
+ return quant, emb_loss, info
101
+
102
+ def encode_to_prequant(self, x):
103
+ h = self.encoder(x)
104
+ h = self.quant_conv(h)
105
+ return h
106
+
107
+ def decode(self, quant):
108
+ quant = self.post_quant_conv(quant)
109
+ dec = self.decoder(quant)
110
+ return dec
111
+
112
+ def decode_code(self, code_b):
113
+ quant_b = self.quantize.embed_code(code_b)
114
+ dec = self.decode(quant_b)
115
+ return dec
116
+
117
+ def forward(self, input, return_pred_indices=False):
118
+ quant, diff, (_,_,ind) = self.encode(input)
119
+ dec = self.decode(quant)
120
+ if return_pred_indices:
121
+ return dec, diff, ind
122
+ return dec, diff
123
+
124
+ def get_input(self, batch, k):
125
+ x = batch[k]
126
+ if len(x.shape) == 3:
127
+ x = x[..., None]
128
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
129
+ if self.batch_resize_range is not None:
130
+ lower_size = self.batch_resize_range[0]
131
+ upper_size = self.batch_resize_range[1]
132
+ if self.global_step <= 4:
133
+ # do the first few batches with max size to avoid later oom
134
+ new_resize = upper_size
135
+ else:
136
+ new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
137
+ if new_resize != x.shape[2]:
138
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
139
+ x = x.detach()
140
+ return x
141
+
142
+ def training_step(self, batch, batch_idx, optimizer_idx):
143
+ # https://github.com/pytorch/pytorch/issues/37142
144
+ # try not to fool the heuristics
145
+ x = self.get_input(batch, self.image_key)
146
+ xrec, qloss, ind = self(x, return_pred_indices=True)
147
+
148
+ if optimizer_idx == 0:
149
+ # autoencode
150
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
151
+ last_layer=self.get_last_layer(), split="train",
152
+ predicted_indices=ind)
153
+
154
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
155
+ return aeloss
156
+
157
+ if optimizer_idx == 1:
158
+ # discriminator
159
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
160
+ last_layer=self.get_last_layer(), split="train")
161
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
162
+ return discloss
163
+
164
+ def validation_step(self, batch, batch_idx):
165
+ log_dict = self._validation_step(batch, batch_idx)
166
+ with self.ema_scope():
167
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
168
+ return log_dict
169
+
170
+ def _validation_step(self, batch, batch_idx, suffix=""):
171
+ x = self.get_input(batch, self.image_key)
172
+ xrec, qloss, ind = self(x, return_pred_indices=True)
173
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
174
+ self.global_step,
175
+ last_layer=self.get_last_layer(),
176
+ split="val"+suffix,
177
+ predicted_indices=ind
178
+ )
179
+
180
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
181
+ self.global_step,
182
+ last_layer=self.get_last_layer(),
183
+ split="val"+suffix,
184
+ predicted_indices=ind
185
+ )
186
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
187
+ self.log(f"val{suffix}/rec_loss", rec_loss,
188
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
189
+ self.log(f"val{suffix}/aeloss", aeloss,
190
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
191
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
192
+ del log_dict_ae[f"val{suffix}/rec_loss"]
193
+ self.log_dict(log_dict_ae)
194
+ self.log_dict(log_dict_disc)
195
+ return self.log_dict
196
+
197
+ def configure_optimizers(self):
198
+ lr_d = self.learning_rate
199
+ lr_g = self.lr_g_factor*self.learning_rate
200
+ print("lr_d", lr_d)
201
+ print("lr_g", lr_g)
202
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
203
+ list(self.decoder.parameters())+
204
+ list(self.quantize.parameters())+
205
+ list(self.quant_conv.parameters())+
206
+ list(self.post_quant_conv.parameters()),
207
+ lr=lr_g, betas=(0.5, 0.9))
208
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
209
+ lr=lr_d, betas=(0.5, 0.9))
210
+
211
+ if self.scheduler_config is not None:
212
+ scheduler = instantiate_from_config(self.scheduler_config)
213
+
214
+ print("Setting up LambdaLR scheduler...")
215
+ scheduler = [
216
+ {
217
+ 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
218
+ 'interval': 'step',
219
+ 'frequency': 1
220
+ },
221
+ {
222
+ 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
223
+ 'interval': 'step',
224
+ 'frequency': 1
225
+ },
226
+ ]
227
+ return [opt_ae, opt_disc], scheduler
228
+ return [opt_ae, opt_disc], []
229
+
230
+ def get_last_layer(self):
231
+ return self.decoder.conv_out.weight
232
+
233
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
234
+ log = dict()
235
+ x = self.get_input(batch, self.image_key)
236
+ x = x.to(self.device)
237
+ if only_inputs:
238
+ log["inputs"] = x
239
+ return log
240
+ xrec, _ = self(x)
241
+ if x.shape[1] > 3:
242
+ # colorize with random projection
243
+ assert xrec.shape[1] > 3
244
+ x = self.to_rgb(x)
245
+ xrec = self.to_rgb(xrec)
246
+ log["inputs"] = x
247
+ log["reconstructions"] = xrec
248
+ if plot_ema:
249
+ with self.ema_scope():
250
+ xrec_ema, _ = self(x)
251
+ if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
252
+ log["reconstructions_ema"] = xrec_ema
253
+ return log
254
+
255
+ def to_rgb(self, x):
256
+ assert self.image_key == "segmentation"
257
+ if not hasattr(self, "colorize"):
258
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
259
+ x = F.conv2d(x, weight=self.colorize)
260
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
261
+ return x
262
+
263
+
264
+ class VQModelInterface(VQModel):
265
+ def __init__(self, embed_dim, *args, **kwargs):
266
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
267
+ self.embed_dim = embed_dim
268
+
269
+ def encode(self, x):
270
+ h = self.encoder(x)
271
+ h = self.quant_conv(h)
272
+ return h
273
+
274
+ def decode(self, h, force_not_quantize=False):
275
+ # also go through quantization layer
276
+ if not force_not_quantize:
277
+ quant, emb_loss, info = self.quantize(h)
278
+ else:
279
+ quant = h
280
+ quant = self.post_quant_conv(quant)
281
+ dec = self.decoder(quant)
282
+ return dec
283
+
284
+
285
+ class AutoencoderKL(pl.LightningModule):
286
+ def __init__(self,
287
+ ddconfig,
288
+ lossconfig,
289
+ embed_dim,
290
+ ckpt_path=None,
291
+ ignore_keys=[],
292
+ image_key="image",
293
+ colorize_nlabels=None,
294
+ monitor=None,
295
+ ):
296
+ super().__init__()
297
+ self.image_key = image_key
298
+ self.encoder = Encoder(**ddconfig)
299
+ self.decoder = Decoder(**ddconfig)
300
+ self.loss = instantiate_from_config(lossconfig)
301
+ assert ddconfig["double_z"]
302
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
303
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
304
+ self.embed_dim = embed_dim
305
+ if colorize_nlabels is not None:
306
+ assert type(colorize_nlabels)==int
307
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
308
+ if monitor is not None:
309
+ self.monitor = monitor
310
+ if ckpt_path is not None:
311
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
312
+
313
+ def init_from_ckpt(self, path, ignore_keys=list()):
314
+ sd = torch.load(path, map_location="cpu")["state_dict"]
315
+ keys = list(sd.keys())
316
+ for k in keys:
317
+ for ik in ignore_keys:
318
+ if k.startswith(ik):
319
+ print("Deleting key {} from state_dict.".format(k))
320
+ del sd[k]
321
+ self.load_state_dict(sd, strict=False)
322
+ print(f"Restored from {path}")
323
+
324
+ def encode(self, x):
325
+ h = self.encoder(x)
326
+ moments = self.quant_conv(h)
327
+ posterior = DiagonalGaussianDistribution(moments)
328
+ return posterior
329
+
330
+ def decode(self, z):
331
+ z = self.post_quant_conv(z)
332
+ dec = self.decoder(z)
333
+ return dec
334
+
335
+ def forward(self, input, sample_posterior=True):
336
+ posterior = self.encode(input)
337
+ if sample_posterior:
338
+ z = posterior.sample()
339
+ else:
340
+ z = posterior.mode()
341
+ dec = self.decode(z)
342
+ return dec, posterior
343
+
344
+ def get_input(self, batch, k):
345
+ x = batch[k]
346
+ if len(x.shape) == 3:
347
+ x = x[..., None]
348
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
349
+ return x
350
+
351
+ def training_step(self, batch, batch_idx, optimizer_idx):
352
+ inputs = self.get_input(batch, self.image_key)
353
+ reconstructions, posterior = self(inputs)
354
+
355
+ if optimizer_idx == 0:
356
+ # train encoder+decoder+logvar
357
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
358
+ last_layer=self.get_last_layer(), split="train")
359
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
360
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
361
+ return aeloss
362
+
363
+ if optimizer_idx == 1:
364
+ # train the discriminator
365
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
366
+ last_layer=self.get_last_layer(), split="train")
367
+
368
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
369
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
370
+ return discloss
371
+
372
+ def validation_step(self, batch, batch_idx):
373
+ inputs = self.get_input(batch, self.image_key)
374
+ reconstructions, posterior = self(inputs)
375
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
376
+ last_layer=self.get_last_layer(), split="val")
377
+
378
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
379
+ last_layer=self.get_last_layer(), split="val")
380
+
381
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
382
+ self.log_dict(log_dict_ae)
383
+ self.log_dict(log_dict_disc)
384
+ return self.log_dict
385
+
386
+ def configure_optimizers(self):
387
+ lr = self.learning_rate
388
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
389
+ list(self.decoder.parameters())+
390
+ list(self.quant_conv.parameters())+
391
+ list(self.post_quant_conv.parameters()),
392
+ lr=lr, betas=(0.5, 0.9))
393
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
394
+ lr=lr, betas=(0.5, 0.9))
395
+ return [opt_ae, opt_disc], []
396
+
397
+ def get_last_layer(self):
398
+ return self.decoder.conv_out.weight
399
+
400
+ @torch.no_grad()
401
+ def log_images(self, batch, only_inputs=False, **kwargs):
402
+ log = dict()
403
+ x = self.get_input(batch, self.image_key)
404
+ x = x.to(self.device)
405
+ if not only_inputs:
406
+ xrec, posterior = self(x)
407
+ if x.shape[1] > 3:
408
+ # colorize with random projection
409
+ assert xrec.shape[1] > 3
410
+ x = self.to_rgb(x)
411
+ xrec = self.to_rgb(xrec)
412
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
413
+ log["reconstructions"] = xrec
414
+ log["inputs"] = x
415
+ return log
416
+
417
+ def to_rgb(self, x):
418
+ assert self.image_key == "segmentation"
419
+ if not hasattr(self, "colorize"):
420
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
421
+ x = F.conv2d(x, weight=self.colorize)
422
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
423
+ return x
424
+
425
+
426
+ class IdentityFirstStage(torch.nn.Module):
427
+ def __init__(self, *args, vq_interface=False, **kwargs):
428
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
429
+ super().__init__()
430
+
431
+ def encode(self, x, *args, **kwargs):
432
+ return x
433
+
434
+ def decode(self, x, *args, **kwargs):
435
+ return x
436
+
437
+ def quantize(self, x, *args, **kwargs):
438
+ if self.vq_interface:
439
+ return x, None, [None, None, None]
440
+ return x
441
+
442
+ def forward(self, x, *args, **kwargs):
443
+ return x
ldm/models/diffusion/__init__.py ADDED
File without changes
ldm/models/diffusion/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (189 Bytes). View file
 
ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (193 Bytes). View file
 
ldm/models/diffusion/__pycache__/ddim.cpython-36.pyc ADDED
Binary file (6.19 kB). View file