Spaces:
Sleeping
Sleeping
Upload 23 files
Browse files- LICENSE +201 -0
- README.md +302 -9
- app.py +97 -0
- gradio_app/app_utils.py +30 -0
- gradio_app/image_synthesis_app.py +166 -0
- gradio_app/images/corgi.jpg +0 -0
- gradio_app/images/person.png +0 -0
- gradio_app/real_image_editing_app.py +162 -0
- masactrl/__init__.py +0 -0
- masactrl/diffuser_utils.py +275 -0
- masactrl/masactrl.py +334 -0
- masactrl/masactrl_processor.py +259 -0
- masactrl/masactrl_utils.py +212 -0
- masactrl_w_adapter/ddim.py +375 -0
- masactrl_w_adapter/masactrl_w_adapter.py +217 -0
- playground.ipynb +149 -0
- playground_real.ipynb +188 -0
- requirements.txt +13 -0
- run_synthesis_genshin_impact_xl.py +117 -0
- run_synthesis_genshin_impact_xl_app.py +97 -0
- run_synthesis_sdxl.py +83 -0
- run_synthesis_sdxl_processor.py +90 -0
- style.css +3 -0
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,12 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MotionCtrl and MasaCtrl: Genshin Impact Character Synthesis
|
2 |
+
|
3 |
+
This repository provides a guide to setting up and running the MotionCtrl and MasaCtrl projects for synthesizing Genshin Impact characters using machine learning models. The process involves cloning repositories, installing dependencies, and executing scripts to generate and visualize character images.
|
4 |
+
|
5 |
+
## Table of Contents
|
6 |
+
- [Prerequisites](#prerequisites)
|
7 |
+
- [Installation](#installation)
|
8 |
+
- [Running the Synthesis](#running-the-synthesis)
|
9 |
+
- [Using Gradio Interface](#using-gradio-interface)
|
10 |
+
- [Example Prompts](#example-prompts)
|
11 |
+
|
12 |
+
## Prerequisites
|
13 |
+
|
14 |
+
Before you begin, ensure you have the following installed:
|
15 |
+
- Python 3.10
|
16 |
+
- Conda (for environment management)
|
17 |
+
- Git
|
18 |
+
- Git LFS (Large File Storage)
|
19 |
+
- FFmpeg
|
20 |
+
|
21 |
+
## Installation
|
22 |
+
|
23 |
+
### Step 1: Clone the MotionCtrl Repository
|
24 |
+
Clone the MotionCtrl repository and install its dependencies:
|
25 |
+
|
26 |
+
```bash
|
27 |
+
git clone https://huggingface.co/spaces/svjack/MotionCtrl
|
28 |
+
cd MotionCtrl
|
29 |
+
pip install -r requirements.txt
|
30 |
+
```
|
31 |
+
|
32 |
+
### Step 2: Install System Dependencies
|
33 |
+
Update your package list and install necessary system packages:
|
34 |
+
|
35 |
+
```bash
|
36 |
+
sudo apt-get update
|
37 |
+
sudo apt-get install cbm git-lfs ffmpeg
|
38 |
+
```
|
39 |
+
|
40 |
+
### Step 3: Set Up Python Environment
|
41 |
+
Create a Conda environment with Python 3.10, activate it, and install the IPython kernel:
|
42 |
+
|
43 |
+
```bash
|
44 |
+
conda create -n py310 python=3.10
|
45 |
+
conda activate py310
|
46 |
+
pip install ipykernel
|
47 |
+
python -m ipykernel install --user --name py310 --display-name "py310"
|
48 |
+
```
|
49 |
+
|
50 |
+
### Step 4: Clone the MasaCtrl Repository
|
51 |
+
Clone the MasaCtrl repository and install its dependencies:
|
52 |
+
|
53 |
+
```bash
|
54 |
+
git clone https://github.com/svjack/MasaCtrl
|
55 |
+
cd MasaCtrl
|
56 |
+
pip install -r requirements.txt
|
57 |
+
```
|
58 |
+
|
59 |
+
## Running the Synthesis
|
60 |
+
|
61 |
+
### Command Line Interface
|
62 |
+
Run the synthesis script to generate images of Genshin Impact characters:
|
63 |
+
|
64 |
+
```bash
|
65 |
+
python run_synthesis_genshin_impact_xl.py --model_path "svjack/GenshinImpact_XL_Base" \
|
66 |
+
--prompt1 "solo,ZHONGLI\(genshin impact\),1boy,highres," \
|
67 |
+
--prompt2 "solo,ZHONGLI drink tea use chinese cup \(genshin impact\),1boy,highres," --guidance_scale 5
|
68 |
+
```
|
69 |
+
|
70 |
+
### Gradio Interface
|
71 |
+
Alternatively, you can use the Gradio interface for a more interactive experience:
|
72 |
+
|
73 |
+
```bash
|
74 |
+
python run_synthesis_genshin_impact_xl_app.py
|
75 |
+
```
|
76 |
+
|
77 |
+
## Example Prompts
|
78 |
+
|
79 |
+
Here are some example prompts you can use to generate different character images: (Image with MasaCtrl more like Source Image: In terms of background and other aspects)
|
80 |
+
|
81 |
+
- **Zhongli Drinking Tea:**
|
82 |
+
```
|
83 |
+
"solo,ZHONGLI(genshin impact),1boy,highres," -> "solo,ZHONGLI drink tea use chinese cup (genshin impact),1boy,highres,"
|
84 |
+
```
|
85 |
+
![Screenshot 2024-11-17 132742](https://github.com/user-attachments/assets/00451728-f2d5-4009-afa8-23baaabdc223)
|
86 |
+
|
87 |
+
- **Kamisato Ayato Smiling:**
|
88 |
+
```
|
89 |
+
"solo,KAMISATO AYATO(genshin impact),1boy,highres," -> "solo,KAMISATO AYATO smiling (genshin impact),1boy,highres,"
|
90 |
+
```
|
91 |
+
|
92 |
+
![Screenshot 2024-11-17 133421](https://github.com/user-attachments/assets/7a920f4c-8a3a-4387-98d6-381a798566ef)
|
93 |
+
|
94 |
+
## MasaCtrl: Tuning-free <span style="text-decoration: underline"><font color="Tomato">M</font></span>utu<span style="text-decoration: underline"><font color="Tomato">a</font></span>l <span style="text-decoration: underline"><font color="Tomato">S</font></span>elf-<span style="text-decoration: underline"><font color="Tomato">A</font></span>ttention <span style="text-decoration: underline"><font color="Tomato">Control</font></span> for Consistent Image Synthesis and Editing
|
95 |
+
|
96 |
+
Pytorch implementation of [MasaCtrl: Tuning-free Mutual Self-Attention Control for **Consistent Image Synthesis and Editing**](https://arxiv.org/abs/2304.08465)
|
97 |
+
|
98 |
+
[Mingdeng Cao](https://github.com/ljzycmd),
|
99 |
+
[Xintao Wang](https://xinntao.github.io/),
|
100 |
+
[Zhongang Qi](https://scholar.google.com/citations?user=zJvrrusAAAAJ),
|
101 |
+
[Ying Shan](https://scholar.google.com/citations?user=4oXBp9UAAAAJ),
|
102 |
+
[Xiaohu Qie](https://scholar.google.com/citations?user=mk-F69UAAAAJ),
|
103 |
+
[Yinqiang Zheng](https://scholar.google.com/citations?user=JD-5DKcAAAAJ)
|
104 |
+
|
105 |
+
[![arXiv](https://img.shields.io/badge/ArXiv-2304.08465-brightgreen)](https://arxiv.org/abs/2304.08465)
|
106 |
+
[![Project page](https://img.shields.io/badge/Project-Page-brightgreen)](https://ljzycmd.github.io/projects/MasaCtrl/)
|
107 |
+
[![demo](https://img.shields.io/badge/Demo-Hugging%20Face-brightgreen)](https://huggingface.co/spaces/TencentARC/MasaCtrl)
|
108 |
+
[![demo](https://img.shields.io/badge/Demo-Colab-brightgreen)](https://colab.research.google.com/drive/1DZeQn2WvRBsNg4feS1bJrwWnIzw1zLJq?usp=sharing)
|
109 |
+
[![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/MingDengCao/MasaCtrl)
|
110 |
+
|
111 |
---
|
112 |
+
|
113 |
+
<div align="center">
|
114 |
+
<img src="https://huggingface.co/TencentARC/MasaCtrl/resolve/main/assets/overview.gif">
|
115 |
+
<i> MasaCtrl enables performing various consistent non-rigid image synthesis and editing without fine-tuning and optimization. </i>
|
116 |
+
</div>
|
117 |
+
|
118 |
+
|
119 |
+
## Updates
|
120 |
+
- [2024/8/17] We add AttnProcessor based MasaCtrlProcessor, please check `masactrl/masactrl_processor.py` and `run_synthesis_sdxl_processor.py`. You can integrate MasaCtrl into official Diffuser pipeline by register the attention processor.
|
121 |
+
- [2023/8/20] MasaCtrl supports SDXL (and other variants) now. ![sdxl_example](https://huggingface.co/TencentARC/MasaCtrl/resolve/main/assets/sdxl_example.jpg)
|
122 |
+
- [2023/5/13] The inference code of MasaCtrl with T2I-Adapter is available.
|
123 |
+
- [2023/4/28] [Hugging Face demo](https://huggingface.co/spaces/TencentARC/MasaCtrl) released.
|
124 |
+
- [2023/4/25] Code released.
|
125 |
+
- [2023/4/17] Paper is available [here](https://arxiv.org/abs/2304.08465).
|
126 |
+
|
127 |
---
|
128 |
|
129 |
+
## Introduction
|
130 |
+
|
131 |
+
We propose MasaCtrl, a tuning-free method for non-rigid consistent image synthesis and editing. The key idea is to combine the `contents` from the *source image* and the `layout` synthesized from *text prompt and additional controls* into the desired synthesized or edited image, by querying semantically correlated features with **Mutual Self-Attention Control**.
|
132 |
+
|
133 |
+
|
134 |
+
## Main Features
|
135 |
+
|
136 |
+
### 1 Consistent Image Synthesis and Editing
|
137 |
+
|
138 |
+
MasaCtrl can perform prompt-based image synthesis and editing that changes the layout while maintaining contents of source image.
|
139 |
+
|
140 |
+
>*The target layout is synthesized directly from the target prompt.*
|
141 |
+
|
142 |
+
<details><summary>View visual results</summary>
|
143 |
+
<div align="center">
|
144 |
+
<img src="https://huggingface.co/TencentARC/MasaCtrl/resolve/main/assets/results_synthetic.png">
|
145 |
+
<i>Consistent synthesis results</i>
|
146 |
+
|
147 |
+
<img src="https://huggingface.co/TencentARC/MasaCtrl/resolve/main/assets/results_real.png">
|
148 |
+
<i>Real image editing results</i>
|
149 |
+
</div>
|
150 |
+
</details>
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
### 2 Integration to Controllable Diffusion Models
|
155 |
+
|
156 |
+
Directly modifying the text prompts often cannot generate target layout of desired image, thus we further integrate our method into existing proposed controllable diffusion pipelines (like T2I-Adapter and ControlNet) to obtain stable synthesis and editing results.
|
157 |
+
|
158 |
+
>*The target layout controlled by additional guidance.*
|
159 |
+
|
160 |
+
<details><summary>View visual results</summary>
|
161 |
+
<div align="center">
|
162 |
+
<img src="https://huggingface.co/TencentARC/MasaCtrl/resolve/main/assets/results_w_adapter.png">
|
163 |
+
<i>Synthesis (left part) and editing (right part) results with T2I-Adapter</i>
|
164 |
+
</div>
|
165 |
+
</details>
|
166 |
+
|
167 |
+
### 3 Generalization to Other Models: Anything-V4
|
168 |
+
|
169 |
+
Our method also generalize well to other Stable-Diffusion-based models.
|
170 |
+
|
171 |
+
<details><summary>View visual results</summary>
|
172 |
+
<div align="center">
|
173 |
+
<img src="https://huggingface.co/TencentARC/MasaCtrl/resolve/main/assets/anythingv4_synthetic.png">
|
174 |
+
<i>Results on Anything-V4</i>
|
175 |
+
</div>
|
176 |
+
</details>
|
177 |
+
|
178 |
+
|
179 |
+
### 4 Extension to Video Synthesis
|
180 |
+
|
181 |
+
With dense consistent guidance, MasaCtrl enables video synthesis
|
182 |
+
|
183 |
+
<details><summary>View visual results</summary>
|
184 |
+
<div align="center">
|
185 |
+
<img src="https://huggingface.co/TencentARC/MasaCtrl/resolve/main/assets/results_w_adapter_consistent.png">
|
186 |
+
<i>Video Synthesis Results (with keypose and canny guidance)</i>
|
187 |
+
</div>
|
188 |
+
</details>
|
189 |
+
|
190 |
+
|
191 |
+
## Usage
|
192 |
+
|
193 |
+
### Requirements
|
194 |
+
We implement our method with [diffusers](https://github.com/huggingface/diffusers) code base with similar code structure to [Prompt-to-Prompt](https://github.com/google/prompt-to-prompt). The code runs on Python 3.8.5 with Pytorch 1.11. Conda environment is highly recommended.
|
195 |
+
|
196 |
+
```base
|
197 |
+
pip install -r requirements.txt
|
198 |
+
```
|
199 |
+
|
200 |
+
### Checkpoints
|
201 |
+
|
202 |
+
**Stable Diffusion:**
|
203 |
+
We mainly conduct expriemnts on Stable Diffusion v1-4, while our method can generalize to other versions (like v1-5). You can download these checkpoints on their official repository and [Hugging Face](https://huggingface.co/).
|
204 |
+
|
205 |
+
**Personalized Models:**
|
206 |
+
You can download personlized models from [CIVITAI](https://civitai.com/) or train your own customized models.
|
207 |
+
|
208 |
+
|
209 |
+
### Demos
|
210 |
+
|
211 |
+
**Notebook demos**
|
212 |
+
|
213 |
+
To run the synthesis with MasaCtrl, single GPU with at least 16 GB VRAM is required.
|
214 |
+
|
215 |
+
The notebook `playground.ipynb` and `playground_real.ipynb` provide the synthesis and real editing samples, respectively.
|
216 |
+
|
217 |
+
**Online demos**
|
218 |
+
|
219 |
+
We provide [![demo](https://img.shields.io/badge/Demo-Hugging%20Face-brightgreen)](https://huggingface.co/spaces/TencentARC/MasaCtrl) with Gradio app. Note that you may copy the demo into your own space to use the GPU. Online Colab demo [![demo](https://img.shields.io/badge/Demo-Colab-brightgreen)](https://colab.research.google.com/drive/1DZeQn2WvRBsNg4feS1bJrwWnIzw1zLJq?usp=sharing) is also available.
|
220 |
+
|
221 |
+
**Local Gradio demo**
|
222 |
+
|
223 |
+
You can launch the provided Gradio demo locally with
|
224 |
+
|
225 |
+
```bash
|
226 |
+
CUDA_VISIBLE_DEVICES=0 python app.py
|
227 |
+
```
|
228 |
+
|
229 |
+
|
230 |
+
### MasaCtrl with T2I-Adapter
|
231 |
+
|
232 |
+
Install [T2I-Adapter](https://github.com/TencentARC/T2I-Adapter) and prepare the checkpoints following their provided tutorial. Assuming it has been successfully installed and the root directory is `T2I-Adapter`.
|
233 |
+
|
234 |
+
Thereafter copy the core `masactrl` package and the inference code `masactrl_w_adapter.py` to the root directory of T2I-Adapter
|
235 |
+
|
236 |
+
```bash
|
237 |
+
cp -r MasaCtrl/masactrl T2I-Adapter/
|
238 |
+
cp MasaCtrl/masactrl_w_adapter/masactrl_w_adapter.py T2I-Adapter/
|
239 |
+
```
|
240 |
+
|
241 |
+
**[Updates]** Or you can clone the repo [MasaCtrl-w-T2I-Adapter](https://github.com/ljzycmd/T2I-Adapter-w-MasaCtrl) directly to your local space.
|
242 |
+
|
243 |
+
Last, you can inference the images with following command (with sketch adapter)
|
244 |
+
|
245 |
+
```bash
|
246 |
+
python masactrl_w_adapter.py \
|
247 |
+
--which_cond sketch \
|
248 |
+
--cond_path_src SOURCE_CONDITION_PATH \
|
249 |
+
--cond_path CONDITION_PATH \
|
250 |
+
--cond_inp_type sketch \
|
251 |
+
--prompt_src "A bear walking in the forest" \
|
252 |
+
--prompt "A bear standing in the forest" \
|
253 |
+
--sd_ckpt models/sd-v1-4.ckpt \
|
254 |
+
--resize_short_edge 512 \
|
255 |
+
--cond_tau 1.0 \
|
256 |
+
--cond_weight 1.0 \
|
257 |
+
--n_samples 1 \
|
258 |
+
--adapter_ckpt models/t2iadapter_sketch_sd14v1.pth
|
259 |
+
```
|
260 |
+
|
261 |
+
NOTE: You can download the sketch examples [here](https://huggingface.co/TencentARC/MasaCtrl/tree/main/sketch_example).
|
262 |
+
|
263 |
+
For real image, the DDIM inversion is performed to invert the image into the noise map, thus we add the inversion process into the original DDIM sampler. **You should replace the original file `T2I-Adapter/ldm/models/diffusion/ddim.py` with the exteneded version `MasaCtrl/masactrl_w_adapter/ddim.py` to enable the inversion function**. Then you can edit the real image with following command (with sketch adapter)
|
264 |
+
|
265 |
+
```bash
|
266 |
+
python masactrl_w_adapter.py \
|
267 |
+
--src_img_path SOURCE_IMAGE_PATH \
|
268 |
+
--cond_path CONDITION_PATH \
|
269 |
+
--cond_inp_type image \
|
270 |
+
--prompt_src "" \
|
271 |
+
--prompt "a photo of a man wearing black t-shirt, giving a thumbs up" \
|
272 |
+
--sd_ckpt models/sd-v1-4.ckpt \
|
273 |
+
--resize_short_edge 512 \
|
274 |
+
--cond_tau 1.0 \
|
275 |
+
--cond_weight 1.0 \
|
276 |
+
--n_samples 1 \
|
277 |
+
--which_cond sketch \
|
278 |
+
--adapter_ckpt models/t2iadapter_sketch_sd14v1.pth \
|
279 |
+
--outdir ./workdir/masactrl_w_adapter_inversion/black-shirt
|
280 |
+
```
|
281 |
+
|
282 |
+
NOTE: You can download the real image editing example [here](https://huggingface.co/TencentARC/MasaCtrl/tree/main/black_shirt_example).
|
283 |
+
|
284 |
+
## Acknowledgements
|
285 |
+
|
286 |
+
We thank the awesome research works [Prompt-to-Prompt](https://github.com/google/prompt-to-prompt), [T2I-Adapter](https://github.com/TencentARC/T2I-Adapter).
|
287 |
+
|
288 |
+
|
289 |
+
## Citation
|
290 |
+
|
291 |
+
```bibtex
|
292 |
+
@InProceedings{cao_2023_masactrl,
|
293 |
+
author = {Cao, Mingdeng and Wang, Xintao and Qi, Zhongang and Shan, Ying and Qie, Xiaohu and Zheng, Yinqiang},
|
294 |
+
title = {MasaCtrl: Tuning-Free Mutual Self-Attention Control for Consistent Image Synthesis and Editing},
|
295 |
+
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
|
296 |
+
month = {October},
|
297 |
+
year = {2023},
|
298 |
+
pages = {22560-22570}
|
299 |
+
}
|
300 |
+
```
|
301 |
+
|
302 |
+
|
303 |
+
## Contact
|
304 |
+
|
305 |
+
If you have any comments or questions, please [open a new issue](https://github.com/TencentARC/MasaCtrl/issues/new/choose) or feel free to contact [Mingdeng Cao](https://github.com/ljzycmd) and [Xintao Wang](https://xinntao.github.io/).
|
app.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from diffusers import DDIMScheduler, DiffusionPipeline
|
4 |
+
from masactrl.diffuser_utils import MasaCtrlPipeline
|
5 |
+
from masactrl.masactrl_utils import AttentionBase, regiter_attention_editor_diffusers
|
6 |
+
from masactrl.masactrl import MutualSelfAttentionControl
|
7 |
+
from pytorch_lightning import seed_everything
|
8 |
+
import os
|
9 |
+
import re
|
10 |
+
|
11 |
+
# 初始化设备和模型
|
12 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
13 |
+
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
14 |
+
model = DiffusionPipeline.from_pretrained("svjack/GenshinImpact_XL_Base", scheduler=scheduler).to(device)
|
15 |
+
|
16 |
+
def pathify(s):
|
17 |
+
return re.sub(r'[^a-zA-Z0-9]', '_', s.lower())
|
18 |
+
|
19 |
+
def consistent_synthesis(prompt1, prompt2, guidance_scale, seed, starting_step, starting_layer):
|
20 |
+
seed_everything(seed)
|
21 |
+
|
22 |
+
# 创建输出目录
|
23 |
+
out_dir_ori = os.path.join("masactrl_exp", pathify(prompt2))
|
24 |
+
os.makedirs(out_dir_ori, exist_ok=True)
|
25 |
+
|
26 |
+
prompts = [prompt1, prompt2]
|
27 |
+
|
28 |
+
# 初始化噪声图
|
29 |
+
start_code = torch.randn([1, 4, 128, 128], device=device)
|
30 |
+
start_code = start_code.expand(len(prompts), -1, -1, -1)
|
31 |
+
|
32 |
+
# 推理没有 MasaCtrl 的图像
|
33 |
+
editor = AttentionBase()
|
34 |
+
regiter_attention_editor_diffusers(model, editor)
|
35 |
+
image_ori = model(prompts, latents=start_code, guidance_scale=guidance_scale).images
|
36 |
+
|
37 |
+
images = []
|
38 |
+
# 劫持注意力模块
|
39 |
+
editor = MutualSelfAttentionControl(starting_step, starting_layer, model_type="SDXL")
|
40 |
+
regiter_attention_editor_diffusers(model, editor)
|
41 |
+
|
42 |
+
# 推理带 MasaCtrl 的图像
|
43 |
+
image_masactrl = model(prompts, latents=start_code, guidance_scale=guidance_scale).images
|
44 |
+
|
45 |
+
sample_count = len(os.listdir(out_dir_ori))
|
46 |
+
out_dir = os.path.join(out_dir_ori, f"sample_{sample_count}")
|
47 |
+
os.makedirs(out_dir, exist_ok=True)
|
48 |
+
image_ori[0].save(os.path.join(out_dir, f"source_step{starting_step}_layer{starting_layer}.png"))
|
49 |
+
image_ori[1].save(os.path.join(out_dir, f"without_step{starting_step}_layer{starting_layer}.png"))
|
50 |
+
image_masactrl[-1].save(os.path.join(out_dir, f"masactrl_step{starting_step}_layer{starting_layer}.png"))
|
51 |
+
with open(os.path.join(out_dir, f"prompts.txt"), "w") as f:
|
52 |
+
for p in prompts:
|
53 |
+
f.write(p + "\n")
|
54 |
+
f.write(f"seed: {seed}\n")
|
55 |
+
f.write(f"starting_step: {starting_step}\n")
|
56 |
+
f.write(f"starting_layer: {starting_layer}\n")
|
57 |
+
print("Synthesized images are saved in", out_dir)
|
58 |
+
|
59 |
+
return [image_ori[0], image_ori[1], image_masactrl[-1]]
|
60 |
+
|
61 |
+
def create_demo_synthesis():
|
62 |
+
with gr.Blocks() as demo:
|
63 |
+
gr.Markdown("# **Genshin Impact XL MasaCtrl Image Synthesis**") # 添加标题
|
64 |
+
gr.Markdown("## **Input Settings**")
|
65 |
+
with gr.Row():
|
66 |
+
with gr.Column():
|
67 |
+
prompt1 = gr.Textbox(label="Prompt 1", value="solo,ZHONGLI(genshin impact),1boy,highres,")
|
68 |
+
prompt2 = gr.Textbox(label="Prompt 2", value="solo,ZHONGLI drink tea use chinese cup (genshin impact),1boy,highres,")
|
69 |
+
with gr.Row():
|
70 |
+
starting_step = gr.Slider(label="Starting Step", minimum=0, maximum=999, value=4, step=1)
|
71 |
+
starting_layer = gr.Slider(label="Starting Layer", minimum=0, maximum=999, value=64, step=1)
|
72 |
+
run_btn = gr.Button("Run")
|
73 |
+
with gr.Column():
|
74 |
+
guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
75 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=42, step=1)
|
76 |
+
|
77 |
+
gr.Markdown("## **Output**")
|
78 |
+
with gr.Row():
|
79 |
+
image_source = gr.Image(label="Source Image")
|
80 |
+
image_without_masactrl = gr.Image(label="Image without MasaCtrl")
|
81 |
+
image_with_masactrl = gr.Image(label="Image with MasaCtrl")
|
82 |
+
|
83 |
+
inputs = [prompt1, prompt2, guidance_scale, seed, starting_step, starting_layer]
|
84 |
+
run_btn.click(consistent_synthesis, inputs, [image_source, image_without_masactrl, image_with_masactrl])
|
85 |
+
|
86 |
+
gr.Examples(
|
87 |
+
[
|
88 |
+
["solo,ZHONGLI(genshin impact),1boy,highres,", "solo,ZHONGLI drink tea use chinese cup (genshin impact),1boy,highres,", 42, 4, 64],
|
89 |
+
["solo,KAMISATO AYATO(genshin impact),1boy,highres,", "solo,KAMISATO AYATO smiling (genshin impact),1boy,highres,", 42, 4, 55]
|
90 |
+
],
|
91 |
+
[prompt1, prompt2, seed, starting_step, starting_layer],
|
92 |
+
)
|
93 |
+
return demo
|
94 |
+
|
95 |
+
if __name__ == "__main__":
|
96 |
+
demo_synthesis = create_demo_synthesis()
|
97 |
+
demo_synthesis.launch(share = True)
|
gradio_app/app_utils.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from diffusers import DDIMScheduler
|
5 |
+
from pytorch_lightning import seed_everything
|
6 |
+
|
7 |
+
from masactrl.diffuser_utils import MasaCtrlPipeline
|
8 |
+
from masactrl.masactrl_utils import (AttentionBase,
|
9 |
+
regiter_attention_editor_diffusers)
|
10 |
+
|
11 |
+
|
12 |
+
torch.set_grad_enabled(False)
|
13 |
+
|
14 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
|
15 |
+
"cpu")
|
16 |
+
model_path = "xyn-ai/anything-v4.0"
|
17 |
+
scheduler = DDIMScheduler(beta_start=0.00085,
|
18 |
+
beta_end=0.012,
|
19 |
+
beta_schedule="scaled_linear",
|
20 |
+
clip_sample=False,
|
21 |
+
set_alpha_to_one=False)
|
22 |
+
model = MasaCtrlPipeline.from_pretrained(model_path,
|
23 |
+
scheduler=scheduler).to(device)
|
24 |
+
|
25 |
+
global_context = {
|
26 |
+
"model_path": model_path,
|
27 |
+
"scheduler": scheduler,
|
28 |
+
"model": model,
|
29 |
+
"device": device
|
30 |
+
}
|
gradio_app/image_synthesis_app.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from diffusers import DDIMScheduler
|
5 |
+
from pytorch_lightning import seed_everything
|
6 |
+
|
7 |
+
from masactrl.diffuser_utils import MasaCtrlPipeline
|
8 |
+
from masactrl.masactrl_utils import (AttentionBase,
|
9 |
+
regiter_attention_editor_diffusers)
|
10 |
+
|
11 |
+
from .app_utils import global_context
|
12 |
+
|
13 |
+
torch.set_grad_enabled(False)
|
14 |
+
|
15 |
+
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
|
16 |
+
# "cpu")
|
17 |
+
# model_path = "andite/anything-v4.0"
|
18 |
+
# scheduler = DDIMScheduler(beta_start=0.00085,
|
19 |
+
# beta_end=0.012,
|
20 |
+
# beta_schedule="scaled_linear",
|
21 |
+
# clip_sample=False,
|
22 |
+
# set_alpha_to_one=False)
|
23 |
+
# model = MasaCtrlPipeline.from_pretrained(model_path,
|
24 |
+
# scheduler=scheduler).to(device)
|
25 |
+
|
26 |
+
|
27 |
+
def consistent_synthesis(source_prompt, target_prompt, starting_step,
|
28 |
+
starting_layer, image_resolution, ddim_steps, scale,
|
29 |
+
seed, appended_prompt, negative_prompt):
|
30 |
+
from masactrl.masactrl import MutualSelfAttentionControl
|
31 |
+
|
32 |
+
model = global_context["model"]
|
33 |
+
device = global_context["device"]
|
34 |
+
|
35 |
+
seed_everything(seed)
|
36 |
+
|
37 |
+
with torch.no_grad():
|
38 |
+
if appended_prompt is not None:
|
39 |
+
source_prompt += appended_prompt
|
40 |
+
target_prompt += appended_prompt
|
41 |
+
prompts = [source_prompt, target_prompt]
|
42 |
+
|
43 |
+
# initialize the noise map
|
44 |
+
start_code = torch.randn([1, 4, 64, 64], device=device)
|
45 |
+
start_code = start_code.expand(len(prompts), -1, -1, -1)
|
46 |
+
|
47 |
+
# inference the synthesized image without MasaCtrl
|
48 |
+
editor = AttentionBase()
|
49 |
+
regiter_attention_editor_diffusers(model, editor)
|
50 |
+
target_image_ori = model([target_prompt],
|
51 |
+
latents=start_code[-1:],
|
52 |
+
guidance_scale=7.5)
|
53 |
+
target_image_ori = target_image_ori.cpu().permute(0, 2, 3, 1).numpy()
|
54 |
+
|
55 |
+
# inference the synthesized image with MasaCtrl
|
56 |
+
# hijack the attention module
|
57 |
+
controller = MutualSelfAttentionControl(starting_step, starting_layer)
|
58 |
+
regiter_attention_editor_diffusers(model, controller)
|
59 |
+
|
60 |
+
# inference the synthesized image
|
61 |
+
image_masactrl = model(prompts, latents=start_code, guidance_scale=7.5)
|
62 |
+
image_masactrl = image_masactrl.cpu().permute(0, 2, 3, 1).numpy()
|
63 |
+
|
64 |
+
return [image_masactrl[0], target_image_ori[0],
|
65 |
+
image_masactrl[1]] # source, fixed seed, masactrl
|
66 |
+
|
67 |
+
|
68 |
+
def create_demo_synthesis():
|
69 |
+
with gr.Blocks() as demo:
|
70 |
+
gr.Markdown("## **Input Settings**")
|
71 |
+
with gr.Row():
|
72 |
+
with gr.Column():
|
73 |
+
source_prompt = gr.Textbox(
|
74 |
+
label="Source Prompt",
|
75 |
+
value='1boy, casual, outdoors, sitting',
|
76 |
+
interactive=True)
|
77 |
+
target_prompt = gr.Textbox(
|
78 |
+
label="Target Prompt",
|
79 |
+
value='1boy, casual, outdoors, standing',
|
80 |
+
interactive=True)
|
81 |
+
with gr.Row():
|
82 |
+
ddim_steps = gr.Slider(label="DDIM Steps",
|
83 |
+
minimum=1,
|
84 |
+
maximum=999,
|
85 |
+
value=50,
|
86 |
+
step=1)
|
87 |
+
starting_step = gr.Slider(
|
88 |
+
label="Step of MasaCtrl",
|
89 |
+
minimum=0,
|
90 |
+
maximum=999,
|
91 |
+
value=4,
|
92 |
+
step=1)
|
93 |
+
starting_layer = gr.Slider(label="Layer of MasaCtrl",
|
94 |
+
minimum=0,
|
95 |
+
maximum=16,
|
96 |
+
value=10,
|
97 |
+
step=1)
|
98 |
+
run_btn = gr.Button(label="Run")
|
99 |
+
with gr.Column():
|
100 |
+
appended_prompt = gr.Textbox(label="Appended Prompt", value='')
|
101 |
+
negative_prompt = gr.Textbox(label="Negative Prompt", value='')
|
102 |
+
with gr.Row():
|
103 |
+
image_resolution = gr.Slider(label="Image Resolution",
|
104 |
+
minimum=256,
|
105 |
+
maximum=768,
|
106 |
+
value=512,
|
107 |
+
step=64)
|
108 |
+
scale = gr.Slider(label="CFG Scale",
|
109 |
+
minimum=0.1,
|
110 |
+
maximum=30.0,
|
111 |
+
value=7.5,
|
112 |
+
step=0.1)
|
113 |
+
seed = gr.Slider(label="Seed",
|
114 |
+
minimum=-1,
|
115 |
+
maximum=2147483647,
|
116 |
+
value=42,
|
117 |
+
step=1)
|
118 |
+
|
119 |
+
gr.Markdown("## **Output**")
|
120 |
+
with gr.Row():
|
121 |
+
image_source = gr.Image(label="Source Image")
|
122 |
+
image_fixed = gr.Image(label="Image with Fixed Seed")
|
123 |
+
image_masactrl = gr.Image(label="Image with MasaCtrl")
|
124 |
+
|
125 |
+
inputs = [
|
126 |
+
source_prompt, target_prompt, starting_step, starting_layer,
|
127 |
+
image_resolution, ddim_steps, scale, seed, appended_prompt,
|
128 |
+
negative_prompt
|
129 |
+
]
|
130 |
+
run_btn.click(consistent_synthesis, inputs,
|
131 |
+
[image_source, image_fixed, image_masactrl])
|
132 |
+
|
133 |
+
gr.Examples(
|
134 |
+
[[
|
135 |
+
"1boy, bishounen, casual, indoors, sitting, coffee shop, bokeh",
|
136 |
+
"1boy, bishounen, casual, indoors, standing, coffee shop, bokeh",
|
137 |
+
42
|
138 |
+
],
|
139 |
+
[
|
140 |
+
"1boy, casual, outdoors, sitting",
|
141 |
+
"1boy, casual, outdoors, sitting, side view", 42
|
142 |
+
],
|
143 |
+
[
|
144 |
+
"1boy, casual, outdoors, sitting",
|
145 |
+
"1boy, casual, outdoors, standing, clapping hands", 42
|
146 |
+
],
|
147 |
+
[
|
148 |
+
"1boy, casual, outdoors, sitting",
|
149 |
+
"1boy, casual, outdoors, sitting, shows thumbs up", 42
|
150 |
+
],
|
151 |
+
[
|
152 |
+
"1boy, casual, outdoors, sitting",
|
153 |
+
"1boy, casual, outdoors, sitting, with crossed arms", 42
|
154 |
+
],
|
155 |
+
[
|
156 |
+
"1boy, casual, outdoors, sitting",
|
157 |
+
"1boy, casual, outdoors, sitting, rasing hands", 42
|
158 |
+
]],
|
159 |
+
[source_prompt, target_prompt, seed],
|
160 |
+
)
|
161 |
+
return demo
|
162 |
+
|
163 |
+
|
164 |
+
if __name__ == "__main__":
|
165 |
+
demo_syntehsis = create_demo_synthesis()
|
166 |
+
demo_synthesis.launch()
|
gradio_app/images/corgi.jpg
ADDED
![]() |
gradio_app/images/person.png
ADDED
![]() |
gradio_app/real_image_editing_app.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from diffusers import DDIMScheduler
|
7 |
+
from torchvision.io import read_image
|
8 |
+
from pytorch_lightning import seed_everything
|
9 |
+
|
10 |
+
from masactrl.diffuser_utils import MasaCtrlPipeline
|
11 |
+
from masactrl.masactrl_utils import (AttentionBase,
|
12 |
+
regiter_attention_editor_diffusers)
|
13 |
+
|
14 |
+
from .app_utils import global_context
|
15 |
+
|
16 |
+
torch.set_grad_enabled(False)
|
17 |
+
|
18 |
+
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
|
19 |
+
# "cpu")
|
20 |
+
|
21 |
+
# model_path = "CompVis/stable-diffusion-v1-4"
|
22 |
+
# scheduler = DDIMScheduler(beta_start=0.00085,
|
23 |
+
# beta_end=0.012,
|
24 |
+
# beta_schedule="scaled_linear",
|
25 |
+
# clip_sample=False,
|
26 |
+
# set_alpha_to_one=False)
|
27 |
+
# model = MasaCtrlPipeline.from_pretrained(model_path,
|
28 |
+
# scheduler=scheduler).to(device)
|
29 |
+
|
30 |
+
|
31 |
+
def load_image(image_path):
|
32 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
33 |
+
image = read_image(image_path)
|
34 |
+
image = image[:3].unsqueeze_(0).float() / 127.5 - 1. # [-1, 1]
|
35 |
+
image = F.interpolate(image, (512, 512))
|
36 |
+
image = image.to(device)
|
37 |
+
|
38 |
+
|
39 |
+
def real_image_editing(source_image, target_prompt,
|
40 |
+
starting_step, starting_layer, ddim_steps, scale, seed,
|
41 |
+
appended_prompt, negative_prompt):
|
42 |
+
from masactrl.masactrl import MutualSelfAttentionControl
|
43 |
+
|
44 |
+
model = global_context["model"]
|
45 |
+
device = global_context["device"]
|
46 |
+
|
47 |
+
seed_everything(seed)
|
48 |
+
|
49 |
+
with torch.no_grad():
|
50 |
+
if appended_prompt is not None:
|
51 |
+
target_prompt += appended_prompt
|
52 |
+
ref_prompt = ""
|
53 |
+
prompts = [ref_prompt, target_prompt]
|
54 |
+
|
55 |
+
# invert the image into noise map
|
56 |
+
if isinstance(source_image, np.ndarray):
|
57 |
+
source_image = torch.from_numpy(source_image).to(device) / 127.5 - 1.
|
58 |
+
source_image = source_image.unsqueeze(0).permute(0, 3, 1, 2)
|
59 |
+
source_image = F.interpolate(source_image, (512, 512))
|
60 |
+
|
61 |
+
start_code, latents_list = model.invert(source_image,
|
62 |
+
ref_prompt,
|
63 |
+
guidance_scale=scale,
|
64 |
+
num_inference_steps=ddim_steps,
|
65 |
+
return_intermediates=True)
|
66 |
+
start_code = start_code.expand(len(prompts), -1, -1, -1)
|
67 |
+
|
68 |
+
# recontruct the image with inverted DDIM noise map
|
69 |
+
editor = AttentionBase()
|
70 |
+
regiter_attention_editor_diffusers(model, editor)
|
71 |
+
image_fixed = model([target_prompt],
|
72 |
+
latents=start_code[-1:],
|
73 |
+
num_inference_steps=ddim_steps,
|
74 |
+
guidance_scale=scale)
|
75 |
+
image_fixed = image_fixed.cpu().permute(0, 2, 3, 1).numpy()
|
76 |
+
|
77 |
+
# inference the synthesized image with MasaCtrl
|
78 |
+
# hijack the attention module
|
79 |
+
controller = MutualSelfAttentionControl(starting_step, starting_layer)
|
80 |
+
regiter_attention_editor_diffusers(model, controller)
|
81 |
+
|
82 |
+
# inference the synthesized image
|
83 |
+
image_masactrl = model(prompts,
|
84 |
+
latents=start_code,
|
85 |
+
guidance_scale=scale)
|
86 |
+
image_masactrl = image_masactrl.cpu().permute(0, 2, 3, 1).numpy()
|
87 |
+
|
88 |
+
return [
|
89 |
+
image_masactrl[0],
|
90 |
+
image_fixed[0],
|
91 |
+
image_masactrl[1]
|
92 |
+
] # source, fixed seed, masactrl
|
93 |
+
|
94 |
+
|
95 |
+
def create_demo_editing():
|
96 |
+
with gr.Blocks() as demo:
|
97 |
+
gr.Markdown("## **Input Settings**")
|
98 |
+
with gr.Row():
|
99 |
+
with gr.Column():
|
100 |
+
source_image = gr.Image(label="Source Image", value=os.path.join(os.path.dirname(__file__), "images/corgi.jpg"), interactive=True)
|
101 |
+
target_prompt = gr.Textbox(label="Target Prompt",
|
102 |
+
value='A photo of a running corgi',
|
103 |
+
interactive=True)
|
104 |
+
with gr.Row():
|
105 |
+
ddim_steps = gr.Slider(label="DDIM Steps",
|
106 |
+
minimum=1,
|
107 |
+
maximum=999,
|
108 |
+
value=50,
|
109 |
+
step=1)
|
110 |
+
starting_step = gr.Slider(label="Step of MasaCtrl",
|
111 |
+
minimum=0,
|
112 |
+
maximum=999,
|
113 |
+
value=4,
|
114 |
+
step=1)
|
115 |
+
starting_layer = gr.Slider(label="Layer of MasaCtrl",
|
116 |
+
minimum=0,
|
117 |
+
maximum=16,
|
118 |
+
value=10,
|
119 |
+
step=1)
|
120 |
+
run_btn = gr.Button(label="Run")
|
121 |
+
with gr.Column():
|
122 |
+
appended_prompt = gr.Textbox(label="Appended Prompt", value='')
|
123 |
+
negative_prompt = gr.Textbox(label="Negative Prompt", value='')
|
124 |
+
with gr.Row():
|
125 |
+
scale = gr.Slider(label="CFG Scale",
|
126 |
+
minimum=0.1,
|
127 |
+
maximum=30.0,
|
128 |
+
value=7.5,
|
129 |
+
step=0.1)
|
130 |
+
seed = gr.Slider(label="Seed",
|
131 |
+
minimum=-1,
|
132 |
+
maximum=2147483647,
|
133 |
+
value=42,
|
134 |
+
step=1)
|
135 |
+
|
136 |
+
gr.Markdown("## **Output**")
|
137 |
+
with gr.Row():
|
138 |
+
image_recons = gr.Image(label="Source Image")
|
139 |
+
image_fixed = gr.Image(label="Image with Fixed Seed")
|
140 |
+
image_masactrl = gr.Image(label="Image with MasaCtrl")
|
141 |
+
|
142 |
+
inputs = [
|
143 |
+
source_image, target_prompt, starting_step, starting_layer, ddim_steps,
|
144 |
+
scale, seed, appended_prompt, negative_prompt
|
145 |
+
]
|
146 |
+
run_btn.click(real_image_editing, inputs,
|
147 |
+
[image_recons, image_fixed, image_masactrl])
|
148 |
+
|
149 |
+
gr.Examples(
|
150 |
+
[[os.path.join(os.path.dirname(__file__), "images/corgi.jpg"),
|
151 |
+
"A photo of a running corgi"],
|
152 |
+
[os.path.join(os.path.dirname(__file__), "images/person.png"),
|
153 |
+
"A photo of a person, black t-shirt, raising hand"],
|
154 |
+
],
|
155 |
+
[source_image, target_prompt]
|
156 |
+
)
|
157 |
+
return demo
|
158 |
+
|
159 |
+
|
160 |
+
if __name__ == "__main__":
|
161 |
+
demo_editing = create_demo_editing()
|
162 |
+
demo_editing.launch()
|
masactrl/__init__.py
ADDED
File without changes
|
masactrl/diffuser_utils.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Util functions based on Diffuser framework.
|
3 |
+
"""
|
4 |
+
|
5 |
+
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
import cv2
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from tqdm import tqdm
|
13 |
+
from PIL import Image
|
14 |
+
from torchvision.utils import save_image
|
15 |
+
from torchvision.io import read_image
|
16 |
+
|
17 |
+
from diffusers import StableDiffusionPipeline
|
18 |
+
|
19 |
+
from pytorch_lightning import seed_everything
|
20 |
+
|
21 |
+
|
22 |
+
class MasaCtrlPipeline(StableDiffusionPipeline):
|
23 |
+
|
24 |
+
def next_step(
|
25 |
+
self,
|
26 |
+
model_output: torch.FloatTensor,
|
27 |
+
timestep: int,
|
28 |
+
x: torch.FloatTensor,
|
29 |
+
eta=0.,
|
30 |
+
verbose=False
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
Inverse sampling for DDIM Inversion
|
34 |
+
"""
|
35 |
+
if verbose:
|
36 |
+
print("timestep: ", timestep)
|
37 |
+
next_step = timestep
|
38 |
+
timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
|
39 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
|
40 |
+
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
|
41 |
+
beta_prod_t = 1 - alpha_prod_t
|
42 |
+
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
|
43 |
+
pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
|
44 |
+
x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
|
45 |
+
return x_next, pred_x0
|
46 |
+
|
47 |
+
def step(
|
48 |
+
self,
|
49 |
+
model_output: torch.FloatTensor,
|
50 |
+
timestep: int,
|
51 |
+
x: torch.FloatTensor,
|
52 |
+
eta: float=0.0,
|
53 |
+
verbose=False,
|
54 |
+
):
|
55 |
+
"""
|
56 |
+
predict the sampe the next step in the denoise process.
|
57 |
+
"""
|
58 |
+
prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
|
59 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
|
60 |
+
alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod
|
61 |
+
beta_prod_t = 1 - alpha_prod_t
|
62 |
+
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
|
63 |
+
pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output
|
64 |
+
x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir
|
65 |
+
return x_prev, pred_x0
|
66 |
+
|
67 |
+
@torch.no_grad()
|
68 |
+
def image2latent(self, image):
|
69 |
+
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
70 |
+
if type(image) is Image:
|
71 |
+
image = np.array(image)
|
72 |
+
image = torch.from_numpy(image).float() / 127.5 - 1
|
73 |
+
image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
|
74 |
+
# input image density range [-1, 1]
|
75 |
+
latents = self.vae.encode(image)['latent_dist'].mean
|
76 |
+
latents = latents * 0.18215
|
77 |
+
return latents
|
78 |
+
|
79 |
+
@torch.no_grad()
|
80 |
+
def latent2image(self, latents, return_type='np'):
|
81 |
+
latents = 1 / 0.18215 * latents.detach()
|
82 |
+
image = self.vae.decode(latents)['sample']
|
83 |
+
if return_type == 'np':
|
84 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
85 |
+
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
|
86 |
+
image = (image * 255).astype(np.uint8)
|
87 |
+
elif return_type == "pt":
|
88 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
89 |
+
|
90 |
+
return image
|
91 |
+
|
92 |
+
def latent2image_grad(self, latents):
|
93 |
+
latents = 1 / 0.18215 * latents
|
94 |
+
image = self.vae.decode(latents)['sample']
|
95 |
+
|
96 |
+
return image # range [-1, 1]
|
97 |
+
|
98 |
+
@torch.no_grad()
|
99 |
+
def __call__(
|
100 |
+
self,
|
101 |
+
prompt,
|
102 |
+
batch_size=1,
|
103 |
+
height=512,
|
104 |
+
width=512,
|
105 |
+
num_inference_steps=50,
|
106 |
+
guidance_scale=7.5,
|
107 |
+
eta=0.0,
|
108 |
+
latents=None,
|
109 |
+
unconditioning=None,
|
110 |
+
neg_prompt=None,
|
111 |
+
ref_intermediate_latents=None,
|
112 |
+
return_intermediates=False,
|
113 |
+
**kwds):
|
114 |
+
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
115 |
+
if isinstance(prompt, list):
|
116 |
+
batch_size = len(prompt)
|
117 |
+
elif isinstance(prompt, str):
|
118 |
+
if batch_size > 1:
|
119 |
+
prompt = [prompt] * batch_size
|
120 |
+
|
121 |
+
# text embeddings
|
122 |
+
text_input = self.tokenizer(
|
123 |
+
prompt,
|
124 |
+
padding="max_length",
|
125 |
+
max_length=77,
|
126 |
+
return_tensors="pt"
|
127 |
+
)
|
128 |
+
|
129 |
+
text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
|
130 |
+
print("input text embeddings :", text_embeddings.shape)
|
131 |
+
if kwds.get("dir"):
|
132 |
+
dir = text_embeddings[-2] - text_embeddings[-1]
|
133 |
+
u, s, v = torch.pca_lowrank(dir.transpose(-1, -2), q=1, center=True)
|
134 |
+
text_embeddings[-1] = text_embeddings[-1] + kwds.get("dir") * v
|
135 |
+
print(u.shape)
|
136 |
+
print(v.shape)
|
137 |
+
|
138 |
+
# define initial latents
|
139 |
+
latents_shape = (batch_size, self.unet.in_channels, height//8, width//8)
|
140 |
+
if latents is None:
|
141 |
+
latents = torch.randn(latents_shape, device=DEVICE)
|
142 |
+
else:
|
143 |
+
assert latents.shape == latents_shape, f"The shape of input latent tensor {latents.shape} should equal to predefined one."
|
144 |
+
|
145 |
+
# unconditional embedding for classifier free guidance
|
146 |
+
if guidance_scale > 1.:
|
147 |
+
max_length = text_input.input_ids.shape[-1]
|
148 |
+
if neg_prompt:
|
149 |
+
uc_text = neg_prompt
|
150 |
+
else:
|
151 |
+
uc_text = ""
|
152 |
+
# uc_text = "ugly, tiling, poorly drawn hands, poorly drawn feet, body out of frame, cut off, low contrast, underexposed, distorted face"
|
153 |
+
unconditional_input = self.tokenizer(
|
154 |
+
[uc_text] * batch_size,
|
155 |
+
padding="max_length",
|
156 |
+
max_length=77,
|
157 |
+
return_tensors="pt"
|
158 |
+
)
|
159 |
+
# unconditional_input.input_ids = unconditional_input.input_ids[:, 1:]
|
160 |
+
unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0]
|
161 |
+
text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0)
|
162 |
+
|
163 |
+
print("latents shape: ", latents.shape)
|
164 |
+
# iterative sampling
|
165 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
166 |
+
# print("Valid timesteps: ", reversed(self.scheduler.timesteps))
|
167 |
+
latents_list = [latents]
|
168 |
+
pred_x0_list = [latents]
|
169 |
+
for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="DDIM Sampler")):
|
170 |
+
if ref_intermediate_latents is not None:
|
171 |
+
# note that the batch_size >= 2
|
172 |
+
latents_ref = ref_intermediate_latents[-1 - i]
|
173 |
+
_, latents_cur = latents.chunk(2)
|
174 |
+
latents = torch.cat([latents_ref, latents_cur])
|
175 |
+
|
176 |
+
if guidance_scale > 1.:
|
177 |
+
model_inputs = torch.cat([latents] * 2)
|
178 |
+
else:
|
179 |
+
model_inputs = latents
|
180 |
+
if unconditioning is not None and isinstance(unconditioning, list):
|
181 |
+
_, text_embeddings = text_embeddings.chunk(2)
|
182 |
+
text_embeddings = torch.cat([unconditioning[i].expand(*text_embeddings.shape), text_embeddings])
|
183 |
+
# predict tghe noise
|
184 |
+
noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample
|
185 |
+
if guidance_scale > 1.:
|
186 |
+
noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
|
187 |
+
noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
|
188 |
+
# compute the previous noise sample x_t -> x_t-1
|
189 |
+
latents, pred_x0 = self.step(noise_pred, t, latents)
|
190 |
+
latents_list.append(latents)
|
191 |
+
pred_x0_list.append(pred_x0)
|
192 |
+
|
193 |
+
image = self.latent2image(latents, return_type="pt")
|
194 |
+
if return_intermediates:
|
195 |
+
pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list]
|
196 |
+
latents_list = [self.latent2image(img, return_type="pt") for img in latents_list]
|
197 |
+
return image, pred_x0_list, latents_list
|
198 |
+
return image
|
199 |
+
|
200 |
+
@torch.no_grad()
|
201 |
+
def invert(
|
202 |
+
self,
|
203 |
+
image: torch.Tensor,
|
204 |
+
prompt,
|
205 |
+
num_inference_steps=50,
|
206 |
+
guidance_scale=7.5,
|
207 |
+
eta=0.0,
|
208 |
+
return_intermediates=False,
|
209 |
+
**kwds):
|
210 |
+
"""
|
211 |
+
invert a real image into noise map with determinisc DDIM inversion
|
212 |
+
"""
|
213 |
+
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
214 |
+
batch_size = image.shape[0]
|
215 |
+
if isinstance(prompt, list):
|
216 |
+
if batch_size == 1:
|
217 |
+
image = image.expand(len(prompt), -1, -1, -1)
|
218 |
+
elif isinstance(prompt, str):
|
219 |
+
if batch_size > 1:
|
220 |
+
prompt = [prompt] * batch_size
|
221 |
+
|
222 |
+
# text embeddings
|
223 |
+
text_input = self.tokenizer(
|
224 |
+
prompt,
|
225 |
+
padding="max_length",
|
226 |
+
max_length=77,
|
227 |
+
return_tensors="pt"
|
228 |
+
)
|
229 |
+
text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
|
230 |
+
print("input text embeddings :", text_embeddings.shape)
|
231 |
+
# define initial latents
|
232 |
+
latents = self.image2latent(image)
|
233 |
+
start_latents = latents
|
234 |
+
# print(latents)
|
235 |
+
# exit()
|
236 |
+
# unconditional embedding for classifier free guidance
|
237 |
+
if guidance_scale > 1.:
|
238 |
+
max_length = text_input.input_ids.shape[-1]
|
239 |
+
unconditional_input = self.tokenizer(
|
240 |
+
[""] * batch_size,
|
241 |
+
padding="max_length",
|
242 |
+
max_length=77,
|
243 |
+
return_tensors="pt"
|
244 |
+
)
|
245 |
+
unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0]
|
246 |
+
text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0)
|
247 |
+
|
248 |
+
print("latents shape: ", latents.shape)
|
249 |
+
# interative sampling
|
250 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
251 |
+
print("Valid timesteps: ", reversed(self.scheduler.timesteps))
|
252 |
+
# print("attributes: ", self.scheduler.__dict__)
|
253 |
+
latents_list = [latents]
|
254 |
+
pred_x0_list = [latents]
|
255 |
+
for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
|
256 |
+
if guidance_scale > 1.:
|
257 |
+
model_inputs = torch.cat([latents] * 2)
|
258 |
+
else:
|
259 |
+
model_inputs = latents
|
260 |
+
|
261 |
+
# predict the noise
|
262 |
+
noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample
|
263 |
+
if guidance_scale > 1.:
|
264 |
+
noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
|
265 |
+
noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
|
266 |
+
# compute the previous noise sample x_t-1 -> x_t
|
267 |
+
latents, pred_x0 = self.next_step(noise_pred, t, latents)
|
268 |
+
latents_list.append(latents)
|
269 |
+
pred_x0_list.append(pred_x0)
|
270 |
+
|
271 |
+
if return_intermediates:
|
272 |
+
# return the intermediate laters during inversion
|
273 |
+
# pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list]
|
274 |
+
return latents, latents_list
|
275 |
+
return latents, start_latents
|
masactrl/masactrl.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from einops import rearrange
|
8 |
+
|
9 |
+
from .masactrl_utils import AttentionBase
|
10 |
+
|
11 |
+
from torchvision.utils import save_image
|
12 |
+
|
13 |
+
|
14 |
+
class MutualSelfAttentionControl(AttentionBase):
|
15 |
+
MODEL_TYPE = {
|
16 |
+
"SD": 16,
|
17 |
+
"SDXL": 70
|
18 |
+
}
|
19 |
+
|
20 |
+
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, model_type="SD"):
|
21 |
+
"""
|
22 |
+
Mutual self-attention control for Stable-Diffusion model
|
23 |
+
Args:
|
24 |
+
start_step: the step to start mutual self-attention control
|
25 |
+
start_layer: the layer to start mutual self-attention control
|
26 |
+
layer_idx: list of the layers to apply mutual self-attention control
|
27 |
+
step_idx: list the steps to apply mutual self-attention control
|
28 |
+
total_steps: the total number of steps
|
29 |
+
model_type: the model type, SD or SDXL
|
30 |
+
"""
|
31 |
+
super().__init__()
|
32 |
+
self.total_steps = total_steps
|
33 |
+
self.total_layers = self.MODEL_TYPE.get(model_type, 16)
|
34 |
+
self.start_step = start_step
|
35 |
+
self.start_layer = start_layer
|
36 |
+
self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers))
|
37 |
+
self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps))
|
38 |
+
print("MasaCtrl at denoising steps: ", self.step_idx)
|
39 |
+
print("MasaCtrl at U-Net layers: ", self.layer_idx)
|
40 |
+
|
41 |
+
def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
42 |
+
"""
|
43 |
+
Performing attention for a batch of queries, keys, and values
|
44 |
+
"""
|
45 |
+
b = q.shape[0] // num_heads
|
46 |
+
q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
|
47 |
+
k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
|
48 |
+
v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
|
49 |
+
|
50 |
+
sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
|
51 |
+
attn = sim.softmax(-1)
|
52 |
+
out = torch.einsum("h i j, h j d -> h i d", attn, v)
|
53 |
+
out = rearrange(out, "h (b n) d -> b n (h d)", b=b)
|
54 |
+
return out
|
55 |
+
|
56 |
+
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
57 |
+
"""
|
58 |
+
Attention forward function
|
59 |
+
"""
|
60 |
+
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
|
61 |
+
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
|
62 |
+
|
63 |
+
qu, qc = q.chunk(2)
|
64 |
+
ku, kc = k.chunk(2)
|
65 |
+
vu, vc = v.chunk(2)
|
66 |
+
attnu, attnc = attn.chunk(2)
|
67 |
+
|
68 |
+
out_u = self.attn_batch(qu, ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)
|
69 |
+
out_c = self.attn_batch(qc, kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)
|
70 |
+
out = torch.cat([out_u, out_c], dim=0)
|
71 |
+
|
72 |
+
return out
|
73 |
+
|
74 |
+
|
75 |
+
class MutualSelfAttentionControlUnion(MutualSelfAttentionControl):
|
76 |
+
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, model_type="SD"):
|
77 |
+
"""
|
78 |
+
Mutual self-attention control for Stable-Diffusion model with unition source and target [K, V]
|
79 |
+
Args:
|
80 |
+
start_step: the step to start mutual self-attention control
|
81 |
+
start_layer: the layer to start mutual self-attention control
|
82 |
+
layer_idx: list of the layers to apply mutual self-attention control
|
83 |
+
step_idx: list the steps to apply mutual self-attention control
|
84 |
+
total_steps: the total number of steps
|
85 |
+
model_type: the model type, SD or SDXL
|
86 |
+
"""
|
87 |
+
super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps, model_type)
|
88 |
+
|
89 |
+
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
90 |
+
"""
|
91 |
+
Attention forward function
|
92 |
+
"""
|
93 |
+
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
|
94 |
+
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
|
95 |
+
|
96 |
+
qu_s, qu_t, qc_s, qc_t = q.chunk(4)
|
97 |
+
ku_s, ku_t, kc_s, kc_t = k.chunk(4)
|
98 |
+
vu_s, vu_t, vc_s, vc_t = v.chunk(4)
|
99 |
+
attnu_s, attnu_t, attnc_s, attnc_t = attn.chunk(4)
|
100 |
+
|
101 |
+
# source image branch
|
102 |
+
out_u_s = super().forward(qu_s, ku_s, vu_s, sim, attnu_s, is_cross, place_in_unet, num_heads, **kwargs)
|
103 |
+
out_c_s = super().forward(qc_s, kc_s, vc_s, sim, attnc_s, is_cross, place_in_unet, num_heads, **kwargs)
|
104 |
+
|
105 |
+
# target image branch, concatenating source and target [K, V]
|
106 |
+
out_u_t = self.attn_batch(qu_t, torch.cat([ku_s, ku_t]), torch.cat([vu_s, vu_t]), sim[:num_heads], attnu_t, is_cross, place_in_unet, num_heads, **kwargs)
|
107 |
+
out_c_t = self.attn_batch(qc_t, torch.cat([kc_s, kc_t]), torch.cat([vc_s, vc_t]), sim[:num_heads], attnc_t, is_cross, place_in_unet, num_heads, **kwargs)
|
108 |
+
|
109 |
+
out = torch.cat([out_u_s, out_u_t, out_c_s, out_c_t], dim=0)
|
110 |
+
|
111 |
+
return out
|
112 |
+
|
113 |
+
|
114 |
+
class MutualSelfAttentionControlMask(MutualSelfAttentionControl):
|
115 |
+
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, mask_s=None, mask_t=None, mask_save_dir=None, model_type="SD"):
|
116 |
+
"""
|
117 |
+
Maske-guided MasaCtrl to alleviate the problem of fore- and background confusion
|
118 |
+
Args:
|
119 |
+
start_step: the step to start mutual self-attention control
|
120 |
+
start_layer: the layer to start mutual self-attention control
|
121 |
+
layer_idx: list of the layers to apply mutual self-attention control
|
122 |
+
step_idx: list the steps to apply mutual self-attention control
|
123 |
+
total_steps: the total number of steps
|
124 |
+
mask_s: source mask with shape (h, w)
|
125 |
+
mask_t: target mask with same shape as source mask
|
126 |
+
mask_save_dir: the path to save the mask image
|
127 |
+
model_type: the model type, SD or SDXL
|
128 |
+
"""
|
129 |
+
super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps, model_type)
|
130 |
+
self.mask_s = mask_s # source mask with shape (h, w)
|
131 |
+
self.mask_t = mask_t # target mask with same shape as source mask
|
132 |
+
print("Using mask-guided MasaCtrl")
|
133 |
+
if mask_save_dir is not None:
|
134 |
+
os.makedirs(mask_save_dir, exist_ok=True)
|
135 |
+
save_image(self.mask_s.unsqueeze(0).unsqueeze(0), os.path.join(mask_save_dir, "mask_s.png"))
|
136 |
+
save_image(self.mask_t.unsqueeze(0).unsqueeze(0), os.path.join(mask_save_dir, "mask_t.png"))
|
137 |
+
|
138 |
+
def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
139 |
+
B = q.shape[0] // num_heads
|
140 |
+
H = W = int(np.sqrt(q.shape[1]))
|
141 |
+
q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
|
142 |
+
k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
|
143 |
+
v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
|
144 |
+
|
145 |
+
sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
|
146 |
+
if kwargs.get("is_mask_attn") and self.mask_s is not None:
|
147 |
+
print("masked attention")
|
148 |
+
mask = self.mask_s.unsqueeze(0).unsqueeze(0)
|
149 |
+
mask = F.interpolate(mask, (H, W)).flatten(0).unsqueeze(0)
|
150 |
+
mask = mask.flatten()
|
151 |
+
# background
|
152 |
+
sim_bg = sim + mask.masked_fill(mask == 1, torch.finfo(sim.dtype).min)
|
153 |
+
# object
|
154 |
+
sim_fg = sim + mask.masked_fill(mask == 0, torch.finfo(sim.dtype).min)
|
155 |
+
sim = torch.cat([sim_fg, sim_bg], dim=0)
|
156 |
+
attn = sim.softmax(-1)
|
157 |
+
if len(attn) == 2 * len(v):
|
158 |
+
v = torch.cat([v] * 2)
|
159 |
+
out = torch.einsum("h i j, h j d -> h i d", attn, v)
|
160 |
+
out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
|
161 |
+
return out
|
162 |
+
|
163 |
+
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
164 |
+
"""
|
165 |
+
Attention forward function
|
166 |
+
"""
|
167 |
+
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
|
168 |
+
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
|
169 |
+
|
170 |
+
B = q.shape[0] // num_heads // 2
|
171 |
+
H = W = int(np.sqrt(q.shape[1]))
|
172 |
+
qu, qc = q.chunk(2)
|
173 |
+
ku, kc = k.chunk(2)
|
174 |
+
vu, vc = v.chunk(2)
|
175 |
+
attnu, attnc = attn.chunk(2)
|
176 |
+
|
177 |
+
out_u_source = self.attn_batch(qu[:num_heads], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)
|
178 |
+
out_c_source = self.attn_batch(qc[:num_heads], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)
|
179 |
+
|
180 |
+
out_u_target = self.attn_batch(qu[-num_heads:], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, is_mask_attn=True, **kwargs)
|
181 |
+
out_c_target = self.attn_batch(qc[-num_heads:], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, is_mask_attn=True, **kwargs)
|
182 |
+
|
183 |
+
if self.mask_s is not None and self.mask_t is not None:
|
184 |
+
out_u_target_fg, out_u_target_bg = out_u_target.chunk(2, 0)
|
185 |
+
out_c_target_fg, out_c_target_bg = out_c_target.chunk(2, 0)
|
186 |
+
|
187 |
+
mask = F.interpolate(self.mask_t.unsqueeze(0).unsqueeze(0), (H, W))
|
188 |
+
mask = mask.reshape(-1, 1) # (hw, 1)
|
189 |
+
out_u_target = out_u_target_fg * mask + out_u_target_bg * (1 - mask)
|
190 |
+
out_c_target = out_c_target_fg * mask + out_c_target_bg * (1 - mask)
|
191 |
+
|
192 |
+
out = torch.cat([out_u_source, out_u_target, out_c_source, out_c_target], dim=0)
|
193 |
+
return out
|
194 |
+
|
195 |
+
|
196 |
+
class MutualSelfAttentionControlMaskAuto(MutualSelfAttentionControl):
|
197 |
+
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, thres=0.1, ref_token_idx=[1], cur_token_idx=[1], mask_save_dir=None, model_type="SD"):
|
198 |
+
"""
|
199 |
+
MasaCtrl with mask auto generation from cross-attention map
|
200 |
+
Args:
|
201 |
+
start_step: the step to start mutual self-attention control
|
202 |
+
start_layer: the layer to start mutual self-attention control
|
203 |
+
layer_idx: list of the layers to apply mutual self-attention control
|
204 |
+
step_idx: list the steps to apply mutual self-attention control
|
205 |
+
total_steps: the total number of steps
|
206 |
+
thres: the thereshold for mask thresholding
|
207 |
+
ref_token_idx: the token index list for cross-attention map aggregation
|
208 |
+
cur_token_idx: the token index list for cross-attention map aggregation
|
209 |
+
mask_save_dir: the path to save the mask image
|
210 |
+
"""
|
211 |
+
super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps, model_type)
|
212 |
+
print("Using MutualSelfAttentionControlMaskAuto")
|
213 |
+
self.thres = thres
|
214 |
+
self.ref_token_idx = ref_token_idx
|
215 |
+
self.cur_token_idx = cur_token_idx
|
216 |
+
|
217 |
+
self.self_attns = []
|
218 |
+
self.cross_attns = []
|
219 |
+
|
220 |
+
self.cross_attns_mask = None
|
221 |
+
self.self_attns_mask = None
|
222 |
+
|
223 |
+
self.mask_save_dir = mask_save_dir
|
224 |
+
if self.mask_save_dir is not None:
|
225 |
+
os.makedirs(self.mask_save_dir, exist_ok=True)
|
226 |
+
|
227 |
+
def after_step(self):
|
228 |
+
self.self_attns = []
|
229 |
+
self.cross_attns = []
|
230 |
+
|
231 |
+
def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
232 |
+
"""
|
233 |
+
Performing attention for a batch of queries, keys, and values
|
234 |
+
"""
|
235 |
+
B = q.shape[0] // num_heads
|
236 |
+
H = W = int(np.sqrt(q.shape[1]))
|
237 |
+
q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
|
238 |
+
k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
|
239 |
+
v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
|
240 |
+
|
241 |
+
sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
|
242 |
+
if self.self_attns_mask is not None:
|
243 |
+
# binarize the mask
|
244 |
+
mask = self.self_attns_mask
|
245 |
+
thres = self.thres
|
246 |
+
mask[mask >= thres] = 1
|
247 |
+
mask[mask < thres] = 0
|
248 |
+
sim_fg = sim + mask.masked_fill(mask == 0, torch.finfo(sim.dtype).min)
|
249 |
+
sim_bg = sim + mask.masked_fill(mask == 1, torch.finfo(sim.dtype).min)
|
250 |
+
sim = torch.cat([sim_fg, sim_bg])
|
251 |
+
|
252 |
+
attn = sim.softmax(-1)
|
253 |
+
|
254 |
+
if len(attn) == 2 * len(v):
|
255 |
+
v = torch.cat([v] * 2)
|
256 |
+
out = torch.einsum("h i j, h j d -> h i d", attn, v)
|
257 |
+
out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
|
258 |
+
return out
|
259 |
+
|
260 |
+
def aggregate_cross_attn_map(self, idx):
|
261 |
+
attn_map = torch.stack(self.cross_attns, dim=1).mean(1) # (B, N, dim)
|
262 |
+
B = attn_map.shape[0]
|
263 |
+
res = int(np.sqrt(attn_map.shape[-2]))
|
264 |
+
attn_map = attn_map.reshape(-1, res, res, attn_map.shape[-1])
|
265 |
+
image = attn_map[..., idx]
|
266 |
+
if isinstance(idx, list):
|
267 |
+
image = image.sum(-1)
|
268 |
+
image_min = image.min(dim=1, keepdim=True)[0].min(dim=2, keepdim=True)[0]
|
269 |
+
image_max = image.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0]
|
270 |
+
image = (image - image_min) / (image_max - image_min)
|
271 |
+
return image
|
272 |
+
|
273 |
+
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
274 |
+
"""
|
275 |
+
Attention forward function
|
276 |
+
"""
|
277 |
+
if is_cross:
|
278 |
+
# save cross attention map with res 16 * 16
|
279 |
+
if attn.shape[1] == 16 * 16:
|
280 |
+
self.cross_attns.append(attn.reshape(-1, num_heads, *attn.shape[-2:]).mean(1))
|
281 |
+
|
282 |
+
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
|
283 |
+
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
|
284 |
+
|
285 |
+
B = q.shape[0] // num_heads // 2
|
286 |
+
H = W = int(np.sqrt(q.shape[1]))
|
287 |
+
qu, qc = q.chunk(2)
|
288 |
+
ku, kc = k.chunk(2)
|
289 |
+
vu, vc = v.chunk(2)
|
290 |
+
attnu, attnc = attn.chunk(2)
|
291 |
+
|
292 |
+
out_u_source = self.attn_batch(qu[:num_heads], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)
|
293 |
+
out_c_source = self.attn_batch(qc[:num_heads], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)
|
294 |
+
|
295 |
+
if len(self.cross_attns) == 0:
|
296 |
+
self.self_attns_mask = None
|
297 |
+
out_u_target = self.attn_batch(qu[-num_heads:], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)
|
298 |
+
out_c_target = self.attn_batch(qc[-num_heads:], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)
|
299 |
+
else:
|
300 |
+
mask = self.aggregate_cross_attn_map(idx=self.ref_token_idx) # (2, H, W)
|
301 |
+
mask_source = mask[-2] # (H, W)
|
302 |
+
res = int(np.sqrt(q.shape[1]))
|
303 |
+
self.self_attns_mask = F.interpolate(mask_source.unsqueeze(0).unsqueeze(0), (res, res)).flatten()
|
304 |
+
if self.mask_save_dir is not None:
|
305 |
+
H = W = int(np.sqrt(self.self_attns_mask.shape[0]))
|
306 |
+
mask_image = self.self_attns_mask.reshape(H, W).unsqueeze(0)
|
307 |
+
save_image(mask_image, os.path.join(self.mask_save_dir, f"mask_s_{self.cur_step}_{self.cur_att_layer}.png"))
|
308 |
+
out_u_target = self.attn_batch(qu[-num_heads:], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)
|
309 |
+
out_c_target = self.attn_batch(qc[-num_heads:], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)
|
310 |
+
|
311 |
+
if self.self_attns_mask is not None:
|
312 |
+
mask = self.aggregate_cross_attn_map(idx=self.cur_token_idx) # (2, H, W)
|
313 |
+
mask_target = mask[-1] # (H, W)
|
314 |
+
res = int(np.sqrt(q.shape[1]))
|
315 |
+
spatial_mask = F.interpolate(mask_target.unsqueeze(0).unsqueeze(0), (res, res)).reshape(-1, 1)
|
316 |
+
if self.mask_save_dir is not None:
|
317 |
+
H = W = int(np.sqrt(spatial_mask.shape[0]))
|
318 |
+
mask_image = spatial_mask.reshape(H, W).unsqueeze(0)
|
319 |
+
save_image(mask_image, os.path.join(self.mask_save_dir, f"mask_t_{self.cur_step}_{self.cur_att_layer}.png"))
|
320 |
+
# binarize the mask
|
321 |
+
thres = self.thres
|
322 |
+
spatial_mask[spatial_mask >= thres] = 1
|
323 |
+
spatial_mask[spatial_mask < thres] = 0
|
324 |
+
out_u_target_fg, out_u_target_bg = out_u_target.chunk(2)
|
325 |
+
out_c_target_fg, out_c_target_bg = out_c_target.chunk(2)
|
326 |
+
|
327 |
+
out_u_target = out_u_target_fg * spatial_mask + out_u_target_bg * (1 - spatial_mask)
|
328 |
+
out_c_target = out_c_target_fg * spatial_mask + out_c_target_bg * (1 - spatial_mask)
|
329 |
+
|
330 |
+
# set self self-attention mask to None
|
331 |
+
self.self_attns_mask = None
|
332 |
+
|
333 |
+
out = torch.cat([out_u_source, out_u_target, out_c_source, out_c_target], dim=0)
|
334 |
+
return out
|
masactrl/masactrl_processor.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from importlib import import_module
|
2 |
+
from typing import Callable, Optional, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from diffusers.utils import deprecate, logging
|
9 |
+
from diffusers.utils.import_utils import is_xformers_available
|
10 |
+
from diffusers.models.attention import Attention
|
11 |
+
|
12 |
+
|
13 |
+
def register_attention_processor(
|
14 |
+
model: Optional[nn.Module] = None,
|
15 |
+
processor_type: str = "MasaCtrlProcessor",
|
16 |
+
**attn_args,
|
17 |
+
):
|
18 |
+
"""
|
19 |
+
Args:
|
20 |
+
model: a unet model or a list of unet models
|
21 |
+
processor_type: the type of the processor
|
22 |
+
"""
|
23 |
+
if not isinstance(model, (list, tuple)):
|
24 |
+
model = [model]
|
25 |
+
|
26 |
+
if processor_type == "MasaCtrlProcessor":
|
27 |
+
processor = MasaCtrlProcessor(**attn_args)
|
28 |
+
else:
|
29 |
+
processor = AttnProcessor()
|
30 |
+
|
31 |
+
for m in model:
|
32 |
+
m.set_attn_processor(processor)
|
33 |
+
print(f"Model {m.__class__.__name__} is registered attention processor: {processor_type}")
|
34 |
+
|
35 |
+
|
36 |
+
class AttnProcessor:
|
37 |
+
r"""
|
38 |
+
Default processor for performing attention-related computations.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __call__(
|
42 |
+
self,
|
43 |
+
attn: Attention,
|
44 |
+
hidden_states: torch.Tensor,
|
45 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
46 |
+
attention_mask: Optional[torch.Tensor] = None,
|
47 |
+
temb: Optional[torch.Tensor] = None,
|
48 |
+
*args,
|
49 |
+
**kwargs,
|
50 |
+
) -> torch.Tensor:
|
51 |
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
52 |
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
53 |
+
deprecate("scale", "1.0.0", deprecation_message)
|
54 |
+
|
55 |
+
residual = hidden_states
|
56 |
+
|
57 |
+
if attn.spatial_norm is not None:
|
58 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
59 |
+
|
60 |
+
input_ndim = hidden_states.ndim
|
61 |
+
|
62 |
+
if input_ndim == 4:
|
63 |
+
batch_size, channel, height, width = hidden_states.shape
|
64 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
65 |
+
|
66 |
+
batch_size, sequence_length, _ = (
|
67 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
68 |
+
)
|
69 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
70 |
+
|
71 |
+
if attn.group_norm is not None:
|
72 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
73 |
+
|
74 |
+
query = attn.to_q(hidden_states)
|
75 |
+
|
76 |
+
if encoder_hidden_states is None:
|
77 |
+
encoder_hidden_states = hidden_states
|
78 |
+
elif attn.norm_cross:
|
79 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
80 |
+
|
81 |
+
key = attn.to_k(encoder_hidden_states)
|
82 |
+
value = attn.to_v(encoder_hidden_states)
|
83 |
+
|
84 |
+
query = attn.head_to_batch_dim(query)
|
85 |
+
key = attn.head_to_batch_dim(key)
|
86 |
+
value = attn.head_to_batch_dim(value)
|
87 |
+
|
88 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
89 |
+
hidden_states = torch.bmm(attention_probs, value)
|
90 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
91 |
+
|
92 |
+
# linear proj
|
93 |
+
hidden_states = attn.to_out[0](hidden_states)
|
94 |
+
# dropout
|
95 |
+
hidden_states = attn.to_out[1](hidden_states)
|
96 |
+
|
97 |
+
if input_ndim == 4:
|
98 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
99 |
+
|
100 |
+
if attn.residual_connection:
|
101 |
+
hidden_states = hidden_states + residual
|
102 |
+
|
103 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
104 |
+
|
105 |
+
return hidden_states
|
106 |
+
|
107 |
+
|
108 |
+
class MasaCtrlProcessor(nn.Module):
|
109 |
+
"""
|
110 |
+
Mutual Self-attention Processor for diffusers library.
|
111 |
+
Note that the all attention layers should register the same processor.
|
112 |
+
"""
|
113 |
+
MODEL_TYPE = {
|
114 |
+
"SD": 16,
|
115 |
+
"SDXL": 70
|
116 |
+
}
|
117 |
+
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_layers=32, total_steps=50, model_type="SD"):
|
118 |
+
"""
|
119 |
+
Mutual self-attention control for Stable-Diffusion model
|
120 |
+
Args:
|
121 |
+
start_step: the step to start mutual self-attention control
|
122 |
+
start_layer: the layer to start mutual self-attention control
|
123 |
+
layer_idx: list of the layers to apply mutual self-attention control
|
124 |
+
step_idx: list the steps to apply mutual self-attention control
|
125 |
+
total_steps: the total number of steps, must be same to the denoising steps used in denoising scheduler
|
126 |
+
model_type: the model type, SD or SDXL
|
127 |
+
"""
|
128 |
+
super().__init__()
|
129 |
+
self.total_steps = total_steps
|
130 |
+
self.total_layers = self.MODEL_TYPE.get(model_type, 16)
|
131 |
+
self.start_step = start_step
|
132 |
+
self.start_layer = start_layer
|
133 |
+
self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers))
|
134 |
+
self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps))
|
135 |
+
print("MasaCtrl at denoising steps: ", self.step_idx)
|
136 |
+
print("MasaCtrl at U-Net layers: ", self.layer_idx)
|
137 |
+
|
138 |
+
self.cur_step = 0
|
139 |
+
self.cur_att_layer = 0
|
140 |
+
self.num_attn_layers = total_layers
|
141 |
+
|
142 |
+
def after_step(self):
|
143 |
+
pass
|
144 |
+
|
145 |
+
def __call__(
|
146 |
+
self,
|
147 |
+
attn: Attention,
|
148 |
+
hidden_states: torch.FloatTensor,
|
149 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
150 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
151 |
+
temb: Optional[torch.FloatTensor] = None,
|
152 |
+
scale: float = 1.0,
|
153 |
+
):
|
154 |
+
out = self.attn_forward(
|
155 |
+
attn,
|
156 |
+
hidden_states,
|
157 |
+
encoder_hidden_states,
|
158 |
+
attention_mask,
|
159 |
+
temb,
|
160 |
+
scale,
|
161 |
+
)
|
162 |
+
self.cur_att_layer += 1
|
163 |
+
if self.cur_att_layer == self.num_attn_layers:
|
164 |
+
self.cur_att_layer = 0
|
165 |
+
self.cur_step += 1
|
166 |
+
self.cur_step %= self.total_steps
|
167 |
+
# after step
|
168 |
+
self.after_step()
|
169 |
+
return out
|
170 |
+
|
171 |
+
def masactrl_forward(
|
172 |
+
self,
|
173 |
+
query,
|
174 |
+
key,
|
175 |
+
value,
|
176 |
+
):
|
177 |
+
"""
|
178 |
+
Rearrange the key and value for mutual self-attention control
|
179 |
+
"""
|
180 |
+
ku_src, ku_tgt, kc_src, kc_tgt = key.chunk(4)
|
181 |
+
vu_src, vu_tgt, vc_src, vc_tgt = value.chunk(4)
|
182 |
+
|
183 |
+
k_rearranged = torch.cat([ku_src, ku_src, kc_src, kc_src])
|
184 |
+
v_rearranged = torch.cat([vu_src, vu_src, vc_src, vc_src])
|
185 |
+
|
186 |
+
return query, k_rearranged, v_rearranged
|
187 |
+
|
188 |
+
def attn_forward(
|
189 |
+
self,
|
190 |
+
attn: Attention,
|
191 |
+
hidden_states: torch.Tensor,
|
192 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
193 |
+
attention_mask: Optional[torch.Tensor] = None,
|
194 |
+
temb: Optional[torch.Tensor] = None,
|
195 |
+
*args,
|
196 |
+
**kwargs,
|
197 |
+
):
|
198 |
+
cur_transformer_layer = self.cur_att_layer // 2
|
199 |
+
residual = hidden_states
|
200 |
+
|
201 |
+
is_cross = True if encoder_hidden_states is not None else False
|
202 |
+
|
203 |
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
204 |
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
205 |
+
deprecate("scale", "1.0.0", deprecation_message)
|
206 |
+
|
207 |
+
if attn.spatial_norm is not None:
|
208 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
209 |
+
|
210 |
+
input_ndim = hidden_states.ndim
|
211 |
+
|
212 |
+
if input_ndim == 4:
|
213 |
+
batch_size, channel, height, width = hidden_states.shape
|
214 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
215 |
+
|
216 |
+
batch_size, sequence_length, _ = (
|
217 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
218 |
+
)
|
219 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
220 |
+
|
221 |
+
if attn.group_norm is not None:
|
222 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
223 |
+
|
224 |
+
query = attn.to_q(hidden_states, *args)
|
225 |
+
|
226 |
+
if encoder_hidden_states is None:
|
227 |
+
encoder_hidden_states = hidden_states
|
228 |
+
elif attn.norm_cross:
|
229 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
230 |
+
|
231 |
+
key = attn.to_k(encoder_hidden_states, *args)
|
232 |
+
value = attn.to_v(encoder_hidden_states, *args)
|
233 |
+
|
234 |
+
# mutual self-attention control
|
235 |
+
if not is_cross and self.cur_step in self.step_idx and cur_transformer_layer in self.layer_idx:
|
236 |
+
query, key, value = self.masactrl_forward(query, key, value)
|
237 |
+
|
238 |
+
query = attn.head_to_batch_dim(query)
|
239 |
+
key = attn.head_to_batch_dim(key)
|
240 |
+
value = attn.head_to_batch_dim(value)
|
241 |
+
|
242 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
243 |
+
hidden_states = torch.bmm(attention_probs, value)
|
244 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
245 |
+
|
246 |
+
# linear proj
|
247 |
+
hidden_states = attn.to_out[0](hidden_states, *args)
|
248 |
+
# dropout
|
249 |
+
hidden_states = attn.to_out[1](hidden_states)
|
250 |
+
|
251 |
+
if input_ndim == 4:
|
252 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
253 |
+
|
254 |
+
if attn.residual_connection:
|
255 |
+
hidden_states = hidden_states + residual
|
256 |
+
|
257 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
258 |
+
|
259 |
+
return hidden_states
|
masactrl/masactrl_utils.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from typing import Optional, Union, Tuple, List, Callable, Dict
|
9 |
+
|
10 |
+
from torchvision.utils import save_image
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
|
13 |
+
|
14 |
+
class AttentionBase:
|
15 |
+
def __init__(self):
|
16 |
+
self.cur_step = 0
|
17 |
+
self.num_att_layers = -1
|
18 |
+
self.cur_att_layer = 0
|
19 |
+
|
20 |
+
def after_step(self):
|
21 |
+
pass
|
22 |
+
|
23 |
+
def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
24 |
+
out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
|
25 |
+
self.cur_att_layer += 1
|
26 |
+
if self.cur_att_layer == self.num_att_layers:
|
27 |
+
self.cur_att_layer = 0
|
28 |
+
self.cur_step += 1
|
29 |
+
# after step
|
30 |
+
self.after_step()
|
31 |
+
return out
|
32 |
+
|
33 |
+
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
34 |
+
out = torch.einsum('b i j, b j d -> b i d', attn, v)
|
35 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
|
36 |
+
return out
|
37 |
+
|
38 |
+
def reset(self):
|
39 |
+
self.cur_step = 0
|
40 |
+
self.cur_att_layer = 0
|
41 |
+
|
42 |
+
|
43 |
+
class AttentionStore(AttentionBase):
|
44 |
+
def __init__(self, res=[32], min_step=0, max_step=1000):
|
45 |
+
super().__init__()
|
46 |
+
self.res = res
|
47 |
+
self.min_step = min_step
|
48 |
+
self.max_step = max_step
|
49 |
+
self.valid_steps = 0
|
50 |
+
|
51 |
+
self.self_attns = [] # store the all attns
|
52 |
+
self.cross_attns = []
|
53 |
+
|
54 |
+
self.self_attns_step = [] # store the attns in each step
|
55 |
+
self.cross_attns_step = []
|
56 |
+
|
57 |
+
def after_step(self):
|
58 |
+
if self.cur_step > self.min_step and self.cur_step < self.max_step:
|
59 |
+
self.valid_steps += 1
|
60 |
+
if len(self.self_attns) == 0:
|
61 |
+
self.self_attns = self.self_attns_step
|
62 |
+
self.cross_attns = self.cross_attns_step
|
63 |
+
else:
|
64 |
+
for i in range(len(self.self_attns)):
|
65 |
+
self.self_attns[i] += self.self_attns_step[i]
|
66 |
+
self.cross_attns[i] += self.cross_attns_step[i]
|
67 |
+
self.self_attns_step.clear()
|
68 |
+
self.cross_attns_step.clear()
|
69 |
+
|
70 |
+
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
71 |
+
if attn.shape[1] <= 64 ** 2: # avoid OOM
|
72 |
+
if is_cross:
|
73 |
+
self.cross_attns_step.append(attn)
|
74 |
+
else:
|
75 |
+
self.self_attns_step.append(attn)
|
76 |
+
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
|
77 |
+
|
78 |
+
|
79 |
+
def regiter_attention_editor_diffusers(model, editor: AttentionBase):
|
80 |
+
"""
|
81 |
+
Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
|
82 |
+
"""
|
83 |
+
def ca_forward(self, place_in_unet):
|
84 |
+
def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
|
85 |
+
"""
|
86 |
+
The attention is similar to the original implementation of LDM CrossAttention class
|
87 |
+
except adding some modifications on the attention
|
88 |
+
"""
|
89 |
+
if encoder_hidden_states is not None:
|
90 |
+
context = encoder_hidden_states
|
91 |
+
if attention_mask is not None:
|
92 |
+
mask = attention_mask
|
93 |
+
|
94 |
+
to_out = self.to_out
|
95 |
+
if isinstance(to_out, nn.modules.container.ModuleList):
|
96 |
+
to_out = self.to_out[0]
|
97 |
+
else:
|
98 |
+
to_out = self.to_out
|
99 |
+
|
100 |
+
h = self.heads
|
101 |
+
q = self.to_q(x)
|
102 |
+
is_cross = context is not None
|
103 |
+
context = context if is_cross else x
|
104 |
+
k = self.to_k(context)
|
105 |
+
v = self.to_v(context)
|
106 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
107 |
+
|
108 |
+
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
109 |
+
|
110 |
+
if mask is not None:
|
111 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
112 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
113 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
114 |
+
mask = mask[:, None, :].repeat(h, 1, 1)
|
115 |
+
sim.masked_fill_(~mask, max_neg_value)
|
116 |
+
|
117 |
+
attn = sim.softmax(dim=-1)
|
118 |
+
# the only difference
|
119 |
+
out = editor(
|
120 |
+
q, k, v, sim, attn, is_cross, place_in_unet,
|
121 |
+
self.heads, scale=self.scale)
|
122 |
+
|
123 |
+
return to_out(out)
|
124 |
+
|
125 |
+
return forward
|
126 |
+
|
127 |
+
def register_editor(net, count, place_in_unet):
|
128 |
+
for name, subnet in net.named_children():
|
129 |
+
if net.__class__.__name__ == 'Attention': # spatial Transformer layer
|
130 |
+
net.forward = ca_forward(net, place_in_unet)
|
131 |
+
return count + 1
|
132 |
+
elif hasattr(net, 'children'):
|
133 |
+
count = register_editor(subnet, count, place_in_unet)
|
134 |
+
return count
|
135 |
+
|
136 |
+
cross_att_count = 0
|
137 |
+
for net_name, net in model.unet.named_children():
|
138 |
+
if "down" in net_name:
|
139 |
+
cross_att_count += register_editor(net, 0, "down")
|
140 |
+
elif "mid" in net_name:
|
141 |
+
cross_att_count += register_editor(net, 0, "mid")
|
142 |
+
elif "up" in net_name:
|
143 |
+
cross_att_count += register_editor(net, 0, "up")
|
144 |
+
editor.num_att_layers = cross_att_count
|
145 |
+
|
146 |
+
|
147 |
+
def regiter_attention_editor_ldm(model, editor: AttentionBase):
|
148 |
+
"""
|
149 |
+
Register a attention editor to Stable Diffusion model, refer from [Prompt-to-Prompt]
|
150 |
+
"""
|
151 |
+
def ca_forward(self, place_in_unet):
|
152 |
+
def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
|
153 |
+
"""
|
154 |
+
The attention is similar to the original implementation of LDM CrossAttention class
|
155 |
+
except adding some modifications on the attention
|
156 |
+
"""
|
157 |
+
if encoder_hidden_states is not None:
|
158 |
+
context = encoder_hidden_states
|
159 |
+
if attention_mask is not None:
|
160 |
+
mask = attention_mask
|
161 |
+
|
162 |
+
to_out = self.to_out
|
163 |
+
if isinstance(to_out, nn.modules.container.ModuleList):
|
164 |
+
to_out = self.to_out[0]
|
165 |
+
else:
|
166 |
+
to_out = self.to_out
|
167 |
+
|
168 |
+
h = self.heads
|
169 |
+
q = self.to_q(x)
|
170 |
+
is_cross = context is not None
|
171 |
+
context = context if is_cross else x
|
172 |
+
k = self.to_k(context)
|
173 |
+
v = self.to_v(context)
|
174 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
175 |
+
|
176 |
+
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
177 |
+
|
178 |
+
if mask is not None:
|
179 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
180 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
181 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
182 |
+
mask = mask[:, None, :].repeat(h, 1, 1)
|
183 |
+
sim.masked_fill_(~mask, max_neg_value)
|
184 |
+
|
185 |
+
attn = sim.softmax(dim=-1)
|
186 |
+
# the only difference
|
187 |
+
out = editor(
|
188 |
+
q, k, v, sim, attn, is_cross, place_in_unet,
|
189 |
+
self.heads, scale=self.scale)
|
190 |
+
|
191 |
+
return to_out(out)
|
192 |
+
|
193 |
+
return forward
|
194 |
+
|
195 |
+
def register_editor(net, count, place_in_unet):
|
196 |
+
for name, subnet in net.named_children():
|
197 |
+
if net.__class__.__name__ == 'CrossAttention': # spatial Transformer layer
|
198 |
+
net.forward = ca_forward(net, place_in_unet)
|
199 |
+
return count + 1
|
200 |
+
elif hasattr(net, 'children'):
|
201 |
+
count = register_editor(subnet, count, place_in_unet)
|
202 |
+
return count
|
203 |
+
|
204 |
+
cross_att_count = 0
|
205 |
+
for net_name, net in model.model.diffusion_model.named_children():
|
206 |
+
if "input" in net_name:
|
207 |
+
cross_att_count += register_editor(net, 0, "input")
|
208 |
+
elif "middle" in net_name:
|
209 |
+
cross_att_count += register_editor(net, 0, "middle")
|
210 |
+
elif "output" in net_name:
|
211 |
+
cross_att_count += register_editor(net, 0, "output")
|
212 |
+
editor.num_att_layers = cross_att_count
|
masactrl_w_adapter/ddim.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""SAMPLING ONLY."""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
|
8 |
+
extract_into_tensor
|
9 |
+
|
10 |
+
|
11 |
+
class DDIMSampler(object):
|
12 |
+
def __init__(self, model, schedule="linear", **kwargs):
|
13 |
+
super().__init__()
|
14 |
+
self.model = model
|
15 |
+
self.ddpm_num_timesteps = model.num_timesteps
|
16 |
+
self.schedule = schedule
|
17 |
+
|
18 |
+
def register_buffer(self, name, attr):
|
19 |
+
if type(attr) == torch.Tensor:
|
20 |
+
if attr.device != torch.device("cuda"):
|
21 |
+
attr = attr.to(torch.device("cuda"))
|
22 |
+
setattr(self, name, attr)
|
23 |
+
|
24 |
+
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
25 |
+
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
26 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
|
27 |
+
alphas_cumprod = self.model.alphas_cumprod
|
28 |
+
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
29 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
30 |
+
|
31 |
+
self.register_buffer('betas', to_torch(self.model.betas))
|
32 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
33 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
34 |
+
|
35 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
36 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
37 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
38 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
39 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
40 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
41 |
+
|
42 |
+
# ddim sampling parameters
|
43 |
+
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
44 |
+
ddim_timesteps=self.ddim_timesteps,
|
45 |
+
eta=ddim_eta, verbose=verbose)
|
46 |
+
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
47 |
+
self.register_buffer('ddim_alphas', ddim_alphas)
|
48 |
+
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
49 |
+
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
50 |
+
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
51 |
+
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
52 |
+
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
53 |
+
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
54 |
+
|
55 |
+
@torch.no_grad()
|
56 |
+
def sample(self,
|
57 |
+
S,
|
58 |
+
batch_size,
|
59 |
+
shape,
|
60 |
+
conditioning=None,
|
61 |
+
callback=None,
|
62 |
+
normals_sequence=None,
|
63 |
+
img_callback=None,
|
64 |
+
quantize_x0=False,
|
65 |
+
eta=0.,
|
66 |
+
mask=None,
|
67 |
+
x0=None,
|
68 |
+
temperature=1.,
|
69 |
+
noise_dropout=0.,
|
70 |
+
score_corrector=None,
|
71 |
+
corrector_kwargs=None,
|
72 |
+
verbose=True,
|
73 |
+
x_T=None,
|
74 |
+
log_every_t=100,
|
75 |
+
unconditional_guidance_scale=1.,
|
76 |
+
unconditional_conditioning=None,
|
77 |
+
features_adapter=None,
|
78 |
+
append_to_context=None,
|
79 |
+
cond_tau=0.4,
|
80 |
+
style_cond_tau=1.0,
|
81 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
82 |
+
**kwargs
|
83 |
+
):
|
84 |
+
if conditioning is not None:
|
85 |
+
if isinstance(conditioning, dict):
|
86 |
+
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
87 |
+
if cbs != batch_size:
|
88 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
89 |
+
else:
|
90 |
+
if conditioning.shape[0] != batch_size:
|
91 |
+
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
92 |
+
|
93 |
+
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
94 |
+
# sampling
|
95 |
+
C, H, W = shape
|
96 |
+
size = (batch_size, C, H, W)
|
97 |
+
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
98 |
+
|
99 |
+
samples, intermediates = self.ddim_sampling(conditioning, size,
|
100 |
+
callback=callback,
|
101 |
+
img_callback=img_callback,
|
102 |
+
quantize_denoised=quantize_x0,
|
103 |
+
mask=mask, x0=x0,
|
104 |
+
ddim_use_original_steps=False,
|
105 |
+
noise_dropout=noise_dropout,
|
106 |
+
temperature=temperature,
|
107 |
+
score_corrector=score_corrector,
|
108 |
+
corrector_kwargs=corrector_kwargs,
|
109 |
+
x_T=x_T,
|
110 |
+
log_every_t=log_every_t,
|
111 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
112 |
+
unconditional_conditioning=unconditional_conditioning,
|
113 |
+
features_adapter=features_adapter,
|
114 |
+
append_to_context=append_to_context,
|
115 |
+
cond_tau=cond_tau,
|
116 |
+
style_cond_tau=style_cond_tau,
|
117 |
+
)
|
118 |
+
return samples, intermediates
|
119 |
+
|
120 |
+
@torch.no_grad()
|
121 |
+
def ddim_sampling(self, cond, shape,
|
122 |
+
x_T=None, ddim_use_original_steps=False,
|
123 |
+
callback=None, timesteps=None, quantize_denoised=False,
|
124 |
+
mask=None, x0=None, img_callback=None, log_every_t=100,
|
125 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
126 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None, features_adapter=None,
|
127 |
+
append_to_context=None, cond_tau=0.4, style_cond_tau=1.0):
|
128 |
+
device = self.model.betas.device
|
129 |
+
b = shape[0]
|
130 |
+
if x_T is None:
|
131 |
+
img = torch.randn(shape, device=device)
|
132 |
+
else:
|
133 |
+
img = x_T
|
134 |
+
|
135 |
+
if timesteps is None:
|
136 |
+
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
137 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
138 |
+
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
139 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
140 |
+
|
141 |
+
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
142 |
+
time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
143 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
144 |
+
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
145 |
+
|
146 |
+
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
147 |
+
|
148 |
+
for i, step in enumerate(iterator):
|
149 |
+
index = total_steps - i - 1
|
150 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
151 |
+
|
152 |
+
if mask is not None:
|
153 |
+
assert x0 is not None
|
154 |
+
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
155 |
+
img = img_orig * mask + (1. - mask) * img
|
156 |
+
|
157 |
+
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
158 |
+
quantize_denoised=quantize_denoised, temperature=temperature,
|
159 |
+
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
160 |
+
corrector_kwargs=corrector_kwargs,
|
161 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
162 |
+
unconditional_conditioning=unconditional_conditioning,
|
163 |
+
features_adapter=None if index < int(
|
164 |
+
(1 - cond_tau) * total_steps) else features_adapter,
|
165 |
+
append_to_context=None if index < int(
|
166 |
+
(1 - style_cond_tau) * total_steps) else append_to_context,
|
167 |
+
)
|
168 |
+
img, pred_x0 = outs
|
169 |
+
if callback: callback(i)
|
170 |
+
if img_callback: img_callback(pred_x0, i)
|
171 |
+
|
172 |
+
if index % log_every_t == 0 or index == total_steps - 1:
|
173 |
+
intermediates['x_inter'].append(img)
|
174 |
+
intermediates['pred_x0'].append(pred_x0)
|
175 |
+
|
176 |
+
return img, intermediates
|
177 |
+
|
178 |
+
@torch.no_grad()
|
179 |
+
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
180 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
181 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None, features_adapter=None,
|
182 |
+
append_to_context=None):
|
183 |
+
b, *_, device = *x.shape, x.device
|
184 |
+
|
185 |
+
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
186 |
+
if append_to_context is not None:
|
187 |
+
model_output = self.model.apply_model(x, t, torch.cat([c, append_to_context], dim=1),
|
188 |
+
features_adapter=features_adapter)
|
189 |
+
else:
|
190 |
+
model_output = self.model.apply_model(x, t, c, features_adapter=features_adapter)
|
191 |
+
else:
|
192 |
+
x_in = torch.cat([x] * 2)
|
193 |
+
t_in = torch.cat([t] * 2)
|
194 |
+
if isinstance(c, dict):
|
195 |
+
assert isinstance(unconditional_conditioning, dict)
|
196 |
+
c_in = dict()
|
197 |
+
for k in c:
|
198 |
+
if isinstance(c[k], list):
|
199 |
+
c_in[k] = [torch.cat([
|
200 |
+
unconditional_conditioning[k][i],
|
201 |
+
c[k][i]]) for i in range(len(c[k]))]
|
202 |
+
else:
|
203 |
+
c_in[k] = torch.cat([
|
204 |
+
unconditional_conditioning[k],
|
205 |
+
c[k]])
|
206 |
+
elif isinstance(c, list):
|
207 |
+
c_in = list()
|
208 |
+
assert isinstance(unconditional_conditioning, list)
|
209 |
+
for i in range(len(c)):
|
210 |
+
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
|
211 |
+
else:
|
212 |
+
if append_to_context is not None:
|
213 |
+
pad_len = append_to_context.size(1)
|
214 |
+
new_unconditional_conditioning = torch.cat(
|
215 |
+
[unconditional_conditioning, unconditional_conditioning[:, -pad_len:, :]], dim=1)
|
216 |
+
new_c = torch.cat([c, append_to_context], dim=1)
|
217 |
+
c_in = torch.cat([new_unconditional_conditioning, new_c])
|
218 |
+
else:
|
219 |
+
c_in = torch.cat([unconditional_conditioning, c])
|
220 |
+
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in, features_adapter=features_adapter).chunk(2)
|
221 |
+
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
|
222 |
+
|
223 |
+
if self.model.parameterization == "v":
|
224 |
+
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
225 |
+
else:
|
226 |
+
e_t = model_output
|
227 |
+
|
228 |
+
if score_corrector is not None:
|
229 |
+
assert self.model.parameterization == "eps", 'not implemented'
|
230 |
+
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
231 |
+
|
232 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
233 |
+
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
234 |
+
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
235 |
+
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
236 |
+
# select parameters corresponding to the currently considered timestep
|
237 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
238 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
239 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
240 |
+
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
|
241 |
+
|
242 |
+
# current prediction for x_0
|
243 |
+
if self.model.parameterization != "v":
|
244 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
245 |
+
else:
|
246 |
+
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
247 |
+
|
248 |
+
if quantize_denoised:
|
249 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
250 |
+
# direction pointing to x_t
|
251 |
+
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
|
252 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
253 |
+
if noise_dropout > 0.:
|
254 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
255 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
256 |
+
return x_prev, pred_x0
|
257 |
+
|
258 |
+
@torch.no_grad()
|
259 |
+
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
260 |
+
# fast, but does not allow for exact reconstruction
|
261 |
+
# t serves as an index to gather the correct alphas
|
262 |
+
if use_original_steps:
|
263 |
+
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
264 |
+
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
265 |
+
else:
|
266 |
+
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
267 |
+
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
268 |
+
|
269 |
+
if noise is None:
|
270 |
+
noise = torch.randn_like(x0)
|
271 |
+
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
272 |
+
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
|
273 |
+
|
274 |
+
@torch.no_grad()
|
275 |
+
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
276 |
+
use_original_steps=False):
|
277 |
+
|
278 |
+
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
279 |
+
timesteps = timesteps[:t_start]
|
280 |
+
|
281 |
+
time_range = np.flip(timesteps)
|
282 |
+
total_steps = timesteps.shape[0]
|
283 |
+
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
284 |
+
|
285 |
+
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
286 |
+
x_dec = x_latent
|
287 |
+
for i, step in enumerate(iterator):
|
288 |
+
index = total_steps - i - 1
|
289 |
+
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
290 |
+
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
291 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
292 |
+
unconditional_conditioning=unconditional_conditioning)
|
293 |
+
return x_dec
|
294 |
+
|
295 |
+
def ddim_sampling_reverse(self,
|
296 |
+
num_steps=50,
|
297 |
+
x_0=None,
|
298 |
+
conditioning=None,
|
299 |
+
eta=0.,
|
300 |
+
verbose=False,
|
301 |
+
unconditional_guidance_scale=7.5,
|
302 |
+
unconditional_conditioning=None
|
303 |
+
):
|
304 |
+
"""
|
305 |
+
obtain the inverted x_T noisy image
|
306 |
+
"""
|
307 |
+
assert eta == 0., "eta should be 0. for deterministic sampling"
|
308 |
+
B = x_0.shape[0]
|
309 |
+
# scheduler
|
310 |
+
self.make_schedule(ddim_num_steps=num_steps, ddim_eta=eta, verbose=verbose)
|
311 |
+
self.register_buffer("ddim_sqrt_one_minus_alphas_prev", torch.tensor(1. - self.ddim_alphas_prev).sqrt())
|
312 |
+
# sampling
|
313 |
+
device = self.model.betas.device
|
314 |
+
|
315 |
+
intermediates = {"x_inter": [x_0], "pred_x0": []}
|
316 |
+
|
317 |
+
time_range = self.ddim_timesteps
|
318 |
+
print("selected steps for ddim inversion: ", time_range)
|
319 |
+
assert len(time_range) == num_steps, "time range should be same as num steps"
|
320 |
+
|
321 |
+
iterator = tqdm(time_range, desc='DDIM Inversion', total=num_steps)
|
322 |
+
x_t = x_0
|
323 |
+
for i, step in enumerate(iterator):
|
324 |
+
if i == 0:
|
325 |
+
step = 1
|
326 |
+
else:
|
327 |
+
step = time_range[i - 1]
|
328 |
+
ts = torch.full((B, ), step, device=device, dtype=torch.long)
|
329 |
+
outs = self.p_sample_ddim_reverse(x_t,
|
330 |
+
conditioning,
|
331 |
+
ts,
|
332 |
+
index=i,
|
333 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
334 |
+
unconditional_conditioning=unconditional_conditioning
|
335 |
+
)
|
336 |
+
x_t, pred_x0 = outs
|
337 |
+
intermediates["x_inter"].append(x_t)
|
338 |
+
intermediates["pred_x0"].append(pred_x0)
|
339 |
+
return x_t, intermediates
|
340 |
+
|
341 |
+
def p_sample_ddim_reverse(self,
|
342 |
+
x, c, t, index,
|
343 |
+
unconditional_guidance_scale=1.,
|
344 |
+
unconditional_conditioning=None):
|
345 |
+
B = x.shape[0]
|
346 |
+
device = x.device
|
347 |
+
|
348 |
+
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
349 |
+
e_t = self.model.apply_model(x, t, c)
|
350 |
+
else:
|
351 |
+
x_in = torch.cat([x] * 2)
|
352 |
+
t_in = torch.cat([t] * 2)
|
353 |
+
c_in = torch.cat([unconditional_conditioning, c])
|
354 |
+
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
355 |
+
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
356 |
+
|
357 |
+
# scheduler parameters
|
358 |
+
alphas = self.ddim_alphas
|
359 |
+
alphas_prev = self.ddim_alphas_prev
|
360 |
+
sqrt_one_minus_alphas_prev = self.ddim_sqrt_one_minus_alphas_prev
|
361 |
+
|
362 |
+
# select parameters corresponding to the currently considered timestep
|
363 |
+
alpha_cumprod_next = torch.full((B, 1, 1, 1), alphas[index], device=device)
|
364 |
+
alpha_cumprod_t = torch.full((B, 1, 1, 1), alphas_prev[index], device=device)
|
365 |
+
sqrt_one_minus_alpha_t = torch.full((B, 1, 1, 1), sqrt_one_minus_alphas_prev[index], device=device)
|
366 |
+
|
367 |
+
# current prediction for x_0
|
368 |
+
pred_x0 = (x - sqrt_one_minus_alpha_t * e_t) / alpha_cumprod_t.sqrt()
|
369 |
+
|
370 |
+
# direction pointing to x_t
|
371 |
+
dir_xt = (1. - alpha_cumprod_next).sqrt() * e_t
|
372 |
+
|
373 |
+
x_next = alpha_cumprod_next.sqrt() * pred_x0 + dir_xt
|
374 |
+
|
375 |
+
return x_next, pred_x0
|
masactrl_w_adapter/masactrl_w_adapter.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from basicsr.utils import tensor2img
|
7 |
+
from pytorch_lightning import seed_everything
|
8 |
+
from torch import autocast
|
9 |
+
from torchvision.io import read_image
|
10 |
+
|
11 |
+
from ldm.inference_base import (diffusion_inference, get_adapters, get_base_argument_parser, get_sd_models)
|
12 |
+
from ldm.modules.extra_condition import api
|
13 |
+
from ldm.modules.extra_condition.api import (ExtraCondition, get_adapter_feature, get_cond_model)
|
14 |
+
from ldm.util import fix_cond_shapes
|
15 |
+
|
16 |
+
# for masactrl
|
17 |
+
from masactrl.masactrl_utils import regiter_attention_editor_ldm
|
18 |
+
from masactrl.masactrl import MutualSelfAttentionControl
|
19 |
+
from masactrl.masactrl import MutualSelfAttentionControlMask
|
20 |
+
from masactrl.masactrl import MutualSelfAttentionControlMaskAuto
|
21 |
+
|
22 |
+
torch.set_grad_enabled(False)
|
23 |
+
|
24 |
+
|
25 |
+
def main():
|
26 |
+
supported_cond = [e.name for e in ExtraCondition]
|
27 |
+
parser = get_base_argument_parser()
|
28 |
+
parser.add_argument(
|
29 |
+
'--which_cond',
|
30 |
+
type=str,
|
31 |
+
required=True,
|
32 |
+
choices=supported_cond,
|
33 |
+
help='which condition modality you want to test',
|
34 |
+
)
|
35 |
+
# [MasaCtrl added] reference cond path
|
36 |
+
parser.add_argument(
|
37 |
+
"--cond_path_src",
|
38 |
+
type=str,
|
39 |
+
default=None,
|
40 |
+
help="the condition image path to synthesize the source image",
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--prompt_src",
|
44 |
+
type=str,
|
45 |
+
default=None,
|
46 |
+
help="the prompt to synthesize the source image",
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--src_img_path",
|
50 |
+
type=str,
|
51 |
+
default=None,
|
52 |
+
help="the input real source image path"
|
53 |
+
)
|
54 |
+
parser.add_argument(
|
55 |
+
"--start_code_path",
|
56 |
+
type=str,
|
57 |
+
default=None,
|
58 |
+
help="the inverted start code path to synthesize the source image",
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
"--masa_step",
|
62 |
+
type=int,
|
63 |
+
default=4,
|
64 |
+
help="the starting step for MasaCtrl",
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--masa_layer",
|
68 |
+
type=int,
|
69 |
+
default=10,
|
70 |
+
help="the starting layer for MasaCtrl",
|
71 |
+
)
|
72 |
+
|
73 |
+
opt = parser.parse_args()
|
74 |
+
which_cond = opt.which_cond
|
75 |
+
if opt.outdir is None:
|
76 |
+
opt.outdir = f'outputs/test-{which_cond}'
|
77 |
+
os.makedirs(opt.outdir, exist_ok=True)
|
78 |
+
if opt.resize_short_edge is None:
|
79 |
+
print(f"you don't specify the resize_shot_edge, so the maximum resolution is set to {opt.max_resolution}")
|
80 |
+
opt.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
81 |
+
|
82 |
+
if os.path.isdir(opt.cond_path): # for conditioning image folder
|
83 |
+
image_paths = [os.path.join(opt.cond_path, f) for f in os.listdir(opt.cond_path)]
|
84 |
+
else:
|
85 |
+
image_paths = [opt.cond_path]
|
86 |
+
print(image_paths)
|
87 |
+
|
88 |
+
# prepare models
|
89 |
+
sd_model, sampler = get_sd_models(opt)
|
90 |
+
adapter = get_adapters(opt, getattr(ExtraCondition, which_cond))
|
91 |
+
cond_model = None
|
92 |
+
if opt.cond_inp_type == 'image':
|
93 |
+
cond_model = get_cond_model(opt, getattr(ExtraCondition, which_cond))
|
94 |
+
|
95 |
+
process_cond_module = getattr(api, f'get_cond_{which_cond}')
|
96 |
+
|
97 |
+
# [MasaCtrl added] default STEP and LAYER params for MasaCtrl
|
98 |
+
STEP = opt.masa_step if opt.masa_step is not None else 4
|
99 |
+
LAYER = opt.masa_layer if opt.masa_layer is not None else 10
|
100 |
+
|
101 |
+
# inference
|
102 |
+
with torch.inference_mode(), \
|
103 |
+
sd_model.ema_scope(), \
|
104 |
+
autocast('cuda'):
|
105 |
+
for test_idx, cond_path in enumerate(image_paths):
|
106 |
+
seed_everything(opt.seed)
|
107 |
+
for v_idx in range(opt.n_samples):
|
108 |
+
# seed_everything(opt.seed+v_idx+test_idx)
|
109 |
+
if opt.cond_path_src:
|
110 |
+
cond_src = process_cond_module(opt, opt.cond_path_src, opt.cond_inp_type, cond_model)
|
111 |
+
cond = process_cond_module(opt, cond_path, opt.cond_inp_type, cond_model)
|
112 |
+
|
113 |
+
base_count = len(os.listdir(opt.outdir)) // 2
|
114 |
+
cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_{which_cond}.png'), tensor2img(cond))
|
115 |
+
if opt.cond_path_src:
|
116 |
+
cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_{which_cond}_src.png'), tensor2img(cond_src))
|
117 |
+
|
118 |
+
adapter_features, append_to_context = get_adapter_feature(cond, adapter)
|
119 |
+
if opt.cond_path_src:
|
120 |
+
adapter_features_src, append_to_context_src = get_adapter_feature(cond_src, adapter)
|
121 |
+
|
122 |
+
if opt.cond_path_src:
|
123 |
+
print("using reference guidance to synthesize image")
|
124 |
+
adapter_features = [torch.cat([adapter_features_src[i], adapter_features[i]]) for i in range(len(adapter_features))]
|
125 |
+
else:
|
126 |
+
adapter_features = [torch.cat([torch.zeros_like(feats), feats]) for feats in adapter_features]
|
127 |
+
|
128 |
+
if opt.scale > 1.:
|
129 |
+
adapter_features = [torch.cat([feats] * 2) for feats in adapter_features]
|
130 |
+
|
131 |
+
# prepare the batch prompts
|
132 |
+
if opt.prompt_src is not None:
|
133 |
+
prompts = [opt.prompt_src, opt.prompt]
|
134 |
+
else:
|
135 |
+
prompts = [opt.prompt] * 2
|
136 |
+
print("promts: ", prompts)
|
137 |
+
# get text embedding
|
138 |
+
c = sd_model.get_learned_conditioning(prompts)
|
139 |
+
if opt.scale != 1.0:
|
140 |
+
uc = sd_model.get_learned_conditioning([""] * len(prompts))
|
141 |
+
else:
|
142 |
+
uc = None
|
143 |
+
c, uc = fix_cond_shapes(sd_model, c, uc)
|
144 |
+
|
145 |
+
if not hasattr(opt, 'H'):
|
146 |
+
opt.H = 512
|
147 |
+
opt.W = 512
|
148 |
+
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
149 |
+
if opt.src_img_path: # perform ddim inversion
|
150 |
+
|
151 |
+
src_img = read_image(opt.src_img_path)
|
152 |
+
src_img = src_img.float() / 255. # input normalized image [0, 1]
|
153 |
+
src_img = src_img * 2 - 1
|
154 |
+
if src_img.dim() == 3:
|
155 |
+
src_img = src_img.unsqueeze(0)
|
156 |
+
src_img = F.interpolate(src_img, (opt.H, opt.W))
|
157 |
+
src_img = src_img.to(opt.device)
|
158 |
+
# obtain initial latent
|
159 |
+
encoder_posterior = sd_model.encode_first_stage(src_img)
|
160 |
+
src_x_0 = sd_model.get_first_stage_encoding(encoder_posterior)
|
161 |
+
start_code, latents_dict = sampler.ddim_sampling_reverse(
|
162 |
+
num_steps=opt.steps,
|
163 |
+
x_0=src_x_0,
|
164 |
+
conditioning=uc[:1], # you may change here during inversion
|
165 |
+
unconditional_guidance_scale=opt.scale,
|
166 |
+
unconditional_conditioning=uc[:1],
|
167 |
+
)
|
168 |
+
torch.save(
|
169 |
+
{
|
170 |
+
"start_code": start_code
|
171 |
+
},
|
172 |
+
os.path.join(opt.outdir, "start_code.pth"),
|
173 |
+
)
|
174 |
+
elif opt.start_code_path:
|
175 |
+
# load the inverted start code
|
176 |
+
start_code_dict = torch.load(opt.start_code_path)
|
177 |
+
start_code = start_code_dict.get("start_code").to(opt.device)
|
178 |
+
else:
|
179 |
+
start_code = torch.randn([1, *shape], device=opt.device)
|
180 |
+
start_code = start_code.expand(len(prompts), -1, -1, -1)
|
181 |
+
|
182 |
+
# hijack the attention module
|
183 |
+
editor = MutualSelfAttentionControl(STEP, LAYER)
|
184 |
+
regiter_attention_editor_ldm(sd_model, editor)
|
185 |
+
|
186 |
+
samples_latents, _ = sampler.sample(
|
187 |
+
S=opt.steps,
|
188 |
+
conditioning=c,
|
189 |
+
batch_size=len(prompts),
|
190 |
+
shape=shape,
|
191 |
+
verbose=False,
|
192 |
+
unconditional_guidance_scale=opt.scale,
|
193 |
+
unconditional_conditioning=uc,
|
194 |
+
x_T=start_code,
|
195 |
+
features_adapter=adapter_features,
|
196 |
+
append_to_context=append_to_context,
|
197 |
+
cond_tau=opt.cond_tau,
|
198 |
+
style_cond_tau=opt.style_cond_tau,
|
199 |
+
)
|
200 |
+
|
201 |
+
x_samples = sd_model.decode_first_stage(samples_latents)
|
202 |
+
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
203 |
+
|
204 |
+
cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_all_result.png'), tensor2img(x_samples))
|
205 |
+
# save the prompts and seed
|
206 |
+
with open(os.path.join(opt.outdir, "log.txt"), "w") as f:
|
207 |
+
for prom in prompts:
|
208 |
+
f.write(prom)
|
209 |
+
f.write("\n")
|
210 |
+
f.write(f"seed: {opt.seed}")
|
211 |
+
for i in range(len(x_samples)):
|
212 |
+
base_count += 1
|
213 |
+
cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_result.png'), tensor2img(x_samples[i]))
|
214 |
+
|
215 |
+
|
216 |
+
if __name__ == '__main__':
|
217 |
+
main()
|
playground.ipynb
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"### MasaCtrl: Tuning-free Mutual Self-Attention Control for Consistent Image Synthesis and Editing"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": null,
|
13 |
+
"metadata": {},
|
14 |
+
"outputs": [],
|
15 |
+
"source": [
|
16 |
+
"import os\n",
|
17 |
+
"import torch\n",
|
18 |
+
"import torch.nn as nn\n",
|
19 |
+
"import torch.nn.functional as F\n",
|
20 |
+
"\n",
|
21 |
+
"import numpy as np\n",
|
22 |
+
"\n",
|
23 |
+
"from tqdm import tqdm\n",
|
24 |
+
"from einops import rearrange, repeat\n",
|
25 |
+
"from omegaconf import OmegaConf\n",
|
26 |
+
"\n",
|
27 |
+
"from diffusers import DDIMScheduler\n",
|
28 |
+
"\n",
|
29 |
+
"from masactrl.diffuser_utils import MasaCtrlPipeline\n",
|
30 |
+
"from masactrl.masactrl_utils import AttentionBase\n",
|
31 |
+
"from masactrl.masactrl_utils import regiter_attention_editor_diffusers\n",
|
32 |
+
"\n",
|
33 |
+
"from torchvision.utils import save_image\n",
|
34 |
+
"from torchvision.io import read_image\n",
|
35 |
+
"from pytorch_lightning import seed_everything\n",
|
36 |
+
"\n",
|
37 |
+
"torch.cuda.set_device(0) # set the GPU device"
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"cell_type": "markdown",
|
42 |
+
"metadata": {},
|
43 |
+
"source": [
|
44 |
+
"#### Model Construction"
|
45 |
+
]
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"cell_type": "code",
|
49 |
+
"execution_count": null,
|
50 |
+
"metadata": {},
|
51 |
+
"outputs": [],
|
52 |
+
"source": [
|
53 |
+
"# Note that you may add your Hugging Face token to get access to the models\n",
|
54 |
+
"device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
|
55 |
+
"model_path = \"xyn-ai/anything-v4.0\"\n",
|
56 |
+
"# model_path = \"runwayml/stable-diffusion-v1-5\"\n",
|
57 |
+
"scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False, set_alpha_to_one=False)\n",
|
58 |
+
"model = MasaCtrlPipeline.from_pretrained(model_path, scheduler=scheduler, cross_attention_kwargs={\"scale\": 0.5}).to(device)"
|
59 |
+
]
|
60 |
+
},
|
61 |
+
{
|
62 |
+
"cell_type": "markdown",
|
63 |
+
"metadata": {},
|
64 |
+
"source": [
|
65 |
+
"#### Consistent synthesis with MasaCtrl"
|
66 |
+
]
|
67 |
+
},
|
68 |
+
{
|
69 |
+
"cell_type": "code",
|
70 |
+
"execution_count": null,
|
71 |
+
"metadata": {},
|
72 |
+
"outputs": [],
|
73 |
+
"source": [
|
74 |
+
"from masactrl.masactrl import MutualSelfAttentionControl\n",
|
75 |
+
"\n",
|
76 |
+
"\n",
|
77 |
+
"seed = 42\n",
|
78 |
+
"seed_everything(seed)\n",
|
79 |
+
"\n",
|
80 |
+
"out_dir = \"./workdir/masactrl_exp/\"\n",
|
81 |
+
"os.makedirs(out_dir, exist_ok=True)\n",
|
82 |
+
"sample_count = len(os.listdir(out_dir))\n",
|
83 |
+
"out_dir = os.path.join(out_dir, f\"sample_{sample_count}\")\n",
|
84 |
+
"os.makedirs(out_dir, exist_ok=True)\n",
|
85 |
+
"\n",
|
86 |
+
"prompts = [\n",
|
87 |
+
" \"1boy, casual, outdoors, sitting\", # source prompt\n",
|
88 |
+
" \"1boy, casual, outdoors, standing\" # target prompt\n",
|
89 |
+
"]\n",
|
90 |
+
"\n",
|
91 |
+
"# initialize the noise map\n",
|
92 |
+
"start_code = torch.randn([1, 4, 64, 64], device=device)\n",
|
93 |
+
"start_code = start_code.expand(len(prompts), -1, -1, -1)\n",
|
94 |
+
"\n",
|
95 |
+
"# inference the synthesized image without MasaCtrl\n",
|
96 |
+
"editor = AttentionBase()\n",
|
97 |
+
"regiter_attention_editor_diffusers(model, editor)\n",
|
98 |
+
"image_ori = model(prompts, latents=start_code, guidance_scale=7.5)\n",
|
99 |
+
"\n",
|
100 |
+
"# inference the synthesized image with MasaCtrl\n",
|
101 |
+
"STEP = 4\n",
|
102 |
+
"LAYPER = 10\n",
|
103 |
+
"\n",
|
104 |
+
"# hijack the attention module\n",
|
105 |
+
"editor = MutualSelfAttentionControl(STEP, LAYPER)\n",
|
106 |
+
"regiter_attention_editor_diffusers(model, editor)\n",
|
107 |
+
"\n",
|
108 |
+
"# inference the synthesized image\n",
|
109 |
+
"image_masactrl = model(prompts, latents=start_code, guidance_scale=7.5)[-1:]\n",
|
110 |
+
"\n",
|
111 |
+
"# save the synthesized image\n",
|
112 |
+
"out_image = torch.cat([image_ori, image_masactrl], dim=0)\n",
|
113 |
+
"save_image(out_image, os.path.join(out_dir, f\"all_step{STEP}_layer{LAYPER}.png\"))\n",
|
114 |
+
"save_image(out_image[0], os.path.join(out_dir, f\"source_step{STEP}_layer{LAYPER}.png\"))\n",
|
115 |
+
"save_image(out_image[1], os.path.join(out_dir, f\"without_step{STEP}_layer{LAYPER}.png\"))\n",
|
116 |
+
"save_image(out_image[2], os.path.join(out_dir, f\"masactrl_step{STEP}_layer{LAYPER}.png\"))\n",
|
117 |
+
"\n",
|
118 |
+
"print(\"Syntheiszed images are saved in\", out_dir)"
|
119 |
+
]
|
120 |
+
}
|
121 |
+
],
|
122 |
+
"metadata": {
|
123 |
+
"kernelspec": {
|
124 |
+
"display_name": "Python 3.8.5 ('ldm')",
|
125 |
+
"language": "python",
|
126 |
+
"name": "python3"
|
127 |
+
},
|
128 |
+
"language_info": {
|
129 |
+
"codemirror_mode": {
|
130 |
+
"name": "ipython",
|
131 |
+
"version": 3
|
132 |
+
},
|
133 |
+
"file_extension": ".py",
|
134 |
+
"mimetype": "text/x-python",
|
135 |
+
"name": "python",
|
136 |
+
"nbconvert_exporter": "python",
|
137 |
+
"pygments_lexer": "ipython3",
|
138 |
+
"version": "3.8.5"
|
139 |
+
},
|
140 |
+
"orig_nbformat": 4,
|
141 |
+
"vscode": {
|
142 |
+
"interpreter": {
|
143 |
+
"hash": "587aa04bacead72c1ffd459abbe4c8140b72ba2b534b24165b36a2ede3d95042"
|
144 |
+
}
|
145 |
+
}
|
146 |
+
},
|
147 |
+
"nbformat": 4,
|
148 |
+
"nbformat_minor": 2
|
149 |
+
}
|
playground_real.ipynb
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"### MasaCtrl: Tuning-free Mutual Self-Attention Control for Consistent Image Synthesis and Editing"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": null,
|
13 |
+
"metadata": {},
|
14 |
+
"outputs": [],
|
15 |
+
"source": [
|
16 |
+
"import os\n",
|
17 |
+
"import torch\n",
|
18 |
+
"import torch.nn as nn\n",
|
19 |
+
"import torch.nn.functional as F\n",
|
20 |
+
"\n",
|
21 |
+
"import numpy as np\n",
|
22 |
+
"\n",
|
23 |
+
"from tqdm import tqdm\n",
|
24 |
+
"from einops import rearrange, repeat\n",
|
25 |
+
"from omegaconf import OmegaConf\n",
|
26 |
+
"\n",
|
27 |
+
"from diffusers import DDIMScheduler\n",
|
28 |
+
"\n",
|
29 |
+
"from masactrl.diffuser_utils import MasaCtrlPipeline\n",
|
30 |
+
"from masactrl.masactrl_utils import AttentionBase\n",
|
31 |
+
"from masactrl.masactrl_utils import regiter_attention_editor_diffusers\n",
|
32 |
+
"\n",
|
33 |
+
"from torchvision.utils import save_image\n",
|
34 |
+
"from torchvision.io import read_image\n",
|
35 |
+
"from pytorch_lightning import seed_everything\n",
|
36 |
+
"\n",
|
37 |
+
"torch.cuda.set_device(0) # set the GPU device"
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"cell_type": "markdown",
|
42 |
+
"metadata": {},
|
43 |
+
"source": [
|
44 |
+
"#### Model Construction"
|
45 |
+
]
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"cell_type": "code",
|
49 |
+
"execution_count": null,
|
50 |
+
"metadata": {},
|
51 |
+
"outputs": [],
|
52 |
+
"source": [
|
53 |
+
"# Note that you may add your Hugging Face token to get access to the models\n",
|
54 |
+
"device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
|
55 |
+
"# model_path = \"xyn-ai/anything-v4.0\"\n",
|
56 |
+
"model_path = \"CompVis/stable-diffusion-v1-4\"\n",
|
57 |
+
"# model_path = \"runwayml/stable-diffusion-v1-5\"\n",
|
58 |
+
"scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False, set_alpha_to_one=False)\n",
|
59 |
+
"model = MasaCtrlPipeline.from_pretrained(model_path, scheduler=scheduler).to(device)"
|
60 |
+
]
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"cell_type": "markdown",
|
64 |
+
"metadata": {},
|
65 |
+
"source": [
|
66 |
+
"#### Real editing with MasaCtrl"
|
67 |
+
]
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"cell_type": "code",
|
71 |
+
"execution_count": null,
|
72 |
+
"metadata": {},
|
73 |
+
"outputs": [],
|
74 |
+
"source": [
|
75 |
+
"from masactrl.masactrl import MutualSelfAttentionControl\n",
|
76 |
+
"from torchvision.io import read_image\n",
|
77 |
+
"\n",
|
78 |
+
"\n",
|
79 |
+
"def load_image(image_path, device):\n",
|
80 |
+
" image = read_image(image_path)\n",
|
81 |
+
" image = image[:3].unsqueeze_(0).float() / 127.5 - 1. # [-1, 1]\n",
|
82 |
+
" image = F.interpolate(image, (512, 512))\n",
|
83 |
+
" image = image.to(device)\n",
|
84 |
+
" return image\n",
|
85 |
+
"\n",
|
86 |
+
"\n",
|
87 |
+
"seed = 42\n",
|
88 |
+
"seed_everything(seed)\n",
|
89 |
+
"\n",
|
90 |
+
"out_dir = \"./workdir/masactrl_real_exp/\"\n",
|
91 |
+
"os.makedirs(out_dir, exist_ok=True)\n",
|
92 |
+
"sample_count = len(os.listdir(out_dir))\n",
|
93 |
+
"out_dir = os.path.join(out_dir, f\"sample_{sample_count}\")\n",
|
94 |
+
"os.makedirs(out_dir, exist_ok=True)\n",
|
95 |
+
"\n",
|
96 |
+
"# source image\n",
|
97 |
+
"SOURCE_IMAGE_PATH = \"./gradio_app/images/corgi.jpg\"\n",
|
98 |
+
"source_image = load_image(SOURCE_IMAGE_PATH, device)\n",
|
99 |
+
"\n",
|
100 |
+
"source_prompt = \"\"\n",
|
101 |
+
"target_prompt = \"a photo of a running corgi\"\n",
|
102 |
+
"prompts = [source_prompt, target_prompt]\n",
|
103 |
+
"\n",
|
104 |
+
"# invert the source image\n",
|
105 |
+
"start_code, latents_list = model.invert(source_image,\n",
|
106 |
+
" source_prompt,\n",
|
107 |
+
" guidance_scale=7.5,\n",
|
108 |
+
" num_inference_steps=50,\n",
|
109 |
+
" return_intermediates=True)\n",
|
110 |
+
"start_code = start_code.expand(len(prompts), -1, -1, -1)\n",
|
111 |
+
"\n",
|
112 |
+
"# results of direct synthesis\n",
|
113 |
+
"editor = AttentionBase()\n",
|
114 |
+
"regiter_attention_editor_diffusers(model, editor)\n",
|
115 |
+
"image_fixed = model([target_prompt],\n",
|
116 |
+
" latents=start_code[-1:],\n",
|
117 |
+
" num_inference_steps=50,\n",
|
118 |
+
" guidance_scale=7.5)\n",
|
119 |
+
"\n",
|
120 |
+
"# inference the synthesized image with MasaCtrl\n",
|
121 |
+
"STEP = 4\n",
|
122 |
+
"LAYPER = 10\n",
|
123 |
+
"\n",
|
124 |
+
"# hijack the attention module\n",
|
125 |
+
"editor = MutualSelfAttentionControl(STEP, LAYPER)\n",
|
126 |
+
"regiter_attention_editor_diffusers(model, editor)\n",
|
127 |
+
"\n",
|
128 |
+
"# inference the synthesized image\n",
|
129 |
+
"image_masactrl = model(prompts,\n",
|
130 |
+
" latents=start_code,\n",
|
131 |
+
" guidance_scale=7.5)\n",
|
132 |
+
"# Note: querying the inversion intermediate features latents_list\n",
|
133 |
+
"# may obtain better reconstruction and editing results\n",
|
134 |
+
"# image_masactrl = model(prompts,\n",
|
135 |
+
"# latents=start_code,\n",
|
136 |
+
"# guidance_scale=7.5,\n",
|
137 |
+
"# ref_intermediate_latents=latents_list)\n",
|
138 |
+
"\n",
|
139 |
+
"# save the synthesized image\n",
|
140 |
+
"out_image = torch.cat([source_image * 0.5 + 0.5,\n",
|
141 |
+
" image_masactrl[0:1],\n",
|
142 |
+
" image_fixed,\n",
|
143 |
+
" image_masactrl[-1:]], dim=0)\n",
|
144 |
+
"save_image(out_image, os.path.join(out_dir, f\"all_step{STEP}_layer{LAYPER}.png\"))\n",
|
145 |
+
"save_image(out_image[0], os.path.join(out_dir, f\"source_step{STEP}_layer{LAYPER}.png\"))\n",
|
146 |
+
"save_image(out_image[1], os.path.join(out_dir, f\"reconstructed_source_step{STEP}_layer{LAYPER}.png\"))\n",
|
147 |
+
"save_image(out_image[2], os.path.join(out_dir, f\"without_step{STEP}_layer{LAYPER}.png\"))\n",
|
148 |
+
"save_image(out_image[3], os.path.join(out_dir, f\"masactrl_step{STEP}_layer{LAYPER}.png\"))\n",
|
149 |
+
"\n",
|
150 |
+
"print(\"Syntheiszed images are saved in\", out_dir)"
|
151 |
+
]
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"cell_type": "code",
|
155 |
+
"execution_count": null,
|
156 |
+
"metadata": {},
|
157 |
+
"outputs": [],
|
158 |
+
"source": []
|
159 |
+
}
|
160 |
+
],
|
161 |
+
"metadata": {
|
162 |
+
"kernelspec": {
|
163 |
+
"display_name": "Python 3.8.5 ('ldm')",
|
164 |
+
"language": "python",
|
165 |
+
"name": "python3"
|
166 |
+
},
|
167 |
+
"language_info": {
|
168 |
+
"codemirror_mode": {
|
169 |
+
"name": "ipython",
|
170 |
+
"version": 3
|
171 |
+
},
|
172 |
+
"file_extension": ".py",
|
173 |
+
"mimetype": "text/x-python",
|
174 |
+
"name": "python",
|
175 |
+
"nbconvert_exporter": "python",
|
176 |
+
"pygments_lexer": "ipython3",
|
177 |
+
"version": "3.8.5"
|
178 |
+
},
|
179 |
+
"orig_nbformat": 4,
|
180 |
+
"vscode": {
|
181 |
+
"interpreter": {
|
182 |
+
"hash": "587aa04bacead72c1ffd459abbe4c8140b72ba2b534b24165b36a2ede3d95042"
|
183 |
+
}
|
184 |
+
}
|
185 |
+
},
|
186 |
+
"nbformat": 4,
|
187 |
+
"nbformat_minor": 2
|
188 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#diffusers==0.15.0
|
2 |
+
diffusers
|
3 |
+
transformers
|
4 |
+
opencv-python
|
5 |
+
einops
|
6 |
+
omegaconf
|
7 |
+
pytorch_lightning
|
8 |
+
torch
|
9 |
+
torchvision
|
10 |
+
gradio
|
11 |
+
httpx[socks]
|
12 |
+
huggingface_hub==0.25.0
|
13 |
+
moviepy
|
run_synthesis_genshin_impact_xl.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
python run_synthesis_genshin_impact_xl.py --model_path "svjack/GenshinImpact_XL_Base" \
|
3 |
+
--prompt1 "A portrait of an old man, facing camera, best quality" \
|
4 |
+
--prompt2 "A portrait of an old man, facing camera, smiling, best quality" --guidance_scale 3.5
|
5 |
+
|
6 |
+
python run_synthesis_genshin_impact_xl.py --model_path "svjack/GenshinImpact_XL_Base" \
|
7 |
+
--prompt1 "solo,ZHONGLI\(genshin impact\),1boy,highres," \
|
8 |
+
--prompt2 "solo,ZHONGLI drink tea use chinese cup \(genshin impact\),1boy,highres," --guidance_scale 5
|
9 |
+
|
10 |
+
from IPython import display
|
11 |
+
|
12 |
+
display.Image("masactrl_exp/solo_zhongli_drink_tea_use_chinese_cup___genshin_impact___1boy_highres_/sample_0/masactrl_step4_layer64.png", width=512, height=512)
|
13 |
+
display.Image("masactrl_exp/solo_zhongli_drink_tea_use_chinese_cup___genshin_impact___1boy_highres_/sample_1/source_step4_layer74.png", width=512, height=512)
|
14 |
+
'''
|
15 |
+
|
16 |
+
import os
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
from tqdm import tqdm
|
24 |
+
from einops import rearrange, repeat
|
25 |
+
from omegaconf import OmegaConf
|
26 |
+
|
27 |
+
from diffusers import DDIMScheduler, DiffusionPipeline
|
28 |
+
|
29 |
+
from masactrl.diffuser_utils import MasaCtrlPipeline
|
30 |
+
from masactrl.masactrl_utils import AttentionBase
|
31 |
+
from masactrl.masactrl_utils import regiter_attention_editor_diffusers
|
32 |
+
from masactrl.masactrl import MutualSelfAttentionControl
|
33 |
+
|
34 |
+
from torchvision.utils import save_image
|
35 |
+
from torchvision.io import read_image
|
36 |
+
from pytorch_lightning import seed_everything
|
37 |
+
|
38 |
+
import argparse
|
39 |
+
import re
|
40 |
+
|
41 |
+
torch.cuda.set_device(0) # set the GPU device
|
42 |
+
|
43 |
+
# Note that you may add your Hugging Face token to get access to the models
|
44 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
45 |
+
|
46 |
+
def pathify(s):
|
47 |
+
# Convert to lowercase and replace non-alphanumeric characters with underscores
|
48 |
+
return re.sub(r'[^a-zA-Z0-9]', '_', s.lower())
|
49 |
+
|
50 |
+
def consistent_synthesis(args):
|
51 |
+
seed = 42
|
52 |
+
seed_everything(seed)
|
53 |
+
|
54 |
+
# Create the output directory based on prompt2
|
55 |
+
out_dir_ori = os.path.join("masactrl_exp", pathify(args.prompt2))
|
56 |
+
os.makedirs(out_dir_ori, exist_ok=True)
|
57 |
+
|
58 |
+
prompts = [
|
59 |
+
args.prompt1,
|
60 |
+
args.prompt2,
|
61 |
+
]
|
62 |
+
|
63 |
+
# inference the synthesized image with MasaCtrl
|
64 |
+
# TODO: note that the hyper paramerter of MasaCtrl for SDXL may be not optimal
|
65 |
+
STEP = 4
|
66 |
+
#LAYER_LIST = [44, 54, 64] # run the synthesis with MasaCtrl at three different layer configs
|
67 |
+
#LAYER_LIST = [64, 74, 84, 94] # run the synthesis with MasaCtrl at three different layer configs
|
68 |
+
LAYER_LIST = [64, 74] # run the synthesis with MasaCtrl at three different layer configs
|
69 |
+
|
70 |
+
# initialize the noise map
|
71 |
+
start_code = torch.randn([1, 4, 128, 128], device=device)
|
72 |
+
start_code = start_code.expand(len(prompts), -1, -1, -1)
|
73 |
+
|
74 |
+
# Load the model
|
75 |
+
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
76 |
+
model = DiffusionPipeline.from_pretrained(args.model_path, scheduler=scheduler).to(device)
|
77 |
+
|
78 |
+
# inference the synthesized image without MasaCtrl
|
79 |
+
editor = AttentionBase()
|
80 |
+
regiter_attention_editor_diffusers(model, editor)
|
81 |
+
image_ori = model(prompts, latents=start_code, guidance_scale=args.guidance_scale).images
|
82 |
+
|
83 |
+
for LAYER in LAYER_LIST:
|
84 |
+
# hijack the attention module
|
85 |
+
editor = MutualSelfAttentionControl(STEP, LAYER, model_type="SDXL")
|
86 |
+
regiter_attention_editor_diffusers(model, editor)
|
87 |
+
|
88 |
+
# inference the synthesized image
|
89 |
+
image_masactrl = model(prompts, latents=start_code, guidance_scale=args.guidance_scale).images
|
90 |
+
|
91 |
+
sample_count = len(os.listdir(out_dir_ori))
|
92 |
+
out_dir = os.path.join(out_dir_ori, f"sample_{sample_count}")
|
93 |
+
os.makedirs(out_dir, exist_ok=True)
|
94 |
+
image_ori[0].save(os.path.join(out_dir, f"source_step{STEP}_layer{LAYER}.png"))
|
95 |
+
image_ori[1].save(os.path.join(out_dir, f"without_step{STEP}_layer{LAYER}.png"))
|
96 |
+
image_masactrl[-1].save(os.path.join(out_dir, f"masactrl_step{STEP}_layer{LAYER}.png"))
|
97 |
+
with open(os.path.join(out_dir, f"prompts.txt"), "w") as f:
|
98 |
+
for p in prompts:
|
99 |
+
f.write(p + "\n")
|
100 |
+
f.write(f"seed: {seed}\n")
|
101 |
+
print("Syntheiszed images are saved in", out_dir)
|
102 |
+
|
103 |
+
if __name__ == "__main__":
|
104 |
+
parser = argparse.ArgumentParser(description="Consistent Synthesis with MasaCtrl")
|
105 |
+
parser.add_argument("--model_path", type=str, default="svjack/GenshinImpact_XL_Base", help="Path to the model")
|
106 |
+
parser.add_argument("--prompt1", type=str, default="A portrait of an old man, facing camera, best quality", help="First prompt")
|
107 |
+
parser.add_argument("--prompt2", type=str, default="A portrait of an old man, facing camera, smiling, best quality", help="Second prompt")
|
108 |
+
parser.add_argument("--guidance_scale", type=float, default=7.5, help="Guidance scale")
|
109 |
+
parser.add_argument("--out_dir", type=str, default=None, help="Output directory")
|
110 |
+
|
111 |
+
args = parser.parse_args()
|
112 |
+
|
113 |
+
# If out_dir is not provided, use the default path based on prompt2
|
114 |
+
if args.out_dir is None:
|
115 |
+
args.out_dir = os.path.join("masactrl_exp", pathify(args.prompt2))
|
116 |
+
|
117 |
+
consistent_synthesis(args)
|
run_synthesis_genshin_impact_xl_app.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from diffusers import DDIMScheduler, DiffusionPipeline
|
4 |
+
from masactrl.diffuser_utils import MasaCtrlPipeline
|
5 |
+
from masactrl.masactrl_utils import AttentionBase, regiter_attention_editor_diffusers
|
6 |
+
from masactrl.masactrl import MutualSelfAttentionControl
|
7 |
+
from pytorch_lightning import seed_everything
|
8 |
+
import os
|
9 |
+
import re
|
10 |
+
|
11 |
+
# 初始化设备和模型
|
12 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
13 |
+
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
14 |
+
model = DiffusionPipeline.from_pretrained("svjack/GenshinImpact_XL_Base", scheduler=scheduler).to(device)
|
15 |
+
|
16 |
+
def pathify(s):
|
17 |
+
return re.sub(r'[^a-zA-Z0-9]', '_', s.lower())
|
18 |
+
|
19 |
+
def consistent_synthesis(prompt1, prompt2, guidance_scale, seed, starting_step, starting_layer):
|
20 |
+
seed_everything(seed)
|
21 |
+
|
22 |
+
# 创建输出目录
|
23 |
+
out_dir_ori = os.path.join("masactrl_exp", pathify(prompt2))
|
24 |
+
os.makedirs(out_dir_ori, exist_ok=True)
|
25 |
+
|
26 |
+
prompts = [prompt1, prompt2]
|
27 |
+
|
28 |
+
# 初始化噪声图
|
29 |
+
start_code = torch.randn([1, 4, 128, 128], device=device)
|
30 |
+
start_code = start_code.expand(len(prompts), -1, -1, -1)
|
31 |
+
|
32 |
+
# 推理没有 MasaCtrl 的图像
|
33 |
+
editor = AttentionBase()
|
34 |
+
regiter_attention_editor_diffusers(model, editor)
|
35 |
+
image_ori = model(prompts, latents=start_code, guidance_scale=guidance_scale).images
|
36 |
+
|
37 |
+
images = []
|
38 |
+
# 劫持注意力模块
|
39 |
+
editor = MutualSelfAttentionControl(starting_step, starting_layer, model_type="SDXL")
|
40 |
+
regiter_attention_editor_diffusers(model, editor)
|
41 |
+
|
42 |
+
# 推理带 MasaCtrl 的图像
|
43 |
+
image_masactrl = model(prompts, latents=start_code, guidance_scale=guidance_scale).images
|
44 |
+
|
45 |
+
sample_count = len(os.listdir(out_dir_ori))
|
46 |
+
out_dir = os.path.join(out_dir_ori, f"sample_{sample_count}")
|
47 |
+
os.makedirs(out_dir, exist_ok=True)
|
48 |
+
image_ori[0].save(os.path.join(out_dir, f"source_step{starting_step}_layer{starting_layer}.png"))
|
49 |
+
image_ori[1].save(os.path.join(out_dir, f"without_step{starting_step}_layer{starting_layer}.png"))
|
50 |
+
image_masactrl[-1].save(os.path.join(out_dir, f"masactrl_step{starting_step}_layer{starting_layer}.png"))
|
51 |
+
with open(os.path.join(out_dir, f"prompts.txt"), "w") as f:
|
52 |
+
for p in prompts:
|
53 |
+
f.write(p + "\n")
|
54 |
+
f.write(f"seed: {seed}\n")
|
55 |
+
f.write(f"starting_step: {starting_step}\n")
|
56 |
+
f.write(f"starting_layer: {starting_layer}\n")
|
57 |
+
print("Synthesized images are saved in", out_dir)
|
58 |
+
|
59 |
+
return [image_ori[0], image_ori[1], image_masactrl[-1]]
|
60 |
+
|
61 |
+
def create_demo_synthesis():
|
62 |
+
with gr.Blocks() as demo:
|
63 |
+
gr.Markdown("# **Genshin Impact XL MasaCtrl Image Synthesis**") # 添加标题
|
64 |
+
gr.Markdown("## **Input Settings**")
|
65 |
+
with gr.Row():
|
66 |
+
with gr.Column():
|
67 |
+
prompt1 = gr.Textbox(label="Prompt 1", value="solo,ZHONGLI(genshin impact),1boy,highres,")
|
68 |
+
prompt2 = gr.Textbox(label="Prompt 2", value="solo,ZHONGLI drink tea use chinese cup (genshin impact),1boy,highres,")
|
69 |
+
with gr.Row():
|
70 |
+
starting_step = gr.Slider(label="Starting Step", minimum=0, maximum=999, value=4, step=1)
|
71 |
+
starting_layer = gr.Slider(label="Starting Layer", minimum=0, maximum=999, value=64, step=1)
|
72 |
+
run_btn = gr.Button("Run")
|
73 |
+
with gr.Column():
|
74 |
+
guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
75 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=42, step=1)
|
76 |
+
|
77 |
+
gr.Markdown("## **Output**")
|
78 |
+
with gr.Row():
|
79 |
+
image_source = gr.Image(label="Source Image")
|
80 |
+
image_without_masactrl = gr.Image(label="Image without MasaCtrl")
|
81 |
+
image_with_masactrl = gr.Image(label="Image with MasaCtrl")
|
82 |
+
|
83 |
+
inputs = [prompt1, prompt2, guidance_scale, seed, starting_step, starting_layer]
|
84 |
+
run_btn.click(consistent_synthesis, inputs, [image_source, image_without_masactrl, image_with_masactrl])
|
85 |
+
|
86 |
+
gr.Examples(
|
87 |
+
[
|
88 |
+
["solo,ZHONGLI(genshin impact),1boy,highres,", "solo,ZHONGLI drink tea use chinese cup (genshin impact),1boy,highres,", 42, 4, 64],
|
89 |
+
["solo,KAMISATO AYATO(genshin impact),1boy,highres,", "solo,KAMISATO AYATO smiling (genshin impact),1boy,highres,", 42, 4, 55]
|
90 |
+
],
|
91 |
+
[prompt1, prompt2, seed, starting_step, starting_layer],
|
92 |
+
)
|
93 |
+
return demo
|
94 |
+
|
95 |
+
if __name__ == "__main__":
|
96 |
+
demo_synthesis = create_demo_synthesis()
|
97 |
+
demo_synthesis.launch(share = True)
|
run_synthesis_sdxl.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from tqdm import tqdm
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
|
12 |
+
from diffusers import DDIMScheduler, DiffusionPipeline
|
13 |
+
|
14 |
+
from masactrl.diffuser_utils import MasaCtrlPipeline
|
15 |
+
from masactrl.masactrl_utils import AttentionBase
|
16 |
+
from masactrl.masactrl_utils import regiter_attention_editor_diffusers
|
17 |
+
from masactrl.masactrl import MutualSelfAttentionControl
|
18 |
+
|
19 |
+
from torchvision.utils import save_image
|
20 |
+
from torchvision.io import read_image
|
21 |
+
from pytorch_lightning import seed_everything
|
22 |
+
|
23 |
+
torch.cuda.set_device(0) # set the GPU device
|
24 |
+
|
25 |
+
# Note that you may add your Hugging Face token to get access to the models
|
26 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
27 |
+
|
28 |
+
model_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
29 |
+
# model_path = "Linaqruf/animagine-xl"
|
30 |
+
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
31 |
+
model = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler).to(device)
|
32 |
+
|
33 |
+
|
34 |
+
def consistent_synthesis():
|
35 |
+
seed = 42
|
36 |
+
seed_everything(seed)
|
37 |
+
|
38 |
+
out_dir_ori = "./workdir/masactrl_exp/oldman_smiling"
|
39 |
+
os.makedirs(out_dir_ori, exist_ok=True)
|
40 |
+
|
41 |
+
prompts = [
|
42 |
+
"A portrait of an old man, facing camera, best quality",
|
43 |
+
"A portrait of an old man, facing camera, smiling, best quality",
|
44 |
+
]
|
45 |
+
|
46 |
+
# inference the synthesized image with MasaCtrl
|
47 |
+
# TODO: note that the hyper paramerter of MasaCtrl for SDXL may be not optimal
|
48 |
+
STEP = 4
|
49 |
+
LAYER_LIST = [44, 54, 64] # run the synthesis with MasaCtrl at three different layer configs
|
50 |
+
|
51 |
+
# initialize the noise map
|
52 |
+
start_code = torch.randn([1, 4, 128, 128], device=device)
|
53 |
+
# start_code = None
|
54 |
+
start_code = start_code.expand(len(prompts), -1, -1, -1)
|
55 |
+
|
56 |
+
# inference the synthesized image without MasaCtrl
|
57 |
+
editor = AttentionBase()
|
58 |
+
regiter_attention_editor_diffusers(model, editor)
|
59 |
+
image_ori = model(prompts, latents=start_code, guidance_scale=7.5).images
|
60 |
+
|
61 |
+
for LAYER in LAYER_LIST:
|
62 |
+
# hijack the attention module
|
63 |
+
editor = MutualSelfAttentionControl(STEP, LAYER, model_type="SDXL")
|
64 |
+
regiter_attention_editor_diffusers(model, editor)
|
65 |
+
|
66 |
+
# inference the synthesized image
|
67 |
+
image_masactrl = model(prompts, latents=start_code, guidance_scale=7.5).images
|
68 |
+
|
69 |
+
sample_count = len(os.listdir(out_dir_ori))
|
70 |
+
out_dir = os.path.join(out_dir_ori, f"sample_{sample_count}")
|
71 |
+
os.makedirs(out_dir, exist_ok=True)
|
72 |
+
image_ori[0].save(os.path.join(out_dir, f"source_step{STEP}_layer{LAYER}.png"))
|
73 |
+
image_ori[1].save(os.path.join(out_dir, f"without_step{STEP}_layer{LAYER}.png"))
|
74 |
+
image_masactrl[-1].save(os.path.join(out_dir, f"masactrl_step{STEP}_layer{LAYER}.png"))
|
75 |
+
with open(os.path.join(out_dir, f"prompts.txt"), "w") as f:
|
76 |
+
for p in prompts:
|
77 |
+
f.write(p + "\n")
|
78 |
+
f.write(f"seed: {seed}\n")
|
79 |
+
print("Syntheiszed images are saved in", out_dir)
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
consistent_synthesis()
|
run_synthesis_sdxl_processor.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from tqdm import tqdm
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
from diffusers import DDIMScheduler, StableDiffusionPipeline, DiffusionPipeline
|
12 |
+
from torchvision.utils import save_image
|
13 |
+
from torchvision.io import read_image
|
14 |
+
from pytorch_lightning import seed_everything
|
15 |
+
|
16 |
+
from masactrl.masactrl_processor import register_attention_processor
|
17 |
+
|
18 |
+
torch.cuda.set_device(0) # set the GPU device
|
19 |
+
|
20 |
+
# Note that you may add your Hugging Face token to get access to the models
|
21 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
22 |
+
weight_dtype = torch.float16
|
23 |
+
model_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
24 |
+
scheduler = DDIMScheduler(
|
25 |
+
beta_start=0.00085,
|
26 |
+
beta_end=0.012,
|
27 |
+
beta_schedule="scaled_linear",
|
28 |
+
clip_sample=False,
|
29 |
+
set_alpha_to_one=False
|
30 |
+
)
|
31 |
+
pipe = DiffusionPipeline.from_pretrained(
|
32 |
+
model_path,
|
33 |
+
scheduler=scheduler,
|
34 |
+
torch_dtype=weight_dtype
|
35 |
+
).to(device)
|
36 |
+
|
37 |
+
|
38 |
+
def consistent_synthesis():
|
39 |
+
seed = 42
|
40 |
+
seed_everything(seed)
|
41 |
+
|
42 |
+
out_dir_ori = "./workdir/masactrl_exp/oldman_smiling"
|
43 |
+
os.makedirs(out_dir_ori, exist_ok=True)
|
44 |
+
|
45 |
+
prompts = [
|
46 |
+
"A portrait of an old man, facing camera, best quality",
|
47 |
+
"A portrait of an old man, facing camera, smiling, best quality",
|
48 |
+
]
|
49 |
+
|
50 |
+
# inference the synthesized image with MasaCtrl
|
51 |
+
# TODO: note that the hyper paramerter of MasaCtrl for SDXL may be not optimal
|
52 |
+
STEP = 4
|
53 |
+
LAYER_LIST = [44, 54, 64] # run the synthesis with MasaCtrl at three different layer configs
|
54 |
+
MODEL_TYPE = "SDXL"
|
55 |
+
|
56 |
+
# initialize the noise map
|
57 |
+
start_code = torch.randn([1, 4, 128, 128], dtype=weight_dtype, device=device)
|
58 |
+
# start_code = None
|
59 |
+
start_code = start_code.expand(len(prompts), -1, -1, -1)
|
60 |
+
|
61 |
+
# inference the synthesized image without MasaCtrl
|
62 |
+
image_ori = pipe(prompts, latents=start_code, guidance_scale=7.5).images
|
63 |
+
|
64 |
+
for LAYER in LAYER_LIST:
|
65 |
+
# hijack the attention module with MasaCtrl processor
|
66 |
+
processor_args = {
|
67 |
+
"start_step": STEP,
|
68 |
+
"start_layer": LAYER,
|
69 |
+
"model_type": MODEL_TYPE
|
70 |
+
}
|
71 |
+
register_attention_processor(pipe.unet, processor_type="MasaCtrlProcessor")
|
72 |
+
|
73 |
+
# inference the synthesized image
|
74 |
+
image_masactrl = pipe(prompts, latents=start_code, guidance_scale=7.5).images
|
75 |
+
|
76 |
+
sample_count = len(os.listdir(out_dir_ori))
|
77 |
+
out_dir = os.path.join(out_dir_ori, f"sample_{sample_count}")
|
78 |
+
os.makedirs(out_dir, exist_ok=True)
|
79 |
+
image_ori[0].save(os.path.join(out_dir, f"source_step{STEP}_layer{LAYER}.png"))
|
80 |
+
image_ori[1].save(os.path.join(out_dir, f"without_step{STEP}_layer{LAYER}.png"))
|
81 |
+
image_masactrl[-1].save(os.path.join(out_dir, f"masactrl_step{STEP}_layer{LAYER}.png"))
|
82 |
+
with open(os.path.join(out_dir, f"prompts.txt"), "w") as f:
|
83 |
+
for p in prompts:
|
84 |
+
f.write(p + "\n")
|
85 |
+
f.write(f"seed: {seed}\n")
|
86 |
+
print("Syntheiszed images are saved in", out_dir)
|
87 |
+
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
consistent_synthesis()
|
style.css
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
h1 {
|
2 |
+
text-align: center;
|
3 |
+
}
|