mta122 commited on
Commit
a8e58c8
1 Parent(s): 40845cc
Files changed (32) hide show
  1. __pycache__/categories.cpython-38.pyc +0 -0
  2. configs/finetune/finetune_bert.yaml +0 -128
  3. configs/finetune/finetune_clip.yaml +0 -118
  4. configs/finetune/finetune_generic.yaml +7 -7
  5. configs/finetune/finetune_multi_bert.yaml +0 -127
  6. configs/finetune/finetune_multi_clip.yaml +0 -118
  7. ldm/__pycache__/util.cpython-38.pyc +0 -0
  8. ldm/models/__pycache__/autoencoder.cpython-38.pyc +0 -0
  9. ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc +0 -0
  10. ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc +0 -0
  11. ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc +0 -0
  12. ldm/models/diffusion/__pycache__/plms.cpython-38.pyc +0 -0
  13. ldm/models/diffusion/ddpm.py +1 -105
  14. ldm/modules/__pycache__/attention.cpython-38.pyc +0 -0
  15. ldm/modules/__pycache__/discriminator.cpython-38.pyc +0 -0
  16. ldm/modules/__pycache__/ema.cpython-38.pyc +0 -0
  17. ldm/modules/__pycache__/x_transformer.cpython-38.pyc +0 -0
  18. ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc +0 -0
  19. ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc +0 -0
  20. ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc +0 -0
  21. ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc +0 -0
  22. ldm/modules/discriminator.py +0 -97
  23. ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc +0 -0
  24. ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc +0 -0
  25. ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc +0 -0
  26. ldm/modules/encoders/__pycache__/modules.cpython-38.pyc +0 -0
  27. out/express/DRAGON-R.jpg +0 -0
  28. out/express/samples/0000.png +0 -0
  29. out/express/samples/0001.png +0 -0
  30. out/express/samples/0002.png +0 -0
  31. out/express/samples/0003.png +0 -0
  32. txt2img.py +0 -4
__pycache__/categories.cpython-38.pyc CHANGED
Binary files a/__pycache__/categories.cpython-38.pyc and b/__pycache__/categories.cpython-38.pyc differ
 
configs/finetune/finetune_bert.yaml DELETED
@@ -1,128 +0,0 @@
1
- model:
2
- base_learning_rate: 1.0e-5
3
- target: ldm.models.diffusion.ddpm.LatentDiffusion
4
- params:
5
- linear_start: 0.00085
6
- linear_end: 0.0120
7
- num_timesteps_cond: 1
8
- log_every_t: 200
9
- timesteps: 1000
10
- first_stage_key: "image"
11
- cond_stage_key: "caption"
12
- image_size: 32
13
- channels: 4
14
- cond_stage_trainable: False
15
- conditioning_key: crossattn
16
- monitor: val/loss_simple_ema
17
- scale_factor: 0.18215
18
- use_ema: False
19
- weight_disc: 0.01
20
-
21
- unet_config:
22
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
23
- params:
24
- image_size: 32
25
- in_channels: 4
26
- out_channels: 4
27
- model_channels: 320
28
- attention_resolutions: [ 4, 2, 1 ]
29
- num_res_blocks: 2
30
- channel_mult: [ 1, 2, 4, 4 ]
31
- num_heads: 8
32
- use_spatial_transformer: True
33
- transformer_depth: 1
34
- context_dim: 1280
35
- use_checkpoint: True
36
- legacy: False
37
-
38
- first_stage_config:
39
- target: ldm.models.autoencoder.AutoencoderKL
40
- params:
41
- embed_dim: 4
42
- monitor: val/rec_loss
43
- ddconfig:
44
- double_z: true
45
- z_channels: 4
46
- resolution: 256
47
- in_channels: 3
48
- out_ch: 3
49
- ch: 128
50
- ch_mult:
51
- - 1
52
- - 2
53
- - 4
54
- - 4
55
- num_res_blocks: 2
56
- attn_resolutions: []
57
- dropout: 0.0
58
- lossconfig:
59
- target: torch.nn.Identity
60
-
61
- cond_stage_config:
62
- target: ldm.modules.encoders.modules.BERTEmbedder
63
- params:
64
- n_embed: 1280
65
- n_layer: 32
66
- device: "cuda"
67
-
68
- discriminator_config:
69
- target: ldm.modules.discriminator.Discriminator
70
- params:
71
- bnorm: True
72
- leakyparam: 0.2
73
- bias: False
74
- generic: False
75
-
76
-
77
- data:
78
- target: main.DataModuleFromConfig
79
- params:
80
- batch_size: 1
81
- num_workers: 32
82
- wrap: false
83
- train:
84
- target: ldm.data.rasterizer.Rasterizer
85
- params:
86
- img_size: 256
87
- text: "R"
88
- style_word: "DRAGON"
89
- data_path: "data/cat"
90
- alternate_glyph: None
91
- num_samples: 2001
92
- make_black: False
93
- one_font: False
94
- full_word: False
95
- font_name: "Garuda-Bold.ttf"
96
- just_use_style: false
97
- use_alt: False
98
- validation:
99
- target: ldm.data.rasterizer.Rasterizer
100
- params:
101
- img_size: 256
102
- text: "R"
103
- style_word: "DRAGON"
104
- data_path: "data/cat"
105
- alternate_glyph: None
106
- num_samples: 5
107
- make_black: False
108
- one_font: False
109
- full_word: False
110
- font_name: "Garuda-Bold.ttf"
111
- just_use_style: false
112
- use_alt: False
113
-
114
- lightning:
115
- modelcheckpoint:
116
- params:
117
- every_n_train_steps: 5000
118
- callbacks:
119
- image_logger:
120
- target: main.ImageLogger
121
- params:
122
- batch_frequency: 1000
123
- max_images: 1
124
- increase_log_steps: False
125
-
126
- trainer:
127
- benchmark: True
128
- max_steps: 500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/finetune/finetune_clip.yaml DELETED
@@ -1,118 +0,0 @@
1
- model:
2
- base_learning_rate: 1.0e-5 #1e-4
3
- target: ldm.models.diffusion.ddpm.LatentDiffusion
4
- params:
5
- linear_start: 0.00085
6
- linear_end: 0.0120
7
- num_timesteps_cond: 1
8
- log_every_t: 200
9
- timesteps: 1000
10
- first_stage_key: "image"
11
- cond_stage_key: "caption"
12
- image_size: 64 # 32
13
- channels: 4
14
- cond_stage_trainable: False # Note: different from the one we trained before
15
- conditioning_key: crossattn
16
- monitor: val/loss_simple_ema
17
- scale_factor: 0.18215
18
- use_ema: False
19
- weight_disc: 0.01
20
-
21
- unet_config:
22
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
23
- params:
24
- image_size: 64 # unused
25
- in_channels: 4
26
- out_channels: 4
27
- model_channels: 320
28
- attention_resolutions: [ 4, 2, 1 ]
29
- num_res_blocks: 2
30
- channel_mult: [ 1, 2, 4, 4 ]
31
- num_heads: 8
32
- use_spatial_transformer: True
33
- transformer_depth: 1
34
- context_dim: 768 # 1280
35
- use_checkpoint: True
36
- legacy: False
37
-
38
- first_stage_config:
39
- target: ldm.models.autoencoder.AutoencoderKL
40
- params:
41
- embed_dim: 4
42
- monitor: val/rec_loss
43
- ddconfig:
44
- double_z: true
45
- z_channels: 4
46
- resolution: 512 #256
47
- in_channels: 3
48
- out_ch: 3
49
- ch: 128
50
- ch_mult:
51
- - 1
52
- - 2
53
- - 4
54
- - 4
55
- num_res_blocks: 2
56
- attn_resolutions: []
57
- dropout: 0.0
58
- lossconfig:
59
- target: torch.nn.Identity
60
-
61
- cond_stage_config:
62
- target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
63
-
64
- discriminator_config:
65
- target: ldm.modules.discriminator.Discriminator64
66
-
67
- data:
68
- target: main.DataModuleFromConfig
69
- params:
70
- batch_size: 1
71
- num_workers: 32
72
- wrap: false
73
- train:
74
- target: ldm.data.rasterizer.Rasterizer
75
- params:
76
- img_size: 256
77
- text: "R"
78
- style_word: "DRAGON"
79
- data_path: "data/cat"
80
- alternate_glyph: None
81
- num_samples: 2001
82
- make_black: False
83
- one_font: False
84
- full_word: False
85
- font_name: "Garuda-Bold.ttf"
86
- just_use_style: false
87
- use_alt: False
88
- validation:
89
- target: ldm.data.rasterizer.Rasterizer
90
- params:
91
- img_size: 256
92
- text: "R"
93
- style_word: "DRAGON"
94
- data_path: "data/cat"
95
- alternate_glyph: None
96
- num_samples: 5
97
- make_black: False
98
- one_font: False
99
- full_word: False
100
- font_name: "Garuda-Bold.ttf"
101
- just_use_style: false
102
- use_alt: False
103
-
104
- lightning:
105
- modelcheckpoint:
106
- params:
107
- every_n_train_steps: 200
108
- callbacks:
109
- image_logger:
110
- target: main.ImageLogger
111
- params:
112
- batch_frequency: 100
113
- max_images: 1
114
- increase_log_steps: False
115
-
116
- trainer:
117
- benchmark: True
118
- max_steps: 1001
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/finetune/finetune_generic.yaml CHANGED
@@ -65,13 +65,13 @@ model:
65
  n_layer: 32
66
  device: "cuda"
67
 
68
- discriminator_config:
69
- target: ldm.modules.discriminator.Discriminator
70
- params:
71
- bnorm: True
72
- leakyparam: 0.2
73
- bias: False
74
- generic: True
75
 
76
 
77
  data:
 
65
  n_layer: 32
66
  device: "cuda"
67
 
68
+ # discriminator_config:
69
+ # target: ldm.modules.discriminator.Discriminator
70
+ # params:
71
+ # bnorm: True
72
+ # leakyparam: 0.2
73
+ # bias: False
74
+ # generic: True
75
 
76
 
77
  data:
configs/finetune/finetune_multi_bert.yaml DELETED
@@ -1,127 +0,0 @@
1
- model:
2
- base_learning_rate: 1.0e-5 #1e-4
3
- target: ldm.models.diffusion.ddpm.LatentDiffusion
4
- params:
5
- linear_start: 0.00085
6
- linear_end: 0.0120
7
- num_timesteps_cond: 1
8
- log_every_t: 200
9
- timesteps: 1000
10
- first_stage_key: "image"
11
- cond_stage_key: "caption"
12
- image_size: 32 # 32
13
- channels: 4
14
- cond_stage_trainable: False # Note: different from the one we trained before
15
- conditioning_key: crossattn
16
- monitor: val/loss_simple_ema
17
- scale_factor: 0.18215
18
- use_ema: False
19
- weight_disc: 0.01
20
-
21
- unet_config:
22
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
23
- params:
24
- image_size: 32 # unused
25
- in_channels: 4
26
- out_channels: 4
27
- model_channels: 320
28
- attention_resolutions: [ 4, 2, 1 ]
29
- num_res_blocks: 2
30
- channel_mult: [ 1, 2, 4, 4 ]
31
- num_heads: 8
32
- use_spatial_transformer: True
33
- transformer_depth: 1
34
- context_dim: 1280 # 1280
35
- use_checkpoint: True
36
- legacy: False
37
-
38
- first_stage_config:
39
- target: ldm.models.autoencoder.AutoencoderKL
40
- params:
41
- embed_dim: 4
42
- monitor: val/rec_loss
43
- ddconfig:
44
- double_z: true
45
- z_channels: 4
46
- resolution: 256 #256
47
- in_channels: 3
48
- out_ch: 3
49
- ch: 128
50
- ch_mult:
51
- - 1
52
- - 2
53
- - 4
54
- - 4
55
- num_res_blocks: 2
56
- attn_resolutions: []
57
- dropout: 0.0
58
- lossconfig:
59
- target: torch.nn.Identity
60
-
61
- cond_stage_config:
62
- target: ldm.modules.encoders.modules.BERTEmbedder
63
- params:
64
- n_embed: 1280
65
- n_layer: 32
66
-
67
- discriminator_config:
68
- target: ldm.modules.discriminator.Discriminator
69
- params:
70
- bnorm: True
71
- leakyparam: 0.2
72
- bias: False
73
- generic: False
74
-
75
-
76
- data:
77
- target: main.DataModuleFromConfig
78
- params:
79
- batch_size: 1
80
- num_workers: 32
81
- wrap: false
82
- train:
83
- target: ldm.data.rasterizer.Rasterizer
84
- params:
85
- img_size: 256
86
- text: "R"
87
- style_word: "DRAGON"
88
- data_path: "data/cat"
89
- alternate_glyph: None
90
- num_samples: 2001
91
- make_black: False
92
- one_font: False
93
- full_word: False
94
- font_name: "Garuda-Bold.ttf"
95
- just_use_style: false
96
- use_alt: False
97
- validation:
98
- target: ldm.data.rasterizer.Rasterizer
99
- params:
100
- img_size: 256
101
- text: "R"
102
- style_word: "DRAGON"
103
- data_path: "data/cat"
104
- alternate_glyph: None
105
- num_samples: 5
106
- make_black: False
107
- one_font: False
108
- full_word: False
109
- font_name: "Garuda-Bold.ttf"
110
- just_use_style: false
111
- use_alt: False
112
-
113
- lightning:
114
- modelcheckpoint:
115
- params:
116
- every_n_train_steps: 2000
117
- callbacks:
118
- image_logger:
119
- target: main.ImageLogger
120
- params:
121
- batch_frequency: 5000
122
- max_images: 1
123
- increase_log_steps: False
124
-
125
- trainer:
126
- benchmark: True
127
- max_steps: 800
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/finetune/finetune_multi_clip.yaml DELETED
@@ -1,118 +0,0 @@
1
- model:
2
- base_learning_rate: 1.0e-5 #1e-4
3
- target: ldm.models.diffusion.ddpm.LatentDiffusion
4
- params:
5
- linear_start: 0.00085
6
- linear_end: 0.0120
7
- num_timesteps_cond: 1
8
- log_every_t: 200
9
- timesteps: 1000
10
- first_stage_key: "image"
11
- cond_stage_key: "caption"
12
- image_size: 64 # 32
13
- channels: 4
14
- cond_stage_trainable: False # Note: different from the one we trained before
15
- conditioning_key: crossattn
16
- monitor: val/loss_simple_ema
17
- scale_factor: 0.18215
18
- use_ema: False
19
- weight_disc: 0.01
20
-
21
- unet_config:
22
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
23
- params:
24
- image_size: 64 # unused
25
- in_channels: 4
26
- out_channels: 4
27
- model_channels: 320
28
- attention_resolutions: [ 4, 2, 1 ]
29
- num_res_blocks: 2
30
- channel_mult: [ 1, 2, 4, 4 ]
31
- num_heads: 8
32
- use_spatial_transformer: True
33
- transformer_depth: 1
34
- context_dim: 768 # 1280
35
- use_checkpoint: True
36
- legacy: False
37
-
38
- first_stage_config:
39
- target: ldm.models.autoencoder.AutoencoderKL
40
- params:
41
- embed_dim: 4
42
- monitor: val/rec_loss
43
- ddconfig:
44
- double_z: true
45
- z_channels: 4
46
- resolution: 512 #256
47
- in_channels: 3
48
- out_ch: 3
49
- ch: 128
50
- ch_mult:
51
- - 1
52
- - 2
53
- - 4
54
- - 4
55
- num_res_blocks: 2
56
- attn_resolutions: []
57
- dropout: 0.0
58
- lossconfig:
59
- target: torch.nn.Identity
60
-
61
- cond_stage_config:
62
- target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
63
-
64
- discriminator_config:
65
- target: ldm.modules.discriminator.Discriminator64
66
-
67
- data:
68
- target: main.DataModuleFromConfig
69
- params:
70
- batch_size: 1
71
- num_workers: 32
72
- wrap: false
73
- train:
74
- target: ldm.data.rasterizer.Rasterizer
75
- params:
76
- img_size: 256
77
- text: "R"
78
- style_word: "DRAGON"
79
- data_path: "data/cat"
80
- alternate_glyph: None
81
- num_samples: 2001
82
- make_black: False
83
- one_font: False
84
- full_word: False
85
- font_name: "Garuda-Bold.ttf"
86
- just_use_style: false
87
- use_alt: False
88
- validation:
89
- target: ldm.data.rasterizer.Rasterizer
90
- params:
91
- img_size: 256
92
- text: "R"
93
- style_word: "DRAGON"
94
- data_path: "data/cat"
95
- alternate_glyph: None
96
- num_samples: 5
97
- make_black: False
98
- one_font: False
99
- full_word: False
100
- font_name: "Garuda-Bold.ttf"
101
- just_use_style: false
102
- use_alt: False
103
-
104
- lightning:
105
- modelcheckpoint:
106
- params:
107
- every_n_train_steps: 200
108
- callbacks:
109
- image_logger:
110
- target: main.ImageLogger
111
- params:
112
- batch_frequency: 100
113
- max_images: 1
114
- increase_log_steps: False
115
-
116
- trainer:
117
- benchmark: True
118
- max_steps: 1501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/__pycache__/util.cpython-38.pyc CHANGED
Binary files a/ldm/__pycache__/util.cpython-38.pyc and b/ldm/__pycache__/util.cpython-38.pyc differ
 
ldm/models/__pycache__/autoencoder.cpython-38.pyc CHANGED
Binary files a/ldm/models/__pycache__/autoencoder.cpython-38.pyc and b/ldm/models/__pycache__/autoencoder.cpython-38.pyc differ
 
ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc and b/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc differ
 
ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc CHANGED
Binary files a/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc and b/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc differ
 
ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc CHANGED
Binary files a/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc and b/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc differ
 
ldm/models/diffusion/__pycache__/plms.cpython-38.pyc CHANGED
Binary files a/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc and b/ldm/models/diffusion/__pycache__/plms.cpython-38.pyc differ
 
ldm/models/diffusion/ddpm.py CHANGED
@@ -485,7 +485,7 @@ class LatentDiffusion(DDPM):
485
  self.init_from_ckpt(ckpt_path, ignore_keys)
486
  self.restarted_from_ckpt = True
487
 
488
- self.discriminator = instantiate_from_config(discriminator_config)
489
  self.weight_disc = weight_disc
490
  self.iter = 0
491
 
@@ -919,84 +919,6 @@ class LatentDiffusion(DDPM):
919
  return z_C, alpha
920
 
921
 
922
- def discriminator_loss(self, batch, optimizer_idx =0):
923
- #
924
- criterion = nn.BCELoss()
925
-
926
- real_label = 1.
927
- fake_label = 0.
928
-
929
- caption = batch["cond"]
930
- with torch.no_grad():
931
- cond = self.get_learned_conditioning(caption)
932
-
933
- img1 = rearrange(batch["style"]["image"], 'b h w c -> b c h w')
934
- save_image(img1, "img_style.png")
935
- img1_base = img1.to(memory_format=torch.contiguous_format).float()
936
- img1 = self.encode_first_stage(img1_base)
937
- z_S = self.get_first_stage_encoding(img1).detach()
938
-
939
- img2 = rearrange(batch["base"]["image"], 'b h w c -> b c h w')
940
- save_image(img2, "img_base.png")
941
- img2_base = img2.to(memory_format=torch.contiguous_format).float()
942
- img2 = self.encode_first_stage(img2_base)
943
- z_R = self.get_first_stage_encoding(img2).detach()
944
-
945
- x_start = z_S
946
- real_x = z_R
947
-
948
- t = torch.randint(0, self.num_timesteps, (z_S.shape[0],), device=self.device).long()
949
- logvar_t = self.logvar[t.cpu()].to(self.device)
950
-
951
- noise = default(None, lambda: torch.randn_like(z_S))
952
-
953
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
954
-
955
- letter = batch["number"][0].cpu().detach().numpy()
956
-
957
- #update generator
958
- if optimizer_idx == 0:
959
-
960
- noise1 = self.apply_model(x_noisy, t, cond)
961
- z_theta = self.predict_start_from_noise(x_noisy,t,noise1)
962
- fake_x = z_theta
963
-
964
- loss_diff = self.get_loss(noise1, noise, mean=False).mean([1,2,3])
965
- loss_diff = loss_diff / torch.exp(logvar_t) + logvar_t
966
- loss_diff = self.l_simple_weight * loss_diff.mean()
967
-
968
- label = torch.full((1,), real_label, dtype=torch.float, device=self.device)
969
-
970
- output = self.discriminator(fake_x, letter).view(-1)
971
- loss_disc = criterion(output, label)
972
-
973
- return loss_diff, loss_disc
974
-
975
- #update discriminator
976
- if optimizer_idx == 1:
977
-
978
- noise1 = self.apply_model(x_noisy, t, cond)
979
- z_theta = self.predict_start_from_noise(x_noisy,t,noise1)
980
- fake_x = z_theta
981
-
982
- loss_diff = self.get_loss(noise1, noise, mean=False).mean([1,2,3])
983
- loss_diff = loss_diff / torch.exp(logvar_t) + logvar_t
984
- loss_diff = self.l_simple_weight * loss_diff.mean()
985
-
986
- label = torch.full((1,), real_label, dtype=torch.float, device=self.device)
987
- output = self.discriminator(real_x, letter).view(-1)
988
- loss1 = criterion(output, label)
989
-
990
-
991
- label = torch.full((1,), fake_label, dtype=torch.float, device=self.device)
992
- output = self.discriminator(fake_x, letter).view(-1)
993
- loss2 = criterion(output, label)
994
-
995
- loss_disc = (loss1+loss2)/2
996
-
997
- return loss_diff, loss_disc
998
-
999
-
1000
  def make_images(self, batch):
1001
  batch = batch["base"]
1002
  use_ddim = 50
@@ -1073,32 +995,6 @@ class LatentDiffusion(DDPM):
1073
  Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join("out_cur/", f"{base_count:04}.png"))
1074
  base_count += 1
1075
 
1076
- def training_step(self, batch,batch_idx, optimizer_idx=None):
1077
-
1078
- loss_diff, loss_disc = self.discriminator_loss(batch, optimizer_idx=optimizer_idx)
1079
- loss =loss_diff+ self.weight_disc*loss_disc
1080
- self.iter+=1
1081
-
1082
- if (self.iter-1) % 100 == 0:
1083
- self.log_view(batch)
1084
-
1085
- # if self.iter == batch["epochs"]:
1086
- # self.last_step_run(batch)
1087
-
1088
-
1089
- return loss
1090
-
1091
- @torch.no_grad()
1092
- def validation_step(self, batch,optimizer_idx):
1093
-
1094
- return None
1095
- cap = batch["cond"]
1096
- batch = batch["base"]
1097
- batch["caption"] = cap
1098
-
1099
- loss, loss_dict_no_ema = self.shared_step(batch)
1100
- self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
1101
-
1102
  def forward(self, x, c, *args, **kwargs):
1103
  t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
1104
  if self.model.conditioning_key is not None:
 
485
  self.init_from_ckpt(ckpt_path, ignore_keys)
486
  self.restarted_from_ckpt = True
487
 
488
+ # self.discriminator = instantiate_from_config(discriminator_config)
489
  self.weight_disc = weight_disc
490
  self.iter = 0
491
 
 
919
  return z_C, alpha
920
 
921
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
922
  def make_images(self, batch):
923
  batch = batch["base"]
924
  use_ddim = 50
 
995
  Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join("out_cur/", f"{base_count:04}.png"))
996
  base_count += 1
997
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
998
  def forward(self, x, c, *args, **kwargs):
999
  t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
1000
  if self.model.conditioning_key is not None:
ldm/modules/__pycache__/attention.cpython-38.pyc CHANGED
Binary files a/ldm/modules/__pycache__/attention.cpython-38.pyc and b/ldm/modules/__pycache__/attention.cpython-38.pyc differ
 
ldm/modules/__pycache__/discriminator.cpython-38.pyc CHANGED
Binary files a/ldm/modules/__pycache__/discriminator.cpython-38.pyc and b/ldm/modules/__pycache__/discriminator.cpython-38.pyc differ
 
ldm/modules/__pycache__/ema.cpython-38.pyc CHANGED
Binary files a/ldm/modules/__pycache__/ema.cpython-38.pyc and b/ldm/modules/__pycache__/ema.cpython-38.pyc differ
 
ldm/modules/__pycache__/x_transformer.cpython-38.pyc CHANGED
Binary files a/ldm/modules/__pycache__/x_transformer.cpython-38.pyc and b/ldm/modules/__pycache__/x_transformer.cpython-38.pyc differ
 
ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc and b/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc differ
 
ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc CHANGED
Binary files a/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc and b/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc differ
 
ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc CHANGED
Binary files a/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc and b/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc differ
 
ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc CHANGED
Binary files a/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc and b/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc differ
 
ldm/modules/discriminator.py DELETED
@@ -1,97 +0,0 @@
1
- from torch import nn
2
- import pdb
3
- import torch
4
-
5
- # to use with clip
6
- class Discriminator64(nn.Module):
7
- def __init__(self, bnorm=True, leakyparam=0.0, bias=False, generic=False):
8
- super(Discriminator64, self).__init__()
9
-
10
- self.bnorm = bnorm
11
- self.generic = generic
12
-
13
- self.relu = nn.LeakyReLU(leakyparam, inplace=True)
14
-
15
- self.bn2 = nn.BatchNorm2d(128)
16
- self.bn3 = nn.BatchNorm2d(256)
17
- self.bn4 = nn.BatchNorm2d(512)
18
-
19
- self.layer1 = nn.Conv2d(4, 64, 4, 2, 1, bias=bias)
20
- self.layer2 = nn.Conv2d(64, 128, 4, 2, 1, bias=bias)
21
- self.layer3 = nn.Conv2d(128, 256, 4, 2, 1, bias=bias)
22
- self.layer4 = nn.Conv2d(256, 512, 4, 2, 1, bias=bias)
23
- if generic:
24
- self.layer5 = nn.Conv2d(512, 26, 4, 1, 0, bias=bias)
25
- else:
26
- self.layer5 = nn.Conv2d(512, 1, 4, 1, 0, bias=bias)
27
- self.sig = nn.Sigmoid()
28
-
29
-
30
- def forward(self, input, letter):
31
- out1 = self.relu(self.layer1(input))
32
-
33
- if self.bnorm:
34
- out2 = self.relu(self.bn2(self.layer2(out1)))
35
- out3 = self.relu(self.bn3(self.layer3(out2)))
36
- out4= self.relu(self.bn4(self.layer4(out3)))
37
- else:
38
- out2 = self.relu(self.layer2(out1))
39
- out3 = self.relu(self.layer3(out2))
40
- out4= self.relu(self.layer4(out3))
41
-
42
- out5 = self.sig(self.layer5(out4))
43
- out5 = out5.flatten()
44
-
45
- if self.generic:
46
- out5 = out5[letter].mean()
47
- else:
48
- out5 = out5.mean()
49
-
50
- return out5
51
-
52
-
53
- # to use with bert
54
- class Discriminator(nn.Module):
55
- def __init__(self, bnorm=True, leakyparam=0.0, bias=False, generic=False):
56
- super(Discriminator, self).__init__()
57
-
58
- self.bnorm = bnorm
59
- self.generic = generic
60
-
61
- self.relu = nn.LeakyReLU(leakyparam, inplace=True)
62
- self.sig = nn.Sigmoid()
63
- self.bn2 = nn.BatchNorm2d(128)
64
- self.bn3 = nn.BatchNorm2d(256)
65
- self.bn4 = nn.BatchNorm2d(512)
66
-
67
- self.layer1 = nn.Conv2d(4, 64, 4, 2, 1, bias=bias)
68
- self.layer2 = nn.Conv2d(64, 128, 4, 2, 1, bias=bias)
69
- self.layer3 = nn.Conv2d(128, 256, 4, 2, 1, bias=bias)
70
- self.layer4 = nn.Conv2d(256, 512, 4, 2, 1, bias=bias)
71
- if generic:
72
- self.layer5 = nn.Conv2d(512, 26, 2, 1, 0, bias=bias)
73
- else:
74
- self.layer5 = nn.Conv2d(512, 1, 2, 1, 0, bias=bias)
75
-
76
- def forward(self, input, letter):
77
-
78
- out1 = self.relu(self.layer1(input))
79
-
80
- if self.bnorm:
81
- out2 = self.relu(self.bn2(self.layer2(out1)))
82
- out3 = self.relu(self.bn3(self.layer3(out2)))
83
- out4= self.relu(self.bn4(self.layer4(out3)))
84
- else:
85
- out2 = self.relu(self.layer2(out1))
86
- out3 = self.relu(self.layer3(out2))
87
- out4= self.relu(self.layer4(out3))
88
-
89
- out5 = self.sig(self.layer5(out4))
90
- out5 = out5.flatten()
91
-
92
- if self.generic:
93
- out5 = out5[letter].mean()
94
- else:
95
- out5 = out5.mean()
96
-
97
- return out5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc and b/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc differ
 
ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc CHANGED
Binary files a/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc and b/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc differ
 
ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc and b/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc differ
 
ldm/modules/encoders/__pycache__/modules.cpython-38.pyc CHANGED
Binary files a/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc and b/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc differ
 
out/express/DRAGON-R.jpg CHANGED
out/express/samples/0000.png CHANGED
out/express/samples/0001.png CHANGED
out/express/samples/0002.png CHANGED
out/express/samples/0003.png CHANGED
txt2img.py CHANGED
@@ -132,10 +132,6 @@ if __name__ == "__main__":
132
  seed_everything(seed)
133
 
134
  # config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic
135
- if opt.H == 512:
136
- config = OmegaConf.load("configs/finetune/finetune_clip.yaml")
137
- else:
138
- config = OmegaConf.load("configs/finetune/finetune_bert.yaml")
139
 
140
  config = OmegaConf.load("configs/finetune/finetune_generic.yaml")
141
  # config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-finetune2.yaml")
 
132
  seed_everything(seed)
133
 
134
  # config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic
 
 
 
 
135
 
136
  config = OmegaConf.load("configs/finetune/finetune_generic.yaml")
137
  # config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-finetune2.yaml")