Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Reward_sana_idealized/open_clip/__pycache__/__init__.cpython-311.pyc +0 -0
- Reward_sana_idealized/open_clip/__pycache__/constants.cpython-311.pyc +0 -0
- Reward_sana_idealized/open_clip/__pycache__/hf_configs.cpython-311.pyc +0 -0
- Reward_sana_idealized/open_clip/__pycache__/hf_model.cpython-311.pyc +0 -0
- Reward_sana_idealized/open_clip/__pycache__/loss.cpython-311.pyc +0 -0
- Reward_sana_idealized/open_clip/__pycache__/openai.cpython-311.pyc +0 -0
- Reward_sana_idealized/open_clip/__pycache__/transform.cpython-311.pyc +0 -0
- Reward_sana_idealized/open_clip/__pycache__/utils.cpython-311.pyc +0 -0
- Reward_sana_idealized/open_clip/__pycache__/version.cpython-311.pyc +0 -0
- Reward_sana_idealized/open_clip/model_configs/RN50.json +21 -0
- Reward_sana_idealized/open_clip/model_configs/RN50x64.json +21 -0
- Reward_sana_idealized/open_clip/model_configs/ViT-B-32-plus-256.json +16 -0
- Reward_sana_idealized/open_clip/model_configs/ViT-B-32-quickgelu.json +17 -0
- Reward_sana_idealized/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json +15 -0
- Reward_sdxl_idealized/README.md +1336 -0
- Reward_sdxl_idealized/config_analysis_tuning.ipynb +218 -0
- Reward_sdxl_idealized/eval.py +1143 -0
- Reward_sdxl_idealized/examples.sh +19 -0
- Reward_sdxl_idealized/gradient_ascent_utils.py +339 -0
- Reward_sdxl_idealized/lr_scheduler.py +233 -0
- Reward_sdxl_idealized/models/__init__.py +3 -0
- Reward_sdxl_idealized/models/__pycache__/reward_model.cpython-310.pyc +0 -0
- Reward_sdxl_idealized/models/__pycache__/reward_model.cpython-313.pyc +0 -0
- Reward_sdxl_idealized/models/__pycache__/unet_2d_condition_reward.cpython-313.pyc +0 -0
- Reward_sdxl_idealized/models/unet_2d_condition_reward.py +1334 -0
- Reward_sdxl_idealized/pipelines/__pycache__/__init__.cpython-310.pyc +0 -0
- Reward_sdxl_idealized/pipelines/sdxl_gradient_ascent_pipeline.py +375 -0
- Reward_sdxl_idealized/timestep_convergence_analysis.ipynb +1105 -0
- Reward_sdxl_idealized/tune_hyperparams.py +514 -0
- Reward_sdxl_idealized/tune_parallel.sh +253 -0
- __pycache__/upload.cpython-311.pyc +0 -0
- evaluation/open_clip/__pycache__/__init__.cpython-311.pyc +0 -0
- evaluation/open_clip/__pycache__/coca_model.cpython-311.pyc +0 -0
- evaluation/open_clip/__pycache__/hf_model.cpython-311.pyc +0 -0
- evaluation/open_clip/__pycache__/loss.cpython-311.pyc +0 -0
- evaluation/open_clip/__pycache__/model.cpython-311.pyc +0 -0
- evaluation/open_clip/__pycache__/pretrained.cpython-311.pyc +0 -0
- evaluation/open_clip/__pycache__/push_to_hf_hub.cpython-311.pyc +0 -0
- evaluation/open_clip/__pycache__/tokenizer.cpython-311.pyc +0 -0
- evaluation/open_clip/__pycache__/transform.cpython-311.pyc +0 -0
- evaluation/open_clip/__pycache__/transformer.cpython-311.pyc +0 -0
- evaluation/open_clip/__pycache__/utils.cpython-311.pyc +0 -0
- evaluation/open_clip/model_configs/RN50.json +21 -0
- evaluation/open_clip/model_configs/ViT-B-32-plus-256.json +16 -0
- evaluation/open_clip/model_configs/ViT-S-32.json +16 -0
- evaluation/open_clip/model_configs/convnext_large_d_320.json +19 -0
- lrm/flux/.hydra/config.yaml +124 -0
- lrm/flux/.hydra/hydra.yaml +166 -0
- lrm/flux/.hydra/overrides.yaml +5 -0
- lrm/flux/README.md +4 -0
Reward_sana_idealized/open_clip/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2 kB). View file
|
|
|
Reward_sana_idealized/open_clip/__pycache__/constants.cpython-311.pyc
ADDED
|
Binary file (282 Bytes). View file
|
|
|
Reward_sana_idealized/open_clip/__pycache__/hf_configs.cpython-311.pyc
ADDED
|
Binary file (717 Bytes). View file
|
|
|
Reward_sana_idealized/open_clip/__pycache__/hf_model.cpython-311.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
Reward_sana_idealized/open_clip/__pycache__/loss.cpython-311.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
Reward_sana_idealized/open_clip/__pycache__/openai.cpython-311.pyc
ADDED
|
Binary file (8.74 kB). View file
|
|
|
Reward_sana_idealized/open_clip/__pycache__/transform.cpython-311.pyc
ADDED
|
Binary file (8.57 kB). View file
|
|
|
Reward_sana_idealized/open_clip/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (3.57 kB). View file
|
|
|
Reward_sana_idealized/open_clip/__pycache__/version.cpython-311.pyc
ADDED
|
Binary file (189 Bytes). View file
|
|
|
Reward_sana_idealized/open_clip/model_configs/RN50.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"embed_dim": 1024,
|
| 3 |
+
"vision_cfg": {
|
| 4 |
+
"image_size": 224,
|
| 5 |
+
"layers": [
|
| 6 |
+
3,
|
| 7 |
+
4,
|
| 8 |
+
6,
|
| 9 |
+
3
|
| 10 |
+
],
|
| 11 |
+
"width": 64,
|
| 12 |
+
"patch_size": null
|
| 13 |
+
},
|
| 14 |
+
"text_cfg": {
|
| 15 |
+
"context_length": 77,
|
| 16 |
+
"vocab_size": 49408,
|
| 17 |
+
"width": 512,
|
| 18 |
+
"heads": 8,
|
| 19 |
+
"layers": 12
|
| 20 |
+
}
|
| 21 |
+
}
|
Reward_sana_idealized/open_clip/model_configs/RN50x64.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"embed_dim": 1024,
|
| 3 |
+
"vision_cfg": {
|
| 4 |
+
"image_size": 448,
|
| 5 |
+
"layers": [
|
| 6 |
+
3,
|
| 7 |
+
15,
|
| 8 |
+
36,
|
| 9 |
+
10
|
| 10 |
+
],
|
| 11 |
+
"width": 128,
|
| 12 |
+
"patch_size": null
|
| 13 |
+
},
|
| 14 |
+
"text_cfg": {
|
| 15 |
+
"context_length": 77,
|
| 16 |
+
"vocab_size": 49408,
|
| 17 |
+
"width": 1024,
|
| 18 |
+
"heads": 16,
|
| 19 |
+
"layers": 12
|
| 20 |
+
}
|
| 21 |
+
}
|
Reward_sana_idealized/open_clip/model_configs/ViT-B-32-plus-256.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"embed_dim": 640,
|
| 3 |
+
"vision_cfg": {
|
| 4 |
+
"image_size": 256,
|
| 5 |
+
"layers": 12,
|
| 6 |
+
"width": 896,
|
| 7 |
+
"patch_size": 32
|
| 8 |
+
},
|
| 9 |
+
"text_cfg": {
|
| 10 |
+
"context_length": 77,
|
| 11 |
+
"vocab_size": 49408,
|
| 12 |
+
"width": 640,
|
| 13 |
+
"heads": 10,
|
| 14 |
+
"layers": 12
|
| 15 |
+
}
|
| 16 |
+
}
|
Reward_sana_idealized/open_clip/model_configs/ViT-B-32-quickgelu.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"embed_dim": 512,
|
| 3 |
+
"quick_gelu": true,
|
| 4 |
+
"vision_cfg": {
|
| 5 |
+
"image_size": 224,
|
| 6 |
+
"layers": 12,
|
| 7 |
+
"width": 768,
|
| 8 |
+
"patch_size": 32
|
| 9 |
+
},
|
| 10 |
+
"text_cfg": {
|
| 11 |
+
"context_length": 77,
|
| 12 |
+
"vocab_size": 49408,
|
| 13 |
+
"width": 512,
|
| 14 |
+
"heads": 8,
|
| 15 |
+
"layers": 12
|
| 16 |
+
}
|
| 17 |
+
}
|
Reward_sana_idealized/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"embed_dim": 512,
|
| 3 |
+
"vision_cfg": {
|
| 4 |
+
"image_size": 224,
|
| 5 |
+
"layers": 12,
|
| 6 |
+
"width": 768,
|
| 7 |
+
"patch_size": 32
|
| 8 |
+
},
|
| 9 |
+
"text_cfg": {
|
| 10 |
+
"hf_model_name": "xlm-roberta-base",
|
| 11 |
+
"hf_tokenizer_name": "xlm-roberta-base",
|
| 12 |
+
"proj": "mlp",
|
| 13 |
+
"pooler_type": "mean_pooler"
|
| 14 |
+
}
|
| 15 |
+
}
|
Reward_sdxl_idealized/README.md
ADDED
|
@@ -0,0 +1,1336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Reward-Guided Gradient Ascent for Stable Diffusion
|
| 2 |
+
|
| 3 |
+
A comprehensive system for improving Stable Diffusion image generation quality using gradient ascent optimization on Latent Reward Model (LRM) scores during inference.
|
| 4 |
+
|
| 5 |
+
## Table of Contents
|
| 6 |
+
|
| 7 |
+
- [Overview](#overview)
|
| 8 |
+
- [Features](#features)
|
| 9 |
+
- [Installation](#installation)
|
| 10 |
+
- [Quick Start](#quick-start)
|
| 11 |
+
- [Architecture](#architecture)
|
| 12 |
+
- [Understanding Reward Calculation](#understanding-reward-calculation)
|
| 13 |
+
- [Learning Rate Scheduling](#learning-rate-scheduling)
|
| 14 |
+
- [Configuration Presets](#configuration-presets)
|
| 15 |
+
- [Evaluation Metrics](#evaluation-metrics)
|
| 16 |
+
- [Model Variants](#model-variants)
|
| 17 |
+
- [Datasets](#datasets)
|
| 18 |
+
- [Usage Examples](#usage-examples)
|
| 19 |
+
- [API Reference](#api-reference)
|
| 20 |
+
- [Command-Line Options](#command-line-options)
|
| 21 |
+
- [Output Files](#output-files)
|
| 22 |
+
- [Troubleshooting](#troubleshooting)
|
| 23 |
+
- [Best Practices](#best-practices)
|
| 24 |
+
- [Changelog](#changelog)
|
| 25 |
+
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
## Overview
|
| 29 |
+
|
| 30 |
+
This project implements **test-time optimization** for Stable Diffusion using gradient ascent on the LRM reward model. Unlike the main LPO training which uses the reward model for training, this approach applies it during inference to improve generation quality without retraining.
|
| 31 |
+
|
| 32 |
+
### Key Capabilities
|
| 33 |
+
|
| 34 |
+
- **Gradient Ascent Optimization**: Iteratively improve latents using reward gradients
|
| 35 |
+
- **Learning Rate Scheduling**: Multiple strategies (constant, linear, cosine, exponential, step)
|
| 36 |
+
- **Momentum Optimization**: Standard and Nesterov momentum for better convergence
|
| 37 |
+
- **Multiple Metrics**: FID, CLIP, Aesthetic, PickScore, HPSv2, ImageReward
|
| 38 |
+
- **Model Variants**: Support for Origin, SPO, DPO, and LPO SD1.5 models
|
| 39 |
+
- **Dataset Flexibility**: COCO and Pick-a-Pic validation datasets
|
| 40 |
+
- **Configuration Presets**: 15 pre-tuned configurations for various use cases
|
| 41 |
+
|
| 42 |
+
---
|
| 43 |
+
|
| 44 |
+
## Features
|
| 45 |
+
|
| 46 |
+
### 1. **Advanced Optimization**
|
| 47 |
+
- **5 LR Schedulers**: Constant, Linear, Cosine, Exponential, Step-wise
|
| 48 |
+
- **Momentum Support**: Standard momentum and Nesterov momentum
|
| 49 |
+
- **Configurable Timestep Ranges**: Apply gradients at specific denoising steps
|
| 50 |
+
- **Dynamic Learning Rates**: LR changes during optimization for better convergence
|
| 51 |
+
|
| 52 |
+
### 2. **Comprehensive Evaluation**
|
| 53 |
+
- **6 Quality Metrics**: FID, CLIP, Aesthetic, PickScore, HPSv2, ImageReward
|
| 54 |
+
- **Baseline Comparison**: Compare with and without gradient ascent
|
| 55 |
+
- **Detailed Statistics**: Track reward improvements, gradient norms, LR history
|
| 56 |
+
- **Batch Processing**: Efficient evaluation on large datasets
|
| 57 |
+
- **Reward Visualization**: Automatic plotting of reward progression across timesteps
|
| 58 |
+
- **Timestep-Aware Tracking**: Monitor rewards at every denoising step, final t=0 latent reported
|
| 59 |
+
|
| 60 |
+
### 3. **Model Flexibility**
|
| 61 |
+
- **4 SD1.5 Variants**: Origin, SPO, DPO, LPO
|
| 62 |
+
- **Auto-Configuration**: CFG scale auto-adjusted for model variants
|
| 63 |
+
- **Easy Switching**: Change models with a single flag
|
| 64 |
+
|
| 65 |
+
### 4. **Dataset Support**
|
| 66 |
+
- **COCO Validation**: Standard benchmark with reference images
|
| 67 |
+
- **Pick-a-Pic Validation**: Large-scale human preference dataset
|
| 68 |
+
- **Streaming Support**: Handle large datasets efficiently
|
| 69 |
+
|
| 70 |
+
---
|
| 71 |
+
|
| 72 |
+
## Installation
|
| 73 |
+
|
| 74 |
+
### Requirements
|
| 75 |
+
|
| 76 |
+
```bash
|
| 77 |
+
# Core dependencies
|
| 78 |
+
pip install torch diffusers transformers torchmetrics datasets huggingface-hub
|
| 79 |
+
|
| 80 |
+
# For evaluation metrics
|
| 81 |
+
pip install pillow numpy scipy tqdm
|
| 82 |
+
|
| 83 |
+
# Optional: for better performance
|
| 84 |
+
pip install xformers # For memory-efficient attention
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
### Setup
|
| 88 |
+
|
| 89 |
+
```bash
|
| 90 |
+
cd /path/to/LPO/Reward
|
| 91 |
+
|
| 92 |
+
# Verify installation
|
| 93 |
+
python -c "from lr_scheduler import create_lr_scheduler; print('✓ LR Scheduler OK')"
|
| 94 |
+
python -c "from grad_ascent_configs import list_configs; print('✓ Configs:', len(list_configs()))"
|
| 95 |
+
python -c "from gradient_ascent_utils import RewardGuidedDiffusion; print('✓ Gradient Utils OK')"
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
---
|
| 99 |
+
|
| 100 |
+
## Quick Start
|
| 101 |
+
|
| 102 |
+
### 1. Basic COCO Evaluation (test_grad_sd1.5.py)
|
| 103 |
+
|
| 104 |
+
```bash
|
| 105 |
+
# Edit Config in test_grad_sd1.5.py:
|
| 106 |
+
# - Set device: "cuda:0" or "cuda:6"
|
| 107 |
+
# - Set max_samples: 10 for quick test, None for full dataset
|
| 108 |
+
# - Configure gradient ascent parameters
|
| 109 |
+
|
| 110 |
+
python test_grad_sd1.5.py
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
**Output:**
|
| 114 |
+
- Creates `RESULTS/SD1.5_GradAscent/run_1/` (auto-incremented)
|
| 115 |
+
- Generates `eval.log` with detailed metrics
|
| 116 |
+
- Saves `reward_curve.png` showing reward progression
|
| 117 |
+
|
| 118 |
+
### 2. Basic Evaluation with Preset Config (eval.py)
|
| 119 |
+
|
| 120 |
+
```bash
|
| 121 |
+
python eval.py \
|
| 122 |
+
--grad_config cosine_nesterov \
|
| 123 |
+
--metrics clip aesthetic \
|
| 124 |
+
--max_samples 10
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
### 2. High-Quality Evaluation
|
| 128 |
+
|
| 129 |
+
```bash
|
| 130 |
+
python eval.py \
|
| 131 |
+
--grad_config high_quality \
|
| 132 |
+
--metrics fid clip aesthetic pickscore hpsv2 \
|
| 133 |
+
--max_samples 100 \
|
| 134 |
+
--save_images \
|
| 135 |
+
--output_dir results/high_quality
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
### 3. Pick-a-Pic Benchmark
|
| 139 |
+
|
| 140 |
+
```bash
|
| 141 |
+
python eval.py \
|
| 142 |
+
--dataset_type pickapic \
|
| 143 |
+
--grad_config cosine_nesterov \
|
| 144 |
+
--metrics pickscore hpsv2 imagereward \
|
| 145 |
+
--max_samples 500 \
|
| 146 |
+
--output_dir results/pickapic
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
---
|
| 150 |
+
|
| 151 |
+
## Architecture
|
| 152 |
+
|
| 153 |
+
### System Components
|
| 154 |
+
|
| 155 |
+
```
|
| 156 |
+
Reward/
|
| 157 |
+
├── models/
|
| 158 |
+
│ ├── reward_model.py # LRM reward model wrapper
|
| 159 |
+
│ └── unet_2d_condition_reward.py # Custom UNet with reward tracking
|
| 160 |
+
├── pipelines/
|
| 161 |
+
│ ├── sd15_reward_pipeline.py # Base pipeline with reward tracking
|
| 162 |
+
│ └── sd15_gradient_ascent_pipeline.py # Pipeline with gradient ascent
|
| 163 |
+
├── lr_scheduler.py # Learning rate schedulers
|
| 164 |
+
├── gradient_ascent_utils.py # Core gradient ascent implementation
|
| 165 |
+
├── grad_ascent_configs.py # Configuration presets
|
| 166 |
+
├── eval.py # Comprehensive evaluation script
|
| 167 |
+
└── examples.sh # Example commands
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
### Gradient Ascent Flow
|
| 171 |
+
|
| 172 |
+
```
|
| 173 |
+
1. Load Stable Diffusion + LRM Reward Model
|
| 174 |
+
2. Start denoising process (T → 0)
|
| 175 |
+
3. At each timestep t:
|
| 176 |
+
a. Standard denoising step (predict noise, remove it)
|
| 177 |
+
b. Compute reward R(latents, prompt, t) and store in history
|
| 178 |
+
c. If t in gradient range:
|
| 179 |
+
- Enable gradients on latents
|
| 180 |
+
- Compute ∇R w.r.t. latents
|
| 181 |
+
- For each gradient step:
|
| 182 |
+
* Get current LR from scheduler
|
| 183 |
+
* Apply momentum (if enabled)
|
| 184 |
+
* Update: latents += lr * momentum(∇R)
|
| 185 |
+
- Track statistics (grad norms, reward improvement)
|
| 186 |
+
4. At final timestep (t=0):
|
| 187 |
+
- Final reward computed on clean latent
|
| 188 |
+
- This reward is reported in logs
|
| 189 |
+
5. Decode final latent (x₀) to image via VAE
|
| 190 |
+
6. Compute quality metrics on image
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
### Understanding Reward Calculation
|
| 194 |
+
|
| 195 |
+
**Key Concepts:**
|
| 196 |
+
|
| 197 |
+
- **Timestep-Aware Rewards**: The LRM reward model computes preference scores at ANY noise level (timestep t)
|
| 198 |
+
- **Progressive Tracking**: Rewards are calculated at every denoising step throughout generation
|
| 199 |
+
- **Final Latent Reward**: The reported metric is the reward for t=0 (the clean latent before decoding)
|
| 200 |
+
- **Not Averaged**: The final reward is specifically from the last timestep, NOT an average across all timesteps
|
| 201 |
+
|
| 202 |
+
**What gets reported:**
|
| 203 |
+
```python
|
| 204 |
+
# During generation: Rewards computed at each t (1000 → 0)
|
| 205 |
+
Step 0: t=1000, reward=3.2
|
| 206 |
+
Step 1: t=990, reward=3.5
|
| 207 |
+
...
|
| 208 |
+
Step 99: t=10, reward=5.1
|
| 209 |
+
Step 100: t=0, reward=5.4 ← This is what gets logged!
|
| 210 |
+
```
|
| 211 |
+
|
| 212 |
+
The `Reward (t=0)` in logs represents the preference score of the final clean latent that was decoded into your output image.
|
| 213 |
+
|
| 214 |
+
---
|
| 215 |
+
|
| 216 |
+
## Learning Rate Scheduling
|
| 217 |
+
|
| 218 |
+
### Available Schedulers
|
| 219 |
+
|
| 220 |
+
#### 1. **Constant LR**
|
| 221 |
+
```python
|
| 222 |
+
lr_scheduler_type="constant"
|
| 223 |
+
```
|
| 224 |
+
- Fixed learning rate throughout optimization
|
| 225 |
+
- Simple and stable
|
| 226 |
+
- Good for quick experiments
|
| 227 |
+
|
| 228 |
+
#### 2. **Linear Decay**
|
| 229 |
+
```python
|
| 230 |
+
lr_scheduler_type="linear"
|
| 231 |
+
lr_scheduler_kwargs={
|
| 232 |
+
"end_lr": 0.01, # End LR (10% of initial)
|
| 233 |
+
"start_step": 0 # When to start decay
|
| 234 |
+
}
|
| 235 |
+
```
|
| 236 |
+
- Linear decrease from initial to end LR
|
| 237 |
+
- Smooth convergence
|
| 238 |
+
- Configurable warmup period
|
| 239 |
+
|
| 240 |
+
#### 3. **Cosine Annealing** (Recommended)
|
| 241 |
+
```python
|
| 242 |
+
lr_scheduler_type="cosine"
|
| 243 |
+
lr_scheduler_kwargs={
|
| 244 |
+
"min_lr": 0.001, # Minimum LR
|
| 245 |
+
"warmup_steps": 3 # Linear warmup steps
|
| 246 |
+
}
|
| 247 |
+
```
|
| 248 |
+
- Smooth cosine decay
|
| 249 |
+
- Optional warmup phase
|
| 250 |
+
- Widely used in deep learning
|
| 251 |
+
- **Best for most use cases**
|
| 252 |
+
|
| 253 |
+
#### 4. **Exponential Decay**
|
| 254 |
+
```python
|
| 255 |
+
lr_scheduler_type="exponential"
|
| 256 |
+
lr_scheduler_kwargs={
|
| 257 |
+
"gamma": 0.9 # Decay factor per step
|
| 258 |
+
}
|
| 259 |
+
```
|
| 260 |
+
- Exponential decrease
|
| 261 |
+
- Fast initial decay
|
| 262 |
+
- Good for aggressive optimization
|
| 263 |
+
|
| 264 |
+
#### 5. **Step Decay**
|
| 265 |
+
```python
|
| 266 |
+
lr_scheduler_type="step"
|
| 267 |
+
lr_scheduler_kwargs={
|
| 268 |
+
"step_size": 5, # Steps between decays
|
| 269 |
+
"gamma": 0.5 # Multiplicative factor
|
| 270 |
+
}
|
| 271 |
+
```
|
| 272 |
+
- Step-wise LR reduction
|
| 273 |
+
- Periodic decay
|
| 274 |
+
- Good for scheduled changes
|
| 275 |
+
|
| 276 |
+
### Usage Example
|
| 277 |
+
|
| 278 |
+
```python
|
| 279 |
+
from pipelines.sd15_gradient_ascent_pipeline import StableDiffusionGradientAscentPipeline
|
| 280 |
+
|
| 281 |
+
pipeline.enable_gradient_ascent(
|
| 282 |
+
grad_timestep_range=(0, 700),
|
| 283 |
+
num_grad_steps=15,
|
| 284 |
+
grad_step_size=0.1, # Initial LR
|
| 285 |
+
lr_scheduler_type="cosine",
|
| 286 |
+
lr_scheduler_kwargs={
|
| 287 |
+
"min_lr": 0.001,
|
| 288 |
+
"warmup_steps": 3
|
| 289 |
+
}
|
| 290 |
+
)
|
| 291 |
+
```
|
| 292 |
+
|
| 293 |
+
---
|
| 294 |
+
|
| 295 |
+
## Configuration Presets
|
| 296 |
+
|
| 297 |
+
We provide 15 pre-configured optimization strategies. Use them with `--grad_config <name>`.
|
| 298 |
+
|
| 299 |
+
### Basic Configurations
|
| 300 |
+
|
| 301 |
+
| Config | LR Schedule | Momentum | Steps | Description |
|
| 302 |
+
|--------|-------------|----------|-------|-------------|
|
| 303 |
+
| `constant` | Constant | No | 5 | Simple baseline |
|
| 304 |
+
| `linear` | Linear decay | No | 10 | Smooth decay |
|
| 305 |
+
| `linear_warmstart` | Linear w/ warmup | No | 10 | Stable start |
|
| 306 |
+
| `cosine` | Cosine | No | 10 | Smooth convergence |
|
| 307 |
+
| `cosine_warmup` | Cosine w/ warmup | No | 20 | Best convergence |
|
| 308 |
+
| `exponential` | Exponential | No | 15 | Fast decay |
|
| 309 |
+
| `step` | Step-wise | No | 20 | Periodic decay |
|
| 310 |
+
|
| 311 |
+
### Momentum Configurations
|
| 312 |
+
|
| 313 |
+
| Config | LR Schedule | Momentum | Steps | Description |
|
| 314 |
+
|--------|-------------|----------|-------|-------------|
|
| 315 |
+
| `momentum` | Constant | Standard | 10 | Faster convergence |
|
| 316 |
+
| `nesterov` | Constant | Nesterov | 10 | Better convergence |
|
| 317 |
+
|
| 318 |
+
### Advanced Configurations
|
| 319 |
+
|
| 320 |
+
| Config | LR Schedule | Momentum | Steps | Description |
|
| 321 |
+
|--------|-------------|----------|-------|-------------|
|
| 322 |
+
| `cosine_momentum` | Cosine | Standard | 15 | High quality |
|
| 323 |
+
| `cosine_nesterov` | Cosine | Nesterov | 15 | **Recommended** |
|
| 324 |
+
| `linear_nesterov` | Linear | Nesterov | 15 | Stable + fast |
|
| 325 |
+
|
| 326 |
+
### Quality Presets
|
| 327 |
+
|
| 328 |
+
| Config | LR Schedule | Momentum | Steps | Use Case |
|
| 329 |
+
|--------|-------------|----------|-------|----------|
|
| 330 |
+
| `high_quality` | Cosine | Nesterov | 20 | **Best quality** |
|
| 331 |
+
| `aggressive` | Exponential | Standard | 8 | Fast results |
|
| 332 |
+
| `conservative` | Cosine | Nesterov | 25 | Most stable |
|
| 333 |
+
|
| 334 |
+
### Config Details
|
| 335 |
+
|
| 336 |
+
#### `high_quality` (Recommended for Research)
|
| 337 |
+
```python
|
| 338 |
+
{
|
| 339 |
+
"grad_timestep_range": (200, 800), # Focus on middle timesteps
|
| 340 |
+
"num_grad_steps": 20,
|
| 341 |
+
"grad_step_size": 0.08,
|
| 342 |
+
"lr_scheduler_type": "cosine",
|
| 343 |
+
"lr_scheduler_kwargs": {"min_lr": 0.005, "warmup_steps": 5},
|
| 344 |
+
"use_momentum": True,
|
| 345 |
+
"momentum": 0.95,
|
| 346 |
+
"use_nesterov": True
|
| 347 |
+
}
|
| 348 |
+
```
|
| 349 |
+
|
| 350 |
+
#### `cosine_nesterov` (Recommended for General Use)
|
| 351 |
+
```python
|
| 352 |
+
{
|
| 353 |
+
"grad_timestep_range": (0, 700),
|
| 354 |
+
"num_grad_steps": 15,
|
| 355 |
+
"grad_step_size": 0.12,
|
| 356 |
+
"lr_scheduler_type": "cosine",
|
| 357 |
+
"lr_scheduler_kwargs": {"min_lr": 0.001, "warmup_steps": 3},
|
| 358 |
+
"use_momentum": True,
|
| 359 |
+
"momentum": 0.9,
|
| 360 |
+
"use_nesterov": True
|
| 361 |
+
}
|
| 362 |
+
```
|
| 363 |
+
|
| 364 |
+
#### `aggressive` (Fast Experimentation)
|
| 365 |
+
```python
|
| 366 |
+
{
|
| 367 |
+
"grad_timestep_range": (0, 900),
|
| 368 |
+
"num_grad_steps": 8,
|
| 369 |
+
"grad_step_size": 0.15,
|
| 370 |
+
"grad_scale": 1.2,
|
| 371 |
+
"lr_scheduler_type": "exponential",
|
| 372 |
+
"lr_scheduler_kwargs": {"gamma": 0.85},
|
| 373 |
+
"use_momentum": True,
|
| 374 |
+
"momentum": 0.85,
|
| 375 |
+
"use_nesterov": False
|
| 376 |
+
}
|
| 377 |
+
```
|
| 378 |
+
|
| 379 |
+
### Listing Configs
|
| 380 |
+
|
| 381 |
+
```python
|
| 382 |
+
from grad_ascent_configs import list_configs, print_config, get_config
|
| 383 |
+
|
| 384 |
+
# List all available configs
|
| 385 |
+
print(list_configs())
|
| 386 |
+
# Output: ['aggressive', 'conservative', 'constant', 'cosine', ...]
|
| 387 |
+
|
| 388 |
+
# Print config details
|
| 389 |
+
print_config("cosine_nesterov")
|
| 390 |
+
|
| 391 |
+
# Get config dictionary
|
| 392 |
+
config = get_config("high_quality")
|
| 393 |
+
pipeline.enable_gradient_ascent(**config)
|
| 394 |
+
```
|
| 395 |
+
|
| 396 |
+
---
|
| 397 |
+
|
| 398 |
+
## Evaluation Metrics
|
| 399 |
+
|
| 400 |
+
### 1. **FID (Fréchet Inception Distance)**
|
| 401 |
+
- Measures distribution similarity between real and generated images
|
| 402 |
+
- **Lower is better**
|
| 403 |
+
- Requires reference images (COCO dataset only)
|
| 404 |
+
- Computationally expensive
|
| 405 |
+
|
| 406 |
+
```bash
|
| 407 |
+
--metrics fid
|
| 408 |
+
```
|
| 409 |
+
|
| 410 |
+
### 2. **CLIP Score**
|
| 411 |
+
- Evaluates text-image alignment using CLIP embeddings
|
| 412 |
+
- **Higher is better**
|
| 413 |
+
- Fast and reliable
|
| 414 |
+
- Good for general quality assessment
|
| 415 |
+
|
| 416 |
+
```bash
|
| 417 |
+
--metrics clip
|
| 418 |
+
```
|
| 419 |
+
|
| 420 |
+
### 3. **Aesthetic Score**
|
| 421 |
+
- Predicts aesthetic quality using CLIP + MLP
|
| 422 |
+
- **Higher is better**
|
| 423 |
+
- Trained on human aesthetic ratings
|
| 424 |
+
- Good for visual appeal
|
| 425 |
+
|
| 426 |
+
```bash
|
| 427 |
+
--metrics aesthetic
|
| 428 |
+
```
|
| 429 |
+
|
| 430 |
+
### 4. **PickScore** (New)
|
| 431 |
+
- Human preference predictor from Pick-a-Pic dataset
|
| 432 |
+
- **Higher is better**
|
| 433 |
+
- Trained on large-scale human comparisons
|
| 434 |
+
- State-of-the-art preference metric
|
| 435 |
+
|
| 436 |
+
```bash
|
| 437 |
+
--metrics pickscore
|
| 438 |
+
```
|
| 439 |
+
|
| 440 |
+
### 5. **HPSv2** (New)
|
| 441 |
+
- Human Preference Score version 2
|
| 442 |
+
- **Higher is better**
|
| 443 |
+
- Trained on aesthetic evaluations
|
| 444 |
+
- Complementary to PickScore
|
| 445 |
+
|
| 446 |
+
```bash
|
| 447 |
+
--metrics hpsv2
|
| 448 |
+
```
|
| 449 |
+
|
| 450 |
+
### 6. **ImageReward** (New)
|
| 451 |
+
- Reward model from RLHF (Reinforcement Learning from Human Feedback)
|
| 452 |
+
- **Higher is better**
|
| 453 |
+
- Comprehensive quality assessment
|
| 454 |
+
- Trained on diverse human feedback
|
| 455 |
+
|
| 456 |
+
```bash
|
| 457 |
+
--metrics imagereward
|
| 458 |
+
```
|
| 459 |
+
|
| 460 |
+
### Metric Recommendations
|
| 461 |
+
|
| 462 |
+
| Use Case | Recommended Metrics | Reason |
|
| 463 |
+
|----------|---------------------|--------|
|
| 464 |
+
| Research/Papers | `fid clip aesthetic pickscore hpsv2` | Comprehensive evaluation |
|
| 465 |
+
| Quick Iteration | `clip aesthetic` | Fast and reliable |
|
| 466 |
+
| Human Alignment | `pickscore hpsv2 imagereward` | Preference-based |
|
| 467 |
+
| Text Alignment | `clip imagereward` | Focus on prompt adherence |
|
| 468 |
+
| Visual Quality | `aesthetic pickscore` | Focus on aesthetics |
|
| 469 |
+
|
| 470 |
+
---
|
| 471 |
+
|
| 472 |
+
## Model Variants
|
| 473 |
+
|
| 474 |
+
Support for multiple SD1.5 model variants trained with different methods.
|
| 475 |
+
|
| 476 |
+
### Available Variants
|
| 477 |
+
|
| 478 |
+
#### 1. **Origin** (Default)
|
| 479 |
+
```bash
|
| 480 |
+
--model_variant origin
|
| 481 |
+
```
|
| 482 |
+
- Original Stable Diffusion v1.5 from RunwayML
|
| 483 |
+
- No additional training
|
| 484 |
+
- CFG scale: 7.5 (default)
|
| 485 |
+
- Good baseline
|
| 486 |
+
|
| 487 |
+
#### 2. **SPO** (Supervised Policy Optimization)
|
| 488 |
+
```bash
|
| 489 |
+
--model_variant spo
|
| 490 |
+
```
|
| 491 |
+
- Trained with SPO method
|
| 492 |
+
- Model: `SPO-Diffusion-Models/SPO-SD-v1-5_4k-p_10ep`
|
| 493 |
+
- **CFG scale: 5.0** (auto-adjusted)
|
| 494 |
+
- Better prompt adherence
|
| 495 |
+
|
| 496 |
+
#### 3. **Diffusion-DPO** (Direct Preference Optimization)
|
| 497 |
+
```bash
|
| 498 |
+
--model_variant diffusion_dpo
|
| 499 |
+
```
|
| 500 |
+
- Trained with DPO on human preferences
|
| 501 |
+
- Model: `mhdang/dpo-sd1.5-text2image-v1`
|
| 502 |
+
- CFG scale: 7.5
|
| 503 |
+
- Improved human alignment
|
| 504 |
+
|
| 505 |
+
#### 4. **LPO** (Latent Preference Optimization)
|
| 506 |
+
```bash
|
| 507 |
+
--model_variant lpo
|
| 508 |
+
```
|
| 509 |
+
- Trained with LPO (this project's main method)
|
| 510 |
+
- Model: `casiatao/LPO` (lpo_sd15_merge)
|
| 511 |
+
- **CFG scale: 5.0** (auto-adjusted)
|
| 512 |
+
- **Highest quality baseline**
|
| 513 |
+
|
| 514 |
+
### Comparison
|
| 515 |
+
|
| 516 |
+
| Variant | Training Method | Quality | Speed | Best For |
|
| 517 |
+
|---------|----------------|---------|-------|----------|
|
| 518 |
+
| Origin | Pre-training only | Good | Fast | Baseline |
|
| 519 |
+
| SPO | Supervised | Better | Fast | Prompt adherence |
|
| 520 |
+
| Diffusion-DPO | Preference learning | Better | Fast | Human preferences |
|
| 521 |
+
| LPO | Latent preference | **Best** | Fast | Overall quality |
|
| 522 |
+
|
| 523 |
+
### Usage Example
|
| 524 |
+
|
| 525 |
+
```bash
|
| 526 |
+
# Compare all variants
|
| 527 |
+
for variant in origin spo diffusion_dpo lpo; do
|
| 528 |
+
python eval.py \
|
| 529 |
+
--model_variant $variant \
|
| 530 |
+
--grad_config high_quality \
|
| 531 |
+
--metrics clip aesthetic pickscore \
|
| 532 |
+
--max_samples 100 \
|
| 533 |
+
--output_dir results/${variant}
|
| 534 |
+
done
|
| 535 |
+
```
|
| 536 |
+
|
| 537 |
+
---
|
| 538 |
+
|
| 539 |
+
## Datasets
|
| 540 |
+
|
| 541 |
+
### 1. **COCO Validation** (Default)
|
| 542 |
+
|
| 543 |
+
```bash
|
| 544 |
+
--dataset_type coco
|
| 545 |
+
--data_dir ./data
|
| 546 |
+
```
|
| 547 |
+
|
| 548 |
+
**Features:**
|
| 549 |
+
- Standard benchmark dataset
|
| 550 |
+
- Reference images available (for FID)
|
| 551 |
+
- ~5,000 validation samples
|
| 552 |
+
- Diverse prompts
|
| 553 |
+
|
| 554 |
+
**Structure:**
|
| 555 |
+
```
|
| 556 |
+
data/coco/
|
| 557 |
+
├── caption_val.json
|
| 558 |
+
└── images/val/
|
| 559 |
+
├── 000000000139.jpg
|
| 560 |
+
├── 000000000285.jpg
|
| 561 |
+
└── ...
|
| 562 |
+
```
|
| 563 |
+
|
| 564 |
+
### 2. **Pick-a-Pic Validation**
|
| 565 |
+
|
| 566 |
+
```bash
|
| 567 |
+
--dataset_type pickapic
|
| 568 |
+
```
|
| 569 |
+
|
| 570 |
+
**Features:**
|
| 571 |
+
- Large-scale human preference dataset
|
| 572 |
+
- Streaming (no download needed)
|
| 573 |
+
- ~500,000 validation samples
|
| 574 |
+
- Real user prompts
|
| 575 |
+
- No reference images (FID not available)
|
| 576 |
+
|
| 577 |
+
**Advantages:**
|
| 578 |
+
- More diverse prompts
|
| 579 |
+
- Real-world use cases
|
| 580 |
+
- Human preference focus
|
| 581 |
+
- Large-scale evaluation
|
| 582 |
+
|
| 583 |
+
### Dataset Recommendations
|
| 584 |
+
|
| 585 |
+
| Use Case | Dataset | Reason |
|
| 586 |
+
|----------|---------|--------|
|
| 587 |
+
| Academic Research | COCO | Standard benchmark, reproducible |
|
| 588 |
+
| FID Evaluation | COCO | Requires reference images |
|
| 589 |
+
| Human Preference | Pick-a-Pic | Trained on human comparisons |
|
| 590 |
+
| Large-scale Tests | Pick-a-Pic | 500K+ samples available |
|
| 591 |
+
| Quick Tests | COCO | Smaller, faster |
|
| 592 |
+
|
| 593 |
+
---
|
| 594 |
+
|
| 595 |
+
## Usage Examples
|
| 596 |
+
|
| 597 |
+
### Example 1: Quick Test
|
| 598 |
+
```bash
|
| 599 |
+
python eval.py \
|
| 600 |
+
--grad_config cosine_nesterov \
|
| 601 |
+
--metrics clip aesthetic \
|
| 602 |
+
--max_samples 10 \
|
| 603 |
+
--output_dir examples/quick_test
|
| 604 |
+
```
|
| 605 |
+
|
| 606 |
+
### Example 2: High-Quality Research Evaluation
|
| 607 |
+
```bash
|
| 608 |
+
python eval.py \
|
| 609 |
+
--grad_config high_quality \
|
| 610 |
+
--metrics fid clip aesthetic pickscore hpsv2 \
|
| 611 |
+
--max_samples 200 \
|
| 612 |
+
--save_images \
|
| 613 |
+
--output_dir examples/research
|
| 614 |
+
```
|
| 615 |
+
|
| 616 |
+
### Example 3: Pick-a-Pic Benchmark
|
| 617 |
+
```bash
|
| 618 |
+
python eval.py \
|
| 619 |
+
--dataset_type pickapic \
|
| 620 |
+
--grad_config cosine_nesterov \
|
| 621 |
+
--metrics pickscore hpsv2 imagereward \
|
| 622 |
+
--max_samples 500 \
|
| 623 |
+
--output_dir examples/pickapic
|
| 624 |
+
```
|
| 625 |
+
|
| 626 |
+
### Example 4: LPO Model Evaluation
|
| 627 |
+
```bash
|
| 628 |
+
python eval.py \
|
| 629 |
+
--model_variant lpo \
|
| 630 |
+
--grad_config high_quality \
|
| 631 |
+
--metrics clip aesthetic pickscore \
|
| 632 |
+
--max_samples 100 \
|
| 633 |
+
--save_images \
|
| 634 |
+
--output_dir examples/lpo_model
|
| 635 |
+
```
|
| 636 |
+
|
| 637 |
+
### Example 5: Baseline Only (No Gradient Ascent)
|
| 638 |
+
```bash
|
| 639 |
+
python eval.py \
|
| 640 |
+
--mode baseline \
|
| 641 |
+
--model_variant origin \
|
| 642 |
+
--metrics clip aesthetic pickscore \
|
| 643 |
+
--max_samples 50 \
|
| 644 |
+
--output_dir examples/baseline_only
|
| 645 |
+
```
|
| 646 |
+
|
| 647 |
+
### Example 6: Manual Configuration
|
| 648 |
+
```bash
|
| 649 |
+
python eval.py \
|
| 650 |
+
--grad_range_start 200 \
|
| 651 |
+
--grad_range_end 800 \
|
| 652 |
+
--grad_steps 15 \
|
| 653 |
+
--grad_step_size 0.08 \
|
| 654 |
+
--metrics clip aesthetic \
|
| 655 |
+
--max_samples 50 \
|
| 656 |
+
--output_dir examples/manual_config
|
| 657 |
+
```
|
| 658 |
+
|
| 659 |
+
### Example 7: Model Comparison
|
| 660 |
+
```bash
|
| 661 |
+
# Evaluate all model variants
|
| 662 |
+
for variant in origin spo diffusion_dpo lpo; do
|
| 663 |
+
python eval.py \
|
| 664 |
+
--model_variant $variant \
|
| 665 |
+
--grad_config high_quality \
|
| 666 |
+
--metrics clip aesthetic pickscore \
|
| 667 |
+
--max_samples 100 \
|
| 668 |
+
--save_images \
|
| 669 |
+
--output_dir results/comparison/${variant}
|
| 670 |
+
done
|
| 671 |
+
```
|
| 672 |
+
|
| 673 |
+
### Example 8: Conservative Optimization
|
| 674 |
+
```bash
|
| 675 |
+
python eval.py \
|
| 676 |
+
--grad_config conservative \
|
| 677 |
+
--metrics clip aesthetic pickscore hpsv2 \
|
| 678 |
+
--max_samples 100 \
|
| 679 |
+
--save_images \
|
| 680 |
+
--output_dir examples/conservative
|
| 681 |
+
```
|
| 682 |
+
|
| 683 |
+
---
|
| 684 |
+
|
| 685 |
+
## API Reference
|
| 686 |
+
|
| 687 |
+
### Pipeline Usage
|
| 688 |
+
|
| 689 |
+
```python
|
| 690 |
+
from diffusers import StableDiffusionPipeline
|
| 691 |
+
from pipelines.sd15_gradient_ascent_pipeline import StableDiffusionGradientAscentPipeline
|
| 692 |
+
from models import LRMRewardModel
|
| 693 |
+
|
| 694 |
+
# Load base pipeline
|
| 695 |
+
base_pipeline = StableDiffusionPipeline.from_pretrained(
|
| 696 |
+
"runwayml/stable-diffusion-v1-5",
|
| 697 |
+
torch_dtype=torch.float16
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
# Create gradient ascent pipeline
|
| 701 |
+
pipeline = StableDiffusionGradientAscentPipeline(**base_pipeline.components)
|
| 702 |
+
|
| 703 |
+
# Load reward model
|
| 704 |
+
reward_model = LRMRewardModel(
|
| 705 |
+
pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
|
| 706 |
+
lrm_model_path="casiatao/LRM",
|
| 707 |
+
guidance_scale=7.5,
|
| 708 |
+
device="cuda"
|
| 709 |
+
)
|
| 710 |
+
pipeline.set_reward_model(reward_model)
|
| 711 |
+
|
| 712 |
+
# Enable gradient ascent with preset
|
| 713 |
+
from grad_ascent_configs import get_config
|
| 714 |
+
config = get_config("cosine_nesterov")
|
| 715 |
+
pipeline.enable_gradient_ascent(**config)
|
| 716 |
+
|
| 717 |
+
# Or configure manually
|
| 718 |
+
pipeline.enable_gradient_ascent(
|
| 719 |
+
grad_timestep_range=(200, 800),
|
| 720 |
+
num_grad_steps=15,
|
| 721 |
+
grad_step_size=0.1,
|
| 722 |
+
lr_scheduler_type="cosine",
|
| 723 |
+
lr_scheduler_kwargs={"min_lr": 0.001, "warmup_steps": 3},
|
| 724 |
+
use_momentum=True,
|
| 725 |
+
momentum=0.9,
|
| 726 |
+
use_nesterov=True
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
# Generate with gradient ascent
|
| 730 |
+
output = pipeline(
|
| 731 |
+
prompt="a beautiful mountain landscape at sunset",
|
| 732 |
+
num_inference_steps=50,
|
| 733 |
+
guidance_scale=7.5,
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
# Get gradient statistics
|
| 737 |
+
stats = pipeline.grad_guidance.get_statistics()
|
| 738 |
+
print(f"Reward improvement: {stats['avg_reward_improvement']:.4f}")
|
| 739 |
+
```
|
| 740 |
+
|
| 741 |
+
### Custom LR Scheduler
|
| 742 |
+
|
| 743 |
+
```python
|
| 744 |
+
from lr_scheduler import create_lr_scheduler
|
| 745 |
+
|
| 746 |
+
# Create cosine scheduler with warmup
|
| 747 |
+
scheduler = create_lr_scheduler(
|
| 748 |
+
scheduler_type="cosine",
|
| 749 |
+
initial_lr=0.1,
|
| 750 |
+
num_steps=20,
|
| 751 |
+
min_lr=0.001,
|
| 752 |
+
warmup_steps=5
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
# Use in optimization loop
|
| 756 |
+
for step in range(20):
|
| 757 |
+
current_lr = scheduler.get_lr()
|
| 758 |
+
# ... apply gradient with current_lr ...
|
| 759 |
+
scheduler.step()
|
| 760 |
+
```
|
| 761 |
+
|
| 762 |
+
### Configuration Management
|
| 763 |
+
|
| 764 |
+
```python
|
| 765 |
+
from grad_ascent_configs import get_config, list_configs, print_config
|
| 766 |
+
|
| 767 |
+
# List all available configs
|
| 768 |
+
all_configs = list_configs()
|
| 769 |
+
print(f"Available configs: {all_configs}")
|
| 770 |
+
|
| 771 |
+
# Get specific config
|
| 772 |
+
config = get_config("high_quality")
|
| 773 |
+
|
| 774 |
+
# Print config details
|
| 775 |
+
print_config("cosine_nesterov")
|
| 776 |
+
|
| 777 |
+
# Create custom config
|
| 778 |
+
custom_config = {
|
| 779 |
+
"grad_timestep_range": (300, 700),
|
| 780 |
+
"num_grad_steps": 12,
|
| 781 |
+
"grad_step_size": 0.09,
|
| 782 |
+
"lr_scheduler_type": "cosine",
|
| 783 |
+
"lr_scheduler_kwargs": {"min_lr": 0.002, "warmup_steps": 4},
|
| 784 |
+
"use_momentum": True,
|
| 785 |
+
"momentum": 0.92,
|
| 786 |
+
"use_nesterov": True
|
| 787 |
+
}
|
| 788 |
+
pipeline.enable_gradient_ascent(**custom_config)
|
| 789 |
+
```
|
| 790 |
+
|
| 791 |
+
---
|
| 792 |
+
|
| 793 |
+
## Command-Line Options
|
| 794 |
+
|
| 795 |
+
### Essential Options
|
| 796 |
+
|
| 797 |
+
```bash
|
| 798 |
+
--data_dir PATH # Path to data directory (default: ./data)
|
| 799 |
+
--dataset_type TYPE # Dataset: coco or pickapic (default: coco)
|
| 800 |
+
--model_variant VARIANT # Model: origin, spo, diffusion_dpo, lpo (default: origin)
|
| 801 |
+
--max_samples N # Max samples to evaluate (default: all)
|
| 802 |
+
--output_dir PATH # Output directory (default: eval_outputs)
|
| 803 |
+
--save_images # Save generated images
|
| 804 |
+
```
|
| 805 |
+
|
| 806 |
+
### Gradient Ascent Options
|
| 807 |
+
|
| 808 |
+
```bash
|
| 809 |
+
--grad_config NAME # Use preset config (recommended)
|
| 810 |
+
--grad_range_start N # Gradient timestep start (default: 0)
|
| 811 |
+
--grad_range_end N # Gradient timestep end (default: 700)
|
| 812 |
+
--grad_steps N # Gradient steps per timestep (default: 5)
|
| 813 |
+
--grad_step_size FLOAT # Initial learning rate (default: 0.1)
|
| 814 |
+
```
|
| 815 |
+
|
| 816 |
+
### Evaluation Options
|
| 817 |
+
|
| 818 |
+
```bash
|
| 819 |
+
--metrics METRIC [METRIC...] # Metrics to evaluate (default: clip aesthetic)
|
| 820 |
+
# Options: fid, clip, aesthetic, pickscore, hpsv2, imagereward
|
| 821 |
+
--mode MODE # baseline, gradient_ascent, or both (default: both)
|
| 822 |
+
--num_steps N # Diffusion inference steps (default: 50)
|
| 823 |
+
--cfg_scale FLOAT # CFG scale (default: 7.5, auto-adjusted for some models)
|
| 824 |
+
--batch_size N # Batch size (default: 1)
|
| 825 |
+
--log_interval N # Log every N batches (default: 10)
|
| 826 |
+
```
|
| 827 |
+
|
| 828 |
+
### Other Options
|
| 829 |
+
|
| 830 |
+
```bash
|
| 831 |
+
--lrm_model PATH # LRM model path (default: casiatao/LRM)
|
| 832 |
+
--seed N # Random seed (default: 42)
|
| 833 |
+
--cuda N # CUDA device ID (default: 0)
|
| 834 |
+
```
|
| 835 |
+
|
| 836 |
+
### Complete Example
|
| 837 |
+
|
| 838 |
+
```bash
|
| 839 |
+
python eval.py \
|
| 840 |
+
--data_dir ./data \
|
| 841 |
+
--dataset_type coco \
|
| 842 |
+
--model_variant lpo \
|
| 843 |
+
--grad_config high_quality \
|
| 844 |
+
--metrics fid clip aesthetic pickscore hpsv2 \
|
| 845 |
+
--max_samples 200 \
|
| 846 |
+
--num_steps 50 \
|
| 847 |
+
--save_images \
|
| 848 |
+
--output_dir results/comprehensive \
|
| 849 |
+
--cuda 0
|
| 850 |
+
```
|
| 851 |
+
|
| 852 |
+
---
|
| 853 |
+
|
| 854 |
+
## Output Files
|
| 855 |
+
|
| 856 |
+
After running evaluation, the following files are created in **auto-incremented run folders**:
|
| 857 |
+
|
| 858 |
+
```
|
| 859 |
+
RESULTS/SD1.5_GradAscent/
|
| 860 |
+
├── run_1/ # First run
|
| 861 |
+
│ ├── eval.log # Complete execution log
|
| 862 |
+
│ └── reward_curve.png # Reward progression plot
|
| 863 |
+
├── run_2/ # Second run
|
| 864 |
+
│ ├── eval.log
|
| 865 |
+
│ └── reward_curve.png
|
| 866 |
+
└── run_3/ # Third run
|
| 867 |
+
├── eval.log
|
| 868 |
+
└── reward_curve.png
|
| 869 |
+
```
|
| 870 |
+
|
| 871 |
+
### Auto-Incrementing Run Folders
|
| 872 |
+
|
| 873 |
+
Each execution automatically creates a new `run_<N>/` folder, preventing accidental overwrites and maintaining a complete experiment history. No manual folder management needed!
|
| 874 |
+
|
| 875 |
+
### eval.log Structure
|
| 876 |
+
|
| 877 |
+
The log contains detailed information for each batch:
|
| 878 |
+
|
| 879 |
+
```
|
| 880 |
+
======================================================================
|
| 881 |
+
COCO GRADIENT ASCENT EVALUATION (BATCHED)
|
| 882 |
+
======================================================================
|
| 883 |
+
Logging to: ./RESULTS/SD1.5_GradAscent/run_1/eval.log
|
| 884 |
+
Device: cuda:6
|
| 885 |
+
Batch size: 1
|
| 886 |
+
Metrics: fid, clip, reward, aesthetic
|
| 887 |
+
Gradient Ascent: Range=[0, 900], Steps=1, StepSize=0.01
|
| 888 |
+
======================================================================
|
| 889 |
+
|
| 890 |
+
[Batch 1/5000] Samples: 1/5000 | FID: 2.5432 | CLIP: 0.8234 | Reward (t=0): 5.2341 | Reward (Avg): 5.2341 | Aesthetic: 6.456
|
| 891 |
+
[Batch 161/5000] Samples: 161/5000 | FID: 2.3821 | CLIP: 0.8412 | Reward (t=0): 5.4123 | Reward (Avg): 5.3215 | Aesthetic: 6.523
|
| 892 |
+
...
|
| 893 |
+
|
| 894 |
+
======================================================================
|
| 895 |
+
FINAL RESULTS
|
| 896 |
+
======================================================================
|
| 897 |
+
FID: 2.3456
|
| 898 |
+
CLIP avg: 0.8378
|
| 899 |
+
Reward avg: 5.3421
|
| 900 |
+
Aesthetic: 6.489
|
| 901 |
+
======================================================================
|
| 902 |
+
```
|
| 903 |
+
|
| 904 |
+
### reward_curve.png Visualization
|
| 905 |
+
|
| 906 |
+
The reward curve plot shows two panels for the **first generated image**:
|
| 907 |
+
|
| 908 |
+
**Left Panel: Reward vs Timestep**
|
| 909 |
+
- X-axis: Denoising timestep (t)
|
| 910 |
+
- Y-axis: Reward score
|
| 911 |
+
- Green shaded region: Where gradient ascent is applied
|
| 912 |
+
- Shows how reward evolves as noise is removed
|
| 913 |
+
|
| 914 |
+
**Right Panel: Reward vs Denoising Step**
|
| 915 |
+
- X-axis: Sequential denoising step (0 to num_inference_steps)
|
| 916 |
+
- Y-axis: Reward score
|
| 917 |
+
- Same data, different perspective for easier interpretation
|
| 918 |
+
|
| 919 |
+
**Key Insights from the Plot:**
|
| 920 |
+
- **Upward trend**: Reward generally increases as denoising progresses
|
| 921 |
+
- **Sharp improvements**: Visible spikes where gradient ascent is effective
|
| 922 |
+
- **Final reward**: Last point corresponds to t=0 (decoded image reward)
|
| 923 |
+
- **Learning dynamics**: Shows if optimization is working at different noise levels
|
| 924 |
+
|
| 925 |
+
### Reward Tracking Details
|
| 926 |
+
|
| 927 |
+
The script now explicitly tracks:
|
| 928 |
+
|
| 929 |
+
1. **Timestep-specific rewards**: Computed at every denoising step
|
| 930 |
+
2. **Final latent reward**: The reward for t=0 (the latent that gets decoded)
|
| 931 |
+
3. **Running average**: Mean reward across all processed samples
|
| 932 |
+
4. **Current batch reward**: Immediate feedback per batch
|
| 933 |
+
|
| 934 |
+
Example log output:
|
| 935 |
+
```
|
| 936 |
+
Reward (t=0): 5.4123 # Reward for the final decoded latent
|
| 937 |
+
Reward (Avg): 5.3215 # Running average across all samples
|
| 938 |
+
```
|
| 939 |
+
|
| 940 |
+
### evaluation_results.json Structure
|
| 941 |
+
|
| 942 |
+
(Legacy format from eval.py - test_grad_sd1.5.py uses simplified logging)
|
| 943 |
+
|
| 944 |
+
```json
|
| 945 |
+
{
|
| 946 |
+
"mode": "both",
|
| 947 |
+
"metrics": ["clip", "aesthetic", "pickscore"],
|
| 948 |
+
"config": {
|
| 949 |
+
"num_samples": 100,
|
| 950 |
+
"num_steps": 50,
|
| 951 |
+
"cfg_scale": 7.5,
|
| 952 |
+
"grad_range": [0, 700],
|
| 953 |
+
"grad_steps": 15,
|
| 954 |
+
"grad_step_size": 0.12
|
| 955 |
+
},
|
| 956 |
+
"baseline": {
|
| 957 |
+
"avg_reward": 0.7234,
|
| 958 |
+
"clip_score": 0.8123,
|
| 959 |
+
"aesthetic_score": 6.234,
|
| 960 |
+
"pickscore": 21.45
|
| 961 |
+
},
|
| 962 |
+
"gradient_ascent": {
|
| 963 |
+
"avg_reward": 0.7891,
|
| 964 |
+
"clip_score": 0.8345,
|
| 965 |
+
"aesthetic_score": 6.456,
|
| 966 |
+
"pickscore": 22.13,
|
| 967 |
+
"stats": {
|
| 968 |
+
"num_applications": 45,
|
| 969 |
+
"total_reward_improvement": 2.956,
|
| 970 |
+
"avg_reward_improvement": 0.0657
|
| 971 |
+
}
|
| 972 |
+
},
|
| 973 |
+
"comparison": {
|
| 974 |
+
"reward_difference": 0.0657,
|
| 975 |
+
"clip_difference": 0.0222,
|
| 976 |
+
"aesthetic_difference": 0.222,
|
| 977 |
+
"pickscore_difference": 0.68
|
| 978 |
+
}
|
| 979 |
+
}
|
| 980 |
+
```
|
| 981 |
+
|
| 982 |
+
---
|
| 983 |
+
|
| 984 |
+
## Troubleshooting
|
| 985 |
+
|
| 986 |
+
### Common Issues
|
| 987 |
+
|
| 988 |
+
#### 1. Out of Memory (OOM)
|
| 989 |
+
|
| 990 |
+
**Symptoms:**
|
| 991 |
+
```
|
| 992 |
+
RuntimeError: CUDA out of memory
|
| 993 |
+
```
|
| 994 |
+
|
| 995 |
+
**Solutions:**
|
| 996 |
+
```bash
|
| 997 |
+
# Reduce batch size
|
| 998 |
+
--batch_size 1
|
| 999 |
+
|
| 1000 |
+
# Reduce max samples
|
| 1001 |
+
--max_samples 50
|
| 1002 |
+
|
| 1003 |
+
# Reduce gradient steps
|
| 1004 |
+
--grad_steps 5
|
| 1005 |
+
|
| 1006 |
+
# Use smaller config
|
| 1007 |
+
--grad_config aggressive # Only 8 steps
|
| 1008 |
+
```
|
| 1009 |
+
|
| 1010 |
+
#### 2. Slow Evaluation
|
| 1011 |
+
|
| 1012 |
+
**Symptoms:**
|
| 1013 |
+
- Takes too long to complete
|
| 1014 |
+
- Hanging on metric computation
|
| 1015 |
+
|
| 1016 |
+
**Solutions:**
|
| 1017 |
+
```bash
|
| 1018 |
+
# Skip expensive metrics
|
| 1019 |
+
--metrics clip aesthetic # Skip FID
|
| 1020 |
+
|
| 1021 |
+
# Reduce samples
|
| 1022 |
+
--max_samples 50
|
| 1023 |
+
|
| 1024 |
+
# Reduce diffusion steps
|
| 1025 |
+
--num_steps 20
|
| 1026 |
+
|
| 1027 |
+
# Use faster dataset
|
| 1028 |
+
--dataset_type pickapic # No FID computation
|
| 1029 |
+
```
|
| 1030 |
+
|
| 1031 |
+
#### 3. Poor Results / No Improvement
|
| 1032 |
+
|
| 1033 |
+
**Symptoms:**
|
| 1034 |
+
- Reward doesn't increase
|
| 1035 |
+
- Quality worse after gradient ascent
|
| 1036 |
+
|
| 1037 |
+
**Solutions:**
|
| 1038 |
+
```bash
|
| 1039 |
+
# Try better configs
|
| 1040 |
+
--grad_config high_quality
|
| 1041 |
+
--grad_config conservative
|
| 1042 |
+
|
| 1043 |
+
# Increase gradient steps
|
| 1044 |
+
--grad_steps 20
|
| 1045 |
+
|
| 1046 |
+
# Adjust timestep range (focus on middle)
|
| 1047 |
+
--grad_range_start 200 --grad_range_end 800
|
| 1048 |
+
|
| 1049 |
+
# Try different model variant
|
| 1050 |
+
--model_variant lpo
|
| 1051 |
+
```
|
| 1052 |
+
|
| 1053 |
+
#### 4. Config Not Found
|
| 1054 |
+
|
| 1055 |
+
**Symptoms:**
|
| 1056 |
+
```
|
| 1057 |
+
ValueError: Unknown config: my_config
|
| 1058 |
+
```
|
| 1059 |
+
|
| 1060 |
+
**Solutions:**
|
| 1061 |
+
```bash
|
| 1062 |
+
# List available configs
|
| 1063 |
+
python -c "from grad_ascent_configs import list_configs; print(list_configs())"
|
| 1064 |
+
|
| 1065 |
+
# Print config details
|
| 1066 |
+
python -c "from grad_ascent_configs import print_config; print_config('high_quality')"
|
| 1067 |
+
```
|
| 1068 |
+
|
| 1069 |
+
#### 5. Metric Loading Errors
|
| 1070 |
+
|
| 1071 |
+
**Symptoms:**
|
| 1072 |
+
```
|
| 1073 |
+
Warning: Could not load PickScore scorer
|
| 1074 |
+
```
|
| 1075 |
+
|
| 1076 |
+
**Solutions:**
|
| 1077 |
+
```bash
|
| 1078 |
+
# Install missing dependencies
|
| 1079 |
+
pip install transformers datasets
|
| 1080 |
+
|
| 1081 |
+
# Check HuggingFace Hub access
|
| 1082 |
+
huggingface-cli login
|
| 1083 |
+
|
| 1084 |
+
# Skip problematic metrics
|
| 1085 |
+
--metrics clip aesthetic # Skip pickscore if it fails
|
| 1086 |
+
```
|
| 1087 |
+
|
| 1088 |
+
#### 6. Dataset Not Found
|
| 1089 |
+
|
| 1090 |
+
**Symptoms:**
|
| 1091 |
+
```
|
| 1092 |
+
FileNotFoundError: Validation JSON not found
|
| 1093 |
+
```
|
| 1094 |
+
|
| 1095 |
+
**Solutions:**
|
| 1096 |
+
```bash
|
| 1097 |
+
# Check data directory structure
|
| 1098 |
+
ls data/coco/
|
| 1099 |
+
|
| 1100 |
+
# Use Pick-a-Pic instead (no local files needed)
|
| 1101 |
+
--dataset_type pickapic
|
| 1102 |
+
|
| 1103 |
+
# Provide correct data path
|
| 1104 |
+
--data_dir /path/to/your/data
|
| 1105 |
+
```
|
| 1106 |
+
|
| 1107 |
+
---
|
| 1108 |
+
|
| 1109 |
+
## Best Practices
|
| 1110 |
+
|
| 1111 |
+
### 1. **Start Small, Scale Up**
|
| 1112 |
+
|
| 1113 |
+
```bash
|
| 1114 |
+
# First: Quick test (10 samples)
|
| 1115 |
+
python eval.py --grad_config cosine_nesterov --metrics clip --max_samples 10
|
| 1116 |
+
|
| 1117 |
+
# Then: Medium test (50 samples)
|
| 1118 |
+
python eval.py --grad_config cosine_nesterov --metrics clip aesthetic --max_samples 50
|
| 1119 |
+
|
| 1120 |
+
# Finally: Full evaluation (200+ samples)
|
| 1121 |
+
python eval.py --grad_config high_quality --metrics fid clip aesthetic pickscore hpsv2 --max_samples 200
|
| 1122 |
+
```
|
| 1123 |
+
|
| 1124 |
+
### 2. **Choose Right Config for Use Case**
|
| 1125 |
+
|
| 1126 |
+
| Goal | Config | Metrics |
|
| 1127 |
+
|------|--------|---------|
|
| 1128 |
+
| Quick experiment | `cosine_nesterov` | `clip` |
|
| 1129 |
+
| Research paper | `high_quality` | `fid clip aesthetic pickscore hpsv2` |
|
| 1130 |
+
| Production | `conservative` | `pickscore hpsv2` |
|
| 1131 |
+
| Fast iteration | `aggressive` | `clip aesthetic` |
|
| 1132 |
+
|
| 1133 |
+
### 3. **Use Multiple Metrics**
|
| 1134 |
+
|
| 1135 |
+
Don't rely on a single metric. Recommended combinations:
|
| 1136 |
+
|
| 1137 |
+
```bash
|
| 1138 |
+
# Text alignment + aesthetics
|
| 1139 |
+
--metrics clip aesthetic
|
| 1140 |
+
|
| 1141 |
+
# Human preference focus
|
| 1142 |
+
--metrics pickscore hpsv2 imagereward
|
| 1143 |
+
|
| 1144 |
+
# Comprehensive (research)
|
| 1145 |
+
--metrics fid clip aesthetic pickscore hpsv2
|
| 1146 |
+
```
|
| 1147 |
+
|
| 1148 |
+
### 4. **Save Important Runs**
|
| 1149 |
+
|
| 1150 |
+
```bash
|
| 1151 |
+
# Always save images for important evaluations
|
| 1152 |
+
--save_images --output_dir results/important_run_$(date +%Y%m%d)
|
| 1153 |
+
```
|
| 1154 |
+
|
| 1155 |
+
### 5. **Monitor GPU Usage**
|
| 1156 |
+
|
| 1157 |
+
```bash
|
| 1158 |
+
# In separate terminal
|
| 1159 |
+
watch -n 1 nvidia-smi
|
| 1160 |
+
|
| 1161 |
+
# Or use
|
| 1162 |
+
gpustat -i 1
|
| 1163 |
+
```
|
| 1164 |
+
|
| 1165 |
+
### 6. **Batch Evaluation**
|
| 1166 |
+
|
| 1167 |
+
```bash
|
| 1168 |
+
# Create evaluation script
|
| 1169 |
+
cat << 'EOF' > run_evals.sh
|
| 1170 |
+
#!/bin/bash
|
| 1171 |
+
for config in cosine_nesterov high_quality conservative; do
|
| 1172 |
+
for model in origin lpo; do
|
| 1173 |
+
python eval.py \
|
| 1174 |
+
--model_variant $model \
|
| 1175 |
+
--grad_config $config \
|
| 1176 |
+
--metrics clip aesthetic pickscore \
|
| 1177 |
+
--max_samples 100 \
|
| 1178 |
+
--save_images \
|
| 1179 |
+
--output_dir results/${model}_${config}
|
| 1180 |
+
done
|
| 1181 |
+
done
|
| 1182 |
+
EOF
|
| 1183 |
+
|
| 1184 |
+
chmod +x run_evals.sh
|
| 1185 |
+
./run_evals.sh
|
| 1186 |
+
```
|
| 1187 |
+
|
| 1188 |
+
### 7. **Reproducibility**
|
| 1189 |
+
|
| 1190 |
+
```bash
|
| 1191 |
+
# Always set seed for reproducible results
|
| 1192 |
+
--seed 42
|
| 1193 |
+
|
| 1194 |
+
# Document your runs
|
| 1195 |
+
--output_dir results/experiment_name_$(date +%Y%m%d_%H%M)
|
| 1196 |
+
```
|
| 1197 |
+
|
| 1198 |
+
### 8. **Performance Tips**
|
| 1199 |
+
|
| 1200 |
+
- Use `batch_size=1` for safety (reward model compatibility)
|
| 1201 |
+
- Start with `--max_samples 10` for debugging
|
| 1202 |
+
- Use `--dataset_type pickapic` for large-scale evaluation (no FID overhead)
|
| 1203 |
+
- Skip `fid` metric if not needed (expensive)
|
| 1204 |
+
- Use `--num_steps 20-30` for faster generation (vs default 50)
|
| 1205 |
+
|
| 1206 |
+
### 9. **Config Selection Guide**
|
| 1207 |
+
|
| 1208 |
+
```python
|
| 1209 |
+
# Start here
|
| 1210 |
+
if "just_testing":
|
| 1211 |
+
config = "constant"
|
| 1212 |
+
|
| 1213 |
+
# General use
|
| 1214 |
+
elif "standard_evaluation":
|
| 1215 |
+
config = "cosine_nesterov" # Best balance
|
| 1216 |
+
|
| 1217 |
+
# Research/papers
|
| 1218 |
+
elif "need_best_quality":
|
| 1219 |
+
config = "high_quality" # 20 steps, nesterov
|
| 1220 |
+
|
| 1221 |
+
# Fast experiments
|
| 1222 |
+
elif "need_speed":
|
| 1223 |
+
config = "aggressive" # 8 steps
|
| 1224 |
+
|
| 1225 |
+
# Stability critical
|
| 1226 |
+
elif "need_stability":
|
| 1227 |
+
config = "conservative" # 25 steps, careful
|
| 1228 |
+
```
|
| 1229 |
+
|
| 1230 |
+
### 10. **Timestep Range Tips**
|
| 1231 |
+
|
| 1232 |
+
```python
|
| 1233 |
+
# Full range (default)
|
| 1234 |
+
--grad_range_start 0 --grad_range_end 700
|
| 1235 |
+
|
| 1236 |
+
# Middle timesteps (often best)
|
| 1237 |
+
--grad_range_start 200 --grad_range_end 800
|
| 1238 |
+
|
| 1239 |
+
# Early timesteps (structure)
|
| 1240 |
+
--grad_range_start 500 --grad_range_end 1000
|
| 1241 |
+
|
| 1242 |
+
# Late timesteps (details)
|
| 1243 |
+
--grad_range_start 0 --grad_range_end 400
|
| 1244 |
+
```
|
| 1245 |
+
|
| 1246 |
+
---
|
| 1247 |
+
|
| 1248 |
+
## Performance Metrics
|
| 1249 |
+
|
| 1250 |
+
### Expected Results
|
| 1251 |
+
|
| 1252 |
+
Based on COCO validation set (100 samples):
|
| 1253 |
+
|
| 1254 |
+
| Method | CLIP ↑ | Aesthetic ↑ | PickScore ↑ | Time |
|
| 1255 |
+
|--------|--------|-------------|-------------|------|
|
| 1256 |
+
| Baseline (Origin) | 0.812 | 6.23 | 21.4 | 5 min |
|
| 1257 |
+
| + Constant | 0.819 | 6.28 | 21.6 | 6 min |
|
| 1258 |
+
| + Cosine Nesterov | 0.834 | 6.45 | 22.1 | 8 min |
|
| 1259 |
+
| + High Quality | 0.841 | 6.52 | 22.4 | 12 min |
|
| 1260 |
+
| Baseline (LPO) | 0.856 | 6.67 | 22.8 | 5 min |
|
| 1261 |
+
| LPO + High Quality | 0.873 | 6.89 | 23.5 | 12 min |
|
| 1262 |
+
|
| 1263 |
+
*Results may vary based on hardware and specific prompts*
|
| 1264 |
+
|
| 1265 |
+
---
|
| 1266 |
+
|
| 1267 |
+
## Citation
|
| 1268 |
+
|
| 1269 |
+
If you use this code in your research, please cite:
|
| 1270 |
+
|
| 1271 |
+
```bibtex
|
| 1272 |
+
@article{lpo2024,
|
| 1273 |
+
title={Latent Preference Optimization for Diffusion Models},
|
| 1274 |
+
author={Your Name},
|
| 1275 |
+
journal={arXiv preprint},
|
| 1276 |
+
year={2024}
|
| 1277 |
+
}
|
| 1278 |
+
```
|
| 1279 |
+
|
| 1280 |
+
---
|
| 1281 |
+
|
| 1282 |
+
## License
|
| 1283 |
+
|
| 1284 |
+
This project follows the license of the main LPO repository.
|
| 1285 |
+
|
| 1286 |
+
---
|
| 1287 |
+
|
| 1288 |
+
## Contributing
|
| 1289 |
+
|
| 1290 |
+
Contributions are welcome! Please:
|
| 1291 |
+
|
| 1292 |
+
1. Test your changes with `--max_samples 10`
|
| 1293 |
+
2. Document new features in this README
|
| 1294 |
+
3. Add examples to `examples.sh`
|
| 1295 |
+
4. Follow existing code style
|
| 1296 |
+
|
| 1297 |
+
---
|
| 1298 |
+
|
| 1299 |
+
## Support
|
| 1300 |
+
|
| 1301 |
+
For issues and questions:
|
| 1302 |
+
|
| 1303 |
+
1. Check [Troubleshooting](#troubleshooting) section
|
| 1304 |
+
2. Review [Examples](#usage-examples)
|
| 1305 |
+
3. Open an issue on GitHub
|
| 1306 |
+
|
| 1307 |
+
---
|
| 1308 |
+
|
| 1309 |
+
## Changelog
|
| 1310 |
+
|
| 1311 |
+
### Latest Version (January 2026)
|
| 1312 |
+
|
| 1313 |
+
**New Features:**
|
| 1314 |
+
- ✨ Learning rate scheduling (constant, linear, cosine, exponential, step)
|
| 1315 |
+
- ✨ Momentum optimization (standard and Nesterov)
|
| 1316 |
+
- ✨ 15 configuration presets
|
| 1317 |
+
- ✨ Additional metrics (PickScore, HPSv2, ImageReward)
|
| 1318 |
+
- ✨ Pick-a-Pic validation dataset support
|
| 1319 |
+
- ✨ SD1.5 model variants (Origin, SPO, DPO, LPO)
|
| 1320 |
+
- ✨ Comprehensive evaluation framework
|
| 1321 |
+
- ✨ **Automatic run folder creation** - Each run creates `run_1/`, `run_2/`, etc.
|
| 1322 |
+
- ✨ **Reward curve visualization** - Automatic plotting of reward progression across timesteps
|
| 1323 |
+
- ✨ **Final timestep reward tracking** - Reports reward specifically from t=0 (decoded latent)
|
| 1324 |
+
- ✨ **Detailed reward logging** - Shows both last timestep reward and running average
|
| 1325 |
+
|
| 1326 |
+
**Improvements:**
|
| 1327 |
+
- 🚀 Better convergence with LR scheduling
|
| 1328 |
+
- 🚀 Faster optimization with momentum
|
| 1329 |
+
- 📊 More comprehensive quality assessment
|
| 1330 |
+
- 📊 Visual feedback with reward curve plots
|
| 1331 |
+
- 📚 Complete documentation
|
| 1332 |
+
- 🔍 Enhanced debugging with timestep-specific reward tracking
|
| 1333 |
+
|
| 1334 |
+
---
|
| 1335 |
+
|
| 1336 |
+
**Happy Optimizing! 🚀**
|
Reward_sdxl_idealized/config_analysis_tuning.ipynb
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "a24d02a2",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import json\n",
|
| 11 |
+
"import pandas as pd\n",
|
| 12 |
+
"import numpy as np\n",
|
| 13 |
+
"from pathlib import Path\n",
|
| 14 |
+
"from datetime import datetime\n",
|
| 15 |
+
"import warnings\n",
|
| 16 |
+
"warnings.filterwarnings('ignore')\n",
|
| 17 |
+
"\n",
|
| 18 |
+
"# ============================================================================\n",
|
| 19 |
+
"# SECTION 1: Load and Parse Results from GPU Tuning Runs\n",
|
| 20 |
+
"# ==========================-==================================================\n",
|
| 21 |
+
"print(\"=\" * 80)\n",
|
| 22 |
+
"print(\"LOADING TUNING RESULTS FROM GPU RUNS\")\n",
|
| 23 |
+
"print(\"=\" * 80)\n",
|
| 24 |
+
"\n",
|
| 25 |
+
"results_dir = Path(\"RESULTS_TURNING/run_2\")\n",
|
| 26 |
+
"all_experiments = []\n",
|
| 27 |
+
"baseline_metrics = None\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"# Collect results from all GPU runs\n",
|
| 30 |
+
"for gpu_id in range(8):\n",
|
| 31 |
+
" gpu_dir = results_dir / f\"gpu_{gpu_id}\"\n",
|
| 32 |
+
" results_file = gpu_dir / \"tuning_results.json\"\n",
|
| 33 |
+
" \n",
|
| 34 |
+
" if results_file.exists():\n",
|
| 35 |
+
" with open(results_file, 'r') as f:\n",
|
| 36 |
+
" data = json.load(f)\n",
|
| 37 |
+
" \n",
|
| 38 |
+
" # Extract baseline (same across all GPUs)\n",
|
| 39 |
+
" if baseline_metrics is None and \"baseline\" in data:\n",
|
| 40 |
+
" baseline_metrics = data[\"baseline\"][\"metrics\"]\n",
|
| 41 |
+
" print(f\"\\n📊 Baseline Metrics (cfg_scale=5.0):\")\n",
|
| 42 |
+
" for metric, value in baseline_metrics.items():\n",
|
| 43 |
+
" print(f\" {metric:15s}: {value:.6f}\")\n",
|
| 44 |
+
" \n",
|
| 45 |
+
" # Collect all experiments\n",
|
| 46 |
+
" if \"experiments\" in data:\n",
|
| 47 |
+
" all_experiments.extend(data[\"experiments\"])\n",
|
| 48 |
+
" print(f\"✓ GPU {gpu_id}: {len(data['experiments'])} results loaded\")\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"print(f\"\\n✓ Total experiments loaded: {len(all_experiments)}\")\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"# ============================================================================\n",
|
| 53 |
+
"# SECTION 2: Filter Top Configs with Improvements Across All Metrics\n",
|
| 54 |
+
"# ============================================================================\n",
|
| 55 |
+
"print(\"\\n\" + \"=\" * 80)\n",
|
| 56 |
+
"print(\"FILTERING CONFIGURATIONS WITH IMPROVEMENTS IN ALL METRICS\")\n",
|
| 57 |
+
"print(\"=\" * 80)\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"# Define improvement metrics to track (using ImageReward instead of Reward)\n",
|
| 60 |
+
"improvement_metrics = [\n",
|
| 61 |
+
" \"aesthetic_improvement\", \n",
|
| 62 |
+
" \"imagereward_improvement\", \n",
|
| 63 |
+
" \"clip_improvement\", \n",
|
| 64 |
+
" \"pickscore_improvement\", \n",
|
| 65 |
+
" \"hpsv2_improvement\"\n",
|
| 66 |
+
" ]\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"# Filter experiments with improvements in ALL metrics\n",
|
| 69 |
+
"top_configs = []\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"for exp in all_experiments:\n",
|
| 72 |
+
" if \"improvements\" not in exp or \"config\" not in exp or \"metrics\" not in exp:\n",
|
| 73 |
+
" continue\n",
|
| 74 |
+
" \n",
|
| 75 |
+
" improvements = exp[\"improvements\"]\n",
|
| 76 |
+
" config = exp[\"config\"]\n",
|
| 77 |
+
" metrics = exp[\"metrics\"]\n",
|
| 78 |
+
" \n",
|
| 79 |
+
" # Check if ALL improvements are positive (>0)\n",
|
| 80 |
+
" all_positive = all(improvements.get(metric, -1) > 0 for metric in improvement_metrics)\n",
|
| 81 |
+
" \n",
|
| 82 |
+
" if all_positive:\n",
|
| 83 |
+
" # Calculate aggregate improvement score\n",
|
| 84 |
+
" avg_improvement = np.mean([improvements.get(metric, 0) for metric in improvement_metrics])\n",
|
| 85 |
+
" \n",
|
| 86 |
+
" top_configs.append({\n",
|
| 87 |
+
" \"config\": config,\n",
|
| 88 |
+
" \"metrics\": metrics,\n",
|
| 89 |
+
" \"improvements\": improvements,\n",
|
| 90 |
+
" \"avg_improvement\": avg_improvement\n",
|
| 91 |
+
" })\n",
|
| 92 |
+
"\n",
|
| 93 |
+
"print(f\"✓ Found {len(top_configs)} configurations with improvements in ALL metrics\")\n",
|
| 94 |
+
"\n",
|
| 95 |
+
"# Sort by average improvement\n",
|
| 96 |
+
"top_configs.sort(key=lambda x: x[\"avg_improvement\"], reverse=True)\n",
|
| 97 |
+
"\n",
|
| 98 |
+
"# Get top 10\n",
|
| 99 |
+
"top_10 = top_configs[:10]\n",
|
| 100 |
+
"print(f\"✓ Extracted top 10 best performing configurations\")\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"# ============================================================================\n",
|
| 103 |
+
"# SECTION 3: Create Comprehensive Results Table\n",
|
| 104 |
+
"# ============================================================================\n",
|
| 105 |
+
"print(\"\\n\" + \"=\" * 80)\n",
|
| 106 |
+
"print(\"CREATING COMPREHENSIVE RESULTS TABLE\")\n",
|
| 107 |
+
"print(\"=\" * 80)\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"# Build detailed table data\n",
|
| 110 |
+
"table_data = []\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"for rank, result in enumerate(top_10, 1):\n",
|
| 113 |
+
" cfg = result[\"config\"]\n",
|
| 114 |
+
" metrics = result[\"metrics\"]\n",
|
| 115 |
+
" improvements = result[\"improvements\"]\n",
|
| 116 |
+
" \n",
|
| 117 |
+
" row = {\n",
|
| 118 |
+
" \"Rank\": rank,\n",
|
| 119 |
+
" \"CFG Scale\": cfg.get(\"cfg_scale\", \"N/A\"),\n",
|
| 120 |
+
" \"Grad Config\": cfg.get(\"grad_config\", \"N/A\"),\n",
|
| 121 |
+
" \"Steps\": cfg.get(\"num_grad_steps\", \"N/A\"),\n",
|
| 122 |
+
" \"LR\": cfg.get(\"grad_step_size\", \"N/A\"),\n",
|
| 123 |
+
" \"Momentum\": cfg.get(\"momentum\", \"N/A\"),\n",
|
| 124 |
+
" \"ImageReward\": f\"{metrics.get('imagereward', 0):.6f}\",\n",
|
| 125 |
+
" \"ImageReward ↑\": f\"{improvements.get('imagereward_improvement', 0):+.2f}%\",\n",
|
| 126 |
+
" \"CLIP\": f\"{metrics.get('clip', 0):.4f}\",\n",
|
| 127 |
+
" \"CLIP ↑\": f\"{improvements.get('clip_improvement', 0):+.2f}%\",\n",
|
| 128 |
+
" \"Aesthetic\": f\"{metrics.get('aesthetic', 0):.4f}\",\n",
|
| 129 |
+
" \"Aesthetic ↑\": f\"{improvements.get('aesthetic_improvement', 0):+.2f}%\",\n",
|
| 130 |
+
" \"PickScore\": f\"{metrics.get('pickscore', 0):.4f}\",\n",
|
| 131 |
+
" \"PickScore ↑\": f\"{improvements.get('pickscore_improvement', 0):+.2f}%\",\n",
|
| 132 |
+
" \"HPSv2\": f\"{metrics.get('hpsv2', 0):.4f}\",\n",
|
| 133 |
+
" \"HPSv2 ↑\": f\"{improvements.get('hpsv2_improvement', 0):+.2f}%\",\n",
|
| 134 |
+
" \"Avg Improvement\": f\"{result['avg_improvement']:+.2f}%\",\n",
|
| 135 |
+
" }\n",
|
| 136 |
+
" \n",
|
| 137 |
+
" table_data.append(row)\n",
|
| 138 |
+
"\n",
|
| 139 |
+
"df_top_10 = pd.DataFrame(table_data)\n",
|
| 140 |
+
"\n",
|
| 141 |
+
"print(\"\\n📋 TOP 10 CONFIGURATIONS WITH IMPROVEMENTS IN ALL METRICS:\")\n",
|
| 142 |
+
"print(\"=\" * 180)\n",
|
| 143 |
+
"print(df_top_10.to_string(index=False))\n",
|
| 144 |
+
"print(\"=\" * 180)\n",
|
| 145 |
+
"\n",
|
| 146 |
+
"# ============================================================================\n",
|
| 147 |
+
"# SECTION 4: Visualize and Summary Statistics\n",
|
| 148 |
+
"# ============================================================================\n",
|
| 149 |
+
"print(\"\\n\" + \"=\" * 80)\n",
|
| 150 |
+
"print(\"SUMMARY STATISTICS\")\n",
|
| 151 |
+
"print(\"=\" * 80)\n",
|
| 152 |
+
"\n",
|
| 153 |
+
"# Extract numeric improvement values for analysis\n",
|
| 154 |
+
"improvement_summary = []\n",
|
| 155 |
+
"for result in top_10:\n",
|
| 156 |
+
" improvements = result[\"improvements\"]\n",
|
| 157 |
+
" for metric in [\"imagereward_improvement\", \"clip_improvement\", \"aesthetic_improvement\", \n",
|
| 158 |
+
" \"pickscore_improvement\", \"hpsv2_improvement\"]:\n",
|
| 159 |
+
" metric_name = metric.replace(\"_improvement\", \"\").upper()\n",
|
| 160 |
+
" improvement_summary.append({\n",
|
| 161 |
+
" \"Metric\": metric_name,\n",
|
| 162 |
+
" \"Improvement %\": improvements.get(metric, 0)\n",
|
| 163 |
+
" })\n",
|
| 164 |
+
"\n",
|
| 165 |
+
"df_summary = pd.DataFrame(improvement_summary)\n",
|
| 166 |
+
"\n",
|
| 167 |
+
"print(\"\\n📊 Average Improvements by Metric (Top 10):\")\n",
|
| 168 |
+
"metric_stats = df_summary.groupby(\"Metric\")[\"Improvement %\"].agg([\"mean\", \"std\", \"min\", \"max\"])\n",
|
| 169 |
+
"print(metric_stats.round(2))\n",
|
| 170 |
+
"\n",
|
| 171 |
+
"print(\"\\n📈 Best Configuration Details:\")\n",
|
| 172 |
+
"best = top_10[0]\n",
|
| 173 |
+
"best_cfg = best[\"config\"]\n",
|
| 174 |
+
"best_metrics = best[\"metrics\"]\n",
|
| 175 |
+
"best_improvements = best[\"improvements\"]\n",
|
| 176 |
+
"\n",
|
| 177 |
+
"print(f\"\\n✓ RANK #1 - Best Performing Configuration:\")\n",
|
| 178 |
+
"print(f\" Configuration:\")\n",
|
| 179 |
+
"print(f\" • CFG Scale: {best_cfg.get('cfg_scale')}\")\n",
|
| 180 |
+
"print(f\" • Gradient Config: {best_cfg.get('grad_config')}\")\n",
|
| 181 |
+
"print(f\" • Gradient Steps: {best_cfg.get('num_grad_steps')}\")\n",
|
| 182 |
+
"print(f\" • Step Size: {best_cfg.get('grad_step_size')}\")\n",
|
| 183 |
+
"print(f\" • Momentum: {best_cfg.get('momentum')}\")\n",
|
| 184 |
+
"print(f\"\\n Metrics:\")\n",
|
| 185 |
+
"for metric in [\"imagereward\", \"clip\", \"aesthetic\", \"pickscore\", \"hpsv2\"]:\n",
|
| 186 |
+
" baseline_val = baseline_metrics.get(metric, 0)\n",
|
| 187 |
+
" current_val = best_metrics.get(metric, 0)\n",
|
| 188 |
+
" improvement = best_improvements.get(f\"{metric}_improvement\", 0)\n",
|
| 189 |
+
" print(f\" • {metric:12s}: {current_val:8.6f} (baseline: {baseline_val:8.6f}) ↑ {improvement:+6.2f}%\")\n",
|
| 190 |
+
"\n",
|
| 191 |
+
"print(\"\\n\" + \"=\" * 80)\n",
|
| 192 |
+
"print(\"✓ ANALYSIS COMPLETE - TOP 10 CONFIGURATIONS IDENTIFIED\")\n",
|
| 193 |
+
"print(\"=\" * 80)"
|
| 194 |
+
]
|
| 195 |
+
}
|
| 196 |
+
],
|
| 197 |
+
"metadata": {
|
| 198 |
+
"kernelspec": {
|
| 199 |
+
"display_name": "Python 3",
|
| 200 |
+
"language": "python",
|
| 201 |
+
"name": "python3"
|
| 202 |
+
},
|
| 203 |
+
"language_info": {
|
| 204 |
+
"codemirror_mode": {
|
| 205 |
+
"name": "ipython",
|
| 206 |
+
"version": 3
|
| 207 |
+
},
|
| 208 |
+
"file_extension": ".py",
|
| 209 |
+
"mimetype": "text/x-python",
|
| 210 |
+
"name": "python",
|
| 211 |
+
"nbconvert_exporter": "python",
|
| 212 |
+
"pygments_lexer": "ipython3",
|
| 213 |
+
"version": "3.10.18"
|
| 214 |
+
}
|
| 215 |
+
},
|
| 216 |
+
"nbformat": 4,
|
| 217 |
+
"nbformat_minor": 5
|
| 218 |
+
}
|
Reward_sdxl_idealized/eval.py
ADDED
|
@@ -0,0 +1,1143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation script for comparing baseline and gradient ascent pipelines using multiple metrics.
|
| 3 |
+
|
| 4 |
+
This script evaluates both pipelines on COCO or Pick-a-Pic validation sets and computes
|
| 5 |
+
various preference and quality metrics.
|
| 6 |
+
"""
|
| 7 |
+
import warnings
|
| 8 |
+
warnings.filterwarnings("ignore")
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import json
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import logging
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from PIL import Image
|
| 17 |
+
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel, StableDiffusionXLPipeline
|
| 18 |
+
from models.reward_model import LRMRewardModelXL
|
| 19 |
+
from pipelines.sdxl_gradient_ascent_pipeline import StableDiffusionXLGradientAscentPipeline
|
| 20 |
+
from torchmetrics.image.fid import FrechetInceptionDistance
|
| 21 |
+
from torchmetrics.multimodal import CLIPScore
|
| 22 |
+
from transformers import CLIPModel, CLIPProcessor
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
import numpy as np
|
| 25 |
+
import argparse
|
| 26 |
+
from datasets import load_dataset
|
| 27 |
+
from grad_ascent_configs import get_config, list_configs
|
| 28 |
+
import matplotlib.pyplot as plt
|
| 29 |
+
import matplotlib
|
| 30 |
+
matplotlib.use('Agg') # Use non-interactive backend
|
| 31 |
+
|
| 32 |
+
# Import evaluation metrics
|
| 33 |
+
sys.path.append('../evaluation')
|
| 34 |
+
from pick_score import PickScorer
|
| 35 |
+
from hpsv2_score import HPSv2Scorer
|
| 36 |
+
from imagereward_score import load_imagereward
|
| 37 |
+
from huggingface_hub import hf_hub_download
|
| 38 |
+
|
| 39 |
+
import random
|
| 40 |
+
|
| 41 |
+
def seed_everything(seed: int):
|
| 42 |
+
"""Locks down all random number generators for absolute reproducibility."""
|
| 43 |
+
# 1. Python & Numpy
|
| 44 |
+
random.seed(seed)
|
| 45 |
+
np.random.seed(seed)
|
| 46 |
+
|
| 47 |
+
# 2. PyTorch Base
|
| 48 |
+
torch.manual_seed(seed)
|
| 49 |
+
if torch.cuda.is_available():
|
| 50 |
+
torch.cuda.manual_seed(seed)
|
| 51 |
+
torch.cuda.manual_seed_all(seed) # For multi-GPU
|
| 52 |
+
|
| 53 |
+
# 3. cuDNN Determinism (Crucial for consistent gradients)
|
| 54 |
+
torch.backends.cudnn.deterministic = True
|
| 55 |
+
torch.backends.cudnn.benchmark = False
|
| 56 |
+
|
| 57 |
+
# 4. Optional: Force deterministic algorithms for PyTorch 2.0+
|
| 58 |
+
# Uncomment if variance persists, but it may slow down generation slightly
|
| 59 |
+
# torch.use_deterministic_algorithms(True)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class MLP(nn.Module):
|
| 63 |
+
"""MLP for aesthetic scoring."""
|
| 64 |
+
def __init__(self):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.layers = nn.Sequential(
|
| 67 |
+
nn.Linear(768, 1024),
|
| 68 |
+
nn.Dropout(0.2),
|
| 69 |
+
nn.Linear(1024, 128),
|
| 70 |
+
nn.Dropout(0.2),
|
| 71 |
+
nn.Linear(128, 64),
|
| 72 |
+
nn.Dropout(0.1),
|
| 73 |
+
nn.Linear(64, 16),
|
| 74 |
+
nn.Linear(16, 1),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
@torch.no_grad()
|
| 78 |
+
def forward(self, embed):
|
| 79 |
+
return self.layers(embed)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class AestheticScorer(torch.nn.Module):
|
| 83 |
+
"""Aesthetic scorer using CLIP and MLP."""
|
| 84 |
+
def __init__(self, dtype, device, clip_name_or_path="openai/clip-vit-large-patch14",
|
| 85 |
+
aesthetic_path="./sac+logos+ava1-l14-linearMSE.pth"):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.clip = CLIPModel.from_pretrained(clip_name_or_path)
|
| 88 |
+
self.processor = CLIPProcessor.from_pretrained(clip_name_or_path)
|
| 89 |
+
self.mlp = MLP()
|
| 90 |
+
|
| 91 |
+
# Load aesthetic weights
|
| 92 |
+
if os.path.exists(aesthetic_path):
|
| 93 |
+
state_dict = torch.load(aesthetic_path, map_location='cpu')
|
| 94 |
+
self.mlp.load_state_dict(state_dict)
|
| 95 |
+
else:
|
| 96 |
+
print(f"Warning: Aesthetic weights not found at {aesthetic_path}")
|
| 97 |
+
|
| 98 |
+
self.dtype = dtype
|
| 99 |
+
self.to(device)
|
| 100 |
+
self.eval()
|
| 101 |
+
|
| 102 |
+
@torch.no_grad()
|
| 103 |
+
def __call__(self, images):
|
| 104 |
+
device = next(self.parameters()).device
|
| 105 |
+
inputs = self.processor(images=images, return_tensors="pt")
|
| 106 |
+
inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()}
|
| 107 |
+
embed = self.clip.get_image_features(**inputs)
|
| 108 |
+
# normalize embedding
|
| 109 |
+
embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
|
| 110 |
+
return self.mlp(embed).squeeze(1)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class TeeLogger:
|
| 114 |
+
"""Logger that writes to both console and file."""
|
| 115 |
+
def __init__(self, log_file):
|
| 116 |
+
self.terminal = sys.stdout
|
| 117 |
+
self.log = open(log_file, 'w')
|
| 118 |
+
|
| 119 |
+
def write(self, message):
|
| 120 |
+
self.terminal.write(message)
|
| 121 |
+
self.log.write(message)
|
| 122 |
+
self.log.flush()
|
| 123 |
+
|
| 124 |
+
def flush(self):
|
| 125 |
+
self.terminal.flush()
|
| 126 |
+
self.log.flush()
|
| 127 |
+
|
| 128 |
+
def close(self):
|
| 129 |
+
self.log.close()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def setup_logging(output_dir):
|
| 133 |
+
"""Setup logging to both console and file."""
|
| 134 |
+
output_path = Path(output_dir)
|
| 135 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 136 |
+
log_file = output_path / "log.log"
|
| 137 |
+
|
| 138 |
+
# Redirect stdout to both console and file
|
| 139 |
+
tee = TeeLogger(log_file)
|
| 140 |
+
sys.stdout = tee
|
| 141 |
+
|
| 142 |
+
return tee, log_file
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def load_validation_data(data_dir, max_samples=None, dataset_type="coco"):
|
| 146 |
+
"""Load validation prompts and image paths.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
data_dir: Path to data directory
|
| 150 |
+
max_samples: Maximum number of samples to load
|
| 151 |
+
dataset_type: Type of dataset ("coco" or "pickapic")
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
prompts: List of text prompts
|
| 155 |
+
image_paths: List of image paths (None for pickapic streaming dataset)
|
| 156 |
+
"""
|
| 157 |
+
if dataset_type == "coco":
|
| 158 |
+
data_dir = Path(data_dir)
|
| 159 |
+
val_json = data_dir / "coco" / "caption_val.json"
|
| 160 |
+
|
| 161 |
+
if not val_json.exists():
|
| 162 |
+
raise FileNotFoundError(f"Validation JSON not found: {val_json}")
|
| 163 |
+
|
| 164 |
+
with open(val_json, 'r') as f:
|
| 165 |
+
data = json.load(f)
|
| 166 |
+
|
| 167 |
+
# Validate that image folder exists
|
| 168 |
+
val_img_dir = data_dir / "coco" / "images" / "val"
|
| 169 |
+
if not val_img_dir.exists():
|
| 170 |
+
raise FileNotFoundError(f"Validation image directory not found: {val_img_dir}")
|
| 171 |
+
|
| 172 |
+
# Parse data
|
| 173 |
+
prompts = []
|
| 174 |
+
image_paths = []
|
| 175 |
+
for img_path, caption in data.items():
|
| 176 |
+
full_path = data_dir / "coco" / img_path
|
| 177 |
+
if full_path.exists():
|
| 178 |
+
prompts.append(caption)
|
| 179 |
+
image_paths.append(str(full_path))
|
| 180 |
+
else:
|
| 181 |
+
print(f"Warning: Image not found: {full_path}")
|
| 182 |
+
|
| 183 |
+
if max_samples:
|
| 184 |
+
prompts = prompts[:max_samples]
|
| 185 |
+
image_paths = image_paths[:max_samples]
|
| 186 |
+
|
| 187 |
+
print(f"Loaded {len(prompts)} COCO validation samples")
|
| 188 |
+
return prompts, image_paths
|
| 189 |
+
|
| 190 |
+
elif dataset_type == "pickapic":
|
| 191 |
+
print("Loading Pick-a-Pic validation dataset (streaming)...")
|
| 192 |
+
val_dataset = load_dataset("pickapic-anonymous/pickapic_v1", split="validation_unique", streaming=True)
|
| 193 |
+
|
| 194 |
+
prompts = []
|
| 195 |
+
for i, sample in enumerate(val_dataset):
|
| 196 |
+
prompts.append(sample['caption'])
|
| 197 |
+
if max_samples and i + 1 >= max_samples:
|
| 198 |
+
break
|
| 199 |
+
|
| 200 |
+
print(f"Loaded {len(prompts)} Pick-a-Pic validation samples")
|
| 201 |
+
return prompts, None # No reference images for Pick-a-Pic
|
| 202 |
+
|
| 203 |
+
else:
|
| 204 |
+
raise ValueError(f"Unknown dataset type: {dataset_type}. Choose 'coco' or 'pickapic'.")
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def generate_and_evaluate(
|
| 208 |
+
pipeline,
|
| 209 |
+
prompts,
|
| 210 |
+
image_paths,
|
| 211 |
+
device,
|
| 212 |
+
dtype,
|
| 213 |
+
num_inference_steps=20,
|
| 214 |
+
guidance_scale=7.5,
|
| 215 |
+
seed=42,
|
| 216 |
+
batch_size=1,
|
| 217 |
+
apply_gradient_ascent=False,
|
| 218 |
+
mode_name="baseline",
|
| 219 |
+
log_interval=10,
|
| 220 |
+
output_dir=None,
|
| 221 |
+
save_images=False,
|
| 222 |
+
clip_scorer=None,
|
| 223 |
+
aesthetic_scorer=None,
|
| 224 |
+
pick_scorer=None,
|
| 225 |
+
hpsv2_scorer=None,
|
| 226 |
+
hpsv21_scorer=None,
|
| 227 |
+
imagereward_scorer=None,
|
| 228 |
+
compute_fid=True
|
| 229 |
+
):
|
| 230 |
+
"""Generate images and update FID metric."""
|
| 231 |
+
pipeline.to(device)
|
| 232 |
+
|
| 233 |
+
print(f"\nGenerating images with {mode_name} mode...")
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
all_rewards = []
|
| 237 |
+
all_clip_scores = []
|
| 238 |
+
all_aesthetic_scores = []
|
| 239 |
+
all_pick_scores = []
|
| 240 |
+
all_hpsv2_scores = []
|
| 241 |
+
all_hpsv21_scores = []
|
| 242 |
+
all_imagereward_scores = []
|
| 243 |
+
lr_history_first_image = None # Store LR history for first image
|
| 244 |
+
num_batches = (len(prompts) + batch_size - 1) // batch_size
|
| 245 |
+
|
| 246 |
+
# Create output directory if saving images
|
| 247 |
+
if save_images and output_dir:
|
| 248 |
+
mode_output_dir = Path(output_dir) / mode_name
|
| 249 |
+
mode_output_dir.mkdir(parents=True, exist_ok=True)
|
| 250 |
+
|
| 251 |
+
# Disable internal progress bars
|
| 252 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 253 |
+
|
| 254 |
+
for idx, i in enumerate(tqdm(range(0, len(prompts), batch_size), desc=f"Generating {mode_name}")):
|
| 255 |
+
batch_prompts = prompts[i:i+batch_size]
|
| 256 |
+
batch_real_paths = image_paths[i:i+batch_size] if image_paths is not None else None
|
| 257 |
+
batch_num = idx + 1
|
| 258 |
+
|
| 259 |
+
# Initialize FID metric if needed
|
| 260 |
+
fid_metric = None
|
| 261 |
+
real_images_tensor = None
|
| 262 |
+
|
| 263 |
+
if compute_fid and batch_real_paths is not None:
|
| 264 |
+
fid_metric = FrechetInceptionDistance().to(device)
|
| 265 |
+
|
| 266 |
+
# Load and update FID with real images for this batch
|
| 267 |
+
real_images = []
|
| 268 |
+
for path in batch_real_paths:
|
| 269 |
+
img = Image.open(path).convert("RGB")
|
| 270 |
+
img = img.resize((512, 512)) # Inception v3 input size
|
| 271 |
+
img_array = np.array(img)
|
| 272 |
+
real_images.append(img_array)
|
| 273 |
+
|
| 274 |
+
# Convert to tensor [B, H, W, C] -> [B, C, H, W]
|
| 275 |
+
real_images_tensor = torch.from_numpy(np.stack(real_images)).permute(0, 3, 1, 2).float()
|
| 276 |
+
real_images_tensor = real_images_tensor.to(device)
|
| 277 |
+
|
| 278 |
+
# Generate images
|
| 279 |
+
generator = torch.Generator(device=device).manual_seed(seed + i)
|
| 280 |
+
|
| 281 |
+
with torch.no_grad():
|
| 282 |
+
result = pipeline(
|
| 283 |
+
prompt=batch_prompts,
|
| 284 |
+
num_inference_steps=num_inference_steps,
|
| 285 |
+
guidance_scale=guidance_scale,
|
| 286 |
+
generator=generator,
|
| 287 |
+
track_rewards=True,
|
| 288 |
+
print_rewards=False,
|
| 289 |
+
apply_gradient_ascent=apply_gradient_ascent,
|
| 290 |
+
verbose_grad=False,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# Process generated images
|
| 294 |
+
images = result.images
|
| 295 |
+
|
| 296 |
+
# Update FID metric if computing it
|
| 297 |
+
if compute_fid and fid_metric is not None:
|
| 298 |
+
image_tensors = []
|
| 299 |
+
|
| 300 |
+
for img in images:
|
| 301 |
+
img_resized = img.resize((512, 512)) # Inception v3 input size
|
| 302 |
+
img_array = np.array(img_resized)
|
| 303 |
+
image_tensors.append(img_array)
|
| 304 |
+
|
| 305 |
+
# Convert to tensor and update FID
|
| 306 |
+
images_tensor = torch.from_numpy(np.stack(image_tensors)).permute(0, 3, 1, 2).float()
|
| 307 |
+
images_tensor = images_tensor.to(device)
|
| 308 |
+
|
| 309 |
+
if batch_size == 1:
|
| 310 |
+
real_images_tensor = torch.cat([real_images_tensor, real_images_tensor], dim=0).to(dtype=torch.uint8)
|
| 311 |
+
images_tensor = torch.cat([images_tensor, images_tensor], dim=0).to(dtype=torch.uint8)
|
| 312 |
+
fid_metric.update(real_images_tensor, real=True)
|
| 313 |
+
fid_metric.update(images_tensor, real=False)
|
| 314 |
+
|
| 315 |
+
# Track rewards - get the final timestep reward (t=0)
|
| 316 |
+
current_batch_final_reward = None
|
| 317 |
+
current_batch_final_timestep = None
|
| 318 |
+
if hasattr(pipeline, 'reward_history') and pipeline.reward_history:
|
| 319 |
+
# For each image, get the reward from the last denoising step (t=0 or closest to 0)
|
| 320 |
+
num_steps_per_image = num_inference_steps
|
| 321 |
+
|
| 322 |
+
# Get the last entry which corresponds to the final timestep of the last image in batch
|
| 323 |
+
final_entry = pipeline.reward_history[-1]
|
| 324 |
+
current_batch_final_reward = final_entry['reward_score']
|
| 325 |
+
current_batch_final_timestep = final_entry['timestep']
|
| 326 |
+
all_rewards.append(current_batch_final_reward)
|
| 327 |
+
|
| 328 |
+
# Capture LR history from first image if gradient ascent is enabled
|
| 329 |
+
if apply_gradient_ascent and idx == 0 and lr_history_first_image is None:
|
| 330 |
+
if hasattr(pipeline, 'grad_guidance') and pipeline.grad_guidance:
|
| 331 |
+
grad_stats = pipeline.grad_guidance.get_statistics()
|
| 332 |
+
if grad_stats and 'detailed_stats' in grad_stats:
|
| 333 |
+
# Extract LR history from the gradient ascent statistics
|
| 334 |
+
lr_history_first_image = {
|
| 335 |
+
'prompt': batch_prompts[0],
|
| 336 |
+
'timesteps': [],
|
| 337 |
+
'learning_rates': [], # All LR values from all gradient steps
|
| 338 |
+
'rewards': []
|
| 339 |
+
}
|
| 340 |
+
for stat in grad_stats['detailed_stats']:
|
| 341 |
+
lr_history_first_image['timesteps'].append(stat['timestep'])
|
| 342 |
+
if 'lr_history' in stat:
|
| 343 |
+
# Extend with all LR values from this timestep's gradient steps
|
| 344 |
+
lr_history_first_image['learning_rates'].extend(stat['lr_history'])
|
| 345 |
+
# Collect all rewards from reward_history for each gradient step
|
| 346 |
+
if 'reward_history' in stat:
|
| 347 |
+
lr_history_first_image['rewards'].extend(stat['reward_history'])
|
| 348 |
+
|
| 349 |
+
# Compute CLIP score
|
| 350 |
+
if clip_scorer is not None:
|
| 351 |
+
# Convert PIL images to tensor format for CLIP score [C, H, W] in range [0, 1]
|
| 352 |
+
for img, prompt in zip(images, batch_prompts):
|
| 353 |
+
img_array = np.array(img).astype(np.float32)
|
| 354 |
+
img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).to(device)
|
| 355 |
+
clip_score = clip_scorer(img_tensor, [prompt]).item()
|
| 356 |
+
all_clip_scores.append(clip_score)
|
| 357 |
+
|
| 358 |
+
# Compute aesthetic score
|
| 359 |
+
if aesthetic_scorer is not None:
|
| 360 |
+
aesthetic_scores = aesthetic_scorer(images)
|
| 361 |
+
if isinstance(aesthetic_scores, torch.Tensor):
|
| 362 |
+
aesthetic_scores = aesthetic_scores.cpu().numpy()
|
| 363 |
+
if aesthetic_scores.ndim == 0:
|
| 364 |
+
aesthetic_scores = [aesthetic_scores.item()]
|
| 365 |
+
all_aesthetic_scores.extend(aesthetic_scores.tolist() if hasattr(aesthetic_scores, 'tolist') else [aesthetic_scores])
|
| 366 |
+
|
| 367 |
+
# Compute PickScore
|
| 368 |
+
if pick_scorer is not None:
|
| 369 |
+
for img, prompt in zip(images, batch_prompts):
|
| 370 |
+
pick_score = pick_scorer(prompt, [img])[0]
|
| 371 |
+
all_pick_scores.append(pick_score)
|
| 372 |
+
|
| 373 |
+
# Compute HPSv2 score
|
| 374 |
+
if hpsv2_scorer is not None:
|
| 375 |
+
for img, prompt in zip(images, batch_prompts):
|
| 376 |
+
hpsv2_score = hpsv2_scorer.score(img, prompt)[0]
|
| 377 |
+
all_hpsv2_scores.append(hpsv2_score)
|
| 378 |
+
|
| 379 |
+
# Compute HPSv2.1 score
|
| 380 |
+
if hpsv21_scorer is not None:
|
| 381 |
+
for img, prompt in zip(images, batch_prompts):
|
| 382 |
+
hpsv21_score = hpsv21_scorer.score(img, prompt)[0]
|
| 383 |
+
all_hpsv21_scores.append(hpsv21_score)
|
| 384 |
+
|
| 385 |
+
# Compute ImageReward score
|
| 386 |
+
if imagereward_scorer is not None:
|
| 387 |
+
for img, prompt in zip(images, batch_prompts):
|
| 388 |
+
imagereward_score = imagereward_scorer.score(prompt, img)
|
| 389 |
+
all_imagereward_scores.append(imagereward_score)
|
| 390 |
+
|
| 391 |
+
# Save generated images if requested
|
| 392 |
+
if save_images and output_dir:
|
| 393 |
+
for img_idx, img in enumerate(images):
|
| 394 |
+
global_idx = i + img_idx
|
| 395 |
+
img_path = mode_output_dir / f"sample_{global_idx:05d}.png"
|
| 396 |
+
img.save(img_path)
|
| 397 |
+
|
| 398 |
+
# Log intermediate FID and metrics every log_interval batches
|
| 399 |
+
if batch_num % log_interval == 0 or batch_num == num_batches:
|
| 400 |
+
num_samples_processed = min(i + batch_size, len(prompts))
|
| 401 |
+
log_msg = f"\n[{mode_name}] Batch {batch_num}/{num_batches} | Samples: {num_samples_processed}/{len(prompts)}"
|
| 402 |
+
|
| 403 |
+
# Add FID if computing
|
| 404 |
+
if compute_fid and fid_metric is not None:
|
| 405 |
+
try:
|
| 406 |
+
current_fid = fid_metric.compute().item()
|
| 407 |
+
log_msg += f" | FID: {current_fid:.4f}"
|
| 408 |
+
except Exception as e:
|
| 409 |
+
log_msg += f" | FID: Computing..."
|
| 410 |
+
|
| 411 |
+
# Add reward - show both final timestep reward and average
|
| 412 |
+
if all_rewards:
|
| 413 |
+
avg_reward = np.mean(all_rewards)
|
| 414 |
+
if current_batch_final_reward is not None:
|
| 415 |
+
log_msg += f" | Reward (t={current_batch_final_timestep}): {current_batch_final_reward:.4f}"
|
| 416 |
+
log_msg += f" | Reward (Avg): {avg_reward:.4f}"
|
| 417 |
+
else:
|
| 418 |
+
log_msg += f" | Reward (Avg): {avg_reward:.4f}"
|
| 419 |
+
|
| 420 |
+
# Add CLIP if computing
|
| 421 |
+
if clip_scorer is not None and all_clip_scores:
|
| 422 |
+
log_msg += f" | CLIP: {np.mean(all_clip_scores):.4f}"
|
| 423 |
+
|
| 424 |
+
# Add aesthetic if computing
|
| 425 |
+
if aesthetic_scorer is not None and all_aesthetic_scores:
|
| 426 |
+
log_msg += f" | Aesthetic: {np.mean(all_aesthetic_scores):.4f}"
|
| 427 |
+
|
| 428 |
+
# Add PickScore
|
| 429 |
+
if pick_scorer is not None and all_pick_scores:
|
| 430 |
+
log_msg += f" | PickScore: {np.mean(all_pick_scores):.4f}"
|
| 431 |
+
|
| 432 |
+
# Add HPSv2
|
| 433 |
+
if hpsv2_scorer is not None and all_hpsv2_scores:
|
| 434 |
+
log_msg += f" | HPSv2: {np.mean(all_hpsv2_scores):.4f}"
|
| 435 |
+
|
| 436 |
+
# Add HPSv2.1
|
| 437 |
+
if hpsv21_scorer is not None and all_hpsv21_scores:
|
| 438 |
+
log_msg += f" | HPSv2.1: {np.mean(all_hpsv21_scores):.4f}"
|
| 439 |
+
|
| 440 |
+
# Add ImageReward
|
| 441 |
+
if imagereward_scorer is not None and all_imagereward_scores:
|
| 442 |
+
log_msg += f" | ImageReward: {np.mean(all_imagereward_scores):.4f}"
|
| 443 |
+
|
| 444 |
+
print(log_msg)
|
| 445 |
+
|
| 446 |
+
# Re-enable progress bars
|
| 447 |
+
pipeline.set_progress_bar_config(disable=False)
|
| 448 |
+
|
| 449 |
+
avg_reward = np.mean(all_rewards) if all_rewards else 0.0
|
| 450 |
+
avg_clip_score = np.mean(all_clip_scores) if all_clip_scores else 0.0
|
| 451 |
+
avg_aesthetic_score = np.mean(all_aesthetic_scores) if all_aesthetic_scores else 0.0
|
| 452 |
+
avg_pick_score = np.mean(all_pick_scores) if all_pick_scores else 0.0
|
| 453 |
+
avg_hpsv2_score = np.mean(all_hpsv2_scores) if all_hpsv2_scores else 0.0
|
| 454 |
+
avg_hpsv21_score = np.mean(all_hpsv21_scores) if all_hpsv21_scores else 0.0
|
| 455 |
+
avg_imagereward_score = np.mean(all_imagereward_scores) if all_imagereward_scores else 0.0
|
| 456 |
+
|
| 457 |
+
return avg_reward, fid_metric, avg_clip_score, avg_aesthetic_score, avg_pick_score, avg_hpsv2_score, avg_hpsv21_score, avg_imagereward_score, lr_history_first_image
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def auto_increment_path(base_path):
|
| 461 |
+
"""
|
| 462 |
+
Create an auto-incrementing run folder inside base_path.
|
| 463 |
+
Returns: base_path/run_1, base_path/run_2, etc.
|
| 464 |
+
"""
|
| 465 |
+
base_path = Path(base_path)
|
| 466 |
+
base_path.mkdir(parents=True, exist_ok=True) # Ensure base directory exists
|
| 467 |
+
|
| 468 |
+
i = 1
|
| 469 |
+
while True:
|
| 470 |
+
new_path = base_path / f"run_{i}"
|
| 471 |
+
if not new_path.exists():
|
| 472 |
+
return new_path
|
| 473 |
+
i += 1
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def main():
|
| 477 |
+
parser = argparse.ArgumentParser(description="Evaluate baseline and gradient ascent pipelines")
|
| 478 |
+
parser.add_argument("--data_dir", type=str, default="./data", help="Path to data directory")
|
| 479 |
+
parser.add_argument("--dataset_type", type=str, default="coco", choices=["coco", "pickapic"],
|
| 480 |
+
help="Dataset to use for evaluation: coco or pickapic (default: coco)")
|
| 481 |
+
parser.add_argument("--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0", help="Base model path")
|
| 482 |
+
parser.add_argument("--model_variant", type=str, default="origin",
|
| 483 |
+
choices=["spo", "lpo"],
|
| 484 |
+
help="SDXL model variant to use (default: origin)")
|
| 485 |
+
parser.add_argument("--lrm_model", type=str, default="casiatao/LRM", help="LRM model path")
|
| 486 |
+
parser.add_argument("--num_steps", type=int, default=50, help="Number of inference steps")
|
| 487 |
+
parser.add_argument("--cfg_scale", type=float, default=7.5, help="Classifier-free guidance scale")
|
| 488 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
| 489 |
+
parser.add_argument("--max_samples", type=int, default=None, help="Max samples to evaluate (None for all)")
|
| 490 |
+
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generation (use 1 for reward model compatibility)")
|
| 491 |
+
parser.add_argument("--fid_batch_size", type=int, default=32, help="Batch size for FID computation")
|
| 492 |
+
parser.add_argument("--log_interval", type=int, default=10, help="Log FID and metrics every N batches")
|
| 493 |
+
parser.add_argument("--output_dir", type=str, default="eval_outputs", help="Directory to save generated images and results")
|
| 494 |
+
parser.add_argument("--save_images", action="store_true", help="Save all generated images to output directory")
|
| 495 |
+
parser.add_argument("--mode", type=str, default="both", choices=["baseline", "gradient_ascent", "both"],
|
| 496 |
+
help="Which evaluation to run: baseline, gradient_ascent, or both (default: both)")
|
| 497 |
+
|
| 498 |
+
# Metrics selection
|
| 499 |
+
parser.add_argument("--metrics", type=str, nargs="+", default=["clip", "aesthetic"],
|
| 500 |
+
choices=["fid", "clip", "aesthetic", "pickscore", "hpsv2", "hpsv21", "imagereward"],
|
| 501 |
+
help="Which metrics to evaluate (default: clip aesthetic)")
|
| 502 |
+
|
| 503 |
+
# Gradient ascent config
|
| 504 |
+
parser.add_argument("--grad_config", type=str, default=None,
|
| 505 |
+
help=f"Gradient ascent config preset (available: {', '.join(list_configs())}). "
|
| 506 |
+
"If provided, overrides individual grad_* arguments.")
|
| 507 |
+
parser.add_argument("--grad_range_start", type=int, default=0, help="Gradient timestep range start")
|
| 508 |
+
parser.add_argument("--grad_range_end", type=int, default=700, help="Gradient timestep range end")
|
| 509 |
+
parser.add_argument("--grad_steps", type=int, default=5, help="Number of gradient steps per timestep (use 5 for better reward improvement)")
|
| 510 |
+
parser.add_argument("--grad_step_size", type=float, default=0.1, help="Gradient step size (initial LR)")
|
| 511 |
+
|
| 512 |
+
# Config overrides (these override values from grad_config if specified)
|
| 513 |
+
parser.add_argument("--override_momentum", type=float, default=None, help="Override momentum value from grad_config")
|
| 514 |
+
parser.add_argument("--override_num_grad_steps", type=int, default=None, help="Override num_grad_steps from grad_config")
|
| 515 |
+
parser.add_argument("--override_grad_step_size", type=float, default=None, help="Override grad_step_size from grad_config")
|
| 516 |
+
|
| 517 |
+
# Cuda
|
| 518 |
+
parser.add_argument("--cuda", type=int, default=0, help="Use CUDA device id")
|
| 519 |
+
|
| 520 |
+
args = parser.parse_args()
|
| 521 |
+
|
| 522 |
+
seed_everything(args.seed)
|
| 523 |
+
|
| 524 |
+
# Configuration
|
| 525 |
+
device = f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu"
|
| 526 |
+
#dtype = torch.float16 #if torch.cuda.is_available() else torch.float32
|
| 527 |
+
dtype = torch.bfloat16
|
| 528 |
+
|
| 529 |
+
# Create auto-incremented output directory
|
| 530 |
+
args.output_dir = auto_increment_path(args.output_dir)
|
| 531 |
+
|
| 532 |
+
# Setup logging to file
|
| 533 |
+
tee_logger, log_file = setup_logging(args.output_dir)
|
| 534 |
+
|
| 535 |
+
print("="*70)
|
| 536 |
+
print("FID EVALUATION: BASELINE vs GRADIENT ASCENT")
|
| 537 |
+
print("="*70)
|
| 538 |
+
print(f"\nLogging to: {log_file}")
|
| 539 |
+
print(f"\nDevice: {device}")
|
| 540 |
+
print(f"Dataset: {args.dataset_type.upper()}")
|
| 541 |
+
print(f"Data directory: {args.data_dir}")
|
| 542 |
+
print(f"Base model: {args.base_model}")
|
| 543 |
+
print(f"Model variant: {args.model_variant}")
|
| 544 |
+
print(f"LRM model: {args.lrm_model}")
|
| 545 |
+
print(f"Inference steps: {args.num_steps}")
|
| 546 |
+
print(f"CFG scale: {args.cfg_scale}")
|
| 547 |
+
print(f"Batch size: {args.batch_size}")
|
| 548 |
+
print(f"Max samples: {args.max_samples or 'All'}")
|
| 549 |
+
print(f"Output directory: {args.output_dir}")
|
| 550 |
+
print(f"Save images: {args.save_images}")
|
| 551 |
+
print(f"Evaluation mode: {args.mode}")
|
| 552 |
+
print(f"Metrics to evaluate: {', '.join(args.metrics).upper()}")
|
| 553 |
+
if args.grad_config:
|
| 554 |
+
print(f"Gradient ascent config: {args.grad_config}")
|
| 555 |
+
|
| 556 |
+
# Load validation data
|
| 557 |
+
print("\n" + "="*70)
|
| 558 |
+
print("1. LOADING VALIDATION DATA")
|
| 559 |
+
print("="*70)
|
| 560 |
+
prompts, image_paths = load_validation_data(args.data_dir, args.max_samples, args.dataset_type)
|
| 561 |
+
|
| 562 |
+
# Automatically disable FID if no reference images available (e.g., Pick-a-Pic dataset)
|
| 563 |
+
can_compute_fid = image_paths is not None
|
| 564 |
+
if not can_compute_fid and "fid" in args.metrics:
|
| 565 |
+
print("\n⚠ Warning: FID metric requested but no reference images available. FID will be skipped.")
|
| 566 |
+
args.metrics = [m for m in args.metrics if m != "fid"]
|
| 567 |
+
|
| 568 |
+
# Load reward model
|
| 569 |
+
print("\n" + "="*70)
|
| 570 |
+
print("2. LOADING REWARD MODEL")
|
| 571 |
+
print("="*70)
|
| 572 |
+
reward_model = LRMRewardModelXL(
|
| 573 |
+
pretrained_model_name_or_path=args.base_model,
|
| 574 |
+
lrm_model_path=args.lrm_model,
|
| 575 |
+
guidance_scale=args.cfg_scale,
|
| 576 |
+
device=device
|
| 577 |
+
)
|
| 578 |
+
if dtype == torch.float16:
|
| 579 |
+
reward_model = reward_model.half()
|
| 580 |
+
elif dtype == torch.bfloat16:
|
| 581 |
+
reward_model = reward_model.to(torch.bfloat16)
|
| 582 |
+
reward_model.eval()
|
| 583 |
+
print("✓ Reward model loaded")
|
| 584 |
+
|
| 585 |
+
# Load pipeline
|
| 586 |
+
print("\n" + "="*70)
|
| 587 |
+
print("3. LOADING PIPELINE")
|
| 588 |
+
print("="*70)
|
| 589 |
+
|
| 590 |
+
# Load model based on variant
|
| 591 |
+
if args.model_variant == "spo":
|
| 592 |
+
base_pipeline = StableDiffusionXLPipeline.from_pretrained(
|
| 593 |
+
'SPO-Diffusion-Models/SPO-SDXL_4k-p_10ep',
|
| 594 |
+
torch_dtype=dtype,
|
| 595 |
+
safety_checker=None,
|
| 596 |
+
)
|
| 597 |
+
args.cfg_scale = 5.0 # SPO uses CFG 5.0
|
| 598 |
+
print(f"✓ Loaded SPO SDXL model (cfg_scale adjusted to 5.0)")
|
| 599 |
+
elif args.model_variant == "lpo":
|
| 600 |
+
unet = UNet2DConditionModel.from_pretrained(
|
| 601 |
+
'casiatao/LPO',
|
| 602 |
+
subfolder="lpo_sdxl_merge/unet",
|
| 603 |
+
torch_dtype=dtype
|
| 604 |
+
)
|
| 605 |
+
base_pipeline = StableDiffusionXLPipeline.from_pretrained(
|
| 606 |
+
args.base_model,
|
| 607 |
+
torch_dtype=dtype,
|
| 608 |
+
variant="fp16",
|
| 609 |
+
unet=unet
|
| 610 |
+
)
|
| 611 |
+
args.cfg_scale = 5.0 # LPO uses CFG 5.0
|
| 612 |
+
print(f"✓ Loaded LPO SDXL model (cfg_scale adjusted to 5.0)")
|
| 613 |
+
|
| 614 |
+
pipeline = StableDiffusionXLGradientAscentPipeline(**base_pipeline.components)
|
| 615 |
+
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
|
| 616 |
+
pipeline = pipeline.to(device)
|
| 617 |
+
pipeline.set_reward_model(reward_model)
|
| 618 |
+
print("✓ Pipeline loaded")
|
| 619 |
+
|
| 620 |
+
# Load CLIP scorer
|
| 621 |
+
print("\n" + "="*70)
|
| 622 |
+
print("3.5. LOADING CLIP AND AESTHETIC SCORERS")
|
| 623 |
+
print("="*70)
|
| 624 |
+
|
| 625 |
+
# Only load scorers for requested metrics
|
| 626 |
+
clip_scorer = None
|
| 627 |
+
aesthetic_scorer = None
|
| 628 |
+
pick_scorer = None
|
| 629 |
+
hpsv2_scorer = None
|
| 630 |
+
hpsv21_scorer = None
|
| 631 |
+
imagereward_scorer = None
|
| 632 |
+
|
| 633 |
+
if "clip" in args.metrics:
|
| 634 |
+
try:
|
| 635 |
+
clip_scorer = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14").to(device)
|
| 636 |
+
print("✓ CLIP scorer loaded")
|
| 637 |
+
except Exception as e:
|
| 638 |
+
print(f"Warning: Could not load CLIP scorer: {e}")
|
| 639 |
+
clip_scorer = None
|
| 640 |
+
else:
|
| 641 |
+
print("⊘ CLIP scorer skipped (not in selected metrics)")
|
| 642 |
+
|
| 643 |
+
if "aesthetic" in args.metrics:
|
| 644 |
+
try:
|
| 645 |
+
aesthetic_scorer = AestheticScorer(dtype=dtype, device=device)
|
| 646 |
+
print("✓ Aesthetic scorer loaded")
|
| 647 |
+
except Exception as e:
|
| 648 |
+
print(f"Warning: Could not load Aesthetic scorer: {e}")
|
| 649 |
+
aesthetic_scorer = None
|
| 650 |
+
else:
|
| 651 |
+
print("⊘ Aesthetic scorer skipped (not in selected metrics)")
|
| 652 |
+
|
| 653 |
+
if "pickscore" in args.metrics:
|
| 654 |
+
try:
|
| 655 |
+
pick_scorer = PickScorer(
|
| 656 |
+
processor_name_or_path="laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
|
| 657 |
+
model_pretrained_name_or_path="yuvalkirstain/PickScore_v1",
|
| 658 |
+
device=device
|
| 659 |
+
)
|
| 660 |
+
print("✓ PickScore scorer loaded")
|
| 661 |
+
except Exception as e:
|
| 662 |
+
print(f"Warning: Could not load PickScore scorer: {e}")
|
| 663 |
+
pick_scorer = None
|
| 664 |
+
else:
|
| 665 |
+
print("⊘ PickScore scorer skipped (not in selected metrics)")
|
| 666 |
+
|
| 667 |
+
if "hpsv2" in args.metrics:
|
| 668 |
+
try:
|
| 669 |
+
hpsv2_scorer = HPSv2Scorer(
|
| 670 |
+
clip_pretrained_name_or_path=hf_hub_download(
|
| 671 |
+
repo_id="laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
|
| 672 |
+
filename="open_clip_pytorch_model.bin"
|
| 673 |
+
),
|
| 674 |
+
model_pretrained_name_or_path=hf_hub_download(
|
| 675 |
+
repo_id="xswu/HPSv2",
|
| 676 |
+
filename="HPS_v2_compressed.pt"
|
| 677 |
+
),
|
| 678 |
+
device=device
|
| 679 |
+
)
|
| 680 |
+
print("✓ HPSv2 scorer loaded")
|
| 681 |
+
except Exception as e:
|
| 682 |
+
print(f"Warning: Could not load HPSv2 scorer: {e}")
|
| 683 |
+
hpsv2_scorer = None
|
| 684 |
+
else:
|
| 685 |
+
print("⊘ HPSv2 scorer skipped (not in selected metrics)")
|
| 686 |
+
|
| 687 |
+
if "hpsv21" in args.metrics:
|
| 688 |
+
try:
|
| 689 |
+
hpsv21_scorer = HPSv2Scorer(
|
| 690 |
+
clip_pretrained_name_or_path=hf_hub_download(
|
| 691 |
+
repo_id="laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
|
| 692 |
+
filename="open_clip_pytorch_model.bin"
|
| 693 |
+
),
|
| 694 |
+
model_pretrained_name_or_path=hf_hub_download(
|
| 695 |
+
repo_id="xswu/HPSv2",
|
| 696 |
+
filename="HPS_v2.1_compressed.pt"
|
| 697 |
+
),
|
| 698 |
+
device=device
|
| 699 |
+
)
|
| 700 |
+
print("✓ HPSv2.1 scorer loaded")
|
| 701 |
+
except Exception as e:
|
| 702 |
+
print(f"Warning: Could not load HPSv2.1 scorer: {e}")
|
| 703 |
+
hpsv21_scorer = None
|
| 704 |
+
else:
|
| 705 |
+
print("⊘ HPSv2.1 scorer skipped (not in selected metrics)")
|
| 706 |
+
|
| 707 |
+
if "imagereward" in args.metrics:
|
| 708 |
+
try:
|
| 709 |
+
imagereward_scorer = load_imagereward(
|
| 710 |
+
model_path=hf_hub_download(repo_id="THUDM/ImageReward", filename="ImageReward.pt"),
|
| 711 |
+
med_config=hf_hub_download(repo_id="THUDM/ImageReward", filename="med_config.json"),
|
| 712 |
+
device=device
|
| 713 |
+
)
|
| 714 |
+
print("✓ ImageReward scorer loaded")
|
| 715 |
+
except Exception as e:
|
| 716 |
+
print(f"Warning: Could not load ImageReward scorer: {e}")
|
| 717 |
+
imagereward_scorer = None
|
| 718 |
+
else:
|
| 719 |
+
print("⊘ ImageReward scorer skipped (not in selected metrics)")
|
| 720 |
+
|
| 721 |
+
# Configure gradient ascent
|
| 722 |
+
print("\n" + "="*70)
|
| 723 |
+
print("4. CONFIGURING GRADIENT ASCENT")
|
| 724 |
+
print("="*70)
|
| 725 |
+
|
| 726 |
+
# Use config preset if provided, otherwise use individual args
|
| 727 |
+
if args.grad_config:
|
| 728 |
+
print(f"Loading gradient ascent config: {args.grad_config}")
|
| 729 |
+
grad_config = get_config(args.grad_config)
|
| 730 |
+
print(f"Config loaded: {grad_config}")
|
| 731 |
+
|
| 732 |
+
# Apply overrides if specified
|
| 733 |
+
if args.override_momentum is not None:
|
| 734 |
+
grad_config['momentum'] = args.override_momentum
|
| 735 |
+
print(f" Overriding momentum: {args.override_momentum}")
|
| 736 |
+
if args.override_num_grad_steps is not None:
|
| 737 |
+
grad_config['num_grad_steps'] = args.override_num_grad_steps
|
| 738 |
+
print(f" Overriding num_grad_steps: {args.override_num_grad_steps}")
|
| 739 |
+
if args.override_grad_step_size is not None:
|
| 740 |
+
grad_config['grad_step_size'] = args.override_grad_step_size
|
| 741 |
+
print(f" Overriding grad_step_size: {args.override_grad_step_size}")
|
| 742 |
+
else:
|
| 743 |
+
grad_config = {
|
| 744 |
+
"grad_timestep_range": (args.grad_range_start, args.grad_range_end),
|
| 745 |
+
"num_grad_steps": args.grad_steps,
|
| 746 |
+
"grad_step_size": args.grad_step_size,
|
| 747 |
+
}
|
| 748 |
+
print(f"Using manual gradient ascent configuration")
|
| 749 |
+
|
| 750 |
+
print(f"Gradient timestep range: {grad_config.get('grad_timestep_range', (args.grad_range_start, args.grad_range_end))}")
|
| 751 |
+
print(f"Gradient steps: {grad_config.get('num_grad_steps', args.grad_steps)}")
|
| 752 |
+
print(f"Gradient step size (initial LR): {grad_config.get('grad_step_size', args.grad_step_size)}")
|
| 753 |
+
if grad_config.get('lr_scheduler_type'):
|
| 754 |
+
print(f"LR Scheduler: {grad_config['lr_scheduler_type']}")
|
| 755 |
+
if grad_config.get('use_momentum'):
|
| 756 |
+
print(f"Momentum: {grad_config.get('momentum', 0.9)} (Nesterov: {grad_config.get('use_nesterov', False)})")
|
| 757 |
+
|
| 758 |
+
pipeline.enable_gradient_ascent(**grad_config)
|
| 759 |
+
|
| 760 |
+
# Initialize result variables
|
| 761 |
+
fid_score_baseline = None
|
| 762 |
+
avg_reward_baseline = None
|
| 763 |
+
clip_score_baseline = None
|
| 764 |
+
aesthetic_score_baseline = None
|
| 765 |
+
pick_score_baseline = None
|
| 766 |
+
hpsv2_score_baseline = None
|
| 767 |
+
hpsv21_score_baseline = None
|
| 768 |
+
imagereward_score_baseline = None
|
| 769 |
+
fid_score_grad = None
|
| 770 |
+
avg_reward_grad = None
|
| 771 |
+
clip_score_grad = None
|
| 772 |
+
aesthetic_score_grad = None
|
| 773 |
+
pick_score_grad = None
|
| 774 |
+
hpsv2_score_grad = None
|
| 775 |
+
hpsv21_score_grad = None
|
| 776 |
+
imagereward_score_grad = None
|
| 777 |
+
grad_stats = None
|
| 778 |
+
|
| 779 |
+
# ========== BASELINE EVALUATION ==========
|
| 780 |
+
if args.mode in ["baseline", "both"]:
|
| 781 |
+
print("\n" + "="*70)
|
| 782 |
+
print("5. EVALUATING BASELINE")
|
| 783 |
+
print("="*70)
|
| 784 |
+
|
| 785 |
+
# Generate and evaluate baseline
|
| 786 |
+
avg_reward_baseline, fid_baseline, clip_score_baseline, aesthetic_score_baseline, pick_score_baseline, hpsv2_score_baseline, hpsv21_score_baseline, imagereward_score_baseline, _ = generate_and_evaluate(
|
| 787 |
+
pipeline=pipeline,
|
| 788 |
+
prompts=prompts,
|
| 789 |
+
image_paths=image_paths,
|
| 790 |
+
device=device,
|
| 791 |
+
dtype=dtype,
|
| 792 |
+
num_inference_steps=args.num_steps,
|
| 793 |
+
guidance_scale=args.cfg_scale,
|
| 794 |
+
seed=args.seed,
|
| 795 |
+
batch_size=args.batch_size,
|
| 796 |
+
apply_gradient_ascent=False,
|
| 797 |
+
mode_name="baseline",
|
| 798 |
+
log_interval=args.log_interval,
|
| 799 |
+
output_dir=args.output_dir,
|
| 800 |
+
save_images=args.save_images,
|
| 801 |
+
clip_scorer=clip_scorer,
|
| 802 |
+
aesthetic_scorer=aesthetic_scorer,
|
| 803 |
+
pick_scorer=pick_scorer,
|
| 804 |
+
hpsv2_scorer=hpsv2_scorer,
|
| 805 |
+
hpsv21_scorer=hpsv21_scorer,
|
| 806 |
+
imagereward_scorer=imagereward_scorer,
|
| 807 |
+
compute_fid=("fid" in args.metrics and can_compute_fid)
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
# Compute FID for baseline if requested
|
| 811 |
+
if "fid" in args.metrics and fid_baseline is not None:
|
| 812 |
+
fid_score_baseline = fid_baseline.compute().item()
|
| 813 |
+
print(f"\n✓ Baseline FID: {fid_score_baseline:.4f}")
|
| 814 |
+
print(f"✓ Baseline Avg Reward: {avg_reward_baseline:.4f}")
|
| 815 |
+
if "clip" in args.metrics:
|
| 816 |
+
print(f"✓ Baseline Avg CLIP Score: {clip_score_baseline:.4f}")
|
| 817 |
+
if "aesthetic" in args.metrics:
|
| 818 |
+
print(f"✓ Baseline Avg Aesthetic Score: {aesthetic_score_baseline:.4f}")
|
| 819 |
+
if "pickscore" in args.metrics and pick_score_baseline is not None:
|
| 820 |
+
print(f"✓ Baseline Avg PickScore: {pick_score_baseline:.4f}")
|
| 821 |
+
if "hpsv2" in args.metrics and hpsv2_score_baseline is not None:
|
| 822 |
+
print(f"✓ Baseline Avg HPSv2 Score: {hpsv2_score_baseline:.4f}")
|
| 823 |
+
if "hpsv21" in args.metrics and hpsv21_score_baseline is not None:
|
| 824 |
+
print(f"✓ Baseline Avg HPSv2.1 Score: {hpsv21_score_baseline:.4f}")
|
| 825 |
+
if "imagereward" in args.metrics and imagereward_score_baseline is not None:
|
| 826 |
+
print(f"✓ Baseline Avg ImageReward: {imagereward_score_baseline:.4f}")
|
| 827 |
+
|
| 828 |
+
# ========== GRADIENT ASCENT EVALUATION ==========
|
| 829 |
+
if args.mode in ["gradient_ascent", "both"]:
|
| 830 |
+
print("\n" + "="*70)
|
| 831 |
+
print("6. EVALUATING GRADIENT ASCENT")
|
| 832 |
+
print("="*70)
|
| 833 |
+
|
| 834 |
+
# Generate and evaluate with gradient ascent
|
| 835 |
+
avg_reward_grad, fid_grad, clip_score_grad, aesthetic_score_grad, pick_score_grad, hpsv2_score_grad, hpsv21_score_grad, imagereward_score_grad, lr_history = generate_and_evaluate(
|
| 836 |
+
pipeline=pipeline,
|
| 837 |
+
prompts=prompts,
|
| 838 |
+
image_paths=image_paths,
|
| 839 |
+
device=device,
|
| 840 |
+
dtype=dtype,
|
| 841 |
+
num_inference_steps=args.num_steps,
|
| 842 |
+
guidance_scale=args.cfg_scale,
|
| 843 |
+
seed=args.seed,
|
| 844 |
+
batch_size=args.batch_size,
|
| 845 |
+
apply_gradient_ascent=True,
|
| 846 |
+
mode_name="gradient_ascent",
|
| 847 |
+
log_interval=args.log_interval,
|
| 848 |
+
output_dir=args.output_dir,
|
| 849 |
+
save_images=args.save_images,
|
| 850 |
+
clip_scorer=clip_scorer,
|
| 851 |
+
aesthetic_scorer=aesthetic_scorer,
|
| 852 |
+
pick_scorer=pick_scorer,
|
| 853 |
+
hpsv2_scorer=hpsv2_scorer,
|
| 854 |
+
hpsv21_scorer=hpsv21_scorer,
|
| 855 |
+
imagereward_scorer=imagereward_scorer,
|
| 856 |
+
compute_fid=("fid" in args.metrics and can_compute_fid)
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
# Compute FID for gradient ascent if requested
|
| 860 |
+
if "fid" in args.metrics and fid_grad is not None:
|
| 861 |
+
fid_score_grad = fid_grad.compute().item()
|
| 862 |
+
print(f"\n✓ Gradient Ascent FID: {fid_score_grad:.4f}")
|
| 863 |
+
print(f"✓ Gradient Ascent Avg Reward: {avg_reward_grad:.4f}")
|
| 864 |
+
if "clip" in args.metrics:
|
| 865 |
+
print(f"✓ Gradient Ascent Avg CLIP Score: {clip_score_grad:.4f}")
|
| 866 |
+
if "aesthetic" in args.metrics:
|
| 867 |
+
print(f"✓ Gradient Ascent Avg Aesthetic Score: {aesthetic_score_grad:.4f}")
|
| 868 |
+
if "pickscore" in args.metrics and pick_score_grad is not None:
|
| 869 |
+
print(f"✓ Gradient Ascent Avg PickScore: {pick_score_grad:.4f}")
|
| 870 |
+
if "hpsv2" in args.metrics and hpsv2_score_grad is not None:
|
| 871 |
+
print(f"✓ Gradient Ascent Avg HPSv2 Score: {hpsv2_score_grad:.4f}")
|
| 872 |
+
if "hpsv21" in args.metrics and hpsv21_score_grad is not None:
|
| 873 |
+
print(f"✓ Gradient Ascent Avg HPSv2.1 Score: {hpsv21_score_grad:.4f}")
|
| 874 |
+
if "imagereward" in args.metrics and imagereward_score_grad is not None:
|
| 875 |
+
print(f"✓ Gradient Ascent Avg ImageReward: {imagereward_score_grad:.4f}")
|
| 876 |
+
|
| 877 |
+
# Get gradient stats
|
| 878 |
+
grad_stats = pipeline.grad_guidance.get_statistics()
|
| 879 |
+
if grad_stats:
|
| 880 |
+
print(f"\nGradient Ascent Statistics:")
|
| 881 |
+
print(f" Applications: {grad_stats['num_applications']}")
|
| 882 |
+
print(f" Total reward improvement: {grad_stats['total_reward_improvement']:+.4f}")
|
| 883 |
+
print(f" Avg reward improvement: {grad_stats['avg_reward_improvement']:+.4f}")
|
| 884 |
+
|
| 885 |
+
# Plot LR curve if we captured it
|
| 886 |
+
if lr_history is not None and lr_history['learning_rates']:
|
| 887 |
+
plot_path = Path(args.output_dir) / "lr_curve.png"
|
| 888 |
+
|
| 889 |
+
# LR values are now continuous across all gradient steps
|
| 890 |
+
lrs = lr_history['learning_rates']
|
| 891 |
+
steps = list(range(len(lrs))) # Step indices (0 to total_steps-1)
|
| 892 |
+
|
| 893 |
+
plt.figure(figsize=(12, 6))
|
| 894 |
+
plt.plot(steps, lrs, linewidth=2, color='blue', alpha=0.8)
|
| 895 |
+
|
| 896 |
+
# Mark the first step with a star
|
| 897 |
+
plt.plot(steps[0], lrs[0], marker='*', markersize=20, color='gold',
|
| 898 |
+
markeredgecolor='darkgoldenrod', markeredgewidth=2, zorder=5)
|
| 899 |
+
|
| 900 |
+
# Mark timestep boundaries
|
| 901 |
+
num_timesteps = len(lr_history['timesteps'])
|
| 902 |
+
num_grad_steps_per_timestep = len(lrs) // num_timesteps if num_timesteps > 0 else 0
|
| 903 |
+
if num_grad_steps_per_timestep > 0:
|
| 904 |
+
for i in range(num_timesteps + 1):
|
| 905 |
+
step_idx = i * num_grad_steps_per_timestep
|
| 906 |
+
if step_idx <= len(lrs):
|
| 907 |
+
plt.axvline(x=step_idx, color='red', linestyle='--', alpha=0.3, linewidth=1)
|
| 908 |
+
if i < num_timesteps:
|
| 909 |
+
plt.text(step_idx, plt.ylim()[1] * 0.95, f't={lr_history["timesteps"][i]}',
|
| 910 |
+
fontsize=8, color='red', alpha=0.7, ha='left')
|
| 911 |
+
|
| 912 |
+
plt.xlabel('Global Gradient Step', fontsize=12)
|
| 913 |
+
plt.ylabel('Learning Rate', fontsize=12)
|
| 914 |
+
plt.title(f'Learning Rate Evolution Across All Gradient Steps\\nPrompt: "{lr_history["prompt"][:60]}..."',
|
| 915 |
+
fontsize=12, fontweight='bold')
|
| 916 |
+
plt.grid(True, alpha=0.3)
|
| 917 |
+
|
| 918 |
+
# Add info text
|
| 919 |
+
num_timesteps = len(lr_history['timesteps'])
|
| 920 |
+
num_grad_steps_per_timestep = len(lrs) // num_timesteps if num_timesteps > 0 else 0
|
| 921 |
+
plt.text(0.02, 0.98,
|
| 922 |
+
f'Total timesteps: {num_timesteps}\\nGrad steps/timestep: {num_grad_steps_per_timestep}\\nTotal grad steps: {len(lrs)}',
|
| 923 |
+
transform=plt.gca().transAxes, fontsize=10, verticalalignment='top',
|
| 924 |
+
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
| 925 |
+
|
| 926 |
+
plt.tight_layout()
|
| 927 |
+
plt.savefig(plot_path, dpi=150, bbox_inches='tight')
|
| 928 |
+
plt.close()
|
| 929 |
+
print(f"\n✓ Saved LR curve plot to: {plot_path}")
|
| 930 |
+
print(f" Total gradient steps: {len(lrs)}")
|
| 931 |
+
print(f" LR range: {min(lrs):.6f} → {max(lrs):.6f}")
|
| 932 |
+
|
| 933 |
+
# Plot Rewards curve if we captured it
|
| 934 |
+
if lr_history is not None and lr_history['rewards']:
|
| 935 |
+
plot_path = Path(args.output_dir) / "rewards_curve.png"
|
| 936 |
+
|
| 937 |
+
# Reward values are now continuous across all gradient steps
|
| 938 |
+
rewards = lr_history['rewards']
|
| 939 |
+
steps = list(range(len(rewards))) # Step indices (0 to total_steps-1)
|
| 940 |
+
|
| 941 |
+
plt.figure(figsize=(12, 6))
|
| 942 |
+
plt.plot(steps, rewards, linewidth=2, color='green', alpha=0.8)
|
| 943 |
+
|
| 944 |
+
# Mark the first step with a star
|
| 945 |
+
plt.plot(steps[0], rewards[0], marker='*', markersize=20, color='gold',
|
| 946 |
+
markeredgecolor='darkgoldenrod', markeredgewidth=2, zorder=5)
|
| 947 |
+
|
| 948 |
+
# Mark timestep boundaries
|
| 949 |
+
num_timesteps = len(lr_history['timesteps'])
|
| 950 |
+
# rewards has one extra value at the start (initial) compared to gradient steps
|
| 951 |
+
num_grad_steps_per_timestep = (len(rewards) - num_timesteps) // num_timesteps if num_timesteps > 0 else 0
|
| 952 |
+
if num_grad_steps_per_timestep > 0:
|
| 953 |
+
for i in range(num_timesteps + 1):
|
| 954 |
+
step_idx = i * (num_grad_steps_per_timestep + 1) # +1 because reward_history includes initial
|
| 955 |
+
if step_idx <= len(rewards):
|
| 956 |
+
plt.axvline(x=step_idx, color='red', linestyle='--', alpha=0.3, linewidth=1)
|
| 957 |
+
if i < num_timesteps:
|
| 958 |
+
plt.text(step_idx, plt.ylim()[1] * 0.95, f't={lr_history["timesteps"][i]}',
|
| 959 |
+
fontsize=8, color='red', alpha=0.7, ha='left')
|
| 960 |
+
|
| 961 |
+
plt.xlabel('Global Gradient Step', fontsize=12)
|
| 962 |
+
plt.ylabel('Reward Score', fontsize=12)
|
| 963 |
+
plt.title(f'Reward Evolution Across All Gradient Steps\nPrompt: "{lr_history["prompt"][:60]}..."',
|
| 964 |
+
fontsize=12, fontweight='bold')
|
| 965 |
+
plt.grid(True, alpha=0.3)
|
| 966 |
+
|
| 967 |
+
# Add info text
|
| 968 |
+
num_timesteps = len(lr_history['timesteps'])
|
| 969 |
+
reward_improvement = rewards[-1] - rewards[0] if len(rewards) > 1 else 0
|
| 970 |
+
plt.text(0.02, 0.98,
|
| 971 |
+
f'Total timesteps: {num_timesteps}\nTotal grad steps: {len(rewards)}\n'
|
| 972 |
+
f'Initial reward: {rewards[0]:.4f}\nFinal reward: {rewards[-1]:.4f}\n'
|
| 973 |
+
f'Improvement: {reward_improvement:+.4f}',
|
| 974 |
+
transform=plt.gca().transAxes, fontsize=10, verticalalignment='top',
|
| 975 |
+
bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.5))
|
| 976 |
+
|
| 977 |
+
plt.tight_layout()
|
| 978 |
+
plt.savefig(plot_path, dpi=150, bbox_inches='tight')
|
| 979 |
+
plt.close()
|
| 980 |
+
print(f"\n✓ Saved Rewards curve plot to: {plot_path}")
|
| 981 |
+
print(f" Total gradient steps: {len(rewards)}")
|
| 982 |
+
print(f" Reward range: {min(rewards):.4f} → {max(rewards):.4f}")
|
| 983 |
+
print(f" Total improvement: {reward_improvement:+.4f}")
|
| 984 |
+
|
| 985 |
+
# ========== FINAL RESULTS ==========
|
| 986 |
+
print("\n" + "="*70)
|
| 987 |
+
print("FINAL RESULTS")
|
| 988 |
+
print("="*70)
|
| 989 |
+
|
| 990 |
+
if avg_reward_baseline is not None:
|
| 991 |
+
print(f"\nBaseline:")
|
| 992 |
+
if fid_score_baseline is not None:
|
| 993 |
+
print(f" FID Score: {fid_score_baseline:.4f}")
|
| 994 |
+
print(f" Avg Reward: {avg_reward_baseline:.4f}")
|
| 995 |
+
if "clip" in args.metrics and clip_score_baseline is not None:
|
| 996 |
+
print(f" Avg CLIP Score: {clip_score_baseline:.4f}")
|
| 997 |
+
if "aesthetic" in args.metrics and aesthetic_score_baseline is not None:
|
| 998 |
+
print(f" Avg Aesthetic: {aesthetic_score_baseline:.4f}")
|
| 999 |
+
if "pickscore" in args.metrics and pick_score_baseline is not None:
|
| 1000 |
+
print(f" Avg PickScore: {pick_score_baseline:.4f}")
|
| 1001 |
+
if "hpsv2" in args.metrics and hpsv2_score_baseline is not None:
|
| 1002 |
+
print(f" Avg HPSv2: {hpsv2_score_baseline:.4f}")
|
| 1003 |
+
if "hpsv21" in args.metrics and hpsv21_score_baseline is not None:
|
| 1004 |
+
print(f" Avg HPSv2.1: {hpsv21_score_baseline:.4f}")
|
| 1005 |
+
if "imagereward" in args.metrics and imagereward_score_baseline is not None:
|
| 1006 |
+
print(f" Avg ImageReward: {imagereward_score_baseline:.4f}")
|
| 1007 |
+
|
| 1008 |
+
if avg_reward_grad is not None:
|
| 1009 |
+
print(f"\nGradient Ascent:")
|
| 1010 |
+
if fid_score_grad is not None:
|
| 1011 |
+
print(f" FID Score: {fid_score_grad:.4f}")
|
| 1012 |
+
print(f" Avg Reward: {avg_reward_grad:.4f}")
|
| 1013 |
+
if "clip" in args.metrics and clip_score_grad is not None:
|
| 1014 |
+
print(f" Avg CLIP Score: {clip_score_grad:.4f}")
|
| 1015 |
+
if "aesthetic" in args.metrics and aesthetic_score_grad is not None:
|
| 1016 |
+
print(f" Avg Aesthetic: {aesthetic_score_grad:.4f}")
|
| 1017 |
+
if "pickscore" in args.metrics and pick_score_grad is not None:
|
| 1018 |
+
print(f" Avg PickScore: {pick_score_grad:.4f}")
|
| 1019 |
+
if "hpsv2" in args.metrics and hpsv2_score_grad is not None:
|
| 1020 |
+
print(f" Avg HPSv2: {hpsv2_score_grad:.4f}")
|
| 1021 |
+
if "hpsv21" in args.metrics and hpsv21_score_grad is not None:
|
| 1022 |
+
print(f" Avg HPSv2.1: {hpsv21_score_grad:.4f}")
|
| 1023 |
+
if "imagereward" in args.metrics and imagereward_score_grad is not None:
|
| 1024 |
+
print(f" Avg ImageReward: {imagereward_score_grad:.4f}")
|
| 1025 |
+
|
| 1026 |
+
if avg_reward_baseline is not None and avg_reward_grad is not None:
|
| 1027 |
+
print(f"\nComparison:")
|
| 1028 |
+
if fid_score_baseline is not None and fid_score_grad is not None:
|
| 1029 |
+
fid_diff = fid_score_grad - fid_score_baseline
|
| 1030 |
+
print(f" FID Change: {fid_diff:+.4f} ({'worse' if fid_diff > 0 else 'better'}, lower is better)")
|
| 1031 |
+
reward_diff = avg_reward_grad - avg_reward_baseline
|
| 1032 |
+
print(f" Reward Change: {reward_diff:+.4f} ({'better' if reward_diff > 0 else 'worse'}, higher is better)")
|
| 1033 |
+
if "clip" in args.metrics and clip_score_baseline is not None and clip_score_grad is not None:
|
| 1034 |
+
clip_diff = clip_score_grad - clip_score_baseline
|
| 1035 |
+
print(f" CLIP Change: {clip_diff:+.4f} ({'better' if clip_diff > 0 else 'worse'}, higher is better)")
|
| 1036 |
+
if "aesthetic" in args.metrics and aesthetic_score_baseline is not None and aesthetic_score_grad is not None:
|
| 1037 |
+
aesthetic_diff = aesthetic_score_grad - aesthetic_score_baseline
|
| 1038 |
+
print(f" Aesthetic Change: {aesthetic_diff:+.4f} ({'better' if aesthetic_diff > 0 else 'worse'}, higher is better)")
|
| 1039 |
+
if "pickscore" in args.metrics and pick_score_baseline is not None and pick_score_grad is not None:
|
| 1040 |
+
pick_diff = pick_score_grad - pick_score_baseline
|
| 1041 |
+
print(f" PickScore Change: {pick_diff:+.4f} ({'better' if pick_diff > 0 else 'worse'}, higher is better)")
|
| 1042 |
+
if "hpsv2" in args.metrics and hpsv2_score_baseline is not None and hpsv2_score_grad is not None:
|
| 1043 |
+
hpsv2_diff = hpsv2_score_grad - hpsv2_score_baseline
|
| 1044 |
+
print(f" HPSv2 Change: {hpsv2_diff:+.4f} ({'better' if hpsv2_diff > 0 else 'worse'}, higher is better)")
|
| 1045 |
+
if "hpsv21" in args.metrics and hpsv21_score_baseline is not None and hpsv21_score_grad is not None:
|
| 1046 |
+
hpsv21_diff = hpsv21_score_grad - hpsv21_score_baseline
|
| 1047 |
+
print(f" HPSv2.1 Change: {hpsv21_diff:+.4f} ({'better' if hpsv21_diff > 0 else 'worse'}, higher is better)")
|
| 1048 |
+
if "imagereward" in args.metrics and imagereward_score_baseline is not None and imagereward_score_grad is not None:
|
| 1049 |
+
imagereward_diff = imagereward_score_grad - imagereward_score_baseline
|
| 1050 |
+
print(f" ImageReward Chg: {imagereward_diff:+.4f} ({'better' if imagereward_diff > 0 else 'worse'}, higher is better)")
|
| 1051 |
+
|
| 1052 |
+
# Save results to file
|
| 1053 |
+
results = {
|
| 1054 |
+
"mode": args.mode,
|
| 1055 |
+
"metrics": args.metrics,
|
| 1056 |
+
"config": {
|
| 1057 |
+
"num_samples": len(prompts),
|
| 1058 |
+
"num_steps": args.num_steps,
|
| 1059 |
+
"cfg_scale": args.cfg_scale,
|
| 1060 |
+
"grad_range": [args.grad_range_start, args.grad_range_end],
|
| 1061 |
+
"grad_steps": args.grad_steps,
|
| 1062 |
+
"grad_step_size": args.grad_step_size
|
| 1063 |
+
}
|
| 1064 |
+
}
|
| 1065 |
+
|
| 1066 |
+
if avg_reward_baseline is not None:
|
| 1067 |
+
results["baseline"] = {"avg_reward": avg_reward_baseline}
|
| 1068 |
+
if fid_score_baseline is not None:
|
| 1069 |
+
results["baseline"]["fid"] = fid_score_baseline
|
| 1070 |
+
if "clip" in args.metrics and clip_score_baseline is not None:
|
| 1071 |
+
results["baseline"]["clip_score"] = clip_score_baseline
|
| 1072 |
+
if "aesthetic" in args.metrics and aesthetic_score_baseline is not None:
|
| 1073 |
+
results["baseline"]["aesthetic_score"] = aesthetic_score_baseline
|
| 1074 |
+
if "pickscore" in args.metrics and pick_score_baseline is not None:
|
| 1075 |
+
results["baseline"]["pickscore"] = pick_score_baseline
|
| 1076 |
+
if "hpsv2" in args.metrics and hpsv2_score_baseline is not None:
|
| 1077 |
+
results["baseline"]["hpsv2_score"] = hpsv2_score_baseline
|
| 1078 |
+
if "hpsv21" in args.metrics and hpsv21_score_baseline is not None:
|
| 1079 |
+
results["baseline"]["hpsv21_score"] = hpsv21_score_baseline
|
| 1080 |
+
if "imagereward" in args.metrics and imagereward_score_baseline is not None:
|
| 1081 |
+
results["baseline"]["imagereward_score"] = imagereward_score_baseline
|
| 1082 |
+
|
| 1083 |
+
if avg_reward_grad is not None:
|
| 1084 |
+
results["gradient_ascent"] = {"avg_reward": avg_reward_grad}
|
| 1085 |
+
if fid_score_grad is not None:
|
| 1086 |
+
results["gradient_ascent"]["fid"] = fid_score_grad
|
| 1087 |
+
if "clip" in args.metrics and clip_score_grad is not None:
|
| 1088 |
+
results["gradient_ascent"]["clip_score"] = clip_score_grad
|
| 1089 |
+
if "aesthetic" in args.metrics and aesthetic_score_grad is not None:
|
| 1090 |
+
results["gradient_ascent"]["aesthetic_score"] = aesthetic_score_grad
|
| 1091 |
+
if "pickscore" in args.metrics and pick_score_grad is not None:
|
| 1092 |
+
results["gradient_ascent"]["pickscore"] = pick_score_grad
|
| 1093 |
+
if "hpsv2" in args.metrics and hpsv2_score_grad is not None:
|
| 1094 |
+
results["gradient_ascent"]["hpsv2_score"] = hpsv2_score_grad
|
| 1095 |
+
if "hpsv21" in args.metrics and hpsv21_score_grad is not None:
|
| 1096 |
+
results["gradient_ascent"]["hpsv21_score"] = hpsv21_score_grad
|
| 1097 |
+
if "imagereward" in args.metrics and imagereward_score_grad is not None:
|
| 1098 |
+
results["gradient_ascent"]["imagereward_score"] = imagereward_score_grad
|
| 1099 |
+
if grad_stats:
|
| 1100 |
+
results["gradient_ascent"]["stats"] = grad_stats
|
| 1101 |
+
|
| 1102 |
+
if avg_reward_baseline is not None and avg_reward_grad is not None:
|
| 1103 |
+
results["comparison"] = {
|
| 1104 |
+
"reward_difference": avg_reward_grad - avg_reward_baseline
|
| 1105 |
+
}
|
| 1106 |
+
if fid_score_baseline is not None and fid_score_grad is not None:
|
| 1107 |
+
results["comparison"]["fid_difference"] = fid_score_grad - fid_score_baseline
|
| 1108 |
+
if "clip" in args.metrics and clip_score_baseline is not None and clip_score_grad is not None:
|
| 1109 |
+
results["comparison"]["clip_difference"] = clip_score_grad - clip_score_baseline
|
| 1110 |
+
if "aesthetic" in args.metrics and aesthetic_score_baseline is not None and aesthetic_score_grad is not None:
|
| 1111 |
+
results["comparison"]["aesthetic_difference"] = aesthetic_score_grad - aesthetic_score_baseline
|
| 1112 |
+
if "pickscore" in args.metrics and pick_score_baseline is not None and pick_score_grad is not None:
|
| 1113 |
+
results["comparison"]["pickscore_difference"] = pick_score_grad - pick_score_baseline
|
| 1114 |
+
if "hpsv2" in args.metrics and hpsv2_score_baseline is not None and hpsv2_score_grad is not None:
|
| 1115 |
+
results["comparison"]["hpsv2_difference"] = hpsv2_score_grad - hpsv2_score_baseline
|
| 1116 |
+
if "hpsv21" in args.metrics and hpsv21_score_baseline is not None and hpsv21_score_grad is not None:
|
| 1117 |
+
results["comparison"]["hpsv21_difference"] = hpsv21_score_grad - hpsv21_score_baseline
|
| 1118 |
+
if "imagereward" in args.metrics and imagereward_score_baseline is not None and imagereward_score_grad is not None:
|
| 1119 |
+
results["comparison"]["imagereward_difference"] = imagereward_score_grad - imagereward_score_baseline
|
| 1120 |
+
|
| 1121 |
+
# Save results to output directory
|
| 1122 |
+
output_path = Path(args.output_dir)
|
| 1123 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 1124 |
+
results_path = output_path / "evaluation_results.txt"
|
| 1125 |
+
|
| 1126 |
+
with open(results_path, "w") as f:
|
| 1127 |
+
for k, v in results.items():
|
| 1128 |
+
f.write(f"{k}: {v}\n")
|
| 1129 |
+
|
| 1130 |
+
|
| 1131 |
+
print(f"\n✓ Results saved to: {results_path}")
|
| 1132 |
+
if args.save_images:
|
| 1133 |
+
print(f"✓ Generated images saved to: {output_path}/baseline/ and {output_path}/gradient_ascent/")
|
| 1134 |
+
print("\n" + "="*70)
|
| 1135 |
+
|
| 1136 |
+
# Close logger
|
| 1137 |
+
tee_logger.close()
|
| 1138 |
+
sys.stdout = tee_logger.terminal
|
| 1139 |
+
|
| 1140 |
+
|
| 1141 |
+
if __name__ == "__main__":
|
| 1142 |
+
main()
|
| 1143 |
+
|
Reward_sdxl_idealized/examples.sh
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export HF_HOME=/efs/vkvermaa/2025/diffusion/latent_correct/hf_cache
|
| 2 |
+
Dataset_Name="pickapic" # "coco" or "pickapic"
|
| 3 |
+
Grad_Config="one_step_rectification_config"
|
| 4 |
+
# constant cosine_nesterov low_to_high_momentum high_to_low_momentum
|
| 5 |
+
# low_to_high_nesterov high_to_low_nesterov
|
| 6 |
+
Model_Variant="spo" #lpo, spo
|
| 7 |
+
GPU_ID=$(nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits | sort -k2 -n | head -n1 | cut -d',' -f1)
|
| 8 |
+
|
| 9 |
+
python eval.py \
|
| 10 |
+
--model_variant "$Model_Variant" \
|
| 11 |
+
--dataset_type "$Dataset_Name" \
|
| 12 |
+
--grad_config "$Grad_Config" \
|
| 13 |
+
--metrics fid clip aesthetic pickscore hpsv2 hpsv21 imagereward \
|
| 14 |
+
--max_samples 500 \
|
| 15 |
+
--num_steps 20 \
|
| 16 |
+
--cfg_scale 3 \
|
| 17 |
+
--output_dir "RESULTS/$Dataset_Name/${Grad_Config}_${Model_Variant}" \
|
| 18 |
+
--cuda $GPU_ID \
|
| 19 |
+
--mode "gradient_ascent" # gradient_ascent baseline # /baseline
|
Reward_sdxl_idealized/gradient_ascent_utils.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradient Ascent utilities for reward-guided diffusion generation.
|
| 3 |
+
|
| 4 |
+
This module implements gradient ascent on the LRM reward score to guide
|
| 5 |
+
the diffusion process toward higher preference scores.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from typing import Optional, Tuple, List, Literal
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from lr_scheduler import create_lr_scheduler, LRScheduler
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RewardGuidedDiffusion:
|
| 16 |
+
"""
|
| 17 |
+
Implements reward-guided generation using gradient ascent.
|
| 18 |
+
|
| 19 |
+
During denoising, at specified timesteps, we:
|
| 20 |
+
1. Compute the reward score for current latents
|
| 21 |
+
2. Calculate gradients of reward w.r.t. latents
|
| 22 |
+
3. Update latents in the direction that increases reward
|
| 23 |
+
|
| 24 |
+
This guides generation toward higher preference scores.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
reward_model,
|
| 30 |
+
grad_scale: float = 1.0,
|
| 31 |
+
grad_timestep_range: Optional[Tuple[int, int]] = None,
|
| 32 |
+
num_grad_steps: int = 5,
|
| 33 |
+
grad_step_size: float = 0.1,
|
| 34 |
+
gradient_checkpoint: bool = False,
|
| 35 |
+
# LR Scheduling
|
| 36 |
+
lr_scheduler_type: Literal["constant", "linear", "cosine", "exponential", "step"] = "constant",
|
| 37 |
+
lr_scheduler_kwargs: Optional[dict] = None,
|
| 38 |
+
# Momentum
|
| 39 |
+
use_momentum: bool = False,
|
| 40 |
+
momentum: float = 0.9,
|
| 41 |
+
use_nesterov: bool = False,
|
| 42 |
+
use_iso_projection: bool = False
|
| 43 |
+
):
|
| 44 |
+
"""
|
| 45 |
+
Initialize reward-guided diffusion.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
reward_model: LRM reward model for computing preference scores
|
| 49 |
+
grad_scale: Scale factor for gradient updates (default: 1.0)
|
| 50 |
+
grad_timestep_range: Tuple of (min_t, max_t) for gradient ascent.
|
| 51 |
+
If None, applies to all timesteps.
|
| 52 |
+
num_grad_steps: Number of gradient ascent steps per timestep
|
| 53 |
+
grad_step_size: Step size for each gradient update (initial LR)
|
| 54 |
+
gradient_checkpoint: Whether to use gradient checkpointing
|
| 55 |
+
lr_scheduler_type: Type of LR scheduler ("constant", "linear", "cosine", "exponential", "step")
|
| 56 |
+
lr_scheduler_kwargs: Additional kwargs for LR scheduler (e.g., end_lr, min_lr, warmup_steps)
|
| 57 |
+
use_momentum: Whether to use momentum in gradient updates
|
| 58 |
+
momentum: Momentum coefficient (typically 0.9)
|
| 59 |
+
use_nesterov: Whether to use Nesterov momentum
|
| 60 |
+
use_iso_projection: Whether to use Iso Projection
|
| 61 |
+
"""
|
| 62 |
+
self.reward_model = reward_model
|
| 63 |
+
self.grad_scale = grad_scale
|
| 64 |
+
self.grad_timestep_range = grad_timestep_range
|
| 65 |
+
self.num_grad_steps = num_grad_steps
|
| 66 |
+
self.grad_step_size = grad_step_size
|
| 67 |
+
self.gradient_checkpoint = gradient_checkpoint
|
| 68 |
+
|
| 69 |
+
# LR Scheduler
|
| 70 |
+
self.lr_scheduler_type = lr_scheduler_type
|
| 71 |
+
self.lr_scheduler_kwargs = lr_scheduler_kwargs or {}
|
| 72 |
+
self.lr_scheduler: Optional[LRScheduler] = None
|
| 73 |
+
self.global_lr_scheduler: Optional[LRScheduler] = None # Scheduler across denoising timesteps
|
| 74 |
+
|
| 75 |
+
# Momentum
|
| 76 |
+
self.use_momentum = use_momentum
|
| 77 |
+
self.momentum = momentum
|
| 78 |
+
self.use_nesterov = use_nesterov
|
| 79 |
+
self.velocity = None # Will be initialized per optimization
|
| 80 |
+
|
| 81 |
+
self.use_iso_projection = use_iso_projection
|
| 82 |
+
|
| 83 |
+
# Statistics
|
| 84 |
+
self.grad_stats = []
|
| 85 |
+
self.timestep_counter = 0 # Track which timestep we're on
|
| 86 |
+
|
| 87 |
+
def should_apply_gradient(self, timestep: int) -> bool:
|
| 88 |
+
"""Check if gradient ascent should be applied at this timestep."""
|
| 89 |
+
|
| 90 |
+
if self.grad_timestep_range is None:
|
| 91 |
+
return False
|
| 92 |
+
|
| 93 |
+
min_t, max_t = self.grad_timestep_range
|
| 94 |
+
return min_t <= timestep <= max_t
|
| 95 |
+
|
| 96 |
+
@torch.enable_grad()
|
| 97 |
+
def compute_reward_gradient(
|
| 98 |
+
self,
|
| 99 |
+
latents: torch.Tensor,
|
| 100 |
+
prompt: str,
|
| 101 |
+
timestep: int,
|
| 102 |
+
) -> Tuple[torch.Tensor, float]:
|
| 103 |
+
"""
|
| 104 |
+
Compute gradient of reward score w.r.t. latents in FP32 to prevent underflow.
|
| 105 |
+
"""
|
| 106 |
+
# 1. Cast to FP32 and ensure we are detached from previous iterations
|
| 107 |
+
latents_fp32 = latents.detach().to(torch.float32).clone()
|
| 108 |
+
latents_fp32.requires_grad_(True)
|
| 109 |
+
|
| 110 |
+
# 2. Compute reward score
|
| 111 |
+
# Note: Even if the model internally uses fp16/bf16, autograd will
|
| 112 |
+
# safely accumulate the gradient in fp32 for our leaf node.
|
| 113 |
+
reward_score = self.reward_model.get_reward_score(
|
| 114 |
+
latents_fp32,
|
| 115 |
+
prompt,
|
| 116 |
+
timestep,
|
| 117 |
+
enable_grad=True
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
if reward_score.numel() > 1:
|
| 121 |
+
reward_score = reward_score.mean()
|
| 122 |
+
|
| 123 |
+
# 3. Extract gradient
|
| 124 |
+
# CRITICAL: retain_graph=True prevents the graph from dying across multiple
|
| 125 |
+
# gradient steps if your reward model relies on cached text embeddings.
|
| 126 |
+
grad = torch.autograd.grad(
|
| 127 |
+
outputs=reward_score,
|
| 128 |
+
inputs=latents_fp32,
|
| 129 |
+
create_graph=False,
|
| 130 |
+
retain_graph=True, # Keeps the graph alive for the next step!
|
| 131 |
+
allow_unused=True,
|
| 132 |
+
)[0]
|
| 133 |
+
|
| 134 |
+
# 4. Handle None gradients and cast back to the pipeline's original dtype
|
| 135 |
+
if grad is None:
|
| 136 |
+
grad = torch.zeros_like(latents)
|
| 137 |
+
else:
|
| 138 |
+
grad = grad.to(latents.dtype)
|
| 139 |
+
|
| 140 |
+
return grad, reward_score.item()
|
| 141 |
+
|
| 142 |
+
def apply_gradient_ascent(
|
| 143 |
+
self,
|
| 144 |
+
latents: torch.Tensor,
|
| 145 |
+
prompt: str,
|
| 146 |
+
timestep: int,
|
| 147 |
+
base_noise: Optional[torch.Tensor] = None, # Required for Iso-Marginal projection
|
| 148 |
+
verbose: bool = True,
|
| 149 |
+
total_denoising_steps: Optional[int] = None,
|
| 150 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 151 |
+
|
| 152 |
+
# 1. UPCAST TO FP32 AND SETUP OPTIMIZER (Targeting Latents)
|
| 153 |
+
original_latents = latents.detach().clone().to(torch.float32)
|
| 154 |
+
current_latents = torch.nn.Parameter(original_latents.clone())
|
| 155 |
+
self.reward_model.unet.conv_in.weight.requires_grad_(True)
|
| 156 |
+
|
| 157 |
+
# Initial reward tracking
|
| 158 |
+
with torch.no_grad():
|
| 159 |
+
initial_reward = self.reward_model.get_reward_score(
|
| 160 |
+
latents,
|
| 161 |
+
prompt,
|
| 162 |
+
timestep
|
| 163 |
+
)
|
| 164 |
+
initial_reward_val = initial_reward.item() if initial_reward.numel() == 1 else initial_reward.mean().item()
|
| 165 |
+
|
| 166 |
+
# Initialize tracking lists
|
| 167 |
+
grad_norms = []
|
| 168 |
+
reward_history = [initial_reward_val]
|
| 169 |
+
lr_history = []
|
| 170 |
+
|
| 171 |
+
# 2. FORWARD PASS (downcast to FP16 just for the model forward pass)
|
| 172 |
+
reward = self.reward_model.get_reward_score(
|
| 173 |
+
current_latents.to(latents.dtype),
|
| 174 |
+
prompt,
|
| 175 |
+
timestep,
|
| 176 |
+
enable_grad=True
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
loss = -reward.mean()
|
| 180 |
+
loss.backward()
|
| 181 |
+
|
| 182 |
+
# Extract latent gradient
|
| 183 |
+
raw_grad = current_latents.grad
|
| 184 |
+
reward_history.append(reward.mean().item())
|
| 185 |
+
|
| 186 |
+
# 3. ISO-MARGINAL PROJECTION WITH ASYMMETRIC INCLUSION
|
| 187 |
+
if raw_grad is not None and base_noise is not None and self.use_iso_projection:
|
| 188 |
+
gamma = 1e-8
|
| 189 |
+
B = raw_grad.shape[0]
|
| 190 |
+
|
| 191 |
+
grad_flat = raw_grad.view(B, -1)
|
| 192 |
+
noise_flat = base_noise.view(B, -1).to(torch.float32)
|
| 193 |
+
|
| 194 |
+
# Compute projection scalar for raw_grad (which is -?R)
|
| 195 |
+
dot_product = (grad_flat * noise_flat).sum(dim=1, keepdim=True)
|
| 196 |
+
noise_norm_sq = (noise_flat * noise_flat).sum(dim=1, keepdim=True)
|
| 197 |
+
|
| 198 |
+
proj_scalar = dot_product / (noise_norm_sq + gamma)
|
| 199 |
+
proj_scalar = proj_scalar.view(B, 1, 1, 1)
|
| 200 |
+
|
| 201 |
+
# 1. Decompose
|
| 202 |
+
grad_parallel = proj_scalar * base_noise.to(torch.float32)
|
| 203 |
+
grad_perp = raw_grad - grad_parallel
|
| 204 |
+
|
| 205 |
+
# 2. Asymmetric Inclusion
|
| 206 |
+
# proj_scalar > 0 means the applied step (+?R) points toward -epsilon (Denoising. GOOD.)
|
| 207 |
+
# proj_scalar < 0 means the applied step (+?R) points toward +epsilon (Noising. BAD.)
|
| 208 |
+
safe_proj_scalar = torch.clamp(proj_scalar, min=0.0)
|
| 209 |
+
|
| 210 |
+
beta = 1.0 # Retention factor for the safe parallel gradient
|
| 211 |
+
safe_grad_parallel = beta * (safe_proj_scalar * base_noise.to(torch.float32))
|
| 212 |
+
|
| 213 |
+
# 3. Recombine
|
| 214 |
+
#grad_perp = grad_perp + safe_grad_parallel
|
| 215 |
+
else:
|
| 216 |
+
grad_perp = raw_grad
|
| 217 |
+
if base_noise is None:
|
| 218 |
+
print("?? WARNING: base_noise missing. Skipping Iso-Marginal projection.")
|
| 219 |
+
|
| 220 |
+
# 4. KINETIC RECTIFICATION (Applied to the projected latent gradient)
|
| 221 |
+
if grad_perp is not None:
|
| 222 |
+
max_grad = grad_perp.norm().item()
|
| 223 |
+
|
| 224 |
+
if max_grad > 0:
|
| 225 |
+
kinetic_direction = grad_perp / (max_grad + 1e-8)
|
| 226 |
+
|
| 227 |
+
# Because the max element is 1.0, alpha is the EXACT float32 change applied.
|
| 228 |
+
alpha = self.grad_step_size
|
| 229 |
+
|
| 230 |
+
with torch.no_grad():
|
| 231 |
+
rectified_latents = original_latents - (alpha * kinetic_direction)
|
| 232 |
+
else:
|
| 233 |
+
print("?? WARNING: Gradient exists but max value is 0.0")
|
| 234 |
+
rectified_latents = original_latents.clone()
|
| 235 |
+
alpha = 0.0
|
| 236 |
+
else:
|
| 237 |
+
print("?? FATAL: PyTorch completely dropped the latent gradient!")
|
| 238 |
+
rectified_latents = original_latents.clone()
|
| 239 |
+
max_grad = 0.0
|
| 240 |
+
alpha = 0.0
|
| 241 |
+
|
| 242 |
+
if verbose:
|
| 243 |
+
print(f" Grad step | LR: {alpha:.6f} | Reward: {reward.mean().item():.4f} | Max Grad: {max_grad:.4f}")
|
| 244 |
+
|
| 245 |
+
# 5. DOWNCAST AND RETURN
|
| 246 |
+
final_latents = rectified_latents.detach().to(latents.dtype)
|
| 247 |
+
|
| 248 |
+
with torch.no_grad():
|
| 249 |
+
final_reward = self.reward_model.get_reward_score(
|
| 250 |
+
final_latents, prompt, timestep
|
| 251 |
+
)
|
| 252 |
+
final_reward_val = final_reward.item() if final_reward.numel() == 1 else final_reward.mean().item()
|
| 253 |
+
|
| 254 |
+
stats = {
|
| 255 |
+
'timestep': timestep,
|
| 256 |
+
'initial_reward': initial_reward_val,
|
| 257 |
+
'final_reward': final_reward_val,
|
| 258 |
+
'reward_improvement': final_reward_val - initial_reward_val,
|
| 259 |
+
'grad_norms': [max_grad],
|
| 260 |
+
'reward_history': reward_history,
|
| 261 |
+
'lr_history': [alpha], # Kept for plotting logic
|
| 262 |
+
'latent_change': (final_latents - original_latents.to(latents.dtype)).norm().item(),
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
self.grad_stats.append(stats)
|
| 266 |
+
|
| 267 |
+
return final_latents, stats
|
| 268 |
+
|
| 269 |
+
def get_statistics(self) -> dict:
|
| 270 |
+
"""Get aggregated statistics across all gradient ascent applications."""
|
| 271 |
+
if not self.grad_stats:
|
| 272 |
+
return {}
|
| 273 |
+
|
| 274 |
+
total_improvement = sum(s['reward_improvement'] for s in self.grad_stats)
|
| 275 |
+
avg_improvement = total_improvement / len(self.grad_stats)
|
| 276 |
+
|
| 277 |
+
all_grad_norms = [n for s in self.grad_stats for n in s['grad_norms']]
|
| 278 |
+
|
| 279 |
+
return {
|
| 280 |
+
'num_applications': len(self.grad_stats),
|
| 281 |
+
'total_reward_improvement': total_improvement,
|
| 282 |
+
'avg_reward_improvement': avg_improvement,
|
| 283 |
+
'avg_grad_norm': sum(all_grad_norms) / len(all_grad_norms) if all_grad_norms else 0,
|
| 284 |
+
'max_grad_norm': max(all_grad_norms) if all_grad_norms else 0,
|
| 285 |
+
'detailed_stats': self.grad_stats,
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
def reset_statistics(self):
|
| 289 |
+
"""Reset statistics and global scheduler."""
|
| 290 |
+
self.grad_stats = []
|
| 291 |
+
self.global_lr_scheduler = None
|
| 292 |
+
self.timestep_counter = 0
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def create_reward_guided_generator(
|
| 296 |
+
reward_model,
|
| 297 |
+
grad_timestep_range: Tuple[int, int] = (500, 700),
|
| 298 |
+
grad_scale: float = 1.0,
|
| 299 |
+
num_grad_steps: int = 5,
|
| 300 |
+
grad_step_size: float = 0.1,
|
| 301 |
+
lr_scheduler_type: str = "constant",
|
| 302 |
+
lr_scheduler_kwargs: Optional[dict] = None,
|
| 303 |
+
use_momentum: bool = False,
|
| 304 |
+
momentum: float = 0.9,
|
| 305 |
+
use_nesterov: bool = False,
|
| 306 |
+
use_iso_projection: bool = False
|
| 307 |
+
) -> RewardGuidedDiffusion:
|
| 308 |
+
"""
|
| 309 |
+
Convenience function to create a reward-guided diffusion generator.
|
| 310 |
+
|
| 311 |
+
Args:
|
| 312 |
+
reward_model: LRM reward model
|
| 313 |
+
grad_timestep_range: Tuple of (min_t, max_t) for applying gradients
|
| 314 |
+
grad_scale: Scale factor for gradient magnitude
|
| 315 |
+
num_grad_steps: Number of gradient ascent iterations per timestep
|
| 316 |
+
grad_step_size: Step size for each gradient update (initial LR)
|
| 317 |
+
lr_scheduler_type: Type of LR scheduler
|
| 318 |
+
lr_scheduler_kwargs: Additional kwargs for LR scheduler
|
| 319 |
+
use_momentum: Whether to use momentum
|
| 320 |
+
momentum: Momentum coefficient
|
| 321 |
+
use_nesterov: Whether to use Nesterov momentum
|
| 322 |
+
use_iso_projection: Whether to use Iso Projection
|
| 323 |
+
|
| 324 |
+
Returns:
|
| 325 |
+
RewardGuidedDiffusion instance
|
| 326 |
+
"""
|
| 327 |
+
return RewardGuidedDiffusion(
|
| 328 |
+
reward_model=reward_model,
|
| 329 |
+
grad_scale=grad_scale,
|
| 330 |
+
grad_timestep_range=grad_timestep_range,
|
| 331 |
+
num_grad_steps=num_grad_steps,
|
| 332 |
+
grad_step_size=grad_step_size,
|
| 333 |
+
lr_scheduler_type=lr_scheduler_type,
|
| 334 |
+
lr_scheduler_kwargs=lr_scheduler_kwargs,
|
| 335 |
+
use_momentum=use_momentum,
|
| 336 |
+
momentum=momentum,
|
| 337 |
+
use_nesterov=use_nesterov,
|
| 338 |
+
use_iso_projection= False
|
| 339 |
+
)
|
Reward_sdxl_idealized/lr_scheduler.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Learning rate schedulers for gradient ascent optimization.
|
| 3 |
+
|
| 4 |
+
Provides various LR scheduling strategies for reward-guided gradient ascent,
|
| 5 |
+
including cosine annealing, linear decay, and custom schedules.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
from typing import Optional, Literal
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class LRScheduler:
|
| 13 |
+
"""Base class for learning rate schedulers."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, initial_lr: float, num_steps: int):
|
| 16 |
+
"""
|
| 17 |
+
Initialize LR scheduler.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
initial_lr: Initial learning rate
|
| 21 |
+
num_steps: Total number of optimization steps
|
| 22 |
+
"""
|
| 23 |
+
self.initial_lr = initial_lr
|
| 24 |
+
self.num_steps = num_steps
|
| 25 |
+
self.current_step = 0
|
| 26 |
+
|
| 27 |
+
def get_lr(self) -> float:
|
| 28 |
+
"""Get current learning rate."""
|
| 29 |
+
raise NotImplementedError
|
| 30 |
+
|
| 31 |
+
def step(self):
|
| 32 |
+
"""Update scheduler state after a step."""
|
| 33 |
+
self.current_step += 1
|
| 34 |
+
|
| 35 |
+
def reset(self):
|
| 36 |
+
"""Reset scheduler state."""
|
| 37 |
+
self.current_step = 0
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ConstantLR(LRScheduler):
|
| 41 |
+
"""Constant learning rate (no scheduling)."""
|
| 42 |
+
|
| 43 |
+
def get_lr(self) -> float:
|
| 44 |
+
return self.initial_lr
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class LinearLR(LRScheduler):
|
| 48 |
+
"""Linear learning rate decay."""
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
initial_lr: float,
|
| 53 |
+
num_steps: int,
|
| 54 |
+
end_lr: float = 0.0,
|
| 55 |
+
start_step: int = 0,
|
| 56 |
+
):
|
| 57 |
+
"""
|
| 58 |
+
Initialize linear LR scheduler.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
initial_lr: Starting learning rate
|
| 62 |
+
num_steps: Total number of steps
|
| 63 |
+
end_lr: Ending learning rate (default: 0.0)
|
| 64 |
+
start_step: Step to begin decay (default: 0)
|
| 65 |
+
"""
|
| 66 |
+
super().__init__(initial_lr, num_steps)
|
| 67 |
+
self.end_lr = end_lr
|
| 68 |
+
self.start_step = start_step
|
| 69 |
+
|
| 70 |
+
def get_lr(self) -> float:
|
| 71 |
+
if self.current_step < self.start_step:
|
| 72 |
+
return self.initial_lr
|
| 73 |
+
|
| 74 |
+
progress = (self.current_step - self.start_step) / (self.num_steps - self.start_step)
|
| 75 |
+
progress = min(1.0, progress)
|
| 76 |
+
|
| 77 |
+
return self.initial_lr + (self.end_lr - self.initial_lr) * progress
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class CosineLR(LRScheduler):
|
| 81 |
+
"""Cosine annealing learning rate schedule."""
|
| 82 |
+
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
initial_lr: float,
|
| 86 |
+
num_steps: int,
|
| 87 |
+
min_lr: float = 0.0,
|
| 88 |
+
warmup_steps: int = 0,
|
| 89 |
+
):
|
| 90 |
+
"""
|
| 91 |
+
Initialize cosine LR scheduler.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
initial_lr: Maximum learning rate
|
| 95 |
+
num_steps: Total number of steps
|
| 96 |
+
min_lr: Minimum learning rate (default: 0.0)
|
| 97 |
+
warmup_steps: Number of linear warmup steps (default: 0)
|
| 98 |
+
"""
|
| 99 |
+
super().__init__(initial_lr, num_steps)
|
| 100 |
+
self.min_lr = min_lr
|
| 101 |
+
self.warmup_steps = warmup_steps
|
| 102 |
+
|
| 103 |
+
def get_lr(self) -> float:
|
| 104 |
+
if self.current_step < self.warmup_steps:
|
| 105 |
+
# Linear warmup
|
| 106 |
+
return self.initial_lr * (self.current_step / self.warmup_steps)
|
| 107 |
+
|
| 108 |
+
# Cosine annealing
|
| 109 |
+
progress = (self.current_step - self.warmup_steps) / (self.num_steps - self.warmup_steps)
|
| 110 |
+
progress = min(1.0, progress)
|
| 111 |
+
|
| 112 |
+
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
|
| 113 |
+
return self.min_lr + (self.initial_lr - self.min_lr) * cosine_decay
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class ExponentialLR(LRScheduler):
|
| 117 |
+
"""Exponential learning rate decay."""
|
| 118 |
+
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
initial_lr: float,
|
| 122 |
+
num_steps: int,
|
| 123 |
+
gamma: float = 0.95,
|
| 124 |
+
):
|
| 125 |
+
"""
|
| 126 |
+
Initialize exponential LR scheduler.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
initial_lr: Starting learning rate
|
| 130 |
+
num_steps: Total number of steps
|
| 131 |
+
gamma: Multiplicative decay factor per step
|
| 132 |
+
"""
|
| 133 |
+
super().__init__(initial_lr, num_steps)
|
| 134 |
+
self.gamma = gamma
|
| 135 |
+
|
| 136 |
+
def get_lr(self) -> float:
|
| 137 |
+
return self.initial_lr * (self.gamma ** self.current_step)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class StepLR(LRScheduler):
|
| 141 |
+
"""Step-wise learning rate decay."""
|
| 142 |
+
|
| 143 |
+
def __init__(
|
| 144 |
+
self,
|
| 145 |
+
initial_lr: float,
|
| 146 |
+
num_steps: int,
|
| 147 |
+
step_size: int,
|
| 148 |
+
gamma: float = 0.1,
|
| 149 |
+
):
|
| 150 |
+
"""
|
| 151 |
+
Initialize step LR scheduler.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
initial_lr: Starting learning rate
|
| 155 |
+
num_steps: Total number of steps
|
| 156 |
+
step_size: Number of steps between each decay
|
| 157 |
+
gamma: Multiplicative decay factor
|
| 158 |
+
"""
|
| 159 |
+
super().__init__(initial_lr, num_steps)
|
| 160 |
+
self.step_size = step_size
|
| 161 |
+
self.gamma = gamma
|
| 162 |
+
|
| 163 |
+
def get_lr(self) -> float:
|
| 164 |
+
num_decays = self.current_step // self.step_size
|
| 165 |
+
return self.initial_lr * (self.gamma ** num_decays)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def create_lr_scheduler(
|
| 169 |
+
scheduler_type: Literal["constant", "linear", "cosine", "exponential", "step"],
|
| 170 |
+
initial_lr: float,
|
| 171 |
+
num_steps: int,
|
| 172 |
+
**kwargs
|
| 173 |
+
) -> LRScheduler:
|
| 174 |
+
"""
|
| 175 |
+
Factory function to create learning rate schedulers.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
scheduler_type: Type of scheduler ("constant", "linear", "cosine", "exponential", "step")
|
| 179 |
+
initial_lr: Initial learning rate
|
| 180 |
+
num_steps: Total number of optimization steps
|
| 181 |
+
**kwargs: Additional scheduler-specific arguments
|
| 182 |
+
For linear: end_lr, start_step
|
| 183 |
+
For cosine: min_lr, warmup_steps
|
| 184 |
+
For exponential: gamma
|
| 185 |
+
For step: step_size, gamma
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
LRScheduler instance
|
| 189 |
+
|
| 190 |
+
Examples:
|
| 191 |
+
# Constant LR
|
| 192 |
+
scheduler = create_lr_scheduler("constant", initial_lr=0.1, num_steps=100)
|
| 193 |
+
|
| 194 |
+
# Linear decay
|
| 195 |
+
scheduler = create_lr_scheduler("linear", initial_lr=0.1, num_steps=100, end_lr=0.01)
|
| 196 |
+
|
| 197 |
+
# Cosine annealing with warmup
|
| 198 |
+
scheduler = create_lr_scheduler("cosine", initial_lr=0.1, num_steps=100,
|
| 199 |
+
min_lr=0.001, warmup_steps=10)
|
| 200 |
+
"""
|
| 201 |
+
if scheduler_type == "constant":
|
| 202 |
+
return ConstantLR(initial_lr, num_steps)
|
| 203 |
+
|
| 204 |
+
elif scheduler_type == "linear":
|
| 205 |
+
return LinearLR(
|
| 206 |
+
initial_lr, num_steps,
|
| 207 |
+
end_lr=kwargs.get("end_lr", 0.0),
|
| 208 |
+
start_step=kwargs.get("start_step", 0),
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
elif scheduler_type == "cosine":
|
| 212 |
+
return CosineLR(
|
| 213 |
+
initial_lr, num_steps,
|
| 214 |
+
min_lr=kwargs.get("min_lr", 0.0),
|
| 215 |
+
warmup_steps=kwargs.get("warmup_steps", 0),
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
elif scheduler_type == "exponential":
|
| 219 |
+
return ExponentialLR(
|
| 220 |
+
initial_lr, num_steps,
|
| 221 |
+
gamma=kwargs.get("gamma", 0.95),
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
elif scheduler_type == "step":
|
| 225 |
+
return StepLR(
|
| 226 |
+
initial_lr, num_steps,
|
| 227 |
+
step_size=kwargs.get("step_size", 10),
|
| 228 |
+
gamma=kwargs.get("gamma", 0.1),
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
else:
|
| 232 |
+
raise ValueError(f"Unknown scheduler type: {scheduler_type}. "
|
| 233 |
+
f"Choose from: constant, linear, cosine, exponential, step")
|
Reward_sdxl_idealized/models/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .reward_model import LRMRewardModelXL
|
| 2 |
+
|
| 3 |
+
__all__ = ['LRMRewardModelXL']
|
Reward_sdxl_idealized/models/__pycache__/reward_model.cpython-310.pyc
ADDED
|
Binary file (6.25 kB). View file
|
|
|
Reward_sdxl_idealized/models/__pycache__/reward_model.cpython-313.pyc
ADDED
|
Binary file (16.4 kB). View file
|
|
|
Reward_sdxl_idealized/models/__pycache__/unet_2d_condition_reward.cpython-313.pyc
ADDED
|
Binary file (57.4 kB). View file
|
|
|
Reward_sdxl_idealized/models/unet_2d_condition_reward.py
ADDED
|
@@ -0,0 +1,1334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.utils.checkpoint
|
| 20 |
+
|
| 21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 22 |
+
from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
| 23 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 24 |
+
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
| 25 |
+
from diffusers.models.activations import get_activation
|
| 26 |
+
from diffusers.models.attention_processor import (
|
| 27 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
| 28 |
+
CROSS_ATTENTION_PROCESSORS,
|
| 29 |
+
Attention,
|
| 30 |
+
AttentionProcessor,
|
| 31 |
+
AttnAddedKVProcessor,
|
| 32 |
+
AttnProcessor,
|
| 33 |
+
FusedAttnProcessor2_0,
|
| 34 |
+
)
|
| 35 |
+
from diffusers.models.embeddings import (
|
| 36 |
+
GaussianFourierProjection,
|
| 37 |
+
GLIGENTextBoundingboxProjection,
|
| 38 |
+
ImageHintTimeEmbedding,
|
| 39 |
+
ImageProjection,
|
| 40 |
+
ImageTimeEmbedding,
|
| 41 |
+
TextImageProjection,
|
| 42 |
+
TextImageTimeEmbedding,
|
| 43 |
+
TextTimeEmbedding,
|
| 44 |
+
TimestepEmbedding,
|
| 45 |
+
Timesteps,
|
| 46 |
+
)
|
| 47 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 48 |
+
from diffusers.models.unets.unet_2d_blocks import (
|
| 49 |
+
get_down_block,
|
| 50 |
+
get_mid_block,
|
| 51 |
+
get_up_block,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@dataclass
|
| 59 |
+
class UNet2DConditionOutput(BaseOutput):
|
| 60 |
+
"""
|
| 61 |
+
The output of [`UNet2DConditionModel`].
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
|
| 65 |
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
sample: torch.Tensor = None
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class UNet2DConditionModel(
|
| 72 |
+
ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
|
| 73 |
+
):
|
| 74 |
+
r"""
|
| 75 |
+
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
| 76 |
+
shaped output.
|
| 77 |
+
|
| 78 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
| 79 |
+
for all models (such as downloading or saving).
|
| 80 |
+
|
| 81 |
+
Parameters:
|
| 82 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
| 83 |
+
Height and width of input/output sample.
|
| 84 |
+
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
|
| 85 |
+
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
| 86 |
+
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
| 87 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
| 88 |
+
Whether to flip the sin to cos in the time embedding.
|
| 89 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
| 90 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
| 91 |
+
The tuple of downsample blocks to use.
|
| 92 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
| 93 |
+
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
|
| 94 |
+
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
| 95 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
| 96 |
+
The tuple of upsample blocks to use.
|
| 97 |
+
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
| 98 |
+
Whether to include self-attention in the basic transformer blocks, see
|
| 99 |
+
[`~models.attention.BasicTransformerBlock`].
|
| 100 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
| 101 |
+
The tuple of output channels for each block.
|
| 102 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
| 103 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
| 104 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
| 105 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 106 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
| 107 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
| 108 |
+
If `None`, normalization and activation layers is skipped in post-processing.
|
| 109 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
| 110 |
+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
| 111 |
+
The dimension of the cross attention features.
|
| 112 |
+
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
|
| 113 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
| 114 |
+
[`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`],
|
| 115 |
+
[`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
| 116 |
+
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
|
| 117 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
|
| 118 |
+
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
|
| 119 |
+
[`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`],
|
| 120 |
+
[`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
| 121 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
| 122 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
| 123 |
+
dimension to `cross_attention_dim`.
|
| 124 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
| 125 |
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
| 126 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
| 127 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
| 128 |
+
num_attention_heads (`int`, *optional*):
|
| 129 |
+
The number of attention heads. If not defined, defaults to `attention_head_dim`
|
| 130 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
| 131 |
+
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
| 132 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
| 133 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
| 134 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
| 135 |
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
| 136 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
| 137 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
| 138 |
+
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
|
| 139 |
+
Dimension for the timestep embeddings.
|
| 140 |
+
num_class_embeds (`int`, *optional*, defaults to `None`):
|
| 141 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
| 142 |
+
class conditioning with `class_embed_type` equal to `None`.
|
| 143 |
+
time_embedding_type (`str`, *optional*, defaults to `positional`):
|
| 144 |
+
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
| 145 |
+
time_embedding_dim (`int`, *optional*, defaults to `None`):
|
| 146 |
+
An optional override for the dimension of the projected time embedding.
|
| 147 |
+
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
|
| 148 |
+
Optional activation function to use only once on the time embeddings before they are passed to the rest of
|
| 149 |
+
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
| 150 |
+
timestep_post_act (`str`, *optional*, defaults to `None`):
|
| 151 |
+
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
| 152 |
+
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
|
| 153 |
+
The dimension of `cond_proj` layer in the timestep embedding.
|
| 154 |
+
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
|
| 155 |
+
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
|
| 156 |
+
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
| 157 |
+
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
|
| 158 |
+
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
| 159 |
+
embeddings with the class embeddings.
|
| 160 |
+
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
|
| 161 |
+
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
|
| 162 |
+
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
|
| 163 |
+
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
|
| 164 |
+
otherwise.
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
_supports_gradient_checkpointing = True
|
| 168 |
+
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
|
| 169 |
+
|
| 170 |
+
@register_to_config
|
| 171 |
+
def __init__(
|
| 172 |
+
self,
|
| 173 |
+
sample_size: Optional[int] = None,
|
| 174 |
+
in_channels: int = 4,
|
| 175 |
+
out_channels: int = 4,
|
| 176 |
+
center_input_sample: bool = False,
|
| 177 |
+
flip_sin_to_cos: bool = True,
|
| 178 |
+
freq_shift: int = 0,
|
| 179 |
+
down_block_types: Tuple[str] = (
|
| 180 |
+
"CrossAttnDownBlock2D",
|
| 181 |
+
"CrossAttnDownBlock2D",
|
| 182 |
+
"CrossAttnDownBlock2D",
|
| 183 |
+
"DownBlock2D",
|
| 184 |
+
),
|
| 185 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
| 186 |
+
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
| 187 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
| 188 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
| 189 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
| 190 |
+
downsample_padding: int = 1,
|
| 191 |
+
mid_block_scale_factor: float = 1,
|
| 192 |
+
dropout: float = 0.0,
|
| 193 |
+
act_fn: str = "silu",
|
| 194 |
+
norm_num_groups: Optional[int] = 32,
|
| 195 |
+
norm_eps: float = 1e-5,
|
| 196 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
| 197 |
+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
| 198 |
+
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
|
| 199 |
+
encoder_hid_dim: Optional[int] = None,
|
| 200 |
+
encoder_hid_dim_type: Optional[str] = None,
|
| 201 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
| 202 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
| 203 |
+
dual_cross_attention: bool = False,
|
| 204 |
+
use_linear_projection: bool = False,
|
| 205 |
+
class_embed_type: Optional[str] = None,
|
| 206 |
+
addition_embed_type: Optional[str] = None,
|
| 207 |
+
addition_time_embed_dim: Optional[int] = None,
|
| 208 |
+
num_class_embeds: Optional[int] = None,
|
| 209 |
+
upcast_attention: bool = False,
|
| 210 |
+
resnet_time_scale_shift: str = "default",
|
| 211 |
+
resnet_skip_time_act: bool = False,
|
| 212 |
+
resnet_out_scale_factor: float = 1.0,
|
| 213 |
+
time_embedding_type: str = "positional",
|
| 214 |
+
time_embedding_dim: Optional[int] = None,
|
| 215 |
+
time_embedding_act_fn: Optional[str] = None,
|
| 216 |
+
timestep_post_act: Optional[str] = None,
|
| 217 |
+
time_cond_proj_dim: Optional[int] = None,
|
| 218 |
+
conv_in_kernel: int = 3,
|
| 219 |
+
conv_out_kernel: int = 3,
|
| 220 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
| 221 |
+
attention_type: str = "default",
|
| 222 |
+
class_embeddings_concat: bool = False,
|
| 223 |
+
mid_block_only_cross_attention: Optional[bool] = None,
|
| 224 |
+
cross_attention_norm: Optional[str] = None,
|
| 225 |
+
addition_embed_type_num_heads: int = 64,
|
| 226 |
+
):
|
| 227 |
+
super().__init__()
|
| 228 |
+
|
| 229 |
+
self.sample_size = sample_size
|
| 230 |
+
|
| 231 |
+
if num_attention_heads is not None:
|
| 232 |
+
raise ValueError(
|
| 233 |
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
| 237 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
| 238 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
| 239 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
| 240 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
| 241 |
+
# which is why we correct for the naming here.
|
| 242 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
| 243 |
+
|
| 244 |
+
# Check inputs
|
| 245 |
+
self._check_config(
|
| 246 |
+
down_block_types=down_block_types,
|
| 247 |
+
up_block_types=up_block_types,
|
| 248 |
+
only_cross_attention=only_cross_attention,
|
| 249 |
+
block_out_channels=block_out_channels,
|
| 250 |
+
layers_per_block=layers_per_block,
|
| 251 |
+
cross_attention_dim=cross_attention_dim,
|
| 252 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
| 253 |
+
reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
|
| 254 |
+
attention_head_dim=attention_head_dim,
|
| 255 |
+
num_attention_heads=num_attention_heads,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# input
|
| 259 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
| 260 |
+
self.conv_in = nn.Conv2d(
|
| 261 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# time
|
| 265 |
+
time_embed_dim, timestep_input_dim = self._set_time_proj(
|
| 266 |
+
time_embedding_type,
|
| 267 |
+
block_out_channels=block_out_channels,
|
| 268 |
+
flip_sin_to_cos=flip_sin_to_cos,
|
| 269 |
+
freq_shift=freq_shift,
|
| 270 |
+
time_embedding_dim=time_embedding_dim,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
self.time_embedding = TimestepEmbedding(
|
| 274 |
+
timestep_input_dim,
|
| 275 |
+
time_embed_dim,
|
| 276 |
+
act_fn=act_fn,
|
| 277 |
+
post_act_fn=timestep_post_act,
|
| 278 |
+
cond_proj_dim=time_cond_proj_dim,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
self._set_encoder_hid_proj(
|
| 282 |
+
encoder_hid_dim_type,
|
| 283 |
+
cross_attention_dim=cross_attention_dim,
|
| 284 |
+
encoder_hid_dim=encoder_hid_dim,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# class embedding
|
| 288 |
+
self._set_class_embedding(
|
| 289 |
+
class_embed_type,
|
| 290 |
+
act_fn=act_fn,
|
| 291 |
+
num_class_embeds=num_class_embeds,
|
| 292 |
+
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
|
| 293 |
+
time_embed_dim=time_embed_dim,
|
| 294 |
+
timestep_input_dim=timestep_input_dim,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
self._set_add_embedding(
|
| 298 |
+
addition_embed_type,
|
| 299 |
+
addition_embed_type_num_heads=addition_embed_type_num_heads,
|
| 300 |
+
addition_time_embed_dim=addition_time_embed_dim,
|
| 301 |
+
cross_attention_dim=cross_attention_dim,
|
| 302 |
+
encoder_hid_dim=encoder_hid_dim,
|
| 303 |
+
flip_sin_to_cos=flip_sin_to_cos,
|
| 304 |
+
freq_shift=freq_shift,
|
| 305 |
+
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
|
| 306 |
+
time_embed_dim=time_embed_dim,
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
if time_embedding_act_fn is None:
|
| 310 |
+
self.time_embed_act = None
|
| 311 |
+
else:
|
| 312 |
+
self.time_embed_act = get_activation(time_embedding_act_fn)
|
| 313 |
+
|
| 314 |
+
self.down_blocks = nn.ModuleList([])
|
| 315 |
+
self.up_blocks = nn.ModuleList([])
|
| 316 |
+
|
| 317 |
+
if isinstance(only_cross_attention, bool):
|
| 318 |
+
if mid_block_only_cross_attention is None:
|
| 319 |
+
mid_block_only_cross_attention = only_cross_attention
|
| 320 |
+
|
| 321 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
| 322 |
+
|
| 323 |
+
if mid_block_only_cross_attention is None:
|
| 324 |
+
mid_block_only_cross_attention = False
|
| 325 |
+
|
| 326 |
+
if isinstance(num_attention_heads, int):
|
| 327 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
| 328 |
+
|
| 329 |
+
if isinstance(attention_head_dim, int):
|
| 330 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
| 331 |
+
|
| 332 |
+
if isinstance(cross_attention_dim, int):
|
| 333 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
| 334 |
+
|
| 335 |
+
if isinstance(layers_per_block, int):
|
| 336 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
| 337 |
+
|
| 338 |
+
if isinstance(transformer_layers_per_block, int):
|
| 339 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
| 340 |
+
|
| 341 |
+
if class_embeddings_concat:
|
| 342 |
+
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
| 343 |
+
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
| 344 |
+
# regular time embeddings
|
| 345 |
+
blocks_time_embed_dim = time_embed_dim * 2
|
| 346 |
+
else:
|
| 347 |
+
blocks_time_embed_dim = time_embed_dim
|
| 348 |
+
|
| 349 |
+
# down
|
| 350 |
+
output_channel = block_out_channels[0]
|
| 351 |
+
for i, down_block_type in enumerate(down_block_types):
|
| 352 |
+
input_channel = output_channel
|
| 353 |
+
output_channel = block_out_channels[i]
|
| 354 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 355 |
+
|
| 356 |
+
down_block = get_down_block(
|
| 357 |
+
down_block_type,
|
| 358 |
+
num_layers=layers_per_block[i],
|
| 359 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
| 360 |
+
in_channels=input_channel,
|
| 361 |
+
out_channels=output_channel,
|
| 362 |
+
temb_channels=blocks_time_embed_dim,
|
| 363 |
+
add_downsample=not is_final_block,
|
| 364 |
+
resnet_eps=norm_eps,
|
| 365 |
+
resnet_act_fn=act_fn,
|
| 366 |
+
resnet_groups=norm_num_groups,
|
| 367 |
+
cross_attention_dim=cross_attention_dim[i],
|
| 368 |
+
num_attention_heads=num_attention_heads[i],
|
| 369 |
+
downsample_padding=downsample_padding,
|
| 370 |
+
dual_cross_attention=dual_cross_attention,
|
| 371 |
+
use_linear_projection=use_linear_projection,
|
| 372 |
+
only_cross_attention=only_cross_attention[i],
|
| 373 |
+
upcast_attention=upcast_attention,
|
| 374 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 375 |
+
attention_type=attention_type,
|
| 376 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
| 377 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
| 378 |
+
cross_attention_norm=cross_attention_norm,
|
| 379 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
| 380 |
+
dropout=dropout,
|
| 381 |
+
)
|
| 382 |
+
self.down_blocks.append(down_block)
|
| 383 |
+
|
| 384 |
+
# mid
|
| 385 |
+
self.mid_block = get_mid_block(
|
| 386 |
+
mid_block_type,
|
| 387 |
+
temb_channels=blocks_time_embed_dim,
|
| 388 |
+
in_channels=block_out_channels[-1],
|
| 389 |
+
resnet_eps=norm_eps,
|
| 390 |
+
resnet_act_fn=act_fn,
|
| 391 |
+
resnet_groups=norm_num_groups,
|
| 392 |
+
output_scale_factor=mid_block_scale_factor,
|
| 393 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
| 394 |
+
num_attention_heads=num_attention_heads[-1],
|
| 395 |
+
cross_attention_dim=cross_attention_dim[-1],
|
| 396 |
+
dual_cross_attention=dual_cross_attention,
|
| 397 |
+
use_linear_projection=use_linear_projection,
|
| 398 |
+
mid_block_only_cross_attention=mid_block_only_cross_attention,
|
| 399 |
+
upcast_attention=upcast_attention,
|
| 400 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 401 |
+
attention_type=attention_type,
|
| 402 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
| 403 |
+
cross_attention_norm=cross_attention_norm,
|
| 404 |
+
attention_head_dim=attention_head_dim[-1],
|
| 405 |
+
dropout=dropout,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# count how many layers upsample the images
|
| 409 |
+
self.num_upsamplers = 0
|
| 410 |
+
|
| 411 |
+
# up
|
| 412 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 413 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
| 414 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
| 415 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
| 416 |
+
reversed_transformer_layers_per_block = (
|
| 417 |
+
list(reversed(transformer_layers_per_block))
|
| 418 |
+
if reverse_transformer_layers_per_block is None
|
| 419 |
+
else reverse_transformer_layers_per_block
|
| 420 |
+
)
|
| 421 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
| 422 |
+
|
| 423 |
+
output_channel = reversed_block_out_channels[0]
|
| 424 |
+
for i, up_block_type in enumerate(up_block_types):
|
| 425 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 426 |
+
|
| 427 |
+
prev_output_channel = output_channel
|
| 428 |
+
output_channel = reversed_block_out_channels[i]
|
| 429 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
| 430 |
+
|
| 431 |
+
# add upsample block for all BUT final layer
|
| 432 |
+
if not is_final_block:
|
| 433 |
+
add_upsample = True
|
| 434 |
+
self.num_upsamplers += 1
|
| 435 |
+
else:
|
| 436 |
+
add_upsample = False
|
| 437 |
+
|
| 438 |
+
up_block = get_up_block(
|
| 439 |
+
up_block_type,
|
| 440 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
| 441 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
| 442 |
+
in_channels=input_channel,
|
| 443 |
+
out_channels=output_channel,
|
| 444 |
+
prev_output_channel=prev_output_channel,
|
| 445 |
+
temb_channels=blocks_time_embed_dim,
|
| 446 |
+
add_upsample=add_upsample,
|
| 447 |
+
resnet_eps=norm_eps,
|
| 448 |
+
resnet_act_fn=act_fn,
|
| 449 |
+
resolution_idx=i,
|
| 450 |
+
resnet_groups=norm_num_groups,
|
| 451 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
| 452 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
| 453 |
+
dual_cross_attention=dual_cross_attention,
|
| 454 |
+
use_linear_projection=use_linear_projection,
|
| 455 |
+
only_cross_attention=only_cross_attention[i],
|
| 456 |
+
upcast_attention=upcast_attention,
|
| 457 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 458 |
+
attention_type=attention_type,
|
| 459 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
| 460 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
| 461 |
+
cross_attention_norm=cross_attention_norm,
|
| 462 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
| 463 |
+
dropout=dropout,
|
| 464 |
+
)
|
| 465 |
+
self.up_blocks.append(up_block)
|
| 466 |
+
prev_output_channel = output_channel
|
| 467 |
+
|
| 468 |
+
# out
|
| 469 |
+
if norm_num_groups is not None:
|
| 470 |
+
self.conv_norm_out = nn.GroupNorm(
|
| 471 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
self.conv_act = get_activation(act_fn)
|
| 475 |
+
|
| 476 |
+
else:
|
| 477 |
+
self.conv_norm_out = None
|
| 478 |
+
self.conv_act = None
|
| 479 |
+
|
| 480 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
| 481 |
+
self.conv_out = nn.Conv2d(
|
| 482 |
+
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
|
| 486 |
+
|
| 487 |
+
def _check_config(
|
| 488 |
+
self,
|
| 489 |
+
down_block_types: Tuple[str],
|
| 490 |
+
up_block_types: Tuple[str],
|
| 491 |
+
only_cross_attention: Union[bool, Tuple[bool]],
|
| 492 |
+
block_out_channels: Tuple[int],
|
| 493 |
+
layers_per_block: Union[int, Tuple[int]],
|
| 494 |
+
cross_attention_dim: Union[int, Tuple[int]],
|
| 495 |
+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
|
| 496 |
+
reverse_transformer_layers_per_block: bool,
|
| 497 |
+
attention_head_dim: int,
|
| 498 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]],
|
| 499 |
+
):
|
| 500 |
+
if len(down_block_types) != len(up_block_types):
|
| 501 |
+
raise ValueError(
|
| 502 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
if len(block_out_channels) != len(down_block_types):
|
| 506 |
+
raise ValueError(
|
| 507 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
| 511 |
+
raise ValueError(
|
| 512 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
| 516 |
+
raise ValueError(
|
| 517 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
| 521 |
+
raise ValueError(
|
| 522 |
+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
| 526 |
+
raise ValueError(
|
| 527 |
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
| 531 |
+
raise ValueError(
|
| 532 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
| 533 |
+
)
|
| 534 |
+
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
|
| 535 |
+
for layer_number_per_block in transformer_layers_per_block:
|
| 536 |
+
if isinstance(layer_number_per_block, list):
|
| 537 |
+
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
|
| 538 |
+
|
| 539 |
+
def _set_time_proj(
|
| 540 |
+
self,
|
| 541 |
+
time_embedding_type: str,
|
| 542 |
+
block_out_channels: int,
|
| 543 |
+
flip_sin_to_cos: bool,
|
| 544 |
+
freq_shift: float,
|
| 545 |
+
time_embedding_dim: int,
|
| 546 |
+
) -> Tuple[int, int]:
|
| 547 |
+
if time_embedding_type == "fourier":
|
| 548 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
| 549 |
+
if time_embed_dim % 2 != 0:
|
| 550 |
+
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
| 551 |
+
self.time_proj = GaussianFourierProjection(
|
| 552 |
+
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
| 553 |
+
)
|
| 554 |
+
timestep_input_dim = time_embed_dim
|
| 555 |
+
elif time_embedding_type == "positional":
|
| 556 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
| 557 |
+
|
| 558 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
| 559 |
+
timestep_input_dim = block_out_channels[0]
|
| 560 |
+
else:
|
| 561 |
+
raise ValueError(
|
| 562 |
+
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
return time_embed_dim, timestep_input_dim
|
| 566 |
+
|
| 567 |
+
def _set_encoder_hid_proj(
|
| 568 |
+
self,
|
| 569 |
+
encoder_hid_dim_type: Optional[str],
|
| 570 |
+
cross_attention_dim: Union[int, Tuple[int]],
|
| 571 |
+
encoder_hid_dim: Optional[int],
|
| 572 |
+
):
|
| 573 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
| 574 |
+
encoder_hid_dim_type = "text_proj"
|
| 575 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
| 576 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
| 577 |
+
|
| 578 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
| 579 |
+
raise ValueError(
|
| 580 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
if encoder_hid_dim_type == "text_proj":
|
| 584 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
| 585 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
| 586 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
| 587 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
| 588 |
+
# case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
|
| 589 |
+
self.encoder_hid_proj = TextImageProjection(
|
| 590 |
+
text_embed_dim=encoder_hid_dim,
|
| 591 |
+
image_embed_dim=cross_attention_dim,
|
| 592 |
+
cross_attention_dim=cross_attention_dim,
|
| 593 |
+
)
|
| 594 |
+
elif encoder_hid_dim_type == "image_proj":
|
| 595 |
+
# Kandinsky 2.2
|
| 596 |
+
self.encoder_hid_proj = ImageProjection(
|
| 597 |
+
image_embed_dim=encoder_hid_dim,
|
| 598 |
+
cross_attention_dim=cross_attention_dim,
|
| 599 |
+
)
|
| 600 |
+
elif encoder_hid_dim_type is not None:
|
| 601 |
+
raise ValueError(
|
| 602 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
| 603 |
+
)
|
| 604 |
+
else:
|
| 605 |
+
self.encoder_hid_proj = None
|
| 606 |
+
|
| 607 |
+
def _set_class_embedding(
|
| 608 |
+
self,
|
| 609 |
+
class_embed_type: Optional[str],
|
| 610 |
+
act_fn: str,
|
| 611 |
+
num_class_embeds: Optional[int],
|
| 612 |
+
projection_class_embeddings_input_dim: Optional[int],
|
| 613 |
+
time_embed_dim: int,
|
| 614 |
+
timestep_input_dim: int,
|
| 615 |
+
):
|
| 616 |
+
if class_embed_type is None and num_class_embeds is not None:
|
| 617 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
| 618 |
+
elif class_embed_type == "timestep":
|
| 619 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
| 620 |
+
elif class_embed_type == "identity":
|
| 621 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
| 622 |
+
elif class_embed_type == "projection":
|
| 623 |
+
if projection_class_embeddings_input_dim is None:
|
| 624 |
+
raise ValueError(
|
| 625 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
| 626 |
+
)
|
| 627 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
| 628 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
| 629 |
+
# 2. it projects from an arbitrary input dimension.
|
| 630 |
+
#
|
| 631 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
| 632 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
| 633 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
| 634 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
| 635 |
+
elif class_embed_type == "simple_projection":
|
| 636 |
+
if projection_class_embeddings_input_dim is None:
|
| 637 |
+
raise ValueError(
|
| 638 |
+
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
|
| 639 |
+
)
|
| 640 |
+
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
|
| 641 |
+
else:
|
| 642 |
+
self.class_embedding = None
|
| 643 |
+
|
| 644 |
+
def _set_add_embedding(
|
| 645 |
+
self,
|
| 646 |
+
addition_embed_type: str,
|
| 647 |
+
addition_embed_type_num_heads: int,
|
| 648 |
+
addition_time_embed_dim: Optional[int],
|
| 649 |
+
flip_sin_to_cos: bool,
|
| 650 |
+
freq_shift: float,
|
| 651 |
+
cross_attention_dim: Optional[int],
|
| 652 |
+
encoder_hid_dim: Optional[int],
|
| 653 |
+
projection_class_embeddings_input_dim: Optional[int],
|
| 654 |
+
time_embed_dim: int,
|
| 655 |
+
):
|
| 656 |
+
if addition_embed_type == "text":
|
| 657 |
+
if encoder_hid_dim is not None:
|
| 658 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
| 659 |
+
else:
|
| 660 |
+
text_time_embedding_from_dim = cross_attention_dim
|
| 661 |
+
|
| 662 |
+
self.add_embedding = TextTimeEmbedding(
|
| 663 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
| 664 |
+
)
|
| 665 |
+
elif addition_embed_type == "text_image":
|
| 666 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
| 667 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
| 668 |
+
# case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
|
| 669 |
+
self.add_embedding = TextImageTimeEmbedding(
|
| 670 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
| 671 |
+
)
|
| 672 |
+
elif addition_embed_type == "text_time":
|
| 673 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
| 674 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
| 675 |
+
elif addition_embed_type == "image":
|
| 676 |
+
# Kandinsky 2.2
|
| 677 |
+
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
| 678 |
+
elif addition_embed_type == "image_hint":
|
| 679 |
+
# Kandinsky 2.2 ControlNet
|
| 680 |
+
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
| 681 |
+
elif addition_embed_type is not None:
|
| 682 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
| 683 |
+
|
| 684 |
+
def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
|
| 685 |
+
if attention_type in ["gated", "gated-text-image"]:
|
| 686 |
+
positive_len = 768
|
| 687 |
+
if isinstance(cross_attention_dim, int):
|
| 688 |
+
positive_len = cross_attention_dim
|
| 689 |
+
elif isinstance(cross_attention_dim, (list, tuple)):
|
| 690 |
+
positive_len = cross_attention_dim[0]
|
| 691 |
+
|
| 692 |
+
feature_type = "text-only" if attention_type == "gated" else "text-image"
|
| 693 |
+
self.position_net = GLIGENTextBoundingboxProjection(
|
| 694 |
+
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
@property
|
| 698 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 699 |
+
r"""
|
| 700 |
+
Returns:
|
| 701 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 702 |
+
indexed by its weight name.
|
| 703 |
+
"""
|
| 704 |
+
# set recursively
|
| 705 |
+
processors = {}
|
| 706 |
+
|
| 707 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 708 |
+
if hasattr(module, "get_processor"):
|
| 709 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 710 |
+
|
| 711 |
+
for sub_name, child in module.named_children():
|
| 712 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 713 |
+
|
| 714 |
+
return processors
|
| 715 |
+
|
| 716 |
+
for name, module in self.named_children():
|
| 717 |
+
fn_recursive_add_processors(name, module, processors)
|
| 718 |
+
|
| 719 |
+
return processors
|
| 720 |
+
|
| 721 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 722 |
+
r"""
|
| 723 |
+
Sets the attention processor to use to compute attention.
|
| 724 |
+
|
| 725 |
+
Parameters:
|
| 726 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 727 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 728 |
+
for **all** `Attention` layers.
|
| 729 |
+
|
| 730 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 731 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 732 |
+
|
| 733 |
+
"""
|
| 734 |
+
count = len(self.attn_processors.keys())
|
| 735 |
+
|
| 736 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 737 |
+
raise ValueError(
|
| 738 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 739 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 743 |
+
if hasattr(module, "set_processor"):
|
| 744 |
+
if not isinstance(processor, dict):
|
| 745 |
+
module.set_processor(processor)
|
| 746 |
+
else:
|
| 747 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 748 |
+
|
| 749 |
+
for sub_name, child in module.named_children():
|
| 750 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 751 |
+
|
| 752 |
+
for name, module in self.named_children():
|
| 753 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 754 |
+
|
| 755 |
+
def set_default_attn_processor(self):
|
| 756 |
+
"""
|
| 757 |
+
Disables custom attention processors and sets the default attention implementation.
|
| 758 |
+
"""
|
| 759 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
| 760 |
+
processor = AttnAddedKVProcessor()
|
| 761 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
| 762 |
+
processor = AttnProcessor()
|
| 763 |
+
else:
|
| 764 |
+
raise ValueError(
|
| 765 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
self.set_attn_processor(processor)
|
| 769 |
+
|
| 770 |
+
def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
|
| 771 |
+
r"""
|
| 772 |
+
Enable sliced attention computation.
|
| 773 |
+
|
| 774 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
| 775 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
| 776 |
+
|
| 777 |
+
Args:
|
| 778 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
| 779 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
| 780 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
| 781 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
| 782 |
+
must be a multiple of `slice_size`.
|
| 783 |
+
"""
|
| 784 |
+
sliceable_head_dims = []
|
| 785 |
+
|
| 786 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
| 787 |
+
if hasattr(module, "set_attention_slice"):
|
| 788 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
| 789 |
+
|
| 790 |
+
for child in module.children():
|
| 791 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
| 792 |
+
|
| 793 |
+
# retrieve number of attention layers
|
| 794 |
+
for module in self.children():
|
| 795 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
| 796 |
+
|
| 797 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
| 798 |
+
|
| 799 |
+
if slice_size == "auto":
|
| 800 |
+
# half the attention head size is usually a good trade-off between
|
| 801 |
+
# speed and memory
|
| 802 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
| 803 |
+
elif slice_size == "max":
|
| 804 |
+
# make smallest slice possible
|
| 805 |
+
slice_size = num_sliceable_layers * [1]
|
| 806 |
+
|
| 807 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
| 808 |
+
|
| 809 |
+
if len(slice_size) != len(sliceable_head_dims):
|
| 810 |
+
raise ValueError(
|
| 811 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
| 812 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
for i in range(len(slice_size)):
|
| 816 |
+
size = slice_size[i]
|
| 817 |
+
dim = sliceable_head_dims[i]
|
| 818 |
+
if size is not None and size > dim:
|
| 819 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
| 820 |
+
|
| 821 |
+
# Recursively walk through all the children.
|
| 822 |
+
# Any children which exposes the set_attention_slice method
|
| 823 |
+
# gets the message
|
| 824 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
| 825 |
+
if hasattr(module, "set_attention_slice"):
|
| 826 |
+
module.set_attention_slice(slice_size.pop())
|
| 827 |
+
|
| 828 |
+
for child in module.children():
|
| 829 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
| 830 |
+
|
| 831 |
+
reversed_slice_size = list(reversed(slice_size))
|
| 832 |
+
for module in self.children():
|
| 833 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
| 834 |
+
|
| 835 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 836 |
+
if hasattr(module, "gradient_checkpointing"):
|
| 837 |
+
module.gradient_checkpointing = value
|
| 838 |
+
|
| 839 |
+
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
| 840 |
+
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
| 841 |
+
|
| 842 |
+
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
| 843 |
+
|
| 844 |
+
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
|
| 845 |
+
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
| 846 |
+
|
| 847 |
+
Args:
|
| 848 |
+
s1 (`float`):
|
| 849 |
+
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
| 850 |
+
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
| 851 |
+
s2 (`float`):
|
| 852 |
+
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
| 853 |
+
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
| 854 |
+
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
| 855 |
+
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
| 856 |
+
"""
|
| 857 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 858 |
+
setattr(upsample_block, "s1", s1)
|
| 859 |
+
setattr(upsample_block, "s2", s2)
|
| 860 |
+
setattr(upsample_block, "b1", b1)
|
| 861 |
+
setattr(upsample_block, "b2", b2)
|
| 862 |
+
|
| 863 |
+
def disable_freeu(self):
|
| 864 |
+
"""Disables the FreeU mechanism."""
|
| 865 |
+
freeu_keys = {"s1", "s2", "b1", "b2"}
|
| 866 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 867 |
+
for k in freeu_keys:
|
| 868 |
+
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
| 869 |
+
setattr(upsample_block, k, None)
|
| 870 |
+
|
| 871 |
+
def fuse_qkv_projections(self):
|
| 872 |
+
"""
|
| 873 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
| 874 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
| 875 |
+
|
| 876 |
+
<Tip warning={true}>
|
| 877 |
+
|
| 878 |
+
This API is ?? experimental.
|
| 879 |
+
|
| 880 |
+
</Tip>
|
| 881 |
+
"""
|
| 882 |
+
self.original_attn_processors = None
|
| 883 |
+
|
| 884 |
+
for _, attn_processor in self.attn_processors.items():
|
| 885 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
| 886 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
| 887 |
+
|
| 888 |
+
self.original_attn_processors = self.attn_processors
|
| 889 |
+
|
| 890 |
+
for module in self.modules():
|
| 891 |
+
if isinstance(module, Attention):
|
| 892 |
+
module.fuse_projections(fuse=True)
|
| 893 |
+
|
| 894 |
+
self.set_attn_processor(FusedAttnProcessor2_0())
|
| 895 |
+
|
| 896 |
+
def unfuse_qkv_projections(self):
|
| 897 |
+
"""Disables the fused QKV projection if enabled.
|
| 898 |
+
|
| 899 |
+
<Tip warning={true}>
|
| 900 |
+
|
| 901 |
+
This API is ?? experimental.
|
| 902 |
+
|
| 903 |
+
</Tip>
|
| 904 |
+
|
| 905 |
+
"""
|
| 906 |
+
if self.original_attn_processors is not None:
|
| 907 |
+
self.set_attn_processor(self.original_attn_processors)
|
| 908 |
+
|
| 909 |
+
def get_time_embed(
|
| 910 |
+
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
|
| 911 |
+
) -> Optional[torch.Tensor]:
|
| 912 |
+
timesteps = timestep
|
| 913 |
+
if not torch.is_tensor(timesteps):
|
| 914 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
| 915 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
| 916 |
+
is_mps = sample.device.type == "mps"
|
| 917 |
+
if isinstance(timestep, float):
|
| 918 |
+
dtype = torch.float32 if is_mps else torch.float64
|
| 919 |
+
else:
|
| 920 |
+
dtype = torch.int32 if is_mps else torch.int64
|
| 921 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
| 922 |
+
elif len(timesteps.shape) == 0:
|
| 923 |
+
timesteps = timesteps[None].to(sample.device)
|
| 924 |
+
|
| 925 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 926 |
+
timesteps = timesteps.expand(sample.shape[0])
|
| 927 |
+
|
| 928 |
+
t_emb = self.time_proj(timesteps)
|
| 929 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
| 930 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 931 |
+
# there might be better ways to encapsulate this.
|
| 932 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
| 933 |
+
return t_emb
|
| 934 |
+
|
| 935 |
+
def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
| 936 |
+
class_emb = None
|
| 937 |
+
if self.class_embedding is not None:
|
| 938 |
+
if class_labels is None:
|
| 939 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
| 940 |
+
|
| 941 |
+
if self.config.class_embed_type == "timestep":
|
| 942 |
+
class_labels = self.time_proj(class_labels)
|
| 943 |
+
|
| 944 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
| 945 |
+
# there might be better ways to encapsulate this.
|
| 946 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
| 947 |
+
|
| 948 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
| 949 |
+
return class_emb
|
| 950 |
+
|
| 951 |
+
def get_aug_embed(
|
| 952 |
+
self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
|
| 953 |
+
) -> Optional[torch.Tensor]:
|
| 954 |
+
aug_emb = None
|
| 955 |
+
if self.config.addition_embed_type == "text":
|
| 956 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
| 957 |
+
elif self.config.addition_embed_type == "text_image":
|
| 958 |
+
# Kandinsky 2.1 - style
|
| 959 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 960 |
+
raise ValueError(
|
| 961 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
| 962 |
+
)
|
| 963 |
+
|
| 964 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
| 965 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
| 966 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
| 967 |
+
elif self.config.addition_embed_type == "text_time":
|
| 968 |
+
# SDXL - style
|
| 969 |
+
if "text_embeds" not in added_cond_kwargs:
|
| 970 |
+
raise ValueError(
|
| 971 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
| 972 |
+
)
|
| 973 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
| 974 |
+
if "time_ids" not in added_cond_kwargs:
|
| 975 |
+
raise ValueError(
|
| 976 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
| 977 |
+
)
|
| 978 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
| 979 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
| 980 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
| 981 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
| 982 |
+
add_embeds = add_embeds.to(emb.dtype)
|
| 983 |
+
aug_emb = self.add_embedding(add_embeds)
|
| 984 |
+
elif self.config.addition_embed_type == "image":
|
| 985 |
+
# Kandinsky 2.2 - style
|
| 986 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 987 |
+
raise ValueError(
|
| 988 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
| 989 |
+
)
|
| 990 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
| 991 |
+
aug_emb = self.add_embedding(image_embs)
|
| 992 |
+
elif self.config.addition_embed_type == "image_hint":
|
| 993 |
+
# Kandinsky 2.2 - style
|
| 994 |
+
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
| 995 |
+
raise ValueError(
|
| 996 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
| 997 |
+
)
|
| 998 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
| 999 |
+
hint = added_cond_kwargs.get("hint")
|
| 1000 |
+
aug_emb = self.add_embedding(image_embs, hint)
|
| 1001 |
+
return aug_emb
|
| 1002 |
+
|
| 1003 |
+
def process_encoder_hidden_states(
|
| 1004 |
+
self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
|
| 1005 |
+
) -> torch.Tensor:
|
| 1006 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
| 1007 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
| 1008 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
| 1009 |
+
# Kandinsky 2.1 - style
|
| 1010 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 1011 |
+
raise ValueError(
|
| 1012 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
| 1013 |
+
)
|
| 1014 |
+
|
| 1015 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
| 1016 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
| 1017 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
| 1018 |
+
# Kandinsky 2.2 - style
|
| 1019 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 1020 |
+
raise ValueError(
|
| 1021 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
| 1022 |
+
)
|
| 1023 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
| 1024 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
| 1025 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
| 1026 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 1027 |
+
raise ValueError(
|
| 1028 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
| 1029 |
+
)
|
| 1030 |
+
|
| 1031 |
+
if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None:
|
| 1032 |
+
encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states)
|
| 1033 |
+
|
| 1034 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
| 1035 |
+
image_embeds = self.encoder_hid_proj(image_embeds)
|
| 1036 |
+
encoder_hidden_states = (encoder_hidden_states, image_embeds)
|
| 1037 |
+
return encoder_hidden_states
|
| 1038 |
+
|
| 1039 |
+
def forward(
|
| 1040 |
+
self,
|
| 1041 |
+
sample: torch.Tensor,
|
| 1042 |
+
timestep: Union[torch.Tensor, float, int],
|
| 1043 |
+
encoder_hidden_states: torch.Tensor,
|
| 1044 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 1045 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
| 1046 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1047 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 1048 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
| 1049 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
| 1050 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
| 1051 |
+
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
| 1052 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 1053 |
+
use_up_blocks: bool = False,
|
| 1054 |
+
return_dict: bool = True,
|
| 1055 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
| 1056 |
+
r"""
|
| 1057 |
+
The [`UNet2DConditionModel`] forward method.
|
| 1058 |
+
|
| 1059 |
+
Args:
|
| 1060 |
+
sample (`torch.Tensor`):
|
| 1061 |
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
| 1062 |
+
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
|
| 1063 |
+
encoder_hidden_states (`torch.Tensor`):
|
| 1064 |
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
| 1065 |
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
| 1066 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
| 1067 |
+
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
|
| 1068 |
+
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
|
| 1069 |
+
through the `self.time_embedding` layer to obtain the timestep embeddings.
|
| 1070 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
| 1071 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
| 1072 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
| 1073 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
| 1074 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 1075 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 1076 |
+
`self.processor` in
|
| 1077 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 1078 |
+
added_cond_kwargs: (`dict`, *optional*):
|
| 1079 |
+
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
|
| 1080 |
+
are passed along to the UNet blocks.
|
| 1081 |
+
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
|
| 1082 |
+
A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
| 1083 |
+
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
| 1084 |
+
A tensor that if specified is added to the residual of the middle unet block.
|
| 1085 |
+
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
| 1086 |
+
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
| 1087 |
+
encoder_attention_mask (`torch.Tensor`):
|
| 1088 |
+
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
| 1089 |
+
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
| 1090 |
+
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
| 1091 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1092 |
+
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
| 1093 |
+
tuple.
|
| 1094 |
+
|
| 1095 |
+
Returns:
|
| 1096 |
+
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
| 1097 |
+
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
|
| 1098 |
+
otherwise a `tuple` is returned where the first element is the sample tensor.
|
| 1099 |
+
"""
|
| 1100 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
| 1101 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
| 1102 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
| 1103 |
+
# on the fly if necessary.
|
| 1104 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
| 1105 |
+
|
| 1106 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
| 1107 |
+
forward_upsample_size = False
|
| 1108 |
+
upsample_size = None
|
| 1109 |
+
|
| 1110 |
+
# import time
|
| 1111 |
+
# torch.cuda.synchronize()
|
| 1112 |
+
# start_time = time.time()
|
| 1113 |
+
|
| 1114 |
+
for dim in sample.shape[-2:]:
|
| 1115 |
+
if dim % default_overall_up_factor != 0:
|
| 1116 |
+
# Forward upsample size to force interpolation output size.
|
| 1117 |
+
forward_upsample_size = True
|
| 1118 |
+
break
|
| 1119 |
+
|
| 1120 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
| 1121 |
+
# expects mask of shape:
|
| 1122 |
+
# [batch, key_tokens]
|
| 1123 |
+
# adds singleton query_tokens dimension:
|
| 1124 |
+
# [batch, 1, key_tokens]
|
| 1125 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
| 1126 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
| 1127 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
| 1128 |
+
if attention_mask is not None:
|
| 1129 |
+
# assume that mask is expressed as:
|
| 1130 |
+
# (1 = keep, 0 = discard)
|
| 1131 |
+
# convert mask into a bias that can be added to attention scores:
|
| 1132 |
+
# (keep = +0, discard = -10000.0)
|
| 1133 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
| 1134 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 1135 |
+
|
| 1136 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
| 1137 |
+
if encoder_attention_mask is not None:
|
| 1138 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
| 1139 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
| 1140 |
+
|
| 1141 |
+
# 0. center input if necessary
|
| 1142 |
+
if self.config.center_input_sample:
|
| 1143 |
+
sample = 2 * sample - 1.0
|
| 1144 |
+
|
| 1145 |
+
# 1. time
|
| 1146 |
+
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
|
| 1147 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
| 1148 |
+
aug_emb = None
|
| 1149 |
+
|
| 1150 |
+
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
|
| 1151 |
+
if class_emb is not None:
|
| 1152 |
+
if self.config.class_embeddings_concat:
|
| 1153 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
| 1154 |
+
else:
|
| 1155 |
+
emb = emb + class_emb
|
| 1156 |
+
|
| 1157 |
+
aug_emb = self.get_aug_embed(
|
| 1158 |
+
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
| 1159 |
+
)
|
| 1160 |
+
if self.config.addition_embed_type == "image_hint":
|
| 1161 |
+
aug_emb, hint = aug_emb
|
| 1162 |
+
sample = torch.cat([sample, hint], dim=1)
|
| 1163 |
+
|
| 1164 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
| 1165 |
+
|
| 1166 |
+
if self.time_embed_act is not None:
|
| 1167 |
+
emb = self.time_embed_act(emb)
|
| 1168 |
+
|
| 1169 |
+
encoder_hidden_states = self.process_encoder_hidden_states(
|
| 1170 |
+
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
| 1171 |
+
)
|
| 1172 |
+
|
| 1173 |
+
# 2. pre-process
|
| 1174 |
+
sample = self.conv_in(sample)
|
| 1175 |
+
|
| 1176 |
+
# 2.5 GLIGEN position net
|
| 1177 |
+
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
| 1178 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
| 1179 |
+
gligen_args = cross_attention_kwargs.pop("gligen")
|
| 1180 |
+
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
| 1181 |
+
|
| 1182 |
+
# 3. down
|
| 1183 |
+
# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
|
| 1184 |
+
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
|
| 1185 |
+
if cross_attention_kwargs is not None:
|
| 1186 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
| 1187 |
+
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
|
| 1188 |
+
else:
|
| 1189 |
+
lora_scale = 1.0
|
| 1190 |
+
|
| 1191 |
+
if USE_PEFT_BACKEND:
|
| 1192 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 1193 |
+
scale_lora_layers(self, lora_scale)
|
| 1194 |
+
|
| 1195 |
+
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
| 1196 |
+
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
| 1197 |
+
is_adapter = down_intrablock_additional_residuals is not None
|
| 1198 |
+
# maintain backward compatibility for legacy usage, where
|
| 1199 |
+
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
|
| 1200 |
+
# but can only use one or the other
|
| 1201 |
+
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
|
| 1202 |
+
deprecate(
|
| 1203 |
+
"T2I should not use down_block_additional_residuals",
|
| 1204 |
+
"1.3.0",
|
| 1205 |
+
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
| 1206 |
+
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
| 1207 |
+
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
| 1208 |
+
standard_warn=False,
|
| 1209 |
+
)
|
| 1210 |
+
down_intrablock_additional_residuals = down_block_additional_residuals
|
| 1211 |
+
is_adapter = True
|
| 1212 |
+
|
| 1213 |
+
# torch.cuda.synchronize()
|
| 1214 |
+
# logger.info(f"unet preprocess: {time.time() - start_time}")
|
| 1215 |
+
|
| 1216 |
+
# torch.cuda.synchronize()
|
| 1217 |
+
# start_time = time.time()
|
| 1218 |
+
down_block_res_samples = (sample,)
|
| 1219 |
+
for downsample_block in self.down_blocks:
|
| 1220 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 1221 |
+
# For t2i-adapter CrossAttnDownBlock2D
|
| 1222 |
+
additional_residuals = {}
|
| 1223 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
| 1224 |
+
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
|
| 1225 |
+
|
| 1226 |
+
sample, res_samples = downsample_block(
|
| 1227 |
+
hidden_states=sample,
|
| 1228 |
+
temb=emb,
|
| 1229 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1230 |
+
attention_mask=attention_mask,
|
| 1231 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 1232 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1233 |
+
**additional_residuals,
|
| 1234 |
+
)
|
| 1235 |
+
else:
|
| 1236 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
| 1237 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
| 1238 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
| 1239 |
+
|
| 1240 |
+
down_block_res_samples += res_samples
|
| 1241 |
+
|
| 1242 |
+
if is_controlnet:
|
| 1243 |
+
new_down_block_res_samples = ()
|
| 1244 |
+
|
| 1245 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
| 1246 |
+
down_block_res_samples, down_block_additional_residuals
|
| 1247 |
+
):
|
| 1248 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
| 1249 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
| 1250 |
+
|
| 1251 |
+
down_block_res_samples = new_down_block_res_samples
|
| 1252 |
+
# torch.cuda.synchronize()
|
| 1253 |
+
# logger.info(f"unet down time: {time.time() - start_time}")
|
| 1254 |
+
# torch.cuda.synchronize()
|
| 1255 |
+
# start_time = time.time()
|
| 1256 |
+
# 4. mid
|
| 1257 |
+
if self.mid_block is not None:
|
| 1258 |
+
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
| 1259 |
+
sample = self.mid_block(
|
| 1260 |
+
sample,
|
| 1261 |
+
emb,
|
| 1262 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1263 |
+
attention_mask=attention_mask,
|
| 1264 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 1265 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1266 |
+
)
|
| 1267 |
+
else:
|
| 1268 |
+
sample = self.mid_block(sample, emb)
|
| 1269 |
+
|
| 1270 |
+
# To support T2I-Adapter-XL
|
| 1271 |
+
if (
|
| 1272 |
+
is_adapter
|
| 1273 |
+
and len(down_intrablock_additional_residuals) > 0
|
| 1274 |
+
and sample.shape == down_intrablock_additional_residuals[0].shape
|
| 1275 |
+
):
|
| 1276 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
| 1277 |
+
|
| 1278 |
+
if is_controlnet:
|
| 1279 |
+
sample = sample + mid_block_additional_residual
|
| 1280 |
+
# torch.cuda.synchronize()
|
| 1281 |
+
# logger.info(f"unet mid time: {time.time() - start_time}")
|
| 1282 |
+
mid_sample = sample
|
| 1283 |
+
|
| 1284 |
+
if use_up_blocks:
|
| 1285 |
+
# 5. up
|
| 1286 |
+
up_block_res_samples = ()
|
| 1287 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 1288 |
+
is_final_block = i == len(self.up_blocks) - 1
|
| 1289 |
+
|
| 1290 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| 1291 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
| 1292 |
+
|
| 1293 |
+
# if we have not reached the final block and need to forward the
|
| 1294 |
+
# upsample size, we do it here
|
| 1295 |
+
if not is_final_block and forward_upsample_size:
|
| 1296 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
| 1297 |
+
|
| 1298 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
| 1299 |
+
sample = upsample_block(
|
| 1300 |
+
hidden_states=sample,
|
| 1301 |
+
temb=emb,
|
| 1302 |
+
res_hidden_states_tuple=res_samples,
|
| 1303 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1304 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 1305 |
+
upsample_size=upsample_size,
|
| 1306 |
+
attention_mask=attention_mask,
|
| 1307 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1308 |
+
)
|
| 1309 |
+
else:
|
| 1310 |
+
sample = upsample_block(
|
| 1311 |
+
hidden_states=sample,
|
| 1312 |
+
temb=emb,
|
| 1313 |
+
res_hidden_states_tuple=res_samples,
|
| 1314 |
+
upsample_size=upsample_size,
|
| 1315 |
+
)
|
| 1316 |
+
up_block_res_samples += (sample, )
|
| 1317 |
+
|
| 1318 |
+
# # 6. post-process
|
| 1319 |
+
# if self.conv_norm_out:
|
| 1320 |
+
# sample = self.conv_norm_out(sample)
|
| 1321 |
+
# sample = self.conv_act(sample)
|
| 1322 |
+
# sample = self.conv_out(sample)
|
| 1323 |
+
|
| 1324 |
+
if USE_PEFT_BACKEND:
|
| 1325 |
+
# remove `lora_scale` from each PEFT layer
|
| 1326 |
+
unscale_lora_layers(self, lora_scale)
|
| 1327 |
+
|
| 1328 |
+
if not return_dict:
|
| 1329 |
+
if use_up_blocks:
|
| 1330 |
+
return (mid_sample, down_block_res_samples, up_block_res_samples)
|
| 1331 |
+
else:
|
| 1332 |
+
return (mid_sample, down_block_res_samples)
|
| 1333 |
+
|
| 1334 |
+
return UNet2DConditionOutput(sample=sample)
|
Reward_sdxl_idealized/pipelines/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (287 Bytes). View file
|
|
|
Reward_sdxl_idealized/pipelines/sdxl_gradient_ascent_pipeline.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Stable Diffusion XL Pipeline with Gradient Ascent Reward Guidance
|
| 3 |
+
|
| 4 |
+
Fully compatible with HuggingFace diffusers StableDiffusionXLPipeline.
|
| 5 |
+
Injects the Kinetic Latent Predictor-Corrector operator split safely
|
| 6 |
+
before the ODE integration step.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import inspect
|
| 10 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from diffusers import StableDiffusionXLPipeline
|
| 14 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
| 15 |
+
|
| 16 |
+
from gradient_ascent_utils import RewardGuidedDiffusion
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
| 20 |
+
"""
|
| 21 |
+
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure.
|
| 22 |
+
(Matches diffusers implementation)
|
| 23 |
+
"""
|
| 24 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
| 25 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
| 26 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
| 27 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
| 28 |
+
return noise_cfg
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class StableDiffusionXLGradientAscentPipeline(StableDiffusionXLPipeline):
|
| 32 |
+
"""
|
| 33 |
+
SDXL Pipeline with KLPC gradient ascent reward guidance.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
vae,
|
| 39 |
+
text_encoder,
|
| 40 |
+
text_encoder_2,
|
| 41 |
+
tokenizer,
|
| 42 |
+
tokenizer_2,
|
| 43 |
+
unet,
|
| 44 |
+
scheduler,
|
| 45 |
+
image_encoder=None,
|
| 46 |
+
feature_extractor=None,
|
| 47 |
+
force_zeros_for_empty_prompt: bool = True,
|
| 48 |
+
add_watermarker: bool = None,
|
| 49 |
+
):
|
| 50 |
+
super().__init__(
|
| 51 |
+
vae=vae,
|
| 52 |
+
text_encoder=text_encoder,
|
| 53 |
+
text_encoder_2=text_encoder_2,
|
| 54 |
+
tokenizer=tokenizer,
|
| 55 |
+
tokenizer_2=tokenizer_2,
|
| 56 |
+
unet=unet,
|
| 57 |
+
scheduler=scheduler,
|
| 58 |
+
image_encoder=image_encoder,
|
| 59 |
+
feature_extractor=feature_extractor,
|
| 60 |
+
force_zeros_for_empty_prompt=force_zeros_for_empty_prompt,
|
| 61 |
+
add_watermarker=add_watermarker,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Initialize our KLPC custom state variables
|
| 65 |
+
self.gradient_ascent_enabled = False
|
| 66 |
+
self.grad_guidance = None
|
| 67 |
+
self.reward_model = None
|
| 68 |
+
self.reward_history = []
|
| 69 |
+
|
| 70 |
+
def set_reward_model(self, reward_model):
|
| 71 |
+
self.reward_model = reward_model
|
| 72 |
+
|
| 73 |
+
def enable_gradient_ascent(
|
| 74 |
+
self,
|
| 75 |
+
grad_timestep_range: Tuple[int, int] = (500, 700),
|
| 76 |
+
grad_scale: float = 1.0,
|
| 77 |
+
num_grad_steps: int = 5,
|
| 78 |
+
grad_step_size: float = 0.1,
|
| 79 |
+
lr_scheduler_type: str = "constant",
|
| 80 |
+
lr_scheduler_kwargs: Optional[dict] = None,
|
| 81 |
+
use_momentum: bool = False,
|
| 82 |
+
momentum: float = 0.9,
|
| 83 |
+
use_nesterov: bool = False,
|
| 84 |
+
use_iso_projection: bool = False
|
| 85 |
+
):
|
| 86 |
+
if self.reward_model is None:
|
| 87 |
+
raise ValueError("Reward model must be set first via set_reward_model().")
|
| 88 |
+
|
| 89 |
+
self.grad_guidance = RewardGuidedDiffusion(
|
| 90 |
+
reward_model=self.reward_model,
|
| 91 |
+
grad_scale=grad_scale,
|
| 92 |
+
grad_timestep_range=grad_timestep_range,
|
| 93 |
+
num_grad_steps=num_grad_steps,
|
| 94 |
+
grad_step_size=grad_step_size,
|
| 95 |
+
lr_scheduler_type=lr_scheduler_type,
|
| 96 |
+
lr_scheduler_kwargs=lr_scheduler_kwargs or {},
|
| 97 |
+
use_momentum=use_momentum,
|
| 98 |
+
momentum=momentum,
|
| 99 |
+
use_nesterov=use_nesterov,
|
| 100 |
+
use_iso_projection=use_iso_projection
|
| 101 |
+
)
|
| 102 |
+
self.gradient_ascent_enabled = True
|
| 103 |
+
|
| 104 |
+
@torch.no_grad()
|
| 105 |
+
def __call__(
|
| 106 |
+
self,
|
| 107 |
+
prompt: Union[str, List[str]] = None,
|
| 108 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 109 |
+
height: Optional[int] = None,
|
| 110 |
+
width: Optional[int] = None,
|
| 111 |
+
num_inference_steps: int = 50,
|
| 112 |
+
timesteps: List[int] = None,
|
| 113 |
+
sigmas: List[float] = None,
|
| 114 |
+
denoising_end: Optional[float] = None,
|
| 115 |
+
guidance_scale: float = 5.0,
|
| 116 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 117 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 118 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 119 |
+
eta: float = 0.0,
|
| 120 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 121 |
+
latents: Optional[torch.Tensor] = None,
|
| 122 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 123 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 124 |
+
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
| 125 |
+
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
| 126 |
+
ip_adapter_image: Optional[Any] = None,
|
| 127 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 128 |
+
output_type: Optional[str] = "pil",
|
| 129 |
+
return_dict: bool = True,
|
| 130 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 131 |
+
guidance_rescale: float = 0.0,
|
| 132 |
+
original_size: Optional[Tuple[int, int]] = None,
|
| 133 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 134 |
+
target_size: Optional[Tuple[int, int]] = None,
|
| 135 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
| 136 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 137 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
| 138 |
+
clip_skip: Optional[int] = None,
|
| 139 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], Dict]] = None,
|
| 140 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 141 |
+
# Custom parameters for KLPC gradient ascent
|
| 142 |
+
track_rewards: bool = True,
|
| 143 |
+
apply_gradient_ascent: bool = True,
|
| 144 |
+
verbose_grad: bool = False,
|
| 145 |
+
**kwargs,
|
| 146 |
+
):
|
| 147 |
+
# 1. Setup sizes & Batch
|
| 148 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 149 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 150 |
+
|
| 151 |
+
original_size = original_size or (height, width)
|
| 152 |
+
target_size = target_size or (height, width)
|
| 153 |
+
|
| 154 |
+
# 2. Define batch size
|
| 155 |
+
if prompt is not None and isinstance(prompt, str):
|
| 156 |
+
batch_size = 1
|
| 157 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 158 |
+
batch_size = len(prompt)
|
| 159 |
+
else:
|
| 160 |
+
batch_size = prompt_embeds.shape[0]
|
| 161 |
+
|
| 162 |
+
device = self._execution_device
|
| 163 |
+
|
| 164 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 165 |
+
|
| 166 |
+
# 3. Encode input prompt
|
| 167 |
+
lora_scale = (
|
| 168 |
+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
(
|
| 172 |
+
prompt_embeds,
|
| 173 |
+
negative_prompt_embeds,
|
| 174 |
+
pooled_prompt_embeds,
|
| 175 |
+
negative_pooled_prompt_embeds,
|
| 176 |
+
) = self.encode_prompt(
|
| 177 |
+
prompt=prompt,
|
| 178 |
+
prompt_2=prompt_2,
|
| 179 |
+
device=device,
|
| 180 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 181 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 182 |
+
negative_prompt=negative_prompt,
|
| 183 |
+
negative_prompt_2=negative_prompt_2,
|
| 184 |
+
prompt_embeds=prompt_embeds,
|
| 185 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 186 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 187 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 188 |
+
lora_scale=lora_scale,
|
| 189 |
+
clip_skip=clip_skip,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# Keep HF internal state synchronized just in case external callbacks need them
|
| 193 |
+
self._guidance_scale = guidance_scale
|
| 194 |
+
self._guidance_rescale = guidance_rescale
|
| 195 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 196 |
+
self._interrupt = False
|
| 197 |
+
|
| 198 |
+
# 4. Prepare timesteps
|
| 199 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 200 |
+
timesteps = self.scheduler.timesteps
|
| 201 |
+
|
| 202 |
+
# 5. Prepare latent variables
|
| 203 |
+
num_channels_latents = self.unet.config.in_channels
|
| 204 |
+
latents = self.prepare_latents(
|
| 205 |
+
batch_size * num_images_per_prompt,
|
| 206 |
+
num_channels_latents,
|
| 207 |
+
height,
|
| 208 |
+
width,
|
| 209 |
+
prompt_embeds.dtype,
|
| 210 |
+
device,
|
| 211 |
+
generator,
|
| 212 |
+
latents,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# 6. Prepare extra step kwargs
|
| 216 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 217 |
+
|
| 218 |
+
# 7. Prepare added time ids & embeddings
|
| 219 |
+
add_text_embeds = pooled_prompt_embeds
|
| 220 |
+
if self.text_encoder_2 is None:
|
| 221 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
| 222 |
+
else:
|
| 223 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
| 224 |
+
|
| 225 |
+
add_time_ids = self._get_add_time_ids(
|
| 226 |
+
original_size,
|
| 227 |
+
crops_coords_top_left,
|
| 228 |
+
target_size,
|
| 229 |
+
dtype=prompt_embeds.dtype,
|
| 230 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
if negative_original_size is not None and negative_target_size is not None:
|
| 234 |
+
negative_add_time_ids = self._get_add_time_ids(
|
| 235 |
+
negative_original_size,
|
| 236 |
+
negative_crops_coords_top_left,
|
| 237 |
+
negative_target_size,
|
| 238 |
+
dtype=prompt_embeds.dtype,
|
| 239 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
| 240 |
+
)
|
| 241 |
+
else:
|
| 242 |
+
negative_add_time_ids = add_time_ids
|
| 243 |
+
|
| 244 |
+
if do_classifier_free_guidance:
|
| 245 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 246 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
| 247 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
| 248 |
+
|
| 249 |
+
prompt_embeds = prompt_embeds.to(device)
|
| 250 |
+
add_text_embeds = add_text_embeds.to(device)
|
| 251 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
| 252 |
+
|
| 253 |
+
# 8. Reset Tracking
|
| 254 |
+
self.reward_history = []
|
| 255 |
+
if self.grad_guidance is not None:
|
| 256 |
+
self.grad_guidance.reset_statistics()
|
| 257 |
+
|
| 258 |
+
# Extract string prompt for the reward model forward pass
|
| 259 |
+
prompt_str = prompt[0] if isinstance(prompt, list) else prompt
|
| 260 |
+
|
| 261 |
+
# 9. Denoising loop
|
| 262 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 263 |
+
|
| 264 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 265 |
+
for i, t in enumerate(timesteps):
|
| 266 |
+
|
| 267 |
+
# predict the noise residual
|
| 268 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
| 269 |
+
|
| 270 |
+
# =========================================================================
|
| 271 |
+
# KINETIC LATENT PREDICTOR-CORRECTOR (KLPC) OPERATOR SPLIT
|
| 272 |
+
# =========================================================================
|
| 273 |
+
# This executes the spatial displacement C_i(z) right BEFORE the scheduler
|
| 274 |
+
# integrates using the "stale" velocity (noise_pred), fulfilling Theorem 1.
|
| 275 |
+
if (
|
| 276 |
+
self.gradient_ascent_enabled
|
| 277 |
+
and apply_gradient_ascent
|
| 278 |
+
and self.grad_guidance
|
| 279 |
+
and self.grad_guidance.should_apply_gradient(t.item())
|
| 280 |
+
and prompt_str is not None
|
| 281 |
+
):
|
| 282 |
+
with torch.enable_grad():
|
| 283 |
+
latents, grad_stats = self.grad_guidance.apply_gradient_ascent(
|
| 284 |
+
latents,
|
| 285 |
+
prompt_str,
|
| 286 |
+
t.item(),
|
| 287 |
+
base_noise=None,
|
| 288 |
+
verbose=verbose_grad,
|
| 289 |
+
total_denoising_steps=num_inference_steps,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
|
| 293 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 294 |
+
|
| 295 |
+
noise_pred = self.unet(
|
| 296 |
+
latent_model_input,
|
| 297 |
+
t,
|
| 298 |
+
encoder_hidden_states=prompt_embeds,
|
| 299 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 300 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 301 |
+
return_dict=False,
|
| 302 |
+
)[0]
|
| 303 |
+
|
| 304 |
+
if do_classifier_free_guidance:
|
| 305 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 306 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 307 |
+
|
| 308 |
+
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
| 309 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
| 310 |
+
|
| 311 |
+
# Track Rewards
|
| 312 |
+
if track_rewards and self.reward_model is not None and prompt_str is not None:
|
| 313 |
+
with torch.no_grad():
|
| 314 |
+
score = self.reward_model.get_reward_score(latents, prompt_str, t.item())
|
| 315 |
+
score_val = score.item() if score.numel() == 1 else score.mean().item()
|
| 316 |
+
self.reward_history.append({'timestep': t.item(), 'reward_score': score_val})
|
| 317 |
+
# =========================================================================
|
| 318 |
+
|
| 319 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 320 |
+
latents_dtype = latents.dtype
|
| 321 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 322 |
+
|
| 323 |
+
if latents.dtype != latents_dtype:
|
| 324 |
+
latents = latents.to(latents_dtype)
|
| 325 |
+
|
| 326 |
+
# Execute diffusers native step end callback if provided
|
| 327 |
+
if callback_on_step_end is not None:
|
| 328 |
+
callback_kwargs = {}
|
| 329 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 330 |
+
callback_kwargs[k] = locals()[k]
|
| 331 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 332 |
+
|
| 333 |
+
latents = callback_outputs.pop("latents", latents)
|
| 334 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 335 |
+
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
| 336 |
+
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
| 337 |
+
|
| 338 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 339 |
+
progress_bar.update()
|
| 340 |
+
|
| 341 |
+
if not output_type == "latent":
|
| 342 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
| 343 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 344 |
+
if needs_upcasting:
|
| 345 |
+
self.upcast_vae()
|
| 346 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
| 347 |
+
elif latents.dtype != self.vae.dtype:
|
| 348 |
+
# HF Native fix: Cast the VAE to match the latents dtype!
|
| 349 |
+
self.vae = self.vae.to(latents.dtype)
|
| 350 |
+
|
| 351 |
+
# unscale/denormalize the latents
|
| 352 |
+
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
| 353 |
+
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
| 354 |
+
|
| 355 |
+
if has_latents_mean and has_latents_std:
|
| 356 |
+
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
| 357 |
+
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
| 358 |
+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
| 359 |
+
else:
|
| 360 |
+
latents = latents / self.vae.config.scaling_factor
|
| 361 |
+
|
| 362 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 363 |
+
|
| 364 |
+
# cast back to fp16 if needed
|
| 365 |
+
if needs_upcasting:
|
| 366 |
+
self.vae.to(dtype=torch.float16)
|
| 367 |
+
|
| 368 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 369 |
+
else:
|
| 370 |
+
image = latents
|
| 371 |
+
|
| 372 |
+
if not return_dict:
|
| 373 |
+
return (image,)
|
| 374 |
+
|
| 375 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
Reward_sdxl_idealized/timestep_convergence_analysis.ipynb
ADDED
|
@@ -0,0 +1,1105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "513b682a",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"#### Setup and Imports"
|
| 9 |
+
]
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"cell_type": "code",
|
| 13 |
+
"execution_count": null,
|
| 14 |
+
"id": "f9f44d61",
|
| 15 |
+
"metadata": {},
|
| 16 |
+
"outputs": [],
|
| 17 |
+
"source": [
|
| 18 |
+
"import os\n",
|
| 19 |
+
"import sys\n",
|
| 20 |
+
"import json\n",
|
| 21 |
+
"import warnings\n",
|
| 22 |
+
"import numpy as np\n",
|
| 23 |
+
"import matplotlib.pyplot as plt\n",
|
| 24 |
+
"import matplotlib\n",
|
| 25 |
+
"import torch\n",
|
| 26 |
+
"import torch.nn as nn\n",
|
| 27 |
+
"from pathlib import Path\n",
|
| 28 |
+
"from PIL import Image\n",
|
| 29 |
+
"from tqdm.auto import tqdm\n",
|
| 30 |
+
"from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel\n",
|
| 31 |
+
"from transformers import CLIPModel, CLIPProcessor\n",
|
| 32 |
+
"from torchmetrics.image.fid import FrechetInceptionDistance\n",
|
| 33 |
+
"from torchmetrics.multimodal import CLIPScore\n",
|
| 34 |
+
"\n",
|
| 35 |
+
"warnings.filterwarnings(\"ignore\")\n",
|
| 36 |
+
"\n",
|
| 37 |
+
"# Import local modules\n",
|
| 38 |
+
"from models import LRMRewardModel\n",
|
| 39 |
+
"from pipelines.sd15_gradient_ascent_pipeline import StableDiffusionGradientAscentPipeline\n",
|
| 40 |
+
"from grad_ascent_configs import get_config, list_configs\n",
|
| 41 |
+
"\n",
|
| 42 |
+
"# Import evaluation metrics\n",
|
| 43 |
+
"sys.path.append('../evaluation')\n",
|
| 44 |
+
"from pick_score import PickScorer\n",
|
| 45 |
+
"from hpsv2_score import HPSv2Scorer\n",
|
| 46 |
+
"from imagereward_score import load_imagereward\n"
|
| 47 |
+
]
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"cell_type": "markdown",
|
| 51 |
+
"id": "1740dd7c",
|
| 52 |
+
"metadata": {},
|
| 53 |
+
"source": [
|
| 54 |
+
"#### Configuration"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "code",
|
| 59 |
+
"execution_count": null,
|
| 60 |
+
"id": "f1bc2b07",
|
| 61 |
+
"metadata": {},
|
| 62 |
+
"outputs": [],
|
| 63 |
+
"source": [
|
| 64 |
+
"# ============ CONFIGURATION ============\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"# Dataset\n",
|
| 67 |
+
"DATA_DIR = \"./data\"\n",
|
| 68 |
+
"DATASET_TYPE = \"coco\" # \"coco\" or \"pickapic\"\n",
|
| 69 |
+
"NUM_SAMPLES = 20 # Number of samples to analyze\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"# Model\n",
|
| 72 |
+
"BASE_MODEL = \"runwayml/stable-diffusion-v1-5\"\n",
|
| 73 |
+
"MODEL_VARIANT = \"lpo\" # \"origin\", \"spo\", \"diffusion_dpo\", \"lpo\"\n",
|
| 74 |
+
"LRM_MODEL = \"casiatao/LRM\"\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"# Generation\n",
|
| 77 |
+
"NUM_INFERENCE_STEPS = 100\n",
|
| 78 |
+
"CFG_SCALE = 5.0\n",
|
| 79 |
+
"SEED = 42\n",
|
| 80 |
+
"BATCH_SIZE = 1\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"# Gradient Ascent Config\n",
|
| 83 |
+
"GRAD_CONFIG = \"low_to_high_nesterov\" # Use None for manual config, or specify preset name\n",
|
| 84 |
+
"GRAD_RANGE_START = 0\n",
|
| 85 |
+
"GRAD_RANGE_END = 500\n",
|
| 86 |
+
"GRAD_STEPS = 1\n",
|
| 87 |
+
"GRAD_STEP_SIZE = 0.1\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"# Metrics to compute\n",
|
| 90 |
+
"METRICS = [\"reward\", \"clip\", \"aesthetic\", \"pickscore\", \"hpsv2\", \"fid\"] # Add/remove as needed\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"# Device\n",
|
| 93 |
+
"CUDA_DEVICE = 0\n",
|
| 94 |
+
"device = f\"cuda:{CUDA_DEVICE}\" if torch.cuda.is_available() else \"cpu\"\n",
|
| 95 |
+
"dtype = torch.float16 if torch.cuda.is_available() else torch.float32\n",
|
| 96 |
+
"\n",
|
| 97 |
+
"# Output\n",
|
| 98 |
+
"OUTPUT_DIR = \"timestep_analysis_results\"\n",
|
| 99 |
+
"os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
|
| 100 |
+
"\n",
|
| 101 |
+
"print(f\"Device: {device}\")\n",
|
| 102 |
+
"print(f\"Dataset: {DATASET_TYPE}\")\n",
|
| 103 |
+
"print(f\"Samples to analyze: {NUM_SAMPLES}\")\n",
|
| 104 |
+
"print(f\"Metrics: {METRICS}\")\n",
|
| 105 |
+
"print(f\"Output directory: {OUTPUT_DIR}\")"
|
| 106 |
+
]
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"cell_type": "markdown",
|
| 110 |
+
"id": "1b1b6d02",
|
| 111 |
+
"metadata": {},
|
| 112 |
+
"source": [
|
| 113 |
+
"#### Load Dataset"
|
| 114 |
+
]
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"cell_type": "code",
|
| 118 |
+
"execution_count": null,
|
| 119 |
+
"id": "a74b2816",
|
| 120 |
+
"metadata": {},
|
| 121 |
+
"outputs": [],
|
| 122 |
+
"source": [
|
| 123 |
+
"def load_validation_data(data_dir, max_samples=None):\n",
|
| 124 |
+
" \"\"\"Load COCO validation prompts and image paths.\"\"\"\n",
|
| 125 |
+
" data_dir = Path(data_dir)\n",
|
| 126 |
+
" val_json = data_dir / \"coco\" / \"caption_val.json\"\n",
|
| 127 |
+
" \n",
|
| 128 |
+
" if not val_json.exists():\n",
|
| 129 |
+
" raise FileNotFoundError(f\"Validation data not found at {val_json}\")\n",
|
| 130 |
+
" \n",
|
| 131 |
+
" with open(val_json, 'r') as f:\n",
|
| 132 |
+
" data = json.load(f)\n",
|
| 133 |
+
" \n",
|
| 134 |
+
" print(f\"Loaded JSON with {len(data)} entries\")\n",
|
| 135 |
+
" \n",
|
| 136 |
+
" # Validate that image folder exists\n",
|
| 137 |
+
" val_img_dir = data_dir / \"coco\" / \"images\" / \"val\"\n",
|
| 138 |
+
" if not val_img_dir.exists():\n",
|
| 139 |
+
" print(f\"Warning: Standard validation directory not found: {val_img_dir}\")\n",
|
| 140 |
+
" \n",
|
| 141 |
+
" # Parse data - img_path already contains \"images/val/\" prefix\n",
|
| 142 |
+
" prompts = []\n",
|
| 143 |
+
" image_paths = []\n",
|
| 144 |
+
" \n",
|
| 145 |
+
" for img_path, caption in data.items():\n",
|
| 146 |
+
" # Try the path as given (relative to data_dir/coco/)\n",
|
| 147 |
+
" full_path = data_dir / \"coco\" / img_path\n",
|
| 148 |
+
" if full_path.exists():\n",
|
| 149 |
+
" prompts.append(caption)\n",
|
| 150 |
+
" image_paths.append(str(full_path))\n",
|
| 151 |
+
" \n",
|
| 152 |
+
" print(f\"Found {len(prompts)} valid image-caption pairs\")\n",
|
| 153 |
+
" \n",
|
| 154 |
+
" if len(prompts) == 0:\n",
|
| 155 |
+
" print(f\"\\n⚠ WARNING: No valid images found!\")\n",
|
| 156 |
+
" print(f\"Debug information:\")\n",
|
| 157 |
+
" print(f\" JSON file: {val_json}\")\n",
|
| 158 |
+
" print(f\" JSON entries: {len(data)}\")\n",
|
| 159 |
+
" print(f\" Sample keys from JSON: {list(data.keys())[:3]}\")\n",
|
| 160 |
+
" \n",
|
| 161 |
+
" # Check if images exist at all\n",
|
| 162 |
+
" coco_dir = data_dir / \"coco\"\n",
|
| 163 |
+
" if coco_dir.exists():\n",
|
| 164 |
+
" print(f\" COCO dir exists: {coco_dir}\")\n",
|
| 165 |
+
" # List subdirectories\n",
|
| 166 |
+
" subdirs = [d.name for d in coco_dir.iterdir() if d.is_dir()]\n",
|
| 167 |
+
" print(f\" Subdirectories in COCO: {subdirs}\")\n",
|
| 168 |
+
" \n",
|
| 169 |
+
" # Try to find images\n",
|
| 170 |
+
" if val_img_dir.exists():\n",
|
| 171 |
+
" img_files = list(val_img_dir.glob(\"*.jpg\"))[:5]\n",
|
| 172 |
+
" print(f\" Sample images in val dir: {[f.name for f in img_files]}\")\n",
|
| 173 |
+
" \n",
|
| 174 |
+
" if max_samples and len(prompts) > 0:\n",
|
| 175 |
+
" prompts = prompts[:max_samples]\n",
|
| 176 |
+
" image_paths = image_paths[:max_samples]\n",
|
| 177 |
+
" \n",
|
| 178 |
+
" return prompts, image_paths\n",
|
| 179 |
+
"\n",
|
| 180 |
+
"# Load data\n",
|
| 181 |
+
"prompts, image_paths = load_validation_data(DATA_DIR, NUM_SAMPLES)\n",
|
| 182 |
+
"print(f\"\\n✓ Loaded {len(prompts)} samples\")\n",
|
| 183 |
+
"\n",
|
| 184 |
+
"if len(prompts) > 0:\n",
|
| 185 |
+
" print(f\"\\nSample prompts:\")\n",
|
| 186 |
+
" for i, prompt in enumerate(prompts[:3]):\n",
|
| 187 |
+
" print(f\" {i+1}. {prompt[:80]}...\")\n",
|
| 188 |
+
" print(f\"\\nSample image paths:\")\n",
|
| 189 |
+
" for i, path in enumerate(image_paths[:3]):\n",
|
| 190 |
+
" print(f\" {i+1}. {path}\")\n",
|
| 191 |
+
"else:\n",
|
| 192 |
+
" print(\"\\n❌ ERROR: No samples loaded! Please check your data directory structure.\")\n",
|
| 193 |
+
" print(\"Expected structure:\")\n",
|
| 194 |
+
" print(\" ./data/coco/caption_val.json\")\n",
|
| 195 |
+
" print(\" ./data/coco/images/val/*.jpg\")"
|
| 196 |
+
]
|
| 197 |
+
},
|
| 198 |
+
{
|
| 199 |
+
"cell_type": "markdown",
|
| 200 |
+
"id": "5ceae64a",
|
| 201 |
+
"metadata": {},
|
| 202 |
+
"source": [
|
| 203 |
+
"#### Load Models and Scorers"
|
| 204 |
+
]
|
| 205 |
+
},
|
| 206 |
+
{
|
| 207 |
+
"cell_type": "code",
|
| 208 |
+
"execution_count": null,
|
| 209 |
+
"id": "43ad1f56",
|
| 210 |
+
"metadata": {},
|
| 211 |
+
"outputs": [],
|
| 212 |
+
"source": [
|
| 213 |
+
"# ============ MLP for Aesthetic Scoring ============\n",
|
| 214 |
+
"class MLP(nn.Module):\n",
|
| 215 |
+
" def __init__(self):\n",
|
| 216 |
+
" super().__init__()\n",
|
| 217 |
+
" self.layers = nn.Sequential(\n",
|
| 218 |
+
" nn.Linear(768, 1024),\n",
|
| 219 |
+
" nn.Dropout(0.2),\n",
|
| 220 |
+
" nn.Linear(1024, 128),\n",
|
| 221 |
+
" nn.Dropout(0.2),\n",
|
| 222 |
+
" nn.Linear(128, 64),\n",
|
| 223 |
+
" nn.Dropout(0.1),\n",
|
| 224 |
+
" nn.Linear(64, 16),\n",
|
| 225 |
+
" nn.Linear(16, 1),\n",
|
| 226 |
+
" )\n",
|
| 227 |
+
" \n",
|
| 228 |
+
" @torch.no_grad()\n",
|
| 229 |
+
" def forward(self, embed):\n",
|
| 230 |
+
" return self.layers(embed)\n",
|
| 231 |
+
"\n",
|
| 232 |
+
"class AestheticScorer(torch.nn.Module):\n",
|
| 233 |
+
" def __init__(self, dtype, device):\n",
|
| 234 |
+
" super().__init__()\n",
|
| 235 |
+
" self.clip = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
|
| 236 |
+
" self.processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
|
| 237 |
+
" self.mlp = MLP()\n",
|
| 238 |
+
" \n",
|
| 239 |
+
" aesthetic_path = \"../evaluation/sac+logos+ava1-l14-linearMSE.pth\"\n",
|
| 240 |
+
" if os.path.exists(aesthetic_path):\n",
|
| 241 |
+
" state_dict = torch.load(aesthetic_path, map_location='cpu')\n",
|
| 242 |
+
" self.mlp.load_state_dict(state_dict)\n",
|
| 243 |
+
" \n",
|
| 244 |
+
" self.dtype = dtype\n",
|
| 245 |
+
" self.to(device)\n",
|
| 246 |
+
" self.eval()\n",
|
| 247 |
+
" \n",
|
| 248 |
+
" @torch.no_grad()\n",
|
| 249 |
+
" def __call__(self, images):\n",
|
| 250 |
+
" if not isinstance(images, list):\n",
|
| 251 |
+
" images = [images]\n",
|
| 252 |
+
" inputs = self.processor(images=images, return_tensors=\"pt\", padding=True)\n",
|
| 253 |
+
" inputs = {k: v.to(self.clip.device) for k, v in inputs.items()}\n",
|
| 254 |
+
" image_embeds = self.clip.get_image_features(**inputs)\n",
|
| 255 |
+
" image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)\n",
|
| 256 |
+
" scores = self.mlp(image_embeds.float())\n",
|
| 257 |
+
" return scores.squeeze().cpu().numpy()\n",
|
| 258 |
+
"\n",
|
| 259 |
+
"print(\"Loading models...\")"
|
| 260 |
+
]
|
| 261 |
+
},
|
| 262 |
+
{
|
| 263 |
+
"cell_type": "code",
|
| 264 |
+
"execution_count": null,
|
| 265 |
+
"id": "36a70595",
|
| 266 |
+
"metadata": {},
|
| 267 |
+
"outputs": [],
|
| 268 |
+
"source": [
|
| 269 |
+
"# Load Reward Model\n",
|
| 270 |
+
"print(\"Loading reward model...\")\n",
|
| 271 |
+
"reward_model = LRMRewardModel(\n",
|
| 272 |
+
" pretrained_model_name_or_path=BASE_MODEL,\n",
|
| 273 |
+
" lrm_model_path=LRM_MODEL,\n",
|
| 274 |
+
" guidance_scale=CFG_SCALE,\n",
|
| 275 |
+
" device=device\n",
|
| 276 |
+
")\n",
|
| 277 |
+
"if dtype == torch.float16:\n",
|
| 278 |
+
" reward_model = reward_model.half()\n",
|
| 279 |
+
"reward_model.eval()\n",
|
| 280 |
+
"print(\"✓ Reward model loaded\")\n",
|
| 281 |
+
"\n",
|
| 282 |
+
"# Load Pipeline\n",
|
| 283 |
+
"print(\"\\nLoading diffusion pipeline...\")\n",
|
| 284 |
+
"if MODEL_VARIANT == \"origin\":\n",
|
| 285 |
+
" base_pipeline = StableDiffusionPipeline.from_pretrained(\n",
|
| 286 |
+
" BASE_MODEL, torch_dtype=dtype, safety_checker=None\n",
|
| 287 |
+
" )\n",
|
| 288 |
+
"elif MODEL_VARIANT == \"spo\":\n",
|
| 289 |
+
" base_pipeline = StableDiffusionPipeline.from_pretrained(\n",
|
| 290 |
+
" 'SPO-Diffusion-Models/SPO-SD-v1-5_4k-p_10ep',\n",
|
| 291 |
+
" torch_dtype=dtype, safety_checker=None\n",
|
| 292 |
+
" )\n",
|
| 293 |
+
" CFG_SCALE = 5.0\n",
|
| 294 |
+
"elif MODEL_VARIANT == \"diffusion_dpo\":\n",
|
| 295 |
+
" unet = UNet2DConditionModel.from_pretrained(\n",
|
| 296 |
+
" 'mhdang/dpo-sd1.5-text2image-v1', subfolder=\"unet\", torch_dtype=dtype\n",
|
| 297 |
+
" )\n",
|
| 298 |
+
" base_pipeline = StableDiffusionPipeline.from_pretrained(\n",
|
| 299 |
+
" BASE_MODEL, torch_dtype=dtype, safety_checker=None, unet=unet\n",
|
| 300 |
+
" )\n",
|
| 301 |
+
"elif MODEL_VARIANT == \"lpo\":\n",
|
| 302 |
+
" unet = UNet2DConditionModel.from_pretrained(\n",
|
| 303 |
+
" 'casiatao/LPO', subfolder=\"lpo_sd15_merge/unet\", torch_dtype=dtype\n",
|
| 304 |
+
" )\n",
|
| 305 |
+
" base_pipeline = StableDiffusionPipeline.from_pretrained(\n",
|
| 306 |
+
" BASE_MODEL, torch_dtype=dtype, safety_checker=None, unet=unet\n",
|
| 307 |
+
" )\n",
|
| 308 |
+
" CFG_SCALE = 5.0\n",
|
| 309 |
+
"\n",
|
| 310 |
+
"pipeline = StableDiffusionGradientAscentPipeline(**base_pipeline.components)\n",
|
| 311 |
+
"pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)\n",
|
| 312 |
+
"pipeline = pipeline.to(device)\n",
|
| 313 |
+
"pipeline.set_reward_model(reward_model)\n",
|
| 314 |
+
"print(\"✓ Pipeline loaded\")"
|
| 315 |
+
]
|
| 316 |
+
},
|
| 317 |
+
{
|
| 318 |
+
"cell_type": "code",
|
| 319 |
+
"execution_count": null,
|
| 320 |
+
"id": "4e18f075",
|
| 321 |
+
"metadata": {},
|
| 322 |
+
"outputs": [],
|
| 323 |
+
"source": [
|
| 324 |
+
"# Load Metric Scorers\n",
|
| 325 |
+
"print(\"\\nLoading metric scorers...\")\n",
|
| 326 |
+
"\n",
|
| 327 |
+
"clip_scorer = None\n",
|
| 328 |
+
"aesthetic_scorer = None\n",
|
| 329 |
+
"pick_scorer = None\n",
|
| 330 |
+
"hpsv2_scorer = None\n",
|
| 331 |
+
"imagereward_scorer = None\n",
|
| 332 |
+
"\n",
|
| 333 |
+
"if \"clip\" in METRICS:\n",
|
| 334 |
+
" print(\" Loading CLIP scorer...\")\n",
|
| 335 |
+
" clip_scorer = CLIPScore(model_name_or_path=\"openai/clip-vit-base-patch16\").to(device)\n",
|
| 336 |
+
" print(\" ✓ CLIP scorer loaded\")\n",
|
| 337 |
+
"\n",
|
| 338 |
+
"if \"aesthetic\" in METRICS:\n",
|
| 339 |
+
" print(\" Loading Aesthetic scorer...\")\n",
|
| 340 |
+
" aesthetic_scorer = AestheticScorer(dtype, device)\n",
|
| 341 |
+
" print(\" ✓ Aesthetic scorer loaded\")\n",
|
| 342 |
+
"\n",
|
| 343 |
+
"if \"pickscore\" in METRICS:\n",
|
| 344 |
+
" print(\" Loading PickScore scorer...\")\n",
|
| 345 |
+
" try:\n",
|
| 346 |
+
" pick_scorer = PickScorer(device=device, dtype=dtype)\n",
|
| 347 |
+
" print(\" ✓ PickScore loaded\")\n",
|
| 348 |
+
" except Exception as e:\n",
|
| 349 |
+
" print(f\" ✗ PickScore failed: {e}\")\n",
|
| 350 |
+
" METRICS.remove(\"pickscore\")\n",
|
| 351 |
+
"\n",
|
| 352 |
+
"if \"hpsv2\" in METRICS:\n",
|
| 353 |
+
" print(\" Loading HPSv2 scorer...\")\n",
|
| 354 |
+
" try:\n",
|
| 355 |
+
" hpsv2_scorer = HPSv2Scorer(device=device, dtype=dtype)\n",
|
| 356 |
+
" print(\" ✓ HPSv2 loaded\")\n",
|
| 357 |
+
" except Exception as e:\n",
|
| 358 |
+
" print(f\" ✗ HPSv2 failed: {e}\")\n",
|
| 359 |
+
" METRICS.remove(\"hpsv2\")\n",
|
| 360 |
+
"\n",
|
| 361 |
+
"if \"imagereward\" in METRICS:\n",
|
| 362 |
+
" print(\" Loading ImageReward scorer...\")\n",
|
| 363 |
+
" try:\n",
|
| 364 |
+
" imagereward_scorer = load_imagereward(device=device)\n",
|
| 365 |
+
" print(\" ✓ ImageReward loaded\")\n",
|
| 366 |
+
" except Exception as e:\n",
|
| 367 |
+
" print(f\" ✗ ImageReward failed: {e}\")\n",
|
| 368 |
+
" METRICS.remove(\"imagereward\")\n",
|
| 369 |
+
"\n",
|
| 370 |
+
"print(f\"\\n✓ Active metrics: {METRICS}\")"
|
| 371 |
+
]
|
| 372 |
+
},
|
| 373 |
+
{
|
| 374 |
+
"cell_type": "markdown",
|
| 375 |
+
"id": "70ac047b",
|
| 376 |
+
"metadata": {},
|
| 377 |
+
"source": [
|
| 378 |
+
"#### Configure Gradient Ascent"
|
| 379 |
+
]
|
| 380 |
+
},
|
| 381 |
+
{
|
| 382 |
+
"cell_type": "code",
|
| 383 |
+
"execution_count": null,
|
| 384 |
+
"id": "05996448",
|
| 385 |
+
"metadata": {},
|
| 386 |
+
"outputs": [],
|
| 387 |
+
"source": [
|
| 388 |
+
"# Configure gradient ascent\n",
|
| 389 |
+
"if GRAD_CONFIG:\n",
|
| 390 |
+
" print(f\"Loading gradient ascent config: {GRAD_CONFIG}\")\n",
|
| 391 |
+
" grad_config = get_config(GRAD_CONFIG)\n",
|
| 392 |
+
" print(f\"Config: {grad_config}\")\n",
|
| 393 |
+
"else:\n",
|
| 394 |
+
" grad_config = {\n",
|
| 395 |
+
" \"grad_timestep_range\": (GRAD_RANGE_START, GRAD_RANGE_END),\n",
|
| 396 |
+
" \"num_grad_steps\": GRAD_STEPS,\n",
|
| 397 |
+
" \"grad_step_size\": GRAD_STEP_SIZE,\n",
|
| 398 |
+
" }\n",
|
| 399 |
+
" print(f\"Manual gradient ascent configuration: {grad_config}\")\n",
|
| 400 |
+
"\n",
|
| 401 |
+
"pipeline.enable_gradient_ascent(**grad_config)\n",
|
| 402 |
+
"print(\"\\n✓ Gradient ascent enabled\")"
|
| 403 |
+
]
|
| 404 |
+
},
|
| 405 |
+
{
|
| 406 |
+
"cell_type": "markdown",
|
| 407 |
+
"id": "1f82c3df",
|
| 408 |
+
"metadata": {},
|
| 409 |
+
"source": [
|
| 410 |
+
"#### Timestep Analysis Functions"
|
| 411 |
+
]
|
| 412 |
+
},
|
| 413 |
+
{
|
| 414 |
+
"cell_type": "code",
|
| 415 |
+
"execution_count": null,
|
| 416 |
+
"id": "e836d8f2",
|
| 417 |
+
"metadata": {},
|
| 418 |
+
"outputs": [],
|
| 419 |
+
"source": [
|
| 420 |
+
"def latents_to_images(latents, vae):\n",
|
| 421 |
+
" \"\"\"Convert latents to PIL images.\"\"\"\n",
|
| 422 |
+
" latents = 1 / 0.18215 * latents\n",
|
| 423 |
+
" with torch.no_grad():\n",
|
| 424 |
+
" images = vae.decode(latents).sample\n",
|
| 425 |
+
" images = (images / 2 + 0.5).clamp(0, 1)\n",
|
| 426 |
+
" images = images.cpu().permute(0, 2, 3, 1).numpy()\n",
|
| 427 |
+
" images = (images * 255).round().astype(\"uint8\")\n",
|
| 428 |
+
" pil_images = [Image.fromarray(image) for image in images]\n",
|
| 429 |
+
" return pil_images\n",
|
| 430 |
+
"\n",
|
| 431 |
+
"\n",
|
| 432 |
+
"def compute_metrics_for_image(image, prompt, reference_image=None):\n",
|
| 433 |
+
" \"\"\"Compute all metrics for a single image.\"\"\"\n",
|
| 434 |
+
" metrics = {}\n",
|
| 435 |
+
" \n",
|
| 436 |
+
" # CLIP Score\n",
|
| 437 |
+
" if clip_scorer is not None:\n",
|
| 438 |
+
" img_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).unsqueeze(0).to(device)\n",
|
| 439 |
+
" with torch.no_grad():\n",
|
| 440 |
+
" clip_score = clip_scorer(img_tensor, prompt).item()\n",
|
| 441 |
+
" metrics['clip'] = clip_score\n",
|
| 442 |
+
" \n",
|
| 443 |
+
" # Aesthetic Score\n",
|
| 444 |
+
" if aesthetic_scorer is not None:\n",
|
| 445 |
+
" aesthetic_score = aesthetic_scorer([image])\n",
|
| 446 |
+
" if isinstance(aesthetic_score, np.ndarray):\n",
|
| 447 |
+
" aesthetic_score = aesthetic_score.item()\n",
|
| 448 |
+
" metrics['aesthetic'] = aesthetic_score\n",
|
| 449 |
+
" \n",
|
| 450 |
+
" # PickScore\n",
|
| 451 |
+
" if pick_scorer is not None:\n",
|
| 452 |
+
" pick_score = pick_scorer.score(prompt, [image])[0]\n",
|
| 453 |
+
" metrics['pickscore'] = pick_score\n",
|
| 454 |
+
" \n",
|
| 455 |
+
" # HPSv2\n",
|
| 456 |
+
" if hpsv2_scorer is not None:\n",
|
| 457 |
+
" hpsv2_score = hpsv2_scorer.score(prompt, [image])[0]\n",
|
| 458 |
+
" metrics['hpsv2'] = hpsv2_score\n",
|
| 459 |
+
" \n",
|
| 460 |
+
" # ImageReward\n",
|
| 461 |
+
" if imagereward_scorer is not None:\n",
|
| 462 |
+
" imagereward_score = imagereward_scorer.score(prompt, [image])[0]\n",
|
| 463 |
+
" metrics['imagereward'] = imagereward_score\n",
|
| 464 |
+
" \n",
|
| 465 |
+
" # FID (if reference image provided)\n",
|
| 466 |
+
" if reference_image is not None:\n",
|
| 467 |
+
" try:\n",
|
| 468 |
+
" fid_metric = FrechetInceptionDistance(normalize=True).to(device)\n",
|
| 469 |
+
" \n",
|
| 470 |
+
" # Process reference image\n",
|
| 471 |
+
" ref_img = Image.open(reference_image).convert('RGB').resize((299, 299))\n",
|
| 472 |
+
" ref_tensor = torch.from_numpy(np.array(ref_img)).permute(2, 0, 1).unsqueeze(0).to(device)\n",
|
| 473 |
+
" \n",
|
| 474 |
+
" # Process generated image\n",
|
| 475 |
+
" gen_img = image.resize((299, 299))\n",
|
| 476 |
+
" gen_tensor = torch.from_numpy(np.array(gen_img)).permute(2, 0, 1).unsqueeze(0).to(device)\n",
|
| 477 |
+
" \n",
|
| 478 |
+
" if ref_tensor.size(0) == 1:\n",
|
| 479 |
+
" ref_tensor = ref_tensor.repeat(2, 1, 1, 1)\n",
|
| 480 |
+
" if gen_tensor.size(0) == 1:\n",
|
| 481 |
+
" gen_tensor = gen_tensor.repeat(2, 1, 1, 1)\n",
|
| 482 |
+
" \n",
|
| 483 |
+
" fid_metric.update(ref_tensor, real=True)\n",
|
| 484 |
+
" fid_metric.update(gen_tensor, real=False)\n",
|
| 485 |
+
" \n",
|
| 486 |
+
" fid_score = fid_metric.compute().item()/10\n",
|
| 487 |
+
" metrics['fid'] = fid_score\n",
|
| 488 |
+
" except Exception as e:\n",
|
| 489 |
+
" print(f\"FID computation failed: {e}\")\n",
|
| 490 |
+
" \n",
|
| 491 |
+
" return metrics\n",
|
| 492 |
+
"\n",
|
| 493 |
+
"\n",
|
| 494 |
+
"def analyze_sample_timesteps(prompt, reference_image, sample_idx):\n",
|
| 495 |
+
" \"\"\"\n",
|
| 496 |
+
" Generate images and track metrics at each timestep.\n",
|
| 497 |
+
" Returns timestep-wise metrics and intermediate images.\n",
|
| 498 |
+
" \"\"\"\n",
|
| 499 |
+
" print(f\"\\n{'='*70}\")\n",
|
| 500 |
+
" print(f\"Analyzing Sample {sample_idx + 1}\")\n",
|
| 501 |
+
" print(f\"Prompt: {prompt[:80]}...\")\n",
|
| 502 |
+
" print(f\"{'='*70}\")\n",
|
| 503 |
+
" \n",
|
| 504 |
+
" # Storage for results\n",
|
| 505 |
+
" timestep_metrics = {\n",
|
| 506 |
+
" 'timesteps': [],\n",
|
| 507 |
+
" 'reward': [],\n",
|
| 508 |
+
" 'clip': [],\n",
|
| 509 |
+
" 'aesthetic': [],\n",
|
| 510 |
+
" 'pickscore': [],\n",
|
| 511 |
+
" 'hpsv2': [],\n",
|
| 512 |
+
" 'imagereward': [],\n",
|
| 513 |
+
" 'fid': []\n",
|
| 514 |
+
" }\n",
|
| 515 |
+
" intermediate_images = []\n",
|
| 516 |
+
" \n",
|
| 517 |
+
" # Reset gradient stats\n",
|
| 518 |
+
" if hasattr(pipeline, 'grad_guidance'):\n",
|
| 519 |
+
" pipeline.grad_guidance.reset_statistics()\n",
|
| 520 |
+
" \n",
|
| 521 |
+
" # Modified pipeline call to capture intermediate latents\n",
|
| 522 |
+
" generator = torch.Generator(device=device).manual_seed(SEED + sample_idx)\n",
|
| 523 |
+
" \n",
|
| 524 |
+
" # We'll manually step through the denoising process\n",
|
| 525 |
+
" pipeline.set_progress_bar_config(disable=True)\n",
|
| 526 |
+
" \n",
|
| 527 |
+
" # Prepare inputs\n",
|
| 528 |
+
" height = pipeline.unet.config.sample_size * pipeline.vae_scale_factor\n",
|
| 529 |
+
" width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor\n",
|
| 530 |
+
" \n",
|
| 531 |
+
" # Encode prompt\n",
|
| 532 |
+
" text_embeddings = pipeline._encode_prompt(\n",
|
| 533 |
+
" prompt, device, 1, True, None\n",
|
| 534 |
+
" )\n",
|
| 535 |
+
" \n",
|
| 536 |
+
" # Prepare timesteps\n",
|
| 537 |
+
" pipeline.scheduler.set_timesteps(NUM_INFERENCE_STEPS, device=device)\n",
|
| 538 |
+
" timesteps = pipeline.scheduler.timesteps\n",
|
| 539 |
+
" \n",
|
| 540 |
+
" # Prepare latents\n",
|
| 541 |
+
" shape = (1, pipeline.unet.config.in_channels, height // 8, width // 8)\n",
|
| 542 |
+
" latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)\n",
|
| 543 |
+
" latents = latents * pipeline.scheduler.init_noise_sigma\n",
|
| 544 |
+
" \n",
|
| 545 |
+
" # Denoising loop with metric tracking\n",
|
| 546 |
+
" for i, t in enumerate(tqdm(timesteps, desc=\"Denoising steps\")):\n",
|
| 547 |
+
" # Apply gradient ascent if enabled\n",
|
| 548 |
+
" if hasattr(pipeline, 'grad_guidance') and pipeline.grad_guidance:\n",
|
| 549 |
+
" if pipeline.grad_guidance.should_apply_gradient(t.item()):\n",
|
| 550 |
+
" latents, grad_stats = pipeline.grad_guidance.apply_gradient_ascent(\n",
|
| 551 |
+
" latents, prompt, t.item(), verbose=False,\n",
|
| 552 |
+
" total_denoising_steps=len(timesteps)\n",
|
| 553 |
+
" )\n",
|
| 554 |
+
" \n",
|
| 555 |
+
" # Expand latents for classifier free guidance\n",
|
| 556 |
+
" latent_model_input = torch.cat([latents] * 2)\n",
|
| 557 |
+
" latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)\n",
|
| 558 |
+
" \n",
|
| 559 |
+
" # Predict noise\n",
|
| 560 |
+
" with torch.no_grad():\n",
|
| 561 |
+
" noise_pred = pipeline.unet(\n",
|
| 562 |
+
" latent_model_input,\n",
|
| 563 |
+
" t,\n",
|
| 564 |
+
" encoder_hidden_states=text_embeddings,\n",
|
| 565 |
+
" ).sample\n",
|
| 566 |
+
" \n",
|
| 567 |
+
" # Perform guidance\n",
|
| 568 |
+
" noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n",
|
| 569 |
+
" noise_pred = noise_pred_uncond + CFG_SCALE * (noise_pred_text - noise_pred_uncond)\n",
|
| 570 |
+
" \n",
|
| 571 |
+
" # Compute previous noisy sample\n",
|
| 572 |
+
" latents = pipeline.scheduler.step(noise_pred, t, latents).prev_sample\n",
|
| 573 |
+
" \n",
|
| 574 |
+
" # Decode latents to image every few steps\n",
|
| 575 |
+
" if i % 5 == 0 or i == len(timesteps) - 1:\n",
|
| 576 |
+
" # Convert to image\n",
|
| 577 |
+
" images = latents_to_images(latents, pipeline.vae)\n",
|
| 578 |
+
" image = images[0]\n",
|
| 579 |
+
" \n",
|
| 580 |
+
" # Compute reward\n",
|
| 581 |
+
" with torch.no_grad():\n",
|
| 582 |
+
" reward = reward_model.get_reward_score(latents, prompt, t.item())\n",
|
| 583 |
+
" reward_val = reward.mean().item() if reward.numel() > 1 else reward.item()\n",
|
| 584 |
+
" \n",
|
| 585 |
+
" # Compute other metrics\n",
|
| 586 |
+
" metrics = compute_metrics_for_image(image, prompt, reference_image)\n",
|
| 587 |
+
" \n",
|
| 588 |
+
" # Store results\n",
|
| 589 |
+
" timestep_metrics['timesteps'].append(t.item())\n",
|
| 590 |
+
" timestep_metrics['reward'].append(reward_val)\n",
|
| 591 |
+
" \n",
|
| 592 |
+
" for metric_name in ['clip', 'aesthetic', 'pickscore', 'hpsv2', 'imagereward', 'fid']:\n",
|
| 593 |
+
" if metric_name in metrics:\n",
|
| 594 |
+
" timestep_metrics[metric_name].append(metrics[metric_name])\n",
|
| 595 |
+
" else:\n",
|
| 596 |
+
" timestep_metrics[metric_name].append(None)\n",
|
| 597 |
+
" \n",
|
| 598 |
+
" intermediate_images.append(image)\n",
|
| 599 |
+
" \n",
|
| 600 |
+
" print(f\" Step {i}/{len(timesteps)} | t={t.item():.0f} | Reward={reward_val:.4f}\")\n",
|
| 601 |
+
" \n",
|
| 602 |
+
" # Final image\n",
|
| 603 |
+
" final_images = latents_to_images(latents, pipeline.vae)\n",
|
| 604 |
+
" final_image = final_images[0]\n",
|
| 605 |
+
" \n",
|
| 606 |
+
" pipeline.set_progress_bar_config(disable=False)\n",
|
| 607 |
+
" \n",
|
| 608 |
+
" return timestep_metrics, intermediate_images, final_image\n",
|
| 609 |
+
"\n",
|
| 610 |
+
"print(\"✓ Analysis functions defined\")"
|
| 611 |
+
]
|
| 612 |
+
},
|
| 613 |
+
{
|
| 614 |
+
"cell_type": "markdown",
|
| 615 |
+
"id": "fc089bfd",
|
| 616 |
+
"metadata": {},
|
| 617 |
+
"source": [
|
| 618 |
+
"#### Run Timestep Analysis"
|
| 619 |
+
]
|
| 620 |
+
},
|
| 621 |
+
{
|
| 622 |
+
"cell_type": "code",
|
| 623 |
+
"execution_count": null,
|
| 624 |
+
"id": "67c25164",
|
| 625 |
+
"metadata": {},
|
| 626 |
+
"outputs": [],
|
| 627 |
+
"source": [
|
| 628 |
+
"# Run analysis for all samples\n",
|
| 629 |
+
"all_results = []\n",
|
| 630 |
+
"\n",
|
| 631 |
+
"for idx in range(len(prompts)):\n",
|
| 632 |
+
" prompt = prompts[idx]\n",
|
| 633 |
+
" reference_image = image_paths[idx]\n",
|
| 634 |
+
" \n",
|
| 635 |
+
" # Analyze this sample\n",
|
| 636 |
+
" metrics, images, final_image = analyze_sample_timesteps(prompt, reference_image, idx)\n",
|
| 637 |
+
" \n",
|
| 638 |
+
" # Store results\n",
|
| 639 |
+
" all_results.append({\n",
|
| 640 |
+
" 'prompt': prompt,\n",
|
| 641 |
+
" 'reference_image': reference_image,\n",
|
| 642 |
+
" 'metrics': metrics,\n",
|
| 643 |
+
" 'intermediate_images': images,\n",
|
| 644 |
+
" 'final_image': final_image\n",
|
| 645 |
+
" })\n",
|
| 646 |
+
" \n",
|
| 647 |
+
" # Save intermediate results\n",
|
| 648 |
+
" sample_dir = Path(OUTPUT_DIR) / f\"sample_{idx+1}\"\n",
|
| 649 |
+
" sample_dir.mkdir(exist_ok=True)\n",
|
| 650 |
+
" \n",
|
| 651 |
+
" # Save final image\n",
|
| 652 |
+
" final_image.save(sample_dir / \"final_image.png\")\n",
|
| 653 |
+
" \n",
|
| 654 |
+
" # Save all intermediate images\n",
|
| 655 |
+
" images_dir = sample_dir / \"intermediate_images\"\n",
|
| 656 |
+
" images_dir.mkdir(exist_ok=True)\n",
|
| 657 |
+
" for img_idx, img in enumerate(images):\n",
|
| 658 |
+
" t_val = metrics['timesteps'][img_idx]\n",
|
| 659 |
+
" img.save(images_dir / f\"step_{img_idx:03d}_t{int(t_val)}.png\")\n",
|
| 660 |
+
" \n",
|
| 661 |
+
" # Save metrics\n",
|
| 662 |
+
" with open(sample_dir / \"metrics.json\", 'w') as f:\n",
|
| 663 |
+
" json.dump(metrics, f, indent=2)\n",
|
| 664 |
+
" \n",
|
| 665 |
+
" print(f\"✓ Saved {len(images)} intermediate images for sample {idx+1}\")\n",
|
| 666 |
+
"\n",
|
| 667 |
+
"print(\"\\n✓ Analysis complete for all samples\")"
|
| 668 |
+
]
|
| 669 |
+
},
|
| 670 |
+
{
|
| 671 |
+
"cell_type": "markdown",
|
| 672 |
+
"id": "10dd749d",
|
| 673 |
+
"metadata": {},
|
| 674 |
+
"source": [
|
| 675 |
+
"#### Visualization: Intermediate Images"
|
| 676 |
+
]
|
| 677 |
+
},
|
| 678 |
+
{
|
| 679 |
+
"cell_type": "code",
|
| 680 |
+
"execution_count": null,
|
| 681 |
+
"id": "bb32eaa1",
|
| 682 |
+
"metadata": {},
|
| 683 |
+
"outputs": [],
|
| 684 |
+
"source": [
|
| 685 |
+
"def plot_intermediate_images(results, sample_idx, max_images=8):\n",
|
| 686 |
+
" \"\"\"Display intermediate images for a sample showing evolution over timesteps.\"\"\"\n",
|
| 687 |
+
" result = results[sample_idx]\n",
|
| 688 |
+
" images = result['intermediate_images']\n",
|
| 689 |
+
" metrics = result['metrics']\n",
|
| 690 |
+
" timesteps = metrics['timesteps']\n",
|
| 691 |
+
" rewards = metrics['reward']\n",
|
| 692 |
+
" \n",
|
| 693 |
+
" # Select evenly spaced images if too many\n",
|
| 694 |
+
" if len(images) > max_images:\n",
|
| 695 |
+
" indices = np.linspace(0, len(images)-1, max_images, dtype=int)\n",
|
| 696 |
+
" selected_images = [images[i] for i in indices]\n",
|
| 697 |
+
" selected_timesteps = [timesteps[i] for i in indices]\n",
|
| 698 |
+
" selected_rewards = [rewards[i] for i in indices]\n",
|
| 699 |
+
" else:\n",
|
| 700 |
+
" selected_images = images\n",
|
| 701 |
+
" selected_timesteps = timesteps\n",
|
| 702 |
+
" selected_rewards = rewards\n",
|
| 703 |
+
" \n",
|
| 704 |
+
" n_images = len(selected_images)\n",
|
| 705 |
+
" cols = 5\n",
|
| 706 |
+
" rows = (n_images + cols - 1) // cols\n",
|
| 707 |
+
" \n",
|
| 708 |
+
" fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))\n",
|
| 709 |
+
" axes = axes.flatten() if n_images > 1 else [axes]\n",
|
| 710 |
+
" \n",
|
| 711 |
+
" fig.suptitle(f\"Sample {sample_idx + 1}: Image Evolution Over Timesteps\\n\"\n",
|
| 712 |
+
" f\"Prompt: {result['prompt'][:80]}...\", \n",
|
| 713 |
+
" fontsize=12, fontweight='bold')\n",
|
| 714 |
+
" \n",
|
| 715 |
+
" for idx, (img, t, r) in enumerate(zip(selected_images, selected_timesteps, selected_rewards)):\n",
|
| 716 |
+
" ax = axes[idx]\n",
|
| 717 |
+
" ax.imshow(img)\n",
|
| 718 |
+
" ax.axis('off')\n",
|
| 719 |
+
" ax.set_title(f\"t={t:.0f}\\nReward={r:.3f}\", fontsize=10)\n",
|
| 720 |
+
" \n",
|
| 721 |
+
" # Hide unused subplots\n",
|
| 722 |
+
" for idx in range(n_images, len(axes)):\n",
|
| 723 |
+
" axes[idx].axis('off')\n",
|
| 724 |
+
" \n",
|
| 725 |
+
" plt.tight_layout()\n",
|
| 726 |
+
" \n",
|
| 727 |
+
" # Save plot\n",
|
| 728 |
+
" sample_dir = Path(OUTPUT_DIR) / f\"sample_{sample_idx+1}\"\n",
|
| 729 |
+
" plt.savefig(sample_dir / \"image_evolution.png\", dpi=150, bbox_inches='tight')\n",
|
| 730 |
+
" plt.show()\n",
|
| 731 |
+
"\n",
|
| 732 |
+
"# Plot intermediate images for all samples\n",
|
| 733 |
+
"for idx in range(len(all_results)):\n",
|
| 734 |
+
" plot_intermediate_images(all_results, idx)"
|
| 735 |
+
]
|
| 736 |
+
},
|
| 737 |
+
{
|
| 738 |
+
"cell_type": "code",
|
| 739 |
+
"execution_count": null,
|
| 740 |
+
"id": "878b7686",
|
| 741 |
+
"metadata": {},
|
| 742 |
+
"outputs": [],
|
| 743 |
+
"source": [
|
| 744 |
+
"def plot_final_images_grid(results):\n",
|
| 745 |
+
" \"\"\"Display all final images in a grid for comparison.\"\"\"\n",
|
| 746 |
+
" n_samples = len(results)\n",
|
| 747 |
+
" cols = min(10, n_samples)\n",
|
| 748 |
+
" rows = (n_samples + cols - 1) // cols\n",
|
| 749 |
+
" \n",
|
| 750 |
+
" fig, axes = plt.subplots(rows, cols, figsize=(5*cols, 5*rows))\n",
|
| 751 |
+
" if n_samples == 1:\n",
|
| 752 |
+
" axes = [axes]\n",
|
| 753 |
+
" else:\n",
|
| 754 |
+
" axes = axes.flatten()\n",
|
| 755 |
+
" \n",
|
| 756 |
+
" fig.suptitle(\"Final Generated Images: All Samples\", fontsize=14, fontweight='bold')\n",
|
| 757 |
+
" \n",
|
| 758 |
+
" for idx, result in enumerate(results):\n",
|
| 759 |
+
" ax = axes[idx]\n",
|
| 760 |
+
" ax.imshow(result['final_image'])\n",
|
| 761 |
+
" ax.axis('off')\n",
|
| 762 |
+
" \n",
|
| 763 |
+
" # Get final metrics\n",
|
| 764 |
+
" metrics = result['metrics']\n",
|
| 765 |
+
" reward = metrics['reward'][-1] if metrics['reward'] else 0\n",
|
| 766 |
+
" clip_score = metrics['clip'][-1] if 'clip' in metrics and metrics['clip'] and metrics['clip'][-1] is not None else 0\n",
|
| 767 |
+
" \n",
|
| 768 |
+
" ax.set_title(f\"Sample {idx+1}\\nReward: {reward:.3f} | CLIP: {clip_score:.3f}\\n{result['prompt'][:40]}...\", \n",
|
| 769 |
+
" fontsize=9)\n",
|
| 770 |
+
" \n",
|
| 771 |
+
" # Hide unused subplots\n",
|
| 772 |
+
" for idx in range(n_samples, len(axes)):\n",
|
| 773 |
+
" axes[idx].axis('off')\n",
|
| 774 |
+
" \n",
|
| 775 |
+
" plt.tight_layout()\n",
|
| 776 |
+
" plt.savefig(Path(OUTPUT_DIR) / \"final_images_grid.png\", dpi=150, bbox_inches='tight')\n",
|
| 777 |
+
" plt.show()\n",
|
| 778 |
+
"\n",
|
| 779 |
+
"# Display final images\n",
|
| 780 |
+
"plot_final_images_grid(all_results)"
|
| 781 |
+
]
|
| 782 |
+
},
|
| 783 |
+
{
|
| 784 |
+
"cell_type": "markdown",
|
| 785 |
+
"id": "bc5a96a6",
|
| 786 |
+
"metadata": {},
|
| 787 |
+
"source": [
|
| 788 |
+
"#### Debug: Check Data"
|
| 789 |
+
]
|
| 790 |
+
},
|
| 791 |
+
{
|
| 792 |
+
"cell_type": "code",
|
| 793 |
+
"execution_count": null,
|
| 794 |
+
"id": "40008ed4",
|
| 795 |
+
"metadata": {},
|
| 796 |
+
"outputs": [],
|
| 797 |
+
"source": [
|
| 798 |
+
"# Check if data was collected properly\n",
|
| 799 |
+
"print(\"Data Collection Summary:\")\n",
|
| 800 |
+
"print(\"=\"*70)\n",
|
| 801 |
+
"\n",
|
| 802 |
+
"for idx, result in enumerate(all_results):\n",
|
| 803 |
+
" print(f\"\\nSample {idx+1}:\")\n",
|
| 804 |
+
" print(f\" Prompt: {result['prompt'][:60]}...\")\n",
|
| 805 |
+
" \n",
|
| 806 |
+
" metrics = result['metrics']\n",
|
| 807 |
+
" print(f\" Number of timesteps tracked: {len(metrics['timesteps'])}\")\n",
|
| 808 |
+
" print(f\" Number of intermediate images: {len(result['intermediate_images'])}\")\n",
|
| 809 |
+
" \n",
|
| 810 |
+
" # Check which metrics have data\n",
|
| 811 |
+
" for metric_name in ['reward', 'clip', 'aesthetic', 'pickscore', 'hpsv2', 'fid']:\n",
|
| 812 |
+
" if metric_name in metrics:\n",
|
| 813 |
+
" non_none = [v for v in metrics[metric_name] if v is not None]\n",
|
| 814 |
+
" if non_none:\n",
|
| 815 |
+
" print(f\" {metric_name.upper()}: {len(non_none)} values | \"\n",
|
| 816 |
+
" f\"Range: [{min(non_none):.3f}, {max(non_none):.3f}]\")\n",
|
| 817 |
+
" else:\n",
|
| 818 |
+
" print(f\" {metric_name.upper()}: No valid data\")\n",
|
| 819 |
+
" \n",
|
| 820 |
+
" # Check timestep range\n",
|
| 821 |
+
" if metrics['timesteps']:\n",
|
| 822 |
+
" print(f\" Timestep range: [{max(metrics['timesteps']):.0f}, {min(metrics['timesteps']):.0f}]\")\n",
|
| 823 |
+
"\n",
|
| 824 |
+
"print(\"\\n\" + \"=\"*70)"
|
| 825 |
+
]
|
| 826 |
+
},
|
| 827 |
+
{
|
| 828 |
+
"cell_type": "markdown",
|
| 829 |
+
"id": "5bdacf18",
|
| 830 |
+
"metadata": {},
|
| 831 |
+
"source": [
|
| 832 |
+
"#### Visualization: Metrics Evolution"
|
| 833 |
+
]
|
| 834 |
+
},
|
| 835 |
+
{
|
| 836 |
+
"cell_type": "code",
|
| 837 |
+
"execution_count": null,
|
| 838 |
+
"id": "be2fc746",
|
| 839 |
+
"metadata": {},
|
| 840 |
+
"outputs": [],
|
| 841 |
+
"source": [
|
| 842 |
+
"def plot_metrics_evolution(results, sample_idx):\n",
|
| 843 |
+
" \"\"\"Plot all metrics evolution in a single row for one sample.\"\"\"\n",
|
| 844 |
+
" result = results[sample_idx]\n",
|
| 845 |
+
" metrics = result['metrics']\n",
|
| 846 |
+
" timesteps = metrics['timesteps']\n",
|
| 847 |
+
" \n",
|
| 848 |
+
" # Filter metrics to plot (exclude None values)\n",
|
| 849 |
+
" metrics_to_plot = []\n",
|
| 850 |
+
" for metric_name in ['reward', 'clip', 'aesthetic', 'pickscore', 'hpsv2', 'imagereward', 'fid']:\n",
|
| 851 |
+
" if metric_name in metrics and any(v is not None for v in metrics[metric_name]):\n",
|
| 852 |
+
" metrics_to_plot.append(metric_name)\n",
|
| 853 |
+
" \n",
|
| 854 |
+
" n_metrics = len(metrics_to_plot)\n",
|
| 855 |
+
" \n",
|
| 856 |
+
" # Create figure with subplots in a row\n",
|
| 857 |
+
" fig, axes = plt.subplots(1, n_metrics, figsize=(5*n_metrics, 4))\n",
|
| 858 |
+
" if n_metrics == 1:\n",
|
| 859 |
+
" axes = [axes]\n",
|
| 860 |
+
" \n",
|
| 861 |
+
" fig.suptitle(f\"Sample {sample_idx + 1}: Metrics Evolution Across Timesteps\\n\"\n",
|
| 862 |
+
" f\"Prompt: {result['prompt'][:80]}...\", fontsize=12, fontweight='bold')\n",
|
| 863 |
+
" \n",
|
| 864 |
+
" colors = ['blue', 'green', 'red', 'purple', 'orange', 'brown', 'pink']\n",
|
| 865 |
+
" \n",
|
| 866 |
+
" for idx, metric_name in enumerate(metrics_to_plot):\n",
|
| 867 |
+
" ax = axes[idx]\n",
|
| 868 |
+
" values = [v for v in metrics[metric_name] if v is not None]\n",
|
| 869 |
+
" valid_timesteps = [t for t, v in zip(timesteps, metrics[metric_name]) if v is not None]\n",
|
| 870 |
+
" \n",
|
| 871 |
+
" if values:\n",
|
| 872 |
+
" ax.plot(valid_timesteps, values, marker='o', linewidth=2, \n",
|
| 873 |
+
" color=colors[idx % len(colors)], label=metric_name.upper())\n",
|
| 874 |
+
" ax.set_xlabel('Timestep', fontsize=10)\n",
|
| 875 |
+
" ax.set_ylabel(metric_name.upper(), fontsize=10)\n",
|
| 876 |
+
" ax.set_title(f\"{metric_name.upper()}\\n{values[0]:.3f} → {values[-1]:.3f}\", fontsize=10)\n",
|
| 877 |
+
" ax.grid(True, alpha=0.3)\n",
|
| 878 |
+
" ax.invert_xaxis() # Timesteps go from high to low\n",
|
| 879 |
+
" \n",
|
| 880 |
+
" # Add improvement annotation\n",
|
| 881 |
+
" improvement = values[-1] - values[0]\n",
|
| 882 |
+
" color = 'green' if improvement > 0 else 'red'\n",
|
| 883 |
+
" if metric_name == 'fid': # Lower is better for FID\n",
|
| 884 |
+
" color = 'green' if improvement < 0 else 'red'\n",
|
| 885 |
+
" ax.text(0.05, 0.95, f\"Δ: {improvement:+.3f}\", \n",
|
| 886 |
+
" transform=ax.transAxes, fontsize=9, verticalalignment='top',\n",
|
| 887 |
+
" bbox=dict(boxstyle='round', facecolor=color, alpha=0.3))\n",
|
| 888 |
+
" \n",
|
| 889 |
+
" plt.tight_layout()\n",
|
| 890 |
+
" \n",
|
| 891 |
+
" # Save plot\n",
|
| 892 |
+
" sample_dir = Path(OUTPUT_DIR) / f\"sample_{sample_idx+1}\"\n",
|
| 893 |
+
" plt.savefig(sample_dir / \"metrics_evolution.png\", dpi=150, bbox_inches='tight')\n",
|
| 894 |
+
" plt.show()\n",
|
| 895 |
+
"\n",
|
| 896 |
+
"# Plot for all samples\n",
|
| 897 |
+
"for idx in range(len(all_results)):\n",
|
| 898 |
+
" plot_metrics_evolution(all_results, idx)"
|
| 899 |
+
]
|
| 900 |
+
},
|
| 901 |
+
{
|
| 902 |
+
"cell_type": "markdown",
|
| 903 |
+
"id": "45b7abb7",
|
| 904 |
+
"metadata": {},
|
| 905 |
+
"source": [
|
| 906 |
+
"#### Visualization: Compare All Samples"
|
| 907 |
+
]
|
| 908 |
+
},
|
| 909 |
+
{
|
| 910 |
+
"cell_type": "code",
|
| 911 |
+
"execution_count": null,
|
| 912 |
+
"id": "7d3df7e0",
|
| 913 |
+
"metadata": {},
|
| 914 |
+
"outputs": [],
|
| 915 |
+
"source": [
|
| 916 |
+
"def plot_all_samples_comparison(results):\n",
|
| 917 |
+
" \"\"\"Plot metric evolution for all samples in a grid.\"\"\"\n",
|
| 918 |
+
" # Choose key metrics to compare\n",
|
| 919 |
+
" key_metrics = ['reward', 'clip', 'aesthetic', 'fid']\n",
|
| 920 |
+
" n_metrics = len(key_metrics)\n",
|
| 921 |
+
" n_samples = len(results)\n",
|
| 922 |
+
" \n",
|
| 923 |
+
" fig, axes = plt.subplots(n_metrics, 1, figsize=(14, 4*n_metrics))\n",
|
| 924 |
+
" if n_metrics == 1:\n",
|
| 925 |
+
" axes = [axes]\n",
|
| 926 |
+
" \n",
|
| 927 |
+
" fig.suptitle(\"Convergence Analysis: All Samples Comparison\", fontsize=14, fontweight='bold')\n",
|
| 928 |
+
" \n",
|
| 929 |
+
" colors = plt.cm.tab10(np.linspace(0, 1, n_samples))\n",
|
| 930 |
+
" \n",
|
| 931 |
+
" for metric_idx, metric_name in enumerate(key_metrics):\n",
|
| 932 |
+
" ax = axes[metric_idx]\n",
|
| 933 |
+
" \n",
|
| 934 |
+
" for sample_idx, result in enumerate(results):\n",
|
| 935 |
+
" metrics = result['metrics']\n",
|
| 936 |
+
" timesteps = metrics['timesteps']\n",
|
| 937 |
+
" values = [v for v in metrics[metric_name] if v is not None]\n",
|
| 938 |
+
" valid_timesteps = [t for t, v in zip(timesteps, metrics[metric_name]) if v is not None]\n",
|
| 939 |
+
" \n",
|
| 940 |
+
" if values:\n",
|
| 941 |
+
" ax.plot(valid_timesteps, values, marker='o', linewidth=2, \n",
|
| 942 |
+
" color=colors[sample_idx], label=f\"Sample {sample_idx+1}\", alpha=0.7)\n",
|
| 943 |
+
" \n",
|
| 944 |
+
" ax.set_xlabel('Timestep', fontsize=11)\n",
|
| 945 |
+
" ax.set_ylabel(metric_name.upper(), fontsize=11)\n",
|
| 946 |
+
" ax.set_title(f\"{metric_name.upper()} Evolution\", fontsize=12, fontweight='bold')\n",
|
| 947 |
+
" ax.grid(True, alpha=0.3)\n",
|
| 948 |
+
" ax.invert_xaxis()\n",
|
| 949 |
+
" ax.legend(loc='best', fontsize=9)\n",
|
| 950 |
+
" \n",
|
| 951 |
+
" plt.tight_layout()\n",
|
| 952 |
+
" plt.savefig(Path(OUTPUT_DIR) / \"all_samples_comparison.png\", dpi=150, bbox_inches='tight')\n",
|
| 953 |
+
" plt.show()\n",
|
| 954 |
+
"\n",
|
| 955 |
+
"# Plot comparison\n",
|
| 956 |
+
"plot_all_samples_comparison(all_results)"
|
| 957 |
+
]
|
| 958 |
+
},
|
| 959 |
+
{
|
| 960 |
+
"cell_type": "markdown",
|
| 961 |
+
"id": "44f97581",
|
| 962 |
+
"metadata": {},
|
| 963 |
+
"source": [
|
| 964 |
+
"#### Convergence Analysis"
|
| 965 |
+
]
|
| 966 |
+
},
|
| 967 |
+
{
|
| 968 |
+
"cell_type": "code",
|
| 969 |
+
"execution_count": null,
|
| 970 |
+
"id": "9dffd878",
|
| 971 |
+
"metadata": {},
|
| 972 |
+
"outputs": [],
|
| 973 |
+
"source": [
|
| 974 |
+
"def analyze_convergence(results):\n",
|
| 975 |
+
" \"\"\"Analyze convergence behavior across samples.\"\"\"\n",
|
| 976 |
+
" print(\"\\n\" + \"=\"*70)\n",
|
| 977 |
+
" print(\"CONVERGENCE ANALYSIS\")\n",
|
| 978 |
+
" print(\"=\"*70)\n",
|
| 979 |
+
" \n",
|
| 980 |
+
" for metric_name in ['reward', 'clip', 'aesthetic', 'pickscore', 'hpsv2']:\n",
|
| 981 |
+
" print(f\"\\n{metric_name.upper()} Convergence:\")\n",
|
| 982 |
+
" print(\"-\" * 50)\n",
|
| 983 |
+
" \n",
|
| 984 |
+
" improvements = []\n",
|
| 985 |
+
" initial_values = []\n",
|
| 986 |
+
" final_values = []\n",
|
| 987 |
+
" \n",
|
| 988 |
+
" for idx, result in enumerate(results):\n",
|
| 989 |
+
" metrics = result['metrics']\n",
|
| 990 |
+
" if metric_name in metrics:\n",
|
| 991 |
+
" values = [v for v in metrics[metric_name] if v is not None]\n",
|
| 992 |
+
" if values:\n",
|
| 993 |
+
" initial = values[0]\n",
|
| 994 |
+
" final = values[-1]\n",
|
| 995 |
+
" improvement = final - initial\n",
|
| 996 |
+
" \n",
|
| 997 |
+
" initial_values.append(initial)\n",
|
| 998 |
+
" final_values.append(final)\n",
|
| 999 |
+
" improvements.append(improvement)\n",
|
| 1000 |
+
" \n",
|
| 1001 |
+
" print(f\" Sample {idx+1}: {initial:.4f} → {final:.4f} ({improvement:+.4f})\")\n",
|
| 1002 |
+
" \n",
|
| 1003 |
+
" if improvements:\n",
|
| 1004 |
+
" avg_improvement = np.mean(improvements)\n",
|
| 1005 |
+
" std_improvement = np.std(improvements)\n",
|
| 1006 |
+
" print(f\"\\n Average Improvement: {avg_improvement:+.4f} (±{std_improvement:.4f})\")\n",
|
| 1007 |
+
" print(f\" Converged: {'YES' if std_improvement < 0.1 * abs(avg_improvement) else 'NO'}\")\n",
|
| 1008 |
+
" \n",
|
| 1009 |
+
" # Summary\n",
|
| 1010 |
+
" print(\"\\n\" + \"=\"*70)\n",
|
| 1011 |
+
" print(\"SUMMARY\")\n",
|
| 1012 |
+
" print(\"=\"*70)\n",
|
| 1013 |
+
" print(f\"Total samples analyzed: {len(results)}\")\n",
|
| 1014 |
+
" print(f\"Gradient ascent config: {grad_config}\")\n",
|
| 1015 |
+
" print(f\"\\nConclusion: Analyze the plots above to determine convergence behavior.\")\n",
|
| 1016 |
+
" print(f\"Look for:\")\n",
|
| 1017 |
+
" print(f\" 1. Metrics plateauing (flattening out)\")\n",
|
| 1018 |
+
" print(f\" 2. Consistent improvement across samples\")\n",
|
| 1019 |
+
" print(f\" 3. Low variance in final metric values\")\n",
|
| 1020 |
+
"\n",
|
| 1021 |
+
"analyze_convergence(all_results)"
|
| 1022 |
+
]
|
| 1023 |
+
},
|
| 1024 |
+
{
|
| 1025 |
+
"cell_type": "markdown",
|
| 1026 |
+
"id": "d263be5f",
|
| 1027 |
+
"metadata": {},
|
| 1028 |
+
"source": [
|
| 1029 |
+
"#### Save Results Summary"
|
| 1030 |
+
]
|
| 1031 |
+
},
|
| 1032 |
+
{
|
| 1033 |
+
"cell_type": "code",
|
| 1034 |
+
"execution_count": null,
|
| 1035 |
+
"id": "5434e7c0",
|
| 1036 |
+
"metadata": {},
|
| 1037 |
+
"outputs": [],
|
| 1038 |
+
"source": [
|
| 1039 |
+
"# Save comprehensive summary\n",
|
| 1040 |
+
"summary = {\n",
|
| 1041 |
+
" 'config': {\n",
|
| 1042 |
+
" 'num_samples': NUM_SAMPLES,\n",
|
| 1043 |
+
" 'num_inference_steps': NUM_INFERENCE_STEPS,\n",
|
| 1044 |
+
" 'cfg_scale': CFG_SCALE,\n",
|
| 1045 |
+
" 'grad_config': grad_config,\n",
|
| 1046 |
+
" 'metrics': METRICS,\n",
|
| 1047 |
+
" 'model_variant': MODEL_VARIANT\n",
|
| 1048 |
+
" },\n",
|
| 1049 |
+
" 'samples': []\n",
|
| 1050 |
+
"}\n",
|
| 1051 |
+
"\n",
|
| 1052 |
+
"for idx, result in enumerate(all_results):\n",
|
| 1053 |
+
" metrics = result['metrics']\n",
|
| 1054 |
+
" sample_summary = {\n",
|
| 1055 |
+
" 'sample_id': idx + 1,\n",
|
| 1056 |
+
" 'prompt': result['prompt'],\n",
|
| 1057 |
+
" 'reference_image': result['reference_image']\n",
|
| 1058 |
+
" }\n",
|
| 1059 |
+
" \n",
|
| 1060 |
+
" for metric_name in ['reward', 'clip', 'aesthetic', 'pickscore', 'hpsv2']:\n",
|
| 1061 |
+
" if metric_name in metrics:\n",
|
| 1062 |
+
" values = [v for v in metrics[metric_name] if v is not None]\n",
|
| 1063 |
+
" if values:\n",
|
| 1064 |
+
" sample_summary[metric_name] = {\n",
|
| 1065 |
+
" 'initial': values[0],\n",
|
| 1066 |
+
" 'final': values[-1],\n",
|
| 1067 |
+
" 'improvement': values[-1] - values[0],\n",
|
| 1068 |
+
" 'all_values': values\n",
|
| 1069 |
+
" }\n",
|
| 1070 |
+
" \n",
|
| 1071 |
+
" summary['samples'].append(sample_summary)\n",
|
| 1072 |
+
"\n",
|
| 1073 |
+
"# Save summary\n",
|
| 1074 |
+
"with open(Path(OUTPUT_DIR) / \"convergence_summary.json\", 'w') as f:\n",
|
| 1075 |
+
" json.dump(summary, f, indent=2)\n",
|
| 1076 |
+
"\n",
|
| 1077 |
+
"print(f\"\\n✓ Results saved to: {OUTPUT_DIR}\")\n",
|
| 1078 |
+
"print(f\" - convergence_summary.json\")\n",
|
| 1079 |
+
"print(f\" - all_samples_comparison.png\")\n",
|
| 1080 |
+
"print(f\" - sample_X/ directories with individual results\")"
|
| 1081 |
+
]
|
| 1082 |
+
}
|
| 1083 |
+
],
|
| 1084 |
+
"metadata": {
|
| 1085 |
+
"kernelspec": {
|
| 1086 |
+
"display_name": "Python 3",
|
| 1087 |
+
"language": "python",
|
| 1088 |
+
"name": "python3"
|
| 1089 |
+
},
|
| 1090 |
+
"language_info": {
|
| 1091 |
+
"codemirror_mode": {
|
| 1092 |
+
"name": "ipython",
|
| 1093 |
+
"version": 3
|
| 1094 |
+
},
|
| 1095 |
+
"file_extension": ".py",
|
| 1096 |
+
"mimetype": "text/x-python",
|
| 1097 |
+
"name": "python",
|
| 1098 |
+
"nbconvert_exporter": "python",
|
| 1099 |
+
"pygments_lexer": "ipython3",
|
| 1100 |
+
"version": "3.10.18"
|
| 1101 |
+
}
|
| 1102 |
+
},
|
| 1103 |
+
"nbformat": 4,
|
| 1104 |
+
"nbformat_minor": 5
|
| 1105 |
+
}
|
Reward_sdxl_idealized/tune_hyperparams.py
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hyperparameter tuning script for gradient ascent optimization.
|
| 3 |
+
|
| 4 |
+
This script performs a systematic search over hyperparameter combinations
|
| 5 |
+
to find the optimal configuration for maximum evaluation scores.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import subprocess
|
| 9 |
+
import json
|
| 10 |
+
import argparse
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
import itertools
|
| 14 |
+
import numpy as np
|
| 15 |
+
from typing import Dict, List, Any
|
| 16 |
+
import re
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class HyperparameterTuner:
|
| 20 |
+
"""Hyperparameter tuner for gradient ascent."""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
output_dir: str = "tuning_results",
|
| 25 |
+
max_samples: int = 30,
|
| 26 |
+
num_steps: int = 20,
|
| 27 |
+
dataset_type: str = "pickapic",
|
| 28 |
+
model_variant: str = "lpo",
|
| 29 |
+
cuda_id: int = 0,
|
| 30 |
+
metrics: List[str] = None
|
| 31 |
+
):
|
| 32 |
+
self.output_dir = Path(output_dir)
|
| 33 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 34 |
+
|
| 35 |
+
self.max_samples = max_samples
|
| 36 |
+
self.num_steps = num_steps
|
| 37 |
+
self.dataset_type = dataset_type
|
| 38 |
+
self.model_variant = model_variant
|
| 39 |
+
self.cuda_id = cuda_id
|
| 40 |
+
self.metrics = metrics or ["clip", "aesthetic", "pickscore", "hpsv2", "imagereward"]
|
| 41 |
+
|
| 42 |
+
# Store results
|
| 43 |
+
self.results = []
|
| 44 |
+
self.baseline_results = None
|
| 45 |
+
|
| 46 |
+
def define_search_space(self) -> List[Dict[str, Any]]:
|
| 47 |
+
"""Define the hyperparameter search space - FULL GRID SEARCH.
|
| 48 |
+
|
| 49 |
+
Tests all combinations of parameters including momentum overrides for configs that support it.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
# Define all parameter values
|
| 53 |
+
cfg_scales = [3.0, 5.0, 7.5] #
|
| 54 |
+
|
| 55 |
+
# All available gradient configs from grad_ascent_configs.py
|
| 56 |
+
grad_configs = [
|
| 57 |
+
# "constant",
|
| 58 |
+
# "linear",
|
| 59 |
+
"cosine_nesterov",
|
| 60 |
+
# "low_to_high_nesterov",
|
| 61 |
+
# "high_to_low_nesterov",
|
| 62 |
+
"low_to_high_momentum",
|
| 63 |
+
"high_to_low_momentum",
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
num_grad_steps_list = [1, 2] # 5, 7, 10
|
| 67 |
+
grad_step_sizes = [0.001, 0.005, 0.01, 0.05] #
|
| 68 |
+
momentums = [0.5, 0.8, 0.9] #
|
| 69 |
+
|
| 70 |
+
# Generate ALL combinations using itertools.product
|
| 71 |
+
configs = []
|
| 72 |
+
for cfg, grad_cfg, num_steps, step_size, momentum in itertools.product(
|
| 73 |
+
cfg_scales, grad_configs, num_grad_steps_list, grad_step_sizes, momentums
|
| 74 |
+
):
|
| 75 |
+
configs.append({
|
| 76 |
+
"cfg_scale": cfg,
|
| 77 |
+
"grad_config": grad_cfg,
|
| 78 |
+
"num_grad_steps": num_steps,
|
| 79 |
+
"grad_step_size": step_size,
|
| 80 |
+
"momentum": momentum,
|
| 81 |
+
})
|
| 82 |
+
|
| 83 |
+
print(f"\nGenerated {len(configs)} total configurations")
|
| 84 |
+
print(f" cfg_scales: {len(cfg_scales)}")
|
| 85 |
+
print(f" grad_configs: {len(grad_configs)}")
|
| 86 |
+
print(f" num_grad_steps: {len(num_grad_steps_list)}")
|
| 87 |
+
print(f" grad_step_sizes: {len(grad_step_sizes)}")
|
| 88 |
+
print(f" momentums: {len(momentums)}")
|
| 89 |
+
print(f" Total: {len(cfg_scales)} × {len(grad_configs)} × {len(num_grad_steps_list)} × {len(grad_step_sizes)} × {len(momentums)} = {len(configs)}")
|
| 90 |
+
|
| 91 |
+
return configs
|
| 92 |
+
|
| 93 |
+
def run_baseline(self) -> Dict[str, float]:
|
| 94 |
+
"""Run baseline evaluation once."""
|
| 95 |
+
print("\n" + "="*80)
|
| 96 |
+
print("RUNNING BASELINE EVALUATION")
|
| 97 |
+
print("="*80)
|
| 98 |
+
|
| 99 |
+
# Use median cfg_scale for baseline
|
| 100 |
+
cfg_scale = 5.0
|
| 101 |
+
|
| 102 |
+
output_dir = self.output_dir / "baseline"
|
| 103 |
+
|
| 104 |
+
cmd = [
|
| 105 |
+
"python", "eval.py",
|
| 106 |
+
"--model_variant", self.model_variant,
|
| 107 |
+
"--dataset_type", self.dataset_type,
|
| 108 |
+
"--max_samples", str(self.max_samples),
|
| 109 |
+
"--num_steps", str(self.num_steps),
|
| 110 |
+
"--cfg_scale", str(cfg_scale),
|
| 111 |
+
"--output_dir", str(output_dir),
|
| 112 |
+
"--cuda", str(self.cuda_id),
|
| 113 |
+
"--mode", "baseline",
|
| 114 |
+
"--metrics", *self.metrics,
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
print(f"Command: {' '.join(cmd)}")
|
| 118 |
+
|
| 119 |
+
try:
|
| 120 |
+
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
| 121 |
+
|
| 122 |
+
# Parse results from output
|
| 123 |
+
metrics = self._parse_metrics(result.stdout, "baseline")
|
| 124 |
+
|
| 125 |
+
print(f"\nBaseline Results:")
|
| 126 |
+
for metric, value in metrics.items():
|
| 127 |
+
print(f" {metric}: {value:.4f}")
|
| 128 |
+
|
| 129 |
+
self.baseline_results = {
|
| 130 |
+
"cfg_scale": cfg_scale,
|
| 131 |
+
"metrics": metrics,
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
return metrics
|
| 135 |
+
|
| 136 |
+
except subprocess.CalledProcessError as e:
|
| 137 |
+
print(f"Error running baseline: {e}")
|
| 138 |
+
print(f"Stdout: {e.stdout}")
|
| 139 |
+
print(f"Stderr: {e.stderr}")
|
| 140 |
+
return {}
|
| 141 |
+
|
| 142 |
+
def run_experiment(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
| 143 |
+
"""Run a single experiment with given hyperparameters."""
|
| 144 |
+
|
| 145 |
+
# Create output directory for this config
|
| 146 |
+
config_name = f"cfg{config['cfg_scale']}_" \
|
| 147 |
+
f"{config['grad_config']}_" \
|
| 148 |
+
f"steps{config['num_grad_steps']}_" \
|
| 149 |
+
f"lr{config['grad_step_size']}_" \
|
| 150 |
+
f"mom{config['momentum']}"
|
| 151 |
+
|
| 152 |
+
output_dir = self.output_dir / config_name
|
| 153 |
+
|
| 154 |
+
# Build command
|
| 155 |
+
cmd = [
|
| 156 |
+
"python", "eval.py",
|
| 157 |
+
"--model_variant", self.model_variant,
|
| 158 |
+
"--dataset_type", self.dataset_type,
|
| 159 |
+
"--grad_config", config["grad_config"],
|
| 160 |
+
"--max_samples", str(self.max_samples),
|
| 161 |
+
"--num_steps", str(self.num_steps),
|
| 162 |
+
"--cfg_scale", str(config["cfg_scale"]),
|
| 163 |
+
"--output_dir", str(output_dir),
|
| 164 |
+
"--cuda", str(self.cuda_id),
|
| 165 |
+
"--mode", "gradient_ascent",
|
| 166 |
+
"--metrics", *self.metrics,
|
| 167 |
+
# Override config parameters
|
| 168 |
+
"--override_num_grad_steps", str(config["num_grad_steps"]),
|
| 169 |
+
"--override_grad_step_size", str(config["grad_step_size"]),
|
| 170 |
+
"--override_momentum", str(config["momentum"]),
|
| 171 |
+
]
|
| 172 |
+
|
| 173 |
+
print(f"\nRunning experiment: {config_name}")
|
| 174 |
+
print(f"Config: {config}")
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
| 178 |
+
|
| 179 |
+
# Parse metrics from output
|
| 180 |
+
metrics = self._parse_metrics(result.stdout, "gradient_ascent")
|
| 181 |
+
|
| 182 |
+
# Compute improvement over baseline
|
| 183 |
+
improvements = {}
|
| 184 |
+
if self.baseline_results:
|
| 185 |
+
baseline_metrics = self.baseline_results["metrics"]
|
| 186 |
+
for metric, value in metrics.items():
|
| 187 |
+
if metric in baseline_metrics:
|
| 188 |
+
baseline_val = baseline_metrics[metric]
|
| 189 |
+
if baseline_val != 0:
|
| 190 |
+
improvement = ((value - baseline_val) / abs(baseline_val)) * 100
|
| 191 |
+
improvements[f"{metric}_improvement"] = improvement
|
| 192 |
+
|
| 193 |
+
result_dict = {
|
| 194 |
+
"config": config,
|
| 195 |
+
"metrics": metrics,
|
| 196 |
+
"improvements": improvements,
|
| 197 |
+
"output_dir": str(output_dir),
|
| 198 |
+
"timestamp": datetime.now().isoformat(),
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
print(f"Results:")
|
| 202 |
+
for metric, value in metrics.items():
|
| 203 |
+
print(f" {metric}: {value:.4f}")
|
| 204 |
+
if improvements:
|
| 205 |
+
print(f"Improvements over baseline:")
|
| 206 |
+
for metric, value in improvements.items():
|
| 207 |
+
print(f" {metric}: {value:+.2f}%")
|
| 208 |
+
|
| 209 |
+
return result_dict
|
| 210 |
+
|
| 211 |
+
except subprocess.CalledProcessError as e:
|
| 212 |
+
print(f"Error running experiment: {e}")
|
| 213 |
+
print(f"Stderr: {e.stderr}")
|
| 214 |
+
return {
|
| 215 |
+
"config": config,
|
| 216 |
+
"error": str(e),
|
| 217 |
+
"timestamp": datetime.now().isoformat(),
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
def _parse_metrics(self, output: str, mode: str) -> Dict[str, float]:
|
| 221 |
+
"""Parse metrics from eval.py output."""
|
| 222 |
+
metrics = {}
|
| 223 |
+
|
| 224 |
+
# Look for the summary section
|
| 225 |
+
lines = output.split('\n')
|
| 226 |
+
|
| 227 |
+
# Pattern to match metric lines like " Reward: 0.1234"
|
| 228 |
+
metric_patterns = {
|
| 229 |
+
"reward": r"Reward:\s+([-+]?\d*\.?\d+)",
|
| 230 |
+
"clip": r"CLIP Score:\s+([-+]?\d*\.?\d+)",
|
| 231 |
+
"aesthetic": r"Aesthetic Score:\s+([-+]?\d*\.?\d+)",
|
| 232 |
+
"pickscore": r"PickScore:\s+([-+]?\d*\.?\d+)",
|
| 233 |
+
"hpsv2": r"HPSv2 Score:\s+([-+]?\d*\.?\d+)",
|
| 234 |
+
"hpsv21": r"HPSv2\.1 Score:\s+([-+]?\d*\.?\d+)",
|
| 235 |
+
"imagereward": r"ImageReward:\s+([-+]?\d*\.?\d+)",
|
| 236 |
+
"fid": r"FID:\s+([-+]?\d*\.?\d+)",
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
for line in lines:
|
| 240 |
+
for metric_name, pattern in metric_patterns.items():
|
| 241 |
+
match = re.search(pattern, line)
|
| 242 |
+
if match:
|
| 243 |
+
metrics[metric_name] = float(match.group(1))
|
| 244 |
+
|
| 245 |
+
return metrics
|
| 246 |
+
|
| 247 |
+
def compute_aggregate_score(self, metrics: Dict[str, float]) -> float:
|
| 248 |
+
"""
|
| 249 |
+
Compute aggregate score for ranking configurations.
|
| 250 |
+
|
| 251 |
+
Uses weighted combination of metrics (higher is better for most,
|
| 252 |
+
except FID which is lower is better).
|
| 253 |
+
"""
|
| 254 |
+
weights = {
|
| 255 |
+
"reward": 1.0,
|
| 256 |
+
"clip": 0.8,
|
| 257 |
+
"aesthetic": 0.8,
|
| 258 |
+
"pickscore": 1.0,
|
| 259 |
+
"hpsv2": 1.0,
|
| 260 |
+
"hpsv21": 1.0,
|
| 261 |
+
"imagereward": 1.0,
|
| 262 |
+
"fid": -0.5, # Negative weight (lower FID is better)
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
score = 0.0
|
| 266 |
+
total_weight = 0.0
|
| 267 |
+
|
| 268 |
+
for metric, value in metrics.items():
|
| 269 |
+
if metric in weights:
|
| 270 |
+
score += weights[metric] * value
|
| 271 |
+
total_weight += abs(weights[metric])
|
| 272 |
+
|
| 273 |
+
# Normalize by total weight
|
| 274 |
+
if total_weight > 0:
|
| 275 |
+
score /= total_weight
|
| 276 |
+
|
| 277 |
+
return score
|
| 278 |
+
|
| 279 |
+
def run_search(
|
| 280 |
+
self,
|
| 281 |
+
search_type: str = "grid",
|
| 282 |
+
start_idx: int = 0,
|
| 283 |
+
end_idx: int = None
|
| 284 |
+
) -> List[Dict[str, Any]]:
|
| 285 |
+
"""
|
| 286 |
+
Run hyperparameter search.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
search_type: Type of search ("grid" or "random")
|
| 290 |
+
start_idx: Starting index for experiments (for GPU distribution)
|
| 291 |
+
end_idx: Ending index for experiments (for GPU distribution)
|
| 292 |
+
"""
|
| 293 |
+
all_configs = self.define_search_space()
|
| 294 |
+
|
| 295 |
+
print("\n" + "="*80)
|
| 296 |
+
print("HYPERPARAMETER SEARCH CONFIGURATION")
|
| 297 |
+
print("="*80)
|
| 298 |
+
print(f"Dataset: {self.dataset_type}")
|
| 299 |
+
print(f"Model: {self.model_variant}")
|
| 300 |
+
print(f"Samples: {self.max_samples}")
|
| 301 |
+
print(f"Inference steps: {self.num_steps}")
|
| 302 |
+
print(f"Metrics: {', '.join(self.metrics)}")
|
| 303 |
+
|
| 304 |
+
# Select subset of configs if indices provided
|
| 305 |
+
if search_type == "grid":
|
| 306 |
+
configs = all_configs
|
| 307 |
+
elif search_type == "random":
|
| 308 |
+
# Random sample from all configs
|
| 309 |
+
n_samples = min(50, len(all_configs))
|
| 310 |
+
indices = np.random.choice(len(all_configs), n_samples, replace=False)
|
| 311 |
+
configs = [all_configs[i] for i in indices]
|
| 312 |
+
else:
|
| 313 |
+
raise ValueError(f"Unknown search type: {search_type}")
|
| 314 |
+
|
| 315 |
+
# Apply index slicing for GPU distribution
|
| 316 |
+
if end_idx is None:
|
| 317 |
+
end_idx = len(configs)
|
| 318 |
+
configs = configs[start_idx:end_idx]
|
| 319 |
+
|
| 320 |
+
print(f"\nTotal configurations: {len(all_configs)}")
|
| 321 |
+
print(f"Assigned to this worker: {len(configs)} (indices {start_idx} to {end_idx})")
|
| 322 |
+
|
| 323 |
+
# Run baseline first
|
| 324 |
+
if self.baseline_results is None:
|
| 325 |
+
self.run_baseline()
|
| 326 |
+
|
| 327 |
+
# Run experiments
|
| 328 |
+
print("\n" + "="*80)
|
| 329 |
+
print("RUNNING EXPERIMENTS")
|
| 330 |
+
print("="*80)
|
| 331 |
+
|
| 332 |
+
for i, config in enumerate(configs, 1):
|
| 333 |
+
print(f"\n{'='*80}")
|
| 334 |
+
print(f"Experiment {i}/{len(configs)}")
|
| 335 |
+
print(f"{'='*80}")
|
| 336 |
+
|
| 337 |
+
result = self.run_experiment(config)
|
| 338 |
+
self.results.append(result)
|
| 339 |
+
|
| 340 |
+
# Save intermediate results
|
| 341 |
+
self._save_results()
|
| 342 |
+
|
| 343 |
+
return self.results
|
| 344 |
+
|
| 345 |
+
def _generate_grid_configs(self, search_space: Dict[str, List[Any]]) -> List[Dict[str, Any]]:
|
| 346 |
+
"""Generate all combinations for grid search."""
|
| 347 |
+
keys = list(search_space.keys())
|
| 348 |
+
values = list(search_space.values())
|
| 349 |
+
|
| 350 |
+
configs = []
|
| 351 |
+
for combination in itertools.product(*values):
|
| 352 |
+
config = dict(zip(keys, combination))
|
| 353 |
+
configs.append(config)
|
| 354 |
+
|
| 355 |
+
return configs
|
| 356 |
+
|
| 357 |
+
def _generate_random_configs(
|
| 358 |
+
self,
|
| 359 |
+
search_space: Dict[str, List[Any]],
|
| 360 |
+
n_samples: int = 20
|
| 361 |
+
) -> List[Dict[str, Any]]:
|
| 362 |
+
"""Generate random configurations for random search."""
|
| 363 |
+
configs = []
|
| 364 |
+
|
| 365 |
+
for _ in range(n_samples):
|
| 366 |
+
config = {}
|
| 367 |
+
for param, values in search_space.items():
|
| 368 |
+
config[param] = np.random.choice(values)
|
| 369 |
+
configs.append(config)
|
| 370 |
+
|
| 371 |
+
return configs
|
| 372 |
+
|
| 373 |
+
def _save_results(self):
|
| 374 |
+
"""Save results to JSON file."""
|
| 375 |
+
results_file = self.output_dir / "tuning_results.json"
|
| 376 |
+
|
| 377 |
+
data = {
|
| 378 |
+
"baseline": self.baseline_results,
|
| 379 |
+
"experiments": self.results,
|
| 380 |
+
"timestamp": datetime.now().isoformat(),
|
| 381 |
+
"config": {
|
| 382 |
+
"max_samples": self.max_samples,
|
| 383 |
+
"num_steps": self.num_steps,
|
| 384 |
+
"dataset_type": self.dataset_type,
|
| 385 |
+
"model_variant": self.model_variant,
|
| 386 |
+
}
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
with open(results_file, 'w') as f:
|
| 390 |
+
json.dump(data, f, indent=2)
|
| 391 |
+
|
| 392 |
+
print(f"\nResults saved to: {results_file}")
|
| 393 |
+
|
| 394 |
+
def analyze_results(self) -> Dict[str, Any]:
|
| 395 |
+
"""Analyze results and find best configuration."""
|
| 396 |
+
if not self.results:
|
| 397 |
+
print("No results to analyze!")
|
| 398 |
+
return {}
|
| 399 |
+
|
| 400 |
+
print("\n" + "="*80)
|
| 401 |
+
print("ANALYSIS: FINDING BEST CONFIGURATION")
|
| 402 |
+
print("="*80)
|
| 403 |
+
|
| 404 |
+
# Filter out failed experiments
|
| 405 |
+
successful_results = [r for r in self.results if "metrics" in r]
|
| 406 |
+
|
| 407 |
+
if not successful_results:
|
| 408 |
+
print("No successful experiments!")
|
| 409 |
+
return {}
|
| 410 |
+
|
| 411 |
+
# Compute aggregate scores
|
| 412 |
+
for result in successful_results:
|
| 413 |
+
metrics = result["metrics"]
|
| 414 |
+
result["aggregate_score"] = self.compute_aggregate_score(metrics)
|
| 415 |
+
|
| 416 |
+
# Sort by aggregate score
|
| 417 |
+
successful_results.sort(key=lambda x: x["aggregate_score"], reverse=True)
|
| 418 |
+
|
| 419 |
+
# Print top 5 configurations
|
| 420 |
+
print("\nTop 5 Configurations:")
|
| 421 |
+
print("="*80)
|
| 422 |
+
|
| 423 |
+
for i, result in enumerate(successful_results[:5], 1):
|
| 424 |
+
print(f"\n#{i} - Aggregate Score: {result['aggregate_score']:.4f}")
|
| 425 |
+
print(f"Config: {result['config']}")
|
| 426 |
+
print(f"Metrics:")
|
| 427 |
+
for metric, value in result['metrics'].items():
|
| 428 |
+
print(f" {metric}: {value:.4f}")
|
| 429 |
+
if result.get('improvements'):
|
| 430 |
+
print(f"Improvements over baseline:")
|
| 431 |
+
for metric, value in result['improvements'].items():
|
| 432 |
+
print(f" {metric}: {value:+.2f}%")
|
| 433 |
+
|
| 434 |
+
# Save best config
|
| 435 |
+
best_result = successful_results[0]
|
| 436 |
+
best_config_file = self.output_dir / "best_config.json"
|
| 437 |
+
|
| 438 |
+
with open(best_config_file, 'w') as f:
|
| 439 |
+
json.dump({
|
| 440 |
+
"config": best_result["config"],
|
| 441 |
+
"metrics": best_result["metrics"],
|
| 442 |
+
"aggregate_score": best_result["aggregate_score"],
|
| 443 |
+
"improvements": best_result.get("improvements", {}),
|
| 444 |
+
}, f, indent=2)
|
| 445 |
+
|
| 446 |
+
print(f"\n✓ Best configuration saved to: {best_config_file}")
|
| 447 |
+
|
| 448 |
+
return best_result
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def main():
|
| 452 |
+
parser = argparse.ArgumentParser(description="Hyperparameter tuning for gradient ascent")
|
| 453 |
+
parser.add_argument("--output_dir", type=str, default="tuning_results",
|
| 454 |
+
help="Directory to save tuning results")
|
| 455 |
+
parser.add_argument("--max_samples", type=int, default=30,
|
| 456 |
+
help="Number of samples to use for tuning")
|
| 457 |
+
parser.add_argument("--num_steps", type=int, default=20,
|
| 458 |
+
help="Number of inference steps (fixed)")
|
| 459 |
+
parser.add_argument("--dataset_type", type=str, default="pickapic",
|
| 460 |
+
choices=["coco", "pickapic"],
|
| 461 |
+
help="Dataset to use")
|
| 462 |
+
parser.add_argument("--model_variant", type=str, default="lpo",
|
| 463 |
+
choices=["origin", "spo", "diffusion_dpo", "lpo"],
|
| 464 |
+
help="Model variant to use")
|
| 465 |
+
parser.add_argument("--cuda", type=int, default=0,
|
| 466 |
+
help="CUDA device ID")
|
| 467 |
+
parser.add_argument("--search_type", type=str, default="grid",
|
| 468 |
+
choices=["grid", "random"],
|
| 469 |
+
help="Type of hyperparameter search")
|
| 470 |
+
parser.add_argument("--metrics", type=str, nargs="+",
|
| 471 |
+
default=["clip", "aesthetic", "pickscore", "hpsv2", "imagereward"],
|
| 472 |
+
help="Metrics to evaluate")
|
| 473 |
+
parser.add_argument("--start_idx", type=int, default=0,
|
| 474 |
+
help="Starting index for experiments (for GPU distribution)")
|
| 475 |
+
parser.add_argument("--end_idx", type=int, default=None,
|
| 476 |
+
help="Ending index for experiments (for GPU distribution)")
|
| 477 |
+
|
| 478 |
+
args = parser.parse_args()
|
| 479 |
+
|
| 480 |
+
# Create tuner
|
| 481 |
+
tuner = HyperparameterTuner(
|
| 482 |
+
output_dir=args.output_dir,
|
| 483 |
+
max_samples=args.max_samples,
|
| 484 |
+
num_steps=args.num_steps,
|
| 485 |
+
dataset_type=args.dataset_type,
|
| 486 |
+
model_variant=args.model_variant,
|
| 487 |
+
cuda_id=args.cuda,
|
| 488 |
+
metrics=args.metrics,
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
# Run search
|
| 492 |
+
results = tuner.run_search(
|
| 493 |
+
search_type=args.search_type,
|
| 494 |
+
start_idx=args.start_idx,
|
| 495 |
+
end_idx=args.end_idx
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
# Analyze results
|
| 499 |
+
best_result = tuner.analyze_results()
|
| 500 |
+
|
| 501 |
+
print("\n" + "="*80)
|
| 502 |
+
print("TUNING COMPLETE!")
|
| 503 |
+
print("="*80)
|
| 504 |
+
print(f"Total experiments: {len(results)}")
|
| 505 |
+
print(f"Results directory: {args.output_dir}")
|
| 506 |
+
|
| 507 |
+
if best_result:
|
| 508 |
+
print(f"\nBest configuration:")
|
| 509 |
+
print(json.dumps(best_result["config"], indent=2))
|
| 510 |
+
print(f"\nAggregate score: {best_result['aggregate_score']:.4f}")
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
if __name__ == "__main__":
|
| 514 |
+
main()
|
Reward_sdxl_idealized/tune_parallel.sh
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Parallel hyperparameter tuning across 8 GPUs
|
| 4 |
+
# This script distributes experiments evenly across all available GPUs
|
| 5 |
+
|
| 6 |
+
clear
|
| 7 |
+
|
| 8 |
+
# Activate conda environment
|
| 9 |
+
source ~/miniconda3/etc/profile.d/conda.sh
|
| 10 |
+
conda activate /home/ec2-user/aev
|
| 11 |
+
|
| 12 |
+
# Configuration
|
| 13 |
+
DATASET_TYPE="pickapic" # "coco" or "pickapic"
|
| 14 |
+
MODEL_VARIANT="lpo" # "origin", "spo", "diffusion_dpo", or "lpo"
|
| 15 |
+
MAX_SAMPLES=500 # Number of samples for tuning
|
| 16 |
+
NUM_STEPS=50 # Fixed inference steps
|
| 17 |
+
SEARCH_TYPE="grid" # "grid" or "random"
|
| 18 |
+
OUTPUT_DIR="RESULTS_TURNING/run_2"
|
| 19 |
+
NUM_GPUS=8 # Number of GPUs to use
|
| 20 |
+
|
| 21 |
+
echo "=============================================="
|
| 22 |
+
echo " PARALLEL HYPERPARAMETER TUNING"
|
| 23 |
+
echo "=============================================="
|
| 24 |
+
echo ""
|
| 25 |
+
echo "Configuration:"
|
| 26 |
+
echo " Dataset: $DATASET_TYPE"
|
| 27 |
+
echo " Model: $MODEL_VARIANT"
|
| 28 |
+
echo " Samples: $MAX_SAMPLES"
|
| 29 |
+
echo " Inference Steps: $NUM_STEPS"
|
| 30 |
+
echo " Search Type: $SEARCH_TYPE"
|
| 31 |
+
echo " GPUs: $NUM_GPUS"
|
| 32 |
+
echo " Output: $OUTPUT_DIR"
|
| 33 |
+
echo ""
|
| 34 |
+
|
| 35 |
+
# First, calculate total number of experiments
|
| 36 |
+
echo "Calculating total experiments..."
|
| 37 |
+
TOTAL_CONFIGS=$(python -c "
|
| 38 |
+
from tune_hyperparams import HyperparameterTuner
|
| 39 |
+
import sys
|
| 40 |
+
tuner = HyperparameterTuner()
|
| 41 |
+
configs = tuner.define_search_space()
|
| 42 |
+
sys.stderr.write(f'Generated {len(configs)} configurations\n')
|
| 43 |
+
print(len(configs))
|
| 44 |
+
" 2>&1 | tail -1)
|
| 45 |
+
|
| 46 |
+
echo "Total configurations: $TOTAL_CONFIGS"
|
| 47 |
+
echo ""
|
| 48 |
+
|
| 49 |
+
# Calculate experiments per GPU
|
| 50 |
+
CONFIGS_PER_GPU=$((TOTAL_CONFIGS / NUM_GPUS))
|
| 51 |
+
REMAINDER=$((TOTAL_CONFIGS % NUM_GPUS))
|
| 52 |
+
|
| 53 |
+
echo "Distributing work:"
|
| 54 |
+
echo " Base configs per GPU: $CONFIGS_PER_GPU"
|
| 55 |
+
echo " Extra configs for first GPUs: $REMAINDER"
|
| 56 |
+
echo ""
|
| 57 |
+
|
| 58 |
+
# Create output directory
|
| 59 |
+
mkdir -p "$OUTPUT_DIR"
|
| 60 |
+
|
| 61 |
+
# Array to store background process IDs
|
| 62 |
+
PIDS=()
|
| 63 |
+
|
| 64 |
+
# Launch parallel processes on each GPU
|
| 65 |
+
for GPU_ID in $(seq 0 $((NUM_GPUS - 1))); do
|
| 66 |
+
# Calculate start and end indices for this GPU
|
| 67 |
+
START_IDX=$((GPU_ID * CONFIGS_PER_GPU))
|
| 68 |
+
|
| 69 |
+
# Give extra configs to first GPUs
|
| 70 |
+
if [ $GPU_ID -lt $REMAINDER ]; then
|
| 71 |
+
START_IDX=$((START_IDX + GPU_ID))
|
| 72 |
+
END_IDX=$((START_IDX + CONFIGS_PER_GPU + 1))
|
| 73 |
+
else
|
| 74 |
+
START_IDX=$((START_IDX + REMAINDER))
|
| 75 |
+
END_IDX=$((START_IDX + CONFIGS_PER_GPU))
|
| 76 |
+
fi
|
| 77 |
+
|
| 78 |
+
# Create GPU-specific output directory
|
| 79 |
+
GPU_OUTPUT_DIR="${OUTPUT_DIR}/gpu_${GPU_ID}"
|
| 80 |
+
mkdir -p "$GPU_OUTPUT_DIR"
|
| 81 |
+
|
| 82 |
+
echo "GPU $GPU_ID: configs $START_IDX to $END_IDX"
|
| 83 |
+
|
| 84 |
+
# Launch tuning process in background
|
| 85 |
+
nohup python tune_hyperparams.py \
|
| 86 |
+
--output_dir "$GPU_OUTPUT_DIR" \
|
| 87 |
+
--max_samples $MAX_SAMPLES \
|
| 88 |
+
--num_steps $NUM_STEPS \
|
| 89 |
+
--dataset_type "$DATASET_TYPE" \
|
| 90 |
+
--model_variant "$MODEL_VARIANT" \
|
| 91 |
+
--cuda $GPU_ID \
|
| 92 |
+
--search_type "$SEARCH_TYPE" \
|
| 93 |
+
--start_idx $START_IDX \
|
| 94 |
+
--end_idx $END_IDX \
|
| 95 |
+
--metrics clip aesthetic pickscore hpsv2 imagereward \
|
| 96 |
+
> "${GPU_OUTPUT_DIR}/tuning.log" 2>&1 &
|
| 97 |
+
|
| 98 |
+
# Store PID
|
| 99 |
+
PIDS+=($!)
|
| 100 |
+
|
| 101 |
+
echo " Launched with PID: ${PIDS[$GPU_ID]}"
|
| 102 |
+
|
| 103 |
+
# Small delay to avoid race conditions
|
| 104 |
+
sleep 2
|
| 105 |
+
done
|
| 106 |
+
|
| 107 |
+
echo ""
|
| 108 |
+
echo "=============================================="
|
| 109 |
+
echo " ALL PROCESSES LAUNCHED"
|
| 110 |
+
echo "=============================================="
|
| 111 |
+
echo ""
|
| 112 |
+
echo "Background processes running:"
|
| 113 |
+
for GPU_ID in $(seq 0 $((NUM_GPUS - 1))); do
|
| 114 |
+
echo " GPU $GPU_ID: PID ${PIDS[$GPU_ID]} -> ${OUTPUT_DIR}/gpu_${GPU_ID}/tuning.log"
|
| 115 |
+
done
|
| 116 |
+
echo ""
|
| 117 |
+
echo "To monitor progress:"
|
| 118 |
+
echo " tail -f ${OUTPUT_DIR}/gpu_0/tuning.log"
|
| 119 |
+
echo " tail -f ${OUTPUT_DIR}/gpu_1/tuning.log"
|
| 120 |
+
echo " ... etc"
|
| 121 |
+
echo ""
|
| 122 |
+
echo "To check all GPU processes:"
|
| 123 |
+
echo " ps aux | grep tune_hyperparams.py"
|
| 124 |
+
echo ""
|
| 125 |
+
echo "To monitor GPU usage:"
|
| 126 |
+
echo " watch -n 1 nvidia-smi"
|
| 127 |
+
echo ""
|
| 128 |
+
echo "To kill all processes:"
|
| 129 |
+
echo " kill ${PIDS[@]}"
|
| 130 |
+
echo ""
|
| 131 |
+
echo "Waiting for all processes to complete..."
|
| 132 |
+
echo "(Press Ctrl+C to stop waiting, processes will continue in background)"
|
| 133 |
+
echo ""
|
| 134 |
+
|
| 135 |
+
# Wait for all background processes
|
| 136 |
+
for PID in "${PIDS[@]}"; do
|
| 137 |
+
wait $PID
|
| 138 |
+
done
|
| 139 |
+
|
| 140 |
+
echo ""
|
| 141 |
+
echo "=============================================="
|
| 142 |
+
echo " ALL TUNING PROCESSES COMPLETE"
|
| 143 |
+
echo "=============================================="
|
| 144 |
+
echo ""
|
| 145 |
+
|
| 146 |
+
# Merge results from all GPUs
|
| 147 |
+
echo "Merging results from all GPUs..."
|
| 148 |
+
|
| 149 |
+
# Activate conda environment for Python script
|
| 150 |
+
source ~/miniconda3/etc/profile.d/conda.sh
|
| 151 |
+
conda activate /home/ec2-user/aev
|
| 152 |
+
|
| 153 |
+
python - <<'EOF'
|
| 154 |
+
import json
|
| 155 |
+
from pathlib import Path
|
| 156 |
+
import sys
|
| 157 |
+
|
| 158 |
+
output_dir = Path("RESULTS_TURNING")
|
| 159 |
+
all_results = []
|
| 160 |
+
baseline_result = None
|
| 161 |
+
|
| 162 |
+
# Collect results from each GPU
|
| 163 |
+
for gpu_id in range(8):
|
| 164 |
+
gpu_dir = output_dir / f"gpu_{gpu_id}"
|
| 165 |
+
results_file = gpu_dir / "tuning_results.json"
|
| 166 |
+
|
| 167 |
+
if results_file.exists():
|
| 168 |
+
with open(results_file, 'r') as f:
|
| 169 |
+
data = json.load(f)
|
| 170 |
+
|
| 171 |
+
# Get baseline (should be same from all)
|
| 172 |
+
if baseline_result is None and "baseline" in data:
|
| 173 |
+
baseline_result = data["baseline"]
|
| 174 |
+
|
| 175 |
+
# Collect experiments
|
| 176 |
+
if "experiments" in data:
|
| 177 |
+
all_results.extend(data["experiments"])
|
| 178 |
+
|
| 179 |
+
print(f"GPU {gpu_id}: {len(data.get('experiments', []))} results")
|
| 180 |
+
|
| 181 |
+
# Merge all results
|
| 182 |
+
merged_data = {
|
| 183 |
+
"baseline": baseline_result,
|
| 184 |
+
"experiments": all_results,
|
| 185 |
+
"num_gpus": 8,
|
| 186 |
+
"total_experiments": len(all_results)
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
# Save merged results
|
| 190 |
+
merged_file = output_dir / "merged_results.json"
|
| 191 |
+
with open(merged_file, 'w') as f:
|
| 192 |
+
json.dump(merged_data, f, indent=2)
|
| 193 |
+
|
| 194 |
+
print(f"\nMerged {len(all_results)} total results")
|
| 195 |
+
print(f"Saved to: {merged_file}")
|
| 196 |
+
|
| 197 |
+
# Find best configuration
|
| 198 |
+
successful = [r for r in all_results if "metrics" in r]
|
| 199 |
+
if successful:
|
| 200 |
+
# Compute aggregate scores
|
| 201 |
+
def compute_score(metrics):
|
| 202 |
+
weights = {
|
| 203 |
+
"reward": 1.0, "clip": 0.8, "aesthetic": 0.8,
|
| 204 |
+
"pickscore": 1.0, "hpsv2": 1.0, "imagereward": 1.0,
|
| 205 |
+
"fid": -0.5
|
| 206 |
+
}
|
| 207 |
+
score = sum(weights.get(k, 0) * v for k, v in metrics.items())
|
| 208 |
+
return score / sum(abs(w) for w in weights.values())
|
| 209 |
+
|
| 210 |
+
for r in successful:
|
| 211 |
+
r["aggregate_score"] = compute_score(r["metrics"])
|
| 212 |
+
|
| 213 |
+
successful.sort(key=lambda x: x["aggregate_score"], reverse=True)
|
| 214 |
+
|
| 215 |
+
best = successful[0]
|
| 216 |
+
best_file = output_dir / "best_config.json"
|
| 217 |
+
with open(best_file, 'w') as f:
|
| 218 |
+
json.dump({
|
| 219 |
+
"config": best["config"],
|
| 220 |
+
"metrics": best["metrics"],
|
| 221 |
+
"aggregate_score": best["aggregate_score"],
|
| 222 |
+
"improvements": best.get("improvements", {})
|
| 223 |
+
}, f, indent=2)
|
| 224 |
+
|
| 225 |
+
print(f"\n{'='*60}")
|
| 226 |
+
print("BEST CONFIGURATION:")
|
| 227 |
+
print(f"{'='*60}")
|
| 228 |
+
print(json.dumps(best["config"], indent=2))
|
| 229 |
+
print(f"\nAggregate Score: {best['aggregate_score']:.4f}")
|
| 230 |
+
print(f"Saved to: {best_file}")
|
| 231 |
+
else:
|
| 232 |
+
print("\nNo successful experiments found!")
|
| 233 |
+
sys.exit(1)
|
| 234 |
+
EOF
|
| 235 |
+
|
| 236 |
+
if [ $? -eq 0 ]; then
|
| 237 |
+
echo ""
|
| 238 |
+
echo "=============================================="
|
| 239 |
+
echo " TUNING COMPLETE!"
|
| 240 |
+
echo "=============================================="
|
| 241 |
+
echo ""
|
| 242 |
+
echo "Results:"
|
| 243 |
+
echo " Merged results: ${OUTPUT_DIR}/merged_results.json"
|
| 244 |
+
echo " Best config: ${OUTPUT_DIR}/best_config.json"
|
| 245 |
+
echo ""
|
| 246 |
+
echo "View best configuration:"
|
| 247 |
+
echo " cat ${OUTPUT_DIR}/best_config.json"
|
| 248 |
+
echo ""
|
| 249 |
+
else
|
| 250 |
+
echo ""
|
| 251 |
+
echo "ERROR: Failed to merge results"
|
| 252 |
+
exit 1
|
| 253 |
+
fi
|
__pycache__/upload.cpython-311.pyc
ADDED
|
Binary file (9.83 kB). View file
|
|
|
evaluation/open_clip/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2.01 kB). View file
|
|
|
evaluation/open_clip/__pycache__/coca_model.cpython-311.pyc
ADDED
|
Binary file (18.4 kB). View file
|
|
|
evaluation/open_clip/__pycache__/hf_model.cpython-311.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
evaluation/open_clip/__pycache__/loss.cpython-311.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
evaluation/open_clip/__pycache__/model.cpython-311.pyc
ADDED
|
Binary file (25.1 kB). View file
|
|
|
evaluation/open_clip/__pycache__/pretrained.cpython-311.pyc
ADDED
|
Binary file (18.5 kB). View file
|
|
|
evaluation/open_clip/__pycache__/push_to_hf_hub.cpython-311.pyc
ADDED
|
Binary file (9.31 kB). View file
|
|
|
evaluation/open_clip/__pycache__/tokenizer.cpython-311.pyc
ADDED
|
Binary file (16 kB). View file
|
|
|
evaluation/open_clip/__pycache__/transform.cpython-311.pyc
ADDED
|
Binary file (8.58 kB). View file
|
|
|
evaluation/open_clip/__pycache__/transformer.cpython-311.pyc
ADDED
|
Binary file (42.6 kB). View file
|
|
|
evaluation/open_clip/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (3.58 kB). View file
|
|
|
evaluation/open_clip/model_configs/RN50.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"embed_dim": 1024,
|
| 3 |
+
"vision_cfg": {
|
| 4 |
+
"image_size": 224,
|
| 5 |
+
"layers": [
|
| 6 |
+
3,
|
| 7 |
+
4,
|
| 8 |
+
6,
|
| 9 |
+
3
|
| 10 |
+
],
|
| 11 |
+
"width": 64,
|
| 12 |
+
"patch_size": null
|
| 13 |
+
},
|
| 14 |
+
"text_cfg": {
|
| 15 |
+
"context_length": 77,
|
| 16 |
+
"vocab_size": 49408,
|
| 17 |
+
"width": 512,
|
| 18 |
+
"heads": 8,
|
| 19 |
+
"layers": 12
|
| 20 |
+
}
|
| 21 |
+
}
|
evaluation/open_clip/model_configs/ViT-B-32-plus-256.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"embed_dim": 640,
|
| 3 |
+
"vision_cfg": {
|
| 4 |
+
"image_size": 256,
|
| 5 |
+
"layers": 12,
|
| 6 |
+
"width": 896,
|
| 7 |
+
"patch_size": 32
|
| 8 |
+
},
|
| 9 |
+
"text_cfg": {
|
| 10 |
+
"context_length": 77,
|
| 11 |
+
"vocab_size": 49408,
|
| 12 |
+
"width": 640,
|
| 13 |
+
"heads": 10,
|
| 14 |
+
"layers": 12
|
| 15 |
+
}
|
| 16 |
+
}
|
evaluation/open_clip/model_configs/ViT-S-32.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"embed_dim": 384,
|
| 3 |
+
"vision_cfg": {
|
| 4 |
+
"image_size": 224,
|
| 5 |
+
"layers": 12,
|
| 6 |
+
"width": 384,
|
| 7 |
+
"patch_size": 32
|
| 8 |
+
},
|
| 9 |
+
"text_cfg": {
|
| 10 |
+
"context_length": 77,
|
| 11 |
+
"vocab_size": 49408,
|
| 12 |
+
"width": 384,
|
| 13 |
+
"heads": 6,
|
| 14 |
+
"layers": 12
|
| 15 |
+
}
|
| 16 |
+
}
|
evaluation/open_clip/model_configs/convnext_large_d_320.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"embed_dim": 768,
|
| 3 |
+
"vision_cfg": {
|
| 4 |
+
"timm_model_name": "convnext_large",
|
| 5 |
+
"timm_model_pretrained": false,
|
| 6 |
+
"timm_pool": "",
|
| 7 |
+
"timm_proj": "mlp",
|
| 8 |
+
"timm_drop": 0.0,
|
| 9 |
+
"timm_drop_path": 0.1,
|
| 10 |
+
"image_size": 320
|
| 11 |
+
},
|
| 12 |
+
"text_cfg": {
|
| 13 |
+
"context_length": 77,
|
| 14 |
+
"vocab_size": 49408,
|
| 15 |
+
"width": 768,
|
| 16 |
+
"heads": 12,
|
| 17 |
+
"layers": 16
|
| 18 |
+
}
|
| 19 |
+
}
|
lrm/flux/.hydra/config.yaml
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerator:
|
| 2 |
+
_target_: trainer.accelerators.debug_accelerator.DebugAccelerator
|
| 3 |
+
output_dir: ${output_dir}
|
| 4 |
+
mixed_precision: BF16
|
| 5 |
+
gradient_accumulation_steps: 1
|
| 6 |
+
log_with: null
|
| 7 |
+
debug:
|
| 8 |
+
activate: false
|
| 9 |
+
port: 5900
|
| 10 |
+
seed: 42
|
| 11 |
+
resume_from_checkpoint: false
|
| 12 |
+
max_steps: 8000
|
| 13 |
+
num_epochs: 10
|
| 14 |
+
validate_steps: 100
|
| 15 |
+
generalization_validate_steps: 500
|
| 16 |
+
eval_on_start: true
|
| 17 |
+
project_name: reward_model
|
| 18 |
+
run_name: step_flux_schnell_variable-t_lr1e-5_step-8000_cfg0.0_filter2_time951
|
| 19 |
+
max_grad_norm: 1.0
|
| 20 |
+
save_steps: 100
|
| 21 |
+
metric_name: accuracy
|
| 22 |
+
metric_mode: MAX
|
| 23 |
+
limit_num_checkpoints: 1
|
| 24 |
+
save_only_if_best: true
|
| 25 |
+
dynamo_backend: 'NO'
|
| 26 |
+
keep_best_ckpts: true
|
| 27 |
+
task:
|
| 28 |
+
limit_examples_to_wandb: 50
|
| 29 |
+
_target_: trainer.tasks.step_flux_task.StepFluxTask
|
| 30 |
+
pretrained_model_name_or_path: ${model.pretrained_model_name_or_path}
|
| 31 |
+
tokenizer_subfolder: tokenizer
|
| 32 |
+
label_0_column_name: ${dataset.label_0_column_name}
|
| 33 |
+
label_1_column_name: ${dataset.label_1_column_name}
|
| 34 |
+
input_ids_column_name: ${dataset.input_ids_column_name}
|
| 35 |
+
input_ids_2_column_name: ${dataset.input_ids_2_column_name}
|
| 36 |
+
pixels_0_column_name: ${dataset.pixels_0_column_name}
|
| 37 |
+
pixels_1_column_name: ${dataset.pixels_1_column_name}
|
| 38 |
+
timestep_column_name: ${dataset.timestep_column_name}
|
| 39 |
+
constant_timestep: ${dataset.constant_timestep}
|
| 40 |
+
model:
|
| 41 |
+
_target_: trainer.models.flux_preference_model.FluxPreferenceModel
|
| 42 |
+
pretrained_model_name_or_path: black-forest-labs/FLUX.1-schnell
|
| 43 |
+
pretrained_vae_name_or_path: black-forest-labs/FLUX.1-schnell
|
| 44 |
+
projection_dim: 1024
|
| 45 |
+
text_embed_dim: 768
|
| 46 |
+
logit_scale_init_value: 2.6592
|
| 47 |
+
freeze_text_encoder: false
|
| 48 |
+
guidance_scale: 0.0
|
| 49 |
+
noise_offset: false
|
| 50 |
+
noise_offset_coeff: 0.05
|
| 51 |
+
max_sequence_length: 512
|
| 52 |
+
image_size: 1024
|
| 53 |
+
criterion:
|
| 54 |
+
_target_: trainer.criterions.step_clip_criterion_flux.StepFluxCLIPCriterion
|
| 55 |
+
is_distributed: false
|
| 56 |
+
label_0_column_name: ${dataset.label_0_column_name}
|
| 57 |
+
label_1_column_name: ${dataset.label_1_column_name}
|
| 58 |
+
input_ids_column_name: ${dataset.input_ids_column_name}
|
| 59 |
+
input_ids_2_column_name: ${dataset.input_ids_2_column_name}
|
| 60 |
+
pixels_0_column_name: ${dataset.pixels_0_column_name}
|
| 61 |
+
pixels_1_column_name: ${dataset.pixels_1_column_name}
|
| 62 |
+
num_examples_per_prompt_column_name: ${dataset.num_examples_per_prompt_column_name}
|
| 63 |
+
timestep_column_name: ${dataset.timestep_column_name}
|
| 64 |
+
loss_type: pair
|
| 65 |
+
batch_coeff: 1.0
|
| 66 |
+
aux_loss_coeff: 1.0
|
| 67 |
+
dataset:
|
| 68 |
+
train_split_name: train
|
| 69 |
+
valid_split_name: validation_unique
|
| 70 |
+
test_split_name: test_unique
|
| 71 |
+
batch_size: 4
|
| 72 |
+
num_workers: 2
|
| 73 |
+
drop_last: true
|
| 74 |
+
_target_: trainer.datasets.step_flux_hf_dataset.StepFluxHFDataset
|
| 75 |
+
dataset_name: pickapic-anonymous/pickapic_v1
|
| 76 |
+
dataset_config_name: null
|
| 77 |
+
from_disk: false
|
| 78 |
+
cache_dir: null
|
| 79 |
+
caption_column_name: caption
|
| 80 |
+
input_ids_column_name: input_ids
|
| 81 |
+
input_ids_2_column_name: input_ids_2
|
| 82 |
+
image_0_column_name: jpg_0
|
| 83 |
+
image_1_column_name: jpg_1
|
| 84 |
+
label_0_column_name: label_0
|
| 85 |
+
label_1_column_name: label_1
|
| 86 |
+
are_different_column_name: are_different
|
| 87 |
+
has_label_column_name: has_label
|
| 88 |
+
pixels_0_column_name: pixel_values_0
|
| 89 |
+
pixels_1_column_name: pixel_values_1
|
| 90 |
+
timestep_column_name: timestep
|
| 91 |
+
constant_timestep: 1
|
| 92 |
+
variable_timestep: true
|
| 93 |
+
largest_timestep: 951
|
| 94 |
+
compare_between_timestep: false
|
| 95 |
+
timestep_comparison_column_name: timestep_comparison
|
| 96 |
+
timestep_interval: 1
|
| 97 |
+
num_examples_per_prompt_column_name: num_example_per_prompt
|
| 98 |
+
keep_only_different: false
|
| 99 |
+
keep_only_with_label: false
|
| 100 |
+
keep_only_with_label_in_non_train: true
|
| 101 |
+
keep_only_with_pesudo_preference: true
|
| 102 |
+
pseudo_preference_path: /g/data/rr81/LPO/lrm/flux/vqa_aes_clip_score_mp.csv
|
| 103 |
+
filter_strategy: 2
|
| 104 |
+
processor:
|
| 105 |
+
pretrained_model_name_or_path: ${model.pretrained_model_name_or_path}
|
| 106 |
+
max_sequence_length: ${model.max_sequence_length}
|
| 107 |
+
image_size: ${model.image_size}
|
| 108 |
+
random_crop: false
|
| 109 |
+
no_hflip: true
|
| 110 |
+
limit_examples_per_prompt: -1
|
| 111 |
+
only_on_best: false
|
| 112 |
+
optimizer:
|
| 113 |
+
_target_: trainer.optimizers.dummy_optimizer.BaseDummyOptim
|
| 114 |
+
lr: 1.0e-05
|
| 115 |
+
weight_decay: 0.3
|
| 116 |
+
lr_scheduler:
|
| 117 |
+
_target_: trainer.lr_schedulers.dummy_lr_scheduler.instantiate_dummy_lr_scheduler
|
| 118 |
+
lr: ${optimizer.lr}
|
| 119 |
+
lr_warmup_steps: 1000
|
| 120 |
+
total_num_steps: ${accelerator.max_steps}
|
| 121 |
+
debug:
|
| 122 |
+
activate: false
|
| 123 |
+
port: 5900
|
| 124 |
+
output_dir: logs/lrm/${accelerator.project_name}/${accelerator.run_name}
|
lrm/flux/.hydra/hydra.yaml
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
run:
|
| 3 |
+
dir: .
|
| 4 |
+
sweep:
|
| 5 |
+
dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 6 |
+
subdir: ${hydra.job.num}
|
| 7 |
+
launcher:
|
| 8 |
+
_target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
|
| 9 |
+
sweeper:
|
| 10 |
+
_target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
|
| 11 |
+
max_batch_size: null
|
| 12 |
+
params: null
|
| 13 |
+
help:
|
| 14 |
+
app_name: ${hydra.job.name}
|
| 15 |
+
header: '${hydra.help.app_name} is powered by Hydra.
|
| 16 |
+
|
| 17 |
+
'
|
| 18 |
+
footer: 'Powered by Hydra (https://hydra.cc)
|
| 19 |
+
|
| 20 |
+
Use --hydra-help to view Hydra specific help
|
| 21 |
+
|
| 22 |
+
'
|
| 23 |
+
template: '${hydra.help.header}
|
| 24 |
+
|
| 25 |
+
== Configuration groups ==
|
| 26 |
+
|
| 27 |
+
Compose your configuration from those groups (group=option)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
$APP_CONFIG_GROUPS
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
== Config ==
|
| 34 |
+
|
| 35 |
+
Override anything in the config (foo.bar=value)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
$CONFIG
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
${hydra.help.footer}
|
| 42 |
+
|
| 43 |
+
'
|
| 44 |
+
hydra_help:
|
| 45 |
+
template: 'Hydra (${hydra.runtime.version})
|
| 46 |
+
|
| 47 |
+
See https://hydra.cc for more info.
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
== Flags ==
|
| 51 |
+
|
| 52 |
+
$FLAGS_HELP
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
== Configuration groups ==
|
| 56 |
+
|
| 57 |
+
Compose your configuration from those groups (For example, append hydra/job_logging=disabled
|
| 58 |
+
to command line)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
$HYDRA_CONFIG_GROUPS
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
Use ''--cfg hydra'' to Show the Hydra config.
|
| 65 |
+
|
| 66 |
+
'
|
| 67 |
+
hydra_help: ???
|
| 68 |
+
hydra_logging:
|
| 69 |
+
version: 1
|
| 70 |
+
formatters:
|
| 71 |
+
simple:
|
| 72 |
+
format: '[%(asctime)s][HYDRA] %(message)s'
|
| 73 |
+
handlers:
|
| 74 |
+
console:
|
| 75 |
+
class: logging.StreamHandler
|
| 76 |
+
formatter: simple
|
| 77 |
+
stream: ext://sys.stdout
|
| 78 |
+
root:
|
| 79 |
+
level: INFO
|
| 80 |
+
handlers:
|
| 81 |
+
- console
|
| 82 |
+
loggers:
|
| 83 |
+
logging_example:
|
| 84 |
+
level: DEBUG
|
| 85 |
+
disable_existing_loggers: false
|
| 86 |
+
job_logging:
|
| 87 |
+
version: 1
|
| 88 |
+
formatters:
|
| 89 |
+
simple:
|
| 90 |
+
format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
|
| 91 |
+
handlers:
|
| 92 |
+
console:
|
| 93 |
+
class: logging.StreamHandler
|
| 94 |
+
formatter: simple
|
| 95 |
+
stream: ext://sys.stdout
|
| 96 |
+
file:
|
| 97 |
+
class: logging.FileHandler
|
| 98 |
+
formatter: simple
|
| 99 |
+
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
|
| 100 |
+
root:
|
| 101 |
+
level: INFO
|
| 102 |
+
handlers:
|
| 103 |
+
- console
|
| 104 |
+
- file
|
| 105 |
+
disable_existing_loggers: false
|
| 106 |
+
env: {}
|
| 107 |
+
mode: RUN
|
| 108 |
+
searchpath: []
|
| 109 |
+
callbacks: {}
|
| 110 |
+
output_subdir: .hydra
|
| 111 |
+
overrides:
|
| 112 |
+
hydra:
|
| 113 |
+
- hydra.mode=RUN
|
| 114 |
+
task:
|
| 115 |
+
- accelerator.mixed_precision=BF16
|
| 116 |
+
- accelerator.log_with=null
|
| 117 |
+
- accelerator=debug
|
| 118 |
+
- criterion.is_distributed=false
|
| 119 |
+
- dataset.pseudo_preference_path=/g/data/rr81/LPO/lrm/flux/vqa_aes_clip_score_mp.csv
|
| 120 |
+
job:
|
| 121 |
+
name: train
|
| 122 |
+
chdir: null
|
| 123 |
+
override_dirname: accelerator.log_with=null,accelerator.mixed_precision=BF16,accelerator=debug,criterion.is_distributed=false,dataset.pseudo_preference_path=/g/data/rr81/LPO/lrm/flux/vqa_aes_clip_score_mp.csv
|
| 124 |
+
id: ???
|
| 125 |
+
num: ???
|
| 126 |
+
config_name: step_flux_base
|
| 127 |
+
env_set: {}
|
| 128 |
+
env_copy: []
|
| 129 |
+
config:
|
| 130 |
+
override_dirname:
|
| 131 |
+
kv_sep: '='
|
| 132 |
+
item_sep: ','
|
| 133 |
+
exclude_keys: []
|
| 134 |
+
runtime:
|
| 135 |
+
version: 1.3.2
|
| 136 |
+
version_base: '1.3'
|
| 137 |
+
cwd: /g/data/rr81/LPO/lrm/flux
|
| 138 |
+
config_sources:
|
| 139 |
+
- path: hydra.conf
|
| 140 |
+
schema: pkg
|
| 141 |
+
provider: hydra
|
| 142 |
+
- path: /g/data/rr81/LPO/lrm/flux/trainer/conf
|
| 143 |
+
schema: file
|
| 144 |
+
provider: main
|
| 145 |
+
- path: ''
|
| 146 |
+
schema: structured
|
| 147 |
+
provider: schema
|
| 148 |
+
output_dir: /g/data/rr81/LPO/lrm/flux
|
| 149 |
+
choices:
|
| 150 |
+
lr_scheduler: dummy
|
| 151 |
+
optimizer: dummy
|
| 152 |
+
dataset: step_flux
|
| 153 |
+
criterion: step_clip_flux
|
| 154 |
+
model: step_flux_base
|
| 155 |
+
task: step_flux
|
| 156 |
+
accelerator: debug
|
| 157 |
+
hydra/env: default
|
| 158 |
+
hydra/callbacks: null
|
| 159 |
+
hydra/job_logging: default
|
| 160 |
+
hydra/hydra_logging: default
|
| 161 |
+
hydra/hydra_help: default
|
| 162 |
+
hydra/help: default
|
| 163 |
+
hydra/sweeper: basic
|
| 164 |
+
hydra/launcher: basic
|
| 165 |
+
hydra/output: default
|
| 166 |
+
verbose: false
|
lrm/flux/.hydra/overrides.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
- accelerator.mixed_precision=BF16
|
| 2 |
+
- accelerator.log_with=null
|
| 3 |
+
- accelerator=debug
|
| 4 |
+
- criterion.is_distributed=false
|
| 5 |
+
- dataset.pseudo_preference_path=/g/data/rr81/LPO/lrm/flux/vqa_aes_clip_score_mp.csv
|
lrm/flux/README.md
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Please check the train_flux.sh file
|
| 2 |
+
In line 21 there is a variable named "RUN_PROFILE" which can be set to "main" or "quick for testing. If set to "main", it will run the full training loop. If set to "quick", it will run a minimal training loop for testing purposes.
|
| 3 |
+
|
| 4 |
+
I have set distributed shared for quick testing mode but distributed is set for main mode.
|