toto10 commited on
Commit
8ebef76
1 Parent(s): ca18358

16c4b8a41720fbb7a5cf8580e62ad4d463b9237ded74f3ae85ffe22c93c3d34b

Browse files
Files changed (50) hide show
  1. repositories/generative-models/configs/example_training/toy/mnist_cond_discrete_eps.yaml +104 -0
  2. repositories/generative-models/configs/example_training/toy/mnist_cond_l1_loss.yaml +104 -0
  3. repositories/generative-models/configs/example_training/toy/mnist_cond_with_ema.yaml +101 -0
  4. repositories/generative-models/configs/example_training/txt2img-clipl-legacy-ucg-training.yaml +185 -0
  5. repositories/generative-models/configs/example_training/txt2img-clipl.yaml +186 -0
  6. repositories/generative-models/configs/inference/sd_2_1.yaml +66 -0
  7. repositories/generative-models/configs/inference/sd_2_1_768.yaml +66 -0
  8. repositories/generative-models/configs/inference/sd_xl_base.yaml +98 -0
  9. repositories/generative-models/configs/inference/sd_xl_refiner.yaml +91 -0
  10. repositories/generative-models/data/DejaVuSans.ttf +0 -0
  11. repositories/generative-models/main.py +946 -0
  12. repositories/generative-models/requirements_pt13.txt +41 -0
  13. repositories/generative-models/requirements_pt2.txt +41 -0
  14. repositories/generative-models/scripts/__init__.py +0 -0
  15. repositories/generative-models/scripts/demo/__init__.py +0 -0
  16. repositories/generative-models/scripts/demo/detect.py +156 -0
  17. repositories/generative-models/scripts/demo/sampling.py +329 -0
  18. repositories/generative-models/scripts/demo/streamlit_helpers.py +666 -0
  19. repositories/generative-models/scripts/util/__init__.py +0 -0
  20. repositories/generative-models/scripts/util/detection/__init__.py +0 -0
  21. repositories/generative-models/scripts/util/detection/nsfw_and_watermark_dectection.py +104 -0
  22. repositories/generative-models/scripts/util/detection/p_head_v1.npz +3 -0
  23. repositories/generative-models/scripts/util/detection/w_head_v1.npz +3 -0
  24. repositories/generative-models/setup.py +13 -0
  25. repositories/generative-models/sgm/__init__.py +3 -0
  26. repositories/generative-models/sgm/__pycache__/__init__.cpython-310.pyc +0 -0
  27. repositories/generative-models/sgm/__pycache__/util.cpython-310.pyc +0 -0
  28. repositories/generative-models/sgm/data/__init__.py +1 -0
  29. repositories/generative-models/sgm/data/cifar10.py +67 -0
  30. repositories/generative-models/sgm/data/dataset.py +80 -0
  31. repositories/generative-models/sgm/data/mnist.py +85 -0
  32. repositories/generative-models/sgm/lr_scheduler.py +135 -0
  33. repositories/generative-models/sgm/models/__init__.py +2 -0
  34. repositories/generative-models/sgm/models/__pycache__/__init__.cpython-310.pyc +0 -0
  35. repositories/generative-models/sgm/models/__pycache__/autoencoder.cpython-310.pyc +0 -0
  36. repositories/generative-models/sgm/models/__pycache__/diffusion.cpython-310.pyc +0 -0
  37. repositories/generative-models/sgm/models/autoencoder.py +335 -0
  38. repositories/generative-models/sgm/models/diffusion.py +320 -0
  39. repositories/generative-models/sgm/modules/__init__.py +6 -0
  40. repositories/generative-models/sgm/modules/__pycache__/__init__.cpython-310.pyc +0 -0
  41. repositories/generative-models/sgm/modules/__pycache__/attention.cpython-310.pyc +0 -0
  42. repositories/generative-models/sgm/modules/__pycache__/ema.cpython-310.pyc +0 -0
  43. repositories/generative-models/sgm/modules/attention.py +947 -0
  44. repositories/generative-models/sgm/modules/autoencoding/__init__.py +0 -0
  45. repositories/generative-models/sgm/modules/autoencoding/__pycache__/__init__.cpython-310.pyc +0 -0
  46. repositories/generative-models/sgm/modules/autoencoding/losses/__init__.py +246 -0
  47. repositories/generative-models/sgm/modules/autoencoding/regularizers/__init__.py +53 -0
  48. repositories/generative-models/sgm/modules/autoencoding/regularizers/__pycache__/__init__.cpython-310.pyc +0 -0
  49. repositories/generative-models/sgm/modules/diffusionmodules/__init__.py +7 -0
  50. repositories/generative-models/sgm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc +0 -0
repositories/generative-models/configs/example_training/toy/mnist_cond_discrete_eps.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ denoiser_config:
6
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
7
+ params:
8
+ num_idx: 1000
9
+
10
+ weighting_config:
11
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
12
+ scaling_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
14
+ discretization_config:
15
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
16
+
17
+ network_config:
18
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
19
+ params:
20
+ use_checkpoint: True
21
+ in_channels: 1
22
+ out_channels: 1
23
+ model_channels: 32
24
+ attention_resolutions: [ ]
25
+ num_res_blocks: 4
26
+ channel_mult: [ 1, 2, 2 ]
27
+ num_head_channels: 32
28
+ num_classes: sequential
29
+ adm_in_channels: 128
30
+
31
+ conditioner_config:
32
+ target: sgm.modules.GeneralConditioner
33
+ params:
34
+ emb_models:
35
+ - is_trainable: True
36
+ input_key: "cls"
37
+ ucg_rate: 0.2
38
+ target: sgm.modules.encoders.modules.ClassEmbedder
39
+ params:
40
+ embed_dim: 128
41
+ n_classes: 10
42
+
43
+ first_stage_config:
44
+ target: sgm.models.autoencoder.IdentityFirstStage
45
+
46
+ loss_fn_config:
47
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
48
+ params:
49
+ sigma_sampler_config:
50
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
51
+ params:
52
+ num_idx: 1000
53
+
54
+ discretization_config:
55
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
56
+
57
+ sampler_config:
58
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
59
+ params:
60
+ num_steps: 50
61
+
62
+ discretization_config:
63
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
64
+
65
+ guider_config:
66
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
67
+ params:
68
+ scale: 5.0
69
+
70
+ data:
71
+ target: sgm.data.mnist.MNISTLoader
72
+ params:
73
+ batch_size: 512
74
+ num_workers: 1
75
+
76
+ lightning:
77
+ modelcheckpoint:
78
+ params:
79
+ every_n_train_steps: 5000
80
+
81
+ callbacks:
82
+ metrics_over_trainsteps_checkpoint:
83
+ params:
84
+ every_n_train_steps: 25000
85
+
86
+ image_logger:
87
+ target: main.ImageLogger
88
+ params:
89
+ disabled: False
90
+ batch_frequency: 1000
91
+ max_images: 16
92
+ increase_log_steps: True
93
+ log_first_step: False
94
+ log_images_kwargs:
95
+ use_ema_scope: False
96
+ N: 16
97
+ n_rows: 4
98
+
99
+ trainer:
100
+ devices: 0,
101
+ benchmark: True
102
+ num_sanity_val_steps: 0
103
+ accumulate_grad_batches: 1
104
+ max_epochs: 20
repositories/generative-models/configs/example_training/toy/mnist_cond_l1_loss.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ denoiser_config:
6
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
7
+ params:
8
+ weighting_config:
9
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
10
+ params:
11
+ sigma_data: 1.0
12
+ scaling_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
14
+ params:
15
+ sigma_data: 1.0
16
+
17
+ network_config:
18
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
19
+ params:
20
+ use_checkpoint: True
21
+ in_channels: 1
22
+ out_channels: 1
23
+ model_channels: 32
24
+ attention_resolutions: []
25
+ num_res_blocks: 4
26
+ channel_mult: [1, 2, 2]
27
+ num_head_channels: 32
28
+ num_classes: "sequential"
29
+ adm_in_channels: 128
30
+
31
+ conditioner_config:
32
+ target: sgm.modules.GeneralConditioner
33
+ params:
34
+ emb_models:
35
+ - is_trainable: True
36
+ input_key: "cls"
37
+ ucg_rate: 0.2
38
+ target: sgm.modules.encoders.modules.ClassEmbedder
39
+ params:
40
+ embed_dim: 128
41
+ n_classes: 10
42
+
43
+ first_stage_config:
44
+ target: sgm.models.autoencoder.IdentityFirstStage
45
+
46
+ loss_fn_config:
47
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
48
+ params:
49
+ sigma_sampler_config:
50
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
51
+
52
+ sampler_config:
53
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
54
+ params:
55
+ num_steps: 50
56
+
57
+ discretization_config:
58
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
59
+
60
+ guider_config:
61
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
62
+ params:
63
+ scale: 3.0
64
+
65
+ loss_config:
66
+ target: sgm.modules.diffusionmodules.StandardDiffusionLoss
67
+ params:
68
+ type: l1
69
+
70
+ data:
71
+ target: sgm.data.mnist.MNISTLoader
72
+ params:
73
+ batch_size: 512
74
+ num_workers: 1
75
+
76
+ lightning:
77
+ modelcheckpoint:
78
+ params:
79
+ every_n_train_steps: 5000
80
+
81
+ callbacks:
82
+ metrics_over_trainsteps_checkpoint:
83
+ params:
84
+ every_n_train_steps: 25000
85
+
86
+ image_logger:
87
+ target: main.ImageLogger
88
+ params:
89
+ disabled: False
90
+ batch_frequency: 1000
91
+ max_images: 64
92
+ increase_log_steps: True
93
+ log_first_step: False
94
+ log_images_kwargs:
95
+ use_ema_scope: False
96
+ N: 64
97
+ n_rows: 8
98
+
99
+ trainer:
100
+ devices: 0,
101
+ benchmark: True
102
+ num_sanity_val_steps: 0
103
+ accumulate_grad_batches: 1
104
+ max_epochs: 20
repositories/generative-models/configs/example_training/toy/mnist_cond_with_ema.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ use_ema: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
9
+ params:
10
+ weighting_config:
11
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EDMWeighting
12
+ params:
13
+ sigma_data: 1.0
14
+ scaling_config:
15
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
16
+ params:
17
+ sigma_data: 1.0
18
+
19
+ network_config:
20
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
21
+ params:
22
+ use_checkpoint: True
23
+ in_channels: 1
24
+ out_channels: 1
25
+ model_channels: 32
26
+ attention_resolutions: []
27
+ num_res_blocks: 4
28
+ channel_mult: [1, 2, 2]
29
+ num_head_channels: 32
30
+ num_classes: sequential
31
+ adm_in_channels: 128
32
+
33
+ conditioner_config:
34
+ target: sgm.modules.GeneralConditioner
35
+ params:
36
+ emb_models:
37
+ - is_trainable: True
38
+ input_key: cls
39
+ ucg_rate: 0.2
40
+ target: sgm.modules.encoders.modules.ClassEmbedder
41
+ params:
42
+ embed_dim: 128
43
+ n_classes: 10
44
+
45
+ first_stage_config:
46
+ target: sgm.models.autoencoder.IdentityFirstStage
47
+
48
+ loss_fn_config:
49
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
50
+ params:
51
+ sigma_sampler_config:
52
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
53
+
54
+ sampler_config:
55
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
56
+ params:
57
+ num_steps: 50
58
+
59
+ discretization_config:
60
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
61
+
62
+ guider_config:
63
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
64
+ params:
65
+ scale: 3.0
66
+
67
+ data:
68
+ target: sgm.data.mnist.MNISTLoader
69
+ params:
70
+ batch_size: 512
71
+ num_workers: 1
72
+
73
+ lightning:
74
+ modelcheckpoint:
75
+ params:
76
+ every_n_train_steps: 5000
77
+
78
+ callbacks:
79
+ metrics_over_trainsteps_checkpoint:
80
+ params:
81
+ every_n_train_steps: 25000
82
+
83
+ image_logger:
84
+ target: main.ImageLogger
85
+ params:
86
+ disabled: False
87
+ batch_frequency: 1000
88
+ max_images: 64
89
+ increase_log_steps: True
90
+ log_first_step: False
91
+ log_images_kwargs:
92
+ use_ema_scope: False
93
+ N: 64
94
+ n_rows: 8
95
+
96
+ trainer:
97
+ devices: 0,
98
+ benchmark: True
99
+ num_sanity_val_steps: 0
100
+ accumulate_grad_batches: 1
101
+ max_epochs: 20
repositories/generative-models/configs/example_training/txt2img-clipl-legacy-ucg-training.yaml ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ scale_factor: 0.13025
6
+ disable_first_stage_autocast: True
7
+ log_keys:
8
+ - txt
9
+
10
+ scheduler_config:
11
+ target: sgm.lr_scheduler.LambdaLinearScheduler
12
+ params:
13
+ warm_up_steps: [ 10000 ]
14
+ cycle_lengths: [ 10000000000000 ]
15
+ f_start: [ 1.e-6 ]
16
+ f_max: [ 1. ]
17
+ f_min: [ 1. ]
18
+
19
+ denoiser_config:
20
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
21
+ params:
22
+ num_idx: 1000
23
+
24
+ weighting_config:
25
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
26
+ scaling_config:
27
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
28
+ discretization_config:
29
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
30
+
31
+ network_config:
32
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
33
+ params:
34
+ use_checkpoint: True
35
+ use_fp16: True
36
+ in_channels: 4
37
+ out_channels: 4
38
+ model_channels: 320
39
+ attention_resolutions: [ 1, 2, 4 ]
40
+ num_res_blocks: 2
41
+ channel_mult: [ 1, 2, 4, 4 ]
42
+ num_head_channels: 64
43
+ num_classes: sequential
44
+ adm_in_channels: 1792
45
+ num_heads: 1
46
+ use_spatial_transformer: true
47
+ transformer_depth: 1
48
+ context_dim: 768
49
+ spatial_transformer_attn_type: softmax-xformers
50
+
51
+ conditioner_config:
52
+ target: sgm.modules.GeneralConditioner
53
+ params:
54
+ emb_models:
55
+ # crossattn cond
56
+ - is_trainable: True
57
+ input_key: txt
58
+ ucg_rate: 0.1
59
+ legacy_ucg_value: ""
60
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
61
+ params:
62
+ always_return_pooled: True
63
+ # vector cond
64
+ - is_trainable: False
65
+ ucg_rate: 0.1
66
+ input_key: original_size_as_tuple
67
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
68
+ params:
69
+ outdim: 256 # multiplied by two
70
+ # vector cond
71
+ - is_trainable: False
72
+ input_key: crop_coords_top_left
73
+ ucg_rate: 0.1
74
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
75
+ params:
76
+ outdim: 256 # multiplied by two
77
+
78
+ first_stage_config:
79
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
80
+ params:
81
+ ckpt_path: CKPT_PATH
82
+ embed_dim: 4
83
+ monitor: val/rec_loss
84
+ ddconfig:
85
+ attn_type: vanilla-xformers
86
+ double_z: true
87
+ z_channels: 4
88
+ resolution: 256
89
+ in_channels: 3
90
+ out_ch: 3
91
+ ch: 128
92
+ ch_mult: [ 1, 2, 4, 4 ]
93
+ num_res_blocks: 2
94
+ attn_resolutions: [ ]
95
+ dropout: 0.0
96
+ lossconfig:
97
+ target: torch.nn.Identity
98
+
99
+ loss_fn_config:
100
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
101
+ params:
102
+ sigma_sampler_config:
103
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
104
+ params:
105
+ num_idx: 1000
106
+
107
+ discretization_config:
108
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
109
+
110
+ sampler_config:
111
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
112
+ params:
113
+ num_steps: 50
114
+
115
+ discretization_config:
116
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
117
+
118
+ guider_config:
119
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
120
+ params:
121
+ scale: 7.5
122
+
123
+ data:
124
+ target: sgm.data.dataset.StableDataModuleFromConfig
125
+ params:
126
+ train:
127
+ datapipeline:
128
+ urls:
129
+ # USER: adapt this path the root of your custom dataset
130
+ - "DATA_PATH"
131
+ pipeline_config:
132
+ shardshuffle: 10000
133
+ sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
134
+
135
+ decoders:
136
+ - "pil"
137
+
138
+ postprocessors:
139
+ - target: sdata.mappers.TorchVisionImageTransforms
140
+ params:
141
+ key: 'jpg' # USER: you might wanna adapt this for your custom dataset
142
+ transforms:
143
+ - target: torchvision.transforms.Resize
144
+ params:
145
+ size: 256
146
+ interpolation: 3
147
+ - target: torchvision.transforms.ToTensor
148
+ - target: sdata.mappers.Rescaler
149
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
150
+ # USER: you might wanna use non-default parameters due to your custom dataset
151
+
152
+ loader:
153
+ batch_size: 64
154
+ num_workers: 6
155
+
156
+ lightning:
157
+ modelcheckpoint:
158
+ params:
159
+ every_n_train_steps: 5000
160
+
161
+ callbacks:
162
+ metrics_over_trainsteps_checkpoint:
163
+ params:
164
+ every_n_train_steps: 25000
165
+
166
+ image_logger:
167
+ target: main.ImageLogger
168
+ params:
169
+ disabled: False
170
+ enable_autocast: False
171
+ batch_frequency: 1000
172
+ max_images: 8
173
+ increase_log_steps: True
174
+ log_first_step: False
175
+ log_images_kwargs:
176
+ use_ema_scope: False
177
+ N: 8
178
+ n_rows: 2
179
+
180
+ trainer:
181
+ devices: 0,
182
+ benchmark: True
183
+ num_sanity_val_steps: 0
184
+ accumulate_grad_batches: 1
185
+ max_epochs: 1000
repositories/generative-models/configs/example_training/txt2img-clipl.yaml ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ scale_factor: 0.13025
6
+ disable_first_stage_autocast: True
7
+ log_keys:
8
+ - txt
9
+
10
+ scheduler_config:
11
+ target: sgm.lr_scheduler.LambdaLinearScheduler
12
+ params:
13
+ warm_up_steps: [ 10000 ]
14
+ cycle_lengths: [ 10000000000000 ]
15
+ f_start: [ 1.e-6 ]
16
+ f_max: [ 1. ]
17
+ f_min: [ 1. ]
18
+
19
+ denoiser_config:
20
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
21
+ params:
22
+ num_idx: 1000
23
+
24
+ weighting_config:
25
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
26
+ scaling_config:
27
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
28
+ discretization_config:
29
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
30
+
31
+ network_config:
32
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
33
+ params:
34
+ use_checkpoint: True
35
+ use_fp16: True
36
+ in_channels: 4
37
+ out_channels: 4
38
+ model_channels: 320
39
+ attention_resolutions: [ 1, 2, 4 ]
40
+ num_res_blocks: 2
41
+ channel_mult: [ 1, 2, 4, 4 ]
42
+ num_head_channels: 64
43
+ num_classes: sequential
44
+ adm_in_channels: 1792
45
+ num_heads: 1
46
+ use_spatial_transformer: true
47
+ transformer_depth: 1
48
+ context_dim: 768
49
+ spatial_transformer_attn_type: softmax-xformers
50
+
51
+ conditioner_config:
52
+ target: sgm.modules.GeneralConditioner
53
+ params:
54
+ emb_models:
55
+ # crossattn cond
56
+ - is_trainable: True
57
+ input_key: txt
58
+ ucg_rate: 0.1
59
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
60
+ params:
61
+ always_return_pooled: True
62
+ # vector cond
63
+ - is_trainable: False
64
+ ucg_rate: 0.1
65
+ input_key: original_size_as_tuple
66
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
67
+ params:
68
+ outdim: 256 # multiplied by two
69
+ # vector cond
70
+ - is_trainable: False
71
+ input_key: crop_coords_top_left
72
+ ucg_rate: 0.1
73
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
74
+ params:
75
+ outdim: 256 # multiplied by two
76
+
77
+ first_stage_config:
78
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
79
+ params:
80
+ ckpt_path: CKPT_PATH
81
+ embed_dim: 4
82
+ monitor: val/rec_loss
83
+ ddconfig:
84
+ attn_type: vanilla-xformers
85
+ double_z: true
86
+ z_channels: 4
87
+ resolution: 256
88
+ in_channels: 3
89
+ out_ch: 3
90
+ ch: 128
91
+ ch_mult: [ 1, 2, 4, 4 ]
92
+ num_res_blocks: 2
93
+ attn_resolutions: [ ]
94
+ dropout: 0.0
95
+ lossconfig:
96
+ target: torch.nn.Identity
97
+
98
+ loss_fn_config:
99
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
100
+ params:
101
+ sigma_sampler_config:
102
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
103
+ params:
104
+ num_idx: 1000
105
+
106
+ discretization_config:
107
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
108
+
109
+ sampler_config:
110
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
111
+ params:
112
+ num_steps: 50
113
+
114
+ discretization_config:
115
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
116
+
117
+ guider_config:
118
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
119
+ params:
120
+ scale: 7.5
121
+
122
+ data:
123
+ target: sgm.data.dataset.StableDataModuleFromConfig
124
+ params:
125
+ train:
126
+ datapipeline:
127
+ urls:
128
+ # USER: adapt this path the root of your custom dataset
129
+ - "DATA_PATH"
130
+ pipeline_config:
131
+ shardshuffle: 10000
132
+ sample_shuffle: 10000
133
+
134
+
135
+ decoders:
136
+ - "pil"
137
+
138
+ postprocessors:
139
+ - target: sdata.mappers.TorchVisionImageTransforms
140
+ params:
141
+ key: 'jpg' # USER: you might wanna adapt this for your custom dataset
142
+ transforms:
143
+ - target: torchvision.transforms.Resize
144
+ params:
145
+ size: 256
146
+ interpolation: 3
147
+ - target: torchvision.transforms.ToTensor
148
+ - target: sdata.mappers.Rescaler
149
+ # USER: you might wanna use non-default parameters due to your custom dataset
150
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
151
+ # USER: you might wanna use non-default parameters due to your custom dataset
152
+
153
+ loader:
154
+ batch_size: 64
155
+ num_workers: 6
156
+
157
+ lightning:
158
+ modelcheckpoint:
159
+ params:
160
+ every_n_train_steps: 5000
161
+
162
+ callbacks:
163
+ metrics_over_trainsteps_checkpoint:
164
+ params:
165
+ every_n_train_steps: 25000
166
+
167
+ image_logger:
168
+ target: main.ImageLogger
169
+ params:
170
+ disabled: False
171
+ enable_autocast: False
172
+ batch_frequency: 1000
173
+ max_images: 8
174
+ increase_log_steps: True
175
+ log_first_step: False
176
+ log_images_kwargs:
177
+ use_ema_scope: False
178
+ N: 8
179
+ n_rows: 2
180
+
181
+ trainer:
182
+ devices: 0,
183
+ benchmark: True
184
+ num_sanity_val_steps: 0
185
+ accumulate_grad_batches: 1
186
+ max_epochs: 1000
repositories/generative-models/configs/inference/sd_2_1.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.18215
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9
+ params:
10
+ num_idx: 1000
11
+
12
+ weighting_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
14
+ scaling_config:
15
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
16
+ discretization_config:
17
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
18
+
19
+ network_config:
20
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
21
+ params:
22
+ use_checkpoint: True
23
+ use_fp16: True
24
+ in_channels: 4
25
+ out_channels: 4
26
+ model_channels: 320
27
+ attention_resolutions: [4, 2, 1]
28
+ num_res_blocks: 2
29
+ channel_mult: [1, 2, 4, 4]
30
+ num_head_channels: 64
31
+ use_spatial_transformer: True
32
+ use_linear_in_transformer: True
33
+ transformer_depth: 1
34
+ context_dim: 1024
35
+ legacy: False
36
+
37
+ conditioner_config:
38
+ target: sgm.modules.GeneralConditioner
39
+ params:
40
+ emb_models:
41
+ # crossattn cond
42
+ - is_trainable: False
43
+ input_key: txt
44
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
45
+ params:
46
+ freeze: true
47
+ layer: penultimate
48
+
49
+ first_stage_config:
50
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
51
+ params:
52
+ embed_dim: 4
53
+ monitor: val/rec_loss
54
+ ddconfig:
55
+ double_z: true
56
+ z_channels: 4
57
+ resolution: 256
58
+ in_channels: 3
59
+ out_ch: 3
60
+ ch: 128
61
+ ch_mult: [1, 2, 4, 4]
62
+ num_res_blocks: 2
63
+ attn_resolutions: []
64
+ dropout: 0.0
65
+ lossconfig:
66
+ target: torch.nn.Identity
repositories/generative-models/configs/inference/sd_2_1_768.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.18215
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9
+ params:
10
+ num_idx: 1000
11
+
12
+ weighting_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_weighting.VWeighting
14
+ scaling_config:
15
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling
16
+ discretization_config:
17
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
18
+
19
+ network_config:
20
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
21
+ params:
22
+ use_checkpoint: True
23
+ use_fp16: True
24
+ in_channels: 4
25
+ out_channels: 4
26
+ model_channels: 320
27
+ attention_resolutions: [4, 2, 1]
28
+ num_res_blocks: 2
29
+ channel_mult: [1, 2, 4, 4]
30
+ num_head_channels: 64
31
+ use_spatial_transformer: True
32
+ use_linear_in_transformer: True
33
+ transformer_depth: 1
34
+ context_dim: 1024
35
+ legacy: False
36
+
37
+ conditioner_config:
38
+ target: sgm.modules.GeneralConditioner
39
+ params:
40
+ emb_models:
41
+ # crossattn cond
42
+ - is_trainable: False
43
+ input_key: txt
44
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
45
+ params:
46
+ freeze: true
47
+ layer: penultimate
48
+
49
+ first_stage_config:
50
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
51
+ params:
52
+ embed_dim: 4
53
+ monitor: val/rec_loss
54
+ ddconfig:
55
+ double_z: true
56
+ z_channels: 4
57
+ resolution: 256
58
+ in_channels: 3
59
+ out_ch: 3
60
+ ch: 128
61
+ ch_mult: [1, 2, 4, 4]
62
+ num_res_blocks: 2
63
+ attn_resolutions: []
64
+ dropout: 0.0
65
+ lossconfig:
66
+ target: torch.nn.Identity
repositories/generative-models/configs/inference/sd_xl_base.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.13025
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9
+ params:
10
+ num_idx: 1000
11
+
12
+ weighting_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
14
+ scaling_config:
15
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
16
+ discretization_config:
17
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
18
+
19
+ network_config:
20
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
21
+ params:
22
+ adm_in_channels: 2816
23
+ num_classes: sequential
24
+ use_checkpoint: True
25
+ in_channels: 4
26
+ out_channels: 4
27
+ model_channels: 320
28
+ attention_resolutions: [4, 2]
29
+ num_res_blocks: 2
30
+ channel_mult: [1, 2, 4]
31
+ num_head_channels: 64
32
+ use_spatial_transformer: True
33
+ use_linear_in_transformer: True
34
+ transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
35
+ context_dim: 2048
36
+ spatial_transformer_attn_type: softmax-xformers
37
+ legacy: False
38
+
39
+ conditioner_config:
40
+ target: sgm.modules.GeneralConditioner
41
+ params:
42
+ emb_models:
43
+ # crossattn cond
44
+ - is_trainable: False
45
+ input_key: txt
46
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
47
+ params:
48
+ layer: hidden
49
+ layer_idx: 11
50
+ # crossattn and vector cond
51
+ - is_trainable: False
52
+ input_key: txt
53
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
54
+ params:
55
+ arch: ViT-bigG-14
56
+ version: laion2b_s39b_b160k
57
+ freeze: True
58
+ layer: penultimate
59
+ always_return_pooled: True
60
+ legacy: False
61
+ # vector cond
62
+ - is_trainable: False
63
+ input_key: original_size_as_tuple
64
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
65
+ params:
66
+ outdim: 256 # multiplied by two
67
+ # vector cond
68
+ - is_trainable: False
69
+ input_key: crop_coords_top_left
70
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
71
+ params:
72
+ outdim: 256 # multiplied by two
73
+ # vector cond
74
+ - is_trainable: False
75
+ input_key: target_size_as_tuple
76
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
77
+ params:
78
+ outdim: 256 # multiplied by two
79
+
80
+ first_stage_config:
81
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
82
+ params:
83
+ embed_dim: 4
84
+ monitor: val/rec_loss
85
+ ddconfig:
86
+ attn_type: vanilla-xformers
87
+ double_z: true
88
+ z_channels: 4
89
+ resolution: 256
90
+ in_channels: 3
91
+ out_ch: 3
92
+ ch: 128
93
+ ch_mult: [1, 2, 4, 4]
94
+ num_res_blocks: 2
95
+ attn_resolutions: []
96
+ dropout: 0.0
97
+ lossconfig:
98
+ target: torch.nn.Identity
repositories/generative-models/configs/inference/sd_xl_refiner.yaml ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.13025
5
+ disable_first_stage_autocast: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9
+ params:
10
+ num_idx: 1000
11
+
12
+ weighting_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
14
+ scaling_config:
15
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
16
+ discretization_config:
17
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
18
+
19
+ network_config:
20
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
21
+ params:
22
+ adm_in_channels: 2560
23
+ num_classes: sequential
24
+ use_checkpoint: True
25
+ in_channels: 4
26
+ out_channels: 4
27
+ model_channels: 384
28
+ attention_resolutions: [4, 2]
29
+ num_res_blocks: 2
30
+ channel_mult: [1, 2, 4, 4]
31
+ num_head_channels: 64
32
+ use_spatial_transformer: True
33
+ use_linear_in_transformer: True
34
+ transformer_depth: 4
35
+ context_dim: [1280, 1280, 1280, 1280] # 1280
36
+ spatial_transformer_attn_type: softmax-xformers
37
+ legacy: False
38
+
39
+ conditioner_config:
40
+ target: sgm.modules.GeneralConditioner
41
+ params:
42
+ emb_models:
43
+ # crossattn and vector cond
44
+ - is_trainable: False
45
+ input_key: txt
46
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
47
+ params:
48
+ arch: ViT-bigG-14
49
+ version: laion2b_s39b_b160k
50
+ legacy: False
51
+ freeze: True
52
+ layer: penultimate
53
+ always_return_pooled: True
54
+ # vector cond
55
+ - is_trainable: False
56
+ input_key: original_size_as_tuple
57
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
58
+ params:
59
+ outdim: 256 # multiplied by two
60
+ # vector cond
61
+ - is_trainable: False
62
+ input_key: crop_coords_top_left
63
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
64
+ params:
65
+ outdim: 256 # multiplied by two
66
+ # vector cond
67
+ - is_trainable: False
68
+ input_key: aesthetic_score
69
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
70
+ params:
71
+ outdim: 256 # multiplied by one
72
+
73
+ first_stage_config:
74
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
75
+ params:
76
+ embed_dim: 4
77
+ monitor: val/rec_loss
78
+ ddconfig:
79
+ attn_type: vanilla-xformers
80
+ double_z: true
81
+ z_channels: 4
82
+ resolution: 256
83
+ in_channels: 3
84
+ out_ch: 3
85
+ ch: 128
86
+ ch_mult: [1, 2, 4, 4]
87
+ num_res_blocks: 2
88
+ attn_resolutions: []
89
+ dropout: 0.0
90
+ lossconfig:
91
+ target: torch.nn.Identity
repositories/generative-models/data/DejaVuSans.ttf ADDED
Binary file (757 kB). View file
 
repositories/generative-models/main.py ADDED
@@ -0,0 +1,946 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import glob
4
+ import inspect
5
+ import os
6
+ import sys
7
+ from inspect import Parameter
8
+ from typing import Union
9
+
10
+ import numpy as np
11
+ import pytorch_lightning as pl
12
+ import torch
13
+ import torchvision
14
+ import wandb
15
+ from PIL import Image
16
+ from matplotlib import pyplot as plt
17
+ from natsort import natsorted
18
+ from omegaconf import OmegaConf
19
+ from packaging import version
20
+ from pytorch_lightning import seed_everything
21
+ from pytorch_lightning.callbacks import Callback
22
+ from pytorch_lightning.loggers import WandbLogger
23
+ from pytorch_lightning.trainer import Trainer
24
+ from pytorch_lightning.utilities import rank_zero_only
25
+
26
+ from sgm.util import (
27
+ exists,
28
+ instantiate_from_config,
29
+ isheatmap,
30
+ )
31
+
32
+ MULTINODE_HACKS = True
33
+
34
+
35
+ def default_trainer_args():
36
+ argspec = dict(inspect.signature(Trainer.__init__).parameters)
37
+ argspec.pop("self")
38
+ default_args = {
39
+ param: argspec[param].default
40
+ for param in argspec
41
+ if argspec[param] != Parameter.empty
42
+ }
43
+ return default_args
44
+
45
+
46
+ def get_parser(**parser_kwargs):
47
+ def str2bool(v):
48
+ if isinstance(v, bool):
49
+ return v
50
+ if v.lower() in ("yes", "true", "t", "y", "1"):
51
+ return True
52
+ elif v.lower() in ("no", "false", "f", "n", "0"):
53
+ return False
54
+ else:
55
+ raise argparse.ArgumentTypeError("Boolean value expected.")
56
+
57
+ parser = argparse.ArgumentParser(**parser_kwargs)
58
+ parser.add_argument(
59
+ "-n",
60
+ "--name",
61
+ type=str,
62
+ const=True,
63
+ default="",
64
+ nargs="?",
65
+ help="postfix for logdir",
66
+ )
67
+ parser.add_argument(
68
+ "--no_date",
69
+ type=str2bool,
70
+ nargs="?",
71
+ const=True,
72
+ default=False,
73
+ help="if True, skip date generation for logdir and only use naming via opt.base or opt.name (+ opt.postfix, optionally)",
74
+ )
75
+ parser.add_argument(
76
+ "-r",
77
+ "--resume",
78
+ type=str,
79
+ const=True,
80
+ default="",
81
+ nargs="?",
82
+ help="resume from logdir or checkpoint in logdir",
83
+ )
84
+ parser.add_argument(
85
+ "-b",
86
+ "--base",
87
+ nargs="*",
88
+ metavar="base_config.yaml",
89
+ help="paths to base configs. Loaded from left-to-right. "
90
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
91
+ default=list(),
92
+ )
93
+ parser.add_argument(
94
+ "-t",
95
+ "--train",
96
+ type=str2bool,
97
+ const=True,
98
+ default=True,
99
+ nargs="?",
100
+ help="train",
101
+ )
102
+ parser.add_argument(
103
+ "--no-test",
104
+ type=str2bool,
105
+ const=True,
106
+ default=False,
107
+ nargs="?",
108
+ help="disable test",
109
+ )
110
+ parser.add_argument(
111
+ "-p", "--project", help="name of new or path to existing project"
112
+ )
113
+ parser.add_argument(
114
+ "-d",
115
+ "--debug",
116
+ type=str2bool,
117
+ nargs="?",
118
+ const=True,
119
+ default=False,
120
+ help="enable post-mortem debugging",
121
+ )
122
+ parser.add_argument(
123
+ "-s",
124
+ "--seed",
125
+ type=int,
126
+ default=23,
127
+ help="seed for seed_everything",
128
+ )
129
+ parser.add_argument(
130
+ "-f",
131
+ "--postfix",
132
+ type=str,
133
+ default="",
134
+ help="post-postfix for default name",
135
+ )
136
+ parser.add_argument(
137
+ "--projectname",
138
+ type=str,
139
+ default="stablediffusion",
140
+ )
141
+ parser.add_argument(
142
+ "-l",
143
+ "--logdir",
144
+ type=str,
145
+ default="logs",
146
+ help="directory for logging dat shit",
147
+ )
148
+ parser.add_argument(
149
+ "--scale_lr",
150
+ type=str2bool,
151
+ nargs="?",
152
+ const=True,
153
+ default=False,
154
+ help="scale base-lr by ngpu * batch_size * n_accumulate",
155
+ )
156
+ parser.add_argument(
157
+ "--legacy_naming",
158
+ type=str2bool,
159
+ nargs="?",
160
+ const=True,
161
+ default=False,
162
+ help="name run based on config file name if true, else by whole path",
163
+ )
164
+ parser.add_argument(
165
+ "--enable_tf32",
166
+ type=str2bool,
167
+ nargs="?",
168
+ const=True,
169
+ default=False,
170
+ help="enables the TensorFloat32 format both for matmuls and cuDNN for pytorch 1.12",
171
+ )
172
+ parser.add_argument(
173
+ "--startup",
174
+ type=str,
175
+ default=None,
176
+ help="Startuptime from distributed script",
177
+ )
178
+ parser.add_argument(
179
+ "--wandb",
180
+ type=str2bool,
181
+ nargs="?",
182
+ const=True,
183
+ default=False, # TODO: later default to True
184
+ help="log to wandb",
185
+ )
186
+ parser.add_argument(
187
+ "--no_base_name",
188
+ type=str2bool,
189
+ nargs="?",
190
+ const=True,
191
+ default=False, # TODO: later default to True
192
+ help="log to wandb",
193
+ )
194
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
195
+ parser.add_argument(
196
+ "--resume_from_checkpoint",
197
+ type=str,
198
+ default=None,
199
+ help="single checkpoint file to resume from",
200
+ )
201
+ default_args = default_trainer_args()
202
+ for key in default_args:
203
+ parser.add_argument("--" + key, default=default_args[key])
204
+ return parser
205
+
206
+
207
+ def get_checkpoint_name(logdir):
208
+ ckpt = os.path.join(logdir, "checkpoints", "last**.ckpt")
209
+ ckpt = natsorted(glob.glob(ckpt))
210
+ print('available "last" checkpoints:')
211
+ print(ckpt)
212
+ if len(ckpt) > 1:
213
+ print("got most recent checkpoint")
214
+ ckpt = sorted(ckpt, key=lambda x: os.path.getmtime(x))[-1]
215
+ print(f"Most recent ckpt is {ckpt}")
216
+ with open(os.path.join(logdir, "most_recent_ckpt.txt"), "w") as f:
217
+ f.write(ckpt + "\n")
218
+ try:
219
+ version = int(ckpt.split("/")[-1].split("-v")[-1].split(".")[0])
220
+ except Exception as e:
221
+ print("version confusion but not bad")
222
+ print(e)
223
+ version = 1
224
+ # version = last_version + 1
225
+ else:
226
+ # in this case, we only have one "last.ckpt"
227
+ ckpt = ckpt[0]
228
+ version = 1
229
+ melk_ckpt_name = f"last-v{version}.ckpt"
230
+ print(f"Current melk ckpt name: {melk_ckpt_name}")
231
+ return ckpt, melk_ckpt_name
232
+
233
+
234
+ class SetupCallback(Callback):
235
+ def __init__(
236
+ self,
237
+ resume,
238
+ now,
239
+ logdir,
240
+ ckptdir,
241
+ cfgdir,
242
+ config,
243
+ lightning_config,
244
+ debug,
245
+ ckpt_name=None,
246
+ ):
247
+ super().__init__()
248
+ self.resume = resume
249
+ self.now = now
250
+ self.logdir = logdir
251
+ self.ckptdir = ckptdir
252
+ self.cfgdir = cfgdir
253
+ self.config = config
254
+ self.lightning_config = lightning_config
255
+ self.debug = debug
256
+ self.ckpt_name = ckpt_name
257
+
258
+ def on_exception(self, trainer: pl.Trainer, pl_module, exception):
259
+ if not self.debug and trainer.global_rank == 0:
260
+ print("Summoning checkpoint.")
261
+ if self.ckpt_name is None:
262
+ ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
263
+ else:
264
+ ckpt_path = os.path.join(self.ckptdir, self.ckpt_name)
265
+ trainer.save_checkpoint(ckpt_path)
266
+
267
+ def on_fit_start(self, trainer, pl_module):
268
+ if trainer.global_rank == 0:
269
+ # Create logdirs and save configs
270
+ os.makedirs(self.logdir, exist_ok=True)
271
+ os.makedirs(self.ckptdir, exist_ok=True)
272
+ os.makedirs(self.cfgdir, exist_ok=True)
273
+
274
+ if "callbacks" in self.lightning_config:
275
+ if (
276
+ "metrics_over_trainsteps_checkpoint"
277
+ in self.lightning_config["callbacks"]
278
+ ):
279
+ os.makedirs(
280
+ os.path.join(self.ckptdir, "trainstep_checkpoints"),
281
+ exist_ok=True,
282
+ )
283
+ print("Project config")
284
+ print(OmegaConf.to_yaml(self.config))
285
+ if MULTINODE_HACKS:
286
+ import time
287
+
288
+ time.sleep(5)
289
+ OmegaConf.save(
290
+ self.config,
291
+ os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)),
292
+ )
293
+
294
+ print("Lightning config")
295
+ print(OmegaConf.to_yaml(self.lightning_config))
296
+ OmegaConf.save(
297
+ OmegaConf.create({"lightning": self.lightning_config}),
298
+ os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
299
+ )
300
+
301
+ else:
302
+ # ModelCheckpoint callback created log directory --- remove it
303
+ if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir):
304
+ dst, name = os.path.split(self.logdir)
305
+ dst = os.path.join(dst, "child_runs", name)
306
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
307
+ try:
308
+ os.rename(self.logdir, dst)
309
+ except FileNotFoundError:
310
+ pass
311
+
312
+
313
+ class ImageLogger(Callback):
314
+ def __init__(
315
+ self,
316
+ batch_frequency,
317
+ max_images,
318
+ clamp=True,
319
+ increase_log_steps=True,
320
+ rescale=True,
321
+ disabled=False,
322
+ log_on_batch_idx=False,
323
+ log_first_step=False,
324
+ log_images_kwargs=None,
325
+ log_before_first_step=False,
326
+ enable_autocast=True,
327
+ ):
328
+ super().__init__()
329
+ self.enable_autocast = enable_autocast
330
+ self.rescale = rescale
331
+ self.batch_freq = batch_frequency
332
+ self.max_images = max_images
333
+ self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
334
+ if not increase_log_steps:
335
+ self.log_steps = [self.batch_freq]
336
+ self.clamp = clamp
337
+ self.disabled = disabled
338
+ self.log_on_batch_idx = log_on_batch_idx
339
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
340
+ self.log_first_step = log_first_step
341
+ self.log_before_first_step = log_before_first_step
342
+
343
+ @rank_zero_only
344
+ def log_local(
345
+ self,
346
+ save_dir,
347
+ split,
348
+ images,
349
+ global_step,
350
+ current_epoch,
351
+ batch_idx,
352
+ pl_module: Union[None, pl.LightningModule] = None,
353
+ ):
354
+ root = os.path.join(save_dir, "images", split)
355
+ for k in images:
356
+ if isheatmap(images[k]):
357
+ fig, ax = plt.subplots()
358
+ ax = ax.matshow(
359
+ images[k].cpu().numpy(), cmap="hot", interpolation="lanczos"
360
+ )
361
+ plt.colorbar(ax)
362
+ plt.axis("off")
363
+
364
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
365
+ k, global_step, current_epoch, batch_idx
366
+ )
367
+ os.makedirs(root, exist_ok=True)
368
+ path = os.path.join(root, filename)
369
+ plt.savefig(path)
370
+ plt.close()
371
+ # TODO: support wandb
372
+ else:
373
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
374
+ if self.rescale:
375
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
376
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
377
+ grid = grid.numpy()
378
+ grid = (grid * 255).astype(np.uint8)
379
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
380
+ k, global_step, current_epoch, batch_idx
381
+ )
382
+ path = os.path.join(root, filename)
383
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
384
+ img = Image.fromarray(grid)
385
+ img.save(path)
386
+ if exists(pl_module):
387
+ assert isinstance(
388
+ pl_module.logger, WandbLogger
389
+ ), "logger_log_image only supports WandbLogger currently"
390
+ pl_module.logger.log_image(
391
+ key=f"{split}/{k}",
392
+ images=[
393
+ img,
394
+ ],
395
+ step=pl_module.global_step,
396
+ )
397
+
398
+ @rank_zero_only
399
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
400
+ check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
401
+ if (
402
+ self.check_frequency(check_idx)
403
+ and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0
404
+ and callable(pl_module.log_images)
405
+ and
406
+ # batch_idx > 5 and
407
+ self.max_images > 0
408
+ ):
409
+ logger = type(pl_module.logger)
410
+ is_train = pl_module.training
411
+ if is_train:
412
+ pl_module.eval()
413
+
414
+ gpu_autocast_kwargs = {
415
+ "enabled": self.enable_autocast, # torch.is_autocast_enabled(),
416
+ "dtype": torch.get_autocast_gpu_dtype(),
417
+ "cache_enabled": torch.is_autocast_cache_enabled(),
418
+ }
419
+ with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs):
420
+ images = pl_module.log_images(
421
+ batch, split=split, **self.log_images_kwargs
422
+ )
423
+
424
+ for k in images:
425
+ N = min(images[k].shape[0], self.max_images)
426
+ if not isheatmap(images[k]):
427
+ images[k] = images[k][:N]
428
+ if isinstance(images[k], torch.Tensor):
429
+ images[k] = images[k].detach().float().cpu()
430
+ if self.clamp and not isheatmap(images[k]):
431
+ images[k] = torch.clamp(images[k], -1.0, 1.0)
432
+
433
+ self.log_local(
434
+ pl_module.logger.save_dir,
435
+ split,
436
+ images,
437
+ pl_module.global_step,
438
+ pl_module.current_epoch,
439
+ batch_idx,
440
+ pl_module=pl_module
441
+ if isinstance(pl_module.logger, WandbLogger)
442
+ else None,
443
+ )
444
+
445
+ if is_train:
446
+ pl_module.train()
447
+
448
+ def check_frequency(self, check_idx):
449
+ if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
450
+ check_idx > 0 or self.log_first_step
451
+ ):
452
+ try:
453
+ self.log_steps.pop(0)
454
+ except IndexError as e:
455
+ print(e)
456
+ pass
457
+ return True
458
+ return False
459
+
460
+ @rank_zero_only
461
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
462
+ if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
463
+ self.log_img(pl_module, batch, batch_idx, split="train")
464
+
465
+ @rank_zero_only
466
+ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
467
+ if self.log_before_first_step and pl_module.global_step == 0:
468
+ print(f"{self.__class__.__name__}: logging before training")
469
+ self.log_img(pl_module, batch, batch_idx, split="train")
470
+
471
+ @rank_zero_only
472
+ def on_validation_batch_end(
473
+ self, trainer, pl_module, outputs, batch, batch_idx, *args, **kwargs
474
+ ):
475
+ if not self.disabled and pl_module.global_step > 0:
476
+ self.log_img(pl_module, batch, batch_idx, split="val")
477
+ if hasattr(pl_module, "calibrate_grad_norm"):
478
+ if (
479
+ pl_module.calibrate_grad_norm and batch_idx % 25 == 0
480
+ ) and batch_idx > 0:
481
+ self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
482
+
483
+
484
+ @rank_zero_only
485
+ def init_wandb(save_dir, opt, config, group_name, name_str):
486
+ print(f"setting WANDB_DIR to {save_dir}")
487
+ os.makedirs(save_dir, exist_ok=True)
488
+
489
+ os.environ["WANDB_DIR"] = save_dir
490
+ if opt.debug:
491
+ wandb.init(project=opt.projectname, mode="offline", group=group_name)
492
+ else:
493
+ wandb.init(
494
+ project=opt.projectname,
495
+ config=config,
496
+ settings=wandb.Settings(code_dir="./sgm"),
497
+ group=group_name,
498
+ name=name_str,
499
+ )
500
+
501
+
502
+ if __name__ == "__main__":
503
+ # custom parser to specify config files, train, test and debug mode,
504
+ # postfix, resume.
505
+ # `--key value` arguments are interpreted as arguments to the trainer.
506
+ # `nested.key=value` arguments are interpreted as config parameters.
507
+ # configs are merged from left-to-right followed by command line parameters.
508
+
509
+ # model:
510
+ # base_learning_rate: float
511
+ # target: path to lightning module
512
+ # params:
513
+ # key: value
514
+ # data:
515
+ # target: main.DataModuleFromConfig
516
+ # params:
517
+ # batch_size: int
518
+ # wrap: bool
519
+ # train:
520
+ # target: path to train dataset
521
+ # params:
522
+ # key: value
523
+ # validation:
524
+ # target: path to validation dataset
525
+ # params:
526
+ # key: value
527
+ # test:
528
+ # target: path to test dataset
529
+ # params:
530
+ # key: value
531
+ # lightning: (optional, has sane defaults and can be specified on cmdline)
532
+ # trainer:
533
+ # additional arguments to trainer
534
+ # logger:
535
+ # logger to instantiate
536
+ # modelcheckpoint:
537
+ # modelcheckpoint to instantiate
538
+ # callbacks:
539
+ # callback1:
540
+ # target: importpath
541
+ # params:
542
+ # key: value
543
+
544
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
545
+
546
+ # add cwd for convenience and to make classes in this file available when
547
+ # running as `python main.py`
548
+ # (in particular `main.DataModuleFromConfig`)
549
+ sys.path.append(os.getcwd())
550
+
551
+ parser = get_parser()
552
+
553
+ opt, unknown = parser.parse_known_args()
554
+
555
+ if opt.name and opt.resume:
556
+ raise ValueError(
557
+ "-n/--name and -r/--resume cannot be specified both."
558
+ "If you want to resume training in a new log folder, "
559
+ "use -n/--name in combination with --resume_from_checkpoint"
560
+ )
561
+ melk_ckpt_name = None
562
+ name = None
563
+ if opt.resume:
564
+ if not os.path.exists(opt.resume):
565
+ raise ValueError("Cannot find {}".format(opt.resume))
566
+ if os.path.isfile(opt.resume):
567
+ paths = opt.resume.split("/")
568
+ # idx = len(paths)-paths[::-1].index("logs")+1
569
+ # logdir = "/".join(paths[:idx])
570
+ logdir = "/".join(paths[:-2])
571
+ ckpt = opt.resume
572
+ _, melk_ckpt_name = get_checkpoint_name(logdir)
573
+ else:
574
+ assert os.path.isdir(opt.resume), opt.resume
575
+ logdir = opt.resume.rstrip("/")
576
+ ckpt, melk_ckpt_name = get_checkpoint_name(logdir)
577
+
578
+ print("#" * 100)
579
+ print(f'Resuming from checkpoint "{ckpt}"')
580
+ print("#" * 100)
581
+
582
+ opt.resume_from_checkpoint = ckpt
583
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
584
+ opt.base = base_configs + opt.base
585
+ _tmp = logdir.split("/")
586
+ nowname = _tmp[-1]
587
+ else:
588
+ if opt.name:
589
+ name = "_" + opt.name
590
+ elif opt.base:
591
+ if opt.no_base_name:
592
+ name = ""
593
+ else:
594
+ if opt.legacy_naming:
595
+ cfg_fname = os.path.split(opt.base[0])[-1]
596
+ cfg_name = os.path.splitext(cfg_fname)[0]
597
+ else:
598
+ assert "configs" in os.path.split(opt.base[0])[0], os.path.split(
599
+ opt.base[0]
600
+ )[0]
601
+ cfg_path = os.path.split(opt.base[0])[0].split(os.sep)[
602
+ os.path.split(opt.base[0])[0].split(os.sep).index("configs")
603
+ + 1 :
604
+ ] # cut away the first one (we assert all configs are in "configs")
605
+ cfg_name = os.path.splitext(os.path.split(opt.base[0])[-1])[0]
606
+ cfg_name = "-".join(cfg_path) + f"-{cfg_name}"
607
+ name = "_" + cfg_name
608
+ else:
609
+ name = ""
610
+ if not opt.no_date:
611
+ nowname = now + name + opt.postfix
612
+ else:
613
+ nowname = name + opt.postfix
614
+ if nowname.startswith("_"):
615
+ nowname = nowname[1:]
616
+ logdir = os.path.join(opt.logdir, nowname)
617
+ print(f"LOGDIR: {logdir}")
618
+
619
+ ckptdir = os.path.join(logdir, "checkpoints")
620
+ cfgdir = os.path.join(logdir, "configs")
621
+ seed_everything(opt.seed, workers=True)
622
+
623
+ # move before model init, in case a torch.compile(...) is called somewhere
624
+ if opt.enable_tf32:
625
+ # pt_version = version.parse(torch.__version__)
626
+ torch.backends.cuda.matmul.allow_tf32 = True
627
+ torch.backends.cudnn.allow_tf32 = True
628
+ print(f"Enabling TF32 for PyTorch {torch.__version__}")
629
+ else:
630
+ print(f"Using default TF32 settings for PyTorch {torch.__version__}:")
631
+ print(
632
+ f"torch.backends.cuda.matmul.allow_tf32={torch.backends.cuda.matmul.allow_tf32}"
633
+ )
634
+ print(f"torch.backends.cudnn.allow_tf32={torch.backends.cudnn.allow_tf32}")
635
+
636
+ try:
637
+ # init and save configs
638
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
639
+ cli = OmegaConf.from_dotlist(unknown)
640
+ config = OmegaConf.merge(*configs, cli)
641
+ lightning_config = config.pop("lightning", OmegaConf.create())
642
+ # merge trainer cli with config
643
+ trainer_config = lightning_config.get("trainer", OmegaConf.create())
644
+
645
+ # default to gpu
646
+ trainer_config["accelerator"] = "gpu"
647
+ #
648
+ standard_args = default_trainer_args()
649
+ for k in standard_args:
650
+ if getattr(opt, k) != standard_args[k]:
651
+ trainer_config[k] = getattr(opt, k)
652
+
653
+ ckpt_resume_path = opt.resume_from_checkpoint
654
+
655
+ if not "devices" in trainer_config and trainer_config["accelerator"] != "gpu":
656
+ del trainer_config["accelerator"]
657
+ cpu = True
658
+ else:
659
+ gpuinfo = trainer_config["devices"]
660
+ print(f"Running on GPUs {gpuinfo}")
661
+ cpu = False
662
+ trainer_opt = argparse.Namespace(**trainer_config)
663
+ lightning_config.trainer = trainer_config
664
+
665
+ # model
666
+ model = instantiate_from_config(config.model)
667
+
668
+ # trainer and callbacks
669
+ trainer_kwargs = dict()
670
+
671
+ # default logger configs
672
+ default_logger_cfgs = {
673
+ "wandb": {
674
+ "target": "pytorch_lightning.loggers.WandbLogger",
675
+ "params": {
676
+ "name": nowname,
677
+ # "save_dir": logdir,
678
+ "offline": opt.debug,
679
+ "id": nowname,
680
+ "project": opt.projectname,
681
+ "log_model": False,
682
+ # "dir": logdir,
683
+ },
684
+ },
685
+ "csv": {
686
+ "target": "pytorch_lightning.loggers.CSVLogger",
687
+ "params": {
688
+ "name": "testtube", # hack for sbord fanatics
689
+ "save_dir": logdir,
690
+ },
691
+ },
692
+ }
693
+ default_logger_cfg = default_logger_cfgs["wandb" if opt.wandb else "csv"]
694
+ if opt.wandb:
695
+ # TODO change once leaving "swiffer" config directory
696
+ try:
697
+ group_name = nowname.split(now)[-1].split("-")[1]
698
+ except:
699
+ group_name = nowname
700
+ default_logger_cfg["params"]["group"] = group_name
701
+ init_wandb(
702
+ os.path.join(os.getcwd(), logdir),
703
+ opt=opt,
704
+ group_name=group_name,
705
+ config=config,
706
+ name_str=nowname,
707
+ )
708
+ if "logger" in lightning_config:
709
+ logger_cfg = lightning_config.logger
710
+ else:
711
+ logger_cfg = OmegaConf.create()
712
+ logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
713
+ trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
714
+
715
+ # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
716
+ # specify which metric is used to determine best models
717
+ default_modelckpt_cfg = {
718
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
719
+ "params": {
720
+ "dirpath": ckptdir,
721
+ "filename": "{epoch:06}",
722
+ "verbose": True,
723
+ "save_last": True,
724
+ },
725
+ }
726
+ if hasattr(model, "monitor"):
727
+ print(f"Monitoring {model.monitor} as checkpoint metric.")
728
+ default_modelckpt_cfg["params"]["monitor"] = model.monitor
729
+ default_modelckpt_cfg["params"]["save_top_k"] = 3
730
+
731
+ if "modelcheckpoint" in lightning_config:
732
+ modelckpt_cfg = lightning_config.modelcheckpoint
733
+ else:
734
+ modelckpt_cfg = OmegaConf.create()
735
+ modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
736
+ print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
737
+
738
+ # https://pytorch-lightning.readthedocs.io/en/stable/extensions/strategy.html
739
+ # default to ddp if not further specified
740
+ default_strategy_config = {"target": "pytorch_lightning.strategies.DDPStrategy"}
741
+
742
+ if "strategy" in lightning_config:
743
+ strategy_cfg = lightning_config.strategy
744
+ else:
745
+ strategy_cfg = OmegaConf.create()
746
+ default_strategy_config["params"] = {
747
+ "find_unused_parameters": False,
748
+ # "static_graph": True,
749
+ # "ddp_comm_hook": default.fp16_compress_hook # TODO: experiment with this, also for DDPSharded
750
+ }
751
+ strategy_cfg = OmegaConf.merge(default_strategy_config, strategy_cfg)
752
+ print(
753
+ f"strategy config: \n ++++++++++++++ \n {strategy_cfg} \n ++++++++++++++ "
754
+ )
755
+ trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
756
+
757
+ # add callback which sets up log directory
758
+ default_callbacks_cfg = {
759
+ "setup_callback": {
760
+ "target": "main.SetupCallback",
761
+ "params": {
762
+ "resume": opt.resume,
763
+ "now": now,
764
+ "logdir": logdir,
765
+ "ckptdir": ckptdir,
766
+ "cfgdir": cfgdir,
767
+ "config": config,
768
+ "lightning_config": lightning_config,
769
+ "debug": opt.debug,
770
+ "ckpt_name": melk_ckpt_name,
771
+ },
772
+ },
773
+ "image_logger": {
774
+ "target": "main.ImageLogger",
775
+ "params": {"batch_frequency": 1000, "max_images": 4, "clamp": True},
776
+ },
777
+ "learning_rate_logger": {
778
+ "target": "pytorch_lightning.callbacks.LearningRateMonitor",
779
+ "params": {
780
+ "logging_interval": "step",
781
+ # "log_momentum": True
782
+ },
783
+ },
784
+ }
785
+ if version.parse(pl.__version__) >= version.parse("1.4.0"):
786
+ default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg})
787
+
788
+ if "callbacks" in lightning_config:
789
+ callbacks_cfg = lightning_config.callbacks
790
+ else:
791
+ callbacks_cfg = OmegaConf.create()
792
+
793
+ if "metrics_over_trainsteps_checkpoint" in callbacks_cfg:
794
+ print(
795
+ "Caution: Saving checkpoints every n train steps without deleting. This might require some free space."
796
+ )
797
+ default_metrics_over_trainsteps_ckpt_dict = {
798
+ "metrics_over_trainsteps_checkpoint": {
799
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
800
+ "params": {
801
+ "dirpath": os.path.join(ckptdir, "trainstep_checkpoints"),
802
+ "filename": "{epoch:06}-{step:09}",
803
+ "verbose": True,
804
+ "save_top_k": -1,
805
+ "every_n_train_steps": 10000,
806
+ "save_weights_only": True,
807
+ },
808
+ }
809
+ }
810
+ default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
811
+
812
+ callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
813
+ if "ignore_keys_callback" in callbacks_cfg and ckpt_resume_path is not None:
814
+ callbacks_cfg.ignore_keys_callback.params["ckpt_path"] = ckpt_resume_path
815
+ elif "ignore_keys_callback" in callbacks_cfg:
816
+ del callbacks_cfg["ignore_keys_callback"]
817
+
818
+ trainer_kwargs["callbacks"] = [
819
+ instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
820
+ ]
821
+ if not "plugins" in trainer_kwargs:
822
+ trainer_kwargs["plugins"] = list()
823
+
824
+ # cmd line trainer args (which are in trainer_opt) have always priority over config-trainer-args (which are in trainer_kwargs)
825
+ trainer_opt = vars(trainer_opt)
826
+ trainer_kwargs = {
827
+ key: val for key, val in trainer_kwargs.items() if key not in trainer_opt
828
+ }
829
+ trainer = Trainer(**trainer_opt, **trainer_kwargs)
830
+
831
+ trainer.logdir = logdir ###
832
+
833
+ # data
834
+ data = instantiate_from_config(config.data)
835
+ # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
836
+ # calling these ourselves should not be necessary but it is.
837
+ # lightning still takes care of proper multiprocessing though
838
+ data.prepare_data()
839
+ # data.setup()
840
+ print("#### Data #####")
841
+ try:
842
+ for k in data.datasets:
843
+ print(
844
+ f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}"
845
+ )
846
+ except:
847
+ print("datasets not yet initialized.")
848
+
849
+ # configure learning rate
850
+ if "batch_size" in config.data.params:
851
+ bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
852
+ else:
853
+ bs, base_lr = (
854
+ config.data.params.train.loader.batch_size,
855
+ config.model.base_learning_rate,
856
+ )
857
+ if not cpu:
858
+ ngpu = len(lightning_config.trainer.devices.strip(",").split(","))
859
+ else:
860
+ ngpu = 1
861
+ if "accumulate_grad_batches" in lightning_config.trainer:
862
+ accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
863
+ else:
864
+ accumulate_grad_batches = 1
865
+ print(f"accumulate_grad_batches = {accumulate_grad_batches}")
866
+ lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
867
+ if opt.scale_lr:
868
+ model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
869
+ print(
870
+ "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
871
+ model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr
872
+ )
873
+ )
874
+ else:
875
+ model.learning_rate = base_lr
876
+ print("++++ NOT USING LR SCALING ++++")
877
+ print(f"Setting learning rate to {model.learning_rate:.2e}")
878
+
879
+ # allow checkpointing via USR1
880
+ def melk(*args, **kwargs):
881
+ # run all checkpoint hooks
882
+ if trainer.global_rank == 0:
883
+ print("Summoning checkpoint.")
884
+ if melk_ckpt_name is None:
885
+ ckpt_path = os.path.join(ckptdir, "last.ckpt")
886
+ else:
887
+ ckpt_path = os.path.join(ckptdir, melk_ckpt_name)
888
+ trainer.save_checkpoint(ckpt_path)
889
+
890
+ def divein(*args, **kwargs):
891
+ if trainer.global_rank == 0:
892
+ import pudb
893
+
894
+ pudb.set_trace()
895
+
896
+ import signal
897
+
898
+ signal.signal(signal.SIGUSR1, melk)
899
+ signal.signal(signal.SIGUSR2, divein)
900
+
901
+ # run
902
+ if opt.train:
903
+ try:
904
+ trainer.fit(model, data, ckpt_path=ckpt_resume_path)
905
+ except Exception:
906
+ if not opt.debug:
907
+ melk()
908
+ raise
909
+ if not opt.no_test and not trainer.interrupted:
910
+ trainer.test(model, data)
911
+ except RuntimeError as err:
912
+ if MULTINODE_HACKS:
913
+ import requests
914
+ import datetime
915
+ import os
916
+ import socket
917
+
918
+ device = os.environ.get("CUDA_VISIBLE_DEVICES", "?")
919
+ hostname = socket.gethostname()
920
+ ts = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
921
+ resp = requests.get("http://169.254.169.254/latest/meta-data/instance-id")
922
+ print(
923
+ f"ERROR at {ts} on {hostname}/{resp.text} (CUDA_VISIBLE_DEVICES={device}): {type(err).__name__}: {err}",
924
+ flush=True,
925
+ )
926
+ raise err
927
+ except Exception:
928
+ if opt.debug and trainer.global_rank == 0:
929
+ try:
930
+ import pudb as debugger
931
+ except ImportError:
932
+ import pdb as debugger
933
+ debugger.post_mortem()
934
+ raise
935
+ finally:
936
+ # move newly created debug project to debug_runs
937
+ if opt.debug and not opt.resume and trainer.global_rank == 0:
938
+ dst, name = os.path.split(logdir)
939
+ dst = os.path.join(dst, "debug_runs", name)
940
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
941
+ os.rename(logdir, dst)
942
+
943
+ if opt.wandb:
944
+ wandb.finish()
945
+ # if trainer.global_rank == 0:
946
+ # print(trainer.profiler.summary())
repositories/generative-models/requirements_pt13.txt ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ omegaconf
2
+ einops
3
+ fire
4
+ tqdm
5
+ pillow
6
+ numpy
7
+ webdataset>=0.2.33
8
+ --extra-index-url https://download.pytorch.org/whl/cu117
9
+ torch==1.13.1+cu117
10
+ xformers==0.0.16
11
+ torchaudio==0.13.1
12
+ torchvision==0.14.1+cu117
13
+ torchmetrics
14
+ opencv-python==4.6.0.66
15
+ fairscale
16
+ pytorch-lightning==1.8.5
17
+ fsspec
18
+ kornia==0.6.9
19
+ matplotlib
20
+ natsort
21
+ tensorboardx==2.5.1
22
+ open-clip-torch
23
+ chardet
24
+ scipy
25
+ pandas
26
+ pudb
27
+ pyyaml
28
+ urllib3<1.27,>=1.25.4
29
+ streamlit>=0.73.1
30
+ timm
31
+ tokenizers==0.12.1
32
+ torchdata==0.5.1
33
+ transformers==4.19.1
34
+ onnx<=1.12.0
35
+ triton
36
+ wandb
37
+ invisible-watermark
38
+ -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
39
+ -e git+https://github.com/openai/CLIP.git@main#egg=clip
40
+ -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
41
+ -e .
repositories/generative-models/requirements_pt2.txt ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ omegaconf
2
+ einops
3
+ fire
4
+ tqdm
5
+ pillow
6
+ numpy
7
+ webdataset>=0.2.33
8
+ ninja
9
+ torch
10
+ matplotlib
11
+ torchaudio>=2.0.2
12
+ torchmetrics
13
+ torchvision>=0.15.2
14
+ opencv-python==4.6.0.66
15
+ fairscale
16
+ pytorch-lightning==2.0.1
17
+ fire
18
+ fsspec
19
+ kornia==0.6.9
20
+ natsort
21
+ open-clip-torch
22
+ chardet==5.1.0
23
+ tensorboardx==2.6
24
+ pandas
25
+ pudb
26
+ pyyaml
27
+ urllib3<1.27,>=1.25.4
28
+ scipy
29
+ streamlit>=0.73.1
30
+ timm
31
+ tokenizers==0.12.1
32
+ transformers==4.19.1
33
+ triton==2.0.0
34
+ torchdata==0.6.1
35
+ wandb
36
+ invisible-watermark
37
+ xformers
38
+ -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
39
+ -e git+https://github.com/openai/CLIP.git@main#egg=clip
40
+ -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
41
+ -e .
repositories/generative-models/scripts/__init__.py ADDED
File without changes
repositories/generative-models/scripts/demo/__init__.py ADDED
File without changes
repositories/generative-models/scripts/demo/detect.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+ try:
7
+ from imwatermark import WatermarkDecoder
8
+ except ImportError as e:
9
+ try:
10
+ # Assume some of the other dependencies such as torch are not fulfilled
11
+ # import file without loading unnecessary libraries.
12
+ import importlib.util
13
+ import sys
14
+
15
+ spec = importlib.util.find_spec("imwatermark.maxDct")
16
+ assert spec is not None
17
+ maxDct = importlib.util.module_from_spec(spec)
18
+ sys.modules["maxDct"] = maxDct
19
+ spec.loader.exec_module(maxDct)
20
+
21
+ class WatermarkDecoder(object):
22
+ """A minimal version of
23
+ https://github.com/ShieldMnt/invisible-watermark/blob/main/imwatermark/watermark.py
24
+ to only reconstruct bits using dwtDct"""
25
+
26
+ def __init__(self, wm_type="bytes", length=0):
27
+ assert wm_type == "bits", "Only bits defined in minimal import"
28
+ self._wmType = wm_type
29
+ self._wmLen = length
30
+
31
+ def reconstruct(self, bits):
32
+ if len(bits) != self._wmLen:
33
+ raise RuntimeError("bits are not matched with watermark length")
34
+
35
+ return bits
36
+
37
+ def decode(self, cv2Image, method="dwtDct", **configs):
38
+ (r, c, channels) = cv2Image.shape
39
+ if r * c < 256 * 256:
40
+ raise RuntimeError("image too small, should be larger than 256x256")
41
+
42
+ bits = []
43
+ assert method == "dwtDct"
44
+ embed = maxDct.EmbedMaxDct(watermarks=[], wmLen=self._wmLen, **configs)
45
+ bits = embed.decode(cv2Image)
46
+ return self.reconstruct(bits)
47
+
48
+ except:
49
+ raise e
50
+
51
+
52
+ # A fixed 48-bit message that was choosen at random
53
+ # WATERMARK_MESSAGE = 0xB3EC907BB19E
54
+ WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
55
+ # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
56
+ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
57
+ MATCH_VALUES = [
58
+ [27, "No watermark detected"],
59
+ [33, "Partial watermark match. Cannot determine with certainty."],
60
+ [
61
+ 35,
62
+ (
63
+ "Likely watermarked. In our test 0.02% of real images were "
64
+ 'falsely detected as "Likely watermarked"'
65
+ ),
66
+ ],
67
+ [
68
+ 49,
69
+ (
70
+ "Very likely watermarked. In our test no real images were "
71
+ 'falsely detected as "Very likely watermarked"'
72
+ ),
73
+ ],
74
+ ]
75
+
76
+
77
+ class GetWatermarkMatch:
78
+ def __init__(self, watermark):
79
+ self.watermark = watermark
80
+ self.num_bits = len(self.watermark)
81
+ self.decoder = WatermarkDecoder("bits", self.num_bits)
82
+
83
+ def __call__(self, x: np.ndarray) -> np.ndarray:
84
+ """
85
+ Detects the number of matching bits the predefined watermark with one
86
+ or multiple images. Images should be in cv2 format, e.g. h x w x c BGR.
87
+
88
+ Args:
89
+ x: ([B], h w, c) in range [0, 255]
90
+
91
+ Returns:
92
+ number of matched bits ([B],)
93
+ """
94
+ squeeze = len(x.shape) == 3
95
+ if squeeze:
96
+ x = x[None, ...]
97
+
98
+ bs = x.shape[0]
99
+ detected = np.empty((bs, self.num_bits), dtype=bool)
100
+ for k in range(bs):
101
+ detected[k] = self.decoder.decode(x[k], "dwtDct")
102
+ result = np.sum(detected == self.watermark, axis=-1)
103
+ if squeeze:
104
+ return result[0]
105
+ else:
106
+ return result
107
+
108
+
109
+ get_watermark_match = GetWatermarkMatch(WATERMARK_BITS)
110
+
111
+
112
+ if __name__ == "__main__":
113
+ parser = argparse.ArgumentParser()
114
+ parser.add_argument(
115
+ "filename",
116
+ nargs="+",
117
+ type=str,
118
+ help="Image files to check for watermarks",
119
+ )
120
+ opts = parser.parse_args()
121
+
122
+ print(
123
+ """
124
+ This script tries to detect watermarked images. Please be aware of
125
+ the following:
126
+ - As the watermark is supposed to be invisible, there is the risk that
127
+ watermarked images may not be detected.
128
+ - To maximize the chance of detection make sure that the image has the same
129
+ dimensions as when the watermark was applied (most likely 1024x1024
130
+ or 512x512).
131
+ - Specific image manipulation may drastically decrease the chance that
132
+ watermarks can be detected.
133
+ - There is also the chance that an image has the characteristics of the
134
+ watermark by chance.
135
+ - The watermark script is public, anybody may watermark any images, and
136
+ could therefore claim it to be generated.
137
+ - All numbers below are based on a test using 10,000 images without any
138
+ modifications after applying the watermark.
139
+ """
140
+ )
141
+
142
+ for fn in opts.filename:
143
+ image = cv2.imread(fn)
144
+ if image is None:
145
+ print(f"Couldn't read {fn}. Skipping")
146
+ continue
147
+
148
+ num_bits = get_watermark_match(image)
149
+ k = 0
150
+ while num_bits > MATCH_VALUES[k][0]:
151
+ k += 1
152
+ print(
153
+ f"{fn}: {MATCH_VALUES[k][1]}",
154
+ f"Bits that matched the watermark {num_bits} from {len(WATERMARK_BITS)}\n",
155
+ sep="\n\t",
156
+ )
repositories/generative-models/scripts/demo/sampling.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytorch_lightning import seed_everything
2
+ from scripts.demo.streamlit_helpers import *
3
+ from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
4
+
5
+ SAVE_PATH = "outputs/demo/txt2img/"
6
+
7
+ SD_XL_BASE_RATIOS = {
8
+ "0.5": (704, 1408),
9
+ "0.52": (704, 1344),
10
+ "0.57": (768, 1344),
11
+ "0.6": (768, 1280),
12
+ "0.68": (832, 1216),
13
+ "0.72": (832, 1152),
14
+ "0.78": (896, 1152),
15
+ "0.82": (896, 1088),
16
+ "0.88": (960, 1088),
17
+ "0.94": (960, 1024),
18
+ "1.0": (1024, 1024),
19
+ "1.07": (1024, 960),
20
+ "1.13": (1088, 960),
21
+ "1.21": (1088, 896),
22
+ "1.29": (1152, 896),
23
+ "1.38": (1152, 832),
24
+ "1.46": (1216, 832),
25
+ "1.67": (1280, 768),
26
+ "1.75": (1344, 768),
27
+ "1.91": (1344, 704),
28
+ "2.0": (1408, 704),
29
+ "2.09": (1472, 704),
30
+ "2.4": (1536, 640),
31
+ "2.5": (1600, 640),
32
+ "2.89": (1664, 576),
33
+ "3.0": (1728, 576),
34
+ }
35
+
36
+ VERSION2SPECS = {
37
+ "SD-XL base": {
38
+ "H": 1024,
39
+ "W": 1024,
40
+ "C": 4,
41
+ "f": 8,
42
+ "is_legacy": False,
43
+ "config": "configs/inference/sd_xl_base.yaml",
44
+ "ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
45
+ "is_guided": True,
46
+ },
47
+ "sd-2.1": {
48
+ "H": 512,
49
+ "W": 512,
50
+ "C": 4,
51
+ "f": 8,
52
+ "is_legacy": True,
53
+ "config": "configs/inference/sd_2_1.yaml",
54
+ "ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
55
+ "is_guided": True,
56
+ },
57
+ "sd-2.1-768": {
58
+ "H": 768,
59
+ "W": 768,
60
+ "C": 4,
61
+ "f": 8,
62
+ "is_legacy": True,
63
+ "config": "configs/inference/sd_2_1_768.yaml",
64
+ "ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
65
+ },
66
+ "SDXL-Refiner": {
67
+ "H": 1024,
68
+ "W": 1024,
69
+ "C": 4,
70
+ "f": 8,
71
+ "is_legacy": True,
72
+ "config": "configs/inference/sd_xl_refiner.yaml",
73
+ "ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
74
+ "is_guided": True,
75
+ },
76
+ }
77
+
78
+
79
+ def load_img(display=True, key=None, device="cuda"):
80
+ image = get_interactive_image(key=key)
81
+ if image is None:
82
+ return None
83
+ if display:
84
+ st.image(image)
85
+ w, h = image.size
86
+ print(f"loaded input image of size ({w}, {h})")
87
+ width, height = map(
88
+ lambda x: x - x % 64, (w, h)
89
+ ) # resize to integer multiple of 64
90
+ image = image.resize((width, height))
91
+ image = np.array(image.convert("RGB"))
92
+ image = image[None].transpose(0, 3, 1, 2)
93
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
94
+ return image.to(device)
95
+
96
+
97
+ def run_txt2img(
98
+ state, version, version_dict, is_legacy=False, return_latents=False, filter=None
99
+ ):
100
+ if version == "SD-XL base":
101
+ ratio = st.sidebar.selectbox("Ratio:", list(SD_XL_BASE_RATIOS.keys()), 10)
102
+ W, H = SD_XL_BASE_RATIOS[ratio]
103
+ else:
104
+ H = st.sidebar.number_input(
105
+ "H", value=version_dict["H"], min_value=64, max_value=2048
106
+ )
107
+ W = st.sidebar.number_input(
108
+ "W", value=version_dict["W"], min_value=64, max_value=2048
109
+ )
110
+ C = version_dict["C"]
111
+ F = version_dict["f"]
112
+
113
+ init_dict = {
114
+ "orig_width": W,
115
+ "orig_height": H,
116
+ "target_width": W,
117
+ "target_height": H,
118
+ }
119
+ value_dict = init_embedder_options(
120
+ get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
121
+ init_dict,
122
+ prompt=prompt,
123
+ negative_prompt=negative_prompt,
124
+ )
125
+ num_rows, num_cols, sampler = init_sampling(
126
+ use_identity_guider=not version_dict["is_guided"]
127
+ )
128
+
129
+ num_samples = num_rows * num_cols
130
+
131
+ if st.button("Sample"):
132
+ st.write(f"**Model I:** {version}")
133
+ out = do_sample(
134
+ state["model"],
135
+ sampler,
136
+ value_dict,
137
+ num_samples,
138
+ H,
139
+ W,
140
+ C,
141
+ F,
142
+ force_uc_zero_embeddings=["txt"] if not is_legacy else [],
143
+ return_latents=return_latents,
144
+ filter=filter,
145
+ )
146
+ return out
147
+
148
+
149
+ def run_img2img(
150
+ state, version_dict, is_legacy=False, return_latents=False, filter=None
151
+ ):
152
+ img = load_img()
153
+ if img is None:
154
+ return None
155
+ H, W = img.shape[2], img.shape[3]
156
+
157
+ init_dict = {
158
+ "orig_width": W,
159
+ "orig_height": H,
160
+ "target_width": W,
161
+ "target_height": H,
162
+ }
163
+ value_dict = init_embedder_options(
164
+ get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
165
+ init_dict,
166
+ )
167
+ strength = st.number_input(
168
+ "**Img2Img Strength**", value=0.5, min_value=0.0, max_value=1.0
169
+ )
170
+ num_rows, num_cols, sampler = init_sampling(
171
+ img2img_strength=strength,
172
+ use_identity_guider=not version_dict["is_guided"],
173
+ )
174
+ num_samples = num_rows * num_cols
175
+
176
+ if st.button("Sample"):
177
+ out = do_img2img(
178
+ repeat(img, "1 ... -> n ...", n=num_samples),
179
+ state["model"],
180
+ sampler,
181
+ value_dict,
182
+ num_samples,
183
+ force_uc_zero_embeddings=["txt"] if not is_legacy else [],
184
+ return_latents=return_latents,
185
+ filter=filter,
186
+ )
187
+ return out
188
+
189
+
190
+ def apply_refiner(
191
+ input,
192
+ state,
193
+ sampler,
194
+ num_samples,
195
+ prompt,
196
+ negative_prompt,
197
+ filter=None,
198
+ ):
199
+ init_dict = {
200
+ "orig_width": input.shape[3] * 8,
201
+ "orig_height": input.shape[2] * 8,
202
+ "target_width": input.shape[3] * 8,
203
+ "target_height": input.shape[2] * 8,
204
+ }
205
+
206
+ value_dict = init_dict
207
+ value_dict["prompt"] = prompt
208
+ value_dict["negative_prompt"] = negative_prompt
209
+
210
+ value_dict["crop_coords_top"] = 0
211
+ value_dict["crop_coords_left"] = 0
212
+
213
+ value_dict["aesthetic_score"] = 6.0
214
+ value_dict["negative_aesthetic_score"] = 2.5
215
+
216
+ st.warning(f"refiner input shape: {input.shape}")
217
+ samples = do_img2img(
218
+ input,
219
+ state["model"],
220
+ sampler,
221
+ value_dict,
222
+ num_samples,
223
+ skip_encode=True,
224
+ filter=filter,
225
+ )
226
+
227
+ return samples
228
+
229
+
230
+ if __name__ == "__main__":
231
+ st.title("Stable Diffusion")
232
+ version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
233
+ version_dict = VERSION2SPECS[version]
234
+ mode = st.radio("Mode", ("txt2img", "img2img"), 0)
235
+ st.write("__________________________")
236
+
237
+ if version == "SD-XL base":
238
+ add_pipeline = st.checkbox("Load SDXL-Refiner?", False)
239
+ st.write("__________________________")
240
+ else:
241
+ add_pipeline = False
242
+
243
+ filter = DeepFloydDataFiltering(verbose=False)
244
+
245
+ seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
246
+ seed_everything(seed)
247
+
248
+ save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
249
+
250
+ state = init_st(version_dict)
251
+ if state["msg"]:
252
+ st.info(state["msg"])
253
+ model = state["model"]
254
+
255
+ is_legacy = version_dict["is_legacy"]
256
+
257
+ prompt = st.text_input(
258
+ "prompt",
259
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
260
+ )
261
+ if is_legacy:
262
+ negative_prompt = st.text_input("negative prompt", "")
263
+ else:
264
+ negative_prompt = "" # which is unused
265
+
266
+ if add_pipeline:
267
+ st.write("__________________________")
268
+
269
+ version2 = "SDXL-Refiner"
270
+ st.warning(
271
+ f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
272
+ )
273
+ st.write("**Refiner Options:**")
274
+
275
+ version_dict2 = VERSION2SPECS[version2]
276
+ state2 = init_st(version_dict2)
277
+ st.info(state2["msg"])
278
+
279
+ stage2strength = st.number_input(
280
+ "**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0
281
+ )
282
+
283
+ sampler2 = init_sampling(
284
+ key=2,
285
+ img2img_strength=stage2strength,
286
+ use_identity_guider=not version_dict2["is_guided"],
287
+ get_num_samples=False,
288
+ )
289
+ st.write("__________________________")
290
+
291
+ if mode == "txt2img":
292
+ out = run_txt2img(
293
+ state,
294
+ version,
295
+ version_dict,
296
+ is_legacy=is_legacy,
297
+ return_latents=add_pipeline,
298
+ filter=filter,
299
+ )
300
+ elif mode == "img2img":
301
+ out = run_img2img(
302
+ state,
303
+ version_dict,
304
+ is_legacy=is_legacy,
305
+ return_latents=add_pipeline,
306
+ filter=filter,
307
+ )
308
+ else:
309
+ raise ValueError(f"unknown mode {mode}")
310
+ if isinstance(out, (tuple, list)):
311
+ samples, samples_z = out
312
+ else:
313
+ samples = out
314
+ samples_z = None
315
+
316
+ if add_pipeline and samples_z is not None:
317
+ st.write("**Running Refinement Stage**")
318
+ samples = apply_refiner(
319
+ samples_z,
320
+ state2,
321
+ sampler2,
322
+ samples_z.shape[0],
323
+ prompt=prompt,
324
+ negative_prompt=negative_prompt if is_legacy else "",
325
+ filter=filter,
326
+ )
327
+
328
+ if save_locally and samples is not None:
329
+ perform_save_locally(save_path, samples)
repositories/generative-models/scripts/demo/streamlit_helpers.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union, List
3
+
4
+ import math
5
+ import numpy as np
6
+ import streamlit as st
7
+ import torch
8
+ from PIL import Image
9
+ from einops import rearrange, repeat
10
+ from imwatermark import WatermarkEncoder
11
+ from omegaconf import OmegaConf, ListConfig
12
+ from torch import autocast
13
+ from torchvision import transforms
14
+ from torchvision.utils import make_grid
15
+ from safetensors.torch import load_file as load_safetensors
16
+
17
+ from sgm.modules.diffusionmodules.sampling import (
18
+ EulerEDMSampler,
19
+ HeunEDMSampler,
20
+ EulerAncestralSampler,
21
+ DPMPP2SAncestralSampler,
22
+ DPMPP2MSampler,
23
+ LinearMultistepSampler,
24
+ )
25
+ from sgm.util import append_dims
26
+ from sgm.util import instantiate_from_config
27
+
28
+
29
+ class WatermarkEmbedder:
30
+ def __init__(self, watermark):
31
+ self.watermark = watermark
32
+ self.num_bits = len(WATERMARK_BITS)
33
+ self.encoder = WatermarkEncoder()
34
+ self.encoder.set_watermark("bits", self.watermark)
35
+
36
+ def __call__(self, image: torch.Tensor):
37
+ """
38
+ Adds a predefined watermark to the input image
39
+
40
+ Args:
41
+ image: ([N,] B, C, H, W) in range [0, 1]
42
+
43
+ Returns:
44
+ same as input but watermarked
45
+ """
46
+ # watermarking libary expects input as cv2 BGR format
47
+ squeeze = len(image.shape) == 4
48
+ if squeeze:
49
+ image = image[None, ...]
50
+ n = image.shape[0]
51
+ image_np = rearrange(
52
+ (255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
53
+ ).numpy()[:, :, :, ::-1]
54
+ # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
55
+ for k in range(image_np.shape[0]):
56
+ image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
57
+ image = torch.from_numpy(
58
+ rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
59
+ ).to(image.device)
60
+ image = torch.clamp(image / 255, min=0.0, max=1.0)
61
+ if squeeze:
62
+ image = image[0]
63
+ return image
64
+
65
+
66
+ # A fixed 48-bit message that was choosen at random
67
+ # WATERMARK_MESSAGE = 0xB3EC907BB19E
68
+ WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
69
+ # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
70
+ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
71
+ embed_watemark = WatermarkEmbedder(WATERMARK_BITS)
72
+
73
+
74
+ @st.cache_resource()
75
+ def init_st(version_dict, load_ckpt=True):
76
+ state = dict()
77
+ if not "model" in state:
78
+ config = version_dict["config"]
79
+ ckpt = version_dict["ckpt"]
80
+
81
+ config = OmegaConf.load(config)
82
+ model, msg = load_model_from_config(config, ckpt if load_ckpt else None)
83
+
84
+ state["msg"] = msg
85
+ state["model"] = model
86
+ state["ckpt"] = ckpt if load_ckpt else None
87
+ state["config"] = config
88
+ return state
89
+
90
+
91
+ def load_model_from_config(config, ckpt=None, verbose=True):
92
+ model = instantiate_from_config(config.model)
93
+
94
+ if ckpt is not None:
95
+ print(f"Loading model from {ckpt}")
96
+ if ckpt.endswith("ckpt"):
97
+ pl_sd = torch.load(ckpt, map_location="cpu")
98
+ if "global_step" in pl_sd:
99
+ global_step = pl_sd["global_step"]
100
+ st.info(f"loaded ckpt from global step {global_step}")
101
+ print(f"Global Step: {pl_sd['global_step']}")
102
+ sd = pl_sd["state_dict"]
103
+ elif ckpt.endswith("safetensors"):
104
+ sd = load_safetensors(ckpt)
105
+ else:
106
+ raise NotImplementedError
107
+
108
+ msg = None
109
+
110
+ m, u = model.load_state_dict(sd, strict=False)
111
+
112
+ if len(m) > 0 and verbose:
113
+ print("missing keys:")
114
+ print(m)
115
+ if len(u) > 0 and verbose:
116
+ print("unexpected keys:")
117
+ print(u)
118
+ else:
119
+ msg = None
120
+
121
+ model.cuda()
122
+ model.eval()
123
+ return model, msg
124
+
125
+
126
+ def get_unique_embedder_keys_from_conditioner(conditioner):
127
+ return list(set([x.input_key for x in conditioner.embedders]))
128
+
129
+
130
+ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
131
+ # Hardcoded demo settings; might undergo some changes in the future
132
+
133
+ value_dict = {}
134
+ for key in keys:
135
+ if key == "txt":
136
+ if prompt is None:
137
+ prompt = st.text_input(
138
+ "Prompt", "A professional photograph of an astronaut riding a pig"
139
+ )
140
+ if negative_prompt is None:
141
+ negative_prompt = st.text_input("Negative prompt", "")
142
+
143
+ value_dict["prompt"] = prompt
144
+ value_dict["negative_prompt"] = negative_prompt
145
+
146
+ if key == "original_size_as_tuple":
147
+ orig_width = st.number_input(
148
+ "orig_width",
149
+ value=init_dict["orig_width"],
150
+ min_value=16,
151
+ )
152
+ orig_height = st.number_input(
153
+ "orig_height",
154
+ value=init_dict["orig_height"],
155
+ min_value=16,
156
+ )
157
+
158
+ value_dict["orig_width"] = orig_width
159
+ value_dict["orig_height"] = orig_height
160
+
161
+ if key == "crop_coords_top_left":
162
+ crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0)
163
+ crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0)
164
+
165
+ value_dict["crop_coords_top"] = crop_coord_top
166
+ value_dict["crop_coords_left"] = crop_coord_left
167
+
168
+ if key == "aesthetic_score":
169
+ value_dict["aesthetic_score"] = 6.0
170
+ value_dict["negative_aesthetic_score"] = 2.5
171
+
172
+ if key == "target_size_as_tuple":
173
+ target_width = st.number_input(
174
+ "target_width",
175
+ value=init_dict["target_width"],
176
+ min_value=16,
177
+ )
178
+ target_height = st.number_input(
179
+ "target_height",
180
+ value=init_dict["target_height"],
181
+ min_value=16,
182
+ )
183
+
184
+ value_dict["target_width"] = target_width
185
+ value_dict["target_height"] = target_height
186
+
187
+ return value_dict
188
+
189
+
190
+ def perform_save_locally(save_path, samples):
191
+ os.makedirs(os.path.join(save_path), exist_ok=True)
192
+ base_count = len(os.listdir(os.path.join(save_path)))
193
+ samples = embed_watemark(samples)
194
+ for sample in samples:
195
+ sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
196
+ Image.fromarray(sample.astype(np.uint8)).save(
197
+ os.path.join(save_path, f"{base_count:09}.png")
198
+ )
199
+ base_count += 1
200
+
201
+
202
+ def init_save_locally(_dir, init_value: bool = False):
203
+ save_locally = st.sidebar.checkbox("Save images locally", value=init_value)
204
+ if save_locally:
205
+ save_path = st.text_input("Save path", value=os.path.join(_dir, "samples"))
206
+ else:
207
+ save_path = None
208
+
209
+ return save_locally, save_path
210
+
211
+
212
+ class Img2ImgDiscretizationWrapper:
213
+ """
214
+ wraps a discretizer, and prunes the sigmas
215
+ params:
216
+ strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
217
+ """
218
+
219
+ def __init__(self, discretization, strength: float = 1.0):
220
+ self.discretization = discretization
221
+ self.strength = strength
222
+ assert 0.0 <= self.strength <= 1.0
223
+
224
+ def __call__(self, *args, **kwargs):
225
+ # sigmas start large first, and decrease then
226
+ sigmas = self.discretization(*args, **kwargs)
227
+ print(f"sigmas after discretization, before pruning img2img: ", sigmas)
228
+ sigmas = torch.flip(sigmas, (0,))
229
+ sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
230
+ print("prune index:", max(int(self.strength * len(sigmas)), 1))
231
+ sigmas = torch.flip(sigmas, (0,))
232
+ print(f"sigmas after pruning: ", sigmas)
233
+ return sigmas
234
+
235
+
236
+ def get_guider(key):
237
+ guider = st.sidebar.selectbox(
238
+ f"Discretization #{key}",
239
+ [
240
+ "VanillaCFG",
241
+ "IdentityGuider",
242
+ ],
243
+ )
244
+
245
+ if guider == "IdentityGuider":
246
+ guider_config = {
247
+ "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
248
+ }
249
+ elif guider == "VanillaCFG":
250
+ scale = st.number_input(
251
+ f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0
252
+ )
253
+
254
+ thresholder = st.sidebar.selectbox(
255
+ f"Thresholder #{key}",
256
+ [
257
+ "None",
258
+ ],
259
+ )
260
+
261
+ if thresholder == "None":
262
+ dyn_thresh_config = {
263
+ "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
264
+ }
265
+ else:
266
+ raise NotImplementedError
267
+
268
+ guider_config = {
269
+ "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
270
+ "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
271
+ }
272
+ else:
273
+ raise NotImplementedError
274
+ return guider_config
275
+
276
+
277
+ def init_sampling(
278
+ key=1, img2img_strength=1.0, use_identity_guider=False, get_num_samples=True
279
+ ):
280
+ if get_num_samples:
281
+ num_rows = 1
282
+ num_cols = st.number_input(
283
+ f"num cols #{key}", value=2, min_value=1, max_value=10
284
+ )
285
+
286
+ steps = st.sidebar.number_input(
287
+ f"steps #{key}", value=50, min_value=1, max_value=1000
288
+ )
289
+ sampler = st.sidebar.selectbox(
290
+ f"Sampler #{key}",
291
+ [
292
+ "EulerEDMSampler",
293
+ "HeunEDMSampler",
294
+ "EulerAncestralSampler",
295
+ "DPMPP2SAncestralSampler",
296
+ "DPMPP2MSampler",
297
+ "LinearMultistepSampler",
298
+ ],
299
+ 0,
300
+ )
301
+ discretization = st.sidebar.selectbox(
302
+ f"Discretization #{key}",
303
+ [
304
+ "LegacyDDPMDiscretization",
305
+ "EDMDiscretization",
306
+ ],
307
+ )
308
+
309
+ discretization_config = get_discretization(discretization, key=key)
310
+
311
+ guider_config = get_guider(key=key)
312
+
313
+ sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
314
+ if img2img_strength < 1.0:
315
+ st.warning(
316
+ f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
317
+ )
318
+ sampler.discretization = Img2ImgDiscretizationWrapper(
319
+ sampler.discretization, strength=img2img_strength
320
+ )
321
+ if get_num_samples:
322
+ return num_rows, num_cols, sampler
323
+ return sampler
324
+
325
+
326
+ def get_discretization(discretization, key=1):
327
+ if discretization == "LegacyDDPMDiscretization":
328
+ discretization_config = {
329
+ "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
330
+ }
331
+ elif discretization == "EDMDiscretization":
332
+ sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292
333
+ sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) # 14.6146
334
+ rho = st.number_input(f"rho #{key}", value=3.0)
335
+ discretization_config = {
336
+ "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
337
+ "params": {
338
+ "sigma_min": sigma_min,
339
+ "sigma_max": sigma_max,
340
+ "rho": rho,
341
+ },
342
+ }
343
+
344
+ return discretization_config
345
+
346
+
347
+ def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1):
348
+ if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler":
349
+ s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0)
350
+ s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0)
351
+ s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0)
352
+ s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0)
353
+
354
+ if sampler_name == "EulerEDMSampler":
355
+ sampler = EulerEDMSampler(
356
+ num_steps=steps,
357
+ discretization_config=discretization_config,
358
+ guider_config=guider_config,
359
+ s_churn=s_churn,
360
+ s_tmin=s_tmin,
361
+ s_tmax=s_tmax,
362
+ s_noise=s_noise,
363
+ verbose=True,
364
+ )
365
+ elif sampler_name == "HeunEDMSampler":
366
+ sampler = HeunEDMSampler(
367
+ num_steps=steps,
368
+ discretization_config=discretization_config,
369
+ guider_config=guider_config,
370
+ s_churn=s_churn,
371
+ s_tmin=s_tmin,
372
+ s_tmax=s_tmax,
373
+ s_noise=s_noise,
374
+ verbose=True,
375
+ )
376
+ elif (
377
+ sampler_name == "EulerAncestralSampler"
378
+ or sampler_name == "DPMPP2SAncestralSampler"
379
+ ):
380
+ s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0)
381
+ eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0)
382
+
383
+ if sampler_name == "EulerAncestralSampler":
384
+ sampler = EulerAncestralSampler(
385
+ num_steps=steps,
386
+ discretization_config=discretization_config,
387
+ guider_config=guider_config,
388
+ eta=eta,
389
+ s_noise=s_noise,
390
+ verbose=True,
391
+ )
392
+ elif sampler_name == "DPMPP2SAncestralSampler":
393
+ sampler = DPMPP2SAncestralSampler(
394
+ num_steps=steps,
395
+ discretization_config=discretization_config,
396
+ guider_config=guider_config,
397
+ eta=eta,
398
+ s_noise=s_noise,
399
+ verbose=True,
400
+ )
401
+ elif sampler_name == "DPMPP2MSampler":
402
+ sampler = DPMPP2MSampler(
403
+ num_steps=steps,
404
+ discretization_config=discretization_config,
405
+ guider_config=guider_config,
406
+ verbose=True,
407
+ )
408
+ elif sampler_name == "LinearMultistepSampler":
409
+ order = st.sidebar.number_input("order", value=4, min_value=1)
410
+ sampler = LinearMultistepSampler(
411
+ num_steps=steps,
412
+ discretization_config=discretization_config,
413
+ guider_config=guider_config,
414
+ order=order,
415
+ verbose=True,
416
+ )
417
+ else:
418
+ raise ValueError(f"unknown sampler {sampler_name}!")
419
+
420
+ return sampler
421
+
422
+
423
+ def get_interactive_image(key=None) -> Image.Image:
424
+ image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
425
+ if image is not None:
426
+ image = Image.open(image)
427
+ if not image.mode == "RGB":
428
+ image = image.convert("RGB")
429
+ return image
430
+
431
+
432
+ def load_img(display=True, key=None):
433
+ image = get_interactive_image(key=key)
434
+ if image is None:
435
+ return None
436
+ if display:
437
+ st.image(image)
438
+ w, h = image.size
439
+ print(f"loaded input image of size ({w}, {h})")
440
+
441
+ transform = transforms.Compose(
442
+ [
443
+ transforms.ToTensor(),
444
+ transforms.Lambda(lambda x: x * 2.0 - 1.0),
445
+ ]
446
+ )
447
+ img = transform(image)[None, ...]
448
+ st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}")
449
+ return img
450
+
451
+
452
+ def get_init_img(batch_size=1, key=None):
453
+ init_image = load_img(key=key).cuda()
454
+ init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
455
+ return init_image
456
+
457
+
458
+ def do_sample(
459
+ model,
460
+ sampler,
461
+ value_dict,
462
+ num_samples,
463
+ H,
464
+ W,
465
+ C,
466
+ F,
467
+ force_uc_zero_embeddings: List = None,
468
+ batch2model_input: List = None,
469
+ return_latents=False,
470
+ filter=None,
471
+ ):
472
+ if force_uc_zero_embeddings is None:
473
+ force_uc_zero_embeddings = []
474
+ if batch2model_input is None:
475
+ batch2model_input = []
476
+
477
+ st.text("Sampling")
478
+
479
+ outputs = st.empty()
480
+ precision_scope = autocast
481
+ with torch.no_grad():
482
+ with precision_scope("cuda"):
483
+ with model.ema_scope():
484
+ num_samples = [num_samples]
485
+ batch, batch_uc = get_batch(
486
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
487
+ value_dict,
488
+ num_samples,
489
+ )
490
+ for key in batch:
491
+ if isinstance(batch[key], torch.Tensor):
492
+ print(key, batch[key].shape)
493
+ elif isinstance(batch[key], list):
494
+ print(key, [len(l) for l in batch[key]])
495
+ else:
496
+ print(key, batch[key])
497
+ c, uc = model.conditioner.get_unconditional_conditioning(
498
+ batch,
499
+ batch_uc=batch_uc,
500
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
501
+ )
502
+
503
+ for k in c:
504
+ if not k == "crossattn":
505
+ c[k], uc[k] = map(
506
+ lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
507
+ )
508
+
509
+ additional_model_inputs = {}
510
+ for k in batch2model_input:
511
+ additional_model_inputs[k] = batch[k]
512
+
513
+ shape = (math.prod(num_samples), C, H // F, W // F)
514
+ randn = torch.randn(shape).to("cuda")
515
+
516
+ def denoiser(input, sigma, c):
517
+ return model.denoiser(
518
+ model.model, input, sigma, c, **additional_model_inputs
519
+ )
520
+
521
+ samples_z = sampler(denoiser, randn, cond=c, uc=uc)
522
+ samples_x = model.decode_first_stage(samples_z)
523
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
524
+
525
+ if filter is not None:
526
+ samples = filter(samples)
527
+
528
+ grid = torch.stack([samples])
529
+ grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
530
+ outputs.image(grid.cpu().numpy())
531
+
532
+ if return_latents:
533
+ return samples, samples_z
534
+ return samples
535
+
536
+
537
+ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
538
+ # Hardcoded demo setups; might undergo some changes in the future
539
+
540
+ batch = {}
541
+ batch_uc = {}
542
+
543
+ for key in keys:
544
+ if key == "txt":
545
+ batch["txt"] = (
546
+ np.repeat([value_dict["prompt"]], repeats=math.prod(N))
547
+ .reshape(N)
548
+ .tolist()
549
+ )
550
+ batch_uc["txt"] = (
551
+ np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
552
+ .reshape(N)
553
+ .tolist()
554
+ )
555
+ elif key == "original_size_as_tuple":
556
+ batch["original_size_as_tuple"] = (
557
+ torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
558
+ .to(device)
559
+ .repeat(*N, 1)
560
+ )
561
+ elif key == "crop_coords_top_left":
562
+ batch["crop_coords_top_left"] = (
563
+ torch.tensor(
564
+ [value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
565
+ )
566
+ .to(device)
567
+ .repeat(*N, 1)
568
+ )
569
+ elif key == "aesthetic_score":
570
+ batch["aesthetic_score"] = (
571
+ torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
572
+ )
573
+ batch_uc["aesthetic_score"] = (
574
+ torch.tensor([value_dict["negative_aesthetic_score"]])
575
+ .to(device)
576
+ .repeat(*N, 1)
577
+ )
578
+
579
+ elif key == "target_size_as_tuple":
580
+ batch["target_size_as_tuple"] = (
581
+ torch.tensor([value_dict["target_height"], value_dict["target_width"]])
582
+ .to(device)
583
+ .repeat(*N, 1)
584
+ )
585
+ else:
586
+ batch[key] = value_dict[key]
587
+
588
+ for key in batch.keys():
589
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
590
+ batch_uc[key] = torch.clone(batch[key])
591
+ return batch, batch_uc
592
+
593
+
594
+ @torch.no_grad()
595
+ def do_img2img(
596
+ img,
597
+ model,
598
+ sampler,
599
+ value_dict,
600
+ num_samples,
601
+ force_uc_zero_embeddings=[],
602
+ additional_kwargs={},
603
+ offset_noise_level: int = 0.0,
604
+ return_latents=False,
605
+ skip_encode=False,
606
+ filter=None,
607
+ ):
608
+ st.text("Sampling")
609
+
610
+ outputs = st.empty()
611
+ precision_scope = autocast
612
+ with torch.no_grad():
613
+ with precision_scope("cuda"):
614
+ with model.ema_scope():
615
+ batch, batch_uc = get_batch(
616
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
617
+ value_dict,
618
+ [num_samples],
619
+ )
620
+ c, uc = model.conditioner.get_unconditional_conditioning(
621
+ batch,
622
+ batch_uc=batch_uc,
623
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
624
+ )
625
+
626
+ for k in c:
627
+ c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc))
628
+
629
+ for k in additional_kwargs:
630
+ c[k] = uc[k] = additional_kwargs[k]
631
+ if skip_encode:
632
+ z = img
633
+ else:
634
+ z = model.encode_first_stage(img)
635
+ noise = torch.randn_like(z)
636
+ sigmas = sampler.discretization(sampler.num_steps)
637
+ sigma = sigmas[0]
638
+
639
+ st.info(f"all sigmas: {sigmas}")
640
+ st.info(f"noising sigma: {sigma}")
641
+
642
+ if offset_noise_level > 0.0:
643
+ noise = noise + offset_noise_level * append_dims(
644
+ torch.randn(z.shape[0], device=z.device), z.ndim
645
+ )
646
+ noised_z = z + noise * append_dims(sigma, z.ndim)
647
+ noised_z = noised_z / torch.sqrt(
648
+ 1.0 + sigmas[0] ** 2.0
649
+ ) # Note: hardcoded to DDPM-like scaling. need to generalize later.
650
+
651
+ def denoiser(x, sigma, c):
652
+ return model.denoiser(model.model, x, sigma, c)
653
+
654
+ samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
655
+ samples_x = model.decode_first_stage(samples_z)
656
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
657
+
658
+ if filter is not None:
659
+ samples = filter(samples)
660
+
661
+ grid = embed_watemark(torch.stack([samples]))
662
+ grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
663
+ outputs.image(grid.cpu().numpy())
664
+ if return_latents:
665
+ return samples, samples_z
666
+ return samples
repositories/generative-models/scripts/util/__init__.py ADDED
File without changes
repositories/generative-models/scripts/util/detection/__init__.py ADDED
File without changes
repositories/generative-models/scripts/util/detection/nsfw_and_watermark_dectection.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import torchvision.transforms as T
5
+ from PIL import Image
6
+ import clip
7
+
8
+ RESOURCES_ROOT = "scripts/util/detection/"
9
+
10
+
11
+ def predict_proba(X, weights, biases):
12
+ logits = X @ weights.T + biases
13
+ proba = np.where(
14
+ logits >= 0, 1 / (1 + np.exp(-logits)), np.exp(logits) / (1 + np.exp(logits))
15
+ )
16
+ return proba.T
17
+
18
+
19
+ def load_model_weights(path: str):
20
+ model_weights = np.load(path)
21
+ return model_weights["weights"], model_weights["biases"]
22
+
23
+
24
+ def clip_process_images(images: torch.Tensor) -> torch.Tensor:
25
+ min_size = min(images.shape[-2:])
26
+ return T.Compose(
27
+ [
28
+ T.CenterCrop(min_size), # TODO: this might affect the watermark, check this
29
+ T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True),
30
+ T.Normalize(
31
+ (0.48145466, 0.4578275, 0.40821073),
32
+ (0.26862954, 0.26130258, 0.27577711),
33
+ ),
34
+ ]
35
+ )(images)
36
+
37
+
38
+ class DeepFloydDataFiltering(object):
39
+ def __init__(self, verbose: bool = False):
40
+ super().__init__()
41
+ self.verbose = verbose
42
+ self.clip_model, _ = clip.load("ViT-L/14", device="cpu")
43
+ self.clip_model.eval()
44
+
45
+ self.cpu_w_weights, self.cpu_w_biases = load_model_weights(
46
+ os.path.join(RESOURCES_ROOT, "w_head_v1.npz")
47
+ )
48
+ self.cpu_p_weights, self.cpu_p_biases = load_model_weights(
49
+ os.path.join(RESOURCES_ROOT, "p_head_v1.npz")
50
+ )
51
+ self.w_threshold, self.p_threshold = 0.5, 0.5
52
+
53
+ @torch.inference_mode()
54
+ def __call__(self, images: torch.Tensor) -> torch.Tensor:
55
+ imgs = clip_process_images(images)
56
+ image_features = self.clip_model.encode_image(imgs.to("cpu"))
57
+ image_features = image_features.detach().cpu().numpy().astype(np.float16)
58
+ p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases)
59
+ w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases)
60
+ print(f"p_pred = {p_pred}, w_pred = {w_pred}") if self.verbose else None
61
+ query = p_pred > self.p_threshold
62
+ if query.sum() > 0:
63
+ print(f"Hit for p_threshold: {p_pred}") if self.verbose else None
64
+ images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query])
65
+ query = w_pred > self.w_threshold
66
+ if query.sum() > 0:
67
+ print(f"Hit for w_threshold: {w_pred}") if self.verbose else None
68
+ images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query])
69
+ return images
70
+
71
+
72
+ def load_img(path: str) -> torch.Tensor:
73
+ image = Image.open(path)
74
+ if not image.mode == "RGB":
75
+ image = image.convert("RGB")
76
+ image_transforms = T.Compose(
77
+ [
78
+ T.ToTensor(),
79
+ ]
80
+ )
81
+ return image_transforms(image)[None, ...]
82
+
83
+
84
+ def test(root):
85
+ from einops import rearrange
86
+
87
+ filter = DeepFloydDataFiltering(verbose=True)
88
+ for p in os.listdir((root)):
89
+ print(f"running on {p}...")
90
+ img = load_img(os.path.join(root, p))
91
+ filtered_img = filter(img)
92
+ filtered_img = rearrange(
93
+ 255.0 * (filtered_img.numpy())[0], "c h w -> h w c"
94
+ ).astype(np.uint8)
95
+ Image.fromarray(filtered_img).save(
96
+ os.path.join(root, f"{os.path.splitext(p)[0]}-filtered.jpg")
97
+ )
98
+
99
+
100
+ if __name__ == "__main__":
101
+ import fire
102
+
103
+ fire.Fire(test)
104
+ print("done.")
repositories/generative-models/scripts/util/detection/p_head_v1.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4653a64d5f85d8d4c5f6c5ec175f1c5c5e37db8f38d39b2ed8b5979da7fdc76
3
+ size 3588
repositories/generative-models/scripts/util/detection/w_head_v1.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6af23687aa347073e692025f405ccc48c14aadc5dbe775b3312041006d496d1
3
+ size 3588
repositories/generative-models/setup.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import find_packages, setup
2
+
3
+ setup(
4
+ name="sgm",
5
+ version="0.0.1",
6
+ packages=find_packages(),
7
+ python_requires=">=3.8",
8
+ py_modules=["sgm"],
9
+ description="Stability Generative Models",
10
+ long_description=open("README.md", "r", encoding="utf-8").read(),
11
+ long_description_content_type="text/markdown",
12
+ url="https://github.com/Stability-AI/generative-models",
13
+ )
repositories/generative-models/sgm/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .data import StableDataModuleFromConfig
2
+ from .models import AutoencodingEngine, DiffusionEngine
3
+ from .util import instantiate_from_config
repositories/generative-models/sgm/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (345 Bytes). View file
 
repositories/generative-models/sgm/__pycache__/util.cpython-310.pyc ADDED
Binary file (8.09 kB). View file
 
repositories/generative-models/sgm/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .dataset import StableDataModuleFromConfig
repositories/generative-models/sgm/data/cifar10.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ import pytorch_lightning as pl
3
+ from torchvision import transforms
4
+ from torch.utils.data import DataLoader, Dataset
5
+
6
+
7
+ class CIFAR10DataDictWrapper(Dataset):
8
+ def __init__(self, dset):
9
+ super().__init__()
10
+ self.dset = dset
11
+
12
+ def __getitem__(self, i):
13
+ x, y = self.dset[i]
14
+ return {"jpg": x, "cls": y}
15
+
16
+ def __len__(self):
17
+ return len(self.dset)
18
+
19
+
20
+ class CIFAR10Loader(pl.LightningDataModule):
21
+ def __init__(self, batch_size, num_workers=0, shuffle=True):
22
+ super().__init__()
23
+
24
+ transform = transforms.Compose(
25
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
26
+ )
27
+
28
+ self.batch_size = batch_size
29
+ self.num_workers = num_workers
30
+ self.shuffle = shuffle
31
+ self.train_dataset = CIFAR10DataDictWrapper(
32
+ torchvision.datasets.CIFAR10(
33
+ root=".data/", train=True, download=True, transform=transform
34
+ )
35
+ )
36
+ self.test_dataset = CIFAR10DataDictWrapper(
37
+ torchvision.datasets.CIFAR10(
38
+ root=".data/", train=False, download=True, transform=transform
39
+ )
40
+ )
41
+
42
+ def prepare_data(self):
43
+ pass
44
+
45
+ def train_dataloader(self):
46
+ return DataLoader(
47
+ self.train_dataset,
48
+ batch_size=self.batch_size,
49
+ shuffle=self.shuffle,
50
+ num_workers=self.num_workers,
51
+ )
52
+
53
+ def test_dataloader(self):
54
+ return DataLoader(
55
+ self.test_dataset,
56
+ batch_size=self.batch_size,
57
+ shuffle=self.shuffle,
58
+ num_workers=self.num_workers,
59
+ )
60
+
61
+ def val_dataloader(self):
62
+ return DataLoader(
63
+ self.test_dataset,
64
+ batch_size=self.batch_size,
65
+ shuffle=self.shuffle,
66
+ num_workers=self.num_workers,
67
+ )
repositories/generative-models/sgm/data/dataset.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torchdata.datapipes.iter
4
+ import webdataset as wds
5
+ from omegaconf import DictConfig
6
+ from pytorch_lightning import LightningDataModule
7
+
8
+ try:
9
+ from sdata import create_dataset, create_dummy_dataset, create_loader
10
+ except ImportError as e:
11
+ print("#" * 100)
12
+ print("Datasets not yet available")
13
+ print("to enable, we need to add stable-datasets as a submodule")
14
+ print("please use ``git submodule update --init --recursive``")
15
+ print("and do ``pip install -e stable-datasets/`` from the root of this repo")
16
+ print("#" * 100)
17
+ exit(1)
18
+
19
+
20
+ class StableDataModuleFromConfig(LightningDataModule):
21
+ def __init__(
22
+ self,
23
+ train: DictConfig,
24
+ validation: Optional[DictConfig] = None,
25
+ test: Optional[DictConfig] = None,
26
+ skip_val_loader: bool = False,
27
+ dummy: bool = False,
28
+ ):
29
+ super().__init__()
30
+ self.train_config = train
31
+ assert (
32
+ "datapipeline" in self.train_config and "loader" in self.train_config
33
+ ), "train config requires the fields `datapipeline` and `loader`"
34
+
35
+ self.val_config = validation
36
+ if not skip_val_loader:
37
+ if self.val_config is not None:
38
+ assert (
39
+ "datapipeline" in self.val_config and "loader" in self.val_config
40
+ ), "validation config requires the fields `datapipeline` and `loader`"
41
+ else:
42
+ print(
43
+ "Warning: No Validation datapipeline defined, using that one from training"
44
+ )
45
+ self.val_config = train
46
+
47
+ self.test_config = test
48
+ if self.test_config is not None:
49
+ assert (
50
+ "datapipeline" in self.test_config and "loader" in self.test_config
51
+ ), "test config requires the fields `datapipeline` and `loader`"
52
+
53
+ self.dummy = dummy
54
+ if self.dummy:
55
+ print("#" * 100)
56
+ print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
57
+ print("#" * 100)
58
+
59
+ def setup(self, stage: str) -> None:
60
+ print("Preparing datasets")
61
+ if self.dummy:
62
+ data_fn = create_dummy_dataset
63
+ else:
64
+ data_fn = create_dataset
65
+
66
+ self.train_datapipeline = data_fn(**self.train_config.datapipeline)
67
+ if self.val_config:
68
+ self.val_datapipeline = data_fn(**self.val_config.datapipeline)
69
+ if self.test_config:
70
+ self.test_datapipeline = data_fn(**self.test_config.datapipeline)
71
+
72
+ def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
73
+ loader = create_loader(self.train_datapipeline, **self.train_config.loader)
74
+ return loader
75
+
76
+ def val_dataloader(self) -> wds.DataPipeline:
77
+ return create_loader(self.val_datapipeline, **self.val_config.loader)
78
+
79
+ def test_dataloader(self) -> wds.DataPipeline:
80
+ return create_loader(self.test_datapipeline, **self.test_config.loader)
repositories/generative-models/sgm/data/mnist.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ import pytorch_lightning as pl
3
+ from torchvision import transforms
4
+ from torch.utils.data import DataLoader, Dataset
5
+
6
+
7
+ class MNISTDataDictWrapper(Dataset):
8
+ def __init__(self, dset):
9
+ super().__init__()
10
+ self.dset = dset
11
+
12
+ def __getitem__(self, i):
13
+ x, y = self.dset[i]
14
+ return {"jpg": x, "cls": y}
15
+
16
+ def __len__(self):
17
+ return len(self.dset)
18
+
19
+
20
+ class MNISTLoader(pl.LightningDataModule):
21
+ def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True):
22
+ super().__init__()
23
+
24
+ transform = transforms.Compose(
25
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
26
+ )
27
+
28
+ self.batch_size = batch_size
29
+ self.num_workers = num_workers
30
+ self.prefetch_factor = prefetch_factor if num_workers > 0 else 0
31
+ self.shuffle = shuffle
32
+ self.train_dataset = MNISTDataDictWrapper(
33
+ torchvision.datasets.MNIST(
34
+ root=".data/", train=True, download=True, transform=transform
35
+ )
36
+ )
37
+ self.test_dataset = MNISTDataDictWrapper(
38
+ torchvision.datasets.MNIST(
39
+ root=".data/", train=False, download=True, transform=transform
40
+ )
41
+ )
42
+
43
+ def prepare_data(self):
44
+ pass
45
+
46
+ def train_dataloader(self):
47
+ return DataLoader(
48
+ self.train_dataset,
49
+ batch_size=self.batch_size,
50
+ shuffle=self.shuffle,
51
+ num_workers=self.num_workers,
52
+ prefetch_factor=self.prefetch_factor,
53
+ )
54
+
55
+ def test_dataloader(self):
56
+ return DataLoader(
57
+ self.test_dataset,
58
+ batch_size=self.batch_size,
59
+ shuffle=self.shuffle,
60
+ num_workers=self.num_workers,
61
+ prefetch_factor=self.prefetch_factor,
62
+ )
63
+
64
+ def val_dataloader(self):
65
+ return DataLoader(
66
+ self.test_dataset,
67
+ batch_size=self.batch_size,
68
+ shuffle=self.shuffle,
69
+ num_workers=self.num_workers,
70
+ prefetch_factor=self.prefetch_factor,
71
+ )
72
+
73
+
74
+ if __name__ == "__main__":
75
+ dset = MNISTDataDictWrapper(
76
+ torchvision.datasets.MNIST(
77
+ root=".data/",
78
+ train=False,
79
+ download=True,
80
+ transform=transforms.Compose(
81
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
82
+ ),
83
+ )
84
+ )
85
+ ex = dset[0]
repositories/generative-models/sgm/lr_scheduler.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+
9
+ def __init__(
10
+ self,
11
+ warm_up_steps,
12
+ lr_min,
13
+ lr_max,
14
+ lr_start,
15
+ max_decay_steps,
16
+ verbosity_interval=0,
17
+ ):
18
+ self.lr_warm_up_steps = warm_up_steps
19
+ self.lr_start = lr_start
20
+ self.lr_min = lr_min
21
+ self.lr_max = lr_max
22
+ self.lr_max_decay_steps = max_decay_steps
23
+ self.last_lr = 0.0
24
+ self.verbosity_interval = verbosity_interval
25
+
26
+ def schedule(self, n, **kwargs):
27
+ if self.verbosity_interval > 0:
28
+ if n % self.verbosity_interval == 0:
29
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
30
+ if n < self.lr_warm_up_steps:
31
+ lr = (
32
+ self.lr_max - self.lr_start
33
+ ) / self.lr_warm_up_steps * n + self.lr_start
34
+ self.last_lr = lr
35
+ return lr
36
+ else:
37
+ t = (n - self.lr_warm_up_steps) / (
38
+ self.lr_max_decay_steps - self.lr_warm_up_steps
39
+ )
40
+ t = min(t, 1.0)
41
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
42
+ 1 + np.cos(t * np.pi)
43
+ )
44
+ self.last_lr = lr
45
+ return lr
46
+
47
+ def __call__(self, n, **kwargs):
48
+ return self.schedule(n, **kwargs)
49
+
50
+
51
+ class LambdaWarmUpCosineScheduler2:
52
+ """
53
+ supports repeated iterations, configurable via lists
54
+ note: use with a base_lr of 1.0.
55
+ """
56
+
57
+ def __init__(
58
+ self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
59
+ ):
60
+ assert (
61
+ len(warm_up_steps)
62
+ == len(f_min)
63
+ == len(f_max)
64
+ == len(f_start)
65
+ == len(cycle_lengths)
66
+ )
67
+ self.lr_warm_up_steps = warm_up_steps
68
+ self.f_start = f_start
69
+ self.f_min = f_min
70
+ self.f_max = f_max
71
+ self.cycle_lengths = cycle_lengths
72
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
73
+ self.last_f = 0.0
74
+ self.verbosity_interval = verbosity_interval
75
+
76
+ def find_in_interval(self, n):
77
+ interval = 0
78
+ for cl in self.cum_cycles[1:]:
79
+ if n <= cl:
80
+ return interval
81
+ interval += 1
82
+
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0:
88
+ print(
89
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
90
+ f"current cycle {cycle}"
91
+ )
92
+ if n < self.lr_warm_up_steps[cycle]:
93
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
94
+ cycle
95
+ ] * n + self.f_start[cycle]
96
+ self.last_f = f
97
+ return f
98
+ else:
99
+ t = (n - self.lr_warm_up_steps[cycle]) / (
100
+ self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
101
+ )
102
+ t = min(t, 1.0)
103
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
104
+ 1 + np.cos(t * np.pi)
105
+ )
106
+ self.last_f = f
107
+ return f
108
+
109
+ def __call__(self, n, **kwargs):
110
+ return self.schedule(n, **kwargs)
111
+
112
+
113
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
114
+ def schedule(self, n, **kwargs):
115
+ cycle = self.find_in_interval(n)
116
+ n = n - self.cum_cycles[cycle]
117
+ if self.verbosity_interval > 0:
118
+ if n % self.verbosity_interval == 0:
119
+ print(
120
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
121
+ f"current cycle {cycle}"
122
+ )
123
+
124
+ if n < self.lr_warm_up_steps[cycle]:
125
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
126
+ cycle
127
+ ] * n + self.f_start[cycle]
128
+ self.last_f = f
129
+ return f
130
+ else:
131
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
132
+ self.cycle_lengths[cycle] - n
133
+ ) / (self.cycle_lengths[cycle])
134
+ self.last_f = f
135
+ return f
repositories/generative-models/sgm/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .autoencoder import AutoencodingEngine
2
+ from .diffusion import DiffusionEngine
repositories/generative-models/sgm/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (273 Bytes). View file
 
repositories/generative-models/sgm/models/__pycache__/autoencoder.cpython-310.pyc ADDED
Binary file (11.5 kB). View file
 
repositories/generative-models/sgm/models/__pycache__/diffusion.cpython-310.pyc ADDED
Binary file (10.2 kB). View file
 
repositories/generative-models/sgm/models/autoencoder.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from abc import abstractmethod
3
+ from contextlib import contextmanager
4
+ from typing import Any, Dict, Tuple, Union
5
+
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ from omegaconf import ListConfig
9
+ from packaging import version
10
+ from safetensors.torch import load_file as load_safetensors
11
+
12
+ from ..modules.diffusionmodules.model import Decoder, Encoder
13
+ from ..modules.distributions.distributions import DiagonalGaussianDistribution
14
+ from ..modules.ema import LitEma
15
+ from ..util import default, get_obj_from_str, instantiate_from_config
16
+
17
+
18
+ class AbstractAutoencoder(pl.LightningModule):
19
+ """
20
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
21
+ unCLIP models, etc. Hence, it is fairly general, and specific features
22
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ ema_decay: Union[None, float] = None,
28
+ monitor: Union[None, str] = None,
29
+ input_key: str = "jpg",
30
+ ckpt_path: Union[None, str] = None,
31
+ ignore_keys: Union[Tuple, list, ListConfig] = (),
32
+ ):
33
+ super().__init__()
34
+ self.input_key = input_key
35
+ self.use_ema = ema_decay is not None
36
+ if monitor is not None:
37
+ self.monitor = monitor
38
+
39
+ if self.use_ema:
40
+ self.model_ema = LitEma(self, decay=ema_decay)
41
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
42
+
43
+ if ckpt_path is not None:
44
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
45
+
46
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
47
+ self.automatic_optimization = False
48
+
49
+ def init_from_ckpt(
50
+ self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple()
51
+ ) -> None:
52
+ if path.endswith("ckpt"):
53
+ sd = torch.load(path, map_location="cpu")["state_dict"]
54
+ elif path.endswith("safetensors"):
55
+ sd = load_safetensors(path)
56
+ else:
57
+ raise NotImplementedError
58
+
59
+ keys = list(sd.keys())
60
+ for k in keys:
61
+ for ik in ignore_keys:
62
+ if re.match(ik, k):
63
+ print("Deleting key {} from state_dict.".format(k))
64
+ del sd[k]
65
+ missing, unexpected = self.load_state_dict(sd, strict=False)
66
+ print(
67
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
68
+ )
69
+ if len(missing) > 0:
70
+ print(f"Missing Keys: {missing}")
71
+ if len(unexpected) > 0:
72
+ print(f"Unexpected Keys: {unexpected}")
73
+
74
+ @abstractmethod
75
+ def get_input(self, batch) -> Any:
76
+ raise NotImplementedError()
77
+
78
+ def on_train_batch_end(self, *args, **kwargs):
79
+ # for EMA computation
80
+ if self.use_ema:
81
+ self.model_ema(self)
82
+
83
+ @contextmanager
84
+ def ema_scope(self, context=None):
85
+ if self.use_ema:
86
+ self.model_ema.store(self.parameters())
87
+ self.model_ema.copy_to(self)
88
+ if context is not None:
89
+ print(f"{context}: Switched to EMA weights")
90
+ try:
91
+ yield None
92
+ finally:
93
+ if self.use_ema:
94
+ self.model_ema.restore(self.parameters())
95
+ if context is not None:
96
+ print(f"{context}: Restored training weights")
97
+
98
+ @abstractmethod
99
+ def encode(self, *args, **kwargs) -> torch.Tensor:
100
+ raise NotImplementedError("encode()-method of abstract base class called")
101
+
102
+ @abstractmethod
103
+ def decode(self, *args, **kwargs) -> torch.Tensor:
104
+ raise NotImplementedError("decode()-method of abstract base class called")
105
+
106
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
107
+ print(f"loading >>> {cfg['target']} <<< optimizer from config")
108
+ return get_obj_from_str(cfg["target"])(
109
+ params, lr=lr, **cfg.get("params", dict())
110
+ )
111
+
112
+ def configure_optimizers(self) -> Any:
113
+ raise NotImplementedError()
114
+
115
+
116
+ class AutoencodingEngine(AbstractAutoencoder):
117
+ """
118
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
119
+ (we also restore them explicitly as special cases for legacy reasons).
120
+ Regularizations such as KL or VQ are moved to the regularizer class.
121
+ """
122
+
123
+ def __init__(
124
+ self,
125
+ *args,
126
+ encoder_config: Dict,
127
+ decoder_config: Dict,
128
+ loss_config: Dict,
129
+ regularizer_config: Dict,
130
+ optimizer_config: Union[Dict, None] = None,
131
+ lr_g_factor: float = 1.0,
132
+ **kwargs,
133
+ ):
134
+ super().__init__(*args, **kwargs)
135
+ # todo: add options to freeze encoder/decoder
136
+ self.encoder = instantiate_from_config(encoder_config)
137
+ self.decoder = instantiate_from_config(decoder_config)
138
+ self.loss = instantiate_from_config(loss_config)
139
+ self.regularization = instantiate_from_config(regularizer_config)
140
+ self.optimizer_config = default(
141
+ optimizer_config, {"target": "torch.optim.Adam"}
142
+ )
143
+ self.lr_g_factor = lr_g_factor
144
+
145
+ def get_input(self, batch: Dict) -> torch.Tensor:
146
+ # assuming unified data format, dataloader returns a dict.
147
+ # image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc)
148
+ return batch[self.input_key]
149
+
150
+ def get_autoencoder_params(self) -> list:
151
+ params = (
152
+ list(self.encoder.parameters())
153
+ + list(self.decoder.parameters())
154
+ + list(self.regularization.get_trainable_parameters())
155
+ + list(self.loss.get_trainable_autoencoder_parameters())
156
+ )
157
+ return params
158
+
159
+ def get_discriminator_params(self) -> list:
160
+ params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
161
+ return params
162
+
163
+ def get_last_layer(self):
164
+ return self.decoder.get_last_layer()
165
+
166
+ def encode(self, x: Any, return_reg_log: bool = False) -> Any:
167
+ z = self.encoder(x)
168
+ z, reg_log = self.regularization(z)
169
+ if return_reg_log:
170
+ return z, reg_log
171
+ return z
172
+
173
+ def decode(self, z: Any) -> torch.Tensor:
174
+ x = self.decoder(z)
175
+ return x
176
+
177
+ def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
178
+ z, reg_log = self.encode(x, return_reg_log=True)
179
+ dec = self.decode(z)
180
+ return z, dec, reg_log
181
+
182
+ def training_step(self, batch, batch_idx, optimizer_idx) -> Any:
183
+ x = self.get_input(batch)
184
+ z, xrec, regularization_log = self(x)
185
+
186
+ if optimizer_idx == 0:
187
+ # autoencode
188
+ aeloss, log_dict_ae = self.loss(
189
+ regularization_log,
190
+ x,
191
+ xrec,
192
+ optimizer_idx,
193
+ self.global_step,
194
+ last_layer=self.get_last_layer(),
195
+ split="train",
196
+ )
197
+
198
+ self.log_dict(
199
+ log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
200
+ )
201
+ return aeloss
202
+
203
+ if optimizer_idx == 1:
204
+ # discriminator
205
+ discloss, log_dict_disc = self.loss(
206
+ regularization_log,
207
+ x,
208
+ xrec,
209
+ optimizer_idx,
210
+ self.global_step,
211
+ last_layer=self.get_last_layer(),
212
+ split="train",
213
+ )
214
+ self.log_dict(
215
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
216
+ )
217
+ return discloss
218
+
219
+ def validation_step(self, batch, batch_idx) -> Dict:
220
+ log_dict = self._validation_step(batch, batch_idx)
221
+ with self.ema_scope():
222
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
223
+ log_dict.update(log_dict_ema)
224
+ return log_dict
225
+
226
+ def _validation_step(self, batch, batch_idx, postfix="") -> Dict:
227
+ x = self.get_input(batch)
228
+
229
+ z, xrec, regularization_log = self(x)
230
+ aeloss, log_dict_ae = self.loss(
231
+ regularization_log,
232
+ x,
233
+ xrec,
234
+ 0,
235
+ self.global_step,
236
+ last_layer=self.get_last_layer(),
237
+ split="val" + postfix,
238
+ )
239
+
240
+ discloss, log_dict_disc = self.loss(
241
+ regularization_log,
242
+ x,
243
+ xrec,
244
+ 1,
245
+ self.global_step,
246
+ last_layer=self.get_last_layer(),
247
+ split="val" + postfix,
248
+ )
249
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
250
+ log_dict_ae.update(log_dict_disc)
251
+ self.log_dict(log_dict_ae)
252
+ return log_dict_ae
253
+
254
+ def configure_optimizers(self) -> Any:
255
+ ae_params = self.get_autoencoder_params()
256
+ disc_params = self.get_discriminator_params()
257
+
258
+ opt_ae = self.instantiate_optimizer_from_config(
259
+ ae_params,
260
+ default(self.lr_g_factor, 1.0) * self.learning_rate,
261
+ self.optimizer_config,
262
+ )
263
+ opt_disc = self.instantiate_optimizer_from_config(
264
+ disc_params, self.learning_rate, self.optimizer_config
265
+ )
266
+
267
+ return [opt_ae, opt_disc], []
268
+
269
+ @torch.no_grad()
270
+ def log_images(self, batch: Dict, **kwargs) -> Dict:
271
+ log = dict()
272
+ x = self.get_input(batch)
273
+ _, xrec, _ = self(x)
274
+ log["inputs"] = x
275
+ log["reconstructions"] = xrec
276
+ with self.ema_scope():
277
+ _, xrec_ema, _ = self(x)
278
+ log["reconstructions_ema"] = xrec_ema
279
+ return log
280
+
281
+
282
+ class AutoencoderKL(AutoencodingEngine):
283
+ def __init__(self, embed_dim: int, **kwargs):
284
+ ddconfig = kwargs.pop("ddconfig")
285
+ ckpt_path = kwargs.pop("ckpt_path", None)
286
+ ignore_keys = kwargs.pop("ignore_keys", ())
287
+ super().__init__(
288
+ encoder_config={"target": "torch.nn.Identity"},
289
+ decoder_config={"target": "torch.nn.Identity"},
290
+ regularizer_config={"target": "torch.nn.Identity"},
291
+ loss_config=kwargs.pop("lossconfig"),
292
+ **kwargs,
293
+ )
294
+ assert ddconfig["double_z"]
295
+ self.encoder = Encoder(**ddconfig)
296
+ self.decoder = Decoder(**ddconfig)
297
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
298
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
299
+ self.embed_dim = embed_dim
300
+
301
+ if ckpt_path is not None:
302
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
303
+
304
+ def encode(self, x):
305
+ assert (
306
+ not self.training
307
+ ), f"{self.__class__.__name__} only supports inference currently"
308
+ h = self.encoder(x)
309
+ moments = self.quant_conv(h)
310
+ posterior = DiagonalGaussianDistribution(moments)
311
+ return posterior
312
+
313
+ def decode(self, z, **decoder_kwargs):
314
+ z = self.post_quant_conv(z)
315
+ dec = self.decoder(z, **decoder_kwargs)
316
+ return dec
317
+
318
+
319
+ class AutoencoderKLInferenceWrapper(AutoencoderKL):
320
+ def encode(self, x):
321
+ return super().encode(x).sample()
322
+
323
+
324
+ class IdentityFirstStage(AbstractAutoencoder):
325
+ def __init__(self, *args, **kwargs):
326
+ super().__init__(*args, **kwargs)
327
+
328
+ def get_input(self, x: Any) -> Any:
329
+ return x
330
+
331
+ def encode(self, x: Any, *args, **kwargs) -> Any:
332
+ return x
333
+
334
+ def decode(self, x: Any, *args, **kwargs) -> Any:
335
+ return x
repositories/generative-models/sgm/models/diffusion.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ from typing import Any, Dict, List, Tuple, Union
3
+
4
+ import pytorch_lightning as pl
5
+ import torch
6
+ from omegaconf import ListConfig, OmegaConf
7
+ from safetensors.torch import load_file as load_safetensors
8
+ from torch.optim.lr_scheduler import LambdaLR
9
+
10
+ from ..modules import UNCONDITIONAL_CONFIG
11
+ from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
12
+ from ..modules.ema import LitEma
13
+ from ..util import (
14
+ default,
15
+ disabled_train,
16
+ get_obj_from_str,
17
+ instantiate_from_config,
18
+ log_txt_as_img,
19
+ )
20
+
21
+
22
+ class DiffusionEngine(pl.LightningModule):
23
+ def __init__(
24
+ self,
25
+ network_config,
26
+ denoiser_config,
27
+ first_stage_config,
28
+ conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
29
+ sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
30
+ optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
31
+ scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
32
+ loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
33
+ network_wrapper: Union[None, str] = None,
34
+ ckpt_path: Union[None, str] = None,
35
+ use_ema: bool = False,
36
+ ema_decay_rate: float = 0.9999,
37
+ scale_factor: float = 1.0,
38
+ disable_first_stage_autocast=False,
39
+ input_key: str = "jpg",
40
+ log_keys: Union[List, None] = None,
41
+ no_cond_log: bool = False,
42
+ compile_model: bool = False,
43
+ ):
44
+ super().__init__()
45
+ self.log_keys = log_keys
46
+ self.input_key = input_key
47
+ self.optimizer_config = default(
48
+ optimizer_config, {"target": "torch.optim.AdamW"}
49
+ )
50
+ model = instantiate_from_config(network_config)
51
+ self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
52
+ model, compile_model=compile_model
53
+ )
54
+
55
+ self.denoiser = instantiate_from_config(denoiser_config)
56
+ self.sampler = (
57
+ instantiate_from_config(sampler_config)
58
+ if sampler_config is not None
59
+ else None
60
+ )
61
+ self.conditioner = instantiate_from_config(
62
+ default(conditioner_config, UNCONDITIONAL_CONFIG)
63
+ )
64
+ self.scheduler_config = scheduler_config
65
+ self._init_first_stage(first_stage_config)
66
+
67
+ self.loss_fn = (
68
+ instantiate_from_config(loss_fn_config)
69
+ if loss_fn_config is not None
70
+ else None
71
+ )
72
+
73
+ self.use_ema = use_ema
74
+ if self.use_ema:
75
+ self.model_ema = LitEma(self.model, decay=ema_decay_rate)
76
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
77
+
78
+ self.scale_factor = scale_factor
79
+ self.disable_first_stage_autocast = disable_first_stage_autocast
80
+ self.no_cond_log = no_cond_log
81
+
82
+ if ckpt_path is not None:
83
+ self.init_from_ckpt(ckpt_path)
84
+
85
+ def init_from_ckpt(
86
+ self,
87
+ path: str,
88
+ ) -> None:
89
+ if path.endswith("ckpt"):
90
+ sd = torch.load(path, map_location="cpu")["state_dict"]
91
+ elif path.endswith("safetensors"):
92
+ sd = load_safetensors(path)
93
+ else:
94
+ raise NotImplementedError
95
+
96
+ missing, unexpected = self.load_state_dict(sd, strict=False)
97
+ print(
98
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
99
+ )
100
+ if len(missing) > 0:
101
+ print(f"Missing Keys: {missing}")
102
+ if len(unexpected) > 0:
103
+ print(f"Unexpected Keys: {unexpected}")
104
+
105
+ def _init_first_stage(self, config):
106
+ model = instantiate_from_config(config).eval()
107
+ model.train = disabled_train
108
+ for param in model.parameters():
109
+ param.requires_grad = False
110
+ self.first_stage_model = model
111
+
112
+ def get_input(self, batch):
113
+ # assuming unified data format, dataloader returns a dict.
114
+ # image tensors should be scaled to -1 ... 1 and in bchw format
115
+ return batch[self.input_key]
116
+
117
+ @torch.no_grad()
118
+ def decode_first_stage(self, z):
119
+ z = 1.0 / self.scale_factor * z
120
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
121
+ out = self.first_stage_model.decode(z)
122
+ return out
123
+
124
+ @torch.no_grad()
125
+ def encode_first_stage(self, x):
126
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
127
+ z = self.first_stage_model.encode(x)
128
+ z = self.scale_factor * z
129
+ return z
130
+
131
+ def forward(self, x, batch):
132
+ loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
133
+ loss_mean = loss.mean()
134
+ loss_dict = {"loss": loss_mean}
135
+ return loss_mean, loss_dict
136
+
137
+ def shared_step(self, batch: Dict) -> Any:
138
+ x = self.get_input(batch)
139
+ x = self.encode_first_stage(x)
140
+ batch["global_step"] = self.global_step
141
+ loss, loss_dict = self(x, batch)
142
+ return loss, loss_dict
143
+
144
+ def training_step(self, batch, batch_idx):
145
+ loss, loss_dict = self.shared_step(batch)
146
+
147
+ self.log_dict(
148
+ loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
149
+ )
150
+
151
+ self.log(
152
+ "global_step",
153
+ self.global_step,
154
+ prog_bar=True,
155
+ logger=True,
156
+ on_step=True,
157
+ on_epoch=False,
158
+ )
159
+
160
+ if self.scheduler_config is not None:
161
+ lr = self.optimizers().param_groups[0]["lr"]
162
+ self.log(
163
+ "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
164
+ )
165
+
166
+ return loss
167
+
168
+ def on_train_start(self, *args, **kwargs):
169
+ if self.sampler is None or self.loss_fn is None:
170
+ raise ValueError("Sampler and loss function need to be set for training.")
171
+
172
+ def on_train_batch_end(self, *args, **kwargs):
173
+ if self.use_ema:
174
+ self.model_ema(self.model)
175
+
176
+ @contextmanager
177
+ def ema_scope(self, context=None):
178
+ if self.use_ema:
179
+ self.model_ema.store(self.model.parameters())
180
+ self.model_ema.copy_to(self.model)
181
+ if context is not None:
182
+ print(f"{context}: Switched to EMA weights")
183
+ try:
184
+ yield None
185
+ finally:
186
+ if self.use_ema:
187
+ self.model_ema.restore(self.model.parameters())
188
+ if context is not None:
189
+ print(f"{context}: Restored training weights")
190
+
191
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
192
+ return get_obj_from_str(cfg["target"])(
193
+ params, lr=lr, **cfg.get("params", dict())
194
+ )
195
+
196
+ def configure_optimizers(self):
197
+ lr = self.learning_rate
198
+ params = list(self.model.parameters())
199
+ for embedder in self.conditioner.embedders:
200
+ if embedder.is_trainable:
201
+ params = params + list(embedder.parameters())
202
+ opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
203
+ if self.scheduler_config is not None:
204
+ scheduler = instantiate_from_config(self.scheduler_config)
205
+ print("Setting up LambdaLR scheduler...")
206
+ scheduler = [
207
+ {
208
+ "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
209
+ "interval": "step",
210
+ "frequency": 1,
211
+ }
212
+ ]
213
+ return [opt], scheduler
214
+ return opt
215
+
216
+ @torch.no_grad()
217
+ def sample(
218
+ self,
219
+ cond: Dict,
220
+ uc: Union[Dict, None] = None,
221
+ batch_size: int = 16,
222
+ shape: Union[None, Tuple, List] = None,
223
+ **kwargs,
224
+ ):
225
+ randn = torch.randn(batch_size, *shape).to(self.device)
226
+
227
+ denoiser = lambda input, sigma, c: self.denoiser(
228
+ self.model, input, sigma, c, **kwargs
229
+ )
230
+ samples = self.sampler(denoiser, randn, cond, uc=uc)
231
+ return samples
232
+
233
+ @torch.no_grad()
234
+ def log_conditionings(self, batch: Dict, n: int) -> Dict:
235
+ """
236
+ Defines heuristics to log different conditionings.
237
+ These can be lists of strings (text-to-image), tensors, ints, ...
238
+ """
239
+ image_h, image_w = batch[self.input_key].shape[2:]
240
+ log = dict()
241
+
242
+ for embedder in self.conditioner.embedders:
243
+ if (
244
+ (self.log_keys is None) or (embedder.input_key in self.log_keys)
245
+ ) and not self.no_cond_log:
246
+ x = batch[embedder.input_key][:n]
247
+ if isinstance(x, torch.Tensor):
248
+ if x.dim() == 1:
249
+ # class-conditional, convert integer to string
250
+ x = [str(x[i].item()) for i in range(x.shape[0])]
251
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
252
+ elif x.dim() == 2:
253
+ # size and crop cond and the like
254
+ x = [
255
+ "x".join([str(xx) for xx in x[i].tolist()])
256
+ for i in range(x.shape[0])
257
+ ]
258
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
259
+ else:
260
+ raise NotImplementedError()
261
+ elif isinstance(x, (List, ListConfig)):
262
+ if isinstance(x[0], str):
263
+ # strings
264
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
265
+ else:
266
+ raise NotImplementedError()
267
+ else:
268
+ raise NotImplementedError()
269
+ log[embedder.input_key] = xc
270
+ return log
271
+
272
+ @torch.no_grad()
273
+ def log_images(
274
+ self,
275
+ batch: Dict,
276
+ N: int = 8,
277
+ sample: bool = True,
278
+ ucg_keys: List[str] = None,
279
+ **kwargs,
280
+ ) -> Dict:
281
+ conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
282
+ if ucg_keys:
283
+ assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
284
+ "Each defined ucg key for sampling must be in the provided conditioner input keys,"
285
+ f"but we have {ucg_keys} vs. {conditioner_input_keys}"
286
+ )
287
+ else:
288
+ ucg_keys = conditioner_input_keys
289
+ log = dict()
290
+
291
+ x = self.get_input(batch)
292
+
293
+ c, uc = self.conditioner.get_unconditional_conditioning(
294
+ batch,
295
+ force_uc_zero_embeddings=ucg_keys
296
+ if len(self.conditioner.embedders) > 0
297
+ else [],
298
+ )
299
+
300
+ sampling_kwargs = {}
301
+
302
+ N = min(x.shape[0], N)
303
+ x = x.to(self.device)[:N]
304
+ log["inputs"] = x
305
+ z = self.encode_first_stage(x)
306
+ log["reconstructions"] = self.decode_first_stage(z)
307
+ log.update(self.log_conditionings(batch, N))
308
+
309
+ for k in c:
310
+ if isinstance(c[k], torch.Tensor):
311
+ c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
312
+
313
+ if sample:
314
+ with self.ema_scope("Plotting"):
315
+ samples = self.sample(
316
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
317
+ )
318
+ samples = self.decode_first_stage(samples)
319
+ log["samples"] = samples
320
+ return log
repositories/generative-models/sgm/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .encoders.modules import GeneralConditioner
2
+
3
+ UNCONDITIONAL_CONFIG = {
4
+ "target": "sgm.modules.GeneralConditioner",
5
+ "params": {"emb_models": []},
6
+ }
repositories/generative-models/sgm/modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (334 Bytes). View file
 
repositories/generative-models/sgm/modules/__pycache__/attention.cpython-310.pyc ADDED
Binary file (21.6 kB). View file
 
repositories/generative-models/sgm/modules/__pycache__/ema.cpython-310.pyc ADDED
Binary file (3.23 kB). View file
 
repositories/generative-models/sgm/modules/attention.py ADDED
@@ -0,0 +1,947 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from inspect import isfunction
3
+ from typing import Any, Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from einops import rearrange, repeat
8
+ from packaging import version
9
+ from torch import nn
10
+
11
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
12
+ SDP_IS_AVAILABLE = True
13
+ from torch.backends.cuda import SDPBackend, sdp_kernel
14
+
15
+ BACKEND_MAP = {
16
+ SDPBackend.MATH: {
17
+ "enable_math": True,
18
+ "enable_flash": False,
19
+ "enable_mem_efficient": False,
20
+ },
21
+ SDPBackend.FLASH_ATTENTION: {
22
+ "enable_math": False,
23
+ "enable_flash": True,
24
+ "enable_mem_efficient": False,
25
+ },
26
+ SDPBackend.EFFICIENT_ATTENTION: {
27
+ "enable_math": False,
28
+ "enable_flash": False,
29
+ "enable_mem_efficient": True,
30
+ },
31
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
32
+ }
33
+ else:
34
+ from contextlib import nullcontext
35
+
36
+ SDP_IS_AVAILABLE = False
37
+ sdp_kernel = nullcontext
38
+ BACKEND_MAP = {}
39
+ print(
40
+ f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
41
+ f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
42
+ )
43
+
44
+ try:
45
+ import xformers
46
+ import xformers.ops
47
+
48
+ XFORMERS_IS_AVAILABLE = True
49
+ except:
50
+ XFORMERS_IS_AVAILABLE = False
51
+ print("no module 'xformers'. Processing without...")
52
+
53
+ from .diffusionmodules.util import checkpoint
54
+
55
+
56
+ def exists(val):
57
+ return val is not None
58
+
59
+
60
+ def uniq(arr):
61
+ return {el: True for el in arr}.keys()
62
+
63
+
64
+ def default(val, d):
65
+ if exists(val):
66
+ return val
67
+ return d() if isfunction(d) else d
68
+
69
+
70
+ def max_neg_value(t):
71
+ return -torch.finfo(t.dtype).max
72
+
73
+
74
+ def init_(tensor):
75
+ dim = tensor.shape[-1]
76
+ std = 1 / math.sqrt(dim)
77
+ tensor.uniform_(-std, std)
78
+ return tensor
79
+
80
+
81
+ # feedforward
82
+ class GEGLU(nn.Module):
83
+ def __init__(self, dim_in, dim_out):
84
+ super().__init__()
85
+ self.proj = nn.Linear(dim_in, dim_out * 2)
86
+
87
+ def forward(self, x):
88
+ x, gate = self.proj(x).chunk(2, dim=-1)
89
+ return x * F.gelu(gate)
90
+
91
+
92
+ class FeedForward(nn.Module):
93
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
94
+ super().__init__()
95
+ inner_dim = int(dim * mult)
96
+ dim_out = default(dim_out, dim)
97
+ project_in = (
98
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
99
+ if not glu
100
+ else GEGLU(dim, inner_dim)
101
+ )
102
+
103
+ self.net = nn.Sequential(
104
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
105
+ )
106
+
107
+ def forward(self, x):
108
+ return self.net(x)
109
+
110
+
111
+ def zero_module(module):
112
+ """
113
+ Zero out the parameters of a module and return it.
114
+ """
115
+ for p in module.parameters():
116
+ p.detach().zero_()
117
+ return module
118
+
119
+
120
+ def Normalize(in_channels):
121
+ return torch.nn.GroupNorm(
122
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
123
+ )
124
+
125
+
126
+ class LinearAttention(nn.Module):
127
+ def __init__(self, dim, heads=4, dim_head=32):
128
+ super().__init__()
129
+ self.heads = heads
130
+ hidden_dim = dim_head * heads
131
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
132
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
133
+
134
+ def forward(self, x):
135
+ b, c, h, w = x.shape
136
+ qkv = self.to_qkv(x)
137
+ q, k, v = rearrange(
138
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
139
+ )
140
+ k = k.softmax(dim=-1)
141
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
142
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
143
+ out = rearrange(
144
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
145
+ )
146
+ return self.to_out(out)
147
+
148
+
149
+ class SpatialSelfAttention(nn.Module):
150
+ def __init__(self, in_channels):
151
+ super().__init__()
152
+ self.in_channels = in_channels
153
+
154
+ self.norm = Normalize(in_channels)
155
+ self.q = torch.nn.Conv2d(
156
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
157
+ )
158
+ self.k = torch.nn.Conv2d(
159
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
160
+ )
161
+ self.v = torch.nn.Conv2d(
162
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
163
+ )
164
+ self.proj_out = torch.nn.Conv2d(
165
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
166
+ )
167
+
168
+ def forward(self, x):
169
+ h_ = x
170
+ h_ = self.norm(h_)
171
+ q = self.q(h_)
172
+ k = self.k(h_)
173
+ v = self.v(h_)
174
+
175
+ # compute attention
176
+ b, c, h, w = q.shape
177
+ q = rearrange(q, "b c h w -> b (h w) c")
178
+ k = rearrange(k, "b c h w -> b c (h w)")
179
+ w_ = torch.einsum("bij,bjk->bik", q, k)
180
+
181
+ w_ = w_ * (int(c) ** (-0.5))
182
+ w_ = torch.nn.functional.softmax(w_, dim=2)
183
+
184
+ # attend to values
185
+ v = rearrange(v, "b c h w -> b c (h w)")
186
+ w_ = rearrange(w_, "b i j -> b j i")
187
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
188
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
189
+ h_ = self.proj_out(h_)
190
+
191
+ return x + h_
192
+
193
+
194
+ class CrossAttention(nn.Module):
195
+ def __init__(
196
+ self,
197
+ query_dim,
198
+ context_dim=None,
199
+ heads=8,
200
+ dim_head=64,
201
+ dropout=0.0,
202
+ backend=None,
203
+ ):
204
+ super().__init__()
205
+ inner_dim = dim_head * heads
206
+ context_dim = default(context_dim, query_dim)
207
+
208
+ self.scale = dim_head**-0.5
209
+ self.heads = heads
210
+
211
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
212
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
213
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
214
+
215
+ self.to_out = nn.Sequential(
216
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
217
+ )
218
+ self.backend = backend
219
+
220
+ def forward(
221
+ self,
222
+ x,
223
+ context=None,
224
+ mask=None,
225
+ additional_tokens=None,
226
+ n_times_crossframe_attn_in_self=0,
227
+ ):
228
+ h = self.heads
229
+
230
+ if additional_tokens is not None:
231
+ # get the number of masked tokens at the beginning of the output sequence
232
+ n_tokens_to_mask = additional_tokens.shape[1]
233
+ # add additional token
234
+ x = torch.cat([additional_tokens, x], dim=1)
235
+
236
+ q = self.to_q(x)
237
+ context = default(context, x)
238
+ k = self.to_k(context)
239
+ v = self.to_v(context)
240
+
241
+ if n_times_crossframe_attn_in_self:
242
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
243
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
244
+ n_cp = x.shape[0] // n_times_crossframe_attn_in_self
245
+ k = repeat(
246
+ k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
247
+ )
248
+ v = repeat(
249
+ v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
250
+ )
251
+
252
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
253
+
254
+ ## old
255
+ """
256
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
257
+ del q, k
258
+
259
+ if exists(mask):
260
+ mask = rearrange(mask, 'b ... -> b (...)')
261
+ max_neg_value = -torch.finfo(sim.dtype).max
262
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
263
+ sim.masked_fill_(~mask, max_neg_value)
264
+
265
+ # attention, what we cannot get enough of
266
+ sim = sim.softmax(dim=-1)
267
+
268
+ out = einsum('b i j, b j d -> b i d', sim, v)
269
+ """
270
+ ## new
271
+ with sdp_kernel(**BACKEND_MAP[self.backend]):
272
+ # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
273
+ out = F.scaled_dot_product_attention(
274
+ q, k, v, attn_mask=mask
275
+ ) # scale is dim_head ** -0.5 per default
276
+
277
+ del q, k, v
278
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
279
+
280
+ if additional_tokens is not None:
281
+ # remove additional token
282
+ out = out[:, n_tokens_to_mask:]
283
+ return self.to_out(out)
284
+
285
+
286
+ class MemoryEfficientCrossAttention(nn.Module):
287
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
288
+ def __init__(
289
+ self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
290
+ ):
291
+ super().__init__()
292
+ print(
293
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
294
+ f"{heads} heads with a dimension of {dim_head}."
295
+ )
296
+ inner_dim = dim_head * heads
297
+ context_dim = default(context_dim, query_dim)
298
+
299
+ self.heads = heads
300
+ self.dim_head = dim_head
301
+
302
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
303
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
304
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
305
+
306
+ self.to_out = nn.Sequential(
307
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
308
+ )
309
+ self.attention_op: Optional[Any] = None
310
+
311
+ def forward(
312
+ self,
313
+ x,
314
+ context=None,
315
+ mask=None,
316
+ additional_tokens=None,
317
+ n_times_crossframe_attn_in_self=0,
318
+ ):
319
+ if additional_tokens is not None:
320
+ # get the number of masked tokens at the beginning of the output sequence
321
+ n_tokens_to_mask = additional_tokens.shape[1]
322
+ # add additional token
323
+ x = torch.cat([additional_tokens, x], dim=1)
324
+ q = self.to_q(x)
325
+ context = default(context, x)
326
+ k = self.to_k(context)
327
+ v = self.to_v(context)
328
+
329
+ if n_times_crossframe_attn_in_self:
330
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
331
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
332
+ # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
333
+ k = repeat(
334
+ k[::n_times_crossframe_attn_in_self],
335
+ "b ... -> (b n) ...",
336
+ n=n_times_crossframe_attn_in_self,
337
+ )
338
+ v = repeat(
339
+ v[::n_times_crossframe_attn_in_self],
340
+ "b ... -> (b n) ...",
341
+ n=n_times_crossframe_attn_in_self,
342
+ )
343
+
344
+ b, _, _ = q.shape
345
+ q, k, v = map(
346
+ lambda t: t.unsqueeze(3)
347
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
348
+ .permute(0, 2, 1, 3)
349
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
350
+ .contiguous(),
351
+ (q, k, v),
352
+ )
353
+
354
+ # actually compute the attention, what we cannot get enough of
355
+ out = xformers.ops.memory_efficient_attention(
356
+ q, k, v, attn_bias=None, op=self.attention_op
357
+ )
358
+
359
+ # TODO: Use this directly in the attention operation, as a bias
360
+ if exists(mask):
361
+ raise NotImplementedError
362
+ out = (
363
+ out.unsqueeze(0)
364
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
365
+ .permute(0, 2, 1, 3)
366
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
367
+ )
368
+ if additional_tokens is not None:
369
+ # remove additional token
370
+ out = out[:, n_tokens_to_mask:]
371
+ return self.to_out(out)
372
+
373
+
374
+ class BasicTransformerBlock(nn.Module):
375
+ ATTENTION_MODES = {
376
+ "softmax": CrossAttention, # vanilla attention
377
+ "softmax-xformers": MemoryEfficientCrossAttention, # ampere
378
+ }
379
+
380
+ def __init__(
381
+ self,
382
+ dim,
383
+ n_heads,
384
+ d_head,
385
+ dropout=0.0,
386
+ context_dim=None,
387
+ gated_ff=True,
388
+ checkpoint=True,
389
+ disable_self_attn=False,
390
+ attn_mode="softmax",
391
+ sdp_backend=None,
392
+ ):
393
+ super().__init__()
394
+ assert attn_mode in self.ATTENTION_MODES
395
+ if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
396
+ print(
397
+ f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
398
+ f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
399
+ )
400
+ attn_mode = "softmax"
401
+ elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
402
+ print(
403
+ "We do not support vanilla attention anymore, as it is too expensive. Sorry."
404
+ )
405
+ if not XFORMERS_IS_AVAILABLE:
406
+ assert (
407
+ False
408
+ ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
409
+ else:
410
+ print("Falling back to xformers efficient attention.")
411
+ attn_mode = "softmax-xformers"
412
+ attn_cls = self.ATTENTION_MODES[attn_mode]
413
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
414
+ assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
415
+ else:
416
+ assert sdp_backend is None
417
+ self.disable_self_attn = disable_self_attn
418
+ self.attn1 = attn_cls(
419
+ query_dim=dim,
420
+ heads=n_heads,
421
+ dim_head=d_head,
422
+ dropout=dropout,
423
+ context_dim=context_dim if self.disable_self_attn else None,
424
+ backend=sdp_backend,
425
+ ) # is a self-attention if not self.disable_self_attn
426
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
427
+ self.attn2 = attn_cls(
428
+ query_dim=dim,
429
+ context_dim=context_dim,
430
+ heads=n_heads,
431
+ dim_head=d_head,
432
+ dropout=dropout,
433
+ backend=sdp_backend,
434
+ ) # is self-attn if context is none
435
+ self.norm1 = nn.LayerNorm(dim)
436
+ self.norm2 = nn.LayerNorm(dim)
437
+ self.norm3 = nn.LayerNorm(dim)
438
+ self.checkpoint = checkpoint
439
+ if self.checkpoint:
440
+ print(f"{self.__class__.__name__} is using checkpointing")
441
+
442
+ def forward(
443
+ self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
444
+ ):
445
+ kwargs = {"x": x}
446
+
447
+ if context is not None:
448
+ kwargs.update({"context": context})
449
+
450
+ if additional_tokens is not None:
451
+ kwargs.update({"additional_tokens": additional_tokens})
452
+
453
+ if n_times_crossframe_attn_in_self:
454
+ kwargs.update(
455
+ {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
456
+ )
457
+
458
+ # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
459
+ return checkpoint(
460
+ self._forward, (x, context), self.parameters(), self.checkpoint
461
+ )
462
+
463
+ def _forward(
464
+ self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
465
+ ):
466
+ x = (
467
+ self.attn1(
468
+ self.norm1(x),
469
+ context=context if self.disable_self_attn else None,
470
+ additional_tokens=additional_tokens,
471
+ n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
472
+ if not self.disable_self_attn
473
+ else 0,
474
+ )
475
+ + x
476
+ )
477
+ x = (
478
+ self.attn2(
479
+ self.norm2(x), context=context, additional_tokens=additional_tokens
480
+ )
481
+ + x
482
+ )
483
+ x = self.ff(self.norm3(x)) + x
484
+ return x
485
+
486
+
487
+ class BasicTransformerSingleLayerBlock(nn.Module):
488
+ ATTENTION_MODES = {
489
+ "softmax": CrossAttention, # vanilla attention
490
+ "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
491
+ # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
492
+ }
493
+
494
+ def __init__(
495
+ self,
496
+ dim,
497
+ n_heads,
498
+ d_head,
499
+ dropout=0.0,
500
+ context_dim=None,
501
+ gated_ff=True,
502
+ checkpoint=True,
503
+ attn_mode="softmax",
504
+ ):
505
+ super().__init__()
506
+ assert attn_mode in self.ATTENTION_MODES
507
+ attn_cls = self.ATTENTION_MODES[attn_mode]
508
+ self.attn1 = attn_cls(
509
+ query_dim=dim,
510
+ heads=n_heads,
511
+ dim_head=d_head,
512
+ dropout=dropout,
513
+ context_dim=context_dim,
514
+ )
515
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
516
+ self.norm1 = nn.LayerNorm(dim)
517
+ self.norm2 = nn.LayerNorm(dim)
518
+ self.checkpoint = checkpoint
519
+
520
+ def forward(self, x, context=None):
521
+ return checkpoint(
522
+ self._forward, (x, context), self.parameters(), self.checkpoint
523
+ )
524
+
525
+ def _forward(self, x, context=None):
526
+ x = self.attn1(self.norm1(x), context=context) + x
527
+ x = self.ff(self.norm2(x)) + x
528
+ return x
529
+
530
+
531
+ class SpatialTransformer(nn.Module):
532
+ """
533
+ Transformer block for image-like data.
534
+ First, project the input (aka embedding)
535
+ and reshape to b, t, d.
536
+ Then apply standard transformer action.
537
+ Finally, reshape to image
538
+ NEW: use_linear for more efficiency instead of the 1x1 convs
539
+ """
540
+
541
+ def __init__(
542
+ self,
543
+ in_channels,
544
+ n_heads,
545
+ d_head,
546
+ depth=1,
547
+ dropout=0.0,
548
+ context_dim=None,
549
+ disable_self_attn=False,
550
+ use_linear=False,
551
+ attn_type="softmax",
552
+ use_checkpoint=True,
553
+ # sdp_backend=SDPBackend.FLASH_ATTENTION
554
+ sdp_backend=None,
555
+ ):
556
+ super().__init__()
557
+ print(
558
+ f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
559
+ )
560
+ from omegaconf import ListConfig
561
+
562
+ if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
563
+ context_dim = [context_dim]
564
+ if exists(context_dim) and isinstance(context_dim, list):
565
+ if depth != len(context_dim):
566
+ print(
567
+ f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
568
+ f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
569
+ )
570
+ # depth does not match context dims.
571
+ assert all(
572
+ map(lambda x: x == context_dim[0], context_dim)
573
+ ), "need homogenous context_dim to match depth automatically"
574
+ context_dim = depth * [context_dim[0]]
575
+ elif context_dim is None:
576
+ context_dim = [None] * depth
577
+ self.in_channels = in_channels
578
+ inner_dim = n_heads * d_head
579
+ self.norm = Normalize(in_channels)
580
+ if not use_linear:
581
+ self.proj_in = nn.Conv2d(
582
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
583
+ )
584
+ else:
585
+ self.proj_in = nn.Linear(in_channels, inner_dim)
586
+
587
+ self.transformer_blocks = nn.ModuleList(
588
+ [
589
+ BasicTransformerBlock(
590
+ inner_dim,
591
+ n_heads,
592
+ d_head,
593
+ dropout=dropout,
594
+ context_dim=context_dim[d],
595
+ disable_self_attn=disable_self_attn,
596
+ attn_mode=attn_type,
597
+ checkpoint=use_checkpoint,
598
+ sdp_backend=sdp_backend,
599
+ )
600
+ for d in range(depth)
601
+ ]
602
+ )
603
+ if not use_linear:
604
+ self.proj_out = zero_module(
605
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
606
+ )
607
+ else:
608
+ # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
609
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
610
+ self.use_linear = use_linear
611
+
612
+ def forward(self, x, context=None):
613
+ # note: if no context is given, cross-attention defaults to self-attention
614
+ if not isinstance(context, list):
615
+ context = [context]
616
+ b, c, h, w = x.shape
617
+ x_in = x
618
+ x = self.norm(x)
619
+ if not self.use_linear:
620
+ x = self.proj_in(x)
621
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
622
+ if self.use_linear:
623
+ x = self.proj_in(x)
624
+ for i, block in enumerate(self.transformer_blocks):
625
+ if i > 0 and len(context) == 1:
626
+ i = 0 # use same context for each block
627
+ x = block(x, context=context[i])
628
+ if self.use_linear:
629
+ x = self.proj_out(x)
630
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
631
+ if not self.use_linear:
632
+ x = self.proj_out(x)
633
+ return x + x_in
634
+
635
+
636
+ def benchmark_attn():
637
+ # Lets define a helpful benchmarking function:
638
+ # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
639
+ device = "cuda" if torch.cuda.is_available() else "cpu"
640
+ import torch.nn.functional as F
641
+ import torch.utils.benchmark as benchmark
642
+
643
+ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
644
+ t0 = benchmark.Timer(
645
+ stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
646
+ )
647
+ return t0.blocked_autorange().mean * 1e6
648
+
649
+ # Lets define the hyper-parameters of our input
650
+ batch_size = 32
651
+ max_sequence_len = 1024
652
+ num_heads = 32
653
+ embed_dimension = 32
654
+
655
+ dtype = torch.float16
656
+
657
+ query = torch.rand(
658
+ batch_size,
659
+ num_heads,
660
+ max_sequence_len,
661
+ embed_dimension,
662
+ device=device,
663
+ dtype=dtype,
664
+ )
665
+ key = torch.rand(
666
+ batch_size,
667
+ num_heads,
668
+ max_sequence_len,
669
+ embed_dimension,
670
+ device=device,
671
+ dtype=dtype,
672
+ )
673
+ value = torch.rand(
674
+ batch_size,
675
+ num_heads,
676
+ max_sequence_len,
677
+ embed_dimension,
678
+ device=device,
679
+ dtype=dtype,
680
+ )
681
+
682
+ print(f"q/k/v shape:", query.shape, key.shape, value.shape)
683
+
684
+ # Lets explore the speed of each of the 3 implementations
685
+ from torch.backends.cuda import SDPBackend, sdp_kernel
686
+
687
+ # Helpful arguments mapper
688
+ backend_map = {
689
+ SDPBackend.MATH: {
690
+ "enable_math": True,
691
+ "enable_flash": False,
692
+ "enable_mem_efficient": False,
693
+ },
694
+ SDPBackend.FLASH_ATTENTION: {
695
+ "enable_math": False,
696
+ "enable_flash": True,
697
+ "enable_mem_efficient": False,
698
+ },
699
+ SDPBackend.EFFICIENT_ATTENTION: {
700
+ "enable_math": False,
701
+ "enable_flash": False,
702
+ "enable_mem_efficient": True,
703
+ },
704
+ }
705
+
706
+ from torch.profiler import ProfilerActivity, profile, record_function
707
+
708
+ activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
709
+
710
+ print(
711
+ f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
712
+ )
713
+ with profile(
714
+ activities=activities, record_shapes=False, profile_memory=True
715
+ ) as prof:
716
+ with record_function("Default detailed stats"):
717
+ for _ in range(25):
718
+ o = F.scaled_dot_product_attention(query, key, value)
719
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
720
+
721
+ print(
722
+ f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
723
+ )
724
+ with sdp_kernel(**backend_map[SDPBackend.MATH]):
725
+ with profile(
726
+ activities=activities, record_shapes=False, profile_memory=True
727
+ ) as prof:
728
+ with record_function("Math implmentation stats"):
729
+ for _ in range(25):
730
+ o = F.scaled_dot_product_attention(query, key, value)
731
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
732
+
733
+ with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
734
+ try:
735
+ print(
736
+ f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
737
+ )
738
+ except RuntimeError:
739
+ print("FlashAttention is not supported. See warnings for reasons.")
740
+ with profile(
741
+ activities=activities, record_shapes=False, profile_memory=True
742
+ ) as prof:
743
+ with record_function("FlashAttention stats"):
744
+ for _ in range(25):
745
+ o = F.scaled_dot_product_attention(query, key, value)
746
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
747
+
748
+ with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
749
+ try:
750
+ print(
751
+ f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
752
+ )
753
+ except RuntimeError:
754
+ print("EfficientAttention is not supported. See warnings for reasons.")
755
+ with profile(
756
+ activities=activities, record_shapes=False, profile_memory=True
757
+ ) as prof:
758
+ with record_function("EfficientAttention stats"):
759
+ for _ in range(25):
760
+ o = F.scaled_dot_product_attention(query, key, value)
761
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
762
+
763
+
764
+ def run_model(model, x, context):
765
+ return model(x, context)
766
+
767
+
768
+ def benchmark_transformer_blocks():
769
+ device = "cuda" if torch.cuda.is_available() else "cpu"
770
+ import torch.utils.benchmark as benchmark
771
+
772
+ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
773
+ t0 = benchmark.Timer(
774
+ stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
775
+ )
776
+ return t0.blocked_autorange().mean * 1e6
777
+
778
+ checkpoint = True
779
+ compile = False
780
+
781
+ batch_size = 32
782
+ h, w = 64, 64
783
+ context_len = 77
784
+ embed_dimension = 1024
785
+ context_dim = 1024
786
+ d_head = 64
787
+
788
+ transformer_depth = 4
789
+
790
+ n_heads = embed_dimension // d_head
791
+
792
+ dtype = torch.float16
793
+
794
+ model_native = SpatialTransformer(
795
+ embed_dimension,
796
+ n_heads,
797
+ d_head,
798
+ context_dim=context_dim,
799
+ use_linear=True,
800
+ use_checkpoint=checkpoint,
801
+ attn_type="softmax",
802
+ depth=transformer_depth,
803
+ sdp_backend=SDPBackend.FLASH_ATTENTION,
804
+ ).to(device)
805
+ model_efficient_attn = SpatialTransformer(
806
+ embed_dimension,
807
+ n_heads,
808
+ d_head,
809
+ context_dim=context_dim,
810
+ use_linear=True,
811
+ depth=transformer_depth,
812
+ use_checkpoint=checkpoint,
813
+ attn_type="softmax-xformers",
814
+ ).to(device)
815
+ if not checkpoint and compile:
816
+ print("compiling models")
817
+ model_native = torch.compile(model_native)
818
+ model_efficient_attn = torch.compile(model_efficient_attn)
819
+
820
+ x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
821
+ c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
822
+
823
+ from torch.profiler import ProfilerActivity, profile, record_function
824
+
825
+ activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
826
+
827
+ with torch.autocast("cuda"):
828
+ print(
829
+ f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
830
+ )
831
+ print(
832
+ f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
833
+ )
834
+
835
+ print(75 * "+")
836
+ print("NATIVE")
837
+ print(75 * "+")
838
+ torch.cuda.reset_peak_memory_stats()
839
+ with profile(
840
+ activities=activities, record_shapes=False, profile_memory=True
841
+ ) as prof:
842
+ with record_function("NativeAttention stats"):
843
+ for _ in range(25):
844
+ model_native(x, c)
845
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
846
+ print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
847
+
848
+ print(75 * "+")
849
+ print("Xformers")
850
+ print(75 * "+")
851
+ torch.cuda.reset_peak_memory_stats()
852
+ with profile(
853
+ activities=activities, record_shapes=False, profile_memory=True
854
+ ) as prof:
855
+ with record_function("xformers stats"):
856
+ for _ in range(25):
857
+ model_efficient_attn(x, c)
858
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
859
+ print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
860
+
861
+
862
+ def test01():
863
+ # conv1x1 vs linear
864
+ from ..util import count_params
865
+
866
+ conv = nn.Conv2d(3, 32, kernel_size=1).cuda()
867
+ print(count_params(conv))
868
+ linear = torch.nn.Linear(3, 32).cuda()
869
+ print(count_params(linear))
870
+
871
+ print(conv.weight.shape)
872
+
873
+ # use same initialization
874
+ linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
875
+ linear.bias = torch.nn.Parameter(conv.bias)
876
+
877
+ print(linear.weight.shape)
878
+
879
+ x = torch.randn(11, 3, 64, 64).cuda()
880
+
881
+ xr = rearrange(x, "b c h w -> b (h w) c").contiguous()
882
+ print(xr.shape)
883
+ out_linear = linear(xr)
884
+ print(out_linear.mean(), out_linear.shape)
885
+
886
+ out_conv = conv(x)
887
+ print(out_conv.mean(), out_conv.shape)
888
+ print("done with test01.\n")
889
+
890
+
891
+ def test02():
892
+ # try cosine flash attention
893
+ import time
894
+
895
+ torch.backends.cuda.matmul.allow_tf32 = True
896
+ torch.backends.cudnn.allow_tf32 = True
897
+ torch.backends.cudnn.benchmark = True
898
+ print("testing cosine flash attention...")
899
+ DIM = 1024
900
+ SEQLEN = 4096
901
+ BS = 16
902
+
903
+ print(" softmax (vanilla) first...")
904
+ model = BasicTransformerBlock(
905
+ dim=DIM,
906
+ n_heads=16,
907
+ d_head=64,
908
+ dropout=0.0,
909
+ context_dim=None,
910
+ attn_mode="softmax",
911
+ ).cuda()
912
+ try:
913
+ x = torch.randn(BS, SEQLEN, DIM).cuda()
914
+ tic = time.time()
915
+ y = model(x)
916
+ toc = time.time()
917
+ print(y.shape, toc - tic)
918
+ except RuntimeError as e:
919
+ # likely oom
920
+ print(str(e))
921
+
922
+ print("\n now flash-cosine...")
923
+ model = BasicTransformerBlock(
924
+ dim=DIM,
925
+ n_heads=16,
926
+ d_head=64,
927
+ dropout=0.0,
928
+ context_dim=None,
929
+ attn_mode="flash-cosine",
930
+ ).cuda()
931
+ x = torch.randn(BS, SEQLEN, DIM).cuda()
932
+ tic = time.time()
933
+ y = model(x)
934
+ toc = time.time()
935
+ print(y.shape, toc - tic)
936
+ print("done with test02.\n")
937
+
938
+
939
+ if __name__ == "__main__":
940
+ # test01()
941
+ # test02()
942
+ # test03()
943
+
944
+ # benchmark_attn()
945
+ benchmark_transformer_blocks()
946
+
947
+ print("done.")
repositories/generative-models/sgm/modules/autoencoding/__init__.py ADDED
File without changes
repositories/generative-models/sgm/modules/autoencoding/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (184 Bytes). View file
 
repositories/generative-models/sgm/modules/autoencoding/losses/__init__.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7
+ from taming.modules.losses.lpips import LPIPS
8
+ from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9
+
10
+ from ....util import default, instantiate_from_config
11
+
12
+
13
+ def adopt_weight(weight, global_step, threshold=0, value=0.0):
14
+ if global_step < threshold:
15
+ weight = value
16
+ return weight
17
+
18
+
19
+ class LatentLPIPS(nn.Module):
20
+ def __init__(
21
+ self,
22
+ decoder_config,
23
+ perceptual_weight=1.0,
24
+ latent_weight=1.0,
25
+ scale_input_to_tgt_size=False,
26
+ scale_tgt_to_input_size=False,
27
+ perceptual_weight_on_inputs=0.0,
28
+ ):
29
+ super().__init__()
30
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
31
+ self.scale_tgt_to_input_size = scale_tgt_to_input_size
32
+ self.init_decoder(decoder_config)
33
+ self.perceptual_loss = LPIPS().eval()
34
+ self.perceptual_weight = perceptual_weight
35
+ self.latent_weight = latent_weight
36
+ self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
37
+
38
+ def init_decoder(self, config):
39
+ self.decoder = instantiate_from_config(config)
40
+ if hasattr(self.decoder, "encoder"):
41
+ del self.decoder.encoder
42
+
43
+ def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
44
+ log = dict()
45
+ loss = (latent_inputs - latent_predictions) ** 2
46
+ log[f"{split}/latent_l2_loss"] = loss.mean().detach()
47
+ image_reconstructions = None
48
+ if self.perceptual_weight > 0.0:
49
+ image_reconstructions = self.decoder.decode(latent_predictions)
50
+ image_targets = self.decoder.decode(latent_inputs)
51
+ perceptual_loss = self.perceptual_loss(
52
+ image_targets.contiguous(), image_reconstructions.contiguous()
53
+ )
54
+ loss = (
55
+ self.latent_weight * loss.mean()
56
+ + self.perceptual_weight * perceptual_loss.mean()
57
+ )
58
+ log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
59
+
60
+ if self.perceptual_weight_on_inputs > 0.0:
61
+ image_reconstructions = default(
62
+ image_reconstructions, self.decoder.decode(latent_predictions)
63
+ )
64
+ if self.scale_input_to_tgt_size:
65
+ image_inputs = torch.nn.functional.interpolate(
66
+ image_inputs,
67
+ image_reconstructions.shape[2:],
68
+ mode="bicubic",
69
+ antialias=True,
70
+ )
71
+ elif self.scale_tgt_to_input_size:
72
+ image_reconstructions = torch.nn.functional.interpolate(
73
+ image_reconstructions,
74
+ image_inputs.shape[2:],
75
+ mode="bicubic",
76
+ antialias=True,
77
+ )
78
+
79
+ perceptual_loss2 = self.perceptual_loss(
80
+ image_inputs.contiguous(), image_reconstructions.contiguous()
81
+ )
82
+ loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
83
+ log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
84
+ return loss, log
85
+
86
+
87
+ class GeneralLPIPSWithDiscriminator(nn.Module):
88
+ def __init__(
89
+ self,
90
+ disc_start: int,
91
+ logvar_init: float = 0.0,
92
+ pixelloss_weight=1.0,
93
+ disc_num_layers: int = 3,
94
+ disc_in_channels: int = 3,
95
+ disc_factor: float = 1.0,
96
+ disc_weight: float = 1.0,
97
+ perceptual_weight: float = 1.0,
98
+ disc_loss: str = "hinge",
99
+ scale_input_to_tgt_size: bool = False,
100
+ dims: int = 2,
101
+ learn_logvar: bool = False,
102
+ regularization_weights: Union[None, dict] = None,
103
+ ):
104
+ super().__init__()
105
+ self.dims = dims
106
+ if self.dims > 2:
107
+ print(
108
+ f"running with dims={dims}. This means that for perceptual loss calculation, "
109
+ f"the LPIPS loss will be applied to each frame independently. "
110
+ )
111
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
112
+ assert disc_loss in ["hinge", "vanilla"]
113
+ self.pixel_weight = pixelloss_weight
114
+ self.perceptual_loss = LPIPS().eval()
115
+ self.perceptual_weight = perceptual_weight
116
+ # output log variance
117
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
118
+ self.learn_logvar = learn_logvar
119
+
120
+ self.discriminator = NLayerDiscriminator(
121
+ input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=False
122
+ ).apply(weights_init)
123
+ self.discriminator_iter_start = disc_start
124
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
125
+ self.disc_factor = disc_factor
126
+ self.discriminator_weight = disc_weight
127
+ self.regularization_weights = default(regularization_weights, {})
128
+
129
+ def get_trainable_parameters(self) -> Any:
130
+ return self.discriminator.parameters()
131
+
132
+ def get_trainable_autoencoder_parameters(self) -> Any:
133
+ if self.learn_logvar:
134
+ yield self.logvar
135
+ yield from ()
136
+
137
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
138
+ if last_layer is not None:
139
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
140
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
141
+ else:
142
+ nll_grads = torch.autograd.grad(
143
+ nll_loss, self.last_layer[0], retain_graph=True
144
+ )[0]
145
+ g_grads = torch.autograd.grad(
146
+ g_loss, self.last_layer[0], retain_graph=True
147
+ )[0]
148
+
149
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
150
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
151
+ d_weight = d_weight * self.discriminator_weight
152
+ return d_weight
153
+
154
+ def forward(
155
+ self,
156
+ regularization_log,
157
+ inputs,
158
+ reconstructions,
159
+ optimizer_idx,
160
+ global_step,
161
+ last_layer=None,
162
+ split="train",
163
+ weights=None,
164
+ ):
165
+ if self.scale_input_to_tgt_size:
166
+ inputs = torch.nn.functional.interpolate(
167
+ inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
168
+ )
169
+
170
+ if self.dims > 2:
171
+ inputs, reconstructions = map(
172
+ lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
173
+ (inputs, reconstructions),
174
+ )
175
+
176
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
177
+ if self.perceptual_weight > 0:
178
+ p_loss = self.perceptual_loss(
179
+ inputs.contiguous(), reconstructions.contiguous()
180
+ )
181
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
182
+
183
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
184
+ weighted_nll_loss = nll_loss
185
+ if weights is not None:
186
+ weighted_nll_loss = weights * nll_loss
187
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
188
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
189
+
190
+ # now the GAN part
191
+ if optimizer_idx == 0:
192
+ # generator update
193
+ logits_fake = self.discriminator(reconstructions.contiguous())
194
+ g_loss = -torch.mean(logits_fake)
195
+
196
+ if self.disc_factor > 0.0:
197
+ try:
198
+ d_weight = self.calculate_adaptive_weight(
199
+ nll_loss, g_loss, last_layer=last_layer
200
+ )
201
+ except RuntimeError:
202
+ assert not self.training
203
+ d_weight = torch.tensor(0.0)
204
+ else:
205
+ d_weight = torch.tensor(0.0)
206
+
207
+ disc_factor = adopt_weight(
208
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
209
+ )
210
+ loss = weighted_nll_loss + d_weight * disc_factor * g_loss
211
+ log = dict()
212
+ for k in regularization_log:
213
+ if k in self.regularization_weights:
214
+ loss = loss + self.regularization_weights[k] * regularization_log[k]
215
+ log[f"{split}/{k}"] = regularization_log[k].detach().mean()
216
+
217
+ log.update(
218
+ {
219
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
220
+ "{}/logvar".format(split): self.logvar.detach(),
221
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
222
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
223
+ "{}/d_weight".format(split): d_weight.detach(),
224
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
225
+ "{}/g_loss".format(split): g_loss.detach().mean(),
226
+ }
227
+ )
228
+
229
+ return loss, log
230
+
231
+ if optimizer_idx == 1:
232
+ # second pass for discriminator update
233
+ logits_real = self.discriminator(inputs.contiguous().detach())
234
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
235
+
236
+ disc_factor = adopt_weight(
237
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
238
+ )
239
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
240
+
241
+ log = {
242
+ "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
243
+ "{}/logits_real".format(split): logits_real.detach().mean(),
244
+ "{}/logits_fake".format(split): logits_fake.detach().mean(),
245
+ }
246
+ return d_loss, log
repositories/generative-models/sgm/modules/autoencoding/regularizers/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import Any, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from ....modules.distributions.distributions import DiagonalGaussianDistribution
9
+
10
+
11
+ class AbstractRegularizer(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
16
+ raise NotImplementedError()
17
+
18
+ @abstractmethod
19
+ def get_trainable_parameters(self) -> Any:
20
+ raise NotImplementedError()
21
+
22
+
23
+ class DiagonalGaussianRegularizer(AbstractRegularizer):
24
+ def __init__(self, sample: bool = True):
25
+ super().__init__()
26
+ self.sample = sample
27
+
28
+ def get_trainable_parameters(self) -> Any:
29
+ yield from ()
30
+
31
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
32
+ log = dict()
33
+ posterior = DiagonalGaussianDistribution(z)
34
+ if self.sample:
35
+ z = posterior.sample()
36
+ else:
37
+ z = posterior.mode()
38
+ kl_loss = posterior.kl()
39
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
40
+ log["kl_loss"] = kl_loss
41
+ return z, log
42
+
43
+
44
+ def measure_perplexity(predicted_indices, num_centroids):
45
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
46
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
47
+ encodings = (
48
+ F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
49
+ )
50
+ avg_probs = encodings.mean(0)
51
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
52
+ cluster_use = torch.sum(avg_probs > 0)
53
+ return perplexity, cluster_use
repositories/generative-models/sgm/modules/autoencoding/regularizers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.48 kB). View file
 
repositories/generative-models/sgm/modules/diffusionmodules/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .denoiser import Denoiser
2
+ from .discretizer import Discretization
3
+ from .loss import StandardDiffusionLoss
4
+ from .model import Model, Encoder, Decoder
5
+ from .openaimodel import UNetModel
6
+ from .sampling import BaseDiffusionSampler
7
+ from .wrappers import OpenAIWrapper
repositories/generative-models/sgm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (544 Bytes). View file