Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- LICENSE +201 -0
- ORIGINAL_README.md +191 -0
- assets/.DS_Store +0 -0
- assets/Konan.png +0 -0
- assets/Naruto.png +0 -0
- assets/cottage.png +0 -0
- assets/dog.png +0 -0
- assets/lady.png +0 -0
- assets/man.png +0 -0
- assets/panda.png +0 -0
- assets/sculpture.png +3 -0
- assets/teaser_figure.png +3 -0
- config_files/IR_dataset.yaml +9 -0
- config_files/losses.yaml +19 -0
- config_files/val_dataset.yaml +7 -0
- data/data_config.py +14 -0
- data/dataset.py +202 -0
- docs/.DS_Store +0 -0
- docs/static/.DS_Store +0 -0
- environment.yaml +37 -0
- gradio_demo/app.py +250 -0
- infer.py +381 -0
- infer.sh +6 -0
- losses/loss_config.py +15 -0
- losses/losses.py +465 -0
- module/aggregator.py +983 -0
- module/attention.py +259 -0
- module/diffusers_vae/autoencoder_kl.py +489 -0
- module/diffusers_vae/vae.py +985 -0
- module/ip_adapter/attention_processor.py +1467 -0
- module/ip_adapter/ip_adapter.py +236 -0
- module/ip_adapter/resampler.py +158 -0
- module/ip_adapter/utils.py +248 -0
- module/min_sdxl.py +915 -0
- module/unet/unet_2d_ZeroSFT.py +1397 -0
- module/unet/unet_2d_ZeroSFT_blocks.py +0 -0
- pipelines/sdxl_instantir.py +1740 -0
- pipelines/stage1_sdxl_pipeline.py +1283 -0
- requirements.txt +14 -0
- schedulers/lcm_single_step_scheduler.py +537 -0
- train_previewer_lora.py +1712 -0
- train_previewer_lora.sh +24 -0
- train_stage1_adapter.py +1259 -0
- train_stage1_adapter.sh +17 -0
- train_stage2_aggregator.py +1698 -0
- train_stage2_aggregator.sh +24 -0
- utils/degradation_pipeline.py +353 -0
- utils/matlab_cp2tform.py +350 -0
- utils/parser.py +452 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/sculpture.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/teaser_figure.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
ORIGINAL_README.md
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<h1>InstantIR: Blind Image Restoration with</br>Instant Generative Reference</h1>
|
3 |
+
|
4 |
+
[**Jen-Yuan Huang**](https://jy-joy.github.io)<sup>1 2</sup>, [**Haofan Wang**](https://haofanwang.github.io/)<sup>2</sup>, [**Qixun Wang**](https://github.com/wangqixun)<sup>2</sup>, [**Xu Bai**](https://huggingface.co/baymin0220)<sup>2</sup>, Hao Ai<sup>2</sup>, Peng Xing<sup>2</sup>, [**Jen-Tse Huang**](https://penguinnnnn.github.io)<sup>3</sup> <br>
|
5 |
+
|
6 |
+
<sup>1</sup>Peking University · <sup>2</sup>InstantX Team · <sup>3</sup>The Chinese University of Hong Kong
|
7 |
+
|
8 |
+
<!-- <sup>*</sup>corresponding authors -->
|
9 |
+
|
10 |
+
<a href='https://arxiv.org/abs/2410.06551'><img src='https://img.shields.io/badge/arXiv-2410.06551-b31b1b.svg'>
|
11 |
+
<a href='https://jy-joy.github.io/InstantIR/'><img src='https://img.shields.io/badge/Project-Website-green'></a>
|
12 |
+
<a href='https://huggingface.co/InstantX/InstantIR'><img src='https://img.shields.io/static/v1?label=Model&message=Huggingface&color=orange'></a>
|
13 |
+
<!-- [![GitHub](https://img.shields.io/github/stars/InstantID/InstantID?style=social)](https://github.com/InstantID/InstantID) -->
|
14 |
+
|
15 |
+
<!-- <a href='https://huggingface.co/spaces/InstantX/InstantID'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
|
16 |
+
[![ModelScope](https://img.shields.io/badge/ModelScope-Studios-blue)](https://modelscope.cn/studios/instantx/InstantID/summary)
|
17 |
+
[![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/InstantX/InstantID) -->
|
18 |
+
|
19 |
+
</div>
|
20 |
+
|
21 |
+
**InstantIR** is a novel single-image restoration model designed to resurrect your damaged images, delivering extrem-quality yet realistic details. You can further boost **InstantIR** performance with additional text prompts, even achieve customized editing!
|
22 |
+
|
23 |
+
|
24 |
+
<!-- >**Abstract**: <br>
|
25 |
+
> Handling test-time unknown degradation is the major challenge in Blind Image Restoration (BIR), necessitating high model generalization. An effective strategy is to incorporate prior knowledge, either from human input or generative model. In this paper, we introduce Instant-reference Image Restoration (InstantIR), a novel diffusion-based BIR method which dynamically adjusts generation condition during inference. We first extract a compact representation of the input via a pre-trained vision encoder. At each generation step, this representation is used to decode current diffusion latent and instantiate it in the generative prior. The degraded image is then encoded with this reference, providing robust generation condition. We observe the variance of generative references fluctuate with degradation intensity, which we further leverage as an indicator for developing a sampling algorithm adaptive to input quality. Extensive experiments demonstrate InstantIR achieves state-of-the-art performance and offering outstanding visual quality. Through modulating generative references with textual description, InstantIR can restore extreme degradation and additionally feature creative restoration. -->
|
26 |
+
|
27 |
+
<img src='assets/teaser_figure.png'>
|
28 |
+
|
29 |
+
## 📢 News
|
30 |
+
- **11/03/2024** 🔥 We provide a Gradio launching script for InstantIR, you can now deploy it on your local machine!
|
31 |
+
- **11/02/2024** 🔥 InstantIR is now compatitble with 🧨 `diffusers`, you can utilize features from this fascinating package!
|
32 |
+
- **10/15/2024** 🔥 Code and model released!
|
33 |
+
|
34 |
+
## 📝 TODOs:
|
35 |
+
- [ ] Launch online demo
|
36 |
+
- [x] Remove dependency on local `diffusers`
|
37 |
+
- [x] Gradio launching script
|
38 |
+
|
39 |
+
## ✨ Usage
|
40 |
+
<!-- ### Online Demo
|
41 |
+
We provide a Gradio Demo on 🤗, click the button below and have fun with InstantIR! -->
|
42 |
+
|
43 |
+
### Quick start
|
44 |
+
#### 1. Clone this repo and setting up environment
|
45 |
+
```sh
|
46 |
+
git clone https://github.com/JY-Joy/InstantIR.git
|
47 |
+
cd InstantIR
|
48 |
+
conda create -n instantir python=3.9 -y
|
49 |
+
conda activate instantir
|
50 |
+
pip install -r requirements.txt
|
51 |
+
```
|
52 |
+
|
53 |
+
#### 2. Download pre-trained models
|
54 |
+
|
55 |
+
InstantIR is built on SDXL and DINOv2. You can download them either directly from 🤗 huggingface or using Python package.
|
56 |
+
|
57 |
+
| 🤗 link | Python command
|
58 |
+
| :--- | :----------
|
59 |
+
|[SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) | `hf_hub_download(repo_id="stabilityai/stable-diffusion-xl-base-1.0")`
|
60 |
+
|[facebook/dinov2-large](https://huggingface.co/facebook/dinov2-large) | `hf_hub_download(repo_id="facebook/dinov2-large")`
|
61 |
+
|[InstantX/InstantIR](https://huggingface.co/InstantX/InstantIR) | `hf_hub_download(repo_id="InstantX/InstantIR")`
|
62 |
+
|
63 |
+
Note: Make sure to import the package first with `from huggingface_hub import hf_hub_download` if you are using Python script.
|
64 |
+
|
65 |
+
#### 3. Inference
|
66 |
+
|
67 |
+
You can run InstantIR inference using `infer.sh` with the following arguments specified.
|
68 |
+
|
69 |
+
```sh
|
70 |
+
infer.sh \
|
71 |
+
--sdxl_path <path_to_SDXL> \
|
72 |
+
--vision_encoder_path <path_to_DINOv2> \
|
73 |
+
--instantir_path <path_to_InstantIR> \
|
74 |
+
--test_path <path_to_input> \
|
75 |
+
--out_path <path_to_output>
|
76 |
+
```
|
77 |
+
|
78 |
+
See `infer.py` for more config options.
|
79 |
+
|
80 |
+
#### 4. Using tips
|
81 |
+
|
82 |
+
InstantIR is powerful, but with your help it can do better. InstantIR's flexible pipeline makes it tunable to a large extent. Here are some tips we found particularly useful for various cases you may encounter:
|
83 |
+
- **Over-smoothing**: reduce `--cfg` to 3.0~5.0. Higher CFG scales can sometimes rigid lines or lack of details.
|
84 |
+
- **Low fidelity**: set `--preview_start` to 0.1~0.4 to preserve fidelity from inputs. The previewer can yield misleading references when input latent is too noisy. In such cases, we suggest to disable the previewer at early timesteps.
|
85 |
+
- **Local distortions**: set `--creative_start` to 0.6~0.8. This will let InstantIR render freely in the late diffusion process, where the high-frequency details are generated. Smaller `--creative_start` spares more spaces for creative restoration, but will diminish fidelity.
|
86 |
+
- **Faster inference**: higher `--preview_start` and lower `--creative_start` can both reduce computational costs and accelerate InstantIR inference.
|
87 |
+
|
88 |
+
> [!CAUTION]
|
89 |
+
> These features are training-free and thus experimental. If you would like to try, we suggest to tune these parameters case-by-case.
|
90 |
+
|
91 |
+
### Use InstantIR with diffusers 🧨
|
92 |
+
|
93 |
+
InstantIR is fully compatible with `diffusers` and is supported by all those powerful features in this package. You can directly load InstantIR via `diffusers` snippet:
|
94 |
+
|
95 |
+
```py
|
96 |
+
# !pip install diffusers opencv-python transformers accelerate
|
97 |
+
import torch
|
98 |
+
from PIL import Image
|
99 |
+
|
100 |
+
from diffusers import DDPMScheduler
|
101 |
+
from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
|
102 |
+
|
103 |
+
from module.ip_adapter.utils import load_adapter_to_pipe
|
104 |
+
from pipelines.sdxl_instantir import InstantIRPipeline
|
105 |
+
|
106 |
+
# suppose you have InstantIR weights under ./models
|
107 |
+
instantir_path = f'./models'
|
108 |
+
|
109 |
+
# load pretrained models
|
110 |
+
pipe = InstantIRPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16)
|
111 |
+
|
112 |
+
# load adapter
|
113 |
+
load_adapter_to_pipe(
|
114 |
+
pipe,
|
115 |
+
f"{instantir_path}/adapter.pt",
|
116 |
+
image_encoder_or_path = 'facebook/dinov2-large',
|
117 |
+
)
|
118 |
+
|
119 |
+
# load previewer lora
|
120 |
+
pipe.prepare_previewers(instantir_path)
|
121 |
+
pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler")
|
122 |
+
lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
|
123 |
+
|
124 |
+
# load aggregator weights
|
125 |
+
pretrained_state_dict = torch.load(f"{instantir_path}/aggregator.pt")
|
126 |
+
pipe.aggregator.load_state_dict(pretrained_state_dict)
|
127 |
+
|
128 |
+
# send to GPU and fp16
|
129 |
+
pipe.to(device='cuda', dtype=torch.float16)
|
130 |
+
pipe.aggregator.to(device='cuda', dtype=torch.float16)
|
131 |
+
```
|
132 |
+
|
133 |
+
Then, you just need to call the `pipe` and InstantIR will handle your image!
|
134 |
+
|
135 |
+
```py
|
136 |
+
# load a broken image
|
137 |
+
low_quality_image = Image.open('./assets/sculpture.png').convert("RGB")
|
138 |
+
|
139 |
+
# InstantIR restoration
|
140 |
+
image = pipe(
|
141 |
+
image=low_quality_image,
|
142 |
+
previewer_scheduler=lcm_scheduler,
|
143 |
+
).images[0]
|
144 |
+
```
|
145 |
+
|
146 |
+
### Deploy local gradio demo
|
147 |
+
|
148 |
+
We provide a python script to launch a local gradio demo of InstantIR, with basic and some advanced features implemented. Start by running the following command in your terminal:
|
149 |
+
|
150 |
+
```sh
|
151 |
+
INSTANTIR_PATH=<path_to_InstantIR> python gradio_demo/app.py
|
152 |
+
```
|
153 |
+
|
154 |
+
Then, visit your local demo via your browser at `http://localhost:7860`.
|
155 |
+
|
156 |
+
|
157 |
+
## ⚙️ Training
|
158 |
+
|
159 |
+
### Prepare data
|
160 |
+
|
161 |
+
InstantIR is trained on [DIV2K](https://www.kaggle.com/datasets/joe1995/div2k-dataset), [Flickr2K](https://www.kaggle.com/datasets/daehoyang/flickr2k), [LSDIR](https://data.vision.ee.ethz.ch/yawli/index.html) and [FFHQ](https://www.kaggle.com/datasets/rahulbhalley/ffhq-1024x1024). We adopt dataset weighting to balance the distribution. You can config their weights in ```config_files/IR_dataset.yaml```. Download these training sets and put them under a same directory, which will be used in the following training configurations.
|
162 |
+
|
163 |
+
### Two-stage training
|
164 |
+
As described in our paper, the training of InstantIR is conducted in two stages. We provide corresponding `.sh` training scripts for each stage. Make sure you have the following arguments adapted to your own use case:
|
165 |
+
|
166 |
+
| Argument | Value
|
167 |
+
| :--- | :----------
|
168 |
+
| `--pretrained_model_name_or_path` | path to your SDXL folder
|
169 |
+
| `--feature_extractor_path` | path to your DINOv2 folder
|
170 |
+
| `--train_data_dir` | your training data directory
|
171 |
+
| `--output_dir` | path to save model weights
|
172 |
+
| `--logging_dir` | path to save logs
|
173 |
+
| `<num_of_gpus>` | number of available GPUs
|
174 |
+
|
175 |
+
Other training hyperparameters we used in our experiments are provided in the corresponding `.sh` scripts. You can tune them according to your own needs.
|
176 |
+
|
177 |
+
## 👏 Acknowledgment
|
178 |
+
Our work is sponsored by [HuggingFace](https://huggingface.co) and [fal.ai](https://fal.ai).
|
179 |
+
|
180 |
+
## 🎓 Citation
|
181 |
+
|
182 |
+
If InstantIR is helpful to your work, please cite our paper via:
|
183 |
+
|
184 |
+
```
|
185 |
+
@article{huang2024instantir,
|
186 |
+
title={InstantIR: Blind Image Restoration with Instant Generative Reference},
|
187 |
+
author={Huang, Jen-Yuan and Wang, Haofan and Wang, Qixun and Bai, Xu and Ai, Hao and Xing, Peng and Huang, Jen-Tse},
|
188 |
+
journal={arXiv preprint arXiv:2410.06551},
|
189 |
+
year={2024}
|
190 |
+
}
|
191 |
+
```
|
assets/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
assets/Konan.png
ADDED
assets/Naruto.png
ADDED
assets/cottage.png
ADDED
assets/dog.png
ADDED
assets/lady.png
ADDED
assets/man.png
ADDED
assets/panda.png
ADDED
assets/sculpture.png
ADDED
Git LFS Details
|
assets/teaser_figure.png
ADDED
Git LFS Details
|
config_files/IR_dataset.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets:
|
2 |
+
- dataset_folder: 'ffhq'
|
3 |
+
dataset_weight: 0.1
|
4 |
+
- dataset_folder: 'DIV2K'
|
5 |
+
dataset_weight: 0.3
|
6 |
+
- dataset_folder: 'LSDIR'
|
7 |
+
dataset_weight: 0.3
|
8 |
+
- dataset_folder: 'Flickr2K'
|
9 |
+
dataset_weight: 0.1
|
config_files/losses.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusion_losses:
|
2 |
+
- name: L2Loss
|
3 |
+
weight: 1
|
4 |
+
lcm_losses:
|
5 |
+
- name: HuberLoss
|
6 |
+
weight: 1
|
7 |
+
# - name: DINOLoss
|
8 |
+
# weight: 1e-3
|
9 |
+
# - name: L2Loss
|
10 |
+
# weight: 5e-2
|
11 |
+
# - name: LPIPSLoss
|
12 |
+
# weight: 1e-3
|
13 |
+
# - name: DreamSIMLoss
|
14 |
+
# weight: 1e-3
|
15 |
+
# - name: IDLoss
|
16 |
+
# weight: 1e-3
|
17 |
+
# visualize_every_k: 50
|
18 |
+
# init_params:
|
19 |
+
# pretrained_arcface_path: /home/dcor/orlichter/consistency_encoder_private/pretrained_models/model_ir_se50.pth
|
config_files/val_dataset.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets:
|
2 |
+
- dataset_folder: 'ffhq'
|
3 |
+
dataset_weight: 0.1
|
4 |
+
- dataset_folder: 'DIV2K'
|
5 |
+
dataset_weight: 0.45
|
6 |
+
- dataset_folder: 'LSDIR'
|
7 |
+
dataset_weight: 0.45
|
data/data_config.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Optional, List
|
3 |
+
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class SingleDataConfig:
|
7 |
+
dataset_folder: str
|
8 |
+
imagefolder: bool = True
|
9 |
+
dataset_weight: float = 1.0 # Not used yet
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class DataConfig:
|
13 |
+
datasets: List[SingleDataConfig]
|
14 |
+
val_dataset: Optional[SingleDataConfig] = None
|
data/dataset.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from PIL import Image
|
5 |
+
from PIL.ImageOps import exif_transpose
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
from torchvision import transforms
|
8 |
+
import json
|
9 |
+
import random
|
10 |
+
from facenet_pytorch import MTCNN
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from utils.utils import extract_faces_and_landmarks, REFERNCE_FACIAL_POINTS_RELATIVE
|
14 |
+
|
15 |
+
def load_image(image_path: str) -> Image:
|
16 |
+
image = Image.open(image_path)
|
17 |
+
image = exif_transpose(image)
|
18 |
+
if not image.mode == "RGB":
|
19 |
+
image = image.convert("RGB")
|
20 |
+
return image
|
21 |
+
|
22 |
+
|
23 |
+
class ImageDataset(Dataset):
|
24 |
+
"""
|
25 |
+
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
26 |
+
It pre-processes the images.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
instance_data_root,
|
32 |
+
instance_prompt,
|
33 |
+
metadata_path: Optional[str] = None,
|
34 |
+
prompt_in_filename=False,
|
35 |
+
use_only_vanilla_for_encoder=False,
|
36 |
+
concept_placeholder='a face',
|
37 |
+
size=1024,
|
38 |
+
center_crop=False,
|
39 |
+
aug_images=False,
|
40 |
+
use_only_decoder_prompts=False,
|
41 |
+
crop_head_for_encoder_image=False,
|
42 |
+
random_target_prob=0.0,
|
43 |
+
):
|
44 |
+
self.mtcnn = MTCNN(device='cuda:0')
|
45 |
+
self.mtcnn.forward = self.mtcnn.detect
|
46 |
+
resize_factor = 1.3
|
47 |
+
self.resized_reference_points = REFERNCE_FACIAL_POINTS_RELATIVE / resize_factor + (resize_factor - 1) / (2 * resize_factor)
|
48 |
+
self.size = size
|
49 |
+
self.center_crop = center_crop
|
50 |
+
self.concept_placeholder = concept_placeholder
|
51 |
+
self.prompt_in_filename = prompt_in_filename
|
52 |
+
self.aug_images = aug_images
|
53 |
+
|
54 |
+
self.instance_prompt = instance_prompt
|
55 |
+
self.custom_instance_prompts = None
|
56 |
+
self.name_to_label = None
|
57 |
+
self.crop_head_for_encoder_image = crop_head_for_encoder_image
|
58 |
+
self.random_target_prob = random_target_prob
|
59 |
+
|
60 |
+
self.use_only_decoder_prompts = use_only_decoder_prompts
|
61 |
+
|
62 |
+
self.instance_data_root = Path(instance_data_root)
|
63 |
+
|
64 |
+
if not self.instance_data_root.exists():
|
65 |
+
raise ValueError(f"Instance images root {self.instance_data_root} doesn't exist.")
|
66 |
+
|
67 |
+
if metadata_path is not None:
|
68 |
+
with open(metadata_path, 'r') as f:
|
69 |
+
self.name_to_label = json.load(f) # dict of filename: label
|
70 |
+
# Create a reversed mapping
|
71 |
+
self.label_to_names = {}
|
72 |
+
for name, label in self.name_to_label.items():
|
73 |
+
if use_only_vanilla_for_encoder and 'vanilla' not in name:
|
74 |
+
continue
|
75 |
+
if label not in self.label_to_names:
|
76 |
+
self.label_to_names[label] = []
|
77 |
+
self.label_to_names[label].append(name)
|
78 |
+
self.all_paths = [self.instance_data_root / filename for filename in self.name_to_label.keys()]
|
79 |
+
|
80 |
+
# Verify all paths exist
|
81 |
+
n_all_paths = len(self.all_paths)
|
82 |
+
self.all_paths = [path for path in self.all_paths if path.exists()]
|
83 |
+
print(f'Found {len(self.all_paths)} out of {n_all_paths} paths.')
|
84 |
+
else:
|
85 |
+
self.all_paths = [path for path in list(Path(instance_data_root).glob('**/*')) if
|
86 |
+
path.suffix.lower() in [".png", ".jpg", ".jpeg"]]
|
87 |
+
# Sort by name so that order for validation remains the same across runs
|
88 |
+
self.all_paths = sorted(self.all_paths, key=lambda x: x.stem)
|
89 |
+
|
90 |
+
self.custom_instance_prompts = None
|
91 |
+
|
92 |
+
self._length = len(self.all_paths)
|
93 |
+
|
94 |
+
self.class_data_root = None
|
95 |
+
|
96 |
+
self.image_transforms = transforms.Compose(
|
97 |
+
[
|
98 |
+
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
99 |
+
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
100 |
+
transforms.ToTensor(),
|
101 |
+
transforms.Normalize([0.5], [0.5]),
|
102 |
+
]
|
103 |
+
)
|
104 |
+
|
105 |
+
if self.prompt_in_filename:
|
106 |
+
self.prompts_set = set([self._path_to_prompt(path) for path in self.all_paths])
|
107 |
+
else:
|
108 |
+
self.prompts_set = set([self.instance_prompt])
|
109 |
+
|
110 |
+
if self.aug_images:
|
111 |
+
self.aug_transforms = transforms.Compose(
|
112 |
+
[
|
113 |
+
transforms.RandomResizedCrop(size, scale=(0.8, 1.0), ratio=(1.0, 1.0)),
|
114 |
+
transforms.RandomHorizontalFlip(p=0.5)
|
115 |
+
]
|
116 |
+
)
|
117 |
+
|
118 |
+
def __len__(self):
|
119 |
+
return self._length
|
120 |
+
|
121 |
+
def _path_to_prompt(self, path):
|
122 |
+
# Remove the extension and seed
|
123 |
+
split_path = path.stem.split('_')
|
124 |
+
while split_path[-1].isnumeric():
|
125 |
+
split_path = split_path[:-1]
|
126 |
+
|
127 |
+
prompt = ' '.join(split_path)
|
128 |
+
# Replace placeholder in prompt with training placeholder
|
129 |
+
prompt = prompt.replace('conceptname', self.concept_placeholder)
|
130 |
+
return prompt
|
131 |
+
|
132 |
+
def __getitem__(self, index):
|
133 |
+
example = {}
|
134 |
+
instance_path = self.all_paths[index]
|
135 |
+
instance_image = load_image(instance_path)
|
136 |
+
example["instance_images"] = self.image_transforms(instance_image)
|
137 |
+
if self.prompt_in_filename:
|
138 |
+
example["instance_prompt"] = self._path_to_prompt(instance_path)
|
139 |
+
else:
|
140 |
+
example["instance_prompt"] = self.instance_prompt
|
141 |
+
|
142 |
+
if self.name_to_label is None:
|
143 |
+
# If no labels, simply take the same image but with different augmentation
|
144 |
+
example["encoder_images"] = self.aug_transforms(example["instance_images"]) if self.aug_images else example["instance_images"]
|
145 |
+
example["encoder_prompt"] = example["instance_prompt"]
|
146 |
+
else:
|
147 |
+
# Randomly select another image with the same label
|
148 |
+
instance_name = str(instance_path.relative_to(self.instance_data_root))
|
149 |
+
instance_label = self.name_to_label[instance_name]
|
150 |
+
label_set = set(self.label_to_names[instance_label])
|
151 |
+
if len(label_set) == 1:
|
152 |
+
# We are not supposed to have only one image per label, but just in case
|
153 |
+
encoder_image_name = instance_name
|
154 |
+
print(f'WARNING: Only one image for label {instance_label}.')
|
155 |
+
else:
|
156 |
+
encoder_image_name = random.choice(list(label_set - {instance_name}))
|
157 |
+
encoder_image = load_image(self.instance_data_root / encoder_image_name)
|
158 |
+
example["encoder_images"] = self.image_transforms(encoder_image)
|
159 |
+
|
160 |
+
if self.prompt_in_filename:
|
161 |
+
example["encoder_prompt"] = self._path_to_prompt(self.instance_data_root / encoder_image_name)
|
162 |
+
else:
|
163 |
+
example["encoder_prompt"] = self.instance_prompt
|
164 |
+
|
165 |
+
if self.crop_head_for_encoder_image:
|
166 |
+
example["encoder_images"] = extract_faces_and_landmarks(example["encoder_images"][None], self.size, self.mtcnn, self.resized_reference_points)[0][0]
|
167 |
+
example["encoder_prompt"] = example["encoder_prompt"].format(placeholder="<ph>")
|
168 |
+
example["instance_prompt"] = example["instance_prompt"].format(placeholder="<s*>")
|
169 |
+
|
170 |
+
if random.random() < self.random_target_prob:
|
171 |
+
random_path = random.choice(self.all_paths)
|
172 |
+
|
173 |
+
random_image = load_image(random_path)
|
174 |
+
example["instance_images"] = self.image_transforms(random_image)
|
175 |
+
if self.prompt_in_filename:
|
176 |
+
example["instance_prompt"] = self._path_to_prompt(random_path)
|
177 |
+
|
178 |
+
|
179 |
+
if self.use_only_decoder_prompts:
|
180 |
+
example["encoder_prompt"] = example["instance_prompt"]
|
181 |
+
|
182 |
+
return example
|
183 |
+
|
184 |
+
|
185 |
+
def collate_fn(examples, with_prior_preservation=False):
|
186 |
+
pixel_values = [example["instance_images"] for example in examples]
|
187 |
+
encoder_pixel_values = [example["encoder_images"] for example in examples]
|
188 |
+
prompts = [example["instance_prompt"] for example in examples]
|
189 |
+
encoder_prompts = [example["encoder_prompt"] for example in examples]
|
190 |
+
|
191 |
+
if with_prior_preservation:
|
192 |
+
raise NotImplementedError("Prior preservation not implemented.")
|
193 |
+
|
194 |
+
pixel_values = torch.stack(pixel_values)
|
195 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
196 |
+
|
197 |
+
encoder_pixel_values = torch.stack(encoder_pixel_values)
|
198 |
+
encoder_pixel_values = encoder_pixel_values.to(memory_format=torch.contiguous_format).float()
|
199 |
+
|
200 |
+
batch = {"pixel_values": pixel_values, "encoder_pixel_values": encoder_pixel_values,
|
201 |
+
"prompts": prompts, "encoder_prompts": encoder_prompts}
|
202 |
+
return batch
|
docs/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
docs/static/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
environment.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: instantir
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- conda-forge
|
6 |
+
- defaults
|
7 |
+
dependencies:
|
8 |
+
- numpy
|
9 |
+
- pandas
|
10 |
+
- pillow
|
11 |
+
- pip
|
12 |
+
- python=3.9.15
|
13 |
+
- pytorch=2.2.2
|
14 |
+
- pytorch-lightning=1.6.5
|
15 |
+
- pytorch-cuda=12.1
|
16 |
+
- setuptools
|
17 |
+
- torchaudio=2.2.2
|
18 |
+
- torchmetrics
|
19 |
+
- torchvision=0.17.2
|
20 |
+
- tqdm
|
21 |
+
- pip:
|
22 |
+
- accelerate==0.25.0
|
23 |
+
- diffusers==0.24.0
|
24 |
+
- einops
|
25 |
+
- open-clip-torch
|
26 |
+
- opencv-python==4.8.1.78
|
27 |
+
- tokenizers
|
28 |
+
- transformers==4.36.2
|
29 |
+
- kornia
|
30 |
+
- facenet_pytorch
|
31 |
+
- lpips
|
32 |
+
- dreamsim
|
33 |
+
- pyrallis
|
34 |
+
- wandb
|
35 |
+
- insightface
|
36 |
+
- onnxruntime==1.17.0
|
37 |
+
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
gradio_demo/app.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import gradio as gr
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from diffusers import DDPMScheduler
|
11 |
+
from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
|
12 |
+
|
13 |
+
from module.ip_adapter.utils import load_adapter_to_pipe
|
14 |
+
from pipelines.sdxl_instantir import InstantIRPipeline
|
15 |
+
|
16 |
+
def resize_img(input_image, max_side=1280, min_side=1024, size=None,
|
17 |
+
pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
|
18 |
+
|
19 |
+
w, h = input_image.size
|
20 |
+
if size is not None:
|
21 |
+
w_resize_new, h_resize_new = size
|
22 |
+
else:
|
23 |
+
# ratio = min_side / min(h, w)
|
24 |
+
# w, h = round(ratio*w), round(ratio*h)
|
25 |
+
ratio = max_side / max(h, w)
|
26 |
+
input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
|
27 |
+
w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
|
28 |
+
h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
|
29 |
+
input_image = input_image.resize([w_resize_new, h_resize_new], mode)
|
30 |
+
|
31 |
+
if pad_to_max_side:
|
32 |
+
res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
|
33 |
+
offset_x = (max_side - w_resize_new) // 2
|
34 |
+
offset_y = (max_side - h_resize_new) // 2
|
35 |
+
res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
|
36 |
+
input_image = Image.fromarray(res)
|
37 |
+
return input_image
|
38 |
+
|
39 |
+
instantir_path = os.environ['INSTANTIR_PATH']
|
40 |
+
|
41 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
42 |
+
sdxl_repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
43 |
+
dinov2_repo_id = "facebook/dinov2-large"
|
44 |
+
lcm_repo_id = "latent-consistency/lcm-lora-sdxl"
|
45 |
+
|
46 |
+
if torch.cuda.is_available():
|
47 |
+
torch_dtype = torch.float16
|
48 |
+
else:
|
49 |
+
torch_dtype = torch.float32
|
50 |
+
|
51 |
+
# Load pretrained models.
|
52 |
+
print("Initializing pipeline...")
|
53 |
+
pipe = InstantIRPipeline.from_pretrained(
|
54 |
+
sdxl_repo_id,
|
55 |
+
torch_dtype=torch_dtype,
|
56 |
+
)
|
57 |
+
|
58 |
+
# Image prompt projector.
|
59 |
+
print("Loading LQ-Adapter...")
|
60 |
+
load_adapter_to_pipe(
|
61 |
+
pipe,
|
62 |
+
f"{instantir_path}/adapter.pt",
|
63 |
+
dinov2_repo_id,
|
64 |
+
)
|
65 |
+
|
66 |
+
# Prepare previewer
|
67 |
+
lora_alpha = pipe.prepare_previewers(instantir_path)
|
68 |
+
print(f"use lora alpha {lora_alpha}")
|
69 |
+
lora_alpha = pipe.prepare_previewers(lcm_repo_id, use_lcm=True)
|
70 |
+
print(f"use lora alpha {lora_alpha}")
|
71 |
+
pipe.to(device=device, dtype=torch_dtype)
|
72 |
+
pipe.scheduler = DDPMScheduler.from_pretrained(sdxl_repo_id, subfolder="scheduler")
|
73 |
+
lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
|
74 |
+
|
75 |
+
# Load weights.
|
76 |
+
print("Loading checkpoint...")
|
77 |
+
aggregator_state_dict = torch.load(
|
78 |
+
f"{instantir_path}/aggregator.pt",
|
79 |
+
map_location="cpu"
|
80 |
+
)
|
81 |
+
pipe.aggregator.load_state_dict(aggregator_state_dict, strict=True)
|
82 |
+
pipe.aggregator.to(device=device, dtype=torch_dtype)
|
83 |
+
|
84 |
+
MAX_SEED = np.iinfo(np.int32).max
|
85 |
+
MAX_IMAGE_SIZE = 1024
|
86 |
+
|
87 |
+
PROMPT = "Photorealistic, highly detailed, hyper detailed photo - realistic maximum detail, 32k, \
|
88 |
+
ultra HD, extreme meticulous detailing, skin pore detailing, \
|
89 |
+
hyper sharpness, perfect without deformations, \
|
90 |
+
taken using a Canon EOS R camera, Cinematic, High Contrast, Color Grading. "
|
91 |
+
|
92 |
+
NEG_PROMPT = "blurry, out of focus, unclear, depth of field, over-smooth, \
|
93 |
+
sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, \
|
94 |
+
dirty, messy, worst quality, low quality, frames, painting, illustration, drawing, art, \
|
95 |
+
watermark, signature, jpeg artifacts, deformed, lowres"
|
96 |
+
|
97 |
+
def unpack_pipe_out(preview_row, index):
|
98 |
+
return preview_row[index][0]
|
99 |
+
|
100 |
+
def dynamic_preview_slider(sampling_steps):
|
101 |
+
print(sampling_steps)
|
102 |
+
return gr.Slider(label="Restoration Previews", value=sampling_steps-1, minimum=0, maximum=sampling_steps-1, step=1)
|
103 |
+
|
104 |
+
def dynamic_guidance_slider(sampling_steps):
|
105 |
+
return gr.Slider(label="Start Free Rendering", value=sampling_steps, minimum=0, maximum=sampling_steps, step=1)
|
106 |
+
|
107 |
+
def show_final_preview(preview_row):
|
108 |
+
return preview_row[-1][0]
|
109 |
+
|
110 |
+
# @spaces.GPU #[uncomment to use ZeroGPU]
|
111 |
+
@torch.no_grad()
|
112 |
+
def instantir_restore(
|
113 |
+
lq, prompt="", steps=30, cfg_scale=7.0, guidance_end=1.0,
|
114 |
+
creative_restoration=False, seed=3407, height=1024, width=1024, preview_start=0.0):
|
115 |
+
if creative_restoration:
|
116 |
+
if "lcm" not in pipe.unet.active_adapters():
|
117 |
+
pipe.unet.set_adapter('lcm')
|
118 |
+
else:
|
119 |
+
if "previewer" not in pipe.unet.active_adapters():
|
120 |
+
pipe.unet.set_adapter('previewer')
|
121 |
+
|
122 |
+
if isinstance(guidance_end, int):
|
123 |
+
guidance_end = guidance_end / steps
|
124 |
+
elif guidance_end > 1.0:
|
125 |
+
guidance_end = guidance_end / steps
|
126 |
+
if isinstance(preview_start, int):
|
127 |
+
preview_start = preview_start / steps
|
128 |
+
elif preview_start > 1.0:
|
129 |
+
preview_start = preview_start / steps
|
130 |
+
lq = [resize_img(lq.convert("RGB"), size=(width, height))]
|
131 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
132 |
+
timesteps = [
|
133 |
+
i * (1000//steps) + pipe.scheduler.config.steps_offset for i in range(0, steps)
|
134 |
+
]
|
135 |
+
timesteps = timesteps[::-1]
|
136 |
+
|
137 |
+
prompt = PROMPT if len(prompt)==0 else prompt
|
138 |
+
neg_prompt = NEG_PROMPT
|
139 |
+
|
140 |
+
out = pipe(
|
141 |
+
prompt=[prompt]*len(lq),
|
142 |
+
image=lq,
|
143 |
+
num_inference_steps=steps,
|
144 |
+
generator=generator,
|
145 |
+
timesteps=timesteps,
|
146 |
+
negative_prompt=[neg_prompt]*len(lq),
|
147 |
+
guidance_scale=cfg_scale,
|
148 |
+
control_guidance_end=guidance_end,
|
149 |
+
preview_start=preview_start,
|
150 |
+
previewer_scheduler=lcm_scheduler,
|
151 |
+
return_dict=False,
|
152 |
+
save_preview_row=True,
|
153 |
+
)
|
154 |
+
for i, preview_img in enumerate(out[1]):
|
155 |
+
preview_img.append(f"preview_{i}")
|
156 |
+
return out[0][0], out[1]
|
157 |
+
|
158 |
+
examples = [
|
159 |
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
160 |
+
"An astronaut riding a green horse",
|
161 |
+
"A delicious ceviche cheesecake slice",
|
162 |
+
]
|
163 |
+
|
164 |
+
css="""
|
165 |
+
#col-container {
|
166 |
+
margin: 0 auto;
|
167 |
+
max-width: 640px;
|
168 |
+
}
|
169 |
+
"""
|
170 |
+
|
171 |
+
with gr.Blocks() as demo:
|
172 |
+
gr.Markdown(
|
173 |
+
"""
|
174 |
+
# InstantIR: Blind Image Restoration with Instant Generative Reference.
|
175 |
+
|
176 |
+
### **Official 🤗 Gradio demo of [InstantIR](https://arxiv.org/abs/2410.06551).**
|
177 |
+
### **InstantIR can not only help you restore your broken image, but also capable of imaginative re-creation following your text prompts. See advance usage for more details!**
|
178 |
+
## Basic usage: revitalize your image
|
179 |
+
1. Upload an image you want to restore;
|
180 |
+
2. Optionally, tune the `Steps` `CFG Scale` parameters. Typically higher steps lead to better results, but less than 50 is recommended for efficiency;
|
181 |
+
3. Click `InstantIR magic!`.
|
182 |
+
""")
|
183 |
+
with gr.Row():
|
184 |
+
lq_img = gr.Image(label="Low-quality image", type="pil")
|
185 |
+
with gr.Column():
|
186 |
+
with gr.Row():
|
187 |
+
steps = gr.Number(label="Steps", value=30, step=1)
|
188 |
+
cfg_scale = gr.Number(label="CFG Scale", value=7.0, step=0.1)
|
189 |
+
with gr.Row():
|
190 |
+
height = gr.Number(label="Height", value=1024, step=1)
|
191 |
+
weight = gr.Number(label="Weight", value=1024, step=1)
|
192 |
+
seed = gr.Number(label="Seed", value=42, step=1)
|
193 |
+
# guidance_start = gr.Slider(label="Guidance Start", value=1.0, minimum=0.0, maximum=1.0, step=0.05)
|
194 |
+
guidance_end = gr.Slider(label="Start Free Rendering", value=30, minimum=0, maximum=30, step=1)
|
195 |
+
preview_start = gr.Slider(label="Preview Start", value=0, minimum=0, maximum=30, step=1)
|
196 |
+
prompt = gr.Textbox(label="Restoration prompts (Optional)", placeholder="")
|
197 |
+
mode = gr.Checkbox(label="Creative Restoration", value=False)
|
198 |
+
with gr.Row():
|
199 |
+
with gr.Row():
|
200 |
+
restore_btn = gr.Button("InstantIR magic!")
|
201 |
+
clear_btn = gr.ClearButton()
|
202 |
+
index = gr.Slider(label="Restoration Previews", value=29, minimum=0, maximum=29, step=1)
|
203 |
+
with gr.Row():
|
204 |
+
output = gr.Image(label="InstantIR restored", type="pil")
|
205 |
+
preview = gr.Image(label="Preview", type="pil")
|
206 |
+
pipe_out = gr.Gallery(visible=False)
|
207 |
+
clear_btn.add([lq_img, output, preview])
|
208 |
+
restore_btn.click(
|
209 |
+
instantir_restore, inputs=[
|
210 |
+
lq_img, prompt, steps, cfg_scale, guidance_end,
|
211 |
+
mode, seed, height, weight, preview_start,
|
212 |
+
],
|
213 |
+
outputs=[output, pipe_out], api_name="InstantIR"
|
214 |
+
)
|
215 |
+
steps.change(dynamic_guidance_slider, inputs=steps, outputs=guidance_end)
|
216 |
+
output.change(dynamic_preview_slider, inputs=steps, outputs=index)
|
217 |
+
index.release(unpack_pipe_out, inputs=[pipe_out, index], outputs=preview)
|
218 |
+
output.change(show_final_preview, inputs=pipe_out, outputs=preview)
|
219 |
+
gr.Markdown(
|
220 |
+
"""
|
221 |
+
## Advance usage:
|
222 |
+
### Browse restoration variants:
|
223 |
+
1. After InstantIR processing, drag the `Restoration Previews` slider to explore other in-progress versions;
|
224 |
+
2. If you like one of them, set the `Start Free Rendering` slider to the same value to get a more refined result.
|
225 |
+
### Creative restoration:
|
226 |
+
1. Check the `Creative Restoration` checkbox;
|
227 |
+
2. Input your text prompts in the `Restoration prompts` textbox;
|
228 |
+
3. Set `Start Free Rendering` slider to a medium value (around half of the `steps`) to provide adequate room for InstantIR creation.
|
229 |
+
|
230 |
+
## Examples
|
231 |
+
Here are some examplar usage of InstantIR:
|
232 |
+
""")
|
233 |
+
# examples = gr.Gallery(label="Examples")
|
234 |
+
|
235 |
+
gr.Markdown(
|
236 |
+
"""
|
237 |
+
## Citation
|
238 |
+
If InstantIR is helpful to your work, please cite our paper via:
|
239 |
+
|
240 |
+
```
|
241 |
+
@article{huang2024instantir,
|
242 |
+
title={InstantIR: Blind Image Restoration with Instant Generative Reference},
|
243 |
+
author={Huang, Jen-Yuan and Wang, Haofan and Wang, Qixun and Bai, Xu and Ai, Hao and Xing, Peng and Huang, Jen-Tse},
|
244 |
+
journal={arXiv preprint arXiv:2410.06551},
|
245 |
+
year={2024}
|
246 |
+
}
|
247 |
+
```
|
248 |
+
""")
|
249 |
+
|
250 |
+
demo.queue().launch()
|
infer.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
|
8 |
+
|
9 |
+
from diffusers import DDPMScheduler
|
10 |
+
|
11 |
+
from module.ip_adapter.utils import load_adapter_to_pipe
|
12 |
+
from pipelines.sdxl_instantir import InstantIRPipeline
|
13 |
+
|
14 |
+
|
15 |
+
def name_unet_submodules(unet):
|
16 |
+
def recursive_find_module(name, module, end=False):
|
17 |
+
if end:
|
18 |
+
for sub_name, sub_module in module.named_children():
|
19 |
+
sub_module.full_name = f"{name}.{sub_name}"
|
20 |
+
return
|
21 |
+
if not "up_blocks" in name and not "down_blocks" in name and not "mid_block" in name: return
|
22 |
+
elif "resnets" in name: return
|
23 |
+
for sub_name, sub_module in module.named_children():
|
24 |
+
end = True if sub_name == "transformer_blocks" else False
|
25 |
+
recursive_find_module(f"{name}.{sub_name}", sub_module, end)
|
26 |
+
|
27 |
+
for name, module in unet.named_children():
|
28 |
+
recursive_find_module(name, module)
|
29 |
+
|
30 |
+
|
31 |
+
def resize_img(input_image, max_side=1280, min_side=1024, size=None,
|
32 |
+
pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
|
33 |
+
|
34 |
+
w, h = input_image.size
|
35 |
+
if size is not None:
|
36 |
+
w_resize_new, h_resize_new = size
|
37 |
+
else:
|
38 |
+
# ratio = min_side / min(h, w)
|
39 |
+
# w, h = round(ratio*w), round(ratio*h)
|
40 |
+
ratio = max_side / max(h, w)
|
41 |
+
input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
|
42 |
+
w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
|
43 |
+
h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
|
44 |
+
input_image = input_image.resize([w_resize_new, h_resize_new], mode)
|
45 |
+
|
46 |
+
if pad_to_max_side:
|
47 |
+
res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
|
48 |
+
offset_x = (max_side - w_resize_new) // 2
|
49 |
+
offset_y = (max_side - h_resize_new) // 2
|
50 |
+
res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
|
51 |
+
input_image = Image.fromarray(res)
|
52 |
+
return input_image
|
53 |
+
|
54 |
+
|
55 |
+
def tensor_to_pil(images):
|
56 |
+
"""
|
57 |
+
Convert image tensor or a batch of image tensors to PIL image(s).
|
58 |
+
"""
|
59 |
+
images = images.clamp(0, 1)
|
60 |
+
images_np = images.detach().cpu().numpy()
|
61 |
+
if images_np.ndim == 4:
|
62 |
+
images_np = np.transpose(images_np, (0, 2, 3, 1))
|
63 |
+
elif images_np.ndim == 3:
|
64 |
+
images_np = np.transpose(images_np, (1, 2, 0))
|
65 |
+
images_np = images_np[None, ...]
|
66 |
+
images_np = (images_np * 255).round().astype("uint8")
|
67 |
+
if images_np.shape[-1] == 1:
|
68 |
+
# special case for grayscale (single channel) images
|
69 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images_np]
|
70 |
+
else:
|
71 |
+
pil_images = [Image.fromarray(image[:, :, :3]) for image in images_np]
|
72 |
+
|
73 |
+
return pil_images
|
74 |
+
|
75 |
+
|
76 |
+
def calc_mean_std(feat, eps=1e-5):
|
77 |
+
"""Calculate mean and std for adaptive_instance_normalization.
|
78 |
+
Args:
|
79 |
+
feat (Tensor): 4D tensor.
|
80 |
+
eps (float): A small value added to the variance to avoid
|
81 |
+
divide-by-zero. Default: 1e-5.
|
82 |
+
"""
|
83 |
+
size = feat.size()
|
84 |
+
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
85 |
+
b, c = size[:2]
|
86 |
+
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
87 |
+
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
88 |
+
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
89 |
+
return feat_mean, feat_std
|
90 |
+
|
91 |
+
|
92 |
+
def adaptive_instance_normalization(content_feat, style_feat):
|
93 |
+
size = content_feat.size()
|
94 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
95 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
96 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
97 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
98 |
+
|
99 |
+
|
100 |
+
def main(args, device):
|
101 |
+
|
102 |
+
# Load pretrained models.
|
103 |
+
pipe = InstantIRPipeline.from_pretrained(
|
104 |
+
args.sdxl_path,
|
105 |
+
torch_dtype=torch.float16,
|
106 |
+
)
|
107 |
+
|
108 |
+
# Image prompt projector.
|
109 |
+
print("Loading LQ-Adapter...")
|
110 |
+
load_adapter_to_pipe(
|
111 |
+
pipe,
|
112 |
+
args.adapter_model_path if args.adapter_model_path is not None else os.path.join(args.instantir_path, 'adapter.pt'),
|
113 |
+
args.vision_encoder_path,
|
114 |
+
use_clip_encoder=args.use_clip_encoder,
|
115 |
+
)
|
116 |
+
|
117 |
+
# Prepare previewer
|
118 |
+
previewer_lora_path = args.previewer_lora_path if args.previewer_lora_path is not None else args.instantir_path
|
119 |
+
if previewer_lora_path is not None:
|
120 |
+
lora_alpha = pipe.prepare_previewers(previewer_lora_path)
|
121 |
+
print(f"use lora alpha {lora_alpha}")
|
122 |
+
pipe.to(device=device, dtype=torch.float16)
|
123 |
+
pipe.scheduler = DDPMScheduler.from_pretrained(args.sdxl_path, subfolder="scheduler")
|
124 |
+
lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
|
125 |
+
|
126 |
+
# Load weights.
|
127 |
+
print("Loading checkpoint...")
|
128 |
+
pretrained_state_dict = torch.load(os.path.join(args.instantir_path, "aggregator.pt"), map_location="cpu")
|
129 |
+
pipe.aggregator.load_state_dict(pretrained_state_dict)
|
130 |
+
pipe.aggregator.to(device, dtype=torch.float16)
|
131 |
+
|
132 |
+
#################### Restoration ####################
|
133 |
+
|
134 |
+
post_fix = f"_{args.post_fix}" if args.post_fix else ""
|
135 |
+
os.makedirs(f"{args.out_path}/{post_fix}", exist_ok=True)
|
136 |
+
|
137 |
+
processed_imgs = os.listdir(os.path.join(args.out_path, post_fix))
|
138 |
+
lq_files = []
|
139 |
+
lq_batch = []
|
140 |
+
if os.path.isfile(args.test_path):
|
141 |
+
all_inputs = [args.test_path.split("/")[-1]]
|
142 |
+
else:
|
143 |
+
all_inputs = os.listdir(args.test_path)
|
144 |
+
all_inputs.sort()
|
145 |
+
for file in all_inputs:
|
146 |
+
if file in processed_imgs:
|
147 |
+
print(f"Skip {file}")
|
148 |
+
continue
|
149 |
+
lq_batch.append(f"{file}")
|
150 |
+
if len(lq_batch) == args.batch_size:
|
151 |
+
lq_files.append(lq_batch)
|
152 |
+
lq_batch = []
|
153 |
+
|
154 |
+
if len(lq_batch) > 0:
|
155 |
+
lq_files.append(lq_batch)
|
156 |
+
|
157 |
+
for lq_batch in lq_files:
|
158 |
+
generator = torch.Generator(device=device).manual_seed(args.seed)
|
159 |
+
pil_lqs = [Image.open(os.path.join(args.test_path, file)) for file in lq_batch]
|
160 |
+
if args.width is None or args.height is None:
|
161 |
+
lq = [resize_img(pil_lq.convert("RGB"), size=None) for pil_lq in pil_lqs]
|
162 |
+
else:
|
163 |
+
lq = [resize_img(pil_lq.convert("RGB"), size=(args.width, args.height)) for pil_lq in pil_lqs]
|
164 |
+
timesteps = None
|
165 |
+
if args.denoising_start < 1000:
|
166 |
+
timesteps = [
|
167 |
+
i * (args.denoising_start//args.num_inference_steps) + pipe.scheduler.config.steps_offset for i in range(0, args.num_inference_steps)
|
168 |
+
]
|
169 |
+
timesteps = timesteps[::-1]
|
170 |
+
pipe.scheduler.set_timesteps(args.num_inference_steps, device)
|
171 |
+
timesteps = pipe.scheduler.timesteps
|
172 |
+
if args.prompt is None or len(args.prompt) == 0:
|
173 |
+
prompt = "Photorealistic, highly detailed, hyper detailed photo - realistic maximum detail, 32k, \
|
174 |
+
ultra HD, extreme meticulous detailing, skin pore detailing, \
|
175 |
+
hyper sharpness, perfect without deformations, \
|
176 |
+
taken using a Canon EOS R camera, Cinematic, High Contrast, Color Grading. "
|
177 |
+
else:
|
178 |
+
prompt = args.prompt
|
179 |
+
if not isinstance(prompt, list):
|
180 |
+
prompt = [prompt]
|
181 |
+
prompt = prompt*len(lq)
|
182 |
+
if args.neg_prompt is None or len(args.neg_prompt) == 0:
|
183 |
+
neg_prompt = "blurry, out of focus, unclear, depth of field, over-smooth, \
|
184 |
+
sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, \
|
185 |
+
dirty, messy, worst quality, low quality, frames, painting, illustration, drawing, art, \
|
186 |
+
watermark, signature, jpeg artifacts, deformed, lowres"
|
187 |
+
else:
|
188 |
+
neg_prompt = args.neg_prompt
|
189 |
+
if not isinstance(neg_prompt, list):
|
190 |
+
neg_prompt = [neg_prompt]
|
191 |
+
neg_prompt = neg_prompt*len(lq)
|
192 |
+
image = pipe(
|
193 |
+
prompt=prompt,
|
194 |
+
image=lq,
|
195 |
+
num_inference_steps=args.num_inference_steps,
|
196 |
+
generator=generator,
|
197 |
+
timesteps=timesteps,
|
198 |
+
negative_prompt=neg_prompt,
|
199 |
+
guidance_scale=args.cfg,
|
200 |
+
previewer_scheduler=lcm_scheduler,
|
201 |
+
preview_start=args.preview_start,
|
202 |
+
control_guidance_end=args.creative_start,
|
203 |
+
).images
|
204 |
+
|
205 |
+
if args.save_preview_row:
|
206 |
+
for i, lcm_image in enumerate(image[1]):
|
207 |
+
lcm_image.save(f"./lcm/{i}.png")
|
208 |
+
for i, rec_image in enumerate(image):
|
209 |
+
rec_image.save(f"{args.out_path}/{post_fix}/{lq_batch[i]}")
|
210 |
+
|
211 |
+
|
212 |
+
if __name__ == "__main__":
|
213 |
+
parser = argparse.ArgumentParser(description="InstantIR pipeline")
|
214 |
+
parser.add_argument(
|
215 |
+
"--sdxl_path",
|
216 |
+
type=str,
|
217 |
+
default=None,
|
218 |
+
required=True,
|
219 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
220 |
+
)
|
221 |
+
parser.add_argument(
|
222 |
+
"--previewer_lora_path",
|
223 |
+
type=str,
|
224 |
+
default=None,
|
225 |
+
help="Path to LCM lora or model identifier from huggingface.co/models.",
|
226 |
+
)
|
227 |
+
parser.add_argument(
|
228 |
+
"--pretrained_vae_model_name_or_path",
|
229 |
+
type=str,
|
230 |
+
default=None,
|
231 |
+
help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
|
232 |
+
)
|
233 |
+
parser.add_argument(
|
234 |
+
"--instantir_path",
|
235 |
+
type=str,
|
236 |
+
default=None,
|
237 |
+
required=True,
|
238 |
+
help="Path to pretrained instantir model.",
|
239 |
+
)
|
240 |
+
parser.add_argument(
|
241 |
+
"--vision_encoder_path",
|
242 |
+
type=str,
|
243 |
+
default='/share/huangrenyuan/model_zoo/vis_backbone/dinov2_large',
|
244 |
+
help="Path to image encoder for IP-Adapters or model identifier from huggingface.co/models.",
|
245 |
+
)
|
246 |
+
parser.add_argument(
|
247 |
+
"--adapter_model_path",
|
248 |
+
type=str,
|
249 |
+
default=None,
|
250 |
+
help="Path to IP-Adapter models or model identifier from huggingface.co/models.",
|
251 |
+
)
|
252 |
+
parser.add_argument(
|
253 |
+
"--adapter_tokens",
|
254 |
+
type=int,
|
255 |
+
default=64,
|
256 |
+
help="Number of tokens to use in IP-adapter cross attention mechanism.",
|
257 |
+
)
|
258 |
+
parser.add_argument(
|
259 |
+
"--use_clip_encoder",
|
260 |
+
action="store_true",
|
261 |
+
help="Whether or not to use DINO as image encoder, else CLIP encoder.",
|
262 |
+
)
|
263 |
+
parser.add_argument(
|
264 |
+
"--denoising_start",
|
265 |
+
type=int,
|
266 |
+
default=1000,
|
267 |
+
help="Diffusion start timestep."
|
268 |
+
)
|
269 |
+
parser.add_argument(
|
270 |
+
"--num_inference_steps",
|
271 |
+
type=int,
|
272 |
+
default=30,
|
273 |
+
help="Diffusion steps."
|
274 |
+
)
|
275 |
+
parser.add_argument(
|
276 |
+
"--creative_start",
|
277 |
+
type=float,
|
278 |
+
default=1.0,
|
279 |
+
help="Proportion of timesteps for creative restoration. 1.0 means no creative restoration while 0.0 means completely free rendering."
|
280 |
+
)
|
281 |
+
parser.add_argument(
|
282 |
+
"--preview_start",
|
283 |
+
type=float,
|
284 |
+
default=0.0,
|
285 |
+
help="Proportion of timesteps to stop previewing at the begining to enhance fidelity to input."
|
286 |
+
)
|
287 |
+
parser.add_argument(
|
288 |
+
"--resolution",
|
289 |
+
type=int,
|
290 |
+
default=1024,
|
291 |
+
help="Number of tokens to use in IP-adapter cross attention mechanism.",
|
292 |
+
)
|
293 |
+
parser.add_argument(
|
294 |
+
"--batch_size",
|
295 |
+
type=int,
|
296 |
+
default=6,
|
297 |
+
help="Test batch size."
|
298 |
+
)
|
299 |
+
parser.add_argument(
|
300 |
+
"--width",
|
301 |
+
type=int,
|
302 |
+
default=None,
|
303 |
+
help="Output image width."
|
304 |
+
)
|
305 |
+
parser.add_argument(
|
306 |
+
"--height",
|
307 |
+
type=int,
|
308 |
+
default=None,
|
309 |
+
help="Output image height."
|
310 |
+
)
|
311 |
+
parser.add_argument(
|
312 |
+
"--cfg",
|
313 |
+
type=float,
|
314 |
+
default=7.0,
|
315 |
+
help="Scale of Classifier-Free-Guidance (CFG).",
|
316 |
+
)
|
317 |
+
parser.add_argument(
|
318 |
+
"--post_fix",
|
319 |
+
type=str,
|
320 |
+
default=None,
|
321 |
+
help="Subfolder name for restoration output under the output directory.",
|
322 |
+
)
|
323 |
+
parser.add_argument(
|
324 |
+
"--variant",
|
325 |
+
type=str,
|
326 |
+
default='fp16',
|
327 |
+
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
328 |
+
)
|
329 |
+
parser.add_argument(
|
330 |
+
"--revision",
|
331 |
+
type=str,
|
332 |
+
default=None,
|
333 |
+
required=False,
|
334 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
335 |
+
)
|
336 |
+
parser.add_argument(
|
337 |
+
"--save_preview_row",
|
338 |
+
action="store_true",
|
339 |
+
help="Whether or not to save the intermediate lcm outputs.",
|
340 |
+
)
|
341 |
+
parser.add_argument(
|
342 |
+
"--prompt",
|
343 |
+
type=str,
|
344 |
+
default='',
|
345 |
+
nargs="+",
|
346 |
+
help=(
|
347 |
+
"A set of prompts for creative restoration. Provide either a matching number of test images,"
|
348 |
+
" or a single prompt to be used with all inputs."
|
349 |
+
),
|
350 |
+
)
|
351 |
+
parser.add_argument(
|
352 |
+
"--neg_prompt",
|
353 |
+
type=str,
|
354 |
+
default='',
|
355 |
+
nargs="+",
|
356 |
+
help=(
|
357 |
+
"A set of negative prompts for creative restoration. Provide either a matching number of test images,"
|
358 |
+
" or a single negative prompt to be used with all inputs."
|
359 |
+
),
|
360 |
+
)
|
361 |
+
parser.add_argument(
|
362 |
+
"--test_path",
|
363 |
+
type=str,
|
364 |
+
default=None,
|
365 |
+
required=True,
|
366 |
+
help="Test directory.",
|
367 |
+
)
|
368 |
+
parser.add_argument(
|
369 |
+
"--out_path",
|
370 |
+
type=str,
|
371 |
+
default="./output",
|
372 |
+
help="Output directory.",
|
373 |
+
)
|
374 |
+
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
375 |
+
args = parser.parse_args()
|
376 |
+
args.height = args.height or args.width
|
377 |
+
args.width = args.width or args.height
|
378 |
+
if args.height is not None and (args.width % 64 != 0 or args.height % 64 != 0):
|
379 |
+
raise ValueError("Image resolution must be divisible by 64.")
|
380 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
381 |
+
main(args, device)
|
infer.sh
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python infer.py \
|
2 |
+
--sdxl_path path/to/sdxl \
|
3 |
+
--vision_encoder_path path/to/dinov2_large \
|
4 |
+
--instantir_path path/to/instantir \
|
5 |
+
--test_path path/to/input \
|
6 |
+
--out_path path/to/output
|
losses/loss_config.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class SingleLossConfig:
|
6 |
+
name: str
|
7 |
+
weight: float = 1.
|
8 |
+
init_params: dict = field(default_factory=dict)
|
9 |
+
visualize_every_k: int = -1
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class LossesConfig:
|
14 |
+
diffusion_losses: List[SingleLossConfig]
|
15 |
+
lcm_losses: List[SingleLossConfig]
|
losses/losses.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import wandb
|
3 |
+
import cv2
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
from facenet_pytorch import MTCNN
|
7 |
+
from torchvision import transforms
|
8 |
+
from dreamsim import dreamsim
|
9 |
+
from einops import rearrange
|
10 |
+
import kornia.augmentation as K
|
11 |
+
import lpips
|
12 |
+
|
13 |
+
from pretrained_models.arcface import Backbone
|
14 |
+
from utils.vis_utils import add_text_to_image
|
15 |
+
from utils.utils import extract_faces_and_landmarks
|
16 |
+
import clip
|
17 |
+
|
18 |
+
|
19 |
+
class Loss():
|
20 |
+
"""
|
21 |
+
General purpose loss class.
|
22 |
+
Mainly handles dtype and visualize_every_k.
|
23 |
+
keeps current iteration of loss, mainly for visualization purposes.
|
24 |
+
"""
|
25 |
+
def __init__(self, visualize_every_k=-1, dtype=torch.float32, accelerator=None, **kwargs):
|
26 |
+
self.visualize_every_k = visualize_every_k
|
27 |
+
self.iteration = -1
|
28 |
+
self.dtype=dtype
|
29 |
+
self.accelerator = accelerator
|
30 |
+
|
31 |
+
def __call__(self, **kwargs):
|
32 |
+
self.iteration += 1
|
33 |
+
return self.forward(**kwargs)
|
34 |
+
|
35 |
+
|
36 |
+
class L1Loss(Loss):
|
37 |
+
"""
|
38 |
+
Simple L1 loss between predicted_pixel_values and pixel_values
|
39 |
+
|
40 |
+
Args:
|
41 |
+
predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder.
|
42 |
+
encoder_pixel_values (torch.Tesnor): The input image to the encoder
|
43 |
+
"""
|
44 |
+
def forward(
|
45 |
+
self,
|
46 |
+
predict: torch.Tensor,
|
47 |
+
target: torch.Tensor,
|
48 |
+
**kwargs
|
49 |
+
) -> torch.Tensor:
|
50 |
+
return F.l1_loss(predict, target, reduction="mean")
|
51 |
+
|
52 |
+
|
53 |
+
class DreamSIMLoss(Loss):
|
54 |
+
"""DreamSIM loss between predicted_pixel_values and pixel_values.
|
55 |
+
DreamSIM is similar to LPIPS (https://dreamsim-nights.github.io/) but is trained on more human defined similarity dataset
|
56 |
+
DreamSIM expects an RGB image of size 224x224 and values between 0 and 1. So we need to normalize the input images to 0-1 range and resize them to 224x224.
|
57 |
+
Args:
|
58 |
+
predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder.
|
59 |
+
encoder_pixel_values (torch.Tesnor): The input image to the encoder
|
60 |
+
"""
|
61 |
+
def __init__(self, device: str='cuda:0', **kwargs):
|
62 |
+
super().__init__(**kwargs)
|
63 |
+
self.model, _ = dreamsim(pretrained=True, device=device)
|
64 |
+
self.model.to(dtype=self.dtype, device=device)
|
65 |
+
self.model = self.accelerator.prepare(self.model)
|
66 |
+
self.transforms = transforms.Compose([
|
67 |
+
transforms.Lambda(lambda x: (x + 1) / 2),
|
68 |
+
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC)])
|
69 |
+
|
70 |
+
def forward(
|
71 |
+
self,
|
72 |
+
predicted_pixel_values: torch.Tensor,
|
73 |
+
encoder_pixel_values: torch.Tensor,
|
74 |
+
**kwargs,
|
75 |
+
) -> torch.Tensor:
|
76 |
+
predicted_pixel_values.to(dtype=self.dtype)
|
77 |
+
encoder_pixel_values.to(dtype=self.dtype)
|
78 |
+
return self.model(self.transforms(predicted_pixel_values), self.transforms(encoder_pixel_values)).mean()
|
79 |
+
|
80 |
+
|
81 |
+
class LPIPSLoss(Loss):
|
82 |
+
"""LPIPS loss between predicted_pixel_values and pixel_values.
|
83 |
+
Args:
|
84 |
+
predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder.
|
85 |
+
encoder_pixel_values (torch.Tesnor): The input image to the encoder
|
86 |
+
"""
|
87 |
+
def __init__(self, **kwargs):
|
88 |
+
super().__init__(**kwargs)
|
89 |
+
self.model = lpips.LPIPS(net='vgg')
|
90 |
+
self.model.to(dtype=self.dtype, device=self.accelerator.device)
|
91 |
+
self.model = self.accelerator.prepare(self.model)
|
92 |
+
|
93 |
+
def forward(self, predict, target, **kwargs):
|
94 |
+
predict.to(dtype=self.dtype)
|
95 |
+
target.to(dtype=self.dtype)
|
96 |
+
return self.model(predict, target).mean()
|
97 |
+
|
98 |
+
|
99 |
+
class LCMVisualization(Loss):
|
100 |
+
"""Dummy loss used to visualize the LCM outputs
|
101 |
+
Args:
|
102 |
+
predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder.
|
103 |
+
pixel_values (torch.Tensor): The input image to the decoder
|
104 |
+
encoder_pixel_values (torch.Tesnor): The input image to the encoder
|
105 |
+
"""
|
106 |
+
def forward(
|
107 |
+
self,
|
108 |
+
predicted_pixel_values: torch.Tensor,
|
109 |
+
pixel_values: torch.Tensor,
|
110 |
+
encoder_pixel_values: torch.Tensor,
|
111 |
+
timesteps: torch.Tensor,
|
112 |
+
**kwargs,
|
113 |
+
) -> None:
|
114 |
+
if self.visualize_every_k > 0 and self.iteration % self.visualize_every_k == 0:
|
115 |
+
predicted_pixel_values = rearrange(predicted_pixel_values, "n c h w -> (n h) w c").detach().cpu().numpy()
|
116 |
+
pixel_values = rearrange(pixel_values, "n c h w -> (n h) w c").detach().cpu().numpy()
|
117 |
+
encoder_pixel_values = rearrange(encoder_pixel_values, "n c h w -> (n h) w c").detach().cpu().numpy()
|
118 |
+
image = np.hstack([encoder_pixel_values, pixel_values, predicted_pixel_values])
|
119 |
+
for tracker in self.accelerator.trackers:
|
120 |
+
if tracker.name == 'wandb':
|
121 |
+
tracker.log({"TrainVisualization": wandb.Image(image, caption=f"Encoder Input Image, Decoder Input Image, Predicted LCM Image. Timesteps {timesteps.cpu().tolist()}")})
|
122 |
+
return torch.tensor(0.0)
|
123 |
+
|
124 |
+
|
125 |
+
class L2Loss(Loss):
|
126 |
+
"""
|
127 |
+
Regular diffusion loss between predicted noise and target noise.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
predicted_noise (torch.Tensor): noise predicted by the diffusion model
|
131 |
+
target_noise (torch.Tensor): actual noise added to the image.
|
132 |
+
"""
|
133 |
+
def forward(
|
134 |
+
self,
|
135 |
+
predict: torch.Tensor,
|
136 |
+
target: torch.Tensor,
|
137 |
+
weights: torch.Tensor = None,
|
138 |
+
**kwargs
|
139 |
+
) -> torch.Tensor:
|
140 |
+
if weights is not None:
|
141 |
+
loss = (predict.float() - target.float()).pow(2) * weights
|
142 |
+
return loss.mean()
|
143 |
+
return F.mse_loss(predict.float(), target.float(), reduction="mean")
|
144 |
+
|
145 |
+
|
146 |
+
class HuberLoss(Loss):
|
147 |
+
"""Huber loss between predicted_pixel_values and pixel_values.
|
148 |
+
Args:
|
149 |
+
predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder.
|
150 |
+
encoder_pixel_values (torch.Tesnor): The input image to the encoder
|
151 |
+
"""
|
152 |
+
def __init__(self, huber_c=0.001, **kwargs):
|
153 |
+
super().__init__(**kwargs)
|
154 |
+
self.huber_c = huber_c
|
155 |
+
|
156 |
+
def forward(
|
157 |
+
self,
|
158 |
+
predict: torch.Tensor,
|
159 |
+
target: torch.Tensor,
|
160 |
+
weights: torch.Tensor = None,
|
161 |
+
**kwargs
|
162 |
+
) -> torch.Tensor:
|
163 |
+
loss = torch.sqrt((predict.float() - target.float()) ** 2 + self.huber_c**2) - self.huber_c
|
164 |
+
if weights is not None:
|
165 |
+
return (loss * weights).mean()
|
166 |
+
return loss.mean()
|
167 |
+
|
168 |
+
|
169 |
+
class WeightedNoiseLoss(Loss):
|
170 |
+
"""
|
171 |
+
Weighted diffusion loss between predicted noise and target noise.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
predicted_noise (torch.Tensor): noise predicted by the diffusion model
|
175 |
+
target_noise (torch.Tensor): actual noise added to the image.
|
176 |
+
loss_batch_weights (torch.Tensor): weighting for each batch item. Can be used to e.g. zero-out loss for InstantID training if keypoint extraction fails.
|
177 |
+
"""
|
178 |
+
def forward(
|
179 |
+
self,
|
180 |
+
predict: torch.Tensor,
|
181 |
+
target: torch.Tensor,
|
182 |
+
weights,
|
183 |
+
**kwargs
|
184 |
+
) -> torch.Tensor:
|
185 |
+
return F.mse_loss(predict.float() * weights, target.float() * weights, reduction="mean")
|
186 |
+
|
187 |
+
|
188 |
+
class IDLoss(Loss):
|
189 |
+
"""
|
190 |
+
Use pretrained facenet model to extract features from the face of the predicted image and target image.
|
191 |
+
Facenet expects 112x112 images, so we crop the face using MTCNN and resize it to 112x112.
|
192 |
+
Then we use the cosine similarity between the features to calculate the loss. (The cosine similarity is 1 - cosine distance).
|
193 |
+
Also notice that the outputs of facenet are normalized so the dot product is the same as cosine distance.
|
194 |
+
"""
|
195 |
+
def __init__(self, pretrained_arcface_path: str, skip_not_found=True, **kwargs):
|
196 |
+
super().__init__(**kwargs)
|
197 |
+
assert pretrained_arcface_path is not None, "please pass `pretrained_arcface_path` in the losses config. You can download the pretrained model from "\
|
198 |
+
"https://drive.google.com/file/d/1KW7bjndL3QG3sxBbZxreGHigcCCpsDgn/view?usp=sharing"
|
199 |
+
self.mtcnn = MTCNN(device=self.accelerator.device)
|
200 |
+
self.mtcnn.forward = self.mtcnn.detect
|
201 |
+
self.facenet_input_size = 112 # Has to be 112, can't find weights for 224 size.
|
202 |
+
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
203 |
+
self.facenet.load_state_dict(torch.load(pretrained_arcface_path))
|
204 |
+
self.face_pool = torch.nn.AdaptiveAvgPool2d((self.facenet_input_size, self.facenet_input_size))
|
205 |
+
self.facenet.requires_grad_(False)
|
206 |
+
self.facenet.eval()
|
207 |
+
self.facenet.to(device=self.accelerator.device, dtype=self.dtype) # not implemented for half precision
|
208 |
+
self.face_pool.to(device=self.accelerator.device, dtype=self.dtype) # not implemented for half precision
|
209 |
+
self.visualization_resize = transforms.Resize((self.facenet_input_size, self.facenet_input_size), interpolation=transforms.InterpolationMode.BICUBIC)
|
210 |
+
self.reference_facial_points = np.array([[38.29459953, 51.69630051],
|
211 |
+
[72.53179932, 51.50139999],
|
212 |
+
[56.02519989, 71.73660278],
|
213 |
+
[41.54930115, 92.3655014],
|
214 |
+
[70.72990036, 92.20410156]
|
215 |
+
]) # Original points are 112 * 96 added 8 to the x axis to make it 112 * 112
|
216 |
+
self.facenet, self.face_pool, self.mtcnn = self.accelerator.prepare(self.facenet, self.face_pool, self.mtcnn)
|
217 |
+
|
218 |
+
self.skip_not_found = skip_not_found
|
219 |
+
|
220 |
+
def extract_feats(self, x: torch.Tensor):
|
221 |
+
"""
|
222 |
+
Extract features from the face of the image using facenet model.
|
223 |
+
"""
|
224 |
+
x = self.face_pool(x)
|
225 |
+
x_feats = self.facenet(x)
|
226 |
+
|
227 |
+
return x_feats
|
228 |
+
|
229 |
+
def forward(
|
230 |
+
self,
|
231 |
+
predicted_pixel_values: torch.Tensor,
|
232 |
+
encoder_pixel_values: torch.Tensor,
|
233 |
+
timesteps: torch.Tensor,
|
234 |
+
**kwargs
|
235 |
+
):
|
236 |
+
encoder_pixel_values = encoder_pixel_values.to(dtype=self.dtype)
|
237 |
+
predicted_pixel_values = predicted_pixel_values.to(dtype=self.dtype)
|
238 |
+
|
239 |
+
predicted_pixel_values_face, predicted_invalid_indices = extract_faces_and_landmarks(predicted_pixel_values, mtcnn=self.mtcnn)
|
240 |
+
with torch.no_grad():
|
241 |
+
encoder_pixel_values_face, source_invalid_indices = extract_faces_and_landmarks(encoder_pixel_values, mtcnn=self.mtcnn)
|
242 |
+
|
243 |
+
if self.skip_not_found:
|
244 |
+
valid_indices = []
|
245 |
+
for i in range(predicted_pixel_values.shape[0]):
|
246 |
+
if i not in predicted_invalid_indices and i not in source_invalid_indices:
|
247 |
+
valid_indices.append(i)
|
248 |
+
else:
|
249 |
+
valid_indices = list(range(predicted_pixel_values))
|
250 |
+
|
251 |
+
valid_indices = torch.tensor(valid_indices).to(device=predicted_pixel_values.device)
|
252 |
+
|
253 |
+
if len(valid_indices) == 0:
|
254 |
+
loss = (predicted_pixel_values_face * 0.0).mean() # It's done this way so the `backwards` will delete the computation graph of the predicted_pixel_values.
|
255 |
+
if self.visualize_every_k > 0 and self.iteration % self.visualize_every_k == 0:
|
256 |
+
self.visualize(predicted_pixel_values, encoder_pixel_values, predicted_pixel_values_face, encoder_pixel_values_face, timesteps, valid_indices, loss)
|
257 |
+
return loss
|
258 |
+
|
259 |
+
with torch.no_grad():
|
260 |
+
pixel_values_feats = self.extract_feats(encoder_pixel_values_face[valid_indices])
|
261 |
+
|
262 |
+
predicted_pixel_values_feats = self.extract_feats(predicted_pixel_values_face[valid_indices])
|
263 |
+
loss = 1 - torch.einsum("bi,bi->b", pixel_values_feats, predicted_pixel_values_feats)
|
264 |
+
|
265 |
+
if self.visualize_every_k > 0 and self.iteration % self.visualize_every_k == 0:
|
266 |
+
self.visualize(predicted_pixel_values, encoder_pixel_values, predicted_pixel_values_face, encoder_pixel_values_face, timesteps, valid_indices, loss)
|
267 |
+
return loss.mean()
|
268 |
+
|
269 |
+
def visualize(
|
270 |
+
self,
|
271 |
+
predicted_pixel_values: torch.Tensor,
|
272 |
+
encoder_pixel_values: torch.Tensor,
|
273 |
+
predicted_pixel_values_face: torch.Tensor,
|
274 |
+
encoder_pixel_values_face: torch.Tensor,
|
275 |
+
timesteps: torch.Tensor,
|
276 |
+
valid_indices: torch.Tensor,
|
277 |
+
loss: torch.Tensor,
|
278 |
+
) -> None:
|
279 |
+
small_predicted_pixel_values = (rearrange(self.visualization_resize(predicted_pixel_values), "n c h w -> (n h) w c").detach().cpu().numpy())
|
280 |
+
small_pixle_values = rearrange(self.visualization_resize(encoder_pixel_values), "n c h w -> (n h) w c").detach().cpu().numpy()
|
281 |
+
small_predicted_pixel_values_face = rearrange(self.visualization_resize(predicted_pixel_values_face), "n c h w -> (n h) w c").detach().cpu().numpy()
|
282 |
+
small_pixle_values_face = rearrange(self.visualization_resize(encoder_pixel_values_face), "n c h w -> (n h) w c").detach().cpu().numpy()
|
283 |
+
|
284 |
+
small_predicted_pixel_values = add_text_to_image(((small_predicted_pixel_values * 0.5 + 0.5) * 255).astype(np.uint8), "Pred Images", add_below=False)
|
285 |
+
small_pixle_values = add_text_to_image(((small_pixle_values * 0.5 + 0.5) * 255).astype(np.uint8), "Target Images", add_below=False)
|
286 |
+
small_predicted_pixel_values_face = add_text_to_image(((small_predicted_pixel_values_face * 0.5 + 0.5) * 255).astype(np.uint8), "Pred Faces", add_below=False)
|
287 |
+
small_pixle_values_face = add_text_to_image(((small_pixle_values_face * 0.5 + 0.5) * 255).astype(np.uint8), "Target Faces", add_below=False)
|
288 |
+
|
289 |
+
|
290 |
+
final_image = np.hstack([small_predicted_pixel_values, small_pixle_values, small_predicted_pixel_values_face, small_pixle_values_face])
|
291 |
+
for tracker in self.accelerator.trackers:
|
292 |
+
if tracker.name == 'wandb':
|
293 |
+
tracker.log({"IDLoss Visualization": wandb.Image(final_image, caption=f"loss: {loss.cpu().tolist()} timesteps: {timesteps.cpu().tolist()}, valid_indices: {valid_indices.cpu().tolist()}")})
|
294 |
+
|
295 |
+
|
296 |
+
class ImageAugmentations(torch.nn.Module):
|
297 |
+
# Standard image augmentations used for CLIP loss to discourage adversarial outputs.
|
298 |
+
def __init__(self, output_size, augmentations_number, p=0.7):
|
299 |
+
super().__init__()
|
300 |
+
self.output_size = output_size
|
301 |
+
self.augmentations_number = augmentations_number
|
302 |
+
|
303 |
+
self.augmentations = torch.nn.Sequential(
|
304 |
+
K.RandomAffine(degrees=15, translate=0.1, p=p, padding_mode="border"), # type: ignore
|
305 |
+
K.RandomPerspective(0.7, p=p),
|
306 |
+
)
|
307 |
+
|
308 |
+
self.avg_pool = torch.nn.AdaptiveAvgPool2d((self.output_size, self.output_size))
|
309 |
+
|
310 |
+
self.device = None
|
311 |
+
|
312 |
+
def forward(self, input):
|
313 |
+
"""Extents the input batch with augmentations
|
314 |
+
If the input is consists of images [I1, I2] the extended augmented output
|
315 |
+
will be [I1_resized, I2_resized, I1_aug1, I2_aug1, I1_aug2, I2_aug2 ...]
|
316 |
+
Args:
|
317 |
+
input ([type]): input batch of shape [batch, C, H, W]
|
318 |
+
Returns:
|
319 |
+
updated batch: of shape [batch * augmentations_number, C, H, W]
|
320 |
+
"""
|
321 |
+
# We want to multiply the number of images in the batch in contrast to regular augmantations
|
322 |
+
# that do not change the number of samples in the batch)
|
323 |
+
resized_images = self.avg_pool(input)
|
324 |
+
resized_images = torch.tile(resized_images, dims=(self.augmentations_number, 1, 1, 1))
|
325 |
+
|
326 |
+
batch_size = input.shape[0]
|
327 |
+
# We want at least one non augmented image
|
328 |
+
non_augmented_batch = resized_images[:batch_size]
|
329 |
+
augmented_batch = self.augmentations(resized_images[batch_size:])
|
330 |
+
updated_batch = torch.cat([non_augmented_batch, augmented_batch], dim=0)
|
331 |
+
|
332 |
+
return updated_batch
|
333 |
+
|
334 |
+
|
335 |
+
class CLIPLoss(Loss):
|
336 |
+
def __init__(self, augmentations_number: int = 4, **kwargs):
|
337 |
+
super().__init__(**kwargs)
|
338 |
+
|
339 |
+
self.clip_model, clip_preprocess = clip.load("ViT-B/16", device=self.accelerator.device, jit=False)
|
340 |
+
|
341 |
+
self.clip_model.device = None
|
342 |
+
|
343 |
+
self.clip_model.eval().requires_grad_(False)
|
344 |
+
|
345 |
+
self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (SD output) to [0, 1].
|
346 |
+
clip_preprocess.transforms[:2] + # to match CLIP input scale assumptions
|
347 |
+
clip_preprocess.transforms[4:]) # + skip convert PIL to tensor
|
348 |
+
|
349 |
+
self.clip_size = self.clip_model.visual.input_resolution
|
350 |
+
|
351 |
+
self.clip_normalize = transforms.Normalize(
|
352 |
+
mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]
|
353 |
+
)
|
354 |
+
|
355 |
+
self.image_augmentations = ImageAugmentations(output_size=self.clip_size,
|
356 |
+
augmentations_number=augmentations_number)
|
357 |
+
|
358 |
+
self.clip_model, self.image_augmentations = self.accelerator.prepare(self.clip_model, self.image_augmentations)
|
359 |
+
|
360 |
+
def forward(self, decoder_prompts, predicted_pixel_values: torch.Tensor, **kwargs) -> torch.Tensor:
|
361 |
+
|
362 |
+
if not isinstance(decoder_prompts, list):
|
363 |
+
decoder_prompts = [decoder_prompts]
|
364 |
+
|
365 |
+
tokens = clip.tokenize(decoder_prompts).to(predicted_pixel_values.device)
|
366 |
+
image = self.preprocess(predicted_pixel_values)
|
367 |
+
|
368 |
+
logits_per_image, _ = self.clip_model(image, tokens)
|
369 |
+
|
370 |
+
logits_per_image = torch.diagonal(logits_per_image)
|
371 |
+
|
372 |
+
return (1. - logits_per_image / 100).mean()
|
373 |
+
|
374 |
+
|
375 |
+
class DINOLoss(Loss):
|
376 |
+
def __init__(
|
377 |
+
self,
|
378 |
+
dino_model,
|
379 |
+
dino_preprocess,
|
380 |
+
output_hidden_states: bool = False,
|
381 |
+
center_momentum: float = 0.9,
|
382 |
+
student_temp: float = 0.1,
|
383 |
+
teacher_temp: float = 0.04,
|
384 |
+
warmup_teacher_temp: float = 0.04,
|
385 |
+
warmup_teacher_temp_epochs: int = 30,
|
386 |
+
**kwargs):
|
387 |
+
super().__init__(**kwargs)
|
388 |
+
|
389 |
+
self.dino_model = dino_model
|
390 |
+
self.output_hidden_states = output_hidden_states
|
391 |
+
self.rescale_factor = dino_preprocess.rescale_factor
|
392 |
+
|
393 |
+
# Un-normalize from [-1.0, 1.0] (SD output) to [0, 1].
|
394 |
+
self.preprocess = transforms.Compose(
|
395 |
+
[
|
396 |
+
transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0]),
|
397 |
+
transforms.Resize(size=256),
|
398 |
+
transforms.CenterCrop(size=(224, 224)),
|
399 |
+
transforms.Normalize(mean=dino_preprocess.image_mean, std=dino_preprocess.image_std)
|
400 |
+
]
|
401 |
+
)
|
402 |
+
|
403 |
+
self.student_temp = student_temp
|
404 |
+
self.teacher_temp = teacher_temp
|
405 |
+
self.center_momentum = center_momentum
|
406 |
+
self.center = torch.zeros(1, 257, 1024).to(self.accelerator.device, dtype=self.dtype)
|
407 |
+
|
408 |
+
# TODO: add temp, now fixed to 0.04
|
409 |
+
# we apply a warm up for the teacher temperature because
|
410 |
+
# a too high temperature makes the training instable at the beginning
|
411 |
+
# self.teacher_temp_schedule = np.concatenate((
|
412 |
+
# np.linspace(warmup_teacher_temp,
|
413 |
+
# teacher_temp, warmup_teacher_temp_epochs),
|
414 |
+
# np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
|
415 |
+
# ))
|
416 |
+
|
417 |
+
self.dino_model = self.accelerator.prepare(self.dino_model)
|
418 |
+
|
419 |
+
def forward(
|
420 |
+
self,
|
421 |
+
target: torch.Tensor,
|
422 |
+
predict: torch.Tensor,
|
423 |
+
weights: torch.Tensor = None,
|
424 |
+
**kwargs) -> torch.Tensor:
|
425 |
+
|
426 |
+
predict = self.preprocess(predict)
|
427 |
+
target = self.preprocess(target)
|
428 |
+
|
429 |
+
encoder_input = torch.cat([target, predict]).to(self.dino_model.device, dtype=self.dino_model.dtype)
|
430 |
+
|
431 |
+
if self.output_hidden_states:
|
432 |
+
raise ValueError("Output hidden states not supported for DINO loss.")
|
433 |
+
image_enc_hidden_states = self.dino_model(encoder_input, output_hidden_states=True).hidden_states[-2]
|
434 |
+
else:
|
435 |
+
image_enc_hidden_states = self.dino_model(encoder_input).last_hidden_state
|
436 |
+
|
437 |
+
teacher_output, student_output = image_enc_hidden_states.chunk(2, dim=0) # [B, 257, 1024]
|
438 |
+
|
439 |
+
student_out = student_output.float() / self.student_temp
|
440 |
+
|
441 |
+
# teacher centering and sharpening
|
442 |
+
# temp = self.teacher_temp_schedule[epoch]
|
443 |
+
temp = self.teacher_temp
|
444 |
+
teacher_out = F.softmax((teacher_output.float() - self.center) / temp, dim=-1)
|
445 |
+
teacher_out = teacher_out.detach()
|
446 |
+
|
447 |
+
loss = torch.sum(-teacher_out * F.log_softmax(student_out, dim=-1), dim=-1, keepdim=True)
|
448 |
+
# self.update_center(teacher_output)
|
449 |
+
|
450 |
+
if weights is not None:
|
451 |
+
loss = loss * weights
|
452 |
+
return loss.mean()
|
453 |
+
return loss.mean()
|
454 |
+
|
455 |
+
@torch.no_grad()
|
456 |
+
def update_center(self, teacher_output):
|
457 |
+
"""
|
458 |
+
Update center used for teacher output.
|
459 |
+
"""
|
460 |
+
batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
|
461 |
+
self.accelerator.reduce(batch_center, reduction="sum")
|
462 |
+
batch_center = batch_center / (len(teacher_output) * self.accelerator.num_processes)
|
463 |
+
|
464 |
+
# ema update
|
465 |
+
self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
|
module/aggregator.py
ADDED
@@ -0,0 +1,983 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
9 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
10 |
+
from diffusers.utils import BaseOutput, logging
|
11 |
+
from diffusers.models.attention_processor import (
|
12 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
13 |
+
CROSS_ATTENTION_PROCESSORS,
|
14 |
+
AttentionProcessor,
|
15 |
+
AttnAddedKVProcessor,
|
16 |
+
AttnProcessor,
|
17 |
+
)
|
18 |
+
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
19 |
+
from diffusers.models.modeling_utils import ModelMixin
|
20 |
+
from diffusers.models.unets.unet_2d_blocks import (
|
21 |
+
CrossAttnDownBlock2D,
|
22 |
+
DownBlock2D,
|
23 |
+
UNetMidBlock2D,
|
24 |
+
UNetMidBlock2DCrossAttn,
|
25 |
+
get_down_block,
|
26 |
+
)
|
27 |
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
28 |
+
|
29 |
+
|
30 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
31 |
+
|
32 |
+
|
33 |
+
class ZeroConv(nn.Module):
|
34 |
+
def __init__(self, label_nc, norm_nc, mask=False):
|
35 |
+
super().__init__()
|
36 |
+
self.zero_conv = zero_module(nn.Conv2d(label_nc+norm_nc, norm_nc, 1, 1, 0))
|
37 |
+
self.mask = mask
|
38 |
+
|
39 |
+
def forward(self, hidden_states, h_ori=None):
|
40 |
+
# with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
|
41 |
+
c, h = hidden_states
|
42 |
+
if not self.mask:
|
43 |
+
h = self.zero_conv(torch.cat([c, h], dim=1))
|
44 |
+
else:
|
45 |
+
h = self.zero_conv(torch.cat([c, h], dim=1)) * torch.zeros_like(h)
|
46 |
+
if h_ori is not None:
|
47 |
+
h = torch.cat([h_ori, h], dim=1)
|
48 |
+
return h
|
49 |
+
|
50 |
+
|
51 |
+
class SFT(nn.Module):
|
52 |
+
def __init__(self, label_nc, norm_nc, mask=False):
|
53 |
+
super().__init__()
|
54 |
+
|
55 |
+
# param_free_norm_type = str(parsed.group(1))
|
56 |
+
ks = 3
|
57 |
+
pw = ks // 2
|
58 |
+
|
59 |
+
self.mask = mask
|
60 |
+
|
61 |
+
nhidden = 128
|
62 |
+
|
63 |
+
self.mlp_shared = nn.Sequential(
|
64 |
+
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
|
65 |
+
nn.SiLU()
|
66 |
+
)
|
67 |
+
self.mul = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
|
68 |
+
self.add = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
|
69 |
+
|
70 |
+
def forward(self, hidden_states, mask=False):
|
71 |
+
|
72 |
+
c, h = hidden_states
|
73 |
+
mask = mask or self.mask
|
74 |
+
assert mask is False
|
75 |
+
|
76 |
+
actv = self.mlp_shared(c)
|
77 |
+
gamma = self.mul(actv)
|
78 |
+
beta = self.add(actv)
|
79 |
+
|
80 |
+
if self.mask:
|
81 |
+
gamma = gamma * torch.zeros_like(gamma)
|
82 |
+
beta = beta * torch.zeros_like(beta)
|
83 |
+
# gamma_ori, gamma_res = torch.split(gamma, [h_ori_c, h_c], dim=1)
|
84 |
+
# beta_ori, beta_res = torch.split(beta, [h_ori_c, h_c], dim=1)
|
85 |
+
# print(gamma_ori.mean(), gamma_res.mean(), beta_ori.mean(), beta_res.mean())
|
86 |
+
h = h * (gamma + 1) + beta
|
87 |
+
# sample_ori, sample_res = torch.split(h, [h_ori_c, h_c], dim=1)
|
88 |
+
# print(sample_ori.mean(), sample_res.mean())
|
89 |
+
|
90 |
+
return h
|
91 |
+
|
92 |
+
|
93 |
+
@dataclass
|
94 |
+
class AggregatorOutput(BaseOutput):
|
95 |
+
"""
|
96 |
+
The output of [`Aggregator`].
|
97 |
+
|
98 |
+
Args:
|
99 |
+
down_block_res_samples (`tuple[torch.Tensor]`):
|
100 |
+
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
101 |
+
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
102 |
+
used to condition the original UNet's downsampling activations.
|
103 |
+
mid_down_block_re_sample (`torch.Tensor`):
|
104 |
+
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
105 |
+
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
106 |
+
Output can be used to condition the original UNet's middle block activation.
|
107 |
+
"""
|
108 |
+
|
109 |
+
down_block_res_samples: Tuple[torch.Tensor]
|
110 |
+
mid_block_res_sample: torch.Tensor
|
111 |
+
|
112 |
+
|
113 |
+
class ConditioningEmbedding(nn.Module):
|
114 |
+
"""
|
115 |
+
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
116 |
+
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
117 |
+
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
118 |
+
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
|
119 |
+
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
120 |
+
model) to encode image-space conditions ... into feature maps ..."
|
121 |
+
"""
|
122 |
+
|
123 |
+
def __init__(
|
124 |
+
self,
|
125 |
+
conditioning_embedding_channels: int,
|
126 |
+
conditioning_channels: int = 3,
|
127 |
+
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
128 |
+
):
|
129 |
+
super().__init__()
|
130 |
+
|
131 |
+
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
132 |
+
|
133 |
+
self.blocks = nn.ModuleList([])
|
134 |
+
|
135 |
+
for i in range(len(block_out_channels) - 1):
|
136 |
+
channel_in = block_out_channels[i]
|
137 |
+
channel_out = block_out_channels[i + 1]
|
138 |
+
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
|
139 |
+
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
140 |
+
|
141 |
+
self.conv_out = zero_module(
|
142 |
+
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
|
143 |
+
)
|
144 |
+
|
145 |
+
def forward(self, conditioning):
|
146 |
+
embedding = self.conv_in(conditioning)
|
147 |
+
embedding = F.silu(embedding)
|
148 |
+
|
149 |
+
for block in self.blocks:
|
150 |
+
embedding = block(embedding)
|
151 |
+
embedding = F.silu(embedding)
|
152 |
+
|
153 |
+
embedding = self.conv_out(embedding)
|
154 |
+
|
155 |
+
return embedding
|
156 |
+
|
157 |
+
|
158 |
+
class Aggregator(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
159 |
+
"""
|
160 |
+
Aggregator model.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
in_channels (`int`, defaults to 4):
|
164 |
+
The number of channels in the input sample.
|
165 |
+
flip_sin_to_cos (`bool`, defaults to `True`):
|
166 |
+
Whether to flip the sin to cos in the time embedding.
|
167 |
+
freq_shift (`int`, defaults to 0):
|
168 |
+
The frequency shift to apply to the time embedding.
|
169 |
+
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
170 |
+
The tuple of downsample blocks to use.
|
171 |
+
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
172 |
+
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
173 |
+
The tuple of output channels for each block.
|
174 |
+
layers_per_block (`int`, defaults to 2):
|
175 |
+
The number of layers per block.
|
176 |
+
downsample_padding (`int`, defaults to 1):
|
177 |
+
The padding to use for the downsampling convolution.
|
178 |
+
mid_block_scale_factor (`float`, defaults to 1):
|
179 |
+
The scale factor to use for the mid block.
|
180 |
+
act_fn (`str`, defaults to "silu"):
|
181 |
+
The activation function to use.
|
182 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
183 |
+
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
184 |
+
in post-processing.
|
185 |
+
norm_eps (`float`, defaults to 1e-5):
|
186 |
+
The epsilon to use for the normalization.
|
187 |
+
cross_attention_dim (`int`, defaults to 1280):
|
188 |
+
The dimension of the cross attention features.
|
189 |
+
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
190 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
191 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
192 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
193 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
194 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
195 |
+
dimension to `cross_attention_dim`.
|
196 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
197 |
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
198 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
199 |
+
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
200 |
+
The dimension of the attention heads.
|
201 |
+
use_linear_projection (`bool`, defaults to `False`):
|
202 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
203 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
204 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
205 |
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
206 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
207 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
208 |
+
num_class_embeds (`int`, *optional*, defaults to 0):
|
209 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
210 |
+
class conditioning with `class_embed_type` equal to `None`.
|
211 |
+
upcast_attention (`bool`, defaults to `False`):
|
212 |
+
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
213 |
+
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
214 |
+
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
215 |
+
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
216 |
+
`class_embed_type="projection"`.
|
217 |
+
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
218 |
+
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
219 |
+
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
220 |
+
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
221 |
+
global_pool_conditions (`bool`, defaults to `False`):
|
222 |
+
TODO(Patrick) - unused parameter.
|
223 |
+
addition_embed_type_num_heads (`int`, defaults to 64):
|
224 |
+
The number of heads to use for the `TextTimeEmbedding` layer.
|
225 |
+
"""
|
226 |
+
|
227 |
+
_supports_gradient_checkpointing = True
|
228 |
+
|
229 |
+
@register_to_config
|
230 |
+
def __init__(
|
231 |
+
self,
|
232 |
+
in_channels: int = 4,
|
233 |
+
conditioning_channels: int = 3,
|
234 |
+
flip_sin_to_cos: bool = True,
|
235 |
+
freq_shift: int = 0,
|
236 |
+
down_block_types: Tuple[str, ...] = (
|
237 |
+
"CrossAttnDownBlock2D",
|
238 |
+
"CrossAttnDownBlock2D",
|
239 |
+
"CrossAttnDownBlock2D",
|
240 |
+
"DownBlock2D",
|
241 |
+
),
|
242 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
243 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
244 |
+
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
245 |
+
layers_per_block: int = 2,
|
246 |
+
downsample_padding: int = 1,
|
247 |
+
mid_block_scale_factor: float = 1,
|
248 |
+
act_fn: str = "silu",
|
249 |
+
norm_num_groups: Optional[int] = 32,
|
250 |
+
norm_eps: float = 1e-5,
|
251 |
+
cross_attention_dim: int = 1280,
|
252 |
+
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
253 |
+
encoder_hid_dim: Optional[int] = None,
|
254 |
+
encoder_hid_dim_type: Optional[str] = None,
|
255 |
+
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
256 |
+
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
257 |
+
use_linear_projection: bool = False,
|
258 |
+
class_embed_type: Optional[str] = None,
|
259 |
+
addition_embed_type: Optional[str] = None,
|
260 |
+
addition_time_embed_dim: Optional[int] = None,
|
261 |
+
num_class_embeds: Optional[int] = None,
|
262 |
+
upcast_attention: bool = False,
|
263 |
+
resnet_time_scale_shift: str = "default",
|
264 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
265 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
266 |
+
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
267 |
+
global_pool_conditions: bool = False,
|
268 |
+
addition_embed_type_num_heads: int = 64,
|
269 |
+
pad_concat: bool = False,
|
270 |
+
):
|
271 |
+
super().__init__()
|
272 |
+
|
273 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
274 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
275 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
276 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
277 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
278 |
+
# which is why we correct for the naming here.
|
279 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
280 |
+
self.pad_concat = pad_concat
|
281 |
+
|
282 |
+
# Check inputs
|
283 |
+
if len(block_out_channels) != len(down_block_types):
|
284 |
+
raise ValueError(
|
285 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
286 |
+
)
|
287 |
+
|
288 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
289 |
+
raise ValueError(
|
290 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
291 |
+
)
|
292 |
+
|
293 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
294 |
+
raise ValueError(
|
295 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
296 |
+
)
|
297 |
+
|
298 |
+
if isinstance(transformer_layers_per_block, int):
|
299 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
300 |
+
|
301 |
+
# input
|
302 |
+
conv_in_kernel = 3
|
303 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
304 |
+
self.conv_in = nn.Conv2d(
|
305 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
306 |
+
)
|
307 |
+
|
308 |
+
# time
|
309 |
+
time_embed_dim = block_out_channels[0] * 4
|
310 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
311 |
+
timestep_input_dim = block_out_channels[0]
|
312 |
+
self.time_embedding = TimestepEmbedding(
|
313 |
+
timestep_input_dim,
|
314 |
+
time_embed_dim,
|
315 |
+
act_fn=act_fn,
|
316 |
+
)
|
317 |
+
|
318 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
319 |
+
encoder_hid_dim_type = "text_proj"
|
320 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
321 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
322 |
+
|
323 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
324 |
+
raise ValueError(
|
325 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
326 |
+
)
|
327 |
+
|
328 |
+
if encoder_hid_dim_type == "text_proj":
|
329 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
330 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
331 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
332 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
333 |
+
# case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
|
334 |
+
self.encoder_hid_proj = TextImageProjection(
|
335 |
+
text_embed_dim=encoder_hid_dim,
|
336 |
+
image_embed_dim=cross_attention_dim,
|
337 |
+
cross_attention_dim=cross_attention_dim,
|
338 |
+
)
|
339 |
+
|
340 |
+
elif encoder_hid_dim_type is not None:
|
341 |
+
raise ValueError(
|
342 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
343 |
+
)
|
344 |
+
else:
|
345 |
+
self.encoder_hid_proj = None
|
346 |
+
|
347 |
+
# class embedding
|
348 |
+
if class_embed_type is None and num_class_embeds is not None:
|
349 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
350 |
+
elif class_embed_type == "timestep":
|
351 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
352 |
+
elif class_embed_type == "identity":
|
353 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
354 |
+
elif class_embed_type == "projection":
|
355 |
+
if projection_class_embeddings_input_dim is None:
|
356 |
+
raise ValueError(
|
357 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
358 |
+
)
|
359 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
360 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
361 |
+
# 2. it projects from an arbitrary input dimension.
|
362 |
+
#
|
363 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
364 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
365 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
366 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
367 |
+
else:
|
368 |
+
self.class_embedding = None
|
369 |
+
|
370 |
+
if addition_embed_type == "text":
|
371 |
+
if encoder_hid_dim is not None:
|
372 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
373 |
+
else:
|
374 |
+
text_time_embedding_from_dim = cross_attention_dim
|
375 |
+
|
376 |
+
self.add_embedding = TextTimeEmbedding(
|
377 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
378 |
+
)
|
379 |
+
elif addition_embed_type == "text_image":
|
380 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
381 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
382 |
+
# case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
|
383 |
+
self.add_embedding = TextImageTimeEmbedding(
|
384 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
385 |
+
)
|
386 |
+
elif addition_embed_type == "text_time":
|
387 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
388 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
389 |
+
|
390 |
+
elif addition_embed_type is not None:
|
391 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
392 |
+
|
393 |
+
# control net conditioning embedding
|
394 |
+
self.ref_conv_in = nn.Conv2d(
|
395 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
396 |
+
)
|
397 |
+
|
398 |
+
self.down_blocks = nn.ModuleList([])
|
399 |
+
self.controlnet_down_blocks = nn.ModuleList([])
|
400 |
+
|
401 |
+
if isinstance(only_cross_attention, bool):
|
402 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
403 |
+
|
404 |
+
if isinstance(attention_head_dim, int):
|
405 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
406 |
+
|
407 |
+
if isinstance(num_attention_heads, int):
|
408 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
409 |
+
|
410 |
+
# down
|
411 |
+
output_channel = block_out_channels[0]
|
412 |
+
|
413 |
+
# controlnet_block = ZeroConv(output_channel, output_channel)
|
414 |
+
controlnet_block = nn.Sequential(
|
415 |
+
SFT(output_channel, output_channel),
|
416 |
+
zero_module(nn.Conv2d(output_channel, output_channel, kernel_size=1))
|
417 |
+
)
|
418 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
419 |
+
|
420 |
+
for i, down_block_type in enumerate(down_block_types):
|
421 |
+
input_channel = output_channel
|
422 |
+
output_channel = block_out_channels[i]
|
423 |
+
is_final_block = i == len(block_out_channels) - 1
|
424 |
+
|
425 |
+
down_block = get_down_block(
|
426 |
+
down_block_type,
|
427 |
+
num_layers=layers_per_block,
|
428 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
429 |
+
in_channels=input_channel,
|
430 |
+
out_channels=output_channel,
|
431 |
+
temb_channels=time_embed_dim,
|
432 |
+
add_downsample=not is_final_block,
|
433 |
+
resnet_eps=norm_eps,
|
434 |
+
resnet_act_fn=act_fn,
|
435 |
+
resnet_groups=norm_num_groups,
|
436 |
+
cross_attention_dim=cross_attention_dim,
|
437 |
+
num_attention_heads=num_attention_heads[i],
|
438 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
439 |
+
downsample_padding=downsample_padding,
|
440 |
+
use_linear_projection=use_linear_projection,
|
441 |
+
only_cross_attention=only_cross_attention[i],
|
442 |
+
upcast_attention=upcast_attention,
|
443 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
444 |
+
)
|
445 |
+
self.down_blocks.append(down_block)
|
446 |
+
|
447 |
+
for _ in range(layers_per_block):
|
448 |
+
# controlnet_block = ZeroConv(output_channel, output_channel)
|
449 |
+
controlnet_block = nn.Sequential(
|
450 |
+
SFT(output_channel, output_channel),
|
451 |
+
zero_module(nn.Conv2d(output_channel, output_channel, kernel_size=1))
|
452 |
+
)
|
453 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
454 |
+
|
455 |
+
if not is_final_block:
|
456 |
+
# controlnet_block = ZeroConv(output_channel, output_channel)
|
457 |
+
controlnet_block = nn.Sequential(
|
458 |
+
SFT(output_channel, output_channel),
|
459 |
+
zero_module(nn.Conv2d(output_channel, output_channel, kernel_size=1))
|
460 |
+
)
|
461 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
462 |
+
|
463 |
+
# mid
|
464 |
+
mid_block_channel = block_out_channels[-1]
|
465 |
+
|
466 |
+
# controlnet_block = ZeroConv(mid_block_channel, mid_block_channel)
|
467 |
+
controlnet_block = nn.Sequential(
|
468 |
+
SFT(mid_block_channel, mid_block_channel),
|
469 |
+
zero_module(nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1))
|
470 |
+
)
|
471 |
+
self.controlnet_mid_block = controlnet_block
|
472 |
+
|
473 |
+
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
474 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
475 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
476 |
+
in_channels=mid_block_channel,
|
477 |
+
temb_channels=time_embed_dim,
|
478 |
+
resnet_eps=norm_eps,
|
479 |
+
resnet_act_fn=act_fn,
|
480 |
+
output_scale_factor=mid_block_scale_factor,
|
481 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
482 |
+
cross_attention_dim=cross_attention_dim,
|
483 |
+
num_attention_heads=num_attention_heads[-1],
|
484 |
+
resnet_groups=norm_num_groups,
|
485 |
+
use_linear_projection=use_linear_projection,
|
486 |
+
upcast_attention=upcast_attention,
|
487 |
+
)
|
488 |
+
elif mid_block_type == "UNetMidBlock2D":
|
489 |
+
self.mid_block = UNetMidBlock2D(
|
490 |
+
in_channels=block_out_channels[-1],
|
491 |
+
temb_channels=time_embed_dim,
|
492 |
+
num_layers=0,
|
493 |
+
resnet_eps=norm_eps,
|
494 |
+
resnet_act_fn=act_fn,
|
495 |
+
output_scale_factor=mid_block_scale_factor,
|
496 |
+
resnet_groups=norm_num_groups,
|
497 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
498 |
+
add_attention=False,
|
499 |
+
)
|
500 |
+
else:
|
501 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
502 |
+
|
503 |
+
@classmethod
|
504 |
+
def from_unet(
|
505 |
+
cls,
|
506 |
+
unet: UNet2DConditionModel,
|
507 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
508 |
+
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
509 |
+
load_weights_from_unet: bool = True,
|
510 |
+
conditioning_channels: int = 3,
|
511 |
+
):
|
512 |
+
r"""
|
513 |
+
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
|
514 |
+
|
515 |
+
Parameters:
|
516 |
+
unet (`UNet2DConditionModel`):
|
517 |
+
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
|
518 |
+
where applicable.
|
519 |
+
"""
|
520 |
+
transformer_layers_per_block = (
|
521 |
+
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
522 |
+
)
|
523 |
+
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
524 |
+
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
525 |
+
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
526 |
+
addition_time_embed_dim = (
|
527 |
+
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
528 |
+
)
|
529 |
+
|
530 |
+
controlnet = cls(
|
531 |
+
encoder_hid_dim=encoder_hid_dim,
|
532 |
+
encoder_hid_dim_type=encoder_hid_dim_type,
|
533 |
+
addition_embed_type=addition_embed_type,
|
534 |
+
addition_time_embed_dim=addition_time_embed_dim,
|
535 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
536 |
+
in_channels=unet.config.in_channels,
|
537 |
+
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
538 |
+
freq_shift=unet.config.freq_shift,
|
539 |
+
down_block_types=unet.config.down_block_types,
|
540 |
+
only_cross_attention=unet.config.only_cross_attention,
|
541 |
+
block_out_channels=unet.config.block_out_channels,
|
542 |
+
layers_per_block=unet.config.layers_per_block,
|
543 |
+
downsample_padding=unet.config.downsample_padding,
|
544 |
+
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
545 |
+
act_fn=unet.config.act_fn,
|
546 |
+
norm_num_groups=unet.config.norm_num_groups,
|
547 |
+
norm_eps=unet.config.norm_eps,
|
548 |
+
cross_attention_dim=unet.config.cross_attention_dim,
|
549 |
+
attention_head_dim=unet.config.attention_head_dim,
|
550 |
+
num_attention_heads=unet.config.num_attention_heads,
|
551 |
+
use_linear_projection=unet.config.use_linear_projection,
|
552 |
+
class_embed_type=unet.config.class_embed_type,
|
553 |
+
num_class_embeds=unet.config.num_class_embeds,
|
554 |
+
upcast_attention=unet.config.upcast_attention,
|
555 |
+
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
556 |
+
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
557 |
+
mid_block_type=unet.config.mid_block_type,
|
558 |
+
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
559 |
+
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
560 |
+
conditioning_channels=conditioning_channels,
|
561 |
+
)
|
562 |
+
|
563 |
+
if load_weights_from_unet:
|
564 |
+
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
|
565 |
+
controlnet.ref_conv_in.load_state_dict(unet.conv_in.state_dict())
|
566 |
+
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
567 |
+
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
568 |
+
|
569 |
+
if controlnet.class_embedding:
|
570 |
+
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
571 |
+
|
572 |
+
if hasattr(controlnet, "add_embedding"):
|
573 |
+
controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict())
|
574 |
+
|
575 |
+
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
|
576 |
+
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
|
577 |
+
|
578 |
+
return controlnet
|
579 |
+
|
580 |
+
@property
|
581 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
582 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
583 |
+
r"""
|
584 |
+
Returns:
|
585 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
586 |
+
indexed by its weight name.
|
587 |
+
"""
|
588 |
+
# set recursively
|
589 |
+
processors = {}
|
590 |
+
|
591 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
592 |
+
if hasattr(module, "get_processor"):
|
593 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
594 |
+
|
595 |
+
for sub_name, child in module.named_children():
|
596 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
597 |
+
|
598 |
+
return processors
|
599 |
+
|
600 |
+
for name, module in self.named_children():
|
601 |
+
fn_recursive_add_processors(name, module, processors)
|
602 |
+
|
603 |
+
return processors
|
604 |
+
|
605 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
606 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
607 |
+
r"""
|
608 |
+
Sets the attention processor to use to compute attention.
|
609 |
+
|
610 |
+
Parameters:
|
611 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
612 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
613 |
+
for **all** `Attention` layers.
|
614 |
+
|
615 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
616 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
617 |
+
|
618 |
+
"""
|
619 |
+
count = len(self.attn_processors.keys())
|
620 |
+
|
621 |
+
if isinstance(processor, dict) and len(processor) != count:
|
622 |
+
raise ValueError(
|
623 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
624 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
625 |
+
)
|
626 |
+
|
627 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
628 |
+
if hasattr(module, "set_processor"):
|
629 |
+
if not isinstance(processor, dict):
|
630 |
+
module.set_processor(processor)
|
631 |
+
else:
|
632 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
633 |
+
|
634 |
+
for sub_name, child in module.named_children():
|
635 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
636 |
+
|
637 |
+
for name, module in self.named_children():
|
638 |
+
fn_recursive_attn_processor(name, module, processor)
|
639 |
+
|
640 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
641 |
+
def set_default_attn_processor(self):
|
642 |
+
"""
|
643 |
+
Disables custom attention processors and sets the default attention implementation.
|
644 |
+
"""
|
645 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
646 |
+
processor = AttnAddedKVProcessor()
|
647 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
648 |
+
processor = AttnProcessor()
|
649 |
+
else:
|
650 |
+
raise ValueError(
|
651 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
652 |
+
)
|
653 |
+
|
654 |
+
self.set_attn_processor(processor)
|
655 |
+
|
656 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
657 |
+
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
658 |
+
r"""
|
659 |
+
Enable sliced attention computation.
|
660 |
+
|
661 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
662 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
663 |
+
|
664 |
+
Args:
|
665 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
666 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
667 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
668 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
669 |
+
must be a multiple of `slice_size`.
|
670 |
+
"""
|
671 |
+
sliceable_head_dims = []
|
672 |
+
|
673 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
674 |
+
if hasattr(module, "set_attention_slice"):
|
675 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
676 |
+
|
677 |
+
for child in module.children():
|
678 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
679 |
+
|
680 |
+
# retrieve number of attention layers
|
681 |
+
for module in self.children():
|
682 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
683 |
+
|
684 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
685 |
+
|
686 |
+
if slice_size == "auto":
|
687 |
+
# half the attention head size is usually a good trade-off between
|
688 |
+
# speed and memory
|
689 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
690 |
+
elif slice_size == "max":
|
691 |
+
# make smallest slice possible
|
692 |
+
slice_size = num_sliceable_layers * [1]
|
693 |
+
|
694 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
695 |
+
|
696 |
+
if len(slice_size) != len(sliceable_head_dims):
|
697 |
+
raise ValueError(
|
698 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
699 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
700 |
+
)
|
701 |
+
|
702 |
+
for i in range(len(slice_size)):
|
703 |
+
size = slice_size[i]
|
704 |
+
dim = sliceable_head_dims[i]
|
705 |
+
if size is not None and size > dim:
|
706 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
707 |
+
|
708 |
+
# Recursively walk through all the children.
|
709 |
+
# Any children which exposes the set_attention_slice method
|
710 |
+
# gets the message
|
711 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
712 |
+
if hasattr(module, "set_attention_slice"):
|
713 |
+
module.set_attention_slice(slice_size.pop())
|
714 |
+
|
715 |
+
for child in module.children():
|
716 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
717 |
+
|
718 |
+
reversed_slice_size = list(reversed(slice_size))
|
719 |
+
for module in self.children():
|
720 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
721 |
+
|
722 |
+
def process_encoder_hidden_states(
|
723 |
+
self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
|
724 |
+
) -> torch.Tensor:
|
725 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
726 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
727 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
728 |
+
# Kandinsky 2.1 - style
|
729 |
+
if "image_embeds" not in added_cond_kwargs:
|
730 |
+
raise ValueError(
|
731 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
732 |
+
)
|
733 |
+
|
734 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
735 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
736 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
737 |
+
# Kandinsky 2.2 - style
|
738 |
+
if "image_embeds" not in added_cond_kwargs:
|
739 |
+
raise ValueError(
|
740 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
741 |
+
)
|
742 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
743 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
744 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
745 |
+
if "image_embeds" not in added_cond_kwargs:
|
746 |
+
raise ValueError(
|
747 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
748 |
+
)
|
749 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
750 |
+
image_embeds = self.encoder_hid_proj(image_embeds)
|
751 |
+
encoder_hidden_states = (encoder_hidden_states, image_embeds)
|
752 |
+
return encoder_hidden_states
|
753 |
+
|
754 |
+
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
755 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
756 |
+
module.gradient_checkpointing = value
|
757 |
+
|
758 |
+
def forward(
|
759 |
+
self,
|
760 |
+
sample: torch.FloatTensor,
|
761 |
+
timestep: Union[torch.Tensor, float, int],
|
762 |
+
encoder_hidden_states: torch.Tensor,
|
763 |
+
controlnet_cond: torch.FloatTensor,
|
764 |
+
cat_dim: int = -2,
|
765 |
+
conditioning_scale: float = 1.0,
|
766 |
+
class_labels: Optional[torch.Tensor] = None,
|
767 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
768 |
+
attention_mask: Optional[torch.Tensor] = None,
|
769 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
770 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
771 |
+
return_dict: bool = True,
|
772 |
+
) -> Union[AggregatorOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
|
773 |
+
"""
|
774 |
+
The [`Aggregator`] forward method.
|
775 |
+
|
776 |
+
Args:
|
777 |
+
sample (`torch.FloatTensor`):
|
778 |
+
The noisy input tensor.
|
779 |
+
timestep (`Union[torch.Tensor, float, int]`):
|
780 |
+
The number of timesteps to denoise an input.
|
781 |
+
encoder_hidden_states (`torch.Tensor`):
|
782 |
+
The encoder hidden states.
|
783 |
+
controlnet_cond (`torch.FloatTensor`):
|
784 |
+
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
785 |
+
conditioning_scale (`float`, defaults to `1.0`):
|
786 |
+
The scale factor for ControlNet outputs.
|
787 |
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
788 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
789 |
+
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
790 |
+
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
791 |
+
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
792 |
+
embeddings.
|
793 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
794 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
795 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
796 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
797 |
+
added_cond_kwargs (`dict`):
|
798 |
+
Additional conditions for the Stable Diffusion XL UNet.
|
799 |
+
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
800 |
+
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
801 |
+
return_dict (`bool`, defaults to `True`):
|
802 |
+
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
803 |
+
|
804 |
+
Returns:
|
805 |
+
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
806 |
+
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
807 |
+
returned where the first element is the sample tensor.
|
808 |
+
"""
|
809 |
+
# check channel order
|
810 |
+
channel_order = self.config.controlnet_conditioning_channel_order
|
811 |
+
|
812 |
+
if channel_order == "rgb":
|
813 |
+
# in rgb order by default
|
814 |
+
...
|
815 |
+
else:
|
816 |
+
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
817 |
+
|
818 |
+
# prepare attention_mask
|
819 |
+
if attention_mask is not None:
|
820 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
821 |
+
attention_mask = attention_mask.unsqueeze(1)
|
822 |
+
|
823 |
+
# 1. time
|
824 |
+
timesteps = timestep
|
825 |
+
if not torch.is_tensor(timesteps):
|
826 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
827 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
828 |
+
is_mps = sample.device.type == "mps"
|
829 |
+
if isinstance(timestep, float):
|
830 |
+
dtype = torch.float32 if is_mps else torch.float64
|
831 |
+
else:
|
832 |
+
dtype = torch.int32 if is_mps else torch.int64
|
833 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
834 |
+
elif len(timesteps.shape) == 0:
|
835 |
+
timesteps = timesteps[None].to(sample.device)
|
836 |
+
|
837 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
838 |
+
timesteps = timesteps.expand(sample.shape[0])
|
839 |
+
|
840 |
+
t_emb = self.time_proj(timesteps)
|
841 |
+
|
842 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
843 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
844 |
+
# there might be better ways to encapsulate this.
|
845 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
846 |
+
|
847 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
848 |
+
aug_emb = None
|
849 |
+
|
850 |
+
if self.class_embedding is not None:
|
851 |
+
if class_labels is None:
|
852 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
853 |
+
|
854 |
+
if self.config.class_embed_type == "timestep":
|
855 |
+
class_labels = self.time_proj(class_labels)
|
856 |
+
|
857 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
858 |
+
emb = emb + class_emb
|
859 |
+
|
860 |
+
if self.config.addition_embed_type is not None:
|
861 |
+
if self.config.addition_embed_type == "text":
|
862 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
863 |
+
|
864 |
+
elif self.config.addition_embed_type == "text_time":
|
865 |
+
if "text_embeds" not in added_cond_kwargs:
|
866 |
+
raise ValueError(
|
867 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
868 |
+
)
|
869 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
870 |
+
if "time_ids" not in added_cond_kwargs:
|
871 |
+
raise ValueError(
|
872 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
873 |
+
)
|
874 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
875 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
876 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
877 |
+
|
878 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
879 |
+
add_embeds = add_embeds.to(emb.dtype)
|
880 |
+
aug_emb = self.add_embedding(add_embeds)
|
881 |
+
|
882 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
883 |
+
|
884 |
+
encoder_hidden_states = self.process_encoder_hidden_states(
|
885 |
+
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
886 |
+
)
|
887 |
+
|
888 |
+
# 2. prepare input
|
889 |
+
cond_latent = self.conv_in(sample)
|
890 |
+
ref_latent = self.ref_conv_in(controlnet_cond)
|
891 |
+
batch_size, channel, height, width = cond_latent.shape
|
892 |
+
if self.pad_concat:
|
893 |
+
if cat_dim == -2 or cat_dim == 2:
|
894 |
+
concat_pad = torch.zeros(batch_size, channel, 1, width)
|
895 |
+
elif cat_dim == -1 or cat_dim == 3:
|
896 |
+
concat_pad = torch.zeros(batch_size, channel, height, 1)
|
897 |
+
else:
|
898 |
+
raise ValueError(f"Aggregator shall concat along spatial dimension, but is asked to concat dim: {cat_dim}.")
|
899 |
+
concat_pad = concat_pad.to(cond_latent.device, dtype=cond_latent.dtype)
|
900 |
+
sample = torch.cat([cond_latent, concat_pad, ref_latent], dim=cat_dim)
|
901 |
+
else:
|
902 |
+
sample = torch.cat([cond_latent, ref_latent], dim=cat_dim)
|
903 |
+
|
904 |
+
# 3. down
|
905 |
+
down_block_res_samples = (sample,)
|
906 |
+
for downsample_block in self.down_blocks:
|
907 |
+
sample, res_samples = downsample_block(
|
908 |
+
hidden_states=sample,
|
909 |
+
temb=emb,
|
910 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
911 |
+
)
|
912 |
+
|
913 |
+
# rebuild sample: split and concat
|
914 |
+
if self.pad_concat:
|
915 |
+
batch_size, channel, height, width = sample.shape
|
916 |
+
if cat_dim == -2 or cat_dim == 2:
|
917 |
+
cond_latent = sample[:, :, :height//2, :]
|
918 |
+
ref_latent = sample[:, :, -(height//2):, :]
|
919 |
+
concat_pad = torch.zeros(batch_size, channel, 1, width)
|
920 |
+
elif cat_dim == -1 or cat_dim == 3:
|
921 |
+
cond_latent = sample[:, :, :, :width//2]
|
922 |
+
ref_latent = sample[:, :, :, -(width//2):]
|
923 |
+
concat_pad = torch.zeros(batch_size, channel, height, 1)
|
924 |
+
concat_pad = concat_pad.to(cond_latent.device, dtype=cond_latent.dtype)
|
925 |
+
sample = torch.cat([cond_latent, concat_pad, ref_latent], dim=cat_dim)
|
926 |
+
res_samples = res_samples[:-1] + (sample,)
|
927 |
+
|
928 |
+
down_block_res_samples += res_samples
|
929 |
+
|
930 |
+
# 4. mid
|
931 |
+
if self.mid_block is not None:
|
932 |
+
sample = self.mid_block(
|
933 |
+
sample,
|
934 |
+
emb,
|
935 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
936 |
+
)
|
937 |
+
|
938 |
+
# 5. split samples and SFT.
|
939 |
+
controlnet_down_block_res_samples = ()
|
940 |
+
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
941 |
+
batch_size, channel, height, width = down_block_res_sample.shape
|
942 |
+
if cat_dim == -2 or cat_dim == 2:
|
943 |
+
cond_latent = down_block_res_sample[:, :, :height//2, :]
|
944 |
+
ref_latent = down_block_res_sample[:, :, -(height//2):, :]
|
945 |
+
elif cat_dim == -1 or cat_dim == 3:
|
946 |
+
cond_latent = down_block_res_sample[:, :, :, :width//2]
|
947 |
+
ref_latent = down_block_res_sample[:, :, :, -(width//2):]
|
948 |
+
down_block_res_sample = controlnet_block((cond_latent, ref_latent), )
|
949 |
+
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
950 |
+
|
951 |
+
down_block_res_samples = controlnet_down_block_res_samples
|
952 |
+
|
953 |
+
batch_size, channel, height, width = sample.shape
|
954 |
+
if cat_dim == -2 or cat_dim == 2:
|
955 |
+
cond_latent = sample[:, :, :height//2, :]
|
956 |
+
ref_latent = sample[:, :, -(height//2):, :]
|
957 |
+
elif cat_dim == -1 or cat_dim == 3:
|
958 |
+
cond_latent = sample[:, :, :, :width//2]
|
959 |
+
ref_latent = sample[:, :, :, -(width//2):]
|
960 |
+
mid_block_res_sample = self.controlnet_mid_block((cond_latent, ref_latent), )
|
961 |
+
|
962 |
+
# 6. scaling
|
963 |
+
down_block_res_samples = [sample*conditioning_scale for sample in down_block_res_samples]
|
964 |
+
mid_block_res_sample = mid_block_res_sample*conditioning_scale
|
965 |
+
|
966 |
+
if self.config.global_pool_conditions:
|
967 |
+
down_block_res_samples = [
|
968 |
+
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
969 |
+
]
|
970 |
+
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
971 |
+
|
972 |
+
if not return_dict:
|
973 |
+
return (down_block_res_samples, mid_block_res_sample)
|
974 |
+
|
975 |
+
return AggregatorOutput(
|
976 |
+
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
977 |
+
)
|
978 |
+
|
979 |
+
|
980 |
+
def zero_module(module):
|
981 |
+
for p in module.parameters():
|
982 |
+
nn.init.zeros_(p)
|
983 |
+
return module
|
module/attention.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copy from diffusers.models.attention.py
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
from typing import Any, Dict, Optional
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from torch import nn
|
21 |
+
|
22 |
+
from diffusers.utils import deprecate, logging
|
23 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
24 |
+
from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
|
25 |
+
from diffusers.models.attention_processor import Attention
|
26 |
+
from diffusers.models.embeddings import SinusoidalPositionalEmbedding
|
27 |
+
from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
|
28 |
+
|
29 |
+
from module.min_sdxl import LoRACompatibleLinear, LoRALinearLayer
|
30 |
+
|
31 |
+
|
32 |
+
logger = logging.get_logger(__name__)
|
33 |
+
|
34 |
+
def create_custom_forward(module):
|
35 |
+
def custom_forward(*inputs):
|
36 |
+
return module(*inputs)
|
37 |
+
|
38 |
+
return custom_forward
|
39 |
+
|
40 |
+
def maybe_grad_checkpoint(resnet, attn, hidden_states, temb, encoder_hidden_states, adapter_hidden_states, do_ckpt=True):
|
41 |
+
|
42 |
+
if do_ckpt:
|
43 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
44 |
+
hidden_states, extracted_kv = torch.utils.checkpoint.checkpoint(
|
45 |
+
create_custom_forward(attn), hidden_states, encoder_hidden_states, adapter_hidden_states, use_reentrant=False
|
46 |
+
)
|
47 |
+
else:
|
48 |
+
hidden_states = resnet(hidden_states, temb)
|
49 |
+
hidden_states, extracted_kv = attn(
|
50 |
+
hidden_states,
|
51 |
+
encoder_hidden_states=encoder_hidden_states,
|
52 |
+
adapter_hidden_states=adapter_hidden_states,
|
53 |
+
)
|
54 |
+
return hidden_states, extracted_kv
|
55 |
+
|
56 |
+
|
57 |
+
def init_lora_in_attn(attn_module, rank: int = 4, is_kvcopy=False):
|
58 |
+
# Set the `lora_layer` attribute of the attention-related matrices.
|
59 |
+
|
60 |
+
attn_module.to_k.set_lora_layer(
|
61 |
+
LoRALinearLayer(
|
62 |
+
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=rank
|
63 |
+
)
|
64 |
+
)
|
65 |
+
attn_module.to_v.set_lora_layer(
|
66 |
+
LoRALinearLayer(
|
67 |
+
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=rank
|
68 |
+
)
|
69 |
+
)
|
70 |
+
|
71 |
+
if not is_kvcopy:
|
72 |
+
attn_module.to_q.set_lora_layer(
|
73 |
+
LoRALinearLayer(
|
74 |
+
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=rank
|
75 |
+
)
|
76 |
+
)
|
77 |
+
|
78 |
+
attn_module.to_out[0].set_lora_layer(
|
79 |
+
LoRALinearLayer(
|
80 |
+
in_features=attn_module.to_out[0].in_features,
|
81 |
+
out_features=attn_module.to_out[0].out_features,
|
82 |
+
rank=rank,
|
83 |
+
)
|
84 |
+
)
|
85 |
+
|
86 |
+
def drop_kvs(encoder_kvs, drop_chance):
|
87 |
+
for layer in encoder_kvs:
|
88 |
+
len_tokens = encoder_kvs[layer].self_attention.k.shape[1]
|
89 |
+
idx_to_keep = (torch.rand(len_tokens) > drop_chance)
|
90 |
+
|
91 |
+
encoder_kvs[layer].self_attention.k = encoder_kvs[layer].self_attention.k[:, idx_to_keep]
|
92 |
+
encoder_kvs[layer].self_attention.v = encoder_kvs[layer].self_attention.v[:, idx_to_keep]
|
93 |
+
|
94 |
+
return encoder_kvs
|
95 |
+
|
96 |
+
def clone_kvs(encoder_kvs):
|
97 |
+
cloned_kvs = {}
|
98 |
+
for layer in encoder_kvs:
|
99 |
+
sa_cpy = KVCache(k=encoder_kvs[layer].self_attention.k.clone(),
|
100 |
+
v=encoder_kvs[layer].self_attention.v.clone())
|
101 |
+
|
102 |
+
ca_cpy = KVCache(k=encoder_kvs[layer].cross_attention.k.clone(),
|
103 |
+
v=encoder_kvs[layer].cross_attention.v.clone())
|
104 |
+
|
105 |
+
cloned_layer_cache = AttentionCache(self_attention=sa_cpy, cross_attention=ca_cpy)
|
106 |
+
|
107 |
+
cloned_kvs[layer] = cloned_layer_cache
|
108 |
+
|
109 |
+
return cloned_kvs
|
110 |
+
|
111 |
+
|
112 |
+
class KVCache(object):
|
113 |
+
def __init__(self, k, v):
|
114 |
+
self.k = k
|
115 |
+
self.v = v
|
116 |
+
|
117 |
+
class AttentionCache(object):
|
118 |
+
def __init__(self, self_attention: KVCache, cross_attention: KVCache):
|
119 |
+
self.self_attention = self_attention
|
120 |
+
self.cross_attention = cross_attention
|
121 |
+
|
122 |
+
class KVCopy(nn.Module):
|
123 |
+
def __init__(
|
124 |
+
self, inner_dim, cross_attention_dim=None,
|
125 |
+
):
|
126 |
+
super(KVCopy, self).__init__()
|
127 |
+
|
128 |
+
in_dim = cross_attention_dim or inner_dim
|
129 |
+
|
130 |
+
self.to_k = LoRACompatibleLinear(in_dim, inner_dim, bias=False)
|
131 |
+
self.to_v = LoRACompatibleLinear(in_dim, inner_dim, bias=False)
|
132 |
+
|
133 |
+
def forward(self, hidden_states):
|
134 |
+
|
135 |
+
k = self.to_k(hidden_states)
|
136 |
+
v = self.to_v(hidden_states)
|
137 |
+
|
138 |
+
return KVCache(k=k, v=v)
|
139 |
+
|
140 |
+
def init_kv_copy(self, source_attn):
|
141 |
+
with torch.no_grad():
|
142 |
+
self.to_k.weight.copy_(source_attn.to_k.weight)
|
143 |
+
self.to_v.weight.copy_(source_attn.to_v.weight)
|
144 |
+
|
145 |
+
|
146 |
+
class FeedForward(nn.Module):
|
147 |
+
r"""
|
148 |
+
A feed-forward layer.
|
149 |
+
|
150 |
+
Parameters:
|
151 |
+
dim (`int`): The number of channels in the input.
|
152 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
153 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
154 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
155 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
156 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
157 |
+
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
158 |
+
"""
|
159 |
+
|
160 |
+
def __init__(
|
161 |
+
self,
|
162 |
+
dim: int,
|
163 |
+
dim_out: Optional[int] = None,
|
164 |
+
mult: int = 4,
|
165 |
+
dropout: float = 0.0,
|
166 |
+
activation_fn: str = "geglu",
|
167 |
+
final_dropout: bool = False,
|
168 |
+
inner_dim=None,
|
169 |
+
bias: bool = True,
|
170 |
+
):
|
171 |
+
super().__init__()
|
172 |
+
if inner_dim is None:
|
173 |
+
inner_dim = int(dim * mult)
|
174 |
+
dim_out = dim_out if dim_out is not None else dim
|
175 |
+
|
176 |
+
if activation_fn == "gelu":
|
177 |
+
act_fn = GELU(dim, inner_dim, bias=bias)
|
178 |
+
if activation_fn == "gelu-approximate":
|
179 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
|
180 |
+
elif activation_fn == "geglu":
|
181 |
+
act_fn = GEGLU(dim, inner_dim, bias=bias)
|
182 |
+
elif activation_fn == "geglu-approximate":
|
183 |
+
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
|
184 |
+
|
185 |
+
self.net = nn.ModuleList([])
|
186 |
+
# project in
|
187 |
+
self.net.append(act_fn)
|
188 |
+
# project dropout
|
189 |
+
self.net.append(nn.Dropout(dropout))
|
190 |
+
# project out
|
191 |
+
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
|
192 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
193 |
+
if final_dropout:
|
194 |
+
self.net.append(nn.Dropout(dropout))
|
195 |
+
|
196 |
+
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
197 |
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
198 |
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
199 |
+
deprecate("scale", "1.0.0", deprecation_message)
|
200 |
+
for module in self.net:
|
201 |
+
hidden_states = module(hidden_states)
|
202 |
+
return hidden_states
|
203 |
+
|
204 |
+
|
205 |
+
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
|
206 |
+
# "feed_forward_chunk_size" can be used to save memory
|
207 |
+
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
208 |
+
raise ValueError(
|
209 |
+
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
210 |
+
)
|
211 |
+
|
212 |
+
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
|
213 |
+
ff_output = torch.cat(
|
214 |
+
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
215 |
+
dim=chunk_dim,
|
216 |
+
)
|
217 |
+
return ff_output
|
218 |
+
|
219 |
+
|
220 |
+
@maybe_allow_in_graph
|
221 |
+
class GatedSelfAttentionDense(nn.Module):
|
222 |
+
r"""
|
223 |
+
A gated self-attention dense layer that combines visual features and object features.
|
224 |
+
|
225 |
+
Parameters:
|
226 |
+
query_dim (`int`): The number of channels in the query.
|
227 |
+
context_dim (`int`): The number of channels in the context.
|
228 |
+
n_heads (`int`): The number of heads to use for attention.
|
229 |
+
d_head (`int`): The number of channels in each head.
|
230 |
+
"""
|
231 |
+
|
232 |
+
def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
|
233 |
+
super().__init__()
|
234 |
+
|
235 |
+
# we need a linear projection since we need cat visual feature and obj feature
|
236 |
+
self.linear = nn.Linear(context_dim, query_dim)
|
237 |
+
|
238 |
+
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
|
239 |
+
self.ff = FeedForward(query_dim, activation_fn="geglu")
|
240 |
+
|
241 |
+
self.norm1 = nn.LayerNorm(query_dim)
|
242 |
+
self.norm2 = nn.LayerNorm(query_dim)
|
243 |
+
|
244 |
+
self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
|
245 |
+
self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
|
246 |
+
|
247 |
+
self.enabled = True
|
248 |
+
|
249 |
+
def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
|
250 |
+
if not self.enabled:
|
251 |
+
return x
|
252 |
+
|
253 |
+
n_visual = x.shape[1]
|
254 |
+
objs = self.linear(objs)
|
255 |
+
|
256 |
+
x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
|
257 |
+
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
|
258 |
+
|
259 |
+
return x
|
module/diffusers_vae/autoencoder_kl.py
ADDED
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Dict, Optional, Tuple, Union
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
20 |
+
from diffusers.loaders import FromOriginalVAEMixin
|
21 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
22 |
+
from diffusers.models.attention_processor import (
|
23 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
24 |
+
CROSS_ATTENTION_PROCESSORS,
|
25 |
+
Attention,
|
26 |
+
AttentionProcessor,
|
27 |
+
AttnAddedKVProcessor,
|
28 |
+
AttnProcessor,
|
29 |
+
)
|
30 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
31 |
+
from diffusers.models.modeling_utils import ModelMixin
|
32 |
+
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
33 |
+
|
34 |
+
|
35 |
+
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
36 |
+
r"""
|
37 |
+
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
38 |
+
|
39 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
40 |
+
for all models (such as downloading or saving).
|
41 |
+
|
42 |
+
Parameters:
|
43 |
+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
44 |
+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
45 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
46 |
+
Tuple of downsample block types.
|
47 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
48 |
+
Tuple of upsample block types.
|
49 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
50 |
+
Tuple of block output channels.
|
51 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
52 |
+
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
53 |
+
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
54 |
+
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
55 |
+
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
56 |
+
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
57 |
+
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
58 |
+
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
59 |
+
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
60 |
+
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
61 |
+
force_upcast (`bool`, *optional*, default to `True`):
|
62 |
+
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
63 |
+
can be fine-tuned / trained to a lower range without loosing too much precision in which case
|
64 |
+
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
|
65 |
+
"""
|
66 |
+
|
67 |
+
_supports_gradient_checkpointing = True
|
68 |
+
|
69 |
+
@register_to_config
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
in_channels: int = 3,
|
73 |
+
out_channels: int = 3,
|
74 |
+
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
75 |
+
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
76 |
+
block_out_channels: Tuple[int] = (64,),
|
77 |
+
layers_per_block: int = 1,
|
78 |
+
act_fn: str = "silu",
|
79 |
+
latent_channels: int = 4,
|
80 |
+
norm_num_groups: int = 32,
|
81 |
+
sample_size: int = 32,
|
82 |
+
scaling_factor: float = 0.18215,
|
83 |
+
force_upcast: float = True,
|
84 |
+
):
|
85 |
+
super().__init__()
|
86 |
+
|
87 |
+
# pass init params to Encoder
|
88 |
+
self.encoder = Encoder(
|
89 |
+
in_channels=in_channels,
|
90 |
+
out_channels=latent_channels,
|
91 |
+
down_block_types=down_block_types,
|
92 |
+
block_out_channels=block_out_channels,
|
93 |
+
layers_per_block=layers_per_block,
|
94 |
+
act_fn=act_fn,
|
95 |
+
norm_num_groups=norm_num_groups,
|
96 |
+
double_z=True,
|
97 |
+
)
|
98 |
+
|
99 |
+
# pass init params to Decoder
|
100 |
+
self.decoder = Decoder(
|
101 |
+
in_channels=latent_channels,
|
102 |
+
out_channels=out_channels,
|
103 |
+
up_block_types=up_block_types,
|
104 |
+
block_out_channels=block_out_channels,
|
105 |
+
layers_per_block=layers_per_block,
|
106 |
+
norm_num_groups=norm_num_groups,
|
107 |
+
act_fn=act_fn,
|
108 |
+
)
|
109 |
+
|
110 |
+
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
111 |
+
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
|
112 |
+
|
113 |
+
self.use_slicing = False
|
114 |
+
self.use_tiling = False
|
115 |
+
|
116 |
+
# only relevant if vae tiling is enabled
|
117 |
+
self.tile_sample_min_size = self.config.sample_size
|
118 |
+
sample_size = (
|
119 |
+
self.config.sample_size[0]
|
120 |
+
if isinstance(self.config.sample_size, (list, tuple))
|
121 |
+
else self.config.sample_size
|
122 |
+
)
|
123 |
+
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
124 |
+
self.tile_overlap_factor = 0.25
|
125 |
+
|
126 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
127 |
+
if isinstance(module, (Encoder, Decoder)):
|
128 |
+
module.gradient_checkpointing = value
|
129 |
+
|
130 |
+
def enable_tiling(self, use_tiling: bool = True):
|
131 |
+
r"""
|
132 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
133 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
134 |
+
processing larger images.
|
135 |
+
"""
|
136 |
+
self.use_tiling = use_tiling
|
137 |
+
|
138 |
+
def disable_tiling(self):
|
139 |
+
r"""
|
140 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
141 |
+
decoding in one step.
|
142 |
+
"""
|
143 |
+
self.enable_tiling(False)
|
144 |
+
|
145 |
+
def enable_slicing(self):
|
146 |
+
r"""
|
147 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
148 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
149 |
+
"""
|
150 |
+
self.use_slicing = True
|
151 |
+
|
152 |
+
def disable_slicing(self):
|
153 |
+
r"""
|
154 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
155 |
+
decoding in one step.
|
156 |
+
"""
|
157 |
+
self.use_slicing = False
|
158 |
+
|
159 |
+
@property
|
160 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
161 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
162 |
+
r"""
|
163 |
+
Returns:
|
164 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
165 |
+
indexed by its weight name.
|
166 |
+
"""
|
167 |
+
# set recursively
|
168 |
+
processors = {}
|
169 |
+
|
170 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
171 |
+
if hasattr(module, "get_processor"):
|
172 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
173 |
+
|
174 |
+
for sub_name, child in module.named_children():
|
175 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
176 |
+
|
177 |
+
return processors
|
178 |
+
|
179 |
+
for name, module in self.named_children():
|
180 |
+
fn_recursive_add_processors(name, module, processors)
|
181 |
+
|
182 |
+
return processors
|
183 |
+
|
184 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
185 |
+
def set_attn_processor(
|
186 |
+
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
187 |
+
):
|
188 |
+
r"""
|
189 |
+
Sets the attention processor to use to compute attention.
|
190 |
+
|
191 |
+
Parameters:
|
192 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
193 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
194 |
+
for **all** `Attention` layers.
|
195 |
+
|
196 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
197 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
198 |
+
|
199 |
+
"""
|
200 |
+
count = len(self.attn_processors.keys())
|
201 |
+
|
202 |
+
if isinstance(processor, dict) and len(processor) != count:
|
203 |
+
raise ValueError(
|
204 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
205 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
206 |
+
)
|
207 |
+
|
208 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
209 |
+
if hasattr(module, "set_processor"):
|
210 |
+
if not isinstance(processor, dict):
|
211 |
+
module.set_processor(processor, _remove_lora=_remove_lora)
|
212 |
+
else:
|
213 |
+
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
214 |
+
|
215 |
+
for sub_name, child in module.named_children():
|
216 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
217 |
+
|
218 |
+
for name, module in self.named_children():
|
219 |
+
fn_recursive_attn_processor(name, module, processor)
|
220 |
+
|
221 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
222 |
+
def set_default_attn_processor(self):
|
223 |
+
"""
|
224 |
+
Disables custom attention processors and sets the default attention implementation.
|
225 |
+
"""
|
226 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
227 |
+
processor = AttnAddedKVProcessor()
|
228 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
229 |
+
processor = AttnProcessor()
|
230 |
+
else:
|
231 |
+
raise ValueError(
|
232 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
233 |
+
)
|
234 |
+
|
235 |
+
self.set_attn_processor(processor, _remove_lora=True)
|
236 |
+
|
237 |
+
@apply_forward_hook
|
238 |
+
def encode(
|
239 |
+
self, x: torch.FloatTensor, return_dict: bool = True
|
240 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
241 |
+
"""
|
242 |
+
Encode a batch of images into latents.
|
243 |
+
|
244 |
+
Args:
|
245 |
+
x (`torch.FloatTensor`): Input batch of images.
|
246 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
247 |
+
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
248 |
+
|
249 |
+
Returns:
|
250 |
+
The latent representations of the encoded images. If `return_dict` is True, a
|
251 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
252 |
+
"""
|
253 |
+
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
254 |
+
return self.tiled_encode(x, return_dict=return_dict)
|
255 |
+
|
256 |
+
if self.use_slicing and x.shape[0] > 1:
|
257 |
+
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
258 |
+
h = torch.cat(encoded_slices)
|
259 |
+
else:
|
260 |
+
h = self.encoder(x)
|
261 |
+
|
262 |
+
moments = self.quant_conv(h)
|
263 |
+
posterior = DiagonalGaussianDistribution(moments)
|
264 |
+
|
265 |
+
if not return_dict:
|
266 |
+
return (posterior,)
|
267 |
+
|
268 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
269 |
+
|
270 |
+
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
271 |
+
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
272 |
+
return self.tiled_decode(z, return_dict=return_dict)
|
273 |
+
|
274 |
+
z = self.post_quant_conv(z)
|
275 |
+
dec = self.decoder(z)
|
276 |
+
|
277 |
+
if not return_dict:
|
278 |
+
return (dec,)
|
279 |
+
|
280 |
+
return DecoderOutput(sample=dec)
|
281 |
+
|
282 |
+
@apply_forward_hook
|
283 |
+
def decode(
|
284 |
+
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
285 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
286 |
+
"""
|
287 |
+
Decode a batch of images.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
z (`torch.FloatTensor`): Input batch of latent vectors.
|
291 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
292 |
+
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
293 |
+
|
294 |
+
Returns:
|
295 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
296 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
297 |
+
returned.
|
298 |
+
|
299 |
+
"""
|
300 |
+
if self.use_slicing and z.shape[0] > 1:
|
301 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
302 |
+
decoded = torch.cat(decoded_slices)
|
303 |
+
else:
|
304 |
+
decoded = self._decode(z).sample
|
305 |
+
|
306 |
+
if not return_dict:
|
307 |
+
return (decoded,)
|
308 |
+
|
309 |
+
return DecoderOutput(sample=decoded)
|
310 |
+
|
311 |
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
312 |
+
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
313 |
+
for y in range(blend_extent):
|
314 |
+
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
315 |
+
return b
|
316 |
+
|
317 |
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
318 |
+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
319 |
+
for x in range(blend_extent):
|
320 |
+
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
321 |
+
return b
|
322 |
+
|
323 |
+
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
324 |
+
r"""Encode a batch of images using a tiled encoder.
|
325 |
+
|
326 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
327 |
+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
328 |
+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
329 |
+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
330 |
+
output, but they should be much less noticeable.
|
331 |
+
|
332 |
+
Args:
|
333 |
+
x (`torch.FloatTensor`): Input batch of images.
|
334 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
335 |
+
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
336 |
+
|
337 |
+
Returns:
|
338 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
|
339 |
+
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
340 |
+
`tuple` is returned.
|
341 |
+
"""
|
342 |
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
343 |
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
344 |
+
row_limit = self.tile_latent_min_size - blend_extent
|
345 |
+
|
346 |
+
# Split the image into 512x512 tiles and encode them separately.
|
347 |
+
rows = []
|
348 |
+
for i in range(0, x.shape[2], overlap_size):
|
349 |
+
row = []
|
350 |
+
for j in range(0, x.shape[3], overlap_size):
|
351 |
+
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
352 |
+
tile = self.encoder(tile)
|
353 |
+
tile = self.quant_conv(tile)
|
354 |
+
row.append(tile)
|
355 |
+
rows.append(row)
|
356 |
+
result_rows = []
|
357 |
+
for i, row in enumerate(rows):
|
358 |
+
result_row = []
|
359 |
+
for j, tile in enumerate(row):
|
360 |
+
# blend the above tile and the left tile
|
361 |
+
# to the current tile and add the current tile to the result row
|
362 |
+
if i > 0:
|
363 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
364 |
+
if j > 0:
|
365 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
366 |
+
result_row.append(tile[:, :, :row_limit, :row_limit])
|
367 |
+
result_rows.append(torch.cat(result_row, dim=3))
|
368 |
+
|
369 |
+
moments = torch.cat(result_rows, dim=2)
|
370 |
+
posterior = DiagonalGaussianDistribution(moments)
|
371 |
+
|
372 |
+
if not return_dict:
|
373 |
+
return (posterior,)
|
374 |
+
|
375 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
376 |
+
|
377 |
+
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
378 |
+
r"""
|
379 |
+
Decode a batch of images using a tiled decoder.
|
380 |
+
|
381 |
+
Args:
|
382 |
+
z (`torch.FloatTensor`): Input batch of latent vectors.
|
383 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
384 |
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
385 |
+
|
386 |
+
Returns:
|
387 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
388 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
389 |
+
returned.
|
390 |
+
"""
|
391 |
+
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
392 |
+
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
393 |
+
row_limit = self.tile_sample_min_size - blend_extent
|
394 |
+
|
395 |
+
# Split z into overlapping 64x64 tiles and decode them separately.
|
396 |
+
# The tiles have an overlap to avoid seams between tiles.
|
397 |
+
rows = []
|
398 |
+
for i in range(0, z.shape[2], overlap_size):
|
399 |
+
row = []
|
400 |
+
for j in range(0, z.shape[3], overlap_size):
|
401 |
+
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
402 |
+
tile = self.post_quant_conv(tile)
|
403 |
+
decoded = self.decoder(tile)
|
404 |
+
row.append(decoded)
|
405 |
+
rows.append(row)
|
406 |
+
result_rows = []
|
407 |
+
for i, row in enumerate(rows):
|
408 |
+
result_row = []
|
409 |
+
for j, tile in enumerate(row):
|
410 |
+
# blend the above tile and the left tile
|
411 |
+
# to the current tile and add the current tile to the result row
|
412 |
+
if i > 0:
|
413 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
414 |
+
if j > 0:
|
415 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
416 |
+
result_row.append(tile[:, :, :row_limit, :row_limit])
|
417 |
+
result_rows.append(torch.cat(result_row, dim=3))
|
418 |
+
|
419 |
+
dec = torch.cat(result_rows, dim=2)
|
420 |
+
if not return_dict:
|
421 |
+
return (dec,)
|
422 |
+
|
423 |
+
return DecoderOutput(sample=dec)
|
424 |
+
|
425 |
+
def forward(
|
426 |
+
self,
|
427 |
+
sample: torch.FloatTensor,
|
428 |
+
sample_posterior: bool = False,
|
429 |
+
return_dict: bool = True,
|
430 |
+
generator: Optional[torch.Generator] = None,
|
431 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
432 |
+
r"""
|
433 |
+
Args:
|
434 |
+
sample (`torch.FloatTensor`): Input sample.
|
435 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
436 |
+
Whether to sample from the posterior.
|
437 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
438 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
439 |
+
"""
|
440 |
+
x = sample
|
441 |
+
posterior = self.encode(x).latent_dist
|
442 |
+
if sample_posterior:
|
443 |
+
z = posterior.sample(generator=generator)
|
444 |
+
else:
|
445 |
+
z = posterior.mode()
|
446 |
+
dec = self.decode(z).sample
|
447 |
+
|
448 |
+
if not return_dict:
|
449 |
+
return (dec,)
|
450 |
+
|
451 |
+
return DecoderOutput(sample=dec)
|
452 |
+
|
453 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
454 |
+
def fuse_qkv_projections(self):
|
455 |
+
"""
|
456 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
457 |
+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
458 |
+
|
459 |
+
<Tip warning={true}>
|
460 |
+
|
461 |
+
This API is 🧪 experimental.
|
462 |
+
|
463 |
+
</Tip>
|
464 |
+
"""
|
465 |
+
self.original_attn_processors = None
|
466 |
+
|
467 |
+
for _, attn_processor in self.attn_processors.items():
|
468 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
469 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
470 |
+
|
471 |
+
self.original_attn_processors = self.attn_processors
|
472 |
+
|
473 |
+
for module in self.modules():
|
474 |
+
if isinstance(module, Attention):
|
475 |
+
module.fuse_projections(fuse=True)
|
476 |
+
|
477 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
478 |
+
def unfuse_qkv_projections(self):
|
479 |
+
"""Disables the fused QKV projection if enabled.
|
480 |
+
|
481 |
+
<Tip warning={true}>
|
482 |
+
|
483 |
+
This API is 🧪 experimental.
|
484 |
+
|
485 |
+
</Tip>
|
486 |
+
|
487 |
+
"""
|
488 |
+
if self.original_attn_processors is not None:
|
489 |
+
self.set_attn_processor(self.original_attn_processors)
|
module/diffusers_vae/vae.py
ADDED
@@ -0,0 +1,985 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Optional, Tuple
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
from diffusers.utils import BaseOutput, is_torch_version
|
22 |
+
from diffusers.utils.torch_utils import randn_tensor
|
23 |
+
from diffusers.models.activations import get_activation
|
24 |
+
from diffusers.models.attention_processor import SpatialNorm
|
25 |
+
from diffusers.models.unet_2d_blocks import (
|
26 |
+
AutoencoderTinyBlock,
|
27 |
+
UNetMidBlock2D,
|
28 |
+
get_down_block,
|
29 |
+
get_up_block,
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
@dataclass
|
34 |
+
class DecoderOutput(BaseOutput):
|
35 |
+
r"""
|
36 |
+
Output of decoding method.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
40 |
+
The decoded output sample from the last layer of the model.
|
41 |
+
"""
|
42 |
+
|
43 |
+
sample: torch.FloatTensor
|
44 |
+
|
45 |
+
|
46 |
+
class Encoder(nn.Module):
|
47 |
+
r"""
|
48 |
+
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
in_channels (`int`, *optional*, defaults to 3):
|
52 |
+
The number of input channels.
|
53 |
+
out_channels (`int`, *optional*, defaults to 3):
|
54 |
+
The number of output channels.
|
55 |
+
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
56 |
+
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
|
57 |
+
options.
|
58 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
59 |
+
The number of output channels for each block.
|
60 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
61 |
+
The number of layers per block.
|
62 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
63 |
+
The number of groups for normalization.
|
64 |
+
act_fn (`str`, *optional*, defaults to `"silu"`):
|
65 |
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
66 |
+
double_z (`bool`, *optional*, defaults to `True`):
|
67 |
+
Whether to double the number of output channels for the last block.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
in_channels: int = 3,
|
73 |
+
out_channels: int = 3,
|
74 |
+
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
|
75 |
+
block_out_channels: Tuple[int, ...] = (64,),
|
76 |
+
layers_per_block: int = 2,
|
77 |
+
norm_num_groups: int = 32,
|
78 |
+
act_fn: str = "silu",
|
79 |
+
double_z: bool = True,
|
80 |
+
mid_block_add_attention=True,
|
81 |
+
):
|
82 |
+
super().__init__()
|
83 |
+
self.layers_per_block = layers_per_block
|
84 |
+
|
85 |
+
self.conv_in = nn.Conv2d(
|
86 |
+
in_channels,
|
87 |
+
block_out_channels[0],
|
88 |
+
kernel_size=3,
|
89 |
+
stride=1,
|
90 |
+
padding=1,
|
91 |
+
)
|
92 |
+
|
93 |
+
self.mid_block = None
|
94 |
+
self.down_blocks = nn.ModuleList([])
|
95 |
+
|
96 |
+
# down
|
97 |
+
output_channel = block_out_channels[0]
|
98 |
+
for i, down_block_type in enumerate(down_block_types):
|
99 |
+
input_channel = output_channel
|
100 |
+
output_channel = block_out_channels[i]
|
101 |
+
is_final_block = i == len(block_out_channels) - 1
|
102 |
+
|
103 |
+
down_block = get_down_block(
|
104 |
+
down_block_type,
|
105 |
+
num_layers=self.layers_per_block,
|
106 |
+
in_channels=input_channel,
|
107 |
+
out_channels=output_channel,
|
108 |
+
add_downsample=not is_final_block,
|
109 |
+
resnet_eps=1e-6,
|
110 |
+
downsample_padding=0,
|
111 |
+
resnet_act_fn=act_fn,
|
112 |
+
resnet_groups=norm_num_groups,
|
113 |
+
attention_head_dim=output_channel,
|
114 |
+
temb_channels=None,
|
115 |
+
)
|
116 |
+
self.down_blocks.append(down_block)
|
117 |
+
|
118 |
+
# mid
|
119 |
+
self.mid_block = UNetMidBlock2D(
|
120 |
+
in_channels=block_out_channels[-1],
|
121 |
+
resnet_eps=1e-6,
|
122 |
+
resnet_act_fn=act_fn,
|
123 |
+
output_scale_factor=1,
|
124 |
+
resnet_time_scale_shift="default",
|
125 |
+
attention_head_dim=block_out_channels[-1],
|
126 |
+
resnet_groups=norm_num_groups,
|
127 |
+
temb_channels=None,
|
128 |
+
add_attention=mid_block_add_attention,
|
129 |
+
)
|
130 |
+
|
131 |
+
# out
|
132 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
133 |
+
self.conv_act = nn.SiLU()
|
134 |
+
|
135 |
+
conv_out_channels = 2 * out_channels if double_z else out_channels
|
136 |
+
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
|
137 |
+
|
138 |
+
self.gradient_checkpointing = False
|
139 |
+
|
140 |
+
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
141 |
+
r"""The forward method of the `Encoder` class."""
|
142 |
+
|
143 |
+
sample = self.conv_in(sample)
|
144 |
+
|
145 |
+
if self.training and self.gradient_checkpointing:
|
146 |
+
|
147 |
+
def create_custom_forward(module):
|
148 |
+
def custom_forward(*inputs):
|
149 |
+
return module(*inputs)
|
150 |
+
|
151 |
+
return custom_forward
|
152 |
+
|
153 |
+
# down
|
154 |
+
if is_torch_version(">=", "1.11.0"):
|
155 |
+
for down_block in self.down_blocks:
|
156 |
+
sample = torch.utils.checkpoint.checkpoint(
|
157 |
+
create_custom_forward(down_block), sample, use_reentrant=False
|
158 |
+
)
|
159 |
+
# middle
|
160 |
+
sample = torch.utils.checkpoint.checkpoint(
|
161 |
+
create_custom_forward(self.mid_block), sample, use_reentrant=False
|
162 |
+
)
|
163 |
+
else:
|
164 |
+
for down_block in self.down_blocks:
|
165 |
+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
|
166 |
+
# middle
|
167 |
+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
|
168 |
+
|
169 |
+
else:
|
170 |
+
# down
|
171 |
+
for down_block in self.down_blocks:
|
172 |
+
sample = down_block(sample)
|
173 |
+
|
174 |
+
# middle
|
175 |
+
sample = self.mid_block(sample)
|
176 |
+
|
177 |
+
# post-process
|
178 |
+
sample = self.conv_norm_out(sample)
|
179 |
+
sample = self.conv_act(sample)
|
180 |
+
sample = self.conv_out(sample)
|
181 |
+
|
182 |
+
return sample
|
183 |
+
|
184 |
+
|
185 |
+
class Decoder(nn.Module):
|
186 |
+
r"""
|
187 |
+
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
in_channels (`int`, *optional*, defaults to 3):
|
191 |
+
The number of input channels.
|
192 |
+
out_channels (`int`, *optional*, defaults to 3):
|
193 |
+
The number of output channels.
|
194 |
+
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
195 |
+
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
|
196 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
197 |
+
The number of output channels for each block.
|
198 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
199 |
+
The number of layers per block.
|
200 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
201 |
+
The number of groups for normalization.
|
202 |
+
act_fn (`str`, *optional*, defaults to `"silu"`):
|
203 |
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
204 |
+
norm_type (`str`, *optional*, defaults to `"group"`):
|
205 |
+
The normalization type to use. Can be either `"group"` or `"spatial"`.
|
206 |
+
"""
|
207 |
+
|
208 |
+
def __init__(
|
209 |
+
self,
|
210 |
+
in_channels: int = 3,
|
211 |
+
out_channels: int = 3,
|
212 |
+
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
|
213 |
+
block_out_channels: Tuple[int, ...] = (64,),
|
214 |
+
layers_per_block: int = 2,
|
215 |
+
norm_num_groups: int = 32,
|
216 |
+
act_fn: str = "silu",
|
217 |
+
norm_type: str = "group", # group, spatial
|
218 |
+
mid_block_add_attention=True,
|
219 |
+
):
|
220 |
+
super().__init__()
|
221 |
+
self.layers_per_block = layers_per_block
|
222 |
+
|
223 |
+
self.conv_in = nn.Conv2d(
|
224 |
+
in_channels,
|
225 |
+
block_out_channels[-1],
|
226 |
+
kernel_size=3,
|
227 |
+
stride=1,
|
228 |
+
padding=1,
|
229 |
+
)
|
230 |
+
|
231 |
+
self.mid_block = None
|
232 |
+
self.up_blocks = nn.ModuleList([])
|
233 |
+
|
234 |
+
temb_channels = in_channels if norm_type == "spatial" else None
|
235 |
+
|
236 |
+
# mid
|
237 |
+
self.mid_block = UNetMidBlock2D(
|
238 |
+
in_channels=block_out_channels[-1],
|
239 |
+
resnet_eps=1e-6,
|
240 |
+
resnet_act_fn=act_fn,
|
241 |
+
output_scale_factor=1,
|
242 |
+
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
|
243 |
+
attention_head_dim=block_out_channels[-1],
|
244 |
+
resnet_groups=norm_num_groups,
|
245 |
+
temb_channels=temb_channels,
|
246 |
+
add_attention=mid_block_add_attention,
|
247 |
+
)
|
248 |
+
|
249 |
+
# up
|
250 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
251 |
+
output_channel = reversed_block_out_channels[0]
|
252 |
+
for i, up_block_type in enumerate(up_block_types):
|
253 |
+
prev_output_channel = output_channel
|
254 |
+
output_channel = reversed_block_out_channels[i]
|
255 |
+
|
256 |
+
is_final_block = i == len(block_out_channels) - 1
|
257 |
+
|
258 |
+
up_block = get_up_block(
|
259 |
+
up_block_type,
|
260 |
+
num_layers=self.layers_per_block + 1,
|
261 |
+
in_channels=prev_output_channel,
|
262 |
+
out_channels=output_channel,
|
263 |
+
prev_output_channel=None,
|
264 |
+
add_upsample=not is_final_block,
|
265 |
+
resnet_eps=1e-6,
|
266 |
+
resnet_act_fn=act_fn,
|
267 |
+
resnet_groups=norm_num_groups,
|
268 |
+
attention_head_dim=output_channel,
|
269 |
+
temb_channels=temb_channels,
|
270 |
+
resnet_time_scale_shift=norm_type,
|
271 |
+
)
|
272 |
+
self.up_blocks.append(up_block)
|
273 |
+
prev_output_channel = output_channel
|
274 |
+
|
275 |
+
# out
|
276 |
+
if norm_type == "spatial":
|
277 |
+
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
|
278 |
+
else:
|
279 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
280 |
+
self.conv_act = nn.SiLU()
|
281 |
+
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
282 |
+
|
283 |
+
self.gradient_checkpointing = False
|
284 |
+
|
285 |
+
def forward(
|
286 |
+
self,
|
287 |
+
sample: torch.FloatTensor,
|
288 |
+
latent_embeds: Optional[torch.FloatTensor] = None,
|
289 |
+
) -> torch.FloatTensor:
|
290 |
+
r"""The forward method of the `Decoder` class."""
|
291 |
+
|
292 |
+
sample = self.conv_in(sample)
|
293 |
+
sample = sample.to(torch.float32)
|
294 |
+
|
295 |
+
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
296 |
+
|
297 |
+
if self.training and self.gradient_checkpointing:
|
298 |
+
|
299 |
+
def create_custom_forward(module):
|
300 |
+
def custom_forward(*inputs):
|
301 |
+
return module(*inputs)
|
302 |
+
|
303 |
+
return custom_forward
|
304 |
+
|
305 |
+
if is_torch_version(">=", "1.11.0"):
|
306 |
+
# middle
|
307 |
+
sample = torch.utils.checkpoint.checkpoint(
|
308 |
+
create_custom_forward(self.mid_block),
|
309 |
+
sample,
|
310 |
+
latent_embeds,
|
311 |
+
use_reentrant=False,
|
312 |
+
)
|
313 |
+
sample = sample.to(upscale_dtype)
|
314 |
+
|
315 |
+
# up
|
316 |
+
for up_block in self.up_blocks:
|
317 |
+
sample = torch.utils.checkpoint.checkpoint(
|
318 |
+
create_custom_forward(up_block),
|
319 |
+
sample,
|
320 |
+
latent_embeds,
|
321 |
+
use_reentrant=False,
|
322 |
+
)
|
323 |
+
else:
|
324 |
+
# middle
|
325 |
+
sample = torch.utils.checkpoint.checkpoint(
|
326 |
+
create_custom_forward(self.mid_block), sample, latent_embeds
|
327 |
+
)
|
328 |
+
sample = sample.to(upscale_dtype)
|
329 |
+
|
330 |
+
# up
|
331 |
+
for up_block in self.up_blocks:
|
332 |
+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
|
333 |
+
else:
|
334 |
+
# middle
|
335 |
+
sample = self.mid_block(sample, latent_embeds)
|
336 |
+
sample = sample.to(upscale_dtype)
|
337 |
+
|
338 |
+
# up
|
339 |
+
for up_block in self.up_blocks:
|
340 |
+
sample = up_block(sample, latent_embeds)
|
341 |
+
|
342 |
+
# post-process
|
343 |
+
if latent_embeds is None:
|
344 |
+
sample = self.conv_norm_out(sample)
|
345 |
+
else:
|
346 |
+
sample = self.conv_norm_out(sample, latent_embeds)
|
347 |
+
sample = self.conv_act(sample)
|
348 |
+
sample = self.conv_out(sample)
|
349 |
+
|
350 |
+
return sample
|
351 |
+
|
352 |
+
|
353 |
+
class UpSample(nn.Module):
|
354 |
+
r"""
|
355 |
+
The `UpSample` layer of a variational autoencoder that upsamples its input.
|
356 |
+
|
357 |
+
Args:
|
358 |
+
in_channels (`int`, *optional*, defaults to 3):
|
359 |
+
The number of input channels.
|
360 |
+
out_channels (`int`, *optional*, defaults to 3):
|
361 |
+
The number of output channels.
|
362 |
+
"""
|
363 |
+
|
364 |
+
def __init__(
|
365 |
+
self,
|
366 |
+
in_channels: int,
|
367 |
+
out_channels: int,
|
368 |
+
) -> None:
|
369 |
+
super().__init__()
|
370 |
+
self.in_channels = in_channels
|
371 |
+
self.out_channels = out_channels
|
372 |
+
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
|
373 |
+
|
374 |
+
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
375 |
+
r"""The forward method of the `UpSample` class."""
|
376 |
+
x = torch.relu(x)
|
377 |
+
x = self.deconv(x)
|
378 |
+
return x
|
379 |
+
|
380 |
+
|
381 |
+
class MaskConditionEncoder(nn.Module):
|
382 |
+
"""
|
383 |
+
used in AsymmetricAutoencoderKL
|
384 |
+
"""
|
385 |
+
|
386 |
+
def __init__(
|
387 |
+
self,
|
388 |
+
in_ch: int,
|
389 |
+
out_ch: int = 192,
|
390 |
+
res_ch: int = 768,
|
391 |
+
stride: int = 16,
|
392 |
+
) -> None:
|
393 |
+
super().__init__()
|
394 |
+
|
395 |
+
channels = []
|
396 |
+
while stride > 1:
|
397 |
+
stride = stride // 2
|
398 |
+
in_ch_ = out_ch * 2
|
399 |
+
if out_ch > res_ch:
|
400 |
+
out_ch = res_ch
|
401 |
+
if stride == 1:
|
402 |
+
in_ch_ = res_ch
|
403 |
+
channels.append((in_ch_, out_ch))
|
404 |
+
out_ch *= 2
|
405 |
+
|
406 |
+
out_channels = []
|
407 |
+
for _in_ch, _out_ch in channels:
|
408 |
+
out_channels.append(_out_ch)
|
409 |
+
out_channels.append(channels[-1][0])
|
410 |
+
|
411 |
+
layers = []
|
412 |
+
in_ch_ = in_ch
|
413 |
+
for l in range(len(out_channels)):
|
414 |
+
out_ch_ = out_channels[l]
|
415 |
+
if l == 0 or l == 1:
|
416 |
+
layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1))
|
417 |
+
else:
|
418 |
+
layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1))
|
419 |
+
in_ch_ = out_ch_
|
420 |
+
|
421 |
+
self.layers = nn.Sequential(*layers)
|
422 |
+
|
423 |
+
def forward(self, x: torch.FloatTensor, mask=None) -> torch.FloatTensor:
|
424 |
+
r"""The forward method of the `MaskConditionEncoder` class."""
|
425 |
+
out = {}
|
426 |
+
for l in range(len(self.layers)):
|
427 |
+
layer = self.layers[l]
|
428 |
+
x = layer(x)
|
429 |
+
out[str(tuple(x.shape))] = x
|
430 |
+
x = torch.relu(x)
|
431 |
+
return out
|
432 |
+
|
433 |
+
|
434 |
+
class MaskConditionDecoder(nn.Module):
|
435 |
+
r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's
|
436 |
+
decoder with a conditioner on the mask and masked image.
|
437 |
+
|
438 |
+
Args:
|
439 |
+
in_channels (`int`, *optional*, defaults to 3):
|
440 |
+
The number of input channels.
|
441 |
+
out_channels (`int`, *optional*, defaults to 3):
|
442 |
+
The number of output channels.
|
443 |
+
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
444 |
+
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
|
445 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
446 |
+
The number of output channels for each block.
|
447 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
448 |
+
The number of layers per block.
|
449 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
450 |
+
The number of groups for normalization.
|
451 |
+
act_fn (`str`, *optional*, defaults to `"silu"`):
|
452 |
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
453 |
+
norm_type (`str`, *optional*, defaults to `"group"`):
|
454 |
+
The normalization type to use. Can be either `"group"` or `"spatial"`.
|
455 |
+
"""
|
456 |
+
|
457 |
+
def __init__(
|
458 |
+
self,
|
459 |
+
in_channels: int = 3,
|
460 |
+
out_channels: int = 3,
|
461 |
+
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
|
462 |
+
block_out_channels: Tuple[int, ...] = (64,),
|
463 |
+
layers_per_block: int = 2,
|
464 |
+
norm_num_groups: int = 32,
|
465 |
+
act_fn: str = "silu",
|
466 |
+
norm_type: str = "group", # group, spatial
|
467 |
+
):
|
468 |
+
super().__init__()
|
469 |
+
self.layers_per_block = layers_per_block
|
470 |
+
|
471 |
+
self.conv_in = nn.Conv2d(
|
472 |
+
in_channels,
|
473 |
+
block_out_channels[-1],
|
474 |
+
kernel_size=3,
|
475 |
+
stride=1,
|
476 |
+
padding=1,
|
477 |
+
)
|
478 |
+
|
479 |
+
self.mid_block = None
|
480 |
+
self.up_blocks = nn.ModuleList([])
|
481 |
+
|
482 |
+
temb_channels = in_channels if norm_type == "spatial" else None
|
483 |
+
|
484 |
+
# mid
|
485 |
+
self.mid_block = UNetMidBlock2D(
|
486 |
+
in_channels=block_out_channels[-1],
|
487 |
+
resnet_eps=1e-6,
|
488 |
+
resnet_act_fn=act_fn,
|
489 |
+
output_scale_factor=1,
|
490 |
+
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
|
491 |
+
attention_head_dim=block_out_channels[-1],
|
492 |
+
resnet_groups=norm_num_groups,
|
493 |
+
temb_channels=temb_channels,
|
494 |
+
)
|
495 |
+
|
496 |
+
# up
|
497 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
498 |
+
output_channel = reversed_block_out_channels[0]
|
499 |
+
for i, up_block_type in enumerate(up_block_types):
|
500 |
+
prev_output_channel = output_channel
|
501 |
+
output_channel = reversed_block_out_channels[i]
|
502 |
+
|
503 |
+
is_final_block = i == len(block_out_channels) - 1
|
504 |
+
|
505 |
+
up_block = get_up_block(
|
506 |
+
up_block_type,
|
507 |
+
num_layers=self.layers_per_block + 1,
|
508 |
+
in_channels=prev_output_channel,
|
509 |
+
out_channels=output_channel,
|
510 |
+
prev_output_channel=None,
|
511 |
+
add_upsample=not is_final_block,
|
512 |
+
resnet_eps=1e-6,
|
513 |
+
resnet_act_fn=act_fn,
|
514 |
+
resnet_groups=norm_num_groups,
|
515 |
+
attention_head_dim=output_channel,
|
516 |
+
temb_channels=temb_channels,
|
517 |
+
resnet_time_scale_shift=norm_type,
|
518 |
+
)
|
519 |
+
self.up_blocks.append(up_block)
|
520 |
+
prev_output_channel = output_channel
|
521 |
+
|
522 |
+
# condition encoder
|
523 |
+
self.condition_encoder = MaskConditionEncoder(
|
524 |
+
in_ch=out_channels,
|
525 |
+
out_ch=block_out_channels[0],
|
526 |
+
res_ch=block_out_channels[-1],
|
527 |
+
)
|
528 |
+
|
529 |
+
# out
|
530 |
+
if norm_type == "spatial":
|
531 |
+
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
|
532 |
+
else:
|
533 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
534 |
+
self.conv_act = nn.SiLU()
|
535 |
+
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
536 |
+
|
537 |
+
self.gradient_checkpointing = False
|
538 |
+
|
539 |
+
def forward(
|
540 |
+
self,
|
541 |
+
z: torch.FloatTensor,
|
542 |
+
image: Optional[torch.FloatTensor] = None,
|
543 |
+
mask: Optional[torch.FloatTensor] = None,
|
544 |
+
latent_embeds: Optional[torch.FloatTensor] = None,
|
545 |
+
) -> torch.FloatTensor:
|
546 |
+
r"""The forward method of the `MaskConditionDecoder` class."""
|
547 |
+
sample = z
|
548 |
+
sample = self.conv_in(sample)
|
549 |
+
|
550 |
+
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
551 |
+
if self.training and self.gradient_checkpointing:
|
552 |
+
|
553 |
+
def create_custom_forward(module):
|
554 |
+
def custom_forward(*inputs):
|
555 |
+
return module(*inputs)
|
556 |
+
|
557 |
+
return custom_forward
|
558 |
+
|
559 |
+
if is_torch_version(">=", "1.11.0"):
|
560 |
+
# middle
|
561 |
+
sample = torch.utils.checkpoint.checkpoint(
|
562 |
+
create_custom_forward(self.mid_block),
|
563 |
+
sample,
|
564 |
+
latent_embeds,
|
565 |
+
use_reentrant=False,
|
566 |
+
)
|
567 |
+
sample = sample.to(upscale_dtype)
|
568 |
+
|
569 |
+
# condition encoder
|
570 |
+
if image is not None and mask is not None:
|
571 |
+
masked_image = (1 - mask) * image
|
572 |
+
im_x = torch.utils.checkpoint.checkpoint(
|
573 |
+
create_custom_forward(self.condition_encoder),
|
574 |
+
masked_image,
|
575 |
+
mask,
|
576 |
+
use_reentrant=False,
|
577 |
+
)
|
578 |
+
|
579 |
+
# up
|
580 |
+
for up_block in self.up_blocks:
|
581 |
+
if image is not None and mask is not None:
|
582 |
+
sample_ = im_x[str(tuple(sample.shape))]
|
583 |
+
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
|
584 |
+
sample = sample * mask_ + sample_ * (1 - mask_)
|
585 |
+
sample = torch.utils.checkpoint.checkpoint(
|
586 |
+
create_custom_forward(up_block),
|
587 |
+
sample,
|
588 |
+
latent_embeds,
|
589 |
+
use_reentrant=False,
|
590 |
+
)
|
591 |
+
if image is not None and mask is not None:
|
592 |
+
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
|
593 |
+
else:
|
594 |
+
# middle
|
595 |
+
sample = torch.utils.checkpoint.checkpoint(
|
596 |
+
create_custom_forward(self.mid_block), sample, latent_embeds
|
597 |
+
)
|
598 |
+
sample = sample.to(upscale_dtype)
|
599 |
+
|
600 |
+
# condition encoder
|
601 |
+
if image is not None and mask is not None:
|
602 |
+
masked_image = (1 - mask) * image
|
603 |
+
im_x = torch.utils.checkpoint.checkpoint(
|
604 |
+
create_custom_forward(self.condition_encoder),
|
605 |
+
masked_image,
|
606 |
+
mask,
|
607 |
+
)
|
608 |
+
|
609 |
+
# up
|
610 |
+
for up_block in self.up_blocks:
|
611 |
+
if image is not None and mask is not None:
|
612 |
+
sample_ = im_x[str(tuple(sample.shape))]
|
613 |
+
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
|
614 |
+
sample = sample * mask_ + sample_ * (1 - mask_)
|
615 |
+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
|
616 |
+
if image is not None and mask is not None:
|
617 |
+
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
|
618 |
+
else:
|
619 |
+
# middle
|
620 |
+
sample = self.mid_block(sample, latent_embeds)
|
621 |
+
sample = sample.to(upscale_dtype)
|
622 |
+
|
623 |
+
# condition encoder
|
624 |
+
if image is not None and mask is not None:
|
625 |
+
masked_image = (1 - mask) * image
|
626 |
+
im_x = self.condition_encoder(masked_image, mask)
|
627 |
+
|
628 |
+
# up
|
629 |
+
for up_block in self.up_blocks:
|
630 |
+
if image is not None and mask is not None:
|
631 |
+
sample_ = im_x[str(tuple(sample.shape))]
|
632 |
+
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
|
633 |
+
sample = sample * mask_ + sample_ * (1 - mask_)
|
634 |
+
sample = up_block(sample, latent_embeds)
|
635 |
+
if image is not None and mask is not None:
|
636 |
+
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
|
637 |
+
|
638 |
+
# post-process
|
639 |
+
if latent_embeds is None:
|
640 |
+
sample = self.conv_norm_out(sample)
|
641 |
+
else:
|
642 |
+
sample = self.conv_norm_out(sample, latent_embeds)
|
643 |
+
sample = self.conv_act(sample)
|
644 |
+
sample = self.conv_out(sample)
|
645 |
+
|
646 |
+
return sample
|
647 |
+
|
648 |
+
|
649 |
+
class VectorQuantizer(nn.Module):
|
650 |
+
"""
|
651 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
|
652 |
+
multiplications and allows for post-hoc remapping of indices.
|
653 |
+
"""
|
654 |
+
|
655 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
656 |
+
# backwards compatibility we use the buggy version by default, but you can
|
657 |
+
# specify legacy=False to fix it.
|
658 |
+
def __init__(
|
659 |
+
self,
|
660 |
+
n_e: int,
|
661 |
+
vq_embed_dim: int,
|
662 |
+
beta: float,
|
663 |
+
remap=None,
|
664 |
+
unknown_index: str = "random",
|
665 |
+
sane_index_shape: bool = False,
|
666 |
+
legacy: bool = True,
|
667 |
+
):
|
668 |
+
super().__init__()
|
669 |
+
self.n_e = n_e
|
670 |
+
self.vq_embed_dim = vq_embed_dim
|
671 |
+
self.beta = beta
|
672 |
+
self.legacy = legacy
|
673 |
+
|
674 |
+
self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
|
675 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
676 |
+
|
677 |
+
self.remap = remap
|
678 |
+
if self.remap is not None:
|
679 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
680 |
+
self.used: torch.Tensor
|
681 |
+
self.re_embed = self.used.shape[0]
|
682 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
683 |
+
if self.unknown_index == "extra":
|
684 |
+
self.unknown_index = self.re_embed
|
685 |
+
self.re_embed = self.re_embed + 1
|
686 |
+
print(
|
687 |
+
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
688 |
+
f"Using {self.unknown_index} for unknown indices."
|
689 |
+
)
|
690 |
+
else:
|
691 |
+
self.re_embed = n_e
|
692 |
+
|
693 |
+
self.sane_index_shape = sane_index_shape
|
694 |
+
|
695 |
+
def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor:
|
696 |
+
ishape = inds.shape
|
697 |
+
assert len(ishape) > 1
|
698 |
+
inds = inds.reshape(ishape[0], -1)
|
699 |
+
used = self.used.to(inds)
|
700 |
+
match = (inds[:, :, None] == used[None, None, ...]).long()
|
701 |
+
new = match.argmax(-1)
|
702 |
+
unknown = match.sum(2) < 1
|
703 |
+
if self.unknown_index == "random":
|
704 |
+
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
705 |
+
else:
|
706 |
+
new[unknown] = self.unknown_index
|
707 |
+
return new.reshape(ishape)
|
708 |
+
|
709 |
+
def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor:
|
710 |
+
ishape = inds.shape
|
711 |
+
assert len(ishape) > 1
|
712 |
+
inds = inds.reshape(ishape[0], -1)
|
713 |
+
used = self.used.to(inds)
|
714 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
715 |
+
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
716 |
+
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
717 |
+
return back.reshape(ishape)
|
718 |
+
|
719 |
+
def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]:
|
720 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
721 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
722 |
+
z_flattened = z.view(-1, self.vq_embed_dim)
|
723 |
+
|
724 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
725 |
+
min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)
|
726 |
+
|
727 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
728 |
+
perplexity = None
|
729 |
+
min_encodings = None
|
730 |
+
|
731 |
+
# compute loss for embedding
|
732 |
+
if not self.legacy:
|
733 |
+
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
|
734 |
+
else:
|
735 |
+
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
736 |
+
|
737 |
+
# preserve gradients
|
738 |
+
z_q: torch.FloatTensor = z + (z_q - z).detach()
|
739 |
+
|
740 |
+
# reshape back to match original input shape
|
741 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
742 |
+
|
743 |
+
if self.remap is not None:
|
744 |
+
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
|
745 |
+
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
746 |
+
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
747 |
+
|
748 |
+
if self.sane_index_shape:
|
749 |
+
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
750 |
+
|
751 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
752 |
+
|
753 |
+
def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.FloatTensor:
|
754 |
+
# shape specifying (batch, height, width, channel)
|
755 |
+
if self.remap is not None:
|
756 |
+
indices = indices.reshape(shape[0], -1) # add batch axis
|
757 |
+
indices = self.unmap_to_all(indices)
|
758 |
+
indices = indices.reshape(-1) # flatten again
|
759 |
+
|
760 |
+
# get quantized latent vectors
|
761 |
+
z_q: torch.FloatTensor = self.embedding(indices)
|
762 |
+
|
763 |
+
if shape is not None:
|
764 |
+
z_q = z_q.view(shape)
|
765 |
+
# reshape back to match original input shape
|
766 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
767 |
+
|
768 |
+
return z_q
|
769 |
+
|
770 |
+
|
771 |
+
class DiagonalGaussianDistribution(object):
|
772 |
+
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
773 |
+
self.parameters = parameters
|
774 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
775 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
776 |
+
self.deterministic = deterministic
|
777 |
+
self.std = torch.exp(0.5 * self.logvar)
|
778 |
+
self.var = torch.exp(self.logvar)
|
779 |
+
if self.deterministic:
|
780 |
+
self.var = self.std = torch.zeros_like(
|
781 |
+
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
|
782 |
+
)
|
783 |
+
|
784 |
+
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
785 |
+
# make sure sample is on the same device as the parameters and has same dtype
|
786 |
+
sample = randn_tensor(
|
787 |
+
self.mean.shape,
|
788 |
+
generator=generator,
|
789 |
+
device=self.parameters.device,
|
790 |
+
dtype=self.parameters.dtype,
|
791 |
+
)
|
792 |
+
x = self.mean + self.std * sample
|
793 |
+
return x
|
794 |
+
|
795 |
+
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
|
796 |
+
if self.deterministic:
|
797 |
+
return torch.Tensor([0.0])
|
798 |
+
else:
|
799 |
+
if other is None:
|
800 |
+
return 0.5 * torch.sum(
|
801 |
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
802 |
+
dim=[1, 2, 3],
|
803 |
+
)
|
804 |
+
else:
|
805 |
+
return 0.5 * torch.sum(
|
806 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
807 |
+
+ self.var / other.var
|
808 |
+
- 1.0
|
809 |
+
- self.logvar
|
810 |
+
+ other.logvar,
|
811 |
+
dim=[1, 2, 3],
|
812 |
+
)
|
813 |
+
|
814 |
+
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
|
815 |
+
if self.deterministic:
|
816 |
+
return torch.Tensor([0.0])
|
817 |
+
logtwopi = np.log(2.0 * np.pi)
|
818 |
+
return 0.5 * torch.sum(
|
819 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
820 |
+
dim=dims,
|
821 |
+
)
|
822 |
+
|
823 |
+
def mode(self) -> torch.Tensor:
|
824 |
+
return self.mean
|
825 |
+
|
826 |
+
|
827 |
+
class EncoderTiny(nn.Module):
|
828 |
+
r"""
|
829 |
+
The `EncoderTiny` layer is a simpler version of the `Encoder` layer.
|
830 |
+
|
831 |
+
Args:
|
832 |
+
in_channels (`int`):
|
833 |
+
The number of input channels.
|
834 |
+
out_channels (`int`):
|
835 |
+
The number of output channels.
|
836 |
+
num_blocks (`Tuple[int, ...]`):
|
837 |
+
Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
|
838 |
+
use.
|
839 |
+
block_out_channels (`Tuple[int, ...]`):
|
840 |
+
The number of output channels for each block.
|
841 |
+
act_fn (`str`):
|
842 |
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
843 |
+
"""
|
844 |
+
|
845 |
+
def __init__(
|
846 |
+
self,
|
847 |
+
in_channels: int,
|
848 |
+
out_channels: int,
|
849 |
+
num_blocks: Tuple[int, ...],
|
850 |
+
block_out_channels: Tuple[int, ...],
|
851 |
+
act_fn: str,
|
852 |
+
):
|
853 |
+
super().__init__()
|
854 |
+
|
855 |
+
layers = []
|
856 |
+
for i, num_block in enumerate(num_blocks):
|
857 |
+
num_channels = block_out_channels[i]
|
858 |
+
|
859 |
+
if i == 0:
|
860 |
+
layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
|
861 |
+
else:
|
862 |
+
layers.append(
|
863 |
+
nn.Conv2d(
|
864 |
+
num_channels,
|
865 |
+
num_channels,
|
866 |
+
kernel_size=3,
|
867 |
+
padding=1,
|
868 |
+
stride=2,
|
869 |
+
bias=False,
|
870 |
+
)
|
871 |
+
)
|
872 |
+
|
873 |
+
for _ in range(num_block):
|
874 |
+
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
|
875 |
+
|
876 |
+
layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1))
|
877 |
+
|
878 |
+
self.layers = nn.Sequential(*layers)
|
879 |
+
self.gradient_checkpointing = False
|
880 |
+
|
881 |
+
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
882 |
+
r"""The forward method of the `EncoderTiny` class."""
|
883 |
+
if self.training and self.gradient_checkpointing:
|
884 |
+
|
885 |
+
def create_custom_forward(module):
|
886 |
+
def custom_forward(*inputs):
|
887 |
+
return module(*inputs)
|
888 |
+
|
889 |
+
return custom_forward
|
890 |
+
|
891 |
+
if is_torch_version(">=", "1.11.0"):
|
892 |
+
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
|
893 |
+
else:
|
894 |
+
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
|
895 |
+
|
896 |
+
else:
|
897 |
+
# scale image from [-1, 1] to [0, 1] to match TAESD convention
|
898 |
+
x = self.layers(x.add(1).div(2))
|
899 |
+
|
900 |
+
return x
|
901 |
+
|
902 |
+
|
903 |
+
class DecoderTiny(nn.Module):
|
904 |
+
r"""
|
905 |
+
The `DecoderTiny` layer is a simpler version of the `Decoder` layer.
|
906 |
+
|
907 |
+
Args:
|
908 |
+
in_channels (`int`):
|
909 |
+
The number of input channels.
|
910 |
+
out_channels (`int`):
|
911 |
+
The number of output channels.
|
912 |
+
num_blocks (`Tuple[int, ...]`):
|
913 |
+
Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
|
914 |
+
use.
|
915 |
+
block_out_channels (`Tuple[int, ...]`):
|
916 |
+
The number of output channels for each block.
|
917 |
+
upsampling_scaling_factor (`int`):
|
918 |
+
The scaling factor to use for upsampling.
|
919 |
+
act_fn (`str`):
|
920 |
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
921 |
+
"""
|
922 |
+
|
923 |
+
def __init__(
|
924 |
+
self,
|
925 |
+
in_channels: int,
|
926 |
+
out_channels: int,
|
927 |
+
num_blocks: Tuple[int, ...],
|
928 |
+
block_out_channels: Tuple[int, ...],
|
929 |
+
upsampling_scaling_factor: int,
|
930 |
+
act_fn: str,
|
931 |
+
):
|
932 |
+
super().__init__()
|
933 |
+
|
934 |
+
layers = [
|
935 |
+
nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1),
|
936 |
+
get_activation(act_fn),
|
937 |
+
]
|
938 |
+
|
939 |
+
for i, num_block in enumerate(num_blocks):
|
940 |
+
is_final_block = i == (len(num_blocks) - 1)
|
941 |
+
num_channels = block_out_channels[i]
|
942 |
+
|
943 |
+
for _ in range(num_block):
|
944 |
+
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
|
945 |
+
|
946 |
+
if not is_final_block:
|
947 |
+
layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
|
948 |
+
|
949 |
+
conv_out_channel = num_channels if not is_final_block else out_channels
|
950 |
+
layers.append(
|
951 |
+
nn.Conv2d(
|
952 |
+
num_channels,
|
953 |
+
conv_out_channel,
|
954 |
+
kernel_size=3,
|
955 |
+
padding=1,
|
956 |
+
bias=is_final_block,
|
957 |
+
)
|
958 |
+
)
|
959 |
+
|
960 |
+
self.layers = nn.Sequential(*layers)
|
961 |
+
self.gradient_checkpointing = False
|
962 |
+
|
963 |
+
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
964 |
+
r"""The forward method of the `DecoderTiny` class."""
|
965 |
+
# Clamp.
|
966 |
+
x = torch.tanh(x / 3) * 3
|
967 |
+
|
968 |
+
if self.training and self.gradient_checkpointing:
|
969 |
+
|
970 |
+
def create_custom_forward(module):
|
971 |
+
def custom_forward(*inputs):
|
972 |
+
return module(*inputs)
|
973 |
+
|
974 |
+
return custom_forward
|
975 |
+
|
976 |
+
if is_torch_version(">=", "1.11.0"):
|
977 |
+
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
|
978 |
+
else:
|
979 |
+
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
|
980 |
+
|
981 |
+
else:
|
982 |
+
x = self.layers(x)
|
983 |
+
|
984 |
+
# scale image from [0, 1] to [-1, 1] to match diffusers convention
|
985 |
+
return x.mul(2).sub(1)
|
module/ip_adapter/attention_processor.py
ADDED
@@ -0,0 +1,1467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
class AdaLayerNorm(nn.Module):
|
7 |
+
def __init__(self, embedding_dim: int, time_embedding_dim: int = None):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
if time_embedding_dim is None:
|
11 |
+
time_embedding_dim = embedding_dim
|
12 |
+
|
13 |
+
self.silu = nn.SiLU()
|
14 |
+
self.linear = nn.Linear(time_embedding_dim, 2 * embedding_dim, bias=True)
|
15 |
+
nn.init.zeros_(self.linear.weight)
|
16 |
+
nn.init.zeros_(self.linear.bias)
|
17 |
+
|
18 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
19 |
+
|
20 |
+
def forward(
|
21 |
+
self, x: torch.Tensor, timestep_embedding: torch.Tensor
|
22 |
+
):
|
23 |
+
emb = self.linear(self.silu(timestep_embedding))
|
24 |
+
shift, scale = emb.view(len(x), 1, -1).chunk(2, dim=-1)
|
25 |
+
x = self.norm(x) * (1 + scale) + shift
|
26 |
+
return x
|
27 |
+
|
28 |
+
|
29 |
+
class AttnProcessor(nn.Module):
|
30 |
+
r"""
|
31 |
+
Default processor for performing attention-related computations.
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
hidden_size=None,
|
37 |
+
cross_attention_dim=None,
|
38 |
+
):
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
def __call__(
|
42 |
+
self,
|
43 |
+
attn,
|
44 |
+
hidden_states,
|
45 |
+
encoder_hidden_states=None,
|
46 |
+
attention_mask=None,
|
47 |
+
temb=None,
|
48 |
+
):
|
49 |
+
residual = hidden_states
|
50 |
+
|
51 |
+
if attn.spatial_norm is not None:
|
52 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
53 |
+
|
54 |
+
input_ndim = hidden_states.ndim
|
55 |
+
|
56 |
+
if input_ndim == 4:
|
57 |
+
batch_size, channel, height, width = hidden_states.shape
|
58 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
59 |
+
|
60 |
+
batch_size, sequence_length, _ = (
|
61 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
62 |
+
)
|
63 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
64 |
+
|
65 |
+
if attn.group_norm is not None:
|
66 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
67 |
+
|
68 |
+
query = attn.to_q(hidden_states)
|
69 |
+
|
70 |
+
if encoder_hidden_states is None:
|
71 |
+
encoder_hidden_states = hidden_states
|
72 |
+
elif attn.norm_cross:
|
73 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
74 |
+
|
75 |
+
key = attn.to_k(encoder_hidden_states)
|
76 |
+
value = attn.to_v(encoder_hidden_states)
|
77 |
+
|
78 |
+
query = attn.head_to_batch_dim(query)
|
79 |
+
key = attn.head_to_batch_dim(key)
|
80 |
+
value = attn.head_to_batch_dim(value)
|
81 |
+
|
82 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
83 |
+
hidden_states = torch.bmm(attention_probs, value)
|
84 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
85 |
+
|
86 |
+
# linear proj
|
87 |
+
hidden_states = attn.to_out[0](hidden_states)
|
88 |
+
# dropout
|
89 |
+
hidden_states = attn.to_out[1](hidden_states)
|
90 |
+
|
91 |
+
if input_ndim == 4:
|
92 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
93 |
+
|
94 |
+
if attn.residual_connection:
|
95 |
+
hidden_states = hidden_states + residual
|
96 |
+
|
97 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
98 |
+
|
99 |
+
return hidden_states
|
100 |
+
|
101 |
+
|
102 |
+
class IPAttnProcessor(nn.Module):
|
103 |
+
r"""
|
104 |
+
Attention processor for IP-Adapater.
|
105 |
+
Args:
|
106 |
+
hidden_size (`int`):
|
107 |
+
The hidden size of the attention layer.
|
108 |
+
cross_attention_dim (`int`):
|
109 |
+
The number of channels in the `encoder_hidden_states`.
|
110 |
+
scale (`float`, defaults to 1.0):
|
111 |
+
the weight scale of image prompt.
|
112 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
113 |
+
The context length of the image features.
|
114 |
+
"""
|
115 |
+
|
116 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
|
117 |
+
super().__init__()
|
118 |
+
|
119 |
+
self.hidden_size = hidden_size
|
120 |
+
self.cross_attention_dim = cross_attention_dim
|
121 |
+
self.scale = scale
|
122 |
+
self.num_tokens = num_tokens
|
123 |
+
|
124 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
125 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
126 |
+
|
127 |
+
def __call__(
|
128 |
+
self,
|
129 |
+
attn,
|
130 |
+
hidden_states,
|
131 |
+
encoder_hidden_states=None,
|
132 |
+
attention_mask=None,
|
133 |
+
temb=None,
|
134 |
+
):
|
135 |
+
residual = hidden_states
|
136 |
+
|
137 |
+
if attn.spatial_norm is not None:
|
138 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
139 |
+
|
140 |
+
input_ndim = hidden_states.ndim
|
141 |
+
|
142 |
+
if input_ndim == 4:
|
143 |
+
batch_size, channel, height, width = hidden_states.shape
|
144 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
145 |
+
|
146 |
+
batch_size, sequence_length, _ = (
|
147 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
148 |
+
)
|
149 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
150 |
+
|
151 |
+
if attn.group_norm is not None:
|
152 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
153 |
+
|
154 |
+
query = attn.to_q(hidden_states)
|
155 |
+
|
156 |
+
if encoder_hidden_states is None:
|
157 |
+
encoder_hidden_states = hidden_states
|
158 |
+
else:
|
159 |
+
# get encoder_hidden_states, ip_hidden_states
|
160 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
161 |
+
encoder_hidden_states, ip_hidden_states = (
|
162 |
+
encoder_hidden_states[:, :end_pos, :],
|
163 |
+
encoder_hidden_states[:, end_pos:, :],
|
164 |
+
)
|
165 |
+
if attn.norm_cross:
|
166 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
167 |
+
|
168 |
+
key = attn.to_k(encoder_hidden_states)
|
169 |
+
value = attn.to_v(encoder_hidden_states)
|
170 |
+
|
171 |
+
query = attn.head_to_batch_dim(query)
|
172 |
+
key = attn.head_to_batch_dim(key)
|
173 |
+
value = attn.head_to_batch_dim(value)
|
174 |
+
|
175 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
176 |
+
hidden_states = torch.bmm(attention_probs, value)
|
177 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
178 |
+
|
179 |
+
# for ip-adapter
|
180 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
181 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
182 |
+
|
183 |
+
ip_key = attn.head_to_batch_dim(ip_key)
|
184 |
+
ip_value = attn.head_to_batch_dim(ip_value)
|
185 |
+
|
186 |
+
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
187 |
+
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
188 |
+
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
189 |
+
|
190 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
191 |
+
|
192 |
+
# linear proj
|
193 |
+
hidden_states = attn.to_out[0](hidden_states)
|
194 |
+
# dropout
|
195 |
+
hidden_states = attn.to_out[1](hidden_states)
|
196 |
+
|
197 |
+
if input_ndim == 4:
|
198 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
199 |
+
|
200 |
+
if attn.residual_connection:
|
201 |
+
hidden_states = hidden_states + residual
|
202 |
+
|
203 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
204 |
+
|
205 |
+
return hidden_states
|
206 |
+
|
207 |
+
|
208 |
+
class TA_IPAttnProcessor(nn.Module):
|
209 |
+
r"""
|
210 |
+
Attention processor for IP-Adapater.
|
211 |
+
Args:
|
212 |
+
hidden_size (`int`):
|
213 |
+
The hidden size of the attention layer.
|
214 |
+
cross_attention_dim (`int`):
|
215 |
+
The number of channels in the `encoder_hidden_states`.
|
216 |
+
scale (`float`, defaults to 1.0):
|
217 |
+
the weight scale of image prompt.
|
218 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
219 |
+
The context length of the image features.
|
220 |
+
"""
|
221 |
+
|
222 |
+
def __init__(self, hidden_size, cross_attention_dim=None, time_embedding_dim: int = None, scale=1.0, num_tokens=4):
|
223 |
+
super().__init__()
|
224 |
+
|
225 |
+
self.hidden_size = hidden_size
|
226 |
+
self.cross_attention_dim = cross_attention_dim
|
227 |
+
self.scale = scale
|
228 |
+
self.num_tokens = num_tokens
|
229 |
+
|
230 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
231 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
232 |
+
|
233 |
+
self.ln_k_ip = AdaLayerNorm(hidden_size, time_embedding_dim)
|
234 |
+
self.ln_v_ip = AdaLayerNorm(hidden_size, time_embedding_dim)
|
235 |
+
|
236 |
+
def __call__(
|
237 |
+
self,
|
238 |
+
attn,
|
239 |
+
hidden_states,
|
240 |
+
encoder_hidden_states=None,
|
241 |
+
attention_mask=None,
|
242 |
+
temb=None,
|
243 |
+
):
|
244 |
+
assert temb is not None, "Timestep embedding is needed for a time-aware attention processor."
|
245 |
+
|
246 |
+
residual = hidden_states
|
247 |
+
|
248 |
+
if attn.spatial_norm is not None:
|
249 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
250 |
+
|
251 |
+
input_ndim = hidden_states.ndim
|
252 |
+
|
253 |
+
if input_ndim == 4:
|
254 |
+
batch_size, channel, height, width = hidden_states.shape
|
255 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
256 |
+
|
257 |
+
batch_size, sequence_length, _ = (
|
258 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
259 |
+
)
|
260 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
261 |
+
|
262 |
+
if attn.group_norm is not None:
|
263 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
264 |
+
|
265 |
+
query = attn.to_q(hidden_states)
|
266 |
+
|
267 |
+
if encoder_hidden_states is None:
|
268 |
+
encoder_hidden_states = hidden_states
|
269 |
+
else:
|
270 |
+
# get encoder_hidden_states, ip_hidden_states
|
271 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
272 |
+
encoder_hidden_states, ip_hidden_states = (
|
273 |
+
encoder_hidden_states[:, :end_pos, :],
|
274 |
+
encoder_hidden_states[:, end_pos:, :],
|
275 |
+
)
|
276 |
+
if attn.norm_cross:
|
277 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
278 |
+
|
279 |
+
key = attn.to_k(encoder_hidden_states)
|
280 |
+
value = attn.to_v(encoder_hidden_states)
|
281 |
+
|
282 |
+
query = attn.head_to_batch_dim(query)
|
283 |
+
key = attn.head_to_batch_dim(key)
|
284 |
+
value = attn.head_to_batch_dim(value)
|
285 |
+
|
286 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
287 |
+
hidden_states = torch.bmm(attention_probs, value)
|
288 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
289 |
+
|
290 |
+
# for ip-adapter
|
291 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
292 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
293 |
+
|
294 |
+
# time-dependent adaLN
|
295 |
+
ip_key = self.ln_k_ip(ip_key, temb)
|
296 |
+
ip_value = self.ln_v_ip(ip_value, temb)
|
297 |
+
|
298 |
+
ip_key = attn.head_to_batch_dim(ip_key)
|
299 |
+
ip_value = attn.head_to_batch_dim(ip_value)
|
300 |
+
|
301 |
+
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
302 |
+
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
303 |
+
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
304 |
+
|
305 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
306 |
+
|
307 |
+
# linear proj
|
308 |
+
hidden_states = attn.to_out[0](hidden_states)
|
309 |
+
# dropout
|
310 |
+
hidden_states = attn.to_out[1](hidden_states)
|
311 |
+
|
312 |
+
if input_ndim == 4:
|
313 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
314 |
+
|
315 |
+
if attn.residual_connection:
|
316 |
+
hidden_states = hidden_states + residual
|
317 |
+
|
318 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
319 |
+
|
320 |
+
return hidden_states
|
321 |
+
|
322 |
+
|
323 |
+
class AttnProcessor2_0(torch.nn.Module):
|
324 |
+
r"""
|
325 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
326 |
+
"""
|
327 |
+
|
328 |
+
def __init__(
|
329 |
+
self,
|
330 |
+
hidden_size=None,
|
331 |
+
cross_attention_dim=None,
|
332 |
+
):
|
333 |
+
super().__init__()
|
334 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
335 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
336 |
+
|
337 |
+
def __call__(
|
338 |
+
self,
|
339 |
+
attn,
|
340 |
+
hidden_states,
|
341 |
+
encoder_hidden_states=None,
|
342 |
+
attention_mask=None,
|
343 |
+
external_kv=None,
|
344 |
+
temb=None,
|
345 |
+
):
|
346 |
+
residual = hidden_states
|
347 |
+
|
348 |
+
if attn.spatial_norm is not None:
|
349 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
350 |
+
|
351 |
+
input_ndim = hidden_states.ndim
|
352 |
+
|
353 |
+
if input_ndim == 4:
|
354 |
+
batch_size, channel, height, width = hidden_states.shape
|
355 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
356 |
+
|
357 |
+
batch_size, sequence_length, _ = (
|
358 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
359 |
+
)
|
360 |
+
|
361 |
+
if attention_mask is not None:
|
362 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
363 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
364 |
+
# (batch, heads, source_length, target_length)
|
365 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
366 |
+
|
367 |
+
if attn.group_norm is not None:
|
368 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
369 |
+
|
370 |
+
query = attn.to_q(hidden_states)
|
371 |
+
|
372 |
+
if encoder_hidden_states is None:
|
373 |
+
encoder_hidden_states = hidden_states
|
374 |
+
elif attn.norm_cross:
|
375 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
376 |
+
|
377 |
+
key = attn.to_k(encoder_hidden_states)
|
378 |
+
value = attn.to_v(encoder_hidden_states)
|
379 |
+
|
380 |
+
if external_kv:
|
381 |
+
key = torch.cat([key, external_kv.k], axis=1)
|
382 |
+
value = torch.cat([value, external_kv.v], axis=1)
|
383 |
+
|
384 |
+
inner_dim = key.shape[-1]
|
385 |
+
head_dim = inner_dim // attn.heads
|
386 |
+
|
387 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
388 |
+
|
389 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
390 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
391 |
+
|
392 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
393 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
394 |
+
hidden_states = F.scaled_dot_product_attention(
|
395 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
396 |
+
)
|
397 |
+
|
398 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
399 |
+
hidden_states = hidden_states.to(query.dtype)
|
400 |
+
|
401 |
+
# linear proj
|
402 |
+
hidden_states = attn.to_out[0](hidden_states)
|
403 |
+
# dropout
|
404 |
+
hidden_states = attn.to_out[1](hidden_states)
|
405 |
+
|
406 |
+
if input_ndim == 4:
|
407 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
408 |
+
|
409 |
+
if attn.residual_connection:
|
410 |
+
hidden_states = hidden_states + residual
|
411 |
+
|
412 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
413 |
+
|
414 |
+
return hidden_states
|
415 |
+
|
416 |
+
|
417 |
+
class split_AttnProcessor2_0(torch.nn.Module):
|
418 |
+
r"""
|
419 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
420 |
+
"""
|
421 |
+
|
422 |
+
def __init__(
|
423 |
+
self,
|
424 |
+
hidden_size=None,
|
425 |
+
cross_attention_dim=None,
|
426 |
+
time_embedding_dim=None,
|
427 |
+
):
|
428 |
+
super().__init__()
|
429 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
430 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
431 |
+
|
432 |
+
def __call__(
|
433 |
+
self,
|
434 |
+
attn,
|
435 |
+
hidden_states,
|
436 |
+
encoder_hidden_states=None,
|
437 |
+
attention_mask=None,
|
438 |
+
external_kv=None,
|
439 |
+
temb=None,
|
440 |
+
cat_dim=-2,
|
441 |
+
original_shape=None,
|
442 |
+
):
|
443 |
+
residual = hidden_states
|
444 |
+
|
445 |
+
if attn.spatial_norm is not None:
|
446 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
447 |
+
|
448 |
+
input_ndim = hidden_states.ndim
|
449 |
+
|
450 |
+
if input_ndim == 4:
|
451 |
+
# 2d to sequence.
|
452 |
+
height, width = hidden_states.shape[-2:]
|
453 |
+
if cat_dim==-2 or cat_dim==2:
|
454 |
+
hidden_states_0 = hidden_states[:, :, :height//2, :]
|
455 |
+
hidden_states_1 = hidden_states[:, :, -(height//2):, :]
|
456 |
+
elif cat_dim==-1 or cat_dim==3:
|
457 |
+
hidden_states_0 = hidden_states[:, :, :, :width//2]
|
458 |
+
hidden_states_1 = hidden_states[:, :, :, -(width//2):]
|
459 |
+
batch_size, channel, height, width = hidden_states_0.shape
|
460 |
+
hidden_states_0 = hidden_states_0.view(batch_size, channel, height * width).transpose(1, 2)
|
461 |
+
hidden_states_1 = hidden_states_1.view(batch_size, channel, height * width).transpose(1, 2)
|
462 |
+
else:
|
463 |
+
# directly split sqeuence according to concat dim.
|
464 |
+
single_dim = original_shape[2] if cat_dim==-2 or cat_dim==2 else original_shape[1]
|
465 |
+
hidden_states_0 = hidden_states[:, :single_dim*single_dim,:]
|
466 |
+
hidden_states_1 = hidden_states[:, single_dim*(single_dim+1):,:]
|
467 |
+
|
468 |
+
hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=1)
|
469 |
+
batch_size, sequence_length, _ = (
|
470 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
471 |
+
)
|
472 |
+
|
473 |
+
if attention_mask is not None:
|
474 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
475 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
476 |
+
# (batch, heads, source_length, target_length)
|
477 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
478 |
+
|
479 |
+
if attn.group_norm is not None:
|
480 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
481 |
+
|
482 |
+
query = attn.to_q(hidden_states)
|
483 |
+
key = attn.to_k(hidden_states)
|
484 |
+
value = attn.to_v(hidden_states)
|
485 |
+
|
486 |
+
if external_kv:
|
487 |
+
key = torch.cat([key, external_kv.k], dim=1)
|
488 |
+
value = torch.cat([value, external_kv.v], dim=1)
|
489 |
+
|
490 |
+
inner_dim = key.shape[-1]
|
491 |
+
head_dim = inner_dim // attn.heads
|
492 |
+
|
493 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
494 |
+
|
495 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
496 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
497 |
+
|
498 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
499 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
500 |
+
hidden_states = F.scaled_dot_product_attention(
|
501 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
502 |
+
)
|
503 |
+
|
504 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
505 |
+
hidden_states = hidden_states.to(query.dtype)
|
506 |
+
|
507 |
+
# linear proj
|
508 |
+
hidden_states = attn.to_out[0](hidden_states)
|
509 |
+
# dropout
|
510 |
+
hidden_states = attn.to_out[1](hidden_states)
|
511 |
+
|
512 |
+
# spatially split.
|
513 |
+
hidden_states_0, hidden_states_1 = hidden_states.chunk(2, dim=1)
|
514 |
+
|
515 |
+
if input_ndim == 4:
|
516 |
+
hidden_states_0 = hidden_states_0.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
517 |
+
hidden_states_1 = hidden_states_1.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
518 |
+
|
519 |
+
if cat_dim==-2 or cat_dim==2:
|
520 |
+
hidden_states_pad = torch.zeros(batch_size, channel, 1, width)
|
521 |
+
elif cat_dim==-1 or cat_dim==3:
|
522 |
+
hidden_states_pad = torch.zeros(batch_size, channel, height, 1)
|
523 |
+
hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype)
|
524 |
+
hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=cat_dim)
|
525 |
+
assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}"
|
526 |
+
else:
|
527 |
+
batch_size, sequence_length, inner_dim = hidden_states.shape
|
528 |
+
hidden_states_pad = torch.zeros(batch_size, single_dim, inner_dim)
|
529 |
+
hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype)
|
530 |
+
hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=1)
|
531 |
+
assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}"
|
532 |
+
|
533 |
+
if attn.residual_connection:
|
534 |
+
hidden_states = hidden_states + residual
|
535 |
+
|
536 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
537 |
+
|
538 |
+
return hidden_states
|
539 |
+
|
540 |
+
|
541 |
+
class sep_split_AttnProcessor2_0(torch.nn.Module):
|
542 |
+
r"""
|
543 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
544 |
+
"""
|
545 |
+
|
546 |
+
def __init__(
|
547 |
+
self,
|
548 |
+
hidden_size=None,
|
549 |
+
cross_attention_dim=None,
|
550 |
+
time_embedding_dim=None,
|
551 |
+
):
|
552 |
+
super().__init__()
|
553 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
554 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
555 |
+
self.ln_k_ref = AdaLayerNorm(hidden_size, time_embedding_dim)
|
556 |
+
self.ln_v_ref = AdaLayerNorm(hidden_size, time_embedding_dim)
|
557 |
+
# self.hidden_size = hidden_size
|
558 |
+
# self.cross_attention_dim = cross_attention_dim
|
559 |
+
# self.scale = scale
|
560 |
+
# self.num_tokens = num_tokens
|
561 |
+
|
562 |
+
# self.to_q_ref = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
563 |
+
# self.to_k_ref = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
564 |
+
# self.to_v_ref = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
565 |
+
|
566 |
+
def __call__(
|
567 |
+
self,
|
568 |
+
attn,
|
569 |
+
hidden_states,
|
570 |
+
encoder_hidden_states=None,
|
571 |
+
attention_mask=None,
|
572 |
+
external_kv=None,
|
573 |
+
temb=None,
|
574 |
+
cat_dim=-2,
|
575 |
+
original_shape=None,
|
576 |
+
ref_scale=1.0,
|
577 |
+
):
|
578 |
+
residual = hidden_states
|
579 |
+
|
580 |
+
if attn.spatial_norm is not None:
|
581 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
582 |
+
|
583 |
+
input_ndim = hidden_states.ndim
|
584 |
+
|
585 |
+
if input_ndim == 4:
|
586 |
+
# 2d to sequence.
|
587 |
+
height, width = hidden_states.shape[-2:]
|
588 |
+
if cat_dim==-2 or cat_dim==2:
|
589 |
+
hidden_states_0 = hidden_states[:, :, :height//2, :]
|
590 |
+
hidden_states_1 = hidden_states[:, :, -(height//2):, :]
|
591 |
+
elif cat_dim==-1 or cat_dim==3:
|
592 |
+
hidden_states_0 = hidden_states[:, :, :, :width//2]
|
593 |
+
hidden_states_1 = hidden_states[:, :, :, -(width//2):]
|
594 |
+
batch_size, channel, height, width = hidden_states_0.shape
|
595 |
+
hidden_states_0 = hidden_states_0.view(batch_size, channel, height * width).transpose(1, 2)
|
596 |
+
hidden_states_1 = hidden_states_1.view(batch_size, channel, height * width).transpose(1, 2)
|
597 |
+
else:
|
598 |
+
# directly split sqeuence according to concat dim.
|
599 |
+
single_dim = original_shape[2] if cat_dim==-2 or cat_dim==2 else original_shape[1]
|
600 |
+
hidden_states_0 = hidden_states[:, :single_dim*single_dim,:]
|
601 |
+
hidden_states_1 = hidden_states[:, single_dim*(single_dim+1):,:]
|
602 |
+
|
603 |
+
batch_size, sequence_length, _ = (
|
604 |
+
hidden_states_0.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
605 |
+
)
|
606 |
+
|
607 |
+
if attention_mask is not None:
|
608 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
609 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
610 |
+
# (batch, heads, source_length, target_length)
|
611 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
612 |
+
|
613 |
+
if attn.group_norm is not None:
|
614 |
+
hidden_states_0 = attn.group_norm(hidden_states_0.transpose(1, 2)).transpose(1, 2)
|
615 |
+
hidden_states_1 = attn.group_norm(hidden_states_1.transpose(1, 2)).transpose(1, 2)
|
616 |
+
|
617 |
+
query_0 = attn.to_q(hidden_states_0)
|
618 |
+
query_1 = attn.to_q(hidden_states_1)
|
619 |
+
key_0 = attn.to_k(hidden_states_0)
|
620 |
+
key_1 = attn.to_k(hidden_states_1)
|
621 |
+
value_0 = attn.to_v(hidden_states_0)
|
622 |
+
value_1 = attn.to_v(hidden_states_1)
|
623 |
+
|
624 |
+
# time-dependent adaLN
|
625 |
+
key_1 = self.ln_k_ref(key_1, temb)
|
626 |
+
value_1 = self.ln_v_ref(value_1, temb)
|
627 |
+
|
628 |
+
if external_kv:
|
629 |
+
key_1 = torch.cat([key_1, external_kv.k], dim=1)
|
630 |
+
value_1 = torch.cat([value_1, external_kv.v], dim=1)
|
631 |
+
|
632 |
+
inner_dim = key_0.shape[-1]
|
633 |
+
head_dim = inner_dim // attn.heads
|
634 |
+
|
635 |
+
query_0 = query_0.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
636 |
+
query_1 = query_1.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
637 |
+
key_0 = key_0.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
638 |
+
key_1 = key_1.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
639 |
+
value_0 = value_0.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
640 |
+
value_1 = value_1.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
641 |
+
|
642 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
643 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
644 |
+
hidden_states_0 = F.scaled_dot_product_attention(
|
645 |
+
query_0, key_0, value_0, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
646 |
+
)
|
647 |
+
hidden_states_1 = F.scaled_dot_product_attention(
|
648 |
+
query_1, key_1, value_1, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
649 |
+
)
|
650 |
+
|
651 |
+
# cross-attn
|
652 |
+
_hidden_states_0 = F.scaled_dot_product_attention(
|
653 |
+
query_0, key_1, value_1, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
654 |
+
)
|
655 |
+
hidden_states_0 = hidden_states_0 + ref_scale * _hidden_states_0 * 10
|
656 |
+
|
657 |
+
# TODO: drop this cross-attn
|
658 |
+
_hidden_states_1 = F.scaled_dot_product_attention(
|
659 |
+
query_1, key_0, value_0, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
660 |
+
)
|
661 |
+
hidden_states_1 = hidden_states_1 + ref_scale * _hidden_states_1
|
662 |
+
|
663 |
+
hidden_states_0 = hidden_states_0.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
664 |
+
hidden_states_1 = hidden_states_1.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
665 |
+
hidden_states_0 = hidden_states_0.to(query_0.dtype)
|
666 |
+
hidden_states_1 = hidden_states_1.to(query_1.dtype)
|
667 |
+
|
668 |
+
|
669 |
+
# linear proj
|
670 |
+
hidden_states_0 = attn.to_out[0](hidden_states_0)
|
671 |
+
hidden_states_1 = attn.to_out[0](hidden_states_1)
|
672 |
+
# dropout
|
673 |
+
hidden_states_0 = attn.to_out[1](hidden_states_0)
|
674 |
+
hidden_states_1 = attn.to_out[1](hidden_states_1)
|
675 |
+
|
676 |
+
|
677 |
+
if input_ndim == 4:
|
678 |
+
hidden_states_0 = hidden_states_0.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
679 |
+
hidden_states_1 = hidden_states_1.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
680 |
+
|
681 |
+
if cat_dim==-2 or cat_dim==2:
|
682 |
+
hidden_states_pad = torch.zeros(batch_size, channel, 1, width)
|
683 |
+
elif cat_dim==-1 or cat_dim==3:
|
684 |
+
hidden_states_pad = torch.zeros(batch_size, channel, height, 1)
|
685 |
+
hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype)
|
686 |
+
hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=cat_dim)
|
687 |
+
assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}"
|
688 |
+
else:
|
689 |
+
batch_size, sequence_length, inner_dim = hidden_states.shape
|
690 |
+
hidden_states_pad = torch.zeros(batch_size, single_dim, inner_dim)
|
691 |
+
hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype)
|
692 |
+
hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=1)
|
693 |
+
assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}"
|
694 |
+
|
695 |
+
if attn.residual_connection:
|
696 |
+
hidden_states = hidden_states + residual
|
697 |
+
|
698 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
699 |
+
|
700 |
+
return hidden_states
|
701 |
+
|
702 |
+
|
703 |
+
class AdditiveKV_AttnProcessor2_0(torch.nn.Module):
|
704 |
+
r"""
|
705 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
706 |
+
"""
|
707 |
+
|
708 |
+
def __init__(
|
709 |
+
self,
|
710 |
+
hidden_size: int = None,
|
711 |
+
cross_attention_dim: int = None,
|
712 |
+
time_embedding_dim: int = None,
|
713 |
+
additive_scale: float = 1.0,
|
714 |
+
):
|
715 |
+
super().__init__()
|
716 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
717 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
718 |
+
self.additive_scale = additive_scale
|
719 |
+
|
720 |
+
def __call__(
|
721 |
+
self,
|
722 |
+
attn,
|
723 |
+
hidden_states,
|
724 |
+
encoder_hidden_states=None,
|
725 |
+
external_kv=None,
|
726 |
+
attention_mask=None,
|
727 |
+
temb=None,
|
728 |
+
):
|
729 |
+
assert temb is not None, "Timestep embedding is needed for a time-aware attention processor."
|
730 |
+
|
731 |
+
residual = hidden_states
|
732 |
+
|
733 |
+
if attn.spatial_norm is not None:
|
734 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
735 |
+
|
736 |
+
input_ndim = hidden_states.ndim
|
737 |
+
|
738 |
+
if input_ndim == 4:
|
739 |
+
batch_size, channel, height, width = hidden_states.shape
|
740 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
741 |
+
|
742 |
+
batch_size, sequence_length, _ = (
|
743 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
744 |
+
)
|
745 |
+
|
746 |
+
if attention_mask is not None:
|
747 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
748 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
749 |
+
# (batch, heads, source_length, target_length)
|
750 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
751 |
+
|
752 |
+
if attn.group_norm is not None:
|
753 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
754 |
+
|
755 |
+
query = attn.to_q(hidden_states)
|
756 |
+
|
757 |
+
if encoder_hidden_states is None:
|
758 |
+
encoder_hidden_states = hidden_states
|
759 |
+
elif attn.norm_cross:
|
760 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
761 |
+
|
762 |
+
key = attn.to_k(encoder_hidden_states)
|
763 |
+
value = attn.to_v(encoder_hidden_states)
|
764 |
+
|
765 |
+
inner_dim = key.shape[-1]
|
766 |
+
head_dim = inner_dim // attn.heads
|
767 |
+
|
768 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
769 |
+
|
770 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
771 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
772 |
+
|
773 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
774 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
775 |
+
hidden_states = F.scaled_dot_product_attention(
|
776 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
777 |
+
)
|
778 |
+
|
779 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
780 |
+
|
781 |
+
if external_kv:
|
782 |
+
key = external_kv.k
|
783 |
+
value = external_kv.v
|
784 |
+
|
785 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
786 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
787 |
+
|
788 |
+
external_attn_output = F.scaled_dot_product_attention(
|
789 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
790 |
+
)
|
791 |
+
|
792 |
+
external_attn_output = external_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
793 |
+
hidden_states = hidden_states + self.additive_scale * external_attn_output
|
794 |
+
|
795 |
+
hidden_states = hidden_states.to(query.dtype)
|
796 |
+
|
797 |
+
# linear proj
|
798 |
+
hidden_states = attn.to_out[0](hidden_states)
|
799 |
+
# dropout
|
800 |
+
hidden_states = attn.to_out[1](hidden_states)
|
801 |
+
|
802 |
+
if input_ndim == 4:
|
803 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
804 |
+
|
805 |
+
if attn.residual_connection:
|
806 |
+
hidden_states = hidden_states + residual
|
807 |
+
|
808 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
809 |
+
|
810 |
+
return hidden_states
|
811 |
+
|
812 |
+
|
813 |
+
class TA_AdditiveKV_AttnProcessor2_0(torch.nn.Module):
|
814 |
+
r"""
|
815 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
816 |
+
"""
|
817 |
+
|
818 |
+
def __init__(
|
819 |
+
self,
|
820 |
+
hidden_size: int = None,
|
821 |
+
cross_attention_dim: int = None,
|
822 |
+
time_embedding_dim: int = None,
|
823 |
+
additive_scale: float = 1.0,
|
824 |
+
):
|
825 |
+
super().__init__()
|
826 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
827 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
828 |
+
self.ln_k = AdaLayerNorm(hidden_size, time_embedding_dim)
|
829 |
+
self.ln_v = AdaLayerNorm(hidden_size, time_embedding_dim)
|
830 |
+
self.additive_scale = additive_scale
|
831 |
+
|
832 |
+
def __call__(
|
833 |
+
self,
|
834 |
+
attn,
|
835 |
+
hidden_states,
|
836 |
+
encoder_hidden_states=None,
|
837 |
+
external_kv=None,
|
838 |
+
attention_mask=None,
|
839 |
+
temb=None,
|
840 |
+
):
|
841 |
+
assert temb is not None, "Timestep embedding is needed for a time-aware attention processor."
|
842 |
+
|
843 |
+
residual = hidden_states
|
844 |
+
|
845 |
+
if attn.spatial_norm is not None:
|
846 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
847 |
+
|
848 |
+
input_ndim = hidden_states.ndim
|
849 |
+
|
850 |
+
if input_ndim == 4:
|
851 |
+
batch_size, channel, height, width = hidden_states.shape
|
852 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
853 |
+
|
854 |
+
batch_size, sequence_length, _ = (
|
855 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
856 |
+
)
|
857 |
+
|
858 |
+
if attention_mask is not None:
|
859 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
860 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
861 |
+
# (batch, heads, source_length, target_length)
|
862 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
863 |
+
|
864 |
+
if attn.group_norm is not None:
|
865 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
866 |
+
|
867 |
+
query = attn.to_q(hidden_states)
|
868 |
+
|
869 |
+
if encoder_hidden_states is None:
|
870 |
+
encoder_hidden_states = hidden_states
|
871 |
+
elif attn.norm_cross:
|
872 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
873 |
+
|
874 |
+
key = attn.to_k(encoder_hidden_states)
|
875 |
+
value = attn.to_v(encoder_hidden_states)
|
876 |
+
|
877 |
+
inner_dim = key.shape[-1]
|
878 |
+
head_dim = inner_dim // attn.heads
|
879 |
+
|
880 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
881 |
+
|
882 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
883 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
884 |
+
|
885 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
886 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
887 |
+
hidden_states = F.scaled_dot_product_attention(
|
888 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
889 |
+
)
|
890 |
+
|
891 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
892 |
+
|
893 |
+
if external_kv:
|
894 |
+
key = external_kv.k
|
895 |
+
value = external_kv.v
|
896 |
+
|
897 |
+
# time-dependent adaLN
|
898 |
+
key = self.ln_k(key, temb)
|
899 |
+
value = self.ln_v(value, temb)
|
900 |
+
|
901 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
902 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
903 |
+
|
904 |
+
external_attn_output = F.scaled_dot_product_attention(
|
905 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
906 |
+
)
|
907 |
+
|
908 |
+
external_attn_output = external_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
909 |
+
hidden_states = hidden_states + self.additive_scale * external_attn_output
|
910 |
+
|
911 |
+
hidden_states = hidden_states.to(query.dtype)
|
912 |
+
|
913 |
+
# linear proj
|
914 |
+
hidden_states = attn.to_out[0](hidden_states)
|
915 |
+
# dropout
|
916 |
+
hidden_states = attn.to_out[1](hidden_states)
|
917 |
+
|
918 |
+
if input_ndim == 4:
|
919 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
920 |
+
|
921 |
+
if attn.residual_connection:
|
922 |
+
hidden_states = hidden_states + residual
|
923 |
+
|
924 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
925 |
+
|
926 |
+
return hidden_states
|
927 |
+
|
928 |
+
|
929 |
+
class IPAttnProcessor2_0(torch.nn.Module):
|
930 |
+
r"""
|
931 |
+
Attention processor for IP-Adapater for PyTorch 2.0.
|
932 |
+
Args:
|
933 |
+
hidden_size (`int`):
|
934 |
+
The hidden size of the attention layer.
|
935 |
+
cross_attention_dim (`int`):
|
936 |
+
The number of channels in the `encoder_hidden_states`.
|
937 |
+
scale (`float`, defaults to 1.0):
|
938 |
+
the weight scale of image prompt.
|
939 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
940 |
+
The context length of the image features.
|
941 |
+
"""
|
942 |
+
|
943 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
|
944 |
+
super().__init__()
|
945 |
+
|
946 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
947 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
948 |
+
|
949 |
+
self.hidden_size = hidden_size
|
950 |
+
self.cross_attention_dim = cross_attention_dim
|
951 |
+
self.scale = scale
|
952 |
+
self.num_tokens = num_tokens
|
953 |
+
|
954 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
955 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
956 |
+
|
957 |
+
def __call__(
|
958 |
+
self,
|
959 |
+
attn,
|
960 |
+
hidden_states,
|
961 |
+
encoder_hidden_states=None,
|
962 |
+
attention_mask=None,
|
963 |
+
temb=None,
|
964 |
+
):
|
965 |
+
residual = hidden_states
|
966 |
+
|
967 |
+
if attn.spatial_norm is not None:
|
968 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
969 |
+
|
970 |
+
input_ndim = hidden_states.ndim
|
971 |
+
|
972 |
+
if input_ndim == 4:
|
973 |
+
batch_size, channel, height, width = hidden_states.shape
|
974 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
975 |
+
|
976 |
+
if isinstance(encoder_hidden_states, tuple):
|
977 |
+
# FIXME: now hard coded to single image prompt.
|
978 |
+
batch_size, _, hid_dim = encoder_hidden_states[0].shape
|
979 |
+
ip_tokens = encoder_hidden_states[1][0]
|
980 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states[0], ip_tokens], dim=1)
|
981 |
+
|
982 |
+
batch_size, sequence_length, _ = (
|
983 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
984 |
+
)
|
985 |
+
|
986 |
+
if attention_mask is not None:
|
987 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
988 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
989 |
+
# (batch, heads, source_length, target_length)
|
990 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
991 |
+
|
992 |
+
if attn.group_norm is not None:
|
993 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
994 |
+
|
995 |
+
query = attn.to_q(hidden_states)
|
996 |
+
|
997 |
+
if encoder_hidden_states is None:
|
998 |
+
encoder_hidden_states = hidden_states
|
999 |
+
else:
|
1000 |
+
# get encoder_hidden_states, ip_hidden_states
|
1001 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
1002 |
+
encoder_hidden_states, ip_hidden_states = (
|
1003 |
+
encoder_hidden_states[:, :end_pos, :],
|
1004 |
+
encoder_hidden_states[:, end_pos:, :],
|
1005 |
+
)
|
1006 |
+
if attn.norm_cross:
|
1007 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1008 |
+
|
1009 |
+
key = attn.to_k(encoder_hidden_states)
|
1010 |
+
value = attn.to_v(encoder_hidden_states)
|
1011 |
+
|
1012 |
+
inner_dim = key.shape[-1]
|
1013 |
+
head_dim = inner_dim // attn.heads
|
1014 |
+
|
1015 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1016 |
+
|
1017 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1018 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1019 |
+
|
1020 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1021 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1022 |
+
hidden_states = F.scaled_dot_product_attention(
|
1023 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1024 |
+
)
|
1025 |
+
|
1026 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1027 |
+
hidden_states = hidden_states.to(query.dtype)
|
1028 |
+
|
1029 |
+
# for ip-adapter
|
1030 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
1031 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
1032 |
+
|
1033 |
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1034 |
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1035 |
+
|
1036 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1037 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1038 |
+
ip_hidden_states = F.scaled_dot_product_attention(
|
1039 |
+
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
1040 |
+
)
|
1041 |
+
|
1042 |
+
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1043 |
+
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
1044 |
+
|
1045 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
1046 |
+
|
1047 |
+
# linear proj
|
1048 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1049 |
+
# dropout
|
1050 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1051 |
+
|
1052 |
+
if input_ndim == 4:
|
1053 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1054 |
+
|
1055 |
+
if attn.residual_connection:
|
1056 |
+
hidden_states = hidden_states + residual
|
1057 |
+
|
1058 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1059 |
+
|
1060 |
+
return hidden_states
|
1061 |
+
|
1062 |
+
|
1063 |
+
class TA_IPAttnProcessor2_0(torch.nn.Module):
|
1064 |
+
r"""
|
1065 |
+
Attention processor for IP-Adapater for PyTorch 2.0.
|
1066 |
+
Args:
|
1067 |
+
hidden_size (`int`):
|
1068 |
+
The hidden size of the attention layer.
|
1069 |
+
cross_attention_dim (`int`):
|
1070 |
+
The number of channels in the `encoder_hidden_states`.
|
1071 |
+
scale (`float`, defaults to 1.0):
|
1072 |
+
the weight scale of image prompt.
|
1073 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
1074 |
+
The context length of the image features.
|
1075 |
+
"""
|
1076 |
+
|
1077 |
+
def __init__(self, hidden_size, cross_attention_dim=None, time_embedding_dim: int = None, scale=1.0, num_tokens=4):
|
1078 |
+
super().__init__()
|
1079 |
+
|
1080 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1081 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1082 |
+
|
1083 |
+
self.hidden_size = hidden_size
|
1084 |
+
self.cross_attention_dim = cross_attention_dim
|
1085 |
+
self.scale = scale
|
1086 |
+
self.num_tokens = num_tokens
|
1087 |
+
|
1088 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
1089 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
1090 |
+
self.ln_k_ip = AdaLayerNorm(hidden_size, time_embedding_dim)
|
1091 |
+
self.ln_v_ip = AdaLayerNorm(hidden_size, time_embedding_dim)
|
1092 |
+
|
1093 |
+
def __call__(
|
1094 |
+
self,
|
1095 |
+
attn,
|
1096 |
+
hidden_states,
|
1097 |
+
encoder_hidden_states=None,
|
1098 |
+
attention_mask=None,
|
1099 |
+
external_kv=None,
|
1100 |
+
temb=None,
|
1101 |
+
):
|
1102 |
+
assert temb is not None, "Timestep embedding is needed for a time-aware attention processor."
|
1103 |
+
|
1104 |
+
residual = hidden_states
|
1105 |
+
|
1106 |
+
if attn.spatial_norm is not None:
|
1107 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1108 |
+
|
1109 |
+
input_ndim = hidden_states.ndim
|
1110 |
+
|
1111 |
+
if input_ndim == 4:
|
1112 |
+
batch_size, channel, height, width = hidden_states.shape
|
1113 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1114 |
+
|
1115 |
+
if not isinstance(encoder_hidden_states, tuple):
|
1116 |
+
# get encoder_hidden_states, ip_hidden_states
|
1117 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
1118 |
+
encoder_hidden_states, ip_hidden_states = (
|
1119 |
+
encoder_hidden_states[:, :end_pos, :],
|
1120 |
+
encoder_hidden_states[:, end_pos:, :],
|
1121 |
+
)
|
1122 |
+
else:
|
1123 |
+
# FIXME: now hard coded to single image prompt.
|
1124 |
+
batch_size, _, hid_dim = encoder_hidden_states[0].shape
|
1125 |
+
ip_hidden_states = encoder_hidden_states[1][0]
|
1126 |
+
encoder_hidden_states = encoder_hidden_states[0]
|
1127 |
+
batch_size, sequence_length, _ = (
|
1128 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1129 |
+
)
|
1130 |
+
|
1131 |
+
if attention_mask is not None:
|
1132 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1133 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
1134 |
+
# (batch, heads, source_length, target_length)
|
1135 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1136 |
+
|
1137 |
+
if attn.group_norm is not None:
|
1138 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1139 |
+
|
1140 |
+
query = attn.to_q(hidden_states)
|
1141 |
+
|
1142 |
+
if encoder_hidden_states is None:
|
1143 |
+
encoder_hidden_states = hidden_states
|
1144 |
+
else:
|
1145 |
+
if attn.norm_cross:
|
1146 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1147 |
+
|
1148 |
+
key = attn.to_k(encoder_hidden_states)
|
1149 |
+
value = attn.to_v(encoder_hidden_states)
|
1150 |
+
|
1151 |
+
if external_kv:
|
1152 |
+
key = torch.cat([key, external_kv.k], axis=1)
|
1153 |
+
value = torch.cat([value, external_kv.v], axis=1)
|
1154 |
+
|
1155 |
+
inner_dim = key.shape[-1]
|
1156 |
+
head_dim = inner_dim // attn.heads
|
1157 |
+
|
1158 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1159 |
+
|
1160 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1161 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1162 |
+
|
1163 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1164 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1165 |
+
hidden_states = F.scaled_dot_product_attention(
|
1166 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1167 |
+
)
|
1168 |
+
|
1169 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1170 |
+
hidden_states = hidden_states.to(query.dtype)
|
1171 |
+
|
1172 |
+
# for ip-adapter
|
1173 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
1174 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
1175 |
+
|
1176 |
+
# time-dependent adaLN
|
1177 |
+
ip_key = self.ln_k_ip(ip_key, temb)
|
1178 |
+
ip_value = self.ln_v_ip(ip_value, temb)
|
1179 |
+
|
1180 |
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1181 |
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1182 |
+
|
1183 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1184 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1185 |
+
ip_hidden_states = F.scaled_dot_product_attention(
|
1186 |
+
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
1187 |
+
)
|
1188 |
+
|
1189 |
+
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1190 |
+
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
1191 |
+
|
1192 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
1193 |
+
|
1194 |
+
# linear proj
|
1195 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1196 |
+
# dropout
|
1197 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1198 |
+
|
1199 |
+
if input_ndim == 4:
|
1200 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1201 |
+
|
1202 |
+
if attn.residual_connection:
|
1203 |
+
hidden_states = hidden_states + residual
|
1204 |
+
|
1205 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1206 |
+
|
1207 |
+
return hidden_states
|
1208 |
+
|
1209 |
+
|
1210 |
+
## for controlnet
|
1211 |
+
class CNAttnProcessor:
|
1212 |
+
r"""
|
1213 |
+
Default processor for performing attention-related computations.
|
1214 |
+
"""
|
1215 |
+
|
1216 |
+
def __init__(self, num_tokens=4):
|
1217 |
+
self.num_tokens = num_tokens
|
1218 |
+
|
1219 |
+
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
|
1220 |
+
residual = hidden_states
|
1221 |
+
|
1222 |
+
if attn.spatial_norm is not None:
|
1223 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1224 |
+
|
1225 |
+
input_ndim = hidden_states.ndim
|
1226 |
+
|
1227 |
+
if input_ndim == 4:
|
1228 |
+
batch_size, channel, height, width = hidden_states.shape
|
1229 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1230 |
+
|
1231 |
+
batch_size, sequence_length, _ = (
|
1232 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1233 |
+
)
|
1234 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1235 |
+
|
1236 |
+
if attn.group_norm is not None:
|
1237 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1238 |
+
|
1239 |
+
query = attn.to_q(hidden_states)
|
1240 |
+
|
1241 |
+
if encoder_hidden_states is None:
|
1242 |
+
encoder_hidden_states = hidden_states
|
1243 |
+
else:
|
1244 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
1245 |
+
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
|
1246 |
+
if attn.norm_cross:
|
1247 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1248 |
+
|
1249 |
+
key = attn.to_k(encoder_hidden_states)
|
1250 |
+
value = attn.to_v(encoder_hidden_states)
|
1251 |
+
|
1252 |
+
query = attn.head_to_batch_dim(query)
|
1253 |
+
key = attn.head_to_batch_dim(key)
|
1254 |
+
value = attn.head_to_batch_dim(value)
|
1255 |
+
|
1256 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
1257 |
+
hidden_states = torch.bmm(attention_probs, value)
|
1258 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1259 |
+
|
1260 |
+
# linear proj
|
1261 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1262 |
+
# dropout
|
1263 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1264 |
+
|
1265 |
+
if input_ndim == 4:
|
1266 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1267 |
+
|
1268 |
+
if attn.residual_connection:
|
1269 |
+
hidden_states = hidden_states + residual
|
1270 |
+
|
1271 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1272 |
+
|
1273 |
+
return hidden_states
|
1274 |
+
|
1275 |
+
|
1276 |
+
class CNAttnProcessor2_0:
|
1277 |
+
r"""
|
1278 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
1279 |
+
"""
|
1280 |
+
|
1281 |
+
def __init__(self, num_tokens=4):
|
1282 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1283 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1284 |
+
self.num_tokens = num_tokens
|
1285 |
+
|
1286 |
+
def __call__(
|
1287 |
+
self,
|
1288 |
+
attn,
|
1289 |
+
hidden_states,
|
1290 |
+
encoder_hidden_states=None,
|
1291 |
+
attention_mask=None,
|
1292 |
+
temb=None,
|
1293 |
+
):
|
1294 |
+
residual = hidden_states
|
1295 |
+
|
1296 |
+
if attn.spatial_norm is not None:
|
1297 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1298 |
+
|
1299 |
+
input_ndim = hidden_states.ndim
|
1300 |
+
|
1301 |
+
if input_ndim == 4:
|
1302 |
+
batch_size, channel, height, width = hidden_states.shape
|
1303 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1304 |
+
|
1305 |
+
batch_size, sequence_length, _ = (
|
1306 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1307 |
+
)
|
1308 |
+
|
1309 |
+
if attention_mask is not None:
|
1310 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1311 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
1312 |
+
# (batch, heads, source_length, target_length)
|
1313 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1314 |
+
|
1315 |
+
if attn.group_norm is not None:
|
1316 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1317 |
+
|
1318 |
+
query = attn.to_q(hidden_states)
|
1319 |
+
|
1320 |
+
if encoder_hidden_states is None:
|
1321 |
+
encoder_hidden_states = hidden_states
|
1322 |
+
else:
|
1323 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
1324 |
+
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
|
1325 |
+
if attn.norm_cross:
|
1326 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1327 |
+
|
1328 |
+
key = attn.to_k(encoder_hidden_states)
|
1329 |
+
value = attn.to_v(encoder_hidden_states)
|
1330 |
+
|
1331 |
+
inner_dim = key.shape[-1]
|
1332 |
+
head_dim = inner_dim // attn.heads
|
1333 |
+
|
1334 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1335 |
+
|
1336 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1337 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1338 |
+
|
1339 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1340 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1341 |
+
hidden_states = F.scaled_dot_product_attention(
|
1342 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1343 |
+
)
|
1344 |
+
|
1345 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1346 |
+
hidden_states = hidden_states.to(query.dtype)
|
1347 |
+
|
1348 |
+
# linear proj
|
1349 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1350 |
+
# dropout
|
1351 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1352 |
+
|
1353 |
+
if input_ndim == 4:
|
1354 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1355 |
+
|
1356 |
+
if attn.residual_connection:
|
1357 |
+
hidden_states = hidden_states + residual
|
1358 |
+
|
1359 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1360 |
+
|
1361 |
+
return hidden_states
|
1362 |
+
|
1363 |
+
|
1364 |
+
def init_attn_proc(unet, ip_adapter_tokens=16, use_lcm=False, use_adaln=True, use_external_kv=False):
|
1365 |
+
attn_procs = {}
|
1366 |
+
unet_sd = unet.state_dict()
|
1367 |
+
for name in unet.attn_processors.keys():
|
1368 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
1369 |
+
if name.startswith("mid_block"):
|
1370 |
+
hidden_size = unet.config.block_out_channels[-1]
|
1371 |
+
elif name.startswith("up_blocks"):
|
1372 |
+
block_id = int(name[len("up_blocks.")])
|
1373 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
1374 |
+
elif name.startswith("down_blocks"):
|
1375 |
+
block_id = int(name[len("down_blocks.")])
|
1376 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
1377 |
+
if cross_attention_dim is None:
|
1378 |
+
if use_external_kv:
|
1379 |
+
attn_procs[name] = AdditiveKV_AttnProcessor2_0(
|
1380 |
+
hidden_size=hidden_size,
|
1381 |
+
cross_attention_dim=cross_attention_dim,
|
1382 |
+
time_embedding_dim=1280,
|
1383 |
+
) if hasattr(F, "scaled_dot_product_attention") else AdditiveKV_AttnProcessor()
|
1384 |
+
else:
|
1385 |
+
attn_procs[name] = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
|
1386 |
+
else:
|
1387 |
+
if use_adaln:
|
1388 |
+
layer_name = name.split(".processor")[0]
|
1389 |
+
if use_lcm:
|
1390 |
+
weights = {
|
1391 |
+
"to_k_ip.weight": unet_sd[layer_name + ".to_k.base_layer.weight"],
|
1392 |
+
"to_v_ip.weight": unet_sd[layer_name + ".to_v.base_layer.weight"],
|
1393 |
+
}
|
1394 |
+
else:
|
1395 |
+
weights = {
|
1396 |
+
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
|
1397 |
+
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
|
1398 |
+
}
|
1399 |
+
attn_procs[name] = TA_IPAttnProcessor2_0(
|
1400 |
+
hidden_size=hidden_size,
|
1401 |
+
cross_attention_dim=cross_attention_dim,
|
1402 |
+
num_tokens=ip_adapter_tokens,
|
1403 |
+
time_embedding_dim=1280,
|
1404 |
+
) if hasattr(F, "scaled_dot_product_attention") else \
|
1405 |
+
TA_IPAttnProcessor(
|
1406 |
+
hidden_size=hidden_size,
|
1407 |
+
cross_attention_dim=cross_attention_dim,
|
1408 |
+
num_tokens=ip_adapter_tokens,
|
1409 |
+
time_embedding_dim=1280,
|
1410 |
+
)
|
1411 |
+
attn_procs[name].load_state_dict(weights, strict=False)
|
1412 |
+
else:
|
1413 |
+
attn_procs[name] = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
|
1414 |
+
|
1415 |
+
return attn_procs
|
1416 |
+
|
1417 |
+
|
1418 |
+
def init_aggregator_attn_proc(unet, use_adaln=False, split_attn=False):
|
1419 |
+
attn_procs = {}
|
1420 |
+
unet_sd = unet.state_dict()
|
1421 |
+
for name in unet.attn_processors.keys():
|
1422 |
+
# get layer name and hidden dim
|
1423 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
1424 |
+
if name.startswith("mid_block"):
|
1425 |
+
hidden_size = unet.config.block_out_channels[-1]
|
1426 |
+
elif name.startswith("up_blocks"):
|
1427 |
+
block_id = int(name[len("up_blocks.")])
|
1428 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
1429 |
+
elif name.startswith("down_blocks"):
|
1430 |
+
block_id = int(name[len("down_blocks.")])
|
1431 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
1432 |
+
# init attn proc
|
1433 |
+
if split_attn:
|
1434 |
+
# layer_name = name.split(".processor")[0]
|
1435 |
+
# weights = {
|
1436 |
+
# "to_q_ref.weight": unet_sd[layer_name + ".to_q.weight"],
|
1437 |
+
# "to_k_ref.weight": unet_sd[layer_name + ".to_k.weight"],
|
1438 |
+
# "to_v_ref.weight": unet_sd[layer_name + ".to_v.weight"],
|
1439 |
+
# }
|
1440 |
+
attn_procs[name] = (
|
1441 |
+
sep_split_AttnProcessor2_0(
|
1442 |
+
hidden_size=hidden_size,
|
1443 |
+
cross_attention_dim=hidden_size,
|
1444 |
+
time_embedding_dim=1280,
|
1445 |
+
)
|
1446 |
+
if use_adaln
|
1447 |
+
else split_AttnProcessor2_0(
|
1448 |
+
hidden_size=hidden_size,
|
1449 |
+
cross_attention_dim=cross_attention_dim,
|
1450 |
+
time_embedding_dim=1280,
|
1451 |
+
)
|
1452 |
+
)
|
1453 |
+
# attn_procs[name].load_state_dict(weights, strict=False)
|
1454 |
+
else:
|
1455 |
+
attn_procs[name] = (
|
1456 |
+
AttnProcessor2_0(
|
1457 |
+
hidden_size=hidden_size,
|
1458 |
+
cross_attention_dim=hidden_size,
|
1459 |
+
)
|
1460 |
+
if hasattr(F, "scaled_dot_product_attention")
|
1461 |
+
else AttnProcessor(
|
1462 |
+
hidden_size=hidden_size,
|
1463 |
+
cross_attention_dim=hidden_size,
|
1464 |
+
)
|
1465 |
+
)
|
1466 |
+
|
1467 |
+
return attn_procs
|
module/ip_adapter/ip_adapter.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from typing import List
|
4 |
+
from collections import namedtuple, OrderedDict
|
5 |
+
|
6 |
+
def is_torch2_available():
|
7 |
+
return hasattr(torch.nn.functional, "scaled_dot_product_attention")
|
8 |
+
|
9 |
+
if is_torch2_available():
|
10 |
+
from .attention_processor import (
|
11 |
+
AttnProcessor2_0 as AttnProcessor,
|
12 |
+
)
|
13 |
+
from .attention_processor import (
|
14 |
+
CNAttnProcessor2_0 as CNAttnProcessor,
|
15 |
+
)
|
16 |
+
from .attention_processor import (
|
17 |
+
IPAttnProcessor2_0 as IPAttnProcessor,
|
18 |
+
)
|
19 |
+
from .attention_processor import (
|
20 |
+
TA_IPAttnProcessor2_0 as TA_IPAttnProcessor,
|
21 |
+
)
|
22 |
+
else:
|
23 |
+
from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor, TA_IPAttnProcessor
|
24 |
+
|
25 |
+
|
26 |
+
class ImageProjModel(torch.nn.Module):
|
27 |
+
"""Projection Model"""
|
28 |
+
|
29 |
+
def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280, clip_extra_context_tokens=4):
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
self.cross_attention_dim = cross_attention_dim
|
33 |
+
self.clip_extra_context_tokens = clip_extra_context_tokens
|
34 |
+
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
35 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
36 |
+
|
37 |
+
def forward(self, image_embeds):
|
38 |
+
embeds = image_embeds
|
39 |
+
clip_extra_context_tokens = self.proj(embeds).reshape(
|
40 |
+
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
41 |
+
)
|
42 |
+
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
43 |
+
return clip_extra_context_tokens
|
44 |
+
|
45 |
+
|
46 |
+
class MLPProjModel(torch.nn.Module):
|
47 |
+
"""SD model with image prompt"""
|
48 |
+
def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280):
|
49 |
+
super().__init__()
|
50 |
+
|
51 |
+
self.proj = torch.nn.Sequential(
|
52 |
+
torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
|
53 |
+
torch.nn.GELU(),
|
54 |
+
torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
|
55 |
+
torch.nn.LayerNorm(cross_attention_dim)
|
56 |
+
)
|
57 |
+
|
58 |
+
def forward(self, image_embeds):
|
59 |
+
clip_extra_context_tokens = self.proj(image_embeds)
|
60 |
+
return clip_extra_context_tokens
|
61 |
+
|
62 |
+
|
63 |
+
class MultiIPAdapterImageProjection(torch.nn.Module):
|
64 |
+
def __init__(self, IPAdapterImageProjectionLayers):
|
65 |
+
super().__init__()
|
66 |
+
self.image_projection_layers = torch.nn.ModuleList(IPAdapterImageProjectionLayers)
|
67 |
+
|
68 |
+
def forward(self, image_embeds: List[torch.FloatTensor]):
|
69 |
+
projected_image_embeds = []
|
70 |
+
|
71 |
+
# currently, we accept `image_embeds` as
|
72 |
+
# 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim]
|
73 |
+
# 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim]
|
74 |
+
if not isinstance(image_embeds, list):
|
75 |
+
image_embeds = [image_embeds.unsqueeze(1)]
|
76 |
+
|
77 |
+
if len(image_embeds) != len(self.image_projection_layers):
|
78 |
+
raise ValueError(
|
79 |
+
f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}"
|
80 |
+
)
|
81 |
+
|
82 |
+
for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers):
|
83 |
+
batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
|
84 |
+
image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])
|
85 |
+
image_embed = image_projection_layer(image_embed)
|
86 |
+
# image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])
|
87 |
+
|
88 |
+
projected_image_embeds.append(image_embed)
|
89 |
+
|
90 |
+
return projected_image_embeds
|
91 |
+
|
92 |
+
|
93 |
+
class IPAdapter(torch.nn.Module):
|
94 |
+
"""IP-Adapter"""
|
95 |
+
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
|
96 |
+
super().__init__()
|
97 |
+
self.unet = unet
|
98 |
+
self.image_proj = image_proj_model
|
99 |
+
self.ip_adapter = adapter_modules
|
100 |
+
|
101 |
+
if ckpt_path is not None:
|
102 |
+
self.load_from_checkpoint(ckpt_path)
|
103 |
+
|
104 |
+
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
|
105 |
+
ip_tokens = self.image_proj(image_embeds)
|
106 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
|
107 |
+
# Predict the noise residual
|
108 |
+
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
109 |
+
return noise_pred
|
110 |
+
|
111 |
+
def load_from_checkpoint(self, ckpt_path: str):
|
112 |
+
# Calculate original checksums
|
113 |
+
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj.parameters()]))
|
114 |
+
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.ip_adapter.parameters()]))
|
115 |
+
|
116 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")
|
117 |
+
keys = list(state_dict.keys())
|
118 |
+
if keys != ["image_proj", "ip_adapter"]:
|
119 |
+
state_dict = revise_state_dict(state_dict)
|
120 |
+
|
121 |
+
# Load state dict for image_proj_model and adapter_modules
|
122 |
+
self.image_proj.load_state_dict(state_dict["image_proj"], strict=True)
|
123 |
+
self.ip_adapter.load_state_dict(state_dict["ip_adapter"], strict=True)
|
124 |
+
|
125 |
+
# Calculate new checksums
|
126 |
+
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj.parameters()]))
|
127 |
+
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.ip_adapter.parameters()]))
|
128 |
+
|
129 |
+
# Verify if the weights have changed
|
130 |
+
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
|
131 |
+
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
|
132 |
+
|
133 |
+
|
134 |
+
class IPAdapterPlus(torch.nn.Module):
|
135 |
+
"""IP-Adapter"""
|
136 |
+
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
|
137 |
+
super().__init__()
|
138 |
+
self.unet = unet
|
139 |
+
self.image_proj = image_proj_model
|
140 |
+
self.ip_adapter = adapter_modules
|
141 |
+
|
142 |
+
if ckpt_path is not None:
|
143 |
+
self.load_from_checkpoint(ckpt_path)
|
144 |
+
|
145 |
+
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
|
146 |
+
ip_tokens = self.image_proj(image_embeds)
|
147 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
|
148 |
+
# Predict the noise residual
|
149 |
+
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
150 |
+
return noise_pred
|
151 |
+
|
152 |
+
def load_from_checkpoint(self, ckpt_path: str):
|
153 |
+
# Calculate original checksums
|
154 |
+
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj.parameters()]))
|
155 |
+
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.ip_adapter.parameters()]))
|
156 |
+
org_unet_sum = []
|
157 |
+
for attn_name, attn_proc in self.unet.attn_processors.items():
|
158 |
+
if isinstance(attn_proc, (TA_IPAttnProcessor, IPAttnProcessor)):
|
159 |
+
org_unet_sum.append(torch.sum(torch.stack([torch.sum(p) for p in attn_proc.parameters()])))
|
160 |
+
org_unet_sum = torch.sum(torch.stack(org_unet_sum))
|
161 |
+
|
162 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")
|
163 |
+
keys = list(state_dict.keys())
|
164 |
+
if keys != ["image_proj", "ip_adapter"]:
|
165 |
+
state_dict = revise_state_dict(state_dict)
|
166 |
+
|
167 |
+
# Check if 'latents' exists in both the saved state_dict and the current model's state_dict
|
168 |
+
strict_load_image_proj_model = True
|
169 |
+
if "latents" in state_dict["image_proj"] and "latents" in self.image_proj.state_dict():
|
170 |
+
# Check if the shapes are mismatched
|
171 |
+
if state_dict["image_proj"]["latents"].shape != self.image_proj.state_dict()["latents"].shape:
|
172 |
+
print(f"Shapes of 'image_proj.latents' in checkpoint {ckpt_path} and current model do not match.")
|
173 |
+
print("Removing 'latents' from checkpoint and loading the rest of the weights.")
|
174 |
+
del state_dict["image_proj"]["latents"]
|
175 |
+
strict_load_image_proj_model = False
|
176 |
+
|
177 |
+
# Load state dict for image_proj_model and adapter_modules
|
178 |
+
self.image_proj.load_state_dict(state_dict["image_proj"], strict=strict_load_image_proj_model)
|
179 |
+
missing_key, unexpected_key = self.ip_adapter.load_state_dict(state_dict["ip_adapter"], strict=False)
|
180 |
+
if len(missing_key) > 0:
|
181 |
+
for ms in missing_key:
|
182 |
+
if "ln" not in ms:
|
183 |
+
raise ValueError(f"Missing key in adapter_modules: {len(missing_key)}")
|
184 |
+
if len(unexpected_key) > 0:
|
185 |
+
raise ValueError(f"Unexpected key in adapter_modules: {len(unexpected_key)}")
|
186 |
+
|
187 |
+
# Calculate new checksums
|
188 |
+
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj.parameters()]))
|
189 |
+
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.ip_adapter.parameters()]))
|
190 |
+
|
191 |
+
# Verify if the weights loaded to unet
|
192 |
+
unet_sum = []
|
193 |
+
for attn_name, attn_proc in self.unet.attn_processors.items():
|
194 |
+
if isinstance(attn_proc, (TA_IPAttnProcessor, IPAttnProcessor)):
|
195 |
+
unet_sum.append(torch.sum(torch.stack([torch.sum(p) for p in attn_proc.parameters()])))
|
196 |
+
unet_sum = torch.sum(torch.stack(unet_sum))
|
197 |
+
|
198 |
+
assert org_unet_sum != unet_sum, "Weights of adapter_modules in unet did not change!"
|
199 |
+
assert (unet_sum - new_adapter_sum < 1e-4), "Weights of adapter_modules did not load to unet!"
|
200 |
+
|
201 |
+
# Verify if the weights have changed
|
202 |
+
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
|
203 |
+
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_mod`ules did not change!"
|
204 |
+
|
205 |
+
|
206 |
+
class IPAdapterXL(IPAdapter):
|
207 |
+
"""SDXL"""
|
208 |
+
|
209 |
+
def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds):
|
210 |
+
ip_tokens = self.image_proj(image_embeds)
|
211 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
|
212 |
+
# Predict the noise residual
|
213 |
+
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs).sample
|
214 |
+
return noise_pred
|
215 |
+
|
216 |
+
|
217 |
+
class IPAdapterPlusXL(IPAdapterPlus):
|
218 |
+
"""IP-Adapter with fine-grained features"""
|
219 |
+
|
220 |
+
def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds):
|
221 |
+
ip_tokens = self.image_proj(image_embeds)
|
222 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
|
223 |
+
# Predict the noise residual
|
224 |
+
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs).sample
|
225 |
+
return noise_pred
|
226 |
+
|
227 |
+
|
228 |
+
class IPAdapterFull(IPAdapterPlus):
|
229 |
+
"""IP-Adapter with full features"""
|
230 |
+
|
231 |
+
def init_proj(self):
|
232 |
+
image_proj_model = MLPProjModel(
|
233 |
+
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
234 |
+
clip_embeddings_dim=self.image_encoder.config.hidden_size,
|
235 |
+
).to(self.device, dtype=torch.float16)
|
236 |
+
return image_proj_model
|
module/ip_adapter/resampler.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
2 |
+
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
|
3 |
+
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from einops import rearrange
|
9 |
+
from einops.layers.torch import Rearrange
|
10 |
+
|
11 |
+
|
12 |
+
# FFN
|
13 |
+
def FeedForward(dim, mult=4):
|
14 |
+
inner_dim = int(dim * mult)
|
15 |
+
return nn.Sequential(
|
16 |
+
nn.LayerNorm(dim),
|
17 |
+
nn.Linear(dim, inner_dim, bias=False),
|
18 |
+
nn.GELU(),
|
19 |
+
nn.Linear(inner_dim, dim, bias=False),
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
def reshape_tensor(x, heads):
|
24 |
+
bs, length, width = x.shape
|
25 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
26 |
+
x = x.view(bs, length, heads, -1)
|
27 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
28 |
+
x = x.transpose(1, 2)
|
29 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
30 |
+
x = x.reshape(bs, heads, length, -1)
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
class PerceiverAttention(nn.Module):
|
35 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
36 |
+
super().__init__()
|
37 |
+
self.scale = dim_head**-0.5
|
38 |
+
self.dim_head = dim_head
|
39 |
+
self.heads = heads
|
40 |
+
inner_dim = dim_head * heads
|
41 |
+
|
42 |
+
self.norm1 = nn.LayerNorm(dim)
|
43 |
+
self.norm2 = nn.LayerNorm(dim)
|
44 |
+
|
45 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
46 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
47 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
48 |
+
|
49 |
+
def forward(self, x, latents):
|
50 |
+
"""
|
51 |
+
Args:
|
52 |
+
x (torch.Tensor): image features
|
53 |
+
shape (b, n1, D)
|
54 |
+
latent (torch.Tensor): latent features
|
55 |
+
shape (b, n2, D)
|
56 |
+
"""
|
57 |
+
x = self.norm1(x)
|
58 |
+
latents = self.norm2(latents)
|
59 |
+
|
60 |
+
b, l, _ = latents.shape
|
61 |
+
|
62 |
+
q = self.to_q(latents)
|
63 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
64 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
65 |
+
|
66 |
+
q = reshape_tensor(q, self.heads)
|
67 |
+
k = reshape_tensor(k, self.heads)
|
68 |
+
v = reshape_tensor(v, self.heads)
|
69 |
+
|
70 |
+
# attention
|
71 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
72 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
73 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
74 |
+
out = weight @ v
|
75 |
+
|
76 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
77 |
+
|
78 |
+
return self.to_out(out)
|
79 |
+
|
80 |
+
|
81 |
+
class Resampler(nn.Module):
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
dim=1280,
|
85 |
+
depth=4,
|
86 |
+
dim_head=64,
|
87 |
+
heads=20,
|
88 |
+
num_queries=64,
|
89 |
+
embedding_dim=768,
|
90 |
+
output_dim=1024,
|
91 |
+
ff_mult=4,
|
92 |
+
max_seq_len: int = 257, # CLIP tokens + CLS token
|
93 |
+
apply_pos_emb: bool = False,
|
94 |
+
num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
|
95 |
+
):
|
96 |
+
super().__init__()
|
97 |
+
self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
|
98 |
+
|
99 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
100 |
+
|
101 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
102 |
+
|
103 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
104 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
105 |
+
|
106 |
+
self.to_latents_from_mean_pooled_seq = (
|
107 |
+
nn.Sequential(
|
108 |
+
nn.LayerNorm(dim),
|
109 |
+
nn.Linear(dim, dim * num_latents_mean_pooled),
|
110 |
+
Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
|
111 |
+
)
|
112 |
+
if num_latents_mean_pooled > 0
|
113 |
+
else None
|
114 |
+
)
|
115 |
+
|
116 |
+
self.layers = nn.ModuleList([])
|
117 |
+
for _ in range(depth):
|
118 |
+
self.layers.append(
|
119 |
+
nn.ModuleList(
|
120 |
+
[
|
121 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
122 |
+
FeedForward(dim=dim, mult=ff_mult),
|
123 |
+
]
|
124 |
+
)
|
125 |
+
)
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
if self.pos_emb is not None:
|
129 |
+
n, device = x.shape[1], x.device
|
130 |
+
pos_emb = self.pos_emb(torch.arange(n, device=device))
|
131 |
+
x = x + pos_emb
|
132 |
+
|
133 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
134 |
+
|
135 |
+
x = self.proj_in(x)
|
136 |
+
|
137 |
+
if self.to_latents_from_mean_pooled_seq:
|
138 |
+
meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
|
139 |
+
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
|
140 |
+
latents = torch.cat((meanpooled_latents, latents), dim=-2)
|
141 |
+
|
142 |
+
for attn, ff in self.layers:
|
143 |
+
latents = attn(x, latents) + latents
|
144 |
+
latents = ff(latents) + latents
|
145 |
+
|
146 |
+
latents = self.proj_out(latents)
|
147 |
+
return self.norm_out(latents)
|
148 |
+
|
149 |
+
|
150 |
+
def masked_mean(t, *, dim, mask=None):
|
151 |
+
if mask is None:
|
152 |
+
return t.mean(dim=dim)
|
153 |
+
|
154 |
+
denom = mask.sum(dim=dim, keepdim=True)
|
155 |
+
mask = rearrange(mask, "b n -> b n 1")
|
156 |
+
masked_t = t.masked_fill(~mask, 0.0)
|
157 |
+
|
158 |
+
return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
|
module/ip_adapter/utils.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from collections import namedtuple, OrderedDict
|
3 |
+
from safetensors import safe_open
|
4 |
+
from .attention_processor import init_attn_proc
|
5 |
+
from .ip_adapter import MultiIPAdapterImageProjection
|
6 |
+
from .resampler import Resampler
|
7 |
+
from transformers import (
|
8 |
+
AutoModel, AutoImageProcessor,
|
9 |
+
CLIPVisionModelWithProjection, CLIPImageProcessor)
|
10 |
+
|
11 |
+
|
12 |
+
def init_adapter_in_unet(
|
13 |
+
unet,
|
14 |
+
image_proj_model=None,
|
15 |
+
pretrained_model_path_or_dict=None,
|
16 |
+
adapter_tokens=64,
|
17 |
+
embedding_dim=None,
|
18 |
+
use_lcm=False,
|
19 |
+
use_adaln=True,
|
20 |
+
):
|
21 |
+
device = unet.device
|
22 |
+
dtype = unet.dtype
|
23 |
+
if image_proj_model is None:
|
24 |
+
assert embedding_dim is not None, "embedding_dim must be provided if image_proj_model is None."
|
25 |
+
image_proj_model = Resampler(
|
26 |
+
embedding_dim=embedding_dim,
|
27 |
+
output_dim=unet.config.cross_attention_dim,
|
28 |
+
num_queries=adapter_tokens,
|
29 |
+
)
|
30 |
+
if pretrained_model_path_or_dict is not None:
|
31 |
+
if not isinstance(pretrained_model_path_or_dict, dict):
|
32 |
+
if pretrained_model_path_or_dict.endswith(".safetensors"):
|
33 |
+
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
34 |
+
with safe_open(pretrained_model_path_or_dict, framework="pt", device=unet.device) as f:
|
35 |
+
for key in f.keys():
|
36 |
+
if key.startswith("image_proj."):
|
37 |
+
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
38 |
+
elif key.startswith("ip_adapter."):
|
39 |
+
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
40 |
+
else:
|
41 |
+
state_dict = torch.load(pretrained_model_path_or_dict, map_location=unet.device)
|
42 |
+
else:
|
43 |
+
state_dict = pretrained_model_path_or_dict
|
44 |
+
keys = list(state_dict.keys())
|
45 |
+
if "image_proj" not in keys and "ip_adapter" not in keys:
|
46 |
+
state_dict = revise_state_dict(state_dict)
|
47 |
+
|
48 |
+
# Creat IP cross-attention in unet.
|
49 |
+
attn_procs = init_attn_proc(unet, adapter_tokens, use_lcm, use_adaln)
|
50 |
+
unet.set_attn_processor(attn_procs)
|
51 |
+
|
52 |
+
# Load pretrinaed model if needed.
|
53 |
+
if pretrained_model_path_or_dict is not None:
|
54 |
+
if "ip_adapter" in state_dict.keys():
|
55 |
+
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
|
56 |
+
missing, unexpected = adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False)
|
57 |
+
for mk in missing:
|
58 |
+
if "ln" not in mk:
|
59 |
+
raise ValueError(f"Missing keys in adapter_modules: {missing}")
|
60 |
+
if "image_proj" in state_dict.keys():
|
61 |
+
image_proj_model.load_state_dict(state_dict["image_proj"])
|
62 |
+
|
63 |
+
# Load image projectors into iterable ModuleList.
|
64 |
+
image_projection_layers = []
|
65 |
+
image_projection_layers.append(image_proj_model)
|
66 |
+
unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
67 |
+
|
68 |
+
# Adjust unet config to handle addtional ip hidden states.
|
69 |
+
unet.config.encoder_hid_dim_type = "ip_image_proj"
|
70 |
+
unet.to(dtype=dtype, device=device)
|
71 |
+
|
72 |
+
|
73 |
+
def load_adapter_to_pipe(
|
74 |
+
pipe,
|
75 |
+
pretrained_model_path_or_dict,
|
76 |
+
image_encoder_or_path=None,
|
77 |
+
feature_extractor_or_path=None,
|
78 |
+
use_clip_encoder=False,
|
79 |
+
adapter_tokens=64,
|
80 |
+
use_lcm=False,
|
81 |
+
use_adaln=True,
|
82 |
+
):
|
83 |
+
|
84 |
+
if not isinstance(pretrained_model_path_or_dict, dict):
|
85 |
+
if pretrained_model_path_or_dict.endswith(".safetensors"):
|
86 |
+
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
87 |
+
with safe_open(pretrained_model_path_or_dict, framework="pt", device=pipe.device) as f:
|
88 |
+
for key in f.keys():
|
89 |
+
if key.startswith("image_proj."):
|
90 |
+
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
91 |
+
elif key.startswith("ip_adapter."):
|
92 |
+
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
93 |
+
else:
|
94 |
+
state_dict = torch.load(pretrained_model_path_or_dict, map_location=pipe.device)
|
95 |
+
else:
|
96 |
+
state_dict = pretrained_model_path_or_dict
|
97 |
+
keys = list(state_dict.keys())
|
98 |
+
if "image_proj" not in keys and "ip_adapter" not in keys:
|
99 |
+
state_dict = revise_state_dict(state_dict)
|
100 |
+
|
101 |
+
# load CLIP image encoder here if it has not been registered to the pipeline yet
|
102 |
+
if image_encoder_or_path is not None:
|
103 |
+
if isinstance(image_encoder_or_path, str):
|
104 |
+
feature_extractor_or_path = image_encoder_or_path if feature_extractor_or_path is None else feature_extractor_or_path
|
105 |
+
|
106 |
+
image_encoder_or_path = (
|
107 |
+
CLIPVisionModelWithProjection.from_pretrained(
|
108 |
+
image_encoder_or_path
|
109 |
+
) if use_clip_encoder else
|
110 |
+
AutoModel.from_pretrained(image_encoder_or_path)
|
111 |
+
)
|
112 |
+
|
113 |
+
if feature_extractor_or_path is not None:
|
114 |
+
if isinstance(feature_extractor_or_path, str):
|
115 |
+
feature_extractor_or_path = (
|
116 |
+
CLIPImageProcessor() if use_clip_encoder else
|
117 |
+
AutoImageProcessor.from_pretrained(feature_extractor_or_path)
|
118 |
+
)
|
119 |
+
|
120 |
+
# create image encoder if it has not been registered to the pipeline yet
|
121 |
+
if hasattr(pipe, "image_encoder") and getattr(pipe, "image_encoder", None) is None:
|
122 |
+
image_encoder = image_encoder_or_path.to(pipe.device, dtype=pipe.dtype)
|
123 |
+
pipe.register_modules(image_encoder=image_encoder)
|
124 |
+
else:
|
125 |
+
image_encoder = pipe.image_encoder
|
126 |
+
|
127 |
+
# create feature extractor if it has not been registered to the pipeline yet
|
128 |
+
if hasattr(pipe, "feature_extractor") and getattr(pipe, "feature_extractor", None) is None:
|
129 |
+
feature_extractor = feature_extractor_or_path
|
130 |
+
pipe.register_modules(feature_extractor=feature_extractor)
|
131 |
+
else:
|
132 |
+
feature_extractor = pipe.feature_extractor
|
133 |
+
|
134 |
+
# load adapter into unet
|
135 |
+
unet = getattr(pipe, pipe.unet_name) if not hasattr(pipe, "unet") else pipe.unet
|
136 |
+
attn_procs = init_attn_proc(unet, adapter_tokens, use_lcm, use_adaln)
|
137 |
+
unet.set_attn_processor(attn_procs)
|
138 |
+
image_proj_model = Resampler(
|
139 |
+
embedding_dim=image_encoder.config.hidden_size,
|
140 |
+
output_dim=unet.config.cross_attention_dim,
|
141 |
+
num_queries=adapter_tokens,
|
142 |
+
)
|
143 |
+
|
144 |
+
# Load pretrinaed model if needed.
|
145 |
+
if "ip_adapter" in state_dict.keys():
|
146 |
+
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
|
147 |
+
missing, unexpected = adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False)
|
148 |
+
for mk in missing:
|
149 |
+
if "ln" not in mk:
|
150 |
+
raise ValueError(f"Missing keys in adapter_modules: {missing}")
|
151 |
+
if "image_proj" in state_dict.keys():
|
152 |
+
image_proj_model.load_state_dict(state_dict["image_proj"])
|
153 |
+
|
154 |
+
# convert IP-Adapter Image Projection layers to diffusers
|
155 |
+
image_projection_layers = []
|
156 |
+
image_projection_layers.append(image_proj_model)
|
157 |
+
unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
158 |
+
|
159 |
+
# Adjust unet config to handle addtional ip hidden states.
|
160 |
+
unet.config.encoder_hid_dim_type = "ip_image_proj"
|
161 |
+
unet.to(dtype=pipe.dtype, device=pipe.device)
|
162 |
+
|
163 |
+
|
164 |
+
def revise_state_dict(old_state_dict_or_path, map_location="cpu"):
|
165 |
+
new_state_dict = OrderedDict()
|
166 |
+
new_state_dict["image_proj"] = OrderedDict()
|
167 |
+
new_state_dict["ip_adapter"] = OrderedDict()
|
168 |
+
if isinstance(old_state_dict_or_path, str):
|
169 |
+
old_state_dict = torch.load(old_state_dict_or_path, map_location=map_location)
|
170 |
+
else:
|
171 |
+
old_state_dict = old_state_dict_or_path
|
172 |
+
for name, weight in old_state_dict.items():
|
173 |
+
if name.startswith("image_proj_model."):
|
174 |
+
new_state_dict["image_proj"][name[len("image_proj_model."):]] = weight
|
175 |
+
elif name.startswith("adapter_modules."):
|
176 |
+
new_state_dict["ip_adapter"][name[len("adapter_modules."):]] = weight
|
177 |
+
return new_state_dict
|
178 |
+
|
179 |
+
|
180 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
181 |
+
def encode_image(image_encoder, feature_extractor, image, device, num_images_per_prompt, output_hidden_states=None):
|
182 |
+
dtype = next(image_encoder.parameters()).dtype
|
183 |
+
|
184 |
+
if not isinstance(image, torch.Tensor):
|
185 |
+
image = feature_extractor(image, return_tensors="pt").pixel_values
|
186 |
+
|
187 |
+
image = image.to(device=device, dtype=dtype)
|
188 |
+
if output_hidden_states:
|
189 |
+
image_enc_hidden_states = image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
190 |
+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
191 |
+
return image_enc_hidden_states
|
192 |
+
else:
|
193 |
+
if isinstance(image_encoder, CLIPVisionModelWithProjection):
|
194 |
+
# CLIP image encoder.
|
195 |
+
image_embeds = image_encoder(image).image_embeds
|
196 |
+
else:
|
197 |
+
# DINO image encoder.
|
198 |
+
image_embeds = image_encoder(image).last_hidden_state
|
199 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
200 |
+
return image_embeds
|
201 |
+
|
202 |
+
|
203 |
+
def prepare_training_image_embeds(
|
204 |
+
image_encoder, feature_extractor,
|
205 |
+
ip_adapter_image, ip_adapter_image_embeds,
|
206 |
+
device, drop_rate, output_hidden_state, idx_to_replace=None
|
207 |
+
):
|
208 |
+
if ip_adapter_image_embeds is None:
|
209 |
+
if not isinstance(ip_adapter_image, list):
|
210 |
+
ip_adapter_image = [ip_adapter_image]
|
211 |
+
|
212 |
+
# if len(ip_adapter_image) != len(unet.encoder_hid_proj.image_projection_layers):
|
213 |
+
# raise ValueError(
|
214 |
+
# f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
215 |
+
# )
|
216 |
+
|
217 |
+
image_embeds = []
|
218 |
+
for single_ip_adapter_image in ip_adapter_image:
|
219 |
+
if idx_to_replace is None:
|
220 |
+
idx_to_replace = torch.rand(len(single_ip_adapter_image)) < drop_rate
|
221 |
+
zero_ip_adapter_image = torch.zeros_like(single_ip_adapter_image)
|
222 |
+
single_ip_adapter_image[idx_to_replace] = zero_ip_adapter_image[idx_to_replace]
|
223 |
+
single_image_embeds = encode_image(
|
224 |
+
image_encoder, feature_extractor, single_ip_adapter_image, device, 1, output_hidden_state
|
225 |
+
)
|
226 |
+
single_image_embeds = torch.stack([single_image_embeds], dim=1) # FIXME
|
227 |
+
|
228 |
+
image_embeds.append(single_image_embeds)
|
229 |
+
else:
|
230 |
+
repeat_dims = [1]
|
231 |
+
image_embeds = []
|
232 |
+
for single_image_embeds in ip_adapter_image_embeds:
|
233 |
+
if do_classifier_free_guidance:
|
234 |
+
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
235 |
+
single_image_embeds = single_image_embeds.repeat(
|
236 |
+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
237 |
+
)
|
238 |
+
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
239 |
+
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
240 |
+
)
|
241 |
+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
242 |
+
else:
|
243 |
+
single_image_embeds = single_image_embeds.repeat(
|
244 |
+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
245 |
+
)
|
246 |
+
image_embeds.append(single_image_embeds)
|
247 |
+
|
248 |
+
return image_embeds
|
module/min_sdxl.py
ADDED
@@ -0,0 +1,915 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from minSDXL by Simo Ryu:
|
2 |
+
# https://github.com/cloneofsimo/minSDXL ,
|
3 |
+
# which is in turn modified from the original code of:
|
4 |
+
# https://github.com/huggingface/diffusers
|
5 |
+
# So has APACHE 2.0 license
|
6 |
+
|
7 |
+
from typing import Optional, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import math
|
13 |
+
import inspect
|
14 |
+
|
15 |
+
from collections import namedtuple
|
16 |
+
|
17 |
+
from torch.fft import fftn, fftshift, ifftn, ifftshift
|
18 |
+
|
19 |
+
from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0
|
20 |
+
|
21 |
+
# Implementation of FreeU for minsdxl
|
22 |
+
|
23 |
+
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
|
24 |
+
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
|
25 |
+
|
26 |
+
This version of the method comes from here:
|
27 |
+
https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706
|
28 |
+
"""
|
29 |
+
x = x_in
|
30 |
+
B, C, H, W = x.shape
|
31 |
+
|
32 |
+
# Non-power of 2 images must be float32
|
33 |
+
if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
|
34 |
+
x = x.to(dtype=torch.float32)
|
35 |
+
|
36 |
+
# FFT
|
37 |
+
x_freq = fftn(x, dim=(-2, -1))
|
38 |
+
x_freq = fftshift(x_freq, dim=(-2, -1))
|
39 |
+
|
40 |
+
B, C, H, W = x_freq.shape
|
41 |
+
mask = torch.ones((B, C, H, W), device=x.device)
|
42 |
+
|
43 |
+
crow, ccol = H // 2, W // 2
|
44 |
+
mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale
|
45 |
+
x_freq = x_freq * mask
|
46 |
+
|
47 |
+
# IFFT
|
48 |
+
x_freq = ifftshift(x_freq, dim=(-2, -1))
|
49 |
+
x_filtered = ifftn(x_freq, dim=(-2, -1)).real
|
50 |
+
|
51 |
+
return x_filtered.to(dtype=x_in.dtype)
|
52 |
+
|
53 |
+
|
54 |
+
def apply_freeu(
|
55 |
+
resolution_idx: int, hidden_states: "torch.Tensor", res_hidden_states: "torch.Tensor", **freeu_kwargs):
|
56 |
+
"""Applies the FreeU mechanism as introduced in https:
|
57 |
+
//arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
resolution_idx (`int`): Integer denoting the UNet block where FreeU is being applied.
|
61 |
+
hidden_states (`torch.Tensor`): Inputs to the underlying block.
|
62 |
+
res_hidden_states (`torch.Tensor`): Features from the skip block corresponding to the underlying block.
|
63 |
+
s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features.
|
64 |
+
s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features.
|
65 |
+
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
66 |
+
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
67 |
+
"""
|
68 |
+
if resolution_idx == 0:
|
69 |
+
num_half_channels = hidden_states.shape[1] // 2
|
70 |
+
hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b1"]
|
71 |
+
res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s1"])
|
72 |
+
if resolution_idx == 1:
|
73 |
+
num_half_channels = hidden_states.shape[1] // 2
|
74 |
+
hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b2"]
|
75 |
+
res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"])
|
76 |
+
|
77 |
+
return hidden_states, res_hidden_states
|
78 |
+
|
79 |
+
# Diffusers-style LoRA to keep everything in the min_sdxl.py file
|
80 |
+
|
81 |
+
class LoRALinearLayer(nn.Module):
|
82 |
+
r"""
|
83 |
+
A linear layer that is used with LoRA.
|
84 |
+
|
85 |
+
Parameters:
|
86 |
+
in_features (`int`):
|
87 |
+
Number of input features.
|
88 |
+
out_features (`int`):
|
89 |
+
Number of output features.
|
90 |
+
rank (`int`, `optional`, defaults to 4):
|
91 |
+
The rank of the LoRA layer.
|
92 |
+
network_alpha (`float`, `optional`, defaults to `None`):
|
93 |
+
The value of the network alpha used for stable learning and preventing underflow. This value has the same
|
94 |
+
meaning as the `--network_alpha` option in the kohya-ss trainer script. See
|
95 |
+
https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
96 |
+
device (`torch.device`, `optional`, defaults to `None`):
|
97 |
+
The device to use for the layer's weights.
|
98 |
+
dtype (`torch.dtype`, `optional`, defaults to `None`):
|
99 |
+
The dtype to use for the layer's weights.
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(
|
103 |
+
self,
|
104 |
+
in_features: int,
|
105 |
+
out_features: int,
|
106 |
+
rank: int = 4,
|
107 |
+
network_alpha: Optional[float] = None,
|
108 |
+
device: Optional[Union[torch.device, str]] = None,
|
109 |
+
dtype: Optional[torch.dtype] = None,
|
110 |
+
):
|
111 |
+
super().__init__()
|
112 |
+
|
113 |
+
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
114 |
+
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
|
115 |
+
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
116 |
+
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
117 |
+
self.network_alpha = network_alpha
|
118 |
+
self.rank = rank
|
119 |
+
self.out_features = out_features
|
120 |
+
self.in_features = in_features
|
121 |
+
|
122 |
+
nn.init.normal_(self.down.weight, std=1 / rank)
|
123 |
+
nn.init.zeros_(self.up.weight)
|
124 |
+
|
125 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
126 |
+
orig_dtype = hidden_states.dtype
|
127 |
+
dtype = self.down.weight.dtype
|
128 |
+
|
129 |
+
down_hidden_states = self.down(hidden_states.to(dtype))
|
130 |
+
up_hidden_states = self.up(down_hidden_states)
|
131 |
+
|
132 |
+
if self.network_alpha is not None:
|
133 |
+
up_hidden_states *= self.network_alpha / self.rank
|
134 |
+
|
135 |
+
return up_hidden_states.to(orig_dtype)
|
136 |
+
|
137 |
+
class LoRACompatibleLinear(nn.Linear):
|
138 |
+
"""
|
139 |
+
A Linear layer that can be used with LoRA.
|
140 |
+
"""
|
141 |
+
|
142 |
+
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
|
143 |
+
super().__init__(*args, **kwargs)
|
144 |
+
self.lora_layer = lora_layer
|
145 |
+
|
146 |
+
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
|
147 |
+
self.lora_layer = lora_layer
|
148 |
+
|
149 |
+
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
|
150 |
+
if self.lora_layer is None:
|
151 |
+
return
|
152 |
+
|
153 |
+
dtype, device = self.weight.data.dtype, self.weight.data.device
|
154 |
+
|
155 |
+
w_orig = self.weight.data.float()
|
156 |
+
w_up = self.lora_layer.up.weight.data.float()
|
157 |
+
w_down = self.lora_layer.down.weight.data.float()
|
158 |
+
|
159 |
+
if self.lora_layer.network_alpha is not None:
|
160 |
+
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
|
161 |
+
|
162 |
+
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
|
163 |
+
|
164 |
+
if safe_fusing and torch.isnan(fused_weight).any().item():
|
165 |
+
raise ValueError(
|
166 |
+
"This LoRA weight seems to be broken. "
|
167 |
+
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
|
168 |
+
"LoRA weights will not be fused."
|
169 |
+
)
|
170 |
+
|
171 |
+
self.weight.data = fused_weight.to(device=device, dtype=dtype)
|
172 |
+
|
173 |
+
# we can drop the lora layer now
|
174 |
+
self.lora_layer = None
|
175 |
+
|
176 |
+
# offload the up and down matrices to CPU to not blow the memory
|
177 |
+
self.w_up = w_up.cpu()
|
178 |
+
self.w_down = w_down.cpu()
|
179 |
+
self._lora_scale = lora_scale
|
180 |
+
|
181 |
+
def _unfuse_lora(self):
|
182 |
+
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
|
183 |
+
return
|
184 |
+
|
185 |
+
fused_weight = self.weight.data
|
186 |
+
dtype, device = fused_weight.dtype, fused_weight.device
|
187 |
+
|
188 |
+
w_up = self.w_up.to(device=device).float()
|
189 |
+
w_down = self.w_down.to(device).float()
|
190 |
+
|
191 |
+
unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
|
192 |
+
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
|
193 |
+
|
194 |
+
self.w_up = None
|
195 |
+
self.w_down = None
|
196 |
+
|
197 |
+
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
198 |
+
if self.lora_layer is None:
|
199 |
+
out = super().forward(hidden_states)
|
200 |
+
return out
|
201 |
+
else:
|
202 |
+
out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
|
203 |
+
return out
|
204 |
+
|
205 |
+
class Timesteps(nn.Module):
|
206 |
+
def __init__(self, num_channels: int = 320):
|
207 |
+
super().__init__()
|
208 |
+
self.num_channels = num_channels
|
209 |
+
|
210 |
+
def forward(self, timesteps):
|
211 |
+
half_dim = self.num_channels // 2
|
212 |
+
exponent = -math.log(10000) * torch.arange(
|
213 |
+
half_dim, dtype=torch.float32, device=timesteps.device
|
214 |
+
)
|
215 |
+
exponent = exponent / (half_dim - 0.0)
|
216 |
+
|
217 |
+
emb = torch.exp(exponent)
|
218 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
219 |
+
|
220 |
+
sin_emb = torch.sin(emb)
|
221 |
+
cos_emb = torch.cos(emb)
|
222 |
+
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
223 |
+
|
224 |
+
return emb
|
225 |
+
|
226 |
+
|
227 |
+
class TimestepEmbedding(nn.Module):
|
228 |
+
def __init__(self, in_features, out_features):
|
229 |
+
super(TimestepEmbedding, self).__init__()
|
230 |
+
self.linear_1 = nn.Linear(in_features, out_features, bias=True)
|
231 |
+
self.act = nn.SiLU()
|
232 |
+
self.linear_2 = nn.Linear(out_features, out_features, bias=True)
|
233 |
+
|
234 |
+
def forward(self, sample):
|
235 |
+
sample = self.linear_1(sample)
|
236 |
+
sample = self.act(sample)
|
237 |
+
sample = self.linear_2(sample)
|
238 |
+
|
239 |
+
return sample
|
240 |
+
|
241 |
+
|
242 |
+
class ResnetBlock2D(nn.Module):
|
243 |
+
def __init__(self, in_channels, out_channels, conv_shortcut=True):
|
244 |
+
super(ResnetBlock2D, self).__init__()
|
245 |
+
self.norm1 = nn.GroupNorm(32, in_channels, eps=1e-05, affine=True)
|
246 |
+
self.conv1 = nn.Conv2d(
|
247 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
248 |
+
)
|
249 |
+
self.time_emb_proj = nn.Linear(1280, out_channels, bias=True)
|
250 |
+
self.norm2 = nn.GroupNorm(32, out_channels, eps=1e-05, affine=True)
|
251 |
+
self.dropout = nn.Dropout(p=0.0, inplace=False)
|
252 |
+
self.conv2 = nn.Conv2d(
|
253 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
254 |
+
)
|
255 |
+
self.nonlinearity = nn.SiLU()
|
256 |
+
self.conv_shortcut = None
|
257 |
+
if conv_shortcut:
|
258 |
+
self.conv_shortcut = nn.Conv2d(
|
259 |
+
in_channels, out_channels, kernel_size=1, stride=1
|
260 |
+
)
|
261 |
+
|
262 |
+
def forward(self, input_tensor, temb):
|
263 |
+
hidden_states = input_tensor
|
264 |
+
hidden_states = self.norm1(hidden_states)
|
265 |
+
hidden_states = self.nonlinearity(hidden_states)
|
266 |
+
|
267 |
+
hidden_states = self.conv1(hidden_states)
|
268 |
+
|
269 |
+
temb = self.nonlinearity(temb)
|
270 |
+
temb = self.time_emb_proj(temb)[:, :, None, None]
|
271 |
+
hidden_states = hidden_states + temb
|
272 |
+
hidden_states = self.norm2(hidden_states)
|
273 |
+
|
274 |
+
hidden_states = self.nonlinearity(hidden_states)
|
275 |
+
hidden_states = self.dropout(hidden_states)
|
276 |
+
hidden_states = self.conv2(hidden_states)
|
277 |
+
|
278 |
+
if self.conv_shortcut is not None:
|
279 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
280 |
+
|
281 |
+
output_tensor = input_tensor + hidden_states
|
282 |
+
|
283 |
+
return output_tensor
|
284 |
+
|
285 |
+
|
286 |
+
class Attention(nn.Module):
|
287 |
+
def __init__(
|
288 |
+
self, inner_dim, cross_attention_dim=None, num_heads=None, dropout=0.0, processor=None, scale_qk=True
|
289 |
+
):
|
290 |
+
super(Attention, self).__init__()
|
291 |
+
if num_heads is None:
|
292 |
+
self.head_dim = 64
|
293 |
+
self.num_heads = inner_dim // self.head_dim
|
294 |
+
else:
|
295 |
+
self.num_heads = num_heads
|
296 |
+
self.head_dim = inner_dim // num_heads
|
297 |
+
|
298 |
+
self.scale = self.head_dim**-0.5
|
299 |
+
if cross_attention_dim is None:
|
300 |
+
cross_attention_dim = inner_dim
|
301 |
+
self.to_q = LoRACompatibleLinear(inner_dim, inner_dim, bias=False)
|
302 |
+
self.to_k = LoRACompatibleLinear(cross_attention_dim, inner_dim, bias=False)
|
303 |
+
self.to_v = LoRACompatibleLinear(cross_attention_dim, inner_dim, bias=False)
|
304 |
+
|
305 |
+
self.to_out = nn.ModuleList(
|
306 |
+
[LoRACompatibleLinear(inner_dim, inner_dim), nn.Dropout(dropout, inplace=False)]
|
307 |
+
)
|
308 |
+
|
309 |
+
self.scale_qk = scale_qk
|
310 |
+
if processor is None:
|
311 |
+
processor = (
|
312 |
+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
313 |
+
)
|
314 |
+
self.set_processor(processor)
|
315 |
+
|
316 |
+
def forward(
|
317 |
+
self,
|
318 |
+
hidden_states: torch.FloatTensor,
|
319 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
320 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
321 |
+
**cross_attention_kwargs,
|
322 |
+
) -> torch.Tensor:
|
323 |
+
r"""
|
324 |
+
The forward method of the `Attention` class.
|
325 |
+
|
326 |
+
Args:
|
327 |
+
hidden_states (`torch.Tensor`):
|
328 |
+
The hidden states of the query.
|
329 |
+
encoder_hidden_states (`torch.Tensor`, *optional*):
|
330 |
+
The hidden states of the encoder.
|
331 |
+
attention_mask (`torch.Tensor`, *optional*):
|
332 |
+
The attention mask to use. If `None`, no mask is applied.
|
333 |
+
**cross_attention_kwargs:
|
334 |
+
Additional keyword arguments to pass along to the cross attention.
|
335 |
+
|
336 |
+
Returns:
|
337 |
+
`torch.Tensor`: The output of the attention layer.
|
338 |
+
"""
|
339 |
+
# The `Attention` class can call different attention processors / attention functions
|
340 |
+
# here we simply pass along all tensors to the selected processor class
|
341 |
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
342 |
+
|
343 |
+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
344 |
+
unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
|
345 |
+
if len(unused_kwargs) > 0:
|
346 |
+
print(
|
347 |
+
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
348 |
+
)
|
349 |
+
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
|
350 |
+
|
351 |
+
return self.processor(
|
352 |
+
self,
|
353 |
+
hidden_states,
|
354 |
+
encoder_hidden_states=encoder_hidden_states,
|
355 |
+
attention_mask=attention_mask,
|
356 |
+
**cross_attention_kwargs,
|
357 |
+
)
|
358 |
+
|
359 |
+
def orig_forward(self, hidden_states, encoder_hidden_states=None):
|
360 |
+
q = self.to_q(hidden_states)
|
361 |
+
k = (
|
362 |
+
self.to_k(encoder_hidden_states)
|
363 |
+
if encoder_hidden_states is not None
|
364 |
+
else self.to_k(hidden_states)
|
365 |
+
)
|
366 |
+
v = (
|
367 |
+
self.to_v(encoder_hidden_states)
|
368 |
+
if encoder_hidden_states is not None
|
369 |
+
else self.to_v(hidden_states)
|
370 |
+
)
|
371 |
+
b, t, c = q.size()
|
372 |
+
|
373 |
+
q = q.view(q.size(0), q.size(1), self.num_heads, self.head_dim).transpose(1, 2)
|
374 |
+
k = k.view(k.size(0), k.size(1), self.num_heads, self.head_dim).transpose(1, 2)
|
375 |
+
v = v.view(v.size(0), v.size(1), self.num_heads, self.head_dim).transpose(1, 2)
|
376 |
+
|
377 |
+
# scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
378 |
+
# attn_weights = torch.softmax(scores, dim=-1)
|
379 |
+
# attn_output = torch.matmul(attn_weights, v)
|
380 |
+
|
381 |
+
attn_output = F.scaled_dot_product_attention(
|
382 |
+
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale,
|
383 |
+
)
|
384 |
+
|
385 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(b, t, c)
|
386 |
+
|
387 |
+
for layer in self.to_out:
|
388 |
+
attn_output = layer(attn_output)
|
389 |
+
|
390 |
+
return attn_output
|
391 |
+
|
392 |
+
def set_processor(self, processor) -> None:
|
393 |
+
r"""
|
394 |
+
Set the attention processor to use.
|
395 |
+
|
396 |
+
Args:
|
397 |
+
processor (`AttnProcessor`):
|
398 |
+
The attention processor to use.
|
399 |
+
"""
|
400 |
+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
401 |
+
# pop `processor` from `self._modules`
|
402 |
+
if (
|
403 |
+
hasattr(self, "processor")
|
404 |
+
and isinstance(self.processor, torch.nn.Module)
|
405 |
+
and not isinstance(processor, torch.nn.Module)
|
406 |
+
):
|
407 |
+
print(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
408 |
+
self._modules.pop("processor")
|
409 |
+
|
410 |
+
self.processor = processor
|
411 |
+
|
412 |
+
def get_processor(self, return_deprecated_lora: bool = False):
|
413 |
+
r"""
|
414 |
+
Get the attention processor in use.
|
415 |
+
|
416 |
+
Args:
|
417 |
+
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
|
418 |
+
Set to `True` to return the deprecated LoRA attention processor.
|
419 |
+
|
420 |
+
Returns:
|
421 |
+
"AttentionProcessor": The attention processor in use.
|
422 |
+
"""
|
423 |
+
if not return_deprecated_lora:
|
424 |
+
return self.processor
|
425 |
+
|
426 |
+
# TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
|
427 |
+
# serialization format for LoRA Attention Processors. It should be deleted once the integration
|
428 |
+
# with PEFT is completed.
|
429 |
+
is_lora_activated = {
|
430 |
+
name: module.lora_layer is not None
|
431 |
+
for name, module in self.named_modules()
|
432 |
+
if hasattr(module, "lora_layer")
|
433 |
+
}
|
434 |
+
|
435 |
+
# 1. if no layer has a LoRA activated we can return the processor as usual
|
436 |
+
if not any(is_lora_activated.values()):
|
437 |
+
return self.processor
|
438 |
+
|
439 |
+
# If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
|
440 |
+
is_lora_activated.pop("add_k_proj", None)
|
441 |
+
is_lora_activated.pop("add_v_proj", None)
|
442 |
+
# 2. else it is not possible that only some layers have LoRA activated
|
443 |
+
if not all(is_lora_activated.values()):
|
444 |
+
raise ValueError(
|
445 |
+
f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
|
446 |
+
)
|
447 |
+
|
448 |
+
# 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
|
449 |
+
non_lora_processor_cls_name = self.processor.__class__.__name__
|
450 |
+
lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
|
451 |
+
|
452 |
+
hidden_size = self.inner_dim
|
453 |
+
|
454 |
+
# now create a LoRA attention processor from the LoRA layers
|
455 |
+
if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
|
456 |
+
kwargs = {
|
457 |
+
"cross_attention_dim": self.cross_attention_dim,
|
458 |
+
"rank": self.to_q.lora_layer.rank,
|
459 |
+
"network_alpha": self.to_q.lora_layer.network_alpha,
|
460 |
+
"q_rank": self.to_q.lora_layer.rank,
|
461 |
+
"q_hidden_size": self.to_q.lora_layer.out_features,
|
462 |
+
"k_rank": self.to_k.lora_layer.rank,
|
463 |
+
"k_hidden_size": self.to_k.lora_layer.out_features,
|
464 |
+
"v_rank": self.to_v.lora_layer.rank,
|
465 |
+
"v_hidden_size": self.to_v.lora_layer.out_features,
|
466 |
+
"out_rank": self.to_out[0].lora_layer.rank,
|
467 |
+
"out_hidden_size": self.to_out[0].lora_layer.out_features,
|
468 |
+
}
|
469 |
+
|
470 |
+
if hasattr(self.processor, "attention_op"):
|
471 |
+
kwargs["attention_op"] = self.processor.attention_op
|
472 |
+
|
473 |
+
lora_processor = lora_processor_cls(hidden_size, **kwargs)
|
474 |
+
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
475 |
+
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
|
476 |
+
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
|
477 |
+
lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
|
478 |
+
elif lora_processor_cls == LoRAAttnAddedKVProcessor:
|
479 |
+
lora_processor = lora_processor_cls(
|
480 |
+
hidden_size,
|
481 |
+
cross_attention_dim=self.add_k_proj.weight.shape[0],
|
482 |
+
rank=self.to_q.lora_layer.rank,
|
483 |
+
network_alpha=self.to_q.lora_layer.network_alpha,
|
484 |
+
)
|
485 |
+
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
486 |
+
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
|
487 |
+
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
|
488 |
+
lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
|
489 |
+
|
490 |
+
# only save if used
|
491 |
+
if self.add_k_proj.lora_layer is not None:
|
492 |
+
lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
|
493 |
+
lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
|
494 |
+
else:
|
495 |
+
lora_processor.add_k_proj_lora = None
|
496 |
+
lora_processor.add_v_proj_lora = None
|
497 |
+
else:
|
498 |
+
raise ValueError(f"{lora_processor_cls} does not exist.")
|
499 |
+
|
500 |
+
return lora_processor
|
501 |
+
|
502 |
+
class GEGLU(nn.Module):
|
503 |
+
def __init__(self, in_features, out_features):
|
504 |
+
super(GEGLU, self).__init__()
|
505 |
+
self.proj = nn.Linear(in_features, out_features * 2, bias=True)
|
506 |
+
|
507 |
+
def forward(self, x):
|
508 |
+
x_proj = self.proj(x)
|
509 |
+
x1, x2 = x_proj.chunk(2, dim=-1)
|
510 |
+
return x1 * torch.nn.functional.gelu(x2)
|
511 |
+
|
512 |
+
|
513 |
+
class FeedForward(nn.Module):
|
514 |
+
def __init__(self, in_features, out_features):
|
515 |
+
super(FeedForward, self).__init__()
|
516 |
+
|
517 |
+
self.net = nn.ModuleList(
|
518 |
+
[
|
519 |
+
GEGLU(in_features, out_features * 4),
|
520 |
+
nn.Dropout(p=0.0, inplace=False),
|
521 |
+
nn.Linear(out_features * 4, out_features, bias=True),
|
522 |
+
]
|
523 |
+
)
|
524 |
+
|
525 |
+
def forward(self, x):
|
526 |
+
for layer in self.net:
|
527 |
+
x = layer(x)
|
528 |
+
return x
|
529 |
+
|
530 |
+
|
531 |
+
class BasicTransformerBlock(nn.Module):
|
532 |
+
def __init__(self, hidden_size):
|
533 |
+
super(BasicTransformerBlock, self).__init__()
|
534 |
+
self.norm1 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True)
|
535 |
+
self.attn1 = Attention(hidden_size)
|
536 |
+
self.norm2 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True)
|
537 |
+
self.attn2 = Attention(hidden_size, 2048)
|
538 |
+
self.norm3 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True)
|
539 |
+
self.ff = FeedForward(hidden_size, hidden_size)
|
540 |
+
|
541 |
+
def forward(self, x, encoder_hidden_states=None):
|
542 |
+
residual = x
|
543 |
+
|
544 |
+
x = self.norm1(x)
|
545 |
+
x = self.attn1(x)
|
546 |
+
x = x + residual
|
547 |
+
|
548 |
+
residual = x
|
549 |
+
|
550 |
+
x = self.norm2(x)
|
551 |
+
if encoder_hidden_states is not None:
|
552 |
+
x = self.attn2(x, encoder_hidden_states)
|
553 |
+
else:
|
554 |
+
x = self.attn2(x)
|
555 |
+
x = x + residual
|
556 |
+
|
557 |
+
residual = x
|
558 |
+
|
559 |
+
x = self.norm3(x)
|
560 |
+
x = self.ff(x)
|
561 |
+
x = x + residual
|
562 |
+
return x
|
563 |
+
|
564 |
+
|
565 |
+
class Transformer2DModel(nn.Module):
|
566 |
+
def __init__(self, in_channels, out_channels, n_layers):
|
567 |
+
super(Transformer2DModel, self).__init__()
|
568 |
+
self.norm = nn.GroupNorm(32, in_channels, eps=1e-06, affine=True)
|
569 |
+
self.proj_in = nn.Linear(in_channels, out_channels, bias=True)
|
570 |
+
self.transformer_blocks = nn.ModuleList(
|
571 |
+
[BasicTransformerBlock(out_channels) for _ in range(n_layers)]
|
572 |
+
)
|
573 |
+
self.proj_out = nn.Linear(out_channels, out_channels, bias=True)
|
574 |
+
|
575 |
+
def forward(self, hidden_states, encoder_hidden_states=None):
|
576 |
+
batch, _, height, width = hidden_states.shape
|
577 |
+
res = hidden_states
|
578 |
+
hidden_states = self.norm(hidden_states)
|
579 |
+
inner_dim = hidden_states.shape[1]
|
580 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
|
581 |
+
batch, height * width, inner_dim
|
582 |
+
)
|
583 |
+
hidden_states = self.proj_in(hidden_states)
|
584 |
+
|
585 |
+
for block in self.transformer_blocks:
|
586 |
+
hidden_states = block(hidden_states, encoder_hidden_states)
|
587 |
+
|
588 |
+
hidden_states = self.proj_out(hidden_states)
|
589 |
+
hidden_states = (
|
590 |
+
hidden_states.reshape(batch, height, width, inner_dim)
|
591 |
+
.permute(0, 3, 1, 2)
|
592 |
+
.contiguous()
|
593 |
+
)
|
594 |
+
|
595 |
+
return hidden_states + res
|
596 |
+
|
597 |
+
|
598 |
+
class Downsample2D(nn.Module):
|
599 |
+
def __init__(self, in_channels, out_channels):
|
600 |
+
super(Downsample2D, self).__init__()
|
601 |
+
self.conv = nn.Conv2d(
|
602 |
+
in_channels, out_channels, kernel_size=3, stride=2, padding=1
|
603 |
+
)
|
604 |
+
|
605 |
+
def forward(self, x):
|
606 |
+
return self.conv(x)
|
607 |
+
|
608 |
+
|
609 |
+
class Upsample2D(nn.Module):
|
610 |
+
def __init__(self, in_channels, out_channels):
|
611 |
+
super(Upsample2D, self).__init__()
|
612 |
+
self.conv = nn.Conv2d(
|
613 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
614 |
+
)
|
615 |
+
|
616 |
+
def forward(self, x):
|
617 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
618 |
+
return self.conv(x)
|
619 |
+
|
620 |
+
|
621 |
+
class DownBlock2D(nn.Module):
|
622 |
+
def __init__(self, in_channels, out_channels):
|
623 |
+
super(DownBlock2D, self).__init__()
|
624 |
+
self.resnets = nn.ModuleList(
|
625 |
+
[
|
626 |
+
ResnetBlock2D(in_channels, out_channels, conv_shortcut=False),
|
627 |
+
ResnetBlock2D(out_channels, out_channels, conv_shortcut=False),
|
628 |
+
]
|
629 |
+
)
|
630 |
+
self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)])
|
631 |
+
|
632 |
+
def forward(self, hidden_states, temb):
|
633 |
+
output_states = []
|
634 |
+
for module in self.resnets:
|
635 |
+
hidden_states = module(hidden_states, temb)
|
636 |
+
output_states.append(hidden_states)
|
637 |
+
|
638 |
+
hidden_states = self.downsamplers[0](hidden_states)
|
639 |
+
output_states.append(hidden_states)
|
640 |
+
|
641 |
+
return hidden_states, output_states
|
642 |
+
|
643 |
+
|
644 |
+
class CrossAttnDownBlock2D(nn.Module):
|
645 |
+
def __init__(self, in_channels, out_channels, n_layers, has_downsamplers=True):
|
646 |
+
super(CrossAttnDownBlock2D, self).__init__()
|
647 |
+
self.attentions = nn.ModuleList(
|
648 |
+
[
|
649 |
+
Transformer2DModel(out_channels, out_channels, n_layers),
|
650 |
+
Transformer2DModel(out_channels, out_channels, n_layers),
|
651 |
+
]
|
652 |
+
)
|
653 |
+
self.resnets = nn.ModuleList(
|
654 |
+
[
|
655 |
+
ResnetBlock2D(in_channels, out_channels),
|
656 |
+
ResnetBlock2D(out_channels, out_channels, conv_shortcut=False),
|
657 |
+
]
|
658 |
+
)
|
659 |
+
self.downsamplers = None
|
660 |
+
if has_downsamplers:
|
661 |
+
self.downsamplers = nn.ModuleList(
|
662 |
+
[Downsample2D(out_channels, out_channels)]
|
663 |
+
)
|
664 |
+
|
665 |
+
def forward(self, hidden_states, temb, encoder_hidden_states):
|
666 |
+
output_states = []
|
667 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
668 |
+
hidden_states = resnet(hidden_states, temb)
|
669 |
+
hidden_states = attn(
|
670 |
+
hidden_states,
|
671 |
+
encoder_hidden_states=encoder_hidden_states,
|
672 |
+
)
|
673 |
+
output_states.append(hidden_states)
|
674 |
+
|
675 |
+
if self.downsamplers is not None:
|
676 |
+
hidden_states = self.downsamplers[0](hidden_states)
|
677 |
+
output_states.append(hidden_states)
|
678 |
+
|
679 |
+
return hidden_states, output_states
|
680 |
+
|
681 |
+
|
682 |
+
class CrossAttnUpBlock2D(nn.Module):
|
683 |
+
def __init__(self, in_channels, out_channels, prev_output_channel, n_layers):
|
684 |
+
super(CrossAttnUpBlock2D, self).__init__()
|
685 |
+
self.attentions = nn.ModuleList(
|
686 |
+
[
|
687 |
+
Transformer2DModel(out_channels, out_channels, n_layers),
|
688 |
+
Transformer2DModel(out_channels, out_channels, n_layers),
|
689 |
+
Transformer2DModel(out_channels, out_channels, n_layers),
|
690 |
+
]
|
691 |
+
)
|
692 |
+
self.resnets = nn.ModuleList(
|
693 |
+
[
|
694 |
+
ResnetBlock2D(prev_output_channel + out_channels, out_channels),
|
695 |
+
ResnetBlock2D(2 * out_channels, out_channels),
|
696 |
+
ResnetBlock2D(out_channels + in_channels, out_channels),
|
697 |
+
]
|
698 |
+
)
|
699 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
|
700 |
+
|
701 |
+
def forward(
|
702 |
+
self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states
|
703 |
+
):
|
704 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
705 |
+
# pop res hidden states
|
706 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
707 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
708 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
709 |
+
hidden_states = resnet(hidden_states, temb)
|
710 |
+
hidden_states = attn(
|
711 |
+
hidden_states,
|
712 |
+
encoder_hidden_states=encoder_hidden_states,
|
713 |
+
)
|
714 |
+
|
715 |
+
if self.upsamplers is not None:
|
716 |
+
for upsampler in self.upsamplers:
|
717 |
+
hidden_states = upsampler(hidden_states)
|
718 |
+
|
719 |
+
return hidden_states
|
720 |
+
|
721 |
+
|
722 |
+
class UpBlock2D(nn.Module):
|
723 |
+
def __init__(self, in_channels, out_channels, prev_output_channel):
|
724 |
+
super(UpBlock2D, self).__init__()
|
725 |
+
self.resnets = nn.ModuleList(
|
726 |
+
[
|
727 |
+
ResnetBlock2D(out_channels + prev_output_channel, out_channels),
|
728 |
+
ResnetBlock2D(out_channels * 2, out_channels),
|
729 |
+
ResnetBlock2D(out_channels + in_channels, out_channels),
|
730 |
+
]
|
731 |
+
)
|
732 |
+
|
733 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
734 |
+
|
735 |
+
is_freeu_enabled = (
|
736 |
+
getattr(self, "s1", None)
|
737 |
+
and getattr(self, "s2", None)
|
738 |
+
and getattr(self, "b1", None)
|
739 |
+
and getattr(self, "b2", None)
|
740 |
+
and getattr(self, "resolution_idx", None)
|
741 |
+
)
|
742 |
+
|
743 |
+
for resnet in self.resnets:
|
744 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
745 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
746 |
+
|
747 |
+
|
748 |
+
if is_freeu_enabled:
|
749 |
+
hidden_states, res_hidden_states = apply_freeu(
|
750 |
+
self.resolution_idx,
|
751 |
+
hidden_states,
|
752 |
+
res_hidden_states,
|
753 |
+
s1=self.s1,
|
754 |
+
s2=self.s2,
|
755 |
+
b1=self.b1,
|
756 |
+
b2=self.b2,
|
757 |
+
)
|
758 |
+
|
759 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
760 |
+
hidden_states = resnet(hidden_states, temb)
|
761 |
+
|
762 |
+
return hidden_states
|
763 |
+
|
764 |
+
class UNetMidBlock2DCrossAttn(nn.Module):
|
765 |
+
def __init__(self, in_features):
|
766 |
+
super(UNetMidBlock2DCrossAttn, self).__init__()
|
767 |
+
self.attentions = nn.ModuleList(
|
768 |
+
[Transformer2DModel(in_features, in_features, n_layers=10)]
|
769 |
+
)
|
770 |
+
self.resnets = nn.ModuleList(
|
771 |
+
[
|
772 |
+
ResnetBlock2D(in_features, in_features, conv_shortcut=False),
|
773 |
+
ResnetBlock2D(in_features, in_features, conv_shortcut=False),
|
774 |
+
]
|
775 |
+
)
|
776 |
+
|
777 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
778 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
779 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
780 |
+
hidden_states = attn(
|
781 |
+
hidden_states,
|
782 |
+
encoder_hidden_states=encoder_hidden_states,
|
783 |
+
)
|
784 |
+
hidden_states = resnet(hidden_states, temb)
|
785 |
+
|
786 |
+
return hidden_states
|
787 |
+
|
788 |
+
|
789 |
+
class UNet2DConditionModel(nn.Module):
|
790 |
+
def __init__(self):
|
791 |
+
super(UNet2DConditionModel, self).__init__()
|
792 |
+
|
793 |
+
# This is needed to imitate huggingface config behavior
|
794 |
+
# has nothing to do with the model itself
|
795 |
+
# remove this if you don't use diffuser's pipeline
|
796 |
+
self.config = namedtuple(
|
797 |
+
"config", "in_channels addition_time_embed_dim sample_size"
|
798 |
+
)
|
799 |
+
self.config.in_channels = 4
|
800 |
+
self.config.addition_time_embed_dim = 256
|
801 |
+
self.config.sample_size = 128
|
802 |
+
|
803 |
+
self.conv_in = nn.Conv2d(4, 320, kernel_size=3, stride=1, padding=1)
|
804 |
+
self.time_proj = Timesteps()
|
805 |
+
self.time_embedding = TimestepEmbedding(in_features=320, out_features=1280)
|
806 |
+
self.add_time_proj = Timesteps(256)
|
807 |
+
self.add_embedding = TimestepEmbedding(in_features=2816, out_features=1280)
|
808 |
+
self.down_blocks = nn.ModuleList(
|
809 |
+
[
|
810 |
+
DownBlock2D(in_channels=320, out_channels=320),
|
811 |
+
CrossAttnDownBlock2D(in_channels=320, out_channels=640, n_layers=2),
|
812 |
+
CrossAttnDownBlock2D(
|
813 |
+
in_channels=640,
|
814 |
+
out_channels=1280,
|
815 |
+
n_layers=10,
|
816 |
+
has_downsamplers=False,
|
817 |
+
),
|
818 |
+
]
|
819 |
+
)
|
820 |
+
self.up_blocks = nn.ModuleList(
|
821 |
+
[
|
822 |
+
CrossAttnUpBlock2D(
|
823 |
+
in_channels=640,
|
824 |
+
out_channels=1280,
|
825 |
+
prev_output_channel=1280,
|
826 |
+
n_layers=10,
|
827 |
+
),
|
828 |
+
CrossAttnUpBlock2D(
|
829 |
+
in_channels=320,
|
830 |
+
out_channels=640,
|
831 |
+
prev_output_channel=1280,
|
832 |
+
n_layers=2,
|
833 |
+
),
|
834 |
+
UpBlock2D(in_channels=320, out_channels=320, prev_output_channel=640),
|
835 |
+
]
|
836 |
+
)
|
837 |
+
self.mid_block = UNetMidBlock2DCrossAttn(1280)
|
838 |
+
self.conv_norm_out = nn.GroupNorm(32, 320, eps=1e-05, affine=True)
|
839 |
+
self.conv_act = nn.SiLU()
|
840 |
+
self.conv_out = nn.Conv2d(320, 4, kernel_size=3, stride=1, padding=1)
|
841 |
+
|
842 |
+
def forward(
|
843 |
+
self, sample, timesteps, encoder_hidden_states, added_cond_kwargs, **kwargs
|
844 |
+
):
|
845 |
+
# Implement the forward pass through the model
|
846 |
+
timesteps = timesteps.expand(sample.shape[0])
|
847 |
+
t_emb = self.time_proj(timesteps).to(dtype=sample.dtype)
|
848 |
+
|
849 |
+
emb = self.time_embedding(t_emb)
|
850 |
+
|
851 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
852 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
853 |
+
|
854 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
855 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
856 |
+
|
857 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
858 |
+
add_embeds = add_embeds.to(emb.dtype)
|
859 |
+
aug_emb = self.add_embedding(add_embeds)
|
860 |
+
|
861 |
+
emb = emb + aug_emb
|
862 |
+
|
863 |
+
sample = self.conv_in(sample)
|
864 |
+
|
865 |
+
# 3. down
|
866 |
+
s0 = sample
|
867 |
+
sample, [s1, s2, s3] = self.down_blocks[0](
|
868 |
+
sample,
|
869 |
+
temb=emb,
|
870 |
+
)
|
871 |
+
|
872 |
+
sample, [s4, s5, s6] = self.down_blocks[1](
|
873 |
+
sample,
|
874 |
+
temb=emb,
|
875 |
+
encoder_hidden_states=encoder_hidden_states,
|
876 |
+
)
|
877 |
+
|
878 |
+
sample, [s7, s8] = self.down_blocks[2](
|
879 |
+
sample,
|
880 |
+
temb=emb,
|
881 |
+
encoder_hidden_states=encoder_hidden_states,
|
882 |
+
)
|
883 |
+
|
884 |
+
# 4. mid
|
885 |
+
sample = self.mid_block(
|
886 |
+
sample, emb, encoder_hidden_states=encoder_hidden_states
|
887 |
+
)
|
888 |
+
|
889 |
+
# 5. up
|
890 |
+
sample = self.up_blocks[0](
|
891 |
+
hidden_states=sample,
|
892 |
+
temb=emb,
|
893 |
+
res_hidden_states_tuple=[s6, s7, s8],
|
894 |
+
encoder_hidden_states=encoder_hidden_states,
|
895 |
+
)
|
896 |
+
|
897 |
+
sample = self.up_blocks[1](
|
898 |
+
hidden_states=sample,
|
899 |
+
temb=emb,
|
900 |
+
res_hidden_states_tuple=[s3, s4, s5],
|
901 |
+
encoder_hidden_states=encoder_hidden_states,
|
902 |
+
)
|
903 |
+
|
904 |
+
sample = self.up_blocks[2](
|
905 |
+
hidden_states=sample,
|
906 |
+
temb=emb,
|
907 |
+
res_hidden_states_tuple=[s0, s1, s2],
|
908 |
+
)
|
909 |
+
|
910 |
+
# 6. post-process
|
911 |
+
sample = self.conv_norm_out(sample)
|
912 |
+
sample = self.conv_act(sample)
|
913 |
+
sample = self.conv_out(sample)
|
914 |
+
|
915 |
+
return [sample]
|
module/unet/unet_2d_ZeroSFT.py
ADDED
@@ -0,0 +1,1397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copy from diffusers.models.unets.unet_2d_condition.py
|
2 |
+
|
3 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.utils.checkpoint
|
22 |
+
|
23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
24 |
+
from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
25 |
+
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
26 |
+
from diffusers.models.activations import get_activation
|
27 |
+
from diffusers.models.attention_processor import (
|
28 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
29 |
+
CROSS_ATTENTION_PROCESSORS,
|
30 |
+
Attention,
|
31 |
+
AttentionProcessor,
|
32 |
+
AttnAddedKVProcessor,
|
33 |
+
AttnProcessor,
|
34 |
+
)
|
35 |
+
from diffusers.models.embeddings import (
|
36 |
+
GaussianFourierProjection,
|
37 |
+
GLIGENTextBoundingboxProjection,
|
38 |
+
ImageHintTimeEmbedding,
|
39 |
+
ImageProjection,
|
40 |
+
ImageTimeEmbedding,
|
41 |
+
TextImageProjection,
|
42 |
+
TextImageTimeEmbedding,
|
43 |
+
TextTimeEmbedding,
|
44 |
+
TimestepEmbedding,
|
45 |
+
Timesteps,
|
46 |
+
)
|
47 |
+
from diffusers.models.modeling_utils import ModelMixin
|
48 |
+
from .unet_2d_ZeroSFT_blocks import (
|
49 |
+
get_down_block,
|
50 |
+
get_mid_block,
|
51 |
+
get_up_block,
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
56 |
+
|
57 |
+
|
58 |
+
def zero_module(module):
|
59 |
+
for p in module.parameters():
|
60 |
+
nn.init.zeros_(p)
|
61 |
+
return module
|
62 |
+
|
63 |
+
|
64 |
+
class ZeroConv(nn.Module):
|
65 |
+
def __init__(self, label_nc, norm_nc, mask=False):
|
66 |
+
super().__init__()
|
67 |
+
self.zero_conv = zero_module(nn.Conv2d(label_nc, norm_nc, 1, 1, 0))
|
68 |
+
self.mask = mask
|
69 |
+
|
70 |
+
def forward(self, c, h, h_ori=None):
|
71 |
+
# with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
|
72 |
+
if not self.mask:
|
73 |
+
h = h + self.zero_conv(c)
|
74 |
+
else:
|
75 |
+
h = h + self.zero_conv(c) * torch.zeros_like(h)
|
76 |
+
if h_ori is not None:
|
77 |
+
h = torch.cat([h_ori, h], dim=1)
|
78 |
+
return h
|
79 |
+
|
80 |
+
|
81 |
+
class ZeroSFT(nn.Module):
|
82 |
+
def __init__(self, label_nc, norm_nc, concat_channels=0, norm=True, mask=False):
|
83 |
+
super().__init__()
|
84 |
+
|
85 |
+
# param_free_norm_type = str(parsed.group(1))
|
86 |
+
ks = 3
|
87 |
+
pw = ks // 2
|
88 |
+
|
89 |
+
self.mask = mask
|
90 |
+
self.norm = norm
|
91 |
+
self.pre_concat = bool(concat_channels != 0)
|
92 |
+
if self.norm:
|
93 |
+
self.param_free_norm = torch.nn.GroupNorm(num_groups=32, num_channels=norm_nc + concat_channels)
|
94 |
+
else:
|
95 |
+
self.param_free_norm = nn.Identity()
|
96 |
+
|
97 |
+
nhidden = 128
|
98 |
+
|
99 |
+
self.mlp_shared = nn.Sequential(
|
100 |
+
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
|
101 |
+
nn.SiLU()
|
102 |
+
)
|
103 |
+
self.zero_mul = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw))
|
104 |
+
self.zero_add = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw))
|
105 |
+
|
106 |
+
self.zero_conv = zero_module(nn.Conv2d(label_nc, norm_nc, 1, 1, 0))
|
107 |
+
|
108 |
+
def forward(self, down_block_res_samples, h_ori=None, control_scale=1.0, mask=False):
|
109 |
+
mask = mask or self.mask
|
110 |
+
assert mask is False
|
111 |
+
if self.pre_concat:
|
112 |
+
assert h_ori is not None
|
113 |
+
|
114 |
+
c,h = down_block_res_samples
|
115 |
+
if h_ori is not None:
|
116 |
+
h_raw = torch.cat([h_ori, h], dim=1)
|
117 |
+
else:
|
118 |
+
h_raw = h
|
119 |
+
|
120 |
+
if self.mask:
|
121 |
+
h = h + self.zero_conv(c) * torch.zeros_like(h)
|
122 |
+
else:
|
123 |
+
h = h + self.zero_conv(c)
|
124 |
+
if h_ori is not None and self.pre_concat:
|
125 |
+
h = torch.cat([h_ori, h], dim=1)
|
126 |
+
actv = self.mlp_shared(c)
|
127 |
+
gamma = self.zero_mul(actv)
|
128 |
+
beta = self.zero_add(actv)
|
129 |
+
if self.mask:
|
130 |
+
gamma = gamma * torch.zeros_like(gamma)
|
131 |
+
beta = beta * torch.zeros_like(beta)
|
132 |
+
# h = h + self.param_free_norm(h) * gamma + beta
|
133 |
+
h = self.param_free_norm(h) * (gamma + 1) + beta
|
134 |
+
if h_ori is not None and not self.pre_concat:
|
135 |
+
h = torch.cat([h_ori, h], dim=1)
|
136 |
+
return h * control_scale + h_raw * (1 - control_scale)
|
137 |
+
|
138 |
+
|
139 |
+
@dataclass
|
140 |
+
class UNet2DConditionOutput(BaseOutput):
|
141 |
+
"""
|
142 |
+
The output of [`UNet2DConditionModel`].
|
143 |
+
|
144 |
+
Args:
|
145 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
146 |
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
147 |
+
"""
|
148 |
+
|
149 |
+
sample: torch.FloatTensor = None
|
150 |
+
|
151 |
+
|
152 |
+
class UNet2DZeroSFTModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
|
153 |
+
r"""
|
154 |
+
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
155 |
+
shaped output.
|
156 |
+
|
157 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
158 |
+
for all models (such as downloading or saving).
|
159 |
+
|
160 |
+
Parameters:
|
161 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
162 |
+
Height and width of input/output sample.
|
163 |
+
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
|
164 |
+
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
165 |
+
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
166 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
167 |
+
Whether to flip the sin to cos in the time embedding.
|
168 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
169 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
170 |
+
The tuple of downsample blocks to use.
|
171 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
172 |
+
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
|
173 |
+
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
174 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
175 |
+
The tuple of upsample blocks to use.
|
176 |
+
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
177 |
+
Whether to include self-attention in the basic transformer blocks, see
|
178 |
+
[`~models.attention.BasicTransformerBlock`].
|
179 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
180 |
+
The tuple of output channels for each block.
|
181 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
182 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
183 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
184 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
185 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
186 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
187 |
+
If `None`, normalization and activation layers is skipped in post-processing.
|
188 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
189 |
+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
190 |
+
The dimension of the cross attention features.
|
191 |
+
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
|
192 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
193 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
194 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
195 |
+
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
|
196 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
|
197 |
+
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
|
198 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
199 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
200 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
201 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
202 |
+
dimension to `cross_attention_dim`.
|
203 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
204 |
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
205 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
206 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
207 |
+
num_attention_heads (`int`, *optional*):
|
208 |
+
The number of attention heads. If not defined, defaults to `attention_head_dim`
|
209 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
210 |
+
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
211 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
212 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
213 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
214 |
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
215 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
216 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
217 |
+
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
|
218 |
+
Dimension for the timestep embeddings.
|
219 |
+
num_class_embeds (`int`, *optional*, defaults to `None`):
|
220 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
221 |
+
class conditioning with `class_embed_type` equal to `None`.
|
222 |
+
time_embedding_type (`str`, *optional*, defaults to `positional`):
|
223 |
+
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
224 |
+
time_embedding_dim (`int`, *optional*, defaults to `None`):
|
225 |
+
An optional override for the dimension of the projected time embedding.
|
226 |
+
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
|
227 |
+
Optional activation function to use only once on the time embeddings before they are passed to the rest of
|
228 |
+
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
229 |
+
timestep_post_act (`str`, *optional*, defaults to `None`):
|
230 |
+
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
231 |
+
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
|
232 |
+
The dimension of `cond_proj` layer in the timestep embedding.
|
233 |
+
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
|
234 |
+
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
|
235 |
+
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
236 |
+
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
|
237 |
+
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
238 |
+
embeddings with the class embeddings.
|
239 |
+
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
|
240 |
+
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
|
241 |
+
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
|
242 |
+
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
|
243 |
+
otherwise.
|
244 |
+
"""
|
245 |
+
|
246 |
+
_supports_gradient_checkpointing = True
|
247 |
+
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
|
248 |
+
|
249 |
+
@register_to_config
|
250 |
+
def __init__(
|
251 |
+
self,
|
252 |
+
sample_size: Optional[int] = None,
|
253 |
+
in_channels: int = 4,
|
254 |
+
out_channels: int = 4,
|
255 |
+
center_input_sample: bool = False,
|
256 |
+
flip_sin_to_cos: bool = True,
|
257 |
+
freq_shift: int = 0,
|
258 |
+
down_block_types: Tuple[str] = (
|
259 |
+
"CrossAttnDownBlock2D",
|
260 |
+
"CrossAttnDownBlock2D",
|
261 |
+
"CrossAttnDownBlock2D",
|
262 |
+
"DownBlock2D",
|
263 |
+
),
|
264 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
265 |
+
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
266 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
267 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
268 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
269 |
+
downsample_padding: int = 1,
|
270 |
+
mid_block_scale_factor: float = 1,
|
271 |
+
dropout: float = 0.0,
|
272 |
+
act_fn: str = "silu",
|
273 |
+
norm_num_groups: Optional[int] = 32,
|
274 |
+
norm_eps: float = 1e-5,
|
275 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
276 |
+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
277 |
+
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
|
278 |
+
encoder_hid_dim: Optional[int] = None,
|
279 |
+
encoder_hid_dim_type: Optional[str] = None,
|
280 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
281 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
282 |
+
dual_cross_attention: bool = False,
|
283 |
+
use_linear_projection: bool = False,
|
284 |
+
class_embed_type: Optional[str] = None,
|
285 |
+
addition_embed_type: Optional[str] = None,
|
286 |
+
addition_time_embed_dim: Optional[int] = None,
|
287 |
+
num_class_embeds: Optional[int] = None,
|
288 |
+
upcast_attention: bool = False,
|
289 |
+
resnet_time_scale_shift: str = "default",
|
290 |
+
resnet_skip_time_act: bool = False,
|
291 |
+
resnet_out_scale_factor: float = 1.0,
|
292 |
+
time_embedding_type: str = "positional",
|
293 |
+
time_embedding_dim: Optional[int] = None,
|
294 |
+
time_embedding_act_fn: Optional[str] = None,
|
295 |
+
timestep_post_act: Optional[str] = None,
|
296 |
+
time_cond_proj_dim: Optional[int] = None,
|
297 |
+
conv_in_kernel: int = 3,
|
298 |
+
conv_out_kernel: int = 3,
|
299 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
300 |
+
attention_type: str = "default",
|
301 |
+
class_embeddings_concat: bool = False,
|
302 |
+
mid_block_only_cross_attention: Optional[bool] = None,
|
303 |
+
cross_attention_norm: Optional[str] = None,
|
304 |
+
addition_embed_type_num_heads: int = 64,
|
305 |
+
):
|
306 |
+
super().__init__()
|
307 |
+
|
308 |
+
self.sample_size = sample_size
|
309 |
+
|
310 |
+
if num_attention_heads is not None:
|
311 |
+
raise ValueError(
|
312 |
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
313 |
+
)
|
314 |
+
|
315 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
316 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
317 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
318 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
319 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
320 |
+
# which is why we correct for the naming here.
|
321 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
322 |
+
|
323 |
+
# Check inputs
|
324 |
+
self._check_config(
|
325 |
+
down_block_types=down_block_types,
|
326 |
+
up_block_types=up_block_types,
|
327 |
+
only_cross_attention=only_cross_attention,
|
328 |
+
block_out_channels=block_out_channels,
|
329 |
+
layers_per_block=layers_per_block,
|
330 |
+
cross_attention_dim=cross_attention_dim,
|
331 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
332 |
+
reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
|
333 |
+
attention_head_dim=attention_head_dim,
|
334 |
+
num_attention_heads=num_attention_heads,
|
335 |
+
)
|
336 |
+
|
337 |
+
# input
|
338 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
339 |
+
self.conv_in = nn.Conv2d(
|
340 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
341 |
+
)
|
342 |
+
|
343 |
+
# time
|
344 |
+
time_embed_dim, timestep_input_dim = self._set_time_proj(
|
345 |
+
time_embedding_type,
|
346 |
+
block_out_channels=block_out_channels,
|
347 |
+
flip_sin_to_cos=flip_sin_to_cos,
|
348 |
+
freq_shift=freq_shift,
|
349 |
+
time_embedding_dim=time_embedding_dim,
|
350 |
+
)
|
351 |
+
|
352 |
+
self.time_embedding = TimestepEmbedding(
|
353 |
+
timestep_input_dim,
|
354 |
+
time_embed_dim,
|
355 |
+
act_fn=act_fn,
|
356 |
+
post_act_fn=timestep_post_act,
|
357 |
+
cond_proj_dim=time_cond_proj_dim,
|
358 |
+
)
|
359 |
+
|
360 |
+
self._set_encoder_hid_proj(
|
361 |
+
encoder_hid_dim_type,
|
362 |
+
cross_attention_dim=cross_attention_dim,
|
363 |
+
encoder_hid_dim=encoder_hid_dim,
|
364 |
+
)
|
365 |
+
|
366 |
+
# class embedding
|
367 |
+
self._set_class_embedding(
|
368 |
+
class_embed_type,
|
369 |
+
act_fn=act_fn,
|
370 |
+
num_class_embeds=num_class_embeds,
|
371 |
+
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
|
372 |
+
time_embed_dim=time_embed_dim,
|
373 |
+
timestep_input_dim=timestep_input_dim,
|
374 |
+
)
|
375 |
+
|
376 |
+
self._set_add_embedding(
|
377 |
+
addition_embed_type,
|
378 |
+
addition_embed_type_num_heads=addition_embed_type_num_heads,
|
379 |
+
addition_time_embed_dim=addition_time_embed_dim,
|
380 |
+
cross_attention_dim=cross_attention_dim,
|
381 |
+
encoder_hid_dim=encoder_hid_dim,
|
382 |
+
flip_sin_to_cos=flip_sin_to_cos,
|
383 |
+
freq_shift=freq_shift,
|
384 |
+
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
|
385 |
+
time_embed_dim=time_embed_dim,
|
386 |
+
)
|
387 |
+
|
388 |
+
if time_embedding_act_fn is None:
|
389 |
+
self.time_embed_act = None
|
390 |
+
else:
|
391 |
+
self.time_embed_act = get_activation(time_embedding_act_fn)
|
392 |
+
|
393 |
+
self.down_blocks = nn.ModuleList([])
|
394 |
+
self.up_blocks = nn.ModuleList([])
|
395 |
+
|
396 |
+
if isinstance(only_cross_attention, bool):
|
397 |
+
if mid_block_only_cross_attention is None:
|
398 |
+
mid_block_only_cross_attention = only_cross_attention
|
399 |
+
|
400 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
401 |
+
|
402 |
+
if mid_block_only_cross_attention is None:
|
403 |
+
mid_block_only_cross_attention = False
|
404 |
+
|
405 |
+
if isinstance(num_attention_heads, int):
|
406 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
407 |
+
|
408 |
+
if isinstance(attention_head_dim, int):
|
409 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
410 |
+
|
411 |
+
if isinstance(cross_attention_dim, int):
|
412 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
413 |
+
|
414 |
+
if isinstance(layers_per_block, int):
|
415 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
416 |
+
|
417 |
+
if isinstance(transformer_layers_per_block, int):
|
418 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
419 |
+
|
420 |
+
if class_embeddings_concat:
|
421 |
+
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
422 |
+
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
423 |
+
# regular time embeddings
|
424 |
+
blocks_time_embed_dim = time_embed_dim * 2
|
425 |
+
else:
|
426 |
+
blocks_time_embed_dim = time_embed_dim
|
427 |
+
|
428 |
+
# down
|
429 |
+
output_channel = block_out_channels[0]
|
430 |
+
for i, down_block_type in enumerate(down_block_types):
|
431 |
+
input_channel = output_channel
|
432 |
+
output_channel = block_out_channels[i]
|
433 |
+
is_final_block = i == len(block_out_channels) - 1
|
434 |
+
|
435 |
+
down_block = get_down_block(
|
436 |
+
down_block_type,
|
437 |
+
num_layers=layers_per_block[i],
|
438 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
439 |
+
in_channels=input_channel,
|
440 |
+
out_channels=output_channel,
|
441 |
+
temb_channels=blocks_time_embed_dim,
|
442 |
+
add_downsample=not is_final_block,
|
443 |
+
resnet_eps=norm_eps,
|
444 |
+
resnet_act_fn=act_fn,
|
445 |
+
resnet_groups=norm_num_groups,
|
446 |
+
cross_attention_dim=cross_attention_dim[i],
|
447 |
+
num_attention_heads=num_attention_heads[i],
|
448 |
+
downsample_padding=downsample_padding,
|
449 |
+
dual_cross_attention=dual_cross_attention,
|
450 |
+
use_linear_projection=use_linear_projection,
|
451 |
+
only_cross_attention=only_cross_attention[i],
|
452 |
+
upcast_attention=upcast_attention,
|
453 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
454 |
+
attention_type=attention_type,
|
455 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
456 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
457 |
+
cross_attention_norm=cross_attention_norm,
|
458 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
459 |
+
dropout=dropout,
|
460 |
+
)
|
461 |
+
self.down_blocks.append(down_block)
|
462 |
+
|
463 |
+
# mid
|
464 |
+
self.mid_block = get_mid_block(
|
465 |
+
mid_block_type,
|
466 |
+
temb_channels=blocks_time_embed_dim,
|
467 |
+
in_channels=block_out_channels[-1],
|
468 |
+
resnet_eps=norm_eps,
|
469 |
+
resnet_act_fn=act_fn,
|
470 |
+
resnet_groups=norm_num_groups,
|
471 |
+
output_scale_factor=mid_block_scale_factor,
|
472 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
473 |
+
num_attention_heads=num_attention_heads[-1],
|
474 |
+
cross_attention_dim=cross_attention_dim[-1],
|
475 |
+
dual_cross_attention=dual_cross_attention,
|
476 |
+
use_linear_projection=use_linear_projection,
|
477 |
+
mid_block_only_cross_attention=mid_block_only_cross_attention,
|
478 |
+
upcast_attention=upcast_attention,
|
479 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
480 |
+
attention_type=attention_type,
|
481 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
482 |
+
cross_attention_norm=cross_attention_norm,
|
483 |
+
attention_head_dim=attention_head_dim[-1],
|
484 |
+
dropout=dropout,
|
485 |
+
)
|
486 |
+
self.mid_zero_SFT = ZeroSFT(block_out_channels[-1],block_out_channels[-1],0)
|
487 |
+
|
488 |
+
# count how many layers upsample the images
|
489 |
+
self.num_upsamplers = 0
|
490 |
+
|
491 |
+
# up
|
492 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
493 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
494 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
495 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
496 |
+
reversed_transformer_layers_per_block = (
|
497 |
+
list(reversed(transformer_layers_per_block))
|
498 |
+
if reverse_transformer_layers_per_block is None
|
499 |
+
else reverse_transformer_layers_per_block
|
500 |
+
)
|
501 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
502 |
+
|
503 |
+
output_channel = reversed_block_out_channels[0]
|
504 |
+
for i, up_block_type in enumerate(up_block_types):
|
505 |
+
is_final_block = i == len(block_out_channels) - 1
|
506 |
+
|
507 |
+
prev_output_channel = output_channel
|
508 |
+
output_channel = reversed_block_out_channels[i]
|
509 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
510 |
+
|
511 |
+
# add upsample block for all BUT final layer
|
512 |
+
if not is_final_block:
|
513 |
+
add_upsample = True
|
514 |
+
self.num_upsamplers += 1
|
515 |
+
else:
|
516 |
+
add_upsample = False
|
517 |
+
|
518 |
+
up_block = get_up_block(
|
519 |
+
up_block_type,
|
520 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
521 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
522 |
+
in_channels=input_channel,
|
523 |
+
out_channels=output_channel,
|
524 |
+
prev_output_channel=prev_output_channel,
|
525 |
+
temb_channels=blocks_time_embed_dim,
|
526 |
+
add_upsample=add_upsample,
|
527 |
+
resnet_eps=norm_eps,
|
528 |
+
resnet_act_fn=act_fn,
|
529 |
+
resolution_idx=i,
|
530 |
+
resnet_groups=norm_num_groups,
|
531 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
532 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
533 |
+
dual_cross_attention=dual_cross_attention,
|
534 |
+
use_linear_projection=use_linear_projection,
|
535 |
+
only_cross_attention=only_cross_attention[i],
|
536 |
+
upcast_attention=upcast_attention,
|
537 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
538 |
+
attention_type=attention_type,
|
539 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
540 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
541 |
+
cross_attention_norm=cross_attention_norm,
|
542 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
543 |
+
dropout=dropout,
|
544 |
+
)
|
545 |
+
self.up_blocks.append(up_block)
|
546 |
+
prev_output_channel = output_channel
|
547 |
+
|
548 |
+
# out
|
549 |
+
if norm_num_groups is not None:
|
550 |
+
self.conv_norm_out = nn.GroupNorm(
|
551 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
552 |
+
)
|
553 |
+
|
554 |
+
self.conv_act = get_activation(act_fn)
|
555 |
+
|
556 |
+
else:
|
557 |
+
self.conv_norm_out = None
|
558 |
+
self.conv_act = None
|
559 |
+
|
560 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
561 |
+
self.conv_out = nn.Conv2d(
|
562 |
+
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
563 |
+
)
|
564 |
+
|
565 |
+
self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
|
566 |
+
|
567 |
+
def _check_config(
|
568 |
+
self,
|
569 |
+
down_block_types: Tuple[str],
|
570 |
+
up_block_types: Tuple[str],
|
571 |
+
only_cross_attention: Union[bool, Tuple[bool]],
|
572 |
+
block_out_channels: Tuple[int],
|
573 |
+
layers_per_block: Union[int, Tuple[int]],
|
574 |
+
cross_attention_dim: Union[int, Tuple[int]],
|
575 |
+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
|
576 |
+
reverse_transformer_layers_per_block: bool,
|
577 |
+
attention_head_dim: int,
|
578 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]],
|
579 |
+
):
|
580 |
+
if len(down_block_types) != len(up_block_types):
|
581 |
+
raise ValueError(
|
582 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
583 |
+
)
|
584 |
+
|
585 |
+
if len(block_out_channels) != len(down_block_types):
|
586 |
+
raise ValueError(
|
587 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
588 |
+
)
|
589 |
+
|
590 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
591 |
+
raise ValueError(
|
592 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
593 |
+
)
|
594 |
+
|
595 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
596 |
+
raise ValueError(
|
597 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
598 |
+
)
|
599 |
+
|
600 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
601 |
+
raise ValueError(
|
602 |
+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
603 |
+
)
|
604 |
+
|
605 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
606 |
+
raise ValueError(
|
607 |
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
608 |
+
)
|
609 |
+
|
610 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
611 |
+
raise ValueError(
|
612 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
613 |
+
)
|
614 |
+
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
|
615 |
+
for layer_number_per_block in transformer_layers_per_block:
|
616 |
+
if isinstance(layer_number_per_block, list):
|
617 |
+
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
|
618 |
+
|
619 |
+
def _set_time_proj(
|
620 |
+
self,
|
621 |
+
time_embedding_type: str,
|
622 |
+
block_out_channels: int,
|
623 |
+
flip_sin_to_cos: bool,
|
624 |
+
freq_shift: float,
|
625 |
+
time_embedding_dim: int,
|
626 |
+
) -> Tuple[int, int]:
|
627 |
+
if time_embedding_type == "fourier":
|
628 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
629 |
+
if time_embed_dim % 2 != 0:
|
630 |
+
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
631 |
+
self.time_proj = GaussianFourierProjection(
|
632 |
+
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
633 |
+
)
|
634 |
+
timestep_input_dim = time_embed_dim
|
635 |
+
elif time_embedding_type == "positional":
|
636 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
637 |
+
|
638 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
639 |
+
timestep_input_dim = block_out_channels[0]
|
640 |
+
else:
|
641 |
+
raise ValueError(
|
642 |
+
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
643 |
+
)
|
644 |
+
|
645 |
+
return time_embed_dim, timestep_input_dim
|
646 |
+
|
647 |
+
def _set_encoder_hid_proj(
|
648 |
+
self,
|
649 |
+
encoder_hid_dim_type: Optional[str],
|
650 |
+
cross_attention_dim: Union[int, Tuple[int]],
|
651 |
+
encoder_hid_dim: Optional[int],
|
652 |
+
):
|
653 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
654 |
+
encoder_hid_dim_type = "text_proj"
|
655 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
656 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
657 |
+
|
658 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
659 |
+
raise ValueError(
|
660 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
661 |
+
)
|
662 |
+
|
663 |
+
if encoder_hid_dim_type == "text_proj":
|
664 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
665 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
666 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
667 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
668 |
+
# case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
|
669 |
+
self.encoder_hid_proj = TextImageProjection(
|
670 |
+
text_embed_dim=encoder_hid_dim,
|
671 |
+
image_embed_dim=cross_attention_dim,
|
672 |
+
cross_attention_dim=cross_attention_dim,
|
673 |
+
)
|
674 |
+
elif encoder_hid_dim_type == "image_proj":
|
675 |
+
# Kandinsky 2.2
|
676 |
+
self.encoder_hid_proj = ImageProjection(
|
677 |
+
image_embed_dim=encoder_hid_dim,
|
678 |
+
cross_attention_dim=cross_attention_dim,
|
679 |
+
)
|
680 |
+
elif encoder_hid_dim_type is not None:
|
681 |
+
raise ValueError(
|
682 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
683 |
+
)
|
684 |
+
else:
|
685 |
+
self.encoder_hid_proj = None
|
686 |
+
|
687 |
+
def _set_class_embedding(
|
688 |
+
self,
|
689 |
+
class_embed_type: Optional[str],
|
690 |
+
act_fn: str,
|
691 |
+
num_class_embeds: Optional[int],
|
692 |
+
projection_class_embeddings_input_dim: Optional[int],
|
693 |
+
time_embed_dim: int,
|
694 |
+
timestep_input_dim: int,
|
695 |
+
):
|
696 |
+
if class_embed_type is None and num_class_embeds is not None:
|
697 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
698 |
+
elif class_embed_type == "timestep":
|
699 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
700 |
+
elif class_embed_type == "identity":
|
701 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
702 |
+
elif class_embed_type == "projection":
|
703 |
+
if projection_class_embeddings_input_dim is None:
|
704 |
+
raise ValueError(
|
705 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
706 |
+
)
|
707 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
708 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
709 |
+
# 2. it projects from an arbitrary input dimension.
|
710 |
+
#
|
711 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
712 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
713 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
714 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
715 |
+
elif class_embed_type == "simple_projection":
|
716 |
+
if projection_class_embeddings_input_dim is None:
|
717 |
+
raise ValueError(
|
718 |
+
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
|
719 |
+
)
|
720 |
+
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
|
721 |
+
else:
|
722 |
+
self.class_embedding = None
|
723 |
+
|
724 |
+
def _set_add_embedding(
|
725 |
+
self,
|
726 |
+
addition_embed_type: str,
|
727 |
+
addition_embed_type_num_heads: int,
|
728 |
+
addition_time_embed_dim: Optional[int],
|
729 |
+
flip_sin_to_cos: bool,
|
730 |
+
freq_shift: float,
|
731 |
+
cross_attention_dim: Optional[int],
|
732 |
+
encoder_hid_dim: Optional[int],
|
733 |
+
projection_class_embeddings_input_dim: Optional[int],
|
734 |
+
time_embed_dim: int,
|
735 |
+
):
|
736 |
+
if addition_embed_type == "text":
|
737 |
+
if encoder_hid_dim is not None:
|
738 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
739 |
+
else:
|
740 |
+
text_time_embedding_from_dim = cross_attention_dim
|
741 |
+
|
742 |
+
self.add_embedding = TextTimeEmbedding(
|
743 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
744 |
+
)
|
745 |
+
elif addition_embed_type == "text_image":
|
746 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
747 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
748 |
+
# case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
|
749 |
+
self.add_embedding = TextImageTimeEmbedding(
|
750 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
751 |
+
)
|
752 |
+
elif addition_embed_type == "text_time":
|
753 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
754 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
755 |
+
elif addition_embed_type == "image":
|
756 |
+
# Kandinsky 2.2
|
757 |
+
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
758 |
+
elif addition_embed_type == "image_hint":
|
759 |
+
# Kandinsky 2.2 ControlNet
|
760 |
+
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
761 |
+
elif addition_embed_type is not None:
|
762 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
763 |
+
|
764 |
+
def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
|
765 |
+
if attention_type in ["gated", "gated-text-image"]:
|
766 |
+
positive_len = 768
|
767 |
+
if isinstance(cross_attention_dim, int):
|
768 |
+
positive_len = cross_attention_dim
|
769 |
+
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
|
770 |
+
positive_len = cross_attention_dim[0]
|
771 |
+
|
772 |
+
feature_type = "text-only" if attention_type == "gated" else "text-image"
|
773 |
+
self.position_net = GLIGENTextBoundingboxProjection(
|
774 |
+
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
|
775 |
+
)
|
776 |
+
|
777 |
+
@property
|
778 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
779 |
+
r"""
|
780 |
+
Returns:
|
781 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
782 |
+
indexed by its weight name.
|
783 |
+
"""
|
784 |
+
# set recursively
|
785 |
+
processors = {}
|
786 |
+
|
787 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
788 |
+
if hasattr(module, "get_processor"):
|
789 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
790 |
+
|
791 |
+
for sub_name, child in module.named_children():
|
792 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
793 |
+
|
794 |
+
return processors
|
795 |
+
|
796 |
+
for name, module in self.named_children():
|
797 |
+
fn_recursive_add_processors(name, module, processors)
|
798 |
+
|
799 |
+
return processors
|
800 |
+
|
801 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
802 |
+
r"""
|
803 |
+
Sets the attention processor to use to compute attention.
|
804 |
+
|
805 |
+
Parameters:
|
806 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
807 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
808 |
+
for **all** `Attention` layers.
|
809 |
+
|
810 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
811 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
812 |
+
|
813 |
+
"""
|
814 |
+
count = len(self.attn_processors.keys())
|
815 |
+
|
816 |
+
if isinstance(processor, dict) and len(processor) != count:
|
817 |
+
raise ValueError(
|
818 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
819 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
820 |
+
)
|
821 |
+
|
822 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
823 |
+
if hasattr(module, "set_processor"):
|
824 |
+
if not isinstance(processor, dict):
|
825 |
+
module.set_processor(processor)
|
826 |
+
else:
|
827 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
828 |
+
|
829 |
+
for sub_name, child in module.named_children():
|
830 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
831 |
+
|
832 |
+
for name, module in self.named_children():
|
833 |
+
fn_recursive_attn_processor(name, module, processor)
|
834 |
+
|
835 |
+
def set_default_attn_processor(self):
|
836 |
+
"""
|
837 |
+
Disables custom attention processors and sets the default attention implementation.
|
838 |
+
"""
|
839 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
840 |
+
processor = AttnAddedKVProcessor()
|
841 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
842 |
+
processor = AttnProcessor()
|
843 |
+
else:
|
844 |
+
raise ValueError(
|
845 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
846 |
+
)
|
847 |
+
|
848 |
+
self.set_attn_processor(processor)
|
849 |
+
|
850 |
+
def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
|
851 |
+
r"""
|
852 |
+
Enable sliced attention computation.
|
853 |
+
|
854 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
855 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
856 |
+
|
857 |
+
Args:
|
858 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
859 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
860 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
861 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
862 |
+
must be a multiple of `slice_size`.
|
863 |
+
"""
|
864 |
+
sliceable_head_dims = []
|
865 |
+
|
866 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
867 |
+
if hasattr(module, "set_attention_slice"):
|
868 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
869 |
+
|
870 |
+
for child in module.children():
|
871 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
872 |
+
|
873 |
+
# retrieve number of attention layers
|
874 |
+
for module in self.children():
|
875 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
876 |
+
|
877 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
878 |
+
|
879 |
+
if slice_size == "auto":
|
880 |
+
# half the attention head size is usually a good trade-off between
|
881 |
+
# speed and memory
|
882 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
883 |
+
elif slice_size == "max":
|
884 |
+
# make smallest slice possible
|
885 |
+
slice_size = num_sliceable_layers * [1]
|
886 |
+
|
887 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
888 |
+
|
889 |
+
if len(slice_size) != len(sliceable_head_dims):
|
890 |
+
raise ValueError(
|
891 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
892 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
893 |
+
)
|
894 |
+
|
895 |
+
for i in range(len(slice_size)):
|
896 |
+
size = slice_size[i]
|
897 |
+
dim = sliceable_head_dims[i]
|
898 |
+
if size is not None and size > dim:
|
899 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
900 |
+
|
901 |
+
# Recursively walk through all the children.
|
902 |
+
# Any children which exposes the set_attention_slice method
|
903 |
+
# gets the message
|
904 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
905 |
+
if hasattr(module, "set_attention_slice"):
|
906 |
+
module.set_attention_slice(slice_size.pop())
|
907 |
+
|
908 |
+
for child in module.children():
|
909 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
910 |
+
|
911 |
+
reversed_slice_size = list(reversed(slice_size))
|
912 |
+
for module in self.children():
|
913 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
914 |
+
|
915 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
916 |
+
if hasattr(module, "gradient_checkpointing"):
|
917 |
+
module.gradient_checkpointing = value
|
918 |
+
|
919 |
+
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
920 |
+
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
921 |
+
|
922 |
+
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
923 |
+
|
924 |
+
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
|
925 |
+
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
926 |
+
|
927 |
+
Args:
|
928 |
+
s1 (`float`):
|
929 |
+
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
930 |
+
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
931 |
+
s2 (`float`):
|
932 |
+
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
933 |
+
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
934 |
+
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
935 |
+
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
936 |
+
"""
|
937 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
938 |
+
setattr(upsample_block, "s1", s1)
|
939 |
+
setattr(upsample_block, "s2", s2)
|
940 |
+
setattr(upsample_block, "b1", b1)
|
941 |
+
setattr(upsample_block, "b2", b2)
|
942 |
+
|
943 |
+
def disable_freeu(self):
|
944 |
+
"""Disables the FreeU mechanism."""
|
945 |
+
freeu_keys = {"s1", "s2", "b1", "b2"}
|
946 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
947 |
+
for k in freeu_keys:
|
948 |
+
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
949 |
+
setattr(upsample_block, k, None)
|
950 |
+
|
951 |
+
def fuse_qkv_projections(self):
|
952 |
+
"""
|
953 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
954 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
955 |
+
|
956 |
+
<Tip warning={true}>
|
957 |
+
|
958 |
+
This API is 🧪 experimental.
|
959 |
+
|
960 |
+
</Tip>
|
961 |
+
"""
|
962 |
+
self.original_attn_processors = None
|
963 |
+
|
964 |
+
for _, attn_processor in self.attn_processors.items():
|
965 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
966 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
967 |
+
|
968 |
+
self.original_attn_processors = self.attn_processors
|
969 |
+
|
970 |
+
for module in self.modules():
|
971 |
+
if isinstance(module, Attention):
|
972 |
+
module.fuse_projections(fuse=True)
|
973 |
+
|
974 |
+
def unfuse_qkv_projections(self):
|
975 |
+
"""Disables the fused QKV projection if enabled.
|
976 |
+
|
977 |
+
<Tip warning={true}>
|
978 |
+
|
979 |
+
This API is 🧪 experimental.
|
980 |
+
|
981 |
+
</Tip>
|
982 |
+
|
983 |
+
"""
|
984 |
+
if self.original_attn_processors is not None:
|
985 |
+
self.set_attn_processor(self.original_attn_processors)
|
986 |
+
|
987 |
+
def unload_lora(self):
|
988 |
+
"""Unloads LoRA weights."""
|
989 |
+
deprecate(
|
990 |
+
"unload_lora",
|
991 |
+
"0.28.0",
|
992 |
+
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
|
993 |
+
)
|
994 |
+
for module in self.modules():
|
995 |
+
if hasattr(module, "set_lora_layer"):
|
996 |
+
module.set_lora_layer(None)
|
997 |
+
|
998 |
+
def get_time_embed(
|
999 |
+
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
|
1000 |
+
) -> Optional[torch.Tensor]:
|
1001 |
+
timesteps = timestep
|
1002 |
+
if not torch.is_tensor(timesteps):
|
1003 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
1004 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
1005 |
+
is_mps = sample.device.type == "mps"
|
1006 |
+
if isinstance(timestep, float):
|
1007 |
+
dtype = torch.float32 if is_mps else torch.float64
|
1008 |
+
else:
|
1009 |
+
dtype = torch.int32 if is_mps else torch.int64
|
1010 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
1011 |
+
elif len(timesteps.shape) == 0:
|
1012 |
+
timesteps = timesteps[None].to(sample.device)
|
1013 |
+
|
1014 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
1015 |
+
timesteps = timesteps.expand(sample.shape[0])
|
1016 |
+
|
1017 |
+
t_emb = self.time_proj(timesteps)
|
1018 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
1019 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
1020 |
+
# there might be better ways to encapsulate this.
|
1021 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
1022 |
+
return t_emb
|
1023 |
+
|
1024 |
+
def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
1025 |
+
class_emb = None
|
1026 |
+
if self.class_embedding is not None:
|
1027 |
+
if class_labels is None:
|
1028 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
1029 |
+
|
1030 |
+
if self.config.class_embed_type == "timestep":
|
1031 |
+
class_labels = self.time_proj(class_labels)
|
1032 |
+
|
1033 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
1034 |
+
# there might be better ways to encapsulate this.
|
1035 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
1036 |
+
|
1037 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
1038 |
+
return class_emb
|
1039 |
+
|
1040 |
+
def get_aug_embed(
|
1041 |
+
self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
|
1042 |
+
) -> Optional[torch.Tensor]:
|
1043 |
+
aug_emb = None
|
1044 |
+
if self.config.addition_embed_type == "text":
|
1045 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
1046 |
+
elif self.config.addition_embed_type == "text_image":
|
1047 |
+
# Kandinsky 2.1 - style
|
1048 |
+
if "image_embeds" not in added_cond_kwargs:
|
1049 |
+
raise ValueError(
|
1050 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
1051 |
+
)
|
1052 |
+
|
1053 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
1054 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
1055 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
1056 |
+
elif self.config.addition_embed_type == "text_time":
|
1057 |
+
# SDXL - style
|
1058 |
+
if "text_embeds" not in added_cond_kwargs:
|
1059 |
+
raise ValueError(
|
1060 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
1061 |
+
)
|
1062 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
1063 |
+
if "time_ids" not in added_cond_kwargs:
|
1064 |
+
raise ValueError(
|
1065 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
1066 |
+
)
|
1067 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
1068 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
1069 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
1070 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
1071 |
+
add_embeds = add_embeds.to(emb.dtype)
|
1072 |
+
aug_emb = self.add_embedding(add_embeds)
|
1073 |
+
elif self.config.addition_embed_type == "image":
|
1074 |
+
# Kandinsky 2.2 - style
|
1075 |
+
if "image_embeds" not in added_cond_kwargs:
|
1076 |
+
raise ValueError(
|
1077 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
1078 |
+
)
|
1079 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
1080 |
+
aug_emb = self.add_embedding(image_embs)
|
1081 |
+
elif self.config.addition_embed_type == "image_hint":
|
1082 |
+
# Kandinsky 2.2 - style
|
1083 |
+
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
1084 |
+
raise ValueError(
|
1085 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
1086 |
+
)
|
1087 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
1088 |
+
hint = added_cond_kwargs.get("hint")
|
1089 |
+
aug_emb = self.add_embedding(image_embs, hint)
|
1090 |
+
return aug_emb
|
1091 |
+
|
1092 |
+
def process_encoder_hidden_states(
|
1093 |
+
self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
|
1094 |
+
) -> torch.Tensor:
|
1095 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
1096 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
1097 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
1098 |
+
# Kandinsky 2.1 - style
|
1099 |
+
if "image_embeds" not in added_cond_kwargs:
|
1100 |
+
raise ValueError(
|
1101 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1102 |
+
)
|
1103 |
+
|
1104 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
1105 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
1106 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
1107 |
+
# Kandinsky 2.2 - style
|
1108 |
+
if "image_embeds" not in added_cond_kwargs:
|
1109 |
+
raise ValueError(
|
1110 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1111 |
+
)
|
1112 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
1113 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
1114 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
1115 |
+
if "image_embeds" not in added_cond_kwargs:
|
1116 |
+
raise ValueError(
|
1117 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1118 |
+
)
|
1119 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
1120 |
+
image_embeds = self.encoder_hid_proj(image_embeds)
|
1121 |
+
encoder_hidden_states = (encoder_hidden_states, image_embeds)
|
1122 |
+
return encoder_hidden_states
|
1123 |
+
|
1124 |
+
def forward(
|
1125 |
+
self,
|
1126 |
+
sample: torch.FloatTensor,
|
1127 |
+
timestep: Union[torch.Tensor, float, int],
|
1128 |
+
encoder_hidden_states: torch.Tensor,
|
1129 |
+
class_labels: Optional[torch.Tensor] = None,
|
1130 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
1131 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1132 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1133 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
1134 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
1135 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
1136 |
+
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
1137 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1138 |
+
return_dict: bool = True,
|
1139 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
1140 |
+
r"""
|
1141 |
+
The [`UNet2DConditionModel`] forward method.
|
1142 |
+
|
1143 |
+
Args:
|
1144 |
+
sample (`torch.FloatTensor`):
|
1145 |
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
1146 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
1147 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
1148 |
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
1149 |
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
1150 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
1151 |
+
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
|
1152 |
+
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
|
1153 |
+
through the `self.time_embedding` layer to obtain the timestep embeddings.
|
1154 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
1155 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
1156 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
1157 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
1158 |
+
cross_attention_kwargs (`dict`, *optional*):
|
1159 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
1160 |
+
`self.processor` in
|
1161 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
1162 |
+
added_cond_kwargs: (`dict`, *optional*):
|
1163 |
+
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
|
1164 |
+
are passed along to the UNet blocks.
|
1165 |
+
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
|
1166 |
+
A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
1167 |
+
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
1168 |
+
A tensor that if specified is added to the residual of the middle unet block.
|
1169 |
+
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
1170 |
+
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
1171 |
+
encoder_attention_mask (`torch.Tensor`):
|
1172 |
+
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
1173 |
+
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
1174 |
+
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
1175 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1176 |
+
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
1177 |
+
tuple.
|
1178 |
+
|
1179 |
+
Returns:
|
1180 |
+
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
1181 |
+
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
|
1182 |
+
otherwise a `tuple` is returned where the first element is the sample tensor.
|
1183 |
+
"""
|
1184 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
1185 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
1186 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
1187 |
+
# on the fly if necessary.
|
1188 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
1189 |
+
|
1190 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
1191 |
+
forward_upsample_size = False
|
1192 |
+
upsample_size = None
|
1193 |
+
|
1194 |
+
for dim in sample.shape[-2:]:
|
1195 |
+
if dim % default_overall_up_factor != 0:
|
1196 |
+
# Forward upsample size to force interpolation output size.
|
1197 |
+
forward_upsample_size = True
|
1198 |
+
break
|
1199 |
+
|
1200 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
1201 |
+
# expects mask of shape:
|
1202 |
+
# [batch, key_tokens]
|
1203 |
+
# adds singleton query_tokens dimension:
|
1204 |
+
# [batch, 1, key_tokens]
|
1205 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
1206 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
1207 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
1208 |
+
if attention_mask is not None:
|
1209 |
+
# assume that mask is expressed as:
|
1210 |
+
# (1 = keep, 0 = discard)
|
1211 |
+
# convert mask into a bias that can be added to attention scores:
|
1212 |
+
# (keep = +0, discard = -10000.0)
|
1213 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
1214 |
+
attention_mask = attention_mask.unsqueeze(1)
|
1215 |
+
|
1216 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
1217 |
+
if encoder_attention_mask is not None:
|
1218 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
1219 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
1220 |
+
|
1221 |
+
# 0. center input if necessary
|
1222 |
+
if self.config.center_input_sample:
|
1223 |
+
sample = 2 * sample - 1.0
|
1224 |
+
|
1225 |
+
# 1. time
|
1226 |
+
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
|
1227 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
1228 |
+
aug_emb = None
|
1229 |
+
|
1230 |
+
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
|
1231 |
+
if class_emb is not None:
|
1232 |
+
if self.config.class_embeddings_concat:
|
1233 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
1234 |
+
else:
|
1235 |
+
emb = emb + class_emb
|
1236 |
+
|
1237 |
+
aug_emb = self.get_aug_embed(
|
1238 |
+
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
1239 |
+
)
|
1240 |
+
if self.config.addition_embed_type == "image_hint":
|
1241 |
+
aug_emb, hint = aug_emb
|
1242 |
+
sample = torch.cat([sample, hint], dim=1)
|
1243 |
+
|
1244 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
1245 |
+
|
1246 |
+
if self.time_embed_act is not None:
|
1247 |
+
emb = self.time_embed_act(emb)
|
1248 |
+
|
1249 |
+
encoder_hidden_states = self.process_encoder_hidden_states(
|
1250 |
+
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
1251 |
+
)
|
1252 |
+
|
1253 |
+
# 2. pre-process
|
1254 |
+
sample = self.conv_in(sample)
|
1255 |
+
|
1256 |
+
# 2.5 GLIGEN position net
|
1257 |
+
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
1258 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
1259 |
+
gligen_args = cross_attention_kwargs.pop("gligen")
|
1260 |
+
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
1261 |
+
|
1262 |
+
# 3. down
|
1263 |
+
# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
|
1264 |
+
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
|
1265 |
+
if cross_attention_kwargs is not None:
|
1266 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
1267 |
+
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
|
1268 |
+
else:
|
1269 |
+
lora_scale = 1.0
|
1270 |
+
|
1271 |
+
if USE_PEFT_BACKEND:
|
1272 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
1273 |
+
scale_lora_layers(self, lora_scale)
|
1274 |
+
|
1275 |
+
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
1276 |
+
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
1277 |
+
is_adapter = down_intrablock_additional_residuals is not None
|
1278 |
+
# maintain backward compatibility for legacy usage, where
|
1279 |
+
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
|
1280 |
+
# but can only use one or the other
|
1281 |
+
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
|
1282 |
+
deprecate(
|
1283 |
+
"T2I should not use down_block_additional_residuals",
|
1284 |
+
"1.3.0",
|
1285 |
+
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
1286 |
+
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
1287 |
+
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
1288 |
+
standard_warn=False,
|
1289 |
+
)
|
1290 |
+
down_intrablock_additional_residuals = down_block_additional_residuals
|
1291 |
+
is_adapter = True
|
1292 |
+
|
1293 |
+
down_block_res_samples = (sample,)
|
1294 |
+
for downsample_block in self.down_blocks:
|
1295 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
1296 |
+
# For t2i-adapter CrossAttnDownBlock2D
|
1297 |
+
additional_residuals = {}
|
1298 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
1299 |
+
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
|
1300 |
+
|
1301 |
+
sample, res_samples = downsample_block(
|
1302 |
+
hidden_states=sample,
|
1303 |
+
temb=emb,
|
1304 |
+
encoder_hidden_states=encoder_hidden_states,
|
1305 |
+
attention_mask=attention_mask,
|
1306 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1307 |
+
encoder_attention_mask=encoder_attention_mask,
|
1308 |
+
**additional_residuals,
|
1309 |
+
)
|
1310 |
+
else:
|
1311 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
1312 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
1313 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
1314 |
+
|
1315 |
+
down_block_res_samples += res_samples
|
1316 |
+
|
1317 |
+
if is_controlnet:
|
1318 |
+
new_down_block_res_samples = ()
|
1319 |
+
|
1320 |
+
for down_block_additional_residual, down_block_res_sample in zip(
|
1321 |
+
down_block_additional_residuals, down_block_res_samples
|
1322 |
+
):
|
1323 |
+
down_block_res_sample_tuple = (down_block_additional_residual, down_block_res_sample)
|
1324 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample_tuple,)
|
1325 |
+
|
1326 |
+
down_block_res_samples = new_down_block_res_samples
|
1327 |
+
|
1328 |
+
# 4. mid
|
1329 |
+
if self.mid_block is not None:
|
1330 |
+
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
1331 |
+
sample = self.mid_block(
|
1332 |
+
sample,
|
1333 |
+
emb,
|
1334 |
+
encoder_hidden_states=encoder_hidden_states,
|
1335 |
+
attention_mask=attention_mask,
|
1336 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1337 |
+
encoder_attention_mask=encoder_attention_mask,
|
1338 |
+
)
|
1339 |
+
else:
|
1340 |
+
sample = self.mid_block(sample, emb)
|
1341 |
+
|
1342 |
+
# To support T2I-Adapter-XL
|
1343 |
+
if (
|
1344 |
+
is_adapter
|
1345 |
+
and len(down_intrablock_additional_residuals) > 0
|
1346 |
+
and sample.shape == down_intrablock_additional_residuals[0].shape
|
1347 |
+
):
|
1348 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
1349 |
+
|
1350 |
+
if is_controlnet:
|
1351 |
+
sample = self.mid_zero_SFT((mid_block_additional_residual, sample),)
|
1352 |
+
|
1353 |
+
# 5. up
|
1354 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
1355 |
+
is_final_block = i == len(self.up_blocks) - 1
|
1356 |
+
|
1357 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
1358 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
1359 |
+
|
1360 |
+
# if we have not reached the final block and need to forward the
|
1361 |
+
# upsample size, we do it here
|
1362 |
+
if not is_final_block and forward_upsample_size:
|
1363 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
1364 |
+
|
1365 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
1366 |
+
sample = upsample_block(
|
1367 |
+
hidden_states=sample,
|
1368 |
+
temb=emb,
|
1369 |
+
res_hidden_states_tuple=res_samples,
|
1370 |
+
encoder_hidden_states=encoder_hidden_states,
|
1371 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1372 |
+
upsample_size=upsample_size,
|
1373 |
+
attention_mask=attention_mask,
|
1374 |
+
encoder_attention_mask=encoder_attention_mask,
|
1375 |
+
)
|
1376 |
+
else:
|
1377 |
+
sample = upsample_block(
|
1378 |
+
hidden_states=sample,
|
1379 |
+
temb=emb,
|
1380 |
+
res_hidden_states_tuple=res_samples,
|
1381 |
+
upsample_size=upsample_size,
|
1382 |
+
)
|
1383 |
+
|
1384 |
+
# 6. post-process
|
1385 |
+
if self.conv_norm_out:
|
1386 |
+
sample = self.conv_norm_out(sample)
|
1387 |
+
sample = self.conv_act(sample)
|
1388 |
+
sample = self.conv_out(sample)
|
1389 |
+
|
1390 |
+
if USE_PEFT_BACKEND:
|
1391 |
+
# remove `lora_scale` from each PEFT layer
|
1392 |
+
unscale_lora_layers(self, lora_scale)
|
1393 |
+
|
1394 |
+
if not return_dict:
|
1395 |
+
return (sample,)
|
1396 |
+
|
1397 |
+
return UNet2DConditionOutput(sample=sample)
|
module/unet/unet_2d_ZeroSFT_blocks.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pipelines/sdxl_instantir.py
ADDED
@@ -0,0 +1,1740 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The InstantX Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import inspect
|
17 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import PIL.Image
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from transformers import (
|
24 |
+
CLIPImageProcessor,
|
25 |
+
CLIPTextModel,
|
26 |
+
CLIPTextModelWithProjection,
|
27 |
+
CLIPTokenizer,
|
28 |
+
CLIPVisionModelWithProjection,
|
29 |
+
)
|
30 |
+
|
31 |
+
from diffusers.utils.import_utils import is_invisible_watermark_available
|
32 |
+
|
33 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
34 |
+
from diffusers.loaders import (
|
35 |
+
FromSingleFileMixin,
|
36 |
+
IPAdapterMixin,
|
37 |
+
StableDiffusionXLLoraLoaderMixin,
|
38 |
+
TextualInversionLoaderMixin,
|
39 |
+
)
|
40 |
+
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
41 |
+
from diffusers.models.attention_processor import (
|
42 |
+
AttnProcessor2_0,
|
43 |
+
LoRAAttnProcessor2_0,
|
44 |
+
LoRAXFormersAttnProcessor,
|
45 |
+
XFormersAttnProcessor,
|
46 |
+
)
|
47 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
48 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers, LCMScheduler
|
49 |
+
from diffusers.utils import (
|
50 |
+
USE_PEFT_BACKEND,
|
51 |
+
deprecate,
|
52 |
+
logging,
|
53 |
+
replace_example_docstring,
|
54 |
+
scale_lora_layers,
|
55 |
+
unscale_lora_layers,
|
56 |
+
convert_unet_state_dict_to_peft
|
57 |
+
)
|
58 |
+
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
59 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
60 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
61 |
+
|
62 |
+
|
63 |
+
if is_invisible_watermark_available():
|
64 |
+
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
65 |
+
|
66 |
+
from peft import LoraConfig, set_peft_model_state_dict
|
67 |
+
from module.aggregator import Aggregator
|
68 |
+
|
69 |
+
|
70 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
71 |
+
|
72 |
+
|
73 |
+
EXAMPLE_DOC_STRING = """
|
74 |
+
Examples:
|
75 |
+
```py
|
76 |
+
>>> # !pip install diffusers pillow transformers accelerate
|
77 |
+
>>> import torch
|
78 |
+
>>> from PIL import Image
|
79 |
+
|
80 |
+
>>> from diffusers import DDPMScheduler
|
81 |
+
>>> from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
|
82 |
+
|
83 |
+
>>> from module.ip_adapter.utils import load_adapter_to_pipe
|
84 |
+
>>> from pipelines.sdxl_instantir import InstantIRPipeline
|
85 |
+
|
86 |
+
>>> # download models under ./models
|
87 |
+
>>> dcp_adapter = f'./models/adapter.pt'
|
88 |
+
>>> previewer_lora_path = f'./models'
|
89 |
+
>>> instantir_path = f'./models/aggregator.pt'
|
90 |
+
|
91 |
+
>>> # load pretrained models
|
92 |
+
>>> pipe = InstantIRPipeline.from_pretrained(
|
93 |
+
... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
|
94 |
+
... )
|
95 |
+
>>> # load adapter
|
96 |
+
>>> load_adapter_to_pipe(
|
97 |
+
... pipe,
|
98 |
+
... dcp_adapter,
|
99 |
+
... image_encoder_or_path = 'facebook/dinov2-large',
|
100 |
+
... )
|
101 |
+
>>> # load previewer lora
|
102 |
+
>>> pipe.prepare_previewers(previewer_lora_path)
|
103 |
+
>>> pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler")
|
104 |
+
>>> lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
|
105 |
+
|
106 |
+
>>> # load aggregator weights
|
107 |
+
>>> pretrained_state_dict = torch.load(instantir_path)
|
108 |
+
>>> pipe.aggregator.load_state_dict(pretrained_state_dict)
|
109 |
+
|
110 |
+
>>> # send to GPU and fp16
|
111 |
+
>>> pipe.to(device="cuda", dtype=torch.float16)
|
112 |
+
>>> pipe.aggregator.to(device="cuda", dtype=torch.float16)
|
113 |
+
>>> pipe.enable_model_cpu_offload()
|
114 |
+
|
115 |
+
>>> # load a broken image
|
116 |
+
>>> low_quality_image = Image.open('path/to/your-image').convert("RGB")
|
117 |
+
|
118 |
+
>>> # restoration
|
119 |
+
>>> image = pipe(
|
120 |
+
... image=low_quality_image,
|
121 |
+
... previewer_scheduler=lcm_scheduler,
|
122 |
+
... ).images[0]
|
123 |
+
```
|
124 |
+
"""
|
125 |
+
|
126 |
+
LCM_LORA_MODULES = [
|
127 |
+
"to_q",
|
128 |
+
"to_k",
|
129 |
+
"to_v",
|
130 |
+
"to_out.0",
|
131 |
+
"proj_in",
|
132 |
+
"proj_out",
|
133 |
+
"ff.net.0.proj",
|
134 |
+
"ff.net.2",
|
135 |
+
"conv1",
|
136 |
+
"conv2",
|
137 |
+
"conv_shortcut",
|
138 |
+
"downsamplers.0.conv",
|
139 |
+
"upsamplers.0.conv",
|
140 |
+
"time_emb_proj",
|
141 |
+
]
|
142 |
+
PREVIEWER_LORA_MODULES = [
|
143 |
+
"to_q",
|
144 |
+
"to_kv",
|
145 |
+
"0.to_out",
|
146 |
+
"attn1.to_k",
|
147 |
+
"attn1.to_v",
|
148 |
+
"to_k_ip",
|
149 |
+
"to_v_ip",
|
150 |
+
"ln_k_ip.linear",
|
151 |
+
"ln_v_ip.linear",
|
152 |
+
"to_out.0",
|
153 |
+
"proj_in",
|
154 |
+
"proj_out",
|
155 |
+
"ff.net.0.proj",
|
156 |
+
"ff.net.2",
|
157 |
+
"conv1",
|
158 |
+
"conv2",
|
159 |
+
"conv_shortcut",
|
160 |
+
"downsamplers.0.conv",
|
161 |
+
"upsamplers.0.conv",
|
162 |
+
"time_emb_proj",
|
163 |
+
]
|
164 |
+
|
165 |
+
|
166 |
+
def remove_attn2(model):
|
167 |
+
def recursive_find_module(name, module):
|
168 |
+
if not "up_blocks" in name and not "down_blocks" in name and not "mid_block" in name: return
|
169 |
+
elif "resnets" in name: return
|
170 |
+
if hasattr(module, "attn2"):
|
171 |
+
setattr(module, "attn2", None)
|
172 |
+
setattr(module, "norm2", None)
|
173 |
+
return
|
174 |
+
for sub_name, sub_module in module.named_children():
|
175 |
+
recursive_find_module(f"{name}.{sub_name}", sub_module)
|
176 |
+
|
177 |
+
for name, module in model.named_children():
|
178 |
+
recursive_find_module(name, module)
|
179 |
+
|
180 |
+
|
181 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
182 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
183 |
+
"""
|
184 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
185 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
186 |
+
"""
|
187 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
188 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
189 |
+
# rescale the results from guidance (fixes overexposure)
|
190 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
191 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
192 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
193 |
+
return noise_cfg
|
194 |
+
|
195 |
+
|
196 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
197 |
+
def retrieve_timesteps(
|
198 |
+
scheduler,
|
199 |
+
num_inference_steps: Optional[int] = None,
|
200 |
+
device: Optional[Union[str, torch.device]] = None,
|
201 |
+
timesteps: Optional[List[int]] = None,
|
202 |
+
**kwargs,
|
203 |
+
):
|
204 |
+
"""
|
205 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
206 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
207 |
+
|
208 |
+
Args:
|
209 |
+
scheduler (`SchedulerMixin`):
|
210 |
+
The scheduler to get timesteps from.
|
211 |
+
num_inference_steps (`int`):
|
212 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
213 |
+
must be `None`.
|
214 |
+
device (`str` or `torch.device`, *optional*):
|
215 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
216 |
+
timesteps (`List[int]`, *optional*):
|
217 |
+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
218 |
+
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
219 |
+
must be `None`.
|
220 |
+
|
221 |
+
Returns:
|
222 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
223 |
+
second element is the number of inference steps.
|
224 |
+
"""
|
225 |
+
if timesteps is not None:
|
226 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
227 |
+
if not accepts_timesteps:
|
228 |
+
raise ValueError(
|
229 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
230 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
231 |
+
)
|
232 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
233 |
+
timesteps = scheduler.timesteps
|
234 |
+
num_inference_steps = len(timesteps)
|
235 |
+
else:
|
236 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
237 |
+
timesteps = scheduler.timesteps
|
238 |
+
return timesteps, num_inference_steps
|
239 |
+
|
240 |
+
|
241 |
+
class InstantIRPipeline(
|
242 |
+
DiffusionPipeline,
|
243 |
+
StableDiffusionMixin,
|
244 |
+
TextualInversionLoaderMixin,
|
245 |
+
StableDiffusionXLLoraLoaderMixin,
|
246 |
+
IPAdapterMixin,
|
247 |
+
FromSingleFileMixin,
|
248 |
+
):
|
249 |
+
r"""
|
250 |
+
Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.
|
251 |
+
|
252 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
253 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
254 |
+
|
255 |
+
The pipeline also inherits the following loading methods:
|
256 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
257 |
+
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
258 |
+
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
259 |
+
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
260 |
+
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
261 |
+
|
262 |
+
Args:
|
263 |
+
vae ([`AutoencoderKL`]):
|
264 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
265 |
+
text_encoder ([`~transformers.CLIPTextModel`]):
|
266 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
267 |
+
text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]):
|
268 |
+
Second frozen text-encoder
|
269 |
+
([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
|
270 |
+
tokenizer ([`~transformers.CLIPTokenizer`]):
|
271 |
+
A `CLIPTokenizer` to tokenize text.
|
272 |
+
tokenizer_2 ([`~transformers.CLIPTokenizer`]):
|
273 |
+
A `CLIPTokenizer` to tokenize text.
|
274 |
+
unet ([`UNet2DConditionModel`]):
|
275 |
+
A `UNet2DConditionModel` to denoise the encoded image latents.
|
276 |
+
controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
|
277 |
+
Provides additional conditioning to the `unet` during the denoising process. If you set multiple
|
278 |
+
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
|
279 |
+
additional conditioning.
|
280 |
+
scheduler ([`SchedulerMixin`]):
|
281 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
282 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
283 |
+
force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
|
284 |
+
Whether the negative prompt embeddings should always be set to 0. Also see the config of
|
285 |
+
`stabilityai/stable-diffusion-xl-base-1-0`.
|
286 |
+
add_watermarker (`bool`, *optional*):
|
287 |
+
Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to
|
288 |
+
watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no
|
289 |
+
watermarker is used.
|
290 |
+
"""
|
291 |
+
|
292 |
+
# leave controlnet out on purpose because it iterates with unet
|
293 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
|
294 |
+
_optional_components = [
|
295 |
+
"tokenizer",
|
296 |
+
"tokenizer_2",
|
297 |
+
"text_encoder",
|
298 |
+
"text_encoder_2",
|
299 |
+
"feature_extractor",
|
300 |
+
"image_encoder",
|
301 |
+
]
|
302 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
303 |
+
|
304 |
+
def __init__(
|
305 |
+
self,
|
306 |
+
vae: AutoencoderKL,
|
307 |
+
text_encoder: CLIPTextModel,
|
308 |
+
text_encoder_2: CLIPTextModelWithProjection,
|
309 |
+
tokenizer: CLIPTokenizer,
|
310 |
+
tokenizer_2: CLIPTokenizer,
|
311 |
+
unet: UNet2DConditionModel,
|
312 |
+
scheduler: KarrasDiffusionSchedulers,
|
313 |
+
aggregator: Aggregator = None,
|
314 |
+
force_zeros_for_empty_prompt: bool = True,
|
315 |
+
add_watermarker: Optional[bool] = None,
|
316 |
+
feature_extractor: CLIPImageProcessor = None,
|
317 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
318 |
+
):
|
319 |
+
super().__init__()
|
320 |
+
|
321 |
+
if aggregator is None:
|
322 |
+
aggregator = Aggregator.from_unet(unet)
|
323 |
+
remove_attn2(aggregator)
|
324 |
+
|
325 |
+
self.register_modules(
|
326 |
+
vae=vae,
|
327 |
+
text_encoder=text_encoder,
|
328 |
+
text_encoder_2=text_encoder_2,
|
329 |
+
tokenizer=tokenizer,
|
330 |
+
tokenizer_2=tokenizer_2,
|
331 |
+
unet=unet,
|
332 |
+
aggregator=aggregator,
|
333 |
+
scheduler=scheduler,
|
334 |
+
feature_extractor=feature_extractor,
|
335 |
+
image_encoder=image_encoder,
|
336 |
+
)
|
337 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
338 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
339 |
+
self.control_image_processor = VaeImageProcessor(
|
340 |
+
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=True
|
341 |
+
)
|
342 |
+
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
343 |
+
|
344 |
+
if add_watermarker:
|
345 |
+
self.watermark = StableDiffusionXLWatermarker()
|
346 |
+
else:
|
347 |
+
self.watermark = None
|
348 |
+
|
349 |
+
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
350 |
+
|
351 |
+
def prepare_previewers(self, previewer_lora_path: str, use_lcm=False):
|
352 |
+
if use_lcm:
|
353 |
+
lora_state_dict, alpha_dict = self.lora_state_dict(
|
354 |
+
previewer_lora_path,
|
355 |
+
)
|
356 |
+
else:
|
357 |
+
lora_state_dict, alpha_dict = self.lora_state_dict(
|
358 |
+
previewer_lora_path,
|
359 |
+
weight_name="previewer_lora_weights.bin"
|
360 |
+
)
|
361 |
+
unet_state_dict = {
|
362 |
+
f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
|
363 |
+
}
|
364 |
+
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
365 |
+
lora_state_dict = dict()
|
366 |
+
for k, v in unet_state_dict.items():
|
367 |
+
if "ip" in k:
|
368 |
+
k = k.replace("attn2", "attn2.processor")
|
369 |
+
lora_state_dict[k] = v
|
370 |
+
else:
|
371 |
+
lora_state_dict[k] = v
|
372 |
+
if alpha_dict:
|
373 |
+
lora_alpha = next(iter(alpha_dict.values()))
|
374 |
+
else:
|
375 |
+
lora_alpha = 1
|
376 |
+
logger.info(f"use lora alpha {lora_alpha}")
|
377 |
+
lora_config = LoraConfig(
|
378 |
+
r=64,
|
379 |
+
target_modules=LCM_LORA_MODULES if use_lcm else PREVIEWER_LORA_MODULES,
|
380 |
+
lora_alpha=lora_alpha,
|
381 |
+
lora_dropout=0.0,
|
382 |
+
)
|
383 |
+
|
384 |
+
adapter_name = "lcm" if use_lcm else "previewer"
|
385 |
+
self.unet.add_adapter(lora_config, adapter_name)
|
386 |
+
incompatible_keys = set_peft_model_state_dict(self.unet, lora_state_dict, adapter_name=adapter_name)
|
387 |
+
if incompatible_keys is not None:
|
388 |
+
# check only for unexpected keys
|
389 |
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
390 |
+
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
391 |
+
if unexpected_keys:
|
392 |
+
raise ValueError(
|
393 |
+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
394 |
+
f" {unexpected_keys}. "
|
395 |
+
)
|
396 |
+
self.unet.disable_adapters()
|
397 |
+
|
398 |
+
return lora_alpha
|
399 |
+
|
400 |
+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
|
401 |
+
def encode_prompt(
|
402 |
+
self,
|
403 |
+
prompt: str,
|
404 |
+
prompt_2: Optional[str] = None,
|
405 |
+
device: Optional[torch.device] = None,
|
406 |
+
num_images_per_prompt: int = 1,
|
407 |
+
do_classifier_free_guidance: bool = True,
|
408 |
+
negative_prompt: Optional[str] = None,
|
409 |
+
negative_prompt_2: Optional[str] = None,
|
410 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
411 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
412 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
413 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
414 |
+
lora_scale: Optional[float] = None,
|
415 |
+
clip_skip: Optional[int] = None,
|
416 |
+
):
|
417 |
+
r"""
|
418 |
+
Encodes the prompt into text encoder hidden states.
|
419 |
+
|
420 |
+
Args:
|
421 |
+
prompt (`str` or `List[str]`, *optional*):
|
422 |
+
prompt to be encoded
|
423 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
424 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
425 |
+
used in both text-encoders
|
426 |
+
device: (`torch.device`):
|
427 |
+
torch device
|
428 |
+
num_images_per_prompt (`int`):
|
429 |
+
number of images that should be generated per prompt
|
430 |
+
do_classifier_free_guidance (`bool`):
|
431 |
+
whether to use classifier free guidance or not
|
432 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
433 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
434 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
435 |
+
less than `1`).
|
436 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
437 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
438 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
439 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
440 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
441 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
442 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
443 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
444 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
445 |
+
argument.
|
446 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
447 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
448 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
449 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
450 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
451 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
452 |
+
input argument.
|
453 |
+
lora_scale (`float`, *optional*):
|
454 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
455 |
+
clip_skip (`int`, *optional*):
|
456 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
457 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
458 |
+
"""
|
459 |
+
device = device or self._execution_device
|
460 |
+
|
461 |
+
# set lora scale so that monkey patched LoRA
|
462 |
+
# function of text encoder can correctly access it
|
463 |
+
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
|
464 |
+
self._lora_scale = lora_scale
|
465 |
+
|
466 |
+
# dynamically adjust the LoRA scale
|
467 |
+
if self.text_encoder is not None:
|
468 |
+
if not USE_PEFT_BACKEND:
|
469 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
470 |
+
else:
|
471 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
472 |
+
|
473 |
+
if self.text_encoder_2 is not None:
|
474 |
+
if not USE_PEFT_BACKEND:
|
475 |
+
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
476 |
+
else:
|
477 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
478 |
+
|
479 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
480 |
+
|
481 |
+
if prompt is not None:
|
482 |
+
batch_size = len(prompt)
|
483 |
+
else:
|
484 |
+
batch_size = prompt_embeds.shape[0]
|
485 |
+
|
486 |
+
# Define tokenizers and text encoders
|
487 |
+
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
|
488 |
+
text_encoders = (
|
489 |
+
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
490 |
+
)
|
491 |
+
|
492 |
+
if prompt_embeds is None:
|
493 |
+
prompt_2 = prompt_2 or prompt
|
494 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
495 |
+
|
496 |
+
# textual inversion: process multi-vector tokens if necessary
|
497 |
+
prompt_embeds_list = []
|
498 |
+
prompts = [prompt, prompt_2]
|
499 |
+
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
|
500 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
501 |
+
prompt = self.maybe_convert_prompt(prompt, tokenizer)
|
502 |
+
|
503 |
+
text_inputs = tokenizer(
|
504 |
+
prompt,
|
505 |
+
padding="max_length",
|
506 |
+
max_length=tokenizer.model_max_length,
|
507 |
+
truncation=True,
|
508 |
+
return_tensors="pt",
|
509 |
+
)
|
510 |
+
|
511 |
+
text_input_ids = text_inputs.input_ids
|
512 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
513 |
+
|
514 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
515 |
+
text_input_ids, untruncated_ids
|
516 |
+
):
|
517 |
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
518 |
+
logger.warning(
|
519 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
520 |
+
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
521 |
+
)
|
522 |
+
|
523 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
524 |
+
|
525 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
526 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
527 |
+
if clip_skip is None:
|
528 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
529 |
+
else:
|
530 |
+
# "2" because SDXL always indexes from the penultimate layer.
|
531 |
+
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
532 |
+
|
533 |
+
prompt_embeds_list.append(prompt_embeds)
|
534 |
+
|
535 |
+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
536 |
+
|
537 |
+
# get unconditional embeddings for classifier free guidance
|
538 |
+
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
|
539 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
|
540 |
+
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
541 |
+
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
542 |
+
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
543 |
+
negative_prompt = negative_prompt or ""
|
544 |
+
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
545 |
+
|
546 |
+
# normalize str to list
|
547 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
548 |
+
negative_prompt_2 = (
|
549 |
+
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
550 |
+
)
|
551 |
+
|
552 |
+
uncond_tokens: List[str]
|
553 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
554 |
+
raise TypeError(
|
555 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
556 |
+
f" {type(prompt)}."
|
557 |
+
)
|
558 |
+
elif batch_size != len(negative_prompt):
|
559 |
+
raise ValueError(
|
560 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
561 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
562 |
+
" the batch size of `prompt`."
|
563 |
+
)
|
564 |
+
else:
|
565 |
+
uncond_tokens = [negative_prompt, negative_prompt_2]
|
566 |
+
|
567 |
+
negative_prompt_embeds_list = []
|
568 |
+
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
|
569 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
570 |
+
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
|
571 |
+
|
572 |
+
max_length = prompt_embeds.shape[1]
|
573 |
+
uncond_input = tokenizer(
|
574 |
+
negative_prompt,
|
575 |
+
padding="max_length",
|
576 |
+
max_length=max_length,
|
577 |
+
truncation=True,
|
578 |
+
return_tensors="pt",
|
579 |
+
)
|
580 |
+
|
581 |
+
negative_prompt_embeds = text_encoder(
|
582 |
+
uncond_input.input_ids.to(device),
|
583 |
+
output_hidden_states=True,
|
584 |
+
)
|
585 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
586 |
+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
587 |
+
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
588 |
+
|
589 |
+
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
590 |
+
|
591 |
+
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
592 |
+
|
593 |
+
if self.text_encoder_2 is not None:
|
594 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
595 |
+
else:
|
596 |
+
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
597 |
+
|
598 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
599 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
600 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
601 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
602 |
+
|
603 |
+
if do_classifier_free_guidance:
|
604 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
605 |
+
seq_len = negative_prompt_embeds.shape[1]
|
606 |
+
|
607 |
+
if self.text_encoder_2 is not None:
|
608 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
609 |
+
else:
|
610 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
611 |
+
|
612 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
613 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
614 |
+
|
615 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
616 |
+
bs_embed * num_images_per_prompt, -1
|
617 |
+
)
|
618 |
+
if do_classifier_free_guidance:
|
619 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
620 |
+
bs_embed * num_images_per_prompt, -1
|
621 |
+
)
|
622 |
+
|
623 |
+
if self.text_encoder is not None:
|
624 |
+
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
625 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
626 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
627 |
+
|
628 |
+
if self.text_encoder_2 is not None:
|
629 |
+
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
630 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
631 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
632 |
+
|
633 |
+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
634 |
+
|
635 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
636 |
+
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
637 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
638 |
+
|
639 |
+
if not isinstance(image, torch.Tensor):
|
640 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
641 |
+
|
642 |
+
image = image.to(device=device, dtype=dtype)
|
643 |
+
if output_hidden_states:
|
644 |
+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
645 |
+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
646 |
+
uncond_image_enc_hidden_states = self.image_encoder(
|
647 |
+
torch.zeros_like(image), output_hidden_states=True
|
648 |
+
).hidden_states[-2]
|
649 |
+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
650 |
+
num_images_per_prompt, dim=0
|
651 |
+
)
|
652 |
+
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
653 |
+
else:
|
654 |
+
if isinstance(self.image_encoder, CLIPVisionModelWithProjection):
|
655 |
+
# CLIP image encoder.
|
656 |
+
image_embeds = self.image_encoder(image).image_embeds
|
657 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
658 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
659 |
+
else:
|
660 |
+
# DINO image encoder.
|
661 |
+
image_embeds = self.image_encoder(image).last_hidden_state
|
662 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
663 |
+
uncond_image_embeds = self.image_encoder(
|
664 |
+
torch.zeros_like(image)
|
665 |
+
).last_hidden_state
|
666 |
+
uncond_image_embeds = uncond_image_embeds.repeat_interleave(
|
667 |
+
num_images_per_prompt, dim=0
|
668 |
+
)
|
669 |
+
|
670 |
+
return image_embeds, uncond_image_embeds
|
671 |
+
|
672 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
673 |
+
def prepare_ip_adapter_image_embeds(
|
674 |
+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
|
675 |
+
):
|
676 |
+
if ip_adapter_image_embeds is None:
|
677 |
+
if not isinstance(ip_adapter_image, list):
|
678 |
+
ip_adapter_image = [ip_adapter_image]
|
679 |
+
|
680 |
+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
681 |
+
if isinstance(ip_adapter_image[0], list):
|
682 |
+
raise ValueError(
|
683 |
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
684 |
+
)
|
685 |
+
else:
|
686 |
+
logger.warning(
|
687 |
+
f"Got {len(ip_adapter_image)} images for {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
688 |
+
" By default, these images will be sent to each IP-Adapter. If this is not your use-case, please specify `ip_adapter_image` as a list of image-list, with"
|
689 |
+
f" length equals to the number of IP-Adapters."
|
690 |
+
)
|
691 |
+
ip_adapter_image = [ip_adapter_image] * len(self.unet.encoder_hid_proj.image_projection_layers)
|
692 |
+
|
693 |
+
image_embeds = []
|
694 |
+
for single_ip_adapter_image, image_proj_layer in zip(
|
695 |
+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
696 |
+
):
|
697 |
+
output_hidden_state = isinstance(self.image_encoder, CLIPVisionModelWithProjection) and not isinstance(image_proj_layer, ImageProjection)
|
698 |
+
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
699 |
+
single_ip_adapter_image, device, 1, output_hidden_state
|
700 |
+
)
|
701 |
+
single_image_embeds = torch.stack([single_image_embeds] * (num_images_per_prompt//single_image_embeds.shape[0]), dim=0)
|
702 |
+
single_negative_image_embeds = torch.stack(
|
703 |
+
[single_negative_image_embeds] * (num_images_per_prompt//single_negative_image_embeds.shape[0]), dim=0
|
704 |
+
)
|
705 |
+
|
706 |
+
if do_classifier_free_guidance:
|
707 |
+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
708 |
+
single_image_embeds = single_image_embeds.to(device)
|
709 |
+
|
710 |
+
image_embeds.append(single_image_embeds)
|
711 |
+
else:
|
712 |
+
repeat_dims = [1]
|
713 |
+
image_embeds = []
|
714 |
+
for single_image_embeds in ip_adapter_image_embeds:
|
715 |
+
if do_classifier_free_guidance:
|
716 |
+
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
717 |
+
single_image_embeds = single_image_embeds.repeat(
|
718 |
+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
719 |
+
)
|
720 |
+
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
721 |
+
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
722 |
+
)
|
723 |
+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
724 |
+
else:
|
725 |
+
single_image_embeds = single_image_embeds.repeat(
|
726 |
+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
727 |
+
)
|
728 |
+
image_embeds.append(single_image_embeds)
|
729 |
+
|
730 |
+
return image_embeds
|
731 |
+
|
732 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
733 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
734 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
735 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
736 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
737 |
+
# and should be between [0, 1]
|
738 |
+
|
739 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
740 |
+
extra_step_kwargs = {}
|
741 |
+
if accepts_eta:
|
742 |
+
extra_step_kwargs["eta"] = eta
|
743 |
+
|
744 |
+
# check if the scheduler accepts generator
|
745 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
746 |
+
if accepts_generator:
|
747 |
+
extra_step_kwargs["generator"] = generator
|
748 |
+
return extra_step_kwargs
|
749 |
+
|
750 |
+
def check_inputs(
|
751 |
+
self,
|
752 |
+
prompt,
|
753 |
+
prompt_2,
|
754 |
+
image,
|
755 |
+
callback_steps,
|
756 |
+
negative_prompt=None,
|
757 |
+
negative_prompt_2=None,
|
758 |
+
prompt_embeds=None,
|
759 |
+
negative_prompt_embeds=None,
|
760 |
+
pooled_prompt_embeds=None,
|
761 |
+
ip_adapter_image=None,
|
762 |
+
ip_adapter_image_embeds=None,
|
763 |
+
negative_pooled_prompt_embeds=None,
|
764 |
+
controlnet_conditioning_scale=1.0,
|
765 |
+
control_guidance_start=0.0,
|
766 |
+
control_guidance_end=1.0,
|
767 |
+
callback_on_step_end_tensor_inputs=None,
|
768 |
+
):
|
769 |
+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
770 |
+
raise ValueError(
|
771 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
772 |
+
f" {type(callback_steps)}."
|
773 |
+
)
|
774 |
+
|
775 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
776 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
777 |
+
):
|
778 |
+
raise ValueError(
|
779 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
780 |
+
)
|
781 |
+
|
782 |
+
if prompt is not None and prompt_embeds is not None:
|
783 |
+
raise ValueError(
|
784 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
785 |
+
" only forward one of the two."
|
786 |
+
)
|
787 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
788 |
+
raise ValueError(
|
789 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
790 |
+
" only forward one of the two."
|
791 |
+
)
|
792 |
+
elif prompt is None and prompt_embeds is None:
|
793 |
+
raise ValueError(
|
794 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
795 |
+
)
|
796 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
797 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
798 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
799 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
800 |
+
|
801 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
802 |
+
raise ValueError(
|
803 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
804 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
805 |
+
)
|
806 |
+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
807 |
+
raise ValueError(
|
808 |
+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
809 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
810 |
+
)
|
811 |
+
|
812 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
813 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
814 |
+
raise ValueError(
|
815 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
816 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
817 |
+
f" {negative_prompt_embeds.shape}."
|
818 |
+
)
|
819 |
+
|
820 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
821 |
+
raise ValueError(
|
822 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
823 |
+
)
|
824 |
+
|
825 |
+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
826 |
+
raise ValueError(
|
827 |
+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
828 |
+
)
|
829 |
+
|
830 |
+
# Check `image`
|
831 |
+
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
832 |
+
self.aggregator, torch._dynamo.eval_frame.OptimizedModule
|
833 |
+
)
|
834 |
+
if (
|
835 |
+
isinstance(self.aggregator, Aggregator)
|
836 |
+
or is_compiled
|
837 |
+
and isinstance(self.aggregator._orig_mod, Aggregator)
|
838 |
+
):
|
839 |
+
self.check_image(image, prompt, prompt_embeds)
|
840 |
+
else:
|
841 |
+
assert False
|
842 |
+
|
843 |
+
if control_guidance_start >= control_guidance_end:
|
844 |
+
raise ValueError(
|
845 |
+
f"control guidance start: {control_guidance_start} cannot be larger or equal to control guidance end: {control_guidance_end}."
|
846 |
+
)
|
847 |
+
if control_guidance_start < 0.0:
|
848 |
+
raise ValueError(f"control guidance start: {control_guidance_start} can't be smaller than 0.")
|
849 |
+
if control_guidance_end > 1.0:
|
850 |
+
raise ValueError(f"control guidance end: {control_guidance_end} can't be larger than 1.0.")
|
851 |
+
|
852 |
+
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
853 |
+
raise ValueError(
|
854 |
+
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
855 |
+
)
|
856 |
+
|
857 |
+
if ip_adapter_image_embeds is not None:
|
858 |
+
if not isinstance(ip_adapter_image_embeds, list):
|
859 |
+
raise ValueError(
|
860 |
+
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
861 |
+
)
|
862 |
+
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
863 |
+
raise ValueError(
|
864 |
+
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
865 |
+
)
|
866 |
+
|
867 |
+
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
|
868 |
+
def check_image(self, image, prompt, prompt_embeds):
|
869 |
+
image_is_pil = isinstance(image, PIL.Image.Image)
|
870 |
+
image_is_tensor = isinstance(image, torch.Tensor)
|
871 |
+
image_is_np = isinstance(image, np.ndarray)
|
872 |
+
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
|
873 |
+
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
|
874 |
+
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
|
875 |
+
|
876 |
+
if (
|
877 |
+
not image_is_pil
|
878 |
+
and not image_is_tensor
|
879 |
+
and not image_is_np
|
880 |
+
and not image_is_pil_list
|
881 |
+
and not image_is_tensor_list
|
882 |
+
and not image_is_np_list
|
883 |
+
):
|
884 |
+
raise TypeError(
|
885 |
+
f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
|
886 |
+
)
|
887 |
+
|
888 |
+
if image_is_pil:
|
889 |
+
image_batch_size = 1
|
890 |
+
else:
|
891 |
+
image_batch_size = len(image)
|
892 |
+
|
893 |
+
if prompt is not None and isinstance(prompt, str):
|
894 |
+
prompt_batch_size = 1
|
895 |
+
elif prompt is not None and isinstance(prompt, list):
|
896 |
+
prompt_batch_size = len(prompt)
|
897 |
+
elif prompt_embeds is not None:
|
898 |
+
prompt_batch_size = prompt_embeds.shape[0]
|
899 |
+
|
900 |
+
if image_batch_size != 1 and image_batch_size != prompt_batch_size:
|
901 |
+
raise ValueError(
|
902 |
+
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
|
903 |
+
)
|
904 |
+
|
905 |
+
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
|
906 |
+
def prepare_image(
|
907 |
+
self,
|
908 |
+
image,
|
909 |
+
width,
|
910 |
+
height,
|
911 |
+
batch_size,
|
912 |
+
num_images_per_prompt,
|
913 |
+
device,
|
914 |
+
dtype,
|
915 |
+
do_classifier_free_guidance=False,
|
916 |
+
):
|
917 |
+
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
918 |
+
image_batch_size = image.shape[0]
|
919 |
+
|
920 |
+
if image_batch_size == 1:
|
921 |
+
repeat_by = batch_size
|
922 |
+
else:
|
923 |
+
# image batch size is the same as prompt batch size
|
924 |
+
repeat_by = num_images_per_prompt
|
925 |
+
|
926 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
927 |
+
|
928 |
+
image = image.to(device=device, dtype=dtype)
|
929 |
+
|
930 |
+
return image
|
931 |
+
|
932 |
+
@torch.no_grad()
|
933 |
+
def init_latents(self, latents, generator, timestep):
|
934 |
+
noise = torch.randn(latents.shape, generator=generator, device=self.vae.device, dtype=self.vae.dtype, layout=torch.strided)
|
935 |
+
bsz = latents.shape[0]
|
936 |
+
print(f"init latent at {timestep}")
|
937 |
+
timestep = torch.tensor([timestep]*bsz, device=self.vae.device)
|
938 |
+
# Note that the latents will be scaled aleady by scheduler.add_noise
|
939 |
+
latents = self.scheduler.add_noise(latents, noise, timestep)
|
940 |
+
return latents
|
941 |
+
|
942 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
943 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
944 |
+
shape = (
|
945 |
+
batch_size,
|
946 |
+
num_channels_latents,
|
947 |
+
int(height) // self.vae_scale_factor,
|
948 |
+
int(width) // self.vae_scale_factor,
|
949 |
+
)
|
950 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
951 |
+
raise ValueError(
|
952 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
953 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
954 |
+
)
|
955 |
+
|
956 |
+
if latents is None:
|
957 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
958 |
+
else:
|
959 |
+
latents = latents.to(device)
|
960 |
+
|
961 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
962 |
+
latents = latents * self.scheduler.init_noise_sigma
|
963 |
+
return latents
|
964 |
+
|
965 |
+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
|
966 |
+
def _get_add_time_ids(
|
967 |
+
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
968 |
+
):
|
969 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
970 |
+
|
971 |
+
passed_add_embed_dim = (
|
972 |
+
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
973 |
+
)
|
974 |
+
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
975 |
+
|
976 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
977 |
+
raise ValueError(
|
978 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
979 |
+
)
|
980 |
+
|
981 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
982 |
+
return add_time_ids
|
983 |
+
|
984 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
985 |
+
def upcast_vae(self):
|
986 |
+
dtype = self.vae.dtype
|
987 |
+
self.vae.to(dtype=torch.float32)
|
988 |
+
use_torch_2_0_or_xformers = isinstance(
|
989 |
+
self.vae.decoder.mid_block.attentions[0].processor,
|
990 |
+
(
|
991 |
+
AttnProcessor2_0,
|
992 |
+
XFormersAttnProcessor,
|
993 |
+
LoRAXFormersAttnProcessor,
|
994 |
+
LoRAAttnProcessor2_0,
|
995 |
+
),
|
996 |
+
)
|
997 |
+
# if xformers or torch_2_0 is used attention block does not need
|
998 |
+
# to be in float32 which can save lots of memory
|
999 |
+
if use_torch_2_0_or_xformers:
|
1000 |
+
self.vae.post_quant_conv.to(dtype)
|
1001 |
+
self.vae.decoder.conv_in.to(dtype)
|
1002 |
+
self.vae.decoder.mid_block.to(dtype)
|
1003 |
+
|
1004 |
+
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
1005 |
+
def get_guidance_scale_embedding(
|
1006 |
+
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
|
1007 |
+
) -> torch.FloatTensor:
|
1008 |
+
"""
|
1009 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
1010 |
+
|
1011 |
+
Args:
|
1012 |
+
w (`torch.Tensor`):
|
1013 |
+
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
|
1014 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
1015 |
+
Dimension of the embeddings to generate.
|
1016 |
+
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
1017 |
+
Data type of the generated embeddings.
|
1018 |
+
|
1019 |
+
Returns:
|
1020 |
+
`torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
|
1021 |
+
"""
|
1022 |
+
assert len(w.shape) == 1
|
1023 |
+
w = w * 1000.0
|
1024 |
+
|
1025 |
+
half_dim = embedding_dim // 2
|
1026 |
+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
1027 |
+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
1028 |
+
emb = w.to(dtype)[:, None] * emb[None, :]
|
1029 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
1030 |
+
if embedding_dim % 2 == 1: # zero pad
|
1031 |
+
emb = torch.nn.functional.pad(emb, (0, 1))
|
1032 |
+
assert emb.shape == (w.shape[0], embedding_dim)
|
1033 |
+
return emb
|
1034 |
+
|
1035 |
+
@property
|
1036 |
+
def guidance_scale(self):
|
1037 |
+
return self._guidance_scale
|
1038 |
+
|
1039 |
+
@property
|
1040 |
+
def guidance_rescale(self):
|
1041 |
+
return self._guidance_rescale
|
1042 |
+
|
1043 |
+
@property
|
1044 |
+
def clip_skip(self):
|
1045 |
+
return self._clip_skip
|
1046 |
+
|
1047 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
1048 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
1049 |
+
# corresponds to doing no classifier free guidance.
|
1050 |
+
@property
|
1051 |
+
def do_classifier_free_guidance(self):
|
1052 |
+
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
1053 |
+
|
1054 |
+
@property
|
1055 |
+
def cross_attention_kwargs(self):
|
1056 |
+
return self._cross_attention_kwargs
|
1057 |
+
|
1058 |
+
@property
|
1059 |
+
def denoising_end(self):
|
1060 |
+
return self._denoising_end
|
1061 |
+
|
1062 |
+
@property
|
1063 |
+
def num_timesteps(self):
|
1064 |
+
return self._num_timesteps
|
1065 |
+
|
1066 |
+
@torch.no_grad()
|
1067 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
1068 |
+
def __call__(
|
1069 |
+
self,
|
1070 |
+
prompt: Union[str, List[str]] = None,
|
1071 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
1072 |
+
image: PipelineImageInput = None,
|
1073 |
+
height: Optional[int] = None,
|
1074 |
+
width: Optional[int] = None,
|
1075 |
+
num_inference_steps: int = 30,
|
1076 |
+
timesteps: List[int] = None,
|
1077 |
+
denoising_end: Optional[float] = None,
|
1078 |
+
guidance_scale: float = 7.0,
|
1079 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
1080 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
1081 |
+
num_images_per_prompt: Optional[int] = 1,
|
1082 |
+
eta: float = 0.0,
|
1083 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
1084 |
+
latents: Optional[torch.FloatTensor] = None,
|
1085 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
1086 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1087 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1088 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1089 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
1090 |
+
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
|
1091 |
+
output_type: Optional[str] = "pil",
|
1092 |
+
return_dict: bool = True,
|
1093 |
+
save_preview_row: bool = False,
|
1094 |
+
init_latents_with_lq: bool = True,
|
1095 |
+
multistep_restore: bool = False,
|
1096 |
+
adastep_restore: bool = False,
|
1097 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1098 |
+
guidance_rescale: float = 0.0,
|
1099 |
+
controlnet_conditioning_scale: float = 1.0,
|
1100 |
+
control_guidance_start: float = 0.0,
|
1101 |
+
control_guidance_end: float = 1.0,
|
1102 |
+
preview_start: float = 0.0,
|
1103 |
+
preview_end: float = 1.0,
|
1104 |
+
original_size: Tuple[int, int] = None,
|
1105 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
1106 |
+
target_size: Tuple[int, int] = None,
|
1107 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
1108 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
1109 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
1110 |
+
clip_skip: Optional[int] = None,
|
1111 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
1112 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
1113 |
+
previewer_scheduler: KarrasDiffusionSchedulers = None,
|
1114 |
+
reference_latents: Optional[torch.FloatTensor] = None,
|
1115 |
+
**kwargs,
|
1116 |
+
):
|
1117 |
+
r"""
|
1118 |
+
The call function to the pipeline for generation.
|
1119 |
+
|
1120 |
+
Args:
|
1121 |
+
prompt (`str` or `List[str]`, *optional*):
|
1122 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
1123 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
1124 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
1125 |
+
used in both text-encoders.
|
1126 |
+
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
1127 |
+
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
1128 |
+
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
1129 |
+
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
|
1130 |
+
accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
|
1131 |
+
and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
|
1132 |
+
`init`, images must be passed as a list such that each element of the list can be correctly batched for
|
1133 |
+
input to a single ControlNet.
|
1134 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
1135 |
+
The height in pixels of the generated image. Anything below 512 pixels won't work well for
|
1136 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
1137 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
1138 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
1139 |
+
The width in pixels of the generated image. Anything below 512 pixels won't work well for
|
1140 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
1141 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
1142 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1143 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
1144 |
+
expense of slower inference.
|
1145 |
+
timesteps (`List[int]`, *optional*):
|
1146 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
1147 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
1148 |
+
passed will be used. Must be in descending order.
|
1149 |
+
denoising_end (`float`, *optional*):
|
1150 |
+
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
1151 |
+
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
1152 |
+
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
1153 |
+
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
1154 |
+
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
1155 |
+
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
1156 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
1157 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
1158 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
1159 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
1160 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
1161 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
1162 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
1163 |
+
The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
|
1164 |
+
and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
|
1165 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1166 |
+
The number of images to generate per prompt.
|
1167 |
+
eta (`float`, *optional*, defaults to 0.0):
|
1168 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
1169 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
1170 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
1171 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
1172 |
+
generation deterministic.
|
1173 |
+
latents (`torch.FloatTensor`, *optional*):
|
1174 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
1175 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
1176 |
+
tensor is generated by sampling using the supplied random `generator`.
|
1177 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
1178 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
1179 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
1180 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
1181 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
1182 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
1183 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
1184 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
1185 |
+
not provided, pooled text embeddings are generated from `prompt` input argument.
|
1186 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
1187 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
|
1188 |
+
weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
|
1189 |
+
argument.
|
1190 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
1191 |
+
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
|
1192 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
1193 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
|
1194 |
+
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
|
1195 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
1196 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1197 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
1198 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1199 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
1200 |
+
plain tuple.
|
1201 |
+
cross_attention_kwargs (`dict`, *optional*):
|
1202 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
1203 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
1204 |
+
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
1205 |
+
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
1206 |
+
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
1207 |
+
the corresponding scale as a list.
|
1208 |
+
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
1209 |
+
The percentage of total steps at which the ControlNet starts applying.
|
1210 |
+
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
1211 |
+
The percentage of total steps at which the ControlNet stops applying.
|
1212 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
1213 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
1214 |
+
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
1215 |
+
explained in section 2.2 of
|
1216 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
1217 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
1218 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
1219 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
1220 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
1221 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
1222 |
+
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
1223 |
+
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
1224 |
+
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
1225 |
+
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
1226 |
+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
1227 |
+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
1228 |
+
micro-conditioning as explained in section 2.2 of
|
1229 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
1230 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
1231 |
+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
1232 |
+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
1233 |
+
micro-conditioning as explained in section 2.2 of
|
1234 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
1235 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
1236 |
+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
1237 |
+
To negatively condition the generation process based on a target image resolution. It should be as same
|
1238 |
+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
1239 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
1240 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
1241 |
+
clip_skip (`int`, *optional*):
|
1242 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
1243 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
1244 |
+
callback_on_step_end (`Callable`, *optional*):
|
1245 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
1246 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
1247 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
1248 |
+
`callback_on_step_end_tensor_inputs`.
|
1249 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
1250 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
1251 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
1252 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
1253 |
+
|
1254 |
+
Examples:
|
1255 |
+
|
1256 |
+
Returns:
|
1257 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1258 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
1259 |
+
otherwise a `tuple` is returned containing the output images.
|
1260 |
+
"""
|
1261 |
+
|
1262 |
+
callback = kwargs.pop("callback", None)
|
1263 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
1264 |
+
|
1265 |
+
if callback is not None:
|
1266 |
+
deprecate(
|
1267 |
+
"callback",
|
1268 |
+
"1.0.0",
|
1269 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
1270 |
+
)
|
1271 |
+
if callback_steps is not None:
|
1272 |
+
deprecate(
|
1273 |
+
"callback_steps",
|
1274 |
+
"1.0.0",
|
1275 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
1276 |
+
)
|
1277 |
+
|
1278 |
+
aggregator = self.aggregator._orig_mod if is_compiled_module(self.aggregator) else self.aggregator
|
1279 |
+
if not isinstance(ip_adapter_image, list):
|
1280 |
+
ip_adapter_image = [ip_adapter_image] if ip_adapter_image is not None else [image]
|
1281 |
+
|
1282 |
+
# 1. Check inputs. Raise error if not correct
|
1283 |
+
self.check_inputs(
|
1284 |
+
prompt,
|
1285 |
+
prompt_2,
|
1286 |
+
image,
|
1287 |
+
callback_steps,
|
1288 |
+
negative_prompt,
|
1289 |
+
negative_prompt_2,
|
1290 |
+
prompt_embeds,
|
1291 |
+
negative_prompt_embeds,
|
1292 |
+
pooled_prompt_embeds,
|
1293 |
+
ip_adapter_image,
|
1294 |
+
ip_adapter_image_embeds,
|
1295 |
+
negative_pooled_prompt_embeds,
|
1296 |
+
controlnet_conditioning_scale,
|
1297 |
+
control_guidance_start,
|
1298 |
+
control_guidance_end,
|
1299 |
+
callback_on_step_end_tensor_inputs,
|
1300 |
+
)
|
1301 |
+
|
1302 |
+
self._guidance_scale = guidance_scale
|
1303 |
+
self._guidance_rescale = guidance_rescale
|
1304 |
+
self._clip_skip = clip_skip
|
1305 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
1306 |
+
self._denoising_end = denoising_end
|
1307 |
+
|
1308 |
+
# 2. Define call parameters
|
1309 |
+
if prompt is not None and isinstance(prompt, str):
|
1310 |
+
if not isinstance(image, PIL.Image.Image):
|
1311 |
+
batch_size = len(image)
|
1312 |
+
else:
|
1313 |
+
batch_size = 1
|
1314 |
+
prompt = [prompt] * batch_size
|
1315 |
+
elif prompt is not None and isinstance(prompt, list):
|
1316 |
+
batch_size = len(prompt)
|
1317 |
+
assert batch_size == len(image) or (isinstance(image, PIL.Image.Image) or len(image) == 1)
|
1318 |
+
else:
|
1319 |
+
batch_size = prompt_embeds.shape[0]
|
1320 |
+
assert batch_size == len(image) or (isinstance(image, PIL.Image.Image) or len(image) == 1)
|
1321 |
+
|
1322 |
+
device = self._execution_device
|
1323 |
+
|
1324 |
+
# 3.1 Encode input prompt
|
1325 |
+
text_encoder_lora_scale = (
|
1326 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
1327 |
+
)
|
1328 |
+
(
|
1329 |
+
prompt_embeds,
|
1330 |
+
negative_prompt_embeds,
|
1331 |
+
pooled_prompt_embeds,
|
1332 |
+
negative_pooled_prompt_embeds,
|
1333 |
+
) = self.encode_prompt(
|
1334 |
+
prompt=prompt,
|
1335 |
+
prompt_2=prompt_2,
|
1336 |
+
device=device,
|
1337 |
+
num_images_per_prompt=num_images_per_prompt,
|
1338 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
1339 |
+
negative_prompt=negative_prompt,
|
1340 |
+
negative_prompt_2=negative_prompt_2,
|
1341 |
+
prompt_embeds=prompt_embeds,
|
1342 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
1343 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
1344 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
1345 |
+
lora_scale=text_encoder_lora_scale,
|
1346 |
+
clip_skip=self.clip_skip,
|
1347 |
+
)
|
1348 |
+
|
1349 |
+
# 3.2 Encode ip_adapter_image
|
1350 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
1351 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
1352 |
+
ip_adapter_image,
|
1353 |
+
ip_adapter_image_embeds,
|
1354 |
+
device,
|
1355 |
+
batch_size * num_images_per_prompt,
|
1356 |
+
self.do_classifier_free_guidance,
|
1357 |
+
)
|
1358 |
+
|
1359 |
+
# 4. Prepare image
|
1360 |
+
image = self.prepare_image(
|
1361 |
+
image=image,
|
1362 |
+
width=width,
|
1363 |
+
height=height,
|
1364 |
+
batch_size=batch_size * num_images_per_prompt,
|
1365 |
+
num_images_per_prompt=num_images_per_prompt,
|
1366 |
+
device=device,
|
1367 |
+
dtype=aggregator.dtype,
|
1368 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
1369 |
+
)
|
1370 |
+
height, width = image.shape[-2:]
|
1371 |
+
if image.shape[1] != 4:
|
1372 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
1373 |
+
if needs_upcasting:
|
1374 |
+
image = image.float()
|
1375 |
+
self.vae.to(dtype=torch.float32)
|
1376 |
+
image = self.vae.encode(image).latent_dist.sample()
|
1377 |
+
image = image * self.vae.config.scaling_factor
|
1378 |
+
if needs_upcasting:
|
1379 |
+
self.vae.to(dtype=torch.float16)
|
1380 |
+
image = image.to(dtype=torch.float16)
|
1381 |
+
else:
|
1382 |
+
height = int(height * self.vae_scale_factor)
|
1383 |
+
width = int(width * self.vae_scale_factor)
|
1384 |
+
|
1385 |
+
# 5. Prepare timesteps
|
1386 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
1387 |
+
|
1388 |
+
# 6. Prepare latent variables
|
1389 |
+
if init_latents_with_lq:
|
1390 |
+
latents = self.init_latents(image, generator, timesteps[0])
|
1391 |
+
else:
|
1392 |
+
num_channels_latents = self.unet.config.in_channels
|
1393 |
+
latents = self.prepare_latents(
|
1394 |
+
batch_size * num_images_per_prompt,
|
1395 |
+
num_channels_latents,
|
1396 |
+
height,
|
1397 |
+
width,
|
1398 |
+
prompt_embeds.dtype,
|
1399 |
+
device,
|
1400 |
+
generator,
|
1401 |
+
latents,
|
1402 |
+
)
|
1403 |
+
|
1404 |
+
# 6.5 Optionally get Guidance Scale Embedding
|
1405 |
+
timestep_cond = None
|
1406 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
1407 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
1408 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
1409 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
1410 |
+
).to(device=device, dtype=latents.dtype)
|
1411 |
+
|
1412 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
1413 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
1414 |
+
|
1415 |
+
# 7.1 Create tensor stating which controlnets to keep
|
1416 |
+
controlnet_keep = []
|
1417 |
+
previewing = []
|
1418 |
+
for i in range(len(timesteps)):
|
1419 |
+
keeps = 1.0 - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
|
1420 |
+
controlnet_keep.append(keeps)
|
1421 |
+
use_preview = 1.0 - float(i / len(timesteps) < preview_start or (i + 1) / len(timesteps) > preview_end)
|
1422 |
+
previewing.append(use_preview)
|
1423 |
+
if isinstance(controlnet_conditioning_scale, list):
|
1424 |
+
assert len(controlnet_conditioning_scale) == len(timesteps), f"{len(controlnet_conditioning_scale)} controlnet scales do not match number of sampling steps {len(timesteps)}"
|
1425 |
+
else:
|
1426 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet_keep)
|
1427 |
+
|
1428 |
+
# 7.2 Prepare added time ids & embeddings
|
1429 |
+
original_size = original_size or (height, width)
|
1430 |
+
target_size = target_size or (height, width)
|
1431 |
+
|
1432 |
+
add_text_embeds = pooled_prompt_embeds
|
1433 |
+
if self.text_encoder_2 is None:
|
1434 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
1435 |
+
else:
|
1436 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
1437 |
+
|
1438 |
+
add_time_ids = self._get_add_time_ids(
|
1439 |
+
original_size,
|
1440 |
+
crops_coords_top_left,
|
1441 |
+
target_size,
|
1442 |
+
dtype=prompt_embeds.dtype,
|
1443 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
1444 |
+
)
|
1445 |
+
|
1446 |
+
if negative_original_size is not None and negative_target_size is not None:
|
1447 |
+
negative_add_time_ids = self._get_add_time_ids(
|
1448 |
+
negative_original_size,
|
1449 |
+
negative_crops_coords_top_left,
|
1450 |
+
negative_target_size,
|
1451 |
+
dtype=prompt_embeds.dtype,
|
1452 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
1453 |
+
)
|
1454 |
+
else:
|
1455 |
+
negative_add_time_ids = add_time_ids
|
1456 |
+
|
1457 |
+
if self.do_classifier_free_guidance:
|
1458 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
1459 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
1460 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
1461 |
+
image = torch.cat([image] * 2, dim=0)
|
1462 |
+
|
1463 |
+
prompt_embeds = prompt_embeds.to(device)
|
1464 |
+
add_text_embeds = add_text_embeds.to(device)
|
1465 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
1466 |
+
|
1467 |
+
# 8. Denoising loop
|
1468 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
1469 |
+
|
1470 |
+
# 8.1 Apply denoising_end
|
1471 |
+
if (
|
1472 |
+
self.denoising_end is not None
|
1473 |
+
and isinstance(self.denoising_end, float)
|
1474 |
+
and self.denoising_end > 0
|
1475 |
+
and self.denoising_end < 1
|
1476 |
+
):
|
1477 |
+
discrete_timestep_cutoff = int(
|
1478 |
+
round(
|
1479 |
+
self.scheduler.config.num_train_timesteps
|
1480 |
+
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
1481 |
+
)
|
1482 |
+
)
|
1483 |
+
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
1484 |
+
timesteps = timesteps[:num_inference_steps]
|
1485 |
+
|
1486 |
+
is_unet_compiled = is_compiled_module(self.unet)
|
1487 |
+
is_aggregator_compiled = is_compiled_module(self.aggregator)
|
1488 |
+
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
1489 |
+
previewer_mean = torch.zeros_like(latents)
|
1490 |
+
unet_mean = torch.zeros_like(latents)
|
1491 |
+
preview_factor = torch.ones(
|
1492 |
+
(latents.shape[0], *((1,) * (len(latents.shape) - 1))), dtype=latents.dtype, device=latents.device
|
1493 |
+
)
|
1494 |
+
|
1495 |
+
self._num_timesteps = len(timesteps)
|
1496 |
+
preview_row = []
|
1497 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1498 |
+
for i, t in enumerate(timesteps):
|
1499 |
+
# Relevant thread:
|
1500 |
+
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
|
1501 |
+
if (is_unet_compiled and is_aggregator_compiled) and is_torch_higher_equal_2_1:
|
1502 |
+
torch._inductor.cudagraph_mark_step_begin()
|
1503 |
+
# expand the latents if we are doing classifier free guidance
|
1504 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
1505 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1506 |
+
prev_t = t
|
1507 |
+
unet_model_input = latent_model_input
|
1508 |
+
|
1509 |
+
added_cond_kwargs = {
|
1510 |
+
"text_embeds": add_text_embeds,
|
1511 |
+
"time_ids": add_time_ids,
|
1512 |
+
"image_embeds": image_embeds
|
1513 |
+
}
|
1514 |
+
aggregator_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
1515 |
+
|
1516 |
+
# prepare time_embeds in advance as adapter input
|
1517 |
+
cross_attention_t_emb = self.unet.get_time_embed(sample=latent_model_input, timestep=t)
|
1518 |
+
cross_attention_emb = self.unet.time_embedding(cross_attention_t_emb, timestep_cond)
|
1519 |
+
cross_attention_aug_emb = None
|
1520 |
+
|
1521 |
+
cross_attention_aug_emb = self.unet.get_aug_embed(
|
1522 |
+
emb=cross_attention_emb,
|
1523 |
+
encoder_hidden_states=prompt_embeds,
|
1524 |
+
added_cond_kwargs=added_cond_kwargs
|
1525 |
+
)
|
1526 |
+
|
1527 |
+
cross_attention_emb = cross_attention_emb + cross_attention_aug_emb if cross_attention_aug_emb is not None else cross_attention_emb
|
1528 |
+
|
1529 |
+
if self.unet.time_embed_act is not None:
|
1530 |
+
cross_attention_emb = self.unet.time_embed_act(cross_attention_emb)
|
1531 |
+
|
1532 |
+
current_cross_attention_kwargs = {"temb": cross_attention_emb}
|
1533 |
+
if cross_attention_kwargs is not None:
|
1534 |
+
for k,v in cross_attention_kwargs.items():
|
1535 |
+
current_cross_attention_kwargs[k] = v
|
1536 |
+
self._cross_attention_kwargs = current_cross_attention_kwargs
|
1537 |
+
|
1538 |
+
# adaptive restoration factors
|
1539 |
+
adaRes_scale = preview_factor.to(latent_model_input.dtype).clamp(0.0, controlnet_conditioning_scale[i])
|
1540 |
+
cond_scale = adaRes_scale * controlnet_keep[i]
|
1541 |
+
cond_scale = torch.cat([cond_scale] * 2) if self.do_classifier_free_guidance else cond_scale
|
1542 |
+
|
1543 |
+
if (cond_scale>0.1).sum().item() > 0:
|
1544 |
+
if previewing[i] > 0:
|
1545 |
+
# preview with LCM
|
1546 |
+
self.unet.enable_adapters()
|
1547 |
+
preview_noise = self.unet(
|
1548 |
+
latent_model_input,
|
1549 |
+
t,
|
1550 |
+
encoder_hidden_states=prompt_embeds,
|
1551 |
+
timestep_cond=timestep_cond,
|
1552 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
1553 |
+
added_cond_kwargs=added_cond_kwargs,
|
1554 |
+
return_dict=False,
|
1555 |
+
)[0]
|
1556 |
+
preview_latent = previewer_scheduler.step(
|
1557 |
+
preview_noise,
|
1558 |
+
t.to(dtype=torch.int64),
|
1559 |
+
# torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents,
|
1560 |
+
latent_model_input, # scaled latents here for compatibility
|
1561 |
+
return_dict=False
|
1562 |
+
)[0]
|
1563 |
+
self.unet.disable_adapters()
|
1564 |
+
|
1565 |
+
if self.do_classifier_free_guidance:
|
1566 |
+
preview_row.append(preview_latent.chunk(2)[1].to('cpu'))
|
1567 |
+
else:
|
1568 |
+
preview_row.append(preview_latent.to('cpu'))
|
1569 |
+
# Prepare 2nd order step.
|
1570 |
+
if multistep_restore and i+1 < len(timesteps):
|
1571 |
+
noise_preview = preview_noise.chunk(2)[1] if self.do_classifier_free_guidance else preview_noise
|
1572 |
+
first_step = self.scheduler.step(
|
1573 |
+
noise_preview, t, latents,
|
1574 |
+
**extra_step_kwargs, return_dict=True, step_forward=False
|
1575 |
+
)
|
1576 |
+
prev_t = timesteps[i + 1]
|
1577 |
+
unet_model_input = torch.cat([first_step.prev_sample] * 2) if self.do_classifier_free_guidance else first_step.prev_sample
|
1578 |
+
unet_model_input = self.scheduler.scale_model_input(unet_model_input, prev_t, heun_step=True)
|
1579 |
+
|
1580 |
+
elif reference_latents is not None:
|
1581 |
+
preview_latent = torch.cat([reference_latents] * 2) if self.do_classifier_free_guidance else reference_latents
|
1582 |
+
else:
|
1583 |
+
preview_latent = image
|
1584 |
+
|
1585 |
+
# Add fresh noise
|
1586 |
+
# preview_noise = torch.randn_like(preview_latent)
|
1587 |
+
# preview_latent = self.scheduler.add_noise(preview_latent, preview_noise, t)
|
1588 |
+
|
1589 |
+
preview_latent=preview_latent.to(dtype=next(aggregator.parameters()).dtype)
|
1590 |
+
|
1591 |
+
# Aggregator inference
|
1592 |
+
down_block_res_samples, mid_block_res_sample = aggregator(
|
1593 |
+
image,
|
1594 |
+
prev_t,
|
1595 |
+
encoder_hidden_states=prompt_embeds,
|
1596 |
+
controlnet_cond=preview_latent,
|
1597 |
+
# conditioning_scale=cond_scale,
|
1598 |
+
added_cond_kwargs=aggregator_added_cond_kwargs,
|
1599 |
+
return_dict=False,
|
1600 |
+
)
|
1601 |
+
|
1602 |
+
# aggregator features scaling
|
1603 |
+
down_block_res_samples = [sample*cond_scale for sample in down_block_res_samples]
|
1604 |
+
mid_block_res_sample = mid_block_res_sample*cond_scale
|
1605 |
+
|
1606 |
+
# predict the noise residual
|
1607 |
+
noise_pred = self.unet(
|
1608 |
+
unet_model_input,
|
1609 |
+
prev_t,
|
1610 |
+
encoder_hidden_states=prompt_embeds,
|
1611 |
+
timestep_cond=timestep_cond,
|
1612 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
1613 |
+
down_block_additional_residuals=down_block_res_samples,
|
1614 |
+
mid_block_additional_residual=mid_block_res_sample,
|
1615 |
+
added_cond_kwargs=added_cond_kwargs,
|
1616 |
+
return_dict=False,
|
1617 |
+
)[0]
|
1618 |
+
|
1619 |
+
# perform guidance
|
1620 |
+
if self.do_classifier_free_guidance:
|
1621 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1622 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1623 |
+
|
1624 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
1625 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
1626 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
1627 |
+
|
1628 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1629 |
+
latents_dtype = latents.dtype
|
1630 |
+
unet_step = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True)
|
1631 |
+
latents = unet_step.prev_sample
|
1632 |
+
|
1633 |
+
# Update adaRes factors
|
1634 |
+
unet_pred_latent = unet_step.pred_original_sample
|
1635 |
+
|
1636 |
+
# Adaptive restoration.
|
1637 |
+
if adastep_restore:
|
1638 |
+
pred_x0_l2 = ((preview_latent[latents.shape[0]:].float()-unet_pred_latent.float())).pow(2).sum(dim=(1,2,3))
|
1639 |
+
previewer_l2 = ((preview_latent[latents.shape[0]:].float()-previewer_mean.float())).pow(2).sum(dim=(1,2,3))
|
1640 |
+
# unet_l2 = ((unet_pred_latent.float()-unet_mean.float())).pow(2).sum(dim=(1,2,3)).sqrt()
|
1641 |
+
# l2_error = (((preview_latent[latents.shape[0]:]-previewer_mean) - (unet_pred_latent-unet_mean))).pow(2).mean(dim=(1,2,3))
|
1642 |
+
# preview_error = torch.nn.functional.cosine_similarity(preview_latent[latents.shape[0]:].reshape(latents.shape[0], -1), unet_pred_latent.reshape(latents.shape[0],-1))
|
1643 |
+
previewer_mean = preview_latent[latents.shape[0]:]
|
1644 |
+
unet_mean = unet_pred_latent
|
1645 |
+
preview_factor = (pred_x0_l2 / previewer_l2).reshape(-1, 1, 1, 1)
|
1646 |
+
|
1647 |
+
if latents.dtype != latents_dtype:
|
1648 |
+
if torch.backends.mps.is_available():
|
1649 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
1650 |
+
latents = latents.to(latents_dtype)
|
1651 |
+
|
1652 |
+
if callback_on_step_end is not None:
|
1653 |
+
callback_kwargs = {}
|
1654 |
+
for k in callback_on_step_end_tensor_inputs:
|
1655 |
+
callback_kwargs[k] = locals()[k]
|
1656 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1657 |
+
|
1658 |
+
latents = callback_outputs.pop("latents", latents)
|
1659 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1660 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
1661 |
+
|
1662 |
+
# call the callback, if provided
|
1663 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1664 |
+
progress_bar.update()
|
1665 |
+
if callback is not None and i % callback_steps == 0:
|
1666 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
1667 |
+
callback(step_idx, t, latents)
|
1668 |
+
|
1669 |
+
if not output_type == "latent":
|
1670 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
1671 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
1672 |
+
|
1673 |
+
if needs_upcasting:
|
1674 |
+
self.upcast_vae()
|
1675 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
1676 |
+
|
1677 |
+
# unscale/denormalize the latents
|
1678 |
+
# denormalize with the mean and std if available and not None
|
1679 |
+
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
1680 |
+
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
1681 |
+
if has_latents_mean and has_latents_std:
|
1682 |
+
latents_mean = (
|
1683 |
+
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
1684 |
+
)
|
1685 |
+
latents_std = (
|
1686 |
+
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
1687 |
+
)
|
1688 |
+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
1689 |
+
else:
|
1690 |
+
latents = latents / self.vae.config.scaling_factor
|
1691 |
+
|
1692 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
1693 |
+
|
1694 |
+
# cast back to fp16 if needed
|
1695 |
+
if needs_upcasting:
|
1696 |
+
self.vae.to(dtype=torch.float16)
|
1697 |
+
else:
|
1698 |
+
image = latents
|
1699 |
+
|
1700 |
+
if not output_type == "latent":
|
1701 |
+
# apply watermark if available
|
1702 |
+
if self.watermark is not None:
|
1703 |
+
image = self.watermark.apply_watermark(image)
|
1704 |
+
|
1705 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
1706 |
+
|
1707 |
+
if save_preview_row:
|
1708 |
+
preview_image_row = []
|
1709 |
+
if needs_upcasting:
|
1710 |
+
self.upcast_vae()
|
1711 |
+
for preview_latents in preview_row:
|
1712 |
+
preview_latents = preview_latents.to(device=self.device, dtype=next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
1713 |
+
if has_latents_mean and has_latents_std:
|
1714 |
+
latents_mean = (
|
1715 |
+
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(preview_latents.device, preview_latents.dtype)
|
1716 |
+
)
|
1717 |
+
latents_std = (
|
1718 |
+
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(preview_latents.device, preview_latents.dtype)
|
1719 |
+
)
|
1720 |
+
preview_latents = preview_latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
1721 |
+
else:
|
1722 |
+
preview_latents = preview_latents / self.vae.config.scaling_factor
|
1723 |
+
|
1724 |
+
preview_image = self.vae.decode(preview_latents, return_dict=False)[0]
|
1725 |
+
preview_image = self.image_processor.postprocess(preview_image, output_type=output_type)
|
1726 |
+
preview_image_row.append(preview_image)
|
1727 |
+
|
1728 |
+
# cast back to fp16 if needed
|
1729 |
+
if needs_upcasting:
|
1730 |
+
self.vae.to(dtype=torch.float16)
|
1731 |
+
|
1732 |
+
# Offload all models
|
1733 |
+
self.maybe_free_model_hooks()
|
1734 |
+
|
1735 |
+
if not return_dict:
|
1736 |
+
if save_preview_row:
|
1737 |
+
return (image, preview_image_row)
|
1738 |
+
return (image,)
|
1739 |
+
|
1740 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
pipelines/stage1_sdxl_pipeline.py
ADDED
@@ -0,0 +1,1283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from transformers import (
|
20 |
+
CLIPImageProcessor,
|
21 |
+
CLIPTextModel,
|
22 |
+
CLIPTextModelWithProjection,
|
23 |
+
CLIPTokenizer,
|
24 |
+
CLIPVisionModelWithProjection,
|
25 |
+
)
|
26 |
+
|
27 |
+
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
28 |
+
from ...loaders import (
|
29 |
+
FromSingleFileMixin,
|
30 |
+
IPAdapterMixin,
|
31 |
+
StableDiffusionXLLoraLoaderMixin,
|
32 |
+
TextualInversionLoaderMixin,
|
33 |
+
)
|
34 |
+
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
35 |
+
from ...models.attention_processor import (
|
36 |
+
AttnProcessor2_0,
|
37 |
+
FusedAttnProcessor2_0,
|
38 |
+
LoRAAttnProcessor2_0,
|
39 |
+
LoRAXFormersAttnProcessor,
|
40 |
+
XFormersAttnProcessor,
|
41 |
+
)
|
42 |
+
from ...models.lora import adjust_lora_scale_text_encoder
|
43 |
+
from ...schedulers import KarrasDiffusionSchedulers
|
44 |
+
from ...utils import (
|
45 |
+
USE_PEFT_BACKEND,
|
46 |
+
deprecate,
|
47 |
+
is_invisible_watermark_available,
|
48 |
+
is_torch_xla_available,
|
49 |
+
logging,
|
50 |
+
replace_example_docstring,
|
51 |
+
scale_lora_layers,
|
52 |
+
unscale_lora_layers,
|
53 |
+
)
|
54 |
+
from ...utils.torch_utils import randn_tensor
|
55 |
+
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
56 |
+
from .pipeline_output import StableDiffusionXLPipelineOutput
|
57 |
+
|
58 |
+
|
59 |
+
if is_invisible_watermark_available():
|
60 |
+
from .watermark import StableDiffusionXLWatermarker
|
61 |
+
|
62 |
+
if is_torch_xla_available():
|
63 |
+
import torch_xla.core.xla_model as xm
|
64 |
+
|
65 |
+
XLA_AVAILABLE = True
|
66 |
+
else:
|
67 |
+
XLA_AVAILABLE = False
|
68 |
+
|
69 |
+
|
70 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
71 |
+
|
72 |
+
EXAMPLE_DOC_STRING = """
|
73 |
+
Examples:
|
74 |
+
```py
|
75 |
+
>>> import torch
|
76 |
+
>>> from diffusers import StableDiffusionXLPipeline
|
77 |
+
|
78 |
+
>>> pipe = StableDiffusionXLPipeline.from_pretrained(
|
79 |
+
... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
80 |
+
... )
|
81 |
+
>>> pipe = pipe.to("cuda")
|
82 |
+
|
83 |
+
>>> prompt = "a photo of an astronaut riding a horse on mars"
|
84 |
+
>>> image = pipe(prompt).images[0]
|
85 |
+
```
|
86 |
+
"""
|
87 |
+
|
88 |
+
|
89 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
90 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
91 |
+
"""
|
92 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
93 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
94 |
+
"""
|
95 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
96 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
97 |
+
# rescale the results from guidance (fixes overexposure)
|
98 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
99 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
100 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
101 |
+
return noise_cfg
|
102 |
+
|
103 |
+
|
104 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
105 |
+
def retrieve_timesteps(
|
106 |
+
scheduler,
|
107 |
+
num_inference_steps: Optional[int] = None,
|
108 |
+
device: Optional[Union[str, torch.device]] = None,
|
109 |
+
timesteps: Optional[List[int]] = None,
|
110 |
+
**kwargs,
|
111 |
+
):
|
112 |
+
"""
|
113 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
114 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
scheduler (`SchedulerMixin`):
|
118 |
+
The scheduler to get timesteps from.
|
119 |
+
num_inference_steps (`int`):
|
120 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
121 |
+
must be `None`.
|
122 |
+
device (`str` or `torch.device`, *optional*):
|
123 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
124 |
+
timesteps (`List[int]`, *optional*):
|
125 |
+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
126 |
+
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
127 |
+
must be `None`.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
131 |
+
second element is the number of inference steps.
|
132 |
+
"""
|
133 |
+
if timesteps is not None:
|
134 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
135 |
+
if not accepts_timesteps:
|
136 |
+
raise ValueError(
|
137 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
138 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
139 |
+
)
|
140 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
141 |
+
timesteps = scheduler.timesteps
|
142 |
+
num_inference_steps = len(timesteps)
|
143 |
+
else:
|
144 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
145 |
+
timesteps = scheduler.timesteps
|
146 |
+
return timesteps, num_inference_steps
|
147 |
+
|
148 |
+
|
149 |
+
class StableDiffusionXLPipeline(
|
150 |
+
DiffusionPipeline,
|
151 |
+
StableDiffusionMixin,
|
152 |
+
FromSingleFileMixin,
|
153 |
+
StableDiffusionXLLoraLoaderMixin,
|
154 |
+
TextualInversionLoaderMixin,
|
155 |
+
IPAdapterMixin,
|
156 |
+
):
|
157 |
+
r"""
|
158 |
+
Pipeline for text-to-image generation using Stable Diffusion XL.
|
159 |
+
|
160 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
161 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
162 |
+
|
163 |
+
The pipeline also inherits the following loading methods:
|
164 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
165 |
+
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
166 |
+
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
167 |
+
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
168 |
+
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
169 |
+
|
170 |
+
Args:
|
171 |
+
vae ([`AutoencoderKL`]):
|
172 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
173 |
+
text_encoder ([`CLIPTextModel`]):
|
174 |
+
Frozen text-encoder. Stable Diffusion XL uses the text portion of
|
175 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
176 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
177 |
+
text_encoder_2 ([` CLIPTextModelWithProjection`]):
|
178 |
+
Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
|
179 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
180 |
+
specifically the
|
181 |
+
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
|
182 |
+
variant.
|
183 |
+
tokenizer (`CLIPTokenizer`):
|
184 |
+
Tokenizer of class
|
185 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
186 |
+
tokenizer_2 (`CLIPTokenizer`):
|
187 |
+
Second Tokenizer of class
|
188 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
189 |
+
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
190 |
+
scheduler ([`SchedulerMixin`]):
|
191 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
192 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
193 |
+
force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
|
194 |
+
Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
|
195 |
+
`stabilityai/stable-diffusion-xl-base-1-0`.
|
196 |
+
add_watermarker (`bool`, *optional*):
|
197 |
+
Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
|
198 |
+
watermark output images. If not defined, it will default to True if the package is installed, otherwise no
|
199 |
+
watermarker will be used.
|
200 |
+
"""
|
201 |
+
|
202 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
|
203 |
+
_optional_components = [
|
204 |
+
"tokenizer",
|
205 |
+
"tokenizer_2",
|
206 |
+
"text_encoder",
|
207 |
+
"text_encoder_2",
|
208 |
+
"image_encoder",
|
209 |
+
"feature_extractor",
|
210 |
+
]
|
211 |
+
_callback_tensor_inputs = [
|
212 |
+
"latents",
|
213 |
+
"prompt_embeds",
|
214 |
+
"negative_prompt_embeds",
|
215 |
+
"add_text_embeds",
|
216 |
+
"add_time_ids",
|
217 |
+
"negative_pooled_prompt_embeds",
|
218 |
+
"negative_add_time_ids",
|
219 |
+
]
|
220 |
+
|
221 |
+
def __init__(
|
222 |
+
self,
|
223 |
+
vae: AutoencoderKL,
|
224 |
+
text_encoder: CLIPTextModel,
|
225 |
+
text_encoder_2: CLIPTextModelWithProjection,
|
226 |
+
tokenizer: CLIPTokenizer,
|
227 |
+
tokenizer_2: CLIPTokenizer,
|
228 |
+
unet: UNet2DConditionModel,
|
229 |
+
scheduler: KarrasDiffusionSchedulers,
|
230 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
231 |
+
feature_extractor: CLIPImageProcessor = None,
|
232 |
+
force_zeros_for_empty_prompt: bool = True,
|
233 |
+
add_watermarker: Optional[bool] = None,
|
234 |
+
):
|
235 |
+
super().__init__()
|
236 |
+
|
237 |
+
self.register_modules(
|
238 |
+
vae=vae,
|
239 |
+
text_encoder=text_encoder,
|
240 |
+
text_encoder_2=text_encoder_2,
|
241 |
+
tokenizer=tokenizer,
|
242 |
+
tokenizer_2=tokenizer_2,
|
243 |
+
unet=unet,
|
244 |
+
scheduler=scheduler,
|
245 |
+
image_encoder=image_encoder,
|
246 |
+
feature_extractor=feature_extractor,
|
247 |
+
)
|
248 |
+
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
249 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
250 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
251 |
+
|
252 |
+
self.default_sample_size = self.unet.config.sample_size
|
253 |
+
|
254 |
+
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
255 |
+
|
256 |
+
if add_watermarker:
|
257 |
+
self.watermark = StableDiffusionXLWatermarker()
|
258 |
+
else:
|
259 |
+
self.watermark = None
|
260 |
+
|
261 |
+
def encode_prompt(
|
262 |
+
self,
|
263 |
+
prompt: str,
|
264 |
+
prompt_2: Optional[str] = None,
|
265 |
+
device: Optional[torch.device] = None,
|
266 |
+
num_images_per_prompt: int = 1,
|
267 |
+
do_classifier_free_guidance: bool = True,
|
268 |
+
negative_prompt: Optional[str] = None,
|
269 |
+
negative_prompt_2: Optional[str] = None,
|
270 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
271 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
272 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
273 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
274 |
+
lora_scale: Optional[float] = None,
|
275 |
+
clip_skip: Optional[int] = None,
|
276 |
+
):
|
277 |
+
r"""
|
278 |
+
Encodes the prompt into text encoder hidden states.
|
279 |
+
|
280 |
+
Args:
|
281 |
+
prompt (`str` or `List[str]`, *optional*):
|
282 |
+
prompt to be encoded
|
283 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
284 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
285 |
+
used in both text-encoders
|
286 |
+
device: (`torch.device`):
|
287 |
+
torch device
|
288 |
+
num_images_per_prompt (`int`):
|
289 |
+
number of images that should be generated per prompt
|
290 |
+
do_classifier_free_guidance (`bool`):
|
291 |
+
whether to use classifier free guidance or not
|
292 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
293 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
294 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
295 |
+
less than `1`).
|
296 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
297 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
298 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
299 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
300 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
301 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
302 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
303 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
304 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
305 |
+
argument.
|
306 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
307 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
308 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
309 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
310 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
311 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
312 |
+
input argument.
|
313 |
+
lora_scale (`float`, *optional*):
|
314 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
315 |
+
clip_skip (`int`, *optional*):
|
316 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
317 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
318 |
+
"""
|
319 |
+
device = device or self._execution_device
|
320 |
+
|
321 |
+
# set lora scale so that monkey patched LoRA
|
322 |
+
# function of text encoder can correctly access it
|
323 |
+
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
|
324 |
+
self._lora_scale = lora_scale
|
325 |
+
|
326 |
+
# dynamically adjust the LoRA scale
|
327 |
+
if self.text_encoder is not None:
|
328 |
+
if not USE_PEFT_BACKEND:
|
329 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
330 |
+
else:
|
331 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
332 |
+
|
333 |
+
if self.text_encoder_2 is not None:
|
334 |
+
if not USE_PEFT_BACKEND:
|
335 |
+
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
336 |
+
else:
|
337 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
338 |
+
|
339 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
340 |
+
|
341 |
+
if prompt is not None:
|
342 |
+
batch_size = len(prompt)
|
343 |
+
else:
|
344 |
+
batch_size = prompt_embeds.shape[0]
|
345 |
+
|
346 |
+
# Define tokenizers and text encoders
|
347 |
+
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
|
348 |
+
text_encoders = (
|
349 |
+
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
350 |
+
)
|
351 |
+
|
352 |
+
if prompt_embeds is None:
|
353 |
+
prompt_2 = prompt_2 or prompt
|
354 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
355 |
+
|
356 |
+
# textual inversion: process multi-vector tokens if necessary
|
357 |
+
prompt_embeds_list = []
|
358 |
+
prompts = [prompt, prompt_2]
|
359 |
+
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
|
360 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
361 |
+
prompt = self.maybe_convert_prompt(prompt, tokenizer)
|
362 |
+
|
363 |
+
text_inputs = tokenizer(
|
364 |
+
prompt,
|
365 |
+
padding="max_length",
|
366 |
+
max_length=tokenizer.model_max_length,
|
367 |
+
truncation=True,
|
368 |
+
return_tensors="pt",
|
369 |
+
)
|
370 |
+
|
371 |
+
text_input_ids = text_inputs.input_ids
|
372 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
373 |
+
|
374 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
375 |
+
text_input_ids, untruncated_ids
|
376 |
+
):
|
377 |
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
378 |
+
logger.warning(
|
379 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
380 |
+
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
381 |
+
)
|
382 |
+
|
383 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
384 |
+
|
385 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
386 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
387 |
+
if clip_skip is None:
|
388 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
389 |
+
else:
|
390 |
+
# "2" because SDXL always indexes from the penultimate layer.
|
391 |
+
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
392 |
+
|
393 |
+
prompt_embeds_list.append(prompt_embeds)
|
394 |
+
|
395 |
+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
396 |
+
|
397 |
+
# get unconditional embeddings for classifier free guidance
|
398 |
+
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
|
399 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
|
400 |
+
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
401 |
+
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
402 |
+
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
403 |
+
negative_prompt = negative_prompt or ""
|
404 |
+
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
405 |
+
|
406 |
+
# normalize str to list
|
407 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
408 |
+
negative_prompt_2 = (
|
409 |
+
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
410 |
+
)
|
411 |
+
|
412 |
+
uncond_tokens: List[str]
|
413 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
414 |
+
raise TypeError(
|
415 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
416 |
+
f" {type(prompt)}."
|
417 |
+
)
|
418 |
+
elif batch_size != len(negative_prompt):
|
419 |
+
raise ValueError(
|
420 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
421 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
422 |
+
" the batch size of `prompt`."
|
423 |
+
)
|
424 |
+
else:
|
425 |
+
uncond_tokens = [negative_prompt, negative_prompt_2]
|
426 |
+
|
427 |
+
negative_prompt_embeds_list = []
|
428 |
+
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
|
429 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
430 |
+
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
|
431 |
+
|
432 |
+
max_length = prompt_embeds.shape[1]
|
433 |
+
uncond_input = tokenizer(
|
434 |
+
negative_prompt,
|
435 |
+
padding="max_length",
|
436 |
+
max_length=max_length,
|
437 |
+
truncation=True,
|
438 |
+
return_tensors="pt",
|
439 |
+
)
|
440 |
+
|
441 |
+
negative_prompt_embeds = text_encoder(
|
442 |
+
uncond_input.input_ids.to(device),
|
443 |
+
output_hidden_states=True,
|
444 |
+
)
|
445 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
446 |
+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
447 |
+
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
448 |
+
|
449 |
+
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
450 |
+
|
451 |
+
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
452 |
+
|
453 |
+
if self.text_encoder_2 is not None:
|
454 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
455 |
+
else:
|
456 |
+
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
457 |
+
|
458 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
459 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
460 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
461 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
462 |
+
|
463 |
+
if do_classifier_free_guidance:
|
464 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
465 |
+
seq_len = negative_prompt_embeds.shape[1]
|
466 |
+
|
467 |
+
if self.text_encoder_2 is not None:
|
468 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
469 |
+
else:
|
470 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
|
471 |
+
|
472 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
473 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
474 |
+
|
475 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
476 |
+
bs_embed * num_images_per_prompt, -1
|
477 |
+
)
|
478 |
+
if do_classifier_free_guidance:
|
479 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
480 |
+
bs_embed * num_images_per_prompt, -1
|
481 |
+
)
|
482 |
+
|
483 |
+
if self.text_encoder is not None:
|
484 |
+
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
485 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
486 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
487 |
+
|
488 |
+
if self.text_encoder_2 is not None:
|
489 |
+
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
490 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
491 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
492 |
+
|
493 |
+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
494 |
+
|
495 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
496 |
+
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
497 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
498 |
+
|
499 |
+
if not isinstance(image, torch.Tensor):
|
500 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
501 |
+
|
502 |
+
image = image.to(device=device, dtype=dtype)
|
503 |
+
if output_hidden_states:
|
504 |
+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
505 |
+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
506 |
+
uncond_image_enc_hidden_states = self.image_encoder(
|
507 |
+
torch.zeros_like(image), output_hidden_states=True
|
508 |
+
).hidden_states[-2]
|
509 |
+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
510 |
+
num_images_per_prompt, dim=0
|
511 |
+
)
|
512 |
+
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
513 |
+
else:
|
514 |
+
image_embeds = self.image_encoder(image).image_embeds
|
515 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
516 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
517 |
+
|
518 |
+
return image_embeds, uncond_image_embeds
|
519 |
+
|
520 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
521 |
+
def prepare_ip_adapter_image_embeds(
|
522 |
+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
|
523 |
+
):
|
524 |
+
if ip_adapter_image_embeds is None:
|
525 |
+
if not isinstance(ip_adapter_image, list):
|
526 |
+
ip_adapter_image = [ip_adapter_image]
|
527 |
+
|
528 |
+
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
529 |
+
raise ValueError(
|
530 |
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
531 |
+
)
|
532 |
+
|
533 |
+
image_embeds = []
|
534 |
+
for single_ip_adapter_image, image_proj_layer in zip(
|
535 |
+
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
536 |
+
):
|
537 |
+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
538 |
+
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
539 |
+
single_ip_adapter_image, device, 1, output_hidden_state
|
540 |
+
)
|
541 |
+
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
542 |
+
single_negative_image_embeds = torch.stack(
|
543 |
+
[single_negative_image_embeds] * num_images_per_prompt, dim=0
|
544 |
+
)
|
545 |
+
|
546 |
+
if do_classifier_free_guidance:
|
547 |
+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
548 |
+
single_image_embeds = single_image_embeds.to(device)
|
549 |
+
|
550 |
+
image_embeds.append(single_image_embeds)
|
551 |
+
else:
|
552 |
+
repeat_dims = [1]
|
553 |
+
image_embeds = []
|
554 |
+
for single_image_embeds in ip_adapter_image_embeds:
|
555 |
+
if do_classifier_free_guidance:
|
556 |
+
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
557 |
+
single_image_embeds = single_image_embeds.repeat(
|
558 |
+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
559 |
+
)
|
560 |
+
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
561 |
+
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
562 |
+
)
|
563 |
+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
564 |
+
else:
|
565 |
+
single_image_embeds = single_image_embeds.repeat(
|
566 |
+
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
567 |
+
)
|
568 |
+
image_embeds.append(single_image_embeds)
|
569 |
+
|
570 |
+
return image_embeds
|
571 |
+
|
572 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
573 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
574 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
575 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
576 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
577 |
+
# and should be between [0, 1]
|
578 |
+
|
579 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
580 |
+
extra_step_kwargs = {}
|
581 |
+
if accepts_eta:
|
582 |
+
extra_step_kwargs["eta"] = eta
|
583 |
+
|
584 |
+
# check if the scheduler accepts generator
|
585 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
586 |
+
if accepts_generator:
|
587 |
+
extra_step_kwargs["generator"] = generator
|
588 |
+
return extra_step_kwargs
|
589 |
+
|
590 |
+
def check_inputs(
|
591 |
+
self,
|
592 |
+
prompt,
|
593 |
+
prompt_2,
|
594 |
+
height,
|
595 |
+
width,
|
596 |
+
callback_steps,
|
597 |
+
negative_prompt=None,
|
598 |
+
negative_prompt_2=None,
|
599 |
+
prompt_embeds=None,
|
600 |
+
negative_prompt_embeds=None,
|
601 |
+
pooled_prompt_embeds=None,
|
602 |
+
negative_pooled_prompt_embeds=None,
|
603 |
+
ip_adapter_image=None,
|
604 |
+
ip_adapter_image_embeds=None,
|
605 |
+
callback_on_step_end_tensor_inputs=None,
|
606 |
+
):
|
607 |
+
if height % 8 != 0 or width % 8 != 0:
|
608 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
609 |
+
|
610 |
+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
611 |
+
raise ValueError(
|
612 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
613 |
+
f" {type(callback_steps)}."
|
614 |
+
)
|
615 |
+
|
616 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
617 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
618 |
+
):
|
619 |
+
raise ValueError(
|
620 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
621 |
+
)
|
622 |
+
|
623 |
+
if prompt is not None and prompt_embeds is not None:
|
624 |
+
raise ValueError(
|
625 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
626 |
+
" only forward one of the two."
|
627 |
+
)
|
628 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
629 |
+
raise ValueError(
|
630 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
631 |
+
" only forward one of the two."
|
632 |
+
)
|
633 |
+
elif prompt is None and prompt_embeds is None:
|
634 |
+
raise ValueError(
|
635 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
636 |
+
)
|
637 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
638 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
639 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
640 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
641 |
+
|
642 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
643 |
+
raise ValueError(
|
644 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
645 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
646 |
+
)
|
647 |
+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
648 |
+
raise ValueError(
|
649 |
+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
650 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
651 |
+
)
|
652 |
+
|
653 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
654 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
655 |
+
raise ValueError(
|
656 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
657 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
658 |
+
f" {negative_prompt_embeds.shape}."
|
659 |
+
)
|
660 |
+
|
661 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
662 |
+
raise ValueError(
|
663 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
664 |
+
)
|
665 |
+
|
666 |
+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
667 |
+
raise ValueError(
|
668 |
+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
669 |
+
)
|
670 |
+
|
671 |
+
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
672 |
+
raise ValueError(
|
673 |
+
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
674 |
+
)
|
675 |
+
|
676 |
+
if ip_adapter_image_embeds is not None:
|
677 |
+
if not isinstance(ip_adapter_image_embeds, list):
|
678 |
+
raise ValueError(
|
679 |
+
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
680 |
+
)
|
681 |
+
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
682 |
+
raise ValueError(
|
683 |
+
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
684 |
+
)
|
685 |
+
|
686 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
687 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
688 |
+
shape = (
|
689 |
+
batch_size,
|
690 |
+
num_channels_latents,
|
691 |
+
int(height) // self.vae_scale_factor,
|
692 |
+
int(width) // self.vae_scale_factor,
|
693 |
+
)
|
694 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
695 |
+
raise ValueError(
|
696 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
697 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
698 |
+
)
|
699 |
+
|
700 |
+
if latents is None:
|
701 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
702 |
+
else:
|
703 |
+
latents = latents.to(device)
|
704 |
+
|
705 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
706 |
+
latents = latents * self.scheduler.init_noise_sigma
|
707 |
+
return latents
|
708 |
+
|
709 |
+
def _get_add_time_ids(
|
710 |
+
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
711 |
+
):
|
712 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
713 |
+
|
714 |
+
passed_add_embed_dim = (
|
715 |
+
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
716 |
+
)
|
717 |
+
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
718 |
+
|
719 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
720 |
+
raise ValueError(
|
721 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
722 |
+
)
|
723 |
+
|
724 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
725 |
+
return add_time_ids
|
726 |
+
|
727 |
+
def upcast_vae(self):
|
728 |
+
dtype = self.vae.dtype
|
729 |
+
self.vae.to(dtype=torch.float32)
|
730 |
+
use_torch_2_0_or_xformers = isinstance(
|
731 |
+
self.vae.decoder.mid_block.attentions[0].processor,
|
732 |
+
(
|
733 |
+
AttnProcessor2_0,
|
734 |
+
XFormersAttnProcessor,
|
735 |
+
LoRAXFormersAttnProcessor,
|
736 |
+
LoRAAttnProcessor2_0,
|
737 |
+
FusedAttnProcessor2_0,
|
738 |
+
),
|
739 |
+
)
|
740 |
+
# if xformers or torch_2_0 is used attention block does not need
|
741 |
+
# to be in float32 which can save lots of memory
|
742 |
+
if use_torch_2_0_or_xformers:
|
743 |
+
self.vae.post_quant_conv.to(dtype)
|
744 |
+
self.vae.decoder.conv_in.to(dtype)
|
745 |
+
self.vae.decoder.mid_block.to(dtype)
|
746 |
+
|
747 |
+
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
748 |
+
def get_guidance_scale_embedding(
|
749 |
+
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
|
750 |
+
) -> torch.FloatTensor:
|
751 |
+
"""
|
752 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
753 |
+
|
754 |
+
Args:
|
755 |
+
w (`torch.Tensor`):
|
756 |
+
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
|
757 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
758 |
+
Dimension of the embeddings to generate.
|
759 |
+
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
760 |
+
Data type of the generated embeddings.
|
761 |
+
|
762 |
+
Returns:
|
763 |
+
`torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
|
764 |
+
"""
|
765 |
+
assert len(w.shape) == 1
|
766 |
+
w = w * 1000.0
|
767 |
+
|
768 |
+
half_dim = embedding_dim // 2
|
769 |
+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
770 |
+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
771 |
+
emb = w.to(dtype)[:, None] * emb[None, :]
|
772 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
773 |
+
if embedding_dim % 2 == 1: # zero pad
|
774 |
+
emb = torch.nn.functional.pad(emb, (0, 1))
|
775 |
+
assert emb.shape == (w.shape[0], embedding_dim)
|
776 |
+
return emb
|
777 |
+
|
778 |
+
@property
|
779 |
+
def guidance_scale(self):
|
780 |
+
return self._guidance_scale
|
781 |
+
|
782 |
+
@property
|
783 |
+
def guidance_rescale(self):
|
784 |
+
return self._guidance_rescale
|
785 |
+
|
786 |
+
@property
|
787 |
+
def clip_skip(self):
|
788 |
+
return self._clip_skip
|
789 |
+
|
790 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
791 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
792 |
+
# corresponds to doing no classifier free guidance.
|
793 |
+
@property
|
794 |
+
def do_classifier_free_guidance(self):
|
795 |
+
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
796 |
+
|
797 |
+
@property
|
798 |
+
def cross_attention_kwargs(self):
|
799 |
+
return self._cross_attention_kwargs
|
800 |
+
|
801 |
+
@property
|
802 |
+
def denoising_end(self):
|
803 |
+
return self._denoising_end
|
804 |
+
|
805 |
+
@property
|
806 |
+
def num_timesteps(self):
|
807 |
+
return self._num_timesteps
|
808 |
+
|
809 |
+
@property
|
810 |
+
def interrupt(self):
|
811 |
+
return self._interrupt
|
812 |
+
|
813 |
+
@torch.no_grad()
|
814 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
815 |
+
def __call__(
|
816 |
+
self,
|
817 |
+
prompt: Union[str, List[str]] = None,
|
818 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
819 |
+
height: Optional[int] = None,
|
820 |
+
width: Optional[int] = None,
|
821 |
+
num_inference_steps: int = 50,
|
822 |
+
timesteps: List[int] = None,
|
823 |
+
denoising_end: Optional[float] = None,
|
824 |
+
guidance_scale: float = 5.0,
|
825 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
826 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
827 |
+
num_images_per_prompt: Optional[int] = 1,
|
828 |
+
eta: float = 0.0,
|
829 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
830 |
+
latents: Optional[torch.FloatTensor] = None,
|
831 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
832 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
833 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
834 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
835 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
836 |
+
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
|
837 |
+
output_type: Optional[str] = "pil",
|
838 |
+
return_dict: bool = True,
|
839 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
840 |
+
guidance_rescale: float = 0.0,
|
841 |
+
original_size: Optional[Tuple[int, int]] = None,
|
842 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
843 |
+
target_size: Optional[Tuple[int, int]] = None,
|
844 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
845 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
846 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
847 |
+
clip_skip: Optional[int] = None,
|
848 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
849 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
850 |
+
**kwargs,
|
851 |
+
):
|
852 |
+
r"""
|
853 |
+
Function invoked when calling the pipeline for generation.
|
854 |
+
|
855 |
+
Args:
|
856 |
+
prompt (`str` or `List[str]`, *optional*):
|
857 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
858 |
+
instead.
|
859 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
860 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
861 |
+
used in both text-encoders
|
862 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
863 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
864 |
+
Anything below 512 pixels won't work well for
|
865 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
866 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
867 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
868 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
869 |
+
Anything below 512 pixels won't work well for
|
870 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
871 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
872 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
873 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
874 |
+
expense of slower inference.
|
875 |
+
timesteps (`List[int]`, *optional*):
|
876 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
877 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
878 |
+
passed will be used. Must be in descending order.
|
879 |
+
denoising_end (`float`, *optional*):
|
880 |
+
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
881 |
+
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
882 |
+
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
883 |
+
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
884 |
+
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
885 |
+
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
886 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
887 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
888 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
889 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
890 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
891 |
+
usually at the expense of lower image quality.
|
892 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
893 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
894 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
895 |
+
less than `1`).
|
896 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
897 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
898 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
899 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
900 |
+
The number of images to generate per prompt.
|
901 |
+
eta (`float`, *optional*, defaults to 0.0):
|
902 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
903 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
904 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
905 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
906 |
+
to make generation deterministic.
|
907 |
+
latents (`torch.FloatTensor`, *optional*):
|
908 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
909 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
910 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
911 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
912 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
913 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
914 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
915 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
916 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
917 |
+
argument.
|
918 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
919 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
920 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
921 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
922 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
923 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
924 |
+
input argument.
|
925 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
926 |
+
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
|
927 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
928 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
|
929 |
+
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
|
930 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
931 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
932 |
+
The output format of the generate image. Choose between
|
933 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
934 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
935 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
936 |
+
of a plain tuple.
|
937 |
+
cross_attention_kwargs (`dict`, *optional*):
|
938 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
939 |
+
`self.processor` in
|
940 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
941 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
942 |
+
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
943 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
944 |
+
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
945 |
+
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
946 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
947 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
948 |
+
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
949 |
+
explained in section 2.2 of
|
950 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
951 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
952 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
953 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
954 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
955 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
956 |
+
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
957 |
+
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
958 |
+
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
959 |
+
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
960 |
+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
961 |
+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
962 |
+
micro-conditioning as explained in section 2.2 of
|
963 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
964 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
965 |
+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
966 |
+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
967 |
+
micro-conditioning as explained in section 2.2 of
|
968 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
969 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
970 |
+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
971 |
+
To negatively condition the generation process based on a target image resolution. It should be as same
|
972 |
+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
973 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
974 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
975 |
+
callback_on_step_end (`Callable`, *optional*):
|
976 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
977 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
978 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
979 |
+
`callback_on_step_end_tensor_inputs`.
|
980 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
981 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
982 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
983 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
984 |
+
|
985 |
+
Examples:
|
986 |
+
|
987 |
+
Returns:
|
988 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
989 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
990 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
991 |
+
"""
|
992 |
+
|
993 |
+
callback = kwargs.pop("callback", None)
|
994 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
995 |
+
|
996 |
+
if callback is not None:
|
997 |
+
deprecate(
|
998 |
+
"callback",
|
999 |
+
"1.0.0",
|
1000 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
1001 |
+
)
|
1002 |
+
if callback_steps is not None:
|
1003 |
+
deprecate(
|
1004 |
+
"callback_steps",
|
1005 |
+
"1.0.0",
|
1006 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
1007 |
+
)
|
1008 |
+
|
1009 |
+
# 0. Default height and width to unet
|
1010 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
1011 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
1012 |
+
|
1013 |
+
original_size = original_size or (height, width)
|
1014 |
+
target_size = target_size or (height, width)
|
1015 |
+
|
1016 |
+
# 1. Check inputs. Raise error if not correct
|
1017 |
+
self.check_inputs(
|
1018 |
+
prompt,
|
1019 |
+
prompt_2,
|
1020 |
+
height,
|
1021 |
+
width,
|
1022 |
+
callback_steps,
|
1023 |
+
negative_prompt,
|
1024 |
+
negative_prompt_2,
|
1025 |
+
prompt_embeds,
|
1026 |
+
negative_prompt_embeds,
|
1027 |
+
pooled_prompt_embeds,
|
1028 |
+
negative_pooled_prompt_embeds,
|
1029 |
+
ip_adapter_image,
|
1030 |
+
ip_adapter_image_embeds,
|
1031 |
+
callback_on_step_end_tensor_inputs,
|
1032 |
+
)
|
1033 |
+
|
1034 |
+
self._guidance_scale = guidance_scale
|
1035 |
+
self._guidance_rescale = guidance_rescale
|
1036 |
+
self._clip_skip = clip_skip
|
1037 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
1038 |
+
self._denoising_end = denoising_end
|
1039 |
+
self._interrupt = False
|
1040 |
+
|
1041 |
+
# 2. Define call parameters
|
1042 |
+
if prompt is not None and isinstance(prompt, str):
|
1043 |
+
batch_size = 1
|
1044 |
+
elif prompt is not None and isinstance(prompt, list):
|
1045 |
+
batch_size = len(prompt)
|
1046 |
+
else:
|
1047 |
+
batch_size = prompt_embeds.shape[0]
|
1048 |
+
|
1049 |
+
device = self._execution_device
|
1050 |
+
|
1051 |
+
# 3. Encode input prompt
|
1052 |
+
lora_scale = (
|
1053 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
1054 |
+
)
|
1055 |
+
|
1056 |
+
(
|
1057 |
+
prompt_embeds,
|
1058 |
+
negative_prompt_embeds,
|
1059 |
+
pooled_prompt_embeds,
|
1060 |
+
negative_pooled_prompt_embeds,
|
1061 |
+
) = self.encode_prompt(
|
1062 |
+
prompt=prompt,
|
1063 |
+
prompt_2=prompt_2,
|
1064 |
+
device=device,
|
1065 |
+
num_images_per_prompt=num_images_per_prompt,
|
1066 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
1067 |
+
negative_prompt=negative_prompt,
|
1068 |
+
negative_prompt_2=negative_prompt_2,
|
1069 |
+
prompt_embeds=prompt_embeds,
|
1070 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
1071 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
1072 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
1073 |
+
lora_scale=lora_scale,
|
1074 |
+
clip_skip=self.clip_skip,
|
1075 |
+
)
|
1076 |
+
|
1077 |
+
# 4. Prepare timesteps
|
1078 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
1079 |
+
|
1080 |
+
# 5. Prepare latent variables
|
1081 |
+
num_channels_latents = self.unet.config.in_channels
|
1082 |
+
latents = self.prepare_latents(
|
1083 |
+
batch_size * num_images_per_prompt,
|
1084 |
+
num_channels_latents,
|
1085 |
+
height,
|
1086 |
+
width,
|
1087 |
+
prompt_embeds.dtype,
|
1088 |
+
device,
|
1089 |
+
generator,
|
1090 |
+
latents,
|
1091 |
+
)
|
1092 |
+
|
1093 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
1094 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
1095 |
+
|
1096 |
+
# 7. Prepare added time ids & embeddings
|
1097 |
+
add_text_embeds = pooled_prompt_embeds
|
1098 |
+
if self.text_encoder_2 is None:
|
1099 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
1100 |
+
else:
|
1101 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
1102 |
+
|
1103 |
+
add_time_ids = self._get_add_time_ids(
|
1104 |
+
original_size,
|
1105 |
+
crops_coords_top_left,
|
1106 |
+
target_size,
|
1107 |
+
dtype=prompt_embeds.dtype,
|
1108 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
1109 |
+
)
|
1110 |
+
if negative_original_size is not None and negative_target_size is not None:
|
1111 |
+
negative_add_time_ids = self._get_add_time_ids(
|
1112 |
+
negative_original_size,
|
1113 |
+
negative_crops_coords_top_left,
|
1114 |
+
negative_target_size,
|
1115 |
+
dtype=prompt_embeds.dtype,
|
1116 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
1117 |
+
)
|
1118 |
+
else:
|
1119 |
+
negative_add_time_ids = add_time_ids
|
1120 |
+
|
1121 |
+
if self.do_classifier_free_guidance:
|
1122 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
1123 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
1124 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
1125 |
+
|
1126 |
+
prompt_embeds = prompt_embeds.to(device)
|
1127 |
+
add_text_embeds = add_text_embeds.to(device)
|
1128 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
1129 |
+
|
1130 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
1131 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
1132 |
+
ip_adapter_image,
|
1133 |
+
ip_adapter_image_embeds,
|
1134 |
+
device,
|
1135 |
+
batch_size * num_images_per_prompt,
|
1136 |
+
self.do_classifier_free_guidance,
|
1137 |
+
)
|
1138 |
+
|
1139 |
+
# 8. Denoising loop
|
1140 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
1141 |
+
|
1142 |
+
# 8.1 Apply denoising_end
|
1143 |
+
if (
|
1144 |
+
self.denoising_end is not None
|
1145 |
+
and isinstance(self.denoising_end, float)
|
1146 |
+
and self.denoising_end > 0
|
1147 |
+
and self.denoising_end < 1
|
1148 |
+
):
|
1149 |
+
discrete_timestep_cutoff = int(
|
1150 |
+
round(
|
1151 |
+
self.scheduler.config.num_train_timesteps
|
1152 |
+
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
1153 |
+
)
|
1154 |
+
)
|
1155 |
+
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
1156 |
+
timesteps = timesteps[:num_inference_steps]
|
1157 |
+
|
1158 |
+
# 9. Optionally get Guidance Scale Embedding
|
1159 |
+
timestep_cond = None
|
1160 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
1161 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
1162 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
1163 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
1164 |
+
).to(device=device, dtype=latents.dtype)
|
1165 |
+
|
1166 |
+
self._num_timesteps = len(timesteps)
|
1167 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1168 |
+
for i, t in enumerate(timesteps):
|
1169 |
+
if self.interrupt:
|
1170 |
+
continue
|
1171 |
+
|
1172 |
+
# expand the latents if we are doing classifier free guidance
|
1173 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
1174 |
+
|
1175 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1176 |
+
|
1177 |
+
# predict the noise residual
|
1178 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
1179 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
1180 |
+
added_cond_kwargs["image_embeds"] = image_embeds
|
1181 |
+
|
1182 |
+
noise_pred = self.unet(
|
1183 |
+
latent_model_input,
|
1184 |
+
t,
|
1185 |
+
encoder_hidden_states=prompt_embeds, # [B, 77, 2048]
|
1186 |
+
timestep_cond=timestep_cond, # None
|
1187 |
+
cross_attention_kwargs=self.cross_attention_kwargs, # None
|
1188 |
+
added_cond_kwargs=added_cond_kwargs, # {[B, 1280], [B, 6]}
|
1189 |
+
return_dict=False,
|
1190 |
+
)[0]
|
1191 |
+
|
1192 |
+
# perform guidance
|
1193 |
+
if self.do_classifier_free_guidance:
|
1194 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1195 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1196 |
+
|
1197 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
1198 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
1199 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
1200 |
+
|
1201 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1202 |
+
latents_dtype = latents.dtype
|
1203 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
1204 |
+
if latents.dtype != latents_dtype:
|
1205 |
+
if torch.backends.mps.is_available():
|
1206 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
1207 |
+
latents = latents.to(latents_dtype)
|
1208 |
+
|
1209 |
+
if callback_on_step_end is not None:
|
1210 |
+
callback_kwargs = {}
|
1211 |
+
for k in callback_on_step_end_tensor_inputs:
|
1212 |
+
callback_kwargs[k] = locals()[k]
|
1213 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1214 |
+
|
1215 |
+
latents = callback_outputs.pop("latents", latents)
|
1216 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1217 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
1218 |
+
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
1219 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
1220 |
+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
1221 |
+
)
|
1222 |
+
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
1223 |
+
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
|
1224 |
+
|
1225 |
+
# call the callback, if provided
|
1226 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1227 |
+
progress_bar.update()
|
1228 |
+
if callback is not None and i % callback_steps == 0:
|
1229 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
1230 |
+
callback(step_idx, t, latents)
|
1231 |
+
|
1232 |
+
if XLA_AVAILABLE:
|
1233 |
+
xm.mark_step()
|
1234 |
+
|
1235 |
+
if not output_type == "latent":
|
1236 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
1237 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
1238 |
+
|
1239 |
+
if needs_upcasting:
|
1240 |
+
self.upcast_vae()
|
1241 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
1242 |
+
elif latents.dtype != self.vae.dtype:
|
1243 |
+
if torch.backends.mps.is_available():
|
1244 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
1245 |
+
self.vae = self.vae.to(latents.dtype)
|
1246 |
+
|
1247 |
+
# unscale/denormalize the latents
|
1248 |
+
# denormalize with the mean and std if available and not None
|
1249 |
+
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
1250 |
+
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
1251 |
+
if has_latents_mean and has_latents_std:
|
1252 |
+
latents_mean = (
|
1253 |
+
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
1254 |
+
)
|
1255 |
+
latents_std = (
|
1256 |
+
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
1257 |
+
)
|
1258 |
+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
1259 |
+
else:
|
1260 |
+
latents = latents / self.vae.config.scaling_factor
|
1261 |
+
|
1262 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
1263 |
+
|
1264 |
+
# cast back to fp16 if needed
|
1265 |
+
if needs_upcasting:
|
1266 |
+
self.vae.to(dtype=torch.float16)
|
1267 |
+
else:
|
1268 |
+
image = latents
|
1269 |
+
|
1270 |
+
if not output_type == "latent":
|
1271 |
+
# apply watermark if available
|
1272 |
+
if self.watermark is not None:
|
1273 |
+
image = self.watermark.apply_watermark(image)
|
1274 |
+
|
1275 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
1276 |
+
|
1277 |
+
# Offload all models
|
1278 |
+
self.maybe_free_model_hooks()
|
1279 |
+
|
1280 |
+
if not return_dict:
|
1281 |
+
return (image,)
|
1282 |
+
|
1283 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers==0.28.1
|
2 |
+
accelerate==0.25.0
|
3 |
+
datasets==2.19.1
|
4 |
+
einops==0.8.0
|
5 |
+
kornia==0.7.2
|
6 |
+
numpy==1.26.4
|
7 |
+
opencv-python==4.9.0.80
|
8 |
+
peft==0.10.0
|
9 |
+
pyrallis==0.3.1
|
10 |
+
tokenizers==0.15.2
|
11 |
+
torch==2.0.1
|
12 |
+
torchvision==0.15.2
|
13 |
+
transformers==4.36.2
|
14 |
+
gradio==4.44.1
|
schedulers/lcm_single_step_scheduler.py
ADDED
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
16 |
+
# and https://github.com/hojonathanho/diffusion
|
17 |
+
|
18 |
+
import math
|
19 |
+
from dataclasses import dataclass
|
20 |
+
from typing import List, Optional, Tuple, Union
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
|
25 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
26 |
+
from diffusers.utils import BaseOutput, logging
|
27 |
+
from diffusers.utils.torch_utils import randn_tensor
|
28 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class LCMSingleStepSchedulerOutput(BaseOutput):
|
36 |
+
"""
|
37 |
+
Output class for the scheduler's `step` function output.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
41 |
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
42 |
+
`pred_original_sample` can be used to preview progress or for guidance.
|
43 |
+
"""
|
44 |
+
|
45 |
+
denoised: Optional[torch.FloatTensor] = None
|
46 |
+
|
47 |
+
|
48 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
49 |
+
def betas_for_alpha_bar(
|
50 |
+
num_diffusion_timesteps,
|
51 |
+
max_beta=0.999,
|
52 |
+
alpha_transform_type="cosine",
|
53 |
+
):
|
54 |
+
"""
|
55 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
56 |
+
(1-beta) over time from t = [0,1].
|
57 |
+
|
58 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
59 |
+
to that part of the diffusion process.
|
60 |
+
|
61 |
+
|
62 |
+
Args:
|
63 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
64 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
65 |
+
prevent singularities.
|
66 |
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
67 |
+
Choose from `cosine` or `exp`
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
71 |
+
"""
|
72 |
+
if alpha_transform_type == "cosine":
|
73 |
+
|
74 |
+
def alpha_bar_fn(t):
|
75 |
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
76 |
+
|
77 |
+
elif alpha_transform_type == "exp":
|
78 |
+
|
79 |
+
def alpha_bar_fn(t):
|
80 |
+
return math.exp(t * -12.0)
|
81 |
+
|
82 |
+
else:
|
83 |
+
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
84 |
+
|
85 |
+
betas = []
|
86 |
+
for i in range(num_diffusion_timesteps):
|
87 |
+
t1 = i / num_diffusion_timesteps
|
88 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
89 |
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
90 |
+
return torch.tensor(betas, dtype=torch.float32)
|
91 |
+
|
92 |
+
|
93 |
+
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
94 |
+
def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor:
|
95 |
+
"""
|
96 |
+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
97 |
+
|
98 |
+
|
99 |
+
Args:
|
100 |
+
betas (`torch.FloatTensor`):
|
101 |
+
the betas that the scheduler is being initialized with.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
105 |
+
"""
|
106 |
+
# Convert betas to alphas_bar_sqrt
|
107 |
+
alphas = 1.0 - betas
|
108 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
109 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
110 |
+
|
111 |
+
# Store old values.
|
112 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
113 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
114 |
+
|
115 |
+
# Shift so the last timestep is zero.
|
116 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
117 |
+
|
118 |
+
# Scale so the first timestep is back to the old value.
|
119 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
120 |
+
|
121 |
+
# Convert alphas_bar_sqrt to betas
|
122 |
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
123 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
124 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
125 |
+
betas = 1 - alphas
|
126 |
+
|
127 |
+
return betas
|
128 |
+
|
129 |
+
|
130 |
+
class LCMSingleStepScheduler(SchedulerMixin, ConfigMixin):
|
131 |
+
"""
|
132 |
+
`LCMSingleStepScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
|
133 |
+
non-Markovian guidance.
|
134 |
+
|
135 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config
|
136 |
+
attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be
|
137 |
+
accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving
|
138 |
+
functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
num_train_timesteps (`int`, defaults to 1000):
|
142 |
+
The number of diffusion steps to train the model.
|
143 |
+
beta_start (`float`, defaults to 0.0001):
|
144 |
+
The starting `beta` value of inference.
|
145 |
+
beta_end (`float`, defaults to 0.02):
|
146 |
+
The final `beta` value.
|
147 |
+
beta_schedule (`str`, defaults to `"linear"`):
|
148 |
+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
149 |
+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
150 |
+
trained_betas (`np.ndarray`, *optional*):
|
151 |
+
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
152 |
+
original_inference_steps (`int`, *optional*, defaults to 50):
|
153 |
+
The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we
|
154 |
+
will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule.
|
155 |
+
clip_sample (`bool`, defaults to `True`):
|
156 |
+
Clip the predicted sample for numerical stability.
|
157 |
+
clip_sample_range (`float`, defaults to 1.0):
|
158 |
+
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
159 |
+
set_alpha_to_one (`bool`, defaults to `True`):
|
160 |
+
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
|
161 |
+
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
162 |
+
otherwise it uses the alpha value at step 0.
|
163 |
+
steps_offset (`int`, defaults to 0):
|
164 |
+
An offset added to the inference steps. You can use a combination of `offset=1` and
|
165 |
+
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
166 |
+
Diffusion.
|
167 |
+
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
168 |
+
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
169 |
+
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
170 |
+
Video](https://imagen.research.google/video/paper.pdf) paper).
|
171 |
+
thresholding (`bool`, defaults to `False`):
|
172 |
+
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
173 |
+
as Stable Diffusion.
|
174 |
+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
175 |
+
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
176 |
+
sample_max_value (`float`, defaults to 1.0):
|
177 |
+
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
178 |
+
timestep_spacing (`str`, defaults to `"leading"`):
|
179 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
180 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
181 |
+
timestep_scaling (`float`, defaults to 10.0):
|
182 |
+
The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
|
183 |
+
`c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
|
184 |
+
error at the default of `10.0` is already pretty small).
|
185 |
+
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
186 |
+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
187 |
+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
188 |
+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
189 |
+
"""
|
190 |
+
|
191 |
+
order = 1
|
192 |
+
|
193 |
+
@register_to_config
|
194 |
+
def __init__(
|
195 |
+
self,
|
196 |
+
num_train_timesteps: int = 1000,
|
197 |
+
beta_start: float = 0.00085,
|
198 |
+
beta_end: float = 0.012,
|
199 |
+
beta_schedule: str = "scaled_linear",
|
200 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
201 |
+
original_inference_steps: int = 50,
|
202 |
+
clip_sample: bool = False,
|
203 |
+
clip_sample_range: float = 1.0,
|
204 |
+
set_alpha_to_one: bool = True,
|
205 |
+
steps_offset: int = 0,
|
206 |
+
prediction_type: str = "epsilon",
|
207 |
+
thresholding: bool = False,
|
208 |
+
dynamic_thresholding_ratio: float = 0.995,
|
209 |
+
sample_max_value: float = 1.0,
|
210 |
+
timestep_spacing: str = "leading",
|
211 |
+
timestep_scaling: float = 10.0,
|
212 |
+
rescale_betas_zero_snr: bool = False,
|
213 |
+
):
|
214 |
+
if trained_betas is not None:
|
215 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
216 |
+
elif beta_schedule == "linear":
|
217 |
+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
218 |
+
elif beta_schedule == "scaled_linear":
|
219 |
+
# this schedule is very specific to the latent diffusion model.
|
220 |
+
self.betas = (
|
221 |
+
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
222 |
+
)
|
223 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
224 |
+
# Glide cosine schedule
|
225 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
226 |
+
else:
|
227 |
+
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
228 |
+
|
229 |
+
# Rescale for zero SNR
|
230 |
+
if rescale_betas_zero_snr:
|
231 |
+
self.betas = rescale_zero_terminal_snr(self.betas)
|
232 |
+
|
233 |
+
self.alphas = 1.0 - self.betas
|
234 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
235 |
+
|
236 |
+
# At every step in ddim, we are looking into the previous alphas_cumprod
|
237 |
+
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
238 |
+
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
239 |
+
# whether we use the final alpha of the "non-previous" one.
|
240 |
+
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
241 |
+
|
242 |
+
# standard deviation of the initial noise distribution
|
243 |
+
self.init_noise_sigma = 1.0
|
244 |
+
|
245 |
+
# setable values
|
246 |
+
self.num_inference_steps = None
|
247 |
+
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
248 |
+
|
249 |
+
self._step_index = None
|
250 |
+
|
251 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
252 |
+
def _init_step_index(self, timestep):
|
253 |
+
if isinstance(timestep, torch.Tensor):
|
254 |
+
timestep = timestep.to(self.timesteps.device)
|
255 |
+
|
256 |
+
index_candidates = (self.timesteps == timestep).nonzero()
|
257 |
+
|
258 |
+
# The sigma index that is taken for the **very** first `step`
|
259 |
+
# is always the second index (or the last index if there is only 1)
|
260 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
261 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
262 |
+
if len(index_candidates) > 1:
|
263 |
+
step_index = index_candidates[1]
|
264 |
+
else:
|
265 |
+
step_index = index_candidates[0]
|
266 |
+
|
267 |
+
self._step_index = step_index.item()
|
268 |
+
|
269 |
+
@property
|
270 |
+
def step_index(self):
|
271 |
+
return self._step_index
|
272 |
+
|
273 |
+
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
274 |
+
"""
|
275 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
276 |
+
current timestep.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
sample (`torch.FloatTensor`):
|
280 |
+
The input sample.
|
281 |
+
timestep (`int`, *optional*):
|
282 |
+
The current timestep in the diffusion chain.
|
283 |
+
Returns:
|
284 |
+
`torch.FloatTensor`:
|
285 |
+
A scaled input sample.
|
286 |
+
"""
|
287 |
+
return sample
|
288 |
+
|
289 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
290 |
+
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
291 |
+
"""
|
292 |
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
293 |
+
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
294 |
+
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
295 |
+
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
296 |
+
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
297 |
+
|
298 |
+
https://arxiv.org/abs/2205.11487
|
299 |
+
"""
|
300 |
+
dtype = sample.dtype
|
301 |
+
batch_size, channels, *remaining_dims = sample.shape
|
302 |
+
|
303 |
+
if dtype not in (torch.float32, torch.float64):
|
304 |
+
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
305 |
+
|
306 |
+
# Flatten sample for doing quantile calculation along each image
|
307 |
+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
308 |
+
|
309 |
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
310 |
+
|
311 |
+
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
312 |
+
s = torch.clamp(
|
313 |
+
s, min=1, max=self.config.sample_max_value
|
314 |
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
315 |
+
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
316 |
+
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
317 |
+
|
318 |
+
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
319 |
+
sample = sample.to(dtype)
|
320 |
+
|
321 |
+
return sample
|
322 |
+
|
323 |
+
def set_timesteps(
|
324 |
+
self,
|
325 |
+
num_inference_steps: int = None,
|
326 |
+
device: Union[str, torch.device] = None,
|
327 |
+
original_inference_steps: Optional[int] = None,
|
328 |
+
strength: int = 1.0,
|
329 |
+
timesteps: Optional[list] = None,
|
330 |
+
):
|
331 |
+
"""
|
332 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
333 |
+
|
334 |
+
Args:
|
335 |
+
num_inference_steps (`int`):
|
336 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
337 |
+
device (`str` or `torch.device`, *optional*):
|
338 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
339 |
+
original_inference_steps (`int`, *optional*):
|
340 |
+
The original number of inference steps, which will be used to generate a linearly-spaced timestep
|
341 |
+
schedule (which is different from the standard `diffusers` implementation). We will then take
|
342 |
+
`num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as
|
343 |
+
our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute.
|
344 |
+
"""
|
345 |
+
|
346 |
+
if num_inference_steps is not None and timesteps is not None:
|
347 |
+
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
|
348 |
+
|
349 |
+
if timesteps is not None:
|
350 |
+
for i in range(1, len(timesteps)):
|
351 |
+
if timesteps[i] >= timesteps[i - 1]:
|
352 |
+
raise ValueError("`custom_timesteps` must be in descending order.")
|
353 |
+
|
354 |
+
if timesteps[0] >= self.config.num_train_timesteps:
|
355 |
+
raise ValueError(
|
356 |
+
f"`timesteps` must start before `self.config.train_timesteps`:"
|
357 |
+
f" {self.config.num_train_timesteps}."
|
358 |
+
)
|
359 |
+
|
360 |
+
timesteps = np.array(timesteps, dtype=np.int64)
|
361 |
+
else:
|
362 |
+
if num_inference_steps > self.config.num_train_timesteps:
|
363 |
+
raise ValueError(
|
364 |
+
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
365 |
+
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
366 |
+
f" maximal {self.config.num_train_timesteps} timesteps."
|
367 |
+
)
|
368 |
+
|
369 |
+
self.num_inference_steps = num_inference_steps
|
370 |
+
original_steps = (
|
371 |
+
original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps
|
372 |
+
)
|
373 |
+
|
374 |
+
if original_steps > self.config.num_train_timesteps:
|
375 |
+
raise ValueError(
|
376 |
+
f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:"
|
377 |
+
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
378 |
+
f" maximal {self.config.num_train_timesteps} timesteps."
|
379 |
+
)
|
380 |
+
|
381 |
+
if num_inference_steps > original_steps:
|
382 |
+
raise ValueError(
|
383 |
+
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
|
384 |
+
f" {original_steps} because the final timestep schedule will be a subset of the"
|
385 |
+
f" `original_inference_steps`-sized initial timestep schedule."
|
386 |
+
)
|
387 |
+
|
388 |
+
# LCM Timesteps Setting
|
389 |
+
# Currently, only linear spacing is supported.
|
390 |
+
c = self.config.num_train_timesteps // original_steps
|
391 |
+
# LCM Training Steps Schedule
|
392 |
+
lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * c - 1
|
393 |
+
skipping_step = len(lcm_origin_timesteps) // num_inference_steps
|
394 |
+
# LCM Inference Steps Schedule
|
395 |
+
timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps]
|
396 |
+
|
397 |
+
self.timesteps = torch.from_numpy(timesteps.copy()).to(device=device, dtype=torch.long)
|
398 |
+
|
399 |
+
self._step_index = None
|
400 |
+
|
401 |
+
def get_scalings_for_boundary_condition_discrete(self, timestep):
|
402 |
+
self.sigma_data = 0.5 # Default: 0.5
|
403 |
+
scaled_timestep = timestep * self.config.timestep_scaling
|
404 |
+
|
405 |
+
c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
|
406 |
+
c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
|
407 |
+
return c_skip, c_out
|
408 |
+
|
409 |
+
def append_dims(self, x, target_dims):
|
410 |
+
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
411 |
+
dims_to_append = target_dims - x.ndim
|
412 |
+
if dims_to_append < 0:
|
413 |
+
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
|
414 |
+
return x[(...,) + (None,) * dims_to_append]
|
415 |
+
|
416 |
+
def extract_into_tensor(self, a, t, x_shape):
|
417 |
+
b, *_ = t.shape
|
418 |
+
out = a.gather(-1, t)
|
419 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
420 |
+
|
421 |
+
def step(
|
422 |
+
self,
|
423 |
+
model_output: torch.FloatTensor,
|
424 |
+
timestep: torch.Tensor,
|
425 |
+
sample: torch.FloatTensor,
|
426 |
+
generator: Optional[torch.Generator] = None,
|
427 |
+
return_dict: bool = True,
|
428 |
+
) -> Union[LCMSingleStepSchedulerOutput, Tuple]:
|
429 |
+
"""
|
430 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
431 |
+
process from the learned model outputs (most often the predicted noise).
|
432 |
+
|
433 |
+
Args:
|
434 |
+
model_output (`torch.FloatTensor`):
|
435 |
+
The direct output from learned diffusion model.
|
436 |
+
timestep (`float`):
|
437 |
+
The current discrete timestep in the diffusion chain.
|
438 |
+
sample (`torch.FloatTensor`):
|
439 |
+
A current instance of a sample created by the diffusion process.
|
440 |
+
generator (`torch.Generator`, *optional*):
|
441 |
+
A random number generator.
|
442 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
443 |
+
Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
|
444 |
+
Returns:
|
445 |
+
[`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
|
446 |
+
If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
|
447 |
+
tuple is returned where the first element is the sample tensor.
|
448 |
+
"""
|
449 |
+
# 0. make sure everything is on the same device
|
450 |
+
alphas_cumprod = self.alphas_cumprod.to(sample.device)
|
451 |
+
|
452 |
+
# 1. compute alphas, betas
|
453 |
+
if timestep.ndim == 0:
|
454 |
+
timestep = timestep.unsqueeze(0)
|
455 |
+
alpha_prod_t = self.extract_into_tensor(alphas_cumprod, timestep, sample.shape)
|
456 |
+
beta_prod_t = 1 - alpha_prod_t
|
457 |
+
|
458 |
+
# 2. Get scalings for boundary conditions
|
459 |
+
c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
|
460 |
+
c_skip, c_out = [self.append_dims(x, sample.ndim) for x in [c_skip, c_out]]
|
461 |
+
|
462 |
+
# 3. Compute the predicted original sample x_0 based on the model parameterization
|
463 |
+
if self.config.prediction_type == "epsilon": # noise-prediction
|
464 |
+
predicted_original_sample = (sample - torch.sqrt(beta_prod_t) * model_output) / torch.sqrt(alpha_prod_t)
|
465 |
+
elif self.config.prediction_type == "sample": # x-prediction
|
466 |
+
predicted_original_sample = model_output
|
467 |
+
elif self.config.prediction_type == "v_prediction": # v-prediction
|
468 |
+
predicted_original_sample = torch.sqrt(alpha_prod_t) * sample - torch.sqrt(beta_prod_t) * model_output
|
469 |
+
else:
|
470 |
+
raise ValueError(
|
471 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
|
472 |
+
" `v_prediction` for `LCMScheduler`."
|
473 |
+
)
|
474 |
+
|
475 |
+
# 4. Clip or threshold "predicted x_0"
|
476 |
+
if self.config.thresholding:
|
477 |
+
predicted_original_sample = self._threshold_sample(predicted_original_sample)
|
478 |
+
elif self.config.clip_sample:
|
479 |
+
predicted_original_sample = predicted_original_sample.clamp(
|
480 |
+
-self.config.clip_sample_range, self.config.clip_sample_range
|
481 |
+
)
|
482 |
+
|
483 |
+
# 5. Denoise model output using boundary conditions
|
484 |
+
denoised = c_out * predicted_original_sample + c_skip * sample
|
485 |
+
|
486 |
+
if not return_dict:
|
487 |
+
return (denoised, )
|
488 |
+
|
489 |
+
return LCMSingleStepSchedulerOutput(denoised=denoised)
|
490 |
+
|
491 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
492 |
+
def add_noise(
|
493 |
+
self,
|
494 |
+
original_samples: torch.FloatTensor,
|
495 |
+
noise: torch.FloatTensor,
|
496 |
+
timesteps: torch.IntTensor,
|
497 |
+
) -> torch.FloatTensor:
|
498 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
499 |
+
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
500 |
+
timesteps = timesteps.to(original_samples.device)
|
501 |
+
|
502 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
503 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
504 |
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
505 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
506 |
+
|
507 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
508 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
509 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
510 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
511 |
+
|
512 |
+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
513 |
+
return noisy_samples
|
514 |
+
|
515 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
|
516 |
+
def get_velocity(
|
517 |
+
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
518 |
+
) -> torch.FloatTensor:
|
519 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
520 |
+
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
|
521 |
+
timesteps = timesteps.to(sample.device)
|
522 |
+
|
523 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
524 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
525 |
+
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
526 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
527 |
+
|
528 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
529 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
530 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
531 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
532 |
+
|
533 |
+
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
534 |
+
return velocity
|
535 |
+
|
536 |
+
def __len__(self):
|
537 |
+
return self.config.num_train_timesteps
|
train_previewer_lora.py
ADDED
@@ -0,0 +1,1712 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2024 The LCM team and the HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
|
16 |
+
import argparse
|
17 |
+
import copy
|
18 |
+
import functools
|
19 |
+
import gc
|
20 |
+
import logging
|
21 |
+
import pyrallis
|
22 |
+
import math
|
23 |
+
import os
|
24 |
+
import random
|
25 |
+
import shutil
|
26 |
+
from contextlib import nullcontext
|
27 |
+
from pathlib import Path
|
28 |
+
|
29 |
+
import accelerate
|
30 |
+
import numpy as np
|
31 |
+
import torch
|
32 |
+
import torch.nn.functional as F
|
33 |
+
import torch.utils.checkpoint
|
34 |
+
import transformers
|
35 |
+
from PIL import Image
|
36 |
+
from accelerate import Accelerator
|
37 |
+
from accelerate.logging import get_logger
|
38 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
39 |
+
from datasets import load_dataset
|
40 |
+
from huggingface_hub import create_repo, upload_folder
|
41 |
+
from packaging import version
|
42 |
+
from collections import namedtuple
|
43 |
+
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
|
44 |
+
from torchvision import transforms
|
45 |
+
from torchvision.transforms.functional import crop
|
46 |
+
from tqdm.auto import tqdm
|
47 |
+
from transformers import (
|
48 |
+
AutoTokenizer,
|
49 |
+
PretrainedConfig,
|
50 |
+
CLIPImageProcessor, CLIPVisionModelWithProjection,
|
51 |
+
AutoImageProcessor, AutoModel
|
52 |
+
)
|
53 |
+
|
54 |
+
import diffusers
|
55 |
+
from diffusers import (
|
56 |
+
AutoencoderKL,
|
57 |
+
DDPMScheduler,
|
58 |
+
LCMScheduler,
|
59 |
+
StableDiffusionXLPipeline,
|
60 |
+
UNet2DConditionModel,
|
61 |
+
)
|
62 |
+
from diffusers.optimization import get_scheduler
|
63 |
+
from diffusers.training_utils import cast_training_params, resolve_interpolation_mode
|
64 |
+
from diffusers.utils import (
|
65 |
+
check_min_version,
|
66 |
+
convert_state_dict_to_diffusers,
|
67 |
+
convert_unet_state_dict_to_peft,
|
68 |
+
is_wandb_available,
|
69 |
+
)
|
70 |
+
from diffusers.utils.import_utils import is_xformers_available
|
71 |
+
from diffusers.utils.torch_utils import is_compiled_module
|
72 |
+
|
73 |
+
from basicsr.utils.degradation_pipeline import RealESRGANDegradation
|
74 |
+
from utils.train_utils import (
|
75 |
+
seperate_ip_params_from_unet,
|
76 |
+
import_model_class_from_model_name_or_path,
|
77 |
+
tensor_to_pil,
|
78 |
+
get_train_dataset, prepare_train_dataset, collate_fn,
|
79 |
+
encode_prompt, importance_sampling_fn, extract_into_tensor
|
80 |
+
|
81 |
+
)
|
82 |
+
from data.data_config import DataConfig
|
83 |
+
from losses.loss_config import LossesConfig
|
84 |
+
from losses.losses import *
|
85 |
+
|
86 |
+
from module.ip_adapter.resampler import Resampler
|
87 |
+
from module.ip_adapter.utils import init_adapter_in_unet, prepare_training_image_embeds
|
88 |
+
|
89 |
+
|
90 |
+
if is_wandb_available():
|
91 |
+
import wandb
|
92 |
+
|
93 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
94 |
+
|
95 |
+
logger = get_logger(__name__)
|
96 |
+
|
97 |
+
|
98 |
+
def prepare_latents(lq, vae, scheduler, generator, timestep):
|
99 |
+
transform = transforms.Compose([
|
100 |
+
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
101 |
+
transforms.CenterCrop(args.resolution),
|
102 |
+
transforms.ToTensor(),
|
103 |
+
])
|
104 |
+
lq_pt = [transform(lq_pil.convert("RGB")) for lq_pil in lq]
|
105 |
+
img_pt = torch.stack(lq_pt).to(vae.device, dtype=vae.dtype)
|
106 |
+
img_pt = img_pt * 2.0 - 1.0
|
107 |
+
with torch.no_grad():
|
108 |
+
latents = vae.encode(img_pt).latent_dist.sample()
|
109 |
+
latents = latents * vae.config.scaling_factor
|
110 |
+
noise = torch.randn(latents.shape, generator=generator, device=vae.device, dtype=vae.dtype, layout=torch.strided).to(vae.device)
|
111 |
+
bsz = latents.shape[0]
|
112 |
+
print(f"init latent at {timestep}")
|
113 |
+
timestep = torch.tensor([timestep]*bsz, device=vae.device, dtype=torch.int64)
|
114 |
+
latents = scheduler.add_noise(latents, noise, timestep)
|
115 |
+
return latents
|
116 |
+
|
117 |
+
|
118 |
+
def log_validation(unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
|
119 |
+
scheduler, image_encoder, image_processor,
|
120 |
+
args, accelerator, weight_dtype, step, lq_img=None, gt_img=None, is_final_validation=False, log_local=False):
|
121 |
+
logger.info("Running validation... ")
|
122 |
+
|
123 |
+
image_logs = []
|
124 |
+
|
125 |
+
lq = [Image.open(lq_example) for lq_example in args.validation_image]
|
126 |
+
|
127 |
+
pipe = StableDiffusionXLPipeline(
|
128 |
+
vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
|
129 |
+
unet, scheduler, image_encoder, image_processor,
|
130 |
+
).to(accelerator.device)
|
131 |
+
|
132 |
+
timesteps = [args.num_train_timesteps - 1]
|
133 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
134 |
+
latents = prepare_latents(lq, vae, scheduler, generator, timesteps[-1])
|
135 |
+
image = pipe(
|
136 |
+
prompt=[""]*len(lq),
|
137 |
+
ip_adapter_image=[lq],
|
138 |
+
num_inference_steps=1,
|
139 |
+
timesteps=timesteps,
|
140 |
+
generator=generator,
|
141 |
+
guidance_scale=1.0,
|
142 |
+
height=args.resolution,
|
143 |
+
width=args.resolution,
|
144 |
+
latents=latents,
|
145 |
+
).images
|
146 |
+
|
147 |
+
if log_local:
|
148 |
+
# for i, img in enumerate(tensor_to_pil(lq_img)):
|
149 |
+
# img.save(f"./lq_{i}.png")
|
150 |
+
# for i, img in enumerate(tensor_to_pil(gt_img)):
|
151 |
+
# img.save(f"./gt_{i}.png")
|
152 |
+
for i, img in enumerate(image):
|
153 |
+
img.save(f"./lq_IPA_{i}.png")
|
154 |
+
return
|
155 |
+
|
156 |
+
tracker_key = "test" if is_final_validation else "validation"
|
157 |
+
for tracker in accelerator.trackers:
|
158 |
+
if tracker.name == "tensorboard":
|
159 |
+
images = [np.asarray(pil_img) for pil_img in image]
|
160 |
+
images = np.stack(images, axis=0)
|
161 |
+
if lq_img is not None and gt_img is not None:
|
162 |
+
input_lq = lq_img.detach().cpu()
|
163 |
+
input_lq = np.asarray(input_lq.add(1).div(2).clamp(0, 1))
|
164 |
+
input_gt = gt_img.detach().cpu()
|
165 |
+
input_gt = np.asarray(input_gt.add(1).div(2).clamp(0, 1))
|
166 |
+
tracker.writer.add_images("lq", input_lq, step, dataformats="NCHW")
|
167 |
+
tracker.writer.add_images("gt", input_gt, step, dataformats="NCHW")
|
168 |
+
tracker.writer.add_images("rec", images, step, dataformats="NHWC")
|
169 |
+
elif tracker.name == "wandb":
|
170 |
+
raise NotImplementedError("Wandb logging not implemented for validation.")
|
171 |
+
formatted_images = []
|
172 |
+
|
173 |
+
for log in image_logs:
|
174 |
+
images = log["images"]
|
175 |
+
validation_prompt = log["validation_prompt"]
|
176 |
+
validation_image = log["validation_image"]
|
177 |
+
|
178 |
+
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
|
179 |
+
|
180 |
+
for image in images:
|
181 |
+
image = wandb.Image(image, caption=validation_prompt)
|
182 |
+
formatted_images.append(image)
|
183 |
+
|
184 |
+
tracker.log({tracker_key: formatted_images})
|
185 |
+
else:
|
186 |
+
logger.warning(f"image logging not implemented for {tracker.name}")
|
187 |
+
|
188 |
+
gc.collect()
|
189 |
+
torch.cuda.empty_cache()
|
190 |
+
|
191 |
+
return image_logs
|
192 |
+
|
193 |
+
|
194 |
+
class DDIMSolver:
|
195 |
+
def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
|
196 |
+
# DDIM sampling parameters
|
197 |
+
step_ratio = timesteps // ddim_timesteps
|
198 |
+
|
199 |
+
self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
|
200 |
+
self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
|
201 |
+
self.ddim_alpha_cumprods_prev = np.asarray(
|
202 |
+
[alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
|
203 |
+
)
|
204 |
+
# convert to torch tensors
|
205 |
+
self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
|
206 |
+
self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
|
207 |
+
self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
|
208 |
+
|
209 |
+
def to(self, device):
|
210 |
+
self.ddim_timesteps = self.ddim_timesteps.to(device)
|
211 |
+
self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
|
212 |
+
self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
|
213 |
+
return self
|
214 |
+
|
215 |
+
def ddim_step(self, pred_x0, pred_noise, timestep_index):
|
216 |
+
alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
|
217 |
+
dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
|
218 |
+
x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
|
219 |
+
return x_prev
|
220 |
+
|
221 |
+
|
222 |
+
def append_dims(x, target_dims):
|
223 |
+
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
224 |
+
dims_to_append = target_dims - x.ndim
|
225 |
+
if dims_to_append < 0:
|
226 |
+
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
|
227 |
+
return x[(...,) + (None,) * dims_to_append]
|
228 |
+
|
229 |
+
|
230 |
+
# From LCMScheduler.get_scalings_for_boundary_condition_discrete
|
231 |
+
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
|
232 |
+
scaled_timestep = timestep_scaling * timestep
|
233 |
+
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
|
234 |
+
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
|
235 |
+
return c_skip, c_out
|
236 |
+
|
237 |
+
|
238 |
+
# Compare LCMScheduler.step, Step 4
|
239 |
+
def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
240 |
+
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
241 |
+
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
242 |
+
if prediction_type == "epsilon":
|
243 |
+
pred_x_0 = (sample - sigmas * model_output) / alphas
|
244 |
+
elif prediction_type == "sample":
|
245 |
+
pred_x_0 = model_output
|
246 |
+
elif prediction_type == "v_prediction":
|
247 |
+
pred_x_0 = alphas * sample - sigmas * model_output
|
248 |
+
else:
|
249 |
+
raise ValueError(
|
250 |
+
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
251 |
+
f" are supported."
|
252 |
+
)
|
253 |
+
|
254 |
+
return pred_x_0
|
255 |
+
|
256 |
+
|
257 |
+
# Based on step 4 in DDIMScheduler.step
|
258 |
+
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
259 |
+
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
260 |
+
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
261 |
+
if prediction_type == "epsilon":
|
262 |
+
pred_epsilon = model_output
|
263 |
+
elif prediction_type == "sample":
|
264 |
+
pred_epsilon = (sample - alphas * model_output) / sigmas
|
265 |
+
elif prediction_type == "v_prediction":
|
266 |
+
pred_epsilon = alphas * model_output + sigmas * sample
|
267 |
+
else:
|
268 |
+
raise ValueError(
|
269 |
+
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
270 |
+
f" are supported."
|
271 |
+
)
|
272 |
+
|
273 |
+
return pred_epsilon
|
274 |
+
|
275 |
+
|
276 |
+
def extract_into_tensor(a, t, x_shape):
|
277 |
+
b, *_ = t.shape
|
278 |
+
out = a.gather(-1, t)
|
279 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
280 |
+
|
281 |
+
|
282 |
+
def parse_args():
|
283 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
284 |
+
# ----------Model Checkpoint Loading Arguments----------
|
285 |
+
parser.add_argument(
|
286 |
+
"--pretrained_model_name_or_path",
|
287 |
+
type=str,
|
288 |
+
default=None,
|
289 |
+
required=True,
|
290 |
+
help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.",
|
291 |
+
)
|
292 |
+
parser.add_argument(
|
293 |
+
"--pretrained_vae_model_name_or_path",
|
294 |
+
type=str,
|
295 |
+
default=None,
|
296 |
+
help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
|
297 |
+
)
|
298 |
+
parser.add_argument(
|
299 |
+
"--teacher_revision",
|
300 |
+
type=str,
|
301 |
+
default=None,
|
302 |
+
required=False,
|
303 |
+
help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.",
|
304 |
+
)
|
305 |
+
parser.add_argument(
|
306 |
+
"--revision",
|
307 |
+
type=str,
|
308 |
+
default=None,
|
309 |
+
required=False,
|
310 |
+
help="Revision of pretrained LDM model identifier from huggingface.co/models.",
|
311 |
+
)
|
312 |
+
parser.add_argument(
|
313 |
+
"--pretrained_lcm_lora_path",
|
314 |
+
type=str,
|
315 |
+
default=None,
|
316 |
+
help="Path to LCM lora or model identifier from huggingface.co/models.",
|
317 |
+
)
|
318 |
+
parser.add_argument(
|
319 |
+
"--feature_extractor_path",
|
320 |
+
type=str,
|
321 |
+
default=None,
|
322 |
+
help="Path to image encoder for IP-Adapters or model identifier from huggingface.co/models.",
|
323 |
+
)
|
324 |
+
parser.add_argument(
|
325 |
+
"--pretrained_adapter_model_path",
|
326 |
+
type=str,
|
327 |
+
default=None,
|
328 |
+
help="Path to IP-Adapter models or model identifier from huggingface.co/models.",
|
329 |
+
)
|
330 |
+
parser.add_argument(
|
331 |
+
"--adapter_tokens",
|
332 |
+
type=int,
|
333 |
+
default=64,
|
334 |
+
help="Number of tokens to use in IP-adapter cross attention mechanism.",
|
335 |
+
)
|
336 |
+
parser.add_argument(
|
337 |
+
"--use_clip_encoder",
|
338 |
+
action="store_true",
|
339 |
+
help="Whether or not to use DINO as image encoder, else CLIP encoder.",
|
340 |
+
)
|
341 |
+
parser.add_argument(
|
342 |
+
"--image_encoder_hidden_feature",
|
343 |
+
action="store_true",
|
344 |
+
help="Whether or not to use the penultimate hidden states as image embeddings.",
|
345 |
+
)
|
346 |
+
# ----------Training Arguments----------
|
347 |
+
# ----General Training Arguments----
|
348 |
+
parser.add_argument(
|
349 |
+
"--output_dir",
|
350 |
+
type=str,
|
351 |
+
default="lcm-xl-distilled",
|
352 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
353 |
+
)
|
354 |
+
parser.add_argument(
|
355 |
+
"--cache_dir",
|
356 |
+
type=str,
|
357 |
+
default=None,
|
358 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
359 |
+
)
|
360 |
+
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
361 |
+
# ----Logging----
|
362 |
+
parser.add_argument(
|
363 |
+
"--logging_dir",
|
364 |
+
type=str,
|
365 |
+
default="logs",
|
366 |
+
help=(
|
367 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
368 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
369 |
+
),
|
370 |
+
)
|
371 |
+
parser.add_argument(
|
372 |
+
"--report_to",
|
373 |
+
type=str,
|
374 |
+
default="tensorboard",
|
375 |
+
help=(
|
376 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
377 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
378 |
+
),
|
379 |
+
)
|
380 |
+
# ----Checkpointing----
|
381 |
+
parser.add_argument(
|
382 |
+
"--checkpointing_steps",
|
383 |
+
type=int,
|
384 |
+
default=4000,
|
385 |
+
help=(
|
386 |
+
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
387 |
+
" training using `--resume_from_checkpoint`."
|
388 |
+
),
|
389 |
+
)
|
390 |
+
parser.add_argument(
|
391 |
+
"--checkpoints_total_limit",
|
392 |
+
type=int,
|
393 |
+
default=5,
|
394 |
+
help=("Max number of checkpoints to store."),
|
395 |
+
)
|
396 |
+
parser.add_argument(
|
397 |
+
"--resume_from_checkpoint",
|
398 |
+
type=str,
|
399 |
+
default=None,
|
400 |
+
help=(
|
401 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
402 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
403 |
+
),
|
404 |
+
)
|
405 |
+
parser.add_argument(
|
406 |
+
"--save_only_adapter",
|
407 |
+
action="store_true",
|
408 |
+
help="Only save extra adapter to save space.",
|
409 |
+
)
|
410 |
+
# ----Image Processing----
|
411 |
+
parser.add_argument(
|
412 |
+
"--data_config_path",
|
413 |
+
type=str,
|
414 |
+
default=None,
|
415 |
+
help=("A folder containing the training data. "),
|
416 |
+
)
|
417 |
+
parser.add_argument(
|
418 |
+
"--train_data_dir",
|
419 |
+
type=str,
|
420 |
+
default=None,
|
421 |
+
help=(
|
422 |
+
"A folder containing the training data. Folder contents must follow the structure described in"
|
423 |
+
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
|
424 |
+
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
|
425 |
+
),
|
426 |
+
)
|
427 |
+
parser.add_argument(
|
428 |
+
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
|
429 |
+
)
|
430 |
+
parser.add_argument(
|
431 |
+
"--conditioning_image_column",
|
432 |
+
type=str,
|
433 |
+
default="conditioning_image",
|
434 |
+
help="The column of the dataset containing the controlnet conditioning image.",
|
435 |
+
)
|
436 |
+
parser.add_argument(
|
437 |
+
"--caption_column",
|
438 |
+
type=str,
|
439 |
+
default="text",
|
440 |
+
help="The column of the dataset containing a caption or a list of captions.",
|
441 |
+
)
|
442 |
+
parser.add_argument(
|
443 |
+
"--text_drop_rate",
|
444 |
+
type=float,
|
445 |
+
default=0,
|
446 |
+
help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
|
447 |
+
)
|
448 |
+
parser.add_argument(
|
449 |
+
"--image_drop_rate",
|
450 |
+
type=float,
|
451 |
+
default=0,
|
452 |
+
help="Proportion of IP-Adapter inputs to be dropped. Defaults to 0 (no drop-out).",
|
453 |
+
)
|
454 |
+
parser.add_argument(
|
455 |
+
"--cond_drop_rate",
|
456 |
+
type=float,
|
457 |
+
default=0,
|
458 |
+
help="Proportion of all conditions to be dropped. Defaults to 0 (no drop-out).",
|
459 |
+
)
|
460 |
+
parser.add_argument(
|
461 |
+
"--resolution",
|
462 |
+
type=int,
|
463 |
+
default=1024,
|
464 |
+
help=(
|
465 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
466 |
+
" resolution"
|
467 |
+
),
|
468 |
+
)
|
469 |
+
parser.add_argument(
|
470 |
+
"--interpolation_type",
|
471 |
+
type=str,
|
472 |
+
default="bilinear",
|
473 |
+
help=(
|
474 |
+
"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
|
475 |
+
" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
|
476 |
+
),
|
477 |
+
)
|
478 |
+
parser.add_argument(
|
479 |
+
"--center_crop",
|
480 |
+
default=False,
|
481 |
+
action="store_true",
|
482 |
+
help=(
|
483 |
+
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
484 |
+
" cropped. The images will be resized to the resolution first before cropping."
|
485 |
+
),
|
486 |
+
)
|
487 |
+
parser.add_argument(
|
488 |
+
"--random_flip",
|
489 |
+
action="store_true",
|
490 |
+
help="whether to randomly flip images horizontally",
|
491 |
+
)
|
492 |
+
parser.add_argument(
|
493 |
+
"--encode_batch_size",
|
494 |
+
type=int,
|
495 |
+
default=8,
|
496 |
+
help="Batch size to use for VAE encoding of the images for efficient processing.",
|
497 |
+
)
|
498 |
+
# ----Dataloader----
|
499 |
+
parser.add_argument(
|
500 |
+
"--dataloader_num_workers",
|
501 |
+
type=int,
|
502 |
+
default=0,
|
503 |
+
help=(
|
504 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
505 |
+
),
|
506 |
+
)
|
507 |
+
# ----Batch Size and Training Steps----
|
508 |
+
parser.add_argument(
|
509 |
+
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
510 |
+
)
|
511 |
+
parser.add_argument("--num_train_epochs", type=int, default=100)
|
512 |
+
parser.add_argument(
|
513 |
+
"--max_train_steps",
|
514 |
+
type=int,
|
515 |
+
default=None,
|
516 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
517 |
+
)
|
518 |
+
parser.add_argument(
|
519 |
+
"--max_train_samples",
|
520 |
+
type=int,
|
521 |
+
default=None,
|
522 |
+
help=(
|
523 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
524 |
+
"value if set."
|
525 |
+
),
|
526 |
+
)
|
527 |
+
# ----Learning Rate----
|
528 |
+
parser.add_argument(
|
529 |
+
"--learning_rate",
|
530 |
+
type=float,
|
531 |
+
default=1e-6,
|
532 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
533 |
+
)
|
534 |
+
parser.add_argument(
|
535 |
+
"--scale_lr",
|
536 |
+
action="store_true",
|
537 |
+
default=False,
|
538 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
539 |
+
)
|
540 |
+
parser.add_argument(
|
541 |
+
"--lr_scheduler",
|
542 |
+
type=str,
|
543 |
+
default="constant",
|
544 |
+
help=(
|
545 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
546 |
+
' "constant", "constant_with_warmup"]'
|
547 |
+
),
|
548 |
+
)
|
549 |
+
parser.add_argument(
|
550 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
551 |
+
)
|
552 |
+
parser.add_argument(
|
553 |
+
"--lr_num_cycles",
|
554 |
+
type=int,
|
555 |
+
default=1,
|
556 |
+
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
557 |
+
)
|
558 |
+
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
|
559 |
+
parser.add_argument(
|
560 |
+
"--gradient_accumulation_steps",
|
561 |
+
type=int,
|
562 |
+
default=1,
|
563 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
564 |
+
)
|
565 |
+
# ----Optimizer (Adam)----
|
566 |
+
parser.add_argument(
|
567 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
568 |
+
)
|
569 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
570 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
571 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
572 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
573 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
574 |
+
# ----Diffusion Training Arguments----
|
575 |
+
# ----Latent Consistency Distillation (LCD) Specific Arguments----
|
576 |
+
parser.add_argument(
|
577 |
+
"--w_min",
|
578 |
+
type=float,
|
579 |
+
default=3.0,
|
580 |
+
required=False,
|
581 |
+
help=(
|
582 |
+
"The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
|
583 |
+
" formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
|
584 |
+
" compared to the original paper."
|
585 |
+
),
|
586 |
+
)
|
587 |
+
parser.add_argument(
|
588 |
+
"--w_max",
|
589 |
+
type=float,
|
590 |
+
default=15.0,
|
591 |
+
required=False,
|
592 |
+
help=(
|
593 |
+
"The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
|
594 |
+
" formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
|
595 |
+
" compared to the original paper."
|
596 |
+
),
|
597 |
+
)
|
598 |
+
parser.add_argument(
|
599 |
+
"--num_train_timesteps",
|
600 |
+
type=int,
|
601 |
+
default=1000,
|
602 |
+
help="The number of timesteps to use for DDIM sampling.",
|
603 |
+
)
|
604 |
+
parser.add_argument(
|
605 |
+
"--num_ddim_timesteps",
|
606 |
+
type=int,
|
607 |
+
default=50,
|
608 |
+
help="The number of timesteps to use for DDIM sampling.",
|
609 |
+
)
|
610 |
+
parser.add_argument(
|
611 |
+
"--losses_config_path",
|
612 |
+
type=str,
|
613 |
+
default='config_files/losses.yaml',
|
614 |
+
required=True,
|
615 |
+
help=("A yaml file containing losses to use and their weights."),
|
616 |
+
)
|
617 |
+
parser.add_argument(
|
618 |
+
"--loss_type",
|
619 |
+
type=str,
|
620 |
+
default="l2",
|
621 |
+
choices=["l2", "huber"],
|
622 |
+
help="The type of loss to use for the LCD loss.",
|
623 |
+
)
|
624 |
+
parser.add_argument(
|
625 |
+
"--huber_c",
|
626 |
+
type=float,
|
627 |
+
default=0.001,
|
628 |
+
help="The huber loss parameter. Only used if `--loss_type=huber`.",
|
629 |
+
)
|
630 |
+
parser.add_argument(
|
631 |
+
"--lora_rank",
|
632 |
+
type=int,
|
633 |
+
default=64,
|
634 |
+
help="The rank of the LoRA projection matrix.",
|
635 |
+
)
|
636 |
+
parser.add_argument(
|
637 |
+
"--lora_alpha",
|
638 |
+
type=int,
|
639 |
+
default=64,
|
640 |
+
help=(
|
641 |
+
"The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
|
642 |
+
" update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
|
643 |
+
),
|
644 |
+
)
|
645 |
+
parser.add_argument(
|
646 |
+
"--lora_dropout",
|
647 |
+
type=float,
|
648 |
+
default=0.0,
|
649 |
+
help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
|
650 |
+
)
|
651 |
+
parser.add_argument(
|
652 |
+
"--lora_target_modules",
|
653 |
+
type=str,
|
654 |
+
default=None,
|
655 |
+
help=(
|
656 |
+
"A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
|
657 |
+
" be used. By default, LoRA will be applied to all conv and linear layers."
|
658 |
+
),
|
659 |
+
)
|
660 |
+
parser.add_argument(
|
661 |
+
"--vae_encode_batch_size",
|
662 |
+
type=int,
|
663 |
+
default=8,
|
664 |
+
required=False,
|
665 |
+
help=(
|
666 |
+
"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
|
667 |
+
" Encoding or decoding the whole batch at once may run into OOM issues."
|
668 |
+
),
|
669 |
+
)
|
670 |
+
parser.add_argument(
|
671 |
+
"--timestep_scaling_factor",
|
672 |
+
type=float,
|
673 |
+
default=10.0,
|
674 |
+
help=(
|
675 |
+
"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
|
676 |
+
" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
|
677 |
+
" suffice."
|
678 |
+
),
|
679 |
+
)
|
680 |
+
# ----Mixed Precision----
|
681 |
+
parser.add_argument(
|
682 |
+
"--mixed_precision",
|
683 |
+
type=str,
|
684 |
+
default=None,
|
685 |
+
choices=["no", "fp16", "bf16"],
|
686 |
+
help=(
|
687 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
688 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
689 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
690 |
+
),
|
691 |
+
)
|
692 |
+
parser.add_argument(
|
693 |
+
"--allow_tf32",
|
694 |
+
action="store_true",
|
695 |
+
help=(
|
696 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
697 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
698 |
+
),
|
699 |
+
)
|
700 |
+
# ----Training Optimizations----
|
701 |
+
parser.add_argument(
|
702 |
+
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
703 |
+
)
|
704 |
+
parser.add_argument(
|
705 |
+
"--gradient_checkpointing",
|
706 |
+
action="store_true",
|
707 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
708 |
+
)
|
709 |
+
# ----Distributed Training----
|
710 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
711 |
+
# ----------Validation Arguments----------
|
712 |
+
parser.add_argument(
|
713 |
+
"--validation_steps",
|
714 |
+
type=int,
|
715 |
+
default=3000,
|
716 |
+
help="Run validation every X steps.",
|
717 |
+
)
|
718 |
+
parser.add_argument(
|
719 |
+
"--validation_image",
|
720 |
+
type=str,
|
721 |
+
default=None,
|
722 |
+
nargs="+",
|
723 |
+
help=(
|
724 |
+
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
|
725 |
+
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
|
726 |
+
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
|
727 |
+
" `--validation_image` that will be used with all `--validation_prompt`s."
|
728 |
+
),
|
729 |
+
)
|
730 |
+
parser.add_argument(
|
731 |
+
"--validation_prompt",
|
732 |
+
type=str,
|
733 |
+
default=None,
|
734 |
+
nargs="+",
|
735 |
+
help=(
|
736 |
+
"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
|
737 |
+
" Provide either a matching number of `--validation_image`s, a single `--validation_image`"
|
738 |
+
" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
|
739 |
+
),
|
740 |
+
)
|
741 |
+
parser.add_argument(
|
742 |
+
"--sanity_check",
|
743 |
+
action="store_true",
|
744 |
+
help=(
|
745 |
+
"sanity check"
|
746 |
+
),
|
747 |
+
)
|
748 |
+
# ----------Huggingface Hub Arguments-----------
|
749 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
750 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
751 |
+
parser.add_argument(
|
752 |
+
"--hub_model_id",
|
753 |
+
type=str,
|
754 |
+
default=None,
|
755 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
756 |
+
)
|
757 |
+
# ----------Accelerate Arguments----------
|
758 |
+
parser.add_argument(
|
759 |
+
"--tracker_project_name",
|
760 |
+
type=str,
|
761 |
+
default="trian",
|
762 |
+
help=(
|
763 |
+
"The `project_name` argument passed to Accelerator.init_trackers for"
|
764 |
+
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
|
765 |
+
),
|
766 |
+
)
|
767 |
+
|
768 |
+
args = parser.parse_args()
|
769 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
770 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
771 |
+
args.local_rank = env_local_rank
|
772 |
+
|
773 |
+
return args
|
774 |
+
|
775 |
+
|
776 |
+
def main(args):
|
777 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
778 |
+
raise ValueError(
|
779 |
+
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
780 |
+
" Please use `huggingface-cli login` to authenticate with the Hub."
|
781 |
+
)
|
782 |
+
|
783 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
784 |
+
|
785 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
786 |
+
|
787 |
+
accelerator = Accelerator(
|
788 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
789 |
+
mixed_precision=args.mixed_precision,
|
790 |
+
log_with=args.report_to,
|
791 |
+
project_config=accelerator_project_config,
|
792 |
+
)
|
793 |
+
|
794 |
+
# Make one log on every process with the configuration for debugging.
|
795 |
+
logging.basicConfig(
|
796 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
797 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
798 |
+
level=logging.INFO,
|
799 |
+
)
|
800 |
+
logger.info(accelerator.state, main_process_only=False)
|
801 |
+
if accelerator.is_local_main_process:
|
802 |
+
transformers.utils.logging.set_verbosity_warning()
|
803 |
+
diffusers.utils.logging.set_verbosity_info()
|
804 |
+
else:
|
805 |
+
transformers.utils.logging.set_verbosity_error()
|
806 |
+
diffusers.utils.logging.set_verbosity_error()
|
807 |
+
|
808 |
+
# If passed along, set the training seed now.
|
809 |
+
if args.seed is not None:
|
810 |
+
set_seed(args.seed)
|
811 |
+
|
812 |
+
# Handle the repository creation.
|
813 |
+
if accelerator.is_main_process:
|
814 |
+
if args.output_dir is not None:
|
815 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
816 |
+
|
817 |
+
# 1. Create the noise scheduler and the desired noise schedule.
|
818 |
+
noise_scheduler = DDPMScheduler.from_pretrained(
|
819 |
+
args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.teacher_revision
|
820 |
+
)
|
821 |
+
noise_scheduler.config.num_train_timesteps = args.num_train_timesteps
|
822 |
+
lcm_scheduler = LCMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
823 |
+
|
824 |
+
# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
|
825 |
+
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
|
826 |
+
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
|
827 |
+
# Initialize the DDIM ODE solver for distillation.
|
828 |
+
solver = DDIMSolver(
|
829 |
+
noise_scheduler.alphas_cumprod.numpy(),
|
830 |
+
timesteps=noise_scheduler.config.num_train_timesteps,
|
831 |
+
ddim_timesteps=args.num_ddim_timesteps,
|
832 |
+
)
|
833 |
+
|
834 |
+
# 2. Load tokenizers from SDXL checkpoint.
|
835 |
+
tokenizer_one = AutoTokenizer.from_pretrained(
|
836 |
+
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
|
837 |
+
)
|
838 |
+
tokenizer_two = AutoTokenizer.from_pretrained(
|
839 |
+
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.teacher_revision, use_fast=False
|
840 |
+
)
|
841 |
+
|
842 |
+
# 3. Load text encoders from SDXL checkpoint.
|
843 |
+
# import correct text encoder classes
|
844 |
+
text_encoder_cls_one = import_model_class_from_model_name_or_path(
|
845 |
+
args.pretrained_model_name_or_path, args.teacher_revision
|
846 |
+
)
|
847 |
+
text_encoder_cls_two = import_model_class_from_model_name_or_path(
|
848 |
+
args.pretrained_model_name_or_path, args.teacher_revision, subfolder="text_encoder_2"
|
849 |
+
)
|
850 |
+
|
851 |
+
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
852 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.teacher_revision
|
853 |
+
)
|
854 |
+
text_encoder_two = text_encoder_cls_two.from_pretrained(
|
855 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.teacher_revision
|
856 |
+
)
|
857 |
+
|
858 |
+
if args.use_clip_encoder:
|
859 |
+
image_processor = CLIPImageProcessor()
|
860 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.feature_extractor_path)
|
861 |
+
else:
|
862 |
+
image_processor = AutoImageProcessor.from_pretrained(args.feature_extractor_path)
|
863 |
+
image_encoder = AutoModel.from_pretrained(args.feature_extractor_path)
|
864 |
+
|
865 |
+
# 4. Load VAE from SDXL checkpoint (or more stable VAE)
|
866 |
+
vae_path = (
|
867 |
+
args.pretrained_model_name_or_path
|
868 |
+
if args.pretrained_vae_model_name_or_path is None
|
869 |
+
else args.pretrained_vae_model_name_or_path
|
870 |
+
)
|
871 |
+
vae = AutoencoderKL.from_pretrained(
|
872 |
+
vae_path,
|
873 |
+
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
874 |
+
revision=args.teacher_revision,
|
875 |
+
)
|
876 |
+
|
877 |
+
# 7. Create online student U-Net.
|
878 |
+
unet = UNet2DConditionModel.from_pretrained(
|
879 |
+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.teacher_revision
|
880 |
+
)
|
881 |
+
|
882 |
+
# Resampler for project model in IP-Adapter
|
883 |
+
image_proj_model = Resampler(
|
884 |
+
dim=1280,
|
885 |
+
depth=4,
|
886 |
+
dim_head=64,
|
887 |
+
heads=20,
|
888 |
+
num_queries=args.adapter_tokens,
|
889 |
+
embedding_dim=image_encoder.config.hidden_size,
|
890 |
+
output_dim=unet.config.cross_attention_dim,
|
891 |
+
ff_mult=4
|
892 |
+
)
|
893 |
+
|
894 |
+
# Load the same adapter in both unet.
|
895 |
+
init_adapter_in_unet(
|
896 |
+
unet,
|
897 |
+
image_proj_model,
|
898 |
+
os.path.join(args.pretrained_adapter_model_path, 'adapter_ckpt.pt'),
|
899 |
+
adapter_tokens=args.adapter_tokens,
|
900 |
+
)
|
901 |
+
|
902 |
+
# Check that all trainable models are in full precision
|
903 |
+
low_precision_error_string = (
|
904 |
+
" Please make sure to always have all model weights in full float32 precision when starting training - even if"
|
905 |
+
" doing mixed precision training, copy of the weights should still be float32."
|
906 |
+
)
|
907 |
+
|
908 |
+
def unwrap_model(model):
|
909 |
+
model = accelerator.unwrap_model(model)
|
910 |
+
model = model._orig_mod if is_compiled_module(model) else model
|
911 |
+
return model
|
912 |
+
|
913 |
+
if unwrap_model(unet).dtype != torch.float32:
|
914 |
+
raise ValueError(
|
915 |
+
f"Controlnet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}"
|
916 |
+
)
|
917 |
+
|
918 |
+
if args.pretrained_lcm_lora_path is not None:
|
919 |
+
lora_state_dict, alpha_dict = StableDiffusionXLPipeline.lora_state_dict(args.pretrained_lcm_lora_path)
|
920 |
+
unet_state_dict = {
|
921 |
+
f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
|
922 |
+
}
|
923 |
+
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
924 |
+
lora_state_dict = dict()
|
925 |
+
for k, v in unet_state_dict.items():
|
926 |
+
if "ip" in k:
|
927 |
+
k = k.replace("attn2", "attn2.processor")
|
928 |
+
lora_state_dict[k] = v
|
929 |
+
else:
|
930 |
+
lora_state_dict[k] = v
|
931 |
+
if alpha_dict:
|
932 |
+
args.lora_alpha = next(iter(alpha_dict.values()))
|
933 |
+
else:
|
934 |
+
args.lora_alpha = 1
|
935 |
+
# 9. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
|
936 |
+
if args.lora_target_modules is not None:
|
937 |
+
lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")]
|
938 |
+
else:
|
939 |
+
lora_target_modules = [
|
940 |
+
"to_q",
|
941 |
+
"to_kv",
|
942 |
+
"0.to_out",
|
943 |
+
"attn1.to_k",
|
944 |
+
"attn1.to_v",
|
945 |
+
"to_k_ip",
|
946 |
+
"to_v_ip",
|
947 |
+
"ln_k_ip.linear",
|
948 |
+
"ln_v_ip.linear",
|
949 |
+
"to_out.0",
|
950 |
+
"proj_in",
|
951 |
+
"proj_out",
|
952 |
+
"ff.net.0.proj",
|
953 |
+
"ff.net.2",
|
954 |
+
"conv1",
|
955 |
+
"conv2",
|
956 |
+
"conv_shortcut",
|
957 |
+
"downsamplers.0.conv",
|
958 |
+
"upsamplers.0.conv",
|
959 |
+
"time_emb_proj",
|
960 |
+
]
|
961 |
+
lora_config = LoraConfig(
|
962 |
+
r=args.lora_rank,
|
963 |
+
target_modules=lora_target_modules,
|
964 |
+
lora_alpha=args.lora_alpha,
|
965 |
+
lora_dropout=args.lora_dropout,
|
966 |
+
)
|
967 |
+
|
968 |
+
# Legacy
|
969 |
+
# for k, v in lcm_pipe.unet.state_dict().items():
|
970 |
+
# if "lora" in k or "base_layer" in k:
|
971 |
+
# lcm_dict[k.replace("default_0", "default")] = v
|
972 |
+
|
973 |
+
unet.add_adapter(lora_config)
|
974 |
+
if args.pretrained_lcm_lora_path is not None:
|
975 |
+
incompatible_keys = set_peft_model_state_dict(unet, lora_state_dict, adapter_name="default")
|
976 |
+
if incompatible_keys is not None:
|
977 |
+
# check only for unexpected keys
|
978 |
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
979 |
+
if unexpected_keys:
|
980 |
+
logger.warning(
|
981 |
+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
982 |
+
f" {unexpected_keys}. "
|
983 |
+
)
|
984 |
+
|
985 |
+
# 6. Freeze unet, vae, text_encoders.
|
986 |
+
vae.requires_grad_(False)
|
987 |
+
text_encoder_one.requires_grad_(False)
|
988 |
+
text_encoder_two.requires_grad_(False)
|
989 |
+
image_encoder.requires_grad_(False)
|
990 |
+
unet.requires_grad_(False)
|
991 |
+
|
992 |
+
# 10. Handle saving and loading of checkpoints
|
993 |
+
# `accelerate` 0.16.0 will have better support for customized saving
|
994 |
+
if args.save_only_adapter:
|
995 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
996 |
+
def save_model_hook(models, weights, output_dir):
|
997 |
+
if accelerator.is_main_process:
|
998 |
+
for model in models:
|
999 |
+
if isinstance(model, type(unwrap_model(unet))): # save adapter only
|
1000 |
+
unet_ = unwrap_model(model)
|
1001 |
+
# also save the checkpoints in native `diffusers` format so that it can be easily
|
1002 |
+
# be independently loaded via `load_lora_weights()`.
|
1003 |
+
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet_))
|
1004 |
+
StableDiffusionXLPipeline.save_lora_weights(output_dir, unet_lora_layers=state_dict, safe_serialization=False)
|
1005 |
+
|
1006 |
+
weights.pop()
|
1007 |
+
|
1008 |
+
def load_model_hook(models, input_dir):
|
1009 |
+
|
1010 |
+
while len(models) > 0:
|
1011 |
+
# pop models so that they are not loaded again
|
1012 |
+
model = models.pop()
|
1013 |
+
|
1014 |
+
if isinstance(model, type(unwrap_model(unet))):
|
1015 |
+
unet_ = unwrap_model(model)
|
1016 |
+
lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir)
|
1017 |
+
unet_state_dict = {
|
1018 |
+
f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
|
1019 |
+
}
|
1020 |
+
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
1021 |
+
lora_state_dict = dict()
|
1022 |
+
for k, v in unet_state_dict.items():
|
1023 |
+
if "ip" in k:
|
1024 |
+
k = k.replace("attn2", "attn2.processor")
|
1025 |
+
lora_state_dict[k] = v
|
1026 |
+
else:
|
1027 |
+
lora_state_dict[k] = v
|
1028 |
+
incompatible_keys = set_peft_model_state_dict(unet_, lora_state_dict, adapter_name="default")
|
1029 |
+
if incompatible_keys is not None:
|
1030 |
+
# check only for unexpected keys
|
1031 |
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
1032 |
+
if unexpected_keys:
|
1033 |
+
logger.warning(
|
1034 |
+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
1035 |
+
f" {unexpected_keys}. "
|
1036 |
+
)
|
1037 |
+
|
1038 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
1039 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
1040 |
+
|
1041 |
+
# 11. Enable optimizations
|
1042 |
+
if args.enable_xformers_memory_efficient_attention:
|
1043 |
+
if is_xformers_available():
|
1044 |
+
import xformers
|
1045 |
+
|
1046 |
+
xformers_version = version.parse(xformers.__version__)
|
1047 |
+
if xformers_version == version.parse("0.0.16"):
|
1048 |
+
logger.warning(
|
1049 |
+
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
1050 |
+
)
|
1051 |
+
unet.enable_xformers_memory_efficient_attention()
|
1052 |
+
else:
|
1053 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
1054 |
+
|
1055 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
1056 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
1057 |
+
if args.allow_tf32:
|
1058 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
1059 |
+
|
1060 |
+
if args.gradient_checkpointing:
|
1061 |
+
unet.enable_gradient_checkpointing()
|
1062 |
+
vae.enable_gradient_checkpointing()
|
1063 |
+
|
1064 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
1065 |
+
if args.use_8bit_adam:
|
1066 |
+
try:
|
1067 |
+
import bitsandbytes as bnb
|
1068 |
+
except ImportError:
|
1069 |
+
raise ImportError(
|
1070 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
1071 |
+
)
|
1072 |
+
|
1073 |
+
optimizer_class = bnb.optim.AdamW8bit
|
1074 |
+
else:
|
1075 |
+
optimizer_class = torch.optim.AdamW
|
1076 |
+
|
1077 |
+
# 12. Optimizer creation
|
1078 |
+
lora_params, non_lora_params = seperate_lora_params_from_unet(unet)
|
1079 |
+
params_to_optimize = lora_params
|
1080 |
+
optimizer = optimizer_class(
|
1081 |
+
params_to_optimize,
|
1082 |
+
lr=args.learning_rate,
|
1083 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
1084 |
+
weight_decay=args.adam_weight_decay,
|
1085 |
+
eps=args.adam_epsilon,
|
1086 |
+
)
|
1087 |
+
|
1088 |
+
# 13. Dataset creation and data processing
|
1089 |
+
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
1090 |
+
# download the dataset.
|
1091 |
+
datasets = []
|
1092 |
+
datasets_name = []
|
1093 |
+
datasets_weights = []
|
1094 |
+
deg_pipeline = RealESRGANDegradation(device=accelerator.device, resolution=args.resolution)
|
1095 |
+
if args.data_config_path is not None:
|
1096 |
+
data_config: DataConfig = pyrallis.load(DataConfig, open(args.data_config_path, "r"))
|
1097 |
+
for single_dataset in data_config.datasets:
|
1098 |
+
datasets_weights.append(single_dataset.dataset_weight)
|
1099 |
+
datasets_name.append(single_dataset.dataset_folder)
|
1100 |
+
dataset_dir = os.path.join(args.train_data_dir, single_dataset.dataset_folder)
|
1101 |
+
image_dataset = get_train_dataset(dataset_dir, dataset_dir, args, accelerator)
|
1102 |
+
image_dataset = prepare_train_dataset(image_dataset, accelerator, deg_pipeline)
|
1103 |
+
datasets.append(image_dataset)
|
1104 |
+
# TODO: Validation dataset
|
1105 |
+
if data_config.val_dataset is not None:
|
1106 |
+
val_dataset = get_train_dataset(dataset_name, dataset_dir, args, accelerator)
|
1107 |
+
logger.info(f"Datasets mixing: {list(zip(datasets_name, datasets_weights))}")
|
1108 |
+
|
1109 |
+
# Mix training datasets.
|
1110 |
+
sampler_train = None
|
1111 |
+
if len(datasets) == 1:
|
1112 |
+
train_dataset = datasets[0]
|
1113 |
+
else:
|
1114 |
+
# Weighted each dataset
|
1115 |
+
train_dataset = torch.utils.data.ConcatDataset(datasets)
|
1116 |
+
dataset_weights = []
|
1117 |
+
for single_dataset, single_weight in zip(datasets, datasets_weights):
|
1118 |
+
dataset_weights.extend([len(train_dataset) / len(single_dataset) * single_weight] * len(single_dataset))
|
1119 |
+
sampler_train = torch.utils.data.WeightedRandomSampler(
|
1120 |
+
weights=dataset_weights,
|
1121 |
+
num_samples=len(dataset_weights)
|
1122 |
+
)
|
1123 |
+
|
1124 |
+
# DataLoaders creation:
|
1125 |
+
train_dataloader = torch.utils.data.DataLoader(
|
1126 |
+
train_dataset,
|
1127 |
+
sampler=sampler_train,
|
1128 |
+
shuffle=True if sampler_train is None else False,
|
1129 |
+
collate_fn=collate_fn,
|
1130 |
+
batch_size=args.train_batch_size,
|
1131 |
+
num_workers=args.dataloader_num_workers,
|
1132 |
+
)
|
1133 |
+
|
1134 |
+
# 14. Embeddings for the UNet.
|
1135 |
+
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
1136 |
+
def compute_embeddings(prompt_batch, original_sizes, crop_coords, text_encoders, tokenizers, is_train=True):
|
1137 |
+
def compute_time_ids(original_size, crops_coords_top_left):
|
1138 |
+
target_size = (args.resolution, args.resolution)
|
1139 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
1140 |
+
add_time_ids = torch.tensor([add_time_ids])
|
1141 |
+
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
|
1142 |
+
return add_time_ids
|
1143 |
+
|
1144 |
+
prompt_embeds, pooled_prompt_embeds = encode_prompt(prompt_batch, text_encoders, tokenizers, is_train)
|
1145 |
+
add_text_embeds = pooled_prompt_embeds
|
1146 |
+
|
1147 |
+
add_time_ids = torch.cat([compute_time_ids(s, c) for s, c in zip(original_sizes, crop_coords)])
|
1148 |
+
|
1149 |
+
prompt_embeds = prompt_embeds.to(accelerator.device)
|
1150 |
+
add_text_embeds = add_text_embeds.to(accelerator.device)
|
1151 |
+
unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
1152 |
+
|
1153 |
+
return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
|
1154 |
+
|
1155 |
+
text_encoders = [text_encoder_one, text_encoder_two]
|
1156 |
+
tokenizers = [tokenizer_one, tokenizer_two]
|
1157 |
+
|
1158 |
+
compute_embeddings_fn = functools.partial(compute_embeddings, text_encoders=text_encoders, tokenizers=tokenizers)
|
1159 |
+
|
1160 |
+
# Move pixels into latents.
|
1161 |
+
@torch.no_grad()
|
1162 |
+
def convert_to_latent(pixels):
|
1163 |
+
model_input = vae.encode(pixels).latent_dist.sample()
|
1164 |
+
model_input = model_input * vae.config.scaling_factor
|
1165 |
+
if args.pretrained_vae_model_name_or_path is None:
|
1166 |
+
model_input = model_input.to(weight_dtype)
|
1167 |
+
return model_input
|
1168 |
+
|
1169 |
+
# 15. LR Scheduler creation
|
1170 |
+
# Scheduler and math around the number of training steps.
|
1171 |
+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
1172 |
+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
1173 |
+
if args.max_train_steps is None:
|
1174 |
+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
1175 |
+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
1176 |
+
num_training_steps_for_scheduler = (
|
1177 |
+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
|
1178 |
+
)
|
1179 |
+
else:
|
1180 |
+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
1181 |
+
|
1182 |
+
if args.scale_lr:
|
1183 |
+
args.learning_rate = (
|
1184 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
1185 |
+
)
|
1186 |
+
|
1187 |
+
# Make sure the trainable params are in float32.
|
1188 |
+
if args.mixed_precision == "fp16":
|
1189 |
+
# only upcast trainable parameters (LoRA) into fp32
|
1190 |
+
cast_training_params(unet, dtype=torch.float32)
|
1191 |
+
|
1192 |
+
lr_scheduler = get_scheduler(
|
1193 |
+
args.lr_scheduler,
|
1194 |
+
optimizer=optimizer,
|
1195 |
+
num_warmup_steps=num_warmup_steps_for_scheduler,
|
1196 |
+
num_training_steps=num_training_steps_for_scheduler,
|
1197 |
+
)
|
1198 |
+
|
1199 |
+
# 16. Prepare for training
|
1200 |
+
# Prepare everything with our `accelerator`.
|
1201 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
1202 |
+
unet, optimizer, train_dataloader, lr_scheduler
|
1203 |
+
)
|
1204 |
+
|
1205 |
+
# 8. Handle mixed precision and device placement
|
1206 |
+
# For mixed precision training we cast all non-trainable weigths to half-precision
|
1207 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
1208 |
+
weight_dtype = torch.float32
|
1209 |
+
if accelerator.mixed_precision == "fp16":
|
1210 |
+
weight_dtype = torch.float16
|
1211 |
+
elif accelerator.mixed_precision == "bf16":
|
1212 |
+
weight_dtype = torch.bfloat16
|
1213 |
+
|
1214 |
+
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
1215 |
+
# The VAE is in float32 to avoid NaN losses.
|
1216 |
+
if args.pretrained_vae_model_name_or_path is None:
|
1217 |
+
vae.to(accelerator.device, dtype=torch.float32)
|
1218 |
+
else:
|
1219 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
1220 |
+
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
|
1221 |
+
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
|
1222 |
+
image_encoder.to(accelerator.device, dtype=weight_dtype)
|
1223 |
+
for p in non_lora_params:
|
1224 |
+
p.data = p.data.to(dtype=weight_dtype)
|
1225 |
+
for p in lora_params:
|
1226 |
+
p.requires_grad_(True)
|
1227 |
+
unet.to(accelerator.device)
|
1228 |
+
|
1229 |
+
# Also move the alpha and sigma noise schedules to accelerator.device.
|
1230 |
+
alpha_schedule = alpha_schedule.to(accelerator.device)
|
1231 |
+
sigma_schedule = sigma_schedule.to(accelerator.device)
|
1232 |
+
solver = solver.to(accelerator.device)
|
1233 |
+
|
1234 |
+
# Instantiate Loss.
|
1235 |
+
losses_configs: LossesConfig = pyrallis.load(LossesConfig, open(args.losses_config_path, "r"))
|
1236 |
+
lcm_losses = list()
|
1237 |
+
for loss_config in losses_configs.lcm_losses:
|
1238 |
+
logger.info(f"Loading lcm loss: {loss_config.name}")
|
1239 |
+
loss = namedtuple("loss", ["loss", "weight"])
|
1240 |
+
loss_class = eval(loss_config.name)
|
1241 |
+
lcm_losses.append(loss(loss_class(
|
1242 |
+
visualize_every_k=loss_config.visualize_every_k,
|
1243 |
+
dtype=weight_dtype,
|
1244 |
+
accelerator=accelerator,
|
1245 |
+
dino_model=image_encoder,
|
1246 |
+
dino_preprocess=image_processor,
|
1247 |
+
huber_c=args.huber_c,
|
1248 |
+
**loss_config.init_params), weight=loss_config.weight))
|
1249 |
+
|
1250 |
+
# Final check.
|
1251 |
+
for n, p in unet.named_parameters():
|
1252 |
+
if p.requires_grad:
|
1253 |
+
assert "lora" in n, n
|
1254 |
+
assert p.dtype == torch.float32, n
|
1255 |
+
else:
|
1256 |
+
assert "lora" not in n, f"{n}"
|
1257 |
+
assert p.dtype == weight_dtype, n
|
1258 |
+
if args.sanity_check:
|
1259 |
+
if args.resume_from_checkpoint:
|
1260 |
+
if args.resume_from_checkpoint != "latest":
|
1261 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
1262 |
+
else:
|
1263 |
+
# Get the most recent checkpoint
|
1264 |
+
dirs = os.listdir(args.output_dir)
|
1265 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
1266 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
1267 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
1268 |
+
|
1269 |
+
if path is None:
|
1270 |
+
accelerator.print(
|
1271 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
1272 |
+
)
|
1273 |
+
args.resume_from_checkpoint = None
|
1274 |
+
initial_global_step = 0
|
1275 |
+
else:
|
1276 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
1277 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
1278 |
+
|
1279 |
+
# Check input data
|
1280 |
+
batch = next(iter(train_dataloader))
|
1281 |
+
lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
|
1282 |
+
out_images = log_validation(unwrap_model(unet), vae, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two,
|
1283 |
+
lcm_scheduler, image_encoder, image_processor,
|
1284 |
+
args, accelerator, weight_dtype, step=0, lq_img=lq_img, gt_img=gt_img, is_final_validation=False, log_local=True)
|
1285 |
+
exit()
|
1286 |
+
|
1287 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
1288 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
1289 |
+
if args.max_train_steps is None:
|
1290 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
1291 |
+
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
|
1292 |
+
logger.warning(
|
1293 |
+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
1294 |
+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
1295 |
+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
1296 |
+
)
|
1297 |
+
# Afterwards we recalculate our number of training epochs
|
1298 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
1299 |
+
|
1300 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
1301 |
+
# The trackers initializes automatically on the main process.
|
1302 |
+
if accelerator.is_main_process:
|
1303 |
+
tracker_config = dict(vars(args))
|
1304 |
+
|
1305 |
+
# tensorboard cannot handle list types for config
|
1306 |
+
tracker_config.pop("validation_prompt")
|
1307 |
+
tracker_config.pop("validation_image")
|
1308 |
+
|
1309 |
+
accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
|
1310 |
+
|
1311 |
+
# 17. Train!
|
1312 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
1313 |
+
|
1314 |
+
logger.info("***** Running training *****")
|
1315 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
1316 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
1317 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
1318 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
1319 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
1320 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
1321 |
+
global_step = 0
|
1322 |
+
first_epoch = 0
|
1323 |
+
|
1324 |
+
# Potentially load in the weights and states from a previous save
|
1325 |
+
if args.resume_from_checkpoint:
|
1326 |
+
if args.resume_from_checkpoint != "latest":
|
1327 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
1328 |
+
else:
|
1329 |
+
# Get the most recent checkpoint
|
1330 |
+
dirs = os.listdir(args.output_dir)
|
1331 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
1332 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
1333 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
1334 |
+
|
1335 |
+
if path is None:
|
1336 |
+
accelerator.print(
|
1337 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
1338 |
+
)
|
1339 |
+
args.resume_from_checkpoint = None
|
1340 |
+
initial_global_step = 0
|
1341 |
+
else:
|
1342 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
1343 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
1344 |
+
global_step = int(path.split("-")[1])
|
1345 |
+
|
1346 |
+
initial_global_step = global_step
|
1347 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
1348 |
+
else:
|
1349 |
+
initial_global_step = 0
|
1350 |
+
|
1351 |
+
progress_bar = tqdm(
|
1352 |
+
range(0, args.max_train_steps),
|
1353 |
+
initial=initial_global_step,
|
1354 |
+
desc="Steps",
|
1355 |
+
# Only show the progress bar once on each machine.
|
1356 |
+
disable=not accelerator.is_local_main_process,
|
1357 |
+
)
|
1358 |
+
|
1359 |
+
unet.train()
|
1360 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
1361 |
+
for step, batch in enumerate(train_dataloader):
|
1362 |
+
with accelerator.accumulate(unet):
|
1363 |
+
total_loss = torch.tensor(0.0)
|
1364 |
+
bsz = batch["images"].shape[0]
|
1365 |
+
|
1366 |
+
# Drop conditions.
|
1367 |
+
rand_tensor = torch.rand(bsz)
|
1368 |
+
drop_image_idx = rand_tensor < args.image_drop_rate
|
1369 |
+
drop_text_idx = (rand_tensor >= args.image_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate)
|
1370 |
+
drop_both_idx = (rand_tensor >= args.image_drop_rate + args.text_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate + args.cond_drop_rate)
|
1371 |
+
drop_image_idx = drop_image_idx | drop_both_idx
|
1372 |
+
drop_text_idx = drop_text_idx | drop_both_idx
|
1373 |
+
|
1374 |
+
with torch.no_grad():
|
1375 |
+
lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
|
1376 |
+
lq_pt = image_processor(
|
1377 |
+
images=lq_img*0.5+0.5,
|
1378 |
+
do_rescale=False, return_tensors="pt"
|
1379 |
+
).pixel_values
|
1380 |
+
image_embeds = prepare_training_image_embeds(
|
1381 |
+
image_encoder, image_processor,
|
1382 |
+
ip_adapter_image=lq_pt, ip_adapter_image_embeds=None,
|
1383 |
+
device=accelerator.device, drop_rate=args.image_drop_rate, output_hidden_state=args.image_encoder_hidden_feature,
|
1384 |
+
idx_to_replace=drop_image_idx
|
1385 |
+
)
|
1386 |
+
uncond_image_embeds = prepare_training_image_embeds(
|
1387 |
+
image_encoder, image_processor,
|
1388 |
+
ip_adapter_image=lq_pt, ip_adapter_image_embeds=None,
|
1389 |
+
device=accelerator.device, drop_rate=1.0, output_hidden_state=args.image_encoder_hidden_feature,
|
1390 |
+
idx_to_replace=torch.ones_like(drop_image_idx)
|
1391 |
+
)
|
1392 |
+
# 1. Load and process the image and text conditioning
|
1393 |
+
text, orig_size, crop_coords = (
|
1394 |
+
batch["text"],
|
1395 |
+
batch["original_sizes"],
|
1396 |
+
batch["crop_top_lefts"],
|
1397 |
+
)
|
1398 |
+
|
1399 |
+
encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)
|
1400 |
+
uncond_encoded_text = compute_embeddings_fn([""]*len(text), orig_size, crop_coords)
|
1401 |
+
|
1402 |
+
# encode pixel values with batch size of at most args.vae_encode_batch_size
|
1403 |
+
gt_img = gt_img.to(dtype=vae.dtype)
|
1404 |
+
latents = []
|
1405 |
+
for i in range(0, gt_img.shape[0], args.vae_encode_batch_size):
|
1406 |
+
latents.append(vae.encode(gt_img[i : i + args.vae_encode_batch_size]).latent_dist.sample())
|
1407 |
+
latents = torch.cat(latents, dim=0)
|
1408 |
+
# latents = convert_to_latent(gt_img)
|
1409 |
+
|
1410 |
+
latents = latents * vae.config.scaling_factor
|
1411 |
+
if args.pretrained_vae_model_name_or_path is None:
|
1412 |
+
latents = latents.to(weight_dtype)
|
1413 |
+
|
1414 |
+
# 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
|
1415 |
+
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
|
1416 |
+
bsz = latents.shape[0]
|
1417 |
+
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
|
1418 |
+
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
|
1419 |
+
start_timesteps = solver.ddim_timesteps[index]
|
1420 |
+
timesteps = start_timesteps - topk
|
1421 |
+
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
1422 |
+
|
1423 |
+
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
1424 |
+
c_skip_start, c_out_start = scalings_for_boundary_conditions(
|
1425 |
+
start_timesteps, timestep_scaling=args.timestep_scaling_factor
|
1426 |
+
)
|
1427 |
+
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
1428 |
+
c_skip, c_out = scalings_for_boundary_conditions(
|
1429 |
+
timesteps, timestep_scaling=args.timestep_scaling_factor
|
1430 |
+
)
|
1431 |
+
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
1432 |
+
|
1433 |
+
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
1434 |
+
# timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
1435 |
+
noise = torch.randn_like(latents)
|
1436 |
+
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
|
1437 |
+
|
1438 |
+
# 5. Sample a random guidance scale w from U[w_min, w_max]
|
1439 |
+
# Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding
|
1440 |
+
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
|
1441 |
+
w = w.reshape(bsz, 1, 1, 1)
|
1442 |
+
w = w.to(device=latents.device, dtype=latents.dtype)
|
1443 |
+
|
1444 |
+
# 6. Prepare prompt embeds and unet_added_conditions
|
1445 |
+
prompt_embeds = encoded_text.pop("prompt_embeds")
|
1446 |
+
encoded_text["image_embeds"] = image_embeds
|
1447 |
+
uncond_prompt_embeds = uncond_encoded_text.pop("prompt_embeds")
|
1448 |
+
uncond_encoded_text["image_embeds"] = image_embeds
|
1449 |
+
|
1450 |
+
# 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
|
1451 |
+
noise_pred = unet(
|
1452 |
+
noisy_model_input,
|
1453 |
+
start_timesteps,
|
1454 |
+
encoder_hidden_states=uncond_prompt_embeds,
|
1455 |
+
added_cond_kwargs=uncond_encoded_text,
|
1456 |
+
).sample
|
1457 |
+
pred_x_0 = get_predicted_original_sample(
|
1458 |
+
noise_pred,
|
1459 |
+
start_timesteps,
|
1460 |
+
noisy_model_input,
|
1461 |
+
noise_scheduler.config.prediction_type,
|
1462 |
+
alpha_schedule,
|
1463 |
+
sigma_schedule,
|
1464 |
+
)
|
1465 |
+
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
|
1466 |
+
|
1467 |
+
# 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
|
1468 |
+
# predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
|
1469 |
+
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
|
1470 |
+
# solver timestep.
|
1471 |
+
|
1472 |
+
# With the adapters disabled, the `unet` is the regular teacher model.
|
1473 |
+
accelerator.unwrap_model(unet).disable_adapters()
|
1474 |
+
with torch.no_grad():
|
1475 |
+
|
1476 |
+
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
1477 |
+
teacher_added_cond = dict()
|
1478 |
+
for k,v in encoded_text.items():
|
1479 |
+
if isinstance(v, torch.Tensor):
|
1480 |
+
teacher_added_cond[k] = v.to(weight_dtype)
|
1481 |
+
else:
|
1482 |
+
teacher_image_embeds = []
|
1483 |
+
for img_emb in v:
|
1484 |
+
teacher_image_embeds.append(img_emb.to(weight_dtype))
|
1485 |
+
teacher_added_cond[k] = teacher_image_embeds
|
1486 |
+
cond_teacher_output = unet(
|
1487 |
+
noisy_model_input,
|
1488 |
+
start_timesteps,
|
1489 |
+
encoder_hidden_states=prompt_embeds,
|
1490 |
+
added_cond_kwargs=teacher_added_cond,
|
1491 |
+
).sample
|
1492 |
+
cond_pred_x0 = get_predicted_original_sample(
|
1493 |
+
cond_teacher_output,
|
1494 |
+
start_timesteps,
|
1495 |
+
noisy_model_input,
|
1496 |
+
noise_scheduler.config.prediction_type,
|
1497 |
+
alpha_schedule,
|
1498 |
+
sigma_schedule,
|
1499 |
+
)
|
1500 |
+
cond_pred_noise = get_predicted_noise(
|
1501 |
+
cond_teacher_output,
|
1502 |
+
start_timesteps,
|
1503 |
+
noisy_model_input,
|
1504 |
+
noise_scheduler.config.prediction_type,
|
1505 |
+
alpha_schedule,
|
1506 |
+
sigma_schedule,
|
1507 |
+
)
|
1508 |
+
|
1509 |
+
# 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
|
1510 |
+
teacher_added_uncond = dict()
|
1511 |
+
uncond_encoded_text["image_embeds"] = uncond_image_embeds
|
1512 |
+
for k,v in uncond_encoded_text.items():
|
1513 |
+
if isinstance(v, torch.Tensor):
|
1514 |
+
teacher_added_uncond[k] = v.to(weight_dtype)
|
1515 |
+
else:
|
1516 |
+
teacher_uncond_image_embeds = []
|
1517 |
+
for img_emb in v:
|
1518 |
+
teacher_uncond_image_embeds.append(img_emb.to(weight_dtype))
|
1519 |
+
teacher_added_uncond[k] = teacher_uncond_image_embeds
|
1520 |
+
uncond_teacher_output = unet(
|
1521 |
+
noisy_model_input,
|
1522 |
+
start_timesteps,
|
1523 |
+
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
|
1524 |
+
added_cond_kwargs=teacher_added_uncond,
|
1525 |
+
).sample
|
1526 |
+
uncond_pred_x0 = get_predicted_original_sample(
|
1527 |
+
uncond_teacher_output,
|
1528 |
+
start_timesteps,
|
1529 |
+
noisy_model_input,
|
1530 |
+
noise_scheduler.config.prediction_type,
|
1531 |
+
alpha_schedule,
|
1532 |
+
sigma_schedule,
|
1533 |
+
)
|
1534 |
+
uncond_pred_noise = get_predicted_noise(
|
1535 |
+
uncond_teacher_output,
|
1536 |
+
start_timesteps,
|
1537 |
+
noisy_model_input,
|
1538 |
+
noise_scheduler.config.prediction_type,
|
1539 |
+
alpha_schedule,
|
1540 |
+
sigma_schedule,
|
1541 |
+
)
|
1542 |
+
|
1543 |
+
# 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
|
1544 |
+
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
|
1545 |
+
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
|
1546 |
+
pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
|
1547 |
+
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
|
1548 |
+
# augmented PF-ODE trajectory (solving backward in time)
|
1549 |
+
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
|
1550 |
+
x_prev = solver.ddim_step(pred_x0, pred_noise, index).to(weight_dtype)
|
1551 |
+
|
1552 |
+
# re-enable unet adapters to turn the `unet` into a student unet.
|
1553 |
+
accelerator.unwrap_model(unet).enable_adapters()
|
1554 |
+
|
1555 |
+
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
1556 |
+
# Note that we do not use a separate target network for LCM-LoRA distillation.
|
1557 |
+
with torch.no_grad():
|
1558 |
+
uncond_encoded_text["image_embeds"] = image_embeds
|
1559 |
+
target_added_cond = dict()
|
1560 |
+
for k,v in uncond_encoded_text.items():
|
1561 |
+
if isinstance(v, torch.Tensor):
|
1562 |
+
target_added_cond[k] = v.to(weight_dtype)
|
1563 |
+
else:
|
1564 |
+
target_image_embeds = []
|
1565 |
+
for img_emb in v:
|
1566 |
+
target_image_embeds.append(img_emb.to(weight_dtype))
|
1567 |
+
target_added_cond[k] = target_image_embeds
|
1568 |
+
target_noise_pred = unet(
|
1569 |
+
x_prev,
|
1570 |
+
timesteps,
|
1571 |
+
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
|
1572 |
+
added_cond_kwargs=target_added_cond,
|
1573 |
+
).sample
|
1574 |
+
pred_x_0 = get_predicted_original_sample(
|
1575 |
+
target_noise_pred,
|
1576 |
+
timesteps,
|
1577 |
+
x_prev,
|
1578 |
+
noise_scheduler.config.prediction_type,
|
1579 |
+
alpha_schedule,
|
1580 |
+
sigma_schedule,
|
1581 |
+
)
|
1582 |
+
target = c_skip * x_prev + c_out * pred_x_0
|
1583 |
+
|
1584 |
+
# 10. Calculate loss
|
1585 |
+
lcm_loss_arguments = {
|
1586 |
+
"target": target.float(),
|
1587 |
+
"predict": model_pred.float(),
|
1588 |
+
}
|
1589 |
+
loss_dict = dict()
|
1590 |
+
# total_loss = total_loss + torch.mean(
|
1591 |
+
# torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
|
1592 |
+
# )
|
1593 |
+
# loss_dict["L2Loss"] = total_loss.item()
|
1594 |
+
for loss_config in lcm_losses:
|
1595 |
+
if loss_config.loss.__class__.__name__=="DINOLoss":
|
1596 |
+
with torch.no_grad():
|
1597 |
+
pixel_target = []
|
1598 |
+
latent_target = target.to(dtype=vae.dtype)
|
1599 |
+
for i in range(0, latent_target.shape[0], args.vae_encode_batch_size):
|
1600 |
+
pixel_target.append(
|
1601 |
+
vae.decode(
|
1602 |
+
latent_target[i : i + args.vae_encode_batch_size] / vae.config.scaling_factor,
|
1603 |
+
return_dict=False
|
1604 |
+
)[0]
|
1605 |
+
)
|
1606 |
+
pixel_target = torch.cat(pixel_target, dim=0)
|
1607 |
+
pixel_pred = []
|
1608 |
+
latent_pred = model_pred.to(dtype=vae.dtype)
|
1609 |
+
for i in range(0, latent_pred.shape[0], args.vae_encode_batch_size):
|
1610 |
+
pixel_pred.append(
|
1611 |
+
vae.decode(
|
1612 |
+
latent_pred[i : i + args.vae_encode_batch_size] / vae.config.scaling_factor,
|
1613 |
+
return_dict=False
|
1614 |
+
)[0]
|
1615 |
+
)
|
1616 |
+
pixel_pred = torch.cat(pixel_pred, dim=0)
|
1617 |
+
dino_loss_arguments = {
|
1618 |
+
"target": pixel_target,
|
1619 |
+
"predict": pixel_pred,
|
1620 |
+
}
|
1621 |
+
non_weighted_loss = loss_config.loss(**dino_loss_arguments, accelerator=accelerator)
|
1622 |
+
loss_dict[loss_config.loss.__class__.__name__] = non_weighted_loss.item()
|
1623 |
+
total_loss = total_loss + non_weighted_loss * loss_config.weight
|
1624 |
+
else:
|
1625 |
+
non_weighted_loss = loss_config.loss(**lcm_loss_arguments, accelerator=accelerator)
|
1626 |
+
total_loss = total_loss + non_weighted_loss * loss_config.weight
|
1627 |
+
loss_dict[loss_config.loss.__class__.__name__] = non_weighted_loss.item()
|
1628 |
+
|
1629 |
+
# 11. Backpropagate on the online student model (`unet`) (only LoRA)
|
1630 |
+
accelerator.backward(total_loss)
|
1631 |
+
if accelerator.sync_gradients:
|
1632 |
+
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
|
1633 |
+
optimizer.step()
|
1634 |
+
lr_scheduler.step()
|
1635 |
+
optimizer.zero_grad(set_to_none=True)
|
1636 |
+
|
1637 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
1638 |
+
if accelerator.sync_gradients:
|
1639 |
+
progress_bar.update(1)
|
1640 |
+
global_step += 1
|
1641 |
+
|
1642 |
+
if accelerator.is_main_process:
|
1643 |
+
if global_step % args.checkpointing_steps == 0:
|
1644 |
+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
1645 |
+
if args.checkpoints_total_limit is not None:
|
1646 |
+
checkpoints = os.listdir(args.output_dir)
|
1647 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
1648 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
1649 |
+
|
1650 |
+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
1651 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
1652 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
1653 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
1654 |
+
|
1655 |
+
logger.info(
|
1656 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
1657 |
+
)
|
1658 |
+
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
1659 |
+
|
1660 |
+
for removing_checkpoint in removing_checkpoints:
|
1661 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
1662 |
+
shutil.rmtree(removing_checkpoint)
|
1663 |
+
|
1664 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
1665 |
+
accelerator.save_state(save_path)
|
1666 |
+
logger.info(f"Saved state to {save_path}")
|
1667 |
+
|
1668 |
+
if global_step % args.validation_steps == 0:
|
1669 |
+
out_images = log_validation(unwrap_model(unet), vae, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two,
|
1670 |
+
lcm_scheduler, image_encoder, image_processor,
|
1671 |
+
args, accelerator, weight_dtype, global_step, lq_img, gt_img, is_final_validation=False, log_local=False)
|
1672 |
+
|
1673 |
+
logs = dict()
|
1674 |
+
# logs.update({"loss": loss.detach().item()})
|
1675 |
+
logs.update(loss_dict)
|
1676 |
+
logs.update({"lr": lr_scheduler.get_last_lr()[0]})
|
1677 |
+
progress_bar.set_postfix(**logs)
|
1678 |
+
accelerator.log(logs, step=global_step)
|
1679 |
+
|
1680 |
+
if global_step >= args.max_train_steps:
|
1681 |
+
break
|
1682 |
+
|
1683 |
+
# Create the pipeline using using the trained modules and save it.
|
1684 |
+
accelerator.wait_for_everyone()
|
1685 |
+
if accelerator.is_main_process:
|
1686 |
+
unet = accelerator.unwrap_model(unet)
|
1687 |
+
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
1688 |
+
StableDiffusionXLPipeline.save_lora_weights(args.output_dir, unet_lora_layers=unet_lora_state_dict)
|
1689 |
+
|
1690 |
+
if args.push_to_hub:
|
1691 |
+
upload_folder(
|
1692 |
+
repo_id=repo_id,
|
1693 |
+
folder_path=args.output_dir,
|
1694 |
+
commit_message="End of training",
|
1695 |
+
ignore_patterns=["step_*", "epoch_*"],
|
1696 |
+
)
|
1697 |
+
|
1698 |
+
del unet
|
1699 |
+
torch.cuda.empty_cache()
|
1700 |
+
|
1701 |
+
# Final inference.
|
1702 |
+
if args.validation_steps is not None:
|
1703 |
+
log_validation(unwrap_model(unet), vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
|
1704 |
+
lcm_scheduler, image_encoder=None, image_processor=None,
|
1705 |
+
args=args, accelerator=accelerator, weight_dtype=weight_dtype, step=0, is_final_validation=False, log_local=True)
|
1706 |
+
|
1707 |
+
accelerator.end_training()
|
1708 |
+
|
1709 |
+
|
1710 |
+
if __name__ == "__main__":
|
1711 |
+
args = parse_args()
|
1712 |
+
main(args)
|
train_previewer_lora.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# After DCP training, distill the Previewer with DCP in `train_previewer_lora.py`:
|
2 |
+
accelerate launch --num_processes <num_of_gpus> train_previewer_lora.py \
|
3 |
+
--output_dir <your/output/path> \
|
4 |
+
--train_data_dir <your/data/path> \
|
5 |
+
--logging_dir <your/logging/path> \
|
6 |
+
--pretrained_model_name_or_path <your/sdxl/path> \
|
7 |
+
--feature_extractor_path <your/dinov2/path> \
|
8 |
+
--pretrained_adapter_model_path <your/dcp/path> \
|
9 |
+
--losses_config_path config_files/losses.yaml \
|
10 |
+
--data_config_path config_files/IR_dataset.yaml \
|
11 |
+
--save_only_adapter \
|
12 |
+
--gradient_checkpointing \
|
13 |
+
--num_train_timesteps 1000 \
|
14 |
+
--num_ddim_timesteps 50 \
|
15 |
+
--lora_alpha 1 \
|
16 |
+
--mixed_precision fp16 \
|
17 |
+
--train_batch_size 32 \
|
18 |
+
--vae_encode_batch_size 16 \
|
19 |
+
--gradient_accumulation_steps 1 \
|
20 |
+
--learning_rate 1e-4 \
|
21 |
+
--lr_warmup_steps 1000 \
|
22 |
+
--lr_scheduler cosine \
|
23 |
+
--lr_num_cycles 1 \
|
24 |
+
--resume_from_checkpoint latest
|
train_stage1_adapter.py
ADDED
@@ -0,0 +1,1259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
|
16 |
+
import argparse
|
17 |
+
import contextlib
|
18 |
+
import time
|
19 |
+
import gc
|
20 |
+
import logging
|
21 |
+
import math
|
22 |
+
import os
|
23 |
+
import random
|
24 |
+
import jsonlines
|
25 |
+
import functools
|
26 |
+
import shutil
|
27 |
+
import pyrallis
|
28 |
+
import itertools
|
29 |
+
from pathlib import Path
|
30 |
+
from collections import namedtuple, OrderedDict
|
31 |
+
|
32 |
+
import accelerate
|
33 |
+
import numpy as np
|
34 |
+
import torch
|
35 |
+
import torch.nn.functional as F
|
36 |
+
import torch.utils.checkpoint
|
37 |
+
import transformers
|
38 |
+
from accelerate import Accelerator
|
39 |
+
from accelerate.logging import get_logger
|
40 |
+
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
41 |
+
from datasets import load_dataset
|
42 |
+
from packaging import version
|
43 |
+
from PIL import Image
|
44 |
+
from data.data_config import DataConfig
|
45 |
+
from basicsr.utils.degradation_pipeline import RealESRGANDegradation
|
46 |
+
from losses.loss_config import LossesConfig
|
47 |
+
from losses.losses import *
|
48 |
+
from torchvision import transforms
|
49 |
+
from torchvision.transforms.functional import crop
|
50 |
+
from tqdm.auto import tqdm
|
51 |
+
from transformers import (
|
52 |
+
AutoTokenizer,
|
53 |
+
PretrainedConfig,
|
54 |
+
CLIPImageProcessor, CLIPVisionModelWithProjection,
|
55 |
+
AutoImageProcessor, AutoModel)
|
56 |
+
|
57 |
+
import diffusers
|
58 |
+
from diffusers import (
|
59 |
+
AutoencoderKL,
|
60 |
+
AutoencoderTiny,
|
61 |
+
DDPMScheduler,
|
62 |
+
StableDiffusionXLPipeline,
|
63 |
+
UNet2DConditionModel,
|
64 |
+
)
|
65 |
+
from diffusers.optimization import get_scheduler
|
66 |
+
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
|
67 |
+
from diffusers.utils.import_utils import is_xformers_available
|
68 |
+
from diffusers.utils.torch_utils import is_compiled_module
|
69 |
+
|
70 |
+
from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
|
71 |
+
from utils.train_utils import (
|
72 |
+
seperate_ip_params_from_unet,
|
73 |
+
import_model_class_from_model_name_or_path,
|
74 |
+
tensor_to_pil,
|
75 |
+
get_train_dataset, prepare_train_dataset, collate_fn,
|
76 |
+
encode_prompt, importance_sampling_fn, extract_into_tensor
|
77 |
+
)
|
78 |
+
from module.ip_adapter.resampler import Resampler
|
79 |
+
from module.ip_adapter.attention_processor import init_attn_proc
|
80 |
+
from module.ip_adapter.utils import init_adapter_in_unet, prepare_training_image_embeds
|
81 |
+
|
82 |
+
|
83 |
+
if is_wandb_available():
|
84 |
+
import wandb
|
85 |
+
|
86 |
+
|
87 |
+
logger = get_logger(__name__)
|
88 |
+
|
89 |
+
|
90 |
+
def log_validation(unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
|
91 |
+
scheduler, image_encoder, image_processor, deg_pipeline,
|
92 |
+
args, accelerator, weight_dtype, step, lq_img=None, gt_img=None, is_final_validation=False, log_local=False):
|
93 |
+
logger.info("Running validation... ")
|
94 |
+
|
95 |
+
image_logs = []
|
96 |
+
|
97 |
+
lq = [Image.open(lq_example) for lq_example in args.validation_image]
|
98 |
+
|
99 |
+
pipe = StableDiffusionXLPipeline(
|
100 |
+
vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
|
101 |
+
unet, scheduler, image_encoder, image_processor,
|
102 |
+
).to(accelerator.device)
|
103 |
+
|
104 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
105 |
+
image = pipe(
|
106 |
+
prompt=[""]*len(lq),
|
107 |
+
ip_adapter_image=[lq],
|
108 |
+
num_inference_steps=20,
|
109 |
+
generator=generator,
|
110 |
+
guidance_scale=5.0,
|
111 |
+
height=args.resolution,
|
112 |
+
width=args.resolution,
|
113 |
+
).images
|
114 |
+
|
115 |
+
if log_local:
|
116 |
+
for i, img in enumerate(tensor_to_pil(lq_img)):
|
117 |
+
img.save(f"./lq_{i}.png")
|
118 |
+
for i, img in enumerate(tensor_to_pil(gt_img)):
|
119 |
+
img.save(f"./gt_{i}.png")
|
120 |
+
for i, img in enumerate(image):
|
121 |
+
img.save(f"./lq_IPA_{i}.png")
|
122 |
+
return
|
123 |
+
|
124 |
+
tracker_key = "test" if is_final_validation else "validation"
|
125 |
+
for tracker in accelerator.trackers:
|
126 |
+
if tracker.name == "tensorboard":
|
127 |
+
images = [np.asarray(pil_img) for pil_img in image]
|
128 |
+
images = np.stack(images, axis=0)
|
129 |
+
if lq_img is not None and gt_img is not None:
|
130 |
+
input_lq = lq_img.detach().cpu()
|
131 |
+
input_lq = np.asarray(input_lq.add(1).div(2).clamp(0, 1))
|
132 |
+
input_gt = gt_img.detach().cpu()
|
133 |
+
input_gt = np.asarray(input_gt.add(1).div(2).clamp(0, 1))
|
134 |
+
tracker.writer.add_images("lq", input_lq[0], step, dataformats="CHW")
|
135 |
+
tracker.writer.add_images("gt", input_gt[0], step, dataformats="CHW")
|
136 |
+
tracker.writer.add_images("rec", images, step, dataformats="NHWC")
|
137 |
+
elif tracker.name == "wandb":
|
138 |
+
raise NotImplementedError("Wandb logging not implemented for validation.")
|
139 |
+
formatted_images = []
|
140 |
+
|
141 |
+
for log in image_logs:
|
142 |
+
images = log["images"]
|
143 |
+
validation_prompt = log["validation_prompt"]
|
144 |
+
validation_image = log["validation_image"]
|
145 |
+
|
146 |
+
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
|
147 |
+
|
148 |
+
for image in images:
|
149 |
+
image = wandb.Image(image, caption=validation_prompt)
|
150 |
+
formatted_images.append(image)
|
151 |
+
|
152 |
+
tracker.log({tracker_key: formatted_images})
|
153 |
+
else:
|
154 |
+
logger.warning(f"image logging not implemented for {tracker.name}")
|
155 |
+
|
156 |
+
gc.collect()
|
157 |
+
torch.cuda.empty_cache()
|
158 |
+
|
159 |
+
return image_logs
|
160 |
+
|
161 |
+
|
162 |
+
def parse_args(input_args=None):
|
163 |
+
parser = argparse.ArgumentParser(description="InstantIR stage-1 training.")
|
164 |
+
parser.add_argument(
|
165 |
+
"--pretrained_model_name_or_path",
|
166 |
+
type=str,
|
167 |
+
default=None,
|
168 |
+
required=True,
|
169 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
170 |
+
)
|
171 |
+
parser.add_argument(
|
172 |
+
"--pretrained_vae_model_name_or_path",
|
173 |
+
type=str,
|
174 |
+
default=None,
|
175 |
+
help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
|
176 |
+
)
|
177 |
+
parser.add_argument(
|
178 |
+
"--feature_extractor_path",
|
179 |
+
type=str,
|
180 |
+
default=None,
|
181 |
+
help="Path to image encoder for IP-Adapters or model identifier from huggingface.co/models.",
|
182 |
+
)
|
183 |
+
parser.add_argument(
|
184 |
+
"--pretrained_adapter_model_path",
|
185 |
+
type=str,
|
186 |
+
default=None,
|
187 |
+
help="Path to IP-Adapter models or model identifier from huggingface.co/models.",
|
188 |
+
)
|
189 |
+
parser.add_argument(
|
190 |
+
"--adapter_tokens",
|
191 |
+
type=int,
|
192 |
+
default=64,
|
193 |
+
help="Number of tokens to use in IP-adapter cross attention mechanism.",
|
194 |
+
)
|
195 |
+
parser.add_argument(
|
196 |
+
"--use_clip_encoder",
|
197 |
+
action="store_true",
|
198 |
+
help="Whether or not to use DINO as image encoder, else CLIP encoder.",
|
199 |
+
)
|
200 |
+
parser.add_argument(
|
201 |
+
"--image_encoder_hidden_feature",
|
202 |
+
action="store_true",
|
203 |
+
help="Whether or not to use the penultimate hidden states as image embeddings.",
|
204 |
+
)
|
205 |
+
parser.add_argument(
|
206 |
+
"--losses_config_path",
|
207 |
+
type=str,
|
208 |
+
required=True,
|
209 |
+
default='config_files/losses.yaml'
|
210 |
+
help=("A yaml file containing losses to use and their weights."),
|
211 |
+
)
|
212 |
+
parser.add_argument(
|
213 |
+
"--data_config_path",
|
214 |
+
type=str,
|
215 |
+
default='config_files/IR_dataset.yaml',
|
216 |
+
help=("A folder containing the training data. "),
|
217 |
+
)
|
218 |
+
parser.add_argument(
|
219 |
+
"--variant",
|
220 |
+
type=str,
|
221 |
+
default=None,
|
222 |
+
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
223 |
+
)
|
224 |
+
parser.add_argument(
|
225 |
+
"--revision",
|
226 |
+
type=str,
|
227 |
+
default=None,
|
228 |
+
required=False,
|
229 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
230 |
+
)
|
231 |
+
parser.add_argument(
|
232 |
+
"--tokenizer_name",
|
233 |
+
type=str,
|
234 |
+
default=None,
|
235 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
236 |
+
)
|
237 |
+
parser.add_argument(
|
238 |
+
"--output_dir",
|
239 |
+
type=str,
|
240 |
+
default="stage1_model",
|
241 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
242 |
+
)
|
243 |
+
parser.add_argument(
|
244 |
+
"--cache_dir",
|
245 |
+
type=str,
|
246 |
+
default=None,
|
247 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
248 |
+
)
|
249 |
+
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
250 |
+
parser.add_argument(
|
251 |
+
"--resolution",
|
252 |
+
type=int,
|
253 |
+
default=512,
|
254 |
+
help=(
|
255 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
256 |
+
" resolution"
|
257 |
+
),
|
258 |
+
)
|
259 |
+
parser.add_argument(
|
260 |
+
"--crops_coords_top_left_h",
|
261 |
+
type=int,
|
262 |
+
default=0,
|
263 |
+
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
|
264 |
+
)
|
265 |
+
parser.add_argument(
|
266 |
+
"--crops_coords_top_left_w",
|
267 |
+
type=int,
|
268 |
+
default=0,
|
269 |
+
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
|
270 |
+
)
|
271 |
+
parser.add_argument(
|
272 |
+
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
273 |
+
)
|
274 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
275 |
+
parser.add_argument(
|
276 |
+
"--max_train_steps",
|
277 |
+
type=int,
|
278 |
+
default=None,
|
279 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
280 |
+
)
|
281 |
+
parser.add_argument(
|
282 |
+
"--checkpointing_steps",
|
283 |
+
type=int,
|
284 |
+
default=2000,
|
285 |
+
help=(
|
286 |
+
"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
|
287 |
+
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
|
288 |
+
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
|
289 |
+
"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
|
290 |
+
"instructions."
|
291 |
+
),
|
292 |
+
)
|
293 |
+
parser.add_argument(
|
294 |
+
"--checkpoints_total_limit",
|
295 |
+
type=int,
|
296 |
+
default=5,
|
297 |
+
help=("Max number of checkpoints to store."),
|
298 |
+
)
|
299 |
+
parser.add_argument(
|
300 |
+
"--resume_from_checkpoint",
|
301 |
+
type=str,
|
302 |
+
default=None,
|
303 |
+
help=(
|
304 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
305 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
306 |
+
),
|
307 |
+
)
|
308 |
+
parser.add_argument(
|
309 |
+
"--gradient_accumulation_steps",
|
310 |
+
type=int,
|
311 |
+
default=1,
|
312 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
313 |
+
)
|
314 |
+
parser.add_argument(
|
315 |
+
"--gradient_checkpointing",
|
316 |
+
action="store_true",
|
317 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
318 |
+
)
|
319 |
+
parser.add_argument(
|
320 |
+
"--save_only_adapter",
|
321 |
+
action="store_true",
|
322 |
+
help="Only save extra adapter to save space.",
|
323 |
+
)
|
324 |
+
parser.add_argument(
|
325 |
+
"--importance_sampling",
|
326 |
+
action="store_true",
|
327 |
+
help="Whether or not to use importance sampling.",
|
328 |
+
)
|
329 |
+
parser.add_argument(
|
330 |
+
"--learning_rate",
|
331 |
+
type=float,
|
332 |
+
default=1e-4,
|
333 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
334 |
+
)
|
335 |
+
parser.add_argument(
|
336 |
+
"--scale_lr",
|
337 |
+
action="store_true",
|
338 |
+
default=False,
|
339 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
340 |
+
)
|
341 |
+
parser.add_argument(
|
342 |
+
"--lr_scheduler",
|
343 |
+
type=str,
|
344 |
+
default="constant",
|
345 |
+
help=(
|
346 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
347 |
+
' "constant", "constant_with_warmup"]'
|
348 |
+
),
|
349 |
+
)
|
350 |
+
parser.add_argument(
|
351 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
352 |
+
)
|
353 |
+
parser.add_argument(
|
354 |
+
"--lr_num_cycles",
|
355 |
+
type=int,
|
356 |
+
default=1,
|
357 |
+
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
358 |
+
)
|
359 |
+
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
|
360 |
+
parser.add_argument(
|
361 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
362 |
+
)
|
363 |
+
parser.add_argument(
|
364 |
+
"--dataloader_num_workers",
|
365 |
+
type=int,
|
366 |
+
default=0,
|
367 |
+
help=(
|
368 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
369 |
+
),
|
370 |
+
)
|
371 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
372 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
373 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
374 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
375 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
376 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
377 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
378 |
+
parser.add_argument(
|
379 |
+
"--hub_model_id",
|
380 |
+
type=str,
|
381 |
+
default=None,
|
382 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
383 |
+
)
|
384 |
+
parser.add_argument(
|
385 |
+
"--logging_dir",
|
386 |
+
type=str,
|
387 |
+
default="logs",
|
388 |
+
help=(
|
389 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
390 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
391 |
+
),
|
392 |
+
)
|
393 |
+
parser.add_argument(
|
394 |
+
"--allow_tf32",
|
395 |
+
action="store_true",
|
396 |
+
help=(
|
397 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
398 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
399 |
+
),
|
400 |
+
)
|
401 |
+
parser.add_argument(
|
402 |
+
"--report_to",
|
403 |
+
type=str,
|
404 |
+
default="tensorboard",
|
405 |
+
help=(
|
406 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
407 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
408 |
+
),
|
409 |
+
)
|
410 |
+
parser.add_argument(
|
411 |
+
"--mixed_precision",
|
412 |
+
type=str,
|
413 |
+
default=None,
|
414 |
+
choices=["no", "fp16", "bf16"],
|
415 |
+
help=(
|
416 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
417 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
418 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
419 |
+
),
|
420 |
+
)
|
421 |
+
parser.add_argument(
|
422 |
+
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
423 |
+
)
|
424 |
+
parser.add_argument(
|
425 |
+
"--set_grads_to_none",
|
426 |
+
action="store_true",
|
427 |
+
help=(
|
428 |
+
"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
|
429 |
+
" behaviors, so disable this argument if it causes any problems. More info:"
|
430 |
+
" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
|
431 |
+
),
|
432 |
+
)
|
433 |
+
parser.add_argument(
|
434 |
+
"--dataset_name",
|
435 |
+
type=str,
|
436 |
+
default=None,
|
437 |
+
help=(
|
438 |
+
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
439 |
+
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
440 |
+
" or to a folder containing files that 🤗 Datasets can understand."
|
441 |
+
),
|
442 |
+
)
|
443 |
+
parser.add_argument(
|
444 |
+
"--dataset_config_name",
|
445 |
+
type=str,
|
446 |
+
default=None,
|
447 |
+
help="The config of the Dataset, leave as None if there's only one config.",
|
448 |
+
)
|
449 |
+
parser.add_argument(
|
450 |
+
"--train_data_dir",
|
451 |
+
type=str,
|
452 |
+
default=None,
|
453 |
+
help=(
|
454 |
+
"A folder containing the training data. Folder contents must follow the structure described in"
|
455 |
+
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
|
456 |
+
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
|
457 |
+
),
|
458 |
+
)
|
459 |
+
parser.add_argument(
|
460 |
+
"--image_column", type=str, default="image", help="The column of the dataset containing the target image."
|
461 |
+
)
|
462 |
+
parser.add_argument(
|
463 |
+
"--conditioning_image_column",
|
464 |
+
type=str,
|
465 |
+
default="conditioning_image",
|
466 |
+
help="The column of the dataset containing the controlnet conditioning image.",
|
467 |
+
)
|
468 |
+
parser.add_argument(
|
469 |
+
"--caption_column",
|
470 |
+
type=str,
|
471 |
+
default="text",
|
472 |
+
help="The column of the dataset containing a caption or a list of captions.",
|
473 |
+
)
|
474 |
+
parser.add_argument(
|
475 |
+
"--max_train_samples",
|
476 |
+
type=int,
|
477 |
+
default=None,
|
478 |
+
help=(
|
479 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
480 |
+
"value if set."
|
481 |
+
),
|
482 |
+
)
|
483 |
+
parser.add_argument(
|
484 |
+
"--text_drop_rate",
|
485 |
+
type=float,
|
486 |
+
default=0.05,
|
487 |
+
help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
|
488 |
+
)
|
489 |
+
parser.add_argument(
|
490 |
+
"--image_drop_rate",
|
491 |
+
type=float,
|
492 |
+
default=0.05,
|
493 |
+
help="Proportion of IP-Adapter inputs to be dropped. Defaults to 0 (no drop-out).",
|
494 |
+
)
|
495 |
+
parser.add_argument(
|
496 |
+
"--cond_drop_rate",
|
497 |
+
type=float,
|
498 |
+
default=0.05,
|
499 |
+
help="Proportion of all conditions to be dropped. Defaults to 0 (no drop-out).",
|
500 |
+
)
|
501 |
+
parser.add_argument(
|
502 |
+
"--sanity_check",
|
503 |
+
action="store_true",
|
504 |
+
help=(
|
505 |
+
"sanity check"
|
506 |
+
),
|
507 |
+
)
|
508 |
+
parser.add_argument(
|
509 |
+
"--validation_prompt",
|
510 |
+
type=str,
|
511 |
+
default=None,
|
512 |
+
nargs="+",
|
513 |
+
help=(
|
514 |
+
"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
|
515 |
+
" Provide either a matching number of `--validation_image`s, a single `--validation_image`"
|
516 |
+
" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
|
517 |
+
),
|
518 |
+
)
|
519 |
+
parser.add_argument(
|
520 |
+
"--validation_image",
|
521 |
+
type=str,
|
522 |
+
default=None,
|
523 |
+
nargs="+",
|
524 |
+
help=(
|
525 |
+
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
|
526 |
+
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
|
527 |
+
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
|
528 |
+
" `--validation_image` that will be used with all `--validation_prompt`s."
|
529 |
+
),
|
530 |
+
)
|
531 |
+
parser.add_argument(
|
532 |
+
"--num_validation_images",
|
533 |
+
type=int,
|
534 |
+
default=4,
|
535 |
+
help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
|
536 |
+
)
|
537 |
+
parser.add_argument(
|
538 |
+
"--validation_steps",
|
539 |
+
type=int,
|
540 |
+
default=3000,
|
541 |
+
help=(
|
542 |
+
"Run validation every X steps. Validation consists of running the prompt"
|
543 |
+
" `args.validation_prompt` multiple times: `args.num_validation_images`"
|
544 |
+
" and logging the images."
|
545 |
+
),
|
546 |
+
)
|
547 |
+
parser.add_argument(
|
548 |
+
"--tracker_project_name",
|
549 |
+
type=str,
|
550 |
+
default="instantir_stage1",
|
551 |
+
help=(
|
552 |
+
"The `project_name` argument passed to Accelerator.init_trackers for"
|
553 |
+
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
|
554 |
+
),
|
555 |
+
)
|
556 |
+
|
557 |
+
if input_args is not None:
|
558 |
+
args = parser.parse_args(input_args)
|
559 |
+
else:
|
560 |
+
args = parser.parse_args()
|
561 |
+
|
562 |
+
# if args.dataset_name is None and args.train_data_dir is None and args.data_config_path is None:
|
563 |
+
# raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
|
564 |
+
|
565 |
+
if args.dataset_name is not None and args.train_data_dir is not None:
|
566 |
+
raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
|
567 |
+
|
568 |
+
if args.text_drop_rate < 0 or args.text_drop_rate > 1:
|
569 |
+
raise ValueError("`--text_drop_rate` must be in the range [0, 1].")
|
570 |
+
|
571 |
+
if args.validation_prompt is not None and args.validation_image is None:
|
572 |
+
raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
|
573 |
+
|
574 |
+
if args.validation_prompt is None and args.validation_image is not None:
|
575 |
+
raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
|
576 |
+
|
577 |
+
if (
|
578 |
+
args.validation_image is not None
|
579 |
+
and args.validation_prompt is not None
|
580 |
+
and len(args.validation_image) != 1
|
581 |
+
and len(args.validation_prompt) != 1
|
582 |
+
and len(args.validation_image) != len(args.validation_prompt)
|
583 |
+
):
|
584 |
+
raise ValueError(
|
585 |
+
"Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
|
586 |
+
" or the same number of `--validation_prompt`s and `--validation_image`s"
|
587 |
+
)
|
588 |
+
|
589 |
+
if args.resolution % 8 != 0:
|
590 |
+
raise ValueError(
|
591 |
+
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
|
592 |
+
)
|
593 |
+
|
594 |
+
return args
|
595 |
+
|
596 |
+
|
597 |
+
def main(args):
|
598 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
599 |
+
raise ValueError(
|
600 |
+
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
601 |
+
" Please use `huggingface-cli login` to authenticate with the Hub."
|
602 |
+
)
|
603 |
+
|
604 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
605 |
+
|
606 |
+
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
607 |
+
# due to pytorch#99272, MPS does not yet support bfloat16.
|
608 |
+
raise ValueError(
|
609 |
+
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
610 |
+
)
|
611 |
+
|
612 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
613 |
+
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
614 |
+
accelerator = Accelerator(
|
615 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
616 |
+
mixed_precision=args.mixed_precision,
|
617 |
+
log_with=args.report_to,
|
618 |
+
project_config=accelerator_project_config,
|
619 |
+
# kwargs_handlers=[kwargs],
|
620 |
+
)
|
621 |
+
|
622 |
+
# Make one log on every process with the configuration for debugging.
|
623 |
+
logging.basicConfig(
|
624 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
625 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
626 |
+
level=logging.INFO,
|
627 |
+
)
|
628 |
+
logger.info(accelerator.state, main_process_only=False)
|
629 |
+
if accelerator.is_local_main_process:
|
630 |
+
transformers.utils.logging.set_verbosity_warning()
|
631 |
+
diffusers.utils.logging.set_verbosity_info()
|
632 |
+
else:
|
633 |
+
transformers.utils.logging.set_verbosity_error()
|
634 |
+
diffusers.utils.logging.set_verbosity_error()
|
635 |
+
|
636 |
+
# If passed along, set the training seed now.
|
637 |
+
if args.seed is not None:
|
638 |
+
set_seed(args.seed)
|
639 |
+
|
640 |
+
# Handle the repository creation.
|
641 |
+
if accelerator.is_main_process:
|
642 |
+
if args.output_dir is not None:
|
643 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
644 |
+
|
645 |
+
# Load scheduler and models
|
646 |
+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
647 |
+
# Importance sampling.
|
648 |
+
list_of_candidates = np.arange(noise_scheduler.config.num_train_timesteps, dtype='float64')
|
649 |
+
prob_dist = importance_sampling_fn(list_of_candidates, noise_scheduler.config.num_train_timesteps, 0.5)
|
650 |
+
importance_ratio = prob_dist / prob_dist.sum() * noise_scheduler.config.num_train_timesteps
|
651 |
+
importance_ratio = torch.from_numpy(importance_ratio.copy()).float()
|
652 |
+
|
653 |
+
# Load the tokenizers
|
654 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
655 |
+
args.pretrained_model_name_or_path,
|
656 |
+
subfolder="tokenizer",
|
657 |
+
revision=args.revision,
|
658 |
+
use_fast=False,
|
659 |
+
)
|
660 |
+
tokenizer_2 = AutoTokenizer.from_pretrained(
|
661 |
+
args.pretrained_model_name_or_path,
|
662 |
+
subfolder="tokenizer_2",
|
663 |
+
revision=args.revision,
|
664 |
+
use_fast=False,
|
665 |
+
)
|
666 |
+
|
667 |
+
# Text encoder and image encoder.
|
668 |
+
text_encoder_cls_one = import_model_class_from_model_name_or_path(
|
669 |
+
args.pretrained_model_name_or_path, args.revision
|
670 |
+
)
|
671 |
+
text_encoder_cls_two = import_model_class_from_model_name_or_path(
|
672 |
+
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
|
673 |
+
)
|
674 |
+
text_encoder = text_encoder_cls_one.from_pretrained(
|
675 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
676 |
+
)
|
677 |
+
text_encoder_2 = text_encoder_cls_two.from_pretrained(
|
678 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
|
679 |
+
)
|
680 |
+
if args.use_clip_encoder:
|
681 |
+
image_processor = CLIPImageProcessor()
|
682 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.feature_extractor_path)
|
683 |
+
else:
|
684 |
+
image_processor = AutoImageProcessor.from_pretrained(args.feature_extractor_path)
|
685 |
+
image_encoder = AutoModel.from_pretrained(args.feature_extractor_path)
|
686 |
+
|
687 |
+
# VAE.
|
688 |
+
vae_path = (
|
689 |
+
args.pretrained_model_name_or_path
|
690 |
+
if args.pretrained_vae_model_name_or_path is None
|
691 |
+
else args.pretrained_vae_model_name_or_path
|
692 |
+
)
|
693 |
+
vae = AutoencoderKL.from_pretrained(
|
694 |
+
vae_path,
|
695 |
+
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
696 |
+
revision=args.revision,
|
697 |
+
variant=args.variant,
|
698 |
+
)
|
699 |
+
|
700 |
+
# UNet.
|
701 |
+
unet = UNet2DConditionModel.from_pretrained(
|
702 |
+
args.pretrained_model_name_or_path,
|
703 |
+
subfolder="unet",
|
704 |
+
revision=args.revision,
|
705 |
+
variant=args.variant
|
706 |
+
)
|
707 |
+
|
708 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(
|
709 |
+
args.pretrained_model_name_or_path,
|
710 |
+
unet=unet,
|
711 |
+
text_encoder=text_encoder,
|
712 |
+
text_encoder_2=text_encoder_2,
|
713 |
+
vae=vae,
|
714 |
+
tokenizer=tokenizer,
|
715 |
+
tokenizer_2=tokenizer_2,
|
716 |
+
variant=args.variant
|
717 |
+
)
|
718 |
+
|
719 |
+
# Resampler for project model in IP-Adapter
|
720 |
+
image_proj_model = Resampler(
|
721 |
+
dim=1280,
|
722 |
+
depth=4,
|
723 |
+
dim_head=64,
|
724 |
+
heads=20,
|
725 |
+
num_queries=args.adapter_tokens,
|
726 |
+
embedding_dim=image_encoder.config.hidden_size,
|
727 |
+
output_dim=unet.config.cross_attention_dim,
|
728 |
+
ff_mult=4
|
729 |
+
)
|
730 |
+
|
731 |
+
init_adapter_in_unet(
|
732 |
+
unet,
|
733 |
+
image_proj_model,
|
734 |
+
os.path.join(args.pretrained_adapter_model_path, 'adapter_ckpt.pt'),
|
735 |
+
adapter_tokens=args.adapter_tokens,
|
736 |
+
)
|
737 |
+
|
738 |
+
# Initialize training state.
|
739 |
+
vae.requires_grad_(False)
|
740 |
+
text_encoder.requires_grad_(False)
|
741 |
+
text_encoder_2.requires_grad_(False)
|
742 |
+
unet.requires_grad_(False)
|
743 |
+
image_encoder.requires_grad_(False)
|
744 |
+
|
745 |
+
def unwrap_model(model):
|
746 |
+
model = accelerator.unwrap_model(model)
|
747 |
+
model = model._orig_mod if is_compiled_module(model) else model
|
748 |
+
return model
|
749 |
+
|
750 |
+
# `accelerate` 0.16.0 will have better support for customized saving
|
751 |
+
if args.save_only_adapter:
|
752 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
753 |
+
def save_model_hook(models, weights, output_dir):
|
754 |
+
if accelerator.is_main_process:
|
755 |
+
for model in models:
|
756 |
+
if isinstance(model, type(unwrap_model(unet))): # save adapter only
|
757 |
+
adapter_state_dict = OrderedDict()
|
758 |
+
adapter_state_dict["image_proj"] = model.encoder_hid_proj.image_projection_layers[0].state_dict()
|
759 |
+
adapter_state_dict["ip_adapter"] = torch.nn.ModuleList(model.attn_processors.values()).state_dict()
|
760 |
+
torch.save(adapter_state_dict, os.path.join(output_dir, "adapter_ckpt.pt"))
|
761 |
+
|
762 |
+
weights.pop()
|
763 |
+
|
764 |
+
def load_model_hook(models, input_dir):
|
765 |
+
|
766 |
+
while len(models) > 0:
|
767 |
+
# pop models so that they are not loaded again
|
768 |
+
model = models.pop()
|
769 |
+
|
770 |
+
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
771 |
+
adapter_state_dict = torch.load(os.path.join(input_dir, "adapter_ckpt.pt"), map_location="cpu")
|
772 |
+
if list(adapter_state_dict.keys()) != ["image_proj", "ip_adapter"]:
|
773 |
+
from module.ip_adapter.utils import revise_state_dict
|
774 |
+
adapter_state_dict = revise_state_dict(adapter_state_dict)
|
775 |
+
model.encoder_hid_proj.image_projection_layers[0].load_state_dict(adapter_state_dict["image_proj"], strict=True)
|
776 |
+
missing, unexpected = torch.nn.ModuleList(model.attn_processors.values()).load_state_dict(adapter_state_dict["ip_adapter"], strict=False)
|
777 |
+
if len(unexpected) > 0:
|
778 |
+
raise ValueError(f"Unexpected keys: {unexpected}")
|
779 |
+
if len(missing) > 0:
|
780 |
+
for mk in missing:
|
781 |
+
if "ln" not in mk:
|
782 |
+
raise ValueError(f"Missing keys: {missing}")
|
783 |
+
|
784 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
785 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
786 |
+
|
787 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
788 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
789 |
+
weight_dtype = torch.float32
|
790 |
+
if accelerator.mixed_precision == "fp16":
|
791 |
+
weight_dtype = torch.float16
|
792 |
+
elif accelerator.mixed_precision == "bf16":
|
793 |
+
weight_dtype = torch.bfloat16
|
794 |
+
|
795 |
+
if args.enable_xformers_memory_efficient_attention:
|
796 |
+
if is_xformers_available():
|
797 |
+
import xformers
|
798 |
+
|
799 |
+
xformers_version = version.parse(xformers.__version__)
|
800 |
+
if xformers_version == version.parse("0.0.16"):
|
801 |
+
logger.warning(
|
802 |
+
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
803 |
+
)
|
804 |
+
unet.enable_xformers_memory_efficient_attention()
|
805 |
+
else:
|
806 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
807 |
+
|
808 |
+
if args.gradient_checkpointing:
|
809 |
+
unet.enable_gradient_checkpointing()
|
810 |
+
vae.enable_gradient_checkpointing()
|
811 |
+
|
812 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
813 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
814 |
+
if args.allow_tf32:
|
815 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
816 |
+
|
817 |
+
if args.scale_lr:
|
818 |
+
args.learning_rate = (
|
819 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
820 |
+
)
|
821 |
+
|
822 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
823 |
+
if args.use_8bit_adam:
|
824 |
+
try:
|
825 |
+
import bitsandbytes as bnb
|
826 |
+
except ImportError:
|
827 |
+
raise ImportError(
|
828 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
829 |
+
)
|
830 |
+
|
831 |
+
optimizer_class = bnb.optim.AdamW8bit
|
832 |
+
else:
|
833 |
+
optimizer_class = torch.optim.AdamW
|
834 |
+
|
835 |
+
# Optimizer creation.
|
836 |
+
ip_params, non_ip_params = seperate_ip_params_from_unet(unet)
|
837 |
+
params_to_optimize = ip_params
|
838 |
+
optimizer = optimizer_class(
|
839 |
+
params_to_optimize,
|
840 |
+
lr=args.learning_rate,
|
841 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
842 |
+
weight_decay=args.adam_weight_decay,
|
843 |
+
eps=args.adam_epsilon,
|
844 |
+
)
|
845 |
+
|
846 |
+
# Instantiate Loss.
|
847 |
+
losses_configs: LossesConfig = pyrallis.load(LossesConfig, open(args.losses_config_path, "r"))
|
848 |
+
diffusion_losses = list()
|
849 |
+
for loss_config in losses_configs.diffusion_losses:
|
850 |
+
logger.info(f"Loading diffusion loss: {loss_config.name}")
|
851 |
+
loss = namedtuple("loss", ["loss", "weight"])
|
852 |
+
loss_class = eval(loss_config.name)
|
853 |
+
diffusion_losses.append(loss(loss_class(visualize_every_k=loss_config.visualize_every_k,
|
854 |
+
dtype=weight_dtype,
|
855 |
+
accelerator=accelerator,
|
856 |
+
**loss_config.init_params), weight=loss_config.weight))
|
857 |
+
|
858 |
+
# SDXL additional condition that will be added to time embedding.
|
859 |
+
def compute_time_ids(original_size, crops_coords_top_left):
|
860 |
+
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
861 |
+
target_size = (args.resolution, args.resolution)
|
862 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
863 |
+
add_time_ids = torch.tensor([add_time_ids])
|
864 |
+
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
|
865 |
+
return add_time_ids
|
866 |
+
|
867 |
+
# Text prompt embeddings.
|
868 |
+
@torch.no_grad()
|
869 |
+
def compute_embeddings(batch, text_encoders, tokenizers, drop_idx=None, is_train=True):
|
870 |
+
prompt_batch = batch[args.caption_column]
|
871 |
+
if drop_idx is not None:
|
872 |
+
for i in range(len(prompt_batch)):
|
873 |
+
prompt_batch[i] = "" if drop_idx[i] else prompt_batch[i]
|
874 |
+
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
875 |
+
prompt_batch, text_encoders, tokenizers, is_train
|
876 |
+
)
|
877 |
+
|
878 |
+
add_time_ids = torch.cat(
|
879 |
+
[compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
|
880 |
+
)
|
881 |
+
|
882 |
+
prompt_embeds = prompt_embeds.to(accelerator.device)
|
883 |
+
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
|
884 |
+
add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
|
885 |
+
sdxl_added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
|
886 |
+
|
887 |
+
return prompt_embeds, sdxl_added_cond_kwargs
|
888 |
+
|
889 |
+
# Move pixels into latents.
|
890 |
+
@torch.no_grad()
|
891 |
+
def convert_to_latent(pixels):
|
892 |
+
model_input = vae.encode(pixels).latent_dist.sample()
|
893 |
+
model_input = model_input * vae.config.scaling_factor
|
894 |
+
if args.pretrained_vae_model_name_or_path is None:
|
895 |
+
model_input = model_input.to(weight_dtype)
|
896 |
+
return model_input
|
897 |
+
|
898 |
+
# Datasets and other data moduels.
|
899 |
+
deg_pipeline = RealESRGANDegradation(device=accelerator.device, resolution=args.resolution)
|
900 |
+
compute_embeddings_fn = functools.partial(
|
901 |
+
compute_embeddings,
|
902 |
+
text_encoders=[text_encoder, text_encoder_2],
|
903 |
+
tokenizers=[tokenizer, tokenizer_2],
|
904 |
+
is_train=True,
|
905 |
+
)
|
906 |
+
|
907 |
+
datasets = []
|
908 |
+
datasets_name = []
|
909 |
+
datasets_weights = []
|
910 |
+
if args.data_config_path is not None:
|
911 |
+
data_config: DataConfig = pyrallis.load(DataConfig, open(args.data_config_path, "r"))
|
912 |
+
for single_dataset in data_config.datasets:
|
913 |
+
datasets_weights.append(single_dataset.dataset_weight)
|
914 |
+
datasets_name.append(single_dataset.dataset_folder)
|
915 |
+
dataset_dir = os.path.join(args.train_data_dir, single_dataset.dataset_folder)
|
916 |
+
image_dataset = get_train_dataset(dataset_dir, dataset_dir, args, accelerator)
|
917 |
+
image_dataset = prepare_train_dataset(image_dataset, accelerator, deg_pipeline)
|
918 |
+
datasets.append(image_dataset)
|
919 |
+
# TODO: Validation dataset
|
920 |
+
if data_config.val_dataset is not None:
|
921 |
+
val_dataset = get_train_dataset(dataset_name, dataset_dir, args, accelerator)
|
922 |
+
logger.info(f"Datasets mixing: {list(zip(datasets_name, datasets_weights))}")
|
923 |
+
|
924 |
+
# Mix training datasets.
|
925 |
+
sampler_train = None
|
926 |
+
if len(datasets) == 1:
|
927 |
+
train_dataset = datasets[0]
|
928 |
+
else:
|
929 |
+
# Weighted each dataset
|
930 |
+
train_dataset = torch.utils.data.ConcatDataset(datasets)
|
931 |
+
dataset_weights = []
|
932 |
+
for single_dataset, single_weight in zip(datasets, datasets_weights):
|
933 |
+
dataset_weights.extend([len(train_dataset) / len(single_dataset) * single_weight] * len(single_dataset))
|
934 |
+
sampler_train = torch.utils.data.WeightedRandomSampler(
|
935 |
+
weights=dataset_weights,
|
936 |
+
num_samples=len(dataset_weights)
|
937 |
+
)
|
938 |
+
|
939 |
+
train_dataloader = torch.utils.data.DataLoader(
|
940 |
+
train_dataset,
|
941 |
+
batch_size=args.train_batch_size,
|
942 |
+
sampler=sampler_train,
|
943 |
+
shuffle=True if sampler_train is None else False,
|
944 |
+
collate_fn=collate_fn,
|
945 |
+
num_workers=args.dataloader_num_workers
|
946 |
+
)
|
947 |
+
|
948 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
949 |
+
# The trackers initializes automatically on the main process.
|
950 |
+
if accelerator.is_main_process:
|
951 |
+
tracker_config = dict(vars(args))
|
952 |
+
|
953 |
+
# tensorboard cannot handle list types for config
|
954 |
+
tracker_config.pop("validation_prompt")
|
955 |
+
tracker_config.pop("validation_image")
|
956 |
+
|
957 |
+
accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
|
958 |
+
|
959 |
+
# Scheduler and math around the number of training steps.
|
960 |
+
overrode_max_train_steps = False
|
961 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
962 |
+
if args.max_train_steps is None:
|
963 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
964 |
+
overrode_max_train_steps = True
|
965 |
+
|
966 |
+
lr_scheduler = get_scheduler(
|
967 |
+
args.lr_scheduler,
|
968 |
+
optimizer=optimizer,
|
969 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
970 |
+
num_training_steps=args.max_train_steps,
|
971 |
+
num_cycles=args.lr_num_cycles,
|
972 |
+
power=args.lr_power,
|
973 |
+
)
|
974 |
+
|
975 |
+
# Prepare everything with our `accelerator`.
|
976 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
977 |
+
unet, optimizer, train_dataloader, lr_scheduler
|
978 |
+
)
|
979 |
+
|
980 |
+
# Move vae, unet and text_encoder to device and cast to weight_dtype
|
981 |
+
if args.pretrained_vae_model_name_or_path is None:
|
982 |
+
# The VAE is fp32 to avoid NaN losses.
|
983 |
+
vae.to(accelerator.device, dtype=torch.float32)
|
984 |
+
else:
|
985 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
986 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
987 |
+
text_encoder_2.to(accelerator.device, dtype=weight_dtype)
|
988 |
+
image_encoder.to(accelerator.device, dtype=weight_dtype)
|
989 |
+
importance_ratio = importance_ratio.to(accelerator.device)
|
990 |
+
for non_ip_param in non_ip_params:
|
991 |
+
non_ip_param.data = non_ip_param.data.to(dtype=weight_dtype)
|
992 |
+
for ip_param in ip_params:
|
993 |
+
ip_param.requires_grad_(True)
|
994 |
+
unet.to(accelerator.device)
|
995 |
+
|
996 |
+
# Final check.
|
997 |
+
for n, p in unet.named_parameters():
|
998 |
+
if p.requires_grad: assert p.dtype == torch.float32, n
|
999 |
+
else: assert p.dtype == weight_dtype, n
|
1000 |
+
if args.sanity_check:
|
1001 |
+
if args.resume_from_checkpoint:
|
1002 |
+
if args.resume_from_checkpoint != "latest":
|
1003 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
1004 |
+
else:
|
1005 |
+
# Get the most recent checkpoint
|
1006 |
+
dirs = os.listdir(args.output_dir)
|
1007 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
1008 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
1009 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
1010 |
+
|
1011 |
+
if path is None:
|
1012 |
+
accelerator.print(
|
1013 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
1014 |
+
)
|
1015 |
+
args.resume_from_checkpoint = None
|
1016 |
+
initial_global_step = 0
|
1017 |
+
else:
|
1018 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
1019 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
1020 |
+
|
1021 |
+
# Check input data
|
1022 |
+
batch = next(iter(train_dataloader))
|
1023 |
+
lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
|
1024 |
+
images_log = log_validation(
|
1025 |
+
unwrap_model(unet), vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
|
1026 |
+
noise_scheduler, image_encoder, image_processor, deg_pipeline,
|
1027 |
+
args, accelerator, weight_dtype, step=0, lq_img=lq_img, gt_img=gt_img, is_final_validation=False, log_local=True
|
1028 |
+
)
|
1029 |
+
exit()
|
1030 |
+
|
1031 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
1032 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
1033 |
+
if overrode_max_train_steps:
|
1034 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
1035 |
+
# Afterwards we recalculate our number of training epochs
|
1036 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
1037 |
+
|
1038 |
+
# Train!
|
1039 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
1040 |
+
|
1041 |
+
logger.info("***** Running training *****")
|
1042 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
1043 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
1044 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
1045 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
1046 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
1047 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
1048 |
+
logger.info(f" Optimization steps per epoch = {num_update_steps_per_epoch}")
|
1049 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
1050 |
+
global_step = 0
|
1051 |
+
first_epoch = 0
|
1052 |
+
|
1053 |
+
# Potentially load in the weights and states from a previous save
|
1054 |
+
if args.resume_from_checkpoint:
|
1055 |
+
if args.resume_from_checkpoint != "latest":
|
1056 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
1057 |
+
else:
|
1058 |
+
# Get the most recent checkpoint
|
1059 |
+
dirs = os.listdir(args.output_dir)
|
1060 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
1061 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
1062 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
1063 |
+
|
1064 |
+
if path is None:
|
1065 |
+
accelerator.print(
|
1066 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
1067 |
+
)
|
1068 |
+
args.resume_from_checkpoint = None
|
1069 |
+
initial_global_step = 0
|
1070 |
+
else:
|
1071 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
1072 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
1073 |
+
global_step = int(path.split("-")[1])
|
1074 |
+
|
1075 |
+
initial_global_step = global_step
|
1076 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
1077 |
+
else:
|
1078 |
+
initial_global_step = 0
|
1079 |
+
|
1080 |
+
progress_bar = tqdm(
|
1081 |
+
range(0, args.max_train_steps),
|
1082 |
+
initial=initial_global_step,
|
1083 |
+
desc="Steps",
|
1084 |
+
# Only show the progress bar once on each machine.
|
1085 |
+
disable=not accelerator.is_local_main_process,
|
1086 |
+
)
|
1087 |
+
|
1088 |
+
trainable_models = [unet]
|
1089 |
+
|
1090 |
+
if args.gradient_checkpointing:
|
1091 |
+
checkpoint_models = []
|
1092 |
+
else:
|
1093 |
+
checkpoint_models = []
|
1094 |
+
|
1095 |
+
image_logs = None
|
1096 |
+
tic = time.time()
|
1097 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
1098 |
+
for step, batch in enumerate(train_dataloader):
|
1099 |
+
toc = time.time()
|
1100 |
+
io_time = toc - tic
|
1101 |
+
tic = toc
|
1102 |
+
for model in trainable_models + checkpoint_models:
|
1103 |
+
model.train()
|
1104 |
+
with accelerator.accumulate(*trainable_models):
|
1105 |
+
loss = torch.tensor(0.0)
|
1106 |
+
|
1107 |
+
# Drop conditions.
|
1108 |
+
rand_tensor = torch.rand(batch["images"].shape[0])
|
1109 |
+
drop_image_idx = rand_tensor < args.image_drop_rate
|
1110 |
+
drop_text_idx = (rand_tensor >= args.image_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate)
|
1111 |
+
drop_both_idx = (rand_tensor >= args.image_drop_rate + args.text_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate + args.cond_drop_rate)
|
1112 |
+
drop_image_idx = drop_image_idx | drop_both_idx
|
1113 |
+
drop_text_idx = drop_text_idx | drop_both_idx
|
1114 |
+
|
1115 |
+
# Get LQ embeddings
|
1116 |
+
with torch.no_grad():
|
1117 |
+
lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
|
1118 |
+
lq_pt = image_processor(
|
1119 |
+
images=lq_img*0.5+0.5,
|
1120 |
+
do_rescale=False, return_tensors="pt"
|
1121 |
+
).pixel_values
|
1122 |
+
image_embeds = prepare_training_image_embeds(
|
1123 |
+
image_encoder, image_processor,
|
1124 |
+
ip_adapter_image=lq_pt, ip_adapter_image_embeds=None,
|
1125 |
+
device=accelerator.device, drop_rate=args.image_drop_rate, output_hidden_state=args.image_encoder_hidden_feature,
|
1126 |
+
idx_to_replace=drop_image_idx
|
1127 |
+
)
|
1128 |
+
|
1129 |
+
# Process text inputs.
|
1130 |
+
prompt_embeds_input, added_conditions = compute_embeddings_fn(batch, drop_idx=drop_text_idx)
|
1131 |
+
added_conditions["image_embeds"] = image_embeds
|
1132 |
+
|
1133 |
+
# Move inputs to latent space.
|
1134 |
+
gt_img = gt_img.to(dtype=vae.dtype)
|
1135 |
+
model_input = convert_to_latent(gt_img)
|
1136 |
+
if args.pretrained_vae_model_name_or_path is None:
|
1137 |
+
model_input = model_input.to(weight_dtype)
|
1138 |
+
|
1139 |
+
# Sample noise that we'll add to the latents.
|
1140 |
+
noise = torch.randn_like(model_input)
|
1141 |
+
bsz = model_input.shape[0]
|
1142 |
+
|
1143 |
+
# Sample a random timestep for each image.
|
1144 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device)
|
1145 |
+
|
1146 |
+
# Add noise to the model input according to the noise magnitude at each timestep
|
1147 |
+
# (this is the forward diffusion process)
|
1148 |
+
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
|
1149 |
+
loss_weights = extract_into_tensor(importance_ratio, timesteps, noise.shape) if args.importance_sampling else None
|
1150 |
+
|
1151 |
+
toc = time.time()
|
1152 |
+
prepare_time = toc - tic
|
1153 |
+
tic = time.time()
|
1154 |
+
|
1155 |
+
model_pred = unet(
|
1156 |
+
noisy_model_input, timesteps,
|
1157 |
+
encoder_hidden_states=prompt_embeds_input,
|
1158 |
+
added_cond_kwargs=added_conditions,
|
1159 |
+
return_dict=False
|
1160 |
+
)[0]
|
1161 |
+
|
1162 |
+
diffusion_loss_arguments = {
|
1163 |
+
"target": noise,
|
1164 |
+
"predict": model_pred,
|
1165 |
+
"prompt_embeddings_input": prompt_embeds_input,
|
1166 |
+
"timesteps": timesteps,
|
1167 |
+
"weights": loss_weights,
|
1168 |
+
}
|
1169 |
+
|
1170 |
+
loss_dict = dict()
|
1171 |
+
for loss_config in diffusion_losses:
|
1172 |
+
non_weighted_loss = loss_config.loss(**diffusion_loss_arguments, accelerator=accelerator)
|
1173 |
+
loss = loss + non_weighted_loss * loss_config.weight
|
1174 |
+
loss_dict[loss_config.loss.__class__.__name__] = non_weighted_loss.item()
|
1175 |
+
|
1176 |
+
accelerator.backward(loss)
|
1177 |
+
if accelerator.sync_gradients:
|
1178 |
+
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
|
1179 |
+
optimizer.step()
|
1180 |
+
lr_scheduler.step()
|
1181 |
+
optimizer.zero_grad()
|
1182 |
+
|
1183 |
+
toc = time.time()
|
1184 |
+
forward_time = toc - tic
|
1185 |
+
tic = toc
|
1186 |
+
|
1187 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
1188 |
+
if accelerator.sync_gradients:
|
1189 |
+
progress_bar.update(1)
|
1190 |
+
global_step += 1
|
1191 |
+
|
1192 |
+
if accelerator.is_main_process:
|
1193 |
+
if global_step % args.checkpointing_steps == 0:
|
1194 |
+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
1195 |
+
if args.checkpoints_total_limit is not None:
|
1196 |
+
checkpoints = os.listdir(args.output_dir)
|
1197 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
1198 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
1199 |
+
|
1200 |
+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
1201 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
1202 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
1203 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
1204 |
+
|
1205 |
+
logger.info(
|
1206 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
1207 |
+
)
|
1208 |
+
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
1209 |
+
|
1210 |
+
for removing_checkpoint in removing_checkpoints:
|
1211 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
1212 |
+
shutil.rmtree(removing_checkpoint)
|
1213 |
+
|
1214 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
1215 |
+
accelerator.save_state(save_path)
|
1216 |
+
logger.info(f"Saved state to {save_path}")
|
1217 |
+
|
1218 |
+
if global_step % args.validation_steps == 0:
|
1219 |
+
image_logs = log_validation(unwrap_model(unet), vae,
|
1220 |
+
text_encoder, text_encoder_2, tokenizer, tokenizer_2,
|
1221 |
+
noise_scheduler, image_encoder, image_processor, deg_pipeline,
|
1222 |
+
args, accelerator, weight_dtype, global_step, lq_img, gt_img, is_final_validation=False)
|
1223 |
+
|
1224 |
+
logs = {}
|
1225 |
+
logs.update(loss_dict)
|
1226 |
+
logs.update({
|
1227 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
1228 |
+
"io_time": io_time,
|
1229 |
+
"prepare_time": prepare_time,
|
1230 |
+
"forward_time": forward_time
|
1231 |
+
})
|
1232 |
+
progress_bar.set_postfix(**logs)
|
1233 |
+
accelerator.log(logs, step=global_step)
|
1234 |
+
tic = time.time()
|
1235 |
+
|
1236 |
+
if global_step >= args.max_train_steps:
|
1237 |
+
break
|
1238 |
+
|
1239 |
+
# Create the pipeline using using the trained modules and save it.
|
1240 |
+
accelerator.wait_for_everyone()
|
1241 |
+
if accelerator.is_main_process:
|
1242 |
+
accelerator.save_state(os.path.join(args.output_dir, "last"), safe_serialization=False)
|
1243 |
+
# Run a final round of validation.
|
1244 |
+
# Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`.
|
1245 |
+
image_logs = None
|
1246 |
+
if args.validation_image is not None:
|
1247 |
+
image_logs = log_validation(
|
1248 |
+
unwrap_model(unet), vae,
|
1249 |
+
text_encoder, text_encoder_2, tokenizer, tokenizer_2,
|
1250 |
+
noise_scheduler, image_encoder, image_processor, deg_pipeline,
|
1251 |
+
args, accelerator, weight_dtype, global_step,
|
1252 |
+
)
|
1253 |
+
|
1254 |
+
accelerator.end_training()
|
1255 |
+
|
1256 |
+
|
1257 |
+
if __name__ == "__main__":
|
1258 |
+
args = parse_args()
|
1259 |
+
main(args)
|
train_stage1_adapter.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Stage 1: training lq adapter
|
2 |
+
accelerate launch --num_processes <num_of_gpus> train_stage1_adapter.py \
|
3 |
+
--output_dir <your/output/path> \
|
4 |
+
--train_data_dir <your/data/path> \
|
5 |
+
--logging_dir <your/logging/path> \
|
6 |
+
--pretrained_model_name_or_path <your/sdxl/path> \
|
7 |
+
--feature_extractor_path <your/dinov2/path> \
|
8 |
+
--save_only_adapter \
|
9 |
+
--gradient_checkpointing \
|
10 |
+
--mixed_precision fp16 \
|
11 |
+
--train_batch_size 96 \
|
12 |
+
--gradient_accumulation_steps 1 \
|
13 |
+
--learning_rate 1e-4 \
|
14 |
+
--lr_warmup_steps 1000 \
|
15 |
+
--lr_scheduler cosine \
|
16 |
+
--lr_num_cycles 1 \
|
17 |
+
--resume_from_checkpoint latest
|
train_stage2_aggregator.py
ADDED
@@ -0,0 +1,1698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
|
16 |
+
import os
|
17 |
+
import argparse
|
18 |
+
import time
|
19 |
+
import gc
|
20 |
+
import logging
|
21 |
+
import math
|
22 |
+
import copy
|
23 |
+
import random
|
24 |
+
import yaml
|
25 |
+
import functools
|
26 |
+
import shutil
|
27 |
+
import pyrallis
|
28 |
+
from pathlib import Path
|
29 |
+
from collections import namedtuple, OrderedDict
|
30 |
+
|
31 |
+
import accelerate
|
32 |
+
import numpy as np
|
33 |
+
import torch
|
34 |
+
from safetensors import safe_open
|
35 |
+
import torch.nn.functional as F
|
36 |
+
import torch.utils.checkpoint
|
37 |
+
import transformers
|
38 |
+
from accelerate import Accelerator
|
39 |
+
from accelerate.logging import get_logger
|
40 |
+
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
41 |
+
from datasets import load_dataset
|
42 |
+
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
|
43 |
+
from huggingface_hub import create_repo, upload_folder
|
44 |
+
from packaging import version
|
45 |
+
from PIL import Image
|
46 |
+
from data.data_config import DataConfig
|
47 |
+
from basicsr.utils.degradation_pipeline import RealESRGANDegradation
|
48 |
+
from losses.loss_config import LossesConfig
|
49 |
+
from losses.losses import *
|
50 |
+
from torchvision import transforms
|
51 |
+
from torchvision.transforms.functional import crop
|
52 |
+
from tqdm.auto import tqdm
|
53 |
+
from transformers import (
|
54 |
+
AutoTokenizer,
|
55 |
+
PretrainedConfig,
|
56 |
+
CLIPImageProcessor, CLIPVisionModelWithProjection,
|
57 |
+
AutoImageProcessor, AutoModel
|
58 |
+
)
|
59 |
+
|
60 |
+
import diffusers
|
61 |
+
from diffusers import (
|
62 |
+
AutoencoderKL,
|
63 |
+
DDPMScheduler,
|
64 |
+
StableDiffusionXLPipeline,
|
65 |
+
UNet2DConditionModel,
|
66 |
+
)
|
67 |
+
from diffusers.optimization import get_scheduler
|
68 |
+
from diffusers.utils import (
|
69 |
+
check_min_version,
|
70 |
+
convert_unet_state_dict_to_peft,
|
71 |
+
is_wandb_available,
|
72 |
+
)
|
73 |
+
from diffusers.utils.import_utils import is_xformers_available
|
74 |
+
from diffusers.utils.torch_utils import is_compiled_module
|
75 |
+
|
76 |
+
from module.aggregator import Aggregator
|
77 |
+
from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
|
78 |
+
from module.ip_adapter.ip_adapter import MultiIPAdapterImageProjection
|
79 |
+
from module.ip_adapter.resampler import Resampler
|
80 |
+
from module.ip_adapter.utils import init_adapter_in_unet, prepare_training_image_embeds
|
81 |
+
from module.ip_adapter.attention_processor import init_attn_proc
|
82 |
+
from utils.train_utils import (
|
83 |
+
seperate_ip_params_from_unet,
|
84 |
+
import_model_class_from_model_name_or_path,
|
85 |
+
tensor_to_pil,
|
86 |
+
get_train_dataset, prepare_train_dataset, collate_fn,
|
87 |
+
encode_prompt, importance_sampling_fn, extract_into_tensor
|
88 |
+
)
|
89 |
+
from pipelines.sdxl_instantir import InstantIRPipeline
|
90 |
+
|
91 |
+
|
92 |
+
if is_wandb_available():
|
93 |
+
import wandb
|
94 |
+
|
95 |
+
|
96 |
+
logger = get_logger(__name__)
|
97 |
+
|
98 |
+
|
99 |
+
def log_validation(unet, aggregator, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
|
100 |
+
scheduler, lcm_scheduler, image_encoder, image_processor, deg_pipeline,
|
101 |
+
args, accelerator, weight_dtype, step, lq_img=None, gt_img=None, is_final_validation=False, log_local=False):
|
102 |
+
logger.info("Running validation... ")
|
103 |
+
|
104 |
+
image_logs = []
|
105 |
+
|
106 |
+
# validation_batch = batchify_pil(args.validation_image, args.validation_prompt, deg_pipeline, image_processor)
|
107 |
+
lq = [Image.open(lq_example).convert("RGB") for lq_example in args.validation_image]
|
108 |
+
|
109 |
+
pipe = InstantIRPipeline(
|
110 |
+
vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
|
111 |
+
unet, scheduler, aggregator, feature_extractor=image_processor, image_encoder=image_encoder,
|
112 |
+
).to(accelerator.device)
|
113 |
+
|
114 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
115 |
+
if lq_img is not None and gt_img is not None:
|
116 |
+
lq_img = lq_img[:len(args.validation_image)]
|
117 |
+
lq_pt = image_processor(
|
118 |
+
images=lq_img*0.5+0.5,
|
119 |
+
do_rescale=False, return_tensors="pt"
|
120 |
+
).pixel_values
|
121 |
+
image = pipe(
|
122 |
+
prompt=[""]*len(lq_img),
|
123 |
+
image=lq_img,
|
124 |
+
ip_adapter_image=lq_pt,
|
125 |
+
num_inference_steps=20,
|
126 |
+
generator=generator,
|
127 |
+
controlnet_conditioning_scale=1.0,
|
128 |
+
negative_prompt=[""]*len(lq),
|
129 |
+
guidance_scale=5.0,
|
130 |
+
height=args.resolution,
|
131 |
+
width=args.resolution,
|
132 |
+
lcm_scheduler=lcm_scheduler,
|
133 |
+
).images
|
134 |
+
else:
|
135 |
+
image = pipe(
|
136 |
+
prompt=[""]*len(lq),
|
137 |
+
image=lq,
|
138 |
+
ip_adapter_image=lq,
|
139 |
+
num_inference_steps=20,
|
140 |
+
generator=generator,
|
141 |
+
controlnet_conditioning_scale=1.0,
|
142 |
+
negative_prompt=[""]*len(lq),
|
143 |
+
guidance_scale=5.0,
|
144 |
+
height=args.resolution,
|
145 |
+
width=args.resolution,
|
146 |
+
lcm_scheduler=lcm_scheduler,
|
147 |
+
).images
|
148 |
+
|
149 |
+
if log_local:
|
150 |
+
for i, rec_image in enumerate(image):
|
151 |
+
rec_image.save(f"./instantid_{i}.png")
|
152 |
+
return
|
153 |
+
|
154 |
+
tracker_key = "test" if is_final_validation else "validation"
|
155 |
+
for tracker in accelerator.trackers:
|
156 |
+
if tracker.name == "tensorboard":
|
157 |
+
images = [np.asarray(pil_img) for pil_img in image]
|
158 |
+
images = np.stack(images, axis=0)
|
159 |
+
if lq_img is not None and gt_img is not None:
|
160 |
+
input_lq = lq_img.cpu()
|
161 |
+
input_lq = np.asarray(input_lq.add(1).div(2).clamp(0, 1))
|
162 |
+
input_gt = gt_img.cpu()
|
163 |
+
input_gt = np.asarray(input_gt.add(1).div(2).clamp(0, 1))
|
164 |
+
tracker.writer.add_images("lq", input_lq, step, dataformats="NCHW")
|
165 |
+
tracker.writer.add_images("gt", input_gt, step, dataformats="NCHW")
|
166 |
+
tracker.writer.add_images("rec", images, step, dataformats="NHWC")
|
167 |
+
elif tracker.name == "wandb":
|
168 |
+
raise NotImplementedError("Wandb logging not implemented for validation.")
|
169 |
+
formatted_images = []
|
170 |
+
|
171 |
+
for log in image_logs:
|
172 |
+
images = log["images"]
|
173 |
+
validation_prompt = log["validation_prompt"]
|
174 |
+
validation_image = log["validation_image"]
|
175 |
+
|
176 |
+
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
|
177 |
+
|
178 |
+
for image in images:
|
179 |
+
image = wandb.Image(image, caption=validation_prompt)
|
180 |
+
formatted_images.append(image)
|
181 |
+
|
182 |
+
tracker.log({tracker_key: formatted_images})
|
183 |
+
else:
|
184 |
+
logger.warning(f"image logging not implemented for {tracker.name}")
|
185 |
+
|
186 |
+
gc.collect()
|
187 |
+
torch.cuda.empty_cache()
|
188 |
+
|
189 |
+
return image_logs
|
190 |
+
|
191 |
+
|
192 |
+
def remove_attn2(model):
|
193 |
+
def recursive_find_module(name, module):
|
194 |
+
if not "up_blocks" in name and not "down_blocks" in name and not "mid_block" in name: return
|
195 |
+
elif "resnets" in name: return
|
196 |
+
if hasattr(module, "attn2"):
|
197 |
+
setattr(module, "attn2", None)
|
198 |
+
setattr(module, "norm2", None)
|
199 |
+
return
|
200 |
+
for sub_name, sub_module in module.named_children():
|
201 |
+
recursive_find_module(f"{name}.{sub_name}", sub_module)
|
202 |
+
|
203 |
+
for name, module in model.named_children():
|
204 |
+
recursive_find_module(name, module)
|
205 |
+
|
206 |
+
|
207 |
+
def parse_args(input_args=None):
|
208 |
+
parser = argparse.ArgumentParser(description="Simple example of a IP-Adapter training script.")
|
209 |
+
parser.add_argument(
|
210 |
+
"--pretrained_model_name_or_path",
|
211 |
+
type=str,
|
212 |
+
default=None,
|
213 |
+
required=True,
|
214 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
215 |
+
)
|
216 |
+
parser.add_argument(
|
217 |
+
"--pretrained_vae_model_name_or_path",
|
218 |
+
type=str,
|
219 |
+
default=None,
|
220 |
+
help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
|
221 |
+
)
|
222 |
+
parser.add_argument(
|
223 |
+
"--controlnet_model_name_or_path",
|
224 |
+
type=str,
|
225 |
+
default=None,
|
226 |
+
help="Path to an pretrained controlnet model like tile-controlnet.",
|
227 |
+
)
|
228 |
+
parser.add_argument(
|
229 |
+
"--use_lcm",
|
230 |
+
action="store_true",
|
231 |
+
help="Whether or not to use lcm unet.",
|
232 |
+
)
|
233 |
+
parser.add_argument(
|
234 |
+
"--pretrained_lcm_lora_path",
|
235 |
+
type=str,
|
236 |
+
default=None,
|
237 |
+
help="Path to LCM lora or model identifier from huggingface.co/models.",
|
238 |
+
)
|
239 |
+
parser.add_argument(
|
240 |
+
"--lora_rank",
|
241 |
+
type=int,
|
242 |
+
default=64,
|
243 |
+
help="The rank of the LoRA projection matrix.",
|
244 |
+
)
|
245 |
+
parser.add_argument(
|
246 |
+
"--lora_alpha",
|
247 |
+
type=int,
|
248 |
+
default=64,
|
249 |
+
help=(
|
250 |
+
"The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
|
251 |
+
" update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
|
252 |
+
),
|
253 |
+
)
|
254 |
+
parser.add_argument(
|
255 |
+
"--lora_dropout",
|
256 |
+
type=float,
|
257 |
+
default=0.0,
|
258 |
+
help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
|
259 |
+
)
|
260 |
+
parser.add_argument(
|
261 |
+
"--lora_target_modules",
|
262 |
+
type=str,
|
263 |
+
default=None,
|
264 |
+
help=(
|
265 |
+
"A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
|
266 |
+
" be used. By default, LoRA will be applied to all conv and linear layers."
|
267 |
+
),
|
268 |
+
)
|
269 |
+
parser.add_argument(
|
270 |
+
"--feature_extractor_path",
|
271 |
+
type=str,
|
272 |
+
default=None,
|
273 |
+
help="Path to image encoder for IP-Adapters or model identifier from huggingface.co/models.",
|
274 |
+
)
|
275 |
+
parser.add_argument(
|
276 |
+
"--pretrained_adapter_model_path",
|
277 |
+
type=str,
|
278 |
+
default=None,
|
279 |
+
help="Path to IP-Adapter models or model identifier from huggingface.co/models.",
|
280 |
+
)
|
281 |
+
parser.add_argument(
|
282 |
+
"--adapter_tokens",
|
283 |
+
type=int,
|
284 |
+
default=64,
|
285 |
+
help="Number of tokens to use in IP-adapter cross attention mechanism.",
|
286 |
+
)
|
287 |
+
parser.add_argument(
|
288 |
+
"--aggregator_adapter",
|
289 |
+
action="store_true",
|
290 |
+
help="Whether or not to add adapter on aggregator.",
|
291 |
+
)
|
292 |
+
parser.add_argument(
|
293 |
+
"--optimize_adapter",
|
294 |
+
action="store_true",
|
295 |
+
help="Whether or not to optimize IP-Adapter.",
|
296 |
+
)
|
297 |
+
parser.add_argument(
|
298 |
+
"--image_encoder_hidden_feature",
|
299 |
+
action="store_true",
|
300 |
+
help="Whether or not to use the penultimate hidden states as image embeddings.",
|
301 |
+
)
|
302 |
+
parser.add_argument(
|
303 |
+
"--losses_config_path",
|
304 |
+
type=str,
|
305 |
+
required=True,
|
306 |
+
help=("A yaml file containing losses to use and their weights."),
|
307 |
+
)
|
308 |
+
parser.add_argument(
|
309 |
+
"--data_config_path",
|
310 |
+
type=str,
|
311 |
+
default=None,
|
312 |
+
help=("A folder containing the training data. "),
|
313 |
+
)
|
314 |
+
parser.add_argument(
|
315 |
+
"--variant",
|
316 |
+
type=str,
|
317 |
+
default=None,
|
318 |
+
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
319 |
+
)
|
320 |
+
parser.add_argument(
|
321 |
+
"--revision",
|
322 |
+
type=str,
|
323 |
+
default=None,
|
324 |
+
required=False,
|
325 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
326 |
+
)
|
327 |
+
parser.add_argument(
|
328 |
+
"--tokenizer_name",
|
329 |
+
type=str,
|
330 |
+
default=None,
|
331 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
332 |
+
)
|
333 |
+
parser.add_argument(
|
334 |
+
"--output_dir",
|
335 |
+
type=str,
|
336 |
+
default="stage1_model",
|
337 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
338 |
+
)
|
339 |
+
parser.add_argument(
|
340 |
+
"--cache_dir",
|
341 |
+
type=str,
|
342 |
+
default=None,
|
343 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
344 |
+
)
|
345 |
+
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
346 |
+
parser.add_argument(
|
347 |
+
"--resolution",
|
348 |
+
type=int,
|
349 |
+
default=512,
|
350 |
+
help=(
|
351 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
352 |
+
" resolution"
|
353 |
+
),
|
354 |
+
)
|
355 |
+
parser.add_argument(
|
356 |
+
"--crops_coords_top_left_h",
|
357 |
+
type=int,
|
358 |
+
default=0,
|
359 |
+
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
|
360 |
+
)
|
361 |
+
parser.add_argument(
|
362 |
+
"--crops_coords_top_left_w",
|
363 |
+
type=int,
|
364 |
+
default=0,
|
365 |
+
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
|
366 |
+
)
|
367 |
+
parser.add_argument(
|
368 |
+
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
369 |
+
)
|
370 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
371 |
+
parser.add_argument(
|
372 |
+
"--max_train_steps",
|
373 |
+
type=int,
|
374 |
+
default=None,
|
375 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
376 |
+
)
|
377 |
+
parser.add_argument(
|
378 |
+
"--checkpointing_steps",
|
379 |
+
type=int,
|
380 |
+
default=3000,
|
381 |
+
help=(
|
382 |
+
"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
|
383 |
+
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
|
384 |
+
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
|
385 |
+
"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
|
386 |
+
"instructions."
|
387 |
+
),
|
388 |
+
)
|
389 |
+
parser.add_argument(
|
390 |
+
"--checkpoints_total_limit",
|
391 |
+
type=int,
|
392 |
+
default=5,
|
393 |
+
help=("Max number of checkpoints to store."),
|
394 |
+
)
|
395 |
+
parser.add_argument(
|
396 |
+
"--previous_ckpt",
|
397 |
+
type=str,
|
398 |
+
default=None,
|
399 |
+
help=(
|
400 |
+
"Whether training should be initialized from a previous checkpoint."
|
401 |
+
),
|
402 |
+
)
|
403 |
+
parser.add_argument(
|
404 |
+
"--resume_from_checkpoint",
|
405 |
+
type=str,
|
406 |
+
default=None,
|
407 |
+
help=(
|
408 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
409 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
410 |
+
),
|
411 |
+
)
|
412 |
+
parser.add_argument(
|
413 |
+
"--gradient_accumulation_steps",
|
414 |
+
type=int,
|
415 |
+
default=1,
|
416 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
417 |
+
)
|
418 |
+
parser.add_argument(
|
419 |
+
"--gradient_checkpointing",
|
420 |
+
action="store_true",
|
421 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
422 |
+
)
|
423 |
+
parser.add_argument(
|
424 |
+
"--save_only_adapter",
|
425 |
+
action="store_true",
|
426 |
+
help="Only save extra adapter to save space.",
|
427 |
+
)
|
428 |
+
parser.add_argument(
|
429 |
+
"--cache_prompt_embeds",
|
430 |
+
action="store_true",
|
431 |
+
help="Whether or not to cache prompt embeds to save memory.",
|
432 |
+
)
|
433 |
+
parser.add_argument(
|
434 |
+
"--importance_sampling",
|
435 |
+
action="store_true",
|
436 |
+
help="Whether or not to use importance sampling.",
|
437 |
+
)
|
438 |
+
parser.add_argument(
|
439 |
+
"--CFG_scale",
|
440 |
+
type=float,
|
441 |
+
default=1.0,
|
442 |
+
help="CFG for previewer.",
|
443 |
+
)
|
444 |
+
parser.add_argument(
|
445 |
+
"--learning_rate",
|
446 |
+
type=float,
|
447 |
+
default=1e-4,
|
448 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
449 |
+
)
|
450 |
+
parser.add_argument(
|
451 |
+
"--scale_lr",
|
452 |
+
action="store_true",
|
453 |
+
default=False,
|
454 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
455 |
+
)
|
456 |
+
parser.add_argument(
|
457 |
+
"--lr_scheduler",
|
458 |
+
type=str,
|
459 |
+
default="constant",
|
460 |
+
help=(
|
461 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
462 |
+
' "constant", "constant_with_warmup"]'
|
463 |
+
),
|
464 |
+
)
|
465 |
+
parser.add_argument(
|
466 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
467 |
+
)
|
468 |
+
parser.add_argument(
|
469 |
+
"--lr_num_cycles",
|
470 |
+
type=int,
|
471 |
+
default=1,
|
472 |
+
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
473 |
+
)
|
474 |
+
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
|
475 |
+
parser.add_argument(
|
476 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
477 |
+
)
|
478 |
+
parser.add_argument(
|
479 |
+
"--dataloader_num_workers",
|
480 |
+
type=int,
|
481 |
+
default=0,
|
482 |
+
help=(
|
483 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
484 |
+
),
|
485 |
+
)
|
486 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
487 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
488 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
489 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
490 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
491 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
492 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
493 |
+
parser.add_argument(
|
494 |
+
"--hub_model_id",
|
495 |
+
type=str,
|
496 |
+
default=None,
|
497 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
498 |
+
)
|
499 |
+
parser.add_argument(
|
500 |
+
"--logging_dir",
|
501 |
+
type=str,
|
502 |
+
default="logs",
|
503 |
+
help=(
|
504 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
505 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
506 |
+
),
|
507 |
+
)
|
508 |
+
parser.add_argument(
|
509 |
+
"--allow_tf32",
|
510 |
+
action="store_true",
|
511 |
+
help=(
|
512 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
513 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
514 |
+
),
|
515 |
+
)
|
516 |
+
parser.add_argument(
|
517 |
+
"--report_to",
|
518 |
+
type=str,
|
519 |
+
default="tensorboard",
|
520 |
+
help=(
|
521 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
522 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
523 |
+
),
|
524 |
+
)
|
525 |
+
parser.add_argument(
|
526 |
+
"--mixed_precision",
|
527 |
+
type=str,
|
528 |
+
default=None,
|
529 |
+
choices=["no", "fp16", "bf16"],
|
530 |
+
help=(
|
531 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
532 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
533 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
534 |
+
),
|
535 |
+
)
|
536 |
+
parser.add_argument(
|
537 |
+
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
538 |
+
)
|
539 |
+
parser.add_argument(
|
540 |
+
"--set_grads_to_none",
|
541 |
+
action="store_true",
|
542 |
+
help=(
|
543 |
+
"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
|
544 |
+
" behaviors, so disable this argument if it causes any problems. More info:"
|
545 |
+
" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
|
546 |
+
),
|
547 |
+
)
|
548 |
+
parser.add_argument(
|
549 |
+
"--dataset_name",
|
550 |
+
type=str,
|
551 |
+
default=None,
|
552 |
+
help=(
|
553 |
+
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
554 |
+
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
555 |
+
" or to a folder containing files that 🤗 Datasets can understand."
|
556 |
+
),
|
557 |
+
)
|
558 |
+
parser.add_argument(
|
559 |
+
"--dataset_config_name",
|
560 |
+
type=str,
|
561 |
+
default=None,
|
562 |
+
help="The config of the Dataset, leave as None if there's only one config.",
|
563 |
+
)
|
564 |
+
parser.add_argument(
|
565 |
+
"--train_data_dir",
|
566 |
+
type=str,
|
567 |
+
default=None,
|
568 |
+
help=(
|
569 |
+
"A folder containing the training data. Folder contents must follow the structure described in"
|
570 |
+
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
|
571 |
+
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
|
572 |
+
),
|
573 |
+
)
|
574 |
+
parser.add_argument(
|
575 |
+
"--image_column", type=str, default="image", help="The column of the dataset containing the target image."
|
576 |
+
)
|
577 |
+
parser.add_argument(
|
578 |
+
"--conditioning_image_column",
|
579 |
+
type=str,
|
580 |
+
default="conditioning_image",
|
581 |
+
help="The column of the dataset containing the controlnet conditioning image.",
|
582 |
+
)
|
583 |
+
parser.add_argument(
|
584 |
+
"--caption_column",
|
585 |
+
type=str,
|
586 |
+
default="text",
|
587 |
+
help="The column of the dataset containing a caption or a list of captions.",
|
588 |
+
)
|
589 |
+
parser.add_argument(
|
590 |
+
"--max_train_samples",
|
591 |
+
type=int,
|
592 |
+
default=None,
|
593 |
+
help=(
|
594 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
595 |
+
"value if set."
|
596 |
+
),
|
597 |
+
)
|
598 |
+
parser.add_argument(
|
599 |
+
"--text_drop_rate",
|
600 |
+
type=float,
|
601 |
+
default=0,
|
602 |
+
help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
|
603 |
+
)
|
604 |
+
parser.add_argument(
|
605 |
+
"--image_drop_rate",
|
606 |
+
type=float,
|
607 |
+
default=0,
|
608 |
+
help="Proportion of IP-Adapter inputs to be dropped. Defaults to 0 (no drop-out).",
|
609 |
+
)
|
610 |
+
parser.add_argument(
|
611 |
+
"--cond_drop_rate",
|
612 |
+
type=float,
|
613 |
+
default=0,
|
614 |
+
help="Proportion of all conditions to be dropped. Defaults to 0 (no drop-out).",
|
615 |
+
)
|
616 |
+
parser.add_argument(
|
617 |
+
"--use_ema_adapter",
|
618 |
+
action="store_true",
|
619 |
+
help=(
|
620 |
+
"use ema ip-adapter for LCM preview"
|
621 |
+
),
|
622 |
+
)
|
623 |
+
parser.add_argument(
|
624 |
+
"--sanity_check",
|
625 |
+
action="store_true",
|
626 |
+
help=(
|
627 |
+
"sanity check"
|
628 |
+
),
|
629 |
+
)
|
630 |
+
parser.add_argument(
|
631 |
+
"--validation_prompt",
|
632 |
+
type=str,
|
633 |
+
default=None,
|
634 |
+
nargs="+",
|
635 |
+
help=(
|
636 |
+
"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
|
637 |
+
" Provide either a matching number of `--validation_image`s, a single `--validation_image`"
|
638 |
+
" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
|
639 |
+
),
|
640 |
+
)
|
641 |
+
parser.add_argument(
|
642 |
+
"--validation_image",
|
643 |
+
type=str,
|
644 |
+
default=None,
|
645 |
+
nargs="+",
|
646 |
+
help=(
|
647 |
+
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
|
648 |
+
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
|
649 |
+
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
|
650 |
+
" `--validation_image` that will be used with all `--validation_prompt`s."
|
651 |
+
),
|
652 |
+
)
|
653 |
+
parser.add_argument(
|
654 |
+
"--num_validation_images",
|
655 |
+
type=int,
|
656 |
+
default=4,
|
657 |
+
help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
|
658 |
+
)
|
659 |
+
parser.add_argument(
|
660 |
+
"--validation_steps",
|
661 |
+
type=int,
|
662 |
+
default=4000,
|
663 |
+
help=(
|
664 |
+
"Run validation every X steps. Validation consists of running the prompt"
|
665 |
+
" `args.validation_prompt` multiple times: `args.num_validation_images`"
|
666 |
+
" and logging the images."
|
667 |
+
),
|
668 |
+
)
|
669 |
+
parser.add_argument(
|
670 |
+
"--tracker_project_name",
|
671 |
+
type=str,
|
672 |
+
default='train',
|
673 |
+
help=(
|
674 |
+
"The `project_name` argument passed to Accelerator.init_trackers for"
|
675 |
+
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
|
676 |
+
),
|
677 |
+
)
|
678 |
+
|
679 |
+
if input_args is not None:
|
680 |
+
args = parser.parse_args(input_args)
|
681 |
+
else:
|
682 |
+
args = parser.parse_args()
|
683 |
+
|
684 |
+
if not args.sanity_check and args.dataset_name is None and args.train_data_dir is None and args.data_config_path is None:
|
685 |
+
raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
|
686 |
+
|
687 |
+
if args.dataset_name is not None and args.train_data_dir is not None:
|
688 |
+
raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
|
689 |
+
|
690 |
+
if args.text_drop_rate < 0 or args.text_drop_rate > 1:
|
691 |
+
raise ValueError("`--text_drop_rate` must be in the range [0, 1].")
|
692 |
+
|
693 |
+
if args.validation_prompt is not None and args.validation_image is None:
|
694 |
+
raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
|
695 |
+
|
696 |
+
if args.validation_prompt is None and args.validation_image is not None:
|
697 |
+
raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
|
698 |
+
|
699 |
+
if (
|
700 |
+
args.validation_image is not None
|
701 |
+
and args.validation_prompt is not None
|
702 |
+
and len(args.validation_image) != 1
|
703 |
+
and len(args.validation_prompt) != 1
|
704 |
+
and len(args.validation_image) != len(args.validation_prompt)
|
705 |
+
):
|
706 |
+
raise ValueError(
|
707 |
+
"Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
|
708 |
+
" or the same number of `--validation_prompt`s and `--validation_image`s"
|
709 |
+
)
|
710 |
+
|
711 |
+
if args.resolution % 8 != 0:
|
712 |
+
raise ValueError(
|
713 |
+
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
|
714 |
+
)
|
715 |
+
|
716 |
+
return args
|
717 |
+
|
718 |
+
|
719 |
+
def update_ema_model(ema_model, model, ema_beta):
|
720 |
+
for ema_param, param in zip(ema_model.parameters(), model.parameters()):
|
721 |
+
ema_param.copy_(param.detach().lerp(ema_param, ema_beta))
|
722 |
+
|
723 |
+
|
724 |
+
def copy_dict(dict):
|
725 |
+
new_dict = {}
|
726 |
+
for key, value in dict.items():
|
727 |
+
new_dict[key] = value
|
728 |
+
return new_dict
|
729 |
+
|
730 |
+
|
731 |
+
def main(args):
|
732 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
733 |
+
raise ValueError(
|
734 |
+
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
735 |
+
" Please use `huggingface-cli login` to authenticate with the Hub."
|
736 |
+
)
|
737 |
+
|
738 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
739 |
+
|
740 |
+
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
741 |
+
# due to pytorch#99272, MPS does not yet support bfloat16.
|
742 |
+
raise ValueError(
|
743 |
+
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
744 |
+
)
|
745 |
+
|
746 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
747 |
+
|
748 |
+
accelerator = Accelerator(
|
749 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
750 |
+
mixed_precision=args.mixed_precision,
|
751 |
+
log_with=args.report_to,
|
752 |
+
project_config=accelerator_project_config,
|
753 |
+
)
|
754 |
+
|
755 |
+
# Make one log on every process with the configuration for debugging.
|
756 |
+
logging.basicConfig(
|
757 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
758 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
759 |
+
level=logging.INFO,
|
760 |
+
)
|
761 |
+
logger.info(accelerator.state, main_process_only=False)
|
762 |
+
if accelerator.is_local_main_process:
|
763 |
+
transformers.utils.logging.set_verbosity_warning()
|
764 |
+
diffusers.utils.logging.set_verbosity_info()
|
765 |
+
else:
|
766 |
+
transformers.utils.logging.set_verbosity_error()
|
767 |
+
diffusers.utils.logging.set_verbosity_error()
|
768 |
+
|
769 |
+
# If passed along, set the training seed now.
|
770 |
+
if args.seed is not None:
|
771 |
+
set_seed(args.seed)
|
772 |
+
|
773 |
+
# Handle the repository creation.
|
774 |
+
if accelerator.is_main_process:
|
775 |
+
if args.output_dir is not None:
|
776 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
777 |
+
|
778 |
+
# Load scheduler and models
|
779 |
+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
780 |
+
|
781 |
+
# Importance sampling.
|
782 |
+
list_of_candidates = np.arange(noise_scheduler.config.num_train_timesteps, dtype='float64')
|
783 |
+
prob_dist = importance_sampling_fn(list_of_candidates, noise_scheduler.config.num_train_timesteps, 0.5)
|
784 |
+
importance_ratio = prob_dist / prob_dist.sum() * noise_scheduler.config.num_train_timesteps
|
785 |
+
importance_ratio = torch.from_numpy(importance_ratio.copy()).float()
|
786 |
+
|
787 |
+
# Load the tokenizers
|
788 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
789 |
+
args.pretrained_model_name_or_path,
|
790 |
+
subfolder="tokenizer",
|
791 |
+
revision=args.revision,
|
792 |
+
use_fast=False,
|
793 |
+
)
|
794 |
+
tokenizer_2 = AutoTokenizer.from_pretrained(
|
795 |
+
args.pretrained_model_name_or_path,
|
796 |
+
subfolder="tokenizer_2",
|
797 |
+
revision=args.revision,
|
798 |
+
use_fast=False,
|
799 |
+
)
|
800 |
+
|
801 |
+
# Text encoder and image encoder.
|
802 |
+
text_encoder_cls_one = import_model_class_from_model_name_or_path(
|
803 |
+
args.pretrained_model_name_or_path, args.revision
|
804 |
+
)
|
805 |
+
text_encoder_cls_two = import_model_class_from_model_name_or_path(
|
806 |
+
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
|
807 |
+
)
|
808 |
+
text_encoder = text_encoder_cls_one.from_pretrained(
|
809 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
810 |
+
)
|
811 |
+
text_encoder_2 = text_encoder_cls_two.from_pretrained(
|
812 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
|
813 |
+
)
|
814 |
+
|
815 |
+
# Image processor and image encoder.
|
816 |
+
if args.use_clip_encoder:
|
817 |
+
image_processor = CLIPImageProcessor()
|
818 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.feature_extractor_path)
|
819 |
+
else:
|
820 |
+
image_processor = AutoImageProcessor.from_pretrained(args.feature_extractor_path)
|
821 |
+
image_encoder = AutoModel.from_pretrained(args.feature_extractor_path)
|
822 |
+
|
823 |
+
# VAE.
|
824 |
+
vae_path = (
|
825 |
+
args.pretrained_model_name_or_path
|
826 |
+
if args.pretrained_vae_model_name_or_path is None
|
827 |
+
else args.pretrained_vae_model_name_or_path
|
828 |
+
)
|
829 |
+
vae = AutoencoderKL.from_pretrained(
|
830 |
+
vae_path,
|
831 |
+
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
832 |
+
revision=args.revision,
|
833 |
+
variant=args.variant,
|
834 |
+
)
|
835 |
+
|
836 |
+
# UNet.
|
837 |
+
unet = UNet2DConditionModel.from_pretrained(
|
838 |
+
args.pretrained_model_name_or_path,
|
839 |
+
subfolder="unet",
|
840 |
+
revision=args.revision,
|
841 |
+
variant=args.variant
|
842 |
+
)
|
843 |
+
|
844 |
+
# Aggregator.
|
845 |
+
aggregator = Aggregator.from_unet(unet)
|
846 |
+
remove_attn2(aggregator)
|
847 |
+
if args.controlnet_model_name_or_path:
|
848 |
+
logger.info("Loading existing controlnet weights")
|
849 |
+
if args.controlnet_model_name_or_path.endswith(".safetensors"):
|
850 |
+
pretrained_cn_state_dict = {}
|
851 |
+
with safe_open(args.controlnet_model_name_or_path, framework="pt", device='cpu') as f:
|
852 |
+
for key in f.keys():
|
853 |
+
pretrained_cn_state_dict[key] = f.get_tensor(key)
|
854 |
+
else:
|
855 |
+
pretrained_cn_state_dict = torch.load(os.path.join(args.controlnet_model_name_or_path, "aggregator_ckpt.pt"), map_location="cpu")
|
856 |
+
aggregator.load_state_dict(pretrained_cn_state_dict, strict=True)
|
857 |
+
else:
|
858 |
+
logger.info("Initializing aggregator weights from unet.")
|
859 |
+
|
860 |
+
# Create image embedding projector for IP-Adapters.
|
861 |
+
if args.pretrained_adapter_model_path is not None:
|
862 |
+
if args.pretrained_adapter_model_path.endswith(".safetensors"):
|
863 |
+
pretrained_adapter_state_dict = {"image_proj": {}, "ip_adapter": {}}
|
864 |
+
with safe_open(args.pretrained_adapter_model_path, framework="pt", device="cpu") as f:
|
865 |
+
for key in f.keys():
|
866 |
+
if key.startswith("image_proj."):
|
867 |
+
pretrained_adapter_state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
868 |
+
elif key.startswith("ip_adapter."):
|
869 |
+
pretrained_adapter_state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
870 |
+
else:
|
871 |
+
pretrained_adapter_state_dict = torch.load(args.pretrained_adapter_model_path, map_location="cpu")
|
872 |
+
|
873 |
+
# Image embedding Projector.
|
874 |
+
image_proj_model = Resampler(
|
875 |
+
dim=1280,
|
876 |
+
depth=4,
|
877 |
+
dim_head=64,
|
878 |
+
heads=20,
|
879 |
+
num_queries=args.adapter_tokens,
|
880 |
+
embedding_dim=image_encoder.config.hidden_size,
|
881 |
+
output_dim=unet.config.cross_attention_dim,
|
882 |
+
ff_mult=4
|
883 |
+
)
|
884 |
+
|
885 |
+
init_adapter_in_unet(
|
886 |
+
unet,
|
887 |
+
image_proj_model,
|
888 |
+
pretrained_adapter_state_dict,
|
889 |
+
adapter_tokens=args.adapter_tokens,
|
890 |
+
)
|
891 |
+
|
892 |
+
# EMA adapter for LCM preview.
|
893 |
+
if args.use_ema_adapter:
|
894 |
+
assert args.optimize_adapter, "No need for EMA with frozen adapter."
|
895 |
+
ema_image_proj_model = Resampler(
|
896 |
+
dim=1280,
|
897 |
+
depth=4,
|
898 |
+
dim_head=64,
|
899 |
+
heads=20,
|
900 |
+
num_queries=args.adapter_tokens,
|
901 |
+
embedding_dim=image_encoder.config.hidden_size,
|
902 |
+
output_dim=unet.config.cross_attention_dim,
|
903 |
+
ff_mult=4
|
904 |
+
)
|
905 |
+
orig_encoder_hid_proj = unet.encoder_hid_proj
|
906 |
+
ema_encoder_hid_proj = MultiIPAdapterImageProjection([ema_image_proj_model])
|
907 |
+
orig_attn_procs = unet.attn_processors
|
908 |
+
orig_attn_procs_list = torch.nn.ModuleList(orig_attn_procs.values())
|
909 |
+
ema_attn_procs = init_attn_proc(unet, args.adapter_tokens, True, True, False)
|
910 |
+
ema_attn_procs_list = torch.nn.ModuleList(ema_attn_procs.values())
|
911 |
+
ema_attn_procs_list.requires_grad_(False)
|
912 |
+
ema_encoder_hid_proj.requires_grad_(False)
|
913 |
+
|
914 |
+
# Initialize EMA state.
|
915 |
+
ema_beta = 0.5 ** (args.ema_update_steps / max(args.ema_halflife_steps, 1e-8))
|
916 |
+
logger.info(f"Using EMA with beta: {ema_beta}")
|
917 |
+
ema_encoder_hid_proj.load_state_dict(orig_encoder_hid_proj.state_dict())
|
918 |
+
ema_attn_procs_list.load_state_dict(orig_attn_procs_list.state_dict())
|
919 |
+
|
920 |
+
# Projector for aggregator.
|
921 |
+
if args.aggregator_adapter:
|
922 |
+
image_proj_model = Resampler(
|
923 |
+
dim=1280,
|
924 |
+
depth=4,
|
925 |
+
dim_head=64,
|
926 |
+
heads=20,
|
927 |
+
num_queries=args.adapter_tokens,
|
928 |
+
embedding_dim=image_encoder.config.hidden_size,
|
929 |
+
output_dim=unet.config.cross_attention_dim,
|
930 |
+
ff_mult=4
|
931 |
+
)
|
932 |
+
|
933 |
+
init_adapter_in_unet(
|
934 |
+
aggregator,
|
935 |
+
image_proj_model,
|
936 |
+
pretrained_adapter_state_dict,
|
937 |
+
adapter_tokens=args.adapter_tokens,
|
938 |
+
)
|
939 |
+
del pretrained_adapter_state_dict
|
940 |
+
|
941 |
+
# Load LCM LoRA into unet.
|
942 |
+
if args.pretrained_lcm_lora_path is not None:
|
943 |
+
lora_state_dict, alpha_dict = StableDiffusionXLPipeline.lora_state_dict(args.pretrained_lcm_lora_path)
|
944 |
+
unet_state_dict = {
|
945 |
+
f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
|
946 |
+
}
|
947 |
+
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
948 |
+
lora_state_dict = dict()
|
949 |
+
for k, v in unet_state_dict.items():
|
950 |
+
if "ip" in k:
|
951 |
+
k = k.replace("attn2", "attn2.processor")
|
952 |
+
lora_state_dict[k] = v
|
953 |
+
else:
|
954 |
+
lora_state_dict[k] = v
|
955 |
+
if alpha_dict:
|
956 |
+
args.lora_alpha = next(iter(alpha_dict.values()))
|
957 |
+
else:
|
958 |
+
args.lora_alpha = 1
|
959 |
+
logger.info(f"Loaded LCM LoRA with alpha: {args.lora_alpha}")
|
960 |
+
# Create LoRA config, FIXME: now hard-coded.
|
961 |
+
lora_target_modules = [
|
962 |
+
"to_q",
|
963 |
+
"to_kv",
|
964 |
+
"0.to_out",
|
965 |
+
"attn1.to_k",
|
966 |
+
"attn1.to_v",
|
967 |
+
"to_k_ip",
|
968 |
+
"to_v_ip",
|
969 |
+
"ln_k_ip.linear",
|
970 |
+
"ln_v_ip.linear",
|
971 |
+
"to_out.0",
|
972 |
+
"proj_in",
|
973 |
+
"proj_out",
|
974 |
+
"ff.net.0.proj",
|
975 |
+
"ff.net.2",
|
976 |
+
"conv1",
|
977 |
+
"conv2",
|
978 |
+
"conv_shortcut",
|
979 |
+
"downsamplers.0.conv",
|
980 |
+
"upsamplers.0.conv",
|
981 |
+
"time_emb_proj",
|
982 |
+
]
|
983 |
+
lora_config = LoraConfig(
|
984 |
+
r=args.lora_rank,
|
985 |
+
target_modules=lora_target_modules,
|
986 |
+
lora_alpha=args.lora_alpha,
|
987 |
+
lora_dropout=args.lora_dropout,
|
988 |
+
)
|
989 |
+
|
990 |
+
unet.add_adapter(lora_config)
|
991 |
+
if args.pretrained_lcm_lora_path is not None:
|
992 |
+
incompatible_keys = set_peft_model_state_dict(unet, lora_state_dict, adapter_name="default")
|
993 |
+
if incompatible_keys is not None:
|
994 |
+
# check only for unexpected keys
|
995 |
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
996 |
+
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
997 |
+
if unexpected_keys:
|
998 |
+
raise ValueError(
|
999 |
+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
1000 |
+
f" {unexpected_keys}. "
|
1001 |
+
)
|
1002 |
+
for k in missing_keys:
|
1003 |
+
if "lora" in k:
|
1004 |
+
raise ValueError(
|
1005 |
+
f"Loading adapter weights from state_dict led to missing keys: {missing_keys}. "
|
1006 |
+
)
|
1007 |
+
lcm_scheduler = LCMSingleStepScheduler.from_config(noise_scheduler.config)
|
1008 |
+
|
1009 |
+
# Initialize training state.
|
1010 |
+
vae.requires_grad_(False)
|
1011 |
+
image_encoder.requires_grad_(False)
|
1012 |
+
text_encoder.requires_grad_(False)
|
1013 |
+
text_encoder_2.requires_grad_(False)
|
1014 |
+
unet.requires_grad_(False)
|
1015 |
+
aggregator.requires_grad_(False)
|
1016 |
+
|
1017 |
+
def unwrap_model(model):
|
1018 |
+
model = accelerator.unwrap_model(model)
|
1019 |
+
model = model._orig_mod if is_compiled_module(model) else model
|
1020 |
+
return model
|
1021 |
+
|
1022 |
+
# `accelerate` 0.16.0 will have better support for customized saving
|
1023 |
+
if args.save_only_adapter:
|
1024 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
1025 |
+
def save_model_hook(models, weights, output_dir):
|
1026 |
+
if accelerator.is_main_process:
|
1027 |
+
for model in models:
|
1028 |
+
if isinstance(model, Aggregator):
|
1029 |
+
torch.save(model.state_dict(), os.path.join(output_dir, "aggregator_ckpt.pt"))
|
1030 |
+
weights.pop()
|
1031 |
+
|
1032 |
+
def load_model_hook(models, input_dir):
|
1033 |
+
|
1034 |
+
while len(models) > 0:
|
1035 |
+
# pop models so that they are not loaded again
|
1036 |
+
model = models.pop()
|
1037 |
+
|
1038 |
+
if isinstance(model, Aggregator):
|
1039 |
+
aggregator_state_dict = torch.load(os.path.join(input_dir, "aggregator_ckpt.pt"), map_location="cpu")
|
1040 |
+
model.load_state_dict(aggregator_state_dict)
|
1041 |
+
|
1042 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
1043 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
1044 |
+
|
1045 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
1046 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
1047 |
+
weight_dtype = torch.float32
|
1048 |
+
if accelerator.mixed_precision == "fp16":
|
1049 |
+
weight_dtype = torch.float16
|
1050 |
+
elif accelerator.mixed_precision == "bf16":
|
1051 |
+
weight_dtype = torch.bfloat16
|
1052 |
+
|
1053 |
+
if args.enable_xformers_memory_efficient_attention:
|
1054 |
+
if is_xformers_available():
|
1055 |
+
import xformers
|
1056 |
+
|
1057 |
+
xformers_version = version.parse(xformers.__version__)
|
1058 |
+
if xformers_version == version.parse("0.0.16"):
|
1059 |
+
logger.warning(
|
1060 |
+
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
1061 |
+
)
|
1062 |
+
unet.enable_xformers_memory_efficient_attention()
|
1063 |
+
else:
|
1064 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
1065 |
+
|
1066 |
+
if args.gradient_checkpointing:
|
1067 |
+
aggregator.enable_gradient_checkpointing()
|
1068 |
+
unet.enable_gradient_checkpointing()
|
1069 |
+
|
1070 |
+
# Check that all trainable models are in full precision
|
1071 |
+
low_precision_error_string = (
|
1072 |
+
" Please make sure to always have all model weights in full float32 precision when starting training - even if"
|
1073 |
+
" doing mixed precision training, copy of the weights should still be float32."
|
1074 |
+
)
|
1075 |
+
|
1076 |
+
if unwrap_model(aggregator).dtype != torch.float32:
|
1077 |
+
raise ValueError(
|
1078 |
+
f"aggregator loaded as datatype {unwrap_model(aggregator).dtype}. {low_precision_error_string}"
|
1079 |
+
)
|
1080 |
+
|
1081 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
1082 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
1083 |
+
if args.allow_tf32:
|
1084 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
1085 |
+
|
1086 |
+
if args.scale_lr:
|
1087 |
+
args.learning_rate = (
|
1088 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
1089 |
+
)
|
1090 |
+
|
1091 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
1092 |
+
if args.use_8bit_adam:
|
1093 |
+
try:
|
1094 |
+
import bitsandbytes as bnb
|
1095 |
+
except ImportError:
|
1096 |
+
raise ImportError(
|
1097 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
1098 |
+
)
|
1099 |
+
|
1100 |
+
optimizer_class = bnb.optim.AdamW8bit
|
1101 |
+
else:
|
1102 |
+
optimizer_class = torch.optim.AdamW
|
1103 |
+
|
1104 |
+
# Optimizer creation
|
1105 |
+
ip_params, non_ip_params = seperate_ip_params_from_unet(unet)
|
1106 |
+
if args.optimize_adapter:
|
1107 |
+
unet_params = ip_params
|
1108 |
+
unet_frozen_params = non_ip_params
|
1109 |
+
else: # freeze all unet params
|
1110 |
+
unet_params = []
|
1111 |
+
unet_frozen_params = ip_params + non_ip_params
|
1112 |
+
assert len(unet_frozen_params) == len(list(unet.parameters()))
|
1113 |
+
params_to_optimize = [p for p in aggregator.parameters()]
|
1114 |
+
params_to_optimize += unet_params
|
1115 |
+
optimizer = optimizer_class(
|
1116 |
+
params_to_optimize,
|
1117 |
+
lr=args.learning_rate,
|
1118 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
1119 |
+
weight_decay=args.adam_weight_decay,
|
1120 |
+
eps=args.adam_epsilon,
|
1121 |
+
)
|
1122 |
+
|
1123 |
+
# Instantiate Loss.
|
1124 |
+
losses_configs: LossesConfig = pyrallis.load(LossesConfig, open(args.losses_config_path, "r"))
|
1125 |
+
diffusion_losses = list()
|
1126 |
+
lcm_losses = list()
|
1127 |
+
for loss_config in losses_configs.diffusion_losses:
|
1128 |
+
logger.info(f"Using diffusion loss: {loss_config.name}")
|
1129 |
+
loss = namedtuple("loss", ["loss", "weight"])
|
1130 |
+
diffusion_losses.append(
|
1131 |
+
loss(loss=eval(loss_config.name)(
|
1132 |
+
visualize_every_k=loss_config.visualize_every_k,
|
1133 |
+
dtype=weight_dtype,
|
1134 |
+
accelerator=accelerator,
|
1135 |
+
**loss_config.init_params), weight=loss_config.weight)
|
1136 |
+
)
|
1137 |
+
for loss_config in losses_configs.lcm_losses:
|
1138 |
+
logger.info(f"Using lcm loss: {loss_config.name}")
|
1139 |
+
loss = namedtuple("loss", ["loss", "weight"])
|
1140 |
+
loss_class = eval(loss_config.name)
|
1141 |
+
lcm_losses.append(loss(loss=loss_class(visualize_every_k=loss_config.visualize_every_k,
|
1142 |
+
dtype=weight_dtype,
|
1143 |
+
accelerator=accelerator,
|
1144 |
+
dino_model=image_encoder,
|
1145 |
+
dino_preprocess=image_processor,
|
1146 |
+
**loss_config.init_params), weight=loss_config.weight))
|
1147 |
+
|
1148 |
+
# SDXL additional condition that will be added to time embedding.
|
1149 |
+
def compute_time_ids(original_size, crops_coords_top_left):
|
1150 |
+
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
1151 |
+
target_size = (args.resolution, args.resolution)
|
1152 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
1153 |
+
add_time_ids = torch.tensor([add_time_ids])
|
1154 |
+
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
|
1155 |
+
return add_time_ids
|
1156 |
+
|
1157 |
+
# Text prompt embeddings.
|
1158 |
+
@torch.no_grad()
|
1159 |
+
def compute_embeddings(batch, text_encoders, tokenizers, proportion_empty_prompts=0.0, drop_idx=None, is_train=True):
|
1160 |
+
prompt_batch = batch[args.caption_column]
|
1161 |
+
if drop_idx is not None:
|
1162 |
+
for i in range(len(prompt_batch)):
|
1163 |
+
prompt_batch[i] = "" if drop_idx[i] else prompt_batch[i]
|
1164 |
+
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
1165 |
+
prompt_batch, text_encoders, tokenizers, is_train
|
1166 |
+
)
|
1167 |
+
|
1168 |
+
add_time_ids = torch.cat(
|
1169 |
+
[compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
|
1170 |
+
)
|
1171 |
+
|
1172 |
+
prompt_embeds = prompt_embeds.to(accelerator.device)
|
1173 |
+
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
|
1174 |
+
add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
|
1175 |
+
unet_added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
|
1176 |
+
|
1177 |
+
return prompt_embeds, unet_added_cond_kwargs
|
1178 |
+
|
1179 |
+
@torch.no_grad()
|
1180 |
+
def get_added_cond(batch, prompt_embeds, pooled_prompt_embeds, proportion_empty_prompts=0.0, drop_idx=None, is_train=True):
|
1181 |
+
|
1182 |
+
add_time_ids = torch.cat(
|
1183 |
+
[compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
|
1184 |
+
)
|
1185 |
+
|
1186 |
+
prompt_embeds = prompt_embeds.to(accelerator.device)
|
1187 |
+
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
|
1188 |
+
add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
|
1189 |
+
unet_added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
|
1190 |
+
|
1191 |
+
return prompt_embeds, unet_added_cond_kwargs
|
1192 |
+
|
1193 |
+
# Move pixels into latents.
|
1194 |
+
@torch.no_grad()
|
1195 |
+
def convert_to_latent(pixels):
|
1196 |
+
model_input = vae.encode(pixels).latent_dist.sample()
|
1197 |
+
model_input = model_input * vae.config.scaling_factor
|
1198 |
+
if args.pretrained_vae_model_name_or_path is None:
|
1199 |
+
model_input = model_input.to(weight_dtype)
|
1200 |
+
return model_input
|
1201 |
+
|
1202 |
+
# Helper functions for training loop.
|
1203 |
+
# if args.degradation_config_path is not None:
|
1204 |
+
# with open(args.degradation_config_path) as stream:
|
1205 |
+
# degradation_configs = yaml.safe_load(stream)
|
1206 |
+
# deg_pipeline = RealESRGANDegradation(device=accelerator.device, resolution=args.resolution)
|
1207 |
+
# else:
|
1208 |
+
deg_pipeline = RealESRGANDegradation(device=accelerator.device, resolution=args.resolution)
|
1209 |
+
compute_embeddings_fn = functools.partial(
|
1210 |
+
compute_embeddings,
|
1211 |
+
text_encoders=[text_encoder, text_encoder_2],
|
1212 |
+
tokenizers=[tokenizer, tokenizer_2],
|
1213 |
+
is_train=True,
|
1214 |
+
)
|
1215 |
+
|
1216 |
+
datasets = []
|
1217 |
+
datasets_name = []
|
1218 |
+
datasets_weights = []
|
1219 |
+
if args.data_config_path is not None:
|
1220 |
+
data_config: DataConfig = pyrallis.load(DataConfig, open(args.data_config_path, "r"))
|
1221 |
+
for single_dataset in data_config.datasets:
|
1222 |
+
datasets_weights.append(single_dataset.dataset_weight)
|
1223 |
+
datasets_name.append(single_dataset.dataset_folder)
|
1224 |
+
dataset_dir = os.path.join(args.train_data_dir, single_dataset.dataset_folder)
|
1225 |
+
image_dataset = get_train_dataset(dataset_dir, dataset_dir, args, accelerator)
|
1226 |
+
image_dataset = prepare_train_dataset(image_dataset, accelerator, deg_pipeline)
|
1227 |
+
datasets.append(image_dataset)
|
1228 |
+
# TODO: Validation dataset
|
1229 |
+
if data_config.val_dataset is not None:
|
1230 |
+
val_dataset = get_train_dataset(dataset_name, dataset_dir, args, accelerator)
|
1231 |
+
logger.info(f"Datasets mixing: {list(zip(datasets_name, datasets_weights))}")
|
1232 |
+
|
1233 |
+
# Mix training datasets.
|
1234 |
+
sampler_train = None
|
1235 |
+
if len(datasets) == 1:
|
1236 |
+
train_dataset = datasets[0]
|
1237 |
+
else:
|
1238 |
+
# Weighted each dataset
|
1239 |
+
train_dataset = torch.utils.data.ConcatDataset(datasets)
|
1240 |
+
dataset_weights = []
|
1241 |
+
for single_dataset, single_weight in zip(datasets, datasets_weights):
|
1242 |
+
dataset_weights.extend([len(train_dataset) / len(single_dataset) * single_weight] * len(single_dataset))
|
1243 |
+
sampler_train = torch.utils.data.WeightedRandomSampler(
|
1244 |
+
weights=dataset_weights,
|
1245 |
+
num_samples=len(dataset_weights)
|
1246 |
+
)
|
1247 |
+
|
1248 |
+
train_dataloader = torch.utils.data.DataLoader(
|
1249 |
+
train_dataset,
|
1250 |
+
batch_size=args.train_batch_size,
|
1251 |
+
sampler=sampler_train,
|
1252 |
+
shuffle=True if sampler_train is None else False,
|
1253 |
+
collate_fn=collate_fn,
|
1254 |
+
num_workers=args.dataloader_num_workers
|
1255 |
+
)
|
1256 |
+
|
1257 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
1258 |
+
# The trackers initializes automatically on the main process.
|
1259 |
+
if accelerator.is_main_process:
|
1260 |
+
tracker_config = dict(vars(args))
|
1261 |
+
|
1262 |
+
# tensorboard cannot handle list types for config
|
1263 |
+
tracker_config.pop("validation_prompt")
|
1264 |
+
tracker_config.pop("validation_image")
|
1265 |
+
|
1266 |
+
accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
|
1267 |
+
|
1268 |
+
# Scheduler and math around the number of training steps.
|
1269 |
+
overrode_max_train_steps = False
|
1270 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
1271 |
+
if args.max_train_steps is None:
|
1272 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
1273 |
+
overrode_max_train_steps = True
|
1274 |
+
|
1275 |
+
lr_scheduler = get_scheduler(
|
1276 |
+
args.lr_scheduler,
|
1277 |
+
optimizer=optimizer,
|
1278 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
1279 |
+
num_training_steps=args.max_train_steps,
|
1280 |
+
num_cycles=args.lr_num_cycles,
|
1281 |
+
power=args.lr_power,
|
1282 |
+
)
|
1283 |
+
|
1284 |
+
# Prepare everything with our `accelerator`.
|
1285 |
+
aggregator, unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
1286 |
+
aggregator, unet, optimizer, train_dataloader, lr_scheduler
|
1287 |
+
)
|
1288 |
+
|
1289 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
1290 |
+
text_encoder_2.to(accelerator.device, dtype=weight_dtype)
|
1291 |
+
|
1292 |
+
# # cache empty prompts and move text encoders to cpu
|
1293 |
+
# empty_prompt_embeds, empty_pooled_prompt_embeds = encode_prompt(
|
1294 |
+
# [""]*args.train_batch_size, [text_encoder, text_encoder_2], [tokenizer, tokenizer_2], True
|
1295 |
+
# )
|
1296 |
+
# compute_embeddings_fn = functools.partial(
|
1297 |
+
# get_added_cond,
|
1298 |
+
# prompt_embeds=empty_prompt_embeds,
|
1299 |
+
# pooled_prompt_embeds=empty_pooled_prompt_embeds,
|
1300 |
+
# is_train=True,
|
1301 |
+
# )
|
1302 |
+
# text_encoder.to("cpu")
|
1303 |
+
# text_encoder_2.to("cpu")
|
1304 |
+
|
1305 |
+
# Move vae, unet and text_encoder to device and cast to `weight_dtype`.
|
1306 |
+
if args.pretrained_vae_model_name_or_path is None:
|
1307 |
+
# The VAE is fp32 to avoid NaN losses.
|
1308 |
+
vae.to(accelerator.device, dtype=torch.float32)
|
1309 |
+
else:
|
1310 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
1311 |
+
image_encoder.to(accelerator.device, dtype=weight_dtype)
|
1312 |
+
if args.use_ema_adapter:
|
1313 |
+
# FIXME: prepare ema model
|
1314 |
+
# ema_encoder_hid_proj, ema_attn_procs_list = accelerator.prepare(ema_encoder_hid_proj, ema_attn_procs_list)
|
1315 |
+
ema_encoder_hid_proj.to(accelerator.device)
|
1316 |
+
ema_attn_procs_list.to(accelerator.device)
|
1317 |
+
for param in unet_frozen_params:
|
1318 |
+
param.data = param.data.to(dtype=weight_dtype)
|
1319 |
+
for param in unet_params:
|
1320 |
+
param.requires_grad_(True)
|
1321 |
+
unet.to(accelerator.device)
|
1322 |
+
aggregator.requires_grad_(True)
|
1323 |
+
aggregator.to(accelerator.device)
|
1324 |
+
importance_ratio = importance_ratio.to(accelerator.device)
|
1325 |
+
|
1326 |
+
# Final sanity check.
|
1327 |
+
for n, p in unet.named_parameters():
|
1328 |
+
assert not p.requires_grad, n
|
1329 |
+
if p.requires_grad:
|
1330 |
+
assert p.dtype == torch.float32, n
|
1331 |
+
else:
|
1332 |
+
assert p.dtype == weight_dtype, n
|
1333 |
+
for n, p in aggregator.named_parameters():
|
1334 |
+
if p.requires_grad: assert p.dtype == torch.float32, n
|
1335 |
+
else:
|
1336 |
+
raise RuntimeError(f"All parameters in aggregator should require grad. {n}")
|
1337 |
+
assert p.dtype == weight_dtype, n
|
1338 |
+
unwrap_model(unet).disable_adapters()
|
1339 |
+
|
1340 |
+
if args.sanity_check:
|
1341 |
+
if args.resume_from_checkpoint:
|
1342 |
+
if args.resume_from_checkpoint != "latest":
|
1343 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
1344 |
+
else:
|
1345 |
+
# Get the most recent checkpoint
|
1346 |
+
dirs = os.listdir(args.output_dir)
|
1347 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
1348 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
1349 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
1350 |
+
|
1351 |
+
if path is None:
|
1352 |
+
accelerator.print(
|
1353 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
1354 |
+
)
|
1355 |
+
args.resume_from_checkpoint = None
|
1356 |
+
initial_global_step = 0
|
1357 |
+
else:
|
1358 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
1359 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
1360 |
+
|
1361 |
+
if args.use_ema_adapter:
|
1362 |
+
unwrap_model(unet).set_attn_processor(ema_attn_procs)
|
1363 |
+
unwrap_model(unet).encoder_hid_proj = ema_encoder_hid_proj
|
1364 |
+
batch = next(iter(train_dataloader))
|
1365 |
+
lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
|
1366 |
+
log_validation(
|
1367 |
+
unwrap_model(unet), unwrap_model(aggregator), vae,
|
1368 |
+
text_encoder, text_encoder_2, tokenizer, tokenizer_2,
|
1369 |
+
noise_scheduler, lcm_scheduler, image_encoder, image_processor, deg_pipeline,
|
1370 |
+
args, accelerator, weight_dtype, step=0, lq_img=lq_img, gt_img=gt_img, log_local=True
|
1371 |
+
)
|
1372 |
+
if args.use_ema_adapter:
|
1373 |
+
unwrap_model(unet).set_attn_processor(orig_attn_procs)
|
1374 |
+
unwrap_model(unet).encoder_hid_proj = orig_encoder_hid_proj
|
1375 |
+
for n, p in unet.named_parameters():
|
1376 |
+
if p.requires_grad: assert p.dtype == torch.float32, n
|
1377 |
+
else: assert p.dtype == weight_dtype, n
|
1378 |
+
for n, p in aggregator.named_parameters():
|
1379 |
+
if p.requires_grad: assert p.dtype == torch.float32, n
|
1380 |
+
else: assert p.dtype == weight_dtype, n
|
1381 |
+
print("PASS")
|
1382 |
+
exit()
|
1383 |
+
|
1384 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
1385 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
1386 |
+
if overrode_max_train_steps:
|
1387 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
1388 |
+
# Afterwards we recalculate our number of training epochs
|
1389 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
1390 |
+
|
1391 |
+
# Train!
|
1392 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
1393 |
+
|
1394 |
+
logger.info("***** Running training *****")
|
1395 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
1396 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
1397 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
1398 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
1399 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
1400 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
1401 |
+
logger.info(f" Optimization steps per epoch = {num_update_steps_per_epoch}")
|
1402 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
1403 |
+
global_step = 0
|
1404 |
+
first_epoch = 0
|
1405 |
+
|
1406 |
+
# Potentially load in the weights and states from a previous save
|
1407 |
+
if args.resume_from_checkpoint:
|
1408 |
+
if args.resume_from_checkpoint != "latest":
|
1409 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
1410 |
+
else:
|
1411 |
+
# Get the most recent checkpoint
|
1412 |
+
dirs = os.listdir(args.output_dir)
|
1413 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
1414 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
1415 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
1416 |
+
|
1417 |
+
if path is None:
|
1418 |
+
accelerator.print(
|
1419 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
1420 |
+
)
|
1421 |
+
args.resume_from_checkpoint = None
|
1422 |
+
initial_global_step = 0
|
1423 |
+
else:
|
1424 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
1425 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
1426 |
+
global_step = int(path.split("-")[1])
|
1427 |
+
|
1428 |
+
initial_global_step = global_step
|
1429 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
1430 |
+
else:
|
1431 |
+
initial_global_step = 0
|
1432 |
+
|
1433 |
+
progress_bar = tqdm(
|
1434 |
+
range(0, args.max_train_steps),
|
1435 |
+
initial=initial_global_step,
|
1436 |
+
desc="Steps",
|
1437 |
+
# Only show the progress bar once on each machine.
|
1438 |
+
disable=not accelerator.is_local_main_process,
|
1439 |
+
)
|
1440 |
+
|
1441 |
+
trainable_models = [aggregator, unet]
|
1442 |
+
|
1443 |
+
if args.gradient_checkpointing:
|
1444 |
+
# TODO: add vae
|
1445 |
+
checkpoint_models = []
|
1446 |
+
else:
|
1447 |
+
checkpoint_models = []
|
1448 |
+
|
1449 |
+
image_logs = None
|
1450 |
+
tic = time.time()
|
1451 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
1452 |
+
for step, batch in enumerate(train_dataloader):
|
1453 |
+
toc = time.time()
|
1454 |
+
io_time = toc - tic
|
1455 |
+
tic = time.time()
|
1456 |
+
for model in trainable_models + checkpoint_models:
|
1457 |
+
model.train()
|
1458 |
+
with accelerator.accumulate(*trainable_models):
|
1459 |
+
loss = torch.tensor(0.0)
|
1460 |
+
|
1461 |
+
# Drop conditions.
|
1462 |
+
rand_tensor = torch.rand(batch["images"].shape[0])
|
1463 |
+
drop_image_idx = rand_tensor < args.image_drop_rate
|
1464 |
+
drop_text_idx = (rand_tensor >= args.image_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate)
|
1465 |
+
drop_both_idx = (rand_tensor >= args.image_drop_rate + args.text_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate + args.cond_drop_rate)
|
1466 |
+
drop_image_idx = drop_image_idx | drop_both_idx
|
1467 |
+
drop_text_idx = drop_text_idx | drop_both_idx
|
1468 |
+
|
1469 |
+
# Get LQ embeddings
|
1470 |
+
with torch.no_grad():
|
1471 |
+
lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
|
1472 |
+
lq_pt = image_processor(
|
1473 |
+
images=lq_img*0.5+0.5,
|
1474 |
+
do_rescale=False, return_tensors="pt"
|
1475 |
+
).pixel_values
|
1476 |
+
|
1477 |
+
# Move inputs to latent space.
|
1478 |
+
gt_img = gt_img.to(dtype=vae.dtype)
|
1479 |
+
lq_img = lq_img.to(dtype=vae.dtype)
|
1480 |
+
model_input = convert_to_latent(gt_img)
|
1481 |
+
lq_latent = convert_to_latent(lq_img)
|
1482 |
+
if args.pretrained_vae_model_name_or_path is None:
|
1483 |
+
model_input = model_input.to(weight_dtype)
|
1484 |
+
lq_latent = lq_latent.to(weight_dtype)
|
1485 |
+
|
1486 |
+
# Process conditions.
|
1487 |
+
image_embeds = prepare_training_image_embeds(
|
1488 |
+
image_encoder, image_processor,
|
1489 |
+
ip_adapter_image=lq_pt, ip_adapter_image_embeds=None,
|
1490 |
+
device=accelerator.device, drop_rate=args.image_drop_rate, output_hidden_state=args.image_encoder_hidden_feature,
|
1491 |
+
idx_to_replace=drop_image_idx
|
1492 |
+
)
|
1493 |
+
prompt_embeds_input, added_conditions = compute_embeddings_fn(batch, drop_idx=drop_text_idx)
|
1494 |
+
|
1495 |
+
# Sample noise that we'll add to the latents.
|
1496 |
+
noise = torch.randn_like(model_input)
|
1497 |
+
bsz = model_input.shape[0]
|
1498 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device)
|
1499 |
+
|
1500 |
+
# Add noise to the model input according to the noise magnitude at each timestep
|
1501 |
+
# (this is the forward diffusion process)
|
1502 |
+
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
|
1503 |
+
loss_weights = extract_into_tensor(importance_ratio, timesteps, noise.shape) if args.importance_sampling else None
|
1504 |
+
|
1505 |
+
if args.CFG_scale > 1.0:
|
1506 |
+
# Process negative conditions.
|
1507 |
+
drop_idx = torch.ones_like(drop_image_idx)
|
1508 |
+
neg_image_embeds = prepare_training_image_embeds(
|
1509 |
+
image_encoder, image_processor,
|
1510 |
+
ip_adapter_image=lq_pt, ip_adapter_image_embeds=None,
|
1511 |
+
device=accelerator.device, drop_rate=1.0, output_hidden_state=args.image_encoder_hidden_feature,
|
1512 |
+
idx_to_replace=drop_idx
|
1513 |
+
)
|
1514 |
+
neg_prompt_embeds_input, neg_added_conditions = compute_embeddings_fn(batch, drop_idx=drop_idx)
|
1515 |
+
previewer_model_input = torch.cat([noisy_model_input] * 2)
|
1516 |
+
previewer_timesteps = torch.cat([timesteps] * 2)
|
1517 |
+
previewer_prompt_embeds = torch.cat([neg_prompt_embeds_input, prompt_embeds_input], dim=0)
|
1518 |
+
previewer_added_conditions = {}
|
1519 |
+
for k in added_conditions.keys():
|
1520 |
+
previewer_added_conditions[k] = torch.cat([neg_added_conditions[k], added_conditions[k]], dim=0)
|
1521 |
+
previewer_image_embeds = []
|
1522 |
+
for neg_image_embed, image_embed in zip(neg_image_embeds, image_embeds):
|
1523 |
+
previewer_image_embeds.append(torch.cat([neg_image_embed, image_embed], dim=0))
|
1524 |
+
previewer_added_conditions["image_embeds"] = previewer_image_embeds
|
1525 |
+
else:
|
1526 |
+
previewer_model_input = noisy_model_input
|
1527 |
+
previewer_timesteps = timesteps
|
1528 |
+
previewer_prompt_embeds = prompt_embeds_input
|
1529 |
+
previewer_added_conditions = {}
|
1530 |
+
for k in added_conditions.keys():
|
1531 |
+
previewer_added_conditions[k] = added_conditions[k]
|
1532 |
+
previewer_added_conditions["image_embeds"] = image_embeds
|
1533 |
+
|
1534 |
+
# Get LCM preview latent
|
1535 |
+
if args.use_ema_adapter:
|
1536 |
+
orig_encoder_hid_proj = unet.encoder_hid_proj
|
1537 |
+
orig_attn_procs = unet.attn_processors
|
1538 |
+
_ema_attn_procs = copy_dict(ema_attn_procs)
|
1539 |
+
unwrap_model(unet).set_attn_processor(_ema_attn_procs)
|
1540 |
+
unwrap_model(unet).encoder_hid_proj = ema_encoder_hid_proj
|
1541 |
+
unwrap_model(unet).enable_adapters()
|
1542 |
+
preview_noise = unet(
|
1543 |
+
previewer_model_input, previewer_timesteps,
|
1544 |
+
encoder_hidden_states=previewer_prompt_embeds,
|
1545 |
+
added_cond_kwargs=previewer_added_conditions,
|
1546 |
+
return_dict=False
|
1547 |
+
)[0]
|
1548 |
+
if args.CFG_scale > 1.0:
|
1549 |
+
preview_noise_uncond, preview_noise_cond = preview_noise.chunk(2)
|
1550 |
+
cfg_scale = 1.0 + torch.rand_like(timesteps, dtype=weight_dtype) * (args.CFG_scale-1.0)
|
1551 |
+
cfg_scale = cfg_scale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
1552 |
+
preview_noise = preview_noise_uncond + cfg_scale * (preview_noise_cond - preview_noise_uncond)
|
1553 |
+
preview_latents = lcm_scheduler.step(
|
1554 |
+
preview_noise,
|
1555 |
+
timesteps,
|
1556 |
+
noisy_model_input,
|
1557 |
+
return_dict=False
|
1558 |
+
)[0]
|
1559 |
+
unwrap_model(unet).disable_adapters()
|
1560 |
+
if args.use_ema_adapter:
|
1561 |
+
unwrap_model(unet).set_attn_processor(orig_attn_procs)
|
1562 |
+
unwrap_model(unet).encoder_hid_proj = orig_encoder_hid_proj
|
1563 |
+
preview_error_latent = F.mse_loss(preview_latents, model_input).cpu().item()
|
1564 |
+
preview_error_noise = F.mse_loss(preview_noise, noise).cpu().item()
|
1565 |
+
|
1566 |
+
# # Add fresh noise
|
1567 |
+
# if args.noisy_encoder_input:
|
1568 |
+
# aggregator_noise = torch.randn_like(preview_latents)
|
1569 |
+
# aggregator_input = noise_scheduler.add_noise(preview_latents, aggregator_noise, timesteps)
|
1570 |
+
|
1571 |
+
down_block_res_samples, mid_block_res_sample = aggregator(
|
1572 |
+
lq_latent,
|
1573 |
+
timesteps,
|
1574 |
+
encoder_hidden_states=prompt_embeds_input,
|
1575 |
+
added_cond_kwargs=added_conditions,
|
1576 |
+
controlnet_cond=preview_latents,
|
1577 |
+
conditioning_scale=1.0,
|
1578 |
+
return_dict=False,
|
1579 |
+
)
|
1580 |
+
|
1581 |
+
# UNet denoise.
|
1582 |
+
added_conditions["image_embeds"] = image_embeds
|
1583 |
+
model_pred = unet(
|
1584 |
+
noisy_model_input,
|
1585 |
+
timesteps,
|
1586 |
+
encoder_hidden_states=prompt_embeds_input,
|
1587 |
+
added_cond_kwargs=added_conditions,
|
1588 |
+
down_block_additional_residuals=[
|
1589 |
+
sample.to(dtype=weight_dtype) for sample in down_block_res_samples
|
1590 |
+
],
|
1591 |
+
mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
|
1592 |
+
return_dict=False
|
1593 |
+
)[0]
|
1594 |
+
|
1595 |
+
diffusion_loss_arguments = {
|
1596 |
+
"target": noise,
|
1597 |
+
"predict": model_pred,
|
1598 |
+
"prompt_embeddings_input": prompt_embeds_input,
|
1599 |
+
"timesteps": timesteps,
|
1600 |
+
"weights": loss_weights,
|
1601 |
+
}
|
1602 |
+
|
1603 |
+
loss_dict = dict()
|
1604 |
+
for loss_config in diffusion_losses:
|
1605 |
+
non_weighted_loss = loss_config.loss(**diffusion_loss_arguments, accelerator=accelerator)
|
1606 |
+
loss = loss + non_weighted_loss * loss_config.weight
|
1607 |
+
loss_dict[loss_config.loss.__class__.__name__] = non_weighted_loss.item()
|
1608 |
+
|
1609 |
+
accelerator.backward(loss)
|
1610 |
+
if accelerator.sync_gradients:
|
1611 |
+
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
|
1612 |
+
optimizer.step()
|
1613 |
+
lr_scheduler.step()
|
1614 |
+
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
|
1615 |
+
|
1616 |
+
toc = time.time()
|
1617 |
+
forward_time = toc - tic
|
1618 |
+
tic = toc
|
1619 |
+
|
1620 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
1621 |
+
if accelerator.sync_gradients:
|
1622 |
+
progress_bar.update(1)
|
1623 |
+
global_step += 1
|
1624 |
+
|
1625 |
+
if global_step % args.ema_update_steps == 0:
|
1626 |
+
if args.use_ema_adapter:
|
1627 |
+
update_ema_model(ema_encoder_hid_proj, orig_encoder_hid_proj, ema_beta)
|
1628 |
+
update_ema_model(ema_attn_procs_list, orig_attn_procs_list, ema_beta)
|
1629 |
+
|
1630 |
+
if accelerator.is_main_process:
|
1631 |
+
if global_step % args.checkpointing_steps == 0:
|
1632 |
+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
1633 |
+
if args.checkpoints_total_limit is not None:
|
1634 |
+
checkpoints = os.listdir(args.output_dir)
|
1635 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
1636 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
1637 |
+
|
1638 |
+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
1639 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
1640 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
1641 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
1642 |
+
|
1643 |
+
logger.info(
|
1644 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
1645 |
+
)
|
1646 |
+
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
1647 |
+
|
1648 |
+
for removing_checkpoint in removing_checkpoints:
|
1649 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
1650 |
+
shutil.rmtree(removing_checkpoint)
|
1651 |
+
|
1652 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
1653 |
+
accelerator.save_state(save_path)
|
1654 |
+
logger.info(f"Saved state to {save_path}")
|
1655 |
+
|
1656 |
+
if global_step % args.validation_steps == 0:
|
1657 |
+
image_logs = log_validation(
|
1658 |
+
unwrap_model(unet), unwrap_model(aggregator), vae,
|
1659 |
+
text_encoder, text_encoder_2, tokenizer, tokenizer_2,
|
1660 |
+
noise_scheduler, lcm_scheduler, image_encoder, image_processor, deg_pipeline,
|
1661 |
+
args, accelerator, weight_dtype, global_step, lq_img.detach().clone(), gt_img.detach().clone()
|
1662 |
+
)
|
1663 |
+
|
1664 |
+
logs = {}
|
1665 |
+
logs.update(loss_dict)
|
1666 |
+
logs.update(
|
1667 |
+
{"preview_error_latent": preview_error_latent, "preview_error_noise": preview_error_noise,
|
1668 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
1669 |
+
"forward_time": forward_time, "io_time": io_time}
|
1670 |
+
)
|
1671 |
+
progress_bar.set_postfix(**logs)
|
1672 |
+
accelerator.log(logs, step=global_step)
|
1673 |
+
tic = time.time()
|
1674 |
+
|
1675 |
+
if global_step >= args.max_train_steps:
|
1676 |
+
break
|
1677 |
+
|
1678 |
+
# Create the pipeline using using the trained modules and save it.
|
1679 |
+
accelerator.wait_for_everyone()
|
1680 |
+
if accelerator.is_main_process:
|
1681 |
+
accelerator.save_state(save_path, safe_serialization=False)
|
1682 |
+
# Run a final round of validation.
|
1683 |
+
# Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`.
|
1684 |
+
image_logs = None
|
1685 |
+
if args.validation_image is not None:
|
1686 |
+
image_logs = log_validation(
|
1687 |
+
unwrap_model(unet), unwrap_model(aggregator), vae,
|
1688 |
+
text_encoder, text_encoder_2, tokenizer, tokenizer_2,
|
1689 |
+
noise_scheduler, lcm_scheduler, image_encoder, image_processor, deg_pipeline,
|
1690 |
+
args, accelerator, weight_dtype, global_step,
|
1691 |
+
)
|
1692 |
+
|
1693 |
+
accelerator.end_training()
|
1694 |
+
|
1695 |
+
|
1696 |
+
if __name__ == "__main__":
|
1697 |
+
args = parse_args()
|
1698 |
+
main(args)
|
train_stage2_aggregator.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Stage 2: train aggregator
|
2 |
+
accelerate launch --num_processes <num_of_gpus> train_stage2_aggregator.py \
|
3 |
+
--output_dir <your/output/path> \
|
4 |
+
--train_data_dir <your/data/path> \
|
5 |
+
--logging_dir <your/logging/path> \
|
6 |
+
--pretrained_model_name_or_path <your/sdxl/path> \
|
7 |
+
--feature_extractor_path <your/dinov2/path> \
|
8 |
+
--pretrained_adapter_model_path <your/dcp/path> \
|
9 |
+
--pretrained_lcm_lora_path <your/previewer_lora/path> \
|
10 |
+
--losses_config_path config_files/losses.yaml \
|
11 |
+
--data_config_path config_files/IR_dataset.yaml \
|
12 |
+
--image_drop_rate 0.0 \
|
13 |
+
--text_drop_rate 0.85 \
|
14 |
+
--cond_drop_rate 0.15 \
|
15 |
+
--save_only_adapter \
|
16 |
+
--gradient_checkpointing \
|
17 |
+
--mixed_precision fp16 \
|
18 |
+
--train_batch_size 6 \
|
19 |
+
--gradient_accumulation_steps 2 \
|
20 |
+
--learning_rate 1e-4 \
|
21 |
+
--lr_warmup_steps 1000 \
|
22 |
+
--lr_scheduler cosine \
|
23 |
+
--lr_num_cycles 1 \
|
24 |
+
--resume_from_checkpoint latest
|
utils/degradation_pipeline.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
from torch.utils import data as data
|
7 |
+
|
8 |
+
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
|
9 |
+
from basicsr.data.transforms import augment
|
10 |
+
from basicsr.utils import img2tensor, DiffJPEG, USMSharp
|
11 |
+
from basicsr.utils.img_process_util import filter2D
|
12 |
+
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
|
13 |
+
from basicsr.data.transforms import paired_random_crop
|
14 |
+
|
15 |
+
AUGMENT_OPT = {
|
16 |
+
'use_hflip': False,
|
17 |
+
'use_rot': False
|
18 |
+
}
|
19 |
+
|
20 |
+
KERNEL_OPT = {
|
21 |
+
'blur_kernel_size': 21,
|
22 |
+
'kernel_list': ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'],
|
23 |
+
'kernel_prob': [0.45, 0.25, 0.12, 0.03, 0.12, 0.03],
|
24 |
+
'sinc_prob': 0.1,
|
25 |
+
'blur_sigma': [0.2, 3],
|
26 |
+
'betag_range': [0.5, 4],
|
27 |
+
'betap_range': [1, 2],
|
28 |
+
|
29 |
+
'blur_kernel_size2': 21,
|
30 |
+
'kernel_list2': ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'],
|
31 |
+
'kernel_prob2': [0.45, 0.25, 0.12, 0.03, 0.12, 0.03],
|
32 |
+
'sinc_prob2': 0.1,
|
33 |
+
'blur_sigma2': [0.2, 1.5],
|
34 |
+
'betag_range2': [0.5, 4],
|
35 |
+
'betap_range2': [1, 2],
|
36 |
+
'final_sinc_prob': 0.8,
|
37 |
+
}
|
38 |
+
|
39 |
+
DEGRADE_OPT = {
|
40 |
+
'resize_prob': [0.2, 0.7, 0.1], # up, down, keep
|
41 |
+
'resize_range': [0.15, 1.5],
|
42 |
+
'gaussian_noise_prob': 0.5,
|
43 |
+
'noise_range': [1, 30],
|
44 |
+
'poisson_scale_range': [0.05, 3],
|
45 |
+
'gray_noise_prob': 0.4,
|
46 |
+
'jpeg_range': [30, 95],
|
47 |
+
|
48 |
+
# the second degradation process
|
49 |
+
'second_blur_prob': 0.8,
|
50 |
+
'resize_prob2': [0.3, 0.4, 0.3], # up, down, keep
|
51 |
+
'resize_range2': [0.3, 1.2],
|
52 |
+
'gaussian_noise_prob2': 0.5,
|
53 |
+
'noise_range2': [1, 25],
|
54 |
+
'poisson_scale_range2': [0.05, 2.5],
|
55 |
+
'gray_noise_prob2': 0.4,
|
56 |
+
'jpeg_range2': [30, 95],
|
57 |
+
|
58 |
+
'gt_size': 512,
|
59 |
+
'no_degradation_prob': 0.01,
|
60 |
+
'use_usm': True,
|
61 |
+
'sf': 4,
|
62 |
+
'random_size': False,
|
63 |
+
'resize_lq': True
|
64 |
+
}
|
65 |
+
|
66 |
+
class RealESRGANDegradation:
|
67 |
+
|
68 |
+
def __init__(self, augment_opt=None, kernel_opt=None, degrade_opt=None, device='cuda', resolution=None):
|
69 |
+
if augment_opt is None:
|
70 |
+
augment_opt = AUGMENT_OPT
|
71 |
+
self.augment_opt = augment_opt
|
72 |
+
if kernel_opt is None:
|
73 |
+
kernel_opt = KERNEL_OPT
|
74 |
+
self.kernel_opt = kernel_opt
|
75 |
+
if degrade_opt is None:
|
76 |
+
degrade_opt = DEGRADE_OPT
|
77 |
+
self.degrade_opt = degrade_opt
|
78 |
+
if resolution is not None:
|
79 |
+
self.degrade_opt['gt_size'] = resolution
|
80 |
+
self.device = device
|
81 |
+
|
82 |
+
self.jpeger = DiffJPEG(differentiable=False).to(self.device)
|
83 |
+
self.usm_sharpener = USMSharp().to(self.device)
|
84 |
+
|
85 |
+
# blur settings for the first degradation
|
86 |
+
self.blur_kernel_size = kernel_opt['blur_kernel_size']
|
87 |
+
self.kernel_list = kernel_opt['kernel_list']
|
88 |
+
self.kernel_prob = kernel_opt['kernel_prob'] # a list for each kernel probability
|
89 |
+
self.blur_sigma = kernel_opt['blur_sigma']
|
90 |
+
self.betag_range = kernel_opt['betag_range'] # betag used in generalized Gaussian blur kernels
|
91 |
+
self.betap_range = kernel_opt['betap_range'] # betap used in plateau blur kernels
|
92 |
+
self.sinc_prob = kernel_opt['sinc_prob'] # the probability for sinc filters
|
93 |
+
|
94 |
+
# blur settings for the second degradation
|
95 |
+
self.blur_kernel_size2 = kernel_opt['blur_kernel_size2']
|
96 |
+
self.kernel_list2 = kernel_opt['kernel_list2']
|
97 |
+
self.kernel_prob2 = kernel_opt['kernel_prob2']
|
98 |
+
self.blur_sigma2 = kernel_opt['blur_sigma2']
|
99 |
+
self.betag_range2 = kernel_opt['betag_range2']
|
100 |
+
self.betap_range2 = kernel_opt['betap_range2']
|
101 |
+
self.sinc_prob2 = kernel_opt['sinc_prob2']
|
102 |
+
|
103 |
+
# a final sinc filter
|
104 |
+
self.final_sinc_prob = kernel_opt['final_sinc_prob']
|
105 |
+
|
106 |
+
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
|
107 |
+
# TODO: kernel range is now hard-coded, should be in the configure file
|
108 |
+
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
|
109 |
+
self.pulse_tensor[10, 10] = 1
|
110 |
+
|
111 |
+
def get_kernel(self):
|
112 |
+
|
113 |
+
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
|
114 |
+
kernel_size = random.choice(self.kernel_range)
|
115 |
+
if np.random.uniform() < self.kernel_opt['sinc_prob']:
|
116 |
+
# this sinc filter setting is for kernels ranging from [7, 21]
|
117 |
+
if kernel_size < 13:
|
118 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
119 |
+
else:
|
120 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
121 |
+
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
122 |
+
else:
|
123 |
+
kernel = random_mixed_kernels(
|
124 |
+
self.kernel_list,
|
125 |
+
self.kernel_prob,
|
126 |
+
kernel_size,
|
127 |
+
self.blur_sigma,
|
128 |
+
self.blur_sigma, [-math.pi, math.pi],
|
129 |
+
self.betag_range,
|
130 |
+
self.betap_range,
|
131 |
+
noise_range=None)
|
132 |
+
# pad kernel
|
133 |
+
pad_size = (21 - kernel_size) // 2
|
134 |
+
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
135 |
+
|
136 |
+
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
|
137 |
+
kernel_size = random.choice(self.kernel_range)
|
138 |
+
if np.random.uniform() < self.kernel_opt['sinc_prob2']:
|
139 |
+
if kernel_size < 13:
|
140 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
141 |
+
else:
|
142 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
143 |
+
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
144 |
+
else:
|
145 |
+
kernel2 = random_mixed_kernels(
|
146 |
+
self.kernel_list2,
|
147 |
+
self.kernel_prob2,
|
148 |
+
kernel_size,
|
149 |
+
self.blur_sigma2,
|
150 |
+
self.blur_sigma2, [-math.pi, math.pi],
|
151 |
+
self.betag_range2,
|
152 |
+
self.betap_range2,
|
153 |
+
noise_range=None)
|
154 |
+
|
155 |
+
# pad kernel
|
156 |
+
pad_size = (21 - kernel_size) // 2
|
157 |
+
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
|
158 |
+
|
159 |
+
# ------------------------------------- the final sinc kernel ------------------------------------- #
|
160 |
+
if np.random.uniform() < self.kernel_opt['final_sinc_prob']:
|
161 |
+
kernel_size = random.choice(self.kernel_range)
|
162 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
163 |
+
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
|
164 |
+
sinc_kernel = torch.FloatTensor(sinc_kernel)
|
165 |
+
else:
|
166 |
+
sinc_kernel = self.pulse_tensor
|
167 |
+
|
168 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
169 |
+
kernel = torch.FloatTensor(kernel)
|
170 |
+
kernel2 = torch.FloatTensor(kernel2)
|
171 |
+
|
172 |
+
return (kernel, kernel2, sinc_kernel)
|
173 |
+
|
174 |
+
@torch.no_grad()
|
175 |
+
def __call__(self, img_gt, kernels=None):
|
176 |
+
'''
|
177 |
+
:param: img_gt: BCHW, RGB, [0, 1] float32 tensor
|
178 |
+
'''
|
179 |
+
if kernels is None:
|
180 |
+
kernel = []
|
181 |
+
kernel2 = []
|
182 |
+
sinc_kernel = []
|
183 |
+
for _ in range(img_gt.shape[0]):
|
184 |
+
k, k2, sk = self.get_kernel()
|
185 |
+
kernel.append(k)
|
186 |
+
kernel2.append(k2)
|
187 |
+
sinc_kernel.append(sk)
|
188 |
+
kernel = torch.stack(kernel)
|
189 |
+
kernel2 = torch.stack(kernel2)
|
190 |
+
sinc_kernel = torch.stack(sinc_kernel)
|
191 |
+
else:
|
192 |
+
# kernels created in dataset.
|
193 |
+
kernel, kernel2, sinc_kernel = kernels
|
194 |
+
|
195 |
+
# ----------------------- Pre-process ----------------------- #
|
196 |
+
im_gt = img_gt.to(self.device)
|
197 |
+
if self.degrade_opt['use_usm']:
|
198 |
+
im_gt = self.usm_sharpener(im_gt)
|
199 |
+
im_gt = im_gt.to(memory_format=torch.contiguous_format).float()
|
200 |
+
kernel = kernel.to(self.device)
|
201 |
+
kernel2 = kernel2.to(self.device)
|
202 |
+
sinc_kernel = sinc_kernel.to(self.device)
|
203 |
+
ori_h, ori_w = im_gt.size()[2:4]
|
204 |
+
|
205 |
+
# ----------------------- The first degradation process ----------------------- #
|
206 |
+
# blur
|
207 |
+
out = filter2D(im_gt, kernel)
|
208 |
+
# random resize
|
209 |
+
updown_type = random.choices(
|
210 |
+
['up', 'down', 'keep'],
|
211 |
+
self.degrade_opt['resize_prob'],
|
212 |
+
)[0]
|
213 |
+
if updown_type == 'up':
|
214 |
+
scale = random.uniform(1, self.degrade_opt['resize_range'][1])
|
215 |
+
elif updown_type == 'down':
|
216 |
+
scale = random.uniform(self.degrade_opt['resize_range'][0], 1)
|
217 |
+
else:
|
218 |
+
scale = 1
|
219 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
220 |
+
out = torch.nn.functional.interpolate(out, scale_factor=scale, mode=mode)
|
221 |
+
# add noise
|
222 |
+
gray_noise_prob = self.degrade_opt['gray_noise_prob']
|
223 |
+
if random.random() < self.degrade_opt['gaussian_noise_prob']:
|
224 |
+
out = random_add_gaussian_noise_pt(
|
225 |
+
out,
|
226 |
+
sigma_range=self.degrade_opt['noise_range'],
|
227 |
+
clip=True,
|
228 |
+
rounds=False,
|
229 |
+
gray_prob=gray_noise_prob,
|
230 |
+
)
|
231 |
+
else:
|
232 |
+
out = random_add_poisson_noise_pt(
|
233 |
+
out,
|
234 |
+
scale_range=self.degrade_opt['poisson_scale_range'],
|
235 |
+
gray_prob=gray_noise_prob,
|
236 |
+
clip=True,
|
237 |
+
rounds=False)
|
238 |
+
# JPEG compression
|
239 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.degrade_opt['jpeg_range'])
|
240 |
+
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
|
241 |
+
out = self.jpeger(out, quality=jpeg_p)
|
242 |
+
|
243 |
+
# ----------------------- The second degradation process ----------------------- #
|
244 |
+
# blur
|
245 |
+
if random.random() < self.degrade_opt['second_blur_prob']:
|
246 |
+
out = out.contiguous()
|
247 |
+
out = filter2D(out, kernel2)
|
248 |
+
# random resize
|
249 |
+
updown_type = random.choices(
|
250 |
+
['up', 'down', 'keep'],
|
251 |
+
self.degrade_opt['resize_prob2'],
|
252 |
+
)[0]
|
253 |
+
if updown_type == 'up':
|
254 |
+
scale = random.uniform(1, self.degrade_opt['resize_range2'][1])
|
255 |
+
elif updown_type == 'down':
|
256 |
+
scale = random.uniform(self.degrade_opt['resize_range2'][0], 1)
|
257 |
+
else:
|
258 |
+
scale = 1
|
259 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
260 |
+
out = torch.nn.functional.interpolate(
|
261 |
+
out,
|
262 |
+
size=(int(ori_h / self.degrade_opt['sf'] * scale),
|
263 |
+
int(ori_w / self.degrade_opt['sf'] * scale)),
|
264 |
+
mode=mode,
|
265 |
+
)
|
266 |
+
# add noise
|
267 |
+
gray_noise_prob = self.degrade_opt['gray_noise_prob2']
|
268 |
+
if random.random() < self.degrade_opt['gaussian_noise_prob2']:
|
269 |
+
out = random_add_gaussian_noise_pt(
|
270 |
+
out,
|
271 |
+
sigma_range=self.degrade_opt['noise_range2'],
|
272 |
+
clip=True,
|
273 |
+
rounds=False,
|
274 |
+
gray_prob=gray_noise_prob,
|
275 |
+
)
|
276 |
+
else:
|
277 |
+
out = random_add_poisson_noise_pt(
|
278 |
+
out,
|
279 |
+
scale_range=self.degrade_opt['poisson_scale_range2'],
|
280 |
+
gray_prob=gray_noise_prob,
|
281 |
+
clip=True,
|
282 |
+
rounds=False,
|
283 |
+
)
|
284 |
+
|
285 |
+
# JPEG compression + the final sinc filter
|
286 |
+
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
|
287 |
+
# as one operation.
|
288 |
+
# We consider two orders:
|
289 |
+
# 1. [resize back + sinc filter] + JPEG compression
|
290 |
+
# 2. JPEG compression + [resize back + sinc filter]
|
291 |
+
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
|
292 |
+
if random.random() < 0.5:
|
293 |
+
# resize back + the final sinc filter
|
294 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
295 |
+
out = torch.nn.functional.interpolate(
|
296 |
+
out,
|
297 |
+
size=(ori_h // self.degrade_opt['sf'],
|
298 |
+
ori_w // self.degrade_opt['sf']),
|
299 |
+
mode=mode,
|
300 |
+
)
|
301 |
+
out = out.contiguous()
|
302 |
+
out = filter2D(out, sinc_kernel)
|
303 |
+
# JPEG compression
|
304 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.degrade_opt['jpeg_range2'])
|
305 |
+
out = torch.clamp(out, 0, 1)
|
306 |
+
out = self.jpeger(out, quality=jpeg_p)
|
307 |
+
else:
|
308 |
+
# JPEG compression
|
309 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.degrade_opt['jpeg_range2'])
|
310 |
+
out = torch.clamp(out, 0, 1)
|
311 |
+
out = self.jpeger(out, quality=jpeg_p)
|
312 |
+
# resize back + the final sinc filter
|
313 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
314 |
+
out = torch.nn.functional.interpolate(
|
315 |
+
out,
|
316 |
+
size=(ori_h // self.degrade_opt['sf'],
|
317 |
+
ori_w // self.degrade_opt['sf']),
|
318 |
+
mode=mode,
|
319 |
+
)
|
320 |
+
out = out.contiguous()
|
321 |
+
out = filter2D(out, sinc_kernel)
|
322 |
+
|
323 |
+
# clamp and round
|
324 |
+
im_lq = torch.clamp(out, 0, 1.0)
|
325 |
+
|
326 |
+
# random crop
|
327 |
+
gt_size = self.degrade_opt['gt_size']
|
328 |
+
im_gt, im_lq = paired_random_crop(im_gt, im_lq, gt_size, self.degrade_opt['sf'])
|
329 |
+
|
330 |
+
if self.degrade_opt['resize_lq']:
|
331 |
+
im_lq = torch.nn.functional.interpolate(
|
332 |
+
im_lq,
|
333 |
+
size=(im_gt.size(-2),
|
334 |
+
im_gt.size(-1)),
|
335 |
+
mode='bicubic',
|
336 |
+
)
|
337 |
+
|
338 |
+
if random.random() < self.degrade_opt['no_degradation_prob'] or torch.isnan(im_lq).any():
|
339 |
+
im_lq = im_gt
|
340 |
+
|
341 |
+
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
|
342 |
+
im_lq = im_lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
|
343 |
+
im_lq = im_lq*2 - 1.0
|
344 |
+
im_gt = im_gt*2 - 1.0
|
345 |
+
|
346 |
+
if self.degrade_opt['random_size']:
|
347 |
+
raise NotImplementedError
|
348 |
+
im_lq, im_gt = self.randn_cropinput(im_lq, im_gt)
|
349 |
+
|
350 |
+
im_lq = torch.clamp(im_lq, -1.0, 1.0)
|
351 |
+
im_gt = torch.clamp(im_gt, -1.0, 1.0)
|
352 |
+
|
353 |
+
return (im_lq, im_gt)
|
utils/matlab_cp2tform.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Tue Jul 11 06:54:28 2017
|
4 |
+
|
5 |
+
@author: zhaoyafei
|
6 |
+
"""
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from numpy.linalg import inv, norm, lstsq
|
10 |
+
from numpy.linalg import matrix_rank as rank
|
11 |
+
|
12 |
+
class MatlabCp2tormException(Exception):
|
13 |
+
def __str__(self):
|
14 |
+
return 'In File {}:{}'.format(
|
15 |
+
__file__, super.__str__(self))
|
16 |
+
|
17 |
+
def tformfwd(trans, uv):
|
18 |
+
"""
|
19 |
+
Function:
|
20 |
+
----------
|
21 |
+
apply affine transform 'trans' to uv
|
22 |
+
|
23 |
+
Parameters:
|
24 |
+
----------
|
25 |
+
@trans: 3x3 np.array
|
26 |
+
transform matrix
|
27 |
+
@uv: Kx2 np.array
|
28 |
+
each row is a pair of coordinates (x, y)
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
----------
|
32 |
+
@xy: Kx2 np.array
|
33 |
+
each row is a pair of transformed coordinates (x, y)
|
34 |
+
"""
|
35 |
+
uv = np.hstack((
|
36 |
+
uv, np.ones((uv.shape[0], 1))
|
37 |
+
))
|
38 |
+
xy = np.dot(uv, trans)
|
39 |
+
xy = xy[:, 0:-1]
|
40 |
+
return xy
|
41 |
+
|
42 |
+
|
43 |
+
def tforminv(trans, uv):
|
44 |
+
"""
|
45 |
+
Function:
|
46 |
+
----------
|
47 |
+
apply the inverse of affine transform 'trans' to uv
|
48 |
+
|
49 |
+
Parameters:
|
50 |
+
----------
|
51 |
+
@trans: 3x3 np.array
|
52 |
+
transform matrix
|
53 |
+
@uv: Kx2 np.array
|
54 |
+
each row is a pair of coordinates (x, y)
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
----------
|
58 |
+
@xy: Kx2 np.array
|
59 |
+
each row is a pair of inverse-transformed coordinates (x, y)
|
60 |
+
"""
|
61 |
+
Tinv = inv(trans)
|
62 |
+
xy = tformfwd(Tinv, uv)
|
63 |
+
return xy
|
64 |
+
|
65 |
+
|
66 |
+
def findNonreflectiveSimilarity(uv, xy, options=None):
|
67 |
+
|
68 |
+
options = {'K': 2}
|
69 |
+
|
70 |
+
K = options['K']
|
71 |
+
M = xy.shape[0]
|
72 |
+
x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
|
73 |
+
y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
|
74 |
+
# print('--->x, y:\n', x, y
|
75 |
+
|
76 |
+
tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
|
77 |
+
tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
|
78 |
+
X = np.vstack((tmp1, tmp2))
|
79 |
+
# print('--->X.shape: ', X.shape
|
80 |
+
# print('X:\n', X
|
81 |
+
|
82 |
+
u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
|
83 |
+
v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
|
84 |
+
U = np.vstack((u, v))
|
85 |
+
# print('--->U.shape: ', U.shape
|
86 |
+
# print('U:\n', U
|
87 |
+
|
88 |
+
# We know that X * r = U
|
89 |
+
if rank(X) >= 2 * K:
|
90 |
+
r, _, _, _ = lstsq(X, U)
|
91 |
+
r = np.squeeze(r)
|
92 |
+
else:
|
93 |
+
raise Exception('cp2tform:twoUniquePointsReq')
|
94 |
+
|
95 |
+
# print('--->r:\n', r
|
96 |
+
|
97 |
+
sc = r[0]
|
98 |
+
ss = r[1]
|
99 |
+
tx = r[2]
|
100 |
+
ty = r[3]
|
101 |
+
|
102 |
+
Tinv = np.array([
|
103 |
+
[sc, -ss, 0],
|
104 |
+
[ss, sc, 0],
|
105 |
+
[tx, ty, 1]
|
106 |
+
])
|
107 |
+
|
108 |
+
# print('--->Tinv:\n', Tinv
|
109 |
+
|
110 |
+
T = inv(Tinv)
|
111 |
+
# print('--->T:\n', T
|
112 |
+
|
113 |
+
T[:, 2] = np.array([0, 0, 1])
|
114 |
+
|
115 |
+
return T, Tinv
|
116 |
+
|
117 |
+
|
118 |
+
def findSimilarity(uv, xy, options=None):
|
119 |
+
|
120 |
+
options = {'K': 2}
|
121 |
+
|
122 |
+
# uv = np.array(uv)
|
123 |
+
# xy = np.array(xy)
|
124 |
+
|
125 |
+
# Solve for trans1
|
126 |
+
trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
|
127 |
+
|
128 |
+
# Solve for trans2
|
129 |
+
|
130 |
+
# manually reflect the xy data across the Y-axis
|
131 |
+
xyR = xy
|
132 |
+
xyR[:, 0] = -1 * xyR[:, 0]
|
133 |
+
|
134 |
+
trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
|
135 |
+
|
136 |
+
# manually reflect the tform to undo the reflection done on xyR
|
137 |
+
TreflectY = np.array([
|
138 |
+
[-1, 0, 0],
|
139 |
+
[0, 1, 0],
|
140 |
+
[0, 0, 1]
|
141 |
+
])
|
142 |
+
|
143 |
+
trans2 = np.dot(trans2r, TreflectY)
|
144 |
+
|
145 |
+
# Figure out if trans1 or trans2 is better
|
146 |
+
xy1 = tformfwd(trans1, uv)
|
147 |
+
norm1 = norm(xy1 - xy)
|
148 |
+
|
149 |
+
xy2 = tformfwd(trans2, uv)
|
150 |
+
norm2 = norm(xy2 - xy)
|
151 |
+
|
152 |
+
if norm1 <= norm2:
|
153 |
+
return trans1, trans1_inv
|
154 |
+
else:
|
155 |
+
trans2_inv = inv(trans2)
|
156 |
+
return trans2, trans2_inv
|
157 |
+
|
158 |
+
|
159 |
+
def get_similarity_transform(src_pts, dst_pts, reflective=True):
|
160 |
+
"""
|
161 |
+
Function:
|
162 |
+
----------
|
163 |
+
Find Similarity Transform Matrix 'trans':
|
164 |
+
u = src_pts[:, 0]
|
165 |
+
v = src_pts[:, 1]
|
166 |
+
x = dst_pts[:, 0]
|
167 |
+
y = dst_pts[:, 1]
|
168 |
+
[x, y, 1] = [u, v, 1] * trans
|
169 |
+
|
170 |
+
Parameters:
|
171 |
+
----------
|
172 |
+
@src_pts: Kx2 np.array
|
173 |
+
source points, each row is a pair of coordinates (x, y)
|
174 |
+
@dst_pts: Kx2 np.array
|
175 |
+
destination points, each row is a pair of transformed
|
176 |
+
coordinates (x, y)
|
177 |
+
@reflective: True or False
|
178 |
+
if True:
|
179 |
+
use reflective similarity transform
|
180 |
+
else:
|
181 |
+
use non-reflective similarity transform
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
----------
|
185 |
+
@trans: 3x3 np.array
|
186 |
+
transform matrix from uv to xy
|
187 |
+
trans_inv: 3x3 np.array
|
188 |
+
inverse of trans, transform matrix from xy to uv
|
189 |
+
"""
|
190 |
+
|
191 |
+
if reflective:
|
192 |
+
trans, trans_inv = findSimilarity(src_pts, dst_pts)
|
193 |
+
else:
|
194 |
+
trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
|
195 |
+
|
196 |
+
return trans, trans_inv
|
197 |
+
|
198 |
+
|
199 |
+
def cvt_tform_mat_for_cv2(trans):
|
200 |
+
"""
|
201 |
+
Function:
|
202 |
+
----------
|
203 |
+
Convert Transform Matrix 'trans' into 'cv2_trans' which could be
|
204 |
+
directly used by cv2.warpAffine():
|
205 |
+
u = src_pts[:, 0]
|
206 |
+
v = src_pts[:, 1]
|
207 |
+
x = dst_pts[:, 0]
|
208 |
+
y = dst_pts[:, 1]
|
209 |
+
[x, y].T = cv_trans * [u, v, 1].T
|
210 |
+
|
211 |
+
Parameters:
|
212 |
+
----------
|
213 |
+
@trans: 3x3 np.array
|
214 |
+
transform matrix from uv to xy
|
215 |
+
|
216 |
+
Returns:
|
217 |
+
----------
|
218 |
+
@cv2_trans: 2x3 np.array
|
219 |
+
transform matrix from src_pts to dst_pts, could be directly used
|
220 |
+
for cv2.warpAffine()
|
221 |
+
"""
|
222 |
+
cv2_trans = trans[:, 0:2].T
|
223 |
+
|
224 |
+
return cv2_trans
|
225 |
+
|
226 |
+
|
227 |
+
def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
|
228 |
+
"""
|
229 |
+
Function:
|
230 |
+
----------
|
231 |
+
Find Similarity Transform Matrix 'cv2_trans' which could be
|
232 |
+
directly used by cv2.warpAffine():
|
233 |
+
u = src_pts[:, 0]
|
234 |
+
v = src_pts[:, 1]
|
235 |
+
x = dst_pts[:, 0]
|
236 |
+
y = dst_pts[:, 1]
|
237 |
+
[x, y].T = cv_trans * [u, v, 1].T
|
238 |
+
|
239 |
+
Parameters:
|
240 |
+
----------
|
241 |
+
@src_pts: Kx2 np.array
|
242 |
+
source points, each row is a pair of coordinates (x, y)
|
243 |
+
@dst_pts: Kx2 np.array
|
244 |
+
destination points, each row is a pair of transformed
|
245 |
+
coordinates (x, y)
|
246 |
+
reflective: True or False
|
247 |
+
if True:
|
248 |
+
use reflective similarity transform
|
249 |
+
else:
|
250 |
+
use non-reflective similarity transform
|
251 |
+
|
252 |
+
Returns:
|
253 |
+
----------
|
254 |
+
@cv2_trans: 2x3 np.array
|
255 |
+
transform matrix from src_pts to dst_pts, could be directly used
|
256 |
+
for cv2.warpAffine()
|
257 |
+
"""
|
258 |
+
trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
|
259 |
+
cv2_trans = cvt_tform_mat_for_cv2(trans)
|
260 |
+
|
261 |
+
return cv2_trans
|
262 |
+
|
263 |
+
|
264 |
+
if __name__ == '__main__':
|
265 |
+
"""
|
266 |
+
u = [0, 6, -2]
|
267 |
+
v = [0, 3, 5]
|
268 |
+
x = [-1, 0, 4]
|
269 |
+
y = [-1, -10, 4]
|
270 |
+
|
271 |
+
# In Matlab, run:
|
272 |
+
#
|
273 |
+
# uv = [u'; v'];
|
274 |
+
# xy = [x'; y'];
|
275 |
+
# tform_sim=cp2tform(uv,xy,'similarity');
|
276 |
+
#
|
277 |
+
# trans = tform_sim.tdata.T
|
278 |
+
# ans =
|
279 |
+
# -0.0764 -1.6190 0
|
280 |
+
# 1.6190 -0.0764 0
|
281 |
+
# -3.2156 0.0290 1.0000
|
282 |
+
# trans_inv = tform_sim.tdata.Tinv
|
283 |
+
# ans =
|
284 |
+
#
|
285 |
+
# -0.0291 0.6163 0
|
286 |
+
# -0.6163 -0.0291 0
|
287 |
+
# -0.0756 1.9826 1.0000
|
288 |
+
# xy_m=tformfwd(tform_sim, u,v)
|
289 |
+
#
|
290 |
+
# xy_m =
|
291 |
+
#
|
292 |
+
# -3.2156 0.0290
|
293 |
+
# 1.1833 -9.9143
|
294 |
+
# 5.0323 2.8853
|
295 |
+
# uv_m=tforminv(tform_sim, x,y)
|
296 |
+
#
|
297 |
+
# uv_m =
|
298 |
+
#
|
299 |
+
# 0.5698 1.3953
|
300 |
+
# 6.0872 2.2733
|
301 |
+
# -2.6570 4.3314
|
302 |
+
"""
|
303 |
+
u = [0, 6, -2]
|
304 |
+
v = [0, 3, 5]
|
305 |
+
x = [-1, 0, 4]
|
306 |
+
y = [-1, -10, 4]
|
307 |
+
|
308 |
+
uv = np.array((u, v)).T
|
309 |
+
xy = np.array((x, y)).T
|
310 |
+
|
311 |
+
print('\n--->uv:')
|
312 |
+
print(uv)
|
313 |
+
print('\n--->xy:')
|
314 |
+
print(xy)
|
315 |
+
|
316 |
+
trans, trans_inv = get_similarity_transform(uv, xy)
|
317 |
+
|
318 |
+
print('\n--->trans matrix:')
|
319 |
+
print(trans)
|
320 |
+
|
321 |
+
print('\n--->trans_inv matrix:')
|
322 |
+
print(trans_inv)
|
323 |
+
|
324 |
+
print('\n---> apply transform to uv')
|
325 |
+
print('\nxy_m = uv_augmented * trans')
|
326 |
+
uv_aug = np.hstack((
|
327 |
+
uv, np.ones((uv.shape[0], 1))
|
328 |
+
))
|
329 |
+
xy_m = np.dot(uv_aug, trans)
|
330 |
+
print(xy_m)
|
331 |
+
|
332 |
+
print('\nxy_m = tformfwd(trans, uv)')
|
333 |
+
xy_m = tformfwd(trans, uv)
|
334 |
+
print(xy_m)
|
335 |
+
|
336 |
+
print('\n---> apply inverse transform to xy')
|
337 |
+
print('\nuv_m = xy_augmented * trans_inv')
|
338 |
+
xy_aug = np.hstack((
|
339 |
+
xy, np.ones((xy.shape[0], 1))
|
340 |
+
))
|
341 |
+
uv_m = np.dot(xy_aug, trans_inv)
|
342 |
+
print(uv_m)
|
343 |
+
|
344 |
+
print('\nuv_m = tformfwd(trans_inv, xy)')
|
345 |
+
uv_m = tformfwd(trans_inv, xy)
|
346 |
+
print(uv_m)
|
347 |
+
|
348 |
+
uv_m = tforminv(trans, xy)
|
349 |
+
print('\nuv_m = tforminv(trans, xy)')
|
350 |
+
print(uv_m)
|
utils/parser.py
ADDED
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
def parse_args(input_args=None):
|
5 |
+
parser = argparse.ArgumentParser(description="Train Consistency Encoder.")
|
6 |
+
parser.add_argument(
|
7 |
+
"--pretrained_model_name_or_path",
|
8 |
+
type=str,
|
9 |
+
default=None,
|
10 |
+
required=True,
|
11 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
12 |
+
)
|
13 |
+
parser.add_argument(
|
14 |
+
"--pretrained_vae_model_name_or_path",
|
15 |
+
type=str,
|
16 |
+
default=None,
|
17 |
+
help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
|
18 |
+
)
|
19 |
+
parser.add_argument(
|
20 |
+
"--revision",
|
21 |
+
type=str,
|
22 |
+
default=None,
|
23 |
+
required=False,
|
24 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
"--variant",
|
28 |
+
type=str,
|
29 |
+
default=None,
|
30 |
+
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
31 |
+
)
|
32 |
+
|
33 |
+
# parser.add_argument(
|
34 |
+
# "--instance_data_dir",
|
35 |
+
# type=str,
|
36 |
+
# required=True,
|
37 |
+
# help=("A folder containing the training data. "),
|
38 |
+
# )
|
39 |
+
|
40 |
+
parser.add_argument(
|
41 |
+
"--data_config_path",
|
42 |
+
type=str,
|
43 |
+
required=True,
|
44 |
+
help=("A folder containing the training data. "),
|
45 |
+
)
|
46 |
+
|
47 |
+
parser.add_argument(
|
48 |
+
"--cache_dir",
|
49 |
+
type=str,
|
50 |
+
default=None,
|
51 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
52 |
+
)
|
53 |
+
|
54 |
+
parser.add_argument(
|
55 |
+
"--image_column",
|
56 |
+
type=str,
|
57 |
+
default="image",
|
58 |
+
help="The column of the dataset containing the target image. By "
|
59 |
+
"default, the standard Image Dataset maps out 'file_name' "
|
60 |
+
"to 'image'.",
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--caption_column",
|
64 |
+
type=str,
|
65 |
+
default=None,
|
66 |
+
help="The column of the dataset containing the instance prompt for each image",
|
67 |
+
)
|
68 |
+
|
69 |
+
parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
|
70 |
+
|
71 |
+
parser.add_argument(
|
72 |
+
"--instance_prompt",
|
73 |
+
type=str,
|
74 |
+
default=None,
|
75 |
+
required=True,
|
76 |
+
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
|
77 |
+
)
|
78 |
+
|
79 |
+
parser.add_argument(
|
80 |
+
"--validation_prompt",
|
81 |
+
type=str,
|
82 |
+
default=None,
|
83 |
+
help="A prompt that is used during validation to verify that the model is learning.",
|
84 |
+
)
|
85 |
+
parser.add_argument(
|
86 |
+
"--num_train_vis_images",
|
87 |
+
type=int,
|
88 |
+
default=2,
|
89 |
+
help="Number of images that should be generated during validation with `validation_prompt`.",
|
90 |
+
)
|
91 |
+
parser.add_argument(
|
92 |
+
"--num_validation_images",
|
93 |
+
type=int,
|
94 |
+
default=2,
|
95 |
+
help="Number of images that should be generated during validation with `validation_prompt`.",
|
96 |
+
)
|
97 |
+
|
98 |
+
parser.add_argument(
|
99 |
+
"--validation_vis_steps",
|
100 |
+
type=int,
|
101 |
+
default=500,
|
102 |
+
help=(
|
103 |
+
"Run dreambooth validation every X steps. Dreambooth validation consists of running the prompt"
|
104 |
+
" `args.validation_prompt` multiple times: `args.num_validation_images`."
|
105 |
+
),
|
106 |
+
)
|
107 |
+
|
108 |
+
parser.add_argument(
|
109 |
+
"--train_vis_steps",
|
110 |
+
type=int,
|
111 |
+
default=500,
|
112 |
+
help=(
|
113 |
+
"Run dreambooth validation every X steps. Dreambooth validation consists of running the prompt"
|
114 |
+
" `args.validation_prompt` multiple times: `args.num_validation_images`."
|
115 |
+
),
|
116 |
+
)
|
117 |
+
|
118 |
+
parser.add_argument(
|
119 |
+
"--vis_lcm",
|
120 |
+
type=bool,
|
121 |
+
default=True,
|
122 |
+
help=(
|
123 |
+
"Also log results of LCM inference",
|
124 |
+
),
|
125 |
+
)
|
126 |
+
|
127 |
+
parser.add_argument(
|
128 |
+
"--output_dir",
|
129 |
+
type=str,
|
130 |
+
default="lora-dreambooth-model",
|
131 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
132 |
+
)
|
133 |
+
|
134 |
+
parser.add_argument("--save_only_encoder", action="store_true", help="Only save the encoder and not the full accelerator state")
|
135 |
+
|
136 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
137 |
+
|
138 |
+
parser.add_argument("--freeze_encoder_unet", action="store_true", help="Don't train encoder unet")
|
139 |
+
parser.add_argument("--predict_word_embedding", action="store_true", help="Predict word embeddings in addition to KV features")
|
140 |
+
parser.add_argument("--ip_adapter_feature_extractor_path", type=str, help="Path to pre-trained feature extractor for IP-adapter")
|
141 |
+
parser.add_argument("--ip_adapter_model_path", type=str, help="Path to pre-trained IP-adapter.")
|
142 |
+
parser.add_argument("--ip_adapter_tokens", type=int, default=16, help="Number of tokens to use in IP-adapter cross attention mechanism")
|
143 |
+
parser.add_argument("--optimize_adapter", action="store_true", help="Optimize IP-adapter parameters (projector + cross-attention layers)")
|
144 |
+
parser.add_argument("--adapter_attention_scale", type=float, default=1.0, help="Relative strength of the adapter cross attention layers")
|
145 |
+
parser.add_argument("--adapter_lr", type=float, help="Learning rate for the adapter parameters. Defaults to the global LR if not provided")
|
146 |
+
|
147 |
+
parser.add_argument("--noisy_encoder_input", action="store_true", help="Noise the encoder input to the same step as the decoder?")
|
148 |
+
|
149 |
+
# related to CFG:
|
150 |
+
parser.add_argument("--adapter_drop_chance", type=float, default=0.0, help="Chance to drop adapter condition input during training")
|
151 |
+
parser.add_argument("--text_drop_chance", type=float, default=0.0, help="Chance to drop text condition during training")
|
152 |
+
parser.add_argument("--kv_drop_chance", type=float, default=0.0, help="Chance to drop KV condition during training")
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
parser.add_argument(
|
157 |
+
"--resolution",
|
158 |
+
type=int,
|
159 |
+
default=1024,
|
160 |
+
help=(
|
161 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
162 |
+
" resolution"
|
163 |
+
),
|
164 |
+
)
|
165 |
+
|
166 |
+
parser.add_argument(
|
167 |
+
"--crops_coords_top_left_h",
|
168 |
+
type=int,
|
169 |
+
default=0,
|
170 |
+
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
|
171 |
+
)
|
172 |
+
|
173 |
+
parser.add_argument(
|
174 |
+
"--crops_coords_top_left_w",
|
175 |
+
type=int,
|
176 |
+
default=0,
|
177 |
+
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
|
178 |
+
)
|
179 |
+
|
180 |
+
parser.add_argument(
|
181 |
+
"--center_crop",
|
182 |
+
default=False,
|
183 |
+
action="store_true",
|
184 |
+
help=(
|
185 |
+
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
186 |
+
" cropped. The images will be resized to the resolution first before cropping."
|
187 |
+
),
|
188 |
+
)
|
189 |
+
|
190 |
+
parser.add_argument(
|
191 |
+
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
192 |
+
)
|
193 |
+
|
194 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
195 |
+
|
196 |
+
parser.add_argument(
|
197 |
+
"--max_train_steps",
|
198 |
+
type=int,
|
199 |
+
default=None,
|
200 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
201 |
+
)
|
202 |
+
|
203 |
+
parser.add_argument(
|
204 |
+
"--checkpointing_steps",
|
205 |
+
type=int,
|
206 |
+
default=500,
|
207 |
+
help=(
|
208 |
+
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
209 |
+
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
|
210 |
+
" training using `--resume_from_checkpoint`."
|
211 |
+
),
|
212 |
+
)
|
213 |
+
|
214 |
+
parser.add_argument(
|
215 |
+
"--checkpoints_total_limit",
|
216 |
+
type=int,
|
217 |
+
default=5,
|
218 |
+
help=("Max number of checkpoints to store."),
|
219 |
+
)
|
220 |
+
|
221 |
+
parser.add_argument(
|
222 |
+
"--resume_from_checkpoint",
|
223 |
+
type=str,
|
224 |
+
default=None,
|
225 |
+
help=(
|
226 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
227 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
228 |
+
),
|
229 |
+
)
|
230 |
+
|
231 |
+
parser.add_argument("--max_timesteps_for_x0_loss", type=int, default=1001)
|
232 |
+
|
233 |
+
parser.add_argument(
|
234 |
+
"--gradient_accumulation_steps",
|
235 |
+
type=int,
|
236 |
+
default=1,
|
237 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
238 |
+
)
|
239 |
+
|
240 |
+
parser.add_argument(
|
241 |
+
"--gradient_checkpointing",
|
242 |
+
action="store_true",
|
243 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
244 |
+
)
|
245 |
+
|
246 |
+
parser.add_argument(
|
247 |
+
"--learning_rate",
|
248 |
+
type=float,
|
249 |
+
default=1e-4,
|
250 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
251 |
+
)
|
252 |
+
|
253 |
+
parser.add_argument(
|
254 |
+
"--scale_lr",
|
255 |
+
action="store_true",
|
256 |
+
default=False,
|
257 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
258 |
+
)
|
259 |
+
|
260 |
+
parser.add_argument(
|
261 |
+
"--lr_scheduler",
|
262 |
+
type=str,
|
263 |
+
default="constant",
|
264 |
+
help=(
|
265 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
266 |
+
' "constant", "constant_with_warmup"]'
|
267 |
+
),
|
268 |
+
)
|
269 |
+
|
270 |
+
parser.add_argument(
|
271 |
+
"--snr_gamma",
|
272 |
+
type=float,
|
273 |
+
default=None,
|
274 |
+
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
275 |
+
"More details here: https://arxiv.org/abs/2303.09556.",
|
276 |
+
)
|
277 |
+
|
278 |
+
parser.add_argument(
|
279 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
280 |
+
)
|
281 |
+
|
282 |
+
parser.add_argument(
|
283 |
+
"--lr_num_cycles",
|
284 |
+
type=int,
|
285 |
+
default=1,
|
286 |
+
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
287 |
+
)
|
288 |
+
|
289 |
+
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
|
290 |
+
|
291 |
+
parser.add_argument(
|
292 |
+
"--dataloader_num_workers",
|
293 |
+
type=int,
|
294 |
+
default=0,
|
295 |
+
help=(
|
296 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
297 |
+
),
|
298 |
+
)
|
299 |
+
|
300 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
|
301 |
+
|
302 |
+
parser.add_argument(
|
303 |
+
"--adam_epsilon",
|
304 |
+
type=float,
|
305 |
+
default=1e-08,
|
306 |
+
help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
|
307 |
+
)
|
308 |
+
|
309 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
310 |
+
|
311 |
+
parser.add_argument(
|
312 |
+
"--logging_dir",
|
313 |
+
type=str,
|
314 |
+
default="logs",
|
315 |
+
help=(
|
316 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
317 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
318 |
+
),
|
319 |
+
)
|
320 |
+
parser.add_argument(
|
321 |
+
"--allow_tf32",
|
322 |
+
action="store_true",
|
323 |
+
help=(
|
324 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
325 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
326 |
+
),
|
327 |
+
)
|
328 |
+
|
329 |
+
parser.add_argument(
|
330 |
+
"--report_to",
|
331 |
+
type=str,
|
332 |
+
default="wandb",
|
333 |
+
help=(
|
334 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
335 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
336 |
+
),
|
337 |
+
)
|
338 |
+
|
339 |
+
parser.add_argument(
|
340 |
+
"--mixed_precision",
|
341 |
+
type=str,
|
342 |
+
default=None,
|
343 |
+
choices=["no", "fp16", "bf16"],
|
344 |
+
help=(
|
345 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
346 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
347 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
348 |
+
),
|
349 |
+
)
|
350 |
+
|
351 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
352 |
+
|
353 |
+
parser.add_argument(
|
354 |
+
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
355 |
+
)
|
356 |
+
|
357 |
+
parser.add_argument(
|
358 |
+
"--rank",
|
359 |
+
type=int,
|
360 |
+
default=4,
|
361 |
+
help=("The dimension of the LoRA update matrices."),
|
362 |
+
)
|
363 |
+
|
364 |
+
parser.add_argument(
|
365 |
+
"--pretrained_lcm_lora_path",
|
366 |
+
type=str,
|
367 |
+
default="latent-consistency/lcm-lora-sdxl",
|
368 |
+
help=("Path for lcm lora pretrained"),
|
369 |
+
)
|
370 |
+
|
371 |
+
parser.add_argument(
|
372 |
+
"--losses_config_path",
|
373 |
+
type=str,
|
374 |
+
required=True,
|
375 |
+
help=("A yaml file containing losses to use and their weights."),
|
376 |
+
)
|
377 |
+
|
378 |
+
parser.add_argument(
|
379 |
+
"--lcm_every_k_steps",
|
380 |
+
type=int,
|
381 |
+
default=-1,
|
382 |
+
help="How often to run lcm. If -1, lcm is not run."
|
383 |
+
)
|
384 |
+
|
385 |
+
parser.add_argument(
|
386 |
+
"--lcm_batch_size",
|
387 |
+
type=int,
|
388 |
+
default=1,
|
389 |
+
help="Batch size for lcm."
|
390 |
+
)
|
391 |
+
parser.add_argument(
|
392 |
+
"--lcm_max_timestep",
|
393 |
+
type=int,
|
394 |
+
default=1000,
|
395 |
+
help="Max timestep to use with LCM."
|
396 |
+
)
|
397 |
+
|
398 |
+
parser.add_argument(
|
399 |
+
"--lcm_sample_scale_every_k_steps",
|
400 |
+
type=int,
|
401 |
+
default=-1,
|
402 |
+
help="How often to change lcm scale. If -1, scale is fixed at 1."
|
403 |
+
)
|
404 |
+
|
405 |
+
parser.add_argument(
|
406 |
+
"--lcm_min_scale",
|
407 |
+
type=float,
|
408 |
+
default=0.1,
|
409 |
+
help="When sampling lcm scale, the minimum scale to use."
|
410 |
+
)
|
411 |
+
|
412 |
+
parser.add_argument(
|
413 |
+
"--scale_lcm_by_max_step",
|
414 |
+
action="store_true",
|
415 |
+
help="scale LCM lora alpha linearly by the maximal timestep sampled that iteration"
|
416 |
+
)
|
417 |
+
|
418 |
+
parser.add_argument(
|
419 |
+
"--lcm_sample_full_lcm_prob",
|
420 |
+
type=float,
|
421 |
+
default=0.2,
|
422 |
+
help="When sampling lcm scale, the probability of using full lcm (scale of 1)."
|
423 |
+
)
|
424 |
+
|
425 |
+
parser.add_argument(
|
426 |
+
"--run_on_cpu",
|
427 |
+
action="store_true",
|
428 |
+
help="whether to run on cpu or not"
|
429 |
+
)
|
430 |
+
|
431 |
+
parser.add_argument(
|
432 |
+
"--experiment_name",
|
433 |
+
type=str,
|
434 |
+
help=("A short description of the experiment to add to the wand run log. "),
|
435 |
+
)
|
436 |
+
parser.add_argument("--encoder_lora_rank", type=int, default=0, help="Rank of Lora in unet encoder. 0 means no lora")
|
437 |
+
|
438 |
+
parser.add_argument("--kvcopy_lora_rank", type=int, default=0, help="Rank of lora in the kvcopy modules. 0 means no lora")
|
439 |
+
|
440 |
+
|
441 |
+
if input_args is not None:
|
442 |
+
args = parser.parse_args(input_args)
|
443 |
+
else:
|
444 |
+
args = parser.parse_args()
|
445 |
+
|
446 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
447 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
448 |
+
args.local_rank = env_local_rank
|
449 |
+
|
450 |
+
args.optimizer = "AdamW"
|
451 |
+
|
452 |
+
return args
|