Spaces:
Runtime error
Runtime error
Updated files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +1 -1
- sd/stable-diffusion-webui/CODEOWNERS +12 -12
- sd/stable-diffusion-webui/LICENSE.txt +663 -663
- sd/stable-diffusion-webui/README.md +162 -0
- sd/stable-diffusion-webui/extensions-builtin/LDSR/preload.py +6 -6
- sd/stable-diffusion-webui/extensions-builtin/Lora/extra_networks_lora.py +26 -26
- sd/stable-diffusion-webui/extensions-builtin/Lora/lora.py +207 -207
- sd/stable-diffusion-webui/extensions-builtin/Lora/preload.py +6 -6
- sd/stable-diffusion-webui/extensions-builtin/Lora/scripts/lora_script.py +38 -38
- sd/stable-diffusion-webui/extensions-builtin/Lora/ui_extra_networks_lora.py +30 -37
- sd/stable-diffusion-webui/extensions-builtin/ScuNET/preload.py +6 -6
- sd/stable-diffusion-webui/extensions-builtin/SwinIR/preload.py +6 -6
- sd/stable-diffusion-webui/extensions-builtin/SwinIR/swinir_model_arch_v2.py +1016 -1016
- sd/stable-diffusion-webui/html/extra-networks-card.html +1 -0
- sd/stable-diffusion-webui/html/footer.html +13 -13
- sd/stable-diffusion-webui/html/licenses.html +638 -419
- sd/stable-diffusion-webui/javascript/aspectRatioOverlay.js +113 -113
- sd/stable-diffusion-webui/javascript/contextMenus.js +177 -177
- sd/stable-diffusion-webui/javascript/edit-attention.js +95 -95
- sd/stable-diffusion-webui/javascript/extensions.js +49 -49
- sd/stable-diffusion-webui/javascript/extraNetworks.js +106 -106
- sd/stable-diffusion-webui/javascript/hints.js +1 -0
- sd/stable-diffusion-webui/javascript/hires_fix.js +22 -22
- sd/stable-diffusion-webui/javascript/localization.js +165 -165
- sd/stable-diffusion-webui/javascript/notification.js +1 -1
- sd/stable-diffusion-webui/javascript/progressbar.js +1 -1
- sd/stable-diffusion-webui/javascript/textualInversion.js +17 -17
- sd/stable-diffusion-webui/launch.py +375 -361
- sd/stable-diffusion-webui/modules/api/api.py +28 -17
- sd/stable-diffusion-webui/modules/api/models.py +24 -4
- sd/stable-diffusion-webui/modules/call_queue.py +109 -109
- sd/stable-diffusion-webui/modules/codeformer_model.py +143 -143
- sd/stable-diffusion-webui/modules/deepbooru_model.py +678 -678
- sd/stable-diffusion-webui/modules/errors.py +43 -43
- sd/stable-diffusion-webui/modules/esrgan_model.py +233 -233
- sd/stable-diffusion-webui/modules/esrgan_model_arch.py +464 -464
- sd/stable-diffusion-webui/modules/extensions.py +107 -107
- sd/stable-diffusion-webui/modules/extra_networks.py +147 -147
- sd/stable-diffusion-webui/modules/extra_networks_hypernet.py +27 -27
- sd/stable-diffusion-webui/modules/extras.py +258 -258
- sd/stable-diffusion-webui/modules/face_restoration.py +19 -19
- sd/stable-diffusion-webui/modules/generation_parameters_copypaste.py +408 -402
- sd/stable-diffusion-webui/modules/gfpgan_model.py +116 -116
- sd/stable-diffusion-webui/modules/hashes.py +91 -91
- sd/stable-diffusion-webui/modules/hypernetworks/hypernetwork.py +811 -811
- sd/stable-diffusion-webui/modules/hypernetworks/ui.py +40 -40
- sd/stable-diffusion-webui/modules/images.py +669 -669
- sd/stable-diffusion-webui/modules/img2img.py +184 -184
- sd/stable-diffusion-webui/modules/interrogate.py +227 -227
- sd/stable-diffusion-webui/modules/localization.py +37 -37
README.md
CHANGED
@@ -3,7 +3,7 @@ license: apache-2.0
|
|
3 |
title: Automatic Stable Diffusion
|
4 |
sdk: gradio
|
5 |
sdk_version: 3.16.2
|
6 |
-
app_file: sd/stable-diffusion-webui/webui.py
|
7 |
emoji: 🚀
|
8 |
colorFrom: indigo
|
9 |
colorTo: purple
|
|
|
3 |
title: Automatic Stable Diffusion
|
4 |
sdk: gradio
|
5 |
sdk_version: 3.16.2
|
6 |
+
app_file: sd/stable-diffusion-webui/webui.py --api
|
7 |
emoji: 🚀
|
8 |
colorFrom: indigo
|
9 |
colorTo: purple
|
sd/stable-diffusion-webui/CODEOWNERS
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
-
* @AUTOMATIC1111
|
2 |
-
|
3 |
-
# if you were managing a localization and were removed from this file, this is because
|
4 |
-
# the intended way to do localizations now is via extensions. See:
|
5 |
-
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions
|
6 |
-
# Make a repo with your localization and since you are still listed as a collaborator
|
7 |
-
# you can add it to the wiki page yourself. This change is because some people complained
|
8 |
-
# the git commit log is cluttered with things unrelated to almost everyone and
|
9 |
-
# because I believe this is the best overall for the project to handle localizations almost
|
10 |
-
# entirely without my oversight.
|
11 |
-
|
12 |
-
|
|
|
1 |
+
* @AUTOMATIC1111
|
2 |
+
|
3 |
+
# if you were managing a localization and were removed from this file, this is because
|
4 |
+
# the intended way to do localizations now is via extensions. See:
|
5 |
+
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions
|
6 |
+
# Make a repo with your localization and since you are still listed as a collaborator
|
7 |
+
# you can add it to the wiki page yourself. This change is because some people complained
|
8 |
+
# the git commit log is cluttered with things unrelated to almost everyone and
|
9 |
+
# because I believe this is the best overall for the project to handle localizations almost
|
10 |
+
# entirely without my oversight.
|
11 |
+
|
12 |
+
|
sd/stable-diffusion-webui/LICENSE.txt
CHANGED
@@ -1,663 +1,663 @@
|
|
1 |
-
GNU AFFERO GENERAL PUBLIC LICENSE
|
2 |
-
Version 3, 19 November 2007
|
3 |
-
|
4 |
-
Copyright (c) 2023 AUTOMATIC1111
|
5 |
-
|
6 |
-
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
7 |
-
Everyone is permitted to copy and distribute verbatim copies
|
8 |
-
of this license document, but changing it is not allowed.
|
9 |
-
|
10 |
-
Preamble
|
11 |
-
|
12 |
-
The GNU Affero General Public License is a free, copyleft license for
|
13 |
-
software and other kinds of works, specifically designed to ensure
|
14 |
-
cooperation with the community in the case of network server software.
|
15 |
-
|
16 |
-
The licenses for most software and other practical works are designed
|
17 |
-
to take away your freedom to share and change the works. By contrast,
|
18 |
-
our General Public Licenses are intended to guarantee your freedom to
|
19 |
-
share and change all versions of a program--to make sure it remains free
|
20 |
-
software for all its users.
|
21 |
-
|
22 |
-
When we speak of free software, we are referring to freedom, not
|
23 |
-
price. Our General Public Licenses are designed to make sure that you
|
24 |
-
have the freedom to distribute copies of free software (and charge for
|
25 |
-
them if you wish), that you receive source code or can get it if you
|
26 |
-
want it, that you can change the software or use pieces of it in new
|
27 |
-
free programs, and that you know you can do these things.
|
28 |
-
|
29 |
-
Developers that use our General Public Licenses protect your rights
|
30 |
-
with two steps: (1) assert copyright on the software, and (2) offer
|
31 |
-
you this License which gives you legal permission to copy, distribute
|
32 |
-
and/or modify the software.
|
33 |
-
|
34 |
-
A secondary benefit of defending all users' freedom is that
|
35 |
-
improvements made in alternate versions of the program, if they
|
36 |
-
receive widespread use, become available for other developers to
|
37 |
-
incorporate. Many developers of free software are heartened and
|
38 |
-
encouraged by the resulting cooperation. However, in the case of
|
39 |
-
software used on network servers, this result may fail to come about.
|
40 |
-
The GNU General Public License permits making a modified version and
|
41 |
-
letting the public access it on a server without ever releasing its
|
42 |
-
source code to the public.
|
43 |
-
|
44 |
-
The GNU Affero General Public License is designed specifically to
|
45 |
-
ensure that, in such cases, the modified source code becomes available
|
46 |
-
to the community. It requires the operator of a network server to
|
47 |
-
provide the source code of the modified version running there to the
|
48 |
-
users of that server. Therefore, public use of a modified version, on
|
49 |
-
a publicly accessible server, gives the public access to the source
|
50 |
-
code of the modified version.
|
51 |
-
|
52 |
-
An older license, called the Affero General Public License and
|
53 |
-
published by Affero, was designed to accomplish similar goals. This is
|
54 |
-
a different license, not a version of the Affero GPL, but Affero has
|
55 |
-
released a new version of the Affero GPL which permits relicensing under
|
56 |
-
this license.
|
57 |
-
|
58 |
-
The precise terms and conditions for copying, distribution and
|
59 |
-
modification follow.
|
60 |
-
|
61 |
-
TERMS AND CONDITIONS
|
62 |
-
|
63 |
-
0. Definitions.
|
64 |
-
|
65 |
-
"This License" refers to version 3 of the GNU Affero General Public License.
|
66 |
-
|
67 |
-
"Copyright" also means copyright-like laws that apply to other kinds of
|
68 |
-
works, such as semiconductor masks.
|
69 |
-
|
70 |
-
"The Program" refers to any copyrightable work licensed under this
|
71 |
-
License. Each licensee is addressed as "you". "Licensees" and
|
72 |
-
"recipients" may be individuals or organizations.
|
73 |
-
|
74 |
-
To "modify" a work means to copy from or adapt all or part of the work
|
75 |
-
in a fashion requiring copyright permission, other than the making of an
|
76 |
-
exact copy. The resulting work is called a "modified version" of the
|
77 |
-
earlier work or a work "based on" the earlier work.
|
78 |
-
|
79 |
-
A "covered work" means either the unmodified Program or a work based
|
80 |
-
on the Program.
|
81 |
-
|
82 |
-
To "propagate" a work means to do anything with it that, without
|
83 |
-
permission, would make you directly or secondarily liable for
|
84 |
-
infringement under applicable copyright law, except executing it on a
|
85 |
-
computer or modifying a private copy. Propagation includes copying,
|
86 |
-
distribution (with or without modification), making available to the
|
87 |
-
public, and in some countries other activities as well.
|
88 |
-
|
89 |
-
To "convey" a work means any kind of propagation that enables other
|
90 |
-
parties to make or receive copies. Mere interaction with a user through
|
91 |
-
a computer network, with no transfer of a copy, is not conveying.
|
92 |
-
|
93 |
-
An interactive user interface displays "Appropriate Legal Notices"
|
94 |
-
to the extent that it includes a convenient and prominently visible
|
95 |
-
feature that (1) displays an appropriate copyright notice, and (2)
|
96 |
-
tells the user that there is no warranty for the work (except to the
|
97 |
-
extent that warranties are provided), that licensees may convey the
|
98 |
-
work under this License, and how to view a copy of this License. If
|
99 |
-
the interface presents a list of user commands or options, such as a
|
100 |
-
menu, a prominent item in the list meets this criterion.
|
101 |
-
|
102 |
-
1. Source Code.
|
103 |
-
|
104 |
-
The "source code" for a work means the preferred form of the work
|
105 |
-
for making modifications to it. "Object code" means any non-source
|
106 |
-
form of a work.
|
107 |
-
|
108 |
-
A "Standard Interface" means an interface that either is an official
|
109 |
-
standard defined by a recognized standards body, or, in the case of
|
110 |
-
interfaces specified for a particular programming language, one that
|
111 |
-
is widely used among developers working in that language.
|
112 |
-
|
113 |
-
The "System Libraries" of an executable work include anything, other
|
114 |
-
than the work as a whole, that (a) is included in the normal form of
|
115 |
-
packaging a Major Component, but which is not part of that Major
|
116 |
-
Component, and (b) serves only to enable use of the work with that
|
117 |
-
Major Component, or to implement a Standard Interface for which an
|
118 |
-
implementation is available to the public in source code form. A
|
119 |
-
"Major Component", in this context, means a major essential component
|
120 |
-
(kernel, window system, and so on) of the specific operating system
|
121 |
-
(if any) on which the executable work runs, or a compiler used to
|
122 |
-
produce the work, or an object code interpreter used to run it.
|
123 |
-
|
124 |
-
The "Corresponding Source" for a work in object code form means all
|
125 |
-
the source code needed to generate, install, and (for an executable
|
126 |
-
work) run the object code and to modify the work, including scripts to
|
127 |
-
control those activities. However, it does not include the work's
|
128 |
-
System Libraries, or general-purpose tools or generally available free
|
129 |
-
programs which are used unmodified in performing those activities but
|
130 |
-
which are not part of the work. For example, Corresponding Source
|
131 |
-
includes interface definition files associated with source files for
|
132 |
-
the work, and the source code for shared libraries and dynamically
|
133 |
-
linked subprograms that the work is specifically designed to require,
|
134 |
-
such as by intimate data communication or control flow between those
|
135 |
-
subprograms and other parts of the work.
|
136 |
-
|
137 |
-
The Corresponding Source need not include anything that users
|
138 |
-
can regenerate automatically from other parts of the Corresponding
|
139 |
-
Source.
|
140 |
-
|
141 |
-
The Corresponding Source for a work in source code form is that
|
142 |
-
same work.
|
143 |
-
|
144 |
-
2. Basic Permissions.
|
145 |
-
|
146 |
-
All rights granted under this License are granted for the term of
|
147 |
-
copyright on the Program, and are irrevocable provided the stated
|
148 |
-
conditions are met. This License explicitly affirms your unlimited
|
149 |
-
permission to run the unmodified Program. The output from running a
|
150 |
-
covered work is covered by this License only if the output, given its
|
151 |
-
content, constitutes a covered work. This License acknowledges your
|
152 |
-
rights of fair use or other equivalent, as provided by copyright law.
|
153 |
-
|
154 |
-
You may make, run and propagate covered works that you do not
|
155 |
-
convey, without conditions so long as your license otherwise remains
|
156 |
-
in force. You may convey covered works to others for the sole purpose
|
157 |
-
of having them make modifications exclusively for you, or provide you
|
158 |
-
with facilities for running those works, provided that you comply with
|
159 |
-
the terms of this License in conveying all material for which you do
|
160 |
-
not control copyright. Those thus making or running the covered works
|
161 |
-
for you must do so exclusively on your behalf, under your direction
|
162 |
-
and control, on terms that prohibit them from making any copies of
|
163 |
-
your copyrighted material outside their relationship with you.
|
164 |
-
|
165 |
-
Conveying under any other circumstances is permitted solely under
|
166 |
-
the conditions stated below. Sublicensing is not allowed; section 10
|
167 |
-
makes it unnecessary.
|
168 |
-
|
169 |
-
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
170 |
-
|
171 |
-
No covered work shall be deemed part of an effective technological
|
172 |
-
measure under any applicable law fulfilling obligations under article
|
173 |
-
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
174 |
-
similar laws prohibiting or restricting circumvention of such
|
175 |
-
measures.
|
176 |
-
|
177 |
-
When you convey a covered work, you waive any legal power to forbid
|
178 |
-
circumvention of technological measures to the extent such circumvention
|
179 |
-
is effected by exercising rights under this License with respect to
|
180 |
-
the covered work, and you disclaim any intention to limit operation or
|
181 |
-
modification of the work as a means of enforcing, against the work's
|
182 |
-
users, your or third parties' legal rights to forbid circumvention of
|
183 |
-
technological measures.
|
184 |
-
|
185 |
-
4. Conveying Verbatim Copies.
|
186 |
-
|
187 |
-
You may convey verbatim copies of the Program's source code as you
|
188 |
-
receive it, in any medium, provided that you conspicuously and
|
189 |
-
appropriately publish on each copy an appropriate copyright notice;
|
190 |
-
keep intact all notices stating that this License and any
|
191 |
-
non-permissive terms added in accord with section 7 apply to the code;
|
192 |
-
keep intact all notices of the absence of any warranty; and give all
|
193 |
-
recipients a copy of this License along with the Program.
|
194 |
-
|
195 |
-
You may charge any price or no price for each copy that you convey,
|
196 |
-
and you may offer support or warranty protection for a fee.
|
197 |
-
|
198 |
-
5. Conveying Modified Source Versions.
|
199 |
-
|
200 |
-
You may convey a work based on the Program, or the modifications to
|
201 |
-
produce it from the Program, in the form of source code under the
|
202 |
-
terms of section 4, provided that you also meet all of these conditions:
|
203 |
-
|
204 |
-
a) The work must carry prominent notices stating that you modified
|
205 |
-
it, and giving a relevant date.
|
206 |
-
|
207 |
-
b) The work must carry prominent notices stating that it is
|
208 |
-
released under this License and any conditions added under section
|
209 |
-
7. This requirement modifies the requirement in section 4 to
|
210 |
-
"keep intact all notices".
|
211 |
-
|
212 |
-
c) You must license the entire work, as a whole, under this
|
213 |
-
License to anyone who comes into possession of a copy. This
|
214 |
-
License will therefore apply, along with any applicable section 7
|
215 |
-
additional terms, to the whole of the work, and all its parts,
|
216 |
-
regardless of how they are packaged. This License gives no
|
217 |
-
permission to license the work in any other way, but it does not
|
218 |
-
invalidate such permission if you have separately received it.
|
219 |
-
|
220 |
-
d) If the work has interactive user interfaces, each must display
|
221 |
-
Appropriate Legal Notices; however, if the Program has interactive
|
222 |
-
interfaces that do not display Appropriate Legal Notices, your
|
223 |
-
work need not make them do so.
|
224 |
-
|
225 |
-
A compilation of a covered work with other separate and independent
|
226 |
-
works, which are not by their nature extensions of the covered work,
|
227 |
-
and which are not combined with it such as to form a larger program,
|
228 |
-
in or on a volume of a storage or distribution medium, is called an
|
229 |
-
"aggregate" if the compilation and its resulting copyright are not
|
230 |
-
used to limit the access or legal rights of the compilation's users
|
231 |
-
beyond what the individual works permit. Inclusion of a covered work
|
232 |
-
in an aggregate does not cause this License to apply to the other
|
233 |
-
parts of the aggregate.
|
234 |
-
|
235 |
-
6. Conveying Non-Source Forms.
|
236 |
-
|
237 |
-
You may convey a covered work in object code form under the terms
|
238 |
-
of sections 4 and 5, provided that you also convey the
|
239 |
-
machine-readable Corresponding Source under the terms of this License,
|
240 |
-
in one of these ways:
|
241 |
-
|
242 |
-
a) Convey the object code in, or embodied in, a physical product
|
243 |
-
(including a physical distribution medium), accompanied by the
|
244 |
-
Corresponding Source fixed on a durable physical medium
|
245 |
-
customarily used for software interchange.
|
246 |
-
|
247 |
-
b) Convey the object code in, or embodied in, a physical product
|
248 |
-
(including a physical distribution medium), accompanied by a
|
249 |
-
written offer, valid for at least three years and valid for as
|
250 |
-
long as you offer spare parts or customer support for that product
|
251 |
-
model, to give anyone who possesses the object code either (1) a
|
252 |
-
copy of the Corresponding Source for all the software in the
|
253 |
-
product that is covered by this License, on a durable physical
|
254 |
-
medium customarily used for software interchange, for a price no
|
255 |
-
more than your reasonable cost of physically performing this
|
256 |
-
conveying of source, or (2) access to copy the
|
257 |
-
Corresponding Source from a network server at no charge.
|
258 |
-
|
259 |
-
c) Convey individual copies of the object code with a copy of the
|
260 |
-
written offer to provide the Corresponding Source. This
|
261 |
-
alternative is allowed only occasionally and noncommercially, and
|
262 |
-
only if you received the object code with such an offer, in accord
|
263 |
-
with subsection 6b.
|
264 |
-
|
265 |
-
d) Convey the object code by offering access from a designated
|
266 |
-
place (gratis or for a charge), and offer equivalent access to the
|
267 |
-
Corresponding Source in the same way through the same place at no
|
268 |
-
further charge. You need not require recipients to copy the
|
269 |
-
Corresponding Source along with the object code. If the place to
|
270 |
-
copy the object code is a network server, the Corresponding Source
|
271 |
-
may be on a different server (operated by you or a third party)
|
272 |
-
that supports equivalent copying facilities, provided you maintain
|
273 |
-
clear directions next to the object code saying where to find the
|
274 |
-
Corresponding Source. Regardless of what server hosts the
|
275 |
-
Corresponding Source, you remain obligated to ensure that it is
|
276 |
-
available for as long as needed to satisfy these requirements.
|
277 |
-
|
278 |
-
e) Convey the object code using peer-to-peer transmission, provided
|
279 |
-
you inform other peers where the object code and Corresponding
|
280 |
-
Source of the work are being offered to the general public at no
|
281 |
-
charge under subsection 6d.
|
282 |
-
|
283 |
-
A separable portion of the object code, whose source code is excluded
|
284 |
-
from the Corresponding Source as a System Library, need not be
|
285 |
-
included in conveying the object code work.
|
286 |
-
|
287 |
-
A "User Product" is either (1) a "consumer product", which means any
|
288 |
-
tangible personal property which is normally used for personal, family,
|
289 |
-
or household purposes, or (2) anything designed or sold for incorporation
|
290 |
-
into a dwelling. In determining whether a product is a consumer product,
|
291 |
-
doubtful cases shall be resolved in favor of coverage. For a particular
|
292 |
-
product received by a particular user, "normally used" refers to a
|
293 |
-
typical or common use of that class of product, regardless of the status
|
294 |
-
of the particular user or of the way in which the particular user
|
295 |
-
actually uses, or expects or is expected to use, the product. A product
|
296 |
-
is a consumer product regardless of whether the product has substantial
|
297 |
-
commercial, industrial or non-consumer uses, unless such uses represent
|
298 |
-
the only significant mode of use of the product.
|
299 |
-
|
300 |
-
"Installation Information" for a User Product means any methods,
|
301 |
-
procedures, authorization keys, or other information required to install
|
302 |
-
and execute modified versions of a covered work in that User Product from
|
303 |
-
a modified version of its Corresponding Source. The information must
|
304 |
-
suffice to ensure that the continued functioning of the modified object
|
305 |
-
code is in no case prevented or interfered with solely because
|
306 |
-
modification has been made.
|
307 |
-
|
308 |
-
If you convey an object code work under this section in, or with, or
|
309 |
-
specifically for use in, a User Product, and the conveying occurs as
|
310 |
-
part of a transaction in which the right of possession and use of the
|
311 |
-
User Product is transferred to the recipient in perpetuity or for a
|
312 |
-
fixed term (regardless of how the transaction is characterized), the
|
313 |
-
Corresponding Source conveyed under this section must be accompanied
|
314 |
-
by the Installation Information. But this requirement does not apply
|
315 |
-
if neither you nor any third party retains the ability to install
|
316 |
-
modified object code on the User Product (for example, the work has
|
317 |
-
been installed in ROM).
|
318 |
-
|
319 |
-
The requirement to provide Installation Information does not include a
|
320 |
-
requirement to continue to provide support service, warranty, or updates
|
321 |
-
for a work that has been modified or installed by the recipient, or for
|
322 |
-
the User Product in which it has been modified or installed. Access to a
|
323 |
-
network may be denied when the modification itself materially and
|
324 |
-
adversely affects the operation of the network or violates the rules and
|
325 |
-
protocols for communication across the network.
|
326 |
-
|
327 |
-
Corresponding Source conveyed, and Installation Information provided,
|
328 |
-
in accord with this section must be in a format that is publicly
|
329 |
-
documented (and with an implementation available to the public in
|
330 |
-
source code form), and must require no special password or key for
|
331 |
-
unpacking, reading or copying.
|
332 |
-
|
333 |
-
7. Additional Terms.
|
334 |
-
|
335 |
-
"Additional permissions" are terms that supplement the terms of this
|
336 |
-
License by making exceptions from one or more of its conditions.
|
337 |
-
Additional permissions that are applicable to the entire Program shall
|
338 |
-
be treated as though they were included in this License, to the extent
|
339 |
-
that they are valid under applicable law. If additional permissions
|
340 |
-
apply only to part of the Program, that part may be used separately
|
341 |
-
under those permissions, but the entire Program remains governed by
|
342 |
-
this License without regard to the additional permissions.
|
343 |
-
|
344 |
-
When you convey a copy of a covered work, you may at your option
|
345 |
-
remove any additional permissions from that copy, or from any part of
|
346 |
-
it. (Additional permissions may be written to require their own
|
347 |
-
removal in certain cases when you modify the work.) You may place
|
348 |
-
additional permissions on material, added by you to a covered work,
|
349 |
-
for which you have or can give appropriate copyright permission.
|
350 |
-
|
351 |
-
Notwithstanding any other provision of this License, for material you
|
352 |
-
add to a covered work, you may (if authorized by the copyright holders of
|
353 |
-
that material) supplement the terms of this License with terms:
|
354 |
-
|
355 |
-
a) Disclaiming warranty or limiting liability differently from the
|
356 |
-
terms of sections 15 and 16 of this License; or
|
357 |
-
|
358 |
-
b) Requiring preservation of specified reasonable legal notices or
|
359 |
-
author attributions in that material or in the Appropriate Legal
|
360 |
-
Notices displayed by works containing it; or
|
361 |
-
|
362 |
-
c) Prohibiting misrepresentation of the origin of that material, or
|
363 |
-
requiring that modified versions of such material be marked in
|
364 |
-
reasonable ways as different from the original version; or
|
365 |
-
|
366 |
-
d) Limiting the use for publicity purposes of names of licensors or
|
367 |
-
authors of the material; or
|
368 |
-
|
369 |
-
e) Declining to grant rights under trademark law for use of some
|
370 |
-
trade names, trademarks, or service marks; or
|
371 |
-
|
372 |
-
f) Requiring indemnification of licensors and authors of that
|
373 |
-
material by anyone who conveys the material (or modified versions of
|
374 |
-
it) with contractual assumptions of liability to the recipient, for
|
375 |
-
any liability that these contractual assumptions directly impose on
|
376 |
-
those licensors and authors.
|
377 |
-
|
378 |
-
All other non-permissive additional terms are considered "further
|
379 |
-
restrictions" within the meaning of section 10. If the Program as you
|
380 |
-
received it, or any part of it, contains a notice stating that it is
|
381 |
-
governed by this License along with a term that is a further
|
382 |
-
restriction, you may remove that term. If a license document contains
|
383 |
-
a further restriction but permits relicensing or conveying under this
|
384 |
-
License, you may add to a covered work material governed by the terms
|
385 |
-
of that license document, provided that the further restriction does
|
386 |
-
not survive such relicensing or conveying.
|
387 |
-
|
388 |
-
If you add terms to a covered work in accord with this section, you
|
389 |
-
must place, in the relevant source files, a statement of the
|
390 |
-
additional terms that apply to those files, or a notice indicating
|
391 |
-
where to find the applicable terms.
|
392 |
-
|
393 |
-
Additional terms, permissive or non-permissive, may be stated in the
|
394 |
-
form of a separately written license, or stated as exceptions;
|
395 |
-
the above requirements apply either way.
|
396 |
-
|
397 |
-
8. Termination.
|
398 |
-
|
399 |
-
You may not propagate or modify a covered work except as expressly
|
400 |
-
provided under this License. Any attempt otherwise to propagate or
|
401 |
-
modify it is void, and will automatically terminate your rights under
|
402 |
-
this License (including any patent licenses granted under the third
|
403 |
-
paragraph of section 11).
|
404 |
-
|
405 |
-
However, if you cease all violation of this License, then your
|
406 |
-
license from a particular copyright holder is reinstated (a)
|
407 |
-
provisionally, unless and until the copyright holder explicitly and
|
408 |
-
finally terminates your license, and (b) permanently, if the copyright
|
409 |
-
holder fails to notify you of the violation by some reasonable means
|
410 |
-
prior to 60 days after the cessation.
|
411 |
-
|
412 |
-
Moreover, your license from a particular copyright holder is
|
413 |
-
reinstated permanently if the copyright holder notifies you of the
|
414 |
-
violation by some reasonable means, this is the first time you have
|
415 |
-
received notice of violation of this License (for any work) from that
|
416 |
-
copyright holder, and you cure the violation prior to 30 days after
|
417 |
-
your receipt of the notice.
|
418 |
-
|
419 |
-
Termination of your rights under this section does not terminate the
|
420 |
-
licenses of parties who have received copies or rights from you under
|
421 |
-
this License. If your rights have been terminated and not permanently
|
422 |
-
reinstated, you do not qualify to receive new licenses for the same
|
423 |
-
material under section 10.
|
424 |
-
|
425 |
-
9. Acceptance Not Required for Having Copies.
|
426 |
-
|
427 |
-
You are not required to accept this License in order to receive or
|
428 |
-
run a copy of the Program. Ancillary propagation of a covered work
|
429 |
-
occurring solely as a consequence of using peer-to-peer transmission
|
430 |
-
to receive a copy likewise does not require acceptance. However,
|
431 |
-
nothing other than this License grants you permission to propagate or
|
432 |
-
modify any covered work. These actions infringe copyright if you do
|
433 |
-
not accept this License. Therefore, by modifying or propagating a
|
434 |
-
covered work, you indicate your acceptance of this License to do so.
|
435 |
-
|
436 |
-
10. Automatic Licensing of Downstream Recipients.
|
437 |
-
|
438 |
-
Each time you convey a covered work, the recipient automatically
|
439 |
-
receives a license from the original licensors, to run, modify and
|
440 |
-
propagate that work, subject to this License. You are not responsible
|
441 |
-
for enforcing compliance by third parties with this License.
|
442 |
-
|
443 |
-
An "entity transaction" is a transaction transferring control of an
|
444 |
-
organization, or substantially all assets of one, or subdividing an
|
445 |
-
organization, or merging organizations. If propagation of a covered
|
446 |
-
work results from an entity transaction, each party to that
|
447 |
-
transaction who receives a copy of the work also receives whatever
|
448 |
-
licenses to the work the party's predecessor in interest had or could
|
449 |
-
give under the previous paragraph, plus a right to possession of the
|
450 |
-
Corresponding Source of the work from the predecessor in interest, if
|
451 |
-
the predecessor has it or can get it with reasonable efforts.
|
452 |
-
|
453 |
-
You may not impose any further restrictions on the exercise of the
|
454 |
-
rights granted or affirmed under this License. For example, you may
|
455 |
-
not impose a license fee, royalty, or other charge for exercise of
|
456 |
-
rights granted under this License, and you may not initiate litigation
|
457 |
-
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
458 |
-
any patent claim is infringed by making, using, selling, offering for
|
459 |
-
sale, or importing the Program or any portion of it.
|
460 |
-
|
461 |
-
11. Patents.
|
462 |
-
|
463 |
-
A "contributor" is a copyright holder who authorizes use under this
|
464 |
-
License of the Program or a work on which the Program is based. The
|
465 |
-
work thus licensed is called the contributor's "contributor version".
|
466 |
-
|
467 |
-
A contributor's "essential patent claims" are all patent claims
|
468 |
-
owned or controlled by the contributor, whether already acquired or
|
469 |
-
hereafter acquired, that would be infringed by some manner, permitted
|
470 |
-
by this License, of making, using, or selling its contributor version,
|
471 |
-
but do not include claims that would be infringed only as a
|
472 |
-
consequence of further modification of the contributor version. For
|
473 |
-
purposes of this definition, "control" includes the right to grant
|
474 |
-
patent sublicenses in a manner consistent with the requirements of
|
475 |
-
this License.
|
476 |
-
|
477 |
-
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
478 |
-
patent license under the contributor's essential patent claims, to
|
479 |
-
make, use, sell, offer for sale, import and otherwise run, modify and
|
480 |
-
propagate the contents of its contributor version.
|
481 |
-
|
482 |
-
In the following three paragraphs, a "patent license" is any express
|
483 |
-
agreement or commitment, however denominated, not to enforce a patent
|
484 |
-
(such as an express permission to practice a patent or covenant not to
|
485 |
-
sue for patent infringement). To "grant" such a patent license to a
|
486 |
-
party means to make such an agreement or commitment not to enforce a
|
487 |
-
patent against the party.
|
488 |
-
|
489 |
-
If you convey a covered work, knowingly relying on a patent license,
|
490 |
-
and the Corresponding Source of the work is not available for anyone
|
491 |
-
to copy, free of charge and under the terms of this License, through a
|
492 |
-
publicly available network server or other readily accessible means,
|
493 |
-
then you must either (1) cause the Corresponding Source to be so
|
494 |
-
available, or (2) arrange to deprive yourself of the benefit of the
|
495 |
-
patent license for this particular work, or (3) arrange, in a manner
|
496 |
-
consistent with the requirements of this License, to extend the patent
|
497 |
-
license to downstream recipients. "Knowingly relying" means you have
|
498 |
-
actual knowledge that, but for the patent license, your conveying the
|
499 |
-
covered work in a country, or your recipient's use of the covered work
|
500 |
-
in a country, would infringe one or more identifiable patents in that
|
501 |
-
country that you have reason to believe are valid.
|
502 |
-
|
503 |
-
If, pursuant to or in connection with a single transaction or
|
504 |
-
arrangement, you convey, or propagate by procuring conveyance of, a
|
505 |
-
covered work, and grant a patent license to some of the parties
|
506 |
-
receiving the covered work authorizing them to use, propagate, modify
|
507 |
-
or convey a specific copy of the covered work, then the patent license
|
508 |
-
you grant is automatically extended to all recipients of the covered
|
509 |
-
work and works based on it.
|
510 |
-
|
511 |
-
A patent license is "discriminatory" if it does not include within
|
512 |
-
the scope of its coverage, prohibits the exercise of, or is
|
513 |
-
conditioned on the non-exercise of one or more of the rights that are
|
514 |
-
specifically granted under this License. You may not convey a covered
|
515 |
-
work if you are a party to an arrangement with a third party that is
|
516 |
-
in the business of distributing software, under which you make payment
|
517 |
-
to the third party based on the extent of your activity of conveying
|
518 |
-
the work, and under which the third party grants, to any of the
|
519 |
-
parties who would receive the covered work from you, a discriminatory
|
520 |
-
patent license (a) in connection with copies of the covered work
|
521 |
-
conveyed by you (or copies made from those copies), or (b) primarily
|
522 |
-
for and in connection with specific products or compilations that
|
523 |
-
contain the covered work, unless you entered into that arrangement,
|
524 |
-
or that patent license was granted, prior to 28 March 2007.
|
525 |
-
|
526 |
-
Nothing in this License shall be construed as excluding or limiting
|
527 |
-
any implied license or other defenses to infringement that may
|
528 |
-
otherwise be available to you under applicable patent law.
|
529 |
-
|
530 |
-
12. No Surrender of Others' Freedom.
|
531 |
-
|
532 |
-
If conditions are imposed on you (whether by court order, agreement or
|
533 |
-
otherwise) that contradict the conditions of this License, they do not
|
534 |
-
excuse you from the conditions of this License. If you cannot convey a
|
535 |
-
covered work so as to satisfy simultaneously your obligations under this
|
536 |
-
License and any other pertinent obligations, then as a consequence you may
|
537 |
-
not convey it at all. For example, if you agree to terms that obligate you
|
538 |
-
to collect a royalty for further conveying from those to whom you convey
|
539 |
-
the Program, the only way you could satisfy both those terms and this
|
540 |
-
License would be to refrain entirely from conveying the Program.
|
541 |
-
|
542 |
-
13. Remote Network Interaction; Use with the GNU General Public License.
|
543 |
-
|
544 |
-
Notwithstanding any other provision of this License, if you modify the
|
545 |
-
Program, your modified version must prominently offer all users
|
546 |
-
interacting with it remotely through a computer network (if your version
|
547 |
-
supports such interaction) an opportunity to receive the Corresponding
|
548 |
-
Source of your version by providing access to the Corresponding Source
|
549 |
-
from a network server at no charge, through some standard or customary
|
550 |
-
means of facilitating copying of software. This Corresponding Source
|
551 |
-
shall include the Corresponding Source for any work covered by version 3
|
552 |
-
of the GNU General Public License that is incorporated pursuant to the
|
553 |
-
following paragraph.
|
554 |
-
|
555 |
-
Notwithstanding any other provision of this License, you have
|
556 |
-
permission to link or combine any covered work with a work licensed
|
557 |
-
under version 3 of the GNU General Public License into a single
|
558 |
-
combined work, and to convey the resulting work. The terms of this
|
559 |
-
License will continue to apply to the part which is the covered work,
|
560 |
-
but the work with which it is combined will remain governed by version
|
561 |
-
3 of the GNU General Public License.
|
562 |
-
|
563 |
-
14. Revised Versions of this License.
|
564 |
-
|
565 |
-
The Free Software Foundation may publish revised and/or new versions of
|
566 |
-
the GNU Affero General Public License from time to time. Such new versions
|
567 |
-
will be similar in spirit to the present version, but may differ in detail to
|
568 |
-
address new problems or concerns.
|
569 |
-
|
570 |
-
Each version is given a distinguishing version number. If the
|
571 |
-
Program specifies that a certain numbered version of the GNU Affero General
|
572 |
-
Public License "or any later version" applies to it, you have the
|
573 |
-
option of following the terms and conditions either of that numbered
|
574 |
-
version or of any later version published by the Free Software
|
575 |
-
Foundation. If the Program does not specify a version number of the
|
576 |
-
GNU Affero General Public License, you may choose any version ever published
|
577 |
-
by the Free Software Foundation.
|
578 |
-
|
579 |
-
If the Program specifies that a proxy can decide which future
|
580 |
-
versions of the GNU Affero General Public License can be used, that proxy's
|
581 |
-
public statement of acceptance of a version permanently authorizes you
|
582 |
-
to choose that version for the Program.
|
583 |
-
|
584 |
-
Later license versions may give you additional or different
|
585 |
-
permissions. However, no additional obligations are imposed on any
|
586 |
-
author or copyright holder as a result of your choosing to follow a
|
587 |
-
later version.
|
588 |
-
|
589 |
-
15. Disclaimer of Warranty.
|
590 |
-
|
591 |
-
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
592 |
-
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
593 |
-
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
594 |
-
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
595 |
-
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
596 |
-
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
597 |
-
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
598 |
-
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
599 |
-
|
600 |
-
16. Limitation of Liability.
|
601 |
-
|
602 |
-
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
603 |
-
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
604 |
-
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
605 |
-
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
606 |
-
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
607 |
-
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
608 |
-
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
609 |
-
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
610 |
-
SUCH DAMAGES.
|
611 |
-
|
612 |
-
17. Interpretation of Sections 15 and 16.
|
613 |
-
|
614 |
-
If the disclaimer of warranty and limitation of liability provided
|
615 |
-
above cannot be given local legal effect according to their terms,
|
616 |
-
reviewing courts shall apply local law that most closely approximates
|
617 |
-
an absolute waiver of all civil liability in connection with the
|
618 |
-
Program, unless a warranty or assumption of liability accompanies a
|
619 |
-
copy of the Program in return for a fee.
|
620 |
-
|
621 |
-
END OF TERMS AND CONDITIONS
|
622 |
-
|
623 |
-
How to Apply These Terms to Your New Programs
|
624 |
-
|
625 |
-
If you develop a new program, and you want it to be of the greatest
|
626 |
-
possible use to the public, the best way to achieve this is to make it
|
627 |
-
free software which everyone can redistribute and change under these terms.
|
628 |
-
|
629 |
-
To do so, attach the following notices to the program. It is safest
|
630 |
-
to attach them to the start of each source file to most effectively
|
631 |
-
state the exclusion of warranty; and each file should have at least
|
632 |
-
the "copyright" line and a pointer to where the full notice is found.
|
633 |
-
|
634 |
-
<one line to give the program's name and a brief idea of what it does.>
|
635 |
-
Copyright (C) <year> <name of author>
|
636 |
-
|
637 |
-
This program is free software: you can redistribute it and/or modify
|
638 |
-
it under the terms of the GNU Affero General Public License as published by
|
639 |
-
the Free Software Foundation, either version 3 of the License, or
|
640 |
-
(at your option) any later version.
|
641 |
-
|
642 |
-
This program is distributed in the hope that it will be useful,
|
643 |
-
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
644 |
-
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
645 |
-
GNU Affero General Public License for more details.
|
646 |
-
|
647 |
-
You should have received a copy of the GNU Affero General Public License
|
648 |
-
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
649 |
-
|
650 |
-
Also add information on how to contact you by electronic and paper mail.
|
651 |
-
|
652 |
-
If your software can interact with users remotely through a computer
|
653 |
-
network, you should also make sure that it provides a way for users to
|
654 |
-
get its source. For example, if your program is a web application, its
|
655 |
-
interface could display a "Source" link that leads users to an archive
|
656 |
-
of the code. There are many ways you could offer source, and different
|
657 |
-
solutions will be better for different programs; see section 13 for the
|
658 |
-
specific requirements.
|
659 |
-
|
660 |
-
You should also get your employer (if you work as a programmer) or school,
|
661 |
-
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
662 |
-
For more information on this, and how to apply and follow the GNU AGPL, see
|
663 |
-
<https://www.gnu.org/licenses/>.
|
|
|
1 |
+
GNU AFFERO GENERAL PUBLIC LICENSE
|
2 |
+
Version 3, 19 November 2007
|
3 |
+
|
4 |
+
Copyright (c) 2023 AUTOMATIC1111
|
5 |
+
|
6 |
+
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
7 |
+
Everyone is permitted to copy and distribute verbatim copies
|
8 |
+
of this license document, but changing it is not allowed.
|
9 |
+
|
10 |
+
Preamble
|
11 |
+
|
12 |
+
The GNU Affero General Public License is a free, copyleft license for
|
13 |
+
software and other kinds of works, specifically designed to ensure
|
14 |
+
cooperation with the community in the case of network server software.
|
15 |
+
|
16 |
+
The licenses for most software and other practical works are designed
|
17 |
+
to take away your freedom to share and change the works. By contrast,
|
18 |
+
our General Public Licenses are intended to guarantee your freedom to
|
19 |
+
share and change all versions of a program--to make sure it remains free
|
20 |
+
software for all its users.
|
21 |
+
|
22 |
+
When we speak of free software, we are referring to freedom, not
|
23 |
+
price. Our General Public Licenses are designed to make sure that you
|
24 |
+
have the freedom to distribute copies of free software (and charge for
|
25 |
+
them if you wish), that you receive source code or can get it if you
|
26 |
+
want it, that you can change the software or use pieces of it in new
|
27 |
+
free programs, and that you know you can do these things.
|
28 |
+
|
29 |
+
Developers that use our General Public Licenses protect your rights
|
30 |
+
with two steps: (1) assert copyright on the software, and (2) offer
|
31 |
+
you this License which gives you legal permission to copy, distribute
|
32 |
+
and/or modify the software.
|
33 |
+
|
34 |
+
A secondary benefit of defending all users' freedom is that
|
35 |
+
improvements made in alternate versions of the program, if they
|
36 |
+
receive widespread use, become available for other developers to
|
37 |
+
incorporate. Many developers of free software are heartened and
|
38 |
+
encouraged by the resulting cooperation. However, in the case of
|
39 |
+
software used on network servers, this result may fail to come about.
|
40 |
+
The GNU General Public License permits making a modified version and
|
41 |
+
letting the public access it on a server without ever releasing its
|
42 |
+
source code to the public.
|
43 |
+
|
44 |
+
The GNU Affero General Public License is designed specifically to
|
45 |
+
ensure that, in such cases, the modified source code becomes available
|
46 |
+
to the community. It requires the operator of a network server to
|
47 |
+
provide the source code of the modified version running there to the
|
48 |
+
users of that server. Therefore, public use of a modified version, on
|
49 |
+
a publicly accessible server, gives the public access to the source
|
50 |
+
code of the modified version.
|
51 |
+
|
52 |
+
An older license, called the Affero General Public License and
|
53 |
+
published by Affero, was designed to accomplish similar goals. This is
|
54 |
+
a different license, not a version of the Affero GPL, but Affero has
|
55 |
+
released a new version of the Affero GPL which permits relicensing under
|
56 |
+
this license.
|
57 |
+
|
58 |
+
The precise terms and conditions for copying, distribution and
|
59 |
+
modification follow.
|
60 |
+
|
61 |
+
TERMS AND CONDITIONS
|
62 |
+
|
63 |
+
0. Definitions.
|
64 |
+
|
65 |
+
"This License" refers to version 3 of the GNU Affero General Public License.
|
66 |
+
|
67 |
+
"Copyright" also means copyright-like laws that apply to other kinds of
|
68 |
+
works, such as semiconductor masks.
|
69 |
+
|
70 |
+
"The Program" refers to any copyrightable work licensed under this
|
71 |
+
License. Each licensee is addressed as "you". "Licensees" and
|
72 |
+
"recipients" may be individuals or organizations.
|
73 |
+
|
74 |
+
To "modify" a work means to copy from or adapt all or part of the work
|
75 |
+
in a fashion requiring copyright permission, other than the making of an
|
76 |
+
exact copy. The resulting work is called a "modified version" of the
|
77 |
+
earlier work or a work "based on" the earlier work.
|
78 |
+
|
79 |
+
A "covered work" means either the unmodified Program or a work based
|
80 |
+
on the Program.
|
81 |
+
|
82 |
+
To "propagate" a work means to do anything with it that, without
|
83 |
+
permission, would make you directly or secondarily liable for
|
84 |
+
infringement under applicable copyright law, except executing it on a
|
85 |
+
computer or modifying a private copy. Propagation includes copying,
|
86 |
+
distribution (with or without modification), making available to the
|
87 |
+
public, and in some countries other activities as well.
|
88 |
+
|
89 |
+
To "convey" a work means any kind of propagation that enables other
|
90 |
+
parties to make or receive copies. Mere interaction with a user through
|
91 |
+
a computer network, with no transfer of a copy, is not conveying.
|
92 |
+
|
93 |
+
An interactive user interface displays "Appropriate Legal Notices"
|
94 |
+
to the extent that it includes a convenient and prominently visible
|
95 |
+
feature that (1) displays an appropriate copyright notice, and (2)
|
96 |
+
tells the user that there is no warranty for the work (except to the
|
97 |
+
extent that warranties are provided), that licensees may convey the
|
98 |
+
work under this License, and how to view a copy of this License. If
|
99 |
+
the interface presents a list of user commands or options, such as a
|
100 |
+
menu, a prominent item in the list meets this criterion.
|
101 |
+
|
102 |
+
1. Source Code.
|
103 |
+
|
104 |
+
The "source code" for a work means the preferred form of the work
|
105 |
+
for making modifications to it. "Object code" means any non-source
|
106 |
+
form of a work.
|
107 |
+
|
108 |
+
A "Standard Interface" means an interface that either is an official
|
109 |
+
standard defined by a recognized standards body, or, in the case of
|
110 |
+
interfaces specified for a particular programming language, one that
|
111 |
+
is widely used among developers working in that language.
|
112 |
+
|
113 |
+
The "System Libraries" of an executable work include anything, other
|
114 |
+
than the work as a whole, that (a) is included in the normal form of
|
115 |
+
packaging a Major Component, but which is not part of that Major
|
116 |
+
Component, and (b) serves only to enable use of the work with that
|
117 |
+
Major Component, or to implement a Standard Interface for which an
|
118 |
+
implementation is available to the public in source code form. A
|
119 |
+
"Major Component", in this context, means a major essential component
|
120 |
+
(kernel, window system, and so on) of the specific operating system
|
121 |
+
(if any) on which the executable work runs, or a compiler used to
|
122 |
+
produce the work, or an object code interpreter used to run it.
|
123 |
+
|
124 |
+
The "Corresponding Source" for a work in object code form means all
|
125 |
+
the source code needed to generate, install, and (for an executable
|
126 |
+
work) run the object code and to modify the work, including scripts to
|
127 |
+
control those activities. However, it does not include the work's
|
128 |
+
System Libraries, or general-purpose tools or generally available free
|
129 |
+
programs which are used unmodified in performing those activities but
|
130 |
+
which are not part of the work. For example, Corresponding Source
|
131 |
+
includes interface definition files associated with source files for
|
132 |
+
the work, and the source code for shared libraries and dynamically
|
133 |
+
linked subprograms that the work is specifically designed to require,
|
134 |
+
such as by intimate data communication or control flow between those
|
135 |
+
subprograms and other parts of the work.
|
136 |
+
|
137 |
+
The Corresponding Source need not include anything that users
|
138 |
+
can regenerate automatically from other parts of the Corresponding
|
139 |
+
Source.
|
140 |
+
|
141 |
+
The Corresponding Source for a work in source code form is that
|
142 |
+
same work.
|
143 |
+
|
144 |
+
2. Basic Permissions.
|
145 |
+
|
146 |
+
All rights granted under this License are granted for the term of
|
147 |
+
copyright on the Program, and are irrevocable provided the stated
|
148 |
+
conditions are met. This License explicitly affirms your unlimited
|
149 |
+
permission to run the unmodified Program. The output from running a
|
150 |
+
covered work is covered by this License only if the output, given its
|
151 |
+
content, constitutes a covered work. This License acknowledges your
|
152 |
+
rights of fair use or other equivalent, as provided by copyright law.
|
153 |
+
|
154 |
+
You may make, run and propagate covered works that you do not
|
155 |
+
convey, without conditions so long as your license otherwise remains
|
156 |
+
in force. You may convey covered works to others for the sole purpose
|
157 |
+
of having them make modifications exclusively for you, or provide you
|
158 |
+
with facilities for running those works, provided that you comply with
|
159 |
+
the terms of this License in conveying all material for which you do
|
160 |
+
not control copyright. Those thus making or running the covered works
|
161 |
+
for you must do so exclusively on your behalf, under your direction
|
162 |
+
and control, on terms that prohibit them from making any copies of
|
163 |
+
your copyrighted material outside their relationship with you.
|
164 |
+
|
165 |
+
Conveying under any other circumstances is permitted solely under
|
166 |
+
the conditions stated below. Sublicensing is not allowed; section 10
|
167 |
+
makes it unnecessary.
|
168 |
+
|
169 |
+
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
170 |
+
|
171 |
+
No covered work shall be deemed part of an effective technological
|
172 |
+
measure under any applicable law fulfilling obligations under article
|
173 |
+
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
174 |
+
similar laws prohibiting or restricting circumvention of such
|
175 |
+
measures.
|
176 |
+
|
177 |
+
When you convey a covered work, you waive any legal power to forbid
|
178 |
+
circumvention of technological measures to the extent such circumvention
|
179 |
+
is effected by exercising rights under this License with respect to
|
180 |
+
the covered work, and you disclaim any intention to limit operation or
|
181 |
+
modification of the work as a means of enforcing, against the work's
|
182 |
+
users, your or third parties' legal rights to forbid circumvention of
|
183 |
+
technological measures.
|
184 |
+
|
185 |
+
4. Conveying Verbatim Copies.
|
186 |
+
|
187 |
+
You may convey verbatim copies of the Program's source code as you
|
188 |
+
receive it, in any medium, provided that you conspicuously and
|
189 |
+
appropriately publish on each copy an appropriate copyright notice;
|
190 |
+
keep intact all notices stating that this License and any
|
191 |
+
non-permissive terms added in accord with section 7 apply to the code;
|
192 |
+
keep intact all notices of the absence of any warranty; and give all
|
193 |
+
recipients a copy of this License along with the Program.
|
194 |
+
|
195 |
+
You may charge any price or no price for each copy that you convey,
|
196 |
+
and you may offer support or warranty protection for a fee.
|
197 |
+
|
198 |
+
5. Conveying Modified Source Versions.
|
199 |
+
|
200 |
+
You may convey a work based on the Program, or the modifications to
|
201 |
+
produce it from the Program, in the form of source code under the
|
202 |
+
terms of section 4, provided that you also meet all of these conditions:
|
203 |
+
|
204 |
+
a) The work must carry prominent notices stating that you modified
|
205 |
+
it, and giving a relevant date.
|
206 |
+
|
207 |
+
b) The work must carry prominent notices stating that it is
|
208 |
+
released under this License and any conditions added under section
|
209 |
+
7. This requirement modifies the requirement in section 4 to
|
210 |
+
"keep intact all notices".
|
211 |
+
|
212 |
+
c) You must license the entire work, as a whole, under this
|
213 |
+
License to anyone who comes into possession of a copy. This
|
214 |
+
License will therefore apply, along with any applicable section 7
|
215 |
+
additional terms, to the whole of the work, and all its parts,
|
216 |
+
regardless of how they are packaged. This License gives no
|
217 |
+
permission to license the work in any other way, but it does not
|
218 |
+
invalidate such permission if you have separately received it.
|
219 |
+
|
220 |
+
d) If the work has interactive user interfaces, each must display
|
221 |
+
Appropriate Legal Notices; however, if the Program has interactive
|
222 |
+
interfaces that do not display Appropriate Legal Notices, your
|
223 |
+
work need not make them do so.
|
224 |
+
|
225 |
+
A compilation of a covered work with other separate and independent
|
226 |
+
works, which are not by their nature extensions of the covered work,
|
227 |
+
and which are not combined with it such as to form a larger program,
|
228 |
+
in or on a volume of a storage or distribution medium, is called an
|
229 |
+
"aggregate" if the compilation and its resulting copyright are not
|
230 |
+
used to limit the access or legal rights of the compilation's users
|
231 |
+
beyond what the individual works permit. Inclusion of a covered work
|
232 |
+
in an aggregate does not cause this License to apply to the other
|
233 |
+
parts of the aggregate.
|
234 |
+
|
235 |
+
6. Conveying Non-Source Forms.
|
236 |
+
|
237 |
+
You may convey a covered work in object code form under the terms
|
238 |
+
of sections 4 and 5, provided that you also convey the
|
239 |
+
machine-readable Corresponding Source under the terms of this License,
|
240 |
+
in one of these ways:
|
241 |
+
|
242 |
+
a) Convey the object code in, or embodied in, a physical product
|
243 |
+
(including a physical distribution medium), accompanied by the
|
244 |
+
Corresponding Source fixed on a durable physical medium
|
245 |
+
customarily used for software interchange.
|
246 |
+
|
247 |
+
b) Convey the object code in, or embodied in, a physical product
|
248 |
+
(including a physical distribution medium), accompanied by a
|
249 |
+
written offer, valid for at least three years and valid for as
|
250 |
+
long as you offer spare parts or customer support for that product
|
251 |
+
model, to give anyone who possesses the object code either (1) a
|
252 |
+
copy of the Corresponding Source for all the software in the
|
253 |
+
product that is covered by this License, on a durable physical
|
254 |
+
medium customarily used for software interchange, for a price no
|
255 |
+
more than your reasonable cost of physically performing this
|
256 |
+
conveying of source, or (2) access to copy the
|
257 |
+
Corresponding Source from a network server at no charge.
|
258 |
+
|
259 |
+
c) Convey individual copies of the object code with a copy of the
|
260 |
+
written offer to provide the Corresponding Source. This
|
261 |
+
alternative is allowed only occasionally and noncommercially, and
|
262 |
+
only if you received the object code with such an offer, in accord
|
263 |
+
with subsection 6b.
|
264 |
+
|
265 |
+
d) Convey the object code by offering access from a designated
|
266 |
+
place (gratis or for a charge), and offer equivalent access to the
|
267 |
+
Corresponding Source in the same way through the same place at no
|
268 |
+
further charge. You need not require recipients to copy the
|
269 |
+
Corresponding Source along with the object code. If the place to
|
270 |
+
copy the object code is a network server, the Corresponding Source
|
271 |
+
may be on a different server (operated by you or a third party)
|
272 |
+
that supports equivalent copying facilities, provided you maintain
|
273 |
+
clear directions next to the object code saying where to find the
|
274 |
+
Corresponding Source. Regardless of what server hosts the
|
275 |
+
Corresponding Source, you remain obligated to ensure that it is
|
276 |
+
available for as long as needed to satisfy these requirements.
|
277 |
+
|
278 |
+
e) Convey the object code using peer-to-peer transmission, provided
|
279 |
+
you inform other peers where the object code and Corresponding
|
280 |
+
Source of the work are being offered to the general public at no
|
281 |
+
charge under subsection 6d.
|
282 |
+
|
283 |
+
A separable portion of the object code, whose source code is excluded
|
284 |
+
from the Corresponding Source as a System Library, need not be
|
285 |
+
included in conveying the object code work.
|
286 |
+
|
287 |
+
A "User Product" is either (1) a "consumer product", which means any
|
288 |
+
tangible personal property which is normally used for personal, family,
|
289 |
+
or household purposes, or (2) anything designed or sold for incorporation
|
290 |
+
into a dwelling. In determining whether a product is a consumer product,
|
291 |
+
doubtful cases shall be resolved in favor of coverage. For a particular
|
292 |
+
product received by a particular user, "normally used" refers to a
|
293 |
+
typical or common use of that class of product, regardless of the status
|
294 |
+
of the particular user or of the way in which the particular user
|
295 |
+
actually uses, or expects or is expected to use, the product. A product
|
296 |
+
is a consumer product regardless of whether the product has substantial
|
297 |
+
commercial, industrial or non-consumer uses, unless such uses represent
|
298 |
+
the only significant mode of use of the product.
|
299 |
+
|
300 |
+
"Installation Information" for a User Product means any methods,
|
301 |
+
procedures, authorization keys, or other information required to install
|
302 |
+
and execute modified versions of a covered work in that User Product from
|
303 |
+
a modified version of its Corresponding Source. The information must
|
304 |
+
suffice to ensure that the continued functioning of the modified object
|
305 |
+
code is in no case prevented or interfered with solely because
|
306 |
+
modification has been made.
|
307 |
+
|
308 |
+
If you convey an object code work under this section in, or with, or
|
309 |
+
specifically for use in, a User Product, and the conveying occurs as
|
310 |
+
part of a transaction in which the right of possession and use of the
|
311 |
+
User Product is transferred to the recipient in perpetuity or for a
|
312 |
+
fixed term (regardless of how the transaction is characterized), the
|
313 |
+
Corresponding Source conveyed under this section must be accompanied
|
314 |
+
by the Installation Information. But this requirement does not apply
|
315 |
+
if neither you nor any third party retains the ability to install
|
316 |
+
modified object code on the User Product (for example, the work has
|
317 |
+
been installed in ROM).
|
318 |
+
|
319 |
+
The requirement to provide Installation Information does not include a
|
320 |
+
requirement to continue to provide support service, warranty, or updates
|
321 |
+
for a work that has been modified or installed by the recipient, or for
|
322 |
+
the User Product in which it has been modified or installed. Access to a
|
323 |
+
network may be denied when the modification itself materially and
|
324 |
+
adversely affects the operation of the network or violates the rules and
|
325 |
+
protocols for communication across the network.
|
326 |
+
|
327 |
+
Corresponding Source conveyed, and Installation Information provided,
|
328 |
+
in accord with this section must be in a format that is publicly
|
329 |
+
documented (and with an implementation available to the public in
|
330 |
+
source code form), and must require no special password or key for
|
331 |
+
unpacking, reading or copying.
|
332 |
+
|
333 |
+
7. Additional Terms.
|
334 |
+
|
335 |
+
"Additional permissions" are terms that supplement the terms of this
|
336 |
+
License by making exceptions from one or more of its conditions.
|
337 |
+
Additional permissions that are applicable to the entire Program shall
|
338 |
+
be treated as though they were included in this License, to the extent
|
339 |
+
that they are valid under applicable law. If additional permissions
|
340 |
+
apply only to part of the Program, that part may be used separately
|
341 |
+
under those permissions, but the entire Program remains governed by
|
342 |
+
this License without regard to the additional permissions.
|
343 |
+
|
344 |
+
When you convey a copy of a covered work, you may at your option
|
345 |
+
remove any additional permissions from that copy, or from any part of
|
346 |
+
it. (Additional permissions may be written to require their own
|
347 |
+
removal in certain cases when you modify the work.) You may place
|
348 |
+
additional permissions on material, added by you to a covered work,
|
349 |
+
for which you have or can give appropriate copyright permission.
|
350 |
+
|
351 |
+
Notwithstanding any other provision of this License, for material you
|
352 |
+
add to a covered work, you may (if authorized by the copyright holders of
|
353 |
+
that material) supplement the terms of this License with terms:
|
354 |
+
|
355 |
+
a) Disclaiming warranty or limiting liability differently from the
|
356 |
+
terms of sections 15 and 16 of this License; or
|
357 |
+
|
358 |
+
b) Requiring preservation of specified reasonable legal notices or
|
359 |
+
author attributions in that material or in the Appropriate Legal
|
360 |
+
Notices displayed by works containing it; or
|
361 |
+
|
362 |
+
c) Prohibiting misrepresentation of the origin of that material, or
|
363 |
+
requiring that modified versions of such material be marked in
|
364 |
+
reasonable ways as different from the original version; or
|
365 |
+
|
366 |
+
d) Limiting the use for publicity purposes of names of licensors or
|
367 |
+
authors of the material; or
|
368 |
+
|
369 |
+
e) Declining to grant rights under trademark law for use of some
|
370 |
+
trade names, trademarks, or service marks; or
|
371 |
+
|
372 |
+
f) Requiring indemnification of licensors and authors of that
|
373 |
+
material by anyone who conveys the material (or modified versions of
|
374 |
+
it) with contractual assumptions of liability to the recipient, for
|
375 |
+
any liability that these contractual assumptions directly impose on
|
376 |
+
those licensors and authors.
|
377 |
+
|
378 |
+
All other non-permissive additional terms are considered "further
|
379 |
+
restrictions" within the meaning of section 10. If the Program as you
|
380 |
+
received it, or any part of it, contains a notice stating that it is
|
381 |
+
governed by this License along with a term that is a further
|
382 |
+
restriction, you may remove that term. If a license document contains
|
383 |
+
a further restriction but permits relicensing or conveying under this
|
384 |
+
License, you may add to a covered work material governed by the terms
|
385 |
+
of that license document, provided that the further restriction does
|
386 |
+
not survive such relicensing or conveying.
|
387 |
+
|
388 |
+
If you add terms to a covered work in accord with this section, you
|
389 |
+
must place, in the relevant source files, a statement of the
|
390 |
+
additional terms that apply to those files, or a notice indicating
|
391 |
+
where to find the applicable terms.
|
392 |
+
|
393 |
+
Additional terms, permissive or non-permissive, may be stated in the
|
394 |
+
form of a separately written license, or stated as exceptions;
|
395 |
+
the above requirements apply either way.
|
396 |
+
|
397 |
+
8. Termination.
|
398 |
+
|
399 |
+
You may not propagate or modify a covered work except as expressly
|
400 |
+
provided under this License. Any attempt otherwise to propagate or
|
401 |
+
modify it is void, and will automatically terminate your rights under
|
402 |
+
this License (including any patent licenses granted under the third
|
403 |
+
paragraph of section 11).
|
404 |
+
|
405 |
+
However, if you cease all violation of this License, then your
|
406 |
+
license from a particular copyright holder is reinstated (a)
|
407 |
+
provisionally, unless and until the copyright holder explicitly and
|
408 |
+
finally terminates your license, and (b) permanently, if the copyright
|
409 |
+
holder fails to notify you of the violation by some reasonable means
|
410 |
+
prior to 60 days after the cessation.
|
411 |
+
|
412 |
+
Moreover, your license from a particular copyright holder is
|
413 |
+
reinstated permanently if the copyright holder notifies you of the
|
414 |
+
violation by some reasonable means, this is the first time you have
|
415 |
+
received notice of violation of this License (for any work) from that
|
416 |
+
copyright holder, and you cure the violation prior to 30 days after
|
417 |
+
your receipt of the notice.
|
418 |
+
|
419 |
+
Termination of your rights under this section does not terminate the
|
420 |
+
licenses of parties who have received copies or rights from you under
|
421 |
+
this License. If your rights have been terminated and not permanently
|
422 |
+
reinstated, you do not qualify to receive new licenses for the same
|
423 |
+
material under section 10.
|
424 |
+
|
425 |
+
9. Acceptance Not Required for Having Copies.
|
426 |
+
|
427 |
+
You are not required to accept this License in order to receive or
|
428 |
+
run a copy of the Program. Ancillary propagation of a covered work
|
429 |
+
occurring solely as a consequence of using peer-to-peer transmission
|
430 |
+
to receive a copy likewise does not require acceptance. However,
|
431 |
+
nothing other than this License grants you permission to propagate or
|
432 |
+
modify any covered work. These actions infringe copyright if you do
|
433 |
+
not accept this License. Therefore, by modifying or propagating a
|
434 |
+
covered work, you indicate your acceptance of this License to do so.
|
435 |
+
|
436 |
+
10. Automatic Licensing of Downstream Recipients.
|
437 |
+
|
438 |
+
Each time you convey a covered work, the recipient automatically
|
439 |
+
receives a license from the original licensors, to run, modify and
|
440 |
+
propagate that work, subject to this License. You are not responsible
|
441 |
+
for enforcing compliance by third parties with this License.
|
442 |
+
|
443 |
+
An "entity transaction" is a transaction transferring control of an
|
444 |
+
organization, or substantially all assets of one, or subdividing an
|
445 |
+
organization, or merging organizations. If propagation of a covered
|
446 |
+
work results from an entity transaction, each party to that
|
447 |
+
transaction who receives a copy of the work also receives whatever
|
448 |
+
licenses to the work the party's predecessor in interest had or could
|
449 |
+
give under the previous paragraph, plus a right to possession of the
|
450 |
+
Corresponding Source of the work from the predecessor in interest, if
|
451 |
+
the predecessor has it or can get it with reasonable efforts.
|
452 |
+
|
453 |
+
You may not impose any further restrictions on the exercise of the
|
454 |
+
rights granted or affirmed under this License. For example, you may
|
455 |
+
not impose a license fee, royalty, or other charge for exercise of
|
456 |
+
rights granted under this License, and you may not initiate litigation
|
457 |
+
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
458 |
+
any patent claim is infringed by making, using, selling, offering for
|
459 |
+
sale, or importing the Program or any portion of it.
|
460 |
+
|
461 |
+
11. Patents.
|
462 |
+
|
463 |
+
A "contributor" is a copyright holder who authorizes use under this
|
464 |
+
License of the Program or a work on which the Program is based. The
|
465 |
+
work thus licensed is called the contributor's "contributor version".
|
466 |
+
|
467 |
+
A contributor's "essential patent claims" are all patent claims
|
468 |
+
owned or controlled by the contributor, whether already acquired or
|
469 |
+
hereafter acquired, that would be infringed by some manner, permitted
|
470 |
+
by this License, of making, using, or selling its contributor version,
|
471 |
+
but do not include claims that would be infringed only as a
|
472 |
+
consequence of further modification of the contributor version. For
|
473 |
+
purposes of this definition, "control" includes the right to grant
|
474 |
+
patent sublicenses in a manner consistent with the requirements of
|
475 |
+
this License.
|
476 |
+
|
477 |
+
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
478 |
+
patent license under the contributor's essential patent claims, to
|
479 |
+
make, use, sell, offer for sale, import and otherwise run, modify and
|
480 |
+
propagate the contents of its contributor version.
|
481 |
+
|
482 |
+
In the following three paragraphs, a "patent license" is any express
|
483 |
+
agreement or commitment, however denominated, not to enforce a patent
|
484 |
+
(such as an express permission to practice a patent or covenant not to
|
485 |
+
sue for patent infringement). To "grant" such a patent license to a
|
486 |
+
party means to make such an agreement or commitment not to enforce a
|
487 |
+
patent against the party.
|
488 |
+
|
489 |
+
If you convey a covered work, knowingly relying on a patent license,
|
490 |
+
and the Corresponding Source of the work is not available for anyone
|
491 |
+
to copy, free of charge and under the terms of this License, through a
|
492 |
+
publicly available network server or other readily accessible means,
|
493 |
+
then you must either (1) cause the Corresponding Source to be so
|
494 |
+
available, or (2) arrange to deprive yourself of the benefit of the
|
495 |
+
patent license for this particular work, or (3) arrange, in a manner
|
496 |
+
consistent with the requirements of this License, to extend the patent
|
497 |
+
license to downstream recipients. "Knowingly relying" means you have
|
498 |
+
actual knowledge that, but for the patent license, your conveying the
|
499 |
+
covered work in a country, or your recipient's use of the covered work
|
500 |
+
in a country, would infringe one or more identifiable patents in that
|
501 |
+
country that you have reason to believe are valid.
|
502 |
+
|
503 |
+
If, pursuant to or in connection with a single transaction or
|
504 |
+
arrangement, you convey, or propagate by procuring conveyance of, a
|
505 |
+
covered work, and grant a patent license to some of the parties
|
506 |
+
receiving the covered work authorizing them to use, propagate, modify
|
507 |
+
or convey a specific copy of the covered work, then the patent license
|
508 |
+
you grant is automatically extended to all recipients of the covered
|
509 |
+
work and works based on it.
|
510 |
+
|
511 |
+
A patent license is "discriminatory" if it does not include within
|
512 |
+
the scope of its coverage, prohibits the exercise of, or is
|
513 |
+
conditioned on the non-exercise of one or more of the rights that are
|
514 |
+
specifically granted under this License. You may not convey a covered
|
515 |
+
work if you are a party to an arrangement with a third party that is
|
516 |
+
in the business of distributing software, under which you make payment
|
517 |
+
to the third party based on the extent of your activity of conveying
|
518 |
+
the work, and under which the third party grants, to any of the
|
519 |
+
parties who would receive the covered work from you, a discriminatory
|
520 |
+
patent license (a) in connection with copies of the covered work
|
521 |
+
conveyed by you (or copies made from those copies), or (b) primarily
|
522 |
+
for and in connection with specific products or compilations that
|
523 |
+
contain the covered work, unless you entered into that arrangement,
|
524 |
+
or that patent license was granted, prior to 28 March 2007.
|
525 |
+
|
526 |
+
Nothing in this License shall be construed as excluding or limiting
|
527 |
+
any implied license or other defenses to infringement that may
|
528 |
+
otherwise be available to you under applicable patent law.
|
529 |
+
|
530 |
+
12. No Surrender of Others' Freedom.
|
531 |
+
|
532 |
+
If conditions are imposed on you (whether by court order, agreement or
|
533 |
+
otherwise) that contradict the conditions of this License, they do not
|
534 |
+
excuse you from the conditions of this License. If you cannot convey a
|
535 |
+
covered work so as to satisfy simultaneously your obligations under this
|
536 |
+
License and any other pertinent obligations, then as a consequence you may
|
537 |
+
not convey it at all. For example, if you agree to terms that obligate you
|
538 |
+
to collect a royalty for further conveying from those to whom you convey
|
539 |
+
the Program, the only way you could satisfy both those terms and this
|
540 |
+
License would be to refrain entirely from conveying the Program.
|
541 |
+
|
542 |
+
13. Remote Network Interaction; Use with the GNU General Public License.
|
543 |
+
|
544 |
+
Notwithstanding any other provision of this License, if you modify the
|
545 |
+
Program, your modified version must prominently offer all users
|
546 |
+
interacting with it remotely through a computer network (if your version
|
547 |
+
supports such interaction) an opportunity to receive the Corresponding
|
548 |
+
Source of your version by providing access to the Corresponding Source
|
549 |
+
from a network server at no charge, through some standard or customary
|
550 |
+
means of facilitating copying of software. This Corresponding Source
|
551 |
+
shall include the Corresponding Source for any work covered by version 3
|
552 |
+
of the GNU General Public License that is incorporated pursuant to the
|
553 |
+
following paragraph.
|
554 |
+
|
555 |
+
Notwithstanding any other provision of this License, you have
|
556 |
+
permission to link or combine any covered work with a work licensed
|
557 |
+
under version 3 of the GNU General Public License into a single
|
558 |
+
combined work, and to convey the resulting work. The terms of this
|
559 |
+
License will continue to apply to the part which is the covered work,
|
560 |
+
but the work with which it is combined will remain governed by version
|
561 |
+
3 of the GNU General Public License.
|
562 |
+
|
563 |
+
14. Revised Versions of this License.
|
564 |
+
|
565 |
+
The Free Software Foundation may publish revised and/or new versions of
|
566 |
+
the GNU Affero General Public License from time to time. Such new versions
|
567 |
+
will be similar in spirit to the present version, but may differ in detail to
|
568 |
+
address new problems or concerns.
|
569 |
+
|
570 |
+
Each version is given a distinguishing version number. If the
|
571 |
+
Program specifies that a certain numbered version of the GNU Affero General
|
572 |
+
Public License "or any later version" applies to it, you have the
|
573 |
+
option of following the terms and conditions either of that numbered
|
574 |
+
version or of any later version published by the Free Software
|
575 |
+
Foundation. If the Program does not specify a version number of the
|
576 |
+
GNU Affero General Public License, you may choose any version ever published
|
577 |
+
by the Free Software Foundation.
|
578 |
+
|
579 |
+
If the Program specifies that a proxy can decide which future
|
580 |
+
versions of the GNU Affero General Public License can be used, that proxy's
|
581 |
+
public statement of acceptance of a version permanently authorizes you
|
582 |
+
to choose that version for the Program.
|
583 |
+
|
584 |
+
Later license versions may give you additional or different
|
585 |
+
permissions. However, no additional obligations are imposed on any
|
586 |
+
author or copyright holder as a result of your choosing to follow a
|
587 |
+
later version.
|
588 |
+
|
589 |
+
15. Disclaimer of Warranty.
|
590 |
+
|
591 |
+
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
592 |
+
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
593 |
+
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
594 |
+
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
595 |
+
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
596 |
+
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
597 |
+
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
598 |
+
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
599 |
+
|
600 |
+
16. Limitation of Liability.
|
601 |
+
|
602 |
+
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
603 |
+
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
604 |
+
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
605 |
+
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
606 |
+
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
607 |
+
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
608 |
+
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
609 |
+
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
610 |
+
SUCH DAMAGES.
|
611 |
+
|
612 |
+
17. Interpretation of Sections 15 and 16.
|
613 |
+
|
614 |
+
If the disclaimer of warranty and limitation of liability provided
|
615 |
+
above cannot be given local legal effect according to their terms,
|
616 |
+
reviewing courts shall apply local law that most closely approximates
|
617 |
+
an absolute waiver of all civil liability in connection with the
|
618 |
+
Program, unless a warranty or assumption of liability accompanies a
|
619 |
+
copy of the Program in return for a fee.
|
620 |
+
|
621 |
+
END OF TERMS AND CONDITIONS
|
622 |
+
|
623 |
+
How to Apply These Terms to Your New Programs
|
624 |
+
|
625 |
+
If you develop a new program, and you want it to be of the greatest
|
626 |
+
possible use to the public, the best way to achieve this is to make it
|
627 |
+
free software which everyone can redistribute and change under these terms.
|
628 |
+
|
629 |
+
To do so, attach the following notices to the program. It is safest
|
630 |
+
to attach them to the start of each source file to most effectively
|
631 |
+
state the exclusion of warranty; and each file should have at least
|
632 |
+
the "copyright" line and a pointer to where the full notice is found.
|
633 |
+
|
634 |
+
<one line to give the program's name and a brief idea of what it does.>
|
635 |
+
Copyright (C) <year> <name of author>
|
636 |
+
|
637 |
+
This program is free software: you can redistribute it and/or modify
|
638 |
+
it under the terms of the GNU Affero General Public License as published by
|
639 |
+
the Free Software Foundation, either version 3 of the License, or
|
640 |
+
(at your option) any later version.
|
641 |
+
|
642 |
+
This program is distributed in the hope that it will be useful,
|
643 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
644 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
645 |
+
GNU Affero General Public License for more details.
|
646 |
+
|
647 |
+
You should have received a copy of the GNU Affero General Public License
|
648 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
649 |
+
|
650 |
+
Also add information on how to contact you by electronic and paper mail.
|
651 |
+
|
652 |
+
If your software can interact with users remotely through a computer
|
653 |
+
network, you should also make sure that it provides a way for users to
|
654 |
+
get its source. For example, if your program is a web application, its
|
655 |
+
interface could display a "Source" link that leads users to an archive
|
656 |
+
of the code. There are many ways you could offer source, and different
|
657 |
+
solutions will be better for different programs; see section 13 for the
|
658 |
+
specific requirements.
|
659 |
+
|
660 |
+
You should also get your employer (if you work as a programmer) or school,
|
661 |
+
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
662 |
+
For more information on this, and how to apply and follow the GNU AGPL, see
|
663 |
+
<https://www.gnu.org/licenses/>.
|
sd/stable-diffusion-webui/README.md
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Stable Diffusion web UI
|
2 |
+
A browser interface based on Gradio library for Stable Diffusion.
|
3 |
+
|
4 |
+
![](screenshot.png)
|
5 |
+
|
6 |
+
## Features
|
7 |
+
[Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features):
|
8 |
+
- Original txt2img and img2img modes
|
9 |
+
- One click install and run script (but you still must install python and git)
|
10 |
+
- Outpainting
|
11 |
+
- Inpainting
|
12 |
+
- Color Sketch
|
13 |
+
- Prompt Matrix
|
14 |
+
- Stable Diffusion Upscale
|
15 |
+
- Attention, specify parts of text that the model should pay more attention to
|
16 |
+
- a man in a ((tuxedo)) - will pay more attention to tuxedo
|
17 |
+
- a man in a (tuxedo:1.21) - alternative syntax
|
18 |
+
- select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user)
|
19 |
+
- Loopback, run img2img processing multiple times
|
20 |
+
- X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters
|
21 |
+
- Textual Inversion
|
22 |
+
- have as many embeddings as you want and use any names you like for them
|
23 |
+
- use multiple embeddings with different numbers of vectors per token
|
24 |
+
- works with half precision floating point numbers
|
25 |
+
- train embeddings on 8GB (also reports of 6GB working)
|
26 |
+
- Extras tab with:
|
27 |
+
- GFPGAN, neural network that fixes faces
|
28 |
+
- CodeFormer, face restoration tool as an alternative to GFPGAN
|
29 |
+
- RealESRGAN, neural network upscaler
|
30 |
+
- ESRGAN, neural network upscaler with a lot of third party models
|
31 |
+
- SwinIR and Swin2SR([see here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2092)), neural network upscalers
|
32 |
+
- LDSR, Latent diffusion super resolution upscaling
|
33 |
+
- Resizing aspect ratio options
|
34 |
+
- Sampling method selection
|
35 |
+
- Adjust sampler eta values (noise multiplier)
|
36 |
+
- More advanced noise setting options
|
37 |
+
- Interrupt processing at any time
|
38 |
+
- 4GB video card support (also reports of 2GB working)
|
39 |
+
- Correct seeds for batches
|
40 |
+
- Live prompt token length validation
|
41 |
+
- Generation parameters
|
42 |
+
- parameters you used to generate images are saved with that image
|
43 |
+
- in PNG chunks for PNG, in EXIF for JPEG
|
44 |
+
- can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
|
45 |
+
- can be disabled in settings
|
46 |
+
- drag and drop an image/text-parameters to promptbox
|
47 |
+
- Read Generation Parameters Button, loads parameters in promptbox to UI
|
48 |
+
- Settings page
|
49 |
+
- Running arbitrary python code from UI (must run with --allow-code to enable)
|
50 |
+
- Mouseover hints for most UI elements
|
51 |
+
- Possible to change defaults/mix/max/step values for UI elements via text config
|
52 |
+
- Tiling support, a checkbox to create images that can be tiled like textures
|
53 |
+
- Progress bar and live image generation preview
|
54 |
+
- Can use a separate neural network to produce previews with almost none VRAM or compute requirement
|
55 |
+
- Negative prompt, an extra text field that allows you to list what you don't want to see in generated image
|
56 |
+
- Styles, a way to save part of prompt and easily apply them via dropdown later
|
57 |
+
- Variations, a way to generate same image but with tiny differences
|
58 |
+
- Seed resizing, a way to generate same image but at slightly different resolution
|
59 |
+
- CLIP interrogator, a button that tries to guess prompt from an image
|
60 |
+
- Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway
|
61 |
+
- Batch Processing, process a group of files using img2img
|
62 |
+
- Img2img Alternative, reverse Euler method of cross attention control
|
63 |
+
- Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions
|
64 |
+
- Reloading checkpoints on the fly
|
65 |
+
- Checkpoint Merger, a tab that allows you to merge up to 3 checkpoints into one
|
66 |
+
- [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community
|
67 |
+
- [Composable-Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/), a way to use multiple prompts at once
|
68 |
+
- separate prompts using uppercase `AND`
|
69 |
+
- also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2`
|
70 |
+
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
|
71 |
+
- DeepDanbooru integration, creates danbooru style tags for anime prompts
|
72 |
+
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
|
73 |
+
- via extension: [History tab](https://github.com/yfszzx/stable-diffusion-webui-images-browser): view, direct and delete images conveniently within the UI
|
74 |
+
- Generate forever option
|
75 |
+
- Training tab
|
76 |
+
- hypernetworks and embeddings options
|
77 |
+
- Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime)
|
78 |
+
- Clip skip
|
79 |
+
- Hypernetworks
|
80 |
+
- Loras (same as Hypernetworks but more pretty)
|
81 |
+
- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt.
|
82 |
+
- Can select to load a different VAE from settings screen
|
83 |
+
- Estimated completion time in progress bar
|
84 |
+
- API
|
85 |
+
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
|
86 |
+
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
|
87 |
+
- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
|
88 |
+
- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
|
89 |
+
- Now without any bad letters!
|
90 |
+
- Load checkpoints in safetensors format
|
91 |
+
- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64
|
92 |
+
- Now with a license!
|
93 |
+
- Reorder elements in the UI from settings screen
|
94 |
+
-
|
95 |
+
|
96 |
+
## Installation and Running
|
97 |
+
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
98 |
+
|
99 |
+
Alternatively, use online services (like Google Colab):
|
100 |
+
|
101 |
+
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
|
102 |
+
|
103 |
+
### Automatic Installation on Windows
|
104 |
+
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
|
105 |
+
2. Install [git](https://git-scm.com/download/win).
|
106 |
+
3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
|
107 |
+
4. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
|
108 |
+
|
109 |
+
### Automatic Installation on Linux
|
110 |
+
1. Install the dependencies:
|
111 |
+
```bash
|
112 |
+
# Debian-based:
|
113 |
+
sudo apt install wget git python3 python3-venv
|
114 |
+
# Red Hat-based:
|
115 |
+
sudo dnf install wget git python3
|
116 |
+
# Arch-based:
|
117 |
+
sudo pacman -S wget git python3
|
118 |
+
```
|
119 |
+
2. To install in `/home/$(whoami)/stable-diffusion-webui/`, run:
|
120 |
+
```bash
|
121 |
+
bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh)
|
122 |
+
```
|
123 |
+
3. Run `webui.sh`.
|
124 |
+
### Installation on Apple Silicon
|
125 |
+
|
126 |
+
Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon).
|
127 |
+
|
128 |
+
## Contributing
|
129 |
+
Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing)
|
130 |
+
|
131 |
+
## Documentation
|
132 |
+
The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
|
133 |
+
|
134 |
+
## Credits
|
135 |
+
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
|
136 |
+
|
137 |
+
- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers
|
138 |
+
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
|
139 |
+
- GFPGAN - https://github.com/TencentARC/GFPGAN.git
|
140 |
+
- CodeFormer - https://github.com/sczhou/CodeFormer
|
141 |
+
- ESRGAN - https://github.com/xinntao/ESRGAN
|
142 |
+
- SwinIR - https://github.com/JingyunLiang/SwinIR
|
143 |
+
- Swin2SR - https://github.com/mv-lab/swin2sr
|
144 |
+
- LDSR - https://github.com/Hafiidz/latent-diffusion
|
145 |
+
- MiDaS - https://github.com/isl-org/MiDaS
|
146 |
+
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
|
147 |
+
- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
|
148 |
+
- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
|
149 |
+
- Sub-quadratic Cross Attention layer optimization - Alex Birch (https://github.com/Birch-san/diffusers/pull/1), Amin Rezaei (https://github.com/AminRezaei0x443/memory-efficient-attention)
|
150 |
+
- Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
|
151 |
+
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
|
152 |
+
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
|
153 |
+
- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
|
154 |
+
- Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch
|
155 |
+
- xformers - https://github.com/facebookresearch/xformers
|
156 |
+
- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
|
157 |
+
- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6)
|
158 |
+
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
|
159 |
+
- Security advice - RyotaK
|
160 |
+
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
|
161 |
+
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
162 |
+
- (You)
|
sd/stable-diffusion-webui/extensions-builtin/LDSR/preload.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
import os
|
2 |
-
from modules import paths
|
3 |
-
|
4 |
-
|
5 |
-
def preload(parser):
|
6 |
-
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
|
|
|
1 |
+
import os
|
2 |
+
from modules import paths
|
3 |
+
|
4 |
+
|
5 |
+
def preload(parser):
|
6 |
+
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
|
sd/stable-diffusion-webui/extensions-builtin/Lora/extra_networks_lora.py
CHANGED
@@ -1,26 +1,26 @@
|
|
1 |
-
from modules import extra_networks, shared
|
2 |
-
import lora
|
3 |
-
|
4 |
-
class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
5 |
-
def __init__(self):
|
6 |
-
super().__init__('lora')
|
7 |
-
|
8 |
-
def activate(self, p, params_list):
|
9 |
-
additional = shared.opts.sd_lora
|
10 |
-
|
11 |
-
if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
|
12 |
-
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
13 |
-
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
14 |
-
|
15 |
-
names = []
|
16 |
-
multipliers = []
|
17 |
-
for params in params_list:
|
18 |
-
assert len(params.items) > 0
|
19 |
-
|
20 |
-
names.append(params.items[0])
|
21 |
-
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
|
22 |
-
|
23 |
-
lora.load_loras(names, multipliers)
|
24 |
-
|
25 |
-
def deactivate(self, p):
|
26 |
-
pass
|
|
|
1 |
+
from modules import extra_networks, shared
|
2 |
+
import lora
|
3 |
+
|
4 |
+
class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__('lora')
|
7 |
+
|
8 |
+
def activate(self, p, params_list):
|
9 |
+
additional = shared.opts.sd_lora
|
10 |
+
|
11 |
+
if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
|
12 |
+
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
13 |
+
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
14 |
+
|
15 |
+
names = []
|
16 |
+
multipliers = []
|
17 |
+
for params in params_list:
|
18 |
+
assert len(params.items) > 0
|
19 |
+
|
20 |
+
names.append(params.items[0])
|
21 |
+
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
|
22 |
+
|
23 |
+
lora.load_loras(names, multipliers)
|
24 |
+
|
25 |
+
def deactivate(self, p):
|
26 |
+
pass
|
sd/stable-diffusion-webui/extensions-builtin/Lora/lora.py
CHANGED
@@ -1,207 +1,207 @@
|
|
1 |
-
import glob
|
2 |
-
import os
|
3 |
-
import re
|
4 |
-
import torch
|
5 |
-
|
6 |
-
from modules import shared, devices, sd_models
|
7 |
-
|
8 |
-
re_digits = re.compile(r"\d+")
|
9 |
-
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
|
10 |
-
re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
|
11 |
-
re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
|
12 |
-
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
|
13 |
-
|
14 |
-
|
15 |
-
def convert_diffusers_name_to_compvis(key):
|
16 |
-
def match(match_list, regex):
|
17 |
-
r = re.match(regex, key)
|
18 |
-
if not r:
|
19 |
-
return False
|
20 |
-
|
21 |
-
match_list.clear()
|
22 |
-
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
|
23 |
-
return True
|
24 |
-
|
25 |
-
m = []
|
26 |
-
|
27 |
-
if match(m, re_unet_down_blocks):
|
28 |
-
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
|
29 |
-
|
30 |
-
if match(m, re_unet_mid_blocks):
|
31 |
-
return f"diffusion_model_middle_block_1_{m[1]}"
|
32 |
-
|
33 |
-
if match(m, re_unet_up_blocks):
|
34 |
-
return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
|
35 |
-
|
36 |
-
if match(m, re_text_block):
|
37 |
-
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
38 |
-
|
39 |
-
return key
|
40 |
-
|
41 |
-
|
42 |
-
class LoraOnDisk:
|
43 |
-
def __init__(self, name, filename):
|
44 |
-
self.name = name
|
45 |
-
self.filename = filename
|
46 |
-
|
47 |
-
|
48 |
-
class LoraModule:
|
49 |
-
def __init__(self, name):
|
50 |
-
self.name = name
|
51 |
-
self.multiplier = 1.0
|
52 |
-
self.modules = {}
|
53 |
-
self.mtime = None
|
54 |
-
|
55 |
-
|
56 |
-
class LoraUpDownModule:
|
57 |
-
def __init__(self):
|
58 |
-
self.up = None
|
59 |
-
self.down = None
|
60 |
-
self.alpha = None
|
61 |
-
|
62 |
-
|
63 |
-
def assign_lora_names_to_compvis_modules(sd_model):
|
64 |
-
lora_layer_mapping = {}
|
65 |
-
|
66 |
-
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
|
67 |
-
lora_name = name.replace(".", "_")
|
68 |
-
lora_layer_mapping[lora_name] = module
|
69 |
-
module.lora_layer_name = lora_name
|
70 |
-
|
71 |
-
for name, module in shared.sd_model.model.named_modules():
|
72 |
-
lora_name = name.replace(".", "_")
|
73 |
-
lora_layer_mapping[lora_name] = module
|
74 |
-
module.lora_layer_name = lora_name
|
75 |
-
|
76 |
-
sd_model.lora_layer_mapping = lora_layer_mapping
|
77 |
-
|
78 |
-
|
79 |
-
def load_lora(name, filename):
|
80 |
-
lora = LoraModule(name)
|
81 |
-
lora.mtime = os.path.getmtime(filename)
|
82 |
-
|
83 |
-
sd = sd_models.read_state_dict(filename)
|
84 |
-
|
85 |
-
keys_failed_to_match = []
|
86 |
-
|
87 |
-
for key_diffusers, weight in sd.items():
|
88 |
-
fullkey = convert_diffusers_name_to_compvis(key_diffusers)
|
89 |
-
key, lora_key = fullkey.split(".", 1)
|
90 |
-
|
91 |
-
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
92 |
-
if sd_module is None:
|
93 |
-
keys_failed_to_match.append(key_diffusers)
|
94 |
-
continue
|
95 |
-
|
96 |
-
lora_module = lora.modules.get(key, None)
|
97 |
-
if lora_module is None:
|
98 |
-
lora_module = LoraUpDownModule()
|
99 |
-
lora.modules[key] = lora_module
|
100 |
-
|
101 |
-
if lora_key == "alpha":
|
102 |
-
lora_module.alpha = weight.item()
|
103 |
-
continue
|
104 |
-
|
105 |
-
if type(sd_module) == torch.nn.Linear:
|
106 |
-
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
107 |
-
elif type(sd_module) == torch.nn.Conv2d:
|
108 |
-
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
109 |
-
else:
|
110 |
-
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
|
111 |
-
|
112 |
-
with torch.no_grad():
|
113 |
-
module.weight.copy_(weight)
|
114 |
-
|
115 |
-
module.to(device=devices.device, dtype=devices.dtype)
|
116 |
-
|
117 |
-
if lora_key == "lora_up.weight":
|
118 |
-
lora_module.up = module
|
119 |
-
elif lora_key == "lora_down.weight":
|
120 |
-
lora_module.down = module
|
121 |
-
else:
|
122 |
-
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
|
123 |
-
|
124 |
-
if len(keys_failed_to_match) > 0:
|
125 |
-
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
|
126 |
-
|
127 |
-
return lora
|
128 |
-
|
129 |
-
|
130 |
-
def load_loras(names, multipliers=None):
|
131 |
-
already_loaded = {}
|
132 |
-
|
133 |
-
for lora in loaded_loras:
|
134 |
-
if lora.name in names:
|
135 |
-
already_loaded[lora.name] = lora
|
136 |
-
|
137 |
-
loaded_loras.clear()
|
138 |
-
|
139 |
-
loras_on_disk = [available_loras.get(name, None) for name in names]
|
140 |
-
if any([x is None for x in loras_on_disk]):
|
141 |
-
list_available_loras()
|
142 |
-
|
143 |
-
loras_on_disk = [available_loras.get(name, None) for name in names]
|
144 |
-
|
145 |
-
for i, name in enumerate(names):
|
146 |
-
lora = already_loaded.get(name, None)
|
147 |
-
|
148 |
-
lora_on_disk = loras_on_disk[i]
|
149 |
-
if lora_on_disk is not None:
|
150 |
-
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
|
151 |
-
lora = load_lora(name, lora_on_disk.filename)
|
152 |
-
|
153 |
-
if lora is None:
|
154 |
-
print(f"Couldn't find Lora with name {name}")
|
155 |
-
continue
|
156 |
-
|
157 |
-
lora.multiplier = multipliers[i] if multipliers else 1.0
|
158 |
-
loaded_loras.append(lora)
|
159 |
-
|
160 |
-
|
161 |
-
def lora_forward(module, input, res):
|
162 |
-
if len(loaded_loras) == 0:
|
163 |
-
return res
|
164 |
-
|
165 |
-
lora_layer_name = getattr(module, 'lora_layer_name', None)
|
166 |
-
for lora in loaded_loras:
|
167 |
-
module = lora.modules.get(lora_layer_name, None)
|
168 |
-
if module is not None:
|
169 |
-
if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
|
170 |
-
res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
171 |
-
else:
|
172 |
-
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
173 |
-
|
174 |
-
return res
|
175 |
-
|
176 |
-
|
177 |
-
def lora_Linear_forward(self, input):
|
178 |
-
return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input))
|
179 |
-
|
180 |
-
|
181 |
-
def lora_Conv2d_forward(self, input):
|
182 |
-
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
|
183 |
-
|
184 |
-
|
185 |
-
def list_available_loras():
|
186 |
-
available_loras.clear()
|
187 |
-
|
188 |
-
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
189 |
-
|
190 |
-
candidates = \
|
191 |
-
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
|
192 |
-
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
|
193 |
-
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
|
194 |
-
|
195 |
-
for filename in sorted(candidates):
|
196 |
-
if os.path.isdir(filename):
|
197 |
-
continue
|
198 |
-
|
199 |
-
name = os.path.splitext(os.path.basename(filename))[0]
|
200 |
-
|
201 |
-
available_loras[name] = LoraOnDisk(name, filename)
|
202 |
-
|
203 |
-
|
204 |
-
available_loras = {}
|
205 |
-
loaded_loras = []
|
206 |
-
|
207 |
-
list_available_loras()
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from modules import shared, devices, sd_models
|
7 |
+
|
8 |
+
re_digits = re.compile(r"\d+")
|
9 |
+
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
|
10 |
+
re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
|
11 |
+
re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
|
12 |
+
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
|
13 |
+
|
14 |
+
|
15 |
+
def convert_diffusers_name_to_compvis(key):
|
16 |
+
def match(match_list, regex):
|
17 |
+
r = re.match(regex, key)
|
18 |
+
if not r:
|
19 |
+
return False
|
20 |
+
|
21 |
+
match_list.clear()
|
22 |
+
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
|
23 |
+
return True
|
24 |
+
|
25 |
+
m = []
|
26 |
+
|
27 |
+
if match(m, re_unet_down_blocks):
|
28 |
+
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
|
29 |
+
|
30 |
+
if match(m, re_unet_mid_blocks):
|
31 |
+
return f"diffusion_model_middle_block_1_{m[1]}"
|
32 |
+
|
33 |
+
if match(m, re_unet_up_blocks):
|
34 |
+
return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
|
35 |
+
|
36 |
+
if match(m, re_text_block):
|
37 |
+
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
38 |
+
|
39 |
+
return key
|
40 |
+
|
41 |
+
|
42 |
+
class LoraOnDisk:
|
43 |
+
def __init__(self, name, filename):
|
44 |
+
self.name = name
|
45 |
+
self.filename = filename
|
46 |
+
|
47 |
+
|
48 |
+
class LoraModule:
|
49 |
+
def __init__(self, name):
|
50 |
+
self.name = name
|
51 |
+
self.multiplier = 1.0
|
52 |
+
self.modules = {}
|
53 |
+
self.mtime = None
|
54 |
+
|
55 |
+
|
56 |
+
class LoraUpDownModule:
|
57 |
+
def __init__(self):
|
58 |
+
self.up = None
|
59 |
+
self.down = None
|
60 |
+
self.alpha = None
|
61 |
+
|
62 |
+
|
63 |
+
def assign_lora_names_to_compvis_modules(sd_model):
|
64 |
+
lora_layer_mapping = {}
|
65 |
+
|
66 |
+
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
|
67 |
+
lora_name = name.replace(".", "_")
|
68 |
+
lora_layer_mapping[lora_name] = module
|
69 |
+
module.lora_layer_name = lora_name
|
70 |
+
|
71 |
+
for name, module in shared.sd_model.model.named_modules():
|
72 |
+
lora_name = name.replace(".", "_")
|
73 |
+
lora_layer_mapping[lora_name] = module
|
74 |
+
module.lora_layer_name = lora_name
|
75 |
+
|
76 |
+
sd_model.lora_layer_mapping = lora_layer_mapping
|
77 |
+
|
78 |
+
|
79 |
+
def load_lora(name, filename):
|
80 |
+
lora = LoraModule(name)
|
81 |
+
lora.mtime = os.path.getmtime(filename)
|
82 |
+
|
83 |
+
sd = sd_models.read_state_dict(filename)
|
84 |
+
|
85 |
+
keys_failed_to_match = []
|
86 |
+
|
87 |
+
for key_diffusers, weight in sd.items():
|
88 |
+
fullkey = convert_diffusers_name_to_compvis(key_diffusers)
|
89 |
+
key, lora_key = fullkey.split(".", 1)
|
90 |
+
|
91 |
+
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
92 |
+
if sd_module is None:
|
93 |
+
keys_failed_to_match.append(key_diffusers)
|
94 |
+
continue
|
95 |
+
|
96 |
+
lora_module = lora.modules.get(key, None)
|
97 |
+
if lora_module is None:
|
98 |
+
lora_module = LoraUpDownModule()
|
99 |
+
lora.modules[key] = lora_module
|
100 |
+
|
101 |
+
if lora_key == "alpha":
|
102 |
+
lora_module.alpha = weight.item()
|
103 |
+
continue
|
104 |
+
|
105 |
+
if type(sd_module) == torch.nn.Linear:
|
106 |
+
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
107 |
+
elif type(sd_module) == torch.nn.Conv2d:
|
108 |
+
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
109 |
+
else:
|
110 |
+
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
|
111 |
+
|
112 |
+
with torch.no_grad():
|
113 |
+
module.weight.copy_(weight)
|
114 |
+
|
115 |
+
module.to(device=devices.device, dtype=devices.dtype)
|
116 |
+
|
117 |
+
if lora_key == "lora_up.weight":
|
118 |
+
lora_module.up = module
|
119 |
+
elif lora_key == "lora_down.weight":
|
120 |
+
lora_module.down = module
|
121 |
+
else:
|
122 |
+
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
|
123 |
+
|
124 |
+
if len(keys_failed_to_match) > 0:
|
125 |
+
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
|
126 |
+
|
127 |
+
return lora
|
128 |
+
|
129 |
+
|
130 |
+
def load_loras(names, multipliers=None):
|
131 |
+
already_loaded = {}
|
132 |
+
|
133 |
+
for lora in loaded_loras:
|
134 |
+
if lora.name in names:
|
135 |
+
already_loaded[lora.name] = lora
|
136 |
+
|
137 |
+
loaded_loras.clear()
|
138 |
+
|
139 |
+
loras_on_disk = [available_loras.get(name, None) for name in names]
|
140 |
+
if any([x is None for x in loras_on_disk]):
|
141 |
+
list_available_loras()
|
142 |
+
|
143 |
+
loras_on_disk = [available_loras.get(name, None) for name in names]
|
144 |
+
|
145 |
+
for i, name in enumerate(names):
|
146 |
+
lora = already_loaded.get(name, None)
|
147 |
+
|
148 |
+
lora_on_disk = loras_on_disk[i]
|
149 |
+
if lora_on_disk is not None:
|
150 |
+
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
|
151 |
+
lora = load_lora(name, lora_on_disk.filename)
|
152 |
+
|
153 |
+
if lora is None:
|
154 |
+
print(f"Couldn't find Lora with name {name}")
|
155 |
+
continue
|
156 |
+
|
157 |
+
lora.multiplier = multipliers[i] if multipliers else 1.0
|
158 |
+
loaded_loras.append(lora)
|
159 |
+
|
160 |
+
|
161 |
+
def lora_forward(module, input, res):
|
162 |
+
if len(loaded_loras) == 0:
|
163 |
+
return res
|
164 |
+
|
165 |
+
lora_layer_name = getattr(module, 'lora_layer_name', None)
|
166 |
+
for lora in loaded_loras:
|
167 |
+
module = lora.modules.get(lora_layer_name, None)
|
168 |
+
if module is not None:
|
169 |
+
if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
|
170 |
+
res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
171 |
+
else:
|
172 |
+
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
173 |
+
|
174 |
+
return res
|
175 |
+
|
176 |
+
|
177 |
+
def lora_Linear_forward(self, input):
|
178 |
+
return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input))
|
179 |
+
|
180 |
+
|
181 |
+
def lora_Conv2d_forward(self, input):
|
182 |
+
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
|
183 |
+
|
184 |
+
|
185 |
+
def list_available_loras():
|
186 |
+
available_loras.clear()
|
187 |
+
|
188 |
+
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
189 |
+
|
190 |
+
candidates = \
|
191 |
+
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
|
192 |
+
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
|
193 |
+
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
|
194 |
+
|
195 |
+
for filename in sorted(candidates):
|
196 |
+
if os.path.isdir(filename):
|
197 |
+
continue
|
198 |
+
|
199 |
+
name = os.path.splitext(os.path.basename(filename))[0]
|
200 |
+
|
201 |
+
available_loras[name] = LoraOnDisk(name, filename)
|
202 |
+
|
203 |
+
|
204 |
+
available_loras = {}
|
205 |
+
loaded_loras = []
|
206 |
+
|
207 |
+
list_available_loras()
|
sd/stable-diffusion-webui/extensions-builtin/Lora/preload.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
import os
|
2 |
-
from modules import paths
|
3 |
-
|
4 |
-
|
5 |
-
def preload(parser):
|
6 |
-
parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
|
|
|
1 |
+
import os
|
2 |
+
from modules import paths
|
3 |
+
|
4 |
+
|
5 |
+
def preload(parser):
|
6 |
+
parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
|
sd/stable-diffusion-webui/extensions-builtin/Lora/scripts/lora_script.py
CHANGED
@@ -1,38 +1,38 @@
|
|
1 |
-
import torch
|
2 |
-
import gradio as gr
|
3 |
-
|
4 |
-
import lora
|
5 |
-
import extra_networks_lora
|
6 |
-
import ui_extra_networks_lora
|
7 |
-
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
8 |
-
|
9 |
-
|
10 |
-
def unload():
|
11 |
-
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
12 |
-
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
|
13 |
-
|
14 |
-
|
15 |
-
def before_ui():
|
16 |
-
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
|
17 |
-
extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora())
|
18 |
-
|
19 |
-
|
20 |
-
if not hasattr(torch.nn, 'Linear_forward_before_lora'):
|
21 |
-
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
|
22 |
-
|
23 |
-
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
|
24 |
-
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
|
25 |
-
|
26 |
-
torch.nn.Linear.forward = lora.lora_Linear_forward
|
27 |
-
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
|
28 |
-
|
29 |
-
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
30 |
-
script_callbacks.on_script_unloaded(unload)
|
31 |
-
script_callbacks.on_before_ui(before_ui)
|
32 |
-
|
33 |
-
|
34 |
-
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
35 |
-
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
|
36 |
-
"lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
|
37 |
-
|
38 |
-
}))
|
|
|
1 |
+
import torch
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
import lora
|
5 |
+
import extra_networks_lora
|
6 |
+
import ui_extra_networks_lora
|
7 |
+
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
8 |
+
|
9 |
+
|
10 |
+
def unload():
|
11 |
+
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
12 |
+
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
|
13 |
+
|
14 |
+
|
15 |
+
def before_ui():
|
16 |
+
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
|
17 |
+
extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora())
|
18 |
+
|
19 |
+
|
20 |
+
if not hasattr(torch.nn, 'Linear_forward_before_lora'):
|
21 |
+
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
|
22 |
+
|
23 |
+
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
|
24 |
+
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
|
25 |
+
|
26 |
+
torch.nn.Linear.forward = lora.lora_Linear_forward
|
27 |
+
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
|
28 |
+
|
29 |
+
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
30 |
+
script_callbacks.on_script_unloaded(unload)
|
31 |
+
script_callbacks.on_before_ui(before_ui)
|
32 |
+
|
33 |
+
|
34 |
+
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
35 |
+
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
|
36 |
+
"lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
|
37 |
+
|
38 |
+
}))
|
sd/stable-diffusion-webui/extensions-builtin/Lora/ui_extra_networks_lora.py
CHANGED
@@ -1,37 +1,30 @@
|
|
1 |
-
import json
|
2 |
-
import os
|
3 |
-
import lora
|
4 |
-
|
5 |
-
from modules import shared, ui_extra_networks
|
6 |
-
|
7 |
-
|
8 |
-
class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
9 |
-
def __init__(self):
|
10 |
-
super().__init__('Lora')
|
11 |
-
|
12 |
-
def refresh(self):
|
13 |
-
lora.list_available_loras()
|
14 |
-
|
15 |
-
def list_items(self):
|
16 |
-
for name, lora_on_disk in lora.available_loras.items():
|
17 |
-
path, ext = os.path.splitext(lora_on_disk.filename)
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
32 |
-
"local_preview": path + ".png",
|
33 |
-
}
|
34 |
-
|
35 |
-
def allowed_directories_for_previews(self):
|
36 |
-
return [shared.cmd_opts.lora_dir]
|
37 |
-
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import lora
|
4 |
+
|
5 |
+
from modules import shared, ui_extra_networks
|
6 |
+
|
7 |
+
|
8 |
+
class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__('Lora')
|
11 |
+
|
12 |
+
def refresh(self):
|
13 |
+
lora.list_available_loras()
|
14 |
+
|
15 |
+
def list_items(self):
|
16 |
+
for name, lora_on_disk in lora.available_loras.items():
|
17 |
+
path, ext = os.path.splitext(lora_on_disk.filename)
|
18 |
+
yield {
|
19 |
+
"name": name,
|
20 |
+
"filename": path,
|
21 |
+
"preview": self._find_preview(path),
|
22 |
+
"description": self._find_description(path),
|
23 |
+
"search_term": self.search_terms_from_path(lora_on_disk.filename),
|
24 |
+
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
25 |
+
"local_preview": path + ".png",
|
26 |
+
}
|
27 |
+
|
28 |
+
def allowed_directories_for_previews(self):
|
29 |
+
return [shared.cmd_opts.lora_dir]
|
30 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sd/stable-diffusion-webui/extensions-builtin/ScuNET/preload.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
import os
|
2 |
-
from modules import paths
|
3 |
-
|
4 |
-
|
5 |
-
def preload(parser):
|
6 |
-
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
|
|
|
1 |
+
import os
|
2 |
+
from modules import paths
|
3 |
+
|
4 |
+
|
5 |
+
def preload(parser):
|
6 |
+
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
|
sd/stable-diffusion-webui/extensions-builtin/SwinIR/preload.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
import os
|
2 |
-
from modules import paths
|
3 |
-
|
4 |
-
|
5 |
-
def preload(parser):
|
6 |
-
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
|
|
|
1 |
+
import os
|
2 |
+
from modules import paths
|
3 |
+
|
4 |
+
|
5 |
+
def preload(parser):
|
6 |
+
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
|
sd/stable-diffusion-webui/extensions-builtin/SwinIR/swinir_model_arch_v2.py
CHANGED
@@ -1,1017 +1,1017 @@
|
|
1 |
-
# -----------------------------------------------------------------------------------
|
2 |
-
# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/
|
3 |
-
# Written by Conde and Choi et al.
|
4 |
-
# -----------------------------------------------------------------------------------
|
5 |
-
|
6 |
-
import math
|
7 |
-
import numpy as np
|
8 |
-
import torch
|
9 |
-
import torch.nn as nn
|
10 |
-
import torch.nn.functional as F
|
11 |
-
import torch.utils.checkpoint as checkpoint
|
12 |
-
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
13 |
-
|
14 |
-
|
15 |
-
class Mlp(nn.Module):
|
16 |
-
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
17 |
-
super().__init__()
|
18 |
-
out_features = out_features or in_features
|
19 |
-
hidden_features = hidden_features or in_features
|
20 |
-
self.fc1 = nn.Linear(in_features, hidden_features)
|
21 |
-
self.act = act_layer()
|
22 |
-
self.fc2 = nn.Linear(hidden_features, out_features)
|
23 |
-
self.drop = nn.Dropout(drop)
|
24 |
-
|
25 |
-
def forward(self, x):
|
26 |
-
x = self.fc1(x)
|
27 |
-
x = self.act(x)
|
28 |
-
x = self.drop(x)
|
29 |
-
x = self.fc2(x)
|
30 |
-
x = self.drop(x)
|
31 |
-
return x
|
32 |
-
|
33 |
-
|
34 |
-
def window_partition(x, window_size):
|
35 |
-
"""
|
36 |
-
Args:
|
37 |
-
x: (B, H, W, C)
|
38 |
-
window_size (int): window size
|
39 |
-
Returns:
|
40 |
-
windows: (num_windows*B, window_size, window_size, C)
|
41 |
-
"""
|
42 |
-
B, H, W, C = x.shape
|
43 |
-
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
44 |
-
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
45 |
-
return windows
|
46 |
-
|
47 |
-
|
48 |
-
def window_reverse(windows, window_size, H, W):
|
49 |
-
"""
|
50 |
-
Args:
|
51 |
-
windows: (num_windows*B, window_size, window_size, C)
|
52 |
-
window_size (int): Window size
|
53 |
-
H (int): Height of image
|
54 |
-
W (int): Width of image
|
55 |
-
Returns:
|
56 |
-
x: (B, H, W, C)
|
57 |
-
"""
|
58 |
-
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
59 |
-
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
60 |
-
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
61 |
-
return x
|
62 |
-
|
63 |
-
class WindowAttention(nn.Module):
|
64 |
-
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
65 |
-
It supports both of shifted and non-shifted window.
|
66 |
-
Args:
|
67 |
-
dim (int): Number of input channels.
|
68 |
-
window_size (tuple[int]): The height and width of the window.
|
69 |
-
num_heads (int): Number of attention heads.
|
70 |
-
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
71 |
-
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
72 |
-
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
73 |
-
pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
|
74 |
-
"""
|
75 |
-
|
76 |
-
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
77 |
-
pretrained_window_size=[0, 0]):
|
78 |
-
|
79 |
-
super().__init__()
|
80 |
-
self.dim = dim
|
81 |
-
self.window_size = window_size # Wh, Ww
|
82 |
-
self.pretrained_window_size = pretrained_window_size
|
83 |
-
self.num_heads = num_heads
|
84 |
-
|
85 |
-
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
|
86 |
-
|
87 |
-
# mlp to generate continuous relative position bias
|
88 |
-
self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
|
89 |
-
nn.ReLU(inplace=True),
|
90 |
-
nn.Linear(512, num_heads, bias=False))
|
91 |
-
|
92 |
-
# get relative_coords_table
|
93 |
-
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
|
94 |
-
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
|
95 |
-
relative_coords_table = torch.stack(
|
96 |
-
torch.meshgrid([relative_coords_h,
|
97 |
-
relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
|
98 |
-
if pretrained_window_size[0] > 0:
|
99 |
-
relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
|
100 |
-
relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
|
101 |
-
else:
|
102 |
-
relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
|
103 |
-
relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
|
104 |
-
relative_coords_table *= 8 # normalize to -8, 8
|
105 |
-
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
106 |
-
torch.abs(relative_coords_table) + 1.0) / np.log2(8)
|
107 |
-
|
108 |
-
self.register_buffer("relative_coords_table", relative_coords_table)
|
109 |
-
|
110 |
-
# get pair-wise relative position index for each token inside the window
|
111 |
-
coords_h = torch.arange(self.window_size[0])
|
112 |
-
coords_w = torch.arange(self.window_size[1])
|
113 |
-
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
114 |
-
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
115 |
-
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
116 |
-
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
117 |
-
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
118 |
-
relative_coords[:, :, 1] += self.window_size[1] - 1
|
119 |
-
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
120 |
-
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
121 |
-
self.register_buffer("relative_position_index", relative_position_index)
|
122 |
-
|
123 |
-
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
124 |
-
if qkv_bias:
|
125 |
-
self.q_bias = nn.Parameter(torch.zeros(dim))
|
126 |
-
self.v_bias = nn.Parameter(torch.zeros(dim))
|
127 |
-
else:
|
128 |
-
self.q_bias = None
|
129 |
-
self.v_bias = None
|
130 |
-
self.attn_drop = nn.Dropout(attn_drop)
|
131 |
-
self.proj = nn.Linear(dim, dim)
|
132 |
-
self.proj_drop = nn.Dropout(proj_drop)
|
133 |
-
self.softmax = nn.Softmax(dim=-1)
|
134 |
-
|
135 |
-
def forward(self, x, mask=None):
|
136 |
-
"""
|
137 |
-
Args:
|
138 |
-
x: input features with shape of (num_windows*B, N, C)
|
139 |
-
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
140 |
-
"""
|
141 |
-
B_, N, C = x.shape
|
142 |
-
qkv_bias = None
|
143 |
-
if self.q_bias is not None:
|
144 |
-
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
145 |
-
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
146 |
-
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
147 |
-
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
148 |
-
|
149 |
-
# cosine attention
|
150 |
-
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
|
151 |
-
logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
|
152 |
-
attn = attn * logit_scale
|
153 |
-
|
154 |
-
relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
|
155 |
-
relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
156 |
-
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
157 |
-
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
158 |
-
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
|
159 |
-
attn = attn + relative_position_bias.unsqueeze(0)
|
160 |
-
|
161 |
-
if mask is not None:
|
162 |
-
nW = mask.shape[0]
|
163 |
-
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
164 |
-
attn = attn.view(-1, self.num_heads, N, N)
|
165 |
-
attn = self.softmax(attn)
|
166 |
-
else:
|
167 |
-
attn = self.softmax(attn)
|
168 |
-
|
169 |
-
attn = self.attn_drop(attn)
|
170 |
-
|
171 |
-
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
172 |
-
x = self.proj(x)
|
173 |
-
x = self.proj_drop(x)
|
174 |
-
return x
|
175 |
-
|
176 |
-
def extra_repr(self) -> str:
|
177 |
-
return f'dim={self.dim}, window_size={self.window_size}, ' \
|
178 |
-
f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
|
179 |
-
|
180 |
-
def flops(self, N):
|
181 |
-
# calculate flops for 1 window with token length of N
|
182 |
-
flops = 0
|
183 |
-
# qkv = self.qkv(x)
|
184 |
-
flops += N * self.dim * 3 * self.dim
|
185 |
-
# attn = (q @ k.transpose(-2, -1))
|
186 |
-
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
187 |
-
# x = (attn @ v)
|
188 |
-
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
189 |
-
# x = self.proj(x)
|
190 |
-
flops += N * self.dim * self.dim
|
191 |
-
return flops
|
192 |
-
|
193 |
-
class SwinTransformerBlock(nn.Module):
|
194 |
-
r""" Swin Transformer Block.
|
195 |
-
Args:
|
196 |
-
dim (int): Number of input channels.
|
197 |
-
input_resolution (tuple[int]): Input resulotion.
|
198 |
-
num_heads (int): Number of attention heads.
|
199 |
-
window_size (int): Window size.
|
200 |
-
shift_size (int): Shift size for SW-MSA.
|
201 |
-
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
202 |
-
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
203 |
-
drop (float, optional): Dropout rate. Default: 0.0
|
204 |
-
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
205 |
-
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
206 |
-
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
207 |
-
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
208 |
-
pretrained_window_size (int): Window size in pre-training.
|
209 |
-
"""
|
210 |
-
|
211 |
-
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
212 |
-
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
|
213 |
-
act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
|
214 |
-
super().__init__()
|
215 |
-
self.dim = dim
|
216 |
-
self.input_resolution = input_resolution
|
217 |
-
self.num_heads = num_heads
|
218 |
-
self.window_size = window_size
|
219 |
-
self.shift_size = shift_size
|
220 |
-
self.mlp_ratio = mlp_ratio
|
221 |
-
if min(self.input_resolution) <= self.window_size:
|
222 |
-
# if window size is larger than input resolution, we don't partition windows
|
223 |
-
self.shift_size = 0
|
224 |
-
self.window_size = min(self.input_resolution)
|
225 |
-
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
226 |
-
|
227 |
-
self.norm1 = norm_layer(dim)
|
228 |
-
self.attn = WindowAttention(
|
229 |
-
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
230 |
-
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
|
231 |
-
pretrained_window_size=to_2tuple(pretrained_window_size))
|
232 |
-
|
233 |
-
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
234 |
-
self.norm2 = norm_layer(dim)
|
235 |
-
mlp_hidden_dim = int(dim * mlp_ratio)
|
236 |
-
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
237 |
-
|
238 |
-
if self.shift_size > 0:
|
239 |
-
attn_mask = self.calculate_mask(self.input_resolution)
|
240 |
-
else:
|
241 |
-
attn_mask = None
|
242 |
-
|
243 |
-
self.register_buffer("attn_mask", attn_mask)
|
244 |
-
|
245 |
-
def calculate_mask(self, x_size):
|
246 |
-
# calculate attention mask for SW-MSA
|
247 |
-
H, W = x_size
|
248 |
-
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
249 |
-
h_slices = (slice(0, -self.window_size),
|
250 |
-
slice(-self.window_size, -self.shift_size),
|
251 |
-
slice(-self.shift_size, None))
|
252 |
-
w_slices = (slice(0, -self.window_size),
|
253 |
-
slice(-self.window_size, -self.shift_size),
|
254 |
-
slice(-self.shift_size, None))
|
255 |
-
cnt = 0
|
256 |
-
for h in h_slices:
|
257 |
-
for w in w_slices:
|
258 |
-
img_mask[:, h, w, :] = cnt
|
259 |
-
cnt += 1
|
260 |
-
|
261 |
-
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
262 |
-
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
263 |
-
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
264 |
-
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
265 |
-
|
266 |
-
return attn_mask
|
267 |
-
|
268 |
-
def forward(self, x, x_size):
|
269 |
-
H, W = x_size
|
270 |
-
B, L, C = x.shape
|
271 |
-
#assert L == H * W, "input feature has wrong size"
|
272 |
-
|
273 |
-
shortcut = x
|
274 |
-
x = x.view(B, H, W, C)
|
275 |
-
|
276 |
-
# cyclic shift
|
277 |
-
if self.shift_size > 0:
|
278 |
-
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
279 |
-
else:
|
280 |
-
shifted_x = x
|
281 |
-
|
282 |
-
# partition windows
|
283 |
-
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
284 |
-
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
285 |
-
|
286 |
-
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
287 |
-
if self.input_resolution == x_size:
|
288 |
-
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
289 |
-
else:
|
290 |
-
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
291 |
-
|
292 |
-
# merge windows
|
293 |
-
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
294 |
-
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
295 |
-
|
296 |
-
# reverse cyclic shift
|
297 |
-
if self.shift_size > 0:
|
298 |
-
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
299 |
-
else:
|
300 |
-
x = shifted_x
|
301 |
-
x = x.view(B, H * W, C)
|
302 |
-
x = shortcut + self.drop_path(self.norm1(x))
|
303 |
-
|
304 |
-
# FFN
|
305 |
-
x = x + self.drop_path(self.norm2(self.mlp(x)))
|
306 |
-
|
307 |
-
return x
|
308 |
-
|
309 |
-
def extra_repr(self) -> str:
|
310 |
-
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
311 |
-
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
312 |
-
|
313 |
-
def flops(self):
|
314 |
-
flops = 0
|
315 |
-
H, W = self.input_resolution
|
316 |
-
# norm1
|
317 |
-
flops += self.dim * H * W
|
318 |
-
# W-MSA/SW-MSA
|
319 |
-
nW = H * W / self.window_size / self.window_size
|
320 |
-
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
321 |
-
# mlp
|
322 |
-
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
323 |
-
# norm2
|
324 |
-
flops += self.dim * H * W
|
325 |
-
return flops
|
326 |
-
|
327 |
-
class PatchMerging(nn.Module):
|
328 |
-
r""" Patch Merging Layer.
|
329 |
-
Args:
|
330 |
-
input_resolution (tuple[int]): Resolution of input feature.
|
331 |
-
dim (int): Number of input channels.
|
332 |
-
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
333 |
-
"""
|
334 |
-
|
335 |
-
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
336 |
-
super().__init__()
|
337 |
-
self.input_resolution = input_resolution
|
338 |
-
self.dim = dim
|
339 |
-
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
340 |
-
self.norm = norm_layer(2 * dim)
|
341 |
-
|
342 |
-
def forward(self, x):
|
343 |
-
"""
|
344 |
-
x: B, H*W, C
|
345 |
-
"""
|
346 |
-
H, W = self.input_resolution
|
347 |
-
B, L, C = x.shape
|
348 |
-
assert L == H * W, "input feature has wrong size"
|
349 |
-
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
350 |
-
|
351 |
-
x = x.view(B, H, W, C)
|
352 |
-
|
353 |
-
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
354 |
-
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
355 |
-
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
356 |
-
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
357 |
-
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
358 |
-
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
359 |
-
|
360 |
-
x = self.reduction(x)
|
361 |
-
x = self.norm(x)
|
362 |
-
|
363 |
-
return x
|
364 |
-
|
365 |
-
def extra_repr(self) -> str:
|
366 |
-
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
367 |
-
|
368 |
-
def flops(self):
|
369 |
-
H, W = self.input_resolution
|
370 |
-
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
371 |
-
flops += H * W * self.dim // 2
|
372 |
-
return flops
|
373 |
-
|
374 |
-
class BasicLayer(nn.Module):
|
375 |
-
""" A basic Swin Transformer layer for one stage.
|
376 |
-
Args:
|
377 |
-
dim (int): Number of input channels.
|
378 |
-
input_resolution (tuple[int]): Input resolution.
|
379 |
-
depth (int): Number of blocks.
|
380 |
-
num_heads (int): Number of attention heads.
|
381 |
-
window_size (int): Local window size.
|
382 |
-
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
383 |
-
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
384 |
-
drop (float, optional): Dropout rate. Default: 0.0
|
385 |
-
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
386 |
-
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
387 |
-
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
388 |
-
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
389 |
-
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
390 |
-
pretrained_window_size (int): Local window size in pre-training.
|
391 |
-
"""
|
392 |
-
|
393 |
-
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
394 |
-
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
|
395 |
-
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
396 |
-
pretrained_window_size=0):
|
397 |
-
|
398 |
-
super().__init__()
|
399 |
-
self.dim = dim
|
400 |
-
self.input_resolution = input_resolution
|
401 |
-
self.depth = depth
|
402 |
-
self.use_checkpoint = use_checkpoint
|
403 |
-
|
404 |
-
# build blocks
|
405 |
-
self.blocks = nn.ModuleList([
|
406 |
-
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
407 |
-
num_heads=num_heads, window_size=window_size,
|
408 |
-
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
409 |
-
mlp_ratio=mlp_ratio,
|
410 |
-
qkv_bias=qkv_bias,
|
411 |
-
drop=drop, attn_drop=attn_drop,
|
412 |
-
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
413 |
-
norm_layer=norm_layer,
|
414 |
-
pretrained_window_size=pretrained_window_size)
|
415 |
-
for i in range(depth)])
|
416 |
-
|
417 |
-
# patch merging layer
|
418 |
-
if downsample is not None:
|
419 |
-
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
420 |
-
else:
|
421 |
-
self.downsample = None
|
422 |
-
|
423 |
-
def forward(self, x, x_size):
|
424 |
-
for blk in self.blocks:
|
425 |
-
if self.use_checkpoint:
|
426 |
-
x = checkpoint.checkpoint(blk, x, x_size)
|
427 |
-
else:
|
428 |
-
x = blk(x, x_size)
|
429 |
-
if self.downsample is not None:
|
430 |
-
x = self.downsample(x)
|
431 |
-
return x
|
432 |
-
|
433 |
-
def extra_repr(self) -> str:
|
434 |
-
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
435 |
-
|
436 |
-
def flops(self):
|
437 |
-
flops = 0
|
438 |
-
for blk in self.blocks:
|
439 |
-
flops += blk.flops()
|
440 |
-
if self.downsample is not None:
|
441 |
-
flops += self.downsample.flops()
|
442 |
-
return flops
|
443 |
-
|
444 |
-
def _init_respostnorm(self):
|
445 |
-
for blk in self.blocks:
|
446 |
-
nn.init.constant_(blk.norm1.bias, 0)
|
447 |
-
nn.init.constant_(blk.norm1.weight, 0)
|
448 |
-
nn.init.constant_(blk.norm2.bias, 0)
|
449 |
-
nn.init.constant_(blk.norm2.weight, 0)
|
450 |
-
|
451 |
-
class PatchEmbed(nn.Module):
|
452 |
-
r""" Image to Patch Embedding
|
453 |
-
Args:
|
454 |
-
img_size (int): Image size. Default: 224.
|
455 |
-
patch_size (int): Patch token size. Default: 4.
|
456 |
-
in_chans (int): Number of input image channels. Default: 3.
|
457 |
-
embed_dim (int): Number of linear projection output channels. Default: 96.
|
458 |
-
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
459 |
-
"""
|
460 |
-
|
461 |
-
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
462 |
-
super().__init__()
|
463 |
-
img_size = to_2tuple(img_size)
|
464 |
-
patch_size = to_2tuple(patch_size)
|
465 |
-
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
466 |
-
self.img_size = img_size
|
467 |
-
self.patch_size = patch_size
|
468 |
-
self.patches_resolution = patches_resolution
|
469 |
-
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
470 |
-
|
471 |
-
self.in_chans = in_chans
|
472 |
-
self.embed_dim = embed_dim
|
473 |
-
|
474 |
-
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
475 |
-
if norm_layer is not None:
|
476 |
-
self.norm = norm_layer(embed_dim)
|
477 |
-
else:
|
478 |
-
self.norm = None
|
479 |
-
|
480 |
-
def forward(self, x):
|
481 |
-
B, C, H, W = x.shape
|
482 |
-
# FIXME look at relaxing size constraints
|
483 |
-
# assert H == self.img_size[0] and W == self.img_size[1],
|
484 |
-
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
485 |
-
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
|
486 |
-
if self.norm is not None:
|
487 |
-
x = self.norm(x)
|
488 |
-
return x
|
489 |
-
|
490 |
-
def flops(self):
|
491 |
-
Ho, Wo = self.patches_resolution
|
492 |
-
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
493 |
-
if self.norm is not None:
|
494 |
-
flops += Ho * Wo * self.embed_dim
|
495 |
-
return flops
|
496 |
-
|
497 |
-
class RSTB(nn.Module):
|
498 |
-
"""Residual Swin Transformer Block (RSTB).
|
499 |
-
|
500 |
-
Args:
|
501 |
-
dim (int): Number of input channels.
|
502 |
-
input_resolution (tuple[int]): Input resolution.
|
503 |
-
depth (int): Number of blocks.
|
504 |
-
num_heads (int): Number of attention heads.
|
505 |
-
window_size (int): Local window size.
|
506 |
-
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
507 |
-
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
508 |
-
drop (float, optional): Dropout rate. Default: 0.0
|
509 |
-
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
510 |
-
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
511 |
-
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
512 |
-
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
513 |
-
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
514 |
-
img_size: Input image size.
|
515 |
-
patch_size: Patch size.
|
516 |
-
resi_connection: The convolutional block before residual connection.
|
517 |
-
"""
|
518 |
-
|
519 |
-
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
520 |
-
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
|
521 |
-
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
522 |
-
img_size=224, patch_size=4, resi_connection='1conv'):
|
523 |
-
super(RSTB, self).__init__()
|
524 |
-
|
525 |
-
self.dim = dim
|
526 |
-
self.input_resolution = input_resolution
|
527 |
-
|
528 |
-
self.residual_group = BasicLayer(dim=dim,
|
529 |
-
input_resolution=input_resolution,
|
530 |
-
depth=depth,
|
531 |
-
num_heads=num_heads,
|
532 |
-
window_size=window_size,
|
533 |
-
mlp_ratio=mlp_ratio,
|
534 |
-
qkv_bias=qkv_bias,
|
535 |
-
drop=drop, attn_drop=attn_drop,
|
536 |
-
drop_path=drop_path,
|
537 |
-
norm_layer=norm_layer,
|
538 |
-
downsample=downsample,
|
539 |
-
use_checkpoint=use_checkpoint)
|
540 |
-
|
541 |
-
if resi_connection == '1conv':
|
542 |
-
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
543 |
-
elif resi_connection == '3conv':
|
544 |
-
# to save parameters and memory
|
545 |
-
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
546 |
-
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
|
547 |
-
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
548 |
-
nn.Conv2d(dim // 4, dim, 3, 1, 1))
|
549 |
-
|
550 |
-
self.patch_embed = PatchEmbed(
|
551 |
-
img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
|
552 |
-
norm_layer=None)
|
553 |
-
|
554 |
-
self.patch_unembed = PatchUnEmbed(
|
555 |
-
img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
|
556 |
-
norm_layer=None)
|
557 |
-
|
558 |
-
def forward(self, x, x_size):
|
559 |
-
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
|
560 |
-
|
561 |
-
def flops(self):
|
562 |
-
flops = 0
|
563 |
-
flops += self.residual_group.flops()
|
564 |
-
H, W = self.input_resolution
|
565 |
-
flops += H * W * self.dim * self.dim * 9
|
566 |
-
flops += self.patch_embed.flops()
|
567 |
-
flops += self.patch_unembed.flops()
|
568 |
-
|
569 |
-
return flops
|
570 |
-
|
571 |
-
class PatchUnEmbed(nn.Module):
|
572 |
-
r""" Image to Patch Unembedding
|
573 |
-
|
574 |
-
Args:
|
575 |
-
img_size (int): Image size. Default: 224.
|
576 |
-
patch_size (int): Patch token size. Default: 4.
|
577 |
-
in_chans (int): Number of input image channels. Default: 3.
|
578 |
-
embed_dim (int): Number of linear projection output channels. Default: 96.
|
579 |
-
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
580 |
-
"""
|
581 |
-
|
582 |
-
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
583 |
-
super().__init__()
|
584 |
-
img_size = to_2tuple(img_size)
|
585 |
-
patch_size = to_2tuple(patch_size)
|
586 |
-
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
587 |
-
self.img_size = img_size
|
588 |
-
self.patch_size = patch_size
|
589 |
-
self.patches_resolution = patches_resolution
|
590 |
-
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
591 |
-
|
592 |
-
self.in_chans = in_chans
|
593 |
-
self.embed_dim = embed_dim
|
594 |
-
|
595 |
-
def forward(self, x, x_size):
|
596 |
-
B, HW, C = x.shape
|
597 |
-
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
|
598 |
-
return x
|
599 |
-
|
600 |
-
def flops(self):
|
601 |
-
flops = 0
|
602 |
-
return flops
|
603 |
-
|
604 |
-
|
605 |
-
class Upsample(nn.Sequential):
|
606 |
-
"""Upsample module.
|
607 |
-
|
608 |
-
Args:
|
609 |
-
scale (int): Scale factor. Supported scales: 2^n and 3.
|
610 |
-
num_feat (int): Channel number of intermediate features.
|
611 |
-
"""
|
612 |
-
|
613 |
-
def __init__(self, scale, num_feat):
|
614 |
-
m = []
|
615 |
-
if (scale & (scale - 1)) == 0: # scale = 2^n
|
616 |
-
for _ in range(int(math.log(scale, 2))):
|
617 |
-
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
618 |
-
m.append(nn.PixelShuffle(2))
|
619 |
-
elif scale == 3:
|
620 |
-
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
621 |
-
m.append(nn.PixelShuffle(3))
|
622 |
-
else:
|
623 |
-
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
624 |
-
super(Upsample, self).__init__(*m)
|
625 |
-
|
626 |
-
class Upsample_hf(nn.Sequential):
|
627 |
-
"""Upsample module.
|
628 |
-
|
629 |
-
Args:
|
630 |
-
scale (int): Scale factor. Supported scales: 2^n and 3.
|
631 |
-
num_feat (int): Channel number of intermediate features.
|
632 |
-
"""
|
633 |
-
|
634 |
-
def __init__(self, scale, num_feat):
|
635 |
-
m = []
|
636 |
-
if (scale & (scale - 1)) == 0: # scale = 2^n
|
637 |
-
for _ in range(int(math.log(scale, 2))):
|
638 |
-
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
639 |
-
m.append(nn.PixelShuffle(2))
|
640 |
-
elif scale == 3:
|
641 |
-
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
642 |
-
m.append(nn.PixelShuffle(3))
|
643 |
-
else:
|
644 |
-
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
645 |
-
super(Upsample_hf, self).__init__(*m)
|
646 |
-
|
647 |
-
|
648 |
-
class UpsampleOneStep(nn.Sequential):
|
649 |
-
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
|
650 |
-
Used in lightweight SR to save parameters.
|
651 |
-
|
652 |
-
Args:
|
653 |
-
scale (int): Scale factor. Supported scales: 2^n and 3.
|
654 |
-
num_feat (int): Channel number of intermediate features.
|
655 |
-
|
656 |
-
"""
|
657 |
-
|
658 |
-
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
|
659 |
-
self.num_feat = num_feat
|
660 |
-
self.input_resolution = input_resolution
|
661 |
-
m = []
|
662 |
-
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
|
663 |
-
m.append(nn.PixelShuffle(scale))
|
664 |
-
super(UpsampleOneStep, self).__init__(*m)
|
665 |
-
|
666 |
-
def flops(self):
|
667 |
-
H, W = self.input_resolution
|
668 |
-
flops = H * W * self.num_feat * 3 * 9
|
669 |
-
return flops
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
class Swin2SR(nn.Module):
|
674 |
-
r""" Swin2SR
|
675 |
-
A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
|
676 |
-
|
677 |
-
Args:
|
678 |
-
img_size (int | tuple(int)): Input image size. Default 64
|
679 |
-
patch_size (int | tuple(int)): Patch size. Default: 1
|
680 |
-
in_chans (int): Number of input image channels. Default: 3
|
681 |
-
embed_dim (int): Patch embedding dimension. Default: 96
|
682 |
-
depths (tuple(int)): Depth of each Swin Transformer layer.
|
683 |
-
num_heads (tuple(int)): Number of attention heads in different layers.
|
684 |
-
window_size (int): Window size. Default: 7
|
685 |
-
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
686 |
-
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
687 |
-
drop_rate (float): Dropout rate. Default: 0
|
688 |
-
attn_drop_rate (float): Attention dropout rate. Default: 0
|
689 |
-
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
690 |
-
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
691 |
-
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
692 |
-
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
693 |
-
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
694 |
-
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
695 |
-
img_range: Image range. 1. or 255.
|
696 |
-
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
|
697 |
-
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
|
698 |
-
"""
|
699 |
-
|
700 |
-
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
701 |
-
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
|
702 |
-
window_size=7, mlp_ratio=4., qkv_bias=True,
|
703 |
-
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
704 |
-
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
705 |
-
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
706 |
-
**kwargs):
|
707 |
-
super(Swin2SR, self).__init__()
|
708 |
-
num_in_ch = in_chans
|
709 |
-
num_out_ch = in_chans
|
710 |
-
num_feat = 64
|
711 |
-
self.img_range = img_range
|
712 |
-
if in_chans == 3:
|
713 |
-
rgb_mean = (0.4488, 0.4371, 0.4040)
|
714 |
-
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
715 |
-
else:
|
716 |
-
self.mean = torch.zeros(1, 1, 1, 1)
|
717 |
-
self.upscale = upscale
|
718 |
-
self.upsampler = upsampler
|
719 |
-
self.window_size = window_size
|
720 |
-
|
721 |
-
#####################################################################################################
|
722 |
-
################################### 1, shallow feature extraction ###################################
|
723 |
-
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
724 |
-
|
725 |
-
#####################################################################################################
|
726 |
-
################################### 2, deep feature extraction ######################################
|
727 |
-
self.num_layers = len(depths)
|
728 |
-
self.embed_dim = embed_dim
|
729 |
-
self.ape = ape
|
730 |
-
self.patch_norm = patch_norm
|
731 |
-
self.num_features = embed_dim
|
732 |
-
self.mlp_ratio = mlp_ratio
|
733 |
-
|
734 |
-
# split image into non-overlapping patches
|
735 |
-
self.patch_embed = PatchEmbed(
|
736 |
-
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
737 |
-
norm_layer=norm_layer if self.patch_norm else None)
|
738 |
-
num_patches = self.patch_embed.num_patches
|
739 |
-
patches_resolution = self.patch_embed.patches_resolution
|
740 |
-
self.patches_resolution = patches_resolution
|
741 |
-
|
742 |
-
# merge non-overlapping patches into image
|
743 |
-
self.patch_unembed = PatchUnEmbed(
|
744 |
-
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
745 |
-
norm_layer=norm_layer if self.patch_norm else None)
|
746 |
-
|
747 |
-
# absolute position embedding
|
748 |
-
if self.ape:
|
749 |
-
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
750 |
-
trunc_normal_(self.absolute_pos_embed, std=.02)
|
751 |
-
|
752 |
-
self.pos_drop = nn.Dropout(p=drop_rate)
|
753 |
-
|
754 |
-
# stochastic depth
|
755 |
-
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
756 |
-
|
757 |
-
# build Residual Swin Transformer blocks (RSTB)
|
758 |
-
self.layers = nn.ModuleList()
|
759 |
-
for i_layer in range(self.num_layers):
|
760 |
-
layer = RSTB(dim=embed_dim,
|
761 |
-
input_resolution=(patches_resolution[0],
|
762 |
-
patches_resolution[1]),
|
763 |
-
depth=depths[i_layer],
|
764 |
-
num_heads=num_heads[i_layer],
|
765 |
-
window_size=window_size,
|
766 |
-
mlp_ratio=self.mlp_ratio,
|
767 |
-
qkv_bias=qkv_bias,
|
768 |
-
drop=drop_rate, attn_drop=attn_drop_rate,
|
769 |
-
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
770 |
-
norm_layer=norm_layer,
|
771 |
-
downsample=None,
|
772 |
-
use_checkpoint=use_checkpoint,
|
773 |
-
img_size=img_size,
|
774 |
-
patch_size=patch_size,
|
775 |
-
resi_connection=resi_connection
|
776 |
-
|
777 |
-
)
|
778 |
-
self.layers.append(layer)
|
779 |
-
|
780 |
-
if self.upsampler == 'pixelshuffle_hf':
|
781 |
-
self.layers_hf = nn.ModuleList()
|
782 |
-
for i_layer in range(self.num_layers):
|
783 |
-
layer = RSTB(dim=embed_dim,
|
784 |
-
input_resolution=(patches_resolution[0],
|
785 |
-
patches_resolution[1]),
|
786 |
-
depth=depths[i_layer],
|
787 |
-
num_heads=num_heads[i_layer],
|
788 |
-
window_size=window_size,
|
789 |
-
mlp_ratio=self.mlp_ratio,
|
790 |
-
qkv_bias=qkv_bias,
|
791 |
-
drop=drop_rate, attn_drop=attn_drop_rate,
|
792 |
-
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
793 |
-
norm_layer=norm_layer,
|
794 |
-
downsample=None,
|
795 |
-
use_checkpoint=use_checkpoint,
|
796 |
-
img_size=img_size,
|
797 |
-
patch_size=patch_size,
|
798 |
-
resi_connection=resi_connection
|
799 |
-
|
800 |
-
)
|
801 |
-
self.layers_hf.append(layer)
|
802 |
-
|
803 |
-
self.norm = norm_layer(self.num_features)
|
804 |
-
|
805 |
-
# build the last conv layer in deep feature extraction
|
806 |
-
if resi_connection == '1conv':
|
807 |
-
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
808 |
-
elif resi_connection == '3conv':
|
809 |
-
# to save parameters and memory
|
810 |
-
self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
|
811 |
-
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
812 |
-
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
|
813 |
-
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
814 |
-
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
|
815 |
-
|
816 |
-
#####################################################################################################
|
817 |
-
################################ 3, high quality image reconstruction ################################
|
818 |
-
if self.upsampler == 'pixelshuffle':
|
819 |
-
# for classical SR
|
820 |
-
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
821 |
-
nn.LeakyReLU(inplace=True))
|
822 |
-
self.upsample = Upsample(upscale, num_feat)
|
823 |
-
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
824 |
-
elif self.upsampler == 'pixelshuffle_aux':
|
825 |
-
self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
826 |
-
self.conv_before_upsample = nn.Sequential(
|
827 |
-
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
828 |
-
nn.LeakyReLU(inplace=True))
|
829 |
-
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
830 |
-
self.conv_after_aux = nn.Sequential(
|
831 |
-
nn.Conv2d(3, num_feat, 3, 1, 1),
|
832 |
-
nn.LeakyReLU(inplace=True))
|
833 |
-
self.upsample = Upsample(upscale, num_feat)
|
834 |
-
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
835 |
-
|
836 |
-
elif self.upsampler == 'pixelshuffle_hf':
|
837 |
-
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
838 |
-
nn.LeakyReLU(inplace=True))
|
839 |
-
self.upsample = Upsample(upscale, num_feat)
|
840 |
-
self.upsample_hf = Upsample_hf(upscale, num_feat)
|
841 |
-
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
842 |
-
self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
|
843 |
-
nn.LeakyReLU(inplace=True))
|
844 |
-
self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
845 |
-
self.conv_before_upsample_hf = nn.Sequential(
|
846 |
-
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
847 |
-
nn.LeakyReLU(inplace=True))
|
848 |
-
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
849 |
-
|
850 |
-
elif self.upsampler == 'pixelshuffledirect':
|
851 |
-
# for lightweight SR (to save parameters)
|
852 |
-
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
853 |
-
(patches_resolution[0], patches_resolution[1]))
|
854 |
-
elif self.upsampler == 'nearest+conv':
|
855 |
-
# for real-world SR (less artifacts)
|
856 |
-
assert self.upscale == 4, 'only support x4 now.'
|
857 |
-
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
858 |
-
nn.LeakyReLU(inplace=True))
|
859 |
-
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
860 |
-
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
861 |
-
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
862 |
-
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
863 |
-
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
864 |
-
else:
|
865 |
-
# for image denoising and JPEG compression artifact reduction
|
866 |
-
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
|
867 |
-
|
868 |
-
self.apply(self._init_weights)
|
869 |
-
|
870 |
-
def _init_weights(self, m):
|
871 |
-
if isinstance(m, nn.Linear):
|
872 |
-
trunc_normal_(m.weight, std=.02)
|
873 |
-
if isinstance(m, nn.Linear) and m.bias is not None:
|
874 |
-
nn.init.constant_(m.bias, 0)
|
875 |
-
elif isinstance(m, nn.LayerNorm):
|
876 |
-
nn.init.constant_(m.bias, 0)
|
877 |
-
nn.init.constant_(m.weight, 1.0)
|
878 |
-
|
879 |
-
@torch.jit.ignore
|
880 |
-
def no_weight_decay(self):
|
881 |
-
return {'absolute_pos_embed'}
|
882 |
-
|
883 |
-
@torch.jit.ignore
|
884 |
-
def no_weight_decay_keywords(self):
|
885 |
-
return {'relative_position_bias_table'}
|
886 |
-
|
887 |
-
def check_image_size(self, x):
|
888 |
-
_, _, h, w = x.size()
|
889 |
-
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
890 |
-
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
891 |
-
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
892 |
-
return x
|
893 |
-
|
894 |
-
def forward_features(self, x):
|
895 |
-
x_size = (x.shape[2], x.shape[3])
|
896 |
-
x = self.patch_embed(x)
|
897 |
-
if self.ape:
|
898 |
-
x = x + self.absolute_pos_embed
|
899 |
-
x = self.pos_drop(x)
|
900 |
-
|
901 |
-
for layer in self.layers:
|
902 |
-
x = layer(x, x_size)
|
903 |
-
|
904 |
-
x = self.norm(x) # B L C
|
905 |
-
x = self.patch_unembed(x, x_size)
|
906 |
-
|
907 |
-
return x
|
908 |
-
|
909 |
-
def forward_features_hf(self, x):
|
910 |
-
x_size = (x.shape[2], x.shape[3])
|
911 |
-
x = self.patch_embed(x)
|
912 |
-
if self.ape:
|
913 |
-
x = x + self.absolute_pos_embed
|
914 |
-
x = self.pos_drop(x)
|
915 |
-
|
916 |
-
for layer in self.layers_hf:
|
917 |
-
x = layer(x, x_size)
|
918 |
-
|
919 |
-
x = self.norm(x) # B L C
|
920 |
-
x = self.patch_unembed(x, x_size)
|
921 |
-
|
922 |
-
return x
|
923 |
-
|
924 |
-
def forward(self, x):
|
925 |
-
H, W = x.shape[2:]
|
926 |
-
x = self.check_image_size(x)
|
927 |
-
|
928 |
-
self.mean = self.mean.type_as(x)
|
929 |
-
x = (x - self.mean) * self.img_range
|
930 |
-
|
931 |
-
if self.upsampler == 'pixelshuffle':
|
932 |
-
# for classical SR
|
933 |
-
x = self.conv_first(x)
|
934 |
-
x = self.conv_after_body(self.forward_features(x)) + x
|
935 |
-
x = self.conv_before_upsample(x)
|
936 |
-
x = self.conv_last(self.upsample(x))
|
937 |
-
elif self.upsampler == 'pixelshuffle_aux':
|
938 |
-
bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
|
939 |
-
bicubic = self.conv_bicubic(bicubic)
|
940 |
-
x = self.conv_first(x)
|
941 |
-
x = self.conv_after_body(self.forward_features(x)) + x
|
942 |
-
x = self.conv_before_upsample(x)
|
943 |
-
aux = self.conv_aux(x) # b, 3, LR_H, LR_W
|
944 |
-
x = self.conv_after_aux(aux)
|
945 |
-
x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
|
946 |
-
x = self.conv_last(x)
|
947 |
-
aux = aux / self.img_range + self.mean
|
948 |
-
elif self.upsampler == 'pixelshuffle_hf':
|
949 |
-
# for classical SR with HF
|
950 |
-
x = self.conv_first(x)
|
951 |
-
x = self.conv_after_body(self.forward_features(x)) + x
|
952 |
-
x_before = self.conv_before_upsample(x)
|
953 |
-
x_out = self.conv_last(self.upsample(x_before))
|
954 |
-
|
955 |
-
x_hf = self.conv_first_hf(x_before)
|
956 |
-
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
|
957 |
-
x_hf = self.conv_before_upsample_hf(x_hf)
|
958 |
-
x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
|
959 |
-
x = x_out + x_hf
|
960 |
-
x_hf = x_hf / self.img_range + self.mean
|
961 |
-
|
962 |
-
elif self.upsampler == 'pixelshuffledirect':
|
963 |
-
# for lightweight SR
|
964 |
-
x = self.conv_first(x)
|
965 |
-
x = self.conv_after_body(self.forward_features(x)) + x
|
966 |
-
x = self.upsample(x)
|
967 |
-
elif self.upsampler == 'nearest+conv':
|
968 |
-
# for real-world SR
|
969 |
-
x = self.conv_first(x)
|
970 |
-
x = self.conv_after_body(self.forward_features(x)) + x
|
971 |
-
x = self.conv_before_upsample(x)
|
972 |
-
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
973 |
-
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
974 |
-
x = self.conv_last(self.lrelu(self.conv_hr(x)))
|
975 |
-
else:
|
976 |
-
# for image denoising and JPEG compression artifact reduction
|
977 |
-
x_first = self.conv_first(x)
|
978 |
-
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
979 |
-
x = x + self.conv_last(res)
|
980 |
-
|
981 |
-
x = x / self.img_range + self.mean
|
982 |
-
if self.upsampler == "pixelshuffle_aux":
|
983 |
-
return x[:, :, :H*self.upscale, :W*self.upscale], aux
|
984 |
-
|
985 |
-
elif self.upsampler == "pixelshuffle_hf":
|
986 |
-
x_out = x_out / self.img_range + self.mean
|
987 |
-
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
|
988 |
-
|
989 |
-
else:
|
990 |
-
return x[:, :, :H*self.upscale, :W*self.upscale]
|
991 |
-
|
992 |
-
def flops(self):
|
993 |
-
flops = 0
|
994 |
-
H, W = self.patches_resolution
|
995 |
-
flops += H * W * 3 * self.embed_dim * 9
|
996 |
-
flops += self.patch_embed.flops()
|
997 |
-
for i, layer in enumerate(self.layers):
|
998 |
-
flops += layer.flops()
|
999 |
-
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
1000 |
-
flops += self.upsample.flops()
|
1001 |
-
return flops
|
1002 |
-
|
1003 |
-
|
1004 |
-
if __name__ == '__main__':
|
1005 |
-
upscale = 4
|
1006 |
-
window_size = 8
|
1007 |
-
height = (1024 // upscale // window_size + 1) * window_size
|
1008 |
-
width = (720 // upscale // window_size + 1) * window_size
|
1009 |
-
model = Swin2SR(upscale=2, img_size=(height, width),
|
1010 |
-
window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
|
1011 |
-
embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
|
1012 |
-
print(model)
|
1013 |
-
print(height, width, model.flops() / 1e9)
|
1014 |
-
|
1015 |
-
x = torch.randn((1, 3, height, width))
|
1016 |
-
x = model(x)
|
1017 |
print(x.shape)
|
|
|
1 |
+
# -----------------------------------------------------------------------------------
|
2 |
+
# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/
|
3 |
+
# Written by Conde and Choi et al.
|
4 |
+
# -----------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import math
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.utils.checkpoint as checkpoint
|
12 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
13 |
+
|
14 |
+
|
15 |
+
class Mlp(nn.Module):
|
16 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
17 |
+
super().__init__()
|
18 |
+
out_features = out_features or in_features
|
19 |
+
hidden_features = hidden_features or in_features
|
20 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
21 |
+
self.act = act_layer()
|
22 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
23 |
+
self.drop = nn.Dropout(drop)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
x = self.fc1(x)
|
27 |
+
x = self.act(x)
|
28 |
+
x = self.drop(x)
|
29 |
+
x = self.fc2(x)
|
30 |
+
x = self.drop(x)
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
def window_partition(x, window_size):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
x: (B, H, W, C)
|
38 |
+
window_size (int): window size
|
39 |
+
Returns:
|
40 |
+
windows: (num_windows*B, window_size, window_size, C)
|
41 |
+
"""
|
42 |
+
B, H, W, C = x.shape
|
43 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
44 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
45 |
+
return windows
|
46 |
+
|
47 |
+
|
48 |
+
def window_reverse(windows, window_size, H, W):
|
49 |
+
"""
|
50 |
+
Args:
|
51 |
+
windows: (num_windows*B, window_size, window_size, C)
|
52 |
+
window_size (int): Window size
|
53 |
+
H (int): Height of image
|
54 |
+
W (int): Width of image
|
55 |
+
Returns:
|
56 |
+
x: (B, H, W, C)
|
57 |
+
"""
|
58 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
59 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
60 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
61 |
+
return x
|
62 |
+
|
63 |
+
class WindowAttention(nn.Module):
|
64 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
65 |
+
It supports both of shifted and non-shifted window.
|
66 |
+
Args:
|
67 |
+
dim (int): Number of input channels.
|
68 |
+
window_size (tuple[int]): The height and width of the window.
|
69 |
+
num_heads (int): Number of attention heads.
|
70 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
71 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
72 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
73 |
+
pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
77 |
+
pretrained_window_size=[0, 0]):
|
78 |
+
|
79 |
+
super().__init__()
|
80 |
+
self.dim = dim
|
81 |
+
self.window_size = window_size # Wh, Ww
|
82 |
+
self.pretrained_window_size = pretrained_window_size
|
83 |
+
self.num_heads = num_heads
|
84 |
+
|
85 |
+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
|
86 |
+
|
87 |
+
# mlp to generate continuous relative position bias
|
88 |
+
self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
|
89 |
+
nn.ReLU(inplace=True),
|
90 |
+
nn.Linear(512, num_heads, bias=False))
|
91 |
+
|
92 |
+
# get relative_coords_table
|
93 |
+
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
|
94 |
+
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
|
95 |
+
relative_coords_table = torch.stack(
|
96 |
+
torch.meshgrid([relative_coords_h,
|
97 |
+
relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
|
98 |
+
if pretrained_window_size[0] > 0:
|
99 |
+
relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
|
100 |
+
relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
|
101 |
+
else:
|
102 |
+
relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
|
103 |
+
relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
|
104 |
+
relative_coords_table *= 8 # normalize to -8, 8
|
105 |
+
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
106 |
+
torch.abs(relative_coords_table) + 1.0) / np.log2(8)
|
107 |
+
|
108 |
+
self.register_buffer("relative_coords_table", relative_coords_table)
|
109 |
+
|
110 |
+
# get pair-wise relative position index for each token inside the window
|
111 |
+
coords_h = torch.arange(self.window_size[0])
|
112 |
+
coords_w = torch.arange(self.window_size[1])
|
113 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
114 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
115 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
116 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
117 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
118 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
119 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
120 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
121 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
122 |
+
|
123 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
124 |
+
if qkv_bias:
|
125 |
+
self.q_bias = nn.Parameter(torch.zeros(dim))
|
126 |
+
self.v_bias = nn.Parameter(torch.zeros(dim))
|
127 |
+
else:
|
128 |
+
self.q_bias = None
|
129 |
+
self.v_bias = None
|
130 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
131 |
+
self.proj = nn.Linear(dim, dim)
|
132 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
133 |
+
self.softmax = nn.Softmax(dim=-1)
|
134 |
+
|
135 |
+
def forward(self, x, mask=None):
|
136 |
+
"""
|
137 |
+
Args:
|
138 |
+
x: input features with shape of (num_windows*B, N, C)
|
139 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
140 |
+
"""
|
141 |
+
B_, N, C = x.shape
|
142 |
+
qkv_bias = None
|
143 |
+
if self.q_bias is not None:
|
144 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
145 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
146 |
+
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
147 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
148 |
+
|
149 |
+
# cosine attention
|
150 |
+
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
|
151 |
+
logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
|
152 |
+
attn = attn * logit_scale
|
153 |
+
|
154 |
+
relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
|
155 |
+
relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
156 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
157 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
158 |
+
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
|
159 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
160 |
+
|
161 |
+
if mask is not None:
|
162 |
+
nW = mask.shape[0]
|
163 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
164 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
165 |
+
attn = self.softmax(attn)
|
166 |
+
else:
|
167 |
+
attn = self.softmax(attn)
|
168 |
+
|
169 |
+
attn = self.attn_drop(attn)
|
170 |
+
|
171 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
172 |
+
x = self.proj(x)
|
173 |
+
x = self.proj_drop(x)
|
174 |
+
return x
|
175 |
+
|
176 |
+
def extra_repr(self) -> str:
|
177 |
+
return f'dim={self.dim}, window_size={self.window_size}, ' \
|
178 |
+
f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
|
179 |
+
|
180 |
+
def flops(self, N):
|
181 |
+
# calculate flops for 1 window with token length of N
|
182 |
+
flops = 0
|
183 |
+
# qkv = self.qkv(x)
|
184 |
+
flops += N * self.dim * 3 * self.dim
|
185 |
+
# attn = (q @ k.transpose(-2, -1))
|
186 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
187 |
+
# x = (attn @ v)
|
188 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
189 |
+
# x = self.proj(x)
|
190 |
+
flops += N * self.dim * self.dim
|
191 |
+
return flops
|
192 |
+
|
193 |
+
class SwinTransformerBlock(nn.Module):
|
194 |
+
r""" Swin Transformer Block.
|
195 |
+
Args:
|
196 |
+
dim (int): Number of input channels.
|
197 |
+
input_resolution (tuple[int]): Input resulotion.
|
198 |
+
num_heads (int): Number of attention heads.
|
199 |
+
window_size (int): Window size.
|
200 |
+
shift_size (int): Shift size for SW-MSA.
|
201 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
202 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
203 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
204 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
205 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
206 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
207 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
208 |
+
pretrained_window_size (int): Window size in pre-training.
|
209 |
+
"""
|
210 |
+
|
211 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
212 |
+
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
|
213 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
|
214 |
+
super().__init__()
|
215 |
+
self.dim = dim
|
216 |
+
self.input_resolution = input_resolution
|
217 |
+
self.num_heads = num_heads
|
218 |
+
self.window_size = window_size
|
219 |
+
self.shift_size = shift_size
|
220 |
+
self.mlp_ratio = mlp_ratio
|
221 |
+
if min(self.input_resolution) <= self.window_size:
|
222 |
+
# if window size is larger than input resolution, we don't partition windows
|
223 |
+
self.shift_size = 0
|
224 |
+
self.window_size = min(self.input_resolution)
|
225 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
226 |
+
|
227 |
+
self.norm1 = norm_layer(dim)
|
228 |
+
self.attn = WindowAttention(
|
229 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
230 |
+
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
|
231 |
+
pretrained_window_size=to_2tuple(pretrained_window_size))
|
232 |
+
|
233 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
234 |
+
self.norm2 = norm_layer(dim)
|
235 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
236 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
237 |
+
|
238 |
+
if self.shift_size > 0:
|
239 |
+
attn_mask = self.calculate_mask(self.input_resolution)
|
240 |
+
else:
|
241 |
+
attn_mask = None
|
242 |
+
|
243 |
+
self.register_buffer("attn_mask", attn_mask)
|
244 |
+
|
245 |
+
def calculate_mask(self, x_size):
|
246 |
+
# calculate attention mask for SW-MSA
|
247 |
+
H, W = x_size
|
248 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
249 |
+
h_slices = (slice(0, -self.window_size),
|
250 |
+
slice(-self.window_size, -self.shift_size),
|
251 |
+
slice(-self.shift_size, None))
|
252 |
+
w_slices = (slice(0, -self.window_size),
|
253 |
+
slice(-self.window_size, -self.shift_size),
|
254 |
+
slice(-self.shift_size, None))
|
255 |
+
cnt = 0
|
256 |
+
for h in h_slices:
|
257 |
+
for w in w_slices:
|
258 |
+
img_mask[:, h, w, :] = cnt
|
259 |
+
cnt += 1
|
260 |
+
|
261 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
262 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
263 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
264 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
265 |
+
|
266 |
+
return attn_mask
|
267 |
+
|
268 |
+
def forward(self, x, x_size):
|
269 |
+
H, W = x_size
|
270 |
+
B, L, C = x.shape
|
271 |
+
#assert L == H * W, "input feature has wrong size"
|
272 |
+
|
273 |
+
shortcut = x
|
274 |
+
x = x.view(B, H, W, C)
|
275 |
+
|
276 |
+
# cyclic shift
|
277 |
+
if self.shift_size > 0:
|
278 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
279 |
+
else:
|
280 |
+
shifted_x = x
|
281 |
+
|
282 |
+
# partition windows
|
283 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
284 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
285 |
+
|
286 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
287 |
+
if self.input_resolution == x_size:
|
288 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
289 |
+
else:
|
290 |
+
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
291 |
+
|
292 |
+
# merge windows
|
293 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
294 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
295 |
+
|
296 |
+
# reverse cyclic shift
|
297 |
+
if self.shift_size > 0:
|
298 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
299 |
+
else:
|
300 |
+
x = shifted_x
|
301 |
+
x = x.view(B, H * W, C)
|
302 |
+
x = shortcut + self.drop_path(self.norm1(x))
|
303 |
+
|
304 |
+
# FFN
|
305 |
+
x = x + self.drop_path(self.norm2(self.mlp(x)))
|
306 |
+
|
307 |
+
return x
|
308 |
+
|
309 |
+
def extra_repr(self) -> str:
|
310 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
311 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
312 |
+
|
313 |
+
def flops(self):
|
314 |
+
flops = 0
|
315 |
+
H, W = self.input_resolution
|
316 |
+
# norm1
|
317 |
+
flops += self.dim * H * W
|
318 |
+
# W-MSA/SW-MSA
|
319 |
+
nW = H * W / self.window_size / self.window_size
|
320 |
+
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
321 |
+
# mlp
|
322 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
323 |
+
# norm2
|
324 |
+
flops += self.dim * H * W
|
325 |
+
return flops
|
326 |
+
|
327 |
+
class PatchMerging(nn.Module):
|
328 |
+
r""" Patch Merging Layer.
|
329 |
+
Args:
|
330 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
331 |
+
dim (int): Number of input channels.
|
332 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
333 |
+
"""
|
334 |
+
|
335 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
336 |
+
super().__init__()
|
337 |
+
self.input_resolution = input_resolution
|
338 |
+
self.dim = dim
|
339 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
340 |
+
self.norm = norm_layer(2 * dim)
|
341 |
+
|
342 |
+
def forward(self, x):
|
343 |
+
"""
|
344 |
+
x: B, H*W, C
|
345 |
+
"""
|
346 |
+
H, W = self.input_resolution
|
347 |
+
B, L, C = x.shape
|
348 |
+
assert L == H * W, "input feature has wrong size"
|
349 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
350 |
+
|
351 |
+
x = x.view(B, H, W, C)
|
352 |
+
|
353 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
354 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
355 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
356 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
357 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
358 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
359 |
+
|
360 |
+
x = self.reduction(x)
|
361 |
+
x = self.norm(x)
|
362 |
+
|
363 |
+
return x
|
364 |
+
|
365 |
+
def extra_repr(self) -> str:
|
366 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
367 |
+
|
368 |
+
def flops(self):
|
369 |
+
H, W = self.input_resolution
|
370 |
+
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
371 |
+
flops += H * W * self.dim // 2
|
372 |
+
return flops
|
373 |
+
|
374 |
+
class BasicLayer(nn.Module):
|
375 |
+
""" A basic Swin Transformer layer for one stage.
|
376 |
+
Args:
|
377 |
+
dim (int): Number of input channels.
|
378 |
+
input_resolution (tuple[int]): Input resolution.
|
379 |
+
depth (int): Number of blocks.
|
380 |
+
num_heads (int): Number of attention heads.
|
381 |
+
window_size (int): Local window size.
|
382 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
383 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
384 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
385 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
386 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
387 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
388 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
389 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
390 |
+
pretrained_window_size (int): Local window size in pre-training.
|
391 |
+
"""
|
392 |
+
|
393 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
394 |
+
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
|
395 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
396 |
+
pretrained_window_size=0):
|
397 |
+
|
398 |
+
super().__init__()
|
399 |
+
self.dim = dim
|
400 |
+
self.input_resolution = input_resolution
|
401 |
+
self.depth = depth
|
402 |
+
self.use_checkpoint = use_checkpoint
|
403 |
+
|
404 |
+
# build blocks
|
405 |
+
self.blocks = nn.ModuleList([
|
406 |
+
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
407 |
+
num_heads=num_heads, window_size=window_size,
|
408 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
409 |
+
mlp_ratio=mlp_ratio,
|
410 |
+
qkv_bias=qkv_bias,
|
411 |
+
drop=drop, attn_drop=attn_drop,
|
412 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
413 |
+
norm_layer=norm_layer,
|
414 |
+
pretrained_window_size=pretrained_window_size)
|
415 |
+
for i in range(depth)])
|
416 |
+
|
417 |
+
# patch merging layer
|
418 |
+
if downsample is not None:
|
419 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
420 |
+
else:
|
421 |
+
self.downsample = None
|
422 |
+
|
423 |
+
def forward(self, x, x_size):
|
424 |
+
for blk in self.blocks:
|
425 |
+
if self.use_checkpoint:
|
426 |
+
x = checkpoint.checkpoint(blk, x, x_size)
|
427 |
+
else:
|
428 |
+
x = blk(x, x_size)
|
429 |
+
if self.downsample is not None:
|
430 |
+
x = self.downsample(x)
|
431 |
+
return x
|
432 |
+
|
433 |
+
def extra_repr(self) -> str:
|
434 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
435 |
+
|
436 |
+
def flops(self):
|
437 |
+
flops = 0
|
438 |
+
for blk in self.blocks:
|
439 |
+
flops += blk.flops()
|
440 |
+
if self.downsample is not None:
|
441 |
+
flops += self.downsample.flops()
|
442 |
+
return flops
|
443 |
+
|
444 |
+
def _init_respostnorm(self):
|
445 |
+
for blk in self.blocks:
|
446 |
+
nn.init.constant_(blk.norm1.bias, 0)
|
447 |
+
nn.init.constant_(blk.norm1.weight, 0)
|
448 |
+
nn.init.constant_(blk.norm2.bias, 0)
|
449 |
+
nn.init.constant_(blk.norm2.weight, 0)
|
450 |
+
|
451 |
+
class PatchEmbed(nn.Module):
|
452 |
+
r""" Image to Patch Embedding
|
453 |
+
Args:
|
454 |
+
img_size (int): Image size. Default: 224.
|
455 |
+
patch_size (int): Patch token size. Default: 4.
|
456 |
+
in_chans (int): Number of input image channels. Default: 3.
|
457 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
458 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
459 |
+
"""
|
460 |
+
|
461 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
462 |
+
super().__init__()
|
463 |
+
img_size = to_2tuple(img_size)
|
464 |
+
patch_size = to_2tuple(patch_size)
|
465 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
466 |
+
self.img_size = img_size
|
467 |
+
self.patch_size = patch_size
|
468 |
+
self.patches_resolution = patches_resolution
|
469 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
470 |
+
|
471 |
+
self.in_chans = in_chans
|
472 |
+
self.embed_dim = embed_dim
|
473 |
+
|
474 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
475 |
+
if norm_layer is not None:
|
476 |
+
self.norm = norm_layer(embed_dim)
|
477 |
+
else:
|
478 |
+
self.norm = None
|
479 |
+
|
480 |
+
def forward(self, x):
|
481 |
+
B, C, H, W = x.shape
|
482 |
+
# FIXME look at relaxing size constraints
|
483 |
+
# assert H == self.img_size[0] and W == self.img_size[1],
|
484 |
+
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
485 |
+
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
|
486 |
+
if self.norm is not None:
|
487 |
+
x = self.norm(x)
|
488 |
+
return x
|
489 |
+
|
490 |
+
def flops(self):
|
491 |
+
Ho, Wo = self.patches_resolution
|
492 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
493 |
+
if self.norm is not None:
|
494 |
+
flops += Ho * Wo * self.embed_dim
|
495 |
+
return flops
|
496 |
+
|
497 |
+
class RSTB(nn.Module):
|
498 |
+
"""Residual Swin Transformer Block (RSTB).
|
499 |
+
|
500 |
+
Args:
|
501 |
+
dim (int): Number of input channels.
|
502 |
+
input_resolution (tuple[int]): Input resolution.
|
503 |
+
depth (int): Number of blocks.
|
504 |
+
num_heads (int): Number of attention heads.
|
505 |
+
window_size (int): Local window size.
|
506 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
507 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
508 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
509 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
510 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
511 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
512 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
513 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
514 |
+
img_size: Input image size.
|
515 |
+
patch_size: Patch size.
|
516 |
+
resi_connection: The convolutional block before residual connection.
|
517 |
+
"""
|
518 |
+
|
519 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
520 |
+
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
|
521 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
522 |
+
img_size=224, patch_size=4, resi_connection='1conv'):
|
523 |
+
super(RSTB, self).__init__()
|
524 |
+
|
525 |
+
self.dim = dim
|
526 |
+
self.input_resolution = input_resolution
|
527 |
+
|
528 |
+
self.residual_group = BasicLayer(dim=dim,
|
529 |
+
input_resolution=input_resolution,
|
530 |
+
depth=depth,
|
531 |
+
num_heads=num_heads,
|
532 |
+
window_size=window_size,
|
533 |
+
mlp_ratio=mlp_ratio,
|
534 |
+
qkv_bias=qkv_bias,
|
535 |
+
drop=drop, attn_drop=attn_drop,
|
536 |
+
drop_path=drop_path,
|
537 |
+
norm_layer=norm_layer,
|
538 |
+
downsample=downsample,
|
539 |
+
use_checkpoint=use_checkpoint)
|
540 |
+
|
541 |
+
if resi_connection == '1conv':
|
542 |
+
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
543 |
+
elif resi_connection == '3conv':
|
544 |
+
# to save parameters and memory
|
545 |
+
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
546 |
+
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
|
547 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
548 |
+
nn.Conv2d(dim // 4, dim, 3, 1, 1))
|
549 |
+
|
550 |
+
self.patch_embed = PatchEmbed(
|
551 |
+
img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
|
552 |
+
norm_layer=None)
|
553 |
+
|
554 |
+
self.patch_unembed = PatchUnEmbed(
|
555 |
+
img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
|
556 |
+
norm_layer=None)
|
557 |
+
|
558 |
+
def forward(self, x, x_size):
|
559 |
+
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
|
560 |
+
|
561 |
+
def flops(self):
|
562 |
+
flops = 0
|
563 |
+
flops += self.residual_group.flops()
|
564 |
+
H, W = self.input_resolution
|
565 |
+
flops += H * W * self.dim * self.dim * 9
|
566 |
+
flops += self.patch_embed.flops()
|
567 |
+
flops += self.patch_unembed.flops()
|
568 |
+
|
569 |
+
return flops
|
570 |
+
|
571 |
+
class PatchUnEmbed(nn.Module):
|
572 |
+
r""" Image to Patch Unembedding
|
573 |
+
|
574 |
+
Args:
|
575 |
+
img_size (int): Image size. Default: 224.
|
576 |
+
patch_size (int): Patch token size. Default: 4.
|
577 |
+
in_chans (int): Number of input image channels. Default: 3.
|
578 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
579 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
580 |
+
"""
|
581 |
+
|
582 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
583 |
+
super().__init__()
|
584 |
+
img_size = to_2tuple(img_size)
|
585 |
+
patch_size = to_2tuple(patch_size)
|
586 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
587 |
+
self.img_size = img_size
|
588 |
+
self.patch_size = patch_size
|
589 |
+
self.patches_resolution = patches_resolution
|
590 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
591 |
+
|
592 |
+
self.in_chans = in_chans
|
593 |
+
self.embed_dim = embed_dim
|
594 |
+
|
595 |
+
def forward(self, x, x_size):
|
596 |
+
B, HW, C = x.shape
|
597 |
+
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
|
598 |
+
return x
|
599 |
+
|
600 |
+
def flops(self):
|
601 |
+
flops = 0
|
602 |
+
return flops
|
603 |
+
|
604 |
+
|
605 |
+
class Upsample(nn.Sequential):
|
606 |
+
"""Upsample module.
|
607 |
+
|
608 |
+
Args:
|
609 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
610 |
+
num_feat (int): Channel number of intermediate features.
|
611 |
+
"""
|
612 |
+
|
613 |
+
def __init__(self, scale, num_feat):
|
614 |
+
m = []
|
615 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
616 |
+
for _ in range(int(math.log(scale, 2))):
|
617 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
618 |
+
m.append(nn.PixelShuffle(2))
|
619 |
+
elif scale == 3:
|
620 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
621 |
+
m.append(nn.PixelShuffle(3))
|
622 |
+
else:
|
623 |
+
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
624 |
+
super(Upsample, self).__init__(*m)
|
625 |
+
|
626 |
+
class Upsample_hf(nn.Sequential):
|
627 |
+
"""Upsample module.
|
628 |
+
|
629 |
+
Args:
|
630 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
631 |
+
num_feat (int): Channel number of intermediate features.
|
632 |
+
"""
|
633 |
+
|
634 |
+
def __init__(self, scale, num_feat):
|
635 |
+
m = []
|
636 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
637 |
+
for _ in range(int(math.log(scale, 2))):
|
638 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
639 |
+
m.append(nn.PixelShuffle(2))
|
640 |
+
elif scale == 3:
|
641 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
642 |
+
m.append(nn.PixelShuffle(3))
|
643 |
+
else:
|
644 |
+
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
645 |
+
super(Upsample_hf, self).__init__(*m)
|
646 |
+
|
647 |
+
|
648 |
+
class UpsampleOneStep(nn.Sequential):
|
649 |
+
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
|
650 |
+
Used in lightweight SR to save parameters.
|
651 |
+
|
652 |
+
Args:
|
653 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
654 |
+
num_feat (int): Channel number of intermediate features.
|
655 |
+
|
656 |
+
"""
|
657 |
+
|
658 |
+
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
|
659 |
+
self.num_feat = num_feat
|
660 |
+
self.input_resolution = input_resolution
|
661 |
+
m = []
|
662 |
+
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
|
663 |
+
m.append(nn.PixelShuffle(scale))
|
664 |
+
super(UpsampleOneStep, self).__init__(*m)
|
665 |
+
|
666 |
+
def flops(self):
|
667 |
+
H, W = self.input_resolution
|
668 |
+
flops = H * W * self.num_feat * 3 * 9
|
669 |
+
return flops
|
670 |
+
|
671 |
+
|
672 |
+
|
673 |
+
class Swin2SR(nn.Module):
|
674 |
+
r""" Swin2SR
|
675 |
+
A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
|
676 |
+
|
677 |
+
Args:
|
678 |
+
img_size (int | tuple(int)): Input image size. Default 64
|
679 |
+
patch_size (int | tuple(int)): Patch size. Default: 1
|
680 |
+
in_chans (int): Number of input image channels. Default: 3
|
681 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
682 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
683 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
684 |
+
window_size (int): Window size. Default: 7
|
685 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
686 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
687 |
+
drop_rate (float): Dropout rate. Default: 0
|
688 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
689 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
690 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
691 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
692 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
693 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
694 |
+
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
695 |
+
img_range: Image range. 1. or 255.
|
696 |
+
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
|
697 |
+
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
|
698 |
+
"""
|
699 |
+
|
700 |
+
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
701 |
+
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
|
702 |
+
window_size=7, mlp_ratio=4., qkv_bias=True,
|
703 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
704 |
+
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
705 |
+
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
706 |
+
**kwargs):
|
707 |
+
super(Swin2SR, self).__init__()
|
708 |
+
num_in_ch = in_chans
|
709 |
+
num_out_ch = in_chans
|
710 |
+
num_feat = 64
|
711 |
+
self.img_range = img_range
|
712 |
+
if in_chans == 3:
|
713 |
+
rgb_mean = (0.4488, 0.4371, 0.4040)
|
714 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
715 |
+
else:
|
716 |
+
self.mean = torch.zeros(1, 1, 1, 1)
|
717 |
+
self.upscale = upscale
|
718 |
+
self.upsampler = upsampler
|
719 |
+
self.window_size = window_size
|
720 |
+
|
721 |
+
#####################################################################################################
|
722 |
+
################################### 1, shallow feature extraction ###################################
|
723 |
+
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
724 |
+
|
725 |
+
#####################################################################################################
|
726 |
+
################################### 2, deep feature extraction ######################################
|
727 |
+
self.num_layers = len(depths)
|
728 |
+
self.embed_dim = embed_dim
|
729 |
+
self.ape = ape
|
730 |
+
self.patch_norm = patch_norm
|
731 |
+
self.num_features = embed_dim
|
732 |
+
self.mlp_ratio = mlp_ratio
|
733 |
+
|
734 |
+
# split image into non-overlapping patches
|
735 |
+
self.patch_embed = PatchEmbed(
|
736 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
737 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
738 |
+
num_patches = self.patch_embed.num_patches
|
739 |
+
patches_resolution = self.patch_embed.patches_resolution
|
740 |
+
self.patches_resolution = patches_resolution
|
741 |
+
|
742 |
+
# merge non-overlapping patches into image
|
743 |
+
self.patch_unembed = PatchUnEmbed(
|
744 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
745 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
746 |
+
|
747 |
+
# absolute position embedding
|
748 |
+
if self.ape:
|
749 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
750 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
751 |
+
|
752 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
753 |
+
|
754 |
+
# stochastic depth
|
755 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
756 |
+
|
757 |
+
# build Residual Swin Transformer blocks (RSTB)
|
758 |
+
self.layers = nn.ModuleList()
|
759 |
+
for i_layer in range(self.num_layers):
|
760 |
+
layer = RSTB(dim=embed_dim,
|
761 |
+
input_resolution=(patches_resolution[0],
|
762 |
+
patches_resolution[1]),
|
763 |
+
depth=depths[i_layer],
|
764 |
+
num_heads=num_heads[i_layer],
|
765 |
+
window_size=window_size,
|
766 |
+
mlp_ratio=self.mlp_ratio,
|
767 |
+
qkv_bias=qkv_bias,
|
768 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
769 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
770 |
+
norm_layer=norm_layer,
|
771 |
+
downsample=None,
|
772 |
+
use_checkpoint=use_checkpoint,
|
773 |
+
img_size=img_size,
|
774 |
+
patch_size=patch_size,
|
775 |
+
resi_connection=resi_connection
|
776 |
+
|
777 |
+
)
|
778 |
+
self.layers.append(layer)
|
779 |
+
|
780 |
+
if self.upsampler == 'pixelshuffle_hf':
|
781 |
+
self.layers_hf = nn.ModuleList()
|
782 |
+
for i_layer in range(self.num_layers):
|
783 |
+
layer = RSTB(dim=embed_dim,
|
784 |
+
input_resolution=(patches_resolution[0],
|
785 |
+
patches_resolution[1]),
|
786 |
+
depth=depths[i_layer],
|
787 |
+
num_heads=num_heads[i_layer],
|
788 |
+
window_size=window_size,
|
789 |
+
mlp_ratio=self.mlp_ratio,
|
790 |
+
qkv_bias=qkv_bias,
|
791 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
792 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
793 |
+
norm_layer=norm_layer,
|
794 |
+
downsample=None,
|
795 |
+
use_checkpoint=use_checkpoint,
|
796 |
+
img_size=img_size,
|
797 |
+
patch_size=patch_size,
|
798 |
+
resi_connection=resi_connection
|
799 |
+
|
800 |
+
)
|
801 |
+
self.layers_hf.append(layer)
|
802 |
+
|
803 |
+
self.norm = norm_layer(self.num_features)
|
804 |
+
|
805 |
+
# build the last conv layer in deep feature extraction
|
806 |
+
if resi_connection == '1conv':
|
807 |
+
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
808 |
+
elif resi_connection == '3conv':
|
809 |
+
# to save parameters and memory
|
810 |
+
self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
|
811 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
812 |
+
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
|
813 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
814 |
+
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
|
815 |
+
|
816 |
+
#####################################################################################################
|
817 |
+
################################ 3, high quality image reconstruction ################################
|
818 |
+
if self.upsampler == 'pixelshuffle':
|
819 |
+
# for classical SR
|
820 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
821 |
+
nn.LeakyReLU(inplace=True))
|
822 |
+
self.upsample = Upsample(upscale, num_feat)
|
823 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
824 |
+
elif self.upsampler == 'pixelshuffle_aux':
|
825 |
+
self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
826 |
+
self.conv_before_upsample = nn.Sequential(
|
827 |
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
828 |
+
nn.LeakyReLU(inplace=True))
|
829 |
+
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
830 |
+
self.conv_after_aux = nn.Sequential(
|
831 |
+
nn.Conv2d(3, num_feat, 3, 1, 1),
|
832 |
+
nn.LeakyReLU(inplace=True))
|
833 |
+
self.upsample = Upsample(upscale, num_feat)
|
834 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
835 |
+
|
836 |
+
elif self.upsampler == 'pixelshuffle_hf':
|
837 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
838 |
+
nn.LeakyReLU(inplace=True))
|
839 |
+
self.upsample = Upsample(upscale, num_feat)
|
840 |
+
self.upsample_hf = Upsample_hf(upscale, num_feat)
|
841 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
842 |
+
self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
|
843 |
+
nn.LeakyReLU(inplace=True))
|
844 |
+
self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
845 |
+
self.conv_before_upsample_hf = nn.Sequential(
|
846 |
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
847 |
+
nn.LeakyReLU(inplace=True))
|
848 |
+
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
849 |
+
|
850 |
+
elif self.upsampler == 'pixelshuffledirect':
|
851 |
+
# for lightweight SR (to save parameters)
|
852 |
+
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
853 |
+
(patches_resolution[0], patches_resolution[1]))
|
854 |
+
elif self.upsampler == 'nearest+conv':
|
855 |
+
# for real-world SR (less artifacts)
|
856 |
+
assert self.upscale == 4, 'only support x4 now.'
|
857 |
+
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
858 |
+
nn.LeakyReLU(inplace=True))
|
859 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
860 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
861 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
862 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
863 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
864 |
+
else:
|
865 |
+
# for image denoising and JPEG compression artifact reduction
|
866 |
+
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
|
867 |
+
|
868 |
+
self.apply(self._init_weights)
|
869 |
+
|
870 |
+
def _init_weights(self, m):
|
871 |
+
if isinstance(m, nn.Linear):
|
872 |
+
trunc_normal_(m.weight, std=.02)
|
873 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
874 |
+
nn.init.constant_(m.bias, 0)
|
875 |
+
elif isinstance(m, nn.LayerNorm):
|
876 |
+
nn.init.constant_(m.bias, 0)
|
877 |
+
nn.init.constant_(m.weight, 1.0)
|
878 |
+
|
879 |
+
@torch.jit.ignore
|
880 |
+
def no_weight_decay(self):
|
881 |
+
return {'absolute_pos_embed'}
|
882 |
+
|
883 |
+
@torch.jit.ignore
|
884 |
+
def no_weight_decay_keywords(self):
|
885 |
+
return {'relative_position_bias_table'}
|
886 |
+
|
887 |
+
def check_image_size(self, x):
|
888 |
+
_, _, h, w = x.size()
|
889 |
+
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
890 |
+
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
891 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
892 |
+
return x
|
893 |
+
|
894 |
+
def forward_features(self, x):
|
895 |
+
x_size = (x.shape[2], x.shape[3])
|
896 |
+
x = self.patch_embed(x)
|
897 |
+
if self.ape:
|
898 |
+
x = x + self.absolute_pos_embed
|
899 |
+
x = self.pos_drop(x)
|
900 |
+
|
901 |
+
for layer in self.layers:
|
902 |
+
x = layer(x, x_size)
|
903 |
+
|
904 |
+
x = self.norm(x) # B L C
|
905 |
+
x = self.patch_unembed(x, x_size)
|
906 |
+
|
907 |
+
return x
|
908 |
+
|
909 |
+
def forward_features_hf(self, x):
|
910 |
+
x_size = (x.shape[2], x.shape[3])
|
911 |
+
x = self.patch_embed(x)
|
912 |
+
if self.ape:
|
913 |
+
x = x + self.absolute_pos_embed
|
914 |
+
x = self.pos_drop(x)
|
915 |
+
|
916 |
+
for layer in self.layers_hf:
|
917 |
+
x = layer(x, x_size)
|
918 |
+
|
919 |
+
x = self.norm(x) # B L C
|
920 |
+
x = self.patch_unembed(x, x_size)
|
921 |
+
|
922 |
+
return x
|
923 |
+
|
924 |
+
def forward(self, x):
|
925 |
+
H, W = x.shape[2:]
|
926 |
+
x = self.check_image_size(x)
|
927 |
+
|
928 |
+
self.mean = self.mean.type_as(x)
|
929 |
+
x = (x - self.mean) * self.img_range
|
930 |
+
|
931 |
+
if self.upsampler == 'pixelshuffle':
|
932 |
+
# for classical SR
|
933 |
+
x = self.conv_first(x)
|
934 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
935 |
+
x = self.conv_before_upsample(x)
|
936 |
+
x = self.conv_last(self.upsample(x))
|
937 |
+
elif self.upsampler == 'pixelshuffle_aux':
|
938 |
+
bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
|
939 |
+
bicubic = self.conv_bicubic(bicubic)
|
940 |
+
x = self.conv_first(x)
|
941 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
942 |
+
x = self.conv_before_upsample(x)
|
943 |
+
aux = self.conv_aux(x) # b, 3, LR_H, LR_W
|
944 |
+
x = self.conv_after_aux(aux)
|
945 |
+
x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
|
946 |
+
x = self.conv_last(x)
|
947 |
+
aux = aux / self.img_range + self.mean
|
948 |
+
elif self.upsampler == 'pixelshuffle_hf':
|
949 |
+
# for classical SR with HF
|
950 |
+
x = self.conv_first(x)
|
951 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
952 |
+
x_before = self.conv_before_upsample(x)
|
953 |
+
x_out = self.conv_last(self.upsample(x_before))
|
954 |
+
|
955 |
+
x_hf = self.conv_first_hf(x_before)
|
956 |
+
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
|
957 |
+
x_hf = self.conv_before_upsample_hf(x_hf)
|
958 |
+
x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
|
959 |
+
x = x_out + x_hf
|
960 |
+
x_hf = x_hf / self.img_range + self.mean
|
961 |
+
|
962 |
+
elif self.upsampler == 'pixelshuffledirect':
|
963 |
+
# for lightweight SR
|
964 |
+
x = self.conv_first(x)
|
965 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
966 |
+
x = self.upsample(x)
|
967 |
+
elif self.upsampler == 'nearest+conv':
|
968 |
+
# for real-world SR
|
969 |
+
x = self.conv_first(x)
|
970 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
971 |
+
x = self.conv_before_upsample(x)
|
972 |
+
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
973 |
+
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
974 |
+
x = self.conv_last(self.lrelu(self.conv_hr(x)))
|
975 |
+
else:
|
976 |
+
# for image denoising and JPEG compression artifact reduction
|
977 |
+
x_first = self.conv_first(x)
|
978 |
+
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
979 |
+
x = x + self.conv_last(res)
|
980 |
+
|
981 |
+
x = x / self.img_range + self.mean
|
982 |
+
if self.upsampler == "pixelshuffle_aux":
|
983 |
+
return x[:, :, :H*self.upscale, :W*self.upscale], aux
|
984 |
+
|
985 |
+
elif self.upsampler == "pixelshuffle_hf":
|
986 |
+
x_out = x_out / self.img_range + self.mean
|
987 |
+
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
|
988 |
+
|
989 |
+
else:
|
990 |
+
return x[:, :, :H*self.upscale, :W*self.upscale]
|
991 |
+
|
992 |
+
def flops(self):
|
993 |
+
flops = 0
|
994 |
+
H, W = self.patches_resolution
|
995 |
+
flops += H * W * 3 * self.embed_dim * 9
|
996 |
+
flops += self.patch_embed.flops()
|
997 |
+
for i, layer in enumerate(self.layers):
|
998 |
+
flops += layer.flops()
|
999 |
+
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
1000 |
+
flops += self.upsample.flops()
|
1001 |
+
return flops
|
1002 |
+
|
1003 |
+
|
1004 |
+
if __name__ == '__main__':
|
1005 |
+
upscale = 4
|
1006 |
+
window_size = 8
|
1007 |
+
height = (1024 // upscale // window_size + 1) * window_size
|
1008 |
+
width = (720 // upscale // window_size + 1) * window_size
|
1009 |
+
model = Swin2SR(upscale=2, img_size=(height, width),
|
1010 |
+
window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
|
1011 |
+
embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
|
1012 |
+
print(model)
|
1013 |
+
print(height, width, model.flops() / 1e9)
|
1014 |
+
|
1015 |
+
x = torch.randn((1, 3, height, width))
|
1016 |
+
x = model(x)
|
1017 |
print(x.shape)
|
sd/stable-diffusion-webui/html/extra-networks-card.html
CHANGED
@@ -7,6 +7,7 @@
|
|
7 |
<span style="display:none" class='search_term'>{search_term}</span>
|
8 |
</div>
|
9 |
<span class='name'>{name}</span>
|
|
|
10 |
</div>
|
11 |
</div>
|
12 |
|
|
|
7 |
<span style="display:none" class='search_term'>{search_term}</span>
|
8 |
</div>
|
9 |
<span class='name'>{name}</span>
|
10 |
+
<span class='description'>{description}</span>
|
11 |
</div>
|
12 |
</div>
|
13 |
|
sd/stable-diffusion-webui/html/footer.html
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
-
<div>
|
2 |
-
<a href="/docs">API</a>
|
3 |
-
•
|
4 |
-
<a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a>
|
5 |
-
•
|
6 |
-
<a href="https://gradio.app">Gradio</a>
|
7 |
-
•
|
8 |
-
<a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false">Reload UI</a>
|
9 |
-
</div>
|
10 |
-
<br />
|
11 |
-
<div class="versions">
|
12 |
-
{versions}
|
13 |
-
</div>
|
|
|
1 |
+
<div>
|
2 |
+
<a href="/docs">API</a>
|
3 |
+
•
|
4 |
+
<a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a>
|
5 |
+
•
|
6 |
+
<a href="https://gradio.app">Gradio</a>
|
7 |
+
•
|
8 |
+
<a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false">Reload UI</a>
|
9 |
+
</div>
|
10 |
+
<br />
|
11 |
+
<div class="versions">
|
12 |
+
{versions}
|
13 |
+
</div>
|
sd/stable-diffusion-webui/html/licenses.html
CHANGED
@@ -1,419 +1,638 @@
|
|
1 |
-
<style>
|
2 |
-
#licenses h2 {font-size: 1.2em; font-weight: bold; margin-bottom: 0.2em;}
|
3 |
-
#licenses small {font-size: 0.95em; opacity: 0.85;}
|
4 |
-
#licenses pre { margin: 1em 0 2em 0;}
|
5 |
-
</style>
|
6 |
-
|
7 |
-
<h2><a href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">CodeFormer</a></h2>
|
8 |
-
<small>Parts of CodeFormer code had to be copied to be compatible with GFPGAN.</small>
|
9 |
-
<pre>
|
10 |
-
S-Lab License 1.0
|
11 |
-
|
12 |
-
Copyright 2022 S-Lab
|
13 |
-
|
14 |
-
Redistribution and use for non-commercial purpose in source and
|
15 |
-
binary forms, with or without modification, are permitted provided
|
16 |
-
that the following conditions are met:
|
17 |
-
|
18 |
-
1. Redistributions of source code must retain the above copyright
|
19 |
-
notice, this list of conditions and the following disclaimer.
|
20 |
-
|
21 |
-
2. Redistributions in binary form must reproduce the above copyright
|
22 |
-
notice, this list of conditions and the following disclaimer in
|
23 |
-
the documentation and/or other materials provided with the
|
24 |
-
distribution.
|
25 |
-
|
26 |
-
3. Neither the name of the copyright holder nor the names of its
|
27 |
-
contributors may be used to endorse or promote products derived
|
28 |
-
from this software without specific prior written permission.
|
29 |
-
|
30 |
-
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
31 |
-
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
32 |
-
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
33 |
-
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
34 |
-
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
35 |
-
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
36 |
-
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
37 |
-
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
38 |
-
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
39 |
-
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
40 |
-
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
41 |
-
|
42 |
-
In the event that redistribution and/or use for commercial purpose in
|
43 |
-
source or binary forms, with or without modification is required,
|
44 |
-
please contact the contributor(s) of the work.
|
45 |
-
</pre>
|
46 |
-
|
47 |
-
|
48 |
-
<h2><a href="https://github.com/victorca25/iNNfer/blob/main/LICENSE">ESRGAN</a></h2>
|
49 |
-
<small>Code for architecture and reading models copied.</small>
|
50 |
-
<pre>
|
51 |
-
MIT License
|
52 |
-
|
53 |
-
Copyright (c) 2021 victorca25
|
54 |
-
|
55 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
56 |
-
of this software and associated documentation files (the "Software"), to deal
|
57 |
-
in the Software without restriction, including without limitation the rights
|
58 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
59 |
-
copies of the Software, and to permit persons to whom the Software is
|
60 |
-
furnished to do so, subject to the following conditions:
|
61 |
-
|
62 |
-
The above copyright notice and this permission notice shall be included in all
|
63 |
-
copies or substantial portions of the Software.
|
64 |
-
|
65 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
66 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
67 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
68 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
69 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
70 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
71 |
-
SOFTWARE.
|
72 |
-
</pre>
|
73 |
-
|
74 |
-
<h2><a href="https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE">Real-ESRGAN</a></h2>
|
75 |
-
<small>Some code is copied to support ESRGAN models.</small>
|
76 |
-
<pre>
|
77 |
-
BSD 3-Clause License
|
78 |
-
|
79 |
-
Copyright (c) 2021, Xintao Wang
|
80 |
-
All rights reserved.
|
81 |
-
|
82 |
-
Redistribution and use in source and binary forms, with or without
|
83 |
-
modification, are permitted provided that the following conditions are met:
|
84 |
-
|
85 |
-
1. Redistributions of source code must retain the above copyright notice, this
|
86 |
-
list of conditions and the following disclaimer.
|
87 |
-
|
88 |
-
2. Redistributions in binary form must reproduce the above copyright notice,
|
89 |
-
this list of conditions and the following disclaimer in the documentation
|
90 |
-
and/or other materials provided with the distribution.
|
91 |
-
|
92 |
-
3. Neither the name of the copyright holder nor the names of its
|
93 |
-
contributors may be used to endorse or promote products derived from
|
94 |
-
this software without specific prior written permission.
|
95 |
-
|
96 |
-
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
97 |
-
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
98 |
-
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
99 |
-
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
100 |
-
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
101 |
-
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
102 |
-
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
103 |
-
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
104 |
-
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
105 |
-
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
106 |
-
</pre>
|
107 |
-
|
108 |
-
<h2><a href="https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE">InvokeAI</a></h2>
|
109 |
-
<small>Some code for compatibility with OSX is taken from lstein's repository.</small>
|
110 |
-
<pre>
|
111 |
-
MIT License
|
112 |
-
|
113 |
-
Copyright (c) 2022 InvokeAI Team
|
114 |
-
|
115 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
116 |
-
of this software and associated documentation files (the "Software"), to deal
|
117 |
-
in the Software without restriction, including without limitation the rights
|
118 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
119 |
-
copies of the Software, and to permit persons to whom the Software is
|
120 |
-
furnished to do so, subject to the following conditions:
|
121 |
-
|
122 |
-
The above copyright notice and this permission notice shall be included in all
|
123 |
-
copies or substantial portions of the Software.
|
124 |
-
|
125 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
126 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
127 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
128 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
129 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
130 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
131 |
-
SOFTWARE.
|
132 |
-
</pre>
|
133 |
-
|
134 |
-
<h2><a href="https://github.com/Hafiidz/latent-diffusion/blob/main/LICENSE">LDSR</a></h2>
|
135 |
-
<small>Code added by contirubtors, most likely copied from this repository.</small>
|
136 |
-
<pre>
|
137 |
-
MIT License
|
138 |
-
|
139 |
-
Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
|
140 |
-
|
141 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
142 |
-
of this software and associated documentation files (the "Software"), to deal
|
143 |
-
in the Software without restriction, including without limitation the rights
|
144 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
145 |
-
copies of the Software, and to permit persons to whom the Software is
|
146 |
-
furnished to do so, subject to the following conditions:
|
147 |
-
|
148 |
-
The above copyright notice and this permission notice shall be included in all
|
149 |
-
copies or substantial portions of the Software.
|
150 |
-
|
151 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
152 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
153 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
154 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
155 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
156 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
157 |
-
SOFTWARE.
|
158 |
-
</pre>
|
159 |
-
|
160 |
-
<h2><a href="https://github.com/pharmapsychotic/clip-interrogator/blob/main/LICENSE">CLIP Interrogator</a></h2>
|
161 |
-
<small>Some small amounts of code borrowed and reworked.</small>
|
162 |
-
<pre>
|
163 |
-
MIT License
|
164 |
-
|
165 |
-
Copyright (c) 2022 pharmapsychotic
|
166 |
-
|
167 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
168 |
-
of this software and associated documentation files (the "Software"), to deal
|
169 |
-
in the Software without restriction, including without limitation the rights
|
170 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
171 |
-
copies of the Software, and to permit persons to whom the Software is
|
172 |
-
furnished to do so, subject to the following conditions:
|
173 |
-
|
174 |
-
The above copyright notice and this permission notice shall be included in all
|
175 |
-
copies or substantial portions of the Software.
|
176 |
-
|
177 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
178 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
179 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
180 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
181 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
182 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
183 |
-
SOFTWARE.
|
184 |
-
</pre>
|
185 |
-
|
186 |
-
<h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a></h2>
|
187 |
-
<small>Code added by contributors, most likely copied from this repository.</small>
|
188 |
-
|
189 |
-
<pre>
|
190 |
-
Apache License
|
191 |
-
Version 2.0, January 2004
|
192 |
-
http://www.apache.org/licenses/
|
193 |
-
|
194 |
-
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
195 |
-
|
196 |
-
1. Definitions.
|
197 |
-
|
198 |
-
"License" shall mean the terms and conditions for use, reproduction,
|
199 |
-
and distribution as defined by Sections 1 through 9 of this document.
|
200 |
-
|
201 |
-
"Licensor" shall mean the copyright owner or entity authorized by
|
202 |
-
the copyright owner that is granting the License.
|
203 |
-
|
204 |
-
"Legal Entity" shall mean the union of the acting entity and all
|
205 |
-
other entities that control, are controlled by, or are under common
|
206 |
-
control with that entity. For the purposes of this definition,
|
207 |
-
"control" means (i) the power, direct or indirect, to cause the
|
208 |
-
direction or management of such entity, whether by contract or
|
209 |
-
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
210 |
-
outstanding shares, or (iii) beneficial ownership of such entity.
|
211 |
-
|
212 |
-
"You" (or "Your") shall mean an individual or Legal Entity
|
213 |
-
exercising permissions granted by this License.
|
214 |
-
|
215 |
-
"Source" form shall mean the preferred form for making modifications,
|
216 |
-
including but not limited to software source code, documentation
|
217 |
-
source, and configuration files.
|
218 |
-
|
219 |
-
"Object" form shall mean any form resulting from mechanical
|
220 |
-
transformation or translation of a Source form, including but
|
221 |
-
not limited to compiled object code, generated documentation,
|
222 |
-
and conversions to other media types.
|
223 |
-
|
224 |
-
"Work" shall mean the work of authorship, whether in Source or
|
225 |
-
Object form, made available under the License, as indicated by a
|
226 |
-
copyright notice that is included in or attached to the work
|
227 |
-
(an example is provided in the Appendix below).
|
228 |
-
|
229 |
-
"Derivative Works" shall mean any work, whether in Source or Object
|
230 |
-
form, that is based on (or derived from) the Work and for which the
|
231 |
-
editorial revisions, annotations, elaborations, or other modifications
|
232 |
-
represent, as a whole, an original work of authorship. For the purposes
|
233 |
-
of this License, Derivative Works shall not include works that remain
|
234 |
-
separable from, or merely link (or bind by name) to the interfaces of,
|
235 |
-
the Work and Derivative Works thereof.
|
236 |
-
|
237 |
-
"Contribution" shall mean any work of authorship, including
|
238 |
-
the original version of the Work and any modifications or additions
|
239 |
-
to that Work or Derivative Works thereof, that is intentionally
|
240 |
-
submitted to Licensor for inclusion in the Work by the copyright owner
|
241 |
-
or by an individual or Legal Entity authorized to submit on behalf of
|
242 |
-
the copyright owner. For the purposes of this definition, "submitted"
|
243 |
-
means any form of electronic, verbal, or written communication sent
|
244 |
-
to the Licensor or its representatives, including but not limited to
|
245 |
-
communication on electronic mailing lists, source code control systems,
|
246 |
-
and issue tracking systems that are managed by, or on behalf of, the
|
247 |
-
Licensor for the purpose of discussing and improving the Work, but
|
248 |
-
excluding communication that is conspicuously marked or otherwise
|
249 |
-
designated in writing by the copyright owner as "Not a Contribution."
|
250 |
-
|
251 |
-
"Contributor" shall mean Licensor and any individual or Legal Entity
|
252 |
-
on behalf of whom a Contribution has been received by Licensor and
|
253 |
-
subsequently incorporated within the Work.
|
254 |
-
|
255 |
-
2. Grant of Copyright License. Subject to the terms and conditions of
|
256 |
-
this License, each Contributor hereby grants to You a perpetual,
|
257 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
258 |
-
copyright license to reproduce, prepare Derivative Works of,
|
259 |
-
publicly display, publicly perform, sublicense, and distribute the
|
260 |
-
Work and such Derivative Works in Source or Object form.
|
261 |
-
|
262 |
-
3. Grant of Patent License. Subject to the terms and conditions of
|
263 |
-
this License, each Contributor hereby grants to You a perpetual,
|
264 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
265 |
-
(except as stated in this section) patent license to make, have made,
|
266 |
-
use, offer to sell, sell, import, and otherwise transfer the Work,
|
267 |
-
where such license applies only to those patent claims licensable
|
268 |
-
by such Contributor that are necessarily infringed by their
|
269 |
-
Contribution(s) alone or by combination of their Contribution(s)
|
270 |
-
with the Work to which such Contribution(s) was submitted. If You
|
271 |
-
institute patent litigation against any entity (including a
|
272 |
-
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
273 |
-
or a Contribution incorporated within the Work constitutes direct
|
274 |
-
or contributory patent infringement, then any patent licenses
|
275 |
-
granted to You under this License for that Work shall terminate
|
276 |
-
as of the date such litigation is filed.
|
277 |
-
|
278 |
-
4. Redistribution. You may reproduce and distribute copies of the
|
279 |
-
Work or Derivative Works thereof in any medium, with or without
|
280 |
-
modifications, and in Source or Object form, provided that You
|
281 |
-
meet the following conditions:
|
282 |
-
|
283 |
-
(a) You must give any other recipients of the Work or
|
284 |
-
Derivative Works a copy of this License; and
|
285 |
-
|
286 |
-
(b) You must cause any modified files to carry prominent notices
|
287 |
-
stating that You changed the files; and
|
288 |
-
|
289 |
-
(c) You must retain, in the Source form of any Derivative Works
|
290 |
-
that You distribute, all copyright, patent, trademark, and
|
291 |
-
attribution notices from the Source form of the Work,
|
292 |
-
excluding those notices that do not pertain to any part of
|
293 |
-
the Derivative Works; and
|
294 |
-
|
295 |
-
(d) If the Work includes a "NOTICE" text file as part of its
|
296 |
-
distribution, then any Derivative Works that You distribute must
|
297 |
-
include a readable copy of the attribution notices contained
|
298 |
-
within such NOTICE file, excluding those notices that do not
|
299 |
-
pertain to any part of the Derivative Works, in at least one
|
300 |
-
of the following places: within a NOTICE text file distributed
|
301 |
-
as part of the Derivative Works; within the Source form or
|
302 |
-
documentation, if provided along with the Derivative Works; or,
|
303 |
-
within a display generated by the Derivative Works, if and
|
304 |
-
wherever such third-party notices normally appear. The contents
|
305 |
-
of the NOTICE file are for informational purposes only and
|
306 |
-
do not modify the License. You may add Your own attribution
|
307 |
-
notices within Derivative Works that You distribute, alongside
|
308 |
-
or as an addendum to the NOTICE text from the Work, provided
|
309 |
-
that such additional attribution notices cannot be construed
|
310 |
-
as modifying the License.
|
311 |
-
|
312 |
-
You may add Your own copyright statement to Your modifications and
|
313 |
-
may provide additional or different license terms and conditions
|
314 |
-
for use, reproduction, or distribution of Your modifications, or
|
315 |
-
for any such Derivative Works as a whole, provided Your use,
|
316 |
-
reproduction, and distribution of the Work otherwise complies with
|
317 |
-
the conditions stated in this License.
|
318 |
-
|
319 |
-
5. Submission of Contributions. Unless You explicitly state otherwise,
|
320 |
-
any Contribution intentionally submitted for inclusion in the Work
|
321 |
-
by You to the Licensor shall be under the terms and conditions of
|
322 |
-
this License, without any additional terms or conditions.
|
323 |
-
Notwithstanding the above, nothing herein shall supersede or modify
|
324 |
-
the terms of any separate license agreement you may have executed
|
325 |
-
with Licensor regarding such Contributions.
|
326 |
-
|
327 |
-
6. Trademarks. This License does not grant permission to use the trade
|
328 |
-
names, trademarks, service marks, or product names of the Licensor,
|
329 |
-
except as required for reasonable and customary use in describing the
|
330 |
-
origin of the Work and reproducing the content of the NOTICE file.
|
331 |
-
|
332 |
-
7. Disclaimer of Warranty. Unless required by applicable law or
|
333 |
-
agreed to in writing, Licensor provides the Work (and each
|
334 |
-
Contributor provides its Contributions) on an "AS IS" BASIS,
|
335 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
336 |
-
implied, including, without limitation, any warranties or conditions
|
337 |
-
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
338 |
-
PARTICULAR PURPOSE. You are solely responsible for determining the
|
339 |
-
appropriateness of using or redistributing the Work and assume any
|
340 |
-
risks associated with Your exercise of permissions under this License.
|
341 |
-
|
342 |
-
8. Limitation of Liability. In no event and under no legal theory,
|
343 |
-
whether in tort (including negligence), contract, or otherwise,
|
344 |
-
unless required by applicable law (such as deliberate and grossly
|
345 |
-
negligent acts) or agreed to in writing, shall any Contributor be
|
346 |
-
liable to You for damages, including any direct, indirect, special,
|
347 |
-
incidental, or consequential damages of any character arising as a
|
348 |
-
result of this License or out of the use or inability to use the
|
349 |
-
Work (including but not limited to damages for loss of goodwill,
|
350 |
-
work stoppage, computer failure or malfunction, or any and all
|
351 |
-
other commercial damages or losses), even if such Contributor
|
352 |
-
has been advised of the possibility of such damages.
|
353 |
-
|
354 |
-
9. Accepting Warranty or Additional Liability. While redistributing
|
355 |
-
the Work or Derivative Works thereof, You may choose to offer,
|
356 |
-
and charge a fee for, acceptance of support, warranty, indemnity,
|
357 |
-
or other liability obligations and/or rights consistent with this
|
358 |
-
License. However, in accepting such obligations, You may act only
|
359 |
-
on Your own behalf and on Your sole responsibility, not on behalf
|
360 |
-
of any other Contributor, and only if You agree to indemnify,
|
361 |
-
defend, and hold each Contributor harmless for any liability
|
362 |
-
incurred by, or claims asserted against, such Contributor by reason
|
363 |
-
of your accepting any such warranty or additional liability.
|
364 |
-
|
365 |
-
END OF TERMS AND CONDITIONS
|
366 |
-
|
367 |
-
APPENDIX: How to apply the Apache License to your work.
|
368 |
-
|
369 |
-
To apply the Apache License to your work, attach the following
|
370 |
-
boilerplate notice, with the fields enclosed by brackets "[]"
|
371 |
-
replaced with your own identifying information. (Don't include
|
372 |
-
the brackets!) The text should be enclosed in the appropriate
|
373 |
-
comment syntax for the file format. We also recommend that a
|
374 |
-
file or class name and description of purpose be included on the
|
375 |
-
same "printed page" as the copyright notice for easier
|
376 |
-
identification within third-party archives.
|
377 |
-
|
378 |
-
Copyright [2021] [SwinIR Authors]
|
379 |
-
|
380 |
-
Licensed under the Apache License, Version 2.0 (the "License");
|
381 |
-
you may not use this file except in compliance with the License.
|
382 |
-
You may obtain a copy of the License at
|
383 |
-
|
384 |
-
http://www.apache.org/licenses/LICENSE-2.0
|
385 |
-
|
386 |
-
Unless required by applicable law or agreed to in writing, software
|
387 |
-
distributed under the License is distributed on an "AS IS" BASIS,
|
388 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
389 |
-
See the License for the specific language governing permissions and
|
390 |
-
limitations under the License.
|
391 |
-
</pre>
|
392 |
-
|
393 |
-
<h2><a href="https://github.com/AminRezaei0x443/memory-efficient-attention/blob/main/LICENSE">Memory Efficient Attention</a></h2>
|
394 |
-
<small>The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.</small>
|
395 |
-
<pre>
|
396 |
-
MIT License
|
397 |
-
|
398 |
-
Copyright (c) 2023 Alex Birch
|
399 |
-
Copyright (c) 2023 Amin Rezaei
|
400 |
-
|
401 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
402 |
-
of this software and associated documentation files (the "Software"), to deal
|
403 |
-
in the Software without restriction, including without limitation the rights
|
404 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
405 |
-
copies of the Software, and to permit persons to whom the Software is
|
406 |
-
furnished to do so, subject to the following conditions:
|
407 |
-
|
408 |
-
The above copyright notice and this permission notice shall be included in all
|
409 |
-
copies or substantial portions of the Software.
|
410 |
-
|
411 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
412 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
413 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
414 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
415 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
416 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
417 |
-
SOFTWARE.
|
418 |
-
</pre>
|
419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<style>
|
2 |
+
#licenses h2 {font-size: 1.2em; font-weight: bold; margin-bottom: 0.2em;}
|
3 |
+
#licenses small {font-size: 0.95em; opacity: 0.85;}
|
4 |
+
#licenses pre { margin: 1em 0 2em 0;}
|
5 |
+
</style>
|
6 |
+
|
7 |
+
<h2><a href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">CodeFormer</a></h2>
|
8 |
+
<small>Parts of CodeFormer code had to be copied to be compatible with GFPGAN.</small>
|
9 |
+
<pre>
|
10 |
+
S-Lab License 1.0
|
11 |
+
|
12 |
+
Copyright 2022 S-Lab
|
13 |
+
|
14 |
+
Redistribution and use for non-commercial purpose in source and
|
15 |
+
binary forms, with or without modification, are permitted provided
|
16 |
+
that the following conditions are met:
|
17 |
+
|
18 |
+
1. Redistributions of source code must retain the above copyright
|
19 |
+
notice, this list of conditions and the following disclaimer.
|
20 |
+
|
21 |
+
2. Redistributions in binary form must reproduce the above copyright
|
22 |
+
notice, this list of conditions and the following disclaimer in
|
23 |
+
the documentation and/or other materials provided with the
|
24 |
+
distribution.
|
25 |
+
|
26 |
+
3. Neither the name of the copyright holder nor the names of its
|
27 |
+
contributors may be used to endorse or promote products derived
|
28 |
+
from this software without specific prior written permission.
|
29 |
+
|
30 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
31 |
+
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
32 |
+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
33 |
+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
34 |
+
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
35 |
+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
36 |
+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
37 |
+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
38 |
+
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
39 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
40 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
41 |
+
|
42 |
+
In the event that redistribution and/or use for commercial purpose in
|
43 |
+
source or binary forms, with or without modification is required,
|
44 |
+
please contact the contributor(s) of the work.
|
45 |
+
</pre>
|
46 |
+
|
47 |
+
|
48 |
+
<h2><a href="https://github.com/victorca25/iNNfer/blob/main/LICENSE">ESRGAN</a></h2>
|
49 |
+
<small>Code for architecture and reading models copied.</small>
|
50 |
+
<pre>
|
51 |
+
MIT License
|
52 |
+
|
53 |
+
Copyright (c) 2021 victorca25
|
54 |
+
|
55 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
56 |
+
of this software and associated documentation files (the "Software"), to deal
|
57 |
+
in the Software without restriction, including without limitation the rights
|
58 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
59 |
+
copies of the Software, and to permit persons to whom the Software is
|
60 |
+
furnished to do so, subject to the following conditions:
|
61 |
+
|
62 |
+
The above copyright notice and this permission notice shall be included in all
|
63 |
+
copies or substantial portions of the Software.
|
64 |
+
|
65 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
66 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
67 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
68 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
69 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
70 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
71 |
+
SOFTWARE.
|
72 |
+
</pre>
|
73 |
+
|
74 |
+
<h2><a href="https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE">Real-ESRGAN</a></h2>
|
75 |
+
<small>Some code is copied to support ESRGAN models.</small>
|
76 |
+
<pre>
|
77 |
+
BSD 3-Clause License
|
78 |
+
|
79 |
+
Copyright (c) 2021, Xintao Wang
|
80 |
+
All rights reserved.
|
81 |
+
|
82 |
+
Redistribution and use in source and binary forms, with or without
|
83 |
+
modification, are permitted provided that the following conditions are met:
|
84 |
+
|
85 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
86 |
+
list of conditions and the following disclaimer.
|
87 |
+
|
88 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
89 |
+
this list of conditions and the following disclaimer in the documentation
|
90 |
+
and/or other materials provided with the distribution.
|
91 |
+
|
92 |
+
3. Neither the name of the copyright holder nor the names of its
|
93 |
+
contributors may be used to endorse or promote products derived from
|
94 |
+
this software without specific prior written permission.
|
95 |
+
|
96 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
97 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
98 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
99 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
100 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
101 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
102 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
103 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
104 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
105 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
106 |
+
</pre>
|
107 |
+
|
108 |
+
<h2><a href="https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE">InvokeAI</a></h2>
|
109 |
+
<small>Some code for compatibility with OSX is taken from lstein's repository.</small>
|
110 |
+
<pre>
|
111 |
+
MIT License
|
112 |
+
|
113 |
+
Copyright (c) 2022 InvokeAI Team
|
114 |
+
|
115 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
116 |
+
of this software and associated documentation files (the "Software"), to deal
|
117 |
+
in the Software without restriction, including without limitation the rights
|
118 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
119 |
+
copies of the Software, and to permit persons to whom the Software is
|
120 |
+
furnished to do so, subject to the following conditions:
|
121 |
+
|
122 |
+
The above copyright notice and this permission notice shall be included in all
|
123 |
+
copies or substantial portions of the Software.
|
124 |
+
|
125 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
126 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
127 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
128 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
129 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
130 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
131 |
+
SOFTWARE.
|
132 |
+
</pre>
|
133 |
+
|
134 |
+
<h2><a href="https://github.com/Hafiidz/latent-diffusion/blob/main/LICENSE">LDSR</a></h2>
|
135 |
+
<small>Code added by contirubtors, most likely copied from this repository.</small>
|
136 |
+
<pre>
|
137 |
+
MIT License
|
138 |
+
|
139 |
+
Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
|
140 |
+
|
141 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
142 |
+
of this software and associated documentation files (the "Software"), to deal
|
143 |
+
in the Software without restriction, including without limitation the rights
|
144 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
145 |
+
copies of the Software, and to permit persons to whom the Software is
|
146 |
+
furnished to do so, subject to the following conditions:
|
147 |
+
|
148 |
+
The above copyright notice and this permission notice shall be included in all
|
149 |
+
copies or substantial portions of the Software.
|
150 |
+
|
151 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
152 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
153 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
154 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
155 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
156 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
157 |
+
SOFTWARE.
|
158 |
+
</pre>
|
159 |
+
|
160 |
+
<h2><a href="https://github.com/pharmapsychotic/clip-interrogator/blob/main/LICENSE">CLIP Interrogator</a></h2>
|
161 |
+
<small>Some small amounts of code borrowed and reworked.</small>
|
162 |
+
<pre>
|
163 |
+
MIT License
|
164 |
+
|
165 |
+
Copyright (c) 2022 pharmapsychotic
|
166 |
+
|
167 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
168 |
+
of this software and associated documentation files (the "Software"), to deal
|
169 |
+
in the Software without restriction, including without limitation the rights
|
170 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
171 |
+
copies of the Software, and to permit persons to whom the Software is
|
172 |
+
furnished to do so, subject to the following conditions:
|
173 |
+
|
174 |
+
The above copyright notice and this permission notice shall be included in all
|
175 |
+
copies or substantial portions of the Software.
|
176 |
+
|
177 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
178 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
179 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
180 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
181 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
182 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
183 |
+
SOFTWARE.
|
184 |
+
</pre>
|
185 |
+
|
186 |
+
<h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a></h2>
|
187 |
+
<small>Code added by contributors, most likely copied from this repository.</small>
|
188 |
+
|
189 |
+
<pre>
|
190 |
+
Apache License
|
191 |
+
Version 2.0, January 2004
|
192 |
+
http://www.apache.org/licenses/
|
193 |
+
|
194 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
195 |
+
|
196 |
+
1. Definitions.
|
197 |
+
|
198 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
199 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
200 |
+
|
201 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
202 |
+
the copyright owner that is granting the License.
|
203 |
+
|
204 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
205 |
+
other entities that control, are controlled by, or are under common
|
206 |
+
control with that entity. For the purposes of this definition,
|
207 |
+
"control" means (i) the power, direct or indirect, to cause the
|
208 |
+
direction or management of such entity, whether by contract or
|
209 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
210 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
211 |
+
|
212 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
213 |
+
exercising permissions granted by this License.
|
214 |
+
|
215 |
+
"Source" form shall mean the preferred form for making modifications,
|
216 |
+
including but not limited to software source code, documentation
|
217 |
+
source, and configuration files.
|
218 |
+
|
219 |
+
"Object" form shall mean any form resulting from mechanical
|
220 |
+
transformation or translation of a Source form, including but
|
221 |
+
not limited to compiled object code, generated documentation,
|
222 |
+
and conversions to other media types.
|
223 |
+
|
224 |
+
"Work" shall mean the work of authorship, whether in Source or
|
225 |
+
Object form, made available under the License, as indicated by a
|
226 |
+
copyright notice that is included in or attached to the work
|
227 |
+
(an example is provided in the Appendix below).
|
228 |
+
|
229 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
230 |
+
form, that is based on (or derived from) the Work and for which the
|
231 |
+
editorial revisions, annotations, elaborations, or other modifications
|
232 |
+
represent, as a whole, an original work of authorship. For the purposes
|
233 |
+
of this License, Derivative Works shall not include works that remain
|
234 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
235 |
+
the Work and Derivative Works thereof.
|
236 |
+
|
237 |
+
"Contribution" shall mean any work of authorship, including
|
238 |
+
the original version of the Work and any modifications or additions
|
239 |
+
to that Work or Derivative Works thereof, that is intentionally
|
240 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
241 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
242 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
243 |
+
means any form of electronic, verbal, or written communication sent
|
244 |
+
to the Licensor or its representatives, including but not limited to
|
245 |
+
communication on electronic mailing lists, source code control systems,
|
246 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
247 |
+
Licensor for the purpose of discussing and improving the Work, but
|
248 |
+
excluding communication that is conspicuously marked or otherwise
|
249 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
250 |
+
|
251 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
252 |
+
on behalf of whom a Contribution has been received by Licensor and
|
253 |
+
subsequently incorporated within the Work.
|
254 |
+
|
255 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
256 |
+
this License, each Contributor hereby grants to You a perpetual,
|
257 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
258 |
+
copyright license to reproduce, prepare Derivative Works of,
|
259 |
+
publicly display, publicly perform, sublicense, and distribute the
|
260 |
+
Work and such Derivative Works in Source or Object form.
|
261 |
+
|
262 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
263 |
+
this License, each Contributor hereby grants to You a perpetual,
|
264 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
265 |
+
(except as stated in this section) patent license to make, have made,
|
266 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
267 |
+
where such license applies only to those patent claims licensable
|
268 |
+
by such Contributor that are necessarily infringed by their
|
269 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
270 |
+
with the Work to which such Contribution(s) was submitted. If You
|
271 |
+
institute patent litigation against any entity (including a
|
272 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
273 |
+
or a Contribution incorporated within the Work constitutes direct
|
274 |
+
or contributory patent infringement, then any patent licenses
|
275 |
+
granted to You under this License for that Work shall terminate
|
276 |
+
as of the date such litigation is filed.
|
277 |
+
|
278 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
279 |
+
Work or Derivative Works thereof in any medium, with or without
|
280 |
+
modifications, and in Source or Object form, provided that You
|
281 |
+
meet the following conditions:
|
282 |
+
|
283 |
+
(a) You must give any other recipients of the Work or
|
284 |
+
Derivative Works a copy of this License; and
|
285 |
+
|
286 |
+
(b) You must cause any modified files to carry prominent notices
|
287 |
+
stating that You changed the files; and
|
288 |
+
|
289 |
+
(c) You must retain, in the Source form of any Derivative Works
|
290 |
+
that You distribute, all copyright, patent, trademark, and
|
291 |
+
attribution notices from the Source form of the Work,
|
292 |
+
excluding those notices that do not pertain to any part of
|
293 |
+
the Derivative Works; and
|
294 |
+
|
295 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
296 |
+
distribution, then any Derivative Works that You distribute must
|
297 |
+
include a readable copy of the attribution notices contained
|
298 |
+
within such NOTICE file, excluding those notices that do not
|
299 |
+
pertain to any part of the Derivative Works, in at least one
|
300 |
+
of the following places: within a NOTICE text file distributed
|
301 |
+
as part of the Derivative Works; within the Source form or
|
302 |
+
documentation, if provided along with the Derivative Works; or,
|
303 |
+
within a display generated by the Derivative Works, if and
|
304 |
+
wherever such third-party notices normally appear. The contents
|
305 |
+
of the NOTICE file are for informational purposes only and
|
306 |
+
do not modify the License. You may add Your own attribution
|
307 |
+
notices within Derivative Works that You distribute, alongside
|
308 |
+
or as an addendum to the NOTICE text from the Work, provided
|
309 |
+
that such additional attribution notices cannot be construed
|
310 |
+
as modifying the License.
|
311 |
+
|
312 |
+
You may add Your own copyright statement to Your modifications and
|
313 |
+
may provide additional or different license terms and conditions
|
314 |
+
for use, reproduction, or distribution of Your modifications, or
|
315 |
+
for any such Derivative Works as a whole, provided Your use,
|
316 |
+
reproduction, and distribution of the Work otherwise complies with
|
317 |
+
the conditions stated in this License.
|
318 |
+
|
319 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
320 |
+
any Contribution intentionally submitted for inclusion in the Work
|
321 |
+
by You to the Licensor shall be under the terms and conditions of
|
322 |
+
this License, without any additional terms or conditions.
|
323 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
324 |
+
the terms of any separate license agreement you may have executed
|
325 |
+
with Licensor regarding such Contributions.
|
326 |
+
|
327 |
+
6. Trademarks. This License does not grant permission to use the trade
|
328 |
+
names, trademarks, service marks, or product names of the Licensor,
|
329 |
+
except as required for reasonable and customary use in describing the
|
330 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
331 |
+
|
332 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
333 |
+
agreed to in writing, Licensor provides the Work (and each
|
334 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
335 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
336 |
+
implied, including, without limitation, any warranties or conditions
|
337 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
338 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
339 |
+
appropriateness of using or redistributing the Work and assume any
|
340 |
+
risks associated with Your exercise of permissions under this License.
|
341 |
+
|
342 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
343 |
+
whether in tort (including negligence), contract, or otherwise,
|
344 |
+
unless required by applicable law (such as deliberate and grossly
|
345 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
346 |
+
liable to You for damages, including any direct, indirect, special,
|
347 |
+
incidental, or consequential damages of any character arising as a
|
348 |
+
result of this License or out of the use or inability to use the
|
349 |
+
Work (including but not limited to damages for loss of goodwill,
|
350 |
+
work stoppage, computer failure or malfunction, or any and all
|
351 |
+
other commercial damages or losses), even if such Contributor
|
352 |
+
has been advised of the possibility of such damages.
|
353 |
+
|
354 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
355 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
356 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
357 |
+
or other liability obligations and/or rights consistent with this
|
358 |
+
License. However, in accepting such obligations, You may act only
|
359 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
360 |
+
of any other Contributor, and only if You agree to indemnify,
|
361 |
+
defend, and hold each Contributor harmless for any liability
|
362 |
+
incurred by, or claims asserted against, such Contributor by reason
|
363 |
+
of your accepting any such warranty or additional liability.
|
364 |
+
|
365 |
+
END OF TERMS AND CONDITIONS
|
366 |
+
|
367 |
+
APPENDIX: How to apply the Apache License to your work.
|
368 |
+
|
369 |
+
To apply the Apache License to your work, attach the following
|
370 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
371 |
+
replaced with your own identifying information. (Don't include
|
372 |
+
the brackets!) The text should be enclosed in the appropriate
|
373 |
+
comment syntax for the file format. We also recommend that a
|
374 |
+
file or class name and description of purpose be included on the
|
375 |
+
same "printed page" as the copyright notice for easier
|
376 |
+
identification within third-party archives.
|
377 |
+
|
378 |
+
Copyright [2021] [SwinIR Authors]
|
379 |
+
|
380 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
381 |
+
you may not use this file except in compliance with the License.
|
382 |
+
You may obtain a copy of the License at
|
383 |
+
|
384 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
385 |
+
|
386 |
+
Unless required by applicable law or agreed to in writing, software
|
387 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
388 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
389 |
+
See the License for the specific language governing permissions and
|
390 |
+
limitations under the License.
|
391 |
+
</pre>
|
392 |
+
|
393 |
+
<h2><a href="https://github.com/AminRezaei0x443/memory-efficient-attention/blob/main/LICENSE">Memory Efficient Attention</a></h2>
|
394 |
+
<small>The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.</small>
|
395 |
+
<pre>
|
396 |
+
MIT License
|
397 |
+
|
398 |
+
Copyright (c) 2023 Alex Birch
|
399 |
+
Copyright (c) 2023 Amin Rezaei
|
400 |
+
|
401 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
402 |
+
of this software and associated documentation files (the "Software"), to deal
|
403 |
+
in the Software without restriction, including without limitation the rights
|
404 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
405 |
+
copies of the Software, and to permit persons to whom the Software is
|
406 |
+
furnished to do so, subject to the following conditions:
|
407 |
+
|
408 |
+
The above copyright notice and this permission notice shall be included in all
|
409 |
+
copies or substantial portions of the Software.
|
410 |
+
|
411 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
412 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
413 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
414 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
415 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
416 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
417 |
+
SOFTWARE.
|
418 |
+
</pre>
|
419 |
+
|
420 |
+
<h2><a href="https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/LICENSE">Scaled Dot Product Attention</a></h2>
|
421 |
+
<small>Some small amounts of code borrowed and reworked.</small>
|
422 |
+
<pre>
|
423 |
+
Copyright 2023 The HuggingFace Team. All rights reserved.
|
424 |
+
|
425 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
426 |
+
you may not use this file except in compliance with the License.
|
427 |
+
You may obtain a copy of the License at
|
428 |
+
|
429 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
430 |
+
|
431 |
+
Unless required by applicable law or agreed to in writing, software
|
432 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
433 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
434 |
+
See the License for the specific language governing permissions and
|
435 |
+
limitations under the License.
|
436 |
+
|
437 |
+
Apache License
|
438 |
+
Version 2.0, January 2004
|
439 |
+
http://www.apache.org/licenses/
|
440 |
+
|
441 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
442 |
+
|
443 |
+
1. Definitions.
|
444 |
+
|
445 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
446 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
447 |
+
|
448 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
449 |
+
the copyright owner that is granting the License.
|
450 |
+
|
451 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
452 |
+
other entities that control, are controlled by, or are under common
|
453 |
+
control with that entity. For the purposes of this definition,
|
454 |
+
"control" means (i) the power, direct or indirect, to cause the
|
455 |
+
direction or management of such entity, whether by contract or
|
456 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
457 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
458 |
+
|
459 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
460 |
+
exercising permissions granted by this License.
|
461 |
+
|
462 |
+
"Source" form shall mean the preferred form for making modifications,
|
463 |
+
including but not limited to software source code, documentation
|
464 |
+
source, and configuration files.
|
465 |
+
|
466 |
+
"Object" form shall mean any form resulting from mechanical
|
467 |
+
transformation or translation of a Source form, including but
|
468 |
+
not limited to compiled object code, generated documentation,
|
469 |
+
and conversions to other media types.
|
470 |
+
|
471 |
+
"Work" shall mean the work of authorship, whether in Source or
|
472 |
+
Object form, made available under the License, as indicated by a
|
473 |
+
copyright notice that is included in or attached to the work
|
474 |
+
(an example is provided in the Appendix below).
|
475 |
+
|
476 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
477 |
+
form, that is based on (or derived from) the Work and for which the
|
478 |
+
editorial revisions, annotations, elaborations, or other modifications
|
479 |
+
represent, as a whole, an original work of authorship. For the purposes
|
480 |
+
of this License, Derivative Works shall not include works that remain
|
481 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
482 |
+
the Work and Derivative Works thereof.
|
483 |
+
|
484 |
+
"Contribution" shall mean any work of authorship, including
|
485 |
+
the original version of the Work and any modifications or additions
|
486 |
+
to that Work or Derivative Works thereof, that is intentionally
|
487 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
488 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
489 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
490 |
+
means any form of electronic, verbal, or written communication sent
|
491 |
+
to the Licensor or its representatives, including but not limited to
|
492 |
+
communication on electronic mailing lists, source code control systems,
|
493 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
494 |
+
Licensor for the purpose of discussing and improving the Work, but
|
495 |
+
excluding communication that is conspicuously marked or otherwise
|
496 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
497 |
+
|
498 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
499 |
+
on behalf of whom a Contribution has been received by Licensor and
|
500 |
+
subsequently incorporated within the Work.
|
501 |
+
|
502 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
503 |
+
this License, each Contributor hereby grants to You a perpetual,
|
504 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
505 |
+
copyright license to reproduce, prepare Derivative Works of,
|
506 |
+
publicly display, publicly perform, sublicense, and distribute the
|
507 |
+
Work and such Derivative Works in Source or Object form.
|
508 |
+
|
509 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
510 |
+
this License, each Contributor hereby grants to You a perpetual,
|
511 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
512 |
+
(except as stated in this section) patent license to make, have made,
|
513 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
514 |
+
where such license applies only to those patent claims licensable
|
515 |
+
by such Contributor that are necessarily infringed by their
|
516 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
517 |
+
with the Work to which such Contribution(s) was submitted. If You
|
518 |
+
institute patent litigation against any entity (including a
|
519 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
520 |
+
or a Contribution incorporated within the Work constitutes direct
|
521 |
+
or contributory patent infringement, then any patent licenses
|
522 |
+
granted to You under this License for that Work shall terminate
|
523 |
+
as of the date such litigation is filed.
|
524 |
+
|
525 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
526 |
+
Work or Derivative Works thereof in any medium, with or without
|
527 |
+
modifications, and in Source or Object form, provided that You
|
528 |
+
meet the following conditions:
|
529 |
+
|
530 |
+
(a) You must give any other recipients of the Work or
|
531 |
+
Derivative Works a copy of this License; and
|
532 |
+
|
533 |
+
(b) You must cause any modified files to carry prominent notices
|
534 |
+
stating that You changed the files; and
|
535 |
+
|
536 |
+
(c) You must retain, in the Source form of any Derivative Works
|
537 |
+
that You distribute, all copyright, patent, trademark, and
|
538 |
+
attribution notices from the Source form of the Work,
|
539 |
+
excluding those notices that do not pertain to any part of
|
540 |
+
the Derivative Works; and
|
541 |
+
|
542 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
543 |
+
distribution, then any Derivative Works that You distribute must
|
544 |
+
include a readable copy of the attribution notices contained
|
545 |
+
within such NOTICE file, excluding those notices that do not
|
546 |
+
pertain to any part of the Derivative Works, in at least one
|
547 |
+
of the following places: within a NOTICE text file distributed
|
548 |
+
as part of the Derivative Works; within the Source form or
|
549 |
+
documentation, if provided along with the Derivative Works; or,
|
550 |
+
within a display generated by the Derivative Works, if and
|
551 |
+
wherever such third-party notices normally appear. The contents
|
552 |
+
of the NOTICE file are for informational purposes only and
|
553 |
+
do not modify the License. You may add Your own attribution
|
554 |
+
notices within Derivative Works that You distribute, alongside
|
555 |
+
or as an addendum to the NOTICE text from the Work, provided
|
556 |
+
that such additional attribution notices cannot be construed
|
557 |
+
as modifying the License.
|
558 |
+
|
559 |
+
You may add Your own copyright statement to Your modifications and
|
560 |
+
may provide additional or different license terms and conditions
|
561 |
+
for use, reproduction, or distribution of Your modifications, or
|
562 |
+
for any such Derivative Works as a whole, provided Your use,
|
563 |
+
reproduction, and distribution of the Work otherwise complies with
|
564 |
+
the conditions stated in this License.
|
565 |
+
|
566 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
567 |
+
any Contribution intentionally submitted for inclusion in the Work
|
568 |
+
by You to the Licensor shall be under the terms and conditions of
|
569 |
+
this License, without any additional terms or conditions.
|
570 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
571 |
+
the terms of any separate license agreement you may have executed
|
572 |
+
with Licensor regarding such Contributions.
|
573 |
+
|
574 |
+
6. Trademarks. This License does not grant permission to use the trade
|
575 |
+
names, trademarks, service marks, or product names of the Licensor,
|
576 |
+
except as required for reasonable and customary use in describing the
|
577 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
578 |
+
|
579 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
580 |
+
agreed to in writing, Licensor provides the Work (and each
|
581 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
582 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
583 |
+
implied, including, without limitation, any warranties or conditions
|
584 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
585 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
586 |
+
appropriateness of using or redistributing the Work and assume any
|
587 |
+
risks associated with Your exercise of permissions under this License.
|
588 |
+
|
589 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
590 |
+
whether in tort (including negligence), contract, or otherwise,
|
591 |
+
unless required by applicable law (such as deliberate and grossly
|
592 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
593 |
+
liable to You for damages, including any direct, indirect, special,
|
594 |
+
incidental, or consequential damages of any character arising as a
|
595 |
+
result of this License or out of the use or inability to use the
|
596 |
+
Work (including but not limited to damages for loss of goodwill,
|
597 |
+
work stoppage, computer failure or malfunction, or any and all
|
598 |
+
other commercial damages or losses), even if such Contributor
|
599 |
+
has been advised of the possibility of such damages.
|
600 |
+
|
601 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
602 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
603 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
604 |
+
or other liability obligations and/or rights consistent with this
|
605 |
+
License. However, in accepting such obligations, You may act only
|
606 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
607 |
+
of any other Contributor, and only if You agree to indemnify,
|
608 |
+
defend, and hold each Contributor harmless for any liability
|
609 |
+
incurred by, or claims asserted against, such Contributor by reason
|
610 |
+
of your accepting any such warranty or additional liability.
|
611 |
+
|
612 |
+
END OF TERMS AND CONDITIONS
|
613 |
+
|
614 |
+
APPENDIX: How to apply the Apache License to your work.
|
615 |
+
|
616 |
+
To apply the Apache License to your work, attach the following
|
617 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
618 |
+
replaced with your own identifying information. (Don't include
|
619 |
+
the brackets!) The text should be enclosed in the appropriate
|
620 |
+
comment syntax for the file format. We also recommend that a
|
621 |
+
file or class name and description of purpose be included on the
|
622 |
+
same "printed page" as the copyright notice for easier
|
623 |
+
identification within third-party archives.
|
624 |
+
|
625 |
+
Copyright [yyyy] [name of copyright owner]
|
626 |
+
|
627 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
628 |
+
you may not use this file except in compliance with the License.
|
629 |
+
You may obtain a copy of the License at
|
630 |
+
|
631 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
632 |
+
|
633 |
+
Unless required by applicable law or agreed to in writing, software
|
634 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
635 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
636 |
+
See the License for the specific language governing permissions and
|
637 |
+
limitations under the License.
|
638 |
+
</pre>
|
sd/stable-diffusion-webui/javascript/aspectRatioOverlay.js
CHANGED
@@ -1,113 +1,113 @@
|
|
1 |
-
|
2 |
-
let currentWidth = null;
|
3 |
-
let currentHeight = null;
|
4 |
-
let arFrameTimeout = setTimeout(function(){},0);
|
5 |
-
|
6 |
-
function dimensionChange(e, is_width, is_height){
|
7 |
-
|
8 |
-
if(is_width){
|
9 |
-
currentWidth = e.target.value*1.0
|
10 |
-
}
|
11 |
-
if(is_height){
|
12 |
-
currentHeight = e.target.value*1.0
|
13 |
-
}
|
14 |
-
|
15 |
-
var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
|
16 |
-
|
17 |
-
if(!inImg2img){
|
18 |
-
return;
|
19 |
-
}
|
20 |
-
|
21 |
-
var targetElement = null;
|
22 |
-
|
23 |
-
var tabIndex = get_tab_index('mode_img2img')
|
24 |
-
if(tabIndex == 0){ // img2img
|
25 |
-
targetElement = gradioApp().querySelector('div[data-testid=image] img');
|
26 |
-
} else if(tabIndex == 1){ //Sketch
|
27 |
-
targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
|
28 |
-
} else if(tabIndex == 2){ // Inpaint
|
29 |
-
targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
|
30 |
-
} else if(tabIndex == 3){ // Inpaint sketch
|
31 |
-
targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
|
32 |
-
}
|
33 |
-
|
34 |
-
|
35 |
-
if(targetElement){
|
36 |
-
|
37 |
-
var arPreviewRect = gradioApp().querySelector('#imageARPreview');
|
38 |
-
if(!arPreviewRect){
|
39 |
-
arPreviewRect = document.createElement('div')
|
40 |
-
arPreviewRect.id = "imageARPreview";
|
41 |
-
gradioApp().getRootNode().appendChild(arPreviewRect)
|
42 |
-
}
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
var viewportOffset = targetElement.getBoundingClientRect();
|
47 |
-
|
48 |
-
viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
|
49 |
-
|
50 |
-
scaledx = targetElement.naturalWidth*viewportscale
|
51 |
-
scaledy = targetElement.naturalHeight*viewportscale
|
52 |
-
|
53 |
-
cleintRectTop = (viewportOffset.top+window.scrollY)
|
54 |
-
cleintRectLeft = (viewportOffset.left+window.scrollX)
|
55 |
-
cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
|
56 |
-
cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
|
57 |
-
|
58 |
-
viewRectTop = cleintRectCentreY-(scaledy/2)
|
59 |
-
viewRectLeft = cleintRectCentreX-(scaledx/2)
|
60 |
-
arRectWidth = scaledx
|
61 |
-
arRectHeight = scaledy
|
62 |
-
|
63 |
-
arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight )
|
64 |
-
arscaledx = currentWidth*arscale
|
65 |
-
arscaledy = currentHeight*arscale
|
66 |
-
|
67 |
-
arRectTop = cleintRectCentreY-(arscaledy/2)
|
68 |
-
arRectLeft = cleintRectCentreX-(arscaledx/2)
|
69 |
-
arRectWidth = arscaledx
|
70 |
-
arRectHeight = arscaledy
|
71 |
-
|
72 |
-
arPreviewRect.style.top = arRectTop+'px';
|
73 |
-
arPreviewRect.style.left = arRectLeft+'px';
|
74 |
-
arPreviewRect.style.width = arRectWidth+'px';
|
75 |
-
arPreviewRect.style.height = arRectHeight+'px';
|
76 |
-
|
77 |
-
clearTimeout(arFrameTimeout);
|
78 |
-
arFrameTimeout = setTimeout(function(){
|
79 |
-
arPreviewRect.style.display = 'none';
|
80 |
-
},2000);
|
81 |
-
|
82 |
-
arPreviewRect.style.display = 'block';
|
83 |
-
|
84 |
-
}
|
85 |
-
|
86 |
-
}
|
87 |
-
|
88 |
-
|
89 |
-
onUiUpdate(function(){
|
90 |
-
var arPreviewRect = gradioApp().querySelector('#imageARPreview');
|
91 |
-
if(arPreviewRect){
|
92 |
-
arPreviewRect.style.display = 'none';
|
93 |
-
}
|
94 |
-
var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
|
95 |
-
if(inImg2img){
|
96 |
-
let inputs = gradioApp().querySelectorAll('input');
|
97 |
-
inputs.forEach(function(e){
|
98 |
-
var is_width = e.parentElement.id == "img2img_width"
|
99 |
-
var is_height = e.parentElement.id == "img2img_height"
|
100 |
-
|
101 |
-
if((is_width || is_height) && !e.classList.contains('scrollwatch')){
|
102 |
-
e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} )
|
103 |
-
e.classList.add('scrollwatch')
|
104 |
-
}
|
105 |
-
if(is_width){
|
106 |
-
currentWidth = e.value*1.0
|
107 |
-
}
|
108 |
-
if(is_height){
|
109 |
-
currentHeight = e.value*1.0
|
110 |
-
}
|
111 |
-
})
|
112 |
-
}
|
113 |
-
});
|
|
|
1 |
+
|
2 |
+
let currentWidth = null;
|
3 |
+
let currentHeight = null;
|
4 |
+
let arFrameTimeout = setTimeout(function(){},0);
|
5 |
+
|
6 |
+
function dimensionChange(e, is_width, is_height){
|
7 |
+
|
8 |
+
if(is_width){
|
9 |
+
currentWidth = e.target.value*1.0
|
10 |
+
}
|
11 |
+
if(is_height){
|
12 |
+
currentHeight = e.target.value*1.0
|
13 |
+
}
|
14 |
+
|
15 |
+
var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
|
16 |
+
|
17 |
+
if(!inImg2img){
|
18 |
+
return;
|
19 |
+
}
|
20 |
+
|
21 |
+
var targetElement = null;
|
22 |
+
|
23 |
+
var tabIndex = get_tab_index('mode_img2img')
|
24 |
+
if(tabIndex == 0){ // img2img
|
25 |
+
targetElement = gradioApp().querySelector('div[data-testid=image] img');
|
26 |
+
} else if(tabIndex == 1){ //Sketch
|
27 |
+
targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
|
28 |
+
} else if(tabIndex == 2){ // Inpaint
|
29 |
+
targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
|
30 |
+
} else if(tabIndex == 3){ // Inpaint sketch
|
31 |
+
targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
|
32 |
+
}
|
33 |
+
|
34 |
+
|
35 |
+
if(targetElement){
|
36 |
+
|
37 |
+
var arPreviewRect = gradioApp().querySelector('#imageARPreview');
|
38 |
+
if(!arPreviewRect){
|
39 |
+
arPreviewRect = document.createElement('div')
|
40 |
+
arPreviewRect.id = "imageARPreview";
|
41 |
+
gradioApp().getRootNode().appendChild(arPreviewRect)
|
42 |
+
}
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
var viewportOffset = targetElement.getBoundingClientRect();
|
47 |
+
|
48 |
+
viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
|
49 |
+
|
50 |
+
scaledx = targetElement.naturalWidth*viewportscale
|
51 |
+
scaledy = targetElement.naturalHeight*viewportscale
|
52 |
+
|
53 |
+
cleintRectTop = (viewportOffset.top+window.scrollY)
|
54 |
+
cleintRectLeft = (viewportOffset.left+window.scrollX)
|
55 |
+
cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
|
56 |
+
cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
|
57 |
+
|
58 |
+
viewRectTop = cleintRectCentreY-(scaledy/2)
|
59 |
+
viewRectLeft = cleintRectCentreX-(scaledx/2)
|
60 |
+
arRectWidth = scaledx
|
61 |
+
arRectHeight = scaledy
|
62 |
+
|
63 |
+
arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight )
|
64 |
+
arscaledx = currentWidth*arscale
|
65 |
+
arscaledy = currentHeight*arscale
|
66 |
+
|
67 |
+
arRectTop = cleintRectCentreY-(arscaledy/2)
|
68 |
+
arRectLeft = cleintRectCentreX-(arscaledx/2)
|
69 |
+
arRectWidth = arscaledx
|
70 |
+
arRectHeight = arscaledy
|
71 |
+
|
72 |
+
arPreviewRect.style.top = arRectTop+'px';
|
73 |
+
arPreviewRect.style.left = arRectLeft+'px';
|
74 |
+
arPreviewRect.style.width = arRectWidth+'px';
|
75 |
+
arPreviewRect.style.height = arRectHeight+'px';
|
76 |
+
|
77 |
+
clearTimeout(arFrameTimeout);
|
78 |
+
arFrameTimeout = setTimeout(function(){
|
79 |
+
arPreviewRect.style.display = 'none';
|
80 |
+
},2000);
|
81 |
+
|
82 |
+
arPreviewRect.style.display = 'block';
|
83 |
+
|
84 |
+
}
|
85 |
+
|
86 |
+
}
|
87 |
+
|
88 |
+
|
89 |
+
onUiUpdate(function(){
|
90 |
+
var arPreviewRect = gradioApp().querySelector('#imageARPreview');
|
91 |
+
if(arPreviewRect){
|
92 |
+
arPreviewRect.style.display = 'none';
|
93 |
+
}
|
94 |
+
var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
|
95 |
+
if(inImg2img){
|
96 |
+
let inputs = gradioApp().querySelectorAll('input');
|
97 |
+
inputs.forEach(function(e){
|
98 |
+
var is_width = e.parentElement.id == "img2img_width"
|
99 |
+
var is_height = e.parentElement.id == "img2img_height"
|
100 |
+
|
101 |
+
if((is_width || is_height) && !e.classList.contains('scrollwatch')){
|
102 |
+
e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} )
|
103 |
+
e.classList.add('scrollwatch')
|
104 |
+
}
|
105 |
+
if(is_width){
|
106 |
+
currentWidth = e.value*1.0
|
107 |
+
}
|
108 |
+
if(is_height){
|
109 |
+
currentHeight = e.value*1.0
|
110 |
+
}
|
111 |
+
})
|
112 |
+
}
|
113 |
+
});
|
sd/stable-diffusion-webui/javascript/contextMenus.js
CHANGED
@@ -1,177 +1,177 @@
|
|
1 |
-
|
2 |
-
contextMenuInit = function(){
|
3 |
-
let eventListenerApplied=false;
|
4 |
-
let menuSpecs = new Map();
|
5 |
-
|
6 |
-
const uid = function(){
|
7 |
-
return Date.now().toString(36) + Math.random().toString(36).substr(2);
|
8 |
-
}
|
9 |
-
|
10 |
-
function showContextMenu(event,element,menuEntries){
|
11 |
-
let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
|
12 |
-
let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
|
13 |
-
|
14 |
-
let oldMenu = gradioApp().querySelector('#context-menu')
|
15 |
-
if(oldMenu){
|
16 |
-
oldMenu.remove()
|
17 |
-
}
|
18 |
-
|
19 |
-
let tabButton = uiCurrentTab
|
20 |
-
let baseStyle = window.getComputedStyle(tabButton)
|
21 |
-
|
22 |
-
const contextMenu = document.createElement('nav')
|
23 |
-
contextMenu.id = "context-menu"
|
24 |
-
contextMenu.style.background = baseStyle.background
|
25 |
-
contextMenu.style.color = baseStyle.color
|
26 |
-
contextMenu.style.fontFamily = baseStyle.fontFamily
|
27 |
-
contextMenu.style.top = posy+'px'
|
28 |
-
contextMenu.style.left = posx+'px'
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
const contextMenuList = document.createElement('ul')
|
33 |
-
contextMenuList.className = 'context-menu-items';
|
34 |
-
contextMenu.append(contextMenuList);
|
35 |
-
|
36 |
-
menuEntries.forEach(function(entry){
|
37 |
-
let contextMenuEntry = document.createElement('a')
|
38 |
-
contextMenuEntry.innerHTML = entry['name']
|
39 |
-
contextMenuEntry.addEventListener("click", function(e) {
|
40 |
-
entry['func']();
|
41 |
-
})
|
42 |
-
contextMenuList.append(contextMenuEntry);
|
43 |
-
|
44 |
-
})
|
45 |
-
|
46 |
-
gradioApp().getRootNode().appendChild(contextMenu)
|
47 |
-
|
48 |
-
let menuWidth = contextMenu.offsetWidth + 4;
|
49 |
-
let menuHeight = contextMenu.offsetHeight + 4;
|
50 |
-
|
51 |
-
let windowWidth = window.innerWidth;
|
52 |
-
let windowHeight = window.innerHeight;
|
53 |
-
|
54 |
-
if ( (windowWidth - posx) < menuWidth ) {
|
55 |
-
contextMenu.style.left = windowWidth - menuWidth + "px";
|
56 |
-
}
|
57 |
-
|
58 |
-
if ( (windowHeight - posy) < menuHeight ) {
|
59 |
-
contextMenu.style.top = windowHeight - menuHeight + "px";
|
60 |
-
}
|
61 |
-
|
62 |
-
}
|
63 |
-
|
64 |
-
function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
|
65 |
-
|
66 |
-
currentItems = menuSpecs.get(targetElementSelector)
|
67 |
-
|
68 |
-
if(!currentItems){
|
69 |
-
currentItems = []
|
70 |
-
menuSpecs.set(targetElementSelector,currentItems);
|
71 |
-
}
|
72 |
-
let newItem = {'id':targetElementSelector+'_'+uid(),
|
73 |
-
'name':entryName,
|
74 |
-
'func':entryFunction,
|
75 |
-
'isNew':true}
|
76 |
-
|
77 |
-
currentItems.push(newItem)
|
78 |
-
return newItem['id']
|
79 |
-
}
|
80 |
-
|
81 |
-
function removeContextMenuOption(uid){
|
82 |
-
menuSpecs.forEach(function(v,k) {
|
83 |
-
let index = -1
|
84 |
-
v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
|
85 |
-
if(index>=0){
|
86 |
-
v.splice(index, 1);
|
87 |
-
}
|
88 |
-
})
|
89 |
-
}
|
90 |
-
|
91 |
-
function addContextMenuEventListener(){
|
92 |
-
if(eventListenerApplied){
|
93 |
-
return;
|
94 |
-
}
|
95 |
-
gradioApp().addEventListener("click", function(e) {
|
96 |
-
let source = e.composedPath()[0]
|
97 |
-
if(source.id && source.id.indexOf('check_progress')>-1){
|
98 |
-
return
|
99 |
-
}
|
100 |
-
|
101 |
-
let oldMenu = gradioApp().querySelector('#context-menu')
|
102 |
-
if(oldMenu){
|
103 |
-
oldMenu.remove()
|
104 |
-
}
|
105 |
-
});
|
106 |
-
gradioApp().addEventListener("contextmenu", function(e) {
|
107 |
-
let oldMenu = gradioApp().querySelector('#context-menu')
|
108 |
-
if(oldMenu){
|
109 |
-
oldMenu.remove()
|
110 |
-
}
|
111 |
-
menuSpecs.forEach(function(v,k) {
|
112 |
-
if(e.composedPath()[0].matches(k)){
|
113 |
-
showContextMenu(e,e.composedPath()[0],v)
|
114 |
-
e.preventDefault()
|
115 |
-
return
|
116 |
-
}
|
117 |
-
})
|
118 |
-
});
|
119 |
-
eventListenerApplied=true
|
120 |
-
|
121 |
-
}
|
122 |
-
|
123 |
-
return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]
|
124 |
-
}
|
125 |
-
|
126 |
-
initResponse = contextMenuInit();
|
127 |
-
appendContextMenuOption = initResponse[0];
|
128 |
-
removeContextMenuOption = initResponse[1];
|
129 |
-
addContextMenuEventListener = initResponse[2];
|
130 |
-
|
131 |
-
(function(){
|
132 |
-
//Start example Context Menu Items
|
133 |
-
let generateOnRepeat = function(genbuttonid,interruptbuttonid){
|
134 |
-
let genbutton = gradioApp().querySelector(genbuttonid);
|
135 |
-
let interruptbutton = gradioApp().querySelector(interruptbuttonid);
|
136 |
-
if(!interruptbutton.offsetParent){
|
137 |
-
genbutton.click();
|
138 |
-
}
|
139 |
-
clearInterval(window.generateOnRepeatInterval)
|
140 |
-
window.generateOnRepeatInterval = setInterval(function(){
|
141 |
-
if(!interruptbutton.offsetParent){
|
142 |
-
genbutton.click();
|
143 |
-
}
|
144 |
-
},
|
145 |
-
500)
|
146 |
-
}
|
147 |
-
|
148 |
-
appendContextMenuOption('#txt2img_generate','Generate forever',function(){
|
149 |
-
generateOnRepeat('#txt2img_generate','#txt2img_interrupt');
|
150 |
-
})
|
151 |
-
appendContextMenuOption('#img2img_generate','Generate forever',function(){
|
152 |
-
generateOnRepeat('#img2img_generate','#img2img_interrupt');
|
153 |
-
})
|
154 |
-
|
155 |
-
let cancelGenerateForever = function(){
|
156 |
-
clearInterval(window.generateOnRepeatInterval)
|
157 |
-
}
|
158 |
-
|
159 |
-
appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)
|
160 |
-
appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever)
|
161 |
-
appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
|
162 |
-
appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
|
163 |
-
|
164 |
-
appendContextMenuOption('#roll','Roll three',
|
165 |
-
function(){
|
166 |
-
let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
|
167 |
-
setTimeout(function(){rollbutton.click()},100)
|
168 |
-
setTimeout(function(){rollbutton.click()},200)
|
169 |
-
setTimeout(function(){rollbutton.click()},300)
|
170 |
-
}
|
171 |
-
)
|
172 |
-
})();
|
173 |
-
//End example Context Menu Items
|
174 |
-
|
175 |
-
onUiUpdate(function(){
|
176 |
-
addContextMenuEventListener()
|
177 |
-
});
|
|
|
1 |
+
|
2 |
+
contextMenuInit = function(){
|
3 |
+
let eventListenerApplied=false;
|
4 |
+
let menuSpecs = new Map();
|
5 |
+
|
6 |
+
const uid = function(){
|
7 |
+
return Date.now().toString(36) + Math.random().toString(36).substr(2);
|
8 |
+
}
|
9 |
+
|
10 |
+
function showContextMenu(event,element,menuEntries){
|
11 |
+
let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
|
12 |
+
let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
|
13 |
+
|
14 |
+
let oldMenu = gradioApp().querySelector('#context-menu')
|
15 |
+
if(oldMenu){
|
16 |
+
oldMenu.remove()
|
17 |
+
}
|
18 |
+
|
19 |
+
let tabButton = uiCurrentTab
|
20 |
+
let baseStyle = window.getComputedStyle(tabButton)
|
21 |
+
|
22 |
+
const contextMenu = document.createElement('nav')
|
23 |
+
contextMenu.id = "context-menu"
|
24 |
+
contextMenu.style.background = baseStyle.background
|
25 |
+
contextMenu.style.color = baseStyle.color
|
26 |
+
contextMenu.style.fontFamily = baseStyle.fontFamily
|
27 |
+
contextMenu.style.top = posy+'px'
|
28 |
+
contextMenu.style.left = posx+'px'
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
const contextMenuList = document.createElement('ul')
|
33 |
+
contextMenuList.className = 'context-menu-items';
|
34 |
+
contextMenu.append(contextMenuList);
|
35 |
+
|
36 |
+
menuEntries.forEach(function(entry){
|
37 |
+
let contextMenuEntry = document.createElement('a')
|
38 |
+
contextMenuEntry.innerHTML = entry['name']
|
39 |
+
contextMenuEntry.addEventListener("click", function(e) {
|
40 |
+
entry['func']();
|
41 |
+
})
|
42 |
+
contextMenuList.append(contextMenuEntry);
|
43 |
+
|
44 |
+
})
|
45 |
+
|
46 |
+
gradioApp().getRootNode().appendChild(contextMenu)
|
47 |
+
|
48 |
+
let menuWidth = contextMenu.offsetWidth + 4;
|
49 |
+
let menuHeight = contextMenu.offsetHeight + 4;
|
50 |
+
|
51 |
+
let windowWidth = window.innerWidth;
|
52 |
+
let windowHeight = window.innerHeight;
|
53 |
+
|
54 |
+
if ( (windowWidth - posx) < menuWidth ) {
|
55 |
+
contextMenu.style.left = windowWidth - menuWidth + "px";
|
56 |
+
}
|
57 |
+
|
58 |
+
if ( (windowHeight - posy) < menuHeight ) {
|
59 |
+
contextMenu.style.top = windowHeight - menuHeight + "px";
|
60 |
+
}
|
61 |
+
|
62 |
+
}
|
63 |
+
|
64 |
+
function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
|
65 |
+
|
66 |
+
currentItems = menuSpecs.get(targetElementSelector)
|
67 |
+
|
68 |
+
if(!currentItems){
|
69 |
+
currentItems = []
|
70 |
+
menuSpecs.set(targetElementSelector,currentItems);
|
71 |
+
}
|
72 |
+
let newItem = {'id':targetElementSelector+'_'+uid(),
|
73 |
+
'name':entryName,
|
74 |
+
'func':entryFunction,
|
75 |
+
'isNew':true}
|
76 |
+
|
77 |
+
currentItems.push(newItem)
|
78 |
+
return newItem['id']
|
79 |
+
}
|
80 |
+
|
81 |
+
function removeContextMenuOption(uid){
|
82 |
+
menuSpecs.forEach(function(v,k) {
|
83 |
+
let index = -1
|
84 |
+
v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
|
85 |
+
if(index>=0){
|
86 |
+
v.splice(index, 1);
|
87 |
+
}
|
88 |
+
})
|
89 |
+
}
|
90 |
+
|
91 |
+
function addContextMenuEventListener(){
|
92 |
+
if(eventListenerApplied){
|
93 |
+
return;
|
94 |
+
}
|
95 |
+
gradioApp().addEventListener("click", function(e) {
|
96 |
+
let source = e.composedPath()[0]
|
97 |
+
if(source.id && source.id.indexOf('check_progress')>-1){
|
98 |
+
return
|
99 |
+
}
|
100 |
+
|
101 |
+
let oldMenu = gradioApp().querySelector('#context-menu')
|
102 |
+
if(oldMenu){
|
103 |
+
oldMenu.remove()
|
104 |
+
}
|
105 |
+
});
|
106 |
+
gradioApp().addEventListener("contextmenu", function(e) {
|
107 |
+
let oldMenu = gradioApp().querySelector('#context-menu')
|
108 |
+
if(oldMenu){
|
109 |
+
oldMenu.remove()
|
110 |
+
}
|
111 |
+
menuSpecs.forEach(function(v,k) {
|
112 |
+
if(e.composedPath()[0].matches(k)){
|
113 |
+
showContextMenu(e,e.composedPath()[0],v)
|
114 |
+
e.preventDefault()
|
115 |
+
return
|
116 |
+
}
|
117 |
+
})
|
118 |
+
});
|
119 |
+
eventListenerApplied=true
|
120 |
+
|
121 |
+
}
|
122 |
+
|
123 |
+
return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]
|
124 |
+
}
|
125 |
+
|
126 |
+
initResponse = contextMenuInit();
|
127 |
+
appendContextMenuOption = initResponse[0];
|
128 |
+
removeContextMenuOption = initResponse[1];
|
129 |
+
addContextMenuEventListener = initResponse[2];
|
130 |
+
|
131 |
+
(function(){
|
132 |
+
//Start example Context Menu Items
|
133 |
+
let generateOnRepeat = function(genbuttonid,interruptbuttonid){
|
134 |
+
let genbutton = gradioApp().querySelector(genbuttonid);
|
135 |
+
let interruptbutton = gradioApp().querySelector(interruptbuttonid);
|
136 |
+
if(!interruptbutton.offsetParent){
|
137 |
+
genbutton.click();
|
138 |
+
}
|
139 |
+
clearInterval(window.generateOnRepeatInterval)
|
140 |
+
window.generateOnRepeatInterval = setInterval(function(){
|
141 |
+
if(!interruptbutton.offsetParent){
|
142 |
+
genbutton.click();
|
143 |
+
}
|
144 |
+
},
|
145 |
+
500)
|
146 |
+
}
|
147 |
+
|
148 |
+
appendContextMenuOption('#txt2img_generate','Generate forever',function(){
|
149 |
+
generateOnRepeat('#txt2img_generate','#txt2img_interrupt');
|
150 |
+
})
|
151 |
+
appendContextMenuOption('#img2img_generate','Generate forever',function(){
|
152 |
+
generateOnRepeat('#img2img_generate','#img2img_interrupt');
|
153 |
+
})
|
154 |
+
|
155 |
+
let cancelGenerateForever = function(){
|
156 |
+
clearInterval(window.generateOnRepeatInterval)
|
157 |
+
}
|
158 |
+
|
159 |
+
appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)
|
160 |
+
appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever)
|
161 |
+
appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
|
162 |
+
appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
|
163 |
+
|
164 |
+
appendContextMenuOption('#roll','Roll three',
|
165 |
+
function(){
|
166 |
+
let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
|
167 |
+
setTimeout(function(){rollbutton.click()},100)
|
168 |
+
setTimeout(function(){rollbutton.click()},200)
|
169 |
+
setTimeout(function(){rollbutton.click()},300)
|
170 |
+
}
|
171 |
+
)
|
172 |
+
})();
|
173 |
+
//End example Context Menu Items
|
174 |
+
|
175 |
+
onUiUpdate(function(){
|
176 |
+
addContextMenuEventListener()
|
177 |
+
});
|
sd/stable-diffusion-webui/javascript/edit-attention.js
CHANGED
@@ -1,96 +1,96 @@
|
|
1 |
-
function keyupEditAttention(event){
|
2 |
-
let target = event.originalTarget || event.composedPath()[0];
|
3 |
-
if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return;
|
4 |
-
if (! (event.metaKey || event.ctrlKey)) return;
|
5 |
-
|
6 |
-
let isPlus = event.key == "ArrowUp"
|
7 |
-
let isMinus = event.key == "ArrowDown"
|
8 |
-
if (!isPlus && !isMinus) return;
|
9 |
-
|
10 |
-
let selectionStart = target.selectionStart;
|
11 |
-
let selectionEnd = target.selectionEnd;
|
12 |
-
let text = target.value;
|
13 |
-
|
14 |
-
function selectCurrentParenthesisBlock(OPEN, CLOSE){
|
15 |
-
if (selectionStart !== selectionEnd) return false;
|
16 |
-
|
17 |
-
// Find opening parenthesis around current cursor
|
18 |
-
const before = text.substring(0, selectionStart);
|
19 |
-
let beforeParen = before.lastIndexOf(OPEN);
|
20 |
-
if (beforeParen == -1) return false;
|
21 |
-
let beforeParenClose = before.lastIndexOf(CLOSE);
|
22 |
-
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
|
23 |
-
beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
|
24 |
-
beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
|
25 |
-
}
|
26 |
-
|
27 |
-
// Find closing parenthesis around current cursor
|
28 |
-
const after = text.substring(selectionStart);
|
29 |
-
let afterParen = after.indexOf(CLOSE);
|
30 |
-
if (afterParen == -1) return false;
|
31 |
-
let afterParenOpen = after.indexOf(OPEN);
|
32 |
-
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
|
33 |
-
afterParen = after.indexOf(CLOSE, afterParen + 1);
|
34 |
-
afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
|
35 |
-
}
|
36 |
-
if (beforeParen === -1 || afterParen === -1) return false;
|
37 |
-
|
38 |
-
// Set the selection to the text between the parenthesis
|
39 |
-
const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
|
40 |
-
const lastColon = parenContent.lastIndexOf(":");
|
41 |
-
selectionStart = beforeParen + 1;
|
42 |
-
selectionEnd = selectionStart + lastColon;
|
43 |
-
target.setSelectionRange(selectionStart, selectionEnd);
|
44 |
-
return true;
|
45 |
-
}
|
46 |
-
|
47 |
-
// If the user hasn't selected anything, let's select their current parenthesis block
|
48 |
-
if(! selectCurrentParenthesisBlock('<', '>')){
|
49 |
-
selectCurrentParenthesisBlock('(', ')')
|
50 |
-
}
|
51 |
-
|
52 |
-
event.preventDefault();
|
53 |
-
|
54 |
-
closeCharacter = ')'
|
55 |
-
delta = opts.keyedit_precision_attention
|
56 |
-
|
57 |
-
if (selectionStart > 0 && text[selectionStart - 1] == '<'){
|
58 |
-
closeCharacter = '>'
|
59 |
-
delta = opts.keyedit_precision_extra
|
60 |
-
} else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
|
61 |
-
|
62 |
-
// do not include spaces at the end
|
63 |
-
while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){
|
64 |
-
selectionEnd -= 1;
|
65 |
-
}
|
66 |
-
if(selectionStart == selectionEnd){
|
67 |
-
return
|
68 |
-
}
|
69 |
-
|
70 |
-
text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
|
71 |
-
|
72 |
-
selectionStart += 1;
|
73 |
-
selectionEnd += 1;
|
74 |
-
}
|
75 |
-
|
76 |
-
end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
|
77 |
-
weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
|
78 |
-
if (isNaN(weight)) return;
|
79 |
-
|
80 |
-
weight += isPlus ? delta : -delta;
|
81 |
-
weight = parseFloat(weight.toPrecision(12));
|
82 |
-
if(String(weight).length == 1) weight += ".0"
|
83 |
-
|
84 |
-
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
|
85 |
-
|
86 |
-
target.focus();
|
87 |
-
target.value = text;
|
88 |
-
target.selectionStart = selectionStart;
|
89 |
-
target.selectionEnd = selectionEnd;
|
90 |
-
|
91 |
-
updateInput(target)
|
92 |
-
}
|
93 |
-
|
94 |
-
addEventListener('keydown', (event) => {
|
95 |
-
keyupEditAttention(event);
|
96 |
});
|
|
|
1 |
+
function keyupEditAttention(event){
|
2 |
+
let target = event.originalTarget || event.composedPath()[0];
|
3 |
+
if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return;
|
4 |
+
if (! (event.metaKey || event.ctrlKey)) return;
|
5 |
+
|
6 |
+
let isPlus = event.key == "ArrowUp"
|
7 |
+
let isMinus = event.key == "ArrowDown"
|
8 |
+
if (!isPlus && !isMinus) return;
|
9 |
+
|
10 |
+
let selectionStart = target.selectionStart;
|
11 |
+
let selectionEnd = target.selectionEnd;
|
12 |
+
let text = target.value;
|
13 |
+
|
14 |
+
function selectCurrentParenthesisBlock(OPEN, CLOSE){
|
15 |
+
if (selectionStart !== selectionEnd) return false;
|
16 |
+
|
17 |
+
// Find opening parenthesis around current cursor
|
18 |
+
const before = text.substring(0, selectionStart);
|
19 |
+
let beforeParen = before.lastIndexOf(OPEN);
|
20 |
+
if (beforeParen == -1) return false;
|
21 |
+
let beforeParenClose = before.lastIndexOf(CLOSE);
|
22 |
+
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
|
23 |
+
beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
|
24 |
+
beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
|
25 |
+
}
|
26 |
+
|
27 |
+
// Find closing parenthesis around current cursor
|
28 |
+
const after = text.substring(selectionStart);
|
29 |
+
let afterParen = after.indexOf(CLOSE);
|
30 |
+
if (afterParen == -1) return false;
|
31 |
+
let afterParenOpen = after.indexOf(OPEN);
|
32 |
+
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
|
33 |
+
afterParen = after.indexOf(CLOSE, afterParen + 1);
|
34 |
+
afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
|
35 |
+
}
|
36 |
+
if (beforeParen === -1 || afterParen === -1) return false;
|
37 |
+
|
38 |
+
// Set the selection to the text between the parenthesis
|
39 |
+
const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
|
40 |
+
const lastColon = parenContent.lastIndexOf(":");
|
41 |
+
selectionStart = beforeParen + 1;
|
42 |
+
selectionEnd = selectionStart + lastColon;
|
43 |
+
target.setSelectionRange(selectionStart, selectionEnd);
|
44 |
+
return true;
|
45 |
+
}
|
46 |
+
|
47 |
+
// If the user hasn't selected anything, let's select their current parenthesis block
|
48 |
+
if(! selectCurrentParenthesisBlock('<', '>')){
|
49 |
+
selectCurrentParenthesisBlock('(', ')')
|
50 |
+
}
|
51 |
+
|
52 |
+
event.preventDefault();
|
53 |
+
|
54 |
+
closeCharacter = ')'
|
55 |
+
delta = opts.keyedit_precision_attention
|
56 |
+
|
57 |
+
if (selectionStart > 0 && text[selectionStart - 1] == '<'){
|
58 |
+
closeCharacter = '>'
|
59 |
+
delta = opts.keyedit_precision_extra
|
60 |
+
} else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
|
61 |
+
|
62 |
+
// do not include spaces at the end
|
63 |
+
while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){
|
64 |
+
selectionEnd -= 1;
|
65 |
+
}
|
66 |
+
if(selectionStart == selectionEnd){
|
67 |
+
return
|
68 |
+
}
|
69 |
+
|
70 |
+
text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
|
71 |
+
|
72 |
+
selectionStart += 1;
|
73 |
+
selectionEnd += 1;
|
74 |
+
}
|
75 |
+
|
76 |
+
end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
|
77 |
+
weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
|
78 |
+
if (isNaN(weight)) return;
|
79 |
+
|
80 |
+
weight += isPlus ? delta : -delta;
|
81 |
+
weight = parseFloat(weight.toPrecision(12));
|
82 |
+
if(String(weight).length == 1) weight += ".0"
|
83 |
+
|
84 |
+
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
|
85 |
+
|
86 |
+
target.focus();
|
87 |
+
target.value = text;
|
88 |
+
target.selectionStart = selectionStart;
|
89 |
+
target.selectionEnd = selectionEnd;
|
90 |
+
|
91 |
+
updateInput(target)
|
92 |
+
}
|
93 |
+
|
94 |
+
addEventListener('keydown', (event) => {
|
95 |
+
keyupEditAttention(event);
|
96 |
});
|
sd/stable-diffusion-webui/javascript/extensions.js
CHANGED
@@ -1,49 +1,49 @@
|
|
1 |
-
|
2 |
-
function extensions_apply(_, _){
|
3 |
-
var disable = []
|
4 |
-
var update = []
|
5 |
-
|
6 |
-
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
7 |
-
if(x.name.startsWith("enable_") && ! x.checked)
|
8 |
-
disable.push(x.name.substr(7))
|
9 |
-
|
10 |
-
if(x.name.startsWith("update_") && x.checked)
|
11 |
-
update.push(x.name.substr(7))
|
12 |
-
})
|
13 |
-
|
14 |
-
restart_reload()
|
15 |
-
|
16 |
-
return [JSON.stringify(disable), JSON.stringify(update)]
|
17 |
-
}
|
18 |
-
|
19 |
-
function extensions_check(){
|
20 |
-
var disable = []
|
21 |
-
|
22 |
-
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
23 |
-
if(x.name.startsWith("enable_") && ! x.checked)
|
24 |
-
disable.push(x.name.substr(7))
|
25 |
-
})
|
26 |
-
|
27 |
-
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
|
28 |
-
x.innerHTML = "Loading..."
|
29 |
-
})
|
30 |
-
|
31 |
-
|
32 |
-
var id = randomId()
|
33 |
-
requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){
|
34 |
-
|
35 |
-
})
|
36 |
-
|
37 |
-
return [id, JSON.stringify(disable)]
|
38 |
-
}
|
39 |
-
|
40 |
-
function install_extension_from_index(button, url){
|
41 |
-
button.disabled = "disabled"
|
42 |
-
button.value = "Installing..."
|
43 |
-
|
44 |
-
textarea = gradioApp().querySelector('#extension_to_install textarea')
|
45 |
-
textarea.value = url
|
46 |
-
updateInput(textarea)
|
47 |
-
|
48 |
-
gradioApp().querySelector('#install_extension_button').click()
|
49 |
-
}
|
|
|
1 |
+
|
2 |
+
function extensions_apply(_, _){
|
3 |
+
var disable = []
|
4 |
+
var update = []
|
5 |
+
|
6 |
+
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
7 |
+
if(x.name.startsWith("enable_") && ! x.checked)
|
8 |
+
disable.push(x.name.substr(7))
|
9 |
+
|
10 |
+
if(x.name.startsWith("update_") && x.checked)
|
11 |
+
update.push(x.name.substr(7))
|
12 |
+
})
|
13 |
+
|
14 |
+
restart_reload()
|
15 |
+
|
16 |
+
return [JSON.stringify(disable), JSON.stringify(update)]
|
17 |
+
}
|
18 |
+
|
19 |
+
function extensions_check(){
|
20 |
+
var disable = []
|
21 |
+
|
22 |
+
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
23 |
+
if(x.name.startsWith("enable_") && ! x.checked)
|
24 |
+
disable.push(x.name.substr(7))
|
25 |
+
})
|
26 |
+
|
27 |
+
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
|
28 |
+
x.innerHTML = "Loading..."
|
29 |
+
})
|
30 |
+
|
31 |
+
|
32 |
+
var id = randomId()
|
33 |
+
requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){
|
34 |
+
|
35 |
+
})
|
36 |
+
|
37 |
+
return [id, JSON.stringify(disable)]
|
38 |
+
}
|
39 |
+
|
40 |
+
function install_extension_from_index(button, url){
|
41 |
+
button.disabled = "disabled"
|
42 |
+
button.value = "Installing..."
|
43 |
+
|
44 |
+
textarea = gradioApp().querySelector('#extension_to_install textarea')
|
45 |
+
textarea.value = url
|
46 |
+
updateInput(textarea)
|
47 |
+
|
48 |
+
gradioApp().querySelector('#install_extension_button').click()
|
49 |
+
}
|
sd/stable-diffusion-webui/javascript/extraNetworks.js
CHANGED
@@ -1,107 +1,107 @@
|
|
1 |
-
|
2 |
-
function setupExtraNetworksForTab(tabname){
|
3 |
-
gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
|
4 |
-
|
5 |
-
var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div')
|
6 |
-
var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea')
|
7 |
-
var refresh = gradioApp().getElementById(tabname+'_extra_refresh')
|
8 |
-
var close = gradioApp().getElementById(tabname+'_extra_close')
|
9 |
-
|
10 |
-
search.classList.add('search')
|
11 |
-
tabs.appendChild(search)
|
12 |
-
tabs.appendChild(refresh)
|
13 |
-
tabs.appendChild(close)
|
14 |
-
|
15 |
-
search.addEventListener("input", function(evt){
|
16 |
-
searchTerm = search.value.toLowerCase()
|
17 |
-
|
18 |
-
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
|
19 |
-
text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
|
20 |
-
elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
|
21 |
-
})
|
22 |
-
});
|
23 |
-
}
|
24 |
-
|
25 |
-
var activePromptTextarea = {};
|
26 |
-
|
27 |
-
function setupExtraNetworks(){
|
28 |
-
setupExtraNetworksForTab('txt2img')
|
29 |
-
setupExtraNetworksForTab('img2img')
|
30 |
-
|
31 |
-
function registerPrompt(tabname, id){
|
32 |
-
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
|
33 |
-
|
34 |
-
if (! activePromptTextarea[tabname]){
|
35 |
-
activePromptTextarea[tabname] = textarea
|
36 |
-
}
|
37 |
-
|
38 |
-
textarea.addEventListener("focus", function(){
|
39 |
-
activePromptTextarea[tabname] = textarea;
|
40 |
-
});
|
41 |
-
}
|
42 |
-
|
43 |
-
registerPrompt('txt2img', 'txt2img_prompt')
|
44 |
-
registerPrompt('txt2img', 'txt2img_neg_prompt')
|
45 |
-
registerPrompt('img2img', 'img2img_prompt')
|
46 |
-
registerPrompt('img2img', 'img2img_neg_prompt')
|
47 |
-
}
|
48 |
-
|
49 |
-
onUiLoaded(setupExtraNetworks)
|
50 |
-
|
51 |
-
var re_extranet = /<([^:]+:[^:]+):[\d\.]+>/;
|
52 |
-
var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g;
|
53 |
-
|
54 |
-
function tryToRemoveExtraNetworkFromPrompt(textarea, text){
|
55 |
-
var m = text.match(re_extranet)
|
56 |
-
if(! m) return false
|
57 |
-
|
58 |
-
var partToSearch = m[1]
|
59 |
-
var replaced = false
|
60 |
-
var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){
|
61 |
-
m = found.match(re_extranet);
|
62 |
-
if(m[1] == partToSearch){
|
63 |
-
replaced = true;
|
64 |
-
return ""
|
65 |
-
}
|
66 |
-
return found;
|
67 |
-
})
|
68 |
-
|
69 |
-
if(replaced){
|
70 |
-
textarea.value = newTextareaText
|
71 |
-
return true;
|
72 |
-
}
|
73 |
-
|
74 |
-
return false
|
75 |
-
}
|
76 |
-
|
77 |
-
function cardClicked(tabname, textToAdd, allowNegativePrompt){
|
78 |
-
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
|
79 |
-
|
80 |
-
if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){
|
81 |
-
textarea.value = textarea.value +
|
82 |
-
}
|
83 |
-
|
84 |
-
updateInput(textarea)
|
85 |
-
}
|
86 |
-
|
87 |
-
function saveCardPreview(event, tabname, filename){
|
88 |
-
var textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea')
|
89 |
-
var button = gradioApp().getElementById(tabname + '_save_preview')
|
90 |
-
|
91 |
-
textarea.value = filename
|
92 |
-
updateInput(textarea)
|
93 |
-
|
94 |
-
button.click()
|
95 |
-
|
96 |
-
event.stopPropagation()
|
97 |
-
event.preventDefault()
|
98 |
-
}
|
99 |
-
|
100 |
-
function extraNetworksSearchButton(tabs_id, event){
|
101 |
-
searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
|
102 |
-
button = event.target
|
103 |
-
text = button.classList.contains("search-all") ? "" : button.textContent.trim()
|
104 |
-
|
105 |
-
searchTextarea.value = text
|
106 |
-
updateInput(searchTextarea)
|
107 |
}
|
|
|
1 |
+
|
2 |
+
function setupExtraNetworksForTab(tabname){
|
3 |
+
gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
|
4 |
+
|
5 |
+
var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div')
|
6 |
+
var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea')
|
7 |
+
var refresh = gradioApp().getElementById(tabname+'_extra_refresh')
|
8 |
+
var close = gradioApp().getElementById(tabname+'_extra_close')
|
9 |
+
|
10 |
+
search.classList.add('search')
|
11 |
+
tabs.appendChild(search)
|
12 |
+
tabs.appendChild(refresh)
|
13 |
+
tabs.appendChild(close)
|
14 |
+
|
15 |
+
search.addEventListener("input", function(evt){
|
16 |
+
searchTerm = search.value.toLowerCase()
|
17 |
+
|
18 |
+
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
|
19 |
+
text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
|
20 |
+
elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
|
21 |
+
})
|
22 |
+
});
|
23 |
+
}
|
24 |
+
|
25 |
+
var activePromptTextarea = {};
|
26 |
+
|
27 |
+
function setupExtraNetworks(){
|
28 |
+
setupExtraNetworksForTab('txt2img')
|
29 |
+
setupExtraNetworksForTab('img2img')
|
30 |
+
|
31 |
+
function registerPrompt(tabname, id){
|
32 |
+
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
|
33 |
+
|
34 |
+
if (! activePromptTextarea[tabname]){
|
35 |
+
activePromptTextarea[tabname] = textarea
|
36 |
+
}
|
37 |
+
|
38 |
+
textarea.addEventListener("focus", function(){
|
39 |
+
activePromptTextarea[tabname] = textarea;
|
40 |
+
});
|
41 |
+
}
|
42 |
+
|
43 |
+
registerPrompt('txt2img', 'txt2img_prompt')
|
44 |
+
registerPrompt('txt2img', 'txt2img_neg_prompt')
|
45 |
+
registerPrompt('img2img', 'img2img_prompt')
|
46 |
+
registerPrompt('img2img', 'img2img_neg_prompt')
|
47 |
+
}
|
48 |
+
|
49 |
+
onUiLoaded(setupExtraNetworks)
|
50 |
+
|
51 |
+
var re_extranet = /<([^:]+:[^:]+):[\d\.]+>/;
|
52 |
+
var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g;
|
53 |
+
|
54 |
+
function tryToRemoveExtraNetworkFromPrompt(textarea, text){
|
55 |
+
var m = text.match(re_extranet)
|
56 |
+
if(! m) return false
|
57 |
+
|
58 |
+
var partToSearch = m[1]
|
59 |
+
var replaced = false
|
60 |
+
var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){
|
61 |
+
m = found.match(re_extranet);
|
62 |
+
if(m[1] == partToSearch){
|
63 |
+
replaced = true;
|
64 |
+
return ""
|
65 |
+
}
|
66 |
+
return found;
|
67 |
+
})
|
68 |
+
|
69 |
+
if(replaced){
|
70 |
+
textarea.value = newTextareaText
|
71 |
+
return true;
|
72 |
+
}
|
73 |
+
|
74 |
+
return false
|
75 |
+
}
|
76 |
+
|
77 |
+
function cardClicked(tabname, textToAdd, allowNegativePrompt){
|
78 |
+
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
|
79 |
+
|
80 |
+
if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){
|
81 |
+
textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd
|
82 |
+
}
|
83 |
+
|
84 |
+
updateInput(textarea)
|
85 |
+
}
|
86 |
+
|
87 |
+
function saveCardPreview(event, tabname, filename){
|
88 |
+
var textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea')
|
89 |
+
var button = gradioApp().getElementById(tabname + '_save_preview')
|
90 |
+
|
91 |
+
textarea.value = filename
|
92 |
+
updateInput(textarea)
|
93 |
+
|
94 |
+
button.click()
|
95 |
+
|
96 |
+
event.stopPropagation()
|
97 |
+
event.preventDefault()
|
98 |
+
}
|
99 |
+
|
100 |
+
function extraNetworksSearchButton(tabs_id, event){
|
101 |
+
searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
|
102 |
+
button = event.target
|
103 |
+
text = button.classList.contains("search-all") ? "" : button.textContent.trim()
|
104 |
+
|
105 |
+
searchTextarea.value = text
|
106 |
+
updateInput(searchTextarea)
|
107 |
}
|
sd/stable-diffusion-webui/javascript/hints.js
CHANGED
@@ -6,6 +6,7 @@ titles = {
|
|
6 |
"GFPGAN": "Restore low quality faces using GFPGAN neural network",
|
7 |
"Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help",
|
8 |
"DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
|
|
|
9 |
"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
|
10 |
|
11 |
"Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)",
|
|
|
6 |
"GFPGAN": "Restore low quality faces using GFPGAN neural network",
|
7 |
"Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help",
|
8 |
"DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
|
9 |
+
"UniPC": "Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models",
|
10 |
"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
|
11 |
|
12 |
"Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)",
|
sd/stable-diffusion-webui/javascript/hires_fix.js
CHANGED
@@ -1,22 +1,22 @@
|
|
1 |
-
|
2 |
-
function setInactive(elem, inactive){
|
3 |
-
if(inactive){
|
4 |
-
elem.classList.add('inactive')
|
5 |
-
} else{
|
6 |
-
elem.classList.remove('inactive')
|
7 |
-
}
|
8 |
-
}
|
9 |
-
|
10 |
-
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
|
11 |
-
hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
|
12 |
-
hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
|
13 |
-
hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
|
14 |
-
|
15 |
-
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
|
16 |
-
|
17 |
-
setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0)
|
18 |
-
setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0)
|
19 |
-
setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0)
|
20 |
-
|
21 |
-
return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y]
|
22 |
-
}
|
|
|
1 |
+
|
2 |
+
function setInactive(elem, inactive){
|
3 |
+
if(inactive){
|
4 |
+
elem.classList.add('inactive')
|
5 |
+
} else{
|
6 |
+
elem.classList.remove('inactive')
|
7 |
+
}
|
8 |
+
}
|
9 |
+
|
10 |
+
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
|
11 |
+
hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
|
12 |
+
hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
|
13 |
+
hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
|
14 |
+
|
15 |
+
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
|
16 |
+
|
17 |
+
setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0)
|
18 |
+
setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0)
|
19 |
+
setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0)
|
20 |
+
|
21 |
+
return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y]
|
22 |
+
}
|
sd/stable-diffusion-webui/javascript/localization.js
CHANGED
@@ -1,165 +1,165 @@
|
|
1 |
-
|
2 |
-
// localization = {} -- the dict with translations is created by the backend
|
3 |
-
|
4 |
-
ignore_ids_for_localization={
|
5 |
-
setting_sd_hypernetwork: 'OPTION',
|
6 |
-
setting_sd_model_checkpoint: 'OPTION',
|
7 |
-
setting_realesrgan_enabled_models: 'OPTION',
|
8 |
-
modelmerger_primary_model_name: 'OPTION',
|
9 |
-
modelmerger_secondary_model_name: 'OPTION',
|
10 |
-
modelmerger_tertiary_model_name: 'OPTION',
|
11 |
-
train_embedding: 'OPTION',
|
12 |
-
train_hypernetwork: 'OPTION',
|
13 |
-
txt2img_styles: 'OPTION',
|
14 |
-
img2img_styles: 'OPTION',
|
15 |
-
setting_random_artist_categories: 'SPAN',
|
16 |
-
setting_face_restoration_model: 'SPAN',
|
17 |
-
setting_realesrgan_enabled_models: 'SPAN',
|
18 |
-
extras_upscaler_1: 'SPAN',
|
19 |
-
extras_upscaler_2: 'SPAN',
|
20 |
-
}
|
21 |
-
|
22 |
-
re_num = /^[\.\d]+$/
|
23 |
-
re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
|
24 |
-
|
25 |
-
original_lines = {}
|
26 |
-
translated_lines = {}
|
27 |
-
|
28 |
-
function textNodesUnder(el){
|
29 |
-
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
|
30 |
-
while(n=walk.nextNode()) a.push(n);
|
31 |
-
return a;
|
32 |
-
}
|
33 |
-
|
34 |
-
function canBeTranslated(node, text){
|
35 |
-
if(! text) return false;
|
36 |
-
if(! node.parentElement) return false;
|
37 |
-
|
38 |
-
parentType = node.parentElement.nodeName
|
39 |
-
if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
|
40 |
-
|
41 |
-
if (parentType=='OPTION' || parentType=='SPAN'){
|
42 |
-
pnode = node
|
43 |
-
for(var level=0; level<4; level++){
|
44 |
-
pnode = pnode.parentElement
|
45 |
-
if(! pnode) break;
|
46 |
-
|
47 |
-
if(ignore_ids_for_localization[pnode.id] == parentType) return false;
|
48 |
-
}
|
49 |
-
}
|
50 |
-
|
51 |
-
if(re_num.test(text)) return false;
|
52 |
-
if(re_emoji.test(text)) return false;
|
53 |
-
return true
|
54 |
-
}
|
55 |
-
|
56 |
-
function getTranslation(text){
|
57 |
-
if(! text) return undefined
|
58 |
-
|
59 |
-
if(translated_lines[text] === undefined){
|
60 |
-
original_lines[text] = 1
|
61 |
-
}
|
62 |
-
|
63 |
-
tl = localization[text]
|
64 |
-
if(tl !== undefined){
|
65 |
-
translated_lines[tl] = 1
|
66 |
-
}
|
67 |
-
|
68 |
-
return tl
|
69 |
-
}
|
70 |
-
|
71 |
-
function processTextNode(node){
|
72 |
-
text = node.textContent.trim()
|
73 |
-
|
74 |
-
if(! canBeTranslated(node, text)) return
|
75 |
-
|
76 |
-
tl = getTranslation(text)
|
77 |
-
if(tl !== undefined){
|
78 |
-
node.textContent = tl
|
79 |
-
}
|
80 |
-
}
|
81 |
-
|
82 |
-
function processNode(node){
|
83 |
-
if(node.nodeType == 3){
|
84 |
-
processTextNode(node)
|
85 |
-
return
|
86 |
-
}
|
87 |
-
|
88 |
-
if(node.title){
|
89 |
-
tl = getTranslation(node.title)
|
90 |
-
if(tl !== undefined){
|
91 |
-
node.title = tl
|
92 |
-
}
|
93 |
-
}
|
94 |
-
|
95 |
-
if(node.placeholder){
|
96 |
-
tl = getTranslation(node.placeholder)
|
97 |
-
if(tl !== undefined){
|
98 |
-
node.placeholder = tl
|
99 |
-
}
|
100 |
-
}
|
101 |
-
|
102 |
-
textNodesUnder(node).forEach(function(node){
|
103 |
-
processTextNode(node)
|
104 |
-
})
|
105 |
-
}
|
106 |
-
|
107 |
-
function dumpTranslations(){
|
108 |
-
dumped = {}
|
109 |
-
if (localization.rtl) {
|
110 |
-
dumped.rtl = true
|
111 |
-
}
|
112 |
-
|
113 |
-
Object.keys(original_lines).forEach(function(text){
|
114 |
-
if(dumped[text] !== undefined) return
|
115 |
-
|
116 |
-
dumped[text] = localization[text] || text
|
117 |
-
})
|
118 |
-
|
119 |
-
return dumped
|
120 |
-
}
|
121 |
-
|
122 |
-
onUiUpdate(function(m){
|
123 |
-
m.forEach(function(mutation){
|
124 |
-
mutation.addedNodes.forEach(function(node){
|
125 |
-
processNode(node)
|
126 |
-
})
|
127 |
-
});
|
128 |
-
})
|
129 |
-
|
130 |
-
|
131 |
-
document.addEventListener("DOMContentLoaded", function() {
|
132 |
-
processNode(gradioApp())
|
133 |
-
|
134 |
-
if (localization.rtl) { // if the language is from right to left,
|
135 |
-
(new MutationObserver((mutations, observer) => { // wait for the style to load
|
136 |
-
mutations.forEach(mutation => {
|
137 |
-
mutation.addedNodes.forEach(node => {
|
138 |
-
if (node.tagName === 'STYLE') {
|
139 |
-
observer.disconnect();
|
140 |
-
|
141 |
-
for (const x of node.sheet.rules) { // find all rtl media rules
|
142 |
-
if (Array.from(x.media || []).includes('rtl')) {
|
143 |
-
x.media.appendMedium('all'); // enable them
|
144 |
-
}
|
145 |
-
}
|
146 |
-
}
|
147 |
-
})
|
148 |
-
});
|
149 |
-
})).observe(gradioApp(), { childList: true });
|
150 |
-
}
|
151 |
-
})
|
152 |
-
|
153 |
-
function download_localization() {
|
154 |
-
text = JSON.stringify(dumpTranslations(), null, 4)
|
155 |
-
|
156 |
-
var element = document.createElement('a');
|
157 |
-
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
|
158 |
-
element.setAttribute('download', "localization.json");
|
159 |
-
element.style.display = 'none';
|
160 |
-
document.body.appendChild(element);
|
161 |
-
|
162 |
-
element.click();
|
163 |
-
|
164 |
-
document.body.removeChild(element);
|
165 |
-
}
|
|
|
1 |
+
|
2 |
+
// localization = {} -- the dict with translations is created by the backend
|
3 |
+
|
4 |
+
ignore_ids_for_localization={
|
5 |
+
setting_sd_hypernetwork: 'OPTION',
|
6 |
+
setting_sd_model_checkpoint: 'OPTION',
|
7 |
+
setting_realesrgan_enabled_models: 'OPTION',
|
8 |
+
modelmerger_primary_model_name: 'OPTION',
|
9 |
+
modelmerger_secondary_model_name: 'OPTION',
|
10 |
+
modelmerger_tertiary_model_name: 'OPTION',
|
11 |
+
train_embedding: 'OPTION',
|
12 |
+
train_hypernetwork: 'OPTION',
|
13 |
+
txt2img_styles: 'OPTION',
|
14 |
+
img2img_styles: 'OPTION',
|
15 |
+
setting_random_artist_categories: 'SPAN',
|
16 |
+
setting_face_restoration_model: 'SPAN',
|
17 |
+
setting_realesrgan_enabled_models: 'SPAN',
|
18 |
+
extras_upscaler_1: 'SPAN',
|
19 |
+
extras_upscaler_2: 'SPAN',
|
20 |
+
}
|
21 |
+
|
22 |
+
re_num = /^[\.\d]+$/
|
23 |
+
re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
|
24 |
+
|
25 |
+
original_lines = {}
|
26 |
+
translated_lines = {}
|
27 |
+
|
28 |
+
function textNodesUnder(el){
|
29 |
+
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
|
30 |
+
while(n=walk.nextNode()) a.push(n);
|
31 |
+
return a;
|
32 |
+
}
|
33 |
+
|
34 |
+
function canBeTranslated(node, text){
|
35 |
+
if(! text) return false;
|
36 |
+
if(! node.parentElement) return false;
|
37 |
+
|
38 |
+
parentType = node.parentElement.nodeName
|
39 |
+
if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
|
40 |
+
|
41 |
+
if (parentType=='OPTION' || parentType=='SPAN'){
|
42 |
+
pnode = node
|
43 |
+
for(var level=0; level<4; level++){
|
44 |
+
pnode = pnode.parentElement
|
45 |
+
if(! pnode) break;
|
46 |
+
|
47 |
+
if(ignore_ids_for_localization[pnode.id] == parentType) return false;
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
if(re_num.test(text)) return false;
|
52 |
+
if(re_emoji.test(text)) return false;
|
53 |
+
return true
|
54 |
+
}
|
55 |
+
|
56 |
+
function getTranslation(text){
|
57 |
+
if(! text) return undefined
|
58 |
+
|
59 |
+
if(translated_lines[text] === undefined){
|
60 |
+
original_lines[text] = 1
|
61 |
+
}
|
62 |
+
|
63 |
+
tl = localization[text]
|
64 |
+
if(tl !== undefined){
|
65 |
+
translated_lines[tl] = 1
|
66 |
+
}
|
67 |
+
|
68 |
+
return tl
|
69 |
+
}
|
70 |
+
|
71 |
+
function processTextNode(node){
|
72 |
+
text = node.textContent.trim()
|
73 |
+
|
74 |
+
if(! canBeTranslated(node, text)) return
|
75 |
+
|
76 |
+
tl = getTranslation(text)
|
77 |
+
if(tl !== undefined){
|
78 |
+
node.textContent = tl
|
79 |
+
}
|
80 |
+
}
|
81 |
+
|
82 |
+
function processNode(node){
|
83 |
+
if(node.nodeType == 3){
|
84 |
+
processTextNode(node)
|
85 |
+
return
|
86 |
+
}
|
87 |
+
|
88 |
+
if(node.title){
|
89 |
+
tl = getTranslation(node.title)
|
90 |
+
if(tl !== undefined){
|
91 |
+
node.title = tl
|
92 |
+
}
|
93 |
+
}
|
94 |
+
|
95 |
+
if(node.placeholder){
|
96 |
+
tl = getTranslation(node.placeholder)
|
97 |
+
if(tl !== undefined){
|
98 |
+
node.placeholder = tl
|
99 |
+
}
|
100 |
+
}
|
101 |
+
|
102 |
+
textNodesUnder(node).forEach(function(node){
|
103 |
+
processTextNode(node)
|
104 |
+
})
|
105 |
+
}
|
106 |
+
|
107 |
+
function dumpTranslations(){
|
108 |
+
dumped = {}
|
109 |
+
if (localization.rtl) {
|
110 |
+
dumped.rtl = true
|
111 |
+
}
|
112 |
+
|
113 |
+
Object.keys(original_lines).forEach(function(text){
|
114 |
+
if(dumped[text] !== undefined) return
|
115 |
+
|
116 |
+
dumped[text] = localization[text] || text
|
117 |
+
})
|
118 |
+
|
119 |
+
return dumped
|
120 |
+
}
|
121 |
+
|
122 |
+
onUiUpdate(function(m){
|
123 |
+
m.forEach(function(mutation){
|
124 |
+
mutation.addedNodes.forEach(function(node){
|
125 |
+
processNode(node)
|
126 |
+
})
|
127 |
+
});
|
128 |
+
})
|
129 |
+
|
130 |
+
|
131 |
+
document.addEventListener("DOMContentLoaded", function() {
|
132 |
+
processNode(gradioApp())
|
133 |
+
|
134 |
+
if (localization.rtl) { // if the language is from right to left,
|
135 |
+
(new MutationObserver((mutations, observer) => { // wait for the style to load
|
136 |
+
mutations.forEach(mutation => {
|
137 |
+
mutation.addedNodes.forEach(node => {
|
138 |
+
if (node.tagName === 'STYLE') {
|
139 |
+
observer.disconnect();
|
140 |
+
|
141 |
+
for (const x of node.sheet.rules) { // find all rtl media rules
|
142 |
+
if (Array.from(x.media || []).includes('rtl')) {
|
143 |
+
x.media.appendMedium('all'); // enable them
|
144 |
+
}
|
145 |
+
}
|
146 |
+
}
|
147 |
+
})
|
148 |
+
});
|
149 |
+
})).observe(gradioApp(), { childList: true });
|
150 |
+
}
|
151 |
+
})
|
152 |
+
|
153 |
+
function download_localization() {
|
154 |
+
text = JSON.stringify(dumpTranslations(), null, 4)
|
155 |
+
|
156 |
+
var element = document.createElement('a');
|
157 |
+
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
|
158 |
+
element.setAttribute('download', "localization.json");
|
159 |
+
element.style.display = 'none';
|
160 |
+
document.body.appendChild(element);
|
161 |
+
|
162 |
+
element.click();
|
163 |
+
|
164 |
+
document.body.removeChild(element);
|
165 |
+
}
|
sd/stable-diffusion-webui/javascript/notification.js
CHANGED
@@ -15,7 +15,7 @@ onUiUpdate(function(){
|
|
15 |
}
|
16 |
}
|
17 |
|
18 |
-
const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] img.h-full.w-full.overflow-hidden');
|
19 |
|
20 |
if (galleryPreviews == null) return;
|
21 |
|
|
|
15 |
}
|
16 |
}
|
17 |
|
18 |
+
const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] img.h-full.w-full.overflow-hidden');
|
19 |
|
20 |
if (galleryPreviews == null) return;
|
21 |
|
sd/stable-diffusion-webui/javascript/progressbar.js
CHANGED
@@ -139,7 +139,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
|
139 |
|
140 |
var divProgress = document.createElement('div')
|
141 |
divProgress.className='progressDiv'
|
142 |
-
divProgress.style.display = opts.show_progressbar ? "" : "none"
|
143 |
var divInner = document.createElement('div')
|
144 |
divInner.className='progress'
|
145 |
|
|
|
139 |
|
140 |
var divProgress = document.createElement('div')
|
141 |
divProgress.className='progressDiv'
|
142 |
+
divProgress.style.display = opts.show_progressbar ? "block" : "none"
|
143 |
var divInner = document.createElement('div')
|
144 |
divInner.className='progress'
|
145 |
|
sd/stable-diffusion-webui/javascript/textualInversion.js
CHANGED
@@ -1,17 +1,17 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
function start_training_textual_inversion(){
|
5 |
-
gradioApp().querySelector('#ti_error').innerHTML=''
|
6 |
-
|
7 |
-
var id = randomId()
|
8 |
-
requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){
|
9 |
-
gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo
|
10 |
-
})
|
11 |
-
|
12 |
-
var res = args_to_array(arguments)
|
13 |
-
|
14 |
-
res[0] = id
|
15 |
-
|
16 |
-
return res
|
17 |
-
}
|
|
|
1 |
+
|
2 |
+
|
3 |
+
|
4 |
+
function start_training_textual_inversion(){
|
5 |
+
gradioApp().querySelector('#ti_error').innerHTML=''
|
6 |
+
|
7 |
+
var id = randomId()
|
8 |
+
requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){
|
9 |
+
gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo
|
10 |
+
})
|
11 |
+
|
12 |
+
var res = args_to_array(arguments)
|
13 |
+
|
14 |
+
res[0] = id
|
15 |
+
|
16 |
+
return res
|
17 |
+
}
|
sd/stable-diffusion-webui/launch.py
CHANGED
@@ -1,361 +1,375 @@
|
|
1 |
-
# this scripts installs necessary requirements and launches main program in webui.py
|
2 |
-
import subprocess
|
3 |
-
import os
|
4 |
-
import sys
|
5 |
-
import importlib.util
|
6 |
-
import shlex
|
7 |
-
import platform
|
8 |
-
import argparse
|
9 |
-
import json
|
10 |
-
|
11 |
-
dir_repos = "repositories"
|
12 |
-
dir_extensions = "extensions"
|
13 |
-
python = sys.executable
|
14 |
-
git = os.environ.get('GIT', "git")
|
15 |
-
index_url = os.environ.get('INDEX_URL', "")
|
16 |
-
stored_commit_hash = None
|
17 |
-
skip_install = False
|
18 |
-
|
19 |
-
|
20 |
-
def check_python_version():
|
21 |
-
is_windows = platform.system() == "Windows"
|
22 |
-
major = sys.version_info.major
|
23 |
-
minor = sys.version_info.minor
|
24 |
-
micro = sys.version_info.micro
|
25 |
-
|
26 |
-
if is_windows:
|
27 |
-
supported_minors = [10]
|
28 |
-
else:
|
29 |
-
supported_minors = [7, 8, 9, 10, 11]
|
30 |
-
|
31 |
-
if not (major == 3 and minor in supported_minors):
|
32 |
-
import modules.errors
|
33 |
-
|
34 |
-
modules.errors.print_error_explanation(f"""
|
35 |
-
INCOMPATIBLE PYTHON VERSION
|
36 |
-
|
37 |
-
This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
|
38 |
-
If you encounter an error with "RuntimeError: Couldn't install torch." message,
|
39 |
-
or any other error regarding unsuccessful package (library) installation,
|
40 |
-
please downgrade (or upgrade) to the latest version of 3.10 Python
|
41 |
-
and delete current Python and "venv" folder in WebUI's directory.
|
42 |
-
|
43 |
-
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
|
44 |
-
|
45 |
-
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
|
46 |
-
|
47 |
-
Use --skip-python-version-check to suppress this warning.
|
48 |
-
""")
|
49 |
-
|
50 |
-
|
51 |
-
def commit_hash():
|
52 |
-
global stored_commit_hash
|
53 |
-
|
54 |
-
if stored_commit_hash is not None:
|
55 |
-
return stored_commit_hash
|
56 |
-
|
57 |
-
try:
|
58 |
-
stored_commit_hash = run(f"{git} rev-parse HEAD").strip()
|
59 |
-
except Exception:
|
60 |
-
stored_commit_hash = "<none>"
|
61 |
-
|
62 |
-
return stored_commit_hash
|
63 |
-
|
64 |
-
|
65 |
-
def extract_arg(args, name):
|
66 |
-
return [x for x in args if x != name], name in args
|
67 |
-
|
68 |
-
|
69 |
-
def extract_opt(args, name):
|
70 |
-
opt = None
|
71 |
-
is_present = False
|
72 |
-
if name in args:
|
73 |
-
is_present = True
|
74 |
-
idx = args.index(name)
|
75 |
-
del args[idx]
|
76 |
-
if idx < len(args) and args[idx][0] != "-":
|
77 |
-
opt = args[idx]
|
78 |
-
del args[idx]
|
79 |
-
return args, is_present, opt
|
80 |
-
|
81 |
-
|
82 |
-
def run(command, desc=None, errdesc=None, custom_env=None, live=False):
|
83 |
-
if desc is not None:
|
84 |
-
print(desc)
|
85 |
-
|
86 |
-
if live:
|
87 |
-
result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
|
88 |
-
if result.returncode != 0:
|
89 |
-
raise RuntimeError(f"""{errdesc or 'Error running command'}.
|
90 |
-
Command: {command}
|
91 |
-
Error code: {result.returncode}""")
|
92 |
-
|
93 |
-
return ""
|
94 |
-
|
95 |
-
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
|
96 |
-
|
97 |
-
if result.returncode != 0:
|
98 |
-
|
99 |
-
message = f"""{errdesc or 'Error running command'}.
|
100 |
-
Command: {command}
|
101 |
-
Error code: {result.returncode}
|
102 |
-
stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
|
103 |
-
stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
|
104 |
-
"""
|
105 |
-
raise RuntimeError(message)
|
106 |
-
|
107 |
-
return result.stdout.decode(encoding="utf8", errors="ignore")
|
108 |
-
|
109 |
-
|
110 |
-
def check_run(command):
|
111 |
-
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
|
112 |
-
return result.returncode == 0
|
113 |
-
|
114 |
-
|
115 |
-
def is_installed(package):
|
116 |
-
try:
|
117 |
-
spec = importlib.util.find_spec(package)
|
118 |
-
except ModuleNotFoundError:
|
119 |
-
return False
|
120 |
-
|
121 |
-
return spec is not None
|
122 |
-
|
123 |
-
|
124 |
-
def repo_dir(name):
|
125 |
-
return os.path.join(dir_repos, name)
|
126 |
-
|
127 |
-
|
128 |
-
def run_python(code, desc=None, errdesc=None):
|
129 |
-
return run(f'"{python}" -c "{code}"', desc, errdesc)
|
130 |
-
|
131 |
-
|
132 |
-
def run_pip(args, desc=None):
|
133 |
-
if skip_install:
|
134 |
-
return
|
135 |
-
|
136 |
-
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
137 |
-
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
|
138 |
-
|
139 |
-
|
140 |
-
def check_run_python(code):
|
141 |
-
return check_run(f'"{python}" -c "{code}"')
|
142 |
-
|
143 |
-
|
144 |
-
def git_clone(url, dir, name, commithash=None):
|
145 |
-
# TODO clone into temporary dir and move if successful
|
146 |
-
|
147 |
-
if os.path.exists(dir):
|
148 |
-
if commithash is None:
|
149 |
-
return
|
150 |
-
|
151 |
-
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
|
152 |
-
if current_hash == commithash:
|
153 |
-
return
|
154 |
-
|
155 |
-
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
156 |
-
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
|
157 |
-
return
|
158 |
-
|
159 |
-
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
|
160 |
-
|
161 |
-
if commithash is not None:
|
162 |
-
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
163 |
-
|
164 |
-
|
165 |
-
def
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
sys.argv
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
if not
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
if not is_installed("
|
307 |
-
run_pip(
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this scripts installs necessary requirements and launches main program in webui.py
|
2 |
+
import subprocess
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import importlib.util
|
6 |
+
import shlex
|
7 |
+
import platform
|
8 |
+
import argparse
|
9 |
+
import json
|
10 |
+
|
11 |
+
dir_repos = "repositories"
|
12 |
+
dir_extensions = "extensions"
|
13 |
+
python = sys.executable
|
14 |
+
git = os.environ.get('GIT', "git")
|
15 |
+
index_url = os.environ.get('INDEX_URL', "")
|
16 |
+
stored_commit_hash = None
|
17 |
+
skip_install = False
|
18 |
+
|
19 |
+
|
20 |
+
def check_python_version():
|
21 |
+
is_windows = platform.system() == "Windows"
|
22 |
+
major = sys.version_info.major
|
23 |
+
minor = sys.version_info.minor
|
24 |
+
micro = sys.version_info.micro
|
25 |
+
|
26 |
+
if is_windows:
|
27 |
+
supported_minors = [10]
|
28 |
+
else:
|
29 |
+
supported_minors = [7, 8, 9, 10, 11]
|
30 |
+
|
31 |
+
if not (major == 3 and minor in supported_minors):
|
32 |
+
import modules.errors
|
33 |
+
|
34 |
+
modules.errors.print_error_explanation(f"""
|
35 |
+
INCOMPATIBLE PYTHON VERSION
|
36 |
+
|
37 |
+
This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
|
38 |
+
If you encounter an error with "RuntimeError: Couldn't install torch." message,
|
39 |
+
or any other error regarding unsuccessful package (library) installation,
|
40 |
+
please downgrade (or upgrade) to the latest version of 3.10 Python
|
41 |
+
and delete current Python and "venv" folder in WebUI's directory.
|
42 |
+
|
43 |
+
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
|
44 |
+
|
45 |
+
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
|
46 |
+
|
47 |
+
Use --skip-python-version-check to suppress this warning.
|
48 |
+
""")
|
49 |
+
|
50 |
+
|
51 |
+
def commit_hash():
|
52 |
+
global stored_commit_hash
|
53 |
+
|
54 |
+
if stored_commit_hash is not None:
|
55 |
+
return stored_commit_hash
|
56 |
+
|
57 |
+
try:
|
58 |
+
stored_commit_hash = run(f"{git} rev-parse HEAD").strip()
|
59 |
+
except Exception:
|
60 |
+
stored_commit_hash = "<none>"
|
61 |
+
|
62 |
+
return stored_commit_hash
|
63 |
+
|
64 |
+
|
65 |
+
def extract_arg(args, name):
|
66 |
+
return [x for x in args if x != name], name in args
|
67 |
+
|
68 |
+
|
69 |
+
def extract_opt(args, name):
|
70 |
+
opt = None
|
71 |
+
is_present = False
|
72 |
+
if name in args:
|
73 |
+
is_present = True
|
74 |
+
idx = args.index(name)
|
75 |
+
del args[idx]
|
76 |
+
if idx < len(args) and args[idx][0] != "-":
|
77 |
+
opt = args[idx]
|
78 |
+
del args[idx]
|
79 |
+
return args, is_present, opt
|
80 |
+
|
81 |
+
|
82 |
+
def run(command, desc=None, errdesc=None, custom_env=None, live=False):
|
83 |
+
if desc is not None:
|
84 |
+
print(desc)
|
85 |
+
|
86 |
+
if live:
|
87 |
+
result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
|
88 |
+
if result.returncode != 0:
|
89 |
+
raise RuntimeError(f"""{errdesc or 'Error running command'}.
|
90 |
+
Command: {command}
|
91 |
+
Error code: {result.returncode}""")
|
92 |
+
|
93 |
+
return ""
|
94 |
+
|
95 |
+
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
|
96 |
+
|
97 |
+
if result.returncode != 0:
|
98 |
+
|
99 |
+
message = f"""{errdesc or 'Error running command'}.
|
100 |
+
Command: {command}
|
101 |
+
Error code: {result.returncode}
|
102 |
+
stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
|
103 |
+
stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
|
104 |
+
"""
|
105 |
+
raise RuntimeError(message)
|
106 |
+
|
107 |
+
return result.stdout.decode(encoding="utf8", errors="ignore")
|
108 |
+
|
109 |
+
|
110 |
+
def check_run(command):
|
111 |
+
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
|
112 |
+
return result.returncode == 0
|
113 |
+
|
114 |
+
|
115 |
+
def is_installed(package):
|
116 |
+
try:
|
117 |
+
spec = importlib.util.find_spec(package)
|
118 |
+
except ModuleNotFoundError:
|
119 |
+
return False
|
120 |
+
|
121 |
+
return spec is not None
|
122 |
+
|
123 |
+
|
124 |
+
def repo_dir(name):
|
125 |
+
return os.path.join(dir_repos, name)
|
126 |
+
|
127 |
+
|
128 |
+
def run_python(code, desc=None, errdesc=None):
|
129 |
+
return run(f'"{python}" -c "{code}"', desc, errdesc)
|
130 |
+
|
131 |
+
|
132 |
+
def run_pip(args, desc=None):
|
133 |
+
if skip_install:
|
134 |
+
return
|
135 |
+
|
136 |
+
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
137 |
+
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
|
138 |
+
|
139 |
+
|
140 |
+
def check_run_python(code):
|
141 |
+
return check_run(f'"{python}" -c "{code}"')
|
142 |
+
|
143 |
+
|
144 |
+
def git_clone(url, dir, name, commithash=None):
|
145 |
+
# TODO clone into temporary dir and move if successful
|
146 |
+
|
147 |
+
if os.path.exists(dir):
|
148 |
+
if commithash is None:
|
149 |
+
return
|
150 |
+
|
151 |
+
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
|
152 |
+
if current_hash == commithash:
|
153 |
+
return
|
154 |
+
|
155 |
+
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
156 |
+
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
|
157 |
+
return
|
158 |
+
|
159 |
+
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
|
160 |
+
|
161 |
+
if commithash is not None:
|
162 |
+
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
163 |
+
|
164 |
+
|
165 |
+
def git_pull_recursive(dir):
|
166 |
+
for subdir, _, _ in os.walk(dir):
|
167 |
+
if os.path.exists(os.path.join(subdir, '.git')):
|
168 |
+
try:
|
169 |
+
output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
|
170 |
+
print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
|
171 |
+
except subprocess.CalledProcessError as e:
|
172 |
+
print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
|
173 |
+
|
174 |
+
|
175 |
+
def version_check(commit):
|
176 |
+
try:
|
177 |
+
import requests
|
178 |
+
commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
|
179 |
+
if commit != "<none>" and commits['commit']['sha'] != commit:
|
180 |
+
print("--------------------------------------------------------")
|
181 |
+
print("| You are not up to date with the most recent release. |")
|
182 |
+
print("| Consider running `git pull` to update. |")
|
183 |
+
print("--------------------------------------------------------")
|
184 |
+
elif commits['commit']['sha'] == commit:
|
185 |
+
print("You are up to date with the most recent release.")
|
186 |
+
else:
|
187 |
+
print("Not a git clone, can't perform version check.")
|
188 |
+
except Exception as e:
|
189 |
+
print("version check failed", e)
|
190 |
+
|
191 |
+
|
192 |
+
def run_extension_installer(extension_dir):
|
193 |
+
path_installer = os.path.join(extension_dir, "install.py")
|
194 |
+
if not os.path.isfile(path_installer):
|
195 |
+
return
|
196 |
+
|
197 |
+
try:
|
198 |
+
env = os.environ.copy()
|
199 |
+
env['PYTHONPATH'] = os.path.abspath(".")
|
200 |
+
|
201 |
+
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
|
202 |
+
except Exception as e:
|
203 |
+
print(e, file=sys.stderr)
|
204 |
+
|
205 |
+
|
206 |
+
def list_extensions(settings_file):
|
207 |
+
settings = {}
|
208 |
+
|
209 |
+
try:
|
210 |
+
if os.path.isfile(settings_file):
|
211 |
+
with open(settings_file, "r", encoding="utf8") as file:
|
212 |
+
settings = json.load(file)
|
213 |
+
except Exception as e:
|
214 |
+
print(e, file=sys.stderr)
|
215 |
+
|
216 |
+
disabled_extensions = set(settings.get('disabled_extensions', []))
|
217 |
+
|
218 |
+
return [x for x in os.listdir(dir_extensions) if x not in disabled_extensions]
|
219 |
+
|
220 |
+
|
221 |
+
def run_extensions_installers(settings_file):
|
222 |
+
if not os.path.isdir(dir_extensions):
|
223 |
+
return
|
224 |
+
|
225 |
+
for dirname_extension in list_extensions(settings_file):
|
226 |
+
run_extension_installer(os.path.join(dir_extensions, dirname_extension))
|
227 |
+
|
228 |
+
|
229 |
+
def prepare_environment():
|
230 |
+
global skip_install
|
231 |
+
|
232 |
+
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
|
233 |
+
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
234 |
+
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
235 |
+
|
236 |
+
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425')
|
237 |
+
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
238 |
+
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
239 |
+
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
|
240 |
+
|
241 |
+
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
242 |
+
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
|
243 |
+
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
244 |
+
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
|
245 |
+
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
246 |
+
|
247 |
+
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "47b6b607fdd31875c9279cd2f4f16b92e4ea958e")
|
248 |
+
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
|
249 |
+
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec")
|
250 |
+
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
251 |
+
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
252 |
+
|
253 |
+
sys.argv += shlex.split(commandline_args)
|
254 |
+
|
255 |
+
parser = argparse.ArgumentParser(add_help=False)
|
256 |
+
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default='config.json')
|
257 |
+
args, _ = parser.parse_known_args(sys.argv)
|
258 |
+
|
259 |
+
sys.argv, _ = extract_arg(sys.argv, '-f')
|
260 |
+
sys.argv, update_all_extensions = extract_arg(sys.argv, '--update-all-extensions')
|
261 |
+
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
|
262 |
+
sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check')
|
263 |
+
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
|
264 |
+
sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch')
|
265 |
+
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
|
266 |
+
sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
|
267 |
+
sys.argv, skip_install = extract_arg(sys.argv, '--skip-install')
|
268 |
+
xformers = '--xformers' in sys.argv
|
269 |
+
ngrok = '--ngrok' in sys.argv
|
270 |
+
|
271 |
+
if not skip_python_version_check:
|
272 |
+
check_python_version()
|
273 |
+
|
274 |
+
commit = commit_hash()
|
275 |
+
|
276 |
+
print(f"Python {sys.version}")
|
277 |
+
print(f"Commit hash: {commit}")
|
278 |
+
|
279 |
+
if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
280 |
+
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
281 |
+
|
282 |
+
if not skip_torch_cuda_test:
|
283 |
+
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
|
284 |
+
|
285 |
+
if not is_installed("gfpgan"):
|
286 |
+
run_pip(f"install {gfpgan_package}", "gfpgan")
|
287 |
+
|
288 |
+
if not is_installed("clip"):
|
289 |
+
run_pip(f"install {clip_package}", "clip")
|
290 |
+
|
291 |
+
if not is_installed("open_clip"):
|
292 |
+
run_pip(f"install {openclip_package}", "open_clip")
|
293 |
+
|
294 |
+
if (not is_installed("xformers") or reinstall_xformers) and xformers:
|
295 |
+
if platform.system() == "Windows":
|
296 |
+
if platform.python_version().startswith("3.10"):
|
297 |
+
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
|
298 |
+
else:
|
299 |
+
print("Installation of xformers is not supported in this version of Python.")
|
300 |
+
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
|
301 |
+
if not is_installed("xformers"):
|
302 |
+
exit(0)
|
303 |
+
elif platform.system() == "Linux":
|
304 |
+
run_pip(f"install {xformers_package}", "xformers")
|
305 |
+
|
306 |
+
if not is_installed("pyngrok") and ngrok:
|
307 |
+
run_pip("install pyngrok", "ngrok")
|
308 |
+
|
309 |
+
os.makedirs(dir_repos, exist_ok=True)
|
310 |
+
|
311 |
+
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
312 |
+
git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
|
313 |
+
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
314 |
+
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
315 |
+
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
316 |
+
|
317 |
+
if not is_installed("lpips"):
|
318 |
+
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
|
319 |
+
|
320 |
+
run_pip(f"install -r {requirements_file}", "requirements for Web UI")
|
321 |
+
|
322 |
+
run_extensions_installers(settings_file=args.ui_settings_file)
|
323 |
+
|
324 |
+
if update_check:
|
325 |
+
version_check(commit)
|
326 |
+
|
327 |
+
if update_all_extensions:
|
328 |
+
git_pull_recursive(dir_extensions)
|
329 |
+
|
330 |
+
if "--exit" in sys.argv:
|
331 |
+
print("Exiting because of --exit argument")
|
332 |
+
exit(0)
|
333 |
+
|
334 |
+
if run_tests:
|
335 |
+
exitcode = tests(test_dir)
|
336 |
+
exit(exitcode)
|
337 |
+
|
338 |
+
|
339 |
+
def tests(test_dir):
|
340 |
+
if "--api" not in sys.argv:
|
341 |
+
sys.argv.append("--api")
|
342 |
+
if "--ckpt" not in sys.argv:
|
343 |
+
sys.argv.append("--ckpt")
|
344 |
+
sys.argv.append("./test/test_files/empty.pt")
|
345 |
+
if "--skip-torch-cuda-test" not in sys.argv:
|
346 |
+
sys.argv.append("--skip-torch-cuda-test")
|
347 |
+
if "--disable-nan-check" not in sys.argv:
|
348 |
+
sys.argv.append("--disable-nan-check")
|
349 |
+
|
350 |
+
print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
|
351 |
+
|
352 |
+
os.environ['COMMANDLINE_ARGS'] = ""
|
353 |
+
with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr:
|
354 |
+
proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
|
355 |
+
|
356 |
+
import test.server_poll
|
357 |
+
exitcode = test.server_poll.run_tests(proc, test_dir)
|
358 |
+
|
359 |
+
print(f"Stopping Web UI process with id {proc.pid}")
|
360 |
+
proc.kill()
|
361 |
+
return exitcode
|
362 |
+
|
363 |
+
|
364 |
+
def start():
|
365 |
+
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
|
366 |
+
import webui
|
367 |
+
if '--nowebui' in sys.argv:
|
368 |
+
webui.api_only()
|
369 |
+
else:
|
370 |
+
webui.webui()
|
371 |
+
|
372 |
+
|
373 |
+
if __name__ == "__main__":
|
374 |
+
prepare_environment()
|
375 |
+
start()
|
sd/stable-diffusion-webui/modules/api/api.py
CHANGED
@@ -150,6 +150,7 @@ class Api:
|
|
150 |
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
151 |
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
152 |
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
|
|
|
153 |
|
154 |
def add_api_route(self, path: str, endpoint, **kwargs):
|
155 |
if shared.cmd_opts.api_auth:
|
@@ -174,36 +175,44 @@ class Api:
|
|
174 |
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
|
175 |
script = script_runner.selectable_scripts[script_idx]
|
176 |
return script, script_idx
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
179 |
script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img)
|
180 |
|
181 |
-
populate = txt2imgreq.copy(update={
|
182 |
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
|
183 |
-
"do_not_save_samples":
|
184 |
-
"do_not_save_grid":
|
185 |
-
|
186 |
-
)
|
187 |
if populate.sampler_name:
|
188 |
populate.sampler_index = None # prevent a warning later on
|
189 |
|
190 |
args = vars(populate)
|
191 |
args.pop('script_name', None)
|
192 |
|
|
|
|
|
|
|
193 |
with self.queue_lock:
|
194 |
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
|
|
|
|
|
195 |
|
196 |
shared.state.begin()
|
197 |
if script is not None:
|
198 |
-
p.outpath_grids = opts.outdir_txt2img_grids
|
199 |
-
p.outpath_samples = opts.outdir_txt2img_samples
|
200 |
p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
|
201 |
processed = scripts.scripts_txt2img.run(p, *p.script_args)
|
202 |
else:
|
203 |
processed = process_images(p)
|
204 |
shared.state.end()
|
205 |
|
206 |
-
b64images = list(map(encode_pil_to_base64, processed.images))
|
207 |
|
208 |
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
209 |
|
@@ -218,13 +227,12 @@ class Api:
|
|
218 |
if mask:
|
219 |
mask = decode_base64_to_image(mask)
|
220 |
|
221 |
-
populate = img2imgreq.copy(update={
|
222 |
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
|
223 |
-
"do_not_save_samples":
|
224 |
-
"do_not_save_grid":
|
225 |
-
"mask": mask
|
226 |
-
|
227 |
-
)
|
228 |
if populate.sampler_name:
|
229 |
populate.sampler_index = None # prevent a warning later on
|
230 |
|
@@ -232,21 +240,24 @@ class Api:
|
|
232 |
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
|
233 |
args.pop('script_name', None)
|
234 |
|
|
|
|
|
|
|
235 |
with self.queue_lock:
|
236 |
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
|
237 |
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
|
|
|
|
238 |
|
239 |
shared.state.begin()
|
240 |
if script is not None:
|
241 |
-
p.outpath_grids = opts.outdir_img2img_grids
|
242 |
-
p.outpath_samples = opts.outdir_img2img_samples
|
243 |
p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
|
244 |
processed = scripts.scripts_img2img.run(p, *p.script_args)
|
245 |
else:
|
246 |
processed = process_images(p)
|
247 |
shared.state.end()
|
248 |
|
249 |
-
b64images = list(map(encode_pil_to_base64, processed.images))
|
250 |
|
251 |
if not img2imgreq.include_init_images:
|
252 |
img2imgreq.init_images = None
|
|
|
150 |
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
151 |
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
152 |
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
|
153 |
+
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
|
154 |
|
155 |
def add_api_route(self, path: str, endpoint, **kwargs):
|
156 |
if shared.cmd_opts.api_auth:
|
|
|
175 |
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
|
176 |
script = script_runner.selectable_scripts[script_idx]
|
177 |
return script, script_idx
|
178 |
+
|
179 |
+
def get_scripts_list(self):
|
180 |
+
t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
|
181 |
+
i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
|
182 |
+
|
183 |
+
return ScriptsList(txt2img = t2ilist, img2img = i2ilist)
|
184 |
|
185 |
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
186 |
script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img)
|
187 |
|
188 |
+
populate = txt2imgreq.copy(update={ # Override __init__ params
|
189 |
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
|
190 |
+
"do_not_save_samples": not txt2imgreq.save_images,
|
191 |
+
"do_not_save_grid": not txt2imgreq.save_images,
|
192 |
+
})
|
|
|
193 |
if populate.sampler_name:
|
194 |
populate.sampler_index = None # prevent a warning later on
|
195 |
|
196 |
args = vars(populate)
|
197 |
args.pop('script_name', None)
|
198 |
|
199 |
+
send_images = args.pop('send_images', True)
|
200 |
+
args.pop('save_images', None)
|
201 |
+
|
202 |
with self.queue_lock:
|
203 |
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
|
204 |
+
p.outpath_grids = opts.outdir_txt2img_grids
|
205 |
+
p.outpath_samples = opts.outdir_txt2img_samples
|
206 |
|
207 |
shared.state.begin()
|
208 |
if script is not None:
|
|
|
|
|
209 |
p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
|
210 |
processed = scripts.scripts_txt2img.run(p, *p.script_args)
|
211 |
else:
|
212 |
processed = process_images(p)
|
213 |
shared.state.end()
|
214 |
|
215 |
+
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
216 |
|
217 |
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
218 |
|
|
|
227 |
if mask:
|
228 |
mask = decode_base64_to_image(mask)
|
229 |
|
230 |
+
populate = img2imgreq.copy(update={ # Override __init__ params
|
231 |
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
|
232 |
+
"do_not_save_samples": not img2imgreq.save_images,
|
233 |
+
"do_not_save_grid": not img2imgreq.save_images,
|
234 |
+
"mask": mask,
|
235 |
+
})
|
|
|
236 |
if populate.sampler_name:
|
237 |
populate.sampler_index = None # prevent a warning later on
|
238 |
|
|
|
240 |
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
|
241 |
args.pop('script_name', None)
|
242 |
|
243 |
+
send_images = args.pop('send_images', True)
|
244 |
+
args.pop('save_images', None)
|
245 |
+
|
246 |
with self.queue_lock:
|
247 |
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
|
248 |
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
249 |
+
p.outpath_grids = opts.outdir_img2img_grids
|
250 |
+
p.outpath_samples = opts.outdir_img2img_samples
|
251 |
|
252 |
shared.state.begin()
|
253 |
if script is not None:
|
|
|
|
|
254 |
p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
|
255 |
processed = scripts.scripts_img2img.run(p, *p.script_args)
|
256 |
else:
|
257 |
processed = process_images(p)
|
258 |
shared.state.end()
|
259 |
|
260 |
+
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
261 |
|
262 |
if not img2imgreq.include_init_images:
|
263 |
img2imgreq.init_images = None
|
sd/stable-diffusion-webui/modules/api/models.py
CHANGED
@@ -14,8 +14,8 @@ API_NOT_ALLOWED = [
|
|
14 |
"outpath_samples",
|
15 |
"outpath_grids",
|
16 |
"sampler_index",
|
17 |
-
"do_not_save_samples",
|
18 |
-
"do_not_save_grid",
|
19 |
"extra_generation_params",
|
20 |
"overlay_images",
|
21 |
"do_not_reload_embeddings",
|
@@ -100,13 +100,29 @@ class PydanticModelGenerator:
|
|
100 |
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
101 |
"StableDiffusionProcessingTxt2Img",
|
102 |
StableDiffusionProcessingTxt2Img,
|
103 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
).generate_model()
|
105 |
|
106 |
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
107 |
"StableDiffusionProcessingImg2Img",
|
108 |
StableDiffusionProcessingImg2Img,
|
109 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
).generate_model()
|
111 |
|
112 |
class TextToImageResponse(BaseModel):
|
@@ -267,3 +283,7 @@ class EmbeddingsResponse(BaseModel):
|
|
267 |
class MemoryResponse(BaseModel):
|
268 |
ram: dict = Field(title="RAM", description="System memory stats")
|
269 |
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
|
|
|
|
|
|
|
|
|
|
14 |
"outpath_samples",
|
15 |
"outpath_grids",
|
16 |
"sampler_index",
|
17 |
+
# "do_not_save_samples",
|
18 |
+
# "do_not_save_grid",
|
19 |
"extra_generation_params",
|
20 |
"overlay_images",
|
21 |
"do_not_reload_embeddings",
|
|
|
100 |
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
101 |
"StableDiffusionProcessingTxt2Img",
|
102 |
StableDiffusionProcessingTxt2Img,
|
103 |
+
[
|
104 |
+
{"key": "sampler_index", "type": str, "default": "Euler"},
|
105 |
+
{"key": "script_name", "type": str, "default": None},
|
106 |
+
{"key": "script_args", "type": list, "default": []},
|
107 |
+
{"key": "send_images", "type": bool, "default": True},
|
108 |
+
{"key": "save_images", "type": bool, "default": False},
|
109 |
+
]
|
110 |
).generate_model()
|
111 |
|
112 |
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
113 |
"StableDiffusionProcessingImg2Img",
|
114 |
StableDiffusionProcessingImg2Img,
|
115 |
+
[
|
116 |
+
{"key": "sampler_index", "type": str, "default": "Euler"},
|
117 |
+
{"key": "init_images", "type": list, "default": None},
|
118 |
+
{"key": "denoising_strength", "type": float, "default": 0.75},
|
119 |
+
{"key": "mask", "type": str, "default": None},
|
120 |
+
{"key": "include_init_images", "type": bool, "default": False, "exclude" : True},
|
121 |
+
{"key": "script_name", "type": str, "default": None},
|
122 |
+
{"key": "script_args", "type": list, "default": []},
|
123 |
+
{"key": "send_images", "type": bool, "default": True},
|
124 |
+
{"key": "save_images", "type": bool, "default": False},
|
125 |
+
]
|
126 |
).generate_model()
|
127 |
|
128 |
class TextToImageResponse(BaseModel):
|
|
|
283 |
class MemoryResponse(BaseModel):
|
284 |
ram: dict = Field(title="RAM", description="System memory stats")
|
285 |
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
|
286 |
+
|
287 |
+
class ScriptsList(BaseModel):
|
288 |
+
txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
|
289 |
+
img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)")
|
sd/stable-diffusion-webui/modules/call_queue.py
CHANGED
@@ -1,109 +1,109 @@
|
|
1 |
-
import html
|
2 |
-
import sys
|
3 |
-
import threading
|
4 |
-
import traceback
|
5 |
-
import time
|
6 |
-
|
7 |
-
from modules import shared, progress
|
8 |
-
|
9 |
-
queue_lock = threading.Lock()
|
10 |
-
|
11 |
-
|
12 |
-
def wrap_queued_call(func):
|
13 |
-
def f(*args, **kwargs):
|
14 |
-
with queue_lock:
|
15 |
-
res = func(*args, **kwargs)
|
16 |
-
|
17 |
-
return res
|
18 |
-
|
19 |
-
return f
|
20 |
-
|
21 |
-
|
22 |
-
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
23 |
-
def f(*args, **kwargs):
|
24 |
-
|
25 |
-
# if the first argument is a string that says "task(...)", it is treated as a job id
|
26 |
-
if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
|
27 |
-
id_task = args[0]
|
28 |
-
progress.add_task_to_queue(id_task)
|
29 |
-
else:
|
30 |
-
id_task = None
|
31 |
-
|
32 |
-
with queue_lock:
|
33 |
-
shared.state.begin()
|
34 |
-
progress.start_task(id_task)
|
35 |
-
|
36 |
-
try:
|
37 |
-
res = func(*args, **kwargs)
|
38 |
-
finally:
|
39 |
-
progress.finish_task(id_task)
|
40 |
-
|
41 |
-
shared.state.end()
|
42 |
-
|
43 |
-
return res
|
44 |
-
|
45 |
-
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
|
46 |
-
|
47 |
-
|
48 |
-
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
49 |
-
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
50 |
-
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
|
51 |
-
if run_memmon:
|
52 |
-
shared.mem_mon.monitor()
|
53 |
-
t = time.perf_counter()
|
54 |
-
|
55 |
-
try:
|
56 |
-
res = list(func(*args, **kwargs))
|
57 |
-
except Exception as e:
|
58 |
-
# When printing out our debug argument list, do not print out more than a MB of text
|
59 |
-
max_debug_str_len = 131072 # (1024*1024)/8
|
60 |
-
|
61 |
-
print("Error completing request", file=sys.stderr)
|
62 |
-
argStr = f"Arguments: {str(args)} {str(kwargs)}"
|
63 |
-
print(argStr[:max_debug_str_len], file=sys.stderr)
|
64 |
-
if len(argStr) > max_debug_str_len:
|
65 |
-
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
|
66 |
-
|
67 |
-
print(traceback.format_exc(), file=sys.stderr)
|
68 |
-
|
69 |
-
shared.state.job = ""
|
70 |
-
shared.state.job_count = 0
|
71 |
-
|
72 |
-
if extra_outputs_array is None:
|
73 |
-
extra_outputs_array = [None, '']
|
74 |
-
|
75 |
-
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
|
76 |
-
|
77 |
-
shared.state.skipped = False
|
78 |
-
shared.state.interrupted = False
|
79 |
-
shared.state.job_count = 0
|
80 |
-
|
81 |
-
if not add_stats:
|
82 |
-
return tuple(res)
|
83 |
-
|
84 |
-
elapsed = time.perf_counter() - t
|
85 |
-
elapsed_m = int(elapsed // 60)
|
86 |
-
elapsed_s = elapsed % 60
|
87 |
-
elapsed_text = f"{elapsed_s:.2f}s"
|
88 |
-
if elapsed_m > 0:
|
89 |
-
elapsed_text = f"{elapsed_m}m "+elapsed_text
|
90 |
-
|
91 |
-
if run_memmon:
|
92 |
-
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
|
93 |
-
active_peak = mem_stats['active_peak']
|
94 |
-
reserved_peak = mem_stats['reserved_peak']
|
95 |
-
sys_peak = mem_stats['system_peak']
|
96 |
-
sys_total = mem_stats['total']
|
97 |
-
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
|
98 |
-
|
99 |
-
vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
|
100 |
-
else:
|
101 |
-
vram_html = ''
|
102 |
-
|
103 |
-
# last item is always HTML
|
104 |
-
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
|
105 |
-
|
106 |
-
return tuple(res)
|
107 |
-
|
108 |
-
return f
|
109 |
-
|
|
|
1 |
+
import html
|
2 |
+
import sys
|
3 |
+
import threading
|
4 |
+
import traceback
|
5 |
+
import time
|
6 |
+
|
7 |
+
from modules import shared, progress
|
8 |
+
|
9 |
+
queue_lock = threading.Lock()
|
10 |
+
|
11 |
+
|
12 |
+
def wrap_queued_call(func):
|
13 |
+
def f(*args, **kwargs):
|
14 |
+
with queue_lock:
|
15 |
+
res = func(*args, **kwargs)
|
16 |
+
|
17 |
+
return res
|
18 |
+
|
19 |
+
return f
|
20 |
+
|
21 |
+
|
22 |
+
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
23 |
+
def f(*args, **kwargs):
|
24 |
+
|
25 |
+
# if the first argument is a string that says "task(...)", it is treated as a job id
|
26 |
+
if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
|
27 |
+
id_task = args[0]
|
28 |
+
progress.add_task_to_queue(id_task)
|
29 |
+
else:
|
30 |
+
id_task = None
|
31 |
+
|
32 |
+
with queue_lock:
|
33 |
+
shared.state.begin()
|
34 |
+
progress.start_task(id_task)
|
35 |
+
|
36 |
+
try:
|
37 |
+
res = func(*args, **kwargs)
|
38 |
+
finally:
|
39 |
+
progress.finish_task(id_task)
|
40 |
+
|
41 |
+
shared.state.end()
|
42 |
+
|
43 |
+
return res
|
44 |
+
|
45 |
+
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
|
46 |
+
|
47 |
+
|
48 |
+
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
49 |
+
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
50 |
+
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
|
51 |
+
if run_memmon:
|
52 |
+
shared.mem_mon.monitor()
|
53 |
+
t = time.perf_counter()
|
54 |
+
|
55 |
+
try:
|
56 |
+
res = list(func(*args, **kwargs))
|
57 |
+
except Exception as e:
|
58 |
+
# When printing out our debug argument list, do not print out more than a MB of text
|
59 |
+
max_debug_str_len = 131072 # (1024*1024)/8
|
60 |
+
|
61 |
+
print("Error completing request", file=sys.stderr)
|
62 |
+
argStr = f"Arguments: {str(args)} {str(kwargs)}"
|
63 |
+
print(argStr[:max_debug_str_len], file=sys.stderr)
|
64 |
+
if len(argStr) > max_debug_str_len:
|
65 |
+
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
|
66 |
+
|
67 |
+
print(traceback.format_exc(), file=sys.stderr)
|
68 |
+
|
69 |
+
shared.state.job = ""
|
70 |
+
shared.state.job_count = 0
|
71 |
+
|
72 |
+
if extra_outputs_array is None:
|
73 |
+
extra_outputs_array = [None, '']
|
74 |
+
|
75 |
+
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
|
76 |
+
|
77 |
+
shared.state.skipped = False
|
78 |
+
shared.state.interrupted = False
|
79 |
+
shared.state.job_count = 0
|
80 |
+
|
81 |
+
if not add_stats:
|
82 |
+
return tuple(res)
|
83 |
+
|
84 |
+
elapsed = time.perf_counter() - t
|
85 |
+
elapsed_m = int(elapsed // 60)
|
86 |
+
elapsed_s = elapsed % 60
|
87 |
+
elapsed_text = f"{elapsed_s:.2f}s"
|
88 |
+
if elapsed_m > 0:
|
89 |
+
elapsed_text = f"{elapsed_m}m "+elapsed_text
|
90 |
+
|
91 |
+
if run_memmon:
|
92 |
+
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
|
93 |
+
active_peak = mem_stats['active_peak']
|
94 |
+
reserved_peak = mem_stats['reserved_peak']
|
95 |
+
sys_peak = mem_stats['system_peak']
|
96 |
+
sys_total = mem_stats['total']
|
97 |
+
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
|
98 |
+
|
99 |
+
vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
|
100 |
+
else:
|
101 |
+
vram_html = ''
|
102 |
+
|
103 |
+
# last item is always HTML
|
104 |
+
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
|
105 |
+
|
106 |
+
return tuple(res)
|
107 |
+
|
108 |
+
return f
|
109 |
+
|
sd/stable-diffusion-webui/modules/codeformer_model.py
CHANGED
@@ -1,143 +1,143 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import traceback
|
4 |
-
|
5 |
-
import cv2
|
6 |
-
import torch
|
7 |
-
|
8 |
-
import modules.face_restoration
|
9 |
-
import modules.shared
|
10 |
-
from modules import shared, devices, modelloader
|
11 |
-
from modules.paths import models_path
|
12 |
-
|
13 |
-
# codeformer people made a choice to include modified basicsr library to their project which makes
|
14 |
-
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
|
15 |
-
# I am making a choice to include some files from codeformer to work around this issue.
|
16 |
-
model_dir = "Codeformer"
|
17 |
-
model_path = os.path.join(models_path, model_dir)
|
18 |
-
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
19 |
-
|
20 |
-
have_codeformer = False
|
21 |
-
codeformer = None
|
22 |
-
|
23 |
-
|
24 |
-
def setup_model(dirname):
|
25 |
-
global model_path
|
26 |
-
if not os.path.exists(model_path):
|
27 |
-
os.makedirs(model_path)
|
28 |
-
|
29 |
-
path = modules.paths.paths.get("CodeFormer", None)
|
30 |
-
if path is None:
|
31 |
-
return
|
32 |
-
|
33 |
-
try:
|
34 |
-
from torchvision.transforms.functional import normalize
|
35 |
-
from modules.codeformer.codeformer_arch import CodeFormer
|
36 |
-
from basicsr.utils.download_util import load_file_from_url
|
37 |
-
from basicsr.utils import imwrite, img2tensor, tensor2img
|
38 |
-
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
39 |
-
from facelib.detection.retinaface import retinaface
|
40 |
-
from modules.shared import cmd_opts
|
41 |
-
|
42 |
-
net_class = CodeFormer
|
43 |
-
|
44 |
-
class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
|
45 |
-
def name(self):
|
46 |
-
return "CodeFormer"
|
47 |
-
|
48 |
-
def __init__(self, dirname):
|
49 |
-
self.net = None
|
50 |
-
self.face_helper = None
|
51 |
-
self.cmd_dir = dirname
|
52 |
-
|
53 |
-
def create_models(self):
|
54 |
-
|
55 |
-
if self.net is not None and self.face_helper is not None:
|
56 |
-
self.net.to(devices.device_codeformer)
|
57 |
-
return self.net, self.face_helper
|
58 |
-
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth')
|
59 |
-
if len(model_paths) != 0:
|
60 |
-
ckpt_path = model_paths[0]
|
61 |
-
else:
|
62 |
-
print("Unable to load codeformer model.")
|
63 |
-
return None, None
|
64 |
-
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
|
65 |
-
checkpoint = torch.load(ckpt_path)['params_ema']
|
66 |
-
net.load_state_dict(checkpoint)
|
67 |
-
net.eval()
|
68 |
-
|
69 |
-
if hasattr(retinaface, 'device'):
|
70 |
-
retinaface.device = devices.device_codeformer
|
71 |
-
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
|
72 |
-
|
73 |
-
self.net = net
|
74 |
-
self.face_helper = face_helper
|
75 |
-
|
76 |
-
return net, face_helper
|
77 |
-
|
78 |
-
def send_model_to(self, device):
|
79 |
-
self.net.to(device)
|
80 |
-
self.face_helper.face_det.to(device)
|
81 |
-
self.face_helper.face_parse.to(device)
|
82 |
-
|
83 |
-
def restore(self, np_image, w=None):
|
84 |
-
np_image = np_image[:, :, ::-1]
|
85 |
-
|
86 |
-
original_resolution = np_image.shape[0:2]
|
87 |
-
|
88 |
-
self.create_models()
|
89 |
-
if self.net is None or self.face_helper is None:
|
90 |
-
return np_image
|
91 |
-
|
92 |
-
self.send_model_to(devices.device_codeformer)
|
93 |
-
|
94 |
-
self.face_helper.clean_all()
|
95 |
-
self.face_helper.read_image(np_image)
|
96 |
-
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
97 |
-
self.face_helper.align_warp_face()
|
98 |
-
|
99 |
-
for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
|
100 |
-
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
101 |
-
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
102 |
-
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
103 |
-
|
104 |
-
try:
|
105 |
-
with torch.no_grad():
|
106 |
-
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
|
107 |
-
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
108 |
-
del output
|
109 |
-
torch.cuda.empty_cache()
|
110 |
-
except Exception as error:
|
111 |
-
print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
|
112 |
-
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
113 |
-
|
114 |
-
restored_face = restored_face.astype('uint8')
|
115 |
-
self.face_helper.add_restored_face(restored_face)
|
116 |
-
|
117 |
-
self.face_helper.get_inverse_affine(None)
|
118 |
-
|
119 |
-
restored_img = self.face_helper.paste_faces_to_input_image()
|
120 |
-
restored_img = restored_img[:, :, ::-1]
|
121 |
-
|
122 |
-
if original_resolution != restored_img.shape[0:2]:
|
123 |
-
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
|
124 |
-
|
125 |
-
self.face_helper.clean_all()
|
126 |
-
|
127 |
-
if shared.opts.face_restoration_unload:
|
128 |
-
self.send_model_to(devices.cpu)
|
129 |
-
|
130 |
-
return restored_img
|
131 |
-
|
132 |
-
global have_codeformer
|
133 |
-
have_codeformer = True
|
134 |
-
|
135 |
-
global codeformer
|
136 |
-
codeformer = FaceRestorerCodeFormer(dirname)
|
137 |
-
shared.face_restorers.append(codeformer)
|
138 |
-
|
139 |
-
except Exception:
|
140 |
-
print("Error setting up CodeFormer:", file=sys.stderr)
|
141 |
-
print(traceback.format_exc(), file=sys.stderr)
|
142 |
-
|
143 |
-
# sys.path = stored_sys_path
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
|
8 |
+
import modules.face_restoration
|
9 |
+
import modules.shared
|
10 |
+
from modules import shared, devices, modelloader
|
11 |
+
from modules.paths import models_path
|
12 |
+
|
13 |
+
# codeformer people made a choice to include modified basicsr library to their project which makes
|
14 |
+
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
|
15 |
+
# I am making a choice to include some files from codeformer to work around this issue.
|
16 |
+
model_dir = "Codeformer"
|
17 |
+
model_path = os.path.join(models_path, model_dir)
|
18 |
+
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
19 |
+
|
20 |
+
have_codeformer = False
|
21 |
+
codeformer = None
|
22 |
+
|
23 |
+
|
24 |
+
def setup_model(dirname):
|
25 |
+
global model_path
|
26 |
+
if not os.path.exists(model_path):
|
27 |
+
os.makedirs(model_path)
|
28 |
+
|
29 |
+
path = modules.paths.paths.get("CodeFormer", None)
|
30 |
+
if path is None:
|
31 |
+
return
|
32 |
+
|
33 |
+
try:
|
34 |
+
from torchvision.transforms.functional import normalize
|
35 |
+
from modules.codeformer.codeformer_arch import CodeFormer
|
36 |
+
from basicsr.utils.download_util import load_file_from_url
|
37 |
+
from basicsr.utils import imwrite, img2tensor, tensor2img
|
38 |
+
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
39 |
+
from facelib.detection.retinaface import retinaface
|
40 |
+
from modules.shared import cmd_opts
|
41 |
+
|
42 |
+
net_class = CodeFormer
|
43 |
+
|
44 |
+
class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
|
45 |
+
def name(self):
|
46 |
+
return "CodeFormer"
|
47 |
+
|
48 |
+
def __init__(self, dirname):
|
49 |
+
self.net = None
|
50 |
+
self.face_helper = None
|
51 |
+
self.cmd_dir = dirname
|
52 |
+
|
53 |
+
def create_models(self):
|
54 |
+
|
55 |
+
if self.net is not None and self.face_helper is not None:
|
56 |
+
self.net.to(devices.device_codeformer)
|
57 |
+
return self.net, self.face_helper
|
58 |
+
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth'])
|
59 |
+
if len(model_paths) != 0:
|
60 |
+
ckpt_path = model_paths[0]
|
61 |
+
else:
|
62 |
+
print("Unable to load codeformer model.")
|
63 |
+
return None, None
|
64 |
+
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
|
65 |
+
checkpoint = torch.load(ckpt_path)['params_ema']
|
66 |
+
net.load_state_dict(checkpoint)
|
67 |
+
net.eval()
|
68 |
+
|
69 |
+
if hasattr(retinaface, 'device'):
|
70 |
+
retinaface.device = devices.device_codeformer
|
71 |
+
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
|
72 |
+
|
73 |
+
self.net = net
|
74 |
+
self.face_helper = face_helper
|
75 |
+
|
76 |
+
return net, face_helper
|
77 |
+
|
78 |
+
def send_model_to(self, device):
|
79 |
+
self.net.to(device)
|
80 |
+
self.face_helper.face_det.to(device)
|
81 |
+
self.face_helper.face_parse.to(device)
|
82 |
+
|
83 |
+
def restore(self, np_image, w=None):
|
84 |
+
np_image = np_image[:, :, ::-1]
|
85 |
+
|
86 |
+
original_resolution = np_image.shape[0:2]
|
87 |
+
|
88 |
+
self.create_models()
|
89 |
+
if self.net is None or self.face_helper is None:
|
90 |
+
return np_image
|
91 |
+
|
92 |
+
self.send_model_to(devices.device_codeformer)
|
93 |
+
|
94 |
+
self.face_helper.clean_all()
|
95 |
+
self.face_helper.read_image(np_image)
|
96 |
+
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
97 |
+
self.face_helper.align_warp_face()
|
98 |
+
|
99 |
+
for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
|
100 |
+
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
101 |
+
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
102 |
+
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
103 |
+
|
104 |
+
try:
|
105 |
+
with torch.no_grad():
|
106 |
+
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
|
107 |
+
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
108 |
+
del output
|
109 |
+
torch.cuda.empty_cache()
|
110 |
+
except Exception as error:
|
111 |
+
print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
|
112 |
+
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
113 |
+
|
114 |
+
restored_face = restored_face.astype('uint8')
|
115 |
+
self.face_helper.add_restored_face(restored_face)
|
116 |
+
|
117 |
+
self.face_helper.get_inverse_affine(None)
|
118 |
+
|
119 |
+
restored_img = self.face_helper.paste_faces_to_input_image()
|
120 |
+
restored_img = restored_img[:, :, ::-1]
|
121 |
+
|
122 |
+
if original_resolution != restored_img.shape[0:2]:
|
123 |
+
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
|
124 |
+
|
125 |
+
self.face_helper.clean_all()
|
126 |
+
|
127 |
+
if shared.opts.face_restoration_unload:
|
128 |
+
self.send_model_to(devices.cpu)
|
129 |
+
|
130 |
+
return restored_img
|
131 |
+
|
132 |
+
global have_codeformer
|
133 |
+
have_codeformer = True
|
134 |
+
|
135 |
+
global codeformer
|
136 |
+
codeformer = FaceRestorerCodeFormer(dirname)
|
137 |
+
shared.face_restorers.append(codeformer)
|
138 |
+
|
139 |
+
except Exception:
|
140 |
+
print("Error setting up CodeFormer:", file=sys.stderr)
|
141 |
+
print(traceback.format_exc(), file=sys.stderr)
|
142 |
+
|
143 |
+
# sys.path = stored_sys_path
|
sd/stable-diffusion-webui/modules/deepbooru_model.py
CHANGED
@@ -1,678 +1,678 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
|
5 |
-
from modules import devices
|
6 |
-
|
7 |
-
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
|
8 |
-
|
9 |
-
|
10 |
-
class DeepDanbooruModel(nn.Module):
|
11 |
-
def __init__(self):
|
12 |
-
super(DeepDanbooruModel, self).__init__()
|
13 |
-
|
14 |
-
self.tags = []
|
15 |
-
|
16 |
-
self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
|
17 |
-
self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
|
18 |
-
self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
19 |
-
self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
|
20 |
-
self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
21 |
-
self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
22 |
-
self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
|
23 |
-
self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
24 |
-
self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
25 |
-
self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
|
26 |
-
self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
27 |
-
self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
28 |
-
self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
|
29 |
-
self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
|
30 |
-
self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
|
31 |
-
self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
32 |
-
self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
33 |
-
self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
34 |
-
self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
35 |
-
self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
36 |
-
self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
37 |
-
self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
38 |
-
self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
39 |
-
self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
40 |
-
self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
41 |
-
self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
42 |
-
self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
43 |
-
self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
44 |
-
self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
45 |
-
self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
46 |
-
self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
47 |
-
self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
48 |
-
self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
49 |
-
self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
50 |
-
self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
51 |
-
self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
52 |
-
self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
53 |
-
self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
|
54 |
-
self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
|
55 |
-
self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
|
56 |
-
self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
57 |
-
self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
58 |
-
self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
59 |
-
self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
60 |
-
self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
61 |
-
self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
62 |
-
self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
63 |
-
self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
64 |
-
self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
65 |
-
self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
66 |
-
self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
67 |
-
self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
68 |
-
self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
69 |
-
self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
70 |
-
self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
71 |
-
self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
72 |
-
self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
73 |
-
self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
74 |
-
self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
75 |
-
self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
76 |
-
self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
77 |
-
self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
78 |
-
self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
79 |
-
self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
80 |
-
self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
81 |
-
self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
82 |
-
self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
83 |
-
self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
84 |
-
self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
85 |
-
self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
86 |
-
self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
87 |
-
self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
88 |
-
self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
89 |
-
self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
90 |
-
self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
91 |
-
self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
92 |
-
self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
93 |
-
self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
94 |
-
self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
95 |
-
self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
96 |
-
self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
97 |
-
self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
98 |
-
self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
99 |
-
self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
100 |
-
self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
101 |
-
self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
102 |
-
self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
103 |
-
self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
104 |
-
self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
105 |
-
self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
106 |
-
self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
107 |
-
self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
108 |
-
self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
109 |
-
self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
110 |
-
self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
111 |
-
self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
112 |
-
self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
113 |
-
self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
114 |
-
self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
115 |
-
self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
|
116 |
-
self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
117 |
-
self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
|
118 |
-
self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
119 |
-
self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
120 |
-
self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
121 |
-
self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
122 |
-
self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
123 |
-
self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
124 |
-
self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
125 |
-
self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
126 |
-
self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
127 |
-
self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
128 |
-
self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
129 |
-
self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
130 |
-
self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
131 |
-
self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
132 |
-
self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
133 |
-
self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
134 |
-
self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
135 |
-
self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
136 |
-
self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
137 |
-
self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
138 |
-
self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
139 |
-
self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
140 |
-
self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
141 |
-
self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
142 |
-
self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
143 |
-
self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
144 |
-
self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
145 |
-
self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
146 |
-
self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
147 |
-
self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
148 |
-
self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
149 |
-
self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
150 |
-
self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
151 |
-
self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
152 |
-
self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
153 |
-
self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
154 |
-
self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
155 |
-
self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
156 |
-
self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
157 |
-
self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
158 |
-
self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
159 |
-
self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
160 |
-
self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
161 |
-
self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
162 |
-
self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
163 |
-
self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
164 |
-
self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
165 |
-
self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
166 |
-
self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
167 |
-
self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
168 |
-
self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
169 |
-
self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
170 |
-
self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
171 |
-
self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
172 |
-
self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
173 |
-
self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
174 |
-
self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
175 |
-
self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
|
176 |
-
self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
|
177 |
-
self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
|
178 |
-
self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
179 |
-
self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
|
180 |
-
self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
|
181 |
-
self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
182 |
-
self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
|
183 |
-
self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
|
184 |
-
self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
185 |
-
self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
|
186 |
-
self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
|
187 |
-
self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
|
188 |
-
self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
189 |
-
self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
|
190 |
-
self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
|
191 |
-
self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
192 |
-
self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
|
193 |
-
self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
|
194 |
-
self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
195 |
-
self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
|
196 |
-
|
197 |
-
def forward(self, *inputs):
|
198 |
-
t_358, = inputs
|
199 |
-
t_359 = t_358.permute(*[0, 3, 1, 2])
|
200 |
-
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
|
201 |
-
t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
|
202 |
-
t_361 = F.relu(t_360)
|
203 |
-
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
|
204 |
-
t_362 = self.n_MaxPool_0(t_361)
|
205 |
-
t_363 = self.n_Conv_1(t_362)
|
206 |
-
t_364 = self.n_Conv_2(t_362)
|
207 |
-
t_365 = F.relu(t_364)
|
208 |
-
t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
|
209 |
-
t_366 = self.n_Conv_3(t_365_padded)
|
210 |
-
t_367 = F.relu(t_366)
|
211 |
-
t_368 = self.n_Conv_4(t_367)
|
212 |
-
t_369 = torch.add(t_368, t_363)
|
213 |
-
t_370 = F.relu(t_369)
|
214 |
-
t_371 = self.n_Conv_5(t_370)
|
215 |
-
t_372 = F.relu(t_371)
|
216 |
-
t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
|
217 |
-
t_373 = self.n_Conv_6(t_372_padded)
|
218 |
-
t_374 = F.relu(t_373)
|
219 |
-
t_375 = self.n_Conv_7(t_374)
|
220 |
-
t_376 = torch.add(t_375, t_370)
|
221 |
-
t_377 = F.relu(t_376)
|
222 |
-
t_378 = self.n_Conv_8(t_377)
|
223 |
-
t_379 = F.relu(t_378)
|
224 |
-
t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
|
225 |
-
t_380 = self.n_Conv_9(t_379_padded)
|
226 |
-
t_381 = F.relu(t_380)
|
227 |
-
t_382 = self.n_Conv_10(t_381)
|
228 |
-
t_383 = torch.add(t_382, t_377)
|
229 |
-
t_384 = F.relu(t_383)
|
230 |
-
t_385 = self.n_Conv_11(t_384)
|
231 |
-
t_386 = self.n_Conv_12(t_384)
|
232 |
-
t_387 = F.relu(t_386)
|
233 |
-
t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
|
234 |
-
t_388 = self.n_Conv_13(t_387_padded)
|
235 |
-
t_389 = F.relu(t_388)
|
236 |
-
t_390 = self.n_Conv_14(t_389)
|
237 |
-
t_391 = torch.add(t_390, t_385)
|
238 |
-
t_392 = F.relu(t_391)
|
239 |
-
t_393 = self.n_Conv_15(t_392)
|
240 |
-
t_394 = F.relu(t_393)
|
241 |
-
t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
|
242 |
-
t_395 = self.n_Conv_16(t_394_padded)
|
243 |
-
t_396 = F.relu(t_395)
|
244 |
-
t_397 = self.n_Conv_17(t_396)
|
245 |
-
t_398 = torch.add(t_397, t_392)
|
246 |
-
t_399 = F.relu(t_398)
|
247 |
-
t_400 = self.n_Conv_18(t_399)
|
248 |
-
t_401 = F.relu(t_400)
|
249 |
-
t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
|
250 |
-
t_402 = self.n_Conv_19(t_401_padded)
|
251 |
-
t_403 = F.relu(t_402)
|
252 |
-
t_404 = self.n_Conv_20(t_403)
|
253 |
-
t_405 = torch.add(t_404, t_399)
|
254 |
-
t_406 = F.relu(t_405)
|
255 |
-
t_407 = self.n_Conv_21(t_406)
|
256 |
-
t_408 = F.relu(t_407)
|
257 |
-
t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
|
258 |
-
t_409 = self.n_Conv_22(t_408_padded)
|
259 |
-
t_410 = F.relu(t_409)
|
260 |
-
t_411 = self.n_Conv_23(t_410)
|
261 |
-
t_412 = torch.add(t_411, t_406)
|
262 |
-
t_413 = F.relu(t_412)
|
263 |
-
t_414 = self.n_Conv_24(t_413)
|
264 |
-
t_415 = F.relu(t_414)
|
265 |
-
t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
|
266 |
-
t_416 = self.n_Conv_25(t_415_padded)
|
267 |
-
t_417 = F.relu(t_416)
|
268 |
-
t_418 = self.n_Conv_26(t_417)
|
269 |
-
t_419 = torch.add(t_418, t_413)
|
270 |
-
t_420 = F.relu(t_419)
|
271 |
-
t_421 = self.n_Conv_27(t_420)
|
272 |
-
t_422 = F.relu(t_421)
|
273 |
-
t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
|
274 |
-
t_423 = self.n_Conv_28(t_422_padded)
|
275 |
-
t_424 = F.relu(t_423)
|
276 |
-
t_425 = self.n_Conv_29(t_424)
|
277 |
-
t_426 = torch.add(t_425, t_420)
|
278 |
-
t_427 = F.relu(t_426)
|
279 |
-
t_428 = self.n_Conv_30(t_427)
|
280 |
-
t_429 = F.relu(t_428)
|
281 |
-
t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
|
282 |
-
t_430 = self.n_Conv_31(t_429_padded)
|
283 |
-
t_431 = F.relu(t_430)
|
284 |
-
t_432 = self.n_Conv_32(t_431)
|
285 |
-
t_433 = torch.add(t_432, t_427)
|
286 |
-
t_434 = F.relu(t_433)
|
287 |
-
t_435 = self.n_Conv_33(t_434)
|
288 |
-
t_436 = F.relu(t_435)
|
289 |
-
t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
|
290 |
-
t_437 = self.n_Conv_34(t_436_padded)
|
291 |
-
t_438 = F.relu(t_437)
|
292 |
-
t_439 = self.n_Conv_35(t_438)
|
293 |
-
t_440 = torch.add(t_439, t_434)
|
294 |
-
t_441 = F.relu(t_440)
|
295 |
-
t_442 = self.n_Conv_36(t_441)
|
296 |
-
t_443 = self.n_Conv_37(t_441)
|
297 |
-
t_444 = F.relu(t_443)
|
298 |
-
t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
|
299 |
-
t_445 = self.n_Conv_38(t_444_padded)
|
300 |
-
t_446 = F.relu(t_445)
|
301 |
-
t_447 = self.n_Conv_39(t_446)
|
302 |
-
t_448 = torch.add(t_447, t_442)
|
303 |
-
t_449 = F.relu(t_448)
|
304 |
-
t_450 = self.n_Conv_40(t_449)
|
305 |
-
t_451 = F.relu(t_450)
|
306 |
-
t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
|
307 |
-
t_452 = self.n_Conv_41(t_451_padded)
|
308 |
-
t_453 = F.relu(t_452)
|
309 |
-
t_454 = self.n_Conv_42(t_453)
|
310 |
-
t_455 = torch.add(t_454, t_449)
|
311 |
-
t_456 = F.relu(t_455)
|
312 |
-
t_457 = self.n_Conv_43(t_456)
|
313 |
-
t_458 = F.relu(t_457)
|
314 |
-
t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
|
315 |
-
t_459 = self.n_Conv_44(t_458_padded)
|
316 |
-
t_460 = F.relu(t_459)
|
317 |
-
t_461 = self.n_Conv_45(t_460)
|
318 |
-
t_462 = torch.add(t_461, t_456)
|
319 |
-
t_463 = F.relu(t_462)
|
320 |
-
t_464 = self.n_Conv_46(t_463)
|
321 |
-
t_465 = F.relu(t_464)
|
322 |
-
t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
|
323 |
-
t_466 = self.n_Conv_47(t_465_padded)
|
324 |
-
t_467 = F.relu(t_466)
|
325 |
-
t_468 = self.n_Conv_48(t_467)
|
326 |
-
t_469 = torch.add(t_468, t_463)
|
327 |
-
t_470 = F.relu(t_469)
|
328 |
-
t_471 = self.n_Conv_49(t_470)
|
329 |
-
t_472 = F.relu(t_471)
|
330 |
-
t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
|
331 |
-
t_473 = self.n_Conv_50(t_472_padded)
|
332 |
-
t_474 = F.relu(t_473)
|
333 |
-
t_475 = self.n_Conv_51(t_474)
|
334 |
-
t_476 = torch.add(t_475, t_470)
|
335 |
-
t_477 = F.relu(t_476)
|
336 |
-
t_478 = self.n_Conv_52(t_477)
|
337 |
-
t_479 = F.relu(t_478)
|
338 |
-
t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
|
339 |
-
t_480 = self.n_Conv_53(t_479_padded)
|
340 |
-
t_481 = F.relu(t_480)
|
341 |
-
t_482 = self.n_Conv_54(t_481)
|
342 |
-
t_483 = torch.add(t_482, t_477)
|
343 |
-
t_484 = F.relu(t_483)
|
344 |
-
t_485 = self.n_Conv_55(t_484)
|
345 |
-
t_486 = F.relu(t_485)
|
346 |
-
t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
|
347 |
-
t_487 = self.n_Conv_56(t_486_padded)
|
348 |
-
t_488 = F.relu(t_487)
|
349 |
-
t_489 = self.n_Conv_57(t_488)
|
350 |
-
t_490 = torch.add(t_489, t_484)
|
351 |
-
t_491 = F.relu(t_490)
|
352 |
-
t_492 = self.n_Conv_58(t_491)
|
353 |
-
t_493 = F.relu(t_492)
|
354 |
-
t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
|
355 |
-
t_494 = self.n_Conv_59(t_493_padded)
|
356 |
-
t_495 = F.relu(t_494)
|
357 |
-
t_496 = self.n_Conv_60(t_495)
|
358 |
-
t_497 = torch.add(t_496, t_491)
|
359 |
-
t_498 = F.relu(t_497)
|
360 |
-
t_499 = self.n_Conv_61(t_498)
|
361 |
-
t_500 = F.relu(t_499)
|
362 |
-
t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
|
363 |
-
t_501 = self.n_Conv_62(t_500_padded)
|
364 |
-
t_502 = F.relu(t_501)
|
365 |
-
t_503 = self.n_Conv_63(t_502)
|
366 |
-
t_504 = torch.add(t_503, t_498)
|
367 |
-
t_505 = F.relu(t_504)
|
368 |
-
t_506 = self.n_Conv_64(t_505)
|
369 |
-
t_507 = F.relu(t_506)
|
370 |
-
t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
|
371 |
-
t_508 = self.n_Conv_65(t_507_padded)
|
372 |
-
t_509 = F.relu(t_508)
|
373 |
-
t_510 = self.n_Conv_66(t_509)
|
374 |
-
t_511 = torch.add(t_510, t_505)
|
375 |
-
t_512 = F.relu(t_511)
|
376 |
-
t_513 = self.n_Conv_67(t_512)
|
377 |
-
t_514 = F.relu(t_513)
|
378 |
-
t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
|
379 |
-
t_515 = self.n_Conv_68(t_514_padded)
|
380 |
-
t_516 = F.relu(t_515)
|
381 |
-
t_517 = self.n_Conv_69(t_516)
|
382 |
-
t_518 = torch.add(t_517, t_512)
|
383 |
-
t_519 = F.relu(t_518)
|
384 |
-
t_520 = self.n_Conv_70(t_519)
|
385 |
-
t_521 = F.relu(t_520)
|
386 |
-
t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
|
387 |
-
t_522 = self.n_Conv_71(t_521_padded)
|
388 |
-
t_523 = F.relu(t_522)
|
389 |
-
t_524 = self.n_Conv_72(t_523)
|
390 |
-
t_525 = torch.add(t_524, t_519)
|
391 |
-
t_526 = F.relu(t_525)
|
392 |
-
t_527 = self.n_Conv_73(t_526)
|
393 |
-
t_528 = F.relu(t_527)
|
394 |
-
t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
|
395 |
-
t_529 = self.n_Conv_74(t_528_padded)
|
396 |
-
t_530 = F.relu(t_529)
|
397 |
-
t_531 = self.n_Conv_75(t_530)
|
398 |
-
t_532 = torch.add(t_531, t_526)
|
399 |
-
t_533 = F.relu(t_532)
|
400 |
-
t_534 = self.n_Conv_76(t_533)
|
401 |
-
t_535 = F.relu(t_534)
|
402 |
-
t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
|
403 |
-
t_536 = self.n_Conv_77(t_535_padded)
|
404 |
-
t_537 = F.relu(t_536)
|
405 |
-
t_538 = self.n_Conv_78(t_537)
|
406 |
-
t_539 = torch.add(t_538, t_533)
|
407 |
-
t_540 = F.relu(t_539)
|
408 |
-
t_541 = self.n_Conv_79(t_540)
|
409 |
-
t_542 = F.relu(t_541)
|
410 |
-
t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
|
411 |
-
t_543 = self.n_Conv_80(t_542_padded)
|
412 |
-
t_544 = F.relu(t_543)
|
413 |
-
t_545 = self.n_Conv_81(t_544)
|
414 |
-
t_546 = torch.add(t_545, t_540)
|
415 |
-
t_547 = F.relu(t_546)
|
416 |
-
t_548 = self.n_Conv_82(t_547)
|
417 |
-
t_549 = F.relu(t_548)
|
418 |
-
t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
|
419 |
-
t_550 = self.n_Conv_83(t_549_padded)
|
420 |
-
t_551 = F.relu(t_550)
|
421 |
-
t_552 = self.n_Conv_84(t_551)
|
422 |
-
t_553 = torch.add(t_552, t_547)
|
423 |
-
t_554 = F.relu(t_553)
|
424 |
-
t_555 = self.n_Conv_85(t_554)
|
425 |
-
t_556 = F.relu(t_555)
|
426 |
-
t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
|
427 |
-
t_557 = self.n_Conv_86(t_556_padded)
|
428 |
-
t_558 = F.relu(t_557)
|
429 |
-
t_559 = self.n_Conv_87(t_558)
|
430 |
-
t_560 = torch.add(t_559, t_554)
|
431 |
-
t_561 = F.relu(t_560)
|
432 |
-
t_562 = self.n_Conv_88(t_561)
|
433 |
-
t_563 = F.relu(t_562)
|
434 |
-
t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
|
435 |
-
t_564 = self.n_Conv_89(t_563_padded)
|
436 |
-
t_565 = F.relu(t_564)
|
437 |
-
t_566 = self.n_Conv_90(t_565)
|
438 |
-
t_567 = torch.add(t_566, t_561)
|
439 |
-
t_568 = F.relu(t_567)
|
440 |
-
t_569 = self.n_Conv_91(t_568)
|
441 |
-
t_570 = F.relu(t_569)
|
442 |
-
t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
|
443 |
-
t_571 = self.n_Conv_92(t_570_padded)
|
444 |
-
t_572 = F.relu(t_571)
|
445 |
-
t_573 = self.n_Conv_93(t_572)
|
446 |
-
t_574 = torch.add(t_573, t_568)
|
447 |
-
t_575 = F.relu(t_574)
|
448 |
-
t_576 = self.n_Conv_94(t_575)
|
449 |
-
t_577 = F.relu(t_576)
|
450 |
-
t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
|
451 |
-
t_578 = self.n_Conv_95(t_577_padded)
|
452 |
-
t_579 = F.relu(t_578)
|
453 |
-
t_580 = self.n_Conv_96(t_579)
|
454 |
-
t_581 = torch.add(t_580, t_575)
|
455 |
-
t_582 = F.relu(t_581)
|
456 |
-
t_583 = self.n_Conv_97(t_582)
|
457 |
-
t_584 = F.relu(t_583)
|
458 |
-
t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
|
459 |
-
t_585 = self.n_Conv_98(t_584_padded)
|
460 |
-
t_586 = F.relu(t_585)
|
461 |
-
t_587 = self.n_Conv_99(t_586)
|
462 |
-
t_588 = self.n_Conv_100(t_582)
|
463 |
-
t_589 = torch.add(t_587, t_588)
|
464 |
-
t_590 = F.relu(t_589)
|
465 |
-
t_591 = self.n_Conv_101(t_590)
|
466 |
-
t_592 = F.relu(t_591)
|
467 |
-
t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
|
468 |
-
t_593 = self.n_Conv_102(t_592_padded)
|
469 |
-
t_594 = F.relu(t_593)
|
470 |
-
t_595 = self.n_Conv_103(t_594)
|
471 |
-
t_596 = torch.add(t_595, t_590)
|
472 |
-
t_597 = F.relu(t_596)
|
473 |
-
t_598 = self.n_Conv_104(t_597)
|
474 |
-
t_599 = F.relu(t_598)
|
475 |
-
t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
|
476 |
-
t_600 = self.n_Conv_105(t_599_padded)
|
477 |
-
t_601 = F.relu(t_600)
|
478 |
-
t_602 = self.n_Conv_106(t_601)
|
479 |
-
t_603 = torch.add(t_602, t_597)
|
480 |
-
t_604 = F.relu(t_603)
|
481 |
-
t_605 = self.n_Conv_107(t_604)
|
482 |
-
t_606 = F.relu(t_605)
|
483 |
-
t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
|
484 |
-
t_607 = self.n_Conv_108(t_606_padded)
|
485 |
-
t_608 = F.relu(t_607)
|
486 |
-
t_609 = self.n_Conv_109(t_608)
|
487 |
-
t_610 = torch.add(t_609, t_604)
|
488 |
-
t_611 = F.relu(t_610)
|
489 |
-
t_612 = self.n_Conv_110(t_611)
|
490 |
-
t_613 = F.relu(t_612)
|
491 |
-
t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
|
492 |
-
t_614 = self.n_Conv_111(t_613_padded)
|
493 |
-
t_615 = F.relu(t_614)
|
494 |
-
t_616 = self.n_Conv_112(t_615)
|
495 |
-
t_617 = torch.add(t_616, t_611)
|
496 |
-
t_618 = F.relu(t_617)
|
497 |
-
t_619 = self.n_Conv_113(t_618)
|
498 |
-
t_620 = F.relu(t_619)
|
499 |
-
t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
|
500 |
-
t_621 = self.n_Conv_114(t_620_padded)
|
501 |
-
t_622 = F.relu(t_621)
|
502 |
-
t_623 = self.n_Conv_115(t_622)
|
503 |
-
t_624 = torch.add(t_623, t_618)
|
504 |
-
t_625 = F.relu(t_624)
|
505 |
-
t_626 = self.n_Conv_116(t_625)
|
506 |
-
t_627 = F.relu(t_626)
|
507 |
-
t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
|
508 |
-
t_628 = self.n_Conv_117(t_627_padded)
|
509 |
-
t_629 = F.relu(t_628)
|
510 |
-
t_630 = self.n_Conv_118(t_629)
|
511 |
-
t_631 = torch.add(t_630, t_625)
|
512 |
-
t_632 = F.relu(t_631)
|
513 |
-
t_633 = self.n_Conv_119(t_632)
|
514 |
-
t_634 = F.relu(t_633)
|
515 |
-
t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
|
516 |
-
t_635 = self.n_Conv_120(t_634_padded)
|
517 |
-
t_636 = F.relu(t_635)
|
518 |
-
t_637 = self.n_Conv_121(t_636)
|
519 |
-
t_638 = torch.add(t_637, t_632)
|
520 |
-
t_639 = F.relu(t_638)
|
521 |
-
t_640 = self.n_Conv_122(t_639)
|
522 |
-
t_641 = F.relu(t_640)
|
523 |
-
t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
|
524 |
-
t_642 = self.n_Conv_123(t_641_padded)
|
525 |
-
t_643 = F.relu(t_642)
|
526 |
-
t_644 = self.n_Conv_124(t_643)
|
527 |
-
t_645 = torch.add(t_644, t_639)
|
528 |
-
t_646 = F.relu(t_645)
|
529 |
-
t_647 = self.n_Conv_125(t_646)
|
530 |
-
t_648 = F.relu(t_647)
|
531 |
-
t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
|
532 |
-
t_649 = self.n_Conv_126(t_648_padded)
|
533 |
-
t_650 = F.relu(t_649)
|
534 |
-
t_651 = self.n_Conv_127(t_650)
|
535 |
-
t_652 = torch.add(t_651, t_646)
|
536 |
-
t_653 = F.relu(t_652)
|
537 |
-
t_654 = self.n_Conv_128(t_653)
|
538 |
-
t_655 = F.relu(t_654)
|
539 |
-
t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
|
540 |
-
t_656 = self.n_Conv_129(t_655_padded)
|
541 |
-
t_657 = F.relu(t_656)
|
542 |
-
t_658 = self.n_Conv_130(t_657)
|
543 |
-
t_659 = torch.add(t_658, t_653)
|
544 |
-
t_660 = F.relu(t_659)
|
545 |
-
t_661 = self.n_Conv_131(t_660)
|
546 |
-
t_662 = F.relu(t_661)
|
547 |
-
t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
|
548 |
-
t_663 = self.n_Conv_132(t_662_padded)
|
549 |
-
t_664 = F.relu(t_663)
|
550 |
-
t_665 = self.n_Conv_133(t_664)
|
551 |
-
t_666 = torch.add(t_665, t_660)
|
552 |
-
t_667 = F.relu(t_666)
|
553 |
-
t_668 = self.n_Conv_134(t_667)
|
554 |
-
t_669 = F.relu(t_668)
|
555 |
-
t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
|
556 |
-
t_670 = self.n_Conv_135(t_669_padded)
|
557 |
-
t_671 = F.relu(t_670)
|
558 |
-
t_672 = self.n_Conv_136(t_671)
|
559 |
-
t_673 = torch.add(t_672, t_667)
|
560 |
-
t_674 = F.relu(t_673)
|
561 |
-
t_675 = self.n_Conv_137(t_674)
|
562 |
-
t_676 = F.relu(t_675)
|
563 |
-
t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
|
564 |
-
t_677 = self.n_Conv_138(t_676_padded)
|
565 |
-
t_678 = F.relu(t_677)
|
566 |
-
t_679 = self.n_Conv_139(t_678)
|
567 |
-
t_680 = torch.add(t_679, t_674)
|
568 |
-
t_681 = F.relu(t_680)
|
569 |
-
t_682 = self.n_Conv_140(t_681)
|
570 |
-
t_683 = F.relu(t_682)
|
571 |
-
t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
|
572 |
-
t_684 = self.n_Conv_141(t_683_padded)
|
573 |
-
t_685 = F.relu(t_684)
|
574 |
-
t_686 = self.n_Conv_142(t_685)
|
575 |
-
t_687 = torch.add(t_686, t_681)
|
576 |
-
t_688 = F.relu(t_687)
|
577 |
-
t_689 = self.n_Conv_143(t_688)
|
578 |
-
t_690 = F.relu(t_689)
|
579 |
-
t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
|
580 |
-
t_691 = self.n_Conv_144(t_690_padded)
|
581 |
-
t_692 = F.relu(t_691)
|
582 |
-
t_693 = self.n_Conv_145(t_692)
|
583 |
-
t_694 = torch.add(t_693, t_688)
|
584 |
-
t_695 = F.relu(t_694)
|
585 |
-
t_696 = self.n_Conv_146(t_695)
|
586 |
-
t_697 = F.relu(t_696)
|
587 |
-
t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
|
588 |
-
t_698 = self.n_Conv_147(t_697_padded)
|
589 |
-
t_699 = F.relu(t_698)
|
590 |
-
t_700 = self.n_Conv_148(t_699)
|
591 |
-
t_701 = torch.add(t_700, t_695)
|
592 |
-
t_702 = F.relu(t_701)
|
593 |
-
t_703 = self.n_Conv_149(t_702)
|
594 |
-
t_704 = F.relu(t_703)
|
595 |
-
t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
|
596 |
-
t_705 = self.n_Conv_150(t_704_padded)
|
597 |
-
t_706 = F.relu(t_705)
|
598 |
-
t_707 = self.n_Conv_151(t_706)
|
599 |
-
t_708 = torch.add(t_707, t_702)
|
600 |
-
t_709 = F.relu(t_708)
|
601 |
-
t_710 = self.n_Conv_152(t_709)
|
602 |
-
t_711 = F.relu(t_710)
|
603 |
-
t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
|
604 |
-
t_712 = self.n_Conv_153(t_711_padded)
|
605 |
-
t_713 = F.relu(t_712)
|
606 |
-
t_714 = self.n_Conv_154(t_713)
|
607 |
-
t_715 = torch.add(t_714, t_709)
|
608 |
-
t_716 = F.relu(t_715)
|
609 |
-
t_717 = self.n_Conv_155(t_716)
|
610 |
-
t_718 = F.relu(t_717)
|
611 |
-
t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
|
612 |
-
t_719 = self.n_Conv_156(t_718_padded)
|
613 |
-
t_720 = F.relu(t_719)
|
614 |
-
t_721 = self.n_Conv_157(t_720)
|
615 |
-
t_722 = torch.add(t_721, t_716)
|
616 |
-
t_723 = F.relu(t_722)
|
617 |
-
t_724 = self.n_Conv_158(t_723)
|
618 |
-
t_725 = self.n_Conv_159(t_723)
|
619 |
-
t_726 = F.relu(t_725)
|
620 |
-
t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
|
621 |
-
t_727 = self.n_Conv_160(t_726_padded)
|
622 |
-
t_728 = F.relu(t_727)
|
623 |
-
t_729 = self.n_Conv_161(t_728)
|
624 |
-
t_730 = torch.add(t_729, t_724)
|
625 |
-
t_731 = F.relu(t_730)
|
626 |
-
t_732 = self.n_Conv_162(t_731)
|
627 |
-
t_733 = F.relu(t_732)
|
628 |
-
t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
|
629 |
-
t_734 = self.n_Conv_163(t_733_padded)
|
630 |
-
t_735 = F.relu(t_734)
|
631 |
-
t_736 = self.n_Conv_164(t_735)
|
632 |
-
t_737 = torch.add(t_736, t_731)
|
633 |
-
t_738 = F.relu(t_737)
|
634 |
-
t_739 = self.n_Conv_165(t_738)
|
635 |
-
t_740 = F.relu(t_739)
|
636 |
-
t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
|
637 |
-
t_741 = self.n_Conv_166(t_740_padded)
|
638 |
-
t_742 = F.relu(t_741)
|
639 |
-
t_743 = self.n_Conv_167(t_742)
|
640 |
-
t_744 = torch.add(t_743, t_738)
|
641 |
-
t_745 = F.relu(t_744)
|
642 |
-
t_746 = self.n_Conv_168(t_745)
|
643 |
-
t_747 = self.n_Conv_169(t_745)
|
644 |
-
t_748 = F.relu(t_747)
|
645 |
-
t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
|
646 |
-
t_749 = self.n_Conv_170(t_748_padded)
|
647 |
-
t_750 = F.relu(t_749)
|
648 |
-
t_751 = self.n_Conv_171(t_750)
|
649 |
-
t_752 = torch.add(t_751, t_746)
|
650 |
-
t_753 = F.relu(t_752)
|
651 |
-
t_754 = self.n_Conv_172(t_753)
|
652 |
-
t_755 = F.relu(t_754)
|
653 |
-
t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
|
654 |
-
t_756 = self.n_Conv_173(t_755_padded)
|
655 |
-
t_757 = F.relu(t_756)
|
656 |
-
t_758 = self.n_Conv_174(t_757)
|
657 |
-
t_759 = torch.add(t_758, t_753)
|
658 |
-
t_760 = F.relu(t_759)
|
659 |
-
t_761 = self.n_Conv_175(t_760)
|
660 |
-
t_762 = F.relu(t_761)
|
661 |
-
t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
|
662 |
-
t_763 = self.n_Conv_176(t_762_padded)
|
663 |
-
t_764 = F.relu(t_763)
|
664 |
-
t_765 = self.n_Conv_177(t_764)
|
665 |
-
t_766 = torch.add(t_765, t_760)
|
666 |
-
t_767 = F.relu(t_766)
|
667 |
-
t_768 = self.n_Conv_178(t_767)
|
668 |
-
t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
|
669 |
-
t_770 = torch.squeeze(t_769, 3)
|
670 |
-
t_770 = torch.squeeze(t_770, 2)
|
671 |
-
t_771 = torch.sigmoid(t_770)
|
672 |
-
return t_771
|
673 |
-
|
674 |
-
def load_state_dict(self, state_dict, **kwargs):
|
675 |
-
self.tags = state_dict.get('tags', [])
|
676 |
-
|
677 |
-
super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
|
678 |
-
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from modules import devices
|
6 |
+
|
7 |
+
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
|
8 |
+
|
9 |
+
|
10 |
+
class DeepDanbooruModel(nn.Module):
|
11 |
+
def __init__(self):
|
12 |
+
super(DeepDanbooruModel, self).__init__()
|
13 |
+
|
14 |
+
self.tags = []
|
15 |
+
|
16 |
+
self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
|
17 |
+
self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
|
18 |
+
self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
19 |
+
self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
|
20 |
+
self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
21 |
+
self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
22 |
+
self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
|
23 |
+
self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
24 |
+
self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
25 |
+
self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
|
26 |
+
self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
27 |
+
self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
28 |
+
self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
|
29 |
+
self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
|
30 |
+
self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
|
31 |
+
self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
32 |
+
self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
33 |
+
self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
34 |
+
self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
35 |
+
self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
36 |
+
self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
37 |
+
self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
38 |
+
self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
39 |
+
self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
40 |
+
self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
41 |
+
self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
42 |
+
self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
43 |
+
self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
44 |
+
self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
45 |
+
self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
46 |
+
self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
47 |
+
self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
48 |
+
self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
49 |
+
self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
50 |
+
self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
51 |
+
self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
52 |
+
self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
53 |
+
self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
|
54 |
+
self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
|
55 |
+
self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
|
56 |
+
self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
57 |
+
self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
58 |
+
self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
59 |
+
self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
60 |
+
self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
61 |
+
self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
62 |
+
self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
63 |
+
self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
64 |
+
self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
65 |
+
self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
66 |
+
self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
67 |
+
self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
68 |
+
self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
69 |
+
self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
70 |
+
self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
71 |
+
self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
72 |
+
self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
73 |
+
self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
74 |
+
self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
75 |
+
self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
76 |
+
self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
77 |
+
self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
78 |
+
self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
79 |
+
self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
80 |
+
self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
81 |
+
self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
82 |
+
self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
83 |
+
self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
84 |
+
self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
85 |
+
self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
86 |
+
self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
87 |
+
self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
88 |
+
self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
89 |
+
self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
90 |
+
self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
91 |
+
self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
92 |
+
self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
93 |
+
self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
94 |
+
self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
95 |
+
self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
96 |
+
self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
97 |
+
self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
98 |
+
self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
99 |
+
self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
100 |
+
self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
101 |
+
self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
102 |
+
self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
103 |
+
self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
104 |
+
self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
105 |
+
self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
106 |
+
self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
107 |
+
self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
108 |
+
self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
109 |
+
self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
110 |
+
self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
111 |
+
self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
112 |
+
self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
113 |
+
self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
114 |
+
self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
115 |
+
self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
|
116 |
+
self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
117 |
+
self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
|
118 |
+
self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
119 |
+
self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
120 |
+
self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
121 |
+
self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
122 |
+
self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
123 |
+
self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
124 |
+
self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
125 |
+
self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
126 |
+
self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
127 |
+
self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
128 |
+
self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
129 |
+
self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
130 |
+
self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
131 |
+
self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
132 |
+
self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
133 |
+
self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
134 |
+
self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
135 |
+
self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
136 |
+
self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
137 |
+
self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
138 |
+
self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
139 |
+
self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
140 |
+
self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
141 |
+
self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
142 |
+
self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
143 |
+
self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
144 |
+
self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
145 |
+
self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
146 |
+
self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
147 |
+
self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
148 |
+
self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
149 |
+
self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
150 |
+
self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
151 |
+
self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
152 |
+
self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
153 |
+
self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
154 |
+
self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
155 |
+
self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
156 |
+
self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
157 |
+
self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
158 |
+
self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
159 |
+
self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
160 |
+
self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
161 |
+
self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
162 |
+
self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
163 |
+
self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
164 |
+
self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
165 |
+
self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
166 |
+
self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
167 |
+
self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
168 |
+
self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
169 |
+
self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
170 |
+
self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
171 |
+
self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
172 |
+
self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
173 |
+
self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
174 |
+
self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
175 |
+
self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
|
176 |
+
self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
|
177 |
+
self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
|
178 |
+
self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
179 |
+
self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
|
180 |
+
self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
|
181 |
+
self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
182 |
+
self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
|
183 |
+
self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
|
184 |
+
self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
185 |
+
self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
|
186 |
+
self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
|
187 |
+
self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
|
188 |
+
self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
189 |
+
self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
|
190 |
+
self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
|
191 |
+
self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
192 |
+
self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
|
193 |
+
self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
|
194 |
+
self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
195 |
+
self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
|
196 |
+
|
197 |
+
def forward(self, *inputs):
|
198 |
+
t_358, = inputs
|
199 |
+
t_359 = t_358.permute(*[0, 3, 1, 2])
|
200 |
+
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
|
201 |
+
t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
|
202 |
+
t_361 = F.relu(t_360)
|
203 |
+
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
|
204 |
+
t_362 = self.n_MaxPool_0(t_361)
|
205 |
+
t_363 = self.n_Conv_1(t_362)
|
206 |
+
t_364 = self.n_Conv_2(t_362)
|
207 |
+
t_365 = F.relu(t_364)
|
208 |
+
t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
|
209 |
+
t_366 = self.n_Conv_3(t_365_padded)
|
210 |
+
t_367 = F.relu(t_366)
|
211 |
+
t_368 = self.n_Conv_4(t_367)
|
212 |
+
t_369 = torch.add(t_368, t_363)
|
213 |
+
t_370 = F.relu(t_369)
|
214 |
+
t_371 = self.n_Conv_5(t_370)
|
215 |
+
t_372 = F.relu(t_371)
|
216 |
+
t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
|
217 |
+
t_373 = self.n_Conv_6(t_372_padded)
|
218 |
+
t_374 = F.relu(t_373)
|
219 |
+
t_375 = self.n_Conv_7(t_374)
|
220 |
+
t_376 = torch.add(t_375, t_370)
|
221 |
+
t_377 = F.relu(t_376)
|
222 |
+
t_378 = self.n_Conv_8(t_377)
|
223 |
+
t_379 = F.relu(t_378)
|
224 |
+
t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
|
225 |
+
t_380 = self.n_Conv_9(t_379_padded)
|
226 |
+
t_381 = F.relu(t_380)
|
227 |
+
t_382 = self.n_Conv_10(t_381)
|
228 |
+
t_383 = torch.add(t_382, t_377)
|
229 |
+
t_384 = F.relu(t_383)
|
230 |
+
t_385 = self.n_Conv_11(t_384)
|
231 |
+
t_386 = self.n_Conv_12(t_384)
|
232 |
+
t_387 = F.relu(t_386)
|
233 |
+
t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
|
234 |
+
t_388 = self.n_Conv_13(t_387_padded)
|
235 |
+
t_389 = F.relu(t_388)
|
236 |
+
t_390 = self.n_Conv_14(t_389)
|
237 |
+
t_391 = torch.add(t_390, t_385)
|
238 |
+
t_392 = F.relu(t_391)
|
239 |
+
t_393 = self.n_Conv_15(t_392)
|
240 |
+
t_394 = F.relu(t_393)
|
241 |
+
t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
|
242 |
+
t_395 = self.n_Conv_16(t_394_padded)
|
243 |
+
t_396 = F.relu(t_395)
|
244 |
+
t_397 = self.n_Conv_17(t_396)
|
245 |
+
t_398 = torch.add(t_397, t_392)
|
246 |
+
t_399 = F.relu(t_398)
|
247 |
+
t_400 = self.n_Conv_18(t_399)
|
248 |
+
t_401 = F.relu(t_400)
|
249 |
+
t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
|
250 |
+
t_402 = self.n_Conv_19(t_401_padded)
|
251 |
+
t_403 = F.relu(t_402)
|
252 |
+
t_404 = self.n_Conv_20(t_403)
|
253 |
+
t_405 = torch.add(t_404, t_399)
|
254 |
+
t_406 = F.relu(t_405)
|
255 |
+
t_407 = self.n_Conv_21(t_406)
|
256 |
+
t_408 = F.relu(t_407)
|
257 |
+
t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
|
258 |
+
t_409 = self.n_Conv_22(t_408_padded)
|
259 |
+
t_410 = F.relu(t_409)
|
260 |
+
t_411 = self.n_Conv_23(t_410)
|
261 |
+
t_412 = torch.add(t_411, t_406)
|
262 |
+
t_413 = F.relu(t_412)
|
263 |
+
t_414 = self.n_Conv_24(t_413)
|
264 |
+
t_415 = F.relu(t_414)
|
265 |
+
t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
|
266 |
+
t_416 = self.n_Conv_25(t_415_padded)
|
267 |
+
t_417 = F.relu(t_416)
|
268 |
+
t_418 = self.n_Conv_26(t_417)
|
269 |
+
t_419 = torch.add(t_418, t_413)
|
270 |
+
t_420 = F.relu(t_419)
|
271 |
+
t_421 = self.n_Conv_27(t_420)
|
272 |
+
t_422 = F.relu(t_421)
|
273 |
+
t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
|
274 |
+
t_423 = self.n_Conv_28(t_422_padded)
|
275 |
+
t_424 = F.relu(t_423)
|
276 |
+
t_425 = self.n_Conv_29(t_424)
|
277 |
+
t_426 = torch.add(t_425, t_420)
|
278 |
+
t_427 = F.relu(t_426)
|
279 |
+
t_428 = self.n_Conv_30(t_427)
|
280 |
+
t_429 = F.relu(t_428)
|
281 |
+
t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
|
282 |
+
t_430 = self.n_Conv_31(t_429_padded)
|
283 |
+
t_431 = F.relu(t_430)
|
284 |
+
t_432 = self.n_Conv_32(t_431)
|
285 |
+
t_433 = torch.add(t_432, t_427)
|
286 |
+
t_434 = F.relu(t_433)
|
287 |
+
t_435 = self.n_Conv_33(t_434)
|
288 |
+
t_436 = F.relu(t_435)
|
289 |
+
t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
|
290 |
+
t_437 = self.n_Conv_34(t_436_padded)
|
291 |
+
t_438 = F.relu(t_437)
|
292 |
+
t_439 = self.n_Conv_35(t_438)
|
293 |
+
t_440 = torch.add(t_439, t_434)
|
294 |
+
t_441 = F.relu(t_440)
|
295 |
+
t_442 = self.n_Conv_36(t_441)
|
296 |
+
t_443 = self.n_Conv_37(t_441)
|
297 |
+
t_444 = F.relu(t_443)
|
298 |
+
t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
|
299 |
+
t_445 = self.n_Conv_38(t_444_padded)
|
300 |
+
t_446 = F.relu(t_445)
|
301 |
+
t_447 = self.n_Conv_39(t_446)
|
302 |
+
t_448 = torch.add(t_447, t_442)
|
303 |
+
t_449 = F.relu(t_448)
|
304 |
+
t_450 = self.n_Conv_40(t_449)
|
305 |
+
t_451 = F.relu(t_450)
|
306 |
+
t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
|
307 |
+
t_452 = self.n_Conv_41(t_451_padded)
|
308 |
+
t_453 = F.relu(t_452)
|
309 |
+
t_454 = self.n_Conv_42(t_453)
|
310 |
+
t_455 = torch.add(t_454, t_449)
|
311 |
+
t_456 = F.relu(t_455)
|
312 |
+
t_457 = self.n_Conv_43(t_456)
|
313 |
+
t_458 = F.relu(t_457)
|
314 |
+
t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
|
315 |
+
t_459 = self.n_Conv_44(t_458_padded)
|
316 |
+
t_460 = F.relu(t_459)
|
317 |
+
t_461 = self.n_Conv_45(t_460)
|
318 |
+
t_462 = torch.add(t_461, t_456)
|
319 |
+
t_463 = F.relu(t_462)
|
320 |
+
t_464 = self.n_Conv_46(t_463)
|
321 |
+
t_465 = F.relu(t_464)
|
322 |
+
t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
|
323 |
+
t_466 = self.n_Conv_47(t_465_padded)
|
324 |
+
t_467 = F.relu(t_466)
|
325 |
+
t_468 = self.n_Conv_48(t_467)
|
326 |
+
t_469 = torch.add(t_468, t_463)
|
327 |
+
t_470 = F.relu(t_469)
|
328 |
+
t_471 = self.n_Conv_49(t_470)
|
329 |
+
t_472 = F.relu(t_471)
|
330 |
+
t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
|
331 |
+
t_473 = self.n_Conv_50(t_472_padded)
|
332 |
+
t_474 = F.relu(t_473)
|
333 |
+
t_475 = self.n_Conv_51(t_474)
|
334 |
+
t_476 = torch.add(t_475, t_470)
|
335 |
+
t_477 = F.relu(t_476)
|
336 |
+
t_478 = self.n_Conv_52(t_477)
|
337 |
+
t_479 = F.relu(t_478)
|
338 |
+
t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
|
339 |
+
t_480 = self.n_Conv_53(t_479_padded)
|
340 |
+
t_481 = F.relu(t_480)
|
341 |
+
t_482 = self.n_Conv_54(t_481)
|
342 |
+
t_483 = torch.add(t_482, t_477)
|
343 |
+
t_484 = F.relu(t_483)
|
344 |
+
t_485 = self.n_Conv_55(t_484)
|
345 |
+
t_486 = F.relu(t_485)
|
346 |
+
t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
|
347 |
+
t_487 = self.n_Conv_56(t_486_padded)
|
348 |
+
t_488 = F.relu(t_487)
|
349 |
+
t_489 = self.n_Conv_57(t_488)
|
350 |
+
t_490 = torch.add(t_489, t_484)
|
351 |
+
t_491 = F.relu(t_490)
|
352 |
+
t_492 = self.n_Conv_58(t_491)
|
353 |
+
t_493 = F.relu(t_492)
|
354 |
+
t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
|
355 |
+
t_494 = self.n_Conv_59(t_493_padded)
|
356 |
+
t_495 = F.relu(t_494)
|
357 |
+
t_496 = self.n_Conv_60(t_495)
|
358 |
+
t_497 = torch.add(t_496, t_491)
|
359 |
+
t_498 = F.relu(t_497)
|
360 |
+
t_499 = self.n_Conv_61(t_498)
|
361 |
+
t_500 = F.relu(t_499)
|
362 |
+
t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
|
363 |
+
t_501 = self.n_Conv_62(t_500_padded)
|
364 |
+
t_502 = F.relu(t_501)
|
365 |
+
t_503 = self.n_Conv_63(t_502)
|
366 |
+
t_504 = torch.add(t_503, t_498)
|
367 |
+
t_505 = F.relu(t_504)
|
368 |
+
t_506 = self.n_Conv_64(t_505)
|
369 |
+
t_507 = F.relu(t_506)
|
370 |
+
t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
|
371 |
+
t_508 = self.n_Conv_65(t_507_padded)
|
372 |
+
t_509 = F.relu(t_508)
|
373 |
+
t_510 = self.n_Conv_66(t_509)
|
374 |
+
t_511 = torch.add(t_510, t_505)
|
375 |
+
t_512 = F.relu(t_511)
|
376 |
+
t_513 = self.n_Conv_67(t_512)
|
377 |
+
t_514 = F.relu(t_513)
|
378 |
+
t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
|
379 |
+
t_515 = self.n_Conv_68(t_514_padded)
|
380 |
+
t_516 = F.relu(t_515)
|
381 |
+
t_517 = self.n_Conv_69(t_516)
|
382 |
+
t_518 = torch.add(t_517, t_512)
|
383 |
+
t_519 = F.relu(t_518)
|
384 |
+
t_520 = self.n_Conv_70(t_519)
|
385 |
+
t_521 = F.relu(t_520)
|
386 |
+
t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
|
387 |
+
t_522 = self.n_Conv_71(t_521_padded)
|
388 |
+
t_523 = F.relu(t_522)
|
389 |
+
t_524 = self.n_Conv_72(t_523)
|
390 |
+
t_525 = torch.add(t_524, t_519)
|
391 |
+
t_526 = F.relu(t_525)
|
392 |
+
t_527 = self.n_Conv_73(t_526)
|
393 |
+
t_528 = F.relu(t_527)
|
394 |
+
t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
|
395 |
+
t_529 = self.n_Conv_74(t_528_padded)
|
396 |
+
t_530 = F.relu(t_529)
|
397 |
+
t_531 = self.n_Conv_75(t_530)
|
398 |
+
t_532 = torch.add(t_531, t_526)
|
399 |
+
t_533 = F.relu(t_532)
|
400 |
+
t_534 = self.n_Conv_76(t_533)
|
401 |
+
t_535 = F.relu(t_534)
|
402 |
+
t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
|
403 |
+
t_536 = self.n_Conv_77(t_535_padded)
|
404 |
+
t_537 = F.relu(t_536)
|
405 |
+
t_538 = self.n_Conv_78(t_537)
|
406 |
+
t_539 = torch.add(t_538, t_533)
|
407 |
+
t_540 = F.relu(t_539)
|
408 |
+
t_541 = self.n_Conv_79(t_540)
|
409 |
+
t_542 = F.relu(t_541)
|
410 |
+
t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
|
411 |
+
t_543 = self.n_Conv_80(t_542_padded)
|
412 |
+
t_544 = F.relu(t_543)
|
413 |
+
t_545 = self.n_Conv_81(t_544)
|
414 |
+
t_546 = torch.add(t_545, t_540)
|
415 |
+
t_547 = F.relu(t_546)
|
416 |
+
t_548 = self.n_Conv_82(t_547)
|
417 |
+
t_549 = F.relu(t_548)
|
418 |
+
t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
|
419 |
+
t_550 = self.n_Conv_83(t_549_padded)
|
420 |
+
t_551 = F.relu(t_550)
|
421 |
+
t_552 = self.n_Conv_84(t_551)
|
422 |
+
t_553 = torch.add(t_552, t_547)
|
423 |
+
t_554 = F.relu(t_553)
|
424 |
+
t_555 = self.n_Conv_85(t_554)
|
425 |
+
t_556 = F.relu(t_555)
|
426 |
+
t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
|
427 |
+
t_557 = self.n_Conv_86(t_556_padded)
|
428 |
+
t_558 = F.relu(t_557)
|
429 |
+
t_559 = self.n_Conv_87(t_558)
|
430 |
+
t_560 = torch.add(t_559, t_554)
|
431 |
+
t_561 = F.relu(t_560)
|
432 |
+
t_562 = self.n_Conv_88(t_561)
|
433 |
+
t_563 = F.relu(t_562)
|
434 |
+
t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
|
435 |
+
t_564 = self.n_Conv_89(t_563_padded)
|
436 |
+
t_565 = F.relu(t_564)
|
437 |
+
t_566 = self.n_Conv_90(t_565)
|
438 |
+
t_567 = torch.add(t_566, t_561)
|
439 |
+
t_568 = F.relu(t_567)
|
440 |
+
t_569 = self.n_Conv_91(t_568)
|
441 |
+
t_570 = F.relu(t_569)
|
442 |
+
t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
|
443 |
+
t_571 = self.n_Conv_92(t_570_padded)
|
444 |
+
t_572 = F.relu(t_571)
|
445 |
+
t_573 = self.n_Conv_93(t_572)
|
446 |
+
t_574 = torch.add(t_573, t_568)
|
447 |
+
t_575 = F.relu(t_574)
|
448 |
+
t_576 = self.n_Conv_94(t_575)
|
449 |
+
t_577 = F.relu(t_576)
|
450 |
+
t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
|
451 |
+
t_578 = self.n_Conv_95(t_577_padded)
|
452 |
+
t_579 = F.relu(t_578)
|
453 |
+
t_580 = self.n_Conv_96(t_579)
|
454 |
+
t_581 = torch.add(t_580, t_575)
|
455 |
+
t_582 = F.relu(t_581)
|
456 |
+
t_583 = self.n_Conv_97(t_582)
|
457 |
+
t_584 = F.relu(t_583)
|
458 |
+
t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
|
459 |
+
t_585 = self.n_Conv_98(t_584_padded)
|
460 |
+
t_586 = F.relu(t_585)
|
461 |
+
t_587 = self.n_Conv_99(t_586)
|
462 |
+
t_588 = self.n_Conv_100(t_582)
|
463 |
+
t_589 = torch.add(t_587, t_588)
|
464 |
+
t_590 = F.relu(t_589)
|
465 |
+
t_591 = self.n_Conv_101(t_590)
|
466 |
+
t_592 = F.relu(t_591)
|
467 |
+
t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
|
468 |
+
t_593 = self.n_Conv_102(t_592_padded)
|
469 |
+
t_594 = F.relu(t_593)
|
470 |
+
t_595 = self.n_Conv_103(t_594)
|
471 |
+
t_596 = torch.add(t_595, t_590)
|
472 |
+
t_597 = F.relu(t_596)
|
473 |
+
t_598 = self.n_Conv_104(t_597)
|
474 |
+
t_599 = F.relu(t_598)
|
475 |
+
t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
|
476 |
+
t_600 = self.n_Conv_105(t_599_padded)
|
477 |
+
t_601 = F.relu(t_600)
|
478 |
+
t_602 = self.n_Conv_106(t_601)
|
479 |
+
t_603 = torch.add(t_602, t_597)
|
480 |
+
t_604 = F.relu(t_603)
|
481 |
+
t_605 = self.n_Conv_107(t_604)
|
482 |
+
t_606 = F.relu(t_605)
|
483 |
+
t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
|
484 |
+
t_607 = self.n_Conv_108(t_606_padded)
|
485 |
+
t_608 = F.relu(t_607)
|
486 |
+
t_609 = self.n_Conv_109(t_608)
|
487 |
+
t_610 = torch.add(t_609, t_604)
|
488 |
+
t_611 = F.relu(t_610)
|
489 |
+
t_612 = self.n_Conv_110(t_611)
|
490 |
+
t_613 = F.relu(t_612)
|
491 |
+
t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
|
492 |
+
t_614 = self.n_Conv_111(t_613_padded)
|
493 |
+
t_615 = F.relu(t_614)
|
494 |
+
t_616 = self.n_Conv_112(t_615)
|
495 |
+
t_617 = torch.add(t_616, t_611)
|
496 |
+
t_618 = F.relu(t_617)
|
497 |
+
t_619 = self.n_Conv_113(t_618)
|
498 |
+
t_620 = F.relu(t_619)
|
499 |
+
t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
|
500 |
+
t_621 = self.n_Conv_114(t_620_padded)
|
501 |
+
t_622 = F.relu(t_621)
|
502 |
+
t_623 = self.n_Conv_115(t_622)
|
503 |
+
t_624 = torch.add(t_623, t_618)
|
504 |
+
t_625 = F.relu(t_624)
|
505 |
+
t_626 = self.n_Conv_116(t_625)
|
506 |
+
t_627 = F.relu(t_626)
|
507 |
+
t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
|
508 |
+
t_628 = self.n_Conv_117(t_627_padded)
|
509 |
+
t_629 = F.relu(t_628)
|
510 |
+
t_630 = self.n_Conv_118(t_629)
|
511 |
+
t_631 = torch.add(t_630, t_625)
|
512 |
+
t_632 = F.relu(t_631)
|
513 |
+
t_633 = self.n_Conv_119(t_632)
|
514 |
+
t_634 = F.relu(t_633)
|
515 |
+
t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
|
516 |
+
t_635 = self.n_Conv_120(t_634_padded)
|
517 |
+
t_636 = F.relu(t_635)
|
518 |
+
t_637 = self.n_Conv_121(t_636)
|
519 |
+
t_638 = torch.add(t_637, t_632)
|
520 |
+
t_639 = F.relu(t_638)
|
521 |
+
t_640 = self.n_Conv_122(t_639)
|
522 |
+
t_641 = F.relu(t_640)
|
523 |
+
t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
|
524 |
+
t_642 = self.n_Conv_123(t_641_padded)
|
525 |
+
t_643 = F.relu(t_642)
|
526 |
+
t_644 = self.n_Conv_124(t_643)
|
527 |
+
t_645 = torch.add(t_644, t_639)
|
528 |
+
t_646 = F.relu(t_645)
|
529 |
+
t_647 = self.n_Conv_125(t_646)
|
530 |
+
t_648 = F.relu(t_647)
|
531 |
+
t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
|
532 |
+
t_649 = self.n_Conv_126(t_648_padded)
|
533 |
+
t_650 = F.relu(t_649)
|
534 |
+
t_651 = self.n_Conv_127(t_650)
|
535 |
+
t_652 = torch.add(t_651, t_646)
|
536 |
+
t_653 = F.relu(t_652)
|
537 |
+
t_654 = self.n_Conv_128(t_653)
|
538 |
+
t_655 = F.relu(t_654)
|
539 |
+
t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
|
540 |
+
t_656 = self.n_Conv_129(t_655_padded)
|
541 |
+
t_657 = F.relu(t_656)
|
542 |
+
t_658 = self.n_Conv_130(t_657)
|
543 |
+
t_659 = torch.add(t_658, t_653)
|
544 |
+
t_660 = F.relu(t_659)
|
545 |
+
t_661 = self.n_Conv_131(t_660)
|
546 |
+
t_662 = F.relu(t_661)
|
547 |
+
t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
|
548 |
+
t_663 = self.n_Conv_132(t_662_padded)
|
549 |
+
t_664 = F.relu(t_663)
|
550 |
+
t_665 = self.n_Conv_133(t_664)
|
551 |
+
t_666 = torch.add(t_665, t_660)
|
552 |
+
t_667 = F.relu(t_666)
|
553 |
+
t_668 = self.n_Conv_134(t_667)
|
554 |
+
t_669 = F.relu(t_668)
|
555 |
+
t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
|
556 |
+
t_670 = self.n_Conv_135(t_669_padded)
|
557 |
+
t_671 = F.relu(t_670)
|
558 |
+
t_672 = self.n_Conv_136(t_671)
|
559 |
+
t_673 = torch.add(t_672, t_667)
|
560 |
+
t_674 = F.relu(t_673)
|
561 |
+
t_675 = self.n_Conv_137(t_674)
|
562 |
+
t_676 = F.relu(t_675)
|
563 |
+
t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
|
564 |
+
t_677 = self.n_Conv_138(t_676_padded)
|
565 |
+
t_678 = F.relu(t_677)
|
566 |
+
t_679 = self.n_Conv_139(t_678)
|
567 |
+
t_680 = torch.add(t_679, t_674)
|
568 |
+
t_681 = F.relu(t_680)
|
569 |
+
t_682 = self.n_Conv_140(t_681)
|
570 |
+
t_683 = F.relu(t_682)
|
571 |
+
t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
|
572 |
+
t_684 = self.n_Conv_141(t_683_padded)
|
573 |
+
t_685 = F.relu(t_684)
|
574 |
+
t_686 = self.n_Conv_142(t_685)
|
575 |
+
t_687 = torch.add(t_686, t_681)
|
576 |
+
t_688 = F.relu(t_687)
|
577 |
+
t_689 = self.n_Conv_143(t_688)
|
578 |
+
t_690 = F.relu(t_689)
|
579 |
+
t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
|
580 |
+
t_691 = self.n_Conv_144(t_690_padded)
|
581 |
+
t_692 = F.relu(t_691)
|
582 |
+
t_693 = self.n_Conv_145(t_692)
|
583 |
+
t_694 = torch.add(t_693, t_688)
|
584 |
+
t_695 = F.relu(t_694)
|
585 |
+
t_696 = self.n_Conv_146(t_695)
|
586 |
+
t_697 = F.relu(t_696)
|
587 |
+
t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
|
588 |
+
t_698 = self.n_Conv_147(t_697_padded)
|
589 |
+
t_699 = F.relu(t_698)
|
590 |
+
t_700 = self.n_Conv_148(t_699)
|
591 |
+
t_701 = torch.add(t_700, t_695)
|
592 |
+
t_702 = F.relu(t_701)
|
593 |
+
t_703 = self.n_Conv_149(t_702)
|
594 |
+
t_704 = F.relu(t_703)
|
595 |
+
t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
|
596 |
+
t_705 = self.n_Conv_150(t_704_padded)
|
597 |
+
t_706 = F.relu(t_705)
|
598 |
+
t_707 = self.n_Conv_151(t_706)
|
599 |
+
t_708 = torch.add(t_707, t_702)
|
600 |
+
t_709 = F.relu(t_708)
|
601 |
+
t_710 = self.n_Conv_152(t_709)
|
602 |
+
t_711 = F.relu(t_710)
|
603 |
+
t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
|
604 |
+
t_712 = self.n_Conv_153(t_711_padded)
|
605 |
+
t_713 = F.relu(t_712)
|
606 |
+
t_714 = self.n_Conv_154(t_713)
|
607 |
+
t_715 = torch.add(t_714, t_709)
|
608 |
+
t_716 = F.relu(t_715)
|
609 |
+
t_717 = self.n_Conv_155(t_716)
|
610 |
+
t_718 = F.relu(t_717)
|
611 |
+
t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
|
612 |
+
t_719 = self.n_Conv_156(t_718_padded)
|
613 |
+
t_720 = F.relu(t_719)
|
614 |
+
t_721 = self.n_Conv_157(t_720)
|
615 |
+
t_722 = torch.add(t_721, t_716)
|
616 |
+
t_723 = F.relu(t_722)
|
617 |
+
t_724 = self.n_Conv_158(t_723)
|
618 |
+
t_725 = self.n_Conv_159(t_723)
|
619 |
+
t_726 = F.relu(t_725)
|
620 |
+
t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
|
621 |
+
t_727 = self.n_Conv_160(t_726_padded)
|
622 |
+
t_728 = F.relu(t_727)
|
623 |
+
t_729 = self.n_Conv_161(t_728)
|
624 |
+
t_730 = torch.add(t_729, t_724)
|
625 |
+
t_731 = F.relu(t_730)
|
626 |
+
t_732 = self.n_Conv_162(t_731)
|
627 |
+
t_733 = F.relu(t_732)
|
628 |
+
t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
|
629 |
+
t_734 = self.n_Conv_163(t_733_padded)
|
630 |
+
t_735 = F.relu(t_734)
|
631 |
+
t_736 = self.n_Conv_164(t_735)
|
632 |
+
t_737 = torch.add(t_736, t_731)
|
633 |
+
t_738 = F.relu(t_737)
|
634 |
+
t_739 = self.n_Conv_165(t_738)
|
635 |
+
t_740 = F.relu(t_739)
|
636 |
+
t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
|
637 |
+
t_741 = self.n_Conv_166(t_740_padded)
|
638 |
+
t_742 = F.relu(t_741)
|
639 |
+
t_743 = self.n_Conv_167(t_742)
|
640 |
+
t_744 = torch.add(t_743, t_738)
|
641 |
+
t_745 = F.relu(t_744)
|
642 |
+
t_746 = self.n_Conv_168(t_745)
|
643 |
+
t_747 = self.n_Conv_169(t_745)
|
644 |
+
t_748 = F.relu(t_747)
|
645 |
+
t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
|
646 |
+
t_749 = self.n_Conv_170(t_748_padded)
|
647 |
+
t_750 = F.relu(t_749)
|
648 |
+
t_751 = self.n_Conv_171(t_750)
|
649 |
+
t_752 = torch.add(t_751, t_746)
|
650 |
+
t_753 = F.relu(t_752)
|
651 |
+
t_754 = self.n_Conv_172(t_753)
|
652 |
+
t_755 = F.relu(t_754)
|
653 |
+
t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
|
654 |
+
t_756 = self.n_Conv_173(t_755_padded)
|
655 |
+
t_757 = F.relu(t_756)
|
656 |
+
t_758 = self.n_Conv_174(t_757)
|
657 |
+
t_759 = torch.add(t_758, t_753)
|
658 |
+
t_760 = F.relu(t_759)
|
659 |
+
t_761 = self.n_Conv_175(t_760)
|
660 |
+
t_762 = F.relu(t_761)
|
661 |
+
t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
|
662 |
+
t_763 = self.n_Conv_176(t_762_padded)
|
663 |
+
t_764 = F.relu(t_763)
|
664 |
+
t_765 = self.n_Conv_177(t_764)
|
665 |
+
t_766 = torch.add(t_765, t_760)
|
666 |
+
t_767 = F.relu(t_766)
|
667 |
+
t_768 = self.n_Conv_178(t_767)
|
668 |
+
t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
|
669 |
+
t_770 = torch.squeeze(t_769, 3)
|
670 |
+
t_770 = torch.squeeze(t_770, 2)
|
671 |
+
t_771 = torch.sigmoid(t_770)
|
672 |
+
return t_771
|
673 |
+
|
674 |
+
def load_state_dict(self, state_dict, **kwargs):
|
675 |
+
self.tags = state_dict.get('tags', [])
|
676 |
+
|
677 |
+
super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
|
678 |
+
|
sd/stable-diffusion-webui/modules/errors.py
CHANGED
@@ -1,43 +1,43 @@
|
|
1 |
-
import sys
|
2 |
-
import traceback
|
3 |
-
|
4 |
-
|
5 |
-
def print_error_explanation(message):
|
6 |
-
lines = message.strip().split("\n")
|
7 |
-
max_len = max([len(x) for x in lines])
|
8 |
-
|
9 |
-
print('=' * max_len, file=sys.stderr)
|
10 |
-
for line in lines:
|
11 |
-
print(line, file=sys.stderr)
|
12 |
-
print('=' * max_len, file=sys.stderr)
|
13 |
-
|
14 |
-
|
15 |
-
def display(e: Exception, task):
|
16 |
-
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
|
17 |
-
print(traceback.format_exc(), file=sys.stderr)
|
18 |
-
|
19 |
-
message = str(e)
|
20 |
-
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
|
21 |
-
print_error_explanation("""
|
22 |
-
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
|
23 |
-
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
|
24 |
-
""")
|
25 |
-
|
26 |
-
|
27 |
-
already_displayed = {}
|
28 |
-
|
29 |
-
|
30 |
-
def display_once(e: Exception, task):
|
31 |
-
if task in already_displayed:
|
32 |
-
return
|
33 |
-
|
34 |
-
display(e, task)
|
35 |
-
|
36 |
-
already_displayed[task] = 1
|
37 |
-
|
38 |
-
|
39 |
-
def run(code, task):
|
40 |
-
try:
|
41 |
-
code()
|
42 |
-
except Exception as e:
|
43 |
-
display(task, e)
|
|
|
1 |
+
import sys
|
2 |
+
import traceback
|
3 |
+
|
4 |
+
|
5 |
+
def print_error_explanation(message):
|
6 |
+
lines = message.strip().split("\n")
|
7 |
+
max_len = max([len(x) for x in lines])
|
8 |
+
|
9 |
+
print('=' * max_len, file=sys.stderr)
|
10 |
+
for line in lines:
|
11 |
+
print(line, file=sys.stderr)
|
12 |
+
print('=' * max_len, file=sys.stderr)
|
13 |
+
|
14 |
+
|
15 |
+
def display(e: Exception, task):
|
16 |
+
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
|
17 |
+
print(traceback.format_exc(), file=sys.stderr)
|
18 |
+
|
19 |
+
message = str(e)
|
20 |
+
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
|
21 |
+
print_error_explanation("""
|
22 |
+
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
|
23 |
+
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
|
24 |
+
""")
|
25 |
+
|
26 |
+
|
27 |
+
already_displayed = {}
|
28 |
+
|
29 |
+
|
30 |
+
def display_once(e: Exception, task):
|
31 |
+
if task in already_displayed:
|
32 |
+
return
|
33 |
+
|
34 |
+
display(e, task)
|
35 |
+
|
36 |
+
already_displayed[task] = 1
|
37 |
+
|
38 |
+
|
39 |
+
def run(code, task):
|
40 |
+
try:
|
41 |
+
code()
|
42 |
+
except Exception as e:
|
43 |
+
display(task, e)
|
sd/stable-diffusion-webui/modules/esrgan_model.py
CHANGED
@@ -1,233 +1,233 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
import torch
|
5 |
-
from PIL import Image
|
6 |
-
from basicsr.utils.download_util import load_file_from_url
|
7 |
-
|
8 |
-
import modules.esrgan_model_arch as arch
|
9 |
-
from modules import shared, modelloader, images, devices
|
10 |
-
from modules.upscaler import Upscaler, UpscalerData
|
11 |
-
from modules.shared import opts
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
def mod2normal(state_dict):
|
16 |
-
# this code is copied from https://github.com/victorca25/iNNfer
|
17 |
-
if 'conv_first.weight' in state_dict:
|
18 |
-
crt_net = {}
|
19 |
-
items = []
|
20 |
-
for k, v in state_dict.items():
|
21 |
-
items.append(k)
|
22 |
-
|
23 |
-
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
24 |
-
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
25 |
-
|
26 |
-
for k in items.copy():
|
27 |
-
if 'RDB' in k:
|
28 |
-
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
|
29 |
-
if '.weight' in k:
|
30 |
-
ori_k = ori_k.replace('.weight', '.0.weight')
|
31 |
-
elif '.bias' in k:
|
32 |
-
ori_k = ori_k.replace('.bias', '.0.bias')
|
33 |
-
crt_net[ori_k] = state_dict[k]
|
34 |
-
items.remove(k)
|
35 |
-
|
36 |
-
crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
|
37 |
-
crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
|
38 |
-
crt_net['model.3.weight'] = state_dict['upconv1.weight']
|
39 |
-
crt_net['model.3.bias'] = state_dict['upconv1.bias']
|
40 |
-
crt_net['model.6.weight'] = state_dict['upconv2.weight']
|
41 |
-
crt_net['model.6.bias'] = state_dict['upconv2.bias']
|
42 |
-
crt_net['model.8.weight'] = state_dict['HRconv.weight']
|
43 |
-
crt_net['model.8.bias'] = state_dict['HRconv.bias']
|
44 |
-
crt_net['model.10.weight'] = state_dict['conv_last.weight']
|
45 |
-
crt_net['model.10.bias'] = state_dict['conv_last.bias']
|
46 |
-
state_dict = crt_net
|
47 |
-
return state_dict
|
48 |
-
|
49 |
-
|
50 |
-
def resrgan2normal(state_dict, nb=23):
|
51 |
-
# this code is copied from https://github.com/victorca25/iNNfer
|
52 |
-
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
|
53 |
-
re8x = 0
|
54 |
-
crt_net = {}
|
55 |
-
items = []
|
56 |
-
for k, v in state_dict.items():
|
57 |
-
items.append(k)
|
58 |
-
|
59 |
-
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
60 |
-
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
61 |
-
|
62 |
-
for k in items.copy():
|
63 |
-
if "rdb" in k:
|
64 |
-
ori_k = k.replace('body.', 'model.1.sub.')
|
65 |
-
ori_k = ori_k.replace('.rdb', '.RDB')
|
66 |
-
if '.weight' in k:
|
67 |
-
ori_k = ori_k.replace('.weight', '.0.weight')
|
68 |
-
elif '.bias' in k:
|
69 |
-
ori_k = ori_k.replace('.bias', '.0.bias')
|
70 |
-
crt_net[ori_k] = state_dict[k]
|
71 |
-
items.remove(k)
|
72 |
-
|
73 |
-
crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
|
74 |
-
crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
|
75 |
-
crt_net['model.3.weight'] = state_dict['conv_up1.weight']
|
76 |
-
crt_net['model.3.bias'] = state_dict['conv_up1.bias']
|
77 |
-
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
|
78 |
-
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
|
79 |
-
|
80 |
-
if 'conv_up3.weight' in state_dict:
|
81 |
-
# modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
|
82 |
-
re8x = 3
|
83 |
-
crt_net['model.9.weight'] = state_dict['conv_up3.weight']
|
84 |
-
crt_net['model.9.bias'] = state_dict['conv_up3.bias']
|
85 |
-
|
86 |
-
crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
|
87 |
-
crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
|
88 |
-
crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
|
89 |
-
crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
|
90 |
-
|
91 |
-
state_dict = crt_net
|
92 |
-
return state_dict
|
93 |
-
|
94 |
-
|
95 |
-
def infer_params(state_dict):
|
96 |
-
# this code is copied from https://github.com/victorca25/iNNfer
|
97 |
-
scale2x = 0
|
98 |
-
scalemin = 6
|
99 |
-
n_uplayer = 0
|
100 |
-
plus = False
|
101 |
-
|
102 |
-
for block in list(state_dict):
|
103 |
-
parts = block.split(".")
|
104 |
-
n_parts = len(parts)
|
105 |
-
if n_parts == 5 and parts[2] == "sub":
|
106 |
-
nb = int(parts[3])
|
107 |
-
elif n_parts == 3:
|
108 |
-
part_num = int(parts[1])
|
109 |
-
if (part_num > scalemin
|
110 |
-
and parts[0] == "model"
|
111 |
-
and parts[2] == "weight"):
|
112 |
-
scale2x += 1
|
113 |
-
if part_num > n_uplayer:
|
114 |
-
n_uplayer = part_num
|
115 |
-
out_nc = state_dict[block].shape[0]
|
116 |
-
if not plus and "conv1x1" in block:
|
117 |
-
plus = True
|
118 |
-
|
119 |
-
nf = state_dict["model.0.weight"].shape[0]
|
120 |
-
in_nc = state_dict["model.0.weight"].shape[1]
|
121 |
-
out_nc = out_nc
|
122 |
-
scale = 2 ** scale2x
|
123 |
-
|
124 |
-
return in_nc, out_nc, nf, nb, plus, scale
|
125 |
-
|
126 |
-
|
127 |
-
class UpscalerESRGAN(Upscaler):
|
128 |
-
def __init__(self, dirname):
|
129 |
-
self.name = "ESRGAN"
|
130 |
-
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
|
131 |
-
self.model_name = "ESRGAN_4x"
|
132 |
-
self.scalers = []
|
133 |
-
self.user_path = dirname
|
134 |
-
super().__init__()
|
135 |
-
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
136 |
-
scalers = []
|
137 |
-
if len(model_paths) == 0:
|
138 |
-
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
139 |
-
scalers.append(scaler_data)
|
140 |
-
for file in model_paths:
|
141 |
-
if "http" in file:
|
142 |
-
name = self.model_name
|
143 |
-
else:
|
144 |
-
name = modelloader.friendly_name(file)
|
145 |
-
|
146 |
-
scaler_data = UpscalerData(name, file, self, 4)
|
147 |
-
self.scalers.append(scaler_data)
|
148 |
-
|
149 |
-
def do_upscale(self, img, selected_model):
|
150 |
-
model = self.load_model(selected_model)
|
151 |
-
if model is None:
|
152 |
-
return img
|
153 |
-
model.to(devices.device_esrgan)
|
154 |
-
img = esrgan_upscale(model, img)
|
155 |
-
return img
|
156 |
-
|
157 |
-
def load_model(self, path: str):
|
158 |
-
if "http" in path:
|
159 |
-
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
160 |
-
file_name="%s.pth" % self.model_name,
|
161 |
-
progress=True)
|
162 |
-
else:
|
163 |
-
filename = path
|
164 |
-
if not os.path.exists(filename) or filename is None:
|
165 |
-
print("Unable to load %s from %s" % (self.model_path, filename))
|
166 |
-
return None
|
167 |
-
|
168 |
-
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
169 |
-
|
170 |
-
if "params_ema" in state_dict:
|
171 |
-
state_dict = state_dict["params_ema"]
|
172 |
-
elif "params" in state_dict:
|
173 |
-
state_dict = state_dict["params"]
|
174 |
-
num_conv = 16 if "realesr-animevideov3" in filename else 32
|
175 |
-
model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
|
176 |
-
model.load_state_dict(state_dict)
|
177 |
-
model.eval()
|
178 |
-
return model
|
179 |
-
|
180 |
-
if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
|
181 |
-
nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
|
182 |
-
state_dict = resrgan2normal(state_dict, nb)
|
183 |
-
elif "conv_first.weight" in state_dict:
|
184 |
-
state_dict = mod2normal(state_dict)
|
185 |
-
elif "model.0.weight" not in state_dict:
|
186 |
-
raise Exception("The file is not a recognized ESRGAN model.")
|
187 |
-
|
188 |
-
in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
|
189 |
-
|
190 |
-
model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
|
191 |
-
model.load_state_dict(state_dict)
|
192 |
-
model.eval()
|
193 |
-
|
194 |
-
return model
|
195 |
-
|
196 |
-
|
197 |
-
def upscale_without_tiling(model, img):
|
198 |
-
img = np.array(img)
|
199 |
-
img = img[:, :, ::-1]
|
200 |
-
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
201 |
-
img = torch.from_numpy(img).float()
|
202 |
-
img = img.unsqueeze(0).to(devices.device_esrgan)
|
203 |
-
with torch.no_grad():
|
204 |
-
output = model(img)
|
205 |
-
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
206 |
-
output = 255. * np.moveaxis(output, 0, 2)
|
207 |
-
output = output.astype(np.uint8)
|
208 |
-
output = output[:, :, ::-1]
|
209 |
-
return Image.fromarray(output, 'RGB')
|
210 |
-
|
211 |
-
|
212 |
-
def esrgan_upscale(model, img):
|
213 |
-
if opts.ESRGAN_tile == 0:
|
214 |
-
return upscale_without_tiling(model, img)
|
215 |
-
|
216 |
-
grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
|
217 |
-
newtiles = []
|
218 |
-
scale_factor = 1
|
219 |
-
|
220 |
-
for y, h, row in grid.tiles:
|
221 |
-
newrow = []
|
222 |
-
for tiledata in row:
|
223 |
-
x, w, tile = tiledata
|
224 |
-
|
225 |
-
output = upscale_without_tiling(model, tile)
|
226 |
-
scale_factor = output.width // tile.width
|
227 |
-
|
228 |
-
newrow.append([x * scale_factor, w * scale_factor, output])
|
229 |
-
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
230 |
-
|
231 |
-
newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
|
232 |
-
output = images.combine_grid(newgrid)
|
233 |
-
return output
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from basicsr.utils.download_util import load_file_from_url
|
7 |
+
|
8 |
+
import modules.esrgan_model_arch as arch
|
9 |
+
from modules import shared, modelloader, images, devices
|
10 |
+
from modules.upscaler import Upscaler, UpscalerData
|
11 |
+
from modules.shared import opts
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
def mod2normal(state_dict):
|
16 |
+
# this code is copied from https://github.com/victorca25/iNNfer
|
17 |
+
if 'conv_first.weight' in state_dict:
|
18 |
+
crt_net = {}
|
19 |
+
items = []
|
20 |
+
for k, v in state_dict.items():
|
21 |
+
items.append(k)
|
22 |
+
|
23 |
+
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
24 |
+
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
25 |
+
|
26 |
+
for k in items.copy():
|
27 |
+
if 'RDB' in k:
|
28 |
+
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
|
29 |
+
if '.weight' in k:
|
30 |
+
ori_k = ori_k.replace('.weight', '.0.weight')
|
31 |
+
elif '.bias' in k:
|
32 |
+
ori_k = ori_k.replace('.bias', '.0.bias')
|
33 |
+
crt_net[ori_k] = state_dict[k]
|
34 |
+
items.remove(k)
|
35 |
+
|
36 |
+
crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
|
37 |
+
crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
|
38 |
+
crt_net['model.3.weight'] = state_dict['upconv1.weight']
|
39 |
+
crt_net['model.3.bias'] = state_dict['upconv1.bias']
|
40 |
+
crt_net['model.6.weight'] = state_dict['upconv2.weight']
|
41 |
+
crt_net['model.6.bias'] = state_dict['upconv2.bias']
|
42 |
+
crt_net['model.8.weight'] = state_dict['HRconv.weight']
|
43 |
+
crt_net['model.8.bias'] = state_dict['HRconv.bias']
|
44 |
+
crt_net['model.10.weight'] = state_dict['conv_last.weight']
|
45 |
+
crt_net['model.10.bias'] = state_dict['conv_last.bias']
|
46 |
+
state_dict = crt_net
|
47 |
+
return state_dict
|
48 |
+
|
49 |
+
|
50 |
+
def resrgan2normal(state_dict, nb=23):
|
51 |
+
# this code is copied from https://github.com/victorca25/iNNfer
|
52 |
+
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
|
53 |
+
re8x = 0
|
54 |
+
crt_net = {}
|
55 |
+
items = []
|
56 |
+
for k, v in state_dict.items():
|
57 |
+
items.append(k)
|
58 |
+
|
59 |
+
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
60 |
+
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
61 |
+
|
62 |
+
for k in items.copy():
|
63 |
+
if "rdb" in k:
|
64 |
+
ori_k = k.replace('body.', 'model.1.sub.')
|
65 |
+
ori_k = ori_k.replace('.rdb', '.RDB')
|
66 |
+
if '.weight' in k:
|
67 |
+
ori_k = ori_k.replace('.weight', '.0.weight')
|
68 |
+
elif '.bias' in k:
|
69 |
+
ori_k = ori_k.replace('.bias', '.0.bias')
|
70 |
+
crt_net[ori_k] = state_dict[k]
|
71 |
+
items.remove(k)
|
72 |
+
|
73 |
+
crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
|
74 |
+
crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
|
75 |
+
crt_net['model.3.weight'] = state_dict['conv_up1.weight']
|
76 |
+
crt_net['model.3.bias'] = state_dict['conv_up1.bias']
|
77 |
+
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
|
78 |
+
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
|
79 |
+
|
80 |
+
if 'conv_up3.weight' in state_dict:
|
81 |
+
# modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
|
82 |
+
re8x = 3
|
83 |
+
crt_net['model.9.weight'] = state_dict['conv_up3.weight']
|
84 |
+
crt_net['model.9.bias'] = state_dict['conv_up3.bias']
|
85 |
+
|
86 |
+
crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
|
87 |
+
crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
|
88 |
+
crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
|
89 |
+
crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
|
90 |
+
|
91 |
+
state_dict = crt_net
|
92 |
+
return state_dict
|
93 |
+
|
94 |
+
|
95 |
+
def infer_params(state_dict):
|
96 |
+
# this code is copied from https://github.com/victorca25/iNNfer
|
97 |
+
scale2x = 0
|
98 |
+
scalemin = 6
|
99 |
+
n_uplayer = 0
|
100 |
+
plus = False
|
101 |
+
|
102 |
+
for block in list(state_dict):
|
103 |
+
parts = block.split(".")
|
104 |
+
n_parts = len(parts)
|
105 |
+
if n_parts == 5 and parts[2] == "sub":
|
106 |
+
nb = int(parts[3])
|
107 |
+
elif n_parts == 3:
|
108 |
+
part_num = int(parts[1])
|
109 |
+
if (part_num > scalemin
|
110 |
+
and parts[0] == "model"
|
111 |
+
and parts[2] == "weight"):
|
112 |
+
scale2x += 1
|
113 |
+
if part_num > n_uplayer:
|
114 |
+
n_uplayer = part_num
|
115 |
+
out_nc = state_dict[block].shape[0]
|
116 |
+
if not plus and "conv1x1" in block:
|
117 |
+
plus = True
|
118 |
+
|
119 |
+
nf = state_dict["model.0.weight"].shape[0]
|
120 |
+
in_nc = state_dict["model.0.weight"].shape[1]
|
121 |
+
out_nc = out_nc
|
122 |
+
scale = 2 ** scale2x
|
123 |
+
|
124 |
+
return in_nc, out_nc, nf, nb, plus, scale
|
125 |
+
|
126 |
+
|
127 |
+
class UpscalerESRGAN(Upscaler):
|
128 |
+
def __init__(self, dirname):
|
129 |
+
self.name = "ESRGAN"
|
130 |
+
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
|
131 |
+
self.model_name = "ESRGAN_4x"
|
132 |
+
self.scalers = []
|
133 |
+
self.user_path = dirname
|
134 |
+
super().__init__()
|
135 |
+
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
136 |
+
scalers = []
|
137 |
+
if len(model_paths) == 0:
|
138 |
+
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
139 |
+
scalers.append(scaler_data)
|
140 |
+
for file in model_paths:
|
141 |
+
if "http" in file:
|
142 |
+
name = self.model_name
|
143 |
+
else:
|
144 |
+
name = modelloader.friendly_name(file)
|
145 |
+
|
146 |
+
scaler_data = UpscalerData(name, file, self, 4)
|
147 |
+
self.scalers.append(scaler_data)
|
148 |
+
|
149 |
+
def do_upscale(self, img, selected_model):
|
150 |
+
model = self.load_model(selected_model)
|
151 |
+
if model is None:
|
152 |
+
return img
|
153 |
+
model.to(devices.device_esrgan)
|
154 |
+
img = esrgan_upscale(model, img)
|
155 |
+
return img
|
156 |
+
|
157 |
+
def load_model(self, path: str):
|
158 |
+
if "http" in path:
|
159 |
+
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
160 |
+
file_name="%s.pth" % self.model_name,
|
161 |
+
progress=True)
|
162 |
+
else:
|
163 |
+
filename = path
|
164 |
+
if not os.path.exists(filename) or filename is None:
|
165 |
+
print("Unable to load %s from %s" % (self.model_path, filename))
|
166 |
+
return None
|
167 |
+
|
168 |
+
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
169 |
+
|
170 |
+
if "params_ema" in state_dict:
|
171 |
+
state_dict = state_dict["params_ema"]
|
172 |
+
elif "params" in state_dict:
|
173 |
+
state_dict = state_dict["params"]
|
174 |
+
num_conv = 16 if "realesr-animevideov3" in filename else 32
|
175 |
+
model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
|
176 |
+
model.load_state_dict(state_dict)
|
177 |
+
model.eval()
|
178 |
+
return model
|
179 |
+
|
180 |
+
if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
|
181 |
+
nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
|
182 |
+
state_dict = resrgan2normal(state_dict, nb)
|
183 |
+
elif "conv_first.weight" in state_dict:
|
184 |
+
state_dict = mod2normal(state_dict)
|
185 |
+
elif "model.0.weight" not in state_dict:
|
186 |
+
raise Exception("The file is not a recognized ESRGAN model.")
|
187 |
+
|
188 |
+
in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
|
189 |
+
|
190 |
+
model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
|
191 |
+
model.load_state_dict(state_dict)
|
192 |
+
model.eval()
|
193 |
+
|
194 |
+
return model
|
195 |
+
|
196 |
+
|
197 |
+
def upscale_without_tiling(model, img):
|
198 |
+
img = np.array(img)
|
199 |
+
img = img[:, :, ::-1]
|
200 |
+
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
201 |
+
img = torch.from_numpy(img).float()
|
202 |
+
img = img.unsqueeze(0).to(devices.device_esrgan)
|
203 |
+
with torch.no_grad():
|
204 |
+
output = model(img)
|
205 |
+
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
206 |
+
output = 255. * np.moveaxis(output, 0, 2)
|
207 |
+
output = output.astype(np.uint8)
|
208 |
+
output = output[:, :, ::-1]
|
209 |
+
return Image.fromarray(output, 'RGB')
|
210 |
+
|
211 |
+
|
212 |
+
def esrgan_upscale(model, img):
|
213 |
+
if opts.ESRGAN_tile == 0:
|
214 |
+
return upscale_without_tiling(model, img)
|
215 |
+
|
216 |
+
grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
|
217 |
+
newtiles = []
|
218 |
+
scale_factor = 1
|
219 |
+
|
220 |
+
for y, h, row in grid.tiles:
|
221 |
+
newrow = []
|
222 |
+
for tiledata in row:
|
223 |
+
x, w, tile = tiledata
|
224 |
+
|
225 |
+
output = upscale_without_tiling(model, tile)
|
226 |
+
scale_factor = output.width // tile.width
|
227 |
+
|
228 |
+
newrow.append([x * scale_factor, w * scale_factor, output])
|
229 |
+
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
230 |
+
|
231 |
+
newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
|
232 |
+
output = images.combine_grid(newgrid)
|
233 |
+
return output
|
sd/stable-diffusion-webui/modules/esrgan_model_arch.py
CHANGED
@@ -1,464 +1,464 @@
|
|
1 |
-
# this file is adapted from https://github.com/victorca25/iNNfer
|
2 |
-
|
3 |
-
from collections import OrderedDict
|
4 |
-
import math
|
5 |
-
import functools
|
6 |
-
import torch
|
7 |
-
import torch.nn as nn
|
8 |
-
import torch.nn.functional as F
|
9 |
-
|
10 |
-
|
11 |
-
####################
|
12 |
-
# RRDBNet Generator
|
13 |
-
####################
|
14 |
-
|
15 |
-
class RRDBNet(nn.Module):
|
16 |
-
def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
|
17 |
-
act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
|
18 |
-
finalact=None, gaussian_noise=False, plus=False):
|
19 |
-
super(RRDBNet, self).__init__()
|
20 |
-
n_upscale = int(math.log(upscale, 2))
|
21 |
-
if upscale == 3:
|
22 |
-
n_upscale = 1
|
23 |
-
|
24 |
-
self.resrgan_scale = 0
|
25 |
-
if in_nc % 16 == 0:
|
26 |
-
self.resrgan_scale = 1
|
27 |
-
elif in_nc != 4 and in_nc % 4 == 0:
|
28 |
-
self.resrgan_scale = 2
|
29 |
-
|
30 |
-
fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
31 |
-
rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
32 |
-
norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
|
33 |
-
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
|
34 |
-
LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
|
35 |
-
|
36 |
-
if upsample_mode == 'upconv':
|
37 |
-
upsample_block = upconv_block
|
38 |
-
elif upsample_mode == 'pixelshuffle':
|
39 |
-
upsample_block = pixelshuffle_block
|
40 |
-
else:
|
41 |
-
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
42 |
-
if upscale == 3:
|
43 |
-
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
|
44 |
-
else:
|
45 |
-
upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
|
46 |
-
HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
|
47 |
-
HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
48 |
-
|
49 |
-
outact = act(finalact) if finalact else None
|
50 |
-
|
51 |
-
self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
|
52 |
-
*upsampler, HR_conv0, HR_conv1, outact)
|
53 |
-
|
54 |
-
def forward(self, x, outm=None):
|
55 |
-
if self.resrgan_scale == 1:
|
56 |
-
feat = pixel_unshuffle(x, scale=4)
|
57 |
-
elif self.resrgan_scale == 2:
|
58 |
-
feat = pixel_unshuffle(x, scale=2)
|
59 |
-
else:
|
60 |
-
feat = x
|
61 |
-
|
62 |
-
return self.model(feat)
|
63 |
-
|
64 |
-
|
65 |
-
class RRDB(nn.Module):
|
66 |
-
"""
|
67 |
-
Residual in Residual Dense Block
|
68 |
-
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
69 |
-
"""
|
70 |
-
|
71 |
-
def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
72 |
-
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
73 |
-
spectral_norm=False, gaussian_noise=False, plus=False):
|
74 |
-
super(RRDB, self).__init__()
|
75 |
-
# This is for backwards compatibility with existing models
|
76 |
-
if nr == 3:
|
77 |
-
self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
78 |
-
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
79 |
-
gaussian_noise=gaussian_noise, plus=plus)
|
80 |
-
self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
81 |
-
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
82 |
-
gaussian_noise=gaussian_noise, plus=plus)
|
83 |
-
self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
84 |
-
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
85 |
-
gaussian_noise=gaussian_noise, plus=plus)
|
86 |
-
else:
|
87 |
-
RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
88 |
-
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
89 |
-
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
|
90 |
-
self.RDBs = nn.Sequential(*RDB_list)
|
91 |
-
|
92 |
-
def forward(self, x):
|
93 |
-
if hasattr(self, 'RDB1'):
|
94 |
-
out = self.RDB1(x)
|
95 |
-
out = self.RDB2(out)
|
96 |
-
out = self.RDB3(out)
|
97 |
-
else:
|
98 |
-
out = self.RDBs(x)
|
99 |
-
return out * 0.2 + x
|
100 |
-
|
101 |
-
|
102 |
-
class ResidualDenseBlock_5C(nn.Module):
|
103 |
-
"""
|
104 |
-
Residual Dense Block
|
105 |
-
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
106 |
-
Modified options that can be used:
|
107 |
-
- "Partial Convolution based Padding" arXiv:1811.11718
|
108 |
-
- "Spectral normalization" arXiv:1802.05957
|
109 |
-
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
110 |
-
{Rakotonirina} and A. {Rasoanaivo}
|
111 |
-
"""
|
112 |
-
|
113 |
-
def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
114 |
-
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
115 |
-
spectral_norm=False, gaussian_noise=False, plus=False):
|
116 |
-
super(ResidualDenseBlock_5C, self).__init__()
|
117 |
-
|
118 |
-
self.noise = GaussianNoise() if gaussian_noise else None
|
119 |
-
self.conv1x1 = conv1x1(nf, gc) if plus else None
|
120 |
-
|
121 |
-
self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
122 |
-
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
123 |
-
spectral_norm=spectral_norm)
|
124 |
-
self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
125 |
-
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
126 |
-
spectral_norm=spectral_norm)
|
127 |
-
self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
128 |
-
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
129 |
-
spectral_norm=spectral_norm)
|
130 |
-
self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
131 |
-
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
132 |
-
spectral_norm=spectral_norm)
|
133 |
-
if mode == 'CNA':
|
134 |
-
last_act = None
|
135 |
-
else:
|
136 |
-
last_act = act_type
|
137 |
-
self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
|
138 |
-
norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
|
139 |
-
spectral_norm=spectral_norm)
|
140 |
-
|
141 |
-
def forward(self, x):
|
142 |
-
x1 = self.conv1(x)
|
143 |
-
x2 = self.conv2(torch.cat((x, x1), 1))
|
144 |
-
if self.conv1x1:
|
145 |
-
x2 = x2 + self.conv1x1(x)
|
146 |
-
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
147 |
-
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
148 |
-
if self.conv1x1:
|
149 |
-
x4 = x4 + x2
|
150 |
-
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
151 |
-
if self.noise:
|
152 |
-
return self.noise(x5.mul(0.2) + x)
|
153 |
-
else:
|
154 |
-
return x5 * 0.2 + x
|
155 |
-
|
156 |
-
|
157 |
-
####################
|
158 |
-
# ESRGANplus
|
159 |
-
####################
|
160 |
-
|
161 |
-
class GaussianNoise(nn.Module):
|
162 |
-
def __init__(self, sigma=0.1, is_relative_detach=False):
|
163 |
-
super().__init__()
|
164 |
-
self.sigma = sigma
|
165 |
-
self.is_relative_detach = is_relative_detach
|
166 |
-
self.noise = torch.tensor(0, dtype=torch.float)
|
167 |
-
|
168 |
-
def forward(self, x):
|
169 |
-
if self.training and self.sigma != 0:
|
170 |
-
self.noise = self.noise.to(x.device)
|
171 |
-
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
172 |
-
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
173 |
-
x = x + sampled_noise
|
174 |
-
return x
|
175 |
-
|
176 |
-
def conv1x1(in_planes, out_planes, stride=1):
|
177 |
-
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
178 |
-
|
179 |
-
|
180 |
-
####################
|
181 |
-
# SRVGGNetCompact
|
182 |
-
####################
|
183 |
-
|
184 |
-
class SRVGGNetCompact(nn.Module):
|
185 |
-
"""A compact VGG-style network structure for super-resolution.
|
186 |
-
This class is copied from https://github.com/xinntao/Real-ESRGAN
|
187 |
-
"""
|
188 |
-
|
189 |
-
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
|
190 |
-
super(SRVGGNetCompact, self).__init__()
|
191 |
-
self.num_in_ch = num_in_ch
|
192 |
-
self.num_out_ch = num_out_ch
|
193 |
-
self.num_feat = num_feat
|
194 |
-
self.num_conv = num_conv
|
195 |
-
self.upscale = upscale
|
196 |
-
self.act_type = act_type
|
197 |
-
|
198 |
-
self.body = nn.ModuleList()
|
199 |
-
# the first conv
|
200 |
-
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
201 |
-
# the first activation
|
202 |
-
if act_type == 'relu':
|
203 |
-
activation = nn.ReLU(inplace=True)
|
204 |
-
elif act_type == 'prelu':
|
205 |
-
activation = nn.PReLU(num_parameters=num_feat)
|
206 |
-
elif act_type == 'leakyrelu':
|
207 |
-
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
208 |
-
self.body.append(activation)
|
209 |
-
|
210 |
-
# the body structure
|
211 |
-
for _ in range(num_conv):
|
212 |
-
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
213 |
-
# activation
|
214 |
-
if act_type == 'relu':
|
215 |
-
activation = nn.ReLU(inplace=True)
|
216 |
-
elif act_type == 'prelu':
|
217 |
-
activation = nn.PReLU(num_parameters=num_feat)
|
218 |
-
elif act_type == 'leakyrelu':
|
219 |
-
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
220 |
-
self.body.append(activation)
|
221 |
-
|
222 |
-
# the last conv
|
223 |
-
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
224 |
-
# upsample
|
225 |
-
self.upsampler = nn.PixelShuffle(upscale)
|
226 |
-
|
227 |
-
def forward(self, x):
|
228 |
-
out = x
|
229 |
-
for i in range(0, len(self.body)):
|
230 |
-
out = self.body[i](out)
|
231 |
-
|
232 |
-
out = self.upsampler(out)
|
233 |
-
# add the nearest upsampled image, so that the network learns the residual
|
234 |
-
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
235 |
-
out += base
|
236 |
-
return out
|
237 |
-
|
238 |
-
|
239 |
-
####################
|
240 |
-
# Upsampler
|
241 |
-
####################
|
242 |
-
|
243 |
-
class Upsample(nn.Module):
|
244 |
-
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
|
245 |
-
The input data is assumed to be of the form
|
246 |
-
`minibatch x channels x [optional depth] x [optional height] x width`.
|
247 |
-
"""
|
248 |
-
|
249 |
-
def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
250 |
-
super(Upsample, self).__init__()
|
251 |
-
if isinstance(scale_factor, tuple):
|
252 |
-
self.scale_factor = tuple(float(factor) for factor in scale_factor)
|
253 |
-
else:
|
254 |
-
self.scale_factor = float(scale_factor) if scale_factor else None
|
255 |
-
self.mode = mode
|
256 |
-
self.size = size
|
257 |
-
self.align_corners = align_corners
|
258 |
-
|
259 |
-
def forward(self, x):
|
260 |
-
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
261 |
-
|
262 |
-
def extra_repr(self):
|
263 |
-
if self.scale_factor is not None:
|
264 |
-
info = 'scale_factor=' + str(self.scale_factor)
|
265 |
-
else:
|
266 |
-
info = 'size=' + str(self.size)
|
267 |
-
info += ', mode=' + self.mode
|
268 |
-
return info
|
269 |
-
|
270 |
-
|
271 |
-
def pixel_unshuffle(x, scale):
|
272 |
-
""" Pixel unshuffle.
|
273 |
-
Args:
|
274 |
-
x (Tensor): Input feature with shape (b, c, hh, hw).
|
275 |
-
scale (int): Downsample ratio.
|
276 |
-
Returns:
|
277 |
-
Tensor: the pixel unshuffled feature.
|
278 |
-
"""
|
279 |
-
b, c, hh, hw = x.size()
|
280 |
-
out_channel = c * (scale**2)
|
281 |
-
assert hh % scale == 0 and hw % scale == 0
|
282 |
-
h = hh // scale
|
283 |
-
w = hw // scale
|
284 |
-
x_view = x.view(b, c, h, scale, w, scale)
|
285 |
-
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
286 |
-
|
287 |
-
|
288 |
-
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
289 |
-
pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
|
290 |
-
"""
|
291 |
-
Pixel shuffle layer
|
292 |
-
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
293 |
-
Neural Network, CVPR17)
|
294 |
-
"""
|
295 |
-
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
|
296 |
-
pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
|
297 |
-
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
298 |
-
|
299 |
-
n = norm(norm_type, out_nc) if norm_type else None
|
300 |
-
a = act(act_type) if act_type else None
|
301 |
-
return sequential(conv, pixel_shuffle, n, a)
|
302 |
-
|
303 |
-
|
304 |
-
def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
305 |
-
pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
|
306 |
-
""" Upconv layer """
|
307 |
-
upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
|
308 |
-
upsample = Upsample(scale_factor=upscale_factor, mode=mode)
|
309 |
-
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
|
310 |
-
pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
|
311 |
-
return sequential(upsample, conv)
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
####################
|
321 |
-
# Basic blocks
|
322 |
-
####################
|
323 |
-
|
324 |
-
|
325 |
-
def make_layer(basic_block, num_basic_block, **kwarg):
|
326 |
-
"""Make layers by stacking the same blocks.
|
327 |
-
Args:
|
328 |
-
basic_block (nn.module): nn.module class for basic block. (block)
|
329 |
-
num_basic_block (int): number of blocks. (n_layers)
|
330 |
-
Returns:
|
331 |
-
nn.Sequential: Stacked blocks in nn.Sequential.
|
332 |
-
"""
|
333 |
-
layers = []
|
334 |
-
for _ in range(num_basic_block):
|
335 |
-
layers.append(basic_block(**kwarg))
|
336 |
-
return nn.Sequential(*layers)
|
337 |
-
|
338 |
-
|
339 |
-
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
|
340 |
-
""" activation helper """
|
341 |
-
act_type = act_type.lower()
|
342 |
-
if act_type == 'relu':
|
343 |
-
layer = nn.ReLU(inplace)
|
344 |
-
elif act_type in ('leakyrelu', 'lrelu'):
|
345 |
-
layer = nn.LeakyReLU(neg_slope, inplace)
|
346 |
-
elif act_type == 'prelu':
|
347 |
-
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
348 |
-
elif act_type == 'tanh': # [-1, 1] range output
|
349 |
-
layer = nn.Tanh()
|
350 |
-
elif act_type == 'sigmoid': # [0, 1] range output
|
351 |
-
layer = nn.Sigmoid()
|
352 |
-
else:
|
353 |
-
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
|
354 |
-
return layer
|
355 |
-
|
356 |
-
|
357 |
-
class Identity(nn.Module):
|
358 |
-
def __init__(self, *kwargs):
|
359 |
-
super(Identity, self).__init__()
|
360 |
-
|
361 |
-
def forward(self, x, *kwargs):
|
362 |
-
return x
|
363 |
-
|
364 |
-
|
365 |
-
def norm(norm_type, nc):
|
366 |
-
""" Return a normalization layer """
|
367 |
-
norm_type = norm_type.lower()
|
368 |
-
if norm_type == 'batch':
|
369 |
-
layer = nn.BatchNorm2d(nc, affine=True)
|
370 |
-
elif norm_type == 'instance':
|
371 |
-
layer = nn.InstanceNorm2d(nc, affine=False)
|
372 |
-
elif norm_type == 'none':
|
373 |
-
def norm_layer(x): return Identity()
|
374 |
-
else:
|
375 |
-
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
|
376 |
-
return layer
|
377 |
-
|
378 |
-
|
379 |
-
def pad(pad_type, padding):
|
380 |
-
""" padding layer helper """
|
381 |
-
pad_type = pad_type.lower()
|
382 |
-
if padding == 0:
|
383 |
-
return None
|
384 |
-
if pad_type == 'reflect':
|
385 |
-
layer = nn.ReflectionPad2d(padding)
|
386 |
-
elif pad_type == 'replicate':
|
387 |
-
layer = nn.ReplicationPad2d(padding)
|
388 |
-
elif pad_type == 'zero':
|
389 |
-
layer = nn.ZeroPad2d(padding)
|
390 |
-
else:
|
391 |
-
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
|
392 |
-
return layer
|
393 |
-
|
394 |
-
|
395 |
-
def get_valid_padding(kernel_size, dilation):
|
396 |
-
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
397 |
-
padding = (kernel_size - 1) // 2
|
398 |
-
return padding
|
399 |
-
|
400 |
-
|
401 |
-
class ShortcutBlock(nn.Module):
|
402 |
-
""" Elementwise sum the output of a submodule to its input """
|
403 |
-
def __init__(self, submodule):
|
404 |
-
super(ShortcutBlock, self).__init__()
|
405 |
-
self.sub = submodule
|
406 |
-
|
407 |
-
def forward(self, x):
|
408 |
-
output = x + self.sub(x)
|
409 |
-
return output
|
410 |
-
|
411 |
-
def __repr__(self):
|
412 |
-
return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
|
413 |
-
|
414 |
-
|
415 |
-
def sequential(*args):
|
416 |
-
""" Flatten Sequential. It unwraps nn.Sequential. """
|
417 |
-
if len(args) == 1:
|
418 |
-
if isinstance(args[0], OrderedDict):
|
419 |
-
raise NotImplementedError('sequential does not support OrderedDict input.')
|
420 |
-
return args[0] # No sequential is needed.
|
421 |
-
modules = []
|
422 |
-
for module in args:
|
423 |
-
if isinstance(module, nn.Sequential):
|
424 |
-
for submodule in module.children():
|
425 |
-
modules.append(submodule)
|
426 |
-
elif isinstance(module, nn.Module):
|
427 |
-
modules.append(module)
|
428 |
-
return nn.Sequential(*modules)
|
429 |
-
|
430 |
-
|
431 |
-
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
|
432 |
-
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
|
433 |
-
spectral_norm=False):
|
434 |
-
""" Conv layer with padding, normalization, activation """
|
435 |
-
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
|
436 |
-
padding = get_valid_padding(kernel_size, dilation)
|
437 |
-
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
438 |
-
padding = padding if pad_type == 'zero' else 0
|
439 |
-
|
440 |
-
if convtype=='PartialConv2D':
|
441 |
-
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
442 |
-
dilation=dilation, bias=bias, groups=groups)
|
443 |
-
elif convtype=='DeformConv2D':
|
444 |
-
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
445 |
-
dilation=dilation, bias=bias, groups=groups)
|
446 |
-
elif convtype=='Conv3D':
|
447 |
-
c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
448 |
-
dilation=dilation, bias=bias, groups=groups)
|
449 |
-
else:
|
450 |
-
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
451 |
-
dilation=dilation, bias=bias, groups=groups)
|
452 |
-
|
453 |
-
if spectral_norm:
|
454 |
-
c = nn.utils.spectral_norm(c)
|
455 |
-
|
456 |
-
a = act(act_type) if act_type else None
|
457 |
-
if 'CNA' in mode:
|
458 |
-
n = norm(norm_type, out_nc) if norm_type else None
|
459 |
-
return sequential(p, c, n, a)
|
460 |
-
elif mode == 'NAC':
|
461 |
-
if norm_type is None and act_type is not None:
|
462 |
-
a = act(act_type, inplace=False)
|
463 |
-
n = norm(norm_type, in_nc) if norm_type else None
|
464 |
-
return sequential(n, a, p, c)
|
|
|
1 |
+
# this file is adapted from https://github.com/victorca25/iNNfer
|
2 |
+
|
3 |
+
from collections import OrderedDict
|
4 |
+
import math
|
5 |
+
import functools
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
####################
|
12 |
+
# RRDBNet Generator
|
13 |
+
####################
|
14 |
+
|
15 |
+
class RRDBNet(nn.Module):
|
16 |
+
def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
|
17 |
+
act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
|
18 |
+
finalact=None, gaussian_noise=False, plus=False):
|
19 |
+
super(RRDBNet, self).__init__()
|
20 |
+
n_upscale = int(math.log(upscale, 2))
|
21 |
+
if upscale == 3:
|
22 |
+
n_upscale = 1
|
23 |
+
|
24 |
+
self.resrgan_scale = 0
|
25 |
+
if in_nc % 16 == 0:
|
26 |
+
self.resrgan_scale = 1
|
27 |
+
elif in_nc != 4 and in_nc % 4 == 0:
|
28 |
+
self.resrgan_scale = 2
|
29 |
+
|
30 |
+
fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
31 |
+
rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
32 |
+
norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
|
33 |
+
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
|
34 |
+
LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
|
35 |
+
|
36 |
+
if upsample_mode == 'upconv':
|
37 |
+
upsample_block = upconv_block
|
38 |
+
elif upsample_mode == 'pixelshuffle':
|
39 |
+
upsample_block = pixelshuffle_block
|
40 |
+
else:
|
41 |
+
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
42 |
+
if upscale == 3:
|
43 |
+
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
|
44 |
+
else:
|
45 |
+
upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
|
46 |
+
HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
|
47 |
+
HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
48 |
+
|
49 |
+
outact = act(finalact) if finalact else None
|
50 |
+
|
51 |
+
self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
|
52 |
+
*upsampler, HR_conv0, HR_conv1, outact)
|
53 |
+
|
54 |
+
def forward(self, x, outm=None):
|
55 |
+
if self.resrgan_scale == 1:
|
56 |
+
feat = pixel_unshuffle(x, scale=4)
|
57 |
+
elif self.resrgan_scale == 2:
|
58 |
+
feat = pixel_unshuffle(x, scale=2)
|
59 |
+
else:
|
60 |
+
feat = x
|
61 |
+
|
62 |
+
return self.model(feat)
|
63 |
+
|
64 |
+
|
65 |
+
class RRDB(nn.Module):
|
66 |
+
"""
|
67 |
+
Residual in Residual Dense Block
|
68 |
+
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
69 |
+
"""
|
70 |
+
|
71 |
+
def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
72 |
+
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
73 |
+
spectral_norm=False, gaussian_noise=False, plus=False):
|
74 |
+
super(RRDB, self).__init__()
|
75 |
+
# This is for backwards compatibility with existing models
|
76 |
+
if nr == 3:
|
77 |
+
self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
78 |
+
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
79 |
+
gaussian_noise=gaussian_noise, plus=plus)
|
80 |
+
self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
81 |
+
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
82 |
+
gaussian_noise=gaussian_noise, plus=plus)
|
83 |
+
self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
84 |
+
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
85 |
+
gaussian_noise=gaussian_noise, plus=plus)
|
86 |
+
else:
|
87 |
+
RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
88 |
+
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
89 |
+
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
|
90 |
+
self.RDBs = nn.Sequential(*RDB_list)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
if hasattr(self, 'RDB1'):
|
94 |
+
out = self.RDB1(x)
|
95 |
+
out = self.RDB2(out)
|
96 |
+
out = self.RDB3(out)
|
97 |
+
else:
|
98 |
+
out = self.RDBs(x)
|
99 |
+
return out * 0.2 + x
|
100 |
+
|
101 |
+
|
102 |
+
class ResidualDenseBlock_5C(nn.Module):
|
103 |
+
"""
|
104 |
+
Residual Dense Block
|
105 |
+
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
106 |
+
Modified options that can be used:
|
107 |
+
- "Partial Convolution based Padding" arXiv:1811.11718
|
108 |
+
- "Spectral normalization" arXiv:1802.05957
|
109 |
+
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
110 |
+
{Rakotonirina} and A. {Rasoanaivo}
|
111 |
+
"""
|
112 |
+
|
113 |
+
def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
114 |
+
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
115 |
+
spectral_norm=False, gaussian_noise=False, plus=False):
|
116 |
+
super(ResidualDenseBlock_5C, self).__init__()
|
117 |
+
|
118 |
+
self.noise = GaussianNoise() if gaussian_noise else None
|
119 |
+
self.conv1x1 = conv1x1(nf, gc) if plus else None
|
120 |
+
|
121 |
+
self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
122 |
+
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
123 |
+
spectral_norm=spectral_norm)
|
124 |
+
self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
125 |
+
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
126 |
+
spectral_norm=spectral_norm)
|
127 |
+
self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
128 |
+
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
129 |
+
spectral_norm=spectral_norm)
|
130 |
+
self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
131 |
+
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
132 |
+
spectral_norm=spectral_norm)
|
133 |
+
if mode == 'CNA':
|
134 |
+
last_act = None
|
135 |
+
else:
|
136 |
+
last_act = act_type
|
137 |
+
self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
|
138 |
+
norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
|
139 |
+
spectral_norm=spectral_norm)
|
140 |
+
|
141 |
+
def forward(self, x):
|
142 |
+
x1 = self.conv1(x)
|
143 |
+
x2 = self.conv2(torch.cat((x, x1), 1))
|
144 |
+
if self.conv1x1:
|
145 |
+
x2 = x2 + self.conv1x1(x)
|
146 |
+
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
147 |
+
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
148 |
+
if self.conv1x1:
|
149 |
+
x4 = x4 + x2
|
150 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
151 |
+
if self.noise:
|
152 |
+
return self.noise(x5.mul(0.2) + x)
|
153 |
+
else:
|
154 |
+
return x5 * 0.2 + x
|
155 |
+
|
156 |
+
|
157 |
+
####################
|
158 |
+
# ESRGANplus
|
159 |
+
####################
|
160 |
+
|
161 |
+
class GaussianNoise(nn.Module):
|
162 |
+
def __init__(self, sigma=0.1, is_relative_detach=False):
|
163 |
+
super().__init__()
|
164 |
+
self.sigma = sigma
|
165 |
+
self.is_relative_detach = is_relative_detach
|
166 |
+
self.noise = torch.tensor(0, dtype=torch.float)
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
if self.training and self.sigma != 0:
|
170 |
+
self.noise = self.noise.to(x.device)
|
171 |
+
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
172 |
+
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
173 |
+
x = x + sampled_noise
|
174 |
+
return x
|
175 |
+
|
176 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
177 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
178 |
+
|
179 |
+
|
180 |
+
####################
|
181 |
+
# SRVGGNetCompact
|
182 |
+
####################
|
183 |
+
|
184 |
+
class SRVGGNetCompact(nn.Module):
|
185 |
+
"""A compact VGG-style network structure for super-resolution.
|
186 |
+
This class is copied from https://github.com/xinntao/Real-ESRGAN
|
187 |
+
"""
|
188 |
+
|
189 |
+
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
|
190 |
+
super(SRVGGNetCompact, self).__init__()
|
191 |
+
self.num_in_ch = num_in_ch
|
192 |
+
self.num_out_ch = num_out_ch
|
193 |
+
self.num_feat = num_feat
|
194 |
+
self.num_conv = num_conv
|
195 |
+
self.upscale = upscale
|
196 |
+
self.act_type = act_type
|
197 |
+
|
198 |
+
self.body = nn.ModuleList()
|
199 |
+
# the first conv
|
200 |
+
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
201 |
+
# the first activation
|
202 |
+
if act_type == 'relu':
|
203 |
+
activation = nn.ReLU(inplace=True)
|
204 |
+
elif act_type == 'prelu':
|
205 |
+
activation = nn.PReLU(num_parameters=num_feat)
|
206 |
+
elif act_type == 'leakyrelu':
|
207 |
+
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
208 |
+
self.body.append(activation)
|
209 |
+
|
210 |
+
# the body structure
|
211 |
+
for _ in range(num_conv):
|
212 |
+
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
213 |
+
# activation
|
214 |
+
if act_type == 'relu':
|
215 |
+
activation = nn.ReLU(inplace=True)
|
216 |
+
elif act_type == 'prelu':
|
217 |
+
activation = nn.PReLU(num_parameters=num_feat)
|
218 |
+
elif act_type == 'leakyrelu':
|
219 |
+
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
220 |
+
self.body.append(activation)
|
221 |
+
|
222 |
+
# the last conv
|
223 |
+
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
224 |
+
# upsample
|
225 |
+
self.upsampler = nn.PixelShuffle(upscale)
|
226 |
+
|
227 |
+
def forward(self, x):
|
228 |
+
out = x
|
229 |
+
for i in range(0, len(self.body)):
|
230 |
+
out = self.body[i](out)
|
231 |
+
|
232 |
+
out = self.upsampler(out)
|
233 |
+
# add the nearest upsampled image, so that the network learns the residual
|
234 |
+
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
235 |
+
out += base
|
236 |
+
return out
|
237 |
+
|
238 |
+
|
239 |
+
####################
|
240 |
+
# Upsampler
|
241 |
+
####################
|
242 |
+
|
243 |
+
class Upsample(nn.Module):
|
244 |
+
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
|
245 |
+
The input data is assumed to be of the form
|
246 |
+
`minibatch x channels x [optional depth] x [optional height] x width`.
|
247 |
+
"""
|
248 |
+
|
249 |
+
def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
250 |
+
super(Upsample, self).__init__()
|
251 |
+
if isinstance(scale_factor, tuple):
|
252 |
+
self.scale_factor = tuple(float(factor) for factor in scale_factor)
|
253 |
+
else:
|
254 |
+
self.scale_factor = float(scale_factor) if scale_factor else None
|
255 |
+
self.mode = mode
|
256 |
+
self.size = size
|
257 |
+
self.align_corners = align_corners
|
258 |
+
|
259 |
+
def forward(self, x):
|
260 |
+
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
261 |
+
|
262 |
+
def extra_repr(self):
|
263 |
+
if self.scale_factor is not None:
|
264 |
+
info = 'scale_factor=' + str(self.scale_factor)
|
265 |
+
else:
|
266 |
+
info = 'size=' + str(self.size)
|
267 |
+
info += ', mode=' + self.mode
|
268 |
+
return info
|
269 |
+
|
270 |
+
|
271 |
+
def pixel_unshuffle(x, scale):
|
272 |
+
""" Pixel unshuffle.
|
273 |
+
Args:
|
274 |
+
x (Tensor): Input feature with shape (b, c, hh, hw).
|
275 |
+
scale (int): Downsample ratio.
|
276 |
+
Returns:
|
277 |
+
Tensor: the pixel unshuffled feature.
|
278 |
+
"""
|
279 |
+
b, c, hh, hw = x.size()
|
280 |
+
out_channel = c * (scale**2)
|
281 |
+
assert hh % scale == 0 and hw % scale == 0
|
282 |
+
h = hh // scale
|
283 |
+
w = hw // scale
|
284 |
+
x_view = x.view(b, c, h, scale, w, scale)
|
285 |
+
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
286 |
+
|
287 |
+
|
288 |
+
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
289 |
+
pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
|
290 |
+
"""
|
291 |
+
Pixel shuffle layer
|
292 |
+
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
293 |
+
Neural Network, CVPR17)
|
294 |
+
"""
|
295 |
+
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
|
296 |
+
pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
|
297 |
+
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
298 |
+
|
299 |
+
n = norm(norm_type, out_nc) if norm_type else None
|
300 |
+
a = act(act_type) if act_type else None
|
301 |
+
return sequential(conv, pixel_shuffle, n, a)
|
302 |
+
|
303 |
+
|
304 |
+
def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
305 |
+
pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
|
306 |
+
""" Upconv layer """
|
307 |
+
upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
|
308 |
+
upsample = Upsample(scale_factor=upscale_factor, mode=mode)
|
309 |
+
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
|
310 |
+
pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
|
311 |
+
return sequential(upsample, conv)
|
312 |
+
|
313 |
+
|
314 |
+
|
315 |
+
|
316 |
+
|
317 |
+
|
318 |
+
|
319 |
+
|
320 |
+
####################
|
321 |
+
# Basic blocks
|
322 |
+
####################
|
323 |
+
|
324 |
+
|
325 |
+
def make_layer(basic_block, num_basic_block, **kwarg):
|
326 |
+
"""Make layers by stacking the same blocks.
|
327 |
+
Args:
|
328 |
+
basic_block (nn.module): nn.module class for basic block. (block)
|
329 |
+
num_basic_block (int): number of blocks. (n_layers)
|
330 |
+
Returns:
|
331 |
+
nn.Sequential: Stacked blocks in nn.Sequential.
|
332 |
+
"""
|
333 |
+
layers = []
|
334 |
+
for _ in range(num_basic_block):
|
335 |
+
layers.append(basic_block(**kwarg))
|
336 |
+
return nn.Sequential(*layers)
|
337 |
+
|
338 |
+
|
339 |
+
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
|
340 |
+
""" activation helper """
|
341 |
+
act_type = act_type.lower()
|
342 |
+
if act_type == 'relu':
|
343 |
+
layer = nn.ReLU(inplace)
|
344 |
+
elif act_type in ('leakyrelu', 'lrelu'):
|
345 |
+
layer = nn.LeakyReLU(neg_slope, inplace)
|
346 |
+
elif act_type == 'prelu':
|
347 |
+
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
348 |
+
elif act_type == 'tanh': # [-1, 1] range output
|
349 |
+
layer = nn.Tanh()
|
350 |
+
elif act_type == 'sigmoid': # [0, 1] range output
|
351 |
+
layer = nn.Sigmoid()
|
352 |
+
else:
|
353 |
+
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
|
354 |
+
return layer
|
355 |
+
|
356 |
+
|
357 |
+
class Identity(nn.Module):
|
358 |
+
def __init__(self, *kwargs):
|
359 |
+
super(Identity, self).__init__()
|
360 |
+
|
361 |
+
def forward(self, x, *kwargs):
|
362 |
+
return x
|
363 |
+
|
364 |
+
|
365 |
+
def norm(norm_type, nc):
|
366 |
+
""" Return a normalization layer """
|
367 |
+
norm_type = norm_type.lower()
|
368 |
+
if norm_type == 'batch':
|
369 |
+
layer = nn.BatchNorm2d(nc, affine=True)
|
370 |
+
elif norm_type == 'instance':
|
371 |
+
layer = nn.InstanceNorm2d(nc, affine=False)
|
372 |
+
elif norm_type == 'none':
|
373 |
+
def norm_layer(x): return Identity()
|
374 |
+
else:
|
375 |
+
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
|
376 |
+
return layer
|
377 |
+
|
378 |
+
|
379 |
+
def pad(pad_type, padding):
|
380 |
+
""" padding layer helper """
|
381 |
+
pad_type = pad_type.lower()
|
382 |
+
if padding == 0:
|
383 |
+
return None
|
384 |
+
if pad_type == 'reflect':
|
385 |
+
layer = nn.ReflectionPad2d(padding)
|
386 |
+
elif pad_type == 'replicate':
|
387 |
+
layer = nn.ReplicationPad2d(padding)
|
388 |
+
elif pad_type == 'zero':
|
389 |
+
layer = nn.ZeroPad2d(padding)
|
390 |
+
else:
|
391 |
+
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
|
392 |
+
return layer
|
393 |
+
|
394 |
+
|
395 |
+
def get_valid_padding(kernel_size, dilation):
|
396 |
+
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
397 |
+
padding = (kernel_size - 1) // 2
|
398 |
+
return padding
|
399 |
+
|
400 |
+
|
401 |
+
class ShortcutBlock(nn.Module):
|
402 |
+
""" Elementwise sum the output of a submodule to its input """
|
403 |
+
def __init__(self, submodule):
|
404 |
+
super(ShortcutBlock, self).__init__()
|
405 |
+
self.sub = submodule
|
406 |
+
|
407 |
+
def forward(self, x):
|
408 |
+
output = x + self.sub(x)
|
409 |
+
return output
|
410 |
+
|
411 |
+
def __repr__(self):
|
412 |
+
return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
|
413 |
+
|
414 |
+
|
415 |
+
def sequential(*args):
|
416 |
+
""" Flatten Sequential. It unwraps nn.Sequential. """
|
417 |
+
if len(args) == 1:
|
418 |
+
if isinstance(args[0], OrderedDict):
|
419 |
+
raise NotImplementedError('sequential does not support OrderedDict input.')
|
420 |
+
return args[0] # No sequential is needed.
|
421 |
+
modules = []
|
422 |
+
for module in args:
|
423 |
+
if isinstance(module, nn.Sequential):
|
424 |
+
for submodule in module.children():
|
425 |
+
modules.append(submodule)
|
426 |
+
elif isinstance(module, nn.Module):
|
427 |
+
modules.append(module)
|
428 |
+
return nn.Sequential(*modules)
|
429 |
+
|
430 |
+
|
431 |
+
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
|
432 |
+
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
|
433 |
+
spectral_norm=False):
|
434 |
+
""" Conv layer with padding, normalization, activation """
|
435 |
+
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
|
436 |
+
padding = get_valid_padding(kernel_size, dilation)
|
437 |
+
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
438 |
+
padding = padding if pad_type == 'zero' else 0
|
439 |
+
|
440 |
+
if convtype=='PartialConv2D':
|
441 |
+
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
442 |
+
dilation=dilation, bias=bias, groups=groups)
|
443 |
+
elif convtype=='DeformConv2D':
|
444 |
+
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
445 |
+
dilation=dilation, bias=bias, groups=groups)
|
446 |
+
elif convtype=='Conv3D':
|
447 |
+
c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
448 |
+
dilation=dilation, bias=bias, groups=groups)
|
449 |
+
else:
|
450 |
+
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
451 |
+
dilation=dilation, bias=bias, groups=groups)
|
452 |
+
|
453 |
+
if spectral_norm:
|
454 |
+
c = nn.utils.spectral_norm(c)
|
455 |
+
|
456 |
+
a = act(act_type) if act_type else None
|
457 |
+
if 'CNA' in mode:
|
458 |
+
n = norm(norm_type, out_nc) if norm_type else None
|
459 |
+
return sequential(p, c, n, a)
|
460 |
+
elif mode == 'NAC':
|
461 |
+
if norm_type is None and act_type is not None:
|
462 |
+
a = act(act_type, inplace=False)
|
463 |
+
n = norm(norm_type, in_nc) if norm_type else None
|
464 |
+
return sequential(n, a, p, c)
|
sd/stable-diffusion-webui/modules/extensions.py
CHANGED
@@ -1,107 +1,107 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import traceback
|
4 |
-
|
5 |
-
import time
|
6 |
-
import git
|
7 |
-
|
8 |
-
from modules import paths, shared
|
9 |
-
|
10 |
-
extensions = []
|
11 |
-
extensions_dir = os.path.join(paths.data_path, "extensions")
|
12 |
-
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
|
13 |
-
|
14 |
-
if not os.path.exists(extensions_dir):
|
15 |
-
os.makedirs(extensions_dir)
|
16 |
-
|
17 |
-
def active():
|
18 |
-
return [x for x in extensions if x.enabled]
|
19 |
-
|
20 |
-
|
21 |
-
class Extension:
|
22 |
-
def __init__(self, name, path, enabled=True, is_builtin=False):
|
23 |
-
self.name = name
|
24 |
-
self.path = path
|
25 |
-
self.enabled = enabled
|
26 |
-
self.status = ''
|
27 |
-
self.can_update = False
|
28 |
-
self.is_builtin = is_builtin
|
29 |
-
self.version = ''
|
30 |
-
|
31 |
-
repo = None
|
32 |
-
try:
|
33 |
-
if os.path.exists(os.path.join(path, ".git")):
|
34 |
-
repo = git.Repo(path)
|
35 |
-
except Exception:
|
36 |
-
print(f"Error reading github repository info from {path}:", file=sys.stderr)
|
37 |
-
print(traceback.format_exc(), file=sys.stderr)
|
38 |
-
|
39 |
-
if repo is None or repo.bare:
|
40 |
-
self.remote = None
|
41 |
-
else:
|
42 |
-
try:
|
43 |
-
self.remote = next(repo.remote().urls, None)
|
44 |
-
self.status = 'unknown'
|
45 |
-
head = repo.head.commit
|
46 |
-
ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
|
47 |
-
self.version = f'{head.hexsha[:8]} ({ts})'
|
48 |
-
|
49 |
-
except Exception:
|
50 |
-
self.remote = None
|
51 |
-
|
52 |
-
def list_files(self, subdir, extension):
|
53 |
-
from modules import scripts
|
54 |
-
|
55 |
-
dirpath = os.path.join(self.path, subdir)
|
56 |
-
if not os.path.isdir(dirpath):
|
57 |
-
return []
|
58 |
-
|
59 |
-
res = []
|
60 |
-
for filename in sorted(os.listdir(dirpath)):
|
61 |
-
res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
|
62 |
-
|
63 |
-
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
64 |
-
|
65 |
-
return res
|
66 |
-
|
67 |
-
def check_updates(self):
|
68 |
-
repo = git.Repo(self.path)
|
69 |
-
for fetch in repo.remote().fetch("--dry-run"):
|
70 |
-
if fetch.flags != fetch.HEAD_UPTODATE:
|
71 |
-
self.can_update = True
|
72 |
-
self.status = "behind"
|
73 |
-
return
|
74 |
-
|
75 |
-
self.can_update = False
|
76 |
-
self.status = "latest"
|
77 |
-
|
78 |
-
def fetch_and_reset_hard(self):
|
79 |
-
repo = git.Repo(self.path)
|
80 |
-
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
81 |
-
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
82 |
-
repo.git.fetch('--all')
|
83 |
-
repo.git.reset('--hard', 'origin')
|
84 |
-
|
85 |
-
|
86 |
-
def list_extensions():
|
87 |
-
extensions.clear()
|
88 |
-
|
89 |
-
if not os.path.isdir(extensions_dir):
|
90 |
-
return
|
91 |
-
|
92 |
-
paths = []
|
93 |
-
for dirname in [extensions_dir, extensions_builtin_dir]:
|
94 |
-
if not os.path.isdir(dirname):
|
95 |
-
return
|
96 |
-
|
97 |
-
for extension_dirname in sorted(os.listdir(dirname)):
|
98 |
-
path = os.path.join(dirname, extension_dirname)
|
99 |
-
if not os.path.isdir(path):
|
100 |
-
continue
|
101 |
-
|
102 |
-
paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
|
103 |
-
|
104 |
-
for dirname, path, is_builtin in paths:
|
105 |
-
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
|
106 |
-
extensions.append(extension)
|
107 |
-
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import time
|
6 |
+
import git
|
7 |
+
|
8 |
+
from modules import paths, shared
|
9 |
+
|
10 |
+
extensions = []
|
11 |
+
extensions_dir = os.path.join(paths.data_path, "extensions")
|
12 |
+
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
|
13 |
+
|
14 |
+
if not os.path.exists(extensions_dir):
|
15 |
+
os.makedirs(extensions_dir)
|
16 |
+
|
17 |
+
def active():
|
18 |
+
return [x for x in extensions if x.enabled]
|
19 |
+
|
20 |
+
|
21 |
+
class Extension:
|
22 |
+
def __init__(self, name, path, enabled=True, is_builtin=False):
|
23 |
+
self.name = name
|
24 |
+
self.path = path
|
25 |
+
self.enabled = enabled
|
26 |
+
self.status = ''
|
27 |
+
self.can_update = False
|
28 |
+
self.is_builtin = is_builtin
|
29 |
+
self.version = ''
|
30 |
+
|
31 |
+
repo = None
|
32 |
+
try:
|
33 |
+
if os.path.exists(os.path.join(path, ".git")):
|
34 |
+
repo = git.Repo(path)
|
35 |
+
except Exception:
|
36 |
+
print(f"Error reading github repository info from {path}:", file=sys.stderr)
|
37 |
+
print(traceback.format_exc(), file=sys.stderr)
|
38 |
+
|
39 |
+
if repo is None or repo.bare:
|
40 |
+
self.remote = None
|
41 |
+
else:
|
42 |
+
try:
|
43 |
+
self.remote = next(repo.remote().urls, None)
|
44 |
+
self.status = 'unknown'
|
45 |
+
head = repo.head.commit
|
46 |
+
ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
|
47 |
+
self.version = f'{head.hexsha[:8]} ({ts})'
|
48 |
+
|
49 |
+
except Exception:
|
50 |
+
self.remote = None
|
51 |
+
|
52 |
+
def list_files(self, subdir, extension):
|
53 |
+
from modules import scripts
|
54 |
+
|
55 |
+
dirpath = os.path.join(self.path, subdir)
|
56 |
+
if not os.path.isdir(dirpath):
|
57 |
+
return []
|
58 |
+
|
59 |
+
res = []
|
60 |
+
for filename in sorted(os.listdir(dirpath)):
|
61 |
+
res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
|
62 |
+
|
63 |
+
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
64 |
+
|
65 |
+
return res
|
66 |
+
|
67 |
+
def check_updates(self):
|
68 |
+
repo = git.Repo(self.path)
|
69 |
+
for fetch in repo.remote().fetch("--dry-run"):
|
70 |
+
if fetch.flags != fetch.HEAD_UPTODATE:
|
71 |
+
self.can_update = True
|
72 |
+
self.status = "behind"
|
73 |
+
return
|
74 |
+
|
75 |
+
self.can_update = False
|
76 |
+
self.status = "latest"
|
77 |
+
|
78 |
+
def fetch_and_reset_hard(self):
|
79 |
+
repo = git.Repo(self.path)
|
80 |
+
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
81 |
+
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
82 |
+
repo.git.fetch('--all')
|
83 |
+
repo.git.reset('--hard', 'origin')
|
84 |
+
|
85 |
+
|
86 |
+
def list_extensions():
|
87 |
+
extensions.clear()
|
88 |
+
|
89 |
+
if not os.path.isdir(extensions_dir):
|
90 |
+
return
|
91 |
+
|
92 |
+
paths = []
|
93 |
+
for dirname in [extensions_dir, extensions_builtin_dir]:
|
94 |
+
if not os.path.isdir(dirname):
|
95 |
+
return
|
96 |
+
|
97 |
+
for extension_dirname in sorted(os.listdir(dirname)):
|
98 |
+
path = os.path.join(dirname, extension_dirname)
|
99 |
+
if not os.path.isdir(path):
|
100 |
+
continue
|
101 |
+
|
102 |
+
paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
|
103 |
+
|
104 |
+
for dirname, path, is_builtin in paths:
|
105 |
+
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
|
106 |
+
extensions.append(extension)
|
107 |
+
|
sd/stable-diffusion-webui/modules/extra_networks.py
CHANGED
@@ -1,147 +1,147 @@
|
|
1 |
-
import re
|
2 |
-
from collections import defaultdict
|
3 |
-
|
4 |
-
from modules import errors
|
5 |
-
|
6 |
-
extra_network_registry = {}
|
7 |
-
|
8 |
-
|
9 |
-
def initialize():
|
10 |
-
extra_network_registry.clear()
|
11 |
-
|
12 |
-
|
13 |
-
def register_extra_network(extra_network):
|
14 |
-
extra_network_registry[extra_network.name] = extra_network
|
15 |
-
|
16 |
-
|
17 |
-
class ExtraNetworkParams:
|
18 |
-
def __init__(self, items=None):
|
19 |
-
self.items = items or []
|
20 |
-
|
21 |
-
|
22 |
-
class ExtraNetwork:
|
23 |
-
def __init__(self, name):
|
24 |
-
self.name = name
|
25 |
-
|
26 |
-
def activate(self, p, params_list):
|
27 |
-
"""
|
28 |
-
Called by processing on every run. Whatever the extra network is meant to do should be activated here.
|
29 |
-
Passes arguments related to this extra network in params_list.
|
30 |
-
User passes arguments by specifying this in his prompt:
|
31 |
-
|
32 |
-
<name:arg1:arg2:arg3>
|
33 |
-
|
34 |
-
Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
|
35 |
-
separated by colon.
|
36 |
-
|
37 |
-
Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
|
38 |
-
in this case, all effects of this extra networks should be disabled.
|
39 |
-
|
40 |
-
Can be called multiple times before deactivate() - each new call should override the previous call completely.
|
41 |
-
|
42 |
-
For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
|
43 |
-
|
44 |
-
> "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
|
45 |
-
|
46 |
-
params_list will be:
|
47 |
-
|
48 |
-
[
|
49 |
-
ExtraNetworkParams(items=["agm", "1.1"]),
|
50 |
-
ExtraNetworkParams(items=["ray"])
|
51 |
-
]
|
52 |
-
|
53 |
-
"""
|
54 |
-
raise NotImplementedError
|
55 |
-
|
56 |
-
def deactivate(self, p):
|
57 |
-
"""
|
58 |
-
Called at the end of processing for housekeeping. No need to do anything here.
|
59 |
-
"""
|
60 |
-
|
61 |
-
raise NotImplementedError
|
62 |
-
|
63 |
-
|
64 |
-
def activate(p, extra_network_data):
|
65 |
-
"""call activate for extra networks in extra_network_data in specified order, then call
|
66 |
-
activate for all remaining registered networks with an empty argument list"""
|
67 |
-
|
68 |
-
for extra_network_name, extra_network_args in extra_network_data.items():
|
69 |
-
extra_network = extra_network_registry.get(extra_network_name, None)
|
70 |
-
if extra_network is None:
|
71 |
-
print(f"Skipping unknown extra network: {extra_network_name}")
|
72 |
-
continue
|
73 |
-
|
74 |
-
try:
|
75 |
-
extra_network.activate(p, extra_network_args)
|
76 |
-
except Exception as e:
|
77 |
-
errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
|
78 |
-
|
79 |
-
for extra_network_name, extra_network in extra_network_registry.items():
|
80 |
-
args = extra_network_data.get(extra_network_name, None)
|
81 |
-
if args is not None:
|
82 |
-
continue
|
83 |
-
|
84 |
-
try:
|
85 |
-
extra_network.activate(p, [])
|
86 |
-
except Exception as e:
|
87 |
-
errors.display(e, f"activating extra network {extra_network_name}")
|
88 |
-
|
89 |
-
|
90 |
-
def deactivate(p, extra_network_data):
|
91 |
-
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
92 |
-
deactivate for all remaining registered networks"""
|
93 |
-
|
94 |
-
for extra_network_name, extra_network_args in extra_network_data.items():
|
95 |
-
extra_network = extra_network_registry.get(extra_network_name, None)
|
96 |
-
if extra_network is None:
|
97 |
-
continue
|
98 |
-
|
99 |
-
try:
|
100 |
-
extra_network.deactivate(p)
|
101 |
-
except Exception as e:
|
102 |
-
errors.display(e, f"deactivating extra network {extra_network_name}")
|
103 |
-
|
104 |
-
for extra_network_name, extra_network in extra_network_registry.items():
|
105 |
-
args = extra_network_data.get(extra_network_name, None)
|
106 |
-
if args is not None:
|
107 |
-
continue
|
108 |
-
|
109 |
-
try:
|
110 |
-
extra_network.deactivate(p)
|
111 |
-
except Exception as e:
|
112 |
-
errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
|
113 |
-
|
114 |
-
|
115 |
-
re_extra_net = re.compile(r"<(\w+):([^>]+)>")
|
116 |
-
|
117 |
-
|
118 |
-
def parse_prompt(prompt):
|
119 |
-
res = defaultdict(list)
|
120 |
-
|
121 |
-
def found(m):
|
122 |
-
name = m.group(1)
|
123 |
-
args = m.group(2)
|
124 |
-
|
125 |
-
res[name].append(ExtraNetworkParams(items=args.split(":")))
|
126 |
-
|
127 |
-
return ""
|
128 |
-
|
129 |
-
prompt = re.sub(re_extra_net, found, prompt)
|
130 |
-
|
131 |
-
return prompt, res
|
132 |
-
|
133 |
-
|
134 |
-
def parse_prompts(prompts):
|
135 |
-
res = []
|
136 |
-
extra_data = None
|
137 |
-
|
138 |
-
for prompt in prompts:
|
139 |
-
updated_prompt, parsed_extra_data = parse_prompt(prompt)
|
140 |
-
|
141 |
-
if extra_data is None:
|
142 |
-
extra_data = parsed_extra_data
|
143 |
-
|
144 |
-
res.append(updated_prompt)
|
145 |
-
|
146 |
-
return res, extra_data
|
147 |
-
|
|
|
1 |
+
import re
|
2 |
+
from collections import defaultdict
|
3 |
+
|
4 |
+
from modules import errors
|
5 |
+
|
6 |
+
extra_network_registry = {}
|
7 |
+
|
8 |
+
|
9 |
+
def initialize():
|
10 |
+
extra_network_registry.clear()
|
11 |
+
|
12 |
+
|
13 |
+
def register_extra_network(extra_network):
|
14 |
+
extra_network_registry[extra_network.name] = extra_network
|
15 |
+
|
16 |
+
|
17 |
+
class ExtraNetworkParams:
|
18 |
+
def __init__(self, items=None):
|
19 |
+
self.items = items or []
|
20 |
+
|
21 |
+
|
22 |
+
class ExtraNetwork:
|
23 |
+
def __init__(self, name):
|
24 |
+
self.name = name
|
25 |
+
|
26 |
+
def activate(self, p, params_list):
|
27 |
+
"""
|
28 |
+
Called by processing on every run. Whatever the extra network is meant to do should be activated here.
|
29 |
+
Passes arguments related to this extra network in params_list.
|
30 |
+
User passes arguments by specifying this in his prompt:
|
31 |
+
|
32 |
+
<name:arg1:arg2:arg3>
|
33 |
+
|
34 |
+
Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
|
35 |
+
separated by colon.
|
36 |
+
|
37 |
+
Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
|
38 |
+
in this case, all effects of this extra networks should be disabled.
|
39 |
+
|
40 |
+
Can be called multiple times before deactivate() - each new call should override the previous call completely.
|
41 |
+
|
42 |
+
For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
|
43 |
+
|
44 |
+
> "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
|
45 |
+
|
46 |
+
params_list will be:
|
47 |
+
|
48 |
+
[
|
49 |
+
ExtraNetworkParams(items=["agm", "1.1"]),
|
50 |
+
ExtraNetworkParams(items=["ray"])
|
51 |
+
]
|
52 |
+
|
53 |
+
"""
|
54 |
+
raise NotImplementedError
|
55 |
+
|
56 |
+
def deactivate(self, p):
|
57 |
+
"""
|
58 |
+
Called at the end of processing for housekeeping. No need to do anything here.
|
59 |
+
"""
|
60 |
+
|
61 |
+
raise NotImplementedError
|
62 |
+
|
63 |
+
|
64 |
+
def activate(p, extra_network_data):
|
65 |
+
"""call activate for extra networks in extra_network_data in specified order, then call
|
66 |
+
activate for all remaining registered networks with an empty argument list"""
|
67 |
+
|
68 |
+
for extra_network_name, extra_network_args in extra_network_data.items():
|
69 |
+
extra_network = extra_network_registry.get(extra_network_name, None)
|
70 |
+
if extra_network is None:
|
71 |
+
print(f"Skipping unknown extra network: {extra_network_name}")
|
72 |
+
continue
|
73 |
+
|
74 |
+
try:
|
75 |
+
extra_network.activate(p, extra_network_args)
|
76 |
+
except Exception as e:
|
77 |
+
errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
|
78 |
+
|
79 |
+
for extra_network_name, extra_network in extra_network_registry.items():
|
80 |
+
args = extra_network_data.get(extra_network_name, None)
|
81 |
+
if args is not None:
|
82 |
+
continue
|
83 |
+
|
84 |
+
try:
|
85 |
+
extra_network.activate(p, [])
|
86 |
+
except Exception as e:
|
87 |
+
errors.display(e, f"activating extra network {extra_network_name}")
|
88 |
+
|
89 |
+
|
90 |
+
def deactivate(p, extra_network_data):
|
91 |
+
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
92 |
+
deactivate for all remaining registered networks"""
|
93 |
+
|
94 |
+
for extra_network_name, extra_network_args in extra_network_data.items():
|
95 |
+
extra_network = extra_network_registry.get(extra_network_name, None)
|
96 |
+
if extra_network is None:
|
97 |
+
continue
|
98 |
+
|
99 |
+
try:
|
100 |
+
extra_network.deactivate(p)
|
101 |
+
except Exception as e:
|
102 |
+
errors.display(e, f"deactivating extra network {extra_network_name}")
|
103 |
+
|
104 |
+
for extra_network_name, extra_network in extra_network_registry.items():
|
105 |
+
args = extra_network_data.get(extra_network_name, None)
|
106 |
+
if args is not None:
|
107 |
+
continue
|
108 |
+
|
109 |
+
try:
|
110 |
+
extra_network.deactivate(p)
|
111 |
+
except Exception as e:
|
112 |
+
errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
|
113 |
+
|
114 |
+
|
115 |
+
re_extra_net = re.compile(r"<(\w+):([^>]+)>")
|
116 |
+
|
117 |
+
|
118 |
+
def parse_prompt(prompt):
|
119 |
+
res = defaultdict(list)
|
120 |
+
|
121 |
+
def found(m):
|
122 |
+
name = m.group(1)
|
123 |
+
args = m.group(2)
|
124 |
+
|
125 |
+
res[name].append(ExtraNetworkParams(items=args.split(":")))
|
126 |
+
|
127 |
+
return ""
|
128 |
+
|
129 |
+
prompt = re.sub(re_extra_net, found, prompt)
|
130 |
+
|
131 |
+
return prompt, res
|
132 |
+
|
133 |
+
|
134 |
+
def parse_prompts(prompts):
|
135 |
+
res = []
|
136 |
+
extra_data = None
|
137 |
+
|
138 |
+
for prompt in prompts:
|
139 |
+
updated_prompt, parsed_extra_data = parse_prompt(prompt)
|
140 |
+
|
141 |
+
if extra_data is None:
|
142 |
+
extra_data = parsed_extra_data
|
143 |
+
|
144 |
+
res.append(updated_prompt)
|
145 |
+
|
146 |
+
return res, extra_data
|
147 |
+
|
sd/stable-diffusion-webui/modules/extra_networks_hypernet.py
CHANGED
@@ -1,27 +1,27 @@
|
|
1 |
-
from modules import extra_networks, shared, extra_networks
|
2 |
-
from modules.hypernetworks import hypernetwork
|
3 |
-
|
4 |
-
|
5 |
-
class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
|
6 |
-
def __init__(self):
|
7 |
-
super().__init__('hypernet')
|
8 |
-
|
9 |
-
def activate(self, p, params_list):
|
10 |
-
additional = shared.opts.sd_hypernetwork
|
11 |
-
|
12 |
-
if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
|
13 |
-
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
14 |
-
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
15 |
-
|
16 |
-
names = []
|
17 |
-
multipliers = []
|
18 |
-
for params in params_list:
|
19 |
-
assert len(params.items) > 0
|
20 |
-
|
21 |
-
names.append(params.items[0])
|
22 |
-
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
|
23 |
-
|
24 |
-
hypernetwork.load_hypernetworks(names, multipliers)
|
25 |
-
|
26 |
-
def deactivate(self, p):
|
27 |
-
pass
|
|
|
1 |
+
from modules import extra_networks, shared, extra_networks
|
2 |
+
from modules.hypernetworks import hypernetwork
|
3 |
+
|
4 |
+
|
5 |
+
class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
|
6 |
+
def __init__(self):
|
7 |
+
super().__init__('hypernet')
|
8 |
+
|
9 |
+
def activate(self, p, params_list):
|
10 |
+
additional = shared.opts.sd_hypernetwork
|
11 |
+
|
12 |
+
if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
|
13 |
+
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
14 |
+
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
15 |
+
|
16 |
+
names = []
|
17 |
+
multipliers = []
|
18 |
+
for params in params_list:
|
19 |
+
assert len(params.items) > 0
|
20 |
+
|
21 |
+
names.append(params.items[0])
|
22 |
+
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
|
23 |
+
|
24 |
+
hypernetwork.load_hypernetworks(names, multipliers)
|
25 |
+
|
26 |
+
def deactivate(self, p):
|
27 |
+
pass
|
sd/stable-diffusion-webui/modules/extras.py
CHANGED
@@ -1,258 +1,258 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
import shutil
|
4 |
-
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import tqdm
|
8 |
-
|
9 |
-
from modules import shared, images, sd_models, sd_vae, sd_models_config
|
10 |
-
from modules.ui_common import plaintext_to_html
|
11 |
-
import gradio as gr
|
12 |
-
import safetensors.torch
|
13 |
-
|
14 |
-
|
15 |
-
def run_pnginfo(image):
|
16 |
-
if image is None:
|
17 |
-
return '', '', ''
|
18 |
-
|
19 |
-
geninfo, items = images.read_info_from_image(image)
|
20 |
-
items = {**{'parameters': geninfo}, **items}
|
21 |
-
|
22 |
-
info = ''
|
23 |
-
for key, text in items.items():
|
24 |
-
info += f"""
|
25 |
-
<div>
|
26 |
-
<p><b>{plaintext_to_html(str(key))}</b></p>
|
27 |
-
<p>{plaintext_to_html(str(text))}</p>
|
28 |
-
</div>
|
29 |
-
""".strip()+"\n"
|
30 |
-
|
31 |
-
if len(info) == 0:
|
32 |
-
message = "Nothing found in the image."
|
33 |
-
info = f"<div><p>{message}<p></div>"
|
34 |
-
|
35 |
-
return '', geninfo, info
|
36 |
-
|
37 |
-
|
38 |
-
def create_config(ckpt_result, config_source, a, b, c):
|
39 |
-
def config(x):
|
40 |
-
res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None
|
41 |
-
return res if res != shared.sd_default_config else None
|
42 |
-
|
43 |
-
if config_source == 0:
|
44 |
-
cfg = config(a) or config(b) or config(c)
|
45 |
-
elif config_source == 1:
|
46 |
-
cfg = config(b)
|
47 |
-
elif config_source == 2:
|
48 |
-
cfg = config(c)
|
49 |
-
else:
|
50 |
-
cfg = None
|
51 |
-
|
52 |
-
if cfg is None:
|
53 |
-
return
|
54 |
-
|
55 |
-
filename, _ = os.path.splitext(ckpt_result)
|
56 |
-
checkpoint_filename = filename + ".yaml"
|
57 |
-
|
58 |
-
print("Copying config:")
|
59 |
-
print(" from:", cfg)
|
60 |
-
print(" to:", checkpoint_filename)
|
61 |
-
shutil.copyfile(cfg, checkpoint_filename)
|
62 |
-
|
63 |
-
|
64 |
-
checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
65 |
-
|
66 |
-
|
67 |
-
def to_half(tensor, enable):
|
68 |
-
if enable and tensor.dtype == torch.float:
|
69 |
-
return tensor.half()
|
70 |
-
|
71 |
-
return tensor
|
72 |
-
|
73 |
-
|
74 |
-
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
|
75 |
-
shared.state.begin()
|
76 |
-
shared.state.job = 'model-merge'
|
77 |
-
|
78 |
-
def fail(message):
|
79 |
-
shared.state.textinfo = message
|
80 |
-
shared.state.end()
|
81 |
-
return [*[gr.update() for _ in range(4)], message]
|
82 |
-
|
83 |
-
def weighted_sum(theta0, theta1, alpha):
|
84 |
-
return ((1 - alpha) * theta0) + (alpha * theta1)
|
85 |
-
|
86 |
-
def get_difference(theta1, theta2):
|
87 |
-
return theta1 - theta2
|
88 |
-
|
89 |
-
def add_difference(theta0, theta1_2_diff, alpha):
|
90 |
-
return theta0 + (alpha * theta1_2_diff)
|
91 |
-
|
92 |
-
def filename_weighted_sum():
|
93 |
-
a = primary_model_info.model_name
|
94 |
-
b = secondary_model_info.model_name
|
95 |
-
Ma = round(1 - multiplier, 2)
|
96 |
-
Mb = round(multiplier, 2)
|
97 |
-
|
98 |
-
return f"{Ma}({a}) + {Mb}({b})"
|
99 |
-
|
100 |
-
def filename_add_difference():
|
101 |
-
a = primary_model_info.model_name
|
102 |
-
b = secondary_model_info.model_name
|
103 |
-
c = tertiary_model_info.model_name
|
104 |
-
M = round(multiplier, 2)
|
105 |
-
|
106 |
-
return f"{a} + {M}({b} - {c})"
|
107 |
-
|
108 |
-
def filename_nothing():
|
109 |
-
return primary_model_info.model_name
|
110 |
-
|
111 |
-
theta_funcs = {
|
112 |
-
"Weighted sum": (filename_weighted_sum, None, weighted_sum),
|
113 |
-
"Add difference": (filename_add_difference, get_difference, add_difference),
|
114 |
-
"No interpolation": (filename_nothing, None, None),
|
115 |
-
}
|
116 |
-
filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
|
117 |
-
shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
|
118 |
-
|
119 |
-
if not primary_model_name:
|
120 |
-
return fail("Failed: Merging requires a primary model.")
|
121 |
-
|
122 |
-
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
123 |
-
|
124 |
-
if theta_func2 and not secondary_model_name:
|
125 |
-
return fail("Failed: Merging requires a secondary model.")
|
126 |
-
|
127 |
-
secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
|
128 |
-
|
129 |
-
if theta_func1 and not tertiary_model_name:
|
130 |
-
return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
|
131 |
-
|
132 |
-
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
|
133 |
-
|
134 |
-
result_is_inpainting_model = False
|
135 |
-
result_is_instruct_pix2pix_model = False
|
136 |
-
|
137 |
-
if theta_func2:
|
138 |
-
shared.state.textinfo = f"Loading B"
|
139 |
-
print(f"Loading {secondary_model_info.filename}...")
|
140 |
-
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
141 |
-
else:
|
142 |
-
theta_1 = None
|
143 |
-
|
144 |
-
if theta_func1:
|
145 |
-
shared.state.textinfo = f"Loading C"
|
146 |
-
print(f"Loading {tertiary_model_info.filename}...")
|
147 |
-
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
148 |
-
|
149 |
-
shared.state.textinfo = 'Merging B and C'
|
150 |
-
shared.state.sampling_steps = len(theta_1.keys())
|
151 |
-
for key in tqdm.tqdm(theta_1.keys()):
|
152 |
-
if key in checkpoint_dict_skip_on_merge:
|
153 |
-
continue
|
154 |
-
|
155 |
-
if 'model' in key:
|
156 |
-
if key in theta_2:
|
157 |
-
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
|
158 |
-
theta_1[key] = theta_func1(theta_1[key], t2)
|
159 |
-
else:
|
160 |
-
theta_1[key] = torch.zeros_like(theta_1[key])
|
161 |
-
|
162 |
-
shared.state.sampling_step += 1
|
163 |
-
del theta_2
|
164 |
-
|
165 |
-
shared.state.nextjob()
|
166 |
-
|
167 |
-
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
|
168 |
-
print(f"Loading {primary_model_info.filename}...")
|
169 |
-
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
170 |
-
|
171 |
-
print("Merging...")
|
172 |
-
shared.state.textinfo = 'Merging A and B'
|
173 |
-
shared.state.sampling_steps = len(theta_0.keys())
|
174 |
-
for key in tqdm.tqdm(theta_0.keys()):
|
175 |
-
if theta_1 and 'model' in key and key in theta_1:
|
176 |
-
|
177 |
-
if key in checkpoint_dict_skip_on_merge:
|
178 |
-
continue
|
179 |
-
|
180 |
-
a = theta_0[key]
|
181 |
-
b = theta_1[key]
|
182 |
-
|
183 |
-
# this enables merging an inpainting model (A) with another one (B);
|
184 |
-
# where normal model would have 4 channels, for latenst space, inpainting model would
|
185 |
-
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
|
186 |
-
if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
|
187 |
-
if a.shape[1] == 4 and b.shape[1] == 9:
|
188 |
-
raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
|
189 |
-
if a.shape[1] == 4 and b.shape[1] == 8:
|
190 |
-
raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.")
|
191 |
-
|
192 |
-
if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model...
|
193 |
-
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
|
194 |
-
result_is_instruct_pix2pix_model = True
|
195 |
-
else:
|
196 |
-
assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
|
197 |
-
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
|
198 |
-
result_is_inpainting_model = True
|
199 |
-
else:
|
200 |
-
theta_0[key] = theta_func2(a, b, multiplier)
|
201 |
-
|
202 |
-
theta_0[key] = to_half(theta_0[key], save_as_half)
|
203 |
-
|
204 |
-
shared.state.sampling_step += 1
|
205 |
-
|
206 |
-
del theta_1
|
207 |
-
|
208 |
-
bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
|
209 |
-
if bake_in_vae_filename is not None:
|
210 |
-
print(f"Baking in VAE from {bake_in_vae_filename}")
|
211 |
-
shared.state.textinfo = 'Baking in VAE'
|
212 |
-
vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
|
213 |
-
|
214 |
-
for key in vae_dict.keys():
|
215 |
-
theta_0_key = 'first_stage_model.' + key
|
216 |
-
if theta_0_key in theta_0:
|
217 |
-
theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)
|
218 |
-
|
219 |
-
del vae_dict
|
220 |
-
|
221 |
-
if save_as_half and not theta_func2:
|
222 |
-
for key in theta_0.keys():
|
223 |
-
theta_0[key] = to_half(theta_0[key], save_as_half)
|
224 |
-
|
225 |
-
if discard_weights:
|
226 |
-
regex = re.compile(discard_weights)
|
227 |
-
for key in list(theta_0):
|
228 |
-
if re.search(regex, key):
|
229 |
-
theta_0.pop(key, None)
|
230 |
-
|
231 |
-
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
232 |
-
|
233 |
-
filename = filename_generator() if custom_name == '' else custom_name
|
234 |
-
filename += ".inpainting" if result_is_inpainting_model else ""
|
235 |
-
filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else ""
|
236 |
-
filename += "." + checkpoint_format
|
237 |
-
|
238 |
-
output_modelname = os.path.join(ckpt_dir, filename)
|
239 |
-
|
240 |
-
shared.state.nextjob()
|
241 |
-
shared.state.textinfo = "Saving"
|
242 |
-
print(f"Saving to {output_modelname}...")
|
243 |
-
|
244 |
-
_, extension = os.path.splitext(output_modelname)
|
245 |
-
if extension.lower() == ".safetensors":
|
246 |
-
safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
|
247 |
-
else:
|
248 |
-
torch.save(theta_0, output_modelname)
|
249 |
-
|
250 |
-
sd_models.list_models()
|
251 |
-
|
252 |
-
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
253 |
-
|
254 |
-
print(f"Checkpoint saved to {output_modelname}.")
|
255 |
-
shared.state.textinfo = "Checkpoint saved"
|
256 |
-
shared.state.end()
|
257 |
-
|
258 |
-
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import tqdm
|
8 |
+
|
9 |
+
from modules import shared, images, sd_models, sd_vae, sd_models_config
|
10 |
+
from modules.ui_common import plaintext_to_html
|
11 |
+
import gradio as gr
|
12 |
+
import safetensors.torch
|
13 |
+
|
14 |
+
|
15 |
+
def run_pnginfo(image):
|
16 |
+
if image is None:
|
17 |
+
return '', '', ''
|
18 |
+
|
19 |
+
geninfo, items = images.read_info_from_image(image)
|
20 |
+
items = {**{'parameters': geninfo}, **items}
|
21 |
+
|
22 |
+
info = ''
|
23 |
+
for key, text in items.items():
|
24 |
+
info += f"""
|
25 |
+
<div>
|
26 |
+
<p><b>{plaintext_to_html(str(key))}</b></p>
|
27 |
+
<p>{plaintext_to_html(str(text))}</p>
|
28 |
+
</div>
|
29 |
+
""".strip()+"\n"
|
30 |
+
|
31 |
+
if len(info) == 0:
|
32 |
+
message = "Nothing found in the image."
|
33 |
+
info = f"<div><p>{message}<p></div>"
|
34 |
+
|
35 |
+
return '', geninfo, info
|
36 |
+
|
37 |
+
|
38 |
+
def create_config(ckpt_result, config_source, a, b, c):
|
39 |
+
def config(x):
|
40 |
+
res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None
|
41 |
+
return res if res != shared.sd_default_config else None
|
42 |
+
|
43 |
+
if config_source == 0:
|
44 |
+
cfg = config(a) or config(b) or config(c)
|
45 |
+
elif config_source == 1:
|
46 |
+
cfg = config(b)
|
47 |
+
elif config_source == 2:
|
48 |
+
cfg = config(c)
|
49 |
+
else:
|
50 |
+
cfg = None
|
51 |
+
|
52 |
+
if cfg is None:
|
53 |
+
return
|
54 |
+
|
55 |
+
filename, _ = os.path.splitext(ckpt_result)
|
56 |
+
checkpoint_filename = filename + ".yaml"
|
57 |
+
|
58 |
+
print("Copying config:")
|
59 |
+
print(" from:", cfg)
|
60 |
+
print(" to:", checkpoint_filename)
|
61 |
+
shutil.copyfile(cfg, checkpoint_filename)
|
62 |
+
|
63 |
+
|
64 |
+
checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
65 |
+
|
66 |
+
|
67 |
+
def to_half(tensor, enable):
|
68 |
+
if enable and tensor.dtype == torch.float:
|
69 |
+
return tensor.half()
|
70 |
+
|
71 |
+
return tensor
|
72 |
+
|
73 |
+
|
74 |
+
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
|
75 |
+
shared.state.begin()
|
76 |
+
shared.state.job = 'model-merge'
|
77 |
+
|
78 |
+
def fail(message):
|
79 |
+
shared.state.textinfo = message
|
80 |
+
shared.state.end()
|
81 |
+
return [*[gr.update() for _ in range(4)], message]
|
82 |
+
|
83 |
+
def weighted_sum(theta0, theta1, alpha):
|
84 |
+
return ((1 - alpha) * theta0) + (alpha * theta1)
|
85 |
+
|
86 |
+
def get_difference(theta1, theta2):
|
87 |
+
return theta1 - theta2
|
88 |
+
|
89 |
+
def add_difference(theta0, theta1_2_diff, alpha):
|
90 |
+
return theta0 + (alpha * theta1_2_diff)
|
91 |
+
|
92 |
+
def filename_weighted_sum():
|
93 |
+
a = primary_model_info.model_name
|
94 |
+
b = secondary_model_info.model_name
|
95 |
+
Ma = round(1 - multiplier, 2)
|
96 |
+
Mb = round(multiplier, 2)
|
97 |
+
|
98 |
+
return f"{Ma}({a}) + {Mb}({b})"
|
99 |
+
|
100 |
+
def filename_add_difference():
|
101 |
+
a = primary_model_info.model_name
|
102 |
+
b = secondary_model_info.model_name
|
103 |
+
c = tertiary_model_info.model_name
|
104 |
+
M = round(multiplier, 2)
|
105 |
+
|
106 |
+
return f"{a} + {M}({b} - {c})"
|
107 |
+
|
108 |
+
def filename_nothing():
|
109 |
+
return primary_model_info.model_name
|
110 |
+
|
111 |
+
theta_funcs = {
|
112 |
+
"Weighted sum": (filename_weighted_sum, None, weighted_sum),
|
113 |
+
"Add difference": (filename_add_difference, get_difference, add_difference),
|
114 |
+
"No interpolation": (filename_nothing, None, None),
|
115 |
+
}
|
116 |
+
filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
|
117 |
+
shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
|
118 |
+
|
119 |
+
if not primary_model_name:
|
120 |
+
return fail("Failed: Merging requires a primary model.")
|
121 |
+
|
122 |
+
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
123 |
+
|
124 |
+
if theta_func2 and not secondary_model_name:
|
125 |
+
return fail("Failed: Merging requires a secondary model.")
|
126 |
+
|
127 |
+
secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
|
128 |
+
|
129 |
+
if theta_func1 and not tertiary_model_name:
|
130 |
+
return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
|
131 |
+
|
132 |
+
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
|
133 |
+
|
134 |
+
result_is_inpainting_model = False
|
135 |
+
result_is_instruct_pix2pix_model = False
|
136 |
+
|
137 |
+
if theta_func2:
|
138 |
+
shared.state.textinfo = f"Loading B"
|
139 |
+
print(f"Loading {secondary_model_info.filename}...")
|
140 |
+
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
141 |
+
else:
|
142 |
+
theta_1 = None
|
143 |
+
|
144 |
+
if theta_func1:
|
145 |
+
shared.state.textinfo = f"Loading C"
|
146 |
+
print(f"Loading {tertiary_model_info.filename}...")
|
147 |
+
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
148 |
+
|
149 |
+
shared.state.textinfo = 'Merging B and C'
|
150 |
+
shared.state.sampling_steps = len(theta_1.keys())
|
151 |
+
for key in tqdm.tqdm(theta_1.keys()):
|
152 |
+
if key in checkpoint_dict_skip_on_merge:
|
153 |
+
continue
|
154 |
+
|
155 |
+
if 'model' in key:
|
156 |
+
if key in theta_2:
|
157 |
+
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
|
158 |
+
theta_1[key] = theta_func1(theta_1[key], t2)
|
159 |
+
else:
|
160 |
+
theta_1[key] = torch.zeros_like(theta_1[key])
|
161 |
+
|
162 |
+
shared.state.sampling_step += 1
|
163 |
+
del theta_2
|
164 |
+
|
165 |
+
shared.state.nextjob()
|
166 |
+
|
167 |
+
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
|
168 |
+
print(f"Loading {primary_model_info.filename}...")
|
169 |
+
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
170 |
+
|
171 |
+
print("Merging...")
|
172 |
+
shared.state.textinfo = 'Merging A and B'
|
173 |
+
shared.state.sampling_steps = len(theta_0.keys())
|
174 |
+
for key in tqdm.tqdm(theta_0.keys()):
|
175 |
+
if theta_1 and 'model' in key and key in theta_1:
|
176 |
+
|
177 |
+
if key in checkpoint_dict_skip_on_merge:
|
178 |
+
continue
|
179 |
+
|
180 |
+
a = theta_0[key]
|
181 |
+
b = theta_1[key]
|
182 |
+
|
183 |
+
# this enables merging an inpainting model (A) with another one (B);
|
184 |
+
# where normal model would have 4 channels, for latenst space, inpainting model would
|
185 |
+
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
|
186 |
+
if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
|
187 |
+
if a.shape[1] == 4 and b.shape[1] == 9:
|
188 |
+
raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
|
189 |
+
if a.shape[1] == 4 and b.shape[1] == 8:
|
190 |
+
raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.")
|
191 |
+
|
192 |
+
if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model...
|
193 |
+
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
|
194 |
+
result_is_instruct_pix2pix_model = True
|
195 |
+
else:
|
196 |
+
assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
|
197 |
+
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
|
198 |
+
result_is_inpainting_model = True
|
199 |
+
else:
|
200 |
+
theta_0[key] = theta_func2(a, b, multiplier)
|
201 |
+
|
202 |
+
theta_0[key] = to_half(theta_0[key], save_as_half)
|
203 |
+
|
204 |
+
shared.state.sampling_step += 1
|
205 |
+
|
206 |
+
del theta_1
|
207 |
+
|
208 |
+
bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
|
209 |
+
if bake_in_vae_filename is not None:
|
210 |
+
print(f"Baking in VAE from {bake_in_vae_filename}")
|
211 |
+
shared.state.textinfo = 'Baking in VAE'
|
212 |
+
vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
|
213 |
+
|
214 |
+
for key in vae_dict.keys():
|
215 |
+
theta_0_key = 'first_stage_model.' + key
|
216 |
+
if theta_0_key in theta_0:
|
217 |
+
theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)
|
218 |
+
|
219 |
+
del vae_dict
|
220 |
+
|
221 |
+
if save_as_half and not theta_func2:
|
222 |
+
for key in theta_0.keys():
|
223 |
+
theta_0[key] = to_half(theta_0[key], save_as_half)
|
224 |
+
|
225 |
+
if discard_weights:
|
226 |
+
regex = re.compile(discard_weights)
|
227 |
+
for key in list(theta_0):
|
228 |
+
if re.search(regex, key):
|
229 |
+
theta_0.pop(key, None)
|
230 |
+
|
231 |
+
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
232 |
+
|
233 |
+
filename = filename_generator() if custom_name == '' else custom_name
|
234 |
+
filename += ".inpainting" if result_is_inpainting_model else ""
|
235 |
+
filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else ""
|
236 |
+
filename += "." + checkpoint_format
|
237 |
+
|
238 |
+
output_modelname = os.path.join(ckpt_dir, filename)
|
239 |
+
|
240 |
+
shared.state.nextjob()
|
241 |
+
shared.state.textinfo = "Saving"
|
242 |
+
print(f"Saving to {output_modelname}...")
|
243 |
+
|
244 |
+
_, extension = os.path.splitext(output_modelname)
|
245 |
+
if extension.lower() == ".safetensors":
|
246 |
+
safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
|
247 |
+
else:
|
248 |
+
torch.save(theta_0, output_modelname)
|
249 |
+
|
250 |
+
sd_models.list_models()
|
251 |
+
|
252 |
+
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
253 |
+
|
254 |
+
print(f"Checkpoint saved to {output_modelname}.")
|
255 |
+
shared.state.textinfo = "Checkpoint saved"
|
256 |
+
shared.state.end()
|
257 |
+
|
258 |
+
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
|
sd/stable-diffusion-webui/modules/face_restoration.py
CHANGED
@@ -1,19 +1,19 @@
|
|
1 |
-
from modules import shared
|
2 |
-
|
3 |
-
|
4 |
-
class FaceRestoration:
|
5 |
-
def name(self):
|
6 |
-
return "None"
|
7 |
-
|
8 |
-
def restore(self, np_image):
|
9 |
-
return np_image
|
10 |
-
|
11 |
-
|
12 |
-
def restore_faces(np_image):
|
13 |
-
face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
|
14 |
-
if len(face_restorers) == 0:
|
15 |
-
return np_image
|
16 |
-
|
17 |
-
face_restorer = face_restorers[0]
|
18 |
-
|
19 |
-
return face_restorer.restore(np_image)
|
|
|
1 |
+
from modules import shared
|
2 |
+
|
3 |
+
|
4 |
+
class FaceRestoration:
|
5 |
+
def name(self):
|
6 |
+
return "None"
|
7 |
+
|
8 |
+
def restore(self, np_image):
|
9 |
+
return np_image
|
10 |
+
|
11 |
+
|
12 |
+
def restore_faces(np_image):
|
13 |
+
face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
|
14 |
+
if len(face_restorers) == 0:
|
15 |
+
return np_image
|
16 |
+
|
17 |
+
face_restorer = face_restorers[0]
|
18 |
+
|
19 |
+
return face_restorer.restore(np_image)
|
sd/stable-diffusion-webui/modules/generation_parameters_copypaste.py
CHANGED
@@ -1,402 +1,408 @@
|
|
1 |
-
import base64
|
2 |
-
import html
|
3 |
-
import io
|
4 |
-
import math
|
5 |
-
import os
|
6 |
-
import re
|
7 |
-
from pathlib import Path
|
8 |
-
|
9 |
-
import gradio as gr
|
10 |
-
from modules.paths import data_path
|
11 |
-
from modules import shared, ui_tempdir, script_callbacks
|
12 |
-
import tempfile
|
13 |
-
from PIL import Image
|
14 |
-
|
15 |
-
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
16 |
-
re_param = re.compile(re_param_code)
|
17 |
-
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
18 |
-
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
|
19 |
-
type_of_gr_update = type(gr.update())
|
20 |
-
|
21 |
-
paste_fields = {}
|
22 |
-
registered_param_bindings = []
|
23 |
-
|
24 |
-
|
25 |
-
class ParamBinding:
|
26 |
-
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None):
|
27 |
-
self.paste_button = paste_button
|
28 |
-
self.tabname = tabname
|
29 |
-
self.source_text_component = source_text_component
|
30 |
-
self.source_image_component = source_image_component
|
31 |
-
self.source_tabname = source_tabname
|
32 |
-
self.override_settings_component = override_settings_component
|
33 |
-
|
34 |
-
|
35 |
-
def reset():
|
36 |
-
paste_fields.clear()
|
37 |
-
|
38 |
-
|
39 |
-
def quote(text):
|
40 |
-
if ',' not in str(text):
|
41 |
-
return text
|
42 |
-
|
43 |
-
text = str(text)
|
44 |
-
text = text.replace('\\', '\\\\')
|
45 |
-
text = text.replace('"', '\\"')
|
46 |
-
return f'"{text}"'
|
47 |
-
|
48 |
-
|
49 |
-
def image_from_url_text(filedata):
|
50 |
-
if filedata is None:
|
51 |
-
return None
|
52 |
-
|
53 |
-
if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
|
54 |
-
filedata = filedata[0]
|
55 |
-
|
56 |
-
if type(filedata) == dict and filedata.get("is_file", False):
|
57 |
-
filename = filedata["name"]
|
58 |
-
is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
|
59 |
-
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
|
60 |
-
|
61 |
-
return Image.open(filename)
|
62 |
-
|
63 |
-
if type(filedata) == list:
|
64 |
-
if len(filedata) == 0:
|
65 |
-
return None
|
66 |
-
|
67 |
-
filedata = filedata[0]
|
68 |
-
|
69 |
-
if filedata.startswith("data:image/png;base64,"):
|
70 |
-
filedata = filedata[len("data:image/png;base64,"):]
|
71 |
-
|
72 |
-
filedata = base64.decodebytes(filedata.encode('utf-8'))
|
73 |
-
image = Image.open(io.BytesIO(filedata))
|
74 |
-
return image
|
75 |
-
|
76 |
-
|
77 |
-
def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
|
78 |
-
paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
|
79 |
-
|
80 |
-
# backwards compatibility for existing extensions
|
81 |
-
import modules.ui
|
82 |
-
if tabname == 'txt2img':
|
83 |
-
modules.ui.txt2img_paste_fields = fields
|
84 |
-
elif tabname == 'img2img':
|
85 |
-
modules.ui.img2img_paste_fields = fields
|
86 |
-
|
87 |
-
|
88 |
-
def create_buttons(tabs_list):
|
89 |
-
buttons = {}
|
90 |
-
for tab in tabs_list:
|
91 |
-
buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
|
92 |
-
return buttons
|
93 |
-
|
94 |
-
|
95 |
-
def bind_buttons(buttons, send_image, send_generate_info):
|
96 |
-
"""old function for backwards compatibility; do not use this, use register_paste_params_button"""
|
97 |
-
for tabname, button in buttons.items():
|
98 |
-
source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
|
99 |
-
source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
|
100 |
-
|
101 |
-
register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
|
102 |
-
|
103 |
-
|
104 |
-
def register_paste_params_button(binding: ParamBinding):
|
105 |
-
registered_param_bindings.append(binding)
|
106 |
-
|
107 |
-
|
108 |
-
def connect_paste_params_buttons():
|
109 |
-
binding: ParamBinding
|
110 |
-
for binding in registered_param_bindings:
|
111 |
-
destination_image_component = paste_fields[binding.tabname]["init_img"]
|
112 |
-
fields = paste_fields[binding.tabname]["fields"]
|
113 |
-
override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
|
114 |
-
|
115 |
-
destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
|
116 |
-
destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
|
117 |
-
|
118 |
-
if binding.source_image_component and destination_image_component:
|
119 |
-
if isinstance(binding.source_image_component, gr.Gallery):
|
120 |
-
func = send_image_and_dimensions if destination_width_component else image_from_url_text
|
121 |
-
jsfunc = "extract_image_from_gallery"
|
122 |
-
else:
|
123 |
-
func = send_image_and_dimensions if destination_width_component else lambda x: x
|
124 |
-
jsfunc = None
|
125 |
-
|
126 |
-
binding.paste_button.click(
|
127 |
-
fn=func,
|
128 |
-
_js=jsfunc,
|
129 |
-
inputs=[binding.source_image_component],
|
130 |
-
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
|
131 |
-
)
|
132 |
-
|
133 |
-
if binding.source_text_component is not None and fields is not None:
|
134 |
-
connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
|
135 |
-
|
136 |
-
if binding.source_tabname is not None and fields is not None:
|
137 |
-
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
|
138 |
-
binding.paste_button.click(
|
139 |
-
fn=lambda *x: x,
|
140 |
-
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
|
141 |
-
outputs=[field for field, name in fields if name in paste_field_names],
|
142 |
-
)
|
143 |
-
|
144 |
-
binding.paste_button.click(
|
145 |
-
fn=None,
|
146 |
-
_js=f"switch_to_{binding.tabname}",
|
147 |
-
inputs=None,
|
148 |
-
outputs=None,
|
149 |
-
)
|
150 |
-
|
151 |
-
|
152 |
-
def send_image_and_dimensions(x):
|
153 |
-
if isinstance(x, Image.Image):
|
154 |
-
img = x
|
155 |
-
else:
|
156 |
-
img = image_from_url_text(x)
|
157 |
-
|
158 |
-
if shared.opts.send_size and isinstance(img, Image.Image):
|
159 |
-
w = img.width
|
160 |
-
h = img.height
|
161 |
-
else:
|
162 |
-
w = gr.update()
|
163 |
-
h = gr.update()
|
164 |
-
|
165 |
-
return img, w, h
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
|
170 |
-
"""Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
|
171 |
-
|
172 |
-
Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
|
173 |
-
parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
|
174 |
-
|
175 |
-
If the infotext has no hash, then a hypernet with the same name will be selected instead.
|
176 |
-
"""
|
177 |
-
hypernet_name = hypernet_name.lower()
|
178 |
-
if hypernet_hash is not None:
|
179 |
-
# Try to match the hash in the name
|
180 |
-
for hypernet_key in shared.hypernetworks.keys():
|
181 |
-
result = re_hypernet_hash.search(hypernet_key)
|
182 |
-
if result is not None and result[1] == hypernet_hash:
|
183 |
-
return hypernet_key
|
184 |
-
else:
|
185 |
-
# Fall back to a hypernet with the same name
|
186 |
-
for hypernet_key in shared.hypernetworks.keys():
|
187 |
-
if hypernet_key.lower().startswith(hypernet_name):
|
188 |
-
return hypernet_key
|
189 |
-
|
190 |
-
return None
|
191 |
-
|
192 |
-
|
193 |
-
def restore_old_hires_fix_params(res):
|
194 |
-
"""for infotexts that specify old First pass size parameter, convert it into
|
195 |
-
width, height, and hr scale"""
|
196 |
-
|
197 |
-
firstpass_width = res.get('First pass size-1', None)
|
198 |
-
firstpass_height = res.get('First pass size-2', None)
|
199 |
-
|
200 |
-
if shared.opts.use_old_hires_fix_width_height:
|
201 |
-
hires_width = int(res.get("Hires resize-1", 0))
|
202 |
-
hires_height = int(res.get("Hires resize-2", 0))
|
203 |
-
|
204 |
-
if hires_width and hires_height:
|
205 |
-
res['Size-1'] = hires_width
|
206 |
-
res['Size-2'] = hires_height
|
207 |
-
return
|
208 |
-
|
209 |
-
if firstpass_width is None or firstpass_height is None:
|
210 |
-
return
|
211 |
-
|
212 |
-
firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
|
213 |
-
width = int(res.get("Size-1", 512))
|
214 |
-
height = int(res.get("Size-2", 512))
|
215 |
-
|
216 |
-
if firstpass_width == 0 or firstpass_height == 0:
|
217 |
-
from modules import processing
|
218 |
-
firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
|
219 |
-
|
220 |
-
res['Size-1'] = firstpass_width
|
221 |
-
res['Size-2'] = firstpass_height
|
222 |
-
res['Hires resize-1'] = width
|
223 |
-
res['Hires resize-2'] = height
|
224 |
-
|
225 |
-
|
226 |
-
def parse_generation_parameters(x: str):
|
227 |
-
"""parses generation parameters string, the one you see in text field under the picture in UI:
|
228 |
-
```
|
229 |
-
girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
|
230 |
-
Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
|
231 |
-
Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
|
232 |
-
```
|
233 |
-
|
234 |
-
returns a dict with field values
|
235 |
-
"""
|
236 |
-
|
237 |
-
res = {}
|
238 |
-
|
239 |
-
prompt = ""
|
240 |
-
negative_prompt = ""
|
241 |
-
|
242 |
-
done_with_prompt = False
|
243 |
-
|
244 |
-
*lines, lastline = x.strip().split("\n")
|
245 |
-
if len(re_param.findall(lastline)) < 3:
|
246 |
-
lines.append(lastline)
|
247 |
-
lastline = ''
|
248 |
-
|
249 |
-
for i, line in enumerate(lines):
|
250 |
-
line = line.strip()
|
251 |
-
if line.startswith("Negative prompt:"):
|
252 |
-
done_with_prompt = True
|
253 |
-
line = line[16:].strip()
|
254 |
-
|
255 |
-
if done_with_prompt:
|
256 |
-
negative_prompt += ("" if negative_prompt == "" else "\n") + line
|
257 |
-
else:
|
258 |
-
prompt += ("" if prompt == "" else "\n") + line
|
259 |
-
|
260 |
-
res["Prompt"] = prompt
|
261 |
-
res["Negative prompt"] = negative_prompt
|
262 |
-
|
263 |
-
for k, v in re_param.findall(lastline):
|
264 |
-
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
|
265 |
-
m = re_imagesize.match(v)
|
266 |
-
if m is not None:
|
267 |
-
res[k+"-1"] = m.group(1)
|
268 |
-
res[k+"-2"] = m.group(2)
|
269 |
-
else:
|
270 |
-
res[k] = v
|
271 |
-
|
272 |
-
# Missing CLIP skip means it was set to 1 (the default)
|
273 |
-
if "Clip skip" not in res:
|
274 |
-
res["Clip skip"] = "1"
|
275 |
-
|
276 |
-
hypernet = res.get("Hypernet", None)
|
277 |
-
if hypernet is not None:
|
278 |
-
res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
|
279 |
-
|
280 |
-
if "Hires resize-1" not in res:
|
281 |
-
res["Hires resize-1"] = 0
|
282 |
-
res["Hires resize-2"] = 0
|
283 |
-
|
284 |
-
restore_old_hires_fix_params(res)
|
285 |
-
|
286 |
-
return res
|
287 |
-
|
288 |
-
|
289 |
-
settings_map = {}
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
('
|
295 |
-
('
|
296 |
-
('
|
297 |
-
('
|
298 |
-
('
|
299 |
-
('
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
v
|
382 |
-
|
383 |
-
|
384 |
-
if
|
385 |
-
continue
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import html
|
3 |
+
import io
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
from modules.paths import data_path
|
11 |
+
from modules import shared, ui_tempdir, script_callbacks
|
12 |
+
import tempfile
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
16 |
+
re_param = re.compile(re_param_code)
|
17 |
+
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
18 |
+
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
|
19 |
+
type_of_gr_update = type(gr.update())
|
20 |
+
|
21 |
+
paste_fields = {}
|
22 |
+
registered_param_bindings = []
|
23 |
+
|
24 |
+
|
25 |
+
class ParamBinding:
|
26 |
+
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None):
|
27 |
+
self.paste_button = paste_button
|
28 |
+
self.tabname = tabname
|
29 |
+
self.source_text_component = source_text_component
|
30 |
+
self.source_image_component = source_image_component
|
31 |
+
self.source_tabname = source_tabname
|
32 |
+
self.override_settings_component = override_settings_component
|
33 |
+
|
34 |
+
|
35 |
+
def reset():
|
36 |
+
paste_fields.clear()
|
37 |
+
|
38 |
+
|
39 |
+
def quote(text):
|
40 |
+
if ',' not in str(text):
|
41 |
+
return text
|
42 |
+
|
43 |
+
text = str(text)
|
44 |
+
text = text.replace('\\', '\\\\')
|
45 |
+
text = text.replace('"', '\\"')
|
46 |
+
return f'"{text}"'
|
47 |
+
|
48 |
+
|
49 |
+
def image_from_url_text(filedata):
|
50 |
+
if filedata is None:
|
51 |
+
return None
|
52 |
+
|
53 |
+
if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
|
54 |
+
filedata = filedata[0]
|
55 |
+
|
56 |
+
if type(filedata) == dict and filedata.get("is_file", False):
|
57 |
+
filename = filedata["name"]
|
58 |
+
is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
|
59 |
+
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
|
60 |
+
|
61 |
+
return Image.open(filename)
|
62 |
+
|
63 |
+
if type(filedata) == list:
|
64 |
+
if len(filedata) == 0:
|
65 |
+
return None
|
66 |
+
|
67 |
+
filedata = filedata[0]
|
68 |
+
|
69 |
+
if filedata.startswith("data:image/png;base64,"):
|
70 |
+
filedata = filedata[len("data:image/png;base64,"):]
|
71 |
+
|
72 |
+
filedata = base64.decodebytes(filedata.encode('utf-8'))
|
73 |
+
image = Image.open(io.BytesIO(filedata))
|
74 |
+
return image
|
75 |
+
|
76 |
+
|
77 |
+
def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
|
78 |
+
paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
|
79 |
+
|
80 |
+
# backwards compatibility for existing extensions
|
81 |
+
import modules.ui
|
82 |
+
if tabname == 'txt2img':
|
83 |
+
modules.ui.txt2img_paste_fields = fields
|
84 |
+
elif tabname == 'img2img':
|
85 |
+
modules.ui.img2img_paste_fields = fields
|
86 |
+
|
87 |
+
|
88 |
+
def create_buttons(tabs_list):
|
89 |
+
buttons = {}
|
90 |
+
for tab in tabs_list:
|
91 |
+
buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
|
92 |
+
return buttons
|
93 |
+
|
94 |
+
|
95 |
+
def bind_buttons(buttons, send_image, send_generate_info):
|
96 |
+
"""old function for backwards compatibility; do not use this, use register_paste_params_button"""
|
97 |
+
for tabname, button in buttons.items():
|
98 |
+
source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
|
99 |
+
source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
|
100 |
+
|
101 |
+
register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
|
102 |
+
|
103 |
+
|
104 |
+
def register_paste_params_button(binding: ParamBinding):
|
105 |
+
registered_param_bindings.append(binding)
|
106 |
+
|
107 |
+
|
108 |
+
def connect_paste_params_buttons():
|
109 |
+
binding: ParamBinding
|
110 |
+
for binding in registered_param_bindings:
|
111 |
+
destination_image_component = paste_fields[binding.tabname]["init_img"]
|
112 |
+
fields = paste_fields[binding.tabname]["fields"]
|
113 |
+
override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
|
114 |
+
|
115 |
+
destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
|
116 |
+
destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
|
117 |
+
|
118 |
+
if binding.source_image_component and destination_image_component:
|
119 |
+
if isinstance(binding.source_image_component, gr.Gallery):
|
120 |
+
func = send_image_and_dimensions if destination_width_component else image_from_url_text
|
121 |
+
jsfunc = "extract_image_from_gallery"
|
122 |
+
else:
|
123 |
+
func = send_image_and_dimensions if destination_width_component else lambda x: x
|
124 |
+
jsfunc = None
|
125 |
+
|
126 |
+
binding.paste_button.click(
|
127 |
+
fn=func,
|
128 |
+
_js=jsfunc,
|
129 |
+
inputs=[binding.source_image_component],
|
130 |
+
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
|
131 |
+
)
|
132 |
+
|
133 |
+
if binding.source_text_component is not None and fields is not None:
|
134 |
+
connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
|
135 |
+
|
136 |
+
if binding.source_tabname is not None and fields is not None:
|
137 |
+
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
|
138 |
+
binding.paste_button.click(
|
139 |
+
fn=lambda *x: x,
|
140 |
+
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
|
141 |
+
outputs=[field for field, name in fields if name in paste_field_names],
|
142 |
+
)
|
143 |
+
|
144 |
+
binding.paste_button.click(
|
145 |
+
fn=None,
|
146 |
+
_js=f"switch_to_{binding.tabname}",
|
147 |
+
inputs=None,
|
148 |
+
outputs=None,
|
149 |
+
)
|
150 |
+
|
151 |
+
|
152 |
+
def send_image_and_dimensions(x):
|
153 |
+
if isinstance(x, Image.Image):
|
154 |
+
img = x
|
155 |
+
else:
|
156 |
+
img = image_from_url_text(x)
|
157 |
+
|
158 |
+
if shared.opts.send_size and isinstance(img, Image.Image):
|
159 |
+
w = img.width
|
160 |
+
h = img.height
|
161 |
+
else:
|
162 |
+
w = gr.update()
|
163 |
+
h = gr.update()
|
164 |
+
|
165 |
+
return img, w, h
|
166 |
+
|
167 |
+
|
168 |
+
|
169 |
+
def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
|
170 |
+
"""Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
|
171 |
+
|
172 |
+
Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
|
173 |
+
parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
|
174 |
+
|
175 |
+
If the infotext has no hash, then a hypernet with the same name will be selected instead.
|
176 |
+
"""
|
177 |
+
hypernet_name = hypernet_name.lower()
|
178 |
+
if hypernet_hash is not None:
|
179 |
+
# Try to match the hash in the name
|
180 |
+
for hypernet_key in shared.hypernetworks.keys():
|
181 |
+
result = re_hypernet_hash.search(hypernet_key)
|
182 |
+
if result is not None and result[1] == hypernet_hash:
|
183 |
+
return hypernet_key
|
184 |
+
else:
|
185 |
+
# Fall back to a hypernet with the same name
|
186 |
+
for hypernet_key in shared.hypernetworks.keys():
|
187 |
+
if hypernet_key.lower().startswith(hypernet_name):
|
188 |
+
return hypernet_key
|
189 |
+
|
190 |
+
return None
|
191 |
+
|
192 |
+
|
193 |
+
def restore_old_hires_fix_params(res):
|
194 |
+
"""for infotexts that specify old First pass size parameter, convert it into
|
195 |
+
width, height, and hr scale"""
|
196 |
+
|
197 |
+
firstpass_width = res.get('First pass size-1', None)
|
198 |
+
firstpass_height = res.get('First pass size-2', None)
|
199 |
+
|
200 |
+
if shared.opts.use_old_hires_fix_width_height:
|
201 |
+
hires_width = int(res.get("Hires resize-1", 0))
|
202 |
+
hires_height = int(res.get("Hires resize-2", 0))
|
203 |
+
|
204 |
+
if hires_width and hires_height:
|
205 |
+
res['Size-1'] = hires_width
|
206 |
+
res['Size-2'] = hires_height
|
207 |
+
return
|
208 |
+
|
209 |
+
if firstpass_width is None or firstpass_height is None:
|
210 |
+
return
|
211 |
+
|
212 |
+
firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
|
213 |
+
width = int(res.get("Size-1", 512))
|
214 |
+
height = int(res.get("Size-2", 512))
|
215 |
+
|
216 |
+
if firstpass_width == 0 or firstpass_height == 0:
|
217 |
+
from modules import processing
|
218 |
+
firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
|
219 |
+
|
220 |
+
res['Size-1'] = firstpass_width
|
221 |
+
res['Size-2'] = firstpass_height
|
222 |
+
res['Hires resize-1'] = width
|
223 |
+
res['Hires resize-2'] = height
|
224 |
+
|
225 |
+
|
226 |
+
def parse_generation_parameters(x: str):
|
227 |
+
"""parses generation parameters string, the one you see in text field under the picture in UI:
|
228 |
+
```
|
229 |
+
girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
|
230 |
+
Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
|
231 |
+
Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
|
232 |
+
```
|
233 |
+
|
234 |
+
returns a dict with field values
|
235 |
+
"""
|
236 |
+
|
237 |
+
res = {}
|
238 |
+
|
239 |
+
prompt = ""
|
240 |
+
negative_prompt = ""
|
241 |
+
|
242 |
+
done_with_prompt = False
|
243 |
+
|
244 |
+
*lines, lastline = x.strip().split("\n")
|
245 |
+
if len(re_param.findall(lastline)) < 3:
|
246 |
+
lines.append(lastline)
|
247 |
+
lastline = ''
|
248 |
+
|
249 |
+
for i, line in enumerate(lines):
|
250 |
+
line = line.strip()
|
251 |
+
if line.startswith("Negative prompt:"):
|
252 |
+
done_with_prompt = True
|
253 |
+
line = line[16:].strip()
|
254 |
+
|
255 |
+
if done_with_prompt:
|
256 |
+
negative_prompt += ("" if negative_prompt == "" else "\n") + line
|
257 |
+
else:
|
258 |
+
prompt += ("" if prompt == "" else "\n") + line
|
259 |
+
|
260 |
+
res["Prompt"] = prompt
|
261 |
+
res["Negative prompt"] = negative_prompt
|
262 |
+
|
263 |
+
for k, v in re_param.findall(lastline):
|
264 |
+
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
|
265 |
+
m = re_imagesize.match(v)
|
266 |
+
if m is not None:
|
267 |
+
res[k+"-1"] = m.group(1)
|
268 |
+
res[k+"-2"] = m.group(2)
|
269 |
+
else:
|
270 |
+
res[k] = v
|
271 |
+
|
272 |
+
# Missing CLIP skip means it was set to 1 (the default)
|
273 |
+
if "Clip skip" not in res:
|
274 |
+
res["Clip skip"] = "1"
|
275 |
+
|
276 |
+
hypernet = res.get("Hypernet", None)
|
277 |
+
if hypernet is not None:
|
278 |
+
res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
|
279 |
+
|
280 |
+
if "Hires resize-1" not in res:
|
281 |
+
res["Hires resize-1"] = 0
|
282 |
+
res["Hires resize-2"] = 0
|
283 |
+
|
284 |
+
restore_old_hires_fix_params(res)
|
285 |
+
|
286 |
+
return res
|
287 |
+
|
288 |
+
|
289 |
+
settings_map = {}
|
290 |
+
|
291 |
+
|
292 |
+
|
293 |
+
infotext_to_setting_name_mapping = [
|
294 |
+
('Clip skip', 'CLIP_stop_at_last_layers', ),
|
295 |
+
('Conditional mask weight', 'inpainting_mask_weight'),
|
296 |
+
('Model hash', 'sd_model_checkpoint'),
|
297 |
+
('ENSD', 'eta_noise_seed_delta'),
|
298 |
+
('Noise multiplier', 'initial_noise_multiplier'),
|
299 |
+
('Eta', 'eta_ancestral'),
|
300 |
+
('Eta DDIM', 'eta_ddim'),
|
301 |
+
('Discard penultimate sigma', 'always_discard_next_to_last_sigma'),
|
302 |
+
('UniPC variant', 'uni_pc_variant'),
|
303 |
+
('UniPC skip type', 'uni_pc_skip_type'),
|
304 |
+
('UniPC order', 'uni_pc_order'),
|
305 |
+
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
306 |
+
]
|
307 |
+
|
308 |
+
|
309 |
+
def create_override_settings_dict(text_pairs):
|
310 |
+
"""creates processing's override_settings parameters from gradio's multiselect
|
311 |
+
|
312 |
+
Example input:
|
313 |
+
['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']
|
314 |
+
|
315 |
+
Example output:
|
316 |
+
{'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
|
317 |
+
"""
|
318 |
+
|
319 |
+
res = {}
|
320 |
+
|
321 |
+
params = {}
|
322 |
+
for pair in text_pairs:
|
323 |
+
k, v = pair.split(":", maxsplit=1)
|
324 |
+
|
325 |
+
params[k] = v.strip()
|
326 |
+
|
327 |
+
for param_name, setting_name in infotext_to_setting_name_mapping:
|
328 |
+
value = params.get(param_name, None)
|
329 |
+
|
330 |
+
if value is None:
|
331 |
+
continue
|
332 |
+
|
333 |
+
res[setting_name] = shared.opts.cast_value(setting_name, value)
|
334 |
+
|
335 |
+
return res
|
336 |
+
|
337 |
+
|
338 |
+
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
|
339 |
+
def paste_func(prompt):
|
340 |
+
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
|
341 |
+
filename = os.path.join(data_path, "params.txt")
|
342 |
+
if os.path.exists(filename):
|
343 |
+
with open(filename, "r", encoding="utf8") as file:
|
344 |
+
prompt = file.read()
|
345 |
+
|
346 |
+
params = parse_generation_parameters(prompt)
|
347 |
+
script_callbacks.infotext_pasted_callback(prompt, params)
|
348 |
+
res = []
|
349 |
+
|
350 |
+
for output, key in paste_fields:
|
351 |
+
if callable(key):
|
352 |
+
v = key(params)
|
353 |
+
else:
|
354 |
+
v = params.get(key, None)
|
355 |
+
|
356 |
+
if v is None:
|
357 |
+
res.append(gr.update())
|
358 |
+
elif isinstance(v, type_of_gr_update):
|
359 |
+
res.append(v)
|
360 |
+
else:
|
361 |
+
try:
|
362 |
+
valtype = type(output.value)
|
363 |
+
|
364 |
+
if valtype == bool and v == "False":
|
365 |
+
val = False
|
366 |
+
else:
|
367 |
+
val = valtype(v)
|
368 |
+
|
369 |
+
res.append(gr.update(value=val))
|
370 |
+
except Exception:
|
371 |
+
res.append(gr.update())
|
372 |
+
|
373 |
+
return res
|
374 |
+
|
375 |
+
if override_settings_component is not None:
|
376 |
+
def paste_settings(params):
|
377 |
+
vals = {}
|
378 |
+
|
379 |
+
for param_name, setting_name in infotext_to_setting_name_mapping:
|
380 |
+
v = params.get(param_name, None)
|
381 |
+
if v is None:
|
382 |
+
continue
|
383 |
+
|
384 |
+
if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
|
385 |
+
continue
|
386 |
+
|
387 |
+
v = shared.opts.cast_value(setting_name, v)
|
388 |
+
current_value = getattr(shared.opts, setting_name, None)
|
389 |
+
|
390 |
+
if v == current_value:
|
391 |
+
continue
|
392 |
+
|
393 |
+
vals[param_name] = v
|
394 |
+
|
395 |
+
vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
|
396 |
+
|
397 |
+
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0)
|
398 |
+
|
399 |
+
paste_fields = paste_fields + [(override_settings_component, paste_settings)]
|
400 |
+
|
401 |
+
button.click(
|
402 |
+
fn=paste_func,
|
403 |
+
_js=f"recalculate_prompts_{tabname}",
|
404 |
+
inputs=[input_comp],
|
405 |
+
outputs=[x[0] for x in paste_fields],
|
406 |
+
)
|
407 |
+
|
408 |
+
|
sd/stable-diffusion-webui/modules/gfpgan_model.py
CHANGED
@@ -1,116 +1,116 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import traceback
|
4 |
-
|
5 |
-
import facexlib
|
6 |
-
import gfpgan
|
7 |
-
|
8 |
-
import modules.face_restoration
|
9 |
-
from modules import paths, shared, devices, modelloader
|
10 |
-
|
11 |
-
model_dir = "GFPGAN"
|
12 |
-
user_path = None
|
13 |
-
model_path = os.path.join(paths.models_path, model_dir)
|
14 |
-
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
15 |
-
have_gfpgan = False
|
16 |
-
loaded_gfpgan_model = None
|
17 |
-
|
18 |
-
|
19 |
-
def gfpgann():
|
20 |
-
global loaded_gfpgan_model
|
21 |
-
global model_path
|
22 |
-
if loaded_gfpgan_model is not None:
|
23 |
-
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
|
24 |
-
return loaded_gfpgan_model
|
25 |
-
|
26 |
-
if gfpgan_constructor is None:
|
27 |
-
return None
|
28 |
-
|
29 |
-
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
|
30 |
-
if len(models) == 1 and "http" in models[0]:
|
31 |
-
model_file = models[0]
|
32 |
-
elif len(models) != 0:
|
33 |
-
latest_file = max(models, key=os.path.getctime)
|
34 |
-
model_file = latest_file
|
35 |
-
else:
|
36 |
-
print("Unable to load gfpgan model!")
|
37 |
-
return None
|
38 |
-
if hasattr(facexlib.detection.retinaface, 'device'):
|
39 |
-
facexlib.detection.retinaface.device = devices.device_gfpgan
|
40 |
-
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
|
41 |
-
loaded_gfpgan_model = model
|
42 |
-
|
43 |
-
return model
|
44 |
-
|
45 |
-
|
46 |
-
def send_model_to(model, device):
|
47 |
-
model.gfpgan.to(device)
|
48 |
-
model.face_helper.face_det.to(device)
|
49 |
-
model.face_helper.face_parse.to(device)
|
50 |
-
|
51 |
-
|
52 |
-
def gfpgan_fix_faces(np_image):
|
53 |
-
model = gfpgann()
|
54 |
-
if model is None:
|
55 |
-
return np_image
|
56 |
-
|
57 |
-
send_model_to(model, devices.device_gfpgan)
|
58 |
-
|
59 |
-
np_image_bgr = np_image[:, :, ::-1]
|
60 |
-
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
61 |
-
np_image = gfpgan_output_bgr[:, :, ::-1]
|
62 |
-
|
63 |
-
model.face_helper.clean_all()
|
64 |
-
|
65 |
-
if shared.opts.face_restoration_unload:
|
66 |
-
send_model_to(model, devices.cpu)
|
67 |
-
|
68 |
-
return np_image
|
69 |
-
|
70 |
-
|
71 |
-
gfpgan_constructor = None
|
72 |
-
|
73 |
-
|
74 |
-
def setup_model(dirname):
|
75 |
-
global model_path
|
76 |
-
if not os.path.exists(model_path):
|
77 |
-
os.makedirs(model_path)
|
78 |
-
|
79 |
-
try:
|
80 |
-
from gfpgan import GFPGANer
|
81 |
-
from facexlib import detection, parsing
|
82 |
-
global user_path
|
83 |
-
global have_gfpgan
|
84 |
-
global gfpgan_constructor
|
85 |
-
|
86 |
-
load_file_from_url_orig = gfpgan.utils.load_file_from_url
|
87 |
-
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
|
88 |
-
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
|
89 |
-
|
90 |
-
def my_load_file_from_url(**kwargs):
|
91 |
-
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
|
92 |
-
|
93 |
-
def facex_load_file_from_url(**kwargs):
|
94 |
-
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
|
95 |
-
|
96 |
-
def facex_load_file_from_url2(**kwargs):
|
97 |
-
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
|
98 |
-
|
99 |
-
gfpgan.utils.load_file_from_url = my_load_file_from_url
|
100 |
-
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
101 |
-
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
|
102 |
-
user_path = dirname
|
103 |
-
have_gfpgan = True
|
104 |
-
gfpgan_constructor = GFPGANer
|
105 |
-
|
106 |
-
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
|
107 |
-
def name(self):
|
108 |
-
return "GFPGAN"
|
109 |
-
|
110 |
-
def restore(self, np_image):
|
111 |
-
return gfpgan_fix_faces(np_image)
|
112 |
-
|
113 |
-
shared.face_restorers.append(FaceRestorerGFPGAN())
|
114 |
-
except Exception:
|
115 |
-
print("Error setting up GFPGAN:", file=sys.stderr)
|
116 |
-
print(traceback.format_exc(), file=sys.stderr)
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import facexlib
|
6 |
+
import gfpgan
|
7 |
+
|
8 |
+
import modules.face_restoration
|
9 |
+
from modules import paths, shared, devices, modelloader
|
10 |
+
|
11 |
+
model_dir = "GFPGAN"
|
12 |
+
user_path = None
|
13 |
+
model_path = os.path.join(paths.models_path, model_dir)
|
14 |
+
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
15 |
+
have_gfpgan = False
|
16 |
+
loaded_gfpgan_model = None
|
17 |
+
|
18 |
+
|
19 |
+
def gfpgann():
|
20 |
+
global loaded_gfpgan_model
|
21 |
+
global model_path
|
22 |
+
if loaded_gfpgan_model is not None:
|
23 |
+
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
|
24 |
+
return loaded_gfpgan_model
|
25 |
+
|
26 |
+
if gfpgan_constructor is None:
|
27 |
+
return None
|
28 |
+
|
29 |
+
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
|
30 |
+
if len(models) == 1 and "http" in models[0]:
|
31 |
+
model_file = models[0]
|
32 |
+
elif len(models) != 0:
|
33 |
+
latest_file = max(models, key=os.path.getctime)
|
34 |
+
model_file = latest_file
|
35 |
+
else:
|
36 |
+
print("Unable to load gfpgan model!")
|
37 |
+
return None
|
38 |
+
if hasattr(facexlib.detection.retinaface, 'device'):
|
39 |
+
facexlib.detection.retinaface.device = devices.device_gfpgan
|
40 |
+
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
|
41 |
+
loaded_gfpgan_model = model
|
42 |
+
|
43 |
+
return model
|
44 |
+
|
45 |
+
|
46 |
+
def send_model_to(model, device):
|
47 |
+
model.gfpgan.to(device)
|
48 |
+
model.face_helper.face_det.to(device)
|
49 |
+
model.face_helper.face_parse.to(device)
|
50 |
+
|
51 |
+
|
52 |
+
def gfpgan_fix_faces(np_image):
|
53 |
+
model = gfpgann()
|
54 |
+
if model is None:
|
55 |
+
return np_image
|
56 |
+
|
57 |
+
send_model_to(model, devices.device_gfpgan)
|
58 |
+
|
59 |
+
np_image_bgr = np_image[:, :, ::-1]
|
60 |
+
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
61 |
+
np_image = gfpgan_output_bgr[:, :, ::-1]
|
62 |
+
|
63 |
+
model.face_helper.clean_all()
|
64 |
+
|
65 |
+
if shared.opts.face_restoration_unload:
|
66 |
+
send_model_to(model, devices.cpu)
|
67 |
+
|
68 |
+
return np_image
|
69 |
+
|
70 |
+
|
71 |
+
gfpgan_constructor = None
|
72 |
+
|
73 |
+
|
74 |
+
def setup_model(dirname):
|
75 |
+
global model_path
|
76 |
+
if not os.path.exists(model_path):
|
77 |
+
os.makedirs(model_path)
|
78 |
+
|
79 |
+
try:
|
80 |
+
from gfpgan import GFPGANer
|
81 |
+
from facexlib import detection, parsing
|
82 |
+
global user_path
|
83 |
+
global have_gfpgan
|
84 |
+
global gfpgan_constructor
|
85 |
+
|
86 |
+
load_file_from_url_orig = gfpgan.utils.load_file_from_url
|
87 |
+
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
|
88 |
+
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
|
89 |
+
|
90 |
+
def my_load_file_from_url(**kwargs):
|
91 |
+
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
|
92 |
+
|
93 |
+
def facex_load_file_from_url(**kwargs):
|
94 |
+
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
|
95 |
+
|
96 |
+
def facex_load_file_from_url2(**kwargs):
|
97 |
+
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
|
98 |
+
|
99 |
+
gfpgan.utils.load_file_from_url = my_load_file_from_url
|
100 |
+
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
101 |
+
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
|
102 |
+
user_path = dirname
|
103 |
+
have_gfpgan = True
|
104 |
+
gfpgan_constructor = GFPGANer
|
105 |
+
|
106 |
+
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
|
107 |
+
def name(self):
|
108 |
+
return "GFPGAN"
|
109 |
+
|
110 |
+
def restore(self, np_image):
|
111 |
+
return gfpgan_fix_faces(np_image)
|
112 |
+
|
113 |
+
shared.face_restorers.append(FaceRestorerGFPGAN())
|
114 |
+
except Exception:
|
115 |
+
print("Error setting up GFPGAN:", file=sys.stderr)
|
116 |
+
print(traceback.format_exc(), file=sys.stderr)
|
sd/stable-diffusion-webui/modules/hashes.py
CHANGED
@@ -1,91 +1,91 @@
|
|
1 |
-
import hashlib
|
2 |
-
import json
|
3 |
-
import os.path
|
4 |
-
|
5 |
-
import filelock
|
6 |
-
|
7 |
-
from modules import shared
|
8 |
-
from modules.paths import data_path
|
9 |
-
|
10 |
-
|
11 |
-
cache_filename = os.path.join(data_path, "cache.json")
|
12 |
-
cache_data = None
|
13 |
-
|
14 |
-
|
15 |
-
def dump_cache():
|
16 |
-
with filelock.FileLock(cache_filename+".lock"):
|
17 |
-
with open(cache_filename, "w", encoding="utf8") as file:
|
18 |
-
json.dump(cache_data, file, indent=4)
|
19 |
-
|
20 |
-
|
21 |
-
def cache(subsection):
|
22 |
-
global cache_data
|
23 |
-
|
24 |
-
if cache_data is None:
|
25 |
-
with filelock.FileLock(cache_filename+".lock"):
|
26 |
-
if not os.path.isfile(cache_filename):
|
27 |
-
cache_data = {}
|
28 |
-
else:
|
29 |
-
with open(cache_filename, "r", encoding="utf8") as file:
|
30 |
-
cache_data = json.load(file)
|
31 |
-
|
32 |
-
s = cache_data.get(subsection, {})
|
33 |
-
cache_data[subsection] = s
|
34 |
-
|
35 |
-
return s
|
36 |
-
|
37 |
-
|
38 |
-
def calculate_sha256(filename):
|
39 |
-
hash_sha256 = hashlib.sha256()
|
40 |
-
blksize = 1024 * 1024
|
41 |
-
|
42 |
-
with open(filename, "rb") as f:
|
43 |
-
for chunk in iter(lambda: f.read(blksize), b""):
|
44 |
-
hash_sha256.update(chunk)
|
45 |
-
|
46 |
-
return hash_sha256.hexdigest()
|
47 |
-
|
48 |
-
|
49 |
-
def sha256_from_cache(filename, title):
|
50 |
-
hashes = cache("hashes")
|
51 |
-
ondisk_mtime = os.path.getmtime(filename)
|
52 |
-
|
53 |
-
if title not in hashes:
|
54 |
-
return None
|
55 |
-
|
56 |
-
cached_sha256 = hashes[title].get("sha256", None)
|
57 |
-
cached_mtime = hashes[title].get("mtime", 0)
|
58 |
-
|
59 |
-
if ondisk_mtime > cached_mtime or cached_sha256 is None:
|
60 |
-
return None
|
61 |
-
|
62 |
-
return cached_sha256
|
63 |
-
|
64 |
-
|
65 |
-
def sha256(filename, title):
|
66 |
-
hashes = cache("hashes")
|
67 |
-
|
68 |
-
sha256_value = sha256_from_cache(filename, title)
|
69 |
-
if sha256_value is not None:
|
70 |
-
return sha256_value
|
71 |
-
|
72 |
-
if shared.cmd_opts.no_hashing:
|
73 |
-
return None
|
74 |
-
|
75 |
-
print(f"Calculating sha256 for {filename}: ", end='')
|
76 |
-
sha256_value = calculate_sha256(filename)
|
77 |
-
print(f"{sha256_value}")
|
78 |
-
|
79 |
-
hashes[title] = {
|
80 |
-
"mtime": os.path.getmtime(filename),
|
81 |
-
"sha256": sha256_value,
|
82 |
-
}
|
83 |
-
|
84 |
-
dump_cache()
|
85 |
-
|
86 |
-
return sha256_value
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
1 |
+
import hashlib
|
2 |
+
import json
|
3 |
+
import os.path
|
4 |
+
|
5 |
+
import filelock
|
6 |
+
|
7 |
+
from modules import shared
|
8 |
+
from modules.paths import data_path
|
9 |
+
|
10 |
+
|
11 |
+
cache_filename = os.path.join(data_path, "cache.json")
|
12 |
+
cache_data = None
|
13 |
+
|
14 |
+
|
15 |
+
def dump_cache():
|
16 |
+
with filelock.FileLock(cache_filename+".lock"):
|
17 |
+
with open(cache_filename, "w", encoding="utf8") as file:
|
18 |
+
json.dump(cache_data, file, indent=4)
|
19 |
+
|
20 |
+
|
21 |
+
def cache(subsection):
|
22 |
+
global cache_data
|
23 |
+
|
24 |
+
if cache_data is None:
|
25 |
+
with filelock.FileLock(cache_filename+".lock"):
|
26 |
+
if not os.path.isfile(cache_filename):
|
27 |
+
cache_data = {}
|
28 |
+
else:
|
29 |
+
with open(cache_filename, "r", encoding="utf8") as file:
|
30 |
+
cache_data = json.load(file)
|
31 |
+
|
32 |
+
s = cache_data.get(subsection, {})
|
33 |
+
cache_data[subsection] = s
|
34 |
+
|
35 |
+
return s
|
36 |
+
|
37 |
+
|
38 |
+
def calculate_sha256(filename):
|
39 |
+
hash_sha256 = hashlib.sha256()
|
40 |
+
blksize = 1024 * 1024
|
41 |
+
|
42 |
+
with open(filename, "rb") as f:
|
43 |
+
for chunk in iter(lambda: f.read(blksize), b""):
|
44 |
+
hash_sha256.update(chunk)
|
45 |
+
|
46 |
+
return hash_sha256.hexdigest()
|
47 |
+
|
48 |
+
|
49 |
+
def sha256_from_cache(filename, title):
|
50 |
+
hashes = cache("hashes")
|
51 |
+
ondisk_mtime = os.path.getmtime(filename)
|
52 |
+
|
53 |
+
if title not in hashes:
|
54 |
+
return None
|
55 |
+
|
56 |
+
cached_sha256 = hashes[title].get("sha256", None)
|
57 |
+
cached_mtime = hashes[title].get("mtime", 0)
|
58 |
+
|
59 |
+
if ondisk_mtime > cached_mtime or cached_sha256 is None:
|
60 |
+
return None
|
61 |
+
|
62 |
+
return cached_sha256
|
63 |
+
|
64 |
+
|
65 |
+
def sha256(filename, title):
|
66 |
+
hashes = cache("hashes")
|
67 |
+
|
68 |
+
sha256_value = sha256_from_cache(filename, title)
|
69 |
+
if sha256_value is not None:
|
70 |
+
return sha256_value
|
71 |
+
|
72 |
+
if shared.cmd_opts.no_hashing:
|
73 |
+
return None
|
74 |
+
|
75 |
+
print(f"Calculating sha256 for {filename}: ", end='')
|
76 |
+
sha256_value = calculate_sha256(filename)
|
77 |
+
print(f"{sha256_value}")
|
78 |
+
|
79 |
+
hashes[title] = {
|
80 |
+
"mtime": os.path.getmtime(filename),
|
81 |
+
"sha256": sha256_value,
|
82 |
+
}
|
83 |
+
|
84 |
+
dump_cache()
|
85 |
+
|
86 |
+
return sha256_value
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
|
sd/stable-diffusion-webui/modules/hypernetworks/hypernetwork.py
CHANGED
@@ -1,811 +1,811 @@
|
|
1 |
-
import csv
|
2 |
-
import datetime
|
3 |
-
import glob
|
4 |
-
import html
|
5 |
-
import os
|
6 |
-
import sys
|
7 |
-
import traceback
|
8 |
-
import inspect
|
9 |
-
|
10 |
-
import modules.textual_inversion.dataset
|
11 |
-
import torch
|
12 |
-
import tqdm
|
13 |
-
from einops import rearrange, repeat
|
14 |
-
from ldm.util import default
|
15 |
-
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
|
16 |
-
from modules.textual_inversion import textual_inversion, logging
|
17 |
-
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
18 |
-
from torch import einsum
|
19 |
-
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
20 |
-
|
21 |
-
from collections import defaultdict, deque
|
22 |
-
from statistics import stdev, mean
|
23 |
-
|
24 |
-
|
25 |
-
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
26 |
-
|
27 |
-
class HypernetworkModule(torch.nn.Module):
|
28 |
-
activation_dict = {
|
29 |
-
"linear": torch.nn.Identity,
|
30 |
-
"relu": torch.nn.ReLU,
|
31 |
-
"leakyrelu": torch.nn.LeakyReLU,
|
32 |
-
"elu": torch.nn.ELU,
|
33 |
-
"swish": torch.nn.Hardswish,
|
34 |
-
"tanh": torch.nn.Tanh,
|
35 |
-
"sigmoid": torch.nn.Sigmoid,
|
36 |
-
}
|
37 |
-
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
38 |
-
|
39 |
-
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
|
40 |
-
add_layer_norm=False, activate_output=False, dropout_structure=None):
|
41 |
-
super().__init__()
|
42 |
-
|
43 |
-
self.multiplier = 1.0
|
44 |
-
|
45 |
-
assert layer_structure is not None, "layer_structure must not be None"
|
46 |
-
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
47 |
-
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
48 |
-
|
49 |
-
linears = []
|
50 |
-
for i in range(len(layer_structure) - 1):
|
51 |
-
|
52 |
-
# Add a fully-connected layer
|
53 |
-
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
54 |
-
|
55 |
-
# Add an activation func except last layer
|
56 |
-
if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
|
57 |
-
pass
|
58 |
-
elif activation_func in self.activation_dict:
|
59 |
-
linears.append(self.activation_dict[activation_func]())
|
60 |
-
else:
|
61 |
-
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
|
62 |
-
|
63 |
-
# Add layer normalization
|
64 |
-
if add_layer_norm:
|
65 |
-
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
66 |
-
|
67 |
-
# Everything should be now parsed into dropout structure, and applied here.
|
68 |
-
# Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
|
69 |
-
if dropout_structure is not None and dropout_structure[i+1] > 0:
|
70 |
-
assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
|
71 |
-
linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
|
72 |
-
# Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
|
73 |
-
|
74 |
-
self.linear = torch.nn.Sequential(*linears)
|
75 |
-
|
76 |
-
if state_dict is not None:
|
77 |
-
self.fix_old_state_dict(state_dict)
|
78 |
-
self.load_state_dict(state_dict)
|
79 |
-
else:
|
80 |
-
for layer in self.linear:
|
81 |
-
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
82 |
-
w, b = layer.weight.data, layer.bias.data
|
83 |
-
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
|
84 |
-
normal_(w, mean=0.0, std=0.01)
|
85 |
-
normal_(b, mean=0.0, std=0)
|
86 |
-
elif weight_init == 'XavierUniform':
|
87 |
-
xavier_uniform_(w)
|
88 |
-
zeros_(b)
|
89 |
-
elif weight_init == 'XavierNormal':
|
90 |
-
xavier_normal_(w)
|
91 |
-
zeros_(b)
|
92 |
-
elif weight_init == 'KaimingUniform':
|
93 |
-
kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
94 |
-
zeros_(b)
|
95 |
-
elif weight_init == 'KaimingNormal':
|
96 |
-
kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
97 |
-
zeros_(b)
|
98 |
-
else:
|
99 |
-
raise KeyError(f"Key {weight_init} is not defined as initialization!")
|
100 |
-
self.to(devices.device)
|
101 |
-
|
102 |
-
def fix_old_state_dict(self, state_dict):
|
103 |
-
changes = {
|
104 |
-
'linear1.bias': 'linear.0.bias',
|
105 |
-
'linear1.weight': 'linear.0.weight',
|
106 |
-
'linear2.bias': 'linear.1.bias',
|
107 |
-
'linear2.weight': 'linear.1.weight',
|
108 |
-
}
|
109 |
-
|
110 |
-
for fr, to in changes.items():
|
111 |
-
x = state_dict.get(fr, None)
|
112 |
-
if x is None:
|
113 |
-
continue
|
114 |
-
|
115 |
-
del state_dict[fr]
|
116 |
-
state_dict[to] = x
|
117 |
-
|
118 |
-
def forward(self, x):
|
119 |
-
return x + self.linear(x) * (self.multiplier if not self.training else 1)
|
120 |
-
|
121 |
-
def trainables(self):
|
122 |
-
layer_structure = []
|
123 |
-
for layer in self.linear:
|
124 |
-
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
125 |
-
layer_structure += [layer.weight, layer.bias]
|
126 |
-
return layer_structure
|
127 |
-
|
128 |
-
|
129 |
-
#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
|
130 |
-
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
|
131 |
-
if layer_structure is None:
|
132 |
-
layer_structure = [1, 2, 1]
|
133 |
-
if not use_dropout:
|
134 |
-
return [0] * len(layer_structure)
|
135 |
-
dropout_values = [0]
|
136 |
-
dropout_values.extend([0.3] * (len(layer_structure) - 3))
|
137 |
-
if last_layer_dropout:
|
138 |
-
dropout_values.append(0.3)
|
139 |
-
else:
|
140 |
-
dropout_values.append(0)
|
141 |
-
dropout_values.append(0)
|
142 |
-
return dropout_values
|
143 |
-
|
144 |
-
|
145 |
-
class Hypernetwork:
|
146 |
-
filename = None
|
147 |
-
name = None
|
148 |
-
|
149 |
-
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
|
150 |
-
self.filename = None
|
151 |
-
self.name = name
|
152 |
-
self.layers = {}
|
153 |
-
self.step = 0
|
154 |
-
self.sd_checkpoint = None
|
155 |
-
self.sd_checkpoint_name = None
|
156 |
-
self.layer_structure = layer_structure
|
157 |
-
self.activation_func = activation_func
|
158 |
-
self.weight_init = weight_init
|
159 |
-
self.add_layer_norm = add_layer_norm
|
160 |
-
self.use_dropout = use_dropout
|
161 |
-
self.activate_output = activate_output
|
162 |
-
self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
|
163 |
-
self.dropout_structure = kwargs.get('dropout_structure', None)
|
164 |
-
if self.dropout_structure is None:
|
165 |
-
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
166 |
-
self.optimizer_name = None
|
167 |
-
self.optimizer_state_dict = None
|
168 |
-
self.optional_info = None
|
169 |
-
|
170 |
-
for size in enable_sizes or []:
|
171 |
-
self.layers[size] = (
|
172 |
-
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
173 |
-
self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
|
174 |
-
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
175 |
-
self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
|
176 |
-
)
|
177 |
-
self.eval()
|
178 |
-
|
179 |
-
def weights(self):
|
180 |
-
res = []
|
181 |
-
for k, layers in self.layers.items():
|
182 |
-
for layer in layers:
|
183 |
-
res += layer.parameters()
|
184 |
-
return res
|
185 |
-
|
186 |
-
def train(self, mode=True):
|
187 |
-
for k, layers in self.layers.items():
|
188 |
-
for layer in layers:
|
189 |
-
layer.train(mode=mode)
|
190 |
-
for param in layer.parameters():
|
191 |
-
param.requires_grad = mode
|
192 |
-
|
193 |
-
def to(self, device):
|
194 |
-
for k, layers in self.layers.items():
|
195 |
-
for layer in layers:
|
196 |
-
layer.to(device)
|
197 |
-
|
198 |
-
return self
|
199 |
-
|
200 |
-
def set_multiplier(self, multiplier):
|
201 |
-
for k, layers in self.layers.items():
|
202 |
-
for layer in layers:
|
203 |
-
layer.multiplier = multiplier
|
204 |
-
|
205 |
-
return self
|
206 |
-
|
207 |
-
def eval(self):
|
208 |
-
for k, layers in self.layers.items():
|
209 |
-
for layer in layers:
|
210 |
-
layer.eval()
|
211 |
-
for param in layer.parameters():
|
212 |
-
param.requires_grad = False
|
213 |
-
|
214 |
-
def save(self, filename):
|
215 |
-
state_dict = {}
|
216 |
-
optimizer_saved_dict = {}
|
217 |
-
|
218 |
-
for k, v in self.layers.items():
|
219 |
-
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
|
220 |
-
|
221 |
-
state_dict['step'] = self.step
|
222 |
-
state_dict['name'] = self.name
|
223 |
-
state_dict['layer_structure'] = self.layer_structure
|
224 |
-
state_dict['activation_func'] = self.activation_func
|
225 |
-
state_dict['is_layer_norm'] = self.add_layer_norm
|
226 |
-
state_dict['weight_initialization'] = self.weight_init
|
227 |
-
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
228 |
-
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
229 |
-
state_dict['activate_output'] = self.activate_output
|
230 |
-
state_dict['use_dropout'] = self.use_dropout
|
231 |
-
state_dict['dropout_structure'] = self.dropout_structure
|
232 |
-
state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
|
233 |
-
state_dict['optional_info'] = self.optional_info if self.optional_info else None
|
234 |
-
|
235 |
-
if self.optimizer_name is not None:
|
236 |
-
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
|
237 |
-
|
238 |
-
torch.save(state_dict, filename)
|
239 |
-
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
|
240 |
-
optimizer_saved_dict['hash'] = self.shorthash()
|
241 |
-
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
|
242 |
-
torch.save(optimizer_saved_dict, filename + '.optim')
|
243 |
-
|
244 |
-
def load(self, filename):
|
245 |
-
self.filename = filename
|
246 |
-
if self.name is None:
|
247 |
-
self.name = os.path.splitext(os.path.basename(filename))[0]
|
248 |
-
|
249 |
-
state_dict = torch.load(filename, map_location='cpu')
|
250 |
-
|
251 |
-
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
252 |
-
self.optional_info = state_dict.get('optional_info', None)
|
253 |
-
self.activation_func = state_dict.get('activation_func', None)
|
254 |
-
self.weight_init = state_dict.get('weight_initialization', 'Normal')
|
255 |
-
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
256 |
-
self.dropout_structure = state_dict.get('dropout_structure', None)
|
257 |
-
self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
|
258 |
-
self.activate_output = state_dict.get('activate_output', True)
|
259 |
-
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
|
260 |
-
# Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
|
261 |
-
if self.dropout_structure is None:
|
262 |
-
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
263 |
-
|
264 |
-
if shared.opts.print_hypernet_extra:
|
265 |
-
if self.optional_info is not None:
|
266 |
-
print(f" INFO:\n {self.optional_info}\n")
|
267 |
-
|
268 |
-
print(f" Layer structure: {self.layer_structure}")
|
269 |
-
print(f" Activation function: {self.activation_func}")
|
270 |
-
print(f" Weight initialization: {self.weight_init}")
|
271 |
-
print(f" Layer norm: {self.add_layer_norm}")
|
272 |
-
print(f" Dropout usage: {self.use_dropout}" )
|
273 |
-
print(f" Activate last layer: {self.activate_output}")
|
274 |
-
print(f" Dropout structure: {self.dropout_structure}")
|
275 |
-
|
276 |
-
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
|
277 |
-
|
278 |
-
if self.shorthash() == optimizer_saved_dict.get('hash', None):
|
279 |
-
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
280 |
-
else:
|
281 |
-
self.optimizer_state_dict = None
|
282 |
-
if self.optimizer_state_dict:
|
283 |
-
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
284 |
-
if shared.opts.print_hypernet_extra:
|
285 |
-
print("Loaded existing optimizer from checkpoint")
|
286 |
-
print(f"Optimizer name is {self.optimizer_name}")
|
287 |
-
else:
|
288 |
-
self.optimizer_name = "AdamW"
|
289 |
-
if shared.opts.print_hypernet_extra:
|
290 |
-
print("No saved optimizer exists in checkpoint")
|
291 |
-
|
292 |
-
for size, sd in state_dict.items():
|
293 |
-
if type(size) == int:
|
294 |
-
self.layers[size] = (
|
295 |
-
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
|
296 |
-
self.add_layer_norm, self.activate_output, self.dropout_structure),
|
297 |
-
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
|
298 |
-
self.add_layer_norm, self.activate_output, self.dropout_structure),
|
299 |
-
)
|
300 |
-
|
301 |
-
self.name = state_dict.get('name', self.name)
|
302 |
-
self.step = state_dict.get('step', 0)
|
303 |
-
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
|
304 |
-
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
305 |
-
self.eval()
|
306 |
-
|
307 |
-
def shorthash(self):
|
308 |
-
sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
|
309 |
-
|
310 |
-
return sha256[0:10] if sha256 else None
|
311 |
-
|
312 |
-
|
313 |
-
def list_hypernetworks(path):
|
314 |
-
res = {}
|
315 |
-
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
|
316 |
-
name = os.path.splitext(os.path.basename(filename))[0]
|
317 |
-
# Prevent a hypothetical "None.pt" from being listed.
|
318 |
-
if name != "None":
|
319 |
-
res[name] = filename
|
320 |
-
return res
|
321 |
-
|
322 |
-
|
323 |
-
def load_hypernetwork(name):
|
324 |
-
path = shared.hypernetworks.get(name, None)
|
325 |
-
|
326 |
-
if path is None:
|
327 |
-
return None
|
328 |
-
|
329 |
-
hypernetwork = Hypernetwork()
|
330 |
-
|
331 |
-
try:
|
332 |
-
hypernetwork.load(path)
|
333 |
-
except Exception:
|
334 |
-
print(f"Error loading hypernetwork {path}", file=sys.stderr)
|
335 |
-
print(traceback.format_exc(), file=sys.stderr)
|
336 |
-
return None
|
337 |
-
|
338 |
-
return hypernetwork
|
339 |
-
|
340 |
-
|
341 |
-
def load_hypernetworks(names, multipliers=None):
|
342 |
-
already_loaded = {}
|
343 |
-
|
344 |
-
for hypernetwork in shared.loaded_hypernetworks:
|
345 |
-
if hypernetwork.name in names:
|
346 |
-
already_loaded[hypernetwork.name] = hypernetwork
|
347 |
-
|
348 |
-
shared.loaded_hypernetworks.clear()
|
349 |
-
|
350 |
-
for i, name in enumerate(names):
|
351 |
-
hypernetwork = already_loaded.get(name, None)
|
352 |
-
if hypernetwork is None:
|
353 |
-
hypernetwork = load_hypernetwork(name)
|
354 |
-
|
355 |
-
if hypernetwork is None:
|
356 |
-
continue
|
357 |
-
|
358 |
-
hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
|
359 |
-
shared.loaded_hypernetworks.append(hypernetwork)
|
360 |
-
|
361 |
-
|
362 |
-
def find_closest_hypernetwork_name(search: str):
|
363 |
-
if not search:
|
364 |
-
return None
|
365 |
-
search = search.lower()
|
366 |
-
applicable = [name for name in shared.hypernetworks if search in name.lower()]
|
367 |
-
if not applicable:
|
368 |
-
return None
|
369 |
-
applicable = sorted(applicable, key=lambda name: len(name))
|
370 |
-
return applicable[0]
|
371 |
-
|
372 |
-
|
373 |
-
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
374 |
-
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
|
375 |
-
|
376 |
-
if hypernetwork_layers is None:
|
377 |
-
return context_k, context_v
|
378 |
-
|
379 |
-
if layer is not None:
|
380 |
-
layer.hyper_k = hypernetwork_layers[0]
|
381 |
-
layer.hyper_v = hypernetwork_layers[1]
|
382 |
-
|
383 |
-
context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k)))
|
384 |
-
context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v)))
|
385 |
-
return context_k, context_v
|
386 |
-
|
387 |
-
|
388 |
-
def apply_hypernetworks(hypernetworks, context, layer=None):
|
389 |
-
context_k = context
|
390 |
-
context_v = context
|
391 |
-
for hypernetwork in hypernetworks:
|
392 |
-
context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
|
393 |
-
|
394 |
-
return context_k, context_v
|
395 |
-
|
396 |
-
|
397 |
-
def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
398 |
-
h = self.heads
|
399 |
-
|
400 |
-
q = self.to_q(x)
|
401 |
-
context = default(context, x)
|
402 |
-
|
403 |
-
context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
|
404 |
-
k = self.to_k(context_k)
|
405 |
-
v = self.to_v(context_v)
|
406 |
-
|
407 |
-
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
408 |
-
|
409 |
-
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
410 |
-
|
411 |
-
if mask is not None:
|
412 |
-
mask = rearrange(mask, 'b ... -> b (...)')
|
413 |
-
max_neg_value = -torch.finfo(sim.dtype).max
|
414 |
-
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
415 |
-
sim.masked_fill_(~mask, max_neg_value)
|
416 |
-
|
417 |
-
# attention, what we cannot get enough of
|
418 |
-
attn = sim.softmax(dim=-1)
|
419 |
-
|
420 |
-
out = einsum('b i j, b j d -> b i d', attn, v)
|
421 |
-
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
422 |
-
return self.to_out(out)
|
423 |
-
|
424 |
-
|
425 |
-
def stack_conds(conds):
|
426 |
-
if len(conds) == 1:
|
427 |
-
return torch.stack(conds)
|
428 |
-
|
429 |
-
# same as in reconstruct_multicond_batch
|
430 |
-
token_count = max([x.shape[0] for x in conds])
|
431 |
-
for i in range(len(conds)):
|
432 |
-
if conds[i].shape[0] != token_count:
|
433 |
-
last_vector = conds[i][-1:]
|
434 |
-
last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
|
435 |
-
conds[i] = torch.vstack([conds[i], last_vector_repeated])
|
436 |
-
|
437 |
-
return torch.stack(conds)
|
438 |
-
|
439 |
-
|
440 |
-
def statistics(data):
|
441 |
-
if len(data) < 2:
|
442 |
-
std = 0
|
443 |
-
else:
|
444 |
-
std = stdev(data)
|
445 |
-
total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
|
446 |
-
recent_data = data[-32:]
|
447 |
-
if len(recent_data) < 2:
|
448 |
-
std = 0
|
449 |
-
else:
|
450 |
-
std = stdev(recent_data)
|
451 |
-
recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
|
452 |
-
return total_information, recent_information
|
453 |
-
|
454 |
-
|
455 |
-
def report_statistics(loss_info:dict):
|
456 |
-
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
|
457 |
-
for key in keys:
|
458 |
-
try:
|
459 |
-
print("Loss statistics for file " + key)
|
460 |
-
info, recent = statistics(list(loss_info[key]))
|
461 |
-
print(info)
|
462 |
-
print(recent)
|
463 |
-
except Exception as e:
|
464 |
-
print(e)
|
465 |
-
|
466 |
-
|
467 |
-
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
468 |
-
# Remove illegal characters from name.
|
469 |
-
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
470 |
-
assert name, "Name cannot be empty!"
|
471 |
-
|
472 |
-
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
473 |
-
if not overwrite_old:
|
474 |
-
assert not os.path.exists(fn), f"file {fn} already exists"
|
475 |
-
|
476 |
-
if type(layer_structure) == str:
|
477 |
-
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
478 |
-
|
479 |
-
if use_dropout and dropout_structure and type(dropout_structure) == str:
|
480 |
-
dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
|
481 |
-
else:
|
482 |
-
dropout_structure = [0] * len(layer_structure)
|
483 |
-
|
484 |
-
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
485 |
-
name=name,
|
486 |
-
enable_sizes=[int(x) for x in enable_sizes],
|
487 |
-
layer_structure=layer_structure,
|
488 |
-
activation_func=activation_func,
|
489 |
-
weight_init=weight_init,
|
490 |
-
add_layer_norm=add_layer_norm,
|
491 |
-
use_dropout=use_dropout,
|
492 |
-
dropout_structure=dropout_structure
|
493 |
-
)
|
494 |
-
hypernet.save(fn)
|
495 |
-
|
496 |
-
shared.reload_hypernetworks()
|
497 |
-
|
498 |
-
|
499 |
-
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
500 |
-
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
501 |
-
from modules import images
|
502 |
-
|
503 |
-
save_hypernetwork_every = save_hypernetwork_every or 0
|
504 |
-
create_image_every = create_image_every or 0
|
505 |
-
template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
|
506 |
-
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
|
507 |
-
template_file = template_file.path
|
508 |
-
|
509 |
-
path = shared.hypernetworks.get(hypernetwork_name, None)
|
510 |
-
hypernetwork = Hypernetwork()
|
511 |
-
hypernetwork.load(path)
|
512 |
-
shared.loaded_hypernetworks = [hypernetwork]
|
513 |
-
|
514 |
-
shared.state.job = "train-hypernetwork"
|
515 |
-
shared.state.textinfo = "Initializing hypernetwork training..."
|
516 |
-
shared.state.job_count = steps
|
517 |
-
|
518 |
-
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
519 |
-
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
520 |
-
|
521 |
-
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
|
522 |
-
unload = shared.opts.unload_models_when_training
|
523 |
-
|
524 |
-
if save_hypernetwork_every > 0:
|
525 |
-
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
|
526 |
-
os.makedirs(hypernetwork_dir, exist_ok=True)
|
527 |
-
else:
|
528 |
-
hypernetwork_dir = None
|
529 |
-
|
530 |
-
if create_image_every > 0:
|
531 |
-
images_dir = os.path.join(log_directory, "images")
|
532 |
-
os.makedirs(images_dir, exist_ok=True)
|
533 |
-
else:
|
534 |
-
images_dir = None
|
535 |
-
|
536 |
-
checkpoint = sd_models.select_checkpoint()
|
537 |
-
|
538 |
-
initial_step = hypernetwork.step or 0
|
539 |
-
if initial_step >= steps:
|
540 |
-
shared.state.textinfo = "Model has already been trained beyond specified max steps"
|
541 |
-
return hypernetwork, filename
|
542 |
-
|
543 |
-
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
544 |
-
|
545 |
-
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
546 |
-
if clip_grad:
|
547 |
-
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
548 |
-
|
549 |
-
if shared.opts.training_enable_tensorboard:
|
550 |
-
tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
|
551 |
-
|
552 |
-
# dataset loading may take a while, so input validations and early returns should be done before this
|
553 |
-
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
554 |
-
|
555 |
-
pin_memory = shared.opts.pin_memory
|
556 |
-
|
557 |
-
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
|
558 |
-
|
559 |
-
if shared.opts.save_training_settings_to_txt:
|
560 |
-
saved_params = dict(
|
561 |
-
model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
|
562 |
-
**{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
|
563 |
-
)
|
564 |
-
logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
|
565 |
-
|
566 |
-
latent_sampling_method = ds.latent_sampling_method
|
567 |
-
|
568 |
-
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
|
569 |
-
|
570 |
-
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
571 |
-
|
572 |
-
if unload:
|
573 |
-
shared.parallel_processing_allowed = False
|
574 |
-
shared.sd_model.cond_stage_model.to(devices.cpu)
|
575 |
-
shared.sd_model.first_stage_model.to(devices.cpu)
|
576 |
-
|
577 |
-
weights = hypernetwork.weights()
|
578 |
-
hypernetwork.train()
|
579 |
-
|
580 |
-
# Here we use optimizer from saved HN, or we can specify as UI option.
|
581 |
-
if hypernetwork.optimizer_name in optimizer_dict:
|
582 |
-
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
|
583 |
-
optimizer_name = hypernetwork.optimizer_name
|
584 |
-
else:
|
585 |
-
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
|
586 |
-
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
|
587 |
-
optimizer_name = 'AdamW'
|
588 |
-
|
589 |
-
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
590 |
-
try:
|
591 |
-
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
592 |
-
except RuntimeError as e:
|
593 |
-
print("Cannot resume from saved optimizer!")
|
594 |
-
print(e)
|
595 |
-
|
596 |
-
scaler = torch.cuda.amp.GradScaler()
|
597 |
-
|
598 |
-
batch_size = ds.batch_size
|
599 |
-
gradient_step = ds.gradient_step
|
600 |
-
# n steps = batch_size * gradient_step * n image processed
|
601 |
-
steps_per_epoch = len(ds) // batch_size // gradient_step
|
602 |
-
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
|
603 |
-
loss_step = 0
|
604 |
-
_loss_step = 0 #internal
|
605 |
-
# size = len(ds.indexes)
|
606 |
-
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
607 |
-
loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size)
|
608 |
-
# losses = torch.zeros((size,))
|
609 |
-
# previous_mean_losses = [0]
|
610 |
-
# previous_mean_loss = 0
|
611 |
-
# print("Mean loss of {} elements".format(size))
|
612 |
-
|
613 |
-
steps_without_grad = 0
|
614 |
-
|
615 |
-
last_saved_file = "<none>"
|
616 |
-
last_saved_image = "<none>"
|
617 |
-
forced_filename = "<none>"
|
618 |
-
|
619 |
-
pbar = tqdm.tqdm(total=steps - initial_step)
|
620 |
-
try:
|
621 |
-
sd_hijack_checkpoint.add()
|
622 |
-
|
623 |
-
for i in range((steps-initial_step) * gradient_step):
|
624 |
-
if scheduler.finished:
|
625 |
-
break
|
626 |
-
if shared.state.interrupted:
|
627 |
-
break
|
628 |
-
for j, batch in enumerate(dl):
|
629 |
-
# works as a drop_last=True for gradient accumulation
|
630 |
-
if j == max_steps_per_epoch:
|
631 |
-
break
|
632 |
-
scheduler.apply(optimizer, hypernetwork.step)
|
633 |
-
if scheduler.finished:
|
634 |
-
break
|
635 |
-
if shared.state.interrupted:
|
636 |
-
break
|
637 |
-
|
638 |
-
if clip_grad:
|
639 |
-
clip_grad_sched.step(hypernetwork.step)
|
640 |
-
|
641 |
-
with devices.autocast():
|
642 |
-
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
643 |
-
if use_weight:
|
644 |
-
w = batch.weight.to(devices.device, non_blocking=pin_memory)
|
645 |
-
if tag_drop_out != 0 or shuffle_tags:
|
646 |
-
shared.sd_model.cond_stage_model.to(devices.device)
|
647 |
-
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
|
648 |
-
shared.sd_model.cond_stage_model.to(devices.cpu)
|
649 |
-
else:
|
650 |
-
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
651 |
-
if use_weight:
|
652 |
-
loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
|
653 |
-
del w
|
654 |
-
else:
|
655 |
-
loss = shared.sd_model.forward(x, c)[0] / gradient_step
|
656 |
-
del x
|
657 |
-
del c
|
658 |
-
|
659 |
-
_loss_step += loss.item()
|
660 |
-
scaler.scale(loss).backward()
|
661 |
-
|
662 |
-
# go back until we reach gradient accumulation steps
|
663 |
-
if (j + 1) % gradient_step != 0:
|
664 |
-
continue
|
665 |
-
loss_logging.append(_loss_step)
|
666 |
-
if clip_grad:
|
667 |
-
clip_grad(weights, clip_grad_sched.learn_rate)
|
668 |
-
|
669 |
-
scaler.step(optimizer)
|
670 |
-
scaler.update()
|
671 |
-
hypernetwork.step += 1
|
672 |
-
pbar.update()
|
673 |
-
optimizer.zero_grad(set_to_none=True)
|
674 |
-
loss_step = _loss_step
|
675 |
-
_loss_step = 0
|
676 |
-
|
677 |
-
steps_done = hypernetwork.step + 1
|
678 |
-
|
679 |
-
epoch_num = hypernetwork.step // steps_per_epoch
|
680 |
-
epoch_step = hypernetwork.step % steps_per_epoch
|
681 |
-
|
682 |
-
description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
|
683 |
-
pbar.set_description(description)
|
684 |
-
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
685 |
-
# Before saving, change name to match current checkpoint.
|
686 |
-
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
687 |
-
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
|
688 |
-
hypernetwork.optimizer_name = optimizer_name
|
689 |
-
if shared.opts.save_optimizer_state:
|
690 |
-
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
691 |
-
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
692 |
-
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
if shared.opts.training_enable_tensorboard:
|
697 |
-
epoch_num = hypernetwork.step // len(ds)
|
698 |
-
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
|
699 |
-
mean_loss = sum(loss_logging) / len(loss_logging)
|
700 |
-
textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
|
701 |
-
|
702 |
-
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
|
703 |
-
"loss": f"{loss_step:.7f}",
|
704 |
-
"learn_rate": scheduler.learn_rate
|
705 |
-
})
|
706 |
-
|
707 |
-
if images_dir is not None and steps_done % create_image_every == 0:
|
708 |
-
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
709 |
-
last_saved_image = os.path.join(images_dir, forced_filename)
|
710 |
-
hypernetwork.eval()
|
711 |
-
rng_state = torch.get_rng_state()
|
712 |
-
cuda_rng_state = None
|
713 |
-
if torch.cuda.is_available():
|
714 |
-
cuda_rng_state = torch.cuda.get_rng_state_all()
|
715 |
-
shared.sd_model.cond_stage_model.to(devices.device)
|
716 |
-
shared.sd_model.first_stage_model.to(devices.device)
|
717 |
-
|
718 |
-
p = processing.StableDiffusionProcessingTxt2Img(
|
719 |
-
sd_model=shared.sd_model,
|
720 |
-
do_not_save_grid=True,
|
721 |
-
do_not_save_samples=True,
|
722 |
-
)
|
723 |
-
|
724 |
-
p.disable_extra_networks = True
|
725 |
-
|
726 |
-
if preview_from_txt2img:
|
727 |
-
p.prompt = preview_prompt
|
728 |
-
p.negative_prompt = preview_negative_prompt
|
729 |
-
p.steps = preview_steps
|
730 |
-
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
731 |
-
p.cfg_scale = preview_cfg_scale
|
732 |
-
p.seed = preview_seed
|
733 |
-
p.width = preview_width
|
734 |
-
p.height = preview_height
|
735 |
-
else:
|
736 |
-
p.prompt = batch.cond_text[0]
|
737 |
-
p.steps = 20
|
738 |
-
p.width = training_width
|
739 |
-
p.height = training_height
|
740 |
-
|
741 |
-
preview_text = p.prompt
|
742 |
-
|
743 |
-
processed = processing.process_images(p)
|
744 |
-
image = processed.images[0] if len(processed.images) > 0 else None
|
745 |
-
|
746 |
-
if unload:
|
747 |
-
shared.sd_model.cond_stage_model.to(devices.cpu)
|
748 |
-
shared.sd_model.first_stage_model.to(devices.cpu)
|
749 |
-
torch.set_rng_state(rng_state)
|
750 |
-
if torch.cuda.is_available():
|
751 |
-
torch.cuda.set_rng_state_all(cuda_rng_state)
|
752 |
-
hypernetwork.train()
|
753 |
-
if image is not None:
|
754 |
-
shared.state.assign_current_image(image)
|
755 |
-
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
756 |
-
textual_inversion.tensorboard_add_image(tensorboard_writer,
|
757 |
-
f"Validation at epoch {epoch_num}", image,
|
758 |
-
hypernetwork.step)
|
759 |
-
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
760 |
-
last_saved_image += f", prompt: {preview_text}"
|
761 |
-
|
762 |
-
shared.state.job_no = hypernetwork.step
|
763 |
-
|
764 |
-
shared.state.textinfo = f"""
|
765 |
-
<p>
|
766 |
-
Loss: {loss_step:.7f}<br/>
|
767 |
-
Step: {steps_done}<br/>
|
768 |
-
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
769 |
-
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
770 |
-
Last saved image: {html.escape(last_saved_image)}<br/>
|
771 |
-
</p>
|
772 |
-
"""
|
773 |
-
except Exception:
|
774 |
-
print(traceback.format_exc(), file=sys.stderr)
|
775 |
-
finally:
|
776 |
-
pbar.leave = False
|
777 |
-
pbar.close()
|
778 |
-
hypernetwork.eval()
|
779 |
-
#report_statistics(loss_dict)
|
780 |
-
sd_hijack_checkpoint.remove()
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
785 |
-
hypernetwork.optimizer_name = optimizer_name
|
786 |
-
if shared.opts.save_optimizer_state:
|
787 |
-
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
788 |
-
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
|
789 |
-
|
790 |
-
del optimizer
|
791 |
-
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
792 |
-
shared.sd_model.cond_stage_model.to(devices.device)
|
793 |
-
shared.sd_model.first_stage_model.to(devices.device)
|
794 |
-
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
795 |
-
|
796 |
-
return hypernetwork, filename
|
797 |
-
|
798 |
-
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
|
799 |
-
old_hypernetwork_name = hypernetwork.name
|
800 |
-
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
|
801 |
-
old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
|
802 |
-
try:
|
803 |
-
hypernetwork.sd_checkpoint = checkpoint.shorthash
|
804 |
-
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
805 |
-
hypernetwork.name = hypernetwork_name
|
806 |
-
hypernetwork.save(filename)
|
807 |
-
except:
|
808 |
-
hypernetwork.sd_checkpoint = old_sd_checkpoint
|
809 |
-
hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
|
810 |
-
hypernetwork.name = old_hypernetwork_name
|
811 |
-
raise
|
|
|
1 |
+
import csv
|
2 |
+
import datetime
|
3 |
+
import glob
|
4 |
+
import html
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import traceback
|
8 |
+
import inspect
|
9 |
+
|
10 |
+
import modules.textual_inversion.dataset
|
11 |
+
import torch
|
12 |
+
import tqdm
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
from ldm.util import default
|
15 |
+
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
|
16 |
+
from modules.textual_inversion import textual_inversion, logging
|
17 |
+
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
18 |
+
from torch import einsum
|
19 |
+
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
20 |
+
|
21 |
+
from collections import defaultdict, deque
|
22 |
+
from statistics import stdev, mean
|
23 |
+
|
24 |
+
|
25 |
+
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
26 |
+
|
27 |
+
class HypernetworkModule(torch.nn.Module):
|
28 |
+
activation_dict = {
|
29 |
+
"linear": torch.nn.Identity,
|
30 |
+
"relu": torch.nn.ReLU,
|
31 |
+
"leakyrelu": torch.nn.LeakyReLU,
|
32 |
+
"elu": torch.nn.ELU,
|
33 |
+
"swish": torch.nn.Hardswish,
|
34 |
+
"tanh": torch.nn.Tanh,
|
35 |
+
"sigmoid": torch.nn.Sigmoid,
|
36 |
+
}
|
37 |
+
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
38 |
+
|
39 |
+
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
|
40 |
+
add_layer_norm=False, activate_output=False, dropout_structure=None):
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
self.multiplier = 1.0
|
44 |
+
|
45 |
+
assert layer_structure is not None, "layer_structure must not be None"
|
46 |
+
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
47 |
+
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
48 |
+
|
49 |
+
linears = []
|
50 |
+
for i in range(len(layer_structure) - 1):
|
51 |
+
|
52 |
+
# Add a fully-connected layer
|
53 |
+
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
54 |
+
|
55 |
+
# Add an activation func except last layer
|
56 |
+
if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
|
57 |
+
pass
|
58 |
+
elif activation_func in self.activation_dict:
|
59 |
+
linears.append(self.activation_dict[activation_func]())
|
60 |
+
else:
|
61 |
+
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
|
62 |
+
|
63 |
+
# Add layer normalization
|
64 |
+
if add_layer_norm:
|
65 |
+
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
66 |
+
|
67 |
+
# Everything should be now parsed into dropout structure, and applied here.
|
68 |
+
# Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
|
69 |
+
if dropout_structure is not None and dropout_structure[i+1] > 0:
|
70 |
+
assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
|
71 |
+
linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
|
72 |
+
# Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
|
73 |
+
|
74 |
+
self.linear = torch.nn.Sequential(*linears)
|
75 |
+
|
76 |
+
if state_dict is not None:
|
77 |
+
self.fix_old_state_dict(state_dict)
|
78 |
+
self.load_state_dict(state_dict)
|
79 |
+
else:
|
80 |
+
for layer in self.linear:
|
81 |
+
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
82 |
+
w, b = layer.weight.data, layer.bias.data
|
83 |
+
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
|
84 |
+
normal_(w, mean=0.0, std=0.01)
|
85 |
+
normal_(b, mean=0.0, std=0)
|
86 |
+
elif weight_init == 'XavierUniform':
|
87 |
+
xavier_uniform_(w)
|
88 |
+
zeros_(b)
|
89 |
+
elif weight_init == 'XavierNormal':
|
90 |
+
xavier_normal_(w)
|
91 |
+
zeros_(b)
|
92 |
+
elif weight_init == 'KaimingUniform':
|
93 |
+
kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
94 |
+
zeros_(b)
|
95 |
+
elif weight_init == 'KaimingNormal':
|
96 |
+
kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
97 |
+
zeros_(b)
|
98 |
+
else:
|
99 |
+
raise KeyError(f"Key {weight_init} is not defined as initialization!")
|
100 |
+
self.to(devices.device)
|
101 |
+
|
102 |
+
def fix_old_state_dict(self, state_dict):
|
103 |
+
changes = {
|
104 |
+
'linear1.bias': 'linear.0.bias',
|
105 |
+
'linear1.weight': 'linear.0.weight',
|
106 |
+
'linear2.bias': 'linear.1.bias',
|
107 |
+
'linear2.weight': 'linear.1.weight',
|
108 |
+
}
|
109 |
+
|
110 |
+
for fr, to in changes.items():
|
111 |
+
x = state_dict.get(fr, None)
|
112 |
+
if x is None:
|
113 |
+
continue
|
114 |
+
|
115 |
+
del state_dict[fr]
|
116 |
+
state_dict[to] = x
|
117 |
+
|
118 |
+
def forward(self, x):
|
119 |
+
return x + self.linear(x) * (self.multiplier if not self.training else 1)
|
120 |
+
|
121 |
+
def trainables(self):
|
122 |
+
layer_structure = []
|
123 |
+
for layer in self.linear:
|
124 |
+
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
125 |
+
layer_structure += [layer.weight, layer.bias]
|
126 |
+
return layer_structure
|
127 |
+
|
128 |
+
|
129 |
+
#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
|
130 |
+
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
|
131 |
+
if layer_structure is None:
|
132 |
+
layer_structure = [1, 2, 1]
|
133 |
+
if not use_dropout:
|
134 |
+
return [0] * len(layer_structure)
|
135 |
+
dropout_values = [0]
|
136 |
+
dropout_values.extend([0.3] * (len(layer_structure) - 3))
|
137 |
+
if last_layer_dropout:
|
138 |
+
dropout_values.append(0.3)
|
139 |
+
else:
|
140 |
+
dropout_values.append(0)
|
141 |
+
dropout_values.append(0)
|
142 |
+
return dropout_values
|
143 |
+
|
144 |
+
|
145 |
+
class Hypernetwork:
|
146 |
+
filename = None
|
147 |
+
name = None
|
148 |
+
|
149 |
+
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
|
150 |
+
self.filename = None
|
151 |
+
self.name = name
|
152 |
+
self.layers = {}
|
153 |
+
self.step = 0
|
154 |
+
self.sd_checkpoint = None
|
155 |
+
self.sd_checkpoint_name = None
|
156 |
+
self.layer_structure = layer_structure
|
157 |
+
self.activation_func = activation_func
|
158 |
+
self.weight_init = weight_init
|
159 |
+
self.add_layer_norm = add_layer_norm
|
160 |
+
self.use_dropout = use_dropout
|
161 |
+
self.activate_output = activate_output
|
162 |
+
self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
|
163 |
+
self.dropout_structure = kwargs.get('dropout_structure', None)
|
164 |
+
if self.dropout_structure is None:
|
165 |
+
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
166 |
+
self.optimizer_name = None
|
167 |
+
self.optimizer_state_dict = None
|
168 |
+
self.optional_info = None
|
169 |
+
|
170 |
+
for size in enable_sizes or []:
|
171 |
+
self.layers[size] = (
|
172 |
+
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
173 |
+
self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
|
174 |
+
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
175 |
+
self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
|
176 |
+
)
|
177 |
+
self.eval()
|
178 |
+
|
179 |
+
def weights(self):
|
180 |
+
res = []
|
181 |
+
for k, layers in self.layers.items():
|
182 |
+
for layer in layers:
|
183 |
+
res += layer.parameters()
|
184 |
+
return res
|
185 |
+
|
186 |
+
def train(self, mode=True):
|
187 |
+
for k, layers in self.layers.items():
|
188 |
+
for layer in layers:
|
189 |
+
layer.train(mode=mode)
|
190 |
+
for param in layer.parameters():
|
191 |
+
param.requires_grad = mode
|
192 |
+
|
193 |
+
def to(self, device):
|
194 |
+
for k, layers in self.layers.items():
|
195 |
+
for layer in layers:
|
196 |
+
layer.to(device)
|
197 |
+
|
198 |
+
return self
|
199 |
+
|
200 |
+
def set_multiplier(self, multiplier):
|
201 |
+
for k, layers in self.layers.items():
|
202 |
+
for layer in layers:
|
203 |
+
layer.multiplier = multiplier
|
204 |
+
|
205 |
+
return self
|
206 |
+
|
207 |
+
def eval(self):
|
208 |
+
for k, layers in self.layers.items():
|
209 |
+
for layer in layers:
|
210 |
+
layer.eval()
|
211 |
+
for param in layer.parameters():
|
212 |
+
param.requires_grad = False
|
213 |
+
|
214 |
+
def save(self, filename):
|
215 |
+
state_dict = {}
|
216 |
+
optimizer_saved_dict = {}
|
217 |
+
|
218 |
+
for k, v in self.layers.items():
|
219 |
+
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
|
220 |
+
|
221 |
+
state_dict['step'] = self.step
|
222 |
+
state_dict['name'] = self.name
|
223 |
+
state_dict['layer_structure'] = self.layer_structure
|
224 |
+
state_dict['activation_func'] = self.activation_func
|
225 |
+
state_dict['is_layer_norm'] = self.add_layer_norm
|
226 |
+
state_dict['weight_initialization'] = self.weight_init
|
227 |
+
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
228 |
+
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
229 |
+
state_dict['activate_output'] = self.activate_output
|
230 |
+
state_dict['use_dropout'] = self.use_dropout
|
231 |
+
state_dict['dropout_structure'] = self.dropout_structure
|
232 |
+
state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
|
233 |
+
state_dict['optional_info'] = self.optional_info if self.optional_info else None
|
234 |
+
|
235 |
+
if self.optimizer_name is not None:
|
236 |
+
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
|
237 |
+
|
238 |
+
torch.save(state_dict, filename)
|
239 |
+
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
|
240 |
+
optimizer_saved_dict['hash'] = self.shorthash()
|
241 |
+
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
|
242 |
+
torch.save(optimizer_saved_dict, filename + '.optim')
|
243 |
+
|
244 |
+
def load(self, filename):
|
245 |
+
self.filename = filename
|
246 |
+
if self.name is None:
|
247 |
+
self.name = os.path.splitext(os.path.basename(filename))[0]
|
248 |
+
|
249 |
+
state_dict = torch.load(filename, map_location='cpu')
|
250 |
+
|
251 |
+
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
252 |
+
self.optional_info = state_dict.get('optional_info', None)
|
253 |
+
self.activation_func = state_dict.get('activation_func', None)
|
254 |
+
self.weight_init = state_dict.get('weight_initialization', 'Normal')
|
255 |
+
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
256 |
+
self.dropout_structure = state_dict.get('dropout_structure', None)
|
257 |
+
self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
|
258 |
+
self.activate_output = state_dict.get('activate_output', True)
|
259 |
+
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
|
260 |
+
# Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
|
261 |
+
if self.dropout_structure is None:
|
262 |
+
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
263 |
+
|
264 |
+
if shared.opts.print_hypernet_extra:
|
265 |
+
if self.optional_info is not None:
|
266 |
+
print(f" INFO:\n {self.optional_info}\n")
|
267 |
+
|
268 |
+
print(f" Layer structure: {self.layer_structure}")
|
269 |
+
print(f" Activation function: {self.activation_func}")
|
270 |
+
print(f" Weight initialization: {self.weight_init}")
|
271 |
+
print(f" Layer norm: {self.add_layer_norm}")
|
272 |
+
print(f" Dropout usage: {self.use_dropout}" )
|
273 |
+
print(f" Activate last layer: {self.activate_output}")
|
274 |
+
print(f" Dropout structure: {self.dropout_structure}")
|
275 |
+
|
276 |
+
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
|
277 |
+
|
278 |
+
if self.shorthash() == optimizer_saved_dict.get('hash', None):
|
279 |
+
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
280 |
+
else:
|
281 |
+
self.optimizer_state_dict = None
|
282 |
+
if self.optimizer_state_dict:
|
283 |
+
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
284 |
+
if shared.opts.print_hypernet_extra:
|
285 |
+
print("Loaded existing optimizer from checkpoint")
|
286 |
+
print(f"Optimizer name is {self.optimizer_name}")
|
287 |
+
else:
|
288 |
+
self.optimizer_name = "AdamW"
|
289 |
+
if shared.opts.print_hypernet_extra:
|
290 |
+
print("No saved optimizer exists in checkpoint")
|
291 |
+
|
292 |
+
for size, sd in state_dict.items():
|
293 |
+
if type(size) == int:
|
294 |
+
self.layers[size] = (
|
295 |
+
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
|
296 |
+
self.add_layer_norm, self.activate_output, self.dropout_structure),
|
297 |
+
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
|
298 |
+
self.add_layer_norm, self.activate_output, self.dropout_structure),
|
299 |
+
)
|
300 |
+
|
301 |
+
self.name = state_dict.get('name', self.name)
|
302 |
+
self.step = state_dict.get('step', 0)
|
303 |
+
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
|
304 |
+
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
305 |
+
self.eval()
|
306 |
+
|
307 |
+
def shorthash(self):
|
308 |
+
sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
|
309 |
+
|
310 |
+
return sha256[0:10] if sha256 else None
|
311 |
+
|
312 |
+
|
313 |
+
def list_hypernetworks(path):
|
314 |
+
res = {}
|
315 |
+
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
|
316 |
+
name = os.path.splitext(os.path.basename(filename))[0]
|
317 |
+
# Prevent a hypothetical "None.pt" from being listed.
|
318 |
+
if name != "None":
|
319 |
+
res[name] = filename
|
320 |
+
return res
|
321 |
+
|
322 |
+
|
323 |
+
def load_hypernetwork(name):
|
324 |
+
path = shared.hypernetworks.get(name, None)
|
325 |
+
|
326 |
+
if path is None:
|
327 |
+
return None
|
328 |
+
|
329 |
+
hypernetwork = Hypernetwork()
|
330 |
+
|
331 |
+
try:
|
332 |
+
hypernetwork.load(path)
|
333 |
+
except Exception:
|
334 |
+
print(f"Error loading hypernetwork {path}", file=sys.stderr)
|
335 |
+
print(traceback.format_exc(), file=sys.stderr)
|
336 |
+
return None
|
337 |
+
|
338 |
+
return hypernetwork
|
339 |
+
|
340 |
+
|
341 |
+
def load_hypernetworks(names, multipliers=None):
|
342 |
+
already_loaded = {}
|
343 |
+
|
344 |
+
for hypernetwork in shared.loaded_hypernetworks:
|
345 |
+
if hypernetwork.name in names:
|
346 |
+
already_loaded[hypernetwork.name] = hypernetwork
|
347 |
+
|
348 |
+
shared.loaded_hypernetworks.clear()
|
349 |
+
|
350 |
+
for i, name in enumerate(names):
|
351 |
+
hypernetwork = already_loaded.get(name, None)
|
352 |
+
if hypernetwork is None:
|
353 |
+
hypernetwork = load_hypernetwork(name)
|
354 |
+
|
355 |
+
if hypernetwork is None:
|
356 |
+
continue
|
357 |
+
|
358 |
+
hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
|
359 |
+
shared.loaded_hypernetworks.append(hypernetwork)
|
360 |
+
|
361 |
+
|
362 |
+
def find_closest_hypernetwork_name(search: str):
|
363 |
+
if not search:
|
364 |
+
return None
|
365 |
+
search = search.lower()
|
366 |
+
applicable = [name for name in shared.hypernetworks if search in name.lower()]
|
367 |
+
if not applicable:
|
368 |
+
return None
|
369 |
+
applicable = sorted(applicable, key=lambda name: len(name))
|
370 |
+
return applicable[0]
|
371 |
+
|
372 |
+
|
373 |
+
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
374 |
+
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
|
375 |
+
|
376 |
+
if hypernetwork_layers is None:
|
377 |
+
return context_k, context_v
|
378 |
+
|
379 |
+
if layer is not None:
|
380 |
+
layer.hyper_k = hypernetwork_layers[0]
|
381 |
+
layer.hyper_v = hypernetwork_layers[1]
|
382 |
+
|
383 |
+
context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k)))
|
384 |
+
context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v)))
|
385 |
+
return context_k, context_v
|
386 |
+
|
387 |
+
|
388 |
+
def apply_hypernetworks(hypernetworks, context, layer=None):
|
389 |
+
context_k = context
|
390 |
+
context_v = context
|
391 |
+
for hypernetwork in hypernetworks:
|
392 |
+
context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
|
393 |
+
|
394 |
+
return context_k, context_v
|
395 |
+
|
396 |
+
|
397 |
+
def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
398 |
+
h = self.heads
|
399 |
+
|
400 |
+
q = self.to_q(x)
|
401 |
+
context = default(context, x)
|
402 |
+
|
403 |
+
context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
|
404 |
+
k = self.to_k(context_k)
|
405 |
+
v = self.to_v(context_v)
|
406 |
+
|
407 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
408 |
+
|
409 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
410 |
+
|
411 |
+
if mask is not None:
|
412 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
413 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
414 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
415 |
+
sim.masked_fill_(~mask, max_neg_value)
|
416 |
+
|
417 |
+
# attention, what we cannot get enough of
|
418 |
+
attn = sim.softmax(dim=-1)
|
419 |
+
|
420 |
+
out = einsum('b i j, b j d -> b i d', attn, v)
|
421 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
422 |
+
return self.to_out(out)
|
423 |
+
|
424 |
+
|
425 |
+
def stack_conds(conds):
|
426 |
+
if len(conds) == 1:
|
427 |
+
return torch.stack(conds)
|
428 |
+
|
429 |
+
# same as in reconstruct_multicond_batch
|
430 |
+
token_count = max([x.shape[0] for x in conds])
|
431 |
+
for i in range(len(conds)):
|
432 |
+
if conds[i].shape[0] != token_count:
|
433 |
+
last_vector = conds[i][-1:]
|
434 |
+
last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
|
435 |
+
conds[i] = torch.vstack([conds[i], last_vector_repeated])
|
436 |
+
|
437 |
+
return torch.stack(conds)
|
438 |
+
|
439 |
+
|
440 |
+
def statistics(data):
|
441 |
+
if len(data) < 2:
|
442 |
+
std = 0
|
443 |
+
else:
|
444 |
+
std = stdev(data)
|
445 |
+
total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
|
446 |
+
recent_data = data[-32:]
|
447 |
+
if len(recent_data) < 2:
|
448 |
+
std = 0
|
449 |
+
else:
|
450 |
+
std = stdev(recent_data)
|
451 |
+
recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
|
452 |
+
return total_information, recent_information
|
453 |
+
|
454 |
+
|
455 |
+
def report_statistics(loss_info:dict):
|
456 |
+
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
|
457 |
+
for key in keys:
|
458 |
+
try:
|
459 |
+
print("Loss statistics for file " + key)
|
460 |
+
info, recent = statistics(list(loss_info[key]))
|
461 |
+
print(info)
|
462 |
+
print(recent)
|
463 |
+
except Exception as e:
|
464 |
+
print(e)
|
465 |
+
|
466 |
+
|
467 |
+
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
468 |
+
# Remove illegal characters from name.
|
469 |
+
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
470 |
+
assert name, "Name cannot be empty!"
|
471 |
+
|
472 |
+
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
473 |
+
if not overwrite_old:
|
474 |
+
assert not os.path.exists(fn), f"file {fn} already exists"
|
475 |
+
|
476 |
+
if type(layer_structure) == str:
|
477 |
+
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
478 |
+
|
479 |
+
if use_dropout and dropout_structure and type(dropout_structure) == str:
|
480 |
+
dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
|
481 |
+
else:
|
482 |
+
dropout_structure = [0] * len(layer_structure)
|
483 |
+
|
484 |
+
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
485 |
+
name=name,
|
486 |
+
enable_sizes=[int(x) for x in enable_sizes],
|
487 |
+
layer_structure=layer_structure,
|
488 |
+
activation_func=activation_func,
|
489 |
+
weight_init=weight_init,
|
490 |
+
add_layer_norm=add_layer_norm,
|
491 |
+
use_dropout=use_dropout,
|
492 |
+
dropout_structure=dropout_structure
|
493 |
+
)
|
494 |
+
hypernet.save(fn)
|
495 |
+
|
496 |
+
shared.reload_hypernetworks()
|
497 |
+
|
498 |
+
|
499 |
+
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
500 |
+
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
501 |
+
from modules import images
|
502 |
+
|
503 |
+
save_hypernetwork_every = save_hypernetwork_every or 0
|
504 |
+
create_image_every = create_image_every or 0
|
505 |
+
template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
|
506 |
+
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
|
507 |
+
template_file = template_file.path
|
508 |
+
|
509 |
+
path = shared.hypernetworks.get(hypernetwork_name, None)
|
510 |
+
hypernetwork = Hypernetwork()
|
511 |
+
hypernetwork.load(path)
|
512 |
+
shared.loaded_hypernetworks = [hypernetwork]
|
513 |
+
|
514 |
+
shared.state.job = "train-hypernetwork"
|
515 |
+
shared.state.textinfo = "Initializing hypernetwork training..."
|
516 |
+
shared.state.job_count = steps
|
517 |
+
|
518 |
+
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
519 |
+
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
520 |
+
|
521 |
+
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
|
522 |
+
unload = shared.opts.unload_models_when_training
|
523 |
+
|
524 |
+
if save_hypernetwork_every > 0:
|
525 |
+
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
|
526 |
+
os.makedirs(hypernetwork_dir, exist_ok=True)
|
527 |
+
else:
|
528 |
+
hypernetwork_dir = None
|
529 |
+
|
530 |
+
if create_image_every > 0:
|
531 |
+
images_dir = os.path.join(log_directory, "images")
|
532 |
+
os.makedirs(images_dir, exist_ok=True)
|
533 |
+
else:
|
534 |
+
images_dir = None
|
535 |
+
|
536 |
+
checkpoint = sd_models.select_checkpoint()
|
537 |
+
|
538 |
+
initial_step = hypernetwork.step or 0
|
539 |
+
if initial_step >= steps:
|
540 |
+
shared.state.textinfo = "Model has already been trained beyond specified max steps"
|
541 |
+
return hypernetwork, filename
|
542 |
+
|
543 |
+
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
544 |
+
|
545 |
+
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
546 |
+
if clip_grad:
|
547 |
+
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
548 |
+
|
549 |
+
if shared.opts.training_enable_tensorboard:
|
550 |
+
tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
|
551 |
+
|
552 |
+
# dataset loading may take a while, so input validations and early returns should be done before this
|
553 |
+
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
554 |
+
|
555 |
+
pin_memory = shared.opts.pin_memory
|
556 |
+
|
557 |
+
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
|
558 |
+
|
559 |
+
if shared.opts.save_training_settings_to_txt:
|
560 |
+
saved_params = dict(
|
561 |
+
model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
|
562 |
+
**{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
|
563 |
+
)
|
564 |
+
logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
|
565 |
+
|
566 |
+
latent_sampling_method = ds.latent_sampling_method
|
567 |
+
|
568 |
+
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
|
569 |
+
|
570 |
+
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
571 |
+
|
572 |
+
if unload:
|
573 |
+
shared.parallel_processing_allowed = False
|
574 |
+
shared.sd_model.cond_stage_model.to(devices.cpu)
|
575 |
+
shared.sd_model.first_stage_model.to(devices.cpu)
|
576 |
+
|
577 |
+
weights = hypernetwork.weights()
|
578 |
+
hypernetwork.train()
|
579 |
+
|
580 |
+
# Here we use optimizer from saved HN, or we can specify as UI option.
|
581 |
+
if hypernetwork.optimizer_name in optimizer_dict:
|
582 |
+
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
|
583 |
+
optimizer_name = hypernetwork.optimizer_name
|
584 |
+
else:
|
585 |
+
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
|
586 |
+
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
|
587 |
+
optimizer_name = 'AdamW'
|
588 |
+
|
589 |
+
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
590 |
+
try:
|
591 |
+
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
592 |
+
except RuntimeError as e:
|
593 |
+
print("Cannot resume from saved optimizer!")
|
594 |
+
print(e)
|
595 |
+
|
596 |
+
scaler = torch.cuda.amp.GradScaler()
|
597 |
+
|
598 |
+
batch_size = ds.batch_size
|
599 |
+
gradient_step = ds.gradient_step
|
600 |
+
# n steps = batch_size * gradient_step * n image processed
|
601 |
+
steps_per_epoch = len(ds) // batch_size // gradient_step
|
602 |
+
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
|
603 |
+
loss_step = 0
|
604 |
+
_loss_step = 0 #internal
|
605 |
+
# size = len(ds.indexes)
|
606 |
+
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
607 |
+
loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size)
|
608 |
+
# losses = torch.zeros((size,))
|
609 |
+
# previous_mean_losses = [0]
|
610 |
+
# previous_mean_loss = 0
|
611 |
+
# print("Mean loss of {} elements".format(size))
|
612 |
+
|
613 |
+
steps_without_grad = 0
|
614 |
+
|
615 |
+
last_saved_file = "<none>"
|
616 |
+
last_saved_image = "<none>"
|
617 |
+
forced_filename = "<none>"
|
618 |
+
|
619 |
+
pbar = tqdm.tqdm(total=steps - initial_step)
|
620 |
+
try:
|
621 |
+
sd_hijack_checkpoint.add()
|
622 |
+
|
623 |
+
for i in range((steps-initial_step) * gradient_step):
|
624 |
+
if scheduler.finished:
|
625 |
+
break
|
626 |
+
if shared.state.interrupted:
|
627 |
+
break
|
628 |
+
for j, batch in enumerate(dl):
|
629 |
+
# works as a drop_last=True for gradient accumulation
|
630 |
+
if j == max_steps_per_epoch:
|
631 |
+
break
|
632 |
+
scheduler.apply(optimizer, hypernetwork.step)
|
633 |
+
if scheduler.finished:
|
634 |
+
break
|
635 |
+
if shared.state.interrupted:
|
636 |
+
break
|
637 |
+
|
638 |
+
if clip_grad:
|
639 |
+
clip_grad_sched.step(hypernetwork.step)
|
640 |
+
|
641 |
+
with devices.autocast():
|
642 |
+
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
643 |
+
if use_weight:
|
644 |
+
w = batch.weight.to(devices.device, non_blocking=pin_memory)
|
645 |
+
if tag_drop_out != 0 or shuffle_tags:
|
646 |
+
shared.sd_model.cond_stage_model.to(devices.device)
|
647 |
+
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
|
648 |
+
shared.sd_model.cond_stage_model.to(devices.cpu)
|
649 |
+
else:
|
650 |
+
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
651 |
+
if use_weight:
|
652 |
+
loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
|
653 |
+
del w
|
654 |
+
else:
|
655 |
+
loss = shared.sd_model.forward(x, c)[0] / gradient_step
|
656 |
+
del x
|
657 |
+
del c
|
658 |
+
|
659 |
+
_loss_step += loss.item()
|
660 |
+
scaler.scale(loss).backward()
|
661 |
+
|
662 |
+
# go back until we reach gradient accumulation steps
|
663 |
+
if (j + 1) % gradient_step != 0:
|
664 |
+
continue
|
665 |
+
loss_logging.append(_loss_step)
|
666 |
+
if clip_grad:
|
667 |
+
clip_grad(weights, clip_grad_sched.learn_rate)
|
668 |
+
|
669 |
+
scaler.step(optimizer)
|
670 |
+
scaler.update()
|
671 |
+
hypernetwork.step += 1
|
672 |
+
pbar.update()
|
673 |
+
optimizer.zero_grad(set_to_none=True)
|
674 |
+
loss_step = _loss_step
|
675 |
+
_loss_step = 0
|
676 |
+
|
677 |
+
steps_done = hypernetwork.step + 1
|
678 |
+
|
679 |
+
epoch_num = hypernetwork.step // steps_per_epoch
|
680 |
+
epoch_step = hypernetwork.step % steps_per_epoch
|
681 |
+
|
682 |
+
description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
|
683 |
+
pbar.set_description(description)
|
684 |
+
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
685 |
+
# Before saving, change name to match current checkpoint.
|
686 |
+
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
687 |
+
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
|
688 |
+
hypernetwork.optimizer_name = optimizer_name
|
689 |
+
if shared.opts.save_optimizer_state:
|
690 |
+
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
691 |
+
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
692 |
+
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
693 |
+
|
694 |
+
|
695 |
+
|
696 |
+
if shared.opts.training_enable_tensorboard:
|
697 |
+
epoch_num = hypernetwork.step // len(ds)
|
698 |
+
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
|
699 |
+
mean_loss = sum(loss_logging) / len(loss_logging)
|
700 |
+
textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
|
701 |
+
|
702 |
+
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
|
703 |
+
"loss": f"{loss_step:.7f}",
|
704 |
+
"learn_rate": scheduler.learn_rate
|
705 |
+
})
|
706 |
+
|
707 |
+
if images_dir is not None and steps_done % create_image_every == 0:
|
708 |
+
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
709 |
+
last_saved_image = os.path.join(images_dir, forced_filename)
|
710 |
+
hypernetwork.eval()
|
711 |
+
rng_state = torch.get_rng_state()
|
712 |
+
cuda_rng_state = None
|
713 |
+
if torch.cuda.is_available():
|
714 |
+
cuda_rng_state = torch.cuda.get_rng_state_all()
|
715 |
+
shared.sd_model.cond_stage_model.to(devices.device)
|
716 |
+
shared.sd_model.first_stage_model.to(devices.device)
|
717 |
+
|
718 |
+
p = processing.StableDiffusionProcessingTxt2Img(
|
719 |
+
sd_model=shared.sd_model,
|
720 |
+
do_not_save_grid=True,
|
721 |
+
do_not_save_samples=True,
|
722 |
+
)
|
723 |
+
|
724 |
+
p.disable_extra_networks = True
|
725 |
+
|
726 |
+
if preview_from_txt2img:
|
727 |
+
p.prompt = preview_prompt
|
728 |
+
p.negative_prompt = preview_negative_prompt
|
729 |
+
p.steps = preview_steps
|
730 |
+
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
731 |
+
p.cfg_scale = preview_cfg_scale
|
732 |
+
p.seed = preview_seed
|
733 |
+
p.width = preview_width
|
734 |
+
p.height = preview_height
|
735 |
+
else:
|
736 |
+
p.prompt = batch.cond_text[0]
|
737 |
+
p.steps = 20
|
738 |
+
p.width = training_width
|
739 |
+
p.height = training_height
|
740 |
+
|
741 |
+
preview_text = p.prompt
|
742 |
+
|
743 |
+
processed = processing.process_images(p)
|
744 |
+
image = processed.images[0] if len(processed.images) > 0 else None
|
745 |
+
|
746 |
+
if unload:
|
747 |
+
shared.sd_model.cond_stage_model.to(devices.cpu)
|
748 |
+
shared.sd_model.first_stage_model.to(devices.cpu)
|
749 |
+
torch.set_rng_state(rng_state)
|
750 |
+
if torch.cuda.is_available():
|
751 |
+
torch.cuda.set_rng_state_all(cuda_rng_state)
|
752 |
+
hypernetwork.train()
|
753 |
+
if image is not None:
|
754 |
+
shared.state.assign_current_image(image)
|
755 |
+
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
756 |
+
textual_inversion.tensorboard_add_image(tensorboard_writer,
|
757 |
+
f"Validation at epoch {epoch_num}", image,
|
758 |
+
hypernetwork.step)
|
759 |
+
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
760 |
+
last_saved_image += f", prompt: {preview_text}"
|
761 |
+
|
762 |
+
shared.state.job_no = hypernetwork.step
|
763 |
+
|
764 |
+
shared.state.textinfo = f"""
|
765 |
+
<p>
|
766 |
+
Loss: {loss_step:.7f}<br/>
|
767 |
+
Step: {steps_done}<br/>
|
768 |
+
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
769 |
+
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
770 |
+
Last saved image: {html.escape(last_saved_image)}<br/>
|
771 |
+
</p>
|
772 |
+
"""
|
773 |
+
except Exception:
|
774 |
+
print(traceback.format_exc(), file=sys.stderr)
|
775 |
+
finally:
|
776 |
+
pbar.leave = False
|
777 |
+
pbar.close()
|
778 |
+
hypernetwork.eval()
|
779 |
+
#report_statistics(loss_dict)
|
780 |
+
sd_hijack_checkpoint.remove()
|
781 |
+
|
782 |
+
|
783 |
+
|
784 |
+
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
785 |
+
hypernetwork.optimizer_name = optimizer_name
|
786 |
+
if shared.opts.save_optimizer_state:
|
787 |
+
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
788 |
+
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
|
789 |
+
|
790 |
+
del optimizer
|
791 |
+
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
792 |
+
shared.sd_model.cond_stage_model.to(devices.device)
|
793 |
+
shared.sd_model.first_stage_model.to(devices.device)
|
794 |
+
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
795 |
+
|
796 |
+
return hypernetwork, filename
|
797 |
+
|
798 |
+
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
|
799 |
+
old_hypernetwork_name = hypernetwork.name
|
800 |
+
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
|
801 |
+
old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
|
802 |
+
try:
|
803 |
+
hypernetwork.sd_checkpoint = checkpoint.shorthash
|
804 |
+
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
805 |
+
hypernetwork.name = hypernetwork_name
|
806 |
+
hypernetwork.save(filename)
|
807 |
+
except:
|
808 |
+
hypernetwork.sd_checkpoint = old_sd_checkpoint
|
809 |
+
hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
|
810 |
+
hypernetwork.name = old_hypernetwork_name
|
811 |
+
raise
|
sd/stable-diffusion-webui/modules/hypernetworks/ui.py
CHANGED
@@ -1,40 +1,40 @@
|
|
1 |
-
import html
|
2 |
-
import os
|
3 |
-
import re
|
4 |
-
|
5 |
-
import gradio as gr
|
6 |
-
import modules.hypernetworks.hypernetwork
|
7 |
-
from modules import devices, sd_hijack, shared
|
8 |
-
|
9 |
-
not_available = ["hardswish", "multiheadattention"]
|
10 |
-
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
11 |
-
|
12 |
-
|
13 |
-
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
14 |
-
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
15 |
-
|
16 |
-
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
|
17 |
-
|
18 |
-
|
19 |
-
def train_hypernetwork(*args):
|
20 |
-
shared.loaded_hypernetworks = []
|
21 |
-
|
22 |
-
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
23 |
-
|
24 |
-
try:
|
25 |
-
sd_hijack.undo_optimizations()
|
26 |
-
|
27 |
-
hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
|
28 |
-
|
29 |
-
res = f"""
|
30 |
-
Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
|
31 |
-
Hypernetwork saved to {html.escape(filename)}
|
32 |
-
"""
|
33 |
-
return res, ""
|
34 |
-
except Exception:
|
35 |
-
raise
|
36 |
-
finally:
|
37 |
-
shared.sd_model.cond_stage_model.to(devices.device)
|
38 |
-
shared.sd_model.first_stage_model.to(devices.device)
|
39 |
-
sd_hijack.apply_optimizations()
|
40 |
-
|
|
|
1 |
+
import html
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import modules.hypernetworks.hypernetwork
|
7 |
+
from modules import devices, sd_hijack, shared
|
8 |
+
|
9 |
+
not_available = ["hardswish", "multiheadattention"]
|
10 |
+
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
11 |
+
|
12 |
+
|
13 |
+
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
14 |
+
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
15 |
+
|
16 |
+
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
|
17 |
+
|
18 |
+
|
19 |
+
def train_hypernetwork(*args):
|
20 |
+
shared.loaded_hypernetworks = []
|
21 |
+
|
22 |
+
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
23 |
+
|
24 |
+
try:
|
25 |
+
sd_hijack.undo_optimizations()
|
26 |
+
|
27 |
+
hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
|
28 |
+
|
29 |
+
res = f"""
|
30 |
+
Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
|
31 |
+
Hypernetwork saved to {html.escape(filename)}
|
32 |
+
"""
|
33 |
+
return res, ""
|
34 |
+
except Exception:
|
35 |
+
raise
|
36 |
+
finally:
|
37 |
+
shared.sd_model.cond_stage_model.to(devices.device)
|
38 |
+
shared.sd_model.first_stage_model.to(devices.device)
|
39 |
+
sd_hijack.apply_optimizations()
|
40 |
+
|
sd/stable-diffusion-webui/modules/images.py
CHANGED
@@ -1,669 +1,669 @@
|
|
1 |
-
import datetime
|
2 |
-
import sys
|
3 |
-
import traceback
|
4 |
-
|
5 |
-
import pytz
|
6 |
-
import io
|
7 |
-
import math
|
8 |
-
import os
|
9 |
-
from collections import namedtuple
|
10 |
-
import re
|
11 |
-
|
12 |
-
import numpy as np
|
13 |
-
import piexif
|
14 |
-
import piexif.helper
|
15 |
-
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
16 |
-
from fonts.ttf import Roboto
|
17 |
-
import string
|
18 |
-
import json
|
19 |
-
import hashlib
|
20 |
-
|
21 |
-
from modules import sd_samplers, shared, script_callbacks, errors
|
22 |
-
from modules.shared import opts, cmd_opts
|
23 |
-
|
24 |
-
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
25 |
-
|
26 |
-
|
27 |
-
def image_grid(imgs, batch_size=1, rows=None):
|
28 |
-
if rows is None:
|
29 |
-
if opts.n_rows > 0:
|
30 |
-
rows = opts.n_rows
|
31 |
-
elif opts.n_rows == 0:
|
32 |
-
rows = batch_size
|
33 |
-
elif opts.grid_prevent_empty_spots:
|
34 |
-
rows = math.floor(math.sqrt(len(imgs)))
|
35 |
-
while len(imgs) % rows != 0:
|
36 |
-
rows -= 1
|
37 |
-
else:
|
38 |
-
rows = math.sqrt(len(imgs))
|
39 |
-
rows = round(rows)
|
40 |
-
if rows > len(imgs):
|
41 |
-
rows = len(imgs)
|
42 |
-
|
43 |
-
cols = math.ceil(len(imgs) / rows)
|
44 |
-
|
45 |
-
params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
|
46 |
-
script_callbacks.image_grid_callback(params)
|
47 |
-
|
48 |
-
w, h = imgs[0].size
|
49 |
-
grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
|
50 |
-
|
51 |
-
for i, img in enumerate(params.imgs):
|
52 |
-
grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
|
53 |
-
|
54 |
-
return grid
|
55 |
-
|
56 |
-
|
57 |
-
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
|
58 |
-
|
59 |
-
|
60 |
-
def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
61 |
-
w = image.width
|
62 |
-
h = image.height
|
63 |
-
|
64 |
-
non_overlap_width = tile_w - overlap
|
65 |
-
non_overlap_height = tile_h - overlap
|
66 |
-
|
67 |
-
cols = math.ceil((w - overlap) / non_overlap_width)
|
68 |
-
rows = math.ceil((h - overlap) / non_overlap_height)
|
69 |
-
|
70 |
-
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
|
71 |
-
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
|
72 |
-
|
73 |
-
grid = Grid([], tile_w, tile_h, w, h, overlap)
|
74 |
-
for row in range(rows):
|
75 |
-
row_images = []
|
76 |
-
|
77 |
-
y = int(row * dy)
|
78 |
-
|
79 |
-
if y + tile_h >= h:
|
80 |
-
y = h - tile_h
|
81 |
-
|
82 |
-
for col in range(cols):
|
83 |
-
x = int(col * dx)
|
84 |
-
|
85 |
-
if x + tile_w >= w:
|
86 |
-
x = w - tile_w
|
87 |
-
|
88 |
-
tile = image.crop((x, y, x + tile_w, y + tile_h))
|
89 |
-
|
90 |
-
row_images.append([x, tile_w, tile])
|
91 |
-
|
92 |
-
grid.tiles.append([y, tile_h, row_images])
|
93 |
-
|
94 |
-
return grid
|
95 |
-
|
96 |
-
|
97 |
-
def combine_grid(grid):
|
98 |
-
def make_mask_image(r):
|
99 |
-
r = r * 255 / grid.overlap
|
100 |
-
r = r.astype(np.uint8)
|
101 |
-
return Image.fromarray(r, 'L')
|
102 |
-
|
103 |
-
mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
|
104 |
-
mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
|
105 |
-
|
106 |
-
combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
|
107 |
-
for y, h, row in grid.tiles:
|
108 |
-
combined_row = Image.new("RGB", (grid.image_w, h))
|
109 |
-
for x, w, tile in row:
|
110 |
-
if x == 0:
|
111 |
-
combined_row.paste(tile, (0, 0))
|
112 |
-
continue
|
113 |
-
|
114 |
-
combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
|
115 |
-
combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
|
116 |
-
|
117 |
-
if y == 0:
|
118 |
-
combined_image.paste(combined_row, (0, 0))
|
119 |
-
continue
|
120 |
-
|
121 |
-
combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
|
122 |
-
combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
|
123 |
-
|
124 |
-
return combined_image
|
125 |
-
|
126 |
-
|
127 |
-
class GridAnnotation:
|
128 |
-
def __init__(self, text='', is_active=True):
|
129 |
-
self.text = text
|
130 |
-
self.is_active = is_active
|
131 |
-
self.size = None
|
132 |
-
|
133 |
-
|
134 |
-
def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
135 |
-
def wrap(drawing, text, font, line_length):
|
136 |
-
lines = ['']
|
137 |
-
for word in text.split():
|
138 |
-
line = f'{lines[-1]} {word}'.strip()
|
139 |
-
if drawing.textlength(line, font=font) <= line_length:
|
140 |
-
lines[-1] = line
|
141 |
-
else:
|
142 |
-
lines.append(word)
|
143 |
-
return lines
|
144 |
-
|
145 |
-
def get_font(fontsize):
|
146 |
-
try:
|
147 |
-
return ImageFont.truetype(opts.font or Roboto, fontsize)
|
148 |
-
except Exception:
|
149 |
-
return ImageFont.truetype(Roboto, fontsize)
|
150 |
-
|
151 |
-
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
|
152 |
-
for i, line in enumerate(lines):
|
153 |
-
fnt = initial_fnt
|
154 |
-
fontsize = initial_fontsize
|
155 |
-
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
|
156 |
-
fontsize -= 1
|
157 |
-
fnt = get_font(fontsize)
|
158 |
-
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
|
159 |
-
|
160 |
-
if not line.is_active:
|
161 |
-
drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
|
162 |
-
|
163 |
-
draw_y += line.size[1] + line_spacing
|
164 |
-
|
165 |
-
fontsize = (width + height) // 25
|
166 |
-
line_spacing = fontsize // 2
|
167 |
-
|
168 |
-
fnt = get_font(fontsize)
|
169 |
-
|
170 |
-
color_active = (0, 0, 0)
|
171 |
-
color_inactive = (153, 153, 153)
|
172 |
-
|
173 |
-
pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
|
174 |
-
|
175 |
-
cols = im.width // width
|
176 |
-
rows = im.height // height
|
177 |
-
|
178 |
-
assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
|
179 |
-
assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
|
180 |
-
|
181 |
-
calc_img = Image.new("RGB", (1, 1), "white")
|
182 |
-
calc_d = ImageDraw.Draw(calc_img)
|
183 |
-
|
184 |
-
for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
|
185 |
-
items = [] + texts
|
186 |
-
texts.clear()
|
187 |
-
|
188 |
-
for line in items:
|
189 |
-
wrapped = wrap(calc_d, line.text, fnt, allowed_width)
|
190 |
-
texts += [GridAnnotation(x, line.is_active) for x in wrapped]
|
191 |
-
|
192 |
-
for line in texts:
|
193 |
-
bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
|
194 |
-
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
|
195 |
-
line.allowed_width = allowed_width
|
196 |
-
|
197 |
-
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
|
198 |
-
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
|
199 |
-
|
200 |
-
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
|
201 |
-
|
202 |
-
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
|
203 |
-
|
204 |
-
for row in range(rows):
|
205 |
-
for col in range(cols):
|
206 |
-
cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
|
207 |
-
result.paste(cell, (pad_left + (width + margin) * col, pad_top + (height + margin) * row))
|
208 |
-
|
209 |
-
d = ImageDraw.Draw(result)
|
210 |
-
|
211 |
-
for col in range(cols):
|
212 |
-
x = pad_left + (width + margin) * col + width / 2
|
213 |
-
y = pad_top / 2 - hor_text_heights[col] / 2
|
214 |
-
|
215 |
-
draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
|
216 |
-
|
217 |
-
for row in range(rows):
|
218 |
-
x = pad_left / 2
|
219 |
-
y = pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2
|
220 |
-
|
221 |
-
draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
|
222 |
-
|
223 |
-
return result
|
224 |
-
|
225 |
-
|
226 |
-
def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
|
227 |
-
prompts = all_prompts[1:]
|
228 |
-
boundary = math.ceil(len(prompts) / 2)
|
229 |
-
|
230 |
-
prompts_horiz = prompts[:boundary]
|
231 |
-
prompts_vert = prompts[boundary:]
|
232 |
-
|
233 |
-
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
|
234 |
-
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
|
235 |
-
|
236 |
-
return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
|
237 |
-
|
238 |
-
|
239 |
-
def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
240 |
-
"""
|
241 |
-
Resizes an image with the specified resize_mode, width, and height.
|
242 |
-
|
243 |
-
Args:
|
244 |
-
resize_mode: The mode to use when resizing the image.
|
245 |
-
0: Resize the image to the specified width and height.
|
246 |
-
1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
|
247 |
-
2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
|
248 |
-
im: The image to resize.
|
249 |
-
width: The width to resize the image to.
|
250 |
-
height: The height to resize the image to.
|
251 |
-
upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
|
252 |
-
"""
|
253 |
-
|
254 |
-
upscaler_name = upscaler_name or opts.upscaler_for_img2img
|
255 |
-
|
256 |
-
def resize(im, w, h):
|
257 |
-
if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
|
258 |
-
return im.resize((w, h), resample=LANCZOS)
|
259 |
-
|
260 |
-
scale = max(w / im.width, h / im.height)
|
261 |
-
|
262 |
-
if scale > 1.0:
|
263 |
-
upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
|
264 |
-
assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
|
265 |
-
|
266 |
-
upscaler = upscalers[0]
|
267 |
-
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
|
268 |
-
|
269 |
-
if im.width != w or im.height != h:
|
270 |
-
im = im.resize((w, h), resample=LANCZOS)
|
271 |
-
|
272 |
-
return im
|
273 |
-
|
274 |
-
if resize_mode == 0:
|
275 |
-
res = resize(im, width, height)
|
276 |
-
|
277 |
-
elif resize_mode == 1:
|
278 |
-
ratio = width / height
|
279 |
-
src_ratio = im.width / im.height
|
280 |
-
|
281 |
-
src_w = width if ratio > src_ratio else im.width * height // im.height
|
282 |
-
src_h = height if ratio <= src_ratio else im.height * width // im.width
|
283 |
-
|
284 |
-
resized = resize(im, src_w, src_h)
|
285 |
-
res = Image.new("RGB", (width, height))
|
286 |
-
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
287 |
-
|
288 |
-
else:
|
289 |
-
ratio = width / height
|
290 |
-
src_ratio = im.width / im.height
|
291 |
-
|
292 |
-
src_w = width if ratio < src_ratio else im.width * height // im.height
|
293 |
-
src_h = height if ratio >= src_ratio else im.height * width // im.width
|
294 |
-
|
295 |
-
resized = resize(im, src_w, src_h)
|
296 |
-
res = Image.new("RGB", (width, height))
|
297 |
-
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
298 |
-
|
299 |
-
if ratio < src_ratio:
|
300 |
-
fill_height = height // 2 - src_h // 2
|
301 |
-
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
302 |
-
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
|
303 |
-
elif ratio > src_ratio:
|
304 |
-
fill_width = width // 2 - src_w // 2
|
305 |
-
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
306 |
-
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
|
307 |
-
|
308 |
-
return res
|
309 |
-
|
310 |
-
|
311 |
-
invalid_filename_chars = '<>:"/\\|?*\n'
|
312 |
-
invalid_filename_prefix = ' '
|
313 |
-
invalid_filename_postfix = ' .'
|
314 |
-
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
315 |
-
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
|
316 |
-
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
|
317 |
-
max_filename_part_length = 128
|
318 |
-
|
319 |
-
|
320 |
-
def sanitize_filename_part(text, replace_spaces=True):
|
321 |
-
if text is None:
|
322 |
-
return None
|
323 |
-
|
324 |
-
if replace_spaces:
|
325 |
-
text = text.replace(' ', '_')
|
326 |
-
|
327 |
-
text = text.translate({ord(x): '_' for x in invalid_filename_chars})
|
328 |
-
text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length]
|
329 |
-
text = text.rstrip(invalid_filename_postfix)
|
330 |
-
return text
|
331 |
-
|
332 |
-
|
333 |
-
class FilenameGenerator:
|
334 |
-
replacements = {
|
335 |
-
'seed': lambda self: self.seed if self.seed is not None else '',
|
336 |
-
'steps': lambda self: self.p and self.p.steps,
|
337 |
-
'cfg': lambda self: self.p and self.p.cfg_scale,
|
338 |
-
'width': lambda self: self.image.width,
|
339 |
-
'height': lambda self: self.image.height,
|
340 |
-
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
|
341 |
-
'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
|
342 |
-
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
|
343 |
-
'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.model_name, replace_spaces=False),
|
344 |
-
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
345 |
-
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
346 |
-
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
|
347 |
-
'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
|
348 |
-
'prompt': lambda self: sanitize_filename_part(self.prompt),
|
349 |
-
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
350 |
-
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
351 |
-
'prompt_words': lambda self: self.prompt_words(),
|
352 |
-
}
|
353 |
-
default_time_format = '%Y%m%d%H%M%S'
|
354 |
-
|
355 |
-
def __init__(self, p, seed, prompt, image):
|
356 |
-
self.p = p
|
357 |
-
self.seed = seed
|
358 |
-
self.prompt = prompt
|
359 |
-
self.image = image
|
360 |
-
|
361 |
-
def prompt_no_style(self):
|
362 |
-
if self.p is None or self.prompt is None:
|
363 |
-
return None
|
364 |
-
|
365 |
-
prompt_no_style = self.prompt
|
366 |
-
for style in shared.prompt_styles.get_style_prompts(self.p.styles):
|
367 |
-
if len(style) > 0:
|
368 |
-
for part in style.split("{prompt}"):
|
369 |
-
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
|
370 |
-
|
371 |
-
prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
|
372 |
-
|
373 |
-
return sanitize_filename_part(prompt_no_style, replace_spaces=False)
|
374 |
-
|
375 |
-
def prompt_words(self):
|
376 |
-
words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0]
|
377 |
-
if len(words) == 0:
|
378 |
-
words = ["empty"]
|
379 |
-
return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
|
380 |
-
|
381 |
-
def datetime(self, *args):
|
382 |
-
time_datetime = datetime.datetime.now()
|
383 |
-
|
384 |
-
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
|
385 |
-
try:
|
386 |
-
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
|
387 |
-
except pytz.exceptions.UnknownTimeZoneError as _:
|
388 |
-
time_zone = None
|
389 |
-
|
390 |
-
time_zone_time = time_datetime.astimezone(time_zone)
|
391 |
-
try:
|
392 |
-
formatted_time = time_zone_time.strftime(time_format)
|
393 |
-
except (ValueError, TypeError) as _:
|
394 |
-
formatted_time = time_zone_time.strftime(self.default_time_format)
|
395 |
-
|
396 |
-
return sanitize_filename_part(formatted_time, replace_spaces=False)
|
397 |
-
|
398 |
-
def apply(self, x):
|
399 |
-
res = ''
|
400 |
-
|
401 |
-
for m in re_pattern.finditer(x):
|
402 |
-
text, pattern = m.groups()
|
403 |
-
res += text
|
404 |
-
|
405 |
-
if pattern is None:
|
406 |
-
continue
|
407 |
-
|
408 |
-
pattern_args = []
|
409 |
-
while True:
|
410 |
-
m = re_pattern_arg.match(pattern)
|
411 |
-
if m is None:
|
412 |
-
break
|
413 |
-
|
414 |
-
pattern, arg = m.groups()
|
415 |
-
pattern_args.insert(0, arg)
|
416 |
-
|
417 |
-
fun = self.replacements.get(pattern.lower())
|
418 |
-
if fun is not None:
|
419 |
-
try:
|
420 |
-
replacement = fun(self, *pattern_args)
|
421 |
-
except Exception:
|
422 |
-
replacement = None
|
423 |
-
print(f"Error adding [{pattern}] to filename", file=sys.stderr)
|
424 |
-
print(traceback.format_exc(), file=sys.stderr)
|
425 |
-
|
426 |
-
if replacement is not None:
|
427 |
-
res += str(replacement)
|
428 |
-
continue
|
429 |
-
|
430 |
-
res += f'[{pattern}]'
|
431 |
-
|
432 |
-
return res
|
433 |
-
|
434 |
-
|
435 |
-
def get_next_sequence_number(path, basename):
|
436 |
-
"""
|
437 |
-
Determines and returns the next sequence number to use when saving an image in the specified directory.
|
438 |
-
|
439 |
-
The sequence starts at 0.
|
440 |
-
"""
|
441 |
-
result = -1
|
442 |
-
if basename != '':
|
443 |
-
basename = basename + "-"
|
444 |
-
|
445 |
-
prefix_length = len(basename)
|
446 |
-
for p in os.listdir(path):
|
447 |
-
if p.startswith(basename):
|
448 |
-
l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
449 |
-
try:
|
450 |
-
result = max(int(l[0]), result)
|
451 |
-
except ValueError:
|
452 |
-
pass
|
453 |
-
|
454 |
-
return result + 1
|
455 |
-
|
456 |
-
|
457 |
-
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
|
458 |
-
"""Save an image.
|
459 |
-
|
460 |
-
Args:
|
461 |
-
image (`PIL.Image`):
|
462 |
-
The image to be saved.
|
463 |
-
path (`str`):
|
464 |
-
The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
|
465 |
-
basename (`str`):
|
466 |
-
The base filename which will be applied to `filename pattern`.
|
467 |
-
seed, prompt, short_filename,
|
468 |
-
extension (`str`):
|
469 |
-
Image file extension, default is `png`.
|
470 |
-
pngsectionname (`str`):
|
471 |
-
Specify the name of the section which `info` will be saved in.
|
472 |
-
info (`str` or `PngImagePlugin.iTXt`):
|
473 |
-
PNG info chunks.
|
474 |
-
existing_info (`dict`):
|
475 |
-
Additional PNG info. `existing_info == {pngsectionname: info, ...}`
|
476 |
-
no_prompt:
|
477 |
-
TODO I don't know its meaning.
|
478 |
-
p (`StableDiffusionProcessing`)
|
479 |
-
forced_filename (`str`):
|
480 |
-
If specified, `basename` and filename pattern will be ignored.
|
481 |
-
save_to_dirs (bool):
|
482 |
-
If true, the image will be saved into a subdirectory of `path`.
|
483 |
-
|
484 |
-
Returns: (fullfn, txt_fullfn)
|
485 |
-
fullfn (`str`):
|
486 |
-
The full path of the saved imaged.
|
487 |
-
txt_fullfn (`str` or None):
|
488 |
-
If a text file is saved for this image, this will be its full path. Otherwise None.
|
489 |
-
"""
|
490 |
-
namegen = FilenameGenerator(p, seed, prompt, image)
|
491 |
-
|
492 |
-
if save_to_dirs is None:
|
493 |
-
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
|
494 |
-
|
495 |
-
if save_to_dirs:
|
496 |
-
dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
|
497 |
-
path = os.path.join(path, dirname)
|
498 |
-
|
499 |
-
os.makedirs(path, exist_ok=True)
|
500 |
-
|
501 |
-
if forced_filename is None:
|
502 |
-
if short_filename or seed is None:
|
503 |
-
file_decoration = ""
|
504 |
-
elif opts.save_to_dirs:
|
505 |
-
file_decoration = opts.samples_filename_pattern or "[seed]"
|
506 |
-
else:
|
507 |
-
file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
|
508 |
-
|
509 |
-
add_number = opts.save_images_add_number or file_decoration == ''
|
510 |
-
|
511 |
-
if file_decoration != "" and add_number:
|
512 |
-
file_decoration = "-" + file_decoration
|
513 |
-
|
514 |
-
file_decoration = namegen.apply(file_decoration) + suffix
|
515 |
-
|
516 |
-
if add_number:
|
517 |
-
basecount = get_next_sequence_number(path, basename)
|
518 |
-
fullfn = None
|
519 |
-
for i in range(500):
|
520 |
-
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
|
521 |
-
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
|
522 |
-
if not os.path.exists(fullfn):
|
523 |
-
break
|
524 |
-
else:
|
525 |
-
fullfn = os.path.join(path, f"{file_decoration}.{extension}")
|
526 |
-
else:
|
527 |
-
fullfn = os.path.join(path, f"{forced_filename}.{extension}")
|
528 |
-
|
529 |
-
pnginfo = existing_info or {}
|
530 |
-
if info is not None:
|
531 |
-
pnginfo[pnginfo_section_name] = info
|
532 |
-
|
533 |
-
params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
|
534 |
-
script_callbacks.before_image_saved_callback(params)
|
535 |
-
|
536 |
-
image = params.image
|
537 |
-
fullfn = params.filename
|
538 |
-
info = params.pnginfo.get(pnginfo_section_name, None)
|
539 |
-
|
540 |
-
def _atomically_save_image(image_to_save, filename_without_extension, extension):
|
541 |
-
# save image with .tmp extension to avoid race condition when another process detects new image in the directory
|
542 |
-
temp_file_path = filename_without_extension + ".tmp"
|
543 |
-
image_format = Image.registered_extensions()[extension]
|
544 |
-
|
545 |
-
if extension.lower() == '.png':
|
546 |
-
pnginfo_data = PngImagePlugin.PngInfo()
|
547 |
-
if opts.enable_pnginfo:
|
548 |
-
for k, v in params.pnginfo.items():
|
549 |
-
pnginfo_data.add_text(k, str(v))
|
550 |
-
|
551 |
-
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
552 |
-
|
553 |
-
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
|
554 |
-
if image_to_save.mode == 'RGBA':
|
555 |
-
image_to_save = image_to_save.convert("RGB")
|
556 |
-
elif image_to_save.mode == 'I;16':
|
557 |
-
image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
|
558 |
-
|
559 |
-
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
|
560 |
-
|
561 |
-
if opts.enable_pnginfo and info is not None:
|
562 |
-
exif_bytes = piexif.dump({
|
563 |
-
"Exif": {
|
564 |
-
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
|
565 |
-
},
|
566 |
-
})
|
567 |
-
|
568 |
-
piexif.insert(exif_bytes, temp_file_path)
|
569 |
-
else:
|
570 |
-
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
|
571 |
-
|
572 |
-
# atomically rename the file with correct extension
|
573 |
-
os.replace(temp_file_path, filename_without_extension + extension)
|
574 |
-
|
575 |
-
fullfn_without_extension, extension = os.path.splitext(params.filename)
|
576 |
-
_atomically_save_image(image, fullfn_without_extension, extension)
|
577 |
-
|
578 |
-
image.already_saved_as = fullfn
|
579 |
-
|
580 |
-
oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
|
581 |
-
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
|
582 |
-
ratio = image.width / image.height
|
583 |
-
|
584 |
-
if oversize and ratio > 1:
|
585 |
-
image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS)
|
586 |
-
elif oversize:
|
587 |
-
image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS)
|
588 |
-
|
589 |
-
try:
|
590 |
-
_atomically_save_image(image, fullfn_without_extension, ".jpg")
|
591 |
-
except Exception as e:
|
592 |
-
errors.display(e, "saving image as downscaled JPG")
|
593 |
-
|
594 |
-
if opts.save_txt and info is not None:
|
595 |
-
txt_fullfn = f"{fullfn_without_extension}.txt"
|
596 |
-
with open(txt_fullfn, "w", encoding="utf8") as file:
|
597 |
-
file.write(info + "\n")
|
598 |
-
else:
|
599 |
-
txt_fullfn = None
|
600 |
-
|
601 |
-
script_callbacks.image_saved_callback(params)
|
602 |
-
|
603 |
-
return fullfn, txt_fullfn
|
604 |
-
|
605 |
-
|
606 |
-
def read_info_from_image(image):
|
607 |
-
items = image.info or {}
|
608 |
-
|
609 |
-
geninfo = items.pop('parameters', None)
|
610 |
-
|
611 |
-
if "exif" in items:
|
612 |
-
exif = piexif.load(items["exif"])
|
613 |
-
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
|
614 |
-
try:
|
615 |
-
exif_comment = piexif.helper.UserComment.load(exif_comment)
|
616 |
-
except ValueError:
|
617 |
-
exif_comment = exif_comment.decode('utf8', errors="ignore")
|
618 |
-
|
619 |
-
if exif_comment:
|
620 |
-
items['exif comment'] = exif_comment
|
621 |
-
geninfo = exif_comment
|
622 |
-
|
623 |
-
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
624 |
-
'loop', 'background', 'timestamp', 'duration']:
|
625 |
-
items.pop(field, None)
|
626 |
-
|
627 |
-
if items.get("Software", None) == "NovelAI":
|
628 |
-
try:
|
629 |
-
json_info = json.loads(items["Comment"])
|
630 |
-
sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
|
631 |
-
|
632 |
-
geninfo = f"""{items["Description"]}
|
633 |
-
Negative prompt: {json_info["uc"]}
|
634 |
-
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
|
635 |
-
except Exception:
|
636 |
-
print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
|
637 |
-
print(traceback.format_exc(), file=sys.stderr)
|
638 |
-
|
639 |
-
return geninfo, items
|
640 |
-
|
641 |
-
|
642 |
-
def image_data(data):
|
643 |
-
try:
|
644 |
-
image = Image.open(io.BytesIO(data))
|
645 |
-
textinfo, _ = read_info_from_image(image)
|
646 |
-
return textinfo, None
|
647 |
-
except Exception:
|
648 |
-
pass
|
649 |
-
|
650 |
-
try:
|
651 |
-
text = data.decode('utf8')
|
652 |
-
assert len(text) < 10000
|
653 |
-
return text, None
|
654 |
-
|
655 |
-
except Exception:
|
656 |
-
pass
|
657 |
-
|
658 |
-
return '', None
|
659 |
-
|
660 |
-
|
661 |
-
def flatten(img, bgcolor):
|
662 |
-
"""replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency"""
|
663 |
-
|
664 |
-
if img.mode == "RGBA":
|
665 |
-
background = Image.new('RGBA', img.size, bgcolor)
|
666 |
-
background.paste(img, mask=img)
|
667 |
-
img = background
|
668 |
-
|
669 |
-
return img.convert('RGB')
|
|
|
1 |
+
import datetime
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import pytz
|
6 |
+
import io
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
from collections import namedtuple
|
10 |
+
import re
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import piexif
|
14 |
+
import piexif.helper
|
15 |
+
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
16 |
+
from fonts.ttf import Roboto
|
17 |
+
import string
|
18 |
+
import json
|
19 |
+
import hashlib
|
20 |
+
|
21 |
+
from modules import sd_samplers, shared, script_callbacks, errors
|
22 |
+
from modules.shared import opts, cmd_opts
|
23 |
+
|
24 |
+
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
25 |
+
|
26 |
+
|
27 |
+
def image_grid(imgs, batch_size=1, rows=None):
|
28 |
+
if rows is None:
|
29 |
+
if opts.n_rows > 0:
|
30 |
+
rows = opts.n_rows
|
31 |
+
elif opts.n_rows == 0:
|
32 |
+
rows = batch_size
|
33 |
+
elif opts.grid_prevent_empty_spots:
|
34 |
+
rows = math.floor(math.sqrt(len(imgs)))
|
35 |
+
while len(imgs) % rows != 0:
|
36 |
+
rows -= 1
|
37 |
+
else:
|
38 |
+
rows = math.sqrt(len(imgs))
|
39 |
+
rows = round(rows)
|
40 |
+
if rows > len(imgs):
|
41 |
+
rows = len(imgs)
|
42 |
+
|
43 |
+
cols = math.ceil(len(imgs) / rows)
|
44 |
+
|
45 |
+
params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
|
46 |
+
script_callbacks.image_grid_callback(params)
|
47 |
+
|
48 |
+
w, h = imgs[0].size
|
49 |
+
grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
|
50 |
+
|
51 |
+
for i, img in enumerate(params.imgs):
|
52 |
+
grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
|
53 |
+
|
54 |
+
return grid
|
55 |
+
|
56 |
+
|
57 |
+
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
|
58 |
+
|
59 |
+
|
60 |
+
def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
61 |
+
w = image.width
|
62 |
+
h = image.height
|
63 |
+
|
64 |
+
non_overlap_width = tile_w - overlap
|
65 |
+
non_overlap_height = tile_h - overlap
|
66 |
+
|
67 |
+
cols = math.ceil((w - overlap) / non_overlap_width)
|
68 |
+
rows = math.ceil((h - overlap) / non_overlap_height)
|
69 |
+
|
70 |
+
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
|
71 |
+
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
|
72 |
+
|
73 |
+
grid = Grid([], tile_w, tile_h, w, h, overlap)
|
74 |
+
for row in range(rows):
|
75 |
+
row_images = []
|
76 |
+
|
77 |
+
y = int(row * dy)
|
78 |
+
|
79 |
+
if y + tile_h >= h:
|
80 |
+
y = h - tile_h
|
81 |
+
|
82 |
+
for col in range(cols):
|
83 |
+
x = int(col * dx)
|
84 |
+
|
85 |
+
if x + tile_w >= w:
|
86 |
+
x = w - tile_w
|
87 |
+
|
88 |
+
tile = image.crop((x, y, x + tile_w, y + tile_h))
|
89 |
+
|
90 |
+
row_images.append([x, tile_w, tile])
|
91 |
+
|
92 |
+
grid.tiles.append([y, tile_h, row_images])
|
93 |
+
|
94 |
+
return grid
|
95 |
+
|
96 |
+
|
97 |
+
def combine_grid(grid):
|
98 |
+
def make_mask_image(r):
|
99 |
+
r = r * 255 / grid.overlap
|
100 |
+
r = r.astype(np.uint8)
|
101 |
+
return Image.fromarray(r, 'L')
|
102 |
+
|
103 |
+
mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
|
104 |
+
mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
|
105 |
+
|
106 |
+
combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
|
107 |
+
for y, h, row in grid.tiles:
|
108 |
+
combined_row = Image.new("RGB", (grid.image_w, h))
|
109 |
+
for x, w, tile in row:
|
110 |
+
if x == 0:
|
111 |
+
combined_row.paste(tile, (0, 0))
|
112 |
+
continue
|
113 |
+
|
114 |
+
combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
|
115 |
+
combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
|
116 |
+
|
117 |
+
if y == 0:
|
118 |
+
combined_image.paste(combined_row, (0, 0))
|
119 |
+
continue
|
120 |
+
|
121 |
+
combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
|
122 |
+
combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
|
123 |
+
|
124 |
+
return combined_image
|
125 |
+
|
126 |
+
|
127 |
+
class GridAnnotation:
|
128 |
+
def __init__(self, text='', is_active=True):
|
129 |
+
self.text = text
|
130 |
+
self.is_active = is_active
|
131 |
+
self.size = None
|
132 |
+
|
133 |
+
|
134 |
+
def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
135 |
+
def wrap(drawing, text, font, line_length):
|
136 |
+
lines = ['']
|
137 |
+
for word in text.split():
|
138 |
+
line = f'{lines[-1]} {word}'.strip()
|
139 |
+
if drawing.textlength(line, font=font) <= line_length:
|
140 |
+
lines[-1] = line
|
141 |
+
else:
|
142 |
+
lines.append(word)
|
143 |
+
return lines
|
144 |
+
|
145 |
+
def get_font(fontsize):
|
146 |
+
try:
|
147 |
+
return ImageFont.truetype(opts.font or Roboto, fontsize)
|
148 |
+
except Exception:
|
149 |
+
return ImageFont.truetype(Roboto, fontsize)
|
150 |
+
|
151 |
+
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
|
152 |
+
for i, line in enumerate(lines):
|
153 |
+
fnt = initial_fnt
|
154 |
+
fontsize = initial_fontsize
|
155 |
+
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
|
156 |
+
fontsize -= 1
|
157 |
+
fnt = get_font(fontsize)
|
158 |
+
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
|
159 |
+
|
160 |
+
if not line.is_active:
|
161 |
+
drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
|
162 |
+
|
163 |
+
draw_y += line.size[1] + line_spacing
|
164 |
+
|
165 |
+
fontsize = (width + height) // 25
|
166 |
+
line_spacing = fontsize // 2
|
167 |
+
|
168 |
+
fnt = get_font(fontsize)
|
169 |
+
|
170 |
+
color_active = (0, 0, 0)
|
171 |
+
color_inactive = (153, 153, 153)
|
172 |
+
|
173 |
+
pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
|
174 |
+
|
175 |
+
cols = im.width // width
|
176 |
+
rows = im.height // height
|
177 |
+
|
178 |
+
assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
|
179 |
+
assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
|
180 |
+
|
181 |
+
calc_img = Image.new("RGB", (1, 1), "white")
|
182 |
+
calc_d = ImageDraw.Draw(calc_img)
|
183 |
+
|
184 |
+
for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
|
185 |
+
items = [] + texts
|
186 |
+
texts.clear()
|
187 |
+
|
188 |
+
for line in items:
|
189 |
+
wrapped = wrap(calc_d, line.text, fnt, allowed_width)
|
190 |
+
texts += [GridAnnotation(x, line.is_active) for x in wrapped]
|
191 |
+
|
192 |
+
for line in texts:
|
193 |
+
bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
|
194 |
+
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
|
195 |
+
line.allowed_width = allowed_width
|
196 |
+
|
197 |
+
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
|
198 |
+
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
|
199 |
+
|
200 |
+
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
|
201 |
+
|
202 |
+
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
|
203 |
+
|
204 |
+
for row in range(rows):
|
205 |
+
for col in range(cols):
|
206 |
+
cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
|
207 |
+
result.paste(cell, (pad_left + (width + margin) * col, pad_top + (height + margin) * row))
|
208 |
+
|
209 |
+
d = ImageDraw.Draw(result)
|
210 |
+
|
211 |
+
for col in range(cols):
|
212 |
+
x = pad_left + (width + margin) * col + width / 2
|
213 |
+
y = pad_top / 2 - hor_text_heights[col] / 2
|
214 |
+
|
215 |
+
draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
|
216 |
+
|
217 |
+
for row in range(rows):
|
218 |
+
x = pad_left / 2
|
219 |
+
y = pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2
|
220 |
+
|
221 |
+
draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
|
222 |
+
|
223 |
+
return result
|
224 |
+
|
225 |
+
|
226 |
+
def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
|
227 |
+
prompts = all_prompts[1:]
|
228 |
+
boundary = math.ceil(len(prompts) / 2)
|
229 |
+
|
230 |
+
prompts_horiz = prompts[:boundary]
|
231 |
+
prompts_vert = prompts[boundary:]
|
232 |
+
|
233 |
+
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
|
234 |
+
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
|
235 |
+
|
236 |
+
return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
|
237 |
+
|
238 |
+
|
239 |
+
def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
240 |
+
"""
|
241 |
+
Resizes an image with the specified resize_mode, width, and height.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
resize_mode: The mode to use when resizing the image.
|
245 |
+
0: Resize the image to the specified width and height.
|
246 |
+
1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
|
247 |
+
2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
|
248 |
+
im: The image to resize.
|
249 |
+
width: The width to resize the image to.
|
250 |
+
height: The height to resize the image to.
|
251 |
+
upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
|
252 |
+
"""
|
253 |
+
|
254 |
+
upscaler_name = upscaler_name or opts.upscaler_for_img2img
|
255 |
+
|
256 |
+
def resize(im, w, h):
|
257 |
+
if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
|
258 |
+
return im.resize((w, h), resample=LANCZOS)
|
259 |
+
|
260 |
+
scale = max(w / im.width, h / im.height)
|
261 |
+
|
262 |
+
if scale > 1.0:
|
263 |
+
upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
|
264 |
+
assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
|
265 |
+
|
266 |
+
upscaler = upscalers[0]
|
267 |
+
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
|
268 |
+
|
269 |
+
if im.width != w or im.height != h:
|
270 |
+
im = im.resize((w, h), resample=LANCZOS)
|
271 |
+
|
272 |
+
return im
|
273 |
+
|
274 |
+
if resize_mode == 0:
|
275 |
+
res = resize(im, width, height)
|
276 |
+
|
277 |
+
elif resize_mode == 1:
|
278 |
+
ratio = width / height
|
279 |
+
src_ratio = im.width / im.height
|
280 |
+
|
281 |
+
src_w = width if ratio > src_ratio else im.width * height // im.height
|
282 |
+
src_h = height if ratio <= src_ratio else im.height * width // im.width
|
283 |
+
|
284 |
+
resized = resize(im, src_w, src_h)
|
285 |
+
res = Image.new("RGB", (width, height))
|
286 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
287 |
+
|
288 |
+
else:
|
289 |
+
ratio = width / height
|
290 |
+
src_ratio = im.width / im.height
|
291 |
+
|
292 |
+
src_w = width if ratio < src_ratio else im.width * height // im.height
|
293 |
+
src_h = height if ratio >= src_ratio else im.height * width // im.width
|
294 |
+
|
295 |
+
resized = resize(im, src_w, src_h)
|
296 |
+
res = Image.new("RGB", (width, height))
|
297 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
298 |
+
|
299 |
+
if ratio < src_ratio:
|
300 |
+
fill_height = height // 2 - src_h // 2
|
301 |
+
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
302 |
+
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
|
303 |
+
elif ratio > src_ratio:
|
304 |
+
fill_width = width // 2 - src_w // 2
|
305 |
+
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
306 |
+
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
|
307 |
+
|
308 |
+
return res
|
309 |
+
|
310 |
+
|
311 |
+
invalid_filename_chars = '<>:"/\\|?*\n'
|
312 |
+
invalid_filename_prefix = ' '
|
313 |
+
invalid_filename_postfix = ' .'
|
314 |
+
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
315 |
+
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
|
316 |
+
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
|
317 |
+
max_filename_part_length = 128
|
318 |
+
|
319 |
+
|
320 |
+
def sanitize_filename_part(text, replace_spaces=True):
|
321 |
+
if text is None:
|
322 |
+
return None
|
323 |
+
|
324 |
+
if replace_spaces:
|
325 |
+
text = text.replace(' ', '_')
|
326 |
+
|
327 |
+
text = text.translate({ord(x): '_' for x in invalid_filename_chars})
|
328 |
+
text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length]
|
329 |
+
text = text.rstrip(invalid_filename_postfix)
|
330 |
+
return text
|
331 |
+
|
332 |
+
|
333 |
+
class FilenameGenerator:
|
334 |
+
replacements = {
|
335 |
+
'seed': lambda self: self.seed if self.seed is not None else '',
|
336 |
+
'steps': lambda self: self.p and self.p.steps,
|
337 |
+
'cfg': lambda self: self.p and self.p.cfg_scale,
|
338 |
+
'width': lambda self: self.image.width,
|
339 |
+
'height': lambda self: self.image.height,
|
340 |
+
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
|
341 |
+
'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
|
342 |
+
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
|
343 |
+
'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.model_name, replace_spaces=False),
|
344 |
+
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
345 |
+
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
346 |
+
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
|
347 |
+
'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
|
348 |
+
'prompt': lambda self: sanitize_filename_part(self.prompt),
|
349 |
+
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
350 |
+
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
351 |
+
'prompt_words': lambda self: self.prompt_words(),
|
352 |
+
}
|
353 |
+
default_time_format = '%Y%m%d%H%M%S'
|
354 |
+
|
355 |
+
def __init__(self, p, seed, prompt, image):
|
356 |
+
self.p = p
|
357 |
+
self.seed = seed
|
358 |
+
self.prompt = prompt
|
359 |
+
self.image = image
|
360 |
+
|
361 |
+
def prompt_no_style(self):
|
362 |
+
if self.p is None or self.prompt is None:
|
363 |
+
return None
|
364 |
+
|
365 |
+
prompt_no_style = self.prompt
|
366 |
+
for style in shared.prompt_styles.get_style_prompts(self.p.styles):
|
367 |
+
if len(style) > 0:
|
368 |
+
for part in style.split("{prompt}"):
|
369 |
+
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
|
370 |
+
|
371 |
+
prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
|
372 |
+
|
373 |
+
return sanitize_filename_part(prompt_no_style, replace_spaces=False)
|
374 |
+
|
375 |
+
def prompt_words(self):
|
376 |
+
words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0]
|
377 |
+
if len(words) == 0:
|
378 |
+
words = ["empty"]
|
379 |
+
return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
|
380 |
+
|
381 |
+
def datetime(self, *args):
|
382 |
+
time_datetime = datetime.datetime.now()
|
383 |
+
|
384 |
+
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
|
385 |
+
try:
|
386 |
+
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
|
387 |
+
except pytz.exceptions.UnknownTimeZoneError as _:
|
388 |
+
time_zone = None
|
389 |
+
|
390 |
+
time_zone_time = time_datetime.astimezone(time_zone)
|
391 |
+
try:
|
392 |
+
formatted_time = time_zone_time.strftime(time_format)
|
393 |
+
except (ValueError, TypeError) as _:
|
394 |
+
formatted_time = time_zone_time.strftime(self.default_time_format)
|
395 |
+
|
396 |
+
return sanitize_filename_part(formatted_time, replace_spaces=False)
|
397 |
+
|
398 |
+
def apply(self, x):
|
399 |
+
res = ''
|
400 |
+
|
401 |
+
for m in re_pattern.finditer(x):
|
402 |
+
text, pattern = m.groups()
|
403 |
+
res += text
|
404 |
+
|
405 |
+
if pattern is None:
|
406 |
+
continue
|
407 |
+
|
408 |
+
pattern_args = []
|
409 |
+
while True:
|
410 |
+
m = re_pattern_arg.match(pattern)
|
411 |
+
if m is None:
|
412 |
+
break
|
413 |
+
|
414 |
+
pattern, arg = m.groups()
|
415 |
+
pattern_args.insert(0, arg)
|
416 |
+
|
417 |
+
fun = self.replacements.get(pattern.lower())
|
418 |
+
if fun is not None:
|
419 |
+
try:
|
420 |
+
replacement = fun(self, *pattern_args)
|
421 |
+
except Exception:
|
422 |
+
replacement = None
|
423 |
+
print(f"Error adding [{pattern}] to filename", file=sys.stderr)
|
424 |
+
print(traceback.format_exc(), file=sys.stderr)
|
425 |
+
|
426 |
+
if replacement is not None:
|
427 |
+
res += str(replacement)
|
428 |
+
continue
|
429 |
+
|
430 |
+
res += f'[{pattern}]'
|
431 |
+
|
432 |
+
return res
|
433 |
+
|
434 |
+
|
435 |
+
def get_next_sequence_number(path, basename):
|
436 |
+
"""
|
437 |
+
Determines and returns the next sequence number to use when saving an image in the specified directory.
|
438 |
+
|
439 |
+
The sequence starts at 0.
|
440 |
+
"""
|
441 |
+
result = -1
|
442 |
+
if basename != '':
|
443 |
+
basename = basename + "-"
|
444 |
+
|
445 |
+
prefix_length = len(basename)
|
446 |
+
for p in os.listdir(path):
|
447 |
+
if p.startswith(basename):
|
448 |
+
l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
449 |
+
try:
|
450 |
+
result = max(int(l[0]), result)
|
451 |
+
except ValueError:
|
452 |
+
pass
|
453 |
+
|
454 |
+
return result + 1
|
455 |
+
|
456 |
+
|
457 |
+
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
|
458 |
+
"""Save an image.
|
459 |
+
|
460 |
+
Args:
|
461 |
+
image (`PIL.Image`):
|
462 |
+
The image to be saved.
|
463 |
+
path (`str`):
|
464 |
+
The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
|
465 |
+
basename (`str`):
|
466 |
+
The base filename which will be applied to `filename pattern`.
|
467 |
+
seed, prompt, short_filename,
|
468 |
+
extension (`str`):
|
469 |
+
Image file extension, default is `png`.
|
470 |
+
pngsectionname (`str`):
|
471 |
+
Specify the name of the section which `info` will be saved in.
|
472 |
+
info (`str` or `PngImagePlugin.iTXt`):
|
473 |
+
PNG info chunks.
|
474 |
+
existing_info (`dict`):
|
475 |
+
Additional PNG info. `existing_info == {pngsectionname: info, ...}`
|
476 |
+
no_prompt:
|
477 |
+
TODO I don't know its meaning.
|
478 |
+
p (`StableDiffusionProcessing`)
|
479 |
+
forced_filename (`str`):
|
480 |
+
If specified, `basename` and filename pattern will be ignored.
|
481 |
+
save_to_dirs (bool):
|
482 |
+
If true, the image will be saved into a subdirectory of `path`.
|
483 |
+
|
484 |
+
Returns: (fullfn, txt_fullfn)
|
485 |
+
fullfn (`str`):
|
486 |
+
The full path of the saved imaged.
|
487 |
+
txt_fullfn (`str` or None):
|
488 |
+
If a text file is saved for this image, this will be its full path. Otherwise None.
|
489 |
+
"""
|
490 |
+
namegen = FilenameGenerator(p, seed, prompt, image)
|
491 |
+
|
492 |
+
if save_to_dirs is None:
|
493 |
+
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
|
494 |
+
|
495 |
+
if save_to_dirs:
|
496 |
+
dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
|
497 |
+
path = os.path.join(path, dirname)
|
498 |
+
|
499 |
+
os.makedirs(path, exist_ok=True)
|
500 |
+
|
501 |
+
if forced_filename is None:
|
502 |
+
if short_filename or seed is None:
|
503 |
+
file_decoration = ""
|
504 |
+
elif opts.save_to_dirs:
|
505 |
+
file_decoration = opts.samples_filename_pattern or "[seed]"
|
506 |
+
else:
|
507 |
+
file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
|
508 |
+
|
509 |
+
add_number = opts.save_images_add_number or file_decoration == ''
|
510 |
+
|
511 |
+
if file_decoration != "" and add_number:
|
512 |
+
file_decoration = "-" + file_decoration
|
513 |
+
|
514 |
+
file_decoration = namegen.apply(file_decoration) + suffix
|
515 |
+
|
516 |
+
if add_number:
|
517 |
+
basecount = get_next_sequence_number(path, basename)
|
518 |
+
fullfn = None
|
519 |
+
for i in range(500):
|
520 |
+
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
|
521 |
+
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
|
522 |
+
if not os.path.exists(fullfn):
|
523 |
+
break
|
524 |
+
else:
|
525 |
+
fullfn = os.path.join(path, f"{file_decoration}.{extension}")
|
526 |
+
else:
|
527 |
+
fullfn = os.path.join(path, f"{forced_filename}.{extension}")
|
528 |
+
|
529 |
+
pnginfo = existing_info or {}
|
530 |
+
if info is not None:
|
531 |
+
pnginfo[pnginfo_section_name] = info
|
532 |
+
|
533 |
+
params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
|
534 |
+
script_callbacks.before_image_saved_callback(params)
|
535 |
+
|
536 |
+
image = params.image
|
537 |
+
fullfn = params.filename
|
538 |
+
info = params.pnginfo.get(pnginfo_section_name, None)
|
539 |
+
|
540 |
+
def _atomically_save_image(image_to_save, filename_without_extension, extension):
|
541 |
+
# save image with .tmp extension to avoid race condition when another process detects new image in the directory
|
542 |
+
temp_file_path = filename_without_extension + ".tmp"
|
543 |
+
image_format = Image.registered_extensions()[extension]
|
544 |
+
|
545 |
+
if extension.lower() == '.png':
|
546 |
+
pnginfo_data = PngImagePlugin.PngInfo()
|
547 |
+
if opts.enable_pnginfo:
|
548 |
+
for k, v in params.pnginfo.items():
|
549 |
+
pnginfo_data.add_text(k, str(v))
|
550 |
+
|
551 |
+
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
552 |
+
|
553 |
+
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
|
554 |
+
if image_to_save.mode == 'RGBA':
|
555 |
+
image_to_save = image_to_save.convert("RGB")
|
556 |
+
elif image_to_save.mode == 'I;16':
|
557 |
+
image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
|
558 |
+
|
559 |
+
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
|
560 |
+
|
561 |
+
if opts.enable_pnginfo and info is not None:
|
562 |
+
exif_bytes = piexif.dump({
|
563 |
+
"Exif": {
|
564 |
+
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
|
565 |
+
},
|
566 |
+
})
|
567 |
+
|
568 |
+
piexif.insert(exif_bytes, temp_file_path)
|
569 |
+
else:
|
570 |
+
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
|
571 |
+
|
572 |
+
# atomically rename the file with correct extension
|
573 |
+
os.replace(temp_file_path, filename_without_extension + extension)
|
574 |
+
|
575 |
+
fullfn_without_extension, extension = os.path.splitext(params.filename)
|
576 |
+
_atomically_save_image(image, fullfn_without_extension, extension)
|
577 |
+
|
578 |
+
image.already_saved_as = fullfn
|
579 |
+
|
580 |
+
oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
|
581 |
+
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
|
582 |
+
ratio = image.width / image.height
|
583 |
+
|
584 |
+
if oversize and ratio > 1:
|
585 |
+
image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS)
|
586 |
+
elif oversize:
|
587 |
+
image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS)
|
588 |
+
|
589 |
+
try:
|
590 |
+
_atomically_save_image(image, fullfn_without_extension, ".jpg")
|
591 |
+
except Exception as e:
|
592 |
+
errors.display(e, "saving image as downscaled JPG")
|
593 |
+
|
594 |
+
if opts.save_txt and info is not None:
|
595 |
+
txt_fullfn = f"{fullfn_without_extension}.txt"
|
596 |
+
with open(txt_fullfn, "w", encoding="utf8") as file:
|
597 |
+
file.write(info + "\n")
|
598 |
+
else:
|
599 |
+
txt_fullfn = None
|
600 |
+
|
601 |
+
script_callbacks.image_saved_callback(params)
|
602 |
+
|
603 |
+
return fullfn, txt_fullfn
|
604 |
+
|
605 |
+
|
606 |
+
def read_info_from_image(image):
|
607 |
+
items = image.info or {}
|
608 |
+
|
609 |
+
geninfo = items.pop('parameters', None)
|
610 |
+
|
611 |
+
if "exif" in items:
|
612 |
+
exif = piexif.load(items["exif"])
|
613 |
+
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
|
614 |
+
try:
|
615 |
+
exif_comment = piexif.helper.UserComment.load(exif_comment)
|
616 |
+
except ValueError:
|
617 |
+
exif_comment = exif_comment.decode('utf8', errors="ignore")
|
618 |
+
|
619 |
+
if exif_comment:
|
620 |
+
items['exif comment'] = exif_comment
|
621 |
+
geninfo = exif_comment
|
622 |
+
|
623 |
+
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
624 |
+
'loop', 'background', 'timestamp', 'duration']:
|
625 |
+
items.pop(field, None)
|
626 |
+
|
627 |
+
if items.get("Software", None) == "NovelAI":
|
628 |
+
try:
|
629 |
+
json_info = json.loads(items["Comment"])
|
630 |
+
sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
|
631 |
+
|
632 |
+
geninfo = f"""{items["Description"]}
|
633 |
+
Negative prompt: {json_info["uc"]}
|
634 |
+
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
|
635 |
+
except Exception:
|
636 |
+
print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
|
637 |
+
print(traceback.format_exc(), file=sys.stderr)
|
638 |
+
|
639 |
+
return geninfo, items
|
640 |
+
|
641 |
+
|
642 |
+
def image_data(data):
|
643 |
+
try:
|
644 |
+
image = Image.open(io.BytesIO(data))
|
645 |
+
textinfo, _ = read_info_from_image(image)
|
646 |
+
return textinfo, None
|
647 |
+
except Exception:
|
648 |
+
pass
|
649 |
+
|
650 |
+
try:
|
651 |
+
text = data.decode('utf8')
|
652 |
+
assert len(text) < 10000
|
653 |
+
return text, None
|
654 |
+
|
655 |
+
except Exception:
|
656 |
+
pass
|
657 |
+
|
658 |
+
return '', None
|
659 |
+
|
660 |
+
|
661 |
+
def flatten(img, bgcolor):
|
662 |
+
"""replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency"""
|
663 |
+
|
664 |
+
if img.mode == "RGBA":
|
665 |
+
background = Image.new('RGBA', img.size, bgcolor)
|
666 |
+
background.paste(img, mask=img)
|
667 |
+
img = background
|
668 |
+
|
669 |
+
return img.convert('RGB')
|
sd/stable-diffusion-webui/modules/img2img.py
CHANGED
@@ -1,184 +1,184 @@
|
|
1 |
-
import math
|
2 |
-
import os
|
3 |
-
import sys
|
4 |
-
import traceback
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
|
8 |
-
|
9 |
-
from modules import devices, sd_samplers
|
10 |
-
from modules.generation_parameters_copypaste import create_override_settings_dict
|
11 |
-
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
12 |
-
from modules.shared import opts, state
|
13 |
-
import modules.shared as shared
|
14 |
-
import modules.processing as processing
|
15 |
-
from modules.ui import plaintext_to_html
|
16 |
-
import modules.images as images
|
17 |
-
import modules.scripts
|
18 |
-
|
19 |
-
|
20 |
-
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
21 |
-
processing.fix_seed(p)
|
22 |
-
|
23 |
-
images = shared.listfiles(input_dir)
|
24 |
-
|
25 |
-
is_inpaint_batch = False
|
26 |
-
if inpaint_mask_dir:
|
27 |
-
inpaint_masks = shared.listfiles(inpaint_mask_dir)
|
28 |
-
is_inpaint_batch = len(inpaint_masks) > 0
|
29 |
-
if is_inpaint_batch:
|
30 |
-
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
|
31 |
-
|
32 |
-
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
33 |
-
|
34 |
-
save_normally = output_dir == ''
|
35 |
-
|
36 |
-
p.do_not_save_grid = True
|
37 |
-
p.do_not_save_samples = not save_normally
|
38 |
-
|
39 |
-
state.job_count = len(images) * p.n_iter
|
40 |
-
|
41 |
-
for i, image in enumerate(images):
|
42 |
-
state.job = f"{i+1} out of {len(images)}"
|
43 |
-
if state.skipped:
|
44 |
-
state.skipped = False
|
45 |
-
|
46 |
-
if state.interrupted:
|
47 |
-
break
|
48 |
-
|
49 |
-
img = Image.open(image)
|
50 |
-
# Use the EXIF orientation of photos taken by smartphones.
|
51 |
-
img = ImageOps.exif_transpose(img)
|
52 |
-
p.init_images = [img] * p.batch_size
|
53 |
-
|
54 |
-
if is_inpaint_batch:
|
55 |
-
# try to find corresponding mask for an image using simple filename matching
|
56 |
-
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
|
57 |
-
# if not found use first one ("same mask for all images" use-case)
|
58 |
-
if not mask_image_path in inpaint_masks:
|
59 |
-
mask_image_path = inpaint_masks[0]
|
60 |
-
mask_image = Image.open(mask_image_path)
|
61 |
-
p.image_mask = mask_image
|
62 |
-
|
63 |
-
proc = modules.scripts.scripts_img2img.run(p, *args)
|
64 |
-
if proc is None:
|
65 |
-
proc = process_images(p)
|
66 |
-
|
67 |
-
for n, processed_image in enumerate(proc.images):
|
68 |
-
filename = os.path.basename(image)
|
69 |
-
|
70 |
-
if n > 0:
|
71 |
-
left, right = os.path.splitext(filename)
|
72 |
-
filename = f"{left}-{n}{right}"
|
73 |
-
|
74 |
-
if not save_normally:
|
75 |
-
os.makedirs(output_dir, exist_ok=True)
|
76 |
-
if processed_image.mode == 'RGBA':
|
77 |
-
processed_image = processed_image.convert("RGB")
|
78 |
-
processed_image.save(os.path.join(output_dir, filename))
|
79 |
-
|
80 |
-
|
81 |
-
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
|
82 |
-
override_settings = create_override_settings_dict(override_settings_texts)
|
83 |
-
|
84 |
-
is_batch = mode == 5
|
85 |
-
|
86 |
-
if mode == 0: # img2img
|
87 |
-
image = init_img.convert("RGB")
|
88 |
-
mask = None
|
89 |
-
elif mode == 1: # img2img sketch
|
90 |
-
image = sketch.convert("RGB")
|
91 |
-
mask = None
|
92 |
-
elif mode == 2: # inpaint
|
93 |
-
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
94 |
-
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
95 |
-
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
|
96 |
-
image = image.convert("RGB")
|
97 |
-
elif mode == 3: # inpaint sketch
|
98 |
-
image = inpaint_color_sketch
|
99 |
-
orig = inpaint_color_sketch_orig or inpaint_color_sketch
|
100 |
-
pred = np.any(np.array(image) != np.array(orig), axis=-1)
|
101 |
-
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
|
102 |
-
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
|
103 |
-
blur = ImageFilter.GaussianBlur(mask_blur)
|
104 |
-
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
105 |
-
image = image.convert("RGB")
|
106 |
-
elif mode == 4: # inpaint upload mask
|
107 |
-
image = init_img_inpaint
|
108 |
-
mask = init_mask_inpaint
|
109 |
-
else:
|
110 |
-
image = None
|
111 |
-
mask = None
|
112 |
-
|
113 |
-
# Use the EXIF orientation of photos taken by smartphones.
|
114 |
-
if image is not None:
|
115 |
-
image = ImageOps.exif_transpose(image)
|
116 |
-
|
117 |
-
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
118 |
-
|
119 |
-
p = StableDiffusionProcessingImg2Img(
|
120 |
-
sd_model=shared.sd_model,
|
121 |
-
outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
|
122 |
-
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
|
123 |
-
prompt=prompt,
|
124 |
-
negative_prompt=negative_prompt,
|
125 |
-
styles=prompt_styles,
|
126 |
-
seed=seed,
|
127 |
-
subseed=subseed,
|
128 |
-
subseed_strength=subseed_strength,
|
129 |
-
seed_resize_from_h=seed_resize_from_h,
|
130 |
-
seed_resize_from_w=seed_resize_from_w,
|
131 |
-
seed_enable_extras=seed_enable_extras,
|
132 |
-
sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
|
133 |
-
batch_size=batch_size,
|
134 |
-
n_iter=n_iter,
|
135 |
-
steps=steps,
|
136 |
-
cfg_scale=cfg_scale,
|
137 |
-
width=width,
|
138 |
-
height=height,
|
139 |
-
restore_faces=restore_faces,
|
140 |
-
tiling=tiling,
|
141 |
-
init_images=[image],
|
142 |
-
mask=mask,
|
143 |
-
mask_blur=mask_blur,
|
144 |
-
inpainting_fill=inpainting_fill,
|
145 |
-
resize_mode=resize_mode,
|
146 |
-
denoising_strength=denoising_strength,
|
147 |
-
image_cfg_scale=image_cfg_scale,
|
148 |
-
inpaint_full_res=inpaint_full_res,
|
149 |
-
inpaint_full_res_padding=inpaint_full_res_padding,
|
150 |
-
inpainting_mask_invert=inpainting_mask_invert,
|
151 |
-
override_settings=override_settings,
|
152 |
-
)
|
153 |
-
|
154 |
-
p.scripts = modules.scripts.scripts_txt2img
|
155 |
-
p.script_args = args
|
156 |
-
|
157 |
-
if shared.cmd_opts.enable_console_prompts:
|
158 |
-
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
159 |
-
|
160 |
-
p.extra_generation_params["Mask blur"] = mask_blur
|
161 |
-
|
162 |
-
if is_batch:
|
163 |
-
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
164 |
-
|
165 |
-
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args)
|
166 |
-
|
167 |
-
processed = Processed(p, [], p.seed, "")
|
168 |
-
else:
|
169 |
-
processed = modules.scripts.scripts_img2img.run(p, *args)
|
170 |
-
if processed is None:
|
171 |
-
processed = process_images(p)
|
172 |
-
|
173 |
-
p.close()
|
174 |
-
|
175 |
-
shared.total_tqdm.clear()
|
176 |
-
|
177 |
-
generation_info_js = processed.js()
|
178 |
-
if opts.samples_log_stdout:
|
179 |
-
print(generation_info_js)
|
180 |
-
|
181 |
-
if opts.do_not_show_images:
|
182 |
-
processed.images = []
|
183 |
-
|
184 |
-
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import traceback
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
|
8 |
+
|
9 |
+
from modules import devices, sd_samplers
|
10 |
+
from modules.generation_parameters_copypaste import create_override_settings_dict
|
11 |
+
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
12 |
+
from modules.shared import opts, state
|
13 |
+
import modules.shared as shared
|
14 |
+
import modules.processing as processing
|
15 |
+
from modules.ui import plaintext_to_html
|
16 |
+
import modules.images as images
|
17 |
+
import modules.scripts
|
18 |
+
|
19 |
+
|
20 |
+
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
21 |
+
processing.fix_seed(p)
|
22 |
+
|
23 |
+
images = shared.listfiles(input_dir)
|
24 |
+
|
25 |
+
is_inpaint_batch = False
|
26 |
+
if inpaint_mask_dir:
|
27 |
+
inpaint_masks = shared.listfiles(inpaint_mask_dir)
|
28 |
+
is_inpaint_batch = len(inpaint_masks) > 0
|
29 |
+
if is_inpaint_batch:
|
30 |
+
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
|
31 |
+
|
32 |
+
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
33 |
+
|
34 |
+
save_normally = output_dir == ''
|
35 |
+
|
36 |
+
p.do_not_save_grid = True
|
37 |
+
p.do_not_save_samples = not save_normally
|
38 |
+
|
39 |
+
state.job_count = len(images) * p.n_iter
|
40 |
+
|
41 |
+
for i, image in enumerate(images):
|
42 |
+
state.job = f"{i+1} out of {len(images)}"
|
43 |
+
if state.skipped:
|
44 |
+
state.skipped = False
|
45 |
+
|
46 |
+
if state.interrupted:
|
47 |
+
break
|
48 |
+
|
49 |
+
img = Image.open(image)
|
50 |
+
# Use the EXIF orientation of photos taken by smartphones.
|
51 |
+
img = ImageOps.exif_transpose(img)
|
52 |
+
p.init_images = [img] * p.batch_size
|
53 |
+
|
54 |
+
if is_inpaint_batch:
|
55 |
+
# try to find corresponding mask for an image using simple filename matching
|
56 |
+
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
|
57 |
+
# if not found use first one ("same mask for all images" use-case)
|
58 |
+
if not mask_image_path in inpaint_masks:
|
59 |
+
mask_image_path = inpaint_masks[0]
|
60 |
+
mask_image = Image.open(mask_image_path)
|
61 |
+
p.image_mask = mask_image
|
62 |
+
|
63 |
+
proc = modules.scripts.scripts_img2img.run(p, *args)
|
64 |
+
if proc is None:
|
65 |
+
proc = process_images(p)
|
66 |
+
|
67 |
+
for n, processed_image in enumerate(proc.images):
|
68 |
+
filename = os.path.basename(image)
|
69 |
+
|
70 |
+
if n > 0:
|
71 |
+
left, right = os.path.splitext(filename)
|
72 |
+
filename = f"{left}-{n}{right}"
|
73 |
+
|
74 |
+
if not save_normally:
|
75 |
+
os.makedirs(output_dir, exist_ok=True)
|
76 |
+
if processed_image.mode == 'RGBA':
|
77 |
+
processed_image = processed_image.convert("RGB")
|
78 |
+
processed_image.save(os.path.join(output_dir, filename))
|
79 |
+
|
80 |
+
|
81 |
+
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
|
82 |
+
override_settings = create_override_settings_dict(override_settings_texts)
|
83 |
+
|
84 |
+
is_batch = mode == 5
|
85 |
+
|
86 |
+
if mode == 0: # img2img
|
87 |
+
image = init_img.convert("RGB")
|
88 |
+
mask = None
|
89 |
+
elif mode == 1: # img2img sketch
|
90 |
+
image = sketch.convert("RGB")
|
91 |
+
mask = None
|
92 |
+
elif mode == 2: # inpaint
|
93 |
+
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
94 |
+
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
95 |
+
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
|
96 |
+
image = image.convert("RGB")
|
97 |
+
elif mode == 3: # inpaint sketch
|
98 |
+
image = inpaint_color_sketch
|
99 |
+
orig = inpaint_color_sketch_orig or inpaint_color_sketch
|
100 |
+
pred = np.any(np.array(image) != np.array(orig), axis=-1)
|
101 |
+
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
|
102 |
+
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
|
103 |
+
blur = ImageFilter.GaussianBlur(mask_blur)
|
104 |
+
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
105 |
+
image = image.convert("RGB")
|
106 |
+
elif mode == 4: # inpaint upload mask
|
107 |
+
image = init_img_inpaint
|
108 |
+
mask = init_mask_inpaint
|
109 |
+
else:
|
110 |
+
image = None
|
111 |
+
mask = None
|
112 |
+
|
113 |
+
# Use the EXIF orientation of photos taken by smartphones.
|
114 |
+
if image is not None:
|
115 |
+
image = ImageOps.exif_transpose(image)
|
116 |
+
|
117 |
+
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
118 |
+
|
119 |
+
p = StableDiffusionProcessingImg2Img(
|
120 |
+
sd_model=shared.sd_model,
|
121 |
+
outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
|
122 |
+
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
|
123 |
+
prompt=prompt,
|
124 |
+
negative_prompt=negative_prompt,
|
125 |
+
styles=prompt_styles,
|
126 |
+
seed=seed,
|
127 |
+
subseed=subseed,
|
128 |
+
subseed_strength=subseed_strength,
|
129 |
+
seed_resize_from_h=seed_resize_from_h,
|
130 |
+
seed_resize_from_w=seed_resize_from_w,
|
131 |
+
seed_enable_extras=seed_enable_extras,
|
132 |
+
sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
|
133 |
+
batch_size=batch_size,
|
134 |
+
n_iter=n_iter,
|
135 |
+
steps=steps,
|
136 |
+
cfg_scale=cfg_scale,
|
137 |
+
width=width,
|
138 |
+
height=height,
|
139 |
+
restore_faces=restore_faces,
|
140 |
+
tiling=tiling,
|
141 |
+
init_images=[image],
|
142 |
+
mask=mask,
|
143 |
+
mask_blur=mask_blur,
|
144 |
+
inpainting_fill=inpainting_fill,
|
145 |
+
resize_mode=resize_mode,
|
146 |
+
denoising_strength=denoising_strength,
|
147 |
+
image_cfg_scale=image_cfg_scale,
|
148 |
+
inpaint_full_res=inpaint_full_res,
|
149 |
+
inpaint_full_res_padding=inpaint_full_res_padding,
|
150 |
+
inpainting_mask_invert=inpainting_mask_invert,
|
151 |
+
override_settings=override_settings,
|
152 |
+
)
|
153 |
+
|
154 |
+
p.scripts = modules.scripts.scripts_txt2img
|
155 |
+
p.script_args = args
|
156 |
+
|
157 |
+
if shared.cmd_opts.enable_console_prompts:
|
158 |
+
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
159 |
+
|
160 |
+
p.extra_generation_params["Mask blur"] = mask_blur
|
161 |
+
|
162 |
+
if is_batch:
|
163 |
+
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
164 |
+
|
165 |
+
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args)
|
166 |
+
|
167 |
+
processed = Processed(p, [], p.seed, "")
|
168 |
+
else:
|
169 |
+
processed = modules.scripts.scripts_img2img.run(p, *args)
|
170 |
+
if processed is None:
|
171 |
+
processed = process_images(p)
|
172 |
+
|
173 |
+
p.close()
|
174 |
+
|
175 |
+
shared.total_tqdm.clear()
|
176 |
+
|
177 |
+
generation_info_js = processed.js()
|
178 |
+
if opts.samples_log_stdout:
|
179 |
+
print(generation_info_js)
|
180 |
+
|
181 |
+
if opts.do_not_show_images:
|
182 |
+
processed.images = []
|
183 |
+
|
184 |
+
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
|
sd/stable-diffusion-webui/modules/interrogate.py
CHANGED
@@ -1,227 +1,227 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import traceback
|
4 |
-
from collections import namedtuple
|
5 |
-
from pathlib import Path
|
6 |
-
import re
|
7 |
-
|
8 |
-
import torch
|
9 |
-
import torch.hub
|
10 |
-
|
11 |
-
from torchvision import transforms
|
12 |
-
from torchvision.transforms.functional import InterpolationMode
|
13 |
-
|
14 |
-
import modules.shared as shared
|
15 |
-
from modules import devices, paths, shared, lowvram, modelloader, errors
|
16 |
-
|
17 |
-
blip_image_eval_size = 384
|
18 |
-
clip_model_name = 'ViT-L/14'
|
19 |
-
|
20 |
-
Category = namedtuple("Category", ["name", "topn", "items"])
|
21 |
-
|
22 |
-
re_topn = re.compile(r"\.top(\d+)\.")
|
23 |
-
|
24 |
-
def category_types():
|
25 |
-
return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
|
26 |
-
|
27 |
-
|
28 |
-
def download_default_clip_interrogate_categories(content_dir):
|
29 |
-
print("Downloading CLIP categories...")
|
30 |
-
|
31 |
-
tmpdir = content_dir + "_tmp"
|
32 |
-
category_types = ["artists", "flavors", "mediums", "movements"]
|
33 |
-
|
34 |
-
try:
|
35 |
-
os.makedirs(tmpdir)
|
36 |
-
for category_type in category_types:
|
37 |
-
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
|
38 |
-
os.rename(tmpdir, content_dir)
|
39 |
-
|
40 |
-
except Exception as e:
|
41 |
-
errors.display(e, "downloading default CLIP interrogate categories")
|
42 |
-
finally:
|
43 |
-
if os.path.exists(tmpdir):
|
44 |
-
os.remove(tmpdir)
|
45 |
-
|
46 |
-
|
47 |
-
class InterrogateModels:
|
48 |
-
blip_model = None
|
49 |
-
clip_model = None
|
50 |
-
clip_preprocess = None
|
51 |
-
dtype = None
|
52 |
-
running_on_cpu = None
|
53 |
-
|
54 |
-
def __init__(self, content_dir):
|
55 |
-
self.loaded_categories = None
|
56 |
-
self.skip_categories = []
|
57 |
-
self.content_dir = content_dir
|
58 |
-
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
|
59 |
-
|
60 |
-
def categories(self):
|
61 |
-
if not os.path.exists(self.content_dir):
|
62 |
-
download_default_clip_interrogate_categories(self.content_dir)
|
63 |
-
|
64 |
-
if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
|
65 |
-
return self.loaded_categories
|
66 |
-
|
67 |
-
self.loaded_categories = []
|
68 |
-
|
69 |
-
if os.path.exists(self.content_dir):
|
70 |
-
self.skip_categories = shared.opts.interrogate_clip_skip_categories
|
71 |
-
category_types = []
|
72 |
-
for filename in Path(self.content_dir).glob('*.txt'):
|
73 |
-
category_types.append(filename.stem)
|
74 |
-
if filename.stem in self.skip_categories:
|
75 |
-
continue
|
76 |
-
m = re_topn.search(filename.stem)
|
77 |
-
topn = 1 if m is None else int(m.group(1))
|
78 |
-
with open(filename, "r", encoding="utf8") as file:
|
79 |
-
lines = [x.strip() for x in file.readlines()]
|
80 |
-
|
81 |
-
self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
|
82 |
-
|
83 |
-
return self.loaded_categories
|
84 |
-
|
85 |
-
def create_fake_fairscale(self):
|
86 |
-
class FakeFairscale:
|
87 |
-
def checkpoint_wrapper(self):
|
88 |
-
pass
|
89 |
-
|
90 |
-
sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
|
91 |
-
|
92 |
-
def load_blip_model(self):
|
93 |
-
self.create_fake_fairscale()
|
94 |
-
import models.blip
|
95 |
-
|
96 |
-
files = modelloader.load_models(
|
97 |
-
model_path=os.path.join(paths.models_path, "BLIP"),
|
98 |
-
model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
|
99 |
-
ext_filter=[".pth"],
|
100 |
-
download_name='model_base_caption_capfilt_large.pth',
|
101 |
-
)
|
102 |
-
|
103 |
-
blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
|
104 |
-
blip_model.eval()
|
105 |
-
|
106 |
-
return blip_model
|
107 |
-
|
108 |
-
def load_clip_model(self):
|
109 |
-
import clip
|
110 |
-
|
111 |
-
if self.running_on_cpu:
|
112 |
-
model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
|
113 |
-
else:
|
114 |
-
model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
|
115 |
-
|
116 |
-
model.eval()
|
117 |
-
model = model.to(devices.device_interrogate)
|
118 |
-
|
119 |
-
return model, preprocess
|
120 |
-
|
121 |
-
def load(self):
|
122 |
-
if self.blip_model is None:
|
123 |
-
self.blip_model = self.load_blip_model()
|
124 |
-
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
125 |
-
self.blip_model = self.blip_model.half()
|
126 |
-
|
127 |
-
self.blip_model = self.blip_model.to(devices.device_interrogate)
|
128 |
-
|
129 |
-
if self.clip_model is None:
|
130 |
-
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
131 |
-
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
132 |
-
self.clip_model = self.clip_model.half()
|
133 |
-
|
134 |
-
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
135 |
-
|
136 |
-
self.dtype = next(self.clip_model.parameters()).dtype
|
137 |
-
|
138 |
-
def send_clip_to_ram(self):
|
139 |
-
if not shared.opts.interrogate_keep_models_in_memory:
|
140 |
-
if self.clip_model is not None:
|
141 |
-
self.clip_model = self.clip_model.to(devices.cpu)
|
142 |
-
|
143 |
-
def send_blip_to_ram(self):
|
144 |
-
if not shared.opts.interrogate_keep_models_in_memory:
|
145 |
-
if self.blip_model is not None:
|
146 |
-
self.blip_model = self.blip_model.to(devices.cpu)
|
147 |
-
|
148 |
-
def unload(self):
|
149 |
-
self.send_clip_to_ram()
|
150 |
-
self.send_blip_to_ram()
|
151 |
-
|
152 |
-
devices.torch_gc()
|
153 |
-
|
154 |
-
def rank(self, image_features, text_array, top_count=1):
|
155 |
-
import clip
|
156 |
-
|
157 |
-
devices.torch_gc()
|
158 |
-
|
159 |
-
if shared.opts.interrogate_clip_dict_limit != 0:
|
160 |
-
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
161 |
-
|
162 |
-
top_count = min(top_count, len(text_array))
|
163 |
-
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
|
164 |
-
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
165 |
-
text_features /= text_features.norm(dim=-1, keepdim=True)
|
166 |
-
|
167 |
-
similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
|
168 |
-
for i in range(image_features.shape[0]):
|
169 |
-
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
|
170 |
-
similarity /= image_features.shape[0]
|
171 |
-
|
172 |
-
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
|
173 |
-
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
|
174 |
-
|
175 |
-
def generate_caption(self, pil_image):
|
176 |
-
gpu_image = transforms.Compose([
|
177 |
-
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
|
178 |
-
transforms.ToTensor(),
|
179 |
-
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
180 |
-
])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
|
181 |
-
|
182 |
-
with torch.no_grad():
|
183 |
-
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
|
184 |
-
|
185 |
-
return caption[0]
|
186 |
-
|
187 |
-
def interrogate(self, pil_image):
|
188 |
-
res = ""
|
189 |
-
shared.state.begin()
|
190 |
-
shared.state.job = 'interrogate'
|
191 |
-
try:
|
192 |
-
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
193 |
-
lowvram.send_everything_to_cpu()
|
194 |
-
devices.torch_gc()
|
195 |
-
|
196 |
-
self.load()
|
197 |
-
|
198 |
-
caption = self.generate_caption(pil_image)
|
199 |
-
self.send_blip_to_ram()
|
200 |
-
devices.torch_gc()
|
201 |
-
|
202 |
-
res = caption
|
203 |
-
|
204 |
-
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
|
205 |
-
|
206 |
-
with torch.no_grad(), devices.autocast():
|
207 |
-
image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
|
208 |
-
|
209 |
-
image_features /= image_features.norm(dim=-1, keepdim=True)
|
210 |
-
|
211 |
-
for name, topn, items in self.categories():
|
212 |
-
matches = self.rank(image_features, items, top_count=topn)
|
213 |
-
for match, score in matches:
|
214 |
-
if shared.opts.interrogate_return_ranks:
|
215 |
-
res += f", ({match}:{score/100:.3f})"
|
216 |
-
else:
|
217 |
-
res += ", " + match
|
218 |
-
|
219 |
-
except Exception:
|
220 |
-
print("Error interrogating", file=sys.stderr)
|
221 |
-
print(traceback.format_exc(), file=sys.stderr)
|
222 |
-
res += "<error>"
|
223 |
-
|
224 |
-
self.unload()
|
225 |
-
shared.state.end()
|
226 |
-
|
227 |
-
return res
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
from collections import namedtuple
|
5 |
+
from pathlib import Path
|
6 |
+
import re
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.hub
|
10 |
+
|
11 |
+
from torchvision import transforms
|
12 |
+
from torchvision.transforms.functional import InterpolationMode
|
13 |
+
|
14 |
+
import modules.shared as shared
|
15 |
+
from modules import devices, paths, shared, lowvram, modelloader, errors
|
16 |
+
|
17 |
+
blip_image_eval_size = 384
|
18 |
+
clip_model_name = 'ViT-L/14'
|
19 |
+
|
20 |
+
Category = namedtuple("Category", ["name", "topn", "items"])
|
21 |
+
|
22 |
+
re_topn = re.compile(r"\.top(\d+)\.")
|
23 |
+
|
24 |
+
def category_types():
|
25 |
+
return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
|
26 |
+
|
27 |
+
|
28 |
+
def download_default_clip_interrogate_categories(content_dir):
|
29 |
+
print("Downloading CLIP categories...")
|
30 |
+
|
31 |
+
tmpdir = content_dir + "_tmp"
|
32 |
+
category_types = ["artists", "flavors", "mediums", "movements"]
|
33 |
+
|
34 |
+
try:
|
35 |
+
os.makedirs(tmpdir)
|
36 |
+
for category_type in category_types:
|
37 |
+
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
|
38 |
+
os.rename(tmpdir, content_dir)
|
39 |
+
|
40 |
+
except Exception as e:
|
41 |
+
errors.display(e, "downloading default CLIP interrogate categories")
|
42 |
+
finally:
|
43 |
+
if os.path.exists(tmpdir):
|
44 |
+
os.remove(tmpdir)
|
45 |
+
|
46 |
+
|
47 |
+
class InterrogateModels:
|
48 |
+
blip_model = None
|
49 |
+
clip_model = None
|
50 |
+
clip_preprocess = None
|
51 |
+
dtype = None
|
52 |
+
running_on_cpu = None
|
53 |
+
|
54 |
+
def __init__(self, content_dir):
|
55 |
+
self.loaded_categories = None
|
56 |
+
self.skip_categories = []
|
57 |
+
self.content_dir = content_dir
|
58 |
+
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
|
59 |
+
|
60 |
+
def categories(self):
|
61 |
+
if not os.path.exists(self.content_dir):
|
62 |
+
download_default_clip_interrogate_categories(self.content_dir)
|
63 |
+
|
64 |
+
if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
|
65 |
+
return self.loaded_categories
|
66 |
+
|
67 |
+
self.loaded_categories = []
|
68 |
+
|
69 |
+
if os.path.exists(self.content_dir):
|
70 |
+
self.skip_categories = shared.opts.interrogate_clip_skip_categories
|
71 |
+
category_types = []
|
72 |
+
for filename in Path(self.content_dir).glob('*.txt'):
|
73 |
+
category_types.append(filename.stem)
|
74 |
+
if filename.stem in self.skip_categories:
|
75 |
+
continue
|
76 |
+
m = re_topn.search(filename.stem)
|
77 |
+
topn = 1 if m is None else int(m.group(1))
|
78 |
+
with open(filename, "r", encoding="utf8") as file:
|
79 |
+
lines = [x.strip() for x in file.readlines()]
|
80 |
+
|
81 |
+
self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
|
82 |
+
|
83 |
+
return self.loaded_categories
|
84 |
+
|
85 |
+
def create_fake_fairscale(self):
|
86 |
+
class FakeFairscale:
|
87 |
+
def checkpoint_wrapper(self):
|
88 |
+
pass
|
89 |
+
|
90 |
+
sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
|
91 |
+
|
92 |
+
def load_blip_model(self):
|
93 |
+
self.create_fake_fairscale()
|
94 |
+
import models.blip
|
95 |
+
|
96 |
+
files = modelloader.load_models(
|
97 |
+
model_path=os.path.join(paths.models_path, "BLIP"),
|
98 |
+
model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
|
99 |
+
ext_filter=[".pth"],
|
100 |
+
download_name='model_base_caption_capfilt_large.pth',
|
101 |
+
)
|
102 |
+
|
103 |
+
blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
|
104 |
+
blip_model.eval()
|
105 |
+
|
106 |
+
return blip_model
|
107 |
+
|
108 |
+
def load_clip_model(self):
|
109 |
+
import clip
|
110 |
+
|
111 |
+
if self.running_on_cpu:
|
112 |
+
model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
|
113 |
+
else:
|
114 |
+
model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
|
115 |
+
|
116 |
+
model.eval()
|
117 |
+
model = model.to(devices.device_interrogate)
|
118 |
+
|
119 |
+
return model, preprocess
|
120 |
+
|
121 |
+
def load(self):
|
122 |
+
if self.blip_model is None:
|
123 |
+
self.blip_model = self.load_blip_model()
|
124 |
+
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
125 |
+
self.blip_model = self.blip_model.half()
|
126 |
+
|
127 |
+
self.blip_model = self.blip_model.to(devices.device_interrogate)
|
128 |
+
|
129 |
+
if self.clip_model is None:
|
130 |
+
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
131 |
+
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
132 |
+
self.clip_model = self.clip_model.half()
|
133 |
+
|
134 |
+
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
135 |
+
|
136 |
+
self.dtype = next(self.clip_model.parameters()).dtype
|
137 |
+
|
138 |
+
def send_clip_to_ram(self):
|
139 |
+
if not shared.opts.interrogate_keep_models_in_memory:
|
140 |
+
if self.clip_model is not None:
|
141 |
+
self.clip_model = self.clip_model.to(devices.cpu)
|
142 |
+
|
143 |
+
def send_blip_to_ram(self):
|
144 |
+
if not shared.opts.interrogate_keep_models_in_memory:
|
145 |
+
if self.blip_model is not None:
|
146 |
+
self.blip_model = self.blip_model.to(devices.cpu)
|
147 |
+
|
148 |
+
def unload(self):
|
149 |
+
self.send_clip_to_ram()
|
150 |
+
self.send_blip_to_ram()
|
151 |
+
|
152 |
+
devices.torch_gc()
|
153 |
+
|
154 |
+
def rank(self, image_features, text_array, top_count=1):
|
155 |
+
import clip
|
156 |
+
|
157 |
+
devices.torch_gc()
|
158 |
+
|
159 |
+
if shared.opts.interrogate_clip_dict_limit != 0:
|
160 |
+
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
161 |
+
|
162 |
+
top_count = min(top_count, len(text_array))
|
163 |
+
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
|
164 |
+
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
165 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
166 |
+
|
167 |
+
similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
|
168 |
+
for i in range(image_features.shape[0]):
|
169 |
+
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
|
170 |
+
similarity /= image_features.shape[0]
|
171 |
+
|
172 |
+
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
|
173 |
+
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
|
174 |
+
|
175 |
+
def generate_caption(self, pil_image):
|
176 |
+
gpu_image = transforms.Compose([
|
177 |
+
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
|
178 |
+
transforms.ToTensor(),
|
179 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
180 |
+
])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
|
181 |
+
|
182 |
+
with torch.no_grad():
|
183 |
+
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
|
184 |
+
|
185 |
+
return caption[0]
|
186 |
+
|
187 |
+
def interrogate(self, pil_image):
|
188 |
+
res = ""
|
189 |
+
shared.state.begin()
|
190 |
+
shared.state.job = 'interrogate'
|
191 |
+
try:
|
192 |
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
193 |
+
lowvram.send_everything_to_cpu()
|
194 |
+
devices.torch_gc()
|
195 |
+
|
196 |
+
self.load()
|
197 |
+
|
198 |
+
caption = self.generate_caption(pil_image)
|
199 |
+
self.send_blip_to_ram()
|
200 |
+
devices.torch_gc()
|
201 |
+
|
202 |
+
res = caption
|
203 |
+
|
204 |
+
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
|
205 |
+
|
206 |
+
with torch.no_grad(), devices.autocast():
|
207 |
+
image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
|
208 |
+
|
209 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
210 |
+
|
211 |
+
for name, topn, items in self.categories():
|
212 |
+
matches = self.rank(image_features, items, top_count=topn)
|
213 |
+
for match, score in matches:
|
214 |
+
if shared.opts.interrogate_return_ranks:
|
215 |
+
res += f", ({match}:{score/100:.3f})"
|
216 |
+
else:
|
217 |
+
res += ", " + match
|
218 |
+
|
219 |
+
except Exception:
|
220 |
+
print("Error interrogating", file=sys.stderr)
|
221 |
+
print(traceback.format_exc(), file=sys.stderr)
|
222 |
+
res += "<error>"
|
223 |
+
|
224 |
+
self.unload()
|
225 |
+
shared.state.end()
|
226 |
+
|
227 |
+
return res
|
sd/stable-diffusion-webui/modules/localization.py
CHANGED
@@ -1,37 +1,37 @@
|
|
1 |
-
import json
|
2 |
-
import os
|
3 |
-
import sys
|
4 |
-
import traceback
|
5 |
-
|
6 |
-
|
7 |
-
localizations = {}
|
8 |
-
|
9 |
-
|
10 |
-
def list_localizations(dirname):
|
11 |
-
localizations.clear()
|
12 |
-
|
13 |
-
for file in os.listdir(dirname):
|
14 |
-
fn, ext = os.path.splitext(file)
|
15 |
-
if ext.lower() != ".json":
|
16 |
-
continue
|
17 |
-
|
18 |
-
localizations[fn] = os.path.join(dirname, file)
|
19 |
-
|
20 |
-
from modules import scripts
|
21 |
-
for file in scripts.list_scripts("localizations", ".json"):
|
22 |
-
fn, ext = os.path.splitext(file.filename)
|
23 |
-
localizations[fn] = file.path
|
24 |
-
|
25 |
-
|
26 |
-
def localization_js(current_localization_name):
|
27 |
-
fn = localizations.get(current_localization_name, None)
|
28 |
-
data = {}
|
29 |
-
if fn is not None:
|
30 |
-
try:
|
31 |
-
with open(fn, "r", encoding="utf8") as file:
|
32 |
-
data = json.load(file)
|
33 |
-
except Exception:
|
34 |
-
print(f"Error loading localization from {fn}:", file=sys.stderr)
|
35 |
-
print(traceback.format_exc(), file=sys.stderr)
|
36 |
-
|
37 |
-
return f"var localization = {json.dumps(data)}\n"
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import traceback
|
5 |
+
|
6 |
+
|
7 |
+
localizations = {}
|
8 |
+
|
9 |
+
|
10 |
+
def list_localizations(dirname):
|
11 |
+
localizations.clear()
|
12 |
+
|
13 |
+
for file in os.listdir(dirname):
|
14 |
+
fn, ext = os.path.splitext(file)
|
15 |
+
if ext.lower() != ".json":
|
16 |
+
continue
|
17 |
+
|
18 |
+
localizations[fn] = os.path.join(dirname, file)
|
19 |
+
|
20 |
+
from modules import scripts
|
21 |
+
for file in scripts.list_scripts("localizations", ".json"):
|
22 |
+
fn, ext = os.path.splitext(file.filename)
|
23 |
+
localizations[fn] = file.path
|
24 |
+
|
25 |
+
|
26 |
+
def localization_js(current_localization_name):
|
27 |
+
fn = localizations.get(current_localization_name, None)
|
28 |
+
data = {}
|
29 |
+
if fn is not None:
|
30 |
+
try:
|
31 |
+
with open(fn, "r", encoding="utf8") as file:
|
32 |
+
data = json.load(file)
|
33 |
+
except Exception:
|
34 |
+
print(f"Error loading localization from {fn}:", file=sys.stderr)
|
35 |
+
print(traceback.format_exc(), file=sys.stderr)
|
36 |
+
|
37 |
+
return f"var localization = {json.dumps(data)}\n"
|