aryadomain commited on
Commit
ef8f3ad
·
verified ·
1 Parent(s): b871133

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Reward_sana_idealized/open_clip/__pycache__/__init__.cpython-311.pyc +0 -0
  2. Reward_sana_idealized/open_clip/__pycache__/constants.cpython-311.pyc +0 -0
  3. Reward_sana_idealized/open_clip/__pycache__/hf_configs.cpython-311.pyc +0 -0
  4. Reward_sana_idealized/open_clip/__pycache__/hf_model.cpython-311.pyc +0 -0
  5. Reward_sana_idealized/open_clip/__pycache__/loss.cpython-311.pyc +0 -0
  6. Reward_sana_idealized/open_clip/__pycache__/openai.cpython-311.pyc +0 -0
  7. Reward_sana_idealized/open_clip/__pycache__/transform.cpython-311.pyc +0 -0
  8. Reward_sana_idealized/open_clip/__pycache__/utils.cpython-311.pyc +0 -0
  9. Reward_sana_idealized/open_clip/__pycache__/version.cpython-311.pyc +0 -0
  10. Reward_sana_idealized/open_clip/model_configs/RN50.json +21 -0
  11. Reward_sana_idealized/open_clip/model_configs/RN50x64.json +21 -0
  12. Reward_sana_idealized/open_clip/model_configs/ViT-B-32-plus-256.json +16 -0
  13. Reward_sana_idealized/open_clip/model_configs/ViT-B-32-quickgelu.json +17 -0
  14. Reward_sana_idealized/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json +15 -0
  15. Reward_sdxl_idealized/README.md +1336 -0
  16. Reward_sdxl_idealized/config_analysis_tuning.ipynb +218 -0
  17. Reward_sdxl_idealized/eval.py +1143 -0
  18. Reward_sdxl_idealized/examples.sh +19 -0
  19. Reward_sdxl_idealized/gradient_ascent_utils.py +339 -0
  20. Reward_sdxl_idealized/lr_scheduler.py +233 -0
  21. Reward_sdxl_idealized/models/__init__.py +3 -0
  22. Reward_sdxl_idealized/models/__pycache__/reward_model.cpython-310.pyc +0 -0
  23. Reward_sdxl_idealized/models/__pycache__/reward_model.cpython-313.pyc +0 -0
  24. Reward_sdxl_idealized/models/__pycache__/unet_2d_condition_reward.cpython-313.pyc +0 -0
  25. Reward_sdxl_idealized/models/unet_2d_condition_reward.py +1334 -0
  26. Reward_sdxl_idealized/pipelines/__pycache__/__init__.cpython-310.pyc +0 -0
  27. Reward_sdxl_idealized/pipelines/sdxl_gradient_ascent_pipeline.py +375 -0
  28. Reward_sdxl_idealized/timestep_convergence_analysis.ipynb +1105 -0
  29. Reward_sdxl_idealized/tune_hyperparams.py +514 -0
  30. Reward_sdxl_idealized/tune_parallel.sh +253 -0
  31. __pycache__/upload.cpython-311.pyc +0 -0
  32. evaluation/open_clip/__pycache__/__init__.cpython-311.pyc +0 -0
  33. evaluation/open_clip/__pycache__/coca_model.cpython-311.pyc +0 -0
  34. evaluation/open_clip/__pycache__/hf_model.cpython-311.pyc +0 -0
  35. evaluation/open_clip/__pycache__/loss.cpython-311.pyc +0 -0
  36. evaluation/open_clip/__pycache__/model.cpython-311.pyc +0 -0
  37. evaluation/open_clip/__pycache__/pretrained.cpython-311.pyc +0 -0
  38. evaluation/open_clip/__pycache__/push_to_hf_hub.cpython-311.pyc +0 -0
  39. evaluation/open_clip/__pycache__/tokenizer.cpython-311.pyc +0 -0
  40. evaluation/open_clip/__pycache__/transform.cpython-311.pyc +0 -0
  41. evaluation/open_clip/__pycache__/transformer.cpython-311.pyc +0 -0
  42. evaluation/open_clip/__pycache__/utils.cpython-311.pyc +0 -0
  43. evaluation/open_clip/model_configs/RN50.json +21 -0
  44. evaluation/open_clip/model_configs/ViT-B-32-plus-256.json +16 -0
  45. evaluation/open_clip/model_configs/ViT-S-32.json +16 -0
  46. evaluation/open_clip/model_configs/convnext_large_d_320.json +19 -0
  47. lrm/flux/.hydra/config.yaml +124 -0
  48. lrm/flux/.hydra/hydra.yaml +166 -0
  49. lrm/flux/.hydra/overrides.yaml +5 -0
  50. lrm/flux/README.md +4 -0
Reward_sana_idealized/open_clip/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2 kB). View file
 
Reward_sana_idealized/open_clip/__pycache__/constants.cpython-311.pyc ADDED
Binary file (282 Bytes). View file
 
Reward_sana_idealized/open_clip/__pycache__/hf_configs.cpython-311.pyc ADDED
Binary file (717 Bytes). View file
 
Reward_sana_idealized/open_clip/__pycache__/hf_model.cpython-311.pyc ADDED
Binary file (10.6 kB). View file
 
Reward_sana_idealized/open_clip/__pycache__/loss.cpython-311.pyc ADDED
Binary file (14.3 kB). View file
 
Reward_sana_idealized/open_clip/__pycache__/openai.cpython-311.pyc ADDED
Binary file (8.74 kB). View file
 
Reward_sana_idealized/open_clip/__pycache__/transform.cpython-311.pyc ADDED
Binary file (8.57 kB). View file
 
Reward_sana_idealized/open_clip/__pycache__/utils.cpython-311.pyc ADDED
Binary file (3.57 kB). View file
 
Reward_sana_idealized/open_clip/__pycache__/version.cpython-311.pyc ADDED
Binary file (189 Bytes). View file
 
Reward_sana_idealized/open_clip/model_configs/RN50.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 6,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
Reward_sana_idealized/open_clip/model_configs/RN50x64.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 448,
5
+ "layers": [
6
+ 3,
7
+ 15,
8
+ 36,
9
+ 10
10
+ ],
11
+ "width": 128,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 1024,
18
+ "heads": 16,
19
+ "layers": 12
20
+ }
21
+ }
Reward_sana_idealized/open_clip/model_configs/ViT-B-32-plus-256.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 256,
5
+ "layers": 12,
6
+ "width": 896,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 640,
13
+ "heads": 10,
14
+ "layers": 12
15
+ }
16
+ }
Reward_sana_idealized/open_clip/model_configs/ViT-B-32-quickgelu.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": 12,
7
+ "width": 768,
8
+ "patch_size": 32
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 512,
14
+ "heads": 8,
15
+ "layers": 12
16
+ }
17
+ }
Reward_sana_idealized/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "hf_model_name": "xlm-roberta-base",
11
+ "hf_tokenizer_name": "xlm-roberta-base",
12
+ "proj": "mlp",
13
+ "pooler_type": "mean_pooler"
14
+ }
15
+ }
Reward_sdxl_idealized/README.md ADDED
@@ -0,0 +1,1336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reward-Guided Gradient Ascent for Stable Diffusion
2
+
3
+ A comprehensive system for improving Stable Diffusion image generation quality using gradient ascent optimization on Latent Reward Model (LRM) scores during inference.
4
+
5
+ ## Table of Contents
6
+
7
+ - [Overview](#overview)
8
+ - [Features](#features)
9
+ - [Installation](#installation)
10
+ - [Quick Start](#quick-start)
11
+ - [Architecture](#architecture)
12
+ - [Understanding Reward Calculation](#understanding-reward-calculation)
13
+ - [Learning Rate Scheduling](#learning-rate-scheduling)
14
+ - [Configuration Presets](#configuration-presets)
15
+ - [Evaluation Metrics](#evaluation-metrics)
16
+ - [Model Variants](#model-variants)
17
+ - [Datasets](#datasets)
18
+ - [Usage Examples](#usage-examples)
19
+ - [API Reference](#api-reference)
20
+ - [Command-Line Options](#command-line-options)
21
+ - [Output Files](#output-files)
22
+ - [Troubleshooting](#troubleshooting)
23
+ - [Best Practices](#best-practices)
24
+ - [Changelog](#changelog)
25
+
26
+ ---
27
+
28
+ ## Overview
29
+
30
+ This project implements **test-time optimization** for Stable Diffusion using gradient ascent on the LRM reward model. Unlike the main LPO training which uses the reward model for training, this approach applies it during inference to improve generation quality without retraining.
31
+
32
+ ### Key Capabilities
33
+
34
+ - **Gradient Ascent Optimization**: Iteratively improve latents using reward gradients
35
+ - **Learning Rate Scheduling**: Multiple strategies (constant, linear, cosine, exponential, step)
36
+ - **Momentum Optimization**: Standard and Nesterov momentum for better convergence
37
+ - **Multiple Metrics**: FID, CLIP, Aesthetic, PickScore, HPSv2, ImageReward
38
+ - **Model Variants**: Support for Origin, SPO, DPO, and LPO SD1.5 models
39
+ - **Dataset Flexibility**: COCO and Pick-a-Pic validation datasets
40
+ - **Configuration Presets**: 15 pre-tuned configurations for various use cases
41
+
42
+ ---
43
+
44
+ ## Features
45
+
46
+ ### 1. **Advanced Optimization**
47
+ - **5 LR Schedulers**: Constant, Linear, Cosine, Exponential, Step-wise
48
+ - **Momentum Support**: Standard momentum and Nesterov momentum
49
+ - **Configurable Timestep Ranges**: Apply gradients at specific denoising steps
50
+ - **Dynamic Learning Rates**: LR changes during optimization for better convergence
51
+
52
+ ### 2. **Comprehensive Evaluation**
53
+ - **6 Quality Metrics**: FID, CLIP, Aesthetic, PickScore, HPSv2, ImageReward
54
+ - **Baseline Comparison**: Compare with and without gradient ascent
55
+ - **Detailed Statistics**: Track reward improvements, gradient norms, LR history
56
+ - **Batch Processing**: Efficient evaluation on large datasets
57
+ - **Reward Visualization**: Automatic plotting of reward progression across timesteps
58
+ - **Timestep-Aware Tracking**: Monitor rewards at every denoising step, final t=0 latent reported
59
+
60
+ ### 3. **Model Flexibility**
61
+ - **4 SD1.5 Variants**: Origin, SPO, DPO, LPO
62
+ - **Auto-Configuration**: CFG scale auto-adjusted for model variants
63
+ - **Easy Switching**: Change models with a single flag
64
+
65
+ ### 4. **Dataset Support**
66
+ - **COCO Validation**: Standard benchmark with reference images
67
+ - **Pick-a-Pic Validation**: Large-scale human preference dataset
68
+ - **Streaming Support**: Handle large datasets efficiently
69
+
70
+ ---
71
+
72
+ ## Installation
73
+
74
+ ### Requirements
75
+
76
+ ```bash
77
+ # Core dependencies
78
+ pip install torch diffusers transformers torchmetrics datasets huggingface-hub
79
+
80
+ # For evaluation metrics
81
+ pip install pillow numpy scipy tqdm
82
+
83
+ # Optional: for better performance
84
+ pip install xformers # For memory-efficient attention
85
+ ```
86
+
87
+ ### Setup
88
+
89
+ ```bash
90
+ cd /path/to/LPO/Reward
91
+
92
+ # Verify installation
93
+ python -c "from lr_scheduler import create_lr_scheduler; print('✓ LR Scheduler OK')"
94
+ python -c "from grad_ascent_configs import list_configs; print('✓ Configs:', len(list_configs()))"
95
+ python -c "from gradient_ascent_utils import RewardGuidedDiffusion; print('✓ Gradient Utils OK')"
96
+ ```
97
+
98
+ ---
99
+
100
+ ## Quick Start
101
+
102
+ ### 1. Basic COCO Evaluation (test_grad_sd1.5.py)
103
+
104
+ ```bash
105
+ # Edit Config in test_grad_sd1.5.py:
106
+ # - Set device: "cuda:0" or "cuda:6"
107
+ # - Set max_samples: 10 for quick test, None for full dataset
108
+ # - Configure gradient ascent parameters
109
+
110
+ python test_grad_sd1.5.py
111
+ ```
112
+
113
+ **Output:**
114
+ - Creates `RESULTS/SD1.5_GradAscent/run_1/` (auto-incremented)
115
+ - Generates `eval.log` with detailed metrics
116
+ - Saves `reward_curve.png` showing reward progression
117
+
118
+ ### 2. Basic Evaluation with Preset Config (eval.py)
119
+
120
+ ```bash
121
+ python eval.py \
122
+ --grad_config cosine_nesterov \
123
+ --metrics clip aesthetic \
124
+ --max_samples 10
125
+ ```
126
+
127
+ ### 2. High-Quality Evaluation
128
+
129
+ ```bash
130
+ python eval.py \
131
+ --grad_config high_quality \
132
+ --metrics fid clip aesthetic pickscore hpsv2 \
133
+ --max_samples 100 \
134
+ --save_images \
135
+ --output_dir results/high_quality
136
+ ```
137
+
138
+ ### 3. Pick-a-Pic Benchmark
139
+
140
+ ```bash
141
+ python eval.py \
142
+ --dataset_type pickapic \
143
+ --grad_config cosine_nesterov \
144
+ --metrics pickscore hpsv2 imagereward \
145
+ --max_samples 500 \
146
+ --output_dir results/pickapic
147
+ ```
148
+
149
+ ---
150
+
151
+ ## Architecture
152
+
153
+ ### System Components
154
+
155
+ ```
156
+ Reward/
157
+ ├── models/
158
+ │ ├── reward_model.py # LRM reward model wrapper
159
+ │ └── unet_2d_condition_reward.py # Custom UNet with reward tracking
160
+ ├── pipelines/
161
+ │ ├── sd15_reward_pipeline.py # Base pipeline with reward tracking
162
+ │ └── sd15_gradient_ascent_pipeline.py # Pipeline with gradient ascent
163
+ ├── lr_scheduler.py # Learning rate schedulers
164
+ ├── gradient_ascent_utils.py # Core gradient ascent implementation
165
+ ├── grad_ascent_configs.py # Configuration presets
166
+ ├── eval.py # Comprehensive evaluation script
167
+ └── examples.sh # Example commands
168
+ ```
169
+
170
+ ### Gradient Ascent Flow
171
+
172
+ ```
173
+ 1. Load Stable Diffusion + LRM Reward Model
174
+ 2. Start denoising process (T → 0)
175
+ 3. At each timestep t:
176
+ a. Standard denoising step (predict noise, remove it)
177
+ b. Compute reward R(latents, prompt, t) and store in history
178
+ c. If t in gradient range:
179
+ - Enable gradients on latents
180
+ - Compute ∇R w.r.t. latents
181
+ - For each gradient step:
182
+ * Get current LR from scheduler
183
+ * Apply momentum (if enabled)
184
+ * Update: latents += lr * momentum(∇R)
185
+ - Track statistics (grad norms, reward improvement)
186
+ 4. At final timestep (t=0):
187
+ - Final reward computed on clean latent
188
+ - This reward is reported in logs
189
+ 5. Decode final latent (x₀) to image via VAE
190
+ 6. Compute quality metrics on image
191
+ ```
192
+
193
+ ### Understanding Reward Calculation
194
+
195
+ **Key Concepts:**
196
+
197
+ - **Timestep-Aware Rewards**: The LRM reward model computes preference scores at ANY noise level (timestep t)
198
+ - **Progressive Tracking**: Rewards are calculated at every denoising step throughout generation
199
+ - **Final Latent Reward**: The reported metric is the reward for t=0 (the clean latent before decoding)
200
+ - **Not Averaged**: The final reward is specifically from the last timestep, NOT an average across all timesteps
201
+
202
+ **What gets reported:**
203
+ ```python
204
+ # During generation: Rewards computed at each t (1000 → 0)
205
+ Step 0: t=1000, reward=3.2
206
+ Step 1: t=990, reward=3.5
207
+ ...
208
+ Step 99: t=10, reward=5.1
209
+ Step 100: t=0, reward=5.4 ← This is what gets logged!
210
+ ```
211
+
212
+ The `Reward (t=0)` in logs represents the preference score of the final clean latent that was decoded into your output image.
213
+
214
+ ---
215
+
216
+ ## Learning Rate Scheduling
217
+
218
+ ### Available Schedulers
219
+
220
+ #### 1. **Constant LR**
221
+ ```python
222
+ lr_scheduler_type="constant"
223
+ ```
224
+ - Fixed learning rate throughout optimization
225
+ - Simple and stable
226
+ - Good for quick experiments
227
+
228
+ #### 2. **Linear Decay**
229
+ ```python
230
+ lr_scheduler_type="linear"
231
+ lr_scheduler_kwargs={
232
+ "end_lr": 0.01, # End LR (10% of initial)
233
+ "start_step": 0 # When to start decay
234
+ }
235
+ ```
236
+ - Linear decrease from initial to end LR
237
+ - Smooth convergence
238
+ - Configurable warmup period
239
+
240
+ #### 3. **Cosine Annealing** (Recommended)
241
+ ```python
242
+ lr_scheduler_type="cosine"
243
+ lr_scheduler_kwargs={
244
+ "min_lr": 0.001, # Minimum LR
245
+ "warmup_steps": 3 # Linear warmup steps
246
+ }
247
+ ```
248
+ - Smooth cosine decay
249
+ - Optional warmup phase
250
+ - Widely used in deep learning
251
+ - **Best for most use cases**
252
+
253
+ #### 4. **Exponential Decay**
254
+ ```python
255
+ lr_scheduler_type="exponential"
256
+ lr_scheduler_kwargs={
257
+ "gamma": 0.9 # Decay factor per step
258
+ }
259
+ ```
260
+ - Exponential decrease
261
+ - Fast initial decay
262
+ - Good for aggressive optimization
263
+
264
+ #### 5. **Step Decay**
265
+ ```python
266
+ lr_scheduler_type="step"
267
+ lr_scheduler_kwargs={
268
+ "step_size": 5, # Steps between decays
269
+ "gamma": 0.5 # Multiplicative factor
270
+ }
271
+ ```
272
+ - Step-wise LR reduction
273
+ - Periodic decay
274
+ - Good for scheduled changes
275
+
276
+ ### Usage Example
277
+
278
+ ```python
279
+ from pipelines.sd15_gradient_ascent_pipeline import StableDiffusionGradientAscentPipeline
280
+
281
+ pipeline.enable_gradient_ascent(
282
+ grad_timestep_range=(0, 700),
283
+ num_grad_steps=15,
284
+ grad_step_size=0.1, # Initial LR
285
+ lr_scheduler_type="cosine",
286
+ lr_scheduler_kwargs={
287
+ "min_lr": 0.001,
288
+ "warmup_steps": 3
289
+ }
290
+ )
291
+ ```
292
+
293
+ ---
294
+
295
+ ## Configuration Presets
296
+
297
+ We provide 15 pre-configured optimization strategies. Use them with `--grad_config <name>`.
298
+
299
+ ### Basic Configurations
300
+
301
+ | Config | LR Schedule | Momentum | Steps | Description |
302
+ |--------|-------------|----------|-------|-------------|
303
+ | `constant` | Constant | No | 5 | Simple baseline |
304
+ | `linear` | Linear decay | No | 10 | Smooth decay |
305
+ | `linear_warmstart` | Linear w/ warmup | No | 10 | Stable start |
306
+ | `cosine` | Cosine | No | 10 | Smooth convergence |
307
+ | `cosine_warmup` | Cosine w/ warmup | No | 20 | Best convergence |
308
+ | `exponential` | Exponential | No | 15 | Fast decay |
309
+ | `step` | Step-wise | No | 20 | Periodic decay |
310
+
311
+ ### Momentum Configurations
312
+
313
+ | Config | LR Schedule | Momentum | Steps | Description |
314
+ |--------|-------------|----------|-------|-------------|
315
+ | `momentum` | Constant | Standard | 10 | Faster convergence |
316
+ | `nesterov` | Constant | Nesterov | 10 | Better convergence |
317
+
318
+ ### Advanced Configurations
319
+
320
+ | Config | LR Schedule | Momentum | Steps | Description |
321
+ |--------|-------------|----------|-------|-------------|
322
+ | `cosine_momentum` | Cosine | Standard | 15 | High quality |
323
+ | `cosine_nesterov` | Cosine | Nesterov | 15 | **Recommended** |
324
+ | `linear_nesterov` | Linear | Nesterov | 15 | Stable + fast |
325
+
326
+ ### Quality Presets
327
+
328
+ | Config | LR Schedule | Momentum | Steps | Use Case |
329
+ |--------|-------------|----------|-------|----------|
330
+ | `high_quality` | Cosine | Nesterov | 20 | **Best quality** |
331
+ | `aggressive` | Exponential | Standard | 8 | Fast results |
332
+ | `conservative` | Cosine | Nesterov | 25 | Most stable |
333
+
334
+ ### Config Details
335
+
336
+ #### `high_quality` (Recommended for Research)
337
+ ```python
338
+ {
339
+ "grad_timestep_range": (200, 800), # Focus on middle timesteps
340
+ "num_grad_steps": 20,
341
+ "grad_step_size": 0.08,
342
+ "lr_scheduler_type": "cosine",
343
+ "lr_scheduler_kwargs": {"min_lr": 0.005, "warmup_steps": 5},
344
+ "use_momentum": True,
345
+ "momentum": 0.95,
346
+ "use_nesterov": True
347
+ }
348
+ ```
349
+
350
+ #### `cosine_nesterov` (Recommended for General Use)
351
+ ```python
352
+ {
353
+ "grad_timestep_range": (0, 700),
354
+ "num_grad_steps": 15,
355
+ "grad_step_size": 0.12,
356
+ "lr_scheduler_type": "cosine",
357
+ "lr_scheduler_kwargs": {"min_lr": 0.001, "warmup_steps": 3},
358
+ "use_momentum": True,
359
+ "momentum": 0.9,
360
+ "use_nesterov": True
361
+ }
362
+ ```
363
+
364
+ #### `aggressive` (Fast Experimentation)
365
+ ```python
366
+ {
367
+ "grad_timestep_range": (0, 900),
368
+ "num_grad_steps": 8,
369
+ "grad_step_size": 0.15,
370
+ "grad_scale": 1.2,
371
+ "lr_scheduler_type": "exponential",
372
+ "lr_scheduler_kwargs": {"gamma": 0.85},
373
+ "use_momentum": True,
374
+ "momentum": 0.85,
375
+ "use_nesterov": False
376
+ }
377
+ ```
378
+
379
+ ### Listing Configs
380
+
381
+ ```python
382
+ from grad_ascent_configs import list_configs, print_config, get_config
383
+
384
+ # List all available configs
385
+ print(list_configs())
386
+ # Output: ['aggressive', 'conservative', 'constant', 'cosine', ...]
387
+
388
+ # Print config details
389
+ print_config("cosine_nesterov")
390
+
391
+ # Get config dictionary
392
+ config = get_config("high_quality")
393
+ pipeline.enable_gradient_ascent(**config)
394
+ ```
395
+
396
+ ---
397
+
398
+ ## Evaluation Metrics
399
+
400
+ ### 1. **FID (Fréchet Inception Distance)**
401
+ - Measures distribution similarity between real and generated images
402
+ - **Lower is better**
403
+ - Requires reference images (COCO dataset only)
404
+ - Computationally expensive
405
+
406
+ ```bash
407
+ --metrics fid
408
+ ```
409
+
410
+ ### 2. **CLIP Score**
411
+ - Evaluates text-image alignment using CLIP embeddings
412
+ - **Higher is better**
413
+ - Fast and reliable
414
+ - Good for general quality assessment
415
+
416
+ ```bash
417
+ --metrics clip
418
+ ```
419
+
420
+ ### 3. **Aesthetic Score**
421
+ - Predicts aesthetic quality using CLIP + MLP
422
+ - **Higher is better**
423
+ - Trained on human aesthetic ratings
424
+ - Good for visual appeal
425
+
426
+ ```bash
427
+ --metrics aesthetic
428
+ ```
429
+
430
+ ### 4. **PickScore** (New)
431
+ - Human preference predictor from Pick-a-Pic dataset
432
+ - **Higher is better**
433
+ - Trained on large-scale human comparisons
434
+ - State-of-the-art preference metric
435
+
436
+ ```bash
437
+ --metrics pickscore
438
+ ```
439
+
440
+ ### 5. **HPSv2** (New)
441
+ - Human Preference Score version 2
442
+ - **Higher is better**
443
+ - Trained on aesthetic evaluations
444
+ - Complementary to PickScore
445
+
446
+ ```bash
447
+ --metrics hpsv2
448
+ ```
449
+
450
+ ### 6. **ImageReward** (New)
451
+ - Reward model from RLHF (Reinforcement Learning from Human Feedback)
452
+ - **Higher is better**
453
+ - Comprehensive quality assessment
454
+ - Trained on diverse human feedback
455
+
456
+ ```bash
457
+ --metrics imagereward
458
+ ```
459
+
460
+ ### Metric Recommendations
461
+
462
+ | Use Case | Recommended Metrics | Reason |
463
+ |----------|---------------------|--------|
464
+ | Research/Papers | `fid clip aesthetic pickscore hpsv2` | Comprehensive evaluation |
465
+ | Quick Iteration | `clip aesthetic` | Fast and reliable |
466
+ | Human Alignment | `pickscore hpsv2 imagereward` | Preference-based |
467
+ | Text Alignment | `clip imagereward` | Focus on prompt adherence |
468
+ | Visual Quality | `aesthetic pickscore` | Focus on aesthetics |
469
+
470
+ ---
471
+
472
+ ## Model Variants
473
+
474
+ Support for multiple SD1.5 model variants trained with different methods.
475
+
476
+ ### Available Variants
477
+
478
+ #### 1. **Origin** (Default)
479
+ ```bash
480
+ --model_variant origin
481
+ ```
482
+ - Original Stable Diffusion v1.5 from RunwayML
483
+ - No additional training
484
+ - CFG scale: 7.5 (default)
485
+ - Good baseline
486
+
487
+ #### 2. **SPO** (Supervised Policy Optimization)
488
+ ```bash
489
+ --model_variant spo
490
+ ```
491
+ - Trained with SPO method
492
+ - Model: `SPO-Diffusion-Models/SPO-SD-v1-5_4k-p_10ep`
493
+ - **CFG scale: 5.0** (auto-adjusted)
494
+ - Better prompt adherence
495
+
496
+ #### 3. **Diffusion-DPO** (Direct Preference Optimization)
497
+ ```bash
498
+ --model_variant diffusion_dpo
499
+ ```
500
+ - Trained with DPO on human preferences
501
+ - Model: `mhdang/dpo-sd1.5-text2image-v1`
502
+ - CFG scale: 7.5
503
+ - Improved human alignment
504
+
505
+ #### 4. **LPO** (Latent Preference Optimization)
506
+ ```bash
507
+ --model_variant lpo
508
+ ```
509
+ - Trained with LPO (this project's main method)
510
+ - Model: `casiatao/LPO` (lpo_sd15_merge)
511
+ - **CFG scale: 5.0** (auto-adjusted)
512
+ - **Highest quality baseline**
513
+
514
+ ### Comparison
515
+
516
+ | Variant | Training Method | Quality | Speed | Best For |
517
+ |---------|----------------|---------|-------|----------|
518
+ | Origin | Pre-training only | Good | Fast | Baseline |
519
+ | SPO | Supervised | Better | Fast | Prompt adherence |
520
+ | Diffusion-DPO | Preference learning | Better | Fast | Human preferences |
521
+ | LPO | Latent preference | **Best** | Fast | Overall quality |
522
+
523
+ ### Usage Example
524
+
525
+ ```bash
526
+ # Compare all variants
527
+ for variant in origin spo diffusion_dpo lpo; do
528
+ python eval.py \
529
+ --model_variant $variant \
530
+ --grad_config high_quality \
531
+ --metrics clip aesthetic pickscore \
532
+ --max_samples 100 \
533
+ --output_dir results/${variant}
534
+ done
535
+ ```
536
+
537
+ ---
538
+
539
+ ## Datasets
540
+
541
+ ### 1. **COCO Validation** (Default)
542
+
543
+ ```bash
544
+ --dataset_type coco
545
+ --data_dir ./data
546
+ ```
547
+
548
+ **Features:**
549
+ - Standard benchmark dataset
550
+ - Reference images available (for FID)
551
+ - ~5,000 validation samples
552
+ - Diverse prompts
553
+
554
+ **Structure:**
555
+ ```
556
+ data/coco/
557
+ ├── caption_val.json
558
+ └── images/val/
559
+ ├── 000000000139.jpg
560
+ ├── 000000000285.jpg
561
+ └── ...
562
+ ```
563
+
564
+ ### 2. **Pick-a-Pic Validation**
565
+
566
+ ```bash
567
+ --dataset_type pickapic
568
+ ```
569
+
570
+ **Features:**
571
+ - Large-scale human preference dataset
572
+ - Streaming (no download needed)
573
+ - ~500,000 validation samples
574
+ - Real user prompts
575
+ - No reference images (FID not available)
576
+
577
+ **Advantages:**
578
+ - More diverse prompts
579
+ - Real-world use cases
580
+ - Human preference focus
581
+ - Large-scale evaluation
582
+
583
+ ### Dataset Recommendations
584
+
585
+ | Use Case | Dataset | Reason |
586
+ |----------|---------|--------|
587
+ | Academic Research | COCO | Standard benchmark, reproducible |
588
+ | FID Evaluation | COCO | Requires reference images |
589
+ | Human Preference | Pick-a-Pic | Trained on human comparisons |
590
+ | Large-scale Tests | Pick-a-Pic | 500K+ samples available |
591
+ | Quick Tests | COCO | Smaller, faster |
592
+
593
+ ---
594
+
595
+ ## Usage Examples
596
+
597
+ ### Example 1: Quick Test
598
+ ```bash
599
+ python eval.py \
600
+ --grad_config cosine_nesterov \
601
+ --metrics clip aesthetic \
602
+ --max_samples 10 \
603
+ --output_dir examples/quick_test
604
+ ```
605
+
606
+ ### Example 2: High-Quality Research Evaluation
607
+ ```bash
608
+ python eval.py \
609
+ --grad_config high_quality \
610
+ --metrics fid clip aesthetic pickscore hpsv2 \
611
+ --max_samples 200 \
612
+ --save_images \
613
+ --output_dir examples/research
614
+ ```
615
+
616
+ ### Example 3: Pick-a-Pic Benchmark
617
+ ```bash
618
+ python eval.py \
619
+ --dataset_type pickapic \
620
+ --grad_config cosine_nesterov \
621
+ --metrics pickscore hpsv2 imagereward \
622
+ --max_samples 500 \
623
+ --output_dir examples/pickapic
624
+ ```
625
+
626
+ ### Example 4: LPO Model Evaluation
627
+ ```bash
628
+ python eval.py \
629
+ --model_variant lpo \
630
+ --grad_config high_quality \
631
+ --metrics clip aesthetic pickscore \
632
+ --max_samples 100 \
633
+ --save_images \
634
+ --output_dir examples/lpo_model
635
+ ```
636
+
637
+ ### Example 5: Baseline Only (No Gradient Ascent)
638
+ ```bash
639
+ python eval.py \
640
+ --mode baseline \
641
+ --model_variant origin \
642
+ --metrics clip aesthetic pickscore \
643
+ --max_samples 50 \
644
+ --output_dir examples/baseline_only
645
+ ```
646
+
647
+ ### Example 6: Manual Configuration
648
+ ```bash
649
+ python eval.py \
650
+ --grad_range_start 200 \
651
+ --grad_range_end 800 \
652
+ --grad_steps 15 \
653
+ --grad_step_size 0.08 \
654
+ --metrics clip aesthetic \
655
+ --max_samples 50 \
656
+ --output_dir examples/manual_config
657
+ ```
658
+
659
+ ### Example 7: Model Comparison
660
+ ```bash
661
+ # Evaluate all model variants
662
+ for variant in origin spo diffusion_dpo lpo; do
663
+ python eval.py \
664
+ --model_variant $variant \
665
+ --grad_config high_quality \
666
+ --metrics clip aesthetic pickscore \
667
+ --max_samples 100 \
668
+ --save_images \
669
+ --output_dir results/comparison/${variant}
670
+ done
671
+ ```
672
+
673
+ ### Example 8: Conservative Optimization
674
+ ```bash
675
+ python eval.py \
676
+ --grad_config conservative \
677
+ --metrics clip aesthetic pickscore hpsv2 \
678
+ --max_samples 100 \
679
+ --save_images \
680
+ --output_dir examples/conservative
681
+ ```
682
+
683
+ ---
684
+
685
+ ## API Reference
686
+
687
+ ### Pipeline Usage
688
+
689
+ ```python
690
+ from diffusers import StableDiffusionPipeline
691
+ from pipelines.sd15_gradient_ascent_pipeline import StableDiffusionGradientAscentPipeline
692
+ from models import LRMRewardModel
693
+
694
+ # Load base pipeline
695
+ base_pipeline = StableDiffusionPipeline.from_pretrained(
696
+ "runwayml/stable-diffusion-v1-5",
697
+ torch_dtype=torch.float16
698
+ )
699
+
700
+ # Create gradient ascent pipeline
701
+ pipeline = StableDiffusionGradientAscentPipeline(**base_pipeline.components)
702
+
703
+ # Load reward model
704
+ reward_model = LRMRewardModel(
705
+ pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
706
+ lrm_model_path="casiatao/LRM",
707
+ guidance_scale=7.5,
708
+ device="cuda"
709
+ )
710
+ pipeline.set_reward_model(reward_model)
711
+
712
+ # Enable gradient ascent with preset
713
+ from grad_ascent_configs import get_config
714
+ config = get_config("cosine_nesterov")
715
+ pipeline.enable_gradient_ascent(**config)
716
+
717
+ # Or configure manually
718
+ pipeline.enable_gradient_ascent(
719
+ grad_timestep_range=(200, 800),
720
+ num_grad_steps=15,
721
+ grad_step_size=0.1,
722
+ lr_scheduler_type="cosine",
723
+ lr_scheduler_kwargs={"min_lr": 0.001, "warmup_steps": 3},
724
+ use_momentum=True,
725
+ momentum=0.9,
726
+ use_nesterov=True
727
+ )
728
+
729
+ # Generate with gradient ascent
730
+ output = pipeline(
731
+ prompt="a beautiful mountain landscape at sunset",
732
+ num_inference_steps=50,
733
+ guidance_scale=7.5,
734
+ )
735
+
736
+ # Get gradient statistics
737
+ stats = pipeline.grad_guidance.get_statistics()
738
+ print(f"Reward improvement: {stats['avg_reward_improvement']:.4f}")
739
+ ```
740
+
741
+ ### Custom LR Scheduler
742
+
743
+ ```python
744
+ from lr_scheduler import create_lr_scheduler
745
+
746
+ # Create cosine scheduler with warmup
747
+ scheduler = create_lr_scheduler(
748
+ scheduler_type="cosine",
749
+ initial_lr=0.1,
750
+ num_steps=20,
751
+ min_lr=0.001,
752
+ warmup_steps=5
753
+ )
754
+
755
+ # Use in optimization loop
756
+ for step in range(20):
757
+ current_lr = scheduler.get_lr()
758
+ # ... apply gradient with current_lr ...
759
+ scheduler.step()
760
+ ```
761
+
762
+ ### Configuration Management
763
+
764
+ ```python
765
+ from grad_ascent_configs import get_config, list_configs, print_config
766
+
767
+ # List all available configs
768
+ all_configs = list_configs()
769
+ print(f"Available configs: {all_configs}")
770
+
771
+ # Get specific config
772
+ config = get_config("high_quality")
773
+
774
+ # Print config details
775
+ print_config("cosine_nesterov")
776
+
777
+ # Create custom config
778
+ custom_config = {
779
+ "grad_timestep_range": (300, 700),
780
+ "num_grad_steps": 12,
781
+ "grad_step_size": 0.09,
782
+ "lr_scheduler_type": "cosine",
783
+ "lr_scheduler_kwargs": {"min_lr": 0.002, "warmup_steps": 4},
784
+ "use_momentum": True,
785
+ "momentum": 0.92,
786
+ "use_nesterov": True
787
+ }
788
+ pipeline.enable_gradient_ascent(**custom_config)
789
+ ```
790
+
791
+ ---
792
+
793
+ ## Command-Line Options
794
+
795
+ ### Essential Options
796
+
797
+ ```bash
798
+ --data_dir PATH # Path to data directory (default: ./data)
799
+ --dataset_type TYPE # Dataset: coco or pickapic (default: coco)
800
+ --model_variant VARIANT # Model: origin, spo, diffusion_dpo, lpo (default: origin)
801
+ --max_samples N # Max samples to evaluate (default: all)
802
+ --output_dir PATH # Output directory (default: eval_outputs)
803
+ --save_images # Save generated images
804
+ ```
805
+
806
+ ### Gradient Ascent Options
807
+
808
+ ```bash
809
+ --grad_config NAME # Use preset config (recommended)
810
+ --grad_range_start N # Gradient timestep start (default: 0)
811
+ --grad_range_end N # Gradient timestep end (default: 700)
812
+ --grad_steps N # Gradient steps per timestep (default: 5)
813
+ --grad_step_size FLOAT # Initial learning rate (default: 0.1)
814
+ ```
815
+
816
+ ### Evaluation Options
817
+
818
+ ```bash
819
+ --metrics METRIC [METRIC...] # Metrics to evaluate (default: clip aesthetic)
820
+ # Options: fid, clip, aesthetic, pickscore, hpsv2, imagereward
821
+ --mode MODE # baseline, gradient_ascent, or both (default: both)
822
+ --num_steps N # Diffusion inference steps (default: 50)
823
+ --cfg_scale FLOAT # CFG scale (default: 7.5, auto-adjusted for some models)
824
+ --batch_size N # Batch size (default: 1)
825
+ --log_interval N # Log every N batches (default: 10)
826
+ ```
827
+
828
+ ### Other Options
829
+
830
+ ```bash
831
+ --lrm_model PATH # LRM model path (default: casiatao/LRM)
832
+ --seed N # Random seed (default: 42)
833
+ --cuda N # CUDA device ID (default: 0)
834
+ ```
835
+
836
+ ### Complete Example
837
+
838
+ ```bash
839
+ python eval.py \
840
+ --data_dir ./data \
841
+ --dataset_type coco \
842
+ --model_variant lpo \
843
+ --grad_config high_quality \
844
+ --metrics fid clip aesthetic pickscore hpsv2 \
845
+ --max_samples 200 \
846
+ --num_steps 50 \
847
+ --save_images \
848
+ --output_dir results/comprehensive \
849
+ --cuda 0
850
+ ```
851
+
852
+ ---
853
+
854
+ ## Output Files
855
+
856
+ After running evaluation, the following files are created in **auto-incremented run folders**:
857
+
858
+ ```
859
+ RESULTS/SD1.5_GradAscent/
860
+ ├── run_1/ # First run
861
+ │ ├── eval.log # Complete execution log
862
+ │ └── reward_curve.png # Reward progression plot
863
+ ├── run_2/ # Second run
864
+ │ ├── eval.log
865
+ │ └── reward_curve.png
866
+ └── run_3/ # Third run
867
+ ├── eval.log
868
+ └── reward_curve.png
869
+ ```
870
+
871
+ ### Auto-Incrementing Run Folders
872
+
873
+ Each execution automatically creates a new `run_<N>/` folder, preventing accidental overwrites and maintaining a complete experiment history. No manual folder management needed!
874
+
875
+ ### eval.log Structure
876
+
877
+ The log contains detailed information for each batch:
878
+
879
+ ```
880
+ ======================================================================
881
+ COCO GRADIENT ASCENT EVALUATION (BATCHED)
882
+ ======================================================================
883
+ Logging to: ./RESULTS/SD1.5_GradAscent/run_1/eval.log
884
+ Device: cuda:6
885
+ Batch size: 1
886
+ Metrics: fid, clip, reward, aesthetic
887
+ Gradient Ascent: Range=[0, 900], Steps=1, StepSize=0.01
888
+ ======================================================================
889
+
890
+ [Batch 1/5000] Samples: 1/5000 | FID: 2.5432 | CLIP: 0.8234 | Reward (t=0): 5.2341 | Reward (Avg): 5.2341 | Aesthetic: 6.456
891
+ [Batch 161/5000] Samples: 161/5000 | FID: 2.3821 | CLIP: 0.8412 | Reward (t=0): 5.4123 | Reward (Avg): 5.3215 | Aesthetic: 6.523
892
+ ...
893
+
894
+ ======================================================================
895
+ FINAL RESULTS
896
+ ======================================================================
897
+ FID: 2.3456
898
+ CLIP avg: 0.8378
899
+ Reward avg: 5.3421
900
+ Aesthetic: 6.489
901
+ ======================================================================
902
+ ```
903
+
904
+ ### reward_curve.png Visualization
905
+
906
+ The reward curve plot shows two panels for the **first generated image**:
907
+
908
+ **Left Panel: Reward vs Timestep**
909
+ - X-axis: Denoising timestep (t)
910
+ - Y-axis: Reward score
911
+ - Green shaded region: Where gradient ascent is applied
912
+ - Shows how reward evolves as noise is removed
913
+
914
+ **Right Panel: Reward vs Denoising Step**
915
+ - X-axis: Sequential denoising step (0 to num_inference_steps)
916
+ - Y-axis: Reward score
917
+ - Same data, different perspective for easier interpretation
918
+
919
+ **Key Insights from the Plot:**
920
+ - **Upward trend**: Reward generally increases as denoising progresses
921
+ - **Sharp improvements**: Visible spikes where gradient ascent is effective
922
+ - **Final reward**: Last point corresponds to t=0 (decoded image reward)
923
+ - **Learning dynamics**: Shows if optimization is working at different noise levels
924
+
925
+ ### Reward Tracking Details
926
+
927
+ The script now explicitly tracks:
928
+
929
+ 1. **Timestep-specific rewards**: Computed at every denoising step
930
+ 2. **Final latent reward**: The reward for t=0 (the latent that gets decoded)
931
+ 3. **Running average**: Mean reward across all processed samples
932
+ 4. **Current batch reward**: Immediate feedback per batch
933
+
934
+ Example log output:
935
+ ```
936
+ Reward (t=0): 5.4123 # Reward for the final decoded latent
937
+ Reward (Avg): 5.3215 # Running average across all samples
938
+ ```
939
+
940
+ ### evaluation_results.json Structure
941
+
942
+ (Legacy format from eval.py - test_grad_sd1.5.py uses simplified logging)
943
+
944
+ ```json
945
+ {
946
+ "mode": "both",
947
+ "metrics": ["clip", "aesthetic", "pickscore"],
948
+ "config": {
949
+ "num_samples": 100,
950
+ "num_steps": 50,
951
+ "cfg_scale": 7.5,
952
+ "grad_range": [0, 700],
953
+ "grad_steps": 15,
954
+ "grad_step_size": 0.12
955
+ },
956
+ "baseline": {
957
+ "avg_reward": 0.7234,
958
+ "clip_score": 0.8123,
959
+ "aesthetic_score": 6.234,
960
+ "pickscore": 21.45
961
+ },
962
+ "gradient_ascent": {
963
+ "avg_reward": 0.7891,
964
+ "clip_score": 0.8345,
965
+ "aesthetic_score": 6.456,
966
+ "pickscore": 22.13,
967
+ "stats": {
968
+ "num_applications": 45,
969
+ "total_reward_improvement": 2.956,
970
+ "avg_reward_improvement": 0.0657
971
+ }
972
+ },
973
+ "comparison": {
974
+ "reward_difference": 0.0657,
975
+ "clip_difference": 0.0222,
976
+ "aesthetic_difference": 0.222,
977
+ "pickscore_difference": 0.68
978
+ }
979
+ }
980
+ ```
981
+
982
+ ---
983
+
984
+ ## Troubleshooting
985
+
986
+ ### Common Issues
987
+
988
+ #### 1. Out of Memory (OOM)
989
+
990
+ **Symptoms:**
991
+ ```
992
+ RuntimeError: CUDA out of memory
993
+ ```
994
+
995
+ **Solutions:**
996
+ ```bash
997
+ # Reduce batch size
998
+ --batch_size 1
999
+
1000
+ # Reduce max samples
1001
+ --max_samples 50
1002
+
1003
+ # Reduce gradient steps
1004
+ --grad_steps 5
1005
+
1006
+ # Use smaller config
1007
+ --grad_config aggressive # Only 8 steps
1008
+ ```
1009
+
1010
+ #### 2. Slow Evaluation
1011
+
1012
+ **Symptoms:**
1013
+ - Takes too long to complete
1014
+ - Hanging on metric computation
1015
+
1016
+ **Solutions:**
1017
+ ```bash
1018
+ # Skip expensive metrics
1019
+ --metrics clip aesthetic # Skip FID
1020
+
1021
+ # Reduce samples
1022
+ --max_samples 50
1023
+
1024
+ # Reduce diffusion steps
1025
+ --num_steps 20
1026
+
1027
+ # Use faster dataset
1028
+ --dataset_type pickapic # No FID computation
1029
+ ```
1030
+
1031
+ #### 3. Poor Results / No Improvement
1032
+
1033
+ **Symptoms:**
1034
+ - Reward doesn't increase
1035
+ - Quality worse after gradient ascent
1036
+
1037
+ **Solutions:**
1038
+ ```bash
1039
+ # Try better configs
1040
+ --grad_config high_quality
1041
+ --grad_config conservative
1042
+
1043
+ # Increase gradient steps
1044
+ --grad_steps 20
1045
+
1046
+ # Adjust timestep range (focus on middle)
1047
+ --grad_range_start 200 --grad_range_end 800
1048
+
1049
+ # Try different model variant
1050
+ --model_variant lpo
1051
+ ```
1052
+
1053
+ #### 4. Config Not Found
1054
+
1055
+ **Symptoms:**
1056
+ ```
1057
+ ValueError: Unknown config: my_config
1058
+ ```
1059
+
1060
+ **Solutions:**
1061
+ ```bash
1062
+ # List available configs
1063
+ python -c "from grad_ascent_configs import list_configs; print(list_configs())"
1064
+
1065
+ # Print config details
1066
+ python -c "from grad_ascent_configs import print_config; print_config('high_quality')"
1067
+ ```
1068
+
1069
+ #### 5. Metric Loading Errors
1070
+
1071
+ **Symptoms:**
1072
+ ```
1073
+ Warning: Could not load PickScore scorer
1074
+ ```
1075
+
1076
+ **Solutions:**
1077
+ ```bash
1078
+ # Install missing dependencies
1079
+ pip install transformers datasets
1080
+
1081
+ # Check HuggingFace Hub access
1082
+ huggingface-cli login
1083
+
1084
+ # Skip problematic metrics
1085
+ --metrics clip aesthetic # Skip pickscore if it fails
1086
+ ```
1087
+
1088
+ #### 6. Dataset Not Found
1089
+
1090
+ **Symptoms:**
1091
+ ```
1092
+ FileNotFoundError: Validation JSON not found
1093
+ ```
1094
+
1095
+ **Solutions:**
1096
+ ```bash
1097
+ # Check data directory structure
1098
+ ls data/coco/
1099
+
1100
+ # Use Pick-a-Pic instead (no local files needed)
1101
+ --dataset_type pickapic
1102
+
1103
+ # Provide correct data path
1104
+ --data_dir /path/to/your/data
1105
+ ```
1106
+
1107
+ ---
1108
+
1109
+ ## Best Practices
1110
+
1111
+ ### 1. **Start Small, Scale Up**
1112
+
1113
+ ```bash
1114
+ # First: Quick test (10 samples)
1115
+ python eval.py --grad_config cosine_nesterov --metrics clip --max_samples 10
1116
+
1117
+ # Then: Medium test (50 samples)
1118
+ python eval.py --grad_config cosine_nesterov --metrics clip aesthetic --max_samples 50
1119
+
1120
+ # Finally: Full evaluation (200+ samples)
1121
+ python eval.py --grad_config high_quality --metrics fid clip aesthetic pickscore hpsv2 --max_samples 200
1122
+ ```
1123
+
1124
+ ### 2. **Choose Right Config for Use Case**
1125
+
1126
+ | Goal | Config | Metrics |
1127
+ |------|--------|---------|
1128
+ | Quick experiment | `cosine_nesterov` | `clip` |
1129
+ | Research paper | `high_quality` | `fid clip aesthetic pickscore hpsv2` |
1130
+ | Production | `conservative` | `pickscore hpsv2` |
1131
+ | Fast iteration | `aggressive` | `clip aesthetic` |
1132
+
1133
+ ### 3. **Use Multiple Metrics**
1134
+
1135
+ Don't rely on a single metric. Recommended combinations:
1136
+
1137
+ ```bash
1138
+ # Text alignment + aesthetics
1139
+ --metrics clip aesthetic
1140
+
1141
+ # Human preference focus
1142
+ --metrics pickscore hpsv2 imagereward
1143
+
1144
+ # Comprehensive (research)
1145
+ --metrics fid clip aesthetic pickscore hpsv2
1146
+ ```
1147
+
1148
+ ### 4. **Save Important Runs**
1149
+
1150
+ ```bash
1151
+ # Always save images for important evaluations
1152
+ --save_images --output_dir results/important_run_$(date +%Y%m%d)
1153
+ ```
1154
+
1155
+ ### 5. **Monitor GPU Usage**
1156
+
1157
+ ```bash
1158
+ # In separate terminal
1159
+ watch -n 1 nvidia-smi
1160
+
1161
+ # Or use
1162
+ gpustat -i 1
1163
+ ```
1164
+
1165
+ ### 6. **Batch Evaluation**
1166
+
1167
+ ```bash
1168
+ # Create evaluation script
1169
+ cat << 'EOF' > run_evals.sh
1170
+ #!/bin/bash
1171
+ for config in cosine_nesterov high_quality conservative; do
1172
+ for model in origin lpo; do
1173
+ python eval.py \
1174
+ --model_variant $model \
1175
+ --grad_config $config \
1176
+ --metrics clip aesthetic pickscore \
1177
+ --max_samples 100 \
1178
+ --save_images \
1179
+ --output_dir results/${model}_${config}
1180
+ done
1181
+ done
1182
+ EOF
1183
+
1184
+ chmod +x run_evals.sh
1185
+ ./run_evals.sh
1186
+ ```
1187
+
1188
+ ### 7. **Reproducibility**
1189
+
1190
+ ```bash
1191
+ # Always set seed for reproducible results
1192
+ --seed 42
1193
+
1194
+ # Document your runs
1195
+ --output_dir results/experiment_name_$(date +%Y%m%d_%H%M)
1196
+ ```
1197
+
1198
+ ### 8. **Performance Tips**
1199
+
1200
+ - Use `batch_size=1` for safety (reward model compatibility)
1201
+ - Start with `--max_samples 10` for debugging
1202
+ - Use `--dataset_type pickapic` for large-scale evaluation (no FID overhead)
1203
+ - Skip `fid` metric if not needed (expensive)
1204
+ - Use `--num_steps 20-30` for faster generation (vs default 50)
1205
+
1206
+ ### 9. **Config Selection Guide**
1207
+
1208
+ ```python
1209
+ # Start here
1210
+ if "just_testing":
1211
+ config = "constant"
1212
+
1213
+ # General use
1214
+ elif "standard_evaluation":
1215
+ config = "cosine_nesterov" # Best balance
1216
+
1217
+ # Research/papers
1218
+ elif "need_best_quality":
1219
+ config = "high_quality" # 20 steps, nesterov
1220
+
1221
+ # Fast experiments
1222
+ elif "need_speed":
1223
+ config = "aggressive" # 8 steps
1224
+
1225
+ # Stability critical
1226
+ elif "need_stability":
1227
+ config = "conservative" # 25 steps, careful
1228
+ ```
1229
+
1230
+ ### 10. **Timestep Range Tips**
1231
+
1232
+ ```python
1233
+ # Full range (default)
1234
+ --grad_range_start 0 --grad_range_end 700
1235
+
1236
+ # Middle timesteps (often best)
1237
+ --grad_range_start 200 --grad_range_end 800
1238
+
1239
+ # Early timesteps (structure)
1240
+ --grad_range_start 500 --grad_range_end 1000
1241
+
1242
+ # Late timesteps (details)
1243
+ --grad_range_start 0 --grad_range_end 400
1244
+ ```
1245
+
1246
+ ---
1247
+
1248
+ ## Performance Metrics
1249
+
1250
+ ### Expected Results
1251
+
1252
+ Based on COCO validation set (100 samples):
1253
+
1254
+ | Method | CLIP ↑ | Aesthetic ↑ | PickScore ↑ | Time |
1255
+ |--------|--------|-------------|-------------|------|
1256
+ | Baseline (Origin) | 0.812 | 6.23 | 21.4 | 5 min |
1257
+ | + Constant | 0.819 | 6.28 | 21.6 | 6 min |
1258
+ | + Cosine Nesterov | 0.834 | 6.45 | 22.1 | 8 min |
1259
+ | + High Quality | 0.841 | 6.52 | 22.4 | 12 min |
1260
+ | Baseline (LPO) | 0.856 | 6.67 | 22.8 | 5 min |
1261
+ | LPO + High Quality | 0.873 | 6.89 | 23.5 | 12 min |
1262
+
1263
+ *Results may vary based on hardware and specific prompts*
1264
+
1265
+ ---
1266
+
1267
+ ## Citation
1268
+
1269
+ If you use this code in your research, please cite:
1270
+
1271
+ ```bibtex
1272
+ @article{lpo2024,
1273
+ title={Latent Preference Optimization for Diffusion Models},
1274
+ author={Your Name},
1275
+ journal={arXiv preprint},
1276
+ year={2024}
1277
+ }
1278
+ ```
1279
+
1280
+ ---
1281
+
1282
+ ## License
1283
+
1284
+ This project follows the license of the main LPO repository.
1285
+
1286
+ ---
1287
+
1288
+ ## Contributing
1289
+
1290
+ Contributions are welcome! Please:
1291
+
1292
+ 1. Test your changes with `--max_samples 10`
1293
+ 2. Document new features in this README
1294
+ 3. Add examples to `examples.sh`
1295
+ 4. Follow existing code style
1296
+
1297
+ ---
1298
+
1299
+ ## Support
1300
+
1301
+ For issues and questions:
1302
+
1303
+ 1. Check [Troubleshooting](#troubleshooting) section
1304
+ 2. Review [Examples](#usage-examples)
1305
+ 3. Open an issue on GitHub
1306
+
1307
+ ---
1308
+
1309
+ ## Changelog
1310
+
1311
+ ### Latest Version (January 2026)
1312
+
1313
+ **New Features:**
1314
+ - ✨ Learning rate scheduling (constant, linear, cosine, exponential, step)
1315
+ - ✨ Momentum optimization (standard and Nesterov)
1316
+ - ✨ 15 configuration presets
1317
+ - ✨ Additional metrics (PickScore, HPSv2, ImageReward)
1318
+ - ✨ Pick-a-Pic validation dataset support
1319
+ - ✨ SD1.5 model variants (Origin, SPO, DPO, LPO)
1320
+ - ✨ Comprehensive evaluation framework
1321
+ - ✨ **Automatic run folder creation** - Each run creates `run_1/`, `run_2/`, etc.
1322
+ - ✨ **Reward curve visualization** - Automatic plotting of reward progression across timesteps
1323
+ - ✨ **Final timestep reward tracking** - Reports reward specifically from t=0 (decoded latent)
1324
+ - ✨ **Detailed reward logging** - Shows both last timestep reward and running average
1325
+
1326
+ **Improvements:**
1327
+ - 🚀 Better convergence with LR scheduling
1328
+ - 🚀 Faster optimization with momentum
1329
+ - 📊 More comprehensive quality assessment
1330
+ - 📊 Visual feedback with reward curve plots
1331
+ - 📚 Complete documentation
1332
+ - 🔍 Enhanced debugging with timestep-specific reward tracking
1333
+
1334
+ ---
1335
+
1336
+ **Happy Optimizing! 🚀**
Reward_sdxl_idealized/config_analysis_tuning.ipynb ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "a24d02a2",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import json\n",
11
+ "import pandas as pd\n",
12
+ "import numpy as np\n",
13
+ "from pathlib import Path\n",
14
+ "from datetime import datetime\n",
15
+ "import warnings\n",
16
+ "warnings.filterwarnings('ignore')\n",
17
+ "\n",
18
+ "# ============================================================================\n",
19
+ "# SECTION 1: Load and Parse Results from GPU Tuning Runs\n",
20
+ "# ==========================-==================================================\n",
21
+ "print(\"=\" * 80)\n",
22
+ "print(\"LOADING TUNING RESULTS FROM GPU RUNS\")\n",
23
+ "print(\"=\" * 80)\n",
24
+ "\n",
25
+ "results_dir = Path(\"RESULTS_TURNING/run_2\")\n",
26
+ "all_experiments = []\n",
27
+ "baseline_metrics = None\n",
28
+ "\n",
29
+ "# Collect results from all GPU runs\n",
30
+ "for gpu_id in range(8):\n",
31
+ " gpu_dir = results_dir / f\"gpu_{gpu_id}\"\n",
32
+ " results_file = gpu_dir / \"tuning_results.json\"\n",
33
+ " \n",
34
+ " if results_file.exists():\n",
35
+ " with open(results_file, 'r') as f:\n",
36
+ " data = json.load(f)\n",
37
+ " \n",
38
+ " # Extract baseline (same across all GPUs)\n",
39
+ " if baseline_metrics is None and \"baseline\" in data:\n",
40
+ " baseline_metrics = data[\"baseline\"][\"metrics\"]\n",
41
+ " print(f\"\\n📊 Baseline Metrics (cfg_scale=5.0):\")\n",
42
+ " for metric, value in baseline_metrics.items():\n",
43
+ " print(f\" {metric:15s}: {value:.6f}\")\n",
44
+ " \n",
45
+ " # Collect all experiments\n",
46
+ " if \"experiments\" in data:\n",
47
+ " all_experiments.extend(data[\"experiments\"])\n",
48
+ " print(f\"✓ GPU {gpu_id}: {len(data['experiments'])} results loaded\")\n",
49
+ "\n",
50
+ "print(f\"\\n✓ Total experiments loaded: {len(all_experiments)}\")\n",
51
+ "\n",
52
+ "# ============================================================================\n",
53
+ "# SECTION 2: Filter Top Configs with Improvements Across All Metrics\n",
54
+ "# ============================================================================\n",
55
+ "print(\"\\n\" + \"=\" * 80)\n",
56
+ "print(\"FILTERING CONFIGURATIONS WITH IMPROVEMENTS IN ALL METRICS\")\n",
57
+ "print(\"=\" * 80)\n",
58
+ "\n",
59
+ "# Define improvement metrics to track (using ImageReward instead of Reward)\n",
60
+ "improvement_metrics = [\n",
61
+ " \"aesthetic_improvement\", \n",
62
+ " \"imagereward_improvement\", \n",
63
+ " \"clip_improvement\", \n",
64
+ " \"pickscore_improvement\", \n",
65
+ " \"hpsv2_improvement\"\n",
66
+ " ]\n",
67
+ "\n",
68
+ "# Filter experiments with improvements in ALL metrics\n",
69
+ "top_configs = []\n",
70
+ "\n",
71
+ "for exp in all_experiments:\n",
72
+ " if \"improvements\" not in exp or \"config\" not in exp or \"metrics\" not in exp:\n",
73
+ " continue\n",
74
+ " \n",
75
+ " improvements = exp[\"improvements\"]\n",
76
+ " config = exp[\"config\"]\n",
77
+ " metrics = exp[\"metrics\"]\n",
78
+ " \n",
79
+ " # Check if ALL improvements are positive (>0)\n",
80
+ " all_positive = all(improvements.get(metric, -1) > 0 for metric in improvement_metrics)\n",
81
+ " \n",
82
+ " if all_positive:\n",
83
+ " # Calculate aggregate improvement score\n",
84
+ " avg_improvement = np.mean([improvements.get(metric, 0) for metric in improvement_metrics])\n",
85
+ " \n",
86
+ " top_configs.append({\n",
87
+ " \"config\": config,\n",
88
+ " \"metrics\": metrics,\n",
89
+ " \"improvements\": improvements,\n",
90
+ " \"avg_improvement\": avg_improvement\n",
91
+ " })\n",
92
+ "\n",
93
+ "print(f\"✓ Found {len(top_configs)} configurations with improvements in ALL metrics\")\n",
94
+ "\n",
95
+ "# Sort by average improvement\n",
96
+ "top_configs.sort(key=lambda x: x[\"avg_improvement\"], reverse=True)\n",
97
+ "\n",
98
+ "# Get top 10\n",
99
+ "top_10 = top_configs[:10]\n",
100
+ "print(f\"✓ Extracted top 10 best performing configurations\")\n",
101
+ "\n",
102
+ "# ============================================================================\n",
103
+ "# SECTION 3: Create Comprehensive Results Table\n",
104
+ "# ============================================================================\n",
105
+ "print(\"\\n\" + \"=\" * 80)\n",
106
+ "print(\"CREATING COMPREHENSIVE RESULTS TABLE\")\n",
107
+ "print(\"=\" * 80)\n",
108
+ "\n",
109
+ "# Build detailed table data\n",
110
+ "table_data = []\n",
111
+ "\n",
112
+ "for rank, result in enumerate(top_10, 1):\n",
113
+ " cfg = result[\"config\"]\n",
114
+ " metrics = result[\"metrics\"]\n",
115
+ " improvements = result[\"improvements\"]\n",
116
+ " \n",
117
+ " row = {\n",
118
+ " \"Rank\": rank,\n",
119
+ " \"CFG Scale\": cfg.get(\"cfg_scale\", \"N/A\"),\n",
120
+ " \"Grad Config\": cfg.get(\"grad_config\", \"N/A\"),\n",
121
+ " \"Steps\": cfg.get(\"num_grad_steps\", \"N/A\"),\n",
122
+ " \"LR\": cfg.get(\"grad_step_size\", \"N/A\"),\n",
123
+ " \"Momentum\": cfg.get(\"momentum\", \"N/A\"),\n",
124
+ " \"ImageReward\": f\"{metrics.get('imagereward', 0):.6f}\",\n",
125
+ " \"ImageReward ↑\": f\"{improvements.get('imagereward_improvement', 0):+.2f}%\",\n",
126
+ " \"CLIP\": f\"{metrics.get('clip', 0):.4f}\",\n",
127
+ " \"CLIP ↑\": f\"{improvements.get('clip_improvement', 0):+.2f}%\",\n",
128
+ " \"Aesthetic\": f\"{metrics.get('aesthetic', 0):.4f}\",\n",
129
+ " \"Aesthetic ↑\": f\"{improvements.get('aesthetic_improvement', 0):+.2f}%\",\n",
130
+ " \"PickScore\": f\"{metrics.get('pickscore', 0):.4f}\",\n",
131
+ " \"PickScore ↑\": f\"{improvements.get('pickscore_improvement', 0):+.2f}%\",\n",
132
+ " \"HPSv2\": f\"{metrics.get('hpsv2', 0):.4f}\",\n",
133
+ " \"HPSv2 ↑\": f\"{improvements.get('hpsv2_improvement', 0):+.2f}%\",\n",
134
+ " \"Avg Improvement\": f\"{result['avg_improvement']:+.2f}%\",\n",
135
+ " }\n",
136
+ " \n",
137
+ " table_data.append(row)\n",
138
+ "\n",
139
+ "df_top_10 = pd.DataFrame(table_data)\n",
140
+ "\n",
141
+ "print(\"\\n📋 TOP 10 CONFIGURATIONS WITH IMPROVEMENTS IN ALL METRICS:\")\n",
142
+ "print(\"=\" * 180)\n",
143
+ "print(df_top_10.to_string(index=False))\n",
144
+ "print(\"=\" * 180)\n",
145
+ "\n",
146
+ "# ============================================================================\n",
147
+ "# SECTION 4: Visualize and Summary Statistics\n",
148
+ "# ============================================================================\n",
149
+ "print(\"\\n\" + \"=\" * 80)\n",
150
+ "print(\"SUMMARY STATISTICS\")\n",
151
+ "print(\"=\" * 80)\n",
152
+ "\n",
153
+ "# Extract numeric improvement values for analysis\n",
154
+ "improvement_summary = []\n",
155
+ "for result in top_10:\n",
156
+ " improvements = result[\"improvements\"]\n",
157
+ " for metric in [\"imagereward_improvement\", \"clip_improvement\", \"aesthetic_improvement\", \n",
158
+ " \"pickscore_improvement\", \"hpsv2_improvement\"]:\n",
159
+ " metric_name = metric.replace(\"_improvement\", \"\").upper()\n",
160
+ " improvement_summary.append({\n",
161
+ " \"Metric\": metric_name,\n",
162
+ " \"Improvement %\": improvements.get(metric, 0)\n",
163
+ " })\n",
164
+ "\n",
165
+ "df_summary = pd.DataFrame(improvement_summary)\n",
166
+ "\n",
167
+ "print(\"\\n📊 Average Improvements by Metric (Top 10):\")\n",
168
+ "metric_stats = df_summary.groupby(\"Metric\")[\"Improvement %\"].agg([\"mean\", \"std\", \"min\", \"max\"])\n",
169
+ "print(metric_stats.round(2))\n",
170
+ "\n",
171
+ "print(\"\\n📈 Best Configuration Details:\")\n",
172
+ "best = top_10[0]\n",
173
+ "best_cfg = best[\"config\"]\n",
174
+ "best_metrics = best[\"metrics\"]\n",
175
+ "best_improvements = best[\"improvements\"]\n",
176
+ "\n",
177
+ "print(f\"\\n✓ RANK #1 - Best Performing Configuration:\")\n",
178
+ "print(f\" Configuration:\")\n",
179
+ "print(f\" • CFG Scale: {best_cfg.get('cfg_scale')}\")\n",
180
+ "print(f\" • Gradient Config: {best_cfg.get('grad_config')}\")\n",
181
+ "print(f\" • Gradient Steps: {best_cfg.get('num_grad_steps')}\")\n",
182
+ "print(f\" • Step Size: {best_cfg.get('grad_step_size')}\")\n",
183
+ "print(f\" • Momentum: {best_cfg.get('momentum')}\")\n",
184
+ "print(f\"\\n Metrics:\")\n",
185
+ "for metric in [\"imagereward\", \"clip\", \"aesthetic\", \"pickscore\", \"hpsv2\"]:\n",
186
+ " baseline_val = baseline_metrics.get(metric, 0)\n",
187
+ " current_val = best_metrics.get(metric, 0)\n",
188
+ " improvement = best_improvements.get(f\"{metric}_improvement\", 0)\n",
189
+ " print(f\" • {metric:12s}: {current_val:8.6f} (baseline: {baseline_val:8.6f}) ↑ {improvement:+6.2f}%\")\n",
190
+ "\n",
191
+ "print(\"\\n\" + \"=\" * 80)\n",
192
+ "print(\"✓ ANALYSIS COMPLETE - TOP 10 CONFIGURATIONS IDENTIFIED\")\n",
193
+ "print(\"=\" * 80)"
194
+ ]
195
+ }
196
+ ],
197
+ "metadata": {
198
+ "kernelspec": {
199
+ "display_name": "Python 3",
200
+ "language": "python",
201
+ "name": "python3"
202
+ },
203
+ "language_info": {
204
+ "codemirror_mode": {
205
+ "name": "ipython",
206
+ "version": 3
207
+ },
208
+ "file_extension": ".py",
209
+ "mimetype": "text/x-python",
210
+ "name": "python",
211
+ "nbconvert_exporter": "python",
212
+ "pygments_lexer": "ipython3",
213
+ "version": "3.10.18"
214
+ }
215
+ },
216
+ "nbformat": 4,
217
+ "nbformat_minor": 5
218
+ }
Reward_sdxl_idealized/eval.py ADDED
@@ -0,0 +1,1143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation script for comparing baseline and gradient ascent pipelines using multiple metrics.
3
+
4
+ This script evaluates both pipelines on COCO or Pick-a-Pic validation sets and computes
5
+ various preference and quality metrics.
6
+ """
7
+ import warnings
8
+ warnings.filterwarnings("ignore")
9
+ import torch
10
+ import torch.nn as nn
11
+ import json
12
+ import os
13
+ import sys
14
+ import logging
15
+ from pathlib import Path
16
+ from PIL import Image
17
+ from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel, StableDiffusionXLPipeline
18
+ from models.reward_model import LRMRewardModelXL
19
+ from pipelines.sdxl_gradient_ascent_pipeline import StableDiffusionXLGradientAscentPipeline
20
+ from torchmetrics.image.fid import FrechetInceptionDistance
21
+ from torchmetrics.multimodal import CLIPScore
22
+ from transformers import CLIPModel, CLIPProcessor
23
+ from tqdm import tqdm
24
+ import numpy as np
25
+ import argparse
26
+ from datasets import load_dataset
27
+ from grad_ascent_configs import get_config, list_configs
28
+ import matplotlib.pyplot as plt
29
+ import matplotlib
30
+ matplotlib.use('Agg') # Use non-interactive backend
31
+
32
+ # Import evaluation metrics
33
+ sys.path.append('../evaluation')
34
+ from pick_score import PickScorer
35
+ from hpsv2_score import HPSv2Scorer
36
+ from imagereward_score import load_imagereward
37
+ from huggingface_hub import hf_hub_download
38
+
39
+ import random
40
+
41
+ def seed_everything(seed: int):
42
+ """Locks down all random number generators for absolute reproducibility."""
43
+ # 1. Python & Numpy
44
+ random.seed(seed)
45
+ np.random.seed(seed)
46
+
47
+ # 2. PyTorch Base
48
+ torch.manual_seed(seed)
49
+ if torch.cuda.is_available():
50
+ torch.cuda.manual_seed(seed)
51
+ torch.cuda.manual_seed_all(seed) # For multi-GPU
52
+
53
+ # 3. cuDNN Determinism (Crucial for consistent gradients)
54
+ torch.backends.cudnn.deterministic = True
55
+ torch.backends.cudnn.benchmark = False
56
+
57
+ # 4. Optional: Force deterministic algorithms for PyTorch 2.0+
58
+ # Uncomment if variance persists, but it may slow down generation slightly
59
+ # torch.use_deterministic_algorithms(True)
60
+
61
+
62
+ class MLP(nn.Module):
63
+ """MLP for aesthetic scoring."""
64
+ def __init__(self):
65
+ super().__init__()
66
+ self.layers = nn.Sequential(
67
+ nn.Linear(768, 1024),
68
+ nn.Dropout(0.2),
69
+ nn.Linear(1024, 128),
70
+ nn.Dropout(0.2),
71
+ nn.Linear(128, 64),
72
+ nn.Dropout(0.1),
73
+ nn.Linear(64, 16),
74
+ nn.Linear(16, 1),
75
+ )
76
+
77
+ @torch.no_grad()
78
+ def forward(self, embed):
79
+ return self.layers(embed)
80
+
81
+
82
+ class AestheticScorer(torch.nn.Module):
83
+ """Aesthetic scorer using CLIP and MLP."""
84
+ def __init__(self, dtype, device, clip_name_or_path="openai/clip-vit-large-patch14",
85
+ aesthetic_path="./sac+logos+ava1-l14-linearMSE.pth"):
86
+ super().__init__()
87
+ self.clip = CLIPModel.from_pretrained(clip_name_or_path)
88
+ self.processor = CLIPProcessor.from_pretrained(clip_name_or_path)
89
+ self.mlp = MLP()
90
+
91
+ # Load aesthetic weights
92
+ if os.path.exists(aesthetic_path):
93
+ state_dict = torch.load(aesthetic_path, map_location='cpu')
94
+ self.mlp.load_state_dict(state_dict)
95
+ else:
96
+ print(f"Warning: Aesthetic weights not found at {aesthetic_path}")
97
+
98
+ self.dtype = dtype
99
+ self.to(device)
100
+ self.eval()
101
+
102
+ @torch.no_grad()
103
+ def __call__(self, images):
104
+ device = next(self.parameters()).device
105
+ inputs = self.processor(images=images, return_tensors="pt")
106
+ inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()}
107
+ embed = self.clip.get_image_features(**inputs)
108
+ # normalize embedding
109
+ embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
110
+ return self.mlp(embed).squeeze(1)
111
+
112
+
113
+ class TeeLogger:
114
+ """Logger that writes to both console and file."""
115
+ def __init__(self, log_file):
116
+ self.terminal = sys.stdout
117
+ self.log = open(log_file, 'w')
118
+
119
+ def write(self, message):
120
+ self.terminal.write(message)
121
+ self.log.write(message)
122
+ self.log.flush()
123
+
124
+ def flush(self):
125
+ self.terminal.flush()
126
+ self.log.flush()
127
+
128
+ def close(self):
129
+ self.log.close()
130
+
131
+
132
+ def setup_logging(output_dir):
133
+ """Setup logging to both console and file."""
134
+ output_path = Path(output_dir)
135
+ output_path.mkdir(parents=True, exist_ok=True)
136
+ log_file = output_path / "log.log"
137
+
138
+ # Redirect stdout to both console and file
139
+ tee = TeeLogger(log_file)
140
+ sys.stdout = tee
141
+
142
+ return tee, log_file
143
+
144
+
145
+ def load_validation_data(data_dir, max_samples=None, dataset_type="coco"):
146
+ """Load validation prompts and image paths.
147
+
148
+ Args:
149
+ data_dir: Path to data directory
150
+ max_samples: Maximum number of samples to load
151
+ dataset_type: Type of dataset ("coco" or "pickapic")
152
+
153
+ Returns:
154
+ prompts: List of text prompts
155
+ image_paths: List of image paths (None for pickapic streaming dataset)
156
+ """
157
+ if dataset_type == "coco":
158
+ data_dir = Path(data_dir)
159
+ val_json = data_dir / "coco" / "caption_val.json"
160
+
161
+ if not val_json.exists():
162
+ raise FileNotFoundError(f"Validation JSON not found: {val_json}")
163
+
164
+ with open(val_json, 'r') as f:
165
+ data = json.load(f)
166
+
167
+ # Validate that image folder exists
168
+ val_img_dir = data_dir / "coco" / "images" / "val"
169
+ if not val_img_dir.exists():
170
+ raise FileNotFoundError(f"Validation image directory not found: {val_img_dir}")
171
+
172
+ # Parse data
173
+ prompts = []
174
+ image_paths = []
175
+ for img_path, caption in data.items():
176
+ full_path = data_dir / "coco" / img_path
177
+ if full_path.exists():
178
+ prompts.append(caption)
179
+ image_paths.append(str(full_path))
180
+ else:
181
+ print(f"Warning: Image not found: {full_path}")
182
+
183
+ if max_samples:
184
+ prompts = prompts[:max_samples]
185
+ image_paths = image_paths[:max_samples]
186
+
187
+ print(f"Loaded {len(prompts)} COCO validation samples")
188
+ return prompts, image_paths
189
+
190
+ elif dataset_type == "pickapic":
191
+ print("Loading Pick-a-Pic validation dataset (streaming)...")
192
+ val_dataset = load_dataset("pickapic-anonymous/pickapic_v1", split="validation_unique", streaming=True)
193
+
194
+ prompts = []
195
+ for i, sample in enumerate(val_dataset):
196
+ prompts.append(sample['caption'])
197
+ if max_samples and i + 1 >= max_samples:
198
+ break
199
+
200
+ print(f"Loaded {len(prompts)} Pick-a-Pic validation samples")
201
+ return prompts, None # No reference images for Pick-a-Pic
202
+
203
+ else:
204
+ raise ValueError(f"Unknown dataset type: {dataset_type}. Choose 'coco' or 'pickapic'.")
205
+
206
+
207
+ def generate_and_evaluate(
208
+ pipeline,
209
+ prompts,
210
+ image_paths,
211
+ device,
212
+ dtype,
213
+ num_inference_steps=20,
214
+ guidance_scale=7.5,
215
+ seed=42,
216
+ batch_size=1,
217
+ apply_gradient_ascent=False,
218
+ mode_name="baseline",
219
+ log_interval=10,
220
+ output_dir=None,
221
+ save_images=False,
222
+ clip_scorer=None,
223
+ aesthetic_scorer=None,
224
+ pick_scorer=None,
225
+ hpsv2_scorer=None,
226
+ hpsv21_scorer=None,
227
+ imagereward_scorer=None,
228
+ compute_fid=True
229
+ ):
230
+ """Generate images and update FID metric."""
231
+ pipeline.to(device)
232
+
233
+ print(f"\nGenerating images with {mode_name} mode...")
234
+
235
+
236
+ all_rewards = []
237
+ all_clip_scores = []
238
+ all_aesthetic_scores = []
239
+ all_pick_scores = []
240
+ all_hpsv2_scores = []
241
+ all_hpsv21_scores = []
242
+ all_imagereward_scores = []
243
+ lr_history_first_image = None # Store LR history for first image
244
+ num_batches = (len(prompts) + batch_size - 1) // batch_size
245
+
246
+ # Create output directory if saving images
247
+ if save_images and output_dir:
248
+ mode_output_dir = Path(output_dir) / mode_name
249
+ mode_output_dir.mkdir(parents=True, exist_ok=True)
250
+
251
+ # Disable internal progress bars
252
+ pipeline.set_progress_bar_config(disable=True)
253
+
254
+ for idx, i in enumerate(tqdm(range(0, len(prompts), batch_size), desc=f"Generating {mode_name}")):
255
+ batch_prompts = prompts[i:i+batch_size]
256
+ batch_real_paths = image_paths[i:i+batch_size] if image_paths is not None else None
257
+ batch_num = idx + 1
258
+
259
+ # Initialize FID metric if needed
260
+ fid_metric = None
261
+ real_images_tensor = None
262
+
263
+ if compute_fid and batch_real_paths is not None:
264
+ fid_metric = FrechetInceptionDistance().to(device)
265
+
266
+ # Load and update FID with real images for this batch
267
+ real_images = []
268
+ for path in batch_real_paths:
269
+ img = Image.open(path).convert("RGB")
270
+ img = img.resize((512, 512)) # Inception v3 input size
271
+ img_array = np.array(img)
272
+ real_images.append(img_array)
273
+
274
+ # Convert to tensor [B, H, W, C] -> [B, C, H, W]
275
+ real_images_tensor = torch.from_numpy(np.stack(real_images)).permute(0, 3, 1, 2).float()
276
+ real_images_tensor = real_images_tensor.to(device)
277
+
278
+ # Generate images
279
+ generator = torch.Generator(device=device).manual_seed(seed + i)
280
+
281
+ with torch.no_grad():
282
+ result = pipeline(
283
+ prompt=batch_prompts,
284
+ num_inference_steps=num_inference_steps,
285
+ guidance_scale=guidance_scale,
286
+ generator=generator,
287
+ track_rewards=True,
288
+ print_rewards=False,
289
+ apply_gradient_ascent=apply_gradient_ascent,
290
+ verbose_grad=False,
291
+ )
292
+
293
+ # Process generated images
294
+ images = result.images
295
+
296
+ # Update FID metric if computing it
297
+ if compute_fid and fid_metric is not None:
298
+ image_tensors = []
299
+
300
+ for img in images:
301
+ img_resized = img.resize((512, 512)) # Inception v3 input size
302
+ img_array = np.array(img_resized)
303
+ image_tensors.append(img_array)
304
+
305
+ # Convert to tensor and update FID
306
+ images_tensor = torch.from_numpy(np.stack(image_tensors)).permute(0, 3, 1, 2).float()
307
+ images_tensor = images_tensor.to(device)
308
+
309
+ if batch_size == 1:
310
+ real_images_tensor = torch.cat([real_images_tensor, real_images_tensor], dim=0).to(dtype=torch.uint8)
311
+ images_tensor = torch.cat([images_tensor, images_tensor], dim=0).to(dtype=torch.uint8)
312
+ fid_metric.update(real_images_tensor, real=True)
313
+ fid_metric.update(images_tensor, real=False)
314
+
315
+ # Track rewards - get the final timestep reward (t=0)
316
+ current_batch_final_reward = None
317
+ current_batch_final_timestep = None
318
+ if hasattr(pipeline, 'reward_history') and pipeline.reward_history:
319
+ # For each image, get the reward from the last denoising step (t=0 or closest to 0)
320
+ num_steps_per_image = num_inference_steps
321
+
322
+ # Get the last entry which corresponds to the final timestep of the last image in batch
323
+ final_entry = pipeline.reward_history[-1]
324
+ current_batch_final_reward = final_entry['reward_score']
325
+ current_batch_final_timestep = final_entry['timestep']
326
+ all_rewards.append(current_batch_final_reward)
327
+
328
+ # Capture LR history from first image if gradient ascent is enabled
329
+ if apply_gradient_ascent and idx == 0 and lr_history_first_image is None:
330
+ if hasattr(pipeline, 'grad_guidance') and pipeline.grad_guidance:
331
+ grad_stats = pipeline.grad_guidance.get_statistics()
332
+ if grad_stats and 'detailed_stats' in grad_stats:
333
+ # Extract LR history from the gradient ascent statistics
334
+ lr_history_first_image = {
335
+ 'prompt': batch_prompts[0],
336
+ 'timesteps': [],
337
+ 'learning_rates': [], # All LR values from all gradient steps
338
+ 'rewards': []
339
+ }
340
+ for stat in grad_stats['detailed_stats']:
341
+ lr_history_first_image['timesteps'].append(stat['timestep'])
342
+ if 'lr_history' in stat:
343
+ # Extend with all LR values from this timestep's gradient steps
344
+ lr_history_first_image['learning_rates'].extend(stat['lr_history'])
345
+ # Collect all rewards from reward_history for each gradient step
346
+ if 'reward_history' in stat:
347
+ lr_history_first_image['rewards'].extend(stat['reward_history'])
348
+
349
+ # Compute CLIP score
350
+ if clip_scorer is not None:
351
+ # Convert PIL images to tensor format for CLIP score [C, H, W] in range [0, 1]
352
+ for img, prompt in zip(images, batch_prompts):
353
+ img_array = np.array(img).astype(np.float32)
354
+ img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).to(device)
355
+ clip_score = clip_scorer(img_tensor, [prompt]).item()
356
+ all_clip_scores.append(clip_score)
357
+
358
+ # Compute aesthetic score
359
+ if aesthetic_scorer is not None:
360
+ aesthetic_scores = aesthetic_scorer(images)
361
+ if isinstance(aesthetic_scores, torch.Tensor):
362
+ aesthetic_scores = aesthetic_scores.cpu().numpy()
363
+ if aesthetic_scores.ndim == 0:
364
+ aesthetic_scores = [aesthetic_scores.item()]
365
+ all_aesthetic_scores.extend(aesthetic_scores.tolist() if hasattr(aesthetic_scores, 'tolist') else [aesthetic_scores])
366
+
367
+ # Compute PickScore
368
+ if pick_scorer is not None:
369
+ for img, prompt in zip(images, batch_prompts):
370
+ pick_score = pick_scorer(prompt, [img])[0]
371
+ all_pick_scores.append(pick_score)
372
+
373
+ # Compute HPSv2 score
374
+ if hpsv2_scorer is not None:
375
+ for img, prompt in zip(images, batch_prompts):
376
+ hpsv2_score = hpsv2_scorer.score(img, prompt)[0]
377
+ all_hpsv2_scores.append(hpsv2_score)
378
+
379
+ # Compute HPSv2.1 score
380
+ if hpsv21_scorer is not None:
381
+ for img, prompt in zip(images, batch_prompts):
382
+ hpsv21_score = hpsv21_scorer.score(img, prompt)[0]
383
+ all_hpsv21_scores.append(hpsv21_score)
384
+
385
+ # Compute ImageReward score
386
+ if imagereward_scorer is not None:
387
+ for img, prompt in zip(images, batch_prompts):
388
+ imagereward_score = imagereward_scorer.score(prompt, img)
389
+ all_imagereward_scores.append(imagereward_score)
390
+
391
+ # Save generated images if requested
392
+ if save_images and output_dir:
393
+ for img_idx, img in enumerate(images):
394
+ global_idx = i + img_idx
395
+ img_path = mode_output_dir / f"sample_{global_idx:05d}.png"
396
+ img.save(img_path)
397
+
398
+ # Log intermediate FID and metrics every log_interval batches
399
+ if batch_num % log_interval == 0 or batch_num == num_batches:
400
+ num_samples_processed = min(i + batch_size, len(prompts))
401
+ log_msg = f"\n[{mode_name}] Batch {batch_num}/{num_batches} | Samples: {num_samples_processed}/{len(prompts)}"
402
+
403
+ # Add FID if computing
404
+ if compute_fid and fid_metric is not None:
405
+ try:
406
+ current_fid = fid_metric.compute().item()
407
+ log_msg += f" | FID: {current_fid:.4f}"
408
+ except Exception as e:
409
+ log_msg += f" | FID: Computing..."
410
+
411
+ # Add reward - show both final timestep reward and average
412
+ if all_rewards:
413
+ avg_reward = np.mean(all_rewards)
414
+ if current_batch_final_reward is not None:
415
+ log_msg += f" | Reward (t={current_batch_final_timestep}): {current_batch_final_reward:.4f}"
416
+ log_msg += f" | Reward (Avg): {avg_reward:.4f}"
417
+ else:
418
+ log_msg += f" | Reward (Avg): {avg_reward:.4f}"
419
+
420
+ # Add CLIP if computing
421
+ if clip_scorer is not None and all_clip_scores:
422
+ log_msg += f" | CLIP: {np.mean(all_clip_scores):.4f}"
423
+
424
+ # Add aesthetic if computing
425
+ if aesthetic_scorer is not None and all_aesthetic_scores:
426
+ log_msg += f" | Aesthetic: {np.mean(all_aesthetic_scores):.4f}"
427
+
428
+ # Add PickScore
429
+ if pick_scorer is not None and all_pick_scores:
430
+ log_msg += f" | PickScore: {np.mean(all_pick_scores):.4f}"
431
+
432
+ # Add HPSv2
433
+ if hpsv2_scorer is not None and all_hpsv2_scores:
434
+ log_msg += f" | HPSv2: {np.mean(all_hpsv2_scores):.4f}"
435
+
436
+ # Add HPSv2.1
437
+ if hpsv21_scorer is not None and all_hpsv21_scores:
438
+ log_msg += f" | HPSv2.1: {np.mean(all_hpsv21_scores):.4f}"
439
+
440
+ # Add ImageReward
441
+ if imagereward_scorer is not None and all_imagereward_scores:
442
+ log_msg += f" | ImageReward: {np.mean(all_imagereward_scores):.4f}"
443
+
444
+ print(log_msg)
445
+
446
+ # Re-enable progress bars
447
+ pipeline.set_progress_bar_config(disable=False)
448
+
449
+ avg_reward = np.mean(all_rewards) if all_rewards else 0.0
450
+ avg_clip_score = np.mean(all_clip_scores) if all_clip_scores else 0.0
451
+ avg_aesthetic_score = np.mean(all_aesthetic_scores) if all_aesthetic_scores else 0.0
452
+ avg_pick_score = np.mean(all_pick_scores) if all_pick_scores else 0.0
453
+ avg_hpsv2_score = np.mean(all_hpsv2_scores) if all_hpsv2_scores else 0.0
454
+ avg_hpsv21_score = np.mean(all_hpsv21_scores) if all_hpsv21_scores else 0.0
455
+ avg_imagereward_score = np.mean(all_imagereward_scores) if all_imagereward_scores else 0.0
456
+
457
+ return avg_reward, fid_metric, avg_clip_score, avg_aesthetic_score, avg_pick_score, avg_hpsv2_score, avg_hpsv21_score, avg_imagereward_score, lr_history_first_image
458
+
459
+
460
+ def auto_increment_path(base_path):
461
+ """
462
+ Create an auto-incrementing run folder inside base_path.
463
+ Returns: base_path/run_1, base_path/run_2, etc.
464
+ """
465
+ base_path = Path(base_path)
466
+ base_path.mkdir(parents=True, exist_ok=True) # Ensure base directory exists
467
+
468
+ i = 1
469
+ while True:
470
+ new_path = base_path / f"run_{i}"
471
+ if not new_path.exists():
472
+ return new_path
473
+ i += 1
474
+
475
+
476
+ def main():
477
+ parser = argparse.ArgumentParser(description="Evaluate baseline and gradient ascent pipelines")
478
+ parser.add_argument("--data_dir", type=str, default="./data", help="Path to data directory")
479
+ parser.add_argument("--dataset_type", type=str, default="coco", choices=["coco", "pickapic"],
480
+ help="Dataset to use for evaluation: coco or pickapic (default: coco)")
481
+ parser.add_argument("--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0", help="Base model path")
482
+ parser.add_argument("--model_variant", type=str, default="origin",
483
+ choices=["spo", "lpo"],
484
+ help="SDXL model variant to use (default: origin)")
485
+ parser.add_argument("--lrm_model", type=str, default="casiatao/LRM", help="LRM model path")
486
+ parser.add_argument("--num_steps", type=int, default=50, help="Number of inference steps")
487
+ parser.add_argument("--cfg_scale", type=float, default=7.5, help="Classifier-free guidance scale")
488
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
489
+ parser.add_argument("--max_samples", type=int, default=None, help="Max samples to evaluate (None for all)")
490
+ parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generation (use 1 for reward model compatibility)")
491
+ parser.add_argument("--fid_batch_size", type=int, default=32, help="Batch size for FID computation")
492
+ parser.add_argument("--log_interval", type=int, default=10, help="Log FID and metrics every N batches")
493
+ parser.add_argument("--output_dir", type=str, default="eval_outputs", help="Directory to save generated images and results")
494
+ parser.add_argument("--save_images", action="store_true", help="Save all generated images to output directory")
495
+ parser.add_argument("--mode", type=str, default="both", choices=["baseline", "gradient_ascent", "both"],
496
+ help="Which evaluation to run: baseline, gradient_ascent, or both (default: both)")
497
+
498
+ # Metrics selection
499
+ parser.add_argument("--metrics", type=str, nargs="+", default=["clip", "aesthetic"],
500
+ choices=["fid", "clip", "aesthetic", "pickscore", "hpsv2", "hpsv21", "imagereward"],
501
+ help="Which metrics to evaluate (default: clip aesthetic)")
502
+
503
+ # Gradient ascent config
504
+ parser.add_argument("--grad_config", type=str, default=None,
505
+ help=f"Gradient ascent config preset (available: {', '.join(list_configs())}). "
506
+ "If provided, overrides individual grad_* arguments.")
507
+ parser.add_argument("--grad_range_start", type=int, default=0, help="Gradient timestep range start")
508
+ parser.add_argument("--grad_range_end", type=int, default=700, help="Gradient timestep range end")
509
+ parser.add_argument("--grad_steps", type=int, default=5, help="Number of gradient steps per timestep (use 5 for better reward improvement)")
510
+ parser.add_argument("--grad_step_size", type=float, default=0.1, help="Gradient step size (initial LR)")
511
+
512
+ # Config overrides (these override values from grad_config if specified)
513
+ parser.add_argument("--override_momentum", type=float, default=None, help="Override momentum value from grad_config")
514
+ parser.add_argument("--override_num_grad_steps", type=int, default=None, help="Override num_grad_steps from grad_config")
515
+ parser.add_argument("--override_grad_step_size", type=float, default=None, help="Override grad_step_size from grad_config")
516
+
517
+ # Cuda
518
+ parser.add_argument("--cuda", type=int, default=0, help="Use CUDA device id")
519
+
520
+ args = parser.parse_args()
521
+
522
+ seed_everything(args.seed)
523
+
524
+ # Configuration
525
+ device = f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu"
526
+ #dtype = torch.float16 #if torch.cuda.is_available() else torch.float32
527
+ dtype = torch.bfloat16
528
+
529
+ # Create auto-incremented output directory
530
+ args.output_dir = auto_increment_path(args.output_dir)
531
+
532
+ # Setup logging to file
533
+ tee_logger, log_file = setup_logging(args.output_dir)
534
+
535
+ print("="*70)
536
+ print("FID EVALUATION: BASELINE vs GRADIENT ASCENT")
537
+ print("="*70)
538
+ print(f"\nLogging to: {log_file}")
539
+ print(f"\nDevice: {device}")
540
+ print(f"Dataset: {args.dataset_type.upper()}")
541
+ print(f"Data directory: {args.data_dir}")
542
+ print(f"Base model: {args.base_model}")
543
+ print(f"Model variant: {args.model_variant}")
544
+ print(f"LRM model: {args.lrm_model}")
545
+ print(f"Inference steps: {args.num_steps}")
546
+ print(f"CFG scale: {args.cfg_scale}")
547
+ print(f"Batch size: {args.batch_size}")
548
+ print(f"Max samples: {args.max_samples or 'All'}")
549
+ print(f"Output directory: {args.output_dir}")
550
+ print(f"Save images: {args.save_images}")
551
+ print(f"Evaluation mode: {args.mode}")
552
+ print(f"Metrics to evaluate: {', '.join(args.metrics).upper()}")
553
+ if args.grad_config:
554
+ print(f"Gradient ascent config: {args.grad_config}")
555
+
556
+ # Load validation data
557
+ print("\n" + "="*70)
558
+ print("1. LOADING VALIDATION DATA")
559
+ print("="*70)
560
+ prompts, image_paths = load_validation_data(args.data_dir, args.max_samples, args.dataset_type)
561
+
562
+ # Automatically disable FID if no reference images available (e.g., Pick-a-Pic dataset)
563
+ can_compute_fid = image_paths is not None
564
+ if not can_compute_fid and "fid" in args.metrics:
565
+ print("\n⚠ Warning: FID metric requested but no reference images available. FID will be skipped.")
566
+ args.metrics = [m for m in args.metrics if m != "fid"]
567
+
568
+ # Load reward model
569
+ print("\n" + "="*70)
570
+ print("2. LOADING REWARD MODEL")
571
+ print("="*70)
572
+ reward_model = LRMRewardModelXL(
573
+ pretrained_model_name_or_path=args.base_model,
574
+ lrm_model_path=args.lrm_model,
575
+ guidance_scale=args.cfg_scale,
576
+ device=device
577
+ )
578
+ if dtype == torch.float16:
579
+ reward_model = reward_model.half()
580
+ elif dtype == torch.bfloat16:
581
+ reward_model = reward_model.to(torch.bfloat16)
582
+ reward_model.eval()
583
+ print("✓ Reward model loaded")
584
+
585
+ # Load pipeline
586
+ print("\n" + "="*70)
587
+ print("3. LOADING PIPELINE")
588
+ print("="*70)
589
+
590
+ # Load model based on variant
591
+ if args.model_variant == "spo":
592
+ base_pipeline = StableDiffusionXLPipeline.from_pretrained(
593
+ 'SPO-Diffusion-Models/SPO-SDXL_4k-p_10ep',
594
+ torch_dtype=dtype,
595
+ safety_checker=None,
596
+ )
597
+ args.cfg_scale = 5.0 # SPO uses CFG 5.0
598
+ print(f"✓ Loaded SPO SDXL model (cfg_scale adjusted to 5.0)")
599
+ elif args.model_variant == "lpo":
600
+ unet = UNet2DConditionModel.from_pretrained(
601
+ 'casiatao/LPO',
602
+ subfolder="lpo_sdxl_merge/unet",
603
+ torch_dtype=dtype
604
+ )
605
+ base_pipeline = StableDiffusionXLPipeline.from_pretrained(
606
+ args.base_model,
607
+ torch_dtype=dtype,
608
+ variant="fp16",
609
+ unet=unet
610
+ )
611
+ args.cfg_scale = 5.0 # LPO uses CFG 5.0
612
+ print(f"✓ Loaded LPO SDXL model (cfg_scale adjusted to 5.0)")
613
+
614
+ pipeline = StableDiffusionXLGradientAscentPipeline(**base_pipeline.components)
615
+ pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
616
+ pipeline = pipeline.to(device)
617
+ pipeline.set_reward_model(reward_model)
618
+ print("✓ Pipeline loaded")
619
+
620
+ # Load CLIP scorer
621
+ print("\n" + "="*70)
622
+ print("3.5. LOADING CLIP AND AESTHETIC SCORERS")
623
+ print("="*70)
624
+
625
+ # Only load scorers for requested metrics
626
+ clip_scorer = None
627
+ aesthetic_scorer = None
628
+ pick_scorer = None
629
+ hpsv2_scorer = None
630
+ hpsv21_scorer = None
631
+ imagereward_scorer = None
632
+
633
+ if "clip" in args.metrics:
634
+ try:
635
+ clip_scorer = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14").to(device)
636
+ print("✓ CLIP scorer loaded")
637
+ except Exception as e:
638
+ print(f"Warning: Could not load CLIP scorer: {e}")
639
+ clip_scorer = None
640
+ else:
641
+ print("⊘ CLIP scorer skipped (not in selected metrics)")
642
+
643
+ if "aesthetic" in args.metrics:
644
+ try:
645
+ aesthetic_scorer = AestheticScorer(dtype=dtype, device=device)
646
+ print("✓ Aesthetic scorer loaded")
647
+ except Exception as e:
648
+ print(f"Warning: Could not load Aesthetic scorer: {e}")
649
+ aesthetic_scorer = None
650
+ else:
651
+ print("⊘ Aesthetic scorer skipped (not in selected metrics)")
652
+
653
+ if "pickscore" in args.metrics:
654
+ try:
655
+ pick_scorer = PickScorer(
656
+ processor_name_or_path="laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
657
+ model_pretrained_name_or_path="yuvalkirstain/PickScore_v1",
658
+ device=device
659
+ )
660
+ print("✓ PickScore scorer loaded")
661
+ except Exception as e:
662
+ print(f"Warning: Could not load PickScore scorer: {e}")
663
+ pick_scorer = None
664
+ else:
665
+ print("⊘ PickScore scorer skipped (not in selected metrics)")
666
+
667
+ if "hpsv2" in args.metrics:
668
+ try:
669
+ hpsv2_scorer = HPSv2Scorer(
670
+ clip_pretrained_name_or_path=hf_hub_download(
671
+ repo_id="laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
672
+ filename="open_clip_pytorch_model.bin"
673
+ ),
674
+ model_pretrained_name_or_path=hf_hub_download(
675
+ repo_id="xswu/HPSv2",
676
+ filename="HPS_v2_compressed.pt"
677
+ ),
678
+ device=device
679
+ )
680
+ print("✓ HPSv2 scorer loaded")
681
+ except Exception as e:
682
+ print(f"Warning: Could not load HPSv2 scorer: {e}")
683
+ hpsv2_scorer = None
684
+ else:
685
+ print("⊘ HPSv2 scorer skipped (not in selected metrics)")
686
+
687
+ if "hpsv21" in args.metrics:
688
+ try:
689
+ hpsv21_scorer = HPSv2Scorer(
690
+ clip_pretrained_name_or_path=hf_hub_download(
691
+ repo_id="laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
692
+ filename="open_clip_pytorch_model.bin"
693
+ ),
694
+ model_pretrained_name_or_path=hf_hub_download(
695
+ repo_id="xswu/HPSv2",
696
+ filename="HPS_v2.1_compressed.pt"
697
+ ),
698
+ device=device
699
+ )
700
+ print("✓ HPSv2.1 scorer loaded")
701
+ except Exception as e:
702
+ print(f"Warning: Could not load HPSv2.1 scorer: {e}")
703
+ hpsv21_scorer = None
704
+ else:
705
+ print("⊘ HPSv2.1 scorer skipped (not in selected metrics)")
706
+
707
+ if "imagereward" in args.metrics:
708
+ try:
709
+ imagereward_scorer = load_imagereward(
710
+ model_path=hf_hub_download(repo_id="THUDM/ImageReward", filename="ImageReward.pt"),
711
+ med_config=hf_hub_download(repo_id="THUDM/ImageReward", filename="med_config.json"),
712
+ device=device
713
+ )
714
+ print("✓ ImageReward scorer loaded")
715
+ except Exception as e:
716
+ print(f"Warning: Could not load ImageReward scorer: {e}")
717
+ imagereward_scorer = None
718
+ else:
719
+ print("⊘ ImageReward scorer skipped (not in selected metrics)")
720
+
721
+ # Configure gradient ascent
722
+ print("\n" + "="*70)
723
+ print("4. CONFIGURING GRADIENT ASCENT")
724
+ print("="*70)
725
+
726
+ # Use config preset if provided, otherwise use individual args
727
+ if args.grad_config:
728
+ print(f"Loading gradient ascent config: {args.grad_config}")
729
+ grad_config = get_config(args.grad_config)
730
+ print(f"Config loaded: {grad_config}")
731
+
732
+ # Apply overrides if specified
733
+ if args.override_momentum is not None:
734
+ grad_config['momentum'] = args.override_momentum
735
+ print(f" Overriding momentum: {args.override_momentum}")
736
+ if args.override_num_grad_steps is not None:
737
+ grad_config['num_grad_steps'] = args.override_num_grad_steps
738
+ print(f" Overriding num_grad_steps: {args.override_num_grad_steps}")
739
+ if args.override_grad_step_size is not None:
740
+ grad_config['grad_step_size'] = args.override_grad_step_size
741
+ print(f" Overriding grad_step_size: {args.override_grad_step_size}")
742
+ else:
743
+ grad_config = {
744
+ "grad_timestep_range": (args.grad_range_start, args.grad_range_end),
745
+ "num_grad_steps": args.grad_steps,
746
+ "grad_step_size": args.grad_step_size,
747
+ }
748
+ print(f"Using manual gradient ascent configuration")
749
+
750
+ print(f"Gradient timestep range: {grad_config.get('grad_timestep_range', (args.grad_range_start, args.grad_range_end))}")
751
+ print(f"Gradient steps: {grad_config.get('num_grad_steps', args.grad_steps)}")
752
+ print(f"Gradient step size (initial LR): {grad_config.get('grad_step_size', args.grad_step_size)}")
753
+ if grad_config.get('lr_scheduler_type'):
754
+ print(f"LR Scheduler: {grad_config['lr_scheduler_type']}")
755
+ if grad_config.get('use_momentum'):
756
+ print(f"Momentum: {grad_config.get('momentum', 0.9)} (Nesterov: {grad_config.get('use_nesterov', False)})")
757
+
758
+ pipeline.enable_gradient_ascent(**grad_config)
759
+
760
+ # Initialize result variables
761
+ fid_score_baseline = None
762
+ avg_reward_baseline = None
763
+ clip_score_baseline = None
764
+ aesthetic_score_baseline = None
765
+ pick_score_baseline = None
766
+ hpsv2_score_baseline = None
767
+ hpsv21_score_baseline = None
768
+ imagereward_score_baseline = None
769
+ fid_score_grad = None
770
+ avg_reward_grad = None
771
+ clip_score_grad = None
772
+ aesthetic_score_grad = None
773
+ pick_score_grad = None
774
+ hpsv2_score_grad = None
775
+ hpsv21_score_grad = None
776
+ imagereward_score_grad = None
777
+ grad_stats = None
778
+
779
+ # ========== BASELINE EVALUATION ==========
780
+ if args.mode in ["baseline", "both"]:
781
+ print("\n" + "="*70)
782
+ print("5. EVALUATING BASELINE")
783
+ print("="*70)
784
+
785
+ # Generate and evaluate baseline
786
+ avg_reward_baseline, fid_baseline, clip_score_baseline, aesthetic_score_baseline, pick_score_baseline, hpsv2_score_baseline, hpsv21_score_baseline, imagereward_score_baseline, _ = generate_and_evaluate(
787
+ pipeline=pipeline,
788
+ prompts=prompts,
789
+ image_paths=image_paths,
790
+ device=device,
791
+ dtype=dtype,
792
+ num_inference_steps=args.num_steps,
793
+ guidance_scale=args.cfg_scale,
794
+ seed=args.seed,
795
+ batch_size=args.batch_size,
796
+ apply_gradient_ascent=False,
797
+ mode_name="baseline",
798
+ log_interval=args.log_interval,
799
+ output_dir=args.output_dir,
800
+ save_images=args.save_images,
801
+ clip_scorer=clip_scorer,
802
+ aesthetic_scorer=aesthetic_scorer,
803
+ pick_scorer=pick_scorer,
804
+ hpsv2_scorer=hpsv2_scorer,
805
+ hpsv21_scorer=hpsv21_scorer,
806
+ imagereward_scorer=imagereward_scorer,
807
+ compute_fid=("fid" in args.metrics and can_compute_fid)
808
+ )
809
+
810
+ # Compute FID for baseline if requested
811
+ if "fid" in args.metrics and fid_baseline is not None:
812
+ fid_score_baseline = fid_baseline.compute().item()
813
+ print(f"\n✓ Baseline FID: {fid_score_baseline:.4f}")
814
+ print(f"✓ Baseline Avg Reward: {avg_reward_baseline:.4f}")
815
+ if "clip" in args.metrics:
816
+ print(f"✓ Baseline Avg CLIP Score: {clip_score_baseline:.4f}")
817
+ if "aesthetic" in args.metrics:
818
+ print(f"✓ Baseline Avg Aesthetic Score: {aesthetic_score_baseline:.4f}")
819
+ if "pickscore" in args.metrics and pick_score_baseline is not None:
820
+ print(f"✓ Baseline Avg PickScore: {pick_score_baseline:.4f}")
821
+ if "hpsv2" in args.metrics and hpsv2_score_baseline is not None:
822
+ print(f"✓ Baseline Avg HPSv2 Score: {hpsv2_score_baseline:.4f}")
823
+ if "hpsv21" in args.metrics and hpsv21_score_baseline is not None:
824
+ print(f"✓ Baseline Avg HPSv2.1 Score: {hpsv21_score_baseline:.4f}")
825
+ if "imagereward" in args.metrics and imagereward_score_baseline is not None:
826
+ print(f"✓ Baseline Avg ImageReward: {imagereward_score_baseline:.4f}")
827
+
828
+ # ========== GRADIENT ASCENT EVALUATION ==========
829
+ if args.mode in ["gradient_ascent", "both"]:
830
+ print("\n" + "="*70)
831
+ print("6. EVALUATING GRADIENT ASCENT")
832
+ print("="*70)
833
+
834
+ # Generate and evaluate with gradient ascent
835
+ avg_reward_grad, fid_grad, clip_score_grad, aesthetic_score_grad, pick_score_grad, hpsv2_score_grad, hpsv21_score_grad, imagereward_score_grad, lr_history = generate_and_evaluate(
836
+ pipeline=pipeline,
837
+ prompts=prompts,
838
+ image_paths=image_paths,
839
+ device=device,
840
+ dtype=dtype,
841
+ num_inference_steps=args.num_steps,
842
+ guidance_scale=args.cfg_scale,
843
+ seed=args.seed,
844
+ batch_size=args.batch_size,
845
+ apply_gradient_ascent=True,
846
+ mode_name="gradient_ascent",
847
+ log_interval=args.log_interval,
848
+ output_dir=args.output_dir,
849
+ save_images=args.save_images,
850
+ clip_scorer=clip_scorer,
851
+ aesthetic_scorer=aesthetic_scorer,
852
+ pick_scorer=pick_scorer,
853
+ hpsv2_scorer=hpsv2_scorer,
854
+ hpsv21_scorer=hpsv21_scorer,
855
+ imagereward_scorer=imagereward_scorer,
856
+ compute_fid=("fid" in args.metrics and can_compute_fid)
857
+ )
858
+
859
+ # Compute FID for gradient ascent if requested
860
+ if "fid" in args.metrics and fid_grad is not None:
861
+ fid_score_grad = fid_grad.compute().item()
862
+ print(f"\n✓ Gradient Ascent FID: {fid_score_grad:.4f}")
863
+ print(f"✓ Gradient Ascent Avg Reward: {avg_reward_grad:.4f}")
864
+ if "clip" in args.metrics:
865
+ print(f"✓ Gradient Ascent Avg CLIP Score: {clip_score_grad:.4f}")
866
+ if "aesthetic" in args.metrics:
867
+ print(f"✓ Gradient Ascent Avg Aesthetic Score: {aesthetic_score_grad:.4f}")
868
+ if "pickscore" in args.metrics and pick_score_grad is not None:
869
+ print(f"✓ Gradient Ascent Avg PickScore: {pick_score_grad:.4f}")
870
+ if "hpsv2" in args.metrics and hpsv2_score_grad is not None:
871
+ print(f"✓ Gradient Ascent Avg HPSv2 Score: {hpsv2_score_grad:.4f}")
872
+ if "hpsv21" in args.metrics and hpsv21_score_grad is not None:
873
+ print(f"✓ Gradient Ascent Avg HPSv2.1 Score: {hpsv21_score_grad:.4f}")
874
+ if "imagereward" in args.metrics and imagereward_score_grad is not None:
875
+ print(f"✓ Gradient Ascent Avg ImageReward: {imagereward_score_grad:.4f}")
876
+
877
+ # Get gradient stats
878
+ grad_stats = pipeline.grad_guidance.get_statistics()
879
+ if grad_stats:
880
+ print(f"\nGradient Ascent Statistics:")
881
+ print(f" Applications: {grad_stats['num_applications']}")
882
+ print(f" Total reward improvement: {grad_stats['total_reward_improvement']:+.4f}")
883
+ print(f" Avg reward improvement: {grad_stats['avg_reward_improvement']:+.4f}")
884
+
885
+ # Plot LR curve if we captured it
886
+ if lr_history is not None and lr_history['learning_rates']:
887
+ plot_path = Path(args.output_dir) / "lr_curve.png"
888
+
889
+ # LR values are now continuous across all gradient steps
890
+ lrs = lr_history['learning_rates']
891
+ steps = list(range(len(lrs))) # Step indices (0 to total_steps-1)
892
+
893
+ plt.figure(figsize=(12, 6))
894
+ plt.plot(steps, lrs, linewidth=2, color='blue', alpha=0.8)
895
+
896
+ # Mark the first step with a star
897
+ plt.plot(steps[0], lrs[0], marker='*', markersize=20, color='gold',
898
+ markeredgecolor='darkgoldenrod', markeredgewidth=2, zorder=5)
899
+
900
+ # Mark timestep boundaries
901
+ num_timesteps = len(lr_history['timesteps'])
902
+ num_grad_steps_per_timestep = len(lrs) // num_timesteps if num_timesteps > 0 else 0
903
+ if num_grad_steps_per_timestep > 0:
904
+ for i in range(num_timesteps + 1):
905
+ step_idx = i * num_grad_steps_per_timestep
906
+ if step_idx <= len(lrs):
907
+ plt.axvline(x=step_idx, color='red', linestyle='--', alpha=0.3, linewidth=1)
908
+ if i < num_timesteps:
909
+ plt.text(step_idx, plt.ylim()[1] * 0.95, f't={lr_history["timesteps"][i]}',
910
+ fontsize=8, color='red', alpha=0.7, ha='left')
911
+
912
+ plt.xlabel('Global Gradient Step', fontsize=12)
913
+ plt.ylabel('Learning Rate', fontsize=12)
914
+ plt.title(f'Learning Rate Evolution Across All Gradient Steps\\nPrompt: "{lr_history["prompt"][:60]}..."',
915
+ fontsize=12, fontweight='bold')
916
+ plt.grid(True, alpha=0.3)
917
+
918
+ # Add info text
919
+ num_timesteps = len(lr_history['timesteps'])
920
+ num_grad_steps_per_timestep = len(lrs) // num_timesteps if num_timesteps > 0 else 0
921
+ plt.text(0.02, 0.98,
922
+ f'Total timesteps: {num_timesteps}\\nGrad steps/timestep: {num_grad_steps_per_timestep}\\nTotal grad steps: {len(lrs)}',
923
+ transform=plt.gca().transAxes, fontsize=10, verticalalignment='top',
924
+ bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
925
+
926
+ plt.tight_layout()
927
+ plt.savefig(plot_path, dpi=150, bbox_inches='tight')
928
+ plt.close()
929
+ print(f"\n✓ Saved LR curve plot to: {plot_path}")
930
+ print(f" Total gradient steps: {len(lrs)}")
931
+ print(f" LR range: {min(lrs):.6f} → {max(lrs):.6f}")
932
+
933
+ # Plot Rewards curve if we captured it
934
+ if lr_history is not None and lr_history['rewards']:
935
+ plot_path = Path(args.output_dir) / "rewards_curve.png"
936
+
937
+ # Reward values are now continuous across all gradient steps
938
+ rewards = lr_history['rewards']
939
+ steps = list(range(len(rewards))) # Step indices (0 to total_steps-1)
940
+
941
+ plt.figure(figsize=(12, 6))
942
+ plt.plot(steps, rewards, linewidth=2, color='green', alpha=0.8)
943
+
944
+ # Mark the first step with a star
945
+ plt.plot(steps[0], rewards[0], marker='*', markersize=20, color='gold',
946
+ markeredgecolor='darkgoldenrod', markeredgewidth=2, zorder=5)
947
+
948
+ # Mark timestep boundaries
949
+ num_timesteps = len(lr_history['timesteps'])
950
+ # rewards has one extra value at the start (initial) compared to gradient steps
951
+ num_grad_steps_per_timestep = (len(rewards) - num_timesteps) // num_timesteps if num_timesteps > 0 else 0
952
+ if num_grad_steps_per_timestep > 0:
953
+ for i in range(num_timesteps + 1):
954
+ step_idx = i * (num_grad_steps_per_timestep + 1) # +1 because reward_history includes initial
955
+ if step_idx <= len(rewards):
956
+ plt.axvline(x=step_idx, color='red', linestyle='--', alpha=0.3, linewidth=1)
957
+ if i < num_timesteps:
958
+ plt.text(step_idx, plt.ylim()[1] * 0.95, f't={lr_history["timesteps"][i]}',
959
+ fontsize=8, color='red', alpha=0.7, ha='left')
960
+
961
+ plt.xlabel('Global Gradient Step', fontsize=12)
962
+ plt.ylabel('Reward Score', fontsize=12)
963
+ plt.title(f'Reward Evolution Across All Gradient Steps\nPrompt: "{lr_history["prompt"][:60]}..."',
964
+ fontsize=12, fontweight='bold')
965
+ plt.grid(True, alpha=0.3)
966
+
967
+ # Add info text
968
+ num_timesteps = len(lr_history['timesteps'])
969
+ reward_improvement = rewards[-1] - rewards[0] if len(rewards) > 1 else 0
970
+ plt.text(0.02, 0.98,
971
+ f'Total timesteps: {num_timesteps}\nTotal grad steps: {len(rewards)}\n'
972
+ f'Initial reward: {rewards[0]:.4f}\nFinal reward: {rewards[-1]:.4f}\n'
973
+ f'Improvement: {reward_improvement:+.4f}',
974
+ transform=plt.gca().transAxes, fontsize=10, verticalalignment='top',
975
+ bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.5))
976
+
977
+ plt.tight_layout()
978
+ plt.savefig(plot_path, dpi=150, bbox_inches='tight')
979
+ plt.close()
980
+ print(f"\n✓ Saved Rewards curve plot to: {plot_path}")
981
+ print(f" Total gradient steps: {len(rewards)}")
982
+ print(f" Reward range: {min(rewards):.4f} → {max(rewards):.4f}")
983
+ print(f" Total improvement: {reward_improvement:+.4f}")
984
+
985
+ # ========== FINAL RESULTS ==========
986
+ print("\n" + "="*70)
987
+ print("FINAL RESULTS")
988
+ print("="*70)
989
+
990
+ if avg_reward_baseline is not None:
991
+ print(f"\nBaseline:")
992
+ if fid_score_baseline is not None:
993
+ print(f" FID Score: {fid_score_baseline:.4f}")
994
+ print(f" Avg Reward: {avg_reward_baseline:.4f}")
995
+ if "clip" in args.metrics and clip_score_baseline is not None:
996
+ print(f" Avg CLIP Score: {clip_score_baseline:.4f}")
997
+ if "aesthetic" in args.metrics and aesthetic_score_baseline is not None:
998
+ print(f" Avg Aesthetic: {aesthetic_score_baseline:.4f}")
999
+ if "pickscore" in args.metrics and pick_score_baseline is not None:
1000
+ print(f" Avg PickScore: {pick_score_baseline:.4f}")
1001
+ if "hpsv2" in args.metrics and hpsv2_score_baseline is not None:
1002
+ print(f" Avg HPSv2: {hpsv2_score_baseline:.4f}")
1003
+ if "hpsv21" in args.metrics and hpsv21_score_baseline is not None:
1004
+ print(f" Avg HPSv2.1: {hpsv21_score_baseline:.4f}")
1005
+ if "imagereward" in args.metrics and imagereward_score_baseline is not None:
1006
+ print(f" Avg ImageReward: {imagereward_score_baseline:.4f}")
1007
+
1008
+ if avg_reward_grad is not None:
1009
+ print(f"\nGradient Ascent:")
1010
+ if fid_score_grad is not None:
1011
+ print(f" FID Score: {fid_score_grad:.4f}")
1012
+ print(f" Avg Reward: {avg_reward_grad:.4f}")
1013
+ if "clip" in args.metrics and clip_score_grad is not None:
1014
+ print(f" Avg CLIP Score: {clip_score_grad:.4f}")
1015
+ if "aesthetic" in args.metrics and aesthetic_score_grad is not None:
1016
+ print(f" Avg Aesthetic: {aesthetic_score_grad:.4f}")
1017
+ if "pickscore" in args.metrics and pick_score_grad is not None:
1018
+ print(f" Avg PickScore: {pick_score_grad:.4f}")
1019
+ if "hpsv2" in args.metrics and hpsv2_score_grad is not None:
1020
+ print(f" Avg HPSv2: {hpsv2_score_grad:.4f}")
1021
+ if "hpsv21" in args.metrics and hpsv21_score_grad is not None:
1022
+ print(f" Avg HPSv2.1: {hpsv21_score_grad:.4f}")
1023
+ if "imagereward" in args.metrics and imagereward_score_grad is not None:
1024
+ print(f" Avg ImageReward: {imagereward_score_grad:.4f}")
1025
+
1026
+ if avg_reward_baseline is not None and avg_reward_grad is not None:
1027
+ print(f"\nComparison:")
1028
+ if fid_score_baseline is not None and fid_score_grad is not None:
1029
+ fid_diff = fid_score_grad - fid_score_baseline
1030
+ print(f" FID Change: {fid_diff:+.4f} ({'worse' if fid_diff > 0 else 'better'}, lower is better)")
1031
+ reward_diff = avg_reward_grad - avg_reward_baseline
1032
+ print(f" Reward Change: {reward_diff:+.4f} ({'better' if reward_diff > 0 else 'worse'}, higher is better)")
1033
+ if "clip" in args.metrics and clip_score_baseline is not None and clip_score_grad is not None:
1034
+ clip_diff = clip_score_grad - clip_score_baseline
1035
+ print(f" CLIP Change: {clip_diff:+.4f} ({'better' if clip_diff > 0 else 'worse'}, higher is better)")
1036
+ if "aesthetic" in args.metrics and aesthetic_score_baseline is not None and aesthetic_score_grad is not None:
1037
+ aesthetic_diff = aesthetic_score_grad - aesthetic_score_baseline
1038
+ print(f" Aesthetic Change: {aesthetic_diff:+.4f} ({'better' if aesthetic_diff > 0 else 'worse'}, higher is better)")
1039
+ if "pickscore" in args.metrics and pick_score_baseline is not None and pick_score_grad is not None:
1040
+ pick_diff = pick_score_grad - pick_score_baseline
1041
+ print(f" PickScore Change: {pick_diff:+.4f} ({'better' if pick_diff > 0 else 'worse'}, higher is better)")
1042
+ if "hpsv2" in args.metrics and hpsv2_score_baseline is not None and hpsv2_score_grad is not None:
1043
+ hpsv2_diff = hpsv2_score_grad - hpsv2_score_baseline
1044
+ print(f" HPSv2 Change: {hpsv2_diff:+.4f} ({'better' if hpsv2_diff > 0 else 'worse'}, higher is better)")
1045
+ if "hpsv21" in args.metrics and hpsv21_score_baseline is not None and hpsv21_score_grad is not None:
1046
+ hpsv21_diff = hpsv21_score_grad - hpsv21_score_baseline
1047
+ print(f" HPSv2.1 Change: {hpsv21_diff:+.4f} ({'better' if hpsv21_diff > 0 else 'worse'}, higher is better)")
1048
+ if "imagereward" in args.metrics and imagereward_score_baseline is not None and imagereward_score_grad is not None:
1049
+ imagereward_diff = imagereward_score_grad - imagereward_score_baseline
1050
+ print(f" ImageReward Chg: {imagereward_diff:+.4f} ({'better' if imagereward_diff > 0 else 'worse'}, higher is better)")
1051
+
1052
+ # Save results to file
1053
+ results = {
1054
+ "mode": args.mode,
1055
+ "metrics": args.metrics,
1056
+ "config": {
1057
+ "num_samples": len(prompts),
1058
+ "num_steps": args.num_steps,
1059
+ "cfg_scale": args.cfg_scale,
1060
+ "grad_range": [args.grad_range_start, args.grad_range_end],
1061
+ "grad_steps": args.grad_steps,
1062
+ "grad_step_size": args.grad_step_size
1063
+ }
1064
+ }
1065
+
1066
+ if avg_reward_baseline is not None:
1067
+ results["baseline"] = {"avg_reward": avg_reward_baseline}
1068
+ if fid_score_baseline is not None:
1069
+ results["baseline"]["fid"] = fid_score_baseline
1070
+ if "clip" in args.metrics and clip_score_baseline is not None:
1071
+ results["baseline"]["clip_score"] = clip_score_baseline
1072
+ if "aesthetic" in args.metrics and aesthetic_score_baseline is not None:
1073
+ results["baseline"]["aesthetic_score"] = aesthetic_score_baseline
1074
+ if "pickscore" in args.metrics and pick_score_baseline is not None:
1075
+ results["baseline"]["pickscore"] = pick_score_baseline
1076
+ if "hpsv2" in args.metrics and hpsv2_score_baseline is not None:
1077
+ results["baseline"]["hpsv2_score"] = hpsv2_score_baseline
1078
+ if "hpsv21" in args.metrics and hpsv21_score_baseline is not None:
1079
+ results["baseline"]["hpsv21_score"] = hpsv21_score_baseline
1080
+ if "imagereward" in args.metrics and imagereward_score_baseline is not None:
1081
+ results["baseline"]["imagereward_score"] = imagereward_score_baseline
1082
+
1083
+ if avg_reward_grad is not None:
1084
+ results["gradient_ascent"] = {"avg_reward": avg_reward_grad}
1085
+ if fid_score_grad is not None:
1086
+ results["gradient_ascent"]["fid"] = fid_score_grad
1087
+ if "clip" in args.metrics and clip_score_grad is not None:
1088
+ results["gradient_ascent"]["clip_score"] = clip_score_grad
1089
+ if "aesthetic" in args.metrics and aesthetic_score_grad is not None:
1090
+ results["gradient_ascent"]["aesthetic_score"] = aesthetic_score_grad
1091
+ if "pickscore" in args.metrics and pick_score_grad is not None:
1092
+ results["gradient_ascent"]["pickscore"] = pick_score_grad
1093
+ if "hpsv2" in args.metrics and hpsv2_score_grad is not None:
1094
+ results["gradient_ascent"]["hpsv2_score"] = hpsv2_score_grad
1095
+ if "hpsv21" in args.metrics and hpsv21_score_grad is not None:
1096
+ results["gradient_ascent"]["hpsv21_score"] = hpsv21_score_grad
1097
+ if "imagereward" in args.metrics and imagereward_score_grad is not None:
1098
+ results["gradient_ascent"]["imagereward_score"] = imagereward_score_grad
1099
+ if grad_stats:
1100
+ results["gradient_ascent"]["stats"] = grad_stats
1101
+
1102
+ if avg_reward_baseline is not None and avg_reward_grad is not None:
1103
+ results["comparison"] = {
1104
+ "reward_difference": avg_reward_grad - avg_reward_baseline
1105
+ }
1106
+ if fid_score_baseline is not None and fid_score_grad is not None:
1107
+ results["comparison"]["fid_difference"] = fid_score_grad - fid_score_baseline
1108
+ if "clip" in args.metrics and clip_score_baseline is not None and clip_score_grad is not None:
1109
+ results["comparison"]["clip_difference"] = clip_score_grad - clip_score_baseline
1110
+ if "aesthetic" in args.metrics and aesthetic_score_baseline is not None and aesthetic_score_grad is not None:
1111
+ results["comparison"]["aesthetic_difference"] = aesthetic_score_grad - aesthetic_score_baseline
1112
+ if "pickscore" in args.metrics and pick_score_baseline is not None and pick_score_grad is not None:
1113
+ results["comparison"]["pickscore_difference"] = pick_score_grad - pick_score_baseline
1114
+ if "hpsv2" in args.metrics and hpsv2_score_baseline is not None and hpsv2_score_grad is not None:
1115
+ results["comparison"]["hpsv2_difference"] = hpsv2_score_grad - hpsv2_score_baseline
1116
+ if "hpsv21" in args.metrics and hpsv21_score_baseline is not None and hpsv21_score_grad is not None:
1117
+ results["comparison"]["hpsv21_difference"] = hpsv21_score_grad - hpsv21_score_baseline
1118
+ if "imagereward" in args.metrics and imagereward_score_baseline is not None and imagereward_score_grad is not None:
1119
+ results["comparison"]["imagereward_difference"] = imagereward_score_grad - imagereward_score_baseline
1120
+
1121
+ # Save results to output directory
1122
+ output_path = Path(args.output_dir)
1123
+ output_path.mkdir(parents=True, exist_ok=True)
1124
+ results_path = output_path / "evaluation_results.txt"
1125
+
1126
+ with open(results_path, "w") as f:
1127
+ for k, v in results.items():
1128
+ f.write(f"{k}: {v}\n")
1129
+
1130
+
1131
+ print(f"\n✓ Results saved to: {results_path}")
1132
+ if args.save_images:
1133
+ print(f"✓ Generated images saved to: {output_path}/baseline/ and {output_path}/gradient_ascent/")
1134
+ print("\n" + "="*70)
1135
+
1136
+ # Close logger
1137
+ tee_logger.close()
1138
+ sys.stdout = tee_logger.terminal
1139
+
1140
+
1141
+ if __name__ == "__main__":
1142
+ main()
1143
+
Reward_sdxl_idealized/examples.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export HF_HOME=/efs/vkvermaa/2025/diffusion/latent_correct/hf_cache
2
+ Dataset_Name="pickapic" # "coco" or "pickapic"
3
+ Grad_Config="one_step_rectification_config"
4
+ # constant cosine_nesterov low_to_high_momentum high_to_low_momentum
5
+ # low_to_high_nesterov high_to_low_nesterov
6
+ Model_Variant="spo" #lpo, spo
7
+ GPU_ID=$(nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits | sort -k2 -n | head -n1 | cut -d',' -f1)
8
+
9
+ python eval.py \
10
+ --model_variant "$Model_Variant" \
11
+ --dataset_type "$Dataset_Name" \
12
+ --grad_config "$Grad_Config" \
13
+ --metrics fid clip aesthetic pickscore hpsv2 hpsv21 imagereward \
14
+ --max_samples 500 \
15
+ --num_steps 20 \
16
+ --cfg_scale 3 \
17
+ --output_dir "RESULTS/$Dataset_Name/${Grad_Config}_${Model_Variant}" \
18
+ --cuda $GPU_ID \
19
+ --mode "gradient_ascent" # gradient_ascent baseline # /baseline
Reward_sdxl_idealized/gradient_ascent_utils.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradient Ascent utilities for reward-guided diffusion generation.
3
+
4
+ This module implements gradient ascent on the LRM reward score to guide
5
+ the diffusion process toward higher preference scores.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from typing import Optional, Tuple, List, Literal
11
+ from tqdm import tqdm
12
+ from lr_scheduler import create_lr_scheduler, LRScheduler
13
+
14
+
15
+ class RewardGuidedDiffusion:
16
+ """
17
+ Implements reward-guided generation using gradient ascent.
18
+
19
+ During denoising, at specified timesteps, we:
20
+ 1. Compute the reward score for current latents
21
+ 2. Calculate gradients of reward w.r.t. latents
22
+ 3. Update latents in the direction that increases reward
23
+
24
+ This guides generation toward higher preference scores.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ reward_model,
30
+ grad_scale: float = 1.0,
31
+ grad_timestep_range: Optional[Tuple[int, int]] = None,
32
+ num_grad_steps: int = 5,
33
+ grad_step_size: float = 0.1,
34
+ gradient_checkpoint: bool = False,
35
+ # LR Scheduling
36
+ lr_scheduler_type: Literal["constant", "linear", "cosine", "exponential", "step"] = "constant",
37
+ lr_scheduler_kwargs: Optional[dict] = None,
38
+ # Momentum
39
+ use_momentum: bool = False,
40
+ momentum: float = 0.9,
41
+ use_nesterov: bool = False,
42
+ use_iso_projection: bool = False
43
+ ):
44
+ """
45
+ Initialize reward-guided diffusion.
46
+
47
+ Args:
48
+ reward_model: LRM reward model for computing preference scores
49
+ grad_scale: Scale factor for gradient updates (default: 1.0)
50
+ grad_timestep_range: Tuple of (min_t, max_t) for gradient ascent.
51
+ If None, applies to all timesteps.
52
+ num_grad_steps: Number of gradient ascent steps per timestep
53
+ grad_step_size: Step size for each gradient update (initial LR)
54
+ gradient_checkpoint: Whether to use gradient checkpointing
55
+ lr_scheduler_type: Type of LR scheduler ("constant", "linear", "cosine", "exponential", "step")
56
+ lr_scheduler_kwargs: Additional kwargs for LR scheduler (e.g., end_lr, min_lr, warmup_steps)
57
+ use_momentum: Whether to use momentum in gradient updates
58
+ momentum: Momentum coefficient (typically 0.9)
59
+ use_nesterov: Whether to use Nesterov momentum
60
+ use_iso_projection: Whether to use Iso Projection
61
+ """
62
+ self.reward_model = reward_model
63
+ self.grad_scale = grad_scale
64
+ self.grad_timestep_range = grad_timestep_range
65
+ self.num_grad_steps = num_grad_steps
66
+ self.grad_step_size = grad_step_size
67
+ self.gradient_checkpoint = gradient_checkpoint
68
+
69
+ # LR Scheduler
70
+ self.lr_scheduler_type = lr_scheduler_type
71
+ self.lr_scheduler_kwargs = lr_scheduler_kwargs or {}
72
+ self.lr_scheduler: Optional[LRScheduler] = None
73
+ self.global_lr_scheduler: Optional[LRScheduler] = None # Scheduler across denoising timesteps
74
+
75
+ # Momentum
76
+ self.use_momentum = use_momentum
77
+ self.momentum = momentum
78
+ self.use_nesterov = use_nesterov
79
+ self.velocity = None # Will be initialized per optimization
80
+
81
+ self.use_iso_projection = use_iso_projection
82
+
83
+ # Statistics
84
+ self.grad_stats = []
85
+ self.timestep_counter = 0 # Track which timestep we're on
86
+
87
+ def should_apply_gradient(self, timestep: int) -> bool:
88
+ """Check if gradient ascent should be applied at this timestep."""
89
+
90
+ if self.grad_timestep_range is None:
91
+ return False
92
+
93
+ min_t, max_t = self.grad_timestep_range
94
+ return min_t <= timestep <= max_t
95
+
96
+ @torch.enable_grad()
97
+ def compute_reward_gradient(
98
+ self,
99
+ latents: torch.Tensor,
100
+ prompt: str,
101
+ timestep: int,
102
+ ) -> Tuple[torch.Tensor, float]:
103
+ """
104
+ Compute gradient of reward score w.r.t. latents in FP32 to prevent underflow.
105
+ """
106
+ # 1. Cast to FP32 and ensure we are detached from previous iterations
107
+ latents_fp32 = latents.detach().to(torch.float32).clone()
108
+ latents_fp32.requires_grad_(True)
109
+
110
+ # 2. Compute reward score
111
+ # Note: Even if the model internally uses fp16/bf16, autograd will
112
+ # safely accumulate the gradient in fp32 for our leaf node.
113
+ reward_score = self.reward_model.get_reward_score(
114
+ latents_fp32,
115
+ prompt,
116
+ timestep,
117
+ enable_grad=True
118
+ )
119
+
120
+ if reward_score.numel() > 1:
121
+ reward_score = reward_score.mean()
122
+
123
+ # 3. Extract gradient
124
+ # CRITICAL: retain_graph=True prevents the graph from dying across multiple
125
+ # gradient steps if your reward model relies on cached text embeddings.
126
+ grad = torch.autograd.grad(
127
+ outputs=reward_score,
128
+ inputs=latents_fp32,
129
+ create_graph=False,
130
+ retain_graph=True, # Keeps the graph alive for the next step!
131
+ allow_unused=True,
132
+ )[0]
133
+
134
+ # 4. Handle None gradients and cast back to the pipeline's original dtype
135
+ if grad is None:
136
+ grad = torch.zeros_like(latents)
137
+ else:
138
+ grad = grad.to(latents.dtype)
139
+
140
+ return grad, reward_score.item()
141
+
142
+ def apply_gradient_ascent(
143
+ self,
144
+ latents: torch.Tensor,
145
+ prompt: str,
146
+ timestep: int,
147
+ base_noise: Optional[torch.Tensor] = None, # Required for Iso-Marginal projection
148
+ verbose: bool = True,
149
+ total_denoising_steps: Optional[int] = None,
150
+ ) -> Tuple[torch.Tensor, dict]:
151
+
152
+ # 1. UPCAST TO FP32 AND SETUP OPTIMIZER (Targeting Latents)
153
+ original_latents = latents.detach().clone().to(torch.float32)
154
+ current_latents = torch.nn.Parameter(original_latents.clone())
155
+ self.reward_model.unet.conv_in.weight.requires_grad_(True)
156
+
157
+ # Initial reward tracking
158
+ with torch.no_grad():
159
+ initial_reward = self.reward_model.get_reward_score(
160
+ latents,
161
+ prompt,
162
+ timestep
163
+ )
164
+ initial_reward_val = initial_reward.item() if initial_reward.numel() == 1 else initial_reward.mean().item()
165
+
166
+ # Initialize tracking lists
167
+ grad_norms = []
168
+ reward_history = [initial_reward_val]
169
+ lr_history = []
170
+
171
+ # 2. FORWARD PASS (downcast to FP16 just for the model forward pass)
172
+ reward = self.reward_model.get_reward_score(
173
+ current_latents.to(latents.dtype),
174
+ prompt,
175
+ timestep,
176
+ enable_grad=True
177
+ )
178
+
179
+ loss = -reward.mean()
180
+ loss.backward()
181
+
182
+ # Extract latent gradient
183
+ raw_grad = current_latents.grad
184
+ reward_history.append(reward.mean().item())
185
+
186
+ # 3. ISO-MARGINAL PROJECTION WITH ASYMMETRIC INCLUSION
187
+ if raw_grad is not None and base_noise is not None and self.use_iso_projection:
188
+ gamma = 1e-8
189
+ B = raw_grad.shape[0]
190
+
191
+ grad_flat = raw_grad.view(B, -1)
192
+ noise_flat = base_noise.view(B, -1).to(torch.float32)
193
+
194
+ # Compute projection scalar for raw_grad (which is -?R)
195
+ dot_product = (grad_flat * noise_flat).sum(dim=1, keepdim=True)
196
+ noise_norm_sq = (noise_flat * noise_flat).sum(dim=1, keepdim=True)
197
+
198
+ proj_scalar = dot_product / (noise_norm_sq + gamma)
199
+ proj_scalar = proj_scalar.view(B, 1, 1, 1)
200
+
201
+ # 1. Decompose
202
+ grad_parallel = proj_scalar * base_noise.to(torch.float32)
203
+ grad_perp = raw_grad - grad_parallel
204
+
205
+ # 2. Asymmetric Inclusion
206
+ # proj_scalar > 0 means the applied step (+?R) points toward -epsilon (Denoising. GOOD.)
207
+ # proj_scalar < 0 means the applied step (+?R) points toward +epsilon (Noising. BAD.)
208
+ safe_proj_scalar = torch.clamp(proj_scalar, min=0.0)
209
+
210
+ beta = 1.0 # Retention factor for the safe parallel gradient
211
+ safe_grad_parallel = beta * (safe_proj_scalar * base_noise.to(torch.float32))
212
+
213
+ # 3. Recombine
214
+ #grad_perp = grad_perp + safe_grad_parallel
215
+ else:
216
+ grad_perp = raw_grad
217
+ if base_noise is None:
218
+ print("?? WARNING: base_noise missing. Skipping Iso-Marginal projection.")
219
+
220
+ # 4. KINETIC RECTIFICATION (Applied to the projected latent gradient)
221
+ if grad_perp is not None:
222
+ max_grad = grad_perp.norm().item()
223
+
224
+ if max_grad > 0:
225
+ kinetic_direction = grad_perp / (max_grad + 1e-8)
226
+
227
+ # Because the max element is 1.0, alpha is the EXACT float32 change applied.
228
+ alpha = self.grad_step_size
229
+
230
+ with torch.no_grad():
231
+ rectified_latents = original_latents - (alpha * kinetic_direction)
232
+ else:
233
+ print("?? WARNING: Gradient exists but max value is 0.0")
234
+ rectified_latents = original_latents.clone()
235
+ alpha = 0.0
236
+ else:
237
+ print("?? FATAL: PyTorch completely dropped the latent gradient!")
238
+ rectified_latents = original_latents.clone()
239
+ max_grad = 0.0
240
+ alpha = 0.0
241
+
242
+ if verbose:
243
+ print(f" Grad step | LR: {alpha:.6f} | Reward: {reward.mean().item():.4f} | Max Grad: {max_grad:.4f}")
244
+
245
+ # 5. DOWNCAST AND RETURN
246
+ final_latents = rectified_latents.detach().to(latents.dtype)
247
+
248
+ with torch.no_grad():
249
+ final_reward = self.reward_model.get_reward_score(
250
+ final_latents, prompt, timestep
251
+ )
252
+ final_reward_val = final_reward.item() if final_reward.numel() == 1 else final_reward.mean().item()
253
+
254
+ stats = {
255
+ 'timestep': timestep,
256
+ 'initial_reward': initial_reward_val,
257
+ 'final_reward': final_reward_val,
258
+ 'reward_improvement': final_reward_val - initial_reward_val,
259
+ 'grad_norms': [max_grad],
260
+ 'reward_history': reward_history,
261
+ 'lr_history': [alpha], # Kept for plotting logic
262
+ 'latent_change': (final_latents - original_latents.to(latents.dtype)).norm().item(),
263
+ }
264
+
265
+ self.grad_stats.append(stats)
266
+
267
+ return final_latents, stats
268
+
269
+ def get_statistics(self) -> dict:
270
+ """Get aggregated statistics across all gradient ascent applications."""
271
+ if not self.grad_stats:
272
+ return {}
273
+
274
+ total_improvement = sum(s['reward_improvement'] for s in self.grad_stats)
275
+ avg_improvement = total_improvement / len(self.grad_stats)
276
+
277
+ all_grad_norms = [n for s in self.grad_stats for n in s['grad_norms']]
278
+
279
+ return {
280
+ 'num_applications': len(self.grad_stats),
281
+ 'total_reward_improvement': total_improvement,
282
+ 'avg_reward_improvement': avg_improvement,
283
+ 'avg_grad_norm': sum(all_grad_norms) / len(all_grad_norms) if all_grad_norms else 0,
284
+ 'max_grad_norm': max(all_grad_norms) if all_grad_norms else 0,
285
+ 'detailed_stats': self.grad_stats,
286
+ }
287
+
288
+ def reset_statistics(self):
289
+ """Reset statistics and global scheduler."""
290
+ self.grad_stats = []
291
+ self.global_lr_scheduler = None
292
+ self.timestep_counter = 0
293
+
294
+
295
+ def create_reward_guided_generator(
296
+ reward_model,
297
+ grad_timestep_range: Tuple[int, int] = (500, 700),
298
+ grad_scale: float = 1.0,
299
+ num_grad_steps: int = 5,
300
+ grad_step_size: float = 0.1,
301
+ lr_scheduler_type: str = "constant",
302
+ lr_scheduler_kwargs: Optional[dict] = None,
303
+ use_momentum: bool = False,
304
+ momentum: float = 0.9,
305
+ use_nesterov: bool = False,
306
+ use_iso_projection: bool = False
307
+ ) -> RewardGuidedDiffusion:
308
+ """
309
+ Convenience function to create a reward-guided diffusion generator.
310
+
311
+ Args:
312
+ reward_model: LRM reward model
313
+ grad_timestep_range: Tuple of (min_t, max_t) for applying gradients
314
+ grad_scale: Scale factor for gradient magnitude
315
+ num_grad_steps: Number of gradient ascent iterations per timestep
316
+ grad_step_size: Step size for each gradient update (initial LR)
317
+ lr_scheduler_type: Type of LR scheduler
318
+ lr_scheduler_kwargs: Additional kwargs for LR scheduler
319
+ use_momentum: Whether to use momentum
320
+ momentum: Momentum coefficient
321
+ use_nesterov: Whether to use Nesterov momentum
322
+ use_iso_projection: Whether to use Iso Projection
323
+
324
+ Returns:
325
+ RewardGuidedDiffusion instance
326
+ """
327
+ return RewardGuidedDiffusion(
328
+ reward_model=reward_model,
329
+ grad_scale=grad_scale,
330
+ grad_timestep_range=grad_timestep_range,
331
+ num_grad_steps=num_grad_steps,
332
+ grad_step_size=grad_step_size,
333
+ lr_scheduler_type=lr_scheduler_type,
334
+ lr_scheduler_kwargs=lr_scheduler_kwargs,
335
+ use_momentum=use_momentum,
336
+ momentum=momentum,
337
+ use_nesterov=use_nesterov,
338
+ use_iso_projection= False
339
+ )
Reward_sdxl_idealized/lr_scheduler.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Learning rate schedulers for gradient ascent optimization.
3
+
4
+ Provides various LR scheduling strategies for reward-guided gradient ascent,
5
+ including cosine annealing, linear decay, and custom schedules.
6
+ """
7
+
8
+ import math
9
+ from typing import Optional, Literal
10
+
11
+
12
+ class LRScheduler:
13
+ """Base class for learning rate schedulers."""
14
+
15
+ def __init__(self, initial_lr: float, num_steps: int):
16
+ """
17
+ Initialize LR scheduler.
18
+
19
+ Args:
20
+ initial_lr: Initial learning rate
21
+ num_steps: Total number of optimization steps
22
+ """
23
+ self.initial_lr = initial_lr
24
+ self.num_steps = num_steps
25
+ self.current_step = 0
26
+
27
+ def get_lr(self) -> float:
28
+ """Get current learning rate."""
29
+ raise NotImplementedError
30
+
31
+ def step(self):
32
+ """Update scheduler state after a step."""
33
+ self.current_step += 1
34
+
35
+ def reset(self):
36
+ """Reset scheduler state."""
37
+ self.current_step = 0
38
+
39
+
40
+ class ConstantLR(LRScheduler):
41
+ """Constant learning rate (no scheduling)."""
42
+
43
+ def get_lr(self) -> float:
44
+ return self.initial_lr
45
+
46
+
47
+ class LinearLR(LRScheduler):
48
+ """Linear learning rate decay."""
49
+
50
+ def __init__(
51
+ self,
52
+ initial_lr: float,
53
+ num_steps: int,
54
+ end_lr: float = 0.0,
55
+ start_step: int = 0,
56
+ ):
57
+ """
58
+ Initialize linear LR scheduler.
59
+
60
+ Args:
61
+ initial_lr: Starting learning rate
62
+ num_steps: Total number of steps
63
+ end_lr: Ending learning rate (default: 0.0)
64
+ start_step: Step to begin decay (default: 0)
65
+ """
66
+ super().__init__(initial_lr, num_steps)
67
+ self.end_lr = end_lr
68
+ self.start_step = start_step
69
+
70
+ def get_lr(self) -> float:
71
+ if self.current_step < self.start_step:
72
+ return self.initial_lr
73
+
74
+ progress = (self.current_step - self.start_step) / (self.num_steps - self.start_step)
75
+ progress = min(1.0, progress)
76
+
77
+ return self.initial_lr + (self.end_lr - self.initial_lr) * progress
78
+
79
+
80
+ class CosineLR(LRScheduler):
81
+ """Cosine annealing learning rate schedule."""
82
+
83
+ def __init__(
84
+ self,
85
+ initial_lr: float,
86
+ num_steps: int,
87
+ min_lr: float = 0.0,
88
+ warmup_steps: int = 0,
89
+ ):
90
+ """
91
+ Initialize cosine LR scheduler.
92
+
93
+ Args:
94
+ initial_lr: Maximum learning rate
95
+ num_steps: Total number of steps
96
+ min_lr: Minimum learning rate (default: 0.0)
97
+ warmup_steps: Number of linear warmup steps (default: 0)
98
+ """
99
+ super().__init__(initial_lr, num_steps)
100
+ self.min_lr = min_lr
101
+ self.warmup_steps = warmup_steps
102
+
103
+ def get_lr(self) -> float:
104
+ if self.current_step < self.warmup_steps:
105
+ # Linear warmup
106
+ return self.initial_lr * (self.current_step / self.warmup_steps)
107
+
108
+ # Cosine annealing
109
+ progress = (self.current_step - self.warmup_steps) / (self.num_steps - self.warmup_steps)
110
+ progress = min(1.0, progress)
111
+
112
+ cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
113
+ return self.min_lr + (self.initial_lr - self.min_lr) * cosine_decay
114
+
115
+
116
+ class ExponentialLR(LRScheduler):
117
+ """Exponential learning rate decay."""
118
+
119
+ def __init__(
120
+ self,
121
+ initial_lr: float,
122
+ num_steps: int,
123
+ gamma: float = 0.95,
124
+ ):
125
+ """
126
+ Initialize exponential LR scheduler.
127
+
128
+ Args:
129
+ initial_lr: Starting learning rate
130
+ num_steps: Total number of steps
131
+ gamma: Multiplicative decay factor per step
132
+ """
133
+ super().__init__(initial_lr, num_steps)
134
+ self.gamma = gamma
135
+
136
+ def get_lr(self) -> float:
137
+ return self.initial_lr * (self.gamma ** self.current_step)
138
+
139
+
140
+ class StepLR(LRScheduler):
141
+ """Step-wise learning rate decay."""
142
+
143
+ def __init__(
144
+ self,
145
+ initial_lr: float,
146
+ num_steps: int,
147
+ step_size: int,
148
+ gamma: float = 0.1,
149
+ ):
150
+ """
151
+ Initialize step LR scheduler.
152
+
153
+ Args:
154
+ initial_lr: Starting learning rate
155
+ num_steps: Total number of steps
156
+ step_size: Number of steps between each decay
157
+ gamma: Multiplicative decay factor
158
+ """
159
+ super().__init__(initial_lr, num_steps)
160
+ self.step_size = step_size
161
+ self.gamma = gamma
162
+
163
+ def get_lr(self) -> float:
164
+ num_decays = self.current_step // self.step_size
165
+ return self.initial_lr * (self.gamma ** num_decays)
166
+
167
+
168
+ def create_lr_scheduler(
169
+ scheduler_type: Literal["constant", "linear", "cosine", "exponential", "step"],
170
+ initial_lr: float,
171
+ num_steps: int,
172
+ **kwargs
173
+ ) -> LRScheduler:
174
+ """
175
+ Factory function to create learning rate schedulers.
176
+
177
+ Args:
178
+ scheduler_type: Type of scheduler ("constant", "linear", "cosine", "exponential", "step")
179
+ initial_lr: Initial learning rate
180
+ num_steps: Total number of optimization steps
181
+ **kwargs: Additional scheduler-specific arguments
182
+ For linear: end_lr, start_step
183
+ For cosine: min_lr, warmup_steps
184
+ For exponential: gamma
185
+ For step: step_size, gamma
186
+
187
+ Returns:
188
+ LRScheduler instance
189
+
190
+ Examples:
191
+ # Constant LR
192
+ scheduler = create_lr_scheduler("constant", initial_lr=0.1, num_steps=100)
193
+
194
+ # Linear decay
195
+ scheduler = create_lr_scheduler("linear", initial_lr=0.1, num_steps=100, end_lr=0.01)
196
+
197
+ # Cosine annealing with warmup
198
+ scheduler = create_lr_scheduler("cosine", initial_lr=0.1, num_steps=100,
199
+ min_lr=0.001, warmup_steps=10)
200
+ """
201
+ if scheduler_type == "constant":
202
+ return ConstantLR(initial_lr, num_steps)
203
+
204
+ elif scheduler_type == "linear":
205
+ return LinearLR(
206
+ initial_lr, num_steps,
207
+ end_lr=kwargs.get("end_lr", 0.0),
208
+ start_step=kwargs.get("start_step", 0),
209
+ )
210
+
211
+ elif scheduler_type == "cosine":
212
+ return CosineLR(
213
+ initial_lr, num_steps,
214
+ min_lr=kwargs.get("min_lr", 0.0),
215
+ warmup_steps=kwargs.get("warmup_steps", 0),
216
+ )
217
+
218
+ elif scheduler_type == "exponential":
219
+ return ExponentialLR(
220
+ initial_lr, num_steps,
221
+ gamma=kwargs.get("gamma", 0.95),
222
+ )
223
+
224
+ elif scheduler_type == "step":
225
+ return StepLR(
226
+ initial_lr, num_steps,
227
+ step_size=kwargs.get("step_size", 10),
228
+ gamma=kwargs.get("gamma", 0.1),
229
+ )
230
+
231
+ else:
232
+ raise ValueError(f"Unknown scheduler type: {scheduler_type}. "
233
+ f"Choose from: constant, linear, cosine, exponential, step")
Reward_sdxl_idealized/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .reward_model import LRMRewardModelXL
2
+
3
+ __all__ = ['LRMRewardModelXL']
Reward_sdxl_idealized/models/__pycache__/reward_model.cpython-310.pyc ADDED
Binary file (6.25 kB). View file
 
Reward_sdxl_idealized/models/__pycache__/reward_model.cpython-313.pyc ADDED
Binary file (16.4 kB). View file
 
Reward_sdxl_idealized/models/__pycache__/unet_2d_condition_reward.cpython-313.pyc ADDED
Binary file (57.4 kB). View file
 
Reward_sdxl_idealized/models/unet_2d_condition_reward.py ADDED
@@ -0,0 +1,1334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
23
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
24
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
25
+ from diffusers.models.activations import get_activation
26
+ from diffusers.models.attention_processor import (
27
+ ADDED_KV_ATTENTION_PROCESSORS,
28
+ CROSS_ATTENTION_PROCESSORS,
29
+ Attention,
30
+ AttentionProcessor,
31
+ AttnAddedKVProcessor,
32
+ AttnProcessor,
33
+ FusedAttnProcessor2_0,
34
+ )
35
+ from diffusers.models.embeddings import (
36
+ GaussianFourierProjection,
37
+ GLIGENTextBoundingboxProjection,
38
+ ImageHintTimeEmbedding,
39
+ ImageProjection,
40
+ ImageTimeEmbedding,
41
+ TextImageProjection,
42
+ TextImageTimeEmbedding,
43
+ TextTimeEmbedding,
44
+ TimestepEmbedding,
45
+ Timesteps,
46
+ )
47
+ from diffusers.models.modeling_utils import ModelMixin
48
+ from diffusers.models.unets.unet_2d_blocks import (
49
+ get_down_block,
50
+ get_mid_block,
51
+ get_up_block,
52
+ )
53
+
54
+
55
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+
58
+ @dataclass
59
+ class UNet2DConditionOutput(BaseOutput):
60
+ """
61
+ The output of [`UNet2DConditionModel`].
62
+
63
+ Args:
64
+ sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
65
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
66
+ """
67
+
68
+ sample: torch.Tensor = None
69
+
70
+
71
+ class UNet2DConditionModel(
72
+ ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
73
+ ):
74
+ r"""
75
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
76
+ shaped output.
77
+
78
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
79
+ for all models (such as downloading or saving).
80
+
81
+ Parameters:
82
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
83
+ Height and width of input/output sample.
84
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
85
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
86
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
87
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
88
+ Whether to flip the sin to cos in the time embedding.
89
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
90
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
91
+ The tuple of downsample blocks to use.
92
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
93
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
94
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
95
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
96
+ The tuple of upsample blocks to use.
97
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
98
+ Whether to include self-attention in the basic transformer blocks, see
99
+ [`~models.attention.BasicTransformerBlock`].
100
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
101
+ The tuple of output channels for each block.
102
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
103
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
104
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
105
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
106
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
107
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
108
+ If `None`, normalization and activation layers is skipped in post-processing.
109
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
110
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
111
+ The dimension of the cross attention features.
112
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
113
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
114
+ [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`],
115
+ [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
116
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
117
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
118
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
119
+ [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`],
120
+ [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
121
+ encoder_hid_dim (`int`, *optional*, defaults to None):
122
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
123
+ dimension to `cross_attention_dim`.
124
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
125
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
126
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
127
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
128
+ num_attention_heads (`int`, *optional*):
129
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
130
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
131
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
132
+ class_embed_type (`str`, *optional*, defaults to `None`):
133
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
134
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
135
+ addition_embed_type (`str`, *optional*, defaults to `None`):
136
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
137
+ "text". "text" will use the `TextTimeEmbedding` layer.
138
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
139
+ Dimension for the timestep embeddings.
140
+ num_class_embeds (`int`, *optional*, defaults to `None`):
141
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
142
+ class conditioning with `class_embed_type` equal to `None`.
143
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
144
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
145
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
146
+ An optional override for the dimension of the projected time embedding.
147
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
148
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
149
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
150
+ timestep_post_act (`str`, *optional*, defaults to `None`):
151
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
152
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
153
+ The dimension of `cond_proj` layer in the timestep embedding.
154
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
155
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
156
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
157
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
158
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
159
+ embeddings with the class embeddings.
160
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
161
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
162
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
163
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
164
+ otherwise.
165
+ """
166
+
167
+ _supports_gradient_checkpointing = True
168
+ _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
169
+
170
+ @register_to_config
171
+ def __init__(
172
+ self,
173
+ sample_size: Optional[int] = None,
174
+ in_channels: int = 4,
175
+ out_channels: int = 4,
176
+ center_input_sample: bool = False,
177
+ flip_sin_to_cos: bool = True,
178
+ freq_shift: int = 0,
179
+ down_block_types: Tuple[str] = (
180
+ "CrossAttnDownBlock2D",
181
+ "CrossAttnDownBlock2D",
182
+ "CrossAttnDownBlock2D",
183
+ "DownBlock2D",
184
+ ),
185
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
186
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
187
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
188
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
189
+ layers_per_block: Union[int, Tuple[int]] = 2,
190
+ downsample_padding: int = 1,
191
+ mid_block_scale_factor: float = 1,
192
+ dropout: float = 0.0,
193
+ act_fn: str = "silu",
194
+ norm_num_groups: Optional[int] = 32,
195
+ norm_eps: float = 1e-5,
196
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
197
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
198
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
199
+ encoder_hid_dim: Optional[int] = None,
200
+ encoder_hid_dim_type: Optional[str] = None,
201
+ attention_head_dim: Union[int, Tuple[int]] = 8,
202
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
203
+ dual_cross_attention: bool = False,
204
+ use_linear_projection: bool = False,
205
+ class_embed_type: Optional[str] = None,
206
+ addition_embed_type: Optional[str] = None,
207
+ addition_time_embed_dim: Optional[int] = None,
208
+ num_class_embeds: Optional[int] = None,
209
+ upcast_attention: bool = False,
210
+ resnet_time_scale_shift: str = "default",
211
+ resnet_skip_time_act: bool = False,
212
+ resnet_out_scale_factor: float = 1.0,
213
+ time_embedding_type: str = "positional",
214
+ time_embedding_dim: Optional[int] = None,
215
+ time_embedding_act_fn: Optional[str] = None,
216
+ timestep_post_act: Optional[str] = None,
217
+ time_cond_proj_dim: Optional[int] = None,
218
+ conv_in_kernel: int = 3,
219
+ conv_out_kernel: int = 3,
220
+ projection_class_embeddings_input_dim: Optional[int] = None,
221
+ attention_type: str = "default",
222
+ class_embeddings_concat: bool = False,
223
+ mid_block_only_cross_attention: Optional[bool] = None,
224
+ cross_attention_norm: Optional[str] = None,
225
+ addition_embed_type_num_heads: int = 64,
226
+ ):
227
+ super().__init__()
228
+
229
+ self.sample_size = sample_size
230
+
231
+ if num_attention_heads is not None:
232
+ raise ValueError(
233
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
234
+ )
235
+
236
+ # If `num_attention_heads` is not defined (which is the case for most models)
237
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
238
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
239
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
240
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
241
+ # which is why we correct for the naming here.
242
+ num_attention_heads = num_attention_heads or attention_head_dim
243
+
244
+ # Check inputs
245
+ self._check_config(
246
+ down_block_types=down_block_types,
247
+ up_block_types=up_block_types,
248
+ only_cross_attention=only_cross_attention,
249
+ block_out_channels=block_out_channels,
250
+ layers_per_block=layers_per_block,
251
+ cross_attention_dim=cross_attention_dim,
252
+ transformer_layers_per_block=transformer_layers_per_block,
253
+ reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
254
+ attention_head_dim=attention_head_dim,
255
+ num_attention_heads=num_attention_heads,
256
+ )
257
+
258
+ # input
259
+ conv_in_padding = (conv_in_kernel - 1) // 2
260
+ self.conv_in = nn.Conv2d(
261
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
262
+ )
263
+
264
+ # time
265
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
266
+ time_embedding_type,
267
+ block_out_channels=block_out_channels,
268
+ flip_sin_to_cos=flip_sin_to_cos,
269
+ freq_shift=freq_shift,
270
+ time_embedding_dim=time_embedding_dim,
271
+ )
272
+
273
+ self.time_embedding = TimestepEmbedding(
274
+ timestep_input_dim,
275
+ time_embed_dim,
276
+ act_fn=act_fn,
277
+ post_act_fn=timestep_post_act,
278
+ cond_proj_dim=time_cond_proj_dim,
279
+ )
280
+
281
+ self._set_encoder_hid_proj(
282
+ encoder_hid_dim_type,
283
+ cross_attention_dim=cross_attention_dim,
284
+ encoder_hid_dim=encoder_hid_dim,
285
+ )
286
+
287
+ # class embedding
288
+ self._set_class_embedding(
289
+ class_embed_type,
290
+ act_fn=act_fn,
291
+ num_class_embeds=num_class_embeds,
292
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
293
+ time_embed_dim=time_embed_dim,
294
+ timestep_input_dim=timestep_input_dim,
295
+ )
296
+
297
+ self._set_add_embedding(
298
+ addition_embed_type,
299
+ addition_embed_type_num_heads=addition_embed_type_num_heads,
300
+ addition_time_embed_dim=addition_time_embed_dim,
301
+ cross_attention_dim=cross_attention_dim,
302
+ encoder_hid_dim=encoder_hid_dim,
303
+ flip_sin_to_cos=flip_sin_to_cos,
304
+ freq_shift=freq_shift,
305
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
306
+ time_embed_dim=time_embed_dim,
307
+ )
308
+
309
+ if time_embedding_act_fn is None:
310
+ self.time_embed_act = None
311
+ else:
312
+ self.time_embed_act = get_activation(time_embedding_act_fn)
313
+
314
+ self.down_blocks = nn.ModuleList([])
315
+ self.up_blocks = nn.ModuleList([])
316
+
317
+ if isinstance(only_cross_attention, bool):
318
+ if mid_block_only_cross_attention is None:
319
+ mid_block_only_cross_attention = only_cross_attention
320
+
321
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
322
+
323
+ if mid_block_only_cross_attention is None:
324
+ mid_block_only_cross_attention = False
325
+
326
+ if isinstance(num_attention_heads, int):
327
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
328
+
329
+ if isinstance(attention_head_dim, int):
330
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
331
+
332
+ if isinstance(cross_attention_dim, int):
333
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
334
+
335
+ if isinstance(layers_per_block, int):
336
+ layers_per_block = [layers_per_block] * len(down_block_types)
337
+
338
+ if isinstance(transformer_layers_per_block, int):
339
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
340
+
341
+ if class_embeddings_concat:
342
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
343
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
344
+ # regular time embeddings
345
+ blocks_time_embed_dim = time_embed_dim * 2
346
+ else:
347
+ blocks_time_embed_dim = time_embed_dim
348
+
349
+ # down
350
+ output_channel = block_out_channels[0]
351
+ for i, down_block_type in enumerate(down_block_types):
352
+ input_channel = output_channel
353
+ output_channel = block_out_channels[i]
354
+ is_final_block = i == len(block_out_channels) - 1
355
+
356
+ down_block = get_down_block(
357
+ down_block_type,
358
+ num_layers=layers_per_block[i],
359
+ transformer_layers_per_block=transformer_layers_per_block[i],
360
+ in_channels=input_channel,
361
+ out_channels=output_channel,
362
+ temb_channels=blocks_time_embed_dim,
363
+ add_downsample=not is_final_block,
364
+ resnet_eps=norm_eps,
365
+ resnet_act_fn=act_fn,
366
+ resnet_groups=norm_num_groups,
367
+ cross_attention_dim=cross_attention_dim[i],
368
+ num_attention_heads=num_attention_heads[i],
369
+ downsample_padding=downsample_padding,
370
+ dual_cross_attention=dual_cross_attention,
371
+ use_linear_projection=use_linear_projection,
372
+ only_cross_attention=only_cross_attention[i],
373
+ upcast_attention=upcast_attention,
374
+ resnet_time_scale_shift=resnet_time_scale_shift,
375
+ attention_type=attention_type,
376
+ resnet_skip_time_act=resnet_skip_time_act,
377
+ resnet_out_scale_factor=resnet_out_scale_factor,
378
+ cross_attention_norm=cross_attention_norm,
379
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
380
+ dropout=dropout,
381
+ )
382
+ self.down_blocks.append(down_block)
383
+
384
+ # mid
385
+ self.mid_block = get_mid_block(
386
+ mid_block_type,
387
+ temb_channels=blocks_time_embed_dim,
388
+ in_channels=block_out_channels[-1],
389
+ resnet_eps=norm_eps,
390
+ resnet_act_fn=act_fn,
391
+ resnet_groups=norm_num_groups,
392
+ output_scale_factor=mid_block_scale_factor,
393
+ transformer_layers_per_block=transformer_layers_per_block[-1],
394
+ num_attention_heads=num_attention_heads[-1],
395
+ cross_attention_dim=cross_attention_dim[-1],
396
+ dual_cross_attention=dual_cross_attention,
397
+ use_linear_projection=use_linear_projection,
398
+ mid_block_only_cross_attention=mid_block_only_cross_attention,
399
+ upcast_attention=upcast_attention,
400
+ resnet_time_scale_shift=resnet_time_scale_shift,
401
+ attention_type=attention_type,
402
+ resnet_skip_time_act=resnet_skip_time_act,
403
+ cross_attention_norm=cross_attention_norm,
404
+ attention_head_dim=attention_head_dim[-1],
405
+ dropout=dropout,
406
+ )
407
+
408
+ # count how many layers upsample the images
409
+ self.num_upsamplers = 0
410
+
411
+ # up
412
+ reversed_block_out_channels = list(reversed(block_out_channels))
413
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
414
+ reversed_layers_per_block = list(reversed(layers_per_block))
415
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
416
+ reversed_transformer_layers_per_block = (
417
+ list(reversed(transformer_layers_per_block))
418
+ if reverse_transformer_layers_per_block is None
419
+ else reverse_transformer_layers_per_block
420
+ )
421
+ only_cross_attention = list(reversed(only_cross_attention))
422
+
423
+ output_channel = reversed_block_out_channels[0]
424
+ for i, up_block_type in enumerate(up_block_types):
425
+ is_final_block = i == len(block_out_channels) - 1
426
+
427
+ prev_output_channel = output_channel
428
+ output_channel = reversed_block_out_channels[i]
429
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
430
+
431
+ # add upsample block for all BUT final layer
432
+ if not is_final_block:
433
+ add_upsample = True
434
+ self.num_upsamplers += 1
435
+ else:
436
+ add_upsample = False
437
+
438
+ up_block = get_up_block(
439
+ up_block_type,
440
+ num_layers=reversed_layers_per_block[i] + 1,
441
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
442
+ in_channels=input_channel,
443
+ out_channels=output_channel,
444
+ prev_output_channel=prev_output_channel,
445
+ temb_channels=blocks_time_embed_dim,
446
+ add_upsample=add_upsample,
447
+ resnet_eps=norm_eps,
448
+ resnet_act_fn=act_fn,
449
+ resolution_idx=i,
450
+ resnet_groups=norm_num_groups,
451
+ cross_attention_dim=reversed_cross_attention_dim[i],
452
+ num_attention_heads=reversed_num_attention_heads[i],
453
+ dual_cross_attention=dual_cross_attention,
454
+ use_linear_projection=use_linear_projection,
455
+ only_cross_attention=only_cross_attention[i],
456
+ upcast_attention=upcast_attention,
457
+ resnet_time_scale_shift=resnet_time_scale_shift,
458
+ attention_type=attention_type,
459
+ resnet_skip_time_act=resnet_skip_time_act,
460
+ resnet_out_scale_factor=resnet_out_scale_factor,
461
+ cross_attention_norm=cross_attention_norm,
462
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
463
+ dropout=dropout,
464
+ )
465
+ self.up_blocks.append(up_block)
466
+ prev_output_channel = output_channel
467
+
468
+ # out
469
+ if norm_num_groups is not None:
470
+ self.conv_norm_out = nn.GroupNorm(
471
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
472
+ )
473
+
474
+ self.conv_act = get_activation(act_fn)
475
+
476
+ else:
477
+ self.conv_norm_out = None
478
+ self.conv_act = None
479
+
480
+ conv_out_padding = (conv_out_kernel - 1) // 2
481
+ self.conv_out = nn.Conv2d(
482
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
483
+ )
484
+
485
+ self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
486
+
487
+ def _check_config(
488
+ self,
489
+ down_block_types: Tuple[str],
490
+ up_block_types: Tuple[str],
491
+ only_cross_attention: Union[bool, Tuple[bool]],
492
+ block_out_channels: Tuple[int],
493
+ layers_per_block: Union[int, Tuple[int]],
494
+ cross_attention_dim: Union[int, Tuple[int]],
495
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
496
+ reverse_transformer_layers_per_block: bool,
497
+ attention_head_dim: int,
498
+ num_attention_heads: Optional[Union[int, Tuple[int]]],
499
+ ):
500
+ if len(down_block_types) != len(up_block_types):
501
+ raise ValueError(
502
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
503
+ )
504
+
505
+ if len(block_out_channels) != len(down_block_types):
506
+ raise ValueError(
507
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
508
+ )
509
+
510
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
511
+ raise ValueError(
512
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
513
+ )
514
+
515
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
516
+ raise ValueError(
517
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
518
+ )
519
+
520
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
521
+ raise ValueError(
522
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
523
+ )
524
+
525
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
526
+ raise ValueError(
527
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
528
+ )
529
+
530
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
531
+ raise ValueError(
532
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
533
+ )
534
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
535
+ for layer_number_per_block in transformer_layers_per_block:
536
+ if isinstance(layer_number_per_block, list):
537
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
538
+
539
+ def _set_time_proj(
540
+ self,
541
+ time_embedding_type: str,
542
+ block_out_channels: int,
543
+ flip_sin_to_cos: bool,
544
+ freq_shift: float,
545
+ time_embedding_dim: int,
546
+ ) -> Tuple[int, int]:
547
+ if time_embedding_type == "fourier":
548
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
549
+ if time_embed_dim % 2 != 0:
550
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
551
+ self.time_proj = GaussianFourierProjection(
552
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
553
+ )
554
+ timestep_input_dim = time_embed_dim
555
+ elif time_embedding_type == "positional":
556
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
557
+
558
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
559
+ timestep_input_dim = block_out_channels[0]
560
+ else:
561
+ raise ValueError(
562
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
563
+ )
564
+
565
+ return time_embed_dim, timestep_input_dim
566
+
567
+ def _set_encoder_hid_proj(
568
+ self,
569
+ encoder_hid_dim_type: Optional[str],
570
+ cross_attention_dim: Union[int, Tuple[int]],
571
+ encoder_hid_dim: Optional[int],
572
+ ):
573
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
574
+ encoder_hid_dim_type = "text_proj"
575
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
576
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
577
+
578
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
579
+ raise ValueError(
580
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
581
+ )
582
+
583
+ if encoder_hid_dim_type == "text_proj":
584
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
585
+ elif encoder_hid_dim_type == "text_image_proj":
586
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
587
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
588
+ # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
589
+ self.encoder_hid_proj = TextImageProjection(
590
+ text_embed_dim=encoder_hid_dim,
591
+ image_embed_dim=cross_attention_dim,
592
+ cross_attention_dim=cross_attention_dim,
593
+ )
594
+ elif encoder_hid_dim_type == "image_proj":
595
+ # Kandinsky 2.2
596
+ self.encoder_hid_proj = ImageProjection(
597
+ image_embed_dim=encoder_hid_dim,
598
+ cross_attention_dim=cross_attention_dim,
599
+ )
600
+ elif encoder_hid_dim_type is not None:
601
+ raise ValueError(
602
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
603
+ )
604
+ else:
605
+ self.encoder_hid_proj = None
606
+
607
+ def _set_class_embedding(
608
+ self,
609
+ class_embed_type: Optional[str],
610
+ act_fn: str,
611
+ num_class_embeds: Optional[int],
612
+ projection_class_embeddings_input_dim: Optional[int],
613
+ time_embed_dim: int,
614
+ timestep_input_dim: int,
615
+ ):
616
+ if class_embed_type is None and num_class_embeds is not None:
617
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
618
+ elif class_embed_type == "timestep":
619
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
620
+ elif class_embed_type == "identity":
621
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
622
+ elif class_embed_type == "projection":
623
+ if projection_class_embeddings_input_dim is None:
624
+ raise ValueError(
625
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
626
+ )
627
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
628
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
629
+ # 2. it projects from an arbitrary input dimension.
630
+ #
631
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
632
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
633
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
634
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
635
+ elif class_embed_type == "simple_projection":
636
+ if projection_class_embeddings_input_dim is None:
637
+ raise ValueError(
638
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
639
+ )
640
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
641
+ else:
642
+ self.class_embedding = None
643
+
644
+ def _set_add_embedding(
645
+ self,
646
+ addition_embed_type: str,
647
+ addition_embed_type_num_heads: int,
648
+ addition_time_embed_dim: Optional[int],
649
+ flip_sin_to_cos: bool,
650
+ freq_shift: float,
651
+ cross_attention_dim: Optional[int],
652
+ encoder_hid_dim: Optional[int],
653
+ projection_class_embeddings_input_dim: Optional[int],
654
+ time_embed_dim: int,
655
+ ):
656
+ if addition_embed_type == "text":
657
+ if encoder_hid_dim is not None:
658
+ text_time_embedding_from_dim = encoder_hid_dim
659
+ else:
660
+ text_time_embedding_from_dim = cross_attention_dim
661
+
662
+ self.add_embedding = TextTimeEmbedding(
663
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
664
+ )
665
+ elif addition_embed_type == "text_image":
666
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
667
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
668
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
669
+ self.add_embedding = TextImageTimeEmbedding(
670
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
671
+ )
672
+ elif addition_embed_type == "text_time":
673
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
674
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
675
+ elif addition_embed_type == "image":
676
+ # Kandinsky 2.2
677
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
678
+ elif addition_embed_type == "image_hint":
679
+ # Kandinsky 2.2 ControlNet
680
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
681
+ elif addition_embed_type is not None:
682
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
683
+
684
+ def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
685
+ if attention_type in ["gated", "gated-text-image"]:
686
+ positive_len = 768
687
+ if isinstance(cross_attention_dim, int):
688
+ positive_len = cross_attention_dim
689
+ elif isinstance(cross_attention_dim, (list, tuple)):
690
+ positive_len = cross_attention_dim[0]
691
+
692
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
693
+ self.position_net = GLIGENTextBoundingboxProjection(
694
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
695
+ )
696
+
697
+ @property
698
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
699
+ r"""
700
+ Returns:
701
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
702
+ indexed by its weight name.
703
+ """
704
+ # set recursively
705
+ processors = {}
706
+
707
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
708
+ if hasattr(module, "get_processor"):
709
+ processors[f"{name}.processor"] = module.get_processor()
710
+
711
+ for sub_name, child in module.named_children():
712
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
713
+
714
+ return processors
715
+
716
+ for name, module in self.named_children():
717
+ fn_recursive_add_processors(name, module, processors)
718
+
719
+ return processors
720
+
721
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
722
+ r"""
723
+ Sets the attention processor to use to compute attention.
724
+
725
+ Parameters:
726
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
727
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
728
+ for **all** `Attention` layers.
729
+
730
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
731
+ processor. This is strongly recommended when setting trainable attention processors.
732
+
733
+ """
734
+ count = len(self.attn_processors.keys())
735
+
736
+ if isinstance(processor, dict) and len(processor) != count:
737
+ raise ValueError(
738
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
739
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
740
+ )
741
+
742
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
743
+ if hasattr(module, "set_processor"):
744
+ if not isinstance(processor, dict):
745
+ module.set_processor(processor)
746
+ else:
747
+ module.set_processor(processor.pop(f"{name}.processor"))
748
+
749
+ for sub_name, child in module.named_children():
750
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
751
+
752
+ for name, module in self.named_children():
753
+ fn_recursive_attn_processor(name, module, processor)
754
+
755
+ def set_default_attn_processor(self):
756
+ """
757
+ Disables custom attention processors and sets the default attention implementation.
758
+ """
759
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
760
+ processor = AttnAddedKVProcessor()
761
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
762
+ processor = AttnProcessor()
763
+ else:
764
+ raise ValueError(
765
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
766
+ )
767
+
768
+ self.set_attn_processor(processor)
769
+
770
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
771
+ r"""
772
+ Enable sliced attention computation.
773
+
774
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
775
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
776
+
777
+ Args:
778
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
779
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
780
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
781
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
782
+ must be a multiple of `slice_size`.
783
+ """
784
+ sliceable_head_dims = []
785
+
786
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
787
+ if hasattr(module, "set_attention_slice"):
788
+ sliceable_head_dims.append(module.sliceable_head_dim)
789
+
790
+ for child in module.children():
791
+ fn_recursive_retrieve_sliceable_dims(child)
792
+
793
+ # retrieve number of attention layers
794
+ for module in self.children():
795
+ fn_recursive_retrieve_sliceable_dims(module)
796
+
797
+ num_sliceable_layers = len(sliceable_head_dims)
798
+
799
+ if slice_size == "auto":
800
+ # half the attention head size is usually a good trade-off between
801
+ # speed and memory
802
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
803
+ elif slice_size == "max":
804
+ # make smallest slice possible
805
+ slice_size = num_sliceable_layers * [1]
806
+
807
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
808
+
809
+ if len(slice_size) != len(sliceable_head_dims):
810
+ raise ValueError(
811
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
812
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
813
+ )
814
+
815
+ for i in range(len(slice_size)):
816
+ size = slice_size[i]
817
+ dim = sliceable_head_dims[i]
818
+ if size is not None and size > dim:
819
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
820
+
821
+ # Recursively walk through all the children.
822
+ # Any children which exposes the set_attention_slice method
823
+ # gets the message
824
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
825
+ if hasattr(module, "set_attention_slice"):
826
+ module.set_attention_slice(slice_size.pop())
827
+
828
+ for child in module.children():
829
+ fn_recursive_set_attention_slice(child, slice_size)
830
+
831
+ reversed_slice_size = list(reversed(slice_size))
832
+ for module in self.children():
833
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
834
+
835
+ def _set_gradient_checkpointing(self, module, value=False):
836
+ if hasattr(module, "gradient_checkpointing"):
837
+ module.gradient_checkpointing = value
838
+
839
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
840
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
841
+
842
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
843
+
844
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
845
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
846
+
847
+ Args:
848
+ s1 (`float`):
849
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
850
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
851
+ s2 (`float`):
852
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
853
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
854
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
855
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
856
+ """
857
+ for i, upsample_block in enumerate(self.up_blocks):
858
+ setattr(upsample_block, "s1", s1)
859
+ setattr(upsample_block, "s2", s2)
860
+ setattr(upsample_block, "b1", b1)
861
+ setattr(upsample_block, "b2", b2)
862
+
863
+ def disable_freeu(self):
864
+ """Disables the FreeU mechanism."""
865
+ freeu_keys = {"s1", "s2", "b1", "b2"}
866
+ for i, upsample_block in enumerate(self.up_blocks):
867
+ for k in freeu_keys:
868
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
869
+ setattr(upsample_block, k, None)
870
+
871
+ def fuse_qkv_projections(self):
872
+ """
873
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
874
+ are fused. For cross-attention modules, key and value projection matrices are fused.
875
+
876
+ <Tip warning={true}>
877
+
878
+ This API is ?? experimental.
879
+
880
+ </Tip>
881
+ """
882
+ self.original_attn_processors = None
883
+
884
+ for _, attn_processor in self.attn_processors.items():
885
+ if "Added" in str(attn_processor.__class__.__name__):
886
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
887
+
888
+ self.original_attn_processors = self.attn_processors
889
+
890
+ for module in self.modules():
891
+ if isinstance(module, Attention):
892
+ module.fuse_projections(fuse=True)
893
+
894
+ self.set_attn_processor(FusedAttnProcessor2_0())
895
+
896
+ def unfuse_qkv_projections(self):
897
+ """Disables the fused QKV projection if enabled.
898
+
899
+ <Tip warning={true}>
900
+
901
+ This API is ?? experimental.
902
+
903
+ </Tip>
904
+
905
+ """
906
+ if self.original_attn_processors is not None:
907
+ self.set_attn_processor(self.original_attn_processors)
908
+
909
+ def get_time_embed(
910
+ self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
911
+ ) -> Optional[torch.Tensor]:
912
+ timesteps = timestep
913
+ if not torch.is_tensor(timesteps):
914
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
915
+ # This would be a good case for the `match` statement (Python 3.10+)
916
+ is_mps = sample.device.type == "mps"
917
+ if isinstance(timestep, float):
918
+ dtype = torch.float32 if is_mps else torch.float64
919
+ else:
920
+ dtype = torch.int32 if is_mps else torch.int64
921
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
922
+ elif len(timesteps.shape) == 0:
923
+ timesteps = timesteps[None].to(sample.device)
924
+
925
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
926
+ timesteps = timesteps.expand(sample.shape[0])
927
+
928
+ t_emb = self.time_proj(timesteps)
929
+ # `Timesteps` does not contain any weights and will always return f32 tensors
930
+ # but time_embedding might actually be running in fp16. so we need to cast here.
931
+ # there might be better ways to encapsulate this.
932
+ t_emb = t_emb.to(dtype=sample.dtype)
933
+ return t_emb
934
+
935
+ def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
936
+ class_emb = None
937
+ if self.class_embedding is not None:
938
+ if class_labels is None:
939
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
940
+
941
+ if self.config.class_embed_type == "timestep":
942
+ class_labels = self.time_proj(class_labels)
943
+
944
+ # `Timesteps` does not contain any weights and will always return f32 tensors
945
+ # there might be better ways to encapsulate this.
946
+ class_labels = class_labels.to(dtype=sample.dtype)
947
+
948
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
949
+ return class_emb
950
+
951
+ def get_aug_embed(
952
+ self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
953
+ ) -> Optional[torch.Tensor]:
954
+ aug_emb = None
955
+ if self.config.addition_embed_type == "text":
956
+ aug_emb = self.add_embedding(encoder_hidden_states)
957
+ elif self.config.addition_embed_type == "text_image":
958
+ # Kandinsky 2.1 - style
959
+ if "image_embeds" not in added_cond_kwargs:
960
+ raise ValueError(
961
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
962
+ )
963
+
964
+ image_embs = added_cond_kwargs.get("image_embeds")
965
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
966
+ aug_emb = self.add_embedding(text_embs, image_embs)
967
+ elif self.config.addition_embed_type == "text_time":
968
+ # SDXL - style
969
+ if "text_embeds" not in added_cond_kwargs:
970
+ raise ValueError(
971
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
972
+ )
973
+ text_embeds = added_cond_kwargs.get("text_embeds")
974
+ if "time_ids" not in added_cond_kwargs:
975
+ raise ValueError(
976
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
977
+ )
978
+ time_ids = added_cond_kwargs.get("time_ids")
979
+ time_embeds = self.add_time_proj(time_ids.flatten())
980
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
981
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
982
+ add_embeds = add_embeds.to(emb.dtype)
983
+ aug_emb = self.add_embedding(add_embeds)
984
+ elif self.config.addition_embed_type == "image":
985
+ # Kandinsky 2.2 - style
986
+ if "image_embeds" not in added_cond_kwargs:
987
+ raise ValueError(
988
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
989
+ )
990
+ image_embs = added_cond_kwargs.get("image_embeds")
991
+ aug_emb = self.add_embedding(image_embs)
992
+ elif self.config.addition_embed_type == "image_hint":
993
+ # Kandinsky 2.2 - style
994
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
995
+ raise ValueError(
996
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
997
+ )
998
+ image_embs = added_cond_kwargs.get("image_embeds")
999
+ hint = added_cond_kwargs.get("hint")
1000
+ aug_emb = self.add_embedding(image_embs, hint)
1001
+ return aug_emb
1002
+
1003
+ def process_encoder_hidden_states(
1004
+ self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
1005
+ ) -> torch.Tensor:
1006
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1007
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1008
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1009
+ # Kandinsky 2.1 - style
1010
+ if "image_embeds" not in added_cond_kwargs:
1011
+ raise ValueError(
1012
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1013
+ )
1014
+
1015
+ image_embeds = added_cond_kwargs.get("image_embeds")
1016
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1017
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1018
+ # Kandinsky 2.2 - style
1019
+ if "image_embeds" not in added_cond_kwargs:
1020
+ raise ValueError(
1021
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1022
+ )
1023
+ image_embeds = added_cond_kwargs.get("image_embeds")
1024
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1025
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
1026
+ if "image_embeds" not in added_cond_kwargs:
1027
+ raise ValueError(
1028
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1029
+ )
1030
+
1031
+ if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None:
1032
+ encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states)
1033
+
1034
+ image_embeds = added_cond_kwargs.get("image_embeds")
1035
+ image_embeds = self.encoder_hid_proj(image_embeds)
1036
+ encoder_hidden_states = (encoder_hidden_states, image_embeds)
1037
+ return encoder_hidden_states
1038
+
1039
+ def forward(
1040
+ self,
1041
+ sample: torch.Tensor,
1042
+ timestep: Union[torch.Tensor, float, int],
1043
+ encoder_hidden_states: torch.Tensor,
1044
+ class_labels: Optional[torch.Tensor] = None,
1045
+ timestep_cond: Optional[torch.Tensor] = None,
1046
+ attention_mask: Optional[torch.Tensor] = None,
1047
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1048
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
1049
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1050
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
1051
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1052
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1053
+ use_up_blocks: bool = False,
1054
+ return_dict: bool = True,
1055
+ ) -> Union[UNet2DConditionOutput, Tuple]:
1056
+ r"""
1057
+ The [`UNet2DConditionModel`] forward method.
1058
+
1059
+ Args:
1060
+ sample (`torch.Tensor`):
1061
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
1062
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
1063
+ encoder_hidden_states (`torch.Tensor`):
1064
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
1065
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
1066
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
1067
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
1068
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
1069
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
1070
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
1071
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
1072
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
1073
+ negative values to the attention scores corresponding to "discard" tokens.
1074
+ cross_attention_kwargs (`dict`, *optional*):
1075
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1076
+ `self.processor` in
1077
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1078
+ added_cond_kwargs: (`dict`, *optional*):
1079
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
1080
+ are passed along to the UNet blocks.
1081
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
1082
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
1083
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
1084
+ A tensor that if specified is added to the residual of the middle unet block.
1085
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1086
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
1087
+ encoder_attention_mask (`torch.Tensor`):
1088
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
1089
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
1090
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
1091
+ return_dict (`bool`, *optional*, defaults to `True`):
1092
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
1093
+ tuple.
1094
+
1095
+ Returns:
1096
+ [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
1097
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
1098
+ otherwise a `tuple` is returned where the first element is the sample tensor.
1099
+ """
1100
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
1101
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
1102
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
1103
+ # on the fly if necessary.
1104
+ default_overall_up_factor = 2**self.num_upsamplers
1105
+
1106
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1107
+ forward_upsample_size = False
1108
+ upsample_size = None
1109
+
1110
+ # import time
1111
+ # torch.cuda.synchronize()
1112
+ # start_time = time.time()
1113
+
1114
+ for dim in sample.shape[-2:]:
1115
+ if dim % default_overall_up_factor != 0:
1116
+ # Forward upsample size to force interpolation output size.
1117
+ forward_upsample_size = True
1118
+ break
1119
+
1120
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
1121
+ # expects mask of shape:
1122
+ # [batch, key_tokens]
1123
+ # adds singleton query_tokens dimension:
1124
+ # [batch, 1, key_tokens]
1125
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
1126
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
1127
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
1128
+ if attention_mask is not None:
1129
+ # assume that mask is expressed as:
1130
+ # (1 = keep, 0 = discard)
1131
+ # convert mask into a bias that can be added to attention scores:
1132
+ # (keep = +0, discard = -10000.0)
1133
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
1134
+ attention_mask = attention_mask.unsqueeze(1)
1135
+
1136
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
1137
+ if encoder_attention_mask is not None:
1138
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
1139
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
1140
+
1141
+ # 0. center input if necessary
1142
+ if self.config.center_input_sample:
1143
+ sample = 2 * sample - 1.0
1144
+
1145
+ # 1. time
1146
+ t_emb = self.get_time_embed(sample=sample, timestep=timestep)
1147
+ emb = self.time_embedding(t_emb, timestep_cond)
1148
+ aug_emb = None
1149
+
1150
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
1151
+ if class_emb is not None:
1152
+ if self.config.class_embeddings_concat:
1153
+ emb = torch.cat([emb, class_emb], dim=-1)
1154
+ else:
1155
+ emb = emb + class_emb
1156
+
1157
+ aug_emb = self.get_aug_embed(
1158
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1159
+ )
1160
+ if self.config.addition_embed_type == "image_hint":
1161
+ aug_emb, hint = aug_emb
1162
+ sample = torch.cat([sample, hint], dim=1)
1163
+
1164
+ emb = emb + aug_emb if aug_emb is not None else emb
1165
+
1166
+ if self.time_embed_act is not None:
1167
+ emb = self.time_embed_act(emb)
1168
+
1169
+ encoder_hidden_states = self.process_encoder_hidden_states(
1170
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1171
+ )
1172
+
1173
+ # 2. pre-process
1174
+ sample = self.conv_in(sample)
1175
+
1176
+ # 2.5 GLIGEN position net
1177
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1178
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1179
+ gligen_args = cross_attention_kwargs.pop("gligen")
1180
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1181
+
1182
+ # 3. down
1183
+ # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
1184
+ # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
1185
+ if cross_attention_kwargs is not None:
1186
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1187
+ lora_scale = cross_attention_kwargs.pop("scale", 1.0)
1188
+ else:
1189
+ lora_scale = 1.0
1190
+
1191
+ if USE_PEFT_BACKEND:
1192
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1193
+ scale_lora_layers(self, lora_scale)
1194
+
1195
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1196
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1197
+ is_adapter = down_intrablock_additional_residuals is not None
1198
+ # maintain backward compatibility for legacy usage, where
1199
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1200
+ # but can only use one or the other
1201
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1202
+ deprecate(
1203
+ "T2I should not use down_block_additional_residuals",
1204
+ "1.3.0",
1205
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1206
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1207
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1208
+ standard_warn=False,
1209
+ )
1210
+ down_intrablock_additional_residuals = down_block_additional_residuals
1211
+ is_adapter = True
1212
+
1213
+ # torch.cuda.synchronize()
1214
+ # logger.info(f"unet preprocess: {time.time() - start_time}")
1215
+
1216
+ # torch.cuda.synchronize()
1217
+ # start_time = time.time()
1218
+ down_block_res_samples = (sample,)
1219
+ for downsample_block in self.down_blocks:
1220
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1221
+ # For t2i-adapter CrossAttnDownBlock2D
1222
+ additional_residuals = {}
1223
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1224
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1225
+
1226
+ sample, res_samples = downsample_block(
1227
+ hidden_states=sample,
1228
+ temb=emb,
1229
+ encoder_hidden_states=encoder_hidden_states,
1230
+ attention_mask=attention_mask,
1231
+ cross_attention_kwargs=cross_attention_kwargs,
1232
+ encoder_attention_mask=encoder_attention_mask,
1233
+ **additional_residuals,
1234
+ )
1235
+ else:
1236
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1237
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1238
+ sample += down_intrablock_additional_residuals.pop(0)
1239
+
1240
+ down_block_res_samples += res_samples
1241
+
1242
+ if is_controlnet:
1243
+ new_down_block_res_samples = ()
1244
+
1245
+ for down_block_res_sample, down_block_additional_residual in zip(
1246
+ down_block_res_samples, down_block_additional_residuals
1247
+ ):
1248
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1249
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1250
+
1251
+ down_block_res_samples = new_down_block_res_samples
1252
+ # torch.cuda.synchronize()
1253
+ # logger.info(f"unet down time: {time.time() - start_time}")
1254
+ # torch.cuda.synchronize()
1255
+ # start_time = time.time()
1256
+ # 4. mid
1257
+ if self.mid_block is not None:
1258
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1259
+ sample = self.mid_block(
1260
+ sample,
1261
+ emb,
1262
+ encoder_hidden_states=encoder_hidden_states,
1263
+ attention_mask=attention_mask,
1264
+ cross_attention_kwargs=cross_attention_kwargs,
1265
+ encoder_attention_mask=encoder_attention_mask,
1266
+ )
1267
+ else:
1268
+ sample = self.mid_block(sample, emb)
1269
+
1270
+ # To support T2I-Adapter-XL
1271
+ if (
1272
+ is_adapter
1273
+ and len(down_intrablock_additional_residuals) > 0
1274
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1275
+ ):
1276
+ sample += down_intrablock_additional_residuals.pop(0)
1277
+
1278
+ if is_controlnet:
1279
+ sample = sample + mid_block_additional_residual
1280
+ # torch.cuda.synchronize()
1281
+ # logger.info(f"unet mid time: {time.time() - start_time}")
1282
+ mid_sample = sample
1283
+
1284
+ if use_up_blocks:
1285
+ # 5. up
1286
+ up_block_res_samples = ()
1287
+ for i, upsample_block in enumerate(self.up_blocks):
1288
+ is_final_block = i == len(self.up_blocks) - 1
1289
+
1290
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1291
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1292
+
1293
+ # if we have not reached the final block and need to forward the
1294
+ # upsample size, we do it here
1295
+ if not is_final_block and forward_upsample_size:
1296
+ upsample_size = down_block_res_samples[-1].shape[2:]
1297
+
1298
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1299
+ sample = upsample_block(
1300
+ hidden_states=sample,
1301
+ temb=emb,
1302
+ res_hidden_states_tuple=res_samples,
1303
+ encoder_hidden_states=encoder_hidden_states,
1304
+ cross_attention_kwargs=cross_attention_kwargs,
1305
+ upsample_size=upsample_size,
1306
+ attention_mask=attention_mask,
1307
+ encoder_attention_mask=encoder_attention_mask,
1308
+ )
1309
+ else:
1310
+ sample = upsample_block(
1311
+ hidden_states=sample,
1312
+ temb=emb,
1313
+ res_hidden_states_tuple=res_samples,
1314
+ upsample_size=upsample_size,
1315
+ )
1316
+ up_block_res_samples += (sample, )
1317
+
1318
+ # # 6. post-process
1319
+ # if self.conv_norm_out:
1320
+ # sample = self.conv_norm_out(sample)
1321
+ # sample = self.conv_act(sample)
1322
+ # sample = self.conv_out(sample)
1323
+
1324
+ if USE_PEFT_BACKEND:
1325
+ # remove `lora_scale` from each PEFT layer
1326
+ unscale_lora_layers(self, lora_scale)
1327
+
1328
+ if not return_dict:
1329
+ if use_up_blocks:
1330
+ return (mid_sample, down_block_res_samples, up_block_res_samples)
1331
+ else:
1332
+ return (mid_sample, down_block_res_samples)
1333
+
1334
+ return UNet2DConditionOutput(sample=sample)
Reward_sdxl_idealized/pipelines/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (287 Bytes). View file
 
Reward_sdxl_idealized/pipelines/sdxl_gradient_ascent_pipeline.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Stable Diffusion XL Pipeline with Gradient Ascent Reward Guidance
3
+
4
+ Fully compatible with HuggingFace diffusers StableDiffusionXLPipeline.
5
+ Injects the Kinetic Latent Predictor-Corrector operator split safely
6
+ before the ODE integration step.
7
+ """
8
+
9
+ import inspect
10
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
11
+
12
+ import torch
13
+ from diffusers import StableDiffusionXLPipeline
14
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
15
+
16
+ from gradient_ascent_utils import RewardGuidedDiffusion
17
+
18
+
19
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
20
+ """
21
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure.
22
+ (Matches diffusers implementation)
23
+ """
24
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
25
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
26
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
27
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
28
+ return noise_cfg
29
+
30
+
31
+ class StableDiffusionXLGradientAscentPipeline(StableDiffusionXLPipeline):
32
+ """
33
+ SDXL Pipeline with KLPC gradient ascent reward guidance.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ vae,
39
+ text_encoder,
40
+ text_encoder_2,
41
+ tokenizer,
42
+ tokenizer_2,
43
+ unet,
44
+ scheduler,
45
+ image_encoder=None,
46
+ feature_extractor=None,
47
+ force_zeros_for_empty_prompt: bool = True,
48
+ add_watermarker: bool = None,
49
+ ):
50
+ super().__init__(
51
+ vae=vae,
52
+ text_encoder=text_encoder,
53
+ text_encoder_2=text_encoder_2,
54
+ tokenizer=tokenizer,
55
+ tokenizer_2=tokenizer_2,
56
+ unet=unet,
57
+ scheduler=scheduler,
58
+ image_encoder=image_encoder,
59
+ feature_extractor=feature_extractor,
60
+ force_zeros_for_empty_prompt=force_zeros_for_empty_prompt,
61
+ add_watermarker=add_watermarker,
62
+ )
63
+
64
+ # Initialize our KLPC custom state variables
65
+ self.gradient_ascent_enabled = False
66
+ self.grad_guidance = None
67
+ self.reward_model = None
68
+ self.reward_history = []
69
+
70
+ def set_reward_model(self, reward_model):
71
+ self.reward_model = reward_model
72
+
73
+ def enable_gradient_ascent(
74
+ self,
75
+ grad_timestep_range: Tuple[int, int] = (500, 700),
76
+ grad_scale: float = 1.0,
77
+ num_grad_steps: int = 5,
78
+ grad_step_size: float = 0.1,
79
+ lr_scheduler_type: str = "constant",
80
+ lr_scheduler_kwargs: Optional[dict] = None,
81
+ use_momentum: bool = False,
82
+ momentum: float = 0.9,
83
+ use_nesterov: bool = False,
84
+ use_iso_projection: bool = False
85
+ ):
86
+ if self.reward_model is None:
87
+ raise ValueError("Reward model must be set first via set_reward_model().")
88
+
89
+ self.grad_guidance = RewardGuidedDiffusion(
90
+ reward_model=self.reward_model,
91
+ grad_scale=grad_scale,
92
+ grad_timestep_range=grad_timestep_range,
93
+ num_grad_steps=num_grad_steps,
94
+ grad_step_size=grad_step_size,
95
+ lr_scheduler_type=lr_scheduler_type,
96
+ lr_scheduler_kwargs=lr_scheduler_kwargs or {},
97
+ use_momentum=use_momentum,
98
+ momentum=momentum,
99
+ use_nesterov=use_nesterov,
100
+ use_iso_projection=use_iso_projection
101
+ )
102
+ self.gradient_ascent_enabled = True
103
+
104
+ @torch.no_grad()
105
+ def __call__(
106
+ self,
107
+ prompt: Union[str, List[str]] = None,
108
+ prompt_2: Optional[Union[str, List[str]]] = None,
109
+ height: Optional[int] = None,
110
+ width: Optional[int] = None,
111
+ num_inference_steps: int = 50,
112
+ timesteps: List[int] = None,
113
+ sigmas: List[float] = None,
114
+ denoising_end: Optional[float] = None,
115
+ guidance_scale: float = 5.0,
116
+ negative_prompt: Optional[Union[str, List[str]]] = None,
117
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
118
+ num_images_per_prompt: Optional[int] = 1,
119
+ eta: float = 0.0,
120
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
121
+ latents: Optional[torch.Tensor] = None,
122
+ prompt_embeds: Optional[torch.Tensor] = None,
123
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
124
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
125
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
126
+ ip_adapter_image: Optional[Any] = None,
127
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
128
+ output_type: Optional[str] = "pil",
129
+ return_dict: bool = True,
130
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
131
+ guidance_rescale: float = 0.0,
132
+ original_size: Optional[Tuple[int, int]] = None,
133
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
134
+ target_size: Optional[Tuple[int, int]] = None,
135
+ negative_original_size: Optional[Tuple[int, int]] = None,
136
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
137
+ negative_target_size: Optional[Tuple[int, int]] = None,
138
+ clip_skip: Optional[int] = None,
139
+ callback_on_step_end: Optional[Callable[[int, int, Dict], Dict]] = None,
140
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
141
+ # Custom parameters for KLPC gradient ascent
142
+ track_rewards: bool = True,
143
+ apply_gradient_ascent: bool = True,
144
+ verbose_grad: bool = False,
145
+ **kwargs,
146
+ ):
147
+ # 1. Setup sizes & Batch
148
+ height = height or self.default_sample_size * self.vae_scale_factor
149
+ width = width or self.default_sample_size * self.vae_scale_factor
150
+
151
+ original_size = original_size or (height, width)
152
+ target_size = target_size or (height, width)
153
+
154
+ # 2. Define batch size
155
+ if prompt is not None and isinstance(prompt, str):
156
+ batch_size = 1
157
+ elif prompt is not None and isinstance(prompt, list):
158
+ batch_size = len(prompt)
159
+ else:
160
+ batch_size = prompt_embeds.shape[0]
161
+
162
+ device = self._execution_device
163
+
164
+ do_classifier_free_guidance = guidance_scale > 1.0
165
+
166
+ # 3. Encode input prompt
167
+ lora_scale = (
168
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
169
+ )
170
+
171
+ (
172
+ prompt_embeds,
173
+ negative_prompt_embeds,
174
+ pooled_prompt_embeds,
175
+ negative_pooled_prompt_embeds,
176
+ ) = self.encode_prompt(
177
+ prompt=prompt,
178
+ prompt_2=prompt_2,
179
+ device=device,
180
+ num_images_per_prompt=num_images_per_prompt,
181
+ do_classifier_free_guidance=do_classifier_free_guidance,
182
+ negative_prompt=negative_prompt,
183
+ negative_prompt_2=negative_prompt_2,
184
+ prompt_embeds=prompt_embeds,
185
+ negative_prompt_embeds=negative_prompt_embeds,
186
+ pooled_prompt_embeds=pooled_prompt_embeds,
187
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
188
+ lora_scale=lora_scale,
189
+ clip_skip=clip_skip,
190
+ )
191
+
192
+ # Keep HF internal state synchronized just in case external callbacks need them
193
+ self._guidance_scale = guidance_scale
194
+ self._guidance_rescale = guidance_rescale
195
+ self._cross_attention_kwargs = cross_attention_kwargs
196
+ self._interrupt = False
197
+
198
+ # 4. Prepare timesteps
199
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
200
+ timesteps = self.scheduler.timesteps
201
+
202
+ # 5. Prepare latent variables
203
+ num_channels_latents = self.unet.config.in_channels
204
+ latents = self.prepare_latents(
205
+ batch_size * num_images_per_prompt,
206
+ num_channels_latents,
207
+ height,
208
+ width,
209
+ prompt_embeds.dtype,
210
+ device,
211
+ generator,
212
+ latents,
213
+ )
214
+
215
+ # 6. Prepare extra step kwargs
216
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
217
+
218
+ # 7. Prepare added time ids & embeddings
219
+ add_text_embeds = pooled_prompt_embeds
220
+ if self.text_encoder_2 is None:
221
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
222
+ else:
223
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
224
+
225
+ add_time_ids = self._get_add_time_ids(
226
+ original_size,
227
+ crops_coords_top_left,
228
+ target_size,
229
+ dtype=prompt_embeds.dtype,
230
+ text_encoder_projection_dim=text_encoder_projection_dim,
231
+ )
232
+
233
+ if negative_original_size is not None and negative_target_size is not None:
234
+ negative_add_time_ids = self._get_add_time_ids(
235
+ negative_original_size,
236
+ negative_crops_coords_top_left,
237
+ negative_target_size,
238
+ dtype=prompt_embeds.dtype,
239
+ text_encoder_projection_dim=text_encoder_projection_dim,
240
+ )
241
+ else:
242
+ negative_add_time_ids = add_time_ids
243
+
244
+ if do_classifier_free_guidance:
245
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
246
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
247
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
248
+
249
+ prompt_embeds = prompt_embeds.to(device)
250
+ add_text_embeds = add_text_embeds.to(device)
251
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
252
+
253
+ # 8. Reset Tracking
254
+ self.reward_history = []
255
+ if self.grad_guidance is not None:
256
+ self.grad_guidance.reset_statistics()
257
+
258
+ # Extract string prompt for the reward model forward pass
259
+ prompt_str = prompt[0] if isinstance(prompt, list) else prompt
260
+
261
+ # 9. Denoising loop
262
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
263
+
264
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
265
+ for i, t in enumerate(timesteps):
266
+
267
+ # predict the noise residual
268
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
269
+
270
+ # =========================================================================
271
+ # KINETIC LATENT PREDICTOR-CORRECTOR (KLPC) OPERATOR SPLIT
272
+ # =========================================================================
273
+ # This executes the spatial displacement C_i(z) right BEFORE the scheduler
274
+ # integrates using the "stale" velocity (noise_pred), fulfilling Theorem 1.
275
+ if (
276
+ self.gradient_ascent_enabled
277
+ and apply_gradient_ascent
278
+ and self.grad_guidance
279
+ and self.grad_guidance.should_apply_gradient(t.item())
280
+ and prompt_str is not None
281
+ ):
282
+ with torch.enable_grad():
283
+ latents, grad_stats = self.grad_guidance.apply_gradient_ascent(
284
+ latents,
285
+ prompt_str,
286
+ t.item(),
287
+ base_noise=None,
288
+ verbose=verbose_grad,
289
+ total_denoising_steps=num_inference_steps,
290
+ )
291
+
292
+ latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
293
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
294
+
295
+ noise_pred = self.unet(
296
+ latent_model_input,
297
+ t,
298
+ encoder_hidden_states=prompt_embeds,
299
+ cross_attention_kwargs=cross_attention_kwargs,
300
+ added_cond_kwargs=added_cond_kwargs,
301
+ return_dict=False,
302
+ )[0]
303
+
304
+ if do_classifier_free_guidance:
305
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
306
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
307
+
308
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
309
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
310
+
311
+ # Track Rewards
312
+ if track_rewards and self.reward_model is not None and prompt_str is not None:
313
+ with torch.no_grad():
314
+ score = self.reward_model.get_reward_score(latents, prompt_str, t.item())
315
+ score_val = score.item() if score.numel() == 1 else score.mean().item()
316
+ self.reward_history.append({'timestep': t.item(), 'reward_score': score_val})
317
+ # =========================================================================
318
+
319
+ # compute the previous noisy sample x_t -> x_t-1
320
+ latents_dtype = latents.dtype
321
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
322
+
323
+ if latents.dtype != latents_dtype:
324
+ latents = latents.to(latents_dtype)
325
+
326
+ # Execute diffusers native step end callback if provided
327
+ if callback_on_step_end is not None:
328
+ callback_kwargs = {}
329
+ for k in callback_on_step_end_tensor_inputs:
330
+ callback_kwargs[k] = locals()[k]
331
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
332
+
333
+ latents = callback_outputs.pop("latents", latents)
334
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
335
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
336
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
337
+
338
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
339
+ progress_bar.update()
340
+
341
+ if not output_type == "latent":
342
+ # make sure the VAE is in float32 mode, as it overflows in float16
343
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
344
+ if needs_upcasting:
345
+ self.upcast_vae()
346
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
347
+ elif latents.dtype != self.vae.dtype:
348
+ # HF Native fix: Cast the VAE to match the latents dtype!
349
+ self.vae = self.vae.to(latents.dtype)
350
+
351
+ # unscale/denormalize the latents
352
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
353
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
354
+
355
+ if has_latents_mean and has_latents_std:
356
+ latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
357
+ latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
358
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
359
+ else:
360
+ latents = latents / self.vae.config.scaling_factor
361
+
362
+ image = self.vae.decode(latents, return_dict=False)[0]
363
+
364
+ # cast back to fp16 if needed
365
+ if needs_upcasting:
366
+ self.vae.to(dtype=torch.float16)
367
+
368
+ image = self.image_processor.postprocess(image, output_type=output_type)
369
+ else:
370
+ image = latents
371
+
372
+ if not return_dict:
373
+ return (image,)
374
+
375
+ return StableDiffusionXLPipelineOutput(images=image)
Reward_sdxl_idealized/timestep_convergence_analysis.ipynb ADDED
@@ -0,0 +1,1105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "513b682a",
6
+ "metadata": {},
7
+ "source": [
8
+ "#### Setup and Imports"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "f9f44d61",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import os\n",
19
+ "import sys\n",
20
+ "import json\n",
21
+ "import warnings\n",
22
+ "import numpy as np\n",
23
+ "import matplotlib.pyplot as plt\n",
24
+ "import matplotlib\n",
25
+ "import torch\n",
26
+ "import torch.nn as nn\n",
27
+ "from pathlib import Path\n",
28
+ "from PIL import Image\n",
29
+ "from tqdm.auto import tqdm\n",
30
+ "from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel\n",
31
+ "from transformers import CLIPModel, CLIPProcessor\n",
32
+ "from torchmetrics.image.fid import FrechetInceptionDistance\n",
33
+ "from torchmetrics.multimodal import CLIPScore\n",
34
+ "\n",
35
+ "warnings.filterwarnings(\"ignore\")\n",
36
+ "\n",
37
+ "# Import local modules\n",
38
+ "from models import LRMRewardModel\n",
39
+ "from pipelines.sd15_gradient_ascent_pipeline import StableDiffusionGradientAscentPipeline\n",
40
+ "from grad_ascent_configs import get_config, list_configs\n",
41
+ "\n",
42
+ "# Import evaluation metrics\n",
43
+ "sys.path.append('../evaluation')\n",
44
+ "from pick_score import PickScorer\n",
45
+ "from hpsv2_score import HPSv2Scorer\n",
46
+ "from imagereward_score import load_imagereward\n"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "markdown",
51
+ "id": "1740dd7c",
52
+ "metadata": {},
53
+ "source": [
54
+ "#### Configuration"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "id": "f1bc2b07",
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "# ============ CONFIGURATION ============\n",
65
+ "\n",
66
+ "# Dataset\n",
67
+ "DATA_DIR = \"./data\"\n",
68
+ "DATASET_TYPE = \"coco\" # \"coco\" or \"pickapic\"\n",
69
+ "NUM_SAMPLES = 20 # Number of samples to analyze\n",
70
+ "\n",
71
+ "# Model\n",
72
+ "BASE_MODEL = \"runwayml/stable-diffusion-v1-5\"\n",
73
+ "MODEL_VARIANT = \"lpo\" # \"origin\", \"spo\", \"diffusion_dpo\", \"lpo\"\n",
74
+ "LRM_MODEL = \"casiatao/LRM\"\n",
75
+ "\n",
76
+ "# Generation\n",
77
+ "NUM_INFERENCE_STEPS = 100\n",
78
+ "CFG_SCALE = 5.0\n",
79
+ "SEED = 42\n",
80
+ "BATCH_SIZE = 1\n",
81
+ "\n",
82
+ "# Gradient Ascent Config\n",
83
+ "GRAD_CONFIG = \"low_to_high_nesterov\" # Use None for manual config, or specify preset name\n",
84
+ "GRAD_RANGE_START = 0\n",
85
+ "GRAD_RANGE_END = 500\n",
86
+ "GRAD_STEPS = 1\n",
87
+ "GRAD_STEP_SIZE = 0.1\n",
88
+ "\n",
89
+ "# Metrics to compute\n",
90
+ "METRICS = [\"reward\", \"clip\", \"aesthetic\", \"pickscore\", \"hpsv2\", \"fid\"] # Add/remove as needed\n",
91
+ "\n",
92
+ "# Device\n",
93
+ "CUDA_DEVICE = 0\n",
94
+ "device = f\"cuda:{CUDA_DEVICE}\" if torch.cuda.is_available() else \"cpu\"\n",
95
+ "dtype = torch.float16 if torch.cuda.is_available() else torch.float32\n",
96
+ "\n",
97
+ "# Output\n",
98
+ "OUTPUT_DIR = \"timestep_analysis_results\"\n",
99
+ "os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
100
+ "\n",
101
+ "print(f\"Device: {device}\")\n",
102
+ "print(f\"Dataset: {DATASET_TYPE}\")\n",
103
+ "print(f\"Samples to analyze: {NUM_SAMPLES}\")\n",
104
+ "print(f\"Metrics: {METRICS}\")\n",
105
+ "print(f\"Output directory: {OUTPUT_DIR}\")"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "markdown",
110
+ "id": "1b1b6d02",
111
+ "metadata": {},
112
+ "source": [
113
+ "#### Load Dataset"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": null,
119
+ "id": "a74b2816",
120
+ "metadata": {},
121
+ "outputs": [],
122
+ "source": [
123
+ "def load_validation_data(data_dir, max_samples=None):\n",
124
+ " \"\"\"Load COCO validation prompts and image paths.\"\"\"\n",
125
+ " data_dir = Path(data_dir)\n",
126
+ " val_json = data_dir / \"coco\" / \"caption_val.json\"\n",
127
+ " \n",
128
+ " if not val_json.exists():\n",
129
+ " raise FileNotFoundError(f\"Validation data not found at {val_json}\")\n",
130
+ " \n",
131
+ " with open(val_json, 'r') as f:\n",
132
+ " data = json.load(f)\n",
133
+ " \n",
134
+ " print(f\"Loaded JSON with {len(data)} entries\")\n",
135
+ " \n",
136
+ " # Validate that image folder exists\n",
137
+ " val_img_dir = data_dir / \"coco\" / \"images\" / \"val\"\n",
138
+ " if not val_img_dir.exists():\n",
139
+ " print(f\"Warning: Standard validation directory not found: {val_img_dir}\")\n",
140
+ " \n",
141
+ " # Parse data - img_path already contains \"images/val/\" prefix\n",
142
+ " prompts = []\n",
143
+ " image_paths = []\n",
144
+ " \n",
145
+ " for img_path, caption in data.items():\n",
146
+ " # Try the path as given (relative to data_dir/coco/)\n",
147
+ " full_path = data_dir / \"coco\" / img_path\n",
148
+ " if full_path.exists():\n",
149
+ " prompts.append(caption)\n",
150
+ " image_paths.append(str(full_path))\n",
151
+ " \n",
152
+ " print(f\"Found {len(prompts)} valid image-caption pairs\")\n",
153
+ " \n",
154
+ " if len(prompts) == 0:\n",
155
+ " print(f\"\\n⚠ WARNING: No valid images found!\")\n",
156
+ " print(f\"Debug information:\")\n",
157
+ " print(f\" JSON file: {val_json}\")\n",
158
+ " print(f\" JSON entries: {len(data)}\")\n",
159
+ " print(f\" Sample keys from JSON: {list(data.keys())[:3]}\")\n",
160
+ " \n",
161
+ " # Check if images exist at all\n",
162
+ " coco_dir = data_dir / \"coco\"\n",
163
+ " if coco_dir.exists():\n",
164
+ " print(f\" COCO dir exists: {coco_dir}\")\n",
165
+ " # List subdirectories\n",
166
+ " subdirs = [d.name for d in coco_dir.iterdir() if d.is_dir()]\n",
167
+ " print(f\" Subdirectories in COCO: {subdirs}\")\n",
168
+ " \n",
169
+ " # Try to find images\n",
170
+ " if val_img_dir.exists():\n",
171
+ " img_files = list(val_img_dir.glob(\"*.jpg\"))[:5]\n",
172
+ " print(f\" Sample images in val dir: {[f.name for f in img_files]}\")\n",
173
+ " \n",
174
+ " if max_samples and len(prompts) > 0:\n",
175
+ " prompts = prompts[:max_samples]\n",
176
+ " image_paths = image_paths[:max_samples]\n",
177
+ " \n",
178
+ " return prompts, image_paths\n",
179
+ "\n",
180
+ "# Load data\n",
181
+ "prompts, image_paths = load_validation_data(DATA_DIR, NUM_SAMPLES)\n",
182
+ "print(f\"\\n✓ Loaded {len(prompts)} samples\")\n",
183
+ "\n",
184
+ "if len(prompts) > 0:\n",
185
+ " print(f\"\\nSample prompts:\")\n",
186
+ " for i, prompt in enumerate(prompts[:3]):\n",
187
+ " print(f\" {i+1}. {prompt[:80]}...\")\n",
188
+ " print(f\"\\nSample image paths:\")\n",
189
+ " for i, path in enumerate(image_paths[:3]):\n",
190
+ " print(f\" {i+1}. {path}\")\n",
191
+ "else:\n",
192
+ " print(\"\\n❌ ERROR: No samples loaded! Please check your data directory structure.\")\n",
193
+ " print(\"Expected structure:\")\n",
194
+ " print(\" ./data/coco/caption_val.json\")\n",
195
+ " print(\" ./data/coco/images/val/*.jpg\")"
196
+ ]
197
+ },
198
+ {
199
+ "cell_type": "markdown",
200
+ "id": "5ceae64a",
201
+ "metadata": {},
202
+ "source": [
203
+ "#### Load Models and Scorers"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "execution_count": null,
209
+ "id": "43ad1f56",
210
+ "metadata": {},
211
+ "outputs": [],
212
+ "source": [
213
+ "# ============ MLP for Aesthetic Scoring ============\n",
214
+ "class MLP(nn.Module):\n",
215
+ " def __init__(self):\n",
216
+ " super().__init__()\n",
217
+ " self.layers = nn.Sequential(\n",
218
+ " nn.Linear(768, 1024),\n",
219
+ " nn.Dropout(0.2),\n",
220
+ " nn.Linear(1024, 128),\n",
221
+ " nn.Dropout(0.2),\n",
222
+ " nn.Linear(128, 64),\n",
223
+ " nn.Dropout(0.1),\n",
224
+ " nn.Linear(64, 16),\n",
225
+ " nn.Linear(16, 1),\n",
226
+ " )\n",
227
+ " \n",
228
+ " @torch.no_grad()\n",
229
+ " def forward(self, embed):\n",
230
+ " return self.layers(embed)\n",
231
+ "\n",
232
+ "class AestheticScorer(torch.nn.Module):\n",
233
+ " def __init__(self, dtype, device):\n",
234
+ " super().__init__()\n",
235
+ " self.clip = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
236
+ " self.processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
237
+ " self.mlp = MLP()\n",
238
+ " \n",
239
+ " aesthetic_path = \"../evaluation/sac+logos+ava1-l14-linearMSE.pth\"\n",
240
+ " if os.path.exists(aesthetic_path):\n",
241
+ " state_dict = torch.load(aesthetic_path, map_location='cpu')\n",
242
+ " self.mlp.load_state_dict(state_dict)\n",
243
+ " \n",
244
+ " self.dtype = dtype\n",
245
+ " self.to(device)\n",
246
+ " self.eval()\n",
247
+ " \n",
248
+ " @torch.no_grad()\n",
249
+ " def __call__(self, images):\n",
250
+ " if not isinstance(images, list):\n",
251
+ " images = [images]\n",
252
+ " inputs = self.processor(images=images, return_tensors=\"pt\", padding=True)\n",
253
+ " inputs = {k: v.to(self.clip.device) for k, v in inputs.items()}\n",
254
+ " image_embeds = self.clip.get_image_features(**inputs)\n",
255
+ " image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)\n",
256
+ " scores = self.mlp(image_embeds.float())\n",
257
+ " return scores.squeeze().cpu().numpy()\n",
258
+ "\n",
259
+ "print(\"Loading models...\")"
260
+ ]
261
+ },
262
+ {
263
+ "cell_type": "code",
264
+ "execution_count": null,
265
+ "id": "36a70595",
266
+ "metadata": {},
267
+ "outputs": [],
268
+ "source": [
269
+ "# Load Reward Model\n",
270
+ "print(\"Loading reward model...\")\n",
271
+ "reward_model = LRMRewardModel(\n",
272
+ " pretrained_model_name_or_path=BASE_MODEL,\n",
273
+ " lrm_model_path=LRM_MODEL,\n",
274
+ " guidance_scale=CFG_SCALE,\n",
275
+ " device=device\n",
276
+ ")\n",
277
+ "if dtype == torch.float16:\n",
278
+ " reward_model = reward_model.half()\n",
279
+ "reward_model.eval()\n",
280
+ "print(\"✓ Reward model loaded\")\n",
281
+ "\n",
282
+ "# Load Pipeline\n",
283
+ "print(\"\\nLoading diffusion pipeline...\")\n",
284
+ "if MODEL_VARIANT == \"origin\":\n",
285
+ " base_pipeline = StableDiffusionPipeline.from_pretrained(\n",
286
+ " BASE_MODEL, torch_dtype=dtype, safety_checker=None\n",
287
+ " )\n",
288
+ "elif MODEL_VARIANT == \"spo\":\n",
289
+ " base_pipeline = StableDiffusionPipeline.from_pretrained(\n",
290
+ " 'SPO-Diffusion-Models/SPO-SD-v1-5_4k-p_10ep',\n",
291
+ " torch_dtype=dtype, safety_checker=None\n",
292
+ " )\n",
293
+ " CFG_SCALE = 5.0\n",
294
+ "elif MODEL_VARIANT == \"diffusion_dpo\":\n",
295
+ " unet = UNet2DConditionModel.from_pretrained(\n",
296
+ " 'mhdang/dpo-sd1.5-text2image-v1', subfolder=\"unet\", torch_dtype=dtype\n",
297
+ " )\n",
298
+ " base_pipeline = StableDiffusionPipeline.from_pretrained(\n",
299
+ " BASE_MODEL, torch_dtype=dtype, safety_checker=None, unet=unet\n",
300
+ " )\n",
301
+ "elif MODEL_VARIANT == \"lpo\":\n",
302
+ " unet = UNet2DConditionModel.from_pretrained(\n",
303
+ " 'casiatao/LPO', subfolder=\"lpo_sd15_merge/unet\", torch_dtype=dtype\n",
304
+ " )\n",
305
+ " base_pipeline = StableDiffusionPipeline.from_pretrained(\n",
306
+ " BASE_MODEL, torch_dtype=dtype, safety_checker=None, unet=unet\n",
307
+ " )\n",
308
+ " CFG_SCALE = 5.0\n",
309
+ "\n",
310
+ "pipeline = StableDiffusionGradientAscentPipeline(**base_pipeline.components)\n",
311
+ "pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)\n",
312
+ "pipeline = pipeline.to(device)\n",
313
+ "pipeline.set_reward_model(reward_model)\n",
314
+ "print(\"✓ Pipeline loaded\")"
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "code",
319
+ "execution_count": null,
320
+ "id": "4e18f075",
321
+ "metadata": {},
322
+ "outputs": [],
323
+ "source": [
324
+ "# Load Metric Scorers\n",
325
+ "print(\"\\nLoading metric scorers...\")\n",
326
+ "\n",
327
+ "clip_scorer = None\n",
328
+ "aesthetic_scorer = None\n",
329
+ "pick_scorer = None\n",
330
+ "hpsv2_scorer = None\n",
331
+ "imagereward_scorer = None\n",
332
+ "\n",
333
+ "if \"clip\" in METRICS:\n",
334
+ " print(\" Loading CLIP scorer...\")\n",
335
+ " clip_scorer = CLIPScore(model_name_or_path=\"openai/clip-vit-base-patch16\").to(device)\n",
336
+ " print(\" ✓ CLIP scorer loaded\")\n",
337
+ "\n",
338
+ "if \"aesthetic\" in METRICS:\n",
339
+ " print(\" Loading Aesthetic scorer...\")\n",
340
+ " aesthetic_scorer = AestheticScorer(dtype, device)\n",
341
+ " print(\" ✓ Aesthetic scorer loaded\")\n",
342
+ "\n",
343
+ "if \"pickscore\" in METRICS:\n",
344
+ " print(\" Loading PickScore scorer...\")\n",
345
+ " try:\n",
346
+ " pick_scorer = PickScorer(device=device, dtype=dtype)\n",
347
+ " print(\" ✓ PickScore loaded\")\n",
348
+ " except Exception as e:\n",
349
+ " print(f\" ✗ PickScore failed: {e}\")\n",
350
+ " METRICS.remove(\"pickscore\")\n",
351
+ "\n",
352
+ "if \"hpsv2\" in METRICS:\n",
353
+ " print(\" Loading HPSv2 scorer...\")\n",
354
+ " try:\n",
355
+ " hpsv2_scorer = HPSv2Scorer(device=device, dtype=dtype)\n",
356
+ " print(\" ✓ HPSv2 loaded\")\n",
357
+ " except Exception as e:\n",
358
+ " print(f\" ✗ HPSv2 failed: {e}\")\n",
359
+ " METRICS.remove(\"hpsv2\")\n",
360
+ "\n",
361
+ "if \"imagereward\" in METRICS:\n",
362
+ " print(\" Loading ImageReward scorer...\")\n",
363
+ " try:\n",
364
+ " imagereward_scorer = load_imagereward(device=device)\n",
365
+ " print(\" ✓ ImageReward loaded\")\n",
366
+ " except Exception as e:\n",
367
+ " print(f\" ✗ ImageReward failed: {e}\")\n",
368
+ " METRICS.remove(\"imagereward\")\n",
369
+ "\n",
370
+ "print(f\"\\n✓ Active metrics: {METRICS}\")"
371
+ ]
372
+ },
373
+ {
374
+ "cell_type": "markdown",
375
+ "id": "70ac047b",
376
+ "metadata": {},
377
+ "source": [
378
+ "#### Configure Gradient Ascent"
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "code",
383
+ "execution_count": null,
384
+ "id": "05996448",
385
+ "metadata": {},
386
+ "outputs": [],
387
+ "source": [
388
+ "# Configure gradient ascent\n",
389
+ "if GRAD_CONFIG:\n",
390
+ " print(f\"Loading gradient ascent config: {GRAD_CONFIG}\")\n",
391
+ " grad_config = get_config(GRAD_CONFIG)\n",
392
+ " print(f\"Config: {grad_config}\")\n",
393
+ "else:\n",
394
+ " grad_config = {\n",
395
+ " \"grad_timestep_range\": (GRAD_RANGE_START, GRAD_RANGE_END),\n",
396
+ " \"num_grad_steps\": GRAD_STEPS,\n",
397
+ " \"grad_step_size\": GRAD_STEP_SIZE,\n",
398
+ " }\n",
399
+ " print(f\"Manual gradient ascent configuration: {grad_config}\")\n",
400
+ "\n",
401
+ "pipeline.enable_gradient_ascent(**grad_config)\n",
402
+ "print(\"\\n✓ Gradient ascent enabled\")"
403
+ ]
404
+ },
405
+ {
406
+ "cell_type": "markdown",
407
+ "id": "1f82c3df",
408
+ "metadata": {},
409
+ "source": [
410
+ "#### Timestep Analysis Functions"
411
+ ]
412
+ },
413
+ {
414
+ "cell_type": "code",
415
+ "execution_count": null,
416
+ "id": "e836d8f2",
417
+ "metadata": {},
418
+ "outputs": [],
419
+ "source": [
420
+ "def latents_to_images(latents, vae):\n",
421
+ " \"\"\"Convert latents to PIL images.\"\"\"\n",
422
+ " latents = 1 / 0.18215 * latents\n",
423
+ " with torch.no_grad():\n",
424
+ " images = vae.decode(latents).sample\n",
425
+ " images = (images / 2 + 0.5).clamp(0, 1)\n",
426
+ " images = images.cpu().permute(0, 2, 3, 1).numpy()\n",
427
+ " images = (images * 255).round().astype(\"uint8\")\n",
428
+ " pil_images = [Image.fromarray(image) for image in images]\n",
429
+ " return pil_images\n",
430
+ "\n",
431
+ "\n",
432
+ "def compute_metrics_for_image(image, prompt, reference_image=None):\n",
433
+ " \"\"\"Compute all metrics for a single image.\"\"\"\n",
434
+ " metrics = {}\n",
435
+ " \n",
436
+ " # CLIP Score\n",
437
+ " if clip_scorer is not None:\n",
438
+ " img_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).unsqueeze(0).to(device)\n",
439
+ " with torch.no_grad():\n",
440
+ " clip_score = clip_scorer(img_tensor, prompt).item()\n",
441
+ " metrics['clip'] = clip_score\n",
442
+ " \n",
443
+ " # Aesthetic Score\n",
444
+ " if aesthetic_scorer is not None:\n",
445
+ " aesthetic_score = aesthetic_scorer([image])\n",
446
+ " if isinstance(aesthetic_score, np.ndarray):\n",
447
+ " aesthetic_score = aesthetic_score.item()\n",
448
+ " metrics['aesthetic'] = aesthetic_score\n",
449
+ " \n",
450
+ " # PickScore\n",
451
+ " if pick_scorer is not None:\n",
452
+ " pick_score = pick_scorer.score(prompt, [image])[0]\n",
453
+ " metrics['pickscore'] = pick_score\n",
454
+ " \n",
455
+ " # HPSv2\n",
456
+ " if hpsv2_scorer is not None:\n",
457
+ " hpsv2_score = hpsv2_scorer.score(prompt, [image])[0]\n",
458
+ " metrics['hpsv2'] = hpsv2_score\n",
459
+ " \n",
460
+ " # ImageReward\n",
461
+ " if imagereward_scorer is not None:\n",
462
+ " imagereward_score = imagereward_scorer.score(prompt, [image])[0]\n",
463
+ " metrics['imagereward'] = imagereward_score\n",
464
+ " \n",
465
+ " # FID (if reference image provided)\n",
466
+ " if reference_image is not None:\n",
467
+ " try:\n",
468
+ " fid_metric = FrechetInceptionDistance(normalize=True).to(device)\n",
469
+ " \n",
470
+ " # Process reference image\n",
471
+ " ref_img = Image.open(reference_image).convert('RGB').resize((299, 299))\n",
472
+ " ref_tensor = torch.from_numpy(np.array(ref_img)).permute(2, 0, 1).unsqueeze(0).to(device)\n",
473
+ " \n",
474
+ " # Process generated image\n",
475
+ " gen_img = image.resize((299, 299))\n",
476
+ " gen_tensor = torch.from_numpy(np.array(gen_img)).permute(2, 0, 1).unsqueeze(0).to(device)\n",
477
+ " \n",
478
+ " if ref_tensor.size(0) == 1:\n",
479
+ " ref_tensor = ref_tensor.repeat(2, 1, 1, 1)\n",
480
+ " if gen_tensor.size(0) == 1:\n",
481
+ " gen_tensor = gen_tensor.repeat(2, 1, 1, 1)\n",
482
+ " \n",
483
+ " fid_metric.update(ref_tensor, real=True)\n",
484
+ " fid_metric.update(gen_tensor, real=False)\n",
485
+ " \n",
486
+ " fid_score = fid_metric.compute().item()/10\n",
487
+ " metrics['fid'] = fid_score\n",
488
+ " except Exception as e:\n",
489
+ " print(f\"FID computation failed: {e}\")\n",
490
+ " \n",
491
+ " return metrics\n",
492
+ "\n",
493
+ "\n",
494
+ "def analyze_sample_timesteps(prompt, reference_image, sample_idx):\n",
495
+ " \"\"\"\n",
496
+ " Generate images and track metrics at each timestep.\n",
497
+ " Returns timestep-wise metrics and intermediate images.\n",
498
+ " \"\"\"\n",
499
+ " print(f\"\\n{'='*70}\")\n",
500
+ " print(f\"Analyzing Sample {sample_idx + 1}\")\n",
501
+ " print(f\"Prompt: {prompt[:80]}...\")\n",
502
+ " print(f\"{'='*70}\")\n",
503
+ " \n",
504
+ " # Storage for results\n",
505
+ " timestep_metrics = {\n",
506
+ " 'timesteps': [],\n",
507
+ " 'reward': [],\n",
508
+ " 'clip': [],\n",
509
+ " 'aesthetic': [],\n",
510
+ " 'pickscore': [],\n",
511
+ " 'hpsv2': [],\n",
512
+ " 'imagereward': [],\n",
513
+ " 'fid': []\n",
514
+ " }\n",
515
+ " intermediate_images = []\n",
516
+ " \n",
517
+ " # Reset gradient stats\n",
518
+ " if hasattr(pipeline, 'grad_guidance'):\n",
519
+ " pipeline.grad_guidance.reset_statistics()\n",
520
+ " \n",
521
+ " # Modified pipeline call to capture intermediate latents\n",
522
+ " generator = torch.Generator(device=device).manual_seed(SEED + sample_idx)\n",
523
+ " \n",
524
+ " # We'll manually step through the denoising process\n",
525
+ " pipeline.set_progress_bar_config(disable=True)\n",
526
+ " \n",
527
+ " # Prepare inputs\n",
528
+ " height = pipeline.unet.config.sample_size * pipeline.vae_scale_factor\n",
529
+ " width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor\n",
530
+ " \n",
531
+ " # Encode prompt\n",
532
+ " text_embeddings = pipeline._encode_prompt(\n",
533
+ " prompt, device, 1, True, None\n",
534
+ " )\n",
535
+ " \n",
536
+ " # Prepare timesteps\n",
537
+ " pipeline.scheduler.set_timesteps(NUM_INFERENCE_STEPS, device=device)\n",
538
+ " timesteps = pipeline.scheduler.timesteps\n",
539
+ " \n",
540
+ " # Prepare latents\n",
541
+ " shape = (1, pipeline.unet.config.in_channels, height // 8, width // 8)\n",
542
+ " latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)\n",
543
+ " latents = latents * pipeline.scheduler.init_noise_sigma\n",
544
+ " \n",
545
+ " # Denoising loop with metric tracking\n",
546
+ " for i, t in enumerate(tqdm(timesteps, desc=\"Denoising steps\")):\n",
547
+ " # Apply gradient ascent if enabled\n",
548
+ " if hasattr(pipeline, 'grad_guidance') and pipeline.grad_guidance:\n",
549
+ " if pipeline.grad_guidance.should_apply_gradient(t.item()):\n",
550
+ " latents, grad_stats = pipeline.grad_guidance.apply_gradient_ascent(\n",
551
+ " latents, prompt, t.item(), verbose=False,\n",
552
+ " total_denoising_steps=len(timesteps)\n",
553
+ " )\n",
554
+ " \n",
555
+ " # Expand latents for classifier free guidance\n",
556
+ " latent_model_input = torch.cat([latents] * 2)\n",
557
+ " latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)\n",
558
+ " \n",
559
+ " # Predict noise\n",
560
+ " with torch.no_grad():\n",
561
+ " noise_pred = pipeline.unet(\n",
562
+ " latent_model_input,\n",
563
+ " t,\n",
564
+ " encoder_hidden_states=text_embeddings,\n",
565
+ " ).sample\n",
566
+ " \n",
567
+ " # Perform guidance\n",
568
+ " noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n",
569
+ " noise_pred = noise_pred_uncond + CFG_SCALE * (noise_pred_text - noise_pred_uncond)\n",
570
+ " \n",
571
+ " # Compute previous noisy sample\n",
572
+ " latents = pipeline.scheduler.step(noise_pred, t, latents).prev_sample\n",
573
+ " \n",
574
+ " # Decode latents to image every few steps\n",
575
+ " if i % 5 == 0 or i == len(timesteps) - 1:\n",
576
+ " # Convert to image\n",
577
+ " images = latents_to_images(latents, pipeline.vae)\n",
578
+ " image = images[0]\n",
579
+ " \n",
580
+ " # Compute reward\n",
581
+ " with torch.no_grad():\n",
582
+ " reward = reward_model.get_reward_score(latents, prompt, t.item())\n",
583
+ " reward_val = reward.mean().item() if reward.numel() > 1 else reward.item()\n",
584
+ " \n",
585
+ " # Compute other metrics\n",
586
+ " metrics = compute_metrics_for_image(image, prompt, reference_image)\n",
587
+ " \n",
588
+ " # Store results\n",
589
+ " timestep_metrics['timesteps'].append(t.item())\n",
590
+ " timestep_metrics['reward'].append(reward_val)\n",
591
+ " \n",
592
+ " for metric_name in ['clip', 'aesthetic', 'pickscore', 'hpsv2', 'imagereward', 'fid']:\n",
593
+ " if metric_name in metrics:\n",
594
+ " timestep_metrics[metric_name].append(metrics[metric_name])\n",
595
+ " else:\n",
596
+ " timestep_metrics[metric_name].append(None)\n",
597
+ " \n",
598
+ " intermediate_images.append(image)\n",
599
+ " \n",
600
+ " print(f\" Step {i}/{len(timesteps)} | t={t.item():.0f} | Reward={reward_val:.4f}\")\n",
601
+ " \n",
602
+ " # Final image\n",
603
+ " final_images = latents_to_images(latents, pipeline.vae)\n",
604
+ " final_image = final_images[0]\n",
605
+ " \n",
606
+ " pipeline.set_progress_bar_config(disable=False)\n",
607
+ " \n",
608
+ " return timestep_metrics, intermediate_images, final_image\n",
609
+ "\n",
610
+ "print(\"✓ Analysis functions defined\")"
611
+ ]
612
+ },
613
+ {
614
+ "cell_type": "markdown",
615
+ "id": "fc089bfd",
616
+ "metadata": {},
617
+ "source": [
618
+ "#### Run Timestep Analysis"
619
+ ]
620
+ },
621
+ {
622
+ "cell_type": "code",
623
+ "execution_count": null,
624
+ "id": "67c25164",
625
+ "metadata": {},
626
+ "outputs": [],
627
+ "source": [
628
+ "# Run analysis for all samples\n",
629
+ "all_results = []\n",
630
+ "\n",
631
+ "for idx in range(len(prompts)):\n",
632
+ " prompt = prompts[idx]\n",
633
+ " reference_image = image_paths[idx]\n",
634
+ " \n",
635
+ " # Analyze this sample\n",
636
+ " metrics, images, final_image = analyze_sample_timesteps(prompt, reference_image, idx)\n",
637
+ " \n",
638
+ " # Store results\n",
639
+ " all_results.append({\n",
640
+ " 'prompt': prompt,\n",
641
+ " 'reference_image': reference_image,\n",
642
+ " 'metrics': metrics,\n",
643
+ " 'intermediate_images': images,\n",
644
+ " 'final_image': final_image\n",
645
+ " })\n",
646
+ " \n",
647
+ " # Save intermediate results\n",
648
+ " sample_dir = Path(OUTPUT_DIR) / f\"sample_{idx+1}\"\n",
649
+ " sample_dir.mkdir(exist_ok=True)\n",
650
+ " \n",
651
+ " # Save final image\n",
652
+ " final_image.save(sample_dir / \"final_image.png\")\n",
653
+ " \n",
654
+ " # Save all intermediate images\n",
655
+ " images_dir = sample_dir / \"intermediate_images\"\n",
656
+ " images_dir.mkdir(exist_ok=True)\n",
657
+ " for img_idx, img in enumerate(images):\n",
658
+ " t_val = metrics['timesteps'][img_idx]\n",
659
+ " img.save(images_dir / f\"step_{img_idx:03d}_t{int(t_val)}.png\")\n",
660
+ " \n",
661
+ " # Save metrics\n",
662
+ " with open(sample_dir / \"metrics.json\", 'w') as f:\n",
663
+ " json.dump(metrics, f, indent=2)\n",
664
+ " \n",
665
+ " print(f\"✓ Saved {len(images)} intermediate images for sample {idx+1}\")\n",
666
+ "\n",
667
+ "print(\"\\n✓ Analysis complete for all samples\")"
668
+ ]
669
+ },
670
+ {
671
+ "cell_type": "markdown",
672
+ "id": "10dd749d",
673
+ "metadata": {},
674
+ "source": [
675
+ "#### Visualization: Intermediate Images"
676
+ ]
677
+ },
678
+ {
679
+ "cell_type": "code",
680
+ "execution_count": null,
681
+ "id": "bb32eaa1",
682
+ "metadata": {},
683
+ "outputs": [],
684
+ "source": [
685
+ "def plot_intermediate_images(results, sample_idx, max_images=8):\n",
686
+ " \"\"\"Display intermediate images for a sample showing evolution over timesteps.\"\"\"\n",
687
+ " result = results[sample_idx]\n",
688
+ " images = result['intermediate_images']\n",
689
+ " metrics = result['metrics']\n",
690
+ " timesteps = metrics['timesteps']\n",
691
+ " rewards = metrics['reward']\n",
692
+ " \n",
693
+ " # Select evenly spaced images if too many\n",
694
+ " if len(images) > max_images:\n",
695
+ " indices = np.linspace(0, len(images)-1, max_images, dtype=int)\n",
696
+ " selected_images = [images[i] for i in indices]\n",
697
+ " selected_timesteps = [timesteps[i] for i in indices]\n",
698
+ " selected_rewards = [rewards[i] for i in indices]\n",
699
+ " else:\n",
700
+ " selected_images = images\n",
701
+ " selected_timesteps = timesteps\n",
702
+ " selected_rewards = rewards\n",
703
+ " \n",
704
+ " n_images = len(selected_images)\n",
705
+ " cols = 5\n",
706
+ " rows = (n_images + cols - 1) // cols\n",
707
+ " \n",
708
+ " fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))\n",
709
+ " axes = axes.flatten() if n_images > 1 else [axes]\n",
710
+ " \n",
711
+ " fig.suptitle(f\"Sample {sample_idx + 1}: Image Evolution Over Timesteps\\n\"\n",
712
+ " f\"Prompt: {result['prompt'][:80]}...\", \n",
713
+ " fontsize=12, fontweight='bold')\n",
714
+ " \n",
715
+ " for idx, (img, t, r) in enumerate(zip(selected_images, selected_timesteps, selected_rewards)):\n",
716
+ " ax = axes[idx]\n",
717
+ " ax.imshow(img)\n",
718
+ " ax.axis('off')\n",
719
+ " ax.set_title(f\"t={t:.0f}\\nReward={r:.3f}\", fontsize=10)\n",
720
+ " \n",
721
+ " # Hide unused subplots\n",
722
+ " for idx in range(n_images, len(axes)):\n",
723
+ " axes[idx].axis('off')\n",
724
+ " \n",
725
+ " plt.tight_layout()\n",
726
+ " \n",
727
+ " # Save plot\n",
728
+ " sample_dir = Path(OUTPUT_DIR) / f\"sample_{sample_idx+1}\"\n",
729
+ " plt.savefig(sample_dir / \"image_evolution.png\", dpi=150, bbox_inches='tight')\n",
730
+ " plt.show()\n",
731
+ "\n",
732
+ "# Plot intermediate images for all samples\n",
733
+ "for idx in range(len(all_results)):\n",
734
+ " plot_intermediate_images(all_results, idx)"
735
+ ]
736
+ },
737
+ {
738
+ "cell_type": "code",
739
+ "execution_count": null,
740
+ "id": "878b7686",
741
+ "metadata": {},
742
+ "outputs": [],
743
+ "source": [
744
+ "def plot_final_images_grid(results):\n",
745
+ " \"\"\"Display all final images in a grid for comparison.\"\"\"\n",
746
+ " n_samples = len(results)\n",
747
+ " cols = min(10, n_samples)\n",
748
+ " rows = (n_samples + cols - 1) // cols\n",
749
+ " \n",
750
+ " fig, axes = plt.subplots(rows, cols, figsize=(5*cols, 5*rows))\n",
751
+ " if n_samples == 1:\n",
752
+ " axes = [axes]\n",
753
+ " else:\n",
754
+ " axes = axes.flatten()\n",
755
+ " \n",
756
+ " fig.suptitle(\"Final Generated Images: All Samples\", fontsize=14, fontweight='bold')\n",
757
+ " \n",
758
+ " for idx, result in enumerate(results):\n",
759
+ " ax = axes[idx]\n",
760
+ " ax.imshow(result['final_image'])\n",
761
+ " ax.axis('off')\n",
762
+ " \n",
763
+ " # Get final metrics\n",
764
+ " metrics = result['metrics']\n",
765
+ " reward = metrics['reward'][-1] if metrics['reward'] else 0\n",
766
+ " clip_score = metrics['clip'][-1] if 'clip' in metrics and metrics['clip'] and metrics['clip'][-1] is not None else 0\n",
767
+ " \n",
768
+ " ax.set_title(f\"Sample {idx+1}\\nReward: {reward:.3f} | CLIP: {clip_score:.3f}\\n{result['prompt'][:40]}...\", \n",
769
+ " fontsize=9)\n",
770
+ " \n",
771
+ " # Hide unused subplots\n",
772
+ " for idx in range(n_samples, len(axes)):\n",
773
+ " axes[idx].axis('off')\n",
774
+ " \n",
775
+ " plt.tight_layout()\n",
776
+ " plt.savefig(Path(OUTPUT_DIR) / \"final_images_grid.png\", dpi=150, bbox_inches='tight')\n",
777
+ " plt.show()\n",
778
+ "\n",
779
+ "# Display final images\n",
780
+ "plot_final_images_grid(all_results)"
781
+ ]
782
+ },
783
+ {
784
+ "cell_type": "markdown",
785
+ "id": "bc5a96a6",
786
+ "metadata": {},
787
+ "source": [
788
+ "#### Debug: Check Data"
789
+ ]
790
+ },
791
+ {
792
+ "cell_type": "code",
793
+ "execution_count": null,
794
+ "id": "40008ed4",
795
+ "metadata": {},
796
+ "outputs": [],
797
+ "source": [
798
+ "# Check if data was collected properly\n",
799
+ "print(\"Data Collection Summary:\")\n",
800
+ "print(\"=\"*70)\n",
801
+ "\n",
802
+ "for idx, result in enumerate(all_results):\n",
803
+ " print(f\"\\nSample {idx+1}:\")\n",
804
+ " print(f\" Prompt: {result['prompt'][:60]}...\")\n",
805
+ " \n",
806
+ " metrics = result['metrics']\n",
807
+ " print(f\" Number of timesteps tracked: {len(metrics['timesteps'])}\")\n",
808
+ " print(f\" Number of intermediate images: {len(result['intermediate_images'])}\")\n",
809
+ " \n",
810
+ " # Check which metrics have data\n",
811
+ " for metric_name in ['reward', 'clip', 'aesthetic', 'pickscore', 'hpsv2', 'fid']:\n",
812
+ " if metric_name in metrics:\n",
813
+ " non_none = [v for v in metrics[metric_name] if v is not None]\n",
814
+ " if non_none:\n",
815
+ " print(f\" {metric_name.upper()}: {len(non_none)} values | \"\n",
816
+ " f\"Range: [{min(non_none):.3f}, {max(non_none):.3f}]\")\n",
817
+ " else:\n",
818
+ " print(f\" {metric_name.upper()}: No valid data\")\n",
819
+ " \n",
820
+ " # Check timestep range\n",
821
+ " if metrics['timesteps']:\n",
822
+ " print(f\" Timestep range: [{max(metrics['timesteps']):.0f}, {min(metrics['timesteps']):.0f}]\")\n",
823
+ "\n",
824
+ "print(\"\\n\" + \"=\"*70)"
825
+ ]
826
+ },
827
+ {
828
+ "cell_type": "markdown",
829
+ "id": "5bdacf18",
830
+ "metadata": {},
831
+ "source": [
832
+ "#### Visualization: Metrics Evolution"
833
+ ]
834
+ },
835
+ {
836
+ "cell_type": "code",
837
+ "execution_count": null,
838
+ "id": "be2fc746",
839
+ "metadata": {},
840
+ "outputs": [],
841
+ "source": [
842
+ "def plot_metrics_evolution(results, sample_idx):\n",
843
+ " \"\"\"Plot all metrics evolution in a single row for one sample.\"\"\"\n",
844
+ " result = results[sample_idx]\n",
845
+ " metrics = result['metrics']\n",
846
+ " timesteps = metrics['timesteps']\n",
847
+ " \n",
848
+ " # Filter metrics to plot (exclude None values)\n",
849
+ " metrics_to_plot = []\n",
850
+ " for metric_name in ['reward', 'clip', 'aesthetic', 'pickscore', 'hpsv2', 'imagereward', 'fid']:\n",
851
+ " if metric_name in metrics and any(v is not None for v in metrics[metric_name]):\n",
852
+ " metrics_to_plot.append(metric_name)\n",
853
+ " \n",
854
+ " n_metrics = len(metrics_to_plot)\n",
855
+ " \n",
856
+ " # Create figure with subplots in a row\n",
857
+ " fig, axes = plt.subplots(1, n_metrics, figsize=(5*n_metrics, 4))\n",
858
+ " if n_metrics == 1:\n",
859
+ " axes = [axes]\n",
860
+ " \n",
861
+ " fig.suptitle(f\"Sample {sample_idx + 1}: Metrics Evolution Across Timesteps\\n\"\n",
862
+ " f\"Prompt: {result['prompt'][:80]}...\", fontsize=12, fontweight='bold')\n",
863
+ " \n",
864
+ " colors = ['blue', 'green', 'red', 'purple', 'orange', 'brown', 'pink']\n",
865
+ " \n",
866
+ " for idx, metric_name in enumerate(metrics_to_plot):\n",
867
+ " ax = axes[idx]\n",
868
+ " values = [v for v in metrics[metric_name] if v is not None]\n",
869
+ " valid_timesteps = [t for t, v in zip(timesteps, metrics[metric_name]) if v is not None]\n",
870
+ " \n",
871
+ " if values:\n",
872
+ " ax.plot(valid_timesteps, values, marker='o', linewidth=2, \n",
873
+ " color=colors[idx % len(colors)], label=metric_name.upper())\n",
874
+ " ax.set_xlabel('Timestep', fontsize=10)\n",
875
+ " ax.set_ylabel(metric_name.upper(), fontsize=10)\n",
876
+ " ax.set_title(f\"{metric_name.upper()}\\n{values[0]:.3f} → {values[-1]:.3f}\", fontsize=10)\n",
877
+ " ax.grid(True, alpha=0.3)\n",
878
+ " ax.invert_xaxis() # Timesteps go from high to low\n",
879
+ " \n",
880
+ " # Add improvement annotation\n",
881
+ " improvement = values[-1] - values[0]\n",
882
+ " color = 'green' if improvement > 0 else 'red'\n",
883
+ " if metric_name == 'fid': # Lower is better for FID\n",
884
+ " color = 'green' if improvement < 0 else 'red'\n",
885
+ " ax.text(0.05, 0.95, f\"Δ: {improvement:+.3f}\", \n",
886
+ " transform=ax.transAxes, fontsize=9, verticalalignment='top',\n",
887
+ " bbox=dict(boxstyle='round', facecolor=color, alpha=0.3))\n",
888
+ " \n",
889
+ " plt.tight_layout()\n",
890
+ " \n",
891
+ " # Save plot\n",
892
+ " sample_dir = Path(OUTPUT_DIR) / f\"sample_{sample_idx+1}\"\n",
893
+ " plt.savefig(sample_dir / \"metrics_evolution.png\", dpi=150, bbox_inches='tight')\n",
894
+ " plt.show()\n",
895
+ "\n",
896
+ "# Plot for all samples\n",
897
+ "for idx in range(len(all_results)):\n",
898
+ " plot_metrics_evolution(all_results, idx)"
899
+ ]
900
+ },
901
+ {
902
+ "cell_type": "markdown",
903
+ "id": "45b7abb7",
904
+ "metadata": {},
905
+ "source": [
906
+ "#### Visualization: Compare All Samples"
907
+ ]
908
+ },
909
+ {
910
+ "cell_type": "code",
911
+ "execution_count": null,
912
+ "id": "7d3df7e0",
913
+ "metadata": {},
914
+ "outputs": [],
915
+ "source": [
916
+ "def plot_all_samples_comparison(results):\n",
917
+ " \"\"\"Plot metric evolution for all samples in a grid.\"\"\"\n",
918
+ " # Choose key metrics to compare\n",
919
+ " key_metrics = ['reward', 'clip', 'aesthetic', 'fid']\n",
920
+ " n_metrics = len(key_metrics)\n",
921
+ " n_samples = len(results)\n",
922
+ " \n",
923
+ " fig, axes = plt.subplots(n_metrics, 1, figsize=(14, 4*n_metrics))\n",
924
+ " if n_metrics == 1:\n",
925
+ " axes = [axes]\n",
926
+ " \n",
927
+ " fig.suptitle(\"Convergence Analysis: All Samples Comparison\", fontsize=14, fontweight='bold')\n",
928
+ " \n",
929
+ " colors = plt.cm.tab10(np.linspace(0, 1, n_samples))\n",
930
+ " \n",
931
+ " for metric_idx, metric_name in enumerate(key_metrics):\n",
932
+ " ax = axes[metric_idx]\n",
933
+ " \n",
934
+ " for sample_idx, result in enumerate(results):\n",
935
+ " metrics = result['metrics']\n",
936
+ " timesteps = metrics['timesteps']\n",
937
+ " values = [v for v in metrics[metric_name] if v is not None]\n",
938
+ " valid_timesteps = [t for t, v in zip(timesteps, metrics[metric_name]) if v is not None]\n",
939
+ " \n",
940
+ " if values:\n",
941
+ " ax.plot(valid_timesteps, values, marker='o', linewidth=2, \n",
942
+ " color=colors[sample_idx], label=f\"Sample {sample_idx+1}\", alpha=0.7)\n",
943
+ " \n",
944
+ " ax.set_xlabel('Timestep', fontsize=11)\n",
945
+ " ax.set_ylabel(metric_name.upper(), fontsize=11)\n",
946
+ " ax.set_title(f\"{metric_name.upper()} Evolution\", fontsize=12, fontweight='bold')\n",
947
+ " ax.grid(True, alpha=0.3)\n",
948
+ " ax.invert_xaxis()\n",
949
+ " ax.legend(loc='best', fontsize=9)\n",
950
+ " \n",
951
+ " plt.tight_layout()\n",
952
+ " plt.savefig(Path(OUTPUT_DIR) / \"all_samples_comparison.png\", dpi=150, bbox_inches='tight')\n",
953
+ " plt.show()\n",
954
+ "\n",
955
+ "# Plot comparison\n",
956
+ "plot_all_samples_comparison(all_results)"
957
+ ]
958
+ },
959
+ {
960
+ "cell_type": "markdown",
961
+ "id": "44f97581",
962
+ "metadata": {},
963
+ "source": [
964
+ "#### Convergence Analysis"
965
+ ]
966
+ },
967
+ {
968
+ "cell_type": "code",
969
+ "execution_count": null,
970
+ "id": "9dffd878",
971
+ "metadata": {},
972
+ "outputs": [],
973
+ "source": [
974
+ "def analyze_convergence(results):\n",
975
+ " \"\"\"Analyze convergence behavior across samples.\"\"\"\n",
976
+ " print(\"\\n\" + \"=\"*70)\n",
977
+ " print(\"CONVERGENCE ANALYSIS\")\n",
978
+ " print(\"=\"*70)\n",
979
+ " \n",
980
+ " for metric_name in ['reward', 'clip', 'aesthetic', 'pickscore', 'hpsv2']:\n",
981
+ " print(f\"\\n{metric_name.upper()} Convergence:\")\n",
982
+ " print(\"-\" * 50)\n",
983
+ " \n",
984
+ " improvements = []\n",
985
+ " initial_values = []\n",
986
+ " final_values = []\n",
987
+ " \n",
988
+ " for idx, result in enumerate(results):\n",
989
+ " metrics = result['metrics']\n",
990
+ " if metric_name in metrics:\n",
991
+ " values = [v for v in metrics[metric_name] if v is not None]\n",
992
+ " if values:\n",
993
+ " initial = values[0]\n",
994
+ " final = values[-1]\n",
995
+ " improvement = final - initial\n",
996
+ " \n",
997
+ " initial_values.append(initial)\n",
998
+ " final_values.append(final)\n",
999
+ " improvements.append(improvement)\n",
1000
+ " \n",
1001
+ " print(f\" Sample {idx+1}: {initial:.4f} → {final:.4f} ({improvement:+.4f})\")\n",
1002
+ " \n",
1003
+ " if improvements:\n",
1004
+ " avg_improvement = np.mean(improvements)\n",
1005
+ " std_improvement = np.std(improvements)\n",
1006
+ " print(f\"\\n Average Improvement: {avg_improvement:+.4f} (±{std_improvement:.4f})\")\n",
1007
+ " print(f\" Converged: {'YES' if std_improvement < 0.1 * abs(avg_improvement) else 'NO'}\")\n",
1008
+ " \n",
1009
+ " # Summary\n",
1010
+ " print(\"\\n\" + \"=\"*70)\n",
1011
+ " print(\"SUMMARY\")\n",
1012
+ " print(\"=\"*70)\n",
1013
+ " print(f\"Total samples analyzed: {len(results)}\")\n",
1014
+ " print(f\"Gradient ascent config: {grad_config}\")\n",
1015
+ " print(f\"\\nConclusion: Analyze the plots above to determine convergence behavior.\")\n",
1016
+ " print(f\"Look for:\")\n",
1017
+ " print(f\" 1. Metrics plateauing (flattening out)\")\n",
1018
+ " print(f\" 2. Consistent improvement across samples\")\n",
1019
+ " print(f\" 3. Low variance in final metric values\")\n",
1020
+ "\n",
1021
+ "analyze_convergence(all_results)"
1022
+ ]
1023
+ },
1024
+ {
1025
+ "cell_type": "markdown",
1026
+ "id": "d263be5f",
1027
+ "metadata": {},
1028
+ "source": [
1029
+ "#### Save Results Summary"
1030
+ ]
1031
+ },
1032
+ {
1033
+ "cell_type": "code",
1034
+ "execution_count": null,
1035
+ "id": "5434e7c0",
1036
+ "metadata": {},
1037
+ "outputs": [],
1038
+ "source": [
1039
+ "# Save comprehensive summary\n",
1040
+ "summary = {\n",
1041
+ " 'config': {\n",
1042
+ " 'num_samples': NUM_SAMPLES,\n",
1043
+ " 'num_inference_steps': NUM_INFERENCE_STEPS,\n",
1044
+ " 'cfg_scale': CFG_SCALE,\n",
1045
+ " 'grad_config': grad_config,\n",
1046
+ " 'metrics': METRICS,\n",
1047
+ " 'model_variant': MODEL_VARIANT\n",
1048
+ " },\n",
1049
+ " 'samples': []\n",
1050
+ "}\n",
1051
+ "\n",
1052
+ "for idx, result in enumerate(all_results):\n",
1053
+ " metrics = result['metrics']\n",
1054
+ " sample_summary = {\n",
1055
+ " 'sample_id': idx + 1,\n",
1056
+ " 'prompt': result['prompt'],\n",
1057
+ " 'reference_image': result['reference_image']\n",
1058
+ " }\n",
1059
+ " \n",
1060
+ " for metric_name in ['reward', 'clip', 'aesthetic', 'pickscore', 'hpsv2']:\n",
1061
+ " if metric_name in metrics:\n",
1062
+ " values = [v for v in metrics[metric_name] if v is not None]\n",
1063
+ " if values:\n",
1064
+ " sample_summary[metric_name] = {\n",
1065
+ " 'initial': values[0],\n",
1066
+ " 'final': values[-1],\n",
1067
+ " 'improvement': values[-1] - values[0],\n",
1068
+ " 'all_values': values\n",
1069
+ " }\n",
1070
+ " \n",
1071
+ " summary['samples'].append(sample_summary)\n",
1072
+ "\n",
1073
+ "# Save summary\n",
1074
+ "with open(Path(OUTPUT_DIR) / \"convergence_summary.json\", 'w') as f:\n",
1075
+ " json.dump(summary, f, indent=2)\n",
1076
+ "\n",
1077
+ "print(f\"\\n✓ Results saved to: {OUTPUT_DIR}\")\n",
1078
+ "print(f\" - convergence_summary.json\")\n",
1079
+ "print(f\" - all_samples_comparison.png\")\n",
1080
+ "print(f\" - sample_X/ directories with individual results\")"
1081
+ ]
1082
+ }
1083
+ ],
1084
+ "metadata": {
1085
+ "kernelspec": {
1086
+ "display_name": "Python 3",
1087
+ "language": "python",
1088
+ "name": "python3"
1089
+ },
1090
+ "language_info": {
1091
+ "codemirror_mode": {
1092
+ "name": "ipython",
1093
+ "version": 3
1094
+ },
1095
+ "file_extension": ".py",
1096
+ "mimetype": "text/x-python",
1097
+ "name": "python",
1098
+ "nbconvert_exporter": "python",
1099
+ "pygments_lexer": "ipython3",
1100
+ "version": "3.10.18"
1101
+ }
1102
+ },
1103
+ "nbformat": 4,
1104
+ "nbformat_minor": 5
1105
+ }
Reward_sdxl_idealized/tune_hyperparams.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hyperparameter tuning script for gradient ascent optimization.
3
+
4
+ This script performs a systematic search over hyperparameter combinations
5
+ to find the optimal configuration for maximum evaluation scores.
6
+ """
7
+
8
+ import subprocess
9
+ import json
10
+ import argparse
11
+ from pathlib import Path
12
+ from datetime import datetime
13
+ import itertools
14
+ import numpy as np
15
+ from typing import Dict, List, Any
16
+ import re
17
+
18
+
19
+ class HyperparameterTuner:
20
+ """Hyperparameter tuner for gradient ascent."""
21
+
22
+ def __init__(
23
+ self,
24
+ output_dir: str = "tuning_results",
25
+ max_samples: int = 30,
26
+ num_steps: int = 20,
27
+ dataset_type: str = "pickapic",
28
+ model_variant: str = "lpo",
29
+ cuda_id: int = 0,
30
+ metrics: List[str] = None
31
+ ):
32
+ self.output_dir = Path(output_dir)
33
+ self.output_dir.mkdir(parents=True, exist_ok=True)
34
+
35
+ self.max_samples = max_samples
36
+ self.num_steps = num_steps
37
+ self.dataset_type = dataset_type
38
+ self.model_variant = model_variant
39
+ self.cuda_id = cuda_id
40
+ self.metrics = metrics or ["clip", "aesthetic", "pickscore", "hpsv2", "imagereward"]
41
+
42
+ # Store results
43
+ self.results = []
44
+ self.baseline_results = None
45
+
46
+ def define_search_space(self) -> List[Dict[str, Any]]:
47
+ """Define the hyperparameter search space - FULL GRID SEARCH.
48
+
49
+ Tests all combinations of parameters including momentum overrides for configs that support it.
50
+ """
51
+
52
+ # Define all parameter values
53
+ cfg_scales = [3.0, 5.0, 7.5] #
54
+
55
+ # All available gradient configs from grad_ascent_configs.py
56
+ grad_configs = [
57
+ # "constant",
58
+ # "linear",
59
+ "cosine_nesterov",
60
+ # "low_to_high_nesterov",
61
+ # "high_to_low_nesterov",
62
+ "low_to_high_momentum",
63
+ "high_to_low_momentum",
64
+ ]
65
+
66
+ num_grad_steps_list = [1, 2] # 5, 7, 10
67
+ grad_step_sizes = [0.001, 0.005, 0.01, 0.05] #
68
+ momentums = [0.5, 0.8, 0.9] #
69
+
70
+ # Generate ALL combinations using itertools.product
71
+ configs = []
72
+ for cfg, grad_cfg, num_steps, step_size, momentum in itertools.product(
73
+ cfg_scales, grad_configs, num_grad_steps_list, grad_step_sizes, momentums
74
+ ):
75
+ configs.append({
76
+ "cfg_scale": cfg,
77
+ "grad_config": grad_cfg,
78
+ "num_grad_steps": num_steps,
79
+ "grad_step_size": step_size,
80
+ "momentum": momentum,
81
+ })
82
+
83
+ print(f"\nGenerated {len(configs)} total configurations")
84
+ print(f" cfg_scales: {len(cfg_scales)}")
85
+ print(f" grad_configs: {len(grad_configs)}")
86
+ print(f" num_grad_steps: {len(num_grad_steps_list)}")
87
+ print(f" grad_step_sizes: {len(grad_step_sizes)}")
88
+ print(f" momentums: {len(momentums)}")
89
+ print(f" Total: {len(cfg_scales)} × {len(grad_configs)} × {len(num_grad_steps_list)} × {len(grad_step_sizes)} × {len(momentums)} = {len(configs)}")
90
+
91
+ return configs
92
+
93
+ def run_baseline(self) -> Dict[str, float]:
94
+ """Run baseline evaluation once."""
95
+ print("\n" + "="*80)
96
+ print("RUNNING BASELINE EVALUATION")
97
+ print("="*80)
98
+
99
+ # Use median cfg_scale for baseline
100
+ cfg_scale = 5.0
101
+
102
+ output_dir = self.output_dir / "baseline"
103
+
104
+ cmd = [
105
+ "python", "eval.py",
106
+ "--model_variant", self.model_variant,
107
+ "--dataset_type", self.dataset_type,
108
+ "--max_samples", str(self.max_samples),
109
+ "--num_steps", str(self.num_steps),
110
+ "--cfg_scale", str(cfg_scale),
111
+ "--output_dir", str(output_dir),
112
+ "--cuda", str(self.cuda_id),
113
+ "--mode", "baseline",
114
+ "--metrics", *self.metrics,
115
+ ]
116
+
117
+ print(f"Command: {' '.join(cmd)}")
118
+
119
+ try:
120
+ result = subprocess.run(cmd, capture_output=True, text=True, check=True)
121
+
122
+ # Parse results from output
123
+ metrics = self._parse_metrics(result.stdout, "baseline")
124
+
125
+ print(f"\nBaseline Results:")
126
+ for metric, value in metrics.items():
127
+ print(f" {metric}: {value:.4f}")
128
+
129
+ self.baseline_results = {
130
+ "cfg_scale": cfg_scale,
131
+ "metrics": metrics,
132
+ }
133
+
134
+ return metrics
135
+
136
+ except subprocess.CalledProcessError as e:
137
+ print(f"Error running baseline: {e}")
138
+ print(f"Stdout: {e.stdout}")
139
+ print(f"Stderr: {e.stderr}")
140
+ return {}
141
+
142
+ def run_experiment(self, config: Dict[str, Any]) -> Dict[str, Any]:
143
+ """Run a single experiment with given hyperparameters."""
144
+
145
+ # Create output directory for this config
146
+ config_name = f"cfg{config['cfg_scale']}_" \
147
+ f"{config['grad_config']}_" \
148
+ f"steps{config['num_grad_steps']}_" \
149
+ f"lr{config['grad_step_size']}_" \
150
+ f"mom{config['momentum']}"
151
+
152
+ output_dir = self.output_dir / config_name
153
+
154
+ # Build command
155
+ cmd = [
156
+ "python", "eval.py",
157
+ "--model_variant", self.model_variant,
158
+ "--dataset_type", self.dataset_type,
159
+ "--grad_config", config["grad_config"],
160
+ "--max_samples", str(self.max_samples),
161
+ "--num_steps", str(self.num_steps),
162
+ "--cfg_scale", str(config["cfg_scale"]),
163
+ "--output_dir", str(output_dir),
164
+ "--cuda", str(self.cuda_id),
165
+ "--mode", "gradient_ascent",
166
+ "--metrics", *self.metrics,
167
+ # Override config parameters
168
+ "--override_num_grad_steps", str(config["num_grad_steps"]),
169
+ "--override_grad_step_size", str(config["grad_step_size"]),
170
+ "--override_momentum", str(config["momentum"]),
171
+ ]
172
+
173
+ print(f"\nRunning experiment: {config_name}")
174
+ print(f"Config: {config}")
175
+
176
+ try:
177
+ result = subprocess.run(cmd, capture_output=True, text=True, check=True)
178
+
179
+ # Parse metrics from output
180
+ metrics = self._parse_metrics(result.stdout, "gradient_ascent")
181
+
182
+ # Compute improvement over baseline
183
+ improvements = {}
184
+ if self.baseline_results:
185
+ baseline_metrics = self.baseline_results["metrics"]
186
+ for metric, value in metrics.items():
187
+ if metric in baseline_metrics:
188
+ baseline_val = baseline_metrics[metric]
189
+ if baseline_val != 0:
190
+ improvement = ((value - baseline_val) / abs(baseline_val)) * 100
191
+ improvements[f"{metric}_improvement"] = improvement
192
+
193
+ result_dict = {
194
+ "config": config,
195
+ "metrics": metrics,
196
+ "improvements": improvements,
197
+ "output_dir": str(output_dir),
198
+ "timestamp": datetime.now().isoformat(),
199
+ }
200
+
201
+ print(f"Results:")
202
+ for metric, value in metrics.items():
203
+ print(f" {metric}: {value:.4f}")
204
+ if improvements:
205
+ print(f"Improvements over baseline:")
206
+ for metric, value in improvements.items():
207
+ print(f" {metric}: {value:+.2f}%")
208
+
209
+ return result_dict
210
+
211
+ except subprocess.CalledProcessError as e:
212
+ print(f"Error running experiment: {e}")
213
+ print(f"Stderr: {e.stderr}")
214
+ return {
215
+ "config": config,
216
+ "error": str(e),
217
+ "timestamp": datetime.now().isoformat(),
218
+ }
219
+
220
+ def _parse_metrics(self, output: str, mode: str) -> Dict[str, float]:
221
+ """Parse metrics from eval.py output."""
222
+ metrics = {}
223
+
224
+ # Look for the summary section
225
+ lines = output.split('\n')
226
+
227
+ # Pattern to match metric lines like " Reward: 0.1234"
228
+ metric_patterns = {
229
+ "reward": r"Reward:\s+([-+]?\d*\.?\d+)",
230
+ "clip": r"CLIP Score:\s+([-+]?\d*\.?\d+)",
231
+ "aesthetic": r"Aesthetic Score:\s+([-+]?\d*\.?\d+)",
232
+ "pickscore": r"PickScore:\s+([-+]?\d*\.?\d+)",
233
+ "hpsv2": r"HPSv2 Score:\s+([-+]?\d*\.?\d+)",
234
+ "hpsv21": r"HPSv2\.1 Score:\s+([-+]?\d*\.?\d+)",
235
+ "imagereward": r"ImageReward:\s+([-+]?\d*\.?\d+)",
236
+ "fid": r"FID:\s+([-+]?\d*\.?\d+)",
237
+ }
238
+
239
+ for line in lines:
240
+ for metric_name, pattern in metric_patterns.items():
241
+ match = re.search(pattern, line)
242
+ if match:
243
+ metrics[metric_name] = float(match.group(1))
244
+
245
+ return metrics
246
+
247
+ def compute_aggregate_score(self, metrics: Dict[str, float]) -> float:
248
+ """
249
+ Compute aggregate score for ranking configurations.
250
+
251
+ Uses weighted combination of metrics (higher is better for most,
252
+ except FID which is lower is better).
253
+ """
254
+ weights = {
255
+ "reward": 1.0,
256
+ "clip": 0.8,
257
+ "aesthetic": 0.8,
258
+ "pickscore": 1.0,
259
+ "hpsv2": 1.0,
260
+ "hpsv21": 1.0,
261
+ "imagereward": 1.0,
262
+ "fid": -0.5, # Negative weight (lower FID is better)
263
+ }
264
+
265
+ score = 0.0
266
+ total_weight = 0.0
267
+
268
+ for metric, value in metrics.items():
269
+ if metric in weights:
270
+ score += weights[metric] * value
271
+ total_weight += abs(weights[metric])
272
+
273
+ # Normalize by total weight
274
+ if total_weight > 0:
275
+ score /= total_weight
276
+
277
+ return score
278
+
279
+ def run_search(
280
+ self,
281
+ search_type: str = "grid",
282
+ start_idx: int = 0,
283
+ end_idx: int = None
284
+ ) -> List[Dict[str, Any]]:
285
+ """
286
+ Run hyperparameter search.
287
+
288
+ Args:
289
+ search_type: Type of search ("grid" or "random")
290
+ start_idx: Starting index for experiments (for GPU distribution)
291
+ end_idx: Ending index for experiments (for GPU distribution)
292
+ """
293
+ all_configs = self.define_search_space()
294
+
295
+ print("\n" + "="*80)
296
+ print("HYPERPARAMETER SEARCH CONFIGURATION")
297
+ print("="*80)
298
+ print(f"Dataset: {self.dataset_type}")
299
+ print(f"Model: {self.model_variant}")
300
+ print(f"Samples: {self.max_samples}")
301
+ print(f"Inference steps: {self.num_steps}")
302
+ print(f"Metrics: {', '.join(self.metrics)}")
303
+
304
+ # Select subset of configs if indices provided
305
+ if search_type == "grid":
306
+ configs = all_configs
307
+ elif search_type == "random":
308
+ # Random sample from all configs
309
+ n_samples = min(50, len(all_configs))
310
+ indices = np.random.choice(len(all_configs), n_samples, replace=False)
311
+ configs = [all_configs[i] for i in indices]
312
+ else:
313
+ raise ValueError(f"Unknown search type: {search_type}")
314
+
315
+ # Apply index slicing for GPU distribution
316
+ if end_idx is None:
317
+ end_idx = len(configs)
318
+ configs = configs[start_idx:end_idx]
319
+
320
+ print(f"\nTotal configurations: {len(all_configs)}")
321
+ print(f"Assigned to this worker: {len(configs)} (indices {start_idx} to {end_idx})")
322
+
323
+ # Run baseline first
324
+ if self.baseline_results is None:
325
+ self.run_baseline()
326
+
327
+ # Run experiments
328
+ print("\n" + "="*80)
329
+ print("RUNNING EXPERIMENTS")
330
+ print("="*80)
331
+
332
+ for i, config in enumerate(configs, 1):
333
+ print(f"\n{'='*80}")
334
+ print(f"Experiment {i}/{len(configs)}")
335
+ print(f"{'='*80}")
336
+
337
+ result = self.run_experiment(config)
338
+ self.results.append(result)
339
+
340
+ # Save intermediate results
341
+ self._save_results()
342
+
343
+ return self.results
344
+
345
+ def _generate_grid_configs(self, search_space: Dict[str, List[Any]]) -> List[Dict[str, Any]]:
346
+ """Generate all combinations for grid search."""
347
+ keys = list(search_space.keys())
348
+ values = list(search_space.values())
349
+
350
+ configs = []
351
+ for combination in itertools.product(*values):
352
+ config = dict(zip(keys, combination))
353
+ configs.append(config)
354
+
355
+ return configs
356
+
357
+ def _generate_random_configs(
358
+ self,
359
+ search_space: Dict[str, List[Any]],
360
+ n_samples: int = 20
361
+ ) -> List[Dict[str, Any]]:
362
+ """Generate random configurations for random search."""
363
+ configs = []
364
+
365
+ for _ in range(n_samples):
366
+ config = {}
367
+ for param, values in search_space.items():
368
+ config[param] = np.random.choice(values)
369
+ configs.append(config)
370
+
371
+ return configs
372
+
373
+ def _save_results(self):
374
+ """Save results to JSON file."""
375
+ results_file = self.output_dir / "tuning_results.json"
376
+
377
+ data = {
378
+ "baseline": self.baseline_results,
379
+ "experiments": self.results,
380
+ "timestamp": datetime.now().isoformat(),
381
+ "config": {
382
+ "max_samples": self.max_samples,
383
+ "num_steps": self.num_steps,
384
+ "dataset_type": self.dataset_type,
385
+ "model_variant": self.model_variant,
386
+ }
387
+ }
388
+
389
+ with open(results_file, 'w') as f:
390
+ json.dump(data, f, indent=2)
391
+
392
+ print(f"\nResults saved to: {results_file}")
393
+
394
+ def analyze_results(self) -> Dict[str, Any]:
395
+ """Analyze results and find best configuration."""
396
+ if not self.results:
397
+ print("No results to analyze!")
398
+ return {}
399
+
400
+ print("\n" + "="*80)
401
+ print("ANALYSIS: FINDING BEST CONFIGURATION")
402
+ print("="*80)
403
+
404
+ # Filter out failed experiments
405
+ successful_results = [r for r in self.results if "metrics" in r]
406
+
407
+ if not successful_results:
408
+ print("No successful experiments!")
409
+ return {}
410
+
411
+ # Compute aggregate scores
412
+ for result in successful_results:
413
+ metrics = result["metrics"]
414
+ result["aggregate_score"] = self.compute_aggregate_score(metrics)
415
+
416
+ # Sort by aggregate score
417
+ successful_results.sort(key=lambda x: x["aggregate_score"], reverse=True)
418
+
419
+ # Print top 5 configurations
420
+ print("\nTop 5 Configurations:")
421
+ print("="*80)
422
+
423
+ for i, result in enumerate(successful_results[:5], 1):
424
+ print(f"\n#{i} - Aggregate Score: {result['aggregate_score']:.4f}")
425
+ print(f"Config: {result['config']}")
426
+ print(f"Metrics:")
427
+ for metric, value in result['metrics'].items():
428
+ print(f" {metric}: {value:.4f}")
429
+ if result.get('improvements'):
430
+ print(f"Improvements over baseline:")
431
+ for metric, value in result['improvements'].items():
432
+ print(f" {metric}: {value:+.2f}%")
433
+
434
+ # Save best config
435
+ best_result = successful_results[0]
436
+ best_config_file = self.output_dir / "best_config.json"
437
+
438
+ with open(best_config_file, 'w') as f:
439
+ json.dump({
440
+ "config": best_result["config"],
441
+ "metrics": best_result["metrics"],
442
+ "aggregate_score": best_result["aggregate_score"],
443
+ "improvements": best_result.get("improvements", {}),
444
+ }, f, indent=2)
445
+
446
+ print(f"\n✓ Best configuration saved to: {best_config_file}")
447
+
448
+ return best_result
449
+
450
+
451
+ def main():
452
+ parser = argparse.ArgumentParser(description="Hyperparameter tuning for gradient ascent")
453
+ parser.add_argument("--output_dir", type=str, default="tuning_results",
454
+ help="Directory to save tuning results")
455
+ parser.add_argument("--max_samples", type=int, default=30,
456
+ help="Number of samples to use for tuning")
457
+ parser.add_argument("--num_steps", type=int, default=20,
458
+ help="Number of inference steps (fixed)")
459
+ parser.add_argument("--dataset_type", type=str, default="pickapic",
460
+ choices=["coco", "pickapic"],
461
+ help="Dataset to use")
462
+ parser.add_argument("--model_variant", type=str, default="lpo",
463
+ choices=["origin", "spo", "diffusion_dpo", "lpo"],
464
+ help="Model variant to use")
465
+ parser.add_argument("--cuda", type=int, default=0,
466
+ help="CUDA device ID")
467
+ parser.add_argument("--search_type", type=str, default="grid",
468
+ choices=["grid", "random"],
469
+ help="Type of hyperparameter search")
470
+ parser.add_argument("--metrics", type=str, nargs="+",
471
+ default=["clip", "aesthetic", "pickscore", "hpsv2", "imagereward"],
472
+ help="Metrics to evaluate")
473
+ parser.add_argument("--start_idx", type=int, default=0,
474
+ help="Starting index for experiments (for GPU distribution)")
475
+ parser.add_argument("--end_idx", type=int, default=None,
476
+ help="Ending index for experiments (for GPU distribution)")
477
+
478
+ args = parser.parse_args()
479
+
480
+ # Create tuner
481
+ tuner = HyperparameterTuner(
482
+ output_dir=args.output_dir,
483
+ max_samples=args.max_samples,
484
+ num_steps=args.num_steps,
485
+ dataset_type=args.dataset_type,
486
+ model_variant=args.model_variant,
487
+ cuda_id=args.cuda,
488
+ metrics=args.metrics,
489
+ )
490
+
491
+ # Run search
492
+ results = tuner.run_search(
493
+ search_type=args.search_type,
494
+ start_idx=args.start_idx,
495
+ end_idx=args.end_idx
496
+ )
497
+
498
+ # Analyze results
499
+ best_result = tuner.analyze_results()
500
+
501
+ print("\n" + "="*80)
502
+ print("TUNING COMPLETE!")
503
+ print("="*80)
504
+ print(f"Total experiments: {len(results)}")
505
+ print(f"Results directory: {args.output_dir}")
506
+
507
+ if best_result:
508
+ print(f"\nBest configuration:")
509
+ print(json.dumps(best_result["config"], indent=2))
510
+ print(f"\nAggregate score: {best_result['aggregate_score']:.4f}")
511
+
512
+
513
+ if __name__ == "__main__":
514
+ main()
Reward_sdxl_idealized/tune_parallel.sh ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Parallel hyperparameter tuning across 8 GPUs
4
+ # This script distributes experiments evenly across all available GPUs
5
+
6
+ clear
7
+
8
+ # Activate conda environment
9
+ source ~/miniconda3/etc/profile.d/conda.sh
10
+ conda activate /home/ec2-user/aev
11
+
12
+ # Configuration
13
+ DATASET_TYPE="pickapic" # "coco" or "pickapic"
14
+ MODEL_VARIANT="lpo" # "origin", "spo", "diffusion_dpo", or "lpo"
15
+ MAX_SAMPLES=500 # Number of samples for tuning
16
+ NUM_STEPS=50 # Fixed inference steps
17
+ SEARCH_TYPE="grid" # "grid" or "random"
18
+ OUTPUT_DIR="RESULTS_TURNING/run_2"
19
+ NUM_GPUS=8 # Number of GPUs to use
20
+
21
+ echo "=============================================="
22
+ echo " PARALLEL HYPERPARAMETER TUNING"
23
+ echo "=============================================="
24
+ echo ""
25
+ echo "Configuration:"
26
+ echo " Dataset: $DATASET_TYPE"
27
+ echo " Model: $MODEL_VARIANT"
28
+ echo " Samples: $MAX_SAMPLES"
29
+ echo " Inference Steps: $NUM_STEPS"
30
+ echo " Search Type: $SEARCH_TYPE"
31
+ echo " GPUs: $NUM_GPUS"
32
+ echo " Output: $OUTPUT_DIR"
33
+ echo ""
34
+
35
+ # First, calculate total number of experiments
36
+ echo "Calculating total experiments..."
37
+ TOTAL_CONFIGS=$(python -c "
38
+ from tune_hyperparams import HyperparameterTuner
39
+ import sys
40
+ tuner = HyperparameterTuner()
41
+ configs = tuner.define_search_space()
42
+ sys.stderr.write(f'Generated {len(configs)} configurations\n')
43
+ print(len(configs))
44
+ " 2>&1 | tail -1)
45
+
46
+ echo "Total configurations: $TOTAL_CONFIGS"
47
+ echo ""
48
+
49
+ # Calculate experiments per GPU
50
+ CONFIGS_PER_GPU=$((TOTAL_CONFIGS / NUM_GPUS))
51
+ REMAINDER=$((TOTAL_CONFIGS % NUM_GPUS))
52
+
53
+ echo "Distributing work:"
54
+ echo " Base configs per GPU: $CONFIGS_PER_GPU"
55
+ echo " Extra configs for first GPUs: $REMAINDER"
56
+ echo ""
57
+
58
+ # Create output directory
59
+ mkdir -p "$OUTPUT_DIR"
60
+
61
+ # Array to store background process IDs
62
+ PIDS=()
63
+
64
+ # Launch parallel processes on each GPU
65
+ for GPU_ID in $(seq 0 $((NUM_GPUS - 1))); do
66
+ # Calculate start and end indices for this GPU
67
+ START_IDX=$((GPU_ID * CONFIGS_PER_GPU))
68
+
69
+ # Give extra configs to first GPUs
70
+ if [ $GPU_ID -lt $REMAINDER ]; then
71
+ START_IDX=$((START_IDX + GPU_ID))
72
+ END_IDX=$((START_IDX + CONFIGS_PER_GPU + 1))
73
+ else
74
+ START_IDX=$((START_IDX + REMAINDER))
75
+ END_IDX=$((START_IDX + CONFIGS_PER_GPU))
76
+ fi
77
+
78
+ # Create GPU-specific output directory
79
+ GPU_OUTPUT_DIR="${OUTPUT_DIR}/gpu_${GPU_ID}"
80
+ mkdir -p "$GPU_OUTPUT_DIR"
81
+
82
+ echo "GPU $GPU_ID: configs $START_IDX to $END_IDX"
83
+
84
+ # Launch tuning process in background
85
+ nohup python tune_hyperparams.py \
86
+ --output_dir "$GPU_OUTPUT_DIR" \
87
+ --max_samples $MAX_SAMPLES \
88
+ --num_steps $NUM_STEPS \
89
+ --dataset_type "$DATASET_TYPE" \
90
+ --model_variant "$MODEL_VARIANT" \
91
+ --cuda $GPU_ID \
92
+ --search_type "$SEARCH_TYPE" \
93
+ --start_idx $START_IDX \
94
+ --end_idx $END_IDX \
95
+ --metrics clip aesthetic pickscore hpsv2 imagereward \
96
+ > "${GPU_OUTPUT_DIR}/tuning.log" 2>&1 &
97
+
98
+ # Store PID
99
+ PIDS+=($!)
100
+
101
+ echo " Launched with PID: ${PIDS[$GPU_ID]}"
102
+
103
+ # Small delay to avoid race conditions
104
+ sleep 2
105
+ done
106
+
107
+ echo ""
108
+ echo "=============================================="
109
+ echo " ALL PROCESSES LAUNCHED"
110
+ echo "=============================================="
111
+ echo ""
112
+ echo "Background processes running:"
113
+ for GPU_ID in $(seq 0 $((NUM_GPUS - 1))); do
114
+ echo " GPU $GPU_ID: PID ${PIDS[$GPU_ID]} -> ${OUTPUT_DIR}/gpu_${GPU_ID}/tuning.log"
115
+ done
116
+ echo ""
117
+ echo "To monitor progress:"
118
+ echo " tail -f ${OUTPUT_DIR}/gpu_0/tuning.log"
119
+ echo " tail -f ${OUTPUT_DIR}/gpu_1/tuning.log"
120
+ echo " ... etc"
121
+ echo ""
122
+ echo "To check all GPU processes:"
123
+ echo " ps aux | grep tune_hyperparams.py"
124
+ echo ""
125
+ echo "To monitor GPU usage:"
126
+ echo " watch -n 1 nvidia-smi"
127
+ echo ""
128
+ echo "To kill all processes:"
129
+ echo " kill ${PIDS[@]}"
130
+ echo ""
131
+ echo "Waiting for all processes to complete..."
132
+ echo "(Press Ctrl+C to stop waiting, processes will continue in background)"
133
+ echo ""
134
+
135
+ # Wait for all background processes
136
+ for PID in "${PIDS[@]}"; do
137
+ wait $PID
138
+ done
139
+
140
+ echo ""
141
+ echo "=============================================="
142
+ echo " ALL TUNING PROCESSES COMPLETE"
143
+ echo "=============================================="
144
+ echo ""
145
+
146
+ # Merge results from all GPUs
147
+ echo "Merging results from all GPUs..."
148
+
149
+ # Activate conda environment for Python script
150
+ source ~/miniconda3/etc/profile.d/conda.sh
151
+ conda activate /home/ec2-user/aev
152
+
153
+ python - <<'EOF'
154
+ import json
155
+ from pathlib import Path
156
+ import sys
157
+
158
+ output_dir = Path("RESULTS_TURNING")
159
+ all_results = []
160
+ baseline_result = None
161
+
162
+ # Collect results from each GPU
163
+ for gpu_id in range(8):
164
+ gpu_dir = output_dir / f"gpu_{gpu_id}"
165
+ results_file = gpu_dir / "tuning_results.json"
166
+
167
+ if results_file.exists():
168
+ with open(results_file, 'r') as f:
169
+ data = json.load(f)
170
+
171
+ # Get baseline (should be same from all)
172
+ if baseline_result is None and "baseline" in data:
173
+ baseline_result = data["baseline"]
174
+
175
+ # Collect experiments
176
+ if "experiments" in data:
177
+ all_results.extend(data["experiments"])
178
+
179
+ print(f"GPU {gpu_id}: {len(data.get('experiments', []))} results")
180
+
181
+ # Merge all results
182
+ merged_data = {
183
+ "baseline": baseline_result,
184
+ "experiments": all_results,
185
+ "num_gpus": 8,
186
+ "total_experiments": len(all_results)
187
+ }
188
+
189
+ # Save merged results
190
+ merged_file = output_dir / "merged_results.json"
191
+ with open(merged_file, 'w') as f:
192
+ json.dump(merged_data, f, indent=2)
193
+
194
+ print(f"\nMerged {len(all_results)} total results")
195
+ print(f"Saved to: {merged_file}")
196
+
197
+ # Find best configuration
198
+ successful = [r for r in all_results if "metrics" in r]
199
+ if successful:
200
+ # Compute aggregate scores
201
+ def compute_score(metrics):
202
+ weights = {
203
+ "reward": 1.0, "clip": 0.8, "aesthetic": 0.8,
204
+ "pickscore": 1.0, "hpsv2": 1.0, "imagereward": 1.0,
205
+ "fid": -0.5
206
+ }
207
+ score = sum(weights.get(k, 0) * v for k, v in metrics.items())
208
+ return score / sum(abs(w) for w in weights.values())
209
+
210
+ for r in successful:
211
+ r["aggregate_score"] = compute_score(r["metrics"])
212
+
213
+ successful.sort(key=lambda x: x["aggregate_score"], reverse=True)
214
+
215
+ best = successful[0]
216
+ best_file = output_dir / "best_config.json"
217
+ with open(best_file, 'w') as f:
218
+ json.dump({
219
+ "config": best["config"],
220
+ "metrics": best["metrics"],
221
+ "aggregate_score": best["aggregate_score"],
222
+ "improvements": best.get("improvements", {})
223
+ }, f, indent=2)
224
+
225
+ print(f"\n{'='*60}")
226
+ print("BEST CONFIGURATION:")
227
+ print(f"{'='*60}")
228
+ print(json.dumps(best["config"], indent=2))
229
+ print(f"\nAggregate Score: {best['aggregate_score']:.4f}")
230
+ print(f"Saved to: {best_file}")
231
+ else:
232
+ print("\nNo successful experiments found!")
233
+ sys.exit(1)
234
+ EOF
235
+
236
+ if [ $? -eq 0 ]; then
237
+ echo ""
238
+ echo "=============================================="
239
+ echo " TUNING COMPLETE!"
240
+ echo "=============================================="
241
+ echo ""
242
+ echo "Results:"
243
+ echo " Merged results: ${OUTPUT_DIR}/merged_results.json"
244
+ echo " Best config: ${OUTPUT_DIR}/best_config.json"
245
+ echo ""
246
+ echo "View best configuration:"
247
+ echo " cat ${OUTPUT_DIR}/best_config.json"
248
+ echo ""
249
+ else
250
+ echo ""
251
+ echo "ERROR: Failed to merge results"
252
+ exit 1
253
+ fi
__pycache__/upload.cpython-311.pyc ADDED
Binary file (9.83 kB). View file
 
evaluation/open_clip/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.01 kB). View file
 
evaluation/open_clip/__pycache__/coca_model.cpython-311.pyc ADDED
Binary file (18.4 kB). View file
 
evaluation/open_clip/__pycache__/hf_model.cpython-311.pyc ADDED
Binary file (10.6 kB). View file
 
evaluation/open_clip/__pycache__/loss.cpython-311.pyc ADDED
Binary file (14.3 kB). View file
 
evaluation/open_clip/__pycache__/model.cpython-311.pyc ADDED
Binary file (25.1 kB). View file
 
evaluation/open_clip/__pycache__/pretrained.cpython-311.pyc ADDED
Binary file (18.5 kB). View file
 
evaluation/open_clip/__pycache__/push_to_hf_hub.cpython-311.pyc ADDED
Binary file (9.31 kB). View file
 
evaluation/open_clip/__pycache__/tokenizer.cpython-311.pyc ADDED
Binary file (16 kB). View file
 
evaluation/open_clip/__pycache__/transform.cpython-311.pyc ADDED
Binary file (8.58 kB). View file
 
evaluation/open_clip/__pycache__/transformer.cpython-311.pyc ADDED
Binary file (42.6 kB). View file
 
evaluation/open_clip/__pycache__/utils.cpython-311.pyc ADDED
Binary file (3.58 kB). View file
 
evaluation/open_clip/model_configs/RN50.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 6,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
evaluation/open_clip/model_configs/ViT-B-32-plus-256.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 256,
5
+ "layers": 12,
6
+ "width": 896,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 640,
13
+ "heads": 10,
14
+ "layers": 12
15
+ }
16
+ }
evaluation/open_clip/model_configs/ViT-S-32.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 384,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 384,
7
+ "patch_size": 32
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 384,
13
+ "heads": 6,
14
+ "layers": 12
15
+ }
16
+ }
evaluation/open_clip/model_configs/convnext_large_d_320.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "timm_model_name": "convnext_large",
5
+ "timm_model_pretrained": false,
6
+ "timm_pool": "",
7
+ "timm_proj": "mlp",
8
+ "timm_drop": 0.0,
9
+ "timm_drop_path": 0.1,
10
+ "image_size": 320
11
+ },
12
+ "text_cfg": {
13
+ "context_length": 77,
14
+ "vocab_size": 49408,
15
+ "width": 768,
16
+ "heads": 12,
17
+ "layers": 16
18
+ }
19
+ }
lrm/flux/.hydra/config.yaml ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerator:
2
+ _target_: trainer.accelerators.debug_accelerator.DebugAccelerator
3
+ output_dir: ${output_dir}
4
+ mixed_precision: BF16
5
+ gradient_accumulation_steps: 1
6
+ log_with: null
7
+ debug:
8
+ activate: false
9
+ port: 5900
10
+ seed: 42
11
+ resume_from_checkpoint: false
12
+ max_steps: 8000
13
+ num_epochs: 10
14
+ validate_steps: 100
15
+ generalization_validate_steps: 500
16
+ eval_on_start: true
17
+ project_name: reward_model
18
+ run_name: step_flux_schnell_variable-t_lr1e-5_step-8000_cfg0.0_filter2_time951
19
+ max_grad_norm: 1.0
20
+ save_steps: 100
21
+ metric_name: accuracy
22
+ metric_mode: MAX
23
+ limit_num_checkpoints: 1
24
+ save_only_if_best: true
25
+ dynamo_backend: 'NO'
26
+ keep_best_ckpts: true
27
+ task:
28
+ limit_examples_to_wandb: 50
29
+ _target_: trainer.tasks.step_flux_task.StepFluxTask
30
+ pretrained_model_name_or_path: ${model.pretrained_model_name_or_path}
31
+ tokenizer_subfolder: tokenizer
32
+ label_0_column_name: ${dataset.label_0_column_name}
33
+ label_1_column_name: ${dataset.label_1_column_name}
34
+ input_ids_column_name: ${dataset.input_ids_column_name}
35
+ input_ids_2_column_name: ${dataset.input_ids_2_column_name}
36
+ pixels_0_column_name: ${dataset.pixels_0_column_name}
37
+ pixels_1_column_name: ${dataset.pixels_1_column_name}
38
+ timestep_column_name: ${dataset.timestep_column_name}
39
+ constant_timestep: ${dataset.constant_timestep}
40
+ model:
41
+ _target_: trainer.models.flux_preference_model.FluxPreferenceModel
42
+ pretrained_model_name_or_path: black-forest-labs/FLUX.1-schnell
43
+ pretrained_vae_name_or_path: black-forest-labs/FLUX.1-schnell
44
+ projection_dim: 1024
45
+ text_embed_dim: 768
46
+ logit_scale_init_value: 2.6592
47
+ freeze_text_encoder: false
48
+ guidance_scale: 0.0
49
+ noise_offset: false
50
+ noise_offset_coeff: 0.05
51
+ max_sequence_length: 512
52
+ image_size: 1024
53
+ criterion:
54
+ _target_: trainer.criterions.step_clip_criterion_flux.StepFluxCLIPCriterion
55
+ is_distributed: false
56
+ label_0_column_name: ${dataset.label_0_column_name}
57
+ label_1_column_name: ${dataset.label_1_column_name}
58
+ input_ids_column_name: ${dataset.input_ids_column_name}
59
+ input_ids_2_column_name: ${dataset.input_ids_2_column_name}
60
+ pixels_0_column_name: ${dataset.pixels_0_column_name}
61
+ pixels_1_column_name: ${dataset.pixels_1_column_name}
62
+ num_examples_per_prompt_column_name: ${dataset.num_examples_per_prompt_column_name}
63
+ timestep_column_name: ${dataset.timestep_column_name}
64
+ loss_type: pair
65
+ batch_coeff: 1.0
66
+ aux_loss_coeff: 1.0
67
+ dataset:
68
+ train_split_name: train
69
+ valid_split_name: validation_unique
70
+ test_split_name: test_unique
71
+ batch_size: 4
72
+ num_workers: 2
73
+ drop_last: true
74
+ _target_: trainer.datasets.step_flux_hf_dataset.StepFluxHFDataset
75
+ dataset_name: pickapic-anonymous/pickapic_v1
76
+ dataset_config_name: null
77
+ from_disk: false
78
+ cache_dir: null
79
+ caption_column_name: caption
80
+ input_ids_column_name: input_ids
81
+ input_ids_2_column_name: input_ids_2
82
+ image_0_column_name: jpg_0
83
+ image_1_column_name: jpg_1
84
+ label_0_column_name: label_0
85
+ label_1_column_name: label_1
86
+ are_different_column_name: are_different
87
+ has_label_column_name: has_label
88
+ pixels_0_column_name: pixel_values_0
89
+ pixels_1_column_name: pixel_values_1
90
+ timestep_column_name: timestep
91
+ constant_timestep: 1
92
+ variable_timestep: true
93
+ largest_timestep: 951
94
+ compare_between_timestep: false
95
+ timestep_comparison_column_name: timestep_comparison
96
+ timestep_interval: 1
97
+ num_examples_per_prompt_column_name: num_example_per_prompt
98
+ keep_only_different: false
99
+ keep_only_with_label: false
100
+ keep_only_with_label_in_non_train: true
101
+ keep_only_with_pesudo_preference: true
102
+ pseudo_preference_path: /g/data/rr81/LPO/lrm/flux/vqa_aes_clip_score_mp.csv
103
+ filter_strategy: 2
104
+ processor:
105
+ pretrained_model_name_or_path: ${model.pretrained_model_name_or_path}
106
+ max_sequence_length: ${model.max_sequence_length}
107
+ image_size: ${model.image_size}
108
+ random_crop: false
109
+ no_hflip: true
110
+ limit_examples_per_prompt: -1
111
+ only_on_best: false
112
+ optimizer:
113
+ _target_: trainer.optimizers.dummy_optimizer.BaseDummyOptim
114
+ lr: 1.0e-05
115
+ weight_decay: 0.3
116
+ lr_scheduler:
117
+ _target_: trainer.lr_schedulers.dummy_lr_scheduler.instantiate_dummy_lr_scheduler
118
+ lr: ${optimizer.lr}
119
+ lr_warmup_steps: 1000
120
+ total_num_steps: ${accelerator.max_steps}
121
+ debug:
122
+ activate: false
123
+ port: 5900
124
+ output_dir: logs/lrm/${accelerator.project_name}/${accelerator.run_name}
lrm/flux/.hydra/hydra.yaml ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: .
4
+ sweep:
5
+ dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
6
+ subdir: ${hydra.job.num}
7
+ launcher:
8
+ _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
9
+ sweeper:
10
+ _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
11
+ max_batch_size: null
12
+ params: null
13
+ help:
14
+ app_name: ${hydra.job.name}
15
+ header: '${hydra.help.app_name} is powered by Hydra.
16
+
17
+ '
18
+ footer: 'Powered by Hydra (https://hydra.cc)
19
+
20
+ Use --hydra-help to view Hydra specific help
21
+
22
+ '
23
+ template: '${hydra.help.header}
24
+
25
+ == Configuration groups ==
26
+
27
+ Compose your configuration from those groups (group=option)
28
+
29
+
30
+ $APP_CONFIG_GROUPS
31
+
32
+
33
+ == Config ==
34
+
35
+ Override anything in the config (foo.bar=value)
36
+
37
+
38
+ $CONFIG
39
+
40
+
41
+ ${hydra.help.footer}
42
+
43
+ '
44
+ hydra_help:
45
+ template: 'Hydra (${hydra.runtime.version})
46
+
47
+ See https://hydra.cc for more info.
48
+
49
+
50
+ == Flags ==
51
+
52
+ $FLAGS_HELP
53
+
54
+
55
+ == Configuration groups ==
56
+
57
+ Compose your configuration from those groups (For example, append hydra/job_logging=disabled
58
+ to command line)
59
+
60
+
61
+ $HYDRA_CONFIG_GROUPS
62
+
63
+
64
+ Use ''--cfg hydra'' to Show the Hydra config.
65
+
66
+ '
67
+ hydra_help: ???
68
+ hydra_logging:
69
+ version: 1
70
+ formatters:
71
+ simple:
72
+ format: '[%(asctime)s][HYDRA] %(message)s'
73
+ handlers:
74
+ console:
75
+ class: logging.StreamHandler
76
+ formatter: simple
77
+ stream: ext://sys.stdout
78
+ root:
79
+ level: INFO
80
+ handlers:
81
+ - console
82
+ loggers:
83
+ logging_example:
84
+ level: DEBUG
85
+ disable_existing_loggers: false
86
+ job_logging:
87
+ version: 1
88
+ formatters:
89
+ simple:
90
+ format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
91
+ handlers:
92
+ console:
93
+ class: logging.StreamHandler
94
+ formatter: simple
95
+ stream: ext://sys.stdout
96
+ file:
97
+ class: logging.FileHandler
98
+ formatter: simple
99
+ filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
100
+ root:
101
+ level: INFO
102
+ handlers:
103
+ - console
104
+ - file
105
+ disable_existing_loggers: false
106
+ env: {}
107
+ mode: RUN
108
+ searchpath: []
109
+ callbacks: {}
110
+ output_subdir: .hydra
111
+ overrides:
112
+ hydra:
113
+ - hydra.mode=RUN
114
+ task:
115
+ - accelerator.mixed_precision=BF16
116
+ - accelerator.log_with=null
117
+ - accelerator=debug
118
+ - criterion.is_distributed=false
119
+ - dataset.pseudo_preference_path=/g/data/rr81/LPO/lrm/flux/vqa_aes_clip_score_mp.csv
120
+ job:
121
+ name: train
122
+ chdir: null
123
+ override_dirname: accelerator.log_with=null,accelerator.mixed_precision=BF16,accelerator=debug,criterion.is_distributed=false,dataset.pseudo_preference_path=/g/data/rr81/LPO/lrm/flux/vqa_aes_clip_score_mp.csv
124
+ id: ???
125
+ num: ???
126
+ config_name: step_flux_base
127
+ env_set: {}
128
+ env_copy: []
129
+ config:
130
+ override_dirname:
131
+ kv_sep: '='
132
+ item_sep: ','
133
+ exclude_keys: []
134
+ runtime:
135
+ version: 1.3.2
136
+ version_base: '1.3'
137
+ cwd: /g/data/rr81/LPO/lrm/flux
138
+ config_sources:
139
+ - path: hydra.conf
140
+ schema: pkg
141
+ provider: hydra
142
+ - path: /g/data/rr81/LPO/lrm/flux/trainer/conf
143
+ schema: file
144
+ provider: main
145
+ - path: ''
146
+ schema: structured
147
+ provider: schema
148
+ output_dir: /g/data/rr81/LPO/lrm/flux
149
+ choices:
150
+ lr_scheduler: dummy
151
+ optimizer: dummy
152
+ dataset: step_flux
153
+ criterion: step_clip_flux
154
+ model: step_flux_base
155
+ task: step_flux
156
+ accelerator: debug
157
+ hydra/env: default
158
+ hydra/callbacks: null
159
+ hydra/job_logging: default
160
+ hydra/hydra_logging: default
161
+ hydra/hydra_help: default
162
+ hydra/help: default
163
+ hydra/sweeper: basic
164
+ hydra/launcher: basic
165
+ hydra/output: default
166
+ verbose: false
lrm/flux/.hydra/overrides.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ - accelerator.mixed_precision=BF16
2
+ - accelerator.log_with=null
3
+ - accelerator=debug
4
+ - criterion.is_distributed=false
5
+ - dataset.pseudo_preference_path=/g/data/rr81/LPO/lrm/flux/vqa_aes_clip_score_mp.csv
lrm/flux/README.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Please check the train_flux.sh file
2
+ In line 21 there is a variable named "RUN_PROFILE" which can be set to "main" or "quick for testing. If set to "main", it will run the full training loop. If set to "quick", it will run a minimal training loop for testing purposes.
3
+
4
+ I have set distributed shared for quick testing mode but distributed is set for main mode.