Spaces:
Runtime error
Runtime error
kevinwang676
commited on
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- DiffSynth_Studio.py +15 -0
- LICENSE +201 -0
- README.md +117 -13
- diffsynth/__init__.py +6 -0
- diffsynth/controlnets/__init__.py +2 -0
- diffsynth/controlnets/controlnet_unit.py +53 -0
- diffsynth/controlnets/processors.py +51 -0
- diffsynth/data/__init__.py +1 -0
- diffsynth/data/video.py +148 -0
- diffsynth/extensions/ESRGAN/__init__.py +118 -0
- diffsynth/extensions/FastBlend/__init__.py +63 -0
- diffsynth/extensions/FastBlend/api.py +397 -0
- diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
- diffsynth/extensions/FastBlend/data.py +146 -0
- diffsynth/extensions/FastBlend/patch_match.py +298 -0
- diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
- diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
- diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
- diffsynth/extensions/FastBlend/runners/fast.py +141 -0
- diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
- diffsynth/extensions/RIFE/__init__.py +241 -0
- diffsynth/models/__init__.py +814 -0
- diffsynth/models/attention.py +89 -0
- diffsynth/models/downloader.py +28 -0
- diffsynth/models/hunyuan_dit.py +451 -0
- diffsynth/models/hunyuan_dit_text_encoder.py +161 -0
- diffsynth/models/kolors_text_encoder.py +1363 -0
- diffsynth/models/sd3_dit.py +797 -0
- diffsynth/models/sd3_text_encoder.py +0 -0
- diffsynth/models/sd3_vae_decoder.py +80 -0
- diffsynth/models/sd3_vae_encoder.py +94 -0
- diffsynth/models/sd_controlnet.py +587 -0
- diffsynth/models/sd_ipadapter.py +56 -0
- diffsynth/models/sd_lora.py +60 -0
- diffsynth/models/sd_motion.py +198 -0
- diffsynth/models/sd_text_encoder.py +320 -0
- diffsynth/models/sd_unet.py +0 -0
- diffsynth/models/sd_vae_decoder.py +332 -0
- diffsynth/models/sd_vae_encoder.py +278 -0
- diffsynth/models/sdxl_ipadapter.py +121 -0
- diffsynth/models/sdxl_motion.py +103 -0
- diffsynth/models/sdxl_text_encoder.py +757 -0
- diffsynth/models/sdxl_unet.py +0 -0
- diffsynth/models/sdxl_vae_decoder.py +15 -0
- diffsynth/models/sdxl_vae_encoder.py +15 -0
- diffsynth/models/svd_image_encoder.py +504 -0
- diffsynth/models/svd_unet.py +0 -0
- diffsynth/models/svd_vae_decoder.py +577 -0
- diffsynth/models/svd_vae_encoder.py +138 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt filter=lfs diff=lfs merge=lfs -text
|
DiffSynth_Studio.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Set web page format
|
2 |
+
import streamlit as st
|
3 |
+
st.set_page_config(layout="wide")
|
4 |
+
# Diasble virtual VRAM on windows system
|
5 |
+
import torch
|
6 |
+
torch.cuda.set_per_process_memory_fraction(0.999, 0)
|
7 |
+
|
8 |
+
|
9 |
+
st.markdown("""
|
10 |
+
# DiffSynth Studio
|
11 |
+
|
12 |
+
[Source Code](https://github.com/Artiprocher/DiffSynth-Studio)
|
13 |
+
|
14 |
+
Welcome to DiffSynth Studio.
|
15 |
+
""")
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [2023] [Zhongjie Duan]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,13 +1,117 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DiffSynth Studio
|
2 |
+
|
3 |
+
|
4 |
+
## Introduction
|
5 |
+
|
6 |
+
DiffSynth Studio is a Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. We provide many interesting features. Enjoy the magic of Diffusion models!
|
7 |
+
|
8 |
+
Until now, DiffSynth Studio has supported the following models:
|
9 |
+
|
10 |
+
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
11 |
+
* [Kolors](https://huggingface.co/Kwai-Kolors/Kolors)
|
12 |
+
* [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
|
13 |
+
* [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)
|
14 |
+
* [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
|
15 |
+
* [RIFE](https://github.com/hzwer/ECCV2022-RIFE)
|
16 |
+
* [ESRGAN](https://github.com/xinntao/ESRGAN)
|
17 |
+
* [Ip-Adapter](https://github.com/tencent-ailab/IP-Adapter)
|
18 |
+
* [AnimateDiff](https://github.com/guoyww/animatediff/)
|
19 |
+
* [ControlNet](https://github.com/lllyasviel/ControlNet)
|
20 |
+
* [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
21 |
+
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
22 |
+
|
23 |
+
## News
|
24 |
+
|
25 |
+
|
26 |
+
- **June 21, 2024.** 🔥🔥🔥 We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
|
27 |
+
- [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
28 |
+
- Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
|
29 |
+
- Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
|
30 |
+
- Technical report is released on [arXiv](https://arxiv.org/abs/2406.14130).
|
31 |
+
- You can try ExVideo in this [Demo](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1)!
|
32 |
+
|
33 |
+
- **June 13, 2024.** DiffSynth Studio is transferred to ModelScope. The developers have transitioned from "I" to "we". Of course, I will still participate in development and maintenance.
|
34 |
+
|
35 |
+
- **Jan 29, 2024.** We propose Diffutoon, a fantastic solution for toon shading.
|
36 |
+
- [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
|
37 |
+
- The source codes are released in this project.
|
38 |
+
- The technical report (IJCAI 2024) is released on [arXiv](https://arxiv.org/abs/2401.16224).
|
39 |
+
|
40 |
+
- **Dec 8, 2023.** We decide to develop a new Project, aiming to release the potential of diffusion models, especially in video synthesis. The development of this project is started.
|
41 |
+
|
42 |
+
- **Nov 15, 2023.** We propose FastBlend, a powerful video deflickering algorithm.
|
43 |
+
- The sd-webui extension is released on [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
|
44 |
+
- Demo videos are shown on Bilibili, including three tasks.
|
45 |
+
- [Video deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
|
46 |
+
- [Video interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
|
47 |
+
- [Image-driven video rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
|
48 |
+
- The technical report is released on [arXiv](https://arxiv.org/abs/2311.09265).
|
49 |
+
- An unofficial ComfyUI extension developed by other users is released on [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
|
50 |
+
|
51 |
+
- **Oct 1, 2023.** We release an early version of this project, namely FastSDXL. A try for building a diffusion engine.
|
52 |
+
- The source codes are released on [GitHub](https://github.com/Artiprocher/FastSDXL).
|
53 |
+
- FastSDXL includes a trainable OLSS scheduler for efficiency improvement.
|
54 |
+
- The original repo of OLSS is [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
|
55 |
+
- The technical report (CIKM 2023) is released on [arXiv](https://arxiv.org/abs/2305.14677).
|
56 |
+
- A demo video is shown on [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
|
57 |
+
- Since OLSS requires additional training, we don't implement it in this project.
|
58 |
+
|
59 |
+
- **Aug 29, 2023.** We propose DiffSynth, a video synthesis framework.
|
60 |
+
- [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).
|
61 |
+
- The source codes are released in [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
|
62 |
+
- The technical report (ECML PKDD 2024) is released on [arXiv](https://arxiv.org/abs/2308.03463).
|
63 |
+
|
64 |
+
|
65 |
+
## Installation
|
66 |
+
|
67 |
+
```
|
68 |
+
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
69 |
+
cd DiffSynth-Studio
|
70 |
+
pip install -e .
|
71 |
+
```
|
72 |
+
|
73 |
+
## Usage (in Python code)
|
74 |
+
|
75 |
+
The Python examples are in [`examples`](./examples/). We provide an overview here.
|
76 |
+
|
77 |
+
### Long Video Synthesis
|
78 |
+
|
79 |
+
We trained an extended video synthesis model, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/)
|
80 |
+
|
81 |
+
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
|
82 |
+
|
83 |
+
### Image Synthesis
|
84 |
+
|
85 |
+
Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/).
|
86 |
+
|
87 |
+
LoRA fine-tuning is supported in [`examples/train`](./examples/train/).
|
88 |
+
|
89 |
+
|Model|Example|
|
90 |
+
|-|-|
|
91 |
+
|Stable Diffusion|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|
|
92 |
+
|Stable Diffusion XL|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|
|
93 |
+
|Stable Diffusion 3|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)|
|
94 |
+
|Kolors|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/53ef6f41-da11-4701-8665-9f64392607bf)|
|
95 |
+
|Hunyuan-DiT|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/60b022c8-df3f-4541-95ab-bf39f2fa8bb5)|
|
96 |
+
|
97 |
+
### Toon Shading
|
98 |
+
|
99 |
+
Render realistic videos in a flatten style and enable video editing features. [`examples/Diffutoon`](./examples/Diffutoon/)
|
100 |
+
|
101 |
+
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
|
102 |
+
|
103 |
+
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/20528af5-5100-474a-8cdc-440b9efdd86c
|
104 |
+
|
105 |
+
### Video Stylization
|
106 |
+
|
107 |
+
Video stylization without video models. [`examples/diffsynth`](./examples/diffsynth/)
|
108 |
+
|
109 |
+
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
110 |
+
|
111 |
+
## Usage (in WebUI)
|
112 |
+
|
113 |
+
```
|
114 |
+
python -m streamlit run DiffSynth_Studio.py
|
115 |
+
```
|
116 |
+
|
117 |
+
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/93085557-73f3-4eee-a205-9829591ef954
|
diffsynth/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .data import *
|
2 |
+
from .models import *
|
3 |
+
from .prompts import *
|
4 |
+
from .schedulers import *
|
5 |
+
from .pipelines import *
|
6 |
+
from .controlnets import *
|
diffsynth/controlnets/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
|
2 |
+
from .processors import Annotator
|
diffsynth/controlnets/controlnet_unit.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from .processors import Processor_id
|
4 |
+
|
5 |
+
|
6 |
+
class ControlNetConfigUnit:
|
7 |
+
def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
|
8 |
+
self.processor_id = processor_id
|
9 |
+
self.model_path = model_path
|
10 |
+
self.scale = scale
|
11 |
+
|
12 |
+
|
13 |
+
class ControlNetUnit:
|
14 |
+
def __init__(self, processor, model, scale=1.0):
|
15 |
+
self.processor = processor
|
16 |
+
self.model = model
|
17 |
+
self.scale = scale
|
18 |
+
|
19 |
+
|
20 |
+
class MultiControlNetManager:
|
21 |
+
def __init__(self, controlnet_units=[]):
|
22 |
+
self.processors = [unit.processor for unit in controlnet_units]
|
23 |
+
self.models = [unit.model for unit in controlnet_units]
|
24 |
+
self.scales = [unit.scale for unit in controlnet_units]
|
25 |
+
|
26 |
+
def process_image(self, image, processor_id=None):
|
27 |
+
if processor_id is None:
|
28 |
+
processed_image = [processor(image) for processor in self.processors]
|
29 |
+
else:
|
30 |
+
processed_image = [self.processors[processor_id](image)]
|
31 |
+
processed_image = torch.concat([
|
32 |
+
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
|
33 |
+
for image_ in processed_image
|
34 |
+
], dim=0)
|
35 |
+
return processed_image
|
36 |
+
|
37 |
+
def __call__(
|
38 |
+
self,
|
39 |
+
sample, timestep, encoder_hidden_states, conditionings,
|
40 |
+
tiled=False, tile_size=64, tile_stride=32
|
41 |
+
):
|
42 |
+
res_stack = None
|
43 |
+
for conditioning, model, scale in zip(conditionings, self.models, self.scales):
|
44 |
+
res_stack_ = model(
|
45 |
+
sample, timestep, encoder_hidden_states, conditioning,
|
46 |
+
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
47 |
+
)
|
48 |
+
res_stack_ = [res * scale for res in res_stack_]
|
49 |
+
if res_stack is None:
|
50 |
+
res_stack = res_stack_
|
51 |
+
else:
|
52 |
+
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
53 |
+
return res_stack
|
diffsynth/controlnets/processors.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing_extensions import Literal, TypeAlias
|
2 |
+
import warnings
|
3 |
+
with warnings.catch_warnings():
|
4 |
+
warnings.simplefilter("ignore")
|
5 |
+
from controlnet_aux.processor import (
|
6 |
+
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
Processor_id: TypeAlias = Literal[
|
11 |
+
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
|
12 |
+
]
|
13 |
+
|
14 |
+
class Annotator:
|
15 |
+
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'):
|
16 |
+
if processor_id == "canny":
|
17 |
+
self.processor = CannyDetector()
|
18 |
+
elif processor_id == "depth":
|
19 |
+
self.processor = MidasDetector.from_pretrained(model_path).to(device)
|
20 |
+
elif processor_id == "softedge":
|
21 |
+
self.processor = HEDdetector.from_pretrained(model_path).to(device)
|
22 |
+
elif processor_id == "lineart":
|
23 |
+
self.processor = LineartDetector.from_pretrained(model_path).to(device)
|
24 |
+
elif processor_id == "lineart_anime":
|
25 |
+
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
|
26 |
+
elif processor_id == "openpose":
|
27 |
+
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
|
28 |
+
elif processor_id == "tile":
|
29 |
+
self.processor = None
|
30 |
+
else:
|
31 |
+
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
32 |
+
|
33 |
+
self.processor_id = processor_id
|
34 |
+
self.detect_resolution = detect_resolution
|
35 |
+
|
36 |
+
def __call__(self, image):
|
37 |
+
width, height = image.size
|
38 |
+
if self.processor_id == "openpose":
|
39 |
+
kwargs = {
|
40 |
+
"include_body": True,
|
41 |
+
"include_hand": True,
|
42 |
+
"include_face": True
|
43 |
+
}
|
44 |
+
else:
|
45 |
+
kwargs = {}
|
46 |
+
if self.processor is not None:
|
47 |
+
detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
|
48 |
+
image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
|
49 |
+
image = image.resize((width, height))
|
50 |
+
return image
|
51 |
+
|
diffsynth/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .video import VideoData, save_video, save_frames
|
diffsynth/data/video.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import imageio, os
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
|
7 |
+
class LowMemoryVideo:
|
8 |
+
def __init__(self, file_name):
|
9 |
+
self.reader = imageio.get_reader(file_name)
|
10 |
+
|
11 |
+
def __len__(self):
|
12 |
+
return self.reader.count_frames()
|
13 |
+
|
14 |
+
def __getitem__(self, item):
|
15 |
+
return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
|
16 |
+
|
17 |
+
def __del__(self):
|
18 |
+
self.reader.close()
|
19 |
+
|
20 |
+
|
21 |
+
def split_file_name(file_name):
|
22 |
+
result = []
|
23 |
+
number = -1
|
24 |
+
for i in file_name:
|
25 |
+
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
26 |
+
if number == -1:
|
27 |
+
number = 0
|
28 |
+
number = number*10 + ord(i) - ord("0")
|
29 |
+
else:
|
30 |
+
if number != -1:
|
31 |
+
result.append(number)
|
32 |
+
number = -1
|
33 |
+
result.append(i)
|
34 |
+
if number != -1:
|
35 |
+
result.append(number)
|
36 |
+
result = tuple(result)
|
37 |
+
return result
|
38 |
+
|
39 |
+
|
40 |
+
def search_for_images(folder):
|
41 |
+
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
|
42 |
+
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
|
43 |
+
file_list = [i[1] for i in sorted(file_list)]
|
44 |
+
file_list = [os.path.join(folder, i) for i in file_list]
|
45 |
+
return file_list
|
46 |
+
|
47 |
+
|
48 |
+
class LowMemoryImageFolder:
|
49 |
+
def __init__(self, folder, file_list=None):
|
50 |
+
if file_list is None:
|
51 |
+
self.file_list = search_for_images(folder)
|
52 |
+
else:
|
53 |
+
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
|
54 |
+
|
55 |
+
def __len__(self):
|
56 |
+
return len(self.file_list)
|
57 |
+
|
58 |
+
def __getitem__(self, item):
|
59 |
+
return Image.open(self.file_list[item]).convert("RGB")
|
60 |
+
|
61 |
+
def __del__(self):
|
62 |
+
pass
|
63 |
+
|
64 |
+
|
65 |
+
def crop_and_resize(image, height, width):
|
66 |
+
image = np.array(image)
|
67 |
+
image_height, image_width, _ = image.shape
|
68 |
+
if image_height / image_width < height / width:
|
69 |
+
croped_width = int(image_height / height * width)
|
70 |
+
left = (image_width - croped_width) // 2
|
71 |
+
image = image[:, left: left+croped_width]
|
72 |
+
image = Image.fromarray(image).resize((width, height))
|
73 |
+
else:
|
74 |
+
croped_height = int(image_width / width * height)
|
75 |
+
left = (image_height - croped_height) // 2
|
76 |
+
image = image[left: left+croped_height, :]
|
77 |
+
image = Image.fromarray(image).resize((width, height))
|
78 |
+
return image
|
79 |
+
|
80 |
+
|
81 |
+
class VideoData:
|
82 |
+
def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
|
83 |
+
if video_file is not None:
|
84 |
+
self.data_type = "video"
|
85 |
+
self.data = LowMemoryVideo(video_file, **kwargs)
|
86 |
+
elif image_folder is not None:
|
87 |
+
self.data_type = "images"
|
88 |
+
self.data = LowMemoryImageFolder(image_folder, **kwargs)
|
89 |
+
else:
|
90 |
+
raise ValueError("Cannot open video or image folder")
|
91 |
+
self.length = None
|
92 |
+
self.set_shape(height, width)
|
93 |
+
|
94 |
+
def raw_data(self):
|
95 |
+
frames = []
|
96 |
+
for i in range(self.__len__()):
|
97 |
+
frames.append(self.__getitem__(i))
|
98 |
+
return frames
|
99 |
+
|
100 |
+
def set_length(self, length):
|
101 |
+
self.length = length
|
102 |
+
|
103 |
+
def set_shape(self, height, width):
|
104 |
+
self.height = height
|
105 |
+
self.width = width
|
106 |
+
|
107 |
+
def __len__(self):
|
108 |
+
if self.length is None:
|
109 |
+
return len(self.data)
|
110 |
+
else:
|
111 |
+
return self.length
|
112 |
+
|
113 |
+
def shape(self):
|
114 |
+
if self.height is not None and self.width is not None:
|
115 |
+
return self.height, self.width
|
116 |
+
else:
|
117 |
+
height, width, _ = self.__getitem__(0).shape
|
118 |
+
return height, width
|
119 |
+
|
120 |
+
def __getitem__(self, item):
|
121 |
+
frame = self.data.__getitem__(item)
|
122 |
+
width, height = frame.size
|
123 |
+
if self.height is not None and self.width is not None:
|
124 |
+
if self.height != height or self.width != width:
|
125 |
+
frame = crop_and_resize(frame, self.height, self.width)
|
126 |
+
return frame
|
127 |
+
|
128 |
+
def __del__(self):
|
129 |
+
pass
|
130 |
+
|
131 |
+
def save_images(self, folder):
|
132 |
+
os.makedirs(folder, exist_ok=True)
|
133 |
+
for i in tqdm(range(self.__len__()), desc="Saving images"):
|
134 |
+
frame = self.__getitem__(i)
|
135 |
+
frame.save(os.path.join(folder, f"{i}.png"))
|
136 |
+
|
137 |
+
|
138 |
+
def save_video(frames, save_path, fps, quality=9):
|
139 |
+
writer = imageio.get_writer(save_path, fps=fps, quality=quality)
|
140 |
+
for frame in tqdm(frames, desc="Saving video"):
|
141 |
+
frame = np.array(frame)
|
142 |
+
writer.append_data(frame)
|
143 |
+
writer.close()
|
144 |
+
|
145 |
+
def save_frames(frames, save_path):
|
146 |
+
os.makedirs(save_path, exist_ok=True)
|
147 |
+
for i, frame in enumerate(tqdm(frames, desc="Saving images")):
|
148 |
+
frame.save(os.path.join(save_path, f"{i}.png"))
|
diffsynth/extensions/ESRGAN/__init__.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import repeat
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class ResidualDenseBlock(torch.nn.Module):
|
8 |
+
|
9 |
+
def __init__(self, num_feat=64, num_grow_ch=32):
|
10 |
+
super(ResidualDenseBlock, self).__init__()
|
11 |
+
self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
12 |
+
self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
13 |
+
self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
14 |
+
self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
15 |
+
self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
16 |
+
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
x1 = self.lrelu(self.conv1(x))
|
20 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
21 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
22 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
23 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
24 |
+
return x5 * 0.2 + x
|
25 |
+
|
26 |
+
|
27 |
+
class RRDB(torch.nn.Module):
|
28 |
+
|
29 |
+
def __init__(self, num_feat, num_grow_ch=32):
|
30 |
+
super(RRDB, self).__init__()
|
31 |
+
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
32 |
+
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
33 |
+
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
out = self.rdb1(x)
|
37 |
+
out = self.rdb2(out)
|
38 |
+
out = self.rdb3(out)
|
39 |
+
return out * 0.2 + x
|
40 |
+
|
41 |
+
|
42 |
+
class RRDBNet(torch.nn.Module):
|
43 |
+
|
44 |
+
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32):
|
45 |
+
super(RRDBNet, self).__init__()
|
46 |
+
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
47 |
+
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
|
48 |
+
self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
49 |
+
# upsample
|
50 |
+
self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
51 |
+
self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
52 |
+
self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
53 |
+
self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
54 |
+
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
feat = x
|
58 |
+
feat = self.conv_first(feat)
|
59 |
+
body_feat = self.conv_body(self.body(feat))
|
60 |
+
feat = feat + body_feat
|
61 |
+
# upsample
|
62 |
+
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
|
63 |
+
feat = self.lrelu(self.conv_up1(feat))
|
64 |
+
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
|
65 |
+
feat = self.lrelu(self.conv_up2(feat))
|
66 |
+
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
67 |
+
return out
|
68 |
+
|
69 |
+
|
70 |
+
class ESRGAN(torch.nn.Module):
|
71 |
+
def __init__(self, model):
|
72 |
+
super().__init__()
|
73 |
+
self.model = model
|
74 |
+
|
75 |
+
@staticmethod
|
76 |
+
def from_pretrained(model_path):
|
77 |
+
model = RRDBNet()
|
78 |
+
state_dict = torch.load(model_path, map_location="cpu")["params_ema"]
|
79 |
+
model.load_state_dict(state_dict)
|
80 |
+
model.eval()
|
81 |
+
return ESRGAN(model)
|
82 |
+
|
83 |
+
def process_image(self, image):
|
84 |
+
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
|
85 |
+
return image
|
86 |
+
|
87 |
+
def process_images(self, images):
|
88 |
+
images = [self.process_image(image) for image in images]
|
89 |
+
images = torch.stack(images)
|
90 |
+
return images
|
91 |
+
|
92 |
+
def decode_images(self, images):
|
93 |
+
images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
|
94 |
+
images = [Image.fromarray(image) for image in images]
|
95 |
+
return images
|
96 |
+
|
97 |
+
@torch.no_grad()
|
98 |
+
def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
|
99 |
+
# Preprocess
|
100 |
+
input_tensor = self.process_images(images)
|
101 |
+
|
102 |
+
# Interpolate
|
103 |
+
output_tensor = []
|
104 |
+
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
|
105 |
+
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
106 |
+
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
107 |
+
batch_input_tensor = batch_input_tensor.to(
|
108 |
+
device=self.model.conv_first.weight.device,
|
109 |
+
dtype=self.model.conv_first.weight.dtype)
|
110 |
+
batch_output_tensor = self.model(batch_input_tensor)
|
111 |
+
output_tensor.append(batch_output_tensor.cpu())
|
112 |
+
|
113 |
+
# Output
|
114 |
+
output_tensor = torch.concat(output_tensor, dim=0)
|
115 |
+
|
116 |
+
# To images
|
117 |
+
output_images = self.decode_images(output_tensor)
|
118 |
+
return output_images
|
diffsynth/extensions/FastBlend/__init__.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .runners.fast import TableManager, PyramidPatchMatcher
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import cupy as cp
|
5 |
+
|
6 |
+
|
7 |
+
class FastBlendSmoother:
|
8 |
+
def __init__(self):
|
9 |
+
self.batch_size = 8
|
10 |
+
self.window_size = 64
|
11 |
+
self.ebsynth_config = {
|
12 |
+
"minimum_patch_size": 5,
|
13 |
+
"threads_per_block": 8,
|
14 |
+
"num_iter": 5,
|
15 |
+
"gpu_id": 0,
|
16 |
+
"guide_weight": 10.0,
|
17 |
+
"initialize": "identity",
|
18 |
+
"tracking_window_size": 0,
|
19 |
+
}
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def from_model_manager(model_manager):
|
23 |
+
# TODO: fetch GPU ID from model_manager
|
24 |
+
return FastBlendSmoother()
|
25 |
+
|
26 |
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
|
27 |
+
frames_guide = [np.array(frame) for frame in frames_guide]
|
28 |
+
frames_style = [np.array(frame) for frame in frames_style]
|
29 |
+
table_manager = TableManager()
|
30 |
+
patch_match_engine = PyramidPatchMatcher(
|
31 |
+
image_height=frames_style[0].shape[0],
|
32 |
+
image_width=frames_style[0].shape[1],
|
33 |
+
channel=3,
|
34 |
+
**ebsynth_config
|
35 |
+
)
|
36 |
+
# left part
|
37 |
+
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
|
38 |
+
table_l = table_manager.remapping_table_to_blending_table(table_l)
|
39 |
+
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
|
40 |
+
# right part
|
41 |
+
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
|
42 |
+
table_r = table_manager.remapping_table_to_blending_table(table_r)
|
43 |
+
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
|
44 |
+
# merge
|
45 |
+
frames = []
|
46 |
+
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
|
47 |
+
weight_m = -1
|
48 |
+
weight = weight_l + weight_m + weight_r
|
49 |
+
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
|
50 |
+
frames.append(frame)
|
51 |
+
frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
|
52 |
+
return frames
|
53 |
+
|
54 |
+
def __call__(self, rendered_frames, original_frames=None, **kwargs):
|
55 |
+
frames = self.run(
|
56 |
+
original_frames, rendered_frames,
|
57 |
+
self.batch_size, self.window_size, self.ebsynth_config
|
58 |
+
)
|
59 |
+
mempool = cp.get_default_memory_pool()
|
60 |
+
pinned_mempool = cp.get_default_pinned_memory_pool()
|
61 |
+
mempool.free_all_blocks()
|
62 |
+
pinned_mempool.free_all_blocks()
|
63 |
+
return frames
|
diffsynth/extensions/FastBlend/api.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner
|
2 |
+
from .data import VideoData, get_video_fps, save_video, search_for_images
|
3 |
+
import os
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
|
7 |
+
def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder):
|
8 |
+
frames_guide = VideoData(video_guide, video_guide_folder)
|
9 |
+
frames_style = VideoData(video_style, video_style_folder)
|
10 |
+
message = ""
|
11 |
+
if len(frames_guide) < len(frames_style):
|
12 |
+
message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n"
|
13 |
+
frames_style.set_length(len(frames_guide))
|
14 |
+
elif len(frames_guide) > len(frames_style):
|
15 |
+
message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n"
|
16 |
+
frames_guide.set_length(len(frames_style))
|
17 |
+
height_guide, width_guide = frames_guide.shape()
|
18 |
+
height_style, width_style = frames_style.shape()
|
19 |
+
if height_guide != height_style or width_guide != width_style:
|
20 |
+
message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n"
|
21 |
+
frames_style.set_shape(height_guide, width_guide)
|
22 |
+
return frames_guide, frames_style, message
|
23 |
+
|
24 |
+
|
25 |
+
def smooth_video(
|
26 |
+
video_guide,
|
27 |
+
video_guide_folder,
|
28 |
+
video_style,
|
29 |
+
video_style_folder,
|
30 |
+
mode,
|
31 |
+
window_size,
|
32 |
+
batch_size,
|
33 |
+
tracking_window_size,
|
34 |
+
output_path,
|
35 |
+
fps,
|
36 |
+
minimum_patch_size,
|
37 |
+
num_iter,
|
38 |
+
guide_weight,
|
39 |
+
initialize,
|
40 |
+
progress = None,
|
41 |
+
):
|
42 |
+
# input
|
43 |
+
frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder)
|
44 |
+
if len(message) > 0:
|
45 |
+
print(message)
|
46 |
+
# output
|
47 |
+
if output_path == "":
|
48 |
+
if video_style is None:
|
49 |
+
output_path = os.path.join(video_style_folder, "output")
|
50 |
+
else:
|
51 |
+
output_path = os.path.join(os.path.split(video_style)[0], "output")
|
52 |
+
os.makedirs(output_path, exist_ok=True)
|
53 |
+
print("No valid output_path. Your video will be saved here:", output_path)
|
54 |
+
elif not os.path.exists(output_path):
|
55 |
+
os.makedirs(output_path, exist_ok=True)
|
56 |
+
print("Your video will be saved here:", output_path)
|
57 |
+
frames_path = os.path.join(output_path, "frames")
|
58 |
+
video_path = os.path.join(output_path, "video.mp4")
|
59 |
+
os.makedirs(frames_path, exist_ok=True)
|
60 |
+
# process
|
61 |
+
if mode == "Fast" or mode == "Balanced":
|
62 |
+
tracking_window_size = 0
|
63 |
+
ebsynth_config = {
|
64 |
+
"minimum_patch_size": minimum_patch_size,
|
65 |
+
"threads_per_block": 8,
|
66 |
+
"num_iter": num_iter,
|
67 |
+
"gpu_id": 0,
|
68 |
+
"guide_weight": guide_weight,
|
69 |
+
"initialize": initialize,
|
70 |
+
"tracking_window_size": tracking_window_size,
|
71 |
+
}
|
72 |
+
if mode == "Fast":
|
73 |
+
FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
74 |
+
elif mode == "Balanced":
|
75 |
+
BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
76 |
+
elif mode == "Accurate":
|
77 |
+
AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
78 |
+
# output
|
79 |
+
try:
|
80 |
+
fps = int(fps)
|
81 |
+
except:
|
82 |
+
fps = get_video_fps(video_style) if video_style is not None else 30
|
83 |
+
print("Fps:", fps)
|
84 |
+
print("Saving video...")
|
85 |
+
video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps)
|
86 |
+
print("Success!")
|
87 |
+
print("Your frames are here:", frames_path)
|
88 |
+
print("Your video is here:", video_path)
|
89 |
+
return output_path, fps, video_path
|
90 |
+
|
91 |
+
|
92 |
+
class KeyFrameMatcher:
|
93 |
+
def __init__(self):
|
94 |
+
pass
|
95 |
+
|
96 |
+
def extract_number_from_filename(self, file_name):
|
97 |
+
result = []
|
98 |
+
number = -1
|
99 |
+
for i in file_name:
|
100 |
+
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
101 |
+
if number == -1:
|
102 |
+
number = 0
|
103 |
+
number = number*10 + ord(i) - ord("0")
|
104 |
+
else:
|
105 |
+
if number != -1:
|
106 |
+
result.append(number)
|
107 |
+
number = -1
|
108 |
+
if number != -1:
|
109 |
+
result.append(number)
|
110 |
+
result = tuple(result)
|
111 |
+
return result
|
112 |
+
|
113 |
+
def extract_number_from_filenames(self, file_names):
|
114 |
+
numbers = [self.extract_number_from_filename(file_name) for file_name in file_names]
|
115 |
+
min_length = min(len(i) for i in numbers)
|
116 |
+
for i in range(min_length-1, -1, -1):
|
117 |
+
if len(set(number[i] for number in numbers))==len(file_names):
|
118 |
+
return [number[i] for number in numbers]
|
119 |
+
return list(range(len(file_names)))
|
120 |
+
|
121 |
+
def match_using_filename(self, file_names_a, file_names_b):
|
122 |
+
file_names_b_set = set(file_names_b)
|
123 |
+
matched_file_name = []
|
124 |
+
for file_name in file_names_a:
|
125 |
+
if file_name not in file_names_b_set:
|
126 |
+
matched_file_name.append(None)
|
127 |
+
else:
|
128 |
+
matched_file_name.append(file_name)
|
129 |
+
return matched_file_name
|
130 |
+
|
131 |
+
def match_using_numbers(self, file_names_a, file_names_b):
|
132 |
+
numbers_a = self.extract_number_from_filenames(file_names_a)
|
133 |
+
numbers_b = self.extract_number_from_filenames(file_names_b)
|
134 |
+
numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)}
|
135 |
+
matched_file_name = []
|
136 |
+
for number in numbers_a:
|
137 |
+
if number in numbers_b_dict:
|
138 |
+
matched_file_name.append(numbers_b_dict[number])
|
139 |
+
else:
|
140 |
+
matched_file_name.append(None)
|
141 |
+
return matched_file_name
|
142 |
+
|
143 |
+
def match_filenames(self, file_names_a, file_names_b):
|
144 |
+
matched_file_name = self.match_using_filename(file_names_a, file_names_b)
|
145 |
+
if sum([i is not None for i in matched_file_name]) > 0:
|
146 |
+
return matched_file_name
|
147 |
+
matched_file_name = self.match_using_numbers(file_names_a, file_names_b)
|
148 |
+
return matched_file_name
|
149 |
+
|
150 |
+
|
151 |
+
def detect_frames(frames_path, keyframes_path):
|
152 |
+
if not os.path.exists(frames_path) and not os.path.exists(keyframes_path):
|
153 |
+
return "Please input the directory of guide video and rendered frames"
|
154 |
+
elif not os.path.exists(frames_path):
|
155 |
+
return "Please input the directory of guide video"
|
156 |
+
elif not os.path.exists(keyframes_path):
|
157 |
+
return "Please input the directory of rendered frames"
|
158 |
+
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
|
159 |
+
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
|
160 |
+
if len(frames)==0:
|
161 |
+
return f"No images detected in {frames_path}"
|
162 |
+
if len(keyframes)==0:
|
163 |
+
return f"No images detected in {keyframes_path}"
|
164 |
+
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
|
165 |
+
max_filename_length = max([len(i) for i in frames])
|
166 |
+
if sum([i is not None for i in matched_keyframes])==0:
|
167 |
+
message = ""
|
168 |
+
for frame, matched_keyframe in zip(frames, matched_keyframes):
|
169 |
+
message += frame + " " * (max_filename_length - len(frame) + 1)
|
170 |
+
message += "--> No matched keyframes\n"
|
171 |
+
else:
|
172 |
+
message = ""
|
173 |
+
for frame, matched_keyframe in zip(frames, matched_keyframes):
|
174 |
+
message += frame + " " * (max_filename_length - len(frame) + 1)
|
175 |
+
if matched_keyframe is None:
|
176 |
+
message += "--> [to be rendered]\n"
|
177 |
+
else:
|
178 |
+
message += f"--> {matched_keyframe}\n"
|
179 |
+
return message
|
180 |
+
|
181 |
+
|
182 |
+
def check_input_for_interpolating(frames_path, keyframes_path):
|
183 |
+
# search for images
|
184 |
+
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
|
185 |
+
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
|
186 |
+
# match frames
|
187 |
+
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
|
188 |
+
file_list = [file_name for file_name in matched_keyframes if file_name is not None]
|
189 |
+
index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None]
|
190 |
+
frames_guide = VideoData(None, frames_path)
|
191 |
+
frames_style = VideoData(None, keyframes_path, file_list=file_list)
|
192 |
+
# match shape
|
193 |
+
message = ""
|
194 |
+
height_guide, width_guide = frames_guide.shape()
|
195 |
+
height_style, width_style = frames_style.shape()
|
196 |
+
if height_guide != height_style or width_guide != width_style:
|
197 |
+
message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n"
|
198 |
+
frames_style.set_shape(height_guide, width_guide)
|
199 |
+
return frames_guide, frames_style, index_style, message
|
200 |
+
|
201 |
+
|
202 |
+
def interpolate_video(
|
203 |
+
frames_path,
|
204 |
+
keyframes_path,
|
205 |
+
output_path,
|
206 |
+
fps,
|
207 |
+
batch_size,
|
208 |
+
tracking_window_size,
|
209 |
+
minimum_patch_size,
|
210 |
+
num_iter,
|
211 |
+
guide_weight,
|
212 |
+
initialize,
|
213 |
+
progress = None,
|
214 |
+
):
|
215 |
+
# input
|
216 |
+
frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path)
|
217 |
+
if len(message) > 0:
|
218 |
+
print(message)
|
219 |
+
# output
|
220 |
+
if output_path == "":
|
221 |
+
output_path = os.path.join(keyframes_path, "output")
|
222 |
+
os.makedirs(output_path, exist_ok=True)
|
223 |
+
print("No valid output_path. Your video will be saved here:", output_path)
|
224 |
+
elif not os.path.exists(output_path):
|
225 |
+
os.makedirs(output_path, exist_ok=True)
|
226 |
+
print("Your video will be saved here:", output_path)
|
227 |
+
output_frames_path = os.path.join(output_path, "frames")
|
228 |
+
output_video_path = os.path.join(output_path, "video.mp4")
|
229 |
+
os.makedirs(output_frames_path, exist_ok=True)
|
230 |
+
# process
|
231 |
+
ebsynth_config = {
|
232 |
+
"minimum_patch_size": minimum_patch_size,
|
233 |
+
"threads_per_block": 8,
|
234 |
+
"num_iter": num_iter,
|
235 |
+
"gpu_id": 0,
|
236 |
+
"guide_weight": guide_weight,
|
237 |
+
"initialize": initialize,
|
238 |
+
"tracking_window_size": tracking_window_size
|
239 |
+
}
|
240 |
+
if len(index_style)==1:
|
241 |
+
InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
|
242 |
+
else:
|
243 |
+
InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
|
244 |
+
try:
|
245 |
+
fps = int(fps)
|
246 |
+
except:
|
247 |
+
fps = 30
|
248 |
+
print("Fps:", fps)
|
249 |
+
print("Saving video...")
|
250 |
+
video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps)
|
251 |
+
print("Success!")
|
252 |
+
print("Your frames are here:", output_frames_path)
|
253 |
+
print("Your video is here:", video_path)
|
254 |
+
return output_path, fps, video_path
|
255 |
+
|
256 |
+
|
257 |
+
def on_ui_tabs():
|
258 |
+
with gr.Blocks(analytics_enabled=False) as ui_component:
|
259 |
+
with gr.Tab("Blend"):
|
260 |
+
gr.Markdown("""
|
261 |
+
# Blend
|
262 |
+
|
263 |
+
Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution.
|
264 |
+
""")
|
265 |
+
with gr.Row():
|
266 |
+
with gr.Column():
|
267 |
+
with gr.Tab("Guide video"):
|
268 |
+
video_guide = gr.Video(label="Guide video")
|
269 |
+
with gr.Tab("Guide video (images format)"):
|
270 |
+
video_guide_folder = gr.Textbox(label="Guide video (images format)", value="")
|
271 |
+
with gr.Column():
|
272 |
+
with gr.Tab("Style video"):
|
273 |
+
video_style = gr.Video(label="Style video")
|
274 |
+
with gr.Tab("Style video (images format)"):
|
275 |
+
video_style_folder = gr.Textbox(label="Style video (images format)", value="")
|
276 |
+
with gr.Column():
|
277 |
+
output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video")
|
278 |
+
fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
|
279 |
+
video_output = gr.Video(label="Output video", interactive=False, show_share_button=True)
|
280 |
+
btn = gr.Button(value="Blend")
|
281 |
+
with gr.Row():
|
282 |
+
with gr.Column():
|
283 |
+
gr.Markdown("# Settings")
|
284 |
+
mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True)
|
285 |
+
window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True)
|
286 |
+
batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
|
287 |
+
tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True)
|
288 |
+
gr.Markdown("## Advanced Settings")
|
289 |
+
minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True)
|
290 |
+
num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
|
291 |
+
guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
|
292 |
+
initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
|
293 |
+
with gr.Column():
|
294 |
+
gr.Markdown("""
|
295 |
+
# Reference
|
296 |
+
|
297 |
+
* Output directory: the directory to save the video.
|
298 |
+
* Inference mode
|
299 |
+
|
300 |
+
|Mode|Time|Memory|Quality|Frame by frame output|Description|
|
301 |
+
|-|-|-|-|-|-|
|
302 |
+
|Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.|
|
303 |
+
|Balanced|■■|■|■■|Yes|Blend the frames naively.|
|
304 |
+
|Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.|
|
305 |
+
|
306 |
+
* Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy.
|
307 |
+
* Batch size: a larger batch size makes the program faster but requires more VRAM.
|
308 |
+
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
|
309 |
+
* Advanced settings
|
310 |
+
* Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5)
|
311 |
+
* Number of iterations: the number of iterations of patch matching. (Default: 5)
|
312 |
+
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
|
313 |
+
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
|
314 |
+
""")
|
315 |
+
btn.click(
|
316 |
+
smooth_video,
|
317 |
+
inputs=[
|
318 |
+
video_guide,
|
319 |
+
video_guide_folder,
|
320 |
+
video_style,
|
321 |
+
video_style_folder,
|
322 |
+
mode,
|
323 |
+
window_size,
|
324 |
+
batch_size,
|
325 |
+
tracking_window_size,
|
326 |
+
output_path,
|
327 |
+
fps,
|
328 |
+
minimum_patch_size,
|
329 |
+
num_iter,
|
330 |
+
guide_weight,
|
331 |
+
initialize
|
332 |
+
],
|
333 |
+
outputs=[output_path, fps, video_output]
|
334 |
+
)
|
335 |
+
with gr.Tab("Interpolate"):
|
336 |
+
gr.Markdown("""
|
337 |
+
# Interpolate
|
338 |
+
|
339 |
+
Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution.
|
340 |
+
""")
|
341 |
+
with gr.Row():
|
342 |
+
with gr.Column():
|
343 |
+
with gr.Row():
|
344 |
+
with gr.Column():
|
345 |
+
video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="")
|
346 |
+
with gr.Column():
|
347 |
+
rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="")
|
348 |
+
with gr.Row():
|
349 |
+
detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False)
|
350 |
+
video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
|
351 |
+
rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
|
352 |
+
with gr.Column():
|
353 |
+
output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes")
|
354 |
+
fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
|
355 |
+
video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True)
|
356 |
+
btn_ = gr.Button(value="Interpolate")
|
357 |
+
with gr.Row():
|
358 |
+
with gr.Column():
|
359 |
+
gr.Markdown("# Settings")
|
360 |
+
batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
|
361 |
+
tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True)
|
362 |
+
gr.Markdown("## Advanced Settings")
|
363 |
+
minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True)
|
364 |
+
num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
|
365 |
+
guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
|
366 |
+
initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
|
367 |
+
with gr.Column():
|
368 |
+
gr.Markdown("""
|
369 |
+
# Reference
|
370 |
+
|
371 |
+
* Output directory: the directory to save the video.
|
372 |
+
* Batch size: a larger batch size makes the program faster but requires more VRAM.
|
373 |
+
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
|
374 |
+
* Advanced settings
|
375 |
+
* Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)**
|
376 |
+
* Number of iterations: the number of iterations of patch matching. (Default: 5)
|
377 |
+
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
|
378 |
+
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
|
379 |
+
""")
|
380 |
+
btn_.click(
|
381 |
+
interpolate_video,
|
382 |
+
inputs=[
|
383 |
+
video_guide_folder_,
|
384 |
+
rendered_keyframes_,
|
385 |
+
output_path_,
|
386 |
+
fps_,
|
387 |
+
batch_size_,
|
388 |
+
tracking_window_size_,
|
389 |
+
minimum_patch_size_,
|
390 |
+
num_iter_,
|
391 |
+
guide_weight_,
|
392 |
+
initialize_,
|
393 |
+
],
|
394 |
+
outputs=[output_path_, fps_, video_output_]
|
395 |
+
)
|
396 |
+
|
397 |
+
return [(ui_component, "FastBlend", "FastBlend_ui")]
|
diffsynth/extensions/FastBlend/cupy_kernels.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cupy as cp
|
2 |
+
|
3 |
+
remapping_kernel = cp.RawKernel(r'''
|
4 |
+
extern "C" __global__
|
5 |
+
void remap(
|
6 |
+
const int height,
|
7 |
+
const int width,
|
8 |
+
const int channel,
|
9 |
+
const int patch_size,
|
10 |
+
const int pad_size,
|
11 |
+
const float* source_style,
|
12 |
+
const int* nnf,
|
13 |
+
float* target_style
|
14 |
+
) {
|
15 |
+
const int r = (patch_size - 1) / 2;
|
16 |
+
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
17 |
+
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
18 |
+
if (x >= height or y >= width) return;
|
19 |
+
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
20 |
+
const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
|
21 |
+
const int min_px = x < r ? -x : -r;
|
22 |
+
const int max_px = x + r > height - 1 ? height - 1 - x : r;
|
23 |
+
const int min_py = y < r ? -y : -r;
|
24 |
+
const int max_py = y + r > width - 1 ? width - 1 - y : r;
|
25 |
+
int num = 0;
|
26 |
+
for (int px = min_px; px <= max_px; px++){
|
27 |
+
for (int py = min_py; py <= max_py; py++){
|
28 |
+
const int nid = (x + px) * width + y + py;
|
29 |
+
const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
|
30 |
+
const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
|
31 |
+
if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
|
32 |
+
const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
|
33 |
+
num++;
|
34 |
+
for (int c = 0; c < channel; c++){
|
35 |
+
target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
|
36 |
+
}
|
37 |
+
}
|
38 |
+
}
|
39 |
+
for (int c = 0; c < channel; c++){
|
40 |
+
target_style[z + pid * channel + c] /= num;
|
41 |
+
}
|
42 |
+
}
|
43 |
+
''', 'remap')
|
44 |
+
|
45 |
+
|
46 |
+
patch_error_kernel = cp.RawKernel(r'''
|
47 |
+
extern "C" __global__
|
48 |
+
void patch_error(
|
49 |
+
const int height,
|
50 |
+
const int width,
|
51 |
+
const int channel,
|
52 |
+
const int patch_size,
|
53 |
+
const int pad_size,
|
54 |
+
const float* source,
|
55 |
+
const int* nnf,
|
56 |
+
const float* target,
|
57 |
+
float* error
|
58 |
+
) {
|
59 |
+
const int r = (patch_size - 1) / 2;
|
60 |
+
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
61 |
+
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
62 |
+
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
63 |
+
if (x >= height or y >= width) return;
|
64 |
+
const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
|
65 |
+
const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
|
66 |
+
float e = 0;
|
67 |
+
for (int px = -r; px <= r; px++){
|
68 |
+
for (int py = -r; py <= r; py++){
|
69 |
+
const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
|
70 |
+
const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
|
71 |
+
for (int c = 0; c < channel; c++){
|
72 |
+
const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
|
73 |
+
e += diff * diff;
|
74 |
+
}
|
75 |
+
}
|
76 |
+
}
|
77 |
+
error[blockIdx.z * height * width + x * width + y] = e;
|
78 |
+
}
|
79 |
+
''', 'patch_error')
|
80 |
+
|
81 |
+
|
82 |
+
pairwise_patch_error_kernel = cp.RawKernel(r'''
|
83 |
+
extern "C" __global__
|
84 |
+
void pairwise_patch_error(
|
85 |
+
const int height,
|
86 |
+
const int width,
|
87 |
+
const int channel,
|
88 |
+
const int patch_size,
|
89 |
+
const int pad_size,
|
90 |
+
const float* source_a,
|
91 |
+
const int* nnf_a,
|
92 |
+
const float* source_b,
|
93 |
+
const int* nnf_b,
|
94 |
+
float* error
|
95 |
+
) {
|
96 |
+
const int r = (patch_size - 1) / 2;
|
97 |
+
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
98 |
+
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
99 |
+
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
100 |
+
if (x >= height or y >= width) return;
|
101 |
+
const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
|
102 |
+
const int x_a = nnf_a[z_nnf + 0];
|
103 |
+
const int y_a = nnf_a[z_nnf + 1];
|
104 |
+
const int x_b = nnf_b[z_nnf + 0];
|
105 |
+
const int y_b = nnf_b[z_nnf + 1];
|
106 |
+
float e = 0;
|
107 |
+
for (int px = -r; px <= r; px++){
|
108 |
+
for (int py = -r; py <= r; py++){
|
109 |
+
const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
|
110 |
+
const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
|
111 |
+
for (int c = 0; c < channel; c++){
|
112 |
+
const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
|
113 |
+
e += diff * diff;
|
114 |
+
}
|
115 |
+
}
|
116 |
+
}
|
117 |
+
error[blockIdx.z * height * width + x * width + y] = e;
|
118 |
+
}
|
119 |
+
''', 'pairwise_patch_error')
|
diffsynth/extensions/FastBlend/data.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import imageio, os
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
def read_video(file_name):
|
7 |
+
reader = imageio.get_reader(file_name)
|
8 |
+
video = []
|
9 |
+
for frame in reader:
|
10 |
+
frame = np.array(frame)
|
11 |
+
video.append(frame)
|
12 |
+
reader.close()
|
13 |
+
return video
|
14 |
+
|
15 |
+
|
16 |
+
def get_video_fps(file_name):
|
17 |
+
reader = imageio.get_reader(file_name)
|
18 |
+
fps = reader.get_meta_data()["fps"]
|
19 |
+
reader.close()
|
20 |
+
return fps
|
21 |
+
|
22 |
+
|
23 |
+
def save_video(frames_path, video_path, num_frames, fps):
|
24 |
+
writer = imageio.get_writer(video_path, fps=fps, quality=9)
|
25 |
+
for i in range(num_frames):
|
26 |
+
frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
|
27 |
+
writer.append_data(frame)
|
28 |
+
writer.close()
|
29 |
+
return video_path
|
30 |
+
|
31 |
+
|
32 |
+
class LowMemoryVideo:
|
33 |
+
def __init__(self, file_name):
|
34 |
+
self.reader = imageio.get_reader(file_name)
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return self.reader.count_frames()
|
38 |
+
|
39 |
+
def __getitem__(self, item):
|
40 |
+
return np.array(self.reader.get_data(item))
|
41 |
+
|
42 |
+
def __del__(self):
|
43 |
+
self.reader.close()
|
44 |
+
|
45 |
+
|
46 |
+
def split_file_name(file_name):
|
47 |
+
result = []
|
48 |
+
number = -1
|
49 |
+
for i in file_name:
|
50 |
+
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
51 |
+
if number == -1:
|
52 |
+
number = 0
|
53 |
+
number = number*10 + ord(i) - ord("0")
|
54 |
+
else:
|
55 |
+
if number != -1:
|
56 |
+
result.append(number)
|
57 |
+
number = -1
|
58 |
+
result.append(i)
|
59 |
+
if number != -1:
|
60 |
+
result.append(number)
|
61 |
+
result = tuple(result)
|
62 |
+
return result
|
63 |
+
|
64 |
+
|
65 |
+
def search_for_images(folder):
|
66 |
+
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
|
67 |
+
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
|
68 |
+
file_list = [i[1] for i in sorted(file_list)]
|
69 |
+
file_list = [os.path.join(folder, i) for i in file_list]
|
70 |
+
return file_list
|
71 |
+
|
72 |
+
|
73 |
+
def read_images(folder):
|
74 |
+
file_list = search_for_images(folder)
|
75 |
+
frames = [np.array(Image.open(i)) for i in file_list]
|
76 |
+
return frames
|
77 |
+
|
78 |
+
|
79 |
+
class LowMemoryImageFolder:
|
80 |
+
def __init__(self, folder, file_list=None):
|
81 |
+
if file_list is None:
|
82 |
+
self.file_list = search_for_images(folder)
|
83 |
+
else:
|
84 |
+
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
|
85 |
+
|
86 |
+
def __len__(self):
|
87 |
+
return len(self.file_list)
|
88 |
+
|
89 |
+
def __getitem__(self, item):
|
90 |
+
return np.array(Image.open(self.file_list[item]))
|
91 |
+
|
92 |
+
def __del__(self):
|
93 |
+
pass
|
94 |
+
|
95 |
+
|
96 |
+
class VideoData:
|
97 |
+
def __init__(self, video_file, image_folder, **kwargs):
|
98 |
+
if video_file is not None:
|
99 |
+
self.data_type = "video"
|
100 |
+
self.data = LowMemoryVideo(video_file, **kwargs)
|
101 |
+
elif image_folder is not None:
|
102 |
+
self.data_type = "images"
|
103 |
+
self.data = LowMemoryImageFolder(image_folder, **kwargs)
|
104 |
+
else:
|
105 |
+
raise ValueError("Cannot open video or image folder")
|
106 |
+
self.length = None
|
107 |
+
self.height = None
|
108 |
+
self.width = None
|
109 |
+
|
110 |
+
def raw_data(self):
|
111 |
+
frames = []
|
112 |
+
for i in range(self.__len__()):
|
113 |
+
frames.append(self.__getitem__(i))
|
114 |
+
return frames
|
115 |
+
|
116 |
+
def set_length(self, length):
|
117 |
+
self.length = length
|
118 |
+
|
119 |
+
def set_shape(self, height, width):
|
120 |
+
self.height = height
|
121 |
+
self.width = width
|
122 |
+
|
123 |
+
def __len__(self):
|
124 |
+
if self.length is None:
|
125 |
+
return len(self.data)
|
126 |
+
else:
|
127 |
+
return self.length
|
128 |
+
|
129 |
+
def shape(self):
|
130 |
+
if self.height is not None and self.width is not None:
|
131 |
+
return self.height, self.width
|
132 |
+
else:
|
133 |
+
height, width, _ = self.__getitem__(0).shape
|
134 |
+
return height, width
|
135 |
+
|
136 |
+
def __getitem__(self, item):
|
137 |
+
frame = self.data.__getitem__(item)
|
138 |
+
height, width, _ = frame.shape
|
139 |
+
if self.height is not None and self.width is not None:
|
140 |
+
if self.height != height or self.width != width:
|
141 |
+
frame = Image.fromarray(frame).resize((self.width, self.height))
|
142 |
+
frame = np.array(frame)
|
143 |
+
return frame
|
144 |
+
|
145 |
+
def __del__(self):
|
146 |
+
pass
|
diffsynth/extensions/FastBlend/patch_match.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
|
2 |
+
import numpy as np
|
3 |
+
import cupy as cp
|
4 |
+
import cv2
|
5 |
+
|
6 |
+
|
7 |
+
class PatchMatcher:
|
8 |
+
def __init__(
|
9 |
+
self, height, width, channel, minimum_patch_size,
|
10 |
+
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
|
11 |
+
random_search_steps=3, random_search_range=4,
|
12 |
+
use_mean_target_style=False, use_pairwise_patch_error=False,
|
13 |
+
tracking_window_size=0
|
14 |
+
):
|
15 |
+
self.height = height
|
16 |
+
self.width = width
|
17 |
+
self.channel = channel
|
18 |
+
self.minimum_patch_size = minimum_patch_size
|
19 |
+
self.threads_per_block = threads_per_block
|
20 |
+
self.num_iter = num_iter
|
21 |
+
self.gpu_id = gpu_id
|
22 |
+
self.guide_weight = guide_weight
|
23 |
+
self.random_search_steps = random_search_steps
|
24 |
+
self.random_search_range = random_search_range
|
25 |
+
self.use_mean_target_style = use_mean_target_style
|
26 |
+
self.use_pairwise_patch_error = use_pairwise_patch_error
|
27 |
+
self.tracking_window_size = tracking_window_size
|
28 |
+
|
29 |
+
self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
|
30 |
+
self.pad_size = self.patch_size_list[0] // 2
|
31 |
+
self.grid = (
|
32 |
+
(height + threads_per_block - 1) // threads_per_block,
|
33 |
+
(width + threads_per_block - 1) // threads_per_block
|
34 |
+
)
|
35 |
+
self.block = (threads_per_block, threads_per_block)
|
36 |
+
|
37 |
+
def pad_image(self, image):
|
38 |
+
return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
|
39 |
+
|
40 |
+
def unpad_image(self, image):
|
41 |
+
return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
|
42 |
+
|
43 |
+
def apply_nnf_to_image(self, nnf, source):
|
44 |
+
batch_size = source.shape[0]
|
45 |
+
target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
|
46 |
+
remapping_kernel(
|
47 |
+
self.grid + (batch_size,),
|
48 |
+
self.block,
|
49 |
+
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
|
50 |
+
)
|
51 |
+
return target
|
52 |
+
|
53 |
+
def get_patch_error(self, source, nnf, target):
|
54 |
+
batch_size = source.shape[0]
|
55 |
+
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
|
56 |
+
patch_error_kernel(
|
57 |
+
self.grid + (batch_size,),
|
58 |
+
self.block,
|
59 |
+
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
|
60 |
+
)
|
61 |
+
return error
|
62 |
+
|
63 |
+
def get_pairwise_patch_error(self, source, nnf):
|
64 |
+
batch_size = source.shape[0]//2
|
65 |
+
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
|
66 |
+
source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
|
67 |
+
source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
|
68 |
+
pairwise_patch_error_kernel(
|
69 |
+
self.grid + (batch_size,),
|
70 |
+
self.block,
|
71 |
+
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
|
72 |
+
)
|
73 |
+
error = error.repeat(2, axis=0)
|
74 |
+
return error
|
75 |
+
|
76 |
+
def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
|
77 |
+
error_guide = self.get_patch_error(source_guide, nnf, target_guide)
|
78 |
+
if self.use_mean_target_style:
|
79 |
+
target_style = self.apply_nnf_to_image(nnf, source_style)
|
80 |
+
target_style = target_style.mean(axis=0, keepdims=True)
|
81 |
+
target_style = target_style.repeat(source_guide.shape[0], axis=0)
|
82 |
+
if self.use_pairwise_patch_error:
|
83 |
+
error_style = self.get_pairwise_patch_error(source_style, nnf)
|
84 |
+
else:
|
85 |
+
error_style = self.get_patch_error(source_style, nnf, target_style)
|
86 |
+
error = error_guide * self.guide_weight + error_style
|
87 |
+
return error
|
88 |
+
|
89 |
+
def clamp_bound(self, nnf):
|
90 |
+
nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
|
91 |
+
nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
|
92 |
+
return nnf
|
93 |
+
|
94 |
+
def random_step(self, nnf, r):
|
95 |
+
batch_size = nnf.shape[0]
|
96 |
+
step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
|
97 |
+
upd_nnf = self.clamp_bound(nnf + step)
|
98 |
+
return upd_nnf
|
99 |
+
|
100 |
+
def neighboor_step(self, nnf, d):
|
101 |
+
if d==0:
|
102 |
+
upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
|
103 |
+
upd_nnf[:, :, :, 0] += 1
|
104 |
+
elif d==1:
|
105 |
+
upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
|
106 |
+
upd_nnf[:, :, :, 1] += 1
|
107 |
+
elif d==2:
|
108 |
+
upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
|
109 |
+
upd_nnf[:, :, :, 0] -= 1
|
110 |
+
elif d==3:
|
111 |
+
upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
|
112 |
+
upd_nnf[:, :, :, 1] -= 1
|
113 |
+
upd_nnf = self.clamp_bound(upd_nnf)
|
114 |
+
return upd_nnf
|
115 |
+
|
116 |
+
def shift_nnf(self, nnf, d):
|
117 |
+
if d>0:
|
118 |
+
d = min(nnf.shape[0], d)
|
119 |
+
upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
|
120 |
+
else:
|
121 |
+
d = max(-nnf.shape[0], d)
|
122 |
+
upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
|
123 |
+
return upd_nnf
|
124 |
+
|
125 |
+
def track_step(self, nnf, d):
|
126 |
+
if self.use_pairwise_patch_error:
|
127 |
+
upd_nnf = cp.zeros_like(nnf)
|
128 |
+
upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
|
129 |
+
upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
|
130 |
+
else:
|
131 |
+
upd_nnf = self.shift_nnf(nnf, d)
|
132 |
+
return upd_nnf
|
133 |
+
|
134 |
+
def C(self, n, m):
|
135 |
+
# not used
|
136 |
+
c = 1
|
137 |
+
for i in range(1, n+1):
|
138 |
+
c *= i
|
139 |
+
for i in range(1, m+1):
|
140 |
+
c //= i
|
141 |
+
for i in range(1, n-m+1):
|
142 |
+
c //= i
|
143 |
+
return c
|
144 |
+
|
145 |
+
def bezier_step(self, nnf, r):
|
146 |
+
# not used
|
147 |
+
n = r * 2 - 1
|
148 |
+
upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
|
149 |
+
for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
|
150 |
+
if d>0:
|
151 |
+
ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
|
152 |
+
elif d<0:
|
153 |
+
ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
|
154 |
+
upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
|
155 |
+
upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
|
156 |
+
return upd_nnf
|
157 |
+
|
158 |
+
def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
|
159 |
+
upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
|
160 |
+
upd_idx = (upd_err < err)
|
161 |
+
nnf[upd_idx] = upd_nnf[upd_idx]
|
162 |
+
err[upd_idx] = upd_err[upd_idx]
|
163 |
+
return nnf, err
|
164 |
+
|
165 |
+
def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
166 |
+
for d in cp.random.permutation(4):
|
167 |
+
upd_nnf = self.neighboor_step(nnf, d)
|
168 |
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
169 |
+
return nnf, err
|
170 |
+
|
171 |
+
def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
172 |
+
for i in range(self.random_search_steps):
|
173 |
+
upd_nnf = self.random_step(nnf, self.random_search_range)
|
174 |
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
175 |
+
return nnf, err
|
176 |
+
|
177 |
+
def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
178 |
+
for d in range(1, self.tracking_window_size + 1):
|
179 |
+
upd_nnf = self.track_step(nnf, d)
|
180 |
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
181 |
+
upd_nnf = self.track_step(nnf, -d)
|
182 |
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
183 |
+
return nnf, err
|
184 |
+
|
185 |
+
def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
186 |
+
nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
|
187 |
+
nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
|
188 |
+
nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
|
189 |
+
return nnf, err
|
190 |
+
|
191 |
+
def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
|
192 |
+
with cp.cuda.Device(self.gpu_id):
|
193 |
+
source_guide = self.pad_image(source_guide)
|
194 |
+
target_guide = self.pad_image(target_guide)
|
195 |
+
source_style = self.pad_image(source_style)
|
196 |
+
for it in range(self.num_iter):
|
197 |
+
self.patch_size = self.patch_size_list[it]
|
198 |
+
target_style = self.apply_nnf_to_image(nnf, source_style)
|
199 |
+
err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
|
200 |
+
nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
|
201 |
+
target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
|
202 |
+
return nnf, target_style
|
203 |
+
|
204 |
+
|
205 |
+
class PyramidPatchMatcher:
|
206 |
+
def __init__(
|
207 |
+
self, image_height, image_width, channel, minimum_patch_size,
|
208 |
+
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
|
209 |
+
use_mean_target_style=False, use_pairwise_patch_error=False,
|
210 |
+
tracking_window_size=0,
|
211 |
+
initialize="identity"
|
212 |
+
):
|
213 |
+
maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
|
214 |
+
self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
|
215 |
+
self.pyramid_heights = []
|
216 |
+
self.pyramid_widths = []
|
217 |
+
self.patch_matchers = []
|
218 |
+
self.minimum_patch_size = minimum_patch_size
|
219 |
+
self.num_iter = num_iter
|
220 |
+
self.gpu_id = gpu_id
|
221 |
+
self.initialize = initialize
|
222 |
+
for level in range(self.pyramid_level):
|
223 |
+
height = image_height//(2**(self.pyramid_level - 1 - level))
|
224 |
+
width = image_width//(2**(self.pyramid_level - 1 - level))
|
225 |
+
self.pyramid_heights.append(height)
|
226 |
+
self.pyramid_widths.append(width)
|
227 |
+
self.patch_matchers.append(PatchMatcher(
|
228 |
+
height, width, channel, minimum_patch_size=minimum_patch_size,
|
229 |
+
threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
|
230 |
+
use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
|
231 |
+
tracking_window_size=tracking_window_size
|
232 |
+
))
|
233 |
+
|
234 |
+
def resample_image(self, images, level):
|
235 |
+
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
236 |
+
images = images.get()
|
237 |
+
images_resample = []
|
238 |
+
for image in images:
|
239 |
+
image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
|
240 |
+
images_resample.append(image_resample)
|
241 |
+
images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
|
242 |
+
return images_resample
|
243 |
+
|
244 |
+
def initialize_nnf(self, batch_size):
|
245 |
+
if self.initialize == "random":
|
246 |
+
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
|
247 |
+
nnf = cp.stack([
|
248 |
+
cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
|
249 |
+
cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
|
250 |
+
], axis=3)
|
251 |
+
elif self.initialize == "identity":
|
252 |
+
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
|
253 |
+
nnf = cp.stack([
|
254 |
+
cp.repeat(cp.arange(height), width).reshape(height, width),
|
255 |
+
cp.tile(cp.arange(width), height).reshape(height, width)
|
256 |
+
], axis=2)
|
257 |
+
nnf = cp.stack([nnf] * batch_size)
|
258 |
+
else:
|
259 |
+
raise NotImplementedError()
|
260 |
+
return nnf
|
261 |
+
|
262 |
+
def update_nnf(self, nnf, level):
|
263 |
+
# upscale
|
264 |
+
nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
|
265 |
+
nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
|
266 |
+
nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
|
267 |
+
# check if scale is 2
|
268 |
+
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
269 |
+
if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
|
270 |
+
nnf = nnf.get().astype(np.float32)
|
271 |
+
nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
|
272 |
+
nnf = cp.array(np.stack(nnf), dtype=cp.int32)
|
273 |
+
nnf = self.patch_matchers[level].clamp_bound(nnf)
|
274 |
+
return nnf
|
275 |
+
|
276 |
+
def apply_nnf_to_image(self, nnf, image):
|
277 |
+
with cp.cuda.Device(self.gpu_id):
|
278 |
+
image = self.patch_matchers[-1].pad_image(image)
|
279 |
+
image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
|
280 |
+
return image
|
281 |
+
|
282 |
+
def estimate_nnf(self, source_guide, target_guide, source_style):
|
283 |
+
with cp.cuda.Device(self.gpu_id):
|
284 |
+
if not isinstance(source_guide, cp.ndarray):
|
285 |
+
source_guide = cp.array(source_guide, dtype=cp.float32)
|
286 |
+
if not isinstance(target_guide, cp.ndarray):
|
287 |
+
target_guide = cp.array(target_guide, dtype=cp.float32)
|
288 |
+
if not isinstance(source_style, cp.ndarray):
|
289 |
+
source_style = cp.array(source_style, dtype=cp.float32)
|
290 |
+
for level in range(self.pyramid_level):
|
291 |
+
nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
|
292 |
+
source_guide_ = self.resample_image(source_guide, level)
|
293 |
+
target_guide_ = self.resample_image(target_guide, level)
|
294 |
+
source_style_ = self.resample_image(source_style, level)
|
295 |
+
nnf, target_style = self.patch_matchers[level].estimate_nnf(
|
296 |
+
source_guide_, target_guide_, source_style_, nnf
|
297 |
+
)
|
298 |
+
return nnf.get(), target_style.get()
|
diffsynth/extensions/FastBlend/runners/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .accurate import AccurateModeRunner
|
2 |
+
from .fast import FastModeRunner
|
3 |
+
from .balanced import BalancedModeRunner
|
4 |
+
from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner
|
diffsynth/extensions/FastBlend/runners/accurate.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..patch_match import PyramidPatchMatcher
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class AccurateModeRunner:
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
|
13 |
+
patch_match_engine = PyramidPatchMatcher(
|
14 |
+
image_height=frames_style[0].shape[0],
|
15 |
+
image_width=frames_style[0].shape[1],
|
16 |
+
channel=3,
|
17 |
+
use_mean_target_style=True,
|
18 |
+
**ebsynth_config
|
19 |
+
)
|
20 |
+
# run
|
21 |
+
n = len(frames_style)
|
22 |
+
for target in tqdm(range(n), desc=desc):
|
23 |
+
l, r = max(target - window_size, 0), min(target + window_size + 1, n)
|
24 |
+
remapped_frames = []
|
25 |
+
for i in range(l, r, batch_size):
|
26 |
+
j = min(i + batch_size, r)
|
27 |
+
source_guide = np.stack([frames_guide[source] for source in range(i, j)])
|
28 |
+
target_guide = np.stack([frames_guide[target]] * (j - i))
|
29 |
+
source_style = np.stack([frames_style[source] for source in range(i, j)])
|
30 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
31 |
+
remapped_frames.append(target_style)
|
32 |
+
frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
|
33 |
+
frame = frame.clip(0, 255).astype("uint8")
|
34 |
+
if save_path is not None:
|
35 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
diffsynth/extensions/FastBlend/runners/balanced.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..patch_match import PyramidPatchMatcher
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class BalancedModeRunner:
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
|
13 |
+
patch_match_engine = PyramidPatchMatcher(
|
14 |
+
image_height=frames_style[0].shape[0],
|
15 |
+
image_width=frames_style[0].shape[1],
|
16 |
+
channel=3,
|
17 |
+
**ebsynth_config
|
18 |
+
)
|
19 |
+
# tasks
|
20 |
+
n = len(frames_style)
|
21 |
+
tasks = []
|
22 |
+
for target in range(n):
|
23 |
+
for source in range(target - window_size, target + window_size + 1):
|
24 |
+
if source >= 0 and source < n and source != target:
|
25 |
+
tasks.append((source, target))
|
26 |
+
# run
|
27 |
+
frames = [(None, 1) for i in range(n)]
|
28 |
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
29 |
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
30 |
+
source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
|
31 |
+
target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
|
32 |
+
source_style = np.stack([frames_style[source] for source, target in tasks_batch])
|
33 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
34 |
+
for (source, target), result in zip(tasks_batch, target_style):
|
35 |
+
frame, weight = frames[target]
|
36 |
+
if frame is None:
|
37 |
+
frame = frames_style[target]
|
38 |
+
frames[target] = (
|
39 |
+
frame * (weight / (weight + 1)) + result / (weight + 1),
|
40 |
+
weight + 1
|
41 |
+
)
|
42 |
+
if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
|
43 |
+
frame = frame.clip(0, 255).astype("uint8")
|
44 |
+
if save_path is not None:
|
45 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
46 |
+
frames[target] = (None, 1)
|
diffsynth/extensions/FastBlend/runners/fast.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..patch_match import PyramidPatchMatcher
|
2 |
+
import functools, os
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class TableManager:
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def task_list(self, n):
|
13 |
+
tasks = []
|
14 |
+
max_level = 1
|
15 |
+
while (1<<max_level)<=n:
|
16 |
+
max_level += 1
|
17 |
+
for i in range(n):
|
18 |
+
j = i
|
19 |
+
for level in range(max_level):
|
20 |
+
if i&(1<<level):
|
21 |
+
continue
|
22 |
+
j |= 1<<level
|
23 |
+
if j>=n:
|
24 |
+
break
|
25 |
+
meta_data = {
|
26 |
+
"source": i,
|
27 |
+
"target": j,
|
28 |
+
"level": level + 1
|
29 |
+
}
|
30 |
+
tasks.append(meta_data)
|
31 |
+
tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
|
32 |
+
return tasks
|
33 |
+
|
34 |
+
def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
|
35 |
+
n = len(frames_guide)
|
36 |
+
tasks = self.task_list(n)
|
37 |
+
remapping_table = [[(frames_style[i], 1)] for i in range(n)]
|
38 |
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
39 |
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
40 |
+
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
|
41 |
+
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
|
42 |
+
source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
|
43 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
44 |
+
for task, result in zip(tasks_batch, target_style):
|
45 |
+
target, level = task["target"], task["level"]
|
46 |
+
if len(remapping_table[target])==level:
|
47 |
+
remapping_table[target].append((result, 1))
|
48 |
+
else:
|
49 |
+
frame, weight = remapping_table[target][level]
|
50 |
+
remapping_table[target][level] = (
|
51 |
+
frame * (weight / (weight + 1)) + result / (weight + 1),
|
52 |
+
weight + 1
|
53 |
+
)
|
54 |
+
return remapping_table
|
55 |
+
|
56 |
+
def remapping_table_to_blending_table(self, table):
|
57 |
+
for i in range(len(table)):
|
58 |
+
for j in range(1, len(table[i])):
|
59 |
+
frame_1, weight_1 = table[i][j-1]
|
60 |
+
frame_2, weight_2 = table[i][j]
|
61 |
+
frame = (frame_1 + frame_2) / 2
|
62 |
+
weight = weight_1 + weight_2
|
63 |
+
table[i][j] = (frame, weight)
|
64 |
+
return table
|
65 |
+
|
66 |
+
def tree_query(self, leftbound, rightbound):
|
67 |
+
node_list = []
|
68 |
+
node_index = rightbound
|
69 |
+
while node_index>=leftbound:
|
70 |
+
node_level = 0
|
71 |
+
while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
|
72 |
+
node_level += 1
|
73 |
+
node_list.append((node_index, node_level))
|
74 |
+
node_index -= 1<<node_level
|
75 |
+
return node_list
|
76 |
+
|
77 |
+
def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
|
78 |
+
n = len(blending_table)
|
79 |
+
tasks = []
|
80 |
+
frames_result = []
|
81 |
+
for target in range(n):
|
82 |
+
node_list = self.tree_query(max(target-window_size, 0), target)
|
83 |
+
for source, level in node_list:
|
84 |
+
if source!=target:
|
85 |
+
meta_data = {
|
86 |
+
"source": source,
|
87 |
+
"target": target,
|
88 |
+
"level": level
|
89 |
+
}
|
90 |
+
tasks.append(meta_data)
|
91 |
+
else:
|
92 |
+
frames_result.append(blending_table[target][level])
|
93 |
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
94 |
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
95 |
+
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
|
96 |
+
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
|
97 |
+
source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
|
98 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
99 |
+
for task, frame_2 in zip(tasks_batch, target_style):
|
100 |
+
source, target, level = task["source"], task["target"], task["level"]
|
101 |
+
frame_1, weight_1 = frames_result[target]
|
102 |
+
weight_2 = blending_table[source][level][1]
|
103 |
+
weight = weight_1 + weight_2
|
104 |
+
frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
|
105 |
+
frames_result[target] = (frame, weight)
|
106 |
+
return frames_result
|
107 |
+
|
108 |
+
|
109 |
+
class FastModeRunner:
|
110 |
+
def __init__(self):
|
111 |
+
pass
|
112 |
+
|
113 |
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
|
114 |
+
frames_guide = frames_guide.raw_data()
|
115 |
+
frames_style = frames_style.raw_data()
|
116 |
+
table_manager = TableManager()
|
117 |
+
patch_match_engine = PyramidPatchMatcher(
|
118 |
+
image_height=frames_style[0].shape[0],
|
119 |
+
image_width=frames_style[0].shape[1],
|
120 |
+
channel=3,
|
121 |
+
**ebsynth_config
|
122 |
+
)
|
123 |
+
# left part
|
124 |
+
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
|
125 |
+
table_l = table_manager.remapping_table_to_blending_table(table_l)
|
126 |
+
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
|
127 |
+
# right part
|
128 |
+
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
|
129 |
+
table_r = table_manager.remapping_table_to_blending_table(table_r)
|
130 |
+
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
|
131 |
+
# merge
|
132 |
+
frames = []
|
133 |
+
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
|
134 |
+
weight_m = -1
|
135 |
+
weight = weight_l + weight_m + weight_r
|
136 |
+
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
|
137 |
+
frames.append(frame)
|
138 |
+
frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
|
139 |
+
if save_path is not None:
|
140 |
+
for target, frame in enumerate(frames):
|
141 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
diffsynth/extensions/FastBlend/runners/interpolation.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..patch_match import PyramidPatchMatcher
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class InterpolationModeRunner:
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def get_index_dict(self, index_style):
|
13 |
+
index_dict = {}
|
14 |
+
for i, index in enumerate(index_style):
|
15 |
+
index_dict[index] = i
|
16 |
+
return index_dict
|
17 |
+
|
18 |
+
def get_weight(self, l, m, r):
|
19 |
+
weight_l, weight_r = abs(m - r), abs(m - l)
|
20 |
+
if weight_l + weight_r == 0:
|
21 |
+
weight_l, weight_r = 0.5, 0.5
|
22 |
+
else:
|
23 |
+
weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
|
24 |
+
return weight_l, weight_r
|
25 |
+
|
26 |
+
def get_task_group(self, index_style, n):
|
27 |
+
task_group = []
|
28 |
+
index_style = sorted(index_style)
|
29 |
+
# first frame
|
30 |
+
if index_style[0]>0:
|
31 |
+
tasks = []
|
32 |
+
for m in range(index_style[0]):
|
33 |
+
tasks.append((index_style[0], m, index_style[0]))
|
34 |
+
task_group.append(tasks)
|
35 |
+
# middle frames
|
36 |
+
for l, r in zip(index_style[:-1], index_style[1:]):
|
37 |
+
tasks = []
|
38 |
+
for m in range(l, r):
|
39 |
+
tasks.append((l, m, r))
|
40 |
+
task_group.append(tasks)
|
41 |
+
# last frame
|
42 |
+
tasks = []
|
43 |
+
for m in range(index_style[-1], n):
|
44 |
+
tasks.append((index_style[-1], m, index_style[-1]))
|
45 |
+
task_group.append(tasks)
|
46 |
+
return task_group
|
47 |
+
|
48 |
+
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
|
49 |
+
patch_match_engine = PyramidPatchMatcher(
|
50 |
+
image_height=frames_style[0].shape[0],
|
51 |
+
image_width=frames_style[0].shape[1],
|
52 |
+
channel=3,
|
53 |
+
use_mean_target_style=False,
|
54 |
+
use_pairwise_patch_error=True,
|
55 |
+
**ebsynth_config
|
56 |
+
)
|
57 |
+
# task
|
58 |
+
index_dict = self.get_index_dict(index_style)
|
59 |
+
task_group = self.get_task_group(index_style, len(frames_guide))
|
60 |
+
# run
|
61 |
+
for tasks in task_group:
|
62 |
+
index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
|
63 |
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
|
64 |
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
65 |
+
source_guide, target_guide, source_style = [], [], []
|
66 |
+
for l, m, r in tasks_batch:
|
67 |
+
# l -> m
|
68 |
+
source_guide.append(frames_guide[l])
|
69 |
+
target_guide.append(frames_guide[m])
|
70 |
+
source_style.append(frames_style[index_dict[l]])
|
71 |
+
# r -> m
|
72 |
+
source_guide.append(frames_guide[r])
|
73 |
+
target_guide.append(frames_guide[m])
|
74 |
+
source_style.append(frames_style[index_dict[r]])
|
75 |
+
source_guide = np.stack(source_guide)
|
76 |
+
target_guide = np.stack(target_guide)
|
77 |
+
source_style = np.stack(source_style)
|
78 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
79 |
+
if save_path is not None:
|
80 |
+
for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
|
81 |
+
weight_l, weight_r = self.get_weight(l, m, r)
|
82 |
+
frame = frame_l * weight_l + frame_r * weight_r
|
83 |
+
frame = frame.clip(0, 255).astype("uint8")
|
84 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
|
85 |
+
|
86 |
+
|
87 |
+
class InterpolationModeSingleFrameRunner:
|
88 |
+
def __init__(self):
|
89 |
+
pass
|
90 |
+
|
91 |
+
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
|
92 |
+
# check input
|
93 |
+
tracking_window_size = ebsynth_config["tracking_window_size"]
|
94 |
+
if tracking_window_size * 2 >= batch_size:
|
95 |
+
raise ValueError("batch_size should be larger than track_window_size * 2")
|
96 |
+
frame_style = frames_style[0]
|
97 |
+
frame_guide = frames_guide[index_style[0]]
|
98 |
+
patch_match_engine = PyramidPatchMatcher(
|
99 |
+
image_height=frame_style.shape[0],
|
100 |
+
image_width=frame_style.shape[1],
|
101 |
+
channel=3,
|
102 |
+
**ebsynth_config
|
103 |
+
)
|
104 |
+
# run
|
105 |
+
frame_id, n = 0, len(frames_guide)
|
106 |
+
for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
|
107 |
+
if i + batch_size > n:
|
108 |
+
l, r = max(n - batch_size, 0), n
|
109 |
+
else:
|
110 |
+
l, r = i, i + batch_size
|
111 |
+
source_guide = np.stack([frame_guide] * (r-l))
|
112 |
+
target_guide = np.stack([frames_guide[i] for i in range(l, r)])
|
113 |
+
source_style = np.stack([frame_style] * (r-l))
|
114 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
115 |
+
for i, frame in zip(range(l, r), target_style):
|
116 |
+
if i==frame_id:
|
117 |
+
frame = frame.clip(0, 255).astype("uint8")
|
118 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
|
119 |
+
frame_id += 1
|
120 |
+
if r < n and r-frame_id <= tracking_window_size:
|
121 |
+
break
|
diffsynth/extensions/RIFE/__init__.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
def warp(tenInput, tenFlow, device):
|
9 |
+
backwarp_tenGrid = {}
|
10 |
+
k = (str(tenFlow.device), str(tenFlow.size()))
|
11 |
+
if k not in backwarp_tenGrid:
|
12 |
+
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
|
13 |
+
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
14 |
+
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
|
15 |
+
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
16 |
+
backwarp_tenGrid[k] = torch.cat(
|
17 |
+
[tenHorizontal, tenVertical], 1).to(device)
|
18 |
+
|
19 |
+
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
20 |
+
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
|
21 |
+
|
22 |
+
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
23 |
+
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
|
24 |
+
|
25 |
+
|
26 |
+
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
27 |
+
return nn.Sequential(
|
28 |
+
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
29 |
+
padding=padding, dilation=dilation, bias=True),
|
30 |
+
nn.PReLU(out_planes)
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
class IFBlock(nn.Module):
|
35 |
+
def __init__(self, in_planes, c=64):
|
36 |
+
super(IFBlock, self).__init__()
|
37 |
+
self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),)
|
38 |
+
self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
|
39 |
+
self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
|
40 |
+
self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
|
41 |
+
self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
|
42 |
+
self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1))
|
43 |
+
self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1))
|
44 |
+
|
45 |
+
def forward(self, x, flow, scale=1):
|
46 |
+
x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
47 |
+
flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
|
48 |
+
feat = self.conv0(torch.cat((x, flow), 1))
|
49 |
+
feat = self.convblock0(feat) + feat
|
50 |
+
feat = self.convblock1(feat) + feat
|
51 |
+
feat = self.convblock2(feat) + feat
|
52 |
+
feat = self.convblock3(feat) + feat
|
53 |
+
flow = self.conv1(feat)
|
54 |
+
mask = self.conv2(feat)
|
55 |
+
flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
|
56 |
+
mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
57 |
+
return flow, mask
|
58 |
+
|
59 |
+
|
60 |
+
class IFNet(nn.Module):
|
61 |
+
def __init__(self):
|
62 |
+
super(IFNet, self).__init__()
|
63 |
+
self.block0 = IFBlock(7+4, c=90)
|
64 |
+
self.block1 = IFBlock(7+4, c=90)
|
65 |
+
self.block2 = IFBlock(7+4, c=90)
|
66 |
+
self.block_tea = IFBlock(10+4, c=90)
|
67 |
+
|
68 |
+
def forward(self, x, scale_list=[4, 2, 1], training=False):
|
69 |
+
if training == False:
|
70 |
+
channel = x.shape[1] // 2
|
71 |
+
img0 = x[:, :channel]
|
72 |
+
img1 = x[:, channel:]
|
73 |
+
flow_list = []
|
74 |
+
merged = []
|
75 |
+
mask_list = []
|
76 |
+
warped_img0 = img0
|
77 |
+
warped_img1 = img1
|
78 |
+
flow = (x[:, :4]).detach() * 0
|
79 |
+
mask = (x[:, :1]).detach() * 0
|
80 |
+
block = [self.block0, self.block1, self.block2]
|
81 |
+
for i in range(3):
|
82 |
+
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
|
83 |
+
f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
|
84 |
+
flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
|
85 |
+
mask = mask + (m0 + (-m1)) / 2
|
86 |
+
mask_list.append(mask)
|
87 |
+
flow_list.append(flow)
|
88 |
+
warped_img0 = warp(img0, flow[:, :2], device=x.device)
|
89 |
+
warped_img1 = warp(img1, flow[:, 2:4], device=x.device)
|
90 |
+
merged.append((warped_img0, warped_img1))
|
91 |
+
'''
|
92 |
+
c0 = self.contextnet(img0, flow[:, :2])
|
93 |
+
c1 = self.contextnet(img1, flow[:, 2:4])
|
94 |
+
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
95 |
+
res = tmp[:, 1:4] * 2 - 1
|
96 |
+
'''
|
97 |
+
for i in range(3):
|
98 |
+
mask_list[i] = torch.sigmoid(mask_list[i])
|
99 |
+
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
100 |
+
return flow_list, mask_list[2], merged
|
101 |
+
|
102 |
+
def state_dict_converter(self):
|
103 |
+
return IFNetStateDictConverter()
|
104 |
+
|
105 |
+
|
106 |
+
class IFNetStateDictConverter:
|
107 |
+
def __init__(self):
|
108 |
+
pass
|
109 |
+
|
110 |
+
def from_diffusers(self, state_dict):
|
111 |
+
state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
112 |
+
return state_dict_
|
113 |
+
|
114 |
+
def from_civitai(self, state_dict):
|
115 |
+
return self.from_diffusers(state_dict)
|
116 |
+
|
117 |
+
|
118 |
+
class RIFEInterpolater:
|
119 |
+
def __init__(self, model, device="cuda"):
|
120 |
+
self.model = model
|
121 |
+
self.device = device
|
122 |
+
# IFNet only does not support float16
|
123 |
+
self.torch_dtype = torch.float32
|
124 |
+
|
125 |
+
@staticmethod
|
126 |
+
def from_model_manager(model_manager):
|
127 |
+
return RIFEInterpolater(model_manager.RIFE, device=model_manager.device)
|
128 |
+
|
129 |
+
def process_image(self, image):
|
130 |
+
width, height = image.size
|
131 |
+
if width % 32 != 0 or height % 32 != 0:
|
132 |
+
width = (width + 31) // 32
|
133 |
+
height = (height + 31) // 32
|
134 |
+
image = image.resize((width, height))
|
135 |
+
image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
|
136 |
+
return image
|
137 |
+
|
138 |
+
def process_images(self, images):
|
139 |
+
images = [self.process_image(image) for image in images]
|
140 |
+
images = torch.stack(images)
|
141 |
+
return images
|
142 |
+
|
143 |
+
def decode_images(self, images):
|
144 |
+
images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
|
145 |
+
images = [Image.fromarray(image) for image in images]
|
146 |
+
return images
|
147 |
+
|
148 |
+
def add_interpolated_images(self, images, interpolated_images):
|
149 |
+
output_images = []
|
150 |
+
for image, interpolated_image in zip(images, interpolated_images):
|
151 |
+
output_images.append(image)
|
152 |
+
output_images.append(interpolated_image)
|
153 |
+
output_images.append(images[-1])
|
154 |
+
return output_images
|
155 |
+
|
156 |
+
|
157 |
+
@torch.no_grad()
|
158 |
+
def interpolate_(self, images, scale=1.0):
|
159 |
+
input_tensor = self.process_images(images)
|
160 |
+
input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1)
|
161 |
+
input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
162 |
+
flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale])
|
163 |
+
output_images = self.decode_images(merged[2].cpu())
|
164 |
+
if output_images[0].size != images[0].size:
|
165 |
+
output_images = [image.resize(images[0].size) for image in output_images]
|
166 |
+
return output_images
|
167 |
+
|
168 |
+
|
169 |
+
@torch.no_grad()
|
170 |
+
def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, progress_bar=lambda x:x):
|
171 |
+
# Preprocess
|
172 |
+
processed_images = self.process_images(images)
|
173 |
+
|
174 |
+
for iter in range(num_iter):
|
175 |
+
# Input
|
176 |
+
input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1)
|
177 |
+
|
178 |
+
# Interpolate
|
179 |
+
output_tensor = []
|
180 |
+
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
|
181 |
+
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
182 |
+
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
183 |
+
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
184 |
+
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
|
185 |
+
output_tensor.append(merged[2].cpu())
|
186 |
+
|
187 |
+
# Output
|
188 |
+
output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1)
|
189 |
+
processed_images = self.add_interpolated_images(processed_images, output_tensor)
|
190 |
+
processed_images = torch.stack(processed_images)
|
191 |
+
|
192 |
+
# To images
|
193 |
+
output_images = self.decode_images(processed_images)
|
194 |
+
if output_images[0].size != images[0].size:
|
195 |
+
output_images = [image.resize(images[0].size) for image in output_images]
|
196 |
+
return output_images
|
197 |
+
|
198 |
+
|
199 |
+
class RIFESmoother(RIFEInterpolater):
|
200 |
+
def __init__(self, model, device="cuda"):
|
201 |
+
super(RIFESmoother, self).__init__(model, device=device)
|
202 |
+
|
203 |
+
@staticmethod
|
204 |
+
def from_model_manager(model_manager):
|
205 |
+
return RIFESmoother(model_manager.RIFE, device=model_manager.device)
|
206 |
+
|
207 |
+
def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
|
208 |
+
output_tensor = []
|
209 |
+
for batch_id in range(0, input_tensor.shape[0], batch_size):
|
210 |
+
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
211 |
+
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
212 |
+
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
213 |
+
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
|
214 |
+
output_tensor.append(merged[2].cpu())
|
215 |
+
output_tensor = torch.concat(output_tensor, dim=0)
|
216 |
+
return output_tensor
|
217 |
+
|
218 |
+
@torch.no_grad()
|
219 |
+
def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs):
|
220 |
+
# Preprocess
|
221 |
+
processed_images = self.process_images(rendered_frames)
|
222 |
+
|
223 |
+
for iter in range(num_iter):
|
224 |
+
# Input
|
225 |
+
input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
|
226 |
+
|
227 |
+
# Interpolate
|
228 |
+
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
|
229 |
+
|
230 |
+
# Blend
|
231 |
+
input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
|
232 |
+
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
|
233 |
+
|
234 |
+
# Add to frames
|
235 |
+
processed_images[1:-1] = output_tensor
|
236 |
+
|
237 |
+
# To images
|
238 |
+
output_images = self.decode_images(processed_images)
|
239 |
+
if output_images[0].size != rendered_frames[0].size:
|
240 |
+
output_images = [image.resize(rendered_frames[0].size) for image in output_images]
|
241 |
+
return output_images
|
diffsynth/models/__init__.py
ADDED
@@ -0,0 +1,814 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, os, json
|
2 |
+
from safetensors import safe_open
|
3 |
+
from typing_extensions import Literal, TypeAlias
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
from .downloader import download_from_huggingface, download_from_modelscope
|
7 |
+
|
8 |
+
from .sd_text_encoder import SDTextEncoder
|
9 |
+
from .sd_unet import SDUNet
|
10 |
+
from .sd_vae_encoder import SDVAEEncoder
|
11 |
+
from .sd_vae_decoder import SDVAEDecoder
|
12 |
+
from .sd_lora import SDLoRA
|
13 |
+
|
14 |
+
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
15 |
+
from .sdxl_unet import SDXLUNet
|
16 |
+
from .sdxl_vae_decoder import SDXLVAEDecoder
|
17 |
+
from .sdxl_vae_encoder import SDXLVAEEncoder
|
18 |
+
|
19 |
+
from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
20 |
+
from .sd3_dit import SD3DiT
|
21 |
+
from .sd3_vae_decoder import SD3VAEDecoder
|
22 |
+
from .sd3_vae_encoder import SD3VAEEncoder
|
23 |
+
|
24 |
+
from .sd_controlnet import SDControlNet
|
25 |
+
|
26 |
+
from .sd_motion import SDMotionModel
|
27 |
+
from .sdxl_motion import SDXLMotionModel
|
28 |
+
|
29 |
+
from .svd_image_encoder import SVDImageEncoder
|
30 |
+
from .svd_unet import SVDUNet
|
31 |
+
from .svd_vae_decoder import SVDVAEDecoder
|
32 |
+
from .svd_vae_encoder import SVDVAEEncoder
|
33 |
+
|
34 |
+
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
35 |
+
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
36 |
+
|
37 |
+
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
38 |
+
from .hunyuan_dit import HunyuanDiT
|
39 |
+
from .kolors_text_encoder import ChatGLMModel
|
40 |
+
|
41 |
+
|
42 |
+
preset_models_on_huggingface = {
|
43 |
+
"HunyuanDiT": [
|
44 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
45 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
46 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
47 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
48 |
+
],
|
49 |
+
"stable-video-diffusion-img2vid-xt": [
|
50 |
+
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
51 |
+
],
|
52 |
+
"ExVideo-SVD-128f-v1": [
|
53 |
+
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
54 |
+
],
|
55 |
+
}
|
56 |
+
preset_models_on_modelscope = {
|
57 |
+
# Hunyuan DiT
|
58 |
+
"HunyuanDiT": [
|
59 |
+
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
60 |
+
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
61 |
+
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
62 |
+
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
63 |
+
],
|
64 |
+
# Stable Video Diffusion
|
65 |
+
"stable-video-diffusion-img2vid-xt": [
|
66 |
+
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
67 |
+
],
|
68 |
+
# ExVideo
|
69 |
+
"ExVideo-SVD-128f-v1": [
|
70 |
+
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
71 |
+
],
|
72 |
+
# Stable Diffusion
|
73 |
+
"StableDiffusion_v15": [
|
74 |
+
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
75 |
+
],
|
76 |
+
"DreamShaper_8": [
|
77 |
+
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
78 |
+
],
|
79 |
+
"AingDiffusion_v12": [
|
80 |
+
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
|
81 |
+
],
|
82 |
+
"Flat2DAnimerge_v45Sharp": [
|
83 |
+
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
|
84 |
+
],
|
85 |
+
# Textual Inversion
|
86 |
+
"TextualInversion_VeryBadImageNegative_v1.3": [
|
87 |
+
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
88 |
+
],
|
89 |
+
# Stable Diffusion XL
|
90 |
+
"StableDiffusionXL_v1": [
|
91 |
+
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
92 |
+
],
|
93 |
+
"BluePencilXL_v200": [
|
94 |
+
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
95 |
+
],
|
96 |
+
"StableDiffusionXL_Turbo": [
|
97 |
+
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
98 |
+
],
|
99 |
+
# Stable Diffusion 3
|
100 |
+
"StableDiffusion3": [
|
101 |
+
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
102 |
+
],
|
103 |
+
"StableDiffusion3_without_T5": [
|
104 |
+
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
105 |
+
],
|
106 |
+
# ControlNet
|
107 |
+
"ControlNet_v11f1p_sd15_depth": [
|
108 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
109 |
+
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
110 |
+
],
|
111 |
+
"ControlNet_v11p_sd15_softedge": [
|
112 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
113 |
+
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
|
114 |
+
],
|
115 |
+
"ControlNet_v11f1e_sd15_tile": [
|
116 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
117 |
+
],
|
118 |
+
"ControlNet_v11p_sd15_lineart": [
|
119 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
120 |
+
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
121 |
+
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
|
122 |
+
],
|
123 |
+
# AnimateDiff
|
124 |
+
"AnimateDiff_v2": [
|
125 |
+
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
126 |
+
],
|
127 |
+
"AnimateDiff_xl_beta": [
|
128 |
+
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
129 |
+
],
|
130 |
+
# RIFE
|
131 |
+
"RIFE": [
|
132 |
+
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
|
133 |
+
],
|
134 |
+
# Beautiful Prompt
|
135 |
+
"BeautifulPrompt": [
|
136 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
137 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
138 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
139 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
140 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
141 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
142 |
+
],
|
143 |
+
# Translator
|
144 |
+
"opus-mt-zh-en": [
|
145 |
+
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
146 |
+
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
147 |
+
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
148 |
+
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
149 |
+
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
150 |
+
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
151 |
+
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
152 |
+
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
153 |
+
],
|
154 |
+
# IP-Adapter
|
155 |
+
"IP-Adapter-SD": [
|
156 |
+
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
157 |
+
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
158 |
+
],
|
159 |
+
"IP-Adapter-SDXL": [
|
160 |
+
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
161 |
+
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
162 |
+
],
|
163 |
+
# Kolors
|
164 |
+
"Kolors": [
|
165 |
+
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
166 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
167 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
168 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
169 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
170 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
171 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
172 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
173 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
174 |
+
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
175 |
+
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
176 |
+
],
|
177 |
+
"SDXL-vae-fp16-fix": [
|
178 |
+
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
179 |
+
],
|
180 |
+
}
|
181 |
+
Preset_model_id: TypeAlias = Literal[
|
182 |
+
"HunyuanDiT",
|
183 |
+
"stable-video-diffusion-img2vid-xt",
|
184 |
+
"ExVideo-SVD-128f-v1",
|
185 |
+
"StableDiffusion_v15",
|
186 |
+
"DreamShaper_8",
|
187 |
+
"AingDiffusion_v12",
|
188 |
+
"Flat2DAnimerge_v45Sharp",
|
189 |
+
"TextualInversion_VeryBadImageNegative_v1.3",
|
190 |
+
"StableDiffusionXL_v1",
|
191 |
+
"BluePencilXL_v200",
|
192 |
+
"StableDiffusionXL_Turbo",
|
193 |
+
"ControlNet_v11f1p_sd15_depth",
|
194 |
+
"ControlNet_v11p_sd15_softedge",
|
195 |
+
"ControlNet_v11f1e_sd15_tile",
|
196 |
+
"ControlNet_v11p_sd15_lineart",
|
197 |
+
"AnimateDiff_v2",
|
198 |
+
"AnimateDiff_xl_beta",
|
199 |
+
"RIFE",
|
200 |
+
"BeautifulPrompt",
|
201 |
+
"opus-mt-zh-en",
|
202 |
+
"IP-Adapter-SD",
|
203 |
+
"IP-Adapter-SDXL",
|
204 |
+
"StableDiffusion3",
|
205 |
+
"StableDiffusion3_without_T5",
|
206 |
+
"Kolors",
|
207 |
+
"SDXL-vae-fp16-fix",
|
208 |
+
]
|
209 |
+
Preset_model_website: TypeAlias = Literal[
|
210 |
+
"HuggingFace",
|
211 |
+
"ModelScope",
|
212 |
+
]
|
213 |
+
website_to_preset_models = {
|
214 |
+
"HuggingFace": preset_models_on_huggingface,
|
215 |
+
"ModelScope": preset_models_on_modelscope,
|
216 |
+
}
|
217 |
+
website_to_download_fn = {
|
218 |
+
"HuggingFace": download_from_huggingface,
|
219 |
+
"ModelScope": download_from_modelscope,
|
220 |
+
}
|
221 |
+
|
222 |
+
|
223 |
+
def download_models(
|
224 |
+
model_id_list: List[Preset_model_id] = [],
|
225 |
+
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
226 |
+
):
|
227 |
+
downloaded_files = []
|
228 |
+
for model_id in model_id_list:
|
229 |
+
for website in downloading_priority:
|
230 |
+
if model_id in website_to_preset_models[website]:
|
231 |
+
for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
|
232 |
+
# Check if the file is downloaded.
|
233 |
+
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
234 |
+
if file_to_download in downloaded_files:
|
235 |
+
continue
|
236 |
+
# Download
|
237 |
+
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
238 |
+
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
239 |
+
downloaded_files.append(file_to_download)
|
240 |
+
return downloaded_files
|
241 |
+
|
242 |
+
|
243 |
+
class ModelManager:
|
244 |
+
def __init__(
|
245 |
+
self,
|
246 |
+
torch_dtype=torch.float16,
|
247 |
+
device="cuda",
|
248 |
+
model_id_list: List[Preset_model_id] = [],
|
249 |
+
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
250 |
+
file_path_list: List[str] = [],
|
251 |
+
):
|
252 |
+
self.torch_dtype = torch_dtype
|
253 |
+
self.device = device
|
254 |
+
self.model = {}
|
255 |
+
self.model_path = {}
|
256 |
+
self.textual_inversion_dict = {}
|
257 |
+
downloaded_files = download_models(model_id_list, downloading_priority)
|
258 |
+
self.load_models(downloaded_files + file_path_list)
|
259 |
+
|
260 |
+
def load_model_from_origin(
|
261 |
+
self,
|
262 |
+
download_from: Preset_model_website = "ModelScope",
|
263 |
+
model_id = "",
|
264 |
+
origin_file_path = "",
|
265 |
+
local_dir = ""
|
266 |
+
):
|
267 |
+
website_to_download_fn[download_from](model_id, origin_file_path, local_dir)
|
268 |
+
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
269 |
+
self.load_model(file_to_download)
|
270 |
+
|
271 |
+
def is_stable_video_diffusion(self, state_dict):
|
272 |
+
param_name = "model.diffusion_model.output_blocks.9.1.time_stack.0.norm_in.weight"
|
273 |
+
return param_name in state_dict
|
274 |
+
|
275 |
+
def is_RIFE(self, state_dict):
|
276 |
+
param_name = "block_tea.convblock3.0.1.weight"
|
277 |
+
return param_name in state_dict or ("module." + param_name) in state_dict
|
278 |
+
|
279 |
+
def is_beautiful_prompt(self, state_dict):
|
280 |
+
param_name = "transformer.h.9.self_attention.query_key_value.weight"
|
281 |
+
return param_name in state_dict
|
282 |
+
|
283 |
+
def is_stabe_diffusion_xl(self, state_dict):
|
284 |
+
param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
|
285 |
+
return param_name in state_dict
|
286 |
+
|
287 |
+
def is_stable_diffusion(self, state_dict):
|
288 |
+
if self.is_stabe_diffusion_xl(state_dict):
|
289 |
+
return False
|
290 |
+
param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight"
|
291 |
+
return param_name in state_dict
|
292 |
+
|
293 |
+
def is_controlnet(self, state_dict):
|
294 |
+
param_name = "control_model.time_embed.0.weight"
|
295 |
+
return param_name in state_dict
|
296 |
+
|
297 |
+
def is_animatediff(self, state_dict):
|
298 |
+
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
|
299 |
+
return param_name in state_dict
|
300 |
+
|
301 |
+
def is_animatediff_xl(self, state_dict):
|
302 |
+
param_name = "up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.weight"
|
303 |
+
return param_name in state_dict
|
304 |
+
|
305 |
+
def is_sd_lora(self, state_dict):
|
306 |
+
param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight"
|
307 |
+
return param_name in state_dict
|
308 |
+
|
309 |
+
def is_translator(self, state_dict):
|
310 |
+
param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
|
311 |
+
return param_name in state_dict and len(state_dict) == 258
|
312 |
+
|
313 |
+
def is_ipadapter(self, state_dict):
|
314 |
+
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([3072, 1024])
|
315 |
+
|
316 |
+
def is_ipadapter_image_encoder(self, state_dict):
|
317 |
+
param_name = "vision_model.encoder.layers.31.self_attn.v_proj.weight"
|
318 |
+
return param_name in state_dict and len(state_dict) == 521
|
319 |
+
|
320 |
+
def is_ipadapter_xl(self, state_dict):
|
321 |
+
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([8192, 1280])
|
322 |
+
|
323 |
+
def is_ipadapter_xl_image_encoder(self, state_dict):
|
324 |
+
param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight"
|
325 |
+
return param_name in state_dict and len(state_dict) == 777
|
326 |
+
|
327 |
+
def is_hunyuan_dit_clip_text_encoder(self, state_dict):
|
328 |
+
param_name = "bert.encoder.layer.23.attention.output.dense.weight"
|
329 |
+
return param_name in state_dict
|
330 |
+
|
331 |
+
def is_hunyuan_dit_t5_text_encoder(self, state_dict):
|
332 |
+
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
|
333 |
+
param_name_ = "decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
|
334 |
+
return param_name in state_dict and param_name_ in state_dict
|
335 |
+
|
336 |
+
def is_hunyuan_dit(self, state_dict):
|
337 |
+
param_name = "final_layer.adaLN_modulation.1.weight"
|
338 |
+
return param_name in state_dict
|
339 |
+
|
340 |
+
def is_diffusers_vae(self, state_dict):
|
341 |
+
param_name = "quant_conv.weight"
|
342 |
+
return param_name in state_dict
|
343 |
+
|
344 |
+
def is_ExVideo_StableVideoDiffusion(self, state_dict):
|
345 |
+
param_name = "blocks.185.positional_embedding.embeddings"
|
346 |
+
return param_name in state_dict
|
347 |
+
|
348 |
+
def is_stable_diffusion_3(self, state_dict):
|
349 |
+
param_names = [
|
350 |
+
"text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight",
|
351 |
+
"text_encoders.clip_g.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight",
|
352 |
+
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.weight",
|
353 |
+
"first_stage_model.encoder.mid.block_2.norm2.weight",
|
354 |
+
"first_stage_model.decoder.mid.block_2.norm2.weight",
|
355 |
+
]
|
356 |
+
for param_name in param_names:
|
357 |
+
if param_name not in state_dict:
|
358 |
+
return False
|
359 |
+
return True
|
360 |
+
|
361 |
+
def is_stable_diffusion_3_t5(self, state_dict):
|
362 |
+
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
|
363 |
+
return param_name in state_dict
|
364 |
+
|
365 |
+
def is_kolors_text_encoder(self, file_path):
|
366 |
+
file_list = os.listdir(file_path)
|
367 |
+
if "config.json" in file_list:
|
368 |
+
try:
|
369 |
+
with open(os.path.join(file_path, "config.json"), "r") as f:
|
370 |
+
config = json.load(f)
|
371 |
+
if config.get("model_type") == "chatglm":
|
372 |
+
return True
|
373 |
+
except:
|
374 |
+
pass
|
375 |
+
return False
|
376 |
+
|
377 |
+
def is_kolors_unet(self, state_dict):
|
378 |
+
return "up_blocks.2.resnets.2.time_emb_proj.weight" in state_dict and "encoder_hid_proj.weight" in state_dict
|
379 |
+
|
380 |
+
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
|
381 |
+
component_dict = {
|
382 |
+
"image_encoder": SVDImageEncoder,
|
383 |
+
"unet": SVDUNet,
|
384 |
+
"vae_decoder": SVDVAEDecoder,
|
385 |
+
"vae_encoder": SVDVAEEncoder,
|
386 |
+
}
|
387 |
+
if components is None:
|
388 |
+
components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"]
|
389 |
+
for component in components:
|
390 |
+
if component == "unet":
|
391 |
+
self.model[component] = component_dict[component](add_positional_conv=add_positional_conv)
|
392 |
+
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv), strict=False)
|
393 |
+
else:
|
394 |
+
self.model[component] = component_dict[component]()
|
395 |
+
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
396 |
+
self.model[component].to(self.torch_dtype).to(self.device)
|
397 |
+
self.model_path[component] = file_path
|
398 |
+
|
399 |
+
def load_stable_diffusion(self, state_dict, components=None, file_path=""):
|
400 |
+
component_dict = {
|
401 |
+
"text_encoder": SDTextEncoder,
|
402 |
+
"unet": SDUNet,
|
403 |
+
"vae_decoder": SDVAEDecoder,
|
404 |
+
"vae_encoder": SDVAEEncoder,
|
405 |
+
"refiner": SDXLUNet,
|
406 |
+
}
|
407 |
+
if components is None:
|
408 |
+
components = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
|
409 |
+
for component in components:
|
410 |
+
if component == "text_encoder":
|
411 |
+
# Add additional token embeddings to text encoder
|
412 |
+
token_embeddings = [state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"]]
|
413 |
+
for keyword in self.textual_inversion_dict:
|
414 |
+
_, embeddings = self.textual_inversion_dict[keyword]
|
415 |
+
token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
|
416 |
+
token_embeddings = torch.concat(token_embeddings, dim=0)
|
417 |
+
state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
|
418 |
+
self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
|
419 |
+
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
420 |
+
self.model[component].to(self.torch_dtype).to(self.device)
|
421 |
+
else:
|
422 |
+
self.model[component] = component_dict[component]()
|
423 |
+
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
424 |
+
self.model[component].to(self.torch_dtype).to(self.device)
|
425 |
+
self.model_path[component] = file_path
|
426 |
+
|
427 |
+
def load_stable_diffusion_xl(self, state_dict, components=None, file_path=""):
|
428 |
+
component_dict = {
|
429 |
+
"text_encoder": SDXLTextEncoder,
|
430 |
+
"text_encoder_2": SDXLTextEncoder2,
|
431 |
+
"unet": SDXLUNet,
|
432 |
+
"vae_decoder": SDXLVAEDecoder,
|
433 |
+
"vae_encoder": SDXLVAEEncoder,
|
434 |
+
}
|
435 |
+
if components is None:
|
436 |
+
components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
|
437 |
+
for component in components:
|
438 |
+
self.model[component] = component_dict[component]()
|
439 |
+
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
440 |
+
if component in ["vae_decoder", "vae_encoder"]:
|
441 |
+
# These two model will output nan when float16 is enabled.
|
442 |
+
# The precision problem happens in the last three resnet blocks.
|
443 |
+
# I do not know how to solve this problem.
|
444 |
+
self.model[component].to(torch.float32).to(self.device)
|
445 |
+
else:
|
446 |
+
self.model[component].to(self.torch_dtype).to(self.device)
|
447 |
+
self.model_path[component] = file_path
|
448 |
+
|
449 |
+
def load_controlnet(self, state_dict, file_path=""):
|
450 |
+
component = "controlnet"
|
451 |
+
if component not in self.model:
|
452 |
+
self.model[component] = []
|
453 |
+
self.model_path[component] = []
|
454 |
+
model = SDControlNet()
|
455 |
+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
456 |
+
model.to(self.torch_dtype).to(self.device)
|
457 |
+
self.model[component].append(model)
|
458 |
+
self.model_path[component].append(file_path)
|
459 |
+
|
460 |
+
def load_animatediff(self, state_dict, file_path=""):
|
461 |
+
component = "motion_modules"
|
462 |
+
model = SDMotionModel()
|
463 |
+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
464 |
+
model.to(self.torch_dtype).to(self.device)
|
465 |
+
self.model[component] = model
|
466 |
+
self.model_path[component] = file_path
|
467 |
+
|
468 |
+
def load_animatediff_xl(self, state_dict, file_path=""):
|
469 |
+
component = "motion_modules_xl"
|
470 |
+
model = SDXLMotionModel()
|
471 |
+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
472 |
+
model.to(self.torch_dtype).to(self.device)
|
473 |
+
self.model[component] = model
|
474 |
+
self.model_path[component] = file_path
|
475 |
+
|
476 |
+
def load_beautiful_prompt(self, state_dict, file_path=""):
|
477 |
+
component = "beautiful_prompt"
|
478 |
+
from transformers import AutoModelForCausalLM
|
479 |
+
model_folder = os.path.dirname(file_path)
|
480 |
+
model = AutoModelForCausalLM.from_pretrained(
|
481 |
+
model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
|
482 |
+
).to(self.device).eval()
|
483 |
+
self.model[component] = model
|
484 |
+
self.model_path[component] = file_path
|
485 |
+
|
486 |
+
def load_RIFE(self, state_dict, file_path=""):
|
487 |
+
component = "RIFE"
|
488 |
+
from ..extensions.RIFE import IFNet
|
489 |
+
model = IFNet().eval()
|
490 |
+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
491 |
+
model.to(torch.float32).to(self.device)
|
492 |
+
self.model[component] = model
|
493 |
+
self.model_path[component] = file_path
|
494 |
+
|
495 |
+
def load_sd_lora(self, state_dict, alpha):
|
496 |
+
SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device)
|
497 |
+
SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device)
|
498 |
+
|
499 |
+
def load_translator(self, state_dict, file_path=""):
|
500 |
+
# This model is lightweight, we do not place it on GPU.
|
501 |
+
component = "translator"
|
502 |
+
from transformers import AutoModelForSeq2SeqLM
|
503 |
+
model_folder = os.path.dirname(file_path)
|
504 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_folder).eval()
|
505 |
+
self.model[component] = model
|
506 |
+
self.model_path[component] = file_path
|
507 |
+
|
508 |
+
def load_ipadapter(self, state_dict, file_path=""):
|
509 |
+
component = "ipadapter"
|
510 |
+
model = SDIpAdapter()
|
511 |
+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
512 |
+
model.to(self.torch_dtype).to(self.device)
|
513 |
+
self.model[component] = model
|
514 |
+
self.model_path[component] = file_path
|
515 |
+
|
516 |
+
def load_ipadapter_image_encoder(self, state_dict, file_path=""):
|
517 |
+
component = "ipadapter_image_encoder"
|
518 |
+
model = IpAdapterCLIPImageEmbedder()
|
519 |
+
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
520 |
+
model.to(self.torch_dtype).to(self.device)
|
521 |
+
self.model[component] = model
|
522 |
+
self.model_path[component] = file_path
|
523 |
+
|
524 |
+
def load_ipadapter_xl(self, state_dict, file_path=""):
|
525 |
+
component = "ipadapter_xl"
|
526 |
+
model = SDXLIpAdapter()
|
527 |
+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
528 |
+
model.to(self.torch_dtype).to(self.device)
|
529 |
+
self.model[component] = model
|
530 |
+
self.model_path[component] = file_path
|
531 |
+
|
532 |
+
def load_ipadapter_xl_image_encoder(self, state_dict, file_path=""):
|
533 |
+
component = "ipadapter_xl_image_encoder"
|
534 |
+
model = IpAdapterXLCLIPImageEmbedder()
|
535 |
+
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
536 |
+
model.to(self.torch_dtype).to(self.device)
|
537 |
+
self.model[component] = model
|
538 |
+
self.model_path[component] = file_path
|
539 |
+
|
540 |
+
def load_hunyuan_dit_clip_text_encoder(self, state_dict, file_path=""):
|
541 |
+
component = "hunyuan_dit_clip_text_encoder"
|
542 |
+
model = HunyuanDiTCLIPTextEncoder()
|
543 |
+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
544 |
+
model.to(self.torch_dtype).to(self.device)
|
545 |
+
self.model[component] = model
|
546 |
+
self.model_path[component] = file_path
|
547 |
+
|
548 |
+
def load_hunyuan_dit_t5_text_encoder(self, state_dict, file_path=""):
|
549 |
+
component = "hunyuan_dit_t5_text_encoder"
|
550 |
+
model = HunyuanDiTT5TextEncoder()
|
551 |
+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
552 |
+
model.to(self.torch_dtype).to(self.device)
|
553 |
+
self.model[component] = model
|
554 |
+
self.model_path[component] = file_path
|
555 |
+
|
556 |
+
def load_hunyuan_dit(self, state_dict, file_path=""):
|
557 |
+
component = "hunyuan_dit"
|
558 |
+
model = HunyuanDiT()
|
559 |
+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
560 |
+
model.to(self.torch_dtype).to(self.device)
|
561 |
+
self.model[component] = model
|
562 |
+
self.model_path[component] = file_path
|
563 |
+
|
564 |
+
def load_diffusers_vae(self, state_dict, file_path=""):
|
565 |
+
# TODO: detect SD and SDXL
|
566 |
+
component = "vae_encoder"
|
567 |
+
model = SDXLVAEEncoder()
|
568 |
+
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
569 |
+
model.to(torch.float32).to(self.device)
|
570 |
+
self.model[component] = model
|
571 |
+
self.model_path[component] = file_path
|
572 |
+
component = "vae_decoder"
|
573 |
+
model = SDXLVAEDecoder()
|
574 |
+
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
575 |
+
model.to(torch.float32).to(self.device)
|
576 |
+
self.model[component] = model
|
577 |
+
self.model_path[component] = file_path
|
578 |
+
|
579 |
+
def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""):
|
580 |
+
unet_state_dict = self.model["unet"].state_dict()
|
581 |
+
self.model["unet"].to("cpu")
|
582 |
+
del self.model["unet"]
|
583 |
+
add_positional_conv = state_dict["blocks.185.positional_embedding.embeddings"].shape[0]
|
584 |
+
self.model["unet"] = SVDUNet(add_positional_conv=add_positional_conv)
|
585 |
+
self.model["unet"].load_state_dict(unet_state_dict, strict=False)
|
586 |
+
self.model["unet"].load_state_dict(state_dict, strict=False)
|
587 |
+
self.model["unet"].to(self.torch_dtype).to(self.device)
|
588 |
+
|
589 |
+
def load_stable_diffusion_3(self, state_dict, components=None, file_path=""):
|
590 |
+
component_dict = {
|
591 |
+
"sd3_text_encoder_1": SD3TextEncoder1,
|
592 |
+
"sd3_text_encoder_2": SD3TextEncoder2,
|
593 |
+
"sd3_text_encoder_3": SD3TextEncoder3,
|
594 |
+
"sd3_dit": SD3DiT,
|
595 |
+
"sd3_vae_decoder": SD3VAEDecoder,
|
596 |
+
"sd3_vae_encoder": SD3VAEEncoder,
|
597 |
+
}
|
598 |
+
if components is None:
|
599 |
+
components = ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_decoder", "sd3_vae_encoder"]
|
600 |
+
for component in components:
|
601 |
+
if component == "sd3_text_encoder_3":
|
602 |
+
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" not in state_dict:
|
603 |
+
continue
|
604 |
+
if component == "sd3_text_encoder_1":
|
605 |
+
# Add additional token embeddings to text encoder
|
606 |
+
token_embeddings = [state_dict["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight"]]
|
607 |
+
for keyword in self.textual_inversion_dict:
|
608 |
+
_, embeddings = self.textual_inversion_dict[keyword]
|
609 |
+
token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
|
610 |
+
token_embeddings = torch.concat(token_embeddings, dim=0)
|
611 |
+
state_dict["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
|
612 |
+
self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
|
613 |
+
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
614 |
+
self.model[component].to(self.torch_dtype).to(self.device)
|
615 |
+
else:
|
616 |
+
self.model[component] = component_dict[component]()
|
617 |
+
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
618 |
+
self.model[component].to(self.torch_dtype).to(self.device)
|
619 |
+
self.model_path[component] = file_path
|
620 |
+
|
621 |
+
def load_stable_diffusion_3_t5(self, state_dict, file_path=""):
|
622 |
+
component = "sd3_text_encoder_3"
|
623 |
+
model = SD3TextEncoder3()
|
624 |
+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
625 |
+
model.to(self.torch_dtype).to(self.device)
|
626 |
+
self.model[component] = model
|
627 |
+
self.model_path[component] = file_path
|
628 |
+
|
629 |
+
def load_kolors_text_encoder(self, state_dict=None, file_path=""):
|
630 |
+
component = "kolors_text_encoder"
|
631 |
+
model = ChatGLMModel.from_pretrained(file_path, torch_dtype=self.torch_dtype)
|
632 |
+
model = model.to(dtype=self.torch_dtype, device=self.device)
|
633 |
+
self.model[component] = model
|
634 |
+
self.model_path[component] = file_path
|
635 |
+
|
636 |
+
def load_kolors_unet(self, state_dict, file_path=""):
|
637 |
+
component = "kolors_unet"
|
638 |
+
model = SDXLUNet(is_kolors=True)
|
639 |
+
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
640 |
+
model.to(self.torch_dtype).to(self.device)
|
641 |
+
self.model[component] = model
|
642 |
+
self.model_path[component] = file_path
|
643 |
+
|
644 |
+
def search_for_embeddings(self, state_dict):
|
645 |
+
embeddings = []
|
646 |
+
for k in state_dict:
|
647 |
+
if isinstance(state_dict[k], torch.Tensor):
|
648 |
+
embeddings.append(state_dict[k])
|
649 |
+
elif isinstance(state_dict[k], dict):
|
650 |
+
embeddings += self.search_for_embeddings(state_dict[k])
|
651 |
+
return embeddings
|
652 |
+
|
653 |
+
def load_textual_inversions(self, folder):
|
654 |
+
# Store additional tokens here
|
655 |
+
self.textual_inversion_dict = {}
|
656 |
+
|
657 |
+
# Load every textual inversion file
|
658 |
+
for file_name in os.listdir(folder):
|
659 |
+
if os.path.isdir(os.path.join(folder, file_name)) or \
|
660 |
+
not (file_name.endswith(".bin") or \
|
661 |
+
file_name.endswith(".safetensors") or \
|
662 |
+
file_name.endswith(".pth") or \
|
663 |
+
file_name.endswith(".pt")):
|
664 |
+
continue
|
665 |
+
keyword = os.path.splitext(file_name)[0]
|
666 |
+
state_dict = load_state_dict(os.path.join(folder, file_name))
|
667 |
+
|
668 |
+
# Search for embeddings
|
669 |
+
for embeddings in self.search_for_embeddings(state_dict):
|
670 |
+
if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
|
671 |
+
tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
|
672 |
+
self.textual_inversion_dict[keyword] = (tokens, embeddings)
|
673 |
+
break
|
674 |
+
|
675 |
+
def load_model(self, file_path, components=None, lora_alphas=[]):
|
676 |
+
if os.path.isdir(file_path):
|
677 |
+
if self.is_kolors_text_encoder(file_path):
|
678 |
+
self.load_kolors_text_encoder(file_path=file_path)
|
679 |
+
return
|
680 |
+
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
|
681 |
+
if self.is_stable_video_diffusion(state_dict):
|
682 |
+
self.load_stable_video_diffusion(state_dict, file_path=file_path)
|
683 |
+
elif self.is_animatediff(state_dict):
|
684 |
+
self.load_animatediff(state_dict, file_path=file_path)
|
685 |
+
elif self.is_animatediff_xl(state_dict):
|
686 |
+
self.load_animatediff_xl(state_dict, file_path=file_path)
|
687 |
+
elif self.is_controlnet(state_dict):
|
688 |
+
self.load_controlnet(state_dict, file_path=file_path)
|
689 |
+
elif self.is_stabe_diffusion_xl(state_dict):
|
690 |
+
self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
|
691 |
+
elif self.is_stable_diffusion(state_dict):
|
692 |
+
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
|
693 |
+
elif self.is_sd_lora(state_dict):
|
694 |
+
self.load_sd_lora(state_dict, alpha=lora_alphas.pop(0))
|
695 |
+
elif self.is_beautiful_prompt(state_dict):
|
696 |
+
self.load_beautiful_prompt(state_dict, file_path=file_path)
|
697 |
+
elif self.is_RIFE(state_dict):
|
698 |
+
self.load_RIFE(state_dict, file_path=file_path)
|
699 |
+
elif self.is_translator(state_dict):
|
700 |
+
self.load_translator(state_dict, file_path=file_path)
|
701 |
+
elif self.is_ipadapter(state_dict):
|
702 |
+
self.load_ipadapter(state_dict, file_path=file_path)
|
703 |
+
elif self.is_ipadapter_image_encoder(state_dict):
|
704 |
+
self.load_ipadapter_image_encoder(state_dict, file_path=file_path)
|
705 |
+
elif self.is_ipadapter_xl(state_dict):
|
706 |
+
self.load_ipadapter_xl(state_dict, file_path=file_path)
|
707 |
+
elif self.is_ipadapter_xl_image_encoder(state_dict):
|
708 |
+
self.load_ipadapter_xl_image_encoder(state_dict, file_path=file_path)
|
709 |
+
elif self.is_hunyuan_dit_clip_text_encoder(state_dict):
|
710 |
+
self.load_hunyuan_dit_clip_text_encoder(state_dict, file_path=file_path)
|
711 |
+
elif self.is_hunyuan_dit_t5_text_encoder(state_dict):
|
712 |
+
self.load_hunyuan_dit_t5_text_encoder(state_dict, file_path=file_path)
|
713 |
+
elif self.is_hunyuan_dit(state_dict):
|
714 |
+
self.load_hunyuan_dit(state_dict, file_path=file_path)
|
715 |
+
elif self.is_diffusers_vae(state_dict):
|
716 |
+
self.load_diffusers_vae(state_dict, file_path=file_path)
|
717 |
+
elif self.is_ExVideo_StableVideoDiffusion(state_dict):
|
718 |
+
self.load_ExVideo_StableVideoDiffusion(state_dict, file_path=file_path)
|
719 |
+
elif self.is_stable_diffusion_3(state_dict):
|
720 |
+
self.load_stable_diffusion_3(state_dict, components=components, file_path=file_path)
|
721 |
+
elif self.is_stable_diffusion_3_t5(state_dict):
|
722 |
+
self.load_stable_diffusion_3_t5(state_dict, file_path=file_path)
|
723 |
+
elif self.is_kolors_unet(state_dict):
|
724 |
+
self.load_kolors_unet(state_dict, file_path=file_path)
|
725 |
+
|
726 |
+
def load_models(self, file_path_list, lora_alphas=[]):
|
727 |
+
for file_path in file_path_list:
|
728 |
+
self.load_model(file_path, lora_alphas=lora_alphas)
|
729 |
+
|
730 |
+
def to(self, device):
|
731 |
+
for component in self.model:
|
732 |
+
if isinstance(self.model[component], list):
|
733 |
+
for model in self.model[component]:
|
734 |
+
model.to(device)
|
735 |
+
else:
|
736 |
+
self.model[component].to(device)
|
737 |
+
torch.cuda.empty_cache()
|
738 |
+
|
739 |
+
def get_model_with_model_path(self, model_path):
|
740 |
+
for component in self.model_path:
|
741 |
+
if isinstance(self.model_path[component], str):
|
742 |
+
if os.path.samefile(self.model_path[component], model_path):
|
743 |
+
return self.model[component]
|
744 |
+
elif isinstance(self.model_path[component], list):
|
745 |
+
for i, model_path_ in enumerate(self.model_path[component]):
|
746 |
+
if os.path.samefile(model_path_, model_path):
|
747 |
+
return self.model[component][i]
|
748 |
+
raise ValueError(f"Please load model {model_path} before you use it.")
|
749 |
+
|
750 |
+
def __getattr__(self, __name):
|
751 |
+
if __name in self.model:
|
752 |
+
return self.model[__name]
|
753 |
+
else:
|
754 |
+
return super.__getattribute__(__name)
|
755 |
+
|
756 |
+
|
757 |
+
def load_state_dict(file_path, torch_dtype=None):
|
758 |
+
if file_path.endswith(".safetensors"):
|
759 |
+
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
760 |
+
else:
|
761 |
+
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
762 |
+
|
763 |
+
|
764 |
+
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
765 |
+
state_dict = {}
|
766 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
767 |
+
for k in f.keys():
|
768 |
+
state_dict[k] = f.get_tensor(k)
|
769 |
+
if torch_dtype is not None:
|
770 |
+
state_dict[k] = state_dict[k].to(torch_dtype)
|
771 |
+
return state_dict
|
772 |
+
|
773 |
+
|
774 |
+
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
775 |
+
state_dict = torch.load(file_path, map_location="cpu")
|
776 |
+
if torch_dtype is not None:
|
777 |
+
for i in state_dict:
|
778 |
+
if isinstance(state_dict[i], torch.Tensor):
|
779 |
+
state_dict[i] = state_dict[i].to(torch_dtype)
|
780 |
+
return state_dict
|
781 |
+
|
782 |
+
|
783 |
+
def search_parameter(param, state_dict):
|
784 |
+
for name, param_ in state_dict.items():
|
785 |
+
if param.numel() == param_.numel():
|
786 |
+
if param.shape == param_.shape:
|
787 |
+
if torch.dist(param, param_) < 1e-6:
|
788 |
+
return name
|
789 |
+
else:
|
790 |
+
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
|
791 |
+
return name
|
792 |
+
return None
|
793 |
+
|
794 |
+
|
795 |
+
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
796 |
+
matched_keys = set()
|
797 |
+
with torch.no_grad():
|
798 |
+
for name in source_state_dict:
|
799 |
+
rename = search_parameter(source_state_dict[name], target_state_dict)
|
800 |
+
if rename is not None:
|
801 |
+
print(f'"{name}": "{rename}",')
|
802 |
+
matched_keys.add(rename)
|
803 |
+
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
804 |
+
length = source_state_dict[name].shape[0] // 3
|
805 |
+
rename = []
|
806 |
+
for i in range(3):
|
807 |
+
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
808 |
+
if None not in rename:
|
809 |
+
print(f'"{name}": {rename},')
|
810 |
+
for rename_ in rename:
|
811 |
+
matched_keys.add(rename_)
|
812 |
+
for name in target_state_dict:
|
813 |
+
if name not in matched_keys:
|
814 |
+
print("Cannot find", name, target_state_dict[name].shape)
|
diffsynth/models/attention.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import rearrange
|
3 |
+
|
4 |
+
|
5 |
+
def low_version_attention(query, key, value, attn_bias=None):
|
6 |
+
scale = 1 / query.shape[-1] ** 0.5
|
7 |
+
query = query * scale
|
8 |
+
attn = torch.matmul(query, key.transpose(-2, -1))
|
9 |
+
if attn_bias is not None:
|
10 |
+
attn = attn + attn_bias
|
11 |
+
attn = attn.softmax(-1)
|
12 |
+
return attn @ value
|
13 |
+
|
14 |
+
|
15 |
+
class Attention(torch.nn.Module):
|
16 |
+
|
17 |
+
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
18 |
+
super().__init__()
|
19 |
+
dim_inner = head_dim * num_heads
|
20 |
+
kv_dim = kv_dim if kv_dim is not None else q_dim
|
21 |
+
self.num_heads = num_heads
|
22 |
+
self.head_dim = head_dim
|
23 |
+
|
24 |
+
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
25 |
+
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
26 |
+
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
27 |
+
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
28 |
+
|
29 |
+
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
|
30 |
+
batch_size = q.shape[0]
|
31 |
+
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
32 |
+
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
33 |
+
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
34 |
+
hidden_states = hidden_states + scale * ip_hidden_states
|
35 |
+
return hidden_states
|
36 |
+
|
37 |
+
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
38 |
+
if encoder_hidden_states is None:
|
39 |
+
encoder_hidden_states = hidden_states
|
40 |
+
|
41 |
+
batch_size = encoder_hidden_states.shape[0]
|
42 |
+
|
43 |
+
q = self.to_q(hidden_states)
|
44 |
+
k = self.to_k(encoder_hidden_states)
|
45 |
+
v = self.to_v(encoder_hidden_states)
|
46 |
+
|
47 |
+
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
48 |
+
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
49 |
+
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
50 |
+
|
51 |
+
if qkv_preprocessor is not None:
|
52 |
+
q, k, v = qkv_preprocessor(q, k, v)
|
53 |
+
|
54 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
55 |
+
if ipadapter_kwargs is not None:
|
56 |
+
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
|
57 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
58 |
+
hidden_states = hidden_states.to(q.dtype)
|
59 |
+
|
60 |
+
hidden_states = self.to_out(hidden_states)
|
61 |
+
|
62 |
+
return hidden_states
|
63 |
+
|
64 |
+
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
65 |
+
if encoder_hidden_states is None:
|
66 |
+
encoder_hidden_states = hidden_states
|
67 |
+
|
68 |
+
q = self.to_q(hidden_states)
|
69 |
+
k = self.to_k(encoder_hidden_states)
|
70 |
+
v = self.to_v(encoder_hidden_states)
|
71 |
+
|
72 |
+
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
|
73 |
+
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
|
74 |
+
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
|
75 |
+
|
76 |
+
if attn_mask is not None:
|
77 |
+
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
|
78 |
+
else:
|
79 |
+
import xformers.ops as xops
|
80 |
+
hidden_states = xops.memory_efficient_attention(q, k, v)
|
81 |
+
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
|
82 |
+
|
83 |
+
hidden_states = hidden_states.to(q.dtype)
|
84 |
+
hidden_states = self.to_out(hidden_states)
|
85 |
+
|
86 |
+
return hidden_states
|
87 |
+
|
88 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
89 |
+
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
|
diffsynth/models/downloader.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import hf_hub_download
|
2 |
+
from modelscope import snapshot_download
|
3 |
+
import os, shutil
|
4 |
+
|
5 |
+
|
6 |
+
def download_from_modelscope(model_id, origin_file_path, local_dir):
|
7 |
+
os.makedirs(local_dir, exist_ok=True)
|
8 |
+
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
9 |
+
print(f"{os.path.basename(origin_file_path)} has been already in {local_dir}.")
|
10 |
+
return
|
11 |
+
else:
|
12 |
+
print(f"Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
13 |
+
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
|
14 |
+
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
15 |
+
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
|
16 |
+
if downloaded_file_path != target_file_path:
|
17 |
+
shutil.move(downloaded_file_path, target_file_path)
|
18 |
+
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
19 |
+
|
20 |
+
|
21 |
+
def download_from_huggingface(model_id, origin_file_path, local_dir):
|
22 |
+
os.makedirs(local_dir, exist_ok=True)
|
23 |
+
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
24 |
+
print(f"{os.path.basename(origin_file_path)} has been already in {local_dir}.")
|
25 |
+
return
|
26 |
+
else:
|
27 |
+
print(f"Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
28 |
+
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
|
diffsynth/models/hunyuan_dit.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .attention import Attention
|
2 |
+
from .tiler import TileWorker
|
3 |
+
from einops import repeat, rearrange
|
4 |
+
import math
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class HunyuanDiTRotaryEmbedding(torch.nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, q_norm_shape=88, k_norm_shape=88, rotary_emb_on_k=True):
|
11 |
+
super().__init__()
|
12 |
+
self.q_norm = torch.nn.LayerNorm((q_norm_shape,), elementwise_affine=True, eps=1e-06)
|
13 |
+
self.k_norm = torch.nn.LayerNorm((k_norm_shape,), elementwise_affine=True, eps=1e-06)
|
14 |
+
self.rotary_emb_on_k = rotary_emb_on_k
|
15 |
+
self.k_cache, self.v_cache = [], []
|
16 |
+
|
17 |
+
def reshape_for_broadcast(self, freqs_cis, x):
|
18 |
+
ndim = x.ndim
|
19 |
+
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
20 |
+
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
21 |
+
|
22 |
+
def rotate_half(self, x):
|
23 |
+
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
24 |
+
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
25 |
+
|
26 |
+
def apply_rotary_emb(self, xq, xk, freqs_cis):
|
27 |
+
xk_out = None
|
28 |
+
cos, sin = self.reshape_for_broadcast(freqs_cis, xq)
|
29 |
+
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
30 |
+
xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq)
|
31 |
+
if xk is not None:
|
32 |
+
xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk)
|
33 |
+
return xq_out, xk_out
|
34 |
+
|
35 |
+
def forward(self, q, k, v, freqs_cis_img, to_cache=False):
|
36 |
+
# norm
|
37 |
+
q = self.q_norm(q)
|
38 |
+
k = self.k_norm(k)
|
39 |
+
|
40 |
+
# RoPE
|
41 |
+
if self.rotary_emb_on_k:
|
42 |
+
q, k = self.apply_rotary_emb(q, k, freqs_cis_img)
|
43 |
+
else:
|
44 |
+
q, _ = self.apply_rotary_emb(q, None, freqs_cis_img)
|
45 |
+
|
46 |
+
if to_cache:
|
47 |
+
self.k_cache.append(k)
|
48 |
+
self.v_cache.append(v)
|
49 |
+
elif len(self.k_cache) > 0 and len(self.v_cache) > 0:
|
50 |
+
k = torch.concat([k] + self.k_cache, dim=2)
|
51 |
+
v = torch.concat([v] + self.v_cache, dim=2)
|
52 |
+
self.k_cache, self.v_cache = [], []
|
53 |
+
return q, k, v
|
54 |
+
|
55 |
+
|
56 |
+
class FP32_Layernorm(torch.nn.LayerNorm):
|
57 |
+
def forward(self, inputs):
|
58 |
+
origin_dtype = inputs.dtype
|
59 |
+
return torch.nn.functional.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).to(origin_dtype)
|
60 |
+
|
61 |
+
|
62 |
+
class FP32_SiLU(torch.nn.SiLU):
|
63 |
+
def forward(self, inputs):
|
64 |
+
origin_dtype = inputs.dtype
|
65 |
+
return torch.nn.functional.silu(inputs.float(), inplace=False).to(origin_dtype)
|
66 |
+
|
67 |
+
|
68 |
+
class HunyuanDiTFinalLayer(torch.nn.Module):
|
69 |
+
def __init__(self, final_hidden_size=1408, condition_dim=1408, patch_size=2, out_channels=8):
|
70 |
+
super().__init__()
|
71 |
+
self.norm_final = torch.nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
|
72 |
+
self.linear = torch.nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
|
73 |
+
self.adaLN_modulation = torch.nn.Sequential(
|
74 |
+
FP32_SiLU(),
|
75 |
+
torch.nn.Linear(condition_dim, 2 * final_hidden_size, bias=True)
|
76 |
+
)
|
77 |
+
|
78 |
+
def modulate(self, x, shift, scale):
|
79 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
80 |
+
|
81 |
+
def forward(self, hidden_states, condition_emb):
|
82 |
+
shift, scale = self.adaLN_modulation(condition_emb).chunk(2, dim=1)
|
83 |
+
hidden_states = self.modulate(self.norm_final(hidden_states), shift, scale)
|
84 |
+
hidden_states = self.linear(hidden_states)
|
85 |
+
return hidden_states
|
86 |
+
|
87 |
+
|
88 |
+
class HunyuanDiTBlock(torch.nn.Module):
|
89 |
+
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
hidden_dim=1408,
|
93 |
+
condition_dim=1408,
|
94 |
+
num_heads=16,
|
95 |
+
mlp_ratio=4.3637,
|
96 |
+
text_dim=1024,
|
97 |
+
skip_connection=False
|
98 |
+
):
|
99 |
+
super().__init__()
|
100 |
+
self.norm1 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
101 |
+
self.rota1 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads)
|
102 |
+
self.attn1 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
|
103 |
+
self.norm2 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
104 |
+
self.rota2 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads, rotary_emb_on_k=False)
|
105 |
+
self.attn2 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, kv_dim=text_dim, bias_q=True, bias_kv=True, bias_out=True)
|
106 |
+
self.norm3 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
107 |
+
self.modulation = torch.nn.Sequential(FP32_SiLU(), torch.nn.Linear(condition_dim, hidden_dim, bias=True))
|
108 |
+
self.mlp = torch.nn.Sequential(
|
109 |
+
torch.nn.Linear(hidden_dim, int(hidden_dim*mlp_ratio), bias=True),
|
110 |
+
torch.nn.GELU(approximate="tanh"),
|
111 |
+
torch.nn.Linear(int(hidden_dim*mlp_ratio), hidden_dim, bias=True)
|
112 |
+
)
|
113 |
+
if skip_connection:
|
114 |
+
self.skip_norm = FP32_Layernorm((hidden_dim * 2,), eps=1e-6, elementwise_affine=True)
|
115 |
+
self.skip_linear = torch.nn.Linear(hidden_dim * 2, hidden_dim, bias=True)
|
116 |
+
else:
|
117 |
+
self.skip_norm, self.skip_linear = None, None
|
118 |
+
|
119 |
+
def forward(self, hidden_states, condition_emb, text_emb, freq_cis_img, residual=None, to_cache=False):
|
120 |
+
# Long Skip Connection
|
121 |
+
if self.skip_norm is not None and self.skip_linear is not None:
|
122 |
+
hidden_states = torch.cat([hidden_states, residual], dim=-1)
|
123 |
+
hidden_states = self.skip_norm(hidden_states)
|
124 |
+
hidden_states = self.skip_linear(hidden_states)
|
125 |
+
|
126 |
+
# Self-Attention
|
127 |
+
shift_msa = self.modulation(condition_emb).unsqueeze(dim=1)
|
128 |
+
attn_input = self.norm1(hidden_states) + shift_msa
|
129 |
+
hidden_states = hidden_states + self.attn1(attn_input, qkv_preprocessor=lambda q, k, v: self.rota1(q, k, v, freq_cis_img, to_cache=to_cache))
|
130 |
+
|
131 |
+
# Cross-Attention
|
132 |
+
attn_input = self.norm3(hidden_states)
|
133 |
+
hidden_states = hidden_states + self.attn2(attn_input, text_emb, qkv_preprocessor=lambda q, k, v: self.rota2(q, k, v, freq_cis_img))
|
134 |
+
|
135 |
+
# FFN Layer
|
136 |
+
mlp_input = self.norm2(hidden_states)
|
137 |
+
hidden_states = hidden_states + self.mlp(mlp_input)
|
138 |
+
return hidden_states
|
139 |
+
|
140 |
+
|
141 |
+
class AttentionPool(torch.nn.Module):
|
142 |
+
def __init__(self, spacial_dim, embed_dim, num_heads, output_dim = None):
|
143 |
+
super().__init__()
|
144 |
+
self.positional_embedding = torch.nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
|
145 |
+
self.k_proj = torch.nn.Linear(embed_dim, embed_dim)
|
146 |
+
self.q_proj = torch.nn.Linear(embed_dim, embed_dim)
|
147 |
+
self.v_proj = torch.nn.Linear(embed_dim, embed_dim)
|
148 |
+
self.c_proj = torch.nn.Linear(embed_dim, output_dim or embed_dim)
|
149 |
+
self.num_heads = num_heads
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
x = x.permute(1, 0, 2) # NLC -> LNC
|
153 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
154 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
|
155 |
+
x, _ = torch.nn.functional.multi_head_attention_forward(
|
156 |
+
query=x[:1], key=x, value=x,
|
157 |
+
embed_dim_to_check=x.shape[-1],
|
158 |
+
num_heads=self.num_heads,
|
159 |
+
q_proj_weight=self.q_proj.weight,
|
160 |
+
k_proj_weight=self.k_proj.weight,
|
161 |
+
v_proj_weight=self.v_proj.weight,
|
162 |
+
in_proj_weight=None,
|
163 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
164 |
+
bias_k=None,
|
165 |
+
bias_v=None,
|
166 |
+
add_zero_attn=False,
|
167 |
+
dropout_p=0,
|
168 |
+
out_proj_weight=self.c_proj.weight,
|
169 |
+
out_proj_bias=self.c_proj.bias,
|
170 |
+
use_separate_proj_weight=True,
|
171 |
+
training=self.training,
|
172 |
+
need_weights=False
|
173 |
+
)
|
174 |
+
return x.squeeze(0)
|
175 |
+
|
176 |
+
|
177 |
+
class PatchEmbed(torch.nn.Module):
|
178 |
+
def __init__(
|
179 |
+
self,
|
180 |
+
patch_size=(2, 2),
|
181 |
+
in_chans=4,
|
182 |
+
embed_dim=1408,
|
183 |
+
bias=True,
|
184 |
+
):
|
185 |
+
super().__init__()
|
186 |
+
self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
187 |
+
|
188 |
+
def forward(self, x):
|
189 |
+
x = self.proj(x)
|
190 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
191 |
+
return x
|
192 |
+
|
193 |
+
|
194 |
+
def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
|
195 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
196 |
+
if not repeat_only:
|
197 |
+
half = dim // 2
|
198 |
+
freqs = torch.exp(
|
199 |
+
-math.log(max_period)
|
200 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
201 |
+
/ half
|
202 |
+
).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
|
203 |
+
args = t[:, None].float() * freqs[None]
|
204 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
205 |
+
if dim % 2:
|
206 |
+
embedding = torch.cat(
|
207 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
208 |
+
)
|
209 |
+
else:
|
210 |
+
embedding = repeat(t, "b -> b d", d=dim)
|
211 |
+
return embedding
|
212 |
+
|
213 |
+
|
214 |
+
class TimestepEmbedder(torch.nn.Module):
|
215 |
+
def __init__(self, hidden_size=1408, frequency_embedding_size=256):
|
216 |
+
super().__init__()
|
217 |
+
self.mlp = torch.nn.Sequential(
|
218 |
+
torch.nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
219 |
+
torch.nn.SiLU(),
|
220 |
+
torch.nn.Linear(hidden_size, hidden_size, bias=True),
|
221 |
+
)
|
222 |
+
self.frequency_embedding_size = frequency_embedding_size
|
223 |
+
|
224 |
+
def forward(self, t):
|
225 |
+
t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
|
226 |
+
t_emb = self.mlp(t_freq)
|
227 |
+
return t_emb
|
228 |
+
|
229 |
+
|
230 |
+
class HunyuanDiT(torch.nn.Module):
|
231 |
+
def __init__(self, num_layers_down=21, num_layers_up=19, in_channels=4, out_channels=8, hidden_dim=1408, text_dim=1024, t5_dim=2048, text_length=77, t5_length=256):
|
232 |
+
super().__init__()
|
233 |
+
|
234 |
+
# Embedders
|
235 |
+
self.text_emb_padding = torch.nn.Parameter(torch.randn(text_length + t5_length, text_dim, dtype=torch.float32))
|
236 |
+
self.t5_embedder = torch.nn.Sequential(
|
237 |
+
torch.nn.Linear(t5_dim, t5_dim * 4, bias=True),
|
238 |
+
FP32_SiLU(),
|
239 |
+
torch.nn.Linear(t5_dim * 4, text_dim, bias=True),
|
240 |
+
)
|
241 |
+
self.t5_pooler = AttentionPool(t5_length, t5_dim, num_heads=8, output_dim=1024)
|
242 |
+
self.style_embedder = torch.nn.Parameter(torch.randn(hidden_dim))
|
243 |
+
self.patch_embedder = PatchEmbed(in_chans=in_channels)
|
244 |
+
self.timestep_embedder = TimestepEmbedder()
|
245 |
+
self.extra_embedder = torch.nn.Sequential(
|
246 |
+
torch.nn.Linear(256 * 6 + 1024 + hidden_dim, hidden_dim * 4),
|
247 |
+
FP32_SiLU(),
|
248 |
+
torch.nn.Linear(hidden_dim * 4, hidden_dim),
|
249 |
+
)
|
250 |
+
|
251 |
+
# Transformer blocks
|
252 |
+
self.num_layers_down = num_layers_down
|
253 |
+
self.num_layers_up = num_layers_up
|
254 |
+
self.blocks = torch.nn.ModuleList(
|
255 |
+
[HunyuanDiTBlock(skip_connection=False) for _ in range(num_layers_down)] + \
|
256 |
+
[HunyuanDiTBlock(skip_connection=True) for _ in range(num_layers_up)]
|
257 |
+
)
|
258 |
+
|
259 |
+
# Output layers
|
260 |
+
self.final_layer = HunyuanDiTFinalLayer()
|
261 |
+
self.out_channels = out_channels
|
262 |
+
|
263 |
+
def prepare_text_emb(self, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5):
|
264 |
+
text_emb_mask = text_emb_mask.bool()
|
265 |
+
text_emb_mask_t5 = text_emb_mask_t5.bool()
|
266 |
+
text_emb_t5 = self.t5_embedder(text_emb_t5)
|
267 |
+
text_emb = torch.cat([text_emb, text_emb_t5], dim=1)
|
268 |
+
text_emb_mask = torch.cat([text_emb_mask, text_emb_mask_t5], dim=-1)
|
269 |
+
text_emb = torch.where(text_emb_mask.unsqueeze(2), text_emb, self.text_emb_padding.to(text_emb))
|
270 |
+
return text_emb
|
271 |
+
|
272 |
+
def prepare_extra_emb(self, text_emb_t5, timestep, size_emb, dtype, batch_size):
|
273 |
+
# Text embedding
|
274 |
+
pooled_text_emb_t5 = self.t5_pooler(text_emb_t5)
|
275 |
+
|
276 |
+
# Timestep embedding
|
277 |
+
timestep_emb = self.timestep_embedder(timestep)
|
278 |
+
|
279 |
+
# Size embedding
|
280 |
+
size_emb = timestep_embedding(size_emb.view(-1), 256).to(dtype)
|
281 |
+
size_emb = size_emb.view(-1, 6 * 256)
|
282 |
+
|
283 |
+
# Style embedding
|
284 |
+
style_emb = repeat(self.style_embedder, "D -> B D", B=batch_size)
|
285 |
+
|
286 |
+
# Concatenate all extra vectors
|
287 |
+
extra_emb = torch.cat([pooled_text_emb_t5, size_emb, style_emb], dim=1)
|
288 |
+
condition_emb = timestep_emb + self.extra_embedder(extra_emb)
|
289 |
+
|
290 |
+
return condition_emb
|
291 |
+
|
292 |
+
def unpatchify(self, x, h, w):
|
293 |
+
return rearrange(x, "B (H W) (P Q C) -> B C (H P) (W Q)", H=h, W=w, P=2, Q=2)
|
294 |
+
|
295 |
+
def build_mask(self, data, is_bound):
|
296 |
+
_, _, H, W = data.shape
|
297 |
+
h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
|
298 |
+
w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
|
299 |
+
border_width = (H + W) // 4
|
300 |
+
pad = torch.ones_like(h) * border_width
|
301 |
+
mask = torch.stack([
|
302 |
+
pad if is_bound[0] else h + 1,
|
303 |
+
pad if is_bound[1] else H - h,
|
304 |
+
pad if is_bound[2] else w + 1,
|
305 |
+
pad if is_bound[3] else W - w
|
306 |
+
]).min(dim=0).values
|
307 |
+
mask = mask.clip(1, border_width)
|
308 |
+
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
|
309 |
+
mask = rearrange(mask, "H W -> 1 H W")
|
310 |
+
return mask
|
311 |
+
|
312 |
+
def tiled_block_forward(self, block, hidden_states, condition_emb, text_emb, freq_cis_img, residual, torch_dtype, data_device, computation_device, tile_size, tile_stride):
|
313 |
+
B, C, H, W = hidden_states.shape
|
314 |
+
|
315 |
+
weight = torch.zeros((1, 1, H, W), dtype=torch_dtype, device=data_device)
|
316 |
+
values = torch.zeros((B, C, H, W), dtype=torch_dtype, device=data_device)
|
317 |
+
|
318 |
+
# Split tasks
|
319 |
+
tasks = []
|
320 |
+
for h in range(0, H, tile_stride):
|
321 |
+
for w in range(0, W, tile_stride):
|
322 |
+
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
|
323 |
+
continue
|
324 |
+
h_, w_ = h + tile_size, w + tile_size
|
325 |
+
if h_ > H: h, h_ = H - tile_size, H
|
326 |
+
if w_ > W: w, w_ = W - tile_size, W
|
327 |
+
tasks.append((h, h_, w, w_))
|
328 |
+
|
329 |
+
# Run
|
330 |
+
for hl, hr, wl, wr in tasks:
|
331 |
+
hidden_states_batch = hidden_states[:, :, hl:hr, wl:wr].to(computation_device)
|
332 |
+
hidden_states_batch = rearrange(hidden_states_batch, "B C H W -> B (H W) C")
|
333 |
+
if residual is not None:
|
334 |
+
residual_batch = residual[:, :, hl:hr, wl:wr].to(computation_device)
|
335 |
+
residual_batch = rearrange(residual_batch, "B C H W -> B (H W) C")
|
336 |
+
else:
|
337 |
+
residual_batch = None
|
338 |
+
|
339 |
+
# Forward
|
340 |
+
hidden_states_batch = block(hidden_states_batch, condition_emb, text_emb, freq_cis_img, residual_batch).to(data_device)
|
341 |
+
hidden_states_batch = rearrange(hidden_states_batch, "B (H W) C -> B C H W", H=hr-hl)
|
342 |
+
|
343 |
+
mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
|
344 |
+
values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
|
345 |
+
weight[:, :, hl:hr, wl:wr] += mask
|
346 |
+
values /= weight
|
347 |
+
return values
|
348 |
+
|
349 |
+
def forward(
|
350 |
+
self, hidden_states, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5, timestep, size_emb, freq_cis_img,
|
351 |
+
tiled=False, tile_size=64, tile_stride=32,
|
352 |
+
to_cache=False,
|
353 |
+
use_gradient_checkpointing=False,
|
354 |
+
):
|
355 |
+
# Embeddings
|
356 |
+
text_emb = self.prepare_text_emb(text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5)
|
357 |
+
condition_emb = self.prepare_extra_emb(text_emb_t5, timestep, size_emb, hidden_states.dtype, hidden_states.shape[0])
|
358 |
+
|
359 |
+
# Input
|
360 |
+
height, width = hidden_states.shape[-2], hidden_states.shape[-1]
|
361 |
+
hidden_states = self.patch_embedder(hidden_states)
|
362 |
+
|
363 |
+
# Blocks
|
364 |
+
def create_custom_forward(module):
|
365 |
+
def custom_forward(*inputs):
|
366 |
+
return module(*inputs)
|
367 |
+
return custom_forward
|
368 |
+
if tiled:
|
369 |
+
hidden_states = rearrange(hidden_states, "B (H W) C -> B C H W", H=height//2)
|
370 |
+
residuals = []
|
371 |
+
for block_id, block in enumerate(self.blocks):
|
372 |
+
residual = residuals.pop() if block_id >= self.num_layers_down else None
|
373 |
+
hidden_states = self.tiled_block_forward(
|
374 |
+
block, hidden_states, condition_emb, text_emb, freq_cis_img, residual,
|
375 |
+
torch_dtype=hidden_states.dtype, data_device=hidden_states.device, computation_device=hidden_states.device,
|
376 |
+
tile_size=tile_size, tile_stride=tile_stride
|
377 |
+
)
|
378 |
+
if block_id < self.num_layers_down - 2:
|
379 |
+
residuals.append(hidden_states)
|
380 |
+
hidden_states = rearrange(hidden_states, "B C H W -> B (H W) C")
|
381 |
+
else:
|
382 |
+
residuals = []
|
383 |
+
for block_id, block in enumerate(self.blocks):
|
384 |
+
residual = residuals.pop() if block_id >= self.num_layers_down else None
|
385 |
+
if self.training and use_gradient_checkpointing:
|
386 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
387 |
+
create_custom_forward(block),
|
388 |
+
hidden_states, condition_emb, text_emb, freq_cis_img, residual,
|
389 |
+
use_reentrant=False,
|
390 |
+
)
|
391 |
+
else:
|
392 |
+
hidden_states = block(hidden_states, condition_emb, text_emb, freq_cis_img, residual, to_cache=to_cache)
|
393 |
+
if block_id < self.num_layers_down - 2:
|
394 |
+
residuals.append(hidden_states)
|
395 |
+
|
396 |
+
# Output
|
397 |
+
hidden_states = self.final_layer(hidden_states, condition_emb)
|
398 |
+
hidden_states = self.unpatchify(hidden_states, height//2, width//2)
|
399 |
+
hidden_states, _ = hidden_states.chunk(2, dim=1)
|
400 |
+
return hidden_states
|
401 |
+
|
402 |
+
def state_dict_converter(self):
|
403 |
+
return HunyuanDiTStateDictConverter()
|
404 |
+
|
405 |
+
|
406 |
+
|
407 |
+
class HunyuanDiTStateDictConverter():
|
408 |
+
def __init__(self):
|
409 |
+
pass
|
410 |
+
|
411 |
+
def from_diffusers(self, state_dict):
|
412 |
+
state_dict_ = {}
|
413 |
+
for name, param in state_dict.items():
|
414 |
+
name_ = name
|
415 |
+
name_ = name_.replace(".default_modulation.", ".modulation.")
|
416 |
+
name_ = name_.replace(".mlp.fc1.", ".mlp.0.")
|
417 |
+
name_ = name_.replace(".mlp.fc2.", ".mlp.2.")
|
418 |
+
name_ = name_.replace(".attn1.q_norm.", ".rota1.q_norm.")
|
419 |
+
name_ = name_.replace(".attn2.q_norm.", ".rota2.q_norm.")
|
420 |
+
name_ = name_.replace(".attn1.k_norm.", ".rota1.k_norm.")
|
421 |
+
name_ = name_.replace(".attn2.k_norm.", ".rota2.k_norm.")
|
422 |
+
name_ = name_.replace(".q_proj.", ".to_q.")
|
423 |
+
name_ = name_.replace(".out_proj.", ".to_out.")
|
424 |
+
name_ = name_.replace("text_embedding_padding", "text_emb_padding")
|
425 |
+
name_ = name_.replace("mlp_t5.0.", "t5_embedder.0.")
|
426 |
+
name_ = name_.replace("mlp_t5.2.", "t5_embedder.2.")
|
427 |
+
name_ = name_.replace("pooler.", "t5_pooler.")
|
428 |
+
name_ = name_.replace("x_embedder.", "patch_embedder.")
|
429 |
+
name_ = name_.replace("t_embedder.", "timestep_embedder.")
|
430 |
+
name_ = name_.replace("t5_pooler.to_q.", "t5_pooler.q_proj.")
|
431 |
+
name_ = name_.replace("style_embedder.weight", "style_embedder")
|
432 |
+
if ".kv_proj." in name_:
|
433 |
+
param_k = param[:param.shape[0]//2]
|
434 |
+
param_v = param[param.shape[0]//2:]
|
435 |
+
state_dict_[name_.replace(".kv_proj.", ".to_k.")] = param_k
|
436 |
+
state_dict_[name_.replace(".kv_proj.", ".to_v.")] = param_v
|
437 |
+
elif ".Wqkv." in name_:
|
438 |
+
param_q = param[:param.shape[0]//3]
|
439 |
+
param_k = param[param.shape[0]//3:param.shape[0]//3*2]
|
440 |
+
param_v = param[param.shape[0]//3*2:]
|
441 |
+
state_dict_[name_.replace(".Wqkv.", ".to_q.")] = param_q
|
442 |
+
state_dict_[name_.replace(".Wqkv.", ".to_k.")] = param_k
|
443 |
+
state_dict_[name_.replace(".Wqkv.", ".to_v.")] = param_v
|
444 |
+
elif "style_embedder" in name_:
|
445 |
+
state_dict_[name_] = param.squeeze()
|
446 |
+
else:
|
447 |
+
state_dict_[name_] = param
|
448 |
+
return state_dict_
|
449 |
+
|
450 |
+
def from_civitai(self, state_dict):
|
451 |
+
return self.from_diffusers(state_dict)
|
diffsynth/models/hunyuan_dit_text_encoder.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BertModel, BertConfig, T5EncoderModel, T5Config
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
class HunyuanDiTCLIPTextEncoder(BertModel):
|
7 |
+
def __init__(self):
|
8 |
+
config = BertConfig(
|
9 |
+
_name_or_path = "",
|
10 |
+
architectures = ["BertModel"],
|
11 |
+
attention_probs_dropout_prob = 0.1,
|
12 |
+
bos_token_id = 0,
|
13 |
+
classifier_dropout = None,
|
14 |
+
directionality = "bidi",
|
15 |
+
eos_token_id = 2,
|
16 |
+
hidden_act = "gelu",
|
17 |
+
hidden_dropout_prob = 0.1,
|
18 |
+
hidden_size = 1024,
|
19 |
+
initializer_range = 0.02,
|
20 |
+
intermediate_size = 4096,
|
21 |
+
layer_norm_eps = 1e-12,
|
22 |
+
max_position_embeddings = 512,
|
23 |
+
model_type = "bert",
|
24 |
+
num_attention_heads = 16,
|
25 |
+
num_hidden_layers = 24,
|
26 |
+
output_past = True,
|
27 |
+
pad_token_id = 0,
|
28 |
+
pooler_fc_size = 768,
|
29 |
+
pooler_num_attention_heads = 12,
|
30 |
+
pooler_num_fc_layers = 3,
|
31 |
+
pooler_size_per_head = 128,
|
32 |
+
pooler_type = "first_token_transform",
|
33 |
+
position_embedding_type = "absolute",
|
34 |
+
torch_dtype = "float32",
|
35 |
+
transformers_version = "4.37.2",
|
36 |
+
type_vocab_size = 2,
|
37 |
+
use_cache = True,
|
38 |
+
vocab_size = 47020
|
39 |
+
)
|
40 |
+
super().__init__(config, add_pooling_layer=False)
|
41 |
+
self.eval()
|
42 |
+
|
43 |
+
def forward(self, input_ids, attention_mask, clip_skip=1):
|
44 |
+
input_shape = input_ids.size()
|
45 |
+
|
46 |
+
batch_size, seq_length = input_shape
|
47 |
+
device = input_ids.device
|
48 |
+
|
49 |
+
past_key_values_length = 0
|
50 |
+
|
51 |
+
if attention_mask is None:
|
52 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
53 |
+
|
54 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
55 |
+
|
56 |
+
embedding_output = self.embeddings(
|
57 |
+
input_ids=input_ids,
|
58 |
+
position_ids=None,
|
59 |
+
token_type_ids=None,
|
60 |
+
inputs_embeds=None,
|
61 |
+
past_key_values_length=0,
|
62 |
+
)
|
63 |
+
encoder_outputs = self.encoder(
|
64 |
+
embedding_output,
|
65 |
+
attention_mask=extended_attention_mask,
|
66 |
+
head_mask=None,
|
67 |
+
encoder_hidden_states=None,
|
68 |
+
encoder_attention_mask=None,
|
69 |
+
past_key_values=None,
|
70 |
+
use_cache=False,
|
71 |
+
output_attentions=False,
|
72 |
+
output_hidden_states=True,
|
73 |
+
return_dict=True,
|
74 |
+
)
|
75 |
+
all_hidden_states = encoder_outputs.hidden_states
|
76 |
+
prompt_emb = all_hidden_states[-clip_skip]
|
77 |
+
if clip_skip > 1:
|
78 |
+
mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std()
|
79 |
+
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
80 |
+
return prompt_emb
|
81 |
+
|
82 |
+
def state_dict_converter(self):
|
83 |
+
return HunyuanDiTCLIPTextEncoderStateDictConverter()
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
class HunyuanDiTT5TextEncoder(T5EncoderModel):
|
88 |
+
def __init__(self):
|
89 |
+
config = T5Config(
|
90 |
+
_name_or_path = "../HunyuanDiT/t2i/mt5",
|
91 |
+
architectures = ["MT5ForConditionalGeneration"],
|
92 |
+
classifier_dropout = 0.0,
|
93 |
+
d_ff = 5120,
|
94 |
+
d_kv = 64,
|
95 |
+
d_model = 2048,
|
96 |
+
decoder_start_token_id = 0,
|
97 |
+
dense_act_fn = "gelu_new",
|
98 |
+
dropout_rate = 0.1,
|
99 |
+
eos_token_id = 1,
|
100 |
+
feed_forward_proj = "gated-gelu",
|
101 |
+
initializer_factor = 1.0,
|
102 |
+
is_encoder_decoder = True,
|
103 |
+
is_gated_act = True,
|
104 |
+
layer_norm_epsilon = 1e-06,
|
105 |
+
model_type = "t5",
|
106 |
+
num_decoder_layers = 24,
|
107 |
+
num_heads = 32,
|
108 |
+
num_layers = 24,
|
109 |
+
output_past = True,
|
110 |
+
pad_token_id = 0,
|
111 |
+
relative_attention_max_distance = 128,
|
112 |
+
relative_attention_num_buckets = 32,
|
113 |
+
tie_word_embeddings = False,
|
114 |
+
tokenizer_class = "T5Tokenizer",
|
115 |
+
transformers_version = "4.37.2",
|
116 |
+
use_cache = True,
|
117 |
+
vocab_size = 250112
|
118 |
+
)
|
119 |
+
super().__init__(config)
|
120 |
+
self.eval()
|
121 |
+
|
122 |
+
def forward(self, input_ids, attention_mask, clip_skip=1):
|
123 |
+
outputs = super().forward(
|
124 |
+
input_ids=input_ids,
|
125 |
+
attention_mask=attention_mask,
|
126 |
+
output_hidden_states=True,
|
127 |
+
)
|
128 |
+
prompt_emb = outputs.hidden_states[-clip_skip]
|
129 |
+
if clip_skip > 1:
|
130 |
+
mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std()
|
131 |
+
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
132 |
+
return prompt_emb
|
133 |
+
|
134 |
+
def state_dict_converter(self):
|
135 |
+
return HunyuanDiTT5TextEncoderStateDictConverter()
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
class HunyuanDiTCLIPTextEncoderStateDictConverter():
|
140 |
+
def __init__(self):
|
141 |
+
pass
|
142 |
+
|
143 |
+
def from_diffusers(self, state_dict):
|
144 |
+
state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")}
|
145 |
+
return state_dict_
|
146 |
+
|
147 |
+
def from_civitai(self, state_dict):
|
148 |
+
return self.from_diffusers(state_dict)
|
149 |
+
|
150 |
+
|
151 |
+
class HunyuanDiTT5TextEncoderStateDictConverter():
|
152 |
+
def __init__(self):
|
153 |
+
pass
|
154 |
+
|
155 |
+
def from_diffusers(self, state_dict):
|
156 |
+
state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")}
|
157 |
+
state_dict_["shared.weight"] = state_dict["shared.weight"]
|
158 |
+
return state_dict_
|
159 |
+
|
160 |
+
def from_civitai(self, state_dict):
|
161 |
+
return self.from_diffusers(state_dict)
|
diffsynth/models/kolors_text_encoder.py
ADDED
@@ -0,0 +1,1363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This model is copied from https://github.com/Kwai-Kolors/Kolors/tree/master/kolors/models.
|
3 |
+
We didn't modify this model.
|
4 |
+
The tensor operation is performed in the prompter.
|
5 |
+
"""
|
6 |
+
|
7 |
+
|
8 |
+
""" PyTorch ChatGLM model. """
|
9 |
+
|
10 |
+
import math
|
11 |
+
import copy
|
12 |
+
import warnings
|
13 |
+
import re
|
14 |
+
import sys
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.utils.checkpoint
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch import nn
|
20 |
+
from torch.nn import CrossEntropyLoss, LayerNorm
|
21 |
+
from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
|
22 |
+
from torch.nn.utils import skip_init
|
23 |
+
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
24 |
+
from copy import deepcopy
|
25 |
+
|
26 |
+
from transformers.modeling_outputs import (
|
27 |
+
BaseModelOutputWithPast,
|
28 |
+
CausalLMOutputWithPast,
|
29 |
+
SequenceClassifierOutputWithPast,
|
30 |
+
)
|
31 |
+
from transformers.modeling_utils import PreTrainedModel
|
32 |
+
from transformers.utils import logging
|
33 |
+
from transformers.generation.logits_process import LogitsProcessor
|
34 |
+
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
|
35 |
+
from transformers import PretrainedConfig
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
class ChatGLMConfig(PretrainedConfig):
|
40 |
+
model_type = "chatglm"
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
num_layers=28,
|
44 |
+
padded_vocab_size=65024,
|
45 |
+
hidden_size=4096,
|
46 |
+
ffn_hidden_size=13696,
|
47 |
+
kv_channels=128,
|
48 |
+
num_attention_heads=32,
|
49 |
+
seq_length=2048,
|
50 |
+
hidden_dropout=0.0,
|
51 |
+
classifier_dropout=None,
|
52 |
+
attention_dropout=0.0,
|
53 |
+
layernorm_epsilon=1e-5,
|
54 |
+
rmsnorm=True,
|
55 |
+
apply_residual_connection_post_layernorm=False,
|
56 |
+
post_layer_norm=True,
|
57 |
+
add_bias_linear=False,
|
58 |
+
add_qkv_bias=False,
|
59 |
+
bias_dropout_fusion=True,
|
60 |
+
multi_query_attention=False,
|
61 |
+
multi_query_group_num=1,
|
62 |
+
apply_query_key_layer_scaling=True,
|
63 |
+
attention_softmax_in_fp32=True,
|
64 |
+
fp32_residual_connection=False,
|
65 |
+
quantization_bit=0,
|
66 |
+
pre_seq_len=None,
|
67 |
+
prefix_projection=False,
|
68 |
+
**kwargs
|
69 |
+
):
|
70 |
+
self.num_layers = num_layers
|
71 |
+
self.vocab_size = padded_vocab_size
|
72 |
+
self.padded_vocab_size = padded_vocab_size
|
73 |
+
self.hidden_size = hidden_size
|
74 |
+
self.ffn_hidden_size = ffn_hidden_size
|
75 |
+
self.kv_channels = kv_channels
|
76 |
+
self.num_attention_heads = num_attention_heads
|
77 |
+
self.seq_length = seq_length
|
78 |
+
self.hidden_dropout = hidden_dropout
|
79 |
+
self.classifier_dropout = classifier_dropout
|
80 |
+
self.attention_dropout = attention_dropout
|
81 |
+
self.layernorm_epsilon = layernorm_epsilon
|
82 |
+
self.rmsnorm = rmsnorm
|
83 |
+
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
84 |
+
self.post_layer_norm = post_layer_norm
|
85 |
+
self.add_bias_linear = add_bias_linear
|
86 |
+
self.add_qkv_bias = add_qkv_bias
|
87 |
+
self.bias_dropout_fusion = bias_dropout_fusion
|
88 |
+
self.multi_query_attention = multi_query_attention
|
89 |
+
self.multi_query_group_num = multi_query_group_num
|
90 |
+
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
91 |
+
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
|
92 |
+
self.fp32_residual_connection = fp32_residual_connection
|
93 |
+
self.quantization_bit = quantization_bit
|
94 |
+
self.pre_seq_len = pre_seq_len
|
95 |
+
self.prefix_projection = prefix_projection
|
96 |
+
super().__init__(**kwargs)
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
# flags required to enable jit fusion kernels
|
101 |
+
|
102 |
+
if sys.platform != 'darwin':
|
103 |
+
torch._C._jit_set_profiling_mode(False)
|
104 |
+
torch._C._jit_set_profiling_executor(False)
|
105 |
+
torch._C._jit_override_can_fuse_on_cpu(True)
|
106 |
+
torch._C._jit_override_can_fuse_on_gpu(True)
|
107 |
+
|
108 |
+
logger = logging.get_logger(__name__)
|
109 |
+
|
110 |
+
_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
|
111 |
+
_CONFIG_FOR_DOC = "ChatGLM6BConfig"
|
112 |
+
|
113 |
+
CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
114 |
+
"THUDM/chatglm3-6b-base",
|
115 |
+
# See all ChatGLM models at https://huggingface.co/models?filter=chatglm
|
116 |
+
]
|
117 |
+
|
118 |
+
|
119 |
+
def default_init(cls, *args, **kwargs):
|
120 |
+
return cls(*args, **kwargs)
|
121 |
+
|
122 |
+
|
123 |
+
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
124 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
125 |
+
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
126 |
+
scores.zero_()
|
127 |
+
scores[..., 5] = 5e4
|
128 |
+
return scores
|
129 |
+
|
130 |
+
|
131 |
+
class PrefixEncoder(torch.nn.Module):
|
132 |
+
"""
|
133 |
+
The torch.nn model to encode the prefix
|
134 |
+
Input shape: (batch-size, prefix-length)
|
135 |
+
Output shape: (batch-size, prefix-length, 2*layers*hidden)
|
136 |
+
"""
|
137 |
+
|
138 |
+
def __init__(self, config: ChatGLMConfig):
|
139 |
+
super().__init__()
|
140 |
+
self.prefix_projection = config.prefix_projection
|
141 |
+
if self.prefix_projection:
|
142 |
+
# Use a two-layer MLP to encode the prefix
|
143 |
+
kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
|
144 |
+
self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
|
145 |
+
self.trans = torch.nn.Sequential(
|
146 |
+
torch.nn.Linear(kv_size, config.hidden_size),
|
147 |
+
torch.nn.Tanh(),
|
148 |
+
torch.nn.Linear(config.hidden_size, kv_size)
|
149 |
+
)
|
150 |
+
else:
|
151 |
+
self.embedding = torch.nn.Embedding(config.pre_seq_len,
|
152 |
+
config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
|
153 |
+
|
154 |
+
def forward(self, prefix: torch.Tensor):
|
155 |
+
if self.prefix_projection:
|
156 |
+
prefix_tokens = self.embedding(prefix)
|
157 |
+
past_key_values = self.trans(prefix_tokens)
|
158 |
+
else:
|
159 |
+
past_key_values = self.embedding(prefix)
|
160 |
+
return past_key_values
|
161 |
+
|
162 |
+
|
163 |
+
def split_tensor_along_last_dim(
|
164 |
+
tensor: torch.Tensor,
|
165 |
+
num_partitions: int,
|
166 |
+
contiguous_split_chunks: bool = False,
|
167 |
+
) -> List[torch.Tensor]:
|
168 |
+
"""Split a tensor along its last dimension.
|
169 |
+
|
170 |
+
Arguments:
|
171 |
+
tensor: input tensor.
|
172 |
+
num_partitions: number of partitions to split the tensor
|
173 |
+
contiguous_split_chunks: If True, make each chunk contiguous
|
174 |
+
in memory.
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
A list of Tensors
|
178 |
+
"""
|
179 |
+
# Get the size and dimension.
|
180 |
+
last_dim = tensor.dim() - 1
|
181 |
+
last_dim_size = tensor.size()[last_dim] // num_partitions
|
182 |
+
# Split.
|
183 |
+
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
184 |
+
# Note: torch.split does not create contiguous tensors by default.
|
185 |
+
if contiguous_split_chunks:
|
186 |
+
return tuple(chunk.contiguous() for chunk in tensor_list)
|
187 |
+
|
188 |
+
return tensor_list
|
189 |
+
|
190 |
+
|
191 |
+
class RotaryEmbedding(nn.Module):
|
192 |
+
def __init__(self, dim, original_impl=False, device=None, dtype=None):
|
193 |
+
super().__init__()
|
194 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
|
195 |
+
self.register_buffer("inv_freq", inv_freq)
|
196 |
+
self.dim = dim
|
197 |
+
self.original_impl = original_impl
|
198 |
+
|
199 |
+
def forward_impl(
|
200 |
+
self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
|
201 |
+
):
|
202 |
+
"""Enhanced Transformer with Rotary Position Embedding.
|
203 |
+
|
204 |
+
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
|
205 |
+
transformers/rope/__init__.py. MIT License:
|
206 |
+
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
|
207 |
+
"""
|
208 |
+
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
209 |
+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
|
210 |
+
|
211 |
+
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
212 |
+
seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
|
213 |
+
|
214 |
+
# Calculate the product of position index and $\theta_i$
|
215 |
+
idx_theta = torch.outer(seq_idx, theta).float()
|
216 |
+
|
217 |
+
cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
|
218 |
+
|
219 |
+
# this is to mimic the behaviour of complex32, else we will get different results
|
220 |
+
if dtype in (torch.float16, torch.bfloat16, torch.int8):
|
221 |
+
cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
|
222 |
+
return cache
|
223 |
+
|
224 |
+
def forward(self, max_seq_len, offset=0):
|
225 |
+
return self.forward_impl(
|
226 |
+
max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
|
227 |
+
)
|
228 |
+
|
229 |
+
|
230 |
+
@torch.jit.script
|
231 |
+
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
232 |
+
# x: [sq, b, np, hn]
|
233 |
+
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
234 |
+
rot_dim = rope_cache.shape[-2] * 2
|
235 |
+
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
|
236 |
+
# truncate to support variable sizes
|
237 |
+
rope_cache = rope_cache[:sq]
|
238 |
+
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
|
239 |
+
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
|
240 |
+
x_out2 = torch.stack(
|
241 |
+
[
|
242 |
+
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
|
243 |
+
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
|
244 |
+
],
|
245 |
+
-1,
|
246 |
+
)
|
247 |
+
x_out2 = x_out2.flatten(3)
|
248 |
+
return torch.cat((x_out2, x_pass), dim=-1)
|
249 |
+
|
250 |
+
|
251 |
+
class RMSNorm(torch.nn.Module):
|
252 |
+
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
|
253 |
+
super().__init__()
|
254 |
+
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
|
255 |
+
self.eps = eps
|
256 |
+
|
257 |
+
def forward(self, hidden_states: torch.Tensor):
|
258 |
+
input_dtype = hidden_states.dtype
|
259 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
260 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
261 |
+
|
262 |
+
return (self.weight * hidden_states).to(input_dtype)
|
263 |
+
|
264 |
+
|
265 |
+
class CoreAttention(torch.nn.Module):
|
266 |
+
def __init__(self, config: ChatGLMConfig, layer_number):
|
267 |
+
super(CoreAttention, self).__init__()
|
268 |
+
|
269 |
+
self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
|
270 |
+
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
|
271 |
+
if self.apply_query_key_layer_scaling:
|
272 |
+
self.attention_softmax_in_fp32 = True
|
273 |
+
self.layer_number = max(1, layer_number)
|
274 |
+
|
275 |
+
projection_size = config.kv_channels * config.num_attention_heads
|
276 |
+
|
277 |
+
# Per attention head and per partition values.
|
278 |
+
self.hidden_size_per_partition = projection_size
|
279 |
+
self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
|
280 |
+
self.num_attention_heads_per_partition = config.num_attention_heads
|
281 |
+
|
282 |
+
coeff = None
|
283 |
+
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
284 |
+
if self.apply_query_key_layer_scaling:
|
285 |
+
coeff = self.layer_number
|
286 |
+
self.norm_factor *= coeff
|
287 |
+
self.coeff = coeff
|
288 |
+
|
289 |
+
self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
|
290 |
+
|
291 |
+
def forward(self, query_layer, key_layer, value_layer, attention_mask):
|
292 |
+
pytorch_major_version = int(torch.__version__.split('.')[0])
|
293 |
+
if pytorch_major_version >= 2:
|
294 |
+
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
|
295 |
+
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
296 |
+
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
|
297 |
+
is_causal=True)
|
298 |
+
else:
|
299 |
+
if attention_mask is not None:
|
300 |
+
attention_mask = ~attention_mask
|
301 |
+
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
|
302 |
+
attention_mask)
|
303 |
+
context_layer = context_layer.permute(2, 0, 1, 3)
|
304 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
305 |
+
context_layer = context_layer.reshape(*new_context_layer_shape)
|
306 |
+
else:
|
307 |
+
# Raw attention scores
|
308 |
+
|
309 |
+
# [b, np, sq, sk]
|
310 |
+
output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
|
311 |
+
|
312 |
+
# [sq, b, np, hn] -> [sq, b * np, hn]
|
313 |
+
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
|
314 |
+
# [sk, b, np, hn] -> [sk, b * np, hn]
|
315 |
+
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
|
316 |
+
|
317 |
+
# preallocting input tensor: [b * np, sq, sk]
|
318 |
+
matmul_input_buffer = torch.empty(
|
319 |
+
output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
|
320 |
+
device=query_layer.device
|
321 |
+
)
|
322 |
+
|
323 |
+
# Raw attention scores. [b * np, sq, sk]
|
324 |
+
matmul_result = torch.baddbmm(
|
325 |
+
matmul_input_buffer,
|
326 |
+
query_layer.transpose(0, 1), # [b * np, sq, hn]
|
327 |
+
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
|
328 |
+
beta=0.0,
|
329 |
+
alpha=(1.0 / self.norm_factor),
|
330 |
+
)
|
331 |
+
|
332 |
+
# change view to [b, np, sq, sk]
|
333 |
+
attention_scores = matmul_result.view(*output_size)
|
334 |
+
|
335 |
+
# ===========================
|
336 |
+
# Attention probs and dropout
|
337 |
+
# ===========================
|
338 |
+
|
339 |
+
# attention scores and attention mask [b, np, sq, sk]
|
340 |
+
if self.attention_softmax_in_fp32:
|
341 |
+
attention_scores = attention_scores.float()
|
342 |
+
if self.coeff is not None:
|
343 |
+
attention_scores = attention_scores * self.coeff
|
344 |
+
if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
|
345 |
+
attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
|
346 |
+
device=attention_scores.device, dtype=torch.bool)
|
347 |
+
attention_mask.tril_()
|
348 |
+
attention_mask = ~attention_mask
|
349 |
+
if attention_mask is not None:
|
350 |
+
attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
|
351 |
+
attention_probs = F.softmax(attention_scores, dim=-1)
|
352 |
+
attention_probs = attention_probs.type_as(value_layer)
|
353 |
+
|
354 |
+
# This is actually dropping out entire tokens to attend to, which might
|
355 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
356 |
+
attention_probs = self.attention_dropout(attention_probs)
|
357 |
+
# =========================
|
358 |
+
# Context layer. [sq, b, hp]
|
359 |
+
# =========================
|
360 |
+
|
361 |
+
# value_layer -> context layer.
|
362 |
+
# [sk, b, np, hn] --> [b, np, sq, hn]
|
363 |
+
|
364 |
+
# context layer shape: [b, np, sq, hn]
|
365 |
+
output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
|
366 |
+
# change view [sk, b * np, hn]
|
367 |
+
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
|
368 |
+
# change view [b * np, sq, sk]
|
369 |
+
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
|
370 |
+
# matmul: [b * np, sq, hn]
|
371 |
+
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
|
372 |
+
# change view [b, np, sq, hn]
|
373 |
+
context_layer = context_layer.view(*output_size)
|
374 |
+
# [b, np, sq, hn] --> [sq, b, np, hn]
|
375 |
+
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
376 |
+
# [sq, b, np, hn] --> [sq, b, hp]
|
377 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
378 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
379 |
+
|
380 |
+
return context_layer
|
381 |
+
|
382 |
+
|
383 |
+
class SelfAttention(torch.nn.Module):
|
384 |
+
"""Parallel self-attention layer abstract class.
|
385 |
+
|
386 |
+
Self-attention layer takes input with size [s, b, h]
|
387 |
+
and returns output of the same size.
|
388 |
+
"""
|
389 |
+
|
390 |
+
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
|
391 |
+
super(SelfAttention, self).__init__()
|
392 |
+
self.layer_number = max(1, layer_number)
|
393 |
+
|
394 |
+
self.projection_size = config.kv_channels * config.num_attention_heads
|
395 |
+
|
396 |
+
# Per attention head and per partition values.
|
397 |
+
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
|
398 |
+
self.num_attention_heads_per_partition = config.num_attention_heads
|
399 |
+
|
400 |
+
self.multi_query_attention = config.multi_query_attention
|
401 |
+
self.qkv_hidden_size = 3 * self.projection_size
|
402 |
+
if self.multi_query_attention:
|
403 |
+
self.num_multi_query_groups_per_partition = config.multi_query_group_num
|
404 |
+
self.qkv_hidden_size = (
|
405 |
+
self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
|
406 |
+
)
|
407 |
+
self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
|
408 |
+
bias=config.add_bias_linear or config.add_qkv_bias,
|
409 |
+
device=device, **_config_to_kwargs(config)
|
410 |
+
)
|
411 |
+
|
412 |
+
self.core_attention = CoreAttention(config, self.layer_number)
|
413 |
+
|
414 |
+
# Output.
|
415 |
+
self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
|
416 |
+
device=device, **_config_to_kwargs(config)
|
417 |
+
)
|
418 |
+
|
419 |
+
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
|
420 |
+
if self.multi_query_attention:
|
421 |
+
num_attention_heads = self.num_multi_query_groups_per_partition
|
422 |
+
else:
|
423 |
+
num_attention_heads = self.num_attention_heads_per_partition
|
424 |
+
return torch.empty(
|
425 |
+
inference_max_sequence_len,
|
426 |
+
batch_size,
|
427 |
+
num_attention_heads,
|
428 |
+
self.hidden_size_per_attention_head,
|
429 |
+
dtype=dtype,
|
430 |
+
device=device,
|
431 |
+
)
|
432 |
+
|
433 |
+
def forward(
|
434 |
+
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
435 |
+
):
|
436 |
+
# hidden_states: [sq, b, h]
|
437 |
+
|
438 |
+
# =================================================
|
439 |
+
# Pre-allocate memory for key-values for inference.
|
440 |
+
# =================================================
|
441 |
+
# =====================
|
442 |
+
# Query, Key, and Value
|
443 |
+
# =====================
|
444 |
+
|
445 |
+
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
446 |
+
mixed_x_layer = self.query_key_value(hidden_states)
|
447 |
+
|
448 |
+
if self.multi_query_attention:
|
449 |
+
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
|
450 |
+
[
|
451 |
+
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
|
452 |
+
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
453 |
+
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
454 |
+
],
|
455 |
+
dim=-1,
|
456 |
+
)
|
457 |
+
query_layer = query_layer.view(
|
458 |
+
query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
459 |
+
)
|
460 |
+
key_layer = key_layer.view(
|
461 |
+
key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
|
462 |
+
)
|
463 |
+
value_layer = value_layer.view(
|
464 |
+
value_layer.size()[:-1]
|
465 |
+
+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
|
466 |
+
)
|
467 |
+
else:
|
468 |
+
new_tensor_shape = mixed_x_layer.size()[:-1] + \
|
469 |
+
(self.num_attention_heads_per_partition,
|
470 |
+
3 * self.hidden_size_per_attention_head)
|
471 |
+
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
472 |
+
|
473 |
+
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
474 |
+
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
475 |
+
|
476 |
+
# apply relative positional encoding (rotary embedding)
|
477 |
+
if rotary_pos_emb is not None:
|
478 |
+
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
|
479 |
+
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
480 |
+
|
481 |
+
# adjust key and value for inference
|
482 |
+
if kv_cache is not None:
|
483 |
+
cache_k, cache_v = kv_cache
|
484 |
+
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
485 |
+
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
486 |
+
if use_cache:
|
487 |
+
kv_cache = (key_layer, value_layer)
|
488 |
+
else:
|
489 |
+
kv_cache = None
|
490 |
+
|
491 |
+
if self.multi_query_attention:
|
492 |
+
key_layer = key_layer.unsqueeze(-2)
|
493 |
+
key_layer = key_layer.expand(
|
494 |
+
-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
|
495 |
+
)
|
496 |
+
key_layer = key_layer.contiguous().view(
|
497 |
+
key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
498 |
+
)
|
499 |
+
value_layer = value_layer.unsqueeze(-2)
|
500 |
+
value_layer = value_layer.expand(
|
501 |
+
-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
|
502 |
+
)
|
503 |
+
value_layer = value_layer.contiguous().view(
|
504 |
+
value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
505 |
+
)
|
506 |
+
|
507 |
+
# ==================================
|
508 |
+
# core attention computation
|
509 |
+
# ==================================
|
510 |
+
|
511 |
+
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
|
512 |
+
|
513 |
+
# =================
|
514 |
+
# Output. [sq, b, h]
|
515 |
+
# =================
|
516 |
+
|
517 |
+
output = self.dense(context_layer)
|
518 |
+
|
519 |
+
return output, kv_cache
|
520 |
+
|
521 |
+
|
522 |
+
def _config_to_kwargs(args):
|
523 |
+
common_kwargs = {
|
524 |
+
"dtype": args.torch_dtype,
|
525 |
+
}
|
526 |
+
return common_kwargs
|
527 |
+
|
528 |
+
|
529 |
+
class MLP(torch.nn.Module):
|
530 |
+
"""MLP.
|
531 |
+
|
532 |
+
MLP will take the input with h hidden state, project it to 4*h
|
533 |
+
hidden dimension, perform nonlinear transformation, and project the
|
534 |
+
state back into h hidden dimension.
|
535 |
+
"""
|
536 |
+
|
537 |
+
def __init__(self, config: ChatGLMConfig, device=None):
|
538 |
+
super(MLP, self).__init__()
|
539 |
+
|
540 |
+
self.add_bias = config.add_bias_linear
|
541 |
+
|
542 |
+
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
543 |
+
self.dense_h_to_4h = nn.Linear(
|
544 |
+
config.hidden_size,
|
545 |
+
config.ffn_hidden_size * 2,
|
546 |
+
bias=self.add_bias,
|
547 |
+
device=device,
|
548 |
+
**_config_to_kwargs(config)
|
549 |
+
)
|
550 |
+
|
551 |
+
def swiglu(x):
|
552 |
+
x = torch.chunk(x, 2, dim=-1)
|
553 |
+
return F.silu(x[0]) * x[1]
|
554 |
+
|
555 |
+
self.activation_func = swiglu
|
556 |
+
|
557 |
+
# Project back to h.
|
558 |
+
self.dense_4h_to_h = nn.Linear(
|
559 |
+
config.ffn_hidden_size,
|
560 |
+
config.hidden_size,
|
561 |
+
bias=self.add_bias,
|
562 |
+
device=device,
|
563 |
+
**_config_to_kwargs(config)
|
564 |
+
)
|
565 |
+
|
566 |
+
def forward(self, hidden_states):
|
567 |
+
# [s, b, 4hp]
|
568 |
+
intermediate_parallel = self.dense_h_to_4h(hidden_states)
|
569 |
+
intermediate_parallel = self.activation_func(intermediate_parallel)
|
570 |
+
# [s, b, h]
|
571 |
+
output = self.dense_4h_to_h(intermediate_parallel)
|
572 |
+
return output
|
573 |
+
|
574 |
+
|
575 |
+
class GLMBlock(torch.nn.Module):
|
576 |
+
"""A single transformer layer.
|
577 |
+
|
578 |
+
Transformer layer takes input with size [s, b, h] and returns an
|
579 |
+
output of the same size.
|
580 |
+
"""
|
581 |
+
|
582 |
+
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
|
583 |
+
super(GLMBlock, self).__init__()
|
584 |
+
self.layer_number = layer_number
|
585 |
+
|
586 |
+
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
587 |
+
|
588 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
589 |
+
|
590 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
591 |
+
# Layernorm on the input data.
|
592 |
+
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
593 |
+
dtype=config.torch_dtype)
|
594 |
+
|
595 |
+
# Self attention.
|
596 |
+
self.self_attention = SelfAttention(config, layer_number, device=device)
|
597 |
+
self.hidden_dropout = config.hidden_dropout
|
598 |
+
|
599 |
+
# Layernorm on the attention output
|
600 |
+
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
601 |
+
dtype=config.torch_dtype)
|
602 |
+
|
603 |
+
# MLP
|
604 |
+
self.mlp = MLP(config, device=device)
|
605 |
+
|
606 |
+
def forward(
|
607 |
+
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
|
608 |
+
):
|
609 |
+
# hidden_states: [s, b, h]
|
610 |
+
|
611 |
+
# Layer norm at the beginning of the transformer layer.
|
612 |
+
layernorm_output = self.input_layernorm(hidden_states)
|
613 |
+
# Self attention.
|
614 |
+
attention_output, kv_cache = self.self_attention(
|
615 |
+
layernorm_output,
|
616 |
+
attention_mask,
|
617 |
+
rotary_pos_emb,
|
618 |
+
kv_cache=kv_cache,
|
619 |
+
use_cache=use_cache
|
620 |
+
)
|
621 |
+
|
622 |
+
# Residual connection.
|
623 |
+
if self.apply_residual_connection_post_layernorm:
|
624 |
+
residual = layernorm_output
|
625 |
+
else:
|
626 |
+
residual = hidden_states
|
627 |
+
|
628 |
+
layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
|
629 |
+
layernorm_input = residual + layernorm_input
|
630 |
+
|
631 |
+
# Layer norm post the self attention.
|
632 |
+
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
633 |
+
|
634 |
+
# MLP.
|
635 |
+
mlp_output = self.mlp(layernorm_output)
|
636 |
+
|
637 |
+
# Second residual connection.
|
638 |
+
if self.apply_residual_connection_post_layernorm:
|
639 |
+
residual = layernorm_output
|
640 |
+
else:
|
641 |
+
residual = layernorm_input
|
642 |
+
|
643 |
+
output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
|
644 |
+
output = residual + output
|
645 |
+
|
646 |
+
return output, kv_cache
|
647 |
+
|
648 |
+
|
649 |
+
class GLMTransformer(torch.nn.Module):
|
650 |
+
"""Transformer class."""
|
651 |
+
|
652 |
+
def __init__(self, config: ChatGLMConfig, device=None):
|
653 |
+
super(GLMTransformer, self).__init__()
|
654 |
+
|
655 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
656 |
+
self.post_layer_norm = config.post_layer_norm
|
657 |
+
|
658 |
+
# Number of layers.
|
659 |
+
self.num_layers = config.num_layers
|
660 |
+
|
661 |
+
# Transformer layers.
|
662 |
+
def build_layer(layer_number):
|
663 |
+
return GLMBlock(config, layer_number, device=device)
|
664 |
+
|
665 |
+
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
|
666 |
+
|
667 |
+
if self.post_layer_norm:
|
668 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
669 |
+
# Final layer norm before output.
|
670 |
+
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
671 |
+
dtype=config.torch_dtype)
|
672 |
+
|
673 |
+
self.gradient_checkpointing = False
|
674 |
+
|
675 |
+
def _get_layer(self, layer_number):
|
676 |
+
return self.layers[layer_number]
|
677 |
+
|
678 |
+
def forward(
|
679 |
+
self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
|
680 |
+
use_cache: Optional[bool] = True,
|
681 |
+
output_hidden_states: Optional[bool] = False,
|
682 |
+
):
|
683 |
+
if not kv_caches:
|
684 |
+
kv_caches = [None for _ in range(self.num_layers)]
|
685 |
+
presents = () if use_cache else None
|
686 |
+
if self.gradient_checkpointing and self.training:
|
687 |
+
if use_cache:
|
688 |
+
logger.warning_once(
|
689 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
690 |
+
)
|
691 |
+
use_cache = False
|
692 |
+
|
693 |
+
all_self_attentions = None
|
694 |
+
all_hidden_states = () if output_hidden_states else None
|
695 |
+
for index in range(self.num_layers):
|
696 |
+
if output_hidden_states:
|
697 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
698 |
+
|
699 |
+
layer = self._get_layer(index)
|
700 |
+
if self.gradient_checkpointing and self.training:
|
701 |
+
layer_ret = torch.utils.checkpoint.checkpoint(
|
702 |
+
layer,
|
703 |
+
hidden_states,
|
704 |
+
attention_mask,
|
705 |
+
rotary_pos_emb,
|
706 |
+
kv_caches[index],
|
707 |
+
use_cache
|
708 |
+
)
|
709 |
+
else:
|
710 |
+
layer_ret = layer(
|
711 |
+
hidden_states,
|
712 |
+
attention_mask,
|
713 |
+
rotary_pos_emb,
|
714 |
+
kv_cache=kv_caches[index],
|
715 |
+
use_cache=use_cache
|
716 |
+
)
|
717 |
+
hidden_states, kv_cache = layer_ret
|
718 |
+
if use_cache:
|
719 |
+
presents = presents + (kv_cache,)
|
720 |
+
|
721 |
+
if output_hidden_states:
|
722 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
723 |
+
|
724 |
+
# Final layer norm.
|
725 |
+
if self.post_layer_norm:
|
726 |
+
hidden_states = self.final_layernorm(hidden_states)
|
727 |
+
|
728 |
+
return hidden_states, presents, all_hidden_states, all_self_attentions
|
729 |
+
|
730 |
+
|
731 |
+
class ChatGLMPreTrainedModel(PreTrainedModel):
|
732 |
+
"""
|
733 |
+
An abstract class to handle weights initialization and
|
734 |
+
a simple interface for downloading and loading pretrained models.
|
735 |
+
"""
|
736 |
+
|
737 |
+
is_parallelizable = False
|
738 |
+
supports_gradient_checkpointing = True
|
739 |
+
config_class = ChatGLMConfig
|
740 |
+
base_model_prefix = "transformer"
|
741 |
+
_no_split_modules = ["GLMBlock"]
|
742 |
+
|
743 |
+
def _init_weights(self, module: nn.Module):
|
744 |
+
"""Initialize the weights."""
|
745 |
+
return
|
746 |
+
|
747 |
+
def get_masks(self, input_ids, past_key_values, padding_mask=None):
|
748 |
+
batch_size, seq_length = input_ids.shape
|
749 |
+
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
|
750 |
+
full_attention_mask.tril_()
|
751 |
+
past_length = 0
|
752 |
+
if past_key_values:
|
753 |
+
past_length = past_key_values[0][0].shape[0]
|
754 |
+
if past_length:
|
755 |
+
full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
|
756 |
+
device=input_ids.device), full_attention_mask), dim=-1)
|
757 |
+
if padding_mask is not None:
|
758 |
+
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
|
759 |
+
if not past_length and padding_mask is not None:
|
760 |
+
full_attention_mask -= padding_mask.unsqueeze(-1) - 1
|
761 |
+
full_attention_mask = (full_attention_mask < 0.5).bool()
|
762 |
+
full_attention_mask.unsqueeze_(1)
|
763 |
+
return full_attention_mask
|
764 |
+
|
765 |
+
def get_position_ids(self, input_ids, device):
|
766 |
+
batch_size, seq_length = input_ids.shape
|
767 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
768 |
+
return position_ids
|
769 |
+
|
770 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
771 |
+
if isinstance(module, GLMTransformer):
|
772 |
+
module.gradient_checkpointing = value
|
773 |
+
|
774 |
+
|
775 |
+
class Embedding(torch.nn.Module):
|
776 |
+
"""Language model embeddings."""
|
777 |
+
|
778 |
+
def __init__(self, config: ChatGLMConfig, device=None):
|
779 |
+
super(Embedding, self).__init__()
|
780 |
+
|
781 |
+
self.hidden_size = config.hidden_size
|
782 |
+
# Word embeddings (parallel).
|
783 |
+
self.word_embeddings = nn.Embedding(
|
784 |
+
config.padded_vocab_size,
|
785 |
+
self.hidden_size,
|
786 |
+
dtype=config.torch_dtype,
|
787 |
+
device=device
|
788 |
+
)
|
789 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
790 |
+
|
791 |
+
def forward(self, input_ids):
|
792 |
+
# Embeddings.
|
793 |
+
words_embeddings = self.word_embeddings(input_ids)
|
794 |
+
embeddings = words_embeddings
|
795 |
+
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
|
796 |
+
embeddings = embeddings.transpose(0, 1).contiguous()
|
797 |
+
# If the input flag for fp32 residual connection is set, convert for float.
|
798 |
+
if self.fp32_residual_connection:
|
799 |
+
embeddings = embeddings.float()
|
800 |
+
return embeddings
|
801 |
+
|
802 |
+
|
803 |
+
class ChatGLMModel(ChatGLMPreTrainedModel):
|
804 |
+
def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
|
805 |
+
super().__init__(config)
|
806 |
+
if empty_init:
|
807 |
+
init_method = skip_init
|
808 |
+
else:
|
809 |
+
init_method = default_init
|
810 |
+
init_kwargs = {}
|
811 |
+
if device is not None:
|
812 |
+
init_kwargs["device"] = device
|
813 |
+
self.embedding = init_method(Embedding, config, **init_kwargs)
|
814 |
+
self.num_layers = config.num_layers
|
815 |
+
self.multi_query_group_num = config.multi_query_group_num
|
816 |
+
self.kv_channels = config.kv_channels
|
817 |
+
|
818 |
+
# Rotary positional embeddings
|
819 |
+
self.seq_length = config.seq_length
|
820 |
+
rotary_dim = (
|
821 |
+
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
822 |
+
)
|
823 |
+
|
824 |
+
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
|
825 |
+
dtype=config.torch_dtype)
|
826 |
+
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
827 |
+
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
828 |
+
dtype=config.torch_dtype, **init_kwargs)
|
829 |
+
self.pre_seq_len = config.pre_seq_len
|
830 |
+
self.prefix_projection = config.prefix_projection
|
831 |
+
if self.pre_seq_len is not None:
|
832 |
+
for param in self.parameters():
|
833 |
+
param.requires_grad = False
|
834 |
+
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
835 |
+
self.prefix_encoder = PrefixEncoder(config)
|
836 |
+
self.dropout = torch.nn.Dropout(0.1)
|
837 |
+
|
838 |
+
def get_input_embeddings(self):
|
839 |
+
return self.embedding.word_embeddings
|
840 |
+
|
841 |
+
def get_prompt(self, batch_size, device, dtype=torch.half):
|
842 |
+
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
|
843 |
+
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
|
844 |
+
past_key_values = past_key_values.view(
|
845 |
+
batch_size,
|
846 |
+
self.pre_seq_len,
|
847 |
+
self.num_layers * 2,
|
848 |
+
self.multi_query_group_num,
|
849 |
+
self.kv_channels
|
850 |
+
)
|
851 |
+
# seq_len, b, nh, hidden_size
|
852 |
+
past_key_values = self.dropout(past_key_values)
|
853 |
+
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
|
854 |
+
return past_key_values
|
855 |
+
|
856 |
+
def forward(
|
857 |
+
self,
|
858 |
+
input_ids,
|
859 |
+
position_ids: Optional[torch.Tensor] = None,
|
860 |
+
attention_mask: Optional[torch.BoolTensor] = None,
|
861 |
+
full_attention_mask: Optional[torch.BoolTensor] = None,
|
862 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
863 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
864 |
+
use_cache: Optional[bool] = None,
|
865 |
+
output_hidden_states: Optional[bool] = None,
|
866 |
+
return_dict: Optional[bool] = None,
|
867 |
+
):
|
868 |
+
output_hidden_states = (
|
869 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
870 |
+
)
|
871 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
872 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
873 |
+
|
874 |
+
batch_size, seq_length = input_ids.shape
|
875 |
+
|
876 |
+
if inputs_embeds is None:
|
877 |
+
inputs_embeds = self.embedding(input_ids)
|
878 |
+
|
879 |
+
if self.pre_seq_len is not None:
|
880 |
+
if past_key_values is None:
|
881 |
+
past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
|
882 |
+
dtype=inputs_embeds.dtype)
|
883 |
+
if attention_mask is not None:
|
884 |
+
attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
|
885 |
+
attention_mask], dim=-1)
|
886 |
+
|
887 |
+
if full_attention_mask is None:
|
888 |
+
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
889 |
+
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
890 |
+
|
891 |
+
# Rotary positional embeddings
|
892 |
+
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
893 |
+
if position_ids is not None:
|
894 |
+
rotary_pos_emb = rotary_pos_emb[position_ids]
|
895 |
+
else:
|
896 |
+
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
897 |
+
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
898 |
+
|
899 |
+
# Run encoder.
|
900 |
+
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
901 |
+
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|
902 |
+
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
903 |
+
)
|
904 |
+
|
905 |
+
if not return_dict:
|
906 |
+
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
907 |
+
|
908 |
+
return BaseModelOutputWithPast(
|
909 |
+
last_hidden_state=hidden_states,
|
910 |
+
past_key_values=presents,
|
911 |
+
hidden_states=all_hidden_states,
|
912 |
+
attentions=all_self_attentions,
|
913 |
+
)
|
914 |
+
|
915 |
+
def quantize(self, weight_bit_width: int):
|
916 |
+
from .quantization import quantize
|
917 |
+
quantize(self.encoder, weight_bit_width)
|
918 |
+
return self
|
919 |
+
|
920 |
+
|
921 |
+
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
922 |
+
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
|
923 |
+
super().__init__(config)
|
924 |
+
|
925 |
+
self.max_sequence_length = config.max_length
|
926 |
+
self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
|
927 |
+
self.config = config
|
928 |
+
self.quantized = False
|
929 |
+
|
930 |
+
if self.config.quantization_bit:
|
931 |
+
self.quantize(self.config.quantization_bit, empty_init=True)
|
932 |
+
|
933 |
+
def _update_model_kwargs_for_generation(
|
934 |
+
self,
|
935 |
+
outputs: ModelOutput,
|
936 |
+
model_kwargs: Dict[str, Any],
|
937 |
+
is_encoder_decoder: bool = False,
|
938 |
+
standardize_cache_format: bool = False,
|
939 |
+
) -> Dict[str, Any]:
|
940 |
+
# update past_key_values
|
941 |
+
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
942 |
+
outputs, standardize_cache_format=standardize_cache_format
|
943 |
+
)
|
944 |
+
|
945 |
+
# update attention mask
|
946 |
+
if "attention_mask" in model_kwargs:
|
947 |
+
attention_mask = model_kwargs["attention_mask"]
|
948 |
+
model_kwargs["attention_mask"] = torch.cat(
|
949 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
950 |
+
)
|
951 |
+
|
952 |
+
# update position ids
|
953 |
+
if "position_ids" in model_kwargs:
|
954 |
+
position_ids = model_kwargs["position_ids"]
|
955 |
+
new_position_id = position_ids[..., -1:].clone()
|
956 |
+
new_position_id += 1
|
957 |
+
model_kwargs["position_ids"] = torch.cat(
|
958 |
+
[position_ids, new_position_id], dim=-1
|
959 |
+
)
|
960 |
+
|
961 |
+
model_kwargs["is_first_forward"] = False
|
962 |
+
return model_kwargs
|
963 |
+
|
964 |
+
def prepare_inputs_for_generation(
|
965 |
+
self,
|
966 |
+
input_ids: torch.LongTensor,
|
967 |
+
past_key_values: Optional[torch.Tensor] = None,
|
968 |
+
attention_mask: Optional[torch.Tensor] = None,
|
969 |
+
position_ids: Optional[torch.Tensor] = None,
|
970 |
+
use_cache: Optional[bool] = None,
|
971 |
+
is_first_forward: bool = True,
|
972 |
+
**kwargs
|
973 |
+
) -> dict:
|
974 |
+
# only last token for input_ids if past is not None
|
975 |
+
if position_ids is None:
|
976 |
+
position_ids = self.get_position_ids(input_ids, device=input_ids.device)
|
977 |
+
if not is_first_forward:
|
978 |
+
if past_key_values is not None:
|
979 |
+
position_ids = position_ids[..., -1:]
|
980 |
+
input_ids = input_ids[:, -1:]
|
981 |
+
return {
|
982 |
+
"input_ids": input_ids,
|
983 |
+
"past_key_values": past_key_values,
|
984 |
+
"position_ids": position_ids,
|
985 |
+
"attention_mask": attention_mask,
|
986 |
+
"return_last_logit": True,
|
987 |
+
"use_cache": use_cache
|
988 |
+
}
|
989 |
+
|
990 |
+
def forward(
|
991 |
+
self,
|
992 |
+
input_ids: Optional[torch.Tensor] = None,
|
993 |
+
position_ids: Optional[torch.Tensor] = None,
|
994 |
+
attention_mask: Optional[torch.Tensor] = None,
|
995 |
+
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
996 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
997 |
+
labels: Optional[torch.Tensor] = None,
|
998 |
+
use_cache: Optional[bool] = None,
|
999 |
+
output_attentions: Optional[bool] = None,
|
1000 |
+
output_hidden_states: Optional[bool] = None,
|
1001 |
+
return_dict: Optional[bool] = None,
|
1002 |
+
return_last_logit: Optional[bool] = False,
|
1003 |
+
):
|
1004 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1005 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1006 |
+
|
1007 |
+
transformer_outputs = self.transformer(
|
1008 |
+
input_ids=input_ids,
|
1009 |
+
position_ids=position_ids,
|
1010 |
+
attention_mask=attention_mask,
|
1011 |
+
past_key_values=past_key_values,
|
1012 |
+
inputs_embeds=inputs_embeds,
|
1013 |
+
use_cache=use_cache,
|
1014 |
+
output_hidden_states=output_hidden_states,
|
1015 |
+
return_dict=return_dict,
|
1016 |
+
)
|
1017 |
+
|
1018 |
+
hidden_states = transformer_outputs[0]
|
1019 |
+
if return_last_logit:
|
1020 |
+
hidden_states = hidden_states[-1:]
|
1021 |
+
lm_logits = self.transformer.output_layer(hidden_states)
|
1022 |
+
lm_logits = lm_logits.transpose(0, 1).contiguous()
|
1023 |
+
|
1024 |
+
loss = None
|
1025 |
+
if labels is not None:
|
1026 |
+
lm_logits = lm_logits.to(torch.float32)
|
1027 |
+
|
1028 |
+
# Shift so that tokens < n predict n
|
1029 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
1030 |
+
shift_labels = labels[..., 1:].contiguous()
|
1031 |
+
# Flatten the tokens
|
1032 |
+
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
1033 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
1034 |
+
|
1035 |
+
lm_logits = lm_logits.to(hidden_states.dtype)
|
1036 |
+
loss = loss.to(hidden_states.dtype)
|
1037 |
+
|
1038 |
+
if not return_dict:
|
1039 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
1040 |
+
return ((loss,) + output) if loss is not None else output
|
1041 |
+
|
1042 |
+
return CausalLMOutputWithPast(
|
1043 |
+
loss=loss,
|
1044 |
+
logits=lm_logits,
|
1045 |
+
past_key_values=transformer_outputs.past_key_values,
|
1046 |
+
hidden_states=transformer_outputs.hidden_states,
|
1047 |
+
attentions=transformer_outputs.attentions,
|
1048 |
+
)
|
1049 |
+
|
1050 |
+
@staticmethod
|
1051 |
+
def _reorder_cache(
|
1052 |
+
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
|
1053 |
+
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
1054 |
+
"""
|
1055 |
+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
1056 |
+
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
1057 |
+
beam_idx at every generation step.
|
1058 |
+
|
1059 |
+
Output shares the same memory storage as `past`.
|
1060 |
+
"""
|
1061 |
+
return tuple(
|
1062 |
+
(
|
1063 |
+
layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
|
1064 |
+
layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
|
1065 |
+
)
|
1066 |
+
for layer_past in past
|
1067 |
+
)
|
1068 |
+
|
1069 |
+
def process_response(self, output, history):
|
1070 |
+
content = ""
|
1071 |
+
history = deepcopy(history)
|
1072 |
+
for response in output.split("<|assistant|>"):
|
1073 |
+
metadata, content = response.split("\n", maxsplit=1)
|
1074 |
+
if not metadata.strip():
|
1075 |
+
content = content.strip()
|
1076 |
+
history.append({"role": "assistant", "metadata": metadata, "content": content})
|
1077 |
+
content = content.replace("[[训练时间]]", "2023年")
|
1078 |
+
else:
|
1079 |
+
history.append({"role": "assistant", "metadata": metadata, "content": content})
|
1080 |
+
if history[0]["role"] == "system" and "tools" in history[0]:
|
1081 |
+
content = "\n".join(content.split("\n")[1:-1])
|
1082 |
+
def tool_call(**kwargs):
|
1083 |
+
return kwargs
|
1084 |
+
parameters = eval(content)
|
1085 |
+
content = {"name": metadata.strip(), "parameters": parameters}
|
1086 |
+
else:
|
1087 |
+
content = {"name": metadata.strip(), "content": content}
|
1088 |
+
return content, history
|
1089 |
+
|
1090 |
+
@torch.inference_mode()
|
1091 |
+
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
|
1092 |
+
max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
|
1093 |
+
**kwargs):
|
1094 |
+
if history is None:
|
1095 |
+
history = []
|
1096 |
+
if logits_processor is None:
|
1097 |
+
logits_processor = LogitsProcessorList()
|
1098 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
1099 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
1100 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1101 |
+
inputs = tokenizer.build_chat_input(query, history=history, role=role)
|
1102 |
+
inputs = inputs.to(self.device)
|
1103 |
+
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
|
1104 |
+
tokenizer.get_command("<|observation|>")]
|
1105 |
+
outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
|
1106 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
|
1107 |
+
response = tokenizer.decode(outputs)
|
1108 |
+
history.append({"role": role, "content": query})
|
1109 |
+
response, history = self.process_response(response, history)
|
1110 |
+
return response, history
|
1111 |
+
|
1112 |
+
@torch.inference_mode()
|
1113 |
+
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
|
1114 |
+
past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
|
1115 |
+
logits_processor=None, return_past_key_values=False, **kwargs):
|
1116 |
+
if history is None:
|
1117 |
+
history = []
|
1118 |
+
if logits_processor is None:
|
1119 |
+
logits_processor = LogitsProcessorList()
|
1120 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
1121 |
+
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
|
1122 |
+
tokenizer.get_command("<|observation|>")]
|
1123 |
+
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
1124 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1125 |
+
if past_key_values is None:
|
1126 |
+
inputs = tokenizer.build_chat_input(query, history=history, role=role)
|
1127 |
+
else:
|
1128 |
+
inputs = tokenizer.build_chat_input(query, role=role)
|
1129 |
+
inputs = inputs.to(self.device)
|
1130 |
+
if past_key_values is not None:
|
1131 |
+
past_length = past_key_values[0][0].shape[0]
|
1132 |
+
if self.transformer.pre_seq_len is not None:
|
1133 |
+
past_length -= self.transformer.pre_seq_len
|
1134 |
+
inputs.position_ids += past_length
|
1135 |
+
attention_mask = inputs.attention_mask
|
1136 |
+
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
1137 |
+
inputs['attention_mask'] = attention_mask
|
1138 |
+
history.append({"role": role, "content": query})
|
1139 |
+
for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
|
1140 |
+
eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
|
1141 |
+
**gen_kwargs):
|
1142 |
+
if return_past_key_values:
|
1143 |
+
outputs, past_key_values = outputs
|
1144 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
|
1145 |
+
response = tokenizer.decode(outputs)
|
1146 |
+
if response and response[-1] != "�":
|
1147 |
+
response, new_history = self.process_response(response, history)
|
1148 |
+
if return_past_key_values:
|
1149 |
+
yield response, new_history, past_key_values
|
1150 |
+
else:
|
1151 |
+
yield response, new_history
|
1152 |
+
|
1153 |
+
@torch.inference_mode()
|
1154 |
+
def stream_generate(
|
1155 |
+
self,
|
1156 |
+
input_ids,
|
1157 |
+
generation_config: Optional[GenerationConfig] = None,
|
1158 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
1159 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
1160 |
+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
1161 |
+
return_past_key_values=False,
|
1162 |
+
**kwargs,
|
1163 |
+
):
|
1164 |
+
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
1165 |
+
|
1166 |
+
if generation_config is None:
|
1167 |
+
generation_config = self.generation_config
|
1168 |
+
generation_config = copy.deepcopy(generation_config)
|
1169 |
+
model_kwargs = generation_config.update(**kwargs)
|
1170 |
+
model_kwargs["use_cache"] = generation_config.use_cache
|
1171 |
+
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
|
1172 |
+
|
1173 |
+
if isinstance(eos_token_id, int):
|
1174 |
+
eos_token_id = [eos_token_id]
|
1175 |
+
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
1176 |
+
|
1177 |
+
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
1178 |
+
if has_default_max_length and generation_config.max_new_tokens is None:
|
1179 |
+
warnings.warn(
|
1180 |
+
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
|
1181 |
+
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
|
1182 |
+
" recommend using `max_new_tokens` to control the maximum length of the generation.",
|
1183 |
+
UserWarning,
|
1184 |
+
)
|
1185 |
+
elif generation_config.max_new_tokens is not None:
|
1186 |
+
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
1187 |
+
if not has_default_max_length:
|
1188 |
+
logger.warn(
|
1189 |
+
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
1190 |
+
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
1191 |
+
"Please refer to the documentation for more information. "
|
1192 |
+
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
|
1193 |
+
UserWarning,
|
1194 |
+
)
|
1195 |
+
|
1196 |
+
if input_ids_seq_length >= generation_config.max_length:
|
1197 |
+
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
1198 |
+
logger.warning(
|
1199 |
+
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
1200 |
+
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
1201 |
+
" increasing `max_new_tokens`."
|
1202 |
+
)
|
1203 |
+
|
1204 |
+
# 2. Set generation parameters if not already defined
|
1205 |
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
1206 |
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
1207 |
+
|
1208 |
+
logits_processor = self._get_logits_processor(
|
1209 |
+
generation_config=generation_config,
|
1210 |
+
input_ids_seq_length=input_ids_seq_length,
|
1211 |
+
encoder_input_ids=input_ids,
|
1212 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
1213 |
+
logits_processor=logits_processor,
|
1214 |
+
)
|
1215 |
+
|
1216 |
+
stopping_criteria = self._get_stopping_criteria(
|
1217 |
+
generation_config=generation_config, stopping_criteria=stopping_criteria
|
1218 |
+
)
|
1219 |
+
logits_warper = self._get_logits_warper(generation_config)
|
1220 |
+
|
1221 |
+
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
1222 |
+
scores = None
|
1223 |
+
while True:
|
1224 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
1225 |
+
# forward pass to get next token
|
1226 |
+
outputs = self(
|
1227 |
+
**model_inputs,
|
1228 |
+
return_dict=True,
|
1229 |
+
output_attentions=False,
|
1230 |
+
output_hidden_states=False,
|
1231 |
+
)
|
1232 |
+
|
1233 |
+
next_token_logits = outputs.logits[:, -1, :]
|
1234 |
+
|
1235 |
+
# pre-process distribution
|
1236 |
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
1237 |
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
1238 |
+
|
1239 |
+
# sample
|
1240 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
1241 |
+
if generation_config.do_sample:
|
1242 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
1243 |
+
else:
|
1244 |
+
next_tokens = torch.argmax(probs, dim=-1)
|
1245 |
+
# update generated ids, model inputs, and length for next step
|
1246 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
1247 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
1248 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
1249 |
+
)
|
1250 |
+
unfinished_sequences = unfinished_sequences.mul(
|
1251 |
+
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
1252 |
+
)
|
1253 |
+
if return_past_key_values:
|
1254 |
+
yield input_ids, outputs.past_key_values
|
1255 |
+
else:
|
1256 |
+
yield input_ids
|
1257 |
+
# stop when each sentence is finished, or if we exceed the maximum length
|
1258 |
+
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
1259 |
+
break
|
1260 |
+
|
1261 |
+
def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
|
1262 |
+
if bits == 0:
|
1263 |
+
return
|
1264 |
+
|
1265 |
+
from .quantization import quantize
|
1266 |
+
|
1267 |
+
if self.quantized:
|
1268 |
+
logger.info("Already quantized.")
|
1269 |
+
return self
|
1270 |
+
|
1271 |
+
self.quantized = True
|
1272 |
+
|
1273 |
+
self.config.quantization_bit = bits
|
1274 |
+
|
1275 |
+
self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device,
|
1276 |
+
**kwargs)
|
1277 |
+
return self
|
1278 |
+
|
1279 |
+
|
1280 |
+
class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
1281 |
+
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
|
1282 |
+
super().__init__(config)
|
1283 |
+
|
1284 |
+
self.num_labels = config.num_labels
|
1285 |
+
self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
|
1286 |
+
|
1287 |
+
self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
|
1288 |
+
if config.classifier_dropout is not None:
|
1289 |
+
self.dropout = nn.Dropout(config.classifier_dropout)
|
1290 |
+
else:
|
1291 |
+
self.dropout = None
|
1292 |
+
self.config = config
|
1293 |
+
|
1294 |
+
if self.config.quantization_bit:
|
1295 |
+
self.quantize(self.config.quantization_bit, empty_init=True)
|
1296 |
+
|
1297 |
+
def forward(
|
1298 |
+
self,
|
1299 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1300 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1301 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1302 |
+
full_attention_mask: Optional[torch.Tensor] = None,
|
1303 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
1304 |
+
inputs_embeds: Optional[torch.LongTensor] = None,
|
1305 |
+
labels: Optional[torch.LongTensor] = None,
|
1306 |
+
use_cache: Optional[bool] = None,
|
1307 |
+
output_hidden_states: Optional[bool] = None,
|
1308 |
+
return_dict: Optional[bool] = None,
|
1309 |
+
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
|
1310 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1311 |
+
|
1312 |
+
transformer_outputs = self.transformer(
|
1313 |
+
input_ids=input_ids,
|
1314 |
+
position_ids=position_ids,
|
1315 |
+
attention_mask=attention_mask,
|
1316 |
+
full_attention_mask=full_attention_mask,
|
1317 |
+
past_key_values=past_key_values,
|
1318 |
+
inputs_embeds=inputs_embeds,
|
1319 |
+
use_cache=use_cache,
|
1320 |
+
output_hidden_states=output_hidden_states,
|
1321 |
+
return_dict=return_dict,
|
1322 |
+
)
|
1323 |
+
|
1324 |
+
hidden_states = transformer_outputs[0]
|
1325 |
+
pooled_hidden_states = hidden_states[-1]
|
1326 |
+
if self.dropout is not None:
|
1327 |
+
pooled_hidden_states = self.dropout(pooled_hidden_states)
|
1328 |
+
logits = self.classifier_head(pooled_hidden_states)
|
1329 |
+
|
1330 |
+
loss = None
|
1331 |
+
if labels is not None:
|
1332 |
+
if self.config.problem_type is None:
|
1333 |
+
if self.num_labels == 1:
|
1334 |
+
self.config.problem_type = "regression"
|
1335 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1336 |
+
self.config.problem_type = "single_label_classification"
|
1337 |
+
else:
|
1338 |
+
self.config.problem_type = "multi_label_classification"
|
1339 |
+
|
1340 |
+
if self.config.problem_type == "regression":
|
1341 |
+
loss_fct = MSELoss()
|
1342 |
+
if self.num_labels == 1:
|
1343 |
+
loss = loss_fct(logits.squeeze().float(), labels.squeeze())
|
1344 |
+
else:
|
1345 |
+
loss = loss_fct(logits.float(), labels)
|
1346 |
+
elif self.config.problem_type == "single_label_classification":
|
1347 |
+
loss_fct = CrossEntropyLoss()
|
1348 |
+
loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
|
1349 |
+
elif self.config.problem_type == "multi_label_classification":
|
1350 |
+
loss_fct = BCEWithLogitsLoss()
|
1351 |
+
loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))
|
1352 |
+
|
1353 |
+
if not return_dict:
|
1354 |
+
output = (logits,) + transformer_outputs[1:]
|
1355 |
+
return ((loss,) + output) if loss is not None else output
|
1356 |
+
|
1357 |
+
return SequenceClassifierOutputWithPast(
|
1358 |
+
loss=loss,
|
1359 |
+
logits=logits,
|
1360 |
+
past_key_values=transformer_outputs.past_key_values,
|
1361 |
+
hidden_states=transformer_outputs.hidden_states,
|
1362 |
+
attentions=transformer_outputs.attentions,
|
1363 |
+
)
|
diffsynth/models/sd3_dit.py
ADDED
@@ -0,0 +1,797 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import rearrange
|
3 |
+
from .svd_unet import TemporalTimesteps
|
4 |
+
from .tiler import TileWorker
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
class PatchEmbed(torch.nn.Module):
|
9 |
+
def __init__(self, patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192):
|
10 |
+
super().__init__()
|
11 |
+
self.pos_embed_max_size = pos_embed_max_size
|
12 |
+
self.patch_size = patch_size
|
13 |
+
|
14 |
+
self.proj = torch.nn.Conv2d(in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size)
|
15 |
+
self.pos_embed = torch.nn.Parameter(torch.zeros(1, self.pos_embed_max_size, self.pos_embed_max_size, 1536))
|
16 |
+
|
17 |
+
def cropped_pos_embed(self, height, width):
|
18 |
+
height = height // self.patch_size
|
19 |
+
width = width // self.patch_size
|
20 |
+
top = (self.pos_embed_max_size - height) // 2
|
21 |
+
left = (self.pos_embed_max_size - width) // 2
|
22 |
+
spatial_pos_embed = self.pos_embed[:, top : top + height, left : left + width, :].flatten(1, 2)
|
23 |
+
return spatial_pos_embed
|
24 |
+
|
25 |
+
def forward(self, latent):
|
26 |
+
height, width = latent.shape[-2:]
|
27 |
+
latent = self.proj(latent)
|
28 |
+
latent = latent.flatten(2).transpose(1, 2)
|
29 |
+
pos_embed = self.cropped_pos_embed(height, width)
|
30 |
+
return latent + pos_embed
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
class TimestepEmbeddings(torch.nn.Module):
|
35 |
+
def __init__(self, dim_in, dim_out):
|
36 |
+
super().__init__()
|
37 |
+
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
|
38 |
+
self.timestep_embedder = torch.nn.Sequential(
|
39 |
+
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
40 |
+
)
|
41 |
+
|
42 |
+
def forward(self, timestep, dtype):
|
43 |
+
time_emb = self.time_proj(timestep).to(dtype)
|
44 |
+
time_emb = self.timestep_embedder(time_emb)
|
45 |
+
return time_emb
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
class AdaLayerNorm(torch.nn.Module):
|
50 |
+
def __init__(self, dim, single=False):
|
51 |
+
super().__init__()
|
52 |
+
self.single = single
|
53 |
+
self.linear = torch.nn.Linear(dim, dim * (2 if single else 6))
|
54 |
+
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
55 |
+
|
56 |
+
def forward(self, x, emb):
|
57 |
+
emb = self.linear(torch.nn.functional.silu(emb))
|
58 |
+
if self.single:
|
59 |
+
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
|
60 |
+
x = self.norm(x) * (1 + scale) + shift
|
61 |
+
return x
|
62 |
+
else:
|
63 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
|
64 |
+
x = self.norm(x) * (1 + scale_msa) + shift_msa
|
65 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
class JointAttention(torch.nn.Module):
|
70 |
+
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
|
71 |
+
super().__init__()
|
72 |
+
self.num_heads = num_heads
|
73 |
+
self.head_dim = head_dim
|
74 |
+
self.only_out_a = only_out_a
|
75 |
+
|
76 |
+
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
77 |
+
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
|
78 |
+
|
79 |
+
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
80 |
+
if not only_out_a:
|
81 |
+
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
|
82 |
+
|
83 |
+
def forward(self, hidden_states_a, hidden_states_b):
|
84 |
+
batch_size = hidden_states_a.shape[0]
|
85 |
+
|
86 |
+
qkv = torch.concat([self.a_to_qkv(hidden_states_a), self.b_to_qkv(hidden_states_b)], dim=1)
|
87 |
+
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
88 |
+
q, k, v = qkv.chunk(3, dim=1)
|
89 |
+
|
90 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
91 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
92 |
+
hidden_states = hidden_states.to(q.dtype)
|
93 |
+
hidden_states_a, hidden_states_b = hidden_states[:, :hidden_states_a.shape[1]], hidden_states[:, hidden_states_a.shape[1]:]
|
94 |
+
hidden_states_a = self.a_to_out(hidden_states_a)
|
95 |
+
if self.only_out_a:
|
96 |
+
return hidden_states_a
|
97 |
+
else:
|
98 |
+
hidden_states_b = self.b_to_out(hidden_states_b)
|
99 |
+
return hidden_states_a, hidden_states_b
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
class JointTransformerBlock(torch.nn.Module):
|
104 |
+
def __init__(self, dim, num_attention_heads):
|
105 |
+
super().__init__()
|
106 |
+
self.norm1_a = AdaLayerNorm(dim)
|
107 |
+
self.norm1_b = AdaLayerNorm(dim)
|
108 |
+
|
109 |
+
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
|
110 |
+
|
111 |
+
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
112 |
+
self.ff_a = torch.nn.Sequential(
|
113 |
+
torch.nn.Linear(dim, dim*4),
|
114 |
+
torch.nn.GELU(approximate="tanh"),
|
115 |
+
torch.nn.Linear(dim*4, dim)
|
116 |
+
)
|
117 |
+
|
118 |
+
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
119 |
+
self.ff_b = torch.nn.Sequential(
|
120 |
+
torch.nn.Linear(dim, dim*4),
|
121 |
+
torch.nn.GELU(approximate="tanh"),
|
122 |
+
torch.nn.Linear(dim*4, dim)
|
123 |
+
)
|
124 |
+
|
125 |
+
|
126 |
+
def forward(self, hidden_states_a, hidden_states_b, temb):
|
127 |
+
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
128 |
+
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
129 |
+
|
130 |
+
# Attention
|
131 |
+
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
132 |
+
|
133 |
+
# Part A
|
134 |
+
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
135 |
+
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
136 |
+
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
137 |
+
|
138 |
+
# Part B
|
139 |
+
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
140 |
+
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
141 |
+
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
142 |
+
|
143 |
+
return hidden_states_a, hidden_states_b
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
class JointTransformerFinalBlock(torch.nn.Module):
|
148 |
+
def __init__(self, dim, num_attention_heads):
|
149 |
+
super().__init__()
|
150 |
+
self.norm1_a = AdaLayerNorm(dim)
|
151 |
+
self.norm1_b = AdaLayerNorm(dim, single=True)
|
152 |
+
|
153 |
+
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, only_out_a=True)
|
154 |
+
|
155 |
+
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
156 |
+
self.ff_a = torch.nn.Sequential(
|
157 |
+
torch.nn.Linear(dim, dim*4),
|
158 |
+
torch.nn.GELU(approximate="tanh"),
|
159 |
+
torch.nn.Linear(dim*4, dim)
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
def forward(self, hidden_states_a, hidden_states_b, temb):
|
164 |
+
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
165 |
+
norm_hidden_states_b = self.norm1_b(hidden_states_b, emb=temb)
|
166 |
+
|
167 |
+
# Attention
|
168 |
+
attn_output_a = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
169 |
+
|
170 |
+
# Part A
|
171 |
+
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
172 |
+
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
173 |
+
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
174 |
+
|
175 |
+
return hidden_states_a, hidden_states_b
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
class SD3DiT(torch.nn.Module):
|
180 |
+
def __init__(self):
|
181 |
+
super().__init__()
|
182 |
+
self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192)
|
183 |
+
self.time_embedder = TimestepEmbeddings(256, 1536)
|
184 |
+
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, 1536), torch.nn.SiLU(), torch.nn.Linear(1536, 1536))
|
185 |
+
self.context_embedder = torch.nn.Linear(4096, 1536)
|
186 |
+
self.blocks = torch.nn.ModuleList([JointTransformerBlock(1536, 24) for _ in range(23)] + [JointTransformerFinalBlock(1536, 24)])
|
187 |
+
self.norm_out = AdaLayerNorm(1536, single=True)
|
188 |
+
self.proj_out = torch.nn.Linear(1536, 64)
|
189 |
+
|
190 |
+
def tiled_forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size=128, tile_stride=64):
|
191 |
+
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
|
192 |
+
hidden_states = TileWorker().tiled_forward(
|
193 |
+
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb),
|
194 |
+
hidden_states,
|
195 |
+
tile_size,
|
196 |
+
tile_stride,
|
197 |
+
tile_device=hidden_states.device,
|
198 |
+
tile_dtype=hidden_states.dtype
|
199 |
+
)
|
200 |
+
return hidden_states
|
201 |
+
|
202 |
+
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tiled=False, tile_size=128, tile_stride=64, use_gradient_checkpointing=False):
|
203 |
+
if tiled:
|
204 |
+
return self.tiled_forward(hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size, tile_stride)
|
205 |
+
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
206 |
+
prompt_emb = self.context_embedder(prompt_emb)
|
207 |
+
|
208 |
+
height, width = hidden_states.shape[-2:]
|
209 |
+
hidden_states = self.pos_embedder(hidden_states)
|
210 |
+
|
211 |
+
def create_custom_forward(module):
|
212 |
+
def custom_forward(*inputs):
|
213 |
+
return module(*inputs)
|
214 |
+
return custom_forward
|
215 |
+
|
216 |
+
for block in self.blocks:
|
217 |
+
if self.training and use_gradient_checkpointing:
|
218 |
+
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
219 |
+
create_custom_forward(block),
|
220 |
+
hidden_states, prompt_emb, conditioning,
|
221 |
+
use_reentrant=False,
|
222 |
+
)
|
223 |
+
else:
|
224 |
+
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning)
|
225 |
+
|
226 |
+
hidden_states = self.norm_out(hidden_states, conditioning)
|
227 |
+
hidden_states = self.proj_out(hidden_states)
|
228 |
+
hidden_states = rearrange(hidden_states, "B (H W) (P Q C) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
229 |
+
return hidden_states
|
230 |
+
|
231 |
+
def state_dict_converter(self):
|
232 |
+
return SD3DiTStateDictConverter()
|
233 |
+
|
234 |
+
|
235 |
+
|
236 |
+
class SD3DiTStateDictConverter:
|
237 |
+
def __init__(self):
|
238 |
+
pass
|
239 |
+
|
240 |
+
def from_diffusers(self, state_dict):
|
241 |
+
rename_dict = {
|
242 |
+
"context_embedder": "context_embedder",
|
243 |
+
"pos_embed.pos_embed": "pos_embedder.pos_embed",
|
244 |
+
"pos_embed.proj": "pos_embedder.proj",
|
245 |
+
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
246 |
+
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
247 |
+
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
248 |
+
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
249 |
+
"norm_out.linear": "norm_out.linear",
|
250 |
+
"proj_out": "proj_out",
|
251 |
+
|
252 |
+
"norm1.linear": "norm1_a.linear",
|
253 |
+
"norm1_context.linear": "norm1_b.linear",
|
254 |
+
"attn.to_q": "attn.a_to_q",
|
255 |
+
"attn.to_k": "attn.a_to_k",
|
256 |
+
"attn.to_v": "attn.a_to_v",
|
257 |
+
"attn.to_out.0": "attn.a_to_out",
|
258 |
+
"attn.add_q_proj": "attn.b_to_q",
|
259 |
+
"attn.add_k_proj": "attn.b_to_k",
|
260 |
+
"attn.add_v_proj": "attn.b_to_v",
|
261 |
+
"attn.to_add_out": "attn.b_to_out",
|
262 |
+
"ff.net.0.proj": "ff_a.0",
|
263 |
+
"ff.net.2": "ff_a.2",
|
264 |
+
"ff_context.net.0.proj": "ff_b.0",
|
265 |
+
"ff_context.net.2": "ff_b.2",
|
266 |
+
}
|
267 |
+
state_dict_ = {}
|
268 |
+
for name, param in state_dict.items():
|
269 |
+
if name in rename_dict:
|
270 |
+
if name == "pos_embed.pos_embed":
|
271 |
+
param = param.reshape((1, 192, 192, 1536))
|
272 |
+
state_dict_[rename_dict[name]] = param
|
273 |
+
elif name.endswith(".weight") or name.endswith(".bias"):
|
274 |
+
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
275 |
+
prefix = name[:-len(suffix)]
|
276 |
+
if prefix in rename_dict:
|
277 |
+
state_dict_[rename_dict[prefix] + suffix] = param
|
278 |
+
elif prefix.startswith("transformer_blocks."):
|
279 |
+
names = prefix.split(".")
|
280 |
+
names[0] = "blocks"
|
281 |
+
middle = ".".join(names[2:])
|
282 |
+
if middle in rename_dict:
|
283 |
+
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
284 |
+
state_dict_[name_] = param
|
285 |
+
return state_dict_
|
286 |
+
|
287 |
+
def from_civitai(self, state_dict):
|
288 |
+
rename_dict = {
|
289 |
+
"model.diffusion_model.context_embedder.bias": "context_embedder.bias",
|
290 |
+
"model.diffusion_model.context_embedder.weight": "context_embedder.weight",
|
291 |
+
"model.diffusion_model.final_layer.linear.bias": "proj_out.bias",
|
292 |
+
"model.diffusion_model.final_layer.linear.weight": "proj_out.weight",
|
293 |
+
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias": "blocks.0.norm1_b.linear.bias",
|
294 |
+
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.weight": "blocks.0.norm1_b.linear.weight",
|
295 |
+
"model.diffusion_model.joint_blocks.0.context_block.attn.proj.bias": "blocks.0.attn.b_to_out.bias",
|
296 |
+
"model.diffusion_model.joint_blocks.0.context_block.attn.proj.weight": "blocks.0.attn.b_to_out.weight",
|
297 |
+
"model.diffusion_model.joint_blocks.0.context_block.attn.qkv.bias": ['blocks.0.attn.b_to_q.bias', 'blocks.0.attn.b_to_k.bias', 'blocks.0.attn.b_to_v.bias'],
|
298 |
+
"model.diffusion_model.joint_blocks.0.context_block.attn.qkv.weight": ['blocks.0.attn.b_to_q.weight', 'blocks.0.attn.b_to_k.weight', 'blocks.0.attn.b_to_v.weight'],
|
299 |
+
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.bias": "blocks.0.ff_b.0.bias",
|
300 |
+
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.weight": "blocks.0.ff_b.0.weight",
|
301 |
+
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.bias": "blocks.0.ff_b.2.bias",
|
302 |
+
"model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.weight": "blocks.0.ff_b.2.weight",
|
303 |
+
"model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.bias": "blocks.0.norm1_a.linear.bias",
|
304 |
+
"model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.weight": "blocks.0.norm1_a.linear.weight",
|
305 |
+
"model.diffusion_model.joint_blocks.0.x_block.attn.proj.bias": "blocks.0.attn.a_to_out.bias",
|
306 |
+
"model.diffusion_model.joint_blocks.0.x_block.attn.proj.weight": "blocks.0.attn.a_to_out.weight",
|
307 |
+
"model.diffusion_model.joint_blocks.0.x_block.attn.qkv.bias": ['blocks.0.attn.a_to_q.bias', 'blocks.0.attn.a_to_k.bias', 'blocks.0.attn.a_to_v.bias'],
|
308 |
+
"model.diffusion_model.joint_blocks.0.x_block.attn.qkv.weight": ['blocks.0.attn.a_to_q.weight', 'blocks.0.attn.a_to_k.weight', 'blocks.0.attn.a_to_v.weight'],
|
309 |
+
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.bias": "blocks.0.ff_a.0.bias",
|
310 |
+
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.weight": "blocks.0.ff_a.0.weight",
|
311 |
+
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.bias": "blocks.0.ff_a.2.bias",
|
312 |
+
"model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.weight": "blocks.0.ff_a.2.weight",
|
313 |
+
"model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.bias": "blocks.1.norm1_b.linear.bias",
|
314 |
+
"model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.weight": "blocks.1.norm1_b.linear.weight",
|
315 |
+
"model.diffusion_model.joint_blocks.1.context_block.attn.proj.bias": "blocks.1.attn.b_to_out.bias",
|
316 |
+
"model.diffusion_model.joint_blocks.1.context_block.attn.proj.weight": "blocks.1.attn.b_to_out.weight",
|
317 |
+
"model.diffusion_model.joint_blocks.1.context_block.attn.qkv.bias": ['blocks.1.attn.b_to_q.bias', 'blocks.1.attn.b_to_k.bias', 'blocks.1.attn.b_to_v.bias'],
|
318 |
+
"model.diffusion_model.joint_blocks.1.context_block.attn.qkv.weight": ['blocks.1.attn.b_to_q.weight', 'blocks.1.attn.b_to_k.weight', 'blocks.1.attn.b_to_v.weight'],
|
319 |
+
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.bias": "blocks.1.ff_b.0.bias",
|
320 |
+
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.weight": "blocks.1.ff_b.0.weight",
|
321 |
+
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.bias": "blocks.1.ff_b.2.bias",
|
322 |
+
"model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.weight": "blocks.1.ff_b.2.weight",
|
323 |
+
"model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.bias": "blocks.1.norm1_a.linear.bias",
|
324 |
+
"model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.weight": "blocks.1.norm1_a.linear.weight",
|
325 |
+
"model.diffusion_model.joint_blocks.1.x_block.attn.proj.bias": "blocks.1.attn.a_to_out.bias",
|
326 |
+
"model.diffusion_model.joint_blocks.1.x_block.attn.proj.weight": "blocks.1.attn.a_to_out.weight",
|
327 |
+
"model.diffusion_model.joint_blocks.1.x_block.attn.qkv.bias": ['blocks.1.attn.a_to_q.bias', 'blocks.1.attn.a_to_k.bias', 'blocks.1.attn.a_to_v.bias'],
|
328 |
+
"model.diffusion_model.joint_blocks.1.x_block.attn.qkv.weight": ['blocks.1.attn.a_to_q.weight', 'blocks.1.attn.a_to_k.weight', 'blocks.1.attn.a_to_v.weight'],
|
329 |
+
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.bias": "blocks.1.ff_a.0.bias",
|
330 |
+
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.weight": "blocks.1.ff_a.0.weight",
|
331 |
+
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.bias": "blocks.1.ff_a.2.bias",
|
332 |
+
"model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.weight": "blocks.1.ff_a.2.weight",
|
333 |
+
"model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.bias": "blocks.10.norm1_b.linear.bias",
|
334 |
+
"model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.weight": "blocks.10.norm1_b.linear.weight",
|
335 |
+
"model.diffusion_model.joint_blocks.10.context_block.attn.proj.bias": "blocks.10.attn.b_to_out.bias",
|
336 |
+
"model.diffusion_model.joint_blocks.10.context_block.attn.proj.weight": "blocks.10.attn.b_to_out.weight",
|
337 |
+
"model.diffusion_model.joint_blocks.10.context_block.attn.qkv.bias": ['blocks.10.attn.b_to_q.bias', 'blocks.10.attn.b_to_k.bias', 'blocks.10.attn.b_to_v.bias'],
|
338 |
+
"model.diffusion_model.joint_blocks.10.context_block.attn.qkv.weight": ['blocks.10.attn.b_to_q.weight', 'blocks.10.attn.b_to_k.weight', 'blocks.10.attn.b_to_v.weight'],
|
339 |
+
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.bias": "blocks.10.ff_b.0.bias",
|
340 |
+
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.weight": "blocks.10.ff_b.0.weight",
|
341 |
+
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.bias": "blocks.10.ff_b.2.bias",
|
342 |
+
"model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.weight": "blocks.10.ff_b.2.weight",
|
343 |
+
"model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.bias": "blocks.10.norm1_a.linear.bias",
|
344 |
+
"model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.weight": "blocks.10.norm1_a.linear.weight",
|
345 |
+
"model.diffusion_model.joint_blocks.10.x_block.attn.proj.bias": "blocks.10.attn.a_to_out.bias",
|
346 |
+
"model.diffusion_model.joint_blocks.10.x_block.attn.proj.weight": "blocks.10.attn.a_to_out.weight",
|
347 |
+
"model.diffusion_model.joint_blocks.10.x_block.attn.qkv.bias": ['blocks.10.attn.a_to_q.bias', 'blocks.10.attn.a_to_k.bias', 'blocks.10.attn.a_to_v.bias'],
|
348 |
+
"model.diffusion_model.joint_blocks.10.x_block.attn.qkv.weight": ['blocks.10.attn.a_to_q.weight', 'blocks.10.attn.a_to_k.weight', 'blocks.10.attn.a_to_v.weight'],
|
349 |
+
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.bias": "blocks.10.ff_a.0.bias",
|
350 |
+
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.weight": "blocks.10.ff_a.0.weight",
|
351 |
+
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.bias": "blocks.10.ff_a.2.bias",
|
352 |
+
"model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.weight": "blocks.10.ff_a.2.weight",
|
353 |
+
"model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.bias": "blocks.11.norm1_b.linear.bias",
|
354 |
+
"model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.weight": "blocks.11.norm1_b.linear.weight",
|
355 |
+
"model.diffusion_model.joint_blocks.11.context_block.attn.proj.bias": "blocks.11.attn.b_to_out.bias",
|
356 |
+
"model.diffusion_model.joint_blocks.11.context_block.attn.proj.weight": "blocks.11.attn.b_to_out.weight",
|
357 |
+
"model.diffusion_model.joint_blocks.11.context_block.attn.qkv.bias": ['blocks.11.attn.b_to_q.bias', 'blocks.11.attn.b_to_k.bias', 'blocks.11.attn.b_to_v.bias'],
|
358 |
+
"model.diffusion_model.joint_blocks.11.context_block.attn.qkv.weight": ['blocks.11.attn.b_to_q.weight', 'blocks.11.attn.b_to_k.weight', 'blocks.11.attn.b_to_v.weight'],
|
359 |
+
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.bias": "blocks.11.ff_b.0.bias",
|
360 |
+
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.weight": "blocks.11.ff_b.0.weight",
|
361 |
+
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.bias": "blocks.11.ff_b.2.bias",
|
362 |
+
"model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.weight": "blocks.11.ff_b.2.weight",
|
363 |
+
"model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.bias": "blocks.11.norm1_a.linear.bias",
|
364 |
+
"model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.weight": "blocks.11.norm1_a.linear.weight",
|
365 |
+
"model.diffusion_model.joint_blocks.11.x_block.attn.proj.bias": "blocks.11.attn.a_to_out.bias",
|
366 |
+
"model.diffusion_model.joint_blocks.11.x_block.attn.proj.weight": "blocks.11.attn.a_to_out.weight",
|
367 |
+
"model.diffusion_model.joint_blocks.11.x_block.attn.qkv.bias": ['blocks.11.attn.a_to_q.bias', 'blocks.11.attn.a_to_k.bias', 'blocks.11.attn.a_to_v.bias'],
|
368 |
+
"model.diffusion_model.joint_blocks.11.x_block.attn.qkv.weight": ['blocks.11.attn.a_to_q.weight', 'blocks.11.attn.a_to_k.weight', 'blocks.11.attn.a_to_v.weight'],
|
369 |
+
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.bias": "blocks.11.ff_a.0.bias",
|
370 |
+
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.weight": "blocks.11.ff_a.0.weight",
|
371 |
+
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.bias": "blocks.11.ff_a.2.bias",
|
372 |
+
"model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.weight": "blocks.11.ff_a.2.weight",
|
373 |
+
"model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.bias": "blocks.12.norm1_b.linear.bias",
|
374 |
+
"model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.weight": "blocks.12.norm1_b.linear.weight",
|
375 |
+
"model.diffusion_model.joint_blocks.12.context_block.attn.proj.bias": "blocks.12.attn.b_to_out.bias",
|
376 |
+
"model.diffusion_model.joint_blocks.12.context_block.attn.proj.weight": "blocks.12.attn.b_to_out.weight",
|
377 |
+
"model.diffusion_model.joint_blocks.12.context_block.attn.qkv.bias": ['blocks.12.attn.b_to_q.bias', 'blocks.12.attn.b_to_k.bias', 'blocks.12.attn.b_to_v.bias'],
|
378 |
+
"model.diffusion_model.joint_blocks.12.context_block.attn.qkv.weight": ['blocks.12.attn.b_to_q.weight', 'blocks.12.attn.b_to_k.weight', 'blocks.12.attn.b_to_v.weight'],
|
379 |
+
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.bias": "blocks.12.ff_b.0.bias",
|
380 |
+
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.weight": "blocks.12.ff_b.0.weight",
|
381 |
+
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.bias": "blocks.12.ff_b.2.bias",
|
382 |
+
"model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.weight": "blocks.12.ff_b.2.weight",
|
383 |
+
"model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.bias": "blocks.12.norm1_a.linear.bias",
|
384 |
+
"model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.weight": "blocks.12.norm1_a.linear.weight",
|
385 |
+
"model.diffusion_model.joint_blocks.12.x_block.attn.proj.bias": "blocks.12.attn.a_to_out.bias",
|
386 |
+
"model.diffusion_model.joint_blocks.12.x_block.attn.proj.weight": "blocks.12.attn.a_to_out.weight",
|
387 |
+
"model.diffusion_model.joint_blocks.12.x_block.attn.qkv.bias": ['blocks.12.attn.a_to_q.bias', 'blocks.12.attn.a_to_k.bias', 'blocks.12.attn.a_to_v.bias'],
|
388 |
+
"model.diffusion_model.joint_blocks.12.x_block.attn.qkv.weight": ['blocks.12.attn.a_to_q.weight', 'blocks.12.attn.a_to_k.weight', 'blocks.12.attn.a_to_v.weight'],
|
389 |
+
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.bias": "blocks.12.ff_a.0.bias",
|
390 |
+
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.weight": "blocks.12.ff_a.0.weight",
|
391 |
+
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.bias": "blocks.12.ff_a.2.bias",
|
392 |
+
"model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.weight": "blocks.12.ff_a.2.weight",
|
393 |
+
"model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.bias": "blocks.13.norm1_b.linear.bias",
|
394 |
+
"model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.weight": "blocks.13.norm1_b.linear.weight",
|
395 |
+
"model.diffusion_model.joint_blocks.13.context_block.attn.proj.bias": "blocks.13.attn.b_to_out.bias",
|
396 |
+
"model.diffusion_model.joint_blocks.13.context_block.attn.proj.weight": "blocks.13.attn.b_to_out.weight",
|
397 |
+
"model.diffusion_model.joint_blocks.13.context_block.attn.qkv.bias": ['blocks.13.attn.b_to_q.bias', 'blocks.13.attn.b_to_k.bias', 'blocks.13.attn.b_to_v.bias'],
|
398 |
+
"model.diffusion_model.joint_blocks.13.context_block.attn.qkv.weight": ['blocks.13.attn.b_to_q.weight', 'blocks.13.attn.b_to_k.weight', 'blocks.13.attn.b_to_v.weight'],
|
399 |
+
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.bias": "blocks.13.ff_b.0.bias",
|
400 |
+
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.weight": "blocks.13.ff_b.0.weight",
|
401 |
+
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.bias": "blocks.13.ff_b.2.bias",
|
402 |
+
"model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.weight": "blocks.13.ff_b.2.weight",
|
403 |
+
"model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.bias": "blocks.13.norm1_a.linear.bias",
|
404 |
+
"model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.weight": "blocks.13.norm1_a.linear.weight",
|
405 |
+
"model.diffusion_model.joint_blocks.13.x_block.attn.proj.bias": "blocks.13.attn.a_to_out.bias",
|
406 |
+
"model.diffusion_model.joint_blocks.13.x_block.attn.proj.weight": "blocks.13.attn.a_to_out.weight",
|
407 |
+
"model.diffusion_model.joint_blocks.13.x_block.attn.qkv.bias": ['blocks.13.attn.a_to_q.bias', 'blocks.13.attn.a_to_k.bias', 'blocks.13.attn.a_to_v.bias'],
|
408 |
+
"model.diffusion_model.joint_blocks.13.x_block.attn.qkv.weight": ['blocks.13.attn.a_to_q.weight', 'blocks.13.attn.a_to_k.weight', 'blocks.13.attn.a_to_v.weight'],
|
409 |
+
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.bias": "blocks.13.ff_a.0.bias",
|
410 |
+
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.weight": "blocks.13.ff_a.0.weight",
|
411 |
+
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.bias": "blocks.13.ff_a.2.bias",
|
412 |
+
"model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.weight": "blocks.13.ff_a.2.weight",
|
413 |
+
"model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.bias": "blocks.14.norm1_b.linear.bias",
|
414 |
+
"model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.weight": "blocks.14.norm1_b.linear.weight",
|
415 |
+
"model.diffusion_model.joint_blocks.14.context_block.attn.proj.bias": "blocks.14.attn.b_to_out.bias",
|
416 |
+
"model.diffusion_model.joint_blocks.14.context_block.attn.proj.weight": "blocks.14.attn.b_to_out.weight",
|
417 |
+
"model.diffusion_model.joint_blocks.14.context_block.attn.qkv.bias": ['blocks.14.attn.b_to_q.bias', 'blocks.14.attn.b_to_k.bias', 'blocks.14.attn.b_to_v.bias'],
|
418 |
+
"model.diffusion_model.joint_blocks.14.context_block.attn.qkv.weight": ['blocks.14.attn.b_to_q.weight', 'blocks.14.attn.b_to_k.weight', 'blocks.14.attn.b_to_v.weight'],
|
419 |
+
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.bias": "blocks.14.ff_b.0.bias",
|
420 |
+
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.weight": "blocks.14.ff_b.0.weight",
|
421 |
+
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.bias": "blocks.14.ff_b.2.bias",
|
422 |
+
"model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.weight": "blocks.14.ff_b.2.weight",
|
423 |
+
"model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.bias": "blocks.14.norm1_a.linear.bias",
|
424 |
+
"model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.weight": "blocks.14.norm1_a.linear.weight",
|
425 |
+
"model.diffusion_model.joint_blocks.14.x_block.attn.proj.bias": "blocks.14.attn.a_to_out.bias",
|
426 |
+
"model.diffusion_model.joint_blocks.14.x_block.attn.proj.weight": "blocks.14.attn.a_to_out.weight",
|
427 |
+
"model.diffusion_model.joint_blocks.14.x_block.attn.qkv.bias": ['blocks.14.attn.a_to_q.bias', 'blocks.14.attn.a_to_k.bias', 'blocks.14.attn.a_to_v.bias'],
|
428 |
+
"model.diffusion_model.joint_blocks.14.x_block.attn.qkv.weight": ['blocks.14.attn.a_to_q.weight', 'blocks.14.attn.a_to_k.weight', 'blocks.14.attn.a_to_v.weight'],
|
429 |
+
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.bias": "blocks.14.ff_a.0.bias",
|
430 |
+
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.weight": "blocks.14.ff_a.0.weight",
|
431 |
+
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.bias": "blocks.14.ff_a.2.bias",
|
432 |
+
"model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.weight": "blocks.14.ff_a.2.weight",
|
433 |
+
"model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.bias": "blocks.15.norm1_b.linear.bias",
|
434 |
+
"model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.weight": "blocks.15.norm1_b.linear.weight",
|
435 |
+
"model.diffusion_model.joint_blocks.15.context_block.attn.proj.bias": "blocks.15.attn.b_to_out.bias",
|
436 |
+
"model.diffusion_model.joint_blocks.15.context_block.attn.proj.weight": "blocks.15.attn.b_to_out.weight",
|
437 |
+
"model.diffusion_model.joint_blocks.15.context_block.attn.qkv.bias": ['blocks.15.attn.b_to_q.bias', 'blocks.15.attn.b_to_k.bias', 'blocks.15.attn.b_to_v.bias'],
|
438 |
+
"model.diffusion_model.joint_blocks.15.context_block.attn.qkv.weight": ['blocks.15.attn.b_to_q.weight', 'blocks.15.attn.b_to_k.weight', 'blocks.15.attn.b_to_v.weight'],
|
439 |
+
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.bias": "blocks.15.ff_b.0.bias",
|
440 |
+
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.weight": "blocks.15.ff_b.0.weight",
|
441 |
+
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.bias": "blocks.15.ff_b.2.bias",
|
442 |
+
"model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.weight": "blocks.15.ff_b.2.weight",
|
443 |
+
"model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.bias": "blocks.15.norm1_a.linear.bias",
|
444 |
+
"model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.weight": "blocks.15.norm1_a.linear.weight",
|
445 |
+
"model.diffusion_model.joint_blocks.15.x_block.attn.proj.bias": "blocks.15.attn.a_to_out.bias",
|
446 |
+
"model.diffusion_model.joint_blocks.15.x_block.attn.proj.weight": "blocks.15.attn.a_to_out.weight",
|
447 |
+
"model.diffusion_model.joint_blocks.15.x_block.attn.qkv.bias": ['blocks.15.attn.a_to_q.bias', 'blocks.15.attn.a_to_k.bias', 'blocks.15.attn.a_to_v.bias'],
|
448 |
+
"model.diffusion_model.joint_blocks.15.x_block.attn.qkv.weight": ['blocks.15.attn.a_to_q.weight', 'blocks.15.attn.a_to_k.weight', 'blocks.15.attn.a_to_v.weight'],
|
449 |
+
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.bias": "blocks.15.ff_a.0.bias",
|
450 |
+
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.weight": "blocks.15.ff_a.0.weight",
|
451 |
+
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.bias": "blocks.15.ff_a.2.bias",
|
452 |
+
"model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.weight": "blocks.15.ff_a.2.weight",
|
453 |
+
"model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.bias": "blocks.16.norm1_b.linear.bias",
|
454 |
+
"model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.weight": "blocks.16.norm1_b.linear.weight",
|
455 |
+
"model.diffusion_model.joint_blocks.16.context_block.attn.proj.bias": "blocks.16.attn.b_to_out.bias",
|
456 |
+
"model.diffusion_model.joint_blocks.16.context_block.attn.proj.weight": "blocks.16.attn.b_to_out.weight",
|
457 |
+
"model.diffusion_model.joint_blocks.16.context_block.attn.qkv.bias": ['blocks.16.attn.b_to_q.bias', 'blocks.16.attn.b_to_k.bias', 'blocks.16.attn.b_to_v.bias'],
|
458 |
+
"model.diffusion_model.joint_blocks.16.context_block.attn.qkv.weight": ['blocks.16.attn.b_to_q.weight', 'blocks.16.attn.b_to_k.weight', 'blocks.16.attn.b_to_v.weight'],
|
459 |
+
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.bias": "blocks.16.ff_b.0.bias",
|
460 |
+
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.weight": "blocks.16.ff_b.0.weight",
|
461 |
+
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.bias": "blocks.16.ff_b.2.bias",
|
462 |
+
"model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.weight": "blocks.16.ff_b.2.weight",
|
463 |
+
"model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.bias": "blocks.16.norm1_a.linear.bias",
|
464 |
+
"model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.weight": "blocks.16.norm1_a.linear.weight",
|
465 |
+
"model.diffusion_model.joint_blocks.16.x_block.attn.proj.bias": "blocks.16.attn.a_to_out.bias",
|
466 |
+
"model.diffusion_model.joint_blocks.16.x_block.attn.proj.weight": "blocks.16.attn.a_to_out.weight",
|
467 |
+
"model.diffusion_model.joint_blocks.16.x_block.attn.qkv.bias": ['blocks.16.attn.a_to_q.bias', 'blocks.16.attn.a_to_k.bias', 'blocks.16.attn.a_to_v.bias'],
|
468 |
+
"model.diffusion_model.joint_blocks.16.x_block.attn.qkv.weight": ['blocks.16.attn.a_to_q.weight', 'blocks.16.attn.a_to_k.weight', 'blocks.16.attn.a_to_v.weight'],
|
469 |
+
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.bias": "blocks.16.ff_a.0.bias",
|
470 |
+
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.weight": "blocks.16.ff_a.0.weight",
|
471 |
+
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.bias": "blocks.16.ff_a.2.bias",
|
472 |
+
"model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.weight": "blocks.16.ff_a.2.weight",
|
473 |
+
"model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.bias": "blocks.17.norm1_b.linear.bias",
|
474 |
+
"model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.weight": "blocks.17.norm1_b.linear.weight",
|
475 |
+
"model.diffusion_model.joint_blocks.17.context_block.attn.proj.bias": "blocks.17.attn.b_to_out.bias",
|
476 |
+
"model.diffusion_model.joint_blocks.17.context_block.attn.proj.weight": "blocks.17.attn.b_to_out.weight",
|
477 |
+
"model.diffusion_model.joint_blocks.17.context_block.attn.qkv.bias": ['blocks.17.attn.b_to_q.bias', 'blocks.17.attn.b_to_k.bias', 'blocks.17.attn.b_to_v.bias'],
|
478 |
+
"model.diffusion_model.joint_blocks.17.context_block.attn.qkv.weight": ['blocks.17.attn.b_to_q.weight', 'blocks.17.attn.b_to_k.weight', 'blocks.17.attn.b_to_v.weight'],
|
479 |
+
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.bias": "blocks.17.ff_b.0.bias",
|
480 |
+
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.weight": "blocks.17.ff_b.0.weight",
|
481 |
+
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.bias": "blocks.17.ff_b.2.bias",
|
482 |
+
"model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.weight": "blocks.17.ff_b.2.weight",
|
483 |
+
"model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.bias": "blocks.17.norm1_a.linear.bias",
|
484 |
+
"model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.weight": "blocks.17.norm1_a.linear.weight",
|
485 |
+
"model.diffusion_model.joint_blocks.17.x_block.attn.proj.bias": "blocks.17.attn.a_to_out.bias",
|
486 |
+
"model.diffusion_model.joint_blocks.17.x_block.attn.proj.weight": "blocks.17.attn.a_to_out.weight",
|
487 |
+
"model.diffusion_model.joint_blocks.17.x_block.attn.qkv.bias": ['blocks.17.attn.a_to_q.bias', 'blocks.17.attn.a_to_k.bias', 'blocks.17.attn.a_to_v.bias'],
|
488 |
+
"model.diffusion_model.joint_blocks.17.x_block.attn.qkv.weight": ['blocks.17.attn.a_to_q.weight', 'blocks.17.attn.a_to_k.weight', 'blocks.17.attn.a_to_v.weight'],
|
489 |
+
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.bias": "blocks.17.ff_a.0.bias",
|
490 |
+
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.weight": "blocks.17.ff_a.0.weight",
|
491 |
+
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.bias": "blocks.17.ff_a.2.bias",
|
492 |
+
"model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.weight": "blocks.17.ff_a.2.weight",
|
493 |
+
"model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.bias": "blocks.18.norm1_b.linear.bias",
|
494 |
+
"model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.weight": "blocks.18.norm1_b.linear.weight",
|
495 |
+
"model.diffusion_model.joint_blocks.18.context_block.attn.proj.bias": "blocks.18.attn.b_to_out.bias",
|
496 |
+
"model.diffusion_model.joint_blocks.18.context_block.attn.proj.weight": "blocks.18.attn.b_to_out.weight",
|
497 |
+
"model.diffusion_model.joint_blocks.18.context_block.attn.qkv.bias": ['blocks.18.attn.b_to_q.bias', 'blocks.18.attn.b_to_k.bias', 'blocks.18.attn.b_to_v.bias'],
|
498 |
+
"model.diffusion_model.joint_blocks.18.context_block.attn.qkv.weight": ['blocks.18.attn.b_to_q.weight', 'blocks.18.attn.b_to_k.weight', 'blocks.18.attn.b_to_v.weight'],
|
499 |
+
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.bias": "blocks.18.ff_b.0.bias",
|
500 |
+
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.weight": "blocks.18.ff_b.0.weight",
|
501 |
+
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.bias": "blocks.18.ff_b.2.bias",
|
502 |
+
"model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.weight": "blocks.18.ff_b.2.weight",
|
503 |
+
"model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.bias": "blocks.18.norm1_a.linear.bias",
|
504 |
+
"model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.weight": "blocks.18.norm1_a.linear.weight",
|
505 |
+
"model.diffusion_model.joint_blocks.18.x_block.attn.proj.bias": "blocks.18.attn.a_to_out.bias",
|
506 |
+
"model.diffusion_model.joint_blocks.18.x_block.attn.proj.weight": "blocks.18.attn.a_to_out.weight",
|
507 |
+
"model.diffusion_model.joint_blocks.18.x_block.attn.qkv.bias": ['blocks.18.attn.a_to_q.bias', 'blocks.18.attn.a_to_k.bias', 'blocks.18.attn.a_to_v.bias'],
|
508 |
+
"model.diffusion_model.joint_blocks.18.x_block.attn.qkv.weight": ['blocks.18.attn.a_to_q.weight', 'blocks.18.attn.a_to_k.weight', 'blocks.18.attn.a_to_v.weight'],
|
509 |
+
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.bias": "blocks.18.ff_a.0.bias",
|
510 |
+
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.weight": "blocks.18.ff_a.0.weight",
|
511 |
+
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.bias": "blocks.18.ff_a.2.bias",
|
512 |
+
"model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.weight": "blocks.18.ff_a.2.weight",
|
513 |
+
"model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.bias": "blocks.19.norm1_b.linear.bias",
|
514 |
+
"model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.weight": "blocks.19.norm1_b.linear.weight",
|
515 |
+
"model.diffusion_model.joint_blocks.19.context_block.attn.proj.bias": "blocks.19.attn.b_to_out.bias",
|
516 |
+
"model.diffusion_model.joint_blocks.19.context_block.attn.proj.weight": "blocks.19.attn.b_to_out.weight",
|
517 |
+
"model.diffusion_model.joint_blocks.19.context_block.attn.qkv.bias": ['blocks.19.attn.b_to_q.bias', 'blocks.19.attn.b_to_k.bias', 'blocks.19.attn.b_to_v.bias'],
|
518 |
+
"model.diffusion_model.joint_blocks.19.context_block.attn.qkv.weight": ['blocks.19.attn.b_to_q.weight', 'blocks.19.attn.b_to_k.weight', 'blocks.19.attn.b_to_v.weight'],
|
519 |
+
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.bias": "blocks.19.ff_b.0.bias",
|
520 |
+
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.weight": "blocks.19.ff_b.0.weight",
|
521 |
+
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.bias": "blocks.19.ff_b.2.bias",
|
522 |
+
"model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.weight": "blocks.19.ff_b.2.weight",
|
523 |
+
"model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.bias": "blocks.19.norm1_a.linear.bias",
|
524 |
+
"model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.weight": "blocks.19.norm1_a.linear.weight",
|
525 |
+
"model.diffusion_model.joint_blocks.19.x_block.attn.proj.bias": "blocks.19.attn.a_to_out.bias",
|
526 |
+
"model.diffusion_model.joint_blocks.19.x_block.attn.proj.weight": "blocks.19.attn.a_to_out.weight",
|
527 |
+
"model.diffusion_model.joint_blocks.19.x_block.attn.qkv.bias": ['blocks.19.attn.a_to_q.bias', 'blocks.19.attn.a_to_k.bias', 'blocks.19.attn.a_to_v.bias'],
|
528 |
+
"model.diffusion_model.joint_blocks.19.x_block.attn.qkv.weight": ['blocks.19.attn.a_to_q.weight', 'blocks.19.attn.a_to_k.weight', 'blocks.19.attn.a_to_v.weight'],
|
529 |
+
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.bias": "blocks.19.ff_a.0.bias",
|
530 |
+
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.weight": "blocks.19.ff_a.0.weight",
|
531 |
+
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.bias": "blocks.19.ff_a.2.bias",
|
532 |
+
"model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.weight": "blocks.19.ff_a.2.weight",
|
533 |
+
"model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.bias": "blocks.2.norm1_b.linear.bias",
|
534 |
+
"model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.weight": "blocks.2.norm1_b.linear.weight",
|
535 |
+
"model.diffusion_model.joint_blocks.2.context_block.attn.proj.bias": "blocks.2.attn.b_to_out.bias",
|
536 |
+
"model.diffusion_model.joint_blocks.2.context_block.attn.proj.weight": "blocks.2.attn.b_to_out.weight",
|
537 |
+
"model.diffusion_model.joint_blocks.2.context_block.attn.qkv.bias": ['blocks.2.attn.b_to_q.bias', 'blocks.2.attn.b_to_k.bias', 'blocks.2.attn.b_to_v.bias'],
|
538 |
+
"model.diffusion_model.joint_blocks.2.context_block.attn.qkv.weight": ['blocks.2.attn.b_to_q.weight', 'blocks.2.attn.b_to_k.weight', 'blocks.2.attn.b_to_v.weight'],
|
539 |
+
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.bias": "blocks.2.ff_b.0.bias",
|
540 |
+
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.weight": "blocks.2.ff_b.0.weight",
|
541 |
+
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.bias": "blocks.2.ff_b.2.bias",
|
542 |
+
"model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.weight": "blocks.2.ff_b.2.weight",
|
543 |
+
"model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.bias": "blocks.2.norm1_a.linear.bias",
|
544 |
+
"model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.weight": "blocks.2.norm1_a.linear.weight",
|
545 |
+
"model.diffusion_model.joint_blocks.2.x_block.attn.proj.bias": "blocks.2.attn.a_to_out.bias",
|
546 |
+
"model.diffusion_model.joint_blocks.2.x_block.attn.proj.weight": "blocks.2.attn.a_to_out.weight",
|
547 |
+
"model.diffusion_model.joint_blocks.2.x_block.attn.qkv.bias": ['blocks.2.attn.a_to_q.bias', 'blocks.2.attn.a_to_k.bias', 'blocks.2.attn.a_to_v.bias'],
|
548 |
+
"model.diffusion_model.joint_blocks.2.x_block.attn.qkv.weight": ['blocks.2.attn.a_to_q.weight', 'blocks.2.attn.a_to_k.weight', 'blocks.2.attn.a_to_v.weight'],
|
549 |
+
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.bias": "blocks.2.ff_a.0.bias",
|
550 |
+
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.weight": "blocks.2.ff_a.0.weight",
|
551 |
+
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.bias": "blocks.2.ff_a.2.bias",
|
552 |
+
"model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.weight": "blocks.2.ff_a.2.weight",
|
553 |
+
"model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.bias": "blocks.20.norm1_b.linear.bias",
|
554 |
+
"model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.weight": "blocks.20.norm1_b.linear.weight",
|
555 |
+
"model.diffusion_model.joint_blocks.20.context_block.attn.proj.bias": "blocks.20.attn.b_to_out.bias",
|
556 |
+
"model.diffusion_model.joint_blocks.20.context_block.attn.proj.weight": "blocks.20.attn.b_to_out.weight",
|
557 |
+
"model.diffusion_model.joint_blocks.20.context_block.attn.qkv.bias": ['blocks.20.attn.b_to_q.bias', 'blocks.20.attn.b_to_k.bias', 'blocks.20.attn.b_to_v.bias'],
|
558 |
+
"model.diffusion_model.joint_blocks.20.context_block.attn.qkv.weight": ['blocks.20.attn.b_to_q.weight', 'blocks.20.attn.b_to_k.weight', 'blocks.20.attn.b_to_v.weight'],
|
559 |
+
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.bias": "blocks.20.ff_b.0.bias",
|
560 |
+
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.weight": "blocks.20.ff_b.0.weight",
|
561 |
+
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.bias": "blocks.20.ff_b.2.bias",
|
562 |
+
"model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.weight": "blocks.20.ff_b.2.weight",
|
563 |
+
"model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.bias": "blocks.20.norm1_a.linear.bias",
|
564 |
+
"model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.weight": "blocks.20.norm1_a.linear.weight",
|
565 |
+
"model.diffusion_model.joint_blocks.20.x_block.attn.proj.bias": "blocks.20.attn.a_to_out.bias",
|
566 |
+
"model.diffusion_model.joint_blocks.20.x_block.attn.proj.weight": "blocks.20.attn.a_to_out.weight",
|
567 |
+
"model.diffusion_model.joint_blocks.20.x_block.attn.qkv.bias": ['blocks.20.attn.a_to_q.bias', 'blocks.20.attn.a_to_k.bias', 'blocks.20.attn.a_to_v.bias'],
|
568 |
+
"model.diffusion_model.joint_blocks.20.x_block.attn.qkv.weight": ['blocks.20.attn.a_to_q.weight', 'blocks.20.attn.a_to_k.weight', 'blocks.20.attn.a_to_v.weight'],
|
569 |
+
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.bias": "blocks.20.ff_a.0.bias",
|
570 |
+
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.weight": "blocks.20.ff_a.0.weight",
|
571 |
+
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.bias": "blocks.20.ff_a.2.bias",
|
572 |
+
"model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.weight": "blocks.20.ff_a.2.weight",
|
573 |
+
"model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.bias": "blocks.21.norm1_b.linear.bias",
|
574 |
+
"model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.weight": "blocks.21.norm1_b.linear.weight",
|
575 |
+
"model.diffusion_model.joint_blocks.21.context_block.attn.proj.bias": "blocks.21.attn.b_to_out.bias",
|
576 |
+
"model.diffusion_model.joint_blocks.21.context_block.attn.proj.weight": "blocks.21.attn.b_to_out.weight",
|
577 |
+
"model.diffusion_model.joint_blocks.21.context_block.attn.qkv.bias": ['blocks.21.attn.b_to_q.bias', 'blocks.21.attn.b_to_k.bias', 'blocks.21.attn.b_to_v.bias'],
|
578 |
+
"model.diffusion_model.joint_blocks.21.context_block.attn.qkv.weight": ['blocks.21.attn.b_to_q.weight', 'blocks.21.attn.b_to_k.weight', 'blocks.21.attn.b_to_v.weight'],
|
579 |
+
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.bias": "blocks.21.ff_b.0.bias",
|
580 |
+
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.weight": "blocks.21.ff_b.0.weight",
|
581 |
+
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.bias": "blocks.21.ff_b.2.bias",
|
582 |
+
"model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.weight": "blocks.21.ff_b.2.weight",
|
583 |
+
"model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.bias": "blocks.21.norm1_a.linear.bias",
|
584 |
+
"model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.weight": "blocks.21.norm1_a.linear.weight",
|
585 |
+
"model.diffusion_model.joint_blocks.21.x_block.attn.proj.bias": "blocks.21.attn.a_to_out.bias",
|
586 |
+
"model.diffusion_model.joint_blocks.21.x_block.attn.proj.weight": "blocks.21.attn.a_to_out.weight",
|
587 |
+
"model.diffusion_model.joint_blocks.21.x_block.attn.qkv.bias": ['blocks.21.attn.a_to_q.bias', 'blocks.21.attn.a_to_k.bias', 'blocks.21.attn.a_to_v.bias'],
|
588 |
+
"model.diffusion_model.joint_blocks.21.x_block.attn.qkv.weight": ['blocks.21.attn.a_to_q.weight', 'blocks.21.attn.a_to_k.weight', 'blocks.21.attn.a_to_v.weight'],
|
589 |
+
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.bias": "blocks.21.ff_a.0.bias",
|
590 |
+
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.weight": "blocks.21.ff_a.0.weight",
|
591 |
+
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.bias": "blocks.21.ff_a.2.bias",
|
592 |
+
"model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.weight": "blocks.21.ff_a.2.weight",
|
593 |
+
"model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.bias": "blocks.22.norm1_b.linear.bias",
|
594 |
+
"model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.weight": "blocks.22.norm1_b.linear.weight",
|
595 |
+
"model.diffusion_model.joint_blocks.22.context_block.attn.proj.bias": "blocks.22.attn.b_to_out.bias",
|
596 |
+
"model.diffusion_model.joint_blocks.22.context_block.attn.proj.weight": "blocks.22.attn.b_to_out.weight",
|
597 |
+
"model.diffusion_model.joint_blocks.22.context_block.attn.qkv.bias": ['blocks.22.attn.b_to_q.bias', 'blocks.22.attn.b_to_k.bias', 'blocks.22.attn.b_to_v.bias'],
|
598 |
+
"model.diffusion_model.joint_blocks.22.context_block.attn.qkv.weight": ['blocks.22.attn.b_to_q.weight', 'blocks.22.attn.b_to_k.weight', 'blocks.22.attn.b_to_v.weight'],
|
599 |
+
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.bias": "blocks.22.ff_b.0.bias",
|
600 |
+
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.weight": "blocks.22.ff_b.0.weight",
|
601 |
+
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.bias": "blocks.22.ff_b.2.bias",
|
602 |
+
"model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.weight": "blocks.22.ff_b.2.weight",
|
603 |
+
"model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.bias": "blocks.22.norm1_a.linear.bias",
|
604 |
+
"model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.weight": "blocks.22.norm1_a.linear.weight",
|
605 |
+
"model.diffusion_model.joint_blocks.22.x_block.attn.proj.bias": "blocks.22.attn.a_to_out.bias",
|
606 |
+
"model.diffusion_model.joint_blocks.22.x_block.attn.proj.weight": "blocks.22.attn.a_to_out.weight",
|
607 |
+
"model.diffusion_model.joint_blocks.22.x_block.attn.qkv.bias": ['blocks.22.attn.a_to_q.bias', 'blocks.22.attn.a_to_k.bias', 'blocks.22.attn.a_to_v.bias'],
|
608 |
+
"model.diffusion_model.joint_blocks.22.x_block.attn.qkv.weight": ['blocks.22.attn.a_to_q.weight', 'blocks.22.attn.a_to_k.weight', 'blocks.22.attn.a_to_v.weight'],
|
609 |
+
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.bias": "blocks.22.ff_a.0.bias",
|
610 |
+
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.weight": "blocks.22.ff_a.0.weight",
|
611 |
+
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.bias": "blocks.22.ff_a.2.bias",
|
612 |
+
"model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.weight": "blocks.22.ff_a.2.weight",
|
613 |
+
"model.diffusion_model.joint_blocks.23.context_block.attn.qkv.bias": ['blocks.23.attn.b_to_q.bias', 'blocks.23.attn.b_to_k.bias', 'blocks.23.attn.b_to_v.bias'],
|
614 |
+
"model.diffusion_model.joint_blocks.23.context_block.attn.qkv.weight": ['blocks.23.attn.b_to_q.weight', 'blocks.23.attn.b_to_k.weight', 'blocks.23.attn.b_to_v.weight'],
|
615 |
+
"model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.bias": "blocks.23.norm1_a.linear.bias",
|
616 |
+
"model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.weight": "blocks.23.norm1_a.linear.weight",
|
617 |
+
"model.diffusion_model.joint_blocks.23.x_block.attn.proj.bias": "blocks.23.attn.a_to_out.bias",
|
618 |
+
"model.diffusion_model.joint_blocks.23.x_block.attn.proj.weight": "blocks.23.attn.a_to_out.weight",
|
619 |
+
"model.diffusion_model.joint_blocks.23.x_block.attn.qkv.bias": ['blocks.23.attn.a_to_q.bias', 'blocks.23.attn.a_to_k.bias', 'blocks.23.attn.a_to_v.bias'],
|
620 |
+
"model.diffusion_model.joint_blocks.23.x_block.attn.qkv.weight": ['blocks.23.attn.a_to_q.weight', 'blocks.23.attn.a_to_k.weight', 'blocks.23.attn.a_to_v.weight'],
|
621 |
+
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.bias": "blocks.23.ff_a.0.bias",
|
622 |
+
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.weight": "blocks.23.ff_a.0.weight",
|
623 |
+
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.bias": "blocks.23.ff_a.2.bias",
|
624 |
+
"model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.weight": "blocks.23.ff_a.2.weight",
|
625 |
+
"model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.bias": "blocks.3.norm1_b.linear.bias",
|
626 |
+
"model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.weight": "blocks.3.norm1_b.linear.weight",
|
627 |
+
"model.diffusion_model.joint_blocks.3.context_block.attn.proj.bias": "blocks.3.attn.b_to_out.bias",
|
628 |
+
"model.diffusion_model.joint_blocks.3.context_block.attn.proj.weight": "blocks.3.attn.b_to_out.weight",
|
629 |
+
"model.diffusion_model.joint_blocks.3.context_block.attn.qkv.bias": ['blocks.3.attn.b_to_q.bias', 'blocks.3.attn.b_to_k.bias', 'blocks.3.attn.b_to_v.bias'],
|
630 |
+
"model.diffusion_model.joint_blocks.3.context_block.attn.qkv.weight": ['blocks.3.attn.b_to_q.weight', 'blocks.3.attn.b_to_k.weight', 'blocks.3.attn.b_to_v.weight'],
|
631 |
+
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.bias": "blocks.3.ff_b.0.bias",
|
632 |
+
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.weight": "blocks.3.ff_b.0.weight",
|
633 |
+
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.bias": "blocks.3.ff_b.2.bias",
|
634 |
+
"model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.weight": "blocks.3.ff_b.2.weight",
|
635 |
+
"model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.bias": "blocks.3.norm1_a.linear.bias",
|
636 |
+
"model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.weight": "blocks.3.norm1_a.linear.weight",
|
637 |
+
"model.diffusion_model.joint_blocks.3.x_block.attn.proj.bias": "blocks.3.attn.a_to_out.bias",
|
638 |
+
"model.diffusion_model.joint_blocks.3.x_block.attn.proj.weight": "blocks.3.attn.a_to_out.weight",
|
639 |
+
"model.diffusion_model.joint_blocks.3.x_block.attn.qkv.bias": ['blocks.3.attn.a_to_q.bias', 'blocks.3.attn.a_to_k.bias', 'blocks.3.attn.a_to_v.bias'],
|
640 |
+
"model.diffusion_model.joint_blocks.3.x_block.attn.qkv.weight": ['blocks.3.attn.a_to_q.weight', 'blocks.3.attn.a_to_k.weight', 'blocks.3.attn.a_to_v.weight'],
|
641 |
+
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.bias": "blocks.3.ff_a.0.bias",
|
642 |
+
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.weight": "blocks.3.ff_a.0.weight",
|
643 |
+
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.bias": "blocks.3.ff_a.2.bias",
|
644 |
+
"model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.weight": "blocks.3.ff_a.2.weight",
|
645 |
+
"model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.bias": "blocks.4.norm1_b.linear.bias",
|
646 |
+
"model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.weight": "blocks.4.norm1_b.linear.weight",
|
647 |
+
"model.diffusion_model.joint_blocks.4.context_block.attn.proj.bias": "blocks.4.attn.b_to_out.bias",
|
648 |
+
"model.diffusion_model.joint_blocks.4.context_block.attn.proj.weight": "blocks.4.attn.b_to_out.weight",
|
649 |
+
"model.diffusion_model.joint_blocks.4.context_block.attn.qkv.bias": ['blocks.4.attn.b_to_q.bias', 'blocks.4.attn.b_to_k.bias', 'blocks.4.attn.b_to_v.bias'],
|
650 |
+
"model.diffusion_model.joint_blocks.4.context_block.attn.qkv.weight": ['blocks.4.attn.b_to_q.weight', 'blocks.4.attn.b_to_k.weight', 'blocks.4.attn.b_to_v.weight'],
|
651 |
+
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.bias": "blocks.4.ff_b.0.bias",
|
652 |
+
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.weight": "blocks.4.ff_b.0.weight",
|
653 |
+
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.bias": "blocks.4.ff_b.2.bias",
|
654 |
+
"model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.weight": "blocks.4.ff_b.2.weight",
|
655 |
+
"model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.bias": "blocks.4.norm1_a.linear.bias",
|
656 |
+
"model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.weight": "blocks.4.norm1_a.linear.weight",
|
657 |
+
"model.diffusion_model.joint_blocks.4.x_block.attn.proj.bias": "blocks.4.attn.a_to_out.bias",
|
658 |
+
"model.diffusion_model.joint_blocks.4.x_block.attn.proj.weight": "blocks.4.attn.a_to_out.weight",
|
659 |
+
"model.diffusion_model.joint_blocks.4.x_block.attn.qkv.bias": ['blocks.4.attn.a_to_q.bias', 'blocks.4.attn.a_to_k.bias', 'blocks.4.attn.a_to_v.bias'],
|
660 |
+
"model.diffusion_model.joint_blocks.4.x_block.attn.qkv.weight": ['blocks.4.attn.a_to_q.weight', 'blocks.4.attn.a_to_k.weight', 'blocks.4.attn.a_to_v.weight'],
|
661 |
+
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.bias": "blocks.4.ff_a.0.bias",
|
662 |
+
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.weight": "blocks.4.ff_a.0.weight",
|
663 |
+
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.bias": "blocks.4.ff_a.2.bias",
|
664 |
+
"model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.weight": "blocks.4.ff_a.2.weight",
|
665 |
+
"model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.bias": "blocks.5.norm1_b.linear.bias",
|
666 |
+
"model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.weight": "blocks.5.norm1_b.linear.weight",
|
667 |
+
"model.diffusion_model.joint_blocks.5.context_block.attn.proj.bias": "blocks.5.attn.b_to_out.bias",
|
668 |
+
"model.diffusion_model.joint_blocks.5.context_block.attn.proj.weight": "blocks.5.attn.b_to_out.weight",
|
669 |
+
"model.diffusion_model.joint_blocks.5.context_block.attn.qkv.bias": ['blocks.5.attn.b_to_q.bias', 'blocks.5.attn.b_to_k.bias', 'blocks.5.attn.b_to_v.bias'],
|
670 |
+
"model.diffusion_model.joint_blocks.5.context_block.attn.qkv.weight": ['blocks.5.attn.b_to_q.weight', 'blocks.5.attn.b_to_k.weight', 'blocks.5.attn.b_to_v.weight'],
|
671 |
+
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.bias": "blocks.5.ff_b.0.bias",
|
672 |
+
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.weight": "blocks.5.ff_b.0.weight",
|
673 |
+
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.bias": "blocks.5.ff_b.2.bias",
|
674 |
+
"model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.weight": "blocks.5.ff_b.2.weight",
|
675 |
+
"model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.bias": "blocks.5.norm1_a.linear.bias",
|
676 |
+
"model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.weight": "blocks.5.norm1_a.linear.weight",
|
677 |
+
"model.diffusion_model.joint_blocks.5.x_block.attn.proj.bias": "blocks.5.attn.a_to_out.bias",
|
678 |
+
"model.diffusion_model.joint_blocks.5.x_block.attn.proj.weight": "blocks.5.attn.a_to_out.weight",
|
679 |
+
"model.diffusion_model.joint_blocks.5.x_block.attn.qkv.bias": ['blocks.5.attn.a_to_q.bias', 'blocks.5.attn.a_to_k.bias', 'blocks.5.attn.a_to_v.bias'],
|
680 |
+
"model.diffusion_model.joint_blocks.5.x_block.attn.qkv.weight": ['blocks.5.attn.a_to_q.weight', 'blocks.5.attn.a_to_k.weight', 'blocks.5.attn.a_to_v.weight'],
|
681 |
+
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.bias": "blocks.5.ff_a.0.bias",
|
682 |
+
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.weight": "blocks.5.ff_a.0.weight",
|
683 |
+
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.bias": "blocks.5.ff_a.2.bias",
|
684 |
+
"model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.weight": "blocks.5.ff_a.2.weight",
|
685 |
+
"model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.bias": "blocks.6.norm1_b.linear.bias",
|
686 |
+
"model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.weight": "blocks.6.norm1_b.linear.weight",
|
687 |
+
"model.diffusion_model.joint_blocks.6.context_block.attn.proj.bias": "blocks.6.attn.b_to_out.bias",
|
688 |
+
"model.diffusion_model.joint_blocks.6.context_block.attn.proj.weight": "blocks.6.attn.b_to_out.weight",
|
689 |
+
"model.diffusion_model.joint_blocks.6.context_block.attn.qkv.bias": ['blocks.6.attn.b_to_q.bias', 'blocks.6.attn.b_to_k.bias', 'blocks.6.attn.b_to_v.bias'],
|
690 |
+
"model.diffusion_model.joint_blocks.6.context_block.attn.qkv.weight": ['blocks.6.attn.b_to_q.weight', 'blocks.6.attn.b_to_k.weight', 'blocks.6.attn.b_to_v.weight'],
|
691 |
+
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.bias": "blocks.6.ff_b.0.bias",
|
692 |
+
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.weight": "blocks.6.ff_b.0.weight",
|
693 |
+
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.bias": "blocks.6.ff_b.2.bias",
|
694 |
+
"model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.weight": "blocks.6.ff_b.2.weight",
|
695 |
+
"model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.bias": "blocks.6.norm1_a.linear.bias",
|
696 |
+
"model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.weight": "blocks.6.norm1_a.linear.weight",
|
697 |
+
"model.diffusion_model.joint_blocks.6.x_block.attn.proj.bias": "blocks.6.attn.a_to_out.bias",
|
698 |
+
"model.diffusion_model.joint_blocks.6.x_block.attn.proj.weight": "blocks.6.attn.a_to_out.weight",
|
699 |
+
"model.diffusion_model.joint_blocks.6.x_block.attn.qkv.bias": ['blocks.6.attn.a_to_q.bias', 'blocks.6.attn.a_to_k.bias', 'blocks.6.attn.a_to_v.bias'],
|
700 |
+
"model.diffusion_model.joint_blocks.6.x_block.attn.qkv.weight": ['blocks.6.attn.a_to_q.weight', 'blocks.6.attn.a_to_k.weight', 'blocks.6.attn.a_to_v.weight'],
|
701 |
+
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.bias": "blocks.6.ff_a.0.bias",
|
702 |
+
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.weight": "blocks.6.ff_a.0.weight",
|
703 |
+
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.bias": "blocks.6.ff_a.2.bias",
|
704 |
+
"model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.weight": "blocks.6.ff_a.2.weight",
|
705 |
+
"model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.bias": "blocks.7.norm1_b.linear.bias",
|
706 |
+
"model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.weight": "blocks.7.norm1_b.linear.weight",
|
707 |
+
"model.diffusion_model.joint_blocks.7.context_block.attn.proj.bias": "blocks.7.attn.b_to_out.bias",
|
708 |
+
"model.diffusion_model.joint_blocks.7.context_block.attn.proj.weight": "blocks.7.attn.b_to_out.weight",
|
709 |
+
"model.diffusion_model.joint_blocks.7.context_block.attn.qkv.bias": ['blocks.7.attn.b_to_q.bias', 'blocks.7.attn.b_to_k.bias', 'blocks.7.attn.b_to_v.bias'],
|
710 |
+
"model.diffusion_model.joint_blocks.7.context_block.attn.qkv.weight": ['blocks.7.attn.b_to_q.weight', 'blocks.7.attn.b_to_k.weight', 'blocks.7.attn.b_to_v.weight'],
|
711 |
+
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.bias": "blocks.7.ff_b.0.bias",
|
712 |
+
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.weight": "blocks.7.ff_b.0.weight",
|
713 |
+
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.bias": "blocks.7.ff_b.2.bias",
|
714 |
+
"model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.weight": "blocks.7.ff_b.2.weight",
|
715 |
+
"model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.bias": "blocks.7.norm1_a.linear.bias",
|
716 |
+
"model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.weight": "blocks.7.norm1_a.linear.weight",
|
717 |
+
"model.diffusion_model.joint_blocks.7.x_block.attn.proj.bias": "blocks.7.attn.a_to_out.bias",
|
718 |
+
"model.diffusion_model.joint_blocks.7.x_block.attn.proj.weight": "blocks.7.attn.a_to_out.weight",
|
719 |
+
"model.diffusion_model.joint_blocks.7.x_block.attn.qkv.bias": ['blocks.7.attn.a_to_q.bias', 'blocks.7.attn.a_to_k.bias', 'blocks.7.attn.a_to_v.bias'],
|
720 |
+
"model.diffusion_model.joint_blocks.7.x_block.attn.qkv.weight": ['blocks.7.attn.a_to_q.weight', 'blocks.7.attn.a_to_k.weight', 'blocks.7.attn.a_to_v.weight'],
|
721 |
+
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.bias": "blocks.7.ff_a.0.bias",
|
722 |
+
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.weight": "blocks.7.ff_a.0.weight",
|
723 |
+
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.bias": "blocks.7.ff_a.2.bias",
|
724 |
+
"model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.weight": "blocks.7.ff_a.2.weight",
|
725 |
+
"model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.bias": "blocks.8.norm1_b.linear.bias",
|
726 |
+
"model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.weight": "blocks.8.norm1_b.linear.weight",
|
727 |
+
"model.diffusion_model.joint_blocks.8.context_block.attn.proj.bias": "blocks.8.attn.b_to_out.bias",
|
728 |
+
"model.diffusion_model.joint_blocks.8.context_block.attn.proj.weight": "blocks.8.attn.b_to_out.weight",
|
729 |
+
"model.diffusion_model.joint_blocks.8.context_block.attn.qkv.bias": ['blocks.8.attn.b_to_q.bias', 'blocks.8.attn.b_to_k.bias', 'blocks.8.attn.b_to_v.bias'],
|
730 |
+
"model.diffusion_model.joint_blocks.8.context_block.attn.qkv.weight": ['blocks.8.attn.b_to_q.weight', 'blocks.8.attn.b_to_k.weight', 'blocks.8.attn.b_to_v.weight'],
|
731 |
+
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.bias": "blocks.8.ff_b.0.bias",
|
732 |
+
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.weight": "blocks.8.ff_b.0.weight",
|
733 |
+
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.bias": "blocks.8.ff_b.2.bias",
|
734 |
+
"model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.weight": "blocks.8.ff_b.2.weight",
|
735 |
+
"model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.bias": "blocks.8.norm1_a.linear.bias",
|
736 |
+
"model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.weight": "blocks.8.norm1_a.linear.weight",
|
737 |
+
"model.diffusion_model.joint_blocks.8.x_block.attn.proj.bias": "blocks.8.attn.a_to_out.bias",
|
738 |
+
"model.diffusion_model.joint_blocks.8.x_block.attn.proj.weight": "blocks.8.attn.a_to_out.weight",
|
739 |
+
"model.diffusion_model.joint_blocks.8.x_block.attn.qkv.bias": ['blocks.8.attn.a_to_q.bias', 'blocks.8.attn.a_to_k.bias', 'blocks.8.attn.a_to_v.bias'],
|
740 |
+
"model.diffusion_model.joint_blocks.8.x_block.attn.qkv.weight": ['blocks.8.attn.a_to_q.weight', 'blocks.8.attn.a_to_k.weight', 'blocks.8.attn.a_to_v.weight'],
|
741 |
+
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.bias": "blocks.8.ff_a.0.bias",
|
742 |
+
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.weight": "blocks.8.ff_a.0.weight",
|
743 |
+
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.bias": "blocks.8.ff_a.2.bias",
|
744 |
+
"model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.weight": "blocks.8.ff_a.2.weight",
|
745 |
+
"model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.bias": "blocks.9.norm1_b.linear.bias",
|
746 |
+
"model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.weight": "blocks.9.norm1_b.linear.weight",
|
747 |
+
"model.diffusion_model.joint_blocks.9.context_block.attn.proj.bias": "blocks.9.attn.b_to_out.bias",
|
748 |
+
"model.diffusion_model.joint_blocks.9.context_block.attn.proj.weight": "blocks.9.attn.b_to_out.weight",
|
749 |
+
"model.diffusion_model.joint_blocks.9.context_block.attn.qkv.bias": ['blocks.9.attn.b_to_q.bias', 'blocks.9.attn.b_to_k.bias', 'blocks.9.attn.b_to_v.bias'],
|
750 |
+
"model.diffusion_model.joint_blocks.9.context_block.attn.qkv.weight": ['blocks.9.attn.b_to_q.weight', 'blocks.9.attn.b_to_k.weight', 'blocks.9.attn.b_to_v.weight'],
|
751 |
+
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.bias": "blocks.9.ff_b.0.bias",
|
752 |
+
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.weight": "blocks.9.ff_b.0.weight",
|
753 |
+
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.bias": "blocks.9.ff_b.2.bias",
|
754 |
+
"model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.weight": "blocks.9.ff_b.2.weight",
|
755 |
+
"model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.bias": "blocks.9.norm1_a.linear.bias",
|
756 |
+
"model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.weight": "blocks.9.norm1_a.linear.weight",
|
757 |
+
"model.diffusion_model.joint_blocks.9.x_block.attn.proj.bias": "blocks.9.attn.a_to_out.bias",
|
758 |
+
"model.diffusion_model.joint_blocks.9.x_block.attn.proj.weight": "blocks.9.attn.a_to_out.weight",
|
759 |
+
"model.diffusion_model.joint_blocks.9.x_block.attn.qkv.bias": ['blocks.9.attn.a_to_q.bias', 'blocks.9.attn.a_to_k.bias', 'blocks.9.attn.a_to_v.bias'],
|
760 |
+
"model.diffusion_model.joint_blocks.9.x_block.attn.qkv.weight": ['blocks.9.attn.a_to_q.weight', 'blocks.9.attn.a_to_k.weight', 'blocks.9.attn.a_to_v.weight'],
|
761 |
+
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.bias": "blocks.9.ff_a.0.bias",
|
762 |
+
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.weight": "blocks.9.ff_a.0.weight",
|
763 |
+
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.bias": "blocks.9.ff_a.2.bias",
|
764 |
+
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.weight": "blocks.9.ff_a.2.weight",
|
765 |
+
"model.diffusion_model.pos_embed": "pos_embedder.pos_embed",
|
766 |
+
"model.diffusion_model.t_embedder.mlp.0.bias": "time_embedder.timestep_embedder.0.bias",
|
767 |
+
"model.diffusion_model.t_embedder.mlp.0.weight": "time_embedder.timestep_embedder.0.weight",
|
768 |
+
"model.diffusion_model.t_embedder.mlp.2.bias": "time_embedder.timestep_embedder.2.bias",
|
769 |
+
"model.diffusion_model.t_embedder.mlp.2.weight": "time_embedder.timestep_embedder.2.weight",
|
770 |
+
"model.diffusion_model.x_embedder.proj.bias": "pos_embedder.proj.bias",
|
771 |
+
"model.diffusion_model.x_embedder.proj.weight": "pos_embedder.proj.weight",
|
772 |
+
"model.diffusion_model.y_embedder.mlp.0.bias": "pooled_text_embedder.0.bias",
|
773 |
+
"model.diffusion_model.y_embedder.mlp.0.weight": "pooled_text_embedder.0.weight",
|
774 |
+
"model.diffusion_model.y_embedder.mlp.2.bias": "pooled_text_embedder.2.bias",
|
775 |
+
"model.diffusion_model.y_embedder.mlp.2.weight": "pooled_text_embedder.2.weight",
|
776 |
+
|
777 |
+
"model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.weight": "blocks.23.norm1_b.linear.weight",
|
778 |
+
"model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.bias": "blocks.23.norm1_b.linear.bias",
|
779 |
+
"model.diffusion_model.final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
|
780 |
+
"model.diffusion_model.final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
|
781 |
+
}
|
782 |
+
state_dict_ = {}
|
783 |
+
for name in state_dict:
|
784 |
+
if name in rename_dict:
|
785 |
+
param = state_dict[name]
|
786 |
+
if name.startswith("model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1."):
|
787 |
+
param = torch.concat([param[1536:], param[:1536]], axis=0)
|
788 |
+
elif name.startswith("model.diffusion_model.final_layer.adaLN_modulation.1."):
|
789 |
+
param = torch.concat([param[1536:], param[:1536]], axis=0)
|
790 |
+
elif name == "model.diffusion_model.pos_embed":
|
791 |
+
param = param.reshape((1, 192, 192, 1536))
|
792 |
+
if isinstance(rename_dict[name], str):
|
793 |
+
state_dict_[rename_dict[name]] = param
|
794 |
+
else:
|
795 |
+
name_ = rename_dict[name][0].replace(".a_to_q.", ".a_to_qkv.").replace(".b_to_q.", ".b_to_qkv.")
|
796 |
+
state_dict_[name_] = param
|
797 |
+
return state_dict_
|
diffsynth/models/sd3_text_encoder.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
diffsynth/models/sd3_vae_decoder.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .sd_vae_decoder import VAEAttentionBlock, SDVAEDecoderStateDictConverter
|
3 |
+
from .sd_unet import ResnetBlock, UpSampler
|
4 |
+
from .tiler import TileWorker
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
class SD3VAEDecoder(torch.nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__()
|
11 |
+
self.scaling_factor = 1.5305 # Different from SD 1.x
|
12 |
+
self.shift_factor = 0.0609 # Different from SD 1.x
|
13 |
+
self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
|
14 |
+
|
15 |
+
self.blocks = torch.nn.ModuleList([
|
16 |
+
# UNetMidBlock2D
|
17 |
+
ResnetBlock(512, 512, eps=1e-6),
|
18 |
+
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
19 |
+
ResnetBlock(512, 512, eps=1e-6),
|
20 |
+
# UpDecoderBlock2D
|
21 |
+
ResnetBlock(512, 512, eps=1e-6),
|
22 |
+
ResnetBlock(512, 512, eps=1e-6),
|
23 |
+
ResnetBlock(512, 512, eps=1e-6),
|
24 |
+
UpSampler(512),
|
25 |
+
# UpDecoderBlock2D
|
26 |
+
ResnetBlock(512, 512, eps=1e-6),
|
27 |
+
ResnetBlock(512, 512, eps=1e-6),
|
28 |
+
ResnetBlock(512, 512, eps=1e-6),
|
29 |
+
UpSampler(512),
|
30 |
+
# UpDecoderBlock2D
|
31 |
+
ResnetBlock(512, 256, eps=1e-6),
|
32 |
+
ResnetBlock(256, 256, eps=1e-6),
|
33 |
+
ResnetBlock(256, 256, eps=1e-6),
|
34 |
+
UpSampler(256),
|
35 |
+
# UpDecoderBlock2D
|
36 |
+
ResnetBlock(256, 128, eps=1e-6),
|
37 |
+
ResnetBlock(128, 128, eps=1e-6),
|
38 |
+
ResnetBlock(128, 128, eps=1e-6),
|
39 |
+
])
|
40 |
+
|
41 |
+
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
|
42 |
+
self.conv_act = torch.nn.SiLU()
|
43 |
+
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
44 |
+
|
45 |
+
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
46 |
+
hidden_states = TileWorker().tiled_forward(
|
47 |
+
lambda x: self.forward(x),
|
48 |
+
sample,
|
49 |
+
tile_size,
|
50 |
+
tile_stride,
|
51 |
+
tile_device=sample.device,
|
52 |
+
tile_dtype=sample.dtype
|
53 |
+
)
|
54 |
+
return hidden_states
|
55 |
+
|
56 |
+
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
57 |
+
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
58 |
+
if tiled:
|
59 |
+
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
60 |
+
|
61 |
+
# 1. pre-process
|
62 |
+
hidden_states = sample / self.scaling_factor + self.shift_factor
|
63 |
+
hidden_states = self.conv_in(hidden_states)
|
64 |
+
time_emb = None
|
65 |
+
text_emb = None
|
66 |
+
res_stack = None
|
67 |
+
|
68 |
+
# 2. blocks
|
69 |
+
for i, block in enumerate(self.blocks):
|
70 |
+
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
71 |
+
|
72 |
+
# 3. output
|
73 |
+
hidden_states = self.conv_norm_out(hidden_states)
|
74 |
+
hidden_states = self.conv_act(hidden_states)
|
75 |
+
hidden_states = self.conv_out(hidden_states)
|
76 |
+
|
77 |
+
return hidden_states
|
78 |
+
|
79 |
+
def state_dict_converter(self):
|
80 |
+
return SDVAEDecoderStateDictConverter()
|
diffsynth/models/sd3_vae_encoder.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .sd_unet import ResnetBlock, DownSampler
|
3 |
+
from .sd_vae_encoder import VAEAttentionBlock, SDVAEEncoderStateDictConverter
|
4 |
+
from .tiler import TileWorker
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
|
8 |
+
class SD3VAEEncoder(torch.nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__()
|
11 |
+
self.scaling_factor = 1.5305 # Different from SD 1.x
|
12 |
+
self.shift_factor = 0.0609 # Different from SD 1.x
|
13 |
+
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
|
14 |
+
|
15 |
+
self.blocks = torch.nn.ModuleList([
|
16 |
+
# DownEncoderBlock2D
|
17 |
+
ResnetBlock(128, 128, eps=1e-6),
|
18 |
+
ResnetBlock(128, 128, eps=1e-6),
|
19 |
+
DownSampler(128, padding=0, extra_padding=True),
|
20 |
+
# DownEncoderBlock2D
|
21 |
+
ResnetBlock(128, 256, eps=1e-6),
|
22 |
+
ResnetBlock(256, 256, eps=1e-6),
|
23 |
+
DownSampler(256, padding=0, extra_padding=True),
|
24 |
+
# DownEncoderBlock2D
|
25 |
+
ResnetBlock(256, 512, eps=1e-6),
|
26 |
+
ResnetBlock(512, 512, eps=1e-6),
|
27 |
+
DownSampler(512, padding=0, extra_padding=True),
|
28 |
+
# DownEncoderBlock2D
|
29 |
+
ResnetBlock(512, 512, eps=1e-6),
|
30 |
+
ResnetBlock(512, 512, eps=1e-6),
|
31 |
+
# UNetMidBlock2D
|
32 |
+
ResnetBlock(512, 512, eps=1e-6),
|
33 |
+
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
34 |
+
ResnetBlock(512, 512, eps=1e-6),
|
35 |
+
])
|
36 |
+
|
37 |
+
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
|
38 |
+
self.conv_act = torch.nn.SiLU()
|
39 |
+
self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
|
40 |
+
|
41 |
+
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
42 |
+
hidden_states = TileWorker().tiled_forward(
|
43 |
+
lambda x: self.forward(x),
|
44 |
+
sample,
|
45 |
+
tile_size,
|
46 |
+
tile_stride,
|
47 |
+
tile_device=sample.device,
|
48 |
+
tile_dtype=sample.dtype
|
49 |
+
)
|
50 |
+
return hidden_states
|
51 |
+
|
52 |
+
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
53 |
+
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
54 |
+
if tiled:
|
55 |
+
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
56 |
+
|
57 |
+
# 1. pre-process
|
58 |
+
hidden_states = self.conv_in(sample)
|
59 |
+
time_emb = None
|
60 |
+
text_emb = None
|
61 |
+
res_stack = None
|
62 |
+
|
63 |
+
# 2. blocks
|
64 |
+
for i, block in enumerate(self.blocks):
|
65 |
+
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
66 |
+
|
67 |
+
# 3. output
|
68 |
+
hidden_states = self.conv_norm_out(hidden_states)
|
69 |
+
hidden_states = self.conv_act(hidden_states)
|
70 |
+
hidden_states = self.conv_out(hidden_states)
|
71 |
+
hidden_states = hidden_states[:, :16]
|
72 |
+
hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
|
73 |
+
|
74 |
+
return hidden_states
|
75 |
+
|
76 |
+
def encode_video(self, sample, batch_size=8):
|
77 |
+
B = sample.shape[0]
|
78 |
+
hidden_states = []
|
79 |
+
|
80 |
+
for i in range(0, sample.shape[2], batch_size):
|
81 |
+
|
82 |
+
j = min(i + batch_size, sample.shape[2])
|
83 |
+
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
|
84 |
+
|
85 |
+
hidden_states_batch = self(sample_batch)
|
86 |
+
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
|
87 |
+
|
88 |
+
hidden_states.append(hidden_states_batch)
|
89 |
+
|
90 |
+
hidden_states = torch.concat(hidden_states, dim=2)
|
91 |
+
return hidden_states
|
92 |
+
|
93 |
+
def state_dict_converter(self):
|
94 |
+
return SDVAEEncoderStateDictConverter()
|
diffsynth/models/sd_controlnet.py
ADDED
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
|
3 |
+
from .tiler import TileWorker
|
4 |
+
|
5 |
+
|
6 |
+
class ControlNetConditioningLayer(torch.nn.Module):
|
7 |
+
def __init__(self, channels = (3, 16, 32, 96, 256, 320)):
|
8 |
+
super().__init__()
|
9 |
+
self.blocks = torch.nn.ModuleList([])
|
10 |
+
self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1))
|
11 |
+
self.blocks.append(torch.nn.SiLU())
|
12 |
+
for i in range(1, len(channels) - 2):
|
13 |
+
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1))
|
14 |
+
self.blocks.append(torch.nn.SiLU())
|
15 |
+
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2))
|
16 |
+
self.blocks.append(torch.nn.SiLU())
|
17 |
+
self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1))
|
18 |
+
|
19 |
+
def forward(self, conditioning):
|
20 |
+
for block in self.blocks:
|
21 |
+
conditioning = block(conditioning)
|
22 |
+
return conditioning
|
23 |
+
|
24 |
+
|
25 |
+
class SDControlNet(torch.nn.Module):
|
26 |
+
def __init__(self, global_pool=False):
|
27 |
+
super().__init__()
|
28 |
+
self.time_proj = Timesteps(320)
|
29 |
+
self.time_embedding = torch.nn.Sequential(
|
30 |
+
torch.nn.Linear(320, 1280),
|
31 |
+
torch.nn.SiLU(),
|
32 |
+
torch.nn.Linear(1280, 1280)
|
33 |
+
)
|
34 |
+
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
35 |
+
|
36 |
+
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
|
37 |
+
|
38 |
+
self.blocks = torch.nn.ModuleList([
|
39 |
+
# CrossAttnDownBlock2D
|
40 |
+
ResnetBlock(320, 320, 1280),
|
41 |
+
AttentionBlock(8, 40, 320, 1, 768),
|
42 |
+
PushBlock(),
|
43 |
+
ResnetBlock(320, 320, 1280),
|
44 |
+
AttentionBlock(8, 40, 320, 1, 768),
|
45 |
+
PushBlock(),
|
46 |
+
DownSampler(320),
|
47 |
+
PushBlock(),
|
48 |
+
# CrossAttnDownBlock2D
|
49 |
+
ResnetBlock(320, 640, 1280),
|
50 |
+
AttentionBlock(8, 80, 640, 1, 768),
|
51 |
+
PushBlock(),
|
52 |
+
ResnetBlock(640, 640, 1280),
|
53 |
+
AttentionBlock(8, 80, 640, 1, 768),
|
54 |
+
PushBlock(),
|
55 |
+
DownSampler(640),
|
56 |
+
PushBlock(),
|
57 |
+
# CrossAttnDownBlock2D
|
58 |
+
ResnetBlock(640, 1280, 1280),
|
59 |
+
AttentionBlock(8, 160, 1280, 1, 768),
|
60 |
+
PushBlock(),
|
61 |
+
ResnetBlock(1280, 1280, 1280),
|
62 |
+
AttentionBlock(8, 160, 1280, 1, 768),
|
63 |
+
PushBlock(),
|
64 |
+
DownSampler(1280),
|
65 |
+
PushBlock(),
|
66 |
+
# DownBlock2D
|
67 |
+
ResnetBlock(1280, 1280, 1280),
|
68 |
+
PushBlock(),
|
69 |
+
ResnetBlock(1280, 1280, 1280),
|
70 |
+
PushBlock(),
|
71 |
+
# UNetMidBlock2DCrossAttn
|
72 |
+
ResnetBlock(1280, 1280, 1280),
|
73 |
+
AttentionBlock(8, 160, 1280, 1, 768),
|
74 |
+
ResnetBlock(1280, 1280, 1280),
|
75 |
+
PushBlock()
|
76 |
+
])
|
77 |
+
|
78 |
+
self.controlnet_blocks = torch.nn.ModuleList([
|
79 |
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
80 |
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
|
81 |
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
|
82 |
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
|
83 |
+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
|
84 |
+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
|
85 |
+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
|
86 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
|
87 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
88 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
89 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
90 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
91 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
92 |
+
])
|
93 |
+
|
94 |
+
self.global_pool = global_pool
|
95 |
+
|
96 |
+
def forward(
|
97 |
+
self,
|
98 |
+
sample, timestep, encoder_hidden_states, conditioning,
|
99 |
+
tiled=False, tile_size=64, tile_stride=32,
|
100 |
+
):
|
101 |
+
# 1. time
|
102 |
+
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
103 |
+
time_emb = self.time_embedding(time_emb)
|
104 |
+
time_emb = time_emb.repeat(sample.shape[0], 1)
|
105 |
+
|
106 |
+
# 2. pre-process
|
107 |
+
height, width = sample.shape[2], sample.shape[3]
|
108 |
+
hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning)
|
109 |
+
text_emb = encoder_hidden_states
|
110 |
+
res_stack = [hidden_states]
|
111 |
+
|
112 |
+
# 3. blocks
|
113 |
+
for i, block in enumerate(self.blocks):
|
114 |
+
if tiled and not isinstance(block, PushBlock):
|
115 |
+
_, _, inter_height, _ = hidden_states.shape
|
116 |
+
resize_scale = inter_height / height
|
117 |
+
hidden_states = TileWorker().tiled_forward(
|
118 |
+
lambda x: block(x, time_emb, text_emb, res_stack)[0],
|
119 |
+
hidden_states,
|
120 |
+
int(tile_size * resize_scale),
|
121 |
+
int(tile_stride * resize_scale),
|
122 |
+
tile_device=hidden_states.device,
|
123 |
+
tile_dtype=hidden_states.dtype
|
124 |
+
)
|
125 |
+
else:
|
126 |
+
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
|
127 |
+
|
128 |
+
# 4. ControlNet blocks
|
129 |
+
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
|
130 |
+
|
131 |
+
# pool
|
132 |
+
if self.global_pool:
|
133 |
+
controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
|
134 |
+
|
135 |
+
return controlnet_res_stack
|
136 |
+
|
137 |
+
def state_dict_converter(self):
|
138 |
+
return SDControlNetStateDictConverter()
|
139 |
+
|
140 |
+
|
141 |
+
class SDControlNetStateDictConverter:
|
142 |
+
def __init__(self):
|
143 |
+
pass
|
144 |
+
|
145 |
+
def from_diffusers(self, state_dict):
|
146 |
+
# architecture
|
147 |
+
block_types = [
|
148 |
+
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
|
149 |
+
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
|
150 |
+
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
|
151 |
+
'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock',
|
152 |
+
'ResnetBlock', 'AttentionBlock', 'ResnetBlock',
|
153 |
+
'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'UpSampler',
|
154 |
+
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
|
155 |
+
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
|
156 |
+
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock'
|
157 |
+
]
|
158 |
+
|
159 |
+
# controlnet_rename_dict
|
160 |
+
controlnet_rename_dict = {
|
161 |
+
"controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
|
162 |
+
"controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
|
163 |
+
"controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
|
164 |
+
"controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
|
165 |
+
"controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
|
166 |
+
"controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
|
167 |
+
"controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
|
168 |
+
"controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
|
169 |
+
"controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
|
170 |
+
"controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
|
171 |
+
"controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
|
172 |
+
"controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
|
173 |
+
"controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
|
174 |
+
"controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
|
175 |
+
"controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
|
176 |
+
"controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
|
177 |
+
}
|
178 |
+
|
179 |
+
# Rename each parameter
|
180 |
+
name_list = sorted([name for name in state_dict])
|
181 |
+
rename_dict = {}
|
182 |
+
block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
|
183 |
+
last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
|
184 |
+
for name in name_list:
|
185 |
+
names = name.split(".")
|
186 |
+
if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
|
187 |
+
pass
|
188 |
+
elif name in controlnet_rename_dict:
|
189 |
+
names = controlnet_rename_dict[name].split(".")
|
190 |
+
elif names[0] == "controlnet_down_blocks":
|
191 |
+
names[0] = "controlnet_blocks"
|
192 |
+
elif names[0] == "controlnet_mid_block":
|
193 |
+
names = ["controlnet_blocks", "12", names[-1]]
|
194 |
+
elif names[0] in ["time_embedding", "add_embedding"]:
|
195 |
+
if names[0] == "add_embedding":
|
196 |
+
names[0] = "add_time_embedding"
|
197 |
+
names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
|
198 |
+
elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
|
199 |
+
if names[0] == "mid_block":
|
200 |
+
names.insert(1, "0")
|
201 |
+
block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
|
202 |
+
block_type_with_id = ".".join(names[:4])
|
203 |
+
if block_type_with_id != last_block_type_with_id[block_type]:
|
204 |
+
block_id[block_type] += 1
|
205 |
+
last_block_type_with_id[block_type] = block_type_with_id
|
206 |
+
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
207 |
+
block_id[block_type] += 1
|
208 |
+
block_type_with_id = ".".join(names[:4])
|
209 |
+
names = ["blocks", str(block_id[block_type])] + names[4:]
|
210 |
+
if "ff" in names:
|
211 |
+
ff_index = names.index("ff")
|
212 |
+
component = ".".join(names[ff_index:ff_index+3])
|
213 |
+
component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
|
214 |
+
names = names[:ff_index] + [component] + names[ff_index+3:]
|
215 |
+
if "to_out" in names:
|
216 |
+
names.pop(names.index("to_out") + 1)
|
217 |
+
else:
|
218 |
+
raise ValueError(f"Unknown parameters: {name}")
|
219 |
+
rename_dict[name] = ".".join(names)
|
220 |
+
|
221 |
+
# Convert state_dict
|
222 |
+
state_dict_ = {}
|
223 |
+
for name, param in state_dict.items():
|
224 |
+
if ".proj_in." in name or ".proj_out." in name:
|
225 |
+
param = param.squeeze()
|
226 |
+
if rename_dict[name] in [
|
227 |
+
"controlnet_blocks.1.bias", "controlnet_blocks.2.bias", "controlnet_blocks.3.bias", "controlnet_blocks.5.bias", "controlnet_blocks.6.bias",
|
228 |
+
"controlnet_blocks.8.bias", "controlnet_blocks.9.bias", "controlnet_blocks.10.bias", "controlnet_blocks.11.bias", "controlnet_blocks.12.bias"
|
229 |
+
]:
|
230 |
+
continue
|
231 |
+
state_dict_[rename_dict[name]] = param
|
232 |
+
return state_dict_
|
233 |
+
|
234 |
+
def from_civitai(self, state_dict):
|
235 |
+
if "mid_block.resnets.1.time_emb_proj.weight" in state_dict:
|
236 |
+
# For controlnets in diffusers format
|
237 |
+
return self.from_diffusers(state_dict)
|
238 |
+
rename_dict = {
|
239 |
+
"control_model.time_embed.0.weight": "time_embedding.0.weight",
|
240 |
+
"control_model.time_embed.0.bias": "time_embedding.0.bias",
|
241 |
+
"control_model.time_embed.2.weight": "time_embedding.2.weight",
|
242 |
+
"control_model.time_embed.2.bias": "time_embedding.2.bias",
|
243 |
+
"control_model.input_blocks.0.0.weight": "conv_in.weight",
|
244 |
+
"control_model.input_blocks.0.0.bias": "conv_in.bias",
|
245 |
+
"control_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight",
|
246 |
+
"control_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias",
|
247 |
+
"control_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight",
|
248 |
+
"control_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias",
|
249 |
+
"control_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight",
|
250 |
+
"control_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias",
|
251 |
+
"control_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight",
|
252 |
+
"control_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias",
|
253 |
+
"control_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight",
|
254 |
+
"control_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias",
|
255 |
+
"control_model.input_blocks.1.1.norm.weight": "blocks.1.norm.weight",
|
256 |
+
"control_model.input_blocks.1.1.norm.bias": "blocks.1.norm.bias",
|
257 |
+
"control_model.input_blocks.1.1.proj_in.weight": "blocks.1.proj_in.weight",
|
258 |
+
"control_model.input_blocks.1.1.proj_in.bias": "blocks.1.proj_in.bias",
|
259 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.1.transformer_blocks.0.attn1.to_q.weight",
|
260 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.1.transformer_blocks.0.attn1.to_k.weight",
|
261 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.1.transformer_blocks.0.attn1.to_v.weight",
|
262 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.1.transformer_blocks.0.attn1.to_out.weight",
|
263 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.1.transformer_blocks.0.attn1.to_out.bias",
|
264 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.1.transformer_blocks.0.act_fn.proj.weight",
|
265 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.1.transformer_blocks.0.act_fn.proj.bias",
|
266 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.1.transformer_blocks.0.ff.weight",
|
267 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.1.transformer_blocks.0.ff.bias",
|
268 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.1.transformer_blocks.0.attn2.to_q.weight",
|
269 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.1.transformer_blocks.0.attn2.to_k.weight",
|
270 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.1.transformer_blocks.0.attn2.to_v.weight",
|
271 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.1.transformer_blocks.0.attn2.to_out.weight",
|
272 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.1.transformer_blocks.0.attn2.to_out.bias",
|
273 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.1.transformer_blocks.0.norm1.weight",
|
274 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.1.transformer_blocks.0.norm1.bias",
|
275 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.1.transformer_blocks.0.norm2.weight",
|
276 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.1.transformer_blocks.0.norm2.bias",
|
277 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.1.transformer_blocks.0.norm3.weight",
|
278 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.1.transformer_blocks.0.norm3.bias",
|
279 |
+
"control_model.input_blocks.1.1.proj_out.weight": "blocks.1.proj_out.weight",
|
280 |
+
"control_model.input_blocks.1.1.proj_out.bias": "blocks.1.proj_out.bias",
|
281 |
+
"control_model.input_blocks.2.0.in_layers.0.weight": "blocks.3.norm1.weight",
|
282 |
+
"control_model.input_blocks.2.0.in_layers.0.bias": "blocks.3.norm1.bias",
|
283 |
+
"control_model.input_blocks.2.0.in_layers.2.weight": "blocks.3.conv1.weight",
|
284 |
+
"control_model.input_blocks.2.0.in_layers.2.bias": "blocks.3.conv1.bias",
|
285 |
+
"control_model.input_blocks.2.0.emb_layers.1.weight": "blocks.3.time_emb_proj.weight",
|
286 |
+
"control_model.input_blocks.2.0.emb_layers.1.bias": "blocks.3.time_emb_proj.bias",
|
287 |
+
"control_model.input_blocks.2.0.out_layers.0.weight": "blocks.3.norm2.weight",
|
288 |
+
"control_model.input_blocks.2.0.out_layers.0.bias": "blocks.3.norm2.bias",
|
289 |
+
"control_model.input_blocks.2.0.out_layers.3.weight": "blocks.3.conv2.weight",
|
290 |
+
"control_model.input_blocks.2.0.out_layers.3.bias": "blocks.3.conv2.bias",
|
291 |
+
"control_model.input_blocks.2.1.norm.weight": "blocks.4.norm.weight",
|
292 |
+
"control_model.input_blocks.2.1.norm.bias": "blocks.4.norm.bias",
|
293 |
+
"control_model.input_blocks.2.1.proj_in.weight": "blocks.4.proj_in.weight",
|
294 |
+
"control_model.input_blocks.2.1.proj_in.bias": "blocks.4.proj_in.bias",
|
295 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.4.transformer_blocks.0.attn1.to_q.weight",
|
296 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.4.transformer_blocks.0.attn1.to_k.weight",
|
297 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.4.transformer_blocks.0.attn1.to_v.weight",
|
298 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.4.transformer_blocks.0.attn1.to_out.weight",
|
299 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.4.transformer_blocks.0.attn1.to_out.bias",
|
300 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.4.transformer_blocks.0.act_fn.proj.weight",
|
301 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.4.transformer_blocks.0.act_fn.proj.bias",
|
302 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.4.transformer_blocks.0.ff.weight",
|
303 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.4.transformer_blocks.0.ff.bias",
|
304 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.4.transformer_blocks.0.attn2.to_q.weight",
|
305 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.4.transformer_blocks.0.attn2.to_k.weight",
|
306 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.4.transformer_blocks.0.attn2.to_v.weight",
|
307 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.4.transformer_blocks.0.attn2.to_out.weight",
|
308 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.4.transformer_blocks.0.attn2.to_out.bias",
|
309 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.4.transformer_blocks.0.norm1.weight",
|
310 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.4.transformer_blocks.0.norm1.bias",
|
311 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.4.transformer_blocks.0.norm2.weight",
|
312 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.4.transformer_blocks.0.norm2.bias",
|
313 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.4.transformer_blocks.0.norm3.weight",
|
314 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.4.transformer_blocks.0.norm3.bias",
|
315 |
+
"control_model.input_blocks.2.1.proj_out.weight": "blocks.4.proj_out.weight",
|
316 |
+
"control_model.input_blocks.2.1.proj_out.bias": "blocks.4.proj_out.bias",
|
317 |
+
"control_model.input_blocks.3.0.op.weight": "blocks.6.conv.weight",
|
318 |
+
"control_model.input_blocks.3.0.op.bias": "blocks.6.conv.bias",
|
319 |
+
"control_model.input_blocks.4.0.in_layers.0.weight": "blocks.8.norm1.weight",
|
320 |
+
"control_model.input_blocks.4.0.in_layers.0.bias": "blocks.8.norm1.bias",
|
321 |
+
"control_model.input_blocks.4.0.in_layers.2.weight": "blocks.8.conv1.weight",
|
322 |
+
"control_model.input_blocks.4.0.in_layers.2.bias": "blocks.8.conv1.bias",
|
323 |
+
"control_model.input_blocks.4.0.emb_layers.1.weight": "blocks.8.time_emb_proj.weight",
|
324 |
+
"control_model.input_blocks.4.0.emb_layers.1.bias": "blocks.8.time_emb_proj.bias",
|
325 |
+
"control_model.input_blocks.4.0.out_layers.0.weight": "blocks.8.norm2.weight",
|
326 |
+
"control_model.input_blocks.4.0.out_layers.0.bias": "blocks.8.norm2.bias",
|
327 |
+
"control_model.input_blocks.4.0.out_layers.3.weight": "blocks.8.conv2.weight",
|
328 |
+
"control_model.input_blocks.4.0.out_layers.3.bias": "blocks.8.conv2.bias",
|
329 |
+
"control_model.input_blocks.4.0.skip_connection.weight": "blocks.8.conv_shortcut.weight",
|
330 |
+
"control_model.input_blocks.4.0.skip_connection.bias": "blocks.8.conv_shortcut.bias",
|
331 |
+
"control_model.input_blocks.4.1.norm.weight": "blocks.9.norm.weight",
|
332 |
+
"control_model.input_blocks.4.1.norm.bias": "blocks.9.norm.bias",
|
333 |
+
"control_model.input_blocks.4.1.proj_in.weight": "blocks.9.proj_in.weight",
|
334 |
+
"control_model.input_blocks.4.1.proj_in.bias": "blocks.9.proj_in.bias",
|
335 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.9.transformer_blocks.0.attn1.to_q.weight",
|
336 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.9.transformer_blocks.0.attn1.to_k.weight",
|
337 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.9.transformer_blocks.0.attn1.to_v.weight",
|
338 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.9.transformer_blocks.0.attn1.to_out.weight",
|
339 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.9.transformer_blocks.0.attn1.to_out.bias",
|
340 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.9.transformer_blocks.0.act_fn.proj.weight",
|
341 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.9.transformer_blocks.0.act_fn.proj.bias",
|
342 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.9.transformer_blocks.0.ff.weight",
|
343 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.9.transformer_blocks.0.ff.bias",
|
344 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.9.transformer_blocks.0.attn2.to_q.weight",
|
345 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.9.transformer_blocks.0.attn2.to_k.weight",
|
346 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.9.transformer_blocks.0.attn2.to_v.weight",
|
347 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.9.transformer_blocks.0.attn2.to_out.weight",
|
348 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.9.transformer_blocks.0.attn2.to_out.bias",
|
349 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.9.transformer_blocks.0.norm1.weight",
|
350 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.9.transformer_blocks.0.norm1.bias",
|
351 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.9.transformer_blocks.0.norm2.weight",
|
352 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.9.transformer_blocks.0.norm2.bias",
|
353 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.9.transformer_blocks.0.norm3.weight",
|
354 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.9.transformer_blocks.0.norm3.bias",
|
355 |
+
"control_model.input_blocks.4.1.proj_out.weight": "blocks.9.proj_out.weight",
|
356 |
+
"control_model.input_blocks.4.1.proj_out.bias": "blocks.9.proj_out.bias",
|
357 |
+
"control_model.input_blocks.5.0.in_layers.0.weight": "blocks.11.norm1.weight",
|
358 |
+
"control_model.input_blocks.5.0.in_layers.0.bias": "blocks.11.norm1.bias",
|
359 |
+
"control_model.input_blocks.5.0.in_layers.2.weight": "blocks.11.conv1.weight",
|
360 |
+
"control_model.input_blocks.5.0.in_layers.2.bias": "blocks.11.conv1.bias",
|
361 |
+
"control_model.input_blocks.5.0.emb_layers.1.weight": "blocks.11.time_emb_proj.weight",
|
362 |
+
"control_model.input_blocks.5.0.emb_layers.1.bias": "blocks.11.time_emb_proj.bias",
|
363 |
+
"control_model.input_blocks.5.0.out_layers.0.weight": "blocks.11.norm2.weight",
|
364 |
+
"control_model.input_blocks.5.0.out_layers.0.bias": "blocks.11.norm2.bias",
|
365 |
+
"control_model.input_blocks.5.0.out_layers.3.weight": "blocks.11.conv2.weight",
|
366 |
+
"control_model.input_blocks.5.0.out_layers.3.bias": "blocks.11.conv2.bias",
|
367 |
+
"control_model.input_blocks.5.1.norm.weight": "blocks.12.norm.weight",
|
368 |
+
"control_model.input_blocks.5.1.norm.bias": "blocks.12.norm.bias",
|
369 |
+
"control_model.input_blocks.5.1.proj_in.weight": "blocks.12.proj_in.weight",
|
370 |
+
"control_model.input_blocks.5.1.proj_in.bias": "blocks.12.proj_in.bias",
|
371 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.12.transformer_blocks.0.attn1.to_q.weight",
|
372 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.12.transformer_blocks.0.attn1.to_k.weight",
|
373 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.12.transformer_blocks.0.attn1.to_v.weight",
|
374 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.12.transformer_blocks.0.attn1.to_out.weight",
|
375 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.12.transformer_blocks.0.attn1.to_out.bias",
|
376 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.12.transformer_blocks.0.act_fn.proj.weight",
|
377 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.12.transformer_blocks.0.act_fn.proj.bias",
|
378 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.12.transformer_blocks.0.ff.weight",
|
379 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.12.transformer_blocks.0.ff.bias",
|
380 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.12.transformer_blocks.0.attn2.to_q.weight",
|
381 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.12.transformer_blocks.0.attn2.to_k.weight",
|
382 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.12.transformer_blocks.0.attn2.to_v.weight",
|
383 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.12.transformer_blocks.0.attn2.to_out.weight",
|
384 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.12.transformer_blocks.0.attn2.to_out.bias",
|
385 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.12.transformer_blocks.0.norm1.weight",
|
386 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.12.transformer_blocks.0.norm1.bias",
|
387 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.12.transformer_blocks.0.norm2.weight",
|
388 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.12.transformer_blocks.0.norm2.bias",
|
389 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.12.transformer_blocks.0.norm3.weight",
|
390 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.12.transformer_blocks.0.norm3.bias",
|
391 |
+
"control_model.input_blocks.5.1.proj_out.weight": "blocks.12.proj_out.weight",
|
392 |
+
"control_model.input_blocks.5.1.proj_out.bias": "blocks.12.proj_out.bias",
|
393 |
+
"control_model.input_blocks.6.0.op.weight": "blocks.14.conv.weight",
|
394 |
+
"control_model.input_blocks.6.0.op.bias": "blocks.14.conv.bias",
|
395 |
+
"control_model.input_blocks.7.0.in_layers.0.weight": "blocks.16.norm1.weight",
|
396 |
+
"control_model.input_blocks.7.0.in_layers.0.bias": "blocks.16.norm1.bias",
|
397 |
+
"control_model.input_blocks.7.0.in_layers.2.weight": "blocks.16.conv1.weight",
|
398 |
+
"control_model.input_blocks.7.0.in_layers.2.bias": "blocks.16.conv1.bias",
|
399 |
+
"control_model.input_blocks.7.0.emb_layers.1.weight": "blocks.16.time_emb_proj.weight",
|
400 |
+
"control_model.input_blocks.7.0.emb_layers.1.bias": "blocks.16.time_emb_proj.bias",
|
401 |
+
"control_model.input_blocks.7.0.out_layers.0.weight": "blocks.16.norm2.weight",
|
402 |
+
"control_model.input_blocks.7.0.out_layers.0.bias": "blocks.16.norm2.bias",
|
403 |
+
"control_model.input_blocks.7.0.out_layers.3.weight": "blocks.16.conv2.weight",
|
404 |
+
"control_model.input_blocks.7.0.out_layers.3.bias": "blocks.16.conv2.bias",
|
405 |
+
"control_model.input_blocks.7.0.skip_connection.weight": "blocks.16.conv_shortcut.weight",
|
406 |
+
"control_model.input_blocks.7.0.skip_connection.bias": "blocks.16.conv_shortcut.bias",
|
407 |
+
"control_model.input_blocks.7.1.norm.weight": "blocks.17.norm.weight",
|
408 |
+
"control_model.input_blocks.7.1.norm.bias": "blocks.17.norm.bias",
|
409 |
+
"control_model.input_blocks.7.1.proj_in.weight": "blocks.17.proj_in.weight",
|
410 |
+
"control_model.input_blocks.7.1.proj_in.bias": "blocks.17.proj_in.bias",
|
411 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.17.transformer_blocks.0.attn1.to_q.weight",
|
412 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.17.transformer_blocks.0.attn1.to_k.weight",
|
413 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.17.transformer_blocks.0.attn1.to_v.weight",
|
414 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.17.transformer_blocks.0.attn1.to_out.weight",
|
415 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.17.transformer_blocks.0.attn1.to_out.bias",
|
416 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.17.transformer_blocks.0.act_fn.proj.weight",
|
417 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.17.transformer_blocks.0.act_fn.proj.bias",
|
418 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.17.transformer_blocks.0.ff.weight",
|
419 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.17.transformer_blocks.0.ff.bias",
|
420 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.17.transformer_blocks.0.attn2.to_q.weight",
|
421 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.17.transformer_blocks.0.attn2.to_k.weight",
|
422 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.17.transformer_blocks.0.attn2.to_v.weight",
|
423 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.17.transformer_blocks.0.attn2.to_out.weight",
|
424 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.17.transformer_blocks.0.attn2.to_out.bias",
|
425 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.17.transformer_blocks.0.norm1.weight",
|
426 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.17.transformer_blocks.0.norm1.bias",
|
427 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.17.transformer_blocks.0.norm2.weight",
|
428 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.17.transformer_blocks.0.norm2.bias",
|
429 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.17.transformer_blocks.0.norm3.weight",
|
430 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.17.transformer_blocks.0.norm3.bias",
|
431 |
+
"control_model.input_blocks.7.1.proj_out.weight": "blocks.17.proj_out.weight",
|
432 |
+
"control_model.input_blocks.7.1.proj_out.bias": "blocks.17.proj_out.bias",
|
433 |
+
"control_model.input_blocks.8.0.in_layers.0.weight": "blocks.19.norm1.weight",
|
434 |
+
"control_model.input_blocks.8.0.in_layers.0.bias": "blocks.19.norm1.bias",
|
435 |
+
"control_model.input_blocks.8.0.in_layers.2.weight": "blocks.19.conv1.weight",
|
436 |
+
"control_model.input_blocks.8.0.in_layers.2.bias": "blocks.19.conv1.bias",
|
437 |
+
"control_model.input_blocks.8.0.emb_layers.1.weight": "blocks.19.time_emb_proj.weight",
|
438 |
+
"control_model.input_blocks.8.0.emb_layers.1.bias": "blocks.19.time_emb_proj.bias",
|
439 |
+
"control_model.input_blocks.8.0.out_layers.0.weight": "blocks.19.norm2.weight",
|
440 |
+
"control_model.input_blocks.8.0.out_layers.0.bias": "blocks.19.norm2.bias",
|
441 |
+
"control_model.input_blocks.8.0.out_layers.3.weight": "blocks.19.conv2.weight",
|
442 |
+
"control_model.input_blocks.8.0.out_layers.3.bias": "blocks.19.conv2.bias",
|
443 |
+
"control_model.input_blocks.8.1.norm.weight": "blocks.20.norm.weight",
|
444 |
+
"control_model.input_blocks.8.1.norm.bias": "blocks.20.norm.bias",
|
445 |
+
"control_model.input_blocks.8.1.proj_in.weight": "blocks.20.proj_in.weight",
|
446 |
+
"control_model.input_blocks.8.1.proj_in.bias": "blocks.20.proj_in.bias",
|
447 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.20.transformer_blocks.0.attn1.to_q.weight",
|
448 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.20.transformer_blocks.0.attn1.to_k.weight",
|
449 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.20.transformer_blocks.0.attn1.to_v.weight",
|
450 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.20.transformer_blocks.0.attn1.to_out.weight",
|
451 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.20.transformer_blocks.0.attn1.to_out.bias",
|
452 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.20.transformer_blocks.0.act_fn.proj.weight",
|
453 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.20.transformer_blocks.0.act_fn.proj.bias",
|
454 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.20.transformer_blocks.0.ff.weight",
|
455 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.20.transformer_blocks.0.ff.bias",
|
456 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.20.transformer_blocks.0.attn2.to_q.weight",
|
457 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.20.transformer_blocks.0.attn2.to_k.weight",
|
458 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.20.transformer_blocks.0.attn2.to_v.weight",
|
459 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.20.transformer_blocks.0.attn2.to_out.weight",
|
460 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.20.transformer_blocks.0.attn2.to_out.bias",
|
461 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.20.transformer_blocks.0.norm1.weight",
|
462 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.20.transformer_blocks.0.norm1.bias",
|
463 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.20.transformer_blocks.0.norm2.weight",
|
464 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.20.transformer_blocks.0.norm2.bias",
|
465 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.20.transformer_blocks.0.norm3.weight",
|
466 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.20.transformer_blocks.0.norm3.bias",
|
467 |
+
"control_model.input_blocks.8.1.proj_out.weight": "blocks.20.proj_out.weight",
|
468 |
+
"control_model.input_blocks.8.1.proj_out.bias": "blocks.20.proj_out.bias",
|
469 |
+
"control_model.input_blocks.9.0.op.weight": "blocks.22.conv.weight",
|
470 |
+
"control_model.input_blocks.9.0.op.bias": "blocks.22.conv.bias",
|
471 |
+
"control_model.input_blocks.10.0.in_layers.0.weight": "blocks.24.norm1.weight",
|
472 |
+
"control_model.input_blocks.10.0.in_layers.0.bias": "blocks.24.norm1.bias",
|
473 |
+
"control_model.input_blocks.10.0.in_layers.2.weight": "blocks.24.conv1.weight",
|
474 |
+
"control_model.input_blocks.10.0.in_layers.2.bias": "blocks.24.conv1.bias",
|
475 |
+
"control_model.input_blocks.10.0.emb_layers.1.weight": "blocks.24.time_emb_proj.weight",
|
476 |
+
"control_model.input_blocks.10.0.emb_layers.1.bias": "blocks.24.time_emb_proj.bias",
|
477 |
+
"control_model.input_blocks.10.0.out_layers.0.weight": "blocks.24.norm2.weight",
|
478 |
+
"control_model.input_blocks.10.0.out_layers.0.bias": "blocks.24.norm2.bias",
|
479 |
+
"control_model.input_blocks.10.0.out_layers.3.weight": "blocks.24.conv2.weight",
|
480 |
+
"control_model.input_blocks.10.0.out_layers.3.bias": "blocks.24.conv2.bias",
|
481 |
+
"control_model.input_blocks.11.0.in_layers.0.weight": "blocks.26.norm1.weight",
|
482 |
+
"control_model.input_blocks.11.0.in_layers.0.bias": "blocks.26.norm1.bias",
|
483 |
+
"control_model.input_blocks.11.0.in_layers.2.weight": "blocks.26.conv1.weight",
|
484 |
+
"control_model.input_blocks.11.0.in_layers.2.bias": "blocks.26.conv1.bias",
|
485 |
+
"control_model.input_blocks.11.0.emb_layers.1.weight": "blocks.26.time_emb_proj.weight",
|
486 |
+
"control_model.input_blocks.11.0.emb_layers.1.bias": "blocks.26.time_emb_proj.bias",
|
487 |
+
"control_model.input_blocks.11.0.out_layers.0.weight": "blocks.26.norm2.weight",
|
488 |
+
"control_model.input_blocks.11.0.out_layers.0.bias": "blocks.26.norm2.bias",
|
489 |
+
"control_model.input_blocks.11.0.out_layers.3.weight": "blocks.26.conv2.weight",
|
490 |
+
"control_model.input_blocks.11.0.out_layers.3.bias": "blocks.26.conv2.bias",
|
491 |
+
"control_model.zero_convs.0.0.weight": "controlnet_blocks.0.weight",
|
492 |
+
"control_model.zero_convs.0.0.bias": "controlnet_blocks.0.bias",
|
493 |
+
"control_model.zero_convs.1.0.weight": "controlnet_blocks.1.weight",
|
494 |
+
"control_model.zero_convs.1.0.bias": "controlnet_blocks.0.bias",
|
495 |
+
"control_model.zero_convs.2.0.weight": "controlnet_blocks.2.weight",
|
496 |
+
"control_model.zero_convs.2.0.bias": "controlnet_blocks.0.bias",
|
497 |
+
"control_model.zero_convs.3.0.weight": "controlnet_blocks.3.weight",
|
498 |
+
"control_model.zero_convs.3.0.bias": "controlnet_blocks.0.bias",
|
499 |
+
"control_model.zero_convs.4.0.weight": "controlnet_blocks.4.weight",
|
500 |
+
"control_model.zero_convs.4.0.bias": "controlnet_blocks.4.bias",
|
501 |
+
"control_model.zero_convs.5.0.weight": "controlnet_blocks.5.weight",
|
502 |
+
"control_model.zero_convs.5.0.bias": "controlnet_blocks.4.bias",
|
503 |
+
"control_model.zero_convs.6.0.weight": "controlnet_blocks.6.weight",
|
504 |
+
"control_model.zero_convs.6.0.bias": "controlnet_blocks.4.bias",
|
505 |
+
"control_model.zero_convs.7.0.weight": "controlnet_blocks.7.weight",
|
506 |
+
"control_model.zero_convs.7.0.bias": "controlnet_blocks.7.bias",
|
507 |
+
"control_model.zero_convs.8.0.weight": "controlnet_blocks.8.weight",
|
508 |
+
"control_model.zero_convs.8.0.bias": "controlnet_blocks.7.bias",
|
509 |
+
"control_model.zero_convs.9.0.weight": "controlnet_blocks.9.weight",
|
510 |
+
"control_model.zero_convs.9.0.bias": "controlnet_blocks.7.bias",
|
511 |
+
"control_model.zero_convs.10.0.weight": "controlnet_blocks.10.weight",
|
512 |
+
"control_model.zero_convs.10.0.bias": "controlnet_blocks.7.bias",
|
513 |
+
"control_model.zero_convs.11.0.weight": "controlnet_blocks.11.weight",
|
514 |
+
"control_model.zero_convs.11.0.bias": "controlnet_blocks.7.bias",
|
515 |
+
"control_model.input_hint_block.0.weight": "controlnet_conv_in.blocks.0.weight",
|
516 |
+
"control_model.input_hint_block.0.bias": "controlnet_conv_in.blocks.0.bias",
|
517 |
+
"control_model.input_hint_block.2.weight": "controlnet_conv_in.blocks.2.weight",
|
518 |
+
"control_model.input_hint_block.2.bias": "controlnet_conv_in.blocks.2.bias",
|
519 |
+
"control_model.input_hint_block.4.weight": "controlnet_conv_in.blocks.4.weight",
|
520 |
+
"control_model.input_hint_block.4.bias": "controlnet_conv_in.blocks.4.bias",
|
521 |
+
"control_model.input_hint_block.6.weight": "controlnet_conv_in.blocks.6.weight",
|
522 |
+
"control_model.input_hint_block.6.bias": "controlnet_conv_in.blocks.6.bias",
|
523 |
+
"control_model.input_hint_block.8.weight": "controlnet_conv_in.blocks.8.weight",
|
524 |
+
"control_model.input_hint_block.8.bias": "controlnet_conv_in.blocks.8.bias",
|
525 |
+
"control_model.input_hint_block.10.weight": "controlnet_conv_in.blocks.10.weight",
|
526 |
+
"control_model.input_hint_block.10.bias": "controlnet_conv_in.blocks.10.bias",
|
527 |
+
"control_model.input_hint_block.12.weight": "controlnet_conv_in.blocks.12.weight",
|
528 |
+
"control_model.input_hint_block.12.bias": "controlnet_conv_in.blocks.12.bias",
|
529 |
+
"control_model.input_hint_block.14.weight": "controlnet_conv_in.blocks.14.weight",
|
530 |
+
"control_model.input_hint_block.14.bias": "controlnet_conv_in.blocks.14.bias",
|
531 |
+
"control_model.middle_block.0.in_layers.0.weight": "blocks.28.norm1.weight",
|
532 |
+
"control_model.middle_block.0.in_layers.0.bias": "blocks.28.norm1.bias",
|
533 |
+
"control_model.middle_block.0.in_layers.2.weight": "blocks.28.conv1.weight",
|
534 |
+
"control_model.middle_block.0.in_layers.2.bias": "blocks.28.conv1.bias",
|
535 |
+
"control_model.middle_block.0.emb_layers.1.weight": "blocks.28.time_emb_proj.weight",
|
536 |
+
"control_model.middle_block.0.emb_layers.1.bias": "blocks.28.time_emb_proj.bias",
|
537 |
+
"control_model.middle_block.0.out_layers.0.weight": "blocks.28.norm2.weight",
|
538 |
+
"control_model.middle_block.0.out_layers.0.bias": "blocks.28.norm2.bias",
|
539 |
+
"control_model.middle_block.0.out_layers.3.weight": "blocks.28.conv2.weight",
|
540 |
+
"control_model.middle_block.0.out_layers.3.bias": "blocks.28.conv2.bias",
|
541 |
+
"control_model.middle_block.1.norm.weight": "blocks.29.norm.weight",
|
542 |
+
"control_model.middle_block.1.norm.bias": "blocks.29.norm.bias",
|
543 |
+
"control_model.middle_block.1.proj_in.weight": "blocks.29.proj_in.weight",
|
544 |
+
"control_model.middle_block.1.proj_in.bias": "blocks.29.proj_in.bias",
|
545 |
+
"control_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.29.transformer_blocks.0.attn1.to_q.weight",
|
546 |
+
"control_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.29.transformer_blocks.0.attn1.to_k.weight",
|
547 |
+
"control_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.29.transformer_blocks.0.attn1.to_v.weight",
|
548 |
+
"control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.29.transformer_blocks.0.attn1.to_out.weight",
|
549 |
+
"control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.29.transformer_blocks.0.attn1.to_out.bias",
|
550 |
+
"control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.29.transformer_blocks.0.act_fn.proj.weight",
|
551 |
+
"control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.29.transformer_blocks.0.act_fn.proj.bias",
|
552 |
+
"control_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.29.transformer_blocks.0.ff.weight",
|
553 |
+
"control_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.29.transformer_blocks.0.ff.bias",
|
554 |
+
"control_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.29.transformer_blocks.0.attn2.to_q.weight",
|
555 |
+
"control_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.29.transformer_blocks.0.attn2.to_k.weight",
|
556 |
+
"control_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.29.transformer_blocks.0.attn2.to_v.weight",
|
557 |
+
"control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.29.transformer_blocks.0.attn2.to_out.weight",
|
558 |
+
"control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.29.transformer_blocks.0.attn2.to_out.bias",
|
559 |
+
"control_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.29.transformer_blocks.0.norm1.weight",
|
560 |
+
"control_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.29.transformer_blocks.0.norm1.bias",
|
561 |
+
"control_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.29.transformer_blocks.0.norm2.weight",
|
562 |
+
"control_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.29.transformer_blocks.0.norm2.bias",
|
563 |
+
"control_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.29.transformer_blocks.0.norm3.weight",
|
564 |
+
"control_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.29.transformer_blocks.0.norm3.bias",
|
565 |
+
"control_model.middle_block.1.proj_out.weight": "blocks.29.proj_out.weight",
|
566 |
+
"control_model.middle_block.1.proj_out.bias": "blocks.29.proj_out.bias",
|
567 |
+
"control_model.middle_block.2.in_layers.0.weight": "blocks.30.norm1.weight",
|
568 |
+
"control_model.middle_block.2.in_layers.0.bias": "blocks.30.norm1.bias",
|
569 |
+
"control_model.middle_block.2.in_layers.2.weight": "blocks.30.conv1.weight",
|
570 |
+
"control_model.middle_block.2.in_layers.2.bias": "blocks.30.conv1.bias",
|
571 |
+
"control_model.middle_block.2.emb_layers.1.weight": "blocks.30.time_emb_proj.weight",
|
572 |
+
"control_model.middle_block.2.emb_layers.1.bias": "blocks.30.time_emb_proj.bias",
|
573 |
+
"control_model.middle_block.2.out_layers.0.weight": "blocks.30.norm2.weight",
|
574 |
+
"control_model.middle_block.2.out_layers.0.bias": "blocks.30.norm2.bias",
|
575 |
+
"control_model.middle_block.2.out_layers.3.weight": "blocks.30.conv2.weight",
|
576 |
+
"control_model.middle_block.2.out_layers.3.bias": "blocks.30.conv2.bias",
|
577 |
+
"control_model.middle_block_out.0.weight": "controlnet_blocks.12.weight",
|
578 |
+
"control_model.middle_block_out.0.bias": "controlnet_blocks.7.bias",
|
579 |
+
}
|
580 |
+
state_dict_ = {}
|
581 |
+
for name in state_dict:
|
582 |
+
if name in rename_dict:
|
583 |
+
param = state_dict[name]
|
584 |
+
if ".proj_in." in name or ".proj_out." in name:
|
585 |
+
param = param.squeeze()
|
586 |
+
state_dict_[rename_dict[name]] = param
|
587 |
+
return state_dict_
|
diffsynth/models/sd_ipadapter.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .svd_image_encoder import SVDImageEncoder
|
2 |
+
from .sdxl_ipadapter import IpAdapterImageProjModel, IpAdapterModule, SDXLIpAdapterStateDictConverter
|
3 |
+
from transformers import CLIPImageProcessor
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class IpAdapterCLIPImageEmbedder(SVDImageEncoder):
|
8 |
+
def __init__(self):
|
9 |
+
super().__init__()
|
10 |
+
self.image_processor = CLIPImageProcessor()
|
11 |
+
|
12 |
+
def forward(self, image):
|
13 |
+
pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
|
14 |
+
pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
|
15 |
+
return super().forward(pixel_values)
|
16 |
+
|
17 |
+
|
18 |
+
class SDIpAdapter(torch.nn.Module):
|
19 |
+
def __init__(self):
|
20 |
+
super().__init__()
|
21 |
+
shape_list = [(768, 320)] * 2 + [(768, 640)] * 2 + [(768, 1280)] * 5 + [(768, 640)] * 3 + [(768, 320)] * 3 + [(768, 1280)] * 1
|
22 |
+
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
|
23 |
+
self.image_proj = IpAdapterImageProjModel(cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4)
|
24 |
+
self.set_full_adapter()
|
25 |
+
|
26 |
+
def set_full_adapter(self):
|
27 |
+
block_ids = [1, 4, 9, 12, 17, 20, 40, 43, 46, 50, 53, 56, 60, 63, 66, 29]
|
28 |
+
self.call_block_id = {(i, 0): j for j, i in enumerate(block_ids)}
|
29 |
+
|
30 |
+
def set_less_adapter(self):
|
31 |
+
# IP-Adapter for SD v1.5 doesn't support this feature.
|
32 |
+
self.set_full_adapter()
|
33 |
+
|
34 |
+
def forward(self, hidden_states, scale=1.0):
|
35 |
+
hidden_states = self.image_proj(hidden_states)
|
36 |
+
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
|
37 |
+
ip_kv_dict = {}
|
38 |
+
for (block_id, transformer_id) in self.call_block_id:
|
39 |
+
ipadapter_id = self.call_block_id[(block_id, transformer_id)]
|
40 |
+
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
|
41 |
+
if block_id not in ip_kv_dict:
|
42 |
+
ip_kv_dict[block_id] = {}
|
43 |
+
ip_kv_dict[block_id][transformer_id] = {
|
44 |
+
"ip_k": ip_k,
|
45 |
+
"ip_v": ip_v,
|
46 |
+
"scale": scale
|
47 |
+
}
|
48 |
+
return ip_kv_dict
|
49 |
+
|
50 |
+
def state_dict_converter(self):
|
51 |
+
return SDIpAdapterStateDictConverter()
|
52 |
+
|
53 |
+
|
54 |
+
class SDIpAdapterStateDictConverter(SDXLIpAdapterStateDictConverter):
|
55 |
+
def __init__(self):
|
56 |
+
pass
|
diffsynth/models/sd_lora.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .sd_unet import SDUNetStateDictConverter, SDUNet
|
3 |
+
from .sd_text_encoder import SDTextEncoderStateDictConverter, SDTextEncoder
|
4 |
+
|
5 |
+
|
6 |
+
class SDLoRA:
|
7 |
+
def __init__(self):
|
8 |
+
pass
|
9 |
+
|
10 |
+
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"):
|
11 |
+
special_keys = {
|
12 |
+
"down.blocks": "down_blocks",
|
13 |
+
"up.blocks": "up_blocks",
|
14 |
+
"mid.block": "mid_block",
|
15 |
+
"proj.in": "proj_in",
|
16 |
+
"proj.out": "proj_out",
|
17 |
+
"transformer.blocks": "transformer_blocks",
|
18 |
+
"to.q": "to_q",
|
19 |
+
"to.k": "to_k",
|
20 |
+
"to.v": "to_v",
|
21 |
+
"to.out": "to_out",
|
22 |
+
}
|
23 |
+
state_dict_ = {}
|
24 |
+
for key in state_dict:
|
25 |
+
if ".lora_up" not in key:
|
26 |
+
continue
|
27 |
+
if not key.startswith(lora_prefix):
|
28 |
+
continue
|
29 |
+
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
|
30 |
+
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
|
31 |
+
if len(weight_up.shape) == 4:
|
32 |
+
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
|
33 |
+
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
|
34 |
+
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
35 |
+
else:
|
36 |
+
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
37 |
+
target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight"
|
38 |
+
for special_key in special_keys:
|
39 |
+
target_name = target_name.replace(special_key, special_keys[special_key])
|
40 |
+
state_dict_[target_name] = lora_weight.cpu()
|
41 |
+
return state_dict_
|
42 |
+
|
43 |
+
def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"):
|
44 |
+
state_dict_unet = unet.state_dict()
|
45 |
+
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device)
|
46 |
+
state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora)
|
47 |
+
if len(state_dict_lora) > 0:
|
48 |
+
for name in state_dict_lora:
|
49 |
+
state_dict_unet[name] += state_dict_lora[name].to(device=device)
|
50 |
+
unet.load_state_dict(state_dict_unet)
|
51 |
+
|
52 |
+
def add_lora_to_text_encoder(self, text_encoder: SDTextEncoder, state_dict_lora, alpha=1.0, device="cuda"):
|
53 |
+
state_dict_text_encoder = text_encoder.state_dict()
|
54 |
+
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_te_", alpha=alpha, device=device)
|
55 |
+
state_dict_lora = SDTextEncoderStateDictConverter().from_diffusers(state_dict_lora)
|
56 |
+
if len(state_dict_lora) > 0:
|
57 |
+
for name in state_dict_lora:
|
58 |
+
state_dict_text_encoder[name] += state_dict_lora[name].to(device=device)
|
59 |
+
text_encoder.load_state_dict(state_dict_text_encoder)
|
60 |
+
|
diffsynth/models/sd_motion.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .sd_unet import SDUNet, Attention, GEGLU
|
2 |
+
import torch
|
3 |
+
from einops import rearrange, repeat
|
4 |
+
|
5 |
+
|
6 |
+
class TemporalTransformerBlock(torch.nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
# 1. Self-Attn
|
12 |
+
self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
|
13 |
+
self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
14 |
+
self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
15 |
+
|
16 |
+
# 2. Cross-Attn
|
17 |
+
self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
|
18 |
+
self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
19 |
+
self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
20 |
+
|
21 |
+
# 3. Feed-forward
|
22 |
+
self.norm3 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
23 |
+
self.act_fn = GEGLU(dim, dim * 4)
|
24 |
+
self.ff = torch.nn.Linear(dim * 4, dim)
|
25 |
+
|
26 |
+
|
27 |
+
def forward(self, hidden_states, batch_size=1):
|
28 |
+
|
29 |
+
# 1. Self-Attention
|
30 |
+
norm_hidden_states = self.norm1(hidden_states)
|
31 |
+
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
32 |
+
attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]])
|
33 |
+
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
34 |
+
hidden_states = attn_output + hidden_states
|
35 |
+
|
36 |
+
# 2. Cross-Attention
|
37 |
+
norm_hidden_states = self.norm2(hidden_states)
|
38 |
+
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
39 |
+
attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]])
|
40 |
+
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
41 |
+
hidden_states = attn_output + hidden_states
|
42 |
+
|
43 |
+
# 3. Feed-forward
|
44 |
+
norm_hidden_states = self.norm3(hidden_states)
|
45 |
+
ff_output = self.act_fn(norm_hidden_states)
|
46 |
+
ff_output = self.ff(ff_output)
|
47 |
+
hidden_states = ff_output + hidden_states
|
48 |
+
|
49 |
+
return hidden_states
|
50 |
+
|
51 |
+
|
52 |
+
class TemporalBlock(torch.nn.Module):
|
53 |
+
|
54 |
+
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
|
55 |
+
super().__init__()
|
56 |
+
inner_dim = num_attention_heads * attention_head_dim
|
57 |
+
|
58 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
59 |
+
self.proj_in = torch.nn.Linear(in_channels, inner_dim)
|
60 |
+
|
61 |
+
self.transformer_blocks = torch.nn.ModuleList([
|
62 |
+
TemporalTransformerBlock(
|
63 |
+
inner_dim,
|
64 |
+
num_attention_heads,
|
65 |
+
attention_head_dim
|
66 |
+
)
|
67 |
+
for d in range(num_layers)
|
68 |
+
])
|
69 |
+
|
70 |
+
self.proj_out = torch.nn.Linear(inner_dim, in_channels)
|
71 |
+
|
72 |
+
def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1):
|
73 |
+
batch, _, height, width = hidden_states.shape
|
74 |
+
residual = hidden_states
|
75 |
+
|
76 |
+
hidden_states = self.norm(hidden_states)
|
77 |
+
inner_dim = hidden_states.shape[1]
|
78 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
79 |
+
hidden_states = self.proj_in(hidden_states)
|
80 |
+
|
81 |
+
for block in self.transformer_blocks:
|
82 |
+
hidden_states = block(
|
83 |
+
hidden_states,
|
84 |
+
batch_size=batch_size
|
85 |
+
)
|
86 |
+
|
87 |
+
hidden_states = self.proj_out(hidden_states)
|
88 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
89 |
+
hidden_states = hidden_states + residual
|
90 |
+
|
91 |
+
return hidden_states, time_emb, text_emb, res_stack
|
92 |
+
|
93 |
+
|
94 |
+
class SDMotionModel(torch.nn.Module):
|
95 |
+
def __init__(self):
|
96 |
+
super().__init__()
|
97 |
+
self.motion_modules = torch.nn.ModuleList([
|
98 |
+
TemporalBlock(8, 40, 320, eps=1e-6),
|
99 |
+
TemporalBlock(8, 40, 320, eps=1e-6),
|
100 |
+
TemporalBlock(8, 80, 640, eps=1e-6),
|
101 |
+
TemporalBlock(8, 80, 640, eps=1e-6),
|
102 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
103 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
104 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
105 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
106 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
107 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
108 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
109 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
110 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
111 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
112 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
113 |
+
TemporalBlock(8, 80, 640, eps=1e-6),
|
114 |
+
TemporalBlock(8, 80, 640, eps=1e-6),
|
115 |
+
TemporalBlock(8, 80, 640, eps=1e-6),
|
116 |
+
TemporalBlock(8, 40, 320, eps=1e-6),
|
117 |
+
TemporalBlock(8, 40, 320, eps=1e-6),
|
118 |
+
TemporalBlock(8, 40, 320, eps=1e-6),
|
119 |
+
])
|
120 |
+
self.call_block_id = {
|
121 |
+
1: 0,
|
122 |
+
4: 1,
|
123 |
+
9: 2,
|
124 |
+
12: 3,
|
125 |
+
17: 4,
|
126 |
+
20: 5,
|
127 |
+
24: 6,
|
128 |
+
26: 7,
|
129 |
+
29: 8,
|
130 |
+
32: 9,
|
131 |
+
34: 10,
|
132 |
+
36: 11,
|
133 |
+
40: 12,
|
134 |
+
43: 13,
|
135 |
+
46: 14,
|
136 |
+
50: 15,
|
137 |
+
53: 16,
|
138 |
+
56: 17,
|
139 |
+
60: 18,
|
140 |
+
63: 19,
|
141 |
+
66: 20
|
142 |
+
}
|
143 |
+
|
144 |
+
def forward(self):
|
145 |
+
pass
|
146 |
+
|
147 |
+
def state_dict_converter(self):
|
148 |
+
return SDMotionModelStateDictConverter()
|
149 |
+
|
150 |
+
|
151 |
+
class SDMotionModelStateDictConverter:
|
152 |
+
def __init__(self):
|
153 |
+
pass
|
154 |
+
|
155 |
+
def from_diffusers(self, state_dict):
|
156 |
+
rename_dict = {
|
157 |
+
"norm": "norm",
|
158 |
+
"proj_in": "proj_in",
|
159 |
+
"transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
|
160 |
+
"transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
|
161 |
+
"transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
|
162 |
+
"transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
|
163 |
+
"transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
|
164 |
+
"transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
|
165 |
+
"transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
|
166 |
+
"transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
|
167 |
+
"transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
|
168 |
+
"transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
|
169 |
+
"transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
|
170 |
+
"transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
|
171 |
+
"transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
|
172 |
+
"transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
|
173 |
+
"transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
|
174 |
+
"proj_out": "proj_out",
|
175 |
+
}
|
176 |
+
name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
|
177 |
+
name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
|
178 |
+
name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
|
179 |
+
state_dict_ = {}
|
180 |
+
last_prefix, module_id = "", -1
|
181 |
+
for name in name_list:
|
182 |
+
names = name.split(".")
|
183 |
+
prefix_index = names.index("temporal_transformer") + 1
|
184 |
+
prefix = ".".join(names[:prefix_index])
|
185 |
+
if prefix != last_prefix:
|
186 |
+
last_prefix = prefix
|
187 |
+
module_id += 1
|
188 |
+
middle_name = ".".join(names[prefix_index:-1])
|
189 |
+
suffix = names[-1]
|
190 |
+
if "pos_encoder" in names:
|
191 |
+
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
|
192 |
+
else:
|
193 |
+
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
|
194 |
+
state_dict_[rename] = state_dict[name]
|
195 |
+
return state_dict_
|
196 |
+
|
197 |
+
def from_civitai(self, state_dict):
|
198 |
+
return self.from_diffusers(state_dict)
|
diffsynth/models/sd_text_encoder.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .attention import Attention
|
3 |
+
|
4 |
+
|
5 |
+
class CLIPEncoderLayer(torch.nn.Module):
|
6 |
+
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
|
7 |
+
super().__init__()
|
8 |
+
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
|
9 |
+
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
|
10 |
+
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
|
11 |
+
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
|
12 |
+
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
|
13 |
+
|
14 |
+
self.use_quick_gelu = use_quick_gelu
|
15 |
+
|
16 |
+
def quickGELU(self, x):
|
17 |
+
return x * torch.sigmoid(1.702 * x)
|
18 |
+
|
19 |
+
def forward(self, hidden_states, attn_mask=None):
|
20 |
+
residual = hidden_states
|
21 |
+
|
22 |
+
hidden_states = self.layer_norm1(hidden_states)
|
23 |
+
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
|
24 |
+
hidden_states = residual + hidden_states
|
25 |
+
|
26 |
+
residual = hidden_states
|
27 |
+
hidden_states = self.layer_norm2(hidden_states)
|
28 |
+
hidden_states = self.fc1(hidden_states)
|
29 |
+
if self.use_quick_gelu:
|
30 |
+
hidden_states = self.quickGELU(hidden_states)
|
31 |
+
else:
|
32 |
+
hidden_states = torch.nn.functional.gelu(hidden_states)
|
33 |
+
hidden_states = self.fc2(hidden_states)
|
34 |
+
hidden_states = residual + hidden_states
|
35 |
+
|
36 |
+
return hidden_states
|
37 |
+
|
38 |
+
|
39 |
+
class SDTextEncoder(torch.nn.Module):
|
40 |
+
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
# token_embedding
|
44 |
+
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
45 |
+
|
46 |
+
# position_embeds (This is a fixed tensor)
|
47 |
+
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
48 |
+
|
49 |
+
# encoders
|
50 |
+
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
51 |
+
|
52 |
+
# attn_mask
|
53 |
+
self.attn_mask = self.attention_mask(max_position_embeddings)
|
54 |
+
|
55 |
+
# final_layer_norm
|
56 |
+
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
57 |
+
|
58 |
+
def attention_mask(self, length):
|
59 |
+
mask = torch.empty(length, length)
|
60 |
+
mask.fill_(float("-inf"))
|
61 |
+
mask.triu_(1)
|
62 |
+
return mask
|
63 |
+
|
64 |
+
def forward(self, input_ids, clip_skip=1):
|
65 |
+
embeds = self.token_embedding(input_ids) + self.position_embeds
|
66 |
+
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
67 |
+
for encoder_id, encoder in enumerate(self.encoders):
|
68 |
+
embeds = encoder(embeds, attn_mask=attn_mask)
|
69 |
+
if encoder_id + clip_skip == len(self.encoders):
|
70 |
+
break
|
71 |
+
embeds = self.final_layer_norm(embeds)
|
72 |
+
return embeds
|
73 |
+
|
74 |
+
def state_dict_converter(self):
|
75 |
+
return SDTextEncoderStateDictConverter()
|
76 |
+
|
77 |
+
|
78 |
+
class SDTextEncoderStateDictConverter:
|
79 |
+
def __init__(self):
|
80 |
+
pass
|
81 |
+
|
82 |
+
def from_diffusers(self, state_dict):
|
83 |
+
rename_dict = {
|
84 |
+
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
85 |
+
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
86 |
+
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
87 |
+
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
|
88 |
+
}
|
89 |
+
attn_rename_dict = {
|
90 |
+
"self_attn.q_proj": "attn.to_q",
|
91 |
+
"self_attn.k_proj": "attn.to_k",
|
92 |
+
"self_attn.v_proj": "attn.to_v",
|
93 |
+
"self_attn.out_proj": "attn.to_out",
|
94 |
+
"layer_norm1": "layer_norm1",
|
95 |
+
"layer_norm2": "layer_norm2",
|
96 |
+
"mlp.fc1": "fc1",
|
97 |
+
"mlp.fc2": "fc2",
|
98 |
+
}
|
99 |
+
state_dict_ = {}
|
100 |
+
for name in state_dict:
|
101 |
+
if name in rename_dict:
|
102 |
+
param = state_dict[name]
|
103 |
+
if name == "text_model.embeddings.position_embedding.weight":
|
104 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
105 |
+
state_dict_[rename_dict[name]] = param
|
106 |
+
elif name.startswith("text_model.encoder.layers."):
|
107 |
+
param = state_dict[name]
|
108 |
+
names = name.split(".")
|
109 |
+
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
110 |
+
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
111 |
+
state_dict_[name_] = param
|
112 |
+
return state_dict_
|
113 |
+
|
114 |
+
def from_civitai(self, state_dict):
|
115 |
+
rename_dict = {
|
116 |
+
"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
117 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
|
118 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
|
119 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
|
120 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
|
121 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
|
122 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
|
123 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
|
124 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
|
125 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
|
126 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
|
127 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
128 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
129 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
|
130 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
|
131 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
|
132 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
|
133 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
|
134 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
|
135 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
|
136 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
|
137 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
|
138 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
|
139 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
|
140 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
|
141 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
|
142 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
|
143 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
144 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
145 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
|
146 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
|
147 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
|
148 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
|
149 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
|
150 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
|
151 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
|
152 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
|
153 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
|
154 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
|
155 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
|
156 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
|
157 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
|
158 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
|
159 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
160 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
161 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
|
162 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
|
163 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
|
164 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
|
165 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
|
166 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
|
167 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
|
168 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
|
169 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
|
170 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
|
171 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
|
172 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
|
173 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
|
174 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
|
175 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
|
176 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
|
177 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
|
178 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
|
179 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
|
180 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
|
181 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
|
182 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
|
183 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
|
184 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
|
185 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
|
186 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
|
187 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
|
188 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
|
189 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
|
190 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
|
191 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
192 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
193 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
|
194 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
|
195 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
|
196 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
|
197 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
|
198 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
|
199 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
|
200 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
|
201 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
|
202 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
|
203 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
|
204 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
|
205 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
|
206 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
|
207 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
208 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
209 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
|
210 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
|
211 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
|
212 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
|
213 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
|
214 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
|
215 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
|
216 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
|
217 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
|
218 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
|
219 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
|
220 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
|
221 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
|
222 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
|
223 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
224 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
225 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
|
226 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
|
227 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
|
228 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
|
229 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
|
230 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
|
231 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
|
232 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
|
233 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
|
234 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
|
235 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
|
236 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
|
237 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
|
238 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
|
239 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
240 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
241 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
|
242 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
|
243 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
|
244 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
|
245 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
|
246 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
|
247 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
|
248 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
|
249 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
|
250 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
|
251 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
|
252 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
|
253 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
|
254 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
|
255 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
256 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
257 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
|
258 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
|
259 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
|
260 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
|
261 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
|
262 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
|
263 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
|
264 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
|
265 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
|
266 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
|
267 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
|
268 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
|
269 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
|
270 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
|
271 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
272 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
273 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
|
274 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
|
275 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
|
276 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
|
277 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
|
278 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
|
279 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
|
280 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
|
281 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
|
282 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
|
283 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
|
284 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
|
285 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
|
286 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
|
287 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
288 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
289 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
|
290 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
|
291 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
|
292 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
|
293 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
|
294 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
|
295 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
|
296 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
|
297 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
|
298 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
|
299 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
|
300 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
|
301 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
|
302 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
|
303 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
304 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
305 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
|
306 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
|
307 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
|
308 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
|
309 |
+
"cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
|
310 |
+
"cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
311 |
+
"cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds"
|
312 |
+
}
|
313 |
+
state_dict_ = {}
|
314 |
+
for name in state_dict:
|
315 |
+
if name in rename_dict:
|
316 |
+
param = state_dict[name]
|
317 |
+
if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight":
|
318 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
319 |
+
state_dict_[rename_dict[name]] = param
|
320 |
+
return state_dict_
|
diffsynth/models/sd_unet.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
diffsynth/models/sd_vae_decoder.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .attention import Attention
|
3 |
+
from .sd_unet import ResnetBlock, UpSampler
|
4 |
+
from .tiler import TileWorker
|
5 |
+
|
6 |
+
|
7 |
+
class VAEAttentionBlock(torch.nn.Module):
|
8 |
+
|
9 |
+
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
|
10 |
+
super().__init__()
|
11 |
+
inner_dim = num_attention_heads * attention_head_dim
|
12 |
+
|
13 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
14 |
+
|
15 |
+
self.transformer_blocks = torch.nn.ModuleList([
|
16 |
+
Attention(
|
17 |
+
inner_dim,
|
18 |
+
num_attention_heads,
|
19 |
+
attention_head_dim,
|
20 |
+
bias_q=True,
|
21 |
+
bias_kv=True,
|
22 |
+
bias_out=True
|
23 |
+
)
|
24 |
+
for d in range(num_layers)
|
25 |
+
])
|
26 |
+
|
27 |
+
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
28 |
+
batch, _, height, width = hidden_states.shape
|
29 |
+
residual = hidden_states
|
30 |
+
|
31 |
+
hidden_states = self.norm(hidden_states)
|
32 |
+
inner_dim = hidden_states.shape[1]
|
33 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
34 |
+
|
35 |
+
for block in self.transformer_blocks:
|
36 |
+
hidden_states = block(hidden_states)
|
37 |
+
|
38 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
39 |
+
hidden_states = hidden_states + residual
|
40 |
+
|
41 |
+
return hidden_states, time_emb, text_emb, res_stack
|
42 |
+
|
43 |
+
|
44 |
+
class SDVAEDecoder(torch.nn.Module):
|
45 |
+
def __init__(self):
|
46 |
+
super().__init__()
|
47 |
+
self.scaling_factor = 0.18215
|
48 |
+
self.post_quant_conv = torch.nn.Conv2d(4, 4, kernel_size=1)
|
49 |
+
self.conv_in = torch.nn.Conv2d(4, 512, kernel_size=3, padding=1)
|
50 |
+
|
51 |
+
self.blocks = torch.nn.ModuleList([
|
52 |
+
# UNetMidBlock2D
|
53 |
+
ResnetBlock(512, 512, eps=1e-6),
|
54 |
+
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
55 |
+
ResnetBlock(512, 512, eps=1e-6),
|
56 |
+
# UpDecoderBlock2D
|
57 |
+
ResnetBlock(512, 512, eps=1e-6),
|
58 |
+
ResnetBlock(512, 512, eps=1e-6),
|
59 |
+
ResnetBlock(512, 512, eps=1e-6),
|
60 |
+
UpSampler(512),
|
61 |
+
# UpDecoderBlock2D
|
62 |
+
ResnetBlock(512, 512, eps=1e-6),
|
63 |
+
ResnetBlock(512, 512, eps=1e-6),
|
64 |
+
ResnetBlock(512, 512, eps=1e-6),
|
65 |
+
UpSampler(512),
|
66 |
+
# UpDecoderBlock2D
|
67 |
+
ResnetBlock(512, 256, eps=1e-6),
|
68 |
+
ResnetBlock(256, 256, eps=1e-6),
|
69 |
+
ResnetBlock(256, 256, eps=1e-6),
|
70 |
+
UpSampler(256),
|
71 |
+
# UpDecoderBlock2D
|
72 |
+
ResnetBlock(256, 128, eps=1e-6),
|
73 |
+
ResnetBlock(128, 128, eps=1e-6),
|
74 |
+
ResnetBlock(128, 128, eps=1e-6),
|
75 |
+
])
|
76 |
+
|
77 |
+
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-5)
|
78 |
+
self.conv_act = torch.nn.SiLU()
|
79 |
+
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
80 |
+
|
81 |
+
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
82 |
+
hidden_states = TileWorker().tiled_forward(
|
83 |
+
lambda x: self.forward(x),
|
84 |
+
sample,
|
85 |
+
tile_size,
|
86 |
+
tile_stride,
|
87 |
+
tile_device=sample.device,
|
88 |
+
tile_dtype=sample.dtype
|
89 |
+
)
|
90 |
+
return hidden_states
|
91 |
+
|
92 |
+
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
93 |
+
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
94 |
+
if tiled:
|
95 |
+
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
96 |
+
|
97 |
+
# 1. pre-process
|
98 |
+
sample = sample / self.scaling_factor
|
99 |
+
hidden_states = self.post_quant_conv(sample)
|
100 |
+
hidden_states = self.conv_in(hidden_states)
|
101 |
+
time_emb = None
|
102 |
+
text_emb = None
|
103 |
+
res_stack = None
|
104 |
+
|
105 |
+
# 2. blocks
|
106 |
+
for i, block in enumerate(self.blocks):
|
107 |
+
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
108 |
+
|
109 |
+
# 3. output
|
110 |
+
hidden_states = self.conv_norm_out(hidden_states)
|
111 |
+
hidden_states = self.conv_act(hidden_states)
|
112 |
+
hidden_states = self.conv_out(hidden_states)
|
113 |
+
|
114 |
+
return hidden_states
|
115 |
+
|
116 |
+
def state_dict_converter(self):
|
117 |
+
return SDVAEDecoderStateDictConverter()
|
118 |
+
|
119 |
+
|
120 |
+
class SDVAEDecoderStateDictConverter:
|
121 |
+
def __init__(self):
|
122 |
+
pass
|
123 |
+
|
124 |
+
def from_diffusers(self, state_dict):
|
125 |
+
# architecture
|
126 |
+
block_types = [
|
127 |
+
'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock',
|
128 |
+
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
129 |
+
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
130 |
+
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
131 |
+
'ResnetBlock', 'ResnetBlock', 'ResnetBlock'
|
132 |
+
]
|
133 |
+
|
134 |
+
# Rename each parameter
|
135 |
+
local_rename_dict = {
|
136 |
+
"post_quant_conv": "post_quant_conv",
|
137 |
+
"decoder.conv_in": "conv_in",
|
138 |
+
"decoder.mid_block.attentions.0.group_norm": "blocks.1.norm",
|
139 |
+
"decoder.mid_block.attentions.0.to_q": "blocks.1.transformer_blocks.0.to_q",
|
140 |
+
"decoder.mid_block.attentions.0.to_k": "blocks.1.transformer_blocks.0.to_k",
|
141 |
+
"decoder.mid_block.attentions.0.to_v": "blocks.1.transformer_blocks.0.to_v",
|
142 |
+
"decoder.mid_block.attentions.0.to_out.0": "blocks.1.transformer_blocks.0.to_out",
|
143 |
+
"decoder.mid_block.resnets.0.norm1": "blocks.0.norm1",
|
144 |
+
"decoder.mid_block.resnets.0.conv1": "blocks.0.conv1",
|
145 |
+
"decoder.mid_block.resnets.0.norm2": "blocks.0.norm2",
|
146 |
+
"decoder.mid_block.resnets.0.conv2": "blocks.0.conv2",
|
147 |
+
"decoder.mid_block.resnets.1.norm1": "blocks.2.norm1",
|
148 |
+
"decoder.mid_block.resnets.1.conv1": "blocks.2.conv1",
|
149 |
+
"decoder.mid_block.resnets.1.norm2": "blocks.2.norm2",
|
150 |
+
"decoder.mid_block.resnets.1.conv2": "blocks.2.conv2",
|
151 |
+
"decoder.conv_norm_out": "conv_norm_out",
|
152 |
+
"decoder.conv_out": "conv_out",
|
153 |
+
}
|
154 |
+
name_list = sorted([name for name in state_dict])
|
155 |
+
rename_dict = {}
|
156 |
+
block_id = {"ResnetBlock": 2, "DownSampler": 2, "UpSampler": 2}
|
157 |
+
last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
|
158 |
+
for name in name_list:
|
159 |
+
names = name.split(".")
|
160 |
+
name_prefix = ".".join(names[:-1])
|
161 |
+
if name_prefix in local_rename_dict:
|
162 |
+
rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
|
163 |
+
elif name.startswith("decoder.up_blocks"):
|
164 |
+
block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
|
165 |
+
block_type_with_id = ".".join(names[:5])
|
166 |
+
if block_type_with_id != last_block_type_with_id[block_type]:
|
167 |
+
block_id[block_type] += 1
|
168 |
+
last_block_type_with_id[block_type] = block_type_with_id
|
169 |
+
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
170 |
+
block_id[block_type] += 1
|
171 |
+
block_type_with_id = ".".join(names[:5])
|
172 |
+
names = ["blocks", str(block_id[block_type])] + names[5:]
|
173 |
+
rename_dict[name] = ".".join(names)
|
174 |
+
|
175 |
+
# Convert state_dict
|
176 |
+
state_dict_ = {}
|
177 |
+
for name, param in state_dict.items():
|
178 |
+
if name in rename_dict:
|
179 |
+
state_dict_[rename_dict[name]] = param
|
180 |
+
return state_dict_
|
181 |
+
|
182 |
+
def from_civitai(self, state_dict):
|
183 |
+
rename_dict = {
|
184 |
+
"first_stage_model.decoder.conv_in.bias": "conv_in.bias",
|
185 |
+
"first_stage_model.decoder.conv_in.weight": "conv_in.weight",
|
186 |
+
"first_stage_model.decoder.conv_out.bias": "conv_out.bias",
|
187 |
+
"first_stage_model.decoder.conv_out.weight": "conv_out.weight",
|
188 |
+
"first_stage_model.decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
|
189 |
+
"first_stage_model.decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
|
190 |
+
"first_stage_model.decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
|
191 |
+
"first_stage_model.decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
|
192 |
+
"first_stage_model.decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
|
193 |
+
"first_stage_model.decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
|
194 |
+
"first_stage_model.decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
|
195 |
+
"first_stage_model.decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
|
196 |
+
"first_stage_model.decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
|
197 |
+
"first_stage_model.decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
|
198 |
+
"first_stage_model.decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
|
199 |
+
"first_stage_model.decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
|
200 |
+
"first_stage_model.decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
|
201 |
+
"first_stage_model.decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
|
202 |
+
"first_stage_model.decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
|
203 |
+
"first_stage_model.decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
|
204 |
+
"first_stage_model.decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
|
205 |
+
"first_stage_model.decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
|
206 |
+
"first_stage_model.decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
|
207 |
+
"first_stage_model.decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
|
208 |
+
"first_stage_model.decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
|
209 |
+
"first_stage_model.decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
|
210 |
+
"first_stage_model.decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
|
211 |
+
"first_stage_model.decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
|
212 |
+
"first_stage_model.decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
|
213 |
+
"first_stage_model.decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
|
214 |
+
"first_stage_model.decoder.norm_out.bias": "conv_norm_out.bias",
|
215 |
+
"first_stage_model.decoder.norm_out.weight": "conv_norm_out.weight",
|
216 |
+
"first_stage_model.decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
|
217 |
+
"first_stage_model.decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
|
218 |
+
"first_stage_model.decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
|
219 |
+
"first_stage_model.decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
|
220 |
+
"first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
|
221 |
+
"first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
|
222 |
+
"first_stage_model.decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
|
223 |
+
"first_stage_model.decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
|
224 |
+
"first_stage_model.decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
|
225 |
+
"first_stage_model.decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
|
226 |
+
"first_stage_model.decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
|
227 |
+
"first_stage_model.decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
|
228 |
+
"first_stage_model.decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
|
229 |
+
"first_stage_model.decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
|
230 |
+
"first_stage_model.decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
|
231 |
+
"first_stage_model.decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
|
232 |
+
"first_stage_model.decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
|
233 |
+
"first_stage_model.decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
|
234 |
+
"first_stage_model.decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
|
235 |
+
"first_stage_model.decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
|
236 |
+
"first_stage_model.decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
|
237 |
+
"first_stage_model.decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
|
238 |
+
"first_stage_model.decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
|
239 |
+
"first_stage_model.decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
|
240 |
+
"first_stage_model.decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
|
241 |
+
"first_stage_model.decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
|
242 |
+
"first_stage_model.decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
|
243 |
+
"first_stage_model.decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
|
244 |
+
"first_stage_model.decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
|
245 |
+
"first_stage_model.decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
|
246 |
+
"first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
|
247 |
+
"first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
|
248 |
+
"first_stage_model.decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
|
249 |
+
"first_stage_model.decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
|
250 |
+
"first_stage_model.decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
|
251 |
+
"first_stage_model.decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
|
252 |
+
"first_stage_model.decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
|
253 |
+
"first_stage_model.decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
|
254 |
+
"first_stage_model.decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
|
255 |
+
"first_stage_model.decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
|
256 |
+
"first_stage_model.decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
|
257 |
+
"first_stage_model.decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
|
258 |
+
"first_stage_model.decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
|
259 |
+
"first_stage_model.decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
|
260 |
+
"first_stage_model.decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
|
261 |
+
"first_stage_model.decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
|
262 |
+
"first_stage_model.decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
|
263 |
+
"first_stage_model.decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
|
264 |
+
"first_stage_model.decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
|
265 |
+
"first_stage_model.decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
|
266 |
+
"first_stage_model.decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
|
267 |
+
"first_stage_model.decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
|
268 |
+
"first_stage_model.decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
|
269 |
+
"first_stage_model.decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
|
270 |
+
"first_stage_model.decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
|
271 |
+
"first_stage_model.decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
|
272 |
+
"first_stage_model.decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
|
273 |
+
"first_stage_model.decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
|
274 |
+
"first_stage_model.decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
|
275 |
+
"first_stage_model.decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
|
276 |
+
"first_stage_model.decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
|
277 |
+
"first_stage_model.decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
|
278 |
+
"first_stage_model.decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
|
279 |
+
"first_stage_model.decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
|
280 |
+
"first_stage_model.decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
|
281 |
+
"first_stage_model.decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
|
282 |
+
"first_stage_model.decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
|
283 |
+
"first_stage_model.decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
|
284 |
+
"first_stage_model.decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
|
285 |
+
"first_stage_model.decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
|
286 |
+
"first_stage_model.decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
|
287 |
+
"first_stage_model.decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
|
288 |
+
"first_stage_model.decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
|
289 |
+
"first_stage_model.decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
|
290 |
+
"first_stage_model.decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
|
291 |
+
"first_stage_model.decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
|
292 |
+
"first_stage_model.decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
|
293 |
+
"first_stage_model.decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
|
294 |
+
"first_stage_model.decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
|
295 |
+
"first_stage_model.decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
|
296 |
+
"first_stage_model.decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
|
297 |
+
"first_stage_model.decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
|
298 |
+
"first_stage_model.decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
|
299 |
+
"first_stage_model.decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
|
300 |
+
"first_stage_model.decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
|
301 |
+
"first_stage_model.decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
|
302 |
+
"first_stage_model.decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
|
303 |
+
"first_stage_model.decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
|
304 |
+
"first_stage_model.decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
|
305 |
+
"first_stage_model.decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
|
306 |
+
"first_stage_model.decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
|
307 |
+
"first_stage_model.decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
|
308 |
+
"first_stage_model.decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
|
309 |
+
"first_stage_model.decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
|
310 |
+
"first_stage_model.decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
|
311 |
+
"first_stage_model.decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
|
312 |
+
"first_stage_model.decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
|
313 |
+
"first_stage_model.decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
|
314 |
+
"first_stage_model.decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
|
315 |
+
"first_stage_model.decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
|
316 |
+
"first_stage_model.decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
|
317 |
+
"first_stage_model.decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
|
318 |
+
"first_stage_model.decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
|
319 |
+
"first_stage_model.decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
|
320 |
+
"first_stage_model.decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
|
321 |
+
"first_stage_model.decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
|
322 |
+
"first_stage_model.post_quant_conv.bias": "post_quant_conv.bias",
|
323 |
+
"first_stage_model.post_quant_conv.weight": "post_quant_conv.weight",
|
324 |
+
}
|
325 |
+
state_dict_ = {}
|
326 |
+
for name in state_dict:
|
327 |
+
if name in rename_dict:
|
328 |
+
param = state_dict[name]
|
329 |
+
if "transformer_blocks" in rename_dict[name]:
|
330 |
+
param = param.squeeze()
|
331 |
+
state_dict_[rename_dict[name]] = param
|
332 |
+
return state_dict_
|
diffsynth/models/sd_vae_encoder.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .sd_unet import ResnetBlock, DownSampler
|
3 |
+
from .sd_vae_decoder import VAEAttentionBlock
|
4 |
+
from .tiler import TileWorker
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
|
8 |
+
class SDVAEEncoder(torch.nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__()
|
11 |
+
self.scaling_factor = 0.18215
|
12 |
+
self.quant_conv = torch.nn.Conv2d(8, 8, kernel_size=1)
|
13 |
+
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
|
14 |
+
|
15 |
+
self.blocks = torch.nn.ModuleList([
|
16 |
+
# DownEncoderBlock2D
|
17 |
+
ResnetBlock(128, 128, eps=1e-6),
|
18 |
+
ResnetBlock(128, 128, eps=1e-6),
|
19 |
+
DownSampler(128, padding=0, extra_padding=True),
|
20 |
+
# DownEncoderBlock2D
|
21 |
+
ResnetBlock(128, 256, eps=1e-6),
|
22 |
+
ResnetBlock(256, 256, eps=1e-6),
|
23 |
+
DownSampler(256, padding=0, extra_padding=True),
|
24 |
+
# DownEncoderBlock2D
|
25 |
+
ResnetBlock(256, 512, eps=1e-6),
|
26 |
+
ResnetBlock(512, 512, eps=1e-6),
|
27 |
+
DownSampler(512, padding=0, extra_padding=True),
|
28 |
+
# DownEncoderBlock2D
|
29 |
+
ResnetBlock(512, 512, eps=1e-6),
|
30 |
+
ResnetBlock(512, 512, eps=1e-6),
|
31 |
+
# UNetMidBlock2D
|
32 |
+
ResnetBlock(512, 512, eps=1e-6),
|
33 |
+
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
34 |
+
ResnetBlock(512, 512, eps=1e-6),
|
35 |
+
])
|
36 |
+
|
37 |
+
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
|
38 |
+
self.conv_act = torch.nn.SiLU()
|
39 |
+
self.conv_out = torch.nn.Conv2d(512, 8, kernel_size=3, padding=1)
|
40 |
+
|
41 |
+
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
42 |
+
hidden_states = TileWorker().tiled_forward(
|
43 |
+
lambda x: self.forward(x),
|
44 |
+
sample,
|
45 |
+
tile_size,
|
46 |
+
tile_stride,
|
47 |
+
tile_device=sample.device,
|
48 |
+
tile_dtype=sample.dtype
|
49 |
+
)
|
50 |
+
return hidden_states
|
51 |
+
|
52 |
+
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
53 |
+
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
54 |
+
if tiled:
|
55 |
+
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
56 |
+
|
57 |
+
# 1. pre-process
|
58 |
+
hidden_states = self.conv_in(sample)
|
59 |
+
time_emb = None
|
60 |
+
text_emb = None
|
61 |
+
res_stack = None
|
62 |
+
|
63 |
+
# 2. blocks
|
64 |
+
for i, block in enumerate(self.blocks):
|
65 |
+
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
66 |
+
|
67 |
+
# 3. output
|
68 |
+
hidden_states = self.conv_norm_out(hidden_states)
|
69 |
+
hidden_states = self.conv_act(hidden_states)
|
70 |
+
hidden_states = self.conv_out(hidden_states)
|
71 |
+
hidden_states = self.quant_conv(hidden_states)
|
72 |
+
hidden_states = hidden_states[:, :4]
|
73 |
+
hidden_states *= self.scaling_factor
|
74 |
+
|
75 |
+
return hidden_states
|
76 |
+
|
77 |
+
def encode_video(self, sample, batch_size=8):
|
78 |
+
B = sample.shape[0]
|
79 |
+
hidden_states = []
|
80 |
+
|
81 |
+
for i in range(0, sample.shape[2], batch_size):
|
82 |
+
|
83 |
+
j = min(i + batch_size, sample.shape[2])
|
84 |
+
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
|
85 |
+
|
86 |
+
hidden_states_batch = self(sample_batch)
|
87 |
+
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
|
88 |
+
|
89 |
+
hidden_states.append(hidden_states_batch)
|
90 |
+
|
91 |
+
hidden_states = torch.concat(hidden_states, dim=2)
|
92 |
+
return hidden_states
|
93 |
+
|
94 |
+
def state_dict_converter(self):
|
95 |
+
return SDVAEEncoderStateDictConverter()
|
96 |
+
|
97 |
+
|
98 |
+
class SDVAEEncoderStateDictConverter:
|
99 |
+
def __init__(self):
|
100 |
+
pass
|
101 |
+
|
102 |
+
def from_diffusers(self, state_dict):
|
103 |
+
# architecture
|
104 |
+
block_types = [
|
105 |
+
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
106 |
+
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
107 |
+
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
108 |
+
'ResnetBlock', 'ResnetBlock',
|
109 |
+
'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock'
|
110 |
+
]
|
111 |
+
|
112 |
+
# Rename each parameter
|
113 |
+
local_rename_dict = {
|
114 |
+
"quant_conv": "quant_conv",
|
115 |
+
"encoder.conv_in": "conv_in",
|
116 |
+
"encoder.mid_block.attentions.0.group_norm": "blocks.12.norm",
|
117 |
+
"encoder.mid_block.attentions.0.to_q": "blocks.12.transformer_blocks.0.to_q",
|
118 |
+
"encoder.mid_block.attentions.0.to_k": "blocks.12.transformer_blocks.0.to_k",
|
119 |
+
"encoder.mid_block.attentions.0.to_v": "blocks.12.transformer_blocks.0.to_v",
|
120 |
+
"encoder.mid_block.attentions.0.to_out.0": "blocks.12.transformer_blocks.0.to_out",
|
121 |
+
"encoder.mid_block.resnets.0.norm1": "blocks.11.norm1",
|
122 |
+
"encoder.mid_block.resnets.0.conv1": "blocks.11.conv1",
|
123 |
+
"encoder.mid_block.resnets.0.norm2": "blocks.11.norm2",
|
124 |
+
"encoder.mid_block.resnets.0.conv2": "blocks.11.conv2",
|
125 |
+
"encoder.mid_block.resnets.1.norm1": "blocks.13.norm1",
|
126 |
+
"encoder.mid_block.resnets.1.conv1": "blocks.13.conv1",
|
127 |
+
"encoder.mid_block.resnets.1.norm2": "blocks.13.norm2",
|
128 |
+
"encoder.mid_block.resnets.1.conv2": "blocks.13.conv2",
|
129 |
+
"encoder.conv_norm_out": "conv_norm_out",
|
130 |
+
"encoder.conv_out": "conv_out",
|
131 |
+
}
|
132 |
+
name_list = sorted([name for name in state_dict])
|
133 |
+
rename_dict = {}
|
134 |
+
block_id = {"ResnetBlock": -1, "DownSampler": -1, "UpSampler": -1}
|
135 |
+
last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
|
136 |
+
for name in name_list:
|
137 |
+
names = name.split(".")
|
138 |
+
name_prefix = ".".join(names[:-1])
|
139 |
+
if name_prefix in local_rename_dict:
|
140 |
+
rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
|
141 |
+
elif name.startswith("encoder.down_blocks"):
|
142 |
+
block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
|
143 |
+
block_type_with_id = ".".join(names[:5])
|
144 |
+
if block_type_with_id != last_block_type_with_id[block_type]:
|
145 |
+
block_id[block_type] += 1
|
146 |
+
last_block_type_with_id[block_type] = block_type_with_id
|
147 |
+
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
148 |
+
block_id[block_type] += 1
|
149 |
+
block_type_with_id = ".".join(names[:5])
|
150 |
+
names = ["blocks", str(block_id[block_type])] + names[5:]
|
151 |
+
rename_dict[name] = ".".join(names)
|
152 |
+
|
153 |
+
# Convert state_dict
|
154 |
+
state_dict_ = {}
|
155 |
+
for name, param in state_dict.items():
|
156 |
+
if name in rename_dict:
|
157 |
+
state_dict_[rename_dict[name]] = param
|
158 |
+
return state_dict_
|
159 |
+
|
160 |
+
def from_civitai(self, state_dict):
|
161 |
+
rename_dict = {
|
162 |
+
"first_stage_model.encoder.conv_in.bias": "conv_in.bias",
|
163 |
+
"first_stage_model.encoder.conv_in.weight": "conv_in.weight",
|
164 |
+
"first_stage_model.encoder.conv_out.bias": "conv_out.bias",
|
165 |
+
"first_stage_model.encoder.conv_out.weight": "conv_out.weight",
|
166 |
+
"first_stage_model.encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
|
167 |
+
"first_stage_model.encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
|
168 |
+
"first_stage_model.encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
|
169 |
+
"first_stage_model.encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
|
170 |
+
"first_stage_model.encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
|
171 |
+
"first_stage_model.encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
|
172 |
+
"first_stage_model.encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
|
173 |
+
"first_stage_model.encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
|
174 |
+
"first_stage_model.encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
|
175 |
+
"first_stage_model.encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
|
176 |
+
"first_stage_model.encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
|
177 |
+
"first_stage_model.encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
|
178 |
+
"first_stage_model.encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
|
179 |
+
"first_stage_model.encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
|
180 |
+
"first_stage_model.encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
|
181 |
+
"first_stage_model.encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
|
182 |
+
"first_stage_model.encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
|
183 |
+
"first_stage_model.encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
|
184 |
+
"first_stage_model.encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
|
185 |
+
"first_stage_model.encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
|
186 |
+
"first_stage_model.encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
|
187 |
+
"first_stage_model.encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
|
188 |
+
"first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
|
189 |
+
"first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
|
190 |
+
"first_stage_model.encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
|
191 |
+
"first_stage_model.encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
|
192 |
+
"first_stage_model.encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
|
193 |
+
"first_stage_model.encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
|
194 |
+
"first_stage_model.encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
|
195 |
+
"first_stage_model.encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
|
196 |
+
"first_stage_model.encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
|
197 |
+
"first_stage_model.encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
|
198 |
+
"first_stage_model.encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
|
199 |
+
"first_stage_model.encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
|
200 |
+
"first_stage_model.encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
|
201 |
+
"first_stage_model.encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
|
202 |
+
"first_stage_model.encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
|
203 |
+
"first_stage_model.encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
|
204 |
+
"first_stage_model.encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
|
205 |
+
"first_stage_model.encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
|
206 |
+
"first_stage_model.encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
|
207 |
+
"first_stage_model.encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
|
208 |
+
"first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
|
209 |
+
"first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
|
210 |
+
"first_stage_model.encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
|
211 |
+
"first_stage_model.encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
|
212 |
+
"first_stage_model.encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
|
213 |
+
"first_stage_model.encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
|
214 |
+
"first_stage_model.encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
|
215 |
+
"first_stage_model.encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
|
216 |
+
"first_stage_model.encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
|
217 |
+
"first_stage_model.encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
|
218 |
+
"first_stage_model.encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
|
219 |
+
"first_stage_model.encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
|
220 |
+
"first_stage_model.encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
|
221 |
+
"first_stage_model.encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
|
222 |
+
"first_stage_model.encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
|
223 |
+
"first_stage_model.encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
|
224 |
+
"first_stage_model.encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
|
225 |
+
"first_stage_model.encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
|
226 |
+
"first_stage_model.encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
|
227 |
+
"first_stage_model.encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
|
228 |
+
"first_stage_model.encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
|
229 |
+
"first_stage_model.encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
|
230 |
+
"first_stage_model.encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
|
231 |
+
"first_stage_model.encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
|
232 |
+
"first_stage_model.encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
|
233 |
+
"first_stage_model.encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
|
234 |
+
"first_stage_model.encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
|
235 |
+
"first_stage_model.encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
|
236 |
+
"first_stage_model.encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
|
237 |
+
"first_stage_model.encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
|
238 |
+
"first_stage_model.encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
|
239 |
+
"first_stage_model.encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
|
240 |
+
"first_stage_model.encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
|
241 |
+
"first_stage_model.encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
|
242 |
+
"first_stage_model.encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
|
243 |
+
"first_stage_model.encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
|
244 |
+
"first_stage_model.encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
|
245 |
+
"first_stage_model.encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
|
246 |
+
"first_stage_model.encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
|
247 |
+
"first_stage_model.encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
|
248 |
+
"first_stage_model.encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
|
249 |
+
"first_stage_model.encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
|
250 |
+
"first_stage_model.encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
|
251 |
+
"first_stage_model.encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
|
252 |
+
"first_stage_model.encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
|
253 |
+
"first_stage_model.encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
|
254 |
+
"first_stage_model.encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
|
255 |
+
"first_stage_model.encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
|
256 |
+
"first_stage_model.encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
|
257 |
+
"first_stage_model.encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
|
258 |
+
"first_stage_model.encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
|
259 |
+
"first_stage_model.encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
|
260 |
+
"first_stage_model.encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
|
261 |
+
"first_stage_model.encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
|
262 |
+
"first_stage_model.encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
|
263 |
+
"first_stage_model.encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
|
264 |
+
"first_stage_model.encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
|
265 |
+
"first_stage_model.encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
|
266 |
+
"first_stage_model.encoder.norm_out.bias": "conv_norm_out.bias",
|
267 |
+
"first_stage_model.encoder.norm_out.weight": "conv_norm_out.weight",
|
268 |
+
"first_stage_model.quant_conv.bias": "quant_conv.bias",
|
269 |
+
"first_stage_model.quant_conv.weight": "quant_conv.weight",
|
270 |
+
}
|
271 |
+
state_dict_ = {}
|
272 |
+
for name in state_dict:
|
273 |
+
if name in rename_dict:
|
274 |
+
param = state_dict[name]
|
275 |
+
if "transformer_blocks" in rename_dict[name]:
|
276 |
+
param = param.squeeze()
|
277 |
+
state_dict_[rename_dict[name]] = param
|
278 |
+
return state_dict_
|
diffsynth/models/sdxl_ipadapter.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .svd_image_encoder import SVDImageEncoder
|
2 |
+
from transformers import CLIPImageProcessor
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class IpAdapterXLCLIPImageEmbedder(SVDImageEncoder):
|
7 |
+
def __init__(self):
|
8 |
+
super().__init__(embed_dim=1664, encoder_intermediate_size=8192, projection_dim=1280, num_encoder_layers=48, num_heads=16, head_dim=104)
|
9 |
+
self.image_processor = CLIPImageProcessor()
|
10 |
+
|
11 |
+
def forward(self, image):
|
12 |
+
pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
|
13 |
+
pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
|
14 |
+
return super().forward(pixel_values)
|
15 |
+
|
16 |
+
|
17 |
+
class IpAdapterImageProjModel(torch.nn.Module):
|
18 |
+
def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280, clip_extra_context_tokens=4):
|
19 |
+
super().__init__()
|
20 |
+
self.cross_attention_dim = cross_attention_dim
|
21 |
+
self.clip_extra_context_tokens = clip_extra_context_tokens
|
22 |
+
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
23 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
24 |
+
|
25 |
+
def forward(self, image_embeds):
|
26 |
+
clip_extra_context_tokens = self.proj(image_embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
|
27 |
+
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
28 |
+
return clip_extra_context_tokens
|
29 |
+
|
30 |
+
|
31 |
+
class IpAdapterModule(torch.nn.Module):
|
32 |
+
def __init__(self, input_dim, output_dim):
|
33 |
+
super().__init__()
|
34 |
+
self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
35 |
+
self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
36 |
+
|
37 |
+
def forward(self, hidden_states):
|
38 |
+
ip_k = self.to_k_ip(hidden_states)
|
39 |
+
ip_v = self.to_v_ip(hidden_states)
|
40 |
+
return ip_k, ip_v
|
41 |
+
|
42 |
+
|
43 |
+
class SDXLIpAdapter(torch.nn.Module):
|
44 |
+
def __init__(self):
|
45 |
+
super().__init__()
|
46 |
+
shape_list = [(2048, 640)] * 4 + [(2048, 1280)] * 50 + [(2048, 640)] * 6 + [(2048, 1280)] * 10
|
47 |
+
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
|
48 |
+
self.image_proj = IpAdapterImageProjModel()
|
49 |
+
self.set_full_adapter()
|
50 |
+
|
51 |
+
def set_full_adapter(self):
|
52 |
+
map_list = sum([
|
53 |
+
[(7, i) for i in range(2)],
|
54 |
+
[(10, i) for i in range(2)],
|
55 |
+
[(15, i) for i in range(10)],
|
56 |
+
[(18, i) for i in range(10)],
|
57 |
+
[(25, i) for i in range(10)],
|
58 |
+
[(28, i) for i in range(10)],
|
59 |
+
[(31, i) for i in range(10)],
|
60 |
+
[(35, i) for i in range(2)],
|
61 |
+
[(38, i) for i in range(2)],
|
62 |
+
[(41, i) for i in range(2)],
|
63 |
+
[(21, i) for i in range(10)],
|
64 |
+
], [])
|
65 |
+
self.call_block_id = {i: j for j, i in enumerate(map_list)}
|
66 |
+
|
67 |
+
def set_less_adapter(self):
|
68 |
+
map_list = sum([
|
69 |
+
[(7, i) for i in range(2)],
|
70 |
+
[(10, i) for i in range(2)],
|
71 |
+
[(15, i) for i in range(10)],
|
72 |
+
[(18, i) for i in range(10)],
|
73 |
+
[(25, i) for i in range(10)],
|
74 |
+
[(28, i) for i in range(10)],
|
75 |
+
[(31, i) for i in range(10)],
|
76 |
+
[(35, i) for i in range(2)],
|
77 |
+
[(38, i) for i in range(2)],
|
78 |
+
[(41, i) for i in range(2)],
|
79 |
+
[(21, i) for i in range(10)],
|
80 |
+
], [])
|
81 |
+
self.call_block_id = {i: j for j, i in enumerate(map_list) if j>=34 and j<44}
|
82 |
+
|
83 |
+
def forward(self, hidden_states, scale=1.0):
|
84 |
+
hidden_states = self.image_proj(hidden_states)
|
85 |
+
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
|
86 |
+
ip_kv_dict = {}
|
87 |
+
for (block_id, transformer_id) in self.call_block_id:
|
88 |
+
ipadapter_id = self.call_block_id[(block_id, transformer_id)]
|
89 |
+
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
|
90 |
+
if block_id not in ip_kv_dict:
|
91 |
+
ip_kv_dict[block_id] = {}
|
92 |
+
ip_kv_dict[block_id][transformer_id] = {
|
93 |
+
"ip_k": ip_k,
|
94 |
+
"ip_v": ip_v,
|
95 |
+
"scale": scale
|
96 |
+
}
|
97 |
+
return ip_kv_dict
|
98 |
+
|
99 |
+
def state_dict_converter(self):
|
100 |
+
return SDXLIpAdapterStateDictConverter()
|
101 |
+
|
102 |
+
|
103 |
+
class SDXLIpAdapterStateDictConverter:
|
104 |
+
def __init__(self):
|
105 |
+
pass
|
106 |
+
|
107 |
+
def from_diffusers(self, state_dict):
|
108 |
+
state_dict_ = {}
|
109 |
+
for name in state_dict["ip_adapter"]:
|
110 |
+
names = name.split(".")
|
111 |
+
layer_id = str(int(names[0]) // 2)
|
112 |
+
name_ = ".".join(["ipadapter_modules"] + [layer_id] + names[1:])
|
113 |
+
state_dict_[name_] = state_dict["ip_adapter"][name]
|
114 |
+
for name in state_dict["image_proj"]:
|
115 |
+
name_ = "image_proj." + name
|
116 |
+
state_dict_[name_] = state_dict["image_proj"][name]
|
117 |
+
return state_dict_
|
118 |
+
|
119 |
+
def from_civitai(self, state_dict):
|
120 |
+
return self.from_diffusers(state_dict)
|
121 |
+
|
diffsynth/models/sdxl_motion.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .sd_motion import TemporalBlock
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
class SDXLMotionModel(torch.nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super().__init__()
|
9 |
+
self.motion_modules = torch.nn.ModuleList([
|
10 |
+
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
11 |
+
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
12 |
+
|
13 |
+
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
14 |
+
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
15 |
+
|
16 |
+
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
17 |
+
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
18 |
+
|
19 |
+
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
20 |
+
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
21 |
+
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
22 |
+
|
23 |
+
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
24 |
+
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
25 |
+
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
26 |
+
|
27 |
+
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
28 |
+
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
29 |
+
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
30 |
+
])
|
31 |
+
self.call_block_id = {
|
32 |
+
0: 0,
|
33 |
+
2: 1,
|
34 |
+
7: 2,
|
35 |
+
10: 3,
|
36 |
+
15: 4,
|
37 |
+
18: 5,
|
38 |
+
25: 6,
|
39 |
+
28: 7,
|
40 |
+
31: 8,
|
41 |
+
35: 9,
|
42 |
+
38: 10,
|
43 |
+
41: 11,
|
44 |
+
44: 12,
|
45 |
+
46: 13,
|
46 |
+
48: 14,
|
47 |
+
}
|
48 |
+
|
49 |
+
def forward(self):
|
50 |
+
pass
|
51 |
+
|
52 |
+
def state_dict_converter(self):
|
53 |
+
return SDMotionModelStateDictConverter()
|
54 |
+
|
55 |
+
|
56 |
+
class SDMotionModelStateDictConverter:
|
57 |
+
def __init__(self):
|
58 |
+
pass
|
59 |
+
|
60 |
+
def from_diffusers(self, state_dict):
|
61 |
+
rename_dict = {
|
62 |
+
"norm": "norm",
|
63 |
+
"proj_in": "proj_in",
|
64 |
+
"transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
|
65 |
+
"transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
|
66 |
+
"transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
|
67 |
+
"transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
|
68 |
+
"transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
|
69 |
+
"transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
|
70 |
+
"transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
|
71 |
+
"transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
|
72 |
+
"transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
|
73 |
+
"transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
|
74 |
+
"transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
|
75 |
+
"transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
|
76 |
+
"transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
|
77 |
+
"transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
|
78 |
+
"transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
|
79 |
+
"proj_out": "proj_out",
|
80 |
+
}
|
81 |
+
name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
|
82 |
+
name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
|
83 |
+
name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
|
84 |
+
state_dict_ = {}
|
85 |
+
last_prefix, module_id = "", -1
|
86 |
+
for name in name_list:
|
87 |
+
names = name.split(".")
|
88 |
+
prefix_index = names.index("temporal_transformer") + 1
|
89 |
+
prefix = ".".join(names[:prefix_index])
|
90 |
+
if prefix != last_prefix:
|
91 |
+
last_prefix = prefix
|
92 |
+
module_id += 1
|
93 |
+
middle_name = ".".join(names[prefix_index:-1])
|
94 |
+
suffix = names[-1]
|
95 |
+
if "pos_encoder" in names:
|
96 |
+
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
|
97 |
+
else:
|
98 |
+
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
|
99 |
+
state_dict_[rename] = state_dict[name]
|
100 |
+
return state_dict_
|
101 |
+
|
102 |
+
def from_civitai(self, state_dict):
|
103 |
+
return self.from_diffusers(state_dict)
|
diffsynth/models/sdxl_text_encoder.py
ADDED
@@ -0,0 +1,757 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .sd_text_encoder import CLIPEncoderLayer
|
3 |
+
|
4 |
+
|
5 |
+
class SDXLTextEncoder(torch.nn.Module):
|
6 |
+
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=11, encoder_intermediate_size=3072):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
# token_embedding
|
10 |
+
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
11 |
+
|
12 |
+
# position_embeds (This is a fixed tensor)
|
13 |
+
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
14 |
+
|
15 |
+
# encoders
|
16 |
+
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
17 |
+
|
18 |
+
# attn_mask
|
19 |
+
self.attn_mask = self.attention_mask(max_position_embeddings)
|
20 |
+
|
21 |
+
# The text encoder is different to that in Stable Diffusion 1.x.
|
22 |
+
# It does not include final_layer_norm.
|
23 |
+
|
24 |
+
def attention_mask(self, length):
|
25 |
+
mask = torch.empty(length, length)
|
26 |
+
mask.fill_(float("-inf"))
|
27 |
+
mask.triu_(1)
|
28 |
+
return mask
|
29 |
+
|
30 |
+
def forward(self, input_ids, clip_skip=1):
|
31 |
+
embeds = self.token_embedding(input_ids) + self.position_embeds
|
32 |
+
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
33 |
+
for encoder_id, encoder in enumerate(self.encoders):
|
34 |
+
embeds = encoder(embeds, attn_mask=attn_mask)
|
35 |
+
if encoder_id + clip_skip == len(self.encoders):
|
36 |
+
break
|
37 |
+
return embeds
|
38 |
+
|
39 |
+
def state_dict_converter(self):
|
40 |
+
return SDXLTextEncoderStateDictConverter()
|
41 |
+
|
42 |
+
|
43 |
+
class SDXLTextEncoder2(torch.nn.Module):
|
44 |
+
def __init__(self, embed_dim=1280, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=32, encoder_intermediate_size=5120):
|
45 |
+
super().__init__()
|
46 |
+
|
47 |
+
# token_embedding
|
48 |
+
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
49 |
+
|
50 |
+
# position_embeds (This is a fixed tensor)
|
51 |
+
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
52 |
+
|
53 |
+
# encoders
|
54 |
+
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=20, head_dim=64, use_quick_gelu=False) for _ in range(num_encoder_layers)])
|
55 |
+
|
56 |
+
# attn_mask
|
57 |
+
self.attn_mask = self.attention_mask(max_position_embeddings)
|
58 |
+
|
59 |
+
# final_layer_norm
|
60 |
+
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
61 |
+
|
62 |
+
# text_projection
|
63 |
+
self.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
|
64 |
+
|
65 |
+
def attention_mask(self, length):
|
66 |
+
mask = torch.empty(length, length)
|
67 |
+
mask.fill_(float("-inf"))
|
68 |
+
mask.triu_(1)
|
69 |
+
return mask
|
70 |
+
|
71 |
+
def forward(self, input_ids, clip_skip=2):
|
72 |
+
embeds = self.token_embedding(input_ids) + self.position_embeds
|
73 |
+
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
74 |
+
for encoder_id, encoder in enumerate(self.encoders):
|
75 |
+
embeds = encoder(embeds, attn_mask=attn_mask)
|
76 |
+
if encoder_id + clip_skip == len(self.encoders):
|
77 |
+
hidden_states = embeds
|
78 |
+
embeds = self.final_layer_norm(embeds)
|
79 |
+
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
|
80 |
+
pooled_embeds = self.text_projection(pooled_embeds)
|
81 |
+
return pooled_embeds, hidden_states
|
82 |
+
|
83 |
+
def state_dict_converter(self):
|
84 |
+
return SDXLTextEncoder2StateDictConverter()
|
85 |
+
|
86 |
+
|
87 |
+
class SDXLTextEncoderStateDictConverter:
|
88 |
+
def __init__(self):
|
89 |
+
pass
|
90 |
+
|
91 |
+
def from_diffusers(self, state_dict):
|
92 |
+
rename_dict = {
|
93 |
+
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
94 |
+
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
95 |
+
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
96 |
+
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
|
97 |
+
}
|
98 |
+
attn_rename_dict = {
|
99 |
+
"self_attn.q_proj": "attn.to_q",
|
100 |
+
"self_attn.k_proj": "attn.to_k",
|
101 |
+
"self_attn.v_proj": "attn.to_v",
|
102 |
+
"self_attn.out_proj": "attn.to_out",
|
103 |
+
"layer_norm1": "layer_norm1",
|
104 |
+
"layer_norm2": "layer_norm2",
|
105 |
+
"mlp.fc1": "fc1",
|
106 |
+
"mlp.fc2": "fc2",
|
107 |
+
}
|
108 |
+
state_dict_ = {}
|
109 |
+
for name in state_dict:
|
110 |
+
if name in rename_dict:
|
111 |
+
param = state_dict[name]
|
112 |
+
if name == "text_model.embeddings.position_embedding.weight":
|
113 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
114 |
+
state_dict_[rename_dict[name]] = param
|
115 |
+
elif name.startswith("text_model.encoder.layers."):
|
116 |
+
param = state_dict[name]
|
117 |
+
names = name.split(".")
|
118 |
+
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
119 |
+
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
120 |
+
state_dict_[name_] = param
|
121 |
+
return state_dict_
|
122 |
+
|
123 |
+
def from_civitai(self, state_dict):
|
124 |
+
rename_dict = {
|
125 |
+
"conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight": "position_embeds",
|
126 |
+
"conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
127 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
|
128 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
|
129 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
|
130 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
|
131 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
|
132 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
|
133 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
|
134 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
|
135 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
|
136 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
|
137 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
138 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
139 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
|
140 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
|
141 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
|
142 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
|
143 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
|
144 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
|
145 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
|
146 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
|
147 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
|
148 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
|
149 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
|
150 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
|
151 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
|
152 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
|
153 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
154 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
155 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
|
156 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
|
157 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
|
158 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
|
159 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
|
160 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
|
161 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
|
162 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
|
163 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
|
164 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
|
165 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
|
166 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
|
167 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
|
168 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
|
169 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
170 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
171 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
|
172 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
|
173 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
|
174 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
|
175 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
|
176 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
|
177 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
|
178 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
|
179 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
|
180 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
|
181 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
|
182 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
|
183 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
|
184 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
|
185 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
186 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
187 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
|
188 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
|
189 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
|
190 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
|
191 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
|
192 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
|
193 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
|
194 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
|
195 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
|
196 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
|
197 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
|
198 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
|
199 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
|
200 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
|
201 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
202 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
203 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
|
204 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
|
205 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
|
206 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
|
207 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
|
208 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
|
209 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
|
210 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
|
211 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
|
212 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
|
213 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
|
214 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
|
215 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
|
216 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
|
217 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
218 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
219 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
|
220 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
|
221 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
|
222 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
|
223 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
|
224 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
|
225 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
|
226 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
|
227 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
|
228 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
|
229 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
|
230 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
|
231 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
|
232 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
|
233 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
234 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
235 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
|
236 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
|
237 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
|
238 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
|
239 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
|
240 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
|
241 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
|
242 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
|
243 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
|
244 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
|
245 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
|
246 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
|
247 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
|
248 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
|
249 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
250 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
251 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
|
252 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
|
253 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
|
254 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
|
255 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
|
256 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
|
257 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
|
258 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
|
259 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
|
260 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
|
261 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
|
262 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
|
263 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
|
264 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
|
265 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
266 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
267 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
|
268 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
|
269 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
|
270 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
|
271 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
|
272 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
|
273 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
|
274 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
|
275 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
|
276 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
|
277 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
|
278 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
|
279 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
|
280 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
|
281 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
282 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
283 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
|
284 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
|
285 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
|
286 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
|
287 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
|
288 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
|
289 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
|
290 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
|
291 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
|
292 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
|
293 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
|
294 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
|
295 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
|
296 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
|
297 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
298 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
299 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
|
300 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
|
301 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
|
302 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
|
303 |
+
}
|
304 |
+
state_dict_ = {}
|
305 |
+
for name in state_dict:
|
306 |
+
if name in rename_dict:
|
307 |
+
param = state_dict[name]
|
308 |
+
if name == "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight":
|
309 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
310 |
+
state_dict_[rename_dict[name]] = param
|
311 |
+
return state_dict_
|
312 |
+
|
313 |
+
|
314 |
+
class SDXLTextEncoder2StateDictConverter:
|
315 |
+
def __init__(self):
|
316 |
+
pass
|
317 |
+
|
318 |
+
def from_diffusers(self, state_dict):
|
319 |
+
rename_dict = {
|
320 |
+
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
321 |
+
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
322 |
+
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
323 |
+
"text_model.final_layer_norm.bias": "final_layer_norm.bias",
|
324 |
+
"text_projection.weight": "text_projection.weight"
|
325 |
+
}
|
326 |
+
attn_rename_dict = {
|
327 |
+
"self_attn.q_proj": "attn.to_q",
|
328 |
+
"self_attn.k_proj": "attn.to_k",
|
329 |
+
"self_attn.v_proj": "attn.to_v",
|
330 |
+
"self_attn.out_proj": "attn.to_out",
|
331 |
+
"layer_norm1": "layer_norm1",
|
332 |
+
"layer_norm2": "layer_norm2",
|
333 |
+
"mlp.fc1": "fc1",
|
334 |
+
"mlp.fc2": "fc2",
|
335 |
+
}
|
336 |
+
state_dict_ = {}
|
337 |
+
for name in state_dict:
|
338 |
+
if name in rename_dict:
|
339 |
+
param = state_dict[name]
|
340 |
+
if name == "text_model.embeddings.position_embedding.weight":
|
341 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
342 |
+
state_dict_[rename_dict[name]] = param
|
343 |
+
elif name.startswith("text_model.encoder.layers."):
|
344 |
+
param = state_dict[name]
|
345 |
+
names = name.split(".")
|
346 |
+
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
347 |
+
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
348 |
+
state_dict_[name_] = param
|
349 |
+
return state_dict_
|
350 |
+
|
351 |
+
def from_civitai(self, state_dict):
|
352 |
+
rename_dict = {
|
353 |
+
"conditioner.embedders.1.model.ln_final.bias": "final_layer_norm.bias",
|
354 |
+
"conditioner.embedders.1.model.ln_final.weight": "final_layer_norm.weight",
|
355 |
+
"conditioner.embedders.1.model.positional_embedding": "position_embeds",
|
356 |
+
"conditioner.embedders.1.model.token_embedding.weight": "token_embedding.weight",
|
357 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias": ['encoders.0.attn.to_q.bias', 'encoders.0.attn.to_k.bias', 'encoders.0.attn.to_v.bias'],
|
358 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": ['encoders.0.attn.to_q.weight', 'encoders.0.attn.to_k.weight', 'encoders.0.attn.to_v.weight'],
|
359 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
360 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
361 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias": "encoders.0.layer_norm1.bias",
|
362 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight": "encoders.0.layer_norm1.weight",
|
363 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.ln_2.bias": "encoders.0.layer_norm2.bias",
|
364 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.ln_2.weight": "encoders.0.layer_norm2.weight",
|
365 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.bias": "encoders.0.fc1.bias",
|
366 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.weight": "encoders.0.fc1.weight",
|
367 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.bias": "encoders.0.fc2.bias",
|
368 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.weight": "encoders.0.fc2.weight",
|
369 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias": ['encoders.1.attn.to_q.bias', 'encoders.1.attn.to_k.bias', 'encoders.1.attn.to_v.bias'],
|
370 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight": ['encoders.1.attn.to_q.weight', 'encoders.1.attn.to_k.weight', 'encoders.1.attn.to_v.weight'],
|
371 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
372 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
373 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.ln_1.bias": "encoders.1.layer_norm1.bias",
|
374 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.ln_1.weight": "encoders.1.layer_norm1.weight",
|
375 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.ln_2.bias": "encoders.1.layer_norm2.bias",
|
376 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.ln_2.weight": "encoders.1.layer_norm2.weight",
|
377 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.bias": "encoders.1.fc1.bias",
|
378 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.weight": "encoders.1.fc1.weight",
|
379 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.bias": "encoders.1.fc2.bias",
|
380 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.weight": "encoders.1.fc2.weight",
|
381 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias": ['encoders.10.attn.to_q.bias', 'encoders.10.attn.to_k.bias', 'encoders.10.attn.to_v.bias'],
|
382 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight": ['encoders.10.attn.to_q.weight', 'encoders.10.attn.to_k.weight', 'encoders.10.attn.to_v.weight'],
|
383 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
384 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
385 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.ln_1.bias": "encoders.10.layer_norm1.bias",
|
386 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.ln_1.weight": "encoders.10.layer_norm1.weight",
|
387 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.ln_2.bias": "encoders.10.layer_norm2.bias",
|
388 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.ln_2.weight": "encoders.10.layer_norm2.weight",
|
389 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.bias": "encoders.10.fc1.bias",
|
390 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.weight": "encoders.10.fc1.weight",
|
391 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.bias": "encoders.10.fc2.bias",
|
392 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.weight": "encoders.10.fc2.weight",
|
393 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias": ['encoders.11.attn.to_q.bias', 'encoders.11.attn.to_k.bias', 'encoders.11.attn.to_v.bias'],
|
394 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight": ['encoders.11.attn.to_q.weight', 'encoders.11.attn.to_k.weight', 'encoders.11.attn.to_v.weight'],
|
395 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.bias": "encoders.11.attn.to_out.bias",
|
396 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.weight": "encoders.11.attn.to_out.weight",
|
397 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.ln_1.bias": "encoders.11.layer_norm1.bias",
|
398 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.ln_1.weight": "encoders.11.layer_norm1.weight",
|
399 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.ln_2.bias": "encoders.11.layer_norm2.bias",
|
400 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.ln_2.weight": "encoders.11.layer_norm2.weight",
|
401 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.bias": "encoders.11.fc1.bias",
|
402 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.weight": "encoders.11.fc1.weight",
|
403 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.bias": "encoders.11.fc2.bias",
|
404 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.weight": "encoders.11.fc2.weight",
|
405 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias": ['encoders.12.attn.to_q.bias', 'encoders.12.attn.to_k.bias', 'encoders.12.attn.to_v.bias'],
|
406 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight": ['encoders.12.attn.to_q.weight', 'encoders.12.attn.to_k.weight', 'encoders.12.attn.to_v.weight'],
|
407 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.bias": "encoders.12.attn.to_out.bias",
|
408 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.weight": "encoders.12.attn.to_out.weight",
|
409 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.ln_1.bias": "encoders.12.layer_norm1.bias",
|
410 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.ln_1.weight": "encoders.12.layer_norm1.weight",
|
411 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.ln_2.bias": "encoders.12.layer_norm2.bias",
|
412 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.ln_2.weight": "encoders.12.layer_norm2.weight",
|
413 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.bias": "encoders.12.fc1.bias",
|
414 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.weight": "encoders.12.fc1.weight",
|
415 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.bias": "encoders.12.fc2.bias",
|
416 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.weight": "encoders.12.fc2.weight",
|
417 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias": ['encoders.13.attn.to_q.bias', 'encoders.13.attn.to_k.bias', 'encoders.13.attn.to_v.bias'],
|
418 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight": ['encoders.13.attn.to_q.weight', 'encoders.13.attn.to_k.weight', 'encoders.13.attn.to_v.weight'],
|
419 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.bias": "encoders.13.attn.to_out.bias",
|
420 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.weight": "encoders.13.attn.to_out.weight",
|
421 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.ln_1.bias": "encoders.13.layer_norm1.bias",
|
422 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.ln_1.weight": "encoders.13.layer_norm1.weight",
|
423 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.ln_2.bias": "encoders.13.layer_norm2.bias",
|
424 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.ln_2.weight": "encoders.13.layer_norm2.weight",
|
425 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.bias": "encoders.13.fc1.bias",
|
426 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.weight": "encoders.13.fc1.weight",
|
427 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.bias": "encoders.13.fc2.bias",
|
428 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.weight": "encoders.13.fc2.weight",
|
429 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias": ['encoders.14.attn.to_q.bias', 'encoders.14.attn.to_k.bias', 'encoders.14.attn.to_v.bias'],
|
430 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight": ['encoders.14.attn.to_q.weight', 'encoders.14.attn.to_k.weight', 'encoders.14.attn.to_v.weight'],
|
431 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.bias": "encoders.14.attn.to_out.bias",
|
432 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.weight": "encoders.14.attn.to_out.weight",
|
433 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.ln_1.bias": "encoders.14.layer_norm1.bias",
|
434 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.ln_1.weight": "encoders.14.layer_norm1.weight",
|
435 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.ln_2.bias": "encoders.14.layer_norm2.bias",
|
436 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.ln_2.weight": "encoders.14.layer_norm2.weight",
|
437 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.bias": "encoders.14.fc1.bias",
|
438 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.weight": "encoders.14.fc1.weight",
|
439 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.bias": "encoders.14.fc2.bias",
|
440 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.weight": "encoders.14.fc2.weight",
|
441 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias": ['encoders.15.attn.to_q.bias', 'encoders.15.attn.to_k.bias', 'encoders.15.attn.to_v.bias'],
|
442 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight": ['encoders.15.attn.to_q.weight', 'encoders.15.attn.to_k.weight', 'encoders.15.attn.to_v.weight'],
|
443 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.bias": "encoders.15.attn.to_out.bias",
|
444 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.weight": "encoders.15.attn.to_out.weight",
|
445 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.ln_1.bias": "encoders.15.layer_norm1.bias",
|
446 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.ln_1.weight": "encoders.15.layer_norm1.weight",
|
447 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.ln_2.bias": "encoders.15.layer_norm2.bias",
|
448 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.ln_2.weight": "encoders.15.layer_norm2.weight",
|
449 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.bias": "encoders.15.fc1.bias",
|
450 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.weight": "encoders.15.fc1.weight",
|
451 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.bias": "encoders.15.fc2.bias",
|
452 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.weight": "encoders.15.fc2.weight",
|
453 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias": ['encoders.16.attn.to_q.bias', 'encoders.16.attn.to_k.bias', 'encoders.16.attn.to_v.bias'],
|
454 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight": ['encoders.16.attn.to_q.weight', 'encoders.16.attn.to_k.weight', 'encoders.16.attn.to_v.weight'],
|
455 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.bias": "encoders.16.attn.to_out.bias",
|
456 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.weight": "encoders.16.attn.to_out.weight",
|
457 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.ln_1.bias": "encoders.16.layer_norm1.bias",
|
458 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.ln_1.weight": "encoders.16.layer_norm1.weight",
|
459 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.ln_2.bias": "encoders.16.layer_norm2.bias",
|
460 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.ln_2.weight": "encoders.16.layer_norm2.weight",
|
461 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.bias": "encoders.16.fc1.bias",
|
462 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.weight": "encoders.16.fc1.weight",
|
463 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.bias": "encoders.16.fc2.bias",
|
464 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.weight": "encoders.16.fc2.weight",
|
465 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias": ['encoders.17.attn.to_q.bias', 'encoders.17.attn.to_k.bias', 'encoders.17.attn.to_v.bias'],
|
466 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight": ['encoders.17.attn.to_q.weight', 'encoders.17.attn.to_k.weight', 'encoders.17.attn.to_v.weight'],
|
467 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.bias": "encoders.17.attn.to_out.bias",
|
468 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.weight": "encoders.17.attn.to_out.weight",
|
469 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.ln_1.bias": "encoders.17.layer_norm1.bias",
|
470 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.ln_1.weight": "encoders.17.layer_norm1.weight",
|
471 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.ln_2.bias": "encoders.17.layer_norm2.bias",
|
472 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.ln_2.weight": "encoders.17.layer_norm2.weight",
|
473 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.bias": "encoders.17.fc1.bias",
|
474 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.weight": "encoders.17.fc1.weight",
|
475 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.bias": "encoders.17.fc2.bias",
|
476 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.weight": "encoders.17.fc2.weight",
|
477 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias": ['encoders.18.attn.to_q.bias', 'encoders.18.attn.to_k.bias', 'encoders.18.attn.to_v.bias'],
|
478 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight": ['encoders.18.attn.to_q.weight', 'encoders.18.attn.to_k.weight', 'encoders.18.attn.to_v.weight'],
|
479 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.bias": "encoders.18.attn.to_out.bias",
|
480 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.weight": "encoders.18.attn.to_out.weight",
|
481 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.ln_1.bias": "encoders.18.layer_norm1.bias",
|
482 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.ln_1.weight": "encoders.18.layer_norm1.weight",
|
483 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.ln_2.bias": "encoders.18.layer_norm2.bias",
|
484 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.ln_2.weight": "encoders.18.layer_norm2.weight",
|
485 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.bias": "encoders.18.fc1.bias",
|
486 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.weight": "encoders.18.fc1.weight",
|
487 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.bias": "encoders.18.fc2.bias",
|
488 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.weight": "encoders.18.fc2.weight",
|
489 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias": ['encoders.19.attn.to_q.bias', 'encoders.19.attn.to_k.bias', 'encoders.19.attn.to_v.bias'],
|
490 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight": ['encoders.19.attn.to_q.weight', 'encoders.19.attn.to_k.weight', 'encoders.19.attn.to_v.weight'],
|
491 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.bias": "encoders.19.attn.to_out.bias",
|
492 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.weight": "encoders.19.attn.to_out.weight",
|
493 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.ln_1.bias": "encoders.19.layer_norm1.bias",
|
494 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.ln_1.weight": "encoders.19.layer_norm1.weight",
|
495 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.ln_2.bias": "encoders.19.layer_norm2.bias",
|
496 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.ln_2.weight": "encoders.19.layer_norm2.weight",
|
497 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.bias": "encoders.19.fc1.bias",
|
498 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.weight": "encoders.19.fc1.weight",
|
499 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.bias": "encoders.19.fc2.bias",
|
500 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.weight": "encoders.19.fc2.weight",
|
501 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias": ['encoders.2.attn.to_q.bias', 'encoders.2.attn.to_k.bias', 'encoders.2.attn.to_v.bias'],
|
502 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight": ['encoders.2.attn.to_q.weight', 'encoders.2.attn.to_k.weight', 'encoders.2.attn.to_v.weight'],
|
503 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
504 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
505 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.ln_1.bias": "encoders.2.layer_norm1.bias",
|
506 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.ln_1.weight": "encoders.2.layer_norm1.weight",
|
507 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.ln_2.bias": "encoders.2.layer_norm2.bias",
|
508 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.ln_2.weight": "encoders.2.layer_norm2.weight",
|
509 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.bias": "encoders.2.fc1.bias",
|
510 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.weight": "encoders.2.fc1.weight",
|
511 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.bias": "encoders.2.fc2.bias",
|
512 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.weight": "encoders.2.fc2.weight",
|
513 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias": ['encoders.20.attn.to_q.bias', 'encoders.20.attn.to_k.bias', 'encoders.20.attn.to_v.bias'],
|
514 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight": ['encoders.20.attn.to_q.weight', 'encoders.20.attn.to_k.weight', 'encoders.20.attn.to_v.weight'],
|
515 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.bias": "encoders.20.attn.to_out.bias",
|
516 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.weight": "encoders.20.attn.to_out.weight",
|
517 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.ln_1.bias": "encoders.20.layer_norm1.bias",
|
518 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.ln_1.weight": "encoders.20.layer_norm1.weight",
|
519 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.ln_2.bias": "encoders.20.layer_norm2.bias",
|
520 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.ln_2.weight": "encoders.20.layer_norm2.weight",
|
521 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.bias": "encoders.20.fc1.bias",
|
522 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.weight": "encoders.20.fc1.weight",
|
523 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.bias": "encoders.20.fc2.bias",
|
524 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.weight": "encoders.20.fc2.weight",
|
525 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias": ['encoders.21.attn.to_q.bias', 'encoders.21.attn.to_k.bias', 'encoders.21.attn.to_v.bias'],
|
526 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight": ['encoders.21.attn.to_q.weight', 'encoders.21.attn.to_k.weight', 'encoders.21.attn.to_v.weight'],
|
527 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.bias": "encoders.21.attn.to_out.bias",
|
528 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.weight": "encoders.21.attn.to_out.weight",
|
529 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.ln_1.bias": "encoders.21.layer_norm1.bias",
|
530 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.ln_1.weight": "encoders.21.layer_norm1.weight",
|
531 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.ln_2.bias": "encoders.21.layer_norm2.bias",
|
532 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.ln_2.weight": "encoders.21.layer_norm2.weight",
|
533 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.bias": "encoders.21.fc1.bias",
|
534 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.weight": "encoders.21.fc1.weight",
|
535 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.bias": "encoders.21.fc2.bias",
|
536 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.weight": "encoders.21.fc2.weight",
|
537 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias": ['encoders.22.attn.to_q.bias', 'encoders.22.attn.to_k.bias', 'encoders.22.attn.to_v.bias'],
|
538 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight": ['encoders.22.attn.to_q.weight', 'encoders.22.attn.to_k.weight', 'encoders.22.attn.to_v.weight'],
|
539 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.bias": "encoders.22.attn.to_out.bias",
|
540 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.weight": "encoders.22.attn.to_out.weight",
|
541 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.ln_1.bias": "encoders.22.layer_norm1.bias",
|
542 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.ln_1.weight": "encoders.22.layer_norm1.weight",
|
543 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.ln_2.bias": "encoders.22.layer_norm2.bias",
|
544 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.ln_2.weight": "encoders.22.layer_norm2.weight",
|
545 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.bias": "encoders.22.fc1.bias",
|
546 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.weight": "encoders.22.fc1.weight",
|
547 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.bias": "encoders.22.fc2.bias",
|
548 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.weight": "encoders.22.fc2.weight",
|
549 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias": ['encoders.23.attn.to_q.bias', 'encoders.23.attn.to_k.bias', 'encoders.23.attn.to_v.bias'],
|
550 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight": ['encoders.23.attn.to_q.weight', 'encoders.23.attn.to_k.weight', 'encoders.23.attn.to_v.weight'],
|
551 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.bias": "encoders.23.attn.to_out.bias",
|
552 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.weight": "encoders.23.attn.to_out.weight",
|
553 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.ln_1.bias": "encoders.23.layer_norm1.bias",
|
554 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.ln_1.weight": "encoders.23.layer_norm1.weight",
|
555 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.ln_2.bias": "encoders.23.layer_norm2.bias",
|
556 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.ln_2.weight": "encoders.23.layer_norm2.weight",
|
557 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.bias": "encoders.23.fc1.bias",
|
558 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.weight": "encoders.23.fc1.weight",
|
559 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.bias": "encoders.23.fc2.bias",
|
560 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.weight": "encoders.23.fc2.weight",
|
561 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias": ['encoders.24.attn.to_q.bias', 'encoders.24.attn.to_k.bias', 'encoders.24.attn.to_v.bias'],
|
562 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight": ['encoders.24.attn.to_q.weight', 'encoders.24.attn.to_k.weight', 'encoders.24.attn.to_v.weight'],
|
563 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.bias": "encoders.24.attn.to_out.bias",
|
564 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.weight": "encoders.24.attn.to_out.weight",
|
565 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.ln_1.bias": "encoders.24.layer_norm1.bias",
|
566 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.ln_1.weight": "encoders.24.layer_norm1.weight",
|
567 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.ln_2.bias": "encoders.24.layer_norm2.bias",
|
568 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.ln_2.weight": "encoders.24.layer_norm2.weight",
|
569 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.bias": "encoders.24.fc1.bias",
|
570 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.weight": "encoders.24.fc1.weight",
|
571 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.bias": "encoders.24.fc2.bias",
|
572 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.weight": "encoders.24.fc2.weight",
|
573 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias": ['encoders.25.attn.to_q.bias', 'encoders.25.attn.to_k.bias', 'encoders.25.attn.to_v.bias'],
|
574 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight": ['encoders.25.attn.to_q.weight', 'encoders.25.attn.to_k.weight', 'encoders.25.attn.to_v.weight'],
|
575 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.bias": "encoders.25.attn.to_out.bias",
|
576 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.weight": "encoders.25.attn.to_out.weight",
|
577 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.ln_1.bias": "encoders.25.layer_norm1.bias",
|
578 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.ln_1.weight": "encoders.25.layer_norm1.weight",
|
579 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.ln_2.bias": "encoders.25.layer_norm2.bias",
|
580 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.ln_2.weight": "encoders.25.layer_norm2.weight",
|
581 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.bias": "encoders.25.fc1.bias",
|
582 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.weight": "encoders.25.fc1.weight",
|
583 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.bias": "encoders.25.fc2.bias",
|
584 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.weight": "encoders.25.fc2.weight",
|
585 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias": ['encoders.26.attn.to_q.bias', 'encoders.26.attn.to_k.bias', 'encoders.26.attn.to_v.bias'],
|
586 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight": ['encoders.26.attn.to_q.weight', 'encoders.26.attn.to_k.weight', 'encoders.26.attn.to_v.weight'],
|
587 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.bias": "encoders.26.attn.to_out.bias",
|
588 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.weight": "encoders.26.attn.to_out.weight",
|
589 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.ln_1.bias": "encoders.26.layer_norm1.bias",
|
590 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.ln_1.weight": "encoders.26.layer_norm1.weight",
|
591 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.ln_2.bias": "encoders.26.layer_norm2.bias",
|
592 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.ln_2.weight": "encoders.26.layer_norm2.weight",
|
593 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.bias": "encoders.26.fc1.bias",
|
594 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.weight": "encoders.26.fc1.weight",
|
595 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.bias": "encoders.26.fc2.bias",
|
596 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.weight": "encoders.26.fc2.weight",
|
597 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias": ['encoders.27.attn.to_q.bias', 'encoders.27.attn.to_k.bias', 'encoders.27.attn.to_v.bias'],
|
598 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight": ['encoders.27.attn.to_q.weight', 'encoders.27.attn.to_k.weight', 'encoders.27.attn.to_v.weight'],
|
599 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.bias": "encoders.27.attn.to_out.bias",
|
600 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.weight": "encoders.27.attn.to_out.weight",
|
601 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.ln_1.bias": "encoders.27.layer_norm1.bias",
|
602 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.ln_1.weight": "encoders.27.layer_norm1.weight",
|
603 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.ln_2.bias": "encoders.27.layer_norm2.bias",
|
604 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.ln_2.weight": "encoders.27.layer_norm2.weight",
|
605 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.bias": "encoders.27.fc1.bias",
|
606 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.weight": "encoders.27.fc1.weight",
|
607 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.bias": "encoders.27.fc2.bias",
|
608 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.weight": "encoders.27.fc2.weight",
|
609 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias": ['encoders.28.attn.to_q.bias', 'encoders.28.attn.to_k.bias', 'encoders.28.attn.to_v.bias'],
|
610 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight": ['encoders.28.attn.to_q.weight', 'encoders.28.attn.to_k.weight', 'encoders.28.attn.to_v.weight'],
|
611 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.bias": "encoders.28.attn.to_out.bias",
|
612 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.weight": "encoders.28.attn.to_out.weight",
|
613 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.ln_1.bias": "encoders.28.layer_norm1.bias",
|
614 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.ln_1.weight": "encoders.28.layer_norm1.weight",
|
615 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.ln_2.bias": "encoders.28.layer_norm2.bias",
|
616 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.ln_2.weight": "encoders.28.layer_norm2.weight",
|
617 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.bias": "encoders.28.fc1.bias",
|
618 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.weight": "encoders.28.fc1.weight",
|
619 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.bias": "encoders.28.fc2.bias",
|
620 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.weight": "encoders.28.fc2.weight",
|
621 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias": ['encoders.29.attn.to_q.bias', 'encoders.29.attn.to_k.bias', 'encoders.29.attn.to_v.bias'],
|
622 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight": ['encoders.29.attn.to_q.weight', 'encoders.29.attn.to_k.weight', 'encoders.29.attn.to_v.weight'],
|
623 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.bias": "encoders.29.attn.to_out.bias",
|
624 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.weight": "encoders.29.attn.to_out.weight",
|
625 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.ln_1.bias": "encoders.29.layer_norm1.bias",
|
626 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.ln_1.weight": "encoders.29.layer_norm1.weight",
|
627 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.ln_2.bias": "encoders.29.layer_norm2.bias",
|
628 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.ln_2.weight": "encoders.29.layer_norm2.weight",
|
629 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.bias": "encoders.29.fc1.bias",
|
630 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.weight": "encoders.29.fc1.weight",
|
631 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.bias": "encoders.29.fc2.bias",
|
632 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.weight": "encoders.29.fc2.weight",
|
633 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias": ['encoders.3.attn.to_q.bias', 'encoders.3.attn.to_k.bias', 'encoders.3.attn.to_v.bias'],
|
634 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight": ['encoders.3.attn.to_q.weight', 'encoders.3.attn.to_k.weight', 'encoders.3.attn.to_v.weight'],
|
635 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
636 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
637 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.ln_1.bias": "encoders.3.layer_norm1.bias",
|
638 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.ln_1.weight": "encoders.3.layer_norm1.weight",
|
639 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.ln_2.bias": "encoders.3.layer_norm2.bias",
|
640 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.ln_2.weight": "encoders.3.layer_norm2.weight",
|
641 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.bias": "encoders.3.fc1.bias",
|
642 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.weight": "encoders.3.fc1.weight",
|
643 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.bias": "encoders.3.fc2.bias",
|
644 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.weight": "encoders.3.fc2.weight",
|
645 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias": ['encoders.30.attn.to_q.bias', 'encoders.30.attn.to_k.bias', 'encoders.30.attn.to_v.bias'],
|
646 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight": ['encoders.30.attn.to_q.weight', 'encoders.30.attn.to_k.weight', 'encoders.30.attn.to_v.weight'],
|
647 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.bias": "encoders.30.attn.to_out.bias",
|
648 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.weight": "encoders.30.attn.to_out.weight",
|
649 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.ln_1.bias": "encoders.30.layer_norm1.bias",
|
650 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.ln_1.weight": "encoders.30.layer_norm1.weight",
|
651 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.ln_2.bias": "encoders.30.layer_norm2.bias",
|
652 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.ln_2.weight": "encoders.30.layer_norm2.weight",
|
653 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.bias": "encoders.30.fc1.bias",
|
654 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.weight": "encoders.30.fc1.weight",
|
655 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.bias": "encoders.30.fc2.bias",
|
656 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.weight": "encoders.30.fc2.weight",
|
657 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias": ['encoders.31.attn.to_q.bias', 'encoders.31.attn.to_k.bias', 'encoders.31.attn.to_v.bias'],
|
658 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight": ['encoders.31.attn.to_q.weight', 'encoders.31.attn.to_k.weight', 'encoders.31.attn.to_v.weight'],
|
659 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.bias": "encoders.31.attn.to_out.bias",
|
660 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.weight": "encoders.31.attn.to_out.weight",
|
661 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.ln_1.bias": "encoders.31.layer_norm1.bias",
|
662 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.ln_1.weight": "encoders.31.layer_norm1.weight",
|
663 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.ln_2.bias": "encoders.31.layer_norm2.bias",
|
664 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.ln_2.weight": "encoders.31.layer_norm2.weight",
|
665 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.bias": "encoders.31.fc1.bias",
|
666 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.weight": "encoders.31.fc1.weight",
|
667 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.bias": "encoders.31.fc2.bias",
|
668 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.weight": "encoders.31.fc2.weight",
|
669 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias": ['encoders.4.attn.to_q.bias', 'encoders.4.attn.to_k.bias', 'encoders.4.attn.to_v.bias'],
|
670 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight": ['encoders.4.attn.to_q.weight', 'encoders.4.attn.to_k.weight', 'encoders.4.attn.to_v.weight'],
|
671 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
672 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
673 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.ln_1.bias": "encoders.4.layer_norm1.bias",
|
674 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.ln_1.weight": "encoders.4.layer_norm1.weight",
|
675 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.ln_2.bias": "encoders.4.layer_norm2.bias",
|
676 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.ln_2.weight": "encoders.4.layer_norm2.weight",
|
677 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.bias": "encoders.4.fc1.bias",
|
678 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.weight": "encoders.4.fc1.weight",
|
679 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.bias": "encoders.4.fc2.bias",
|
680 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.weight": "encoders.4.fc2.weight",
|
681 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias": ['encoders.5.attn.to_q.bias', 'encoders.5.attn.to_k.bias', 'encoders.5.attn.to_v.bias'],
|
682 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight": ['encoders.5.attn.to_q.weight', 'encoders.5.attn.to_k.weight', 'encoders.5.attn.to_v.weight'],
|
683 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
684 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
685 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.ln_1.bias": "encoders.5.layer_norm1.bias",
|
686 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.ln_1.weight": "encoders.5.layer_norm1.weight",
|
687 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.ln_2.bias": "encoders.5.layer_norm2.bias",
|
688 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.ln_2.weight": "encoders.5.layer_norm2.weight",
|
689 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.bias": "encoders.5.fc1.bias",
|
690 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.weight": "encoders.5.fc1.weight",
|
691 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.bias": "encoders.5.fc2.bias",
|
692 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.weight": "encoders.5.fc2.weight",
|
693 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias": ['encoders.6.attn.to_q.bias', 'encoders.6.attn.to_k.bias', 'encoders.6.attn.to_v.bias'],
|
694 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight": ['encoders.6.attn.to_q.weight', 'encoders.6.attn.to_k.weight', 'encoders.6.attn.to_v.weight'],
|
695 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
696 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
697 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.ln_1.bias": "encoders.6.layer_norm1.bias",
|
698 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.ln_1.weight": "encoders.6.layer_norm1.weight",
|
699 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.ln_2.bias": "encoders.6.layer_norm2.bias",
|
700 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.ln_2.weight": "encoders.6.layer_norm2.weight",
|
701 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.bias": "encoders.6.fc1.bias",
|
702 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.weight": "encoders.6.fc1.weight",
|
703 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.bias": "encoders.6.fc2.bias",
|
704 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.weight": "encoders.6.fc2.weight",
|
705 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias": ['encoders.7.attn.to_q.bias', 'encoders.7.attn.to_k.bias', 'encoders.7.attn.to_v.bias'],
|
706 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight": ['encoders.7.attn.to_q.weight', 'encoders.7.attn.to_k.weight', 'encoders.7.attn.to_v.weight'],
|
707 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
708 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
709 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.ln_1.bias": "encoders.7.layer_norm1.bias",
|
710 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.ln_1.weight": "encoders.7.layer_norm1.weight",
|
711 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.ln_2.bias": "encoders.7.layer_norm2.bias",
|
712 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.ln_2.weight": "encoders.7.layer_norm2.weight",
|
713 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.bias": "encoders.7.fc1.bias",
|
714 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.weight": "encoders.7.fc1.weight",
|
715 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.bias": "encoders.7.fc2.bias",
|
716 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.weight": "encoders.7.fc2.weight",
|
717 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias": ['encoders.8.attn.to_q.bias', 'encoders.8.attn.to_k.bias', 'encoders.8.attn.to_v.bias'],
|
718 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight": ['encoders.8.attn.to_q.weight', 'encoders.8.attn.to_k.weight', 'encoders.8.attn.to_v.weight'],
|
719 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
720 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
721 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.ln_1.bias": "encoders.8.layer_norm1.bias",
|
722 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.ln_1.weight": "encoders.8.layer_norm1.weight",
|
723 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.ln_2.bias": "encoders.8.layer_norm2.bias",
|
724 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.ln_2.weight": "encoders.8.layer_norm2.weight",
|
725 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.bias": "encoders.8.fc1.bias",
|
726 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.weight": "encoders.8.fc1.weight",
|
727 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.bias": "encoders.8.fc2.bias",
|
728 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.weight": "encoders.8.fc2.weight",
|
729 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias": ['encoders.9.attn.to_q.bias', 'encoders.9.attn.to_k.bias', 'encoders.9.attn.to_v.bias'],
|
730 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight": ['encoders.9.attn.to_q.weight', 'encoders.9.attn.to_k.weight', 'encoders.9.attn.to_v.weight'],
|
731 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
732 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
733 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.ln_1.bias": "encoders.9.layer_norm1.bias",
|
734 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.ln_1.weight": "encoders.9.layer_norm1.weight",
|
735 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.ln_2.bias": "encoders.9.layer_norm2.bias",
|
736 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.ln_2.weight": "encoders.9.layer_norm2.weight",
|
737 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.bias": "encoders.9.fc1.bias",
|
738 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.weight": "encoders.9.fc1.weight",
|
739 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias": "encoders.9.fc2.bias",
|
740 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight": "encoders.9.fc2.weight",
|
741 |
+
"conditioner.embedders.1.model.text_projection": "text_projection.weight",
|
742 |
+
}
|
743 |
+
state_dict_ = {}
|
744 |
+
for name in state_dict:
|
745 |
+
if name in rename_dict:
|
746 |
+
param = state_dict[name]
|
747 |
+
if name == "conditioner.embedders.1.model.positional_embedding":
|
748 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
749 |
+
elif name == "conditioner.embedders.1.model.text_projection":
|
750 |
+
param = param.T
|
751 |
+
if isinstance(rename_dict[name], str):
|
752 |
+
state_dict_[rename_dict[name]] = param
|
753 |
+
else:
|
754 |
+
length = param.shape[0] // 3
|
755 |
+
for i, rename in enumerate(rename_dict[name]):
|
756 |
+
state_dict_[rename] = param[i*length: i*length+length]
|
757 |
+
return state_dict_
|
diffsynth/models/sdxl_unet.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
diffsynth/models/sdxl_vae_decoder.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter
|
2 |
+
|
3 |
+
|
4 |
+
class SDXLVAEDecoder(SDVAEDecoder):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.scaling_factor = 0.13025
|
8 |
+
|
9 |
+
def state_dict_converter(self):
|
10 |
+
return SDXLVAEDecoderStateDictConverter()
|
11 |
+
|
12 |
+
|
13 |
+
class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
|
14 |
+
def __init__(self):
|
15 |
+
super().__init__()
|
diffsynth/models/sdxl_vae_encoder.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
|
2 |
+
|
3 |
+
|
4 |
+
class SDXLVAEEncoder(SDVAEEncoder):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.scaling_factor = 0.13025
|
8 |
+
|
9 |
+
def state_dict_converter(self):
|
10 |
+
return SDXLVAEEncoderStateDictConverter()
|
11 |
+
|
12 |
+
|
13 |
+
class SDXLVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
|
14 |
+
def __init__(self):
|
15 |
+
super().__init__()
|
diffsynth/models/svd_image_encoder.py
ADDED
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .sd_text_encoder import CLIPEncoderLayer
|
3 |
+
|
4 |
+
|
5 |
+
class CLIPVisionEmbeddings(torch.nn.Module):
|
6 |
+
def __init__(self, embed_dim=1280, image_size=224, patch_size=14, num_channels=3):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
# class_embeds (This is a fixed tensor)
|
10 |
+
self.class_embedding = torch.nn.Parameter(torch.randn(1, 1, embed_dim))
|
11 |
+
|
12 |
+
# position_embeds
|
13 |
+
self.patch_embedding = torch.nn.Conv2d(in_channels=num_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, bias=False)
|
14 |
+
|
15 |
+
# position_embeds (This is a fixed tensor)
|
16 |
+
self.position_embeds = torch.nn.Parameter(torch.zeros(1, (image_size // patch_size) ** 2 + 1, embed_dim))
|
17 |
+
|
18 |
+
def forward(self, pixel_values):
|
19 |
+
batch_size = pixel_values.shape[0]
|
20 |
+
patch_embeds = self.patch_embedding(pixel_values)
|
21 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
22 |
+
class_embeds = self.class_embedding.repeat(batch_size, 1, 1)
|
23 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + self.position_embeds
|
24 |
+
return embeddings
|
25 |
+
|
26 |
+
|
27 |
+
class SVDImageEncoder(torch.nn.Module):
|
28 |
+
def __init__(self, embed_dim=1280, layer_norm_eps=1e-5, num_encoder_layers=32, encoder_intermediate_size=5120, projection_dim=1024, num_heads=16, head_dim=80):
|
29 |
+
super().__init__()
|
30 |
+
self.embeddings = CLIPVisionEmbeddings(embed_dim=embed_dim)
|
31 |
+
self.pre_layernorm = torch.nn.LayerNorm(embed_dim, eps=layer_norm_eps)
|
32 |
+
self.encoders = torch.nn.ModuleList([
|
33 |
+
CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=num_heads, head_dim=head_dim, use_quick_gelu=False)
|
34 |
+
for _ in range(num_encoder_layers)])
|
35 |
+
self.post_layernorm = torch.nn.LayerNorm(embed_dim, eps=layer_norm_eps)
|
36 |
+
self.visual_projection = torch.nn.Linear(embed_dim, projection_dim, bias=False)
|
37 |
+
|
38 |
+
def forward(self, pixel_values):
|
39 |
+
embeds = self.embeddings(pixel_values)
|
40 |
+
embeds = self.pre_layernorm(embeds)
|
41 |
+
for encoder_id, encoder in enumerate(self.encoders):
|
42 |
+
embeds = encoder(embeds)
|
43 |
+
embeds = self.post_layernorm(embeds[:, 0, :])
|
44 |
+
embeds = self.visual_projection(embeds)
|
45 |
+
return embeds
|
46 |
+
|
47 |
+
def state_dict_converter(self):
|
48 |
+
return SVDImageEncoderStateDictConverter()
|
49 |
+
|
50 |
+
|
51 |
+
class SVDImageEncoderStateDictConverter:
|
52 |
+
def __init__(self):
|
53 |
+
pass
|
54 |
+
|
55 |
+
def from_diffusers(self, state_dict):
|
56 |
+
rename_dict = {
|
57 |
+
"vision_model.embeddings.patch_embedding.weight": "embeddings.patch_embedding.weight",
|
58 |
+
"vision_model.embeddings.class_embedding": "embeddings.class_embedding",
|
59 |
+
"vision_model.embeddings.position_embedding.weight": "embeddings.position_embeds",
|
60 |
+
"vision_model.pre_layrnorm.weight": "pre_layernorm.weight",
|
61 |
+
"vision_model.pre_layrnorm.bias": "pre_layernorm.bias",
|
62 |
+
"vision_model.post_layernorm.weight": "post_layernorm.weight",
|
63 |
+
"vision_model.post_layernorm.bias": "post_layernorm.bias",
|
64 |
+
"visual_projection.weight": "visual_projection.weight"
|
65 |
+
}
|
66 |
+
attn_rename_dict = {
|
67 |
+
"self_attn.q_proj": "attn.to_q",
|
68 |
+
"self_attn.k_proj": "attn.to_k",
|
69 |
+
"self_attn.v_proj": "attn.to_v",
|
70 |
+
"self_attn.out_proj": "attn.to_out",
|
71 |
+
"layer_norm1": "layer_norm1",
|
72 |
+
"layer_norm2": "layer_norm2",
|
73 |
+
"mlp.fc1": "fc1",
|
74 |
+
"mlp.fc2": "fc2",
|
75 |
+
}
|
76 |
+
state_dict_ = {}
|
77 |
+
for name in state_dict:
|
78 |
+
if name in rename_dict:
|
79 |
+
param = state_dict[name]
|
80 |
+
if name == "vision_model.embeddings.class_embedding":
|
81 |
+
param = state_dict[name].view(1, 1, -1)
|
82 |
+
elif name == "vision_model.embeddings.position_embedding.weight":
|
83 |
+
param = state_dict[name].unsqueeze(0)
|
84 |
+
state_dict_[rename_dict[name]] = param
|
85 |
+
elif name.startswith("vision_model.encoder.layers."):
|
86 |
+
param = state_dict[name]
|
87 |
+
names = name.split(".")
|
88 |
+
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
89 |
+
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
90 |
+
state_dict_[name_] = param
|
91 |
+
return state_dict_
|
92 |
+
|
93 |
+
def from_civitai(self, state_dict):
|
94 |
+
rename_dict = {
|
95 |
+
"conditioner.embedders.0.open_clip.model.visual.class_embedding": "embeddings.class_embedding",
|
96 |
+
"conditioner.embedders.0.open_clip.model.visual.conv1.weight": "embeddings.patch_embedding.weight",
|
97 |
+
"conditioner.embedders.0.open_clip.model.visual.ln_post.bias": "post_layernorm.bias",
|
98 |
+
"conditioner.embedders.0.open_clip.model.visual.ln_post.weight": "post_layernorm.weight",
|
99 |
+
"conditioner.embedders.0.open_clip.model.visual.ln_pre.bias": "pre_layernorm.bias",
|
100 |
+
"conditioner.embedders.0.open_clip.model.visual.ln_pre.weight": "pre_layernorm.weight",
|
101 |
+
"conditioner.embedders.0.open_clip.model.visual.positional_embedding": "embeddings.position_embeds",
|
102 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.in_proj_bias": ['encoders.0.attn.to_q.bias', 'encoders.0.attn.to_k.bias', 'encoders.0.attn.to_v.bias'],
|
103 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.in_proj_weight": ['encoders.0.attn.to_q.weight', 'encoders.0.attn.to_k.weight', 'encoders.0.attn.to_v.weight'],
|
104 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
105 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
106 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_1.bias": "encoders.0.layer_norm1.bias",
|
107 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_1.weight": "encoders.0.layer_norm1.weight",
|
108 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_2.bias": "encoders.0.layer_norm2.bias",
|
109 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_2.weight": "encoders.0.layer_norm2.weight",
|
110 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_fc.bias": "encoders.0.fc1.bias",
|
111 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_fc.weight": "encoders.0.fc1.weight",
|
112 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_proj.bias": "encoders.0.fc2.bias",
|
113 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_proj.weight": "encoders.0.fc2.weight",
|
114 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.in_proj_bias": ['encoders.1.attn.to_q.bias', 'encoders.1.attn.to_k.bias', 'encoders.1.attn.to_v.bias'],
|
115 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.in_proj_weight": ['encoders.1.attn.to_q.weight', 'encoders.1.attn.to_k.weight', 'encoders.1.attn.to_v.weight'],
|
116 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
117 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
118 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_1.bias": "encoders.1.layer_norm1.bias",
|
119 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_1.weight": "encoders.1.layer_norm1.weight",
|
120 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_2.bias": "encoders.1.layer_norm2.bias",
|
121 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_2.weight": "encoders.1.layer_norm2.weight",
|
122 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_fc.bias": "encoders.1.fc1.bias",
|
123 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_fc.weight": "encoders.1.fc1.weight",
|
124 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_proj.bias": "encoders.1.fc2.bias",
|
125 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_proj.weight": "encoders.1.fc2.weight",
|
126 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.in_proj_bias": ['encoders.10.attn.to_q.bias', 'encoders.10.attn.to_k.bias', 'encoders.10.attn.to_v.bias'],
|
127 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.in_proj_weight": ['encoders.10.attn.to_q.weight', 'encoders.10.attn.to_k.weight', 'encoders.10.attn.to_v.weight'],
|
128 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
129 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
130 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_1.bias": "encoders.10.layer_norm1.bias",
|
131 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_1.weight": "encoders.10.layer_norm1.weight",
|
132 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_2.bias": "encoders.10.layer_norm2.bias",
|
133 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_2.weight": "encoders.10.layer_norm2.weight",
|
134 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_fc.bias": "encoders.10.fc1.bias",
|
135 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_fc.weight": "encoders.10.fc1.weight",
|
136 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_proj.bias": "encoders.10.fc2.bias",
|
137 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_proj.weight": "encoders.10.fc2.weight",
|
138 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.in_proj_bias": ['encoders.11.attn.to_q.bias', 'encoders.11.attn.to_k.bias', 'encoders.11.attn.to_v.bias'],
|
139 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.in_proj_weight": ['encoders.11.attn.to_q.weight', 'encoders.11.attn.to_k.weight', 'encoders.11.attn.to_v.weight'],
|
140 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.out_proj.bias": "encoders.11.attn.to_out.bias",
|
141 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.out_proj.weight": "encoders.11.attn.to_out.weight",
|
142 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_1.bias": "encoders.11.layer_norm1.bias",
|
143 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_1.weight": "encoders.11.layer_norm1.weight",
|
144 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_2.bias": "encoders.11.layer_norm2.bias",
|
145 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_2.weight": "encoders.11.layer_norm2.weight",
|
146 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_fc.bias": "encoders.11.fc1.bias",
|
147 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_fc.weight": "encoders.11.fc1.weight",
|
148 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_proj.bias": "encoders.11.fc2.bias",
|
149 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_proj.weight": "encoders.11.fc2.weight",
|
150 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.in_proj_bias": ['encoders.12.attn.to_q.bias', 'encoders.12.attn.to_k.bias', 'encoders.12.attn.to_v.bias'],
|
151 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.in_proj_weight": ['encoders.12.attn.to_q.weight', 'encoders.12.attn.to_k.weight', 'encoders.12.attn.to_v.weight'],
|
152 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.out_proj.bias": "encoders.12.attn.to_out.bias",
|
153 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.out_proj.weight": "encoders.12.attn.to_out.weight",
|
154 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_1.bias": "encoders.12.layer_norm1.bias",
|
155 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_1.weight": "encoders.12.layer_norm1.weight",
|
156 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_2.bias": "encoders.12.layer_norm2.bias",
|
157 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_2.weight": "encoders.12.layer_norm2.weight",
|
158 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_fc.bias": "encoders.12.fc1.bias",
|
159 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_fc.weight": "encoders.12.fc1.weight",
|
160 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_proj.bias": "encoders.12.fc2.bias",
|
161 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_proj.weight": "encoders.12.fc2.weight",
|
162 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.in_proj_bias": ['encoders.13.attn.to_q.bias', 'encoders.13.attn.to_k.bias', 'encoders.13.attn.to_v.bias'],
|
163 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.in_proj_weight": ['encoders.13.attn.to_q.weight', 'encoders.13.attn.to_k.weight', 'encoders.13.attn.to_v.weight'],
|
164 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.out_proj.bias": "encoders.13.attn.to_out.bias",
|
165 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.out_proj.weight": "encoders.13.attn.to_out.weight",
|
166 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_1.bias": "encoders.13.layer_norm1.bias",
|
167 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_1.weight": "encoders.13.layer_norm1.weight",
|
168 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_2.bias": "encoders.13.layer_norm2.bias",
|
169 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_2.weight": "encoders.13.layer_norm2.weight",
|
170 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_fc.bias": "encoders.13.fc1.bias",
|
171 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_fc.weight": "encoders.13.fc1.weight",
|
172 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_proj.bias": "encoders.13.fc2.bias",
|
173 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_proj.weight": "encoders.13.fc2.weight",
|
174 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.in_proj_bias": ['encoders.14.attn.to_q.bias', 'encoders.14.attn.to_k.bias', 'encoders.14.attn.to_v.bias'],
|
175 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.in_proj_weight": ['encoders.14.attn.to_q.weight', 'encoders.14.attn.to_k.weight', 'encoders.14.attn.to_v.weight'],
|
176 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.out_proj.bias": "encoders.14.attn.to_out.bias",
|
177 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.out_proj.weight": "encoders.14.attn.to_out.weight",
|
178 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_1.bias": "encoders.14.layer_norm1.bias",
|
179 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_1.weight": "encoders.14.layer_norm1.weight",
|
180 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_2.bias": "encoders.14.layer_norm2.bias",
|
181 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_2.weight": "encoders.14.layer_norm2.weight",
|
182 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_fc.bias": "encoders.14.fc1.bias",
|
183 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_fc.weight": "encoders.14.fc1.weight",
|
184 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_proj.bias": "encoders.14.fc2.bias",
|
185 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_proj.weight": "encoders.14.fc2.weight",
|
186 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.in_proj_bias": ['encoders.15.attn.to_q.bias', 'encoders.15.attn.to_k.bias', 'encoders.15.attn.to_v.bias'],
|
187 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.in_proj_weight": ['encoders.15.attn.to_q.weight', 'encoders.15.attn.to_k.weight', 'encoders.15.attn.to_v.weight'],
|
188 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.out_proj.bias": "encoders.15.attn.to_out.bias",
|
189 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.out_proj.weight": "encoders.15.attn.to_out.weight",
|
190 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_1.bias": "encoders.15.layer_norm1.bias",
|
191 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_1.weight": "encoders.15.layer_norm1.weight",
|
192 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_2.bias": "encoders.15.layer_norm2.bias",
|
193 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_2.weight": "encoders.15.layer_norm2.weight",
|
194 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_fc.bias": "encoders.15.fc1.bias",
|
195 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_fc.weight": "encoders.15.fc1.weight",
|
196 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_proj.bias": "encoders.15.fc2.bias",
|
197 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_proj.weight": "encoders.15.fc2.weight",
|
198 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.in_proj_bias": ['encoders.16.attn.to_q.bias', 'encoders.16.attn.to_k.bias', 'encoders.16.attn.to_v.bias'],
|
199 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.in_proj_weight": ['encoders.16.attn.to_q.weight', 'encoders.16.attn.to_k.weight', 'encoders.16.attn.to_v.weight'],
|
200 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.out_proj.bias": "encoders.16.attn.to_out.bias",
|
201 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.out_proj.weight": "encoders.16.attn.to_out.weight",
|
202 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_1.bias": "encoders.16.layer_norm1.bias",
|
203 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_1.weight": "encoders.16.layer_norm1.weight",
|
204 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_2.bias": "encoders.16.layer_norm2.bias",
|
205 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_2.weight": "encoders.16.layer_norm2.weight",
|
206 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_fc.bias": "encoders.16.fc1.bias",
|
207 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_fc.weight": "encoders.16.fc1.weight",
|
208 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_proj.bias": "encoders.16.fc2.bias",
|
209 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_proj.weight": "encoders.16.fc2.weight",
|
210 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.in_proj_bias": ['encoders.17.attn.to_q.bias', 'encoders.17.attn.to_k.bias', 'encoders.17.attn.to_v.bias'],
|
211 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.in_proj_weight": ['encoders.17.attn.to_q.weight', 'encoders.17.attn.to_k.weight', 'encoders.17.attn.to_v.weight'],
|
212 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.out_proj.bias": "encoders.17.attn.to_out.bias",
|
213 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.out_proj.weight": "encoders.17.attn.to_out.weight",
|
214 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_1.bias": "encoders.17.layer_norm1.bias",
|
215 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_1.weight": "encoders.17.layer_norm1.weight",
|
216 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_2.bias": "encoders.17.layer_norm2.bias",
|
217 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_2.weight": "encoders.17.layer_norm2.weight",
|
218 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_fc.bias": "encoders.17.fc1.bias",
|
219 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_fc.weight": "encoders.17.fc1.weight",
|
220 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_proj.bias": "encoders.17.fc2.bias",
|
221 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_proj.weight": "encoders.17.fc2.weight",
|
222 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.in_proj_bias": ['encoders.18.attn.to_q.bias', 'encoders.18.attn.to_k.bias', 'encoders.18.attn.to_v.bias'],
|
223 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.in_proj_weight": ['encoders.18.attn.to_q.weight', 'encoders.18.attn.to_k.weight', 'encoders.18.attn.to_v.weight'],
|
224 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.out_proj.bias": "encoders.18.attn.to_out.bias",
|
225 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.out_proj.weight": "encoders.18.attn.to_out.weight",
|
226 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_1.bias": "encoders.18.layer_norm1.bias",
|
227 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_1.weight": "encoders.18.layer_norm1.weight",
|
228 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_2.bias": "encoders.18.layer_norm2.bias",
|
229 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_2.weight": "encoders.18.layer_norm2.weight",
|
230 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_fc.bias": "encoders.18.fc1.bias",
|
231 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_fc.weight": "encoders.18.fc1.weight",
|
232 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_proj.bias": "encoders.18.fc2.bias",
|
233 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_proj.weight": "encoders.18.fc2.weight",
|
234 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.in_proj_bias": ['encoders.19.attn.to_q.bias', 'encoders.19.attn.to_k.bias', 'encoders.19.attn.to_v.bias'],
|
235 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.in_proj_weight": ['encoders.19.attn.to_q.weight', 'encoders.19.attn.to_k.weight', 'encoders.19.attn.to_v.weight'],
|
236 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.out_proj.bias": "encoders.19.attn.to_out.bias",
|
237 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.out_proj.weight": "encoders.19.attn.to_out.weight",
|
238 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_1.bias": "encoders.19.layer_norm1.bias",
|
239 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_1.weight": "encoders.19.layer_norm1.weight",
|
240 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_2.bias": "encoders.19.layer_norm2.bias",
|
241 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_2.weight": "encoders.19.layer_norm2.weight",
|
242 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_fc.bias": "encoders.19.fc1.bias",
|
243 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_fc.weight": "encoders.19.fc1.weight",
|
244 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_proj.bias": "encoders.19.fc2.bias",
|
245 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_proj.weight": "encoders.19.fc2.weight",
|
246 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.in_proj_bias": ['encoders.2.attn.to_q.bias', 'encoders.2.attn.to_k.bias', 'encoders.2.attn.to_v.bias'],
|
247 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.in_proj_weight": ['encoders.2.attn.to_q.weight', 'encoders.2.attn.to_k.weight', 'encoders.2.attn.to_v.weight'],
|
248 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
249 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
250 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_1.bias": "encoders.2.layer_norm1.bias",
|
251 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_1.weight": "encoders.2.layer_norm1.weight",
|
252 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_2.bias": "encoders.2.layer_norm2.bias",
|
253 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_2.weight": "encoders.2.layer_norm2.weight",
|
254 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_fc.bias": "encoders.2.fc1.bias",
|
255 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_fc.weight": "encoders.2.fc1.weight",
|
256 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_proj.bias": "encoders.2.fc2.bias",
|
257 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_proj.weight": "encoders.2.fc2.weight",
|
258 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.in_proj_bias": ['encoders.20.attn.to_q.bias', 'encoders.20.attn.to_k.bias', 'encoders.20.attn.to_v.bias'],
|
259 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.in_proj_weight": ['encoders.20.attn.to_q.weight', 'encoders.20.attn.to_k.weight', 'encoders.20.attn.to_v.weight'],
|
260 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.out_proj.bias": "encoders.20.attn.to_out.bias",
|
261 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.out_proj.weight": "encoders.20.attn.to_out.weight",
|
262 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_1.bias": "encoders.20.layer_norm1.bias",
|
263 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_1.weight": "encoders.20.layer_norm1.weight",
|
264 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_2.bias": "encoders.20.layer_norm2.bias",
|
265 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_2.weight": "encoders.20.layer_norm2.weight",
|
266 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_fc.bias": "encoders.20.fc1.bias",
|
267 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_fc.weight": "encoders.20.fc1.weight",
|
268 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_proj.bias": "encoders.20.fc2.bias",
|
269 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_proj.weight": "encoders.20.fc2.weight",
|
270 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.in_proj_bias": ['encoders.21.attn.to_q.bias', 'encoders.21.attn.to_k.bias', 'encoders.21.attn.to_v.bias'],
|
271 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.in_proj_weight": ['encoders.21.attn.to_q.weight', 'encoders.21.attn.to_k.weight', 'encoders.21.attn.to_v.weight'],
|
272 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.out_proj.bias": "encoders.21.attn.to_out.bias",
|
273 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.out_proj.weight": "encoders.21.attn.to_out.weight",
|
274 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_1.bias": "encoders.21.layer_norm1.bias",
|
275 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_1.weight": "encoders.21.layer_norm1.weight",
|
276 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_2.bias": "encoders.21.layer_norm2.bias",
|
277 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_2.weight": "encoders.21.layer_norm2.weight",
|
278 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_fc.bias": "encoders.21.fc1.bias",
|
279 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_fc.weight": "encoders.21.fc1.weight",
|
280 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_proj.bias": "encoders.21.fc2.bias",
|
281 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_proj.weight": "encoders.21.fc2.weight",
|
282 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.in_proj_bias": ['encoders.22.attn.to_q.bias', 'encoders.22.attn.to_k.bias', 'encoders.22.attn.to_v.bias'],
|
283 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.in_proj_weight": ['encoders.22.attn.to_q.weight', 'encoders.22.attn.to_k.weight', 'encoders.22.attn.to_v.weight'],
|
284 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.out_proj.bias": "encoders.22.attn.to_out.bias",
|
285 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.out_proj.weight": "encoders.22.attn.to_out.weight",
|
286 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_1.bias": "encoders.22.layer_norm1.bias",
|
287 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_1.weight": "encoders.22.layer_norm1.weight",
|
288 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_2.bias": "encoders.22.layer_norm2.bias",
|
289 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_2.weight": "encoders.22.layer_norm2.weight",
|
290 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_fc.bias": "encoders.22.fc1.bias",
|
291 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_fc.weight": "encoders.22.fc1.weight",
|
292 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_proj.bias": "encoders.22.fc2.bias",
|
293 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_proj.weight": "encoders.22.fc2.weight",
|
294 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.in_proj_bias": ['encoders.23.attn.to_q.bias', 'encoders.23.attn.to_k.bias', 'encoders.23.attn.to_v.bias'],
|
295 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.in_proj_weight": ['encoders.23.attn.to_q.weight', 'encoders.23.attn.to_k.weight', 'encoders.23.attn.to_v.weight'],
|
296 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.out_proj.bias": "encoders.23.attn.to_out.bias",
|
297 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.out_proj.weight": "encoders.23.attn.to_out.weight",
|
298 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_1.bias": "encoders.23.layer_norm1.bias",
|
299 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_1.weight": "encoders.23.layer_norm1.weight",
|
300 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_2.bias": "encoders.23.layer_norm2.bias",
|
301 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_2.weight": "encoders.23.layer_norm2.weight",
|
302 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_fc.bias": "encoders.23.fc1.bias",
|
303 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_fc.weight": "encoders.23.fc1.weight",
|
304 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_proj.bias": "encoders.23.fc2.bias",
|
305 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_proj.weight": "encoders.23.fc2.weight",
|
306 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.in_proj_bias": ['encoders.24.attn.to_q.bias', 'encoders.24.attn.to_k.bias', 'encoders.24.attn.to_v.bias'],
|
307 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.in_proj_weight": ['encoders.24.attn.to_q.weight', 'encoders.24.attn.to_k.weight', 'encoders.24.attn.to_v.weight'],
|
308 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.out_proj.bias": "encoders.24.attn.to_out.bias",
|
309 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.out_proj.weight": "encoders.24.attn.to_out.weight",
|
310 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_1.bias": "encoders.24.layer_norm1.bias",
|
311 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_1.weight": "encoders.24.layer_norm1.weight",
|
312 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_2.bias": "encoders.24.layer_norm2.bias",
|
313 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_2.weight": "encoders.24.layer_norm2.weight",
|
314 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_fc.bias": "encoders.24.fc1.bias",
|
315 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_fc.weight": "encoders.24.fc1.weight",
|
316 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_proj.bias": "encoders.24.fc2.bias",
|
317 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_proj.weight": "encoders.24.fc2.weight",
|
318 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.in_proj_bias": ['encoders.25.attn.to_q.bias', 'encoders.25.attn.to_k.bias', 'encoders.25.attn.to_v.bias'],
|
319 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.in_proj_weight": ['encoders.25.attn.to_q.weight', 'encoders.25.attn.to_k.weight', 'encoders.25.attn.to_v.weight'],
|
320 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.out_proj.bias": "encoders.25.attn.to_out.bias",
|
321 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.out_proj.weight": "encoders.25.attn.to_out.weight",
|
322 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_1.bias": "encoders.25.layer_norm1.bias",
|
323 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_1.weight": "encoders.25.layer_norm1.weight",
|
324 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_2.bias": "encoders.25.layer_norm2.bias",
|
325 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_2.weight": "encoders.25.layer_norm2.weight",
|
326 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_fc.bias": "encoders.25.fc1.bias",
|
327 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_fc.weight": "encoders.25.fc1.weight",
|
328 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_proj.bias": "encoders.25.fc2.bias",
|
329 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_proj.weight": "encoders.25.fc2.weight",
|
330 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.in_proj_bias": ['encoders.26.attn.to_q.bias', 'encoders.26.attn.to_k.bias', 'encoders.26.attn.to_v.bias'],
|
331 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.in_proj_weight": ['encoders.26.attn.to_q.weight', 'encoders.26.attn.to_k.weight', 'encoders.26.attn.to_v.weight'],
|
332 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.out_proj.bias": "encoders.26.attn.to_out.bias",
|
333 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.out_proj.weight": "encoders.26.attn.to_out.weight",
|
334 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_1.bias": "encoders.26.layer_norm1.bias",
|
335 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_1.weight": "encoders.26.layer_norm1.weight",
|
336 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_2.bias": "encoders.26.layer_norm2.bias",
|
337 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_2.weight": "encoders.26.layer_norm2.weight",
|
338 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_fc.bias": "encoders.26.fc1.bias",
|
339 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_fc.weight": "encoders.26.fc1.weight",
|
340 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_proj.bias": "encoders.26.fc2.bias",
|
341 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_proj.weight": "encoders.26.fc2.weight",
|
342 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.in_proj_bias": ['encoders.27.attn.to_q.bias', 'encoders.27.attn.to_k.bias', 'encoders.27.attn.to_v.bias'],
|
343 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.in_proj_weight": ['encoders.27.attn.to_q.weight', 'encoders.27.attn.to_k.weight', 'encoders.27.attn.to_v.weight'],
|
344 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.out_proj.bias": "encoders.27.attn.to_out.bias",
|
345 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.out_proj.weight": "encoders.27.attn.to_out.weight",
|
346 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_1.bias": "encoders.27.layer_norm1.bias",
|
347 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_1.weight": "encoders.27.layer_norm1.weight",
|
348 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_2.bias": "encoders.27.layer_norm2.bias",
|
349 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_2.weight": "encoders.27.layer_norm2.weight",
|
350 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_fc.bias": "encoders.27.fc1.bias",
|
351 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_fc.weight": "encoders.27.fc1.weight",
|
352 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_proj.bias": "encoders.27.fc2.bias",
|
353 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_proj.weight": "encoders.27.fc2.weight",
|
354 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.in_proj_bias": ['encoders.28.attn.to_q.bias', 'encoders.28.attn.to_k.bias', 'encoders.28.attn.to_v.bias'],
|
355 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.in_proj_weight": ['encoders.28.attn.to_q.weight', 'encoders.28.attn.to_k.weight', 'encoders.28.attn.to_v.weight'],
|
356 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.out_proj.bias": "encoders.28.attn.to_out.bias",
|
357 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.out_proj.weight": "encoders.28.attn.to_out.weight",
|
358 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_1.bias": "encoders.28.layer_norm1.bias",
|
359 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_1.weight": "encoders.28.layer_norm1.weight",
|
360 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_2.bias": "encoders.28.layer_norm2.bias",
|
361 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_2.weight": "encoders.28.layer_norm2.weight",
|
362 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_fc.bias": "encoders.28.fc1.bias",
|
363 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_fc.weight": "encoders.28.fc1.weight",
|
364 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_proj.bias": "encoders.28.fc2.bias",
|
365 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_proj.weight": "encoders.28.fc2.weight",
|
366 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.in_proj_bias": ['encoders.29.attn.to_q.bias', 'encoders.29.attn.to_k.bias', 'encoders.29.attn.to_v.bias'],
|
367 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.in_proj_weight": ['encoders.29.attn.to_q.weight', 'encoders.29.attn.to_k.weight', 'encoders.29.attn.to_v.weight'],
|
368 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.out_proj.bias": "encoders.29.attn.to_out.bias",
|
369 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.out_proj.weight": "encoders.29.attn.to_out.weight",
|
370 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_1.bias": "encoders.29.layer_norm1.bias",
|
371 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_1.weight": "encoders.29.layer_norm1.weight",
|
372 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_2.bias": "encoders.29.layer_norm2.bias",
|
373 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_2.weight": "encoders.29.layer_norm2.weight",
|
374 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_fc.bias": "encoders.29.fc1.bias",
|
375 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_fc.weight": "encoders.29.fc1.weight",
|
376 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_proj.bias": "encoders.29.fc2.bias",
|
377 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_proj.weight": "encoders.29.fc2.weight",
|
378 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.in_proj_bias": ['encoders.3.attn.to_q.bias', 'encoders.3.attn.to_k.bias', 'encoders.3.attn.to_v.bias'],
|
379 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.in_proj_weight": ['encoders.3.attn.to_q.weight', 'encoders.3.attn.to_k.weight', 'encoders.3.attn.to_v.weight'],
|
380 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
381 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
382 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_1.bias": "encoders.3.layer_norm1.bias",
|
383 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_1.weight": "encoders.3.layer_norm1.weight",
|
384 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_2.bias": "encoders.3.layer_norm2.bias",
|
385 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_2.weight": "encoders.3.layer_norm2.weight",
|
386 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_fc.bias": "encoders.3.fc1.bias",
|
387 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_fc.weight": "encoders.3.fc1.weight",
|
388 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_proj.bias": "encoders.3.fc2.bias",
|
389 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_proj.weight": "encoders.3.fc2.weight",
|
390 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.in_proj_bias": ['encoders.30.attn.to_q.bias', 'encoders.30.attn.to_k.bias', 'encoders.30.attn.to_v.bias'],
|
391 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.in_proj_weight": ['encoders.30.attn.to_q.weight', 'encoders.30.attn.to_k.weight', 'encoders.30.attn.to_v.weight'],
|
392 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.out_proj.bias": "encoders.30.attn.to_out.bias",
|
393 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.out_proj.weight": "encoders.30.attn.to_out.weight",
|
394 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_1.bias": "encoders.30.layer_norm1.bias",
|
395 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_1.weight": "encoders.30.layer_norm1.weight",
|
396 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_2.bias": "encoders.30.layer_norm2.bias",
|
397 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_2.weight": "encoders.30.layer_norm2.weight",
|
398 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_fc.bias": "encoders.30.fc1.bias",
|
399 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_fc.weight": "encoders.30.fc1.weight",
|
400 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_proj.bias": "encoders.30.fc2.bias",
|
401 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_proj.weight": "encoders.30.fc2.weight",
|
402 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.in_proj_bias": ['encoders.31.attn.to_q.bias', 'encoders.31.attn.to_k.bias', 'encoders.31.attn.to_v.bias'],
|
403 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.in_proj_weight": ['encoders.31.attn.to_q.weight', 'encoders.31.attn.to_k.weight', 'encoders.31.attn.to_v.weight'],
|
404 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.out_proj.bias": "encoders.31.attn.to_out.bias",
|
405 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.out_proj.weight": "encoders.31.attn.to_out.weight",
|
406 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_1.bias": "encoders.31.layer_norm1.bias",
|
407 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_1.weight": "encoders.31.layer_norm1.weight",
|
408 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_2.bias": "encoders.31.layer_norm2.bias",
|
409 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_2.weight": "encoders.31.layer_norm2.weight",
|
410 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_fc.bias": "encoders.31.fc1.bias",
|
411 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_fc.weight": "encoders.31.fc1.weight",
|
412 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_proj.bias": "encoders.31.fc2.bias",
|
413 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_proj.weight": "encoders.31.fc2.weight",
|
414 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.in_proj_bias": ['encoders.4.attn.to_q.bias', 'encoders.4.attn.to_k.bias', 'encoders.4.attn.to_v.bias'],
|
415 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.in_proj_weight": ['encoders.4.attn.to_q.weight', 'encoders.4.attn.to_k.weight', 'encoders.4.attn.to_v.weight'],
|
416 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
417 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
418 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_1.bias": "encoders.4.layer_norm1.bias",
|
419 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_1.weight": "encoders.4.layer_norm1.weight",
|
420 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_2.bias": "encoders.4.layer_norm2.bias",
|
421 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_2.weight": "encoders.4.layer_norm2.weight",
|
422 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_fc.bias": "encoders.4.fc1.bias",
|
423 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_fc.weight": "encoders.4.fc1.weight",
|
424 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_proj.bias": "encoders.4.fc2.bias",
|
425 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_proj.weight": "encoders.4.fc2.weight",
|
426 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.in_proj_bias": ['encoders.5.attn.to_q.bias', 'encoders.5.attn.to_k.bias', 'encoders.5.attn.to_v.bias'],
|
427 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.in_proj_weight": ['encoders.5.attn.to_q.weight', 'encoders.5.attn.to_k.weight', 'encoders.5.attn.to_v.weight'],
|
428 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
429 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
430 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_1.bias": "encoders.5.layer_norm1.bias",
|
431 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_1.weight": "encoders.5.layer_norm1.weight",
|
432 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_2.bias": "encoders.5.layer_norm2.bias",
|
433 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_2.weight": "encoders.5.layer_norm2.weight",
|
434 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_fc.bias": "encoders.5.fc1.bias",
|
435 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_fc.weight": "encoders.5.fc1.weight",
|
436 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_proj.bias": "encoders.5.fc2.bias",
|
437 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_proj.weight": "encoders.5.fc2.weight",
|
438 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.in_proj_bias": ['encoders.6.attn.to_q.bias', 'encoders.6.attn.to_k.bias', 'encoders.6.attn.to_v.bias'],
|
439 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.in_proj_weight": ['encoders.6.attn.to_q.weight', 'encoders.6.attn.to_k.weight', 'encoders.6.attn.to_v.weight'],
|
440 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
441 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
442 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_1.bias": "encoders.6.layer_norm1.bias",
|
443 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_1.weight": "encoders.6.layer_norm1.weight",
|
444 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_2.bias": "encoders.6.layer_norm2.bias",
|
445 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_2.weight": "encoders.6.layer_norm2.weight",
|
446 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_fc.bias": "encoders.6.fc1.bias",
|
447 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_fc.weight": "encoders.6.fc1.weight",
|
448 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_proj.bias": "encoders.6.fc2.bias",
|
449 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_proj.weight": "encoders.6.fc2.weight",
|
450 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.in_proj_bias": ['encoders.7.attn.to_q.bias', 'encoders.7.attn.to_k.bias', 'encoders.7.attn.to_v.bias'],
|
451 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.in_proj_weight": ['encoders.7.attn.to_q.weight', 'encoders.7.attn.to_k.weight', 'encoders.7.attn.to_v.weight'],
|
452 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
453 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
454 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_1.bias": "encoders.7.layer_norm1.bias",
|
455 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_1.weight": "encoders.7.layer_norm1.weight",
|
456 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_2.bias": "encoders.7.layer_norm2.bias",
|
457 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_2.weight": "encoders.7.layer_norm2.weight",
|
458 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_fc.bias": "encoders.7.fc1.bias",
|
459 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_fc.weight": "encoders.7.fc1.weight",
|
460 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_proj.bias": "encoders.7.fc2.bias",
|
461 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_proj.weight": "encoders.7.fc2.weight",
|
462 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.in_proj_bias": ['encoders.8.attn.to_q.bias', 'encoders.8.attn.to_k.bias', 'encoders.8.attn.to_v.bias'],
|
463 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.in_proj_weight": ['encoders.8.attn.to_q.weight', 'encoders.8.attn.to_k.weight', 'encoders.8.attn.to_v.weight'],
|
464 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
465 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
466 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_1.bias": "encoders.8.layer_norm1.bias",
|
467 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_1.weight": "encoders.8.layer_norm1.weight",
|
468 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_2.bias": "encoders.8.layer_norm2.bias",
|
469 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_2.weight": "encoders.8.layer_norm2.weight",
|
470 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_fc.bias": "encoders.8.fc1.bias",
|
471 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_fc.weight": "encoders.8.fc1.weight",
|
472 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_proj.bias": "encoders.8.fc2.bias",
|
473 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_proj.weight": "encoders.8.fc2.weight",
|
474 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.in_proj_bias": ['encoders.9.attn.to_q.bias', 'encoders.9.attn.to_k.bias', 'encoders.9.attn.to_v.bias'],
|
475 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.in_proj_weight": ['encoders.9.attn.to_q.weight', 'encoders.9.attn.to_k.weight', 'encoders.9.attn.to_v.weight'],
|
476 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
477 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
478 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_1.bias": "encoders.9.layer_norm1.bias",
|
479 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_1.weight": "encoders.9.layer_norm1.weight",
|
480 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_2.bias": "encoders.9.layer_norm2.bias",
|
481 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_2.weight": "encoders.9.layer_norm2.weight",
|
482 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_fc.bias": "encoders.9.fc1.bias",
|
483 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_fc.weight": "encoders.9.fc1.weight",
|
484 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_proj.bias": "encoders.9.fc2.bias",
|
485 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_proj.weight": "encoders.9.fc2.weight",
|
486 |
+
"conditioner.embedders.0.open_clip.model.visual.proj": "visual_projection.weight",
|
487 |
+
}
|
488 |
+
state_dict_ = {}
|
489 |
+
for name in state_dict:
|
490 |
+
if name in rename_dict:
|
491 |
+
param = state_dict[name]
|
492 |
+
if name == "conditioner.embedders.0.open_clip.model.visual.class_embedding":
|
493 |
+
param = param.reshape((1, 1, param.shape[0]))
|
494 |
+
elif name == "conditioner.embedders.0.open_clip.model.visual.positional_embedding":
|
495 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
496 |
+
elif name == "conditioner.embedders.0.open_clip.model.visual.proj":
|
497 |
+
param = param.T
|
498 |
+
if isinstance(rename_dict[name], str):
|
499 |
+
state_dict_[rename_dict[name]] = param
|
500 |
+
else:
|
501 |
+
length = param.shape[0] // 3
|
502 |
+
for i, rename in enumerate(rename_dict[name]):
|
503 |
+
state_dict_[rename] = param[i*length: i*length+length]
|
504 |
+
return state_dict_
|
diffsynth/models/svd_unet.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
diffsynth/models/svd_vae_decoder.py
ADDED
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .attention import Attention
|
3 |
+
from .sd_unet import ResnetBlock, UpSampler
|
4 |
+
from .tiler import TileWorker
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
|
7 |
+
|
8 |
+
class VAEAttentionBlock(torch.nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
|
11 |
+
super().__init__()
|
12 |
+
inner_dim = num_attention_heads * attention_head_dim
|
13 |
+
|
14 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
15 |
+
|
16 |
+
self.transformer_blocks = torch.nn.ModuleList([
|
17 |
+
Attention(
|
18 |
+
inner_dim,
|
19 |
+
num_attention_heads,
|
20 |
+
attention_head_dim,
|
21 |
+
bias_q=True,
|
22 |
+
bias_kv=True,
|
23 |
+
bias_out=True
|
24 |
+
)
|
25 |
+
for d in range(num_layers)
|
26 |
+
])
|
27 |
+
|
28 |
+
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
29 |
+
batch, _, height, width = hidden_states.shape
|
30 |
+
residual = hidden_states
|
31 |
+
|
32 |
+
hidden_states = self.norm(hidden_states)
|
33 |
+
inner_dim = hidden_states.shape[1]
|
34 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
35 |
+
|
36 |
+
for block in self.transformer_blocks:
|
37 |
+
hidden_states = block(hidden_states)
|
38 |
+
|
39 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
40 |
+
hidden_states = hidden_states + residual
|
41 |
+
|
42 |
+
return hidden_states, time_emb, text_emb, res_stack
|
43 |
+
|
44 |
+
|
45 |
+
class TemporalResnetBlock(torch.nn.Module):
|
46 |
+
|
47 |
+
def __init__(self, in_channels, out_channels, groups=32, eps=1e-5):
|
48 |
+
super().__init__()
|
49 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
50 |
+
self.conv1 = torch.nn.Conv3d(in_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0))
|
51 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
52 |
+
self.conv2 = torch.nn.Conv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0))
|
53 |
+
self.nonlinearity = torch.nn.SiLU()
|
54 |
+
self.mix_factor = torch.nn.Parameter(torch.Tensor([0.5]))
|
55 |
+
|
56 |
+
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
57 |
+
x_spatial = hidden_states
|
58 |
+
x = rearrange(hidden_states, "T C H W -> 1 C T H W")
|
59 |
+
x = self.norm1(x)
|
60 |
+
x = self.nonlinearity(x)
|
61 |
+
x = self.conv1(x)
|
62 |
+
x = self.norm2(x)
|
63 |
+
x = self.nonlinearity(x)
|
64 |
+
x = self.conv2(x)
|
65 |
+
x_temporal = hidden_states + x[0].permute(1, 0, 2, 3)
|
66 |
+
alpha = torch.sigmoid(self.mix_factor)
|
67 |
+
hidden_states = alpha * x_temporal + (1 - alpha) * x_spatial
|
68 |
+
return hidden_states, time_emb, text_emb, res_stack
|
69 |
+
|
70 |
+
|
71 |
+
class SVDVAEDecoder(torch.nn.Module):
|
72 |
+
def __init__(self):
|
73 |
+
super().__init__()
|
74 |
+
self.scaling_factor = 0.18215
|
75 |
+
self.conv_in = torch.nn.Conv2d(4, 512, kernel_size=3, padding=1)
|
76 |
+
|
77 |
+
self.blocks = torch.nn.ModuleList([
|
78 |
+
# UNetMidBlock
|
79 |
+
ResnetBlock(512, 512, eps=1e-6),
|
80 |
+
TemporalResnetBlock(512, 512, eps=1e-6),
|
81 |
+
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
82 |
+
ResnetBlock(512, 512, eps=1e-6),
|
83 |
+
TemporalResnetBlock(512, 512, eps=1e-6),
|
84 |
+
# UpDecoderBlock
|
85 |
+
ResnetBlock(512, 512, eps=1e-6),
|
86 |
+
TemporalResnetBlock(512, 512, eps=1e-6),
|
87 |
+
ResnetBlock(512, 512, eps=1e-6),
|
88 |
+
TemporalResnetBlock(512, 512, eps=1e-6),
|
89 |
+
ResnetBlock(512, 512, eps=1e-6),
|
90 |
+
TemporalResnetBlock(512, 512, eps=1e-6),
|
91 |
+
UpSampler(512),
|
92 |
+
# UpDecoderBlock
|
93 |
+
ResnetBlock(512, 512, eps=1e-6),
|
94 |
+
TemporalResnetBlock(512, 512, eps=1e-6),
|
95 |
+
ResnetBlock(512, 512, eps=1e-6),
|
96 |
+
TemporalResnetBlock(512, 512, eps=1e-6),
|
97 |
+
ResnetBlock(512, 512, eps=1e-6),
|
98 |
+
TemporalResnetBlock(512, 512, eps=1e-6),
|
99 |
+
UpSampler(512),
|
100 |
+
# UpDecoderBlock
|
101 |
+
ResnetBlock(512, 256, eps=1e-6),
|
102 |
+
TemporalResnetBlock(256, 256, eps=1e-6),
|
103 |
+
ResnetBlock(256, 256, eps=1e-6),
|
104 |
+
TemporalResnetBlock(256, 256, eps=1e-6),
|
105 |
+
ResnetBlock(256, 256, eps=1e-6),
|
106 |
+
TemporalResnetBlock(256, 256, eps=1e-6),
|
107 |
+
UpSampler(256),
|
108 |
+
# UpDecoderBlock
|
109 |
+
ResnetBlock(256, 128, eps=1e-6),
|
110 |
+
TemporalResnetBlock(128, 128, eps=1e-6),
|
111 |
+
ResnetBlock(128, 128, eps=1e-6),
|
112 |
+
TemporalResnetBlock(128, 128, eps=1e-6),
|
113 |
+
ResnetBlock(128, 128, eps=1e-6),
|
114 |
+
TemporalResnetBlock(128, 128, eps=1e-6),
|
115 |
+
])
|
116 |
+
|
117 |
+
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-5)
|
118 |
+
self.conv_act = torch.nn.SiLU()
|
119 |
+
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
120 |
+
self.time_conv_out = torch.nn.Conv3d(3, 3, kernel_size=(3, 1, 1), padding=(1, 0, 0))
|
121 |
+
|
122 |
+
|
123 |
+
def forward(self, sample):
|
124 |
+
# 1. pre-process
|
125 |
+
hidden_states = rearrange(sample, "C T H W -> T C H W")
|
126 |
+
hidden_states = hidden_states / self.scaling_factor
|
127 |
+
hidden_states = self.conv_in(hidden_states)
|
128 |
+
time_emb, text_emb, res_stack = None, None, None
|
129 |
+
|
130 |
+
# 2. blocks
|
131 |
+
for i, block in enumerate(self.blocks):
|
132 |
+
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
133 |
+
|
134 |
+
# 3. output
|
135 |
+
hidden_states = self.conv_norm_out(hidden_states)
|
136 |
+
hidden_states = self.conv_act(hidden_states)
|
137 |
+
hidden_states = self.conv_out(hidden_states)
|
138 |
+
hidden_states = rearrange(hidden_states, "T C H W -> C T H W")
|
139 |
+
hidden_states = self.time_conv_out(hidden_states)
|
140 |
+
|
141 |
+
return hidden_states
|
142 |
+
|
143 |
+
|
144 |
+
def build_mask(self, data, is_bound):
|
145 |
+
_, T, H, W = data.shape
|
146 |
+
t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
|
147 |
+
h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
|
148 |
+
w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
|
149 |
+
border_width = (T + H + W) // 6
|
150 |
+
pad = torch.ones_like(t) * border_width
|
151 |
+
mask = torch.stack([
|
152 |
+
pad if is_bound[0] else t + 1,
|
153 |
+
pad if is_bound[1] else T - t,
|
154 |
+
pad if is_bound[2] else h + 1,
|
155 |
+
pad if is_bound[3] else H - h,
|
156 |
+
pad if is_bound[4] else w + 1,
|
157 |
+
pad if is_bound[5] else W - w
|
158 |
+
]).min(dim=0).values
|
159 |
+
mask = mask.clip(1, border_width)
|
160 |
+
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
|
161 |
+
mask = rearrange(mask, "T H W -> 1 T H W")
|
162 |
+
return mask
|
163 |
+
|
164 |
+
|
165 |
+
def decode_video(
|
166 |
+
self, sample,
|
167 |
+
batch_time=8, batch_height=128, batch_width=128,
|
168 |
+
stride_time=4, stride_height=32, stride_width=32,
|
169 |
+
progress_bar=lambda x:x
|
170 |
+
):
|
171 |
+
sample = sample.permute(1, 0, 2, 3)
|
172 |
+
data_device = sample.device
|
173 |
+
computation_device = self.conv_in.weight.device
|
174 |
+
torch_dtype = sample.dtype
|
175 |
+
_, T, H, W = sample.shape
|
176 |
+
|
177 |
+
weight = torch.zeros((1, T, H*8, W*8), dtype=torch_dtype, device=data_device)
|
178 |
+
values = torch.zeros((3, T, H*8, W*8), dtype=torch_dtype, device=data_device)
|
179 |
+
|
180 |
+
# Split tasks
|
181 |
+
tasks = []
|
182 |
+
for t in range(0, T, stride_time):
|
183 |
+
for h in range(0, H, stride_height):
|
184 |
+
for w in range(0, W, stride_width):
|
185 |
+
if (t-stride_time >= 0 and t-stride_time+batch_time >= T)\
|
186 |
+
or (h-stride_height >= 0 and h-stride_height+batch_height >= H)\
|
187 |
+
or (w-stride_width >= 0 and w-stride_width+batch_width >= W):
|
188 |
+
continue
|
189 |
+
tasks.append((t, t+batch_time, h, h+batch_height, w, w+batch_width))
|
190 |
+
|
191 |
+
# Run
|
192 |
+
for tl, tr, hl, hr, wl, wr in progress_bar(tasks):
|
193 |
+
sample_batch = sample[:, tl:tr, hl:hr, wl:wr].to(computation_device)
|
194 |
+
sample_batch = self.forward(sample_batch).to(data_device)
|
195 |
+
mask = self.build_mask(sample_batch, is_bound=(tl==0, tr>=T, hl==0, hr>=H, wl==0, wr>=W))
|
196 |
+
values[:, tl:tr, hl*8:hr*8, wl*8:wr*8] += sample_batch * mask
|
197 |
+
weight[:, tl:tr, hl*8:hr*8, wl*8:wr*8] += mask
|
198 |
+
values /= weight
|
199 |
+
return values
|
200 |
+
|
201 |
+
|
202 |
+
def state_dict_converter(self):
|
203 |
+
return SVDVAEDecoderStateDictConverter()
|
204 |
+
|
205 |
+
|
206 |
+
class SVDVAEDecoderStateDictConverter:
|
207 |
+
def __init__(self):
|
208 |
+
pass
|
209 |
+
|
210 |
+
def from_diffusers(self, state_dict):
|
211 |
+
static_rename_dict = {
|
212 |
+
"decoder.conv_in": "conv_in",
|
213 |
+
"decoder.mid_block.attentions.0.group_norm": "blocks.2.norm",
|
214 |
+
"decoder.mid_block.attentions.0.to_q": "blocks.2.transformer_blocks.0.to_q",
|
215 |
+
"decoder.mid_block.attentions.0.to_k": "blocks.2.transformer_blocks.0.to_k",
|
216 |
+
"decoder.mid_block.attentions.0.to_v": "blocks.2.transformer_blocks.0.to_v",
|
217 |
+
"decoder.mid_block.attentions.0.to_out.0": "blocks.2.transformer_blocks.0.to_out",
|
218 |
+
"decoder.up_blocks.0.upsamplers.0.conv": "blocks.11.conv",
|
219 |
+
"decoder.up_blocks.1.upsamplers.0.conv": "blocks.18.conv",
|
220 |
+
"decoder.up_blocks.2.upsamplers.0.conv": "blocks.25.conv",
|
221 |
+
"decoder.conv_norm_out": "conv_norm_out",
|
222 |
+
"decoder.conv_out": "conv_out",
|
223 |
+
"decoder.time_conv_out": "time_conv_out"
|
224 |
+
}
|
225 |
+
prefix_rename_dict = {
|
226 |
+
"decoder.mid_block.resnets.0.spatial_res_block": "blocks.0",
|
227 |
+
"decoder.mid_block.resnets.0.temporal_res_block": "blocks.1",
|
228 |
+
"decoder.mid_block.resnets.0.time_mixer": "blocks.1",
|
229 |
+
"decoder.mid_block.resnets.1.spatial_res_block": "blocks.3",
|
230 |
+
"decoder.mid_block.resnets.1.temporal_res_block": "blocks.4",
|
231 |
+
"decoder.mid_block.resnets.1.time_mixer": "blocks.4",
|
232 |
+
|
233 |
+
"decoder.up_blocks.0.resnets.0.spatial_res_block": "blocks.5",
|
234 |
+
"decoder.up_blocks.0.resnets.0.temporal_res_block": "blocks.6",
|
235 |
+
"decoder.up_blocks.0.resnets.0.time_mixer": "blocks.6",
|
236 |
+
"decoder.up_blocks.0.resnets.1.spatial_res_block": "blocks.7",
|
237 |
+
"decoder.up_blocks.0.resnets.1.temporal_res_block": "blocks.8",
|
238 |
+
"decoder.up_blocks.0.resnets.1.time_mixer": "blocks.8",
|
239 |
+
"decoder.up_blocks.0.resnets.2.spatial_res_block": "blocks.9",
|
240 |
+
"decoder.up_blocks.0.resnets.2.temporal_res_block": "blocks.10",
|
241 |
+
"decoder.up_blocks.0.resnets.2.time_mixer": "blocks.10",
|
242 |
+
|
243 |
+
"decoder.up_blocks.1.resnets.0.spatial_res_block": "blocks.12",
|
244 |
+
"decoder.up_blocks.1.resnets.0.temporal_res_block": "blocks.13",
|
245 |
+
"decoder.up_blocks.1.resnets.0.time_mixer": "blocks.13",
|
246 |
+
"decoder.up_blocks.1.resnets.1.spatial_res_block": "blocks.14",
|
247 |
+
"decoder.up_blocks.1.resnets.1.temporal_res_block": "blocks.15",
|
248 |
+
"decoder.up_blocks.1.resnets.1.time_mixer": "blocks.15",
|
249 |
+
"decoder.up_blocks.1.resnets.2.spatial_res_block": "blocks.16",
|
250 |
+
"decoder.up_blocks.1.resnets.2.temporal_res_block": "blocks.17",
|
251 |
+
"decoder.up_blocks.1.resnets.2.time_mixer": "blocks.17",
|
252 |
+
|
253 |
+
"decoder.up_blocks.2.resnets.0.spatial_res_block": "blocks.19",
|
254 |
+
"decoder.up_blocks.2.resnets.0.temporal_res_block": "blocks.20",
|
255 |
+
"decoder.up_blocks.2.resnets.0.time_mixer": "blocks.20",
|
256 |
+
"decoder.up_blocks.2.resnets.1.spatial_res_block": "blocks.21",
|
257 |
+
"decoder.up_blocks.2.resnets.1.temporal_res_block": "blocks.22",
|
258 |
+
"decoder.up_blocks.2.resnets.1.time_mixer": "blocks.22",
|
259 |
+
"decoder.up_blocks.2.resnets.2.spatial_res_block": "blocks.23",
|
260 |
+
"decoder.up_blocks.2.resnets.2.temporal_res_block": "blocks.24",
|
261 |
+
"decoder.up_blocks.2.resnets.2.time_mixer": "blocks.24",
|
262 |
+
|
263 |
+
"decoder.up_blocks.3.resnets.0.spatial_res_block": "blocks.26",
|
264 |
+
"decoder.up_blocks.3.resnets.0.temporal_res_block": "blocks.27",
|
265 |
+
"decoder.up_blocks.3.resnets.0.time_mixer": "blocks.27",
|
266 |
+
"decoder.up_blocks.3.resnets.1.spatial_res_block": "blocks.28",
|
267 |
+
"decoder.up_blocks.3.resnets.1.temporal_res_block": "blocks.29",
|
268 |
+
"decoder.up_blocks.3.resnets.1.time_mixer": "blocks.29",
|
269 |
+
"decoder.up_blocks.3.resnets.2.spatial_res_block": "blocks.30",
|
270 |
+
"decoder.up_blocks.3.resnets.2.temporal_res_block": "blocks.31",
|
271 |
+
"decoder.up_blocks.3.resnets.2.time_mixer": "blocks.31",
|
272 |
+
}
|
273 |
+
suffix_rename_dict = {
|
274 |
+
"norm1.weight": "norm1.weight",
|
275 |
+
"conv1.weight": "conv1.weight",
|
276 |
+
"norm2.weight": "norm2.weight",
|
277 |
+
"conv2.weight": "conv2.weight",
|
278 |
+
"conv_shortcut.weight": "conv_shortcut.weight",
|
279 |
+
"norm1.bias": "norm1.bias",
|
280 |
+
"conv1.bias": "conv1.bias",
|
281 |
+
"norm2.bias": "norm2.bias",
|
282 |
+
"conv2.bias": "conv2.bias",
|
283 |
+
"conv_shortcut.bias": "conv_shortcut.bias",
|
284 |
+
"mix_factor": "mix_factor",
|
285 |
+
}
|
286 |
+
|
287 |
+
state_dict_ = {}
|
288 |
+
for name in static_rename_dict:
|
289 |
+
state_dict_[static_rename_dict[name] + ".weight"] = state_dict[name + ".weight"]
|
290 |
+
state_dict_[static_rename_dict[name] + ".bias"] = state_dict[name + ".bias"]
|
291 |
+
for prefix_name in prefix_rename_dict:
|
292 |
+
for suffix_name in suffix_rename_dict:
|
293 |
+
name = prefix_name + "." + suffix_name
|
294 |
+
name_ = prefix_rename_dict[prefix_name] + "." + suffix_rename_dict[suffix_name]
|
295 |
+
if name in state_dict:
|
296 |
+
state_dict_[name_] = state_dict[name]
|
297 |
+
|
298 |
+
return state_dict_
|
299 |
+
|
300 |
+
|
301 |
+
def from_civitai(self, state_dict):
|
302 |
+
rename_dict = {
|
303 |
+
"first_stage_model.decoder.conv_in.bias": "conv_in.bias",
|
304 |
+
"first_stage_model.decoder.conv_in.weight": "conv_in.weight",
|
305 |
+
"first_stage_model.decoder.conv_out.bias": "conv_out.bias",
|
306 |
+
"first_stage_model.decoder.conv_out.time_mix_conv.bias": "time_conv_out.bias",
|
307 |
+
"first_stage_model.decoder.conv_out.time_mix_conv.weight": "time_conv_out.weight",
|
308 |
+
"first_stage_model.decoder.conv_out.weight": "conv_out.weight",
|
309 |
+
"first_stage_model.decoder.mid.attn_1.k.bias": "blocks.2.transformer_blocks.0.to_k.bias",
|
310 |
+
"first_stage_model.decoder.mid.attn_1.k.weight": "blocks.2.transformer_blocks.0.to_k.weight",
|
311 |
+
"first_stage_model.decoder.mid.attn_1.norm.bias": "blocks.2.norm.bias",
|
312 |
+
"first_stage_model.decoder.mid.attn_1.norm.weight": "blocks.2.norm.weight",
|
313 |
+
"first_stage_model.decoder.mid.attn_1.proj_out.bias": "blocks.2.transformer_blocks.0.to_out.bias",
|
314 |
+
"first_stage_model.decoder.mid.attn_1.proj_out.weight": "blocks.2.transformer_blocks.0.to_out.weight",
|
315 |
+
"first_stage_model.decoder.mid.attn_1.q.bias": "blocks.2.transformer_blocks.0.to_q.bias",
|
316 |
+
"first_stage_model.decoder.mid.attn_1.q.weight": "blocks.2.transformer_blocks.0.to_q.weight",
|
317 |
+
"first_stage_model.decoder.mid.attn_1.v.bias": "blocks.2.transformer_blocks.0.to_v.bias",
|
318 |
+
"first_stage_model.decoder.mid.attn_1.v.weight": "blocks.2.transformer_blocks.0.to_v.weight",
|
319 |
+
"first_stage_model.decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
|
320 |
+
"first_stage_model.decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
|
321 |
+
"first_stage_model.decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
|
322 |
+
"first_stage_model.decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
|
323 |
+
"first_stage_model.decoder.mid.block_1.mix_factor": "blocks.1.mix_factor",
|
324 |
+
"first_stage_model.decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
|
325 |
+
"first_stage_model.decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
|
326 |
+
"first_stage_model.decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
|
327 |
+
"first_stage_model.decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
|
328 |
+
"first_stage_model.decoder.mid.block_1.time_stack.in_layers.0.bias": "blocks.1.norm1.bias",
|
329 |
+
"first_stage_model.decoder.mid.block_1.time_stack.in_layers.0.weight": "blocks.1.norm1.weight",
|
330 |
+
"first_stage_model.decoder.mid.block_1.time_stack.in_layers.2.bias": "blocks.1.conv1.bias",
|
331 |
+
"first_stage_model.decoder.mid.block_1.time_stack.in_layers.2.weight": "blocks.1.conv1.weight",
|
332 |
+
"first_stage_model.decoder.mid.block_1.time_stack.out_layers.0.bias": "blocks.1.norm2.bias",
|
333 |
+
"first_stage_model.decoder.mid.block_1.time_stack.out_layers.0.weight": "blocks.1.norm2.weight",
|
334 |
+
"first_stage_model.decoder.mid.block_1.time_stack.out_layers.3.bias": "blocks.1.conv2.bias",
|
335 |
+
"first_stage_model.decoder.mid.block_1.time_stack.out_layers.3.weight": "blocks.1.conv2.weight",
|
336 |
+
"first_stage_model.decoder.mid.block_2.conv1.bias": "blocks.3.conv1.bias",
|
337 |
+
"first_stage_model.decoder.mid.block_2.conv1.weight": "blocks.3.conv1.weight",
|
338 |
+
"first_stage_model.decoder.mid.block_2.conv2.bias": "blocks.3.conv2.bias",
|
339 |
+
"first_stage_model.decoder.mid.block_2.conv2.weight": "blocks.3.conv2.weight",
|
340 |
+
"first_stage_model.decoder.mid.block_2.mix_factor": "blocks.4.mix_factor",
|
341 |
+
"first_stage_model.decoder.mid.block_2.norm1.bias": "blocks.3.norm1.bias",
|
342 |
+
"first_stage_model.decoder.mid.block_2.norm1.weight": "blocks.3.norm1.weight",
|
343 |
+
"first_stage_model.decoder.mid.block_2.norm2.bias": "blocks.3.norm2.bias",
|
344 |
+
"first_stage_model.decoder.mid.block_2.norm2.weight": "blocks.3.norm2.weight",
|
345 |
+
"first_stage_model.decoder.mid.block_2.time_stack.in_layers.0.bias": "blocks.4.norm1.bias",
|
346 |
+
"first_stage_model.decoder.mid.block_2.time_stack.in_layers.0.weight": "blocks.4.norm1.weight",
|
347 |
+
"first_stage_model.decoder.mid.block_2.time_stack.in_layers.2.bias": "blocks.4.conv1.bias",
|
348 |
+
"first_stage_model.decoder.mid.block_2.time_stack.in_layers.2.weight": "blocks.4.conv1.weight",
|
349 |
+
"first_stage_model.decoder.mid.block_2.time_stack.out_layers.0.bias": "blocks.4.norm2.bias",
|
350 |
+
"first_stage_model.decoder.mid.block_2.time_stack.out_layers.0.weight": "blocks.4.norm2.weight",
|
351 |
+
"first_stage_model.decoder.mid.block_2.time_stack.out_layers.3.bias": "blocks.4.conv2.bias",
|
352 |
+
"first_stage_model.decoder.mid.block_2.time_stack.out_layers.3.weight": "blocks.4.conv2.weight",
|
353 |
+
"first_stage_model.decoder.norm_out.bias": "conv_norm_out.bias",
|
354 |
+
"first_stage_model.decoder.norm_out.weight": "conv_norm_out.weight",
|
355 |
+
"first_stage_model.decoder.up.0.block.0.conv1.bias": "blocks.26.conv1.bias",
|
356 |
+
"first_stage_model.decoder.up.0.block.0.conv1.weight": "blocks.26.conv1.weight",
|
357 |
+
"first_stage_model.decoder.up.0.block.0.conv2.bias": "blocks.26.conv2.bias",
|
358 |
+
"first_stage_model.decoder.up.0.block.0.conv2.weight": "blocks.26.conv2.weight",
|
359 |
+
"first_stage_model.decoder.up.0.block.0.mix_factor": "blocks.27.mix_factor",
|
360 |
+
"first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "blocks.26.conv_shortcut.bias",
|
361 |
+
"first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "blocks.26.conv_shortcut.weight",
|
362 |
+
"first_stage_model.decoder.up.0.block.0.norm1.bias": "blocks.26.norm1.bias",
|
363 |
+
"first_stage_model.decoder.up.0.block.0.norm1.weight": "blocks.26.norm1.weight",
|
364 |
+
"first_stage_model.decoder.up.0.block.0.norm2.bias": "blocks.26.norm2.bias",
|
365 |
+
"first_stage_model.decoder.up.0.block.0.norm2.weight": "blocks.26.norm2.weight",
|
366 |
+
"first_stage_model.decoder.up.0.block.0.time_stack.in_layers.0.bias": "blocks.27.norm1.bias",
|
367 |
+
"first_stage_model.decoder.up.0.block.0.time_stack.in_layers.0.weight": "blocks.27.norm1.weight",
|
368 |
+
"first_stage_model.decoder.up.0.block.0.time_stack.in_layers.2.bias": "blocks.27.conv1.bias",
|
369 |
+
"first_stage_model.decoder.up.0.block.0.time_stack.in_layers.2.weight": "blocks.27.conv1.weight",
|
370 |
+
"first_stage_model.decoder.up.0.block.0.time_stack.out_layers.0.bias": "blocks.27.norm2.bias",
|
371 |
+
"first_stage_model.decoder.up.0.block.0.time_stack.out_layers.0.weight": "blocks.27.norm2.weight",
|
372 |
+
"first_stage_model.decoder.up.0.block.0.time_stack.out_layers.3.bias": "blocks.27.conv2.bias",
|
373 |
+
"first_stage_model.decoder.up.0.block.0.time_stack.out_layers.3.weight": "blocks.27.conv2.weight",
|
374 |
+
"first_stage_model.decoder.up.0.block.1.conv1.bias": "blocks.28.conv1.bias",
|
375 |
+
"first_stage_model.decoder.up.0.block.1.conv1.weight": "blocks.28.conv1.weight",
|
376 |
+
"first_stage_model.decoder.up.0.block.1.conv2.bias": "blocks.28.conv2.bias",
|
377 |
+
"first_stage_model.decoder.up.0.block.1.conv2.weight": "blocks.28.conv2.weight",
|
378 |
+
"first_stage_model.decoder.up.0.block.1.mix_factor": "blocks.29.mix_factor",
|
379 |
+
"first_stage_model.decoder.up.0.block.1.norm1.bias": "blocks.28.norm1.bias",
|
380 |
+
"first_stage_model.decoder.up.0.block.1.norm1.weight": "blocks.28.norm1.weight",
|
381 |
+
"first_stage_model.decoder.up.0.block.1.norm2.bias": "blocks.28.norm2.bias",
|
382 |
+
"first_stage_model.decoder.up.0.block.1.norm2.weight": "blocks.28.norm2.weight",
|
383 |
+
"first_stage_model.decoder.up.0.block.1.time_stack.in_layers.0.bias": "blocks.29.norm1.bias",
|
384 |
+
"first_stage_model.decoder.up.0.block.1.time_stack.in_layers.0.weight": "blocks.29.norm1.weight",
|
385 |
+
"first_stage_model.decoder.up.0.block.1.time_stack.in_layers.2.bias": "blocks.29.conv1.bias",
|
386 |
+
"first_stage_model.decoder.up.0.block.1.time_stack.in_layers.2.weight": "blocks.29.conv1.weight",
|
387 |
+
"first_stage_model.decoder.up.0.block.1.time_stack.out_layers.0.bias": "blocks.29.norm2.bias",
|
388 |
+
"first_stage_model.decoder.up.0.block.1.time_stack.out_layers.0.weight": "blocks.29.norm2.weight",
|
389 |
+
"first_stage_model.decoder.up.0.block.1.time_stack.out_layers.3.bias": "blocks.29.conv2.bias",
|
390 |
+
"first_stage_model.decoder.up.0.block.1.time_stack.out_layers.3.weight": "blocks.29.conv2.weight",
|
391 |
+
"first_stage_model.decoder.up.0.block.2.conv1.bias": "blocks.30.conv1.bias",
|
392 |
+
"first_stage_model.decoder.up.0.block.2.conv1.weight": "blocks.30.conv1.weight",
|
393 |
+
"first_stage_model.decoder.up.0.block.2.conv2.bias": "blocks.30.conv2.bias",
|
394 |
+
"first_stage_model.decoder.up.0.block.2.conv2.weight": "blocks.30.conv2.weight",
|
395 |
+
"first_stage_model.decoder.up.0.block.2.mix_factor": "blocks.31.mix_factor",
|
396 |
+
"first_stage_model.decoder.up.0.block.2.norm1.bias": "blocks.30.norm1.bias",
|
397 |
+
"first_stage_model.decoder.up.0.block.2.norm1.weight": "blocks.30.norm1.weight",
|
398 |
+
"first_stage_model.decoder.up.0.block.2.norm2.bias": "blocks.30.norm2.bias",
|
399 |
+
"first_stage_model.decoder.up.0.block.2.norm2.weight": "blocks.30.norm2.weight",
|
400 |
+
"first_stage_model.decoder.up.0.block.2.time_stack.in_layers.0.bias": "blocks.31.norm1.bias",
|
401 |
+
"first_stage_model.decoder.up.0.block.2.time_stack.in_layers.0.weight": "blocks.31.norm1.weight",
|
402 |
+
"first_stage_model.decoder.up.0.block.2.time_stack.in_layers.2.bias": "blocks.31.conv1.bias",
|
403 |
+
"first_stage_model.decoder.up.0.block.2.time_stack.in_layers.2.weight": "blocks.31.conv1.weight",
|
404 |
+
"first_stage_model.decoder.up.0.block.2.time_stack.out_layers.0.bias": "blocks.31.norm2.bias",
|
405 |
+
"first_stage_model.decoder.up.0.block.2.time_stack.out_layers.0.weight": "blocks.31.norm2.weight",
|
406 |
+
"first_stage_model.decoder.up.0.block.2.time_stack.out_layers.3.bias": "blocks.31.conv2.bias",
|
407 |
+
"first_stage_model.decoder.up.0.block.2.time_stack.out_layers.3.weight": "blocks.31.conv2.weight",
|
408 |
+
"first_stage_model.decoder.up.1.block.0.conv1.bias": "blocks.19.conv1.bias",
|
409 |
+
"first_stage_model.decoder.up.1.block.0.conv1.weight": "blocks.19.conv1.weight",
|
410 |
+
"first_stage_model.decoder.up.1.block.0.conv2.bias": "blocks.19.conv2.bias",
|
411 |
+
"first_stage_model.decoder.up.1.block.0.conv2.weight": "blocks.19.conv2.weight",
|
412 |
+
"first_stage_model.decoder.up.1.block.0.mix_factor": "blocks.20.mix_factor",
|
413 |
+
"first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "blocks.19.conv_shortcut.bias",
|
414 |
+
"first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "blocks.19.conv_shortcut.weight",
|
415 |
+
"first_stage_model.decoder.up.1.block.0.norm1.bias": "blocks.19.norm1.bias",
|
416 |
+
"first_stage_model.decoder.up.1.block.0.norm1.weight": "blocks.19.norm1.weight",
|
417 |
+
"first_stage_model.decoder.up.1.block.0.norm2.bias": "blocks.19.norm2.bias",
|
418 |
+
"first_stage_model.decoder.up.1.block.0.norm2.weight": "blocks.19.norm2.weight",
|
419 |
+
"first_stage_model.decoder.up.1.block.0.time_stack.in_layers.0.bias": "blocks.20.norm1.bias",
|
420 |
+
"first_stage_model.decoder.up.1.block.0.time_stack.in_layers.0.weight": "blocks.20.norm1.weight",
|
421 |
+
"first_stage_model.decoder.up.1.block.0.time_stack.in_layers.2.bias": "blocks.20.conv1.bias",
|
422 |
+
"first_stage_model.decoder.up.1.block.0.time_stack.in_layers.2.weight": "blocks.20.conv1.weight",
|
423 |
+
"first_stage_model.decoder.up.1.block.0.time_stack.out_layers.0.bias": "blocks.20.norm2.bias",
|
424 |
+
"first_stage_model.decoder.up.1.block.0.time_stack.out_layers.0.weight": "blocks.20.norm2.weight",
|
425 |
+
"first_stage_model.decoder.up.1.block.0.time_stack.out_layers.3.bias": "blocks.20.conv2.bias",
|
426 |
+
"first_stage_model.decoder.up.1.block.0.time_stack.out_layers.3.weight": "blocks.20.conv2.weight",
|
427 |
+
"first_stage_model.decoder.up.1.block.1.conv1.bias": "blocks.21.conv1.bias",
|
428 |
+
"first_stage_model.decoder.up.1.block.1.conv1.weight": "blocks.21.conv1.weight",
|
429 |
+
"first_stage_model.decoder.up.1.block.1.conv2.bias": "blocks.21.conv2.bias",
|
430 |
+
"first_stage_model.decoder.up.1.block.1.conv2.weight": "blocks.21.conv2.weight",
|
431 |
+
"first_stage_model.decoder.up.1.block.1.mix_factor": "blocks.22.mix_factor",
|
432 |
+
"first_stage_model.decoder.up.1.block.1.norm1.bias": "blocks.21.norm1.bias",
|
433 |
+
"first_stage_model.decoder.up.1.block.1.norm1.weight": "blocks.21.norm1.weight",
|
434 |
+
"first_stage_model.decoder.up.1.block.1.norm2.bias": "blocks.21.norm2.bias",
|
435 |
+
"first_stage_model.decoder.up.1.block.1.norm2.weight": "blocks.21.norm2.weight",
|
436 |
+
"first_stage_model.decoder.up.1.block.1.time_stack.in_layers.0.bias": "blocks.22.norm1.bias",
|
437 |
+
"first_stage_model.decoder.up.1.block.1.time_stack.in_layers.0.weight": "blocks.22.norm1.weight",
|
438 |
+
"first_stage_model.decoder.up.1.block.1.time_stack.in_layers.2.bias": "blocks.22.conv1.bias",
|
439 |
+
"first_stage_model.decoder.up.1.block.1.time_stack.in_layers.2.weight": "blocks.22.conv1.weight",
|
440 |
+
"first_stage_model.decoder.up.1.block.1.time_stack.out_layers.0.bias": "blocks.22.norm2.bias",
|
441 |
+
"first_stage_model.decoder.up.1.block.1.time_stack.out_layers.0.weight": "blocks.22.norm2.weight",
|
442 |
+
"first_stage_model.decoder.up.1.block.1.time_stack.out_layers.3.bias": "blocks.22.conv2.bias",
|
443 |
+
"first_stage_model.decoder.up.1.block.1.time_stack.out_layers.3.weight": "blocks.22.conv2.weight",
|
444 |
+
"first_stage_model.decoder.up.1.block.2.conv1.bias": "blocks.23.conv1.bias",
|
445 |
+
"first_stage_model.decoder.up.1.block.2.conv1.weight": "blocks.23.conv1.weight",
|
446 |
+
"first_stage_model.decoder.up.1.block.2.conv2.bias": "blocks.23.conv2.bias",
|
447 |
+
"first_stage_model.decoder.up.1.block.2.conv2.weight": "blocks.23.conv2.weight",
|
448 |
+
"first_stage_model.decoder.up.1.block.2.mix_factor": "blocks.24.mix_factor",
|
449 |
+
"first_stage_model.decoder.up.1.block.2.norm1.bias": "blocks.23.norm1.bias",
|
450 |
+
"first_stage_model.decoder.up.1.block.2.norm1.weight": "blocks.23.norm1.weight",
|
451 |
+
"first_stage_model.decoder.up.1.block.2.norm2.bias": "blocks.23.norm2.bias",
|
452 |
+
"first_stage_model.decoder.up.1.block.2.norm2.weight": "blocks.23.norm2.weight",
|
453 |
+
"first_stage_model.decoder.up.1.block.2.time_stack.in_layers.0.bias": "blocks.24.norm1.bias",
|
454 |
+
"first_stage_model.decoder.up.1.block.2.time_stack.in_layers.0.weight": "blocks.24.norm1.weight",
|
455 |
+
"first_stage_model.decoder.up.1.block.2.time_stack.in_layers.2.bias": "blocks.24.conv1.bias",
|
456 |
+
"first_stage_model.decoder.up.1.block.2.time_stack.in_layers.2.weight": "blocks.24.conv1.weight",
|
457 |
+
"first_stage_model.decoder.up.1.block.2.time_stack.out_layers.0.bias": "blocks.24.norm2.bias",
|
458 |
+
"first_stage_model.decoder.up.1.block.2.time_stack.out_layers.0.weight": "blocks.24.norm2.weight",
|
459 |
+
"first_stage_model.decoder.up.1.block.2.time_stack.out_layers.3.bias": "blocks.24.conv2.bias",
|
460 |
+
"first_stage_model.decoder.up.1.block.2.time_stack.out_layers.3.weight": "blocks.24.conv2.weight",
|
461 |
+
"first_stage_model.decoder.up.1.upsample.conv.bias": "blocks.25.conv.bias",
|
462 |
+
"first_stage_model.decoder.up.1.upsample.conv.weight": "blocks.25.conv.weight",
|
463 |
+
"first_stage_model.decoder.up.2.block.0.conv1.bias": "blocks.12.conv1.bias",
|
464 |
+
"first_stage_model.decoder.up.2.block.0.conv1.weight": "blocks.12.conv1.weight",
|
465 |
+
"first_stage_model.decoder.up.2.block.0.conv2.bias": "blocks.12.conv2.bias",
|
466 |
+
"first_stage_model.decoder.up.2.block.0.conv2.weight": "blocks.12.conv2.weight",
|
467 |
+
"first_stage_model.decoder.up.2.block.0.mix_factor": "blocks.13.mix_factor",
|
468 |
+
"first_stage_model.decoder.up.2.block.0.norm1.bias": "blocks.12.norm1.bias",
|
469 |
+
"first_stage_model.decoder.up.2.block.0.norm1.weight": "blocks.12.norm1.weight",
|
470 |
+
"first_stage_model.decoder.up.2.block.0.norm2.bias": "blocks.12.norm2.bias",
|
471 |
+
"first_stage_model.decoder.up.2.block.0.norm2.weight": "blocks.12.norm2.weight",
|
472 |
+
"first_stage_model.decoder.up.2.block.0.time_stack.in_layers.0.bias": "blocks.13.norm1.bias",
|
473 |
+
"first_stage_model.decoder.up.2.block.0.time_stack.in_layers.0.weight": "blocks.13.norm1.weight",
|
474 |
+
"first_stage_model.decoder.up.2.block.0.time_stack.in_layers.2.bias": "blocks.13.conv1.bias",
|
475 |
+
"first_stage_model.decoder.up.2.block.0.time_stack.in_layers.2.weight": "blocks.13.conv1.weight",
|
476 |
+
"first_stage_model.decoder.up.2.block.0.time_stack.out_layers.0.bias": "blocks.13.norm2.bias",
|
477 |
+
"first_stage_model.decoder.up.2.block.0.time_stack.out_layers.0.weight": "blocks.13.norm2.weight",
|
478 |
+
"first_stage_model.decoder.up.2.block.0.time_stack.out_layers.3.bias": "blocks.13.conv2.bias",
|
479 |
+
"first_stage_model.decoder.up.2.block.0.time_stack.out_layers.3.weight": "blocks.13.conv2.weight",
|
480 |
+
"first_stage_model.decoder.up.2.block.1.conv1.bias": "blocks.14.conv1.bias",
|
481 |
+
"first_stage_model.decoder.up.2.block.1.conv1.weight": "blocks.14.conv1.weight",
|
482 |
+
"first_stage_model.decoder.up.2.block.1.conv2.bias": "blocks.14.conv2.bias",
|
483 |
+
"first_stage_model.decoder.up.2.block.1.conv2.weight": "blocks.14.conv2.weight",
|
484 |
+
"first_stage_model.decoder.up.2.block.1.mix_factor": "blocks.15.mix_factor",
|
485 |
+
"first_stage_model.decoder.up.2.block.1.norm1.bias": "blocks.14.norm1.bias",
|
486 |
+
"first_stage_model.decoder.up.2.block.1.norm1.weight": "blocks.14.norm1.weight",
|
487 |
+
"first_stage_model.decoder.up.2.block.1.norm2.bias": "blocks.14.norm2.bias",
|
488 |
+
"first_stage_model.decoder.up.2.block.1.norm2.weight": "blocks.14.norm2.weight",
|
489 |
+
"first_stage_model.decoder.up.2.block.1.time_stack.in_layers.0.bias": "blocks.15.norm1.bias",
|
490 |
+
"first_stage_model.decoder.up.2.block.1.time_stack.in_layers.0.weight": "blocks.15.norm1.weight",
|
491 |
+
"first_stage_model.decoder.up.2.block.1.time_stack.in_layers.2.bias": "blocks.15.conv1.bias",
|
492 |
+
"first_stage_model.decoder.up.2.block.1.time_stack.in_layers.2.weight": "blocks.15.conv1.weight",
|
493 |
+
"first_stage_model.decoder.up.2.block.1.time_stack.out_layers.0.bias": "blocks.15.norm2.bias",
|
494 |
+
"first_stage_model.decoder.up.2.block.1.time_stack.out_layers.0.weight": "blocks.15.norm2.weight",
|
495 |
+
"first_stage_model.decoder.up.2.block.1.time_stack.out_layers.3.bias": "blocks.15.conv2.bias",
|
496 |
+
"first_stage_model.decoder.up.2.block.1.time_stack.out_layers.3.weight": "blocks.15.conv2.weight",
|
497 |
+
"first_stage_model.decoder.up.2.block.2.conv1.bias": "blocks.16.conv1.bias",
|
498 |
+
"first_stage_model.decoder.up.2.block.2.conv1.weight": "blocks.16.conv1.weight",
|
499 |
+
"first_stage_model.decoder.up.2.block.2.conv2.bias": "blocks.16.conv2.bias",
|
500 |
+
"first_stage_model.decoder.up.2.block.2.conv2.weight": "blocks.16.conv2.weight",
|
501 |
+
"first_stage_model.decoder.up.2.block.2.mix_factor": "blocks.17.mix_factor",
|
502 |
+
"first_stage_model.decoder.up.2.block.2.norm1.bias": "blocks.16.norm1.bias",
|
503 |
+
"first_stage_model.decoder.up.2.block.2.norm1.weight": "blocks.16.norm1.weight",
|
504 |
+
"first_stage_model.decoder.up.2.block.2.norm2.bias": "blocks.16.norm2.bias",
|
505 |
+
"first_stage_model.decoder.up.2.block.2.norm2.weight": "blocks.16.norm2.weight",
|
506 |
+
"first_stage_model.decoder.up.2.block.2.time_stack.in_layers.0.bias": "blocks.17.norm1.bias",
|
507 |
+
"first_stage_model.decoder.up.2.block.2.time_stack.in_layers.0.weight": "blocks.17.norm1.weight",
|
508 |
+
"first_stage_model.decoder.up.2.block.2.time_stack.in_layers.2.bias": "blocks.17.conv1.bias",
|
509 |
+
"first_stage_model.decoder.up.2.block.2.time_stack.in_layers.2.weight": "blocks.17.conv1.weight",
|
510 |
+
"first_stage_model.decoder.up.2.block.2.time_stack.out_layers.0.bias": "blocks.17.norm2.bias",
|
511 |
+
"first_stage_model.decoder.up.2.block.2.time_stack.out_layers.0.weight": "blocks.17.norm2.weight",
|
512 |
+
"first_stage_model.decoder.up.2.block.2.time_stack.out_layers.3.bias": "blocks.17.conv2.bias",
|
513 |
+
"first_stage_model.decoder.up.2.block.2.time_stack.out_layers.3.weight": "blocks.17.conv2.weight",
|
514 |
+
"first_stage_model.decoder.up.2.upsample.conv.bias": "blocks.18.conv.bias",
|
515 |
+
"first_stage_model.decoder.up.2.upsample.conv.weight": "blocks.18.conv.weight",
|
516 |
+
"first_stage_model.decoder.up.3.block.0.conv1.bias": "blocks.5.conv1.bias",
|
517 |
+
"first_stage_model.decoder.up.3.block.0.conv1.weight": "blocks.5.conv1.weight",
|
518 |
+
"first_stage_model.decoder.up.3.block.0.conv2.bias": "blocks.5.conv2.bias",
|
519 |
+
"first_stage_model.decoder.up.3.block.0.conv2.weight": "blocks.5.conv2.weight",
|
520 |
+
"first_stage_model.decoder.up.3.block.0.mix_factor": "blocks.6.mix_factor",
|
521 |
+
"first_stage_model.decoder.up.3.block.0.norm1.bias": "blocks.5.norm1.bias",
|
522 |
+
"first_stage_model.decoder.up.3.block.0.norm1.weight": "blocks.5.norm1.weight",
|
523 |
+
"first_stage_model.decoder.up.3.block.0.norm2.bias": "blocks.5.norm2.bias",
|
524 |
+
"first_stage_model.decoder.up.3.block.0.norm2.weight": "blocks.5.norm2.weight",
|
525 |
+
"first_stage_model.decoder.up.3.block.0.time_stack.in_layers.0.bias": "blocks.6.norm1.bias",
|
526 |
+
"first_stage_model.decoder.up.3.block.0.time_stack.in_layers.0.weight": "blocks.6.norm1.weight",
|
527 |
+
"first_stage_model.decoder.up.3.block.0.time_stack.in_layers.2.bias": "blocks.6.conv1.bias",
|
528 |
+
"first_stage_model.decoder.up.3.block.0.time_stack.in_layers.2.weight": "blocks.6.conv1.weight",
|
529 |
+
"first_stage_model.decoder.up.3.block.0.time_stack.out_layers.0.bias": "blocks.6.norm2.bias",
|
530 |
+
"first_stage_model.decoder.up.3.block.0.time_stack.out_layers.0.weight": "blocks.6.norm2.weight",
|
531 |
+
"first_stage_model.decoder.up.3.block.0.time_stack.out_layers.3.bias": "blocks.6.conv2.bias",
|
532 |
+
"first_stage_model.decoder.up.3.block.0.time_stack.out_layers.3.weight": "blocks.6.conv2.weight",
|
533 |
+
"first_stage_model.decoder.up.3.block.1.conv1.bias": "blocks.7.conv1.bias",
|
534 |
+
"first_stage_model.decoder.up.3.block.1.conv1.weight": "blocks.7.conv1.weight",
|
535 |
+
"first_stage_model.decoder.up.3.block.1.conv2.bias": "blocks.7.conv2.bias",
|
536 |
+
"first_stage_model.decoder.up.3.block.1.conv2.weight": "blocks.7.conv2.weight",
|
537 |
+
"first_stage_model.decoder.up.3.block.1.mix_factor": "blocks.8.mix_factor",
|
538 |
+
"first_stage_model.decoder.up.3.block.1.norm1.bias": "blocks.7.norm1.bias",
|
539 |
+
"first_stage_model.decoder.up.3.block.1.norm1.weight": "blocks.7.norm1.weight",
|
540 |
+
"first_stage_model.decoder.up.3.block.1.norm2.bias": "blocks.7.norm2.bias",
|
541 |
+
"first_stage_model.decoder.up.3.block.1.norm2.weight": "blocks.7.norm2.weight",
|
542 |
+
"first_stage_model.decoder.up.3.block.1.time_stack.in_layers.0.bias": "blocks.8.norm1.bias",
|
543 |
+
"first_stage_model.decoder.up.3.block.1.time_stack.in_layers.0.weight": "blocks.8.norm1.weight",
|
544 |
+
"first_stage_model.decoder.up.3.block.1.time_stack.in_layers.2.bias": "blocks.8.conv1.bias",
|
545 |
+
"first_stage_model.decoder.up.3.block.1.time_stack.in_layers.2.weight": "blocks.8.conv1.weight",
|
546 |
+
"first_stage_model.decoder.up.3.block.1.time_stack.out_layers.0.bias": "blocks.8.norm2.bias",
|
547 |
+
"first_stage_model.decoder.up.3.block.1.time_stack.out_layers.0.weight": "blocks.8.norm2.weight",
|
548 |
+
"first_stage_model.decoder.up.3.block.1.time_stack.out_layers.3.bias": "blocks.8.conv2.bias",
|
549 |
+
"first_stage_model.decoder.up.3.block.1.time_stack.out_layers.3.weight": "blocks.8.conv2.weight",
|
550 |
+
"first_stage_model.decoder.up.3.block.2.conv1.bias": "blocks.9.conv1.bias",
|
551 |
+
"first_stage_model.decoder.up.3.block.2.conv1.weight": "blocks.9.conv1.weight",
|
552 |
+
"first_stage_model.decoder.up.3.block.2.conv2.bias": "blocks.9.conv2.bias",
|
553 |
+
"first_stage_model.decoder.up.3.block.2.conv2.weight": "blocks.9.conv2.weight",
|
554 |
+
"first_stage_model.decoder.up.3.block.2.mix_factor": "blocks.10.mix_factor",
|
555 |
+
"first_stage_model.decoder.up.3.block.2.norm1.bias": "blocks.9.norm1.bias",
|
556 |
+
"first_stage_model.decoder.up.3.block.2.norm1.weight": "blocks.9.norm1.weight",
|
557 |
+
"first_stage_model.decoder.up.3.block.2.norm2.bias": "blocks.9.norm2.bias",
|
558 |
+
"first_stage_model.decoder.up.3.block.2.norm2.weight": "blocks.9.norm2.weight",
|
559 |
+
"first_stage_model.decoder.up.3.block.2.time_stack.in_layers.0.bias": "blocks.10.norm1.bias",
|
560 |
+
"first_stage_model.decoder.up.3.block.2.time_stack.in_layers.0.weight": "blocks.10.norm1.weight",
|
561 |
+
"first_stage_model.decoder.up.3.block.2.time_stack.in_layers.2.bias": "blocks.10.conv1.bias",
|
562 |
+
"first_stage_model.decoder.up.3.block.2.time_stack.in_layers.2.weight": "blocks.10.conv1.weight",
|
563 |
+
"first_stage_model.decoder.up.3.block.2.time_stack.out_layers.0.bias": "blocks.10.norm2.bias",
|
564 |
+
"first_stage_model.decoder.up.3.block.2.time_stack.out_layers.0.weight": "blocks.10.norm2.weight",
|
565 |
+
"first_stage_model.decoder.up.3.block.2.time_stack.out_layers.3.bias": "blocks.10.conv2.bias",
|
566 |
+
"first_stage_model.decoder.up.3.block.2.time_stack.out_layers.3.weight": "blocks.10.conv2.weight",
|
567 |
+
"first_stage_model.decoder.up.3.upsample.conv.bias": "blocks.11.conv.bias",
|
568 |
+
"first_stage_model.decoder.up.3.upsample.conv.weight": "blocks.11.conv.weight",
|
569 |
+
}
|
570 |
+
state_dict_ = {}
|
571 |
+
for name in state_dict:
|
572 |
+
if name in rename_dict:
|
573 |
+
param = state_dict[name]
|
574 |
+
if "blocks.2.transformer_blocks.0" in rename_dict[name]:
|
575 |
+
param = param.squeeze()
|
576 |
+
state_dict_[rename_dict[name]] = param
|
577 |
+
return state_dict_
|
diffsynth/models/svd_vae_encoder.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
|
2 |
+
|
3 |
+
|
4 |
+
class SVDVAEEncoder(SDVAEEncoder):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.scaling_factor = 0.13025
|
8 |
+
|
9 |
+
def state_dict_converter(self):
|
10 |
+
return SVDVAEEncoderStateDictConverter()
|
11 |
+
|
12 |
+
|
13 |
+
class SVDVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
|
14 |
+
def __init__(self):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
def from_diffusers(self, state_dict):
|
18 |
+
return super().from_diffusers(state_dict)
|
19 |
+
|
20 |
+
def from_civitai(self, state_dict):
|
21 |
+
rename_dict = {
|
22 |
+
"conditioner.embedders.3.encoder.encoder.conv_in.bias": "conv_in.bias",
|
23 |
+
"conditioner.embedders.3.encoder.encoder.conv_in.weight": "conv_in.weight",
|
24 |
+
"conditioner.embedders.3.encoder.encoder.conv_out.bias": "conv_out.bias",
|
25 |
+
"conditioner.embedders.3.encoder.encoder.conv_out.weight": "conv_out.weight",
|
26 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
|
27 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
|
28 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
|
29 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
|
30 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
|
31 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
|
32 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
|
33 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
|
34 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
|
35 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
|
36 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
|
37 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
|
38 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
|
39 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
|
40 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
|
41 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
|
42 |
+
"conditioner.embedders.3.encoder.encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
|
43 |
+
"conditioner.embedders.3.encoder.encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
|
44 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
|
45 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
|
46 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
|
47 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
|
48 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
|
49 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
|
50 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
|
51 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
|
52 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
|
53 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
|
54 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
|
55 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
|
56 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
|
57 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
|
58 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
|
59 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
|
60 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
|
61 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
|
62 |
+
"conditioner.embedders.3.encoder.encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
|
63 |
+
"conditioner.embedders.3.encoder.encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
|
64 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
|
65 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
|
66 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
|
67 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
|
68 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
|
69 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
|
70 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
|
71 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
|
72 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
|
73 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
|
74 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
|
75 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
|
76 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
|
77 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
|
78 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
|
79 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
|
80 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
|
81 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
|
82 |
+
"conditioner.embedders.3.encoder.encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
|
83 |
+
"conditioner.embedders.3.encoder.encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
|
84 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
|
85 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
|
86 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
|
87 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
|
88 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
|
89 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
|
90 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
|
91 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
|
92 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
|
93 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
|
94 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
|
95 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
|
96 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
|
97 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
|
98 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
|
99 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
|
100 |
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
|
101 |
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
|
102 |
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
|
103 |
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
|
104 |
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
|
105 |
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
|
106 |
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
|
107 |
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
|
108 |
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
|
109 |
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
|
110 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
|
111 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
|
112 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
|
113 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
|
114 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
|
115 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
|
116 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
|
117 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
|
118 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
|
119 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
|
120 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
|
121 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
|
122 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
|
123 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
|
124 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
|
125 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
|
126 |
+
"conditioner.embedders.3.encoder.encoder.norm_out.bias": "conv_norm_out.bias",
|
127 |
+
"conditioner.embedders.3.encoder.encoder.norm_out.weight": "conv_norm_out.weight",
|
128 |
+
"conditioner.embedders.3.encoder.quant_conv.bias": "quant_conv.bias",
|
129 |
+
"conditioner.embedders.3.encoder.quant_conv.weight": "quant_conv.weight",
|
130 |
+
}
|
131 |
+
state_dict_ = {}
|
132 |
+
for name in state_dict:
|
133 |
+
if name in rename_dict:
|
134 |
+
param = state_dict[name]
|
135 |
+
if "transformer_blocks" in rename_dict[name]:
|
136 |
+
param = param.squeeze()
|
137 |
+
state_dict_[rename_dict[name]] = param
|
138 |
+
return state_dict_
|