Upload 53 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +10 -0
- LICENSE +21 -0
- README.md +228 -0
- assets/cat_2x.gif +3 -0
- assets/clear2rainy_results.jpg +3 -0
- assets/day2night_results.jpg +3 -0
- assets/edge_to_image_results.jpg +3 -0
- assets/examples/bird.png +3 -0
- assets/examples/bird_canny.png +0 -0
- assets/examples/bird_canny_blue.png +0 -0
- assets/examples/circles_inference_input.png +0 -0
- assets/examples/circles_inference_output.png +0 -0
- assets/examples/clear2rainy_input.png +0 -0
- assets/examples/clear2rainy_output.png +0 -0
- assets/examples/day2night_input.png +0 -0
- assets/examples/day2night_output.png +0 -0
- assets/examples/my_horse2zebra_input.jpg +0 -0
- assets/examples/my_horse2zebra_output.jpg +0 -0
- assets/examples/night2day_input.png +0 -0
- assets/examples/night2day_output.png +0 -0
- assets/examples/rainy2clear_input.png +0 -0
- assets/examples/rainy2clear_output.png +0 -0
- assets/examples/sketch_input.png +0 -0
- assets/examples/sketch_output.png +0 -0
- assets/examples/training_evaluation.png +0 -0
- assets/examples/training_evaluation_unpaired.png +0 -0
- assets/examples/training_step_0.png +0 -0
- assets/examples/training_step_500.png +0 -0
- assets/examples/training_step_6000.png +0 -0
- assets/fish_2x.gif +3 -0
- assets/gen_variations.jpg +3 -0
- assets/method.jpg +0 -0
- assets/night2day_results.jpg +3 -0
- assets/rainy2clear.jpg +3 -0
- assets/teaser_results.jpg +3 -0
- docs/training_cyclegan_turbo.md +98 -0
- docs/training_pix2pix_turbo.md +118 -0
- environment.yaml +34 -0
- gradio_canny2image.py +78 -0
- gradio_sketch2image.py +382 -0
- requirements.txt +29 -0
- scripts/download_fill50k.sh +5 -0
- scripts/download_horse2zebra.sh +5 -0
- src/cyclegan_turbo.py +254 -0
- src/image_prep.py +12 -0
- src/inference_paired.py +75 -0
- src/inference_unpaired.py +58 -0
- src/model.py +73 -0
- src/my_utils/dino_struct.py +185 -0
- src/my_utils/training_utils.py +409 -0
.gitattributes
CHANGED
@@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/cat_2x.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/clear2rainy_results.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/day2night_results.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/edge_to_image_results.jpg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/examples/bird.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/fish_2x.gif filter=lfs diff=lfs merge=lfs -text
|
42 |
+
assets/gen_variations.jpg filter=lfs diff=lfs merge=lfs -text
|
43 |
+
assets/night2day_results.jpg filter=lfs diff=lfs merge=lfs -text
|
44 |
+
assets/rainy2clear.jpg filter=lfs diff=lfs merge=lfs -text
|
45 |
+
assets/teaser_results.jpg filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 img-to-img-turbo
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# img2img-turbo
|
2 |
+
|
3 |
+
[**Paper**](https://arxiv.org/abs/2403.12036) | [**Sketch2Image Demo**](https://huggingface.co/spaces/gparmar/img2img-turbo-sketch)
|
4 |
+
#### **Quick start:** [**Running Locally**](#getting-started) | [**Gradio (locally hosted)**](#gradio-demo) | [**Training**](#training-with-your-own-data)
|
5 |
+
|
6 |
+
### Cat Sketching
|
7 |
+
<p align="left" >
|
8 |
+
<img src="https://raw.githubusercontent.com/GaParmar/img2img-turbo/main/assets/cat_2x.gif" width="800" />
|
9 |
+
</p>
|
10 |
+
|
11 |
+
### Fish Sketching
|
12 |
+
<p align="left">
|
13 |
+
<img src="https://raw.githubusercontent.com/GaParmar/img2img-turbo/main/assets/fish_2x.gif" width="800" />
|
14 |
+
</p>
|
15 |
+
|
16 |
+
|
17 |
+
We propose a general method for adapting a single-step diffusion model, such as SD-Turbo, to new tasks and domains through adversarial learning. This enables us to leverage the internal knowledge of pre-trained diffusion models while achieving efficient inference (e.g., for 512x512 images, 0.29 seconds on A6000 and 0.11 seconds on A100).
|
18 |
+
|
19 |
+
Our one-step conditional models **CycleGAN-Turbo** and **pix2pix-turbo** can perform various image-to-image translation tasks for both unpaired and paired settings. CycleGAN-Turbo outperforms existing GAN-based and diffusion-based methods, while pix2pix-turbo is on par with recent works such as ControlNet for Sketch2Photo and Edge2Image, but with one-step inference.
|
20 |
+
|
21 |
+
[One-Step Image Translation with Text-to-Image Models](https://arxiv.org/abs/2403.12036)<br>
|
22 |
+
[Gaurav Parmar](https://gauravparmar.com/), [Taesung Park](https://taesung.me/), [Srinivasa Narasimhan](https://www.cs.cmu.edu/~srinivas/), [Jun-Yan Zhu](https://github.com/junyanz/)<br>
|
23 |
+
CMU and Adobe, arXiv 2403.12036
|
24 |
+
|
25 |
+
<br>
|
26 |
+
<div>
|
27 |
+
<p align="center">
|
28 |
+
<img src='assets/teaser_results.jpg' align="center" width=1000px>
|
29 |
+
</p>
|
30 |
+
</div>
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
## Results
|
36 |
+
|
37 |
+
### Paired Translation with pix2pix-turbo
|
38 |
+
**Edge to Image**
|
39 |
+
<div>
|
40 |
+
<p align="center">
|
41 |
+
<img src='assets/edge_to_image_results.jpg' align="center" width=800px>
|
42 |
+
</p>
|
43 |
+
</div>
|
44 |
+
|
45 |
+
<!-- **Sketch to Image**
|
46 |
+
TODO -->
|
47 |
+
### Generating Diverse Outputs
|
48 |
+
By varying the input noise map, our method can generate diverse outputs from the same input conditioning.
|
49 |
+
The output style can be controlled by changing the text prompt.
|
50 |
+
<div> <p align="center">
|
51 |
+
<img src='assets/gen_variations.jpg' align="center" width=800px>
|
52 |
+
</p> </div>
|
53 |
+
|
54 |
+
### Unpaired Translation with CycleGAN-Turbo
|
55 |
+
|
56 |
+
**Day to Night**
|
57 |
+
<div> <p align="center">
|
58 |
+
<img src='assets/day2night_results.jpg' align="center" width=800px>
|
59 |
+
</p> </div>
|
60 |
+
|
61 |
+
**Night to Day**
|
62 |
+
<div><p align="center">
|
63 |
+
<img src='assets/night2day_results.jpg' align="center" width=800px>
|
64 |
+
</p> </div>
|
65 |
+
|
66 |
+
**Clear to Rainy**
|
67 |
+
<div>
|
68 |
+
<p align="center">
|
69 |
+
<img src='assets/clear2rainy_results.jpg' align="center" width=800px>
|
70 |
+
</p>
|
71 |
+
</div>
|
72 |
+
|
73 |
+
**Rainy to Clear**
|
74 |
+
<div>
|
75 |
+
<p align="center">
|
76 |
+
<img src='assets/rainy2clear.jpg' align="center" width=800px>
|
77 |
+
</p>
|
78 |
+
</div>
|
79 |
+
<hr>
|
80 |
+
|
81 |
+
|
82 |
+
## Method
|
83 |
+
**Our Generator Architecture:**
|
84 |
+
We tightly integrate three separate modules in the original latent diffusion models into a single end-to-end network with small trainable weights. This architecture allows us to translate the input image x to the output y, while retaining the input scene structure. We use LoRA adapters in each module, introduce skip connections and Zero-Convs between input and output, and retrain the first layer of the U-Net. Blue boxes indicate trainable layers. Semi-transparent layers are frozen. The same generator can be used for various GAN objectives.
|
85 |
+
<div>
|
86 |
+
<p align="center">
|
87 |
+
<img src='assets/method.jpg' align="center" width=900px>
|
88 |
+
</p>
|
89 |
+
</div>
|
90 |
+
|
91 |
+
|
92 |
+
## Getting Started
|
93 |
+
**Environment Setup**
|
94 |
+
- We provide a [conda env file](environment.yml) that contains all the required dependencies.
|
95 |
+
```
|
96 |
+
conda env create -f environment.yaml
|
97 |
+
```
|
98 |
+
- Following this, you can activate the conda environment with the command below.
|
99 |
+
```
|
100 |
+
conda activate img2img-turbo
|
101 |
+
```
|
102 |
+
- Or use virtual environment:
|
103 |
+
```
|
104 |
+
python3 -m venv venv
|
105 |
+
source venv/bin/activate
|
106 |
+
pip install -r requirements.txt
|
107 |
+
```
|
108 |
+
**Paired Image Translation (pix2pix-turbo)**
|
109 |
+
- The following command takes an image file and a prompt as inputs, extracts the canny edges, and saves the results in the directory specified.
|
110 |
+
```bash
|
111 |
+
python src/inference_paired.py --model_name "edge_to_image" \
|
112 |
+
--input_image "assets/examples/bird.png" \
|
113 |
+
--prompt "a blue bird" \
|
114 |
+
--output_dir "outputs"
|
115 |
+
```
|
116 |
+
<table>
|
117 |
+
<th>Input Image</th>
|
118 |
+
<th>Canny Edges</th>
|
119 |
+
<th>Model Output</th>
|
120 |
+
</tr>
|
121 |
+
<tr>
|
122 |
+
<td><img src='assets/examples/bird.png' width="200px"></td>
|
123 |
+
<td><img src='assets/examples/bird_canny.png' width="200px"></td>
|
124 |
+
<td><img src='assets/examples/bird_canny_blue.png' width="200px"></td>
|
125 |
+
</tr>
|
126 |
+
</table>
|
127 |
+
<br>
|
128 |
+
|
129 |
+
- The following command takes a sketch and a prompt as inputs, and saves the results in the directory specified.
|
130 |
+
```bash
|
131 |
+
python src/inference_paired.py --model_name "sketch_to_image_stochastic" \
|
132 |
+
--input_image "assets/examples/sketch_input.png" --gamma 0.4 \
|
133 |
+
--prompt "ethereal fantasy concept art of an asteroid. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy" \
|
134 |
+
--output_dir "outputs"
|
135 |
+
```
|
136 |
+
<table>
|
137 |
+
<th>Input</th>
|
138 |
+
<th>Model Output</th>
|
139 |
+
</tr>
|
140 |
+
<tr>
|
141 |
+
<td><img src='assets/examples/sketch_input.png' width="400px"></td>
|
142 |
+
<td><img src='assets/examples/sketch_output.png' width="400px"></td>
|
143 |
+
</tr>
|
144 |
+
</table>
|
145 |
+
<br>
|
146 |
+
|
147 |
+
**Unpaired Image Translation (CycleGAN-Turbo)**
|
148 |
+
- The following command takes a **day** image file as input, and saves the output **night** in the directory specified.
|
149 |
+
```
|
150 |
+
python src/inference_unpaired.py --model_name "day_to_night" \
|
151 |
+
--input_image "assets/examples/day2night_input.png" --output_dir "outputs"
|
152 |
+
```
|
153 |
+
<table>
|
154 |
+
<th>Input (day)</th>
|
155 |
+
<th>Model Output (night)</th>
|
156 |
+
</tr>
|
157 |
+
<tr>
|
158 |
+
<td><img src='assets/examples/day2night_input.png' width="400px"></td>
|
159 |
+
<td><img src='assets/examples/day2night_output.png' width="400px"></td>
|
160 |
+
</tr>
|
161 |
+
</table>
|
162 |
+
|
163 |
+
- The following command takes a **night** image file as input, and saves the output **day** in the directory specified.
|
164 |
+
```
|
165 |
+
python src/inference_unpaired.py --model_name "night_to_day" \
|
166 |
+
--input_image "assets/examples/night2day_input.png" --output_dir "outputs"
|
167 |
+
```
|
168 |
+
<table>
|
169 |
+
<th>Input (night)</th>
|
170 |
+
<th>Model Output (day)</th>
|
171 |
+
</tr>
|
172 |
+
<tr>
|
173 |
+
<td><img src='assets/examples/night2day_input.png' width="400px"></td>
|
174 |
+
<td><img src='assets/examples/night2day_output.png' width="400px"></td>
|
175 |
+
</tr>
|
176 |
+
</table>
|
177 |
+
|
178 |
+
- The following command takes a **clear** image file as input, and saves the output **rainy** in the directory specified.
|
179 |
+
```
|
180 |
+
python src/inference_unpaired.py --model_name "clear_to_rainy" \
|
181 |
+
--input_image "assets/examples/clear2rainy_input.png" --output_dir "outputs"
|
182 |
+
```
|
183 |
+
<table>
|
184 |
+
<th>Input (clear)</th>
|
185 |
+
<th>Model Output (rainy)</th>
|
186 |
+
</tr>
|
187 |
+
<tr>
|
188 |
+
<td><img src='assets/examples/clear2rainy_input.png' width="400px"></td>
|
189 |
+
<td><img src='assets/examples/clear2rainy_output.png' width="400px"></td>
|
190 |
+
</tr>
|
191 |
+
</table>
|
192 |
+
|
193 |
+
- The following command takes a **rainy** image file as input, and saves the output **clear** in the directory specified.
|
194 |
+
```
|
195 |
+
python src/inference_unpaired.py --model_name "rainy_to_clear" \
|
196 |
+
--input_image "assets/examples/rainy2clear_input.png" --output_dir "outputs"
|
197 |
+
```
|
198 |
+
<table>
|
199 |
+
<th>Input (rainy)</th>
|
200 |
+
<th>Model Output (clear)</th>
|
201 |
+
</tr>
|
202 |
+
<tr>
|
203 |
+
<td><img src='assets/examples/rainy2clear_input.png' width="400px"></td>
|
204 |
+
<td><img src='assets/examples/rainy2clear_output.png' width="400px"></td>
|
205 |
+
</tr>
|
206 |
+
</table>
|
207 |
+
|
208 |
+
|
209 |
+
|
210 |
+
## Gradio Demo
|
211 |
+
- We provide a Gradio demo for the paired image translation tasks.
|
212 |
+
- The following command will launch the sketch to image locally using gradio.
|
213 |
+
```
|
214 |
+
gradio gradio_sketch2image.py
|
215 |
+
```
|
216 |
+
- The following command will launch the canny edge to image gradio demo locally.
|
217 |
+
```
|
218 |
+
gradio gradio_canny2image.py
|
219 |
+
```
|
220 |
+
|
221 |
+
|
222 |
+
## Training with your own data
|
223 |
+
- See the steps [here](docs/training_pix2pix_turbo.md) for training a pix2pix-turbo model on your paired data.
|
224 |
+
- See the steps [here](docs/training_cyclegan_turbo.md) for training a CycleGAN-Turbo model on your unpaired data.
|
225 |
+
|
226 |
+
|
227 |
+
## Acknowledgment
|
228 |
+
Our work uses the Stable Diffusion-Turbo as the base model with the following [LICENSE](https://huggingface.co/stabilityai/sd-turbo/blob/main/LICENSE).
|
assets/cat_2x.gif
ADDED
Git LFS Details
|
assets/clear2rainy_results.jpg
ADDED
Git LFS Details
|
assets/day2night_results.jpg
ADDED
Git LFS Details
|
assets/edge_to_image_results.jpg
ADDED
Git LFS Details
|
assets/examples/bird.png
ADDED
Git LFS Details
|
assets/examples/bird_canny.png
ADDED
assets/examples/bird_canny_blue.png
ADDED
assets/examples/circles_inference_input.png
ADDED
assets/examples/circles_inference_output.png
ADDED
assets/examples/clear2rainy_input.png
ADDED
assets/examples/clear2rainy_output.png
ADDED
assets/examples/day2night_input.png
ADDED
assets/examples/day2night_output.png
ADDED
assets/examples/my_horse2zebra_input.jpg
ADDED
assets/examples/my_horse2zebra_output.jpg
ADDED
assets/examples/night2day_input.png
ADDED
assets/examples/night2day_output.png
ADDED
assets/examples/rainy2clear_input.png
ADDED
assets/examples/rainy2clear_output.png
ADDED
assets/examples/sketch_input.png
ADDED
assets/examples/sketch_output.png
ADDED
assets/examples/training_evaluation.png
ADDED
assets/examples/training_evaluation_unpaired.png
ADDED
assets/examples/training_step_0.png
ADDED
assets/examples/training_step_500.png
ADDED
assets/examples/training_step_6000.png
ADDED
assets/fish_2x.gif
ADDED
Git LFS Details
|
assets/gen_variations.jpg
ADDED
Git LFS Details
|
assets/method.jpg
ADDED
assets/night2day_results.jpg
ADDED
Git LFS Details
|
assets/rainy2clear.jpg
ADDED
Git LFS Details
|
assets/teaser_results.jpg
ADDED
Git LFS Details
|
docs/training_cyclegan_turbo.md
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Training with Unpaired Data (CycleGAN-turbo)
|
2 |
+
Here, we show how to train a CycleGAN-turbo model using unpaired data.
|
3 |
+
We will use the [horse2zebra dataset](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/datasets.md) introduced by [CycleGAN](https://junyanz.github.io/CycleGAN/) as an example dataset.
|
4 |
+
|
5 |
+
|
6 |
+
### Step 1. Get the Dataset
|
7 |
+
- First download the horse2zebra dataset from [here](https://www.cs.cmu.edu/~img2img-turbo/data/my_horse2zebra.zip) using the command below.
|
8 |
+
```
|
9 |
+
bash scripts/download_horse2zebra.sh
|
10 |
+
```
|
11 |
+
|
12 |
+
- Our training scripts expect the dataset to be in the following format:
|
13 |
+
```
|
14 |
+
data
|
15 |
+
├── dataset_name
|
16 |
+
│ ├── train_A
|
17 |
+
│ │ ├── 000000.png
|
18 |
+
│ │ ├── 000001.png
|
19 |
+
│ │ └── ...
|
20 |
+
│ ├── train_B
|
21 |
+
│ │ ├── 000000.png
|
22 |
+
│ │ ├── 000001.png
|
23 |
+
│ │ └── ...
|
24 |
+
│ └── fixed_prompt_a.txt
|
25 |
+
| └── fixed_prompt_b.txt
|
26 |
+
|
|
27 |
+
| ├── test_A
|
28 |
+
│ │ ├── 000000.png
|
29 |
+
│ │ ├── 000001.png
|
30 |
+
│ │ └── ...
|
31 |
+
│ ├── test_B
|
32 |
+
│ │ ├── 000000.png
|
33 |
+
│ │ ├── 000001.png
|
34 |
+
│ │ └── ...
|
35 |
+
```
|
36 |
+
- The `fixed_prompt_a.txt` and `fixed_prompt_b.txt` files contain the **fixed caption** used for the source and target domains respectively.
|
37 |
+
|
38 |
+
|
39 |
+
### Step 2. Train the Model
|
40 |
+
- Initialize the `accelerate` environment with the following command:
|
41 |
+
```
|
42 |
+
accelerate config
|
43 |
+
```
|
44 |
+
|
45 |
+
- Run the following command to train the model.
|
46 |
+
```
|
47 |
+
export NCCL_P2P_DISABLE=1
|
48 |
+
accelerate launch --main_process_port 29501 src/train_cyclegan_turbo.py \
|
49 |
+
--pretrained_model_name_or_path="stabilityai/sd-turbo" \
|
50 |
+
--output_dir="output/cyclegan_turbo/my_horse2zebra" \
|
51 |
+
--dataset_folder "data/my_horse2zebra" \
|
52 |
+
--train_img_prep "resize_286_randomcrop_256x256_hflip" --val_img_prep "no_resize" \
|
53 |
+
--learning_rate="1e-5" --max_train_steps=25000 \
|
54 |
+
--train_batch_size=1 --gradient_accumulation_steps=1 \
|
55 |
+
--report_to "wandb" --tracker_project_name "gparmar_unpaired_h2z_cycle_debug_v2" \
|
56 |
+
--enable_xformers_memory_efficient_attention --validation_steps 250 \
|
57 |
+
--lambda_gan 0.5 --lambda_idt 1 --lambda_cycle 1
|
58 |
+
```
|
59 |
+
|
60 |
+
- Additional optional flags:
|
61 |
+
- `--enable_xformers_memory_efficient_attention`: Enable memory-efficient attention in the model.
|
62 |
+
|
63 |
+
### Step 3. Monitor the training progress
|
64 |
+
- You can monitor the training progress using the [Weights & Biases](https://wandb.ai/site) dashboard.
|
65 |
+
|
66 |
+
- The training script will visualizing the training batch, the training losses, and validation set L2, LPIPS, and FID scores (if specified).
|
67 |
+
<div>
|
68 |
+
<p align="center">
|
69 |
+
<img src='../assets/examples/training_evaluation_unpaired.png' align="center" width=800px>
|
70 |
+
</p>
|
71 |
+
</div>
|
72 |
+
|
73 |
+
|
74 |
+
- The model checkpoints will be saved in the `<output_dir>/checkpoints` directory.
|
75 |
+
|
76 |
+
|
77 |
+
### Step 4. Running Inference with the trained models
|
78 |
+
|
79 |
+
- You can run inference using the trained model using the following command:
|
80 |
+
```
|
81 |
+
python src/inference_unpaired.py --model_path "output/cyclegan_turbo/my_horse2zebra/checkpoints/model_1001.pkl" \
|
82 |
+
--input_image "data/my_horse2zebra/test_A/n02381460_20.jpg" \
|
83 |
+
--prompt "picture of a zebra" --direction "a2b" \
|
84 |
+
--output_dir "outputs" --image_prep "no_resize"
|
85 |
+
```
|
86 |
+
|
87 |
+
- The above command should generate the following output:
|
88 |
+
<table>
|
89 |
+
<tr>
|
90 |
+
<th>Model Input</th>
|
91 |
+
<th>Model Output</th>
|
92 |
+
</tr>
|
93 |
+
<tr>
|
94 |
+
<td><img src='../assets/examples/my_horse2zebra_input.jpg' width="200px"></td>
|
95 |
+
<td><img src='../assets/examples/my_horse2zebra_output.jpg' width="200px"></td>
|
96 |
+
</tr>
|
97 |
+
</table>
|
98 |
+
|
docs/training_pix2pix_turbo.md
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Training with Paired Data (pix2pix-turbo)
|
2 |
+
Here, we show how to train a pix2pix-turbo model using paired data.
|
3 |
+
We will use the [Fill50k dataset](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md) used by [ControlNet](https://github.com/lllyasviel/ControlNet) as an example dataset.
|
4 |
+
|
5 |
+
|
6 |
+
### Step 1. Get the Dataset
|
7 |
+
- First download a modified Fill50k dataset from [here](https://www.cs.cmu.edu/~img2img-turbo/data/my_fill50k.zip) using the command below.
|
8 |
+
```
|
9 |
+
bash scripts/download_fill50k.sh
|
10 |
+
```
|
11 |
+
|
12 |
+
- Our training scripts expect the dataset to be in the following format:
|
13 |
+
```
|
14 |
+
data
|
15 |
+
├── dataset_name
|
16 |
+
│ ├── train_A
|
17 |
+
│ │ ├── 000000.png
|
18 |
+
│ │ ├── 000001.png
|
19 |
+
│ │ └── ...
|
20 |
+
│ ├── train_B
|
21 |
+
│ │ ├── 000000.png
|
22 |
+
│ │ ├── 000001.png
|
23 |
+
│ │ └── ...
|
24 |
+
│ └── train_prompts.json
|
25 |
+
|
|
26 |
+
| ├── test_A
|
27 |
+
│ │ ├── 000000.png
|
28 |
+
│ │ ├── 000001.png
|
29 |
+
│ │ └── ...
|
30 |
+
│ ├── test_B
|
31 |
+
│ │ ├── 000000.png
|
32 |
+
│ │ ├── 000001.png
|
33 |
+
│ │ └── ...
|
34 |
+
│ └── test_prompts.json
|
35 |
+
```
|
36 |
+
|
37 |
+
|
38 |
+
### Step 2. Train the Model
|
39 |
+
- Initialize the `accelerate` environment with the following command:
|
40 |
+
```
|
41 |
+
accelerate config
|
42 |
+
```
|
43 |
+
|
44 |
+
- Run the following command to train the model.
|
45 |
+
```
|
46 |
+
accelerate launch src/train_pix2pix_turbo.py \
|
47 |
+
--pretrained_model_name_or_path="stabilityai/sd-turbo" \
|
48 |
+
--output_dir="output/pix2pix_turbo/fill50k" \
|
49 |
+
--dataset_folder="data/my_fill50k" \
|
50 |
+
--resolution=512 \
|
51 |
+
--train_batch_size=2 \
|
52 |
+
--enable_xformers_memory_efficient_attention --viz_freq 25 \
|
53 |
+
--track_val_fid \
|
54 |
+
--report_to "wandb" --tracker_project_name "pix2pix_turbo_fill50k"
|
55 |
+
```
|
56 |
+
|
57 |
+
- Additional optional flags:
|
58 |
+
- `--track_val_fid`: Track FID score on the validation set using the [Clean-FID](https://github.com/GaParmar/clean-fid) implementation.
|
59 |
+
- `--enable_xformers_memory_efficient_attention`: Enable memory-efficient attention in the model.
|
60 |
+
- `--viz_freq`: Frequency of visualizing the results during training.
|
61 |
+
|
62 |
+
### Step 3. Monitor the training progress
|
63 |
+
- You can monitor the training progress using the [Weights & Biases](https://wandb.ai/site) dashboard.
|
64 |
+
|
65 |
+
- The training script will visualizing the training batch, the training losses, and validation set L2, LPIPS, and FID scores (if specified).
|
66 |
+
<div>
|
67 |
+
<p align="center">
|
68 |
+
<img src='../assets/examples/training_evaluation.png' align="center" width=800px>
|
69 |
+
</p>
|
70 |
+
</div>
|
71 |
+
|
72 |
+
|
73 |
+
- The model checkpoints will be saved in the `<output_dir>/checkpoints` directory.
|
74 |
+
|
75 |
+
- Screenshots of the training progress are shown below:
|
76 |
+
- Step 0:
|
77 |
+
<div>
|
78 |
+
<p align="center">
|
79 |
+
<img src='../assets/examples/training_step_0.png' align="center" width=800px>
|
80 |
+
</p>
|
81 |
+
</div>
|
82 |
+
|
83 |
+
- Step 500:
|
84 |
+
<div>
|
85 |
+
<p align="center">
|
86 |
+
<img src='../assets/examples/training_step_500.png' align="center" width=800px>
|
87 |
+
</p>
|
88 |
+
</div>
|
89 |
+
|
90 |
+
- Step 6000:
|
91 |
+
<div>
|
92 |
+
<p align="center">
|
93 |
+
<img src='../assets/examples/training_step_6000.png' align="center" width=800px>
|
94 |
+
</p>
|
95 |
+
</div>
|
96 |
+
|
97 |
+
|
98 |
+
### Step 4. Running Inference with the trained models
|
99 |
+
|
100 |
+
- You can run inference using the trained model using the following command:
|
101 |
+
```
|
102 |
+
python src/inference_paired.py --model_path "output/pix2pix_turbo/fill50k/checkpoints/model_6001.pkl" \
|
103 |
+
--input_image "data/my_fill50k/test_A/40000.png" \
|
104 |
+
--prompt "violet circle with orange background" \
|
105 |
+
--output_dir "outputs"
|
106 |
+
```
|
107 |
+
|
108 |
+
- The above command should generate the following output:
|
109 |
+
<table>
|
110 |
+
<tr>
|
111 |
+
<th>Model Input</th>
|
112 |
+
<th>Model Output</th>
|
113 |
+
</tr>
|
114 |
+
<tr>
|
115 |
+
<td><img src='../assets/examples/circles_inference_input.png' width="200px"></td>
|
116 |
+
<td><img src='../assets/examples/circles_inference_output.png' width="200px"></td>
|
117 |
+
</tr>
|
118 |
+
</table>
|
environment.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: img2img-turbo
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- python=3.10
|
7 |
+
- pip:
|
8 |
+
- clip @ git+https://github.com/openai/CLIP.git
|
9 |
+
- einops>=0.6.1
|
10 |
+
- numpy>=1.24.4
|
11 |
+
- open-clip-torch>=2.20.0
|
12 |
+
- opencv-python==4.6.0.66
|
13 |
+
- pillow>=9.5.0
|
14 |
+
- scipy==1.11.1
|
15 |
+
- timm>=0.9.2
|
16 |
+
- tokenizers
|
17 |
+
- torch>=2.0.1
|
18 |
+
|
19 |
+
- torchaudio>=2.0.2
|
20 |
+
- torchdata==0.6.1
|
21 |
+
- torchmetrics>=1.0.1
|
22 |
+
- torchvision>=0.15.2
|
23 |
+
|
24 |
+
- tqdm>=4.65.0
|
25 |
+
- transformers==4.35.2
|
26 |
+
- urllib3<1.27,>=1.25.4
|
27 |
+
- xformers>=0.0.20
|
28 |
+
- streamlit-keyup==0.2.0
|
29 |
+
- lpips
|
30 |
+
- clean-fid
|
31 |
+
- peft
|
32 |
+
- dominate
|
33 |
+
- diffusers==0.25.1
|
34 |
+
- gradio==3.43.1
|
gradio_canny2image.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
from torchvision import transforms
|
5 |
+
import gradio as gr
|
6 |
+
from src.image_prep import canny_from_pil
|
7 |
+
from src.pix2pix_turbo import Pix2Pix_Turbo
|
8 |
+
|
9 |
+
model = Pix2Pix_Turbo("edge_to_image")
|
10 |
+
|
11 |
+
|
12 |
+
def process(input_image, prompt, low_threshold, high_threshold):
|
13 |
+
# resize to be a multiple of 8
|
14 |
+
new_width = input_image.width - input_image.width % 8
|
15 |
+
new_height = input_image.height - input_image.height % 8
|
16 |
+
input_image = input_image.resize((new_width, new_height))
|
17 |
+
canny = canny_from_pil(input_image, low_threshold, high_threshold)
|
18 |
+
with torch.no_grad():
|
19 |
+
c_t = transforms.ToTensor()(canny).unsqueeze(0).cuda()
|
20 |
+
output_image = model(c_t, prompt)
|
21 |
+
output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
|
22 |
+
# flippy canny values, map all 0s to 1s and 1s to 0s
|
23 |
+
canny_viz = 1 - (np.array(canny) / 255)
|
24 |
+
canny_viz = Image.fromarray((canny_viz * 255).astype(np.uint8))
|
25 |
+
return canny_viz, output_pil
|
26 |
+
|
27 |
+
|
28 |
+
if __name__ == "__main__":
|
29 |
+
# load the model
|
30 |
+
with gr.Blocks() as demo:
|
31 |
+
gr.Markdown("# Pix2pix-Turbo: **Canny Edge -> Image**")
|
32 |
+
with gr.Row():
|
33 |
+
with gr.Column():
|
34 |
+
input_image = gr.Image(sources="upload", type="pil")
|
35 |
+
prompt = gr.Textbox(label="Prompt")
|
36 |
+
low_threshold = gr.Slider(
|
37 |
+
label="Canny low threshold",
|
38 |
+
minimum=1,
|
39 |
+
maximum=255,
|
40 |
+
value=100,
|
41 |
+
step=10,
|
42 |
+
)
|
43 |
+
high_threshold = gr.Slider(
|
44 |
+
label="Canny high threshold",
|
45 |
+
minimum=1,
|
46 |
+
maximum=255,
|
47 |
+
value=200,
|
48 |
+
step=10,
|
49 |
+
)
|
50 |
+
run_button = gr.Button(value="Run")
|
51 |
+
with gr.Column():
|
52 |
+
result_canny = gr.Image(type="pil")
|
53 |
+
with gr.Column():
|
54 |
+
result_output = gr.Image(type="pil")
|
55 |
+
|
56 |
+
prompt.submit(
|
57 |
+
fn=process,
|
58 |
+
inputs=[input_image, prompt, low_threshold, high_threshold],
|
59 |
+
outputs=[result_canny, result_output],
|
60 |
+
)
|
61 |
+
low_threshold.change(
|
62 |
+
fn=process,
|
63 |
+
inputs=[input_image, prompt, low_threshold, high_threshold],
|
64 |
+
outputs=[result_canny, result_output],
|
65 |
+
)
|
66 |
+
high_threshold.change(
|
67 |
+
fn=process,
|
68 |
+
inputs=[input_image, prompt, low_threshold, high_threshold],
|
69 |
+
outputs=[result_canny, result_output],
|
70 |
+
)
|
71 |
+
run_button.click(
|
72 |
+
fn=process,
|
73 |
+
inputs=[input_image, prompt, low_threshold, high_threshold],
|
74 |
+
outputs=[result_canny, result_output],
|
75 |
+
)
|
76 |
+
|
77 |
+
demo.queue()
|
78 |
+
demo.launch(debug=True, share=False)
|
gradio_sketch2image.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import base64
|
5 |
+
from io import BytesIO
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torchvision.transforms.functional as F
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
from src.pix2pix_turbo import Pix2Pix_Turbo
|
12 |
+
|
13 |
+
model = Pix2Pix_Turbo("sketch_to_image_stochastic")
|
14 |
+
|
15 |
+
style_list = [
|
16 |
+
{
|
17 |
+
"name": "Cinematic",
|
18 |
+
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"name": "3D Model",
|
22 |
+
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"name": "Anime",
|
26 |
+
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"name": "Digital Art",
|
30 |
+
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"name": "Photographic",
|
34 |
+
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"name": "Pixel art",
|
38 |
+
"prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"name": "Fantasy art",
|
42 |
+
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"name": "Neonpunk",
|
46 |
+
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"name": "Manga",
|
50 |
+
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
|
51 |
+
},
|
52 |
+
]
|
53 |
+
|
54 |
+
styles = {k["name"]: k["prompt"] for k in style_list}
|
55 |
+
STYLE_NAMES = list(styles.keys())
|
56 |
+
DEFAULT_STYLE_NAME = "Fantasy art"
|
57 |
+
MAX_SEED = np.iinfo(np.int32).max
|
58 |
+
|
59 |
+
|
60 |
+
def pil_image_to_data_uri(img, format="PNG"):
|
61 |
+
buffered = BytesIO()
|
62 |
+
img.save(buffered, format=format)
|
63 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
64 |
+
return f"data:image/{format.lower()};base64,{img_str}"
|
65 |
+
|
66 |
+
|
67 |
+
def run(image, prompt, prompt_template, style_name, seed, val_r):
|
68 |
+
print(f"prompt: {prompt}")
|
69 |
+
print("sketch updated")
|
70 |
+
if image is None:
|
71 |
+
ones = Image.new("L", (512, 512), 255)
|
72 |
+
temp_uri = pil_image_to_data_uri(ones)
|
73 |
+
return ones, gr.update(link=temp_uri), gr.update(link=temp_uri)
|
74 |
+
prompt = prompt_template.replace("{prompt}", prompt)
|
75 |
+
image = image.convert("RGB")
|
76 |
+
image_t = F.to_tensor(image) > 0.5
|
77 |
+
print(f"r_val={val_r}, seed={seed}")
|
78 |
+
with torch.no_grad():
|
79 |
+
c_t = image_t.unsqueeze(0).cuda().float()
|
80 |
+
torch.manual_seed(seed)
|
81 |
+
B, C, H, W = c_t.shape
|
82 |
+
noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
|
83 |
+
output_image = model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
|
84 |
+
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
|
85 |
+
input_sketch_uri = pil_image_to_data_uri(Image.fromarray(255 - np.array(image)))
|
86 |
+
output_image_uri = pil_image_to_data_uri(output_pil)
|
87 |
+
return (
|
88 |
+
output_pil,
|
89 |
+
gr.update(link=input_sketch_uri),
|
90 |
+
gr.update(link=output_image_uri),
|
91 |
+
)
|
92 |
+
|
93 |
+
|
94 |
+
def update_canvas(use_line, use_eraser):
|
95 |
+
if use_eraser:
|
96 |
+
_color = "#ffffff"
|
97 |
+
brush_size = 20
|
98 |
+
if use_line:
|
99 |
+
_color = "#000000"
|
100 |
+
brush_size = 4
|
101 |
+
return gr.update(brush_radius=brush_size, brush_color=_color, interactive=True)
|
102 |
+
|
103 |
+
|
104 |
+
def upload_sketch(file):
|
105 |
+
_img = Image.open(file.name)
|
106 |
+
_img = _img.convert("L")
|
107 |
+
return gr.update(value=_img, source="upload", interactive=True)
|
108 |
+
|
109 |
+
|
110 |
+
scripts = """
|
111 |
+
async () => {
|
112 |
+
globalThis.theSketchDownloadFunction = () => {
|
113 |
+
console.log("test")
|
114 |
+
var link = document.createElement("a");
|
115 |
+
dataUri = document.getElementById('download_sketch').href
|
116 |
+
link.setAttribute("href", dataUri)
|
117 |
+
link.setAttribute("download", "sketch.png")
|
118 |
+
document.body.appendChild(link); // Required for Firefox
|
119 |
+
link.click();
|
120 |
+
document.body.removeChild(link); // Clean up
|
121 |
+
|
122 |
+
// also call the output download function
|
123 |
+
theOutputDownloadFunction();
|
124 |
+
return false
|
125 |
+
}
|
126 |
+
|
127 |
+
globalThis.theOutputDownloadFunction = () => {
|
128 |
+
console.log("test output download function")
|
129 |
+
var link = document.createElement("a");
|
130 |
+
dataUri = document.getElementById('download_output').href
|
131 |
+
link.setAttribute("href", dataUri);
|
132 |
+
link.setAttribute("download", "output.png");
|
133 |
+
document.body.appendChild(link); // Required for Firefox
|
134 |
+
link.click();
|
135 |
+
document.body.removeChild(link); // Clean up
|
136 |
+
return false
|
137 |
+
}
|
138 |
+
|
139 |
+
globalThis.UNDO_SKETCH_FUNCTION = () => {
|
140 |
+
console.log("undo sketch function")
|
141 |
+
var button_undo = document.querySelector('#input_image > div.image-container.svelte-p3y7hu > div.svelte-s6ybro > button:nth-child(1)');
|
142 |
+
// Create a new 'click' event
|
143 |
+
var event = new MouseEvent('click', {
|
144 |
+
'view': window,
|
145 |
+
'bubbles': true,
|
146 |
+
'cancelable': true
|
147 |
+
});
|
148 |
+
button_undo.dispatchEvent(event);
|
149 |
+
}
|
150 |
+
|
151 |
+
globalThis.DELETE_SKETCH_FUNCTION = () => {
|
152 |
+
console.log("delete sketch function")
|
153 |
+
var button_del = document.querySelector('#input_image > div.image-container.svelte-p3y7hu > div.svelte-s6ybro > button:nth-child(2)');
|
154 |
+
// Create a new 'click' event
|
155 |
+
var event = new MouseEvent('click', {
|
156 |
+
'view': window,
|
157 |
+
'bubbles': true,
|
158 |
+
'cancelable': true
|
159 |
+
});
|
160 |
+
button_del.dispatchEvent(event);
|
161 |
+
}
|
162 |
+
|
163 |
+
globalThis.togglePencil = () => {
|
164 |
+
el_pencil = document.getElementById('my-toggle-pencil');
|
165 |
+
el_pencil.classList.toggle('clicked');
|
166 |
+
// simulate a click on the gradio button
|
167 |
+
btn_gradio = document.querySelector("#cb-line > label > input");
|
168 |
+
var event = new MouseEvent('click', {
|
169 |
+
'view': window,
|
170 |
+
'bubbles': true,
|
171 |
+
'cancelable': true
|
172 |
+
});
|
173 |
+
btn_gradio.dispatchEvent(event);
|
174 |
+
if (el_pencil.classList.contains('clicked')) {
|
175 |
+
document.getElementById('my-toggle-eraser').classList.remove('clicked');
|
176 |
+
document.getElementById('my-div-pencil').style.backgroundColor = "gray";
|
177 |
+
document.getElementById('my-div-eraser').style.backgroundColor = "white";
|
178 |
+
}
|
179 |
+
else {
|
180 |
+
document.getElementById('my-toggle-eraser').classList.add('clicked');
|
181 |
+
document.getElementById('my-div-pencil').style.backgroundColor = "white";
|
182 |
+
document.getElementById('my-div-eraser').style.backgroundColor = "gray";
|
183 |
+
}
|
184 |
+
}
|
185 |
+
|
186 |
+
globalThis.toggleEraser = () => {
|
187 |
+
element = document.getElementById('my-toggle-eraser');
|
188 |
+
element.classList.toggle('clicked');
|
189 |
+
// simulate a click on the gradio button
|
190 |
+
btn_gradio = document.querySelector("#cb-eraser > label > input");
|
191 |
+
var event = new MouseEvent('click', {
|
192 |
+
'view': window,
|
193 |
+
'bubbles': true,
|
194 |
+
'cancelable': true
|
195 |
+
});
|
196 |
+
btn_gradio.dispatchEvent(event);
|
197 |
+
if (element.classList.contains('clicked')) {
|
198 |
+
document.getElementById('my-toggle-pencil').classList.remove('clicked');
|
199 |
+
document.getElementById('my-div-pencil').style.backgroundColor = "white";
|
200 |
+
document.getElementById('my-div-eraser').style.backgroundColor = "gray";
|
201 |
+
}
|
202 |
+
else {
|
203 |
+
document.getElementById('my-toggle-pencil').classList.add('clicked');
|
204 |
+
document.getElementById('my-div-pencil').style.backgroundColor = "gray";
|
205 |
+
document.getElementById('my-div-eraser').style.backgroundColor = "white";
|
206 |
+
}
|
207 |
+
}
|
208 |
+
}
|
209 |
+
"""
|
210 |
+
|
211 |
+
with gr.Blocks(css="style.css") as demo:
|
212 |
+
|
213 |
+
gr.HTML(
|
214 |
+
"""
|
215 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
216 |
+
<div>
|
217 |
+
<h2><a href="https://github.com/GaParmar/img2img-turbo">One-Step Image Translation with Text-to-Image Models</a></h2>
|
218 |
+
<div>
|
219 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
220 |
+
<a href='https://gauravparmar.com/'>Gaurav Parmar, </a>
|
221 |
+
|
222 |
+
<a href='https://taesung.me/'> Taesung Park,</a>
|
223 |
+
|
224 |
+
<a href='https://www.cs.cmu.edu/~srinivas/'>Srinivasa Narasimhan, </a>
|
225 |
+
|
226 |
+
<a href='https://www.cs.cmu.edu/~junyanz/'> Jun-Yan Zhu </a>
|
227 |
+
</div>
|
228 |
+
</div>
|
229 |
+
</br>
|
230 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
231 |
+
<a href='https://arxiv.org/abs/2403.12036'>
|
232 |
+
<img src="https://img.shields.io/badge/arXiv-2403.12036-red">
|
233 |
+
</a>
|
234 |
+
|
235 |
+
<a href='https://github.com/GaParmar/img2img-turbo'>
|
236 |
+
<img src='https://img.shields.io/badge/github-%23121011.svg'>
|
237 |
+
</a>
|
238 |
+
|
239 |
+
<a href='https://github.com/GaParmar/img2img-turbo/blob/main/LICENSE'>
|
240 |
+
<img src='https://img.shields.io/badge/license-MIT-lightgrey'>
|
241 |
+
</a>
|
242 |
+
</div>
|
243 |
+
</div>
|
244 |
+
</div>
|
245 |
+
<div>
|
246 |
+
</br>
|
247 |
+
</div>
|
248 |
+
"""
|
249 |
+
)
|
250 |
+
|
251 |
+
# these are hidden buttons that are used to trigger the canvas changes
|
252 |
+
line = gr.Checkbox(label="line", value=False, elem_id="cb-line")
|
253 |
+
eraser = gr.Checkbox(label="eraser", value=False, elem_id="cb-eraser")
|
254 |
+
with gr.Row(elem_id="main_row"):
|
255 |
+
with gr.Column(elem_id="column_input"):
|
256 |
+
gr.Markdown("## INPUT", elem_id="input_header")
|
257 |
+
image = gr.Image(
|
258 |
+
source="canvas",
|
259 |
+
tool="color-sketch",
|
260 |
+
type="pil",
|
261 |
+
image_mode="L",
|
262 |
+
invert_colors=True,
|
263 |
+
shape=(512, 512),
|
264 |
+
brush_radius=4,
|
265 |
+
height=440,
|
266 |
+
width=440,
|
267 |
+
brush_color="#000000",
|
268 |
+
interactive=True,
|
269 |
+
show_download_button=True,
|
270 |
+
elem_id="input_image",
|
271 |
+
show_label=False,
|
272 |
+
)
|
273 |
+
download_sketch = gr.Button(
|
274 |
+
"Download sketch", scale=1, elem_id="download_sketch"
|
275 |
+
)
|
276 |
+
|
277 |
+
gr.HTML(
|
278 |
+
"""
|
279 |
+
<div class="button-row">
|
280 |
+
<div id="my-div-pencil" class="pad2"> <button id="my-toggle-pencil" onclick="return togglePencil(this)"></button> </div>
|
281 |
+
<div id="my-div-eraser" class="pad2"> <button id="my-toggle-eraser" onclick="return toggleEraser(this)"></button> </div>
|
282 |
+
<div class="pad2"> <button id="my-button-undo" onclick="return UNDO_SKETCH_FUNCTION(this)"></button> </div>
|
283 |
+
<div class="pad2"> <button id="my-button-clear" onclick="return DELETE_SKETCH_FUNCTION(this)"></button> </div>
|
284 |
+
<div class="pad2"> <button href="TODO" download="image" id="my-button-down" onclick='return theSketchDownloadFunction()'></button> </div>
|
285 |
+
</div>
|
286 |
+
"""
|
287 |
+
)
|
288 |
+
# gr.Markdown("## Prompt", elem_id="tools_header")
|
289 |
+
prompt = gr.Textbox(label="Prompt", value="", show_label=True)
|
290 |
+
with gr.Row():
|
291 |
+
style = gr.Dropdown(
|
292 |
+
label="Style",
|
293 |
+
choices=STYLE_NAMES,
|
294 |
+
value=DEFAULT_STYLE_NAME,
|
295 |
+
scale=1,
|
296 |
+
)
|
297 |
+
prompt_temp = gr.Textbox(
|
298 |
+
label="Prompt Style Template",
|
299 |
+
value=styles[DEFAULT_STYLE_NAME],
|
300 |
+
scale=2,
|
301 |
+
max_lines=1,
|
302 |
+
)
|
303 |
+
|
304 |
+
with gr.Row():
|
305 |
+
val_r = gr.Slider(
|
306 |
+
label="Sketch guidance: ",
|
307 |
+
show_label=True,
|
308 |
+
minimum=0,
|
309 |
+
maximum=1,
|
310 |
+
value=0.4,
|
311 |
+
step=0.01,
|
312 |
+
scale=3,
|
313 |
+
)
|
314 |
+
seed = gr.Textbox(label="Seed", value=42, scale=1, min_width=50)
|
315 |
+
randomize_seed = gr.Button("Random", scale=1, min_width=50)
|
316 |
+
|
317 |
+
with gr.Column(elem_id="column_process", min_width=50, scale=0.4):
|
318 |
+
gr.Markdown("## pix2pix-turbo", elem_id="description")
|
319 |
+
run_button = gr.Button("Run", min_width=50)
|
320 |
+
|
321 |
+
with gr.Column(elem_id="column_output"):
|
322 |
+
gr.Markdown("## OUTPUT", elem_id="output_header")
|
323 |
+
result = gr.Image(
|
324 |
+
label="Result",
|
325 |
+
height=440,
|
326 |
+
width=440,
|
327 |
+
elem_id="output_image",
|
328 |
+
show_label=False,
|
329 |
+
show_download_button=True,
|
330 |
+
)
|
331 |
+
download_output = gr.Button("Download output", elem_id="download_output")
|
332 |
+
gr.Markdown("### Instructions")
|
333 |
+
gr.Markdown("**1**. Enter a text prompt (e.g. cat)")
|
334 |
+
gr.Markdown("**2**. Start sketching")
|
335 |
+
gr.Markdown("**3**. Change the image style using a style template")
|
336 |
+
gr.Markdown("**4**. Adjust the effect of sketch guidance using the slider")
|
337 |
+
gr.Markdown("**5**. Try different seeds to generate different results")
|
338 |
+
|
339 |
+
eraser.change(
|
340 |
+
fn=lambda x: gr.update(value=not x),
|
341 |
+
inputs=[eraser],
|
342 |
+
outputs=[line],
|
343 |
+
queue=False,
|
344 |
+
api_name=False,
|
345 |
+
).then(update_canvas, [line, eraser], [image])
|
346 |
+
line.change(
|
347 |
+
fn=lambda x: gr.update(value=not x),
|
348 |
+
inputs=[line],
|
349 |
+
outputs=[eraser],
|
350 |
+
queue=False,
|
351 |
+
api_name=False,
|
352 |
+
).then(update_canvas, [line, eraser], [image])
|
353 |
+
|
354 |
+
demo.load(None, None, None, _js=scripts)
|
355 |
+
randomize_seed.click(
|
356 |
+
lambda x: random.randint(0, MAX_SEED),
|
357 |
+
inputs=[],
|
358 |
+
outputs=seed,
|
359 |
+
queue=False,
|
360 |
+
api_name=False,
|
361 |
+
)
|
362 |
+
inputs = [image, prompt, prompt_temp, style, seed, val_r]
|
363 |
+
outputs = [result, download_sketch, download_output]
|
364 |
+
prompt.submit(fn=run, inputs=inputs, outputs=outputs, api_name=False)
|
365 |
+
style.change(
|
366 |
+
lambda x: styles[x],
|
367 |
+
inputs=[style],
|
368 |
+
outputs=[prompt_temp],
|
369 |
+
queue=False,
|
370 |
+
api_name=False,
|
371 |
+
).then(
|
372 |
+
fn=run,
|
373 |
+
inputs=inputs,
|
374 |
+
outputs=outputs,
|
375 |
+
api_name=False,
|
376 |
+
)
|
377 |
+
val_r.change(run, inputs=inputs, outputs=outputs, queue=False, api_name=False)
|
378 |
+
run_button.click(fn=run, inputs=inputs, outputs=outputs, api_name=False)
|
379 |
+
image.change(run, inputs=inputs, outputs=outputs, queue=False, api_name=False)
|
380 |
+
|
381 |
+
if __name__ == "__main__":
|
382 |
+
demo.queue().launch(debug=True, share=True)
|
requirements.txt
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
clip @ git+https://github.com/openai/CLIP.git
|
2 |
+
einops>=0.6.1
|
3 |
+
numpy>=1.24.4
|
4 |
+
open-clip-torch>=2.20.0
|
5 |
+
opencv-python==4.6.0.66
|
6 |
+
pillow>=9.5.0
|
7 |
+
scipy==1.11.1
|
8 |
+
timm>=0.9.2
|
9 |
+
tokenizers
|
10 |
+
torch>=2.0.1
|
11 |
+
|
12 |
+
torchaudio>=2.0.2
|
13 |
+
torchdata==0.6.1
|
14 |
+
torchmetrics>=1.0.1
|
15 |
+
torchvision>=0.15.2
|
16 |
+
|
17 |
+
tqdm>=4.65.0
|
18 |
+
transformers==4.35.2
|
19 |
+
urllib3<1.27,>=1.25.4
|
20 |
+
xformers>=0.0.20
|
21 |
+
streamlit-keyup==0.2.0
|
22 |
+
lpips
|
23 |
+
clean-fid
|
24 |
+
peft
|
25 |
+
dominate
|
26 |
+
diffusers==0.25.1
|
27 |
+
gradio==3.43.1
|
28 |
+
|
29 |
+
vision_aided_loss
|
scripts/download_fill50k.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
mkdir -p data
|
2 |
+
wget https://www.cs.cmu.edu/~img2img-turbo/data/my_fill50k.zip -O data/my_fill50k.zip
|
3 |
+
cd data
|
4 |
+
unzip my_fill50k.zip
|
5 |
+
rm my_fill50k.zip
|
scripts/download_horse2zebra.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
mkdir -p data
|
2 |
+
wget https://www.cs.cmu.edu/~img2img-turbo/data/my_horse2zebra.zip -O data/my_horse2zebra.zip
|
3 |
+
cd data
|
4 |
+
unzip my_horse2zebra.zip
|
5 |
+
rm my_horse2zebra.zip
|
src/cyclegan_turbo.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import copy
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from transformers import AutoTokenizer, CLIPTextModel
|
7 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel
|
8 |
+
from peft import LoraConfig
|
9 |
+
from peft.utils import get_peft_model_state_dict
|
10 |
+
p = "src/"
|
11 |
+
sys.path.append(p)
|
12 |
+
from model import make_1step_sched, my_vae_encoder_fwd, my_vae_decoder_fwd, download_url
|
13 |
+
|
14 |
+
|
15 |
+
class VAE_encode(nn.Module):
|
16 |
+
def __init__(self, vae, vae_b2a=None):
|
17 |
+
super(VAE_encode, self).__init__()
|
18 |
+
self.vae = vae
|
19 |
+
self.vae_b2a = vae_b2a
|
20 |
+
|
21 |
+
def forward(self, x, direction):
|
22 |
+
assert direction in ["a2b", "b2a"]
|
23 |
+
if direction == "a2b":
|
24 |
+
_vae = self.vae
|
25 |
+
else:
|
26 |
+
_vae = self.vae_b2a
|
27 |
+
return _vae.encode(x).latent_dist.sample() * _vae.config.scaling_factor
|
28 |
+
|
29 |
+
|
30 |
+
class VAE_decode(nn.Module):
|
31 |
+
def __init__(self, vae, vae_b2a=None):
|
32 |
+
super(VAE_decode, self).__init__()
|
33 |
+
self.vae = vae
|
34 |
+
self.vae_b2a = vae_b2a
|
35 |
+
|
36 |
+
def forward(self, x, direction):
|
37 |
+
assert direction in ["a2b", "b2a"]
|
38 |
+
if direction == "a2b":
|
39 |
+
_vae = self.vae
|
40 |
+
else:
|
41 |
+
_vae = self.vae_b2a
|
42 |
+
assert _vae.encoder.current_down_blocks is not None
|
43 |
+
_vae.decoder.incoming_skip_acts = _vae.encoder.current_down_blocks
|
44 |
+
x_decoded = (_vae.decode(x / _vae.config.scaling_factor).sample).clamp(-1, 1)
|
45 |
+
return x_decoded
|
46 |
+
|
47 |
+
|
48 |
+
def initialize_unet(rank, return_lora_module_names=False):
|
49 |
+
unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet")
|
50 |
+
unet.requires_grad_(False)
|
51 |
+
unet.train()
|
52 |
+
l_target_modules_encoder, l_target_modules_decoder, l_modules_others = [], [], []
|
53 |
+
l_grep = ["to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_in", "conv_shortcut", "conv_out", "proj_out", "proj_in", "ff.net.2", "ff.net.0.proj"]
|
54 |
+
for n, p in unet.named_parameters():
|
55 |
+
if "bias" in n or "norm" in n: continue
|
56 |
+
for pattern in l_grep:
|
57 |
+
if pattern in n and ("down_blocks" in n or "conv_in" in n):
|
58 |
+
l_target_modules_encoder.append(n.replace(".weight",""))
|
59 |
+
break
|
60 |
+
elif pattern in n and "up_blocks" in n:
|
61 |
+
l_target_modules_decoder.append(n.replace(".weight",""))
|
62 |
+
break
|
63 |
+
elif pattern in n:
|
64 |
+
l_modules_others.append(n.replace(".weight",""))
|
65 |
+
break
|
66 |
+
lora_conf_encoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_encoder, lora_alpha=rank)
|
67 |
+
lora_conf_decoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_decoder, lora_alpha=rank)
|
68 |
+
lora_conf_others = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_modules_others, lora_alpha=rank)
|
69 |
+
unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder")
|
70 |
+
unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder")
|
71 |
+
unet.add_adapter(lora_conf_others, adapter_name="default_others")
|
72 |
+
unet.set_adapters(["default_encoder", "default_decoder", "default_others"])
|
73 |
+
if return_lora_module_names:
|
74 |
+
return unet, l_target_modules_encoder, l_target_modules_decoder, l_modules_others
|
75 |
+
else:
|
76 |
+
return unet
|
77 |
+
|
78 |
+
|
79 |
+
def initialize_vae(rank=4, return_lora_module_names=False):
|
80 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
|
81 |
+
vae.requires_grad_(False)
|
82 |
+
vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
|
83 |
+
vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
|
84 |
+
vae.requires_grad_(True)
|
85 |
+
vae.train()
|
86 |
+
# add the skip connection convs
|
87 |
+
vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
|
88 |
+
vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
|
89 |
+
vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
|
90 |
+
vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
|
91 |
+
torch.nn.init.constant_(vae.decoder.skip_conv_1.weight, 1e-5)
|
92 |
+
torch.nn.init.constant_(vae.decoder.skip_conv_2.weight, 1e-5)
|
93 |
+
torch.nn.init.constant_(vae.decoder.skip_conv_3.weight, 1e-5)
|
94 |
+
torch.nn.init.constant_(vae.decoder.skip_conv_4.weight, 1e-5)
|
95 |
+
vae.decoder.ignore_skip = False
|
96 |
+
vae.decoder.gamma = 1
|
97 |
+
l_vae_target_modules = ["conv1","conv2","conv_in", "conv_shortcut",
|
98 |
+
"conv", "conv_out", "skip_conv_1", "skip_conv_2", "skip_conv_3",
|
99 |
+
"skip_conv_4", "to_k", "to_q", "to_v", "to_out.0",
|
100 |
+
]
|
101 |
+
vae_lora_config = LoraConfig(r=rank, init_lora_weights="gaussian", target_modules=l_vae_target_modules)
|
102 |
+
vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
|
103 |
+
if return_lora_module_names:
|
104 |
+
return vae, l_vae_target_modules
|
105 |
+
else:
|
106 |
+
return vae
|
107 |
+
|
108 |
+
|
109 |
+
class CycleGAN_Turbo(torch.nn.Module):
|
110 |
+
def __init__(self, pretrained_name=None, pretrained_path=None, ckpt_folder="checkpoints", lora_rank_unet=8, lora_rank_vae=4):
|
111 |
+
super().__init__()
|
112 |
+
self.tokenizer = AutoTokenizer.from_pretrained("stabilityai/sd-turbo", subfolder="tokenizer")
|
113 |
+
self.text_encoder = CLIPTextModel.from_pretrained("stabilityai/sd-turbo", subfolder="text_encoder").cuda()
|
114 |
+
self.sched = make_1step_sched()
|
115 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
|
116 |
+
unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet")
|
117 |
+
vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
|
118 |
+
vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
|
119 |
+
# add the skip connection convs
|
120 |
+
vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
|
121 |
+
vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
|
122 |
+
vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
|
123 |
+
vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
|
124 |
+
vae.decoder.ignore_skip = False
|
125 |
+
self.unet, self.vae = unet, vae
|
126 |
+
if pretrained_name == "day_to_night":
|
127 |
+
url = "https://www.cs.cmu.edu/~img2img-turbo/models/day2night.pkl"
|
128 |
+
self.load_ckpt_from_url(url, ckpt_folder)
|
129 |
+
self.timesteps = torch.tensor([999], device="cuda").long()
|
130 |
+
self.caption = "driving in the night"
|
131 |
+
self.direction = "a2b"
|
132 |
+
elif pretrained_name == "night_to_day":
|
133 |
+
url = "https://www.cs.cmu.edu/~img2img-turbo/models/night2day.pkl"
|
134 |
+
self.load_ckpt_from_url(url, ckpt_folder)
|
135 |
+
self.timesteps = torch.tensor([999], device="cuda").long()
|
136 |
+
self.caption = "driving in the day"
|
137 |
+
self.direction = "b2a"
|
138 |
+
elif pretrained_name == "clear_to_rainy":
|
139 |
+
url = "https://www.cs.cmu.edu/~img2img-turbo/models/clear2rainy.pkl"
|
140 |
+
self.load_ckpt_from_url(url, ckpt_folder)
|
141 |
+
self.timesteps = torch.tensor([999], device="cuda").long()
|
142 |
+
self.caption = "driving in heavy rain"
|
143 |
+
self.direction = "a2b"
|
144 |
+
elif pretrained_name == "rainy_to_clear":
|
145 |
+
url = "https://www.cs.cmu.edu/~img2img-turbo/models/rainy2clear.pkl"
|
146 |
+
self.load_ckpt_from_url(url, ckpt_folder)
|
147 |
+
self.timesteps = torch.tensor([999], device="cuda").long()
|
148 |
+
self.caption = "driving in the day"
|
149 |
+
self.direction = "b2a"
|
150 |
+
|
151 |
+
elif pretrained_path is not None:
|
152 |
+
sd = torch.load(pretrained_path)
|
153 |
+
self.load_ckpt_from_state_dict(sd)
|
154 |
+
self.timesteps = torch.tensor([999], device="cuda").long()
|
155 |
+
self.caption = None
|
156 |
+
self.direction = None
|
157 |
+
|
158 |
+
self.vae_enc.cuda()
|
159 |
+
self.vae_dec.cuda()
|
160 |
+
self.unet.cuda()
|
161 |
+
|
162 |
+
def load_ckpt_from_state_dict(self, sd):
|
163 |
+
lora_conf_encoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_target_modules_encoder"], lora_alpha=sd["rank_unet"])
|
164 |
+
lora_conf_decoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_target_modules_decoder"], lora_alpha=sd["rank_unet"])
|
165 |
+
lora_conf_others = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_modules_others"], lora_alpha=sd["rank_unet"])
|
166 |
+
self.unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder")
|
167 |
+
self.unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder")
|
168 |
+
self.unet.add_adapter(lora_conf_others, adapter_name="default_others")
|
169 |
+
for n, p in self.unet.named_parameters():
|
170 |
+
name_sd = n.replace(".default_encoder.weight", ".weight")
|
171 |
+
if "lora" in n and "default_encoder" in n:
|
172 |
+
p.data.copy_(sd["sd_encoder"][name_sd])
|
173 |
+
for n, p in self.unet.named_parameters():
|
174 |
+
name_sd = n.replace(".default_decoder.weight", ".weight")
|
175 |
+
if "lora" in n and "default_decoder" in n:
|
176 |
+
p.data.copy_(sd["sd_decoder"][name_sd])
|
177 |
+
for n, p in self.unet.named_parameters():
|
178 |
+
name_sd = n.replace(".default_others.weight", ".weight")
|
179 |
+
if "lora" in n and "default_others" in n:
|
180 |
+
p.data.copy_(sd["sd_other"][name_sd])
|
181 |
+
self.unet.set_adapter(["default_encoder", "default_decoder", "default_others"])
|
182 |
+
|
183 |
+
vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"])
|
184 |
+
self.vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
|
185 |
+
self.vae.decoder.gamma = 1
|
186 |
+
self.vae_b2a = copy.deepcopy(self.vae)
|
187 |
+
self.vae_enc = VAE_encode(self.vae, vae_b2a=self.vae_b2a)
|
188 |
+
self.vae_enc.load_state_dict(sd["sd_vae_enc"])
|
189 |
+
self.vae_dec = VAE_decode(self.vae, vae_b2a=self.vae_b2a)
|
190 |
+
self.vae_dec.load_state_dict(sd["sd_vae_dec"])
|
191 |
+
|
192 |
+
def load_ckpt_from_url(self, url, ckpt_folder):
|
193 |
+
os.makedirs(ckpt_folder, exist_ok=True)
|
194 |
+
outf = os.path.join(ckpt_folder, os.path.basename(url))
|
195 |
+
download_url(url, outf)
|
196 |
+
sd = torch.load(outf)
|
197 |
+
self.load_ckpt_from_state_dict(sd)
|
198 |
+
|
199 |
+
@staticmethod
|
200 |
+
def forward_with_networks(x, direction, vae_enc, unet, vae_dec, sched, timesteps, text_emb):
|
201 |
+
B = x.shape[0]
|
202 |
+
assert direction in ["a2b", "b2a"]
|
203 |
+
x_enc = vae_enc(x, direction=direction).to(x.dtype)
|
204 |
+
model_pred = unet(x_enc, timesteps, encoder_hidden_states=text_emb,).sample
|
205 |
+
x_out = torch.stack([sched.step(model_pred[i], timesteps[i], x_enc[i], return_dict=True).prev_sample for i in range(B)])
|
206 |
+
x_out_decoded = vae_dec(x_out, direction=direction)
|
207 |
+
return x_out_decoded
|
208 |
+
|
209 |
+
@staticmethod
|
210 |
+
def get_traininable_params(unet, vae_a2b, vae_b2a):
|
211 |
+
# add all unet parameters
|
212 |
+
params_gen = list(unet.conv_in.parameters())
|
213 |
+
unet.conv_in.requires_grad_(True)
|
214 |
+
unet.set_adapters(["default_encoder", "default_decoder", "default_others"])
|
215 |
+
for n,p in unet.named_parameters():
|
216 |
+
if "lora" in n and "default" in n:
|
217 |
+
assert p.requires_grad
|
218 |
+
params_gen.append(p)
|
219 |
+
|
220 |
+
# add all vae_a2b parameters
|
221 |
+
for n,p in vae_a2b.named_parameters():
|
222 |
+
if "lora" in n and "vae_skip" in n:
|
223 |
+
assert p.requires_grad
|
224 |
+
params_gen.append(p)
|
225 |
+
params_gen = params_gen + list(vae_a2b.decoder.skip_conv_1.parameters())
|
226 |
+
params_gen = params_gen + list(vae_a2b.decoder.skip_conv_2.parameters())
|
227 |
+
params_gen = params_gen + list(vae_a2b.decoder.skip_conv_3.parameters())
|
228 |
+
params_gen = params_gen + list(vae_a2b.decoder.skip_conv_4.parameters())
|
229 |
+
|
230 |
+
# add all vae_b2a parameters
|
231 |
+
for n,p in vae_b2a.named_parameters():
|
232 |
+
if "lora" in n and "vae_skip" in n:
|
233 |
+
assert p.requires_grad
|
234 |
+
params_gen.append(p)
|
235 |
+
params_gen = params_gen + list(vae_b2a.decoder.skip_conv_1.parameters())
|
236 |
+
params_gen = params_gen + list(vae_b2a.decoder.skip_conv_2.parameters())
|
237 |
+
params_gen = params_gen + list(vae_b2a.decoder.skip_conv_3.parameters())
|
238 |
+
params_gen = params_gen + list(vae_b2a.decoder.skip_conv_4.parameters())
|
239 |
+
return params_gen
|
240 |
+
|
241 |
+
def forward(self, x_t, direction=None, caption=None, caption_emb=None):
|
242 |
+
if direction is None:
|
243 |
+
assert self.direction is not None
|
244 |
+
direction = self.direction
|
245 |
+
if caption is None and caption_emb is None:
|
246 |
+
assert self.caption is not None
|
247 |
+
caption = self.caption
|
248 |
+
if caption_emb is not None:
|
249 |
+
caption_enc = caption_emb
|
250 |
+
else:
|
251 |
+
caption_tokens = self.tokenizer(caption, max_length=self.tokenizer.model_max_length,
|
252 |
+
padding="max_length", truncation=True, return_tensors="pt").input_ids.to(x_t.device)
|
253 |
+
caption_enc = self.text_encoder(caption_tokens)[0].detach().clone()
|
254 |
+
return self.forward_with_networks(x_t, direction, self.vae_enc, self.unet, self.vae_dec, self.sched, self.timesteps, caption_enc)
|
src/image_prep.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
|
6 |
+
def canny_from_pil(image, low_threshold=100, high_threshold=200):
|
7 |
+
image = np.array(image)
|
8 |
+
image = cv2.Canny(image, low_threshold, high_threshold)
|
9 |
+
image = image[:, :, None]
|
10 |
+
image = np.concatenate([image, image, image], axis=2)
|
11 |
+
control_image = Image.fromarray(image)
|
12 |
+
return control_image
|
src/inference_paired.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
from torchvision import transforms
|
7 |
+
import torchvision.transforms.functional as F
|
8 |
+
from pix2pix_turbo import Pix2Pix_Turbo
|
9 |
+
from image_prep import canny_from_pil
|
10 |
+
|
11 |
+
if __name__ == "__main__":
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument('--input_image', type=str, required=True, help='path to the input image')
|
14 |
+
parser.add_argument('--prompt', type=str, required=True, help='the prompt to be used')
|
15 |
+
parser.add_argument('--model_name', type=str, default='', help='name of the pretrained model to be used')
|
16 |
+
parser.add_argument('--model_path', type=str, default='', help='path to a model state dict to be used')
|
17 |
+
parser.add_argument('--output_dir', type=str, default='output', help='the directory to save the output')
|
18 |
+
parser.add_argument('--low_threshold', type=int, default=100, help='Canny low threshold')
|
19 |
+
parser.add_argument('--high_threshold', type=int, default=200, help='Canny high threshold')
|
20 |
+
parser.add_argument('--gamma', type=float, default=0.4, help='The sketch interpolation guidance amount')
|
21 |
+
parser.add_argument('--seed', type=int, default=42, help='Random seed to be used')
|
22 |
+
parser.add_argument('--use_fp16', action='store_true', help='Use Float16 precision for faster inference')
|
23 |
+
args = parser.parse_args()
|
24 |
+
|
25 |
+
# only one of model_name and model_path should be provided
|
26 |
+
if args.model_name == '' != args.model_path == '':
|
27 |
+
raise ValueError('Either model_name or model_path should be provided')
|
28 |
+
|
29 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
30 |
+
|
31 |
+
# initialize the model
|
32 |
+
model = Pix2Pix_Turbo(pretrained_name=args.model_name, pretrained_path=args.model_path)
|
33 |
+
model.set_eval()
|
34 |
+
if args.use_fp16:
|
35 |
+
model.half()
|
36 |
+
|
37 |
+
# make sure that the input image is a multiple of 8
|
38 |
+
input_image = Image.open(args.input_image).convert('RGB')
|
39 |
+
new_width = input_image.width - input_image.width % 8
|
40 |
+
new_height = input_image.height - input_image.height % 8
|
41 |
+
input_image = input_image.resize((new_width, new_height), Image.LANCZOS)
|
42 |
+
bname = os.path.basename(args.input_image)
|
43 |
+
|
44 |
+
# translate the image
|
45 |
+
with torch.no_grad():
|
46 |
+
if args.model_name == 'edge_to_image':
|
47 |
+
canny = canny_from_pil(input_image, args.low_threshold, args.high_threshold)
|
48 |
+
canny_viz_inv = Image.fromarray(255 - np.array(canny))
|
49 |
+
canny_viz_inv.save(os.path.join(args.output_dir, bname.replace('.png', '_canny.png')))
|
50 |
+
c_t = F.to_tensor(canny).unsqueeze(0).cuda()
|
51 |
+
if args.use_fp16:
|
52 |
+
c_t = c_t.half()
|
53 |
+
output_image = model(c_t, args.prompt)
|
54 |
+
|
55 |
+
elif args.model_name == 'sketch_to_image_stochastic':
|
56 |
+
image_t = F.to_tensor(input_image) < 0.5
|
57 |
+
c_t = image_t.unsqueeze(0).cuda().float()
|
58 |
+
torch.manual_seed(args.seed)
|
59 |
+
B, C, H, W = c_t.shape
|
60 |
+
noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
|
61 |
+
if args.use_fp16:
|
62 |
+
c_t = c_t.half()
|
63 |
+
noise = noise.half()
|
64 |
+
output_image = model(c_t, args.prompt, deterministic=False, r=args.gamma, noise_map=noise)
|
65 |
+
|
66 |
+
else:
|
67 |
+
c_t = F.to_tensor(input_image).unsqueeze(0).cuda()
|
68 |
+
if args.use_fp16:
|
69 |
+
c_t = c_t.half()
|
70 |
+
output_image = model(c_t, args.prompt)
|
71 |
+
|
72 |
+
output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
|
73 |
+
|
74 |
+
# save the output image
|
75 |
+
output_pil.save(os.path.join(args.output_dir, bname))
|
src/inference_unpaired.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
from torchvision import transforms
|
6 |
+
from cyclegan_turbo import CycleGAN_Turbo
|
7 |
+
from my_utils.training_utils import build_transform
|
8 |
+
|
9 |
+
|
10 |
+
if __name__ == "__main__":
|
11 |
+
parser = argparse.ArgumentParser()
|
12 |
+
parser.add_argument('--input_image', type=str, required=True, help='path to the input image')
|
13 |
+
parser.add_argument('--prompt', type=str, required=False, help='the prompt to be used. It is required when loading a custom model_path.')
|
14 |
+
parser.add_argument('--model_name', type=str, default=None, help='name of the pretrained model to be used')
|
15 |
+
parser.add_argument('--model_path', type=str, default=None, help='path to a local model state dict to be used')
|
16 |
+
parser.add_argument('--output_dir', type=str, default='output', help='the directory to save the output')
|
17 |
+
parser.add_argument('--image_prep', type=str, default='resize_512x512', help='the image preparation method')
|
18 |
+
parser.add_argument('--direction', type=str, default=None, help='the direction of translation. None for pretrained models, a2b or b2a for custom paths.')
|
19 |
+
parser.add_argument('--use_fp16', action='store_true', help='Use Float16 precision for faster inference')
|
20 |
+
args = parser.parse_args()
|
21 |
+
|
22 |
+
# only one of model_name and model_path should be provided
|
23 |
+
if args.model_name is None != args.model_path is None:
|
24 |
+
raise ValueError('Either model_name or model_path should be provided')
|
25 |
+
|
26 |
+
if args.model_path is not None and args.prompt is None:
|
27 |
+
raise ValueError('prompt is required when loading a custom model_path.')
|
28 |
+
|
29 |
+
if args.model_name is not None:
|
30 |
+
assert args.prompt is None, 'prompt is not required when loading a pretrained model.'
|
31 |
+
assert args.direction is None, 'direction is not required when loading a pretrained model.'
|
32 |
+
|
33 |
+
# initialize the model
|
34 |
+
model = CycleGAN_Turbo(pretrained_name=args.model_name, pretrained_path=args.model_path)
|
35 |
+
model.eval()
|
36 |
+
model.unet.enable_xformers_memory_efficient_attention()
|
37 |
+
if args.use_fp16:
|
38 |
+
model.half()
|
39 |
+
|
40 |
+
T_val = build_transform(args.image_prep)
|
41 |
+
|
42 |
+
input_image = Image.open(args.input_image).convert('RGB')
|
43 |
+
# translate the image
|
44 |
+
with torch.no_grad():
|
45 |
+
input_img = T_val(input_image)
|
46 |
+
x_t = transforms.ToTensor()(input_img)
|
47 |
+
x_t = transforms.Normalize([0.5], [0.5])(x_t).unsqueeze(0).cuda()
|
48 |
+
if args.use_fp16:
|
49 |
+
x_t = x_t.half()
|
50 |
+
output = model(x_t, direction=args.direction, caption=args.prompt)
|
51 |
+
|
52 |
+
output_pil = transforms.ToPILImage()(output[0].cpu() * 0.5 + 0.5)
|
53 |
+
output_pil = output_pil.resize((input_image.width, input_image.height), Image.LANCZOS)
|
54 |
+
|
55 |
+
# save the output image
|
56 |
+
bname = os.path.basename(args.input_image)
|
57 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
58 |
+
output_pil.save(os.path.join(args.output_dir, bname))
|
src/model.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
from tqdm import tqdm
|
4 |
+
from diffusers import DDPMScheduler
|
5 |
+
|
6 |
+
|
7 |
+
def make_1step_sched():
|
8 |
+
noise_scheduler_1step = DDPMScheduler.from_pretrained("stabilityai/sd-turbo", subfolder="scheduler")
|
9 |
+
noise_scheduler_1step.set_timesteps(1, device="cuda")
|
10 |
+
noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
|
11 |
+
return noise_scheduler_1step
|
12 |
+
|
13 |
+
|
14 |
+
def my_vae_encoder_fwd(self, sample):
|
15 |
+
sample = self.conv_in(sample)
|
16 |
+
l_blocks = []
|
17 |
+
# down
|
18 |
+
for down_block in self.down_blocks:
|
19 |
+
l_blocks.append(sample)
|
20 |
+
sample = down_block(sample)
|
21 |
+
# middle
|
22 |
+
sample = self.mid_block(sample)
|
23 |
+
sample = self.conv_norm_out(sample)
|
24 |
+
sample = self.conv_act(sample)
|
25 |
+
sample = self.conv_out(sample)
|
26 |
+
self.current_down_blocks = l_blocks
|
27 |
+
return sample
|
28 |
+
|
29 |
+
|
30 |
+
def my_vae_decoder_fwd(self, sample, latent_embeds=None):
|
31 |
+
sample = self.conv_in(sample)
|
32 |
+
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
33 |
+
# middle
|
34 |
+
sample = self.mid_block(sample, latent_embeds)
|
35 |
+
sample = sample.to(upscale_dtype)
|
36 |
+
if not self.ignore_skip:
|
37 |
+
skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4]
|
38 |
+
# up
|
39 |
+
for idx, up_block in enumerate(self.up_blocks):
|
40 |
+
skip_in = skip_convs[idx](self.incoming_skip_acts[::-1][idx] * self.gamma)
|
41 |
+
# add skip
|
42 |
+
sample = sample + skip_in
|
43 |
+
sample = up_block(sample, latent_embeds)
|
44 |
+
else:
|
45 |
+
for idx, up_block in enumerate(self.up_blocks):
|
46 |
+
sample = up_block(sample, latent_embeds)
|
47 |
+
# post-process
|
48 |
+
if latent_embeds is None:
|
49 |
+
sample = self.conv_norm_out(sample)
|
50 |
+
else:
|
51 |
+
sample = self.conv_norm_out(sample, latent_embeds)
|
52 |
+
sample = self.conv_act(sample)
|
53 |
+
sample = self.conv_out(sample)
|
54 |
+
return sample
|
55 |
+
|
56 |
+
|
57 |
+
def download_url(url, outf):
|
58 |
+
if not os.path.exists(outf):
|
59 |
+
print(f"Downloading checkpoint to {outf}")
|
60 |
+
response = requests.get(url, stream=True)
|
61 |
+
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
62 |
+
block_size = 1024 # 1 Kibibyte
|
63 |
+
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
64 |
+
with open(outf, 'wb') as file:
|
65 |
+
for data in response.iter_content(block_size):
|
66 |
+
progress_bar.update(len(data))
|
67 |
+
file.write(data)
|
68 |
+
progress_bar.close()
|
69 |
+
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
70 |
+
print("ERROR, something went wrong")
|
71 |
+
print(f"Downloaded successfully to {outf}")
|
72 |
+
else:
|
73 |
+
print(f"Skipping download, {outf} already exists")
|
src/my_utils/dino_struct.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def attn_cosine_sim(x, eps=1e-08):
|
7 |
+
x = x[0] # TEMP: getting rid of redundant dimension, TBF
|
8 |
+
norm1 = x.norm(dim=2, keepdim=True)
|
9 |
+
factor = torch.clamp(norm1 @ norm1.permute(0, 2, 1), min=eps)
|
10 |
+
sim_matrix = (x @ x.permute(0, 2, 1)) / factor
|
11 |
+
return sim_matrix
|
12 |
+
|
13 |
+
|
14 |
+
class VitExtractor:
|
15 |
+
BLOCK_KEY = 'block'
|
16 |
+
ATTN_KEY = 'attn'
|
17 |
+
PATCH_IMD_KEY = 'patch_imd'
|
18 |
+
QKV_KEY = 'qkv'
|
19 |
+
KEY_LIST = [BLOCK_KEY, ATTN_KEY, PATCH_IMD_KEY, QKV_KEY]
|
20 |
+
|
21 |
+
def __init__(self, model_name, device):
|
22 |
+
# pdb.set_trace()
|
23 |
+
self.model = torch.hub.load('facebookresearch/dino:main', model_name).to(device)
|
24 |
+
self.model.eval()
|
25 |
+
self.model_name = model_name
|
26 |
+
self.hook_handlers = []
|
27 |
+
self.layers_dict = {}
|
28 |
+
self.outputs_dict = {}
|
29 |
+
for key in VitExtractor.KEY_LIST:
|
30 |
+
self.layers_dict[key] = []
|
31 |
+
self.outputs_dict[key] = []
|
32 |
+
self._init_hooks_data()
|
33 |
+
|
34 |
+
def _init_hooks_data(self):
|
35 |
+
self.layers_dict[VitExtractor.BLOCK_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
36 |
+
self.layers_dict[VitExtractor.ATTN_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
37 |
+
self.layers_dict[VitExtractor.QKV_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
38 |
+
self.layers_dict[VitExtractor.PATCH_IMD_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
39 |
+
for key in VitExtractor.KEY_LIST:
|
40 |
+
# self.layers_dict[key] = kwargs[key] if key in kwargs.keys() else []
|
41 |
+
self.outputs_dict[key] = []
|
42 |
+
|
43 |
+
def _register_hooks(self, **kwargs):
|
44 |
+
for block_idx, block in enumerate(self.model.blocks):
|
45 |
+
if block_idx in self.layers_dict[VitExtractor.BLOCK_KEY]:
|
46 |
+
self.hook_handlers.append(block.register_forward_hook(self._get_block_hook()))
|
47 |
+
if block_idx in self.layers_dict[VitExtractor.ATTN_KEY]:
|
48 |
+
self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_attn_hook()))
|
49 |
+
if block_idx in self.layers_dict[VitExtractor.QKV_KEY]:
|
50 |
+
self.hook_handlers.append(block.attn.qkv.register_forward_hook(self._get_qkv_hook()))
|
51 |
+
if block_idx in self.layers_dict[VitExtractor.PATCH_IMD_KEY]:
|
52 |
+
self.hook_handlers.append(block.attn.register_forward_hook(self._get_patch_imd_hook()))
|
53 |
+
|
54 |
+
def _clear_hooks(self):
|
55 |
+
for handler in self.hook_handlers:
|
56 |
+
handler.remove()
|
57 |
+
self.hook_handlers = []
|
58 |
+
|
59 |
+
def _get_block_hook(self):
|
60 |
+
def _get_block_output(model, input, output):
|
61 |
+
self.outputs_dict[VitExtractor.BLOCK_KEY].append(output)
|
62 |
+
|
63 |
+
return _get_block_output
|
64 |
+
|
65 |
+
def _get_attn_hook(self):
|
66 |
+
def _get_attn_output(model, inp, output):
|
67 |
+
self.outputs_dict[VitExtractor.ATTN_KEY].append(output)
|
68 |
+
|
69 |
+
return _get_attn_output
|
70 |
+
|
71 |
+
def _get_qkv_hook(self):
|
72 |
+
def _get_qkv_output(model, inp, output):
|
73 |
+
self.outputs_dict[VitExtractor.QKV_KEY].append(output)
|
74 |
+
|
75 |
+
return _get_qkv_output
|
76 |
+
|
77 |
+
# TODO: CHECK ATTN OUTPUT TUPLE
|
78 |
+
def _get_patch_imd_hook(self):
|
79 |
+
def _get_attn_output(model, inp, output):
|
80 |
+
self.outputs_dict[VitExtractor.PATCH_IMD_KEY].append(output[0])
|
81 |
+
|
82 |
+
return _get_attn_output
|
83 |
+
|
84 |
+
def get_feature_from_input(self, input_img): # List([B, N, D])
|
85 |
+
self._register_hooks()
|
86 |
+
self.model(input_img)
|
87 |
+
feature = self.outputs_dict[VitExtractor.BLOCK_KEY]
|
88 |
+
self._clear_hooks()
|
89 |
+
self._init_hooks_data()
|
90 |
+
return feature
|
91 |
+
|
92 |
+
def get_qkv_feature_from_input(self, input_img):
|
93 |
+
self._register_hooks()
|
94 |
+
self.model(input_img)
|
95 |
+
feature = self.outputs_dict[VitExtractor.QKV_KEY]
|
96 |
+
self._clear_hooks()
|
97 |
+
self._init_hooks_data()
|
98 |
+
return feature
|
99 |
+
|
100 |
+
def get_attn_feature_from_input(self, input_img):
|
101 |
+
self._register_hooks()
|
102 |
+
self.model(input_img)
|
103 |
+
feature = self.outputs_dict[VitExtractor.ATTN_KEY]
|
104 |
+
self._clear_hooks()
|
105 |
+
self._init_hooks_data()
|
106 |
+
return feature
|
107 |
+
|
108 |
+
def get_patch_size(self):
|
109 |
+
return 8 if "8" in self.model_name else 16
|
110 |
+
|
111 |
+
def get_width_patch_num(self, input_img_shape):
|
112 |
+
b, c, h, w = input_img_shape
|
113 |
+
patch_size = self.get_patch_size()
|
114 |
+
return w // patch_size
|
115 |
+
|
116 |
+
def get_height_patch_num(self, input_img_shape):
|
117 |
+
b, c, h, w = input_img_shape
|
118 |
+
patch_size = self.get_patch_size()
|
119 |
+
return h // patch_size
|
120 |
+
|
121 |
+
def get_patch_num(self, input_img_shape):
|
122 |
+
patch_num = 1 + (self.get_height_patch_num(input_img_shape) * self.get_width_patch_num(input_img_shape))
|
123 |
+
return patch_num
|
124 |
+
|
125 |
+
def get_head_num(self):
|
126 |
+
if "dino" in self.model_name:
|
127 |
+
return 6 if "s" in self.model_name else 12
|
128 |
+
return 6 if "small" in self.model_name else 12
|
129 |
+
|
130 |
+
def get_embedding_dim(self):
|
131 |
+
if "dino" in self.model_name:
|
132 |
+
return 384 if "s" in self.model_name else 768
|
133 |
+
return 384 if "small" in self.model_name else 768
|
134 |
+
|
135 |
+
def get_queries_from_qkv(self, qkv, input_img_shape):
|
136 |
+
patch_num = self.get_patch_num(input_img_shape)
|
137 |
+
head_num = self.get_head_num()
|
138 |
+
embedding_dim = self.get_embedding_dim()
|
139 |
+
q = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[0]
|
140 |
+
return q
|
141 |
+
|
142 |
+
def get_keys_from_qkv(self, qkv, input_img_shape):
|
143 |
+
patch_num = self.get_patch_num(input_img_shape)
|
144 |
+
head_num = self.get_head_num()
|
145 |
+
embedding_dim = self.get_embedding_dim()
|
146 |
+
k = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[1]
|
147 |
+
return k
|
148 |
+
|
149 |
+
def get_values_from_qkv(self, qkv, input_img_shape):
|
150 |
+
patch_num = self.get_patch_num(input_img_shape)
|
151 |
+
head_num = self.get_head_num()
|
152 |
+
embedding_dim = self.get_embedding_dim()
|
153 |
+
v = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[2]
|
154 |
+
return v
|
155 |
+
|
156 |
+
def get_keys_from_input(self, input_img, layer_num):
|
157 |
+
qkv_features = self.get_qkv_feature_from_input(input_img)[layer_num]
|
158 |
+
keys = self.get_keys_from_qkv(qkv_features, input_img.shape)
|
159 |
+
return keys
|
160 |
+
|
161 |
+
def get_keys_self_sim_from_input(self, input_img, layer_num):
|
162 |
+
keys = self.get_keys_from_input(input_img, layer_num=layer_num)
|
163 |
+
h, t, d = keys.shape
|
164 |
+
concatenated_keys = keys.transpose(0, 1).reshape(t, h * d)
|
165 |
+
ssim_map = attn_cosine_sim(concatenated_keys[None, None, ...])
|
166 |
+
return ssim_map
|
167 |
+
|
168 |
+
|
169 |
+
class DinoStructureLoss:
|
170 |
+
def __init__(self, ):
|
171 |
+
self.extractor = VitExtractor(model_name="dino_vitb8", device="cuda")
|
172 |
+
self.preprocess = torchvision.transforms.Compose([
|
173 |
+
torchvision.transforms.Resize(224),
|
174 |
+
torchvision.transforms.ToTensor(),
|
175 |
+
torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
176 |
+
])
|
177 |
+
|
178 |
+
def calculate_global_ssim_loss(self, outputs, inputs):
|
179 |
+
loss = 0.0
|
180 |
+
for a, b in zip(inputs, outputs): # avoid memory limitations
|
181 |
+
with torch.no_grad():
|
182 |
+
target_keys_self_sim = self.extractor.get_keys_self_sim_from_input(a.unsqueeze(0), layer_num=11)
|
183 |
+
keys_ssim = self.extractor.get_keys_self_sim_from_input(b.unsqueeze(0), layer_num=11)
|
184 |
+
loss += F.mse_loss(keys_ssim, target_keys_self_sim)
|
185 |
+
return loss
|
src/my_utils/training_utils.py
ADDED
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from torchvision import transforms
|
8 |
+
import torchvision.transforms.functional as F
|
9 |
+
from glob import glob
|
10 |
+
|
11 |
+
|
12 |
+
def parse_args_paired_training(input_args=None):
|
13 |
+
"""
|
14 |
+
Parses command-line arguments used for configuring an paired session (pix2pix-Turbo).
|
15 |
+
This function sets up an argument parser to handle various training options.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
argparse.Namespace: The parsed command-line arguments.
|
19 |
+
"""
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
# args for the loss function
|
22 |
+
parser.add_argument("--gan_disc_type", default="vagan_clip")
|
23 |
+
parser.add_argument("--gan_loss_type", default="multilevel_sigmoid_s")
|
24 |
+
parser.add_argument("--lambda_gan", default=0.5, type=float)
|
25 |
+
parser.add_argument("--lambda_lpips", default=5, type=float)
|
26 |
+
parser.add_argument("--lambda_l2", default=1.0, type=float)
|
27 |
+
parser.add_argument("--lambda_clipsim", default=5.0, type=float)
|
28 |
+
|
29 |
+
# dataset options
|
30 |
+
parser.add_argument("--dataset_folder", required=True, type=str)
|
31 |
+
parser.add_argument("--train_image_prep", default="resized_crop_512", type=str)
|
32 |
+
parser.add_argument("--test_image_prep", default="resized_crop_512", type=str)
|
33 |
+
|
34 |
+
# validation eval args
|
35 |
+
parser.add_argument("--eval_freq", default=100, type=int)
|
36 |
+
parser.add_argument("--track_val_fid", default=False, action="store_true")
|
37 |
+
parser.add_argument("--num_samples_eval", type=int, default=100, help="Number of samples to use for all evaluation")
|
38 |
+
|
39 |
+
parser.add_argument("--viz_freq", type=int, default=100, help="Frequency of visualizing the outputs.")
|
40 |
+
parser.add_argument("--tracker_project_name", type=str, default="train_pix2pix_turbo", help="The name of the wandb project to log to.")
|
41 |
+
|
42 |
+
# details about the model architecture
|
43 |
+
parser.add_argument("--pretrained_model_name_or_path")
|
44 |
+
parser.add_argument("--revision", type=str, default=None,)
|
45 |
+
parser.add_argument("--variant", type=str, default=None,)
|
46 |
+
parser.add_argument("--tokenizer_name", type=str, default=None)
|
47 |
+
parser.add_argument("--lora_rank_unet", default=8, type=int)
|
48 |
+
parser.add_argument("--lora_rank_vae", default=4, type=int)
|
49 |
+
|
50 |
+
# training details
|
51 |
+
parser.add_argument("--output_dir", required=True)
|
52 |
+
parser.add_argument("--cache_dir", default=None,)
|
53 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
54 |
+
parser.add_argument("--resolution", type=int, default=512,)
|
55 |
+
parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.")
|
56 |
+
parser.add_argument("--num_training_epochs", type=int, default=10)
|
57 |
+
parser.add_argument("--max_train_steps", type=int, default=10_000,)
|
58 |
+
parser.add_argument("--checkpointing_steps", type=int, default=500,)
|
59 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.",)
|
60 |
+
parser.add_argument("--gradient_checkpointing", action="store_true",)
|
61 |
+
parser.add_argument("--learning_rate", type=float, default=5e-6)
|
62 |
+
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
63 |
+
help=(
|
64 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
65 |
+
' "constant", "constant_with_warmup"]'
|
66 |
+
),
|
67 |
+
)
|
68 |
+
parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.")
|
69 |
+
parser.add_argument("--lr_num_cycles", type=int, default=1,
|
70 |
+
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
71 |
+
)
|
72 |
+
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
|
73 |
+
|
74 |
+
parser.add_argument("--dataloader_num_workers", type=int, default=0,)
|
75 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
76 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
77 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
78 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
79 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
80 |
+
parser.add_argument("--allow_tf32", action="store_true",
|
81 |
+
help=(
|
82 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
83 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
84 |
+
),
|
85 |
+
)
|
86 |
+
parser.add_argument("--report_to", type=str, default="wandb",
|
87 |
+
help=(
|
88 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
89 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
90 |
+
),
|
91 |
+
)
|
92 |
+
parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"],)
|
93 |
+
parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
|
94 |
+
parser.add_argument("--set_grads_to_none", action="store_true",)
|
95 |
+
|
96 |
+
if input_args is not None:
|
97 |
+
args = parser.parse_args(input_args)
|
98 |
+
else:
|
99 |
+
args = parser.parse_args()
|
100 |
+
|
101 |
+
return args
|
102 |
+
|
103 |
+
|
104 |
+
def parse_args_unpaired_training():
|
105 |
+
"""
|
106 |
+
Parses command-line arguments used for configuring an unpaired session (CycleGAN-Turbo).
|
107 |
+
This function sets up an argument parser to handle various training options.
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
argparse.Namespace: The parsed command-line arguments.
|
111 |
+
"""
|
112 |
+
|
113 |
+
parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
|
114 |
+
|
115 |
+
# fixed random seed
|
116 |
+
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
117 |
+
|
118 |
+
# args for the loss function
|
119 |
+
parser.add_argument("--gan_disc_type", default="vagan_clip")
|
120 |
+
parser.add_argument("--gan_loss_type", default="multilevel_sigmoid")
|
121 |
+
parser.add_argument("--lambda_gan", default=0.5, type=float)
|
122 |
+
parser.add_argument("--lambda_idt", default=1, type=float)
|
123 |
+
parser.add_argument("--lambda_cycle", default=1, type=float)
|
124 |
+
parser.add_argument("--lambda_cycle_lpips", default=10.0, type=float)
|
125 |
+
parser.add_argument("--lambda_idt_lpips", default=1.0, type=float)
|
126 |
+
|
127 |
+
# args for dataset and dataloader options
|
128 |
+
parser.add_argument("--dataset_folder", required=True, type=str)
|
129 |
+
parser.add_argument("--train_img_prep", required=True)
|
130 |
+
parser.add_argument("--val_img_prep", required=True)
|
131 |
+
parser.add_argument("--dataloader_num_workers", type=int, default=0)
|
132 |
+
parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.")
|
133 |
+
parser.add_argument("--max_train_epochs", type=int, default=100)
|
134 |
+
parser.add_argument("--max_train_steps", type=int, default=None)
|
135 |
+
|
136 |
+
# args for the model
|
137 |
+
parser.add_argument("--pretrained_model_name_or_path", default="stabilityai/sd-turbo")
|
138 |
+
parser.add_argument("--revision", default=None, type=str)
|
139 |
+
parser.add_argument("--variant", default=None, type=str)
|
140 |
+
parser.add_argument("--lora_rank_unet", default=128, type=int)
|
141 |
+
parser.add_argument("--lora_rank_vae", default=4, type=int)
|
142 |
+
|
143 |
+
# args for validation and logging
|
144 |
+
parser.add_argument("--viz_freq", type=int, default=20)
|
145 |
+
parser.add_argument("--output_dir", type=str, required=True)
|
146 |
+
parser.add_argument("--report_to", type=str, default="wandb")
|
147 |
+
parser.add_argument("--tracker_project_name", type=str, required=True)
|
148 |
+
parser.add_argument("--validation_steps", type=int, default=500,)
|
149 |
+
parser.add_argument("--validation_num_images", type=int, default=-1, help="Number of images to use for validation. -1 to use all images.")
|
150 |
+
parser.add_argument("--checkpointing_steps", type=int, default=500)
|
151 |
+
|
152 |
+
# args for the optimization options
|
153 |
+
parser.add_argument("--learning_rate", type=float, default=5e-6,)
|
154 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
155 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
156 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
157 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
158 |
+
parser.add_argument("--max_grad_norm", default=10.0, type=float, help="Max gradient norm.")
|
159 |
+
parser.add_argument("--lr_scheduler", type=str, default="constant", help=(
|
160 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
161 |
+
' "constant", "constant_with_warmup"]'
|
162 |
+
),
|
163 |
+
)
|
164 |
+
parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.")
|
165 |
+
parser.add_argument("--lr_num_cycles", type=int, default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.",)
|
166 |
+
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
|
167 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
168 |
+
|
169 |
+
# memory saving options
|
170 |
+
parser.add_argument("--allow_tf32", action="store_true",
|
171 |
+
help=(
|
172 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
173 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
174 |
+
),
|
175 |
+
)
|
176 |
+
parser.add_argument("--gradient_checkpointing", action="store_true",
|
177 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.")
|
178 |
+
parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
|
179 |
+
|
180 |
+
args = parser.parse_args()
|
181 |
+
return args
|
182 |
+
|
183 |
+
|
184 |
+
def build_transform(image_prep):
|
185 |
+
"""
|
186 |
+
Constructs a transformation pipeline based on the specified image preparation method.
|
187 |
+
|
188 |
+
Parameters:
|
189 |
+
- image_prep (str): A string describing the desired image preparation
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
- torchvision.transforms.Compose: A composable sequence of transformations to be applied to images.
|
193 |
+
"""
|
194 |
+
if image_prep == "resized_crop_512":
|
195 |
+
T = transforms.Compose([
|
196 |
+
transforms.Resize(512, interpolation=transforms.InterpolationMode.LANCZOS),
|
197 |
+
transforms.CenterCrop(512),
|
198 |
+
])
|
199 |
+
elif image_prep == "resize_286_randomcrop_256x256_hflip":
|
200 |
+
T = transforms.Compose([
|
201 |
+
transforms.Resize((286, 286), interpolation=Image.LANCZOS),
|
202 |
+
transforms.RandomCrop((256, 256)),
|
203 |
+
transforms.RandomHorizontalFlip(),
|
204 |
+
])
|
205 |
+
elif image_prep in ["resize_256", "resize_256x256"]:
|
206 |
+
T = transforms.Compose([
|
207 |
+
transforms.Resize((256, 256), interpolation=Image.LANCZOS)
|
208 |
+
])
|
209 |
+
elif image_prep in ["resize_512", "resize_512x512"]:
|
210 |
+
T = transforms.Compose([
|
211 |
+
transforms.Resize((512, 512), interpolation=Image.LANCZOS)
|
212 |
+
])
|
213 |
+
elif image_prep == "no_resize":
|
214 |
+
T = transforms.Lambda(lambda x: x)
|
215 |
+
return T
|
216 |
+
|
217 |
+
|
218 |
+
class PairedDataset(torch.utils.data.Dataset):
|
219 |
+
def __init__(self, dataset_folder, split, image_prep, tokenizer):
|
220 |
+
"""
|
221 |
+
Itialize the paired dataset object for loading and transforming paired data samples
|
222 |
+
from specified dataset folders.
|
223 |
+
|
224 |
+
This constructor sets up the paths to input and output folders based on the specified 'split',
|
225 |
+
loads the captions (or prompts) for the input images, and prepares the transformations and
|
226 |
+
tokenizer to be applied on the data.
|
227 |
+
|
228 |
+
Parameters:
|
229 |
+
- dataset_folder (str): The root folder containing the dataset, expected to include
|
230 |
+
sub-folders for different splits (e.g., 'train_A', 'train_B').
|
231 |
+
- split (str): The dataset split to use ('train' or 'test'), used to select the appropriate
|
232 |
+
sub-folders and caption files within the dataset folder.
|
233 |
+
- image_prep (str): The image preprocessing transformation to apply to each image.
|
234 |
+
- tokenizer: The tokenizer used for tokenizing the captions (or prompts).
|
235 |
+
"""
|
236 |
+
super().__init__()
|
237 |
+
if split == "train":
|
238 |
+
self.input_folder = os.path.join(dataset_folder, "train_A")
|
239 |
+
self.output_folder = os.path.join(dataset_folder, "train_B")
|
240 |
+
captions = os.path.join(dataset_folder, "train_prompts.json")
|
241 |
+
elif split == "test":
|
242 |
+
self.input_folder = os.path.join(dataset_folder, "test_A")
|
243 |
+
self.output_folder = os.path.join(dataset_folder, "test_B")
|
244 |
+
captions = os.path.join(dataset_folder, "test_prompts.json")
|
245 |
+
with open(captions, "r") as f:
|
246 |
+
self.captions = json.load(f)
|
247 |
+
self.img_names = list(self.captions.keys())
|
248 |
+
self.T = build_transform(image_prep)
|
249 |
+
self.tokenizer = tokenizer
|
250 |
+
|
251 |
+
def __len__(self):
|
252 |
+
"""
|
253 |
+
Returns:
|
254 |
+
int: The total number of items in the dataset.
|
255 |
+
"""
|
256 |
+
return len(self.captions)
|
257 |
+
|
258 |
+
def __getitem__(self, idx):
|
259 |
+
"""
|
260 |
+
Retrieves a dataset item given its index. Each item consists of an input image,
|
261 |
+
its corresponding output image, the captions associated with the input image,
|
262 |
+
and the tokenized form of this caption.
|
263 |
+
|
264 |
+
This method performs the necessary preprocessing on both the input and output images,
|
265 |
+
including scaling and normalization, as well as tokenizing the caption using a provided tokenizer.
|
266 |
+
|
267 |
+
Parameters:
|
268 |
+
- idx (int): The index of the item to retrieve.
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
dict: A dictionary containing the following key-value pairs:
|
272 |
+
- "output_pixel_values": a tensor of the preprocessed output image with pixel values
|
273 |
+
scaled to [-1, 1].
|
274 |
+
- "conditioning_pixel_values": a tensor of the preprocessed input image with pixel values
|
275 |
+
scaled to [0, 1].
|
276 |
+
- "caption": the text caption.
|
277 |
+
- "input_ids": a tensor of the tokenized caption.
|
278 |
+
|
279 |
+
Note:
|
280 |
+
The actual preprocessing steps (scaling and normalization) for images are defined externally
|
281 |
+
and passed to this class through the `image_prep` parameter during initialization. The
|
282 |
+
tokenization process relies on the `tokenizer` also provided at initialization, which
|
283 |
+
should be compatible with the models intended to be used with this dataset.
|
284 |
+
"""
|
285 |
+
img_name = self.img_names[idx]
|
286 |
+
input_img = Image.open(os.path.join(self.input_folder, img_name))
|
287 |
+
output_img = Image.open(os.path.join(self.output_folder, img_name))
|
288 |
+
caption = self.captions[img_name]
|
289 |
+
|
290 |
+
# input images scaled to 0,1
|
291 |
+
img_t = self.T(input_img)
|
292 |
+
img_t = F.to_tensor(img_t)
|
293 |
+
# output images scaled to -1,1
|
294 |
+
output_t = self.T(output_img)
|
295 |
+
output_t = F.to_tensor(output_t)
|
296 |
+
output_t = F.normalize(output_t, mean=[0.5], std=[0.5])
|
297 |
+
|
298 |
+
input_ids = self.tokenizer(
|
299 |
+
caption, max_length=self.tokenizer.model_max_length,
|
300 |
+
padding="max_length", truncation=True, return_tensors="pt"
|
301 |
+
).input_ids
|
302 |
+
|
303 |
+
return {
|
304 |
+
"output_pixel_values": output_t,
|
305 |
+
"conditioning_pixel_values": img_t,
|
306 |
+
"caption": caption,
|
307 |
+
"input_ids": input_ids,
|
308 |
+
}
|
309 |
+
|
310 |
+
|
311 |
+
class UnpairedDataset(torch.utils.data.Dataset):
|
312 |
+
def __init__(self, dataset_folder, split, image_prep, tokenizer):
|
313 |
+
"""
|
314 |
+
A dataset class for loading unpaired data samples from two distinct domains (source and target),
|
315 |
+
typically used in unsupervised learning tasks like image-to-image translation.
|
316 |
+
|
317 |
+
The class supports loading images from specified dataset folders, applying predefined image
|
318 |
+
preprocessing transformations, and utilizing fixed textual prompts (captions) for each domain,
|
319 |
+
tokenized using a provided tokenizer.
|
320 |
+
|
321 |
+
Parameters:
|
322 |
+
- dataset_folder (str): Base directory of the dataset containing subdirectories (train_A, train_B, test_A, test_B)
|
323 |
+
- split (str): Indicates the dataset split to use. Expected values are 'train' or 'test'.
|
324 |
+
- image_prep (str): he image preprocessing transformation to apply to each image.
|
325 |
+
- tokenizer: The tokenizer used for tokenizing the captions (or prompts).
|
326 |
+
"""
|
327 |
+
super().__init__()
|
328 |
+
if split == "train":
|
329 |
+
self.source_folder = os.path.join(dataset_folder, "train_A")
|
330 |
+
self.target_folder = os.path.join(dataset_folder, "train_B")
|
331 |
+
elif split == "test":
|
332 |
+
self.source_folder = os.path.join(dataset_folder, "test_A")
|
333 |
+
self.target_folder = os.path.join(dataset_folder, "test_B")
|
334 |
+
self.tokenizer = tokenizer
|
335 |
+
with open(os.path.join(dataset_folder, "fixed_prompt_a.txt"), "r") as f:
|
336 |
+
self.fixed_caption_src = f.read().strip()
|
337 |
+
self.input_ids_src = self.tokenizer(
|
338 |
+
self.fixed_caption_src, max_length=self.tokenizer.model_max_length,
|
339 |
+
padding="max_length", truncation=True, return_tensors="pt"
|
340 |
+
).input_ids
|
341 |
+
|
342 |
+
with open(os.path.join(dataset_folder, "fixed_prompt_b.txt"), "r") as f:
|
343 |
+
self.fixed_caption_tgt = f.read().strip()
|
344 |
+
self.input_ids_tgt = self.tokenizer(
|
345 |
+
self.fixed_caption_tgt, max_length=self.tokenizer.model_max_length,
|
346 |
+
padding="max_length", truncation=True, return_tensors="pt"
|
347 |
+
).input_ids
|
348 |
+
# find all images in the source and target folders with all IMG extensions
|
349 |
+
self.l_imgs_src = []
|
350 |
+
for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif"]:
|
351 |
+
self.l_imgs_src.extend(glob(os.path.join(self.source_folder, ext)))
|
352 |
+
self.l_imgs_tgt = []
|
353 |
+
for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif"]:
|
354 |
+
self.l_imgs_tgt.extend(glob(os.path.join(self.target_folder, ext)))
|
355 |
+
self.T = build_transform(image_prep)
|
356 |
+
|
357 |
+
def __len__(self):
|
358 |
+
"""
|
359 |
+
Returns:
|
360 |
+
int: The total number of items in the dataset.
|
361 |
+
"""
|
362 |
+
return len(self.l_imgs_src) + len(self.l_imgs_tgt)
|
363 |
+
|
364 |
+
def __getitem__(self, index):
|
365 |
+
"""
|
366 |
+
Fetches a pair of unaligned images from the source and target domains along with their
|
367 |
+
corresponding tokenized captions.
|
368 |
+
|
369 |
+
For the source domain, if the requested index is within the range of available images,
|
370 |
+
the specific image at that index is chosen. If the index exceeds the number of source
|
371 |
+
images, a random source image is selected. For the target domain,
|
372 |
+
an image is always randomly selected, irrespective of the index, to maintain the
|
373 |
+
unpaired nature of the dataset.
|
374 |
+
|
375 |
+
Both images are preprocessed according to the specified image transformation `T`, and normalized.
|
376 |
+
The fixed captions for both domains
|
377 |
+
are included along with their tokenized forms.
|
378 |
+
|
379 |
+
Parameters:
|
380 |
+
- index (int): The index of the source image to retrieve.
|
381 |
+
|
382 |
+
Returns:
|
383 |
+
dict: A dictionary containing processed data for a single training example, with the following keys:
|
384 |
+
- "pixel_values_src": The processed source image
|
385 |
+
- "pixel_values_tgt": The processed target image
|
386 |
+
- "caption_src": The fixed caption of the source domain.
|
387 |
+
- "caption_tgt": The fixed caption of the target domain.
|
388 |
+
- "input_ids_src": The source domain's fixed caption tokenized.
|
389 |
+
- "input_ids_tgt": The target domain's fixed caption tokenized.
|
390 |
+
"""
|
391 |
+
if index < len(self.l_imgs_src):
|
392 |
+
img_path_src = self.l_imgs_src[index]
|
393 |
+
else:
|
394 |
+
img_path_src = random.choice(self.l_imgs_src)
|
395 |
+
img_path_tgt = random.choice(self.l_imgs_tgt)
|
396 |
+
img_pil_src = Image.open(img_path_src).convert("RGB")
|
397 |
+
img_pil_tgt = Image.open(img_path_tgt).convert("RGB")
|
398 |
+
img_t_src = F.to_tensor(self.T(img_pil_src))
|
399 |
+
img_t_tgt = F.to_tensor(self.T(img_pil_tgt))
|
400 |
+
img_t_src = F.normalize(img_t_src, mean=[0.5], std=[0.5])
|
401 |
+
img_t_tgt = F.normalize(img_t_tgt, mean=[0.5], std=[0.5])
|
402 |
+
return {
|
403 |
+
"pixel_values_src": img_t_src,
|
404 |
+
"pixel_values_tgt": img_t_tgt,
|
405 |
+
"caption_src": self.fixed_caption_src,
|
406 |
+
"caption_tgt": self.fixed_caption_tgt,
|
407 |
+
"input_ids_src": self.input_ids_src,
|
408 |
+
"input_ids_tgt": self.input_ids_tgt,
|
409 |
+
}
|