svjack commited on
Commit
f070657
·
verified ·
1 Parent(s): 968b4fb

Upload 23 files

Browse files
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Genshin Impact XL MasaCtrl
3
- emoji: 🐠
4
- colorFrom: blue
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 5.6.0
8
- app_file: app.py
9
- pinned: false
 
 
 
 
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MotionCtrl and MasaCtrl: Genshin Impact Character Synthesis
2
+
3
+ This repository provides a guide to setting up and running the MotionCtrl and MasaCtrl projects for synthesizing Genshin Impact characters using machine learning models. The process involves cloning repositories, installing dependencies, and executing scripts to generate and visualize character images.
4
+
5
+ ## Table of Contents
6
+ - [Prerequisites](#prerequisites)
7
+ - [Installation](#installation)
8
+ - [Running the Synthesis](#running-the-synthesis)
9
+ - [Using Gradio Interface](#using-gradio-interface)
10
+ - [Example Prompts](#example-prompts)
11
+
12
+ ## Prerequisites
13
+
14
+ Before you begin, ensure you have the following installed:
15
+ - Python 3.10
16
+ - Conda (for environment management)
17
+ - Git
18
+ - Git LFS (Large File Storage)
19
+ - FFmpeg
20
+
21
+ ## Installation
22
+
23
+ ### Step 1: Clone the MotionCtrl Repository
24
+ Clone the MotionCtrl repository and install its dependencies:
25
+
26
+ ```bash
27
+ git clone https://huggingface.co/spaces/svjack/MotionCtrl
28
+ cd MotionCtrl
29
+ pip install -r requirements.txt
30
+ ```
31
+
32
+ ### Step 2: Install System Dependencies
33
+ Update your package list and install necessary system packages:
34
+
35
+ ```bash
36
+ sudo apt-get update
37
+ sudo apt-get install cbm git-lfs ffmpeg
38
+ ```
39
+
40
+ ### Step 3: Set Up Python Environment
41
+ Create a Conda environment with Python 3.10, activate it, and install the IPython kernel:
42
+
43
+ ```bash
44
+ conda create -n py310 python=3.10
45
+ conda activate py310
46
+ pip install ipykernel
47
+ python -m ipykernel install --user --name py310 --display-name "py310"
48
+ ```
49
+
50
+ ### Step 4: Clone the MasaCtrl Repository
51
+ Clone the MasaCtrl repository and install its dependencies:
52
+
53
+ ```bash
54
+ git clone https://github.com/svjack/MasaCtrl
55
+ cd MasaCtrl
56
+ pip install -r requirements.txt
57
+ ```
58
+
59
+ ## Running the Synthesis
60
+
61
+ ### Command Line Interface
62
+ Run the synthesis script to generate images of Genshin Impact characters:
63
+
64
+ ```bash
65
+ python run_synthesis_genshin_impact_xl.py --model_path "svjack/GenshinImpact_XL_Base" \
66
+ --prompt1 "solo,ZHONGLI\(genshin impact\),1boy,highres," \
67
+ --prompt2 "solo,ZHONGLI drink tea use chinese cup \(genshin impact\),1boy,highres," --guidance_scale 5
68
+ ```
69
+
70
+ ### Gradio Interface
71
+ Alternatively, you can use the Gradio interface for a more interactive experience:
72
+
73
+ ```bash
74
+ python run_synthesis_genshin_impact_xl_app.py
75
+ ```
76
+
77
+ ## Example Prompts
78
+
79
+ Here are some example prompts you can use to generate different character images: (Image with MasaCtrl more like Source Image: In terms of background and other aspects)
80
+
81
+ - **Zhongli Drinking Tea:**
82
+ ```
83
+ "solo,ZHONGLI(genshin impact),1boy,highres," -> "solo,ZHONGLI drink tea use chinese cup (genshin impact),1boy,highres,"
84
+ ```
85
+ ![Screenshot 2024-11-17 132742](https://github.com/user-attachments/assets/00451728-f2d5-4009-afa8-23baaabdc223)
86
+
87
+ - **Kamisato Ayato Smiling:**
88
+ ```
89
+ "solo,KAMISATO AYATO(genshin impact),1boy,highres," -> "solo,KAMISATO AYATO smiling (genshin impact),1boy,highres,"
90
+ ```
91
+
92
+ ![Screenshot 2024-11-17 133421](https://github.com/user-attachments/assets/7a920f4c-8a3a-4387-98d6-381a798566ef)
93
+
94
+ ## MasaCtrl: Tuning-free <span style="text-decoration: underline"><font color="Tomato">M</font></span>utu<span style="text-decoration: underline"><font color="Tomato">a</font></span>l <span style="text-decoration: underline"><font color="Tomato">S</font></span>elf-<span style="text-decoration: underline"><font color="Tomato">A</font></span>ttention <span style="text-decoration: underline"><font color="Tomato">Control</font></span> for Consistent Image Synthesis and Editing
95
+
96
+ Pytorch implementation of [MasaCtrl: Tuning-free Mutual Self-Attention Control for **Consistent Image Synthesis and Editing**](https://arxiv.org/abs/2304.08465)
97
+
98
+ [Mingdeng Cao](https://github.com/ljzycmd),
99
+ [Xintao Wang](https://xinntao.github.io/),
100
+ [Zhongang Qi](https://scholar.google.com/citations?user=zJvrrusAAAAJ),
101
+ [Ying Shan](https://scholar.google.com/citations?user=4oXBp9UAAAAJ),
102
+ [Xiaohu Qie](https://scholar.google.com/citations?user=mk-F69UAAAAJ),
103
+ [Yinqiang Zheng](https://scholar.google.com/citations?user=JD-5DKcAAAAJ)
104
+
105
+ [![arXiv](https://img.shields.io/badge/ArXiv-2304.08465-brightgreen)](https://arxiv.org/abs/2304.08465)
106
+ [![Project page](https://img.shields.io/badge/Project-Page-brightgreen)](https://ljzycmd.github.io/projects/MasaCtrl/)
107
+ [![demo](https://img.shields.io/badge/Demo-Hugging%20Face-brightgreen)](https://huggingface.co/spaces/TencentARC/MasaCtrl)
108
+ [![demo](https://img.shields.io/badge/Demo-Colab-brightgreen)](https://colab.research.google.com/drive/1DZeQn2WvRBsNg4feS1bJrwWnIzw1zLJq?usp=sharing)
109
+ [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/MingDengCao/MasaCtrl)
110
+
111
  ---
112
+
113
+ <div align="center">
114
+ <img src="https://huggingface.co/TencentARC/MasaCtrl/resolve/main/assets/overview.gif">
115
+ <i> MasaCtrl enables performing various consistent non-rigid image synthesis and editing without fine-tuning and optimization. </i>
116
+ </div>
117
+
118
+
119
+ ## Updates
120
+ - [2024/8/17] We add AttnProcessor based MasaCtrlProcessor, please check `masactrl/masactrl_processor.py` and `run_synthesis_sdxl_processor.py`. You can integrate MasaCtrl into official Diffuser pipeline by register the attention processor.
121
+ - [2023/8/20] MasaCtrl supports SDXL (and other variants) now. ![sdxl_example](https://huggingface.co/TencentARC/MasaCtrl/resolve/main/assets/sdxl_example.jpg)
122
+ - [2023/5/13] The inference code of MasaCtrl with T2I-Adapter is available.
123
+ - [2023/4/28] [Hugging Face demo](https://huggingface.co/spaces/TencentARC/MasaCtrl) released.
124
+ - [2023/4/25] Code released.
125
+ - [2023/4/17] Paper is available [here](https://arxiv.org/abs/2304.08465).
126
+
127
  ---
128
 
129
+ ## Introduction
130
+
131
+ We propose MasaCtrl, a tuning-free method for non-rigid consistent image synthesis and editing. The key idea is to combine the `contents` from the *source image* and the `layout` synthesized from *text prompt and additional controls* into the desired synthesized or edited image, by querying semantically correlated features with **Mutual Self-Attention Control**.
132
+
133
+
134
+ ## Main Features
135
+
136
+ ### 1 Consistent Image Synthesis and Editing
137
+
138
+ MasaCtrl can perform prompt-based image synthesis and editing that changes the layout while maintaining contents of source image.
139
+
140
+ >*The target layout is synthesized directly from the target prompt.*
141
+
142
+ <details><summary>View visual results</summary>
143
+ <div align="center">
144
+ <img src="https://huggingface.co/TencentARC/MasaCtrl/resolve/main/assets/results_synthetic.png">
145
+ <i>Consistent synthesis results</i>
146
+
147
+ <img src="https://huggingface.co/TencentARC/MasaCtrl/resolve/main/assets/results_real.png">
148
+ <i>Real image editing results</i>
149
+ </div>
150
+ </details>
151
+
152
+
153
+
154
+ ### 2 Integration to Controllable Diffusion Models
155
+
156
+ Directly modifying the text prompts often cannot generate target layout of desired image, thus we further integrate our method into existing proposed controllable diffusion pipelines (like T2I-Adapter and ControlNet) to obtain stable synthesis and editing results.
157
+
158
+ >*The target layout controlled by additional guidance.*
159
+
160
+ <details><summary>View visual results</summary>
161
+ <div align="center">
162
+ <img src="https://huggingface.co/TencentARC/MasaCtrl/resolve/main/assets/results_w_adapter.png">
163
+ <i>Synthesis (left part) and editing (right part) results with T2I-Adapter</i>
164
+ </div>
165
+ </details>
166
+
167
+ ### 3 Generalization to Other Models: Anything-V4
168
+
169
+ Our method also generalize well to other Stable-Diffusion-based models.
170
+
171
+ <details><summary>View visual results</summary>
172
+ <div align="center">
173
+ <img src="https://huggingface.co/TencentARC/MasaCtrl/resolve/main/assets/anythingv4_synthetic.png">
174
+ <i>Results on Anything-V4</i>
175
+ </div>
176
+ </details>
177
+
178
+
179
+ ### 4 Extension to Video Synthesis
180
+
181
+ With dense consistent guidance, MasaCtrl enables video synthesis
182
+
183
+ <details><summary>View visual results</summary>
184
+ <div align="center">
185
+ <img src="https://huggingface.co/TencentARC/MasaCtrl/resolve/main/assets/results_w_adapter_consistent.png">
186
+ <i>Video Synthesis Results (with keypose and canny guidance)</i>
187
+ </div>
188
+ </details>
189
+
190
+
191
+ ## Usage
192
+
193
+ ### Requirements
194
+ We implement our method with [diffusers](https://github.com/huggingface/diffusers) code base with similar code structure to [Prompt-to-Prompt](https://github.com/google/prompt-to-prompt). The code runs on Python 3.8.5 with Pytorch 1.11. Conda environment is highly recommended.
195
+
196
+ ```base
197
+ pip install -r requirements.txt
198
+ ```
199
+
200
+ ### Checkpoints
201
+
202
+ **Stable Diffusion:**
203
+ We mainly conduct expriemnts on Stable Diffusion v1-4, while our method can generalize to other versions (like v1-5). You can download these checkpoints on their official repository and [Hugging Face](https://huggingface.co/).
204
+
205
+ **Personalized Models:**
206
+ You can download personlized models from [CIVITAI](https://civitai.com/) or train your own customized models.
207
+
208
+
209
+ ### Demos
210
+
211
+ **Notebook demos**
212
+
213
+ To run the synthesis with MasaCtrl, single GPU with at least 16 GB VRAM is required.
214
+
215
+ The notebook `playground.ipynb` and `playground_real.ipynb` provide the synthesis and real editing samples, respectively.
216
+
217
+ **Online demos**
218
+
219
+ We provide [![demo](https://img.shields.io/badge/Demo-Hugging%20Face-brightgreen)](https://huggingface.co/spaces/TencentARC/MasaCtrl) with Gradio app. Note that you may copy the demo into your own space to use the GPU. Online Colab demo [![demo](https://img.shields.io/badge/Demo-Colab-brightgreen)](https://colab.research.google.com/drive/1DZeQn2WvRBsNg4feS1bJrwWnIzw1zLJq?usp=sharing) is also available.
220
+
221
+ **Local Gradio demo**
222
+
223
+ You can launch the provided Gradio demo locally with
224
+
225
+ ```bash
226
+ CUDA_VISIBLE_DEVICES=0 python app.py
227
+ ```
228
+
229
+
230
+ ### MasaCtrl with T2I-Adapter
231
+
232
+ Install [T2I-Adapter](https://github.com/TencentARC/T2I-Adapter) and prepare the checkpoints following their provided tutorial. Assuming it has been successfully installed and the root directory is `T2I-Adapter`.
233
+
234
+ Thereafter copy the core `masactrl` package and the inference code `masactrl_w_adapter.py` to the root directory of T2I-Adapter
235
+
236
+ ```bash
237
+ cp -r MasaCtrl/masactrl T2I-Adapter/
238
+ cp MasaCtrl/masactrl_w_adapter/masactrl_w_adapter.py T2I-Adapter/
239
+ ```
240
+
241
+ **[Updates]** Or you can clone the repo [MasaCtrl-w-T2I-Adapter](https://github.com/ljzycmd/T2I-Adapter-w-MasaCtrl) directly to your local space.
242
+
243
+ Last, you can inference the images with following command (with sketch adapter)
244
+
245
+ ```bash
246
+ python masactrl_w_adapter.py \
247
+ --which_cond sketch \
248
+ --cond_path_src SOURCE_CONDITION_PATH \
249
+ --cond_path CONDITION_PATH \
250
+ --cond_inp_type sketch \
251
+ --prompt_src "A bear walking in the forest" \
252
+ --prompt "A bear standing in the forest" \
253
+ --sd_ckpt models/sd-v1-4.ckpt \
254
+ --resize_short_edge 512 \
255
+ --cond_tau 1.0 \
256
+ --cond_weight 1.0 \
257
+ --n_samples 1 \
258
+ --adapter_ckpt models/t2iadapter_sketch_sd14v1.pth
259
+ ```
260
+
261
+ NOTE: You can download the sketch examples [here](https://huggingface.co/TencentARC/MasaCtrl/tree/main/sketch_example).
262
+
263
+ For real image, the DDIM inversion is performed to invert the image into the noise map, thus we add the inversion process into the original DDIM sampler. **You should replace the original file `T2I-Adapter/ldm/models/diffusion/ddim.py` with the exteneded version `MasaCtrl/masactrl_w_adapter/ddim.py` to enable the inversion function**. Then you can edit the real image with following command (with sketch adapter)
264
+
265
+ ```bash
266
+ python masactrl_w_adapter.py \
267
+ --src_img_path SOURCE_IMAGE_PATH \
268
+ --cond_path CONDITION_PATH \
269
+ --cond_inp_type image \
270
+ --prompt_src "" \
271
+ --prompt "a photo of a man wearing black t-shirt, giving a thumbs up" \
272
+ --sd_ckpt models/sd-v1-4.ckpt \
273
+ --resize_short_edge 512 \
274
+ --cond_tau 1.0 \
275
+ --cond_weight 1.0 \
276
+ --n_samples 1 \
277
+ --which_cond sketch \
278
+ --adapter_ckpt models/t2iadapter_sketch_sd14v1.pth \
279
+ --outdir ./workdir/masactrl_w_adapter_inversion/black-shirt
280
+ ```
281
+
282
+ NOTE: You can download the real image editing example [here](https://huggingface.co/TencentARC/MasaCtrl/tree/main/black_shirt_example).
283
+
284
+ ## Acknowledgements
285
+
286
+ We thank the awesome research works [Prompt-to-Prompt](https://github.com/google/prompt-to-prompt), [T2I-Adapter](https://github.com/TencentARC/T2I-Adapter).
287
+
288
+
289
+ ## Citation
290
+
291
+ ```bibtex
292
+ @InProceedings{cao_2023_masactrl,
293
+ author = {Cao, Mingdeng and Wang, Xintao and Qi, Zhongang and Shan, Ying and Qie, Xiaohu and Zheng, Yinqiang},
294
+ title = {MasaCtrl: Tuning-Free Mutual Self-Attention Control for Consistent Image Synthesis and Editing},
295
+ booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
296
+ month = {October},
297
+ year = {2023},
298
+ pages = {22560-22570}
299
+ }
300
+ ```
301
+
302
+
303
+ ## Contact
304
+
305
+ If you have any comments or questions, please [open a new issue](https://github.com/TencentARC/MasaCtrl/issues/new/choose) or feel free to contact [Mingdeng Cao](https://github.com/ljzycmd) and [Xintao Wang](https://xinntao.github.io/).
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import DDIMScheduler, DiffusionPipeline
4
+ from masactrl.diffuser_utils import MasaCtrlPipeline
5
+ from masactrl.masactrl_utils import AttentionBase, regiter_attention_editor_diffusers
6
+ from masactrl.masactrl import MutualSelfAttentionControl
7
+ from pytorch_lightning import seed_everything
8
+ import os
9
+ import re
10
+
11
+ # 初始化设备和模型
12
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
13
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
14
+ model = DiffusionPipeline.from_pretrained("svjack/GenshinImpact_XL_Base", scheduler=scheduler).to(device)
15
+
16
+ def pathify(s):
17
+ return re.sub(r'[^a-zA-Z0-9]', '_', s.lower())
18
+
19
+ def consistent_synthesis(prompt1, prompt2, guidance_scale, seed, starting_step, starting_layer):
20
+ seed_everything(seed)
21
+
22
+ # 创建输出目录
23
+ out_dir_ori = os.path.join("masactrl_exp", pathify(prompt2))
24
+ os.makedirs(out_dir_ori, exist_ok=True)
25
+
26
+ prompts = [prompt1, prompt2]
27
+
28
+ # 初始化噪声图
29
+ start_code = torch.randn([1, 4, 128, 128], device=device)
30
+ start_code = start_code.expand(len(prompts), -1, -1, -1)
31
+
32
+ # 推理没有 MasaCtrl 的图像
33
+ editor = AttentionBase()
34
+ regiter_attention_editor_diffusers(model, editor)
35
+ image_ori = model(prompts, latents=start_code, guidance_scale=guidance_scale).images
36
+
37
+ images = []
38
+ # 劫持注意力模块
39
+ editor = MutualSelfAttentionControl(starting_step, starting_layer, model_type="SDXL")
40
+ regiter_attention_editor_diffusers(model, editor)
41
+
42
+ # 推理带 MasaCtrl 的图像
43
+ image_masactrl = model(prompts, latents=start_code, guidance_scale=guidance_scale).images
44
+
45
+ sample_count = len(os.listdir(out_dir_ori))
46
+ out_dir = os.path.join(out_dir_ori, f"sample_{sample_count}")
47
+ os.makedirs(out_dir, exist_ok=True)
48
+ image_ori[0].save(os.path.join(out_dir, f"source_step{starting_step}_layer{starting_layer}.png"))
49
+ image_ori[1].save(os.path.join(out_dir, f"without_step{starting_step}_layer{starting_layer}.png"))
50
+ image_masactrl[-1].save(os.path.join(out_dir, f"masactrl_step{starting_step}_layer{starting_layer}.png"))
51
+ with open(os.path.join(out_dir, f"prompts.txt"), "w") as f:
52
+ for p in prompts:
53
+ f.write(p + "\n")
54
+ f.write(f"seed: {seed}\n")
55
+ f.write(f"starting_step: {starting_step}\n")
56
+ f.write(f"starting_layer: {starting_layer}\n")
57
+ print("Synthesized images are saved in", out_dir)
58
+
59
+ return [image_ori[0], image_ori[1], image_masactrl[-1]]
60
+
61
+ def create_demo_synthesis():
62
+ with gr.Blocks() as demo:
63
+ gr.Markdown("# **Genshin Impact XL MasaCtrl Image Synthesis**") # 添加标题
64
+ gr.Markdown("## **Input Settings**")
65
+ with gr.Row():
66
+ with gr.Column():
67
+ prompt1 = gr.Textbox(label="Prompt 1", value="solo,ZHONGLI(genshin impact),1boy,highres,")
68
+ prompt2 = gr.Textbox(label="Prompt 2", value="solo,ZHONGLI drink tea use chinese cup (genshin impact),1boy,highres,")
69
+ with gr.Row():
70
+ starting_step = gr.Slider(label="Starting Step", minimum=0, maximum=999, value=4, step=1)
71
+ starting_layer = gr.Slider(label="Starting Layer", minimum=0, maximum=999, value=64, step=1)
72
+ run_btn = gr.Button("Run")
73
+ with gr.Column():
74
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
75
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=42, step=1)
76
+
77
+ gr.Markdown("## **Output**")
78
+ with gr.Row():
79
+ image_source = gr.Image(label="Source Image")
80
+ image_without_masactrl = gr.Image(label="Image without MasaCtrl")
81
+ image_with_masactrl = gr.Image(label="Image with MasaCtrl")
82
+
83
+ inputs = [prompt1, prompt2, guidance_scale, seed, starting_step, starting_layer]
84
+ run_btn.click(consistent_synthesis, inputs, [image_source, image_without_masactrl, image_with_masactrl])
85
+
86
+ gr.Examples(
87
+ [
88
+ ["solo,ZHONGLI(genshin impact),1boy,highres,", "solo,ZHONGLI drink tea use chinese cup (genshin impact),1boy,highres,", 42, 4, 64],
89
+ ["solo,KAMISATO AYATO(genshin impact),1boy,highres,", "solo,KAMISATO AYATO smiling (genshin impact),1boy,highres,", 42, 4, 55]
90
+ ],
91
+ [prompt1, prompt2, seed, starting_step, starting_layer],
92
+ )
93
+ return demo
94
+
95
+ if __name__ == "__main__":
96
+ demo_synthesis = create_demo_synthesis()
97
+ demo_synthesis.launch(share = True)
gradio_app/app_utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from diffusers import DDIMScheduler
5
+ from pytorch_lightning import seed_everything
6
+
7
+ from masactrl.diffuser_utils import MasaCtrlPipeline
8
+ from masactrl.masactrl_utils import (AttentionBase,
9
+ regiter_attention_editor_diffusers)
10
+
11
+
12
+ torch.set_grad_enabled(False)
13
+
14
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
15
+ "cpu")
16
+ model_path = "xyn-ai/anything-v4.0"
17
+ scheduler = DDIMScheduler(beta_start=0.00085,
18
+ beta_end=0.012,
19
+ beta_schedule="scaled_linear",
20
+ clip_sample=False,
21
+ set_alpha_to_one=False)
22
+ model = MasaCtrlPipeline.from_pretrained(model_path,
23
+ scheduler=scheduler).to(device)
24
+
25
+ global_context = {
26
+ "model_path": model_path,
27
+ "scheduler": scheduler,
28
+ "model": model,
29
+ "device": device
30
+ }
gradio_app/image_synthesis_app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from diffusers import DDIMScheduler
5
+ from pytorch_lightning import seed_everything
6
+
7
+ from masactrl.diffuser_utils import MasaCtrlPipeline
8
+ from masactrl.masactrl_utils import (AttentionBase,
9
+ regiter_attention_editor_diffusers)
10
+
11
+ from .app_utils import global_context
12
+
13
+ torch.set_grad_enabled(False)
14
+
15
+ # device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
16
+ # "cpu")
17
+ # model_path = "andite/anything-v4.0"
18
+ # scheduler = DDIMScheduler(beta_start=0.00085,
19
+ # beta_end=0.012,
20
+ # beta_schedule="scaled_linear",
21
+ # clip_sample=False,
22
+ # set_alpha_to_one=False)
23
+ # model = MasaCtrlPipeline.from_pretrained(model_path,
24
+ # scheduler=scheduler).to(device)
25
+
26
+
27
+ def consistent_synthesis(source_prompt, target_prompt, starting_step,
28
+ starting_layer, image_resolution, ddim_steps, scale,
29
+ seed, appended_prompt, negative_prompt):
30
+ from masactrl.masactrl import MutualSelfAttentionControl
31
+
32
+ model = global_context["model"]
33
+ device = global_context["device"]
34
+
35
+ seed_everything(seed)
36
+
37
+ with torch.no_grad():
38
+ if appended_prompt is not None:
39
+ source_prompt += appended_prompt
40
+ target_prompt += appended_prompt
41
+ prompts = [source_prompt, target_prompt]
42
+
43
+ # initialize the noise map
44
+ start_code = torch.randn([1, 4, 64, 64], device=device)
45
+ start_code = start_code.expand(len(prompts), -1, -1, -1)
46
+
47
+ # inference the synthesized image without MasaCtrl
48
+ editor = AttentionBase()
49
+ regiter_attention_editor_diffusers(model, editor)
50
+ target_image_ori = model([target_prompt],
51
+ latents=start_code[-1:],
52
+ guidance_scale=7.5)
53
+ target_image_ori = target_image_ori.cpu().permute(0, 2, 3, 1).numpy()
54
+
55
+ # inference the synthesized image with MasaCtrl
56
+ # hijack the attention module
57
+ controller = MutualSelfAttentionControl(starting_step, starting_layer)
58
+ regiter_attention_editor_diffusers(model, controller)
59
+
60
+ # inference the synthesized image
61
+ image_masactrl = model(prompts, latents=start_code, guidance_scale=7.5)
62
+ image_masactrl = image_masactrl.cpu().permute(0, 2, 3, 1).numpy()
63
+
64
+ return [image_masactrl[0], target_image_ori[0],
65
+ image_masactrl[1]] # source, fixed seed, masactrl
66
+
67
+
68
+ def create_demo_synthesis():
69
+ with gr.Blocks() as demo:
70
+ gr.Markdown("## **Input Settings**")
71
+ with gr.Row():
72
+ with gr.Column():
73
+ source_prompt = gr.Textbox(
74
+ label="Source Prompt",
75
+ value='1boy, casual, outdoors, sitting',
76
+ interactive=True)
77
+ target_prompt = gr.Textbox(
78
+ label="Target Prompt",
79
+ value='1boy, casual, outdoors, standing',
80
+ interactive=True)
81
+ with gr.Row():
82
+ ddim_steps = gr.Slider(label="DDIM Steps",
83
+ minimum=1,
84
+ maximum=999,
85
+ value=50,
86
+ step=1)
87
+ starting_step = gr.Slider(
88
+ label="Step of MasaCtrl",
89
+ minimum=0,
90
+ maximum=999,
91
+ value=4,
92
+ step=1)
93
+ starting_layer = gr.Slider(label="Layer of MasaCtrl",
94
+ minimum=0,
95
+ maximum=16,
96
+ value=10,
97
+ step=1)
98
+ run_btn = gr.Button(label="Run")
99
+ with gr.Column():
100
+ appended_prompt = gr.Textbox(label="Appended Prompt", value='')
101
+ negative_prompt = gr.Textbox(label="Negative Prompt", value='')
102
+ with gr.Row():
103
+ image_resolution = gr.Slider(label="Image Resolution",
104
+ minimum=256,
105
+ maximum=768,
106
+ value=512,
107
+ step=64)
108
+ scale = gr.Slider(label="CFG Scale",
109
+ minimum=0.1,
110
+ maximum=30.0,
111
+ value=7.5,
112
+ step=0.1)
113
+ seed = gr.Slider(label="Seed",
114
+ minimum=-1,
115
+ maximum=2147483647,
116
+ value=42,
117
+ step=1)
118
+
119
+ gr.Markdown("## **Output**")
120
+ with gr.Row():
121
+ image_source = gr.Image(label="Source Image")
122
+ image_fixed = gr.Image(label="Image with Fixed Seed")
123
+ image_masactrl = gr.Image(label="Image with MasaCtrl")
124
+
125
+ inputs = [
126
+ source_prompt, target_prompt, starting_step, starting_layer,
127
+ image_resolution, ddim_steps, scale, seed, appended_prompt,
128
+ negative_prompt
129
+ ]
130
+ run_btn.click(consistent_synthesis, inputs,
131
+ [image_source, image_fixed, image_masactrl])
132
+
133
+ gr.Examples(
134
+ [[
135
+ "1boy, bishounen, casual, indoors, sitting, coffee shop, bokeh",
136
+ "1boy, bishounen, casual, indoors, standing, coffee shop, bokeh",
137
+ 42
138
+ ],
139
+ [
140
+ "1boy, casual, outdoors, sitting",
141
+ "1boy, casual, outdoors, sitting, side view", 42
142
+ ],
143
+ [
144
+ "1boy, casual, outdoors, sitting",
145
+ "1boy, casual, outdoors, standing, clapping hands", 42
146
+ ],
147
+ [
148
+ "1boy, casual, outdoors, sitting",
149
+ "1boy, casual, outdoors, sitting, shows thumbs up", 42
150
+ ],
151
+ [
152
+ "1boy, casual, outdoors, sitting",
153
+ "1boy, casual, outdoors, sitting, with crossed arms", 42
154
+ ],
155
+ [
156
+ "1boy, casual, outdoors, sitting",
157
+ "1boy, casual, outdoors, sitting, rasing hands", 42
158
+ ]],
159
+ [source_prompt, target_prompt, seed],
160
+ )
161
+ return demo
162
+
163
+
164
+ if __name__ == "__main__":
165
+ demo_syntehsis = create_demo_synthesis()
166
+ demo_synthesis.launch()
gradio_app/images/corgi.jpg ADDED
gradio_app/images/person.png ADDED
gradio_app/real_image_editing_app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import gradio as gr
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from diffusers import DDIMScheduler
7
+ from torchvision.io import read_image
8
+ from pytorch_lightning import seed_everything
9
+
10
+ from masactrl.diffuser_utils import MasaCtrlPipeline
11
+ from masactrl.masactrl_utils import (AttentionBase,
12
+ regiter_attention_editor_diffusers)
13
+
14
+ from .app_utils import global_context
15
+
16
+ torch.set_grad_enabled(False)
17
+
18
+ # device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
19
+ # "cpu")
20
+
21
+ # model_path = "CompVis/stable-diffusion-v1-4"
22
+ # scheduler = DDIMScheduler(beta_start=0.00085,
23
+ # beta_end=0.012,
24
+ # beta_schedule="scaled_linear",
25
+ # clip_sample=False,
26
+ # set_alpha_to_one=False)
27
+ # model = MasaCtrlPipeline.from_pretrained(model_path,
28
+ # scheduler=scheduler).to(device)
29
+
30
+
31
+ def load_image(image_path):
32
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
33
+ image = read_image(image_path)
34
+ image = image[:3].unsqueeze_(0).float() / 127.5 - 1. # [-1, 1]
35
+ image = F.interpolate(image, (512, 512))
36
+ image = image.to(device)
37
+
38
+
39
+ def real_image_editing(source_image, target_prompt,
40
+ starting_step, starting_layer, ddim_steps, scale, seed,
41
+ appended_prompt, negative_prompt):
42
+ from masactrl.masactrl import MutualSelfAttentionControl
43
+
44
+ model = global_context["model"]
45
+ device = global_context["device"]
46
+
47
+ seed_everything(seed)
48
+
49
+ with torch.no_grad():
50
+ if appended_prompt is not None:
51
+ target_prompt += appended_prompt
52
+ ref_prompt = ""
53
+ prompts = [ref_prompt, target_prompt]
54
+
55
+ # invert the image into noise map
56
+ if isinstance(source_image, np.ndarray):
57
+ source_image = torch.from_numpy(source_image).to(device) / 127.5 - 1.
58
+ source_image = source_image.unsqueeze(0).permute(0, 3, 1, 2)
59
+ source_image = F.interpolate(source_image, (512, 512))
60
+
61
+ start_code, latents_list = model.invert(source_image,
62
+ ref_prompt,
63
+ guidance_scale=scale,
64
+ num_inference_steps=ddim_steps,
65
+ return_intermediates=True)
66
+ start_code = start_code.expand(len(prompts), -1, -1, -1)
67
+
68
+ # recontruct the image with inverted DDIM noise map
69
+ editor = AttentionBase()
70
+ regiter_attention_editor_diffusers(model, editor)
71
+ image_fixed = model([target_prompt],
72
+ latents=start_code[-1:],
73
+ num_inference_steps=ddim_steps,
74
+ guidance_scale=scale)
75
+ image_fixed = image_fixed.cpu().permute(0, 2, 3, 1).numpy()
76
+
77
+ # inference the synthesized image with MasaCtrl
78
+ # hijack the attention module
79
+ controller = MutualSelfAttentionControl(starting_step, starting_layer)
80
+ regiter_attention_editor_diffusers(model, controller)
81
+
82
+ # inference the synthesized image
83
+ image_masactrl = model(prompts,
84
+ latents=start_code,
85
+ guidance_scale=scale)
86
+ image_masactrl = image_masactrl.cpu().permute(0, 2, 3, 1).numpy()
87
+
88
+ return [
89
+ image_masactrl[0],
90
+ image_fixed[0],
91
+ image_masactrl[1]
92
+ ] # source, fixed seed, masactrl
93
+
94
+
95
+ def create_demo_editing():
96
+ with gr.Blocks() as demo:
97
+ gr.Markdown("## **Input Settings**")
98
+ with gr.Row():
99
+ with gr.Column():
100
+ source_image = gr.Image(label="Source Image", value=os.path.join(os.path.dirname(__file__), "images/corgi.jpg"), interactive=True)
101
+ target_prompt = gr.Textbox(label="Target Prompt",
102
+ value='A photo of a running corgi',
103
+ interactive=True)
104
+ with gr.Row():
105
+ ddim_steps = gr.Slider(label="DDIM Steps",
106
+ minimum=1,
107
+ maximum=999,
108
+ value=50,
109
+ step=1)
110
+ starting_step = gr.Slider(label="Step of MasaCtrl",
111
+ minimum=0,
112
+ maximum=999,
113
+ value=4,
114
+ step=1)
115
+ starting_layer = gr.Slider(label="Layer of MasaCtrl",
116
+ minimum=0,
117
+ maximum=16,
118
+ value=10,
119
+ step=1)
120
+ run_btn = gr.Button(label="Run")
121
+ with gr.Column():
122
+ appended_prompt = gr.Textbox(label="Appended Prompt", value='')
123
+ negative_prompt = gr.Textbox(label="Negative Prompt", value='')
124
+ with gr.Row():
125
+ scale = gr.Slider(label="CFG Scale",
126
+ minimum=0.1,
127
+ maximum=30.0,
128
+ value=7.5,
129
+ step=0.1)
130
+ seed = gr.Slider(label="Seed",
131
+ minimum=-1,
132
+ maximum=2147483647,
133
+ value=42,
134
+ step=1)
135
+
136
+ gr.Markdown("## **Output**")
137
+ with gr.Row():
138
+ image_recons = gr.Image(label="Source Image")
139
+ image_fixed = gr.Image(label="Image with Fixed Seed")
140
+ image_masactrl = gr.Image(label="Image with MasaCtrl")
141
+
142
+ inputs = [
143
+ source_image, target_prompt, starting_step, starting_layer, ddim_steps,
144
+ scale, seed, appended_prompt, negative_prompt
145
+ ]
146
+ run_btn.click(real_image_editing, inputs,
147
+ [image_recons, image_fixed, image_masactrl])
148
+
149
+ gr.Examples(
150
+ [[os.path.join(os.path.dirname(__file__), "images/corgi.jpg"),
151
+ "A photo of a running corgi"],
152
+ [os.path.join(os.path.dirname(__file__), "images/person.png"),
153
+ "A photo of a person, black t-shirt, raising hand"],
154
+ ],
155
+ [source_image, target_prompt]
156
+ )
157
+ return demo
158
+
159
+
160
+ if __name__ == "__main__":
161
+ demo_editing = create_demo_editing()
162
+ demo_editing.launch()
masactrl/__init__.py ADDED
File without changes
masactrl/diffuser_utils.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Util functions based on Diffuser framework.
3
+ """
4
+
5
+
6
+ import os
7
+ import torch
8
+ import cv2
9
+ import numpy as np
10
+
11
+ import torch.nn.functional as F
12
+ from tqdm import tqdm
13
+ from PIL import Image
14
+ from torchvision.utils import save_image
15
+ from torchvision.io import read_image
16
+
17
+ from diffusers import StableDiffusionPipeline
18
+
19
+ from pytorch_lightning import seed_everything
20
+
21
+
22
+ class MasaCtrlPipeline(StableDiffusionPipeline):
23
+
24
+ def next_step(
25
+ self,
26
+ model_output: torch.FloatTensor,
27
+ timestep: int,
28
+ x: torch.FloatTensor,
29
+ eta=0.,
30
+ verbose=False
31
+ ):
32
+ """
33
+ Inverse sampling for DDIM Inversion
34
+ """
35
+ if verbose:
36
+ print("timestep: ", timestep)
37
+ next_step = timestep
38
+ timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
39
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
40
+ alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
41
+ beta_prod_t = 1 - alpha_prod_t
42
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
43
+ pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
44
+ x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
45
+ return x_next, pred_x0
46
+
47
+ def step(
48
+ self,
49
+ model_output: torch.FloatTensor,
50
+ timestep: int,
51
+ x: torch.FloatTensor,
52
+ eta: float=0.0,
53
+ verbose=False,
54
+ ):
55
+ """
56
+ predict the sampe the next step in the denoise process.
57
+ """
58
+ prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
59
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
60
+ alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod
61
+ beta_prod_t = 1 - alpha_prod_t
62
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
63
+ pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output
64
+ x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir
65
+ return x_prev, pred_x0
66
+
67
+ @torch.no_grad()
68
+ def image2latent(self, image):
69
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
70
+ if type(image) is Image:
71
+ image = np.array(image)
72
+ image = torch.from_numpy(image).float() / 127.5 - 1
73
+ image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
74
+ # input image density range [-1, 1]
75
+ latents = self.vae.encode(image)['latent_dist'].mean
76
+ latents = latents * 0.18215
77
+ return latents
78
+
79
+ @torch.no_grad()
80
+ def latent2image(self, latents, return_type='np'):
81
+ latents = 1 / 0.18215 * latents.detach()
82
+ image = self.vae.decode(latents)['sample']
83
+ if return_type == 'np':
84
+ image = (image / 2 + 0.5).clamp(0, 1)
85
+ image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
86
+ image = (image * 255).astype(np.uint8)
87
+ elif return_type == "pt":
88
+ image = (image / 2 + 0.5).clamp(0, 1)
89
+
90
+ return image
91
+
92
+ def latent2image_grad(self, latents):
93
+ latents = 1 / 0.18215 * latents
94
+ image = self.vae.decode(latents)['sample']
95
+
96
+ return image # range [-1, 1]
97
+
98
+ @torch.no_grad()
99
+ def __call__(
100
+ self,
101
+ prompt,
102
+ batch_size=1,
103
+ height=512,
104
+ width=512,
105
+ num_inference_steps=50,
106
+ guidance_scale=7.5,
107
+ eta=0.0,
108
+ latents=None,
109
+ unconditioning=None,
110
+ neg_prompt=None,
111
+ ref_intermediate_latents=None,
112
+ return_intermediates=False,
113
+ **kwds):
114
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
115
+ if isinstance(prompt, list):
116
+ batch_size = len(prompt)
117
+ elif isinstance(prompt, str):
118
+ if batch_size > 1:
119
+ prompt = [prompt] * batch_size
120
+
121
+ # text embeddings
122
+ text_input = self.tokenizer(
123
+ prompt,
124
+ padding="max_length",
125
+ max_length=77,
126
+ return_tensors="pt"
127
+ )
128
+
129
+ text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
130
+ print("input text embeddings :", text_embeddings.shape)
131
+ if kwds.get("dir"):
132
+ dir = text_embeddings[-2] - text_embeddings[-1]
133
+ u, s, v = torch.pca_lowrank(dir.transpose(-1, -2), q=1, center=True)
134
+ text_embeddings[-1] = text_embeddings[-1] + kwds.get("dir") * v
135
+ print(u.shape)
136
+ print(v.shape)
137
+
138
+ # define initial latents
139
+ latents_shape = (batch_size, self.unet.in_channels, height//8, width//8)
140
+ if latents is None:
141
+ latents = torch.randn(latents_shape, device=DEVICE)
142
+ else:
143
+ assert latents.shape == latents_shape, f"The shape of input latent tensor {latents.shape} should equal to predefined one."
144
+
145
+ # unconditional embedding for classifier free guidance
146
+ if guidance_scale > 1.:
147
+ max_length = text_input.input_ids.shape[-1]
148
+ if neg_prompt:
149
+ uc_text = neg_prompt
150
+ else:
151
+ uc_text = ""
152
+ # uc_text = "ugly, tiling, poorly drawn hands, poorly drawn feet, body out of frame, cut off, low contrast, underexposed, distorted face"
153
+ unconditional_input = self.tokenizer(
154
+ [uc_text] * batch_size,
155
+ padding="max_length",
156
+ max_length=77,
157
+ return_tensors="pt"
158
+ )
159
+ # unconditional_input.input_ids = unconditional_input.input_ids[:, 1:]
160
+ unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0]
161
+ text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0)
162
+
163
+ print("latents shape: ", latents.shape)
164
+ # iterative sampling
165
+ self.scheduler.set_timesteps(num_inference_steps)
166
+ # print("Valid timesteps: ", reversed(self.scheduler.timesteps))
167
+ latents_list = [latents]
168
+ pred_x0_list = [latents]
169
+ for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="DDIM Sampler")):
170
+ if ref_intermediate_latents is not None:
171
+ # note that the batch_size >= 2
172
+ latents_ref = ref_intermediate_latents[-1 - i]
173
+ _, latents_cur = latents.chunk(2)
174
+ latents = torch.cat([latents_ref, latents_cur])
175
+
176
+ if guidance_scale > 1.:
177
+ model_inputs = torch.cat([latents] * 2)
178
+ else:
179
+ model_inputs = latents
180
+ if unconditioning is not None and isinstance(unconditioning, list):
181
+ _, text_embeddings = text_embeddings.chunk(2)
182
+ text_embeddings = torch.cat([unconditioning[i].expand(*text_embeddings.shape), text_embeddings])
183
+ # predict tghe noise
184
+ noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample
185
+ if guidance_scale > 1.:
186
+ noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
187
+ noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
188
+ # compute the previous noise sample x_t -> x_t-1
189
+ latents, pred_x0 = self.step(noise_pred, t, latents)
190
+ latents_list.append(latents)
191
+ pred_x0_list.append(pred_x0)
192
+
193
+ image = self.latent2image(latents, return_type="pt")
194
+ if return_intermediates:
195
+ pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list]
196
+ latents_list = [self.latent2image(img, return_type="pt") for img in latents_list]
197
+ return image, pred_x0_list, latents_list
198
+ return image
199
+
200
+ @torch.no_grad()
201
+ def invert(
202
+ self,
203
+ image: torch.Tensor,
204
+ prompt,
205
+ num_inference_steps=50,
206
+ guidance_scale=7.5,
207
+ eta=0.0,
208
+ return_intermediates=False,
209
+ **kwds):
210
+ """
211
+ invert a real image into noise map with determinisc DDIM inversion
212
+ """
213
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
214
+ batch_size = image.shape[0]
215
+ if isinstance(prompt, list):
216
+ if batch_size == 1:
217
+ image = image.expand(len(prompt), -1, -1, -1)
218
+ elif isinstance(prompt, str):
219
+ if batch_size > 1:
220
+ prompt = [prompt] * batch_size
221
+
222
+ # text embeddings
223
+ text_input = self.tokenizer(
224
+ prompt,
225
+ padding="max_length",
226
+ max_length=77,
227
+ return_tensors="pt"
228
+ )
229
+ text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
230
+ print("input text embeddings :", text_embeddings.shape)
231
+ # define initial latents
232
+ latents = self.image2latent(image)
233
+ start_latents = latents
234
+ # print(latents)
235
+ # exit()
236
+ # unconditional embedding for classifier free guidance
237
+ if guidance_scale > 1.:
238
+ max_length = text_input.input_ids.shape[-1]
239
+ unconditional_input = self.tokenizer(
240
+ [""] * batch_size,
241
+ padding="max_length",
242
+ max_length=77,
243
+ return_tensors="pt"
244
+ )
245
+ unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0]
246
+ text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0)
247
+
248
+ print("latents shape: ", latents.shape)
249
+ # interative sampling
250
+ self.scheduler.set_timesteps(num_inference_steps)
251
+ print("Valid timesteps: ", reversed(self.scheduler.timesteps))
252
+ # print("attributes: ", self.scheduler.__dict__)
253
+ latents_list = [latents]
254
+ pred_x0_list = [latents]
255
+ for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
256
+ if guidance_scale > 1.:
257
+ model_inputs = torch.cat([latents] * 2)
258
+ else:
259
+ model_inputs = latents
260
+
261
+ # predict the noise
262
+ noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample
263
+ if guidance_scale > 1.:
264
+ noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
265
+ noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
266
+ # compute the previous noise sample x_t-1 -> x_t
267
+ latents, pred_x0 = self.next_step(noise_pred, t, latents)
268
+ latents_list.append(latents)
269
+ pred_x0_list.append(pred_x0)
270
+
271
+ if return_intermediates:
272
+ # return the intermediate laters during inversion
273
+ # pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list]
274
+ return latents, latents_list
275
+ return latents, start_latents
masactrl/masactrl.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+
7
+ from einops import rearrange
8
+
9
+ from .masactrl_utils import AttentionBase
10
+
11
+ from torchvision.utils import save_image
12
+
13
+
14
+ class MutualSelfAttentionControl(AttentionBase):
15
+ MODEL_TYPE = {
16
+ "SD": 16,
17
+ "SDXL": 70
18
+ }
19
+
20
+ def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, model_type="SD"):
21
+ """
22
+ Mutual self-attention control for Stable-Diffusion model
23
+ Args:
24
+ start_step: the step to start mutual self-attention control
25
+ start_layer: the layer to start mutual self-attention control
26
+ layer_idx: list of the layers to apply mutual self-attention control
27
+ step_idx: list the steps to apply mutual self-attention control
28
+ total_steps: the total number of steps
29
+ model_type: the model type, SD or SDXL
30
+ """
31
+ super().__init__()
32
+ self.total_steps = total_steps
33
+ self.total_layers = self.MODEL_TYPE.get(model_type, 16)
34
+ self.start_step = start_step
35
+ self.start_layer = start_layer
36
+ self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers))
37
+ self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps))
38
+ print("MasaCtrl at denoising steps: ", self.step_idx)
39
+ print("MasaCtrl at U-Net layers: ", self.layer_idx)
40
+
41
+ def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
42
+ """
43
+ Performing attention for a batch of queries, keys, and values
44
+ """
45
+ b = q.shape[0] // num_heads
46
+ q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
47
+ k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
48
+ v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
49
+
50
+ sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
51
+ attn = sim.softmax(-1)
52
+ out = torch.einsum("h i j, h j d -> h i d", attn, v)
53
+ out = rearrange(out, "h (b n) d -> b n (h d)", b=b)
54
+ return out
55
+
56
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
57
+ """
58
+ Attention forward function
59
+ """
60
+ if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
61
+ return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
62
+
63
+ qu, qc = q.chunk(2)
64
+ ku, kc = k.chunk(2)
65
+ vu, vc = v.chunk(2)
66
+ attnu, attnc = attn.chunk(2)
67
+
68
+ out_u = self.attn_batch(qu, ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)
69
+ out_c = self.attn_batch(qc, kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)
70
+ out = torch.cat([out_u, out_c], dim=0)
71
+
72
+ return out
73
+
74
+
75
+ class MutualSelfAttentionControlUnion(MutualSelfAttentionControl):
76
+ def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, model_type="SD"):
77
+ """
78
+ Mutual self-attention control for Stable-Diffusion model with unition source and target [K, V]
79
+ Args:
80
+ start_step: the step to start mutual self-attention control
81
+ start_layer: the layer to start mutual self-attention control
82
+ layer_idx: list of the layers to apply mutual self-attention control
83
+ step_idx: list the steps to apply mutual self-attention control
84
+ total_steps: the total number of steps
85
+ model_type: the model type, SD or SDXL
86
+ """
87
+ super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps, model_type)
88
+
89
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
90
+ """
91
+ Attention forward function
92
+ """
93
+ if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
94
+ return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
95
+
96
+ qu_s, qu_t, qc_s, qc_t = q.chunk(4)
97
+ ku_s, ku_t, kc_s, kc_t = k.chunk(4)
98
+ vu_s, vu_t, vc_s, vc_t = v.chunk(4)
99
+ attnu_s, attnu_t, attnc_s, attnc_t = attn.chunk(4)
100
+
101
+ # source image branch
102
+ out_u_s = super().forward(qu_s, ku_s, vu_s, sim, attnu_s, is_cross, place_in_unet, num_heads, **kwargs)
103
+ out_c_s = super().forward(qc_s, kc_s, vc_s, sim, attnc_s, is_cross, place_in_unet, num_heads, **kwargs)
104
+
105
+ # target image branch, concatenating source and target [K, V]
106
+ out_u_t = self.attn_batch(qu_t, torch.cat([ku_s, ku_t]), torch.cat([vu_s, vu_t]), sim[:num_heads], attnu_t, is_cross, place_in_unet, num_heads, **kwargs)
107
+ out_c_t = self.attn_batch(qc_t, torch.cat([kc_s, kc_t]), torch.cat([vc_s, vc_t]), sim[:num_heads], attnc_t, is_cross, place_in_unet, num_heads, **kwargs)
108
+
109
+ out = torch.cat([out_u_s, out_u_t, out_c_s, out_c_t], dim=0)
110
+
111
+ return out
112
+
113
+
114
+ class MutualSelfAttentionControlMask(MutualSelfAttentionControl):
115
+ def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, mask_s=None, mask_t=None, mask_save_dir=None, model_type="SD"):
116
+ """
117
+ Maske-guided MasaCtrl to alleviate the problem of fore- and background confusion
118
+ Args:
119
+ start_step: the step to start mutual self-attention control
120
+ start_layer: the layer to start mutual self-attention control
121
+ layer_idx: list of the layers to apply mutual self-attention control
122
+ step_idx: list the steps to apply mutual self-attention control
123
+ total_steps: the total number of steps
124
+ mask_s: source mask with shape (h, w)
125
+ mask_t: target mask with same shape as source mask
126
+ mask_save_dir: the path to save the mask image
127
+ model_type: the model type, SD or SDXL
128
+ """
129
+ super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps, model_type)
130
+ self.mask_s = mask_s # source mask with shape (h, w)
131
+ self.mask_t = mask_t # target mask with same shape as source mask
132
+ print("Using mask-guided MasaCtrl")
133
+ if mask_save_dir is not None:
134
+ os.makedirs(mask_save_dir, exist_ok=True)
135
+ save_image(self.mask_s.unsqueeze(0).unsqueeze(0), os.path.join(mask_save_dir, "mask_s.png"))
136
+ save_image(self.mask_t.unsqueeze(0).unsqueeze(0), os.path.join(mask_save_dir, "mask_t.png"))
137
+
138
+ def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
139
+ B = q.shape[0] // num_heads
140
+ H = W = int(np.sqrt(q.shape[1]))
141
+ q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
142
+ k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
143
+ v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
144
+
145
+ sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
146
+ if kwargs.get("is_mask_attn") and self.mask_s is not None:
147
+ print("masked attention")
148
+ mask = self.mask_s.unsqueeze(0).unsqueeze(0)
149
+ mask = F.interpolate(mask, (H, W)).flatten(0).unsqueeze(0)
150
+ mask = mask.flatten()
151
+ # background
152
+ sim_bg = sim + mask.masked_fill(mask == 1, torch.finfo(sim.dtype).min)
153
+ # object
154
+ sim_fg = sim + mask.masked_fill(mask == 0, torch.finfo(sim.dtype).min)
155
+ sim = torch.cat([sim_fg, sim_bg], dim=0)
156
+ attn = sim.softmax(-1)
157
+ if len(attn) == 2 * len(v):
158
+ v = torch.cat([v] * 2)
159
+ out = torch.einsum("h i j, h j d -> h i d", attn, v)
160
+ out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
161
+ return out
162
+
163
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
164
+ """
165
+ Attention forward function
166
+ """
167
+ if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
168
+ return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
169
+
170
+ B = q.shape[0] // num_heads // 2
171
+ H = W = int(np.sqrt(q.shape[1]))
172
+ qu, qc = q.chunk(2)
173
+ ku, kc = k.chunk(2)
174
+ vu, vc = v.chunk(2)
175
+ attnu, attnc = attn.chunk(2)
176
+
177
+ out_u_source = self.attn_batch(qu[:num_heads], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)
178
+ out_c_source = self.attn_batch(qc[:num_heads], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)
179
+
180
+ out_u_target = self.attn_batch(qu[-num_heads:], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, is_mask_attn=True, **kwargs)
181
+ out_c_target = self.attn_batch(qc[-num_heads:], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, is_mask_attn=True, **kwargs)
182
+
183
+ if self.mask_s is not None and self.mask_t is not None:
184
+ out_u_target_fg, out_u_target_bg = out_u_target.chunk(2, 0)
185
+ out_c_target_fg, out_c_target_bg = out_c_target.chunk(2, 0)
186
+
187
+ mask = F.interpolate(self.mask_t.unsqueeze(0).unsqueeze(0), (H, W))
188
+ mask = mask.reshape(-1, 1) # (hw, 1)
189
+ out_u_target = out_u_target_fg * mask + out_u_target_bg * (1 - mask)
190
+ out_c_target = out_c_target_fg * mask + out_c_target_bg * (1 - mask)
191
+
192
+ out = torch.cat([out_u_source, out_u_target, out_c_source, out_c_target], dim=0)
193
+ return out
194
+
195
+
196
+ class MutualSelfAttentionControlMaskAuto(MutualSelfAttentionControl):
197
+ def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, thres=0.1, ref_token_idx=[1], cur_token_idx=[1], mask_save_dir=None, model_type="SD"):
198
+ """
199
+ MasaCtrl with mask auto generation from cross-attention map
200
+ Args:
201
+ start_step: the step to start mutual self-attention control
202
+ start_layer: the layer to start mutual self-attention control
203
+ layer_idx: list of the layers to apply mutual self-attention control
204
+ step_idx: list the steps to apply mutual self-attention control
205
+ total_steps: the total number of steps
206
+ thres: the thereshold for mask thresholding
207
+ ref_token_idx: the token index list for cross-attention map aggregation
208
+ cur_token_idx: the token index list for cross-attention map aggregation
209
+ mask_save_dir: the path to save the mask image
210
+ """
211
+ super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps, model_type)
212
+ print("Using MutualSelfAttentionControlMaskAuto")
213
+ self.thres = thres
214
+ self.ref_token_idx = ref_token_idx
215
+ self.cur_token_idx = cur_token_idx
216
+
217
+ self.self_attns = []
218
+ self.cross_attns = []
219
+
220
+ self.cross_attns_mask = None
221
+ self.self_attns_mask = None
222
+
223
+ self.mask_save_dir = mask_save_dir
224
+ if self.mask_save_dir is not None:
225
+ os.makedirs(self.mask_save_dir, exist_ok=True)
226
+
227
+ def after_step(self):
228
+ self.self_attns = []
229
+ self.cross_attns = []
230
+
231
+ def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
232
+ """
233
+ Performing attention for a batch of queries, keys, and values
234
+ """
235
+ B = q.shape[0] // num_heads
236
+ H = W = int(np.sqrt(q.shape[1]))
237
+ q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
238
+ k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
239
+ v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
240
+
241
+ sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
242
+ if self.self_attns_mask is not None:
243
+ # binarize the mask
244
+ mask = self.self_attns_mask
245
+ thres = self.thres
246
+ mask[mask >= thres] = 1
247
+ mask[mask < thres] = 0
248
+ sim_fg = sim + mask.masked_fill(mask == 0, torch.finfo(sim.dtype).min)
249
+ sim_bg = sim + mask.masked_fill(mask == 1, torch.finfo(sim.dtype).min)
250
+ sim = torch.cat([sim_fg, sim_bg])
251
+
252
+ attn = sim.softmax(-1)
253
+
254
+ if len(attn) == 2 * len(v):
255
+ v = torch.cat([v] * 2)
256
+ out = torch.einsum("h i j, h j d -> h i d", attn, v)
257
+ out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
258
+ return out
259
+
260
+ def aggregate_cross_attn_map(self, idx):
261
+ attn_map = torch.stack(self.cross_attns, dim=1).mean(1) # (B, N, dim)
262
+ B = attn_map.shape[0]
263
+ res = int(np.sqrt(attn_map.shape[-2]))
264
+ attn_map = attn_map.reshape(-1, res, res, attn_map.shape[-1])
265
+ image = attn_map[..., idx]
266
+ if isinstance(idx, list):
267
+ image = image.sum(-1)
268
+ image_min = image.min(dim=1, keepdim=True)[0].min(dim=2, keepdim=True)[0]
269
+ image_max = image.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0]
270
+ image = (image - image_min) / (image_max - image_min)
271
+ return image
272
+
273
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
274
+ """
275
+ Attention forward function
276
+ """
277
+ if is_cross:
278
+ # save cross attention map with res 16 * 16
279
+ if attn.shape[1] == 16 * 16:
280
+ self.cross_attns.append(attn.reshape(-1, num_heads, *attn.shape[-2:]).mean(1))
281
+
282
+ if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
283
+ return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
284
+
285
+ B = q.shape[0] // num_heads // 2
286
+ H = W = int(np.sqrt(q.shape[1]))
287
+ qu, qc = q.chunk(2)
288
+ ku, kc = k.chunk(2)
289
+ vu, vc = v.chunk(2)
290
+ attnu, attnc = attn.chunk(2)
291
+
292
+ out_u_source = self.attn_batch(qu[:num_heads], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)
293
+ out_c_source = self.attn_batch(qc[:num_heads], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)
294
+
295
+ if len(self.cross_attns) == 0:
296
+ self.self_attns_mask = None
297
+ out_u_target = self.attn_batch(qu[-num_heads:], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)
298
+ out_c_target = self.attn_batch(qc[-num_heads:], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)
299
+ else:
300
+ mask = self.aggregate_cross_attn_map(idx=self.ref_token_idx) # (2, H, W)
301
+ mask_source = mask[-2] # (H, W)
302
+ res = int(np.sqrt(q.shape[1]))
303
+ self.self_attns_mask = F.interpolate(mask_source.unsqueeze(0).unsqueeze(0), (res, res)).flatten()
304
+ if self.mask_save_dir is not None:
305
+ H = W = int(np.sqrt(self.self_attns_mask.shape[0]))
306
+ mask_image = self.self_attns_mask.reshape(H, W).unsqueeze(0)
307
+ save_image(mask_image, os.path.join(self.mask_save_dir, f"mask_s_{self.cur_step}_{self.cur_att_layer}.png"))
308
+ out_u_target = self.attn_batch(qu[-num_heads:], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)
309
+ out_c_target = self.attn_batch(qc[-num_heads:], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)
310
+
311
+ if self.self_attns_mask is not None:
312
+ mask = self.aggregate_cross_attn_map(idx=self.cur_token_idx) # (2, H, W)
313
+ mask_target = mask[-1] # (H, W)
314
+ res = int(np.sqrt(q.shape[1]))
315
+ spatial_mask = F.interpolate(mask_target.unsqueeze(0).unsqueeze(0), (res, res)).reshape(-1, 1)
316
+ if self.mask_save_dir is not None:
317
+ H = W = int(np.sqrt(spatial_mask.shape[0]))
318
+ mask_image = spatial_mask.reshape(H, W).unsqueeze(0)
319
+ save_image(mask_image, os.path.join(self.mask_save_dir, f"mask_t_{self.cur_step}_{self.cur_att_layer}.png"))
320
+ # binarize the mask
321
+ thres = self.thres
322
+ spatial_mask[spatial_mask >= thres] = 1
323
+ spatial_mask[spatial_mask < thres] = 0
324
+ out_u_target_fg, out_u_target_bg = out_u_target.chunk(2)
325
+ out_c_target_fg, out_c_target_bg = out_c_target.chunk(2)
326
+
327
+ out_u_target = out_u_target_fg * spatial_mask + out_u_target_bg * (1 - spatial_mask)
328
+ out_c_target = out_c_target_fg * spatial_mask + out_c_target_bg * (1 - spatial_mask)
329
+
330
+ # set self self-attention mask to None
331
+ self.self_attns_mask = None
332
+
333
+ out = torch.cat([out_u_source, out_u_target, out_c_source, out_c_target], dim=0)
334
+ return out
masactrl/masactrl_processor.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import import_module
2
+ from typing import Callable, Optional, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from diffusers.utils import deprecate, logging
9
+ from diffusers.utils.import_utils import is_xformers_available
10
+ from diffusers.models.attention import Attention
11
+
12
+
13
+ def register_attention_processor(
14
+ model: Optional[nn.Module] = None,
15
+ processor_type: str = "MasaCtrlProcessor",
16
+ **attn_args,
17
+ ):
18
+ """
19
+ Args:
20
+ model: a unet model or a list of unet models
21
+ processor_type: the type of the processor
22
+ """
23
+ if not isinstance(model, (list, tuple)):
24
+ model = [model]
25
+
26
+ if processor_type == "MasaCtrlProcessor":
27
+ processor = MasaCtrlProcessor(**attn_args)
28
+ else:
29
+ processor = AttnProcessor()
30
+
31
+ for m in model:
32
+ m.set_attn_processor(processor)
33
+ print(f"Model {m.__class__.__name__} is registered attention processor: {processor_type}")
34
+
35
+
36
+ class AttnProcessor:
37
+ r"""
38
+ Default processor for performing attention-related computations.
39
+ """
40
+
41
+ def __call__(
42
+ self,
43
+ attn: Attention,
44
+ hidden_states: torch.Tensor,
45
+ encoder_hidden_states: Optional[torch.Tensor] = None,
46
+ attention_mask: Optional[torch.Tensor] = None,
47
+ temb: Optional[torch.Tensor] = None,
48
+ *args,
49
+ **kwargs,
50
+ ) -> torch.Tensor:
51
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
52
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
53
+ deprecate("scale", "1.0.0", deprecation_message)
54
+
55
+ residual = hidden_states
56
+
57
+ if attn.spatial_norm is not None:
58
+ hidden_states = attn.spatial_norm(hidden_states, temb)
59
+
60
+ input_ndim = hidden_states.ndim
61
+
62
+ if input_ndim == 4:
63
+ batch_size, channel, height, width = hidden_states.shape
64
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
65
+
66
+ batch_size, sequence_length, _ = (
67
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
68
+ )
69
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
70
+
71
+ if attn.group_norm is not None:
72
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
73
+
74
+ query = attn.to_q(hidden_states)
75
+
76
+ if encoder_hidden_states is None:
77
+ encoder_hidden_states = hidden_states
78
+ elif attn.norm_cross:
79
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
80
+
81
+ key = attn.to_k(encoder_hidden_states)
82
+ value = attn.to_v(encoder_hidden_states)
83
+
84
+ query = attn.head_to_batch_dim(query)
85
+ key = attn.head_to_batch_dim(key)
86
+ value = attn.head_to_batch_dim(value)
87
+
88
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
89
+ hidden_states = torch.bmm(attention_probs, value)
90
+ hidden_states = attn.batch_to_head_dim(hidden_states)
91
+
92
+ # linear proj
93
+ hidden_states = attn.to_out[0](hidden_states)
94
+ # dropout
95
+ hidden_states = attn.to_out[1](hidden_states)
96
+
97
+ if input_ndim == 4:
98
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
99
+
100
+ if attn.residual_connection:
101
+ hidden_states = hidden_states + residual
102
+
103
+ hidden_states = hidden_states / attn.rescale_output_factor
104
+
105
+ return hidden_states
106
+
107
+
108
+ class MasaCtrlProcessor(nn.Module):
109
+ """
110
+ Mutual Self-attention Processor for diffusers library.
111
+ Note that the all attention layers should register the same processor.
112
+ """
113
+ MODEL_TYPE = {
114
+ "SD": 16,
115
+ "SDXL": 70
116
+ }
117
+ def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_layers=32, total_steps=50, model_type="SD"):
118
+ """
119
+ Mutual self-attention control for Stable-Diffusion model
120
+ Args:
121
+ start_step: the step to start mutual self-attention control
122
+ start_layer: the layer to start mutual self-attention control
123
+ layer_idx: list of the layers to apply mutual self-attention control
124
+ step_idx: list the steps to apply mutual self-attention control
125
+ total_steps: the total number of steps, must be same to the denoising steps used in denoising scheduler
126
+ model_type: the model type, SD or SDXL
127
+ """
128
+ super().__init__()
129
+ self.total_steps = total_steps
130
+ self.total_layers = self.MODEL_TYPE.get(model_type, 16)
131
+ self.start_step = start_step
132
+ self.start_layer = start_layer
133
+ self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers))
134
+ self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps))
135
+ print("MasaCtrl at denoising steps: ", self.step_idx)
136
+ print("MasaCtrl at U-Net layers: ", self.layer_idx)
137
+
138
+ self.cur_step = 0
139
+ self.cur_att_layer = 0
140
+ self.num_attn_layers = total_layers
141
+
142
+ def after_step(self):
143
+ pass
144
+
145
+ def __call__(
146
+ self,
147
+ attn: Attention,
148
+ hidden_states: torch.FloatTensor,
149
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
150
+ attention_mask: Optional[torch.FloatTensor] = None,
151
+ temb: Optional[torch.FloatTensor] = None,
152
+ scale: float = 1.0,
153
+ ):
154
+ out = self.attn_forward(
155
+ attn,
156
+ hidden_states,
157
+ encoder_hidden_states,
158
+ attention_mask,
159
+ temb,
160
+ scale,
161
+ )
162
+ self.cur_att_layer += 1
163
+ if self.cur_att_layer == self.num_attn_layers:
164
+ self.cur_att_layer = 0
165
+ self.cur_step += 1
166
+ self.cur_step %= self.total_steps
167
+ # after step
168
+ self.after_step()
169
+ return out
170
+
171
+ def masactrl_forward(
172
+ self,
173
+ query,
174
+ key,
175
+ value,
176
+ ):
177
+ """
178
+ Rearrange the key and value for mutual self-attention control
179
+ """
180
+ ku_src, ku_tgt, kc_src, kc_tgt = key.chunk(4)
181
+ vu_src, vu_tgt, vc_src, vc_tgt = value.chunk(4)
182
+
183
+ k_rearranged = torch.cat([ku_src, ku_src, kc_src, kc_src])
184
+ v_rearranged = torch.cat([vu_src, vu_src, vc_src, vc_src])
185
+
186
+ return query, k_rearranged, v_rearranged
187
+
188
+ def attn_forward(
189
+ self,
190
+ attn: Attention,
191
+ hidden_states: torch.Tensor,
192
+ encoder_hidden_states: Optional[torch.Tensor] = None,
193
+ attention_mask: Optional[torch.Tensor] = None,
194
+ temb: Optional[torch.Tensor] = None,
195
+ *args,
196
+ **kwargs,
197
+ ):
198
+ cur_transformer_layer = self.cur_att_layer // 2
199
+ residual = hidden_states
200
+
201
+ is_cross = True if encoder_hidden_states is not None else False
202
+
203
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
204
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
205
+ deprecate("scale", "1.0.0", deprecation_message)
206
+
207
+ if attn.spatial_norm is not None:
208
+ hidden_states = attn.spatial_norm(hidden_states, temb)
209
+
210
+ input_ndim = hidden_states.ndim
211
+
212
+ if input_ndim == 4:
213
+ batch_size, channel, height, width = hidden_states.shape
214
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
215
+
216
+ batch_size, sequence_length, _ = (
217
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
218
+ )
219
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
220
+
221
+ if attn.group_norm is not None:
222
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
223
+
224
+ query = attn.to_q(hidden_states, *args)
225
+
226
+ if encoder_hidden_states is None:
227
+ encoder_hidden_states = hidden_states
228
+ elif attn.norm_cross:
229
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
230
+
231
+ key = attn.to_k(encoder_hidden_states, *args)
232
+ value = attn.to_v(encoder_hidden_states, *args)
233
+
234
+ # mutual self-attention control
235
+ if not is_cross and self.cur_step in self.step_idx and cur_transformer_layer in self.layer_idx:
236
+ query, key, value = self.masactrl_forward(query, key, value)
237
+
238
+ query = attn.head_to_batch_dim(query)
239
+ key = attn.head_to_batch_dim(key)
240
+ value = attn.head_to_batch_dim(value)
241
+
242
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
243
+ hidden_states = torch.bmm(attention_probs, value)
244
+ hidden_states = attn.batch_to_head_dim(hidden_states)
245
+
246
+ # linear proj
247
+ hidden_states = attn.to_out[0](hidden_states, *args)
248
+ # dropout
249
+ hidden_states = attn.to_out[1](hidden_states)
250
+
251
+ if input_ndim == 4:
252
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
253
+
254
+ if attn.residual_connection:
255
+ hidden_states = hidden_states + residual
256
+
257
+ hidden_states = hidden_states / attn.rescale_output_factor
258
+
259
+ return hidden_states
masactrl/masactrl_utils.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from typing import Optional, Union, Tuple, List, Callable, Dict
9
+
10
+ from torchvision.utils import save_image
11
+ from einops import rearrange, repeat
12
+
13
+
14
+ class AttentionBase:
15
+ def __init__(self):
16
+ self.cur_step = 0
17
+ self.num_att_layers = -1
18
+ self.cur_att_layer = 0
19
+
20
+ def after_step(self):
21
+ pass
22
+
23
+ def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
24
+ out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
25
+ self.cur_att_layer += 1
26
+ if self.cur_att_layer == self.num_att_layers:
27
+ self.cur_att_layer = 0
28
+ self.cur_step += 1
29
+ # after step
30
+ self.after_step()
31
+ return out
32
+
33
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
34
+ out = torch.einsum('b i j, b j d -> b i d', attn, v)
35
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
36
+ return out
37
+
38
+ def reset(self):
39
+ self.cur_step = 0
40
+ self.cur_att_layer = 0
41
+
42
+
43
+ class AttentionStore(AttentionBase):
44
+ def __init__(self, res=[32], min_step=0, max_step=1000):
45
+ super().__init__()
46
+ self.res = res
47
+ self.min_step = min_step
48
+ self.max_step = max_step
49
+ self.valid_steps = 0
50
+
51
+ self.self_attns = [] # store the all attns
52
+ self.cross_attns = []
53
+
54
+ self.self_attns_step = [] # store the attns in each step
55
+ self.cross_attns_step = []
56
+
57
+ def after_step(self):
58
+ if self.cur_step > self.min_step and self.cur_step < self.max_step:
59
+ self.valid_steps += 1
60
+ if len(self.self_attns) == 0:
61
+ self.self_attns = self.self_attns_step
62
+ self.cross_attns = self.cross_attns_step
63
+ else:
64
+ for i in range(len(self.self_attns)):
65
+ self.self_attns[i] += self.self_attns_step[i]
66
+ self.cross_attns[i] += self.cross_attns_step[i]
67
+ self.self_attns_step.clear()
68
+ self.cross_attns_step.clear()
69
+
70
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
71
+ if attn.shape[1] <= 64 ** 2: # avoid OOM
72
+ if is_cross:
73
+ self.cross_attns_step.append(attn)
74
+ else:
75
+ self.self_attns_step.append(attn)
76
+ return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
77
+
78
+
79
+ def regiter_attention_editor_diffusers(model, editor: AttentionBase):
80
+ """
81
+ Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
82
+ """
83
+ def ca_forward(self, place_in_unet):
84
+ def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
85
+ """
86
+ The attention is similar to the original implementation of LDM CrossAttention class
87
+ except adding some modifications on the attention
88
+ """
89
+ if encoder_hidden_states is not None:
90
+ context = encoder_hidden_states
91
+ if attention_mask is not None:
92
+ mask = attention_mask
93
+
94
+ to_out = self.to_out
95
+ if isinstance(to_out, nn.modules.container.ModuleList):
96
+ to_out = self.to_out[0]
97
+ else:
98
+ to_out = self.to_out
99
+
100
+ h = self.heads
101
+ q = self.to_q(x)
102
+ is_cross = context is not None
103
+ context = context if is_cross else x
104
+ k = self.to_k(context)
105
+ v = self.to_v(context)
106
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
107
+
108
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
109
+
110
+ if mask is not None:
111
+ mask = rearrange(mask, 'b ... -> b (...)')
112
+ max_neg_value = -torch.finfo(sim.dtype).max
113
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
114
+ mask = mask[:, None, :].repeat(h, 1, 1)
115
+ sim.masked_fill_(~mask, max_neg_value)
116
+
117
+ attn = sim.softmax(dim=-1)
118
+ # the only difference
119
+ out = editor(
120
+ q, k, v, sim, attn, is_cross, place_in_unet,
121
+ self.heads, scale=self.scale)
122
+
123
+ return to_out(out)
124
+
125
+ return forward
126
+
127
+ def register_editor(net, count, place_in_unet):
128
+ for name, subnet in net.named_children():
129
+ if net.__class__.__name__ == 'Attention': # spatial Transformer layer
130
+ net.forward = ca_forward(net, place_in_unet)
131
+ return count + 1
132
+ elif hasattr(net, 'children'):
133
+ count = register_editor(subnet, count, place_in_unet)
134
+ return count
135
+
136
+ cross_att_count = 0
137
+ for net_name, net in model.unet.named_children():
138
+ if "down" in net_name:
139
+ cross_att_count += register_editor(net, 0, "down")
140
+ elif "mid" in net_name:
141
+ cross_att_count += register_editor(net, 0, "mid")
142
+ elif "up" in net_name:
143
+ cross_att_count += register_editor(net, 0, "up")
144
+ editor.num_att_layers = cross_att_count
145
+
146
+
147
+ def regiter_attention_editor_ldm(model, editor: AttentionBase):
148
+ """
149
+ Register a attention editor to Stable Diffusion model, refer from [Prompt-to-Prompt]
150
+ """
151
+ def ca_forward(self, place_in_unet):
152
+ def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
153
+ """
154
+ The attention is similar to the original implementation of LDM CrossAttention class
155
+ except adding some modifications on the attention
156
+ """
157
+ if encoder_hidden_states is not None:
158
+ context = encoder_hidden_states
159
+ if attention_mask is not None:
160
+ mask = attention_mask
161
+
162
+ to_out = self.to_out
163
+ if isinstance(to_out, nn.modules.container.ModuleList):
164
+ to_out = self.to_out[0]
165
+ else:
166
+ to_out = self.to_out
167
+
168
+ h = self.heads
169
+ q = self.to_q(x)
170
+ is_cross = context is not None
171
+ context = context if is_cross else x
172
+ k = self.to_k(context)
173
+ v = self.to_v(context)
174
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
175
+
176
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
177
+
178
+ if mask is not None:
179
+ mask = rearrange(mask, 'b ... -> b (...)')
180
+ max_neg_value = -torch.finfo(sim.dtype).max
181
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
182
+ mask = mask[:, None, :].repeat(h, 1, 1)
183
+ sim.masked_fill_(~mask, max_neg_value)
184
+
185
+ attn = sim.softmax(dim=-1)
186
+ # the only difference
187
+ out = editor(
188
+ q, k, v, sim, attn, is_cross, place_in_unet,
189
+ self.heads, scale=self.scale)
190
+
191
+ return to_out(out)
192
+
193
+ return forward
194
+
195
+ def register_editor(net, count, place_in_unet):
196
+ for name, subnet in net.named_children():
197
+ if net.__class__.__name__ == 'CrossAttention': # spatial Transformer layer
198
+ net.forward = ca_forward(net, place_in_unet)
199
+ return count + 1
200
+ elif hasattr(net, 'children'):
201
+ count = register_editor(subnet, count, place_in_unet)
202
+ return count
203
+
204
+ cross_att_count = 0
205
+ for net_name, net in model.model.diffusion_model.named_children():
206
+ if "input" in net_name:
207
+ cross_att_count += register_editor(net, 0, "input")
208
+ elif "middle" in net_name:
209
+ cross_att_count += register_editor(net, 0, "middle")
210
+ elif "output" in net_name:
211
+ cross_att_count += register_editor(net, 0, "output")
212
+ editor.num_att_layers = cross_att_count
masactrl_w_adapter/ddim.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
8
+ extract_into_tensor
9
+
10
+
11
+ class DDIMSampler(object):
12
+ def __init__(self, model, schedule="linear", **kwargs):
13
+ super().__init__()
14
+ self.model = model
15
+ self.ddpm_num_timesteps = model.num_timesteps
16
+ self.schedule = schedule
17
+
18
+ def register_buffer(self, name, attr):
19
+ if type(attr) == torch.Tensor:
20
+ if attr.device != torch.device("cuda"):
21
+ attr = attr.to(torch.device("cuda"))
22
+ setattr(self, name, attr)
23
+
24
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
26
+ num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
27
+ alphas_cumprod = self.model.alphas_cumprod
28
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
29
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
30
+
31
+ self.register_buffer('betas', to_torch(self.model.betas))
32
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
33
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
34
+
35
+ # calculations for diffusion q(x_t | x_{t-1}) and others
36
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
37
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
38
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
40
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
41
+
42
+ # ddim sampling parameters
43
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
44
+ ddim_timesteps=self.ddim_timesteps,
45
+ eta=ddim_eta, verbose=verbose)
46
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
47
+ self.register_buffer('ddim_alphas', ddim_alphas)
48
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
49
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
50
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
51
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
52
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
53
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
54
+
55
+ @torch.no_grad()
56
+ def sample(self,
57
+ S,
58
+ batch_size,
59
+ shape,
60
+ conditioning=None,
61
+ callback=None,
62
+ normals_sequence=None,
63
+ img_callback=None,
64
+ quantize_x0=False,
65
+ eta=0.,
66
+ mask=None,
67
+ x0=None,
68
+ temperature=1.,
69
+ noise_dropout=0.,
70
+ score_corrector=None,
71
+ corrector_kwargs=None,
72
+ verbose=True,
73
+ x_T=None,
74
+ log_every_t=100,
75
+ unconditional_guidance_scale=1.,
76
+ unconditional_conditioning=None,
77
+ features_adapter=None,
78
+ append_to_context=None,
79
+ cond_tau=0.4,
80
+ style_cond_tau=1.0,
81
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
82
+ **kwargs
83
+ ):
84
+ if conditioning is not None:
85
+ if isinstance(conditioning, dict):
86
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
87
+ if cbs != batch_size:
88
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
89
+ else:
90
+ if conditioning.shape[0] != batch_size:
91
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
92
+
93
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
94
+ # sampling
95
+ C, H, W = shape
96
+ size = (batch_size, C, H, W)
97
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
98
+
99
+ samples, intermediates = self.ddim_sampling(conditioning, size,
100
+ callback=callback,
101
+ img_callback=img_callback,
102
+ quantize_denoised=quantize_x0,
103
+ mask=mask, x0=x0,
104
+ ddim_use_original_steps=False,
105
+ noise_dropout=noise_dropout,
106
+ temperature=temperature,
107
+ score_corrector=score_corrector,
108
+ corrector_kwargs=corrector_kwargs,
109
+ x_T=x_T,
110
+ log_every_t=log_every_t,
111
+ unconditional_guidance_scale=unconditional_guidance_scale,
112
+ unconditional_conditioning=unconditional_conditioning,
113
+ features_adapter=features_adapter,
114
+ append_to_context=append_to_context,
115
+ cond_tau=cond_tau,
116
+ style_cond_tau=style_cond_tau,
117
+ )
118
+ return samples, intermediates
119
+
120
+ @torch.no_grad()
121
+ def ddim_sampling(self, cond, shape,
122
+ x_T=None, ddim_use_original_steps=False,
123
+ callback=None, timesteps=None, quantize_denoised=False,
124
+ mask=None, x0=None, img_callback=None, log_every_t=100,
125
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
126
+ unconditional_guidance_scale=1., unconditional_conditioning=None, features_adapter=None,
127
+ append_to_context=None, cond_tau=0.4, style_cond_tau=1.0):
128
+ device = self.model.betas.device
129
+ b = shape[0]
130
+ if x_T is None:
131
+ img = torch.randn(shape, device=device)
132
+ else:
133
+ img = x_T
134
+
135
+ if timesteps is None:
136
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
137
+ elif timesteps is not None and not ddim_use_original_steps:
138
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
139
+ timesteps = self.ddim_timesteps[:subset_end]
140
+
141
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
142
+ time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
143
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
144
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
145
+
146
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
147
+
148
+ for i, step in enumerate(iterator):
149
+ index = total_steps - i - 1
150
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
151
+
152
+ if mask is not None:
153
+ assert x0 is not None
154
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
155
+ img = img_orig * mask + (1. - mask) * img
156
+
157
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
158
+ quantize_denoised=quantize_denoised, temperature=temperature,
159
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
160
+ corrector_kwargs=corrector_kwargs,
161
+ unconditional_guidance_scale=unconditional_guidance_scale,
162
+ unconditional_conditioning=unconditional_conditioning,
163
+ features_adapter=None if index < int(
164
+ (1 - cond_tau) * total_steps) else features_adapter,
165
+ append_to_context=None if index < int(
166
+ (1 - style_cond_tau) * total_steps) else append_to_context,
167
+ )
168
+ img, pred_x0 = outs
169
+ if callback: callback(i)
170
+ if img_callback: img_callback(pred_x0, i)
171
+
172
+ if index % log_every_t == 0 or index == total_steps - 1:
173
+ intermediates['x_inter'].append(img)
174
+ intermediates['pred_x0'].append(pred_x0)
175
+
176
+ return img, intermediates
177
+
178
+ @torch.no_grad()
179
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
180
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
181
+ unconditional_guidance_scale=1., unconditional_conditioning=None, features_adapter=None,
182
+ append_to_context=None):
183
+ b, *_, device = *x.shape, x.device
184
+
185
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
186
+ if append_to_context is not None:
187
+ model_output = self.model.apply_model(x, t, torch.cat([c, append_to_context], dim=1),
188
+ features_adapter=features_adapter)
189
+ else:
190
+ model_output = self.model.apply_model(x, t, c, features_adapter=features_adapter)
191
+ else:
192
+ x_in = torch.cat([x] * 2)
193
+ t_in = torch.cat([t] * 2)
194
+ if isinstance(c, dict):
195
+ assert isinstance(unconditional_conditioning, dict)
196
+ c_in = dict()
197
+ for k in c:
198
+ if isinstance(c[k], list):
199
+ c_in[k] = [torch.cat([
200
+ unconditional_conditioning[k][i],
201
+ c[k][i]]) for i in range(len(c[k]))]
202
+ else:
203
+ c_in[k] = torch.cat([
204
+ unconditional_conditioning[k],
205
+ c[k]])
206
+ elif isinstance(c, list):
207
+ c_in = list()
208
+ assert isinstance(unconditional_conditioning, list)
209
+ for i in range(len(c)):
210
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
211
+ else:
212
+ if append_to_context is not None:
213
+ pad_len = append_to_context.size(1)
214
+ new_unconditional_conditioning = torch.cat(
215
+ [unconditional_conditioning, unconditional_conditioning[:, -pad_len:, :]], dim=1)
216
+ new_c = torch.cat([c, append_to_context], dim=1)
217
+ c_in = torch.cat([new_unconditional_conditioning, new_c])
218
+ else:
219
+ c_in = torch.cat([unconditional_conditioning, c])
220
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in, features_adapter=features_adapter).chunk(2)
221
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
222
+
223
+ if self.model.parameterization == "v":
224
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
225
+ else:
226
+ e_t = model_output
227
+
228
+ if score_corrector is not None:
229
+ assert self.model.parameterization == "eps", 'not implemented'
230
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
231
+
232
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
233
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
234
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
235
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
236
+ # select parameters corresponding to the currently considered timestep
237
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
238
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
239
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
240
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
241
+
242
+ # current prediction for x_0
243
+ if self.model.parameterization != "v":
244
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
245
+ else:
246
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
247
+
248
+ if quantize_denoised:
249
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
250
+ # direction pointing to x_t
251
+ dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
252
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
253
+ if noise_dropout > 0.:
254
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
255
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
256
+ return x_prev, pred_x0
257
+
258
+ @torch.no_grad()
259
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
260
+ # fast, but does not allow for exact reconstruction
261
+ # t serves as an index to gather the correct alphas
262
+ if use_original_steps:
263
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
264
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
265
+ else:
266
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
267
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
268
+
269
+ if noise is None:
270
+ noise = torch.randn_like(x0)
271
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
272
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
273
+
274
+ @torch.no_grad()
275
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
276
+ use_original_steps=False):
277
+
278
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
279
+ timesteps = timesteps[:t_start]
280
+
281
+ time_range = np.flip(timesteps)
282
+ total_steps = timesteps.shape[0]
283
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
284
+
285
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
286
+ x_dec = x_latent
287
+ for i, step in enumerate(iterator):
288
+ index = total_steps - i - 1
289
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
290
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
291
+ unconditional_guidance_scale=unconditional_guidance_scale,
292
+ unconditional_conditioning=unconditional_conditioning)
293
+ return x_dec
294
+
295
+ def ddim_sampling_reverse(self,
296
+ num_steps=50,
297
+ x_0=None,
298
+ conditioning=None,
299
+ eta=0.,
300
+ verbose=False,
301
+ unconditional_guidance_scale=7.5,
302
+ unconditional_conditioning=None
303
+ ):
304
+ """
305
+ obtain the inverted x_T noisy image
306
+ """
307
+ assert eta == 0., "eta should be 0. for deterministic sampling"
308
+ B = x_0.shape[0]
309
+ # scheduler
310
+ self.make_schedule(ddim_num_steps=num_steps, ddim_eta=eta, verbose=verbose)
311
+ self.register_buffer("ddim_sqrt_one_minus_alphas_prev", torch.tensor(1. - self.ddim_alphas_prev).sqrt())
312
+ # sampling
313
+ device = self.model.betas.device
314
+
315
+ intermediates = {"x_inter": [x_0], "pred_x0": []}
316
+
317
+ time_range = self.ddim_timesteps
318
+ print("selected steps for ddim inversion: ", time_range)
319
+ assert len(time_range) == num_steps, "time range should be same as num steps"
320
+
321
+ iterator = tqdm(time_range, desc='DDIM Inversion', total=num_steps)
322
+ x_t = x_0
323
+ for i, step in enumerate(iterator):
324
+ if i == 0:
325
+ step = 1
326
+ else:
327
+ step = time_range[i - 1]
328
+ ts = torch.full((B, ), step, device=device, dtype=torch.long)
329
+ outs = self.p_sample_ddim_reverse(x_t,
330
+ conditioning,
331
+ ts,
332
+ index=i,
333
+ unconditional_guidance_scale=unconditional_guidance_scale,
334
+ unconditional_conditioning=unconditional_conditioning
335
+ )
336
+ x_t, pred_x0 = outs
337
+ intermediates["x_inter"].append(x_t)
338
+ intermediates["pred_x0"].append(pred_x0)
339
+ return x_t, intermediates
340
+
341
+ def p_sample_ddim_reverse(self,
342
+ x, c, t, index,
343
+ unconditional_guidance_scale=1.,
344
+ unconditional_conditioning=None):
345
+ B = x.shape[0]
346
+ device = x.device
347
+
348
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
349
+ e_t = self.model.apply_model(x, t, c)
350
+ else:
351
+ x_in = torch.cat([x] * 2)
352
+ t_in = torch.cat([t] * 2)
353
+ c_in = torch.cat([unconditional_conditioning, c])
354
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
355
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
356
+
357
+ # scheduler parameters
358
+ alphas = self.ddim_alphas
359
+ alphas_prev = self.ddim_alphas_prev
360
+ sqrt_one_minus_alphas_prev = self.ddim_sqrt_one_minus_alphas_prev
361
+
362
+ # select parameters corresponding to the currently considered timestep
363
+ alpha_cumprod_next = torch.full((B, 1, 1, 1), alphas[index], device=device)
364
+ alpha_cumprod_t = torch.full((B, 1, 1, 1), alphas_prev[index], device=device)
365
+ sqrt_one_minus_alpha_t = torch.full((B, 1, 1, 1), sqrt_one_minus_alphas_prev[index], device=device)
366
+
367
+ # current prediction for x_0
368
+ pred_x0 = (x - sqrt_one_minus_alpha_t * e_t) / alpha_cumprod_t.sqrt()
369
+
370
+ # direction pointing to x_t
371
+ dir_xt = (1. - alpha_cumprod_next).sqrt() * e_t
372
+
373
+ x_next = alpha_cumprod_next.sqrt() * pred_x0 + dir_xt
374
+
375
+ return x_next, pred_x0
masactrl_w_adapter/masactrl_w_adapter.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from basicsr.utils import tensor2img
7
+ from pytorch_lightning import seed_everything
8
+ from torch import autocast
9
+ from torchvision.io import read_image
10
+
11
+ from ldm.inference_base import (diffusion_inference, get_adapters, get_base_argument_parser, get_sd_models)
12
+ from ldm.modules.extra_condition import api
13
+ from ldm.modules.extra_condition.api import (ExtraCondition, get_adapter_feature, get_cond_model)
14
+ from ldm.util import fix_cond_shapes
15
+
16
+ # for masactrl
17
+ from masactrl.masactrl_utils import regiter_attention_editor_ldm
18
+ from masactrl.masactrl import MutualSelfAttentionControl
19
+ from masactrl.masactrl import MutualSelfAttentionControlMask
20
+ from masactrl.masactrl import MutualSelfAttentionControlMaskAuto
21
+
22
+ torch.set_grad_enabled(False)
23
+
24
+
25
+ def main():
26
+ supported_cond = [e.name for e in ExtraCondition]
27
+ parser = get_base_argument_parser()
28
+ parser.add_argument(
29
+ '--which_cond',
30
+ type=str,
31
+ required=True,
32
+ choices=supported_cond,
33
+ help='which condition modality you want to test',
34
+ )
35
+ # [MasaCtrl added] reference cond path
36
+ parser.add_argument(
37
+ "--cond_path_src",
38
+ type=str,
39
+ default=None,
40
+ help="the condition image path to synthesize the source image",
41
+ )
42
+ parser.add_argument(
43
+ "--prompt_src",
44
+ type=str,
45
+ default=None,
46
+ help="the prompt to synthesize the source image",
47
+ )
48
+ parser.add_argument(
49
+ "--src_img_path",
50
+ type=str,
51
+ default=None,
52
+ help="the input real source image path"
53
+ )
54
+ parser.add_argument(
55
+ "--start_code_path",
56
+ type=str,
57
+ default=None,
58
+ help="the inverted start code path to synthesize the source image",
59
+ )
60
+ parser.add_argument(
61
+ "--masa_step",
62
+ type=int,
63
+ default=4,
64
+ help="the starting step for MasaCtrl",
65
+ )
66
+ parser.add_argument(
67
+ "--masa_layer",
68
+ type=int,
69
+ default=10,
70
+ help="the starting layer for MasaCtrl",
71
+ )
72
+
73
+ opt = parser.parse_args()
74
+ which_cond = opt.which_cond
75
+ if opt.outdir is None:
76
+ opt.outdir = f'outputs/test-{which_cond}'
77
+ os.makedirs(opt.outdir, exist_ok=True)
78
+ if opt.resize_short_edge is None:
79
+ print(f"you don't specify the resize_shot_edge, so the maximum resolution is set to {opt.max_resolution}")
80
+ opt.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
81
+
82
+ if os.path.isdir(opt.cond_path): # for conditioning image folder
83
+ image_paths = [os.path.join(opt.cond_path, f) for f in os.listdir(opt.cond_path)]
84
+ else:
85
+ image_paths = [opt.cond_path]
86
+ print(image_paths)
87
+
88
+ # prepare models
89
+ sd_model, sampler = get_sd_models(opt)
90
+ adapter = get_adapters(opt, getattr(ExtraCondition, which_cond))
91
+ cond_model = None
92
+ if opt.cond_inp_type == 'image':
93
+ cond_model = get_cond_model(opt, getattr(ExtraCondition, which_cond))
94
+
95
+ process_cond_module = getattr(api, f'get_cond_{which_cond}')
96
+
97
+ # [MasaCtrl added] default STEP and LAYER params for MasaCtrl
98
+ STEP = opt.masa_step if opt.masa_step is not None else 4
99
+ LAYER = opt.masa_layer if opt.masa_layer is not None else 10
100
+
101
+ # inference
102
+ with torch.inference_mode(), \
103
+ sd_model.ema_scope(), \
104
+ autocast('cuda'):
105
+ for test_idx, cond_path in enumerate(image_paths):
106
+ seed_everything(opt.seed)
107
+ for v_idx in range(opt.n_samples):
108
+ # seed_everything(opt.seed+v_idx+test_idx)
109
+ if opt.cond_path_src:
110
+ cond_src = process_cond_module(opt, opt.cond_path_src, opt.cond_inp_type, cond_model)
111
+ cond = process_cond_module(opt, cond_path, opt.cond_inp_type, cond_model)
112
+
113
+ base_count = len(os.listdir(opt.outdir)) // 2
114
+ cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_{which_cond}.png'), tensor2img(cond))
115
+ if opt.cond_path_src:
116
+ cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_{which_cond}_src.png'), tensor2img(cond_src))
117
+
118
+ adapter_features, append_to_context = get_adapter_feature(cond, adapter)
119
+ if opt.cond_path_src:
120
+ adapter_features_src, append_to_context_src = get_adapter_feature(cond_src, adapter)
121
+
122
+ if opt.cond_path_src:
123
+ print("using reference guidance to synthesize image")
124
+ adapter_features = [torch.cat([adapter_features_src[i], adapter_features[i]]) for i in range(len(adapter_features))]
125
+ else:
126
+ adapter_features = [torch.cat([torch.zeros_like(feats), feats]) for feats in adapter_features]
127
+
128
+ if opt.scale > 1.:
129
+ adapter_features = [torch.cat([feats] * 2) for feats in adapter_features]
130
+
131
+ # prepare the batch prompts
132
+ if opt.prompt_src is not None:
133
+ prompts = [opt.prompt_src, opt.prompt]
134
+ else:
135
+ prompts = [opt.prompt] * 2
136
+ print("promts: ", prompts)
137
+ # get text embedding
138
+ c = sd_model.get_learned_conditioning(prompts)
139
+ if opt.scale != 1.0:
140
+ uc = sd_model.get_learned_conditioning([""] * len(prompts))
141
+ else:
142
+ uc = None
143
+ c, uc = fix_cond_shapes(sd_model, c, uc)
144
+
145
+ if not hasattr(opt, 'H'):
146
+ opt.H = 512
147
+ opt.W = 512
148
+ shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
149
+ if opt.src_img_path: # perform ddim inversion
150
+
151
+ src_img = read_image(opt.src_img_path)
152
+ src_img = src_img.float() / 255. # input normalized image [0, 1]
153
+ src_img = src_img * 2 - 1
154
+ if src_img.dim() == 3:
155
+ src_img = src_img.unsqueeze(0)
156
+ src_img = F.interpolate(src_img, (opt.H, opt.W))
157
+ src_img = src_img.to(opt.device)
158
+ # obtain initial latent
159
+ encoder_posterior = sd_model.encode_first_stage(src_img)
160
+ src_x_0 = sd_model.get_first_stage_encoding(encoder_posterior)
161
+ start_code, latents_dict = sampler.ddim_sampling_reverse(
162
+ num_steps=opt.steps,
163
+ x_0=src_x_0,
164
+ conditioning=uc[:1], # you may change here during inversion
165
+ unconditional_guidance_scale=opt.scale,
166
+ unconditional_conditioning=uc[:1],
167
+ )
168
+ torch.save(
169
+ {
170
+ "start_code": start_code
171
+ },
172
+ os.path.join(opt.outdir, "start_code.pth"),
173
+ )
174
+ elif opt.start_code_path:
175
+ # load the inverted start code
176
+ start_code_dict = torch.load(opt.start_code_path)
177
+ start_code = start_code_dict.get("start_code").to(opt.device)
178
+ else:
179
+ start_code = torch.randn([1, *shape], device=opt.device)
180
+ start_code = start_code.expand(len(prompts), -1, -1, -1)
181
+
182
+ # hijack the attention module
183
+ editor = MutualSelfAttentionControl(STEP, LAYER)
184
+ regiter_attention_editor_ldm(sd_model, editor)
185
+
186
+ samples_latents, _ = sampler.sample(
187
+ S=opt.steps,
188
+ conditioning=c,
189
+ batch_size=len(prompts),
190
+ shape=shape,
191
+ verbose=False,
192
+ unconditional_guidance_scale=opt.scale,
193
+ unconditional_conditioning=uc,
194
+ x_T=start_code,
195
+ features_adapter=adapter_features,
196
+ append_to_context=append_to_context,
197
+ cond_tau=opt.cond_tau,
198
+ style_cond_tau=opt.style_cond_tau,
199
+ )
200
+
201
+ x_samples = sd_model.decode_first_stage(samples_latents)
202
+ x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
203
+
204
+ cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_all_result.png'), tensor2img(x_samples))
205
+ # save the prompts and seed
206
+ with open(os.path.join(opt.outdir, "log.txt"), "w") as f:
207
+ for prom in prompts:
208
+ f.write(prom)
209
+ f.write("\n")
210
+ f.write(f"seed: {opt.seed}")
211
+ for i in range(len(x_samples)):
212
+ base_count += 1
213
+ cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_result.png'), tensor2img(x_samples[i]))
214
+
215
+
216
+ if __name__ == '__main__':
217
+ main()
playground.ipynb ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "### MasaCtrl: Tuning-free Mutual Self-Attention Control for Consistent Image Synthesis and Editing"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import os\n",
17
+ "import torch\n",
18
+ "import torch.nn as nn\n",
19
+ "import torch.nn.functional as F\n",
20
+ "\n",
21
+ "import numpy as np\n",
22
+ "\n",
23
+ "from tqdm import tqdm\n",
24
+ "from einops import rearrange, repeat\n",
25
+ "from omegaconf import OmegaConf\n",
26
+ "\n",
27
+ "from diffusers import DDIMScheduler\n",
28
+ "\n",
29
+ "from masactrl.diffuser_utils import MasaCtrlPipeline\n",
30
+ "from masactrl.masactrl_utils import AttentionBase\n",
31
+ "from masactrl.masactrl_utils import regiter_attention_editor_diffusers\n",
32
+ "\n",
33
+ "from torchvision.utils import save_image\n",
34
+ "from torchvision.io import read_image\n",
35
+ "from pytorch_lightning import seed_everything\n",
36
+ "\n",
37
+ "torch.cuda.set_device(0) # set the GPU device"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "markdown",
42
+ "metadata": {},
43
+ "source": [
44
+ "#### Model Construction"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "metadata": {},
51
+ "outputs": [],
52
+ "source": [
53
+ "# Note that you may add your Hugging Face token to get access to the models\n",
54
+ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
55
+ "model_path = \"xyn-ai/anything-v4.0\"\n",
56
+ "# model_path = \"runwayml/stable-diffusion-v1-5\"\n",
57
+ "scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False, set_alpha_to_one=False)\n",
58
+ "model = MasaCtrlPipeline.from_pretrained(model_path, scheduler=scheduler, cross_attention_kwargs={\"scale\": 0.5}).to(device)"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "markdown",
63
+ "metadata": {},
64
+ "source": [
65
+ "#### Consistent synthesis with MasaCtrl"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": null,
71
+ "metadata": {},
72
+ "outputs": [],
73
+ "source": [
74
+ "from masactrl.masactrl import MutualSelfAttentionControl\n",
75
+ "\n",
76
+ "\n",
77
+ "seed = 42\n",
78
+ "seed_everything(seed)\n",
79
+ "\n",
80
+ "out_dir = \"./workdir/masactrl_exp/\"\n",
81
+ "os.makedirs(out_dir, exist_ok=True)\n",
82
+ "sample_count = len(os.listdir(out_dir))\n",
83
+ "out_dir = os.path.join(out_dir, f\"sample_{sample_count}\")\n",
84
+ "os.makedirs(out_dir, exist_ok=True)\n",
85
+ "\n",
86
+ "prompts = [\n",
87
+ " \"1boy, casual, outdoors, sitting\", # source prompt\n",
88
+ " \"1boy, casual, outdoors, standing\" # target prompt\n",
89
+ "]\n",
90
+ "\n",
91
+ "# initialize the noise map\n",
92
+ "start_code = torch.randn([1, 4, 64, 64], device=device)\n",
93
+ "start_code = start_code.expand(len(prompts), -1, -1, -1)\n",
94
+ "\n",
95
+ "# inference the synthesized image without MasaCtrl\n",
96
+ "editor = AttentionBase()\n",
97
+ "regiter_attention_editor_diffusers(model, editor)\n",
98
+ "image_ori = model(prompts, latents=start_code, guidance_scale=7.5)\n",
99
+ "\n",
100
+ "# inference the synthesized image with MasaCtrl\n",
101
+ "STEP = 4\n",
102
+ "LAYPER = 10\n",
103
+ "\n",
104
+ "# hijack the attention module\n",
105
+ "editor = MutualSelfAttentionControl(STEP, LAYPER)\n",
106
+ "regiter_attention_editor_diffusers(model, editor)\n",
107
+ "\n",
108
+ "# inference the synthesized image\n",
109
+ "image_masactrl = model(prompts, latents=start_code, guidance_scale=7.5)[-1:]\n",
110
+ "\n",
111
+ "# save the synthesized image\n",
112
+ "out_image = torch.cat([image_ori, image_masactrl], dim=0)\n",
113
+ "save_image(out_image, os.path.join(out_dir, f\"all_step{STEP}_layer{LAYPER}.png\"))\n",
114
+ "save_image(out_image[0], os.path.join(out_dir, f\"source_step{STEP}_layer{LAYPER}.png\"))\n",
115
+ "save_image(out_image[1], os.path.join(out_dir, f\"without_step{STEP}_layer{LAYPER}.png\"))\n",
116
+ "save_image(out_image[2], os.path.join(out_dir, f\"masactrl_step{STEP}_layer{LAYPER}.png\"))\n",
117
+ "\n",
118
+ "print(\"Syntheiszed images are saved in\", out_dir)"
119
+ ]
120
+ }
121
+ ],
122
+ "metadata": {
123
+ "kernelspec": {
124
+ "display_name": "Python 3.8.5 ('ldm')",
125
+ "language": "python",
126
+ "name": "python3"
127
+ },
128
+ "language_info": {
129
+ "codemirror_mode": {
130
+ "name": "ipython",
131
+ "version": 3
132
+ },
133
+ "file_extension": ".py",
134
+ "mimetype": "text/x-python",
135
+ "name": "python",
136
+ "nbconvert_exporter": "python",
137
+ "pygments_lexer": "ipython3",
138
+ "version": "3.8.5"
139
+ },
140
+ "orig_nbformat": 4,
141
+ "vscode": {
142
+ "interpreter": {
143
+ "hash": "587aa04bacead72c1ffd459abbe4c8140b72ba2b534b24165b36a2ede3d95042"
144
+ }
145
+ }
146
+ },
147
+ "nbformat": 4,
148
+ "nbformat_minor": 2
149
+ }
playground_real.ipynb ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "### MasaCtrl: Tuning-free Mutual Self-Attention Control for Consistent Image Synthesis and Editing"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import os\n",
17
+ "import torch\n",
18
+ "import torch.nn as nn\n",
19
+ "import torch.nn.functional as F\n",
20
+ "\n",
21
+ "import numpy as np\n",
22
+ "\n",
23
+ "from tqdm import tqdm\n",
24
+ "from einops import rearrange, repeat\n",
25
+ "from omegaconf import OmegaConf\n",
26
+ "\n",
27
+ "from diffusers import DDIMScheduler\n",
28
+ "\n",
29
+ "from masactrl.diffuser_utils import MasaCtrlPipeline\n",
30
+ "from masactrl.masactrl_utils import AttentionBase\n",
31
+ "from masactrl.masactrl_utils import regiter_attention_editor_diffusers\n",
32
+ "\n",
33
+ "from torchvision.utils import save_image\n",
34
+ "from torchvision.io import read_image\n",
35
+ "from pytorch_lightning import seed_everything\n",
36
+ "\n",
37
+ "torch.cuda.set_device(0) # set the GPU device"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "markdown",
42
+ "metadata": {},
43
+ "source": [
44
+ "#### Model Construction"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "metadata": {},
51
+ "outputs": [],
52
+ "source": [
53
+ "# Note that you may add your Hugging Face token to get access to the models\n",
54
+ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
55
+ "# model_path = \"xyn-ai/anything-v4.0\"\n",
56
+ "model_path = \"CompVis/stable-diffusion-v1-4\"\n",
57
+ "# model_path = \"runwayml/stable-diffusion-v1-5\"\n",
58
+ "scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False, set_alpha_to_one=False)\n",
59
+ "model = MasaCtrlPipeline.from_pretrained(model_path, scheduler=scheduler).to(device)"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "markdown",
64
+ "metadata": {},
65
+ "source": [
66
+ "#### Real editing with MasaCtrl"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "from masactrl.masactrl import MutualSelfAttentionControl\n",
76
+ "from torchvision.io import read_image\n",
77
+ "\n",
78
+ "\n",
79
+ "def load_image(image_path, device):\n",
80
+ " image = read_image(image_path)\n",
81
+ " image = image[:3].unsqueeze_(0).float() / 127.5 - 1. # [-1, 1]\n",
82
+ " image = F.interpolate(image, (512, 512))\n",
83
+ " image = image.to(device)\n",
84
+ " return image\n",
85
+ "\n",
86
+ "\n",
87
+ "seed = 42\n",
88
+ "seed_everything(seed)\n",
89
+ "\n",
90
+ "out_dir = \"./workdir/masactrl_real_exp/\"\n",
91
+ "os.makedirs(out_dir, exist_ok=True)\n",
92
+ "sample_count = len(os.listdir(out_dir))\n",
93
+ "out_dir = os.path.join(out_dir, f\"sample_{sample_count}\")\n",
94
+ "os.makedirs(out_dir, exist_ok=True)\n",
95
+ "\n",
96
+ "# source image\n",
97
+ "SOURCE_IMAGE_PATH = \"./gradio_app/images/corgi.jpg\"\n",
98
+ "source_image = load_image(SOURCE_IMAGE_PATH, device)\n",
99
+ "\n",
100
+ "source_prompt = \"\"\n",
101
+ "target_prompt = \"a photo of a running corgi\"\n",
102
+ "prompts = [source_prompt, target_prompt]\n",
103
+ "\n",
104
+ "# invert the source image\n",
105
+ "start_code, latents_list = model.invert(source_image,\n",
106
+ " source_prompt,\n",
107
+ " guidance_scale=7.5,\n",
108
+ " num_inference_steps=50,\n",
109
+ " return_intermediates=True)\n",
110
+ "start_code = start_code.expand(len(prompts), -1, -1, -1)\n",
111
+ "\n",
112
+ "# results of direct synthesis\n",
113
+ "editor = AttentionBase()\n",
114
+ "regiter_attention_editor_diffusers(model, editor)\n",
115
+ "image_fixed = model([target_prompt],\n",
116
+ " latents=start_code[-1:],\n",
117
+ " num_inference_steps=50,\n",
118
+ " guidance_scale=7.5)\n",
119
+ "\n",
120
+ "# inference the synthesized image with MasaCtrl\n",
121
+ "STEP = 4\n",
122
+ "LAYPER = 10\n",
123
+ "\n",
124
+ "# hijack the attention module\n",
125
+ "editor = MutualSelfAttentionControl(STEP, LAYPER)\n",
126
+ "regiter_attention_editor_diffusers(model, editor)\n",
127
+ "\n",
128
+ "# inference the synthesized image\n",
129
+ "image_masactrl = model(prompts,\n",
130
+ " latents=start_code,\n",
131
+ " guidance_scale=7.5)\n",
132
+ "# Note: querying the inversion intermediate features latents_list\n",
133
+ "# may obtain better reconstruction and editing results\n",
134
+ "# image_masactrl = model(prompts,\n",
135
+ "# latents=start_code,\n",
136
+ "# guidance_scale=7.5,\n",
137
+ "# ref_intermediate_latents=latents_list)\n",
138
+ "\n",
139
+ "# save the synthesized image\n",
140
+ "out_image = torch.cat([source_image * 0.5 + 0.5,\n",
141
+ " image_masactrl[0:1],\n",
142
+ " image_fixed,\n",
143
+ " image_masactrl[-1:]], dim=0)\n",
144
+ "save_image(out_image, os.path.join(out_dir, f\"all_step{STEP}_layer{LAYPER}.png\"))\n",
145
+ "save_image(out_image[0], os.path.join(out_dir, f\"source_step{STEP}_layer{LAYPER}.png\"))\n",
146
+ "save_image(out_image[1], os.path.join(out_dir, f\"reconstructed_source_step{STEP}_layer{LAYPER}.png\"))\n",
147
+ "save_image(out_image[2], os.path.join(out_dir, f\"without_step{STEP}_layer{LAYPER}.png\"))\n",
148
+ "save_image(out_image[3], os.path.join(out_dir, f\"masactrl_step{STEP}_layer{LAYPER}.png\"))\n",
149
+ "\n",
150
+ "print(\"Syntheiszed images are saved in\", out_dir)"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": null,
156
+ "metadata": {},
157
+ "outputs": [],
158
+ "source": []
159
+ }
160
+ ],
161
+ "metadata": {
162
+ "kernelspec": {
163
+ "display_name": "Python 3.8.5 ('ldm')",
164
+ "language": "python",
165
+ "name": "python3"
166
+ },
167
+ "language_info": {
168
+ "codemirror_mode": {
169
+ "name": "ipython",
170
+ "version": 3
171
+ },
172
+ "file_extension": ".py",
173
+ "mimetype": "text/x-python",
174
+ "name": "python",
175
+ "nbconvert_exporter": "python",
176
+ "pygments_lexer": "ipython3",
177
+ "version": "3.8.5"
178
+ },
179
+ "orig_nbformat": 4,
180
+ "vscode": {
181
+ "interpreter": {
182
+ "hash": "587aa04bacead72c1ffd459abbe4c8140b72ba2b534b24165b36a2ede3d95042"
183
+ }
184
+ }
185
+ },
186
+ "nbformat": 4,
187
+ "nbformat_minor": 2
188
+ }
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #diffusers==0.15.0
2
+ diffusers
3
+ transformers
4
+ opencv-python
5
+ einops
6
+ omegaconf
7
+ pytorch_lightning
8
+ torch
9
+ torchvision
10
+ gradio
11
+ httpx[socks]
12
+ huggingface_hub==0.25.0
13
+ moviepy
run_synthesis_genshin_impact_xl.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ python run_synthesis_genshin_impact_xl.py --model_path "svjack/GenshinImpact_XL_Base" \
3
+ --prompt1 "A portrait of an old man, facing camera, best quality" \
4
+ --prompt2 "A portrait of an old man, facing camera, smiling, best quality" --guidance_scale 3.5
5
+
6
+ python run_synthesis_genshin_impact_xl.py --model_path "svjack/GenshinImpact_XL_Base" \
7
+ --prompt1 "solo,ZHONGLI\(genshin impact\),1boy,highres," \
8
+ --prompt2 "solo,ZHONGLI drink tea use chinese cup \(genshin impact\),1boy,highres," --guidance_scale 5
9
+
10
+ from IPython import display
11
+
12
+ display.Image("masactrl_exp/solo_zhongli_drink_tea_use_chinese_cup___genshin_impact___1boy_highres_/sample_0/masactrl_step4_layer64.png", width=512, height=512)
13
+ display.Image("masactrl_exp/solo_zhongli_drink_tea_use_chinese_cup___genshin_impact___1boy_highres_/sample_1/source_step4_layer74.png", width=512, height=512)
14
+ '''
15
+
16
+ import os
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+ import numpy as np
22
+
23
+ from tqdm import tqdm
24
+ from einops import rearrange, repeat
25
+ from omegaconf import OmegaConf
26
+
27
+ from diffusers import DDIMScheduler, DiffusionPipeline
28
+
29
+ from masactrl.diffuser_utils import MasaCtrlPipeline
30
+ from masactrl.masactrl_utils import AttentionBase
31
+ from masactrl.masactrl_utils import regiter_attention_editor_diffusers
32
+ from masactrl.masactrl import MutualSelfAttentionControl
33
+
34
+ from torchvision.utils import save_image
35
+ from torchvision.io import read_image
36
+ from pytorch_lightning import seed_everything
37
+
38
+ import argparse
39
+ import re
40
+
41
+ torch.cuda.set_device(0) # set the GPU device
42
+
43
+ # Note that you may add your Hugging Face token to get access to the models
44
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
45
+
46
+ def pathify(s):
47
+ # Convert to lowercase and replace non-alphanumeric characters with underscores
48
+ return re.sub(r'[^a-zA-Z0-9]', '_', s.lower())
49
+
50
+ def consistent_synthesis(args):
51
+ seed = 42
52
+ seed_everything(seed)
53
+
54
+ # Create the output directory based on prompt2
55
+ out_dir_ori = os.path.join("masactrl_exp", pathify(args.prompt2))
56
+ os.makedirs(out_dir_ori, exist_ok=True)
57
+
58
+ prompts = [
59
+ args.prompt1,
60
+ args.prompt2,
61
+ ]
62
+
63
+ # inference the synthesized image with MasaCtrl
64
+ # TODO: note that the hyper paramerter of MasaCtrl for SDXL may be not optimal
65
+ STEP = 4
66
+ #LAYER_LIST = [44, 54, 64] # run the synthesis with MasaCtrl at three different layer configs
67
+ #LAYER_LIST = [64, 74, 84, 94] # run the synthesis with MasaCtrl at three different layer configs
68
+ LAYER_LIST = [64, 74] # run the synthesis with MasaCtrl at three different layer configs
69
+
70
+ # initialize the noise map
71
+ start_code = torch.randn([1, 4, 128, 128], device=device)
72
+ start_code = start_code.expand(len(prompts), -1, -1, -1)
73
+
74
+ # Load the model
75
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
76
+ model = DiffusionPipeline.from_pretrained(args.model_path, scheduler=scheduler).to(device)
77
+
78
+ # inference the synthesized image without MasaCtrl
79
+ editor = AttentionBase()
80
+ regiter_attention_editor_diffusers(model, editor)
81
+ image_ori = model(prompts, latents=start_code, guidance_scale=args.guidance_scale).images
82
+
83
+ for LAYER in LAYER_LIST:
84
+ # hijack the attention module
85
+ editor = MutualSelfAttentionControl(STEP, LAYER, model_type="SDXL")
86
+ regiter_attention_editor_diffusers(model, editor)
87
+
88
+ # inference the synthesized image
89
+ image_masactrl = model(prompts, latents=start_code, guidance_scale=args.guidance_scale).images
90
+
91
+ sample_count = len(os.listdir(out_dir_ori))
92
+ out_dir = os.path.join(out_dir_ori, f"sample_{sample_count}")
93
+ os.makedirs(out_dir, exist_ok=True)
94
+ image_ori[0].save(os.path.join(out_dir, f"source_step{STEP}_layer{LAYER}.png"))
95
+ image_ori[1].save(os.path.join(out_dir, f"without_step{STEP}_layer{LAYER}.png"))
96
+ image_masactrl[-1].save(os.path.join(out_dir, f"masactrl_step{STEP}_layer{LAYER}.png"))
97
+ with open(os.path.join(out_dir, f"prompts.txt"), "w") as f:
98
+ for p in prompts:
99
+ f.write(p + "\n")
100
+ f.write(f"seed: {seed}\n")
101
+ print("Syntheiszed images are saved in", out_dir)
102
+
103
+ if __name__ == "__main__":
104
+ parser = argparse.ArgumentParser(description="Consistent Synthesis with MasaCtrl")
105
+ parser.add_argument("--model_path", type=str, default="svjack/GenshinImpact_XL_Base", help="Path to the model")
106
+ parser.add_argument("--prompt1", type=str, default="A portrait of an old man, facing camera, best quality", help="First prompt")
107
+ parser.add_argument("--prompt2", type=str, default="A portrait of an old man, facing camera, smiling, best quality", help="Second prompt")
108
+ parser.add_argument("--guidance_scale", type=float, default=7.5, help="Guidance scale")
109
+ parser.add_argument("--out_dir", type=str, default=None, help="Output directory")
110
+
111
+ args = parser.parse_args()
112
+
113
+ # If out_dir is not provided, use the default path based on prompt2
114
+ if args.out_dir is None:
115
+ args.out_dir = os.path.join("masactrl_exp", pathify(args.prompt2))
116
+
117
+ consistent_synthesis(args)
run_synthesis_genshin_impact_xl_app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import DDIMScheduler, DiffusionPipeline
4
+ from masactrl.diffuser_utils import MasaCtrlPipeline
5
+ from masactrl.masactrl_utils import AttentionBase, regiter_attention_editor_diffusers
6
+ from masactrl.masactrl import MutualSelfAttentionControl
7
+ from pytorch_lightning import seed_everything
8
+ import os
9
+ import re
10
+
11
+ # 初始化设备和模型
12
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
13
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
14
+ model = DiffusionPipeline.from_pretrained("svjack/GenshinImpact_XL_Base", scheduler=scheduler).to(device)
15
+
16
+ def pathify(s):
17
+ return re.sub(r'[^a-zA-Z0-9]', '_', s.lower())
18
+
19
+ def consistent_synthesis(prompt1, prompt2, guidance_scale, seed, starting_step, starting_layer):
20
+ seed_everything(seed)
21
+
22
+ # 创建输出目录
23
+ out_dir_ori = os.path.join("masactrl_exp", pathify(prompt2))
24
+ os.makedirs(out_dir_ori, exist_ok=True)
25
+
26
+ prompts = [prompt1, prompt2]
27
+
28
+ # 初始化噪声图
29
+ start_code = torch.randn([1, 4, 128, 128], device=device)
30
+ start_code = start_code.expand(len(prompts), -1, -1, -1)
31
+
32
+ # 推理没有 MasaCtrl 的图像
33
+ editor = AttentionBase()
34
+ regiter_attention_editor_diffusers(model, editor)
35
+ image_ori = model(prompts, latents=start_code, guidance_scale=guidance_scale).images
36
+
37
+ images = []
38
+ # 劫持注意力模块
39
+ editor = MutualSelfAttentionControl(starting_step, starting_layer, model_type="SDXL")
40
+ regiter_attention_editor_diffusers(model, editor)
41
+
42
+ # 推理带 MasaCtrl 的图像
43
+ image_masactrl = model(prompts, latents=start_code, guidance_scale=guidance_scale).images
44
+
45
+ sample_count = len(os.listdir(out_dir_ori))
46
+ out_dir = os.path.join(out_dir_ori, f"sample_{sample_count}")
47
+ os.makedirs(out_dir, exist_ok=True)
48
+ image_ori[0].save(os.path.join(out_dir, f"source_step{starting_step}_layer{starting_layer}.png"))
49
+ image_ori[1].save(os.path.join(out_dir, f"without_step{starting_step}_layer{starting_layer}.png"))
50
+ image_masactrl[-1].save(os.path.join(out_dir, f"masactrl_step{starting_step}_layer{starting_layer}.png"))
51
+ with open(os.path.join(out_dir, f"prompts.txt"), "w") as f:
52
+ for p in prompts:
53
+ f.write(p + "\n")
54
+ f.write(f"seed: {seed}\n")
55
+ f.write(f"starting_step: {starting_step}\n")
56
+ f.write(f"starting_layer: {starting_layer}\n")
57
+ print("Synthesized images are saved in", out_dir)
58
+
59
+ return [image_ori[0], image_ori[1], image_masactrl[-1]]
60
+
61
+ def create_demo_synthesis():
62
+ with gr.Blocks() as demo:
63
+ gr.Markdown("# **Genshin Impact XL MasaCtrl Image Synthesis**") # 添加标题
64
+ gr.Markdown("## **Input Settings**")
65
+ with gr.Row():
66
+ with gr.Column():
67
+ prompt1 = gr.Textbox(label="Prompt 1", value="solo,ZHONGLI(genshin impact),1boy,highres,")
68
+ prompt2 = gr.Textbox(label="Prompt 2", value="solo,ZHONGLI drink tea use chinese cup (genshin impact),1boy,highres,")
69
+ with gr.Row():
70
+ starting_step = gr.Slider(label="Starting Step", minimum=0, maximum=999, value=4, step=1)
71
+ starting_layer = gr.Slider(label="Starting Layer", minimum=0, maximum=999, value=64, step=1)
72
+ run_btn = gr.Button("Run")
73
+ with gr.Column():
74
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
75
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=42, step=1)
76
+
77
+ gr.Markdown("## **Output**")
78
+ with gr.Row():
79
+ image_source = gr.Image(label="Source Image")
80
+ image_without_masactrl = gr.Image(label="Image without MasaCtrl")
81
+ image_with_masactrl = gr.Image(label="Image with MasaCtrl")
82
+
83
+ inputs = [prompt1, prompt2, guidance_scale, seed, starting_step, starting_layer]
84
+ run_btn.click(consistent_synthesis, inputs, [image_source, image_without_masactrl, image_with_masactrl])
85
+
86
+ gr.Examples(
87
+ [
88
+ ["solo,ZHONGLI(genshin impact),1boy,highres,", "solo,ZHONGLI drink tea use chinese cup (genshin impact),1boy,highres,", 42, 4, 64],
89
+ ["solo,KAMISATO AYATO(genshin impact),1boy,highres,", "solo,KAMISATO AYATO smiling (genshin impact),1boy,highres,", 42, 4, 55]
90
+ ],
91
+ [prompt1, prompt2, seed, starting_step, starting_layer],
92
+ )
93
+ return demo
94
+
95
+ if __name__ == "__main__":
96
+ demo_synthesis = create_demo_synthesis()
97
+ demo_synthesis.launch(share = True)
run_synthesis_sdxl.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ import numpy as np
7
+
8
+ from tqdm import tqdm
9
+ from einops import rearrange, repeat
10
+ from omegaconf import OmegaConf
11
+
12
+ from diffusers import DDIMScheduler, DiffusionPipeline
13
+
14
+ from masactrl.diffuser_utils import MasaCtrlPipeline
15
+ from masactrl.masactrl_utils import AttentionBase
16
+ from masactrl.masactrl_utils import regiter_attention_editor_diffusers
17
+ from masactrl.masactrl import MutualSelfAttentionControl
18
+
19
+ from torchvision.utils import save_image
20
+ from torchvision.io import read_image
21
+ from pytorch_lightning import seed_everything
22
+
23
+ torch.cuda.set_device(0) # set the GPU device
24
+
25
+ # Note that you may add your Hugging Face token to get access to the models
26
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
27
+
28
+ model_path = "stabilityai/stable-diffusion-xl-base-1.0"
29
+ # model_path = "Linaqruf/animagine-xl"
30
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
31
+ model = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler).to(device)
32
+
33
+
34
+ def consistent_synthesis():
35
+ seed = 42
36
+ seed_everything(seed)
37
+
38
+ out_dir_ori = "./workdir/masactrl_exp/oldman_smiling"
39
+ os.makedirs(out_dir_ori, exist_ok=True)
40
+
41
+ prompts = [
42
+ "A portrait of an old man, facing camera, best quality",
43
+ "A portrait of an old man, facing camera, smiling, best quality",
44
+ ]
45
+
46
+ # inference the synthesized image with MasaCtrl
47
+ # TODO: note that the hyper paramerter of MasaCtrl for SDXL may be not optimal
48
+ STEP = 4
49
+ LAYER_LIST = [44, 54, 64] # run the synthesis with MasaCtrl at three different layer configs
50
+
51
+ # initialize the noise map
52
+ start_code = torch.randn([1, 4, 128, 128], device=device)
53
+ # start_code = None
54
+ start_code = start_code.expand(len(prompts), -1, -1, -1)
55
+
56
+ # inference the synthesized image without MasaCtrl
57
+ editor = AttentionBase()
58
+ regiter_attention_editor_diffusers(model, editor)
59
+ image_ori = model(prompts, latents=start_code, guidance_scale=7.5).images
60
+
61
+ for LAYER in LAYER_LIST:
62
+ # hijack the attention module
63
+ editor = MutualSelfAttentionControl(STEP, LAYER, model_type="SDXL")
64
+ regiter_attention_editor_diffusers(model, editor)
65
+
66
+ # inference the synthesized image
67
+ image_masactrl = model(prompts, latents=start_code, guidance_scale=7.5).images
68
+
69
+ sample_count = len(os.listdir(out_dir_ori))
70
+ out_dir = os.path.join(out_dir_ori, f"sample_{sample_count}")
71
+ os.makedirs(out_dir, exist_ok=True)
72
+ image_ori[0].save(os.path.join(out_dir, f"source_step{STEP}_layer{LAYER}.png"))
73
+ image_ori[1].save(os.path.join(out_dir, f"without_step{STEP}_layer{LAYER}.png"))
74
+ image_masactrl[-1].save(os.path.join(out_dir, f"masactrl_step{STEP}_layer{LAYER}.png"))
75
+ with open(os.path.join(out_dir, f"prompts.txt"), "w") as f:
76
+ for p in prompts:
77
+ f.write(p + "\n")
78
+ f.write(f"seed: {seed}\n")
79
+ print("Syntheiszed images are saved in", out_dir)
80
+
81
+
82
+ if __name__ == "__main__":
83
+ consistent_synthesis()
run_synthesis_sdxl_processor.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ import numpy as np
7
+
8
+ from tqdm import tqdm
9
+ from einops import rearrange, repeat
10
+ from omegaconf import OmegaConf
11
+ from diffusers import DDIMScheduler, StableDiffusionPipeline, DiffusionPipeline
12
+ from torchvision.utils import save_image
13
+ from torchvision.io import read_image
14
+ from pytorch_lightning import seed_everything
15
+
16
+ from masactrl.masactrl_processor import register_attention_processor
17
+
18
+ torch.cuda.set_device(0) # set the GPU device
19
+
20
+ # Note that you may add your Hugging Face token to get access to the models
21
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
22
+ weight_dtype = torch.float16
23
+ model_path = "stabilityai/stable-diffusion-xl-base-1.0"
24
+ scheduler = DDIMScheduler(
25
+ beta_start=0.00085,
26
+ beta_end=0.012,
27
+ beta_schedule="scaled_linear",
28
+ clip_sample=False,
29
+ set_alpha_to_one=False
30
+ )
31
+ pipe = DiffusionPipeline.from_pretrained(
32
+ model_path,
33
+ scheduler=scheduler,
34
+ torch_dtype=weight_dtype
35
+ ).to(device)
36
+
37
+
38
+ def consistent_synthesis():
39
+ seed = 42
40
+ seed_everything(seed)
41
+
42
+ out_dir_ori = "./workdir/masactrl_exp/oldman_smiling"
43
+ os.makedirs(out_dir_ori, exist_ok=True)
44
+
45
+ prompts = [
46
+ "A portrait of an old man, facing camera, best quality",
47
+ "A portrait of an old man, facing camera, smiling, best quality",
48
+ ]
49
+
50
+ # inference the synthesized image with MasaCtrl
51
+ # TODO: note that the hyper paramerter of MasaCtrl for SDXL may be not optimal
52
+ STEP = 4
53
+ LAYER_LIST = [44, 54, 64] # run the synthesis with MasaCtrl at three different layer configs
54
+ MODEL_TYPE = "SDXL"
55
+
56
+ # initialize the noise map
57
+ start_code = torch.randn([1, 4, 128, 128], dtype=weight_dtype, device=device)
58
+ # start_code = None
59
+ start_code = start_code.expand(len(prompts), -1, -1, -1)
60
+
61
+ # inference the synthesized image without MasaCtrl
62
+ image_ori = pipe(prompts, latents=start_code, guidance_scale=7.5).images
63
+
64
+ for LAYER in LAYER_LIST:
65
+ # hijack the attention module with MasaCtrl processor
66
+ processor_args = {
67
+ "start_step": STEP,
68
+ "start_layer": LAYER,
69
+ "model_type": MODEL_TYPE
70
+ }
71
+ register_attention_processor(pipe.unet, processor_type="MasaCtrlProcessor")
72
+
73
+ # inference the synthesized image
74
+ image_masactrl = pipe(prompts, latents=start_code, guidance_scale=7.5).images
75
+
76
+ sample_count = len(os.listdir(out_dir_ori))
77
+ out_dir = os.path.join(out_dir_ori, f"sample_{sample_count}")
78
+ os.makedirs(out_dir, exist_ok=True)
79
+ image_ori[0].save(os.path.join(out_dir, f"source_step{STEP}_layer{LAYER}.png"))
80
+ image_ori[1].save(os.path.join(out_dir, f"without_step{STEP}_layer{LAYER}.png"))
81
+ image_masactrl[-1].save(os.path.join(out_dir, f"masactrl_step{STEP}_layer{LAYER}.png"))
82
+ with open(os.path.join(out_dir, f"prompts.txt"), "w") as f:
83
+ for p in prompts:
84
+ f.write(p + "\n")
85
+ f.write(f"seed: {seed}\n")
86
+ print("Syntheiszed images are saved in", out_dir)
87
+
88
+
89
+ if __name__ == "__main__":
90
+ consistent_synthesis()
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }